learning3d 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- learning3d/__init__.py +2 -0
- learning3d/data_utils/__init__.py +4 -0
- learning3d/data_utils/dataloaders.py +454 -0
- learning3d/data_utils/user_data.py +119 -0
- learning3d/examples/test_dcp.py +139 -0
- learning3d/examples/test_deepgmr.py +144 -0
- learning3d/examples/test_flownet.py +113 -0
- learning3d/examples/test_masknet.py +159 -0
- learning3d/examples/test_masknet2.py +162 -0
- learning3d/examples/test_pcn.py +118 -0
- learning3d/examples/test_pcrnet.py +120 -0
- learning3d/examples/test_pnlk.py +121 -0
- learning3d/examples/test_pointconv.py +126 -0
- learning3d/examples/test_pointnet.py +121 -0
- learning3d/examples/test_prnet.py +126 -0
- learning3d/examples/test_rpmnet.py +120 -0
- learning3d/examples/train_PointNetLK.py +240 -0
- learning3d/examples/train_dcp.py +249 -0
- learning3d/examples/train_deepgmr.py +244 -0
- learning3d/examples/train_flownet.py +259 -0
- learning3d/examples/train_masknet.py +239 -0
- learning3d/examples/train_pcn.py +216 -0
- learning3d/examples/train_pcrnet.py +228 -0
- learning3d/examples/train_pointconv.py +245 -0
- learning3d/examples/train_pointnet.py +244 -0
- learning3d/examples/train_prnet.py +229 -0
- learning3d/examples/train_rpmnet.py +228 -0
- learning3d/losses/__init__.py +12 -0
- learning3d/losses/chamfer_distance.py +51 -0
- learning3d/losses/classification.py +14 -0
- learning3d/losses/correspondence_loss.py +10 -0
- learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
- learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
- learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
- learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
- learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
- learning3d/losses/cuda/emd_torch/setup.py +29 -0
- learning3d/losses/emd.py +16 -0
- learning3d/losses/frobenius_norm.py +21 -0
- learning3d/losses/rmse_features.py +16 -0
- learning3d/models/__init__.py +23 -0
- learning3d/models/classifier.py +41 -0
- learning3d/models/dcp.py +92 -0
- learning3d/models/deepgmr.py +165 -0
- learning3d/models/dgcnn.py +92 -0
- learning3d/models/flownet3d.py +446 -0
- learning3d/models/masknet.py +84 -0
- learning3d/models/masknet2.py +264 -0
- learning3d/models/pcn.py +164 -0
- learning3d/models/pcrnet.py +74 -0
- learning3d/models/pointconv.py +108 -0
- learning3d/models/pointnet.py +108 -0
- learning3d/models/pointnetlk.py +173 -0
- learning3d/models/pooling.py +15 -0
- learning3d/models/ppfnet.py +102 -0
- learning3d/models/prnet.py +431 -0
- learning3d/models/rpmnet.py +359 -0
- learning3d/models/segmentation.py +38 -0
- learning3d/ops/__init__.py +0 -0
- learning3d/ops/data_utils.py +45 -0
- learning3d/ops/invmat.py +134 -0
- learning3d/ops/quaternion.py +218 -0
- learning3d/ops/se3.py +157 -0
- learning3d/ops/sinc.py +229 -0
- learning3d/ops/so3.py +213 -0
- learning3d/ops/transform_functions.py +342 -0
- learning3d/utils/__init__.py +9 -0
- learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
- learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
- learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
- learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
- learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
- learning3d/utils/lib/pointnet2_modules.py +160 -0
- learning3d/utils/lib/pointnet2_utils.py +318 -0
- learning3d/utils/lib/pytorch_utils.py +236 -0
- learning3d/utils/lib/setup.py +23 -0
- learning3d/utils/lib/src/ball_query.cpp +25 -0
- learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
- learning3d/utils/lib/src/ball_query_gpu.h +15 -0
- learning3d/utils/lib/src/cuda_utils.h +15 -0
- learning3d/utils/lib/src/group_points.cpp +36 -0
- learning3d/utils/lib/src/group_points_gpu.cu +86 -0
- learning3d/utils/lib/src/group_points_gpu.h +22 -0
- learning3d/utils/lib/src/interpolate.cpp +65 -0
- learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
- learning3d/utils/lib/src/interpolate_gpu.h +36 -0
- learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
- learning3d/utils/lib/src/sampling.cpp +46 -0
- learning3d/utils/lib/src/sampling_gpu.cu +253 -0
- learning3d/utils/lib/src/sampling_gpu.h +29 -0
- learning3d/utils/pointconv_util.py +382 -0
- learning3d/utils/ppfnet_util.py +244 -0
- learning3d/utils/svd.py +59 -0
- learning3d/utils/transformer.py +243 -0
- learning3d-0.0.1.dist-info/LICENSE +21 -0
- learning3d-0.0.1.dist-info/METADATA +271 -0
- learning3d-0.0.1.dist-info/RECORD +115 -0
- learning3d-0.0.1.dist-info/WHEEL +5 -0
- learning3d-0.0.1.dist-info/top_level.txt +1 -0
learning3d/__init__.py
ADDED
@@ -0,0 +1,454 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from torch.utils.data import Dataset
|
5
|
+
from torch.utils.data import DataLoader
|
6
|
+
import numpy as np
|
7
|
+
import os
|
8
|
+
import h5py
|
9
|
+
import subprocess
|
10
|
+
import shlex
|
11
|
+
import json
|
12
|
+
import glob
|
13
|
+
from .. ops import transform_functions, se3
|
14
|
+
from sklearn.neighbors import NearestNeighbors
|
15
|
+
from scipy.spatial.distance import minkowski
|
16
|
+
from scipy.spatial import cKDTree
|
17
|
+
from torch.utils.data import Dataset
|
18
|
+
|
19
|
+
def download_modelnet40():
|
20
|
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
21
|
+
DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
|
22
|
+
if not os.path.exists(DATA_DIR):
|
23
|
+
os.mkdir(DATA_DIR)
|
24
|
+
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
|
25
|
+
www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
|
26
|
+
zipfile = os.path.basename(www)
|
27
|
+
os.system('wget --no-check-certificate %s; unzip %s' % (www, zipfile))
|
28
|
+
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
|
29
|
+
os.system('rm %s' % (zipfile))
|
30
|
+
|
31
|
+
def load_data(train, use_normals):
|
32
|
+
if train: partition = 'train'
|
33
|
+
else: partition = 'test'
|
34
|
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
35
|
+
DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
|
36
|
+
all_data = []
|
37
|
+
all_label = []
|
38
|
+
for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5' % partition)):
|
39
|
+
f = h5py.File(h5_name)
|
40
|
+
if use_normals: data = np.concatenate([f['data'][:], f['normal'][:]], axis=-1).astype('float32')
|
41
|
+
else: data = f['data'][:].astype('float32')
|
42
|
+
label = f['label'][:].astype('int64')
|
43
|
+
f.close()
|
44
|
+
all_data.append(data)
|
45
|
+
all_label.append(label)
|
46
|
+
all_data = np.concatenate(all_data, axis=0)
|
47
|
+
all_label = np.concatenate(all_label, axis=0)
|
48
|
+
return all_data, all_label
|
49
|
+
|
50
|
+
def deg_to_rad(deg):
|
51
|
+
return np.pi / 180 * deg
|
52
|
+
|
53
|
+
def create_random_transform(dtype, max_rotation_deg, max_translation):
|
54
|
+
max_rotation = deg_to_rad(max_rotation_deg)
|
55
|
+
rot = np.random.uniform(-max_rotation, max_rotation, [1, 3])
|
56
|
+
trans = np.random.uniform(-max_translation, max_translation, [1, 3])
|
57
|
+
quat = transform_functions.euler_to_quaternion(rot, "xyz")
|
58
|
+
|
59
|
+
vec = np.concatenate([quat, trans], axis=1)
|
60
|
+
vec = torch.tensor(vec, dtype=dtype)
|
61
|
+
return vec
|
62
|
+
|
63
|
+
def jitter_pointcloud(pointcloud, sigma=0.04, clip=0.05):
|
64
|
+
# N, C = pointcloud.shape
|
65
|
+
sigma = 0.04*np.random.random_sample()
|
66
|
+
pointcloud += torch.empty(pointcloud.shape).normal_(mean=0, std=sigma).clamp(-clip, clip)
|
67
|
+
return pointcloud
|
68
|
+
|
69
|
+
def farthest_subsample_points(pointcloud1, num_subsampled_points=768):
|
70
|
+
pointcloud1 = pointcloud1
|
71
|
+
num_points = pointcloud1.shape[0]
|
72
|
+
nbrs1 = NearestNeighbors(n_neighbors=num_subsampled_points, algorithm='auto',
|
73
|
+
metric=lambda x, y: minkowski(x, y)).fit(pointcloud1[:, :3])
|
74
|
+
random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1])
|
75
|
+
idx1 = nbrs1.kneighbors(random_p1, return_distance=False).reshape((num_subsampled_points,))
|
76
|
+
gt_mask = torch.zeros(num_points).scatter_(0, torch.tensor(idx1), 1)
|
77
|
+
return pointcloud1[idx1, :], gt_mask
|
78
|
+
|
79
|
+
def uniform_2_sphere(num: int = None):
|
80
|
+
"""Uniform sampling on a 2-sphere
|
81
|
+
|
82
|
+
Source: https://gist.github.com/andrewbolster/10274979
|
83
|
+
|
84
|
+
Args:
|
85
|
+
num: Number of vectors to sample (or None if single)
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
Random Vector (np.ndarray) of size (num, 3) with norm 1.
|
89
|
+
If num is None returned value will have size (3,)
|
90
|
+
|
91
|
+
"""
|
92
|
+
if num is not None:
|
93
|
+
phi = np.random.uniform(0.0, 2 * np.pi, num)
|
94
|
+
cos_theta = np.random.uniform(-1.0, 1.0, num)
|
95
|
+
else:
|
96
|
+
phi = np.random.uniform(0.0, 2 * np.pi)
|
97
|
+
cos_theta = np.random.uniform(-1.0, 1.0)
|
98
|
+
|
99
|
+
theta = np.arccos(cos_theta)
|
100
|
+
x = np.sin(theta) * np.cos(phi)
|
101
|
+
y = np.sin(theta) * np.sin(phi)
|
102
|
+
z = np.cos(theta)
|
103
|
+
|
104
|
+
return np.stack((x, y, z), axis=-1)
|
105
|
+
|
106
|
+
def planar_crop(points, p_keep= 0.7):
|
107
|
+
p_keep = np.array(p_keep, dtype=np.float32)
|
108
|
+
|
109
|
+
rand_xyz = uniform_2_sphere()
|
110
|
+
pts = points.numpy()
|
111
|
+
centroid = np.mean(pts[:, :3], axis=0)
|
112
|
+
points_centered = pts[:, :3] - centroid
|
113
|
+
|
114
|
+
dist_from_plane = np.dot(points_centered, rand_xyz)
|
115
|
+
|
116
|
+
mask = dist_from_plane > np.percentile(dist_from_plane, (1.0 - p_keep) * 100)
|
117
|
+
idx_x = torch.Tensor(np.nonzero(mask))
|
118
|
+
|
119
|
+
return torch.Tensor(pts[mask, :3]), idx_x
|
120
|
+
|
121
|
+
def knn_idx(pts, k):
|
122
|
+
kdt = cKDTree(pts)
|
123
|
+
_, idx = kdt.query(pts, k=k+1)
|
124
|
+
return idx[:, 1:]
|
125
|
+
|
126
|
+
def get_rri(pts, k):
|
127
|
+
# pts: N x 3, original points
|
128
|
+
# q: N x K x 3, nearest neighbors
|
129
|
+
q = pts[knn_idx(pts, k)]
|
130
|
+
p = np.repeat(pts[:, None], k, axis=1)
|
131
|
+
# rp, rq: N x K x 1, norms
|
132
|
+
rp = np.linalg.norm(p, axis=-1, keepdims=True)
|
133
|
+
rq = np.linalg.norm(q, axis=-1, keepdims=True)
|
134
|
+
pn = p / rp
|
135
|
+
qn = q / rq
|
136
|
+
dot = np.sum(pn * qn, -1, keepdims=True)
|
137
|
+
# theta: N x K x 1, angles
|
138
|
+
theta = np.arccos(np.clip(dot, -1, 1))
|
139
|
+
T_q = q - dot * p
|
140
|
+
sin_psi = np.sum(np.cross(T_q[:, None], T_q[:, :, None]) * pn[:, None], -1)
|
141
|
+
cos_psi = np.sum(T_q[:, None] * T_q[:, :, None], -1)
|
142
|
+
psi = np.arctan2(sin_psi, cos_psi) % (2*np.pi)
|
143
|
+
idx = np.argpartition(psi, 1)[:, :, 1:2]
|
144
|
+
# phi: N x K x 1, projection angles
|
145
|
+
phi = np.take_along_axis(psi, idx, axis=-1)
|
146
|
+
feat = np.concatenate([rp, rq, theta, phi], axis=-1)
|
147
|
+
return feat.reshape(-1, k * 4)
|
148
|
+
|
149
|
+
def get_rri_cuda(pts, k, npts_per_block=1):
|
150
|
+
try:
|
151
|
+
import pycuda.autoinit
|
152
|
+
from pycuda import gpuarray
|
153
|
+
from pycuda.compiler import SourceModule
|
154
|
+
except Exception as e:
|
155
|
+
print("Error raised in pycuda modules! pycuda only works with GPU, ", e)
|
156
|
+
raise
|
157
|
+
|
158
|
+
mod_rri = SourceModule(open('rri.cu').read() % (k, npts_per_block))
|
159
|
+
rri_cuda = mod_rri.get_function('get_rri_feature')
|
160
|
+
|
161
|
+
N = len(pts)
|
162
|
+
pts_gpu = gpuarray.to_gpu(pts.astype(np.float32).ravel())
|
163
|
+
k_idx = knn_idx(pts, k)
|
164
|
+
k_idx_gpu = gpuarray.to_gpu(k_idx.astype(np.int32).ravel())
|
165
|
+
feat_gpu = gpuarray.GPUArray((N * k * 4,), np.float32)
|
166
|
+
|
167
|
+
rri_cuda(pts_gpu, np.int32(N), k_idx_gpu, feat_gpu,
|
168
|
+
grid=(((N-1) // npts_per_block)+1, 1),
|
169
|
+
block=(npts_per_block, k, 1))
|
170
|
+
|
171
|
+
feat = feat_gpu.get().reshape(N, k * 4).astype(np.float32)
|
172
|
+
return feat
|
173
|
+
|
174
|
+
|
175
|
+
class UnknownDataTypeError(Exception):
|
176
|
+
def __init__(self, *args):
|
177
|
+
if args: self.message = args[0]
|
178
|
+
else: self.message = 'Datatype not understood for dataset.'
|
179
|
+
|
180
|
+
def __str__(self):
|
181
|
+
return self.message
|
182
|
+
|
183
|
+
|
184
|
+
class ModelNet40Data(Dataset):
|
185
|
+
def __init__(
|
186
|
+
self,
|
187
|
+
train=True,
|
188
|
+
num_points=1024,
|
189
|
+
download=True,
|
190
|
+
randomize_data=False,
|
191
|
+
use_normals=False
|
192
|
+
):
|
193
|
+
super(ModelNet40Data, self).__init__()
|
194
|
+
if download: download_modelnet40()
|
195
|
+
self.data, self.labels = load_data(train, use_normals)
|
196
|
+
if not train: self.shapes = self.read_classes_ModelNet40()
|
197
|
+
self.num_points = num_points
|
198
|
+
self.randomize_data = randomize_data
|
199
|
+
|
200
|
+
def __getitem__(self, idx):
|
201
|
+
if self.randomize_data: current_points = self.randomize(idx)
|
202
|
+
else: current_points = self.data[idx].copy()
|
203
|
+
|
204
|
+
current_points = torch.from_numpy(current_points[:self.num_points, :]).float()
|
205
|
+
label = torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
|
206
|
+
|
207
|
+
return current_points, label
|
208
|
+
|
209
|
+
def __len__(self):
|
210
|
+
return self.data.shape[0]
|
211
|
+
|
212
|
+
def randomize(self, idx):
|
213
|
+
pt_idxs = np.arange(0, self.num_points)
|
214
|
+
np.random.shuffle(pt_idxs)
|
215
|
+
return self.data[idx, pt_idxs].copy()
|
216
|
+
|
217
|
+
def get_shape(self, label):
|
218
|
+
return self.shapes[label]
|
219
|
+
|
220
|
+
def read_classes_ModelNet40(self):
|
221
|
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
222
|
+
DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
|
223
|
+
file = open(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'shape_names.txt'), 'r')
|
224
|
+
shape_names = file.read()
|
225
|
+
shape_names = np.array(shape_names.split('\n')[:-1])
|
226
|
+
return shape_names
|
227
|
+
|
228
|
+
|
229
|
+
class ClassificationData(Dataset):
|
230
|
+
def __init__(self, data_class=ModelNet40Data()):
|
231
|
+
super(ClassificationData, self).__init__()
|
232
|
+
self.set_class(data_class)
|
233
|
+
|
234
|
+
def __len__(self):
|
235
|
+
return len(self.data_class)
|
236
|
+
|
237
|
+
def set_class(self, data_class):
|
238
|
+
self.data_class = data_class
|
239
|
+
|
240
|
+
def get_shape(self, label):
|
241
|
+
try:
|
242
|
+
return self.data_class.get_shape(label)
|
243
|
+
except:
|
244
|
+
return -1
|
245
|
+
|
246
|
+
def __getitem__(self, index):
|
247
|
+
return self.data_class[index]
|
248
|
+
|
249
|
+
|
250
|
+
class RegistrationData(Dataset):
|
251
|
+
def __init__(self, algorithm, data_class=ModelNet40Data(), partial_source=False, partial_template=False, noise=False, additional_params={}):
|
252
|
+
super(RegistrationData, self).__init__()
|
253
|
+
available_algorithms = ['PCRNet', 'PointNetLK', 'DCP', 'PRNet', 'iPCRNet', 'RPMNet', 'DeepGMR']
|
254
|
+
if algorithm in available_algorithms: self.algorithm = algorithm
|
255
|
+
else: raise Exception("Algorithm not available for registration.")
|
256
|
+
|
257
|
+
self.set_class(data_class)
|
258
|
+
self.partial_template = partial_template
|
259
|
+
self.partial_source = partial_source
|
260
|
+
self.noise = noise
|
261
|
+
self.additional_params = additional_params
|
262
|
+
self.use_rri = False
|
263
|
+
|
264
|
+
if self.algorithm == 'PCRNet' or self.algorithm == 'iPCRNet':
|
265
|
+
from .. ops.transform_functions import PCRNetTransform
|
266
|
+
self.transforms = PCRNetTransform(len(data_class), angle_range=45, translation_range=1)
|
267
|
+
if self.algorithm == 'PointNetLK':
|
268
|
+
from .. ops.transform_functions import PNLKTransform
|
269
|
+
self.transforms = PNLKTransform(0.8, True)
|
270
|
+
if self.algorithm == 'RPMNet':
|
271
|
+
from .. ops.transform_functions import RPMNetTransform
|
272
|
+
self.transforms = RPMNetTransform(0.8, True)
|
273
|
+
if self.algorithm == 'DCP' or self.algorithm == 'PRNet':
|
274
|
+
from .. ops.transform_functions import DCPTransform
|
275
|
+
self.transforms = DCPTransform(angle_range=45, translation_range=1)
|
276
|
+
if self.algorithm == 'DeepGMR':
|
277
|
+
self.get_rri = get_rri_cuda if torch.cuda.is_available() else get_rri
|
278
|
+
from .. ops.transform_functions import DeepGMRTransform
|
279
|
+
self.transforms = DeepGMRTransform(angle_range=90, translation_range=1)
|
280
|
+
if 'nearest_neighbors' in self.additional_params.keys() and self.additional_params['nearest_neighbors'] > 0:
|
281
|
+
self.use_rri = True
|
282
|
+
self.nearest_neighbors = self.additional_params['nearest_neighbors']
|
283
|
+
|
284
|
+
def __len__(self):
|
285
|
+
return len(self.data_class)
|
286
|
+
|
287
|
+
def set_class(self, data_class):
|
288
|
+
self.data_class = data_class
|
289
|
+
|
290
|
+
def __getitem__(self, index):
|
291
|
+
template, label = self.data_class[index]
|
292
|
+
self.transforms.index = index # for fixed transformations in PCRNet.
|
293
|
+
source = self.transforms(template)
|
294
|
+
|
295
|
+
# Check for Partial Data.
|
296
|
+
if self.additional_params.get('partial_point_cloud_method', None) == 'planar_crop':
|
297
|
+
source, gt_idx_source = planar_crop(source)
|
298
|
+
template, gt_idx_template = planar_crop(template)
|
299
|
+
intersect_mask, intersect_x, intersect_y = np.intersect1d(gt_idx_source, gt_idx_template, return_indices=True)
|
300
|
+
|
301
|
+
self.template_mask = torch.zeros(template.shape[0])
|
302
|
+
self.source_mask = torch.zeros(source.shape[0])
|
303
|
+
self.template_mask[intersect_y] = 1
|
304
|
+
self.source_mask[intersect_x] = 1
|
305
|
+
else:
|
306
|
+
if self.partial_source: source, self.source_mask = farthest_subsample_points(source)
|
307
|
+
if self.partial_template: template, self.template_mask = farthest_subsample_points(template)
|
308
|
+
|
309
|
+
|
310
|
+
|
311
|
+
# Check for Noise in Source Data.
|
312
|
+
if self.noise: source = jitter_pointcloud(source)
|
313
|
+
|
314
|
+
if self.use_rri:
|
315
|
+
template, source = template.numpy(), source.numpy()
|
316
|
+
template = np.concatenate([template, self.get_rri(template - template.mean(axis=0), self.nearest_neighbors)], axis=1)
|
317
|
+
source = np.concatenate([source, self.get_rri(source - source.mean(axis=0), self.nearest_neighbors)], axis=1)
|
318
|
+
template, source = torch.tensor(template).float(), torch.tensor(source).float()
|
319
|
+
|
320
|
+
igt = self.transforms.igt
|
321
|
+
|
322
|
+
if self.additional_params.get('use_masknet', False):
|
323
|
+
if self.partial_source and self.partial_template:
|
324
|
+
return template, source, igt, self.template_mask, self.source_mask
|
325
|
+
elif self.partial_source:
|
326
|
+
return template, source, igt, self.source_mask
|
327
|
+
elif self.partial_template:
|
328
|
+
return template, source, igt, self.template_mask
|
329
|
+
else:
|
330
|
+
return template, source, igt
|
331
|
+
|
332
|
+
|
333
|
+
class SegmentationData(Dataset):
|
334
|
+
def __init__(self):
|
335
|
+
super(SegmentationData, self).__init__()
|
336
|
+
|
337
|
+
def __len__(self):
|
338
|
+
pass
|
339
|
+
|
340
|
+
def __getitem__(self, index):
|
341
|
+
pass
|
342
|
+
|
343
|
+
|
344
|
+
class FlowData(Dataset):
|
345
|
+
def __init__(self):
|
346
|
+
super(FlowData, self).__init__()
|
347
|
+
self.pc1, self.pc2, self.flow = self.read_data()
|
348
|
+
|
349
|
+
def __len__(self):
|
350
|
+
if isinstance(self.pc1, np.ndarray):
|
351
|
+
return self.pc1.shape[0]
|
352
|
+
elif isinstance(self.pc1, list):
|
353
|
+
return len(self.pc1)
|
354
|
+
else:
|
355
|
+
raise UnknownDataTypeError
|
356
|
+
|
357
|
+
def read_data(self):
|
358
|
+
pass
|
359
|
+
|
360
|
+
def __getitem__(self, index):
|
361
|
+
return self.pc1[index], self.pc2[index], self.flow[index]
|
362
|
+
|
363
|
+
|
364
|
+
class SceneflowDataset(Dataset):
|
365
|
+
def __init__(self, npoints=1024, root='', partition='train'):
|
366
|
+
if root == '':
|
367
|
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
368
|
+
DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
|
369
|
+
root = os.path.join(DATA_DIR, 'data_processed_maxcut_35_20k_2k_8192')
|
370
|
+
if not os.path.exists(root):
|
371
|
+
print("To download dataset, click here: https://drive.google.com/file/d/1CMaxdt-Tg1Wct8v8eGNwuT7qRSIyJPY-/view")
|
372
|
+
exit()
|
373
|
+
else:
|
374
|
+
print("SceneflowDataset Found Successfully!")
|
375
|
+
|
376
|
+
self.npoints = npoints
|
377
|
+
self.partition = partition
|
378
|
+
self.root = root
|
379
|
+
if self.partition=='train':
|
380
|
+
self.datapath = glob.glob(os.path.join(self.root, 'TRAIN*.npz'))
|
381
|
+
else:
|
382
|
+
self.datapath = glob.glob(os.path.join(self.root, 'TEST*.npz'))
|
383
|
+
self.cache = {}
|
384
|
+
self.cache_size = 30000
|
385
|
+
|
386
|
+
###### deal with one bad datapoint with nan value
|
387
|
+
self.datapath = [d for d in self.datapath if 'TRAIN_C_0140_left_0006-0' not in d]
|
388
|
+
######
|
389
|
+
print(self.partition, ': ',len(self.datapath))
|
390
|
+
|
391
|
+
def __getitem__(self, index):
|
392
|
+
if index in self.cache:
|
393
|
+
pos1, pos2, color1, color2, flow, mask1 = self.cache[index]
|
394
|
+
else:
|
395
|
+
fn = self.datapath[index]
|
396
|
+
with open(fn, 'rb') as fp:
|
397
|
+
data = np.load(fp)
|
398
|
+
pos1 = data['points1'].astype('float32')
|
399
|
+
pos2 = data['points2'].astype('float32')
|
400
|
+
color1 = data['color1'].astype('float32')
|
401
|
+
color2 = data['color2'].astype('float32')
|
402
|
+
flow = data['flow'].astype('float32')
|
403
|
+
mask1 = data['valid_mask1']
|
404
|
+
|
405
|
+
if len(self.cache) < self.cache_size:
|
406
|
+
self.cache[index] = (pos1, pos2, color1, color2, flow, mask1)
|
407
|
+
|
408
|
+
if self.partition == 'train':
|
409
|
+
n1 = pos1.shape[0]
|
410
|
+
sample_idx1 = np.random.choice(n1, self.npoints, replace=False)
|
411
|
+
n2 = pos2.shape[0]
|
412
|
+
sample_idx2 = np.random.choice(n2, self.npoints, replace=False)
|
413
|
+
|
414
|
+
pos1 = pos1[sample_idx1, :]
|
415
|
+
pos2 = pos2[sample_idx2, :]
|
416
|
+
color1 = color1[sample_idx1, :]
|
417
|
+
color2 = color2[sample_idx2, :]
|
418
|
+
flow = flow[sample_idx1, :]
|
419
|
+
mask1 = mask1[sample_idx1]
|
420
|
+
else:
|
421
|
+
pos1 = pos1[:self.npoints, :]
|
422
|
+
pos2 = pos2[:self.npoints, :]
|
423
|
+
color1 = color1[:self.npoints, :]
|
424
|
+
color2 = color2[:self.npoints, :]
|
425
|
+
flow = flow[:self.npoints, :]
|
426
|
+
mask1 = mask1[:self.npoints]
|
427
|
+
|
428
|
+
pos1_center = np.mean(pos1, 0)
|
429
|
+
pos1 -= pos1_center
|
430
|
+
pos2 -= pos1_center
|
431
|
+
|
432
|
+
return pos1, pos2, color1, color2, flow, mask1
|
433
|
+
|
434
|
+
def __len__(self):
|
435
|
+
return len(self.datapath)
|
436
|
+
|
437
|
+
|
438
|
+
if __name__ == '__main__':
|
439
|
+
class Data():
|
440
|
+
def __init__(self):
|
441
|
+
super(Data, self).__init__()
|
442
|
+
self.data, self.label = self.read_data()
|
443
|
+
|
444
|
+
def read_data(self):
|
445
|
+
return [4,5,6], [4,5,6]
|
446
|
+
|
447
|
+
def __len__(self):
|
448
|
+
return len(self.data)
|
449
|
+
|
450
|
+
def __getitem__(self, idx):
|
451
|
+
return self.data[idx], self.label[idx]
|
452
|
+
|
453
|
+
cd = RegistrationData('abc')
|
454
|
+
import ipdb; ipdb.set_trace()
|
@@ -0,0 +1,119 @@
|
|
1
|
+
import os
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
|
5
|
+
class ClassificationData:
|
6
|
+
def __init__(self, data_dict):
|
7
|
+
self.data_dict = data_dict
|
8
|
+
self.pcs = self.find_attribute('pcs')
|
9
|
+
self.labels = self.find_attribute('labels')
|
10
|
+
self.check_data()
|
11
|
+
|
12
|
+
def find_attribute(self, attribute):
|
13
|
+
try:
|
14
|
+
attribute_data = self.data_dict[attribute]
|
15
|
+
except:
|
16
|
+
print("Given data directory has no key attribute \"{}\"".format(attribute))
|
17
|
+
return attribute_data
|
18
|
+
|
19
|
+
def check_data(self):
|
20
|
+
assert 1 < len(self.pcs.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.pcs.shape)
|
21
|
+
assert 0 < len(self.labels.shape) < 3, "Error in dimension of labels! Given data dimension: {}".format(self.labels.shape)
|
22
|
+
|
23
|
+
if len(self.pcs.shape)==2: self.pcs = self.pcs.reshape(1, -1, 3)
|
24
|
+
if len(self.labels.shape) == 1: self.labels = self.labels.reshape(1, -1)
|
25
|
+
|
26
|
+
assert self.pcs.shape[0] == self.labels.shape[0], "Inconsistency in the number of point clouds and number of ground truth labels!"
|
27
|
+
|
28
|
+
|
29
|
+
def __len__(self):
|
30
|
+
return self.pcs.shape[0]
|
31
|
+
|
32
|
+
def __getitem__(self, index):
|
33
|
+
return torch.tensor(self.pcs[index]).float(), torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
|
34
|
+
|
35
|
+
|
36
|
+
class RegistrationData:
|
37
|
+
def __init__(self, data_dict):
|
38
|
+
self.data_dict = data_dict
|
39
|
+
self.template = self.find_attribute('template')
|
40
|
+
self.source = self.find_attribute('source')
|
41
|
+
self.transformation = self.find_attribute('transformation')
|
42
|
+
self.check_data()
|
43
|
+
|
44
|
+
def find_attribute(self, attribute):
|
45
|
+
try:
|
46
|
+
attribute_data = self.data[attribute]
|
47
|
+
except:
|
48
|
+
print("Given data directory has no key attribute \"{}\"".format(attribute))
|
49
|
+
return attribute_data
|
50
|
+
|
51
|
+
def check_data(self):
|
52
|
+
assert 1 < len(self.template.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.template.shape)
|
53
|
+
assert 1 < len(self.source.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.source.shape)
|
54
|
+
assert 1 < len(self.transformation.shape) < 4, "Error in dimension of transformations! Given data dimension: {}".format(self.transformation.shape)
|
55
|
+
|
56
|
+
if len(self.template.shape)==2: self.template = self.template.reshape(1, -1, 3)
|
57
|
+
if len(self.source.shape)==2: self.source = self.source.reshape(1, -1, 3)
|
58
|
+
if len(self.transformation.shape) == 2: self.transformation = self.transformation.reshape(1, 4, 4)
|
59
|
+
|
60
|
+
assert self.template.shape[0] == self.source.shape[0], "Inconsistency in the number of template and source point clouds!"
|
61
|
+
assert self.source.shape[0] == self.transformation.shape[0], "Inconsistency in the number of transformation and source point clouds!"
|
62
|
+
|
63
|
+
def __len__(self):
|
64
|
+
return self.template.shape[0]
|
65
|
+
|
66
|
+
def __getitem__(self, index):
|
67
|
+
return torch.tensor(self.template[index]).float(), torch.tensor(self.source[index]).float(), torch.tensor(self.transformation[index]).float()
|
68
|
+
|
69
|
+
|
70
|
+
class FlowData:
|
71
|
+
def __init__(self, data_dict):
|
72
|
+
self.data_dict = data_dict
|
73
|
+
self.frame1 = self.find_attribute('frame1')
|
74
|
+
self.frame2 = self.find_attribute('frame2')
|
75
|
+
self.flow = self.find_attribute('flow')
|
76
|
+
self.check_data()
|
77
|
+
|
78
|
+
def find_attribute(self, attribute):
|
79
|
+
try:
|
80
|
+
attribute_data = self.data[attribute]
|
81
|
+
except:
|
82
|
+
print("Given data directory has no key attribute \"{}\"".format(attribute))
|
83
|
+
return attribute_data
|
84
|
+
|
85
|
+
def check_data(self):
|
86
|
+
assert 1 < len(self.frame1.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame1.shape)
|
87
|
+
assert 1 < len(self.frame2.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame2.shape)
|
88
|
+
assert 1 < len(self.flow.shape) < 4, "Error in dimension of flow! Given data dimension: {}".format(self.flow.shape)
|
89
|
+
|
90
|
+
if len(self.frame1.shape)==2: self.frame1 = self.frame1.reshape(1, -1, 3)
|
91
|
+
if len(self.frame2.shape)==2: self.frame2 = self.frame2.reshape(1, -1, 3)
|
92
|
+
if len(self.flow.shape) == 2: self.flow = self.flow.reshape(1, -1, 3)
|
93
|
+
|
94
|
+
assert self.frame1.shape[0] == self.frame2.shape[0], "Inconsistency in the number of frame1 and frame2 point clouds!"
|
95
|
+
assert self.frame2.shape[0] == self.flow.shape[0], "Inconsistency in the number of flow and frame2 point clouds!"
|
96
|
+
|
97
|
+
def __len__(self):
|
98
|
+
return self.frame1.shape[0]
|
99
|
+
|
100
|
+
def __getitem__(self, index):
|
101
|
+
return torch.tensor(self.frame1[index]).float(), torch.tensor(self.frame2[index]).float(), torch.tensor(self.flow[index]).float()
|
102
|
+
|
103
|
+
|
104
|
+
class UserData:
|
105
|
+
def __init__(self, application, data_dict):
|
106
|
+
self.application = application
|
107
|
+
|
108
|
+
if self.application == 'classification':
|
109
|
+
self.data_class = ClassificationData(data_dict)
|
110
|
+
elif self.application == 'registration':
|
111
|
+
self.data_class = RegistrationData(data_dict)
|
112
|
+
elif self.application == 'flow_estimation':
|
113
|
+
self.data_class = FlowData(data_dict)
|
114
|
+
|
115
|
+
def __len__(self):
|
116
|
+
return len(self.data_class)
|
117
|
+
|
118
|
+
def __getitem__(self, index):
|
119
|
+
return self.data_class[index]
|