ubuntu version
This commit is contained in:
parent
2503fca572
commit
067a0fe9cc
@ -1,8 +1,10 @@
|
||||
from PytorchBoot.application import PytorchBootApplication
|
||||
from runners.strategy_generator import StrategyGenerator
|
||||
from runners.data_generator import DataGenerator
|
||||
|
||||
@PytorchBootApplication("generate")
|
||||
class Generator:
|
||||
@staticmethod
|
||||
def start():
|
||||
StrategyGenerator("configs\generate_config.yaml").run()
|
||||
#StrategyGenerator("configs\strategy_generate_config.yaml").run()
|
||||
DataGenerator("configs/data_generate_config.yaml").run()
|
24
configs/data_generate_config.yaml
Normal file
24
configs/data_generate_config.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
|
||||
runner:
|
||||
general:
|
||||
seed: 0
|
||||
device: cpu
|
||||
cuda_visible_devices: "0,1,2,3,4,5,6,7"
|
||||
|
||||
|
||||
experiment:
|
||||
name: debug
|
||||
root_dir: "experiments"
|
||||
|
||||
generate:
|
||||
voxel_threshold: 0.005
|
||||
overlap_threshold: 0.3
|
||||
dataset_list:
|
||||
- OmniObject3d
|
||||
|
||||
datasets:
|
||||
OmniObject3d:
|
||||
model_dir: "/media/hofee/data/data/object_meshes"
|
||||
output_dir: "/media/hofee/data/data/omni_sample_output"
|
||||
|
||||
|
@ -19,6 +19,5 @@ runner:
|
||||
datasets:
|
||||
OmniObject3d:
|
||||
root_dir: "C:\\Document\\Local Project\\nbv_rec\\sample_dataset"
|
||||
output_dir: "C:\\Document\\Local Project\\nbv_rec\\sample_output"
|
||||
|
||||
|
34
runners/data_generator.py
Normal file
34
runners/data_generator.py
Normal file
@ -0,0 +1,34 @@
|
||||
import os
|
||||
import json
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.config import ConfigManager
|
||||
from PytorchBoot.utils import Log
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
|
||||
@stereotype.runner("data_generator", comment="unfinished")
|
||||
class DataGenerator(Runner):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.load_experiment("generate")
|
||||
|
||||
def run(self):
|
||||
dataset_name_list = ConfigManager.get("runner", "generate" ,"dataset_list")
|
||||
for dataset_name in dataset_name_list:
|
||||
self.generate(dataset_name)
|
||||
|
||||
def generate(self, dataset_name):
|
||||
dataset_config = ConfigManager.get("datasets", dataset_name)
|
||||
model_dir = dataset_config["model_dir"]
|
||||
output_dir = dataset_config["output_dir"]
|
||||
Log.debug(model_dir)
|
||||
Log.debug(output_dir)
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
output_dir = os.path.join(str(self.experiment_path), "output")
|
||||
os.makedirs(output_dir)
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
|
Loading…
x
Reference in New Issue
Block a user