45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
from abc import ABC
|
|
|
|
import numpy as np
|
|
|
|
from torch.utils.data import Dataset
|
|
from torch.utils.data import DataLoader, Subset
|
|
|
|
from PytorchBoot.component import Component
|
|
|
|
class BaseDataset(ABC, Dataset, Component):
|
|
def __init__(self, config):
|
|
super(BaseDataset, self).__init__()
|
|
self.config = config
|
|
|
|
@staticmethod
|
|
def process_batch(batch, device):
|
|
for key in batch.keys():
|
|
if isinstance(batch[key], list):
|
|
continue
|
|
batch[key] = batch[key].to(device)
|
|
return batch
|
|
|
|
def get_collate_fn(self):
|
|
return None
|
|
|
|
def get_loader(self, shuffle=False):
|
|
ratio = self.config["ratio"]
|
|
if ratio > 1 or ratio <= 0:
|
|
raise ValueError(
|
|
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
|
|
)
|
|
subset_size = max(1,int(len(self) * ratio))
|
|
indices = np.random.permutation(len(self))[:subset_size]
|
|
subset = Subset(self, indices)
|
|
return DataLoader(
|
|
subset,
|
|
batch_size=self.config["batch_size"],
|
|
num_workers=self.config["num_workers"],
|
|
shuffle=shuffle,
|
|
collate_fn=self.get_collate_fn(),
|
|
)
|
|
|
|
|
|
|