Source code for multiviewae.models.mcvae

import torch
import hydra

from ..base.constants import MODEL_MCVAE
from ..base.base_model import BaseModelVAE

[docs]class mcVAE(BaseModelVAE): r""" Multi-Channel Variational Autoencoder and Sparse Multi-Channel Variational Autoencoder. Code is based on: https://github.com/ggbioing/mcvae Args: cfg (str): Path to configuration file. Model specific parameters in addition to default parameters: - model.beta (int, float): KL divergence weighting term. - model.sparse (bool): Whether to enforce sparsity of the encoding distribution. - model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0. - encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use. - encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution. - decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use. - decoder.default.init_logvar (int, float): Initial value for log variance of decoder. - decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution. input_dim (list): Dimensionality of the input data. z_dim (int): Number of latent dimensions. References ---------- Antelmi, Luigi & Ayache, Nicholas & Robert, Philippe & Lorenzi, Marco. (2019). Sparse Multi-Channel Variational Autoencoder for the Joint Analysis of Heterogeneous Data. """ def __init__( self, cfg = None, input_dim = None, z_dim = None ): super().__init__(model_name=MODEL_MCVAE, cfg=cfg, input_dim=input_dim, z_dim=z_dim)
[docs] def encode(self, x): r"""Forward pass through encoder networks. Args: x (list): list of input data of type torch.Tensor. Returns: qz_xs (list): list of encoding dimensions for each view. """ qz_xs = [] for i in range(self.n_views): mu, logvar = self.encoders[i](x[i]) qz_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu, logvar=logvar ) qz_xs.append(qz_x) return qz_xs
[docs] def decode(self, qz_xs): r"""Forward pass through decoder networks. Each latent is passed through all of the decoders. Args: z (list): list of latent dimensions for each view of type torch.Tensor. Returns: px_zs (list): A nested list of decoding distributions. The outer list has a n_view element indicating latent dimensions index. The inner list is a n_view element list with the position in the list indicating the decoder index. """ px_zs = [] for qz_x in qz_xs: px_z = [ self.decoders[i](qz_x._sample(training=self._training, return_mean=self.return_mean)) for i in range(self.n_views) ] px_zs.append(px_z) del px_z return px_zs
[docs] def forward(self, x): r"""Apply encode and decode methods to input data to generate latent dimensions and data reconstructions. Args: x (list): list of input data of type torch.Tensor. Returns: fwd_rtn (dict): dictionary containing encoding (qz_xs) and decoding (px_zs) distributions. """ qz_xs = self.encode(x) px_zs = self.decode(qz_xs) fwd_rtn = {"px_zs": px_zs, "qz_xs": qz_xs} return fwd_rtn
[docs] def loss_function(self, x, fwd_rtn): r"""Calculate mcVAE loss. Args: x (list): list of input data of type torch.Tensor. fwd_rtn (dict): dictionary containing encoding and decoding distributions. Returns: losses (dict): dictionary containing each element of the mcVAE loss. """ px_zs = fwd_rtn["px_zs"] qz_xs = fwd_rtn["qz_xs"] kl = self.calc_kl(qz_xs) ll = self.calc_ll(x, px_zs) total = kl - ll losses = {"loss": total, "kl": kl, "ll": ll} return losses
[docs] def calc_kl(self, qz_xs): r"""Calculate mcVAE KL-divergence loss. Args: qz_xs (list): list of encoding distributions. Returns: (torch.Tensor): KL-divergence loss across all views. """ kl = 0 for qz_x in qz_xs: if self.sparse: kl += qz_x.sparse_kl_divergence().mean(0).sum() else: kl += qz_x.kl_divergence(self.prior).mean(0).sum() return self.beta * kl
[docs] def calc_ll(self, x, px_zs): r"""Calculate log-likelihood loss. Args: x (list): list of input data of type torch.Tensor. px_zs (list): list of decoding distributions. Returns: ll (torch.Tensor): Log-likelihood loss. """ ll = 0 for i in range(self.n_views): for j in range(self.n_views): ll += px_zs[j][i].log_likelihood(x[i]).mean(0).sum() #first index is latent, second index is view return ll/self.n_views/self.n_views