update basic framework
This commit is contained in:
parent
73dcd592df
commit
f977fd4b8e
@ -1,7 +0,0 @@
|
||||
EXTERNAL_FREEZE_MODULES = set()
|
||||
|
||||
def external_freeze(cls):
|
||||
if not hasattr(cls, 'load') or not callable(getattr(cls, 'load')):
|
||||
raise TypeError(f"external module <{cls.__name__}> must implement a 'load' method")
|
||||
EXTERNAL_FREEZE_MODULES.add(cls)
|
||||
return cls
|
@ -1,34 +0,0 @@
|
||||
# --- Classes --- #
|
||||
|
||||
def dataset():
|
||||
pass
|
||||
|
||||
def module():
|
||||
pass
|
||||
|
||||
def pipeline():
|
||||
pass
|
||||
|
||||
def runner():
|
||||
pass
|
||||
|
||||
def factory():
|
||||
pass
|
||||
|
||||
# --- Functions --- #
|
||||
|
||||
evaluation_methods = {}
|
||||
def evaluation_method(eval_type):
|
||||
def decorator(func):
|
||||
evaluation_methods[eval_type] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def loss_function():
|
||||
pass
|
||||
|
||||
|
||||
# --- Main --- #
|
||||
|
||||
|
8
app_generate.py
Normal file
8
app_generate.py
Normal file
@ -0,0 +1,8 @@
|
||||
from PytorchBoot.application import PytorchBootApplication
|
||||
from runners.strategy_generator import StrategyGenerator
|
||||
|
||||
@PytorchBootApplication("generate")
|
||||
class Generator:
|
||||
@staticmethod
|
||||
def start():
|
||||
StrategyGenerator("configs\generate_config.yaml").run()
|
@ -1,74 +0,0 @@
|
||||
import argparse
|
||||
import os.path
|
||||
import shutil
|
||||
import yaml
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
config = None
|
||||
config_path = None
|
||||
|
||||
@staticmethod
|
||||
def get(*args):
|
||||
result = ConfigManager.config
|
||||
for arg in args:
|
||||
result = result[arg]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def load_config_with(config_file_path):
|
||||
ConfigManager.config_path = config_file_path
|
||||
if not os.path.exists(ConfigManager.config_path):
|
||||
raise ValueError(f"Config file <{config_file_path}> does not exist")
|
||||
with open(config_file_path, 'r') as file:
|
||||
ConfigManager.config = yaml.safe_load(file)
|
||||
|
||||
@staticmethod
|
||||
def backup_config_to(target_config_dir, file_name, prefix="config"):
|
||||
file_name = f"{prefix}_{file_name}.yaml"
|
||||
target_config_file_path = str(os.path.join(target_config_dir, file_name))
|
||||
shutil.copy(ConfigManager.config_path, target_config_file_path)
|
||||
|
||||
@staticmethod
|
||||
def load_config():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, default='', help='config file path')
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
ConfigManager.load_config_with(args.config)
|
||||
|
||||
@staticmethod
|
||||
def print_config(key: str = None, group: dict = None, level=0):
|
||||
table_size = 80
|
||||
if key and group:
|
||||
value = group[key]
|
||||
if type(value) is dict:
|
||||
print("\t" * level + f"+-{key}:")
|
||||
for k in value:
|
||||
ConfigManager.print_config(k, value, level=level + 1)
|
||||
else:
|
||||
print("\t" * level + f"| {key}: {value}")
|
||||
elif key:
|
||||
ConfigManager.print_config(key, ConfigManager.config, level=level)
|
||||
else:
|
||||
print("+" + "-" * table_size + "+")
|
||||
print(f"| Configurations in <{ConfigManager.config_path}>:")
|
||||
print("+" + "-" * table_size + "+")
|
||||
for key in ConfigManager.config:
|
||||
ConfigManager.print_config(key, level=level + 1)
|
||||
print("+" + "-" * table_size + "+")
|
||||
|
||||
|
||||
''' ------------ Debug ------------ '''
|
||||
if __name__ == "__main__":
|
||||
test_args = ['--config', r'configs\train_config.yaml']
|
||||
test_parser = argparse.ArgumentParser()
|
||||
test_parser.add_argument('--config', type=str, default='', help='config file path')
|
||||
test_args = test_parser.parse_args(test_args)
|
||||
if test_args.config:
|
||||
ConfigManager.load_config_with(test_args.config)
|
||||
ConfigManager.print_config()
|
||||
print()
|
||||
pipeline = ConfigManager.get('settings', 'train', "dataset", 'batch_size')
|
||||
ConfigManager.print_config('settings')
|
||||
print(pipeline)
|
23
configs/generate_config.yaml
Normal file
23
configs/generate_config.yaml
Normal file
@ -0,0 +1,23 @@
|
||||
|
||||
runners:
|
||||
general:
|
||||
seed: 0
|
||||
device: cpu
|
||||
cuda_visible_devices: "0,1,2,3,4,5,6,7"
|
||||
|
||||
experiment:
|
||||
name: debug
|
||||
root_dir: "experiments"
|
||||
|
||||
generate:
|
||||
- name: OmniObject3d_train
|
||||
component: OmniObject3d
|
||||
data_type: train
|
||||
|
||||
datasets:
|
||||
general:
|
||||
components:
|
||||
OmniObject3d:
|
||||
root_dir: "C:\\Document\\Local Project\\nbv_rec\\output"
|
||||
|
||||
|
@ -1,73 +0,0 @@
|
||||
# Train config file
|
||||
|
||||
settings:
|
||||
general:
|
||||
seed: 0
|
||||
device: cuda
|
||||
cuda_visible_devices: "0,1,2,3,4,5,6,7"
|
||||
parallel: True
|
||||
|
||||
experiment:
|
||||
name: train_experiment
|
||||
root_dir: "experiments"
|
||||
use_checkpoint: True
|
||||
epoch: -1 # -1 stands for last epoch
|
||||
max_epochs: 5000
|
||||
save_checkpoint_interval: 1
|
||||
test_first: True
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 0.0001
|
||||
losses: # loss type : weight
|
||||
gf_loss: 1.0
|
||||
dataset:
|
||||
name: synthetic_train_train_dataset
|
||||
source: nbv1
|
||||
data_type: train
|
||||
ratio: 0.1
|
||||
batch_size: 128
|
||||
num_workers: 96
|
||||
|
||||
test:
|
||||
batch_size: 16
|
||||
frequency: 3
|
||||
dataset_list:
|
||||
- name: synthetic_test_train_dataset
|
||||
source: nbv1
|
||||
data_type: train
|
||||
eval_list:
|
||||
ratio: 0.00001
|
||||
batch_size: 16
|
||||
num_workers: 16
|
||||
|
||||
pipeline:
|
||||
pts_encoder: pointnet
|
||||
view_finder: gradient_field
|
||||
|
||||
datasets:
|
||||
general:
|
||||
data_dir: "../data"
|
||||
|
||||
modules:
|
||||
general:
|
||||
pts_channels: 3
|
||||
feature_dim: 1024
|
||||
per_point_feature: False
|
||||
pts_encoder:
|
||||
pointnet:
|
||||
pointnet++:
|
||||
params_name: light
|
||||
pointnet++rgb:
|
||||
params_name: light
|
||||
target_layer: 3
|
||||
rgb_feat_dim: 384
|
||||
view_finder:
|
||||
gradient_field:
|
||||
pose_mode: rot_matrix
|
||||
regression_head: Rx_Ry
|
||||
sample_mode: ode
|
||||
sample_repeat: 50
|
||||
sampling_steps: 500
|
||||
sde_mode: ve
|
@ -1,35 +0,0 @@
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
class BaseDataset(ABC, Dataset):
|
||||
def __init__(self, config):
|
||||
super(BaseDataset, self).__init__()
|
||||
self.config = config
|
||||
|
||||
@staticmethod
|
||||
def process_batch(batch, device):
|
||||
for key in batch.keys():
|
||||
if isinstance(batch[key], list):
|
||||
continue
|
||||
batch[key] = batch[key].to(device)
|
||||
return batch
|
||||
|
||||
def get_loader(self, shuffle=False):
|
||||
ratio = self.config["ratio"]
|
||||
if ratio > 1 or ratio <= 0:
|
||||
raise ValueError(
|
||||
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
|
||||
)
|
||||
subset_size = int(len(self) * ratio)
|
||||
indices = np.random.permutation(len(self))[:subset_size]
|
||||
subset = Subset(self, indices)
|
||||
return DataLoader(
|
||||
subset,
|
||||
batch_size=self.config["batch_size"],
|
||||
num_workers=self.config["num_workers"],
|
||||
shuffle=shuffle,
|
||||
)
|
@ -1,30 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
path = os.path.abspath(__file__)
|
||||
for i in range(2):
|
||||
path = os.path.dirname(path)
|
||||
PROJECT_ROOT = path
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from datasets.dataset import BaseDataset
|
||||
|
||||
class DatasetFactory:
|
||||
@staticmethod
|
||||
def create(config) -> BaseDataset:
|
||||
pass
|
||||
|
||||
|
||||
''' ------------ Debug ------------ '''
|
||||
if __name__ == "__main__":
|
||||
|
||||
from configs.config import ConfigManager
|
||||
|
||||
ConfigManager.load_config_with('/home/data/hofee/project/ActivePerception/ActivePerception/configs/server_train_config.yaml')
|
||||
ConfigManager.print_config()
|
||||
dataset = DatasetFactory.create(ConfigManager.get("settings", "test", "dataset_list")[1])
|
||||
print(len(dataset))
|
||||
data_test = dataset.__getitem__(107000)
|
||||
print(data_test['src_path'])
|
||||
import pickle
|
||||
# with open("data_sample_new.pkl", "wb") as f:
|
||||
# pickle.dump(data_test, f)
|
@ -1,35 +0,0 @@
|
||||
from annotations.stereotype import evaluation_methods
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
|
||||
package_name = os.path.dirname("evaluations")
|
||||
package = importlib.import_module("evaluations")
|
||||
for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
|
||||
importlib.import_module(module_name)
|
||||
|
||||
class EvalFunctionFactory:
|
||||
@staticmethod
|
||||
def create(eval_type_list):
|
||||
def eval_func(output, data):
|
||||
temp_results = {"scalars": {}, "points": {}, "images": {}}
|
||||
for eval_type in eval_type_list:
|
||||
if eval_type in evaluation_methods:
|
||||
result = evaluation_methods[eval_type](output, data)
|
||||
for k, v in result.items():
|
||||
temp_results[k].update(v)
|
||||
results = {}
|
||||
for k, v in temp_results.items():
|
||||
if len(v) > 0:
|
||||
results[k] = v
|
||||
return results
|
||||
|
||||
return eval_func
|
||||
|
||||
|
||||
''' ------------ Debug ------------ '''
|
||||
if __name__ == "__main__":
|
||||
from configs.config import ConfigManager
|
||||
|
||||
ConfigManager.load_config_with('../configs/local_train_config.yaml')
|
||||
ConfigManager.print_config()
|
12
losses/gf_loss.py
Normal file
12
losses/gf_loss.py
Normal file
@ -0,0 +1,12 @@
|
||||
import torch
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
@stereotype.loss_function("gf_loss")
|
||||
def compute_loss(output, data):
|
||||
estimated_score = output['estimated_score']
|
||||
target_score = output['target_score']
|
||||
std = output['std']
|
||||
bs = estimated_score.shape[0]
|
||||
loss_weighting = std ** 2
|
||||
loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1))
|
||||
return loss
|
@ -1,12 +0,0 @@
|
||||
class LossFunctionFactory:
|
||||
@staticmethod
|
||||
def create(function_name):
|
||||
raise ValueError("Unknown loss function {}".format(function_name))
|
||||
|
||||
|
||||
''' ------------ Debug ------------ '''
|
||||
if __name__ == "__main__":
|
||||
from configs.config import ConfigManager
|
||||
|
||||
ConfigManager.load_config_with('../configs/local_train_config.yaml')
|
||||
ConfigManager.print_config()
|
7
modules/func_lib/__init__.py
Normal file
7
modules/func_lib/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from modules.func_lib.samplers import (
|
||||
cond_pc_sampler,
|
||||
cond_ode_sampler
|
||||
)
|
||||
from modules.func_lib.sde import (
|
||||
init_sde
|
||||
)
|
280
modules/func_lib/samplers.py
Normal file
280
modules/func_lib/samplers.py
Normal file
@ -0,0 +1,280 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from scipy import integrate
|
||||
from utils.pose import PoseUtil
|
||||
|
||||
|
||||
def global_prior_likelihood(z, sigma_max):
|
||||
"""The likelihood of a Gaussian distribution with mean zero and
|
||||
standard deviation sigma."""
|
||||
# z: [bs, pose_dim]
|
||||
shape = z.shape
|
||||
N = np.prod(shape[1:]) # pose_dim
|
||||
return -N / 2. * torch.log(2 * np.pi * sigma_max ** 2) - torch.sum(z ** 2, dim=-1) / (2 * sigma_max ** 2)
|
||||
|
||||
|
||||
def cond_ode_likelihood(
|
||||
score_model,
|
||||
data,
|
||||
prior,
|
||||
sde_coeff,
|
||||
marginal_prob_fn,
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
device='cuda',
|
||||
eps=1e-5,
|
||||
num_steps=None,
|
||||
pose_mode='quat_wxyz',
|
||||
init_x=None,
|
||||
):
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
batch_size = data['pts'].shape[0]
|
||||
epsilon = prior((batch_size, pose_dim)).to(device)
|
||||
init_x = data['sampled_pose'].clone().cpu().numpy() if init_x is None else init_x
|
||||
shape = init_x.shape
|
||||
init_logp = np.zeros((shape[0],)) # [bs]
|
||||
init_inp = np.concatenate([init_x.reshape(-1), init_logp], axis=0)
|
||||
|
||||
def score_eval_wrapper(data):
|
||||
"""A wrapper of the score-based model for use by the ODE solver."""
|
||||
with torch.no_grad():
|
||||
score = score_model(data)
|
||||
return score.cpu().numpy().reshape((-1,))
|
||||
|
||||
def divergence_eval(data, epsilon):
|
||||
"""Compute the divergence of the score-based model with Skilling-Hutchinson."""
|
||||
# save ckpt of sampled_pose
|
||||
origin_sampled_pose = data['sampled_pose'].clone()
|
||||
with torch.enable_grad():
|
||||
# make sampled_pose differentiable
|
||||
data['sampled_pose'].requires_grad_(True)
|
||||
score = score_model(data)
|
||||
score_energy = torch.sum(score * epsilon) # [, ]
|
||||
grad_score_energy = torch.autograd.grad(score_energy, data['sampled_pose'])[0] # [bs, pose_dim]
|
||||
# reset sampled_pose
|
||||
data['sampled_pose'] = origin_sampled_pose
|
||||
return torch.sum(grad_score_energy * epsilon, dim=-1) # [bs, 1]
|
||||
|
||||
def divergence_eval_wrapper(data):
|
||||
"""A wrapper for evaluating the divergence of score for the black-box ODE solver."""
|
||||
with torch.no_grad():
|
||||
# Compute likelihood.
|
||||
div = divergence_eval(data, epsilon) # [bs, 1]
|
||||
return div.cpu().numpy().reshape((-1,)).astype(np.float64)
|
||||
|
||||
def ode_func(t, inp):
|
||||
"""The ODE function for use by the ODE solver."""
|
||||
# split x, logp from inp
|
||||
x = inp[:-shape[0]]
|
||||
# calc x-grad
|
||||
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
|
||||
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
|
||||
drift, diffusion = sde_coeff(torch.tensor(t))
|
||||
drift = drift.cpu().numpy()
|
||||
diffusion = diffusion.cpu().numpy()
|
||||
data['sampled_pose'] = x
|
||||
data['t'] = time_steps
|
||||
x_grad = drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data)
|
||||
# calc logp-grad
|
||||
logp_grad = drift - 0.5 * (diffusion ** 2) * divergence_eval_wrapper(data)
|
||||
# concat curr grad
|
||||
return np.concatenate([x_grad, logp_grad], axis=0)
|
||||
|
||||
# Run the black-box ODE solver, note the
|
||||
res = integrate.solve_ivp(ode_func, (eps, 1.0), init_inp, rtol=rtol, atol=atol, method='RK45')
|
||||
zp = torch.tensor(res.y[:, -1], device=device) # [bs * (pose_dim + 1)]
|
||||
z = zp[:-shape[0]].reshape(shape) # [bs, pose_dim]
|
||||
delta_logp = zp[-shape[0]:].reshape(shape[0]) # [bs,] logp
|
||||
_, sigma_max = marginal_prob_fn(None, torch.tensor(1.).to(device)) # we assume T = 1
|
||||
prior_logp = global_prior_likelihood(z, sigma_max)
|
||||
log_likelihoods = (prior_logp + delta_logp) / np.log(2) # negative log-likelihoods (nlls)
|
||||
return z, log_likelihoods
|
||||
|
||||
|
||||
def cond_pc_sampler(
|
||||
score_model,
|
||||
data,
|
||||
prior,
|
||||
sde_coeff,
|
||||
num_steps=500,
|
||||
snr=0.16,
|
||||
device='cuda',
|
||||
eps=1e-5,
|
||||
pose_mode='quat_wxyz',
|
||||
init_x=None,
|
||||
):
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
batch_size = data['target_pts_feat'].shape[0]
|
||||
init_x = prior((batch_size, pose_dim)).to(device) if init_x is None else init_x
|
||||
time_steps = torch.linspace(1., eps, num_steps, device=device)
|
||||
step_size = time_steps[0] - time_steps[1]
|
||||
noise_norm = np.sqrt(pose_dim)
|
||||
x = init_x
|
||||
poses = []
|
||||
with torch.no_grad():
|
||||
for time_step in time_steps:
|
||||
batch_time_step = torch.ones(batch_size, device=device).unsqueeze(-1) * time_step
|
||||
# Corrector step (Langevin MCMC)
|
||||
data['sampled_pose'] = x
|
||||
data['t'] = batch_time_step
|
||||
grad = score_model(data)
|
||||
grad_norm = torch.norm(grad.reshape(batch_size, -1), dim=-1).mean()
|
||||
langevin_step_size = 2 * (snr * noise_norm / grad_norm) ** 2
|
||||
x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)
|
||||
|
||||
# normalisation
|
||||
if pose_mode == 'quat_wxyz' or pose_mode == 'quat_xyzw':
|
||||
# quat, should be normalised
|
||||
x[:, :4] /= torch.norm(x[:, :4], dim=-1, keepdim=True)
|
||||
elif pose_mode == 'euler_xyz':
|
||||
pass
|
||||
else:
|
||||
# rotation(x axis, y axis), should be normalised
|
||||
x[:, :3] /= torch.norm(x[:, :3], dim=-1, keepdim=True)
|
||||
x[:, 3:6] /= torch.norm(x[:, 3:6], dim=-1, keepdim=True)
|
||||
|
||||
# Predictor step (Euler-Maruyama)
|
||||
drift, diffusion = sde_coeff(batch_time_step)
|
||||
drift = drift - diffusion ** 2 * grad # R-SDE
|
||||
mean_x = x + drift * step_size
|
||||
x = mean_x + diffusion * torch.sqrt(step_size) * torch.randn_like(x)
|
||||
|
||||
# normalisation
|
||||
x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode)
|
||||
poses.append(x.unsqueeze(0))
|
||||
|
||||
xs = torch.cat(poses, dim=0)
|
||||
xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1)
|
||||
mean_x[:, -3:] += data['pts_center']
|
||||
mean_x[:, :-3] = PoseUtil.normalize_rotation(mean_x[:, :-3], pose_mode)
|
||||
# The last step does not include any noise
|
||||
return xs.permute(1, 0, 2), mean_x
|
||||
|
||||
|
||||
def cond_ode_sampler(
|
||||
score_model,
|
||||
data,
|
||||
prior,
|
||||
sde_coeff,
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
device='cuda',
|
||||
eps=1e-5,
|
||||
T=1.0,
|
||||
num_steps=None,
|
||||
pose_mode='quat_wxyz',
|
||||
denoise=True,
|
||||
init_x=None,
|
||||
):
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
batch_size = data['target_feat'].shape[0]
|
||||
init_x = prior((batch_size, pose_dim), T=T).to(device) if init_x is None else init_x + prior((batch_size, pose_dim),
|
||||
T=T).to(device)
|
||||
shape = init_x.shape
|
||||
|
||||
def score_eval_wrapper(data):
|
||||
"""A wrapper of the score-based model for use by the ODE solver."""
|
||||
with torch.no_grad():
|
||||
score = score_model(data)
|
||||
return score.cpu().numpy().reshape((-1,))
|
||||
|
||||
def ode_func(t, x):
|
||||
"""The ODE function for use by the ODE solver."""
|
||||
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
|
||||
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
|
||||
drift, diffusion = sde_coeff(torch.tensor(t))
|
||||
drift = drift.cpu().numpy()
|
||||
diffusion = diffusion.cpu().numpy()
|
||||
data['sampled_pose'] = x
|
||||
data['t'] = time_steps
|
||||
return drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data)
|
||||
|
||||
# Run the black-box ODE solver, note the
|
||||
t_eval = None
|
||||
if num_steps is not None:
|
||||
# num_steps, from T -> eps
|
||||
t_eval = np.linspace(T, eps, num_steps)
|
||||
res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45',
|
||||
t_eval=t_eval)
|
||||
xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim]
|
||||
x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
|
||||
# denoise, using the predictor step in P-C sampler
|
||||
if denoise:
|
||||
# Reverse diffusion predictor for denoising
|
||||
vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps
|
||||
drift, diffusion = sde_coeff(vec_eps)
|
||||
data['sampled_pose'] = x.float()
|
||||
data['t'] = vec_eps
|
||||
grad = score_model(data)
|
||||
drift = drift - diffusion ** 2 * grad # R-SDE
|
||||
mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps))
|
||||
x = mean_x
|
||||
|
||||
num_steps = xs.shape[0]
|
||||
xs = xs.reshape(batch_size * num_steps, -1)
|
||||
xs = PoseUtil.normalize_rotation(xs, pose_mode)
|
||||
xs = xs.reshape(num_steps, batch_size, -1)
|
||||
x = PoseUtil.normalize_rotation(x, pose_mode)
|
||||
return xs.permute(1, 0, 2), x
|
||||
|
||||
|
||||
def cond_edm_sampler(
|
||||
decoder_model, data, prior_fn, randn_like=torch.randn_like,
|
||||
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
|
||||
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
||||
pose_mode='quat_wxyz', device='cuda'
|
||||
):
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
batch_size = data['pts'].shape[0]
|
||||
latents = prior_fn((batch_size, pose_dim)).to(device)
|
||||
|
||||
# Time step discretion. note that sigma and t is interchangeable
|
||||
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
||||
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
|
||||
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
||||
t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
||||
|
||||
def decoder_wrapper(decoder, data, x, t):
|
||||
# save temp
|
||||
x_, t_ = data['sampled_pose'], data['t']
|
||||
# init data
|
||||
data['sampled_pose'], data['t'] = x, t
|
||||
# denoise
|
||||
data, denoised = decoder(data)
|
||||
# recover data
|
||||
data['sampled_pose'], data['t'] = x_, t_
|
||||
return denoised.to(torch.float64)
|
||||
|
||||
# Main sampling loop.
|
||||
x_next = latents.to(torch.float64) * t_steps[0]
|
||||
xs = []
|
||||
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
||||
x_cur = x_next
|
||||
|
||||
# Increase noise temporarily.
|
||||
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
||||
t_hat = torch.as_tensor(t_cur + gamma * t_cur)
|
||||
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
|
||||
|
||||
# Euler step.
|
||||
denoised = decoder_wrapper(decoder_model, data, x_hat, t_hat)
|
||||
d_cur = (x_hat - denoised) / t_hat
|
||||
x_next = x_hat + (t_next - t_hat) * d_cur
|
||||
|
||||
# Apply 2nd order correction.
|
||||
if i < num_steps - 1:
|
||||
denoised = decoder_wrapper(decoder_model, data, x_next, t_next)
|
||||
d_prime = (x_next - denoised) / t_next
|
||||
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
||||
xs.append(x_next.unsqueeze(0))
|
||||
|
||||
xs = torch.stack(xs, dim=0) # [num_steps, bs, pose_dim]
|
||||
x = xs[-1] # [bs, pose_dim]
|
||||
|
||||
# post-processing
|
||||
xs = xs.reshape(batch_size * num_steps, -1)
|
||||
xs = PoseUtil.normalize_rotation(xs, pose_mode)
|
||||
xs = xs.reshape(num_steps, batch_size, -1)
|
||||
x = PoseUtil.normalize_rotation(x, pose_mode)
|
||||
return xs.permute(1, 0, 2), x
|
121
modules/func_lib/sde.py
Normal file
121
modules/func_lib/sde.py
Normal file
@ -0,0 +1,121 @@
|
||||
import functools
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ----- VE SDE -----
|
||||
# ------------------
|
||||
def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90):
|
||||
std = sigma_min * (sigma_max / sigma_min) ** t
|
||||
mean = x
|
||||
return mean, std
|
||||
|
||||
|
||||
def ve_sde(t, sigma_min=0.01, sigma_max=90):
|
||||
sigma = sigma_min * (sigma_max / sigma_min) ** t
|
||||
drift_coeff = torch.tensor(0)
|
||||
diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device))
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0):
|
||||
_, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
return torch.randn(*shape) * sigma_max_prior
|
||||
|
||||
|
||||
# ----- VP SDE -----
|
||||
# ------------------
|
||||
def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20):
|
||||
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
|
||||
mean = torch.exp(log_mean_coeff) * x
|
||||
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
|
||||
return mean, std
|
||||
|
||||
|
||||
def vp_sde(t, beta_0=0.1, beta_1=20):
|
||||
beta_t = beta_0 + t * (beta_1 - beta_0)
|
||||
drift_coeff = -0.5 * beta_t
|
||||
diffusion_coeff = torch.sqrt(beta_t)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def vp_prior(shape, beta_0=0.1, beta_1=20):
|
||||
return torch.randn(*shape)
|
||||
|
||||
|
||||
# ----- sub-VP SDE -----
|
||||
# ----------------------
|
||||
def subvp_marginal_prob(x, t, beta_0, beta_1):
|
||||
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
|
||||
mean = torch.exp(log_mean_coeff) * x
|
||||
std = 1 - torch.exp(2. * log_mean_coeff)
|
||||
return mean, std
|
||||
|
||||
|
||||
def subvp_sde(t, beta_0, beta_1):
|
||||
beta_t = beta_0 + t * (beta_1 - beta_0)
|
||||
drift_coeff = -0.5 * beta_t
|
||||
discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2)
|
||||
diffusion_coeff = torch.sqrt(beta_t * discount)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def subvp_prior(shape, beta_0=0.1, beta_1=20):
|
||||
return torch.randn(*shape)
|
||||
|
||||
|
||||
# ----- EDM SDE -----
|
||||
# ------------------
|
||||
def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80):
|
||||
std = t
|
||||
mean = x
|
||||
return mean, std
|
||||
|
||||
|
||||
def edm_sde(t, sigma_min=0.002, sigma_max=80):
|
||||
drift_coeff = torch.tensor(0)
|
||||
diffusion_coeff = torch.sqrt(2 * t)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def edm_prior(shape, sigma_min=0.002, sigma_max=80):
|
||||
return torch.randn(*shape) * sigma_max
|
||||
|
||||
|
||||
def init_sde(sde_mode):
|
||||
# the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch
|
||||
if sde_mode == 'edm':
|
||||
sigma_min = 0.002
|
||||
sigma_max = 80
|
||||
eps = 0.002
|
||||
prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
T = sigma_max
|
||||
elif sde_mode == 've':
|
||||
sigma_min = 0.01
|
||||
sigma_max = 50
|
||||
eps = 1e-5
|
||||
marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
T = 1.0
|
||||
prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
elif sde_mode == 'vp':
|
||||
beta_0 = 0.1
|
||||
beta_1 = 20
|
||||
eps = 1e-3
|
||||
prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1)
|
||||
marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
|
||||
sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1)
|
||||
T = 1.0
|
||||
elif sde_mode == 'subvp':
|
||||
beta_0 = 0.1
|
||||
beta_1 = 20
|
||||
eps = 1e-3
|
||||
prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1)
|
||||
marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
|
||||
sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1)
|
||||
T = 1.0
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return prior_fn, marginal_prob_fn, sde_fn, eps, T
|
17
modules/module_lib/gaussian_fourier_projection.py
Normal file
17
modules/module_lib/gaussian_fourier_projection.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class GaussianFourierProjection(nn.Module):
|
||||
"""Gaussian random features for encoding time steps."""
|
||||
|
||||
def __init__(self, embed_dim, scale=30.):
|
||||
super().__init__()
|
||||
# Randomly sample weights during initialization. These weights are fixed
|
||||
# during optimization and are not trainable.
|
||||
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
||||
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
30
modules/module_lib/linear.py
Normal file
30
modules/module_lib/linear.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def weight_init(shape, mode, fan_in, fan_out):
|
||||
if mode == 'xavier_uniform':
|
||||
return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
|
||||
if mode == 'xavier_normal':
|
||||
return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
|
||||
if mode == 'kaiming_uniform':
|
||||
return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
|
||||
if mode == 'kaiming_normal':
|
||||
return np.sqrt(1 / fan_in) * torch.randn(*shape)
|
||||
raise ValueError(f'Invalid init mode "{mode}"')
|
||||
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
|
||||
self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
|
||||
self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
x = x @ self.weight.to(x.dtype).t()
|
||||
if self.bias is not None:
|
||||
x = x.add_(self.bias.to(x.dtype))
|
||||
return x
|
@ -1,20 +0,0 @@
|
||||
from torch import nn
|
||||
from configs.config import ConfigManager
|
||||
|
||||
|
||||
class Pipeline(nn.Module):
|
||||
TRAIN_MODE: str = "train"
|
||||
TEST_MODE: str = "test"
|
||||
|
||||
def __init__(self, pipeline_config):
|
||||
super(Pipeline, self).__init__()
|
||||
|
||||
self.modules_config = ConfigManager.get("modules")
|
||||
self.device = ConfigManager.get("settings", "general", "device")
|
||||
|
||||
def forward(self, data, mode):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
12
modules/pts_encoder/abstract_pts_encoder.py
Normal file
12
modules/pts_encoder/abstract_pts_encoder.py
Normal file
@ -0,0 +1,12 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PointsEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(PointsEncoder, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def encode_points(self, pts):
|
||||
pass
|
110
modules/pts_encoder/pointnet_encoder.py
Normal file
110
modules/pts_encoder/pointnet_encoder.py
Normal file
@ -0,0 +1,110 @@
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.utils.data
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
class STNkd(nn.Module):
|
||||
def __init__(self, k=64):
|
||||
super(STNkd, self).__init__()
|
||||
self.conv1 = torch.nn.Conv1d(k, 64, 1)
|
||||
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
||||
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
|
||||
self.fc1 = nn.Linear(1024, 512)
|
||||
self.fc2 = nn.Linear(512, 256)
|
||||
self.fc3 = nn.Linear(256, k * k)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.k = k
|
||||
|
||||
def forward(self, x):
|
||||
batchsize = x.size()[0]
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = torch.max(x, 2, keepdim=True)[0]
|
||||
x = x.view(-1, 1024)
|
||||
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
|
||||
iden = (
|
||||
Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)))
|
||||
.view(1, self.k * self.k)
|
||||
.repeat(batchsize, 1)
|
||||
)
|
||||
if x.is_cuda:
|
||||
iden = iden.to(x.get_device())
|
||||
x = x + iden
|
||||
x = x.view(-1, self.k, self.k)
|
||||
return x
|
||||
|
||||
|
||||
@stereotype.module("pointnet_encoder")
|
||||
class PointNetEncoder(PointsEncoder):
|
||||
|
||||
def __init__(self, global_feat=True, in_dim=3, out_dim=1024, feature_transform=False):
|
||||
super(PointNetEncoder, self).__init__()
|
||||
self.out_dim = out_dim
|
||||
self.feature_transform = feature_transform
|
||||
self.stn = STNkd(k=in_dim)
|
||||
self.conv1 = torch.nn.Conv1d(in_dim, 64, 1)
|
||||
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
||||
self.conv3 = torch.nn.Conv1d(128, 512, 1)
|
||||
self.conv4 = torch.nn.Conv1d(512, out_dim, 1)
|
||||
self.global_feat = global_feat
|
||||
if self.feature_transform:
|
||||
self.f_stn = STNkd(k=64)
|
||||
|
||||
def forward(self, x):
|
||||
n_pts = x.shape[2]
|
||||
trans = self.stn(x)
|
||||
x = x.transpose(2, 1)
|
||||
x = torch.bmm(x, trans)
|
||||
x = x.transpose(2, 1)
|
||||
x = F.relu(self.conv1(x))
|
||||
|
||||
if self.feature_transform:
|
||||
trans_feat = self.f_stn(x)
|
||||
x = x.transpose(2, 1)
|
||||
x = torch.bmm(x, trans_feat)
|
||||
x = x.transpose(2, 1)
|
||||
|
||||
point_feat = x
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = self.conv4(x)
|
||||
x = torch.max(x, 2, keepdim=True)[0]
|
||||
x = x.view(-1, self.out_dim)
|
||||
if self.global_feat:
|
||||
return x
|
||||
else:
|
||||
x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts)
|
||||
return torch.cat([x, point_feat], 1)
|
||||
|
||||
def encode_points(self, pts):
|
||||
pts = pts.transpose(2, 1)
|
||||
if not self.global_feat:
|
||||
pts_feature = self(pts).transpose(2, 1)
|
||||
else:
|
||||
pts_feature = self(pts)
|
||||
return pts_feature
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sim_data = Variable(torch.rand(32, 2500, 3))
|
||||
|
||||
pointnet_global = PointNetEncoder(global_feat=True)
|
||||
out = pointnet_global.encode_points(sim_data)
|
||||
print("global feat", out.size())
|
||||
|
||||
pointnet = PointNetEncoder(global_feat=False)
|
||||
out = pointnet.encode_points(sim_data)
|
||||
print("point feat", out.size())
|
12
modules/view_finder/abstract_view_finder.py
Normal file
12
modules/view_finder/abstract_view_finder.py
Normal file
@ -0,0 +1,12 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ViewFinder(nn.Module):
|
||||
def __init__(self):
|
||||
super(ViewFinder, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def next_best_view(self, scene_pts_feat, target_pts_feat):
|
||||
pass
|
168
modules/view_finder/gf_view_finder.py
Normal file
168
modules/view_finder/gf_view_finder.py
Normal file
@ -0,0 +1,168 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
from utils.pose import PoseUtil
|
||||
from modules.view_finder.abstract_view_finder import ViewFinder
|
||||
import modules.module_lib as mlib
|
||||
import modules.func_lib as flib
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
@stereotype.module("gf_view_finder")
|
||||
class GradientFieldViewFinder(ViewFinder):
|
||||
def __init__(self, pose_mode='rot_matrix', regression_head='Rx_Ry', per_point_feature=False,
|
||||
sample_mode="ode", sampling_steps=None, sde_mode="ve"):
|
||||
|
||||
super(GradientFieldViewFinder, self).__init__()
|
||||
self.regression_head = regression_head
|
||||
self.per_point_feature = per_point_feature
|
||||
self.act = nn.ReLU(True)
|
||||
self.sample_mode = sample_mode
|
||||
self.pose_mode = pose_mode
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(sde_mode)
|
||||
self.sampling_steps = sampling_steps
|
||||
|
||||
''' encode pose '''
|
||||
self.pose_encoder = nn.Sequential(
|
||||
nn.Linear(pose_dim, 256),
|
||||
self.act,
|
||||
nn.Linear(256, 256),
|
||||
self.act,
|
||||
)
|
||||
|
||||
''' encode t '''
|
||||
self.t_encoder = nn.Sequential(
|
||||
mlib.GaussianFourierProjection(embed_dim=128),
|
||||
nn.Linear(128, 128),
|
||||
self.act,
|
||||
)
|
||||
|
||||
''' fusion tail '''
|
||||
if self.regression_head == 'Rx_Ry':
|
||||
if pose_mode != 'rot_matrix':
|
||||
raise NotImplementedError
|
||||
if not per_point_feature:
|
||||
''' rotation_x_axis regress head '''
|
||||
self.fusion_tail_rot_x = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 1024 + 1024, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
self.fusion_tail_rot_y = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 1024 + 1024, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, data):
|
||||
"""
|
||||
Args:
|
||||
data, dict {
|
||||
'target_pts_feat': [bs, c]
|
||||
'scene_pts_feat': [bs, c]
|
||||
'pose_sample': [bs, pose_dim]
|
||||
't': [bs, 1]
|
||||
}
|
||||
"""
|
||||
|
||||
scene_pts_feat = data['scene_feat']
|
||||
target_pts_feat = data['target_feat']
|
||||
sampled_pose = data['sampled_pose']
|
||||
t = data['t']
|
||||
t_feat = self.t_encoder(t.squeeze(1))
|
||||
pose_feat = self.pose_encoder(sampled_pose)
|
||||
|
||||
if self.per_point_feature:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
total_feat = torch.cat([scene_pts_feat, target_pts_feat, t_feat, pose_feat], dim=-1)
|
||||
_, std = self.marginal_prob_fn(total_feat, t)
|
||||
|
||||
if self.regression_head == 'Rx_Ry':
|
||||
rot_x = self.fusion_tail_rot_x(total_feat)
|
||||
rot_y = self.fusion_tail_rot_y(total_feat)
|
||||
out_score = torch.cat([rot_x, rot_y], dim=-1) / (std + 1e-7) # normalisation
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return out_score
|
||||
|
||||
def marginal_prob(self, x, t):
|
||||
return self.marginal_prob_fn(x,t)
|
||||
|
||||
def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
|
||||
|
||||
if self.sample_mode == 'pc':
|
||||
in_process_sample, res = flib.cond_pc_sampler(
|
||||
score_model=self,
|
||||
data=data,
|
||||
prior=self.prior_fn,
|
||||
sde_coeff=self.sde_fn,
|
||||
num_steps=self.sampling_steps,
|
||||
snr=snr,
|
||||
eps=self.sampling_eps,
|
||||
pose_mode=self.pose_mode,
|
||||
init_x=init_x
|
||||
)
|
||||
|
||||
elif self.sample_mode == 'ode':
|
||||
T0 = self.T if T0 is None else T0
|
||||
in_process_sample, res = flib.cond_ode_sampler(
|
||||
score_model=self,
|
||||
data=data,
|
||||
prior=self.prior_fn,
|
||||
sde_coeff=self.sde_fn,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
eps=self.sampling_eps,
|
||||
T=T0,
|
||||
num_steps=self.sampling_steps,
|
||||
pose_mode=self.pose_mode,
|
||||
denoise=denoise,
|
||||
init_x=init_x
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return in_process_sample, res
|
||||
|
||||
def next_best_view(self, scene_pts_feat, target_pts_feat):
|
||||
data = {
|
||||
'scene_feat': scene_pts_feat,
|
||||
'target_feat': target_pts_feat,
|
||||
}
|
||||
in_process_sample, res = self.sample(data)
|
||||
return res.to(dtype=torch.float32), in_process_sample
|
||||
|
||||
|
||||
''' ----------- DEBUG -----------'''
|
||||
if __name__ == "__main__":
|
||||
test_scene_feat = torch.rand(32, 1024).to("cuda:0")
|
||||
test_target_feat = torch.rand(32, 1024).to("cuda:0")
|
||||
test_pose = torch.rand(32, 6).to("cuda:0")
|
||||
test_t = torch.rand(32, 1).to("cuda:0")
|
||||
view_finder = GradientFieldViewFinder().to("cuda:0")
|
||||
test_data = {
|
||||
'target_feat': test_target_feat,
|
||||
'scene_feat': test_scene_feat,
|
||||
'sampled_pose': test_pose,
|
||||
't': test_t
|
||||
}
|
||||
score = view_finder(test_data)
|
||||
|
||||
result = view_finder.next_best_view(test_scene_feat, test_target_feat)
|
||||
print(result)
|
@ -1,32 +0,0 @@
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
class OptimizerFactory:
|
||||
@staticmethod
|
||||
def create(config, params):
|
||||
optim_type = config["type"]
|
||||
lr = config.get("lr", 1e-3)
|
||||
if optim_type == "sgd":
|
||||
return optim.SGD(
|
||||
params,
|
||||
lr=lr,
|
||||
momentum=config.get("momentum", 0.9),
|
||||
weight_decay=config.get("weight_decay", 1e-4),
|
||||
)
|
||||
elif optim_type == "adam":
|
||||
return optim.Adam(
|
||||
params,
|
||||
lr=lr,
|
||||
betas=config.get("betas", (0.9, 0.999)),
|
||||
eps=config.get("eps", 1e-8),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown optimizers: {}".format(optim_type))
|
||||
|
||||
|
||||
""" ------------ Debug ------------ """
|
||||
if __name__ == "__main__":
|
||||
from configs.config import ConfigManager
|
||||
|
||||
ConfigManager.load_config_with("../configs/local_train_config.yaml")
|
||||
ConfigManager.print_config()
|
@ -1,59 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from configs.config import ConfigManager
|
||||
|
||||
class Runner(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config_path):
|
||||
ConfigManager.load_config_with(config_path)
|
||||
ConfigManager.print_config()
|
||||
seed = ConfigManager.get("settings", "general", "seed")
|
||||
self.device = ConfigManager.get("settings", "general", "device")
|
||||
self.cuda_visible_devices = ConfigManager.get("settings","general","cuda_visible_devices")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_visible_devices
|
||||
self.experiments_config = ConfigManager.get("settings", "experiment")
|
||||
self.experiment_path = os.path.join(self.experiments_config["root_dir"], self.experiments_config["name"])
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
lt = time.localtime()
|
||||
self.file_name = f"{lt.tm_year}_{lt.tm_mon}_{lt.tm_mday}_{lt.tm_hour}h{lt.tm_min}m{lt.tm_sec}s"
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_experiment(self, backup_name=None):
|
||||
if not os.path.exists(self.experiment_path):
|
||||
print(f"experiments environment {self.experiments_config['name']} does not exists.")
|
||||
self.create_experiment(backup_name)
|
||||
else:
|
||||
print(f"experiments environment {self.experiments_config['name']}")
|
||||
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
||||
if not os.path.exists(backup_config_dir):
|
||||
os.makedirs(backup_config_dir)
|
||||
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
||||
|
||||
@abstractmethod
|
||||
def create_experiment(self, backup_name=None):
|
||||
print("creating experiment: " + self.experiments_config["name"])
|
||||
os.makedirs(self.experiment_path)
|
||||
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
||||
os.makedirs(backup_config_dir)
|
||||
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
||||
log_dir = os.path.join(str(self.experiment_path), "log")
|
||||
os.makedirs(log_dir)
|
||||
cache_dir = os.path.join(str(self.experiment_path), "cache")
|
||||
os.makedirs(cache_dir)
|
||||
|
||||
def print_info(self):
|
||||
table_size = 80
|
||||
print("+" + "-" * table_size + "+")
|
||||
print(f"| Experiment <{self.experiments_config['name']}>")
|
||||
print("+" + "-" * table_size + "+")
|
73
runners/strategy_generator.py
Normal file
73
runners/strategy_generator.py
Normal file
@ -0,0 +1,73 @@
|
||||
import os
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.config import ConfigManager
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
@stereotype.runner("strategy_generator")
|
||||
class StrategyGenerator(Runner):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.load_experiment("generate")
|
||||
|
||||
def run(self):
|
||||
self.demo(seq=16,num=100)
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
output_dir = os.path.join(str(self.experiment_path), "output")
|
||||
os.makedirs(output_dir)
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
|
||||
def demo(self, seq, num=100):
|
||||
import os
|
||||
from utils.data_load import DataLoadUtil
|
||||
from utils.reconstruction import ReconstructionUtil
|
||||
import numpy as np
|
||||
|
||||
component = self.config["generate"][0]["component"] #r"C:\Document\Local Project\nbv_rec\output"
|
||||
data_dir = ConfigManager.get("datasets", "components", component, "root_dir")
|
||||
model_path = os.path.join(data_dir, f"sequence.{seq}\\world_points.txt")
|
||||
model_pts = np.loadtxt(model_path)
|
||||
|
||||
output_dir = os.path.join(str(self.experiment_path), "output")
|
||||
pts_list = []
|
||||
for idx in range(0,num):
|
||||
path = DataLoadUtil.get_path(data_dir, seq, idx)
|
||||
point_cloud = DataLoadUtil.get_point_cloud_world_from_path(path)
|
||||
|
||||
sampled_point_cloud = ReconstructionUtil.downsample_point_cloud(point_cloud, 0.005)
|
||||
pts_list.append(sampled_point_cloud)
|
||||
|
||||
sampled_model_pts = ReconstructionUtil.downsample_point_cloud(model_pts, 0.005)
|
||||
np.savetxt(os.path.join(output_dir,"sampled_model_points.txt"), sampled_model_pts)
|
||||
thre = 0.005
|
||||
|
||||
useful_view, useless_view = ReconstructionUtil.compute_next_best_view_sequence(model_pts, pts_list, threshold=thre)
|
||||
print("useful:", useful_view)
|
||||
print("useless:", useless_view)
|
||||
|
||||
selected_full_views = ReconstructionUtil.combine_point_with_view_sequence(pts_list, useful_view)
|
||||
downsampled_selected_full_views = ReconstructionUtil.downsample_point_cloud(selected_full_views, thre)
|
||||
np.savetxt(os.path.join(output_dir,"selected_full_views.txt"), downsampled_selected_full_views)
|
||||
|
||||
limited_useful_view, limited_useless_view = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(model_pts, pts_list, threshold=thre, overlap_threshold=0.3)
|
||||
print("limited_useful:", limited_useful_view)
|
||||
print("limited_useless:", limited_useless_view)
|
||||
|
||||
limited_selected_full_views = ReconstructionUtil.combine_point_with_view_sequence(pts_list, limited_useful_view)
|
||||
downsampled_limited_selected_full_views = ReconstructionUtil.downsample_point_cloud(limited_selected_full_views, thre)
|
||||
np.savetxt(os.path.join(output_dir,"selected_full_views_limited.txt"), downsampled_limited_selected_full_views)
|
||||
import json
|
||||
for idx, score in limited_useful_view:
|
||||
path = DataLoadUtil.get_path(data_dir, seq, idx)
|
||||
point_cloud = DataLoadUtil.get_point_cloud_world_from_path(path)
|
||||
print("saving useful view: ", idx, " | score: ", score)
|
||||
np.savetxt(os.path.join(output_dir,f"useful_view_{idx}.txt"), point_cloud)
|
||||
with open(os.path.join(output_dir,f"useful_view.json"), 'w') as f:
|
||||
json.dump(limited_useful_view, f)
|
||||
print("seq length: ", len(useful_view), "limited seq length: ", len(limited_useful_view))
|
||||
|
||||
|
||||
|
@ -1,261 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
path = os.path.abspath(__file__)
|
||||
for i in range(2):
|
||||
path = os.path.dirname(path)
|
||||
PROJECT_ROOT = path
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from configs.config import ConfigManager
|
||||
from datasets.dataset_factory import DatasetFactory
|
||||
from optimizers.optimizer_factory import OptimizerFactory
|
||||
from evaluations.eval_function_factory import EvalFunctionFactory
|
||||
from losses.loss_function_factory import LossFunctionFactory
|
||||
from modules.pipeline import Pipeline
|
||||
from runners.runner import Runner
|
||||
from utils.tensorboard_util import TensorboardWriter
|
||||
from annotations.external_module import EXTERNAL_FREEZE_MODULES
|
||||
|
||||
|
||||
class Trainer(Runner):
|
||||
CHECKPOINT_DIR_NAME: str = 'checkpoints'
|
||||
TENSORBOARD_DIR_NAME: str = 'tensorboard'
|
||||
LOG_DIR_NAME: str = 'log'
|
||||
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
tensorboard_path = os.path.join(self.experiment_path, Trainer.TENSORBOARD_DIR_NAME)
|
||||
|
||||
''' Pipeline '''
|
||||
self.pipeline_config = ConfigManager.get("settings", "pipeline")
|
||||
self.parallel = ConfigManager.get("settings","general","parallel")
|
||||
self.pipeline = Pipeline(self.pipeline_config)
|
||||
if self.parallel and self.device == "cuda":
|
||||
self.pipeline = torch.nn.DataParallel(self.pipeline)
|
||||
self.pipeline = self.pipeline.to(self.device)
|
||||
|
||||
''' Experiment '''
|
||||
self.current_epoch = 0
|
||||
self.max_epochs = self.experiments_config["max_epochs"]
|
||||
self.test_first = self.experiments_config["test_first"]
|
||||
self.load_experiment("train")
|
||||
|
||||
''' Train '''
|
||||
self.train_config = ConfigManager.get("settings", "train")
|
||||
self.train_dataset_config = self.train_config["dataset"]
|
||||
self.train_set = DatasetFactory.create(self.train_dataset_config)
|
||||
self.optimizer = OptimizerFactory.create(self.train_config["optimizer"], self.pipeline.parameters())
|
||||
self.train_writer = SummaryWriter(
|
||||
log_dir=os.path.join(tensorboard_path, f"[train]{self.train_dataset_config['name']}"))
|
||||
|
||||
''' Test '''
|
||||
self.test_config = ConfigManager.get("settings", "test")
|
||||
self.test_dataset_config_list = self.test_config["dataset_list"]
|
||||
self.test_set_list = []
|
||||
self.test_writer_list = []
|
||||
seen_name = set()
|
||||
for test_dataset_config in self.test_dataset_config_list:
|
||||
if test_dataset_config["name"] not in seen_name:
|
||||
seen_name.add(test_dataset_config["name"])
|
||||
else:
|
||||
raise ValueError("Duplicate test dataset name: {}".format(test_dataset_config["name"]))
|
||||
test_set = DatasetFactory.create(test_dataset_config)
|
||||
test_writer = SummaryWriter(
|
||||
log_dir=os.path.join(tensorboard_path, f"[test]{test_dataset_config['name']}"))
|
||||
self.test_set_list.append(test_set)
|
||||
self.test_writer_list.append(test_writer)
|
||||
del seen_name
|
||||
|
||||
self.print_info()
|
||||
|
||||
def run(self):
|
||||
save_interval = self.experiments_config["save_checkpoint_interval"]
|
||||
if self.current_epoch != 0:
|
||||
print("Continue training from epoch {}.".format(self.current_epoch))
|
||||
else:
|
||||
print("Start training from initial model.")
|
||||
if self.test_first:
|
||||
print("Do test first.")
|
||||
self.test()
|
||||
while self.current_epoch < self.max_epochs:
|
||||
self.current_epoch += 1
|
||||
self.train()
|
||||
self.test()
|
||||
if self.current_epoch % save_interval == 0:
|
||||
self.save_checkpoint()
|
||||
self.save_checkpoint(is_last=True)
|
||||
|
||||
def train(self):
|
||||
self.pipeline.train()
|
||||
train_set_name = self.train_dataset_config["name"]
|
||||
ratio = self.train_dataset_config["ratio"]
|
||||
train_loader = self.train_set.get_loader(device="cuda", shuffle=True)
|
||||
|
||||
loop = tqdm(enumerate(train_loader), total=len(train_loader))
|
||||
loader_length = len(train_loader)
|
||||
for i, data in loop:
|
||||
self.train_set.process_batch(data, self.device)
|
||||
loss_dict = self.train_step(data)
|
||||
loop.set_description(
|
||||
f'Epoch [{self.current_epoch}/{self.max_epochs}] (Train: {train_set_name}, ratio={ratio})')
|
||||
loop.set_postfix(loss=loss_dict)
|
||||
curr_iters = (self.current_epoch - 1) * loader_length + i
|
||||
TensorboardWriter.write_tensorboard(self.train_writer, "iter", loss_dict, curr_iters)
|
||||
|
||||
def train_step(self, data):
|
||||
self.optimizer.zero_grad()
|
||||
output = self.pipeline(data, Pipeline.TRAIN_MODE)
|
||||
total_loss, loss_dict = self.loss_fn(output, data)
|
||||
total_loss.backward()
|
||||
self.optimizer.step()
|
||||
for k, v in loss_dict.items():
|
||||
loss_dict[k] = round(v, 5)
|
||||
return loss_dict
|
||||
|
||||
def loss_fn(self, output, data):
|
||||
loss_config = self.train_config["losses"]
|
||||
loss_dict = {}
|
||||
total_loss = torch.tensor(0.0, dtype=torch.float32, device=self.device)
|
||||
for key in loss_config:
|
||||
weight = loss_config[key]
|
||||
target_loss_fn = LossFunctionFactory.create(key)
|
||||
loss = target_loss_fn(output, data)
|
||||
loss_dict[key] = loss.item()
|
||||
total_loss += weight * loss
|
||||
|
||||
loss_dict['total_loss'] = total_loss.item()
|
||||
return total_loss, loss_dict
|
||||
|
||||
def test(self):
|
||||
self.pipeline.eval()
|
||||
with torch.no_grad():
|
||||
for dataset_idx, test_set in enumerate(self.test_set_list):
|
||||
eval_list = self.test_dataset_config_list[dataset_idx]["eval_list"]
|
||||
test_set_name = self.test_dataset_config_list[dataset_idx]["name"]
|
||||
ratio = self.test_dataset_config_list[dataset_idx]["ratio"]
|
||||
writer = self.test_writer_list[dataset_idx]
|
||||
output_list = []
|
||||
data_list = []
|
||||
test_loader = test_set.get_loader("cpu")
|
||||
loop = tqdm(enumerate(test_loader), total=int(len(test_loader)))
|
||||
for i, data in loop:
|
||||
test_set.process_batch(data, self.device)
|
||||
output = self.pipeline(data, Pipeline.TEST_MODE)
|
||||
output_list.append(output)
|
||||
data_list.append(data)
|
||||
loop.set_description(
|
||||
f'Epoch [{self.current_epoch}/{self.max_epochs}] (Test: {test_set_name}, ratio={ratio})')
|
||||
result_dict = self.eval_fn(output_list, data_list, eval_list)
|
||||
TensorboardWriter.write_tensorboard(writer, "epoch", result_dict, self.current_epoch - 1)
|
||||
|
||||
@staticmethod
|
||||
def eval_fn(output_list, data_list, eval_list):
|
||||
target_eval_fn = EvalFunctionFactory.create(eval_list)
|
||||
result_dict = target_eval_fn(output_list, data_list)
|
||||
return result_dict
|
||||
|
||||
def get_checkpoint_path(self, is_last=False):
|
||||
return os.path.join(self.experiment_path, Trainer.CHECKPOINT_DIR_NAME,
|
||||
"Epoch_{}.pth".format(
|
||||
self.current_epoch if self.current_epoch != -1 and not is_last else "last"))
|
||||
|
||||
def load_checkpoint(self, is_last=False):
|
||||
self.load(self.get_checkpoint_path(is_last))
|
||||
print(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}")
|
||||
if is_last:
|
||||
checkpoint_root = os.path.join(self.experiment_path, Trainer.CHECKPOINT_DIR_NAME)
|
||||
meta_path = os.path.join(checkpoint_root, "meta.json")
|
||||
if not os.path.exists(meta_path):
|
||||
raise FileNotFoundError(
|
||||
"No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"]))
|
||||
file_path = os.path.join(checkpoint_root, "meta.json")
|
||||
with open(file_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
self.current_epoch = meta["last_epoch"]
|
||||
|
||||
def save_checkpoint(self, is_last=False):
|
||||
self.save(self.get_checkpoint_path(is_last))
|
||||
if not is_last:
|
||||
print(f"Checkpoint at epoch {self.current_epoch} saved to {self.get_checkpoint_path(is_last)}")
|
||||
else:
|
||||
meta = {
|
||||
"last_epoch": self.current_epoch,
|
||||
"time": str(datetime.now())
|
||||
}
|
||||
checkpoint_root = os.path.join(self.experiment_path, Trainer.CHECKPOINT_DIR_NAME)
|
||||
file_path = os.path.join(checkpoint_root, "meta.json")
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(meta, f)
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
if self.experiments_config["use_checkpoint"]:
|
||||
self.current_epoch = self.experiments_config["epoch"]
|
||||
self.load_checkpoint(is_last=(self.current_epoch == -1))
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
ckpt_dir = os.path.join(str(self.experiment_path), Trainer.CHECKPOINT_DIR_NAME)
|
||||
os.makedirs(ckpt_dir)
|
||||
tensorboard_dir = os.path.join(str(self.experiment_path), Trainer.TENSORBOARD_DIR_NAME)
|
||||
os.makedirs(tensorboard_dir)
|
||||
|
||||
def load(self, path):
|
||||
state_dict = torch.load(path)
|
||||
if self.parallel:
|
||||
self.pipeline.module.load_state_dict(state_dict)
|
||||
else:
|
||||
self.pipeline.load_state_dict(state_dict)
|
||||
|
||||
def save(self, path):
|
||||
if self.parallel:
|
||||
state_dict = self.pipeline.module.state_dict()
|
||||
else:
|
||||
state_dict = self.pipeline.state_dict()
|
||||
|
||||
for name, module in self.pipeline.named_modules():
|
||||
if module.__class__ in EXTERNAL_FREEZE_MODULES:
|
||||
if name in state_dict:
|
||||
del state_dict[name]
|
||||
|
||||
torch.save(state_dict, path)
|
||||
|
||||
|
||||
def print_info(self):
|
||||
def print_dataset(config, dataset):
|
||||
print("\t name: {}".format(config["name"]))
|
||||
print("\t source: {}".format(config["source"]))
|
||||
print("\t data_type: {}".format(config["data_type"]))
|
||||
print("\t total_length: {}".format(len(dataset)))
|
||||
print("\t ratio: {}".format(config["ratio"]))
|
||||
print()
|
||||
|
||||
super().print_info()
|
||||
table_size = 70
|
||||
print(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+')
|
||||
print(self.pipeline)
|
||||
print(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+')
|
||||
print("train dataset: ")
|
||||
print_dataset(self.train_dataset_config, self.train_set)
|
||||
for i, test_dataset_config in enumerate(self.test_dataset_config_list):
|
||||
print(f"test dataset {i}: ")
|
||||
print_dataset(test_dataset_config, self.test_set_list[i])
|
||||
|
||||
print(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="configs/train_config.yaml")
|
||||
args = parser.parse_args()
|
||||
trainer = Trainer(args.config)
|
||||
trainer.run()
|
135
utils/data_load.py
Normal file
135
utils/data_load.py
Normal file
@ -0,0 +1,135 @@
|
||||
import os
|
||||
import OpenEXR
|
||||
import Imath
|
||||
import numpy as np
|
||||
import json
|
||||
import cv2
|
||||
|
||||
class DataLoadUtil:
|
||||
|
||||
@staticmethod
|
||||
def get_path(root, scene_idx, frame_idx):
|
||||
path = os.path.join(root, f"sequence.{scene_idx}", f"step{frame_idx}")
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def read_exr_depth(depth_path):
|
||||
file = OpenEXR.InputFile(depth_path)
|
||||
|
||||
dw = file.header()['dataWindow']
|
||||
width = dw.max.x - dw.min.x + 1
|
||||
height = dw.max.y - dw.min.y + 1
|
||||
|
||||
pix_type = Imath.PixelType(Imath.PixelType.FLOAT)
|
||||
depth_map = np.frombuffer(file.channel('R', pix_type), dtype=np.float32)
|
||||
|
||||
depth_map.shape = (height, width)
|
||||
|
||||
return depth_map
|
||||
|
||||
@staticmethod
|
||||
def load_depth(path):
|
||||
depth_path = path + ".camera.Depth.exr"
|
||||
depth_map = DataLoadUtil.read_exr_depth(depth_path)
|
||||
return depth_map
|
||||
|
||||
@staticmethod
|
||||
def load_rgb(path):
|
||||
rgb_path = path + ".camera.png"
|
||||
rgb_image = cv2.imread(rgb_path, cv2.IMREAD_COLOR)
|
||||
return rgb_image
|
||||
|
||||
@staticmethod
|
||||
def load_seg(path):
|
||||
seg_path = path + ".camera.semantic segmentation.png"
|
||||
seg_image = cv2.imread(seg_path, cv2.IMREAD_COLOR)
|
||||
return seg_image
|
||||
|
||||
@staticmethod
|
||||
def load_cam_info(path):
|
||||
label_path = path + ".camera_params.json"
|
||||
with open(label_path, 'r') as f:
|
||||
label_data = json.load(f)
|
||||
cam_transform = np.asarray(label_data['cam_to_world']).reshape(
|
||||
(4, 4)
|
||||
).T
|
||||
|
||||
offset = np.asarray([
|
||||
[1, 0, 0, 0],
|
||||
[0, -1, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1]])
|
||||
|
||||
cam_to_world = cam_transform @ offset
|
||||
|
||||
|
||||
|
||||
f_x = label_data['f_x']
|
||||
f_y = label_data['f_y']
|
||||
c_x = label_data['c_x']
|
||||
c_y = label_data['c_y']
|
||||
cam_intrinsic = np.array([[f_x, 0, c_x], [0, f_y, c_y], [0, 0, 1]])
|
||||
|
||||
return {
|
||||
"cam_to_world": cam_to_world,
|
||||
"cam_intrinsic": cam_intrinsic
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_target_point_cloud(depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(255,255,255)):
|
||||
h, w = depth.shape
|
||||
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
|
||||
|
||||
z = depth
|
||||
x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0]
|
||||
y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1]
|
||||
|
||||
points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3)
|
||||
points_camera_aug = np.concatenate([points_camera, np.ones((points_camera.shape[0], 1))], axis=-1)
|
||||
|
||||
points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3]
|
||||
mask = mask.reshape(-1, 3)
|
||||
target_mask = np.all(mask == target_mask_label, axis=-1)
|
||||
return {
|
||||
"points_world": points_world[target_mask],
|
||||
"points_camera": points_camera[target_mask]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_target_point_cloud(depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(255,255,255)):
|
||||
h, w = depth.shape
|
||||
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
|
||||
|
||||
z = depth
|
||||
x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0]
|
||||
y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1]
|
||||
|
||||
points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3)
|
||||
points_camera_aug = np.concatenate([points_camera, np.ones((points_camera.shape[0], 1))], axis=-1)
|
||||
|
||||
points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3]
|
||||
mask = mask.reshape(-1, 3)
|
||||
target_mask = np.all(mask == target_mask_label, axis=-1)
|
||||
return {
|
||||
"points_world": points_world[target_mask],
|
||||
"points_camera": points_camera[target_mask]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_point_cloud_world_from_path(path):
|
||||
cam_info = DataLoadUtil.load_cam_info(path)
|
||||
depth = DataLoadUtil.load_depth(path)
|
||||
mask = DataLoadUtil.load_seg(path)
|
||||
point_cloud = DataLoadUtil.get_target_point_cloud(depth, cam_info['cam_intrinsic'], cam_info['cam_to_world'], mask)
|
||||
return point_cloud['points_world']
|
||||
|
||||
@staticmethod
|
||||
def get_point_cloud_list_from_seq(root, seq_idx, num_frames):
|
||||
point_cloud_list = []
|
||||
for idx in range(num_frames):
|
||||
path = DataLoadUtil.get_path(root, seq_idx, idx)
|
||||
point_cloud = DataLoadUtil.get_point_cloud_world_from_path(path)
|
||||
point_cloud_list.append(point_cloud)
|
||||
return point_cloud_list
|
||||
|
246
utils/pose.py
Normal file
246
utils/pose.py
Normal file
@ -0,0 +1,246 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PoseUtil:
|
||||
ROTATION = 1
|
||||
TRANSLATION = 2
|
||||
SCALE = 3
|
||||
|
||||
@staticmethod
|
||||
def get_uniform_translation(trans_m_min, trans_m_max, trans_unit, debug=False):
|
||||
if isinstance(trans_m_min, list):
|
||||
x_min, y_min, z_min = trans_m_min
|
||||
x_max, y_max, z_max = trans_m_max
|
||||
else:
|
||||
x_min, y_min, z_min = trans_m_min, trans_m_min, trans_m_min
|
||||
x_max, y_max, z_max = trans_m_max, trans_m_max, trans_m_max
|
||||
|
||||
x = np.random.uniform(x_min, x_max)
|
||||
y = np.random.uniform(y_min, y_max)
|
||||
z = np.random.uniform(z_min, z_max)
|
||||
translation = np.array([x, y, z])
|
||||
if trans_unit == "cm":
|
||||
translation = translation / 100
|
||||
if debug:
|
||||
print("uniform translation:", translation)
|
||||
return translation
|
||||
|
||||
@staticmethod
|
||||
def get_uniform_rotation(rot_degree_min=0, rot_degree_max=180, debug=False):
|
||||
axis = np.random.randn(3)
|
||||
axis /= np.linalg.norm(axis)
|
||||
theta = np.random.uniform(
|
||||
rot_degree_min / 180 * np.pi, rot_degree_max / 180 * np.pi
|
||||
)
|
||||
|
||||
K = np.array(
|
||||
[[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]]
|
||||
)
|
||||
R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * (K @ K)
|
||||
if debug:
|
||||
print("uniform rotation:", theta * 180 / np.pi)
|
||||
return R
|
||||
|
||||
@staticmethod
|
||||
def get_uniform_pose(
|
||||
trans_min, trans_max, rot_min=0, rot_max=180, trans_unit="cm", debug=False
|
||||
):
|
||||
translation = PoseUtil.get_uniform_translation(
|
||||
trans_min, trans_max, trans_unit, debug
|
||||
)
|
||||
rotation = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
|
||||
pose = np.eye(4)
|
||||
pose[:3, :3] = rotation
|
||||
pose[:3, 3] = translation
|
||||
return pose
|
||||
|
||||
@staticmethod
|
||||
def get_n_uniform_pose(
|
||||
trans_min,
|
||||
trans_max,
|
||||
rot_min=0,
|
||||
rot_max=180,
|
||||
n=1,
|
||||
trans_unit="cm",
|
||||
fix=None,
|
||||
contain_canonical=True,
|
||||
debug=False,
|
||||
):
|
||||
if fix == PoseUtil.ROTATION:
|
||||
translations = np.zeros((n, 3))
|
||||
for i in range(n):
|
||||
translations[i] = PoseUtil.get_uniform_translation(
|
||||
trans_min, trans_max, trans_unit, debug
|
||||
)
|
||||
if contain_canonical:
|
||||
translations[0] = np.zeros(3)
|
||||
rotations = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
|
||||
elif fix == PoseUtil.TRANSLATION:
|
||||
rotations = np.zeros((n, 3, 3))
|
||||
for i in range(n):
|
||||
rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
|
||||
if contain_canonical:
|
||||
rotations[0] = np.eye(3)
|
||||
translations = PoseUtil.get_uniform_translation(
|
||||
trans_min, trans_max, trans_unit, debug
|
||||
)
|
||||
else:
|
||||
translations = np.zeros((n, 3))
|
||||
rotations = np.zeros((n, 3, 3))
|
||||
for i in range(n):
|
||||
translations[i] = PoseUtil.get_uniform_translation(
|
||||
trans_min, trans_max, trans_unit, debug
|
||||
)
|
||||
for i in range(n):
|
||||
rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
|
||||
if contain_canonical:
|
||||
translations[0] = np.zeros(3)
|
||||
rotations[0] = np.eye(3)
|
||||
|
||||
pose = np.eye(4, 4, k=0)[np.newaxis, :].repeat(n, axis=0)
|
||||
pose[:, :3, :3] = rotations
|
||||
pose[:, :3, 3] = translations
|
||||
|
||||
return pose
|
||||
|
||||
@staticmethod
|
||||
def get_n_uniform_pose_batch(
|
||||
trans_min,
|
||||
trans_max,
|
||||
rot_min=0,
|
||||
rot_max=180,
|
||||
n=1,
|
||||
batch_size=1,
|
||||
trans_unit="cm",
|
||||
fix=None,
|
||||
contain_canonical=False,
|
||||
debug=False,
|
||||
):
|
||||
|
||||
batch_poses = []
|
||||
for i in range(batch_size):
|
||||
pose = PoseUtil.get_n_uniform_pose(
|
||||
trans_min,
|
||||
trans_max,
|
||||
rot_min,
|
||||
rot_max,
|
||||
n,
|
||||
trans_unit,
|
||||
fix,
|
||||
contain_canonical,
|
||||
debug,
|
||||
)
|
||||
batch_poses.append(pose)
|
||||
pose_batch = np.stack(batch_poses, axis=0)
|
||||
return pose_batch
|
||||
|
||||
@staticmethod
|
||||
def get_uniform_scale(scale_min, scale_max, debug=False):
|
||||
if isinstance(scale_min, list):
|
||||
x_min, y_min, z_min = scale_min
|
||||
x_max, y_max, z_max = scale_max
|
||||
else:
|
||||
x_min, y_min, z_min = scale_min, scale_min, scale_min
|
||||
x_max, y_max, z_max = scale_max, scale_max, scale_max
|
||||
|
||||
x = np.random.uniform(x_min, x_max)
|
||||
y = np.random.uniform(y_min, y_max)
|
||||
z = np.random.uniform(z_min, z_max)
|
||||
scale = np.array([x, y, z])
|
||||
if debug:
|
||||
print("uniform scale:", scale)
|
||||
return scale
|
||||
|
||||
@staticmethod
|
||||
def normalize_rotation(rotation, rotation_mode):
|
||||
if rotation_mode == "quat_wxyz" or rotation_mode == "quat_xyzw":
|
||||
rotation /= torch.norm(rotation, dim=-1, keepdim=True)
|
||||
elif rotation_mode == "rot_matrix":
|
||||
rot_matrix = PoseUtil.rotation_6d_to_matrix_tensor_batch(rotation)
|
||||
rotation[:, :3] = rot_matrix[:, 0, :]
|
||||
rotation[:, 3:6] = rot_matrix[:, 1, :]
|
||||
elif rotation_mode == "euler_xyz_sx_cx":
|
||||
rot_sin_theta = rotation[:, :3]
|
||||
rot_cos_theta = rotation[:, 3:6]
|
||||
theta = torch.atan2(rot_sin_theta, rot_cos_theta)
|
||||
rotation[:, :3] = torch.sin(theta)
|
||||
rotation[:, 3:6] = torch.cos(theta)
|
||||
elif rotation_mode == "euler_xyz":
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return rotation
|
||||
|
||||
@staticmethod
|
||||
def get_pose_dim(rot_mode):
|
||||
assert rot_mode in [
|
||||
"quat_wxyz",
|
||||
"quat_xyzw",
|
||||
"euler_xyz",
|
||||
"euler_xyz_sx_cx",
|
||||
"rot_matrix",
|
||||
], f"the rotation mode {rot_mode} is not supported!"
|
||||
|
||||
if rot_mode == "quat_wxyz" or rot_mode == "quat_xyzw":
|
||||
pose_dim = 4
|
||||
elif rot_mode == "euler_xyz":
|
||||
pose_dim = 3
|
||||
elif rot_mode == "euler_xyz_sx_cx" or rot_mode == "rot_matrix":
|
||||
pose_dim = 6
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return pose_dim
|
||||
|
||||
@staticmethod
|
||||
def rotation_6d_to_matrix_tensor_batch(d6: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
a1, a2 = d6[..., :3], d6[..., 3:]
|
||||
b1 = F.normalize(a1, dim=-1)
|
||||
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
||||
b2 = F.normalize(b2, dim=-1)
|
||||
b3 = torch.cross(b1, b2, dim=-1)
|
||||
return torch.stack((b1, b2, b3), dim=-2)
|
||||
|
||||
@staticmethod
|
||||
def matrix_to_rotation_6d_tensor_batch(matrix: torch.Tensor) -> torch.Tensor:
|
||||
batch_dim = matrix.size()[:-2]
|
||||
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
|
||||
|
||||
@staticmethod
|
||||
def rotation_6d_to_matrix_numpy(d6):
|
||||
a1, a2 = d6[:3], d6[3:]
|
||||
b1 = a1 / np.linalg.norm(a1)
|
||||
b2 = a2 - np.dot(b1, a2) * b1
|
||||
b2 = b2 / np.linalg.norm(b2)
|
||||
b3 = np.cross(b1, b2)
|
||||
return np.stack((b1, b2, b3), axis=-2)
|
||||
|
||||
@staticmethod
|
||||
def matrix_to_rotation_6d_numpy(matrix):
|
||||
return np.copy(matrix[:2, :]).reshape((6,))
|
||||
|
||||
|
||||
""" ------------ Debug ------------ """
|
||||
|
||||
if __name__ == "__main__":
|
||||
for _ in range(1):
|
||||
PoseUtil.get_uniform_pose(
|
||||
trans_min=[-25, -25, 10],
|
||||
trans_max=[25, 25, 60],
|
||||
rot_min=0,
|
||||
rot_max=10,
|
||||
debug=True,
|
||||
)
|
||||
PoseUtil.get_uniform_scale(scale_min=0.25, scale_max=0.30, debug=True)
|
||||
PoseUtil.get_n_uniform_pose_batch(
|
||||
trans_min=[-25, -25, 10],
|
||||
trans_max=[25, 25, 60],
|
||||
rot_min=0,
|
||||
rot_max=10,
|
||||
batch_size=2,
|
||||
n=2,
|
||||
fix=PoseUtil.TRANSLATION,
|
||||
debug=True,
|
||||
)
|
139
utils/reconstruction.py
Normal file
139
utils/reconstruction.py
Normal file
@ -0,0 +1,139 @@
|
||||
import numpy as np
|
||||
import open3d as o3d
|
||||
from scipy.spatial import cKDTree
|
||||
|
||||
class ReconstructionUtil:
|
||||
|
||||
@staticmethod
|
||||
def compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold=0.01):
|
||||
kdtree = cKDTree(combined_point_cloud)
|
||||
distances, _ = kdtree.query(target_point_cloud)
|
||||
covered_points = np.sum(distances < threshold)
|
||||
coverage_rate = covered_points / target_point_cloud.shape[0]
|
||||
return coverage_rate
|
||||
|
||||
@staticmethod
|
||||
def compute_overlap_rate(point_cloud1, point_cloud2, threshold=0.01):
|
||||
kdtree1 = cKDTree(point_cloud1)
|
||||
kdtree2 = cKDTree(point_cloud2)
|
||||
distances1, _ = kdtree2.query(point_cloud1)
|
||||
distances2, _ = kdtree1.query(point_cloud2)
|
||||
overlapping_points1 = np.sum(distances1 < threshold)
|
||||
overlapping_points2 = np.sum(distances2 < threshold)
|
||||
|
||||
overlap_rate1 = overlapping_points1 / point_cloud1.shape[0]
|
||||
overlap_rate2 = overlapping_points2 / point_cloud2.shape[0]
|
||||
|
||||
return (overlap_rate1 + overlap_rate2) / 2
|
||||
|
||||
@staticmethod
|
||||
def combine_point_with_view_sequence(point_list, view_sequence):
|
||||
selected_views = []
|
||||
for view_index, _ in view_sequence:
|
||||
selected_views.append(point_list[view_index])
|
||||
return np.vstack(selected_views)
|
||||
|
||||
@staticmethod
|
||||
def compute_next_view_coverage_list(views, combined_point_cloud, target_point_cloud, threshold=0.01):
|
||||
best_view = None
|
||||
best_coverage_increase = -1
|
||||
current_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold)
|
||||
|
||||
for view_index, view in enumerate(views):
|
||||
candidate_views = combined_point_cloud + [view]
|
||||
down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(candidate_views, threshold)
|
||||
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
|
||||
coverage_increase = new_coverage - current_coverage
|
||||
if coverage_increase > best_coverage_increase:
|
||||
best_coverage_increase = coverage_increase
|
||||
best_view = view_index
|
||||
return best_view, best_coverage_increase
|
||||
|
||||
@staticmethod
|
||||
def compute_next_best_view_sequence(target_point_cloud, point_cloud_list, threshold=0.01):
|
||||
selected_views = []
|
||||
current_coverage = 0.0
|
||||
remaining_views = list(range(len(point_cloud_list)))
|
||||
view_sequence = []
|
||||
target_point_cloud = ReconstructionUtil.downsample_point_cloud(target_point_cloud, threshold)
|
||||
while remaining_views:
|
||||
best_view = None
|
||||
best_coverage_increase = -1
|
||||
|
||||
for view_index in remaining_views:
|
||||
candidate_views = selected_views + [point_cloud_list[view_index]]
|
||||
combined_point_cloud = np.vstack(candidate_views)
|
||||
|
||||
down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_point_cloud,threshold)
|
||||
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
|
||||
coverage_increase = new_coverage - current_coverage
|
||||
|
||||
if coverage_increase > best_coverage_increase:
|
||||
best_coverage_increase = coverage_increase
|
||||
best_view = view_index
|
||||
|
||||
if best_view is not None:
|
||||
if best_coverage_increase <=1e-3:
|
||||
break
|
||||
selected_views.append(point_cloud_list[best_view])
|
||||
current_coverage += best_coverage_increase
|
||||
view_sequence.append((best_view, current_coverage))
|
||||
remaining_views.remove(best_view)
|
||||
return view_sequence, remaining_views
|
||||
|
||||
|
||||
@staticmethod
|
||||
def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, threshold=0.01, overlap_threshold=0.3):
|
||||
selected_views = []
|
||||
current_coverage = 0.0
|
||||
remaining_views = list(range(len(point_cloud_list)))
|
||||
view_sequence = []
|
||||
target_point_cloud = ReconstructionUtil.downsample_point_cloud(target_point_cloud, threshold)
|
||||
|
||||
while remaining_views:
|
||||
best_view = None
|
||||
best_coverage_increase = -1
|
||||
|
||||
for view_index in remaining_views:
|
||||
|
||||
if selected_views:
|
||||
combined_old_point_cloud = np.vstack(selected_views)
|
||||
down_sampled_old_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_old_point_cloud,threshold)
|
||||
down_sampled_new_view_point_cloud = ReconstructionUtil.downsample_point_cloud(point_cloud_list[view_index],threshold)
|
||||
overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_old_point_cloud,down_sampled_new_view_point_cloud , threshold)
|
||||
if overlap_rate < overlap_threshold:
|
||||
continue
|
||||
|
||||
candidate_views = selected_views + [point_cloud_list[view_index]]
|
||||
combined_point_cloud = np.vstack(candidate_views)
|
||||
down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_point_cloud,threshold)
|
||||
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
|
||||
coverage_increase = new_coverage - current_coverage
|
||||
#print(f"view_index: {view_index}, coverage_increase: {coverage_increase}")
|
||||
if coverage_increase > best_coverage_increase:
|
||||
best_coverage_increase = coverage_increase
|
||||
best_view = view_index
|
||||
|
||||
if best_view is not None:
|
||||
if best_coverage_increase <=1e-3:
|
||||
break
|
||||
selected_views.append(point_cloud_list[best_view])
|
||||
remaining_views.remove(best_view)
|
||||
if best_coverage_increase > 0:
|
||||
current_coverage += best_coverage_increase
|
||||
|
||||
view_sequence.append((best_view, current_coverage))
|
||||
|
||||
else:
|
||||
break
|
||||
|
||||
return view_sequence, remaining_views
|
||||
|
||||
|
||||
def downsample_point_cloud(point_cloud, voxel_size=0.005):
|
||||
o3d_pc = o3d.geometry.PointCloud()
|
||||
o3d_pc.points = o3d.utility.Vector3dVector(point_cloud)
|
||||
downsampled_pc = o3d_pc.voxel_down_sample(voxel_size)
|
||||
return np.asarray(downsampled_pc.points)
|
||||
|
||||
|
@ -1,47 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class TensorboardWriter:
|
||||
@staticmethod
|
||||
def write_tensorboard(writer, panel, data_dict, step):
|
||||
complex_dict = False
|
||||
if "scalars" in data_dict:
|
||||
scalar_data_dict = data_dict["scalars"]
|
||||
TensorboardWriter.write_scalar_tensorboard(writer, panel, scalar_data_dict, step)
|
||||
complex_dict = True
|
||||
if "images" in data_dict:
|
||||
image_data_dict = data_dict["images"]
|
||||
TensorboardWriter.write_image_tensorboard(writer, panel, image_data_dict, step)
|
||||
complex_dict = True
|
||||
if "points" in data_dict:
|
||||
point_data_dict = data_dict["points"]
|
||||
TensorboardWriter.write_points_tensorboard(writer, panel, point_data_dict, step)
|
||||
complex_dict = True
|
||||
|
||||
if not complex_dict:
|
||||
TensorboardWriter.write_scalar_tensorboard(writer, panel, data_dict, step)
|
||||
|
||||
@staticmethod
|
||||
def write_scalar_tensorboard(writer, panel, data_dict, step):
|
||||
for key, value in data_dict.items():
|
||||
if isinstance(value, dict):
|
||||
writer.add_scalars(f'{panel}/{key}', value, step)
|
||||
else:
|
||||
writer.add_scalar(f'{panel}/{key}', value, step)
|
||||
|
||||
@staticmethod
|
||||
def write_image_tensorboard(writer, panel, data_dict, step):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def write_points_tensorboard(writer, panel, data_dict, step):
|
||||
for key, value in data_dict.items():
|
||||
if value.shape[-1] == 3:
|
||||
colors = torch.zeros_like(value)
|
||||
vertices = torch.cat([value, colors], dim=-1)
|
||||
elif value.shape[-1] == 6:
|
||||
vertices = value
|
||||
else:
|
||||
raise ValueError(f'Unexpected value shape: {value.shape}')
|
||||
faces = None
|
||||
writer.add_mesh(f'{panel}/{key}', vertices=vertices, faces=faces, global_step=step)
|
Loading…
x
Reference in New Issue
Block a user