Source code for multiviewae.models.dccae

from ..base.constants import MODEL_DCCAE, EPS
from ..base.base_model import BaseModelAE
import torch

[docs]class DCCAE(BaseModelAE): r"""Deep Canonically Correlated Autoencoder (DCCAE). CCA implementation adapted from: https://github.com/jameschapman19/cca_zoo Args: cfg (str): Path to configuration file. Model specific parameters in addition to default parameters: - model._lambda (int, float): Reconstruction weighting term input_dim (list): Dimensionality of the input data. z_dim (int): Number of latent dimensions. References ---------- Wang, Weiran & Arora, Raman & Livescu, Karen & Bilmes, Jeff. (2016). On Deep Multi-View Representation Learning: Objectives and Optimization. """ def __init__( self, cfg = None, input_dim = None, z_dim = None ): super().__init__(model_name=MODEL_DCCAE, cfg=cfg, input_dim=input_dim, z_dim=z_dim)
[docs] def encode(self, x): r"""Forward pass through encoder networks. Args: x (list): list of input data of type torch.Tensor. Returns: z (list): list of latent dimensions for each view of type torch.Tensor. """ z = [] for i in range(self.n_views): z_ = self.encoders[i](x[i]) z.append(z_) return z
[docs] def decode(self, z): r"""Forward pass through decoder networks. Each latent is passed through all of the decoders. Args: z (list): list of latent dimensions for each view of type torch.Tensor. Returns: x_recon (list): list of data reconstructions. """ x_recon = [] for i in range(self.n_views): temp_recon = self.decoders[i](z[i]) x_recon.append(temp_recon) return [x_recon]
[docs] def cca(self, zs): r"""CCA loss calculation. Adapted from: https://github.com/jameschapman19/cca_zoo/blob/main/cca_zoo/deep/_discriminative/_dmcca.py Args: z (list): list of latent dimensions for each view of type torch.Tensor. Returns: cca_loss (torch.Tensor): CCA loss. """ zs = [ z - z.mean(dim=0) for z in zs ] all_views = torch.cat(zs, dim=1) #Calculate cross-covariance matrix C = torch.cov(all_views.T) C = C - torch.block_diag( *[torch.cov(z.T) for z in zs] ) C = C / len(zs) #Calculate block covariance matrix D = torch.block_diag( *[ (1 - EPS) * torch.cov(z.T) + EPS * torch.eye(z.shape[1], device=z.device) for z in zs ] ) D = D / len(zs) C += D U, S, V = torch.svd(D) # Enforce positive definite by taking a torch max() with EPS S = torch.clamp(S, min=EPS) # Calculate inverse square-root inv_sqrt_S = torch.diag_embed(torch.pow(S, -0.5)) # Calculate inverse square-root matrix R = torch.matmul(torch.matmul(U, inv_sqrt_S), V.transpose(-1, -2)) C_whitened = R @ C @ R.T eigvals = torch.linalg.eigvalsh(C_whitened) idx = torch.argsort(eigvals, descending=True) eigvals = eigvals[idx[:self.z_dim]] eigvals = torch.nn.LeakyReLU()(eigvals[torch.gt(eigvals, 0)]) corr = eigvals.sum() return -corr
[docs] def forward(self, x): r"""Apply encode and decode methods to input data to generate latent dimensions and data reconstructions. Args: x (list): list of input data of type torch.Tensor. Returns: fwd_rtn (dict): dictionary containing list of data reconstructions (x_recon) and latent dimensions (z). """ z = self.encode(x) x_recon = self.decode(z) fwd_rtn = {"x_recon": x_recon, "z": z} return fwd_rtn
[docs] def loss_function(self, x, fwd_rtn): r"""Calculate reconstruction loss. Args: x (list): list of input data of type torch.Tensor. fwd_rtn (dict): dictionary containing list of data reconstructions (x_recon) and latent dimensions (z). Returns: losses (dict): dictionary containing reconstruction loss. """ x_recon = fwd_rtn["x_recon"] z = fwd_rtn["z"] recon = 0 for i in range(self.n_views): recon += - x_recon[0][i].log_likelihood(x[i]).mean(0).sum() #first index is latent, second index is view cca_loss = self.cca(z) total_loss = self._lambda * recon + cca_loss losses = {"loss": total_loss, "recon_loss": recon, "cca_loss": cca_loss} return losses