learning3d 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- learning3d/__init__.py +2 -0
- learning3d/data_utils/__init__.py +4 -0
- learning3d/data_utils/dataloaders.py +454 -0
- learning3d/data_utils/user_data.py +119 -0
- learning3d/examples/test_dcp.py +139 -0
- learning3d/examples/test_deepgmr.py +144 -0
- learning3d/examples/test_flownet.py +113 -0
- learning3d/examples/test_masknet.py +159 -0
- learning3d/examples/test_masknet2.py +162 -0
- learning3d/examples/test_pcn.py +118 -0
- learning3d/examples/test_pcrnet.py +120 -0
- learning3d/examples/test_pnlk.py +121 -0
- learning3d/examples/test_pointconv.py +126 -0
- learning3d/examples/test_pointnet.py +121 -0
- learning3d/examples/test_prnet.py +126 -0
- learning3d/examples/test_rpmnet.py +120 -0
- learning3d/examples/train_PointNetLK.py +240 -0
- learning3d/examples/train_dcp.py +249 -0
- learning3d/examples/train_deepgmr.py +244 -0
- learning3d/examples/train_flownet.py +259 -0
- learning3d/examples/train_masknet.py +239 -0
- learning3d/examples/train_pcn.py +216 -0
- learning3d/examples/train_pcrnet.py +228 -0
- learning3d/examples/train_pointconv.py +245 -0
- learning3d/examples/train_pointnet.py +244 -0
- learning3d/examples/train_prnet.py +229 -0
- learning3d/examples/train_rpmnet.py +228 -0
- learning3d/losses/__init__.py +12 -0
- learning3d/losses/chamfer_distance.py +51 -0
- learning3d/losses/classification.py +14 -0
- learning3d/losses/correspondence_loss.py +10 -0
- learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
- learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
- learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
- learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
- learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
- learning3d/losses/cuda/emd_torch/setup.py +29 -0
- learning3d/losses/emd.py +16 -0
- learning3d/losses/frobenius_norm.py +21 -0
- learning3d/losses/rmse_features.py +16 -0
- learning3d/models/__init__.py +23 -0
- learning3d/models/classifier.py +41 -0
- learning3d/models/dcp.py +92 -0
- learning3d/models/deepgmr.py +165 -0
- learning3d/models/dgcnn.py +92 -0
- learning3d/models/flownet3d.py +446 -0
- learning3d/models/masknet.py +84 -0
- learning3d/models/masknet2.py +264 -0
- learning3d/models/pcn.py +164 -0
- learning3d/models/pcrnet.py +74 -0
- learning3d/models/pointconv.py +108 -0
- learning3d/models/pointnet.py +108 -0
- learning3d/models/pointnetlk.py +173 -0
- learning3d/models/pooling.py +15 -0
- learning3d/models/ppfnet.py +102 -0
- learning3d/models/prnet.py +431 -0
- learning3d/models/rpmnet.py +359 -0
- learning3d/models/segmentation.py +38 -0
- learning3d/ops/__init__.py +0 -0
- learning3d/ops/data_utils.py +45 -0
- learning3d/ops/invmat.py +134 -0
- learning3d/ops/quaternion.py +218 -0
- learning3d/ops/se3.py +157 -0
- learning3d/ops/sinc.py +229 -0
- learning3d/ops/so3.py +213 -0
- learning3d/ops/transform_functions.py +342 -0
- learning3d/utils/__init__.py +9 -0
- learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
- learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
- learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
- learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
- learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
- learning3d/utils/lib/pointnet2_modules.py +160 -0
- learning3d/utils/lib/pointnet2_utils.py +318 -0
- learning3d/utils/lib/pytorch_utils.py +236 -0
- learning3d/utils/lib/setup.py +23 -0
- learning3d/utils/lib/src/ball_query.cpp +25 -0
- learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
- learning3d/utils/lib/src/ball_query_gpu.h +15 -0
- learning3d/utils/lib/src/cuda_utils.h +15 -0
- learning3d/utils/lib/src/group_points.cpp +36 -0
- learning3d/utils/lib/src/group_points_gpu.cu +86 -0
- learning3d/utils/lib/src/group_points_gpu.h +22 -0
- learning3d/utils/lib/src/interpolate.cpp +65 -0
- learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
- learning3d/utils/lib/src/interpolate_gpu.h +36 -0
- learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
- learning3d/utils/lib/src/sampling.cpp +46 -0
- learning3d/utils/lib/src/sampling_gpu.cu +253 -0
- learning3d/utils/lib/src/sampling_gpu.h +29 -0
- learning3d/utils/pointconv_util.py +382 -0
- learning3d/utils/ppfnet_util.py +244 -0
- learning3d/utils/svd.py +59 -0
- learning3d/utils/transformer.py +243 -0
- learning3d-0.0.1.dist-info/LICENSE +21 -0
- learning3d-0.0.1.dist-info/METADATA +271 -0
- learning3d-0.0.1.dist-info/RECORD +115 -0
- learning3d-0.0.1.dist-info/WHEEL +5 -0
- learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,160 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
from . import pointnet2_utils
|
6
|
+
from . import pytorch_utils as pt_utils
|
7
|
+
from typing import List
|
8
|
+
|
9
|
+
|
10
|
+
class _PointnetSAModuleBase(nn.Module):
|
11
|
+
|
12
|
+
def __init__(self):
|
13
|
+
super().__init__()
|
14
|
+
self.npoint = None
|
15
|
+
self.groupers = None
|
16
|
+
self.mlps = None
|
17
|
+
self.pool_method = 'max_pool'
|
18
|
+
|
19
|
+
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
|
20
|
+
"""
|
21
|
+
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
|
22
|
+
:param features: (B, N, C) tensor of the descriptors of the the features
|
23
|
+
:param new_xyz:
|
24
|
+
:return:
|
25
|
+
new_xyz: (B, npoint, 3) tensor of the new features' xyz
|
26
|
+
new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
|
27
|
+
"""
|
28
|
+
new_features_list = []
|
29
|
+
|
30
|
+
xyz_flipped = xyz.transpose(1, 2).contiguous()
|
31
|
+
if new_xyz is None:
|
32
|
+
new_xyz = pointnet2_utils.gather_operation(
|
33
|
+
xyz_flipped,
|
34
|
+
pointnet2_utils.furthest_point_sample(xyz, self.npoint)
|
35
|
+
).transpose(1, 2).contiguous() if self.npoint is not None else None
|
36
|
+
|
37
|
+
for i in range(len(self.groupers)):
|
38
|
+
new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
|
39
|
+
|
40
|
+
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
|
41
|
+
if self.pool_method == 'max_pool':
|
42
|
+
new_features = F.max_pool2d(
|
43
|
+
new_features, kernel_size=[1, new_features.size(3)]
|
44
|
+
) # (B, mlp[-1], npoint, 1)
|
45
|
+
elif self.pool_method == 'avg_pool':
|
46
|
+
new_features = F.avg_pool2d(
|
47
|
+
new_features, kernel_size=[1, new_features.size(3)]
|
48
|
+
) # (B, mlp[-1], npoint, 1)
|
49
|
+
else:
|
50
|
+
raise NotImplementedError
|
51
|
+
|
52
|
+
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
|
53
|
+
new_features_list.append(new_features)
|
54
|
+
|
55
|
+
return new_xyz, torch.cat(new_features_list, dim=1)
|
56
|
+
|
57
|
+
|
58
|
+
class PointnetSAModuleMSG(_PointnetSAModuleBase):
|
59
|
+
"""Pointnet set abstraction layer with multiscale grouping"""
|
60
|
+
|
61
|
+
def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
|
62
|
+
use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
|
63
|
+
"""
|
64
|
+
:param npoint: int
|
65
|
+
:param radii: list of float, list of radii to group with
|
66
|
+
:param nsamples: list of int, number of samples in each ball query
|
67
|
+
:param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
|
68
|
+
:param bn: whether to use batchnorm
|
69
|
+
:param use_xyz:
|
70
|
+
:param pool_method: max_pool / avg_pool
|
71
|
+
:param instance_norm: whether to use instance_norm
|
72
|
+
"""
|
73
|
+
super().__init__()
|
74
|
+
|
75
|
+
assert len(radii) == len(nsamples) == len(mlps)
|
76
|
+
|
77
|
+
self.npoint = npoint
|
78
|
+
self.groupers = nn.ModuleList()
|
79
|
+
self.mlps = nn.ModuleList()
|
80
|
+
for i in range(len(radii)):
|
81
|
+
radius = radii[i]
|
82
|
+
nsample = nsamples[i]
|
83
|
+
self.groupers.append(
|
84
|
+
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
|
85
|
+
if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
|
86
|
+
)
|
87
|
+
mlp_spec = mlps[i]
|
88
|
+
if use_xyz:
|
89
|
+
mlp_spec[0] += 3
|
90
|
+
|
91
|
+
self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
|
92
|
+
self.pool_method = pool_method
|
93
|
+
|
94
|
+
|
95
|
+
class PointnetSAModule(PointnetSAModuleMSG):
|
96
|
+
"""Pointnet set abstraction layer"""
|
97
|
+
|
98
|
+
def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
|
99
|
+
bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
|
100
|
+
"""
|
101
|
+
:param mlp: list of int, spec of the pointnet before the global max_pool
|
102
|
+
:param npoint: int, number of features
|
103
|
+
:param radius: float, radius of ball
|
104
|
+
:param nsample: int, number of samples in the ball query
|
105
|
+
:param bn: whether to use batchnorm
|
106
|
+
:param use_xyz:
|
107
|
+
:param pool_method: max_pool / avg_pool
|
108
|
+
:param instance_norm: whether to use instance_norm
|
109
|
+
"""
|
110
|
+
super().__init__(
|
111
|
+
mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
|
112
|
+
pool_method=pool_method, instance_norm=instance_norm
|
113
|
+
)
|
114
|
+
|
115
|
+
|
116
|
+
class PointnetFPModule(nn.Module):
|
117
|
+
r"""Propigates the features of one set to another"""
|
118
|
+
|
119
|
+
def __init__(self, *, mlp: List[int], bn: bool = True):
|
120
|
+
"""
|
121
|
+
:param mlp: list of int
|
122
|
+
:param bn: whether to use batchnorm
|
123
|
+
"""
|
124
|
+
super().__init__()
|
125
|
+
self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
|
126
|
+
|
127
|
+
def forward(
|
128
|
+
self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
|
129
|
+
) -> torch.Tensor:
|
130
|
+
"""
|
131
|
+
:param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
|
132
|
+
:param known: (B, m, 3) tensor of the xyz positions of the known features
|
133
|
+
:param unknow_feats: (B, C1, n) tensor of the features to be propigated to
|
134
|
+
:param known_feats: (B, C2, m) tensor of features to be propigated
|
135
|
+
:return:
|
136
|
+
new_features: (B, mlp[-1], n) tensor of the features of the unknown features
|
137
|
+
"""
|
138
|
+
if known is not None:
|
139
|
+
dist, idx = pointnet2_utils.three_nn(unknown, known)
|
140
|
+
dist_recip = 1.0 / (dist + 1e-8)
|
141
|
+
norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
142
|
+
weight = dist_recip / norm
|
143
|
+
|
144
|
+
interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
|
145
|
+
else:
|
146
|
+
interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
|
147
|
+
|
148
|
+
if unknow_feats is not None:
|
149
|
+
new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
|
150
|
+
else:
|
151
|
+
new_features = interpolated_feats
|
152
|
+
|
153
|
+
new_features = new_features.unsqueeze(-1)
|
154
|
+
new_features = self.mlp(new_features)
|
155
|
+
|
156
|
+
return new_features.squeeze(-1)
|
157
|
+
|
158
|
+
|
159
|
+
if __name__ == "__main__":
|
160
|
+
pass
|
@@ -0,0 +1,318 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.autograd import Variable
|
3
|
+
from torch.autograd import Function
|
4
|
+
import torch.nn as nn
|
5
|
+
from typing import Tuple
|
6
|
+
|
7
|
+
import pointnet2_cuda as pointnet2
|
8
|
+
|
9
|
+
|
10
|
+
class FurthestPointSampling(Function):
|
11
|
+
@staticmethod
|
12
|
+
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
|
13
|
+
"""
|
14
|
+
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
15
|
+
minimum distance
|
16
|
+
:param ctx:
|
17
|
+
:param xyz: (B, N, 3) where N > npoint
|
18
|
+
:param npoint: int, number of features in the sampled set
|
19
|
+
:return:
|
20
|
+
output: (B, npoint) tensor containing the set
|
21
|
+
"""
|
22
|
+
assert xyz.is_contiguous()
|
23
|
+
|
24
|
+
B, N, _ = xyz.size()
|
25
|
+
output = torch.cuda.IntTensor(B, npoint)
|
26
|
+
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
|
27
|
+
|
28
|
+
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
|
29
|
+
return output
|
30
|
+
|
31
|
+
@staticmethod
|
32
|
+
def backward(xyz, a=None):
|
33
|
+
return None, None
|
34
|
+
|
35
|
+
|
36
|
+
furthest_point_sample = FurthestPointSampling.apply
|
37
|
+
|
38
|
+
|
39
|
+
class GatherOperation(Function):
|
40
|
+
|
41
|
+
@staticmethod
|
42
|
+
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
43
|
+
"""
|
44
|
+
:param ctx:
|
45
|
+
:param features: (B, C, N)
|
46
|
+
:param idx: (B, npoint) index tensor of the features to gather
|
47
|
+
:return:
|
48
|
+
output: (B, C, npoint)
|
49
|
+
"""
|
50
|
+
assert features.is_contiguous()
|
51
|
+
assert idx.is_contiguous()
|
52
|
+
|
53
|
+
B, npoint = idx.size()
|
54
|
+
_, C, N = features.size()
|
55
|
+
output = torch.cuda.FloatTensor(B, C, npoint)
|
56
|
+
|
57
|
+
pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
|
58
|
+
|
59
|
+
ctx.for_backwards = (idx, C, N)
|
60
|
+
return output
|
61
|
+
|
62
|
+
@staticmethod
|
63
|
+
def backward(ctx, grad_out):
|
64
|
+
idx, C, N = ctx.for_backwards
|
65
|
+
B, npoint = idx.size()
|
66
|
+
|
67
|
+
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
68
|
+
grad_out_data = grad_out.data.contiguous()
|
69
|
+
pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
|
70
|
+
return grad_features, None
|
71
|
+
|
72
|
+
|
73
|
+
gather_operation = GatherOperation.apply
|
74
|
+
|
75
|
+
class KNN(Function):
|
76
|
+
|
77
|
+
@staticmethod
|
78
|
+
def forward(ctx, k: int, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
79
|
+
"""
|
80
|
+
Find the three nearest neighbors of unknown in known
|
81
|
+
:param ctx:
|
82
|
+
:param unknown: (B, N, 3)
|
83
|
+
:param known: (B, M, 3)
|
84
|
+
:return:
|
85
|
+
dist: (B, N, k) l2 distance to the three nearest neighbors
|
86
|
+
idx: (B, N, k) index of 3 nearest neighbors
|
87
|
+
"""
|
88
|
+
assert unknown.is_contiguous()
|
89
|
+
assert known.is_contiguous()
|
90
|
+
|
91
|
+
B, N, _ = unknown.size()
|
92
|
+
m = known.size(1)
|
93
|
+
dist2 = torch.cuda.FloatTensor(B, N, k)
|
94
|
+
idx = torch.cuda.IntTensor(B, N, k)
|
95
|
+
|
96
|
+
pointnet2.knn_wrapper(B, N, m, k, unknown, known, dist2, idx)
|
97
|
+
return torch.sqrt(dist2), idx
|
98
|
+
|
99
|
+
@staticmethod
|
100
|
+
def backward(ctx, a=None, b=None):
|
101
|
+
return None, None, None
|
102
|
+
knn = KNN.apply
|
103
|
+
|
104
|
+
class ThreeNN(Function):
|
105
|
+
|
106
|
+
@staticmethod
|
107
|
+
def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
108
|
+
"""
|
109
|
+
Find the three nearest neighbors of unknown in known
|
110
|
+
:param ctx:
|
111
|
+
:param unknown: (B, N, 3)
|
112
|
+
:param known: (B, M, 3)
|
113
|
+
:return:
|
114
|
+
dist: (B, N, 3) l2 distance to the three nearest neighbors
|
115
|
+
idx: (B, N, 3) index of 3 nearest neighbors
|
116
|
+
"""
|
117
|
+
assert unknown.is_contiguous()
|
118
|
+
assert known.is_contiguous()
|
119
|
+
|
120
|
+
B, N, _ = unknown.size()
|
121
|
+
m = known.size(1)
|
122
|
+
dist2 = torch.cuda.FloatTensor(B, N, 3)
|
123
|
+
idx = torch.cuda.IntTensor(B, N, 3)
|
124
|
+
|
125
|
+
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
|
126
|
+
return torch.sqrt(dist2), idx
|
127
|
+
|
128
|
+
@staticmethod
|
129
|
+
def backward(ctx, a=None, b=None):
|
130
|
+
return None, None
|
131
|
+
|
132
|
+
|
133
|
+
three_nn = ThreeNN.apply
|
134
|
+
|
135
|
+
|
136
|
+
class ThreeInterpolate(Function):
|
137
|
+
|
138
|
+
@staticmethod
|
139
|
+
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
140
|
+
"""
|
141
|
+
Performs weight linear interpolation on 3 features
|
142
|
+
:param ctx:
|
143
|
+
:param features: (B, C, M) Features descriptors to be interpolated from
|
144
|
+
:param idx: (B, n, 3) three nearest neighbors of the target features in features
|
145
|
+
:param weight: (B, n, 3) weights
|
146
|
+
:return:
|
147
|
+
output: (B, C, N) tensor of the interpolated features
|
148
|
+
"""
|
149
|
+
assert features.is_contiguous()
|
150
|
+
assert idx.is_contiguous()
|
151
|
+
assert weight.is_contiguous()
|
152
|
+
|
153
|
+
B, c, m = features.size()
|
154
|
+
n = idx.size(1)
|
155
|
+
ctx.three_interpolate_for_backward = (idx, weight, m)
|
156
|
+
output = torch.cuda.FloatTensor(B, c, n)
|
157
|
+
|
158
|
+
pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
|
159
|
+
return output
|
160
|
+
|
161
|
+
@staticmethod
|
162
|
+
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
163
|
+
"""
|
164
|
+
:param ctx:
|
165
|
+
:param grad_out: (B, C, N) tensor with gradients of outputs
|
166
|
+
:return:
|
167
|
+
grad_features: (B, C, M) tensor with gradients of features
|
168
|
+
None:
|
169
|
+
None:
|
170
|
+
"""
|
171
|
+
idx, weight, m = ctx.three_interpolate_for_backward
|
172
|
+
B, c, n = grad_out.size()
|
173
|
+
|
174
|
+
grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
|
175
|
+
grad_out_data = grad_out.data.contiguous()
|
176
|
+
|
177
|
+
pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
|
178
|
+
return grad_features, None, None
|
179
|
+
|
180
|
+
|
181
|
+
three_interpolate = ThreeInterpolate.apply
|
182
|
+
|
183
|
+
|
184
|
+
class GroupingOperation(Function):
|
185
|
+
|
186
|
+
@staticmethod
|
187
|
+
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
188
|
+
"""
|
189
|
+
:param ctx:
|
190
|
+
:param features: (B, C, N) tensor of features to group
|
191
|
+
:param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
|
192
|
+
:return:
|
193
|
+
output: (B, C, npoint, nsample) tensor
|
194
|
+
"""
|
195
|
+
assert features.is_contiguous()
|
196
|
+
assert idx.is_contiguous()
|
197
|
+
idx = idx.int()
|
198
|
+
B, nfeatures, nsample = idx.size()
|
199
|
+
_, C, N = features.size()
|
200
|
+
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
|
201
|
+
|
202
|
+
pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
|
203
|
+
|
204
|
+
ctx.for_backwards = (idx, N)
|
205
|
+
return output
|
206
|
+
|
207
|
+
@staticmethod
|
208
|
+
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
209
|
+
"""
|
210
|
+
:param ctx:
|
211
|
+
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
|
212
|
+
:return:
|
213
|
+
grad_features: (B, C, N) gradient of the features
|
214
|
+
"""
|
215
|
+
idx, N = ctx.for_backwards
|
216
|
+
|
217
|
+
B, C, npoint, nsample = grad_out.size()
|
218
|
+
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
219
|
+
|
220
|
+
grad_out_data = grad_out.data.contiguous()
|
221
|
+
pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
|
222
|
+
return grad_features, None
|
223
|
+
|
224
|
+
|
225
|
+
grouping_operation = GroupingOperation.apply
|
226
|
+
|
227
|
+
|
228
|
+
class BallQuery(Function):
|
229
|
+
|
230
|
+
@staticmethod
|
231
|
+
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
|
232
|
+
"""
|
233
|
+
:param ctx:
|
234
|
+
:param radius: float, radius of the balls
|
235
|
+
:param nsample: int, maximum number of features in the balls
|
236
|
+
:param xyz: (B, N, 3) xyz coordinates of the features
|
237
|
+
:param new_xyz: (B, npoint, 3) centers of the ball query
|
238
|
+
:return:
|
239
|
+
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
240
|
+
"""
|
241
|
+
assert new_xyz.is_contiguous()
|
242
|
+
assert xyz.is_contiguous()
|
243
|
+
|
244
|
+
B, N, _ = xyz.size()
|
245
|
+
npoint = new_xyz.size(1)
|
246
|
+
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
|
247
|
+
|
248
|
+
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
|
249
|
+
return idx
|
250
|
+
|
251
|
+
@staticmethod
|
252
|
+
def backward(ctx, a=None):
|
253
|
+
return None, None, None, None
|
254
|
+
|
255
|
+
|
256
|
+
ball_query = BallQuery.apply
|
257
|
+
|
258
|
+
|
259
|
+
class QueryAndGroup(nn.Module):
|
260
|
+
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
|
261
|
+
"""
|
262
|
+
:param radius: float, radius of ball
|
263
|
+
:param nsample: int, maximum number of features to gather in the ball
|
264
|
+
:param use_xyz:
|
265
|
+
"""
|
266
|
+
super().__init__()
|
267
|
+
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
268
|
+
|
269
|
+
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
|
270
|
+
"""
|
271
|
+
:param xyz: (B, N, 3) xyz coordinates of the features
|
272
|
+
:param new_xyz: (B, npoint, 3) centroids
|
273
|
+
:param features: (B, C, N) descriptors of the features
|
274
|
+
:return:
|
275
|
+
new_features: (B, 3 + C, npoint, nsample)
|
276
|
+
"""
|
277
|
+
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
278
|
+
xyz_trans = xyz.transpose(1, 2).contiguous()
|
279
|
+
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
280
|
+
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
|
281
|
+
|
282
|
+
if features is not None:
|
283
|
+
grouped_features = grouping_operation(features, idx)
|
284
|
+
if self.use_xyz:
|
285
|
+
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
|
286
|
+
else:
|
287
|
+
new_features = grouped_features
|
288
|
+
else:
|
289
|
+
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
|
290
|
+
new_features = grouped_xyz
|
291
|
+
|
292
|
+
return new_features
|
293
|
+
|
294
|
+
|
295
|
+
class GroupAll(nn.Module):
|
296
|
+
def __init__(self, use_xyz: bool = True):
|
297
|
+
super().__init__()
|
298
|
+
self.use_xyz = use_xyz
|
299
|
+
|
300
|
+
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
|
301
|
+
"""
|
302
|
+
:param xyz: (B, N, 3) xyz coordinates of the features
|
303
|
+
:param new_xyz: ignored
|
304
|
+
:param features: (B, C, N) descriptors of the features
|
305
|
+
:return:
|
306
|
+
new_features: (B, C + 3, 1, N)
|
307
|
+
"""
|
308
|
+
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
309
|
+
if features is not None:
|
310
|
+
grouped_features = features.unsqueeze(2)
|
311
|
+
if self.use_xyz:
|
312
|
+
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
|
313
|
+
else:
|
314
|
+
new_features = grouped_features
|
315
|
+
else:
|
316
|
+
new_features = grouped_xyz
|
317
|
+
|
318
|
+
return new_features
|
@@ -0,0 +1,236 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
from typing import List, Tuple
|
3
|
+
|
4
|
+
|
5
|
+
class SharedMLP(nn.Sequential):
|
6
|
+
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
args: List[int],
|
10
|
+
*,
|
11
|
+
bn: bool = False,
|
12
|
+
activation=nn.ReLU(inplace=True),
|
13
|
+
preact: bool = False,
|
14
|
+
first: bool = False,
|
15
|
+
name: str = "",
|
16
|
+
instance_norm: bool = False,
|
17
|
+
):
|
18
|
+
super().__init__()
|
19
|
+
|
20
|
+
for i in range(len(args) - 1):
|
21
|
+
self.add_module(
|
22
|
+
name + 'layer{}'.format(i),
|
23
|
+
Conv2d(
|
24
|
+
args[i],
|
25
|
+
args[i + 1],
|
26
|
+
bn=(not first or not preact or (i != 0)) and bn,
|
27
|
+
activation=activation
|
28
|
+
if (not first or not preact or (i != 0)) else None,
|
29
|
+
preact=preact,
|
30
|
+
instance_norm=instance_norm
|
31
|
+
)
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
class _ConvBase(nn.Sequential):
|
36
|
+
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
in_size,
|
40
|
+
out_size,
|
41
|
+
kernel_size,
|
42
|
+
stride,
|
43
|
+
padding,
|
44
|
+
activation,
|
45
|
+
bn,
|
46
|
+
init,
|
47
|
+
conv=None,
|
48
|
+
batch_norm=None,
|
49
|
+
bias=True,
|
50
|
+
preact=False,
|
51
|
+
name="",
|
52
|
+
instance_norm=False,
|
53
|
+
instance_norm_func=None
|
54
|
+
):
|
55
|
+
super().__init__()
|
56
|
+
|
57
|
+
bias = bias and (not bn)
|
58
|
+
conv_unit = conv(
|
59
|
+
in_size,
|
60
|
+
out_size,
|
61
|
+
kernel_size=kernel_size,
|
62
|
+
stride=stride,
|
63
|
+
padding=padding,
|
64
|
+
bias=bias
|
65
|
+
)
|
66
|
+
init(conv_unit.weight)
|
67
|
+
if bias:
|
68
|
+
nn.init.constant_(conv_unit.bias, 0)
|
69
|
+
|
70
|
+
if bn:
|
71
|
+
if not preact:
|
72
|
+
bn_unit = batch_norm(out_size)
|
73
|
+
else:
|
74
|
+
bn_unit = batch_norm(in_size)
|
75
|
+
if instance_norm:
|
76
|
+
if not preact:
|
77
|
+
in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
|
78
|
+
else:
|
79
|
+
in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
|
80
|
+
|
81
|
+
if preact:
|
82
|
+
if bn:
|
83
|
+
self.add_module(name + 'bn', bn_unit)
|
84
|
+
|
85
|
+
if activation is not None:
|
86
|
+
self.add_module(name + 'activation', activation)
|
87
|
+
|
88
|
+
if not bn and instance_norm:
|
89
|
+
self.add_module(name + 'in', in_unit)
|
90
|
+
|
91
|
+
self.add_module(name + 'conv', conv_unit)
|
92
|
+
|
93
|
+
if not preact:
|
94
|
+
if bn:
|
95
|
+
self.add_module(name + 'bn', bn_unit)
|
96
|
+
|
97
|
+
if activation is not None:
|
98
|
+
self.add_module(name + 'activation', activation)
|
99
|
+
|
100
|
+
if not bn and instance_norm:
|
101
|
+
self.add_module(name + 'in', in_unit)
|
102
|
+
|
103
|
+
|
104
|
+
class _BNBase(nn.Sequential):
|
105
|
+
|
106
|
+
def __init__(self, in_size, batch_norm=None, name=""):
|
107
|
+
super().__init__()
|
108
|
+
self.add_module(name + "bn", batch_norm(in_size))
|
109
|
+
|
110
|
+
nn.init.constant_(self[0].weight, 1.0)
|
111
|
+
nn.init.constant_(self[0].bias, 0)
|
112
|
+
|
113
|
+
|
114
|
+
class BatchNorm1d(_BNBase):
|
115
|
+
|
116
|
+
def __init__(self, in_size: int, *, name: str = ""):
|
117
|
+
super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
|
118
|
+
|
119
|
+
|
120
|
+
class BatchNorm2d(_BNBase):
|
121
|
+
|
122
|
+
def __init__(self, in_size: int, name: str = ""):
|
123
|
+
super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
|
124
|
+
|
125
|
+
|
126
|
+
class Conv1d(_ConvBase):
|
127
|
+
|
128
|
+
def __init__(
|
129
|
+
self,
|
130
|
+
in_size: int,
|
131
|
+
out_size: int,
|
132
|
+
*,
|
133
|
+
kernel_size: int = 1,
|
134
|
+
stride: int = 1,
|
135
|
+
padding: int = 0,
|
136
|
+
activation=nn.ReLU(inplace=True),
|
137
|
+
bn: bool = False,
|
138
|
+
init=nn.init.kaiming_normal_,
|
139
|
+
bias: bool = True,
|
140
|
+
preact: bool = False,
|
141
|
+
name: str = "",
|
142
|
+
instance_norm=False
|
143
|
+
):
|
144
|
+
super().__init__(
|
145
|
+
in_size,
|
146
|
+
out_size,
|
147
|
+
kernel_size,
|
148
|
+
stride,
|
149
|
+
padding,
|
150
|
+
activation,
|
151
|
+
bn,
|
152
|
+
init,
|
153
|
+
conv=nn.Conv1d,
|
154
|
+
batch_norm=BatchNorm1d,
|
155
|
+
bias=bias,
|
156
|
+
preact=preact,
|
157
|
+
name=name,
|
158
|
+
instance_norm=instance_norm,
|
159
|
+
instance_norm_func=nn.InstanceNorm1d
|
160
|
+
)
|
161
|
+
|
162
|
+
|
163
|
+
class Conv2d(_ConvBase):
|
164
|
+
|
165
|
+
def __init__(
|
166
|
+
self,
|
167
|
+
in_size: int,
|
168
|
+
out_size: int,
|
169
|
+
*,
|
170
|
+
kernel_size: Tuple[int, int] = (1, 1),
|
171
|
+
stride: Tuple[int, int] = (1, 1),
|
172
|
+
padding: Tuple[int, int] = (0, 0),
|
173
|
+
activation=nn.ReLU(inplace=True),
|
174
|
+
bn: bool = False,
|
175
|
+
init=nn.init.kaiming_normal_,
|
176
|
+
bias: bool = True,
|
177
|
+
preact: bool = False,
|
178
|
+
name: str = "",
|
179
|
+
instance_norm=False
|
180
|
+
):
|
181
|
+
super().__init__(
|
182
|
+
in_size,
|
183
|
+
out_size,
|
184
|
+
kernel_size,
|
185
|
+
stride,
|
186
|
+
padding,
|
187
|
+
activation,
|
188
|
+
bn,
|
189
|
+
init,
|
190
|
+
conv=nn.Conv2d,
|
191
|
+
batch_norm=BatchNorm2d,
|
192
|
+
bias=bias,
|
193
|
+
preact=preact,
|
194
|
+
name=name,
|
195
|
+
instance_norm=instance_norm,
|
196
|
+
instance_norm_func=nn.InstanceNorm2d
|
197
|
+
)
|
198
|
+
|
199
|
+
|
200
|
+
class FC(nn.Sequential):
|
201
|
+
|
202
|
+
def __init__(
|
203
|
+
self,
|
204
|
+
in_size: int,
|
205
|
+
out_size: int,
|
206
|
+
*,
|
207
|
+
activation=nn.ReLU(inplace=True),
|
208
|
+
bn: bool = False,
|
209
|
+
init=None,
|
210
|
+
preact: bool = False,
|
211
|
+
name: str = ""
|
212
|
+
):
|
213
|
+
super().__init__()
|
214
|
+
|
215
|
+
fc = nn.Linear(in_size, out_size, bias=not bn)
|
216
|
+
if init is not None:
|
217
|
+
init(fc.weight)
|
218
|
+
if not bn:
|
219
|
+
nn.init.constant(fc.bias, 0)
|
220
|
+
|
221
|
+
if preact:
|
222
|
+
if bn:
|
223
|
+
self.add_module(name + 'bn', BatchNorm1d(in_size))
|
224
|
+
|
225
|
+
if activation is not None:
|
226
|
+
self.add_module(name + 'activation', activation)
|
227
|
+
|
228
|
+
self.add_module(name + 'fc', fc)
|
229
|
+
|
230
|
+
if not preact:
|
231
|
+
if bn:
|
232
|
+
self.add_module(name + 'bn', BatchNorm1d(out_size))
|
233
|
+
|
234
|
+
if activation is not None:
|
235
|
+
self.add_module(name + 'activation', activation)
|
236
|
+
|