2024-10-09 16:13:22 +00:00

64 lines
2.3 KiB
Python
Executable File

from typing import Sized
import os
import numpy as np
import torch
import pickle
from abc import ABC, abstractmethod
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Subset
from configs.config import ConfigManager
class AdvancedDataset(ABC, Dataset, Sized):
def __init__(self, config):
super(AdvancedDataset, self).__init__()
self.config = config
self.use_cache = ConfigManager.get("settings", "experiment", "use_cache")
exp_root = ConfigManager.get("settings", "experiment", "root_dir")
exp_name = ConfigManager.get("settings", "experiment", "name")
self.cache_path = os.path.join(exp_root,exp_name,"cache",self.config["name"])
if self.use_cache and not os.path.exists(self.cache_path):
os.makedirs(self.cache_path)
@staticmethod
def process_batch(batch, device):
for key in batch.keys():
if isinstance(batch[key], list):
continue
batch[key] = batch[key].to(device)
return batch
@abstractmethod
def getitem(self, index) -> dict:
raise NotImplementedError
def __getitem__(self, index) -> dict:
cache_data_path = os.path.join(self.cache_path, f"{index}.pkl")
if self.use_cache and os.path.exists(cache_data_path):
with open(cache_data_path, "rb") as f:
item = pickle.load(f)
else:
item = self.getitem(index)
if self.use_cache:
with open(cache_data_path, "wb") as f:
pickle.dump(item, f)
return item
def get_loader(self, device, shuffle=False):
ratio = self.config["ratio"]
if ratio > 1 or ratio <= 0:
raise ValueError(
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
)
subset_size = int(len(self) * ratio)
indices = np.random.permutation(len(self))[:subset_size]
subset = Subset(self, indices)
return DataLoader(
subset,
batch_size=self.config["batch_size"],
num_workers=self.config["num_workers"],
shuffle=shuffle,
#generator=torch.Generator(device=device),
)