learning3d 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {learning3d/data_utils → data_utils}/dataloaders.py +16 -14
  2. examples/test_curvenet.py +118 -0
  3. {learning3d/examples → examples}/test_dcp.py +3 -5
  4. {learning3d/examples → examples}/test_deepgmr.py +3 -5
  5. {learning3d/examples → examples}/test_masknet.py +1 -3
  6. {learning3d/examples → examples}/test_masknet2.py +1 -3
  7. {learning3d/examples → examples}/test_pcn.py +2 -4
  8. {learning3d/examples → examples}/test_pcrnet.py +1 -3
  9. {learning3d/examples → examples}/test_pnlk.py +1 -3
  10. {learning3d/examples → examples}/test_pointconv.py +1 -3
  11. {learning3d/examples → examples}/test_pointnet.py +1 -3
  12. {learning3d/examples → examples}/test_prnet.py +3 -5
  13. {learning3d/examples → examples}/test_rpmnet.py +1 -3
  14. {learning3d/examples → examples}/train_PointNetLK.py +2 -4
  15. {learning3d/examples → examples}/train_dcp.py +2 -4
  16. {learning3d/examples → examples}/train_deepgmr.py +2 -4
  17. {learning3d/examples → examples}/train_masknet.py +2 -4
  18. {learning3d/examples → examples}/train_pcn.py +2 -4
  19. {learning3d/examples → examples}/train_pcrnet.py +2 -4
  20. {learning3d/examples → examples}/train_pointconv.py +2 -4
  21. {learning3d/examples → examples}/train_pointnet.py +2 -4
  22. {learning3d/examples → examples}/train_prnet.py +2 -4
  23. {learning3d/examples → examples}/train_rpmnet.py +2 -4
  24. {learning3d-0.1.0.dist-info → learning3d-0.2.1.dist-info}/METADATA +57 -12
  25. learning3d-0.2.1.dist-info/RECORD +70 -0
  26. {learning3d-0.1.0.dist-info → learning3d-0.2.1.dist-info}/WHEEL +1 -1
  27. learning3d-0.2.1.dist-info/top_level.txt +6 -0
  28. {learning3d/models → models}/__init__.py +7 -1
  29. models/curvenet.py +130 -0
  30. {learning3d/models → models}/dgcnn.py +1 -35
  31. {learning3d/models → models}/prnet.py +5 -39
  32. utils/__init__.py +23 -0
  33. utils/curvenet_util.py +540 -0
  34. utils/model_common_utils.py +156 -0
  35. learning3d/losses/cuda/chamfer_distance/__init__.py +0 -1
  36. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +0 -185
  37. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +0 -209
  38. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +0 -66
  39. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +0 -41
  40. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +0 -347
  41. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +0 -18
  42. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +0 -54
  43. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +0 -1
  44. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +0 -40
  45. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +0 -70
  46. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +0 -1
  47. learning3d/losses/cuda/emd_torch/setup.py +0 -29
  48. learning3d/ops/__init__.py +0 -0
  49. learning3d/utils/__init__.py +0 -4
  50. learning3d-0.1.0.dist-info/RECORD +0 -80
  51. learning3d-0.1.0.dist-info/top_level.txt +0 -1
  52. {learning3d/data_utils → data_utils}/__init__.py +0 -0
  53. {learning3d/data_utils → data_utils}/user_data.py +0 -0
  54. {learning3d-0.1.0.dist-info → learning3d-0.2.1.dist-info}/LICENSE +0 -0
  55. {learning3d/losses → losses}/__init__.py +0 -0
  56. {learning3d/losses → losses}/chamfer_distance.py +0 -0
  57. {learning3d/losses → losses}/classification.py +0 -0
  58. {learning3d/losses → losses}/correspondence_loss.py +0 -0
  59. {learning3d/losses → losses}/emd.py +0 -0
  60. {learning3d/losses → losses}/frobenius_norm.py +0 -0
  61. {learning3d/losses → losses}/rmse_features.py +0 -0
  62. {learning3d/models → models}/classifier.py +0 -0
  63. {learning3d/models → models}/dcp.py +0 -0
  64. {learning3d/models → models}/deepgmr.py +0 -0
  65. {learning3d/models → models}/masknet.py +0 -0
  66. {learning3d/models → models}/masknet2.py +0 -0
  67. {learning3d/models → models}/pcn.py +0 -0
  68. {learning3d/models → models}/pcrnet.py +0 -0
  69. {learning3d/models → models}/pointconv.py +0 -0
  70. {learning3d/models → models}/pointnet.py +0 -0
  71. {learning3d/models → models}/pointnetlk.py +0 -0
  72. {learning3d/models → models}/pooling.py +0 -0
  73. {learning3d/models → models}/ppfnet.py +0 -0
  74. {learning3d/models → models}/rpmnet.py +0 -0
  75. {learning3d/models → models}/segmentation.py +0 -0
  76. {learning3d → ops}/__init__.py +0 -0
  77. {learning3d/ops → ops}/data_utils.py +0 -0
  78. {learning3d/ops → ops}/invmat.py +0 -0
  79. {learning3d/ops → ops}/quaternion.py +0 -0
  80. {learning3d/ops → ops}/se3.py +0 -0
  81. {learning3d/ops → ops}/sinc.py +0 -0
  82. {learning3d/ops → ops}/so3.py +0 -0
  83. {learning3d/ops → ops}/transform_functions.py +0 -0
  84. {learning3d/utils → utils}/pointconv_util.py +0 -0
  85. {learning3d/utils → utils}/ppfnet_util.py +0 -0
  86. {learning3d/utils → utils}/svd.py +0 -0
  87. {learning3d/utils → utils}/transformer.py +0 -0
