learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,382 @@
1
+ """
2
+ Utility function for PointConv
3
+ Originally from : https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/utils.py
4
+ Modify by Wenxuan Wu
5
+ Date: September 2019
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from time import time
11
+ import numpy as np
12
+ from sklearn.neighbors._kde import KernelDensity
13
+
14
+ def timeit(tag, t):
15
+ print("{}: {}s".format(tag, time() - t))
16
+ return time()
17
+
18
+ def square_distance(src, dst):
19
+ """
20
+ Calculate Euclid distance between each two points.
21
+
22
+ src^T * dst = xn * xm + yn * ym + zn * zm;
23
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
24
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
25
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
26
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
27
+
28
+ Input:
29
+ src: source points, [B, N, C]
30
+ dst: target points, [B, M, C]
31
+ Output:
32
+ dist: per-point square distance, [B, N, M]
33
+ """
34
+ B, N, _ = src.shape
35
+ _, M, _ = dst.shape
36
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
37
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
38
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
39
+ return dist
40
+
41
+ def index_points(points, idx):
42
+ """
43
+
44
+ Input:
45
+ points: input points data, [B, N, C]
46
+ idx: sample index data, [B, S]
47
+ Return:
48
+ new_points:, indexed points data, [B, S, C]
49
+ """
50
+ device = points.device
51
+ B = points.shape[0]
52
+ view_shape = list(idx.shape)
53
+ view_shape[1:] = [1] * (len(view_shape) - 1)
54
+ repeat_shape = list(idx.shape)
55
+ repeat_shape[0] = 1
56
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
57
+ new_points = points[batch_indices, idx, :]
58
+ return new_points
59
+
60
+ def farthest_point_sample(xyz, npoint):
61
+ """
62
+ Input:
63
+ xyz: pointcloud data, [B, N, C]
64
+ npoint: number of samples
65
+ Return:
66
+ centroids: sampled pointcloud index, [B, npoint]
67
+ """
68
+ #import ipdb; ipdb.set_trace()
69
+ device = xyz.device
70
+ B, N, C = xyz.shape
71
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
72
+ distance = torch.ones(B, N).to(device) * 1e10
73
+ #farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
74
+ farthest = torch.zeros(B, dtype=torch.long).to(device)
75
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
76
+ for i in range(npoint):
77
+ centroids[:, i] = farthest
78
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
79
+ dist = torch.sum((xyz - centroid) ** 2, -1)
80
+ mask = dist < distance
81
+ distance[mask] = dist[mask]
82
+ farthest = torch.max(distance, -1)[1]
83
+ return centroids
84
+
85
+ def query_ball_point(radius, nsample, xyz, new_xyz):
86
+ """
87
+ Input:
88
+ radius: local region radius
89
+ nsample: max sample number in local region
90
+ xyz: all points, [B, N, C]
91
+ new_xyz: query points, [B, S, C]
92
+ Return:
93
+ group_idx: grouped points index, [B, S, nsample]
94
+ """
95
+ device = xyz.device
96
+ B, N, C = xyz.shape
97
+ _, S, _ = new_xyz.shape
98
+ group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
99
+ sqrdists = square_distance(new_xyz, xyz)
100
+ group_idx[sqrdists > radius ** 2] = N
101
+ group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
102
+ group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
103
+ mask = group_idx == N
104
+ group_idx[mask] = group_first[mask]
105
+ return group_idx
106
+
107
+ def knn_point(nsample, xyz, new_xyz):
108
+ """
109
+ Input:
110
+ nsample: max sample number in local region
111
+ xyz: all points, [B, N, C]
112
+ new_xyz: query points, [B, S, C]
113
+ Return:
114
+ group_idx: grouped points index, [B, S, nsample]
115
+ """
116
+ sqrdists = square_distance(new_xyz, xyz)
117
+ _, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
118
+ return group_idx
119
+
120
+ def sample_and_group(npoint, nsample, xyz, points, density_scale = None):
121
+ """
122
+ Input:
123
+ npoint:
124
+ nsample:
125
+ xyz: input points position data, [B, N, C]
126
+ points: input points data, [B, N, D]
127
+ Return:
128
+ new_xyz: sampled points position data, [B, 1, C]
129
+ new_points: sampled points data, [B, 1, N, C+D]
130
+ """
131
+ B, N, C = xyz.shape
132
+ S = npoint
133
+ fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
134
+ new_xyz = index_points(xyz, fps_idx)
135
+ idx = knn_point(nsample, xyz, new_xyz)
136
+ grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
137
+ grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
138
+ if points is not None:
139
+ grouped_points = index_points(points, idx)
140
+ new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
141
+ else:
142
+ new_points = grouped_xyz_norm
143
+
144
+ if density_scale is None:
145
+ return new_xyz, new_points, grouped_xyz_norm, idx
146
+ else:
147
+ grouped_density = index_points(density_scale, idx)
148
+ return new_xyz, new_points, grouped_xyz_norm, idx, grouped_density
149
+
150
+ def sample_and_group_all(xyz, points, density_scale = None):
151
+ """
152
+ Input:
153
+ xyz: input points position data, [B, N, C]
154
+ points: input points data, [B, N, D]
155
+ Return:
156
+ new_xyz: sampled points position data, [B, 1, C]
157
+ new_points: sampled points data, [B, 1, N, C+D]
158
+ """
159
+ device = xyz.device
160
+ B, N, C = xyz.shape
161
+ #new_xyz = torch.zeros(B, 1, C).to(device)
162
+ new_xyz = xyz.mean(dim = 1, keepdim = True)
163
+ grouped_xyz = xyz.view(B, 1, N, C) - new_xyz.view(B, 1, 1, C)
164
+ if points is not None:
165
+ new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
166
+ else:
167
+ new_points = grouped_xyz
168
+ if density_scale is None:
169
+ return new_xyz, new_points, grouped_xyz
170
+ else:
171
+ grouped_density = density_scale.view(B, 1, N, 1)
172
+ return new_xyz, new_points, grouped_xyz, grouped_density
173
+
174
+ def group(nsample, xyz, points):
175
+ """
176
+ Input:
177
+ npoint:
178
+ nsample:
179
+ xyz: input points position data, [B, N, C]
180
+ points: input points data, [B, N, D]
181
+ Return:
182
+ new_xyz: sampled points position data, [B, 1, C]
183
+ new_points: sampled points data, [B, 1, N, C+D]
184
+ """
185
+ B, N, C = xyz.shape
186
+ S = N
187
+ new_xyz = xyz
188
+ idx = knn_point(nsample, xyz, new_xyz)
189
+ grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
190
+ grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
191
+ if points is not None:
192
+ grouped_points = index_points(points, idx)
193
+ new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
194
+ else:
195
+ new_points = grouped_xyz_norm
196
+
197
+ return new_points, grouped_xyz_norm
198
+
199
+ def compute_density(xyz, bandwidth):
200
+ '''
201
+ xyz: input points position data, [B, N, C]
202
+ '''
203
+ #import ipdb; ipdb.set_trace()
204
+ B, N, C = xyz.shape
205
+ sqrdists = square_distance(xyz, xyz)
206
+ gaussion_density = torch.exp(- sqrdists / (2.0 * bandwidth * bandwidth)) / (2.5 * bandwidth)
207
+ xyz_density = gaussion_density.mean(dim = -1)
208
+
209
+ return xyz_density
210
+
211
+ class DensityNet(nn.Module):
212
+ def __init__(self, hidden_unit = [16, 8]):
213
+ super(DensityNet, self).__init__()
214
+ self.mlp_convs = nn.ModuleList()
215
+ self.mlp_bns = nn.ModuleList()
216
+
217
+ self.mlp_convs.append(nn.Conv2d(1, hidden_unit[0], 1))
218
+ self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[0]))
219
+ for i in range(1, len(hidden_unit)):
220
+ self.mlp_convs.append(nn.Conv2d(hidden_unit[i - 1], hidden_unit[i], 1))
221
+ self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[i]))
222
+ self.mlp_convs.append(nn.Conv2d(hidden_unit[-1], 1, 1))
223
+ self.mlp_bns.append(nn.BatchNorm2d(1))
224
+
225
+ def forward(self, density_scale):
226
+ for i, conv in enumerate(self.mlp_convs):
227
+ bn = self.mlp_bns[i]
228
+ density_scale = bn(conv(density_scale))
229
+ if i == len(self.mlp_convs):
230
+ density_scale = F.sigmoid(density_scale)
231
+ else:
232
+ density_scale = F.relu(density_scale)
233
+
234
+ return density_scale
235
+
236
+ class WeightNet(nn.Module):
237
+
238
+ def __init__(self, in_channel, out_channel, hidden_unit = [8, 8]):
239
+ super(WeightNet, self).__init__()
240
+
241
+ self.mlp_convs = nn.ModuleList()
242
+ self.mlp_bns = nn.ModuleList()
243
+ if hidden_unit is None or len(hidden_unit) == 0:
244
+ self.mlp_convs.append(nn.Conv2d(in_channel, out_channel, 1))
245
+ self.mlp_bns.append(nn.BatchNorm2d(out_channel))
246
+ else:
247
+ self.mlp_convs.append(nn.Conv2d(in_channel, hidden_unit[0], 1))
248
+ self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[0]))
249
+ for i in range(1, len(hidden_unit)):
250
+ self.mlp_convs.append(nn.Conv2d(hidden_unit[i - 1], hidden_unit[i], 1))
251
+ self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[i]))
252
+ self.mlp_convs.append(nn.Conv2d(hidden_unit[-1], out_channel, 1))
253
+ self.mlp_bns.append(nn.BatchNorm2d(out_channel))
254
+
255
+ def forward(self, localized_xyz):
256
+ #xyz : BxCxKxN
257
+
258
+ weights = localized_xyz
259
+ for i, conv in enumerate(self.mlp_convs):
260
+ bn = self.mlp_bns[i]
261
+ weights = F.relu(bn(conv(weights)))
262
+
263
+ return weights
264
+
265
+ class PointConvSetAbstraction(nn.Module):
266
+ def __init__(self, npoint, nsample, in_channel, mlp, group_all):
267
+ super(PointConvSetAbstraction, self).__init__()
268
+ self.npoint = npoint
269
+ self.nsample = nsample
270
+ self.mlp_convs = nn.ModuleList()
271
+ self.mlp_bns = nn.ModuleList()
272
+ last_channel = in_channel
273
+ for out_channel in mlp:
274
+ self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
275
+ self.mlp_bns.append(nn.BatchNorm2d(out_channel))
276
+ last_channel = out_channel
277
+
278
+ self.weightnet = WeightNet(3, 16)
279
+ self.linear = nn.Linear(16 * mlp[-1], mlp[-1])
280
+ self.bn_linear = nn.BatchNorm1d(mlp[-1])
281
+ self.group_all = group_all
282
+
283
+ def forward(self, xyz, points):
284
+ """
285
+ Input:
286
+ xyz: input points position data, [B, C, N]
287
+ points: input points data, [B, D, N]
288
+ Return:
289
+ new_xyz: sampled points position data, [B, C, S]
290
+ new_points_concat: sample points feature data, [B, D', S]
291
+ """
292
+ B = xyz.shape[0]
293
+ xyz = xyz.permute(0, 2, 1)
294
+ if points is not None:
295
+ points = points.permute(0, 2, 1)
296
+
297
+ if self.group_all:
298
+ new_xyz, new_points, grouped_xyz_norm = sample_and_group_all(xyz, points)
299
+ else:
300
+ new_xyz, new_points, grouped_xyz_norm, _ = sample_and_group(self.npoint, self.nsample, xyz, points)
301
+ # new_xyz: sampled points position data, [B, npoint, C]
302
+ # new_points: sampled points data, [B, npoint, nsample, C+D]
303
+ new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
304
+ for i, conv in enumerate(self.mlp_convs):
305
+ bn = self.mlp_bns[i]
306
+ new_points = F.relu(bn(conv(new_points)))
307
+
308
+ grouped_xyz = grouped_xyz_norm.permute(0, 3, 2, 1)
309
+ weights = self.weightnet(grouped_xyz)
310
+ new_points = torch.matmul(input=new_points.permute(0, 3, 1, 2), other = weights.permute(0, 3, 2, 1)).view(B, self.npoint, -1)
311
+ new_points = self.linear(new_points)
312
+ new_points = self.bn_linear(new_points.permute(0, 2, 1))
313
+ new_points = F.relu(new_points)
314
+ new_xyz = new_xyz.permute(0, 2, 1)
315
+
316
+ return new_xyz, new_points
317
+
318
+ class PointConvDensitySetAbstraction(nn.Module):
319
+ def __init__(self, npoint, nsample, in_channel, mlp, bandwidth, group_all):
320
+ super(PointConvDensitySetAbstraction, self).__init__()
321
+ self.npoint = npoint
322
+ self.nsample = nsample
323
+ self.mlp_convs = nn.ModuleList()
324
+ self.mlp_bns = nn.ModuleList()
325
+ last_channel = in_channel
326
+ for out_channel in mlp:
327
+ self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
328
+ self.mlp_bns.append(nn.BatchNorm2d(out_channel))
329
+ last_channel = out_channel
330
+
331
+ self.weightnet = WeightNet(3, 16)
332
+ self.linear = nn.Linear(16 * mlp[-1], mlp[-1])
333
+ self.bn_linear = nn.BatchNorm1d(mlp[-1])
334
+ self.densitynet = DensityNet()
335
+ self.group_all = group_all
336
+ self.bandwidth = bandwidth
337
+
338
+ def forward(self, xyz, points):
339
+ """
340
+ Input:
341
+ xyz: input points position data, [B, C, N]
342
+ points: input points data, [B, D, N]
343
+ Return:
344
+ new_xyz: sampled points position data, [B, C, S]
345
+ new_points_concat: sample points feature data, [B, D', S]
346
+ """
347
+ B = xyz.shape[0]
348
+ N = xyz.shape[2]
349
+ xyz = xyz.permute(0, 2, 1)
350
+ if points is not None:
351
+ points = points.permute(0, 2, 1)
352
+
353
+ xyz_density = compute_density(xyz, self.bandwidth)
354
+ inverse_density = 1.0 / xyz_density
355
+
356
+ if self.group_all:
357
+ new_xyz, new_points, grouped_xyz_norm, grouped_density = sample_and_group_all(xyz, points, inverse_density.view(B, N, 1))
358
+ else:
359
+ new_xyz, new_points, grouped_xyz_norm, _, grouped_density = sample_and_group(self.npoint, self.nsample, xyz, points, inverse_density.view(B, N, 1))
360
+ # new_xyz: sampled points position data, [B, npoint, C]
361
+ # new_points: sampled points data, [B, npoint, nsample, C+D]
362
+ new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
363
+ for i, conv in enumerate(self.mlp_convs):
364
+ bn = self.mlp_bns[i]
365
+ new_points = F.relu(bn(conv(new_points)))
366
+
367
+ inverse_max_density = grouped_density.max(dim = 2, keepdim=True)[0]
368
+ density_scale = grouped_density / inverse_max_density
369
+ density_scale = self.densitynet(density_scale.permute(0, 3, 2, 1))
370
+ new_points = new_points * density_scale
371
+
372
+ grouped_xyz = grouped_xyz_norm.permute(0, 3, 2, 1)
373
+ weights = self.weightnet(grouped_xyz)
374
+ new_points = torch.matmul(input=new_points.permute(0, 3, 1, 2), other = weights.permute(0, 3, 2, 1)).view(B, self.npoint, -1)
375
+ new_points = self.linear(new_points)
376
+ new_points = self.bn_linear(new_points.permute(0, 2, 1))
377
+ new_points = F.relu(new_points)
378
+ new_xyz = new_xyz.permute(0, 2, 1)
379
+
380
+ return new_xyz, new_points
381
+
382
+
@@ -0,0 +1,244 @@
1
+ """Utilities for PointNet related functions
2
+
3
+ Modified from:
4
+ Pytorch Implementation of PointNet and PointNet++
5
+ https://github.com/yanx27/Pointnet_Pointnet2_pytorch
6
+ """
7
+
8
+ import torch
9
+
10
+
11
+ def angle_difference(src, dst):
12
+ """Calculate angle between each pair of vectors.
13
+ Assumes points are l2-normalized to unit length.
14
+
15
+ Input:
16
+ src: source points, [B, N, C]
17
+ dst: target points, [B, M, C]
18
+ Output:
19
+ dist: per-point square distance, [B, N, M]
20
+ """
21
+ B, N, _ = src.shape
22
+ _, M, _ = dst.shape
23
+ dist = torch.matmul(src, dst.permute(0, 2, 1))
24
+ dist = torch.acos(dist)
25
+
26
+ return dist
27
+
28
+
29
+ def square_distance(src, dst):
30
+ """Calculate Euclid distance between each two points.
31
+ src^T * dst = xn * xm + yn * ym + zn * zm;
32
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
33
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
34
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
35
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
36
+
37
+ Args:
38
+ src: source points, [B, N, C]
39
+ dst: target points, [B, M, C]
40
+ Returns:
41
+ dist: per-point square distance, [B, N, M]
42
+ """
43
+ B, N, _ = src.shape
44
+ _, M, _ = dst.shape
45
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
46
+ dist += torch.sum(src ** 2, dim=-1)[:, :, None]
47
+ dist += torch.sum(dst ** 2, dim=-1)[:, None, :]
48
+ return dist
49
+
50
+
51
+ def index_points(points, idx):
52
+ """Array indexing, i.e. retrieves relevant points based on indices
53
+
54
+ Args:
55
+ points: input points data_loader, [B, N, C]
56
+ idx: sample index data_loader, [B, S]. S can be 2 dimensional
57
+ Returns:
58
+ new_points:, indexed points data_loader, [B, S, C]
59
+ """
60
+ device = points.device
61
+ B = points.shape[0]
62
+ view_shape = list(idx.shape)
63
+ view_shape[1:] = [1] * (len(view_shape) - 1)
64
+ repeat_shape = list(idx.shape)
65
+ repeat_shape[0] = 1
66
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
67
+ new_points = points[batch_indices, idx, :]
68
+ return new_points
69
+
70
+
71
+ def farthest_point_sample(xyz, npoint):
72
+ """Iterative farthest point sampling
73
+
74
+ Args:
75
+ xyz: pointcloud data_loader, [B, N, C]
76
+ npoint: number of samples
77
+ Returns:
78
+ centroids: sampled pointcloud index, [B, npoint]
79
+ """
80
+ device = xyz.device
81
+ B, N, C = xyz.shape
82
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
83
+ distance = torch.ones(B, N).to(device) * 1e10
84
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
85
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
86
+ for i in range(npoint):
87
+ centroids[:, i] = farthest
88
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
89
+ dist = torch.sum((xyz - centroid) ** 2, -1)
90
+ mask = dist < distance
91
+ distance[mask] = dist[mask]
92
+ farthest = torch.max(distance, -1)[1]
93
+ return centroids
94
+
95
+
96
+ def query_ball_point(radius, nsample, xyz, new_xyz, itself_indices=None):
97
+ """ Grouping layer in PointNet++.
98
+
99
+ Inputs:
100
+ radius: local region radius
101
+ nsample: max sample number in local region
102
+ xyz: all points, (B, N, C)
103
+ new_xyz: query points, (B, S, C)
104
+ itself_indices (Optional): Indices of new_xyz into xyz (B, S).
105
+ Used to try and prevent grouping the point itself into the neighborhood.
106
+ If there is insufficient points in the neighborhood, or if left is none, the resulting cluster will
107
+ still contain the center point.
108
+ Returns:
109
+ group_idx: grouped points index, [B, S, nsample]
110
+ """
111
+ device = xyz.device
112
+ B, N, C = xyz.shape
113
+ _, S, _ = new_xyz.shape
114
+ group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) # (B, S, N)
115
+ sqrdists = square_distance(new_xyz, xyz)
116
+
117
+ if itself_indices is not None:
118
+ # Remove indices of the center points so that it will not be chosen
119
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)[:, None].repeat(1, S) # (B, S)
120
+ row_indices = torch.arange(S, dtype=torch.long).to(device)[None, :].repeat(B, 1) # (B, S)
121
+ group_idx[batch_indices, row_indices, itself_indices] = N
122
+
123
+ group_idx[sqrdists > radius ** 2] = N
124
+ group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
125
+ if itself_indices is not None:
126
+ group_first = itself_indices[:, :, None].repeat([1, 1, nsample])
127
+ else:
128
+ group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
129
+ mask = group_idx == N
130
+ group_idx[mask] = group_first[mask]
131
+ return group_idx
132
+
133
+
134
+ def sample_and_group(npoint: int, radius: float, nsample: int, xyz: torch.Tensor, points: torch.Tensor,
135
+ returnfps: bool=False):
136
+ """
137
+ Args:
138
+ npoint (int): Set to negative to compute for all points
139
+ radius:
140
+ nsample:
141
+ xyz: input points position data_loader, [B, N, C]
142
+ points: input points data_loader, [B, N, D]
143
+ returnfps (bool) Whether to return furthest point indices
144
+ Returns:
145
+ new_xyz: sampled points position data_loader, [B, 1, C]
146
+ new_points: sampled points data_loader, [B, 1, N, C+D]
147
+ """
148
+ B, N, C = xyz.shape
149
+
150
+ if npoint > 0:
151
+ S = npoint
152
+ fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
153
+ new_xyz = index_points(xyz, fps_idx)
154
+ else:
155
+ S = xyz.shape[1]
156
+ fps_idx = torch.arange(0, xyz.shape[1])[None, ...].repeat(xyz.shape[0], 1)
157
+ new_xyz = xyz
158
+
159
+ idx = query_ball_point(radius, nsample, xyz, new_xyz) # (B, N, nsample)
160
+ grouped_xyz = index_points(xyz, idx) # (B, npoint, nsample, C)
161
+ grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
162
+ if points is not None:
163
+ grouped_points = index_points(points, idx)
164
+ new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
165
+ else:
166
+ new_points = grouped_xyz_norm
167
+ if returnfps:
168
+ return new_xyz, new_points, grouped_xyz, fps_idx
169
+ else:
170
+ return new_xyz, new_points
171
+
172
+
173
+ def angle(v1: torch.Tensor, v2: torch.Tensor):
174
+ """Compute angle between 2 vectors
175
+
176
+ For robustness, we use the same formulation as in PPFNet, i.e.
177
+ angle(v1, v2) = atan2(cross(v1, v2), dot(v1, v2)).
178
+ This handles the case where one of the vectors is 0.0, since torch.atan2(0.0, 0.0)=0.0
179
+
180
+ Args:
181
+ v1: (B, *, 3)
182
+ v2: (B, *, 3)
183
+
184
+ Returns:
185
+
186
+ """
187
+
188
+ cross_prod = torch.stack([v1[..., 1] * v2[..., 2] - v1[..., 2] * v2[..., 1],
189
+ v1[..., 2] * v2[..., 0] - v1[..., 0] * v2[..., 2],
190
+ v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0]], dim=-1)
191
+ cross_prod_norm = torch.norm(cross_prod, dim=-1)
192
+ dot_prod = torch.sum(v1 * v2, dim=-1)
193
+
194
+ return torch.atan2(cross_prod_norm, dot_prod)
195
+
196
+
197
+ def sample_and_group_multi(npoint: int, radius: float, nsample: int, xyz: torch.Tensor, normals: torch.Tensor,
198
+ returnfps: bool = False):
199
+ """Sample and group for xyz, dxyz and ppf features
200
+
201
+ Args:
202
+ npoint(int): Number of clusters (equivalently, keypoints) to sample.
203
+ Set to negative to compute for all points
204
+ radius(int): Radius of cluster for computing local features
205
+ nsample: Maximum number of points to consider per cluster
206
+ xyz: XYZ coordinates of the points
207
+ normals: Corresponding normals for the points (required for ppf computation)
208
+ returnfps: Whether to return indices of FPS points and their neighborhood
209
+
210
+ Returns:
211
+ Dictionary containing the following fields ['xyz', 'dxyz', 'ppf'].
212
+ If returnfps is True, also returns: grouped_xyz, fps_idx
213
+ """
214
+
215
+ B, N, C = xyz.shape
216
+
217
+ if npoint > 0:
218
+ S = npoint
219
+ fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
220
+ new_xyz = index_points(xyz, fps_idx)
221
+ nr = index_points(normals, fps_idx)[:, :, None, :]
222
+ else:
223
+ S = xyz.shape[1]
224
+ fps_idx = torch.arange(0, xyz.shape[1])[None, ...].repeat(xyz.shape[0], 1).to(xyz.device)
225
+ new_xyz = xyz
226
+ nr = normals[:, :, None, :]
227
+
228
+ idx = query_ball_point(radius, nsample, xyz, new_xyz, fps_idx) # (B, npoint, nsample)
229
+ grouped_xyz = index_points(xyz, idx) # (B, npoint, nsample, C)
230
+ d = grouped_xyz - new_xyz.view(B, S, 1, C) # d = p_r - p_i (B, npoint, nsample, 3)
231
+ ni = index_points(normals, idx)
232
+
233
+ nr_d = angle(nr, d)
234
+ ni_d = angle(ni, d)
235
+ nr_ni = angle(nr, ni)
236
+ d_norm = torch.norm(d, dim=-1)
237
+
238
+ xyz_feat = d # (B, npoint, n_sample, 3)
239
+ ppf_feat = torch.stack([nr_d, ni_d, nr_ni, d_norm], dim=-1) # (B, npoint, n_sample, 4)
240
+
241
+ if returnfps:
242
+ return {'xyz': new_xyz, 'dxyz': xyz_feat, 'ppf': ppf_feat}, grouped_xyz, fps_idx
243
+ else:
244
+ return {'xyz': new_xyz, 'dxyz': xyz_feat, 'ppf': ppf_feat}
@@ -0,0 +1,59 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class SVDHead(nn.Module):
6
+ def __init__(self, emb_dims, input_shape="bnc"):
7
+ super(SVDHead, self).__init__()
8
+ self.emb_dims = emb_dims
9
+ self.reflect = nn.Parameter(torch.eye(3), requires_grad=False)
10
+ self.reflect[2, 2] = -1
11
+ self.input_shape = input_shape
12
+
13
+ def forward(self, *input):
14
+ src_embedding = input[0]
15
+ tgt_embedding = input[1]
16
+ src = input[2]
17
+ tgt = input[3]
18
+ batch_size = src.size(0)
19
+ if self.input_shape == "bnc":
20
+ src = src.permute(0, 2, 1)
21
+ tgt = tgt.permute(0, 2, 1)
22
+
23
+ d_k = src_embedding.size(1)
24
+ scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k)
25
+ scores = torch.softmax(scores, dim=2)
26
+
27
+ src_corr = torch.matmul(tgt, scores.transpose(2, 1).contiguous())
28
+
29
+ src_centered = src - src.mean(dim=2, keepdim=True)
30
+
31
+ src_corr_centered = src_corr - src_corr.mean(dim=2, keepdim=True)
32
+
33
+ H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous())
34
+
35
+ U, S, V = [], [], []
36
+ R = []
37
+
38
+ for i in range(src.size(0)):
39
+ u, s, v = torch.svd(H[i])
40
+ r = torch.matmul(v, u.transpose(1, 0).contiguous())
41
+ r_det = torch.det(r)
42
+ if r_det < 0:
43
+ u, s, v = torch.svd(H[i])
44
+ v = torch.matmul(v, self.reflect)
45
+ r = torch.matmul(v, u.transpose(1, 0).contiguous())
46
+ # r = r * self.reflect
47
+ R.append(r)
48
+
49
+ U.append(u)
50
+ S.append(s)
51
+ V.append(v)
52
+
53
+ U = torch.stack(U, dim=0)
54
+ V = torch.stack(V, dim=0)
55
+ S = torch.stack(S, dim=0)
56
+ R = torch.stack(R, dim=0)
57
+
58
+ t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True)
59
+ return R, t.view(batch_size, 3)