learning3d 0.0.7__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {learning3d/data_utils → data_utils}/dataloaders.py +16 -14
  2. examples/test_curvenet.py +118 -0
  3. {learning3d/examples → examples}/test_dcp.py +1 -1
  4. {learning3d/examples → examples}/test_deepgmr.py +1 -1
  5. {learning3d/examples → examples}/test_prnet.py +1 -1
  6. {learning3d-0.0.7.dist-info → learning3d-0.2.0.dist-info}/METADATA +56 -11
  7. learning3d-0.2.0.dist-info/RECORD +70 -0
  8. {learning3d-0.0.7.dist-info → learning3d-0.2.0.dist-info}/WHEEL +1 -1
  9. learning3d-0.2.0.dist-info/top_level.txt +6 -0
  10. {learning3d/models → models}/__init__.py +7 -1
  11. models/curvenet.py +130 -0
  12. {learning3d/models → models}/dgcnn.py +1 -35
  13. {learning3d/models → models}/prnet.py +5 -39
  14. utils/__init__.py +23 -0
  15. utils/curvenet_util.py +540 -0
  16. utils/model_common_utils.py +156 -0
  17. learning3d/losses/cuda/chamfer_distance/__init__.py +0 -1
  18. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +0 -185
  19. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +0 -209
  20. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +0 -66
  21. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +0 -41
  22. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +0 -347
  23. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +0 -18
  24. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +0 -54
  25. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +0 -1
  26. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +0 -40
  27. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +0 -70
  28. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +0 -1
  29. learning3d/losses/cuda/emd_torch/setup.py +0 -29
  30. learning3d/ops/__init__.py +0 -0
  31. learning3d/utils/__init__.py +0 -4
  32. learning3d-0.0.7.dist-info/RECORD +0 -80
  33. learning3d-0.0.7.dist-info/top_level.txt +0 -1
  34. {learning3d/data_utils → data_utils}/__init__.py +0 -0
  35. {learning3d/data_utils → data_utils}/user_data.py +0 -0
  36. {learning3d/examples → examples}/test_masknet.py +0 -0
  37. {learning3d/examples → examples}/test_masknet2.py +0 -0
  38. {learning3d/examples → examples}/test_pcn.py +0 -0
  39. {learning3d/examples → examples}/test_pcrnet.py +0 -0
  40. {learning3d/examples → examples}/test_pnlk.py +0 -0
  41. {learning3d/examples → examples}/test_pointconv.py +0 -0
  42. {learning3d/examples → examples}/test_pointnet.py +0 -0
  43. {learning3d/examples → examples}/test_rpmnet.py +0 -0
  44. {learning3d/examples → examples}/train_PointNetLK.py +0 -0
  45. {learning3d/examples → examples}/train_dcp.py +0 -0
  46. {learning3d/examples → examples}/train_deepgmr.py +0 -0
  47. {learning3d/examples → examples}/train_masknet.py +0 -0
  48. {learning3d/examples → examples}/train_pcn.py +0 -0
  49. {learning3d/examples → examples}/train_pcrnet.py +0 -0
  50. {learning3d/examples → examples}/train_pointconv.py +0 -0
  51. {learning3d/examples → examples}/train_pointnet.py +0 -0
  52. {learning3d/examples → examples}/train_prnet.py +0 -0
  53. {learning3d/examples → examples}/train_rpmnet.py +0 -0
  54. {learning3d-0.0.7.dist-info → learning3d-0.2.0.dist-info}/LICENSE +0 -0
  55. {learning3d/losses → losses}/__init__.py +0 -0
  56. {learning3d/losses → losses}/chamfer_distance.py +0 -0
  57. {learning3d/losses → losses}/classification.py +0 -0
  58. {learning3d/losses → losses}/correspondence_loss.py +0 -0
  59. {learning3d/losses → losses}/emd.py +0 -0
  60. {learning3d/losses → losses}/frobenius_norm.py +0 -0
  61. {learning3d/losses → losses}/rmse_features.py +0 -0
  62. {learning3d/models → models}/classifier.py +0 -0
  63. {learning3d/models → models}/dcp.py +0 -0
  64. {learning3d/models → models}/deepgmr.py +0 -0
  65. {learning3d/models → models}/masknet.py +0 -0
  66. {learning3d/models → models}/masknet2.py +0 -0
  67. {learning3d/models → models}/pcn.py +0 -0
  68. {learning3d/models → models}/pcrnet.py +0 -0
  69. {learning3d/models → models}/pointconv.py +0 -0
  70. {learning3d/models → models}/pointnet.py +0 -0
  71. {learning3d/models → models}/pointnetlk.py +0 -0
  72. {learning3d/models → models}/pooling.py +0 -0
  73. {learning3d/models → models}/ppfnet.py +0 -0
  74. {learning3d/models → models}/rpmnet.py +0 -0
  75. {learning3d/models → models}/segmentation.py +0 -0
  76. {learning3d → ops}/__init__.py +0 -0
  77. {learning3d/ops → ops}/data_utils.py +0 -0
  78. {learning3d/ops → ops}/invmat.py +0 -0
  79. {learning3d/ops → ops}/quaternion.py +0 -0
  80. {learning3d/ops → ops}/se3.py +0 -0
  81. {learning3d/ops → ops}/sinc.py +0 -0
  82. {learning3d/ops → ops}/so3.py +0 -0
  83. {learning3d/ops → ops}/transform_functions.py +0 -0
  84. {learning3d/utils → utils}/pointconv_util.py +0 -0
  85. {learning3d/utils → utils}/ppfnet_util.py +0 -0
  86. {learning3d/utils → utils}/svd.py +0 -0
  87. {learning3d/utils → utils}/transformer.py +0 -0
