140 lines
5.9 KiB
Python
Executable File
140 lines
5.9 KiB
Python
Executable File
import torch
|
|
from torch import nn
|
|
import inspect
|
|
|
|
from configs.config import ConfigManager
|
|
|
|
from modules.pts_encoder.pts_encoder_factory import PointsEncoderFactory
|
|
from modules.view_finder.view_finder_factory import ViewFinderFactory
|
|
from modules.module_lib.fusion_layer import FeatureFusion
|
|
from modules.rgb_encoder.rgb_encoder_factory import RGBEncoderFactory
|
|
|
|
|
|
class Pipeline(nn.Module):
|
|
TRAIN_MODE: str = "train"
|
|
TEST_MODE: str = "test"
|
|
|
|
def __init__(self, pipeline_config):
|
|
super(Pipeline, self).__init__()
|
|
|
|
self.modules_config = ConfigManager.get("modules")
|
|
self.device = ConfigManager.get("settings", "general", "device")
|
|
self.rgb_feat_cache = ConfigManager.get("datasets", "general", "rgb_feat_cache")
|
|
self.pts_encoder = PointsEncoderFactory.create(pipeline_config["pts_encoder"], self.modules_config)
|
|
self.view_finder = ViewFinderFactory.create(pipeline_config["view_finder"], self.modules_config)
|
|
self.has_rgb_encoder = "rgb_encoder" in pipeline_config
|
|
if self.has_rgb_encoder and not self.rgb_feat_cache:
|
|
self.rgb_encoder = RGBEncoderFactory.create(pipeline_config["rgb_encoder"], self.modules_config)
|
|
self.eps = 1e-5
|
|
self.fusion_layer = FeatureFusion(rgb_dim=384, pts_dim=1024,output_dim=1024)
|
|
|
|
self.to(self.device)
|
|
|
|
def forward(self, data, mode):
|
|
if mode == self.TRAIN_MODE:
|
|
return self.forward_gradient(data)
|
|
elif mode == self.TEST_MODE:
|
|
return self.forward_view(data)
|
|
raise ValueError("Unknown mode: {}".format(self.mode))
|
|
|
|
def forward_gradient(self, data):
|
|
target_pts = data["target_pts"]
|
|
scene_pts = data["scene_pts"]
|
|
gt_delta_rot_6d = data["delta_rot_6d"]
|
|
|
|
if hasattr(self,"rgb_encoder"):
|
|
if "rgb" in data:
|
|
rgb_feat = self.rgb_encoder.encode_rgb(data["rgb"])
|
|
else:
|
|
rgb_feat = data["rgb_feat"]
|
|
if "rgb_feat" not in inspect.signature(self.pts_encoder.encode_points).parameters:
|
|
target_feat = self.pts_encoder.encode_points(target_pts)
|
|
scene_feat = self.pts_encoder.encode_points(scene_pts)
|
|
target_feat = self.fusion_layer(rgb_feat, target_feat)
|
|
scene_feat = self.fusion_layer(rgb_feat, scene_feat)
|
|
else:
|
|
target_feat = self.pts_encoder.encode_points(target_pts, rgb_feat)
|
|
scene_feat = self.pts_encoder.encode_points(scene_pts, rgb_feat)
|
|
else:
|
|
target_feat = self.pts_encoder.encode_points(target_pts)
|
|
scene_feat = self.pts_encoder.encode_points(scene_pts)
|
|
''' get std '''
|
|
bs = target_pts.shape[0]
|
|
random_t = torch.rand(bs, device=self.device) * (1. - self.eps) + self.eps
|
|
random_t = random_t.unsqueeze(-1)
|
|
mu, std = self.view_finder.marginal_prob(gt_delta_rot_6d, random_t)
|
|
std = std.view(-1, 1)
|
|
|
|
''' perturb data and get estimated score '''
|
|
z = torch.randn_like(gt_delta_rot_6d)
|
|
perturbed_x = mu + z * std
|
|
input_data = {
|
|
"sampled_pose": perturbed_x,
|
|
"t": random_t,
|
|
"scene_feat": scene_feat,
|
|
"target_feat": target_feat
|
|
}
|
|
estimated_score = self.view_finder(input_data)
|
|
|
|
''' get target score '''
|
|
target_score = - z * std / (std ** 2)
|
|
|
|
result = {
|
|
"estimated_score": estimated_score,
|
|
"target_score": target_score,
|
|
"std": std
|
|
}
|
|
return result
|
|
|
|
def forward_view(self, data):
|
|
target_pts = data["target_pts"]
|
|
scene_pts = data["scene_pts"]
|
|
|
|
if self.has_rgb_encoder :
|
|
if self.rgb_feat_cache:
|
|
rgb_feat = data["rgb_feat"]
|
|
else:
|
|
rgb = data["rgb"]
|
|
rgb_feat = self.rgb_encoder.encode_rgb(rgb)
|
|
if "rgb_feat" not in inspect.signature(self.pts_encoder.encode_points).parameters:
|
|
target_feat = self.pts_encoder.encode_points(target_pts)
|
|
scene_feat = self.pts_encoder.encode_points(scene_pts)
|
|
target_feat = self.fusion_layer(rgb_feat, target_feat)
|
|
scene_feat = self.fusion_layer(rgb_feat, scene_feat)
|
|
else:
|
|
target_feat = self.pts_encoder.encode_points(target_pts, rgb_feat)
|
|
scene_feat = self.pts_encoder.encode_points(scene_pts, rgb_feat)
|
|
else:
|
|
target_feat = self.pts_encoder.encode_points(target_pts)
|
|
scene_feat = self.pts_encoder.encode_points(scene_pts)
|
|
estimated_delta_rot_6d, in_process_sample = self.view_finder.next_best_view(scene_feat, target_feat)
|
|
result = {
|
|
"estimated_delta_rot_6d": estimated_delta_rot_6d,
|
|
"in_process_sample": in_process_sample
|
|
}
|
|
return result
|
|
|
|
|
|
if __name__ == '__main__':
|
|
ConfigManager.load_config_with('../configs/local_train_config.yaml')
|
|
ConfigManager.print_config()
|
|
test_pipeline_config = ConfigManager.get("settings", "pipeline")
|
|
pipeline = Pipeline(test_pipeline_config)
|
|
test_scene = torch.rand(32, 1024, 3).to("cuda:0")
|
|
test_target = torch.rand(32, 1024, 3).to("cuda:0")
|
|
test_delta_rot_6d = torch.rand(32, 6).to("cuda:0")
|
|
a = test_delta_rot_6d[:, :3]
|
|
b = test_delta_rot_6d[:, 3:]
|
|
a_norm = a / a.norm(dim=1, keepdim=True)
|
|
b_norm = b / b.norm(dim=1, keepdim=True)
|
|
normalized_test_delta_rot_6d = torch.cat((a_norm, b_norm), dim=1)
|
|
test_data = {
|
|
'target_pts': test_target,
|
|
'scene_pts': test_scene,
|
|
'delta_rot_6d': normalized_test_delta_rot_6d
|
|
}
|
|
out_data = pipeline(test_data, "train")
|
|
print(out_data.keys())
|
|
out_data_test = pipeline(test_data, "test")
|
|
print(out_data_test.keys())
|