add points_normals under display_table_world_space

This commit is contained in:
hofee 2024-09-19 12:12:48 +00:00
parent 55684e86ba
commit 4e4fcb2ce5
3 changed files with 4 additions and 19 deletions

View File

@ -9,7 +9,7 @@ runner:
experiment: experiment:
name: new_test_overfit_to_world name: new_test_overfit_to_world
root_dir: "experiments" root_dir: "experiments"
use_checkpoint: False use_checkpoint: True
epoch: -1 # -1 stands for last epoch epoch: -1 # -1 stands for last epoch
max_epochs: 5000 max_epochs: 5000
save_checkpoint_interval: 3 save_checkpoint_interval: 3

View File

@ -117,7 +117,6 @@ class StrategyGenerator(Runner):
if self.save_mesh: if self.save_mesh:
DataLoadUtil.save_target_mesh_at_world_space(root, model_dir, scene_name) DataLoadUtil.save_target_mesh_at_world_space(root, model_dir, scene_name)
DataLoadUtil.save_downsampled_world_model_points(root, scene_name, down_sampled_model_pts)
def generate_data_pairs(self, useful_view): def generate_data_pairs(self, useful_view):
data_pairs = [] data_pairs = []

View File

@ -17,27 +17,11 @@ class DataLoadUtil:
path = os.path.join(root,scene_name, f"label.json") path = os.path.join(root,scene_name, f"label.json")
return path return path
@staticmethod
def get_sampled_model_points_path(root, scene_name):
path = os.path.join(root,scene_name, f"sampled_model_points.txt")
return path
@staticmethod @staticmethod
def get_scene_seq_length(root, scene_name): def get_scene_seq_length(root, scene_name):
camera_params_path = os.path.join(root, scene_name, "camera_params") camera_params_path = os.path.join(root, scene_name, "camera_params")
return len(os.listdir(camera_params_path)) return len(os.listdir(camera_params_path))
@staticmethod
def load_downsampled_world_model_points(root, scene_name):
model_path = DataLoadUtil.get_sampled_model_points_path(root, scene_name)
model_points = np.loadtxt(model_path)
return model_points
@staticmethod
def save_downsampled_world_model_points(root, scene_name, model_points):
model_path = DataLoadUtil.get_sampled_model_points_path(root, scene_name)
np.savetxt(model_path, model_points)
@staticmethod @staticmethod
def load_mesh_at(model_dir, object_name, world_object_pose): def load_mesh_at(model_dir, object_name, world_object_pose):
model_path = os.path.join(model_dir, object_name, "mesh.obj") model_path = os.path.join(model_dir, object_name, "mesh.obj")
@ -269,7 +253,9 @@ class DataLoadUtil:
return overlapping_points return overlapping_points
@staticmethod @staticmethod
def load_points_normals(root, scene_name): def load_points_normals(root, scene_name, display_table_as_world_space_origin=True):
points_path = os.path.join(root, scene_name, "points_and_normals.txt") points_path = os.path.join(root, scene_name, "points_and_normals.txt")
points_normals = np.loadtxt(points_path) points_normals = np.loadtxt(points_path)
if display_table_as_world_space_origin:
points_normals[:,:3] = points_normals[:,:3] - DataLoadUtil.DISPLAY_TABLE_POSITION
return points_normals return points_normals