@@ -16,7 +16,7 @@ import torch.nn as nn
16
16
  import torch.nn.functional as F
17
17
 
18
18
  from .. ops import transform_functions as transform
19
- from .. utils import Transformer, Identity
19
+ from .. utils import Transformer, Identity, knn, get_graph_feature
20
20
 
21
21
  from sklearn.metrics import r2_score
22
22
 
@@ -30,40 +30,6 @@ def pairwise_distance(src, tgt):
30
30
  distances = xx.transpose(2, 1).contiguous() + inner + yy
31
31
  return torch.sqrt(distances)
32
32
 
33
-
34
- def knn(x, k):
35
- inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
36
- xx = torch.sum(x ** 2, dim=1, keepdim=True)
37
- distance = -xx - inner - xx.transpose(2, 1).contiguous()
38
-
39
- idx = distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
40
- return idx
41
-
42
-
43
- def get_graph_feature(x, k=20):
44
- # x = x.squeeze()
45
- x = x.view(*x.size()[:3])
46
- idx = knn(x, k=k) # (batch_size, num_points, k)
47
- batch_size, num_points, _ = idx.size()
48
-
49
- idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
50
-
51
- idx = idx + idx_base
52
-
53
- idx = idx.view(-1)
54
-
55
- _, num_dims, _ = x.size()
56
-
57
- x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
58
- feature = x.view(batch_size * num_points, -1)[idx, :]
59
- feature = feature.view(batch_size, num_points, k, num_dims)
60
- x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
61
-
62
- feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
63
-
64
- return feature
65
-
66
-
67
33
  def cycle_consistency(rotation_ab, translation_ab, rotation_ba, translation_ba):
68
34
  batch_size = rotation_ab.size(0)
69
35
  identity = torch.eye(3, device=rotation_ab.device).unsqueeze(0).repeat(batch_size, 1, 1)
@@ -109,19 +75,19 @@ class DGCNN(nn.Module):
109
75
 
110
76
  def forward(self, x):
111
77
  batch_size, num_dims, num_points = x.size()
112
- x = get_graph_feature(x)
78
+ x = get_graph_feature(x, device=device)
113
79
  x = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.2)
114
80
  x1 = x.max(dim=-1, keepdim=True)[0]
115
81
 
116
- x = get_graph_feature(x1)
82
+ x = get_graph_feature(x1, device=device)
117
83
  x = F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.2)
118
84
  x2 = x.max(dim=-1, keepdim=True)[0]
119
85
 
120
- x = get_graph_feature(x2)
86
+ x = get_graph_feature(x2, device=device)
121
87
  x = F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.2)
