292 lines
9.5 KiB
Python
292 lines
9.5 KiB
Python
import torch
|
|
from torch.autograd import Variable
|
|
from torch.autograd import Function
|
|
import torch.nn as nn
|
|
from typing import Tuple
|
|
import sys
|
|
|
|
import pointnet2_cuda as pointnet2
|
|
|
|
|
|
class FurthestPointSampling(Function):
|
|
@staticmethod
|
|
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
|
|
"""
|
|
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
|
minimum distance
|
|
:param ctx:
|
|
:param xyz: (B, N, 3) where N > npoint
|
|
:param npoint: int, number of features in the sampled set
|
|
:return:
|
|
output: (B, npoint) tensor containing the set
|
|
"""
|
|
assert xyz.is_contiguous()
|
|
|
|
B, N, _ = xyz.size()
|
|
output = torch.cuda.IntTensor(B, npoint)
|
|
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
|
|
|
|
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(xyz, a=None):
|
|
return None, None
|
|
|
|
|
|
furthest_point_sample = FurthestPointSampling.apply
|
|
|
|
|
|
class GatherOperation(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
:param ctx:
|
|
:param features: (B, C, N)
|
|
:param idx: (B, npoint) index tensor of the features to gather
|
|
:return:
|
|
output: (B, C, npoint)
|
|
"""
|
|
assert features.is_contiguous()
|
|
assert idx.is_contiguous()
|
|
|
|
B, npoint = idx.size()
|
|
_, C, N = features.size()
|
|
output = torch.cuda.FloatTensor(B, C, npoint)
|
|
|
|
pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
|
|
|
|
ctx.for_backwards = (idx, C, N)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
idx, C, N = ctx.for_backwards
|
|
B, npoint = idx.size()
|
|
|
|
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
|
grad_out_data = grad_out.data.contiguous()
|
|
pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
|
|
return grad_features, None
|
|
|
|
|
|
gather_operation = GatherOperation.apply
|
|
|
|
|
|
class ThreeNN(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Find the three nearest neighbors of unknown in known
|
|
:param ctx:
|
|
:param unknown: (B, N, 3)
|
|
:param known: (B, M, 3)
|
|
:return:
|
|
dist: (B, N, 3) l2 distance to the three nearest neighbors
|
|
idx: (B, N, 3) index of 3 nearest neighbors
|
|
"""
|
|
assert unknown.is_contiguous()
|
|
assert known.is_contiguous()
|
|
|
|
B, N, _ = unknown.size()
|
|
m = known.size(1)
|
|
dist2 = torch.cuda.FloatTensor(B, N, 3)
|
|
idx = torch.cuda.IntTensor(B, N, 3)
|
|
|
|
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
|
|
return torch.sqrt(dist2), idx
|
|
|
|
@staticmethod
|
|
def backward(ctx, a=None, b=None):
|
|
return None, None
|
|
|
|
|
|
three_nn = ThreeNN.apply
|
|
|
|
|
|
class ThreeInterpolate(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Performs weight linear interpolation on 3 features
|
|
:param ctx:
|
|
:param features: (B, C, M) Features descriptors to be interpolated from
|
|
:param idx: (B, n, 3) three nearest neighbors of the target features in features
|
|
:param weight: (B, n, 3) weights
|
|
:return:
|
|
output: (B, C, N) tensor of the interpolated features
|
|
"""
|
|
assert features.is_contiguous()
|
|
assert idx.is_contiguous()
|
|
assert weight.is_contiguous()
|
|
|
|
B, c, m = features.size()
|
|
n = idx.size(1)
|
|
ctx.three_interpolate_for_backward = (idx, weight, m)
|
|
output = torch.cuda.FloatTensor(B, c, n)
|
|
|
|
pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
:param ctx:
|
|
:param grad_out: (B, C, N) tensor with gradients of outputs
|
|
:return:
|
|
grad_features: (B, C, M) tensor with gradients of features
|
|
None:
|
|
None:
|
|
"""
|
|
idx, weight, m = ctx.three_interpolate_for_backward
|
|
B, c, n = grad_out.size()
|
|
|
|
grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
|
|
grad_out_data = grad_out.data.contiguous()
|
|
|
|
pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
|
|
return grad_features, None, None
|
|
|
|
|
|
three_interpolate = ThreeInterpolate.apply
|
|
|
|
|
|
class GroupingOperation(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
:param ctx:
|
|
:param features: (B, C, N) tensor of features to group
|
|
:param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
|
|
:return:
|
|
output: (B, C, npoint, nsample) tensor
|
|
"""
|
|
assert features.is_contiguous()
|
|
assert idx.is_contiguous()
|
|
|
|
B, nfeatures, nsample = idx.size()
|
|
_, C, N = features.size()
|
|
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
|
|
|
|
pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
|
|
|
|
ctx.for_backwards = (idx, N)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
:param ctx:
|
|
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
|
|
:return:
|
|
grad_features: (B, C, N) gradient of the features
|
|
"""
|
|
idx, N = ctx.for_backwards
|
|
|
|
B, C, npoint, nsample = grad_out.size()
|
|
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
|
|
|
grad_out_data = grad_out.data.contiguous()
|
|
pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
|
|
return grad_features, None
|
|
|
|
|
|
grouping_operation = GroupingOperation.apply
|
|
|
|
|
|
class BallQuery(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
:param ctx:
|
|
:param radius: float, radius of the balls
|
|
:param nsample: int, maximum number of features in the balls
|
|
:param xyz: (B, N, 3) xyz coordinates of the features
|
|
:param new_xyz: (B, npoint, 3) centers of the ball query
|
|
:return:
|
|
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
|
"""
|
|
assert new_xyz.is_contiguous()
|
|
assert xyz.is_contiguous()
|
|
|
|
B, N, _ = xyz.size()
|
|
npoint = new_xyz.size(1)
|
|
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
|
|
|
|
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
|
|
return idx
|
|
|
|
@staticmethod
|
|
def backward(ctx, a=None):
|
|
return None, None, None, None
|
|
|
|
|
|
ball_query = BallQuery.apply
|
|
|
|
|
|
class QueryAndGroup(nn.Module):
|
|
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
|
|
"""
|
|
:param radius: float, radius of ball
|
|
:param nsample: int, maximum number of features to gather in the ball
|
|
:param use_xyz:
|
|
"""
|
|
super().__init__()
|
|
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
|
|
|
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
|
|
"""
|
|
:param xyz: (B, N, 3) xyz coordinates of the features
|
|
:param new_xyz: (B, npoint, 3) centroids
|
|
:param features: (B, C, N) descriptors of the features
|
|
:return:
|
|
new_features: (B, 3 + C, npoint, nsample)
|
|
"""
|
|
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
|
xyz_trans = xyz.transpose(1, 2).contiguous()
|
|
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
|
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
|
|
|
|
if features is not None:
|
|
grouped_features = grouping_operation(features, idx)
|
|
if self.use_xyz:
|
|
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
|
|
else:
|
|
new_features = grouped_features
|
|
else:
|
|
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
|
|
new_features = grouped_xyz
|
|
|
|
return new_features
|
|
|
|
|
|
class GroupAll(nn.Module):
|
|
def __init__(self, use_xyz: bool = True):
|
|
super().__init__()
|
|
self.use_xyz = use_xyz
|
|
|
|
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
|
|
"""
|
|
:param xyz: (B, N, 3) xyz coordinates of the features
|
|
:param new_xyz: ignored
|
|
:param features: (B, C, N) descriptors of the features
|
|
:return:
|
|
new_features: (B, C + 3, 1, N)
|
|
"""
|
|
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
|
if features is not None:
|
|
grouped_features = features.unsqueeze(2)
|
|
if self.use_xyz:
|
|
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
|
|
else:
|
|
new_features = grouped_features
|
|
else:
|
|
new_features = grouped_xyz
|
|
|
|
return new_features
|