553 lines
16 KiB
Python
Executable File
553 lines
16 KiB
Python
Executable File
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch '''
|
|
from __future__ import (
|
|
division,
|
|
absolute_import,
|
|
with_statement,
|
|
print_function,
|
|
unicode_literals,
|
|
)
|
|
import torch
|
|
from torch.autograd import Function
|
|
import torch.nn as nn
|
|
import pytorch_utils as pt_utils
|
|
import sys
|
|
|
|
try:
|
|
import builtins
|
|
except:
|
|
import __builtin__ as builtins
|
|
|
|
try:
|
|
import pointnet2._ext as _ext
|
|
except ImportError:
|
|
if not getattr(builtins, "__POINTNET2_SETUP__", False):
|
|
raise ImportError(
|
|
"Could not import _ext module.\n"
|
|
"Please see the setup instructions in the README: "
|
|
"https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst"
|
|
)
|
|
|
|
if False:
|
|
# Workaround for type hints without depending on the `typing` module
|
|
from typing import *
|
|
|
|
|
|
class RandomDropout(nn.Module):
|
|
def __init__(self, p=0.5, inplace=False):
|
|
super(RandomDropout, self).__init__()
|
|
self.p = p
|
|
self.inplace = inplace
|
|
|
|
def forward(self, X):
|
|
theta = torch.Tensor(1).uniform_(0, self.p)[0]
|
|
return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace)
|
|
|
|
|
|
class FurthestPointSampling(Function):
|
|
@staticmethod
|
|
def forward(ctx, xyz, npoint):
|
|
# type: (Any, torch.Tensor, int) -> torch.Tensor
|
|
r"""
|
|
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
|
minimum distance
|
|
|
|
Parameters
|
|
----------
|
|
xyz : torch.Tensor
|
|
(B, N, 3) tensor where N > npoint
|
|
npoint : int32
|
|
number of features in the sampled set
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, npoint) tensor containing the set
|
|
"""
|
|
return _ext.furthest_point_sampling(xyz, npoint)
|
|
|
|
@staticmethod
|
|
def backward(xyz, a=None):
|
|
return None, None
|
|
|
|
|
|
furthest_point_sample = FurthestPointSampling.apply
|
|
|
|
|
|
class GatherOperation(Function):
|
|
@staticmethod
|
|
def forward(ctx, features, idx):
|
|
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
|
|
r"""
|
|
|
|
Parameters
|
|
----------
|
|
features : torch.Tensor
|
|
(B, C, N) tensor
|
|
|
|
idx : torch.Tensor
|
|
(B, npoint) tensor of the features to gather
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, C, npoint) tensor
|
|
"""
|
|
|
|
_, C, N = features.size()
|
|
|
|
ctx.for_backwards = (idx, C, N)
|
|
|
|
return _ext.gather_points(features, idx)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
idx, C, N = ctx.for_backwards
|
|
|
|
grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
|
|
return grad_features, None
|
|
|
|
|
|
gather_operation = GatherOperation.apply
|
|
|
|
|
|
class ThreeNN(Function):
|
|
@staticmethod
|
|
def forward(ctx, unknown, known):
|
|
# type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
|
r"""
|
|
Find the three nearest neighbors of unknown in known
|
|
Parameters
|
|
----------
|
|
unknown : torch.Tensor
|
|
(B, n, 3) tensor of known features
|
|
known : torch.Tensor
|
|
(B, m, 3) tensor of unknown features
|
|
|
|
Returns
|
|
-------
|
|
dist : torch.Tensor
|
|
(B, n, 3) l2 distance to the three nearest neighbors
|
|
idx : torch.Tensor
|
|
(B, n, 3) index of 3 nearest neighbors
|
|
"""
|
|
dist2, idx = _ext.three_nn(unknown, known)
|
|
|
|
return torch.sqrt(dist2), idx
|
|
|
|
@staticmethod
|
|
def backward(ctx, a=None, b=None):
|
|
return None, None
|
|
|
|
|
|
three_nn = ThreeNN.apply
|
|
|
|
|
|
class ThreeInterpolate(Function):
|
|
@staticmethod
|
|
def forward(ctx, features, idx, weight):
|
|
# type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
|
|
r"""
|
|
Performs weight linear interpolation on 3 features
|
|
Parameters
|
|
----------
|
|
features : torch.Tensor
|
|
(B, c, m) Features descriptors to be interpolated from
|
|
idx : torch.Tensor
|
|
(B, n, 3) three nearest neighbors of the target features in features
|
|
weight : torch.Tensor
|
|
(B, n, 3) weights
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, c, n) tensor of the interpolated features
|
|
"""
|
|
B, c, m = features.size()
|
|
n = idx.size(1)
|
|
|
|
ctx.three_interpolate_for_backward = (idx, weight, m)
|
|
|
|
return _ext.three_interpolate(features, idx, weight)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
# type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
r"""
|
|
Parameters
|
|
----------
|
|
grad_out : torch.Tensor
|
|
(B, c, n) tensor with gradients of ouputs
|
|
|
|
Returns
|
|
-------
|
|
grad_features : torch.Tensor
|
|
(B, c, m) tensor with gradients of features
|
|
|
|
None
|
|
|
|
None
|
|
"""
|
|
idx, weight, m = ctx.three_interpolate_for_backward
|
|
|
|
grad_features = _ext.three_interpolate_grad(
|
|
grad_out.contiguous(), idx, weight, m
|
|
)
|
|
|
|
return grad_features, None, None
|
|
|
|
|
|
three_interpolate = ThreeInterpolate.apply
|
|
|
|
|
|
class GroupingOperation(Function):
|
|
@staticmethod
|
|
def forward(ctx, features, idx):
|
|
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
|
|
r"""
|
|
|
|
Parameters
|
|
----------
|
|
features : torch.Tensor
|
|
(B, C, N) tensor of features to group
|
|
idx : torch.Tensor
|
|
(B, npoint, nsample) tensor containing the indicies of features to group with
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, C, npoint, nsample) tensor
|
|
"""
|
|
B, nfeatures, nsample = idx.size()
|
|
_, C, N = features.size()
|
|
|
|
ctx.for_backwards = (idx, N)
|
|
|
|
return _ext.group_points(features, idx)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
# type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
|
r"""
|
|
|
|
Parameters
|
|
----------
|
|
grad_out : torch.Tensor
|
|
(B, C, npoint, nsample) tensor of the gradients of the output from forward
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, C, N) gradient of the features
|
|
None
|
|
"""
|
|
idx, N = ctx.for_backwards
|
|
|
|
grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
|
|
|
|
return grad_features, None
|
|
|
|
|
|
grouping_operation = GroupingOperation.apply
|
|
|
|
|
|
class BallQuery(Function):
|
|
@staticmethod
|
|
def forward(ctx, radius, nsample, xyz, new_xyz):
|
|
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
|
|
r"""
|
|
|
|
Parameters
|
|
----------
|
|
radius : float
|
|
radius of the balls
|
|
nsample : int
|
|
maximum number of features in the balls
|
|
xyz : torch.Tensor
|
|
(B, N, 3) xyz coordinates of the features
|
|
new_xyz : torch.Tensor
|
|
(B, npoint, 3) centers of the ball query
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
|
"""
|
|
return _ext.ball_query(new_xyz, xyz, radius, nsample)
|
|
|
|
@staticmethod
|
|
def backward(ctx, a=None):
|
|
return None, None, None, None
|
|
|
|
|
|
ball_query = BallQuery.apply
|
|
|
|
|
|
class QueryAndGroup(nn.Module):
|
|
r"""
|
|
Groups with a ball query of radius
|
|
|
|
Parameters
|
|
---------
|
|
radius : float32
|
|
Radius of ball
|
|
nsample : int32
|
|
Maximum number of features to gather in the ball
|
|
"""
|
|
|
|
def __init__(self, radius, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, sample_uniformly=False, ret_unique_cnt=False):
|
|
# type: (QueryAndGroup, float, int, bool) -> None
|
|
super(QueryAndGroup, self).__init__()
|
|
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
|
self.ret_grouped_xyz = ret_grouped_xyz
|
|
self.normalize_xyz = normalize_xyz
|
|
self.sample_uniformly = sample_uniformly
|
|
self.ret_unique_cnt = ret_unique_cnt
|
|
if self.ret_unique_cnt:
|
|
assert(self.sample_uniformly)
|
|
|
|
def forward(self, xyz, new_xyz, features=None):
|
|
# type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
|
|
r"""
|
|
Parameters
|
|
----------
|
|
xyz : torch.Tensor
|
|
xyz coordinates of the features (B, N, 3)
|
|
new_xyz : torch.Tensor
|
|
centriods (B, npoint, 3)
|
|
features : torch.Tensor
|
|
Descriptors of the features (B, C, N)
|
|
|
|
Returns
|
|
-------
|
|
new_features : torch.Tensor
|
|
(B, 3 + C, npoint, nsample) tensor
|
|
"""
|
|
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
|
|
|
if self.sample_uniformly:
|
|
unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
|
|
for i_batch in range(idx.shape[0]):
|
|
for i_region in range(idx.shape[1]):
|
|
unique_ind = torch.unique(idx[i_batch, i_region, :])
|
|
num_unique = unique_ind.shape[0]
|
|
unique_cnt[i_batch, i_region] = num_unique
|
|
sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long)
|
|
all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
|
|
idx[i_batch, i_region, :] = all_ind
|
|
|
|
|
|
xyz_trans = xyz.transpose(1, 2).contiguous()
|
|
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
|
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
|
|
if self.normalize_xyz:
|
|
grouped_xyz /= self.radius
|
|
|
|
if features is not None:
|
|
grouped_features = grouping_operation(features, idx)
|
|
if self.use_xyz:
|
|
new_features = torch.cat(
|
|
[grouped_xyz, grouped_features], dim=1
|
|
) # (B, C + 3, npoint, nsample)
|
|
else:
|
|
new_features = grouped_features
|
|
else:
|
|
assert (
|
|
self.use_xyz
|
|
), "Cannot have not features and not use xyz as a feature!"
|
|
new_features = grouped_xyz
|
|
|
|
ret = [new_features]
|
|
if self.ret_grouped_xyz:
|
|
ret.append(grouped_xyz)
|
|
if self.ret_unique_cnt:
|
|
ret.append(unique_cnt)
|
|
if len(ret) == 1:
|
|
return ret[0]
|
|
else:
|
|
return tuple(ret)
|
|
|
|
|
|
class GroupAll(nn.Module):
|
|
r"""
|
|
Groups all features
|
|
|
|
Parameters
|
|
---------
|
|
"""
|
|
|
|
def __init__(self, use_xyz=True, ret_grouped_xyz=False):
|
|
# type: (GroupAll, bool) -> None
|
|
super(GroupAll, self).__init__()
|
|
self.use_xyz = use_xyz
|
|
|
|
def forward(self, xyz, new_xyz, features=None):
|
|
# type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
|
|
r"""
|
|
Parameters
|
|
----------
|
|
xyz : torch.Tensor
|
|
xyz coordinates of the features (B, N, 3)
|
|
new_xyz : torch.Tensor
|
|
Ignored
|
|
features : torch.Tensor
|
|
Descriptors of the features (B, C, N)
|
|
|
|
Returns
|
|
-------
|
|
new_features : torch.Tensor
|
|
(B, C + 3, 1, N) tensor
|
|
"""
|
|
|
|
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
|
if features is not None:
|
|
grouped_features = features.unsqueeze(2)
|
|
if self.use_xyz:
|
|
new_features = torch.cat(
|
|
[grouped_xyz, grouped_features], dim=1
|
|
) # (B, 3 + C, 1, N)
|
|
else:
|
|
new_features = grouped_features
|
|
else:
|
|
new_features = grouped_xyz
|
|
|
|
if self.ret_grouped_xyz:
|
|
return new_features, grouped_xyz
|
|
else:
|
|
return new_features
|
|
|
|
|
|
class CylinderQuery(Function):
|
|
@staticmethod
|
|
def forward(ctx, radius, hmin, hmax, nsample, xyz, new_xyz, rot):
|
|
# type: (Any, float, float, float, int, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
|
|
r"""
|
|
|
|
Parameters
|
|
----------
|
|
radius : float
|
|
radius of the cylinders
|
|
hmin, hmax : float
|
|
endpoints of cylinder height in x-rotation axis
|
|
nsample : int
|
|
maximum number of features in the cylinders
|
|
xyz : torch.Tensor
|
|
(B, N, 3) xyz coordinates of the features
|
|
new_xyz : torch.Tensor
|
|
(B, npoint, 3) centers of the cylinder query
|
|
rot: torch.Tensor
|
|
(B, npoint, 9) flatten rotation matrices from
|
|
cylinder frame to world frame
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
|
"""
|
|
return _ext.cylinder_query(new_xyz, xyz, rot, radius, hmin, hmax, nsample)
|
|
|
|
@staticmethod
|
|
def backward(ctx, a=None):
|
|
return None, None, None, None, None, None, None
|
|
|
|
|
|
cylinder_query = CylinderQuery.apply
|
|
|
|
|
|
class CylinderQueryAndGroup(nn.Module):
|
|
r"""
|
|
Groups with a cylinder query of radius and height
|
|
|
|
Parameters
|
|
---------
|
|
radius : float32
|
|
Radius of cylinder
|
|
hmin, hmax: float32
|
|
endpoints of cylinder height in x-rotation axis
|
|
nsample : int32
|
|
Maximum number of features to gather in the ball
|
|
"""
|
|
|
|
def __init__(self, radius, hmin, hmax, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, rotate_xyz=True, sample_uniformly=False, ret_unique_cnt=False):
|
|
super(CylinderQueryAndGroup, self).__init__()
|
|
self.radius, self.nsample, self.hmin, self.hmax, = radius, nsample, hmin, hmax
|
|
self.use_xyz = use_xyz
|
|
self.ret_grouped_xyz = ret_grouped_xyz
|
|
self.normalize_xyz = normalize_xyz
|
|
self.rotate_xyz = rotate_xyz
|
|
self.sample_uniformly = sample_uniformly
|
|
self.ret_unique_cnt = ret_unique_cnt
|
|
if self.ret_unique_cnt:
|
|
assert(self.sample_uniformly)
|
|
|
|
def forward(self, xyz, new_xyz, rot, features=None):
|
|
r"""
|
|
Parameters
|
|
----------
|
|
xyz : torch.Tensor
|
|
xyz coordinates of the features (B, N, 3)
|
|
new_xyz : torch.Tensor
|
|
centriods (B, npoint, 3)
|
|
rot : torch.Tensor
|
|
rotation matrices (B, npoint, 3, 3)
|
|
features : torch.Tensor
|
|
Descriptors of the features (B, C, N)
|
|
|
|
Returns
|
|
-------
|
|
new_features : torch.Tensor
|
|
(B, 3 + C, npoint, nsample) tensor
|
|
"""
|
|
B, npoint, _ = new_xyz.size()
|
|
idx = cylinder_query(self.radius, self.hmin, self.hmax, self.nsample, xyz, new_xyz, rot.view(B, npoint, 9))
|
|
|
|
if self.sample_uniformly:
|
|
unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
|
|
for i_batch in range(idx.shape[0]):
|
|
for i_region in range(idx.shape[1]):
|
|
unique_ind = torch.unique(idx[i_batch, i_region, :])
|
|
num_unique = unique_ind.shape[0]
|
|
unique_cnt[i_batch, i_region] = num_unique
|
|
sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long)
|
|
all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
|
|
idx[i_batch, i_region, :] = all_ind
|
|
|
|
|
|
xyz_trans = xyz.transpose(1, 2).contiguous()
|
|
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
|
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
|
|
if self.normalize_xyz:
|
|
grouped_xyz /= self.radius
|
|
if self.rotate_xyz:
|
|
grouped_xyz_ = grouped_xyz.permute(0, 2, 3, 1).contiguous() # (B, npoint, nsample, 3)
|
|
grouped_xyz_ = torch.matmul(grouped_xyz_, rot)
|
|
grouped_xyz = grouped_xyz_.permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
|
if features is not None:
|
|
grouped_features = grouping_operation(features, idx)
|
|
if self.use_xyz:
|
|
new_features = torch.cat(
|
|
[grouped_xyz, grouped_features], dim=1
|
|
) # (B, C + 3, npoint, nsample)
|
|
else:
|
|
new_features = grouped_features
|
|
else:
|
|
assert (
|
|
self.use_xyz
|
|
), "Cannot have not features and not use xyz as a feature!"
|
|
new_features = grouped_xyz
|
|
|
|
ret = [new_features]
|
|
if self.ret_grouped_xyz:
|
|
ret.append(grouped_xyz)
|
|
if self.ret_unique_cnt:
|
|
ret.append(unique_cnt)
|
|
if len(ret) == 1:
|
|
return ret[0]
|
|
else:
|
|
return tuple(ret) |