From f977fd4b8ef7782d6f9803f74b41f5c22d879174 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Wed, 21 Aug 2024 17:11:56 +0800 Subject: [PATCH] update basic framework --- annotations/external_module.py | 7 - annotations/stereotype.py | 34 --- app_generate.py | 8 + configs/config.py | 74 ----- configs/generate_config.yaml | 23 ++ configs/train_config.yaml | 73 ----- datasets/dataset.py | 35 --- datasets/dataset_factory.py | 30 -- evaluations/eval_function_factory.py | 35 --- losses/gf_loss.py | 12 + losses/loss_function_factory.py | 12 - modules/func_lib/__init__.py | 7 + modules/func_lib/samplers.py | 280 ++++++++++++++++++ modules/func_lib/sde.py | 121 ++++++++ .../module_lib/gaussian_fourier_projection.py | 17 ++ modules/module_lib/linear.py | 30 ++ modules/pipeline.py | 20 -- modules/pts_encoder/abstract_pts_encoder.py | 12 + modules/pts_encoder/pointnet_encoder.py | 110 +++++++ modules/view_finder/abstract_view_finder.py | 12 + modules/view_finder/gf_view_finder.py | 168 +++++++++++ optimizers/optimizer_factory.py | 32 -- runners/runner.py | 59 ---- runners/strategy_generator.py | 73 +++++ runners/trainer.py | 261 ---------------- utils/data_load.py | 135 +++++++++ utils/pose.py | 246 +++++++++++++++ utils/reconstruction.py | 139 +++++++++ utils/tensorboard_util.py | 47 --- 29 files changed, 1393 insertions(+), 719 deletions(-) delete mode 100644 annotations/external_module.py delete mode 100644 annotations/stereotype.py create mode 100644 app_generate.py delete mode 100644 configs/config.py create mode 100644 configs/generate_config.yaml delete mode 100644 configs/train_config.yaml delete mode 100644 datasets/dataset.py delete mode 100644 datasets/dataset_factory.py delete mode 100644 evaluations/eval_function_factory.py create mode 100644 losses/gf_loss.py delete mode 100644 losses/loss_function_factory.py create mode 100644 modules/func_lib/__init__.py create mode 100644 modules/func_lib/samplers.py create mode 100644 modules/func_lib/sde.py create mode 100644 modules/module_lib/gaussian_fourier_projection.py create mode 100644 modules/module_lib/linear.py delete mode 100644 modules/pipeline.py create mode 100644 modules/pts_encoder/abstract_pts_encoder.py create mode 100644 modules/pts_encoder/pointnet_encoder.py create mode 100644 modules/view_finder/abstract_view_finder.py create mode 100644 modules/view_finder/gf_view_finder.py delete mode 100644 optimizers/optimizer_factory.py delete mode 100644 runners/runner.py create mode 100644 runners/strategy_generator.py delete mode 100644 runners/trainer.py create mode 100644 utils/data_load.py create mode 100644 utils/pose.py create mode 100644 utils/reconstruction.py delete mode 100644 utils/tensorboard_util.py diff --git a/annotations/external_module.py b/annotations/external_module.py deleted file mode 100644 index 5fd15ba..0000000 --- a/annotations/external_module.py +++ /dev/null @@ -1,7 +0,0 @@ -EXTERNAL_FREEZE_MODULES = set() - -def external_freeze(cls): - if not hasattr(cls, 'load') or not callable(getattr(cls, 'load')): - raise TypeError(f"external module <{cls.__name__}> must implement a 'load' method") - EXTERNAL_FREEZE_MODULES.add(cls) - return cls \ No newline at end of file diff --git a/annotations/stereotype.py b/annotations/stereotype.py deleted file mode 100644 index e7280fc..0000000 --- a/annotations/stereotype.py +++ /dev/null @@ -1,34 +0,0 @@ -# --- Classes --- # - -def dataset(): - pass - -def module(): - pass - -def pipeline(): - pass - -def runner(): - pass - -def factory(): - pass - -# --- Functions --- # - -evaluation_methods = {} -def evaluation_method(eval_type): - def decorator(func): - evaluation_methods[eval_type] = func - return func - return decorator - - -def loss_function(): - pass - - -# --- Main --- # - - \ No newline at end of file diff --git a/app_generate.py b/app_generate.py new file mode 100644 index 0000000..a74a2e8 --- /dev/null +++ b/app_generate.py @@ -0,0 +1,8 @@ +from PytorchBoot.application import PytorchBootApplication +from runners.strategy_generator import StrategyGenerator + +@PytorchBootApplication("generate") +class Generator: + @staticmethod + def start(): + StrategyGenerator("configs\generate_config.yaml").run() \ No newline at end of file diff --git a/configs/config.py b/configs/config.py deleted file mode 100644 index 2e68fa6..0000000 --- a/configs/config.py +++ /dev/null @@ -1,74 +0,0 @@ -import argparse -import os.path -import shutil -import yaml - - -class ConfigManager: - config = None - config_path = None - - @staticmethod - def get(*args): - result = ConfigManager.config - for arg in args: - result = result[arg] - return result - - @staticmethod - def load_config_with(config_file_path): - ConfigManager.config_path = config_file_path - if not os.path.exists(ConfigManager.config_path): - raise ValueError(f"Config file <{config_file_path}> does not exist") - with open(config_file_path, 'r') as file: - ConfigManager.config = yaml.safe_load(file) - - @staticmethod - def backup_config_to(target_config_dir, file_name, prefix="config"): - file_name = f"{prefix}_{file_name}.yaml" - target_config_file_path = str(os.path.join(target_config_dir, file_name)) - shutil.copy(ConfigManager.config_path, target_config_file_path) - - @staticmethod - def load_config(): - parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, default='', help='config file path') - args = parser.parse_args() - if args.config: - ConfigManager.load_config_with(args.config) - - @staticmethod - def print_config(key: str = None, group: dict = None, level=0): - table_size = 80 - if key and group: - value = group[key] - if type(value) is dict: - print("\t" * level + f"+-{key}:") - for k in value: - ConfigManager.print_config(k, value, level=level + 1) - else: - print("\t" * level + f"| {key}: {value}") - elif key: - ConfigManager.print_config(key, ConfigManager.config, level=level) - else: - print("+" + "-" * table_size + "+") - print(f"| Configurations in <{ConfigManager.config_path}>:") - print("+" + "-" * table_size + "+") - for key in ConfigManager.config: - ConfigManager.print_config(key, level=level + 1) - print("+" + "-" * table_size + "+") - - -''' ------------ Debug ------------ ''' -if __name__ == "__main__": - test_args = ['--config', r'configs\train_config.yaml'] - test_parser = argparse.ArgumentParser() - test_parser.add_argument('--config', type=str, default='', help='config file path') - test_args = test_parser.parse_args(test_args) - if test_args.config: - ConfigManager.load_config_with(test_args.config) - ConfigManager.print_config() - print() - pipeline = ConfigManager.get('settings', 'train', "dataset", 'batch_size') - ConfigManager.print_config('settings') - print(pipeline) diff --git a/configs/generate_config.yaml b/configs/generate_config.yaml new file mode 100644 index 0000000..70a808a --- /dev/null +++ b/configs/generate_config.yaml @@ -0,0 +1,23 @@ + +runners: + general: + seed: 0 + device: cpu + cuda_visible_devices: "0,1,2,3,4,5,6,7" + + experiment: + name: debug + root_dir: "experiments" + + generate: + - name: OmniObject3d_train + component: OmniObject3d + data_type: train + +datasets: + general: + components: + OmniObject3d: + root_dir: "C:\\Document\\Local Project\\nbv_rec\\output" + + diff --git a/configs/train_config.yaml b/configs/train_config.yaml deleted file mode 100644 index e8e4fbc..0000000 --- a/configs/train_config.yaml +++ /dev/null @@ -1,73 +0,0 @@ -# Train config file - -settings: - general: - seed: 0 - device: cuda - cuda_visible_devices: "0,1,2,3,4,5,6,7" - parallel: True - - experiment: - name: train_experiment - root_dir: "experiments" - use_checkpoint: True - epoch: -1 # -1 stands for last epoch - max_epochs: 5000 - save_checkpoint_interval: 1 - test_first: True - - train: - optimizer: - type: adam - lr: 0.0001 - losses: # loss type : weight - gf_loss: 1.0 - dataset: - name: synthetic_train_train_dataset - source: nbv1 - data_type: train - ratio: 0.1 - batch_size: 128 - num_workers: 96 - - test: - batch_size: 16 - frequency: 3 - dataset_list: - - name: synthetic_test_train_dataset - source: nbv1 - data_type: train - eval_list: - ratio: 0.00001 - batch_size: 16 - num_workers: 16 - - pipeline: - pts_encoder: pointnet - view_finder: gradient_field - -datasets: - general: - data_dir: "../data" - -modules: - general: - pts_channels: 3 - feature_dim: 1024 - per_point_feature: False - pts_encoder: - pointnet: - pointnet++: - params_name: light - pointnet++rgb: - params_name: light - target_layer: 3 - rgb_feat_dim: 384 - view_finder: - gradient_field: - pose_mode: rot_matrix - regression_head: Rx_Ry - sample_mode: ode - sample_repeat: 50 - sampling_steps: 500 - sde_mode: ve \ No newline at end of file diff --git a/datasets/dataset.py b/datasets/dataset.py deleted file mode 100644 index e58cf06..0000000 --- a/datasets/dataset.py +++ /dev/null @@ -1,35 +0,0 @@ -from abc import ABC - -import numpy as np - -from torch.utils.data import Dataset -from torch.utils.data import DataLoader, Subset - -class BaseDataset(ABC, Dataset): - def __init__(self, config): - super(BaseDataset, self).__init__() - self.config = config - - @staticmethod - def process_batch(batch, device): - for key in batch.keys(): - if isinstance(batch[key], list): - continue - batch[key] = batch[key].to(device) - return batch - - def get_loader(self, shuffle=False): - ratio = self.config["ratio"] - if ratio > 1 or ratio <= 0: - raise ValueError( - f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}" - ) - subset_size = int(len(self) * ratio) - indices = np.random.permutation(len(self))[:subset_size] - subset = Subset(self, indices) - return DataLoader( - subset, - batch_size=self.config["batch_size"], - num_workers=self.config["num_workers"], - shuffle=shuffle, - ) diff --git a/datasets/dataset_factory.py b/datasets/dataset_factory.py deleted file mode 100644 index f97d1df..0000000 --- a/datasets/dataset_factory.py +++ /dev/null @@ -1,30 +0,0 @@ -import sys -import os -path = os.path.abspath(__file__) -for i in range(2): - path = os.path.dirname(path) -PROJECT_ROOT = path -sys.path.append(PROJECT_ROOT) - -from datasets.dataset import BaseDataset - -class DatasetFactory: - @staticmethod - def create(config) -> BaseDataset: - pass - - -''' ------------ Debug ------------ ''' -if __name__ == "__main__": - - from configs.config import ConfigManager - - ConfigManager.load_config_with('/home/data/hofee/project/ActivePerception/ActivePerception/configs/server_train_config.yaml') - ConfigManager.print_config() - dataset = DatasetFactory.create(ConfigManager.get("settings", "test", "dataset_list")[1]) - print(len(dataset)) - data_test = dataset.__getitem__(107000) - print(data_test['src_path']) - import pickle - # with open("data_sample_new.pkl", "wb") as f: - # pickle.dump(data_test, f) diff --git a/evaluations/eval_function_factory.py b/evaluations/eval_function_factory.py deleted file mode 100644 index 89ef3b1..0000000 --- a/evaluations/eval_function_factory.py +++ /dev/null @@ -1,35 +0,0 @@ -from annotations.stereotype import evaluation_methods -import importlib -import pkgutil -import os - -package_name = os.path.dirname("evaluations") -package = importlib.import_module("evaluations") -for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."): - importlib.import_module(module_name) - -class EvalFunctionFactory: - @staticmethod - def create(eval_type_list): - def eval_func(output, data): - temp_results = {"scalars": {}, "points": {}, "images": {}} - for eval_type in eval_type_list: - if eval_type in evaluation_methods: - result = evaluation_methods[eval_type](output, data) - for k, v in result.items(): - temp_results[k].update(v) - results = {} - for k, v in temp_results.items(): - if len(v) > 0: - results[k] = v - return results - - return eval_func - - -''' ------------ Debug ------------ ''' -if __name__ == "__main__": - from configs.config import ConfigManager - - ConfigManager.load_config_with('../configs/local_train_config.yaml') - ConfigManager.print_config() diff --git a/losses/gf_loss.py b/losses/gf_loss.py new file mode 100644 index 0000000..a4320a2 --- /dev/null +++ b/losses/gf_loss.py @@ -0,0 +1,12 @@ +import torch +import PytorchBoot.stereotype as stereotype + +@stereotype.loss_function("gf_loss") +def compute_loss(output, data): + estimated_score = output['estimated_score'] + target_score = output['target_score'] + std = output['std'] + bs = estimated_score.shape[0] + loss_weighting = std ** 2 + loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1)) + return loss diff --git a/losses/loss_function_factory.py b/losses/loss_function_factory.py deleted file mode 100644 index ac1f817..0000000 --- a/losses/loss_function_factory.py +++ /dev/null @@ -1,12 +0,0 @@ -class LossFunctionFactory: - @staticmethod - def create(function_name): - raise ValueError("Unknown loss function {}".format(function_name)) - - -''' ------------ Debug ------------ ''' -if __name__ == "__main__": - from configs.config import ConfigManager - - ConfigManager.load_config_with('../configs/local_train_config.yaml') - ConfigManager.print_config() diff --git a/modules/func_lib/__init__.py b/modules/func_lib/__init__.py new file mode 100644 index 0000000..5d3879a --- /dev/null +++ b/modules/func_lib/__init__.py @@ -0,0 +1,7 @@ +from modules.func_lib.samplers import ( + cond_pc_sampler, + cond_ode_sampler +) +from modules.func_lib.sde import ( + init_sde +) diff --git a/modules/func_lib/samplers.py b/modules/func_lib/samplers.py new file mode 100644 index 0000000..cdefdcc --- /dev/null +++ b/modules/func_lib/samplers.py @@ -0,0 +1,280 @@ +import torch +import numpy as np + +from scipy import integrate +from utils.pose import PoseUtil + + +def global_prior_likelihood(z, sigma_max): + """The likelihood of a Gaussian distribution with mean zero and + standard deviation sigma.""" + # z: [bs, pose_dim] + shape = z.shape + N = np.prod(shape[1:]) # pose_dim + return -N / 2. * torch.log(2 * np.pi * sigma_max ** 2) - torch.sum(z ** 2, dim=-1) / (2 * sigma_max ** 2) + + +def cond_ode_likelihood( + score_model, + data, + prior, + sde_coeff, + marginal_prob_fn, + atol=1e-5, + rtol=1e-5, + device='cuda', + eps=1e-5, + num_steps=None, + pose_mode='quat_wxyz', + init_x=None, +): + pose_dim = PoseUtil.get_pose_dim(pose_mode) + batch_size = data['pts'].shape[0] + epsilon = prior((batch_size, pose_dim)).to(device) + init_x = data['sampled_pose'].clone().cpu().numpy() if init_x is None else init_x + shape = init_x.shape + init_logp = np.zeros((shape[0],)) # [bs] + init_inp = np.concatenate([init_x.reshape(-1), init_logp], axis=0) + + def score_eval_wrapper(data): + """A wrapper of the score-based model for use by the ODE solver.""" + with torch.no_grad(): + score = score_model(data) + return score.cpu().numpy().reshape((-1,)) + + def divergence_eval(data, epsilon): + """Compute the divergence of the score-based model with Skilling-Hutchinson.""" + # save ckpt of sampled_pose + origin_sampled_pose = data['sampled_pose'].clone() + with torch.enable_grad(): + # make sampled_pose differentiable + data['sampled_pose'].requires_grad_(True) + score = score_model(data) + score_energy = torch.sum(score * epsilon) # [, ] + grad_score_energy = torch.autograd.grad(score_energy, data['sampled_pose'])[0] # [bs, pose_dim] + # reset sampled_pose + data['sampled_pose'] = origin_sampled_pose + return torch.sum(grad_score_energy * epsilon, dim=-1) # [bs, 1] + + def divergence_eval_wrapper(data): + """A wrapper for evaluating the divergence of score for the black-box ODE solver.""" + with torch.no_grad(): + # Compute likelihood. + div = divergence_eval(data, epsilon) # [bs, 1] + return div.cpu().numpy().reshape((-1,)).astype(np.float64) + + def ode_func(t, inp): + """The ODE function for use by the ODE solver.""" + # split x, logp from inp + x = inp[:-shape[0]] + # calc x-grad + x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device) + time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t + drift, diffusion = sde_coeff(torch.tensor(t)) + drift = drift.cpu().numpy() + diffusion = diffusion.cpu().numpy() + data['sampled_pose'] = x + data['t'] = time_steps + x_grad = drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data) + # calc logp-grad + logp_grad = drift - 0.5 * (diffusion ** 2) * divergence_eval_wrapper(data) + # concat curr grad + return np.concatenate([x_grad, logp_grad], axis=0) + + # Run the black-box ODE solver, note the + res = integrate.solve_ivp(ode_func, (eps, 1.0), init_inp, rtol=rtol, atol=atol, method='RK45') + zp = torch.tensor(res.y[:, -1], device=device) # [bs * (pose_dim + 1)] + z = zp[:-shape[0]].reshape(shape) # [bs, pose_dim] + delta_logp = zp[-shape[0]:].reshape(shape[0]) # [bs,] logp + _, sigma_max = marginal_prob_fn(None, torch.tensor(1.).to(device)) # we assume T = 1 + prior_logp = global_prior_likelihood(z, sigma_max) + log_likelihoods = (prior_logp + delta_logp) / np.log(2) # negative log-likelihoods (nlls) + return z, log_likelihoods + + +def cond_pc_sampler( + score_model, + data, + prior, + sde_coeff, + num_steps=500, + snr=0.16, + device='cuda', + eps=1e-5, + pose_mode='quat_wxyz', + init_x=None, +): + pose_dim = PoseUtil.get_pose_dim(pose_mode) + batch_size = data['target_pts_feat'].shape[0] + init_x = prior((batch_size, pose_dim)).to(device) if init_x is None else init_x + time_steps = torch.linspace(1., eps, num_steps, device=device) + step_size = time_steps[0] - time_steps[1] + noise_norm = np.sqrt(pose_dim) + x = init_x + poses = [] + with torch.no_grad(): + for time_step in time_steps: + batch_time_step = torch.ones(batch_size, device=device).unsqueeze(-1) * time_step + # Corrector step (Langevin MCMC) + data['sampled_pose'] = x + data['t'] = batch_time_step + grad = score_model(data) + grad_norm = torch.norm(grad.reshape(batch_size, -1), dim=-1).mean() + langevin_step_size = 2 * (snr * noise_norm / grad_norm) ** 2 + x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x) + + # normalisation + if pose_mode == 'quat_wxyz' or pose_mode == 'quat_xyzw': + # quat, should be normalised + x[:, :4] /= torch.norm(x[:, :4], dim=-1, keepdim=True) + elif pose_mode == 'euler_xyz': + pass + else: + # rotation(x axis, y axis), should be normalised + x[:, :3] /= torch.norm(x[:, :3], dim=-1, keepdim=True) + x[:, 3:6] /= torch.norm(x[:, 3:6], dim=-1, keepdim=True) + + # Predictor step (Euler-Maruyama) + drift, diffusion = sde_coeff(batch_time_step) + drift = drift - diffusion ** 2 * grad # R-SDE + mean_x = x + drift * step_size + x = mean_x + diffusion * torch.sqrt(step_size) * torch.randn_like(x) + + # normalisation + x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode) + poses.append(x.unsqueeze(0)) + + xs = torch.cat(poses, dim=0) + xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1) + mean_x[:, -3:] += data['pts_center'] + mean_x[:, :-3] = PoseUtil.normalize_rotation(mean_x[:, :-3], pose_mode) + # The last step does not include any noise + return xs.permute(1, 0, 2), mean_x + + +def cond_ode_sampler( + score_model, + data, + prior, + sde_coeff, + atol=1e-5, + rtol=1e-5, + device='cuda', + eps=1e-5, + T=1.0, + num_steps=None, + pose_mode='quat_wxyz', + denoise=True, + init_x=None, +): + pose_dim = PoseUtil.get_pose_dim(pose_mode) + batch_size = data['target_feat'].shape[0] + init_x = prior((batch_size, pose_dim), T=T).to(device) if init_x is None else init_x + prior((batch_size, pose_dim), + T=T).to(device) + shape = init_x.shape + + def score_eval_wrapper(data): + """A wrapper of the score-based model for use by the ODE solver.""" + with torch.no_grad(): + score = score_model(data) + return score.cpu().numpy().reshape((-1,)) + + def ode_func(t, x): + """The ODE function for use by the ODE solver.""" + x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device) + time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t + drift, diffusion = sde_coeff(torch.tensor(t)) + drift = drift.cpu().numpy() + diffusion = diffusion.cpu().numpy() + data['sampled_pose'] = x + data['t'] = time_steps + return drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data) + + # Run the black-box ODE solver, note the + t_eval = None + if num_steps is not None: + # num_steps, from T -> eps + t_eval = np.linspace(T, eps, num_steps) + res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45', + t_eval=t_eval) + xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim] + x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim] + # denoise, using the predictor step in P-C sampler + if denoise: + # Reverse diffusion predictor for denoising + vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps + drift, diffusion = sde_coeff(vec_eps) + data['sampled_pose'] = x.float() + data['t'] = vec_eps + grad = score_model(data) + drift = drift - diffusion ** 2 * grad # R-SDE + mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps)) + x = mean_x + + num_steps = xs.shape[0] + xs = xs.reshape(batch_size * num_steps, -1) + xs = PoseUtil.normalize_rotation(xs, pose_mode) + xs = xs.reshape(num_steps, batch_size, -1) + x = PoseUtil.normalize_rotation(x, pose_mode) + return xs.permute(1, 0, 2), x + + +def cond_edm_sampler( + decoder_model, data, prior_fn, randn_like=torch.randn_like, + num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, + S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, + pose_mode='quat_wxyz', device='cuda' +): + pose_dim = PoseUtil.get_pose_dim(pose_mode) + batch_size = data['pts'].shape[0] + latents = prior_fn((batch_size, pose_dim)).to(device) + + # Time step discretion. note that sigma and t is interchangeable + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( + sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 + + def decoder_wrapper(decoder, data, x, t): + # save temp + x_, t_ = data['sampled_pose'], data['t'] + # init data + data['sampled_pose'], data['t'] = x, t + # denoise + data, denoised = decoder(data) + # recover data + data['sampled_pose'], data['t'] = x_, t_ + return denoised.to(torch.float64) + + # Main sampling loop. + x_next = latents.to(torch.float64) * t_steps[0] + xs = [] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + t_hat = torch.as_tensor(t_cur + gamma * t_cur) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. + denoised = decoder_wrapper(decoder_model, data, x_hat, t_hat) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = decoder_wrapper(decoder_model, data, x_next, t_next) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + xs.append(x_next.unsqueeze(0)) + + xs = torch.stack(xs, dim=0) # [num_steps, bs, pose_dim] + x = xs[-1] # [bs, pose_dim] + + # post-processing + xs = xs.reshape(batch_size * num_steps, -1) + xs = PoseUtil.normalize_rotation(xs, pose_mode) + xs = xs.reshape(num_steps, batch_size, -1) + x = PoseUtil.normalize_rotation(x, pose_mode) + return xs.permute(1, 0, 2), x diff --git a/modules/func_lib/sde.py b/modules/func_lib/sde.py new file mode 100644 index 0000000..d93c999 --- /dev/null +++ b/modules/func_lib/sde.py @@ -0,0 +1,121 @@ +import functools +import torch +import numpy as np + + +# ----- VE SDE ----- +# ------------------ +def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90): + std = sigma_min * (sigma_max / sigma_min) ** t + mean = x + return mean, std + + +def ve_sde(t, sigma_min=0.01, sigma_max=90): + sigma = sigma_min * (sigma_max / sigma_min) ** t + drift_coeff = torch.tensor(0) + diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device)) + return drift_coeff, diffusion_coeff + + +def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0): + _, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max) + return torch.randn(*shape) * sigma_max_prior + + +# ----- VP SDE ----- +# ------------------ +def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20): + log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 + mean = torch.exp(log_mean_coeff) * x + std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + return mean, std + + +def vp_sde(t, beta_0=0.1, beta_1=20): + beta_t = beta_0 + t * (beta_1 - beta_0) + drift_coeff = -0.5 * beta_t + diffusion_coeff = torch.sqrt(beta_t) + return drift_coeff, diffusion_coeff + + +def vp_prior(shape, beta_0=0.1, beta_1=20): + return torch.randn(*shape) + + +# ----- sub-VP SDE ----- +# ---------------------- +def subvp_marginal_prob(x, t, beta_0, beta_1): + log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 + mean = torch.exp(log_mean_coeff) * x + std = 1 - torch.exp(2. * log_mean_coeff) + return mean, std + + +def subvp_sde(t, beta_0, beta_1): + beta_t = beta_0 + t * (beta_1 - beta_0) + drift_coeff = -0.5 * beta_t + discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2) + diffusion_coeff = torch.sqrt(beta_t * discount) + return drift_coeff, diffusion_coeff + + +def subvp_prior(shape, beta_0=0.1, beta_1=20): + return torch.randn(*shape) + + +# ----- EDM SDE ----- +# ------------------ +def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80): + std = t + mean = x + return mean, std + + +def edm_sde(t, sigma_min=0.002, sigma_max=80): + drift_coeff = torch.tensor(0) + diffusion_coeff = torch.sqrt(2 * t) + return drift_coeff, diffusion_coeff + + +def edm_prior(shape, sigma_min=0.002, sigma_max=80): + return torch.randn(*shape) * sigma_max + + +def init_sde(sde_mode): + # the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch + if sde_mode == 'edm': + sigma_min = 0.002 + sigma_max = 80 + eps = 0.002 + prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max) + marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max) + sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max) + T = sigma_max + elif sde_mode == 've': + sigma_min = 0.01 + sigma_max = 50 + eps = 1e-5 + marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max) + sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max) + T = 1.0 + prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max) + elif sde_mode == 'vp': + beta_0 = 0.1 + beta_1 = 20 + eps = 1e-3 + prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1) + marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1) + sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1) + T = 1.0 + elif sde_mode == 'subvp': + beta_0 = 0.1 + beta_1 = 20 + eps = 1e-3 + prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1) + marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1) + sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1) + T = 1.0 + else: + raise NotImplementedError + return prior_fn, marginal_prob_fn, sde_fn, eps, T diff --git a/modules/module_lib/gaussian_fourier_projection.py b/modules/module_lib/gaussian_fourier_projection.py new file mode 100644 index 0000000..13a7e4a --- /dev/null +++ b/modules/module_lib/gaussian_fourier_projection.py @@ -0,0 +1,17 @@ +import torch +import numpy as np +import torch.nn as nn + + +class GaussianFourierProjection(nn.Module): + """Gaussian random features for encoding time steps.""" + + def __init__(self, embed_dim, scale=30.): + super().__init__() + # Randomly sample weights during initialization. These weights are fixed + # during optimization and are not trainable. + self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * np.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) diff --git a/modules/module_lib/linear.py b/modules/module_lib/linear.py new file mode 100644 index 0000000..e79e498 --- /dev/null +++ b/modules/module_lib/linear.py @@ -0,0 +1,30 @@ +import torch +import numpy as np + + +def weight_init(shape, mode, fan_in, fan_out): + if mode == 'xavier_uniform': + return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) + if mode == 'xavier_normal': + return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) + if mode == 'kaiming_uniform': + return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) + if mode == 'kaiming_normal': + return np.sqrt(1 / fan_in) * torch.randn(*shape) + raise ValueError(f'Invalid init mode "{mode}"') + + +class Linear(torch.nn.Module): + def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) + self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight) + self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None + + def forward(self, x): + x = x @ self.weight.to(x.dtype).t() + if self.bias is not None: + x = x.add_(self.bias.to(x.dtype)) + return x diff --git a/modules/pipeline.py b/modules/pipeline.py deleted file mode 100644 index d775742..0000000 --- a/modules/pipeline.py +++ /dev/null @@ -1,20 +0,0 @@ -from torch import nn -from configs.config import ConfigManager - - -class Pipeline(nn.Module): - TRAIN_MODE: str = "train" - TEST_MODE: str = "test" - - def __init__(self, pipeline_config): - super(Pipeline, self).__init__() - - self.modules_config = ConfigManager.get("modules") - self.device = ConfigManager.get("settings", "general", "device") - - def forward(self, data, mode): - pass - - -if __name__ == '__main__': - pass diff --git a/modules/pts_encoder/abstract_pts_encoder.py b/modules/pts_encoder/abstract_pts_encoder.py new file mode 100644 index 0000000..a7e33ab --- /dev/null +++ b/modules/pts_encoder/abstract_pts_encoder.py @@ -0,0 +1,12 @@ +from abc import abstractmethod + +from torch import nn + + +class PointsEncoder(nn.Module): + def __init__(self): + super(PointsEncoder, self).__init__() + + @abstractmethod + def encode_points(self, pts): + pass diff --git a/modules/pts_encoder/pointnet_encoder.py b/modules/pts_encoder/pointnet_encoder.py new file mode 100644 index 0000000..8e30261 --- /dev/null +++ b/modules/pts_encoder/pointnet_encoder.py @@ -0,0 +1,110 @@ +from __future__ import print_function +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.utils.data +from torch.autograd import Variable +import numpy as np +import torch.nn.functional as F + +from modules.pts_encoder.abstract_pts_encoder import PointsEncoder +import PytorchBoot.stereotype as stereotype + +class STNkd(nn.Module): + def __init__(self, k=64): + super(STNkd, self).__init__() + self.conv1 = torch.nn.Conv1d(k, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 1024, 1) + self.fc1 = nn.Linear(1024, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, k * k) + self.relu = nn.ReLU() + + self.k = k + + def forward(self, x): + batchsize = x.size()[0] + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, 1024) + + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + + iden = ( + Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))) + .view(1, self.k * self.k) + .repeat(batchsize, 1) + ) + if x.is_cuda: + iden = iden.to(x.get_device()) + x = x + iden + x = x.view(-1, self.k, self.k) + return x + + +@stereotype.module("pointnet_encoder") +class PointNetEncoder(PointsEncoder): + + def __init__(self, global_feat=True, in_dim=3, out_dim=1024, feature_transform=False): + super(PointNetEncoder, self).__init__() + self.out_dim = out_dim + self.feature_transform = feature_transform + self.stn = STNkd(k=in_dim) + self.conv1 = torch.nn.Conv1d(in_dim, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 512, 1) + self.conv4 = torch.nn.Conv1d(512, out_dim, 1) + self.global_feat = global_feat + if self.feature_transform: + self.f_stn = STNkd(k=64) + + def forward(self, x): + n_pts = x.shape[2] + trans = self.stn(x) + x = x.transpose(2, 1) + x = torch.bmm(x, trans) + x = x.transpose(2, 1) + x = F.relu(self.conv1(x)) + + if self.feature_transform: + trans_feat = self.f_stn(x) + x = x.transpose(2, 1) + x = torch.bmm(x, trans_feat) + x = x.transpose(2, 1) + + point_feat = x + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = self.conv4(x) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, self.out_dim) + if self.global_feat: + return x + else: + x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts) + return torch.cat([x, point_feat], 1) + + def encode_points(self, pts): + pts = pts.transpose(2, 1) + if not self.global_feat: + pts_feature = self(pts).transpose(2, 1) + else: + pts_feature = self(pts) + return pts_feature + + +if __name__ == "__main__": + sim_data = Variable(torch.rand(32, 2500, 3)) + + pointnet_global = PointNetEncoder(global_feat=True) + out = pointnet_global.encode_points(sim_data) + print("global feat", out.size()) + + pointnet = PointNetEncoder(global_feat=False) + out = pointnet.encode_points(sim_data) + print("point feat", out.size()) diff --git a/modules/view_finder/abstract_view_finder.py b/modules/view_finder/abstract_view_finder.py new file mode 100644 index 0000000..b688c16 --- /dev/null +++ b/modules/view_finder/abstract_view_finder.py @@ -0,0 +1,12 @@ +from abc import abstractmethod + +from torch import nn + + +class ViewFinder(nn.Module): + def __init__(self): + super(ViewFinder, self).__init__() + + @abstractmethod + def next_best_view(self, scene_pts_feat, target_pts_feat): + pass diff --git a/modules/view_finder/gf_view_finder.py b/modules/view_finder/gf_view_finder.py new file mode 100644 index 0000000..030760d --- /dev/null +++ b/modules/view_finder/gf_view_finder.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import PytorchBoot.stereotype as stereotype + +from utils.pose import PoseUtil +from modules.view_finder.abstract_view_finder import ViewFinder +import modules.module_lib as mlib +import modules.func_lib as flib + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +@stereotype.module("gf_view_finder") +class GradientFieldViewFinder(ViewFinder): + def __init__(self, pose_mode='rot_matrix', regression_head='Rx_Ry', per_point_feature=False, + sample_mode="ode", sampling_steps=None, sde_mode="ve"): + + super(GradientFieldViewFinder, self).__init__() + self.regression_head = regression_head + self.per_point_feature = per_point_feature + self.act = nn.ReLU(True) + self.sample_mode = sample_mode + self.pose_mode = pose_mode + pose_dim = PoseUtil.get_pose_dim(pose_mode) + self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(sde_mode) + self.sampling_steps = sampling_steps + + ''' encode pose ''' + self.pose_encoder = nn.Sequential( + nn.Linear(pose_dim, 256), + self.act, + nn.Linear(256, 256), + self.act, + ) + + ''' encode t ''' + self.t_encoder = nn.Sequential( + mlib.GaussianFourierProjection(embed_dim=128), + nn.Linear(128, 128), + self.act, + ) + + ''' fusion tail ''' + if self.regression_head == 'Rx_Ry': + if pose_mode != 'rot_matrix': + raise NotImplementedError + if not per_point_feature: + ''' rotation_x_axis regress head ''' + self.fusion_tail_rot_x = nn.Sequential( + nn.Linear(128 + 256 + 1024 + 1024, 256), + self.act, + zero_module(nn.Linear(256, 3)), + ) + self.fusion_tail_rot_y = nn.Sequential( + nn.Linear(128 + 256 + 1024 + 1024, 256), + self.act, + zero_module(nn.Linear(256, 3)), + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + def forward(self, data): + """ + Args: + data, dict { + 'target_pts_feat': [bs, c] + 'scene_pts_feat': [bs, c] + 'pose_sample': [bs, pose_dim] + 't': [bs, 1] + } + """ + + scene_pts_feat = data['scene_feat'] + target_pts_feat = data['target_feat'] + sampled_pose = data['sampled_pose'] + t = data['t'] + t_feat = self.t_encoder(t.squeeze(1)) + pose_feat = self.pose_encoder(sampled_pose) + + if self.per_point_feature: + raise NotImplementedError + else: + total_feat = torch.cat([scene_pts_feat, target_pts_feat, t_feat, pose_feat], dim=-1) + _, std = self.marginal_prob_fn(total_feat, t) + + if self.regression_head == 'Rx_Ry': + rot_x = self.fusion_tail_rot_x(total_feat) + rot_y = self.fusion_tail_rot_y(total_feat) + out_score = torch.cat([rot_x, rot_y], dim=-1) / (std + 1e-7) # normalisation + else: + raise NotImplementedError + + return out_score + + def marginal_prob(self, x, t): + return self.marginal_prob_fn(x,t) + + def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None): + + if self.sample_mode == 'pc': + in_process_sample, res = flib.cond_pc_sampler( + score_model=self, + data=data, + prior=self.prior_fn, + sde_coeff=self.sde_fn, + num_steps=self.sampling_steps, + snr=snr, + eps=self.sampling_eps, + pose_mode=self.pose_mode, + init_x=init_x + ) + + elif self.sample_mode == 'ode': + T0 = self.T if T0 is None else T0 + in_process_sample, res = flib.cond_ode_sampler( + score_model=self, + data=data, + prior=self.prior_fn, + sde_coeff=self.sde_fn, + atol=atol, + rtol=rtol, + eps=self.sampling_eps, + T=T0, + num_steps=self.sampling_steps, + pose_mode=self.pose_mode, + denoise=denoise, + init_x=init_x + ) + else: + raise NotImplementedError + + return in_process_sample, res + + def next_best_view(self, scene_pts_feat, target_pts_feat): + data = { + 'scene_feat': scene_pts_feat, + 'target_feat': target_pts_feat, + } + in_process_sample, res = self.sample(data) + return res.to(dtype=torch.float32), in_process_sample + + +''' ----------- DEBUG -----------''' +if __name__ == "__main__": + test_scene_feat = torch.rand(32, 1024).to("cuda:0") + test_target_feat = torch.rand(32, 1024).to("cuda:0") + test_pose = torch.rand(32, 6).to("cuda:0") + test_t = torch.rand(32, 1).to("cuda:0") + view_finder = GradientFieldViewFinder().to("cuda:0") + test_data = { + 'target_feat': test_target_feat, + 'scene_feat': test_scene_feat, + 'sampled_pose': test_pose, + 't': test_t + } + score = view_finder(test_data) + + result = view_finder.next_best_view(test_scene_feat, test_target_feat) + print(result) diff --git a/optimizers/optimizer_factory.py b/optimizers/optimizer_factory.py deleted file mode 100644 index 4c5e1cf..0000000 --- a/optimizers/optimizer_factory.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch.optim as optim - - -class OptimizerFactory: - @staticmethod - def create(config, params): - optim_type = config["type"] - lr = config.get("lr", 1e-3) - if optim_type == "sgd": - return optim.SGD( - params, - lr=lr, - momentum=config.get("momentum", 0.9), - weight_decay=config.get("weight_decay", 1e-4), - ) - elif optim_type == "adam": - return optim.Adam( - params, - lr=lr, - betas=config.get("betas", (0.9, 0.999)), - eps=config.get("eps", 1e-8), - ) - else: - raise NotImplementedError("Unknown optimizers: {}".format(optim_type)) - - -""" ------------ Debug ------------ """ -if __name__ == "__main__": - from configs.config import ConfigManager - - ConfigManager.load_config_with("../configs/local_train_config.yaml") - ConfigManager.print_config() diff --git a/runners/runner.py b/runners/runner.py deleted file mode 100644 index 5c98c00..0000000 --- a/runners/runner.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import time -from abc import abstractmethod, ABC - -import numpy as np -import torch - -from configs.config import ConfigManager - -class Runner(ABC): - - @abstractmethod - def __init__(self, config_path): - ConfigManager.load_config_with(config_path) - ConfigManager.print_config() - seed = ConfigManager.get("settings", "general", "seed") - self.device = ConfigManager.get("settings", "general", "device") - self.cuda_visible_devices = ConfigManager.get("settings","general","cuda_visible_devices") - os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_visible_devices - self.experiments_config = ConfigManager.get("settings", "experiment") - self.experiment_path = os.path.join(self.experiments_config["root_dir"], self.experiments_config["name"]) - np.random.seed(seed) - torch.manual_seed(seed) - lt = time.localtime() - self.file_name = f"{lt.tm_year}_{lt.tm_mon}_{lt.tm_mday}_{lt.tm_hour}h{lt.tm_min}m{lt.tm_sec}s" - - @abstractmethod - def run(self): - pass - - @abstractmethod - def load_experiment(self, backup_name=None): - if not os.path.exists(self.experiment_path): - print(f"experiments environment {self.experiments_config['name']} does not exists.") - self.create_experiment(backup_name) - else: - print(f"experiments environment {self.experiments_config['name']}") - backup_config_dir = os.path.join(str(self.experiment_path), "configs") - if not os.path.exists(backup_config_dir): - os.makedirs(backup_config_dir) - ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name) - - @abstractmethod - def create_experiment(self, backup_name=None): - print("creating experiment: " + self.experiments_config["name"]) - os.makedirs(self.experiment_path) - backup_config_dir = os.path.join(str(self.experiment_path), "configs") - os.makedirs(backup_config_dir) - ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name) - log_dir = os.path.join(str(self.experiment_path), "log") - os.makedirs(log_dir) - cache_dir = os.path.join(str(self.experiment_path), "cache") - os.makedirs(cache_dir) - - def print_info(self): - table_size = 80 - print("+" + "-" * table_size + "+") - print(f"| Experiment <{self.experiments_config['name']}>") - print("+" + "-" * table_size + "+") diff --git a/runners/strategy_generator.py b/runners/strategy_generator.py new file mode 100644 index 0000000..adc1888 --- /dev/null +++ b/runners/strategy_generator.py @@ -0,0 +1,73 @@ +import os +from PytorchBoot.runners.runner import Runner +from PytorchBoot.config import ConfigManager +import PytorchBoot.stereotype as stereotype + +@stereotype.runner("strategy_generator") +class StrategyGenerator(Runner): + def __init__(self, config): + super().__init__(config) + self.load_experiment("generate") + + def run(self): + self.demo(seq=16,num=100) + + def create_experiment(self, backup_name=None): + super().create_experiment(backup_name) + output_dir = os.path.join(str(self.experiment_path), "output") + os.makedirs(output_dir) + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + + def demo(self, seq, num=100): + import os + from utils.data_load import DataLoadUtil + from utils.reconstruction import ReconstructionUtil + import numpy as np + + component = self.config["generate"][0]["component"] #r"C:\Document\Local Project\nbv_rec\output" + data_dir = ConfigManager.get("datasets", "components", component, "root_dir") + model_path = os.path.join(data_dir, f"sequence.{seq}\\world_points.txt") + model_pts = np.loadtxt(model_path) + + output_dir = os.path.join(str(self.experiment_path), "output") + pts_list = [] + for idx in range(0,num): + path = DataLoadUtil.get_path(data_dir, seq, idx) + point_cloud = DataLoadUtil.get_point_cloud_world_from_path(path) + + sampled_point_cloud = ReconstructionUtil.downsample_point_cloud(point_cloud, 0.005) + pts_list.append(sampled_point_cloud) + + sampled_model_pts = ReconstructionUtil.downsample_point_cloud(model_pts, 0.005) + np.savetxt(os.path.join(output_dir,"sampled_model_points.txt"), sampled_model_pts) + thre = 0.005 + + useful_view, useless_view = ReconstructionUtil.compute_next_best_view_sequence(model_pts, pts_list, threshold=thre) + print("useful:", useful_view) + print("useless:", useless_view) + + selected_full_views = ReconstructionUtil.combine_point_with_view_sequence(pts_list, useful_view) + downsampled_selected_full_views = ReconstructionUtil.downsample_point_cloud(selected_full_views, thre) + np.savetxt(os.path.join(output_dir,"selected_full_views.txt"), downsampled_selected_full_views) + + limited_useful_view, limited_useless_view = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(model_pts, pts_list, threshold=thre, overlap_threshold=0.3) + print("limited_useful:", limited_useful_view) + print("limited_useless:", limited_useless_view) + + limited_selected_full_views = ReconstructionUtil.combine_point_with_view_sequence(pts_list, limited_useful_view) + downsampled_limited_selected_full_views = ReconstructionUtil.downsample_point_cloud(limited_selected_full_views, thre) + np.savetxt(os.path.join(output_dir,"selected_full_views_limited.txt"), downsampled_limited_selected_full_views) + import json + for idx, score in limited_useful_view: + path = DataLoadUtil.get_path(data_dir, seq, idx) + point_cloud = DataLoadUtil.get_point_cloud_world_from_path(path) + print("saving useful view: ", idx, " | score: ", score) + np.savetxt(os.path.join(output_dir,f"useful_view_{idx}.txt"), point_cloud) + with open(os.path.join(output_dir,f"useful_view.json"), 'w') as f: + json.dump(limited_useful_view, f) + print("seq length: ", len(useful_view), "limited seq length: ", len(limited_useful_view)) + + + \ No newline at end of file diff --git a/runners/trainer.py b/runners/trainer.py deleted file mode 100644 index d32b993..0000000 --- a/runners/trainer.py +++ /dev/null @@ -1,261 +0,0 @@ -import os -import sys -import json -from datetime import datetime - -import torch -from tqdm import tqdm -from torch.utils.tensorboard import SummaryWriter - -path = os.path.abspath(__file__) -for i in range(2): - path = os.path.dirname(path) -PROJECT_ROOT = path -sys.path.append(PROJECT_ROOT) - -from configs.config import ConfigManager -from datasets.dataset_factory import DatasetFactory -from optimizers.optimizer_factory import OptimizerFactory -from evaluations.eval_function_factory import EvalFunctionFactory -from losses.loss_function_factory import LossFunctionFactory -from modules.pipeline import Pipeline -from runners.runner import Runner -from utils.tensorboard_util import TensorboardWriter -from annotations.external_module import EXTERNAL_FREEZE_MODULES - - -class Trainer(Runner): - CHECKPOINT_DIR_NAME: str = 'checkpoints' - TENSORBOARD_DIR_NAME: str = 'tensorboard' - LOG_DIR_NAME: str = 'log' - - def __init__(self, config_path): - super().__init__(config_path) - tensorboard_path = os.path.join(self.experiment_path, Trainer.TENSORBOARD_DIR_NAME) - - ''' Pipeline ''' - self.pipeline_config = ConfigManager.get("settings", "pipeline") - self.parallel = ConfigManager.get("settings","general","parallel") - self.pipeline = Pipeline(self.pipeline_config) - if self.parallel and self.device == "cuda": - self.pipeline = torch.nn.DataParallel(self.pipeline) - self.pipeline = self.pipeline.to(self.device) - - ''' Experiment ''' - self.current_epoch = 0 - self.max_epochs = self.experiments_config["max_epochs"] - self.test_first = self.experiments_config["test_first"] - self.load_experiment("train") - - ''' Train ''' - self.train_config = ConfigManager.get("settings", "train") - self.train_dataset_config = self.train_config["dataset"] - self.train_set = DatasetFactory.create(self.train_dataset_config) - self.optimizer = OptimizerFactory.create(self.train_config["optimizer"], self.pipeline.parameters()) - self.train_writer = SummaryWriter( - log_dir=os.path.join(tensorboard_path, f"[train]{self.train_dataset_config['name']}")) - - ''' Test ''' - self.test_config = ConfigManager.get("settings", "test") - self.test_dataset_config_list = self.test_config["dataset_list"] - self.test_set_list = [] - self.test_writer_list = [] - seen_name = set() - for test_dataset_config in self.test_dataset_config_list: - if test_dataset_config["name"] not in seen_name: - seen_name.add(test_dataset_config["name"]) - else: - raise ValueError("Duplicate test dataset name: {}".format(test_dataset_config["name"])) - test_set = DatasetFactory.create(test_dataset_config) - test_writer = SummaryWriter( - log_dir=os.path.join(tensorboard_path, f"[test]{test_dataset_config['name']}")) - self.test_set_list.append(test_set) - self.test_writer_list.append(test_writer) - del seen_name - - self.print_info() - - def run(self): - save_interval = self.experiments_config["save_checkpoint_interval"] - if self.current_epoch != 0: - print("Continue training from epoch {}.".format(self.current_epoch)) - else: - print("Start training from initial model.") - if self.test_first: - print("Do test first.") - self.test() - while self.current_epoch < self.max_epochs: - self.current_epoch += 1 - self.train() - self.test() - if self.current_epoch % save_interval == 0: - self.save_checkpoint() - self.save_checkpoint(is_last=True) - - def train(self): - self.pipeline.train() - train_set_name = self.train_dataset_config["name"] - ratio = self.train_dataset_config["ratio"] - train_loader = self.train_set.get_loader(device="cuda", shuffle=True) - - loop = tqdm(enumerate(train_loader), total=len(train_loader)) - loader_length = len(train_loader) - for i, data in loop: - self.train_set.process_batch(data, self.device) - loss_dict = self.train_step(data) - loop.set_description( - f'Epoch [{self.current_epoch}/{self.max_epochs}] (Train: {train_set_name}, ratio={ratio})') - loop.set_postfix(loss=loss_dict) - curr_iters = (self.current_epoch - 1) * loader_length + i - TensorboardWriter.write_tensorboard(self.train_writer, "iter", loss_dict, curr_iters) - - def train_step(self, data): - self.optimizer.zero_grad() - output = self.pipeline(data, Pipeline.TRAIN_MODE) - total_loss, loss_dict = self.loss_fn(output, data) - total_loss.backward() - self.optimizer.step() - for k, v in loss_dict.items(): - loss_dict[k] = round(v, 5) - return loss_dict - - def loss_fn(self, output, data): - loss_config = self.train_config["losses"] - loss_dict = {} - total_loss = torch.tensor(0.0, dtype=torch.float32, device=self.device) - for key in loss_config: - weight = loss_config[key] - target_loss_fn = LossFunctionFactory.create(key) - loss = target_loss_fn(output, data) - loss_dict[key] = loss.item() - total_loss += weight * loss - - loss_dict['total_loss'] = total_loss.item() - return total_loss, loss_dict - - def test(self): - self.pipeline.eval() - with torch.no_grad(): - for dataset_idx, test_set in enumerate(self.test_set_list): - eval_list = self.test_dataset_config_list[dataset_idx]["eval_list"] - test_set_name = self.test_dataset_config_list[dataset_idx]["name"] - ratio = self.test_dataset_config_list[dataset_idx]["ratio"] - writer = self.test_writer_list[dataset_idx] - output_list = [] - data_list = [] - test_loader = test_set.get_loader("cpu") - loop = tqdm(enumerate(test_loader), total=int(len(test_loader))) - for i, data in loop: - test_set.process_batch(data, self.device) - output = self.pipeline(data, Pipeline.TEST_MODE) - output_list.append(output) - data_list.append(data) - loop.set_description( - f'Epoch [{self.current_epoch}/{self.max_epochs}] (Test: {test_set_name}, ratio={ratio})') - result_dict = self.eval_fn(output_list, data_list, eval_list) - TensorboardWriter.write_tensorboard(writer, "epoch", result_dict, self.current_epoch - 1) - - @staticmethod - def eval_fn(output_list, data_list, eval_list): - target_eval_fn = EvalFunctionFactory.create(eval_list) - result_dict = target_eval_fn(output_list, data_list) - return result_dict - - def get_checkpoint_path(self, is_last=False): - return os.path.join(self.experiment_path, Trainer.CHECKPOINT_DIR_NAME, - "Epoch_{}.pth".format( - self.current_epoch if self.current_epoch != -1 and not is_last else "last")) - - def load_checkpoint(self, is_last=False): - self.load(self.get_checkpoint_path(is_last)) - print(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") - if is_last: - checkpoint_root = os.path.join(self.experiment_path, Trainer.CHECKPOINT_DIR_NAME) - meta_path = os.path.join(checkpoint_root, "meta.json") - if not os.path.exists(meta_path): - raise FileNotFoundError( - "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) - file_path = os.path.join(checkpoint_root, "meta.json") - with open(file_path, "r") as f: - meta = json.load(f) - self.current_epoch = meta["last_epoch"] - - def save_checkpoint(self, is_last=False): - self.save(self.get_checkpoint_path(is_last)) - if not is_last: - print(f"Checkpoint at epoch {self.current_epoch} saved to {self.get_checkpoint_path(is_last)}") - else: - meta = { - "last_epoch": self.current_epoch, - "time": str(datetime.now()) - } - checkpoint_root = os.path.join(self.experiment_path, Trainer.CHECKPOINT_DIR_NAME) - file_path = os.path.join(checkpoint_root, "meta.json") - with open(file_path, "w") as f: - json.dump(meta, f) - - def load_experiment(self, backup_name=None): - super().load_experiment(backup_name) - if self.experiments_config["use_checkpoint"]: - self.current_epoch = self.experiments_config["epoch"] - self.load_checkpoint(is_last=(self.current_epoch == -1)) - - def create_experiment(self, backup_name=None): - super().create_experiment(backup_name) - ckpt_dir = os.path.join(str(self.experiment_path), Trainer.CHECKPOINT_DIR_NAME) - os.makedirs(ckpt_dir) - tensorboard_dir = os.path.join(str(self.experiment_path), Trainer.TENSORBOARD_DIR_NAME) - os.makedirs(tensorboard_dir) - - def load(self, path): - state_dict = torch.load(path) - if self.parallel: - self.pipeline.module.load_state_dict(state_dict) - else: - self.pipeline.load_state_dict(state_dict) - - def save(self, path): - if self.parallel: - state_dict = self.pipeline.module.state_dict() - else: - state_dict = self.pipeline.state_dict() - - for name, module in self.pipeline.named_modules(): - if module.__class__ in EXTERNAL_FREEZE_MODULES: - if name in state_dict: - del state_dict[name] - - torch.save(state_dict, path) - - - def print_info(self): - def print_dataset(config, dataset): - print("\t name: {}".format(config["name"])) - print("\t source: {}".format(config["source"])) - print("\t data_type: {}".format(config["data_type"])) - print("\t total_length: {}".format(len(dataset))) - print("\t ratio: {}".format(config["ratio"])) - print() - - super().print_info() - table_size = 70 - print(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+') - print(self.pipeline) - print(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+') - print("train dataset: ") - print_dataset(self.train_dataset_config, self.train_set) - for i, test_dataset_config in enumerate(self.test_dataset_config_list): - print(f"test dataset {i}: ") - print_dataset(test_dataset_config, self.test_set_list[i]) - - print(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+') - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="configs/train_config.yaml") - args = parser.parse_args() - trainer = Trainer(args.config) - trainer.run() diff --git a/utils/data_load.py b/utils/data_load.py new file mode 100644 index 0000000..c48726e --- /dev/null +++ b/utils/data_load.py @@ -0,0 +1,135 @@ +import os +import OpenEXR +import Imath +import numpy as np +import json +import cv2 + +class DataLoadUtil: + + @staticmethod + def get_path(root, scene_idx, frame_idx): + path = os.path.join(root, f"sequence.{scene_idx}", f"step{frame_idx}") + return path + + @staticmethod + def read_exr_depth(depth_path): + file = OpenEXR.InputFile(depth_path) + + dw = file.header()['dataWindow'] + width = dw.max.x - dw.min.x + 1 + height = dw.max.y - dw.min.y + 1 + + pix_type = Imath.PixelType(Imath.PixelType.FLOAT) + depth_map = np.frombuffer(file.channel('R', pix_type), dtype=np.float32) + + depth_map.shape = (height, width) + + return depth_map + + @staticmethod + def load_depth(path): + depth_path = path + ".camera.Depth.exr" + depth_map = DataLoadUtil.read_exr_depth(depth_path) + return depth_map + + @staticmethod + def load_rgb(path): + rgb_path = path + ".camera.png" + rgb_image = cv2.imread(rgb_path, cv2.IMREAD_COLOR) + return rgb_image + + @staticmethod + def load_seg(path): + seg_path = path + ".camera.semantic segmentation.png" + seg_image = cv2.imread(seg_path, cv2.IMREAD_COLOR) + return seg_image + + @staticmethod + def load_cam_info(path): + label_path = path + ".camera_params.json" + with open(label_path, 'r') as f: + label_data = json.load(f) + cam_transform = np.asarray(label_data['cam_to_world']).reshape( + (4, 4) + ).T + + offset = np.asarray([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + + cam_to_world = cam_transform @ offset + + + + f_x = label_data['f_x'] + f_y = label_data['f_y'] + c_x = label_data['c_x'] + c_y = label_data['c_y'] + cam_intrinsic = np.array([[f_x, 0, c_x], [0, f_y, c_y], [0, 0, 1]]) + + return { + "cam_to_world": cam_to_world, + "cam_intrinsic": cam_intrinsic + } + + + @staticmethod + def get_target_point_cloud(depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(255,255,255)): + h, w = depth.shape + i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy') + + z = depth + x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] + y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] + + points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) + points_camera_aug = np.concatenate([points_camera, np.ones((points_camera.shape[0], 1))], axis=-1) + + points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3] + mask = mask.reshape(-1, 3) + target_mask = np.all(mask == target_mask_label, axis=-1) + return { + "points_world": points_world[target_mask], + "points_camera": points_camera[target_mask] + } + + @staticmethod + def get_target_point_cloud(depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(255,255,255)): + h, w = depth.shape + i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy') + + z = depth + x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] + y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] + + points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) + points_camera_aug = np.concatenate([points_camera, np.ones((points_camera.shape[0], 1))], axis=-1) + + points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3] + mask = mask.reshape(-1, 3) + target_mask = np.all(mask == target_mask_label, axis=-1) + return { + "points_world": points_world[target_mask], + "points_camera": points_camera[target_mask] + } + + @staticmethod + def get_point_cloud_world_from_path(path): + cam_info = DataLoadUtil.load_cam_info(path) + depth = DataLoadUtil.load_depth(path) + mask = DataLoadUtil.load_seg(path) + point_cloud = DataLoadUtil.get_target_point_cloud(depth, cam_info['cam_intrinsic'], cam_info['cam_to_world'], mask) + return point_cloud['points_world'] + + @staticmethod + def get_point_cloud_list_from_seq(root, seq_idx, num_frames): + point_cloud_list = [] + for idx in range(num_frames): + path = DataLoadUtil.get_path(root, seq_idx, idx) + point_cloud = DataLoadUtil.get_point_cloud_world_from_path(path) + point_cloud_list.append(point_cloud) + return point_cloud_list + \ No newline at end of file diff --git a/utils/pose.py b/utils/pose.py new file mode 100644 index 0000000..b91aae9 --- /dev/null +++ b/utils/pose.py @@ -0,0 +1,246 @@ +import numpy as np +import torch +import torch.nn.functional as F + + +class PoseUtil: + ROTATION = 1 + TRANSLATION = 2 + SCALE = 3 + + @staticmethod + def get_uniform_translation(trans_m_min, trans_m_max, trans_unit, debug=False): + if isinstance(trans_m_min, list): + x_min, y_min, z_min = trans_m_min + x_max, y_max, z_max = trans_m_max + else: + x_min, y_min, z_min = trans_m_min, trans_m_min, trans_m_min + x_max, y_max, z_max = trans_m_max, trans_m_max, trans_m_max + + x = np.random.uniform(x_min, x_max) + y = np.random.uniform(y_min, y_max) + z = np.random.uniform(z_min, z_max) + translation = np.array([x, y, z]) + if trans_unit == "cm": + translation = translation / 100 + if debug: + print("uniform translation:", translation) + return translation + + @staticmethod + def get_uniform_rotation(rot_degree_min=0, rot_degree_max=180, debug=False): + axis = np.random.randn(3) + axis /= np.linalg.norm(axis) + theta = np.random.uniform( + rot_degree_min / 180 * np.pi, rot_degree_max / 180 * np.pi + ) + + K = np.array( + [[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]] + ) + R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * (K @ K) + if debug: + print("uniform rotation:", theta * 180 / np.pi) + return R + + @staticmethod + def get_uniform_pose( + trans_min, trans_max, rot_min=0, rot_max=180, trans_unit="cm", debug=False + ): + translation = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + rotation = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + pose = np.eye(4) + pose[:3, :3] = rotation + pose[:3, 3] = translation + return pose + + @staticmethod + def get_n_uniform_pose( + trans_min, + trans_max, + rot_min=0, + rot_max=180, + n=1, + trans_unit="cm", + fix=None, + contain_canonical=True, + debug=False, + ): + if fix == PoseUtil.ROTATION: + translations = np.zeros((n, 3)) + for i in range(n): + translations[i] = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + if contain_canonical: + translations[0] = np.zeros(3) + rotations = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + elif fix == PoseUtil.TRANSLATION: + rotations = np.zeros((n, 3, 3)) + for i in range(n): + rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + if contain_canonical: + rotations[0] = np.eye(3) + translations = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + else: + translations = np.zeros((n, 3)) + rotations = np.zeros((n, 3, 3)) + for i in range(n): + translations[i] = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + for i in range(n): + rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + if contain_canonical: + translations[0] = np.zeros(3) + rotations[0] = np.eye(3) + + pose = np.eye(4, 4, k=0)[np.newaxis, :].repeat(n, axis=0) + pose[:, :3, :3] = rotations + pose[:, :3, 3] = translations + + return pose + + @staticmethod + def get_n_uniform_pose_batch( + trans_min, + trans_max, + rot_min=0, + rot_max=180, + n=1, + batch_size=1, + trans_unit="cm", + fix=None, + contain_canonical=False, + debug=False, + ): + + batch_poses = [] + for i in range(batch_size): + pose = PoseUtil.get_n_uniform_pose( + trans_min, + trans_max, + rot_min, + rot_max, + n, + trans_unit, + fix, + contain_canonical, + debug, + ) + batch_poses.append(pose) + pose_batch = np.stack(batch_poses, axis=0) + return pose_batch + + @staticmethod + def get_uniform_scale(scale_min, scale_max, debug=False): + if isinstance(scale_min, list): + x_min, y_min, z_min = scale_min + x_max, y_max, z_max = scale_max + else: + x_min, y_min, z_min = scale_min, scale_min, scale_min + x_max, y_max, z_max = scale_max, scale_max, scale_max + + x = np.random.uniform(x_min, x_max) + y = np.random.uniform(y_min, y_max) + z = np.random.uniform(z_min, z_max) + scale = np.array([x, y, z]) + if debug: + print("uniform scale:", scale) + return scale + + @staticmethod + def normalize_rotation(rotation, rotation_mode): + if rotation_mode == "quat_wxyz" or rotation_mode == "quat_xyzw": + rotation /= torch.norm(rotation, dim=-1, keepdim=True) + elif rotation_mode == "rot_matrix": + rot_matrix = PoseUtil.rotation_6d_to_matrix_tensor_batch(rotation) + rotation[:, :3] = rot_matrix[:, 0, :] + rotation[:, 3:6] = rot_matrix[:, 1, :] + elif rotation_mode == "euler_xyz_sx_cx": + rot_sin_theta = rotation[:, :3] + rot_cos_theta = rotation[:, 3:6] + theta = torch.atan2(rot_sin_theta, rot_cos_theta) + rotation[:, :3] = torch.sin(theta) + rotation[:, 3:6] = torch.cos(theta) + elif rotation_mode == "euler_xyz": + pass + else: + raise NotImplementedError + return rotation + + @staticmethod + def get_pose_dim(rot_mode): + assert rot_mode in [ + "quat_wxyz", + "quat_xyzw", + "euler_xyz", + "euler_xyz_sx_cx", + "rot_matrix", + ], f"the rotation mode {rot_mode} is not supported!" + + if rot_mode == "quat_wxyz" or rot_mode == "quat_xyzw": + pose_dim = 4 + elif rot_mode == "euler_xyz": + pose_dim = 3 + elif rot_mode == "euler_xyz_sx_cx" or rot_mode == "rot_matrix": + pose_dim = 6 + else: + raise NotImplementedError + return pose_dim + + @staticmethod + def rotation_6d_to_matrix_tensor_batch(d6: torch.Tensor) -> torch.Tensor: + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + @staticmethod + def matrix_to_rotation_6d_tensor_batch(matrix: torch.Tensor) -> torch.Tensor: + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) + + @staticmethod + def rotation_6d_to_matrix_numpy(d6): + a1, a2 = d6[:3], d6[3:] + b1 = a1 / np.linalg.norm(a1) + b2 = a2 - np.dot(b1, a2) * b1 + b2 = b2 / np.linalg.norm(b2) + b3 = np.cross(b1, b2) + return np.stack((b1, b2, b3), axis=-2) + + @staticmethod + def matrix_to_rotation_6d_numpy(matrix): + return np.copy(matrix[:2, :]).reshape((6,)) + + +""" ------------ Debug ------------ """ + +if __name__ == "__main__": + for _ in range(1): + PoseUtil.get_uniform_pose( + trans_min=[-25, -25, 10], + trans_max=[25, 25, 60], + rot_min=0, + rot_max=10, + debug=True, + ) + PoseUtil.get_uniform_scale(scale_min=0.25, scale_max=0.30, debug=True) + PoseUtil.get_n_uniform_pose_batch( + trans_min=[-25, -25, 10], + trans_max=[25, 25, 60], + rot_min=0, + rot_max=10, + batch_size=2, + n=2, + fix=PoseUtil.TRANSLATION, + debug=True, + ) diff --git a/utils/reconstruction.py b/utils/reconstruction.py new file mode 100644 index 0000000..0fb025d --- /dev/null +++ b/utils/reconstruction.py @@ -0,0 +1,139 @@ +import numpy as np +import open3d as o3d +from scipy.spatial import cKDTree + +class ReconstructionUtil: + + @staticmethod + def compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold=0.01): + kdtree = cKDTree(combined_point_cloud) + distances, _ = kdtree.query(target_point_cloud) + covered_points = np.sum(distances < threshold) + coverage_rate = covered_points / target_point_cloud.shape[0] + return coverage_rate + + @staticmethod + def compute_overlap_rate(point_cloud1, point_cloud2, threshold=0.01): + kdtree1 = cKDTree(point_cloud1) + kdtree2 = cKDTree(point_cloud2) + distances1, _ = kdtree2.query(point_cloud1) + distances2, _ = kdtree1.query(point_cloud2) + overlapping_points1 = np.sum(distances1 < threshold) + overlapping_points2 = np.sum(distances2 < threshold) + + overlap_rate1 = overlapping_points1 / point_cloud1.shape[0] + overlap_rate2 = overlapping_points2 / point_cloud2.shape[0] + + return (overlap_rate1 + overlap_rate2) / 2 + + @staticmethod + def combine_point_with_view_sequence(point_list, view_sequence): + selected_views = [] + for view_index, _ in view_sequence: + selected_views.append(point_list[view_index]) + return np.vstack(selected_views) + + @staticmethod + def compute_next_view_coverage_list(views, combined_point_cloud, target_point_cloud, threshold=0.01): + best_view = None + best_coverage_increase = -1 + current_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold) + + for view_index, view in enumerate(views): + candidate_views = combined_point_cloud + [view] + down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(candidate_views, threshold) + new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold) + coverage_increase = new_coverage - current_coverage + if coverage_increase > best_coverage_increase: + best_coverage_increase = coverage_increase + best_view = view_index + return best_view, best_coverage_increase + + @staticmethod + def compute_next_best_view_sequence(target_point_cloud, point_cloud_list, threshold=0.01): + selected_views = [] + current_coverage = 0.0 + remaining_views = list(range(len(point_cloud_list))) + view_sequence = [] + target_point_cloud = ReconstructionUtil.downsample_point_cloud(target_point_cloud, threshold) + while remaining_views: + best_view = None + best_coverage_increase = -1 + + for view_index in remaining_views: + candidate_views = selected_views + [point_cloud_list[view_index]] + combined_point_cloud = np.vstack(candidate_views) + + down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_point_cloud,threshold) + new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold) + coverage_increase = new_coverage - current_coverage + + if coverage_increase > best_coverage_increase: + best_coverage_increase = coverage_increase + best_view = view_index + + if best_view is not None: + if best_coverage_increase <=1e-3: + break + selected_views.append(point_cloud_list[best_view]) + current_coverage += best_coverage_increase + view_sequence.append((best_view, current_coverage)) + remaining_views.remove(best_view) + return view_sequence, remaining_views + + + @staticmethod + def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, threshold=0.01, overlap_threshold=0.3): + selected_views = [] + current_coverage = 0.0 + remaining_views = list(range(len(point_cloud_list))) + view_sequence = [] + target_point_cloud = ReconstructionUtil.downsample_point_cloud(target_point_cloud, threshold) + + while remaining_views: + best_view = None + best_coverage_increase = -1 + + for view_index in remaining_views: + + if selected_views: + combined_old_point_cloud = np.vstack(selected_views) + down_sampled_old_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_old_point_cloud,threshold) + down_sampled_new_view_point_cloud = ReconstructionUtil.downsample_point_cloud(point_cloud_list[view_index],threshold) + overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_old_point_cloud,down_sampled_new_view_point_cloud , threshold) + if overlap_rate < overlap_threshold: + continue + + candidate_views = selected_views + [point_cloud_list[view_index]] + combined_point_cloud = np.vstack(candidate_views) + down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_point_cloud,threshold) + new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold) + coverage_increase = new_coverage - current_coverage + #print(f"view_index: {view_index}, coverage_increase: {coverage_increase}") + if coverage_increase > best_coverage_increase: + best_coverage_increase = coverage_increase + best_view = view_index + + if best_view is not None: + if best_coverage_increase <=1e-3: + break + selected_views.append(point_cloud_list[best_view]) + remaining_views.remove(best_view) + if best_coverage_increase > 0: + current_coverage += best_coverage_increase + + view_sequence.append((best_view, current_coverage)) + + else: + break + + return view_sequence, remaining_views + + + def downsample_point_cloud(point_cloud, voxel_size=0.005): + o3d_pc = o3d.geometry.PointCloud() + o3d_pc.points = o3d.utility.Vector3dVector(point_cloud) + downsampled_pc = o3d_pc.voxel_down_sample(voxel_size) + return np.asarray(downsampled_pc.points) + + \ No newline at end of file diff --git a/utils/tensorboard_util.py b/utils/tensorboard_util.py deleted file mode 100644 index d9f85f2..0000000 --- a/utils/tensorboard_util.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - - -class TensorboardWriter: - @staticmethod - def write_tensorboard(writer, panel, data_dict, step): - complex_dict = False - if "scalars" in data_dict: - scalar_data_dict = data_dict["scalars"] - TensorboardWriter.write_scalar_tensorboard(writer, panel, scalar_data_dict, step) - complex_dict = True - if "images" in data_dict: - image_data_dict = data_dict["images"] - TensorboardWriter.write_image_tensorboard(writer, panel, image_data_dict, step) - complex_dict = True - if "points" in data_dict: - point_data_dict = data_dict["points"] - TensorboardWriter.write_points_tensorboard(writer, panel, point_data_dict, step) - complex_dict = True - - if not complex_dict: - TensorboardWriter.write_scalar_tensorboard(writer, panel, data_dict, step) - - @staticmethod - def write_scalar_tensorboard(writer, panel, data_dict, step): - for key, value in data_dict.items(): - if isinstance(value, dict): - writer.add_scalars(f'{panel}/{key}', value, step) - else: - writer.add_scalar(f'{panel}/{key}', value, step) - - @staticmethod - def write_image_tensorboard(writer, panel, data_dict, step): - pass - - @staticmethod - def write_points_tensorboard(writer, panel, data_dict, step): - for key, value in data_dict.items(): - if value.shape[-1] == 3: - colors = torch.zeros_like(value) - vertices = torch.cat([value, colors], dim=-1) - elif value.shape[-1] == 6: - vertices = value - else: - raise ValueError(f'Unexpected value shape: {value.shape}') - faces = None - writer.add_mesh(f'{panel}/{key}', vertices=vertices, faces=faces, global_step=step)