Source code for multiviewae.base.distributions

import torch
from torch.distributions import Normal, kl_divergence, Laplace, Bernoulli
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.utils import broadcast_all
from torch.nn.functional import binary_cross_entropy
from .constants import EPS
import torch.nn.functional as F

def compute_log_alpha(mu, logvar):
    # clamp because dropout rate p in 0-99%, where p = alpha/(alpha+1)
    return (logvar - 2 * torch.log(torch.abs(mu) + 1e-8)).clamp(min=-8, max=8)

[docs]class Default(): """Artificial distribution designed for data with unspecified distribution. Used so that log_likelihood and _sample methods can be called by model class. Args: x (list): List of input data. """ def __init__( self, **kwargs, ): self.x = kwargs['x']
[docs] def log_likelihood(self, x): """calculates the mean squared error between input data and reconstruction. Args: x (torch.Tensor): data reconstruction. Returns: torch.Tensor: Negative mean squared error. """ logits, x = broadcast_all(self.x, x) return - (logits - x)**2
def rsample(self): raise NotImplementedError def kl_divergence(self): raise NotImplementedError def sparse_kl_divergence(self): raise NotImplementedError def _sample(self, training=False, return_mean=True): return self.x
[docs]class Categorical(): """Artificial distribution designed for categorical data. Used so that log_likelihood and _sample methods can be called by model class. Args: x (list): List of input data. """ def __init__( self, **kwargs, ): self.x = kwargs['x']
[docs] def log_likelihood(self, x, eps=1e-6): """calculates the k-class cross entropy between input data and reconstruction. Args: x (torch.Tensor): data reconstruction. Returns: torch.Tensor: Negative k-class cross entropy. """ log_input = F.log_softmax(self.x + eps, dim=-1) ce = x * log_input return ce
def rsample(self): raise NotImplementedError def kl_divergence(self): raise NotImplementedError def sparse_kl_divergence(self): raise NotImplementedError def _sample(self, training=False, return_mean=True): return self.x
[docs]class Normal(Normal): """Univariate normal distribution. Inherits from torch.distributions.Normal. Args: loc (int, torch.Tensor): Mean of distribution. scale (int, torch.Tensor): Standard deviation of distribution. """ def __init__( self, **kwargs, ): self.loc = kwargs['loc'] if 'logvar' in kwargs: self.logvar = kwargs['logvar'] self.scale = kwargs['logvar'].mul(0.5).exp_() elif 'scale' in kwargs: self.scale = kwargs['scale'] if not isinstance(self.scale, torch.Tensor): self.scale = torch.tensor(self.scale) self.logvar = 2 * torch.log(self.scale) super().__init__(loc=self.loc, scale=self.scale) @property def variance(self): return self.scale.pow(2) def kl_divergence(self, other): logvar0 = self.logvar mu0 = self.loc logvar1 = other.logvar mu1 = other.loc return -0.5 * (1 - logvar0.exp()/logvar1.exp() - (mu0-mu1).pow(2)/logvar1.exp() + logvar0 - logvar1)
[docs] def sparse_kl_divergence(self): """ Implementation from: https://github.com/senya-ashukha/variational-dropout-sparsifies-dnn/blob/master/KL%20approximation.ipynb """ mu = self.loc logvar = torch.log(self.variance) log_alpha = compute_log_alpha(mu, logvar) k1, k2, k3 = 0.63576, 1.8732, 1.48695 neg_KL = ( k1 * torch.sigmoid(k2 + k3 * log_alpha) - 0.5 * torch.log1p(torch.exp(-log_alpha)) - k1 ) return -neg_KL
def log_likelihood(self, x): return self.log_prob(x) def _sample(self, *kwargs, training=False, return_mean=True): if training: return self.rsample(*kwargs) if return_mean: return self.loc return self.rsample()
[docs]class MultivariateNormal(MultivariateNormal): """Multivariate normal distribution with diagonal covariance matrix. Inherits from torch.distributions.multivariate_normal.MultivariateNormal. Args: loc (list, torch.Tensor): Mean of distribution. scale (int, torch.Tensor): Standard deviation of distribution. """ def __init__( self, **kwargs ): self.loc = torch.as_tensor(kwargs['loc']) if 'logvar' in kwargs: self.logvar = torch.as_tensor(kwargs['logvar']) self.scale = torch.exp(0.5 * kwargs['logvar']) + EPS elif 'scale' in kwargs: self.scale = torch.as_tensor(kwargs['scale']) if not isinstance(self.scale, torch.Tensor): self.scale = torch.tensor(self.scale) self.logvar = 2 * torch.log(self.scale) #used when fitting encoder/decoder distribution or prior distribution with different mean and SD values self.covariance_matrix = torch.diag_embed(self.scale) super().__init__(loc=self.loc, covariance_matrix=self.covariance_matrix) @property def variance(self): return self.scale.pow(2) def kl_divergence(self, other): kl = kl_divergence(torch.distributions.multivariate_normal.MultivariateNormal( \ loc=self.loc, covariance_matrix=self.covariance_matrix), other) return torch.unsqueeze(kl,-1) def sparse_kl_divergence(self): mu = self.loc logvar = torch.log(self.variance) log_alpha = compute_log_alpha(mu, logvar) k1, k2, k3 = 0.63576, 1.8732, 1.48695 neg_KL = ( k1 * torch.sigmoid(k2 + k3 * log_alpha) - 0.5 * torch.log1p(torch.exp(-log_alpha)) - k1 ) return -neg_KL def log_likelihood(self, x): ll = self.log_prob(x) return torch.unsqueeze(ll,-1) def _sample(self, *kwargs, training=False, return_mean=True): if training: return self.rsample(*kwargs) if return_mean: return self.loc return self.rsample()
[docs]class Bernoulli(Bernoulli): """Bernoulli distribution. Inherits from torch.distributions.Bernoulli. Args: x (list): List of input data. """ def __init__( self, **kwargs, ): x = kwargs['x'] super().__init__(logits=x) def log_likelihood(self, target): return self.log_prob(target)
[docs] def rsample(self): raise NotImplementedError
def kl_divergence(self): raise NotImplementedError def sparse_kl_divergence(self): raise NotImplementedError def _sample(self, training=False, return_mean=True): return self.sample()
[docs]class ApproxBernoulli(): """Artificial distribution designed for (approximately) Bernoulli distributed data. The data isn't restricted to bernoulli distribution, this class is designed as a wrapper for the log_likelihood() method which is required for the multiview models. Args: x (list): List of input data. """ def __init__( self, **kwargs, ): self.x = kwargs['x'] def log_likelihood(self, target): x, target = broadcast_all(self.x, target) x = torch.sigmoid(x) bce = binary_cross_entropy(x, target, reduction='none') return -bce def rsample(self): raise NotImplementedError def kl_divergence(self): raise NotImplementedError def sparse_kl_divergence(self): raise NotImplementedError def _sample(self, training=False, return_mean=True): return torch.sigmoid(self.x)
[docs]class Laplace(Laplace): """Laplace distribution. Inherits from torch.distributions.Laplace. Args: loc (list, torch.Tensor): Mean of distribution. scale (int, torch.Tensor): Standard deviation of distribution. """ def __init__( self, **kwargs ): if 'loc' in kwargs: self.loc = torch.as_tensor(kwargs['loc']) elif 'x' in kwargs: self.loc = torch.as_tensor(kwargs['x']) if 'logvar' in kwargs: self.logvar = torch.as_tensor(kwargs['logvar']) if 'with_softmax' in kwargs and kwargs['with_softmax']: self.scale = (F.softmax(self.logvar, dim=-1) * self.logvar.size(-1) + EPS).to(self.loc.device) else: self.scale = torch.exp(0.5 * self.logvar) + EPS elif 'scale' in kwargs: self.scale = torch.as_tensor(kwargs['scale']) else: self.scale = torch.tensor(0.75).to(self.loc.device) super().__init__(loc=self.loc, scale=self.scale) def kl_divergence(self, other): kl = kl_divergence(torch.distributions.laplace.Laplace( \ loc=self.loc, scale=self.scale), other) #check this works return torch.unsqueeze(kl,-1) def sparse_kl_divergence(self): raise NotImplementedError def log_likelihood(self, x): return self.log_prob(x) def _sample(self, *kwargs, training=False, return_mean=True): if training: return self.rsample(*kwargs) if return_mean: return self.loc return self.rsample()