learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
learning3d/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from .models import PointNet, create_pointconv, DGCNN, PPFNet, Pooling, Classifier, Segmentation
2
+ from .models import DCP, PRNet, iPCRNet, PointNetLK, RPMNet, PCN, DeepGMR, MaskNet, MaskNet2
@@ -0,0 +1,4 @@
1
+ from .dataloaders import ModelNet40Data
2
+ from .dataloaders import ClassificationData, RegistrationData, SegmentationData, FlowData, SceneflowDataset
3
+ from .dataloaders import download_modelnet40, deg_to_rad, create_random_transform
4
+ from .user_data import UserData
@@ -0,0 +1,454 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset
5
+ from torch.utils.data import DataLoader
6
+ import numpy as np
7
+ import os
8
+ import h5py
9
+ import subprocess
10
+ import shlex
11
+ import json
12
+ import glob
13
+ from .. ops import transform_functions, se3
14
+ from sklearn.neighbors import NearestNeighbors
15
+ from scipy.spatial.distance import minkowski
16
+ from scipy.spatial import cKDTree
17
+ from torch.utils.data import Dataset
18
+
19
+ def download_modelnet40():
20
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
21
+ DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
22
+ if not os.path.exists(DATA_DIR):
23
+ os.mkdir(DATA_DIR)
24
+ if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
25
+ www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
26
+ zipfile = os.path.basename(www)
27
+ os.system('wget --no-check-certificate %s; unzip %s' % (www, zipfile))
28
+ os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
29
+ os.system('rm %s' % (zipfile))
30
+
31
+ def load_data(train, use_normals):
32
+ if train: partition = 'train'
33
+ else: partition = 'test'
34
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
35
+ DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
36
+ all_data = []
37
+ all_label = []
38
+ for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5' % partition)):
39
+ f = h5py.File(h5_name)
40
+ if use_normals: data = np.concatenate([f['data'][:], f['normal'][:]], axis=-1).astype('float32')
41
+ else: data = f['data'][:].astype('float32')
42
+ label = f['label'][:].astype('int64')
43
+ f.close()
44
+ all_data.append(data)
45
+ all_label.append(label)
46
+ all_data = np.concatenate(all_data, axis=0)
47
+ all_label = np.concatenate(all_label, axis=0)
48
+ return all_data, all_label
49
+
50
+ def deg_to_rad(deg):
51
+ return np.pi / 180 * deg
52
+
53
+ def create_random_transform(dtype, max_rotation_deg, max_translation):
54
+ max_rotation = deg_to_rad(max_rotation_deg)
55
+ rot = np.random.uniform(-max_rotation, max_rotation, [1, 3])
56
+ trans = np.random.uniform(-max_translation, max_translation, [1, 3])
57
+ quat = transform_functions.euler_to_quaternion(rot, "xyz")
58
+
59
+ vec = np.concatenate([quat, trans], axis=1)
60
+ vec = torch.tensor(vec, dtype=dtype)
61
+ return vec
62
+
63
+ def jitter_pointcloud(pointcloud, sigma=0.04, clip=0.05):
64
+ # N, C = pointcloud.shape
65
+ sigma = 0.04*np.random.random_sample()
66
+ pointcloud += torch.empty(pointcloud.shape).normal_(mean=0, std=sigma).clamp(-clip, clip)
67
+ return pointcloud
68
+
69
+ def farthest_subsample_points(pointcloud1, num_subsampled_points=768):
70
+ pointcloud1 = pointcloud1
71
+ num_points = pointcloud1.shape[0]
72
+ nbrs1 = NearestNeighbors(n_neighbors=num_subsampled_points, algorithm='auto',
73
+ metric=lambda x, y: minkowski(x, y)).fit(pointcloud1[:, :3])
74
+ random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1])
75
+ idx1 = nbrs1.kneighbors(random_p1, return_distance=False).reshape((num_subsampled_points,))
76
+ gt_mask = torch.zeros(num_points).scatter_(0, torch.tensor(idx1), 1)
77
+ return pointcloud1[idx1, :], gt_mask
78
+
79
+ def uniform_2_sphere(num: int = None):
80
+ """Uniform sampling on a 2-sphere
81
+
82
+ Source: https://gist.github.com/andrewbolster/10274979
83
+
84
+ Args:
85
+ num: Number of vectors to sample (or None if single)
86
+
87
+ Returns:
88
+ Random Vector (np.ndarray) of size (num, 3) with norm 1.
89
+ If num is None returned value will have size (3,)
90
+
91
+ """
92
+ if num is not None:
93
+ phi = np.random.uniform(0.0, 2 * np.pi, num)
94
+ cos_theta = np.random.uniform(-1.0, 1.0, num)
95
+ else:
96
+ phi = np.random.uniform(0.0, 2 * np.pi)
97
+ cos_theta = np.random.uniform(-1.0, 1.0)
98
+
99
+ theta = np.arccos(cos_theta)
100
+ x = np.sin(theta) * np.cos(phi)
101
+ y = np.sin(theta) * np.sin(phi)
102
+ z = np.cos(theta)
103
+
104
+ return np.stack((x, y, z), axis=-1)
105
+
106
+ def planar_crop(points, p_keep= 0.7):
107
+ p_keep = np.array(p_keep, dtype=np.float32)
108
+
109
+ rand_xyz = uniform_2_sphere()
110
+ pts = points.numpy()
111
+ centroid = np.mean(pts[:, :3], axis=0)
112
+ points_centered = pts[:, :3] - centroid
113
+
114
+ dist_from_plane = np.dot(points_centered, rand_xyz)
115
+
116
+ mask = dist_from_plane > np.percentile(dist_from_plane, (1.0 - p_keep) * 100)
117
+ idx_x = torch.Tensor(np.nonzero(mask))
118
+
119
+ return torch.Tensor(pts[mask, :3]), idx_x
120
+
121
+ def knn_idx(pts, k):
122
+ kdt = cKDTree(pts)
123
+ _, idx = kdt.query(pts, k=k+1)
124
+ return idx[:, 1:]
125
+
126
+ def get_rri(pts, k):
127
+ # pts: N x 3, original points
128
+ # q: N x K x 3, nearest neighbors
129
+ q = pts[knn_idx(pts, k)]
130
+ p = np.repeat(pts[:, None], k, axis=1)
131
+ # rp, rq: N x K x 1, norms
132
+ rp = np.linalg.norm(p, axis=-1, keepdims=True)
133
+ rq = np.linalg.norm(q, axis=-1, keepdims=True)
134
+ pn = p / rp
135
+ qn = q / rq
136
+ dot = np.sum(pn * qn, -1, keepdims=True)
137
+ # theta: N x K x 1, angles
138
+ theta = np.arccos(np.clip(dot, -1, 1))
139
+ T_q = q - dot * p
140
+ sin_psi = np.sum(np.cross(T_q[:, None], T_q[:, :, None]) * pn[:, None], -1)
141
+ cos_psi = np.sum(T_q[:, None] * T_q[:, :, None], -1)
142
+ psi = np.arctan2(sin_psi, cos_psi) % (2*np.pi)
143
+ idx = np.argpartition(psi, 1)[:, :, 1:2]
144
+ # phi: N x K x 1, projection angles
145
+ phi = np.take_along_axis(psi, idx, axis=-1)
146
+ feat = np.concatenate([rp, rq, theta, phi], axis=-1)
147
+ return feat.reshape(-1, k * 4)
148
+
149
+ def get_rri_cuda(pts, k, npts_per_block=1):
150
+ try:
151
+ import pycuda.autoinit
152
+ from pycuda import gpuarray
153
+ from pycuda.compiler import SourceModule
154
+ except Exception as e:
155
+ print("Error raised in pycuda modules! pycuda only works with GPU, ", e)
156
+ raise
157
+
158
+ mod_rri = SourceModule(open('rri.cu').read() % (k, npts_per_block))
159
+ rri_cuda = mod_rri.get_function('get_rri_feature')
160
+
161
+ N = len(pts)
162
+ pts_gpu = gpuarray.to_gpu(pts.astype(np.float32).ravel())
163
+ k_idx = knn_idx(pts, k)
164
+ k_idx_gpu = gpuarray.to_gpu(k_idx.astype(np.int32).ravel())
165
+ feat_gpu = gpuarray.GPUArray((N * k * 4,), np.float32)
166
+
167
+ rri_cuda(pts_gpu, np.int32(N), k_idx_gpu, feat_gpu,
168
+ grid=(((N-1) // npts_per_block)+1, 1),
169
+ block=(npts_per_block, k, 1))
170
+
171
+ feat = feat_gpu.get().reshape(N, k * 4).astype(np.float32)
172
+ return feat
173
+
174
+
175
+ class UnknownDataTypeError(Exception):
176
+ def __init__(self, *args):
177
+ if args: self.message = args[0]
178
+ else: self.message = 'Datatype not understood for dataset.'
179
+
180
+ def __str__(self):
181
+ return self.message
182
+
183
+
184
+ class ModelNet40Data(Dataset):
185
+ def __init__(
186
+ self,
187
+ train=True,
188
+ num_points=1024,
189
+ download=True,
190
+ randomize_data=False,
191
+ use_normals=False
192
+ ):
193
+ super(ModelNet40Data, self).__init__()
194
+ if download: download_modelnet40()
195
+ self.data, self.labels = load_data(train, use_normals)
196
+ if not train: self.shapes = self.read_classes_ModelNet40()
197
+ self.num_points = num_points
198
+ self.randomize_data = randomize_data
199
+
200
+ def __getitem__(self, idx):
201
+ if self.randomize_data: current_points = self.randomize(idx)
202
+ else: current_points = self.data[idx].copy()
203
+
204
+ current_points = torch.from_numpy(current_points[:self.num_points, :]).float()
205
+ label = torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
206
+
207
+ return current_points, label
208
+
209
+ def __len__(self):
210
+ return self.data.shape[0]
211
+
212
+ def randomize(self, idx):
213
+ pt_idxs = np.arange(0, self.num_points)
214
+ np.random.shuffle(pt_idxs)
215
+ return self.data[idx, pt_idxs].copy()
216
+
217
+ def get_shape(self, label):
218
+ return self.shapes[label]
219
+
220
+ def read_classes_ModelNet40(self):
221
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
222
+ DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
223
+ file = open(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'shape_names.txt'), 'r')
224
+ shape_names = file.read()
225
+ shape_names = np.array(shape_names.split('\n')[:-1])
226
+ return shape_names
227
+
228
+
229
+ class ClassificationData(Dataset):
230
+ def __init__(self, data_class=ModelNet40Data()):
231
+ super(ClassificationData, self).__init__()
232
+ self.set_class(data_class)
233
+
234
+ def __len__(self):
235
+ return len(self.data_class)
236
+
237
+ def set_class(self, data_class):
238
+ self.data_class = data_class
239
+
240
+ def get_shape(self, label):
241
+ try:
242
+ return self.data_class.get_shape(label)
243
+ except:
244
+ return -1
245
+
246
+ def __getitem__(self, index):
247
+ return self.data_class[index]
248
+
249
+
250
+ class RegistrationData(Dataset):
251
+ def __init__(self, algorithm, data_class=ModelNet40Data(), partial_source=False, partial_template=False, noise=False, additional_params={}):
252
+ super(RegistrationData, self).__init__()
253
+ available_algorithms = ['PCRNet', 'PointNetLK', 'DCP', 'PRNet', 'iPCRNet', 'RPMNet', 'DeepGMR']
254
+ if algorithm in available_algorithms: self.algorithm = algorithm
255
+ else: raise Exception("Algorithm not available for registration.")
256
+
257
+ self.set_class(data_class)
258
+ self.partial_template = partial_template
259
+ self.partial_source = partial_source
260
+ self.noise = noise
261
+ self.additional_params = additional_params
262
+ self.use_rri = False
263
+
264
+ if self.algorithm == 'PCRNet' or self.algorithm == 'iPCRNet':
265
+ from .. ops.transform_functions import PCRNetTransform
266
+ self.transforms = PCRNetTransform(len(data_class), angle_range=45, translation_range=1)
267
+ if self.algorithm == 'PointNetLK':
268
+ from .. ops.transform_functions import PNLKTransform
269
+ self.transforms = PNLKTransform(0.8, True)
270
+ if self.algorithm == 'RPMNet':
271
+ from .. ops.transform_functions import RPMNetTransform
272
+ self.transforms = RPMNetTransform(0.8, True)
273
+ if self.algorithm == 'DCP' or self.algorithm == 'PRNet':
274
+ from .. ops.transform_functions import DCPTransform
275
+ self.transforms = DCPTransform(angle_range=45, translation_range=1)
276
+ if self.algorithm == 'DeepGMR':
277
+ self.get_rri = get_rri_cuda if torch.cuda.is_available() else get_rri
278
+ from .. ops.transform_functions import DeepGMRTransform
279
+ self.transforms = DeepGMRTransform(angle_range=90, translation_range=1)
280
+ if 'nearest_neighbors' in self.additional_params.keys() and self.additional_params['nearest_neighbors'] > 0:
281
+ self.use_rri = True
282
+ self.nearest_neighbors = self.additional_params['nearest_neighbors']
283
+
284
+ def __len__(self):
285
+ return len(self.data_class)
286
+
287
+ def set_class(self, data_class):
288
+ self.data_class = data_class
289
+
290
+ def __getitem__(self, index):
291
+ template, label = self.data_class[index]
292
+ self.transforms.index = index # for fixed transformations in PCRNet.
293
+ source = self.transforms(template)
294
+
295
+ # Check for Partial Data.
296
+ if self.additional_params.get('partial_point_cloud_method', None) == 'planar_crop':
297
+ source, gt_idx_source = planar_crop(source)
298
+ template, gt_idx_template = planar_crop(template)
299
+ intersect_mask, intersect_x, intersect_y = np.intersect1d(gt_idx_source, gt_idx_template, return_indices=True)
300
+
301
+ self.template_mask = torch.zeros(template.shape[0])
302
+ self.source_mask = torch.zeros(source.shape[0])
303
+ self.template_mask[intersect_y] = 1
304
+ self.source_mask[intersect_x] = 1
305
+ else:
306
+ if self.partial_source: source, self.source_mask = farthest_subsample_points(source)
307
+ if self.partial_template: template, self.template_mask = farthest_subsample_points(template)
308
+
309
+
310
+
311
+ # Check for Noise in Source Data.
312
+ if self.noise: source = jitter_pointcloud(source)
313
+
314
+ if self.use_rri:
315
+ template, source = template.numpy(), source.numpy()
316
+ template = np.concatenate([template, self.get_rri(template - template.mean(axis=0), self.nearest_neighbors)], axis=1)
317
+ source = np.concatenate([source, self.get_rri(source - source.mean(axis=0), self.nearest_neighbors)], axis=1)
318
+ template, source = torch.tensor(template).float(), torch.tensor(source).float()
319
+
320
+ igt = self.transforms.igt
321
+
322
+ if self.additional_params.get('use_masknet', False):
323
+ if self.partial_source and self.partial_template:
324
+ return template, source, igt, self.template_mask, self.source_mask
325
+ elif self.partial_source:
326
+ return template, source, igt, self.source_mask
327
+ elif self.partial_template:
328
+ return template, source, igt, self.template_mask
329
+ else:
330
+ return template, source, igt
331
+
332
+
333
+ class SegmentationData(Dataset):
334
+ def __init__(self):
335
+ super(SegmentationData, self).__init__()
336
+
337
+ def __len__(self):
338
+ pass
339
+
340
+ def __getitem__(self, index):
341
+ pass
342
+
343
+
344
+ class FlowData(Dataset):
345
+ def __init__(self):
346
+ super(FlowData, self).__init__()
347
+ self.pc1, self.pc2, self.flow = self.read_data()
348
+
349
+ def __len__(self):
350
+ if isinstance(self.pc1, np.ndarray):
351
+ return self.pc1.shape[0]
352
+ elif isinstance(self.pc1, list):
353
+ return len(self.pc1)
354
+ else:
355
+ raise UnknownDataTypeError
356
+
357
+ def read_data(self):
358
+ pass
359
+
360
+ def __getitem__(self, index):
361
+ return self.pc1[index], self.pc2[index], self.flow[index]
362
+
363
+
364
+ class SceneflowDataset(Dataset):
365
+ def __init__(self, npoints=1024, root='', partition='train'):
366
+ if root == '':
367
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
368
+ DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
369
+ root = os.path.join(DATA_DIR, 'data_processed_maxcut_35_20k_2k_8192')
370
+ if not os.path.exists(root):
371
+ print("To download dataset, click here: https://drive.google.com/file/d/1CMaxdt-Tg1Wct8v8eGNwuT7qRSIyJPY-/view")
372
+ exit()
373
+ else:
374
+ print("SceneflowDataset Found Successfully!")
375
+
376
+ self.npoints = npoints
377
+ self.partition = partition
378
+ self.root = root
379
+ if self.partition=='train':
380
+ self.datapath = glob.glob(os.path.join(self.root, 'TRAIN*.npz'))
381
+ else:
382
+ self.datapath = glob.glob(os.path.join(self.root, 'TEST*.npz'))
383
+ self.cache = {}
384
+ self.cache_size = 30000
385
+
386
+ ###### deal with one bad datapoint with nan value
387
+ self.datapath = [d for d in self.datapath if 'TRAIN_C_0140_left_0006-0' not in d]
388
+ ######
389
+ print(self.partition, ': ',len(self.datapath))
390
+
391
+ def __getitem__(self, index):
392
+ if index in self.cache:
393
+ pos1, pos2, color1, color2, flow, mask1 = self.cache[index]
394
+ else:
395
+ fn = self.datapath[index]
396
+ with open(fn, 'rb') as fp:
397
+ data = np.load(fp)
398
+ pos1 = data['points1'].astype('float32')
399
+ pos2 = data['points2'].astype('float32')
400
+ color1 = data['color1'].astype('float32')
401
+ color2 = data['color2'].astype('float32')
402
+ flow = data['flow'].astype('float32')
403
+ mask1 = data['valid_mask1']
404
+
405
+ if len(self.cache) < self.cache_size:
406
+ self.cache[index] = (pos1, pos2, color1, color2, flow, mask1)
407
+
408
+ if self.partition == 'train':
409
+ n1 = pos1.shape[0]
410
+ sample_idx1 = np.random.choice(n1, self.npoints, replace=False)
411
+ n2 = pos2.shape[0]
412
+ sample_idx2 = np.random.choice(n2, self.npoints, replace=False)
413
+
414
+ pos1 = pos1[sample_idx1, :]
415
+ pos2 = pos2[sample_idx2, :]
416
+ color1 = color1[sample_idx1, :]
417
+ color2 = color2[sample_idx2, :]
418
+ flow = flow[sample_idx1, :]
419
+ mask1 = mask1[sample_idx1]
420
+ else:
421
+ pos1 = pos1[:self.npoints, :]
422
+ pos2 = pos2[:self.npoints, :]
423
+ color1 = color1[:self.npoints, :]
424
+ color2 = color2[:self.npoints, :]
425
+ flow = flow[:self.npoints, :]
426
+ mask1 = mask1[:self.npoints]
427
+
428
+ pos1_center = np.mean(pos1, 0)
429
+ pos1 -= pos1_center
430
+ pos2 -= pos1_center
431
+
432
+ return pos1, pos2, color1, color2, flow, mask1
433
+
434
+ def __len__(self):
435
+ return len(self.datapath)
436
+
437
+
438
+ if __name__ == '__main__':
439
+ class Data():
440
+ def __init__(self):
441
+ super(Data, self).__init__()
442
+ self.data, self.label = self.read_data()
443
+
444
+ def read_data(self):
445
+ return [4,5,6], [4,5,6]
446
+
447
+ def __len__(self):
448
+ return len(self.data)
449
+
450
+ def __getitem__(self, idx):
451
+ return self.data[idx], self.label[idx]
452
+
453
+ cd = RegistrationData('abc')
454
+ import ipdb; ipdb.set_trace()
@@ -0,0 +1,119 @@
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+
5
+ class ClassificationData:
6
+ def __init__(self, data_dict):
7
+ self.data_dict = data_dict
8
+ self.pcs = self.find_attribute('pcs')
9
+ self.labels = self.find_attribute('labels')
10
+ self.check_data()
11
+
12
+ def find_attribute(self, attribute):
13
+ try:
14
+ attribute_data = self.data_dict[attribute]
15
+ except:
16
+ print("Given data directory has no key attribute \"{}\"".format(attribute))
17
+ return attribute_data
18
+
19
+ def check_data(self):
20
+ assert 1 < len(self.pcs.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.pcs.shape)
21
+ assert 0 < len(self.labels.shape) < 3, "Error in dimension of labels! Given data dimension: {}".format(self.labels.shape)
22
+
23
+ if len(self.pcs.shape)==2: self.pcs = self.pcs.reshape(1, -1, 3)
24
+ if len(self.labels.shape) == 1: self.labels = self.labels.reshape(1, -1)
25
+
26
+ assert self.pcs.shape[0] == self.labels.shape[0], "Inconsistency in the number of point clouds and number of ground truth labels!"
27
+
28
+
29
+ def __len__(self):
30
+ return self.pcs.shape[0]
31
+
32
+ def __getitem__(self, index):
33
+ return torch.tensor(self.pcs[index]).float(), torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
34
+
35
+
36
+ class RegistrationData:
37
+ def __init__(self, data_dict):
38
+ self.data_dict = data_dict
39
+ self.template = self.find_attribute('template')
40
+ self.source = self.find_attribute('source')
41
+ self.transformation = self.find_attribute('transformation')
42
+ self.check_data()
43
+
44
+ def find_attribute(self, attribute):
45
+ try:
46
+ attribute_data = self.data[attribute]
47
+ except:
48
+ print("Given data directory has no key attribute \"{}\"".format(attribute))
49
+ return attribute_data
50
+
51
+ def check_data(self):
52
+ assert 1 < len(self.template.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.template.shape)
53
+ assert 1 < len(self.source.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.source.shape)
54
+ assert 1 < len(self.transformation.shape) < 4, "Error in dimension of transformations! Given data dimension: {}".format(self.transformation.shape)
55
+
56
+ if len(self.template.shape)==2: self.template = self.template.reshape(1, -1, 3)
57
+ if len(self.source.shape)==2: self.source = self.source.reshape(1, -1, 3)
58
+ if len(self.transformation.shape) == 2: self.transformation = self.transformation.reshape(1, 4, 4)
59
+
60
+ assert self.template.shape[0] == self.source.shape[0], "Inconsistency in the number of template and source point clouds!"
61
+ assert self.source.shape[0] == self.transformation.shape[0], "Inconsistency in the number of transformation and source point clouds!"
62
+
63
+ def __len__(self):
64
+ return self.template.shape[0]
65
+
66
+ def __getitem__(self, index):
67
+ return torch.tensor(self.template[index]).float(), torch.tensor(self.source[index]).float(), torch.tensor(self.transformation[index]).float()
68
+
69
+
70
+ class FlowData:
71
+ def __init__(self, data_dict):
72
+ self.data_dict = data_dict
73
+ self.frame1 = self.find_attribute('frame1')
74
+ self.frame2 = self.find_attribute('frame2')
75
+ self.flow = self.find_attribute('flow')
76
+ self.check_data()
77
+
78
+ def find_attribute(self, attribute):
79
+ try:
80
+ attribute_data = self.data[attribute]
81
+ except:
82
+ print("Given data directory has no key attribute \"{}\"".format(attribute))
83
+ return attribute_data
84
+
85
+ def check_data(self):
86
+ assert 1 < len(self.frame1.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame1.shape)
87
+ assert 1 < len(self.frame2.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame2.shape)
88
+ assert 1 < len(self.flow.shape) < 4, "Error in dimension of flow! Given data dimension: {}".format(self.flow.shape)
89
+
90
+ if len(self.frame1.shape)==2: self.frame1 = self.frame1.reshape(1, -1, 3)
91
+ if len(self.frame2.shape)==2: self.frame2 = self.frame2.reshape(1, -1, 3)
92
+ if len(self.flow.shape) == 2: self.flow = self.flow.reshape(1, -1, 3)
93
+
94
+ assert self.frame1.shape[0] == self.frame2.shape[0], "Inconsistency in the number of frame1 and frame2 point clouds!"
95
+ assert self.frame2.shape[0] == self.flow.shape[0], "Inconsistency in the number of flow and frame2 point clouds!"
96
+
97
+ def __len__(self):
98
+ return self.frame1.shape[0]
99
+
100
+ def __getitem__(self, index):
101
+ return torch.tensor(self.frame1[index]).float(), torch.tensor(self.frame2[index]).float(), torch.tensor(self.flow[index]).float()
102
+
103
+
104
+ class UserData:
105
+ def __init__(self, application, data_dict):
106
+ self.application = application
107
+
108
+ if self.application == 'classification':
109
+ self.data_class = ClassificationData(data_dict)
110
+ elif self.application == 'registration':
111
+ self.data_class = RegistrationData(data_dict)
112
+ elif self.application == 'flow_estimation':
113
+ self.data_class = FlowData(data_dict)
114
+
115
+ def __len__(self):
116
+ return len(self.data_class)
117
+
118
+ def __getitem__(self, index):
119
+ return self.data_class[index]