def test_information_bootstrap():
# Data: two highly correlated MSAs
= 3
n_classes = 5
length = 10
size_each_group = 10
n_groups # Define first MSA group by group
= [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
x_tok_by_group # Within-group shuffling to control for algorithmic biases towards identity permutation
= [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
x_tok_by_group_shuffle = torch.cat(x_tok_by_group, dim=0)
x_tok = torch.cat(x_tok_by_group_shuffle, dim=0)
x_tok_shuffle = (x_tok + 1) % n_classes
y_tok = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
x = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
x_shuffle = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())
y
= [size_each_group] * n_groups
group_sizes
# Model
= InformationPairing(group_sizes=group_sizes)
model = model.fit_bootstrap(x_shuffle, y)
results = model.compute_losses_identity_perm(x, y)["hard"]
hard_loss_identity_perm
# Check that the hard loss of the optimized permutation is close to the ground truth
assert np.abs(results.hard_losses[-2][-1] - hard_loss_identity_perm) < 1e-4
test_information_bootstrap()
train
Perform optimization using DiffPaSS models
Type aliases
= tuple[int, int] # Pair of indices
IndexPair = list[IndexPair] # Pairs of indices in a group of sequences
IndexPairsInGroup = list[IndexPairsInGroup] # Pairs of indices in groups of sequences IndexPairsInGroups
InformationPairing
InformationPairing (group_sizes:collections.abc.Sequence[int], fixed_pairings:Optional[list[list[tuple[int,int]]]]=N one, permutation_cfg:Optional[dict[str,Any]]=None, in formation_measure:Literal['MI','TwoBodyEntropy']='Two BodyEntropy')
DiffPaSS model for information-theoretic pairing of multiple sequence alignments (MSAs).
Type | Default | Details | |
---|---|---|---|
group_sizes | Sequence | Number of sequences in each group (e.g. species) of the two MSAs | |
fixed_pairings | Optional | None | If not None , fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc. |
permutation_cfg | Optional | None | If not None , configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations |
information_measure | Literal | TwoBodyEntropy | Information-theoretic measure to use. For hard permutations, these two measures are equivalent |
BestHitsPairing
BestHitsPairing (group_sizes:collections.abc.Sequence[int], fixed_pairings:Optional[list[list[tuple[int,int]]]]=None , permutation_cfg:Optional[dict[str,Any]]=None, similarity_kind:Literal['Hamming','Blosum62']='Hamming', similarities_cfg:Optional[dict[str,Any]]=None, compute_in_group_best_hits:bool=True, best_hits_cfg:Optional[dict[str,Any]]=None, similarities_comparison_loss:Optional[<built- infunctioncallable>]=None, compare_soft_best_hits_to_hard:bool=True)
DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their orthology networks, constructed using (reciprocal) best hits .
Type | Default | Details | |
---|---|---|---|
group_sizes | Sequence | Number of sequences in each group (e.g. species) of the two MSAs | |
fixed_pairings | Optional | None | If not None , fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc. |
permutation_cfg | Optional | None | If not None , configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations |
similarity_kind | Literal | Hamming | (Smoothly extended) similarity metric to use on all pairs of aligned sequences |
similarities_cfg | Optional | None | If not None , configuration dictionary containing init parameters for the internal HammingSimilarities or Blosum62Similarities object to compute similarity matrices |
compute_in_group_best_hits | bool | True | Whether to also compute best hits within each group (in addition to between different groups) |
best_hits_cfg | Optional | None | If not None , configuration dictionary containing init parameters for the internal BestHits object to compute soft/hard (reciprocal) best hits |
similarities_comparison_loss | Optional | None | If not None , custom callable to compute the differentiable loss between the soft/hard best hits matrices of the two MSAs |
compare_soft_best_hits_to_hard | bool | True | Whether to compare the soft best hits from the MSA to permute (x ) to the hard or soft best hits from the reference MSA (y ) |
def test_besthits_bootstrap():
# Data: two highly correlated MSAs
= 3
n_classes = 100
length = 10
size_each_group = 10
n_groups # Define first MSA group by group
= [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
x_tok_by_group # Within-group shuffling to control for algorithmic biases towards identity permutation
= [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
x_tok_by_group_shuffle = torch.cat(x_tok_by_group, dim=0)
x_tok = torch.cat(x_tok_by_group_shuffle, dim=0)
x_tok_shuffle = (x_tok + 1) % n_classes
y_tok = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
x = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
x_shuffle = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())
y
= [size_each_group] * n_groups
group_sizes
# Model
= BestHitsPairing(group_sizes=group_sizes)
model = model.fit_bootstrap(x_shuffle, y)
results = model.compute_losses_identity_perm(x, y)["hard"]
target_hard_loss
# Check that the hard loss of the optimized permutation is close to the ground truth
assert results.hard_losses[-2][-1] / target_hard_loss > 0.7
test_besthits_bootstrap()
MirrortreePairing
MirrortreePairing (group_sizes:collections.abc.Sequence[int], fixed_pairings:Optional[list[list[tuple[int,int]]]]=No ne, permutation_cfg:Optional[dict[str,Any]]=None, simi larity_kind:Literal['Hamming','Blosum62']='Hamming', similarities_cfg:Optional[dict[str,Any]]=None, similarities_comparison_loss:Optional[<built- infunctioncallable>]=None)
DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their sequence distance networks as in the Mirrortree method.
Type | Default | Details | |
---|---|---|---|
group_sizes | Sequence | Number of sequences in each group (e.g. species) of the two MSAs | |
fixed_pairings | Optional | None | If not None , fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc. |
permutation_cfg | Optional | None | If not None , configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations |
similarity_kind | Literal | Hamming | (Smoothly extended) similarity metric to use on all pairs of aligned sequences |
similarities_cfg | Optional | None | If not None , configuration dictionary containing init parameters for the internal HammingSimilarities or Blosum62Similarities object to compute similarity matrices |
similarities_comparison_loss | Optional | None | If not None , custom callable to compute the differentiable loss between the similarity matrix of the two MSAs. Default: IntraGroupSimilarityLoss |
def test_mirrortree_bootstrap():
# Data: two highly correlated MSAs
= 3
n_classes = 100
length = 10
size_each_group = 10
n_groups # Define first MSA group by group
= [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
x_tok_by_group # Within-group shuffling to control for algorithmic biases towards identity permutation
= [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
x_tok_by_group_shuffle = torch.cat(x_tok_by_group, dim=0)
x_tok = torch.cat(x_tok_by_group_shuffle, dim=0)
x_tok_shuffle = (x_tok + 1) % n_classes
y_tok = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
x = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
x_shuffle = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())
y
= [size_each_group] * n_groups
group_sizes
# Model
= MirrortreePairing(group_sizes=group_sizes)
model = model.fit_bootstrap(x_shuffle, y)
results = model.compute_losses_identity_perm(x, y)["hard"]
target_hard_loss
# Check that the hard loss of the optimized permutation is close to the ground truth
assert results.hard_losses[-2][-1] / target_hard_loss > 0.95
test_mirrortree_bootstrap()
GraphAlignment
GraphAlignment (group_sizes:collections.abc.Sequence[int], fixed_pairings:Optional[list[list[tuple[int,int]]]]=None, permutation_cfg:Optional[dict[str,Any]]=None, comparison_loss:Optional[<built- infunctioncallable>]=None)
DiffPaSS model for general graph alignment starting from the weighted adjacency matrices of two graphs.
Type | Default | Details | |
---|---|---|---|
group_sizes | Sequence | Number of graph nodes in each group (e.g. species), assumed the same between the two graphs to align | |
fixed_pairings | Optional | None | If not None , fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc. |
permutation_cfg | Optional | None | If not None , configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations. Soft/hard permutations P act on adjacency matrices X via P @ X @ P.T |
comparison_loss | Optional | None | If not None , custom callable to compute the differentiable loss between the soft/hard-permuted adjacency matrix of graph x and the adjacency matrix of graph y . Defaults to dot product between all upper triangular elements |
def test_graph_alignment_bootstrap():
# Data: two identical weighted adjacency matrices
= 10
size_each_group = 10
n_groups = size_each_group * n_groups
n_samples = torch.exp(torch.randn((n_samples, n_samples))).to(torch.get_default_dtype())
x = x.clone()
y # Within-group shuffling to control for algorithmic biases towards identity permutation
= []
rand_perm_mats for _ in range(n_groups):
= torch.zeros(
rp_mat =x.dtype, device=x.device
(size_each_group, size_each_group), dtype
)= 1
rp_mat[torch.arange(size_each_group), torch.randperm(size_each_group)]
rand_perm_mats.append(rp_mat)= apply_hard_permutation_batch_to_similarity(x=x, perms=rand_perm_mats)
x_shuffle
= [size_each_group] * n_groups
group_sizes
# Model
= GraphAlignment(group_sizes=group_sizes)
model = model.fit_bootstrap(x_shuffle, y)
results = model.compute_losses_identity_perm(x, y)["hard"]
target_hard_loss
# Check that the hard loss of the optimized permutation is close to the ground truth
assert results.hard_losses[-2][-1] / target_hard_loss > 0.95
test_graph_alignment_bootstrap()