Models
- class multiviewae.models.AE(cfg=None, input_dim=None, z_dim=None)[source]
Multi-view Autoencoder model with a separate latent representation for each view.
- Parameters
cfg (str) – Path to configuration file.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
list of latent dimensions for each view of type torch.Tensor.
- Return type
z (list)
- decode(z)[source]
Forward pass through decoder networks. Each latent is passed through all of the decoders.
- Parameters
z (list) – list of latent dimensions for each view of type torch.Tensor.
- Returns
list of data reconstructions.
- Return type
x_recon (list)
- forward(x)[source]
Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing list of data reconstructions (x_recon) and latent dimensions (z).
- Return type
fwd_rtn (dict)
- loss_function(x, fwd_rtn)[source]
Calculate reconstruction loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing list of data reconstructions (x_recon) and latent dimensions (z).
- Returns
dictionary containing reconstruction loss.
- Return type
losses (dict)
- class multiviewae.models.mAAE(cfg=None, input_dim=None, z_dim=None)[source]
Multi-view Adversarial Autoencoder model with a separate latent representation for each view.
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
eps (float): Value added for numerical stability.
discriminator._target_ (multiviewae.architectures.mlp.Discriminator): Discriminator network class.
discriminator.hidden_layer_dim (list): Number of nodes per hidden layer.
discriminator.bias (bool): Whether to include a bias term in hidden layers.
discriminator.non_linear (bool): Whether to include a ReLU() function between layers.
discriminator.dropout_threshold (float): Dropout threshold of layers.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
list of latent dimensions for each view of type torch.Tensor.
- Return type
z (list)
- decode(z)[source]
Forward pass through decoder networks. Each latent is passed through all of the decoders.
- Parameters
z (list) – list of latent dimensions for each view of type torch.Tensor.
- Returns
list of decoding distributions.
- Return type
px_zs (list)
- disc(z)[source]
Forward pass of “real” samples from gaussian prior and “fake” samples from encoders through the discriminator network.
- Parameters
z (list) – list of latent dimensions for each view of type torch.Tensor.
- Returns
Discriminator network output for “real” samples. d_fake (list): list of discriminator network output for “fake” samples.
- Return type
d_real (torch.Tensor)
- forward_recon(x)[source]
Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing decoding distributions (px_zs) and latent dimensions (z).
- Return type
fwd_rtn (dict)
- forward_discrim(x)[source]
Apply encode and disc methods to input data to generate discriminator prediction on the latent dimensions and train discriminator parameters.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing discriminator output from “real” samples (d_real), discriminator output from “fake” samples (d_fake), and latent dimensions (z).
- Return type
fwd_rtn (dict)
- forward_gen(x)[source]
Apply encode and disc methods to input data to generate discriminator prediction on the latent dimensions and train encoder parameters.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
fwd_rtn (dict): dictionary containing discriminator output from “fake” samples (d_fake) and latent dimensions (z).
- Return type
fwd_rtn (dict)
- recon_loss(x, fwd_rtn)[source]
Calculate reconstruction loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – fwd_rtn from the forward_recon method.
- Returns
Reconstruction error.
- Return type
ll (torch.Tensor)
- class multiviewae.models.mWAE(cfg=None, input_dim=None, z_dim=None)[source]
Multi-view Adversarial Autoencoder model with wasserstein loss.
Wasserstein autoencoders: https://arxiv.org/abs/1711.01558
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
discriminator._target_ (multiviewae.architectures.mlp.Discriminator): Discriminator network class.
discriminator.hidden_layer_dim (list): Number of nodes per hidden layer.
discriminator.bias (bool): Whether to include a bias term in hidden layers.
discriminator.non_linear (bool): Whether to include a ReLU() function between layers.
discriminator.dropout_threshold (float): Dropout threshold of layers.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
list of latent dimensions for each view of type torch.Tensor.
- Return type
z (list)
- decode(z)[source]
Forward pass through decoder networks. Each latent is passed through all of the decoders.
- Parameters
z (list) – list of latent dimensions for each view of type torch.Tensor.
- Returns
list of decoding distributions.
- Return type
px_zs (list)
- disc(z)[source]
Forward pass of “real” samples from gaussian prior and “fake” samples from encoders through the discriminator network.
- Parameters
z (list) – list of latent dimensions for each view of type torch.Tensor.
- Returns
Discriminator network output for “real” samples. d_fake (list): list of discriminator network output for “fake” samples.
- Return type
d_real (torch.Tensor)
- forward_recon(x)[source]
Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing decoding distributions (px_zs) and latent dimensions (z).
- Return type
fwd_rtn (dict)
- forward_discrim(x)[source]
Apply encode and disc methods to input data to generate discriminator prediction on the latent dimensions and train discriminator parameters.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing discriminator output from “real” samples (d_real), discriminator output from “fake” samples (d_fake), and latent dimensions (z).
- Return type
fwd_rtn (dict)
- forward_gen(x)[source]
Apply encode and disc methods to input data to generate discriminator prediction on the latent dimensions and train encoder parameters.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
fwd_rtn (dict): dictionary containing discriminator output from “fake” samples (d_fake) and latent dimensions (z).
- Return type
fwd_rtn (dict)
- recon_loss(x, fwd_rtn)[source]
Calculate reconstruction loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – fwd_rtn from the forward_recon method.
- Returns
Reconstruction error.
- Return type
ll (torch.Tensor)
- class multiviewae.models.mcVAE(cfg=None, input_dim=None, z_dim=None)[source]
Multi-Channel Variational Autoencoder and Sparse Multi-Channel Variational Autoencoder.
Code is based on: https://github.com/ggbioing/mcvae
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model.beta (int, float): KL divergence weighting term.
model.sparse (bool): Whether to enforce sparsity of the encoding distribution.
model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar (int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Antelmi, Luigi & Ayache, Nicholas & Robert, Philippe & Lorenzi, Marco. (2019). Sparse Multi-Channel Variational Autoencoder for the Joint Analysis of Heterogeneous Data.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
list of encoding dimensions for each view.
- Return type
qz_xs (list)
- decode(qz_xs)[source]
Forward pass through decoder networks. Each latent is passed through all of the decoders.
- Parameters
z (list) – list of latent dimensions for each view of type torch.Tensor.
- Returns
A nested list of decoding distributions. The outer list has a n_view element indicating latent dimensions index. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
px_zs (list)
- forward(x)[source]
Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding (qz_xs) and decoding (px_zs) distributions.
- Return type
fwd_rtn (dict)
- loss_function(x, fwd_rtn)[source]
Calculate mcVAE loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing encoding and decoding distributions.
- Returns
dictionary containing each element of the mcVAE loss.
- Return type
losses (dict)
- class multiviewae.models.mVAE(cfg=None, input_dim=None, z_dim=None)[source]
Multimodal Variational Autoencoder (MVAE).
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model.beta (int, float): KL divergence weighting term.
model.join_type (str): Method of combining encoding distributions.
model.warmup (int): KL term weighted by beta linearly increased to 1 over this many epochs.
model.use_prior (bool): Whether to use a prior expert when combining encoding distributions.
model.sparse (bool): Whether to enforce sparsity of the encoding distribution.
model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0.
model.weight_ll (bool): Whether to weight the log-likelihood loss by 1/n_views.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar (int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Wu, M., & Goodman, N.D. (2018). Multimodal Generative Models for Scalable Weakly-Supervised Learning. NeurIPS.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
Single element list of joint encoding distribution.
- Return type
(list)
- encode_subset(x, subset)[source]
Forward pass through encoder networks for the specified subset.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
Single element list of joint encoding distribution.
- Return type
(list)
- decode(qz_x)[source]
Forward pass of joint latent dimensions through decoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- decode_subset(qz_x, subset)[source]
Forward pass of joint latent dimensions through decoder networks for the specified subset.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- forward(x)[source]
Apply encode and decode methods to input data to generate the joint latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding and decoding distributions.
- Return type
fwd_rtn (dict)
- calc_kl(qz_x)[source]
Calculate KL-divergence loss.
- Parameters
qz_xs (list) – Single element list containing joint encoding distribution.
- Returns
KL-divergence loss.
- Return type
(torch.Tensor)
- calc_ll(x, px_zs)[source]
Calculate log-likelihood loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
px_zs (list) – list of decoding distributions.
- Returns
Log-likelihood loss.
- Return type
ll (torch.Tensor)
- class multiviewae.models.JMVAE(cfg=None, input_dim=None, z_dim=None)[source]
JMVAE-kl.
- Parameters
cfg (str) – Path to configuration file. Model specific parameters in addition to default parameters: - alpha (float): Weighting of KL-divergence loss from individual encoders. - encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of Encoder to use. - encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution. - decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use. - decoder.default.init_logvar(int, float): Initial value for log variance of decoder. - decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Suzuki, Masahiro & Nakayama, Kotaro & Matsuo, Yutaka. (2016). Joint Multimodal Learning with Deep Generative Models.
- encode(x)[source]
Forward pass through joint encoder network.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
Single element list containing joint encoding distribution, qz_xy.
- Return type
(list)
- encode_separate(x)[source]
Forward pass through separate encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
Encoding distribution for modality X. qz_y: Encoding distribution for modality Y.
- Return type
qz_x
- decode(qz_x)[source]
Forward pass of joint latent dimensions through decoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a 2 element list with the position in the list indicating the decoder index.
- Return type
(list)
- forward(x)[source]
Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding and decoding distributions.
- Return type
fwd_rtn (dict)
- calc_kl(qz_xy, qz_x, qz_y)[source]
Calculate JMVAE-kl KL-divergence loss.
- Parameters
qz_xy (list) – Single element list containing shared encoding distribution.
qz_x (list) – Single element list containing encoding distribution for modality X.
qz_y (list) – Single element list containing encoding distribution for modality Y.
- Returns
KL-divergence loss.
- Return type
(torch.Tensor)
- calc_ll(x, px_zs)[source]
Calculate log-likelihood loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
px_zs (list) – list of decoding distributions.
- Returns
Log-likelihood loss.
- Return type
ll (torch.Tensor)
- class multiviewae.models.me_mVAE(cfg=None, input_dim=None, z_dim=None)[source]
Multimodal Variational Autoencoder (MVAE).
Loss optimises the ELBO term from the joint posterior distribution, as well as the separate ELBO terms for each view. me_mVAE stands for multi ELBO Multimodal VAE
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model.beta (int, float): KL divergence weighting term.
model.join_type (str): Method of combining encoding distributions.
model.warmup (int): KL term weighted by beta linearly increased to 1 over this many epochs.
model.use_prior (bool): Whether to use a prior expert when combining encoding distributions.
model.sparse (bool): Whether to enforce sparsity of the encoding distribution.
model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0.
model.weight_kld (bool): Whether to weight the KL term by the number of views.
model.weight_ll (bool): Whether to weight the log-likelihood term by the number of views.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar (int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Wu, M., & Goodman, N.D. (2018). Multimodal Generative Models for Scalable Weakly-Supervised Learning. NeurIPS.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
Single element list of joint encoding distribution. (list): List containing separate encoding distributions.
- Return type
(list)
- encode_subset(x, subset)[source]
Forward pass through encoder networks for the specified subset.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
Single element list of joint encoding distribution.
- Return type
(list)
- decode(qz_x)[source]
Forward pass of joint latent dimensions through decoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- decode_subset(qz_x, subset)[source]
Forward pass of joint latent dimensions through decoder networks for the specified subset.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- decode_separate(qz_xs)[source]
Forward pass of each view specific latent dimensions through the respective decoder network.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the view specific latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- forward(x)[source]
Apply encode, decode, encode_separate and decode_separate methods to input data to generate the joint and modality specific latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding and decoding distributions.
- Return type
fwd_rtn (dict)
- calc_kl(qz_xs)[source]
Calculate KL-divergence loss.
- Parameters
qz_xs (list) – list of encoding distributions.
- Returns
KL-divergence loss.
- Return type
(torch.Tensor)
- calc_ll(x, px_zs)[source]
Calculate log-likelihood loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
px_zs (list) – list of decoding distributions.
- Returns
Log-likelihood loss.
- Return type
ll (torch.Tensor)
- loss_function(x, fwd_rtn)[source]
Calculate multi ELBO Multimodal VAE loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing encoding and decoding distributions.
- Returns
dictionary containing each element of the MVAE loss.
- Return type
losses (dict)
- class multiviewae.models.mmVAE(cfg=None, input_dim=None, z_dim=None)[source]
Mixture-of-Experts Multimodal Variational Autoencoder (MMVAE).
Code is based on: https://github.com/iffsid/mmvae
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model.K (int): Number of samples to take from encoding distribution.
model.DREG_loss (bool): Whether to use DReG estimator when using large K value.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar (int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Shi, Y., Siddharth, N., Paige, B., & Torr, P.H. (2019). Variational Mixture-of-Experts Autoencoders for Multi-Modal Deep Generative Models. ArXiv, abs/1911.03393.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
list of encoding distributions.
- Return type
(list)
- encode_subset(x, subset)[source]
Forward pass through encoder networks for a subset of modalities. :param x: list of input data of type torch.Tensor. :type x: list :param subset: list of modalities to encode. :type subset: list
- Returns
list of encoding distributions.
- Return type
(list)
- decode(qz_xs)[source]
Forward pass through decoder networks. Each latent is passed through all of the decoders.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions. The outer list has a n_view element indicating latent dimensions index. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- decode_subset(qz_xs, subset)[source]
Forward pass through decoder networks for a subset of modalities. Each latent is passed through its own decoder.
- Parameters
qz_xs (list) – list of encoding distributions.
subset (list) – list of modalities to decode.
- Returns
A nested list of decoding distributions. The outer list is a single element list, the inner list is a subset element list of decoding distributions.
- Return type
(list)
- forward(x)[source]
Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding (qz_xs) and decoding (px_zs) distributions.
- Return type
fwd_rtn (dict)
- loss_function(x, fwd_rtn)[source]
Wrapper function for mmVAE loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing encoding and decoding distributions.
- Returns
dictionary containing mmVAE loss.
- Return type
losses (dict)
- moe_iwae(x, qz_xs, px_zs)[source]
Calculate Mixture-of-Experts importance weighted autoencoder (IWAE) loss used for the mmVAE model.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing encoding and decoding distributions.
- Returns
the output tensor.
- Return type
(torch.Tensor)
- log_mean_exp(value, dim=0, keepdim=False)[source]
Returns the log of the mean of the exponentials along the given dimension (dim).
- Parameters
value (torch.Tensor) – the input tensor.
dim (int, optional) – the dimension along which to take the mean.
keepdim (bool, optional) – whether the output tensor has dim retained or not.
- Returns
the output tensor.
- Return type
(torch.Tensor)
- class multiviewae.models.mvtCAE(cfg=None, input_dim=None, z_dim=None)[source]
Multi-View Total Correlation Auto-Encoder (MVTCAE).
Code is based on: https://github.com/gr8joo/MVTCAE
NOTE: This implementation currently only caters for a PoE posterior distribution. MoE and MoPoE posteriors will be included in further work.
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model.beta (int, float): KL divergence weighting term.
model.alpha (int, float): Log likelihood, Conditional VIB and VIB weighting term.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar(int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Hwang, HyeongJoo and Kim, Geon-Hyeong and Hong, Seunghoon and Kim, Kee-Eung. Multi-View Representation Learning via Total Correlation Objective. 2021. NeurIPS
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
qz_xs (list): list containing separate encoding distributions. qz_x (list): Single element list containing PoE joint encoding distribution.
- Return type
Returns the separate and/or joint encoding distributions depending on whether the model is in the training stage
- encode_subset(x, subset)[source]
Forward pass through encoder networks for a subset of modalities.
- Parameters
x (list) – list of input data of type torch.Tensor.
subset (list) – list of modalities to encode.
- Returns
qz_xs (list): list containing separate encoding distributions. qz_x (list): Single element list containing PoE joint encoding distribution.
- Return type
Returns either the joint or separate encoding distributions depending on whether the model is in the training stage
- decode(qz_x)[source]
Forward pass of joint latent dimensions through decoder networks.
- Parameters
qz_x (list) – list of joint encoding distribution.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- decode_subset(qz_x, subset)[source]
Forward pass of joint latent dimensions through decoder networks for a subset of modalities.
- forward(x)[source]
Apply encode and decode methods to input data to generate the joint latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding and decoding distributions.
- Return type
fwd_rtn (dict)
- loss_function(x, fwd_rtn)[source]
Calculate MVTCAE loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing encoding and decoding distributions.
- Returns
dictionary containing each element of the MVTCAE loss.
- Return type
losses (dict)
- calc_kl_cvib(qz_x, qz_xs)[source]
Calculate KL-divergence between PoE joint encoding distribution and the encoding distribution for each view.
- Parameters
qz_xs (list) – list of encoding distributions of each view.
- Returns
KL-divergence loss.
- Return type
kl (torch.Tensor)
- class multiviewae.models.DVCCA(cfg=None, input_dim=None, z_dim=None)[source]
Deep Variational Canonical Correlation Analysis (DVCCA).
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model.beta (int, float): KL divergence weighting term.
model.private (bool): Whether to include private view-specific latent dimensions.
model.sparse (bool): Whether to enforce sparsity of the encoding distribution.
model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar(int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Wang, Weiran & Lee, Honglak & Livescu, Karen. (2016). Deep Variational Canonical Correlation Analysis.
- configure_optimizers()[source]
Configure optimizers for encoder, private encoder, and decoder network parameters.
- Returns
list of Adam optimizers for encoders and decoders.
- Return type
optimizers (list)
- encode(x)[source]
Forward pass through encoder network. For DVCCA-private a forward pass is performed through each private encoder and the output latent is concatenated with the shared latent.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
qz_x (list): list containing the shared encoding distribution. qz_xs (list): list of encoding distributions for shared and private latents of DVCCA-private. qh_xs (list): list of encoding distributions for private latents of DVCCA-private.
- Return type
Returns a combination of the following depending on the training stage and model type
- decode(qz_x)[source]
Forward pass through decoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared or shared and private latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- forward(x)[source]
Apply encode and decode methods to input data to generate latent dimensions and data reconstructions. For DVCCA, the shared encoding distribution is passed to the decode method. For DVCCA-private, the joint distribution of the shared and private latents for each view is passed to the decode method.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing list of decoding distributions (px_zs), shared encoding distribution (qz_x), and (for DVCCA-private) private encoding distributions (qh_xs).
- Return type
fwd_rtn (dict)
- loss_function(x, fwd_rtn)[source]
Calculate DVCCA loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing list of decoding distributions (px_zs), shared encoding distribution (qz_x), and (for DVCCA-private) private encoding distributions (qh_xs).
- Returns
dictionary containing each element of the DVCCA loss.
- Return type
losses (dict)
- calc_kl(qz_x, qh_xs)[source]
Wrapper function for calculating KL-divergence loss.
- Parameters
qz_x (list) – Single element list containing shared encoding distribution.
qh_xs (list) – list of private encoding distributions for DVCCA-private.
- Returns
KL-divergence loss across all views.
- Return type
(torch.Tensor)
- class multiviewae.models.MoPoEVAE(cfg=None, input_dim=None, z_dim=None)[source]
Mixture-of-Product-of-Experts Variational Autoencoder.
Code is based on: https://github.com/thomassutter/MoPoE
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model.beta (int, float): KL divergence weighting term.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar (int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Sutter, Thomas & Daunhawer, Imant & Vogt, Julia. (2021). Generalized Multimodal ELBO.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
list containing the MoE joint encoding distribution. If training, the model also returns the encoding distribution for each subset.
- Return type
(list)
- encode_subset(x, subset)[source]
Forward pass through encoder networks for a subset of modalities. :param x: list of input data of type torch.Tensor. :type x: list :param subset: list of modalities to encode. :type subset: list
- Returns
list containing the MoE joint encoding distribution.
- Return type
(list)
- decode(qz_x)[source]
Forward pass of joint latent dimensions through decoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- forward(x)[source]
Apply encode and decode methods to input data to generate the joint and subset latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding and decoding distributions.
- Return type
fwd_rtn (dict)
- loss_function(x, fwd_rtn)[source]
Calculate MoPoE VAE loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing encoding and decoding distributions.
- Returns
dictionary containing each element of the MoPoE VAE loss.
- Return type
losses (dict)
- calc_kl_moe(qz_xs)[source]
Calculate KL-divergence between the each PoE subset posterior and the prior distribution.
- Parameters
qz_xs (list) – list of encoding distributions.
- Returns
KL-divergence loss.
- Return type
(torch.Tensor)
- class multiviewae.models.mmJSD(cfg=None, input_dim=None, z_dim=None)[source]
Multimodal Jensen-Shannon divergence (mmJSD) model with Product-of-Experts dynamic prior.
Code is based on: https://github.com/thomassutter/mmjsd
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model.private (bool): Whether to include private modality-specific latent dimensions.
model.beta (int, float): KL divergence weighting term.
model.alpha (int, float): JSD divergence weighting term.
model.s_dim (int): Number of private latent dimensions.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar(int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Sutter, Thomas & Daunhawer, Imant & Vogt, Julia. (2021). Multimodal Generative Learning Utilizing Jensen-Shannon-Divergence. Advances in Neural Information Processing Systems. 33.
- encode(x)[source]
Forward pass through encoder network. If self.private=True, the first two dimensions of each latent are used for the modality-specific part and the remaining dimensions for the joint content.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
qz_xs (list): Single element list containing the MoE encoding distribution for self.private=False. qzs_xs (list): list containing each encoding distribution for self.private=False. qz_x (list): Single element list containing the PoE encoding distribution for self.private=False.
qc_x (list): Single element list containing the PoE shared encoding distribution. qscs_xs (list): list containing combined shared and private latents. qs_xs (list): list of encoding distributions for private latents. qcs_xs (list): list containing encoding distributions for shared latent dimensions for each view.
- Return type
Returns a combination of the following depending on the training stage and model type
- encode_subset(x, subset)[source]
Forward pass through encoder networks for a subset of modalities. :param x: list of input data of type torch.Tensor. :type x: list :param subset: list of modalities to encode. :type subset: list
- Returns
list containing the PoE joint encoding distribution.
- Return type
(list)
- decode(qz_x)[source]
Forward pass of latent dimensions through decoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element, the inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- decode_subset(qz_x, subset)[source]
Forward pass of latent dimensions through decoder networks for a subset of modalities.
- Parameters
x (list) – list of input data of type torch.Tensor.
subset (list) – list of modalities to decode.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element, the inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- forward(x)[source]
Apply encode and decode methods to input data to generate the joint and modality specific latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding and decoding distributions.
- Return type
fwd_rtn (dict)
- calc_kl(qz_xs)[source]
Calculate KL-divergence loss.
- Parameters
qz_xs (list) – list of encoding distributions.
- Returns
KL-divergence loss.
- Return type
(torch.Tensor)
- calc_ll(x, px_zs)[source]
Calculate log-likelihood loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
px_zs (list) – list of decoding distributions.
- Returns
Log-likelihood loss.
- Return type
ll (torch.Tensor)
- calc_jsd(qcs_xs, qc_x)[source]
Calculate Jensen-Shannon Divergence loss.
- Parameters
qcs_xs (list) – list of encoding distributions of each view for shared latent dimensions.
qc_x (list) – Dynamic prior given by PoE of shared encoding distributions.
- Returns
Jensen-Shannon Divergence loss.
- Return type
jsd (torch.Tensor)
- class multiviewae.models.weighted_mVAE(cfg=None, input_dim=None, z_dim=None)[source]
Generalised Product-of-Experts Variational Autoencoder (gPoE-MVAE).
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model.beta (int, float): KL divergence weighting term.
model.private (bool): Whether to include private view-specific latent dimensions.
model.sparse (bool): Whether to enforce sparsity of the encoding distribution.
model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar (int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Cao, Y., & Fleet, D. (2014). Generalized Product of Experts for Automatic and Principled Fusion of Gaussian Process Predictions. arXiv. Lawry Aguila, A., Chapman, J., Altmann, A. (2023). Multi-modal Variational Autoencoders for normative modelling across multiple imaging modalities. arXiv
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
qz_x (list): Single element list containing the PoE encoding distribution for self.private=False.
qc_x (list): Single element list containing the PoE shared encoding distribution. qs_xs (list): list of encoding distributions for private latents. qscs_xs (list): list containing combined shared and private latents. qcs_xs (list): list containing encoding distributions for shared latent dimensions for each view.
- Return type
Returns a combination of the following depending on the training stage and model type
- decode(qz_x)[source]
Forward pass of joint latent dimensions through decoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- forward(x)[source]
Apply encode and decode methods to input data to generate the joint latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding and decoding distributions.
- Return type
fwd_rtn (dict)
- calc_kl(qz_x)[source]
Calculate KL-divergence loss.
- Parameters
qz_x (list) – Single element list containing joint encoding distribution.
- Returns
KL-divergence loss.
- Return type
(torch.Tensor)
- calc_kl_separate(qc_xs)[source]
Calculate KL-divergence loss.
- Parameters
qc_xs (list) – list of encoding distributions for private/shared latent dimensions for each view.
- Returns
KL-divergence loss.
- Return type
(torch.Tensor)
- class multiviewae.models.DMVAE(cfg=None, input_dim=None, z_dim=None)[source]
Disentangled multi-modal variational autoencoder (DMVAE)
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model._lambda (list, optional): Log likelihood weighting term for each modality.
model.s_dim (int): Number of private latent dimensions.
model.beta (int, float): KL divergence weighting term.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar (int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
M. Lee and V. Pavlovic, “Private-Shared Disentangled Multimodal VAE for Learning of Latent Representations,” 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), Nashville, TN, USA, 2021, pp. 1692-1700, doi: 10.1109/CVPRW53098.2021.00185.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
Single element list containing the PoE shared encoding distribution. qcs_xs (list): list containing encoding distributions for shared latent dimensions for each view. qs_xs (list): list of encoding distributions for private latents. qscs_xs (list): nested list containing combined PoE shared and private latents. qscss_xs (list): nested list containing combined shared latents from each modality and private latents for same and cross view reconstruction.
- Return type
qc_x (list)
- decode(qz_x)[source]
Forward pass of joint latent dimensions through decoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- decode_separate(qz_xs)[source]
Forward pass through decoder networks. Each shared latent is passed through all of the decoders with the private latents from the same view.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a n_view element list with position in the list indicating the decoder index. The inner list is a n_view element list with the position in the list indicating latent dimensions index. NOTE: This is the reverse to other models.
- Return type
(list)
- forward(x)[source]
Apply encode and decode methods to input data to generate the latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding and decoding distributions.
- Return type
fwd_rtn (dict)
- calc_kl_joint_latent(qz_x, qs_xs)[source]
Calculate KL-divergence loss for the first terms in Equation 3.
- Parameters
qz_x (list) – Single element list containing joint encoding distribution.
qs_xs (list) – list of encoding distributions for private latent dimensions for each view.
- Returns
KL-divergence loss.
- Return type
(torch.Tensor)
- calc_kl_separate_latent(qcs_xs, qs_xs)[source]
Calculate KL-divergence loss for the second terms in Equation 3.
- Parameters
qcs_x (list) – list of the shared encoding distributions calculated from each view.
qs_xs (list) – list of encoding distributions for private latent dimensions for each view.
- Returns
KL-divergence loss.
- Return type
(torch.Tensor)
- calc_ll_joint(x, px_zs)[source]
Calculate log-likelihood loss from the joint encoding distribution for each modality.
- Parameters
x (list) – list of input data of type torch.Tensor.
px_zs (list) – list of decoding distributions.
- Returns
Log-likelihood loss.
- Return type
ll (torch.Tensor)
- calc_ll_separate(x, pxs_zs)[source]
Calculate cross-modal and self-reconstrution log-likelihood loss from the shared encoding distribution for each modality and private latents.
- Parameters
x (list) – list of input data of type torch.Tensor.
pxs_zs (list) – nested list of decoding distributons. NOTE: The ordering of decoding distribution is the reverse compared to other models.
- Returns
Log-likelihood loss.
- Return type
ll (torch.Tensor)
- loss_function(x, fwd_rtn)[source]
Calculate DMVAE loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing encoding and decoding distributions.
- Returns
dictionary containing each element of the DMVAE loss.
- Return type
losses (dict)
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
- Returns
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.Tuple of dictionaries as described above, with an optional
"frequency"
key.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated" # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
The
frequency
value specified in a dict along with theoptimizer
key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.
This is different from the
frequency
value specified in thelr_scheduler_config
mentioned above.def configure_optimizers(self): optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01) optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01) return [ {"optimizer": optimizer_one, "frequency": 5}, {"optimizer": optimizer_two, "frequency": 10}, ]
In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the
lr_scheduler
key in the above dict, the scheduler will only be updated when its optimizer is being used.Examples:
# most cases. no learning rate scheduler def configure_optimizers(self): return Adam(self.parameters(), lr=1e-3) # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) return gen_opt, dis_opt # example with learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) dis_sch = CosineAnnealing(dis_opt, T_max=10) return [gen_opt, dis_opt], [dis_sch] # example with step-based learning rate schedulers # each optimizer has its own scheduler def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) gen_sch = { 'scheduler': ExponentialLR(gen_opt, 0.99), 'interval': 'step' # called after each training step } dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch return [gen_opt, dis_opt], [gen_sch, dis_sch] # example with optimizer frequencies # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 # https://arxiv.org/abs/1704.00028 def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) n_critic = 5 return ( {'optimizer': dis_opt, 'frequency': n_critic}, {'optimizer': gen_opt, 'frequency': 1} )
Note
Some things to know:
Lightning calls
.backward()
and.step()
on each optimizer and learning rate scheduler as needed.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizers.If you use multiple optimizers,
training_step()
will have an additionaloptimizer_idx
parameter.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.
If you need to control how often those optimizers step or override the default
.step()
schedule, override theoptimizer_step()
hook.
- class multiviewae.models.weighted_DMVAE(cfg=None, input_dim=None, z_dim=None)[source]
Variant of Disentangled multi-modal variational autoencoder (DMVAE) using weighted Product-of-Experts for joint encoding distribution.
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model._lambda (list, optional): Log likelihood weighting term for each modality.
model.s_dim (int): Number of private latent dimensions.
model.beta (int, float): KL divergence weighting term.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar (int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
M. Lee and V. Pavlovic, “Private-Shared Disentangled Multimodal VAE for Learning of Latent Representations,” 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), Nashville, TN, USA, 2021, pp. 1692-1700, doi: 10.1109/CVPRW53098.2021.00185.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
Single element list containing the PoE shared encoding distribution. qcs_xs (list): list containing encoding distributions for shared latent dimensions for each view. qs_xs (list): list of encoding distributions for private latents. qscs_xs (list): nested list containing combined PoE shared and private latents. qscss_xs (list): nested list containing combined shared latents from each modality and private latents for same and cross view reconstruction.
- Return type
qc_x (list)
- decode(qz_x)[source]
Forward pass of joint latent dimensions through decoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- decode_separate(qz_xs)[source]
Forward pass through decoder networks. Each shared latent is passed through all of the decoders with the private latents from the same view.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
A nested list of decoding distributions, px_zs. The outer list has a n_view element list with position in the list indicating the decoder index. The inner list is a n_view element list with the position in the list indicating latent dimensions index. NOTE: This is the reverse to other models.
- Return type
(list)
- forward(x)[source]
Apply encode and decode methods to input data to generate the latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding and decoding distributions.
- Return type
fwd_rtn (dict)
- calc_kl_joint_latent(qz_x, qs_xs)[source]
Calculate KL-divergence loss for the first terms in Equation 3.
- Parameters
qz_x (list) – Single element list containing joint encoding distribution.
qs_xs (list) – list of encoding distributions for private latent dimensions for each view.
- Returns
KL-divergence loss.
- Return type
(torch.Tensor)
- calc_kl_separate_latent(qcs_xs, qs_xs)[source]
Calculate KL-divergence loss for the second terms in Equation 3.
- Parameters
qcs_x (list) – list of the shared encoding distributions calculated from each view.
qs_xs (list) – list of encoding distributions for private latent dimensions for each view.
- Returns
KL-divergence loss.
- Return type
(torch.Tensor)
- calc_ll_joint(x, px_zs)[source]
Calculate log-likelihood loss from the joint encoding distribution for each modality.
- Parameters
x (list) – list of input data of type torch.Tensor.
px_zs (list) – list of decoding distributions.
- Returns
Log-likelihood loss.
- Return type
ll (torch.Tensor)
- calc_ll_separate(x, pxs_zs)[source]
Calculate cross-modal and self-reconstrution log-likelihood loss from the shared encoding distribution for each modality and private latents.
- Parameters
x (list) – list of input data of type torch.Tensor.
pxs_zs (list) – nested list of decoding distributons. NOTE: The ordering of decoding distribution is the reverse compared to other models.
- Returns
Log-likelihood loss.
- Return type
ll (torch.Tensor)
- loss_function(x, fwd_rtn)[source]
Calculate DMVAE loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing encoding and decoding distributions.
- Returns
dictionary containing each element of the DMVAE loss.
- Return type
losses (dict)
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
- Returns
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.Tuple of dictionaries as described above, with an optional
"frequency"
key.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated" # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
The
frequency
value specified in a dict along with theoptimizer
key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.
This is different from the
frequency
value specified in thelr_scheduler_config
mentioned above.def configure_optimizers(self): optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01) optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01) return [ {"optimizer": optimizer_one, "frequency": 5}, {"optimizer": optimizer_two, "frequency": 10}, ]
In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the
lr_scheduler
key in the above dict, the scheduler will only be updated when its optimizer is being used.Examples:
# most cases. no learning rate scheduler def configure_optimizers(self): return Adam(self.parameters(), lr=1e-3) # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) return gen_opt, dis_opt # example with learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) dis_sch = CosineAnnealing(dis_opt, T_max=10) return [gen_opt, dis_opt], [dis_sch] # example with step-based learning rate schedulers # each optimizer has its own scheduler def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) gen_sch = { 'scheduler': ExponentialLR(gen_opt, 0.99), 'interval': 'step' # called after each training step } dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch return [gen_opt, dis_opt], [gen_sch, dis_sch] # example with optimizer frequencies # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 # https://arxiv.org/abs/1704.00028 def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) n_critic = 5 return ( {'optimizer': dis_opt, 'frequency': n_critic}, {'optimizer': gen_opt, 'frequency': 1} )
Note
Some things to know:
Lightning calls
.backward()
and.step()
on each optimizer and learning rate scheduler as needed.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizers.If you use multiple optimizers,
training_step()
will have an additionaloptimizer_idx
parameter.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.
If you need to control how often those optimizers step or override the default
.step()
schedule, override theoptimizer_step()
hook.
- class multiviewae.models.mmVAEPlus(cfg=None, input_dim=None, z_dim=None)[source]
Mixture-of-Experts Multimodal Variational Autoencoder (MMVAE).
Code is based on: https://github.com/iffsid/mmvae
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model.K (int): Number of samples to take from encoding distribution.
model.DREG_loss (bool): Whether to use DReG estimator when using large K value.
encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use.
encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution.
decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use.
decoder.default.init_logvar (int, float): Initial value for log variance of decoder.
decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution.
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Shi, Y., Siddharth, N., Paige, B., & Torr, P.H. (2019). Variational Mixture-of-Experts Autoencoders for Multi-Modal Deep Generative Models. ArXiv, abs/1911.03393.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
list of encoding distributions for shared (qu_xs) and private (qw_xs) latent dimensions during training, otherwise return samples from encoding distributions.
- Return type
(list)
- encode_subset(x, subset)[source]
Forward pass through encoder networks for a subset of modalities. For modalities not in subset, shared latents are sampled from a random modality and private latents from the shared prior.
- Parameters
x (list) – list of input data of type torch.Tensor.
subset (list) – list of modalities to encode.
- Returns
list of samples from encoding distributions.
- Return type
(list)
- decode(zss)[source]
Forward pass through decoder networks. Each latent is passed through all of the decoders.
- Parameters
zss (list) – list of latent samples if not training or list containing qw_xs (list of private encoding distributions) and
qu_xs (list of shared encoding distributions) –
- Returns
A nested list of decoding distributions. The outer list has a n_view element indicating latent dimensions index. The inner list is a n_view element list with the position in the list indicating the decoder index.
- Return type
(list)
- decode_subset(zss, subset)[source]
Forward pass through decoder networks for a subset of modalities. Each latent is passed through its own decoder.
- Parameters
zss (list) – list of latent samples for each modality.
subset (list) – list of modalities to decode.
- Returns
A list of decoding distributions for each modality in subset.
- Return type
(list)
- forward(x)[source]
Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing encoding (qw_xs and qu_xs) and decoding (px_zs) distributions.
- Return type
fwd_rtn (dict)
- loss_function(x, fwd_rtn)[source]
Wrapper function for mmVAE loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing encoding and decoding distributions.
- Returns
dictionary containing mmVAE loss.
- Return type
losses (dict)
- moe_iwae(x, qw_xs, qu_xs, px_zs)[source]
Calculate Mixture-of-Experts importance weighted autoencoder (IWAE) loss used for the mmVAE model.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing encoding and decoding distributions.
- Returns
the output tensor.
- Return type
(torch.Tensor)
- log_mean_exp(value, dim=0, keepdim=False)[source]
Returns the log of the mean of the exponentials along the given dimension (dim).
- Parameters
value (torch.Tensor) – the input tensor.
dim (int, optional) – the dimension along which to take the mean.
keepdim (bool, optional) – whether the output tensor has dim retained or not.
- Returns
the output tensor.
- Return type
(torch.Tensor)
- class multiviewae.models.DCCAE(cfg=None, input_dim=None, z_dim=None)[source]
Deep Canonically Correlated Autoencoder (DCCAE). CCA implementation adapted from: https://github.com/jameschapman19/cca_zoo
- Parameters
cfg (str) –
Path to configuration file. Model specific parameters in addition to default parameters:
model._lambda (int, float): Reconstruction weighting term
input_dim (list) – Dimensionality of the input data.
z_dim (int) – Number of latent dimensions.
References
Wang, Weiran & Arora, Raman & Livescu, Karen & Bilmes, Jeff. (2016). On Deep Multi-View Representation Learning: Objectives and Optimization.
- encode(x)[source]
Forward pass through encoder networks.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
list of latent dimensions for each view of type torch.Tensor.
- Return type
z (list)
- decode(z)[source]
Forward pass through decoder networks. Each latent is passed through all of the decoders.
- Parameters
z (list) – list of latent dimensions for each view of type torch.Tensor.
- Returns
list of data reconstructions.
- Return type
x_recon (list)
- cca(zs)[source]
CCA loss calculation. Adapted from: https://github.com/jameschapman19/cca_zoo/blob/main/cca_zoo/deep/_discriminative/_dmcca.py
- Parameters
z (list) – list of latent dimensions for each view of type torch.Tensor.
- Returns
CCA loss.
- Return type
cca_loss (torch.Tensor)
- forward(x)[source]
Apply encode and decode methods to input data to generate latent dimensions and data reconstructions.
- Parameters
x (list) – list of input data of type torch.Tensor.
- Returns
dictionary containing list of data reconstructions (x_recon) and latent dimensions (z).
- Return type
fwd_rtn (dict)
- loss_function(x, fwd_rtn)[source]
Calculate reconstruction loss.
- Parameters
x (list) – list of input data of type torch.Tensor.
fwd_rtn (dict) – dictionary containing list of data reconstructions (x_recon) and latent dimensions (z).
- Returns
dictionary containing reconstruction loss.
- Return type
losses (dict)