learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,446 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from time import time
5
+ import numpy as np
6
+
7
+ try:
8
+ from .. utils import pointnet2_utils as pointutils
9
+ except:
10
+ print("Error in pointnet2_utils! Retry setup for pointnet2_utils.")
11
+
12
+ def timeit(tag, t):
13
+ print("{}: {}s".format(tag, time() - t))
14
+ return time()
15
+
16
+ def pc_normalize(pc):
17
+ l = pc.shape[0]
18
+ centroid = np.mean(pc, axis=0)
19
+ pc = pc - centroid
20
+ m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
21
+ pc = pc / m
22
+ return pc
23
+
24
+ def square_distance(src, dst):
25
+ """
26
+ Calculate Euclid distance between each two points.
27
+ src^T * dst = xn * xm + yn * ym + zn * zm;
28
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
29
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
30
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
31
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
32
+ Input:
33
+ src: source points, [B, N, C]
34
+ dst: target points, [B, M, C]
35
+ Output:
36
+ dist: per-point square distance, [B, N, M]
37
+ """
38
+ B, N, _ = src.shape
39
+ _, M, _ = dst.shape
40
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
41
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
42
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
43
+ return dist
44
+
45
+
46
+ def index_points(points, idx):
47
+ """
48
+ Input:
49
+ points: input points data, [B, N, C]
50
+ idx: sample index data, [B, S]
51
+ Return:
52
+ new_points:, indexed points data, [B, S, C]
53
+ """
54
+ device = points.device
55
+ B = points.shape[0]
56
+ view_shape = list(idx.shape)
57
+ view_shape[1:] = [1] * (len(view_shape) - 1)
58
+ repeat_shape = list(idx.shape)
59
+ repeat_shape[0] = 1
60
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
61
+ new_points = points[batch_indices, idx, :]
62
+ return new_points
63
+
64
+
65
+ def farthest_point_sample(xyz, npoint):
66
+ """
67
+ Input:
68
+ xyz: pointcloud data, [B, N, C]
69
+ npoint: number of samples
70
+ Return:
71
+ centroids: sampled pointcloud index, [B, npoint]
72
+ """
73
+ device = xyz.device
74
+ B, N, C = xyz.shape
75
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
76
+ distance = torch.ones(B, N).to(device) * 1e10
77
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
78
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
79
+ for i in range(npoint):
80
+ centroids[:, i] = farthest
81
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
82
+ dist = torch.sum((xyz - centroid) ** 2, -1)
83
+ mask = dist < distance
84
+ distance[mask] = dist[mask]
85
+ farthest = torch.max(distance, -1)[1]
86
+ return centroids
87
+
88
+ def knn_point(k, pos1, pos2):
89
+ '''
90
+ Input:
91
+ k: int32, number of k in k-nn search
92
+ pos1: (batch_size, ndataset, c) float32 array, input points
93
+ pos2: (batch_size, npoint, c) float32 array, query points
94
+ Output:
95
+ val: (batch_size, npoint, k) float32 array, L2 distances
96
+ idx: (batch_size, npoint, k) int32 array, indices to input points
97
+ '''
98
+ B, N, C = pos1.shape
99
+ M = pos2.shape[1]
100
+ pos1 = pos1.view(B,1,N,-1).repeat(1,M,1,1)
101
+ pos2 = pos2.view(B,M,1,-1).repeat(1,1,N,1)
102
+ dist = torch.sum(-(pos1-pos2)**2,-1)
103
+ val,idx = dist.topk(k=k,dim = -1)
104
+ return torch.sqrt(-val), idx
105
+
106
+
107
+ def query_ball_point(radius, nsample, xyz, new_xyz):
108
+ """
109
+ Input:
110
+ radius: local region radius
111
+ nsample: max sample number in local region
112
+ xyz: all points, [B, N, C]
113
+ new_xyz: query points, [B, S, C]
114
+ Return:
115
+ group_idx: grouped points index, [B, S, nsample]
116
+ """
117
+ device = xyz.device
118
+ B, N, C = xyz.shape
119
+ _, S, _ = new_xyz.shape
120
+ group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
121
+ sqrdists = square_distance(new_xyz, xyz)
122
+ group_idx[sqrdists > radius ** 2] = N
123
+ mask = group_idx != N
124
+ cnt = mask.sum(dim=-1)
125
+ group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
126
+ group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
127
+ mask = group_idx == N
128
+ group_idx[mask] = group_first[mask]
129
+ return group_idx, cnt
130
+
131
+
132
+ def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
133
+ """
134
+ Input:
135
+ npoint:
136
+ radius:
137
+ nsample:
138
+ xyz: input points position data, [B, N, C]
139
+ points: input points data, [B, N, D]
140
+ Return:
141
+ new_xyz: sampled points position data, [B, 1, C]
142
+ new_points: sampled points data, [B, 1, N, C+D]
143
+ """
144
+ B, N, C = xyz.shape
145
+ S = npoint
146
+ fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
147
+ new_xyz = index_points(xyz, fps_idx)
148
+ idx, _ = query_ball_point(radius, nsample, xyz, new_xyz)
149
+ grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
150
+ grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
151
+ if points is not None:
152
+ grouped_points = index_points(points, idx)
153
+ new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
154
+ else:
155
+ new_points = grouped_xyz_norm
156
+ if returnfps:
157
+ return new_xyz, new_points, grouped_xyz, fps_idx
158
+ else:
159
+ return new_xyz, new_points
160
+
161
+
162
+ def sample_and_group_all(xyz, points):
163
+ """
164
+ Input:
165
+ xyz: input points position data, [B, N, C]
166
+ points: input points data, [B, N, D]
167
+ Return:
168
+ new_xyz: sampled points position data, [B, 1, C]
169
+ new_points: sampled points data, [B, 1, N, C+D]
170
+ """
171
+ device = xyz.device
172
+ B, N, C = xyz.shape
173
+ new_xyz = torch.zeros(B, 1, C).to(device)
174
+ grouped_xyz = xyz.view(B, 1, N, C)
175
+ if points is not None:
176
+ new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
177
+ else:
178
+ new_points = grouped_xyz
179
+ return new_xyz, new_points
180
+
181
+ class PointNetSetAbstraction(nn.Module):
182
+ def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
183
+ super(PointNetSetAbstraction, self).__init__()
184
+ self.npoint = npoint
185
+ self.radius = radius
186
+ self.nsample = nsample
187
+ self.group_all = group_all
188
+ self.mlp_convs = nn.ModuleList()
189
+ self.mlp_bns = nn.ModuleList()
190
+ last_channel = in_channel+3 # TODO:
191
+ for out_channel in mlp:
192
+ self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias = False))
193
+ self.mlp_bns.append(nn.BatchNorm2d(out_channel))
194
+ last_channel = out_channel
195
+
196
+ if group_all:
197
+ self.queryandgroup = pointutils.GroupAll()
198
+ else:
199
+ self.queryandgroup = pointutils.QueryAndGroup(radius, nsample)
200
+
201
+ def forward(self, xyz, points):
202
+ """
203
+ Input:
204
+ xyz: input points position data, [B, C, N]
205
+ points: input points data, [B, D, N]
206
+ Return:
207
+ new_xyz: sampled points position data, [B, S, C]
208
+ new_points_concat: sample points feature data, [B, S, D']
209
+ """
210
+ device = xyz.device
211
+ B, C, N = xyz.shape
212
+ xyz_t = xyz.permute(0, 2, 1).contiguous()
213
+ # if points is not None:
214
+ # points = points.permute(0, 2, 1).contiguous()
215
+
216
+ # 选取邻域点
217
+ if self.group_all == False:
218
+ fps_idx = pointutils.furthest_point_sample(xyz_t, self.npoint) # [B, N]
219
+ new_xyz = pointutils.gather_operation(xyz, fps_idx) # [B, C, N]
220
+ else:
221
+ new_xyz = xyz
222
+ new_points = self.queryandgroup(xyz_t, new_xyz.transpose(2, 1).contiguous(), points) # [B, 3+C, N, S]
223
+
224
+ # new_xyz: sampled points position data, [B, C, npoint]
225
+ # new_points: sampled points data, [B, C+D, npoint, nsample]
226
+ for i, conv in enumerate(self.mlp_convs):
227
+ bn = self.mlp_bns[i]
228
+ new_points = F.relu(bn(conv(new_points)))
229
+
230
+ new_points = torch.max(new_points, -1)[0]
231
+ return new_xyz, new_points
232
+
233
+ class FlowEmbedding(nn.Module):
234
+ def __init__(self, radius, nsample, in_channel, mlp, pooling='max', corr_func='concat', knn = True):
235
+ super(FlowEmbedding, self).__init__()
236
+ self.radius = radius
237
+ self.nsample = nsample
238
+ self.knn = knn
239
+ self.pooling = pooling
240
+ self.corr_func = corr_func
241
+ self.mlp_convs = nn.ModuleList()
242
+ self.mlp_bns = nn.ModuleList()
243
+ if corr_func is 'concat':
244
+ last_channel = in_channel*2+3
245
+ for out_channel in mlp:
246
+ self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias=False))
247
+ self.mlp_bns.append(nn.BatchNorm2d(out_channel))
248
+ last_channel = out_channel
249
+
250
+ def forward(self, pos1, pos2, feature1, feature2):
251
+ """
252
+ Input:
253
+ xyz1: (batch_size, 3, npoint)
254
+ xyz2: (batch_size, 3, npoint)
255
+ feat1: (batch_size, channel, npoint)
256
+ feat2: (batch_size, channel, npoint)
257
+ Output:
258
+ xyz1: (batch_size, 3, npoint)
259
+ feat1_new: (batch_size, mlp[-1], npoint)
260
+ """
261
+ pos1_t = pos1.permute(0, 2, 1).contiguous()
262
+ pos2_t = pos2.permute(0, 2, 1).contiguous()
263
+ B, N, C = pos1_t.shape
264
+ if self.knn:
265
+ _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t)
266
+ else:
267
+ # If the ball neighborhood points are less than nsample,
268
+ # than use the knn neighborhood points
269
+ idx, cnt = query_ball_point(self.radius, self.nsample, pos2_t, pos1_t)
270
+ # 利用knn取最近的那些点
271
+ _, idx_knn = pointutils.knn(self.nsample, pos1_t, pos2_t)
272
+ cnt = cnt.view(B, -1, 1).repeat(1, 1, self.nsample)
273
+ idx = idx_knn[cnt > (self.nsample-1)]
274
+
275
+ pos2_grouped = pointutils.grouping_operation(pos2, idx) # [B, 3, N, S]
276
+ pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B, 3, N, S]
277
+
278
+ feat2_grouped = pointutils.grouping_operation(feature2, idx) # [B, C, N, S]
279
+ if self.corr_func=='concat':
280
+ feat_diff = torch.cat([feat2_grouped, feature1.view(B, -1, N, 1).repeat(1, 1, 1, self.nsample)], dim = 1)
281
+
282
+ feat1_new = torch.cat([pos_diff, feat_diff], dim = 1) # [B, 2*C+3,N,S]
283
+ for i, conv in enumerate(self.mlp_convs):
284
+ bn = self.mlp_bns[i]
285
+ feat1_new = F.relu(bn(conv(feat1_new)))
286
+
287
+ feat1_new = torch.max(feat1_new, -1)[0] # [B, mlp[-1], npoint]
288
+ return pos1, feat1_new
289
+
290
+ class PointNetSetUpConv(nn.Module):
291
+ def __init__(self, nsample, radius, f1_channel, f2_channel, mlp, mlp2, knn = True):
292
+ super(PointNetSetUpConv, self).__init__()
293
+ self.nsample = nsample
294
+ self.radius = radius
295
+ self.knn = knn
296
+ self.mlp1_convs = nn.ModuleList()
297
+ self.mlp2_convs = nn.ModuleList()
298
+ last_channel = f2_channel+3
299
+ for out_channel in mlp:
300
+ self.mlp1_convs.append(nn.Sequential(nn.Conv2d(last_channel, out_channel, 1, bias=False),
301
+ nn.BatchNorm2d(out_channel),
302
+ nn.ReLU(inplace=False)))
303
+ last_channel = out_channel
304
+ if len(mlp) is not 0:
305
+ last_channel = mlp[-1] + f1_channel
306
+ else:
307
+ last_channel = last_channel + f1_channel
308
+ for out_channel in mlp2:
309
+ self.mlp2_convs.append(nn.Sequential(nn.Conv1d(last_channel, out_channel, 1, bias=False),
310
+ nn.BatchNorm1d(out_channel),
311
+ nn.ReLU(inplace=False)))
312
+ last_channel = out_channel
313
+
314
+ def forward(self, pos1, pos2, feature1, feature2):
315
+ """
316
+ Feature propagation from xyz2 (less points) to xyz1 (more points)
317
+ Inputs:
318
+ xyz1: (batch_size, 3, npoint1)
319
+ xyz2: (batch_size, 3, npoint2)
320
+ feat1: (batch_size, channel1, npoint1) features for xyz1 points (earlier layers, more points)
321
+ feat2: (batch_size, channel1, npoint2) features for xyz2 points
322
+ Output:
323
+ feat1_new: (batch_size, npoint2, mlp[-1] or mlp2[-1] or channel1+3)
324
+ TODO: Add support for skip links. Study how delta(XYZ) plays a role in feature updating.
325
+ """
326
+ pos1_t = pos1.permute(0, 2, 1).contiguous()
327
+ pos2_t = pos2.permute(0, 2, 1).contiguous()
328
+ B,C,N = pos1.shape
329
+ if self.knn:
330
+ _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t)
331
+ else:
332
+ idx, _ = query_ball_point(self.radius, self.nsample, pos2_t, pos1_t)
333
+
334
+ pos2_grouped = pointutils.grouping_operation(pos2, idx)
335
+ pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B,3,N1,S]
336
+
337
+ feat2_grouped = pointutils.grouping_operation(feature2, idx)
338
+ feat_new = torch.cat([feat2_grouped, pos_diff], dim = 1) # [B,C1+3,N1,S]
339
+ for conv in self.mlp1_convs:
340
+ feat_new = conv(feat_new)
341
+ # max pooling
342
+ feat_new = feat_new.max(-1)[0] # [B,mlp1[-1],N1]
343
+ # concatenate feature in early layer
344
+ if feature1 is not None:
345
+ feat_new = torch.cat([feat_new, feature1], dim=1)
346
+ # feat_new = feat_new.view(B,-1,N,1)
347
+ for conv in self.mlp2_convs:
348
+ feat_new = conv(feat_new)
349
+
350
+ return feat_new
351
+
352
+ class PointNetFeaturePropogation(nn.Module):
353
+ def __init__(self, in_channel, mlp):
354
+ super(PointNetFeaturePropogation, self).__init__()
355
+ self.mlp_convs = nn.ModuleList()
356
+ self.mlp_bns = nn.ModuleList()
357
+ last_channel = in_channel
358
+ for out_channel in mlp:
359
+ self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
360
+ self.mlp_bns.append(nn.BatchNorm1d(out_channel))
361
+ last_channel = out_channel
362
+
363
+ def forward(self, pos1, pos2, feature1, feature2):
364
+ """
365
+ Input:
366
+ xyz1: input points position data, [B, C, N]
367
+ xyz2: sampled input points position data, [B, C, S]
368
+ points1: input points data, [B, D, N]
369
+ points2: input points data, [B, D, S]
370
+ Return:
371
+ new_points: upsampled points data, [B, D', N]
372
+ """
373
+ pos1_t = pos1.permute(0, 2, 1).contiguous()
374
+ pos2_t = pos2.permute(0, 2, 1).contiguous()
375
+ B, C, N = pos1.shape
376
+
377
+ # dists = square_distance(pos1, pos2)
378
+ # dists, idx = dists.sort(dim=-1)
379
+ # dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
380
+ dists,idx = pointutils.three_nn(pos1_t,pos2_t)
381
+ dists[dists < 1e-10] = 1e-10
382
+ weight = 1.0 / dists
383
+ weight = weight / torch.sum(weight, -1,keepdim = True) # [B,N,3]
384
+ interpolated_feat = torch.sum(pointutils.grouping_operation(feature2, idx) * weight.view(B, 1, N, 3), dim = -1) # [B,C,N,3]
385
+
386
+ if feature1 is not None:
387
+ feat_new = torch.cat([interpolated_feat, feature1], 1)
388
+ else:
389
+ feat_new = interpolated_feat
390
+
391
+ for i, conv in enumerate(self.mlp_convs):
392
+ bn = self.mlp_bns[i]
393
+ feat_new = F.relu(bn(conv(feat_new)))
394
+ return feat_new
395
+
396
+
397
+ class FlowNet3D(nn.Module):
398
+ def __init__(self):
399
+ super(FlowNet3D, self).__init__()
400
+
401
+ self.sa1 = PointNetSetAbstraction(npoint=1024, radius=0.5, nsample=16, in_channel=3, mlp=[32,32,64], group_all=False)
402
+ self.sa2 = PointNetSetAbstraction(npoint=256, radius=1.0, nsample=16, in_channel=64, mlp=[64, 64, 128], group_all=False)
403
+ self.sa3 = PointNetSetAbstraction(npoint=64, radius=2.0, nsample=8, in_channel=128, mlp=[128, 128, 256], group_all=False)
404
+ self.sa4 = PointNetSetAbstraction(npoint=16, radius=4.0, nsample=8, in_channel=256, mlp=[256, 256, 512], group_all=False)
405
+
406
+ self.fe_layer = FlowEmbedding(radius=10.0, nsample=64, in_channel = 128, mlp=[128, 128, 128], pooling='max', corr_func='concat')
407
+
408
+ self.su1 = PointNetSetUpConv(nsample=8, radius=2.4, f1_channel = 256, f2_channel = 512, mlp=[], mlp2=[256, 256])
409
+ self.su2 = PointNetSetUpConv(nsample=8, radius=1.2, f1_channel = 128+128, f2_channel = 256, mlp=[128, 128, 256], mlp2=[256])
410
+ self.su3 = PointNetSetUpConv(nsample=8, radius=0.6, f1_channel = 64, f2_channel = 256, mlp=[128, 128, 256], mlp2=[256])
411
+ self.fp = PointNetFeaturePropogation(in_channel = 256+3, mlp = [256, 256])
412
+
413
+ self.conv1 = nn.Conv1d(256, 128, kernel_size=1, bias=False)
414
+ self.bn1 = nn.BatchNorm1d(128)
415
+ self.conv2=nn.Conv1d(128, 3, kernel_size=1, bias=True)
416
+
417
+ def forward(self, pc1, pc2, feature1, feature2):
418
+ l1_pc1, l1_feature1 = self.sa1(pc1, feature1)
419
+ l2_pc1, l2_feature1 = self.sa2(l1_pc1, l1_feature1)
420
+
421
+ l1_pc2, l1_feature2 = self.sa1(pc2, feature2)
422
+ l2_pc2, l2_feature2 = self.sa2(l1_pc2, l1_feature2)
423
+
424
+ _, l2_feature1_new = self.fe_layer(l2_pc1, l2_pc2, l2_feature1, l2_feature2)
425
+
426
+ l3_pc1, l3_feature1 = self.sa3(l2_pc1, l2_feature1_new)
427
+ l4_pc1, l4_feature1 = self.sa4(l3_pc1, l3_feature1)
428
+
429
+ l3_fnew1 = self.su1(l3_pc1, l4_pc1, l3_feature1, l4_feature1)
430
+ l2_fnew1 = self.su2(l2_pc1, l3_pc1, torch.cat([l2_feature1, l2_feature1_new], dim=1), l3_fnew1)
431
+ l1_fnew1 = self.su3(l1_pc1, l2_pc1, l1_feature1, l2_fnew1)
432
+ l0_fnew1 = self.fp(pc1, l1_pc1, feature1, l1_fnew1)
433
+
434
+ x = F.relu(self.bn1(self.conv1(l0_fnew1)))
435
+ sf = self.conv2(x)
436
+ return sf
437
+
438
+ if __name__ == '__main__':
439
+ import os
440
+ import torch
441
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0'
442
+ input = torch.randn((8,3,2048))
443
+ label = torch.randn(8,16)
444
+ model = FlowNet3D()
445
+ output = model(input,input)
446
+ print(output.size())
@@ -0,0 +1,84 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .pointnet import PointNet
5
+ from .pooling import Pooling
6
+
7
+ class PointNetMask(nn.Module):
8
+ def __init__(self, template_feature_size=1024, source_feature_size=1024, feature_model=PointNet()):
9
+ super().__init__()
10
+ self.feature_model = feature_model
11
+ self.pooling = Pooling()
12
+
13
+ input_size = template_feature_size + source_feature_size
14
+ self.h3 = nn.Sequential(nn.Conv1d(input_size, 1024, 1), nn.ReLU(),
15
+ nn.Conv1d(1024, 512, 1), nn.ReLU(),
16
+ nn.Conv1d(512, 256, 1), nn.ReLU(),
17
+ nn.Conv1d(256, 128, 1), nn.ReLU(),
18
+ nn.Conv1d(128, 1, 1), nn.Sigmoid())
19
+
20
+ def find_mask(self, x, t_out_h1):
21
+ batch_size, _ , num_points = t_out_h1.size()
22
+ x = x.unsqueeze(2)
23
+ x = x.repeat(1,1,num_points)
24
+ x = torch.cat([t_out_h1, x], dim=1)
25
+ x = self.h3(x)
26
+ return x.view(batch_size, -1)
27
+
28
+ def forward(self, template, source):
29
+ source_features = self.feature_model(source) # [B x C x N]
30
+ template_features = self.feature_model(template) # [B x C x N]
31
+
32
+ source_features = self.pooling(source_features)
33
+ mask = self.find_mask(source_features, template_features)
34
+ return mask
35
+
36
+
37
+ class MaskNet(nn.Module):
38
+ def __init__(self, feature_model=PointNet(use_bn=True), is_training=True):
39
+ super().__init__()
40
+ self.maskNet = PointNetMask(feature_model=feature_model)
41
+ self.is_training = is_training
42
+
43
+ @staticmethod
44
+ def index_points(points, idx):
45
+ """
46
+ Input:
47
+ points: input points data, [B, N, C]
48
+ idx: sample index data, [B, S]
49
+ Return:
50
+ new_points:, indexed points data, [B, S, C]
51
+ """
52
+ device = points.device
53
+ B = points.shape[0]
54
+ view_shape = list(idx.shape)
55
+ view_shape[1:] = [1] * (len(view_shape) - 1)
56
+ repeat_shape = list(idx.shape)
57
+ repeat_shape[0] = 1
58
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
59
+ new_points = points[batch_indices, idx, :]
60
+ return new_points
61
+
62
+ # This function is only useful for testing with a single pair of point clouds.
63
+ @staticmethod
64
+ def find_index(mask_val):
65
+ mask_idx = torch.nonzero((mask_val[0]>0.5)*1.0)
66
+ return mask_idx.view(1, -1)
67
+
68
+ def forward(self, template, source, point_selection='threshold'):
69
+ mask = self.maskNet(template, source)
70
+
71
+ if point_selection == 'topk' or self.is_training:
72
+ _, self.mask_idx = torch.topk(mask, source.shape[1], dim=1, sorted=False)
73
+ elif point_selection == 'threshold':
74
+ self.mask_idx = self.find_index(mask)
75
+
76
+ template = self.index_points(template, self.mask_idx)
77
+ return template, mask
78
+
79
+
80
+ if __name__ == '__main__':
81
+ template, source = torch.rand(10,1024,3), torch.rand(10,1024,3)
82
+ net = MaskNet()
83
+ result = net(template, source)
84
+ import ipdb; ipdb.set_trace()