# Test for GeneralizedPermutation
def test_generalizedpermutation(*, length, alphabet_size, init_kwargs):
= init_kwargs["group_sizes"]
group_sizes = sum(group_sizes)
n_samples
= torch.randn(n_samples, length, alphabet_size)
x = GeneralizedPermutation(**init_kwargs)
perm = perm()
mats = MatrixApply(group_sizes)
mat_apply = mat_apply(x, mats=mats)
y
assert y.shape == x.shape
assert y.requires_grad
perm.hard_()assert perm.mode == "hard"
test_generalizedpermutation(=5,
length=10,
alphabet_size={
init_kwargs"group_sizes": [3, 2, 4],
"fixed_pairings": [[(0, 1)], [(0, 0)], [(1, 0), (2, 3)]],
"tau": 0.1,
}
)
def test_batch_perm(shape: tuple[int, int, int, int]):
= torch.randn(*shape)
perms = torch.randn(shape[-2], shape[-1])
x
= perms.argmax(-1)
argmax = x[argmax]
x_permuted_rows = argmax.view(*argmax.shape[:-1], 1, -1).expand_as(perms)
index = torch.gather(x_permuted_rows, -1, index)
output
= torch.stack([
expected
torch.stack([for j in range(shape[1])
x[argmax[i, j], :][:, argmax[i, j]] =0) for i in range(shape[0])
], dim=0)
], dim
assert torch.equal(output, expected)
2, 5, 4, 4)) test_batch_perm((
model
Type aliases
= tuple[int, int] # Pair of indices
IndexPair = list[IndexPair] # Pairs of indices in a group of sequences
IndexPairsInGroup = list[IndexPairsInGroup] # Pairs of indices in groups of sequences IndexPairsInGroups
Sinkhorn/matching layer for soft/hard permutations
apply_hard_permutation_batch_to_similarity
apply_hard_permutation_batch_to_similarity (x:torch.Tensor, perms:list[torch.Tensor])
*Conjugate a single similarity matrix by a batch of hard permutations.
Args: perms: List of batches of permutation matrices of shape (…, D, D). x: Similarity matrix of shape (D, D).
Returns: Batch of conjugated matrices of shape (…, D, D).*
global_argmax_from_group_argmaxes
global_argmax_from_group_argmaxes (mats:collections.abc.Iterable[torch.T ensor])
PermutationConjugate
PermutationConjugate (group_sizes:collections.abc.Sequence[int])
Conjugate blocks of a square 2D tensor of shape (n_samples, n_samples) by permutation matrices.
MatrixApply
MatrixApply (group_sizes:collections.abc.Sequence[int])
Apply matrices to chunks of a tensor of shape (n_samples, length, alphabet_size) and collate the results.
GeneralizedPermutation
GeneralizedPermutation (group_sizes:collections.abc.Sequence[int], fixed_pairings:Optional[list[list[tuple[int,int]] ]]=None, tau:float=1.0, n_iter:int=1, noise:bool=False, noise_factor:float=1.0, noise_std:bool=False, mode:Literal['soft','hard']='soft')
Generalized permutation layer implementing both soft and hard permutations.
Information-theory losses
TwoBodyEntropyLoss
TwoBodyEntropyLoss ()
Differentiable extension of the mean of estimated two-body entropies between all pairs of columns from two one-hot encoded tensors.
MILoss
MILoss ()
Differentiable extension of minus the mean of estimated mutual informations between all pairs of columns from two one-hot encoded tensors.
# Test for TwoBodyEntropyScore
def test_twobodyentropyloss(
*,
n_samples, length_x, length_y, alphabet_size
):= torch.randn(
x
n_samples, length_x, alphabet_size,=True
requires_grad
)= torch.randn(n_samples, length_y, alphabet_size)
y = softmax(x, dim=-1)
x_soft = softmax(y, dim=-1)
y_soft = TwoBodyEntropyLoss()
two_body_entropy_loss = two_body_entropy_loss(x_soft, y_soft)
loss
assert loss.requires_grad
# In the following scenario, the score should be close to log2(alphabet_size)
= softmax(x / 1e-5, dim=-1)
x_almost_hard = x_almost_hard[:, :1, :]
first_x_almost_hard_length_1 = two_body_entropy_loss(
loss
first_x_almost_hard_length_1, first_x_almost_hard_length_1
)
torch.testing.assert_close(=1e-3, rtol=1e-7
loss, torch.log2(torch.tensor(alphabet_size)), atol
)
test_twobodyentropyloss(=10_000,
n_samples=3,
length_x=4,
length_y=3,
alphabet_size )
Sequence similarities (Hamming and Blosum62)
HammingSimilarities
HammingSimilarities (group_sizes:Optional[collections.abc.Sequence[int]] =None, use_dot:bool=True, p:Optional[float]=None)
*Compute Hamming similarities between sequences using differentiable operations.
Optionally, if the sequences are arranged in groups, the computation of similarities can be restricted to within groups. Differentiable operations are used to compute the similarities, which can be either dot products or an L^p distance function.*
Blosum62Similarities
Blosum62Similarities (group_sizes:Optional[collections.abc.Sequence[int] ]=None, use_dot:bool=True, p:Optional[float]=None, use_scoredist:bool=False, aa_to_int:Optional[dict[str,int]]=None, gaps_as_stars:bool=True)
*Compute Blosum62-based similarities between sequences using differentiable operations.
Optionally, if the sequences are arranged in groups, the computation of similarities can be restricted to within groups. Differentiable operations are used to compute the similarities, which can be either dot products or an L^p distance function.*
# Tests for HammingSimilarities and Blosum62Similarities
def test_similarities(
*,
cls,
length, alphabet_size,
init_kwargs
):= init_kwargs["group_sizes"]
group_sizes = sum(group_sizes)
n_samples
= torch.randn(
x
n_samples, length, alphabet_size,=True
requires_grad
)= softmax(x, dim=-1)
x_soft
= deepcopy(init_kwargs)
_init_kwargs "group_sizes"] = None
_init_kwargs[= cls(**_init_kwargs)
similarities = similarities(x_soft)
out_all
assert out_all.shape == (n_samples, n_samples)
= cls(**init_kwargs)
similarities = similarities(x_soft)
out
for sl in similarities._group_slices:
assert torch.allclose(
out[..., sl, sl], out_all[..., sl, sl]
)
test_similarities(=HammingSimilarities,
cls=3,
length=10,
alphabet_size={"group_sizes": [3, 2, 4], "use_dot": False, "p": 1.}
init_kwargs
)
test_similarities(=Blosum62Similarities,
cls=3,
length=21,
alphabet_size={"group_sizes": [3, 2, 4]}
init_kwargs )
Best hits from similarities
BestHits
BestHits (reciprocal:bool=True, group_sizes:Optional[collections.abc.Sequence[int]], tau:float=0.1, mode:Literal['soft','hard']='soft')
*Compute (reciprocal) best hits within and between groups of sequences, starting from a similarity matrix.
Best hits can be either ‘hard’, in which cases they are computed using the argmax, or ‘soft’, in which case they are computed using the softmax with a temperature parameter tau
. In both cases, the main diagonal in the similarity matrix is excluded by setting its entries to minus infinity.*
Losses based on comparing similarity matrices
InterGroupSimilarityLoss
InterGroupSimilarityLoss (group_sizes:collections.abc.Sequence[int], score_fn:Optional[<built- infunctioncallable>]=None)
*Compute a loss that compares similarity matrices restricted to inter-group relationships.
Similarity matrices are expected to be square and symmetric. The loss is computed by comparing the (flattened and concatenated) blocks containing inter-group similarities.*
IntraGroupSimilarityLoss
IntraGroupSimilarityLoss (group_sizes:Optional[collections.abc.Sequence[ int]]=None, score_fn:Optional[<built- infunctioncallable>]=None, exclude_diagonal:bool=True)
*Compute a loss that compares similarity matrices restricted to intra-group relationships.
Similarity matrices are expected to be square and symmetric. Their diagonal elements are ignored if exclude_diagonal
is set to True
. If group_sizes
is provided, the loss is computed by comparing the flattened and concatenated upper triangular blocks containing intra-group similarities. Otherwise, the loss is computed by comparing the upper triangular part of the full similarity matrices.*
# Test for BestHits, InterGroupSimilarityLoss and IntraGroupSimilarityLoss
def test_sequence_similarity_losses(
*,
group_sizes,
length_x, length_y, alphabet_size,
extra_init_kwargs_bh, extra_init_kwargs_loss
):= HammingSimilarities(group_sizes=None)
similarities = BestHits(group_sizes=group_sizes, **extra_init_kwargs_bh)
best_hits = sum(group_sizes)
n_samples
= torch.randn(n_samples, length_y, alphabet_size)
y -1, y.argmax(dim=-1, keepdim=True), 1.)
y.scatter_(= similarities(y)
similarities_y
best_hits.hard_()= best_hits(similarities_y)
best_hits_y
best_hits.soft_()
= torch.randn(
x
n_samples, length_x, alphabet_size,=True
requires_grad
)= softmax(x, dim=-1)
x_soft = similarities(x_soft)
similarities_x = best_hits(similarities_x)
best_hits_x
#### Best hits loss ####
= InterGroupSimilarityLoss(group_sizes=group_sizes, **extra_init_kwargs_loss)
inter_group_similarity_loss = inter_group_similarity_loss(best_hits_x, best_hits_y)
loss
assert loss.requires_grad
# In the following scenario, the loss should be close to -1
= deepcopy(extra_init_kwargs_bh)
extra_init_kwargs_bh "tau"] = 1e-4
extra_init_kwargs_bh[= BestHits(group_sizes=group_sizes, **extra_init_kwargs_bh)
best_hits = best_hits(similarities_y)
best_hits_x = inter_group_similarity_loss(best_hits_x, best_hits_y)
loss
-1.))
torch.testing.assert_close(loss, torch.tensor(
#### Mirrortree-like loss ####
= IntraGroupSimilarityLoss(**extra_init_kwargs_loss)
intra_group_similarity_loss = intra_group_similarity_loss(similarities_x, similarities_y)
loss
assert loss.requires_grad
# In the following scenario, the loss should be close to -1
= intra_group_similarity_loss(similarities_y, similarities_y)
loss
-1.))
torch.testing.assert_close(loss, torch.tensor(
test_sequence_similarity_losses( =3,
length_x=4,
length_y=3,
alphabet_size=[3, 2, 4],
group_sizes={
extra_init_kwargs_bh"tau": 0.1,
},={
extra_init_kwargs_loss"score_fn": torch.nn.CosineSimilarity(dim=-1)
} )