Source code for multiviewae.models.mopoevae

import torch
import hydra
from ..base.constants import MODEL_MOPOEVAE
from ..base.base_model import BaseModelVAE
from itertools import combinations
from ..base.representations import ProductOfExperts, MixtureOfExperts
import numpy as np
[docs]class MoPoEVAE(BaseModelVAE): r""" Mixture-of-Product-of-Experts Variational Autoencoder. Code is based on: https://github.com/thomassutter/MoPoE Args: cfg (str): Path to configuration file. Model specific parameters in addition to default parameters: - model.beta (int, float): KL divergence weighting term. - encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use. - encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution. - decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use. - decoder.default.init_logvar (int, float): Initial value for log variance of decoder. - decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution. input_dim (list): Dimensionality of the input data. z_dim (int): Number of latent dimensions. References ---------- Sutter, Thomas & Daunhawer, Imant & Vogt, Julia. (2021). Generalized Multimodal ELBO. """ def __init__( self, cfg = None, input_dim = None, z_dim = None ): super().__init__(model_name=MODEL_MOPOEVAE, cfg=cfg, input_dim=input_dim, z_dim=z_dim) self.subsets = self.set_subsets() if self.weight_ll: self.ll_weighting = 1/self.n_views else: self.ll_weighting = 1
[docs] def encode(self, x): r"""Forward pass through encoder networks. Args: x (list): list of input data of type torch.Tensor. Returns: (list): list containing the MoE joint encoding distribution. If training, the model also returns the encoding distribution for each subset. """ mu = [] logvar = [] for i in range(self.n_views): mu_, logvar_ = self.encoders[i](x[i]) mu.append(mu_) logvar.append(logvar_) mu = torch.stack(mu) logvar = torch.stack(logvar) mu_out = [] logvar_out = [] if self._training: qz_xs = [] for subset in self.subsets: mu_s = mu[subset] logvar_s = logvar[subset] if len(subset) == self.n_views: mu_ = self.prior.mean mu_ = mu_.expand(mu[0].shape).to(mu[0].device) logvar_ = torch.log(self.prior.variance).to(mu[0].device) logvar_ = logvar_.expand(logvar[0].shape) mu_ = mu_.unsqueeze(0) logvar_ = logvar_.unsqueeze(0) mu_s = torch.cat([mu_s, mu_], dim=0) logvar_s = torch.cat([logvar_s, logvar_], dim=0) mu_s, logvar_s = ProductOfExperts()(mu_s, logvar_s) mu_out.append(mu_s) logvar_out.append(logvar_s) qz_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu_s, logvar=logvar_s ) qz_xs.append(qz_x) mu_out = torch.stack(mu_out) logvar_out = torch.stack(logvar_out) moe_mu, moe_logvar = MixtureOfExperts()(mu_out, logvar_out) qz_x = hydra.utils.instantiate( self.cfg.encoder.default.enc_dist, loc=moe_mu, logvar=moe_logvar ) return [qz_xs, qz_x] else: for i in range(self.n_views): mu_s = mu[i] logvar_s = logvar[i] mu_out.append(mu_s) logvar_out.append(logvar_s) mu_ = self.prior.mean mu_ = mu_.expand(mu[0].shape).to(mu[0].device) logvar_ = torch.log(self.prior.variance).to(mu[0].device) logvar_ = logvar_.expand(logvar[0].shape) mu_out.append(mu_) logvar_out.append(logvar_) mu_out = torch.stack(mu_out) logvar_out = torch.stack(logvar_out) mu, logvar = ProductOfExperts()(mu_out, logvar_out) qz_x = hydra.utils.instantiate( self.cfg.encoder.default.enc_dist, loc=mu, logvar=logvar ) return [qz_x]
[docs] def encode_subset(self, x, subset): r""" Forward pass through encoder networks for a subset of modalities. Args: x (list): list of input data of type torch.Tensor. subset (list): list of modalities to encode. Returns: (list): list containing the MoE joint encoding distribution. """ mu = [] logvar = [] for i in subset: mu_, logvar_ = self.encoders[i](x[i]) mu.append(mu_) logvar.append(logvar_) if len(subset) == self.n_views: mu_ = self.prior.mean mu_ = mu_.expand(mu[0].shape).to(mu[0].device) logvar_ = torch.log(self.prior.variance).to(mu[0].device) logvar_ = logvar_.expand(logvar[0].shape) mu.append(mu_) logvar.append(logvar_) mu = torch.stack(mu) logvar = torch.stack(logvar) mu, logvar = ProductOfExperts()(mu, logvar) qz_x = hydra.utils.instantiate( self.cfg.encoder.default.enc_dist, loc=mu, logvar=logvar ) return [qz_x]
def decode_subset(self, qz_x, subset): px_zs = [] for i in range(self.n_views): if i in subset: px_z = self.decoders[i](qz_x[0]._sample(training=self._training, return_mean=self.return_mean)) px_zs.append(px_z) return [px_zs]
[docs] def decode(self, qz_x): r"""Forward pass of joint latent dimensions through decoder networks. Args: x (list): list of input data of type torch.Tensor. Returns: (list): A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index. """ px_zs = [] for i in range(self.n_views): px_z = self.decoders[i](qz_x[0]._sample(training=self._training, return_mean=self.return_mean)) px_zs.append(px_z) return [px_zs]
[docs] def forward(self, x): r"""Apply encode and decode methods to input data to generate the joint and subset latent dimensions and data reconstructions. Args: x (list): list of input data of type torch.Tensor. Returns: fwd_rtn (dict): dictionary containing encoding and decoding distributions. """ qz_xs, qz_x = self.encode(x) px_zs = self.decode([qz_x]) fwd_rtn = {"px_zs": px_zs, "qz_xs_subsets": qz_xs, "qz_x_joint": qz_x} return fwd_rtn
[docs] def loss_function(self, x, fwd_rtn): r"""Calculate MoPoE VAE loss. Args: x (list): list of input data of type torch.Tensor. fwd_rtn (dict): dictionary containing encoding and decoding distributions. Returns: losses (dict): dictionary containing each element of the MoPoE VAE loss. """ px_zs = fwd_rtn["px_zs"] qz_xs = fwd_rtn["qz_xs_subsets"] kl = self.calc_kl_moe(qz_xs) ll = self.calc_ll(x, px_zs) total = self.beta * kl - ll losses = {"loss": total, "kl": kl, 'll': ll} return losses
[docs] def calc_kl_moe(self, qz_xs): r"""Calculate KL-divergence between the each PoE subset posterior and the prior distribution. Args: qz_xs (list): list of encoding distributions. Returns: (torch.Tensor): KL-divergence loss. """ weight = 1/len(qz_xs) kl = 0 for qz_x in qz_xs: kl +=qz_x.kl_divergence(self.prior).mean(0).sum() return kl*weight
[docs] def set_subsets(self, n_views=None): """Create combinations of subsets of views. Returns: subset_list (list): list of unique combinations of n_views. """ if n_views is None: n_views = self.n_views xs = list(range(0, n_views)) tmp = [list(combinations(xs, n+1)) for n in range(len(xs))] subset_list = [list(item) for sublist in tmp for item in sublist] return subset_list
[docs] def calc_ll(self, x, px_zs): r"""Calculate log-likelihood loss. Args: x (list): list of input data of type torch.Tensor. px_zs (list): list of decoding distributions. Returns: ll (torch.Tensor): Log-likelihood loss. """ ll = 0 for i in range(self.n_views): ll += px_zs[0][i].log_likelihood(x[i]).mean(0).sum()*self.ll_weighting #first index is latent, second index is view return ll