ocnn 2.2.1__py3-none-any.whl → 2.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -158
- ocnn/models/__init__.py +29 -27
- ocnn/models/autoencoder.py +155 -165
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -0
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +20 -20
- ocnn/modules/modules.py +193 -231
- ocnn/modules/resblocks.py +124 -124
- ocnn/nn/__init__.py +43 -42
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +429 -411
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -204
- ocnn/nn/octree_gconv.py +79 -0
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +86 -86
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -21
- ocnn/octree/octree.py +639 -601
- ocnn/octree/points.py +322 -298
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -153
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/LICENSE +21 -21
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/METADATA +79 -65
- ocnn-2.2.3.dist-info/RECORD +36 -0
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/WHEEL +1 -1
- ocnn-2.2.1.dist-info/RECORD +0 -34
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/top_level.txt +0 -0
ocnn/octree/points.py
CHANGED
|
@@ -1,298 +1,322 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import numpy as np
|
|
10
|
-
from typing import Optional, Union, List
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class Points:
|
|
14
|
-
r''' Represents a point cloud and contains some elementary transformations.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
points (torch.Tensor): The coordinates of the points with a shape of
|
|
18
|
-
:obj:`(N, 3)`, where :obj:`N` is the number of points.
|
|
19
|
-
normals (torch.Tensor or None): The point normals with a shape of
|
|
20
|
-
:obj:`(N, 3)`.
|
|
21
|
-
features (torch.Tensor or None): The point features with a shape of
|
|
22
|
-
:obj:`(N, C)`, where :obj:`C` is the channel of features.
|
|
23
|
-
labels (torch.Tensor or None): The point labels with a shape of
|
|
24
|
-
:obj:`(N, K)`, where :obj:`K` is the channel of labels.
|
|
25
|
-
batch_id (torch.Tensor or None): The batch indices for each point with a
|
|
26
|
-
shape of :obj:`(N, 1)`.
|
|
27
|
-
batch_size (int): The batch size.
|
|
28
|
-
'''
|
|
29
|
-
|
|
30
|
-
def __init__(self, points: torch.Tensor,
|
|
31
|
-
normals: Optional[torch.Tensor] = None,
|
|
32
|
-
features: Optional[torch.Tensor] = None,
|
|
33
|
-
labels: Optional[torch.Tensor] = None,
|
|
34
|
-
batch_id: Optional[torch.Tensor] = None,
|
|
35
|
-
batch_size: int = 1):
|
|
36
|
-
super().__init__()
|
|
37
|
-
self.points = points
|
|
38
|
-
self.normals = normals
|
|
39
|
-
self.features = features
|
|
40
|
-
self.labels = labels
|
|
41
|
-
self.batch_id = batch_id
|
|
42
|
-
self.batch_size = batch_size
|
|
43
|
-
self.device = points.device
|
|
44
|
-
self.batch_npt = None # valid after `merge_points`
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
if self.
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
if
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
'''
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
def
|
|
135
|
-
r'''
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
'''
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import numpy as np
|
|
10
|
+
from typing import Optional, Union, List
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Points:
|
|
14
|
+
r''' Represents a point cloud and contains some elementary transformations.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
points (torch.Tensor): The coordinates of the points with a shape of
|
|
18
|
+
:obj:`(N, 3)`, where :obj:`N` is the number of points.
|
|
19
|
+
normals (torch.Tensor or None): The point normals with a shape of
|
|
20
|
+
:obj:`(N, 3)`.
|
|
21
|
+
features (torch.Tensor or None): The point features with a shape of
|
|
22
|
+
:obj:`(N, C)`, where :obj:`C` is the channel of features.
|
|
23
|
+
labels (torch.Tensor or None): The point labels with a shape of
|
|
24
|
+
:obj:`(N, K)`, where :obj:`K` is the channel of labels.
|
|
25
|
+
batch_id (torch.Tensor or None): The batch indices for each point with a
|
|
26
|
+
shape of :obj:`(N, 1)`.
|
|
27
|
+
batch_size (int): The batch size.
|
|
28
|
+
'''
|
|
29
|
+
|
|
30
|
+
def __init__(self, points: torch.Tensor,
|
|
31
|
+
normals: Optional[torch.Tensor] = None,
|
|
32
|
+
features: Optional[torch.Tensor] = None,
|
|
33
|
+
labels: Optional[torch.Tensor] = None,
|
|
34
|
+
batch_id: Optional[torch.Tensor] = None,
|
|
35
|
+
batch_size: int = 1):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.points = points
|
|
38
|
+
self.normals = normals
|
|
39
|
+
self.features = features
|
|
40
|
+
self.labels = labels
|
|
41
|
+
self.batch_id = batch_id
|
|
42
|
+
self.batch_size = batch_size
|
|
43
|
+
self.device = points.device
|
|
44
|
+
self.batch_npt = None # valid after `merge_points`
|
|
45
|
+
self.check_input()
|
|
46
|
+
|
|
47
|
+
def check_input(self):
|
|
48
|
+
r''' Checks the input arguments.
|
|
49
|
+
'''
|
|
50
|
+
|
|
51
|
+
assert self.points.dim() == 2 and self.points.size(1) == 3
|
|
52
|
+
if self.normals is not None:
|
|
53
|
+
assert self.normals.dim() == 2 and self.normals.size(1) == 3
|
|
54
|
+
assert self.normals.size(0) == self.points.size(0)
|
|
55
|
+
if self.features is not None:
|
|
56
|
+
assert self.features.dim() == 2
|
|
57
|
+
assert self.features.size(0) == self.points.size(0)
|
|
58
|
+
if self.labels is not None:
|
|
59
|
+
assert self.labels.dim() == 2 or self.labels.dim() == 1
|
|
60
|
+
assert self.labels.size(0) == self.points.size(0)
|
|
61
|
+
if self.labels.dim() == 1:
|
|
62
|
+
self.labels = self.labels.unsqueeze(1)
|
|
63
|
+
if self.batch_id is not None:
|
|
64
|
+
assert self.batch_id.dim() == 2 or self.batch_id.dim() == 1
|
|
65
|
+
assert self.batch_id.size(0) == self.points.size(0)
|
|
66
|
+
assert self.batch_id.size(1) == 1
|
|
67
|
+
if self.batch_id.dim() == 1:
|
|
68
|
+
self.batch_id = self.batch_id.unsqueeze(1)
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def npt(self):
|
|
72
|
+
return self.points.shape[0]
|
|
73
|
+
|
|
74
|
+
def orient_normal(self, axis: str = 'x'):
|
|
75
|
+
r''' Orients the point normals along a given axis.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
axis (int): The coordinate axes, choose from :obj:`x`, :obj:`y` and
|
|
79
|
+
:obj:`z`. (default: :obj:`x`)
|
|
80
|
+
'''
|
|
81
|
+
|
|
82
|
+
if self.normals is None:
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
axis_map = {'x': 0, 'y': 1, 'z': 2, 'xyz': 3}
|
|
86
|
+
idx = axis_map[axis]
|
|
87
|
+
if idx < 3:
|
|
88
|
+
flags = self.normals[:, idx] > 0
|
|
89
|
+
flags = flags.float() * 2.0 - 1.0 # [0, 1] -> [-1, 1]
|
|
90
|
+
self.normals = self.normals * flags.unsqueeze(1)
|
|
91
|
+
else:
|
|
92
|
+
self.normals.abs_()
|
|
93
|
+
|
|
94
|
+
def scale(self, factor: torch.Tensor):
|
|
95
|
+
r''' Rescales the point cloud.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
factor (torch.Tensor): The scale factor with shape :obj:`(3,)`.
|
|
99
|
+
'''
|
|
100
|
+
|
|
101
|
+
non_zero = (factor != 0).all()
|
|
102
|
+
all_ones = (factor == 1.0).all()
|
|
103
|
+
non_uniform = (factor != factor[0]).any()
|
|
104
|
+
assert non_zero, 'The scale factor must not constain 0.'
|
|
105
|
+
if all_ones: return
|
|
106
|
+
|
|
107
|
+
factor = factor.to(self.device)
|
|
108
|
+
self.points = self.points * factor
|
|
109
|
+
if self.normals is not None and non_uniform:
|
|
110
|
+
ifactor = 1.0 / factor
|
|
111
|
+
self.normals = self.normals * ifactor
|
|
112
|
+
norm2 = torch.sqrt(torch.sum(self.normals ** 2, dim=1, keepdim=True))
|
|
113
|
+
self.normals = self.normals / torch.clamp(norm2, min=1.0e-12)
|
|
114
|
+
|
|
115
|
+
def rotate(self, angle: torch.Tensor):
|
|
116
|
+
r''' Rotates the point cloud.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
angle (torch.Tensor): The rotation angles in radian with shape :obj:`(3,)`.
|
|
120
|
+
'''
|
|
121
|
+
|
|
122
|
+
cos, sin = angle.cos(), angle.sin()
|
|
123
|
+
# rotx, roty, rotz are actually the transpose of the rotation matrices
|
|
124
|
+
rotx = torch.Tensor([[1, 0, 0], [0, cos[0], sin[0]], [0, -sin[0], cos[0]]])
|
|
125
|
+
roty = torch.Tensor([[cos[1], 0, -sin[1]], [0, 1, 0], [sin[1], 0, cos[1]]])
|
|
126
|
+
rotz = torch.Tensor([[cos[2], sin[2], 0], [-sin[2], cos[2], 0], [0, 0, 1]])
|
|
127
|
+
rot = rotx @ roty @ rotz
|
|
128
|
+
|
|
129
|
+
rot = rot.to(self.device)
|
|
130
|
+
self.points = self.points @ rot
|
|
131
|
+
if self.normals is not None:
|
|
132
|
+
self.normals = self.normals @ rot
|
|
133
|
+
|
|
134
|
+
def translate(self, dis: torch.Tensor):
|
|
135
|
+
r''' Translates the point cloud.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
dis (torch.Tensor): The displacement with shape :obj:`(3,)`.
|
|
139
|
+
'''
|
|
140
|
+
|
|
141
|
+
dis = dis.to(self.device)
|
|
142
|
+
self.points = self.points + dis
|
|
143
|
+
|
|
144
|
+
def flip(self, axis: str):
|
|
145
|
+
r''' Flips the point cloud along the given :attr:`axis`.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
axis (str): The flipping axis, choosen from :obj:`x`, :obj:`y`, and :obj`z`.
|
|
149
|
+
'''
|
|
150
|
+
|
|
151
|
+
axis_map = {'x': 0, 'y': 1, 'z': 2}
|
|
152
|
+
for x in axis:
|
|
153
|
+
idx = axis_map[x]
|
|
154
|
+
self.points[:, idx] *= -1.0
|
|
155
|
+
if self.normals is not None:
|
|
156
|
+
self.normals[:, idx] *= -1.0
|
|
157
|
+
|
|
158
|
+
def clip(self, min: float = -1.0, max: float = 1.0, esp: float = 0.01):
|
|
159
|
+
r''' Clips the point cloud to :obj:`[min+esp, max-esp]` and returns the mask.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
min (float): The minimum value to clip.
|
|
163
|
+
max (float): The maximum value to clip.
|
|
164
|
+
esp (float): The margin.
|
|
165
|
+
'''
|
|
166
|
+
|
|
167
|
+
mask = self.inbox_mask(min + esp, max - esp)
|
|
168
|
+
tmp = self.__getitem__(mask)
|
|
169
|
+
self.__dict__.update(tmp.__dict__)
|
|
170
|
+
return mask
|
|
171
|
+
|
|
172
|
+
def __getitem__(self, mask: torch.Tensor):
|
|
173
|
+
r''' Slices the point cloud according a given :attr:`mask`.
|
|
174
|
+
'''
|
|
175
|
+
|
|
176
|
+
dummy_pts = torch.zeros(1, 3, device=self.device)
|
|
177
|
+
out = Points(dummy_pts, batch_size=self.batch_size)
|
|
178
|
+
|
|
179
|
+
out.points = self.points[mask]
|
|
180
|
+
if self.normals is not None:
|
|
181
|
+
out.normals = self.normals[mask]
|
|
182
|
+
if self.features is not None:
|
|
183
|
+
out.features = self.features[mask]
|
|
184
|
+
if self.labels is not None:
|
|
185
|
+
out.labels = self.labels[mask]
|
|
186
|
+
if self.batch_id is not None:
|
|
187
|
+
out.batch_id = self.batch_id[mask]
|
|
188
|
+
return out
|
|
189
|
+
|
|
190
|
+
def inbox_mask(self, bbmin: Union[float, torch.Tensor] = -1.0,
|
|
191
|
+
bbmax: Union[float, torch.Tensor] = 1.0):
|
|
192
|
+
r''' Returns a mask indicating whether the points are within the specified
|
|
193
|
+
bounding box or not.
|
|
194
|
+
'''
|
|
195
|
+
|
|
196
|
+
mask_min = torch.all(self.points > bbmin, dim=1)
|
|
197
|
+
mask_max = torch.all(self.points < bbmax, dim=1)
|
|
198
|
+
mask = torch.logical_and(mask_min, mask_max)
|
|
199
|
+
return mask
|
|
200
|
+
|
|
201
|
+
def bbox(self):
|
|
202
|
+
r''' Returns the bounding box.
|
|
203
|
+
'''
|
|
204
|
+
|
|
205
|
+
# torch.min and torch.max return (value, indices)
|
|
206
|
+
bbmin = self.points.min(dim=0)
|
|
207
|
+
bbmax = self.points.max(dim=0)
|
|
208
|
+
return bbmin[0], bbmax[0]
|
|
209
|
+
|
|
210
|
+
def normalize(self, bbmin: torch.Tensor, bbmax: torch.Tensor,
|
|
211
|
+
scale: float = 1.0):
|
|
212
|
+
r''' Normalizes the point cloud to :obj:`[-scale, scale]`.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
bbmin (torch.Tensor): The minimum coordinates of the bounding box.
|
|
216
|
+
bbmax (torch.Tensor): The maximum coordinates of the bounding box.
|
|
217
|
+
scale (float): The scale factor
|
|
218
|
+
'''
|
|
219
|
+
|
|
220
|
+
center = (bbmin + bbmax) * 0.5
|
|
221
|
+
box_size = (bbmax - bbmin).max() + 1.0e-6
|
|
222
|
+
self.points = (self.points - center) * (2.0 * scale / box_size)
|
|
223
|
+
|
|
224
|
+
def to(self, device: Union[torch.device, str], non_blocking: bool = False):
|
|
225
|
+
r''' Moves the Points to a specified device.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
device (torch.device or str): The destination device.
|
|
229
|
+
non_blocking (bool): If True and the source is in pinned memory, the copy
|
|
230
|
+
will be asynchronous with respect to the host. Otherwise, the argument
|
|
231
|
+
has no effect. Default: False.
|
|
232
|
+
'''
|
|
233
|
+
|
|
234
|
+
if isinstance(device, str):
|
|
235
|
+
device = torch.device(device)
|
|
236
|
+
|
|
237
|
+
# If on the save device, directly retrun self
|
|
238
|
+
if self.device == device:
|
|
239
|
+
return self
|
|
240
|
+
|
|
241
|
+
# Construct a new Points on the specified device
|
|
242
|
+
points = Points(torch.zeros(1, 3, device=device))
|
|
243
|
+
points.batch_npt = self.batch_npt
|
|
244
|
+
points.points = self.points.to(device, non_blocking=non_blocking)
|
|
245
|
+
if self.normals is not None:
|
|
246
|
+
points.normals = self.normals.to(device, non_blocking=non_blocking)
|
|
247
|
+
if self.features is not None:
|
|
248
|
+
points.features = self.features.to(device, non_blocking=non_blocking)
|
|
249
|
+
if self.labels is not None:
|
|
250
|
+
points.labels = self.labels.to(device, non_blocking=non_blocking)
|
|
251
|
+
if self.batch_id is not None:
|
|
252
|
+
points.batch_id = self.batch_id.to(device, non_blocking=non_blocking)
|
|
253
|
+
return points
|
|
254
|
+
|
|
255
|
+
def cuda(self, non_blocking: bool = False):
|
|
256
|
+
r''' Moves the Points to the GPU. '''
|
|
257
|
+
|
|
258
|
+
return self.to('cuda', non_blocking)
|
|
259
|
+
|
|
260
|
+
def cpu(self):
|
|
261
|
+
r''' Moves the Points to the CPU. '''
|
|
262
|
+
|
|
263
|
+
return self.to('cpu')
|
|
264
|
+
|
|
265
|
+
def save(self, filename: str, info: str = 'PNFL'):
|
|
266
|
+
r''' Save the Points into npz or xyz files.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
filename (str): The output filename.
|
|
270
|
+
info (str): The infomation for saving: 'P' -> 'points', 'N' -> 'normals',
|
|
271
|
+
'F' -> 'features', 'L' -> 'labels', 'B' -> 'batch_id'.
|
|
272
|
+
'''
|
|
273
|
+
|
|
274
|
+
mapping = {
|
|
275
|
+
'P': ('points', self.points), 'N': ('normals', self.normals),
|
|
276
|
+
'F': ('features', self.features), 'L': ('labels', self.labels),
|
|
277
|
+
'B': ('batch_id', self.batch_id), }
|
|
278
|
+
|
|
279
|
+
names, outs = [], []
|
|
280
|
+
for key in info.upper():
|
|
281
|
+
name, out = mapping[key]
|
|
282
|
+
if out is not None:
|
|
283
|
+
names.append(name)
|
|
284
|
+
if out.dim() == 1:
|
|
285
|
+
out = out.unsqueeze(1)
|
|
286
|
+
outs.append(out.cpu().numpy())
|
|
287
|
+
|
|
288
|
+
if filename.endswith('npz'):
|
|
289
|
+
out_dict = dict(zip(names, outs))
|
|
290
|
+
np.savez(filename, **out_dict)
|
|
291
|
+
elif filename.endswith('xyz'):
|
|
292
|
+
out_array = np.concatenate(outs, axis=1)
|
|
293
|
+
np.savetxt(filename, out_array, fmt='%.6f')
|
|
294
|
+
else:
|
|
295
|
+
raise ValueError
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def merge_points(points: List['Points'], update_batch_info: bool = True):
|
|
299
|
+
r''' Merges a list of points into one batch.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
points (List[Octree]): A list of points to merge. The batch size of each
|
|
303
|
+
points in the list is assumed to be 1, and the :obj:`batch_size`,
|
|
304
|
+
:obj:`batch_id`, and :obj:`batch_npt` in the points are ignored.
|
|
305
|
+
'''
|
|
306
|
+
|
|
307
|
+
out = Points(torch.zeros(1, 3))
|
|
308
|
+
out.points = torch.cat([p.points for p in points], dim=0)
|
|
309
|
+
if points[0].normals is not None:
|
|
310
|
+
out.normals = torch.cat([p.normals for p in points], dim=0)
|
|
311
|
+
if points[0].features is not None:
|
|
312
|
+
out.features = torch.cat([p.features for p in points], dim=0)
|
|
313
|
+
if points[0].labels is not None:
|
|
314
|
+
out.labels = torch.cat([p.labels for p in points], dim=0)
|
|
315
|
+
out.device = points[0].device
|
|
316
|
+
|
|
317
|
+
if update_batch_info:
|
|
318
|
+
out.batch_size = len(points)
|
|
319
|
+
out.batch_npt = torch.Tensor([p.npt for p in points]).long()
|
|
320
|
+
out.batch_id = torch.cat([p.points.new_full((p.npt, 1), i)
|
|
321
|
+
for i, p in enumerate(points)], dim=0)
|
|
322
|
+
return out
|