learning3d 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {learning3d/data_utils → data_utils}/dataloaders.py +16 -14
- examples/test_curvenet.py +118 -0
- {learning3d/examples → examples}/test_dcp.py +3 -5
- {learning3d/examples → examples}/test_deepgmr.py +3 -5
- {learning3d/examples → examples}/test_masknet.py +1 -3
- {learning3d/examples → examples}/test_masknet2.py +1 -3
- {learning3d/examples → examples}/test_pcn.py +2 -4
- {learning3d/examples → examples}/test_pcrnet.py +1 -3
- {learning3d/examples → examples}/test_pnlk.py +1 -3
- {learning3d/examples → examples}/test_pointconv.py +1 -3
- {learning3d/examples → examples}/test_pointnet.py +1 -3
- {learning3d/examples → examples}/test_prnet.py +3 -5
- {learning3d/examples → examples}/test_rpmnet.py +1 -3
- {learning3d/examples → examples}/train_PointNetLK.py +2 -4
- {learning3d/examples → examples}/train_dcp.py +2 -4
- {learning3d/examples → examples}/train_deepgmr.py +2 -4
- {learning3d/examples → examples}/train_masknet.py +2 -4
- {learning3d/examples → examples}/train_pcn.py +2 -4
- {learning3d/examples → examples}/train_pcrnet.py +2 -4
- {learning3d/examples → examples}/train_pointconv.py +2 -4
- {learning3d/examples → examples}/train_pointnet.py +2 -4
- {learning3d/examples → examples}/train_prnet.py +2 -4
- {learning3d/examples → examples}/train_rpmnet.py +2 -4
- {learning3d-0.1.0.dist-info → learning3d-0.2.1.dist-info}/METADATA +57 -12
- learning3d-0.2.1.dist-info/RECORD +70 -0
- {learning3d-0.1.0.dist-info → learning3d-0.2.1.dist-info}/WHEEL +1 -1
- learning3d-0.2.1.dist-info/top_level.txt +6 -0
- {learning3d/models → models}/__init__.py +7 -1
- models/curvenet.py +130 -0
- {learning3d/models → models}/dgcnn.py +1 -35
- {learning3d/models → models}/prnet.py +5 -39
- utils/__init__.py +23 -0
- utils/curvenet_util.py +540 -0
- utils/model_common_utils.py +156 -0
- learning3d/losses/cuda/chamfer_distance/__init__.py +0 -1
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +0 -185
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +0 -209
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +0 -66
- learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +0 -41
- learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +0 -347
- learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +0 -18
- learning3d/losses/cuda/emd_torch/pkg/include/emd.h +0 -54
- learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +0 -1
- learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +0 -40
- learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +0 -70
- learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +0 -1
- learning3d/losses/cuda/emd_torch/setup.py +0 -29
- learning3d/ops/__init__.py +0 -0
- learning3d/utils/__init__.py +0 -4
- learning3d-0.1.0.dist-info/RECORD +0 -80
- learning3d-0.1.0.dist-info/top_level.txt +0 -1
- {learning3d/data_utils → data_utils}/__init__.py +0 -0
- {learning3d/data_utils → data_utils}/user_data.py +0 -0
- {learning3d-0.1.0.dist-info → learning3d-0.2.1.dist-info}/LICENSE +0 -0
- {learning3d/losses → losses}/__init__.py +0 -0
- {learning3d/losses → losses}/chamfer_distance.py +0 -0
- {learning3d/losses → losses}/classification.py +0 -0
- {learning3d/losses → losses}/correspondence_loss.py +0 -0
- {learning3d/losses → losses}/emd.py +0 -0
- {learning3d/losses → losses}/frobenius_norm.py +0 -0
- {learning3d/losses → losses}/rmse_features.py +0 -0
- {learning3d/models → models}/classifier.py +0 -0
- {learning3d/models → models}/dcp.py +0 -0
- {learning3d/models → models}/deepgmr.py +0 -0
- {learning3d/models → models}/masknet.py +0 -0
- {learning3d/models → models}/masknet2.py +0 -0
- {learning3d/models → models}/pcn.py +0 -0
- {learning3d/models → models}/pcrnet.py +0 -0
- {learning3d/models → models}/pointconv.py +0 -0
- {learning3d/models → models}/pointnet.py +0 -0
- {learning3d/models → models}/pointnetlk.py +0 -0
- {learning3d/models → models}/pooling.py +0 -0
- {learning3d/models → models}/ppfnet.py +0 -0
- {learning3d/models → models}/rpmnet.py +0 -0
- {learning3d/models → models}/segmentation.py +0 -0
- {learning3d → ops}/__init__.py +0 -0
- {learning3d/ops → ops}/data_utils.py +0 -0
- {learning3d/ops → ops}/invmat.py +0 -0
- {learning3d/ops → ops}/quaternion.py +0 -0
- {learning3d/ops → ops}/se3.py +0 -0
- {learning3d/ops → ops}/sinc.py +0 -0
- {learning3d/ops → ops}/so3.py +0 -0
- {learning3d/ops → ops}/transform_functions.py +0 -0
- {learning3d/utils → utils}/pointconv_util.py +0 -0
- {learning3d/utils → utils}/ppfnet_util.py +0 -0
- {learning3d/utils → utils}/svd.py +0 -0
- {learning3d/utils → utils}/transformer.py +0 -0
utils/curvenet_util.py
ADDED
@@ -0,0 +1,540 @@
|
|
1
|
+
"""
|
2
|
+
@Author: Yue Wang
|
3
|
+
@Contact: yuewangx@mit.edu
|
4
|
+
@File: pointnet_util.py
|
5
|
+
@Time: 2018/10/13 10:39 PM
|
6
|
+
|
7
|
+
Modified by
|
8
|
+
@Author: Tiange Xiang
|
9
|
+
@Contact: txia7609@uni.sydney.edu.au
|
10
|
+
@Time: 2021/01/21 3:10 PM
|
11
|
+
"""
|
12
|
+
|
13
|
+
import torch
|
14
|
+
import torch.nn as nn
|
15
|
+
import torch.nn.functional as F
|
16
|
+
from time import time
|
17
|
+
import numpy as np
|
18
|
+
from .model_common_utils import (
|
19
|
+
knn,
|
20
|
+
square_distance,
|
21
|
+
index_points,
|
22
|
+
farthest_point_sample,
|
23
|
+
query_ball_point,
|
24
|
+
)
|
25
|
+
|
26
|
+
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
|
27
|
+
"""
|
28
|
+
Input:
|
29
|
+
npoint:
|
30
|
+
radius:
|
31
|
+
nsample:
|
32
|
+
xyz: input points position data, [B, N, 3]
|
33
|
+
points: input points data, [B, N, D]
|
34
|
+
Return:
|
35
|
+
new_xyz: sampled points position data, [B, npoint, nsample, 3]
|
36
|
+
new_points: sampled points data, [B, npoint, nsample, 3+D]
|
37
|
+
"""
|
38
|
+
new_xyz = index_points(xyz, farthest_point_sample(xyz, npoint, start_with_first_point=True))
|
39
|
+
torch.cuda.empty_cache()
|
40
|
+
|
41
|
+
idx = query_ball_point(radius, nsample, xyz, new_xyz, get_cnt=False)
|
42
|
+
torch.cuda.empty_cache()
|
43
|
+
|
44
|
+
new_points = index_points(points, idx)
|
45
|
+
torch.cuda.empty_cache()
|
46
|
+
|
47
|
+
if returnfps:
|
48
|
+
return new_xyz, new_points, idx
|
49
|
+
else:
|
50
|
+
return new_xyz, new_points
|
51
|
+
|
52
|
+
def batched_index_select(input, dim, index):
|
53
|
+
views = [input.shape[0]] + \
|
54
|
+
[1 if i != dim else -1 for i in range(1, len(input.shape))]
|
55
|
+
expanse = list(input.shape)
|
56
|
+
expanse[0] = -1
|
57
|
+
expanse[dim] = -1
|
58
|
+
index = index.view(views).expand(expanse)
|
59
|
+
return torch.gather(input, dim, index)
|
60
|
+
|
61
|
+
def gumbel_softmax(logits, dim, temperature=1):
|
62
|
+
"""
|
63
|
+
ST-gumple-softmax w/o random gumbel samplings
|
64
|
+
input: [*, n_class]
|
65
|
+
return: flatten --> [*, n_class] an one-hot vector
|
66
|
+
"""
|
67
|
+
y = F.softmax(logits / temperature, dim=dim)
|
68
|
+
|
69
|
+
shape = y.size()
|
70
|
+
_, ind = y.max(dim=-1)
|
71
|
+
y_hard = torch.zeros_like(y).view(-1, shape[-1])
|
72
|
+
y_hard.scatter_(1, ind.view(-1, 1), 1)
|
73
|
+
y_hard = y_hard.view(*shape)
|
74
|
+
|
75
|
+
y_hard = (y_hard - y).detach() + y
|
76
|
+
return y_hard
|
77
|
+
|
78
|
+
class Walk(nn.Module):
|
79
|
+
'''
|
80
|
+
Walk in the cloud
|
81
|
+
'''
|
82
|
+
def __init__(self, in_channel, k, curve_num, curve_length):
|
83
|
+
super(Walk, self).__init__()
|
84
|
+
self.curve_num = curve_num
|
85
|
+
self.curve_length = curve_length
|
86
|
+
self.k = k
|
87
|
+
|
88
|
+
self.agent_mlp = nn.Sequential(
|
89
|
+
nn.Conv2d(in_channel * 2,
|
90
|
+
1,
|
91
|
+
kernel_size=1,
|
92
|
+
bias=False), nn.BatchNorm2d(1))
|
93
|
+
self.momentum_mlp = nn.Sequential(
|
94
|
+
nn.Conv1d(in_channel * 2,
|
95
|
+
2,
|
96
|
+
kernel_size=1,
|
97
|
+
bias=False), nn.BatchNorm1d(2))
|
98
|
+
|
99
|
+
def crossover_suppression(self, cur, neighbor, bn, n, k):
|
100
|
+
# cur: bs*n, 3
|
101
|
+
# neighbor: bs*n, 3, k
|
102
|
+
neighbor = neighbor.detach()
|
103
|
+
cur = cur.unsqueeze(-1).detach()
|
104
|
+
dot = torch.bmm(cur.transpose(1,2), neighbor) # bs*n, 1, k
|
105
|
+
norm1 = torch.norm(cur, dim=1, keepdim=True)
|
106
|
+
norm2 = torch.norm(neighbor, dim=1, keepdim=True)
|
107
|
+
divider = torch.clamp(norm1 * norm2, min=1e-8)
|
108
|
+
ans = torch.div(dot, divider).squeeze() # bs*n, k
|
109
|
+
|
110
|
+
# normalize to [0, 1]
|
111
|
+
ans = 1. + ans
|
112
|
+
ans = torch.clamp(ans, 0., 1.0)
|
113
|
+
|
114
|
+
return ans.detach()
|
115
|
+
|
116
|
+
def forward(self, xyz, x, adj, cur):
|
117
|
+
bn, c, tot_points = x.size()
|
118
|
+
device = x.device
|
119
|
+
|
120
|
+
# raw point coordinates
|
121
|
+
xyz = xyz.transpose(1,2).contiguous # bs, n, 3
|
122
|
+
|
123
|
+
# point features
|
124
|
+
x = x.transpose(1,2).contiguous() # bs, n, c
|
125
|
+
|
126
|
+
flatten_x = x.view(bn * tot_points, -1)
|
127
|
+
batch_offset = torch.arange(0, bn, device=device).detach() * tot_points
|
128
|
+
|
129
|
+
# indices of neighbors for the starting points
|
130
|
+
tmp_adj = (adj + batch_offset.view(-1,1,1)).view(adj.size(0)*adj.size(1),-1) #bs, n, k
|
131
|
+
|
132
|
+
# batch flattened indices for teh starting points
|
133
|
+
flatten_cur = (cur + batch_offset.view(-1,1,1)).view(-1)
|
134
|
+
|
135
|
+
curves = []
|
136
|
+
flatten_curve_idxs = [flatten_cur.unsqueeze(1)]
|
137
|
+
|
138
|
+
# one step at a time
|
139
|
+
for step in range(self.curve_length):
|
140
|
+
|
141
|
+
if step == 0:
|
142
|
+
# get starting point features using flattend indices
|
143
|
+
starting_points = flatten_x[flatten_cur, :].contiguous()
|
144
|
+
pre_feature = starting_points.view(bn, self.curve_num, -1, 1).transpose(1,2) # bs * n, c
|
145
|
+
else:
|
146
|
+
# dynamic momentum
|
147
|
+
cat_feature = torch.cat((cur_feature.squeeze(-1), pre_feature.squeeze(-1)),dim=1)
|
148
|
+
att_feature = F.softmax(self.momentum_mlp(cat_feature),dim=1).view(bn, 1, self.curve_num, 2) # bs, 1, n, 2
|
149
|
+
cat_feature = torch.cat((cur_feature, pre_feature),dim=-1) # bs, c, n, 2
|
150
|
+
|
151
|
+
# update curve descriptor
|
152
|
+
pre_feature = torch.sum(cat_feature * att_feature, dim=-1, keepdim=True) # bs, c, n
|
153
|
+
pre_feature_cos = pre_feature.transpose(1,2).contiguous().view(bn * self.curve_num, -1)
|
154
|
+
|
155
|
+
pick_idx = tmp_adj[flatten_cur] # bs*n, k
|
156
|
+
|
157
|
+
# get the neighbors of current points
|
158
|
+
pick_values = flatten_x[pick_idx.view(-1),:]
|
159
|
+
|
160
|
+
# reshape to fit crossover suppresion below
|
161
|
+
pick_values_cos = pick_values.view(bn * self.curve_num, self.k, c)
|
162
|
+
pick_values = pick_values_cos.view(bn, self.curve_num, self.k, c)
|
163
|
+
pick_values_cos = pick_values_cos.transpose(1,2).contiguous()
|
164
|
+
|
165
|
+
pick_values = pick_values.permute(0,3,1,2) # bs, c, n, k
|
166
|
+
|
167
|
+
pre_feature_expand = pre_feature.expand_as(pick_values)
|
168
|
+
|
169
|
+
# concat current point features with curve descriptors
|
170
|
+
pre_feature_expand = torch.cat((pick_values, pre_feature_expand),dim=1)
|
171
|
+
|
172
|
+
# which node to pick next?
|
173
|
+
pre_feature_expand = self.agent_mlp(pre_feature_expand) # bs, 1, n, k
|
174
|
+
|
175
|
+
if step !=0:
|
176
|
+
# cross over supression
|
177
|
+
d = self.crossover_suppression(cur_feature_cos - pre_feature_cos,
|
178
|
+
pick_values_cos - cur_feature_cos.unsqueeze(-1),
|
179
|
+
bn, self.curve_num, self.k)
|
180
|
+
d = d.view(bn, self.curve_num, self.k).unsqueeze(1) # bs, 1, n, k
|
181
|
+
pre_feature_expand = torch.mul(pre_feature_expand, d)
|
182
|
+
|
183
|
+
pre_feature_expand = gumbel_softmax(pre_feature_expand, -1) #bs, 1, n, k
|
184
|
+
|
185
|
+
cur_feature = torch.sum(pick_values * pre_feature_expand, dim=-1, keepdim=True) # bs, c, n, 1
|
186
|
+
|
187
|
+
cur_feature_cos = cur_feature.transpose(1,2).contiguous().view(bn * self.curve_num, c)
|
188
|
+
|
189
|
+
cur = torch.argmax(pre_feature_expand, dim=-1).view(-1, 1) # bs * n, 1
|
190
|
+
|
191
|
+
flatten_cur = batched_index_select(pick_idx, 1, cur).squeeze() # bs * n
|
192
|
+
|
193
|
+
# collect curve progress
|
194
|
+
curves.append(cur_feature)
|
195
|
+
flatten_curve_idxs.append(flatten_cur.unsqueeze(1))
|
196
|
+
|
197
|
+
return torch.cat(curves,dim=-1), torch.cat(flatten_curve_idxs, dim=1)
|
198
|
+
|
199
|
+
|
200
|
+
class Attention_block(nn.Module):
|
201
|
+
'''
|
202
|
+
Used in attention U-Net.
|
203
|
+
'''
|
204
|
+
def __init__(self,F_g,F_l,F_int):
|
205
|
+
super(Attention_block,self).__init__()
|
206
|
+
self.W_g = nn.Sequential(
|
207
|
+
nn.Conv1d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
|
208
|
+
nn.BatchNorm1d(F_int)
|
209
|
+
)
|
210
|
+
|
211
|
+
self.W_x = nn.Sequential(
|
212
|
+
nn.Conv1d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
|
213
|
+
nn.BatchNorm1d(F_int)
|
214
|
+
)
|
215
|
+
|
216
|
+
self.psi = nn.Sequential(
|
217
|
+
nn.Conv1d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
|
218
|
+
nn.BatchNorm1d(1),
|
219
|
+
nn.Sigmoid()
|
220
|
+
)
|
221
|
+
|
222
|
+
def forward(self,g,x):
|
223
|
+
g1 = self.W_g(g)
|
224
|
+
x1 = self.W_x(x)
|
225
|
+
psi = F.leaky_relu(g1+x1, negative_slope=0.2)
|
226
|
+
psi = self.psi(psi)
|
227
|
+
|
228
|
+
return psi, 1. - psi
|
229
|
+
|
230
|
+
|
231
|
+
class LPFA(nn.Module):
|
232
|
+
def __init__(self, in_channel, out_channel, k, mlp_num=2, initial=False):
|
233
|
+
super(LPFA, self).__init__()
|
234
|
+
self.k = k
|
235
|
+
self.initial = initial
|
236
|
+
|
237
|
+
if not initial:
|
238
|
+
self.xyz2feature = nn.Sequential(
|
239
|
+
nn.Conv2d(9, in_channel, kernel_size=1, bias=False),
|
240
|
+
nn.BatchNorm2d(in_channel))
|
241
|
+
|
242
|
+
self.mlp = []
|
243
|
+
for _ in range(mlp_num):
|
244
|
+
self.mlp.append(nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, bias=False),
|
245
|
+
nn.BatchNorm2d(out_channel),
|
246
|
+
nn.LeakyReLU(0.2)))
|
247
|
+
in_channel = out_channel
|
248
|
+
self.mlp = nn.Sequential(*self.mlp)
|
249
|
+
|
250
|
+
def forward(self, x, xyz, idx=None):
|
251
|
+
x = self.group_feature(x, xyz, idx)
|
252
|
+
x = self.mlp(x)
|
253
|
+
|
254
|
+
if self.initial:
|
255
|
+
x = x.max(dim=-1, keepdim=False)[0]
|
256
|
+
else:
|
257
|
+
x = x.mean(dim=-1, keepdim=False)
|
258
|
+
|
259
|
+
return x
|
260
|
+
|
261
|
+
def group_feature(self, x, xyz, idx):
|
262
|
+
batch_size, num_dims, num_points = x.size()
|
263
|
+
device = x.device
|
264
|
+
|
265
|
+
if idx is None:
|
266
|
+
idx = knn(xyz, k=self.k, add_one_to_k=True)[:,:,:self.k] # (batch_size, num_points, k)
|
267
|
+
|
268
|
+
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
|
269
|
+
idx = idx + idx_base
|
270
|
+
idx = idx.view(-1)
|
271
|
+
|
272
|
+
xyz = xyz.transpose(2, 1).contiguous() # bs, n, 3
|
273
|
+
point_feature = xyz.view(batch_size * num_points, -1)[idx, :]
|
274
|
+
point_feature = point_feature.view(batch_size, num_points, self.k, -1) # bs, n, k, 3
|
275
|
+
points = xyz.view(batch_size, num_points, 1, 3).expand(-1, -1, self.k, -1) # bs, n, k, 3
|
276
|
+
|
277
|
+
point_feature = torch.cat((points, point_feature, point_feature - points),
|
278
|
+
dim=3).permute(0, 3, 1, 2).contiguous()
|
279
|
+
|
280
|
+
if self.initial:
|
281
|
+
return point_feature
|
282
|
+
|
283
|
+
x = x.transpose(2, 1).contiguous() # bs, n, c
|
284
|
+
feature = x.view(batch_size * num_points, -1)[idx, :]
|
285
|
+
feature = feature.view(batch_size, num_points, self.k, num_dims) #bs, n, k, c
|
286
|
+
x = x.view(batch_size, num_points, 1, num_dims)
|
287
|
+
feature = feature - x
|
288
|
+
|
289
|
+
feature = feature.permute(0, 3, 1, 2).contiguous()
|
290
|
+
point_feature = self.xyz2feature(point_feature) #bs, c, n, k
|
291
|
+
feature = F.leaky_relu(feature + point_feature, 0.2)
|
292
|
+
return feature #bs, c, n, k
|
293
|
+
|
294
|
+
|
295
|
+
class PointNetFeaturePropagation(nn.Module):
|
296
|
+
def __init__(self, in_channel, mlp, att=None):
|
297
|
+
super(PointNetFeaturePropagation, self).__init__()
|
298
|
+
self.mlp_convs = nn.ModuleList()
|
299
|
+
self.mlp_bns = nn.ModuleList()
|
300
|
+
last_channel = in_channel
|
301
|
+
self.att = None
|
302
|
+
if att is not None:
|
303
|
+
self.att = Attention_block(F_g=att[0],F_l=att[1],F_int=att[2])
|
304
|
+
|
305
|
+
for out_channel in mlp:
|
306
|
+
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
|
307
|
+
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
|
308
|
+
last_channel = out_channel
|
309
|
+
|
310
|
+
def forward(self, xyz1, xyz2, points1, points2):
|
311
|
+
"""
|
312
|
+
Input:
|
313
|
+
xyz1: input points position data, [B, C, N]
|
314
|
+
xyz2: sampled input points position data, [B, C, S], skipped xyz
|
315
|
+
points1: input points data, [B, D, N]
|
316
|
+
points2: input points data, [B, D, S], skipped features
|
317
|
+
Return:
|
318
|
+
new_points: upsampled points data, [B, D', N]
|
319
|
+
"""
|
320
|
+
xyz1 = xyz1.permute(0, 2, 1)
|
321
|
+
xyz2 = xyz2.permute(0, 2, 1)
|
322
|
+
|
323
|
+
points2 = points2.permute(0, 2, 1)
|
324
|
+
B, N, C = xyz1.shape
|
325
|
+
_, S, _ = xyz2.shape
|
326
|
+
|
327
|
+
if S == 1:
|
328
|
+
interpolated_points = points2.repeat(1, N, 1)
|
329
|
+
else:
|
330
|
+
dists = square_distance(xyz1, xyz2)
|
331
|
+
dists, idx = dists.sort(dim=-1)
|
332
|
+
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
|
333
|
+
|
334
|
+
dist_recip = 1.0 / (dists + 1e-8)
|
335
|
+
norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
336
|
+
weight = dist_recip / norm
|
337
|
+
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
|
338
|
+
|
339
|
+
# skip attention
|
340
|
+
if self.att is not None:
|
341
|
+
psix, psig = self.att(interpolated_points.permute(0, 2, 1), points1)
|
342
|
+
points1 = points1 * psix
|
343
|
+
|
344
|
+
if points1 is not None:
|
345
|
+
points1 = points1.permute(0, 2, 1)
|
346
|
+
new_points = torch.cat([points1, interpolated_points], dim=-1)
|
347
|
+
else:
|
348
|
+
new_points = interpolated_points
|
349
|
+
|
350
|
+
new_points = new_points.permute(0, 2, 1)
|
351
|
+
|
352
|
+
for i, conv in enumerate(self.mlp_convs):
|
353
|
+
bn = self.mlp_bns[i]
|
354
|
+
new_points = F.leaky_relu(bn(conv(new_points)), 0.2)
|
355
|
+
|
356
|
+
return new_points
|
357
|
+
|
358
|
+
|
359
|
+
class CIC(nn.Module):
|
360
|
+
def __init__(self, npoint, radius, k, in_channels, output_channels, bottleneck_ratio=2, mlp_num=2, curve_config=None):
|
361
|
+
super(CIC, self).__init__()
|
362
|
+
self.in_channels = in_channels
|
363
|
+
self.output_channels = output_channels
|
364
|
+
self.bottleneck_ratio = bottleneck_ratio
|
365
|
+
self.radius = radius
|
366
|
+
self.k = k
|
367
|
+
self.npoint = npoint
|
368
|
+
|
369
|
+
planes = in_channels // bottleneck_ratio
|
370
|
+
|
371
|
+
self.use_curve = curve_config is not None
|
372
|
+
if self.use_curve:
|
373
|
+
self.curveaggregation = CurveAggregation(planes)
|
374
|
+
self.curvegrouping = CurveGrouping(planes, k, curve_config[0], curve_config[1])
|
375
|
+
|
376
|
+
self.conv1 = nn.Sequential(
|
377
|
+
nn.Conv1d(in_channels,
|
378
|
+
planes,
|
379
|
+
kernel_size=1,
|
380
|
+
bias=False),
|
381
|
+
nn.BatchNorm1d(in_channels // bottleneck_ratio),
|
382
|
+
nn.LeakyReLU(negative_slope=0.2, inplace=True))
|
383
|
+
|
384
|
+
self.conv2 = nn.Sequential(
|
385
|
+
nn.Conv1d(planes, output_channels, kernel_size=1, bias=False),
|
386
|
+
nn.BatchNorm1d(output_channels))
|
387
|
+
|
388
|
+
if in_channels != output_channels:
|
389
|
+
self.shortcut = nn.Sequential(
|
390
|
+
nn.Conv1d(in_channels,
|
391
|
+
output_channels,
|
392
|
+
kernel_size=1,
|
393
|
+
bias=False),
|
394
|
+
nn.BatchNorm1d(output_channels))
|
395
|
+
|
396
|
+
self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
397
|
+
|
398
|
+
self.maxpool = MaskedMaxPool(npoint, radius, k)
|
399
|
+
|
400
|
+
self.lpfa = LPFA(planes, planes, k, mlp_num=mlp_num, initial=False)
|
401
|
+
|
402
|
+
def forward(self, xyz, x):
|
403
|
+
|
404
|
+
# max pool
|
405
|
+
if xyz.size(-1) != self.npoint:
|
406
|
+
xyz, x = self.maxpool(
|
407
|
+
xyz.transpose(1, 2).contiguous(), x)
|
408
|
+
xyz = xyz.transpose(1, 2)
|
409
|
+
|
410
|
+
shortcut = x
|
411
|
+
x = self.conv1(x) # bs, c', n
|
412
|
+
|
413
|
+
idx = knn(xyz, self.k, add_one_to_k=True)
|
414
|
+
|
415
|
+
if self.use_curve:
|
416
|
+
# curve grouping
|
417
|
+
curves, flatten_curve_idxs = self.curvegrouping(x, xyz, idx[:,:,1:]) # avoid self-loop
|
418
|
+
|
419
|
+
# curve aggregation
|
420
|
+
x = self.curveaggregation(x, curves)
|
421
|
+
else:
|
422
|
+
flatten_curve_idxs = None
|
423
|
+
|
424
|
+
x = self.lpfa(x, xyz, idx=idx[:,:,:self.k]) #bs, c', n, k
|
425
|
+
|
426
|
+
x = self.conv2(x) # bs, c, n
|
427
|
+
|
428
|
+
if self.in_channels != self.output_channels:
|
429
|
+
shortcut = self.shortcut(shortcut)
|
430
|
+
|
431
|
+
x = self.relu(x + shortcut)
|
432
|
+
return xyz, x, flatten_curve_idxs
|
433
|
+
|
434
|
+
|
435
|
+
class CurveAggregation(nn.Module):
|
436
|
+
def __init__(self, in_channel):
|
437
|
+
super(CurveAggregation, self).__init__()
|
438
|
+
self.in_channel = in_channel
|
439
|
+
mid_feature = in_channel // 2
|
440
|
+
self.conva = nn.Conv1d(in_channel,
|
441
|
+
mid_feature,
|
442
|
+
kernel_size=1,
|
443
|
+
bias=False)
|
444
|
+
self.convb = nn.Conv1d(in_channel,
|
445
|
+
mid_feature,
|
446
|
+
kernel_size=1,
|
447
|
+
bias=False)
|
448
|
+
self.convc = nn.Conv1d(in_channel,
|
449
|
+
mid_feature,
|
450
|
+
kernel_size=1,
|
451
|
+
bias=False)
|
452
|
+
self.convn = nn.Conv1d(mid_feature,
|
453
|
+
mid_feature,
|
454
|
+
kernel_size=1,
|
455
|
+
bias=False)
|
456
|
+
self.convl = nn.Conv1d(mid_feature,
|
457
|
+
mid_feature,
|
458
|
+
kernel_size=1,
|
459
|
+
bias=False)
|
460
|
+
self.convd = nn.Sequential(
|
461
|
+
nn.Conv1d(mid_feature * 2,
|
462
|
+
in_channel,
|
463
|
+
kernel_size=1,
|
464
|
+
bias=False),
|
465
|
+
nn.BatchNorm1d(in_channel))
|
466
|
+
self.line_conv_att = nn.Conv2d(in_channel,
|
467
|
+
1,
|
468
|
+
kernel_size=1,
|
469
|
+
bias=False)
|
470
|
+
|
471
|
+
def forward(self, x, curves):
|
472
|
+
curves_att = self.line_conv_att(curves) # bs, 1, c_n, c_l
|
473
|
+
|
474
|
+
curver_inter = torch.sum(curves * F.softmax(curves_att, dim=-1), dim=-1) #bs, c, c_n
|
475
|
+
curves_intra = torch.sum(curves * F.softmax(curves_att, dim=-2), dim=-2) #bs, c, c_l
|
476
|
+
|
477
|
+
curver_inter = self.conva(curver_inter) # bs, mid, n
|
478
|
+
curves_intra = self.convb(curves_intra) # bs, mid ,n
|
479
|
+
|
480
|
+
x_logits = self.convc(x).transpose(1, 2).contiguous()
|
481
|
+
x_inter = F.softmax(torch.bmm(x_logits, curver_inter), dim=-1) # bs, n, c_n
|
482
|
+
x_intra = F.softmax(torch.bmm(x_logits, curves_intra), dim=-1) # bs, l, c_l
|
483
|
+
|
484
|
+
|
485
|
+
curver_inter = self.convn(curver_inter).transpose(1, 2).contiguous()
|
486
|
+
curves_intra = self.convl(curves_intra).transpose(1, 2).contiguous()
|
487
|
+
|
488
|
+
x_inter = torch.bmm(x_inter, curver_inter)
|
489
|
+
x_intra = torch.bmm(x_intra, curves_intra)
|
490
|
+
|
491
|
+
curve_features = torch.cat((x_inter, x_intra),dim=-1).transpose(1, 2).contiguous()
|
492
|
+
x = x + self.convd(curve_features)
|
493
|
+
|
494
|
+
return F.leaky_relu(x, negative_slope=0.2)
|
495
|
+
|
496
|
+
|
497
|
+
class CurveGrouping(nn.Module):
|
498
|
+
def __init__(self, in_channel, k, curve_num, curve_length):
|
499
|
+
super(CurveGrouping, self).__init__()
|
500
|
+
self.curve_num = curve_num
|
501
|
+
self.curve_length = curve_length
|
502
|
+
self.in_channel = in_channel
|
503
|
+
self.k = k
|
504
|
+
|
505
|
+
self.att = nn.Conv1d(in_channel, 1, kernel_size=1, bias=False)
|
506
|
+
|
507
|
+
self.walk = Walk(in_channel, k, curve_num, curve_length)
|
508
|
+
|
509
|
+
def forward(self, x, xyz, idx):
|
510
|
+
# starting point selection in self attention style
|
511
|
+
x_att = torch.sigmoid(self.att(x))
|
512
|
+
x = x * x_att
|
513
|
+
|
514
|
+
_, start_index = torch.topk(x_att,
|
515
|
+
self.curve_num,
|
516
|
+
dim=2,
|
517
|
+
sorted=False)
|
518
|
+
start_index = start_index.squeeze(1).unsqueeze(2)
|
519
|
+
|
520
|
+
curves, flatten_curve_idxs = self.walk(xyz, x, idx, start_index) #bs, c, c_n, c_l
|
521
|
+
|
522
|
+
return curves, flatten_curve_idxs
|
523
|
+
|
524
|
+
|
525
|
+
class MaskedMaxPool(nn.Module):
|
526
|
+
def __init__(self, npoint, radius, k):
|
527
|
+
super(MaskedMaxPool, self).__init__()
|
528
|
+
self.npoint = npoint
|
529
|
+
self.radius = radius
|
530
|
+
self.k = k
|
531
|
+
|
532
|
+
def forward(self, xyz, features):
|
533
|
+
sub_xyz, neighborhood_features = sample_and_group(self.npoint, self.radius, self.k, xyz, features.transpose(1,2))
|
534
|
+
|
535
|
+
neighborhood_features = neighborhood_features.permute(0, 3, 1, 2).contiguous()
|
536
|
+
sub_features = F.max_pool2d(
|
537
|
+
neighborhood_features, kernel_size=[1, neighborhood_features.shape[3]]
|
538
|
+
) # bs, c, n, 1
|
539
|
+
sub_features = torch.squeeze(sub_features, -1) # bs, c, n
|
540
|
+
return sub_xyz, sub_features
|
@@ -0,0 +1,156 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
def knn(x, k, add_one_to_k=False):
|
4
|
+
if add_one_to_k: k = k + 1
|
5
|
+
inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
|
6
|
+
xx = torch.sum(x**2, dim=1, keepdim=True)
|
7
|
+
pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()
|
8
|
+
idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
|
9
|
+
return idx
|
10
|
+
|
11
|
+
def pc_normalize(pc):
|
12
|
+
l = pc.shape[0]
|
13
|
+
centroid = np.mean(pc, axis=0)
|
14
|
+
pc = pc - centroid
|
15
|
+
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
|
16
|
+
pc = pc / m
|
17
|
+
return pc
|
18
|
+
|
19
|
+
def square_distance(src, dst):
|
20
|
+
"""
|
21
|
+
Calculate Euclid distance between each two points.
|
22
|
+
src^T * dst = xn * xm + yn * ym + zn * zm;
|
23
|
+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
24
|
+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
25
|
+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
26
|
+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
27
|
+
Input:
|
28
|
+
src: source points, [B, N, C]
|
29
|
+
dst: target points, [B, M, C]
|
30
|
+
Output:
|
31
|
+
dist: per-point square distance, [B, N, M]
|
32
|
+
"""
|
33
|
+
B, N, _ = src.shape
|
34
|
+
_, M, _ = dst.shape
|
35
|
+
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
36
|
+
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
37
|
+
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
38
|
+
return dist
|
39
|
+
|
40
|
+
def index_points(points, idx):
|
41
|
+
"""
|
42
|
+
Input:
|
43
|
+
points: input points data, [B, N, C]
|
44
|
+
idx: sample index data, [B, S]
|
45
|
+
Return:
|
46
|
+
new_points:, indexed points data, [B, S, C]
|
47
|
+
"""
|
48
|
+
device = points.device
|
49
|
+
B = points.shape[0]
|
50
|
+
view_shape = list(idx.shape)
|
51
|
+
view_shape[1:] = [1] * (len(view_shape) - 1)
|
52
|
+
repeat_shape = list(idx.shape)
|
53
|
+
repeat_shape[0] = 1
|
54
|
+
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
55
|
+
new_points = points[batch_indices, idx, :]
|
56
|
+
return new_points
|
57
|
+
|
58
|
+
def farthest_point_sample(xyz, npoint, start_with_first_point=False):
|
59
|
+
"""
|
60
|
+
Input:
|
61
|
+
xyz: pointcloud data, [B, N, C]
|
62
|
+
npoint: number of samples
|
63
|
+
Return:
|
64
|
+
centroids: sampled pointcloud index, [B, npoint]
|
65
|
+
"""
|
66
|
+
device = xyz.device
|
67
|
+
B, N, C = xyz.shape
|
68
|
+
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
|
69
|
+
distance = torch.ones(B, N).to(device) * 1e10
|
70
|
+
if not start_with_first_point:
|
71
|
+
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
|
72
|
+
else:
|
73
|
+
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) * 0
|
74
|
+
batch_indices = torch.arange(B, dtype=torch.long).to(device)
|
75
|
+
for i in range(npoint):
|
76
|
+
centroids[:, i] = farthest
|
77
|
+
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
|
78
|
+
dist = torch.sum((xyz - centroid) ** 2, -1)
|
79
|
+
mask = dist < distance
|
80
|
+
distance[mask] = dist[mask]
|
81
|
+
farthest = torch.max(distance, -1)[1]
|
82
|
+
return centroids
|
83
|
+
|
84
|
+
def knn_point(k, pos1, pos2):
|
85
|
+
'''
|
86
|
+
Input:
|
87
|
+
k: int32, number of k in k-nn search
|
88
|
+
pos1: (batch_size, ndataset, c) float32 array, input points
|
89
|
+
pos2: (batch_size, npoint, c) float32 array, query points
|
90
|
+
Output:
|
91
|
+
val: (batch_size, npoint, k) float32 array, L2 distances
|
92
|
+
idx: (batch_size, npoint, k) int32 array, indices to input points
|
93
|
+
'''
|
94
|
+
B, N, C = pos1.shape
|
95
|
+
M = pos2.shape[1]
|
96
|
+
pos1 = pos1.view(B,1,N,-1).repeat(1,M,1,1)
|
97
|
+
pos2 = pos2.view(B,M,1,-1).repeat(1,1,N,1)
|
98
|
+
dist = torch.sum(-(pos1-pos2)**2,-1)
|
99
|
+
val,idx = dist.topk(k=k,dim = -1)
|
100
|
+
return torch.sqrt(-val), idx
|
101
|
+
|
102
|
+
def query_ball_point(radius, nsample, xyz, new_xyz, get_cnt=False):
|
103
|
+
"""
|
104
|
+
Input:
|
105
|
+
radius: local region radius
|
106
|
+
nsample: max sample number in local region
|
107
|
+
xyz: all points, [B, N, C]
|
108
|
+
new_xyz: query points, [B, S, C]
|
109
|
+
Return:
|
110
|
+
group_idx: grouped points index, [B, S, nsample]
|
111
|
+
"""
|
112
|
+
device = xyz.device
|
113
|
+
B, N, C = xyz.shape
|
114
|
+
_, S, _ = new_xyz.shape
|
115
|
+
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
116
|
+
sqrdists = square_distance(new_xyz, xyz)
|
117
|
+
group_idx[sqrdists > radius ** 2] = N
|
118
|
+
|
119
|
+
if get_cnt:
|
120
|
+
mask = group_idx != N
|
121
|
+
cnt = mask.sum(dim=-1)
|
122
|
+
|
123
|
+
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
124
|
+
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
125
|
+
mask = group_idx == N
|
126
|
+
group_idx[mask] = group_first[mask]
|
127
|
+
if get_cnt:
|
128
|
+
return group_idx, cnt
|
129
|
+
else:
|
130
|
+
return group_idx
|
131
|
+
|
132
|
+
def get_graph_feature(x, k=20, device=None):
|
133
|
+
# x = x.squeeze()
|
134
|
+
x = x.view(*x.size()[:3])
|
135
|
+
idx = knn(x, k=k) # (batch_size, num_points, k)
|
136
|
+
batch_size, num_points, _ = idx.size()
|
137
|
+
|
138
|
+
if device is None:
|
139
|
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
140
|
+
|
141
|
+
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
|
142
|
+
|
143
|
+
idx = idx + idx_base
|
144
|
+
|
145
|
+
idx = idx.view(-1)
|
146
|
+
|
147
|
+
_, num_dims, _ = x.size()
|
148
|
+
|
149
|
+
x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
|
150
|
+
feature = x.view(batch_size * num_points, -1)[idx, :]
|
151
|
+
feature = feature.view(batch_size, num_points, k, num_dims)
|
152
|
+
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
|
153
|
+
|
154
|
+
feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
|
155
|
+
|
156
|
+
return feature
|
@@ -1 +0,0 @@
|
|
1
|
-
from .chamfer_distance import ChamferDistance
|