gumbel_sinkhorn_ops

Gumbel-Sinkhorn and Gumbel-matching operators

Throughout, a permutation \(0 \to p_0, 1 \to p_1, ..., n-1 \to p_{n-1}\) is encoded by the permutation matrix \(P = (p_{ij})_{i,j=0}^{n-1}\) with \(p_{ij} = 1\) if and only if \(j = p_i\), and \(0\) otherwise.

In NumPy/PyTorch, P[arange(n), p] is identically equal to 1, and we can obtain p from P by p = P.argmax(-1).


source

gumbel_noise_like

 gumbel_noise_like (log_alpha:torch.Tensor, noise_factor:float=1.0,
                    noise_std:bool=False)

Generate rescaled Gumbel noise with the same shape as log_alpha. The noise is rescaled by noise_factor or, if noise_std is True, by noise_factor times the standard deviation of log_alpha.


source

unbias_by_randperms

 unbias_by_randperms (func:<built-infunctioncallable>)

Decorator to unbias func with two random permutations.


source

randperm_mat_like

 randperm_mat_like (log_alpha:torch.Tensor)

Generate a random permutation matrix with the same shape as log_alpha[-2:]. Assume log_alpha is of shape (batch_size, n, n).*


source

log_sinkhorn_norm

 log_sinkhorn_norm (log_alpha:torch.Tensor, n_iter:int=20)

Iterative Sinkhorn normalization in log space, for numerical stability.


source

sinkhorn_norm

 sinkhorn_norm (alpha:torch.Tensor, n_iter:int=20)

Iterative Sinkhorn normalization of non-negative matrices.


source

gumbel_sinkhorn

 gumbel_sinkhorn (log_alpha:torch.Tensor,
                  tau:Union[float,torch.Tensor]=1.0, n_iter:int=10,
                  noise:bool=False, noise_factor:float=1.0,
                  noise_std:bool=False)

Gumbel-Sinkhorn operator with a temperature parameter tau. Given arbitrary square matrices, outputs bistochastic matrices that are close to permutation matrices when tau is small.


source

matching

 matching (log_alpha:torch.Tensor)

source

np_matching

 np_matching (cost:numpy.ndarray)

Find an assignment matrix with maximum cost, using the Hungarian algorithm. Return the matrix in dense format.


source

gumbel_matching

 gumbel_matching (log_alpha:torch.Tensor, noise:bool=False,
                  noise_factor:float=1.0, noise_std:bool=False,
                  unbias_lsa:bool=False)

Gumbel-matching operator, i.e. the solution of the linear assignment problem with optional Gumbel noise.


source

inverse_permutation

 inverse_permutation (x:torch.Tensor, mats:torch.Tensor)

When mats contains permutation matrices, exchange the rows of x using the inverse(s) of the permutation(s) encoded in mats.