Source code for multiviewae.models.jmvae

import torch
import hydra
import numpy as np
from ..base.constants import MODEL_JMVAE
from ..base.base_model import BaseModelVAE

[docs]class JMVAE(BaseModelVAE): r""" JMVAE-kl. Args: cfg (str): Path to configuration file. Model specific parameters in addition to default parameters: - alpha (float): Weighting of KL-divergence loss from individual encoders. - encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of Encoder to use. - encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution. - decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use. - decoder.default.init_logvar(int, float): Initial value for log variance of decoder. - decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution. input_dim (list): Dimensionality of the input data. z_dim (int): Number of latent dimensions. References ---------- Suzuki, Masahiro & Nakayama, Kotaro & Matsuo, Yutaka. (2016). Joint Multimodal Learning with Deep Generative Models. """ def __init__( self, cfg = None, input_dim = None, z_dim = None ): super().__init__(model_name=MODEL_JMVAE, cfg=cfg, input_dim=input_dim, z_dim=z_dim) def _setencoders(self): r"""Set the joint and individual encoder networks. """ self.encoders = torch.nn.ModuleList( [hydra.utils.instantiate( self.cfg.encoder.default, input_dim=self.input_dim[0]+self.input_dim[1], z_dim=self.z_dim, sparse=False, log_alpha=None, _recursive_=False, _convert_="all" )] + [ hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}"), input_dim=d, z_dim=self.z_dim, sparse=False, log_alpha=None, _recursive_=False, _convert_="all" ) for i, d in enumerate(self.input_dim) ] )
[docs] def encode(self, x): r"""Forward pass through joint encoder network. Args: x (list): list of input data of type torch.Tensor. Returns: (list): Single element list containing joint encoding distribution, qz_xy. """ mu, logvar = self.encoders[0](torch.cat((x[0], x[1]),dim=1)) #TODO: current implementation only works for 2D data for both modalities qz_xy = hydra.utils.instantiate( self.cfg.encoder.default.enc_dist, loc=mu, logvar=logvar ) return [qz_xy]
[docs] def encode_separate(self, x): r"""Forward pass through separate encoder networks. Args: x (list): list of input data of type torch.Tensor. Returns: qz_x: Encoding distribution for modality X. qz_y: Encoding distribution for modality Y. """ mu, logvar = self.encoders[1](x[0]) qz_x = hydra.utils.instantiate( self.cfg.encoder.enc0.enc_dist, loc=mu, logvar=logvar ) mu, logvar = self.encoders[2](x[1]) qz_y = hydra.utils.instantiate( self.cfg.encoder.enc1.enc_dist, loc=mu, logvar=logvar ) return qz_x, qz_y
[docs] def decode(self, qz_x): r"""Forward pass of joint latent dimensions through decoder networks. Args: x (list): list of input data of type torch.Tensor. Returns: (list): A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a 2 element list with the position in the list indicating the decoder index. """ px_zs = [] for i in range(self.n_views): px_z = self.decoders[i](qz_x[0]._sample(training=self._training, return_mean=self.return_mean)) px_zs.append(px_z) return [px_zs]
[docs] def forward(self, x): r"""Apply encode and decode methods to input data to generate latent dimensions and data reconstructions. Args: x (list): list of input data of type torch.Tensor. Returns: fwd_rtn (dict): dictionary containing encoding and decoding distributions. """ qz_xy = self.encode(x) qz_x, qz_y = self.encode_separate(x) px_z, py_z = self.decode(qz_xy)[0] fwd_rtn = {"px_z": px_z, "py_z": py_z, "qz_x": qz_x, "qz_y": qz_y, "qz_xy": qz_xy} return fwd_rtn
[docs] def calc_kl(self, qz_xy, qz_x, qz_y): r"""Calculate JMVAE-kl KL-divergence loss. Args: qz_xy (list): Single element list containing shared encoding distribution. qz_x (list): Single element list containing encoding distribution for modality X. qz_y (list): Single element list containing encoding distribution for modality Y. Returns: (torch.Tensor): KL-divergence loss. """ kl_prior = qz_xy[0].kl_divergence(self.prior).mean(0).sum() kl_qz_x = qz_xy[0].kl_divergence(qz_x).mean(0).sum() kl_qz_y = qz_xy[0].kl_divergence(qz_y).mean(0).sum() return kl_prior + self.alpha*(kl_qz_x + kl_qz_y)
[docs] def calc_ll(self, x, px_zs): r"""Calculate log-likelihood loss. Args: x (list): list of input data of type torch.Tensor. px_zs (list): list of decoding distributions. Returns: ll (torch.Tensor): Log-likelihood loss. """ ll = 0 for i in range(self.n_views): ll += px_zs[i].log_likelihood(x[i]).mean(0).sum() return ll
[docs] def loss_function(self, x, fwd_rtn): r"""Calculate JMVAE-kl loss. Args: x (list): list of input data of type torch.Tensor. fwd_rtn (dict): dictionary containing encoding and decoding distributions. Returns: losses (dict): dictionary containing each element of the JMVAE loss. """ px_z = fwd_rtn["px_z"] py_z = fwd_rtn["py_z"] qz_x = fwd_rtn["qz_x"] qz_y = fwd_rtn["qz_y"] qz_xy = fwd_rtn["qz_xy"] kls = self.calc_kl(qz_xy, qz_x, qz_y) ll = self.calc_ll(x, [px_z, py_z]) if self.current_epoch > self.warmup - 1: annealing_factor = 1 else: annealing_factor = self.current_epoch/self.warmup total = annealing_factor*kls - ll losses = {"loss": total, "kls": kls, "ll": ll} return losses
[docs] def calc_nll(self, x, K=1000, batch_size_K=100): r"""Calculate negative log-likelihood used to evaluate model performance. Args: x (list): list of input data of type torch.Tensor. Returns: nll (torch.Tensor): Negative log-likelihood. """ mu, logvar = self.encoders[0](torch.cat((x[0], x[1]),dim=1)) #TODO: current implementation only works for 2D data for both modalities qz_xy = hydra.utils.instantiate( self.cfg.encoder.default.enc_dist, loc=mu, logvar=logvar ) # And sample from the posterior zs = qz_xy.rsample([K]) # shape K x n_data x latent_dim ll = 0 start_idx = 0 stop_idx = min(start_idx + batch_size_K, K) lnpxs = [] while start_idx < stop_idx: zs_ = zs[start_idx:stop_idx] lpx_zs = 0 for j in range(self.n_views): px_z = self.decoders[j](zs_) lpx_zs += px_z.log_likelihood(x[j]).sum(dim=-1) lpz = self.prior.log_likelihood(zs_).sum(dim=-1) lqz_xy = qz_xy.log_likelihood(zs_).sum(dim=-1) ln_px = torch.logsumexp(lpx_zs + lpz - lqz_xy, dim=0) lnpxs.append(ln_px) start_idx += batch_size_K stop_idx = min(stop_idx + batch_size_K, K) lnpxs = torch.stack(lnpxs) ll += (torch.logsumexp(lnpxs, dim=0) - np.log(K)).mean() return -ll