train

Perform optimization using DiffPaSS models

Type aliases

IndexPair = tuple[int, int]  # Pair of indices
IndexPairsInGroup = list[IndexPair]  # Pairs of indices in a group of sequences
IndexPairsInGroups = list[IndexPairsInGroup]  # Pairs of indices in groups of sequences

source

InformationPairing

 InformationPairing (group_sizes:collections.abc.Sequence[int],
                     fixed_pairings:Optional[list[list[tuple[int,int]]]]=N
                     one, permutation_cfg:Optional[dict[str,Any]]=None, in
                     formation_measure:Literal['MI','TwoBodyEntropy']='Two
                     BodyEntropy')

DiffPaSS model for information-theoretic pairing of multiple sequence alignments (MSAs).

Type Default Details
group_sizes Sequence Number of sequences in each group (e.g. species) of the two MSAs
fixed_pairings Optional None If not None, fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc.
permutation_cfg Optional None If not None, configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations
information_measure Literal TwoBodyEntropy Information-theoretic measure to use. For hard permutations, these two measures are equivalent
def test_information_bootstrap():
    # Data: two highly correlated MSAs
    n_classes = 3
    length = 5
    size_each_group = 10
    n_groups = 10
    # Define first MSA group by group
    x_tok_by_group = [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
    # Within-group shuffling to control for algorithmic biases towards identity permutation
    x_tok_by_group_shuffle = [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
    x_tok = torch.cat(x_tok_by_group, dim=0)
    x_tok_shuffle = torch.cat(x_tok_by_group_shuffle, dim=0)
    y_tok = (x_tok + 1) % n_classes
    x = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
    x_shuffle = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
    y = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())

    group_sizes = [size_each_group] * n_groups

    # Model
    model = InformationPairing(group_sizes=group_sizes)
    results = model.fit_bootstrap(x_shuffle, y)
    hard_loss_identity_perm = model.compute_losses_identity_perm(x, y)["hard"]

    # Check that the hard loss of the optimized permutation is close to the ground truth
    assert np.abs(results.hard_losses[-2][-1] - hard_loss_identity_perm) < 1e-4

test_information_bootstrap()

source

BestHitsPairing

 BestHitsPairing (group_sizes:collections.abc.Sequence[int],
                  fixed_pairings:Optional[list[list[tuple[int,int]]]]=None
                  , permutation_cfg:Optional[dict[str,Any]]=None,
                  similarity_kind:Literal['Hamming','Blosum62']='Hamming',
                  similarities_cfg:Optional[dict[str,Any]]=None,
                  compute_in_group_best_hits:bool=True,
                  best_hits_cfg:Optional[dict[str,Any]]=None,
                  similarities_comparison_loss:Optional[<built-
                  infunctioncallable>]=None,
                  compare_soft_best_hits_to_hard:bool=True)

DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their orthology networks, constructed using (reciprocal) best hits .

Type Default Details
group_sizes Sequence Number of sequences in each group (e.g. species) of the two MSAs
fixed_pairings Optional None If not None, fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc.
permutation_cfg Optional None If not None, configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations
similarity_kind Literal Hamming (Smoothly extended) similarity metric to use on all pairs of aligned sequences
similarities_cfg Optional None If not None, configuration dictionary containing init parameters for the internal HammingSimilarities or Blosum62Similarities object to compute similarity matrices
compute_in_group_best_hits bool True Whether to also compute best hits within each group (in addition to between different groups)
best_hits_cfg Optional None If not None, configuration dictionary containing init parameters for the internal BestHits object to compute soft/hard (reciprocal) best hits
similarities_comparison_loss Optional None If not None, custom callable to compute the differentiable loss between the soft/hard best hits matrices of the two MSAs
compare_soft_best_hits_to_hard bool True Whether to compare the soft best hits from the MSA to permute (x) to the hard or soft best hits from the reference MSA (y)
def test_besthits_bootstrap():
    # Data: two highly correlated MSAs
    n_classes = 3
    length = 100
    size_each_group = 10
    n_groups = 10
    # Define first MSA group by group
    x_tok_by_group = [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
    # Within-group shuffling to control for algorithmic biases towards identity permutation
    x_tok_by_group_shuffle = [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
    x_tok = torch.cat(x_tok_by_group, dim=0)
    x_tok_shuffle = torch.cat(x_tok_by_group_shuffle, dim=0)
    y_tok = (x_tok + 1) % n_classes
    x = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
    x_shuffle = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
    y = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())

    group_sizes = [size_each_group] * n_groups

    # Model
    model = BestHitsPairing(group_sizes=group_sizes)
    results = model.fit_bootstrap(x_shuffle, y)
    target_hard_loss = model.compute_losses_identity_perm(x, y)["hard"]

    # Check that the hard loss of the optimized permutation is close to the ground truth
    assert results.hard_losses[-2][-1] / target_hard_loss > 0.7

test_besthits_bootstrap()

source

MirrortreePairing

 MirrortreePairing (group_sizes:collections.abc.Sequence[int],
                    fixed_pairings:Optional[list[list[tuple[int,int]]]]=No
                    ne, permutation_cfg:Optional[dict[str,Any]]=None, simi
                    larity_kind:Literal['Hamming','Blosum62']='Hamming',
                    similarities_cfg:Optional[dict[str,Any]]=None,
                    similarities_comparison_loss:Optional[<built-
                    infunctioncallable>]=None)

DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their sequence distance networks as in the Mirrortree method.

Type Default Details
group_sizes Sequence Number of sequences in each group (e.g. species) of the two MSAs
fixed_pairings Optional None If not None, fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc.
permutation_cfg Optional None If not None, configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations
similarity_kind Literal Hamming (Smoothly extended) similarity metric to use on all pairs of aligned sequences
similarities_cfg Optional None If not None, configuration dictionary containing init parameters for the internal HammingSimilarities or Blosum62Similarities object to compute similarity matrices
similarities_comparison_loss Optional None If not None, custom callable to compute the differentiable loss between the similarity matrix of the two MSAs. Default: IntraGroupSimilarityLoss
def test_mirrortree_bootstrap():
    # Data: two highly correlated MSAs
    n_classes = 3
    length = 100
    size_each_group = 10
    n_groups = 10
    # Define first MSA group by group
    x_tok_by_group = [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
    # Within-group shuffling to control for algorithmic biases towards identity permutation
    x_tok_by_group_shuffle = [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
    x_tok = torch.cat(x_tok_by_group, dim=0)
    x_tok_shuffle = torch.cat(x_tok_by_group_shuffle, dim=0)
    y_tok = (x_tok + 1) % n_classes
    x = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
    x_shuffle = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
    y = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())

    group_sizes = [size_each_group] * n_groups

    # Model
    model = MirrortreePairing(group_sizes=group_sizes)
    results = model.fit_bootstrap(x_shuffle, y)
    target_hard_loss = model.compute_losses_identity_perm(x, y)["hard"]

    # Check that the hard loss of the optimized permutation is close to the ground truth
    assert results.hard_losses[-2][-1] / target_hard_loss > 0.95

test_mirrortree_bootstrap()

source

GraphAlignment

 GraphAlignment (group_sizes:collections.abc.Sequence[int],
                 fixed_pairings:Optional[list[list[tuple[int,int]]]]=None,
                 permutation_cfg:Optional[dict[str,Any]]=None,
                 comparison_loss:Optional[<built-
                 infunctioncallable>]=None)

DiffPaSS model for general graph alignment starting from the weighted adjacency matrices of two graphs.

Type Default Details
group_sizes Sequence Number of graph nodes in each group (e.g. species), assumed the same between the two graphs to align
fixed_pairings Optional None If not None, fixed pairings between groups, of the form [[(i1, j1), (i2, j2), …], …] where (i1, j1) are the indices of the first fixed pair in the first group to be paired, etc.
permutation_cfg Optional None If not None, configuration dictionary containing init parameters for the internal GeneralizedPermutation object to compute soft/hard permutations. Soft/hard permutations P act on adjacency matrices X via P @ X @ P.T
comparison_loss Optional None If not None, custom callable to compute the differentiable loss between the soft/hard-permuted adjacency matrix of graph x and the adjacency matrix of graph y. Defaults to dot product between all upper triangular elements
def test_graph_alignment_bootstrap():
    # Data: two identical weighted adjacency matrices
    size_each_group = 10
    n_groups = 10
    n_samples = size_each_group * n_groups
    x = torch.exp(torch.randn((n_samples, n_samples))).to(torch.get_default_dtype())
    y = x.clone()
    # Within-group shuffling to control for algorithmic biases towards identity permutation
    rand_perm_mats = []
    for _ in range(n_groups):
        rp_mat = torch.zeros(
            (size_each_group, size_each_group), dtype=x.dtype, device=x.device
        )
        rp_mat[torch.arange(size_each_group), torch.randperm(size_each_group)] = 1
        rand_perm_mats.append(rp_mat)
    x_shuffle = apply_hard_permutation_batch_to_similarity(x=x, perms=rand_perm_mats)

    group_sizes = [size_each_group] * n_groups

    # Model
    model = GraphAlignment(group_sizes=group_sizes)
    results = model.fit_bootstrap(x_shuffle, y)
    target_hard_loss = model.compute_losses_identity_perm(x, y)["hard"]

    # Check that the hard loss of the optimized permutation is close to the ground truth
    assert results.hard_losses[-2][-1] / target_hard_loss > 0.95

test_graph_alignment_bootstrap()