This commit is contained in:
hofee 2024-09-20 15:00:38 +08:00
commit a621749cc9
5 changed files with 5 additions and 26 deletions

View File

@ -9,7 +9,7 @@ runner:
experiment:
name: new_test_overfit_to_world
root_dir: "experiments"
use_checkpoint: False
use_checkpoint: True
epoch: -1 # -1 stands for last epoch
max_epochs: 5000
save_checkpoint_interval: 3

View File

@ -22,7 +22,6 @@ class TransformerSequenceEncoder(nn.Module):
self.fc = nn.Linear(embed_dim, config["output_dim"])
def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch):
# Combine features and pad sequences
combined_features_batch = []
lengths = []
@ -36,16 +35,11 @@ class TransformerSequenceEncoder(nn.Module):
combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
# Prepare mask for padding
max_len = max(lengths)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
# Transformer encoding
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
# Mean pooling
final_feature = transformer_output.mean(dim=1)
# Fully connected layer
final_output = self.fc(final_feature)
return final_output

View File

@ -117,7 +117,6 @@ class StrategyGenerator(Runner):
if self.save_mesh:
DataLoadUtil.save_target_mesh_at_world_space(root, model_dir, scene_name)
DataLoadUtil.save_downsampled_world_model_points(root, scene_name, down_sampled_model_pts)
def generate_data_pairs(self, useful_view):
data_pairs = []

View File

@ -17,27 +17,11 @@ class DataLoadUtil:
path = os.path.join(root,scene_name, f"label.json")
return path
@staticmethod
def get_sampled_model_points_path(root, scene_name):
path = os.path.join(root,scene_name, f"sampled_model_points.txt")
return path
@staticmethod
def get_scene_seq_length(root, scene_name):
camera_params_path = os.path.join(root, scene_name, "camera_params")
return len(os.listdir(camera_params_path))
@staticmethod
def load_downsampled_world_model_points(root, scene_name):
model_path = DataLoadUtil.get_sampled_model_points_path(root, scene_name)
model_points = np.loadtxt(model_path)
return model_points
@staticmethod
def save_downsampled_world_model_points(root, scene_name, model_points):
model_path = DataLoadUtil.get_sampled_model_points_path(root, scene_name)
np.savetxt(model_path, model_points)
@staticmethod
def load_mesh_at(model_dir, object_name, world_object_pose):
model_path = os.path.join(model_dir, object_name, "mesh.obj")
@ -279,7 +263,9 @@ class DataLoadUtil:
return overlapping_points
@staticmethod
def load_points_normals(root, scene_name):
def load_points_normals(root, scene_name, display_table_as_world_space_origin=True):
points_path = os.path.join(root, scene_name, "points_and_normals.txt")
points_normals = np.loadtxt(points_path)
if display_table_as_world_space_origin:
points_normals[:,:3] = points_normals[:,:3] - DataLoadUtil.DISPLAY_TABLE_POSITION
return points_normals