gumbel_sinkhorn_ops
Throughout, a permutation \(0 \to p_0, 1 \to p_1, ..., n-1 \to p_{n-1}\) is encoded by the permutation matrix \(P = (p_{ij})_{i,j=0}^{n-1}\) with \(p_{ij} = 1\) if and only if \(j = p_i\), and \(0\) otherwise.
In NumPy/PyTorch, P[arange(n), p]
is identically equal to 1
, and we can obtain p
from P
by p = P.argmax(-1)
.
gumbel_noise_like
gumbel_noise_like (log_alpha:torch.Tensor, noise_factor:float=1.0, noise_std:bool=False)
Generate rescaled Gumbel noise with the same shape as log_alpha. The noise is rescaled by noise_factor
or, if noise_std
is True, by noise_factor
times the standard deviation of log_alpha.
unbias_by_randperms
unbias_by_randperms (func:<built-infunctioncallable>)
Decorator to unbias func
with two random permutations.
randperm_mat_like
randperm_mat_like (log_alpha:torch.Tensor)
Generate a random permutation matrix with the same shape as log_alpha[-2:]
. Assume log_alpha is of shape (batch_size, n, n).*
log_sinkhorn_norm
log_sinkhorn_norm (log_alpha:torch.Tensor, n_iter:int=20)
Iterative Sinkhorn normalization in log space, for numerical stability.
sinkhorn_norm
sinkhorn_norm (alpha:torch.Tensor, n_iter:int=20)
Iterative Sinkhorn normalization of non-negative matrices.
gumbel_sinkhorn
gumbel_sinkhorn (log_alpha:torch.Tensor, tau:Union[float,torch.Tensor]=1.0, n_iter:int=10, noise:bool=False, noise_factor:float=1.0, noise_std:bool=False)
Gumbel-Sinkhorn operator with a temperature parameter tau
. Given arbitrary square matrices, outputs bistochastic matrices that are close to permutation matrices when tau
is small.
matching
matching (log_alpha:torch.Tensor)
np_matching
np_matching (cost:numpy.ndarray)
Find an assignment matrix with maximum cost, using the Hungarian algorithm. Return the matrix in dense format.
gumbel_matching
gumbel_matching (log_alpha:torch.Tensor, noise:bool=False, noise_factor:float=1.0, noise_std:bool=False, unbias_lsa:bool=False)
Gumbel-matching operator, i.e. the solution of the linear assignment problem with optional Gumbel noise.
inverse_permutation
inverse_permutation (x:torch.Tensor, mats:torch.Tensor)
When mats contains permutation matrices, exchange the rows of x
using the inverse(s) of the permutation(s) encoded in mats
.