learning3d 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- learning3d/__init__.py +2 -0
- learning3d/data_utils/__init__.py +4 -0
- learning3d/data_utils/dataloaders.py +454 -0
- learning3d/data_utils/user_data.py +119 -0
- learning3d/examples/test_dcp.py +139 -0
- learning3d/examples/test_deepgmr.py +144 -0
- learning3d/examples/test_flownet.py +113 -0
- learning3d/examples/test_masknet.py +159 -0
- learning3d/examples/test_masknet2.py +162 -0
- learning3d/examples/test_pcn.py +118 -0
- learning3d/examples/test_pcrnet.py +120 -0
- learning3d/examples/test_pnlk.py +121 -0
- learning3d/examples/test_pointconv.py +126 -0
- learning3d/examples/test_pointnet.py +121 -0
- learning3d/examples/test_prnet.py +126 -0
- learning3d/examples/test_rpmnet.py +120 -0
- learning3d/examples/train_PointNetLK.py +240 -0
- learning3d/examples/train_dcp.py +249 -0
- learning3d/examples/train_deepgmr.py +244 -0
- learning3d/examples/train_flownet.py +259 -0
- learning3d/examples/train_masknet.py +239 -0
- learning3d/examples/train_pcn.py +216 -0
- learning3d/examples/train_pcrnet.py +228 -0
- learning3d/examples/train_pointconv.py +245 -0
- learning3d/examples/train_pointnet.py +244 -0
- learning3d/examples/train_prnet.py +229 -0
- learning3d/examples/train_rpmnet.py +228 -0
- learning3d/losses/__init__.py +12 -0
- learning3d/losses/chamfer_distance.py +51 -0
- learning3d/losses/classification.py +14 -0
- learning3d/losses/correspondence_loss.py +10 -0
- learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
- learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
- learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
- learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
- learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
- learning3d/losses/cuda/emd_torch/setup.py +29 -0
- learning3d/losses/emd.py +16 -0
- learning3d/losses/frobenius_norm.py +21 -0
- learning3d/losses/rmse_features.py +16 -0
- learning3d/models/__init__.py +23 -0
- learning3d/models/classifier.py +41 -0
- learning3d/models/dcp.py +92 -0
- learning3d/models/deepgmr.py +165 -0
- learning3d/models/dgcnn.py +92 -0
- learning3d/models/flownet3d.py +446 -0
- learning3d/models/masknet.py +84 -0
- learning3d/models/masknet2.py +264 -0
- learning3d/models/pcn.py +164 -0
- learning3d/models/pcrnet.py +74 -0
- learning3d/models/pointconv.py +108 -0
- learning3d/models/pointnet.py +108 -0
- learning3d/models/pointnetlk.py +173 -0
- learning3d/models/pooling.py +15 -0
- learning3d/models/ppfnet.py +102 -0
- learning3d/models/prnet.py +431 -0
- learning3d/models/rpmnet.py +359 -0
- learning3d/models/segmentation.py +38 -0
- learning3d/ops/__init__.py +0 -0
- learning3d/ops/data_utils.py +45 -0
- learning3d/ops/invmat.py +134 -0
- learning3d/ops/quaternion.py +218 -0
- learning3d/ops/se3.py +157 -0
- learning3d/ops/sinc.py +229 -0
- learning3d/ops/so3.py +213 -0
- learning3d/ops/transform_functions.py +342 -0
- learning3d/utils/__init__.py +9 -0
- learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
- learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
- learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
- learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
- learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
- learning3d/utils/lib/pointnet2_modules.py +160 -0
- learning3d/utils/lib/pointnet2_utils.py +318 -0
- learning3d/utils/lib/pytorch_utils.py +236 -0
- learning3d/utils/lib/setup.py +23 -0
- learning3d/utils/lib/src/ball_query.cpp +25 -0
- learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
- learning3d/utils/lib/src/ball_query_gpu.h +15 -0
- learning3d/utils/lib/src/cuda_utils.h +15 -0
- learning3d/utils/lib/src/group_points.cpp +36 -0
- learning3d/utils/lib/src/group_points_gpu.cu +86 -0
- learning3d/utils/lib/src/group_points_gpu.h +22 -0
- learning3d/utils/lib/src/interpolate.cpp +65 -0
- learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
- learning3d/utils/lib/src/interpolate_gpu.h +36 -0
- learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
- learning3d/utils/lib/src/sampling.cpp +46 -0
- learning3d/utils/lib/src/sampling_gpu.cu +253 -0
- learning3d/utils/lib/src/sampling_gpu.h +29 -0
- learning3d/utils/pointconv_util.py +382 -0
- learning3d/utils/ppfnet_util.py +244 -0
- learning3d/utils/svd.py +59 -0
- learning3d/utils/transformer.py +243 -0
- learning3d-0.0.1.dist-info/LICENSE +21 -0
- learning3d-0.0.1.dist-info/METADATA +271 -0
- learning3d-0.0.1.dist-info/RECORD +115 -0
- learning3d-0.0.1.dist-info/WHEEL +5 -0
- learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,382 @@
|
|
1
|
+
"""
|
2
|
+
Utility function for PointConv
|
3
|
+
Originally from : https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/utils.py
|
4
|
+
Modify by Wenxuan Wu
|
5
|
+
Date: September 2019
|
6
|
+
"""
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
import torch.nn.functional as F
|
10
|
+
from time import time
|
11
|
+
import numpy as np
|
12
|
+
from sklearn.neighbors._kde import KernelDensity
|
13
|
+
|
14
|
+
def timeit(tag, t):
|
15
|
+
print("{}: {}s".format(tag, time() - t))
|
16
|
+
return time()
|
17
|
+
|
18
|
+
def square_distance(src, dst):
|
19
|
+
"""
|
20
|
+
Calculate Euclid distance between each two points.
|
21
|
+
|
22
|
+
src^T * dst = xn * xm + yn * ym + zn * zm;
|
23
|
+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
24
|
+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
25
|
+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
26
|
+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
27
|
+
|
28
|
+
Input:
|
29
|
+
src: source points, [B, N, C]
|
30
|
+
dst: target points, [B, M, C]
|
31
|
+
Output:
|
32
|
+
dist: per-point square distance, [B, N, M]
|
33
|
+
"""
|
34
|
+
B, N, _ = src.shape
|
35
|
+
_, M, _ = dst.shape
|
36
|
+
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
37
|
+
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
38
|
+
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
39
|
+
return dist
|
40
|
+
|
41
|
+
def index_points(points, idx):
|
42
|
+
"""
|
43
|
+
|
44
|
+
Input:
|
45
|
+
points: input points data, [B, N, C]
|
46
|
+
idx: sample index data, [B, S]
|
47
|
+
Return:
|
48
|
+
new_points:, indexed points data, [B, S, C]
|
49
|
+
"""
|
50
|
+
device = points.device
|
51
|
+
B = points.shape[0]
|
52
|
+
view_shape = list(idx.shape)
|
53
|
+
view_shape[1:] = [1] * (len(view_shape) - 1)
|
54
|
+
repeat_shape = list(idx.shape)
|
55
|
+
repeat_shape[0] = 1
|
56
|
+
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
57
|
+
new_points = points[batch_indices, idx, :]
|
58
|
+
return new_points
|
59
|
+
|
60
|
+
def farthest_point_sample(xyz, npoint):
|
61
|
+
"""
|
62
|
+
Input:
|
63
|
+
xyz: pointcloud data, [B, N, C]
|
64
|
+
npoint: number of samples
|
65
|
+
Return:
|
66
|
+
centroids: sampled pointcloud index, [B, npoint]
|
67
|
+
"""
|
68
|
+
#import ipdb; ipdb.set_trace()
|
69
|
+
device = xyz.device
|
70
|
+
B, N, C = xyz.shape
|
71
|
+
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
|
72
|
+
distance = torch.ones(B, N).to(device) * 1e10
|
73
|
+
#farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
|
74
|
+
farthest = torch.zeros(B, dtype=torch.long).to(device)
|
75
|
+
batch_indices = torch.arange(B, dtype=torch.long).to(device)
|
76
|
+
for i in range(npoint):
|
77
|
+
centroids[:, i] = farthest
|
78
|
+
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
|
79
|
+
dist = torch.sum((xyz - centroid) ** 2, -1)
|
80
|
+
mask = dist < distance
|
81
|
+
distance[mask] = dist[mask]
|
82
|
+
farthest = torch.max(distance, -1)[1]
|
83
|
+
return centroids
|
84
|
+
|
85
|
+
def query_ball_point(radius, nsample, xyz, new_xyz):
|
86
|
+
"""
|
87
|
+
Input:
|
88
|
+
radius: local region radius
|
89
|
+
nsample: max sample number in local region
|
90
|
+
xyz: all points, [B, N, C]
|
91
|
+
new_xyz: query points, [B, S, C]
|
92
|
+
Return:
|
93
|
+
group_idx: grouped points index, [B, S, nsample]
|
94
|
+
"""
|
95
|
+
device = xyz.device
|
96
|
+
B, N, C = xyz.shape
|
97
|
+
_, S, _ = new_xyz.shape
|
98
|
+
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
99
|
+
sqrdists = square_distance(new_xyz, xyz)
|
100
|
+
group_idx[sqrdists > radius ** 2] = N
|
101
|
+
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
102
|
+
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
103
|
+
mask = group_idx == N
|
104
|
+
group_idx[mask] = group_first[mask]
|
105
|
+
return group_idx
|
106
|
+
|
107
|
+
def knn_point(nsample, xyz, new_xyz):
|
108
|
+
"""
|
109
|
+
Input:
|
110
|
+
nsample: max sample number in local region
|
111
|
+
xyz: all points, [B, N, C]
|
112
|
+
new_xyz: query points, [B, S, C]
|
113
|
+
Return:
|
114
|
+
group_idx: grouped points index, [B, S, nsample]
|
115
|
+
"""
|
116
|
+
sqrdists = square_distance(new_xyz, xyz)
|
117
|
+
_, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
|
118
|
+
return group_idx
|
119
|
+
|
120
|
+
def sample_and_group(npoint, nsample, xyz, points, density_scale = None):
|
121
|
+
"""
|
122
|
+
Input:
|
123
|
+
npoint:
|
124
|
+
nsample:
|
125
|
+
xyz: input points position data, [B, N, C]
|
126
|
+
points: input points data, [B, N, D]
|
127
|
+
Return:
|
128
|
+
new_xyz: sampled points position data, [B, 1, C]
|
129
|
+
new_points: sampled points data, [B, 1, N, C+D]
|
130
|
+
"""
|
131
|
+
B, N, C = xyz.shape
|
132
|
+
S = npoint
|
133
|
+
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
|
134
|
+
new_xyz = index_points(xyz, fps_idx)
|
135
|
+
idx = knn_point(nsample, xyz, new_xyz)
|
136
|
+
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
|
137
|
+
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
|
138
|
+
if points is not None:
|
139
|
+
grouped_points = index_points(points, idx)
|
140
|
+
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
|
141
|
+
else:
|
142
|
+
new_points = grouped_xyz_norm
|
143
|
+
|
144
|
+
if density_scale is None:
|
145
|
+
return new_xyz, new_points, grouped_xyz_norm, idx
|
146
|
+
else:
|
147
|
+
grouped_density = index_points(density_scale, idx)
|
148
|
+
return new_xyz, new_points, grouped_xyz_norm, idx, grouped_density
|
149
|
+
|
150
|
+
def sample_and_group_all(xyz, points, density_scale = None):
|
151
|
+
"""
|
152
|
+
Input:
|
153
|
+
xyz: input points position data, [B, N, C]
|
154
|
+
points: input points data, [B, N, D]
|
155
|
+
Return:
|
156
|
+
new_xyz: sampled points position data, [B, 1, C]
|
157
|
+
new_points: sampled points data, [B, 1, N, C+D]
|
158
|
+
"""
|
159
|
+
device = xyz.device
|
160
|
+
B, N, C = xyz.shape
|
161
|
+
#new_xyz = torch.zeros(B, 1, C).to(device)
|
162
|
+
new_xyz = xyz.mean(dim = 1, keepdim = True)
|
163
|
+
grouped_xyz = xyz.view(B, 1, N, C) - new_xyz.view(B, 1, 1, C)
|
164
|
+
if points is not None:
|
165
|
+
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
|
166
|
+
else:
|
167
|
+
new_points = grouped_xyz
|
168
|
+
if density_scale is None:
|
169
|
+
return new_xyz, new_points, grouped_xyz
|
170
|
+
else:
|
171
|
+
grouped_density = density_scale.view(B, 1, N, 1)
|
172
|
+
return new_xyz, new_points, grouped_xyz, grouped_density
|
173
|
+
|
174
|
+
def group(nsample, xyz, points):
|
175
|
+
"""
|
176
|
+
Input:
|
177
|
+
npoint:
|
178
|
+
nsample:
|
179
|
+
xyz: input points position data, [B, N, C]
|
180
|
+
points: input points data, [B, N, D]
|
181
|
+
Return:
|
182
|
+
new_xyz: sampled points position data, [B, 1, C]
|
183
|
+
new_points: sampled points data, [B, 1, N, C+D]
|
184
|
+
"""
|
185
|
+
B, N, C = xyz.shape
|
186
|
+
S = N
|
187
|
+
new_xyz = xyz
|
188
|
+
idx = knn_point(nsample, xyz, new_xyz)
|
189
|
+
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
|
190
|
+
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
|
191
|
+
if points is not None:
|
192
|
+
grouped_points = index_points(points, idx)
|
193
|
+
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
|
194
|
+
else:
|
195
|
+
new_points = grouped_xyz_norm
|
196
|
+
|
197
|
+
return new_points, grouped_xyz_norm
|
198
|
+
|
199
|
+
def compute_density(xyz, bandwidth):
|
200
|
+
'''
|
201
|
+
xyz: input points position data, [B, N, C]
|
202
|
+
'''
|
203
|
+
#import ipdb; ipdb.set_trace()
|
204
|
+
B, N, C = xyz.shape
|
205
|
+
sqrdists = square_distance(xyz, xyz)
|
206
|
+
gaussion_density = torch.exp(- sqrdists / (2.0 * bandwidth * bandwidth)) / (2.5 * bandwidth)
|
207
|
+
xyz_density = gaussion_density.mean(dim = -1)
|
208
|
+
|
209
|
+
return xyz_density
|
210
|
+
|
211
|
+
class DensityNet(nn.Module):
|
212
|
+
def __init__(self, hidden_unit = [16, 8]):
|
213
|
+
super(DensityNet, self).__init__()
|
214
|
+
self.mlp_convs = nn.ModuleList()
|
215
|
+
self.mlp_bns = nn.ModuleList()
|
216
|
+
|
217
|
+
self.mlp_convs.append(nn.Conv2d(1, hidden_unit[0], 1))
|
218
|
+
self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[0]))
|
219
|
+
for i in range(1, len(hidden_unit)):
|
220
|
+
self.mlp_convs.append(nn.Conv2d(hidden_unit[i - 1], hidden_unit[i], 1))
|
221
|
+
self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[i]))
|
222
|
+
self.mlp_convs.append(nn.Conv2d(hidden_unit[-1], 1, 1))
|
223
|
+
self.mlp_bns.append(nn.BatchNorm2d(1))
|
224
|
+
|
225
|
+
def forward(self, density_scale):
|
226
|
+
for i, conv in enumerate(self.mlp_convs):
|
227
|
+
bn = self.mlp_bns[i]
|
228
|
+
density_scale = bn(conv(density_scale))
|
229
|
+
if i == len(self.mlp_convs):
|
230
|
+
density_scale = F.sigmoid(density_scale)
|
231
|
+
else:
|
232
|
+
density_scale = F.relu(density_scale)
|
233
|
+
|
234
|
+
return density_scale
|
235
|
+
|
236
|
+
class WeightNet(nn.Module):
|
237
|
+
|
238
|
+
def __init__(self, in_channel, out_channel, hidden_unit = [8, 8]):
|
239
|
+
super(WeightNet, self).__init__()
|
240
|
+
|
241
|
+
self.mlp_convs = nn.ModuleList()
|
242
|
+
self.mlp_bns = nn.ModuleList()
|
243
|
+
if hidden_unit is None or len(hidden_unit) == 0:
|
244
|
+
self.mlp_convs.append(nn.Conv2d(in_channel, out_channel, 1))
|
245
|
+
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
246
|
+
else:
|
247
|
+
self.mlp_convs.append(nn.Conv2d(in_channel, hidden_unit[0], 1))
|
248
|
+
self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[0]))
|
249
|
+
for i in range(1, len(hidden_unit)):
|
250
|
+
self.mlp_convs.append(nn.Conv2d(hidden_unit[i - 1], hidden_unit[i], 1))
|
251
|
+
self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[i]))
|
252
|
+
self.mlp_convs.append(nn.Conv2d(hidden_unit[-1], out_channel, 1))
|
253
|
+
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
254
|
+
|
255
|
+
def forward(self, localized_xyz):
|
256
|
+
#xyz : BxCxKxN
|
257
|
+
|
258
|
+
weights = localized_xyz
|
259
|
+
for i, conv in enumerate(self.mlp_convs):
|
260
|
+
bn = self.mlp_bns[i]
|
261
|
+
weights = F.relu(bn(conv(weights)))
|
262
|
+
|
263
|
+
return weights
|
264
|
+
|
265
|
+
class PointConvSetAbstraction(nn.Module):
|
266
|
+
def __init__(self, npoint, nsample, in_channel, mlp, group_all):
|
267
|
+
super(PointConvSetAbstraction, self).__init__()
|
268
|
+
self.npoint = npoint
|
269
|
+
self.nsample = nsample
|
270
|
+
self.mlp_convs = nn.ModuleList()
|
271
|
+
self.mlp_bns = nn.ModuleList()
|
272
|
+
last_channel = in_channel
|
273
|
+
for out_channel in mlp:
|
274
|
+
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
|
275
|
+
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
276
|
+
last_channel = out_channel
|
277
|
+
|
278
|
+
self.weightnet = WeightNet(3, 16)
|
279
|
+
self.linear = nn.Linear(16 * mlp[-1], mlp[-1])
|
280
|
+
self.bn_linear = nn.BatchNorm1d(mlp[-1])
|
281
|
+
self.group_all = group_all
|
282
|
+
|
283
|
+
def forward(self, xyz, points):
|
284
|
+
"""
|
285
|
+
Input:
|
286
|
+
xyz: input points position data, [B, C, N]
|
287
|
+
points: input points data, [B, D, N]
|
288
|
+
Return:
|
289
|
+
new_xyz: sampled points position data, [B, C, S]
|
290
|
+
new_points_concat: sample points feature data, [B, D', S]
|
291
|
+
"""
|
292
|
+
B = xyz.shape[0]
|
293
|
+
xyz = xyz.permute(0, 2, 1)
|
294
|
+
if points is not None:
|
295
|
+
points = points.permute(0, 2, 1)
|
296
|
+
|
297
|
+
if self.group_all:
|
298
|
+
new_xyz, new_points, grouped_xyz_norm = sample_and_group_all(xyz, points)
|
299
|
+
else:
|
300
|
+
new_xyz, new_points, grouped_xyz_norm, _ = sample_and_group(self.npoint, self.nsample, xyz, points)
|
301
|
+
# new_xyz: sampled points position data, [B, npoint, C]
|
302
|
+
# new_points: sampled points data, [B, npoint, nsample, C+D]
|
303
|
+
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
|
304
|
+
for i, conv in enumerate(self.mlp_convs):
|
305
|
+
bn = self.mlp_bns[i]
|
306
|
+
new_points = F.relu(bn(conv(new_points)))
|
307
|
+
|
308
|
+
grouped_xyz = grouped_xyz_norm.permute(0, 3, 2, 1)
|
309
|
+
weights = self.weightnet(grouped_xyz)
|
310
|
+
new_points = torch.matmul(input=new_points.permute(0, 3, 1, 2), other = weights.permute(0, 3, 2, 1)).view(B, self.npoint, -1)
|
311
|
+
new_points = self.linear(new_points)
|
312
|
+
new_points = self.bn_linear(new_points.permute(0, 2, 1))
|
313
|
+
new_points = F.relu(new_points)
|
314
|
+
new_xyz = new_xyz.permute(0, 2, 1)
|
315
|
+
|
316
|
+
return new_xyz, new_points
|
317
|
+
|
318
|
+
class PointConvDensitySetAbstraction(nn.Module):
|
319
|
+
def __init__(self, npoint, nsample, in_channel, mlp, bandwidth, group_all):
|
320
|
+
super(PointConvDensitySetAbstraction, self).__init__()
|
321
|
+
self.npoint = npoint
|
322
|
+
self.nsample = nsample
|
323
|
+
self.mlp_convs = nn.ModuleList()
|
324
|
+
self.mlp_bns = nn.ModuleList()
|
325
|
+
last_channel = in_channel
|
326
|
+
for out_channel in mlp:
|
327
|
+
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
|
328
|
+
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
329
|
+
last_channel = out_channel
|
330
|
+
|
331
|
+
self.weightnet = WeightNet(3, 16)
|
332
|
+
self.linear = nn.Linear(16 * mlp[-1], mlp[-1])
|
333
|
+
self.bn_linear = nn.BatchNorm1d(mlp[-1])
|
334
|
+
self.densitynet = DensityNet()
|
335
|
+
self.group_all = group_all
|
336
|
+
self.bandwidth = bandwidth
|
337
|
+
|
338
|
+
def forward(self, xyz, points):
|
339
|
+
"""
|
340
|
+
Input:
|
341
|
+
xyz: input points position data, [B, C, N]
|
342
|
+
points: input points data, [B, D, N]
|
343
|
+
Return:
|
344
|
+
new_xyz: sampled points position data, [B, C, S]
|
345
|
+
new_points_concat: sample points feature data, [B, D', S]
|
346
|
+
"""
|
347
|
+
B = xyz.shape[0]
|
348
|
+
N = xyz.shape[2]
|
349
|
+
xyz = xyz.permute(0, 2, 1)
|
350
|
+
if points is not None:
|
351
|
+
points = points.permute(0, 2, 1)
|
352
|
+
|
353
|
+
xyz_density = compute_density(xyz, self.bandwidth)
|
354
|
+
inverse_density = 1.0 / xyz_density
|
355
|
+
|
356
|
+
if self.group_all:
|
357
|
+
new_xyz, new_points, grouped_xyz_norm, grouped_density = sample_and_group_all(xyz, points, inverse_density.view(B, N, 1))
|
358
|
+
else:
|
359
|
+
new_xyz, new_points, grouped_xyz_norm, _, grouped_density = sample_and_group(self.npoint, self.nsample, xyz, points, inverse_density.view(B, N, 1))
|
360
|
+
# new_xyz: sampled points position data, [B, npoint, C]
|
361
|
+
# new_points: sampled points data, [B, npoint, nsample, C+D]
|
362
|
+
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
|
363
|
+
for i, conv in enumerate(self.mlp_convs):
|
364
|
+
bn = self.mlp_bns[i]
|
365
|
+
new_points = F.relu(bn(conv(new_points)))
|
366
|
+
|
367
|
+
inverse_max_density = grouped_density.max(dim = 2, keepdim=True)[0]
|
368
|
+
density_scale = grouped_density / inverse_max_density
|
369
|
+
density_scale = self.densitynet(density_scale.permute(0, 3, 2, 1))
|
370
|
+
new_points = new_points * density_scale
|
371
|
+
|
372
|
+
grouped_xyz = grouped_xyz_norm.permute(0, 3, 2, 1)
|
373
|
+
weights = self.weightnet(grouped_xyz)
|
374
|
+
new_points = torch.matmul(input=new_points.permute(0, 3, 1, 2), other = weights.permute(0, 3, 2, 1)).view(B, self.npoint, -1)
|
375
|
+
new_points = self.linear(new_points)
|
376
|
+
new_points = self.bn_linear(new_points.permute(0, 2, 1))
|
377
|
+
new_points = F.relu(new_points)
|
378
|
+
new_xyz = new_xyz.permute(0, 2, 1)
|
379
|
+
|
380
|
+
return new_xyz, new_points
|
381
|
+
|
382
|
+
|
@@ -0,0 +1,244 @@
|
|
1
|
+
"""Utilities for PointNet related functions
|
2
|
+
|
3
|
+
Modified from:
|
4
|
+
Pytorch Implementation of PointNet and PointNet++
|
5
|
+
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
|
6
|
+
"""
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
|
11
|
+
def angle_difference(src, dst):
|
12
|
+
"""Calculate angle between each pair of vectors.
|
13
|
+
Assumes points are l2-normalized to unit length.
|
14
|
+
|
15
|
+
Input:
|
16
|
+
src: source points, [B, N, C]
|
17
|
+
dst: target points, [B, M, C]
|
18
|
+
Output:
|
19
|
+
dist: per-point square distance, [B, N, M]
|
20
|
+
"""
|
21
|
+
B, N, _ = src.shape
|
22
|
+
_, M, _ = dst.shape
|
23
|
+
dist = torch.matmul(src, dst.permute(0, 2, 1))
|
24
|
+
dist = torch.acos(dist)
|
25
|
+
|
26
|
+
return dist
|
27
|
+
|
28
|
+
|
29
|
+
def square_distance(src, dst):
|
30
|
+
"""Calculate Euclid distance between each two points.
|
31
|
+
src^T * dst = xn * xm + yn * ym + zn * zm;
|
32
|
+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
33
|
+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
34
|
+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
35
|
+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
36
|
+
|
37
|
+
Args:
|
38
|
+
src: source points, [B, N, C]
|
39
|
+
dst: target points, [B, M, C]
|
40
|
+
Returns:
|
41
|
+
dist: per-point square distance, [B, N, M]
|
42
|
+
"""
|
43
|
+
B, N, _ = src.shape
|
44
|
+
_, M, _ = dst.shape
|
45
|
+
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
46
|
+
dist += torch.sum(src ** 2, dim=-1)[:, :, None]
|
47
|
+
dist += torch.sum(dst ** 2, dim=-1)[:, None, :]
|
48
|
+
return dist
|
49
|
+
|
50
|
+
|
51
|
+
def index_points(points, idx):
|
52
|
+
"""Array indexing, i.e. retrieves relevant points based on indices
|
53
|
+
|
54
|
+
Args:
|
55
|
+
points: input points data_loader, [B, N, C]
|
56
|
+
idx: sample index data_loader, [B, S]. S can be 2 dimensional
|
57
|
+
Returns:
|
58
|
+
new_points:, indexed points data_loader, [B, S, C]
|
59
|
+
"""
|
60
|
+
device = points.device
|
61
|
+
B = points.shape[0]
|
62
|
+
view_shape = list(idx.shape)
|
63
|
+
view_shape[1:] = [1] * (len(view_shape) - 1)
|
64
|
+
repeat_shape = list(idx.shape)
|
65
|
+
repeat_shape[0] = 1
|
66
|
+
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
67
|
+
new_points = points[batch_indices, idx, :]
|
68
|
+
return new_points
|
69
|
+
|
70
|
+
|
71
|
+
def farthest_point_sample(xyz, npoint):
|
72
|
+
"""Iterative farthest point sampling
|
73
|
+
|
74
|
+
Args:
|
75
|
+
xyz: pointcloud data_loader, [B, N, C]
|
76
|
+
npoint: number of samples
|
77
|
+
Returns:
|
78
|
+
centroids: sampled pointcloud index, [B, npoint]
|
79
|
+
"""
|
80
|
+
device = xyz.device
|
81
|
+
B, N, C = xyz.shape
|
82
|
+
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
|
83
|
+
distance = torch.ones(B, N).to(device) * 1e10
|
84
|
+
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
|
85
|
+
batch_indices = torch.arange(B, dtype=torch.long).to(device)
|
86
|
+
for i in range(npoint):
|
87
|
+
centroids[:, i] = farthest
|
88
|
+
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
|
89
|
+
dist = torch.sum((xyz - centroid) ** 2, -1)
|
90
|
+
mask = dist < distance
|
91
|
+
distance[mask] = dist[mask]
|
92
|
+
farthest = torch.max(distance, -1)[1]
|
93
|
+
return centroids
|
94
|
+
|
95
|
+
|
96
|
+
def query_ball_point(radius, nsample, xyz, new_xyz, itself_indices=None):
|
97
|
+
""" Grouping layer in PointNet++.
|
98
|
+
|
99
|
+
Inputs:
|
100
|
+
radius: local region radius
|
101
|
+
nsample: max sample number in local region
|
102
|
+
xyz: all points, (B, N, C)
|
103
|
+
new_xyz: query points, (B, S, C)
|
104
|
+
itself_indices (Optional): Indices of new_xyz into xyz (B, S).
|
105
|
+
Used to try and prevent grouping the point itself into the neighborhood.
|
106
|
+
If there is insufficient points in the neighborhood, or if left is none, the resulting cluster will
|
107
|
+
still contain the center point.
|
108
|
+
Returns:
|
109
|
+
group_idx: grouped points index, [B, S, nsample]
|
110
|
+
"""
|
111
|
+
device = xyz.device
|
112
|
+
B, N, C = xyz.shape
|
113
|
+
_, S, _ = new_xyz.shape
|
114
|
+
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) # (B, S, N)
|
115
|
+
sqrdists = square_distance(new_xyz, xyz)
|
116
|
+
|
117
|
+
if itself_indices is not None:
|
118
|
+
# Remove indices of the center points so that it will not be chosen
|
119
|
+
batch_indices = torch.arange(B, dtype=torch.long).to(device)[:, None].repeat(1, S) # (B, S)
|
120
|
+
row_indices = torch.arange(S, dtype=torch.long).to(device)[None, :].repeat(B, 1) # (B, S)
|
121
|
+
group_idx[batch_indices, row_indices, itself_indices] = N
|
122
|
+
|
123
|
+
group_idx[sqrdists > radius ** 2] = N
|
124
|
+
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
125
|
+
if itself_indices is not None:
|
126
|
+
group_first = itself_indices[:, :, None].repeat([1, 1, nsample])
|
127
|
+
else:
|
128
|
+
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
129
|
+
mask = group_idx == N
|
130
|
+
group_idx[mask] = group_first[mask]
|
131
|
+
return group_idx
|
132
|
+
|
133
|
+
|
134
|
+
def sample_and_group(npoint: int, radius: float, nsample: int, xyz: torch.Tensor, points: torch.Tensor,
|
135
|
+
returnfps: bool=False):
|
136
|
+
"""
|
137
|
+
Args:
|
138
|
+
npoint (int): Set to negative to compute for all points
|
139
|
+
radius:
|
140
|
+
nsample:
|
141
|
+
xyz: input points position data_loader, [B, N, C]
|
142
|
+
points: input points data_loader, [B, N, D]
|
143
|
+
returnfps (bool) Whether to return furthest point indices
|
144
|
+
Returns:
|
145
|
+
new_xyz: sampled points position data_loader, [B, 1, C]
|
146
|
+
new_points: sampled points data_loader, [B, 1, N, C+D]
|
147
|
+
"""
|
148
|
+
B, N, C = xyz.shape
|
149
|
+
|
150
|
+
if npoint > 0:
|
151
|
+
S = npoint
|
152
|
+
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
|
153
|
+
new_xyz = index_points(xyz, fps_idx)
|
154
|
+
else:
|
155
|
+
S = xyz.shape[1]
|
156
|
+
fps_idx = torch.arange(0, xyz.shape[1])[None, ...].repeat(xyz.shape[0], 1)
|
157
|
+
new_xyz = xyz
|
158
|
+
|
159
|
+
idx = query_ball_point(radius, nsample, xyz, new_xyz) # (B, N, nsample)
|
160
|
+
grouped_xyz = index_points(xyz, idx) # (B, npoint, nsample, C)
|
161
|
+
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
|
162
|
+
if points is not None:
|
163
|
+
grouped_points = index_points(points, idx)
|
164
|
+
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
|
165
|
+
else:
|
166
|
+
new_points = grouped_xyz_norm
|
167
|
+
if returnfps:
|
168
|
+
return new_xyz, new_points, grouped_xyz, fps_idx
|
169
|
+
else:
|
170
|
+
return new_xyz, new_points
|
171
|
+
|
172
|
+
|
173
|
+
def angle(v1: torch.Tensor, v2: torch.Tensor):
|
174
|
+
"""Compute angle between 2 vectors
|
175
|
+
|
176
|
+
For robustness, we use the same formulation as in PPFNet, i.e.
|
177
|
+
angle(v1, v2) = atan2(cross(v1, v2), dot(v1, v2)).
|
178
|
+
This handles the case where one of the vectors is 0.0, since torch.atan2(0.0, 0.0)=0.0
|
179
|
+
|
180
|
+
Args:
|
181
|
+
v1: (B, *, 3)
|
182
|
+
v2: (B, *, 3)
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
|
186
|
+
"""
|
187
|
+
|
188
|
+
cross_prod = torch.stack([v1[..., 1] * v2[..., 2] - v1[..., 2] * v2[..., 1],
|
189
|
+
v1[..., 2] * v2[..., 0] - v1[..., 0] * v2[..., 2],
|
190
|
+
v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0]], dim=-1)
|
191
|
+
cross_prod_norm = torch.norm(cross_prod, dim=-1)
|
192
|
+
dot_prod = torch.sum(v1 * v2, dim=-1)
|
193
|
+
|
194
|
+
return torch.atan2(cross_prod_norm, dot_prod)
|
195
|
+
|
196
|
+
|
197
|
+
def sample_and_group_multi(npoint: int, radius: float, nsample: int, xyz: torch.Tensor, normals: torch.Tensor,
|
198
|
+
returnfps: bool = False):
|
199
|
+
"""Sample and group for xyz, dxyz and ppf features
|
200
|
+
|
201
|
+
Args:
|
202
|
+
npoint(int): Number of clusters (equivalently, keypoints) to sample.
|
203
|
+
Set to negative to compute for all points
|
204
|
+
radius(int): Radius of cluster for computing local features
|
205
|
+
nsample: Maximum number of points to consider per cluster
|
206
|
+
xyz: XYZ coordinates of the points
|
207
|
+
normals: Corresponding normals for the points (required for ppf computation)
|
208
|
+
returnfps: Whether to return indices of FPS points and their neighborhood
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
Dictionary containing the following fields ['xyz', 'dxyz', 'ppf'].
|
212
|
+
If returnfps is True, also returns: grouped_xyz, fps_idx
|
213
|
+
"""
|
214
|
+
|
215
|
+
B, N, C = xyz.shape
|
216
|
+
|
217
|
+
if npoint > 0:
|
218
|
+
S = npoint
|
219
|
+
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
|
220
|
+
new_xyz = index_points(xyz, fps_idx)
|
221
|
+
nr = index_points(normals, fps_idx)[:, :, None, :]
|
222
|
+
else:
|
223
|
+
S = xyz.shape[1]
|
224
|
+
fps_idx = torch.arange(0, xyz.shape[1])[None, ...].repeat(xyz.shape[0], 1).to(xyz.device)
|
225
|
+
new_xyz = xyz
|
226
|
+
nr = normals[:, :, None, :]
|
227
|
+
|
228
|
+
idx = query_ball_point(radius, nsample, xyz, new_xyz, fps_idx) # (B, npoint, nsample)
|
229
|
+
grouped_xyz = index_points(xyz, idx) # (B, npoint, nsample, C)
|
230
|
+
d = grouped_xyz - new_xyz.view(B, S, 1, C) # d = p_r - p_i (B, npoint, nsample, 3)
|
231
|
+
ni = index_points(normals, idx)
|
232
|
+
|
233
|
+
nr_d = angle(nr, d)
|
234
|
+
ni_d = angle(ni, d)
|
235
|
+
nr_ni = angle(nr, ni)
|
236
|
+
d_norm = torch.norm(d, dim=-1)
|
237
|
+
|
238
|
+
xyz_feat = d # (B, npoint, n_sample, 3)
|
239
|
+
ppf_feat = torch.stack([nr_d, ni_d, nr_ni, d_norm], dim=-1) # (B, npoint, n_sample, 4)
|
240
|
+
|
241
|
+
if returnfps:
|
242
|
+
return {'xyz': new_xyz, 'dxyz': xyz_feat, 'ppf': ppf_feat}, grouped_xyz, fps_idx
|
243
|
+
else:
|
244
|
+
return {'xyz': new_xyz, 'dxyz': xyz_feat, 'ppf': ppf_feat}
|
learning3d/utils/svd.py
ADDED
@@ -0,0 +1,59 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import math
|
4
|
+
|
5
|
+
class SVDHead(nn.Module):
|
6
|
+
def __init__(self, emb_dims, input_shape="bnc"):
|
7
|
+
super(SVDHead, self).__init__()
|
8
|
+
self.emb_dims = emb_dims
|
9
|
+
self.reflect = nn.Parameter(torch.eye(3), requires_grad=False)
|
10
|
+
self.reflect[2, 2] = -1
|
11
|
+
self.input_shape = input_shape
|
12
|
+
|
13
|
+
def forward(self, *input):
|
14
|
+
src_embedding = input[0]
|
15
|
+
tgt_embedding = input[1]
|
16
|
+
src = input[2]
|
17
|
+
tgt = input[3]
|
18
|
+
batch_size = src.size(0)
|
19
|
+
if self.input_shape == "bnc":
|
20
|
+
src = src.permute(0, 2, 1)
|
21
|
+
tgt = tgt.permute(0, 2, 1)
|
22
|
+
|
23
|
+
d_k = src_embedding.size(1)
|
24
|
+
scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k)
|
25
|
+
scores = torch.softmax(scores, dim=2)
|
26
|
+
|
27
|
+
src_corr = torch.matmul(tgt, scores.transpose(2, 1).contiguous())
|
28
|
+
|
29
|
+
src_centered = src - src.mean(dim=2, keepdim=True)
|
30
|
+
|
31
|
+
src_corr_centered = src_corr - src_corr.mean(dim=2, keepdim=True)
|
32
|
+
|
33
|
+
H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous())
|
34
|
+
|
35
|
+
U, S, V = [], [], []
|
36
|
+
R = []
|
37
|
+
|
38
|
+
for i in range(src.size(0)):
|
39
|
+
u, s, v = torch.svd(H[i])
|
40
|
+
r = torch.matmul(v, u.transpose(1, 0).contiguous())
|
41
|
+
r_det = torch.det(r)
|
42
|
+
if r_det < 0:
|
43
|
+
u, s, v = torch.svd(H[i])
|
44
|
+
v = torch.matmul(v, self.reflect)
|
45
|
+
r = torch.matmul(v, u.transpose(1, 0).contiguous())
|
46
|
+
# r = r * self.reflect
|
47
|
+
R.append(r)
|
48
|
+
|
49
|
+
U.append(u)
|
50
|
+
S.append(s)
|
51
|
+
V.append(v)
|
52
|
+
|
53
|
+
U = torch.stack(U, dim=0)
|
54
|
+
V = torch.stack(V, dim=0)
|
55
|
+
S = torch.stack(S, dim=0)
|
56
|
+
R = torch.stack(R, dim=0)
|
57
|
+
|
58
|
+
t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True)
|
59
|
+
return R, t.view(batch_size, 3)
|