Source code for multiviewae.models.mmvae

import math
import torch
import hydra

from ..base.constants import MODEL_MMVAE
from ..base.base_model import BaseModelVAE
import numpy as np
#set numpy seed
np.random.seed(0)

[docs]class mmVAE(BaseModelVAE): r""" Mixture-of-Experts Multimodal Variational Autoencoder (MMVAE). Code is based on: https://github.com/iffsid/mmvae Args: cfg (str): Path to configuration file. Model specific parameters in addition to default parameters: - model.K (int): Number of samples to take from encoding distribution. - model.DREG_loss (bool): Whether to use DReG estimator when using large K value. - encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use. - encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution. - decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use. - decoder.default.init_logvar (int, float): Initial value for log variance of decoder. - decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution. input_dim (list): Dimensionality of the input data. z_dim (int): Number of latent dimensions. References ---------- Shi, Y., Siddharth, N., Paige, B., & Torr, P.H. (2019). Variational Mixture-of-Experts Autoencoders for Multi-Modal Deep Generative Models. ArXiv, abs/1911.03393. """ def __init__( self, cfg = None, input_dim = None, z_dim = None ): super().__init__(model_name=MODEL_MMVAE, cfg=cfg, input_dim=input_dim, z_dim=z_dim)
[docs] def encode(self, x): r"""Forward pass through encoder networks. Args: x (list): list of input data of type torch.Tensor. Returns: (list): list of encoding distributions. """ qz_xs = [] for i in range(self.n_views): mu, logvar = self.encoders[i](x[i]) qz_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu, logvar=logvar ) qz_xs.append(qz_x) return qz_xs
[docs] def encode_subset(self, x, subset): r"""Forward pass through encoder networks for a subset of modalities. Args: x (list): list of input data of type torch.Tensor. subset (list): list of modalities to encode. Returns: (list): list of encoding distributions. """ qz_xs = [] for i in range(self.n_views): if i in subset: mu, logvar = self.encoders[i](x[i]) qz_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu, logvar=logvar ) else: # Choose one of the subset modalities at random mod = np.random.choice(subset) mu, logvar = self.encoders[mod](x[mod]) qz_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{mod}.enc_dist"), loc=mu, logvar=logvar ) qz_xs.append(qz_x) return qz_xs
[docs] def decode(self, qz_xs): r"""Forward pass through decoder networks. Each latent is passed through all of the decoders. Args: x (list): list of input data of type torch.Tensor. Returns: (list): A nested list of decoding distributions. The outer list has a n_view element indicating latent dimensions index. The inner list is a n_view element list with the position in the list indicating the decoder index. """ px_zs = [] for qz_x in qz_xs: if self._training: px_z = [ self.decoders[j](qz_x.rsample(torch.Size([self.K]))) for j in range(self.n_views) ] else: px_z = [ self.decoders[j](qz_x._sample()) for j in range(self.n_views) ] px_zs.append( px_z ) del px_z return px_zs
[docs] def decode_subset(self, qz_xs, subset): r"""Forward pass through decoder networks for a subset of modalities. Each latent is passed through its own decoder. Args: qz_xs (list): list of encoding distributions. subset (list): list of modalities to decode. Returns: (list): A nested list of decoding distributions. The outer list is a single element list, the inner list is a subset element list of decoding distributions. """ px_zs = [] for i, qz_x in enumerate(qz_xs): if i in subset: px_z = self.decoders[i](qz_x._sample(training=self._training, return_mean=self.return_mean)) px_zs.append(px_z) return [px_zs]
[docs] def forward(self, x): r"""Apply encode and decode methods to input data to generate latent dimensions and data reconstructions. Args: x (list): list of input data of type torch.Tensor. Returns: fwd_rtn (dict): dictionary containing encoding (qz_xs) and decoding (px_zs) distributions. """ qz_xs = self.encode(x) px_zs = self.decode(qz_xs) return {"qz_xs": qz_xs, "px_zs": px_zs}
[docs] def loss_function(self, x, fwd_rtn): r"""Wrapper function for mmVAE loss. Args: x (list): list of input data of type torch.Tensor. fwd_rtn (dict): dictionary containing encoding and decoding distributions. Returns: losses (dict): dictionary containing mmVAE loss. """ qz_xs, px_zs = fwd_rtn["qz_xs"], fwd_rtn["px_zs"] total = -self.moe_iwae(x, qz_xs, px_zs) losses = {"loss": total} return losses
[docs] def moe_iwae(self, x, qz_xs, px_zs): r"""Calculate Mixture-of-Experts importance weighted autoencoder (IWAE) loss used for the mmVAE model. Args: x (list): list of input data of type torch.Tensor. fwd_rtn (dict): dictionary containing encoding and decoding distributions. Returns: (torch.Tensor): the output tensor. """ lws = [] zss = [] if self.DREG_loss: qz_xs = [hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=qz_x.loc.detach(), logvar=qz_x.logvar.detach()) for i, qz_x in enumerate(qz_xs)] for i in range(self.n_views): if self._training: zs = qz_xs[i].rsample(torch.Size([self.K])) else: zs = qz_xs[i]._sample() zss.append(zs) for r, qz_x in enumerate(qz_xs): lpz = self.prior.log_likelihood(zss[r]).sum(-1) lqz_x = self.log_mean_exp( torch.stack([qz_x.log_likelihood(zss[r]).sum(-1) for qz_x in qz_xs]) ) # summing over M modalities for each z to create q(z|x1:M) lpx_z = [ px_z.log_likelihood(x[d]).view(*px_z._sample().size()[:2], -1).sum(-1) for d, px_z in enumerate(px_zs[r]) ] # summing over each decoder lpx_z = torch.stack(lpx_z).sum(0) lw = lpx_z + self.beta*(lpz - lqz_x) lws.append(lw) if self.DREG_loss: zss = torch.stack(zss) with torch.no_grad(): grad_wt = (lws - torch.logsumexp(lws, 1, keepdim=True)).exp() if zss.requires_grad: zss.register_hook(lambda grad: grad_wt.unsqueeze(-1) * grad) return (grad_wt * lws).sum(0) / self.n_views #DReG loss return ( self.log_mean_exp(torch.stack(lws), dim=1).mean(0).sum()/self.n_views ) # looser iwae bound
[docs] def log_mean_exp(self, value, dim=0, keepdim=False): r"""Returns the log of the mean of the exponentials along the given dimension (dim). Args: value (torch.Tensor): the input tensor. dim (int, optional): the dimension along which to take the mean. keepdim (bool, optional): whether the output tensor has dim retained or not. Returns: (torch.Tensor): the output tensor. """ return torch.logsumexp(value, dim, keepdim=keepdim) - math.log(value.size(dim))