nbv_grasping/modules/pipeline.py
2024-10-09 16:13:22 +00:00

140 lines
5.9 KiB
Python
Executable File

import torch
from torch import nn
import inspect
from configs.config import ConfigManager
from modules.pts_encoder.pts_encoder_factory import PointsEncoderFactory
from modules.view_finder.view_finder_factory import ViewFinderFactory
from modules.module_lib.fusion_layer import FeatureFusion
from modules.rgb_encoder.rgb_encoder_factory import RGBEncoderFactory
class Pipeline(nn.Module):
TRAIN_MODE: str = "train"
TEST_MODE: str = "test"
def __init__(self, pipeline_config):
super(Pipeline, self).__init__()
self.modules_config = ConfigManager.get("modules")
self.device = ConfigManager.get("settings", "general", "device")
self.rgb_feat_cache = ConfigManager.get("datasets", "general", "rgb_feat_cache")
self.pts_encoder = PointsEncoderFactory.create(pipeline_config["pts_encoder"], self.modules_config)
self.view_finder = ViewFinderFactory.create(pipeline_config["view_finder"], self.modules_config)
self.has_rgb_encoder = "rgb_encoder" in pipeline_config
if self.has_rgb_encoder and not self.rgb_feat_cache:
self.rgb_encoder = RGBEncoderFactory.create(pipeline_config["rgb_encoder"], self.modules_config)
self.eps = 1e-5
self.fusion_layer = FeatureFusion(rgb_dim=384, pts_dim=1024,output_dim=1024)
self.to(self.device)
def forward(self, data, mode):
if mode == self.TRAIN_MODE:
return self.forward_gradient(data)
elif mode == self.TEST_MODE:
return self.forward_view(data)
raise ValueError("Unknown mode: {}".format(self.mode))
def forward_gradient(self, data):
target_pts = data["target_pts"]
scene_pts = data["scene_pts"]
gt_delta_rot_6d = data["delta_rot_6d"]
if hasattr(self,"rgb_encoder"):
if "rgb" in data:
rgb_feat = self.rgb_encoder.encode_rgb(data["rgb"])
else:
rgb_feat = data["rgb_feat"]
if "rgb_feat" not in inspect.signature(self.pts_encoder.encode_points).parameters:
target_feat = self.pts_encoder.encode_points(target_pts)
scene_feat = self.pts_encoder.encode_points(scene_pts)
target_feat = self.fusion_layer(rgb_feat, target_feat)
scene_feat = self.fusion_layer(rgb_feat, scene_feat)
else:
target_feat = self.pts_encoder.encode_points(target_pts, rgb_feat)
scene_feat = self.pts_encoder.encode_points(scene_pts, rgb_feat)
else:
target_feat = self.pts_encoder.encode_points(target_pts)
scene_feat = self.pts_encoder.encode_points(scene_pts)
''' get std '''
bs = target_pts.shape[0]
random_t = torch.rand(bs, device=self.device) * (1. - self.eps) + self.eps
random_t = random_t.unsqueeze(-1)
mu, std = self.view_finder.marginal_prob(gt_delta_rot_6d, random_t)
std = std.view(-1, 1)
''' perturb data and get estimated score '''
z = torch.randn_like(gt_delta_rot_6d)
perturbed_x = mu + z * std
input_data = {
"sampled_pose": perturbed_x,
"t": random_t,
"scene_feat": scene_feat,
"target_feat": target_feat
}
estimated_score = self.view_finder(input_data)
''' get target score '''
target_score = - z * std / (std ** 2)
result = {
"estimated_score": estimated_score,
"target_score": target_score,
"std": std
}
return result
def forward_view(self, data):
target_pts = data["target_pts"]
scene_pts = data["scene_pts"]
if self.has_rgb_encoder :
if self.rgb_feat_cache:
rgb_feat = data["rgb_feat"]
else:
rgb = data["rgb"]
rgb_feat = self.rgb_encoder.encode_rgb(rgb)
if "rgb_feat" not in inspect.signature(self.pts_encoder.encode_points).parameters:
target_feat = self.pts_encoder.encode_points(target_pts)
scene_feat = self.pts_encoder.encode_points(scene_pts)
target_feat = self.fusion_layer(rgb_feat, target_feat)
scene_feat = self.fusion_layer(rgb_feat, scene_feat)
else:
target_feat = self.pts_encoder.encode_points(target_pts, rgb_feat)
scene_feat = self.pts_encoder.encode_points(scene_pts, rgb_feat)
else:
target_feat = self.pts_encoder.encode_points(target_pts)
scene_feat = self.pts_encoder.encode_points(scene_pts)
estimated_delta_rot_6d, in_process_sample = self.view_finder.next_best_view(scene_feat, target_feat)
result = {
"estimated_delta_rot_6d": estimated_delta_rot_6d,
"in_process_sample": in_process_sample
}
return result
if __name__ == '__main__':
ConfigManager.load_config_with('../configs/local_train_config.yaml')
ConfigManager.print_config()
test_pipeline_config = ConfigManager.get("settings", "pipeline")
pipeline = Pipeline(test_pipeline_config)
test_scene = torch.rand(32, 1024, 3).to("cuda:0")
test_target = torch.rand(32, 1024, 3).to("cuda:0")
test_delta_rot_6d = torch.rand(32, 6).to("cuda:0")
a = test_delta_rot_6d[:, :3]
b = test_delta_rot_6d[:, 3:]
a_norm = a / a.norm(dim=1, keepdim=True)
b_norm = b / b.norm(dim=1, keepdim=True)
normalized_test_delta_rot_6d = torch.cat((a_norm, b_norm), dim=1)
test_data = {
'target_pts': test_target,
'scene_pts': test_scene,
'delta_rot_6d': normalized_test_delta_rot_6d
}
out_data = pipeline(test_data, "train")
print(out_data.keys())
out_data_test = pipeline(test_data, "test")
print(out_data_test.keys())