learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,359 @@
1
+ import argparse
2
+ import logging
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from .. utils import square_distance, angle_difference
9
+ from .. ops.transform_functions import convert2transformation
10
+ from .ppfnet import PPFNet
11
+ _EPS = 1e-5 # To prevent division by zero
12
+
13
+
14
+ class ParameterPredictionNet(nn.Module):
15
+ def __init__(self, weights_dim):
16
+ """PointNet based Parameter prediction network
17
+
18
+ Args:
19
+ weights_dim: Number of weights to predict (excluding beta), should be something like
20
+ [3], or [64, 3], for 3 types of features
21
+ """
22
+
23
+ super().__init__()
24
+
25
+ self._logger = logging.getLogger(self.__class__.__name__)
26
+
27
+ self.weights_dim = weights_dim
28
+
29
+ # Pointnet
30
+ self.prepool = nn.Sequential(
31
+ nn.Conv1d(4, 64, 1),
32
+ nn.GroupNorm(8, 64),
33
+ nn.ReLU(),
34
+
35
+ nn.Conv1d(64, 64, 1),
36
+ nn.GroupNorm(8, 64),
37
+ nn.ReLU(),
38
+
39
+ nn.Conv1d(64, 64, 1),
40
+ nn.GroupNorm(8, 64),
41
+ nn.ReLU(),
42
+
43
+ nn.Conv1d(64, 128, 1),
44
+ nn.GroupNorm(8, 128),
45
+ nn.ReLU(),
46
+
47
+ nn.Conv1d(128, 1024, 1),
48
+ nn.GroupNorm(16, 1024),
49
+ nn.ReLU(),
50
+ )
51
+ self.pooling = nn.AdaptiveMaxPool1d(1)
52
+ self.postpool = nn.Sequential(
53
+ nn.Linear(1024, 512),
54
+ nn.GroupNorm(16, 512),
55
+ nn.ReLU(),
56
+
57
+ nn.Linear(512, 256),
58
+ nn.GroupNorm(16, 256),
59
+ nn.ReLU(),
60
+
61
+ nn.Linear(256, 2 + np.prod(weights_dim)),
62
+ )
63
+
64
+ self._logger.info('Predicting weights with dim {}.'.format(self.weights_dim))
65
+
66
+ def forward(self, x):
67
+ """ Returns alpha, beta, and gating_weights (if needed)
68
+
69
+ Args:
70
+ x: List containing two point clouds, x[0] = src (B, J, 3), x[1] = ref (B, K, 3)
71
+
72
+ Returns:
73
+ beta, alpha, weightings
74
+ """
75
+
76
+ src_padded = F.pad(x[0], (0, 1), mode='constant', value=0)
77
+ ref_padded = F.pad(x[1], (0, 1), mode='constant', value=1)
78
+ concatenated = torch.cat([src_padded, ref_padded], dim=1)
79
+
80
+ prepool_feat = self.prepool(concatenated.permute(0, 2, 1))
81
+ pooled = torch.flatten(self.pooling(prepool_feat), start_dim=-2)
82
+ raw_weights = self.postpool(pooled)
83
+
84
+ beta = F.softplus(raw_weights[:, 0])
85
+ alpha = F.softplus(raw_weights[:, 1])
86
+
87
+ return beta, alpha
88
+
89
+
90
+
91
+ def to_numpy(tensor):
92
+ """Wrapper around .detach().cpu().numpy() """
93
+ if isinstance(tensor, torch.Tensor):
94
+ return tensor.detach().cpu().numpy()
95
+ elif isinstance(tensor, np.ndarray):
96
+ return tensor
97
+ else:
98
+ raise NotImplementedError
99
+
100
+
101
+ def se3_transform(g, a, normals=None):
102
+ """ Applies the SE3 transform
103
+
104
+ Args:
105
+ g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4)
106
+ a: Points to be transformed (N, 3) or (B, N, 3)
107
+ normals: (Optional). If provided, normals will be transformed
108
+
109
+ Returns:
110
+ transformed points of size (N, 3) or (B, N, 3)
111
+
112
+ """
113
+ R = g[..., :3, :3] # (B, 3, 3)
114
+ p = g[..., :3, 3] # (B, 3)
115
+
116
+ if len(g.size()) == len(a.size()):
117
+ b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :]
118
+ else:
119
+ raise NotImplementedError
120
+ b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p # No batch. Not checked
121
+
122
+ if normals is not None:
123
+ rotated_normals = normals @ R.transpose(-1, -2)
124
+ return b, rotated_normals
125
+
126
+ else:
127
+ return b
128
+
129
+
130
+ def match_features(feat_src, feat_ref, metric='l2'):
131
+ """ Compute pairwise distance between features
132
+
133
+ Args:
134
+ feat_src: (B, J, C)
135
+ feat_ref: (B, K, C)
136
+ metric: either 'angle' or 'l2' (squared euclidean)
137
+
138
+ Returns:
139
+ Matching matrix (B, J, K). i'th row describes how well the i'th point
140
+ in the src agrees with every point in the ref.
141
+ """
142
+ assert feat_src.shape[-1] == feat_ref.shape[-1]
143
+
144
+ if metric == 'l2':
145
+ dist_matrix = square_distance(feat_src, feat_ref)
146
+ elif metric == 'angle':
147
+ feat_src_norm = feat_src / (torch.norm(feat_src, dim=-1, keepdim=True) + _EPS)
148
+ feat_ref_norm = feat_ref / (torch.norm(feat_ref, dim=-1, keepdim=True) + _EPS)
149
+
150
+ dist_matrix = angle_difference(feat_src_norm, feat_ref_norm)
151
+ else:
152
+ raise NotImplementedError
153
+
154
+ return dist_matrix
155
+
156
+
157
+ def sinkhorn(log_alpha, n_iters: int = 5, slack: bool = True, eps: float = -1) -> torch.Tensor:
158
+ """ Run sinkhorn iterations to generate a near doubly stochastic matrix, where each row or column sum to <=1
159
+
160
+ Args:
161
+ log_alpha: log of positive matrix to apply sinkhorn normalization (B, J, K)
162
+ n_iters (int): Number of normalization iterations
163
+ slack (bool): Whether to include slack row and column
164
+ eps: eps for early termination (Used only for handcrafted RPM). Set to negative to disable.
165
+
166
+ Returns:
167
+ log(perm_matrix): Doubly stochastic matrix (B, J, K)
168
+
169
+ Modified from original source taken from:
170
+ Learning Latent Permutations with Gumbel-Sinkhorn Networks
171
+ https://github.com/HeddaCohenIndelman/Learning-Gumbel-Sinkhorn-Permutations-w-Pytorch
172
+ """
173
+
174
+ # Sinkhorn iterations
175
+ prev_alpha = None
176
+ if slack:
177
+ zero_pad = nn.ZeroPad2d((0, 1, 0, 1))
178
+ log_alpha_padded = zero_pad(log_alpha[:, None, :, :])
179
+
180
+ log_alpha_padded = torch.squeeze(log_alpha_padded, dim=1)
181
+
182
+ for i in range(n_iters):
183
+ # Row normalization
184
+ log_alpha_padded = torch.cat((
185
+ log_alpha_padded[:, :-1, :] - (torch.logsumexp(log_alpha_padded[:, :-1, :], dim=2, keepdim=True)),
186
+ log_alpha_padded[:, -1, None, :]), # Don't normalize last row
187
+ dim=1)
188
+
189
+ # Column normalization
190
+ log_alpha_padded = torch.cat((
191
+ log_alpha_padded[:, :, :-1] - (torch.logsumexp(log_alpha_padded[:, :, :-1], dim=1, keepdim=True)),
192
+ log_alpha_padded[:, :, -1, None]), # Don't normalize last column
193
+ dim=2)
194
+
195
+ if eps > 0:
196
+ if prev_alpha is not None:
197
+ abs_dev = torch.abs(torch.exp(log_alpha_padded[:, :-1, :-1]) - prev_alpha)
198
+ if torch.max(torch.sum(abs_dev, dim=[1, 2])) < eps:
199
+ break
200
+ prev_alpha = torch.exp(log_alpha_padded[:, :-1, :-1]).clone()
201
+
202
+ log_alpha = log_alpha_padded[:, :-1, :-1]
203
+ else:
204
+ for i in range(n_iters):
205
+ # Row normalization (i.e. each row sum to 1)
206
+ log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=2, keepdim=True))
207
+
208
+ # Column normalization (i.e. each column sum to 1)
209
+ log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=1, keepdim=True))
210
+
211
+ if eps > 0:
212
+ if prev_alpha is not None:
213
+ abs_dev = torch.abs(torch.exp(log_alpha) - prev_alpha)
214
+ if torch.max(torch.sum(abs_dev, dim=[1, 2])) < eps:
215
+ break
216
+ prev_alpha = torch.exp(log_alpha).clone()
217
+
218
+ return log_alpha
219
+
220
+
221
+ def compute_rigid_transform(a: torch.Tensor, b: torch.Tensor, weights: torch.Tensor):
222
+ """Compute rigid transforms between two point sets
223
+
224
+ Args:
225
+ a (torch.Tensor): (B, M, 3) points
226
+ b (torch.Tensor): (B, N, 3) points
227
+ weights (torch.Tensor): (B, M)
228
+
229
+ Returns:
230
+ Transform T (B, 3, 4) to get from a to b, i.e. T*a = b
231
+ """
232
+
233
+ weights_normalized = weights[..., None] / (torch.sum(weights[..., None], dim=1, keepdim=True) + _EPS)
234
+ centroid_a = torch.sum(a * weights_normalized, dim=1)
235
+ centroid_b = torch.sum(b * weights_normalized, dim=1)
236
+ a_centered = a - centroid_a[:, None, :]
237
+ b_centered = b - centroid_b[:, None, :]
238
+ cov = a_centered.transpose(-2, -1) @ (b_centered * weights_normalized)
239
+
240
+ # Compute rotation using Kabsch algorithm. Will compute two copies with +/-V[:,:3]
241
+ # and choose based on determinant to avoid flips
242
+ u, s, v = torch.svd(cov, some=False, compute_uv=True)
243
+ rot_mat_pos = v @ u.transpose(-1, -2)
244
+ v_neg = v.clone()
245
+ v_neg[:, :, 2] *= -1
246
+ rot_mat_neg = v_neg @ u.transpose(-1, -2)
247
+ rot_mat = torch.where(torch.det(rot_mat_pos)[:, None, None] > 0, rot_mat_pos, rot_mat_neg)
248
+ assert torch.all(torch.det(rot_mat) > 0)
249
+
250
+ # Compute translation (uncenter centroid)
251
+ translation = -rot_mat @ centroid_a[:, :, None] + centroid_b[:, :, None]
252
+
253
+ transform = torch.cat((rot_mat, translation), dim=2)
254
+ return transform
255
+
256
+
257
+ class RPMNet(nn.Module):
258
+ def __init__(self, feature_model=PPFNet()):
259
+ super().__init__()
260
+
261
+ self.add_slack = True
262
+ self.num_sk_iter = 5
263
+
264
+ self.weights_net = ParameterPredictionNet(weights_dim=[0])
265
+ self.feat_extractor = feature_model
266
+
267
+ def compute_affinity(self, beta, feat_distance, alpha=0.5):
268
+ """Compute logarithm of Initial match matrix values, i.e. log(m_jk)"""
269
+ if isinstance(alpha, float):
270
+ hybrid_affinity = -beta[:, None, None] * (feat_distance - alpha)
271
+ else:
272
+ hybrid_affinity = -beta[:, None, None] * (feat_distance - alpha[:, None, None])
273
+ return hybrid_affinity
274
+
275
+ @staticmethod
276
+ def split_normals(data):
277
+ if data.shape[2] == 6:
278
+ xyz, normals = data[:, :, :3], data[:, :, 3:6]
279
+ elif data.shape[2] == 3:
280
+ xyz, normals = data, torch.zeros(data.shape).to(data.device)
281
+ return xyz, normals
282
+
283
+ def spam(self, xyz_template, norm_template, xyz_source, norm_source):
284
+ self.beta, self.alpha = self.weights_net([xyz_source, xyz_template])
285
+ self.feat_source = self.feat_extractor(xyz_source, norm_source)
286
+ self.feat_template = self.feat_extractor(xyz_template, norm_template)
287
+
288
+ feat_distance = match_features(self.feat_source, self.feat_template)
289
+ self.affinity = self.compute_affinity(self.beta, feat_distance, alpha=self.alpha)
290
+
291
+ # Compute weighted coordinates
292
+ log_perm_matrix = sinkhorn(self.affinity, n_iters=self.num_sk_iter, slack=self.add_slack)
293
+ self.perm_matrix = torch.exp(log_perm_matrix)
294
+ weighted_template = self.perm_matrix @ xyz_template / (torch.sum(self.perm_matrix, dim=2, keepdim=True) + _EPS)
295
+
296
+ return weighted_template
297
+
298
+ def forward(self, template, source, max_iterations: int = 1):
299
+ """Forward pass for RPMNet
300
+
301
+ Args:
302
+ data: Dict containing the following fields:
303
+ 'points_src': Source points (B, J, 6)
304
+ 'points_ref': Reference points (B, K, 6)
305
+ num_iter (int): Number of iterations. Recommended to be 2 for training
306
+
307
+ Returns:
308
+ transform: Transform to apply to source points such that they align to reference
309
+ src_transformed: Transformed source points
310
+ """
311
+
312
+ xyz_template, norm_template = self.split_normals(template)
313
+ xyz_source, norm_source = self.split_normals(source)
314
+
315
+ xyz_source_t, norm_source_t = xyz_source, norm_source
316
+
317
+ transforms = []
318
+ all_gamma, all_perm_matrices, all_weighted_template = [], [], []
319
+ all_beta, all_alpha = [], []
320
+
321
+ for i in range(max_iterations):
322
+ weighted_template = self.spam(xyz_template, norm_template, xyz_source_t, norm_source_t) # Finding better correspondences after each iteration.
323
+
324
+ # Compute transform and transform points
325
+ transform = compute_rigid_transform(xyz_source, weighted_template, weights=torch.sum(self.perm_matrix, dim=2))
326
+ xyz_source_t, norm_source_t = se3_transform(transform.detach(), xyz_source, norm_source) # Apply transformation to original source.
327
+
328
+ transforms.append(transform)
329
+ all_gamma.append(torch.exp(self.affinity))
330
+ all_perm_matrices.append(self.perm_matrix)
331
+ all_weighted_template.append(weighted_template)
332
+ all_beta.append(to_numpy(self.beta))
333
+ all_alpha.append(to_numpy(self.alpha))
334
+
335
+ est_T = convert2transformation(transforms[max_iterations-1][:, :3, :3], transforms[max_iterations-1][:, :3, 3])
336
+ transformed_source = torch.bmm(est_T[:, :3, :3], source[:,:,:3].permute(0, 2, 1)).permute(0, 2, 1) + est_T[:, :3, 3].unsqueeze(1)
337
+
338
+ result = {'est_R': est_T[:, :3, :3], # source -> template
339
+ 'est_t': est_T[:, :3, 3], # source -> template
340
+ 'est_T': est_T, # source -> template
341
+ # 'r': self.feat_template - self.feat_source,
342
+ 'transformed_source': transformed_source}
343
+
344
+ result['perm_matrices_init'] = all_gamma
345
+ result['perm_matrices'] = all_perm_matrices
346
+ result['weighted_template'] = all_weighted_template
347
+ result['beta'] = np.stack(all_beta, axis=0)
348
+ result['alpha'] = np.stack(all_alpha, axis=0)
349
+ result['transforms'] = transforms
350
+
351
+ return result
352
+
353
+
354
+ if __name__ == '__main__':
355
+ template, source = torch.rand(10,1024,6), torch.rand(10,1024,6)
356
+
357
+ net = RPMNet()
358
+ result = net(template, source)
359
+ import ipdb; ipdb.set_trace()
@@ -0,0 +1,38 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Segmentation(nn.Module):
7
+ def __init__(self, feature_model, num_classes=40):
8
+ super(Segmentation, self).__init__()
9
+ self.feature_model = feature_model
10
+ self.num_classes = num_classes
11
+
12
+ self.conv1 = torch.nn.Conv1d(self.feature_model.emb_dims+64, 512, 1)
13
+ self.conv2 = torch.nn.Conv1d(512, 256, 1)
14
+ self.conv3 = torch.nn.Conv1d(256, 128, 1)
15
+ self.conv4 = torch.nn.Conv1d(128, self.num_classes, 1)
16
+ self.bn1 = nn.BatchNorm1d(512)
17
+ self.bn2 = nn.BatchNorm1d(256)
18
+ self.bn3 = nn.BatchNorm1d(128)
19
+
20
+ def forward(self, input_data):
21
+ output = self.feature_model(input_data)
22
+ output = F.relu(self.bn1(self.conv1(output)))
23
+ output = F.relu(self.bn2(self.conv2(output)))
24
+ output = F.relu(self.bn3(self.conv3(output)))
25
+ output = self.conv4(output)
26
+ output = output.permute(0, 2, 1) # B x N x num_classes
27
+ return output
28
+
29
+ if __name__ == '__main__':
30
+ from pointnet import PointNet
31
+ x = torch.rand(10,1024,3)
32
+
33
+ pn = PointNet(global_feat=False)
34
+ seg = Segmentation(pn)
35
+ seg_result = seg(x)
36
+
37
+ print('Input Shape: {}\n Segmentation Output Shape: {}'
38
+ .format(x.shape, seg_result.shape))
File without changes
@@ -0,0 +1,45 @@
1
+ import torch
2
+
3
+ def mean_shift(template, source, p0_zero_mean, p1_zero_mean):
4
+ template_mean = torch.eye(3).view(1, 3, 3).expand(template.size(0), 3, 3).to(template) # [B, 3, 3]
5
+ source_mean = torch.eye(3).view(1, 3, 3).expand(source.size(0), 3, 3).to(source) # [B, 3, 3]
6
+
7
+ if p0_zero_mean:
8
+ p0_m = template.mean(dim=1) # [B, N, 3] -> [B, 3]
9
+ template_mean = torch.cat([template_mean, p0_m.unsqueeze(-1)], dim=2)
10
+ one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(template_mean.shape[0], 1, 1).to(template_mean) # (Bx1x4)
11
+ template_mean = torch.cat([template_mean, one_], dim=1)
12
+ template = template - p0_m.unsqueeze(1)
13
+ # else:
14
+ # q0 = template
15
+
16
+ if p1_zero_mean:
17
+ #print(numpy.any(numpy.isnan(p1.numpy())))
18
+ p1_m = source.mean(dim=1) # [B, N, 3] -> [B, 3]
19
+ source_mean = torch.cat([source_mean, -p0_m.unsqueeze(-1)], dim=2)
20
+ one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(source_mean.shape[0], 1, 1).to(source_mean) # (Bx1x4)
21
+ source_mean = torch.cat([source_mean, one_], dim=1)
22
+ source = source - p1_m.unsqueeze(1)
23
+ # else:
24
+ # q1 = source
25
+ return template, source, template_mean, source_mean
26
+
27
+ def postprocess_data(result, p0, p1, a0, a1, p0_zero_mean, p1_zero_mean):
28
+ #output' = trans(p0_m) * output * trans(-p1_m)
29
+ # = [I, p0_m;] * [R, t;] * [I, -p1_m;]
30
+ # [0, 1 ] [0, 1 ] [0, 1 ]
31
+ est_g = result['est_T']
32
+ if p0_zero_mean:
33
+ est_g = a0.to(est_g).bmm(est_g)
34
+ if p1_zero_mean:
35
+ est_g = est_g.bmm(a1.to(est_g))
36
+ result['est_T'] = est_g
37
+
38
+ est_gs = result['est_T_series'] # [M, B, 4, 4]
39
+ if p0_zero_mean:
40
+ est_gs = a0.unsqueeze(0).contiguous().to(est_gs).matmul(est_gs)
41
+ if p1_zero_mean:
42
+ est_gs = est_gs.matmul(a1.unsqueeze(0).contiguous().to(est_gs))
43
+ result['est_T_series'] = est_gs
44
+
45
+ return result
@@ -0,0 +1,134 @@
1
+ """ inverse matrix """
2
+
3
+ import torch
4
+
5
+
6
+ def batch_inverse(x):
7
+ """ M(n) -> M(n); x -> x^-1 """
8
+ batch_size, h, w = x.size()
9
+ assert h == w
10
+ y = torch.zeros_like(x)
11
+ for i in range(batch_size):
12
+ y[i, :, :] = x[i, :, :].inverse()
13
+ return y
14
+
15
+ def batch_inverse_dx(y):
16
+ """ backward """
17
+ # Let y(x) = x^-1.
18
+ # compute dy
19
+ # dy = dy(j,k)
20
+ # = - y(j,m) * dx(m,n) * y(n,k)
21
+ # = - y(j,m) * y(n,k) * dx(m,n)
22
+ # therefore,
23
+ # dy(j,k)/dx(m,n) = - y(j,m) * y(n,k)
24
+ batch_size, h, w = y.size()
25
+ assert h == w
26
+ # compute dy(j,k,m,n) = dy(j,k)/dx(m,n) = - y(j,m) * y(n,k)
27
+ # = - (y(j,:))' * y'(k,:)
28
+ yl = y.repeat(1, 1, h).view(batch_size*h*h, h, 1)
29
+ yr = y.transpose(1, 2).repeat(1, h, 1).view(batch_size*h*h, 1, h)
30
+ dy = - yl.bmm(yr).view(batch_size, h, h, h, h)
31
+
32
+ # compute dy(m,n,j,k) = dy(j,k)/dx(m,n) = - y(j,m) * y(n,k)
33
+ # = - (y'(m,:))' * y(n,:)
34
+ #yl = y.transpose(1, 2).repeat(1, 1, h).view(batch_size*h*h, h, 1)
35
+ #yr = y.repeat(1, h, 1).view(batch_size*h*h, 1, h)
36
+ #dy = - yl.bmm(yr).view(batch_size, h, h, h, h)
37
+
38
+ return dy
39
+
40
+
41
+ def batch_pinv_dx(x):
42
+ """ returns y = (x'*x)^-1 * x' and dy/dx. """
43
+ # y = (x'*x)^-1 * x'
44
+ # = s^-1 * x'
45
+ # = b * x'
46
+ # d{y(j,k)}/d{x(m,n)}
47
+ # = d{b(j,i) * x(k,i)}/d{x(m,n)}
48
+ # = d{b(j,i)}/d{x(m,n)} * x(k,i) + b(j,i) * d{x(k,i)}/d{x(m,n)}
49
+ # d{b(j,i)}/d{x(m,n)}
50
+ # = d{b(j,i)}/d{s(p,q)} * d{s(p,q)}/d{x(m,n)}
51
+ # = -b(j,p)*b(q,i) * d{s(p,q)}/d{x(m,n)}
52
+ # d{s(p,q)}/d{x(m,n)}
53
+ # = d{x(t,p)*x(t,q)}/d{x(m,n)}
54
+ # = d{x(t,p)}/d{x(m,n)} * x(t,q) + x(t,p) * d{x(t,q)}/d{x(m,n)}
55
+ batch_size, h, w = x.size()
56
+ xt = x.transpose(1, 2)
57
+ s = xt.bmm(x)
58
+ b = batch_inverse(s)
59
+ y = b.bmm(xt)
60
+
61
+ # dx/dx
62
+ ex = torch.eye(h*w).to(x).unsqueeze(0).view(1, h, w, h, w)
63
+ # ds/dx = dx(t,_)/dx * x(t,_) + x(t,_) * dx(t,_)/dx
64
+ ex1 = ex.view(1, h, w*h*w) # [t, p*m*n]
65
+ dx1 = x.transpose(1, 2).matmul(ex1).view(batch_size, w, w, h, w) # [q, p,m,n]
66
+ ds_dx = dx1.transpose(1, 2) + dx1 # [p, q, m, n]
67
+ # db/ds
68
+ db_ds = batch_inverse_dx(b) # [j, i, p, q]
69
+ # db/dx = db/d{s(p,q)} * d{s(p,q)}/dx
70
+ db1 = db_ds.view(batch_size, w*w, w*w).bmm(ds_dx.view(batch_size, w*w, h*w))
71
+ db_dx = db1.view(batch_size, w, w, h, w) # [j, i, m, n]
72
+ # dy/dx = db(_,i)/dx * x(_,i) + b(_,i) * dx(_,i)/dx
73
+ dy1 = db_dx.transpose(1, 2).contiguous().view(batch_size, w, w*h*w)
74
+ dy1 = x.matmul(dy1).view(batch_size, h, w, h, w) # [k, j, m, n]
75
+ ext = ex.transpose(1, 2).contiguous().view(1, w, h*h*w)
76
+ dy2 = b.matmul(ext).view(batch_size, w, h, h, w) # [j, k, m, n]
77
+ dy_dx = dy1.transpose(1, 2) + dy2
78
+
79
+ return y, dy_dx
80
+
81
+
82
+ class InvMatrix(torch.autograd.Function):
83
+ """ M(n) -> M(n); x -> x^-1.
84
+ """
85
+ @staticmethod
86
+ def forward(ctx, x):
87
+ y = batch_inverse(x)
88
+ ctx.save_for_backward(y)
89
+ return y
90
+
91
+ @staticmethod
92
+ def backward(ctx, grad_output):
93
+ y, = ctx.saved_tensors # v0.4
94
+ #y, = ctx.saved_variables # v0.3.1
95
+ batch_size, h, w = y.size()
96
+ assert h == w
97
+
98
+ # Let y(x) = x^-1 and assume any function f(y(x)).
99
+ # compute df/dx(m,n)...
100
+ # df/dx(m,n) = df/dy(j,k) * dy(j,k)/dx(m,n)
101
+ # well, df/dy is 'grad_output'
102
+ # and so we will return 'grad_input = df/dy(j,k) * dy(j,k)/dx(m,n)'
103
+
104
+ dy = batch_inverse_dx(y) # dy(j,k,m,n) = dy(j,k)/dx(m,n)
105
+ go = grad_output.contiguous().view(batch_size, 1, h*h) # [1, (j*k)]
106
+ ym = dy.view(batch_size, h*h, h*h) # [(j*k), (m*n)]
107
+ r = go.bmm(ym) # [1, (m*n)]
108
+ grad_input = r.view(batch_size, h, h) # [m, n]
109
+
110
+ return grad_input
111
+
112
+
113
+
114
+ if __name__ == '__main__':
115
+ def test():
116
+ x = torch.randn(2, 3, 2)
117
+ x_val = x.requires_grad_()
118
+
119
+ s_val = x_val.transpose(1, 2).bmm(x_val)
120
+ s_inv = InvMatrix.apply(s_val)
121
+ y_val = s_inv.bmm(x_val.transpose(1, 2))
122
+ y_val.sum().backward()
123
+ t1 = x_val.grad
124
+
125
+ y, dy_dx = batch_pinv_dx(x)
126
+ t2 = dy_dx.sum(1).sum(1)
127
+
128
+ print(t1)
129
+ print(t2)
130
+ print(t1 - t2)
131
+
132
+ test()
133
+
134
+ #EOF