def test_information_bootstrap():
# Data: two highly correlated MSAs
n_classes = 3
length = 5
size_each_group = 10
n_groups = 10
# Define first MSA group by group
x_tok_by_group = [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
# Within-group shuffling to control for algorithmic biases towards identity permutation
x_tok_by_group_shuffle = [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
x_tok = torch.cat(x_tok_by_group, dim=0)
x_tok_shuffle = torch.cat(x_tok_by_group_shuffle, dim=0)
y_tok = (x_tok + 1) % n_classes
x = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
x_shuffle = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
y = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())
group_sizes = [size_each_group] * n_groups
# Model
model = InformationPairing(group_sizes=group_sizes)
results = model.fit_bootstrap(x_shuffle, y)
hard_loss_identity_perm = model.compute_losses_identity_perm(x, y)["hard"]
# Check that the hard loss of the optimized permutation is close to the ground truth
assert np.abs(results.hard_losses[-2][-1] - hard_loss_identity_perm) < 1e-4
test_information_bootstrap()train
Perform optimization using DiffPaSS models
Type aliases
IndexPair = tuple[int, int] # Pair of indices
IndexPairsInGroup = list[IndexPair] # Pairs of indices in a group of sequences
IndexPairsInGroups = list[IndexPairsInGroup] # Pairs of indices in groups of sequencesInformationPairing
InformationPairing (group_sizes:collections.abc.Sequence[int], fixed_pairings:Optional[list[list[tuple[int,int]]]]=N one, permutation_cfg:Optional[dict[str,Any]]=None, in formation_measure:Literal['MI','TwoBodyEntropy']='Two BodyEntropy')
DiffPaSS model for information-theoretic pairing of multiple sequence alignments (MSAs).
| Type | Default | Details | |
|---|---|---|---|
| group_sizes | Sequence | Number of sequences in each group (e.g. species) of the two MSAs | |
| fixed_pairings | Optional | None | If not None, fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc. |
| permutation_cfg | Optional | None | If not None, configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations |
| information_measure | Literal | TwoBodyEntropy | Information-theoretic measure to use. For hard permutations, these two measures are equivalent |
BestHitsPairing
BestHitsPairing (group_sizes:collections.abc.Sequence[int], fixed_pairings:Optional[list[list[tuple[int,int]]]]=None , permutation_cfg:Optional[dict[str,Any]]=None, similarity_kind:Literal['Hamming','Blosum62']='Hamming', similarities_cfg:Optional[dict[str,Any]]=None, compute_in_group_best_hits:bool=True, best_hits_cfg:Optional[dict[str,Any]]=None, similarities_comparison_loss:Optional[<built- infunctioncallable>]=None, compare_soft_best_hits_to_hard:bool=True)
DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their orthology networks, constructed using (reciprocal) best hits .
| Type | Default | Details | |
|---|---|---|---|
| group_sizes | Sequence | Number of sequences in each group (e.g. species) of the two MSAs | |
| fixed_pairings | Optional | None | If not None, fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc. |
| permutation_cfg | Optional | None | If not None, configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations |
| similarity_kind | Literal | Hamming | (Smoothly extended) similarity metric to use on all pairs of aligned sequences |
| similarities_cfg | Optional | None | If not None, configuration dictionary containing init parameters for the internal HammingSimilarities or Blosum62Similarities object to compute similarity matrices |
| compute_in_group_best_hits | bool | True | Whether to also compute best hits within each group (in addition to between different groups) |
| best_hits_cfg | Optional | None | If not None, configuration dictionary containing init parameters for the internal BestHits object to compute soft/hard (reciprocal) best hits |
| similarities_comparison_loss | Optional | None | If not None, custom callable to compute the differentiable loss between the soft/hard best hits matrices of the two MSAs |
| compare_soft_best_hits_to_hard | bool | True | Whether to compare the soft best hits from the MSA to permute (x) to the hard or soft best hits from the reference MSA (y) |
def test_besthits_bootstrap():
# Data: two highly correlated MSAs
n_classes = 3
length = 100
size_each_group = 10
n_groups = 10
# Define first MSA group by group
x_tok_by_group = [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
# Within-group shuffling to control for algorithmic biases towards identity permutation
x_tok_by_group_shuffle = [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
x_tok = torch.cat(x_tok_by_group, dim=0)
x_tok_shuffle = torch.cat(x_tok_by_group_shuffle, dim=0)
y_tok = (x_tok + 1) % n_classes
x = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
x_shuffle = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
y = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())
group_sizes = [size_each_group] * n_groups
# Model
model = BestHitsPairing(group_sizes=group_sizes)
results = model.fit_bootstrap(x_shuffle, y)
target_hard_loss = model.compute_losses_identity_perm(x, y)["hard"]
# Check that the hard loss of the optimized permutation is close to the ground truth
assert results.hard_losses[-2][-1] / target_hard_loss > 0.7
test_besthits_bootstrap()MirrortreePairing
MirrortreePairing (group_sizes:collections.abc.Sequence[int], fixed_pairings:Optional[list[list[tuple[int,int]]]]=No ne, permutation_cfg:Optional[dict[str,Any]]=None, simi larity_kind:Literal['Hamming','Blosum62']='Hamming', similarities_cfg:Optional[dict[str,Any]]=None, similarities_comparison_loss:Optional[<built- infunctioncallable>]=None)
DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their sequence distance networks as in the Mirrortree method.
| Type | Default | Details | |
|---|---|---|---|
| group_sizes | Sequence | Number of sequences in each group (e.g. species) of the two MSAs | |
| fixed_pairings | Optional | None | If not None, fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc. |
| permutation_cfg | Optional | None | If not None, configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations |
| similarity_kind | Literal | Hamming | (Smoothly extended) similarity metric to use on all pairs of aligned sequences |
| similarities_cfg | Optional | None | If not None, configuration dictionary containing init parameters for the internal HammingSimilarities or Blosum62Similarities object to compute similarity matrices |
| similarities_comparison_loss | Optional | None | If not None, custom callable to compute the differentiable loss between the similarity matrix of the two MSAs. Default: IntraGroupSimilarityLoss |
def test_mirrortree_bootstrap():
# Data: two highly correlated MSAs
n_classes = 3
length = 100
size_each_group = 10
n_groups = 10
# Define first MSA group by group
x_tok_by_group = [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
# Within-group shuffling to control for algorithmic biases towards identity permutation
x_tok_by_group_shuffle = [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
x_tok = torch.cat(x_tok_by_group, dim=0)
x_tok_shuffle = torch.cat(x_tok_by_group_shuffle, dim=0)
y_tok = (x_tok + 1) % n_classes
x = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
x_shuffle = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
y = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())
group_sizes = [size_each_group] * n_groups
# Model
model = MirrortreePairing(group_sizes=group_sizes)
results = model.fit_bootstrap(x_shuffle, y)
target_hard_loss = model.compute_losses_identity_perm(x, y)["hard"]
# Check that the hard loss of the optimized permutation is close to the ground truth
assert results.hard_losses[-2][-1] / target_hard_loss > 0.95
test_mirrortree_bootstrap()GraphAlignment
GraphAlignment (group_sizes:collections.abc.Sequence[int], fixed_pairings:Optional[list[list[tuple[int,int]]]]=None, permutation_cfg:Optional[dict[str,Any]]=None, comparison_loss:Optional[<built- infunctioncallable>]=None)
DiffPaSS model for general graph alignment starting from the weighted adjacency matrices of two graphs.
| Type | Default | Details | |
|---|---|---|---|
| group_sizes | Sequence | Number of graph nodes in each group (e.g. species), assumed the same between the two graphs to align | |
| fixed_pairings | Optional | None | If not None, fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc. |
| permutation_cfg | Optional | None | If not None, configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations. Soft/hard permutations P act on adjacency matrices X via P @ X @ P.T |
| comparison_loss | Optional | None | If not None, custom callable to compute the differentiable loss between the soft/hard-permuted adjacency matrix of graph x and the adjacency matrix of graph y. Defaults to dot product between all upper triangular elements |
def test_graph_alignment_bootstrap():
# Data: two identical weighted adjacency matrices
size_each_group = 10
n_groups = 10
n_samples = size_each_group * n_groups
x = torch.exp(torch.randn((n_samples, n_samples))).to(torch.get_default_dtype())
y = x.clone()
# Within-group shuffling to control for algorithmic biases towards identity permutation
rand_perm_mats = []
for _ in range(n_groups):
rp_mat = torch.zeros(
(size_each_group, size_each_group), dtype=x.dtype, device=x.device
)
rp_mat[torch.arange(size_each_group), torch.randperm(size_each_group)] = 1
rand_perm_mats.append(rp_mat)
x_shuffle = apply_hard_permutation_batch_to_similarity(x=x, perms=rand_perm_mats)
group_sizes = [size_each_group] * n_groups
# Model
model = GraphAlignment(group_sizes=group_sizes)
results = model.fit_bootstrap(x_shuffle, y)
target_hard_loss = model.compute_losses_identity_perm(x, y)["hard"]
# Check that the hard loss of the optimized permutation is close to the ground truth
assert results.hard_losses[-2][-1] / target_hard_loss > 0.95
test_graph_alignment_bootstrap()