Merge branch 'master' of http://git.hofee.top/hofee/nbv_reconstruction
This commit is contained in:
commit
a621749cc9
@ -9,7 +9,7 @@ runner:
|
|||||||
experiment:
|
experiment:
|
||||||
name: new_test_overfit_to_world
|
name: new_test_overfit_to_world
|
||||||
root_dir: "experiments"
|
root_dir: "experiments"
|
||||||
use_checkpoint: False
|
use_checkpoint: True
|
||||||
epoch: -1 # -1 stands for last epoch
|
epoch: -1 # -1 stands for last epoch
|
||||||
max_epochs: 5000
|
max_epochs: 5000
|
||||||
save_checkpoint_interval: 3
|
save_checkpoint_interval: 3
|
||||||
|
@ -22,7 +22,6 @@ class TransformerSequenceEncoder(nn.Module):
|
|||||||
self.fc = nn.Linear(embed_dim, config["output_dim"])
|
self.fc = nn.Linear(embed_dim, config["output_dim"])
|
||||||
|
|
||||||
def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch):
|
def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch):
|
||||||
# Combine features and pad sequences
|
|
||||||
combined_features_batch = []
|
combined_features_batch = []
|
||||||
lengths = []
|
lengths = []
|
||||||
|
|
||||||
@ -36,16 +35,11 @@ class TransformerSequenceEncoder(nn.Module):
|
|||||||
|
|
||||||
combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
|
combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
|
||||||
|
|
||||||
# Prepare mask for padding
|
|
||||||
max_len = max(lengths)
|
max_len = max(lengths)
|
||||||
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
|
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
|
||||||
# Transformer encoding
|
|
||||||
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
|
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
|
||||||
|
|
||||||
# Mean pooling
|
|
||||||
final_feature = transformer_output.mean(dim=1)
|
final_feature = transformer_output.mean(dim=1)
|
||||||
|
|
||||||
# Fully connected layer
|
|
||||||
final_output = self.fc(final_feature)
|
final_output = self.fc(final_feature)
|
||||||
|
|
||||||
return final_output
|
return final_output
|
||||||
|
@ -117,7 +117,6 @@ class StrategyGenerator(Runner):
|
|||||||
if self.save_mesh:
|
if self.save_mesh:
|
||||||
DataLoadUtil.save_target_mesh_at_world_space(root, model_dir, scene_name)
|
DataLoadUtil.save_target_mesh_at_world_space(root, model_dir, scene_name)
|
||||||
|
|
||||||
DataLoadUtil.save_downsampled_world_model_points(root, scene_name, down_sampled_model_pts)
|
|
||||||
|
|
||||||
def generate_data_pairs(self, useful_view):
|
def generate_data_pairs(self, useful_view):
|
||||||
data_pairs = []
|
data_pairs = []
|
||||||
|
@ -17,27 +17,11 @@ class DataLoadUtil:
|
|||||||
path = os.path.join(root,scene_name, f"label.json")
|
path = os.path.join(root,scene_name, f"label.json")
|
||||||
return path
|
return path
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_sampled_model_points_path(root, scene_name):
|
|
||||||
path = os.path.join(root,scene_name, f"sampled_model_points.txt")
|
|
||||||
return path
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_scene_seq_length(root, scene_name):
|
def get_scene_seq_length(root, scene_name):
|
||||||
camera_params_path = os.path.join(root, scene_name, "camera_params")
|
camera_params_path = os.path.join(root, scene_name, "camera_params")
|
||||||
return len(os.listdir(camera_params_path))
|
return len(os.listdir(camera_params_path))
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_downsampled_world_model_points(root, scene_name):
|
|
||||||
model_path = DataLoadUtil.get_sampled_model_points_path(root, scene_name)
|
|
||||||
model_points = np.loadtxt(model_path)
|
|
||||||
return model_points
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def save_downsampled_world_model_points(root, scene_name, model_points):
|
|
||||||
model_path = DataLoadUtil.get_sampled_model_points_path(root, scene_name)
|
|
||||||
np.savetxt(model_path, model_points)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_mesh_at(model_dir, object_name, world_object_pose):
|
def load_mesh_at(model_dir, object_name, world_object_pose):
|
||||||
model_path = os.path.join(model_dir, object_name, "mesh.obj")
|
model_path = os.path.join(model_dir, object_name, "mesh.obj")
|
||||||
@ -279,7 +263,9 @@ class DataLoadUtil:
|
|||||||
return overlapping_points
|
return overlapping_points
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_points_normals(root, scene_name):
|
def load_points_normals(root, scene_name, display_table_as_world_space_origin=True):
|
||||||
points_path = os.path.join(root, scene_name, "points_and_normals.txt")
|
points_path = os.path.join(root, scene_name, "points_and_normals.txt")
|
||||||
points_normals = np.loadtxt(points_path)
|
points_normals = np.loadtxt(points_path)
|
||||||
|
if display_table_as_world_space_origin:
|
||||||
|
points_normals[:,:3] = points_normals[:,:3] - DataLoadUtil.DISPLAY_TABLE_POSITION
|
||||||
return points_normals
|
return points_normals
|
Loading…
x
Reference in New Issue
Block a user