Models

class multiviewae.models.AE(cfg=None, input_dim=None, z_dim=None)[source]

Multi-view Autoencoder model with a separate latent representation for each view.

Parameters
  • cfg (str) – Path to configuration file.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

list of latent dimensions for each view of type torch.Tensor.

Return type

z (list)

decode(z)[source]

Forward pass through decoder networks. Each latent is passed through all of the decoders.

Parameters

z (list) – list of latent dimensions for each view of type torch.Tensor.

Returns

list of data reconstructions.

Return type

x_recon (list)

forward(x)[source]

Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing list of data reconstructions (x_recon) and latent dimensions (z).

Return type

fwd_rtn (dict)

loss_function(x, fwd_rtn)[source]

Calculate reconstruction loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing list of data reconstructions (x_recon) and latent dimensions (z).

Returns

dictionary containing reconstruction loss.

Return type

losses (dict)

class multiviewae.models.mAAE(cfg=None, input_dim=None, z_dim=None)[source]

Multi-view Adversarial Autoencoder model with a separate latent representation for each view.

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • eps (float): Value added for numerical stability.

    • discriminator._target_ (multiviewae.architectures.mlp.Discriminator): Discriminator network class.

    • discriminator.hidden_layer_dim (list): Number of nodes per hidden layer.

    • discriminator.bias (bool): Whether to include a bias term in hidden layers.

    • discriminator.non_linear (bool): Whether to include a ReLU() function between layers.

    • discriminator.dropout_threshold (float): Dropout threshold of layers.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

list of latent dimensions for each view of type torch.Tensor.

Return type

z (list)

decode(z)[source]

Forward pass through decoder networks. Each latent is passed through all of the decoders.

Parameters

z (list) – list of latent dimensions for each view of type torch.Tensor.

Returns

list of decoding distributions.

Return type

px_zs (list)

disc(z)[source]

Forward pass of “real” samples from gaussian prior and “fake” samples from encoders through the discriminator network.

Parameters

z (list) – list of latent dimensions for each view of type torch.Tensor.

Returns

Discriminator network output for “real” samples. d_fake (list): list of discriminator network output for “fake” samples.

Return type

d_real (torch.Tensor)

forward_recon(x)[source]

Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing decoding distributions (px_zs) and latent dimensions (z).

Return type

fwd_rtn (dict)

forward_discrim(x)[source]

Apply encode and disc methods to input data to generate discriminator prediction on the latent dimensions and train discriminator parameters.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing discriminator output from “real” samples (d_real), discriminator output from “fake” samples (d_fake), and latent dimensions (z).

Return type

fwd_rtn (dict)

forward_gen(x)[source]

Apply encode and disc methods to input data to generate discriminator prediction on the latent dimensions and train encoder parameters.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

fwd_rtn (dict): dictionary containing discriminator output from “fake” samples (d_fake) and latent dimensions (z).

Return type

fwd_rtn (dict)

recon_loss(x, fwd_rtn)[source]

Calculate reconstruction loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – fwd_rtn from the forward_recon method.

Returns

Reconstruction error.

Return type

ll (torch.Tensor)

generator_loss(fwd_rtn)[source]

Calculate the generator loss.

Parameters

fwd_rtn (dict) – fwd_rtn from the forward_gen method.

Returns

Generator loss.

Return type

gen_loss (torch.Tensor)

discriminator_loss(fwd_rtn)[source]

Calculate the discriminator loss.

Parameters

fwd_rtn (dict) – fwd_rtn from the forward_discrim method.

Returns

Discriminator loss.

Return type

disc_loss (torch.Tensor)

class multiviewae.models.mWAE(cfg=None, input_dim=None, z_dim=None)[source]

Multi-view Adversarial Autoencoder model with wasserstein loss.

Wasserstein autoencoders: https://arxiv.org/abs/1711.01558

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • discriminator._target_ (multiviewae.architectures.mlp.Discriminator): Discriminator network class.

    • discriminator.hidden_layer_dim (list): Number of nodes per hidden layer.

    • discriminator.bias (bool): Whether to include a bias term in hidden layers.

    • discriminator.non_linear (bool): Whether to include a ReLU() function between layers.

    • discriminator.dropout_threshold (float): Dropout threshold of layers.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

list of latent dimensions for each view of type torch.Tensor.

Return type

z (list)

decode(z)[source]

Forward pass through decoder networks. Each latent is passed through all of the decoders.

Parameters

z (list) – list of latent dimensions for each view of type torch.Tensor.

Returns

list of decoding distributions.

Return type

px_zs (list)

disc(z)[source]

Forward pass of “real” samples from gaussian prior and “fake” samples from encoders through the discriminator network.

Parameters

z (list) – list of latent dimensions for each view of type torch.Tensor.

Returns

Discriminator network output for “real” samples. d_fake (list): list of discriminator network output for “fake” samples.

Return type

d_real (torch.Tensor)

forward_recon(x)[source]

Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing decoding distributions (px_zs) and latent dimensions (z).

Return type

fwd_rtn (dict)

forward_discrim(x)[source]

Apply encode and disc methods to input data to generate discriminator prediction on the latent dimensions and train discriminator parameters.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing discriminator output from “real” samples (d_real), discriminator output from “fake” samples (d_fake), and latent dimensions (z).

