ocnn 2.2.6__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/octree/octree.py CHANGED
@@ -1,661 +1,661 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn.functional as F
10
- from typing import Union, List
11
-
12
- import ocnn
13
- from ocnn.octree.points import Points
14
- from ocnn.octree.shuffled_key import xyz2key, key2xyz
15
- from ocnn.utils import range_grid, scatter_add, cumsum, trunc_div
16
-
17
-
18
- class Octree:
19
- r''' Builds an octree from an input point cloud.
20
-
21
- Args:
22
- depth (int): The octree depth.
23
- full_depth (int): The octree layers with a depth small than
24
- :attr:`full_depth` are forced to be full.
25
- batch_size (int): The octree batch size.
26
- device (torch.device or str): Choose from :obj:`cpu` and :obj:`gpu`.
27
- (default: :obj:`cpu`)
28
-
29
- .. note::
30
- The octree data structure requires that if an octree node has children nodes,
31
- the number of children nodes is exactly 8, in which some of the nodes are
32
- empty and some nodes are non-empty. The properties of an octree, including
33
- :obj:`keys`, :obj:`children` and :obj:`neighs`, contain both non-empty and
34
- empty nodes, and other properties, including :obj:`features`, :obj:`normals`
35
- and :obj:`points`, contain only non-empty nodes.
36
-
37
- .. note::
38
- The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
39
- is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
40
- some margin.
41
- '''
42
-
43
- def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
44
- device: Union[torch.device, str] = 'cpu', **kwargs):
45
- super().__init__()
46
- self.depth = depth
47
- self.full_depth = full_depth
48
- self.batch_size = batch_size
49
- self.device = device
50
-
51
- self.reset()
52
-
53
- def reset(self):
54
- r''' Resets the Octree status and constructs several lookup tables.
55
- '''
56
-
57
- # octree features in each octree layers
58
- num = self.depth + 1
59
- self.keys = [None] * num
60
- self.children = [None] * num
61
- self.neighs = [None] * num
62
- self.features = [None] * num
63
- self.normals = [None] * num
64
- self.points = [None] * num
65
-
66
- # octree node numbers in each octree layers.
67
- # TODO: decide whether to settle them to 'gpu' or not?
68
- self.nnum = torch.zeros(num, dtype=torch.long)
69
- self.nnum_nempty = torch.zeros(num, dtype=torch.long)
70
-
71
- # the following properties are valid after `merge_octrees`.
72
- # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
73
- batch_size = self.batch_size
74
- self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
75
- self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
76
-
77
- # construct the look up tables for neighborhood searching
78
- device = self.device
79
- center_grid = range_grid(2, 3, device) # (8, 3)
80
- displacement = range_grid(-1, 1, device) # (27, 3)
81
- neigh_grid = center_grid.unsqueeze(1) + displacement # (8, 27, 3)
82
- parent_grid = trunc_div(neigh_grid, 2)
83
- child_grid = neigh_grid % 2
84
- self.lut_parent = torch.sum(
85
- parent_grid * torch.tensor([9, 3, 1], device=device), dim=2)
86
- self.lut_child = torch.sum(
87
- child_grid * torch.tensor([4, 2, 1], device=device), dim=2)
88
-
89
- # lookup tables for different kernel sizes
90
- self.lut_kernel = {
91
- '222': torch.tensor([13, 14, 16, 17, 22, 23, 25, 26], device=device),
92
- '311': torch.tensor([4, 13, 22], device=device),
93
- '131': torch.tensor([10, 13, 16], device=device),
94
- '113': torch.tensor([12, 13, 14], device=device),
95
- '331': torch.tensor([1, 4, 7, 10, 13, 16, 19, 22, 25], device=device),
96
- '313': torch.tensor([3, 4, 5, 12, 13, 14, 21, 22, 23], device=device),
97
- '133': torch.tensor([9, 10, 11, 12, 13, 14, 15, 16, 17], device=device),
98
- }
99
-
100
- def key(self, depth: int, nempty: bool = False):
101
- r''' Returns the shuffled key of each octree node.
102
-
103
- Args:
104
- depth (int): The depth of the octree.
105
- nempty (bool): If True, returns the results of non-empty octree nodes.
106
- '''
107
-
108
- key = self.keys[depth]
109
- if nempty:
110
- mask = self.nempty_mask(depth)
111
- key = key[mask]
112
- return key
113
-
114
- def xyzb(self, depth: int, nempty: bool = False):
115
- r''' Returns the xyz coordinates and the batch indices of each octree node.
116
-
117
- Args:
118
- depth (int): The depth of the octree.
119
- nempty (bool): If True, returns the results of non-empty octree nodes.
120
- '''
121
-
122
- key = self.key(depth, nempty)
123
- return key2xyz(key, depth)
124
-
125
- def batch_id(self, depth: int, nempty: bool = False):
126
- r''' Returns the batch indices of each octree node.
127
-
128
- Args:
129
- depth (int): The depth of the octree.
130
- nempty (bool): If True, returns the results of non-empty octree nodes.
131
- '''
132
-
133
- batch_id = self.keys[depth] >> 48
134
- if nempty:
135
- mask = self.nempty_mask(depth)
136
- batch_id = batch_id[mask]
137
- return batch_id
138
-
139
- def nempty_mask(self, depth: int):
140
- r''' Returns a binary mask which indicates whether the cooreponding octree
141
- node is empty or not.
142
-
143
- Args:
144
- depth (int): The depth of the octree.
145
- '''
146
-
147
- return self.children[depth] >= 0
148
-
149
- def build_octree(self, point_cloud: Points):
150
- r''' Builds an octree from a point cloud.
151
-
152
- Args:
153
- point_cloud (Points): The input point cloud.
154
-
155
- .. note::
156
- The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
157
- is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
158
- some margin.
159
- '''
160
-
161
- self.device = point_cloud.device
162
- assert point_cloud.batch_size == self.batch_size, 'Inconsistent batch_size'
163
-
164
- # normalize points from [-1, 1] to [0, 2^depth]. #[L:Scale]
165
- scale = 2 ** (self.depth - 1)
166
- points = (point_cloud.points + 1.0) * scale
167
-
168
- # get the shuffled key and sort
169
- x, y, z = points[:, 0], points[:, 1], points[:, 2]
170
- b = None if self.batch_size == 1 else point_cloud.batch_id.view(-1)
171
- key = xyz2key(x, y, z, b, self.depth)
172
- node_key, idx, counts = torch.unique(
173
- key, sorted=True, return_inverse=True, return_counts=True)
174
-
175
- # layer 0 to full_layer: the octree is full in these layers
176
- for d in range(self.full_depth+1):
177
- self.octree_grow_full(d, update_neigh=False)
178
-
179
- # layer depth_ to full_layer_
180
- for d in range(self.depth, self.full_depth, -1):
181
- # compute parent key, i.e. keys of layer (d -1)
182
- pkey = node_key >> 3
183
- pkey, pidx, _ = torch.unique_consecutive(
184
- pkey, return_inverse=True, return_counts=True)
185
-
186
- # augmented key
187
- key = (pkey.unsqueeze(-1) << 3) + torch.arange(8, device=self.device)
188
- self.keys[d] = key.view(-1)
189
- self.nnum[d] = key.numel()
190
- self.nnum_nempty[d] = node_key.numel()
191
-
192
- # children
193
- addr = (pidx << 3) | (node_key % 8)
194
- children = -torch.ones(
195
- self.nnum[d].item(), dtype=torch.int32, device=self.device)
196
- children[addr] = torch.arange(
197
- self.nnum_nempty[d], dtype=torch.int32, device=self.device)
198
- self.children[d] = children
199
-
200
- # cache pkey for the next iteration
201
- # Use `pkey >> 45` instead of `pkey >> 48` in L199 since pkey is already
202
- # shifted to the right by 3 bits in L177
203
- node_key = pkey if self.batch_size == 1 else \
204
- ((pkey >> 45) << 48) | (pkey & ((1 << 45) - 1))
205
-
206
- # set the children for the layer full_layer,
207
- # now the node_keys are the key for full_layer
208
- d = self.full_depth
209
- children = -torch.ones_like(self.children[d])
210
- nempty_idx = node_key if self.batch_size == 1 else \
211
- ((node_key >> 48) << (3 * d)) | (node_key & ((1 << 48) - 1))
212
- children[nempty_idx] = torch.arange(
213
- node_key.numel(), dtype=torch.int32, device=self.device)
214
- self.children[d] = children
215
- self.nnum_nempty[d] = node_key.numel()
216
-
217
- # average the signal for the last octree layer
218
- d = self.depth
219
- points = scatter_add(points, idx, dim=0) # points is rescaled in [L:Scale]
220
- self.points[d] = points / counts.unsqueeze(1)
221
- if point_cloud.normals is not None:
222
- normals = scatter_add(point_cloud.normals, idx, dim=0)
223
- self.normals[d] = F.normalize(normals)
224
- if point_cloud.features is not None:
225
- features = scatter_add(point_cloud.features, idx, dim=0)
226
- self.features[d] = features / counts.unsqueeze(1)
227
-
228
- return idx
229
-
230
- def octree_grow_full(self, depth: int, update_neigh: bool = True):
231
- r''' Builds the full octree, which is essentially a dense volumetric grid.
232
-
233
- Args:
234
- depth (int): The depth of the octree.
235
- update_neigh (bool): If True, construct the neighborhood indices.
236
- '''
237
-
238
- # check
239
- assert depth <= self.full_depth, 'error'
240
-
241
- # node number
242
- num = 1 << (3 * depth)
243
- self.nnum[depth] = num * self.batch_size
244
- self.nnum_nempty[depth] = num * self.batch_size
245
-
246
- # update key
247
- key = torch.arange(num, dtype=torch.long, device=self.device)
248
- bs = torch.arange(self.batch_size, dtype=torch.long, device=self.device)
249
- key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
250
- self.keys[depth] = key.view(-1)
251
-
252
- # update children
253
- self.children[depth] = torch.arange(
254
- num * self.batch_size, dtype=torch.int32, device=self.device)
255
-
256
- # update neigh if needed
257
- if update_neigh:
258
- self.construct_neigh(depth)
259
-
260
- def octree_split(self, split: torch.Tensor, depth: int):
261
- r''' Sets whether the octree nodes in :attr:`depth` are splitted or not.
262
-
263
- Args:
264
- split (torch.Tensor): The input tensor with its element indicating status
265
- of each octree node: 0 - empty, 1 - non-empty or splitted.
266
- depth (int): The depth of current octree.
267
- '''
268
-
269
- # split -> children
270
- empty = split == 0
271
- sum = cumsum(split, dim=0, exclusive=True)
272
- children, nnum_nempty = torch.split(sum, [split.shape[0], 1])
273
- children[empty] = -1
274
-
275
- # boundary case, make sure that at least one octree node is splitted
276
- if nnum_nempty == 0:
277
- nnum_nempty = 1
278
- children[0] = 0
279
-
280
- # update octree
281
- self.children[depth] = children.int()
282
- self.nnum_nempty[depth] = nnum_nempty
283
-
284
- def octree_grow(self, depth: int, update_neigh: bool = True):
285
- r''' Grows the octree and updates the relevant properties. And in most
286
- cases, call :func:`Octree.octree_split` to update the splitting status of
287
- the octree before this function.
288
-
289
- Args:
290
- depth (int): The depth of the octree.
291
- update_neigh (bool): If True, construct the neighborhood indices.
292
- '''
293
-
294
- # increase the octree depth if required
295
- if depth > self.depth:
296
- assert depth == self.depth + 1
297
- self.depth = depth
298
- self.keys.append(None)
299
- self.children.append(None)
300
- self.neighs.append(None)
301
- self.features.append(None)
302
- self.normals.append(None)
303
- self.points.append(None)
304
- zero = torch.zeros(1, dtype=torch.long)
305
- self.nnum = torch.cat([self.nnum, zero])
306
- self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
307
- zero = zero.view(1, 1)
308
- self.batch_nnum = torch.cat([self.batch_nnum, zero], dim=0)
309
- self.batch_nnum_nempty = torch.cat([self.batch_nnum_nempty, zero], dim=0)
310
-
311
- # node number
312
- nnum = self.nnum_nempty[depth-1] * 8
313
- self.nnum[depth] = nnum
314
- self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
315
-
316
- # update keys
317
- key = self.key(depth-1, nempty=True)
318
- batch_id = (key >> 48) << 48
319
- key = (key & ((1 << 48) - 1)) << 3
320
- key = key | batch_id
321
- key = key.unsqueeze(1) + torch.arange(8, device=key.device)
322
- self.keys[depth] = key.view(-1)
323
-
324
- # update children
325
- self.children[depth] = torch.arange(
326
- nnum, dtype=torch.int32, device=self.device)
327
-
328
- # update neighs
329
- if update_neigh:
330
- self.construct_neigh(depth)
331
-
332
- def construct_neigh(self, depth: int):
333
- r''' Constructs the :obj:`3x3x3` neighbors for each octree node.
334
-
335
- Args:
336
- depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
337
- '''
338
-
339
- if depth <= self.full_depth:
340
- device = self.device
341
- nnum = 1 << (3 * depth)
342
- key = torch.arange(nnum, dtype=torch.long, device=device)
343
- x, y, z, _ = key2xyz(key, depth)
344
- xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
345
- grid = range_grid(-1, 1, device) # (27, 3)
346
- xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
347
- xyz = xyz.view(-1, 3) # (N*27, 3)
348
- neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
349
-
350
- bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
351
- neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
352
-
353
- bound = 1 << depth
354
- invalid = torch.logical_or((xyz < 0).any(1), (xyz >= bound).any(1))
355
- neigh[:, invalid] = -1
356
- self.neighs[depth] = neigh.view(-1, 27) # (B*N, 27)
357
-
358
- else:
359
- child_p = self.children[depth-1]
360
- nempty = child_p >= 0
361
- neigh_p = self.neighs[depth-1][nempty] # (N, 27)
362
- neigh_p = neigh_p[:, self.lut_parent] # (N, 8, 27)
363
- child_p = child_p[neigh_p] # (N, 8, 27)
364
- invalid = torch.logical_or(child_p < 0, neigh_p < 0) # (N, 8, 27)
365
- neigh = child_p * 8 + self.lut_child
366
- neigh[invalid] = -1
367
- self.neighs[depth] = neigh.view(-1, 27)
368
-
369
- def construct_all_neigh(self):
370
- r''' A convenient handler for constructing all neighbors.
371
- '''
372
-
373
- for depth in range(1, self.depth+1):
374
- self.construct_neigh(depth)
375
-
376
- def search_xyzb(self, query: torch.Tensor, depth: int, nempty: bool = False):
377
- r''' Searches the octree nodes given the query points.
378
-
379
- Args:
380
- query (torch.Tensor): The coordinates of query points with shape
381
- :obj:`(N, 4)`. The first 3 channels of the coordinates are :obj:`x`,
382
- :obj:`y`, and :obj:`z`, and the last channel is the batch index. Note
383
- that the coordinates must be in range :obj:`[0, 2^depth)`.
384
- depth (int): The depth of the octree layer. nemtpy (bool): If true, only
385
- searches the non-empty octree nodes.
386
- '''
387
-
388
- key = xyz2key(query[:, 0], query[:, 1], query[:, 2], query[:, 3], depth)
389
- idx = self.search_key(key, depth, nempty)
390
- return idx
391
-
392
- def search_key(self, query: torch.Tensor, depth: int, nempty: bool = False):
393
- r''' Searches the octree nodes given the query points.
394
-
395
- Args:
396
- query (torch.Tensor): The keys of query points with shape :obj:`(N,)`,
397
- which are computed from the coordinates of query points.
398
- depth (int): The depth of the octree layer. nemtpy (bool): If true, only
399
- searches the non-empty octree nodes.
400
- '''
401
-
402
- key = self.key(depth, nempty)
403
- # `torch.bucketize` is similar to `torch.searchsorted`.
404
- # I choose `torch.bucketize` here because it has fewer dimension checks,
405
- # resulting in slightly better performance according to the docs of
406
- # pytorch-1.9.1, since `key` is always 1-D sorted sequence.
407
- # https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
408
- idx = torch.bucketize(query, key)
409
-
410
- valid = idx < key.shape[0] # valid if in-bound
411
- found = key[idx[valid]] == query[valid]
412
- valid[valid.clone()] = found # valid if found
413
- idx[valid.logical_not()] = -1 # set to -1 if invalid
414
- return idx
415
-
416
- def get_neigh(self, depth: int, kernel: str = '333', stride: int = 1,
417
- nempty: bool = False):
418
- r''' Returns the neighborhoods given the depth and a kernel shape.
419
-
420
- Args:
421
- depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
422
- kernel (str): The kernel shape from :obj:`333`, :obj:`311`, :obj:`131`,
423
- :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and :obj:`313`.
424
- stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
425
- stride is :obj:`2`, always returns the neighborhood of the first
426
- siblings.
427
- nempty (bool): If True, only returns the neighborhoods of the non-empty
428
- octree nodes.
429
- '''
430
-
431
- if stride == 1:
432
- neigh = self.neighs[depth]
433
- elif stride == 2:
434
- # clone neigh to avoid self.neigh[depth] being modified
435
- neigh = self.neighs[depth][::8].clone()
436
- else:
437
- raise ValueError('Unsupported stride {}'.format(stride))
438
-
439
- if nempty:
440
- child = self.children[depth]
441
- if stride == 1:
442
- nempty_node = child >= 0
443
- neigh = neigh[nempty_node]
444
- valid = neigh >= 0
445
- neigh[valid] = child[neigh[valid]].long() # remap the index
446
-
447
- if kernel == '333':
448
- return neigh
449
- elif kernel in self.lut_kernel:
450
- lut = self.lut_kernel[kernel]
451
- return neigh[:, lut]
452
- else:
453
- raise ValueError('Unsupported kernel {}'.format(kernel))
454
-
455
- def get_input_feature(self, feature: str, nempty: bool = False):
456
- r''' Returns the initial input feature stored in octree.
457
-
458
- Args:
459
- feature (str): A string used to indicate which features to extract from
460
- the input octree. If the character :obj:`N` is in :attr:`feature`, the
461
- normal signal is extracted (3 channels). Similarly, if :obj:`D` is in
462
- :attr:`feature`, the local displacement is extracted (1 channels). If
463
- :obj:`L` is in :attr:`feature`, the local coordinates of the averaged
464
- points in each octree node is extracted (3 channels). If :attr:`P` is
465
- in :attr:`feature`, the global coordinates are extracted (3 channels).
466
- If :attr:`F` is in :attr:`feature`, other features (like colors) are
467
- extracted (k channels).
468
- nempty (bool): If false, gets the features of all octree nodes.
469
- '''
470
-
471
- features = list()
472
- depth = self.depth
473
- feature = feature.upper()
474
- if 'N' in feature:
475
- features.append(self.normals[depth])
476
-
477
- if 'L' in feature or 'D' in feature:
478
- local_points = self.points[depth].frac() - 0.5
479
-
480
- if 'D' in feature:
481
- dis = torch.sum(local_points * self.normals[depth], dim=1, keepdim=True)
482
- features.append(dis)
483
-
484
- if 'L' in feature:
485
- features.append(local_points)
486
-
487
- if 'P' in feature:
488
- scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1]
489
- global_points = self.points[depth] * scale - 1.0
490
- features.append(global_points)
491
-
492
- if 'F' in feature:
493
- features.append(self.features[depth])
494
-
495
- out = torch.cat(features, dim=1)
496
- if not nempty:
497
- out = ocnn.nn.octree_pad(out, self, depth)
498
- return out
499
-
500
- def to_points(self, rescale: bool = True):
501
- r''' Converts averaged points in the octree to a point cloud.
502
-
503
- Args:
504
- rescale (bool): rescale the xyz coordinates to [-1, 1] if True.
505
- '''
506
-
507
- depth = self.depth
508
- batch_size = self.batch_size
509
-
510
- # by default, use the average points generated when building the octree
511
- # from the input point cloud
512
- xyz = self.points[depth]
513
- batch_id = self.batch_id(depth, nempty=True)
514
-
515
- # xyz is None when the octree is predicted by a neural network
516
- if xyz is None:
517
- x, y, z, batch_id = self.xyzb(depth, nempty=True)
518
- xyz = torch.stack([x, y, z], dim=1) + 0.5
519
-
520
- # normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
521
- if rescale:
522
- scale = 2 ** (1 - depth)
523
- xyz = xyz * scale - 1.0
524
-
525
- # construct Points
526
- out = Points(xyz, self.normals[depth], self.features[depth],
527
- batch_id=batch_id, batch_size=batch_size)
528
- return out
529
-
530
- def to(self, device: Union[torch.device, str], non_blocking: bool = False):
531
- r''' Moves the octree to a specified device.
532
-
533
- Args:
534
- device (torch.device or str): The destination device.
535
- non_blocking (bool): If True and the source is in pinned memory, the copy
536
- will be asynchronous with respect to the host. Otherwise, the argument
537
- has no effect. Default: False.
538
- '''
539
-
540
- if isinstance(device, str):
541
- device = torch.device(device)
542
-
543
- # If on the save device, directly retrun self
544
- if self.device == device:
545
- return self
546
-
547
- def list_to_device(prop):
548
- return [p.to(device, non_blocking=non_blocking)
549
- if isinstance(p, torch.Tensor) else None for p in prop]
550
-
551
- # Construct a new Octree on the specified device
552
- octree = Octree(self.depth, self.full_depth, self.batch_size, device)
553
- octree.keys = list_to_device(self.keys)
554
- octree.children = list_to_device(self.children)
555
- octree.neighs = list_to_device(self.neighs)
556
- octree.features = list_to_device(self.features)
557
- octree.normals = list_to_device(self.normals)
558
- octree.points = list_to_device(self.points)
559
- octree.nnum = self.nnum.clone() # TODO: whether to move nnum to the self.device?
560
- octree.nnum_nempty = self.nnum_nempty.clone()
561
- octree.batch_nnum = self.batch_nnum.clone()
562
- octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
563
- return octree
564
-
565
- def cuda(self, non_blocking: bool = False):
566
- r''' Moves the octree to the GPU. '''
567
-
568
- return self.to('cuda', non_blocking)
569
-
570
- def cpu(self):
571
- r''' Moves the octree to the CPU. '''
572
-
573
- return self.to('cpu')
574
-
575
-
576
- def merge_octrees(octrees: List['Octree']):
577
- r''' Merges a list of octrees into one batch.
578
-
579
- Args:
580
- octrees (List[Octree]): A list of octrees to merge.
581
-
582
- Returns:
583
- Octree: The merged octree.
584
- '''
585
-
586
- # init and check
587
- octree = Octree(depth=octrees[0].depth, full_depth=octrees[0].full_depth,
588
- batch_size=len(octrees), device=octrees[0].device)
589
- for i in range(1, octree.batch_size):
590
- condition = (octrees[i].depth == octree.depth and
591
- octrees[i].full_depth == octree.full_depth and
592
- octrees[i].device == octree.device)
593
- assert condition, 'The check of merge_octrees failed'
594
-
595
- # node num
596
- batch_nnum = torch.stack(
597
- [octrees[i].nnum for i in range(octree.batch_size)], dim=1)
598
- batch_nnum_nempty = torch.stack(
599
- [octrees[i].nnum_nempty for i in range(octree.batch_size)], dim=1)
600
- octree.nnum = torch.sum(batch_nnum, dim=1)
601
- octree.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
602
- octree.batch_nnum = batch_nnum
603
- octree.batch_nnum_nempty = batch_nnum_nempty
604
- nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
605
-
606
- # merge octre properties
607
- for d in range(octree.depth+1):
608
- # key
609
- keys = [None] * octree.batch_size
610
- for i in range(octree.batch_size):
611
- key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
612
- keys[i] = key | (i << 48)
613
- octree.keys[d] = torch.cat(keys, dim=0)
614
-
615
- # children
616
- children = [None] * octree.batch_size
617
- for i in range(octree.batch_size):
618
- child = octrees[i].children[d].clone() # !! `clone` is used here to avoid
619
- mask = child >= 0 # !! modifying the original octrees
620
- child[mask] = child[mask] + nnum_cum[d, i]
621
- children[i] = child
622
- octree.children[d] = torch.cat(children, dim=0)
623
-
624
- # features
625
- if octrees[0].features[d] is not None and d == octree.depth:
626
- features = [octrees[i].features[d] for i in range(octree.batch_size)]
627
- octree.features[d] = torch.cat(features, dim=0)
628
-
629
- # normals
630
- if octrees[0].normals[d] is not None and d == octree.depth:
631
- normals = [octrees[i].normals[d] for i in range(octree.batch_size)]
632
- octree.normals[d] = torch.cat(normals, dim=0)
633
-
634
- # points
635
- if octrees[0].points[d] is not None and d == octree.depth:
636
- points = [octrees[i].points[d] for i in range(octree.batch_size)]
637
- octree.points[d] = torch.cat(points, dim=0)
638
-
639
- return octree
640
-
641
-
642
- def init_octree(depth: int, full_depth: int = 2, batch_size: int = 1,
643
- device: Union[torch.device, str] = 'cpu'):
644
- r'''
645
- Initializes an octree to :attr:`full_depth`.
646
-
647
- Args:
648
- depth (int): The depth of the octree.
649
- full_depth (int): The octree layers with a depth small than
650
- :attr:`full_depth` are forced to be full.
651
- batch_size (int, optional): The batch size.
652
- device (torch.device or str): The device to use for computation.
653
-
654
- Returns:
655
- Octree: The initialized Octree object.
656
- '''
657
-
658
- octree = Octree(depth, full_depth, batch_size, device)
659
- for d in range(full_depth+1):
660
- octree.octree_grow_full(depth=d)
661
- return octree
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from typing import Union, List
11
+
12
+ import ocnn
13
+ from ocnn.octree.points import Points
14
+ from ocnn.octree.shuffled_key import xyz2key, key2xyz
15
+ from ocnn.utils import range_grid, scatter_add, cumsum, trunc_div
16
+
17
+
18
+ class Octree:
19
+ r''' Builds an octree from an input point cloud.
20
+
21
+ Args:
22
+ depth (int): The octree depth.
23
+ full_depth (int): The octree layers with a depth small than
24
+ :attr:`full_depth` are forced to be full.
25
+ batch_size (int): The octree batch size.
26
+ device (torch.device or str): Choose from :obj:`cpu` and :obj:`gpu`.
27
+ (default: :obj:`cpu`)
28
+
29
+ .. note::
30
+ The octree data structure requires that if an octree node has children nodes,
31
+ the number of children nodes is exactly 8, in which some of the nodes are
32
+ empty and some nodes are non-empty. The properties of an octree, including
33
+ :obj:`keys`, :obj:`children` and :obj:`neighs`, contain both non-empty and
34
+ empty nodes, and other properties, including :obj:`features`, :obj:`normals`
35
+ and :obj:`points`, contain only non-empty nodes.
36
+
37
+ .. note::
38
+ The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
39
+ is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
40
+ some margin.
41
+ '''
42
+
43
+ def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
44
+ device: Union[torch.device, str] = 'cpu', **kwargs):
45
+ super().__init__()
46
+ self.depth = depth
47
+ self.full_depth = full_depth
48
+ self.batch_size = batch_size
49
+ self.device = device
50
+
51
+ self.reset()
52
+
53
+ def reset(self):
54
+ r''' Resets the Octree status and constructs several lookup tables.
55
+ '''
56
+
57
+ # octree features in each octree layers
58
+ num = self.depth + 1
59
+ self.keys = [None] * num
60
+ self.children = [None] * num
61
+ self.neighs = [None] * num
62
+ self.features = [None] * num
63
+ self.normals = [None] * num
64
+ self.points = [None] * num
65
+
66
+ # octree node numbers in each octree layers.
67
+ # TODO: decide whether to settle them to 'gpu' or not?
68
+ self.nnum = torch.zeros(num, dtype=torch.long)
69
+ self.nnum_nempty = torch.zeros(num, dtype=torch.long)
70
+
71
+ # the following properties are valid after `merge_octrees`.
72
+ # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
73
+ batch_size = self.batch_size
74
+ self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
75
+ self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
76
+
77
+ # construct the look up tables for neighborhood searching
78
+ device = self.device
79
+ center_grid = range_grid(2, 3, device) # (8, 3)
80
+ displacement = range_grid(-1, 1, device) # (27, 3)
81
+ neigh_grid = center_grid.unsqueeze(1) + displacement # (8, 27, 3)
82
+ parent_grid = trunc_div(neigh_grid, 2)
83
+ child_grid = neigh_grid % 2
84
+ self.lut_parent = torch.sum(
85
+ parent_grid * torch.tensor([9, 3, 1], device=device), dim=2)
86
+ self.lut_child = torch.sum(
87
+ child_grid * torch.tensor([4, 2, 1], device=device), dim=2)
88
+
89
+ # lookup tables for different kernel sizes
90
+ self.lut_kernel = {
91
+ '222': torch.tensor([13, 14, 16, 17, 22, 23, 25, 26], device=device),
92
+ '311': torch.tensor([4, 13, 22], device=device),
93
+ '131': torch.tensor([10, 13, 16], device=device),
94
+ '113': torch.tensor([12, 13, 14], device=device),
95
+ '331': torch.tensor([1, 4, 7, 10, 13, 16, 19, 22, 25], device=device),
96
+ '313': torch.tensor([3, 4, 5, 12, 13, 14, 21, 22, 23], device=device),
97
+ '133': torch.tensor([9, 10, 11, 12, 13, 14, 15, 16, 17], device=device),
98
+ }
99
+
100
+ def key(self, depth: int, nempty: bool = False):
101
+ r''' Returns the shuffled key of each octree node.
102
+
103
+ Args:
104
+ depth (int): The depth of the octree.
105
+ nempty (bool): If True, returns the results of non-empty octree nodes.
106
+ '''
107
+
108
+ key = self.keys[depth]
109
+ if nempty:
110
+ mask = self.nempty_mask(depth)
111
+ key = key[mask]
112
+ return key
113
+
114
+ def xyzb(self, depth: int, nempty: bool = False):
115
+ r''' Returns the xyz coordinates and the batch indices of each octree node.
116
+
117
+ Args:
118
+ depth (int): The depth of the octree.
119
+ nempty (bool): If True, returns the results of non-empty octree nodes.
120
+ '''
121
+
122
+ key = self.key(depth, nempty)
123
+ return key2xyz(key, depth)
124
+
125
+ def batch_id(self, depth: int, nempty: bool = False):
126
+ r''' Returns the batch indices of each octree node.
127
+
128
+ Args:
129
+ depth (int): The depth of the octree.
130
+ nempty (bool): If True, returns the results of non-empty octree nodes.
131
+ '''
132
+
133
+ batch_id = self.keys[depth] >> 48
134
+ if nempty:
135
+ mask = self.nempty_mask(depth)
136
+ batch_id = batch_id[mask]
137
+ return batch_id
138
+
139
+ def nempty_mask(self, depth: int):
140
+ r''' Returns a binary mask which indicates whether the cooreponding octree
141
+ node is empty or not.
142
+
143
+ Args:
144
+ depth (int): The depth of the octree.
145
+ '''
146
+
147
+ return self.children[depth] >= 0
148
+
149
+ def build_octree(self, point_cloud: Points):
150
+ r''' Builds an octree from a point cloud.
151
+
152
+ Args:
153
+ point_cloud (Points): The input point cloud.
154
+
155
+ .. note::
156
+ The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
157
+ is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
158
+ some margin.
159
+ '''
160
+
161
+ self.device = point_cloud.device
162
+ assert point_cloud.batch_size == self.batch_size, 'Inconsistent batch_size'
163
+
164
+ # normalize points from [-1, 1] to [0, 2^depth]. #[L:Scale]
165
+ scale = 2 ** (self.depth - 1)
166
+ points = (point_cloud.points + 1.0) * scale
167
+
168
+ # get the shuffled key and sort
169
+ x, y, z = points[:, 0], points[:, 1], points[:, 2]
170
+ b = None if self.batch_size == 1 else point_cloud.batch_id.view(-1)
171
+ key = xyz2key(x, y, z, b, self.depth)
172
+ node_key, idx, counts = torch.unique(
173
+ key, sorted=True, return_inverse=True, return_counts=True)
174
+
175
+ # layer 0 to full_layer: the octree is full in these layers
176
+ for d in range(self.full_depth+1):
177
+ self.octree_grow_full(d, update_neigh=False)
178
+
179
+ # layer depth_ to full_layer_
180
+ for d in range(self.depth, self.full_depth, -1):
181
+ # compute parent key, i.e. keys of layer (d -1)
182
+ pkey = node_key >> 3
183
+ pkey, pidx, _ = torch.unique_consecutive(
184
+ pkey, return_inverse=True, return_counts=True)
185
+
186
+ # augmented key
187
+ key = (pkey.unsqueeze(-1) << 3) + torch.arange(8, device=self.device)
188
+ self.keys[d] = key.view(-1)
189
+ self.nnum[d] = key.numel()
190
+ self.nnum_nempty[d] = node_key.numel()
191
+
192
+ # children
193
+ addr = (pidx << 3) | (node_key % 8)
194
+ children = -torch.ones(
195
+ self.nnum[d].item(), dtype=torch.int32, device=self.device)
196
+ children[addr] = torch.arange(
197
+ self.nnum_nempty[d], dtype=torch.int32, device=self.device)
198
+ self.children[d] = children
199
+
200
+ # cache pkey for the next iteration
201
+ # Use `pkey >> 45` instead of `pkey >> 48` in L199 since pkey is already
202
+ # shifted to the right by 3 bits in L177
203
+ node_key = pkey if self.batch_size == 1 else \
204
+ ((pkey >> 45) << 48) | (pkey & ((1 << 45) - 1))
205
+
206
+ # set the children for the layer full_layer,
207
+ # now the node_keys are the key for full_layer
208
+ d = self.full_depth
209
+ children = -torch.ones_like(self.children[d])
210
+ nempty_idx = node_key if self.batch_size == 1 else \
211
+ ((node_key >> 48) << (3 * d)) | (node_key & ((1 << 48) - 1))
212
+ children[nempty_idx] = torch.arange(
213
+ node_key.numel(), dtype=torch.int32, device=self.device)
214
+ self.children[d] = children
215
+ self.nnum_nempty[d] = node_key.numel()
216
+
217
+ # average the signal for the last octree layer
218
+ d = self.depth
219
+ points = scatter_add(points, idx, dim=0) # points is rescaled in [L:Scale]
220
+ self.points[d] = points / counts.unsqueeze(1)
221
+ if point_cloud.normals is not None:
222
+ normals = scatter_add(point_cloud.normals, idx, dim=0)
223
+ self.normals[d] = F.normalize(normals)
224
+ if point_cloud.features is not None:
225
+ features = scatter_add(point_cloud.features, idx, dim=0)
226
+ self.features[d] = features / counts.unsqueeze(1)
227
+
228
+ return idx
229
+
230
+ def octree_grow_full(self, depth: int, update_neigh: bool = True):
231
+ r''' Builds the full octree, which is essentially a dense volumetric grid.
232
+
233
+ Args:
234
+ depth (int): The depth of the octree.
235
+ update_neigh (bool): If True, construct the neighborhood indices.
236
+ '''
237
+
238
+ # check
239
+ assert depth <= self.full_depth, 'error'
240
+
241
+ # node number
242
+ num = 1 << (3 * depth)
243
+ self.nnum[depth] = num * self.batch_size
244
+ self.nnum_nempty[depth] = num * self.batch_size
245
+
246
+ # update key
247
+ key = torch.arange(num, dtype=torch.long, device=self.device)
248
+ bs = torch.arange(self.batch_size, dtype=torch.long, device=self.device)
249
+ key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
250
+ self.keys[depth] = key.view(-1)
251
+
252
+ # update children
253
+ self.children[depth] = torch.arange(
254
+ num * self.batch_size, dtype=torch.int32, device=self.device)
255
+
256
+ # update neigh if needed
257
+ if update_neigh:
258
+ self.construct_neigh(depth)
259
+
260
+ def octree_split(self, split: torch.Tensor, depth: int):
261
+ r''' Sets whether the octree nodes in :attr:`depth` are splitted or not.
262
+
263
+ Args:
264
+ split (torch.Tensor): The input tensor with its element indicating status
265
+ of each octree node: 0 - empty, 1 - non-empty or splitted.
266
+ depth (int): The depth of current octree.
267
+ '''
268
+
269
+ # split -> children
270
+ empty = split == 0
271
+ sum = cumsum(split, dim=0, exclusive=True)
272
+ children, nnum_nempty = torch.split(sum, [split.shape[0], 1])
273
+ children[empty] = -1
274
+
275
+ # boundary case, make sure that at least one octree node is splitted
276
+ if nnum_nempty == 0:
277
+ nnum_nempty = 1
278
+ children[0] = 0
279
+
280
+ # update octree
281
+ self.children[depth] = children.int()
282
+ self.nnum_nempty[depth] = nnum_nempty
283
+
284
+ def octree_grow(self, depth: int, update_neigh: bool = True):
285
+ r''' Grows the octree and updates the relevant properties. And in most
286
+ cases, call :func:`Octree.octree_split` to update the splitting status of
287
+ the octree before this function.
288
+
289
+ Args:
290
+ depth (int): The depth of the octree.
291
+ update_neigh (bool): If True, construct the neighborhood indices.
292
+ '''
293
+
294
+ # increase the octree depth if required
295
+ if depth > self.depth:
296
+ assert depth == self.depth + 1
297
+ self.depth = depth
298
+ self.keys.append(None)
299
+ self.children.append(None)
300
+ self.neighs.append(None)
301
+ self.features.append(None)
302
+ self.normals.append(None)
303
+ self.points.append(None)
304
+ zero = torch.zeros(1, dtype=torch.long)
305
+ self.nnum = torch.cat([self.nnum, zero])
306
+ self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
307
+ zero = zero.view(1, 1)
308
+ self.batch_nnum = torch.cat([self.batch_nnum, zero], dim=0)
309
+ self.batch_nnum_nempty = torch.cat([self.batch_nnum_nempty, zero], dim=0)
310
+
311
+ # node number
312
+ nnum = self.nnum_nempty[depth-1] * 8
313
+ self.nnum[depth] = nnum
314
+ self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
315
+
316
+ # update keys
317
+ key = self.key(depth-1, nempty=True)
318
+ batch_id = (key >> 48) << 48
319
+ key = (key & ((1 << 48) - 1)) << 3
320
+ key = key | batch_id
321
+ key = key.unsqueeze(1) + torch.arange(8, device=key.device)
322
+ self.keys[depth] = key.view(-1)
323
+
324
+ # update children
325
+ self.children[depth] = torch.arange(
326
+ nnum, dtype=torch.int32, device=self.device)
327
+
328
+ # update neighs
329
+ if update_neigh:
330
+ self.construct_neigh(depth)
331
+
332
+ def construct_neigh(self, depth: int):
333
+ r''' Constructs the :obj:`3x3x3` neighbors for each octree node.
334
+
335
+ Args:
336
+ depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
337
+ '''
338
+
339
+ if depth <= self.full_depth:
340
+ device = self.device
341
+ nnum = 1 << (3 * depth)
342
+ key = torch.arange(nnum, dtype=torch.long, device=device)
343
+ x, y, z, _ = key2xyz(key, depth)
344
+ xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
345
+ grid = range_grid(-1, 1, device) # (27, 3)
346
+ xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
347
+ xyz = xyz.view(-1, 3) # (N*27, 3)
348
+ neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
349
+
350
+ bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
351
+ neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
352
+
353
+ bound = 1 << depth
354
+ invalid = torch.logical_or((xyz < 0).any(1), (xyz >= bound).any(1))
355
+ neigh[:, invalid] = -1
356
+ self.neighs[depth] = neigh.view(-1, 27) # (B*N, 27)
357
+
358
+ else:
359
+ child_p = self.children[depth-1]
360
+ nempty = child_p >= 0
361
+ neigh_p = self.neighs[depth-1][nempty] # (N, 27)
362
+ neigh_p = neigh_p[:, self.lut_parent] # (N, 8, 27)
363
+ child_p = child_p[neigh_p] # (N, 8, 27)
364
+ invalid = torch.logical_or(child_p < 0, neigh_p < 0) # (N, 8, 27)
365
+ neigh = child_p * 8 + self.lut_child
366
+ neigh[invalid] = -1
367
+ self.neighs[depth] = neigh.view(-1, 27)
368
+
369
+ def construct_all_neigh(self):
370
+ r''' A convenient handler for constructing all neighbors.
371
+ '''
372
+
373
+ for depth in range(1, self.depth+1):
374
+ self.construct_neigh(depth)
375
+
376
+ def search_xyzb(self, query: torch.Tensor, depth: int, nempty: bool = False):
377
+ r''' Searches the octree nodes given the query points.
378
+
379
+ Args:
380
+ query (torch.Tensor): The coordinates of query points with shape
381
+ :obj:`(N, 4)`. The first 3 channels of the coordinates are :obj:`x`,
382
+ :obj:`y`, and :obj:`z`, and the last channel is the batch index. Note
383
+ that the coordinates must be in range :obj:`[0, 2^depth)`.
384
+ depth (int): The depth of the octree layer. nemtpy (bool): If true, only
385
+ searches the non-empty octree nodes.
386
+ '''
387
+
388
+ key = xyz2key(query[:, 0], query[:, 1], query[:, 2], query[:, 3], depth)
389
+ idx = self.search_key(key, depth, nempty)
390
+ return idx
391
+
392
+ def search_key(self, query: torch.Tensor, depth: int, nempty: bool = False):
393
+ r''' Searches the octree nodes given the query points.
394
+
395
+ Args:
396
+ query (torch.Tensor): The keys of query points with shape :obj:`(N,)`,
397
+ which are computed from the coordinates of query points.
398
+ depth (int): The depth of the octree layer. nemtpy (bool): If true, only
399
+ searches the non-empty octree nodes.
400
+ '''
401
+
402
+ key = self.key(depth, nempty)
403
+ # `torch.bucketize` is similar to `torch.searchsorted`.
404
+ # I choose `torch.bucketize` here because it has fewer dimension checks,
405
+ # resulting in slightly better performance according to the docs of
406
+ # pytorch-1.9.1, since `key` is always 1-D sorted sequence.
407
+ # https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
408
+ idx = torch.bucketize(query, key)
409
+
410
+ valid = idx < key.shape[0] # valid if in-bound
411
+ found = key[idx[valid]] == query[valid]
412
+ valid[valid.clone()] = found # valid if found
413
+ idx[valid.logical_not()] = -1 # set to -1 if invalid
414
+ return idx
415
+
416
+ def get_neigh(self, depth: int, kernel: str = '333', stride: int = 1,
417
+ nempty: bool = False):
418
+ r''' Returns the neighborhoods given the depth and a kernel shape.
419
+
420
+ Args:
421
+ depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
422
+ kernel (str): The kernel shape from :obj:`333`, :obj:`311`, :obj:`131`,
423
+ :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and :obj:`313`.
424
+ stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
425
+ stride is :obj:`2`, always returns the neighborhood of the first
426
+ siblings.
427
+ nempty (bool): If True, only returns the neighborhoods of the non-empty
428
+ octree nodes.
429
+ '''
430
+
431
+ if stride == 1:
432
+ neigh = self.neighs[depth]
433
+ elif stride == 2:
434
+ # clone neigh to avoid self.neigh[depth] being modified
435
+ neigh = self.neighs[depth][::8].clone()
436
+ else:
437
+ raise ValueError('Unsupported stride {}'.format(stride))
438
+
439
+ if nempty:
440
+ child = self.children[depth]
441
+ if stride == 1:
442
+ nempty_node = child >= 0
443
+ neigh = neigh[nempty_node]
444
+ valid = neigh >= 0
445
+ neigh[valid] = child[neigh[valid]].long() # remap the index
446
+
447
+ if kernel == '333':
448
+ return neigh
449
+ elif kernel in self.lut_kernel:
450
+ lut = self.lut_kernel[kernel]
451
+ return neigh[:, lut]
452
+ else:
453
+ raise ValueError('Unsupported kernel {}'.format(kernel))
454
+
455
+ def get_input_feature(self, feature: str, nempty: bool = False):
456
+ r''' Returns the initial input feature stored in octree.
457
+
458
+ Args:
459
+ feature (str): A string used to indicate which features to extract from
460
+ the input octree. If the character :obj:`N` is in :attr:`feature`, the
461
+ normal signal is extracted (3 channels). Similarly, if :obj:`D` is in
462
+ :attr:`feature`, the local displacement is extracted (1 channels). If
463
+ :obj:`L` is in :attr:`feature`, the local coordinates of the averaged
464
+ points in each octree node is extracted (3 channels). If :attr:`P` is
465
+ in :attr:`feature`, the global coordinates are extracted (3 channels).
466
+ If :attr:`F` is in :attr:`feature`, other features (like colors) are
467
+ extracted (k channels).
468
+ nempty (bool): If false, gets the features of all octree nodes.
469
+ '''
470
+
471
+ features = list()
472
+ depth = self.depth
473
+ feature = feature.upper()
474
+ if 'N' in feature:
475
+ features.append(self.normals[depth])
476
+
477
+ if 'L' in feature or 'D' in feature:
478
+ local_points = self.points[depth].frac() - 0.5
479
+
480
+ if 'D' in feature:
481
+ dis = torch.sum(local_points * self.normals[depth], dim=1, keepdim=True)
482
+ features.append(dis)
483
+
484
+ if 'L' in feature:
485
+ features.append(local_points)
486
+
487
+ if 'P' in feature:
488
+ scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1]
489
+ global_points = self.points[depth] * scale - 1.0
490
+ features.append(global_points)
491
+
492
+ if 'F' in feature:
493
+ features.append(self.features[depth])
494
+
495
+ out = torch.cat(features, dim=1)
496
+ if not nempty:
497
+ out = ocnn.nn.octree_pad(out, self, depth)
498
+ return out
499
+
500
+ def to_points(self, rescale: bool = True):
501
+ r''' Converts averaged points in the octree to a point cloud.
502
+
503
+ Args:
504
+ rescale (bool): rescale the xyz coordinates to [-1, 1] if True.
505
+ '''
506
+
507
+ depth = self.depth
508
+ batch_size = self.batch_size
509
+
510
+ # by default, use the average points generated when building the octree
511
+ # from the input point cloud
512
+ xyz = self.points[depth]
513
+ batch_id = self.batch_id(depth, nempty=True)
514
+
515
+ # xyz is None when the octree is predicted by a neural network
516
+ if xyz is None:
517
+ x, y, z, batch_id = self.xyzb(depth, nempty=True)
518
+ xyz = torch.stack([x, y, z], dim=1) + 0.5
519
+
520
+ # normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
521
+ if rescale:
522
+ scale = 2 ** (1 - depth)
523
+ xyz = xyz * scale - 1.0
524
+
525
+ # construct Points
526
+ out = Points(xyz, self.normals[depth], self.features[depth],
527
+ batch_id=batch_id, batch_size=batch_size)
528
+ return out
529
+
530
+ def to(self, device: Union[torch.device, str], non_blocking: bool = False):
531
+ r''' Moves the octree to a specified device.
532
+
533
+ Args:
534
+ device (torch.device or str): The destination device.
535
+ non_blocking (bool): If True and the source is in pinned memory, the copy
536
+ will be asynchronous with respect to the host. Otherwise, the argument
537
+ has no effect. Default: False.
538
+ '''
539
+
540
+ if isinstance(device, str):
541
+ device = torch.device(device)
542
+
543
+ # If on the save device, directly retrun self
544
+ if self.device == device:
545
+ return self
546
+
547
+ def list_to_device(prop):
548
+ return [p.to(device, non_blocking=non_blocking)
549
+ if isinstance(p, torch.Tensor) else None for p in prop]
550
+
551
+ # Construct a new Octree on the specified device
552
+ octree = Octree(self.depth, self.full_depth, self.batch_size, device)
553
+ octree.keys = list_to_device(self.keys)
554
+ octree.children = list_to_device(self.children)
555
+ octree.neighs = list_to_device(self.neighs)
556
+ octree.features = list_to_device(self.features)
557
+ octree.normals = list_to_device(self.normals)
558
+ octree.points = list_to_device(self.points)
559
+ octree.nnum = self.nnum.clone() # TODO: whether to move nnum to the self.device?
560
+ octree.nnum_nempty = self.nnum_nempty.clone()
561
+ octree.batch_nnum = self.batch_nnum.clone()
562
+ octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
563
+ return octree
564
+
565
+ def cuda(self, non_blocking: bool = False):
566
+ r''' Moves the octree to the GPU. '''
567
+
568
+ return self.to('cuda', non_blocking)
569
+
570
+ def cpu(self):
571
+ r''' Moves the octree to the CPU. '''
572
+
573
+ return self.to('cpu')
574
+
575
+
576
+ def merge_octrees(octrees: List['Octree']):
577
+ r''' Merges a list of octrees into one batch.
578
+
579
+ Args:
580
+ octrees (List[Octree]): A list of octrees to merge.
581
+
582
+ Returns:
583
+ Octree: The merged octree.
584
+ '''
585
+
586
+ # init and check
587
+ octree = Octree(depth=octrees[0].depth, full_depth=octrees[0].full_depth,
588
+ batch_size=len(octrees), device=octrees[0].device)
589
+ for i in range(1, octree.batch_size):
590
+ condition = (octrees[i].depth == octree.depth and
591
+ octrees[i].full_depth == octree.full_depth and
592
+ octrees[i].device == octree.device)
593
+ assert condition, 'The check of merge_octrees failed'
594
+
595
+ # node num
596
+ batch_nnum = torch.stack(
597
+ [octrees[i].nnum for i in range(octree.batch_size)], dim=1)
598
+ batch_nnum_nempty = torch.stack(
599
+ [octrees[i].nnum_nempty for i in range(octree.batch_size)], dim=1)
600
+ octree.nnum = torch.sum(batch_nnum, dim=1)
601
+ octree.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
602
+ octree.batch_nnum = batch_nnum
603
+ octree.batch_nnum_nempty = batch_nnum_nempty
604
+ nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
605
+
606
+ # merge octre properties
607
+ for d in range(octree.depth+1):
608
+ # key
609
+ keys = [None] * octree.batch_size
610
+ for i in range(octree.batch_size):
611
+ key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
612
+ keys[i] = key | (i << 48)
613
+ octree.keys[d] = torch.cat(keys, dim=0)
614
+
615
+ # children
616
+ children = [None] * octree.batch_size
617
+ for i in range(octree.batch_size):
618
+ child = octrees[i].children[d].clone() # !! `clone` is used here to avoid
619
+ mask = child >= 0 # !! modifying the original octrees
620
+ child[mask] = child[mask] + nnum_cum[d, i]
621
+ children[i] = child
622
+ octree.children[d] = torch.cat(children, dim=0)
623
+
624
+ # features
625
+ if octrees[0].features[d] is not None and d == octree.depth:
626
+ features = [octrees[i].features[d] for i in range(octree.batch_size)]
627
+ octree.features[d] = torch.cat(features, dim=0)
628
+
629
+ # normals
630
+ if octrees[0].normals[d] is not None and d == octree.depth:
631
+ normals = [octrees[i].normals[d] for i in range(octree.batch_size)]
632
+ octree.normals[d] = torch.cat(normals, dim=0)
633
+
634
+ # points
635
+ if octrees[0].points[d] is not None and d == octree.depth:
636
+ points = [octrees[i].points[d] for i in range(octree.batch_size)]
637
+ octree.points[d] = torch.cat(points, dim=0)
638
+
639
+ return octree
640
+
641
+
642
+ def init_octree(depth: int, full_depth: int = 2, batch_size: int = 1,
643
+ device: Union[torch.device, str] = 'cpu'):
644
+ r'''
645
+ Initializes an octree to :attr:`full_depth`.
646
+
647
+ Args:
648
+ depth (int): The depth of the octree.
649
+ full_depth (int): The octree layers with a depth small than
650
+ :attr:`full_depth` are forced to be full.
651
+ batch_size (int, optional): The batch size.
652
+ device (torch.device or str): The device to use for computation.
653
+
654
+ Returns:
655
+ Octree: The initialized Octree object.
656
+ '''
657
+
658
+ octree = Octree(depth, full_depth, batch_size, device)
659
+ for d in range(full_depth+1):
660
+ octree.octree_grow_full(depth=d)
661
+ return octree