from abc import ABC import numpy as np from torch.utils.data import Dataset from torch.utils.data import DataLoader, Subset from PytorchBoot.component import Component class BaseDataset(ABC, Dataset, Component): def __init__(self, config): super(BaseDataset, self).__init__() self.config = config @staticmethod def process_batch(batch, device): for key in batch.keys(): if isinstance(batch[key], list): continue batch[key] = batch[key].to(device) return batch def get_collate_fn(self): return None def get_loader(self, shuffle=False): ratio = self.config["ratio"] if ratio > 1 or ratio <= 0: raise ValueError( f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}" ) subset_size = max(1,int(len(self) * ratio)) indices = np.random.permutation(len(self))[:subset_size] subset = Subset(self, indices) return DataLoader( subset, batch_size=self.config["batch_size"], num_workers=self.config["num_workers"], shuffle=shuffle, collate_fn=self.get_collate_fn(), )