Source code for multiviewae.base.dataloaders

import random
import numpy as np
import pytorch_lightning as pl
import hydra
from torch.utils.data import DataLoader

[docs]class MultiviewDataModule(pl.LightningDataModule): """LightningDataModule for multi-view data. Args: n_views (int): Number of views in the data. batch_size (int): Batch size. is_validate (bool): Whether to use a validation set. train_size (float): Proportion of batch to use for training between 0 and 1. Remainder of batch is used for validation. data (list): Input data. list of torch.Tensors. labels (np.array): Dataset labels. """ def __init__( self, n_views, batch_size, is_validate, train_size, dataset, data, labels ): super().__init__() self.n_views = n_views self.batch_size = batch_size self.is_validate = is_validate self.train_size = train_size self.data = data self.labels = labels self.dataset = dataset if not isinstance(self.batch_size, int): self.batch_size = self.data[0].shape[0]
[docs] def setup(self, stage): if self.is_validate: train_data, test_data, train_labels, test_labels = self.train_test_split() self.train_dataset = hydra.utils.instantiate(self.dataset, data=train_data, labels=train_labels, n_views=self.n_views) self.test_dataset = hydra.utils.instantiate(self.dataset, data=test_data, labels=test_labels, n_views=self.n_views) else: self.train_dataset = hydra.utils.instantiate(self.dataset, data=self.data, labels=self.labels, n_views=self.n_views) self.test_dataset = None del self.data
def train_test_split(self): N = self.data[0].shape[0] train_idx = list(random.sample(range(N), int(N * self.train_size))) test_idx = np.setdiff1d(list(range(N)), train_idx) train_data = [] test_data = [] for dt in self.data: train_data.append(dt[train_idx, :]) test_data.append(dt[test_idx, :]) train_labels = None test_labels = None if self.labels is not None: train_labels = self.labels[train_idx] test_labels = self.labels[test_idx] return train_data, test_data, train_labels, test_labels
[docs] def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0) #use default num_workers for now, problem in windows!
[docs] def val_dataloader(self): if self.is_validate: return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0) return None
[docs]class IndexDataModule(MultiviewDataModule): """LightningDataModule for multi-view data. Args: n_views (int): Number of views in the data. batch_size (int): Batch size. is_validate (bool): Whether to use a validation set. train_size (float): Proportion of batch to use for training between 0 and 1. Remainder of batch is used for validation. data (list): Input data. list of identifiers to load data from. labels (np.array): Dataset labels. """ def __init__( self, n_views, batch_size, is_validate, train_size, dataset, data, labels, ): data_ = data[0] if not isinstance(batch_size, int): batch_size = len(data_) super().__init__(n_views=n_views, batch_size=batch_size, is_validate=is_validate, train_size=train_size, dataset=dataset, data=data_, labels=labels) def train_test_split(self): N = len(self.data) train_idx = list(random.sample(range(N), int(N * self.train_size))) test_idx = list(set(list(range(N))) - set(train_idx)) data = self.data train_data = [data[i] for i in train_idx] test_data = [data[i] for i in test_idx] train_labels = None test_labels = None if self.labels is not None: train_labels = self.labels[train_idx] test_labels = self.labels[test_idx] return [train_data], [test_data], train_labels, test_labels