ocnn 2.2.1__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/octree/octree.py CHANGED
@@ -1,601 +1,639 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn.functional as F
10
- from typing import Union, List
11
-
12
- from ocnn.utils import meshgrid, scatter_add, cumsum, trunc_div
13
- from .points import Points
14
- from .shuffled_key import xyz2key, key2xyz
15
-
16
-
17
- class Octree:
18
- r''' Builds an octree from an input point cloud.
19
-
20
- Args:
21
- depth (int): The octree depth.
22
- full_depth (int): The octree layers with a depth small than
23
- :attr:`full_depth` are forced to be full.
24
- batch_size (int): The octree batch size.
25
- device (torch.device or str): Choose from :obj:`cpu` and :obj:`gpu`.
26
- (default: :obj:`cpu`)
27
-
28
- .. note::
29
- The octree data structure requires that if an octree node has children nodes,
30
- the number of children nodes is exactly 8, in which some of the nodes are
31
- empty and some nodes are non-empty. The properties of an octree, including
32
- :obj:`keys`, :obj:`children` and :obj:`neighs`, contain both non-empty and
33
- empty nodes, and other properties, including :obj:`features`, :obj:`normals`
34
- and :obj:`points`, contain only non-empty nodes.
35
-
36
- .. note::
37
- The point cloud must be in range :obj:`[-1, 1]`.
38
- '''
39
-
40
- def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
41
- device: Union[torch.device, str] = 'cpu', **kwargs):
42
- super().__init__()
43
- self.depth = depth
44
- self.full_depth = full_depth
45
- self.batch_size = batch_size
46
- self.device = device
47
-
48
- self.reset()
49
-
50
- def reset(self):
51
- r''' Resets the Octree status and constructs several lookup tables.
52
- '''
53
-
54
- # octree features in each octree layers
55
- num = self.depth + 1
56
- self.keys = [None] * num
57
- self.children = [None] * num
58
- self.neighs = [None] * num
59
- self.features = [None] * num
60
- self.normals = [None] * num
61
- self.points = [None] * num
62
-
63
- # octree node numbers in each octree layers.
64
- # TODO: decide whether to settle them to 'gpu' or not?
65
- self.nnum = torch.zeros(num, dtype=torch.int32)
66
- self.nnum_nempty = torch.zeros(num, dtype=torch.int32)
67
-
68
- # the following properties are valid after `merge_octrees`.
69
- # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
70
- batch_size = self.batch_size
71
- self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.int32)
72
- self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.int32)
73
-
74
- # construct the look up tables for neighborhood searching
75
- device = self.device
76
- center_grid = self.rng_grid(2, 3) # (8, 3)
77
- displacement = self.rng_grid(-1, 1) # (27, 3)
78
- neigh_grid = center_grid.unsqueeze(1) + displacement # (8, 27, 3)
79
- parent_grid = trunc_div(neigh_grid, 2)
80
- child_grid = neigh_grid % 2
81
- self.lut_parent = torch.sum(
82
- parent_grid * torch.tensor([9, 3, 1], device=device), dim=2)
83
- self.lut_child = torch.sum(
84
- child_grid * torch.tensor([4, 2, 1], device=device), dim=2)
85
-
86
- # lookup tables for different kernel sizes
87
- self.lut_kernel = {
88
- '222': torch.tensor([13, 14, 16, 17, 22, 23, 25, 26], device=device),
89
- '311': torch.tensor([4, 13, 22], device=device),
90
- '131': torch.tensor([10, 13, 16], device=device),
91
- '113': torch.tensor([12, 13, 14], device=device),
92
- '331': torch.tensor([1, 4, 7, 10, 13, 16, 19, 22, 25], device=device),
93
- '313': torch.tensor([3, 4, 5, 12, 13, 14, 21, 22, 23], device=device),
94
- '133': torch.tensor([9, 10, 11, 12, 13, 14, 15, 16, 17], device=device),
95
- }
96
-
97
- def key(self, depth: int, nempty: bool = False):
98
- r''' Returns the shuffled key of each octree node.
99
-
100
- Args:
101
- depth (int): The depth of the octree.
102
- nempty (bool): If True, returns the results of non-empty octree nodes.
103
- '''
104
-
105
- key = self.keys[depth]
106
- if nempty:
107
- mask = self.nempty_mask(depth)
108
- key = key[mask]
109
- return key
110
-
111
- def xyzb(self, depth: int, nempty: bool = False):
112
- r''' Returns the xyz coordinates and the batch indices of each octree node.
113
-
114
- Args:
115
- depth (int): The depth of the octree.
116
- nempty (bool): If True, returns the results of non-empty octree nodes.
117
- '''
118
-
119
- key = self.key(depth, nempty)
120
- return key2xyz(key, depth)
121
-
122
- def batch_id(self, depth: int, nempty: bool = False):
123
- r''' Returns the batch indices of each octree node.
124
-
125
- Args:
126
- depth (int): The depth of the octree.
127
- nempty (bool): If True, returns the results of non-empty octree nodes.
128
- '''
129
-
130
- batch_id = self.keys[depth] >> 48
131
- if nempty:
132
- mask = self.nempty_mask(depth)
133
- batch_id = batch_id[mask]
134
- return batch_id
135
-
136
- def nempty_mask(self, depth: int):
137
- r''' Returns a binary mask which indicates whether the cooreponding octree
138
- node is empty or not.
139
-
140
- Args:
141
- depth (int): The depth of the octree.
142
- '''
143
-
144
- return self.children[depth] >= 0
145
-
146
- def build_octree(self, point_cloud: Points):
147
- r''' Builds an octree from a point cloud.
148
-
149
- Args:
150
- point_cloud (Points): The input point cloud.
151
-
152
- .. note::
153
- Currently, the batch size of the point cloud must be 1.
154
- '''
155
-
156
- self.device = point_cloud.device
157
- assert point_cloud.batch_size == self.batch_size, 'Inconsistent batch_size'
158
-
159
- # normalize points from [-1, 1] to [0, 2^depth]. #[L:Scale]
160
- scale = 2 ** (self.depth - 1)
161
- points = (point_cloud.points + 1.0) * scale
162
-
163
- # get the shuffled key and sort
164
- x, y, z = points[:, 0], points[:, 1], points[:, 2]
165
- b = None if self.batch_size == 1 else point_cloud.batch_id.view(-1)
166
- key = xyz2key(x, y, z, b, self.depth)
167
- node_key, idx, counts = torch.unique(
168
- key, sorted=True, return_inverse=True, return_counts=True)
169
-
170
- # layer 0 to full_layer: the octree is full in these layers
171
- for d in range(self.full_depth+1):
172
- self.octree_grow_full(d, update_neigh=False)
173
-
174
- # layer depth_ to full_layer_
175
- for d in range(self.depth, self.full_depth, -1):
176
- # compute parent key, i.e. keys of layer (d -1)
177
- pkey = node_key >> 3
178
- pkey, pidx, pcounts = torch.unique_consecutive(
179
- pkey, return_inverse=True, return_counts=True)
180
-
181
- # augmented key
182
- key = (pkey.unsqueeze(-1) << 3) + torch.arange(8, device=self.device)
183
- self.keys[d] = key.view(-1)
184
- self.nnum[d] = key.numel()
185
- self.nnum_nempty[d] = node_key.numel()
186
-
187
- # children
188
- addr = (pidx << 3) | (node_key % 8)
189
- children = -torch.ones(
190
- self.nnum[d].item(), dtype=torch.int32, device=self.device)
191
- children[addr] = torch.arange(
192
- self.nnum_nempty[d], dtype=torch.int32, device=self.device)
193
- self.children[d] = children
194
-
195
- # cache pkey for the next iteration
196
- # Use `pkey >> 45` instead of `pkey >> 48` in L199 since pkey is already
197
- # shifted to the right by 3 bits in L177
198
- node_key = pkey if self.batch_size == 1 else \
199
- ((pkey >> 45) << 48) | (pkey & ((1 << 45) - 1))
200
-
201
- # set the children for the layer full_layer,
202
- # now the node_keys are the key for full_layer
203
- d = self.full_depth
204
- children = -torch.ones_like(self.children[d])
205
- nempty_idx = node_key if self.batch_size == 1 else \
206
- ((node_key >> 48) << (3 * d)) | (node_key & ((1 << 48) - 1))
207
- children[nempty_idx] = torch.arange(
208
- node_key.numel(), dtype=torch.int32, device=self.device)
209
- self.children[d] = children
210
- self.nnum_nempty[d] = node_key.numel()
211
-
212
- # average the signal for the last octree layer
213
- d = self.depth
214
- points = scatter_add(points, idx, dim=0) # points is rescaled in [L:Scale]
215
- self.points[d] = points / counts.unsqueeze(1)
216
- if point_cloud.normals is not None:
217
- normals = scatter_add(point_cloud.normals, idx, dim=0)
218
- self.normals[d] = F.normalize(normals)
219
- if point_cloud.features is not None:
220
- features = scatter_add(point_cloud.features, idx, dim=0)
221
- self.features[d] = features / counts.unsqueeze(1)
222
-
223
- return idx
224
-
225
- def octree_grow_full(self, depth: int, update_neigh: bool = True):
226
- r''' Builds the full octree, which is essentially a dense volumetric grid.
227
-
228
- Args:
229
- depth (int): The depth of the octree.
230
- update_neigh (bool): If True, construct the neighborhood indices.
231
- '''
232
-
233
- # check
234
- assert depth <= self.full_depth, 'error'
235
-
236
- # node number
237
- num = 1 << (3 * depth)
238
- self.nnum[depth] = num * self.batch_size
239
- self.nnum_nempty[depth] = num * self.batch_size
240
-
241
- # update key
242
- key = torch.arange(num, dtype=torch.long, device=self.device)
243
- bs = torch.arange(self.batch_size, dtype=torch.long, device=self.device)
244
- key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
245
- self.keys[depth] = key.view(-1)
246
-
247
- # update children
248
- self.children[depth] = torch.arange(
249
- num * self.batch_size, dtype=torch.int32, device=self.device)
250
-
251
- # update neigh if needed
252
- if update_neigh:
253
- self.construct_neigh(depth)
254
-
255
- def octree_split(self, split: torch.Tensor, depth: int):
256
- r''' Sets whether the octree nodes in :attr:`depth` are splitted or not.
257
-
258
- Args:
259
- split (torch.Tensor): The input tensor with its element indicating status
260
- of each octree node: 0 - empty, 1 - non-empty or splitted.
261
- depth (int): The depth of current octree.
262
- '''
263
-
264
- # split -> children
265
- empty = split == 0
266
- sum = cumsum(split, dim=0, exclusive=True)
267
- children, nnum_nempty = torch.split(sum, [split.shape[0], 1])
268
- children[empty] = -1
269
-
270
- # boundary case, make sure that at least one octree node is splitted
271
- if nnum_nempty == 0:
272
- nnum_nempty = 1
273
- children[0] = 0
274
-
275
- # update octree
276
- self.children[depth] = children
277
- self.nnum_nempty[depth] = nnum_nempty
278
-
279
- def octree_grow(self, depth: int, update_neigh: bool = True):
280
- r''' Grows the octree and updates the relevant properties. And in most
281
- cases, call :func:`Octree.octree_split` to update the splitting status of
282
- the octree before this function.
283
-
284
- Args:
285
- depth (int): The depth of the octree.
286
- update_neigh (bool): If True, construct the neighborhood indices.
287
- '''
288
-
289
- # node number
290
- nnum = self.nnum_nempty[depth-1] * 8
291
- self.nnum[depth] = nnum
292
- self.nnum_nempty[depth] = nnum
293
-
294
- # update keys
295
- key = self.key(depth-1, nempty=True)
296
- batch_id = (key >> 48) << 48
297
- key = (key & ((1 << 48) - 1)) << 3
298
- key = key | batch_id
299
- key = key.unsqueeze(1) + torch.arange(8, device=key.device)
300
- self.keys[depth] = key.view(-1)
301
-
302
- # update children
303
- self.children[depth] = torch.arange(
304
- nnum, dtype=torch.int32, device=self.device)
305
-
306
- # update neighs
307
- if update_neigh:
308
- self.construct_neigh(depth)
309
-
310
- def construct_neigh(self, depth: int):
311
- r''' Constructs the :obj:`3x3x3` neighbors for each octree node.
312
-
313
- Args:
314
- depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
315
- '''
316
-
317
- if depth <= self.full_depth:
318
- nnum = 1 << (3 * depth)
319
- key = torch.arange(nnum, dtype=torch.long, device=self.device)
320
- x, y, z, _ = key2xyz(key, depth)
321
- xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
322
- grid = self.rng_grid(min=-1, max=1) # (27, 3)
323
- xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
324
- xyz = xyz.view(-1, 3) # (N*27, 3)
325
- neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
326
-
327
- bs = torch.arange(self.batch_size, dtype=torch.int32, device=self.device)
328
- neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
329
-
330
- bound = 1 << depth
331
- invalid = torch.logical_or((xyz < 0).any(1), (xyz >= bound).any(1))
332
- neigh[:, invalid] = -1
333
- self.neighs[depth] = neigh.view(-1, 27) # (B*N, 27)
334
-
335
- else:
336
- child_p = self.children[depth-1]
337
- nempty = child_p >= 0
338
- neigh_p = self.neighs[depth-1][nempty] # (N, 27)
339
- neigh_p = neigh_p[:, self.lut_parent] # (N, 8, 27)
340
- child_p = child_p[neigh_p] # (N, 8, 27)
341
- invalid = torch.logical_or(child_p < 0, neigh_p < 0) # (N, 8, 27)
342
- neigh = child_p * 8 + self.lut_child
343
- neigh[invalid] = -1
344
- self.neighs[depth] = neigh.view(-1, 27)
345
-
346
- def construct_all_neigh(self):
347
- r''' A convenient handler for constructing all neighbors.
348
- '''
349
-
350
- for depth in range(1, self.depth+1):
351
- self.construct_neigh(depth)
352
-
353
- def search_xyzb(self, query: torch.Tensor, depth: int, nempty: bool = False):
354
- r''' Searches the octree nodes given the query points.
355
-
356
- Args:
357
- query (torch.Tensor): The coordinates of query points with shape
358
- :obj:`(N, 4)`. The first 3 channels of the coordinates are :obj:`x`,
359
- :obj:`y`, and :obj:`z`, and the last channel is the batch index. Note
360
- that the coordinates must be in range :obj:`[0, 2^depth)`.
361
- depth (int): The depth of the octree layer. nemtpy (bool): If true, only
362
- searches the non-empty octree nodes.
363
- '''
364
-
365
- key = xyz2key(query[:, 0], query[:, 1], query[:, 2], query[:, 3], depth)
366
- idx = self.search_key(key, depth, nempty)
367
- return idx
368
-
369
- def search_key(self, query: torch.Tensor, depth: int, nempty: bool = False):
370
- r''' Searches the octree nodes given the query points.
371
-
372
- Args:
373
- query (torch.Tensor): The keys of query points with shape :obj:`(N,)`,
374
- which are computed from the coordinates of query points.
375
- depth (int): The depth of the octree layer. nemtpy (bool): If true, only
376
- searches the non-empty octree nodes.
377
- '''
378
-
379
- key = self.key(depth, nempty)
380
- # `torch.bucketize` is similar to `torch.searchsorted`.
381
- # I choose `torch.bucketize` here because it has fewer dimension checks,
382
- # resulting in slightly better performance according to the docs of
383
- # pytorch-1.9.1, since `key` is always 1-D sorted sequence.
384
- idx = torch.bucketize(query, key)
385
-
386
- valid = idx < key.shape[0] # invalid if out of bound
387
- found = key[idx[valid]] == query[valid]
388
- valid[valid.clone()] = found
389
- idx[valid.logical_not()] = -1
390
- return idx
391
-
392
- def get_neigh(self, depth: int, kernel: str = '333', stride: int = 1,
393
- nempty: bool = False):
394
- r''' Returns the neighborhoods given the depth and a kernel shape.
395
-
396
- Args:
397
- depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
398
- kernel (str): The kernel shape from :obj:`333`, :obj:`311`, :obj:`131`,
399
- :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and :obj:`313`.
400
- stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
401
- stride is :obj:`2`, always returns the neighborhood of the first
402
- siblings.
403
- nempty (bool): If True, only returns the neighborhoods of the non-empty
404
- octree nodes.
405
- '''
406
-
407
- if stride == 1:
408
- neigh = self.neighs[depth]
409
- elif stride == 2:
410
- # clone neigh to avoid self.neigh[depth] being modified
411
- neigh = self.neighs[depth][::8].clone()
412
- else:
413
- raise ValueError('Unsupported stride {}'.format(stride))
414
-
415
- if nempty:
416
- child = self.children[depth]
417
- if stride == 1:
418
- nempty_node = child >= 0
419
- neigh = neigh[nempty_node]
420
- valid = neigh >= 0
421
- neigh[valid] = child[neigh[valid]].long() # remap the index
422
-
423
- if kernel == '333':
424
- return neigh
425
- elif kernel in self.lut_kernel:
426
- lut = self.lut_kernel[kernel]
427
- return neigh[:, lut]
428
- else:
429
- raise ValueError('Unsupported kernel {}'.format(kernel))
430
-
431
- def get_input_feature(self):
432
- r''' Gets the initial input features.
433
- '''
434
-
435
- # normals
436
- features = list()
437
- depth = self.depth
438
- has_normal = self.normals[depth] is not None
439
- if has_normal:
440
- features.append(self.normals[depth])
441
-
442
- # local points
443
- points = self.points[depth].frac() - 0.5
444
- if has_normal:
445
- dis = torch.sum(points * self.normals[depth], dim=1, keepdim=True)
446
- features.append(dis)
447
- else:
448
- features.append(points)
449
-
450
- # features
451
- if self.features[depth] is not None:
452
- features.append(self.features[depth])
453
-
454
- return torch.cat(features, dim=1)
455
-
456
- def to_points(self, rescale: bool = True):
457
- r''' Converts averaged points in the octree to a point cloud.
458
-
459
- Args:
460
- rescale (bool): rescale the xyz coordinates to [-1, 1] if True.
461
- '''
462
-
463
- depth = self.depth
464
- batch_size = self.batch_size
465
-
466
- # by default, use the average points generated when building the octree
467
- # from the input point cloud
468
- xyz = self.points[depth]
469
- batch_id = self.batch_id(depth, nempty=True)
470
-
471
- # xyz is None when the octree is predicted by a neural network
472
- if xyz is None:
473
- x, y, z, batch_id = self.xyzb(depth, nempty=True)
474
- xyz = torch.stack([x, y, z], dim=1) + 0.5
475
-
476
- # normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
477
- if rescale:
478
- scale = 2 ** (1 - depth)
479
- xyz = self.points[depth] * scale - 1.0
480
-
481
- # construct Points
482
- out = Points(xyz, self.normals[depth], self.features[depth],
483
- batch_id=batch_id, batch_size=batch_size)
484
- return out
485
-
486
- def to(self, device: Union[torch.device, str], non_blocking: bool = False):
487
- r''' Moves the octree to a specified device.
488
-
489
- Args:
490
- device (torch.device or str): The destination device.
491
- non_blocking (bool): If True and the source is in pinned memory, the copy
492
- will be asynchronous with respect to the host. Otherwise, the argument
493
- has no effect. Default: False.
494
- '''
495
-
496
- if isinstance(device, str):
497
- device = torch.device(device)
498
-
499
- # If on the save device, directly retrun self
500
- if self.device == device:
501
- return self
502
-
503
- def list_to_device(prop):
504
- return [p.to(device, non_blocking=non_blocking)
505
- if isinstance(p, torch.Tensor) else None for p in prop]
506
-
507
- # Construct a new Octree on the specified device
508
- octree = Octree(self.depth, self.full_depth, self.batch_size, device)
509
- octree.keys = list_to_device(self.keys)
510
- octree.children = list_to_device(self.children)
511
- octree.neighs = list_to_device(self.neighs)
512
- octree.features = list_to_device(self.features)
513
- octree.normals = list_to_device(self.normals)
514
- octree.points = list_to_device(self.points)
515
- octree.nnum = self.nnum.clone() # TODO: whether to move nnum to the self.device?
516
- octree.nnum_nempty = self.nnum_nempty.clone()
517
- octree.batch_nnum = self.batch_nnum.clone()
518
- octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
519
- return octree
520
-
521
- def cuda(self, non_blocking: bool = False):
522
- r''' Moves the octree to the GPU. '''
523
-
524
- return self.to('cuda', non_blocking)
525
-
526
- def cpu(self):
527
- r''' Moves the octree to the CPU. '''
528
-
529
- return self.to('cpu')
530
-
531
- def rng_grid(self, min, max):
532
- r''' Builds a mesh grid in :obj:`[min, max]` (:attr:`max` included).
533
- '''
534
-
535
- rng = torch.arange(min, max+1, dtype=torch.long, device=self.device)
536
- grid = meshgrid(rng, rng, rng, indexing='ij')
537
- grid = torch.stack(grid, dim=-1).view(-1, 3) # (27, 3)
538
- return grid
539
-
540
-
541
- def merge_octrees(octrees: List['Octree']):
542
- r''' Merges a list of octrees into one batch.
543
-
544
- Args:
545
- octrees (List[Octree]): A list of octrees to merge.
546
- '''
547
-
548
- # init and check
549
- octree = Octree(depth=octrees[0].depth, full_depth=octrees[0].full_depth,
550
- batch_size=len(octrees), device=octrees[0].device)
551
- for i in range(1, octree.batch_size):
552
- condition = (octrees[i].depth == octree.depth and
553
- octrees[i].full_depth == octree.full_depth and
554
- octrees[i].device == octree.device)
555
- assert condition, 'The check of merge_octrees failed'
556
-
557
- # node num
558
- batch_nnum = torch.stack(
559
- [octrees[i].nnum for i in range(octree.batch_size)], dim=1)
560
- batch_nnum_nempty = torch.stack(
561
- [octrees[i].nnum_nempty for i in range(octree.batch_size)], dim=1)
562
- octree.nnum = torch.sum(batch_nnum, dim=1)
563
- octree.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
564
- octree.batch_nnum = batch_nnum
565
- octree.batch_nnum_nempty = batch_nnum_nempty
566
- nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
567
-
568
- # merge octre properties
569
- for d in range(octree.depth+1):
570
- # key
571
- keys = [None] * octree.batch_size
572
- for i in range(octree.batch_size):
573
- key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
574
- keys[i] = key | (i << 48)
575
- octree.keys[d] = torch.cat(keys, dim=0)
576
-
577
- # children
578
- children = [None] * octree.batch_size
579
- for i in range(octree.batch_size):
580
- child = octrees[i].children[d].clone() # !! `clone` is used here to avoid
581
- mask = child >= 0 # !! modifying the original octrees
582
- child[mask] = child[mask] + nnum_cum[d, i]
583
- children[i] = child
584
- octree.children[d] = torch.cat(children, dim=0)
585
-
586
- # features
587
- if octrees[0].features[d] is not None and d == octree.depth:
588
- features = [octrees[i].features[d] for i in range(octree.batch_size)]
589
- octree.features[d] = torch.cat(features, dim=0)
590
-
591
- # normals
592
- if octrees[0].normals[d] is not None and d == octree.depth:
593
- normals = [octrees[i].normals[d] for i in range(octree.batch_size)]
594
- octree.normals[d] = torch.cat(normals, dim=0)
595
-
596
- # points
597
- if octrees[0].points[d] is not None and d == octree.depth:
598
- points = [octrees[i].points[d] for i in range(octree.batch_size)]
599
- octree.points[d] = torch.cat(points, dim=0)
600
-
601
- return octree
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from typing import Union, List
11
+
12
+ import ocnn
13
+ from ocnn.octree.points import Points
14
+ from ocnn.octree.shuffled_key import xyz2key, key2xyz
15
+ from ocnn.utils import range_grid, scatter_add, cumsum, trunc_div
16
+
17
+
18
+ class Octree:
19
+ r''' Builds an octree from an input point cloud.
20
+
21
+ Args:
22
+ depth (int): The octree depth.
23
+ full_depth (int): The octree layers with a depth small than
24
+ :attr:`full_depth` are forced to be full.
25
+ batch_size (int): The octree batch size.
26
+ device (torch.device or str): Choose from :obj:`cpu` and :obj:`gpu`.
27
+ (default: :obj:`cpu`)
28
+
29
+ .. note::
30
+ The octree data structure requires that if an octree node has children nodes,
31
+ the number of children nodes is exactly 8, in which some of the nodes are
32
+ empty and some nodes are non-empty. The properties of an octree, including
33
+ :obj:`keys`, :obj:`children` and :obj:`neighs`, contain both non-empty and
34
+ empty nodes, and other properties, including :obj:`features`, :obj:`normals`
35
+ and :obj:`points`, contain only non-empty nodes.
36
+
37
+ .. note::
38
+ The point cloud must be in range :obj:`[-1, 1]`.
39
+ '''
40
+
41
+ def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
42
+ device: Union[torch.device, str] = 'cpu', **kwargs):
43
+ super().__init__()
44
+ self.depth = depth
45
+ self.full_depth = full_depth
46
+ self.batch_size = batch_size
47
+ self.device = device
48
+
49
+ self.reset()
50
+
51
+ def reset(self):
52
+ r''' Resets the Octree status and constructs several lookup tables.
53
+ '''
54
+
55
+ # octree features in each octree layers
56
+ num = self.depth + 1
57
+ self.keys = [None] * num
58
+ self.children = [None] * num
59
+ self.neighs = [None] * num
60
+ self.features = [None] * num
61
+ self.normals = [None] * num
62
+ self.points = [None] * num
63
+
64
+ # octree node numbers in each octree layers.
65
+ # TODO: decide whether to settle them to 'gpu' or not?
66
+ self.nnum = torch.zeros(num, dtype=torch.int32)
67
+ self.nnum_nempty = torch.zeros(num, dtype=torch.int32)
68
+
69
+ # the following properties are valid after `merge_octrees`.
70
+ # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
71
+ batch_size = self.batch_size
72
+ self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.int32)
73
+ self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.int32)
74
+
75
+ # construct the look up tables for neighborhood searching
76
+ device = self.device
77
+ center_grid = range_grid(2, 3, device) # (8, 3)
78
+ displacement = range_grid(-1, 1, device) # (27, 3)
79
+ neigh_grid = center_grid.unsqueeze(1) + displacement # (8, 27, 3)
80
+ parent_grid = trunc_div(neigh_grid, 2)
81
+ child_grid = neigh_grid % 2
82
+ self.lut_parent = torch.sum(
83
+ parent_grid * torch.tensor([9, 3, 1], device=device), dim=2)
84
+ self.lut_child = torch.sum(
85
+ child_grid * torch.tensor([4, 2, 1], device=device), dim=2)
86
+
87
+ # lookup tables for different kernel sizes
88
+ self.lut_kernel = {
89
+ '222': torch.tensor([13, 14, 16, 17, 22, 23, 25, 26], device=device),
90
+ '311': torch.tensor([4, 13, 22], device=device),
91
+ '131': torch.tensor([10, 13, 16], device=device),
92
+ '113': torch.tensor([12, 13, 14], device=device),
93
+ '331': torch.tensor([1, 4, 7, 10, 13, 16, 19, 22, 25], device=device),
94
+ '313': torch.tensor([3, 4, 5, 12, 13, 14, 21, 22, 23], device=device),
95
+ '133': torch.tensor([9, 10, 11, 12, 13, 14, 15, 16, 17], device=device),
96
+ }
97
+
98
+ def key(self, depth: int, nempty: bool = False):
99
+ r''' Returns the shuffled key of each octree node.
100
+
101
+ Args:
102
+ depth (int): The depth of the octree.
103
+ nempty (bool): If True, returns the results of non-empty octree nodes.
104
+ '''
105
+
106
+ key = self.keys[depth]
107
+ if nempty:
108
+ mask = self.nempty_mask(depth)
109
+ key = key[mask]
110
+ return key
111
+
112
+ def xyzb(self, depth: int, nempty: bool = False):
113
+ r''' Returns the xyz coordinates and the batch indices of each octree node.
114
+
115
+ Args:
116
+ depth (int): The depth of the octree.
117
+ nempty (bool): If True, returns the results of non-empty octree nodes.
118
+ '''
119
+
120
+ key = self.key(depth, nempty)
121
+ return key2xyz(key, depth)
122
+
123
+ def batch_id(self, depth: int, nempty: bool = False):
124
+ r''' Returns the batch indices of each octree node.
125
+
126
+ Args:
127
+ depth (int): The depth of the octree.
128
+ nempty (bool): If True, returns the results of non-empty octree nodes.
129
+ '''
130
+
131
+ batch_id = self.keys[depth] >> 48
132
+ if nempty:
133
+ mask = self.nempty_mask(depth)
134
+ batch_id = batch_id[mask]
135
+ return batch_id
136
+
137
+ def nempty_mask(self, depth: int):
138
+ r''' Returns a binary mask which indicates whether the cooreponding octree
139
+ node is empty or not.
140
+
141
+ Args:
142
+ depth (int): The depth of the octree.
143
+ '''
144
+
145
+ return self.children[depth] >= 0
146
+
147
+ def build_octree(self, point_cloud: Points):
148
+ r''' Builds an octree from a point cloud.
149
+
150
+ Args:
151
+ point_cloud (Points): The input point cloud.
152
+
153
+ .. note::
154
+ Currently, the batch size of the point cloud must be 1.
155
+ '''
156
+
157
+ self.device = point_cloud.device
158
+ assert point_cloud.batch_size == self.batch_size, 'Inconsistent batch_size'
159
+
160
+ # normalize points from [-1, 1] to [0, 2^depth]. #[L:Scale]
161
+ scale = 2 ** (self.depth - 1)
162
+ points = (point_cloud.points + 1.0) * scale
163
+
164
+ # get the shuffled key and sort
165
+ x, y, z = points[:, 0], points[:, 1], points[:, 2]
166
+ b = None if self.batch_size == 1 else point_cloud.batch_id.view(-1)
167
+ key = xyz2key(x, y, z, b, self.depth)
168
+ node_key, idx, counts = torch.unique(
169
+ key, sorted=True, return_inverse=True, return_counts=True)
170
+
171
+ # layer 0 to full_layer: the octree is full in these layers
172
+ for d in range(self.full_depth+1):
173
+ self.octree_grow_full(d, update_neigh=False)
174
+
175
+ # layer depth_ to full_layer_
176
+ for d in range(self.depth, self.full_depth, -1):
177
+ # compute parent key, i.e. keys of layer (d -1)
178
+ pkey = node_key >> 3
179
+ pkey, pidx, pcounts = torch.unique_consecutive(
180
+ pkey, return_inverse=True, return_counts=True)
181
+
182
+ # augmented key
183
+ key = (pkey.unsqueeze(-1) << 3) + torch.arange(8, device=self.device)
184
+ self.keys[d] = key.view(-1)
185
+ self.nnum[d] = key.numel()
186
+ self.nnum_nempty[d] = node_key.numel()
187
+
188
+ # children
189
+ addr = (pidx << 3) | (node_key % 8)
190
+ children = -torch.ones(
191
+ self.nnum[d].item(), dtype=torch.int32, device=self.device)
192
+ children[addr] = torch.arange(
193
+ self.nnum_nempty[d], dtype=torch.int32, device=self.device)
194
+ self.children[d] = children
195
+
196
+ # cache pkey for the next iteration
197
+ # Use `pkey >> 45` instead of `pkey >> 48` in L199 since pkey is already
198
+ # shifted to the right by 3 bits in L177
199
+ node_key = pkey if self.batch_size == 1 else \
200
+ ((pkey >> 45) << 48) | (pkey & ((1 << 45) - 1))
201
+
202
+ # set the children for the layer full_layer,
203
+ # now the node_keys are the key for full_layer
204
+ d = self.full_depth
205
+ children = -torch.ones_like(self.children[d])
206
+ nempty_idx = node_key if self.batch_size == 1 else \
207
+ ((node_key >> 48) << (3 * d)) | (node_key & ((1 << 48) - 1))
208
+ children[nempty_idx] = torch.arange(
209
+ node_key.numel(), dtype=torch.int32, device=self.device)
210
+ self.children[d] = children
211
+ self.nnum_nempty[d] = node_key.numel()
212
+
213
+ # average the signal for the last octree layer
214
+ d = self.depth
215
+ points = scatter_add(points, idx, dim=0) # points is rescaled in [L:Scale]
216
+ self.points[d] = points / counts.unsqueeze(1)
217
+ if point_cloud.normals is not None:
218
+ normals = scatter_add(point_cloud.normals, idx, dim=0)
219
+ self.normals[d] = F.normalize(normals)
220
+ if point_cloud.features is not None:
221
+ features = scatter_add(point_cloud.features, idx, dim=0)
222
+ self.features[d] = features / counts.unsqueeze(1)
223
+
224
+ return idx
225
+
226
+ def octree_grow_full(self, depth: int, update_neigh: bool = True):
227
+ r''' Builds the full octree, which is essentially a dense volumetric grid.
228
+
229
+ Args:
230
+ depth (int): The depth of the octree.
231
+ update_neigh (bool): If True, construct the neighborhood indices.
232
+ '''
233
+
234
+ # check
235
+ assert depth <= self.full_depth, 'error'
236
+
237
+ # node number
238
+ num = 1 << (3 * depth)
239
+ self.nnum[depth] = num * self.batch_size
240
+ self.nnum_nempty[depth] = num * self.batch_size
241
+
242
+ # update key
243
+ key = torch.arange(num, dtype=torch.long, device=self.device)
244
+ bs = torch.arange(self.batch_size, dtype=torch.long, device=self.device)
245
+ key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
246
+ self.keys[depth] = key.view(-1)
247
+
248
+ # update children
249
+ self.children[depth] = torch.arange(
250
+ num * self.batch_size, dtype=torch.int32, device=self.device)
251
+
252
+ # update neigh if needed
253
+ if update_neigh:
254
+ self.construct_neigh(depth)
255
+
256
+ def octree_split(self, split: torch.Tensor, depth: int):
257
+ r''' Sets whether the octree nodes in :attr:`depth` are splitted or not.
258
+
259
+ Args:
260
+ split (torch.Tensor): The input tensor with its element indicating status
261
+ of each octree node: 0 - empty, 1 - non-empty or splitted.
262
+ depth (int): The depth of current octree.
263
+ '''
264
+
265
+ # split -> children
266
+ empty = split == 0
267
+ sum = cumsum(split, dim=0, exclusive=True)
268
+ children, nnum_nempty = torch.split(sum, [split.shape[0], 1])
269
+ children[empty] = -1
270
+
271
+ # boundary case, make sure that at least one octree node is splitted
272
+ if nnum_nempty == 0:
273
+ nnum_nempty = 1
274
+ children[0] = 0
275
+
276
+ # update octree
277
+ self.children[depth] = children
278
+ self.nnum_nempty[depth] = nnum_nempty
279
+
280
+ def octree_grow(self, depth: int, update_neigh: bool = True):
281
+ r''' Grows the octree and updates the relevant properties. And in most
282
+ cases, call :func:`Octree.octree_split` to update the splitting status of
283
+ the octree before this function.
284
+
285
+ Args:
286
+ depth (int): The depth of the octree.
287
+ update_neigh (bool): If True, construct the neighborhood indices.
288
+ '''
289
+
290
+ # node number
291
+ nnum = self.nnum_nempty[depth-1] * 8
292
+ self.nnum[depth] = nnum
293
+ self.nnum_nempty[depth] = nnum
294
+
295
+ # update keys
296
+ key = self.key(depth-1, nempty=True)
297
+ batch_id = (key >> 48) << 48
298
+ key = (key & ((1 << 48) - 1)) << 3
299
+ key = key | batch_id
300
+ key = key.unsqueeze(1) + torch.arange(8, device=key.device)
301
+ self.keys[depth] = key.view(-1)
302
+
303
+ # update children
304
+ self.children[depth] = torch.arange(
305
+ nnum, dtype=torch.int32, device=self.device)
306
+
307
+ # update neighs
308
+ if update_neigh:
309
+ self.construct_neigh(depth)
310
+
311
+ def construct_neigh(self, depth: int):
312
+ r''' Constructs the :obj:`3x3x3` neighbors for each octree node.
313
+
314
+ Args:
315
+ depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
316
+ '''
317
+
318
+ if depth <= self.full_depth:
319
+ device = self.device
320
+ nnum = 1 << (3 * depth)
321
+ key = torch.arange(nnum, dtype=torch.long, device=device)
322
+ x, y, z, _ = key2xyz(key, depth)
323
+ xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
324
+ grid = range_grid(-1, 1, device) # (27, 3)
325
+ xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
326
+ xyz = xyz.view(-1, 3) # (N*27, 3)
327
+ neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
328
+
329
+ bs = torch.arange(self.batch_size, dtype=torch.int32, device=device)
330
+ neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
331
+
332
+ bound = 1 << depth
333
+ invalid = torch.logical_or((xyz < 0).any(1), (xyz >= bound).any(1))
334
+ neigh[:, invalid] = -1
335
+ self.neighs[depth] = neigh.view(-1, 27) # (B*N, 27)
336
+
337
+ else:
338
+ child_p = self.children[depth-1]
339
+ nempty = child_p >= 0
340
+ neigh_p = self.neighs[depth-1][nempty] # (N, 27)
341
+ neigh_p = neigh_p[:, self.lut_parent] # (N, 8, 27)
342
+ child_p = child_p[neigh_p] # (N, 8, 27)
343
+ invalid = torch.logical_or(child_p < 0, neigh_p < 0) # (N, 8, 27)
344
+ neigh = child_p * 8 + self.lut_child
345
+ neigh[invalid] = -1
346
+ self.neighs[depth] = neigh.view(-1, 27)
347
+
348
+ def construct_all_neigh(self):
349
+ r''' A convenient handler for constructing all neighbors.
350
+ '''
351
+
352
+ for depth in range(1, self.depth+1):
353
+ self.construct_neigh(depth)
354
+
355
+ def search_xyzb(self, query: torch.Tensor, depth: int, nempty: bool = False):
356
+ r''' Searches the octree nodes given the query points.
357
+
358
+ Args:
359
+ query (torch.Tensor): The coordinates of query points with shape
360
+ :obj:`(N, 4)`. The first 3 channels of the coordinates are :obj:`x`,
361
+ :obj:`y`, and :obj:`z`, and the last channel is the batch index. Note
362
+ that the coordinates must be in range :obj:`[0, 2^depth)`.
363
+ depth (int): The depth of the octree layer. nemtpy (bool): If true, only
364
+ searches the non-empty octree nodes.
365
+ '''
366
+
367
+ key = xyz2key(query[:, 0], query[:, 1], query[:, 2], query[:, 3], depth)
368
+ idx = self.search_key(key, depth, nempty)
369
+ return idx
370
+
371
+ def search_key(self, query: torch.Tensor, depth: int, nempty: bool = False):
372
+ r''' Searches the octree nodes given the query points.
373
+
374
+ Args:
375
+ query (torch.Tensor): The keys of query points with shape :obj:`(N,)`,
376
+ which are computed from the coordinates of query points.
377
+ depth (int): The depth of the octree layer. nemtpy (bool): If true, only
378
+ searches the non-empty octree nodes.
379
+ '''
380
+
381
+ key = self.key(depth, nempty)
382
+ # `torch.bucketize` is similar to `torch.searchsorted`.
383
+ # I choose `torch.bucketize` here because it has fewer dimension checks,
384
+ # resulting in slightly better performance according to the docs of
385
+ # pytorch-1.9.1, since `key` is always 1-D sorted sequence.
386
+ idx = torch.bucketize(query, key)
387
+
388
+ valid = idx < key.shape[0] # valid if NOT out-of-bound
389
+ found = key[idx[valid]] == query[valid]
390
+ valid[valid.clone()] = found # valid if found
391
+ idx[valid.logical_not()] = -1 # set to -1 if invalid
392
+ return idx
393
+
394
+ def get_neigh(self, depth: int, kernel: str = '333', stride: int = 1,
395
+ nempty: bool = False):
396
+ r''' Returns the neighborhoods given the depth and a kernel shape.
397
+
398
+ Args:
399
+ depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
400
+ kernel (str): The kernel shape from :obj:`333`, :obj:`311`, :obj:`131`,
401
+ :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and :obj:`313`.
402
+ stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
403
+ stride is :obj:`2`, always returns the neighborhood of the first
404
+ siblings.
405
+ nempty (bool): If True, only returns the neighborhoods of the non-empty
406
+ octree nodes.
407
+ '''
408
+
409
+ if stride == 1:
410
+ neigh = self.neighs[depth]
411
+ elif stride == 2:
412
+ # clone neigh to avoid self.neigh[depth] being modified
413
+ neigh = self.neighs[depth][::8].clone()
414
+ else:
415
+ raise ValueError('Unsupported stride {}'.format(stride))
416
+
417
+ if nempty:
418
+ child = self.children[depth]
419
+ if stride == 1:
420
+ nempty_node = child >= 0
421
+ neigh = neigh[nempty_node]
422
+ valid = neigh >= 0
423
+ neigh[valid] = child[neigh[valid]].long() # remap the index
424
+
425
+ if kernel == '333':
426
+ return neigh
427
+ elif kernel in self.lut_kernel:
428
+ lut = self.lut_kernel[kernel]
429
+ return neigh[:, lut]
430
+ else:
431
+ raise ValueError('Unsupported kernel {}'.format(kernel))
432
+
433
+ def get_input_feature(self, feature: str, nempty: bool = False):
434
+ r''' Returns the initial input feature stored in octree.
435
+
436
+ Args:
437
+ feature (str): A string used to indicate which features to extract from
438
+ the input octree. If the character :obj:`N` is in :attr:`feature`, the
439
+ normal signal is extracted (3 channels). Similarly, if :obj:`D` is in
440
+ :attr:`feature`, the local displacement is extracted (1 channels). If
441
+ :obj:`L` is in :attr:`feature`, the local coordinates of the averaged
442
+ points in each octree node is extracted (3 channels). If :attr:`P` is
443
+ in :attr:`feature`, the global coordinates are extracted (3 channels).
444
+ If :attr:`F` is in :attr:`feature`, other features (like colors) are
445
+ extracted (k channels).
446
+ nempty (bool): If false, gets the features of all octree nodes.
447
+ '''
448
+
449
+ features = list()
450
+ depth = self.depth
451
+ feature = feature.upper()
452
+ if 'N' in feature:
453
+ features.append(self.normals[depth])
454
+
455
+ if 'L' in feature or 'D' in feature:
456
+ local_points = self.points[depth].frac() - 0.5
457
+
458
+ if 'D' in feature:
459
+ dis = torch.sum(local_points * self.normals[depth], dim=1, keepdim=True)
460
+ features.append(dis)
461
+
462
+ if 'L' in feature:
463
+ features.append(local_points)
464
+
465
+ if 'P' in feature:
466
+ scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1]
467
+ global_points = self.points[depth] * scale - 1.0
468
+ features.append(global_points)
469
+
470
+ if 'F' in feature:
471
+ features.append(self.features[depth])
472
+
473
+ out = torch.cat(features, dim=1)
474
+ if not nempty:
475
+ out = ocnn.nn.octree_pad(out, self, depth)
476
+ return out
477
+
478
+ def to_points(self, rescale: bool = True):
479
+ r''' Converts averaged points in the octree to a point cloud.
480
+
481
+ Args:
482
+ rescale (bool): rescale the xyz coordinates to [-1, 1] if True.
483
+ '''
484
+
485
+ depth = self.depth
486
+ batch_size = self.batch_size
487
+
488
+ # by default, use the average points generated when building the octree
489
+ # from the input point cloud
490
+ xyz = self.points[depth]
491
+ batch_id = self.batch_id(depth, nempty=True)
492
+
493
+ # xyz is None when the octree is predicted by a neural network
494
+ if xyz is None:
495
+ x, y, z, batch_id = self.xyzb(depth, nempty=True)
496
+ xyz = torch.stack([x, y, z], dim=1) + 0.5
497
+
498
+ # normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
499
+ if rescale:
500
+ scale = 2 ** (1 - depth)
501
+ xyz = self.points[depth] * scale - 1.0
502
+
503
+ # construct Points
504
+ out = Points(xyz, self.normals[depth], self.features[depth],
505
+ batch_id=batch_id, batch_size=batch_size)
506
+ return out
507
+
508
+ def to(self, device: Union[torch.device, str], non_blocking: bool = False):
509
+ r''' Moves the octree to a specified device.
510
+
511
+ Args:
512
+ device (torch.device or str): The destination device.
513
+ non_blocking (bool): If True and the source is in pinned memory, the copy
514
+ will be asynchronous with respect to the host. Otherwise, the argument
515
+ has no effect. Default: False.
516
+ '''
517
+
518
+ if isinstance(device, str):
519
+ device = torch.device(device)
520
+
521
+ # If on the save device, directly retrun self
522
+ if self.device == device:
523
+ return self
524
+
525
+ def list_to_device(prop):
526
+ return [p.to(device, non_blocking=non_blocking)
527
+ if isinstance(p, torch.Tensor) else None for p in prop]
528
+
529
+ # Construct a new Octree on the specified device
530
+ octree = Octree(self.depth, self.full_depth, self.batch_size, device)
531
+ octree.keys = list_to_device(self.keys)
532
+ octree.children = list_to_device(self.children)
533
+ octree.neighs = list_to_device(self.neighs)
534
+ octree.features = list_to_device(self.features)
535
+ octree.normals = list_to_device(self.normals)
536
+ octree.points = list_to_device(self.points)
537
+ octree.nnum = self.nnum.clone() # TODO: whether to move nnum to the self.device?
538
+ octree.nnum_nempty = self.nnum_nempty.clone()
539
+ octree.batch_nnum = self.batch_nnum.clone()
540
+ octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
541
+ return octree
542
+
543
+ def cuda(self, non_blocking: bool = False):
544
+ r''' Moves the octree to the GPU. '''
545
+
546
+ return self.to('cuda', non_blocking)
547
+
548
+ def cpu(self):
549
+ r''' Moves the octree to the CPU. '''
550
+
551
+ return self.to('cpu')
552
+
553
+
554
+ def merge_octrees(octrees: List['Octree']):
555
+ r''' Merges a list of octrees into one batch.
556
+
557
+ Args:
558
+ octrees (List[Octree]): A list of octrees to merge.
559
+
560
+ Returns:
561
+ Octree: The merged octree.
562
+ '''
563
+
564
+ # init and check
565
+ octree = Octree(depth=octrees[0].depth, full_depth=octrees[0].full_depth,
566
+ batch_size=len(octrees), device=octrees[0].device)
567
+ for i in range(1, octree.batch_size):
568
+ condition = (octrees[i].depth == octree.depth and
569
+ octrees[i].full_depth == octree.full_depth and
570
+ octrees[i].device == octree.device)
571
+ assert condition, 'The check of merge_octrees failed'
572
+
573
+ # node num
574
+ batch_nnum = torch.stack(
575
+ [octrees[i].nnum for i in range(octree.batch_size)], dim=1)
576
+ batch_nnum_nempty = torch.stack(
577
+ [octrees[i].nnum_nempty for i in range(octree.batch_size)], dim=1)
578
+ octree.nnum = torch.sum(batch_nnum, dim=1)
579
+ octree.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
580
+ octree.batch_nnum = batch_nnum
581
+ octree.batch_nnum_nempty = batch_nnum_nempty
582
+ nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
583
+
584
+ # merge octre properties
585
+ for d in range(octree.depth+1):
586
+ # key
587
+ keys = [None] * octree.batch_size
588
+ for i in range(octree.batch_size):
589
+ key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
590
+ keys[i] = key | (i << 48)
591
+ octree.keys[d] = torch.cat(keys, dim=0)
592
+
593
+ # children
594
+ children = [None] * octree.batch_size
595
+ for i in range(octree.batch_size):
596
+ child = octrees[i].children[d].clone() # !! `clone` is used here to avoid
597
+ mask = child >= 0 # !! modifying the original octrees
598
+ child[mask] = child[mask] + nnum_cum[d, i]
599
+ children[i] = child
600
+ octree.children[d] = torch.cat(children, dim=0)
601
+
602
+ # features
603
+ if octrees[0].features[d] is not None and d == octree.depth:
604
+ features = [octrees[i].features[d] for i in range(octree.batch_size)]
605
+ octree.features[d] = torch.cat(features, dim=0)
606
+
607
+ # normals
608
+ if octrees[0].normals[d] is not None and d == octree.depth:
609
+ normals = [octrees[i].normals[d] for i in range(octree.batch_size)]
610
+ octree.normals[d] = torch.cat(normals, dim=0)
611
+
612
+ # points
613
+ if octrees[0].points[d] is not None and d == octree.depth:
614
+ points = [octrees[i].points[d] for i in range(octree.batch_size)]
615
+ octree.points[d] = torch.cat(points, dim=0)
616
+
617
+ return octree
618
+
619
+
620
+ def init_octree(depth: int, full_depth: int = 2, batch_size: int = 1,
621
+ device: Union[torch.device, str] = 'cpu'):
622
+ r'''
623
+ Initializes an octree to :attr:`full_depth`.
624
+
625
+ Args:
626
+ depth (int): The depth of the octree.
627
+ full_depth (int): The octree layers with a depth small than
628
+ :attr:`full_depth` are forced to be full.
629
+ batch_size (int, optional): The batch size.
630
+ device (torch.device or str): The device to use for computation.
631
+
632
+ Returns:
633
+ Octree: The initialized Octree object.
634
+ '''
635
+
636
+ octree = Octree(depth, full_depth, batch_size, device)
637
+ for d in range(full_depth+1):
638
+ octree.octree_grow_full(depth=d)
639
+ return octree