Source code for multiviewae.base.datasets

"""
Class for loading data into Pytorch float tensor

From: https://gitlab.com/acasamitjana/latentmodels_ad

"""
import numpy as np
import torch
from torch.utils.data import Dataset
from os.path import join

[docs]class MVDataset(Dataset): """PyTorch Dataset for storing and accessing multi-view data. Args: data (list): Input data. list of torch.Tensors. labels (np.array): Dataset labels. return_index (bool): Whether to return batch index labels. transform (torchvision.transforms): Torchvision transformation to apply to the data. Default is None. """ def __init__(self, data, n_views, labels=None, return_index=False, transform=None): self.data = data self.labels = labels self.n_views = n_views self.return_index = return_index self.transform = transform self.N = len(self.data[0]) self.data = [ torch.from_numpy(d).float() if isinstance(d, np.ndarray) else d for d in self.data ] self.shape = [ np.shape(d) for d in self.data ] if labels is not None: self.labels = torch.from_numpy(self.labels) def __getitem__(self, index): x = [d[index] for d in self.data] if self.transform: x = self.transform(x) if self.return_index: if self.labels is not None: return x, self.labels[index], index return x, index if self.labels is not None: return x, self.labels[index] return x def __len__(self): return self.N
[docs]class IndexMVDataset(Dataset): def __init__(self, data, n_views, labels=None, filename="", data_dir=''): self.N = len(data[0]) self.n_views = n_views self.filename = filename self.data_dir = data_dir self.data = data[0] self.labels = labels def __getitem__(self, index): """ Args: index (int): Index Returns: list (): returns list of numpy arrays """ x = [] idx = self.data[index] for i in range(self.n_views): x_i = np.load(join(self.data_dir, self.filename.format(i, idx))) x_i = torch.from_numpy(x_i).float() x.append(x_i) return x def __len__(self): return self.N