Source code for multiviewae.models.weighted_mvae

import torch
import hydra

from ..base.constants import MODEL_WEIGHTEDMVAE
from ..base.base_model import BaseModelVAE
from ..base.representations import weightedProductOfExperts

[docs]class weighted_mVAE(BaseModelVAE): r""" Generalised Product-of-Experts Variational Autoencoder (gPoE-MVAE). Args: cfg (str): Path to configuration file. Model specific parameters in addition to default parameters: - model.beta (int, float): KL divergence weighting term. - model.private (bool): Whether to include private view-specific latent dimensions. - model.sparse (bool): Whether to enforce sparsity of the encoding distribution. - model.threshold (float): Dropout threshold applied to the latent dimensions. Default is 0. - encoder.default._target_ (multiviewae.architectures.mlp.VariationalEncoder): Type of encoder class to use. - encoder.default.enc_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoding distribution. - decoder.default._target_ (multiviewae.architectures.mlp.VariationalDecoder): Type of decoder class to use. - decoder.default.init_logvar (int, float): Initial value for log variance of decoder. - decoder.default.dec_dist._target_ (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoding distribution. input_dim (list): Dimensionality of the input data. z_dim (int): Number of latent dimensions. References ---------- Cao, Y., & Fleet, D. (2014). Generalized Product of Experts for Automatic and Principled Fusion of Gaussian Process Predictions. arXiv. Lawry Aguila, A., Chapman, J., Altmann, A. (2023). Multi-modal Variational Autoencoders for normative modelling across multiple imaging modalities. arXiv """ def __init__( self, cfg = None, input_dim = None, z_dim = None ): super().__init__(model_name=MODEL_WEIGHTEDMVAE, cfg=cfg, input_dim=input_dim, z_dim=z_dim) self.join_z = weightedProductOfExperts() #check if self.private is attribute, if not set to False if not hasattr(self, 'private'): self.private = False if self.private: tmp_weight = torch.FloatTensor(len(input_dim), self.z_dim - self.s_dim).fill_(1/len(input_dim)) self.poe_weight = torch.nn.Parameter(data=tmp_weight, requires_grad=True) else: tmp_weight = torch.FloatTensor(len(input_dim), self.z_dim).fill_(1/len(input_dim)) self.poe_weight = torch.nn.Parameter(data=tmp_weight, requires_grad=True)
[docs] def encode(self, x): r"""Forward pass through encoder networks. Args: x (list): list of input data of type torch.Tensor. Returns: Returns a combination of the following depending on the training stage and model type: qz_x (list): Single element list containing the PoE encoding distribution for self.private=False. qc_x (list): Single element list containing the PoE shared encoding distribution. qs_xs (list): list of encoding distributions for private latents. qscs_xs (list): list containing combined shared and private latents. qcs_xs (list): list containing encoding distributions for shared latent dimensions for each view. """ if not hasattr(self, 'private'): self.private = False if self.private: qs_xs = [] qcs_xs = [] mu_s = [] logvar_s = [] mu_c = [] logvar_c = [] for i in range(self.n_views): mu, logvar = self.encoders[i](x[i]) mu_s.append(mu[:,:self.s_dim]) logvar_s.append(logvar[:,:self.s_dim]) mu_c.append(mu[:,self.s_dim:]) logvar_c.append(logvar[:,self.s_dim:]) qs_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu[:,:self.s_dim], logvar=logvar[:,:self.s_dim] ) qs_xs.append(qs_x) qc_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu[:,self.s_dim:], logvar=logvar[:,self.s_dim:] ) qcs_xs.append(qc_x) mu_c = torch.stack(mu_c) logvar_c = torch.stack(logvar_c) mu_c, logvar_c = self.join_z(mu_c, logvar_c, self.poe_weight) qc_x = hydra.utils.instantiate( self.cfg.encoder.default.enc_dist, loc=mu_c, logvar=logvar_c ) with torch.no_grad(): self.poe_weight = self.poe_weight.clamp_(0, +1) qscs_xs = [] for i in range(self.n_views): mu_sc = torch.cat((mu_s[i], mu_c), 1) logvar_sc = torch.cat((logvar_s[i], logvar_c), 1) qsc_x = hydra.utils.instantiate( eval(f"self.cfg.encoder.enc{i}.enc_dist"), loc=mu_sc, logvar=logvar_sc ) qscs_xs.append(qsc_x) if self._training: return [[qc_x], qcs_xs, qs_xs, qscs_xs] return qscs_xs mu = [] logvar = [] for i in range(self.n_views): mu_, logvar_ = self.encoders[i](x[i]) mu.append(mu_) logvar.append(logvar_) mu = torch.stack(mu) logvar = torch.stack(logvar) mu_out, logvar_out = self.join_z(mu, logvar, self.poe_weight) qz_x = hydra.utils.instantiate( self.cfg.encoder.default.enc_dist, loc=mu_out, logvar=logvar_out ) with torch.no_grad(): self.poe_weight = self.poe_weight.clamp_(0, +1) return [qz_x]
[docs] def decode(self, qz_x): r"""Forward pass of joint latent dimensions through decoder networks. Args: x (list): list of input data of type torch.Tensor. Returns: (list): A nested list of decoding distributions, px_zs. The outer list has a single element indicating the shared latent dimensions. The inner list is a n_view element list with the position in the list indicating the decoder index. """ px_zs = [] for i in range(self.n_views): if self.private: px_z = self.decoders[i](qz_x[i]._sample(training=self._training, return_mean=self.return_mean)) else: px_z = self.decoders[i](qz_x[0]._sample(training=self._training, return_mean=self.return_mean)) px_zs.append(px_z) return [px_zs]
[docs] def forward(self, x): r"""Apply encode and decode methods to input data to generate the joint latent dimensions and data reconstructions. Args: x (list): list of input data of type torch.Tensor. Returns: fwd_rtn (dict): dictionary containing encoding and decoding distributions. """ if self.private: qc_x, qcs_xs, qs_xs, qscs_xs = self.encode(x) px_zs = self.decode(qscs_xs) fwd_rtn = {"px_zs": px_zs, "qcs_xs": qcs_xs, "qs_xs": qs_xs, "qc_x": qc_x} else: qz_x = self.encode(x) px_zs = self.decode(qz_x) fwd_rtn = {"px_zs": px_zs, "qz_x": qz_x} return fwd_rtn
[docs] def calc_kl(self, qz_x): r"""Calculate KL-divergence loss. Args: qz_x (list): Single element list containing joint encoding distribution. Returns: (torch.Tensor): KL-divergence loss. """ kl = qz_x[0].kl_divergence(self.prior).mean(0).sum() return self.beta * kl
[docs] def calc_kl_separate(self, qc_xs): r"""Calculate KL-divergence loss. Args: qc_xs (list): list of encoding distributions for private/shared latent dimensions for each view. Returns: (torch.Tensor): KL-divergence loss. """ kl = 0 for i in range(self.n_views): kl += qc_xs[i].kl_divergence(self.prior).mean(0).sum() return self.beta * kl/self.n_views
[docs] def calc_ll(self, x, px_zs): r"""Calculate log-likelihood loss. Args: x (list): list of input data of type torch.Tensor. px_zs (list): list of decoding distributions. Returns: ll (torch.Tensor): Log-likelihood loss. """ ll = 0 for i in range(self.n_views): ll += px_zs[0][i].log_likelihood(x[i]).mean(0).sum() #first index is latent, second index is view return ll/self.n_views
[docs] def loss_function(self, x, fwd_rtn): r"""Calculate Multimodal VAE loss. Args: x (list): list of input data of type torch.Tensor. fwd_rtn (dict): dictionary containing encoding and decoding distributions. Returns: losses (dict): dictionary containing each element of the MVAE loss. """ px_zs = fwd_rtn["px_zs"] ll = self.calc_ll(x, px_zs) if self.private: qs_xs = fwd_rtn["qs_xs"] qc_x = fwd_rtn["qc_x"] kl = self.calc_kl_separate(qs_xs) #calc kl for private latents kl += self.calc_kl(qc_x) #calc kl for shared latents total = kl - ll losses = {"loss": total, "kl": kl, "ll": ll} return losses else: qz_x = fwd_rtn["qz_x"] kl = self.calc_kl(qz_x) total = kl - ll losses = {"loss": total, "kl": kl, "ll": ll} return losses