64 lines
2.3 KiB
Python
Executable File
64 lines
2.3 KiB
Python
Executable File
from typing import Sized
|
|
import os
|
|
import numpy as np
|
|
import torch
|
|
import pickle
|
|
from abc import ABC, abstractmethod
|
|
from torch.utils.data import Dataset
|
|
from torch.utils.data import DataLoader, Subset
|
|
|
|
from configs.config import ConfigManager
|
|
|
|
class AdvancedDataset(ABC, Dataset, Sized):
|
|
def __init__(self, config):
|
|
super(AdvancedDataset, self).__init__()
|
|
self.config = config
|
|
self.use_cache = ConfigManager.get("settings", "experiment", "use_cache")
|
|
exp_root = ConfigManager.get("settings", "experiment", "root_dir")
|
|
exp_name = ConfigManager.get("settings", "experiment", "name")
|
|
self.cache_path = os.path.join(exp_root,exp_name,"cache",self.config["name"])
|
|
if self.use_cache and not os.path.exists(self.cache_path):
|
|
os.makedirs(self.cache_path)
|
|
|
|
@staticmethod
|
|
def process_batch(batch, device):
|
|
for key in batch.keys():
|
|
if isinstance(batch[key], list):
|
|
continue
|
|
batch[key] = batch[key].to(device)
|
|
return batch
|
|
|
|
@abstractmethod
|
|
def getitem(self, index) -> dict:
|
|
raise NotImplementedError
|
|
|
|
def __getitem__(self, index) -> dict:
|
|
cache_data_path = os.path.join(self.cache_path, f"{index}.pkl")
|
|
if self.use_cache and os.path.exists(cache_data_path):
|
|
with open(cache_data_path, "rb") as f:
|
|
item = pickle.load(f)
|
|
else:
|
|
item = self.getitem(index)
|
|
if self.use_cache:
|
|
with open(cache_data_path, "wb") as f:
|
|
pickle.dump(item, f)
|
|
return item
|
|
|
|
def get_loader(self, device, shuffle=False):
|
|
ratio = self.config["ratio"]
|
|
if ratio > 1 or ratio <= 0:
|
|
raise ValueError(
|
|
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
|
|
)
|
|
subset_size = int(len(self) * ratio)
|
|
indices = np.random.permutation(len(self))[:subset_size]
|
|
subset = Subset(self, indices)
|
|
return DataLoader(
|
|
|
|
subset,
|
|
batch_size=self.config["batch_size"],
|
|
num_workers=self.config["num_workers"],
|
|
shuffle=shuffle,
|
|
#generator=torch.Generator(device=device),
|
|
)
|