ocnn 2.2.7__py3-none-any.whl → 2.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/octree/octree.py CHANGED
@@ -1,661 +1,770 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn.functional as F
10
- from typing import Union, List
11
-
12
- import ocnn
13
- from ocnn.octree.points import Points
14
- from ocnn.octree.shuffled_key import xyz2key, key2xyz
15
- from ocnn.utils import range_grid, scatter_add, cumsum, trunc_div
16
-
17
-
18
- class Octree:
19
- r''' Builds an octree from an input point cloud.
20
-
21
- Args:
22
- depth (int): The octree depth.
23
- full_depth (int): The octree layers with a depth small than
24
- :attr:`full_depth` are forced to be full.
25
- batch_size (int): The octree batch size.
26
- device (torch.device or str): Choose from :obj:`cpu` and :obj:`gpu`.
27
- (default: :obj:`cpu`)
28
-
29
- .. note::
30
- The octree data structure requires that if an octree node has children nodes,
31
- the number of children nodes is exactly 8, in which some of the nodes are
32
- empty and some nodes are non-empty. The properties of an octree, including
33
- :obj:`keys`, :obj:`children` and :obj:`neighs`, contain both non-empty and
34
- empty nodes, and other properties, including :obj:`features`, :obj:`normals`
35
- and :obj:`points`, contain only non-empty nodes.
36
-
37
- .. note::
38
- The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
39
- is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
40
- some margin.
41
- '''
42
-
43
- def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
44
- device: Union[torch.device, str] = 'cpu', **kwargs):
45
- super().__init__()
46
- self.depth = depth
47
- self.full_depth = full_depth
48
- self.batch_size = batch_size
49
- self.device = device
50
-
51
- self.reset()
52
-
53
- def reset(self):
54
- r''' Resets the Octree status and constructs several lookup tables.
55
- '''
56
-
57
- # octree features in each octree layers
58
- num = self.depth + 1
59
- self.keys = [None] * num
60
- self.children = [None] * num
61
- self.neighs = [None] * num
62
- self.features = [None] * num
63
- self.normals = [None] * num
64
- self.points = [None] * num
65
-
66
- # octree node numbers in each octree layers.
67
- # TODO: decide whether to settle them to 'gpu' or not?
68
- self.nnum = torch.zeros(num, dtype=torch.long)
69
- self.nnum_nempty = torch.zeros(num, dtype=torch.long)
70
-
71
- # the following properties are valid after `merge_octrees`.
72
- # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
73
- batch_size = self.batch_size
74
- self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
75
- self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
76
-
77
- # construct the look up tables for neighborhood searching
78
- device = self.device
79
- center_grid = range_grid(2, 3, device) # (8, 3)
80
- displacement = range_grid(-1, 1, device) # (27, 3)
81
- neigh_grid = center_grid.unsqueeze(1) + displacement # (8, 27, 3)
82
- parent_grid = trunc_div(neigh_grid, 2)
83
- child_grid = neigh_grid % 2
84
- self.lut_parent = torch.sum(
85
- parent_grid * torch.tensor([9, 3, 1], device=device), dim=2)
86
- self.lut_child = torch.sum(
87
- child_grid * torch.tensor([4, 2, 1], device=device), dim=2)
88
-
89
- # lookup tables for different kernel sizes
90
- self.lut_kernel = {
91
- '222': torch.tensor([13, 14, 16, 17, 22, 23, 25, 26], device=device),
92
- '311': torch.tensor([4, 13, 22], device=device),
93
- '131': torch.tensor([10, 13, 16], device=device),
94
- '113': torch.tensor([12, 13, 14], device=device),
95
- '331': torch.tensor([1, 4, 7, 10, 13, 16, 19, 22, 25], device=device),
96
- '313': torch.tensor([3, 4, 5, 12, 13, 14, 21, 22, 23], device=device),
97
- '133': torch.tensor([9, 10, 11, 12, 13, 14, 15, 16, 17], device=device),
98
- }
99
-
100
- def key(self, depth: int, nempty: bool = False):
101
- r''' Returns the shuffled key of each octree node.
102
-
103
- Args:
104
- depth (int): The depth of the octree.
105
- nempty (bool): If True, returns the results of non-empty octree nodes.
106
- '''
107
-
108
- key = self.keys[depth]
109
- if nempty:
110
- mask = self.nempty_mask(depth)
111
- key = key[mask]
112
- return key
113
-
114
- def xyzb(self, depth: int, nempty: bool = False):
115
- r''' Returns the xyz coordinates and the batch indices of each octree node.
116
-
117
- Args:
118
- depth (int): The depth of the octree.
119
- nempty (bool): If True, returns the results of non-empty octree nodes.
120
- '''
121
-
122
- key = self.key(depth, nempty)
123
- return key2xyz(key, depth)
124
-
125
- def batch_id(self, depth: int, nempty: bool = False):
126
- r''' Returns the batch indices of each octree node.
127
-
128
- Args:
129
- depth (int): The depth of the octree.
130
- nempty (bool): If True, returns the results of non-empty octree nodes.
131
- '''
132
-
133
- batch_id = self.keys[depth] >> 48
134
- if nempty:
135
- mask = self.nempty_mask(depth)
136
- batch_id = batch_id[mask]
137
- return batch_id
138
-
139
- def nempty_mask(self, depth: int):
140
- r''' Returns a binary mask which indicates whether the cooreponding octree
141
- node is empty or not.
142
-
143
- Args:
144
- depth (int): The depth of the octree.
145
- '''
146
-
147
- return self.children[depth] >= 0
148
-
149
- def build_octree(self, point_cloud: Points):
150
- r''' Builds an octree from a point cloud.
151
-
152
- Args:
153
- point_cloud (Points): The input point cloud.
154
-
155
- .. note::
156
- The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
157
- is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
158
- some margin.
159
- '''
160
-
161
- self.device = point_cloud.device
162
- assert point_cloud.batch_size == self.batch_size, 'Inconsistent batch_size'
163
-
164
- # normalize points from [-1, 1] to [0, 2^depth]. #[L:Scale]
165
- scale = 2 ** (self.depth - 1)
166
- points = (point_cloud.points + 1.0) * scale
167
-
168
- # get the shuffled key and sort
169
- x, y, z = points[:, 0], points[:, 1], points[:, 2]
170
- b = None if self.batch_size == 1 else point_cloud.batch_id.view(-1)
171
- key = xyz2key(x, y, z, b, self.depth)
172
- node_key, idx, counts = torch.unique(
173
- key, sorted=True, return_inverse=True, return_counts=True)
174
-
175
- # layer 0 to full_layer: the octree is full in these layers
176
- for d in range(self.full_depth+1):
177
- self.octree_grow_full(d, update_neigh=False)
178
-
179
- # layer depth_ to full_layer_
180
- for d in range(self.depth, self.full_depth, -1):
181
- # compute parent key, i.e. keys of layer (d -1)
182
- pkey = node_key >> 3
183
- pkey, pidx, _ = torch.unique_consecutive(
184
- pkey, return_inverse=True, return_counts=True)
185
-
186
- # augmented key
187
- key = (pkey.unsqueeze(-1) << 3) + torch.arange(8, device=self.device)
188
- self.keys[d] = key.view(-1)
189
- self.nnum[d] = key.numel()
190
- self.nnum_nempty[d] = node_key.numel()
191
-
192
- # children
193
- addr = (pidx << 3) | (node_key % 8)
194
- children = -torch.ones(
195
- self.nnum[d].item(), dtype=torch.int32, device=self.device)
196
- children[addr] = torch.arange(
197
- self.nnum_nempty[d], dtype=torch.int32, device=self.device)
198
- self.children[d] = children
199
-
200
- # cache pkey for the next iteration
201
- # Use `pkey >> 45` instead of `pkey >> 48` in L199 since pkey is already
202
- # shifted to the right by 3 bits in L177
203
- node_key = pkey if self.batch_size == 1 else \
204
- ((pkey >> 45) << 48) | (pkey & ((1 << 45) - 1))
205
-
206
- # set the children for the layer full_layer,
207
- # now the node_keys are the key for full_layer
208
- d = self.full_depth
209
- children = -torch.ones_like(self.children[d])
210
- nempty_idx = node_key if self.batch_size == 1 else \
211
- ((node_key >> 48) << (3 * d)) | (node_key & ((1 << 48) - 1))
212
- children[nempty_idx] = torch.arange(
213
- node_key.numel(), dtype=torch.int32, device=self.device)
214
- self.children[d] = children
215
- self.nnum_nempty[d] = node_key.numel()
216
-
217
- # average the signal for the last octree layer
218
- d = self.depth
219
- points = scatter_add(points, idx, dim=0) # points is rescaled in [L:Scale]
220
- self.points[d] = points / counts.unsqueeze(1)
221
- if point_cloud.normals is not None:
222
- normals = scatter_add(point_cloud.normals, idx, dim=0)
223
- self.normals[d] = F.normalize(normals)
224
- if point_cloud.features is not None:
225
- features = scatter_add(point_cloud.features, idx, dim=0)
226
- self.features[d] = features / counts.unsqueeze(1)
227
-
228
- return idx
229
-
230
- def octree_grow_full(self, depth: int, update_neigh: bool = True):
231
- r''' Builds the full octree, which is essentially a dense volumetric grid.
232
-
233
- Args:
234
- depth (int): The depth of the octree.
235
- update_neigh (bool): If True, construct the neighborhood indices.
236
- '''
237
-
238
- # check
239
- assert depth <= self.full_depth, 'error'
240
-
241
- # node number
242
- num = 1 << (3 * depth)
243
- self.nnum[depth] = num * self.batch_size
244
- self.nnum_nempty[depth] = num * self.batch_size
245
-
246
- # update key
247
- key = torch.arange(num, dtype=torch.long, device=self.device)
248
- bs = torch.arange(self.batch_size, dtype=torch.long, device=self.device)
249
- key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
250
- self.keys[depth] = key.view(-1)
251
-
252
- # update children
253
- self.children[depth] = torch.arange(
254
- num * self.batch_size, dtype=torch.int32, device=self.device)
255
-
256
- # update neigh if needed
257
- if update_neigh:
258
- self.construct_neigh(depth)
259
-
260
- def octree_split(self, split: torch.Tensor, depth: int):
261
- r''' Sets whether the octree nodes in :attr:`depth` are splitted or not.
262
-
263
- Args:
264
- split (torch.Tensor): The input tensor with its element indicating status
265
- of each octree node: 0 - empty, 1 - non-empty or splitted.
266
- depth (int): The depth of current octree.
267
- '''
268
-
269
- # split -> children
270
- empty = split == 0
271
- sum = cumsum(split, dim=0, exclusive=True)
272
- children, nnum_nempty = torch.split(sum, [split.shape[0], 1])
273
- children[empty] = -1
274
-
275
- # boundary case, make sure that at least one octree node is splitted
276
- if nnum_nempty == 0:
277
- nnum_nempty = 1
278
- children[0] = 0
279
-
280
- # update octree
281
- self.children[depth] = children.int()
282
- self.nnum_nempty[depth] = nnum_nempty
283
-
284
- def octree_grow(self, depth: int, update_neigh: bool = True):
285
- r''' Grows the octree and updates the relevant properties. And in most
286
- cases, call :func:`Octree.octree_split` to update the splitting status of
287
- the octree before this function.
288
-
289
- Args:
290
- depth (int): The depth of the octree.
291
- update_neigh (bool): If True, construct the neighborhood indices.
292
- '''
293
-
294
- # increase the octree depth if required
295
- if depth > self.depth:
296
- assert depth == self.depth + 1
297
- self.depth = depth
298
- self.keys.append(None)
299
- self.children.append(None)
300
- self.neighs.append(None)
301
- self.features.append(None)
302
- self.normals.append(None)
303
- self.points.append(None)
304
- zero = torch.zeros(1, dtype=torch.long)
305
- self.nnum = torch.cat([self.nnum, zero])
306
- self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
307
- zero = zero.view(1, 1)
308
- self.batch_nnum = torch.cat([self.batch_nnum, zero], dim=0)
309
- self.batch_nnum_nempty = torch.cat([self.batch_nnum_nempty, zero], dim=0)
310
-
311
- # node number
312
- nnum = self.nnum_nempty[depth-1] * 8
313
- self.nnum[depth] = nnum
314
- self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
315
-
316
- # update keys
317
- key = self.key(depth-1, nempty=True)
318
- batch_id = (key >> 48) << 48
319
- key = (key & ((1 << 48) - 1)) << 3
320
- key = key | batch_id
321
- key = key.unsqueeze(1) + torch.arange(8, device=key.device)
322
- self.keys[depth] = key.view(-1)
323
-
324
- # update children
325
- self.children[depth] = torch.arange(
326
- nnum, dtype=torch.int32, device=self.device)
327
-
328
- # update neighs
329
- if update_neigh:
330
- self.construct_neigh(depth)
331
-
332
- def construct_neigh(self, depth: int):
333
- r''' Constructs the :obj:`3x3x3` neighbors for each octree node.
334
-
335
- Args:
336
- depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
337
- '''
338
-
339
- if depth <= self.full_depth:
340
- device = self.device
341
- nnum = 1 << (3 * depth)
342
- key = torch.arange(nnum, dtype=torch.long, device=device)
343
- x, y, z, _ = key2xyz(key, depth)
344
- xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
345
- grid = range_grid(-1, 1, device) # (27, 3)
346
- xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
347
- xyz = xyz.view(-1, 3) # (N*27, 3)
348
- neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
349
-
350
- bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
351
- neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
352
-
353
- bound = 1 << depth
354
- invalid = torch.logical_or((xyz < 0).any(1), (xyz >= bound).any(1))
355
- neigh[:, invalid] = -1
356
- self.neighs[depth] = neigh.view(-1, 27) # (B*N, 27)
357
-
358
- else:
359
- child_p = self.children[depth-1]
360
- nempty = child_p >= 0
361
- neigh_p = self.neighs[depth-1][nempty] # (N, 27)
362
- neigh_p = neigh_p[:, self.lut_parent] # (N, 8, 27)
363
- child_p = child_p[neigh_p] # (N, 8, 27)
364
- invalid = torch.logical_or(child_p < 0, neigh_p < 0) # (N, 8, 27)
365
- neigh = child_p * 8 + self.lut_child
366
- neigh[invalid] = -1
367
- self.neighs[depth] = neigh.view(-1, 27)
368
-
369
- def construct_all_neigh(self):
370
- r''' A convenient handler for constructing all neighbors.
371
- '''
372
-
373
- for depth in range(1, self.depth+1):
374
- self.construct_neigh(depth)
375
-
376
- def search_xyzb(self, query: torch.Tensor, depth: int, nempty: bool = False):
377
- r''' Searches the octree nodes given the query points.
378
-
379
- Args:
380
- query (torch.Tensor): The coordinates of query points with shape
381
- :obj:`(N, 4)`. The first 3 channels of the coordinates are :obj:`x`,
382
- :obj:`y`, and :obj:`z`, and the last channel is the batch index. Note
383
- that the coordinates must be in range :obj:`[0, 2^depth)`.
384
- depth (int): The depth of the octree layer. nemtpy (bool): If true, only
385
- searches the non-empty octree nodes.
386
- '''
387
-
388
- key = xyz2key(query[:, 0], query[:, 1], query[:, 2], query[:, 3], depth)
389
- idx = self.search_key(key, depth, nempty)
390
- return idx
391
-
392
- def search_key(self, query: torch.Tensor, depth: int, nempty: bool = False):
393
- r''' Searches the octree nodes given the query points.
394
-
395
- Args:
396
- query (torch.Tensor): The keys of query points with shape :obj:`(N,)`,
397
- which are computed from the coordinates of query points.
398
- depth (int): The depth of the octree layer. nemtpy (bool): If true, only
399
- searches the non-empty octree nodes.
400
- '''
401
-
402
- key = self.key(depth, nempty)
403
- # `torch.bucketize` is similar to `torch.searchsorted`.
404
- # I choose `torch.bucketize` here because it has fewer dimension checks,
405
- # resulting in slightly better performance according to the docs of
406
- # pytorch-1.9.1, since `key` is always 1-D sorted sequence.
407
- # https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
408
- idx = torch.bucketize(query, key)
409
-
410
- valid = idx < key.shape[0] # valid if in-bound
411
- found = key[idx[valid]] == query[valid]
412
- valid[valid.clone()] = found # valid if found
413
- idx[valid.logical_not()] = -1 # set to -1 if invalid
414
- return idx
415
-
416
- def get_neigh(self, depth: int, kernel: str = '333', stride: int = 1,
417
- nempty: bool = False):
418
- r''' Returns the neighborhoods given the depth and a kernel shape.
419
-
420
- Args:
421
- depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
422
- kernel (str): The kernel shape from :obj:`333`, :obj:`311`, :obj:`131`,
423
- :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and :obj:`313`.
424
- stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
425
- stride is :obj:`2`, always returns the neighborhood of the first
426
- siblings.
427
- nempty (bool): If True, only returns the neighborhoods of the non-empty
428
- octree nodes.
429
- '''
430
-
431
- if stride == 1:
432
- neigh = self.neighs[depth]
433
- elif stride == 2:
434
- # clone neigh to avoid self.neigh[depth] being modified
435
- neigh = self.neighs[depth][::8].clone()
436
- else:
437
- raise ValueError('Unsupported stride {}'.format(stride))
438
-
439
- if nempty:
440
- child = self.children[depth]
441
- if stride == 1:
442
- nempty_node = child >= 0
443
- neigh = neigh[nempty_node]
444
- valid = neigh >= 0
445
- neigh[valid] = child[neigh[valid]].long() # remap the index
446
-
447
- if kernel == '333':
448
- return neigh
449
- elif kernel in self.lut_kernel:
450
- lut = self.lut_kernel[kernel]
451
- return neigh[:, lut]
452
- else:
453
- raise ValueError('Unsupported kernel {}'.format(kernel))
454
-
455
- def get_input_feature(self, feature: str, nempty: bool = False):
456
- r''' Returns the initial input feature stored in octree.
457
-
458
- Args:
459
- feature (str): A string used to indicate which features to extract from
460
- the input octree. If the character :obj:`N` is in :attr:`feature`, the
461
- normal signal is extracted (3 channels). Similarly, if :obj:`D` is in
462
- :attr:`feature`, the local displacement is extracted (1 channels). If
463
- :obj:`L` is in :attr:`feature`, the local coordinates of the averaged
464
- points in each octree node is extracted (3 channels). If :attr:`P` is
465
- in :attr:`feature`, the global coordinates are extracted (3 channels).
466
- If :attr:`F` is in :attr:`feature`, other features (like colors) are
467
- extracted (k channels).
468
- nempty (bool): If false, gets the features of all octree nodes.
469
- '''
470
-
471
- features = list()
472
- depth = self.depth
473
- feature = feature.upper()
474
- if 'N' in feature:
475
- features.append(self.normals[depth])
476
-
477
- if 'L' in feature or 'D' in feature:
478
- local_points = self.points[depth].frac() - 0.5
479
-
480
- if 'D' in feature:
481
- dis = torch.sum(local_points * self.normals[depth], dim=1, keepdim=True)
482
- features.append(dis)
483
-
484
- if 'L' in feature:
485
- features.append(local_points)
486
-
487
- if 'P' in feature:
488
- scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1]
489
- global_points = self.points[depth] * scale - 1.0
490
- features.append(global_points)
491
-
492
- if 'F' in feature:
493
- features.append(self.features[depth])
494
-
495
- out = torch.cat(features, dim=1)
496
- if not nempty:
497
- out = ocnn.nn.octree_pad(out, self, depth)
498
- return out
499
-
500
- def to_points(self, rescale: bool = True):
501
- r''' Converts averaged points in the octree to a point cloud.
502
-
503
- Args:
504
- rescale (bool): rescale the xyz coordinates to [-1, 1] if True.
505
- '''
506
-
507
- depth = self.depth
508
- batch_size = self.batch_size
509
-
510
- # by default, use the average points generated when building the octree
511
- # from the input point cloud
512
- xyz = self.points[depth]
513
- batch_id = self.batch_id(depth, nempty=True)
514
-
515
- # xyz is None when the octree is predicted by a neural network
516
- if xyz is None:
517
- x, y, z, batch_id = self.xyzb(depth, nempty=True)
518
- xyz = torch.stack([x, y, z], dim=1) + 0.5
519
-
520
- # normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
521
- if rescale:
522
- scale = 2 ** (1 - depth)
523
- xyz = xyz * scale - 1.0
524
-
525
- # construct Points
526
- out = Points(xyz, self.normals[depth], self.features[depth],
527
- batch_id=batch_id, batch_size=batch_size)
528
- return out
529
-
530
- def to(self, device: Union[torch.device, str], non_blocking: bool = False):
531
- r''' Moves the octree to a specified device.
532
-
533
- Args:
534
- device (torch.device or str): The destination device.
535
- non_blocking (bool): If True and the source is in pinned memory, the copy
536
- will be asynchronous with respect to the host. Otherwise, the argument
537
- has no effect. Default: False.
538
- '''
539
-
540
- if isinstance(device, str):
541
- device = torch.device(device)
542
-
543
- # If on the save device, directly retrun self
544
- if self.device == device:
545
- return self
546
-
547
- def list_to_device(prop):
548
- return [p.to(device, non_blocking=non_blocking)
549
- if isinstance(p, torch.Tensor) else None for p in prop]
550
-
551
- # Construct a new Octree on the specified device
552
- octree = Octree(self.depth, self.full_depth, self.batch_size, device)
553
- octree.keys = list_to_device(self.keys)
554
- octree.children = list_to_device(self.children)
555
- octree.neighs = list_to_device(self.neighs)
556
- octree.features = list_to_device(self.features)
557
- octree.normals = list_to_device(self.normals)
558
- octree.points = list_to_device(self.points)
559
- octree.nnum = self.nnum.clone() # TODO: whether to move nnum to the self.device?
560
- octree.nnum_nempty = self.nnum_nempty.clone()
561
- octree.batch_nnum = self.batch_nnum.clone()
562
- octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
563
- return octree
564
-
565
- def cuda(self, non_blocking: bool = False):
566
- r''' Moves the octree to the GPU. '''
567
-
568
- return self.to('cuda', non_blocking)
569
-
570
- def cpu(self):
571
- r''' Moves the octree to the CPU. '''
572
-
573
- return self.to('cpu')
574
-
575
-
576
- def merge_octrees(octrees: List['Octree']):
577
- r''' Merges a list of octrees into one batch.
578
-
579
- Args:
580
- octrees (List[Octree]): A list of octrees to merge.
581
-
582
- Returns:
583
- Octree: The merged octree.
584
- '''
585
-
586
- # init and check
587
- octree = Octree(depth=octrees[0].depth, full_depth=octrees[0].full_depth,
588
- batch_size=len(octrees), device=octrees[0].device)
589
- for i in range(1, octree.batch_size):
590
- condition = (octrees[i].depth == octree.depth and
591
- octrees[i].full_depth == octree.full_depth and
592
- octrees[i].device == octree.device)
593
- assert condition, 'The check of merge_octrees failed'
594
-
595
- # node num
596
- batch_nnum = torch.stack(
597
- [octrees[i].nnum for i in range(octree.batch_size)], dim=1)
598
- batch_nnum_nempty = torch.stack(
599
- [octrees[i].nnum_nempty for i in range(octree.batch_size)], dim=1)
600
- octree.nnum = torch.sum(batch_nnum, dim=1)
601
- octree.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
602
- octree.batch_nnum = batch_nnum
603
- octree.batch_nnum_nempty = batch_nnum_nempty
604
- nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
605
-
606
- # merge octre properties
607
- for d in range(octree.depth+1):
608
- # key
609
- keys = [None] * octree.batch_size
610
- for i in range(octree.batch_size):
611
- key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
612
- keys[i] = key | (i << 48)
613
- octree.keys[d] = torch.cat(keys, dim=0)
614
-
615
- # children
616
- children = [None] * octree.batch_size
617
- for i in range(octree.batch_size):
618
- child = octrees[i].children[d].clone() # !! `clone` is used here to avoid
619
- mask = child >= 0 # !! modifying the original octrees
620
- child[mask] = child[mask] + nnum_cum[d, i]
621
- children[i] = child
622
- octree.children[d] = torch.cat(children, dim=0)
623
-
624
- # features
625
- if octrees[0].features[d] is not None and d == octree.depth:
626
- features = [octrees[i].features[d] for i in range(octree.batch_size)]
627
- octree.features[d] = torch.cat(features, dim=0)
628
-
629
- # normals
630
- if octrees[0].normals[d] is not None and d == octree.depth:
631
- normals = [octrees[i].normals[d] for i in range(octree.batch_size)]
632
- octree.normals[d] = torch.cat(normals, dim=0)
633
-
634
- # points
635
- if octrees[0].points[d] is not None and d == octree.depth:
636
- points = [octrees[i].points[d] for i in range(octree.batch_size)]
637
- octree.points[d] = torch.cat(points, dim=0)
638
-
639
- return octree
640
-
641
-
642
- def init_octree(depth: int, full_depth: int = 2, batch_size: int = 1,
643
- device: Union[torch.device, str] = 'cpu'):
644
- r'''
645
- Initializes an octree to :attr:`full_depth`.
646
-
647
- Args:
648
- depth (int): The depth of the octree.
649
- full_depth (int): The octree layers with a depth small than
650
- :attr:`full_depth` are forced to be full.
651
- batch_size (int, optional): The batch size.
652
- device (torch.device or str): The device to use for computation.
653
-
654
- Returns:
655
- Octree: The initialized Octree object.
656
- '''
657
-
658
- octree = Octree(depth, full_depth, batch_size, device)
659
- for d in range(full_depth+1):
660
- octree.octree_grow_full(depth=d)
661
- return octree
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from typing import Union, List
11
+
12
+ import ocnn
13
+ from ocnn.octree.points import Points
14
+ from ocnn.octree.shuffled_key import xyz2key, key2xyz
15
+ from ocnn.utils import range_grid, scatter_add, cumsum, trunc_div
16
+
17
+
18
+ class Octree:
19
+ r''' Builds an octree from an input point cloud.
20
+
21
+ Args:
22
+ depth (int): The octree depth.
23
+ full_depth (int): The octree layers with a depth small than
24
+ :attr:`full_depth` are forced to be full.
25
+ batch_size (int): The octree batch size.
26
+ device (torch.device or str): Choose from :obj:`cpu` and :obj:`gpu`.
27
+ (default: :obj:`cpu`)
28
+
29
+ .. note::
30
+ The octree data structure requires that if an octree node has children nodes,
31
+ the number of children nodes is exactly 8, in which some of the nodes are
32
+ empty and some nodes are non-empty. The properties of an octree, including
33
+ :obj:`keys`, :obj:`children` and :obj:`neighs`, contain both non-empty and
34
+ empty nodes, and other properties, including :obj:`features`, :obj:`normals`
35
+ and :obj:`points`, contain only non-empty nodes.
36
+
37
+ .. note::
38
+ The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
39
+ is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
40
+ some margin.
41
+ '''
42
+
43
+ def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
44
+ device: Union[torch.device, str] = 'cpu', **kwargs):
45
+ super().__init__()
46
+ # configurations for initialization
47
+ self.depth = depth
48
+ self.full_depth = full_depth
49
+ self.batch_size = batch_size
50
+ self.device = device
51
+
52
+ # properties after building the octree
53
+ self.reset()
54
+
55
+ def reset(self):
56
+ r''' Resets the Octree status and constructs several lookup tables.
57
+ '''
58
+
59
+ # octree features in each octree layers
60
+ num = self.depth + 1
61
+ self.keys = [None] * num
62
+ self.children = [None] * num
63
+ self.neighs = [None] * num
64
+ self.features = [None] * num
65
+ self.normals = [None] * num
66
+ self.points = [None] * num
67
+
68
+ # self.nempty_masks, self.nempty_indices and self.nempty_neighs are
69
+ # for handling of non-empty nodes and are constructed on demand
70
+ self.nempty_masks = [None] * num
71
+ self.nempty_indices = [None] * num
72
+ self.nempty_neighs = [None] * num
73
+
74
+ # octree node numbers in each octree layers.
75
+ # These are small 1-D tensors; just keep them on CPUs
76
+ self.nnum = torch.zeros(num, dtype=torch.long)
77
+ self.nnum_nempty = torch.zeros(num, dtype=torch.long)
78
+
79
+ # the following properties are only valid after `merge_octrees`.
80
+ # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
81
+ batch_size = self.batch_size
82
+ self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
83
+ self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
84
+
85
+ # construct the look up tables for neighborhood searching
86
+ device = self.device
87
+ center_grid = range_grid(2, 3, device) # (8, 3)
88
+ displacement = range_grid(-1, 1, device) # (27, 3)
89
+ neigh_grid = center_grid.unsqueeze(1) + displacement # (8, 27, 3)
90
+ parent_grid = trunc_div(neigh_grid, 2)
91
+ child_grid = neigh_grid % 2
92
+ self.lut_parent = torch.sum(
93
+ parent_grid * torch.tensor([9, 3, 1], device=device), dim=2)
94
+ self.lut_child = torch.sum(
95
+ child_grid * torch.tensor([4, 2, 1], device=device), dim=2)
96
+
97
+ # lookup tables for different kernel sizes
98
+ self.lut_kernel = {
99
+ '222': torch.tensor([13, 14, 16, 17, 22, 23, 25, 26], device=device),
100
+ '311': torch.tensor([4, 13, 22], device=device),
101
+ '131': torch.tensor([10, 13, 16], device=device),
102
+ '113': torch.tensor([12, 13, 14], device=device),
103
+ '331': torch.tensor([1, 4, 7, 10, 13, 16, 19, 22, 25], device=device),
104
+ '313': torch.tensor([3, 4, 5, 12, 13, 14, 21, 22, 23], device=device),
105
+ '133': torch.tensor([9, 10, 11, 12, 13, 14, 15, 16, 17], device=device),
106
+ }
107
+
108
+ def key(self, depth: int, nempty: bool = False):
109
+ r''' Returns the shuffled key of each octree node.
110
+
111
+ Args:
112
+ depth (int): The depth of the octree.
113
+ nempty (bool): If True, returns the results of non-empty octree nodes.
114
+ '''
115
+
116
+ key = self.keys[depth]
117
+ if nempty:
118
+ idx = self.nempty_index(depth)
119
+ key = key[idx]
120
+ return key
121
+
122
+ def xyzb(self, depth: int, nempty: bool = False):
123
+ r''' Returns the xyz coordinates and the batch indices of each octree node.
124
+
125
+ Args:
126
+ depth (int): The depth of the octree.
127
+ nempty (bool): If True, returns the results of non-empty octree nodes.
128
+ '''
129
+
130
+ key = self.key(depth, nempty)
131
+ return key2xyz(key, depth)
132
+
133
+ def batch_id(self, depth: int, nempty: bool = False):
134
+ r''' Returns the batch indices of each octree node.
135
+
136
+ Args:
137
+ depth (int): The depth of the octree.
138
+ nempty (bool): If True, returns the results of non-empty octree nodes.
139
+ '''
140
+
141
+ batch_id = self.keys[depth] >> 48
142
+ if nempty:
143
+ idx = self.nempty_index(depth)
144
+ batch_id = batch_id[idx]
145
+ return batch_id
146
+
147
+ def nempty_mask(self, depth: int, reset: bool = False):
148
+ r''' Returns a binary mask which indicates whether the cooreponding octree
149
+ node is empty or not.
150
+
151
+ Args:
152
+ depth (int): The depth of the octree.
153
+ reset (bool): If True, recomputes the mask.
154
+ '''
155
+
156
+ if self.nempty_masks[depth] is None or reset:
157
+ self.nempty_masks[depth] = self.children[depth] >= 0
158
+ return self.nempty_masks[depth]
159
+
160
+ def nempty_index(self, depth: int, reset: bool = False):
161
+ r''' Returns the indices of non-empty octree nodes.
162
+
163
+ Args:
164
+ depth (int): The depth of the octree.
165
+ reset (bool): If True, recomputes the indices.
166
+ '''
167
+
168
+ if self.nempty_indices[depth] is None or reset:
169
+ mask = self.nempty_mask(depth)
170
+ rng = torch.arange(mask.shape[0], device=mask.device, dtype=torch.long)
171
+ self.nempty_indices[depth] = rng[mask]
172
+ return self.nempty_indices[depth]
173
+
174
+ def nempty_neigh(self, depth: int, reset: bool = False):
175
+ r''' Returns the neighborhoods of non-empty octree nodes.
176
+ Args:
177
+
178
+ depth (int): The depth of the octree.
179
+ reset (bool): If True, recomputes the neighborhoods.
180
+ '''
181
+ if self.nempty_neighs[depth] is None or reset:
182
+ neigh = self.neighs[depth]
183
+ idx = self.nempty_index(depth)
184
+ neigh = self.remap_nempty_neigh(neigh[idx], depth)
185
+ self.nempty_neighs[depth] = neigh
186
+ return self.nempty_neighs[depth]
187
+
188
+ def remap_nempty_neigh(self, neigh: torch.Tensor, depth: int):
189
+ r''' Remaps the neighborhood indices to the non-empty octree nodes.
190
+
191
+ Args:
192
+ neigh (torch.Tensor): The input neighborhoods with shape :obj:`(N, 27)`.
193
+ depth (int): The depth of the octree.
194
+ '''
195
+
196
+ valid = neigh >= 0
197
+ child = self.children[depth]
198
+ neigh[valid] = child[neigh[valid]].long() # remap the index
199
+ return neigh
200
+
201
+ def build_octree(self, point_cloud: Points):
202
+ r''' Builds an octree from a point cloud.
203
+
204
+ Args:
205
+ point_cloud (Points): The input point cloud.
206
+
207
+ .. note::
208
+ The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
209
+ is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
210
+ some margin.
211
+ '''
212
+
213
+ self.device = point_cloud.device
214
+ assert point_cloud.batch_size == self.batch_size, 'Inconsistent batch_size'
215
+
216
+ # normalize points from [-1, 1] to [0, 2^depth]. #[L:Scale]
217
+ scale = 2 ** (self.depth - 1)
218
+ points = (point_cloud.points + 1.0) * scale
219
+
220
+ # get the shuffled key and sort
221
+ x, y, z = points[:, 0], points[:, 1], points[:, 2]
222
+ b = None if self.batch_size == 1 else point_cloud.batch_id.view(-1)
223
+ key = xyz2key(x, y, z, b, self.depth)
224
+ node_key, idx, counts = torch.unique(
225
+ key, sorted=True, return_inverse=True, return_counts=True)
226
+
227
+ # layer 0 to full_layer: the octree is full in these layers
228
+ for d in range(self.full_depth+1):
229
+ self.octree_grow_full(d, update_neigh=False)
230
+
231
+ # layer depth_ to full_layer_
232
+ for d in range(self.depth, self.full_depth, -1):
233
+ # compute parent key, i.e. keys of layer (d -1)
234
+ pkey = node_key >> 3
235
+ pkey, pidx, _ = torch.unique_consecutive(
236
+ pkey, return_inverse=True, return_counts=True)
237
+
238
+ # augmented key
239
+ key = (pkey.unsqueeze(-1) << 3) + torch.arange(8, device=self.device)
240
+ self.keys[d] = key.view(-1)
241
+ self.nnum[d] = key.numel()
242
+ self.nnum_nempty[d] = node_key.numel()
243
+
244
+ # children
245
+ addr = (pidx << 3) | (node_key % 8)
246
+ children = -torch.ones(
247
+ self.nnum[d].item(), dtype=torch.int32, device=self.device)
248
+ children[addr] = torch.arange(
249
+ self.nnum_nempty[d], dtype=torch.int32, device=self.device)
250
+ self.children[d] = children
251
+
252
+ # cache pkey for the next iteration
253
+ # Use `pkey >> 45` instead of `pkey >> 48` in L199 since pkey is already
254
+ # shifted to the right by 3 bits in L177
255
+ node_key = pkey if self.batch_size == 1 else \
256
+ ((pkey >> 45) << 48) | (pkey & ((1 << 45) - 1))
257
+
258
+ # set the children for the layer full_layer,
259
+ # now the node_keys are the key for full_layer
260
+ d = self.full_depth
261
+ children = -torch.ones_like(self.children[d])
262
+ nempty_idx = node_key if self.batch_size == 1 else \
263
+ ((node_key >> 48) << (3 * d)) | (node_key & ((1 << 48) - 1))
264
+ children[nempty_idx] = torch.arange(
265
+ node_key.numel(), dtype=torch.int32, device=self.device)
266
+ self.children[d] = children
267
+ self.nnum_nempty[d] = node_key.numel()
268
+
269
+ # average the signal for the last octree layer
270
+ d = self.depth
271
+ points = scatter_add(points, idx, dim=0) # points is rescaled in [L:Scale]
272
+ self.points[d] = points / counts.unsqueeze(1)
273
+ if point_cloud.normals is not None:
274
+ normals = scatter_add(point_cloud.normals, idx, dim=0)
275
+ self.normals[d] = F.normalize(normals)
276
+ if point_cloud.features is not None:
277
+ features = scatter_add(point_cloud.features, idx, dim=0)
278
+ self.features[d] = features / counts.unsqueeze(1)
279
+
280
+ # reset nempty_masks and nempty_indices, which will be updated on demand
281
+ self.nempty_masks = [None] * (self.depth + 1)
282
+ self.nempty_indices = [None] * (self.depth + 1)
283
+ self.nempty_neighs = [None] * (self.depth + 1)
284
+ return idx
285
+
286
+ def octree_grow_full(self, depth: int, update_neigh: bool = True):
287
+ r''' Builds the full octree, which is essentially a dense volumetric grid.
288
+
289
+ Args:
290
+ depth (int): The depth of the octree.
291
+ update_neigh (bool): If True, construct the neighborhood indices.
292
+ '''
293
+
294
+ # check
295
+ assert depth <= self.full_depth, 'error'
296
+
297
+ # node number
298
+ num = 1 << (3 * depth)
299
+ batch_size = self.batch_size
300
+ self.nnum[depth] = num * batch_size
301
+ self.nnum_nempty[depth] = num * batch_size
302
+
303
+ # update key
304
+ key = torch.arange(num, dtype=torch.long, device=self.device)
305
+ bs = torch.arange(batch_size, dtype=torch.long, device=self.device)
306
+ key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
307
+ self.keys[depth] = key.view(-1)
308
+
309
+ # update children
310
+ self.children[depth] = torch.arange(
311
+ num * batch_size, dtype=torch.int32, device=self.device)
312
+
313
+ # nempty_masks, nempty_indices, and nempty_neighs
314
+ # need not be reset for full octrees
315
+
316
+ # update neigh if needed
317
+ if update_neigh:
318
+ self.construct_neigh(depth)
319
+
320
+ def octree_split(self, split: torch.Tensor, depth: int):
321
+ r''' Sets whether the octree nodes in :attr:`depth` are splitted or not.
322
+
323
+ Args:
324
+ split (torch.Tensor): The input tensor with its element indicating status
325
+ of each octree node: 0 - empty, 1 - non-empty or splitted.
326
+ depth (int): The depth of current octree.
327
+ '''
328
+
329
+ # split -> children
330
+ empty = split == 0
331
+ sum = cumsum(split, dim=0, exclusive=True)
332
+ children, nnum_nempty = torch.split(sum, [split.shape[0], 1])
333
+ children[empty] = -1
334
+
335
+ # boundary case, make sure that at least one octree node is splitted
336
+ if nnum_nempty == 0:
337
+ nnum_nempty = 1
338
+ children[0] = 0
339
+
340
+ # update octree
341
+ self.children[depth] = children.int()
342
+ self.nnum_nempty[depth] = nnum_nempty
343
+
344
+ # reset nempty_masks, nempty_indices, and nempty_neighs as they depend on
345
+ # children[depth] and are invalid now
346
+ self.nempty_masks[depth] = None
347
+ self.nempty_indices[depth] = None
348
+ self.nempty_neighs[depth] = None
349
+
350
+ def octree_grow(self, depth: int, update_neigh: bool = True):
351
+ r''' Grows the octree and updates the relevant properties. And in most
352
+ cases, call :func:`Octree.octree_split` to update the splitting status of
353
+ the octree before this function.
354
+
355
+ Args:
356
+ depth (int): The depth of the octree.
357
+ update_neigh (bool): If True, construct the neighborhood indices.
358
+ '''
359
+
360
+ # increase the octree depth if required
361
+ if depth > self.depth:
362
+ assert depth == self.depth + 1
363
+ self.depth = depth
364
+ self.keys.append(None)
365
+ self.children.append(None)
366
+ self.neighs.append(None)
367
+ self.features.append(None)
368
+ self.normals.append(None)
369
+ self.points.append(None)
370
+ self.nempty_masks.append(None)
371
+ self.nempty_indices.append(None)
372
+ self.nempty_neighs.append(None)
373
+ zero = torch.zeros(1, dtype=torch.long)
374
+ self.nnum = torch.cat([self.nnum, zero])
375
+ self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
376
+ zero = zero.view(1, 1)
377
+ self.batch_nnum = torch.cat([self.batch_nnum, zero], dim=0)
378
+ self.batch_nnum_nempty = torch.cat([self.batch_nnum_nempty, zero], dim=0)
379
+
380
+ # node number
381
+ nnum = self.nnum_nempty[depth-1] * 8
382
+ self.nnum[depth] = nnum
383
+ self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
384
+
385
+ # update keys
386
+ key = self.key(depth-1, nempty=True)
387
+ batch_id = (key >> 48) << 48
388
+ key = (key & ((1 << 48) - 1)) << 3
389
+ key = key | batch_id
390
+ key = key.unsqueeze(1) + torch.arange(8, device=key.device)
391
+ self.keys[depth] = key.view(-1)
392
+
393
+ # update children
394
+ self.children[depth] = torch.arange(
395
+ nnum, dtype=torch.int32, device=self.device)
396
+
397
+ # update neighs
398
+ if update_neigh:
399
+ self.construct_neigh(depth)
400
+
401
+ def construct_neigh(self, depth: int):
402
+ r''' Constructs the :obj:`3x3x3` neighbors for each octree node.
403
+
404
+ Args:
405
+ depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
406
+ '''
407
+
408
+ if depth <= self.full_depth:
409
+ device = self.device
410
+ nnum = 1 << (3 * depth)
411
+ key = torch.arange(nnum, dtype=torch.long, device=device)
412
+ x, y, z, _ = key2xyz(key, depth)
413
+ xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
414
+ grid = range_grid(-1, 1, device) # (27, 3)
415
+ xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
416
+ xyz = xyz.view(-1, 3) # (N*27, 3)
417
+ neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
418
+
419
+ bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
420
+ neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
421
+
422
+ bound = 1 << depth
423
+ invalid = torch.logical_or((xyz < 0).any(1), (xyz >= bound).any(1))
424
+ neigh[:, invalid] = -1
425
+ self.neighs[depth] = neigh.view(-1, 27) # (B*N, 27)
426
+
427
+ else:
428
+ child_p = self.children[depth-1]
429
+ nempty = child_p >= 0
430
+ neigh_p = self.neighs[depth-1][nempty] # (N, 27)
431
+ neigh_p = neigh_p[:, self.lut_parent] # (N, 8, 27)
432
+ child_p = child_p[neigh_p] # (N, 8, 27)
433
+ invalid = torch.logical_or(child_p < 0, neigh_p < 0) # (N, 8, 27)
434
+ neigh = child_p * 8 + self.lut_child
435
+ neigh[invalid] = -1
436
+ self.neighs[depth] = neigh.view(-1, 27)
437
+
438
+ def construct_all_neigh(self):
439
+ r''' A convenient handler for constructing all neighbors.
440
+ '''
441
+
442
+ for depth in range(1, self.depth+1):
443
+ self.construct_neigh(depth)
444
+
445
+ def search_xyzb(self, query: torch.Tensor, depth: int, nempty: bool = False):
446
+ r''' Searches the octree nodes given the query points.
447
+
448
+ Args:
449
+ query (torch.Tensor): The coordinates of query points with shape
450
+ :obj:`(N, 4)`. The first 3 channels of the coordinates are :obj:`x`,
451
+ :obj:`y`, and :obj:`z`, and the last channel is the batch index. Note
452
+ that the coordinates must be in range :obj:`[0, 2^depth)`.
453
+ depth (int): The depth of the octree layer. nemtpy (bool): If true, only
454
+ searches the non-empty octree nodes.
455
+ '''
456
+
457
+ key = xyz2key(query[:, 0], query[:, 1], query[:, 2], query[:, 3], depth)
458
+ idx = self.search_key(key, depth, nempty)
459
+ return idx
460
+
461
+ def search_key(self, query: torch.Tensor, depth: int, nempty: bool = False):
462
+ r''' Searches the octree nodes given the query points.
463
+
464
+ Args:
465
+ query (torch.Tensor): The keys of query points with shape :obj:`(N,)`,
466
+ which are computed from the coordinates of query points.
467
+ depth (int): The depth of the octree layer. nemtpy (bool): If true, only
468
+ searches the non-empty octree nodes.
469
+ '''
470
+
471
+ key = self.key(depth, nempty)
472
+ idx = torch.searchsorted(key, query)
473
+
474
+ # `torch.bucketize` can also be used here; it is similar to
475
+ # `torch.searchsorted`, and it has fewer dimension checks, resulting in
476
+ # slightly better performance for 1D sorted sequences according to the docs
477
+ # of pytorch-1.9.1. `key` is always a 1D sorted sequence.
478
+ # https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
479
+ # idx = torch.bucketize(query, key)
480
+
481
+ valid = idx < key.shape[0] # valid if in-bound
482
+ vi = torch.arange(query.shape[0], device=query.device)[valid]
483
+ valid[vi] = key[idx[vi]] == query[vi] # valid if found
484
+ idx[valid.logical_not()] = -1 # set to -1 if invalid
485
+ return idx
486
+
487
+ def get_neigh(self, depth: int, kernel: str = '333', stride: int = 1,
488
+ nempty: bool = False):
489
+ r''' Returns the neighborhoods given the depth and a kernel shape.
490
+
491
+ Args:
492
+ depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
493
+ kernel (str): The kernel shape from :obj:`333`, :obj:`311`, :obj:`131`,
494
+ :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and :obj:`313`.
495
+ stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
496
+ stride is :obj:`2`, always returns the neighborhood of the first
497
+ siblings.
498
+ nempty (bool): If True, only returns the neighborhoods of the non-empty
499
+ octree nodes.
500
+ '''
501
+
502
+ if stride == 1 and not nempty:
503
+ neigh = self.neighs[depth]
504
+ elif stride == 2 and not nempty:
505
+ neigh = self.neighs[depth][::8].clone()
506
+ elif stride == 1 and nempty:
507
+ neigh = self.nempty_neigh(depth)
508
+ elif stride == 2 and nempty:
509
+ neigh = self.neighs[depth][::8].clone()
510
+ neigh = self.remap_nempty_neigh(neigh, depth)
511
+ else:
512
+ raise ValueError('Unsupported stride {}'.format(stride))
513
+
514
+ if kernel == '333':
515
+ return neigh
516
+ elif kernel in self.lut_kernel:
517
+ lut = self.lut_kernel[kernel]
518
+ return neigh[:, lut]
519
+ else:
520
+ raise ValueError('Unsupported kernel {}'.format(kernel))
521
+
522
+ def get_input_feature(self, feature: str, nempty: bool = False):
523
+ r''' Returns the initial input feature stored in octree.
524
+
525
+ Args:
526
+ feature (str): A string used to indicate which features to extract from
527
+ the input octree. If the character :obj:`N` is in :attr:`feature`, the
528
+ normal signal is extracted (3 channels). Similarly, if :obj:`D` is in
529
+ :attr:`feature`, the local displacement is extracted (1 channels). If
530
+ :obj:`L` is in :attr:`feature`, the local coordinates of the averaged
531
+ points in each octree node is extracted (3 channels). If :attr:`P` is
532
+ in :attr:`feature`, the global coordinates are extracted (3 channels).
533
+ If :attr:`F` is in :attr:`feature`, other features (like colors) are
534
+ extracted (k channels).
535
+ nempty (bool): If false, gets the features of all octree nodes.
536
+ '''
537
+
538
+ features = list()
539
+ depth = self.depth
540
+ feature = feature.upper()
541
+ if 'N' in feature:
542
+ features.append(self.normals[depth])
543
+
544
+ if 'L' in feature or 'D' in feature:
545
+ local_points = self.points[depth].frac() - 0.5
546
+
547
+ if 'D' in feature:
548
+ dis = torch.sum(local_points * self.normals[depth], dim=1, keepdim=True)
549
+ features.append(dis)
550
+
551
+ if 'L' in feature:
552
+ features.append(local_points)
553
+
554
+ if 'P' in feature:
555
+ scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1]
556
+ global_points = self.points[depth] * scale - 1.0
557
+ features.append(global_points)
558
+
559
+ if 'F' in feature:
560
+ features.append(self.features[depth])
561
+
562
+ out = torch.cat(features, dim=1)
563
+ if not nempty:
564
+ out = ocnn.nn.octree_pad(out, self, depth)
565
+ return out
566
+
567
+ def to_points(self, rescale: bool = True):
568
+ r''' Converts averaged points in the octree to a point cloud.
569
+
570
+ Args:
571
+ rescale (bool): rescale the xyz coordinates to [-1, 1] if True.
572
+ '''
573
+
574
+ depth = self.depth
575
+ batch_size = self.batch_size
576
+
577
+ # by default, use the average points generated when building the octree
578
+ # from the input point cloud
579
+ xyz = self.points[depth]
580
+ batch_id = self.batch_id(depth, nempty=True)
581
+
582
+ # xyz is None when the octree is predicted by a neural network
583
+ if xyz is None:
584
+ x, y, z, batch_id = self.xyzb(depth, nempty=True)
585
+ xyz = torch.stack([x, y, z], dim=1) + 0.5
586
+
587
+ # normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
588
+ if rescale:
589
+ scale = 2 ** (1 - depth)
590
+ xyz = xyz * scale - 1.0
591
+
592
+ # construct Points
593
+ out = Points(xyz, self.normals[depth], self.features[depth],
594
+ batch_id=batch_id, batch_size=batch_size)
595
+ return out
596
+
597
+ def to(self, device: Union[torch.device, str], non_blocking: bool = False):
598
+ r''' Moves the octree to a specified device.
599
+
600
+ Args:
601
+ device (torch.device or str): The destination device.
602
+ non_blocking (bool): If True and the source is in pinned memory, the copy
603
+ will be asynchronous with respect to the host. Otherwise, the argument
604
+ has no effect. Default: False.
605
+ '''
606
+
607
+ if isinstance(device, str):
608
+ device = torch.device(device)
609
+
610
+ # If on the save device, directly retrun self
611
+ if self.device == device:
612
+ return self
613
+
614
+ def list_to_device(prop):
615
+ return [p.to(device, non_blocking=non_blocking)
616
+ if isinstance(p, torch.Tensor) else None for p in prop]
617
+
618
+ # Construct a new Octree on the specified device.
619
+ # During the initialization, self.device is used to set up the new Octree;
620
+ # the look-up tables (including self.lut_kernel, self.lut_parent, and
621
+ # self.lut_child), will be already created on the correct device.
622
+ octree = Octree.init_like(self, device)
623
+
624
+ # Move all the other properties to the specified device
625
+ octree.keys = list_to_device(self.keys)
626
+ octree.children = list_to_device(self.children)
627
+ octree.neighs = list_to_device(self.neighs)
628
+ octree.features = list_to_device(self.features)
629
+ octree.normals = list_to_device(self.normals)
630
+ octree.points = list_to_device(self.points)
631
+
632
+ # The following are small tensors, keep them on CPU to avoid frequent device
633
+ # switching, so just clone them.
634
+ octree.nnum = self.nnum.clone()
635
+ octree.nnum_nempty = self.nnum_nempty.clone()
636
+ octree.batch_nnum = self.batch_nnum.clone()
637
+ octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
638
+ return octree
639
+
640
+ def cuda(self, non_blocking: bool = False):
641
+ r''' Moves the octree to the GPU. '''
642
+
643
+ return self.to('cuda', non_blocking)
644
+
645
+ def cpu(self):
646
+ r''' Moves the octree to the CPU. '''
647
+
648
+ return self.to('cpu')
649
+
650
+ def merge_octrees(self, octrees: List['Octree']):
651
+ r''' Merges a list of octrees into one batch.
652
+
653
+ Args:
654
+ octrees (List[Octree]): A list of octrees to merge.
655
+
656
+ Returns:
657
+ Octree: The merged octree.
658
+ '''
659
+
660
+ # init and check
661
+ batch_size = len(octrees)
662
+ self.batch_size = batch_size
663
+ for i in range(1, batch_size):
664
+ condition = (octrees[i].depth == self.depth and
665
+ octrees[i].full_depth == self.full_depth and
666
+ octrees[i].device == self.device)
667
+ assert condition, 'The check of merge_octrees failed'
668
+
669
+ # node num
670
+ batch_nnum = torch.stack(
671
+ [octrees[i].nnum for i in range(batch_size)], dim=1)
672
+ batch_nnum_nempty = torch.stack(
673
+ [octrees[i].nnum_nempty for i in range(batch_size)], dim=1)
674
+ self.nnum = torch.sum(batch_nnum, dim=1)
675
+ self.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
676
+ self.batch_nnum = batch_nnum
677
+ self.batch_nnum_nempty = batch_nnum_nempty
678
+ nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
679
+
680
+ # merge octre properties
681
+ for d in range(self.depth + 1):
682
+ # key
683
+ keys = [None] * batch_size
684
+ for i in range(batch_size):
685
+ key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
686
+ keys[i] = key | (i << 48)
687
+ self.keys[d] = torch.cat(keys, dim=0)
688
+
689
+ # children
690
+ children = [None] * batch_size
691
+ for i in range(batch_size):
692
+ # !! `clone` is used here to avoid modifying the original octrees
693
+ child = octrees[i].children[d].clone()
694
+ mask = child >= 0
695
+ child[mask] = child[mask] + nnum_cum[d, i]
696
+ children[i] = child
697
+ self.children[d] = torch.cat(children, dim=0)
698
+
699
+ # features
700
+ if octrees[0].features[d] is not None and d == self.depth:
701
+ features = [octrees[i].features[d] for i in range(batch_size)]
702
+ self.features[d] = torch.cat(features, dim=0)
703
+
704
+ # normals
705
+ if octrees[0].normals[d] is not None and d == self.depth:
706
+ normals = [octrees[i].normals[d] for i in range(batch_size)]
707
+ self.normals[d] = torch.cat(normals, dim=0)
708
+
709
+ # points
710
+ if octrees[0].points[d] is not None and d == self.depth:
711
+ points = [octrees[i].points[d] for i in range(batch_size)]
712
+ self.points[d] = torch.cat(points, dim=0)
713
+
714
+ return self
715
+
716
+ @classmethod
717
+ def init_like(cls, octree: 'Octree', device: Union[torch.device, str, None] = None):
718
+ r''' Initializes the octree like another octree.
719
+
720
+ Args:
721
+ octree (Octree): The reference octree.
722
+ device (torch.device or str): The device to use for computation.
723
+ '''
724
+
725
+ device = device if device is not None else octree.device
726
+ return cls(depth=octree.depth, full_depth=octree.full_depth,
727
+ batch_size=octree.batch_size, device=device)
728
+
729
+ @classmethod
730
+ def init_octree(cls, depth: int, full_depth: int = 2, batch_size: int = 1,
731
+ device: Union[torch.device, str] = 'cpu'):
732
+ r'''
733
+ Initializes an octree to :attr:`full_depth`.
734
+
735
+ Args:
736
+ depth (int): The depth of the octree.
737
+ full_depth (int): The octree layers with a depth small than
738
+ :attr:`full_depth` are forced to be full.
739
+ batch_size (int, optional): The batch size.
740
+ device (torch.device or str): The device to use for computation.
741
+
742
+ Returns:
743
+ Octree: The initialized Octree object.
744
+ '''
745
+
746
+ octree = cls(depth, full_depth, batch_size, device)
747
+ for d in range(full_depth + 1):
748
+ octree.octree_grow_full(depth=d)
749
+ return octree
750
+
751
+
752
+ def merge_octrees(octrees: List['Octree']):
753
+ r''' A wrapper of :meth:`Octree.merge_octrees`.
754
+
755
+ .. deprecated::
756
+ Use :meth:`Octree.merge_octrees` instead.
757
+ '''
758
+
759
+ return Octree.init_like(octrees[0]).merge_octrees(octrees)
760
+
761
+
762
+ def init_octree(depth: int, full_depth: int = 2, batch_size: int = 1,
763
+ device: Union[torch.device, str] = 'cpu'):
764
+ r''' A wrapper of :meth:`Octree.init_octree`.
765
+
766
+ .. deprecated::
767
+ Use :meth:`Octree.init_octree` instead.
768
+ '''
769
+
770
+ return Octree.init_octree(depth, full_depth, batch_size, device)