learning3d 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- learning3d/__init__.py +2 -0
- learning3d/data_utils/__init__.py +4 -0
- learning3d/data_utils/dataloaders.py +454 -0
- learning3d/data_utils/user_data.py +119 -0
- learning3d/examples/test_dcp.py +139 -0
- learning3d/examples/test_deepgmr.py +144 -0
- learning3d/examples/test_flownet.py +113 -0
- learning3d/examples/test_masknet.py +159 -0
- learning3d/examples/test_masknet2.py +162 -0
- learning3d/examples/test_pcn.py +118 -0
- learning3d/examples/test_pcrnet.py +120 -0
- learning3d/examples/test_pnlk.py +121 -0
- learning3d/examples/test_pointconv.py +126 -0
- learning3d/examples/test_pointnet.py +121 -0
- learning3d/examples/test_prnet.py +126 -0
- learning3d/examples/test_rpmnet.py +120 -0
- learning3d/examples/train_PointNetLK.py +240 -0
- learning3d/examples/train_dcp.py +249 -0
- learning3d/examples/train_deepgmr.py +244 -0
- learning3d/examples/train_flownet.py +259 -0
- learning3d/examples/train_masknet.py +239 -0
- learning3d/examples/train_pcn.py +216 -0
- learning3d/examples/train_pcrnet.py +228 -0
- learning3d/examples/train_pointconv.py +245 -0
- learning3d/examples/train_pointnet.py +244 -0
- learning3d/examples/train_prnet.py +229 -0
- learning3d/examples/train_rpmnet.py +228 -0
- learning3d/losses/__init__.py +12 -0
- learning3d/losses/chamfer_distance.py +51 -0
- learning3d/losses/classification.py +14 -0
- learning3d/losses/correspondence_loss.py +10 -0
- learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
- learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
- learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
- learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
- learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
- learning3d/losses/cuda/emd_torch/setup.py +29 -0
- learning3d/losses/emd.py +16 -0
- learning3d/losses/frobenius_norm.py +21 -0
- learning3d/losses/rmse_features.py +16 -0
- learning3d/models/__init__.py +23 -0
- learning3d/models/classifier.py +41 -0
- learning3d/models/dcp.py +92 -0
- learning3d/models/deepgmr.py +165 -0
- learning3d/models/dgcnn.py +92 -0
- learning3d/models/flownet3d.py +446 -0
- learning3d/models/masknet.py +84 -0
- learning3d/models/masknet2.py +264 -0
- learning3d/models/pcn.py +164 -0
- learning3d/models/pcrnet.py +74 -0
- learning3d/models/pointconv.py +108 -0
- learning3d/models/pointnet.py +108 -0
- learning3d/models/pointnetlk.py +173 -0
- learning3d/models/pooling.py +15 -0
- learning3d/models/ppfnet.py +102 -0
- learning3d/models/prnet.py +431 -0
- learning3d/models/rpmnet.py +359 -0
- learning3d/models/segmentation.py +38 -0
- learning3d/ops/__init__.py +0 -0
- learning3d/ops/data_utils.py +45 -0
- learning3d/ops/invmat.py +134 -0
- learning3d/ops/quaternion.py +218 -0
- learning3d/ops/se3.py +157 -0
- learning3d/ops/sinc.py +229 -0
- learning3d/ops/so3.py +213 -0
- learning3d/ops/transform_functions.py +342 -0
- learning3d/utils/__init__.py +9 -0
- learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
- learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
- learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
- learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
- learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
- learning3d/utils/lib/pointnet2_modules.py +160 -0
- learning3d/utils/lib/pointnet2_utils.py +318 -0
- learning3d/utils/lib/pytorch_utils.py +236 -0
- learning3d/utils/lib/setup.py +23 -0
- learning3d/utils/lib/src/ball_query.cpp +25 -0
- learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
- learning3d/utils/lib/src/ball_query_gpu.h +15 -0
- learning3d/utils/lib/src/cuda_utils.h +15 -0
- learning3d/utils/lib/src/group_points.cpp +36 -0
- learning3d/utils/lib/src/group_points_gpu.cu +86 -0
- learning3d/utils/lib/src/group_points_gpu.h +22 -0
- learning3d/utils/lib/src/interpolate.cpp +65 -0
- learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
- learning3d/utils/lib/src/interpolate_gpu.h +36 -0
- learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
- learning3d/utils/lib/src/sampling.cpp +46 -0
- learning3d/utils/lib/src/sampling_gpu.cu +253 -0
- learning3d/utils/lib/src/sampling_gpu.h +29 -0
- learning3d/utils/pointconv_util.py +382 -0
- learning3d/utils/ppfnet_util.py +244 -0
- learning3d/utils/svd.py +59 -0
- learning3d/utils/transformer.py +243 -0
- learning3d-0.0.1.dist-info/LICENSE +21 -0
- learning3d-0.0.1.dist-info/METADATA +271 -0
- learning3d-0.0.1.dist-info/RECORD +115 -0
- learning3d-0.0.1.dist-info/WHEEL +5 -0
- learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,431 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
|
5
|
+
import os
|
6
|
+
import sys
|
7
|
+
import glob
|
8
|
+
import h5py
|
9
|
+
import copy
|
10
|
+
import math
|
11
|
+
import json
|
12
|
+
import numpy as np
|
13
|
+
from tqdm import tqdm
|
14
|
+
import torch
|
15
|
+
import torch.nn as nn
|
16
|
+
import torch.nn.functional as F
|
17
|
+
|
18
|
+
from .. ops import transform_functions as transform
|
19
|
+
from .. utils import Transformer, Identity
|
20
|
+
|
21
|
+
from sklearn.metrics import r2_score
|
22
|
+
|
23
|
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
24
|
+
|
25
|
+
|
26
|
+
def pairwise_distance(src, tgt):
|
27
|
+
inner = -2 * torch.matmul(src.transpose(2, 1).contiguous(), tgt)
|
28
|
+
xx = torch.sum(src**2, dim=1, keepdim=True)
|
29
|
+
yy = torch.sum(tgt**2, dim=1, keepdim=True)
|
30
|
+
distances = xx.transpose(2, 1).contiguous() + inner + yy
|
31
|
+
return torch.sqrt(distances)
|
32
|
+
|
33
|
+
|
34
|
+
def knn(x, k):
|
35
|
+
inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
|
36
|
+
xx = torch.sum(x ** 2, dim=1, keepdim=True)
|
37
|
+
distance = -xx - inner - xx.transpose(2, 1).contiguous()
|
38
|
+
|
39
|
+
idx = distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
|
40
|
+
return idx
|
41
|
+
|
42
|
+
|
43
|
+
def get_graph_feature(x, k=20):
|
44
|
+
# x = x.squeeze()
|
45
|
+
x = x.view(*x.size()[:3])
|
46
|
+
idx = knn(x, k=k) # (batch_size, num_points, k)
|
47
|
+
batch_size, num_points, _ = idx.size()
|
48
|
+
|
49
|
+
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
|
50
|
+
|
51
|
+
idx = idx + idx_base
|
52
|
+
|
53
|
+
idx = idx.view(-1)
|
54
|
+
|
55
|
+
_, num_dims, _ = x.size()
|
56
|
+
|
57
|
+
x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
|
58
|
+
feature = x.view(batch_size * num_points, -1)[idx, :]
|
59
|
+
feature = feature.view(batch_size, num_points, k, num_dims)
|
60
|
+
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
|
61
|
+
|
62
|
+
feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
|
63
|
+
|
64
|
+
return feature
|
65
|
+
|
66
|
+
|
67
|
+
def cycle_consistency(rotation_ab, translation_ab, rotation_ba, translation_ba):
|
68
|
+
batch_size = rotation_ab.size(0)
|
69
|
+
identity = torch.eye(3, device=rotation_ab.device).unsqueeze(0).repeat(batch_size, 1, 1)
|
70
|
+
return F.mse_loss(torch.matmul(rotation_ab, rotation_ba), identity) + F.mse_loss(translation_ab, -translation_ba)
|
71
|
+
|
72
|
+
|
73
|
+
class PointNet(nn.Module):
|
74
|
+
def __init__(self, emb_dims=512):
|
75
|
+
super(PointNet, self).__init__()
|
76
|
+
self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)
|
77
|
+
self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
|
78
|
+
self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
|
79
|
+
self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False)
|
80
|
+
self.conv5 = nn.Conv1d(128, emb_dims, kernel_size=1, bias=False)
|
81
|
+
self.bn1 = nn.BatchNorm1d(64)
|
82
|
+
self.bn2 = nn.BatchNorm1d(64)
|
83
|
+
self.bn3 = nn.BatchNorm1d(64)
|
84
|
+
self.bn4 = nn.BatchNorm1d(128)
|
85
|
+
self.bn5 = nn.BatchNorm1d(emb_dims)
|
86
|
+
|
87
|
+
def forward(self, x):
|
88
|
+
x = F.relu(self.bn1(self.conv1(x)))
|
89
|
+
x = F.relu(self.bn2(self.conv2(x)))
|
90
|
+
x = F.relu(self.bn3(self.conv3(x)))
|
91
|
+
x = F.relu(self.bn4(self.conv4(x)))
|
92
|
+
x = F.relu(self.bn5(self.conv5(x)))
|
93
|
+
return x
|
94
|
+
|
95
|
+
|
96
|
+
class DGCNN(nn.Module):
|
97
|
+
def __init__(self, emb_dims=512):
|
98
|
+
super(DGCNN, self).__init__()
|
99
|
+
self.conv1 = nn.Conv2d(6, 64, kernel_size=1, bias=False)
|
100
|
+
self.conv2 = nn.Conv2d(64*2, 64, kernel_size=1, bias=False)
|
101
|
+
self.conv3 = nn.Conv2d(64*2, 128, kernel_size=1, bias=False)
|
102
|
+
self.conv4 = nn.Conv2d(128*2, 256, kernel_size=1, bias=False)
|
103
|
+
self.conv5 = nn.Conv2d(512, emb_dims, kernel_size=1, bias=False)
|
104
|
+
self.bn1 = nn.BatchNorm2d(64)
|
105
|
+
self.bn2 = nn.BatchNorm2d(64)
|
106
|
+
self.bn3 = nn.BatchNorm2d(128)
|
107
|
+
self.bn4 = nn.BatchNorm2d(256)
|
108
|
+
self.bn5 = nn.BatchNorm2d(emb_dims)
|
109
|
+
|
110
|
+
def forward(self, x):
|
111
|
+
batch_size, num_dims, num_points = x.size()
|
112
|
+
x = get_graph_feature(x)
|
113
|
+
x = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.2)
|
114
|
+
x1 = x.max(dim=-1, keepdim=True)[0]
|
115
|
+
|
116
|
+
x = get_graph_feature(x1)
|
117
|
+
x = F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.2)
|
118
|
+
x2 = x.max(dim=-1, keepdim=True)[0]
|
119
|
+
|
120
|
+
x = get_graph_feature(x2)
|
121
|
+
x = F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.2)
|
122
|
+
x3 = x.max(dim=-1, keepdim=True)[0]
|
123
|
+
|
124
|
+
x = get_graph_feature(x3)
|
125
|
+
x = F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.2)
|
126
|
+
x4 = x.max(dim=-1, keepdim=True)[0]
|
127
|
+
|
128
|
+
x = torch.cat((x1, x2, x3, x4), dim=1)
|
129
|
+
|
130
|
+
x = F.leaky_relu(self.bn5(self.conv5(x)), negative_slope=0.2).view(batch_size, -1, num_points)
|
131
|
+
return x
|
132
|
+
|
133
|
+
|
134
|
+
class MLPHead(nn.Module):
|
135
|
+
def __init__(self, emb_dims):
|
136
|
+
super(MLPHead, self).__init__()
|
137
|
+
n_emb_dims = emb_dims
|
138
|
+
self.n_emb_dims = n_emb_dims
|
139
|
+
self.nn = nn.Sequential(nn.Linear(n_emb_dims*2, n_emb_dims//2),
|
140
|
+
nn.BatchNorm1d(n_emb_dims//2),
|
141
|
+
nn.ReLU(),
|
142
|
+
nn.Linear(n_emb_dims//2, n_emb_dims//4),
|
143
|
+
nn.BatchNorm1d(n_emb_dims//4),
|
144
|
+
nn.ReLU(),
|
145
|
+
nn.Linear(n_emb_dims//4, n_emb_dims//8),
|
146
|
+
nn.BatchNorm1d(n_emb_dims//8),
|
147
|
+
nn.ReLU())
|
148
|
+
self.proj_rot = nn.Linear(n_emb_dims//8, 4)
|
149
|
+
self.proj_trans = nn.Linear(n_emb_dims//8, 3)
|
150
|
+
|
151
|
+
def forward(self, *input):
|
152
|
+
src_embedding = input[0]
|
153
|
+
tgt_embedding = input[1]
|
154
|
+
embedding = torch.cat((src_embedding, tgt_embedding), dim=1)
|
155
|
+
embedding = self.nn(embedding.max(dim=-1)[0])
|
156
|
+
rotation = self.proj_rot(embedding)
|
157
|
+
rotation = rotation / torch.norm(rotation, p=2, dim=1, keepdim=True)
|
158
|
+
translation = self.proj_trans(embedding)
|
159
|
+
return quat2mat(rotation), translation
|
160
|
+
|
161
|
+
|
162
|
+
class TemperatureNet(nn.Module):
|
163
|
+
def __init__(self, emb_dims, temp_factor):
|
164
|
+
super(TemperatureNet, self).__init__()
|
165
|
+
self.n_emb_dims = emb_dims
|
166
|
+
self.temp_factor = temp_factor
|
167
|
+
self.nn = nn.Sequential(nn.Linear(self.n_emb_dims, 128),
|
168
|
+
nn.BatchNorm1d(128),
|
169
|
+
nn.ReLU(),
|
170
|
+
nn.Linear(128, 128),
|
171
|
+
nn.BatchNorm1d(128),
|
172
|
+
nn.ReLU(),
|
173
|
+
nn.Linear(128, 128),
|
174
|
+
nn.BatchNorm1d(128),
|
175
|
+
nn.ReLU(),
|
176
|
+
nn.Linear(128, 1),
|
177
|
+
nn.ReLU())
|
178
|
+
self.feature_disparity = None
|
179
|
+
|
180
|
+
def forward(self, *input):
|
181
|
+
src_embedding = input[0]
|
182
|
+
tgt_embedding = input[1]
|
183
|
+
src_embedding = src_embedding.mean(dim=2)
|
184
|
+
tgt_embedding = tgt_embedding.mean(dim=2)
|
185
|
+
residual = torch.abs(src_embedding-tgt_embedding)
|
186
|
+
|
187
|
+
self.feature_disparity = residual
|
188
|
+
|
189
|
+
return torch.clamp(self.nn(residual), 1.0/self.temp_factor, 1.0*self.temp_factor), residual
|
190
|
+
|
191
|
+
|
192
|
+
class SVDHead(nn.Module):
|
193
|
+
def __init__(self, emb_dims, cat_sampler):
|
194
|
+
super(SVDHead, self).__init__()
|
195
|
+
self.n_emb_dims = emb_dims
|
196
|
+
self.cat_sampler = cat_sampler
|
197
|
+
self.reflect = nn.Parameter(torch.eye(3), requires_grad=False)
|
198
|
+
self.reflect[2, 2] = -1
|
199
|
+
self.temperature = nn.Parameter(torch.ones(1)*0.5, requires_grad=True)
|
200
|
+
self.my_iter = torch.ones(1)
|
201
|
+
|
202
|
+
def forward(self, *input):
|
203
|
+
src_embedding = input[0]
|
204
|
+
tgt_embedding = input[1]
|
205
|
+
src = input[2]
|
206
|
+
tgt = input[3]
|
207
|
+
batch_size, num_dims, num_points = src.size()
|
208
|
+
temperature = input[4].view(batch_size, 1, 1)
|
209
|
+
|
210
|
+
if self.cat_sampler == 'softmax':
|
211
|
+
d_k = src_embedding.size(1)
|
212
|
+
scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k)
|
213
|
+
scores = torch.softmax(temperature*scores, dim=2)
|
214
|
+
elif self.cat_sampler == 'gumbel_softmax':
|
215
|
+
d_k = src_embedding.size(1)
|
216
|
+
scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k)
|
217
|
+
scores = scores.view(batch_size*num_points, num_points)
|
218
|
+
temperature = temperature.repeat(1, num_points, 1).view(-1, 1)
|
219
|
+
scores = F.gumbel_softmax(scores, tau=temperature, hard=True)
|
220
|
+
scores = scores.view(batch_size, num_points, num_points)
|
221
|
+
else:
|
222
|
+
raise Exception('not implemented')
|
223
|
+
|
224
|
+
src_corr = torch.matmul(tgt, scores.transpose(2, 1).contiguous())
|
225
|
+
|
226
|
+
src_centered = src - src.mean(dim=2, keepdim=True)
|
227
|
+
|
228
|
+
src_corr_centered = src_corr - src_corr.mean(dim=2, keepdim=True)
|
229
|
+
|
230
|
+
H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous()).cpu()
|
231
|
+
|
232
|
+
R = []
|
233
|
+
|
234
|
+
for i in range(src.size(0)):
|
235
|
+
u, s, v = torch.svd(H[i])
|
236
|
+
r = torch.matmul(v, u.transpose(1, 0)).contiguous()
|
237
|
+
r_det = torch.det(r).item()
|
238
|
+
diag = torch.from_numpy(np.array([[1.0, 0, 0],
|
239
|
+
[0, 1.0, 0],
|
240
|
+
[0, 0, r_det]]).astype('float32')).to(v.device)
|
241
|
+
r = torch.matmul(torch.matmul(v, diag), u.transpose(1, 0)).contiguous()
|
242
|
+
R.append(r)
|
243
|
+
|
244
|
+
R = torch.stack(R, dim=0).to(device)
|
245
|
+
|
246
|
+
t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True)
|
247
|
+
if self.training:
|
248
|
+
self.my_iter += 1
|
249
|
+
return R, t.view(batch_size, 3)
|
250
|
+
|
251
|
+
|
252
|
+
class KeyPointNet(nn.Module):
|
253
|
+
def __init__(self, num_keypoints):
|
254
|
+
super(KeyPointNet, self).__init__()
|
255
|
+
self.num_keypoints = num_keypoints
|
256
|
+
|
257
|
+
def forward(self, *input):
|
258
|
+
src = input[0]
|
259
|
+
tgt = input[1]
|
260
|
+
src_embedding = input[2]
|
261
|
+
tgt_embedding = input[3]
|
262
|
+
batch_size, num_dims, num_points = src_embedding.size()
|
263
|
+
src_norm = torch.norm(src_embedding, dim=1, keepdim=True)
|
264
|
+
tgt_norm = torch.norm(tgt_embedding, dim=1, keepdim=True)
|
265
|
+
src_topk_idx = torch.topk(src_norm, k=self.num_keypoints, dim=2, sorted=False)[1]
|
266
|
+
tgt_topk_idx = torch.topk(tgt_norm, k=self.num_keypoints, dim=2, sorted=False)[1]
|
267
|
+
src_keypoints_idx = src_topk_idx.repeat(1, 3, 1)
|
268
|
+
tgt_keypoints_idx = tgt_topk_idx.repeat(1, 3, 1)
|
269
|
+
src_embedding_idx = src_topk_idx.repeat(1, num_dims, 1)
|
270
|
+
tgt_embedding_idx = tgt_topk_idx.repeat(1, num_dims, 1)
|
271
|
+
|
272
|
+
src_keypoints = torch.gather(src, dim=2, index=src_keypoints_idx)
|
273
|
+
tgt_keypoints = torch.gather(tgt, dim=2, index=tgt_keypoints_idx)
|
274
|
+
|
275
|
+
src_embedding = torch.gather(src_embedding, dim=2, index=src_embedding_idx)
|
276
|
+
tgt_embedding = torch.gather(tgt_embedding, dim=2, index=tgt_embedding_idx)
|
277
|
+
return src_keypoints, tgt_keypoints, src_embedding, tgt_embedding
|
278
|
+
|
279
|
+
|
280
|
+
class PRNet(nn.Module):
|
281
|
+
def __init__(self, emb_nn='dgcnn', attention='transformer', head='svd', emb_dims=512, num_keypoints=512, num_subsampled_points=768, num_iters=3, cycle_consistency_loss=0.1, feature_alignment_loss=0.1, discount_factor = 0.9, input_shape='bnc'):
|
282
|
+
super(PRNet, self).__init__()
|
283
|
+
self.emb_dims = emb_dims
|
284
|
+
self.num_keypoints = num_keypoints
|
285
|
+
self.num_subsampled_points = num_subsampled_points
|
286
|
+
self.num_iters = num_iters
|
287
|
+
self.discount_factor = discount_factor
|
288
|
+
self.feature_alignment_loss = feature_alignment_loss
|
289
|
+
self.cycle_consistency_loss = cycle_consistency_loss
|
290
|
+
self.input_shape = input_shape
|
291
|
+
|
292
|
+
if emb_nn == 'pointnet':
|
293
|
+
self.emb_nn = PointNet(emb_dims=self.emb_dims)
|
294
|
+
elif emb_nn == 'dgcnn':
|
295
|
+
self.emb_nn = DGCNN(emb_dims=self.emb_dims)
|
296
|
+
else:
|
297
|
+
raise Exception('Not implemented')
|
298
|
+
|
299
|
+
if attention == 'identity':
|
300
|
+
self.attention = Identity()
|
301
|
+
elif attention == 'transformer':
|
302
|
+
self.attention = Transformer(emb_dims=self.emb_dims, n_blocks=1, dropout=0.0, ff_dims=1024, n_heads=4)
|
303
|
+
else:
|
304
|
+
raise Exception("Not implemented")
|
305
|
+
|
306
|
+
self.temp_net = TemperatureNet(emb_dims=self.emb_dims, temp_factor=100)
|
307
|
+
|
308
|
+
if head == 'mlp':
|
309
|
+
self.head = MLPHead(emb_dims=self.emb_dims)
|
310
|
+
elif head == 'svd':
|
311
|
+
self.head = SVDHead(emb_dims=self.emb_dims, cat_sampler='softmax')
|
312
|
+
else:
|
313
|
+
raise Exception('Not implemented')
|
314
|
+
|
315
|
+
if self.num_keypoints != self.num_subsampled_points:
|
316
|
+
self.keypointnet = KeyPointNet(num_keypoints=self.num_keypoints)
|
317
|
+
else:
|
318
|
+
self.keypointnet = Identity()
|
319
|
+
|
320
|
+
def predict_embedding(self, *input):
|
321
|
+
src = input[0]
|
322
|
+
tgt = input[1]
|
323
|
+
src_embedding = self.emb_nn(src)
|
324
|
+
tgt_embedding = self.emb_nn(tgt)
|
325
|
+
|
326
|
+
src_embedding_p, tgt_embedding_p = self.attention(src_embedding, tgt_embedding)
|
327
|
+
|
328
|
+
src_embedding = src_embedding + src_embedding_p
|
329
|
+
tgt_embedding = tgt_embedding + tgt_embedding_p
|
330
|
+
|
331
|
+
src, tgt, src_embedding, tgt_embedding = self.keypointnet(src, tgt, src_embedding, tgt_embedding)
|
332
|
+
|
333
|
+
temperature, feature_disparity = self.temp_net(src_embedding, tgt_embedding)
|
334
|
+
|
335
|
+
return src, tgt, src_embedding, tgt_embedding, temperature, feature_disparity
|
336
|
+
|
337
|
+
# Single Pass Alignment Module for PRNet
|
338
|
+
def spam(self, *input):
|
339
|
+
src, tgt, src_embedding, tgt_embedding, temperature, feature_disparity = self.predict_embedding(*input)
|
340
|
+
rotation_ab, translation_ab = self.head(src_embedding, tgt_embedding, src, tgt, temperature)
|
341
|
+
rotation_ba, translation_ba = self.head(tgt_embedding, src_embedding, tgt, src, temperature)
|
342
|
+
return rotation_ab, translation_ab, rotation_ba, translation_ba, feature_disparity
|
343
|
+
|
344
|
+
def predict_keypoint_correspondence(self, *input):
|
345
|
+
src, tgt, src_embedding, tgt_embedding, temperature, _ = self.predict_embedding(*input)
|
346
|
+
batch_size, num_dims, num_points = src.size()
|
347
|
+
d_k = src_embedding.size(1)
|
348
|
+
scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k)
|
349
|
+
scores = scores.view(batch_size*num_points, num_points)
|
350
|
+
temperature = temperature.repeat(1, num_points, 1).view(-1, 1)
|
351
|
+
scores = F.gumbel_softmax(scores, tau=temperature, hard=True)
|
352
|
+
scores = scores.view(batch_size, num_points, num_points)
|
353
|
+
return src, tgt, scores
|
354
|
+
|
355
|
+
def forward(self, *input):
|
356
|
+
calculate_loss = False
|
357
|
+
if len(input) == 2:
|
358
|
+
src, tgt = input[0], input[1]
|
359
|
+
elif len(input) == 3:
|
360
|
+
src, tgt, rotation_ab, translation_ab = input[0], input[1], input[2][:, :3, :3], input[2][:, :3, 3].view(-1, 3)
|
361
|
+
calculate_loss = True
|
362
|
+
elif len(input) == 4:
|
363
|
+
src, tgt, rotation_ab, translation_ab = input[0], input[1], input[2], input[3]
|
364
|
+
calculate_loss = True
|
365
|
+
|
366
|
+
if self.input_shape == 'bnc':
|
367
|
+
src, tgt = src.permute(0, 2, 1), tgt.permute(0, 2, 1)
|
368
|
+
|
369
|
+
batch_size = src.size(0)
|
370
|
+
identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(batch_size, 1, 1)
|
371
|
+
|
372
|
+
rotation_ab_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1)
|
373
|
+
translation_ab_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1)
|
374
|
+
|
375
|
+
rotation_ba_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1)
|
376
|
+
translation_ba_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1)
|
377
|
+
|
378
|
+
total_loss = 0
|
379
|
+
total_feature_alignment_loss = 0
|
380
|
+
total_cycle_consistency_loss = 0
|
381
|
+
total_scale_consensus_loss = 0
|
382
|
+
|
383
|
+
for i in range(self.num_iters):
|
384
|
+
rotation_ab_pred_i, translation_ab_pred_i, rotation_ba_pred_i, translation_ba_pred_i, feature_disparity = self.spam(src, tgt)
|
385
|
+
|
386
|
+
rotation_ab_pred = torch.matmul(rotation_ab_pred_i, rotation_ab_pred)
|
387
|
+
translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) + translation_ab_pred_i
|
388
|
+
|
389
|
+
rotation_ba_pred = torch.matmul(rotation_ba_pred_i, rotation_ba_pred)
|
390
|
+
translation_ba_pred = torch.matmul(rotation_ba_pred_i, translation_ba_pred.unsqueeze(2)).squeeze(2) + translation_ba_pred_i
|
391
|
+
|
392
|
+
if calculate_loss:
|
393
|
+
loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
|
394
|
+
+ F.mse_loss(translation_ab_pred, translation_ab)) * self.discount_factor**i
|
395
|
+
|
396
|
+
feature_alignment_loss = feature_disparity.mean() * self.feature_alignment_loss * self.discount_factor**i
|
397
|
+
cycle_consistency_loss = cycle_consistency(rotation_ab_pred_i, translation_ab_pred_i,
|
398
|
+
rotation_ba_pred_i, translation_ba_pred_i) \
|
399
|
+
* self.cycle_consistency_loss * self.discount_factor**i
|
400
|
+
|
401
|
+
scale_consensus_loss = 0
|
402
|
+
total_feature_alignment_loss += feature_alignment_loss
|
403
|
+
total_cycle_consistency_loss += cycle_consistency_loss
|
404
|
+
total_loss = total_loss + loss + feature_alignment_loss + cycle_consistency_loss + scale_consensus_loss
|
405
|
+
|
406
|
+
if self.input_shape == 'bnc':
|
407
|
+
src = transform.transform_point_cloud(src.permute(0, 2, 1), rotation_ab_pred_i, translation_ab_pred_i).permute(0, 2, 1)
|
408
|
+
else:
|
409
|
+
src = transform.transform_point_cloud(src, rotation_ab_pred_i, translation_ab_pred_i)
|
410
|
+
|
411
|
+
if self.input_shape == 'bnc':
|
412
|
+
src, tgt = src.permute(0, 2, 1), tgt.permute(0, 2, 1)
|
413
|
+
|
414
|
+
result = {'est_R': rotation_ab_pred,
|
415
|
+
'est_t': translation_ab_pred,
|
416
|
+
'est_T': transform.convert2transformation(rotation_ab_pred, translation_ab_pred),
|
417
|
+
'transformed_source': src}
|
418
|
+
|
419
|
+
if calculate_loss:
|
420
|
+
result['loss'] = total_loss
|
421
|
+
return result
|
422
|
+
|
423
|
+
|
424
|
+
if __name__ == '__main__':
|
425
|
+
model = PRNet()
|
426
|
+
src = torch.tensor(10, 1024, 3)
|
427
|
+
tgt = torch.tensor(10, 768, 3)
|
428
|
+
rotation_ab, translation_ab = torch.tensor(10, 3, 3), torch.tensor(10, 3)
|
429
|
+
src, tgt = src.to(device), tgt.to(device)
|
430
|
+
rotation_ab, translation_ab = rotation_ab.to(device), translation_ab.to(device)
|
431
|
+
rotation_ab_pred, translation_ab_pred, loss = model(src, tgt, rotation_ab, translation_ab)
|