# Test for GeneralizedPermutation
def test_generalizedpermutation(*, length, alphabet_size, init_kwargs):
group_sizes = init_kwargs["group_sizes"]
n_samples = sum(group_sizes)
x = torch.randn(n_samples, length, alphabet_size)
perm = GeneralizedPermutation(**init_kwargs)
mats = perm()
mat_apply = MatrixApply(group_sizes)
y = mat_apply(x, mats=mats)
assert y.shape == x.shape
assert y.requires_grad
perm.hard_()
assert perm.mode == "hard"
test_generalizedpermutation(
length=5,
alphabet_size=10,
init_kwargs={
"group_sizes": [3, 2, 4],
"fixed_pairings": [[(0, 1)], [(0, 0)], [(1, 0), (2, 3)]],
"tau": 0.1,
}
)
def test_batch_perm(shape: tuple[int, int, int, int]):
perms = torch.randn(*shape)
x = torch.randn(shape[-2], shape[-1])
argmax = perms.argmax(-1)
x_permuted_rows = x[argmax]
index = argmax.view(*argmax.shape[:-1], 1, -1).expand_as(perms)
output = torch.gather(x_permuted_rows, -1, index)
expected = torch.stack([
torch.stack([
x[argmax[i, j], :][:, argmax[i, j]] for j in range(shape[1])
], dim=0) for i in range(shape[0])
], dim=0)
assert torch.equal(output, expected)
test_batch_perm((2, 5, 4, 4))model
Type aliases
IndexPair = tuple[int, int] # Pair of indices
IndexPairsInGroup = list[IndexPair] # Pairs of indices in a group of sequences
IndexPairsInGroups = list[IndexPairsInGroup] # Pairs of indices in groups of sequencesSinkhorn/matching layer for soft/hard permutations
apply_hard_permutation_batch_to_similarity
apply_hard_permutation_batch_to_similarity (x:torch.Tensor, perms:list[torch.Tensor])
*Conjugate a single similarity matrix by a batch of hard permutations.
Args: perms: List of batches of permutation matrices of shape (…, D, D). x: Similarity matrix of shape (D, D).
Returns: Batch of conjugated matrices of shape (…, D, D).*
global_argmax_from_group_argmaxes
global_argmax_from_group_argmaxes (mats:collections.abc.Iterable[torch.T ensor])
PermutationConjugate
PermutationConjugate (group_sizes:collections.abc.Sequence[int])
Conjugate blocks of a square 2D tensor of shape (n_samples, n_samples) by permutation matrices.
MatrixApply
MatrixApply (group_sizes:collections.abc.Sequence[int])
Apply matrices to chunks of a tensor of shape (n_samples, length, alphabet_size) and collate the results.
GeneralizedPermutation
GeneralizedPermutation (group_sizes:collections.abc.Sequence[int], fixed_pairings:Optional[list[list[tuple[int,int]] ]]=None, tau:float=1.0, n_iter:int=1, noise:bool=False, noise_factor:float=1.0, noise_std:bool=False, mode:Literal['soft','hard']='soft')
Generalized permutation layer implementing both soft and hard permutations.
Information-theory losses
TwoBodyEntropyLoss
TwoBodyEntropyLoss ()
Differentiable extension of the mean of estimated two-body entropies between all pairs of columns from two one-hot encoded tensors.
MILoss
MILoss ()
Differentiable extension of minus the mean of estimated mutual informations between all pairs of columns from two one-hot encoded tensors.
# Test for TwoBodyEntropyScore
def test_twobodyentropyloss(
*,
n_samples, length_x, length_y, alphabet_size
):
x = torch.randn(
n_samples, length_x, alphabet_size,
requires_grad=True
)
y = torch.randn(n_samples, length_y, alphabet_size)
x_soft = softmax(x, dim=-1)
y_soft = softmax(y, dim=-1)
two_body_entropy_loss = TwoBodyEntropyLoss()
loss = two_body_entropy_loss(x_soft, y_soft)
assert loss.requires_grad
# In the following scenario, the score should be close to log2(alphabet_size)
x_almost_hard = softmax(x / 1e-5, dim=-1)
first_x_almost_hard_length_1 = x_almost_hard[:, :1, :]
loss = two_body_entropy_loss(
first_x_almost_hard_length_1, first_x_almost_hard_length_1
)
torch.testing.assert_close(
loss, torch.log2(torch.tensor(alphabet_size)), atol=1e-3, rtol=1e-7
)
test_twobodyentropyloss(
n_samples=10_000,
length_x=3,
length_y=4,
alphabet_size=3,
)Sequence similarities (Hamming and Blosum62)
HammingSimilarities
HammingSimilarities (group_sizes:Optional[collections.abc.Sequence[int]] =None, use_dot:bool=True, p:Optional[float]=None)
*Compute Hamming similarities between sequences using differentiable operations.
Optionally, if the sequences are arranged in groups, the computation of similarities can be restricted to within groups. Differentiable operations are used to compute the similarities, which can be either dot products or an L^p distance function.*
Blosum62Similarities
Blosum62Similarities (group_sizes:Optional[collections.abc.Sequence[int] ]=None, use_dot:bool=True, p:Optional[float]=None, use_scoredist:bool=False, aa_to_int:Optional[dict[str,int]]=None, gaps_as_stars:bool=True)
*Compute Blosum62-based similarities between sequences using differentiable operations.
Optionally, if the sequences are arranged in groups, the computation of similarities can be restricted to within groups. Differentiable operations are used to compute the similarities, which can be either dot products or an L^p distance function.*
# Tests for HammingSimilarities and Blosum62Similarities
def test_similarities(
*,
cls,
length, alphabet_size,
init_kwargs
):
group_sizes = init_kwargs["group_sizes"]
n_samples = sum(group_sizes)
x = torch.randn(
n_samples, length, alphabet_size,
requires_grad=True
)
x_soft = softmax(x, dim=-1)
_init_kwargs = deepcopy(init_kwargs)
_init_kwargs["group_sizes"] = None
similarities = cls(**_init_kwargs)
out_all = similarities(x_soft)
assert out_all.shape == (n_samples, n_samples)
similarities = cls(**init_kwargs)
out = similarities(x_soft)
for sl in similarities._group_slices:
assert torch.allclose(
out[..., sl, sl], out_all[..., sl, sl]
)
test_similarities(
cls=HammingSimilarities,
length=3,
alphabet_size=10,
init_kwargs={"group_sizes": [3, 2, 4], "use_dot": False, "p": 1.}
)
test_similarities(
cls=Blosum62Similarities,
length=3,
alphabet_size=21,
init_kwargs={"group_sizes": [3, 2, 4]}
)Best hits from similarities
BestHits
BestHits (reciprocal:bool=True, group_sizes:Optional[collections.abc.Sequence[int]], tau:float=0.1, mode:Literal['soft','hard']='soft')
*Compute (reciprocal) best hits within and between groups of sequences, starting from a similarity matrix.
Best hits can be either ‘hard’, in which cases they are computed using the argmax, or ‘soft’, in which case they are computed using the softmax with a temperature parameter tau. In both cases, the main diagonal in the similarity matrix is excluded by setting its entries to minus infinity.*
Losses based on comparing similarity matrices
InterGroupSimilarityLoss
InterGroupSimilarityLoss (group_sizes:collections.abc.Sequence[int], score_fn:Optional[<built- infunctioncallable>]=None)
*Compute a loss that compares similarity matrices restricted to inter-group relationships.
Similarity matrices are expected to be square and symmetric. The loss is computed by comparing the (flattened and concatenated) blocks containing inter-group similarities.*
IntraGroupSimilarityLoss
IntraGroupSimilarityLoss (group_sizes:Optional[collections.abc.Sequence[ int]]=None, score_fn:Optional[<built- infunctioncallable>]=None, exclude_diagonal:bool=True)
*Compute a loss that compares similarity matrices restricted to intra-group relationships.
Similarity matrices are expected to be square and symmetric. Their diagonal elements are ignored if exclude_diagonal is set to True. If group_sizes is provided, the loss is computed by comparing the flattened and concatenated upper triangular blocks containing intra-group similarities. Otherwise, the loss is computed by comparing the upper triangular part of the full similarity matrices.*
# Test for BestHits, InterGroupSimilarityLoss and IntraGroupSimilarityLoss
def test_sequence_similarity_losses(
*,
group_sizes,
length_x, length_y, alphabet_size,
extra_init_kwargs_bh, extra_init_kwargs_loss
):
similarities = HammingSimilarities(group_sizes=None)
best_hits = BestHits(group_sizes=group_sizes, **extra_init_kwargs_bh)
n_samples = sum(group_sizes)
y = torch.randn(n_samples, length_y, alphabet_size)
y.scatter_(-1, y.argmax(dim=-1, keepdim=True), 1.)
similarities_y = similarities(y)
best_hits.hard_()
best_hits_y = best_hits(similarities_y)
best_hits.soft_()
x = torch.randn(
n_samples, length_x, alphabet_size,
requires_grad=True
)
x_soft = softmax(x, dim=-1)
similarities_x = similarities(x_soft)
best_hits_x = best_hits(similarities_x)
#### Best hits loss ####
inter_group_similarity_loss = InterGroupSimilarityLoss(group_sizes=group_sizes, **extra_init_kwargs_loss)
loss = inter_group_similarity_loss(best_hits_x, best_hits_y)
assert loss.requires_grad
# In the following scenario, the loss should be close to -1
extra_init_kwargs_bh = deepcopy(extra_init_kwargs_bh)
extra_init_kwargs_bh["tau"] = 1e-4
best_hits = BestHits(group_sizes=group_sizes, **extra_init_kwargs_bh)
best_hits_x = best_hits(similarities_y)
loss = inter_group_similarity_loss(best_hits_x, best_hits_y)
torch.testing.assert_close(loss, torch.tensor(-1.))
#### Mirrortree-like loss ####
intra_group_similarity_loss = IntraGroupSimilarityLoss(**extra_init_kwargs_loss)
loss = intra_group_similarity_loss(similarities_x, similarities_y)
assert loss.requires_grad
# In the following scenario, the loss should be close to -1
loss = intra_group_similarity_loss(similarities_y, similarities_y)
torch.testing.assert_close(loss, torch.tensor(-1.))
test_sequence_similarity_losses(
length_x=3,
length_y=4,
alphabet_size=3,
group_sizes=[3, 2, 4],
extra_init_kwargs_bh={
"tau": 0.1,
},
extra_init_kwargs_loss={
"score_fn": torch.nn.CosineSimilarity(dim=-1)
}
)