learning3d 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- learning3d/__init__.py +2 -0
- learning3d/data_utils/__init__.py +4 -0
- learning3d/data_utils/dataloaders.py +454 -0
- learning3d/data_utils/user_data.py +119 -0
- learning3d/examples/test_dcp.py +139 -0
- learning3d/examples/test_deepgmr.py +144 -0
- learning3d/examples/test_flownet.py +113 -0
- learning3d/examples/test_masknet.py +159 -0
- learning3d/examples/test_masknet2.py +162 -0
- learning3d/examples/test_pcn.py +118 -0
- learning3d/examples/test_pcrnet.py +120 -0
- learning3d/examples/test_pnlk.py +121 -0
- learning3d/examples/test_pointconv.py +126 -0
- learning3d/examples/test_pointnet.py +121 -0
- learning3d/examples/test_prnet.py +126 -0
- learning3d/examples/test_rpmnet.py +120 -0
- learning3d/examples/train_PointNetLK.py +240 -0
- learning3d/examples/train_dcp.py +249 -0
- learning3d/examples/train_deepgmr.py +244 -0
- learning3d/examples/train_flownet.py +259 -0
- learning3d/examples/train_masknet.py +239 -0
- learning3d/examples/train_pcn.py +216 -0
- learning3d/examples/train_pcrnet.py +228 -0
- learning3d/examples/train_pointconv.py +245 -0
- learning3d/examples/train_pointnet.py +244 -0
- learning3d/examples/train_prnet.py +229 -0
- learning3d/examples/train_rpmnet.py +228 -0
- learning3d/losses/__init__.py +12 -0
- learning3d/losses/chamfer_distance.py +51 -0
- learning3d/losses/classification.py +14 -0
- learning3d/losses/correspondence_loss.py +10 -0
- learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
- learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
- learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
- learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
- learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
- learning3d/losses/cuda/emd_torch/setup.py +29 -0
- learning3d/losses/emd.py +16 -0
- learning3d/losses/frobenius_norm.py +21 -0
- learning3d/losses/rmse_features.py +16 -0
- learning3d/models/__init__.py +23 -0
- learning3d/models/classifier.py +41 -0
- learning3d/models/dcp.py +92 -0
- learning3d/models/deepgmr.py +165 -0
- learning3d/models/dgcnn.py +92 -0
- learning3d/models/flownet3d.py +446 -0
- learning3d/models/masknet.py +84 -0
- learning3d/models/masknet2.py +264 -0
- learning3d/models/pcn.py +164 -0
- learning3d/models/pcrnet.py +74 -0
- learning3d/models/pointconv.py +108 -0
- learning3d/models/pointnet.py +108 -0
- learning3d/models/pointnetlk.py +173 -0
- learning3d/models/pooling.py +15 -0
- learning3d/models/ppfnet.py +102 -0
- learning3d/models/prnet.py +431 -0
- learning3d/models/rpmnet.py +359 -0
- learning3d/models/segmentation.py +38 -0
- learning3d/ops/__init__.py +0 -0
- learning3d/ops/data_utils.py +45 -0
- learning3d/ops/invmat.py +134 -0
- learning3d/ops/quaternion.py +218 -0
- learning3d/ops/se3.py +157 -0
- learning3d/ops/sinc.py +229 -0
- learning3d/ops/so3.py +213 -0
- learning3d/ops/transform_functions.py +342 -0
- learning3d/utils/__init__.py +9 -0
- learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
- learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
- learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
- learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
- learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
- learning3d/utils/lib/pointnet2_modules.py +160 -0
- learning3d/utils/lib/pointnet2_utils.py +318 -0
- learning3d/utils/lib/pytorch_utils.py +236 -0
- learning3d/utils/lib/setup.py +23 -0
- learning3d/utils/lib/src/ball_query.cpp +25 -0
- learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
- learning3d/utils/lib/src/ball_query_gpu.h +15 -0
- learning3d/utils/lib/src/cuda_utils.h +15 -0
- learning3d/utils/lib/src/group_points.cpp +36 -0
- learning3d/utils/lib/src/group_points_gpu.cu +86 -0
- learning3d/utils/lib/src/group_points_gpu.h +22 -0
- learning3d/utils/lib/src/interpolate.cpp +65 -0
- learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
- learning3d/utils/lib/src/interpolate_gpu.h +36 -0
- learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
- learning3d/utils/lib/src/sampling.cpp +46 -0
- learning3d/utils/lib/src/sampling_gpu.cu +253 -0
- learning3d/utils/lib/src/sampling_gpu.h +29 -0
- learning3d/utils/pointconv_util.py +382 -0
- learning3d/utils/ppfnet_util.py +244 -0
- learning3d/utils/svd.py +59 -0
- learning3d/utils/transformer.py +243 -0
- learning3d-0.0.1.dist-info/LICENSE +21 -0
- learning3d-0.0.1.dist-info/METADATA +271 -0
- learning3d-0.0.1.dist-info/RECORD +115 -0
- learning3d-0.0.1.dist-info/WHEEL +5 -0
- learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,446 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from time import time
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
try:
|
8
|
+
from .. utils import pointnet2_utils as pointutils
|
9
|
+
except:
|
10
|
+
print("Error in pointnet2_utils! Retry setup for pointnet2_utils.")
|
11
|
+
|
12
|
+
def timeit(tag, t):
|
13
|
+
print("{}: {}s".format(tag, time() - t))
|
14
|
+
return time()
|
15
|
+
|
16
|
+
def pc_normalize(pc):
|
17
|
+
l = pc.shape[0]
|
18
|
+
centroid = np.mean(pc, axis=0)
|
19
|
+
pc = pc - centroid
|
20
|
+
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
|
21
|
+
pc = pc / m
|
22
|
+
return pc
|
23
|
+
|
24
|
+
def square_distance(src, dst):
|
25
|
+
"""
|
26
|
+
Calculate Euclid distance between each two points.
|
27
|
+
src^T * dst = xn * xm + yn * ym + zn * zm;
|
28
|
+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
29
|
+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
30
|
+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
31
|
+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
32
|
+
Input:
|
33
|
+
src: source points, [B, N, C]
|
34
|
+
dst: target points, [B, M, C]
|
35
|
+
Output:
|
36
|
+
dist: per-point square distance, [B, N, M]
|
37
|
+
"""
|
38
|
+
B, N, _ = src.shape
|
39
|
+
_, M, _ = dst.shape
|
40
|
+
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
41
|
+
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
42
|
+
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
43
|
+
return dist
|
44
|
+
|
45
|
+
|
46
|
+
def index_points(points, idx):
|
47
|
+
"""
|
48
|
+
Input:
|
49
|
+
points: input points data, [B, N, C]
|
50
|
+
idx: sample index data, [B, S]
|
51
|
+
Return:
|
52
|
+
new_points:, indexed points data, [B, S, C]
|
53
|
+
"""
|
54
|
+
device = points.device
|
55
|
+
B = points.shape[0]
|
56
|
+
view_shape = list(idx.shape)
|
57
|
+
view_shape[1:] = [1] * (len(view_shape) - 1)
|
58
|
+
repeat_shape = list(idx.shape)
|
59
|
+
repeat_shape[0] = 1
|
60
|
+
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
61
|
+
new_points = points[batch_indices, idx, :]
|
62
|
+
return new_points
|
63
|
+
|
64
|
+
|
65
|
+
def farthest_point_sample(xyz, npoint):
|
66
|
+
"""
|
67
|
+
Input:
|
68
|
+
xyz: pointcloud data, [B, N, C]
|
69
|
+
npoint: number of samples
|
70
|
+
Return:
|
71
|
+
centroids: sampled pointcloud index, [B, npoint]
|
72
|
+
"""
|
73
|
+
device = xyz.device
|
74
|
+
B, N, C = xyz.shape
|
75
|
+
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
|
76
|
+
distance = torch.ones(B, N).to(device) * 1e10
|
77
|
+
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
|
78
|
+
batch_indices = torch.arange(B, dtype=torch.long).to(device)
|
79
|
+
for i in range(npoint):
|
80
|
+
centroids[:, i] = farthest
|
81
|
+
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
|
82
|
+
dist = torch.sum((xyz - centroid) ** 2, -1)
|
83
|
+
mask = dist < distance
|
84
|
+
distance[mask] = dist[mask]
|
85
|
+
farthest = torch.max(distance, -1)[1]
|
86
|
+
return centroids
|
87
|
+
|
88
|
+
def knn_point(k, pos1, pos2):
|
89
|
+
'''
|
90
|
+
Input:
|
91
|
+
k: int32, number of k in k-nn search
|
92
|
+
pos1: (batch_size, ndataset, c) float32 array, input points
|
93
|
+
pos2: (batch_size, npoint, c) float32 array, query points
|
94
|
+
Output:
|
95
|
+
val: (batch_size, npoint, k) float32 array, L2 distances
|
96
|
+
idx: (batch_size, npoint, k) int32 array, indices to input points
|
97
|
+
'''
|
98
|
+
B, N, C = pos1.shape
|
99
|
+
M = pos2.shape[1]
|
100
|
+
pos1 = pos1.view(B,1,N,-1).repeat(1,M,1,1)
|
101
|
+
pos2 = pos2.view(B,M,1,-1).repeat(1,1,N,1)
|
102
|
+
dist = torch.sum(-(pos1-pos2)**2,-1)
|
103
|
+
val,idx = dist.topk(k=k,dim = -1)
|
104
|
+
return torch.sqrt(-val), idx
|
105
|
+
|
106
|
+
|
107
|
+
def query_ball_point(radius, nsample, xyz, new_xyz):
|
108
|
+
"""
|
109
|
+
Input:
|
110
|
+
radius: local region radius
|
111
|
+
nsample: max sample number in local region
|
112
|
+
xyz: all points, [B, N, C]
|
113
|
+
new_xyz: query points, [B, S, C]
|
114
|
+
Return:
|
115
|
+
group_idx: grouped points index, [B, S, nsample]
|
116
|
+
"""
|
117
|
+
device = xyz.device
|
118
|
+
B, N, C = xyz.shape
|
119
|
+
_, S, _ = new_xyz.shape
|
120
|
+
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
121
|
+
sqrdists = square_distance(new_xyz, xyz)
|
122
|
+
group_idx[sqrdists > radius ** 2] = N
|
123
|
+
mask = group_idx != N
|
124
|
+
cnt = mask.sum(dim=-1)
|
125
|
+
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
126
|
+
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
127
|
+
mask = group_idx == N
|
128
|
+
group_idx[mask] = group_first[mask]
|
129
|
+
return group_idx, cnt
|
130
|
+
|
131
|
+
|
132
|
+
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
|
133
|
+
"""
|
134
|
+
Input:
|
135
|
+
npoint:
|
136
|
+
radius:
|
137
|
+
nsample:
|
138
|
+
xyz: input points position data, [B, N, C]
|
139
|
+
points: input points data, [B, N, D]
|
140
|
+
Return:
|
141
|
+
new_xyz: sampled points position data, [B, 1, C]
|
142
|
+
new_points: sampled points data, [B, 1, N, C+D]
|
143
|
+
"""
|
144
|
+
B, N, C = xyz.shape
|
145
|
+
S = npoint
|
146
|
+
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
|
147
|
+
new_xyz = index_points(xyz, fps_idx)
|
148
|
+
idx, _ = query_ball_point(radius, nsample, xyz, new_xyz)
|
149
|
+
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
|
150
|
+
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
|
151
|
+
if points is not None:
|
152
|
+
grouped_points = index_points(points, idx)
|
153
|
+
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
|
154
|
+
else:
|
155
|
+
new_points = grouped_xyz_norm
|
156
|
+
if returnfps:
|
157
|
+
return new_xyz, new_points, grouped_xyz, fps_idx
|
158
|
+
else:
|
159
|
+
return new_xyz, new_points
|
160
|
+
|
161
|
+
|
162
|
+
def sample_and_group_all(xyz, points):
|
163
|
+
"""
|
164
|
+
Input:
|
165
|
+
xyz: input points position data, [B, N, C]
|
166
|
+
points: input points data, [B, N, D]
|
167
|
+
Return:
|
168
|
+
new_xyz: sampled points position data, [B, 1, C]
|
169
|
+
new_points: sampled points data, [B, 1, N, C+D]
|
170
|
+
"""
|
171
|
+
device = xyz.device
|
172
|
+
B, N, C = xyz.shape
|
173
|
+
new_xyz = torch.zeros(B, 1, C).to(device)
|
174
|
+
grouped_xyz = xyz.view(B, 1, N, C)
|
175
|
+
if points is not None:
|
176
|
+
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
|
177
|
+
else:
|
178
|
+
new_points = grouped_xyz
|
179
|
+
return new_xyz, new_points
|
180
|
+
|
181
|
+
class PointNetSetAbstraction(nn.Module):
|
182
|
+
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
|
183
|
+
super(PointNetSetAbstraction, self).__init__()
|
184
|
+
self.npoint = npoint
|
185
|
+
self.radius = radius
|
186
|
+
self.nsample = nsample
|
187
|
+
self.group_all = group_all
|
188
|
+
self.mlp_convs = nn.ModuleList()
|
189
|
+
self.mlp_bns = nn.ModuleList()
|
190
|
+
last_channel = in_channel+3 # TODO:
|
191
|
+
for out_channel in mlp:
|
192
|
+
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias = False))
|
193
|
+
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
194
|
+
last_channel = out_channel
|
195
|
+
|
196
|
+
if group_all:
|
197
|
+
self.queryandgroup = pointutils.GroupAll()
|
198
|
+
else:
|
199
|
+
self.queryandgroup = pointutils.QueryAndGroup(radius, nsample)
|
200
|
+
|
201
|
+
def forward(self, xyz, points):
|
202
|
+
"""
|
203
|
+
Input:
|
204
|
+
xyz: input points position data, [B, C, N]
|
205
|
+
points: input points data, [B, D, N]
|
206
|
+
Return:
|
207
|
+
new_xyz: sampled points position data, [B, S, C]
|
208
|
+
new_points_concat: sample points feature data, [B, S, D']
|
209
|
+
"""
|
210
|
+
device = xyz.device
|
211
|
+
B, C, N = xyz.shape
|
212
|
+
xyz_t = xyz.permute(0, 2, 1).contiguous()
|
213
|
+
# if points is not None:
|
214
|
+
# points = points.permute(0, 2, 1).contiguous()
|
215
|
+
|
216
|
+
# 选取邻域点
|
217
|
+
if self.group_all == False:
|
218
|
+
fps_idx = pointutils.furthest_point_sample(xyz_t, self.npoint) # [B, N]
|
219
|
+
new_xyz = pointutils.gather_operation(xyz, fps_idx) # [B, C, N]
|
220
|
+
else:
|
221
|
+
new_xyz = xyz
|
222
|
+
new_points = self.queryandgroup(xyz_t, new_xyz.transpose(2, 1).contiguous(), points) # [B, 3+C, N, S]
|
223
|
+
|
224
|
+
# new_xyz: sampled points position data, [B, C, npoint]
|
225
|
+
# new_points: sampled points data, [B, C+D, npoint, nsample]
|
226
|
+
for i, conv in enumerate(self.mlp_convs):
|
227
|
+
bn = self.mlp_bns[i]
|
228
|
+
new_points = F.relu(bn(conv(new_points)))
|
229
|
+
|
230
|
+
new_points = torch.max(new_points, -1)[0]
|
231
|
+
return new_xyz, new_points
|
232
|
+
|
233
|
+
class FlowEmbedding(nn.Module):
|
234
|
+
def __init__(self, radius, nsample, in_channel, mlp, pooling='max', corr_func='concat', knn = True):
|
235
|
+
super(FlowEmbedding, self).__init__()
|
236
|
+
self.radius = radius
|
237
|
+
self.nsample = nsample
|
238
|
+
self.knn = knn
|
239
|
+
self.pooling = pooling
|
240
|
+
self.corr_func = corr_func
|
241
|
+
self.mlp_convs = nn.ModuleList()
|
242
|
+
self.mlp_bns = nn.ModuleList()
|
243
|
+
if corr_func is 'concat':
|
244
|
+
last_channel = in_channel*2+3
|
245
|
+
for out_channel in mlp:
|
246
|
+
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias=False))
|
247
|
+
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
248
|
+
last_channel = out_channel
|
249
|
+
|
250
|
+
def forward(self, pos1, pos2, feature1, feature2):
|
251
|
+
"""
|
252
|
+
Input:
|
253
|
+
xyz1: (batch_size, 3, npoint)
|
254
|
+
xyz2: (batch_size, 3, npoint)
|
255
|
+
feat1: (batch_size, channel, npoint)
|
256
|
+
feat2: (batch_size, channel, npoint)
|
257
|
+
Output:
|
258
|
+
xyz1: (batch_size, 3, npoint)
|
259
|
+
feat1_new: (batch_size, mlp[-1], npoint)
|
260
|
+
"""
|
261
|
+
pos1_t = pos1.permute(0, 2, 1).contiguous()
|
262
|
+
pos2_t = pos2.permute(0, 2, 1).contiguous()
|
263
|
+
B, N, C = pos1_t.shape
|
264
|
+
if self.knn:
|
265
|
+
_, idx = pointutils.knn(self.nsample, pos1_t, pos2_t)
|
266
|
+
else:
|
267
|
+
# If the ball neighborhood points are less than nsample,
|
268
|
+
# than use the knn neighborhood points
|
269
|
+
idx, cnt = query_ball_point(self.radius, self.nsample, pos2_t, pos1_t)
|
270
|
+
# 利用knn取最近的那些点
|
271
|
+
_, idx_knn = pointutils.knn(self.nsample, pos1_t, pos2_t)
|
272
|
+
cnt = cnt.view(B, -1, 1).repeat(1, 1, self.nsample)
|
273
|
+
idx = idx_knn[cnt > (self.nsample-1)]
|
274
|
+
|
275
|
+
pos2_grouped = pointutils.grouping_operation(pos2, idx) # [B, 3, N, S]
|
276
|
+
pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B, 3, N, S]
|
277
|
+
|
278
|
+
feat2_grouped = pointutils.grouping_operation(feature2, idx) # [B, C, N, S]
|
279
|
+
if self.corr_func=='concat':
|
280
|
+
feat_diff = torch.cat([feat2_grouped, feature1.view(B, -1, N, 1).repeat(1, 1, 1, self.nsample)], dim = 1)
|
281
|
+
|
282
|
+
feat1_new = torch.cat([pos_diff, feat_diff], dim = 1) # [B, 2*C+3,N,S]
|
283
|
+
for i, conv in enumerate(self.mlp_convs):
|
284
|
+
bn = self.mlp_bns[i]
|
285
|
+
feat1_new = F.relu(bn(conv(feat1_new)))
|
286
|
+
|
287
|
+
feat1_new = torch.max(feat1_new, -1)[0] # [B, mlp[-1], npoint]
|
288
|
+
return pos1, feat1_new
|
289
|
+
|
290
|
+
class PointNetSetUpConv(nn.Module):
|
291
|
+
def __init__(self, nsample, radius, f1_channel, f2_channel, mlp, mlp2, knn = True):
|
292
|
+
super(PointNetSetUpConv, self).__init__()
|
293
|
+
self.nsample = nsample
|
294
|
+
self.radius = radius
|
295
|
+
self.knn = knn
|
296
|
+
self.mlp1_convs = nn.ModuleList()
|
297
|
+
self.mlp2_convs = nn.ModuleList()
|
298
|
+
last_channel = f2_channel+3
|
299
|
+
for out_channel in mlp:
|
300
|
+
self.mlp1_convs.append(nn.Sequential(nn.Conv2d(last_channel, out_channel, 1, bias=False),
|
301
|
+
nn.BatchNorm2d(out_channel),
|
302
|
+
nn.ReLU(inplace=False)))
|
303
|
+
last_channel = out_channel
|
304
|
+
if len(mlp) is not 0:
|
305
|
+
last_channel = mlp[-1] + f1_channel
|
306
|
+
else:
|
307
|
+
last_channel = last_channel + f1_channel
|
308
|
+
for out_channel in mlp2:
|
309
|
+
self.mlp2_convs.append(nn.Sequential(nn.Conv1d(last_channel, out_channel, 1, bias=False),
|
310
|
+
nn.BatchNorm1d(out_channel),
|
311
|
+
nn.ReLU(inplace=False)))
|
312
|
+
last_channel = out_channel
|
313
|
+
|
314
|
+
def forward(self, pos1, pos2, feature1, feature2):
|
315
|
+
"""
|
316
|
+
Feature propagation from xyz2 (less points) to xyz1 (more points)
|
317
|
+
Inputs:
|
318
|
+
xyz1: (batch_size, 3, npoint1)
|
319
|
+
xyz2: (batch_size, 3, npoint2)
|
320
|
+
feat1: (batch_size, channel1, npoint1) features for xyz1 points (earlier layers, more points)
|
321
|
+
feat2: (batch_size, channel1, npoint2) features for xyz2 points
|
322
|
+
Output:
|
323
|
+
feat1_new: (batch_size, npoint2, mlp[-1] or mlp2[-1] or channel1+3)
|
324
|
+
TODO: Add support for skip links. Study how delta(XYZ) plays a role in feature updating.
|
325
|
+
"""
|
326
|
+
pos1_t = pos1.permute(0, 2, 1).contiguous()
|
327
|
+
pos2_t = pos2.permute(0, 2, 1).contiguous()
|
328
|
+
B,C,N = pos1.shape
|
329
|
+
if self.knn:
|
330
|
+
_, idx = pointutils.knn(self.nsample, pos1_t, pos2_t)
|
331
|
+
else:
|
332
|
+
idx, _ = query_ball_point(self.radius, self.nsample, pos2_t, pos1_t)
|
333
|
+
|
334
|
+
pos2_grouped = pointutils.grouping_operation(pos2, idx)
|
335
|
+
pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B,3,N1,S]
|
336
|
+
|
337
|
+
feat2_grouped = pointutils.grouping_operation(feature2, idx)
|
338
|
+
feat_new = torch.cat([feat2_grouped, pos_diff], dim = 1) # [B,C1+3,N1,S]
|
339
|
+
for conv in self.mlp1_convs:
|
340
|
+
feat_new = conv(feat_new)
|
341
|
+
# max pooling
|
342
|
+
feat_new = feat_new.max(-1)[0] # [B,mlp1[-1],N1]
|
343
|
+
# concatenate feature in early layer
|
344
|
+
if feature1 is not None:
|
345
|
+
feat_new = torch.cat([feat_new, feature1], dim=1)
|
346
|
+
# feat_new = feat_new.view(B,-1,N,1)
|
347
|
+
for conv in self.mlp2_convs:
|
348
|
+
feat_new = conv(feat_new)
|
349
|
+
|
350
|
+
return feat_new
|
351
|
+
|
352
|
+
class PointNetFeaturePropogation(nn.Module):
|
353
|
+
def __init__(self, in_channel, mlp):
|
354
|
+
super(PointNetFeaturePropogation, self).__init__()
|
355
|
+
self.mlp_convs = nn.ModuleList()
|
356
|
+
self.mlp_bns = nn.ModuleList()
|
357
|
+
last_channel = in_channel
|
358
|
+
for out_channel in mlp:
|
359
|
+
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
|
360
|
+
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
|
361
|
+
last_channel = out_channel
|
362
|
+
|
363
|
+
def forward(self, pos1, pos2, feature1, feature2):
|
364
|
+
"""
|
365
|
+
Input:
|
366
|
+
xyz1: input points position data, [B, C, N]
|
367
|
+
xyz2: sampled input points position data, [B, C, S]
|
368
|
+
points1: input points data, [B, D, N]
|
369
|
+
points2: input points data, [B, D, S]
|
370
|
+
Return:
|
371
|
+
new_points: upsampled points data, [B, D', N]
|
372
|
+
"""
|
373
|
+
pos1_t = pos1.permute(0, 2, 1).contiguous()
|
374
|
+
pos2_t = pos2.permute(0, 2, 1).contiguous()
|
375
|
+
B, C, N = pos1.shape
|
376
|
+
|
377
|
+
# dists = square_distance(pos1, pos2)
|
378
|
+
# dists, idx = dists.sort(dim=-1)
|
379
|
+
# dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
|
380
|
+
dists,idx = pointutils.three_nn(pos1_t,pos2_t)
|
381
|
+
dists[dists < 1e-10] = 1e-10
|
382
|
+
weight = 1.0 / dists
|
383
|
+
weight = weight / torch.sum(weight, -1,keepdim = True) # [B,N,3]
|
384
|
+
interpolated_feat = torch.sum(pointutils.grouping_operation(feature2, idx) * weight.view(B, 1, N, 3), dim = -1) # [B,C,N,3]
|
385
|
+
|
386
|
+
if feature1 is not None:
|
387
|
+
feat_new = torch.cat([interpolated_feat, feature1], 1)
|
388
|
+
else:
|
389
|
+
feat_new = interpolated_feat
|
390
|
+
|
391
|
+
for i, conv in enumerate(self.mlp_convs):
|
392
|
+
bn = self.mlp_bns[i]
|
393
|
+
feat_new = F.relu(bn(conv(feat_new)))
|
394
|
+
return feat_new
|
395
|
+
|
396
|
+
|
397
|
+
class FlowNet3D(nn.Module):
|
398
|
+
def __init__(self):
|
399
|
+
super(FlowNet3D, self).__init__()
|
400
|
+
|
401
|
+
self.sa1 = PointNetSetAbstraction(npoint=1024, radius=0.5, nsample=16, in_channel=3, mlp=[32,32,64], group_all=False)
|
402
|
+
self.sa2 = PointNetSetAbstraction(npoint=256, radius=1.0, nsample=16, in_channel=64, mlp=[64, 64, 128], group_all=False)
|
403
|
+
self.sa3 = PointNetSetAbstraction(npoint=64, radius=2.0, nsample=8, in_channel=128, mlp=[128, 128, 256], group_all=False)
|
404
|
+
self.sa4 = PointNetSetAbstraction(npoint=16, radius=4.0, nsample=8, in_channel=256, mlp=[256, 256, 512], group_all=False)
|
405
|
+
|
406
|
+
self.fe_layer = FlowEmbedding(radius=10.0, nsample=64, in_channel = 128, mlp=[128, 128, 128], pooling='max', corr_func='concat')
|
407
|
+
|
408
|
+
self.su1 = PointNetSetUpConv(nsample=8, radius=2.4, f1_channel = 256, f2_channel = 512, mlp=[], mlp2=[256, 256])
|
409
|
+
self.su2 = PointNetSetUpConv(nsample=8, radius=1.2, f1_channel = 128+128, f2_channel = 256, mlp=[128, 128, 256], mlp2=[256])
|
410
|
+
self.su3 = PointNetSetUpConv(nsample=8, radius=0.6, f1_channel = 64, f2_channel = 256, mlp=[128, 128, 256], mlp2=[256])
|
411
|
+
self.fp = PointNetFeaturePropogation(in_channel = 256+3, mlp = [256, 256])
|
412
|
+
|
413
|
+
self.conv1 = nn.Conv1d(256, 128, kernel_size=1, bias=False)
|
414
|
+
self.bn1 = nn.BatchNorm1d(128)
|
415
|
+
self.conv2=nn.Conv1d(128, 3, kernel_size=1, bias=True)
|
416
|
+
|
417
|
+
def forward(self, pc1, pc2, feature1, feature2):
|
418
|
+
l1_pc1, l1_feature1 = self.sa1(pc1, feature1)
|
419
|
+
l2_pc1, l2_feature1 = self.sa2(l1_pc1, l1_feature1)
|
420
|
+
|
421
|
+
l1_pc2, l1_feature2 = self.sa1(pc2, feature2)
|
422
|
+
l2_pc2, l2_feature2 = self.sa2(l1_pc2, l1_feature2)
|
423
|
+
|
424
|
+
_, l2_feature1_new = self.fe_layer(l2_pc1, l2_pc2, l2_feature1, l2_feature2)
|
425
|
+
|
426
|
+
l3_pc1, l3_feature1 = self.sa3(l2_pc1, l2_feature1_new)
|
427
|
+
l4_pc1, l4_feature1 = self.sa4(l3_pc1, l3_feature1)
|
428
|
+
|
429
|
+
l3_fnew1 = self.su1(l3_pc1, l4_pc1, l3_feature1, l4_feature1)
|
430
|
+
l2_fnew1 = self.su2(l2_pc1, l3_pc1, torch.cat([l2_feature1, l2_feature1_new], dim=1), l3_fnew1)
|
431
|
+
l1_fnew1 = self.su3(l1_pc1, l2_pc1, l1_feature1, l2_fnew1)
|
432
|
+
l0_fnew1 = self.fp(pc1, l1_pc1, feature1, l1_fnew1)
|
433
|
+
|
434
|
+
x = F.relu(self.bn1(self.conv1(l0_fnew1)))
|
435
|
+
sf = self.conv2(x)
|
436
|
+
return sf
|
437
|
+
|
438
|
+
if __name__ == '__main__':
|
439
|
+
import os
|
440
|
+
import torch
|
441
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
442
|
+
input = torch.randn((8,3,2048))
|
443
|
+
label = torch.randn(8,16)
|
444
|
+
model = FlowNet3D()
|
445
|
+
output = model(input,input)
|
446
|
+
print(output.size())
|
@@ -0,0 +1,84 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from .pointnet import PointNet
|
5
|
+
from .pooling import Pooling
|
6
|
+
|
7
|
+
class PointNetMask(nn.Module):
|
8
|
+
def __init__(self, template_feature_size=1024, source_feature_size=1024, feature_model=PointNet()):
|
9
|
+
super().__init__()
|
10
|
+
self.feature_model = feature_model
|
11
|
+
self.pooling = Pooling()
|
12
|
+
|
13
|
+
input_size = template_feature_size + source_feature_size
|
14
|
+
self.h3 = nn.Sequential(nn.Conv1d(input_size, 1024, 1), nn.ReLU(),
|
15
|
+
nn.Conv1d(1024, 512, 1), nn.ReLU(),
|
16
|
+
nn.Conv1d(512, 256, 1), nn.ReLU(),
|
17
|
+
nn.Conv1d(256, 128, 1), nn.ReLU(),
|
18
|
+
nn.Conv1d(128, 1, 1), nn.Sigmoid())
|
19
|
+
|
20
|
+
def find_mask(self, x, t_out_h1):
|
21
|
+
batch_size, _ , num_points = t_out_h1.size()
|
22
|
+
x = x.unsqueeze(2)
|
23
|
+
x = x.repeat(1,1,num_points)
|
24
|
+
x = torch.cat([t_out_h1, x], dim=1)
|
25
|
+
x = self.h3(x)
|
26
|
+
return x.view(batch_size, -1)
|
27
|
+
|
28
|
+
def forward(self, template, source):
|
29
|
+
source_features = self.feature_model(source) # [B x C x N]
|
30
|
+
template_features = self.feature_model(template) # [B x C x N]
|
31
|
+
|
32
|
+
source_features = self.pooling(source_features)
|
33
|
+
mask = self.find_mask(source_features, template_features)
|
34
|
+
return mask
|
35
|
+
|
36
|
+
|
37
|
+
class MaskNet(nn.Module):
|
38
|
+
def __init__(self, feature_model=PointNet(use_bn=True), is_training=True):
|
39
|
+
super().__init__()
|
40
|
+
self.maskNet = PointNetMask(feature_model=feature_model)
|
41
|
+
self.is_training = is_training
|
42
|
+
|
43
|
+
@staticmethod
|
44
|
+
def index_points(points, idx):
|
45
|
+
"""
|
46
|
+
Input:
|
47
|
+
points: input points data, [B, N, C]
|
48
|
+
idx: sample index data, [B, S]
|
49
|
+
Return:
|
50
|
+
new_points:, indexed points data, [B, S, C]
|
51
|
+
"""
|
52
|
+
device = points.device
|
53
|
+
B = points.shape[0]
|
54
|
+
view_shape = list(idx.shape)
|
55
|
+
view_shape[1:] = [1] * (len(view_shape) - 1)
|
56
|
+
repeat_shape = list(idx.shape)
|
57
|
+
repeat_shape[0] = 1
|
58
|
+
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
59
|
+
new_points = points[batch_indices, idx, :]
|
60
|
+
return new_points
|
61
|
+
|
62
|
+
# This function is only useful for testing with a single pair of point clouds.
|
63
|
+
@staticmethod
|
64
|
+
def find_index(mask_val):
|
65
|
+
mask_idx = torch.nonzero((mask_val[0]>0.5)*1.0)
|
66
|
+
return mask_idx.view(1, -1)
|
67
|
+
|
68
|
+
def forward(self, template, source, point_selection='threshold'):
|
69
|
+
mask = self.maskNet(template, source)
|
70
|
+
|
71
|
+
if point_selection == 'topk' or self.is_training:
|
72
|
+
_, self.mask_idx = torch.topk(mask, source.shape[1], dim=1, sorted=False)
|
73
|
+
elif point_selection == 'threshold':
|
74
|
+
self.mask_idx = self.find_index(mask)
|
75
|
+
|
76
|
+
template = self.index_points(template, self.mask_idx)
|
77
|
+
return template, mask
|
78
|
+
|
79
|
+
|
80
|
+
if __name__ == '__main__':
|
81
|
+
template, source = torch.rand(10,1024,3), torch.rand(10,1024,3)
|
82
|
+
net = MaskNet()
|
83
|
+
result = net(template, source)
|
84
|
+
import ipdb; ipdb.set_trace()
|