2024-09-13 16:58:34 +08:00

45 lines
1.3 KiB
Python

from abc import ABC
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Subset
from PytorchBoot.component import Component
class BaseDataset(ABC, Dataset, Component):
def __init__(self, config):
super(BaseDataset, self).__init__()
self.config = config
@staticmethod
def process_batch(batch, device):
for key in batch.keys():
if isinstance(batch[key], list):
continue
batch[key] = batch[key].to(device)
return batch
def get_collate_fn(self):
return None
def get_loader(self, shuffle=False):
ratio = self.config["ratio"]
if ratio > 1 or ratio <= 0:
raise ValueError(
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
)
subset_size = max(1,int(len(self) * ratio))
indices = np.random.permutation(len(self))[:subset_size]
subset = Subset(self, indices)
return DataLoader(
subset,
batch_size=self.config["batch_size"],
num_workers=self.config["num_workers"],
shuffle=shuffle,
collate_fn=self.get_collate_fn(),
)