learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,160 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from . import pointnet2_utils
6
+ from . import pytorch_utils as pt_utils
7
+ from typing import List
8
+
9
+
10
+ class _PointnetSAModuleBase(nn.Module):
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.npoint = None
15
+ self.groupers = None
16
+ self.mlps = None
17
+ self.pool_method = 'max_pool'
18
+
19
+ def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
20
+ """
21
+ :param xyz: (B, N, 3) tensor of the xyz coordinates of the features
22
+ :param features: (B, N, C) tensor of the descriptors of the the features
23
+ :param new_xyz:
24
+ :return:
25
+ new_xyz: (B, npoint, 3) tensor of the new features' xyz
26
+ new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
27
+ """
28
+ new_features_list = []
29
+
30
+ xyz_flipped = xyz.transpose(1, 2).contiguous()
31
+ if new_xyz is None:
32
+ new_xyz = pointnet2_utils.gather_operation(
33
+ xyz_flipped,
34
+ pointnet2_utils.furthest_point_sample(xyz, self.npoint)
35
+ ).transpose(1, 2).contiguous() if self.npoint is not None else None
36
+
37
+ for i in range(len(self.groupers)):
38
+ new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
39
+
40
+ new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
41
+ if self.pool_method == 'max_pool':
42
+ new_features = F.max_pool2d(
43
+ new_features, kernel_size=[1, new_features.size(3)]
44
+ ) # (B, mlp[-1], npoint, 1)
45
+ elif self.pool_method == 'avg_pool':
46
+ new_features = F.avg_pool2d(
47
+ new_features, kernel_size=[1, new_features.size(3)]
48
+ ) # (B, mlp[-1], npoint, 1)
49
+ else:
50
+ raise NotImplementedError
51
+
52
+ new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
53
+ new_features_list.append(new_features)
54
+
55
+ return new_xyz, torch.cat(new_features_list, dim=1)
56
+
57
+
58
+ class PointnetSAModuleMSG(_PointnetSAModuleBase):
59
+ """Pointnet set abstraction layer with multiscale grouping"""
60
+
61
+ def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
62
+ use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
63
+ """
64
+ :param npoint: int
65
+ :param radii: list of float, list of radii to group with
66
+ :param nsamples: list of int, number of samples in each ball query
67
+ :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
68
+ :param bn: whether to use batchnorm
69
+ :param use_xyz:
70
+ :param pool_method: max_pool / avg_pool
71
+ :param instance_norm: whether to use instance_norm
72
+ """
73
+ super().__init__()
74
+
75
+ assert len(radii) == len(nsamples) == len(mlps)
76
+
77
+ self.npoint = npoint
78
+ self.groupers = nn.ModuleList()
79
+ self.mlps = nn.ModuleList()
80
+ for i in range(len(radii)):
81
+ radius = radii[i]
82
+ nsample = nsamples[i]
83
+ self.groupers.append(
84
+ pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
85
+ if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
86
+ )
87
+ mlp_spec = mlps[i]
88
+ if use_xyz:
89
+ mlp_spec[0] += 3
90
+
91
+ self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
92
+ self.pool_method = pool_method
93
+
94
+
95
+ class PointnetSAModule(PointnetSAModuleMSG):
96
+ """Pointnet set abstraction layer"""
97
+
98
+ def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
99
+ bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
100
+ """
101
+ :param mlp: list of int, spec of the pointnet before the global max_pool
102
+ :param npoint: int, number of features
103
+ :param radius: float, radius of ball
104
+ :param nsample: int, number of samples in the ball query
105
+ :param bn: whether to use batchnorm
106
+ :param use_xyz:
107
+ :param pool_method: max_pool / avg_pool
108
+ :param instance_norm: whether to use instance_norm
109
+ """
110
+ super().__init__(
111
+ mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
112
+ pool_method=pool_method, instance_norm=instance_norm
113
+ )
114
+
115
+
116
+ class PointnetFPModule(nn.Module):
117
+ r"""Propigates the features of one set to another"""
118
+
119
+ def __init__(self, *, mlp: List[int], bn: bool = True):
120
+ """
121
+ :param mlp: list of int
122
+ :param bn: whether to use batchnorm
123
+ """
124
+ super().__init__()
125
+ self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
126
+
127
+ def forward(
128
+ self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
129
+ ) -> torch.Tensor:
130
+ """
131
+ :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
132
+ :param known: (B, m, 3) tensor of the xyz positions of the known features
133
+ :param unknow_feats: (B, C1, n) tensor of the features to be propigated to
134
+ :param known_feats: (B, C2, m) tensor of features to be propigated
135
+ :return:
136
+ new_features: (B, mlp[-1], n) tensor of the features of the unknown features
137
+ """
138
+ if known is not None:
139
+ dist, idx = pointnet2_utils.three_nn(unknown, known)
140
+ dist_recip = 1.0 / (dist + 1e-8)
141
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
142
+ weight = dist_recip / norm
143
+
144
+ interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
145
+ else:
146
+ interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
147
+
148
+ if unknow_feats is not None:
149
+ new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
150
+ else:
151
+ new_features = interpolated_feats
152
+
153
+ new_features = new_features.unsqueeze(-1)
154
+ new_features = self.mlp(new_features)
155
+
156
+ return new_features.squeeze(-1)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ pass
@@ -0,0 +1,318 @@
1
+ import torch
2
+ from torch.autograd import Variable
3
+ from torch.autograd import Function
4
+ import torch.nn as nn
5
+ from typing import Tuple
6
+
7
+ import pointnet2_cuda as pointnet2
8
+
9
+
10
+ class FurthestPointSampling(Function):
11
+ @staticmethod
12
+ def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
13
+ """
14
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
15
+ minimum distance
16
+ :param ctx:
17
+ :param xyz: (B, N, 3) where N > npoint
18
+ :param npoint: int, number of features in the sampled set
19
+ :return:
20
+ output: (B, npoint) tensor containing the set
21
+ """
22
+ assert xyz.is_contiguous()
23
+
24
+ B, N, _ = xyz.size()
25
+ output = torch.cuda.IntTensor(B, npoint)
26
+ temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
27
+
28
+ pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
29
+ return output
30
+
31
+ @staticmethod
32
+ def backward(xyz, a=None):
33
+ return None, None
34
+
35
+
36
+ furthest_point_sample = FurthestPointSampling.apply
37
+
38
+
39
+ class GatherOperation(Function):
40
+
41
+ @staticmethod
42
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ :param ctx:
45
+ :param features: (B, C, N)
46
+ :param idx: (B, npoint) index tensor of the features to gather
47
+ :return:
48
+ output: (B, C, npoint)
49
+ """
50
+ assert features.is_contiguous()
51
+ assert idx.is_contiguous()
52
+
53
+ B, npoint = idx.size()
54
+ _, C, N = features.size()
55
+ output = torch.cuda.FloatTensor(B, C, npoint)
56
+
57
+ pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
58
+
59
+ ctx.for_backwards = (idx, C, N)
60
+ return output
61
+
62
+ @staticmethod
63
+ def backward(ctx, grad_out):
64
+ idx, C, N = ctx.for_backwards
65
+ B, npoint = idx.size()
66
+
67
+ grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
68
+ grad_out_data = grad_out.data.contiguous()
69
+ pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
70
+ return grad_features, None
71
+
72
+
73
+ gather_operation = GatherOperation.apply
74
+
75
+ class KNN(Function):
76
+
77
+ @staticmethod
78
+ def forward(ctx, k: int, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
79
+ """
80
+ Find the three nearest neighbors of unknown in known
81
+ :param ctx:
82
+ :param unknown: (B, N, 3)
83
+ :param known: (B, M, 3)
84
+ :return:
85
+ dist: (B, N, k) l2 distance to the three nearest neighbors
86
+ idx: (B, N, k) index of 3 nearest neighbors
87
+ """
88
+ assert unknown.is_contiguous()
89
+ assert known.is_contiguous()
90
+
91
+ B, N, _ = unknown.size()
92
+ m = known.size(1)
93
+ dist2 = torch.cuda.FloatTensor(B, N, k)
94
+ idx = torch.cuda.IntTensor(B, N, k)
95
+
96
+ pointnet2.knn_wrapper(B, N, m, k, unknown, known, dist2, idx)
97
+ return torch.sqrt(dist2), idx
98
+
99
+ @staticmethod
100
+ def backward(ctx, a=None, b=None):
101
+ return None, None, None
102
+ knn = KNN.apply
103
+
104
+ class ThreeNN(Function):
105
+
106
+ @staticmethod
107
+ def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
108
+ """
109
+ Find the three nearest neighbors of unknown in known
110
+ :param ctx:
111
+ :param unknown: (B, N, 3)
112
+ :param known: (B, M, 3)
113
+ :return:
114
+ dist: (B, N, 3) l2 distance to the three nearest neighbors
115
+ idx: (B, N, 3) index of 3 nearest neighbors
116
+ """
117
+ assert unknown.is_contiguous()
118
+ assert known.is_contiguous()
119
+
120
+ B, N, _ = unknown.size()
121
+ m = known.size(1)
122
+ dist2 = torch.cuda.FloatTensor(B, N, 3)
123
+ idx = torch.cuda.IntTensor(B, N, 3)
124
+
125
+ pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
126
+ return torch.sqrt(dist2), idx
127
+
128
+ @staticmethod
129
+ def backward(ctx, a=None, b=None):
130
+ return None, None
131
+
132
+
133
+ three_nn = ThreeNN.apply
134
+
135
+
136
+ class ThreeInterpolate(Function):
137
+
138
+ @staticmethod
139
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
140
+ """
141
+ Performs weight linear interpolation on 3 features
142
+ :param ctx:
143
+ :param features: (B, C, M) Features descriptors to be interpolated from
144
+ :param idx: (B, n, 3) three nearest neighbors of the target features in features
145
+ :param weight: (B, n, 3) weights
146
+ :return:
147
+ output: (B, C, N) tensor of the interpolated features
148
+ """
149
+ assert features.is_contiguous()
150
+ assert idx.is_contiguous()
151
+ assert weight.is_contiguous()
152
+
153
+ B, c, m = features.size()
154
+ n = idx.size(1)
155
+ ctx.three_interpolate_for_backward = (idx, weight, m)
156
+ output = torch.cuda.FloatTensor(B, c, n)
157
+
158
+ pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
159
+ return output
160
+
161
+ @staticmethod
162
+ def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
163
+ """
164
+ :param ctx:
165
+ :param grad_out: (B, C, N) tensor with gradients of outputs
166
+ :return:
167
+ grad_features: (B, C, M) tensor with gradients of features
168
+ None:
169
+ None:
170
+ """
171
+ idx, weight, m = ctx.three_interpolate_for_backward
172
+ B, c, n = grad_out.size()
173
+
174
+ grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
175
+ grad_out_data = grad_out.data.contiguous()
176
+
177
+ pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
178
+ return grad_features, None, None
179
+
180
+
181
+ three_interpolate = ThreeInterpolate.apply
182
+
183
+
184
+ class GroupingOperation(Function):
185
+
186
+ @staticmethod
187
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
188
+ """
189
+ :param ctx:
190
+ :param features: (B, C, N) tensor of features to group
191
+ :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
192
+ :return:
193
+ output: (B, C, npoint, nsample) tensor
194
+ """
195
+ assert features.is_contiguous()
196
+ assert idx.is_contiguous()
197
+ idx = idx.int()
198
+ B, nfeatures, nsample = idx.size()
199
+ _, C, N = features.size()
200
+ output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
201
+
202
+ pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
203
+
204
+ ctx.for_backwards = (idx, N)
205
+ return output
206
+
207
+ @staticmethod
208
+ def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
209
+ """
210
+ :param ctx:
211
+ :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
212
+ :return:
213
+ grad_features: (B, C, N) gradient of the features
214
+ """
215
+ idx, N = ctx.for_backwards
216
+
217
+ B, C, npoint, nsample = grad_out.size()
218
+ grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
219
+
220
+ grad_out_data = grad_out.data.contiguous()
221
+ pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
222
+ return grad_features, None
223
+
224
+
225
+ grouping_operation = GroupingOperation.apply
226
+
227
+
228
+ class BallQuery(Function):
229
+
230
+ @staticmethod
231
+ def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
232
+ """
233
+ :param ctx:
234
+ :param radius: float, radius of the balls
235
+ :param nsample: int, maximum number of features in the balls
236
+ :param xyz: (B, N, 3) xyz coordinates of the features
237
+ :param new_xyz: (B, npoint, 3) centers of the ball query
238
+ :return:
239
+ idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
240
+ """
241
+ assert new_xyz.is_contiguous()
242
+ assert xyz.is_contiguous()
243
+
244
+ B, N, _ = xyz.size()
245
+ npoint = new_xyz.size(1)
246
+ idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
247
+
248
+ pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
249
+ return idx
250
+
251
+ @staticmethod
252
+ def backward(ctx, a=None):
253
+ return None, None, None, None
254
+
255
+
256
+ ball_query = BallQuery.apply
257
+
258
+
259
+ class QueryAndGroup(nn.Module):
260
+ def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
261
+ """
262
+ :param radius: float, radius of ball
263
+ :param nsample: int, maximum number of features to gather in the ball
264
+ :param use_xyz:
265
+ """
266
+ super().__init__()
267
+ self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
268
+
269
+ def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
270
+ """
271
+ :param xyz: (B, N, 3) xyz coordinates of the features
272
+ :param new_xyz: (B, npoint, 3) centroids
273
+ :param features: (B, C, N) descriptors of the features
274
+ :return:
275
+ new_features: (B, 3 + C, npoint, nsample)
276
+ """
277
+ idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
278
+ xyz_trans = xyz.transpose(1, 2).contiguous()
279
+ grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
280
+ grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
281
+
282
+ if features is not None:
283
+ grouped_features = grouping_operation(features, idx)
284
+ if self.use_xyz:
285
+ new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
286
+ else:
287
+ new_features = grouped_features
288
+ else:
289
+ assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
290
+ new_features = grouped_xyz
291
+
292
+ return new_features
293
+
294
+
295
+ class GroupAll(nn.Module):
296
+ def __init__(self, use_xyz: bool = True):
297
+ super().__init__()
298
+ self.use_xyz = use_xyz
299
+
300
+ def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
301
+ """
302
+ :param xyz: (B, N, 3) xyz coordinates of the features
303
+ :param new_xyz: ignored
304
+ :param features: (B, C, N) descriptors of the features
305
+ :return:
306
+ new_features: (B, C + 3, 1, N)
307
+ """
308
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
309
+ if features is not None:
310
+ grouped_features = features.unsqueeze(2)
311
+ if self.use_xyz:
312
+ new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
313
+ else:
314
+ new_features = grouped_features
315
+ else:
316
+ new_features = grouped_xyz
317
+
318
+ return new_features
@@ -0,0 +1,236 @@
1
+ import torch.nn as nn
2
+ from typing import List, Tuple
3
+
4
+
5
+ class SharedMLP(nn.Sequential):
6
+
7
+ def __init__(
8
+ self,
9
+ args: List[int],
10
+ *,
11
+ bn: bool = False,
12
+ activation=nn.ReLU(inplace=True),
13
+ preact: bool = False,
14
+ first: bool = False,
15
+ name: str = "",
16
+ instance_norm: bool = False,
17
+ ):
18
+ super().__init__()
19
+
20
+ for i in range(len(args) - 1):
21
+ self.add_module(
22
+ name + 'layer{}'.format(i),
23
+ Conv2d(
24
+ args[i],
25
+ args[i + 1],
26
+ bn=(not first or not preact or (i != 0)) and bn,
27
+ activation=activation
28
+ if (not first or not preact or (i != 0)) else None,
29
+ preact=preact,
30
+ instance_norm=instance_norm
31
+ )
32
+ )
33
+
34
+
35
+ class _ConvBase(nn.Sequential):
36
+
37
+ def __init__(
38
+ self,
39
+ in_size,
40
+ out_size,
41
+ kernel_size,
42
+ stride,
43
+ padding,
44
+ activation,
45
+ bn,
46
+ init,
47
+ conv=None,
48
+ batch_norm=None,
49
+ bias=True,
50
+ preact=False,
51
+ name="",
52
+ instance_norm=False,
53
+ instance_norm_func=None
54
+ ):
55
+ super().__init__()
56
+
57
+ bias = bias and (not bn)
58
+ conv_unit = conv(
59
+ in_size,
60
+ out_size,
61
+ kernel_size=kernel_size,
62
+ stride=stride,
63
+ padding=padding,
64
+ bias=bias
65
+ )
66
+ init(conv_unit.weight)
67
+ if bias:
68
+ nn.init.constant_(conv_unit.bias, 0)
69
+
70
+ if bn:
71
+ if not preact:
72
+ bn_unit = batch_norm(out_size)
73
+ else:
74
+ bn_unit = batch_norm(in_size)
75
+ if instance_norm:
76
+ if not preact:
77
+ in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
78
+ else:
79
+ in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
80
+
81
+ if preact:
82
+ if bn:
83
+ self.add_module(name + 'bn', bn_unit)
84
+
85
+ if activation is not None:
86
+ self.add_module(name + 'activation', activation)
87
+
88
+ if not bn and instance_norm:
89
+ self.add_module(name + 'in', in_unit)
90
+
91
+ self.add_module(name + 'conv', conv_unit)
92
+
93
+ if not preact:
94
+ if bn:
95
+ self.add_module(name + 'bn', bn_unit)
96
+
97
+ if activation is not None:
98
+ self.add_module(name + 'activation', activation)
99
+
100
+ if not bn and instance_norm:
101
+ self.add_module(name + 'in', in_unit)
102
+
103
+
104
+ class _BNBase(nn.Sequential):
105
+
106
+ def __init__(self, in_size, batch_norm=None, name=""):
107
+ super().__init__()
108
+ self.add_module(name + "bn", batch_norm(in_size))
109
+
110
+ nn.init.constant_(self[0].weight, 1.0)
111
+ nn.init.constant_(self[0].bias, 0)
112
+
113
+
114
+ class BatchNorm1d(_BNBase):
115
+
116
+ def __init__(self, in_size: int, *, name: str = ""):
117
+ super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
118
+
119
+
120
+ class BatchNorm2d(_BNBase):
121
+
122
+ def __init__(self, in_size: int, name: str = ""):
123
+ super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
124
+
125
+
126
+ class Conv1d(_ConvBase):
127
+
128
+ def __init__(
129
+ self,
130
+ in_size: int,
131
+ out_size: int,
132
+ *,
133
+ kernel_size: int = 1,
134
+ stride: int = 1,
135
+ padding: int = 0,
136
+ activation=nn.ReLU(inplace=True),
137
+ bn: bool = False,
138
+ init=nn.init.kaiming_normal_,
139
+ bias: bool = True,
140
+ preact: bool = False,
141
+ name: str = "",
142
+ instance_norm=False
143
+ ):
144
+ super().__init__(
145
+ in_size,
146
+ out_size,
147
+ kernel_size,
148
+ stride,
149
+ padding,
150
+ activation,
151
+ bn,
152
+ init,
153
+ conv=nn.Conv1d,
154
+ batch_norm=BatchNorm1d,
155
+ bias=bias,
156
+ preact=preact,
157
+ name=name,
158
+ instance_norm=instance_norm,
159
+ instance_norm_func=nn.InstanceNorm1d
160
+ )
161
+
162
+
163
+ class Conv2d(_ConvBase):
164
+
165
+ def __init__(
166
+ self,
167
+ in_size: int,
168
+ out_size: int,
169
+ *,
170
+ kernel_size: Tuple[int, int] = (1, 1),
171
+ stride: Tuple[int, int] = (1, 1),
172
+ padding: Tuple[int, int] = (0, 0),
173
+ activation=nn.ReLU(inplace=True),
174
+ bn: bool = False,
175
+ init=nn.init.kaiming_normal_,
176
+ bias: bool = True,
177
+ preact: bool = False,
178
+ name: str = "",
179
+ instance_norm=False
180
+ ):
181
+ super().__init__(
182
+ in_size,
183
+ out_size,
184
+ kernel_size,
185
+ stride,
186
+ padding,
187
+ activation,
188
+ bn,
189
+ init,
190
+ conv=nn.Conv2d,
191
+ batch_norm=BatchNorm2d,
192
+ bias=bias,
193
+ preact=preact,
194
+ name=name,
195
+ instance_norm=instance_norm,
196
+ instance_norm_func=nn.InstanceNorm2d
197
+ )
198
+
199
+
200
+ class FC(nn.Sequential):
201
+
202
+ def __init__(
203
+ self,
204
+ in_size: int,
205
+ out_size: int,
206
+ *,
207
+ activation=nn.ReLU(inplace=True),
208
+ bn: bool = False,
209
+ init=None,
210
+ preact: bool = False,
211
+ name: str = ""
212
+ ):
213
+ super().__init__()
214
+
215
+ fc = nn.Linear(in_size, out_size, bias=not bn)
216
+ if init is not None:
217
+ init(fc.weight)
218
+ if not bn:
219
+ nn.init.constant(fc.bias, 0)
220
+
221
+ if preact:
222
+ if bn:
223
+ self.add_module(name + 'bn', BatchNorm1d(in_size))
224
+
225
+ if activation is not None:
226
+ self.add_module(name + 'activation', activation)
227
+
228
+ self.add_module(name + 'fc', fc)
229
+
230
+ if not preact:
231
+ if bn:
232
+ self.add_module(name + 'bn', BatchNorm1d(out_size))
233
+
234
+ if activation is not None:
235
+ self.add_module(name + 'activation', activation)
236
+