122
88
  x3 = x.max(dim=-1, keepdim=True)[0]
123
89
 
124
- x = get_graph_feature(x3)
90
+ x = get_graph_feature(x3, device=device)
125
91
  x = F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.2)
126
92
  x4 = x.max(dim=-1, keepdim=True)[0]
127
93
 
utils/__init__.py ADDED
@@ -0,0 +1,23 @@
1
+ from .svd import SVDHead
2
+ from .transformer import Transformer, Identity
3
+ from .ppfnet_util import angle_difference, square_distance, index_points, farthest_point_sample, query_ball_point, sample_and_group, sample_and_group_multi
4
+ from .pointconv_util import PointConvDensitySetAbstraction
5
+ from .model_common_utils import (
6
+ knn,
7
+ pc_normalize,
8
+ square_distance,
9
+ index_points,
10
+ farthest_point_sample,
11
+ knn_point,
12
+ query_ball_point,
13
+ get_graph_feature
14
+ )
15
+ from .curvenet_util import (
16
+ LPFA,
17
+ CIC,
18
+ )
19
+
20
+ try:
21
+ from .lib import pointnet2_utils
22
+ except:
23
+ print("Error raised in pointnet2 module in utils!\nEither don't use pointnet2_utils or retry it's setup.")
utils/curvenet_util.py ADDED
@@ -0,0 +1,540 @@
1
+ """
2
+ @Author: Yue Wang
3
+ @Contact: yuewangx@mit.edu
4
+ @File: pointnet_util.py
5
+ @Time: 2018/10/13 10:39 PM
6
+
7
+ Modified by
8
+ @Author: Tiange Xiang
9
+ @Contact: txia7609@uni.sydney.edu.au
10
+ @Time: 2021/01/21 3:10 PM
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from time import time
17
+ import numpy as np
18
+ from .model_common_utils import (
19
+ knn,
20
+ square_distance,
21
+ index_points,
22
+ farthest_point_sample,
23
+ query_ball_point,
24
+ )
25
+
26
+ def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
27
+ """
28
+ Input:
29
+ npoint:
30
+ radius:
31
+ nsample:
32
+ xyz: input points position data, [B, N, 3]
33
+ points: input points data, [B, N, D]
34
+ Return:
35
+ new_xyz: sampled points position data, [B, npoint, nsample, 3]
36
+ new_points: sampled points data, [B, npoint, nsample, 3+D]
37
+ """
38
+ new_xyz = index_points(xyz, farthest_point_sample(xyz, npoint, start_with_first_point=True))
39
+ torch.cuda.empty_cache()
40
+
41
+ idx = query_ball_point(radius, nsample, xyz, new_xyz, get_cnt=False)
42
+ torch.cuda.empty_cache()
43
+
44
+ new_points = index_points(points, idx)
45
+ torch.cuda.empty_cache()
46
+
47
+ if returnfps:
48
+ return new_xyz, new_points, idx
49
+ else:
50
+ return new_xyz, new_points
51
+
52
+ def batched_index_select(input, dim, index):
53
+ views = [input.shape[0]] + \
54
+ [1 if i != dim else -1 for i in range(1, len(input.shape))]
55
+ expanse = list(input.shape)
56
+ expanse[0] = -1
57
+ expanse[dim] = -1
58
+ index = index.view(views).expand(expanse)
59
+ return torch.gather(input, dim, index)
60
+
61
+ def gumbel_softmax(logits, dim, temperature=1):
62
+ """
63
+ ST-gumple-softmax w/o random gumbel samplings
64
+ input: [*, n_class]
65
+ return: flatten --> [*, n_class] an one-hot vector
66
+ """
67
+ y = F.softmax(logits / temperature, dim=dim)
68
+
69
+ shape = y.size()
70
+ _, ind = y.max(dim=-1)
71
+ y_hard = torch.zeros_like(y).view(-1, shape[-1])
72
+ y_hard.scatter_(1, ind.view(-1, 1), 1)
73
+ y_hard = y_hard.view(*shape)
74
+
75
+ y_hard = (y_hard - y).detach() + y
76
+ return y_hard
77
+
78
+ class Walk(nn.Module):
79
+ '''
80
+ Walk in the cloud
81
+ '''
82
+ def __init__(self, in_channel, k, curve_num, curve_length):
83
+ super(Walk, self).__init__()
84
+ self.curve_num = curve_num
85
+ self.curve_length = curve_length
86
+ self.k = k
87
+
88
+ self.agent_mlp = nn.Sequential(
89
+ nn.Conv2d(in_channel * 2,
90
+ 1,
91
+ kernel_size=1,
92
+ bias=False), nn.BatchNorm2d(1))
93
+ self.momentum_mlp = nn.Sequential(
94
+ nn.Conv1d(in_channel * 2,
95
+ 2,
96
+ kernel_size=1,
97
+ bias=False), nn.BatchNorm1d(2))
98
+
99
+ def crossover_suppression(self, cur, neighbor, bn, n, k):
100
+ # cur: bs*n, 3
101
+ # neighbor: bs*n, 3, k
102
+ neighbor = neighbor.detach()
103
+ cur = cur.unsqueeze(-1).detach()
104
+ dot = torch.bmm(cur.transpose(1,2), neighbor) # bs*n, 1, k
105
+ norm1 = torch.norm(cur, dim=1, keepdim=True)
106
+ norm2 = torch.norm(neighbor, dim=1, keepdim=True)
107
+ divider = torch.clamp(norm1 * norm2, min=1e-8)
108
+ ans = torch.div(dot, divider).squeeze() # bs*n, k
109
+
110
+ # normalize to [0, 1]
111
+ ans = 1. + ans
112
+ ans = torch.clamp(ans, 0., 1.0)
113
+
114
+ return ans.detach()
115
+
116
+ def forward(self, xyz, x, adj, cur):
117
+ bn, c, tot_points = x.size()
118
+ device = x.device
119
+
120
+ # raw point coordinates
121
+ xyz = xyz.transpose(1,2).contiguous # bs, n, 3
122
+
123
+ # point features
124
+ x = x.transpose(1,2).contiguous() # bs, n, c
125
+
126
+ flatten_x = x.view(bn * tot_points, -1)
127
+ batch_offset = torch.arange(0, bn, device=device).detach() * tot_points
128
+
129
+ # indices of neighbors for the starting points
130
+ tmp_adj = (adj + batch_offset.view(-1,1,1)).view(adj.size(0)*adj.size(1),-1) #bs, n, k
131
+
132
+ # batch flattened indices for teh starting points
133
+ flatten_cur = (cur + batch_offset.view(-1,1,1)).view(-1)
134
+
135
+ curves = []
136
+ flatten_curve_idxs = [flatten_cur.unsqueeze(1)]
137
+
138
+ # one step at a time
139
+ for step in range(self.curve_length):
140
+
141
+ if step == 0:
142
+ # get starting point features using flattend indices
143
+ starting_points = flatten_x[flatten_cur, :].contiguous()
144
+ pre_feature = starting_points.view(bn, self.curve_num, -1, 1).transpose(1,2) # bs * n, c
145
+ else:
146
+ # dynamic momentum
147
+ cat_feature = torch.cat((cur_feature.squeeze(-1), pre_feature.squeeze(-1)),dim=1)
148
+ att_feature = F.softmax(self.momentum_mlp(cat_feature),dim=1).view(bn, 1, self.curve_num, 2) # bs, 1, n, 2
149
+ cat_feature = torch.cat((cur_feature, pre_feature),dim=-1) # bs, c, n, 2
150
+
151
+ # update curve descriptor
152
+ pre_feature = torch.sum(cat_feature * att_feature, dim=-1, keepdim=True) # bs, c, n
153
+ pre_feature_cos = pre_feature.transpose(1,2).contiguous().view(bn * self.curve_num, -1)
154
+
155
+ pick_idx = tmp_adj[flatten_cur] # bs*n, k
156
+
157
+ # get the neighbors of current points
158
+ pick_values = flatten_x[pick_idx.view(-1),:]
159
+
160
+ # reshape to fit crossover suppresion below
161
+ pick_values_cos = pick_values.view(bn * self.curve_num, self.k, c)
162
+ pick_values = pick_values_cos.view(bn, self.curve_num, self.k, c)
163
+ pick_values_cos = pick_values_cos.transpose(1,2).contiguous()
164
+
165
+ pick_values = pick_values.permute(0,3,1,2) # bs, c, n, k
166
+
167
+ pre_feature_expand = pre_feature.expand_as(pick_values)
168
+
169
+ # concat current point features with curve descriptors
170
+ pre_feature_expand = torch.cat((pick_values, pre_feature_expand),dim=1)
171
+
172
+ # which node to pick next?
173
+ pre_feature_expand = self.agent_mlp(pre_feature_expand) # bs, 1, n, k
174
+
175
+ if step !=0:
176
+ # cross over supression
177
+ d = self.crossover_suppression(cur_feature_cos - pre_feature_cos,
178
+ pick_values_cos - cur_feature_cos.unsqueeze(-1),
179
+ bn, self.curve_num, self.k)
180
+ d = d.view(bn, self.curve_num, self.k).unsqueeze(1) # bs, 1, n, k
181
+ pre_feature_expand = torch.mul(pre_feature_expand, d)
182
+
183
+ pre_feature_expand = gumbel_softmax(pre_feature_expand, -1) #bs, 1, n, k
184
+
185
+ cur_feature = torch.sum(pick_values * pre_feature_expand, dim=-1, keepdim=True) # bs, c, n, 1
186
+
187
+ cur_feature_cos = cur_feature.transpose(1,2).contiguous().view(bn * self.curve_num, c)
188
+
189
+ cur = torch.argmax(pre_feature_expand, dim=-1).view(-1, 1) # bs * n, 1
190
+
191
+ flatten_cur = batched_index_select(pick_idx, 1, cur).squeeze() # bs * n
192
+
193
+ # collect curve progress
194
+ curves.append(cur_feature)
195
+ flatten_curve_idxs.append(flatten_cur.unsqueeze(1))
196
+
197
+ return torch.cat(curves,dim=-1), torch.cat(flatten_curve_idxs, dim=1)
198
+
199
+
200
+ class Attention_block(nn.Module):
201
+ '''
202
+ Used in attention U-Net.
203
+ '''
204
+ def __init__(self,F_g,F_l,F_int):
205
+ super(Attention_block,self).__init__()
206
+ self.W_g = nn.Sequential(
207
+ nn.Conv1d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
208
+ nn.BatchNorm1d(F_int)
209
+ )
210
+
211
+ self.W_x = nn.Sequential(
212
+ nn.Conv1d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
213
+ nn.BatchNorm1d(F_int)
214
+ )
215
+
216
+ self.psi = nn.Sequential(
217
+ nn.Conv1d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
218
+ nn.BatchNorm1d(1),
219
+ nn.Sigmoid()
220
+ )
221
+
222
+ def forward(self,g,x):
223
+ g1 = self.W_g(g)
224
+ x1 = self.W_x(x)
225
+ psi = F.leaky_relu(g1+x1, negative_slope=0.2)
226
+ psi = self.psi(psi)
227
+
228
+ return psi, 1. - psi
229
+
230
+
231
+ class LPFA(nn.Module):
232
+ def __init__(self, in_channel, out_channel, k, mlp_num=2, initial=False):
233
+ super(LPFA, self).__init__()
234
+ self.k = k
235
+ self.initial = initial
236
+
237
+ if not initial:
238
+ self.xyz2feature = nn.Sequential(
239
+ nn.Conv2d(9, in_channel, kernel_size=1, bias=False),
240
+ nn.BatchNorm2d(in_channel))
241
+
242
+ self.mlp = []
243
+ for _ in range(mlp_num):
244
+ self.mlp.append(nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, bias=False),
245
+ nn.BatchNorm2d(out_channel),
246
+ nn.LeakyReLU(0.2)))
247
+ in_channel = out_channel
248
+ self.mlp = nn.Sequential(*self.mlp)
249
+
250
+ def forward(self, x, xyz, idx=None):
251
+ x = self.group_feature(x, xyz, idx)
252
+ x = self.mlp(x)
253
+
254
+ if self.initial:
255
+ x = x.max(dim=-1, keepdim=False)[0]
256
+ else:
257
+ x = x.mean(dim=-1, keepdim=False)
258
+
259
+ return x
260
+
261
+ def group_feature(self, x, xyz, idx):
262
+ batch_size, num_dims, num_points = x.size()
263
+ device = x.device
264
+
265
+ if idx is None:
266
+ idx = knn(xyz, k=self.k, add_one_to_k=True)[:,:,:self.k] # (batch_size, num_points, k)
267
+
268
+ idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
269
+ idx = idx + idx_base
270
+ idx = idx.view(-1)
271
+
272
+ xyz = xyz.transpose(2, 1).contiguous() # bs, n, 3
273
+ point_feature = xyz.view(batch_size * num_points, -1)[idx, :]
274
+ point_feature = point_feature.view(batch_size, num_points, self.k, -1) # bs, n, k, 3
275
+ points = xyz.view(batch_size, num_points, 1, 3).expand(-1, -1, self.k, -1) # bs, n, k, 3
276
+
277
+ point_feature = torch.cat((points, point_feature, point_feature - points),
278
+ dim=3).permute(0, 3, 1, 2).contiguous()
279
+
280
+ if self.initial:
281
+ return point_feature
282
+
283
+ x = x.transpose(2, 1).contiguous() # bs, n, c
284
+ feature = x.view(batch_size * num_points, -1)[idx, :]
285
+ feature = feature.view(batch_size, num_points, self.k, num_dims) #bs, n, k, c
286
+ x = x.view(batch_size, num_points, 1, num_dims)
287
+ feature = feature - x
288
+
289
+ feature = feature.permute(0, 3, 1, 2).contiguous()
290
+ point_feature = self.xyz2feature(point_feature) #bs, c, n, k
291
+ feature = F.leaky_relu(feature + point_feature, 0.2)
292
+ return feature #bs, c, n, k
293
+
294
+
295
+ class PointNetFeaturePropagation(nn.Module):
296
+ def __init__(self, in_channel, mlp, att=None):
297
+ super(PointNetFeaturePropagation, self).__init__()
298
+ self.mlp_convs = nn.ModuleList()
299
+ self.mlp_bns = nn.ModuleList()
300
+ last_channel = in_channel
301
+ self.att = None
302
+ if att is not None:
303
+ self.att = Attention_block(F_g=att[0],F_l=att[1],F_int=att[2])
304
+
305
+ for out_channel in mlp:
306
+ self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
307
+ self.mlp_bns.append(nn.BatchNorm1d(out_channel))
308
+ last_channel = out_channel
309
+
310
+ def forward(self, xyz1, xyz2, points1, points2):
311
+ """
312
+ Input:
313
+ xyz1: input points position data, [B, C, N]
314
+ xyz2: sampled input points position data, [B, C, S], skipped xyz
315
+ points1: input points data, [B, D, N]
316
+ points2: input points data, [B, D, S], skipped features
317
+ Return:
318
+ new_points: upsampled points data, [B, D', N]
319
+ """
320
+ xyz1 = xyz1.permute(0, 2, 1)
321
+ xyz2 = xyz2.permute(0, 2, 1)
322
+
323
+ points2 = points2.permute(0, 2, 1)
324
+ B, N, C = xyz1.shape
325
+ _, S, _ = xyz2.shape
326
+
327
+ if S == 1:
328
+ interpolated_points = points2.repeat(1, N, 1)
329
+ else:
330
+ dists = square_distance(xyz1, xyz2)
331
+ dists, idx = dists.sort(dim=-1)
332
+ dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
333
+
334
+ dist_recip = 1.0 / (dists + 1e-8)
335
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
336
+ weight = dist_recip / norm
337
+ interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
338
+
339
+ # skip attention
340
+ if self.att is not None:
341
+ psix, psig = self.att(interpolated_points.permute(0, 2, 1), points1)
342
+ points1 = points1 * psix
343
+
344
+ if points1 is not None:
345
+ points1 = points1.permute(0, 2, 1)
346
+ new_points = torch.cat([points1, interpolated_points], dim=-1)
347
+ else:
348
+ new_points = interpolated_points
349
+
350
+ new_points = new_points.permute(0, 2, 1)
351
+
352
+ for i, conv in enumerate(self.mlp_convs):
353
+ bn = self.mlp_bns[i]
354
+ new_points = F.leaky_relu(bn(conv(new_points)), 0.2)
355
+
356
+ return new_points
357
+
358
+
359
+ class CIC(nn.Module):
360
+ def __init__(self, npoint, radius, k, in_channels, output_channels, bottleneck_ratio=2, mlp_num=2, curve_config=None):
361
+ super(CIC, self).__init__()
362
+ self.in_channels = in_channels
363
+ self.output_channels = output_channels
364
+ self.bottleneck_ratio = bottleneck_ratio
365
+ self.radius = radius
366
+ self.k = k
367
+ self.npoint = npoint
368
+
369
+ planes = in_channels // bottleneck_ratio
370
+
371
+ self.use_curve = curve_config is not None
372
+ if self.use_curve:
373
+ self.curveaggregation = CurveAggregation(planes)
374
+ self.curvegrouping = CurveGrouping(planes, k, curve_config[0], curve_config[1])
375
+
376
+ self.conv1 = nn.Sequential(
377
+ nn.Conv1d(in_channels,
378
+ planes,
379
+ kernel_size=1,
380
+ bias=False),
381
+ nn.BatchNorm1d(in_channels // bottleneck_ratio),
382
+ nn.LeakyReLU(negative_slope=0.2, inplace=True))
383
+
384
+ self.conv2 = nn.Sequential(
385
+ nn.Conv1d(planes, output_channels, kernel_size=1, bias=False),
386
+ nn.BatchNorm1d(output_channels))
387
+
388
+ if in_channels != output_channels:
389
+ self.shortcut = nn.Sequential(
390
+ nn.Conv1d(in_channels,
391
+ output_channels,
392
+ kernel_size=1,
393
+ bias=False),
394
+ nn.BatchNorm1d(output_channels))
395
+
396
+ self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
397
+
398
+ self.maxpool = MaskedMaxPool(npoint, radius, k)
399
+
400
+ self.lpfa = LPFA(planes, planes, k, mlp_num=mlp_num, initial=False)
401
+
402
+ def forward(self, xyz, x):
403
+
404
+ # max pool
405
+ if xyz.size(-1) != self.npoint:
406
+ xyz, x = self.maxpool(
407
+ xyz.transpose(1, 2).contiguous(), x)
408
+ xyz = xyz.transpose(1, 2)
409
+
410
+ shortcut = x
411
+ x = self.conv1(x) # bs, c', n
412
+
413
+ idx = knn(xyz, self.k, add_one_to_k=True)
414
+
415
+ if self.use_curve:
416
+ # curve grouping
417
+ curves, flatten_curve_idxs = self.curvegrouping(x, xyz, idx[:,:,1:]) # avoid self-loop
418
+
419
+ # curve aggregation
420
+ x = self.curveaggregation(x, curves)
421
+ else:
422
+ flatten_curve_idxs = None
423
+
424
+ x = self.lpfa(x, xyz, idx=idx[:,:,:self.k]) #bs, c', n, k
425
+
426
+ x = self.conv2(x) # bs, c, n
427
+
428
+ if self.in_channels != self.output_channels:
429
+ shortcut = self.shortcut(shortcut)
430
+
431
+ x = self.relu(x + shortcut)
432
+ return xyz, x, flatten_curve_idxs
433
+
434
+
435
+ class CurveAggregation(nn.Module):
436
+ def __init__(self, in_channel):
437
+ super(CurveAggregation, self).__init__()
438
+ self.in_channel = in_channel
439
+ mid_feature = in_channel // 2
440
+ self.conva = nn.Conv1d(in_channel,
441
+ mid_feature,
442
+ kernel_size=1,
443
+ bias=False)
444
+ self.convb = nn.Conv1d(in_channel,
445
+ mid_feature,
446
+ kernel_size=1,
447
+ bias=False)
448
+ self.convc = nn.Conv1d(in_channel,
449
+ mid_feature,
450
+ kernel_size=1,
451
+ bias=False)
452
+ self.convn = nn.Conv1d(mid_feature,
453
+ mid_feature,
454
+ kernel_size=1,
455
+ bias=False)
456
+ self.convl = nn.Conv1d(mid_feature,
457
+ mid_feature,
458
+ kernel_size=1,
459
+ bias=False)
460
+ self.convd = nn.Sequential(
461
+ nn.Conv1d(mid_feature * 2,
462
+ in_channel,
463
+ kernel_size=1,
464
+ bias=False),
465
+ nn.BatchNorm1d(in_channel))
466
+ self.line_conv_att = nn.Conv2d(in_channel,
467
+ 1,
468
+ kernel_size=1,
469
+ bias=False)
470
+
471
+ def forward(self, x, curves):
472
+ curves_att = self.line_conv_att(curves) # bs, 1, c_n, c_l
473
+
474
+ curver_inter = torch.sum(curves * F.softmax(curves_att, dim=-1), dim=-1) #bs, c, c_n
475
+ curves_intra = torch.sum(curves * F.softmax(curves_att, dim=-2), dim=-2) #bs, c, c_l
476
+
477
+ curver_inter = self.conva(curver_inter) # bs, mid, n
478
+ curves_intra = self.convb(curves_intra) # bs, mid ,n
479
+
480
+ x_logits = self.convc(x).transpose(1, 2).contiguous()
481
+ x_inter = F.softmax(torch.bmm(x_logits, curver_inter), dim=-1) # bs, n, c_n
482
+ x_intra = F.softmax(torch.bmm(x_logits, curves_intra), dim=-1) # bs, l, c_l
483
+
484
+
485
+ curver_inter = self.convn(curver_inter).transpose(1, 2).contiguous()
486
+ curves_intra = self.convl(curves_intra).transpose(1, 2).contiguous()
487
+
488
+ x_inter = torch.bmm(x_inter, curver_inter)
489
+ x_intra = torch.bmm(x_intra, curves_intra)
490
+
491
+ curve_features = torch.cat((x_inter, x_intra),dim=-1).transpose(1, 2).contiguous()
492
+ x = x + self.convd(curve_features)
493
+
494
+ return F.leaky_relu(x, negative_slope=0.2)
495
+
496
+
497
+ class CurveGrouping(nn.Module):
498
+ def __init__(self, in_channel, k, curve_num, curve_length):
499
+ super(CurveGrouping, self).__init__()
500
+ self.curve_num = curve_num
501
+ self.curve_length = curve_length
502
+ self.in_channel = in_channel
503
+ self.k = k
504
+
505
+ self.att = nn.Conv1d(in_channel, 1, kernel_size=1, bias=False)
506
+
507
+ self.walk = Walk(in_channel, k, curve_num, curve_length)
508
+
509
+ def forward(self, x, xyz, idx):
510
+ # starting point selection in self attention style
511
+ x_att = torch.sigmoid(self.att(x))
512
+ x = x * x_att
513
+
514
+ _, start_index = torch.topk(x_att,
515
+ self.curve_num,
516
+ dim=2,
517
+ sorted=False)
518
+ start_index = start_index.squeeze(1).unsqueeze(2)
519
+
520
+ curves, flatten_curve_idxs = self.walk(xyz, x, idx, start_index) #bs, c, c_n, c_l
521
+
522
+ return curves, flatten_curve_idxs
523
+
524
+
525
+ class MaskedMaxPool(nn.Module):
526
+ def __init__(self, npoint, radius, k):
527
+ super(MaskedMaxPool, self).__init__()
528
+ self.npoint = npoint
529
+ self.radius = radius
530
+ self.k = k
531
+
532
+ def forward(self, xyz, features):
533
+ sub_xyz, neighborhood_features = sample_and_group(self.npoint, self.radius, self.k, xyz, features.transpose(1,2))
534
+
535
+ neighborhood_features = neighborhood_features.permute(0, 3, 1, 2).contiguous()
536
+ sub_features = F.max_pool2d(
537
+ neighborhood_features, kernel_size=[1, neighborhood_features.shape[3]]
538
+ ) # bs, c, n, 1
539
+ sub_features = torch.squeeze(sub_features, -1) # bs, c, n
540
+ return sub_xyz, sub_features