2024-10-09 16:13:22 +00:00

122 lines
4.3 KiB
Python
Executable File

""" Tools for loss computation.
Author: chenxi-wang
"""
import torch
import numpy as np
GRASP_MAX_WIDTH = 0.1
GRASPNESS_THRESHOLD = 0.1
NUM_VIEW = 300
NUM_ANGLE = 12
NUM_DEPTH = 4
M_POINT = 1024
def transform_point_cloud(cloud, transform, format='4x4'):
""" Transform points to new coordinates with transformation matrix.
Input:
cloud: [torch.FloatTensor, (N,3)]
points in original coordinates
transform: [torch.FloatTensor, (3,3)/(3,4)/(4,4)]
transformation matrix, could be rotation only or rotation+translation
format: [string, '3x3'/'3x4'/'4x4']
the shape of transformation matrix
'3x3' --> rotation matrix
'3x4'/'4x4' --> rotation matrix + translation matrix
Output:
cloud_transformed: [torch.FloatTensor, (N,3)]
points in new coordinates
"""
if not (format == '3x3' or format == '4x4' or format == '3x4'):
raise ValueError('Unknown transformation format, only support \'3x3\' or \'4x4\' or \'3x4\'.')
if format == '3x3':
cloud_transformed = torch.matmul(transform, cloud.T).T
elif format == '4x4' or format == '3x4':
ones = cloud.new_ones(cloud.size(0), device=cloud.device).unsqueeze(-1)
cloud_ = torch.cat([cloud, ones], dim=1)
cloud_transformed = torch.matmul(transform, cloud_.T).T
cloud_transformed = cloud_transformed[:, :3]
return cloud_transformed
def generate_grasp_views(N=300, phi=(np.sqrt(5) - 1) / 2, center=np.zeros(3), r=1):
""" View sampling on a unit sphere using Fibonacci lattices.
Ref: https://arxiv.org/abs/0912.4540
Input:
N: [int]
number of sampled views
phi: [float]
constant for view coordinate calculation, different phi's bring different distributions, default: (sqrt(5)-1)/2
center: [np.ndarray, (3,), np.float32]
sphere center
r: [float]
sphere radius
Output:
views: [torch.FloatTensor, (N,3)]
sampled view coordinates
"""
views = []
for i in range(N):
zi = (2 * i + 1) / N - 1
xi = np.sqrt(1 - zi ** 2) * np.cos(2 * i * np.pi * phi)
yi = np.sqrt(1 - zi ** 2) * np.sin(2 * i * np.pi * phi)
views.append([xi, yi, zi])
views = r * np.array(views) + center
return torch.from_numpy(views.astype(np.float32))
def batch_viewpoint_params_to_matrix(batch_towards, batch_angle):
""" Transform approach vectors and in-plane rotation angles to rotation matrices.
Input:
batch_towards: [torch.FloatTensor, (N,3)]
approach vectors in batch
batch_angle: [torch.floatTensor, (N,)]
in-plane rotation angles in batch
Output:
batch_matrix: [torch.floatTensor, (N,3,3)]
rotation matrices in batch
"""
axis_x = batch_towards
ones = torch.ones(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device)
zeros = torch.zeros(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device)
axis_y = torch.stack([-axis_x[:, 1], axis_x[:, 0], zeros], dim=-1)
mask_y = (torch.norm(axis_y, dim=-1) == 0)
axis_y[mask_y, 1] = 1
axis_x = axis_x / torch.norm(axis_x, dim=-1, keepdim=True)
axis_y = axis_y / torch.norm(axis_y, dim=-1, keepdim=True)
axis_z = torch.cross(axis_x, axis_y)
sin = torch.sin(batch_angle)
cos = torch.cos(batch_angle)
R1 = torch.stack([ones, zeros, zeros, zeros, cos, -sin, zeros, sin, cos], dim=-1)
R1 = R1.reshape([-1, 3, 3])
R2 = torch.stack([axis_x, axis_y, axis_z], dim=-1)
batch_matrix = torch.matmul(R2, R1)
return batch_matrix
def huber_loss(error, delta=1.0):
"""
Args:
error: Torch tensor (d1,d2,...,dk)
Returns:
loss: Torch tensor (d1,d2,...,dk)
x = error = pred - gt or dist(pred,gt)
0.5 * |x|^2 if |x|<=d
0.5 * d^2 + d * (|x|-d) if |x|>d
Author: Charles R. Qi
Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py
"""
abs_error = torch.abs(error)
quadratic = torch.clamp(abs_error, max=delta)
linear = (abs_error - quadratic)
loss = 0.5 * quadratic ** 2 + delta * linear
return loss