Source code for multiviewae.models.mmJSD

import torch
import hydra

from ..base.constants import MODEL_MMJSD
from ..base.base_model import BaseModelVAE
from ..base.representations import MixtureOfExperts, alphaProductOfExperts

[docs]class mmJSD(BaseModelVAE): r""" Multimodal Jensen-Shannon divergence (mmJSD) model with Product-of-Experts dynamic prior. Code is based on: https://github.com/thomassutter/mmjsd Args: cfg (str): Path to configuration file. Model specific parameters in addition to default parameters: - model.private (bool): Whether to include private modality-specific latent dimensions. - model.beta (int, float): KL divergence weighting term. - model.alpha (int, float): JSD divergence weighting term. - model.s_dim (int): Number of private latent dimensions. - encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use. - encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution. - decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use. - decoder.default.init_logvar(int, float): Initial value for log variance of decoder. - decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution. input_dim (list): Dimensionality of the input data. z_dim (int): Number of latent dimensions. References ---------- Sutter, Thomas & Daunhawer, Imant & Vogt, Julia. (2021). Multimodal Generative Learning Utilizing Jensen-Shannon-Divergence. Advances in Neural Information Processing Systems. 33. """ def __init__( self, cfg = None, input_dim = None, z_dim = None ): super().__init__(model_name=MODEL_MMJSD, cfg=cfg, input_dim=input_dim, z_dim=z_dim) if self.weight_ll: self.ll_weighting = 1/self.n_views else: self.ll_weighting = 1
[docs] def encode(self, x): r"""Forward pass through encoder network. If self.private=True, the first two dimensions of each latent are used for the modality-specific part and the remaining dimensions for the joint content. Args: x (list): list of input data of type torch.Tensor. Returns: Returns a combination of the following depending on the training stage and model type: qz_xs (list): Single element list containing the MoE encoding distribution for self.private=False. qzs_xs (list): list containing each encoding distribution for self.private=False. qz_x (list): Single element list containing the PoE encoding distribution for self.private=False. qc_x (list): Single element list containing the PoE shared encoding distribution. qscs_xs (list): list containing combined shared and private latents. qs_xs (list): list of encoding distributions for private latents. qcs_xs (list): list containing encoding distributions for shared latent dimensions for each view. """ if self.private: qs_xs = [] qcs_xs = [] mu_s = [] logvar_s = [] mu_c = [] logvar_c = [] for i in range(self.n_views): mu, logvar = self.encoders[i](x[i]) mu_s.append(mu[:,:self.s_dim]) logvar_s.append(logvar[:,:self.s_dim]) mu_c.append(mu[:,self.s_dim:]) logvar_c.append(logvar[:,self.s_dim:]) qs_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu[:,:self.s_dim], logvar=logvar[:,:self.s_dim] ) qs_xs.append(qs_x) qc_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu[:,self.s_dim:], logvar=logvar[:,self.s_dim:] ) qcs_xs.append(qc_x) mu_c = torch.stack(mu_c) logvar_c = torch.stack(logvar_c) moe_mu_c, moe_logvar_c = MixtureOfExperts()(mu_c, logvar_c) poe_mu_c, poe_logvar_c = alphaProductOfExperts()(mu_c, logvar_c) qc_x = hydra.utils.instantiate( self.cfg.encoder.default.enc_dist, loc=poe_mu_c, logvar=poe_logvar_c ) qscs_xs = [] for i in range(self.n_views): mu_sc = torch.cat((mu_s[i], moe_mu_c), 1) logvar_sc = torch.cat((logvar_s[i], moe_logvar_c), 1) qsc_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu_sc, logvar=logvar_sc ) qscs_xs.append(qsc_x) if self._training: return [[qc_x], qcs_xs, qs_xs, qscs_xs] qscs_xs = [] for i in range(self.n_views): mu_sc = torch.cat((mu_s[i], poe_mu_c), 1) logvar_sc = torch.cat((logvar_s[i], poe_logvar_c), 1) qsc_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu_sc, logvar=logvar_sc ) qscs_xs.append(qsc_x) return qscs_xs mu = [] logvar = [] qzs_xs = [] for i in range(self.n_views): mu_, logvar_ = self.encoders[i](x[i]) mu.append(mu_) logvar.append(logvar_) qz_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu_, logvar=logvar_ ) qzs_xs.append(qz_x) #add prior expert mu_ = self.prior.mean mu_ = mu_.expand(mu[0].shape).to(mu[0].device) logvar_ = torch.log(self.prior.variance).to(mu[0].device) logvar_ = logvar_.expand(logvar[0].shape) qz_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu_, logvar=logvar_ ) mu.append(mu_) logvar.append(logvar_) qzs_xs.append(qz_x) mu = torch.stack(mu) logvar = torch.stack(logvar) moe_mu, moe_logvar = MixtureOfExperts()(mu, logvar) qz_xs = hydra.utils.instantiate( self.cfg.encoder.default.enc_dist, loc=moe_mu, logvar=moe_logvar ) poe_mu, poe_logvar = alphaProductOfExperts()(mu, logvar) qz_x = hydra.utils.instantiate( self.cfg.encoder.default.enc_dist, loc=poe_mu, logvar=poe_logvar ) if self._training: return [[qz_xs], qzs_xs, [qz_x]] return [qz_x]
[docs] def encode_subset(self, x, subset): r""" Forward pass through encoder networks for a subset of modalities. Args: x (list): list of input data of type torch.Tensor. subset (list): list of modalities to encode. Returns: (list): list containing the PoE joint encoding distribution. """ assert (self.private==False), \ "Subset feature only works for private=False (for now)." mu = [] logvar = [] qzs_xs = [] for i in subset: mu_, logvar_ = self.encoders[i](x[i]) mu.append(mu_) logvar.append(logvar_) qz_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu_, logvar=logvar_ ) qzs_xs.append(qz_x) if len(subset) > 1: mu_ = self.prior.mean mu_ = mu_.expand(mu[0].shape).to(mu[0].device) logvar_ = torch.log(self.prior.variance).to(mu[0].device) logvar_ = logvar_.expand(logvar[0].shape) mu.append(mu_) logvar.append(logvar_) mu = torch.stack(mu) logvar = torch.stack(logvar) poe_mu, poe_logvar = alphaProductOfExperts()(mu, logvar) qz_x = hydra.utils.instantiate( self.cfg.encoder.default.enc_dist, loc=poe_mu, logvar=poe_logvar ) return [qz_x]
[docs] def decode(self, qz_x): r"""Forward pass of latent dimensions through decoder networks. Args: x (list): list of input data of type torch.Tensor. Returns: (list): A nested list of decoding distributions, px_zs. The outer list has a single element, the inner list is a n_view element list with the position in the list indicating the decoder index. """ px_zs = [] for i in range(self.n_views): if self.private: px_z = self.decoders[i](qz_x[i]._sample(training=self._training, return_mean=self.return_mean)) else: px_z = self.decoders[i](qz_x[0]._sample(training=self._training, return_mean=self.return_mean)) px_zs.append(px_z) return [px_zs]
[docs] def decode_subset(self, qz_x, subset): r"""Forward pass of latent dimensions through decoder networks for a subset of modalities. Args: x (list): list of input data of type torch.Tensor. subset (list): list of modalities to decode. Returns: (list): A nested list of decoding distributions, px_zs. The outer list has a single element, the inner list is a n_view element list with the position in the list indicating the decoder index. """ assert (self.private==False), \ "Subset feature only works for private=False (for now)." px_zs = [] for i in subset: px_z = self.decoders[i](qz_x[0]._sample(training=self._training, return_mean=self.return_mean)) px_zs.append(px_z) return [px_zs]
[docs] def forward(self, x): r"""Apply encode and decode methods to input data to generate the joint and modality specific latent dimensions and data reconstructions. Args: x (list): list of input data of type torch.Tensor. Returns: fwd_rtn (dict): dictionary containing encoding and decoding distributions. """ if self.private: qc_x, qcs_xs, qs_xs, qscs_xs = self.encode(x) px_zs = self.decode(qscs_xs) fwd_rtn = {"px_zs": px_zs, "qcs_xs": qcs_xs, "qs_xs": qs_xs, "qc_x": qc_x} else: qz_xs, qzs_xs, qz_x = self.encode(x) px_zs = self.decode(qz_xs) fwd_rtn = {"px_zs": px_zs, "qzs_xs": qzs_xs, "qz_x": qz_x} return fwd_rtn
[docs] def calc_kl(self, qz_xs): r"""Calculate KL-divergence loss. Args: qz_xs (list): list of encoding distributions. Returns: (torch.Tensor): KL-divergence loss. """ kl = 0 for i in range(len(qz_xs)): kl += qz_xs[i].kl_divergence(self.prior).sum(1, keepdims=True).mean(0) return self.beta*kl/self.n_views
[docs] def calc_ll(self, x, px_zs): r"""Calculate log-likelihood loss. Args: x (list): list of input data of type torch.Tensor. px_zs (list): list of decoding distributions. Returns: ll (torch.Tensor): Log-likelihood loss. """ ll = 0 for i in range(self.n_views): ll += px_zs[0][i].log_likelihood(x[i]).mean(0).sum()*self.ll_weighting #first index is latent, second index is view return ll
[docs] def calc_jsd(self, qcs_xs, qc_x): r"""Calculate Jensen-Shannon Divergence loss. Args: qcs_xs (list): list of encoding distributions of each view for shared latent dimensions. qc_x (list): Dynamic prior given by PoE of shared encoding distributions. Returns: jsd (torch.Tensor): Jensen-Shannon Divergence loss. """ jsd = 0 for i in range(self.n_views): jsd += qcs_xs[i].kl_divergence(qc_x[0]).sum(1, keepdims=True).mean(0) return self.alpha*jsd/(self.n_views+1)
[docs] def loss_function(self, x, fwd_rtn): r"""Calculate mmJSD loss. Args: x (list): list of input data of type torch.Tensor. fwd_rtn (dict): dictionary containing encoding and decoding distributions. Returns: losses (dict): dictionary containing each element of the mmJSD loss. """ px_zs = fwd_rtn["px_zs"] ll = self.calc_ll(x, px_zs) if self.private: qcs_xs = fwd_rtn["qcs_xs"] qs_xs = fwd_rtn["qs_xs"] qc_x = fwd_rtn["qc_x"] jsd = self.calc_jsd(qcs_xs, qc_x) kl = self.calc_kl(qs_xs) total = kl + jsd - ll losses = {"loss": total, "kl": kl, "ll": ll, "jsd": jsd} else: qz_x = fwd_rtn["qz_x"] qzs_xs = fwd_rtn["qzs_xs"] jsd = self.calc_jsd(qzs_xs, qz_x) total = jsd - ll losses = {"loss": total, "ll": ll, "jsd": jsd} return losses