learning3d 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- learning3d/__init__.py +2 -0
- learning3d/data_utils/__init__.py +4 -0
- learning3d/data_utils/dataloaders.py +454 -0
- learning3d/data_utils/user_data.py +119 -0
- learning3d/examples/test_dcp.py +139 -0
- learning3d/examples/test_deepgmr.py +144 -0
- learning3d/examples/test_flownet.py +113 -0
- learning3d/examples/test_masknet.py +159 -0
- learning3d/examples/test_masknet2.py +162 -0
- learning3d/examples/test_pcn.py +118 -0
- learning3d/examples/test_pcrnet.py +120 -0
- learning3d/examples/test_pnlk.py +121 -0
- learning3d/examples/test_pointconv.py +126 -0
- learning3d/examples/test_pointnet.py +121 -0
- learning3d/examples/test_prnet.py +126 -0
- learning3d/examples/test_rpmnet.py +120 -0
- learning3d/examples/train_PointNetLK.py +240 -0
- learning3d/examples/train_dcp.py +249 -0
- learning3d/examples/train_deepgmr.py +244 -0
- learning3d/examples/train_flownet.py +259 -0
- learning3d/examples/train_masknet.py +239 -0
- learning3d/examples/train_pcn.py +216 -0
- learning3d/examples/train_pcrnet.py +228 -0
- learning3d/examples/train_pointconv.py +245 -0
- learning3d/examples/train_pointnet.py +244 -0
- learning3d/examples/train_prnet.py +229 -0
- learning3d/examples/train_rpmnet.py +228 -0
- learning3d/losses/__init__.py +12 -0
- learning3d/losses/chamfer_distance.py +51 -0
- learning3d/losses/classification.py +14 -0
- learning3d/losses/correspondence_loss.py +10 -0
- learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
- learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
- learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
- learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
- learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
- learning3d/losses/cuda/emd_torch/setup.py +29 -0
- learning3d/losses/emd.py +16 -0
- learning3d/losses/frobenius_norm.py +21 -0
- learning3d/losses/rmse_features.py +16 -0
- learning3d/models/__init__.py +23 -0
- learning3d/models/classifier.py +41 -0
- learning3d/models/dcp.py +92 -0
- learning3d/models/deepgmr.py +165 -0
- learning3d/models/dgcnn.py +92 -0
- learning3d/models/flownet3d.py +446 -0
- learning3d/models/masknet.py +84 -0
- learning3d/models/masknet2.py +264 -0
- learning3d/models/pcn.py +164 -0
- learning3d/models/pcrnet.py +74 -0
- learning3d/models/pointconv.py +108 -0
- learning3d/models/pointnet.py +108 -0
- learning3d/models/pointnetlk.py +173 -0
- learning3d/models/pooling.py +15 -0
- learning3d/models/ppfnet.py +102 -0
- learning3d/models/prnet.py +431 -0
- learning3d/models/rpmnet.py +359 -0
- learning3d/models/segmentation.py +38 -0
- learning3d/ops/__init__.py +0 -0
- learning3d/ops/data_utils.py +45 -0
- learning3d/ops/invmat.py +134 -0
- learning3d/ops/quaternion.py +218 -0
- learning3d/ops/se3.py +157 -0
- learning3d/ops/sinc.py +229 -0
- learning3d/ops/so3.py +213 -0
- learning3d/ops/transform_functions.py +342 -0
- learning3d/utils/__init__.py +9 -0
- learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
- learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
- learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
- learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
- learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
- learning3d/utils/lib/pointnet2_modules.py +160 -0
- learning3d/utils/lib/pointnet2_utils.py +318 -0
- learning3d/utils/lib/pytorch_utils.py +236 -0
- learning3d/utils/lib/setup.py +23 -0
- learning3d/utils/lib/src/ball_query.cpp +25 -0
- learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
- learning3d/utils/lib/src/ball_query_gpu.h +15 -0
- learning3d/utils/lib/src/cuda_utils.h +15 -0
- learning3d/utils/lib/src/group_points.cpp +36 -0
- learning3d/utils/lib/src/group_points_gpu.cu +86 -0
- learning3d/utils/lib/src/group_points_gpu.h +22 -0
- learning3d/utils/lib/src/interpolate.cpp +65 -0
- learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
- learning3d/utils/lib/src/interpolate_gpu.h +36 -0
- learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
- learning3d/utils/lib/src/sampling.cpp +46 -0
- learning3d/utils/lib/src/sampling_gpu.cu +253 -0
- learning3d/utils/lib/src/sampling_gpu.h +29 -0
- learning3d/utils/pointconv_util.py +382 -0
- learning3d/utils/ppfnet_util.py +244 -0
- learning3d/utils/svd.py +59 -0
- learning3d/utils/transformer.py +243 -0
- learning3d-0.0.1.dist-info/LICENSE +21 -0
- learning3d-0.0.1.dist-info/METADATA +271 -0
- learning3d-0.0.1.dist-info/RECORD +115 -0
- learning3d-0.0.1.dist-info/WHEEL +5 -0
- learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,359 @@
|
|
1
|
+
import argparse
|
2
|
+
import logging
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
from .. utils import square_distance, angle_difference
|
9
|
+
from .. ops.transform_functions import convert2transformation
|
10
|
+
from .ppfnet import PPFNet
|
11
|
+
_EPS = 1e-5 # To prevent division by zero
|
12
|
+
|
13
|
+
|
14
|
+
class ParameterPredictionNet(nn.Module):
|
15
|
+
def __init__(self, weights_dim):
|
16
|
+
"""PointNet based Parameter prediction network
|
17
|
+
|
18
|
+
Args:
|
19
|
+
weights_dim: Number of weights to predict (excluding beta), should be something like
|
20
|
+
[3], or [64, 3], for 3 types of features
|
21
|
+
"""
|
22
|
+
|
23
|
+
super().__init__()
|
24
|
+
|
25
|
+
self._logger = logging.getLogger(self.__class__.__name__)
|
26
|
+
|
27
|
+
self.weights_dim = weights_dim
|
28
|
+
|
29
|
+
# Pointnet
|
30
|
+
self.prepool = nn.Sequential(
|
31
|
+
nn.Conv1d(4, 64, 1),
|
32
|
+
nn.GroupNorm(8, 64),
|
33
|
+
nn.ReLU(),
|
34
|
+
|
35
|
+
nn.Conv1d(64, 64, 1),
|
36
|
+
nn.GroupNorm(8, 64),
|
37
|
+
nn.ReLU(),
|
38
|
+
|
39
|
+
nn.Conv1d(64, 64, 1),
|
40
|
+
nn.GroupNorm(8, 64),
|
41
|
+
nn.ReLU(),
|
42
|
+
|
43
|
+
nn.Conv1d(64, 128, 1),
|
44
|
+
nn.GroupNorm(8, 128),
|
45
|
+
nn.ReLU(),
|
46
|
+
|
47
|
+
nn.Conv1d(128, 1024, 1),
|
48
|
+
nn.GroupNorm(16, 1024),
|
49
|
+
nn.ReLU(),
|
50
|
+
)
|
51
|
+
self.pooling = nn.AdaptiveMaxPool1d(1)
|
52
|
+
self.postpool = nn.Sequential(
|
53
|
+
nn.Linear(1024, 512),
|
54
|
+
nn.GroupNorm(16, 512),
|
55
|
+
nn.ReLU(),
|
56
|
+
|
57
|
+
nn.Linear(512, 256),
|
58
|
+
nn.GroupNorm(16, 256),
|
59
|
+
nn.ReLU(),
|
60
|
+
|
61
|
+
nn.Linear(256, 2 + np.prod(weights_dim)),
|
62
|
+
)
|
63
|
+
|
64
|
+
self._logger.info('Predicting weights with dim {}.'.format(self.weights_dim))
|
65
|
+
|
66
|
+
def forward(self, x):
|
67
|
+
""" Returns alpha, beta, and gating_weights (if needed)
|
68
|
+
|
69
|
+
Args:
|
70
|
+
x: List containing two point clouds, x[0] = src (B, J, 3), x[1] = ref (B, K, 3)
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
beta, alpha, weightings
|
74
|
+
"""
|
75
|
+
|
76
|
+
src_padded = F.pad(x[0], (0, 1), mode='constant', value=0)
|
77
|
+
ref_padded = F.pad(x[1], (0, 1), mode='constant', value=1)
|
78
|
+
concatenated = torch.cat([src_padded, ref_padded], dim=1)
|
79
|
+
|
80
|
+
prepool_feat = self.prepool(concatenated.permute(0, 2, 1))
|
81
|
+
pooled = torch.flatten(self.pooling(prepool_feat), start_dim=-2)
|
82
|
+
raw_weights = self.postpool(pooled)
|
83
|
+
|
84
|
+
beta = F.softplus(raw_weights[:, 0])
|
85
|
+
alpha = F.softplus(raw_weights[:, 1])
|
86
|
+
|
87
|
+
return beta, alpha
|
88
|
+
|
89
|
+
|
90
|
+
|
91
|
+
def to_numpy(tensor):
|
92
|
+
"""Wrapper around .detach().cpu().numpy() """
|
93
|
+
if isinstance(tensor, torch.Tensor):
|
94
|
+
return tensor.detach().cpu().numpy()
|
95
|
+
elif isinstance(tensor, np.ndarray):
|
96
|
+
return tensor
|
97
|
+
else:
|
98
|
+
raise NotImplementedError
|
99
|
+
|
100
|
+
|
101
|
+
def se3_transform(g, a, normals=None):
|
102
|
+
""" Applies the SE3 transform
|
103
|
+
|
104
|
+
Args:
|
105
|
+
g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4)
|
106
|
+
a: Points to be transformed (N, 3) or (B, N, 3)
|
107
|
+
normals: (Optional). If provided, normals will be transformed
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
transformed points of size (N, 3) or (B, N, 3)
|
111
|
+
|
112
|
+
"""
|
113
|
+
R = g[..., :3, :3] # (B, 3, 3)
|
114
|
+
p = g[..., :3, 3] # (B, 3)
|
115
|
+
|
116
|
+
if len(g.size()) == len(a.size()):
|
117
|
+
b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :]
|
118
|
+
else:
|
119
|
+
raise NotImplementedError
|
120
|
+
b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p # No batch. Not checked
|
121
|
+
|
122
|
+
if normals is not None:
|
123
|
+
rotated_normals = normals @ R.transpose(-1, -2)
|
124
|
+
return b, rotated_normals
|
125
|
+
|
126
|
+
else:
|
127
|
+
return b
|
128
|
+
|
129
|
+
|
130
|
+
def match_features(feat_src, feat_ref, metric='l2'):
|
131
|
+
""" Compute pairwise distance between features
|
132
|
+
|
133
|
+
Args:
|
134
|
+
feat_src: (B, J, C)
|
135
|
+
feat_ref: (B, K, C)
|
136
|
+
metric: either 'angle' or 'l2' (squared euclidean)
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
Matching matrix (B, J, K). i'th row describes how well the i'th point
|
140
|
+
in the src agrees with every point in the ref.
|
141
|
+
"""
|
142
|
+
assert feat_src.shape[-1] == feat_ref.shape[-1]
|
143
|
+
|
144
|
+
if metric == 'l2':
|
145
|
+
dist_matrix = square_distance(feat_src, feat_ref)
|
146
|
+
elif metric == 'angle':
|
147
|
+
feat_src_norm = feat_src / (torch.norm(feat_src, dim=-1, keepdim=True) + _EPS)
|
148
|
+
feat_ref_norm = feat_ref / (torch.norm(feat_ref, dim=-1, keepdim=True) + _EPS)
|
149
|
+
|
150
|
+
dist_matrix = angle_difference(feat_src_norm, feat_ref_norm)
|
151
|
+
else:
|
152
|
+
raise NotImplementedError
|
153
|
+
|
154
|
+
return dist_matrix
|
155
|
+
|
156
|
+
|
157
|
+
def sinkhorn(log_alpha, n_iters: int = 5, slack: bool = True, eps: float = -1) -> torch.Tensor:
|
158
|
+
""" Run sinkhorn iterations to generate a near doubly stochastic matrix, where each row or column sum to <=1
|
159
|
+
|
160
|
+
Args:
|
161
|
+
log_alpha: log of positive matrix to apply sinkhorn normalization (B, J, K)
|
162
|
+
n_iters (int): Number of normalization iterations
|
163
|
+
slack (bool): Whether to include slack row and column
|
164
|
+
eps: eps for early termination (Used only for handcrafted RPM). Set to negative to disable.
|
165
|
+
|
166
|
+
Returns:
|
167
|
+
log(perm_matrix): Doubly stochastic matrix (B, J, K)
|
168
|
+
|
169
|
+
Modified from original source taken from:
|
170
|
+
Learning Latent Permutations with Gumbel-Sinkhorn Networks
|
171
|
+
https://github.com/HeddaCohenIndelman/Learning-Gumbel-Sinkhorn-Permutations-w-Pytorch
|
172
|
+
"""
|
173
|
+
|
174
|
+
# Sinkhorn iterations
|
175
|
+
prev_alpha = None
|
176
|
+
if slack:
|
177
|
+
zero_pad = nn.ZeroPad2d((0, 1, 0, 1))
|
178
|
+
log_alpha_padded = zero_pad(log_alpha[:, None, :, :])
|
179
|
+
|
180
|
+
log_alpha_padded = torch.squeeze(log_alpha_padded, dim=1)
|
181
|
+
|
182
|
+
for i in range(n_iters):
|
183
|
+
# Row normalization
|
184
|
+
log_alpha_padded = torch.cat((
|
185
|
+
log_alpha_padded[:, :-1, :] - (torch.logsumexp(log_alpha_padded[:, :-1, :], dim=2, keepdim=True)),
|
186
|
+
log_alpha_padded[:, -1, None, :]), # Don't normalize last row
|
187
|
+
dim=1)
|
188
|
+
|
189
|
+
# Column normalization
|
190
|
+
log_alpha_padded = torch.cat((
|
191
|
+
log_alpha_padded[:, :, :-1] - (torch.logsumexp(log_alpha_padded[:, :, :-1], dim=1, keepdim=True)),
|
192
|
+
log_alpha_padded[:, :, -1, None]), # Don't normalize last column
|
193
|
+
dim=2)
|
194
|
+
|
195
|
+
if eps > 0:
|
196
|
+
if prev_alpha is not None:
|
197
|
+
abs_dev = torch.abs(torch.exp(log_alpha_padded[:, :-1, :-1]) - prev_alpha)
|
198
|
+
if torch.max(torch.sum(abs_dev, dim=[1, 2])) < eps:
|
199
|
+
break
|
200
|
+
prev_alpha = torch.exp(log_alpha_padded[:, :-1, :-1]).clone()
|
201
|
+
|
202
|
+
log_alpha = log_alpha_padded[:, :-1, :-1]
|
203
|
+
else:
|
204
|
+
for i in range(n_iters):
|
205
|
+
# Row normalization (i.e. each row sum to 1)
|
206
|
+
log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=2, keepdim=True))
|
207
|
+
|
208
|
+
# Column normalization (i.e. each column sum to 1)
|
209
|
+
log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=1, keepdim=True))
|
210
|
+
|
211
|
+
if eps > 0:
|
212
|
+
if prev_alpha is not None:
|
213
|
+
abs_dev = torch.abs(torch.exp(log_alpha) - prev_alpha)
|
214
|
+
if torch.max(torch.sum(abs_dev, dim=[1, 2])) < eps:
|
215
|
+
break
|
216
|
+
prev_alpha = torch.exp(log_alpha).clone()
|
217
|
+
|
218
|
+
return log_alpha
|
219
|
+
|
220
|
+
|
221
|
+
def compute_rigid_transform(a: torch.Tensor, b: torch.Tensor, weights: torch.Tensor):
|
222
|
+
"""Compute rigid transforms between two point sets
|
223
|
+
|
224
|
+
Args:
|
225
|
+
a (torch.Tensor): (B, M, 3) points
|
226
|
+
b (torch.Tensor): (B, N, 3) points
|
227
|
+
weights (torch.Tensor): (B, M)
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
Transform T (B, 3, 4) to get from a to b, i.e. T*a = b
|
231
|
+
"""
|
232
|
+
|
233
|
+
weights_normalized = weights[..., None] / (torch.sum(weights[..., None], dim=1, keepdim=True) + _EPS)
|
234
|
+
centroid_a = torch.sum(a * weights_normalized, dim=1)
|
235
|
+
centroid_b = torch.sum(b * weights_normalized, dim=1)
|
236
|
+
a_centered = a - centroid_a[:, None, :]
|
237
|
+
b_centered = b - centroid_b[:, None, :]
|
238
|
+
cov = a_centered.transpose(-2, -1) @ (b_centered * weights_normalized)
|
239
|
+
|
240
|
+
# Compute rotation using Kabsch algorithm. Will compute two copies with +/-V[:,:3]
|
241
|
+
# and choose based on determinant to avoid flips
|
242
|
+
u, s, v = torch.svd(cov, some=False, compute_uv=True)
|
243
|
+
rot_mat_pos = v @ u.transpose(-1, -2)
|
244
|
+
v_neg = v.clone()
|
245
|
+
v_neg[:, :, 2] *= -1
|
246
|
+
rot_mat_neg = v_neg @ u.transpose(-1, -2)
|
247
|
+
rot_mat = torch.where(torch.det(rot_mat_pos)[:, None, None] > 0, rot_mat_pos, rot_mat_neg)
|
248
|
+
assert torch.all(torch.det(rot_mat) > 0)
|
249
|
+
|
250
|
+
# Compute translation (uncenter centroid)
|
251
|
+
translation = -rot_mat @ centroid_a[:, :, None] + centroid_b[:, :, None]
|
252
|
+
|
253
|
+
transform = torch.cat((rot_mat, translation), dim=2)
|
254
|
+
return transform
|
255
|
+
|
256
|
+
|
257
|
+
class RPMNet(nn.Module):
|
258
|
+
def __init__(self, feature_model=PPFNet()):
|
259
|
+
super().__init__()
|
260
|
+
|
261
|
+
self.add_slack = True
|
262
|
+
self.num_sk_iter = 5
|
263
|
+
|
264
|
+
self.weights_net = ParameterPredictionNet(weights_dim=[0])
|
265
|
+
self.feat_extractor = feature_model
|
266
|
+
|
267
|
+
def compute_affinity(self, beta, feat_distance, alpha=0.5):
|
268
|
+
"""Compute logarithm of Initial match matrix values, i.e. log(m_jk)"""
|
269
|
+
if isinstance(alpha, float):
|
270
|
+
hybrid_affinity = -beta[:, None, None] * (feat_distance - alpha)
|
271
|
+
else:
|
272
|
+
hybrid_affinity = -beta[:, None, None] * (feat_distance - alpha[:, None, None])
|
273
|
+
return hybrid_affinity
|
274
|
+
|
275
|
+
@staticmethod
|
276
|
+
def split_normals(data):
|
277
|
+
if data.shape[2] == 6:
|
278
|
+
xyz, normals = data[:, :, :3], data[:, :, 3:6]
|
279
|
+
elif data.shape[2] == 3:
|
280
|
+
xyz, normals = data, torch.zeros(data.shape).to(data.device)
|
281
|
+
return xyz, normals
|
282
|
+
|
283
|
+
def spam(self, xyz_template, norm_template, xyz_source, norm_source):
|
284
|
+
self.beta, self.alpha = self.weights_net([xyz_source, xyz_template])
|
285
|
+
self.feat_source = self.feat_extractor(xyz_source, norm_source)
|
286
|
+
self.feat_template = self.feat_extractor(xyz_template, norm_template)
|
287
|
+
|
288
|
+
feat_distance = match_features(self.feat_source, self.feat_template)
|
289
|
+
self.affinity = self.compute_affinity(self.beta, feat_distance, alpha=self.alpha)
|
290
|
+
|
291
|
+
# Compute weighted coordinates
|
292
|
+
log_perm_matrix = sinkhorn(self.affinity, n_iters=self.num_sk_iter, slack=self.add_slack)
|
293
|
+
self.perm_matrix = torch.exp(log_perm_matrix)
|
294
|
+
weighted_template = self.perm_matrix @ xyz_template / (torch.sum(self.perm_matrix, dim=2, keepdim=True) + _EPS)
|
295
|
+
|
296
|
+
return weighted_template
|
297
|
+
|
298
|
+
def forward(self, template, source, max_iterations: int = 1):
|
299
|
+
"""Forward pass for RPMNet
|
300
|
+
|
301
|
+
Args:
|
302
|
+
data: Dict containing the following fields:
|
303
|
+
'points_src': Source points (B, J, 6)
|
304
|
+
'points_ref': Reference points (B, K, 6)
|
305
|
+
num_iter (int): Number of iterations. Recommended to be 2 for training
|
306
|
+
|
307
|
+
Returns:
|
308
|
+
transform: Transform to apply to source points such that they align to reference
|
309
|
+
src_transformed: Transformed source points
|
310
|
+
"""
|
311
|
+
|
312
|
+
xyz_template, norm_template = self.split_normals(template)
|
313
|
+
xyz_source, norm_source = self.split_normals(source)
|
314
|
+
|
315
|
+
xyz_source_t, norm_source_t = xyz_source, norm_source
|
316
|
+
|
317
|
+
transforms = []
|
318
|
+
all_gamma, all_perm_matrices, all_weighted_template = [], [], []
|
319
|
+
all_beta, all_alpha = [], []
|
320
|
+
|
321
|
+
for i in range(max_iterations):
|
322
|
+
weighted_template = self.spam(xyz_template, norm_template, xyz_source_t, norm_source_t) # Finding better correspondences after each iteration.
|
323
|
+
|
324
|
+
# Compute transform and transform points
|
325
|
+
transform = compute_rigid_transform(xyz_source, weighted_template, weights=torch.sum(self.perm_matrix, dim=2))
|
326
|
+
xyz_source_t, norm_source_t = se3_transform(transform.detach(), xyz_source, norm_source) # Apply transformation to original source.
|
327
|
+
|
328
|
+
transforms.append(transform)
|
329
|
+
all_gamma.append(torch.exp(self.affinity))
|
330
|
+
all_perm_matrices.append(self.perm_matrix)
|
331
|
+
all_weighted_template.append(weighted_template)
|
332
|
+
all_beta.append(to_numpy(self.beta))
|
333
|
+
all_alpha.append(to_numpy(self.alpha))
|
334
|
+
|
335
|
+
est_T = convert2transformation(transforms[max_iterations-1][:, :3, :3], transforms[max_iterations-1][:, :3, 3])
|
336
|
+
transformed_source = torch.bmm(est_T[:, :3, :3], source[:,:,:3].permute(0, 2, 1)).permute(0, 2, 1) + est_T[:, :3, 3].unsqueeze(1)
|
337
|
+
|
338
|
+
result = {'est_R': est_T[:, :3, :3], # source -> template
|
339
|
+
'est_t': est_T[:, :3, 3], # source -> template
|
340
|
+
'est_T': est_T, # source -> template
|
341
|
+
# 'r': self.feat_template - self.feat_source,
|
342
|
+
'transformed_source': transformed_source}
|
343
|
+
|
344
|
+
result['perm_matrices_init'] = all_gamma
|
345
|
+
result['perm_matrices'] = all_perm_matrices
|
346
|
+
result['weighted_template'] = all_weighted_template
|
347
|
+
result['beta'] = np.stack(all_beta, axis=0)
|
348
|
+
result['alpha'] = np.stack(all_alpha, axis=0)
|
349
|
+
result['transforms'] = transforms
|
350
|
+
|
351
|
+
return result
|
352
|
+
|
353
|
+
|
354
|
+
if __name__ == '__main__':
|
355
|
+
template, source = torch.rand(10,1024,6), torch.rand(10,1024,6)
|
356
|
+
|
357
|
+
net = RPMNet()
|
358
|
+
result = net(template, source)
|
359
|
+
import ipdb; ipdb.set_trace()
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
|
6
|
+
class Segmentation(nn.Module):
|
7
|
+
def __init__(self, feature_model, num_classes=40):
|
8
|
+
super(Segmentation, self).__init__()
|
9
|
+
self.feature_model = feature_model
|
10
|
+
self.num_classes = num_classes
|
11
|
+
|
12
|
+
self.conv1 = torch.nn.Conv1d(self.feature_model.emb_dims+64, 512, 1)
|
13
|
+
self.conv2 = torch.nn.Conv1d(512, 256, 1)
|
14
|
+
self.conv3 = torch.nn.Conv1d(256, 128, 1)
|
15
|
+
self.conv4 = torch.nn.Conv1d(128, self.num_classes, 1)
|
16
|
+
self.bn1 = nn.BatchNorm1d(512)
|
17
|
+
self.bn2 = nn.BatchNorm1d(256)
|
18
|
+
self.bn3 = nn.BatchNorm1d(128)
|
19
|
+
|
20
|
+
def forward(self, input_data):
|
21
|
+
output = self.feature_model(input_data)
|
22
|
+
output = F.relu(self.bn1(self.conv1(output)))
|
23
|
+
output = F.relu(self.bn2(self.conv2(output)))
|
24
|
+
output = F.relu(self.bn3(self.conv3(output)))
|
25
|
+
output = self.conv4(output)
|
26
|
+
output = output.permute(0, 2, 1) # B x N x num_classes
|
27
|
+
return output
|
28
|
+
|
29
|
+
if __name__ == '__main__':
|
30
|
+
from pointnet import PointNet
|
31
|
+
x = torch.rand(10,1024,3)
|
32
|
+
|
33
|
+
pn = PointNet(global_feat=False)
|
34
|
+
seg = Segmentation(pn)
|
35
|
+
seg_result = seg(x)
|
36
|
+
|
37
|
+
print('Input Shape: {}\n Segmentation Output Shape: {}'
|
38
|
+
.format(x.shape, seg_result.shape))
|
File without changes
|
@@ -0,0 +1,45 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
def mean_shift(template, source, p0_zero_mean, p1_zero_mean):
|
4
|
+
template_mean = torch.eye(3).view(1, 3, 3).expand(template.size(0), 3, 3).to(template) # [B, 3, 3]
|
5
|
+
source_mean = torch.eye(3).view(1, 3, 3).expand(source.size(0), 3, 3).to(source) # [B, 3, 3]
|
6
|
+
|
7
|
+
if p0_zero_mean:
|
8
|
+
p0_m = template.mean(dim=1) # [B, N, 3] -> [B, 3]
|
9
|
+
template_mean = torch.cat([template_mean, p0_m.unsqueeze(-1)], dim=2)
|
10
|
+
one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(template_mean.shape[0], 1, 1).to(template_mean) # (Bx1x4)
|
11
|
+
template_mean = torch.cat([template_mean, one_], dim=1)
|
12
|
+
template = template - p0_m.unsqueeze(1)
|
13
|
+
# else:
|
14
|
+
# q0 = template
|
15
|
+
|
16
|
+
if p1_zero_mean:
|
17
|
+
#print(numpy.any(numpy.isnan(p1.numpy())))
|
18
|
+
p1_m = source.mean(dim=1) # [B, N, 3] -> [B, 3]
|
19
|
+
source_mean = torch.cat([source_mean, -p0_m.unsqueeze(-1)], dim=2)
|
20
|
+
one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(source_mean.shape[0], 1, 1).to(source_mean) # (Bx1x4)
|
21
|
+
source_mean = torch.cat([source_mean, one_], dim=1)
|
22
|
+
source = source - p1_m.unsqueeze(1)
|
23
|
+
# else:
|
24
|
+
# q1 = source
|
25
|
+
return template, source, template_mean, source_mean
|
26
|
+
|
27
|
+
def postprocess_data(result, p0, p1, a0, a1, p0_zero_mean, p1_zero_mean):
|
28
|
+
#output' = trans(p0_m) * output * trans(-p1_m)
|
29
|
+
# = [I, p0_m;] * [R, t;] * [I, -p1_m;]
|
30
|
+
# [0, 1 ] [0, 1 ] [0, 1 ]
|
31
|
+
est_g = result['est_T']
|
32
|
+
if p0_zero_mean:
|
33
|
+
est_g = a0.to(est_g).bmm(est_g)
|
34
|
+
if p1_zero_mean:
|
35
|
+
est_g = est_g.bmm(a1.to(est_g))
|
36
|
+
result['est_T'] = est_g
|
37
|
+
|
38
|
+
est_gs = result['est_T_series'] # [M, B, 4, 4]
|
39
|
+
if p0_zero_mean:
|
40
|
+
est_gs = a0.unsqueeze(0).contiguous().to(est_gs).matmul(est_gs)
|
41
|
+
if p1_zero_mean:
|
42
|
+
est_gs = est_gs.matmul(a1.unsqueeze(0).contiguous().to(est_gs))
|
43
|
+
result['est_T_series'] = est_gs
|
44
|
+
|
45
|
+
return result
|
learning3d/ops/invmat.py
ADDED
@@ -0,0 +1,134 @@
|
|
1
|
+
""" inverse matrix """
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
|
6
|
+
def batch_inverse(x):
|
7
|
+
""" M(n) -> M(n); x -> x^-1 """
|
8
|
+
batch_size, h, w = x.size()
|
9
|
+
assert h == w
|
10
|
+
y = torch.zeros_like(x)
|
11
|
+
for i in range(batch_size):
|
12
|
+
y[i, :, :] = x[i, :, :].inverse()
|
13
|
+
return y
|
14
|
+
|
15
|
+
def batch_inverse_dx(y):
|
16
|
+
""" backward """
|
17
|
+
# Let y(x) = x^-1.
|
18
|
+
# compute dy
|
19
|
+
# dy = dy(j,k)
|
20
|
+
# = - y(j,m) * dx(m,n) * y(n,k)
|
21
|
+
# = - y(j,m) * y(n,k) * dx(m,n)
|
22
|
+
# therefore,
|
23
|
+
# dy(j,k)/dx(m,n) = - y(j,m) * y(n,k)
|
24
|
+
batch_size, h, w = y.size()
|
25
|
+
assert h == w
|
26
|
+
# compute dy(j,k,m,n) = dy(j,k)/dx(m,n) = - y(j,m) * y(n,k)
|
27
|
+
# = - (y(j,:))' * y'(k,:)
|
28
|
+
yl = y.repeat(1, 1, h).view(batch_size*h*h, h, 1)
|
29
|
+
yr = y.transpose(1, 2).repeat(1, h, 1).view(batch_size*h*h, 1, h)
|
30
|
+
dy = - yl.bmm(yr).view(batch_size, h, h, h, h)
|
31
|
+
|
32
|
+
# compute dy(m,n,j,k) = dy(j,k)/dx(m,n) = - y(j,m) * y(n,k)
|
33
|
+
# = - (y'(m,:))' * y(n,:)
|
34
|
+
#yl = y.transpose(1, 2).repeat(1, 1, h).view(batch_size*h*h, h, 1)
|
35
|
+
#yr = y.repeat(1, h, 1).view(batch_size*h*h, 1, h)
|
36
|
+
#dy = - yl.bmm(yr).view(batch_size, h, h, h, h)
|
37
|
+
|
38
|
+
return dy
|
39
|
+
|
40
|
+
|
41
|
+
def batch_pinv_dx(x):
|
42
|
+
""" returns y = (x'*x)^-1 * x' and dy/dx. """
|
43
|
+
# y = (x'*x)^-1 * x'
|
44
|
+
# = s^-1 * x'
|
45
|
+
# = b * x'
|
46
|
+
# d{y(j,k)}/d{x(m,n)}
|
47
|
+
# = d{b(j,i) * x(k,i)}/d{x(m,n)}
|
48
|
+
# = d{b(j,i)}/d{x(m,n)} * x(k,i) + b(j,i) * d{x(k,i)}/d{x(m,n)}
|
49
|
+
# d{b(j,i)}/d{x(m,n)}
|
50
|
+
# = d{b(j,i)}/d{s(p,q)} * d{s(p,q)}/d{x(m,n)}
|
51
|
+
# = -b(j,p)*b(q,i) * d{s(p,q)}/d{x(m,n)}
|
52
|
+
# d{s(p,q)}/d{x(m,n)}
|
53
|
+
# = d{x(t,p)*x(t,q)}/d{x(m,n)}
|
54
|
+
# = d{x(t,p)}/d{x(m,n)} * x(t,q) + x(t,p) * d{x(t,q)}/d{x(m,n)}
|
55
|
+
batch_size, h, w = x.size()
|
56
|
+
xt = x.transpose(1, 2)
|
57
|
+
s = xt.bmm(x)
|
58
|
+
b = batch_inverse(s)
|
59
|
+
y = b.bmm(xt)
|
60
|
+
|
61
|
+
# dx/dx
|
62
|
+
ex = torch.eye(h*w).to(x).unsqueeze(0).view(1, h, w, h, w)
|
63
|
+
# ds/dx = dx(t,_)/dx * x(t,_) + x(t,_) * dx(t,_)/dx
|
64
|
+
ex1 = ex.view(1, h, w*h*w) # [t, p*m*n]
|
65
|
+
dx1 = x.transpose(1, 2).matmul(ex1).view(batch_size, w, w, h, w) # [q, p,m,n]
|
66
|
+
ds_dx = dx1.transpose(1, 2) + dx1 # [p, q, m, n]
|
67
|
+
# db/ds
|
68
|
+
db_ds = batch_inverse_dx(b) # [j, i, p, q]
|
69
|
+
# db/dx = db/d{s(p,q)} * d{s(p,q)}/dx
|
70
|
+
db1 = db_ds.view(batch_size, w*w, w*w).bmm(ds_dx.view(batch_size, w*w, h*w))
|
71
|
+
db_dx = db1.view(batch_size, w, w, h, w) # [j, i, m, n]
|
72
|
+
# dy/dx = db(_,i)/dx * x(_,i) + b(_,i) * dx(_,i)/dx
|
73
|
+
dy1 = db_dx.transpose(1, 2).contiguous().view(batch_size, w, w*h*w)
|
74
|
+
dy1 = x.matmul(dy1).view(batch_size, h, w, h, w) # [k, j, m, n]
|
75
|
+
ext = ex.transpose(1, 2).contiguous().view(1, w, h*h*w)
|
76
|
+
dy2 = b.matmul(ext).view(batch_size, w, h, h, w) # [j, k, m, n]
|
77
|
+
dy_dx = dy1.transpose(1, 2) + dy2
|
78
|
+
|
79
|
+
return y, dy_dx
|
80
|
+
|
81
|
+
|
82
|
+
class InvMatrix(torch.autograd.Function):
|
83
|
+
""" M(n) -> M(n); x -> x^-1.
|
84
|
+
"""
|
85
|
+
@staticmethod
|
86
|
+
def forward(ctx, x):
|
87
|
+
y = batch_inverse(x)
|
88
|
+
ctx.save_for_backward(y)
|
89
|
+
return y
|
90
|
+
|
91
|
+
@staticmethod
|
92
|
+
def backward(ctx, grad_output):
|
93
|
+
y, = ctx.saved_tensors # v0.4
|
94
|
+
#y, = ctx.saved_variables # v0.3.1
|
95
|
+
batch_size, h, w = y.size()
|
96
|
+
assert h == w
|
97
|
+
|
98
|
+
# Let y(x) = x^-1 and assume any function f(y(x)).
|
99
|
+
# compute df/dx(m,n)...
|
100
|
+
# df/dx(m,n) = df/dy(j,k) * dy(j,k)/dx(m,n)
|
101
|
+
# well, df/dy is 'grad_output'
|
102
|
+
# and so we will return 'grad_input = df/dy(j,k) * dy(j,k)/dx(m,n)'
|
103
|
+
|
104
|
+
dy = batch_inverse_dx(y) # dy(j,k,m,n) = dy(j,k)/dx(m,n)
|
105
|
+
go = grad_output.contiguous().view(batch_size, 1, h*h) # [1, (j*k)]
|
106
|
+
ym = dy.view(batch_size, h*h, h*h) # [(j*k), (m*n)]
|
107
|
+
r = go.bmm(ym) # [1, (m*n)]
|
108
|
+
grad_input = r.view(batch_size, h, h) # [m, n]
|
109
|
+
|
110
|
+
return grad_input
|
111
|
+
|
112
|
+
|
113
|
+
|
114
|
+
if __name__ == '__main__':
|
115
|
+
def test():
|
116
|
+
x = torch.randn(2, 3, 2)
|
117
|
+
x_val = x.requires_grad_()
|
118
|
+
|
119
|
+
s_val = x_val.transpose(1, 2).bmm(x_val)
|
120
|
+
s_inv = InvMatrix.apply(s_val)
|
121
|
+
y_val = s_inv.bmm(x_val.transpose(1, 2))
|
122
|
+
y_val.sum().backward()
|
123
|
+
t1 = x_val.grad
|
124
|
+
|
125
|
+
y, dy_dx = batch_pinv_dx(x)
|
126
|
+
t2 = dy_dx.sum(1).sum(1)
|
127
|
+
|
128
|
+
print(t1)
|
129
|
+
print(t2)
|
130
|
+
print(t1 - t2)
|
131
|
+
|
132
|
+
test()
|
133
|
+
|
134
|
+
#EOF
|