from typing import Sized import os import numpy as np import torch import pickle from abc import ABC, abstractmethod from torch.utils.data import Dataset from torch.utils.data import DataLoader, Subset from configs.config import ConfigManager class AdvancedDataset(ABC, Dataset, Sized): def __init__(self, config): super(AdvancedDataset, self).__init__() self.config = config self.use_cache = ConfigManager.get("settings", "experiment", "use_cache") exp_root = ConfigManager.get("settings", "experiment", "root_dir") exp_name = ConfigManager.get("settings", "experiment", "name") self.cache_path = os.path.join(exp_root,exp_name,"cache",self.config["name"]) if self.use_cache and not os.path.exists(self.cache_path): os.makedirs(self.cache_path) @staticmethod def process_batch(batch, device): for key in batch.keys(): if isinstance(batch[key], list): continue batch[key] = batch[key].to(device) return batch @abstractmethod def getitem(self, index) -> dict: raise NotImplementedError def __getitem__(self, index) -> dict: cache_data_path = os.path.join(self.cache_path, f"{index}.pkl") if self.use_cache and os.path.exists(cache_data_path): with open(cache_data_path, "rb") as f: item = pickle.load(f) else: item = self.getitem(index) if self.use_cache: with open(cache_data_path, "wb") as f: pickle.dump(item, f) return item def get_loader(self, device, shuffle=False): ratio = self.config["ratio"] if ratio > 1 or ratio <= 0: raise ValueError( f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}" ) subset_size = int(len(self) * ratio) indices = np.random.permutation(len(self))[:subset_size] subset = Subset(self, indices) return DataLoader( subset, batch_size=self.config["batch_size"], num_workers=self.config["num_workers"], shuffle=shuffle, #generator=torch.Generator(device=device), )