Return type

fwd_rtn (dict)

forward_gen(x)[source]

Apply encode and disc methods to input data to generate discriminator prediction on the latent dimensions and train encoder parameters.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

fwd_rtn (dict): dictionary containing discriminator output from “fake” samples (d_fake) and latent dimensions (z).

Return type

fwd_rtn (dict)

recon_loss(x, fwd_rtn)[source]

Calculate reconstruction loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – fwd_rtn from the forward_recon method.

Returns

Reconstruction error.

Return type

ll (torch.Tensor)

generator_loss(fwd_rtn)[source]

Calculate the generator loss.

Parameters

fwd_rtn (dict) – fwd_rtn from the forward_gen method.

Returns

Generator loss.

Return type

gen_loss (torch.Tensor)

discriminator_loss(fwd_rtn)[source]

Calculate the discriminator loss.

Parameters

fwd_rtn (dict) – fwd_rtn from the forward_discrim method.

Returns

Discriminator loss.

Return type

disc_loss (torch.Tensor)

class multiviewae.models.mcVAE(cfg=None, input_dim=None, z_dim=None)[source]

Multi-Channel Variational Autoencoder and Sparse Multi-Channel Variational Autoencoder.

Code is based on: https://github.com/ggbioing/mcvae

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model.beta (int, float): KL divergence weighting term.

    • model.sparse (bool): Whether to enforce sparsity of the encoding distribution.

    • model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar (int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Antelmi, Luigi & Ayache, Nicholas & Robert, Philippe & Lorenzi, Marco. (2019). Sparse Multi-Channel Variational Autoencoder for the Joint Analysis of Heterogeneous Data.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

list of encoding dimensions for each view.

Return type

qz_xs (list)

decode(qz_xs)[source]

Forward pass through decoder networks. Each latent is passed through all of the decoders.

Parameters

z (list) – list of latent dimensions for each view of type torch.Tensor.

Returns

A nested list of decoding distributions. The outer list has a n_view element indicating latent dimensions index. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

px_zs (list)

forward(x)[source]

Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding (qz_xs) and decoding (px_zs) distributions.

Return type

fwd_rtn (dict)

loss_function(x, fwd_rtn)[source]

Calculate mcVAE loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing each element of the mcVAE loss.

Return type

losses (dict)

calc_kl(qz_xs)[source]

Calculate mcVAE KL-divergence loss.

Parameters

qz_xs (list) – list of encoding distributions.

Returns

KL-divergence loss across all views.

Return type

(torch.Tensor)

calc_ll(x, px_zs)[source]

Calculate log-likelihood loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • px_zs (list) – list of decoding distributions.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

class multiviewae.models.mVAE(cfg=None, input_dim=None, z_dim=None)[source]

Multimodal Variational Autoencoder (MVAE).

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model.beta (int, float): KL divergence weighting term.

    • model.join_type (str): Method of combining encoding distributions.

    • model.warmup (int): KL term weighted by beta linearly increased to 1 over this many epochs.

    • model.use_prior (bool): Whether to use a prior expert when combining encoding distributions.

    • model.sparse (bool): Whether to enforce sparsity of the encoding distribution.

    • model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0.

    • model.weight_ll (bool): Whether to weight the log-likelihood loss by 1/n_views.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar (int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Wu, M., & Goodman, N.D. (2018). Multimodal Generative Models for Scalable Weakly-Supervised Learning. NeurIPS.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

Single element list of joint encoding distribution.

Return type

(list)

encode_subset(x, subset)[source]

Forward pass through encoder networks for the specified subset.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

Single element list of joint encoding distribution.

Return type

(list)

decode(qz_x)[source]

Forward pass of joint latent dimensions through decoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

decode_subset(qz_x, subset)[source]

Forward pass of joint latent dimensions through decoder networks for the specified subset.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

forward(x)[source]

Apply encode and decode methods to input data to generate the joint latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding and decoding distributions.

Return type

fwd_rtn (dict)

calc_kl(qz_x)[source]

Calculate KL-divergence loss.

Parameters

qz_xs (list) – Single element list containing joint encoding distribution.

Returns

KL-divergence loss.

Return type

(torch.Tensor)

calc_ll(x, px_zs)[source]

Calculate log-likelihood loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • px_zs (list) – list of decoding distributions.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

loss_function(x, fwd_rtn)[source]

Calculate Multimodal VAE loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing each element of the MVAE loss.

Return type

losses (dict)

calc_nll(x, K=1000, batch_size_K=100)[source]

Calculate negative log-likelihood used to evaluate model performance.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

Negative log-likelihood.

Return type

nll (torch.Tensor)

class multiviewae.models.JMVAE(cfg=None, input_dim=None, z_dim=None)[source]

JMVAE-kl.

Parameters
  • cfg (str) – Path to configuration file. Model specific parameters in addition to default parameters: - alpha (float): Weighting of KL-divergence loss from individual encoders. - encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of Encoder to use. - encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution. - decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use. - decoder.default.init_logvar(int, float): Initial value for log variance of decoder. - decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Suzuki, Masahiro & Nakayama, Kotaro & Matsuo, Yutaka. (2016). Joint Multimodal Learning with Deep Generative Models.

encode(x)[source]

Forward pass through joint encoder network.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

Single element list containing joint encoding distribution, qz_xy.

Return type

(list)

encode_separate(x)[source]

Forward pass through separate encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

Encoding distribution for modality X. qz_y: Encoding distribution for modality Y.

Return type

qz_x

decode(qz_x)[source]

Forward pass of joint latent dimensions through decoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a 2 element list with the position in the list indicating the decoder index.

Return type

(list)

forward(x)[source]

Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding and decoding distributions.

Return type

fwd_rtn (dict)

calc_kl(qz_xy, qz_x, qz_y)[source]

Calculate JMVAE-kl KL-divergence loss.

Parameters
  • qz_xy (list) – Single element list containing shared encoding distribution.

  • qz_x (list) – Single element list containing encoding distribution for modality X.

  • qz_y (list) – Single element list containing encoding distribution for modality Y.

Returns

KL-divergence loss.

Return type

(torch.Tensor)

calc_ll(x, px_zs)[source]

Calculate log-likelihood loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • px_zs (list) – list of decoding distributions.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

loss_function(x, fwd_rtn)[source]

Calculate JMVAE-kl loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing each element of the JMVAE loss.

Return type

losses (dict)

calc_nll(x, K=1000, batch_size_K=100)[source]

Calculate negative log-likelihood used to evaluate model performance.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

Negative log-likelihood.

Return type

nll (torch.Tensor)

class multiviewae.models.me_mVAE(cfg=None, input_dim=None, z_dim=None)[source]

Multimodal Variational Autoencoder (MVAE).

Loss optimises the ELBO term from the joint posterior distribution, as well as the separate ELBO terms for each view. me_mVAE stands for multi ELBO Multimodal VAE

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model.beta (int, float): KL divergence weighting term.

    • model.join_type (str): Method of combining encoding distributions.

    • model.warmup (int): KL term weighted by beta linearly increased to 1 over this many epochs.

    • model.use_prior (bool): Whether to use a prior expert when combining encoding distributions.

    • model.sparse (bool): Whether to enforce sparsity of the encoding distribution.

    • model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0.

    • model.weight_kld (bool): Whether to weight the KL term by the number of views.

    • model.weight_ll (bool): Whether to weight the log-likelihood term by the number of views.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar (int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Wu, M., & Goodman, N.D. (2018). Multimodal Generative Models for Scalable Weakly-Supervised Learning. NeurIPS.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

Single element list of joint encoding distribution. (list): List containing separate encoding distributions.

Return type

(list)

encode_subset(x, subset)[source]

Forward pass through encoder networks for the specified subset.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

Single element list of joint encoding distribution.

Return type

(list)

decode(qz_x)[source]

Forward pass of joint latent dimensions through decoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

decode_subset(qz_x, subset)[source]

Forward pass of joint latent dimensions through decoder networks for the specified subset.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

decode_separate(qz_xs)[source]

Forward pass of each view specific latent dimensions through the respective decoder network.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the view specific latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

forward(x)[source]

Apply encode, decode, encode_separate and decode_separate methods to input data to generate the joint and modality specific latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding and decoding distributions.

Return type

fwd_rtn (dict)

calc_kl(qz_xs)[source]

Calculate KL-divergence loss.

Parameters

qz_xs (list) – list of encoding distributions.

Returns

KL-divergence loss.

Return type

(torch.Tensor)

calc_ll(x, px_zs)[source]

Calculate log-likelihood loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • px_zs (list) – list of decoding distributions.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

loss_function(x, fwd_rtn)[source]

Calculate multi ELBO Multimodal VAE loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing each element of the MVAE loss.

Return type

losses (dict)

calc_nll(x, K=1000, batch_size_K=100)[source]

Calculate negative log-likelihood used to evaluate model performance.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

Negative log-likelihood.

Return type

nll (torch.Tensor)

class multiviewae.models.mmVAE(cfg=None, input_dim=None, z_dim=None)[source]

Mixture-of-Experts Multimodal Variational Autoencoder (MMVAE).

Code is based on: https://github.com/iffsid/mmvae

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model.K (int): Number of samples to take from encoding distribution.

    • model.DREG_loss (bool): Whether to use DReG estimator when using large K value.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar (int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Shi, Y., Siddharth, N., Paige, B., & Torr, P.H. (2019). Variational Mixture-of-Experts Autoencoders for Multi-Modal Deep Generative Models. ArXiv, abs/1911.03393.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

list of encoding distributions.

Return type

(list)

encode_subset(x, subset)[source]

Forward pass through encoder networks for a subset of modalities. :param x: list of input data of type torch.Tensor. :type x: list :param subset: list of modalities to encode. :type subset: list

Returns

list of encoding distributions.

Return type

(list)

decode(qz_xs)[source]

Forward pass through decoder networks. Each latent is passed through all of the decoders.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions. The outer list has a n_view element indicating latent dimensions index. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

decode_subset(qz_xs, subset)[source]

Forward pass through decoder networks for a subset of modalities. Each latent is passed through its own decoder.

Parameters
  • qz_xs (list) – list of encoding distributions.

  • subset (list) – list of modalities to decode.

Returns

A nested list of decoding distributions. The outer list is a single element list, the inner list is a subset element list of decoding distributions.

Return type

(list)

forward(x)[source]

Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding (qz_xs) and decoding (px_zs) distributions.

Return type

fwd_rtn (dict)

loss_function(x, fwd_rtn)[source]

Wrapper function for mmVAE loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing mmVAE loss.

Return type

losses (dict)

moe_iwae(x, qz_xs, px_zs)[source]

Calculate Mixture-of-Experts importance weighted autoencoder (IWAE) loss used for the mmVAE model.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

the output tensor.

Return type

(torch.Tensor)

log_mean_exp(value, dim=0, keepdim=False)[source]

Returns the log of the mean of the exponentials along the given dimension (dim).

Parameters
  • value (torch.Tensor) – the input tensor.

  • dim (int, optional) – the dimension along which to take the mean.

  • keepdim (bool, optional) – whether the output tensor has dim retained or not.

Returns

the output tensor.

Return type

(torch.Tensor)

class multiviewae.models.mvtCAE(cfg=None, input_dim=None, z_dim=None)[source]

Multi-View Total Correlation Auto-Encoder (MVTCAE).

Code is based on: https://github.com/gr8joo/MVTCAE

NOTE: This implementation currently only caters for a PoE posterior distribution. MoE and MoPoE posteriors will be included in further work.

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model.beta (int, float): KL divergence weighting term.

    • model.alpha (int, float): Log likelihood, Conditional VIB and VIB weighting term.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar(int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Hwang, HyeongJoo and Kim, Geon-Hyeong and Hong, Seunghoon and Kim, Kee-Eung. Multi-View Representation Learning via Total Correlation Objective. 2021. NeurIPS

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

qz_xs (list): list containing separate encoding distributions. qz_x (list): Single element list containing PoE joint encoding distribution.

Return type

Returns the separate and/or joint encoding distributions depending on whether the model is in the training stage

encode_subset(x, subset)[source]

Forward pass through encoder networks for a subset of modalities.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • subset (list) – list of modalities to encode.

Returns

qz_xs (list): list containing separate encoding distributions. qz_x (list): Single element list containing PoE joint encoding distribution.

Return type

Returns either the joint or separate encoding distributions depending on whether the model is in the training stage

decode(qz_x)[source]

Forward pass of joint latent dimensions through decoder networks.

Parameters

qz_x (list) – list of joint encoding distribution.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

decode_subset(qz_x, subset)[source]

Forward pass of joint latent dimensions through decoder networks for a subset of modalities.

forward(x)[source]

Apply encode and decode methods to input data to generate the joint latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding and decoding distributions.

Return type

fwd_rtn (dict)

loss_function(x, fwd_rtn)[source]

Calculate MVTCAE loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing each element of the MVTCAE loss.

Return type

losses (dict)

calc_kl_cvib(qz_x, qz_xs)[source]

Calculate KL-divergence between PoE joint encoding distribution and the encoding distribution for each view.

Parameters

qz_xs (list) – list of encoding distributions of each view.

Returns

KL-divergence loss.

Return type

kl (torch.Tensor)

calc_kl_groupwise(qz_x)[source]

Calculate KL-divergence between the PoE joint encoding distribution and the prior distribution.

Parameters

qz_xs (list) – list of encoding distributions of each view.

Returns

KL-divergence loss.

Return type

kl (torch.Tensor)

calc_ll(x, px_zs)[source]

Calculate log-likelihood loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • px_zs (list) – list of decoding distributions.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

class multiviewae.models.DVCCA(cfg=None, input_dim=None, z_dim=None)[source]

Deep Variational Canonical Correlation Analysis (DVCCA).

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model.beta (int, float): KL divergence weighting term.

    • model.private (bool): Whether to include private view-specific latent dimensions.

    • model.sparse (bool): Whether to enforce sparsity of the encoding distribution.

    • model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar(int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Wang, Weiran & Lee, Honglak & Livescu, Karen. (2016). Deep Variational Canonical Correlation Analysis.

configure_optimizers()[source]

Configure optimizers for encoder, private encoder, and decoder network parameters.

Returns

list of Adam optimizers for encoders and decoders.

Return type

optimizers (list)

encode(x)[source]

Forward pass through encoder network. For DVCCA-private a forward pass is performed through each private encoder and the output latent is concatenated with the shared latent.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

qz_x (list): list containing the shared encoding distribution. qz_xs (list): list of encoding distributions for shared and private latents of DVCCA-private. qh_xs (list): list of encoding distributions for private latents of DVCCA-private.

Return type

Returns a combination of the following depending on the training stage and model type

decode(qz_x)[source]

Forward pass through decoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared or shared and private latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

forward(x)[source]

Apply encode and decode methods to input data to generate latent dimensions and data reconstructions. For DVCCA, the shared encoding distribution is passed to the decode method. For DVCCA-private, the joint distribution of the shared and private latents for each view is passed to the decode method.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing list of decoding distributions (px_zs), shared encoding distribution (qz_x), and (for DVCCA-private) private encoding distributions (qh_xs).

Return type

fwd_rtn (dict)

loss_function(x, fwd_rtn)[source]

Calculate DVCCA loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing list of decoding distributions (px_zs), shared encoding distribution (qz_x), and (for DVCCA-private) private encoding distributions (qh_xs).

Returns

dictionary containing each element of the DVCCA loss.

Return type

losses (dict)

calc_kl(qz_x, qh_xs)[source]

Wrapper function for calculating KL-divergence loss.

Parameters
  • qz_x (list) – Single element list containing shared encoding distribution.

  • qh_xs (list) – list of private encoding distributions for DVCCA-private.

Returns

KL-divergence loss across all views.

Return type

(torch.Tensor)

calc_kl_(dist)[source]

Calculate KL-divergence.

Parameters

dist – Distribution object.

Returns

Kl-divergence.

Return type

(torch.Tensor)

calc_ll(x, px_zs)[source]

Calculate log-likelihood loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • px_zs (list) – list of decoding distributions.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

class multiviewae.models.MoPoEVAE(cfg=None, input_dim=None, z_dim=None)[source]

Mixture-of-Product-of-Experts Variational Autoencoder.

Code is based on: https://github.com/thomassutter/MoPoE

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model.beta (int, float): KL divergence weighting term.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar (int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Sutter, Thomas & Daunhawer, Imant & Vogt, Julia. (2021). Generalized Multimodal ELBO.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

list containing the MoE joint encoding distribution. If training, the model also returns the encoding distribution for each subset.

Return type

(list)

encode_subset(x, subset)[source]

Forward pass through encoder networks for a subset of modalities. :param x: list of input data of type torch.Tensor. :type x: list :param subset: list of modalities to encode. :type subset: list

Returns

list containing the MoE joint encoding distribution.

Return type

(list)

decode(qz_x)[source]

Forward pass of joint latent dimensions through decoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

forward(x)[source]

Apply encode and decode methods to input data to generate the joint and subset latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding and decoding distributions.

Return type

fwd_rtn (dict)

loss_function(x, fwd_rtn)[source]

Calculate MoPoE VAE loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing each element of the MoPoE VAE loss.

Return type

losses (dict)

calc_kl_moe(qz_xs)[source]

Calculate KL-divergence between the each PoE subset posterior and the prior distribution.

Parameters

qz_xs (list) – list of encoding distributions.

Returns

KL-divergence loss.

Return type

(torch.Tensor)

set_subsets(n_views=None)[source]

Create combinations of subsets of views.

Returns

list of unique combinations of n_views.

Return type

subset_list (list)

calc_ll(x, px_zs)[source]

Calculate log-likelihood loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • px_zs (list) – list of decoding distributions.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

class multiviewae.models.mmJSD(cfg=None, input_dim=None, z_dim=None)[source]

Multimodal Jensen-Shannon divergence (mmJSD) model with Product-of-Experts dynamic prior.

Code is based on: https://github.com/thomassutter/mmjsd

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model.private (bool): Whether to include private modality-specific latent dimensions.

    • model.beta (int, float): KL divergence weighting term.

    • model.alpha (int, float): JSD divergence weighting term.

    • model.s_dim (int): Number of private latent dimensions.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar(int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Sutter, Thomas & Daunhawer, Imant & Vogt, Julia. (2021). Multimodal Generative Learning Utilizing Jensen-Shannon-Divergence. Advances in Neural Information Processing Systems. 33.

encode(x)[source]

Forward pass through encoder network. If self.private=True, the first two dimensions of each latent are used for the modality-specific part and the remaining dimensions for the joint content.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

qz_xs (list): Single element list containing the MoE encoding distribution for self.private=False. qzs_xs (list): list containing each encoding distribution for self.private=False. qz_x (list): Single element list containing the PoE encoding distribution for self.private=False.

qc_x (list): Single element list containing the PoE shared encoding distribution. qscs_xs (list): list containing combined shared and private latents. qs_xs (list): list of encoding distributions for private latents. qcs_xs (list): list containing encoding distributions for shared latent dimensions for each view.

Return type

Returns a combination of the following depending on the training stage and model type

encode_subset(x, subset)[source]

Forward pass through encoder networks for a subset of modalities. :param x: list of input data of type torch.Tensor. :type x: list :param subset: list of modalities to encode. :type subset: list

Returns

list containing the PoE joint encoding distribution.

Return type

(list)

decode(qz_x)[source]

Forward pass of latent dimensions through decoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element, the inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

decode_subset(qz_x, subset)[source]

Forward pass of latent dimensions through decoder networks for a subset of modalities.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • subset (list) – list of modalities to decode.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element, the inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

forward(x)[source]

Apply encode and decode methods to input data to generate the joint and modality specific latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding and decoding distributions.

Return type

fwd_rtn (dict)

calc_kl(qz_xs)[source]

Calculate KL-divergence loss.

Parameters

qz_xs (list) – list of encoding distributions.

Returns

KL-divergence loss.

Return type

(torch.Tensor)

calc_ll(x, px_zs)[source]

Calculate log-likelihood loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • px_zs (list) – list of decoding distributions.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

calc_jsd(qcs_xs, qc_x)[source]

Calculate Jensen-Shannon Divergence loss.

Parameters
  • qcs_xs (list) – list of encoding distributions of each view for shared latent dimensions.

  • qc_x (list) – Dynamic prior given by PoE of shared encoding distributions.

Returns

Jensen-Shannon Divergence loss.

Return type

jsd (torch.Tensor)

loss_function(x, fwd_rtn)[source]

Calculate mmJSD loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing each element of the mmJSD loss.

Return type

losses (dict)

class multiviewae.models.weighted_mVAE(cfg=None, input_dim=None, z_dim=None)[source]

Generalised Product-of-Experts Variational Autoencoder (gPoE-MVAE).

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model.beta (int, float): KL divergence weighting term.

    • model.private (bool): Whether to include private view-specific latent dimensions.

    • model.sparse (bool): Whether to enforce sparsity of the encoding distribution.

    • model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar (int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Cao, Y., & Fleet, D. (2014). Generalized Product of Experts for Automatic and Principled Fusion of Gaussian Process Predictions. arXiv. Lawry Aguila, A., Chapman, J., Altmann, A. (2023). Multi-modal Variational Autoencoders for normative modelling across multiple imaging modalities. arXiv

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

qz_x (list): Single element list containing the PoE encoding distribution for self.private=False.

qc_x (list): Single element list containing the PoE shared encoding distribution. qs_xs (list): list of encoding distributions for private latents. qscs_xs (list): list containing combined shared and private latents. qcs_xs (list): list containing encoding distributions for shared latent dimensions for each view.

Return type

Returns a combination of the following depending on the training stage and model type

decode(qz_x)[source]

Forward pass of joint latent dimensions through decoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

forward(x)[source]

Apply encode and decode methods to input data to generate the joint latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding and decoding distributions.

Return type

fwd_rtn (dict)

calc_kl(qz_x)[source]

Calculate KL-divergence loss.

Parameters

qz_x (list) – Single element list containing joint encoding distribution.

Returns

KL-divergence loss.

Return type

(torch.Tensor)

calc_kl_separate(qc_xs)[source]

Calculate KL-divergence loss.

Parameters

qc_xs (list) – list of encoding distributions for private/shared latent dimensions for each view.

Returns

KL-divergence loss.

Return type

(torch.Tensor)

calc_ll(x, px_zs)[source]

Calculate log-likelihood loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • px_zs (list) – list of decoding distributions.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

loss_function(x, fwd_rtn)[source]

Calculate Multimodal VAE loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing each element of the MVAE loss.

Return type

losses (dict)

class multiviewae.models.DMVAE(cfg=None, input_dim=None, z_dim=None)[source]

Disentangled multi-modal variational autoencoder (DMVAE)

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model._lambda (list, optional): Log likelihood weighting term for each modality.

    • model.s_dim (int): Number of private latent dimensions.

    • model.beta (int, float): KL divergence weighting term.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar (int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

M. Lee and V. Pavlovic, “Private-Shared Disentangled Multimodal VAE for Learning of Latent Representations,” 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), Nashville, TN, USA, 2021, pp. 1692-1700, doi: 10.1109/CVPRW53098.2021.00185.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

Single element list containing the PoE shared encoding distribution. qcs_xs (list): list containing encoding distributions for shared latent dimensions for each view. qs_xs (list): list of encoding distributions for private latents. qscs_xs (list): nested list containing combined PoE shared and private latents. qscss_xs (list): nested list containing combined shared latents from each modality and private latents for same and cross view reconstruction.

Return type

qc_x (list)

decode(qz_x)[source]

Forward pass of joint latent dimensions through decoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

decode_separate(qz_xs)[source]

Forward pass through decoder networks. Each shared latent is passed through all of the decoders with the private latents from the same view.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a n_view element list with position in the list indicating the decoder index. The inner list is a n_view element list with the position in the list indicating latent dimensions index. NOTE: This is the reverse to other models.

Return type

(list)

forward(x)[source]

Apply encode and decode methods to input data to generate the latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding and decoding distributions.

Return type

fwd_rtn (dict)

calc_kl_joint_latent(qz_x, qs_xs)[source]

Calculate KL-divergence loss for the first terms in Equation 3.

Parameters
  • qz_x (list) – Single element list containing joint encoding distribution.

  • qs_xs (list) – list of encoding distributions for private latent dimensions for each view.

Returns

KL-divergence loss.

Return type

(torch.Tensor)

calc_kl_separate_latent(qcs_xs, qs_xs)[source]

Calculate KL-divergence loss for the second terms in Equation 3.

Parameters
  • qcs_x (list) – list of the shared encoding distributions calculated from each view.

  • qs_xs (list) – list of encoding distributions for private latent dimensions for each view.

Returns

KL-divergence loss.

Return type

(torch.Tensor)

calc_ll_joint(x, px_zs)[source]

Calculate log-likelihood loss from the joint encoding distribution for each modality.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • px_zs (list) – list of decoding distributions.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

calc_ll_separate(x, pxs_zs)[source]

Calculate cross-modal and self-reconstrution log-likelihood loss from the shared encoding distribution for each modality and private latents.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • pxs_zs (list) – nested list of decoding distributons. NOTE: The ordering of decoding distribution is the reverse compared to other models.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

loss_function(x, fwd_rtn)[source]

Calculate DMVAE loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing each element of the DMVAE loss.

Return type

losses (dict)

configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.

Returns

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • Tuple of dictionaries as described above, with an optional "frequency" key.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated"
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

The frequency value specified in a dict along with the optimizer key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:

  • In the former case, all optimizers will operate on the given batch in each optimization step.

  • In the latter, only one optimizer will operate on the given batch at every step.

This is different from the frequency value specified in the lr_scheduler_config mentioned above.

def configure_optimizers(self):
    optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
    optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
    return [
        {"optimizer": optimizer_one, "frequency": 5},
        {"optimizer": optimizer_two, "frequency": 10},
    ]

In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the lr_scheduler key in the above dict, the scheduler will only be updated when its optimizer is being used.

Examples:

# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    }
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
# https://arxiv.org/abs/1704.00028
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}
    )

Note

Some things to know:

  • Lightning calls .backward() and .step() on each optimizer and learning rate scheduler as needed.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizers.

  • If you use multiple optimizers, training_step() will have an additional optimizer_idx parameter.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.

  • If you need to control how often those optimizers step or override the default .step() schedule, override the optimizer_step() hook.

class multiviewae.models.weighted_DMVAE(cfg=None, input_dim=None, z_dim=None)[source]

Variant of Disentangled multi-modal variational autoencoder (DMVAE) using weighted Product-of-Experts for joint encoding distribution.

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model._lambda (list, optional): Log likelihood weighting term for each modality.

    • model.s_dim (int): Number of private latent dimensions.

    • model.beta (int, float): KL divergence weighting term.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar (int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

M. Lee and V. Pavlovic, “Private-Shared Disentangled Multimodal VAE for Learning of Latent Representations,” 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), Nashville, TN, USA, 2021, pp. 1692-1700, doi: 10.1109/CVPRW53098.2021.00185.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

Single element list containing the PoE shared encoding distribution. qcs_xs (list): list containing encoding distributions for shared latent dimensions for each view. qs_xs (list): list of encoding distributions for private latents. qscs_xs (list): nested list containing combined PoE shared and private latents. qscss_xs (list): nested list containing combined shared latents from each modality and private latents for same and cross view reconstruction.

Return type

qc_x (list)

decode(qz_x)[source]

Forward pass of joint latent dimensions through decoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

decode_separate(qz_xs)[source]

Forward pass through decoder networks. Each shared latent is passed through all of the decoders with the private latents from the same view.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

A nested list of decoding distributions, px_zs. The outer list has a n_view element list with position in the list indicating the decoder index. The inner list is a n_view element list with the position in the list indicating latent dimensions index. NOTE: This is the reverse to other models.

Return type

(list)

forward(x)[source]

Apply encode and decode methods to input data to generate the latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding and decoding distributions.

Return type

fwd_rtn (dict)

calc_kl_joint_latent(qz_x, qs_xs)[source]

Calculate KL-divergence loss for the first terms in Equation 3.

Parameters
  • qz_x (list) – Single element list containing joint encoding distribution.

  • qs_xs (list) – list of encoding distributions for private latent dimensions for each view.

Returns

KL-divergence loss.

Return type

(torch.Tensor)

calc_kl_separate_latent(qcs_xs, qs_xs)[source]

Calculate KL-divergence loss for the second terms in Equation 3.

Parameters
  • qcs_x (list) – list of the shared encoding distributions calculated from each view.

  • qs_xs (list) – list of encoding distributions for private latent dimensions for each view.

Returns

KL-divergence loss.

Return type

(torch.Tensor)

calc_ll_joint(x, px_zs)[source]

Calculate log-likelihood loss from the joint encoding distribution for each modality.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • px_zs (list) – list of decoding distributions.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

calc_ll_separate(x, pxs_zs)[source]

Calculate cross-modal and self-reconstrution log-likelihood loss from the shared encoding distribution for each modality and private latents.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • pxs_zs (list) – nested list of decoding distributons. NOTE: The ordering of decoding distribution is the reverse compared to other models.

Returns

Log-likelihood loss.

Return type

ll (torch.Tensor)

loss_function(x, fwd_rtn)[source]

Calculate DMVAE loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing each element of the DMVAE loss.

Return type

losses (dict)

configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.

Returns

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • Tuple of dictionaries as described above, with an optional "frequency" key.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated"
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

The frequency value specified in a dict along with the optimizer key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:

  • In the former case, all optimizers will operate on the given batch in each optimization step.

  • In the latter, only one optimizer will operate on the given batch at every step.

This is different from the frequency value specified in the lr_scheduler_config mentioned above.

def configure_optimizers(self):
    optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
    optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
    return [
        {"optimizer": optimizer_one, "frequency": 5},
        {"optimizer": optimizer_two, "frequency": 10},
    ]

In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the lr_scheduler key in the above dict, the scheduler will only be updated when its optimizer is being used.

Examples:

# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    }
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
# https://arxiv.org/abs/1704.00028
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}
    )

Note

Some things to know:

  • Lightning calls .backward() and .step() on each optimizer and learning rate scheduler as needed.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizers.

  • If you use multiple optimizers, training_step() will have an additional optimizer_idx parameter.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.

  • If you need to control how often those optimizers step or override the default .step() schedule, override the optimizer_step() hook.

class multiviewae.models.mmVAEPlus(cfg=None, input_dim=None, z_dim=None)[source]

Mixture-of-Experts Multimodal Variational Autoencoder (MMVAE).

Code is based on: https://github.com/iffsid/mmvae

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model.K (int): Number of samples to take from encoding distribution.

    • model.DREG_loss (bool): Whether to use DReG estimator when using large K value.

    • encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.

    • encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.

    • decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.

    • decoder.default.init_logvar (int, float): Initial value for log variance of decoder.

    • decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Shi, Y., Siddharth, N., Paige, B., & Torr, P.H. (2019). Variational Mixture-of-Experts Autoencoders for Multi-Modal Deep Generative Models. ArXiv, abs/1911.03393.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

list of encoding distributions for shared (qu_xs) and private (qw_xs) latent dimensions during training, otherwise return samples from encoding distributions.

Return type

(list)

encode_subset(x, subset)[source]

Forward pass through encoder networks for a subset of modalities. For modalities not in subset, shared latents are sampled from a random modality and private latents from the shared prior.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • subset (list) – list of modalities to encode.

Returns

list of samples from encoding distributions.

Return type

(list)

decode(zss)[source]

Forward pass through decoder networks. Each latent is passed through all of the decoders.

Parameters
  • zss (list) – list of latent samples if not training or list containing qw_xs (list of private encoding distributions) and

  • qu_xs (list of shared encoding distributions) –

Returns

A nested list of decoding distributions. The outer list has a n_view element indicating latent dimensions index. The inner list is a n_view element list with the position in the list indicating the decoder index.

Return type

(list)

decode_subset(zss, subset)[source]

Forward pass through decoder networks for a subset of modalities. Each latent is passed through its own decoder.

Parameters
  • zss (list) – list of latent samples for each modality.

  • subset (list) – list of modalities to decode.

Returns

A list of decoding distributions for each modality in subset.

Return type

(list)

forward(x)[source]

Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing encoding (qw_xs and qu_xs) and decoding (px_zs) distributions.

Return type

fwd_rtn (dict)

loss_function(x, fwd_rtn)[source]

Wrapper function for mmVAE loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

dictionary containing mmVAE loss.

Return type

losses (dict)

moe_iwae(x, qw_xs, qu_xs, px_zs)[source]

Calculate Mixture-of-Experts importance weighted autoencoder (IWAE) loss used for the mmVAE model.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing encoding and decoding distributions.

Returns

the output tensor.

Return type

(torch.Tensor)

log_mean_exp(value, dim=0, keepdim=False)[source]

Returns the log of the mean of the exponentials along the given dimension (dim).

Parameters
  • value (torch.Tensor) – the input tensor.

  • dim (int, optional) – the dimension along which to take the mean.

  • keepdim (bool, optional) – whether the output tensor has dim retained or not.

Returns

the output tensor.

Return type

(torch.Tensor)

class multiviewae.models.DCCAE(cfg=None, input_dim=None, z_dim=None)[source]

Deep Canonically Correlated Autoencoder (DCCAE). CCA implementation adapted from: https://github.com/jameschapman19/cca_zoo

Parameters
  • cfg (str) –

    Path to configuration file. Model specific parameters in addition to default parameters:

    • model._lambda (int, float): Reconstruction weighting term

  • input_dim (list) – Dimensionality of the input data.

  • z_dim (int) – Number of latent dimensions.

References

Wang, Weiran & Arora, Raman & Livescu, Karen & Bilmes, Jeff. (2016). On Deep Multi-View Representation Learning: Objectives and Optimization.

encode(x)[source]

Forward pass through encoder networks.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

list of latent dimensions for each view of type torch.Tensor.

Return type

z (list)

decode(z)[source]

Forward pass through decoder networks. Each latent is passed through all of the decoders.

Parameters

z (list) – list of latent dimensions for each view of type torch.Tensor.

Returns

list of data reconstructions.

Return type

x_recon (list)

cca(zs)[source]

CCA loss calculation. Adapted from: https://github.com/jameschapman19/cca_zoo/blob/main/cca_zoo/deep/_discriminative/_dmcca.py

Parameters

z (list) – list of latent dimensions for each view of type torch.Tensor.

Returns

CCA loss.

Return type

cca_loss (torch.Tensor)

forward(x)[source]

Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.

Parameters

x (list) – list of input data of type torch.Tensor.

Returns

dictionary containing list of data reconstructions (x_recon) and latent dimensions (z).

Return type

fwd_rtn (dict)

loss_function(x, fwd_rtn)[source]

Calculate reconstruction loss.

Parameters
  • x (list) – list of input data of type torch.Tensor.

  • fwd_rtn (dict) – dictionary containing list of data reconstructions (x_recon) and latent dimensions (z).

Returns

dictionary containing reconstruction loss.

Return type

losses (dict)