commit 73dcd592df2210021d40f4ce68d092e34db9650e Author: hofee <64160135+GitHofee@users.noreply.github.com> Date: Sun Aug 18 00:37:17 2024 +0800 Basic Framework diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..17fc0b6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +__pycache__/ +.DS_Store +.idea +experiments/ +pytorch3d/ +test/ +*.xyz +*.zip +*.txt +*.pkl +*.log +/data_generation/data/* +/data_generation/output/* +test/ \ No newline at end of file diff --git a/annotations/external_module.py b/annotations/external_module.py new file mode 100644 index 0000000..5fd15ba --- /dev/null +++ b/annotations/external_module.py @@ -0,0 +1,7 @@ +EXTERNAL_FREEZE_MODULES = set() + +def external_freeze(cls): + if not hasattr(cls, 'load') or not callable(getattr(cls, 'load')): + raise TypeError(f"external module <{cls.__name__}> must implement a 'load' method") + EXTERNAL_FREEZE_MODULES.add(cls) + return cls \ No newline at end of file diff --git a/annotations/stereotype.py b/annotations/stereotype.py new file mode 100644 index 0000000..e7280fc --- /dev/null +++ b/annotations/stereotype.py @@ -0,0 +1,34 @@ +# --- Classes --- # + +def dataset(): + pass + +def module(): + pass + +def pipeline(): + pass + +def runner(): + pass + +def factory(): + pass + +# --- Functions --- # + +evaluation_methods = {} +def evaluation_method(eval_type): + def decorator(func): + evaluation_methods[eval_type] = func + return func + return decorator + + +def loss_function(): + pass + + +# --- Main --- # + + \ No newline at end of file diff --git a/configs/config.py b/configs/config.py new file mode 100644 index 0000000..2e68fa6 --- /dev/null +++ b/configs/config.py @@ -0,0 +1,74 @@ +import argparse +import os.path +import shutil +import yaml + + +class ConfigManager: + config = None + config_path = None + + @staticmethod + def get(*args): + result = ConfigManager.config + for arg in args: + result = result[arg] + return result + + @staticmethod + def load_config_with(config_file_path): + ConfigManager.config_path = config_file_path + if not os.path.exists(ConfigManager.config_path): + raise ValueError(f"Config file <{config_file_path}> does not exist") + with open(config_file_path, 'r') as file: + ConfigManager.config = yaml.safe_load(file) + + @staticmethod + def backup_config_to(target_config_dir, file_name, prefix="config"): + file_name = f"{prefix}_{file_name}.yaml" + target_config_file_path = str(os.path.join(target_config_dir, file_name)) + shutil.copy(ConfigManager.config_path, target_config_file_path) + + @staticmethod + def load_config(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, default='', help='config file path') + args = parser.parse_args() + if args.config: + ConfigManager.load_config_with(args.config) + + @staticmethod + def print_config(key: str = None, group: dict = None, level=0): + table_size = 80 + if key and group: + value = group[key] + if type(value) is dict: + print("\t" * level + f"+-{key}:") + for k in value: + ConfigManager.print_config(k, value, level=level + 1) + else: + print("\t" * level + f"| {key}: {value}") + elif key: + ConfigManager.print_config(key, ConfigManager.config, level=level) + else: + print("+" + "-" * table_size + "+") + print(f"| Configurations in <{ConfigManager.config_path}>:") + print("+" + "-" * table_size + "+") + for key in ConfigManager.config: + ConfigManager.print_config(key, level=level + 1) + print("+" + "-" * table_size + "+") + + +''' ------------ Debug ------------ ''' +if __name__ == "__main__": + test_args = ['--config', r'configs\train_config.yaml'] + test_parser = argparse.ArgumentParser() + test_parser.add_argument('--config', type=str, default='', help='config file path') + test_args = test_parser.parse_args(test_args) + if test_args.config: + ConfigManager.load_config_with(test_args.config) + ConfigManager.print_config() + print() + pipeline = ConfigManager.get('settings', 'train', "dataset", 'batch_size') + ConfigManager.print_config('settings') + print(pipeline) diff --git a/configs/train_config.yaml b/configs/train_config.yaml new file mode 100644 index 0000000..e8e4fbc --- /dev/null +++ b/configs/train_config.yaml @@ -0,0 +1,73 @@ +# Train config file + +settings: + general: + seed: 0 + device: cuda + cuda_visible_devices: "0,1,2,3,4,5,6,7" + parallel: True + + experiment: + name: train_experiment + root_dir: "experiments" + use_checkpoint: True + epoch: -1 # -1 stands for last epoch + max_epochs: 5000 + save_checkpoint_interval: 1 + test_first: True + + train: + optimizer: + type: adam + lr: 0.0001 + losses: # loss type : weight + gf_loss: 1.0 + dataset: + name: synthetic_train_train_dataset + source: nbv1 + data_type: train + ratio: 0.1 + batch_size: 128 + num_workers: 96 + + test: + batch_size: 16 + frequency: 3 + dataset_list: + - name: synthetic_test_train_dataset + source: nbv1 + data_type: train + eval_list: + ratio: 0.00001 + batch_size: 16 + num_workers: 16 + + pipeline: + pts_encoder: pointnet + view_finder: gradient_field + +datasets: + general: + data_dir: "../data" + +modules: + general: + pts_channels: 3 + feature_dim: 1024 + per_point_feature: False + pts_encoder: + pointnet: + pointnet++: + params_name: light + pointnet++rgb: + params_name: light + target_layer: 3 + rgb_feat_dim: 384 + view_finder: + gradient_field: + pose_mode: rot_matrix + regression_head: Rx_Ry + sample_mode: ode + sample_repeat: 50 + sampling_steps: 500 + sde_mode: ve \ No newline at end of file diff --git a/datasets/dataset.py b/datasets/dataset.py new file mode 100644 index 0000000..e58cf06 --- /dev/null +++ b/datasets/dataset.py @@ -0,0 +1,35 @@ +from abc import ABC + +import numpy as np + +from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Subset + +class BaseDataset(ABC, Dataset): + def __init__(self, config): + super(BaseDataset, self).__init__() + self.config = config + + @staticmethod + def process_batch(batch, device): + for key in batch.keys(): + if isinstance(batch[key], list): + continue + batch[key] = batch[key].to(device) + return batch + + def get_loader(self, shuffle=False): + ratio = self.config["ratio"] + if ratio > 1 or ratio <= 0: + raise ValueError( + f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}" + ) + subset_size = int(len(self) * ratio) + indices = np.random.permutation(len(self))[:subset_size] + subset = Subset(self, indices) + return DataLoader( + subset, + batch_size=self.config["batch_size"], + num_workers=self.config["num_workers"], + shuffle=shuffle, + ) diff --git a/datasets/dataset_factory.py b/datasets/dataset_factory.py new file mode 100644 index 0000000..f97d1df --- /dev/null +++ b/datasets/dataset_factory.py @@ -0,0 +1,30 @@ +import sys +import os +path = os.path.abspath(__file__) +for i in range(2): + path = os.path.dirname(path) +PROJECT_ROOT = path +sys.path.append(PROJECT_ROOT) + +from datasets.dataset import BaseDataset + +class DatasetFactory: + @staticmethod + def create(config) -> BaseDataset: + pass + + +''' ------------ Debug ------------ ''' +if __name__ == "__main__": + + from configs.config import ConfigManager + + ConfigManager.load_config_with('/home/data/hofee/project/ActivePerception/ActivePerception/configs/server_train_config.yaml') + ConfigManager.print_config() + dataset = DatasetFactory.create(ConfigManager.get("settings", "test", "dataset_list")[1]) + print(len(dataset)) + data_test = dataset.__getitem__(107000) + print(data_test['src_path']) + import pickle + # with open("data_sample_new.pkl", "wb") as f: + # pickle.dump(data_test, f) diff --git a/evaluations/eval_function_factory.py b/evaluations/eval_function_factory.py new file mode 100644 index 0000000..89ef3b1 --- /dev/null +++ b/evaluations/eval_function_factory.py @@ -0,0 +1,35 @@ +from annotations.stereotype import evaluation_methods +import importlib +import pkgutil +import os + +package_name = os.path.dirname("evaluations") +package = importlib.import_module("evaluations") +for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."): + importlib.import_module(module_name) + +class EvalFunctionFactory: + @staticmethod + def create(eval_type_list): + def eval_func(output, data): + temp_results = {"scalars": {}, "points": {}, "images": {}} + for eval_type in eval_type_list: + if eval_type in evaluation_methods: + result = evaluation_methods[eval_type](output, data) + for k, v in result.items(): + temp_results[k].update(v) + results = {} + for k, v in temp_results.items(): + if len(v) > 0: + results[k] = v + return results + + return eval_func + + +''' ------------ Debug ------------ ''' +if __name__ == "__main__": + from configs.config import ConfigManager + + ConfigManager.load_config_with('../configs/local_train_config.yaml') + ConfigManager.print_config() diff --git a/losses/loss_function_factory.py b/losses/loss_function_factory.py new file mode 100644 index 0000000..ac1f817 --- /dev/null +++ b/losses/loss_function_factory.py @@ -0,0 +1,12 @@ +class LossFunctionFactory: + @staticmethod + def create(function_name): + raise ValueError("Unknown loss function {}".format(function_name)) + + +''' ------------ Debug ------------ ''' +if __name__ == "__main__": + from configs.config import ConfigManager + + ConfigManager.load_config_with('../configs/local_train_config.yaml') + ConfigManager.print_config() diff --git a/modules/pipeline.py b/modules/pipeline.py new file mode 100644 index 0000000..d775742 --- /dev/null +++ b/modules/pipeline.py @@ -0,0 +1,20 @@ +from torch import nn +from configs.config import ConfigManager + + +class Pipeline(nn.Module): + TRAIN_MODE: str = "train" + TEST_MODE: str = "test" + + def __init__(self, pipeline_config): + super(Pipeline, self).__init__() + + self.modules_config = ConfigManager.get("modules") + self.device = ConfigManager.get("settings", "general", "device") + + def forward(self, data, mode): + pass + + +if __name__ == '__main__': + pass diff --git a/optimizers/optimizer_factory.py b/optimizers/optimizer_factory.py new file mode 100644 index 0000000..4c5e1cf --- /dev/null +++ b/optimizers/optimizer_factory.py @@ -0,0 +1,32 @@ +import torch.optim as optim + + +class OptimizerFactory: + @staticmethod + def create(config, params): + optim_type = config["type"] + lr = config.get("lr", 1e-3) + if optim_type == "sgd": + return optim.SGD( + params, + lr=lr, + momentum=config.get("momentum", 0.9), + weight_decay=config.get("weight_decay", 1e-4), + ) + elif optim_type == "adam": + return optim.Adam( + params, + lr=lr, + betas=config.get("betas", (0.9, 0.999)), + eps=config.get("eps", 1e-8), + ) + else: + raise NotImplementedError("Unknown optimizers: {}".format(optim_type)) + + +""" ------------ Debug ------------ """ +if __name__ == "__main__": + from configs.config import ConfigManager + + ConfigManager.load_config_with("../configs/local_train_config.yaml") + ConfigManager.print_config() diff --git a/runners/runner.py b/runners/runner.py new file mode 100644 index 0000000..5c98c00 --- /dev/null +++ b/runners/runner.py @@ -0,0 +1,59 @@ +import os +import time +from abc import abstractmethod, ABC + +import numpy as np +import torch + +from configs.config import ConfigManager + +class Runner(ABC): + + @abstractmethod + def __init__(self, config_path): + ConfigManager.load_config_with(config_path) + ConfigManager.print_config() + seed = ConfigManager.get("settings", "general", "seed") + self.device = ConfigManager.get("settings", "general", "device") + self.cuda_visible_devices = ConfigManager.get("settings","general","cuda_visible_devices") + os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_visible_devices + self.experiments_config = ConfigManager.get("settings", "experiment") + self.experiment_path = os.path.join(self.experiments_config["root_dir"], self.experiments_config["name"]) + np.random.seed(seed) + torch.manual_seed(seed) + lt = time.localtime() + self.file_name = f"{lt.tm_year}_{lt.tm_mon}_{lt.tm_mday}_{lt.tm_hour}h{lt.tm_min}m{lt.tm_sec}s" + + @abstractmethod + def run(self): + pass + + @abstractmethod + def load_experiment(self, backup_name=None): + if not os.path.exists(self.experiment_path): + print(f"experiments environment {self.experiments_config['name']} does not exists.") + self.create_experiment(backup_name) + else: + print(f"experiments environment {self.experiments_config['name']}") + backup_config_dir = os.path.join(str(self.experiment_path), "configs") + if not os.path.exists(backup_config_dir): + os.makedirs(backup_config_dir) + ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name) + + @abstractmethod + def create_experiment(self, backup_name=None): + print("creating experiment: " + self.experiments_config["name"]) + os.makedirs(self.experiment_path) + backup_config_dir = os.path.join(str(self.experiment_path), "configs") + os.makedirs(backup_config_dir) + ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name) + log_dir = os.path.join(str(self.experiment_path), "log") + os.makedirs(log_dir) + cache_dir = os.path.join(str(self.experiment_path), "cache") + os.makedirs(cache_dir) + + def print_info(self): + table_size = 80 + print("+" + "-" * table_size + "+") + print(f"| Experiment <{self.experiments_config['name']}>") + print("+" + "-" * table_size + "+") diff --git a/runners/trainer.py b/runners/trainer.py new file mode 100644 index 0000000..d32b993 --- /dev/null +++ b/runners/trainer.py @@ -0,0 +1,261 @@ +import os +import sys +import json +from datetime import datetime + +import torch +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter + +path = os.path.abspath(__file__) +for i in range(2): + path = os.path.dirname(path) +PROJECT_ROOT = path +sys.path.append(PROJECT_ROOT) + +from configs.config import ConfigManager +from datasets.dataset_factory import DatasetFactory +from optimizers.optimizer_factory import OptimizerFactory +from evaluations.eval_function_factory import EvalFunctionFactory +from losses.loss_function_factory import LossFunctionFactory +from modules.pipeline import Pipeline +from runners.runner import Runner +from utils.tensorboard_util import TensorboardWriter +from annotations.external_module import EXTERNAL_FREEZE_MODULES + + +class Trainer(Runner): + CHECKPOINT_DIR_NAME: str = 'checkpoints' + TENSORBOARD_DIR_NAME: str = 'tensorboard' + LOG_DIR_NAME: str = 'log' + + def __init__(self, config_path): + super().__init__(config_path) + tensorboard_path = os.path.join(self.experiment_path, Trainer.TENSORBOARD_DIR_NAME) + + ''' Pipeline ''' + self.pipeline_config = ConfigManager.get("settings", "pipeline") + self.parallel = ConfigManager.get("settings","general","parallel") + self.pipeline = Pipeline(self.pipeline_config) + if self.parallel and self.device == "cuda": + self.pipeline = torch.nn.DataParallel(self.pipeline) + self.pipeline = self.pipeline.to(self.device) + + ''' Experiment ''' + self.current_epoch = 0 + self.max_epochs = self.experiments_config["max_epochs"] + self.test_first = self.experiments_config["test_first"] + self.load_experiment("train") + + ''' Train ''' + self.train_config = ConfigManager.get("settings", "train") + self.train_dataset_config = self.train_config["dataset"] + self.train_set = DatasetFactory.create(self.train_dataset_config) + self.optimizer = OptimizerFactory.create(self.train_config["optimizer"], self.pipeline.parameters()) + self.train_writer = SummaryWriter( + log_dir=os.path.join(tensorboard_path, f"[train]{self.train_dataset_config['name']}")) + + ''' Test ''' + self.test_config = ConfigManager.get("settings", "test") + self.test_dataset_config_list = self.test_config["dataset_list"] + self.test_set_list = [] + self.test_writer_list = [] + seen_name = set() + for test_dataset_config in self.test_dataset_config_list: + if test_dataset_config["name"] not in seen_name: + seen_name.add(test_dataset_config["name"]) + else: + raise ValueError("Duplicate test dataset name: {}".format(test_dataset_config["name"])) + test_set = DatasetFactory.create(test_dataset_config) + test_writer = SummaryWriter( + log_dir=os.path.join(tensorboard_path, f"[test]{test_dataset_config['name']}")) + self.test_set_list.append(test_set) + self.test_writer_list.append(test_writer) + del seen_name + + self.print_info() + + def run(self): + save_interval = self.experiments_config["save_checkpoint_interval"] + if self.current_epoch != 0: + print("Continue training from epoch {}.".format(self.current_epoch)) + else: + print("Start training from initial model.") + if self.test_first: + print("Do test first.") + self.test() + while self.current_epoch < self.max_epochs: + self.current_epoch += 1 + self.train() + self.test() + if self.current_epoch % save_interval == 0: + self.save_checkpoint() + self.save_checkpoint(is_last=True) + + def train(self): + self.pipeline.train() + train_set_name = self.train_dataset_config["name"] + ratio = self.train_dataset_config["ratio"] + train_loader = self.train_set.get_loader(device="cuda", shuffle=True) + + loop = tqdm(enumerate(train_loader), total=len(train_loader)) + loader_length = len(train_loader) + for i, data in loop: + self.train_set.process_batch(data, self.device) + loss_dict = self.train_step(data) + loop.set_description( + f'Epoch [{self.current_epoch}/{self.max_epochs}] (Train: {train_set_name}, ratio={ratio})') + loop.set_postfix(loss=loss_dict) + curr_iters = (self.current_epoch - 1) * loader_length + i + TensorboardWriter.write_tensorboard(self.train_writer, "iter", loss_dict, curr_iters) + + def train_step(self, data): + self.optimizer.zero_grad() + output = self.pipeline(data, Pipeline.TRAIN_MODE) + total_loss, loss_dict = self.loss_fn(output, data) + total_loss.backward() + self.optimizer.step() + for k, v in loss_dict.items(): + loss_dict[k] = round(v, 5) + return loss_dict + + def loss_fn(self, output, data): + loss_config = self.train_config["losses"] + loss_dict = {} + total_loss = torch.tensor(0.0, dtype=torch.float32, device=self.device) + for key in loss_config: + weight = loss_config[key] + target_loss_fn = LossFunctionFactory.create(key) + loss = target_loss_fn(output, data) + loss_dict[key] = loss.item() + total_loss += weight * loss + + loss_dict['total_loss'] = total_loss.item() + return total_loss, loss_dict + + def test(self): + self.pipeline.eval() + with torch.no_grad(): + for dataset_idx, test_set in enumerate(self.test_set_list): + eval_list = self.test_dataset_config_list[dataset_idx]["eval_list"] + test_set_name = self.test_dataset_config_list[dataset_idx]["name"] + ratio = self.test_dataset_config_list[dataset_idx]["ratio"] + writer = self.test_writer_list[dataset_idx] + output_list = [] + data_list = [] + test_loader = test_set.get_loader("cpu") + loop = tqdm(enumerate(test_loader), total=int(len(test_loader))) + for i, data in loop: + test_set.process_batch(data, self.device) + output = self.pipeline(data, Pipeline.TEST_MODE) + output_list.append(output) + data_list.append(data) + loop.set_description( + f'Epoch [{self.current_epoch}/{self.max_epochs}] (Test: {test_set_name}, ratio={ratio})') + result_dict = self.eval_fn(output_list, data_list, eval_list) + TensorboardWriter.write_tensorboard(writer, "epoch", result_dict, self.current_epoch - 1) + + @staticmethod + def eval_fn(output_list, data_list, eval_list): + target_eval_fn = EvalFunctionFactory.create(eval_list) + result_dict = target_eval_fn(output_list, data_list) + return result_dict + + def get_checkpoint_path(self, is_last=False): + return os.path.join(self.experiment_path, Trainer.CHECKPOINT_DIR_NAME, + "Epoch_{}.pth".format( + self.current_epoch if self.current_epoch != -1 and not is_last else "last")) + + def load_checkpoint(self, is_last=False): + self.load(self.get_checkpoint_path(is_last)) + print(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") + if is_last: + checkpoint_root = os.path.join(self.experiment_path, Trainer.CHECKPOINT_DIR_NAME) + meta_path = os.path.join(checkpoint_root, "meta.json") + if not os.path.exists(meta_path): + raise FileNotFoundError( + "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) + file_path = os.path.join(checkpoint_root, "meta.json") + with open(file_path, "r") as f: + meta = json.load(f) + self.current_epoch = meta["last_epoch"] + + def save_checkpoint(self, is_last=False): + self.save(self.get_checkpoint_path(is_last)) + if not is_last: + print(f"Checkpoint at epoch {self.current_epoch} saved to {self.get_checkpoint_path(is_last)}") + else: + meta = { + "last_epoch": self.current_epoch, + "time": str(datetime.now()) + } + checkpoint_root = os.path.join(self.experiment_path, Trainer.CHECKPOINT_DIR_NAME) + file_path = os.path.join(checkpoint_root, "meta.json") + with open(file_path, "w") as f: + json.dump(meta, f) + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + if self.experiments_config["use_checkpoint"]: + self.current_epoch = self.experiments_config["epoch"] + self.load_checkpoint(is_last=(self.current_epoch == -1)) + + def create_experiment(self, backup_name=None): + super().create_experiment(backup_name) + ckpt_dir = os.path.join(str(self.experiment_path), Trainer.CHECKPOINT_DIR_NAME) + os.makedirs(ckpt_dir) + tensorboard_dir = os.path.join(str(self.experiment_path), Trainer.TENSORBOARD_DIR_NAME) + os.makedirs(tensorboard_dir) + + def load(self, path): + state_dict = torch.load(path) + if self.parallel: + self.pipeline.module.load_state_dict(state_dict) + else: + self.pipeline.load_state_dict(state_dict) + + def save(self, path): + if self.parallel: + state_dict = self.pipeline.module.state_dict() + else: + state_dict = self.pipeline.state_dict() + + for name, module in self.pipeline.named_modules(): + if module.__class__ in EXTERNAL_FREEZE_MODULES: + if name in state_dict: + del state_dict[name] + + torch.save(state_dict, path) + + + def print_info(self): + def print_dataset(config, dataset): + print("\t name: {}".format(config["name"])) + print("\t source: {}".format(config["source"])) + print("\t data_type: {}".format(config["data_type"])) + print("\t total_length: {}".format(len(dataset))) + print("\t ratio: {}".format(config["ratio"])) + print() + + super().print_info() + table_size = 70 + print(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+') + print(self.pipeline) + print(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+') + print("train dataset: ") + print_dataset(self.train_dataset_config, self.train_set) + for i, test_dataset_config in enumerate(self.test_dataset_config_list): + print(f"test dataset {i}: ") + print_dataset(test_dataset_config, self.test_set_list[i]) + + print(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+') + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="configs/train_config.yaml") + args = parser.parse_args() + trainer = Trainer(args.config) + trainer.run() diff --git a/utils/tensorboard_util.py b/utils/tensorboard_util.py new file mode 100644 index 0000000..d9f85f2 --- /dev/null +++ b/utils/tensorboard_util.py @@ -0,0 +1,47 @@ +import torch + + +class TensorboardWriter: + @staticmethod + def write_tensorboard(writer, panel, data_dict, step): + complex_dict = False + if "scalars" in data_dict: + scalar_data_dict = data_dict["scalars"] + TensorboardWriter.write_scalar_tensorboard(writer, panel, scalar_data_dict, step) + complex_dict = True + if "images" in data_dict: + image_data_dict = data_dict["images"] + TensorboardWriter.write_image_tensorboard(writer, panel, image_data_dict, step) + complex_dict = True + if "points" in data_dict: + point_data_dict = data_dict["points"] + TensorboardWriter.write_points_tensorboard(writer, panel, point_data_dict, step) + complex_dict = True + + if not complex_dict: + TensorboardWriter.write_scalar_tensorboard(writer, panel, data_dict, step) + + @staticmethod + def write_scalar_tensorboard(writer, panel, data_dict, step): + for key, value in data_dict.items(): + if isinstance(value, dict): + writer.add_scalars(f'{panel}/{key}', value, step) + else: + writer.add_scalar(f'{panel}/{key}', value, step) + + @staticmethod + def write_image_tensorboard(writer, panel, data_dict, step): + pass + + @staticmethod + def write_points_tensorboard(writer, panel, data_dict, step): + for key, value in data_dict.items(): + if value.shape[-1] == 3: + colors = torch.zeros_like(value) + vertices = torch.cat([value, colors], dim=-1) + elif value.shape[-1] == 6: + vertices = value + else: + raise ValueError(f'Unexpected value shape: {value.shape}') + faces = None + writer.add_mesh(f'{panel}/{key}', vertices=vertices, faces=faces, global_step=step)