learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,431 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ import os
6
+ import sys
7
+ import glob
8
+ import h5py
9
+ import copy
10
+ import math
11
+ import json
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from .. ops import transform_functions as transform
19
+ from .. utils import Transformer, Identity
20
+
21
+ from sklearn.metrics import r2_score
22
+
23
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
24
+
25
+
26
+ def pairwise_distance(src, tgt):
27
+ inner = -2 * torch.matmul(src.transpose(2, 1).contiguous(), tgt)
28
+ xx = torch.sum(src**2, dim=1, keepdim=True)
29
+ yy = torch.sum(tgt**2, dim=1, keepdim=True)
30
+ distances = xx.transpose(2, 1).contiguous() + inner + yy
31
+ return torch.sqrt(distances)
32
+
33
+
34
+ def knn(x, k):
35
+ inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
36
+ xx = torch.sum(x ** 2, dim=1, keepdim=True)
37
+ distance = -xx - inner - xx.transpose(2, 1).contiguous()
38
+
39
+ idx = distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
40
+ return idx
41
+
42
+
43
+ def get_graph_feature(x, k=20):
44
+ # x = x.squeeze()
45
+ x = x.view(*x.size()[:3])
46
+ idx = knn(x, k=k) # (batch_size, num_points, k)
47
+ batch_size, num_points, _ = idx.size()
48
+
49
+ idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
50
+
51
+ idx = idx + idx_base
52
+
53
+ idx = idx.view(-1)
54
+
55
+ _, num_dims, _ = x.size()
56
+
57
+ x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
58
+ feature = x.view(batch_size * num_points, -1)[idx, :]
59
+ feature = feature.view(batch_size, num_points, k, num_dims)
60
+ x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
61
+
62
+ feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
63
+
64
+ return feature
65
+
66
+
67
+ def cycle_consistency(rotation_ab, translation_ab, rotation_ba, translation_ba):
68
+ batch_size = rotation_ab.size(0)
69
+ identity = torch.eye(3, device=rotation_ab.device).unsqueeze(0).repeat(batch_size, 1, 1)
70
+ return F.mse_loss(torch.matmul(rotation_ab, rotation_ba), identity) + F.mse_loss(translation_ab, -translation_ba)
71
+
72
+
73
+ class PointNet(nn.Module):
74
+ def __init__(self, emb_dims=512):
75
+ super(PointNet, self).__init__()
76
+ self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)
77
+ self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
78
+ self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
79
+ self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False)
80
+ self.conv5 = nn.Conv1d(128, emb_dims, kernel_size=1, bias=False)
81
+ self.bn1 = nn.BatchNorm1d(64)
82
+ self.bn2 = nn.BatchNorm1d(64)
83
+ self.bn3 = nn.BatchNorm1d(64)
84
+ self.bn4 = nn.BatchNorm1d(128)
85
+ self.bn5 = nn.BatchNorm1d(emb_dims)
86
+
87
+ def forward(self, x):
88
+ x = F.relu(self.bn1(self.conv1(x)))
89
+ x = F.relu(self.bn2(self.conv2(x)))
90
+ x = F.relu(self.bn3(self.conv3(x)))
91
+ x = F.relu(self.bn4(self.conv4(x)))
92
+ x = F.relu(self.bn5(self.conv5(x)))
93
+ return x
94
+
95
+
96
+ class DGCNN(nn.Module):
97
+ def __init__(self, emb_dims=512):
98
+ super(DGCNN, self).__init__()
99
+ self.conv1 = nn.Conv2d(6, 64, kernel_size=1, bias=False)
100
+ self.conv2 = nn.Conv2d(64*2, 64, kernel_size=1, bias=False)
101
+ self.conv3 = nn.Conv2d(64*2, 128, kernel_size=1, bias=False)
102
+ self.conv4 = nn.Conv2d(128*2, 256, kernel_size=1, bias=False)
103
+ self.conv5 = nn.Conv2d(512, emb_dims, kernel_size=1, bias=False)
104
+ self.bn1 = nn.BatchNorm2d(64)
105
+ self.bn2 = nn.BatchNorm2d(64)
106
+ self.bn3 = nn.BatchNorm2d(128)
107
+ self.bn4 = nn.BatchNorm2d(256)
108
+ self.bn5 = nn.BatchNorm2d(emb_dims)
109
+
110
+ def forward(self, x):
111
+ batch_size, num_dims, num_points = x.size()
112
+ x = get_graph_feature(x)
113
+ x = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.2)
114
+ x1 = x.max(dim=-1, keepdim=True)[0]
115
+
116
+ x = get_graph_feature(x1)
117
+ x = F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.2)
118
+ x2 = x.max(dim=-1, keepdim=True)[0]
119
+
120
+ x = get_graph_feature(x2)
121
+ x = F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.2)
122
+ x3 = x.max(dim=-1, keepdim=True)[0]
123
+
124
+ x = get_graph_feature(x3)
125
+ x = F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.2)
126
+ x4 = x.max(dim=-1, keepdim=True)[0]
127
+
128
+ x = torch.cat((x1, x2, x3, x4), dim=1)
129
+
130
+ x = F.leaky_relu(self.bn5(self.conv5(x)), negative_slope=0.2).view(batch_size, -1, num_points)
131
+ return x
132
+
133
+
134
+ class MLPHead(nn.Module):
135
+ def __init__(self, emb_dims):
136
+ super(MLPHead, self).__init__()
137
+ n_emb_dims = emb_dims
138
+ self.n_emb_dims = n_emb_dims
139
+ self.nn = nn.Sequential(nn.Linear(n_emb_dims*2, n_emb_dims//2),
140
+ nn.BatchNorm1d(n_emb_dims//2),
141
+ nn.ReLU(),
142
+ nn.Linear(n_emb_dims//2, n_emb_dims//4),
143
+ nn.BatchNorm1d(n_emb_dims//4),
144
+ nn.ReLU(),
145
+ nn.Linear(n_emb_dims//4, n_emb_dims//8),
146
+ nn.BatchNorm1d(n_emb_dims//8),
147
+ nn.ReLU())
148
+ self.proj_rot = nn.Linear(n_emb_dims//8, 4)
149
+ self.proj_trans = nn.Linear(n_emb_dims//8, 3)
150
+
151
+ def forward(self, *input):
152
+ src_embedding = input[0]
153
+ tgt_embedding = input[1]
154
+ embedding = torch.cat((src_embedding, tgt_embedding), dim=1)
155
+ embedding = self.nn(embedding.max(dim=-1)[0])
156
+ rotation = self.proj_rot(embedding)
157
+ rotation = rotation / torch.norm(rotation, p=2, dim=1, keepdim=True)
158
+ translation = self.proj_trans(embedding)
159
+ return quat2mat(rotation), translation
160
+
161
+
162
+ class TemperatureNet(nn.Module):
163
+ def __init__(self, emb_dims, temp_factor):
164
+ super(TemperatureNet, self).__init__()
165
+ self.n_emb_dims = emb_dims
166
+ self.temp_factor = temp_factor
167
+ self.nn = nn.Sequential(nn.Linear(self.n_emb_dims, 128),
168
+ nn.BatchNorm1d(128),
169
+ nn.ReLU(),
170
+ nn.Linear(128, 128),
171
+ nn.BatchNorm1d(128),
172
+ nn.ReLU(),
173
+ nn.Linear(128, 128),
174
+ nn.BatchNorm1d(128),
175
+ nn.ReLU(),
176
+ nn.Linear(128, 1),
177
+ nn.ReLU())
178
+ self.feature_disparity = None
179
+
180
+ def forward(self, *input):
181
+ src_embedding = input[0]
182
+ tgt_embedding = input[1]
183
+ src_embedding = src_embedding.mean(dim=2)
184
+ tgt_embedding = tgt_embedding.mean(dim=2)
185
+ residual = torch.abs(src_embedding-tgt_embedding)
186
+
187
+ self.feature_disparity = residual
188
+
189
+ return torch.clamp(self.nn(residual), 1.0/self.temp_factor, 1.0*self.temp_factor), residual
190
+
191
+
192
+ class SVDHead(nn.Module):
193
+ def __init__(self, emb_dims, cat_sampler):
194
+ super(SVDHead, self).__init__()
195
+ self.n_emb_dims = emb_dims
196
+ self.cat_sampler = cat_sampler
197
+ self.reflect = nn.Parameter(torch.eye(3), requires_grad=False)
198
+ self.reflect[2, 2] = -1
199
+ self.temperature = nn.Parameter(torch.ones(1)*0.5, requires_grad=True)
200
+ self.my_iter = torch.ones(1)
201
+
202
+ def forward(self, *input):
203
+ src_embedding = input[0]
204
+ tgt_embedding = input[1]
205
+ src = input[2]
206
+ tgt = input[3]
207
+ batch_size, num_dims, num_points = src.size()
208
+ temperature = input[4].view(batch_size, 1, 1)
209
+
210
+ if self.cat_sampler == 'softmax':
211
+ d_k = src_embedding.size(1)
212
+ scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k)
213
+ scores = torch.softmax(temperature*scores, dim=2)
214
+ elif self.cat_sampler == 'gumbel_softmax':
215
+ d_k = src_embedding.size(1)
216
+ scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k)
217
+ scores = scores.view(batch_size*num_points, num_points)
218
+ temperature = temperature.repeat(1, num_points, 1).view(-1, 1)
219
+ scores = F.gumbel_softmax(scores, tau=temperature, hard=True)
220
+ scores = scores.view(batch_size, num_points, num_points)
221
+ else:
222
+ raise Exception('not implemented')
223
+
224
+ src_corr = torch.matmul(tgt, scores.transpose(2, 1).contiguous())
225
+
226
+ src_centered = src - src.mean(dim=2, keepdim=True)
227
+
228
+ src_corr_centered = src_corr - src_corr.mean(dim=2, keepdim=True)
229
+
230
+ H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous()).cpu()
231
+
232
+ R = []
233
+
234
+ for i in range(src.size(0)):
235
+ u, s, v = torch.svd(H[i])
236
+ r = torch.matmul(v, u.transpose(1, 0)).contiguous()
237
+ r_det = torch.det(r).item()
238
+ diag = torch.from_numpy(np.array([[1.0, 0, 0],
239
+ [0, 1.0, 0],
240
+ [0, 0, r_det]]).astype('float32')).to(v.device)
241
+ r = torch.matmul(torch.matmul(v, diag), u.transpose(1, 0)).contiguous()
242
+ R.append(r)
243
+
244
+ R = torch.stack(R, dim=0).to(device)
245
+
246
+ t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True)
247
+ if self.training:
248
+ self.my_iter += 1
249
+ return R, t.view(batch_size, 3)
250
+
251
+
252
+ class KeyPointNet(nn.Module):
253
+ def __init__(self, num_keypoints):
254
+ super(KeyPointNet, self).__init__()
255
+ self.num_keypoints = num_keypoints
256
+
257
+ def forward(self, *input):
258
+ src = input[0]
259
+ tgt = input[1]
260
+ src_embedding = input[2]
261
+ tgt_embedding = input[3]
262
+ batch_size, num_dims, num_points = src_embedding.size()
263
+ src_norm = torch.norm(src_embedding, dim=1, keepdim=True)
264
+ tgt_norm = torch.norm(tgt_embedding, dim=1, keepdim=True)
265
+ src_topk_idx = torch.topk(src_norm, k=self.num_keypoints, dim=2, sorted=False)[1]
266
+ tgt_topk_idx = torch.topk(tgt_norm, k=self.num_keypoints, dim=2, sorted=False)[1]
267
+ src_keypoints_idx = src_topk_idx.repeat(1, 3, 1)
268
+ tgt_keypoints_idx = tgt_topk_idx.repeat(1, 3, 1)
269
+ src_embedding_idx = src_topk_idx.repeat(1, num_dims, 1)
270
+ tgt_embedding_idx = tgt_topk_idx.repeat(1, num_dims, 1)
271
+
272
+ src_keypoints = torch.gather(src, dim=2, index=src_keypoints_idx)
273
+ tgt_keypoints = torch.gather(tgt, dim=2, index=tgt_keypoints_idx)
274
+
275
+ src_embedding = torch.gather(src_embedding, dim=2, index=src_embedding_idx)
276
+ tgt_embedding = torch.gather(tgt_embedding, dim=2, index=tgt_embedding_idx)
277
+ return src_keypoints, tgt_keypoints, src_embedding, tgt_embedding
278
+
279
+
280
+ class PRNet(nn.Module):
281
+ def __init__(self, emb_nn='dgcnn', attention='transformer', head='svd', emb_dims=512, num_keypoints=512, num_subsampled_points=768, num_iters=3, cycle_consistency_loss=0.1, feature_alignment_loss=0.1, discount_factor = 0.9, input_shape='bnc'):
282
+ super(PRNet, self).__init__()
283
+ self.emb_dims = emb_dims
284
+ self.num_keypoints = num_keypoints
285
+ self.num_subsampled_points = num_subsampled_points
286
+ self.num_iters = num_iters
287
+ self.discount_factor = discount_factor
288
+ self.feature_alignment_loss = feature_alignment_loss
289
+ self.cycle_consistency_loss = cycle_consistency_loss
290
+ self.input_shape = input_shape
291
+
292
+ if emb_nn == 'pointnet':
293
+ self.emb_nn = PointNet(emb_dims=self.emb_dims)
294
+ elif emb_nn == 'dgcnn':
295
+ self.emb_nn = DGCNN(emb_dims=self.emb_dims)
296
+ else:
297
+ raise Exception('Not implemented')
298
+
299
+ if attention == 'identity':
300
+ self.attention = Identity()
301
+ elif attention == 'transformer':
302
+ self.attention = Transformer(emb_dims=self.emb_dims, n_blocks=1, dropout=0.0, ff_dims=1024, n_heads=4)
303
+ else:
304
+ raise Exception("Not implemented")
305
+
306
+ self.temp_net = TemperatureNet(emb_dims=self.emb_dims, temp_factor=100)
307
+
308
+ if head == 'mlp':
309
+ self.head = MLPHead(emb_dims=self.emb_dims)
310
+ elif head == 'svd':
311
+ self.head = SVDHead(emb_dims=self.emb_dims, cat_sampler='softmax')
312
+ else:
313
+ raise Exception('Not implemented')
314
+
315
+ if self.num_keypoints != self.num_subsampled_points:
316
+ self.keypointnet = KeyPointNet(num_keypoints=self.num_keypoints)
317
+ else:
318
+ self.keypointnet = Identity()
319
+
320
+ def predict_embedding(self, *input):
321
+ src = input[0]
322
+ tgt = input[1]
323
+ src_embedding = self.emb_nn(src)
324
+ tgt_embedding = self.emb_nn(tgt)
325
+
326
+ src_embedding_p, tgt_embedding_p = self.attention(src_embedding, tgt_embedding)
327
+
328
+ src_embedding = src_embedding + src_embedding_p
329
+ tgt_embedding = tgt_embedding + tgt_embedding_p
330
+
331
+ src, tgt, src_embedding, tgt_embedding = self.keypointnet(src, tgt, src_embedding, tgt_embedding)
332
+
333
+ temperature, feature_disparity = self.temp_net(src_embedding, tgt_embedding)
334
+
335
+ return src, tgt, src_embedding, tgt_embedding, temperature, feature_disparity
336
+
337
+ # Single Pass Alignment Module for PRNet
338
+ def spam(self, *input):
339
+ src, tgt, src_embedding, tgt_embedding, temperature, feature_disparity = self.predict_embedding(*input)
340
+ rotation_ab, translation_ab = self.head(src_embedding, tgt_embedding, src, tgt, temperature)
341
+ rotation_ba, translation_ba = self.head(tgt_embedding, src_embedding, tgt, src, temperature)
342
+ return rotation_ab, translation_ab, rotation_ba, translation_ba, feature_disparity
343
+
344
+ def predict_keypoint_correspondence(self, *input):
345
+ src, tgt, src_embedding, tgt_embedding, temperature, _ = self.predict_embedding(*input)
346
+ batch_size, num_dims, num_points = src.size()
347
+ d_k = src_embedding.size(1)
348
+ scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k)
349
+ scores = scores.view(batch_size*num_points, num_points)
350
+ temperature = temperature.repeat(1, num_points, 1).view(-1, 1)
351
+ scores = F.gumbel_softmax(scores, tau=temperature, hard=True)
352
+ scores = scores.view(batch_size, num_points, num_points)
353
+ return src, tgt, scores
354
+
355
+ def forward(self, *input):
356
+ calculate_loss = False
357
+ if len(input) == 2:
358
+ src, tgt = input[0], input[1]
359
+ elif len(input) == 3:
360
+ src, tgt, rotation_ab, translation_ab = input[0], input[1], input[2][:, :3, :3], input[2][:, :3, 3].view(-1, 3)
361
+ calculate_loss = True
362
+ elif len(input) == 4:
363
+ src, tgt, rotation_ab, translation_ab = input[0], input[1], input[2], input[3]
364
+ calculate_loss = True
365
+
366
+ if self.input_shape == 'bnc':
367
+ src, tgt = src.permute(0, 2, 1), tgt.permute(0, 2, 1)
368
+
369
+ batch_size = src.size(0)
370
+ identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(batch_size, 1, 1)
371
+
372
+ rotation_ab_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1)
373
+ translation_ab_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1)
374
+
375
+ rotation_ba_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1)
376
+ translation_ba_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1)
377
+
378
+ total_loss = 0
379
+ total_feature_alignment_loss = 0
380
+ total_cycle_consistency_loss = 0
381
+ total_scale_consensus_loss = 0
382
+
383
+ for i in range(self.num_iters):
384
+ rotation_ab_pred_i, translation_ab_pred_i, rotation_ba_pred_i, translation_ba_pred_i, feature_disparity = self.spam(src, tgt)
385
+
386
+ rotation_ab_pred = torch.matmul(rotation_ab_pred_i, rotation_ab_pred)
387
+ translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) + translation_ab_pred_i
388
+
389
+ rotation_ba_pred = torch.matmul(rotation_ba_pred_i, rotation_ba_pred)
390
+ translation_ba_pred = torch.matmul(rotation_ba_pred_i, translation_ba_pred.unsqueeze(2)).squeeze(2) + translation_ba_pred_i
391
+
392
+ if calculate_loss:
393
+ loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
394
+ + F.mse_loss(translation_ab_pred, translation_ab)) * self.discount_factor**i
395
+
396
+ feature_alignment_loss = feature_disparity.mean() * self.feature_alignment_loss * self.discount_factor**i
397
+ cycle_consistency_loss = cycle_consistency(rotation_ab_pred_i, translation_ab_pred_i,
398
+ rotation_ba_pred_i, translation_ba_pred_i) \
399
+ * self.cycle_consistency_loss * self.discount_factor**i
400
+
401
+ scale_consensus_loss = 0
402
+ total_feature_alignment_loss += feature_alignment_loss
403
+ total_cycle_consistency_loss += cycle_consistency_loss
404
+ total_loss = total_loss + loss + feature_alignment_loss + cycle_consistency_loss + scale_consensus_loss
405
+
406
+ if self.input_shape == 'bnc':
407
+ src = transform.transform_point_cloud(src.permute(0, 2, 1), rotation_ab_pred_i, translation_ab_pred_i).permute(0, 2, 1)
408
+ else:
409
+ src = transform.transform_point_cloud(src, rotation_ab_pred_i, translation_ab_pred_i)
410
+
411
+ if self.input_shape == 'bnc':
412
+ src, tgt = src.permute(0, 2, 1), tgt.permute(0, 2, 1)
413
+
414
+ result = {'est_R': rotation_ab_pred,
415
+ 'est_t': translation_ab_pred,
416
+ 'est_T': transform.convert2transformation(rotation_ab_pred, translation_ab_pred),
417
+ 'transformed_source': src}
418
+
419
+ if calculate_loss:
420
+ result['loss'] = total_loss
421
+ return result
422
+
423
+
424
+ if __name__ == '__main__':
425
+ model = PRNet()
426
+ src = torch.tensor(10, 1024, 3)
427
+ tgt = torch.tensor(10, 768, 3)
428
+ rotation_ab, translation_ab = torch.tensor(10, 3, 3), torch.tensor(10, 3)
429
+ src, tgt = src.to(device), tgt.to(device)
430
+ rotation_ab, translation_ab = rotation_ab.to(device), translation_ab.to(device)
431
+ rotation_ab_pred, translation_ab_pred, loss = model(src, tgt, rotation_ab, translation_ab)