Distributions

class multiviewae.base.distributions.Default(**kwargs)[source]

Artificial distribution designed for data with unspecified distribution. Used so that log_likelihood and _sample methods can be called by model class. :param x: List of input data. :type x: list

log_likelihood(x)[source]

calculates the mean squared error between input data and reconstruction.

Parameters

x (torch.Tensor) – data reconstruction.

Returns

Negative mean squared error.

Return type

torch.Tensor

class multiviewae.base.distributions.Categorical(**kwargs)[source]

Artificial distribution designed for categorical data. Used so that log_likelihood and _sample methods can be called by model class. :param x: List of input data. :type x: list

log_likelihood(x, eps=1e-06)[source]

calculates the k-class cross entropy between input data and reconstruction. :param x: data reconstruction. :type x: torch.Tensor

Returns

Negative k-class cross entropy.

Return type

torch.Tensor

class multiviewae.base.distributions.Normal(**kwargs)[source]

Univariate normal distribution. Inherits from torch.distributions.Normal.

Parameters
  • loc (int, torch.Tensor) – Mean of distribution.

  • scale (int, torch.Tensor) – Standard deviation of distribution.

property variance

Returns the variance of the distribution.

sparse_kl_divergence()[source]

Implementation from: https://github.com/senya-ashukha/variational-dropout-sparsifies-dnn/blob/master/KL%20approximation.ipynb

class multiviewae.base.distributions.MultivariateNormal(**kwargs)[source]

Multivariate normal distribution with diagonal covariance matrix. Inherits from torch.distributions.multivariate_normal.MultivariateNormal.

Parameters
  • loc (list, torch.Tensor) – Mean of distribution.

  • scale (int, torch.Tensor) – Standard deviation of distribution.

property variance

Returns the variance of the distribution.

class multiviewae.base.distributions.Bernoulli(**kwargs)[source]

Bernoulli distribution. Inherits from torch.distributions.Bernoulli. :param x: List of input data. :type x: list

rsample()[source]

Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.

class multiviewae.base.distributions.ApproxBernoulli(**kwargs)[source]

Artificial distribution designed for (approximately) Bernoulli distributed data. The data isn’t restricted to bernoulli distribution, this class is designed as a wrapper for the log_likelihood() method which is required for the multiview models.

Parameters

x (list) – List of input data.

class multiviewae.base.distributions.Laplace(**kwargs)[source]

Laplace distribution. Inherits from torch.distributions.Laplace.

Parameters
  • loc (list, torch.Tensor) – Mean of distribution.

  • scale (int, torch.Tensor) – Standard deviation of distribution.