model

DiffPaSS modules for optimizing permutations and computing soft scores

Type aliases

IndexPair = tuple[int, int]  # Pair of indices
IndexPairsInGroup = list[IndexPair]  # Pairs of indices in a group of sequences
IndexPairsInGroups = list[IndexPairsInGroup]  # Pairs of indices in groups of sequences

Sinkhorn/matching layer for soft/hard permutations


source

apply_hard_permutation_batch_to_similarity

 apply_hard_permutation_batch_to_similarity (x:torch.Tensor,
                                             perms:list[torch.Tensor])

*Conjugate a single similarity matrix by a batch of hard permutations.

Args: perms: List of batches of permutation matrices of shape (…, D, D). x: Similarity matrix of shape (D, D).

Returns: Batch of conjugated matrices of shape (…, D, D).*


source

global_argmax_from_group_argmaxes

 global_argmax_from_group_argmaxes
                                    (mats:collections.abc.Iterable[torch.T
                                    ensor])

source

PermutationConjugate

 PermutationConjugate (group_sizes:collections.abc.Sequence[int])

Conjugate blocks of a square 2D tensor of shape (n_samples, n_samples) by permutation matrices.


source

MatrixApply

 MatrixApply (group_sizes:collections.abc.Sequence[int])

Apply matrices to chunks of a tensor of shape (n_samples, length, alphabet_size) and collate the results.


source

GeneralizedPermutation

 GeneralizedPermutation (group_sizes:collections.abc.Sequence[int],
                         fixed_pairings:Optional[list[list[tuple[int,int]]
                         ]]=None, tau:float=1.0, n_iter:int=1,
                         noise:bool=False, noise_factor:float=1.0,
                         noise_std:bool=False,
                         mode:Literal['soft','hard']='soft')

Generalized permutation layer implementing both soft and hard permutations.

# Test for GeneralizedPermutation

def test_generalizedpermutation(*, length, alphabet_size, init_kwargs):
    group_sizes = init_kwargs["group_sizes"]
    n_samples = sum(group_sizes)

    x = torch.randn(n_samples, length, alphabet_size)
    perm = GeneralizedPermutation(**init_kwargs)
    mats = perm()
    mat_apply = MatrixApply(group_sizes)
    y = mat_apply(x, mats=mats)

    assert y.shape == x.shape
    assert y.requires_grad

    perm.hard_()
    assert perm.mode == "hard"


test_generalizedpermutation(
    length=5,
    alphabet_size=10,
    init_kwargs={
        "group_sizes": [3, 2, 4],
        "fixed_pairings": [[(0, 1)], [(0, 0)], [(1, 0), (2, 3)]],
        "tau": 0.1,
    }
)


def test_batch_perm(shape: tuple[int, int, int, int]):
    perms = torch.randn(*shape)
    x = torch.randn(shape[-2], shape[-1])

    argmax = perms.argmax(-1)
    x_permuted_rows = x[argmax]
    index = argmax.view(*argmax.shape[:-1], 1, -1).expand_as(perms)
    output = torch.gather(x_permuted_rows, -1, index)

    expected = torch.stack([
        torch.stack([
            x[argmax[i, j], :][:, argmax[i, j]] for j in range(shape[1])
        ], dim=0) for i in range(shape[0])
    ], dim=0)

    assert torch.equal(output, expected)


test_batch_perm((2, 5, 4, 4))

Information-theory losses


source

TwoBodyEntropyLoss

 TwoBodyEntropyLoss ()

Differentiable extension of the mean of estimated two-body entropies between all pairs of columns from two one-hot encoded tensors.


source

MILoss

 MILoss ()

Differentiable extension of minus the mean of estimated mutual informations between all pairs of columns from two one-hot encoded tensors.

# Test for TwoBodyEntropyScore

def test_twobodyentropyloss(
        *,
        n_samples, length_x, length_y, alphabet_size
):
    x = torch.randn(
        n_samples, length_x, alphabet_size,
        requires_grad=True
    )
    y = torch.randn(n_samples, length_y, alphabet_size)
    x_soft = softmax(x, dim=-1)
    y_soft = softmax(y, dim=-1)
    two_body_entropy_loss = TwoBodyEntropyLoss()
    loss = two_body_entropy_loss(x_soft, y_soft)

    assert loss.requires_grad

    # In the following scenario, the score should be close to log2(alphabet_size)
    x_almost_hard = softmax(x / 1e-5, dim=-1)
    first_x_almost_hard_length_1 = x_almost_hard[:, :1, :]
    loss = two_body_entropy_loss(
        first_x_almost_hard_length_1, first_x_almost_hard_length_1
    )

    torch.testing.assert_close(
        loss, torch.log2(torch.tensor(alphabet_size)), atol=1e-3, rtol=1e-7
    )


test_twobodyentropyloss(
    n_samples=10_000,
    length_x=3,
    length_y=4,
    alphabet_size=3,
)

Sequence similarities (Hamming and Blosum62)


source

HammingSimilarities

 HammingSimilarities
                      (group_sizes:Optional[collections.abc.Sequence[int]]
                      =None, use_dot:bool=True, p:Optional[float]=None)

*Compute Hamming similarities between sequences using differentiable operations.

Optionally, if the sequences are arranged in groups, the computation of similarities can be restricted to within groups. Differentiable operations are used to compute the similarities, which can be either dot products or an L^p distance function.*


source

Blosum62Similarities

 Blosum62Similarities
                       (group_sizes:Optional[collections.abc.Sequence[int]
                       ]=None, use_dot:bool=True, p:Optional[float]=None,
                       use_scoredist:bool=False,
                       aa_to_int:Optional[dict[str,int]]=None,
                       gaps_as_stars:bool=True)