utils/curvenet_util.py ADDED
@@ -0,0 +1,540 @@
1
+ """
2
+ @Author: Yue Wang
3
+ @Contact: yuewangx@mit.edu
4
+ @File: pointnet_util.py
5
+ @Time: 2018/10/13 10:39 PM
6
+
7
+ Modified by
8
+ @Author: Tiange Xiang
9
+ @Contact: txia7609@uni.sydney.edu.au
10
+ @Time: 2021/01/21 3:10 PM
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from time import time
17
+ import numpy as np
18
+ from .model_common_utils import (
19
+ knn,
20
+ square_distance,
21
+ index_points,
22
+ farthest_point_sample,
23
+ query_ball_point,
24
+ )
25
+
26
+ def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
27
+ """
28
+ Input:
29
+ npoint:
30
+ radius:
31
+ nsample:
32
+ xyz: input points position data, [B, N, 3]
33
+ points: input points data, [B, N, D]
34
+ Return:
35
+ new_xyz: sampled points position data, [B, npoint, nsample, 3]
36
+ new_points: sampled points data, [B, npoint, nsample, 3+D]
37
+ """
38
+ new_xyz = index_points(xyz, farthest_point_sample(xyz, npoint, start_with_first_point=True))
39
+ torch.cuda.empty_cache()
40
+
41
+ idx = query_ball_point(radius, nsample, xyz, new_xyz, get_cnt=False)
42
+ torch.cuda.empty_cache()
43
+
44
+ new_points = index_points(points, idx)
45
+ torch.cuda.empty_cache()
46
+
47
+ if returnfps:
48
+ return new_xyz, new_points, idx
49
+ else:
50
+ return new_xyz, new_points
51
+
52
+ def batched_index_select(input, dim, index):
53
+ views = [input.shape[0]] + \
54
+ [1 if i != dim else -1 for i in range(1, len(input.shape))]
55
+ expanse = list(input.shape)
56
+ expanse[0] = -1
57
+ expanse[dim] = -1
58
+ index = index.view(views).expand(expanse)
59
+ return torch.gather(input, dim, index)
60
+
61
+ def gumbel_softmax(logits, dim, temperature=1):
62
+ """
63
+ ST-gumple-softmax w/o random gumbel samplings
64
+ input: [*, n_class]
65
+ return: flatten --> [*, n_class] an one-hot vector
66
+ """
67
+ y = F.softmax(logits / temperature, dim=dim)
68
+
69
+ shape = y.size()
70
+ _, ind = y.max(dim=-1)
71
+ y_hard = torch.zeros_like(y).view(-1, shape[-1])
72
+ y_hard.scatter_(1, ind.view(-1, 1), 1)
73
+ y_hard = y_hard.view(*shape)
74
+
75
+ y_hard = (y_hard - y).detach() + y
76
+ return y_hard
77
+
78
+ class Walk(nn.Module):
79
+ '''
80
+ Walk in the cloud
81
+ '''
82
+ def __init__(self, in_channel, k, curve_num, curve_length):
83
+ super(Walk, self).__init__()
84
+ self.curve_num = curve_num
85
+ self.curve_length = curve_length
86
+ self.k = k
87
+
88
+ self.agent_mlp = nn.Sequential(
89
+ nn.Conv2d(in_channel * 2,
90
+ 1,
91
+ kernel_size=1,
92
+ bias=False), nn.BatchNorm2d(1))
93
+ self.momentum_mlp = nn.Sequential(
94
+ nn.Conv1d(in_channel * 2,
95
+ 2,
96
+ kernel_size=1,
97
+ bias=False), nn.BatchNorm1d(2))
98
+
99
+ def crossover_suppression(self, cur, neighbor, bn, n, k):
100
+ # cur: bs*n, 3
101
+ # neighbor: bs*n, 3, k
102
+ neighbor = neighbor.detach()
103
+ cur = cur.unsqueeze(-1).detach()
104
+ dot = torch.bmm(cur.transpose(1,2), neighbor) # bs*n, 1, k
105
+ norm1 = torch.norm(cur, dim=1, keepdim=True)
106
+ norm2 = torch.norm(neighbor, dim=1, keepdim=True)
107
+ divider = torch.clamp(norm1 * norm2, min=1e-8)
108
+ ans = torch.div(dot, divider).squeeze() # bs*n, k
109
+
110
+ # normalize to [0, 1]
111
+ ans = 1. + ans
112
+ ans = torch.clamp(ans, 0., 1.0)
113
+
114
+ return ans.detach()
115
+
116
+ def forward(self, xyz, x, adj, cur):
117
+ bn, c, tot_points = x.size()
118
+ device = x.device
119
+
120
+ # raw point coordinates
121
+ xyz = xyz.transpose(1,2).contiguous # bs, n, 3
122
+
123
+ # point features
124
+ x = x.transpose(1,2).contiguous() # bs, n, c
125
+
126
+ flatten_x = x.view(bn * tot_points, -1)
127
+ batch_offset = torch.arange(0, bn, device=device).detach() * tot_points
128
+
129
+ # indices of neighbors for the starting points
130
+ tmp_adj = (adj + batch_offset.view(-1,1,1)).view(adj.size(0)*adj.size(1),-1) #bs, n, k
131
+
132
+ # batch flattened indices for teh starting points
133
+ flatten_cur = (cur + batch_offset.view(-1,1,1)).view(-1)
134
+
135
+ curves = []
136
+ flatten_curve_idxs = [flatten_cur.unsqueeze(1)]
137
+
138
+ # one step at a time
139
+ for step in range(self.curve_length):
140
+
141
+ if step == 0:
142
+ # get starting point features using flattend indices
143
+ starting_points = flatten_x[flatten_cur, :].contiguous()
144
+ pre_feature = starting_points.view(bn, self.curve_num, -1, 1).transpose(1,2) # bs * n, c
145
+ else:
146
+ # dynamic momentum
147
+ cat_feature = torch.cat((cur_feature.squeeze(-1), pre_feature.squeeze(-1)),dim=1)
148
+ att_feature = F.softmax(self.momentum_mlp(cat_feature),dim=1).view(bn, 1, self.curve_num, 2) # bs, 1, n, 2
149
+ cat_feature = torch.cat((cur_feature, pre_feature),dim=-1) # bs, c, n, 2
150
+
151
+ # update curve descriptor
152
+ pre_feature = torch.sum(cat_feature * att_feature, dim=-1, keepdim=True) # bs, c, n
153
+ pre_feature_cos = pre_feature.transpose(1,2).contiguous().view(bn * self.curve_num, -1)
154
+
155
+ pick_idx = tmp_adj[flatten_cur] # bs*n, k
156
+
157
+ # get the neighbors of current points
158
+ pick_values = flatten_x[pick_idx.view(-1),:]
159
+
160
+ # reshape to fit crossover suppresion below
161
+ pick_values_cos = pick_values.view(bn * self.curve_num, self.k, c)
162
+ pick_values = pick_values_cos.view(bn, self.curve_num, self.k, c)
163
+ pick_values_cos = pick_values_cos.transpose(1,2).contiguous()
164
+
165
+ pick_values = pick_values.permute(0,3,1,2) # bs, c, n, k
166
+
167
+ pre_feature_expand = pre_feature.expand_as(pick_values)
168
+
169
+ # concat current point features with curve descriptors
170
+ pre_feature_expand = torch.cat((pick_values, pre_feature_expand),dim=1)
171
+
172
+ # which node to pick next?
173
+ pre_feature_expand = self.agent_mlp(pre_feature_expand) # bs, 1, n, k
174
+
175
+ if step !=0:
176
+ # cross over supression
177
+ d = self.crossover_suppression(cur_feature_cos - pre_feature_cos,
178
+ pick_values_cos - cur_feature_cos.unsqueeze(-1),
179
+ bn, self.curve_num, self.k)
180
+ d = d.view(bn, self.curve_num, self.k).unsqueeze(1) # bs, 1, n, k
181
+ pre_feature_expand = torch.mul(pre_feature_expand, d)
182
+
183
+ pre_feature_expand = gumbel_softmax(pre_feature_expand, -1) #bs, 1, n, k
184
+
185
+ cur_feature = torch.sum(pick_values * pre_feature_expand, dim=-1, keepdim=True) # bs, c, n, 1
186
+
187
+ cur_feature_cos = cur_feature.transpose(1,2).contiguous().view(bn * self.curve_num, c)
188
+
189
+ cur = torch.argmax(pre_feature_expand, dim=-1).view(-1, 1) # bs * n, 1
190
+
191
+ flatten_cur = batched_index_select(pick_idx, 1, cur).squeeze() # bs * n
192
+
193
+ # collect curve progress
194
+ curves.append(cur_feature)
195
+ flatten_curve_idxs.append(flatten_cur.unsqueeze(1))
196
+
197
+ return torch.cat(curves,dim=-1), torch.cat(flatten_curve_idxs, dim=1)
198
+
199
+
200
+ class Attention_block(nn.Module):
201
+ '''
202
+ Used in attention U-Net.
203
+ '''
204
+ def __init__(self,F_g,F_l,F_int):
205
+ super(Attention_block,self).__init__()
206
+ self.W_g = nn.Sequential(
207
+ nn.Conv1d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
208
+ nn.BatchNorm1d(F_int)
209
+ )
210
+
211
+ self.W_x = nn.Sequential(
212
+ nn.Conv1d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
213
+ nn.BatchNorm1d(F_int)
214
+ )
215
+
216
+ self.psi = nn.Sequential(
217
+ nn.Conv1d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
218
+ nn.BatchNorm1d(1),
219
+ nn.Sigmoid()
220
+ )
221
+
222
+ def forward(self,g,x):
223
+ g1 = self.W_g(g)
224
+ x1 = self.W_x(x)
225
+ psi = F.leaky_relu(g1+x1, negative_slope=0.2)
226
+ psi = self.psi(psi)
227
+
228
+ return psi, 1. - psi
229
+
230
+
231
+ class LPFA(nn.Module):
232
+ def __init__(self, in_channel, out_channel, k, mlp_num=2, initial=False):
233
+ super(LPFA, self).__init__()
234
+ self.k = k
235
+ self.initial = initial
236
+
237
+ if not initial:
238
+ self.xyz2feature = nn.Sequential(
239
+ nn.Conv2d(9, in_channel, kernel_size=1, bias=False),
240
+ nn.BatchNorm2d(in_channel))
241
+
242
+ self.mlp = []
243
+ for _ in range(mlp_num):
244
+ self.mlp.append(nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, bias=False),
245
+ nn.BatchNorm2d(out_channel),
246
+ nn.LeakyReLU(0.2)))
247
+ in_channel = out_channel
248
+ self.mlp = nn.Sequential(*self.mlp)
249
+
250
+ def forward(self, x, xyz, idx=None):
251
+ x = self.group_feature(x, xyz, idx)
252
+ x = self.mlp(x)
253
+
254
+ if self.initial:
255
+ x = x.max(dim=-1, keepdim=False)[0]
256
+ else:
257
+ x = x.mean(dim=-1, keepdim=False)
258
+
259
+ return x
260
+
261
+ def group_feature(self, x, xyz, idx):
262
+ batch_size, num_dims, num_points = x.size()
263
+ device = x.device
264
+
265
+ if idx is None:
266
+ idx = knn(xyz, k=self.k, add_one_to_k=True)[:,:,:self.k] # (batch_size, num_points, k)
267
+
268
+ idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
269
+ idx = idx + idx_base
270
+ idx = idx.view(-1)
271
+
272
+ xyz = xyz.transpose(2, 1).contiguous() # bs, n, 3
273
+ point_feature = xyz.view(batch_size * num_points, -1)[idx, :]
274
+ point_feature = point_feature.view(batch_size, num_points, self.k, -1) # bs, n, k, 3
275
+ points = xyz.view(batch_size, num_points, 1, 3).expand(-1, -1, self.k, -1) # bs, n, k, 3
276
+
277
+ point_feature = torch.cat((points, point_feature, point_feature - points),
278
+ dim=3).permute(0, 3, 1, 2).contiguous()
279
+
280
+ if self.initial:
281
+ return point_feature
282
+
283
+ x = x.transpose(2, 1).contiguous() # bs, n, c
284
+ feature = x.view(batch_size * num_points, -1)[idx, :]
285
+ feature = feature.view(batch_size, num_points, self.k, num_dims) #bs, n, k, c
286
+ x = x.view(batch_size, num_points, 1, num_dims)
287
+ feature = feature - x
288
+
289
+ feature = feature.permute(0, 3, 1, 2).contiguous()
290
+ point_feature = self.xyz2feature(point_feature) #bs, c, n, k
291
+ feature = F.leaky_relu(feature + point_feature, 0.2)
292
+ return feature #bs, c, n, k
293
+
294
+
295
+ class PointNetFeaturePropagation(nn.Module):
296
+ def __init__(self, in_channel, mlp, att=None):
297
+ super(PointNetFeaturePropagation, self).__init__()
298
+ self.mlp_convs = nn.ModuleList()
299
+ self.mlp_bns = nn.ModuleList()
300
+ last_channel = in_channel
301
+ self.att = None
302
+ if att is not None:
303
+ self.att = Attention_block(F_g=att[0],F_l=att[1],F_int=att[2])
304
+
305
+ for out_channel in mlp:
306
+ self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
307
+ self.mlp_bns.append(nn.BatchNorm1d(out_channel))
308
+ last_channel = out_channel
309
+
310
+ def forward(self, xyz1, xyz2, points1, points2):
311
+ """
312
+ Input:
313
+ xyz1: input points position data, [B, C, N]
314
+ xyz2: sampled input points position data, [B, C, S], skipped xyz
315
+ points1: input points data, [B, D, N]
316
+ points2: input points data, [B, D, S], skipped features
317
+ Return:
318
+ new_points: upsampled points data, [B, D', N]
319
+ """
320
+ xyz1 = xyz1.permute(0, 2, 1)
321
+ xyz2 = xyz2.permute(0, 2, 1)
322
+
323
+ points2 = points2.permute(0, 2, 1)
324
+ B, N, C = xyz1.shape
325
+ _, S, _ = xyz2.shape
326
+
327
+ if S == 1:
328
+ interpolated_points = points2.repeat(1, N, 1)
329
+ else:
330
+ dists = square_distance(xyz1, xyz2)
331
+ dists, idx = dists.sort(dim=-1)
332
+ dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
333
+
334
+ dist_recip = 1.0 / (dists + 1e-8)
335
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
336
+ weight = dist_recip / norm
337
+ interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
338
+
339
+ # skip attention
340
+ if self.att is not None:
341
+ psix, psig = self.att(interpolated_points.permute(0, 2, 1), points1)
342
+ points1 = points1 * psix
343
+
344
+ if points1 is not None:
345
+ points1 = points1.permute(0, 2, 1)
346
+ new_points = torch.cat([points1, interpolated_points], dim=-1)
347
+ else:
348
+ new_points = interpolated_points
349
+
350
+ new_points = new_points.permute(0, 2, 1)
351
+
352
+ for i, conv in enumerate(self.mlp_convs):
353
+ bn = self.mlp_bns[i]
354
+ new_points = F.leaky_relu(bn(conv(new_points)), 0.2)
355
+
356
+ return new_points
357
+
358
+
359
+ class CIC(nn.Module):
360
+ def __init__(self, npoint, radius, k, in_channels, output_channels, bottleneck_ratio=2, mlp_num=2, curve_config=None):
361
+ super(CIC, self).__init__()
362
+ self.in_channels = in_channels
363
+ self.output_channels = output_channels
364
+ self.bottleneck_ratio = bottleneck_ratio
365
+ self.radius = radius
366
+ self.k = k
367
+ self.npoint = npoint
368
+
369
+ planes = in_channels // bottleneck_ratio
370
+
371
+ self.use_curve = curve_config is not None
372
+ if self.use_curve:
373
+ self.curveaggregation = CurveAggregation(planes)
374
+ self.curvegrouping = CurveGrouping(planes, k, curve_config[0], curve_config[1])
375
+
376
+ self.conv1 = nn.Sequential(
377
+ nn.Conv1d(in_channels,
378
+ planes,
379
+ kernel_size=1,
380
+ bias=False),
381
+ nn.BatchNorm1d(in_channels // bottleneck_ratio),
382
+ nn.LeakyReLU(negative_slope=0.2, inplace=True))
383
+
384
+ self.conv2 = nn.Sequential(
385
+ nn.Conv1d(planes, output_channels, kernel_size=1, bias=False),
386
+ nn.BatchNorm1d(output_channels))
387
+
388
+ if in_channels != output_channels:
389
+ self.shortcut = nn.Sequential(
390
+ nn.Conv1d(in_channels,
391
+ output_channels,
392
+ kernel_size=1,
393
+ bias=False),
394
+ nn.BatchNorm1d(output_channels))
395
+
396
+ self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
397
+
398
+ self.maxpool = MaskedMaxPool(npoint, radius, k)
399
+
400
+ self.lpfa = LPFA(planes, planes, k, mlp_num=mlp_num, initial=False)
401
+
402
+ def forward(self, xyz, x):
403
+
404
+ # max pool
405
+ if xyz.size(-1) != self.npoint:
406
+ xyz, x = self.maxpool(
407
+ xyz.transpose(1, 2).contiguous(), x)
408
+ xyz = xyz.transpose(1, 2)
409
+
410
+ shortcut = x
411
+ x = self.conv1(x) # bs, c', n
412
+
413
+ idx = knn(xyz, self.k, add_one_to_k=True)
414
+
415
+ if self.use_curve:
416
+ # curve grouping
417
+ curves, flatten_curve_idxs = self.curvegrouping(x, xyz, idx[:,:,1:]) # avoid self-loop
418
+
419
+ # curve aggregation
420
+ x = self.curveaggregation(x, curves)
421
+ else:
422
+ flatten_curve_idxs = None
423
+
424
+ x = self.lpfa(x, xyz, idx=idx[:,:,:self.k]) #bs, c', n, k
425
+
426
+ x = self.conv2(x) # bs, c, n
427
+
428
+ if self.in_channels != self.output_channels:
429
+ shortcut = self.shortcut(shortcut)
430
+
431
+ x = self.relu(x + shortcut)
432
+ return xyz, x, flatten_curve_idxs
433
+
434
+
435
+ class CurveAggregation(nn.Module):
436
+ def __init__(self, in_channel):
437
+ super(CurveAggregation, self).__init__()
438
+ self.in_channel = in_channel
439
+ mid_feature = in_channel // 2
440
+ self.conva = nn.Conv1d(in_channel,
441
+ mid_feature,
442
+ kernel_size=1,
443
+ bias=False)
444
+ self.convb = nn.Conv1d(in_channel,
445
+ mid_feature,
446
+ kernel_size=1,
447
+ bias=False)
448
+ self.convc = nn.Conv1d(in_channel,
449
+ mid_feature,
450
+ kernel_size=1,
451
+ bias=False)
452
+ self.convn = nn.Conv1d(mid_feature,
453
+ mid_feature,
454
+ kernel_size=1,
455
+ bias=False)
456
+ self.convl = nn.Conv1d(mid_feature,
457
+ mid_feature,
458
+ kernel_size=1,
459
+ bias=False)
460
+ self.convd = nn.Sequential(
461
+ nn.Conv1d(mid_feature * 2,
462
+ in_channel,
463
+ kernel_size=1,
464
+ bias=False),
465
+ nn.BatchNorm1d(in_channel))
466
+ self.line_conv_att = nn.Conv2d(in_channel,
467
+ 1,
468
+ kernel_size=1,
469
+ bias=False)
470
+
471
+ def forward(self, x, curves):
472
+ curves_att = self.line_conv_att(curves) # bs, 1, c_n, c_l
473
+
474
+ curver_inter = torch.sum(curves * F.softmax(curves_att, dim=-1), dim=-1) #bs, c, c_n
475
+ curves_intra = torch.sum(curves * F.softmax(curves_att, dim=-2), dim=-2) #bs, c, c_l
476
+
477
+ curver_inter = self.conva(curver_inter) # bs, mid, n
478
+ curves_intra = self.convb(curves_intra) # bs, mid ,n
479
+
480
+ x_logits = self.convc(x).transpose(1, 2).contiguous()
481
+ x_inter = F.softmax(torch.bmm(x_logits, curver_inter), dim=-1) # bs, n, c_n
482
+ x_intra = F.softmax(torch.bmm(x_logits, curves_intra), dim=-1) # bs, l, c_l
483
+
484
+
485
+ curver_inter = self.convn(curver_inter).transpose(1, 2).contiguous()
486
+ curves_intra = self.convl(curves_intra).transpose(1, 2).contiguous()
487
+
488
+ x_inter = torch.bmm(x_inter, curver_inter)
489
+ x_intra = torch.bmm(x_intra, curves_intra)
490
+
491
+ curve_features = torch.cat((x_inter, x_intra),dim=-1).transpose(1, 2).contiguous()
492
+ x = x + self.convd(curve_features)
493
+
494
+ return F.leaky_relu(x, negative_slope=0.2)
495
+
496
+
497
+ class CurveGrouping(nn.Module):
498
+ def __init__(self, in_channel, k, curve_num, curve_length):
499
+ super(CurveGrouping, self).__init__()
500
+ self.curve_num = curve_num
501
+ self.curve_length = curve_length
502
+ self.in_channel = in_channel
503
+ self.k = k
504
+
505
+ self.att = nn.Conv1d(in_channel, 1, kernel_size=1, bias=False)
506
+
507
+ self.walk = Walk(in_channel, k, curve_num, curve_length)
508
+
509
+ def forward(self, x, xyz, idx):
510
+ # starting point selection in self attention style
511
+ x_att = torch.sigmoid(self.att(x))
512
+ x = x * x_att
513
+
514
+ _, start_index = torch.topk(x_att,
515
+ self.curve_num,
516
+ dim=2,
517
+ sorted=False)
518
+ start_index = start_index.squeeze(1).unsqueeze(2)
519
+
520
+ curves, flatten_curve_idxs = self.walk(xyz, x, idx, start_index) #bs, c, c_n, c_l
521
+
522
+ return curves, flatten_curve_idxs
523
+
524
+
525
+ class MaskedMaxPool(nn.Module):
526
+ def __init__(self, npoint, radius, k):
527
+ super(MaskedMaxPool, self).__init__()
528
+ self.npoint = npoint
529
+ self.radius = radius
530
+ self.k = k
531
+
532
+ def forward(self, xyz, features):
533
+ sub_xyz, neighborhood_features = sample_and_group(self.npoint, self.radius, self.k, xyz, features.transpose(1,2))
534
+
535
+ neighborhood_features = neighborhood_features.permute(0, 3, 1, 2).contiguous()
536
+ sub_features = F.max_pool2d(
537
+ neighborhood_features, kernel_size=[1, neighborhood_features.shape[3]]
538
+ ) # bs, c, n, 1
539
+ sub_features = torch.squeeze(sub_features, -1) # bs, c, n
540
+ return sub_xyz, sub_features
@@ -0,0 +1,156 @@
1
+ import torch
2
+
3
+ def knn(x, k, add_one_to_k=False):
4
+ if add_one_to_k: k = k + 1
5
+ inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
6
+ xx = torch.sum(x**2, dim=1, keepdim=True)
7
+ pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()
8
+ idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
9
+ return idx
10
+
11
+ def pc_normalize(pc):
12
+ l = pc.shape[0]
13
+ centroid = np.mean(pc, axis=0)
14
+ pc = pc - centroid
15
+ m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
16
+ pc = pc / m
17
+ return pc
18
+
19
+ def square_distance(src, dst):
20
+ """
21
+ Calculate Euclid distance between each two points.
22
+ src^T * dst = xn * xm + yn * ym + zn * zm;
23
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
24
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
25
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
26
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
27
+ Input:
28
+ src: source points, [B, N, C]
29
+ dst: target points, [B, M, C]
30
+ Output:
31
+ dist: per-point square distance, [B, N, M]
32
+ """
33
+ B, N, _ = src.shape
34
+ _, M, _ = dst.shape
35
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
36
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
37
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
38
+ return dist
39
+
40
+ def index_points(points, idx):
41
+ """
42
+ Input:
43
+ points: input points data, [B, N, C]
44
+ idx: sample index data, [B, S]
45
+ Return:
46
+ new_points:, indexed points data, [B, S, C]
47
+ """
48
+ device = points.device
49
+ B = points.shape[0]
50
+ view_shape = list(idx.shape)
51
+ view_shape[1:] = [1] * (len(view_shape) - 1)
52
+ repeat_shape = list(idx.shape)
53
+ repeat_shape[0] = 1
54
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
55
+ new_points = points[batch_indices, idx, :]
56
+ return new_points
57
+
58
+ def farthest_point_sample(xyz, npoint, start_with_first_point=False):
59
+ """
60
+ Input:
61
+ xyz: pointcloud data, [B, N, C]
62
+ npoint: number of samples
63
+ Return:
64
+ centroids: sampled pointcloud index, [B, npoint]
65
+ """
66
+ device = xyz.device
67
+ B, N, C = xyz.shape
68
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
69
+ distance = torch.ones(B, N).to(device) * 1e10
70
+ if not start_with_first_point:
71
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
72
+ else:
73
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) * 0
74
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
75
+ for i in range(npoint):
76
+ centroids[:, i] = farthest
77
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
78
+ dist = torch.sum((xyz - centroid) ** 2, -1)
79
+ mask = dist < distance
80
+ distance[mask] = dist[mask]
81
+ farthest = torch.max(distance, -1)[1]
82
+ return centroids
83
+
84
+ def knn_point(k, pos1, pos2):
85
+ '''
86
+ Input:
87
+ k: int32, number of k in k-nn search
88
+ pos1: (batch_size, ndataset, c) float32 array, input points
89
+ pos2: (batch_size, npoint, c) float32 array, query points
90
+ Output:
91
+ val: (batch_size, npoint, k) float32 array, L2 distances
92
+ idx: (batch_size, npoint, k) int32 array, indices to input points
93
+ '''
94
+ B, N, C = pos1.shape
95
+ M = pos2.shape[1]
96
+ pos1 = pos1.view(B,1,N,-1).repeat(1,M,1,1)
97
+ pos2 = pos2.view(B,M,1,-1).repeat(1,1,N,1)
98
+ dist = torch.sum(-(pos1-pos2)**2,-1)
99
+ val,idx = dist.topk(k=k,dim = -1)
100
+ return torch.sqrt(-val), idx
101
+
102
+ def query_ball_point(radius, nsample, xyz, new_xyz, get_cnt=False):
103
+ """
104
+ Input:
105
+ radius: local region radius
106
+ nsample: max sample number in local region
107
+ xyz: all points, [B, N, C]
108
+ new_xyz: query points, [B, S, C]
109
+ Return:
110
+ group_idx: grouped points index, [B, S, nsample]
111
+ """
112
+ device = xyz.device
113
+ B, N, C = xyz.shape
114
+ _, S, _ = new_xyz.shape
115
+ group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
116
+ sqrdists = square_distance(new_xyz, xyz)
117
+ group_idx[sqrdists > radius ** 2] = N
118
+
119
+ if get_cnt:
120
+ mask = group_idx != N
121
+ cnt = mask.sum(dim=-1)
122
+
123
+ group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
124
+ group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
125
+ mask = group_idx == N
126
+ group_idx[mask] = group_first[mask]
127
+ if get_cnt:
128
+ return group_idx, cnt
129
+ else:
130
+ return group_idx
131
+
132
+ def get_graph_feature(x, k=20, device=None):
133
+ # x = x.squeeze()
134
+ x = x.view(*x.size()[:3])
135
+ idx = knn(x, k=k) # (batch_size, num_points, k)
136
+ batch_size, num_points, _ = idx.size()
137
+
138
+ if device is None:
139
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
140
+
141
+ idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
142
+
143
+ idx = idx + idx_base
144
+
145
+ idx = idx.view(-1)
146
+
147
+ _, num_dims, _ = x.size()
148
+
149
+ x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
150
+ feature = x.view(batch_size * num_points, -1)[idx, :]
151
+ feature = feature.view(batch_size, num_points, k, num_dims)
152
+ x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
153
+
154
+ feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
155
+
156
+ return feature
@@ -1 +0,0 @@
1
- from .chamfer_distance import ChamferDistance