import ast
import torch
import hydra
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
class Encoder(nn.Module):
"""MLP Encoder
Args:
input_dim (list): Dimensionality of the input data.
z_dim (int): Number of latent dimensions.
hidden_layer_dim (list): Number of nodes per hidden layer.
non_linear (bool): Whether to include a ReLU() function between layers.
bias (bool): Whether to include a bias term in hidden layers.
enc_dist (multiviewae.base.distributions.Default): Encoder distribution.
"""
def __init__(
self,
input_dim,
z_dim,
hidden_layer_dim,
non_linear,
bias,
enc_dist
):
super().__init__()
self.input_size = input_dim
self.z_dim = z_dim
self.hidden_layer_dim = hidden_layer_dim
self.bias = bias
self.enc_dist = enc_dist
self.non_linear = non_linear
self.layer_sizes = [input_dim] + self.hidden_layer_dim + [z_dim]
lin_layers = [
nn.Linear(dim0, dim1, bias=self.bias)
for dim0, dim1 in zip(
self.layer_sizes[:-1], self.layer_sizes[1:]
)
]
self.encoder_layers = nn.Sequential(*lin_layers)
def forward(self, x):
h1 = x
for it_layer, layer in enumerate(self.encoder_layers[0:-1]):
h1 = layer(h1)
if self.non_linear:
h1 = F.relu(h1)
h1 = self.encoder_layers[-1](h1)
return h1
class VariationalEncoder(Encoder):
"""Variational MLP Encoder
Args:
input_dim (list): Dimensionality of the input data.
z_dim (int): Number of latent dimensions.
hidden_layer_dim (list): Number of nodes per hidden layer.
non_linear (bool): Whether to include a ReLU() function between layers.
bias (bool): Whether to include a bias term in hidden layers.
sparse (bool): Whether to enforce sparsity of the encoding distribution.
log_alpha (float): Log of the dropout parameter.
enc_dist (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoder distribution.
multiple_latents (bool, optional): Whether the model using a separate linear layers for shared and private latent spaces.
u_dim (int, optional): Dimensionality of the shared latent space.
w_dim (int, optional): Dimensionality of the private latent space.
"""
def __init__(
self,
input_dim,
z_dim,
hidden_layer_dim,
non_linear,
bias,
sparse,
log_alpha,
enc_dist,
multiple_latents=False,
u_dim=None,
w_dim=None
):
super().__init__(input_dim=input_dim,
z_dim=z_dim,
hidden_layer_dim=hidden_layer_dim,
bias=bias,
non_linear=non_linear,
enc_dist=enc_dist)
self.sparse = sparse
self.non_linear = non_linear
self.log_alpha = log_alpha
self.multiple_latents = multiple_latents
self.u_dim = u_dim
self.w_dim = w_dim
if self.multiple_latents:
assert self.u_dim is not None and self.w_dim is not None, "u_dim and w_dim must be specified for multiple_latents=True"
assert self.u_dim + self.w_dim == self.z_dim, "u_dim + w_dim must equal z_dim"
self.layer_sizes[-1] = self.u_dim
self.encoder_layers = self.encoder_layers[:-1]
self.enc_mean_layer = nn.Linear(
self.layer_sizes[-2],
self.layer_sizes[-1],
bias=self.bias,
)
if not self.sparse:
self.enc_logvar_layer = nn.Linear(
self.layer_sizes[-2],
self.layer_sizes[-1],
bias=self.bias,
)
if self.multiple_latents:
self.enc_mean_layer_private = nn.Linear(
self.layer_sizes[-2],
self.w_dim,
bias=self.bias,
)
self.enc_logvar_layer_private = nn.Linear(
self.layer_sizes[-2],
self.w_dim,
bias=self.bias,
)
def forward(self, x):
h1 = x
for it_layer, layer in enumerate(self.encoder_layers):
h1 = layer(h1)
if self.non_linear:
h1 = F.relu(h1)
if not self.sparse:
mu = self.enc_mean_layer(h1)
logvar = self.enc_logvar_layer(h1)
if self.multiple_latents:
mu_private = self.enc_mean_layer_private(h1)
logvar_private = self.enc_logvar_layer_private(h1)
return mu, logvar, mu_private, logvar_private
else:
mu = self.enc_mean_layer(h1)
logvar = self.log_alpha + 2 * torch.log(torch.abs(mu) + 1e-8)
return mu, logvar
[docs]class ConditionalVariationalEncoder(Encoder):
"""MLP Variational Conditional Encoder
Args:
y (list):
input_dim (list): Dimensionality of the input data.
z_dim (int): Number of latent dimensions.
hidden_layer_dim (list): Number of nodes per hidden layer.
non_linear (bool): Whether to include a ReLU() function between layers.
bias (bool): Whether to include a bias term in hidden layers.
sparse (bool): Whether to enforce sparsity of the encoding distribution.
log_alpha (float): Log of the dropout parameter.
enc_dist (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Encoder distribution.
num_cat (int): Number of categories of the labels.
one_hot (bool): Whether to one-hot encode the labels.
multiple_latents (bool, optional): Whether the model using a separate linear layers for shared and private latent spaces.
u_dim (int, optional): Dimensionality of the shared latent space.
w_dim (int, optional): Dimensionality of the private latent space.
"""
def __init__(
self,
input_dim,
z_dim,
hidden_layer_dim,
non_linear,
bias,
sparse,
log_alpha,
enc_dist,
num_cat,
one_hot,
multiple_latents=False,
u_dim=None,
w_dim=None
):
super().__init__(input_dim=input_dim,
z_dim=z_dim,
hidden_layer_dim=hidden_layer_dim,
bias=bias,
non_linear=non_linear,
enc_dist=enc_dist)
self.num_cat = num_cat
self.sparse = sparse
self.non_linear = non_linear
self.log_alpha = log_alpha
self.one_hot = one_hot
self.multiple_latents = multiple_latents
self.u_dim = u_dim
self.w_dim = w_dim
self.layer_sizes = [input_dim + num_cat] + self.hidden_layer_dim + [z_dim]
if self.multiple_latents:
assert self.u_dim is not None and self.w_dim is not None, "u_dim and w_dim must be specified for multiple_latents=True"
assert self.u_dim + self.w_dim == self.z_dim, "u_dim + w_dim must equal z_dim"
self.layer_sizes[-1] = self.u_dim
lin_layers = [
nn.Linear(dim0, dim1, bias=self.bias)
for dim0, dim1 in zip(
self.layer_sizes[:-1], self.layer_sizes[1:]
)
]
self.encoder_layers = nn.Sequential(*lin_layers)
self.encoder_layers = self.encoder_layers[:-1]
self.enc_mean_layer = nn.Linear(
self.layer_sizes[-2],
self.layer_sizes[-1],
bias=self.bias,
)
if not self.sparse:
self.enc_logvar_layer = nn.Linear(
self.layer_sizes[-2],
self.layer_sizes[-1],
bias=self.bias,
)
if self.multiple_latents:
self.enc_mean_layer_private = nn.Linear(
self.layer_sizes[-2],
self.w_dim,
bias=self.bias,
)
self.enc_logvar_layer_private = nn.Linear(
self.layer_sizes[-2],
self.w_dim,
bias=self.bias,
)
def set_labels(self, labels):
self.labels = labels
[docs] def forward(self, x):
if self.one_hot:
c = F.one_hot(self.labels.long(), self.num_cat)
else:
c = self.labels
x_cond = torch.hstack((x, c))
h1 = x_cond
for it_layer, layer in enumerate(self.encoder_layers):
h1 = layer(h1)
if self.non_linear:
h1 = F.relu(h1)
if not self.sparse:
mu = self.enc_mean_layer(h1)
logvar = self.enc_logvar_layer(h1)
if self.multiple_latents:
mu_private = self.enc_mean_layer_private(h1)
logvar_private = self.enc_logvar_layer_private(h1)
return mu, logvar, mu_private, logvar_private
else:
mu = self.enc_mean_layer(h1)
logvar = self.log_alpha + 2 * torch.log(torch.abs(mu) + 1e-8)
return mu, logvar
class Decoder(nn.Module):
"""MLP Decoder
Args:
input_dim (list): Dimensionality of the input data.
z_dim (int): Number of latent dimensions.
hidden_layer_dim (list): Number of nodes per hidden layer. The layer order is reversed e.g. [100, 50, 5] becomes [5, 50, 100].
non_linear (bool): Whether to include a ReLU() function between layers.
bias (bool): Whether to include a bias term in hidden layers.
dec_dist (multiviewae.base.distributions.Default, multiviewae.base.distributions.Bernoulli): Decoder distribution.
init_logvar (int, float): Initial value for log variance of decoder. Unused in Decoder class.
"""
def __init__(
self,
input_dim,
z_dim,
hidden_layer_dim,
non_linear,
bias,
dec_dist,
init_logvar=None
):
super().__init__()
self.input_size = input_dim
self.z_dim = z_dim
self.hidden_layer_dim = hidden_layer_dim
self.bias = bias
self.dec_dist = dec_dist
self.non_linear = non_linear
self.layer_sizes = [z_dim] + self.hidden_layer_dim + [input_dim]
lin_layers = [
nn.Linear(dim0, dim1, bias=self.bias)
for dim0, dim1 in zip(
self.layer_sizes[:-1], self.layer_sizes[1:]
)
]
self.decoder_layers = nn.Sequential(*lin_layers)
def forward(self, z):
x_rec = z
for it_layer, layer in enumerate(self.decoder_layers[:-1]):
x_rec = layer(x_rec)
if self.non_linear:
x_rec = F.relu(x_rec)
x_rec = self.decoder_layers[-1](x_rec)
x_rec = hydra.utils.instantiate(self.dec_dist, x=x_rec)
return x_rec
[docs]class VariationalDecoder(Decoder):
"""MLP Variational Decoder
Args:
input_dim (list): Dimensionality of the input data.
z_dim (int): Number of latent dimensions.
hidden_layer_dim (list): Number of nodes per hidden layer. The layer order is reversed e.g. [100, 50, 5] becomes [5, 50, 100].
non_linear (bool): Whether to include a ReLU() function between layers.
bias (bool): Whether to include a bias term in hidden layers.
init_logvar (int, float): Initial value for log variance of decoder.
dec_dist (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoder distribution.
"""
def __init__(
self,
input_dim,
z_dim,
hidden_layer_dim,
bias,
non_linear,
init_logvar,
dec_dist
):
super().__init__(input_dim=input_dim,
z_dim=z_dim,
hidden_layer_dim=hidden_layer_dim,
non_linear=non_linear,
bias=bias, dec_dist=dec_dist)
self.non_linear = non_linear
self.init_logvar = init_logvar
self.dec_dist = dec_dist
self.decoder_layers = self.decoder_layers[:-1]
self.dec_mean_layer = nn.Linear(
self.layer_sizes[-2],
self.layer_sizes[-1],
bias=self.bias,
)
tmp_noise_par = torch.FloatTensor(1, self.input_size).fill_(
self.init_logvar
)
self.logvar_out = Parameter(data=tmp_noise_par, requires_grad=True)
[docs] def forward(self, z):
x_rec = z
for it_layer, layer in enumerate(self.decoder_layers):
x_rec = layer(x_rec)
if self.non_linear:
x_rec = F.relu(x_rec)
x_rec = self.dec_mean_layer(x_rec)
x_rec = hydra.utils.instantiate(
self.dec_dist, loc=x_rec, scale=torch.exp(0.5 * self.logvar_out)
)
return x_rec
[docs]class ConditionalVariationalDecoder(Decoder):
"""MLP Conditinal Variational Decoder
Args:
input_dim (list): Dimensionality of the input data.
z_dim (int): Number of latent dimensions.
hidden_layer_dim (list): Number of nodes per hidden layer. The layer order is reversed e.g. [100, 50, 5] becomes [5, 50, 100].
non_linear (bool): Whether to include a ReLU() function between layers.
bias (bool): Whether to include a bias term in hidden layers.
init_logvar (int, float): Initial value for log variance of decoder.
dec_dist (multiviewae.base.distributions.Normal, multiviewae.base.distributions.MultivariateNormal): Decoder distribution.
num_cat (int): Number of categories of the labels.
one_hot (bool): Whether to one-hot encode the labels.
"""
def __init__(
self,
input_dim,
z_dim,
hidden_layer_dim,
bias,
non_linear,
init_logvar,
dec_dist,
num_cat,
one_hot
):
super().__init__(input_dim=input_dim,
z_dim=z_dim,
hidden_layer_dim=hidden_layer_dim,
non_linear=non_linear,
bias=bias, dec_dist=dec_dist)
self.num_cat = num_cat
self.non_linear = non_linear
self.init_logvar = init_logvar
self.dec_dist = dec_dist
self.one_hot = one_hot
self.layer_sizes = [z_dim + num_cat] + self.hidden_layer_dim + [input_dim]
lin_layers = [
nn.Linear(dim0, dim1, bias=self.bias)
for dim0, dim1 in zip(
self.layer_sizes[:-1], self.layer_sizes[1:]
)
]
self.decoder_layers = nn.Sequential(*lin_layers)
self.decoder_layers = self.decoder_layers[:-1]
self.dec_mean_layer = nn.Linear(
self.layer_sizes[-2],
self.layer_sizes[-1],
bias=self.bias,
)
tmp_noise_par = torch.FloatTensor(1, self.input_size).fill_(
self.init_logvar
)
self.logvar_out = Parameter(data=tmp_noise_par, requires_grad=True)
def set_labels(self, labels):
self.labels = labels
[docs] def forward(self, z):
if self.one_hot:
c = F.one_hot(self.labels.long(), self.num_cat)
else:
c = self.labels
if (len(z.size()) == 3 and len(c.size()) == 2): # NOTE: for mmvae which uses rsample() instead of sample()
z_cond = torch.cat((z, c.repeat(z.size()[0],1,1)), dim=2)
else:
z_cond = torch.hstack((z, c))
x_rec = z_cond
for it_layer, layer in enumerate(self.decoder_layers):
x_rec = layer(x_rec)
if self.non_linear:
x_rec = F.relu(x_rec)
x_rec = self.dec_mean_layer(x_rec)
x_rec = hydra.utils.instantiate(
self.dec_dist, loc=x_rec, scale=torch.exp(0.5 * self.logvar_out)
)
return x_rec
[docs]class Discriminator(nn.Module):
"""MLP Discriminator
Args:
input_dim (list): Dimensionality of the input data.
z_dim (int): Number of output dimensions.
hidden_layer_dim (list): Number of nodes per hidden layer.
non_linear (bool): Whether to include a ReLU() function between layers.
bias (bool): Whether to include a bias term in hidden layers.
dropout_threshold (float): Dropout threshold of layers.
is_wasserstein (bool): Whether model employs a wasserstein loss.
"""
def __init__(
self,
input_dim,
output_dim,
hidden_layer_dim,
non_linear,
bias,
dropout_threshold,
is_wasserstein
):
super().__init__()
self.bias = bias
self.non_linear = non_linear
self.dropout_threshold = dropout_threshold
self.is_wasserstein = is_wasserstein
self.layer_sizes = [input_dim] + hidden_layer_dim + [output_dim]
lin_layers = [
nn.Linear(dim0, dim1, bias=self.bias)
for dim0, dim1 in zip(self.layer_sizes[:-1], self.layer_sizes[1:])
]
self.linear_layers = nn.Sequential(*lin_layers)
[docs] def forward(self, x):
for it_layer, layer in enumerate(self.linear_layers):
x = F.dropout(layer(x), self.dropout_threshold, training=self.training)
if it_layer < len(self.linear_layers) - 1:
if self.non_linear:
x = F.relu(x)
else:
if self.is_wasserstein:
return x
elif self.layer_sizes[-1] > 1:
x = nn.Softmax(dim=-1)(x)
else:
x = torch.sigmoid(x)
return x