*Compute Blosum62-based similarities between sequences using differentiable operations.

Optionally, if the sequences are arranged in groups, the computation of similarities can be restricted to within groups. Differentiable operations are used to compute the similarities, which can be either dot products or an L^p distance function.*

# Tests for HammingSimilarities and Blosum62Similarities

def test_similarities(
        *,
        cls,
        length, alphabet_size,
        init_kwargs
):
    group_sizes = init_kwargs["group_sizes"]
    n_samples = sum(group_sizes)

    x = torch.randn(
        n_samples, length, alphabet_size,
        requires_grad=True
    )
    x_soft = softmax(x, dim=-1)

    _init_kwargs = deepcopy(init_kwargs)
    _init_kwargs["group_sizes"] = None
    similarities = cls(**_init_kwargs)
    out_all = similarities(x_soft)

    assert out_all.shape == (n_samples, n_samples)

    similarities = cls(**init_kwargs)
    out = similarities(x_soft)

    for sl in similarities._group_slices:
        assert torch.allclose(
            out[..., sl, sl], out_all[..., sl, sl]
        )


test_similarities(
    cls=HammingSimilarities,
    length=3,
    alphabet_size=10,
    init_kwargs={"group_sizes": [3, 2, 4], "use_dot": False, "p": 1.}
)

test_similarities(
    cls=Blosum62Similarities,
    length=3,
    alphabet_size=21,
    init_kwargs={"group_sizes": [3, 2, 4]}
)

Best hits from similarities


source

BestHits

 BestHits (reciprocal:bool=True,
           group_sizes:Optional[collections.abc.Sequence[int]],
           tau:float=0.1, mode:Literal['soft','hard']='soft')

*Compute (reciprocal) best hits within and between groups of sequences, starting from a similarity matrix.

Best hits can be either ‘hard’, in which cases they are computed using the argmax, or ‘soft’, in which case they are computed using the softmax with a temperature parameter tau. In both cases, the main diagonal in the similarity matrix is excluded by setting its entries to minus infinity.*

Losses based on comparing similarity matrices


source

InterGroupSimilarityLoss

 InterGroupSimilarityLoss (group_sizes:collections.abc.Sequence[int],
                           score_fn:Optional[<built-
                           infunctioncallable>]=None)

*Compute a loss that compares similarity matrices restricted to inter-group relationships.

Similarity matrices are expected to be square and symmetric. The loss is computed by comparing the (flattened and concatenated) blocks containing inter-group similarities.*


source

IntraGroupSimilarityLoss

 IntraGroupSimilarityLoss
                           (group_sizes:Optional[collections.abc.Sequence[
                           int]]=None, score_fn:Optional[<built-
                           infunctioncallable>]=None,
                           exclude_diagonal:bool=True)

*Compute a loss that compares similarity matrices restricted to intra-group relationships.

Similarity matrices are expected to be square and symmetric. Their diagonal elements are ignored if exclude_diagonal is set to True. If group_sizes is provided, the loss is computed by comparing the flattened and concatenated upper triangular blocks containing intra-group similarities. Otherwise, the loss is computed by comparing the upper triangular part of the full similarity matrices.*

# Test for BestHits, InterGroupSimilarityLoss and IntraGroupSimilarityLoss

def test_sequence_similarity_losses(
        *,
        group_sizes,
        length_x, length_y, alphabet_size,
        extra_init_kwargs_bh, extra_init_kwargs_loss
):
    similarities = HammingSimilarities(group_sizes=None)
    best_hits = BestHits(group_sizes=group_sizes, **extra_init_kwargs_bh)
    n_samples = sum(group_sizes)

    y = torch.randn(n_samples, length_y, alphabet_size)
    y.scatter_(-1, y.argmax(dim=-1, keepdim=True), 1.)
    similarities_y = similarities(y)
    best_hits.hard_()
    best_hits_y = best_hits(similarities_y)
    best_hits.soft_()

    x = torch.randn(
        n_samples, length_x, alphabet_size,
        requires_grad=True
    )
    x_soft = softmax(x, dim=-1)
    similarities_x = similarities(x_soft)
    best_hits_x = best_hits(similarities_x)

    #### Best hits loss ####
    inter_group_similarity_loss = InterGroupSimilarityLoss(group_sizes=group_sizes, **extra_init_kwargs_loss)
    loss = inter_group_similarity_loss(best_hits_x, best_hits_y)

    assert loss.requires_grad

    # In the following scenario, the loss should be close to -1
    extra_init_kwargs_bh = deepcopy(extra_init_kwargs_bh)
    extra_init_kwargs_bh["tau"] = 1e-4
    best_hits = BestHits(group_sizes=group_sizes, **extra_init_kwargs_bh)
    best_hits_x = best_hits(similarities_y)
    loss = inter_group_similarity_loss(best_hits_x, best_hits_y)

    torch.testing.assert_close(loss, torch.tensor(-1.))

    #### Mirrortree-like loss ####
    intra_group_similarity_loss = IntraGroupSimilarityLoss(**extra_init_kwargs_loss)
    loss = intra_group_similarity_loss(similarities_x, similarities_y)

    assert loss.requires_grad

    # In the following scenario, the loss should be close to -1
    loss = intra_group_similarity_loss(similarities_y, similarities_y)

    torch.testing.assert_close(loss, torch.tensor(-1.))


test_sequence_similarity_losses( 
    length_x=3,
    length_y=4,
    alphabet_size=3,
    group_sizes=[3, 2, 4],
    extra_init_kwargs_bh={
        "tau": 0.1,
    },
    extra_init_kwargs_loss={
        "score_fn": torch.nn.CosineSimilarity(dim=-1)
    }
)