ocnn 2.2.6__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +44 -44
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +429 -429
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +661 -661
- ocnn/octree/points.py +323 -322
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info}/METADATA +112 -91
- ocnn-2.2.7.dist-info/RECORD +36 -0
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info}/WHEEL +1 -1
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info/licenses}/LICENSE +21 -21
- ocnn-2.2.6.dist-info/RECORD +0 -36
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info}/top_level.txt +0 -0
ocnn/octree/octree.py
CHANGED
|
@@ -1,661 +1,661 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.nn.functional as F
|
|
10
|
-
from typing import Union, List
|
|
11
|
-
|
|
12
|
-
import ocnn
|
|
13
|
-
from ocnn.octree.points import Points
|
|
14
|
-
from ocnn.octree.shuffled_key import xyz2key, key2xyz
|
|
15
|
-
from ocnn.utils import range_grid, scatter_add, cumsum, trunc_div
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class Octree:
|
|
19
|
-
r''' Builds an octree from an input point cloud.
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
depth (int): The octree depth.
|
|
23
|
-
full_depth (int): The octree layers with a depth small than
|
|
24
|
-
:attr:`full_depth` are forced to be full.
|
|
25
|
-
batch_size (int): The octree batch size.
|
|
26
|
-
device (torch.device or str): Choose from :obj:`cpu` and :obj:`gpu`.
|
|
27
|
-
(default: :obj:`cpu`)
|
|
28
|
-
|
|
29
|
-
.. note::
|
|
30
|
-
The octree data structure requires that if an octree node has children nodes,
|
|
31
|
-
the number of children nodes is exactly 8, in which some of the nodes are
|
|
32
|
-
empty and some nodes are non-empty. The properties of an octree, including
|
|
33
|
-
:obj:`keys`, :obj:`children` and :obj:`neighs`, contain both non-empty and
|
|
34
|
-
empty nodes, and other properties, including :obj:`features`, :obj:`normals`
|
|
35
|
-
and :obj:`points`, contain only non-empty nodes.
|
|
36
|
-
|
|
37
|
-
.. note::
|
|
38
|
-
The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
|
|
39
|
-
is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
|
|
40
|
-
some margin.
|
|
41
|
-
'''
|
|
42
|
-
|
|
43
|
-
def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
|
|
44
|
-
device: Union[torch.device, str] = 'cpu', **kwargs):
|
|
45
|
-
super().__init__()
|
|
46
|
-
self.depth = depth
|
|
47
|
-
self.full_depth = full_depth
|
|
48
|
-
self.batch_size = batch_size
|
|
49
|
-
self.device = device
|
|
50
|
-
|
|
51
|
-
self.reset()
|
|
52
|
-
|
|
53
|
-
def reset(self):
|
|
54
|
-
r''' Resets the Octree status and constructs several lookup tables.
|
|
55
|
-
'''
|
|
56
|
-
|
|
57
|
-
# octree features in each octree layers
|
|
58
|
-
num = self.depth + 1
|
|
59
|
-
self.keys = [None] * num
|
|
60
|
-
self.children = [None] * num
|
|
61
|
-
self.neighs = [None] * num
|
|
62
|
-
self.features = [None] * num
|
|
63
|
-
self.normals = [None] * num
|
|
64
|
-
self.points = [None] * num
|
|
65
|
-
|
|
66
|
-
# octree node numbers in each octree layers.
|
|
67
|
-
# TODO: decide whether to settle them to 'gpu' or not?
|
|
68
|
-
self.nnum = torch.zeros(num, dtype=torch.long)
|
|
69
|
-
self.nnum_nempty = torch.zeros(num, dtype=torch.long)
|
|
70
|
-
|
|
71
|
-
# the following properties are valid after `merge_octrees`.
|
|
72
|
-
# TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
|
|
73
|
-
batch_size = self.batch_size
|
|
74
|
-
self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
|
|
75
|
-
self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
|
|
76
|
-
|
|
77
|
-
# construct the look up tables for neighborhood searching
|
|
78
|
-
device = self.device
|
|
79
|
-
center_grid = range_grid(2, 3, device) # (8, 3)
|
|
80
|
-
displacement = range_grid(-1, 1, device) # (27, 3)
|
|
81
|
-
neigh_grid = center_grid.unsqueeze(1) + displacement # (8, 27, 3)
|
|
82
|
-
parent_grid = trunc_div(neigh_grid, 2)
|
|
83
|
-
child_grid = neigh_grid % 2
|
|
84
|
-
self.lut_parent = torch.sum(
|
|
85
|
-
parent_grid * torch.tensor([9, 3, 1], device=device), dim=2)
|
|
86
|
-
self.lut_child = torch.sum(
|
|
87
|
-
child_grid * torch.tensor([4, 2, 1], device=device), dim=2)
|
|
88
|
-
|
|
89
|
-
# lookup tables for different kernel sizes
|
|
90
|
-
self.lut_kernel = {
|
|
91
|
-
'222': torch.tensor([13, 14, 16, 17, 22, 23, 25, 26], device=device),
|
|
92
|
-
'311': torch.tensor([4, 13, 22], device=device),
|
|
93
|
-
'131': torch.tensor([10, 13, 16], device=device),
|
|
94
|
-
'113': torch.tensor([12, 13, 14], device=device),
|
|
95
|
-
'331': torch.tensor([1, 4, 7, 10, 13, 16, 19, 22, 25], device=device),
|
|
96
|
-
'313': torch.tensor([3, 4, 5, 12, 13, 14, 21, 22, 23], device=device),
|
|
97
|
-
'133': torch.tensor([9, 10, 11, 12, 13, 14, 15, 16, 17], device=device),
|
|
98
|
-
}
|
|
99
|
-
|
|
100
|
-
def key(self, depth: int, nempty: bool = False):
|
|
101
|
-
r''' Returns the shuffled key of each octree node.
|
|
102
|
-
|
|
103
|
-
Args:
|
|
104
|
-
depth (int): The depth of the octree.
|
|
105
|
-
nempty (bool): If True, returns the results of non-empty octree nodes.
|
|
106
|
-
'''
|
|
107
|
-
|
|
108
|
-
key = self.keys[depth]
|
|
109
|
-
if nempty:
|
|
110
|
-
mask = self.nempty_mask(depth)
|
|
111
|
-
key = key[mask]
|
|
112
|
-
return key
|
|
113
|
-
|
|
114
|
-
def xyzb(self, depth: int, nempty: bool = False):
|
|
115
|
-
r''' Returns the xyz coordinates and the batch indices of each octree node.
|
|
116
|
-
|
|
117
|
-
Args:
|
|
118
|
-
depth (int): The depth of the octree.
|
|
119
|
-
nempty (bool): If True, returns the results of non-empty octree nodes.
|
|
120
|
-
'''
|
|
121
|
-
|
|
122
|
-
key = self.key(depth, nempty)
|
|
123
|
-
return key2xyz(key, depth)
|
|
124
|
-
|
|
125
|
-
def batch_id(self, depth: int, nempty: bool = False):
|
|
126
|
-
r''' Returns the batch indices of each octree node.
|
|
127
|
-
|
|
128
|
-
Args:
|
|
129
|
-
depth (int): The depth of the octree.
|
|
130
|
-
nempty (bool): If True, returns the results of non-empty octree nodes.
|
|
131
|
-
'''
|
|
132
|
-
|
|
133
|
-
batch_id = self.keys[depth] >> 48
|
|
134
|
-
if nempty:
|
|
135
|
-
mask = self.nempty_mask(depth)
|
|
136
|
-
batch_id = batch_id[mask]
|
|
137
|
-
return batch_id
|
|
138
|
-
|
|
139
|
-
def nempty_mask(self, depth: int):
|
|
140
|
-
r''' Returns a binary mask which indicates whether the cooreponding octree
|
|
141
|
-
node is empty or not.
|
|
142
|
-
|
|
143
|
-
Args:
|
|
144
|
-
depth (int): The depth of the octree.
|
|
145
|
-
'''
|
|
146
|
-
|
|
147
|
-
return self.children[depth] >= 0
|
|
148
|
-
|
|
149
|
-
def build_octree(self, point_cloud: Points):
|
|
150
|
-
r''' Builds an octree from a point cloud.
|
|
151
|
-
|
|
152
|
-
Args:
|
|
153
|
-
point_cloud (Points): The input point cloud.
|
|
154
|
-
|
|
155
|
-
.. note::
|
|
156
|
-
The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
|
|
157
|
-
is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
|
|
158
|
-
some margin.
|
|
159
|
-
'''
|
|
160
|
-
|
|
161
|
-
self.device = point_cloud.device
|
|
162
|
-
assert point_cloud.batch_size == self.batch_size, 'Inconsistent batch_size'
|
|
163
|
-
|
|
164
|
-
# normalize points from [-1, 1] to [0, 2^depth]. #[L:Scale]
|
|
165
|
-
scale = 2 ** (self.depth - 1)
|
|
166
|
-
points = (point_cloud.points + 1.0) * scale
|
|
167
|
-
|
|
168
|
-
# get the shuffled key and sort
|
|
169
|
-
x, y, z = points[:, 0], points[:, 1], points[:, 2]
|
|
170
|
-
b = None if self.batch_size == 1 else point_cloud.batch_id.view(-1)
|
|
171
|
-
key = xyz2key(x, y, z, b, self.depth)
|
|
172
|
-
node_key, idx, counts = torch.unique(
|
|
173
|
-
key, sorted=True, return_inverse=True, return_counts=True)
|
|
174
|
-
|
|
175
|
-
# layer 0 to full_layer: the octree is full in these layers
|
|
176
|
-
for d in range(self.full_depth+1):
|
|
177
|
-
self.octree_grow_full(d, update_neigh=False)
|
|
178
|
-
|
|
179
|
-
# layer depth_ to full_layer_
|
|
180
|
-
for d in range(self.depth, self.full_depth, -1):
|
|
181
|
-
# compute parent key, i.e. keys of layer (d -1)
|
|
182
|
-
pkey = node_key >> 3
|
|
183
|
-
pkey, pidx, _ = torch.unique_consecutive(
|
|
184
|
-
pkey, return_inverse=True, return_counts=True)
|
|
185
|
-
|
|
186
|
-
# augmented key
|
|
187
|
-
key = (pkey.unsqueeze(-1) << 3) + torch.arange(8, device=self.device)
|
|
188
|
-
self.keys[d] = key.view(-1)
|
|
189
|
-
self.nnum[d] = key.numel()
|
|
190
|
-
self.nnum_nempty[d] = node_key.numel()
|
|
191
|
-
|
|
192
|
-
# children
|
|
193
|
-
addr = (pidx << 3) | (node_key % 8)
|
|
194
|
-
children = -torch.ones(
|
|
195
|
-
self.nnum[d].item(), dtype=torch.int32, device=self.device)
|
|
196
|
-
children[addr] = torch.arange(
|
|
197
|
-
self.nnum_nempty[d], dtype=torch.int32, device=self.device)
|
|
198
|
-
self.children[d] = children
|
|
199
|
-
|
|
200
|
-
# cache pkey for the next iteration
|
|
201
|
-
# Use `pkey >> 45` instead of `pkey >> 48` in L199 since pkey is already
|
|
202
|
-
# shifted to the right by 3 bits in L177
|
|
203
|
-
node_key = pkey if self.batch_size == 1 else \
|
|
204
|
-
((pkey >> 45) << 48) | (pkey & ((1 << 45) - 1))
|
|
205
|
-
|
|
206
|
-
# set the children for the layer full_layer,
|
|
207
|
-
# now the node_keys are the key for full_layer
|
|
208
|
-
d = self.full_depth
|
|
209
|
-
children = -torch.ones_like(self.children[d])
|
|
210
|
-
nempty_idx = node_key if self.batch_size == 1 else \
|
|
211
|
-
((node_key >> 48) << (3 * d)) | (node_key & ((1 << 48) - 1))
|
|
212
|
-
children[nempty_idx] = torch.arange(
|
|
213
|
-
node_key.numel(), dtype=torch.int32, device=self.device)
|
|
214
|
-
self.children[d] = children
|
|
215
|
-
self.nnum_nempty[d] = node_key.numel()
|
|
216
|
-
|
|
217
|
-
# average the signal for the last octree layer
|
|
218
|
-
d = self.depth
|
|
219
|
-
points = scatter_add(points, idx, dim=0) # points is rescaled in [L:Scale]
|
|
220
|
-
self.points[d] = points / counts.unsqueeze(1)
|
|
221
|
-
if point_cloud.normals is not None:
|
|
222
|
-
normals = scatter_add(point_cloud.normals, idx, dim=0)
|
|
223
|
-
self.normals[d] = F.normalize(normals)
|
|
224
|
-
if point_cloud.features is not None:
|
|
225
|
-
features = scatter_add(point_cloud.features, idx, dim=0)
|
|
226
|
-
self.features[d] = features / counts.unsqueeze(1)
|
|
227
|
-
|
|
228
|
-
return idx
|
|
229
|
-
|
|
230
|
-
def octree_grow_full(self, depth: int, update_neigh: bool = True):
|
|
231
|
-
r''' Builds the full octree, which is essentially a dense volumetric grid.
|
|
232
|
-
|
|
233
|
-
Args:
|
|
234
|
-
depth (int): The depth of the octree.
|
|
235
|
-
update_neigh (bool): If True, construct the neighborhood indices.
|
|
236
|
-
'''
|
|
237
|
-
|
|
238
|
-
# check
|
|
239
|
-
assert depth <= self.full_depth, 'error'
|
|
240
|
-
|
|
241
|
-
# node number
|
|
242
|
-
num = 1 << (3 * depth)
|
|
243
|
-
self.nnum[depth] = num * self.batch_size
|
|
244
|
-
self.nnum_nempty[depth] = num * self.batch_size
|
|
245
|
-
|
|
246
|
-
# update key
|
|
247
|
-
key = torch.arange(num, dtype=torch.long, device=self.device)
|
|
248
|
-
bs = torch.arange(self.batch_size, dtype=torch.long, device=self.device)
|
|
249
|
-
key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
|
|
250
|
-
self.keys[depth] = key.view(-1)
|
|
251
|
-
|
|
252
|
-
# update children
|
|
253
|
-
self.children[depth] = torch.arange(
|
|
254
|
-
num * self.batch_size, dtype=torch.int32, device=self.device)
|
|
255
|
-
|
|
256
|
-
# update neigh if needed
|
|
257
|
-
if update_neigh:
|
|
258
|
-
self.construct_neigh(depth)
|
|
259
|
-
|
|
260
|
-
def octree_split(self, split: torch.Tensor, depth: int):
|
|
261
|
-
r''' Sets whether the octree nodes in :attr:`depth` are splitted or not.
|
|
262
|
-
|
|
263
|
-
Args:
|
|
264
|
-
split (torch.Tensor): The input tensor with its element indicating status
|
|
265
|
-
of each octree node: 0 - empty, 1 - non-empty or splitted.
|
|
266
|
-
depth (int): The depth of current octree.
|
|
267
|
-
'''
|
|
268
|
-
|
|
269
|
-
# split -> children
|
|
270
|
-
empty = split == 0
|
|
271
|
-
sum = cumsum(split, dim=0, exclusive=True)
|
|
272
|
-
children, nnum_nempty = torch.split(sum, [split.shape[0], 1])
|
|
273
|
-
children[empty] = -1
|
|
274
|
-
|
|
275
|
-
# boundary case, make sure that at least one octree node is splitted
|
|
276
|
-
if nnum_nempty == 0:
|
|
277
|
-
nnum_nempty = 1
|
|
278
|
-
children[0] = 0
|
|
279
|
-
|
|
280
|
-
# update octree
|
|
281
|
-
self.children[depth] = children.int()
|
|
282
|
-
self.nnum_nempty[depth] = nnum_nempty
|
|
283
|
-
|
|
284
|
-
def octree_grow(self, depth: int, update_neigh: bool = True):
|
|
285
|
-
r''' Grows the octree and updates the relevant properties. And in most
|
|
286
|
-
cases, call :func:`Octree.octree_split` to update the splitting status of
|
|
287
|
-
the octree before this function.
|
|
288
|
-
|
|
289
|
-
Args:
|
|
290
|
-
depth (int): The depth of the octree.
|
|
291
|
-
update_neigh (bool): If True, construct the neighborhood indices.
|
|
292
|
-
'''
|
|
293
|
-
|
|
294
|
-
# increase the octree depth if required
|
|
295
|
-
if depth > self.depth:
|
|
296
|
-
assert depth == self.depth + 1
|
|
297
|
-
self.depth = depth
|
|
298
|
-
self.keys.append(None)
|
|
299
|
-
self.children.append(None)
|
|
300
|
-
self.neighs.append(None)
|
|
301
|
-
self.features.append(None)
|
|
302
|
-
self.normals.append(None)
|
|
303
|
-
self.points.append(None)
|
|
304
|
-
zero = torch.zeros(1, dtype=torch.long)
|
|
305
|
-
self.nnum = torch.cat([self.nnum, zero])
|
|
306
|
-
self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
|
|
307
|
-
zero = zero.view(1, 1)
|
|
308
|
-
self.batch_nnum = torch.cat([self.batch_nnum, zero], dim=0)
|
|
309
|
-
self.batch_nnum_nempty = torch.cat([self.batch_nnum_nempty, zero], dim=0)
|
|
310
|
-
|
|
311
|
-
# node number
|
|
312
|
-
nnum = self.nnum_nempty[depth-1] * 8
|
|
313
|
-
self.nnum[depth] = nnum
|
|
314
|
-
self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
|
|
315
|
-
|
|
316
|
-
# update keys
|
|
317
|
-
key = self.key(depth-1, nempty=True)
|
|
318
|
-
batch_id = (key >> 48) << 48
|
|
319
|
-
key = (key & ((1 << 48) - 1)) << 3
|
|
320
|
-
key = key | batch_id
|
|
321
|
-
key = key.unsqueeze(1) + torch.arange(8, device=key.device)
|
|
322
|
-
self.keys[depth] = key.view(-1)
|
|
323
|
-
|
|
324
|
-
# update children
|
|
325
|
-
self.children[depth] = torch.arange(
|
|
326
|
-
nnum, dtype=torch.int32, device=self.device)
|
|
327
|
-
|
|
328
|
-
# update neighs
|
|
329
|
-
if update_neigh:
|
|
330
|
-
self.construct_neigh(depth)
|
|
331
|
-
|
|
332
|
-
def construct_neigh(self, depth: int):
|
|
333
|
-
r''' Constructs the :obj:`3x3x3` neighbors for each octree node.
|
|
334
|
-
|
|
335
|
-
Args:
|
|
336
|
-
depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
|
|
337
|
-
'''
|
|
338
|
-
|
|
339
|
-
if depth <= self.full_depth:
|
|
340
|
-
device = self.device
|
|
341
|
-
nnum = 1 << (3 * depth)
|
|
342
|
-
key = torch.arange(nnum, dtype=torch.long, device=device)
|
|
343
|
-
x, y, z, _ = key2xyz(key, depth)
|
|
344
|
-
xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
|
|
345
|
-
grid = range_grid(-1, 1, device) # (27, 3)
|
|
346
|
-
xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
|
|
347
|
-
xyz = xyz.view(-1, 3) # (N*27, 3)
|
|
348
|
-
neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
|
|
349
|
-
|
|
350
|
-
bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
|
|
351
|
-
neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
|
|
352
|
-
|
|
353
|
-
bound = 1 << depth
|
|
354
|
-
invalid = torch.logical_or((xyz < 0).any(1), (xyz >= bound).any(1))
|
|
355
|
-
neigh[:, invalid] = -1
|
|
356
|
-
self.neighs[depth] = neigh.view(-1, 27) # (B*N, 27)
|
|
357
|
-
|
|
358
|
-
else:
|
|
359
|
-
child_p = self.children[depth-1]
|
|
360
|
-
nempty = child_p >= 0
|
|
361
|
-
neigh_p = self.neighs[depth-1][nempty] # (N, 27)
|
|
362
|
-
neigh_p = neigh_p[:, self.lut_parent] # (N, 8, 27)
|
|
363
|
-
child_p = child_p[neigh_p] # (N, 8, 27)
|
|
364
|
-
invalid = torch.logical_or(child_p < 0, neigh_p < 0) # (N, 8, 27)
|
|
365
|
-
neigh = child_p * 8 + self.lut_child
|
|
366
|
-
neigh[invalid] = -1
|
|
367
|
-
self.neighs[depth] = neigh.view(-1, 27)
|
|
368
|
-
|
|
369
|
-
def construct_all_neigh(self):
|
|
370
|
-
r''' A convenient handler for constructing all neighbors.
|
|
371
|
-
'''
|
|
372
|
-
|
|
373
|
-
for depth in range(1, self.depth+1):
|
|
374
|
-
self.construct_neigh(depth)
|
|
375
|
-
|
|
376
|
-
def search_xyzb(self, query: torch.Tensor, depth: int, nempty: bool = False):
|
|
377
|
-
r''' Searches the octree nodes given the query points.
|
|
378
|
-
|
|
379
|
-
Args:
|
|
380
|
-
query (torch.Tensor): The coordinates of query points with shape
|
|
381
|
-
:obj:`(N, 4)`. The first 3 channels of the coordinates are :obj:`x`,
|
|
382
|
-
:obj:`y`, and :obj:`z`, and the last channel is the batch index. Note
|
|
383
|
-
that the coordinates must be in range :obj:`[0, 2^depth)`.
|
|
384
|
-
depth (int): The depth of the octree layer. nemtpy (bool): If true, only
|
|
385
|
-
searches the non-empty octree nodes.
|
|
386
|
-
'''
|
|
387
|
-
|
|
388
|
-
key = xyz2key(query[:, 0], query[:, 1], query[:, 2], query[:, 3], depth)
|
|
389
|
-
idx = self.search_key(key, depth, nempty)
|
|
390
|
-
return idx
|
|
391
|
-
|
|
392
|
-
def search_key(self, query: torch.Tensor, depth: int, nempty: bool = False):
|
|
393
|
-
r''' Searches the octree nodes given the query points.
|
|
394
|
-
|
|
395
|
-
Args:
|
|
396
|
-
query (torch.Tensor): The keys of query points with shape :obj:`(N,)`,
|
|
397
|
-
which are computed from the coordinates of query points.
|
|
398
|
-
depth (int): The depth of the octree layer. nemtpy (bool): If true, only
|
|
399
|
-
searches the non-empty octree nodes.
|
|
400
|
-
'''
|
|
401
|
-
|
|
402
|
-
key = self.key(depth, nempty)
|
|
403
|
-
# `torch.bucketize` is similar to `torch.searchsorted`.
|
|
404
|
-
# I choose `torch.bucketize` here because it has fewer dimension checks,
|
|
405
|
-
# resulting in slightly better performance according to the docs of
|
|
406
|
-
# pytorch-1.9.1, since `key` is always 1-D sorted sequence.
|
|
407
|
-
# https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
|
|
408
|
-
idx = torch.bucketize(query, key)
|
|
409
|
-
|
|
410
|
-
valid = idx < key.shape[0] # valid if in-bound
|
|
411
|
-
found = key[idx[valid]] == query[valid]
|
|
412
|
-
valid[valid.clone()] = found # valid if found
|
|
413
|
-
idx[valid.logical_not()] = -1 # set to -1 if invalid
|
|
414
|
-
return idx
|
|
415
|
-
|
|
416
|
-
def get_neigh(self, depth: int, kernel: str = '333', stride: int = 1,
|
|
417
|
-
nempty: bool = False):
|
|
418
|
-
r''' Returns the neighborhoods given the depth and a kernel shape.
|
|
419
|
-
|
|
420
|
-
Args:
|
|
421
|
-
depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
|
|
422
|
-
kernel (str): The kernel shape from :obj:`333`, :obj:`311`, :obj:`131`,
|
|
423
|
-
:obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and :obj:`313`.
|
|
424
|
-
stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
|
|
425
|
-
stride is :obj:`2`, always returns the neighborhood of the first
|
|
426
|
-
siblings.
|
|
427
|
-
nempty (bool): If True, only returns the neighborhoods of the non-empty
|
|
428
|
-
octree nodes.
|
|
429
|
-
'''
|
|
430
|
-
|
|
431
|
-
if stride == 1:
|
|
432
|
-
neigh = self.neighs[depth]
|
|
433
|
-
elif stride == 2:
|
|
434
|
-
# clone neigh to avoid self.neigh[depth] being modified
|
|
435
|
-
neigh = self.neighs[depth][::8].clone()
|
|
436
|
-
else:
|
|
437
|
-
raise ValueError('Unsupported stride {}'.format(stride))
|
|
438
|
-
|
|
439
|
-
if nempty:
|
|
440
|
-
child = self.children[depth]
|
|
441
|
-
if stride == 1:
|
|
442
|
-
nempty_node = child >= 0
|
|
443
|
-
neigh = neigh[nempty_node]
|
|
444
|
-
valid = neigh >= 0
|
|
445
|
-
neigh[valid] = child[neigh[valid]].long() # remap the index
|
|
446
|
-
|
|
447
|
-
if kernel == '333':
|
|
448
|
-
return neigh
|
|
449
|
-
elif kernel in self.lut_kernel:
|
|
450
|
-
lut = self.lut_kernel[kernel]
|
|
451
|
-
return neigh[:, lut]
|
|
452
|
-
else:
|
|
453
|
-
raise ValueError('Unsupported kernel {}'.format(kernel))
|
|
454
|
-
|
|
455
|
-
def get_input_feature(self, feature: str, nempty: bool = False):
|
|
456
|
-
r''' Returns the initial input feature stored in octree.
|
|
457
|
-
|
|
458
|
-
Args:
|
|
459
|
-
feature (str): A string used to indicate which features to extract from
|
|
460
|
-
the input octree. If the character :obj:`N` is in :attr:`feature`, the
|
|
461
|
-
normal signal is extracted (3 channels). Similarly, if :obj:`D` is in
|
|
462
|
-
:attr:`feature`, the local displacement is extracted (1 channels). If
|
|
463
|
-
:obj:`L` is in :attr:`feature`, the local coordinates of the averaged
|
|
464
|
-
points in each octree node is extracted (3 channels). If :attr:`P` is
|
|
465
|
-
in :attr:`feature`, the global coordinates are extracted (3 channels).
|
|
466
|
-
If :attr:`F` is in :attr:`feature`, other features (like colors) are
|
|
467
|
-
extracted (k channels).
|
|
468
|
-
nempty (bool): If false, gets the features of all octree nodes.
|
|
469
|
-
'''
|
|
470
|
-
|
|
471
|
-
features = list()
|
|
472
|
-
depth = self.depth
|
|
473
|
-
feature = feature.upper()
|
|
474
|
-
if 'N' in feature:
|
|
475
|
-
features.append(self.normals[depth])
|
|
476
|
-
|
|
477
|
-
if 'L' in feature or 'D' in feature:
|
|
478
|
-
local_points = self.points[depth].frac() - 0.5
|
|
479
|
-
|
|
480
|
-
if 'D' in feature:
|
|
481
|
-
dis = torch.sum(local_points * self.normals[depth], dim=1, keepdim=True)
|
|
482
|
-
features.append(dis)
|
|
483
|
-
|
|
484
|
-
if 'L' in feature:
|
|
485
|
-
features.append(local_points)
|
|
486
|
-
|
|
487
|
-
if 'P' in feature:
|
|
488
|
-
scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1]
|
|
489
|
-
global_points = self.points[depth] * scale - 1.0
|
|
490
|
-
features.append(global_points)
|
|
491
|
-
|
|
492
|
-
if 'F' in feature:
|
|
493
|
-
features.append(self.features[depth])
|
|
494
|
-
|
|
495
|
-
out = torch.cat(features, dim=1)
|
|
496
|
-
if not nempty:
|
|
497
|
-
out = ocnn.nn.octree_pad(out, self, depth)
|
|
498
|
-
return out
|
|
499
|
-
|
|
500
|
-
def to_points(self, rescale: bool = True):
|
|
501
|
-
r''' Converts averaged points in the octree to a point cloud.
|
|
502
|
-
|
|
503
|
-
Args:
|
|
504
|
-
rescale (bool): rescale the xyz coordinates to [-1, 1] if True.
|
|
505
|
-
'''
|
|
506
|
-
|
|
507
|
-
depth = self.depth
|
|
508
|
-
batch_size = self.batch_size
|
|
509
|
-
|
|
510
|
-
# by default, use the average points generated when building the octree
|
|
511
|
-
# from the input point cloud
|
|
512
|
-
xyz = self.points[depth]
|
|
513
|
-
batch_id = self.batch_id(depth, nempty=True)
|
|
514
|
-
|
|
515
|
-
# xyz is None when the octree is predicted by a neural network
|
|
516
|
-
if xyz is None:
|
|
517
|
-
x, y, z, batch_id = self.xyzb(depth, nempty=True)
|
|
518
|
-
xyz = torch.stack([x, y, z], dim=1) + 0.5
|
|
519
|
-
|
|
520
|
-
# normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
|
|
521
|
-
if rescale:
|
|
522
|
-
scale = 2 ** (1 - depth)
|
|
523
|
-
xyz = xyz * scale - 1.0
|
|
524
|
-
|
|
525
|
-
# construct Points
|
|
526
|
-
out = Points(xyz, self.normals[depth], self.features[depth],
|
|
527
|
-
batch_id=batch_id, batch_size=batch_size)
|
|
528
|
-
return out
|
|
529
|
-
|
|
530
|
-
def to(self, device: Union[torch.device, str], non_blocking: bool = False):
|
|
531
|
-
r''' Moves the octree to a specified device.
|
|
532
|
-
|
|
533
|
-
Args:
|
|
534
|
-
device (torch.device or str): The destination device.
|
|
535
|
-
non_blocking (bool): If True and the source is in pinned memory, the copy
|
|
536
|
-
will be asynchronous with respect to the host. Otherwise, the argument
|
|
537
|
-
has no effect. Default: False.
|
|
538
|
-
'''
|
|
539
|
-
|
|
540
|
-
if isinstance(device, str):
|
|
541
|
-
device = torch.device(device)
|
|
542
|
-
|
|
543
|
-
# If on the save device, directly retrun self
|
|
544
|
-
if self.device == device:
|
|
545
|
-
return self
|
|
546
|
-
|
|
547
|
-
def list_to_device(prop):
|
|
548
|
-
return [p.to(device, non_blocking=non_blocking)
|
|
549
|
-
if isinstance(p, torch.Tensor) else None for p in prop]
|
|
550
|
-
|
|
551
|
-
# Construct a new Octree on the specified device
|
|
552
|
-
octree = Octree(self.depth, self.full_depth, self.batch_size, device)
|
|
553
|
-
octree.keys = list_to_device(self.keys)
|
|
554
|
-
octree.children = list_to_device(self.children)
|
|
555
|
-
octree.neighs = list_to_device(self.neighs)
|
|
556
|
-
octree.features = list_to_device(self.features)
|
|
557
|
-
octree.normals = list_to_device(self.normals)
|
|
558
|
-
octree.points = list_to_device(self.points)
|
|
559
|
-
octree.nnum = self.nnum.clone() # TODO: whether to move nnum to the self.device?
|
|
560
|
-
octree.nnum_nempty = self.nnum_nempty.clone()
|
|
561
|
-
octree.batch_nnum = self.batch_nnum.clone()
|
|
562
|
-
octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
|
|
563
|
-
return octree
|
|
564
|
-
|
|
565
|
-
def cuda(self, non_blocking: bool = False):
|
|
566
|
-
r''' Moves the octree to the GPU. '''
|
|
567
|
-
|
|
568
|
-
return self.to('cuda', non_blocking)
|
|
569
|
-
|
|
570
|
-
def cpu(self):
|
|
571
|
-
r''' Moves the octree to the CPU. '''
|
|
572
|
-
|
|
573
|
-
return self.to('cpu')
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
def merge_octrees(octrees: List['Octree']):
|
|
577
|
-
r''' Merges a list of octrees into one batch.
|
|
578
|
-
|
|
579
|
-
Args:
|
|
580
|
-
octrees (List[Octree]): A list of octrees to merge.
|
|
581
|
-
|
|
582
|
-
Returns:
|
|
583
|
-
Octree: The merged octree.
|
|
584
|
-
'''
|
|
585
|
-
|
|
586
|
-
# init and check
|
|
587
|
-
octree = Octree(depth=octrees[0].depth, full_depth=octrees[0].full_depth,
|
|
588
|
-
batch_size=len(octrees), device=octrees[0].device)
|
|
589
|
-
for i in range(1, octree.batch_size):
|
|
590
|
-
condition = (octrees[i].depth == octree.depth and
|
|
591
|
-
octrees[i].full_depth == octree.full_depth and
|
|
592
|
-
octrees[i].device == octree.device)
|
|
593
|
-
assert condition, 'The check of merge_octrees failed'
|
|
594
|
-
|
|
595
|
-
# node num
|
|
596
|
-
batch_nnum = torch.stack(
|
|
597
|
-
[octrees[i].nnum for i in range(octree.batch_size)], dim=1)
|
|
598
|
-
batch_nnum_nempty = torch.stack(
|
|
599
|
-
[octrees[i].nnum_nempty for i in range(octree.batch_size)], dim=1)
|
|
600
|
-
octree.nnum = torch.sum(batch_nnum, dim=1)
|
|
601
|
-
octree.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
|
|
602
|
-
octree.batch_nnum = batch_nnum
|
|
603
|
-
octree.batch_nnum_nempty = batch_nnum_nempty
|
|
604
|
-
nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
|
|
605
|
-
|
|
606
|
-
# merge octre properties
|
|
607
|
-
for d in range(octree.depth+1):
|
|
608
|
-
# key
|
|
609
|
-
keys = [None] * octree.batch_size
|
|
610
|
-
for i in range(octree.batch_size):
|
|
611
|
-
key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
|
|
612
|
-
keys[i] = key | (i << 48)
|
|
613
|
-
octree.keys[d] = torch.cat(keys, dim=0)
|
|
614
|
-
|
|
615
|
-
# children
|
|
616
|
-
children = [None] * octree.batch_size
|
|
617
|
-
for i in range(octree.batch_size):
|
|
618
|
-
child = octrees[i].children[d].clone() # !! `clone` is used here to avoid
|
|
619
|
-
mask = child >= 0 # !! modifying the original octrees
|
|
620
|
-
child[mask] = child[mask] + nnum_cum[d, i]
|
|
621
|
-
children[i] = child
|
|
622
|
-
octree.children[d] = torch.cat(children, dim=0)
|
|
623
|
-
|
|
624
|
-
# features
|
|
625
|
-
if octrees[0].features[d] is not None and d == octree.depth:
|
|
626
|
-
features = [octrees[i].features[d] for i in range(octree.batch_size)]
|
|
627
|
-
octree.features[d] = torch.cat(features, dim=0)
|
|
628
|
-
|
|
629
|
-
# normals
|
|
630
|
-
if octrees[0].normals[d] is not None and d == octree.depth:
|
|
631
|
-
normals = [octrees[i].normals[d] for i in range(octree.batch_size)]
|
|
632
|
-
octree.normals[d] = torch.cat(normals, dim=0)
|
|
633
|
-
|
|
634
|
-
# points
|
|
635
|
-
if octrees[0].points[d] is not None and d == octree.depth:
|
|
636
|
-
points = [octrees[i].points[d] for i in range(octree.batch_size)]
|
|
637
|
-
octree.points[d] = torch.cat(points, dim=0)
|
|
638
|
-
|
|
639
|
-
return octree
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
def init_octree(depth: int, full_depth: int = 2, batch_size: int = 1,
|
|
643
|
-
device: Union[torch.device, str] = 'cpu'):
|
|
644
|
-
r'''
|
|
645
|
-
Initializes an octree to :attr:`full_depth`.
|
|
646
|
-
|
|
647
|
-
Args:
|
|
648
|
-
depth (int): The depth of the octree.
|
|
649
|
-
full_depth (int): The octree layers with a depth small than
|
|
650
|
-
:attr:`full_depth` are forced to be full.
|
|
651
|
-
batch_size (int, optional): The batch size.
|
|
652
|
-
device (torch.device or str): The device to use for computation.
|
|
653
|
-
|
|
654
|
-
Returns:
|
|
655
|
-
Octree: The initialized Octree object.
|
|
656
|
-
'''
|
|
657
|
-
|
|
658
|
-
octree = Octree(depth, full_depth, batch_size, device)
|
|
659
|
-
for d in range(full_depth+1):
|
|
660
|
-
octree.octree_grow_full(depth=d)
|
|
661
|
-
return octree
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from typing import Union, List
|
|
11
|
+
|
|
12
|
+
import ocnn
|
|
13
|
+
from ocnn.octree.points import Points
|
|
14
|
+
from ocnn.octree.shuffled_key import xyz2key, key2xyz
|
|
15
|
+
from ocnn.utils import range_grid, scatter_add, cumsum, trunc_div
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Octree:
|
|
19
|
+
r''' Builds an octree from an input point cloud.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
depth (int): The octree depth.
|
|
23
|
+
full_depth (int): The octree layers with a depth small than
|
|
24
|
+
:attr:`full_depth` are forced to be full.
|
|
25
|
+
batch_size (int): The octree batch size.
|
|
26
|
+
device (torch.device or str): Choose from :obj:`cpu` and :obj:`gpu`.
|
|
27
|
+
(default: :obj:`cpu`)
|
|
28
|
+
|
|
29
|
+
.. note::
|
|
30
|
+
The octree data structure requires that if an octree node has children nodes,
|
|
31
|
+
the number of children nodes is exactly 8, in which some of the nodes are
|
|
32
|
+
empty and some nodes are non-empty. The properties of an octree, including
|
|
33
|
+
:obj:`keys`, :obj:`children` and :obj:`neighs`, contain both non-empty and
|
|
34
|
+
empty nodes, and other properties, including :obj:`features`, :obj:`normals`
|
|
35
|
+
and :obj:`points`, contain only non-empty nodes.
|
|
36
|
+
|
|
37
|
+
.. note::
|
|
38
|
+
The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
|
|
39
|
+
is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
|
|
40
|
+
some margin.
|
|
41
|
+
'''
|
|
42
|
+
|
|
43
|
+
def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
|
|
44
|
+
device: Union[torch.device, str] = 'cpu', **kwargs):
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.depth = depth
|
|
47
|
+
self.full_depth = full_depth
|
|
48
|
+
self.batch_size = batch_size
|
|
49
|
+
self.device = device
|
|
50
|
+
|
|
51
|
+
self.reset()
|
|
52
|
+
|
|
53
|
+
def reset(self):
|
|
54
|
+
r''' Resets the Octree status and constructs several lookup tables.
|
|
55
|
+
'''
|
|
56
|
+
|
|
57
|
+
# octree features in each octree layers
|
|
58
|
+
num = self.depth + 1
|
|
59
|
+
self.keys = [None] * num
|
|
60
|
+
self.children = [None] * num
|
|
61
|
+
self.neighs = [None] * num
|
|
62
|
+
self.features = [None] * num
|
|
63
|
+
self.normals = [None] * num
|
|
64
|
+
self.points = [None] * num
|
|
65
|
+
|
|
66
|
+
# octree node numbers in each octree layers.
|
|
67
|
+
# TODO: decide whether to settle them to 'gpu' or not?
|
|
68
|
+
self.nnum = torch.zeros(num, dtype=torch.long)
|
|
69
|
+
self.nnum_nempty = torch.zeros(num, dtype=torch.long)
|
|
70
|
+
|
|
71
|
+
# the following properties are valid after `merge_octrees`.
|
|
72
|
+
# TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
|
|
73
|
+
batch_size = self.batch_size
|
|
74
|
+
self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
|
|
75
|
+
self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
|
|
76
|
+
|
|
77
|
+
# construct the look up tables for neighborhood searching
|
|
78
|
+
device = self.device
|
|
79
|
+
center_grid = range_grid(2, 3, device) # (8, 3)
|
|
80
|
+
displacement = range_grid(-1, 1, device) # (27, 3)
|
|
81
|
+
neigh_grid = center_grid.unsqueeze(1) + displacement # (8, 27, 3)
|
|
82
|
+
parent_grid = trunc_div(neigh_grid, 2)
|
|
83
|
+
child_grid = neigh_grid % 2
|
|
84
|
+
self.lut_parent = torch.sum(
|
|
85
|
+
parent_grid * torch.tensor([9, 3, 1], device=device), dim=2)
|
|
86
|
+
self.lut_child = torch.sum(
|
|
87
|
+
child_grid * torch.tensor([4, 2, 1], device=device), dim=2)
|
|
88
|
+
|
|
89
|
+
# lookup tables for different kernel sizes
|
|
90
|
+
self.lut_kernel = {
|
|
91
|
+
'222': torch.tensor([13, 14, 16, 17, 22, 23, 25, 26], device=device),
|
|
92
|
+
'311': torch.tensor([4, 13, 22], device=device),
|
|
93
|
+
'131': torch.tensor([10, 13, 16], device=device),
|
|
94
|
+
'113': torch.tensor([12, 13, 14], device=device),
|
|
95
|
+
'331': torch.tensor([1, 4, 7, 10, 13, 16, 19, 22, 25], device=device),
|
|
96
|
+
'313': torch.tensor([3, 4, 5, 12, 13, 14, 21, 22, 23], device=device),
|
|
97
|
+
'133': torch.tensor([9, 10, 11, 12, 13, 14, 15, 16, 17], device=device),
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
def key(self, depth: int, nempty: bool = False):
|
|
101
|
+
r''' Returns the shuffled key of each octree node.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
depth (int): The depth of the octree.
|
|
105
|
+
nempty (bool): If True, returns the results of non-empty octree nodes.
|
|
106
|
+
'''
|
|
107
|
+
|
|
108
|
+
key = self.keys[depth]
|
|
109
|
+
if nempty:
|
|
110
|
+
mask = self.nempty_mask(depth)
|
|
111
|
+
key = key[mask]
|
|
112
|
+
return key
|
|
113
|
+
|
|
114
|
+
def xyzb(self, depth: int, nempty: bool = False):
|
|
115
|
+
r''' Returns the xyz coordinates and the batch indices of each octree node.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
depth (int): The depth of the octree.
|
|
119
|
+
nempty (bool): If True, returns the results of non-empty octree nodes.
|
|
120
|
+
'''
|
|
121
|
+
|
|
122
|
+
key = self.key(depth, nempty)
|
|
123
|
+
return key2xyz(key, depth)
|
|
124
|
+
|
|
125
|
+
def batch_id(self, depth: int, nempty: bool = False):
|
|
126
|
+
r''' Returns the batch indices of each octree node.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
depth (int): The depth of the octree.
|
|
130
|
+
nempty (bool): If True, returns the results of non-empty octree nodes.
|
|
131
|
+
'''
|
|
132
|
+
|
|
133
|
+
batch_id = self.keys[depth] >> 48
|
|
134
|
+
if nempty:
|
|
135
|
+
mask = self.nempty_mask(depth)
|
|
136
|
+
batch_id = batch_id[mask]
|
|
137
|
+
return batch_id
|
|
138
|
+
|
|
139
|
+
def nempty_mask(self, depth: int):
|
|
140
|
+
r''' Returns a binary mask which indicates whether the cooreponding octree
|
|
141
|
+
node is empty or not.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
depth (int): The depth of the octree.
|
|
145
|
+
'''
|
|
146
|
+
|
|
147
|
+
return self.children[depth] >= 0
|
|
148
|
+
|
|
149
|
+
def build_octree(self, point_cloud: Points):
|
|
150
|
+
r''' Builds an octree from a point cloud.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
point_cloud (Points): The input point cloud.
|
|
154
|
+
|
|
155
|
+
.. note::
|
|
156
|
+
The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
|
|
157
|
+
is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
|
|
158
|
+
some margin.
|
|
159
|
+
'''
|
|
160
|
+
|
|
161
|
+
self.device = point_cloud.device
|
|
162
|
+
assert point_cloud.batch_size == self.batch_size, 'Inconsistent batch_size'
|
|
163
|
+
|
|
164
|
+
# normalize points from [-1, 1] to [0, 2^depth]. #[L:Scale]
|
|
165
|
+
scale = 2 ** (self.depth - 1)
|
|
166
|
+
points = (point_cloud.points + 1.0) * scale
|
|
167
|
+
|
|
168
|
+
# get the shuffled key and sort
|
|
169
|
+
x, y, z = points[:, 0], points[:, 1], points[:, 2]
|
|
170
|
+
b = None if self.batch_size == 1 else point_cloud.batch_id.view(-1)
|
|
171
|
+
key = xyz2key(x, y, z, b, self.depth)
|
|
172
|
+
node_key, idx, counts = torch.unique(
|
|
173
|
+
key, sorted=True, return_inverse=True, return_counts=True)
|
|
174
|
+
|
|
175
|
+
# layer 0 to full_layer: the octree is full in these layers
|
|
176
|
+
for d in range(self.full_depth+1):
|
|
177
|
+
self.octree_grow_full(d, update_neigh=False)
|
|
178
|
+
|
|
179
|
+
# layer depth_ to full_layer_
|
|
180
|
+
for d in range(self.depth, self.full_depth, -1):
|
|
181
|
+
# compute parent key, i.e. keys of layer (d -1)
|
|
182
|
+
pkey = node_key >> 3
|
|
183
|
+
pkey, pidx, _ = torch.unique_consecutive(
|
|
184
|
+
pkey, return_inverse=True, return_counts=True)
|
|
185
|
+
|
|
186
|
+
# augmented key
|
|
187
|
+
key = (pkey.unsqueeze(-1) << 3) + torch.arange(8, device=self.device)
|
|
188
|
+
self.keys[d] = key.view(-1)
|
|
189
|
+
self.nnum[d] = key.numel()
|
|
190
|
+
self.nnum_nempty[d] = node_key.numel()
|
|
191
|
+
|
|
192
|
+
# children
|
|
193
|
+
addr = (pidx << 3) | (node_key % 8)
|
|
194
|
+
children = -torch.ones(
|
|
195
|
+
self.nnum[d].item(), dtype=torch.int32, device=self.device)
|
|
196
|
+
children[addr] = torch.arange(
|
|
197
|
+
self.nnum_nempty[d], dtype=torch.int32, device=self.device)
|
|
198
|
+
self.children[d] = children
|
|
199
|
+
|
|
200
|
+
# cache pkey for the next iteration
|
|
201
|
+
# Use `pkey >> 45` instead of `pkey >> 48` in L199 since pkey is already
|
|
202
|
+
# shifted to the right by 3 bits in L177
|
|
203
|
+
node_key = pkey if self.batch_size == 1 else \
|
|
204
|
+
((pkey >> 45) << 48) | (pkey & ((1 << 45) - 1))
|
|
205
|
+
|
|
206
|
+
# set the children for the layer full_layer,
|
|
207
|
+
# now the node_keys are the key for full_layer
|
|
208
|
+
d = self.full_depth
|
|
209
|
+
children = -torch.ones_like(self.children[d])
|
|
210
|
+
nempty_idx = node_key if self.batch_size == 1 else \
|
|
211
|
+
((node_key >> 48) << (3 * d)) | (node_key & ((1 << 48) - 1))
|
|
212
|
+
children[nempty_idx] = torch.arange(
|
|
213
|
+
node_key.numel(), dtype=torch.int32, device=self.device)
|
|
214
|
+
self.children[d] = children
|
|
215
|
+
self.nnum_nempty[d] = node_key.numel()
|
|
216
|
+
|
|
217
|
+
# average the signal for the last octree layer
|
|
218
|
+
d = self.depth
|
|
219
|
+
points = scatter_add(points, idx, dim=0) # points is rescaled in [L:Scale]
|
|
220
|
+
self.points[d] = points / counts.unsqueeze(1)
|
|
221
|
+
if point_cloud.normals is not None:
|
|
222
|
+
normals = scatter_add(point_cloud.normals, idx, dim=0)
|
|
223
|
+
self.normals[d] = F.normalize(normals)
|
|
224
|
+
if point_cloud.features is not None:
|
|
225
|
+
features = scatter_add(point_cloud.features, idx, dim=0)
|
|
226
|
+
self.features[d] = features / counts.unsqueeze(1)
|
|
227
|
+
|
|
228
|
+
return idx
|
|
229
|
+
|
|
230
|
+
def octree_grow_full(self, depth: int, update_neigh: bool = True):
|
|
231
|
+
r''' Builds the full octree, which is essentially a dense volumetric grid.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
depth (int): The depth of the octree.
|
|
235
|
+
update_neigh (bool): If True, construct the neighborhood indices.
|
|
236
|
+
'''
|
|
237
|
+
|
|
238
|
+
# check
|
|
239
|
+
assert depth <= self.full_depth, 'error'
|
|
240
|
+
|
|
241
|
+
# node number
|
|
242
|
+
num = 1 << (3 * depth)
|
|
243
|
+
self.nnum[depth] = num * self.batch_size
|
|
244
|
+
self.nnum_nempty[depth] = num * self.batch_size
|
|
245
|
+
|
|
246
|
+
# update key
|
|
247
|
+
key = torch.arange(num, dtype=torch.long, device=self.device)
|
|
248
|
+
bs = torch.arange(self.batch_size, dtype=torch.long, device=self.device)
|
|
249
|
+
key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
|
|
250
|
+
self.keys[depth] = key.view(-1)
|
|
251
|
+
|
|
252
|
+
# update children
|
|
253
|
+
self.children[depth] = torch.arange(
|
|
254
|
+
num * self.batch_size, dtype=torch.int32, device=self.device)
|
|
255
|
+
|
|
256
|
+
# update neigh if needed
|
|
257
|
+
if update_neigh:
|
|
258
|
+
self.construct_neigh(depth)
|
|
259
|
+
|
|
260
|
+
def octree_split(self, split: torch.Tensor, depth: int):
|
|
261
|
+
r''' Sets whether the octree nodes in :attr:`depth` are splitted or not.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
split (torch.Tensor): The input tensor with its element indicating status
|
|
265
|
+
of each octree node: 0 - empty, 1 - non-empty or splitted.
|
|
266
|
+
depth (int): The depth of current octree.
|
|
267
|
+
'''
|
|
268
|
+
|
|
269
|
+
# split -> children
|
|
270
|
+
empty = split == 0
|
|
271
|
+
sum = cumsum(split, dim=0, exclusive=True)
|
|
272
|
+
children, nnum_nempty = torch.split(sum, [split.shape[0], 1])
|
|
273
|
+
children[empty] = -1
|
|
274
|
+
|
|
275
|
+
# boundary case, make sure that at least one octree node is splitted
|
|
276
|
+
if nnum_nempty == 0:
|
|
277
|
+
nnum_nempty = 1
|
|
278
|
+
children[0] = 0
|
|
279
|
+
|
|
280
|
+
# update octree
|
|
281
|
+
self.children[depth] = children.int()
|
|
282
|
+
self.nnum_nempty[depth] = nnum_nempty
|
|
283
|
+
|
|
284
|
+
def octree_grow(self, depth: int, update_neigh: bool = True):
|
|
285
|
+
r''' Grows the octree and updates the relevant properties. And in most
|
|
286
|
+
cases, call :func:`Octree.octree_split` to update the splitting status of
|
|
287
|
+
the octree before this function.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
depth (int): The depth of the octree.
|
|
291
|
+
update_neigh (bool): If True, construct the neighborhood indices.
|
|
292
|
+
'''
|
|
293
|
+
|
|
294
|
+
# increase the octree depth if required
|
|
295
|
+
if depth > self.depth:
|
|
296
|
+
assert depth == self.depth + 1
|
|
297
|
+
self.depth = depth
|
|
298
|
+
self.keys.append(None)
|
|
299
|
+
self.children.append(None)
|
|
300
|
+
self.neighs.append(None)
|
|
301
|
+
self.features.append(None)
|
|
302
|
+
self.normals.append(None)
|
|
303
|
+
self.points.append(None)
|
|
304
|
+
zero = torch.zeros(1, dtype=torch.long)
|
|
305
|
+
self.nnum = torch.cat([self.nnum, zero])
|
|
306
|
+
self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
|
|
307
|
+
zero = zero.view(1, 1)
|
|
308
|
+
self.batch_nnum = torch.cat([self.batch_nnum, zero], dim=0)
|
|
309
|
+
self.batch_nnum_nempty = torch.cat([self.batch_nnum_nempty, zero], dim=0)
|
|
310
|
+
|
|
311
|
+
# node number
|
|
312
|
+
nnum = self.nnum_nempty[depth-1] * 8
|
|
313
|
+
self.nnum[depth] = nnum
|
|
314
|
+
self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
|
|
315
|
+
|
|
316
|
+
# update keys
|
|
317
|
+
key = self.key(depth-1, nempty=True)
|
|
318
|
+
batch_id = (key >> 48) << 48
|
|
319
|
+
key = (key & ((1 << 48) - 1)) << 3
|
|
320
|
+
key = key | batch_id
|
|
321
|
+
key = key.unsqueeze(1) + torch.arange(8, device=key.device)
|
|
322
|
+
self.keys[depth] = key.view(-1)
|
|
323
|
+
|
|
324
|
+
# update children
|
|
325
|
+
self.children[depth] = torch.arange(
|
|
326
|
+
nnum, dtype=torch.int32, device=self.device)
|
|
327
|
+
|
|
328
|
+
# update neighs
|
|
329
|
+
if update_neigh:
|
|
330
|
+
self.construct_neigh(depth)
|
|
331
|
+
|
|
332
|
+
def construct_neigh(self, depth: int):
|
|
333
|
+
r''' Constructs the :obj:`3x3x3` neighbors for each octree node.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
|
|
337
|
+
'''
|
|
338
|
+
|
|
339
|
+
if depth <= self.full_depth:
|
|
340
|
+
device = self.device
|
|
341
|
+
nnum = 1 << (3 * depth)
|
|
342
|
+
key = torch.arange(nnum, dtype=torch.long, device=device)
|
|
343
|
+
x, y, z, _ = key2xyz(key, depth)
|
|
344
|
+
xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
|
|
345
|
+
grid = range_grid(-1, 1, device) # (27, 3)
|
|
346
|
+
xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
|
|
347
|
+
xyz = xyz.view(-1, 3) # (N*27, 3)
|
|
348
|
+
neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
|
|
349
|
+
|
|
350
|
+
bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
|
|
351
|
+
neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
|
|
352
|
+
|
|
353
|
+
bound = 1 << depth
|
|
354
|
+
invalid = torch.logical_or((xyz < 0).any(1), (xyz >= bound).any(1))
|
|
355
|
+
neigh[:, invalid] = -1
|
|
356
|
+
self.neighs[depth] = neigh.view(-1, 27) # (B*N, 27)
|
|
357
|
+
|
|
358
|
+
else:
|
|
359
|
+
child_p = self.children[depth-1]
|
|
360
|
+
nempty = child_p >= 0
|
|
361
|
+
neigh_p = self.neighs[depth-1][nempty] # (N, 27)
|
|
362
|
+
neigh_p = neigh_p[:, self.lut_parent] # (N, 8, 27)
|
|
363
|
+
child_p = child_p[neigh_p] # (N, 8, 27)
|
|
364
|
+
invalid = torch.logical_or(child_p < 0, neigh_p < 0) # (N, 8, 27)
|
|
365
|
+
neigh = child_p * 8 + self.lut_child
|
|
366
|
+
neigh[invalid] = -1
|
|
367
|
+
self.neighs[depth] = neigh.view(-1, 27)
|
|
368
|
+
|
|
369
|
+
def construct_all_neigh(self):
|
|
370
|
+
r''' A convenient handler for constructing all neighbors.
|
|
371
|
+
'''
|
|
372
|
+
|
|
373
|
+
for depth in range(1, self.depth+1):
|
|
374
|
+
self.construct_neigh(depth)
|
|
375
|
+
|
|
376
|
+
def search_xyzb(self, query: torch.Tensor, depth: int, nempty: bool = False):
|
|
377
|
+
r''' Searches the octree nodes given the query points.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
query (torch.Tensor): The coordinates of query points with shape
|
|
381
|
+
:obj:`(N, 4)`. The first 3 channels of the coordinates are :obj:`x`,
|
|
382
|
+
:obj:`y`, and :obj:`z`, and the last channel is the batch index. Note
|
|
383
|
+
that the coordinates must be in range :obj:`[0, 2^depth)`.
|
|
384
|
+
depth (int): The depth of the octree layer. nemtpy (bool): If true, only
|
|
385
|
+
searches the non-empty octree nodes.
|
|
386
|
+
'''
|
|
387
|
+
|
|
388
|
+
key = xyz2key(query[:, 0], query[:, 1], query[:, 2], query[:, 3], depth)
|
|
389
|
+
idx = self.search_key(key, depth, nempty)
|
|
390
|
+
return idx
|
|
391
|
+
|
|
392
|
+
def search_key(self, query: torch.Tensor, depth: int, nempty: bool = False):
|
|
393
|
+
r''' Searches the octree nodes given the query points.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
query (torch.Tensor): The keys of query points with shape :obj:`(N,)`,
|
|
397
|
+
which are computed from the coordinates of query points.
|
|
398
|
+
depth (int): The depth of the octree layer. nemtpy (bool): If true, only
|
|
399
|
+
searches the non-empty octree nodes.
|
|
400
|
+
'''
|
|
401
|
+
|
|
402
|
+
key = self.key(depth, nempty)
|
|
403
|
+
# `torch.bucketize` is similar to `torch.searchsorted`.
|
|
404
|
+
# I choose `torch.bucketize` here because it has fewer dimension checks,
|
|
405
|
+
# resulting in slightly better performance according to the docs of
|
|
406
|
+
# pytorch-1.9.1, since `key` is always 1-D sorted sequence.
|
|
407
|
+
# https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
|
|
408
|
+
idx = torch.bucketize(query, key)
|
|
409
|
+
|
|
410
|
+
valid = idx < key.shape[0] # valid if in-bound
|
|
411
|
+
found = key[idx[valid]] == query[valid]
|
|
412
|
+
valid[valid.clone()] = found # valid if found
|
|
413
|
+
idx[valid.logical_not()] = -1 # set to -1 if invalid
|
|
414
|
+
return idx
|
|
415
|
+
|
|
416
|
+
def get_neigh(self, depth: int, kernel: str = '333', stride: int = 1,
|
|
417
|
+
nempty: bool = False):
|
|
418
|
+
r''' Returns the neighborhoods given the depth and a kernel shape.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
|
|
422
|
+
kernel (str): The kernel shape from :obj:`333`, :obj:`311`, :obj:`131`,
|
|
423
|
+
:obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and :obj:`313`.
|
|
424
|
+
stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
|
|
425
|
+
stride is :obj:`2`, always returns the neighborhood of the first
|
|
426
|
+
siblings.
|
|
427
|
+
nempty (bool): If True, only returns the neighborhoods of the non-empty
|
|
428
|
+
octree nodes.
|
|
429
|
+
'''
|
|
430
|
+
|
|
431
|
+
if stride == 1:
|
|
432
|
+
neigh = self.neighs[depth]
|
|
433
|
+
elif stride == 2:
|
|
434
|
+
# clone neigh to avoid self.neigh[depth] being modified
|
|
435
|
+
neigh = self.neighs[depth][::8].clone()
|
|
436
|
+
else:
|
|
437
|
+
raise ValueError('Unsupported stride {}'.format(stride))
|
|
438
|
+
|
|
439
|
+
if nempty:
|
|
440
|
+
child = self.children[depth]
|
|
441
|
+
if stride == 1:
|
|
442
|
+
nempty_node = child >= 0
|
|
443
|
+
neigh = neigh[nempty_node]
|
|
444
|
+
valid = neigh >= 0
|
|
445
|
+
neigh[valid] = child[neigh[valid]].long() # remap the index
|
|
446
|
+
|
|
447
|
+
if kernel == '333':
|
|
448
|
+
return neigh
|
|
449
|
+
elif kernel in self.lut_kernel:
|
|
450
|
+
lut = self.lut_kernel[kernel]
|
|
451
|
+
return neigh[:, lut]
|
|
452
|
+
else:
|
|
453
|
+
raise ValueError('Unsupported kernel {}'.format(kernel))
|
|
454
|
+
|
|
455
|
+
def get_input_feature(self, feature: str, nempty: bool = False):
|
|
456
|
+
r''' Returns the initial input feature stored in octree.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
feature (str): A string used to indicate which features to extract from
|
|
460
|
+
the input octree. If the character :obj:`N` is in :attr:`feature`, the
|
|
461
|
+
normal signal is extracted (3 channels). Similarly, if :obj:`D` is in
|
|
462
|
+
:attr:`feature`, the local displacement is extracted (1 channels). If
|
|
463
|
+
:obj:`L` is in :attr:`feature`, the local coordinates of the averaged
|
|
464
|
+
points in each octree node is extracted (3 channels). If :attr:`P` is
|
|
465
|
+
in :attr:`feature`, the global coordinates are extracted (3 channels).
|
|
466
|
+
If :attr:`F` is in :attr:`feature`, other features (like colors) are
|
|
467
|
+
extracted (k channels).
|
|
468
|
+
nempty (bool): If false, gets the features of all octree nodes.
|
|
469
|
+
'''
|
|
470
|
+
|
|
471
|
+
features = list()
|
|
472
|
+
depth = self.depth
|
|
473
|
+
feature = feature.upper()
|
|
474
|
+
if 'N' in feature:
|
|
475
|
+
features.append(self.normals[depth])
|
|
476
|
+
|
|
477
|
+
if 'L' in feature or 'D' in feature:
|
|
478
|
+
local_points = self.points[depth].frac() - 0.5
|
|
479
|
+
|
|
480
|
+
if 'D' in feature:
|
|
481
|
+
dis = torch.sum(local_points * self.normals[depth], dim=1, keepdim=True)
|
|
482
|
+
features.append(dis)
|
|
483
|
+
|
|
484
|
+
if 'L' in feature:
|
|
485
|
+
features.append(local_points)
|
|
486
|
+
|
|
487
|
+
if 'P' in feature:
|
|
488
|
+
scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1]
|
|
489
|
+
global_points = self.points[depth] * scale - 1.0
|
|
490
|
+
features.append(global_points)
|
|
491
|
+
|
|
492
|
+
if 'F' in feature:
|
|
493
|
+
features.append(self.features[depth])
|
|
494
|
+
|
|
495
|
+
out = torch.cat(features, dim=1)
|
|
496
|
+
if not nempty:
|
|
497
|
+
out = ocnn.nn.octree_pad(out, self, depth)
|
|
498
|
+
return out
|
|
499
|
+
|
|
500
|
+
def to_points(self, rescale: bool = True):
|
|
501
|
+
r''' Converts averaged points in the octree to a point cloud.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
rescale (bool): rescale the xyz coordinates to [-1, 1] if True.
|
|
505
|
+
'''
|
|
506
|
+
|
|
507
|
+
depth = self.depth
|
|
508
|
+
batch_size = self.batch_size
|
|
509
|
+
|
|
510
|
+
# by default, use the average points generated when building the octree
|
|
511
|
+
# from the input point cloud
|
|
512
|
+
xyz = self.points[depth]
|
|
513
|
+
batch_id = self.batch_id(depth, nempty=True)
|
|
514
|
+
|
|
515
|
+
# xyz is None when the octree is predicted by a neural network
|
|
516
|
+
if xyz is None:
|
|
517
|
+
x, y, z, batch_id = self.xyzb(depth, nempty=True)
|
|
518
|
+
xyz = torch.stack([x, y, z], dim=1) + 0.5
|
|
519
|
+
|
|
520
|
+
# normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
|
|
521
|
+
if rescale:
|
|
522
|
+
scale = 2 ** (1 - depth)
|
|
523
|
+
xyz = xyz * scale - 1.0
|
|
524
|
+
|
|
525
|
+
# construct Points
|
|
526
|
+
out = Points(xyz, self.normals[depth], self.features[depth],
|
|
527
|
+
batch_id=batch_id, batch_size=batch_size)
|
|
528
|
+
return out
|
|
529
|
+
|
|
530
|
+
def to(self, device: Union[torch.device, str], non_blocking: bool = False):
|
|
531
|
+
r''' Moves the octree to a specified device.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
device (torch.device or str): The destination device.
|
|
535
|
+
non_blocking (bool): If True and the source is in pinned memory, the copy
|
|
536
|
+
will be asynchronous with respect to the host. Otherwise, the argument
|
|
537
|
+
has no effect. Default: False.
|
|
538
|
+
'''
|
|
539
|
+
|
|
540
|
+
if isinstance(device, str):
|
|
541
|
+
device = torch.device(device)
|
|
542
|
+
|
|
543
|
+
# If on the save device, directly retrun self
|
|
544
|
+
if self.device == device:
|
|
545
|
+
return self
|
|
546
|
+
|
|
547
|
+
def list_to_device(prop):
|
|
548
|
+
return [p.to(device, non_blocking=non_blocking)
|
|
549
|
+
if isinstance(p, torch.Tensor) else None for p in prop]
|
|
550
|
+
|
|
551
|
+
# Construct a new Octree on the specified device
|
|
552
|
+
octree = Octree(self.depth, self.full_depth, self.batch_size, device)
|
|
553
|
+
octree.keys = list_to_device(self.keys)
|
|
554
|
+
octree.children = list_to_device(self.children)
|
|
555
|
+
octree.neighs = list_to_device(self.neighs)
|
|
556
|
+
octree.features = list_to_device(self.features)
|
|
557
|
+
octree.normals = list_to_device(self.normals)
|
|
558
|
+
octree.points = list_to_device(self.points)
|
|
559
|
+
octree.nnum = self.nnum.clone() # TODO: whether to move nnum to the self.device?
|
|
560
|
+
octree.nnum_nempty = self.nnum_nempty.clone()
|
|
561
|
+
octree.batch_nnum = self.batch_nnum.clone()
|
|
562
|
+
octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
|
|
563
|
+
return octree
|
|
564
|
+
|
|
565
|
+
def cuda(self, non_blocking: bool = False):
|
|
566
|
+
r''' Moves the octree to the GPU. '''
|
|
567
|
+
|
|
568
|
+
return self.to('cuda', non_blocking)
|
|
569
|
+
|
|
570
|
+
def cpu(self):
|
|
571
|
+
r''' Moves the octree to the CPU. '''
|
|
572
|
+
|
|
573
|
+
return self.to('cpu')
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def merge_octrees(octrees: List['Octree']):
|
|
577
|
+
r''' Merges a list of octrees into one batch.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
octrees (List[Octree]): A list of octrees to merge.
|
|
581
|
+
|
|
582
|
+
Returns:
|
|
583
|
+
Octree: The merged octree.
|
|
584
|
+
'''
|
|
585
|
+
|
|
586
|
+
# init and check
|
|
587
|
+
octree = Octree(depth=octrees[0].depth, full_depth=octrees[0].full_depth,
|
|
588
|
+
batch_size=len(octrees), device=octrees[0].device)
|
|
589
|
+
for i in range(1, octree.batch_size):
|
|
590
|
+
condition = (octrees[i].depth == octree.depth and
|
|
591
|
+
octrees[i].full_depth == octree.full_depth and
|
|
592
|
+
octrees[i].device == octree.device)
|
|
593
|
+
assert condition, 'The check of merge_octrees failed'
|
|
594
|
+
|
|
595
|
+
# node num
|
|
596
|
+
batch_nnum = torch.stack(
|
|
597
|
+
[octrees[i].nnum for i in range(octree.batch_size)], dim=1)
|
|
598
|
+
batch_nnum_nempty = torch.stack(
|
|
599
|
+
[octrees[i].nnum_nempty for i in range(octree.batch_size)], dim=1)
|
|
600
|
+
octree.nnum = torch.sum(batch_nnum, dim=1)
|
|
601
|
+
octree.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
|
|
602
|
+
octree.batch_nnum = batch_nnum
|
|
603
|
+
octree.batch_nnum_nempty = batch_nnum_nempty
|
|
604
|
+
nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
|
|
605
|
+
|
|
606
|
+
# merge octre properties
|
|
607
|
+
for d in range(octree.depth+1):
|
|
608
|
+
# key
|
|
609
|
+
keys = [None] * octree.batch_size
|
|
610
|
+
for i in range(octree.batch_size):
|
|
611
|
+
key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
|
|
612
|
+
keys[i] = key | (i << 48)
|
|
613
|
+
octree.keys[d] = torch.cat(keys, dim=0)
|
|
614
|
+
|
|
615
|
+
# children
|
|
616
|
+
children = [None] * octree.batch_size
|
|
617
|
+
for i in range(octree.batch_size):
|
|
618
|
+
child = octrees[i].children[d].clone() # !! `clone` is used here to avoid
|
|
619
|
+
mask = child >= 0 # !! modifying the original octrees
|
|
620
|
+
child[mask] = child[mask] + nnum_cum[d, i]
|
|
621
|
+
children[i] = child
|
|
622
|
+
octree.children[d] = torch.cat(children, dim=0)
|
|
623
|
+
|
|
624
|
+
# features
|
|
625
|
+
if octrees[0].features[d] is not None and d == octree.depth:
|
|
626
|
+
features = [octrees[i].features[d] for i in range(octree.batch_size)]
|
|
627
|
+
octree.features[d] = torch.cat(features, dim=0)
|
|
628
|
+
|
|
629
|
+
# normals
|
|
630
|
+
if octrees[0].normals[d] is not None and d == octree.depth:
|
|
631
|
+
normals = [octrees[i].normals[d] for i in range(octree.batch_size)]
|
|
632
|
+
octree.normals[d] = torch.cat(normals, dim=0)
|
|
633
|
+
|
|
634
|
+
# points
|
|
635
|
+
if octrees[0].points[d] is not None and d == octree.depth:
|
|
636
|
+
points = [octrees[i].points[d] for i in range(octree.batch_size)]
|
|
637
|
+
octree.points[d] = torch.cat(points, dim=0)
|
|
638
|
+
|
|
639
|
+
return octree
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def init_octree(depth: int, full_depth: int = 2, batch_size: int = 1,
|
|
643
|
+
device: Union[torch.device, str] = 'cpu'):
|
|
644
|
+
r'''
|
|
645
|
+
Initializes an octree to :attr:`full_depth`.
|
|
646
|
+
|
|
647
|
+
Args:
|
|
648
|
+
depth (int): The depth of the octree.
|
|
649
|
+
full_depth (int): The octree layers with a depth small than
|
|
650
|
+
:attr:`full_depth` are forced to be full.
|
|
651
|
+
batch_size (int, optional): The batch size.
|
|
652
|
+
device (torch.device or str): The device to use for computation.
|
|
653
|
+
|
|
654
|
+
Returns:
|
|
655
|
+
Octree: The initialized Octree object.
|
|
656
|
+
'''
|
|
657
|
+
|
|
658
|
+
octree = Octree(depth, full_depth, batch_size, device)
|
|
659
|
+
for d in range(full_depth+1):
|
|
660
|
+
octree.octree_grow_full(depth=d)
|
|
661
|
+
return octree
|