base
DiffPaSS base classes
Type aliases
= list # List indexed by bootstrap iteration
BootstrapList = list # List indexed by gradient descent iteration
GradientDescentList = list # List indexed by group index
GroupByGroupList
= tuple[int, int] # Pair of indices
IndexPair = list[IndexPair] # Pairs of indices in a group of sequences
IndexPairsInGroup = list[IndexPairsInGroup] # Pairs of indices in groups of sequences IndexPairsInGroups
make_pbar
make_pbar (epochs:int, show_pbar:bool)
dccn
dccn (x:torch.Tensor)
DiffPaSSResults
DiffPaSSResults (log_alphas:Union[list[list[numpy.ndarray]],list[list[lis t[numpy.ndarray]]],NoneType], soft_perms:Union[list[list [numpy.ndarray]],list[list[list[numpy.ndarray]]],NoneTyp e], hard_perms:Union[list[list[numpy.ndarray]],list[list [list[numpy.ndarray]]]], hard_losses:Union[list[list[flo at]],list[list[list[float]]]], soft_losses:Union[list[li st[float]],list[list[list[float]]],NoneType])
Container for results of DiffPaSS fits.
DiffPaSSModel
DiffPaSSModel (*args, **kwargs)
Base class for DiffPaSS models.
DiffPaSSModel.fit
DiffPaSSModel.fit (x:torch.Tensor, y:torch.Tensor, epochs:int=1, optimizer_name:Optional[str]='SGD', optimizer_kwargs:Optional[dict[str,Any]]=None, mean_centering:bool=False, show_pbar:bool=False, compute_final_soft:bool=False, record_log_alphas:bool=False, record_soft_perms:bool=False, record_soft_losses:bool=False)
Fit permutations to data using gradient descent.
Type | Default | Details | |
---|---|---|---|
x | Tensor | The object (MSA or adjacency matrix of graphs) to be permuted | |
y | Tensor | The target object (MSA or adjacency matrix of graphs), that the objects represented by x should be paired with. Not acted upon by soft/hard permutations |
|
epochs | int | 1 | |
optimizer_name | Optional | SGD | |
optimizer_kwargs | Optional | None | |
mean_centering | bool | False | |
show_pbar | bool | False | |
compute_final_soft | bool | False | |
record_log_alphas | bool | False | |
record_soft_perms | bool | False | |
record_soft_losses | bool | False | |
Returns | DiffPaSSResults |
DiffPaSSModel.fit_bootstrap
DiffPaSSModel.fit_bootstrap (x:torch.Tensor, y:torch.Tensor, n_start:int=1, n_end:Optional[int]=None, step_size:int=1, n_repeats:int=1, show_pbar:bool=True, single_fit_cfg:Optional[dict]=None)
*Fit permutations to data using the DiffPaSS bootstrap.
The DiffPaSS bootstrap consists of a sequence of short gradient descent runs (default: one epoch per run). At the end of each run, a subset of the found pairings is chosen uniformly at random and fixed for the next run. The number of pairings fixed at each iteration ranges between n_start
(default: 1) and n_end
(default: total number of pairs), with a step size of step_size
.*
Type | Default | Details | |
---|---|---|---|
x | Tensor | The object (MSA or adjacency matrix of graphs) to be permuted | |
y | Tensor | The target object (MSA or adjacency matrix of graphs), that the objects represented by x should be paired with. Not acted upon by soft/hard permutations |
|
n_start | int | 1 | |
n_end | Optional | None | |
step_size | int | 1 | |
n_repeats | int | 1 | |
show_pbar | bool | True | |
single_fit_cfg | Optional | None | |
Returns | DiffPaSSResults |