ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ocnn/__init__.py +24 -24
  2. ocnn/dataset.py +160 -160
  3. ocnn/models/__init__.py +29 -29
  4. ocnn/models/autoencoder.py +155 -155
  5. ocnn/models/hrnet.py +192 -192
  6. ocnn/models/image2shape.py +128 -128
  7. ocnn/models/lenet.py +46 -46
  8. ocnn/models/ounet.py +94 -94
  9. ocnn/models/resnet.py +53 -53
  10. ocnn/models/segnet.py +72 -72
  11. ocnn/models/unet.py +105 -105
  12. ocnn/modules/__init__.py +26 -26
  13. ocnn/modules/modules.py +303 -303
  14. ocnn/modules/resblocks.py +158 -158
  15. ocnn/nn/__init__.py +45 -44
  16. ocnn/nn/kernels/__init__.py +14 -0
  17. ocnn/nn/kernels/autotuner.py +416 -0
  18. ocnn/nn/kernels/config.py +67 -0
  19. ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
  20. ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
  21. ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
  22. ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
  23. ocnn/nn/kernels/utils.py +44 -0
  24. ocnn/nn/octree2col.py +53 -53
  25. ocnn/nn/octree2vox.py +50 -50
  26. ocnn/nn/octree_align.py +46 -46
  27. ocnn/nn/octree_conv.py +430 -429
  28. ocnn/nn/octree_conv_t.py +148 -0
  29. ocnn/nn/octree_drop.py +55 -55
  30. ocnn/nn/octree_dwconv.py +222 -222
  31. ocnn/nn/octree_gconv.py +79 -79
  32. ocnn/nn/octree_interp.py +196 -196
  33. ocnn/nn/octree_norm.py +126 -126
  34. ocnn/nn/octree_pad.py +39 -39
  35. ocnn/nn/octree_pool.py +200 -200
  36. ocnn/octree/__init__.py +22 -22
  37. ocnn/octree/octree.py +770 -770
  38. ocnn/octree/points.py +384 -323
  39. ocnn/octree/shuffled_key.py +115 -115
  40. ocnn/utils.py +205 -205
  41. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
  42. ocnn-2.3.0.dist-info/RECORD +45 -0
  43. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
  44. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
  45. ocnn-2.2.8.dist-info/RECORD +0 -36
  46. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
ocnn/octree/octree.py CHANGED
@@ -1,770 +1,770 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn.functional as F
10
- from typing import Union, List
11
-
12
- import ocnn
13
- from ocnn.octree.points import Points
14
- from ocnn.octree.shuffled_key import xyz2key, key2xyz
15
- from ocnn.utils import range_grid, scatter_add, cumsum, trunc_div
16
-
17
-
18
- class Octree:
19
- r''' Builds an octree from an input point cloud.
20
-
21
- Args:
22
- depth (int): The octree depth.
23
- full_depth (int): The octree layers with a depth small than
24
- :attr:`full_depth` are forced to be full.
25
- batch_size (int): The octree batch size.
26
- device (torch.device or str): Choose from :obj:`cpu` and :obj:`gpu`.
27
- (default: :obj:`cpu`)
28
-
29
- .. note::
30
- The octree data structure requires that if an octree node has children nodes,
31
- the number of children nodes is exactly 8, in which some of the nodes are
32
- empty and some nodes are non-empty. The properties of an octree, including
33
- :obj:`keys`, :obj:`children` and :obj:`neighs`, contain both non-empty and
34
- empty nodes, and other properties, including :obj:`features`, :obj:`normals`
35
- and :obj:`points`, contain only non-empty nodes.
36
-
37
- .. note::
38
- The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
39
- is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
40
- some margin.
41
- '''
42
-
43
- def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
44
- device: Union[torch.device, str] = 'cpu', **kwargs):
45
- super().__init__()
46
- # configurations for initialization
47
- self.depth = depth
48
- self.full_depth = full_depth
49
- self.batch_size = batch_size
50
- self.device = device
51
-
52
- # properties after building the octree
53
- self.reset()
54
-
55
- def reset(self):
56
- r''' Resets the Octree status and constructs several lookup tables.
57
- '''
58
-
59
- # octree features in each octree layers
60
- num = self.depth + 1
61
- self.keys = [None] * num
62
- self.children = [None] * num
63
- self.neighs = [None] * num
64
- self.features = [None] * num
65
- self.normals = [None] * num
66
- self.points = [None] * num
67
-
68
- # self.nempty_masks, self.nempty_indices and self.nempty_neighs are
69
- # for handling of non-empty nodes and are constructed on demand
70
- self.nempty_masks = [None] * num
71
- self.nempty_indices = [None] * num
72
- self.nempty_neighs = [None] * num
73
-
74
- # octree node numbers in each octree layers.
75
- # These are small 1-D tensors; just keep them on CPUs
76
- self.nnum = torch.zeros(num, dtype=torch.long)
77
- self.nnum_nempty = torch.zeros(num, dtype=torch.long)
78
-
79
- # the following properties are only valid after `merge_octrees`.
80
- # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
81
- batch_size = self.batch_size
82
- self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
83
- self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
84
-
85
- # construct the look up tables for neighborhood searching
86
- device = self.device
87
- center_grid = range_grid(2, 3, device) # (8, 3)
88
- displacement = range_grid(-1, 1, device) # (27, 3)
89
- neigh_grid = center_grid.unsqueeze(1) + displacement # (8, 27, 3)
90
- parent_grid = trunc_div(neigh_grid, 2)
91
- child_grid = neigh_grid % 2
92
- self.lut_parent = torch.sum(
93
- parent_grid * torch.tensor([9, 3, 1], device=device), dim=2)
94
- self.lut_child = torch.sum(
95
- child_grid * torch.tensor([4, 2, 1], device=device), dim=2)
96
-
97
- # lookup tables for different kernel sizes
98
- self.lut_kernel = {
99
- '222': torch.tensor([13, 14, 16, 17, 22, 23, 25, 26], device=device),
100
- '311': torch.tensor([4, 13, 22], device=device),
101
- '131': torch.tensor([10, 13, 16], device=device),
102
- '113': torch.tensor([12, 13, 14], device=device),
103
- '331': torch.tensor([1, 4, 7, 10, 13, 16, 19, 22, 25], device=device),
104
- '313': torch.tensor([3, 4, 5, 12, 13, 14, 21, 22, 23], device=device),
105
- '133': torch.tensor([9, 10, 11, 12, 13, 14, 15, 16, 17], device=device),
106
- }
107
-
108
- def key(self, depth: int, nempty: bool = False):
109
- r''' Returns the shuffled key of each octree node.
110
-
111
- Args:
112
- depth (int): The depth of the octree.
113
- nempty (bool): If True, returns the results of non-empty octree nodes.
114
- '''
115
-
116
- key = self.keys[depth]
117
- if nempty:
118
- idx = self.nempty_index(depth)
119
- key = key[idx]
120
- return key
121
-
122
- def xyzb(self, depth: int, nempty: bool = False):
123
- r''' Returns the xyz coordinates and the batch indices of each octree node.
124
-
125
- Args:
126
- depth (int): The depth of the octree.
127
- nempty (bool): If True, returns the results of non-empty octree nodes.
128
- '''
129
-
130
- key = self.key(depth, nempty)
131
- return key2xyz(key, depth)
132
-
133
- def batch_id(self, depth: int, nempty: bool = False):
134
- r''' Returns the batch indices of each octree node.
135
-
136
- Args:
137
- depth (int): The depth of the octree.
138
- nempty (bool): If True, returns the results of non-empty octree nodes.
139
- '''
140
-
141
- batch_id = self.keys[depth] >> 48
142
- if nempty:
143
- idx = self.nempty_index(depth)
144
- batch_id = batch_id[idx]
145
- return batch_id
146
-
147
- def nempty_mask(self, depth: int, reset: bool = False):
148
- r''' Returns a binary mask which indicates whether the cooreponding octree
149
- node is empty or not.
150
-
151
- Args:
152
- depth (int): The depth of the octree.
153
- reset (bool): If True, recomputes the mask.
154
- '''
155
-
156
- if self.nempty_masks[depth] is None or reset:
157
- self.nempty_masks[depth] = self.children[depth] >= 0
158
- return self.nempty_masks[depth]
159
-
160
- def nempty_index(self, depth: int, reset: bool = False):
161
- r''' Returns the indices of non-empty octree nodes.
162
-
163
- Args:
164
- depth (int): The depth of the octree.
165
- reset (bool): If True, recomputes the indices.
166
- '''
167
-
168
- if self.nempty_indices[depth] is None or reset:
169
- mask = self.nempty_mask(depth)
170
- rng = torch.arange(mask.shape[0], device=mask.device, dtype=torch.long)
171
- self.nempty_indices[depth] = rng[mask]
172
- return self.nempty_indices[depth]
173
-
174
- def nempty_neigh(self, depth: int, reset: bool = False):
175
- r''' Returns the neighborhoods of non-empty octree nodes.
176
- Args:
177
-
178
- depth (int): The depth of the octree.
179
- reset (bool): If True, recomputes the neighborhoods.
180
- '''
181
- if self.nempty_neighs[depth] is None or reset:
182
- neigh = self.neighs[depth]
183
- idx = self.nempty_index(depth)
184
- neigh = self.remap_nempty_neigh(neigh[idx], depth)
185
- self.nempty_neighs[depth] = neigh
186
- return self.nempty_neighs[depth]
187
-
188
- def remap_nempty_neigh(self, neigh: torch.Tensor, depth: int):
189
- r''' Remaps the neighborhood indices to the non-empty octree nodes.
190
-
191
- Args:
192
- neigh (torch.Tensor): The input neighborhoods with shape :obj:`(N, 27)`.
193
- depth (int): The depth of the octree.
194
- '''
195
-
196
- valid = neigh >= 0
197
- child = self.children[depth]
198
- neigh[valid] = child[neigh[valid]].long() # remap the index
199
- return neigh
200
-
201
- def build_octree(self, point_cloud: Points):
202
- r''' Builds an octree from a point cloud.
203
-
204
- Args:
205
- point_cloud (Points): The input point cloud.
206
-
207
- .. note::
208
- The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
209
- is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
210
- some margin.
211
- '''
212
-
213
- self.device = point_cloud.device
214
- assert point_cloud.batch_size == self.batch_size, 'Inconsistent batch_size'
215
-
216
- # normalize points from [-1, 1] to [0, 2^depth]. #[L:Scale]
217
- scale = 2 ** (self.depth - 1)
218
- points = (point_cloud.points + 1.0) * scale
219
-
220
- # get the shuffled key and sort
221
- x, y, z = points[:, 0], points[:, 1], points[:, 2]
222
- b = None if self.batch_size == 1 else point_cloud.batch_id.view(-1)
223
- key = xyz2key(x, y, z, b, self.depth)
224
- node_key, idx, counts = torch.unique(
225
- key, sorted=True, return_inverse=True, return_counts=True)
226
-
227
- # layer 0 to full_layer: the octree is full in these layers
228
- for d in range(self.full_depth+1):
229
- self.octree_grow_full(d, update_neigh=False)
230
-
231
- # layer depth_ to full_layer_
232
- for d in range(self.depth, self.full_depth, -1):
233
- # compute parent key, i.e. keys of layer (d -1)
234
- pkey = node_key >> 3
235
- pkey, pidx, _ = torch.unique_consecutive(
236
- pkey, return_inverse=True, return_counts=True)
237
-
238
- # augmented key
239
- key = (pkey.unsqueeze(-1) << 3) + torch.arange(8, device=self.device)
240
- self.keys[d] = key.view(-1)
241
- self.nnum[d] = key.numel()
242
- self.nnum_nempty[d] = node_key.numel()
243
-
244
- # children
245
- addr = (pidx << 3) | (node_key % 8)
246
- children = -torch.ones(
247
- self.nnum[d].item(), dtype=torch.int32, device=self.device)
248
- children[addr] = torch.arange(
249
- self.nnum_nempty[d], dtype=torch.int32, device=self.device)
250
- self.children[d] = children
251
-
252
- # cache pkey for the next iteration
253
- # Use `pkey >> 45` instead of `pkey >> 48` in L199 since pkey is already
254
- # shifted to the right by 3 bits in L177
255
- node_key = pkey if self.batch_size == 1 else \
256
- ((pkey >> 45) << 48) | (pkey & ((1 << 45) - 1))
257
-
258
- # set the children for the layer full_layer,
259
- # now the node_keys are the key for full_layer
260
- d = self.full_depth
261
- children = -torch.ones_like(self.children[d])
262
- nempty_idx = node_key if self.batch_size == 1 else \
263
- ((node_key >> 48) << (3 * d)) | (node_key & ((1 << 48) - 1))
264
- children[nempty_idx] = torch.arange(
265
- node_key.numel(), dtype=torch.int32, device=self.device)
266
- self.children[d] = children
267
- self.nnum_nempty[d] = node_key.numel()
268
-
269
- # average the signal for the last octree layer
270
- d = self.depth
271
- points = scatter_add(points, idx, dim=0) # points is rescaled in [L:Scale]
272
- self.points[d] = points / counts.unsqueeze(1)
273
- if point_cloud.normals is not None:
274
- normals = scatter_add(point_cloud.normals, idx, dim=0)
275
- self.normals[d] = F.normalize(normals)
276
- if point_cloud.features is not None:
277
- features = scatter_add(point_cloud.features, idx, dim=0)
278
- self.features[d] = features / counts.unsqueeze(1)
279
-
280
- # reset nempty_masks and nempty_indices, which will be updated on demand
281
- self.nempty_masks = [None] * (self.depth + 1)
282
- self.nempty_indices = [None] * (self.depth + 1)
283
- self.nempty_neighs = [None] * (self.depth + 1)
284
- return idx
285
-
286
- def octree_grow_full(self, depth: int, update_neigh: bool = True):
287
- r''' Builds the full octree, which is essentially a dense volumetric grid.
288
-
289
- Args:
290
- depth (int): The depth of the octree.
291
- update_neigh (bool): If True, construct the neighborhood indices.
292
- '''
293
-
294
- # check
295
- assert depth <= self.full_depth, 'error'
296
-
297
- # node number
298
- num = 1 << (3 * depth)
299
- batch_size = self.batch_size
300
- self.nnum[depth] = num * batch_size
301
- self.nnum_nempty[depth] = num * batch_size
302
-
303
- # update key
304
- key = torch.arange(num, dtype=torch.long, device=self.device)
305
- bs = torch.arange(batch_size, dtype=torch.long, device=self.device)
306
- key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
307
- self.keys[depth] = key.view(-1)
308
-
309
- # update children
310
- self.children[depth] = torch.arange(
311
- num * batch_size, dtype=torch.int32, device=self.device)
312
-
313
- # nempty_masks, nempty_indices, and nempty_neighs
314
- # need not be reset for full octrees
315
-
316
- # update neigh if needed
317
- if update_neigh:
318
- self.construct_neigh(depth)
319
-
320
- def octree_split(self, split: torch.Tensor, depth: int):
321
- r''' Sets whether the octree nodes in :attr:`depth` are splitted or not.
322
-
323
- Args:
324
- split (torch.Tensor): The input tensor with its element indicating status
325
- of each octree node: 0 - empty, 1 - non-empty or splitted.
326
- depth (int): The depth of current octree.
327
- '''
328
-
329
- # split -> children
330
- empty = split == 0
331
- sum = cumsum(split, dim=0, exclusive=True)
332
- children, nnum_nempty = torch.split(sum, [split.shape[0], 1])
333
- children[empty] = -1
334
-
335
- # boundary case, make sure that at least one octree node is splitted
336
- if nnum_nempty == 0:
337
- nnum_nempty = 1
338
- children[0] = 0
339
-
340
- # update octree
341
- self.children[depth] = children.int()
342
- self.nnum_nempty[depth] = nnum_nempty
343
-
344
- # reset nempty_masks, nempty_indices, and nempty_neighs as they depend on
345
- # children[depth] and are invalid now
346
- self.nempty_masks[depth] = None
347
- self.nempty_indices[depth] = None
348
- self.nempty_neighs[depth] = None
349
-
350
- def octree_grow(self, depth: int, update_neigh: bool = True):
351
- r''' Grows the octree and updates the relevant properties. And in most
352
- cases, call :func:`Octree.octree_split` to update the splitting status of
353
- the octree before this function.
354
-
355
- Args:
356
- depth (int): The depth of the octree.
357
- update_neigh (bool): If True, construct the neighborhood indices.
358
- '''
359
-
360
- # increase the octree depth if required
361
- if depth > self.depth:
362
- assert depth == self.depth + 1
363
- self.depth = depth
364
- self.keys.append(None)
365
- self.children.append(None)
366
- self.neighs.append(None)
367
- self.features.append(None)
368
- self.normals.append(None)
369
- self.points.append(None)
370
- self.nempty_masks.append(None)
371
- self.nempty_indices.append(None)
372
- self.nempty_neighs.append(None)
373
- zero = torch.zeros(1, dtype=torch.long)
374
- self.nnum = torch.cat([self.nnum, zero])
375
- self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
376
- zero = zero.view(1, 1)
377
- self.batch_nnum = torch.cat([self.batch_nnum, zero], dim=0)
378
- self.batch_nnum_nempty = torch.cat([self.batch_nnum_nempty, zero], dim=0)
379
-
380
- # node number
381
- nnum = self.nnum_nempty[depth-1] * 8
382
- self.nnum[depth] = nnum
383
- self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
384
-
385
- # update keys
386
- key = self.key(depth-1, nempty=True)
387
- batch_id = (key >> 48) << 48
388
- key = (key & ((1 << 48) - 1)) << 3
389
- key = key | batch_id
390
- key = key.unsqueeze(1) + torch.arange(8, device=key.device)
391
- self.keys[depth] = key.view(-1)
392
-
393
- # update children
394
- self.children[depth] = torch.arange(
395
- nnum, dtype=torch.int32, device=self.device)
396
-
397
- # update neighs
398
- if update_neigh:
399
- self.construct_neigh(depth)
400
-
401
- def construct_neigh(self, depth: int):
402
- r''' Constructs the :obj:`3x3x3` neighbors for each octree node.
403
-
404
- Args:
405
- depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
406
- '''
407
-
408
- if depth <= self.full_depth:
409
- device = self.device
410
- nnum = 1 << (3 * depth)
411
- key = torch.arange(nnum, dtype=torch.long, device=device)
412
- x, y, z, _ = key2xyz(key, depth)
413
- xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
414
- grid = range_grid(-1, 1, device) # (27, 3)
415
- xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
416
- xyz = xyz.view(-1, 3) # (N*27, 3)
417
- neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
418
-
419
- bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
420
- neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
421
-
422
- bound = 1 << depth
423
- invalid = torch.logical_or((xyz < 0).any(1), (xyz >= bound).any(1))
424
- neigh[:, invalid] = -1
425
- self.neighs[depth] = neigh.view(-1, 27) # (B*N, 27)
426
-
427
- else:
428
- child_p = self.children[depth-1]
429
- nempty = child_p >= 0
430
- neigh_p = self.neighs[depth-1][nempty] # (N, 27)
431
- neigh_p = neigh_p[:, self.lut_parent] # (N, 8, 27)
432
- child_p = child_p[neigh_p] # (N, 8, 27)
433
- invalid = torch.logical_or(child_p < 0, neigh_p < 0) # (N, 8, 27)
434
- neigh = child_p * 8 + self.lut_child
435
- neigh[invalid] = -1
436
- self.neighs[depth] = neigh.view(-1, 27)
437
-
438
- def construct_all_neigh(self):
439
- r''' A convenient handler for constructing all neighbors.
440
- '''
441
-
442
- for depth in range(1, self.depth+1):
443
- self.construct_neigh(depth)
444
-
445
- def search_xyzb(self, query: torch.Tensor, depth: int, nempty: bool = False):
446
- r''' Searches the octree nodes given the query points.
447
-
448
- Args:
449
- query (torch.Tensor): The coordinates of query points with shape
450
- :obj:`(N, 4)`. The first 3 channels of the coordinates are :obj:`x`,
451
- :obj:`y`, and :obj:`z`, and the last channel is the batch index. Note
452
- that the coordinates must be in range :obj:`[0, 2^depth)`.
453
- depth (int): The depth of the octree layer. nemtpy (bool): If true, only
454
- searches the non-empty octree nodes.
455
- '''
456
-
457
- key = xyz2key(query[:, 0], query[:, 1], query[:, 2], query[:, 3], depth)
458
- idx = self.search_key(key, depth, nempty)
459
- return idx
460
-
461
- def search_key(self, query: torch.Tensor, depth: int, nempty: bool = False):
462
- r''' Searches the octree nodes given the query points.
463
-
464
- Args:
465
- query (torch.Tensor): The keys of query points with shape :obj:`(N,)`,
466
- which are computed from the coordinates of query points.
467
- depth (int): The depth of the octree layer. nemtpy (bool): If true, only
468
- searches the non-empty octree nodes.
469
- '''
470
-
471
- key = self.key(depth, nempty)
472
- idx = torch.searchsorted(key, query)
473
-
474
- # `torch.bucketize` can also be used here; it is similar to
475
- # `torch.searchsorted`, and it has fewer dimension checks, resulting in
476
- # slightly better performance for 1D sorted sequences according to the docs
477
- # of pytorch-1.9.1. `key` is always a 1D sorted sequence.
478
- # https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
479
- # idx = torch.bucketize(query, key)
480
-
481
- valid = idx < key.shape[0] # valid if in-bound
482
- vi = torch.arange(query.shape[0], device=query.device)[valid]
483
- valid[vi] = key[idx[vi]] == query[vi] # valid if found
484
- idx[valid.logical_not()] = -1 # set to -1 if invalid
485
- return idx
486
-
487
- def get_neigh(self, depth: int, kernel: str = '333', stride: int = 1,
488
- nempty: bool = False):
489
- r''' Returns the neighborhoods given the depth and a kernel shape.
490
-
491
- Args:
492
- depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
493
- kernel (str): The kernel shape from :obj:`333`, :obj:`311`, :obj:`131`,
494
- :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and :obj:`313`.
495
- stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
496
- stride is :obj:`2`, always returns the neighborhood of the first
497
- siblings.
498
- nempty (bool): If True, only returns the neighborhoods of the non-empty
499
- octree nodes.
500
- '''
501
-
502
- if stride == 1 and not nempty:
503
- neigh = self.neighs[depth]
504
- elif stride == 2 and not nempty:
505
- neigh = self.neighs[depth][::8].clone()
506
- elif stride == 1 and nempty:
507
- neigh = self.nempty_neigh(depth)
508
- elif stride == 2 and nempty:
509
- neigh = self.neighs[depth][::8].clone()
510
- neigh = self.remap_nempty_neigh(neigh, depth)
511
- else:
512
- raise ValueError('Unsupported stride {}'.format(stride))
513
-
514
- if kernel == '333':
515
- return neigh
516
- elif kernel in self.lut_kernel:
517
- lut = self.lut_kernel[kernel]
518
- return neigh[:, lut]
519
- else:
520
- raise ValueError('Unsupported kernel {}'.format(kernel))
521
-
522
- def get_input_feature(self, feature: str, nempty: bool = False):
523
- r''' Returns the initial input feature stored in octree.
524
-
525
- Args:
526
- feature (str): A string used to indicate which features to extract from
527
- the input octree. If the character :obj:`N` is in :attr:`feature`, the
528
- normal signal is extracted (3 channels). Similarly, if :obj:`D` is in
529
- :attr:`feature`, the local displacement is extracted (1 channels). If
530
- :obj:`L` is in :attr:`feature`, the local coordinates of the averaged
531
- points in each octree node is extracted (3 channels). If :attr:`P` is
532
- in :attr:`feature`, the global coordinates are extracted (3 channels).
533
- If :attr:`F` is in :attr:`feature`, other features (like colors) are
534
- extracted (k channels).
535
- nempty (bool): If false, gets the features of all octree nodes.
536
- '''
537
-
538
- features = list()
539
- depth = self.depth
540
- feature = feature.upper()
541
- if 'N' in feature:
542
- features.append(self.normals[depth])
543
-
544
- if 'L' in feature or 'D' in feature:
545
- local_points = self.points[depth].frac() - 0.5
546
-
547
- if 'D' in feature:
548
- dis = torch.sum(local_points * self.normals[depth], dim=1, keepdim=True)
549
- features.append(dis)
550
-
551
- if 'L' in feature:
552
- features.append(local_points)
553
-
554
- if 'P' in feature:
555
- scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1]
556
- global_points = self.points[depth] * scale - 1.0
557
- features.append(global_points)
558
-
559
- if 'F' in feature:
560
- features.append(self.features[depth])
561
-
562
- out = torch.cat(features, dim=1)
563
- if not nempty:
564
- out = ocnn.nn.octree_pad(out, self, depth)
565
- return out
566
-
567
- def to_points(self, rescale: bool = True):
568
- r''' Converts averaged points in the octree to a point cloud.
569
-
570
- Args:
571
- rescale (bool): rescale the xyz coordinates to [-1, 1] if True.
572
- '''
573
-
574
- depth = self.depth
575
- batch_size = self.batch_size
576
-
577
- # by default, use the average points generated when building the octree
578
- # from the input point cloud
579
- xyz = self.points[depth]
580
- batch_id = self.batch_id(depth, nempty=True)
581
-
582
- # xyz is None when the octree is predicted by a neural network
583
- if xyz is None:
584
- x, y, z, batch_id = self.xyzb(depth, nempty=True)
585
- xyz = torch.stack([x, y, z], dim=1) + 0.5
586
-
587
- # normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
588
- if rescale:
589
- scale = 2 ** (1 - depth)
590
- xyz = xyz * scale - 1.0
591
-
592
- # construct Points
593
- out = Points(xyz, self.normals[depth], self.features[depth],
594
- batch_id=batch_id, batch_size=batch_size)
595
- return out
596
-
597
- def to(self, device: Union[torch.device, str], non_blocking: bool = False):
598
- r''' Moves the octree to a specified device.
599
-
600
- Args:
601
- device (torch.device or str): The destination device.
602
- non_blocking (bool): If True and the source is in pinned memory, the copy
603
- will be asynchronous with respect to the host. Otherwise, the argument
604
- has no effect. Default: False.
605
- '''
606
-
607
- if isinstance(device, str):
608
- device = torch.device(device)
609
-
610
- # If on the save device, directly retrun self
611
- if self.device == device:
612
- return self
613
-
614
- def list_to_device(prop):
615
- return [p.to(device, non_blocking=non_blocking)
616
- if isinstance(p, torch.Tensor) else None for p in prop]
617
-
618
- # Construct a new Octree on the specified device.
619
- # During the initialization, self.device is used to set up the new Octree;
620
- # the look-up tables (including self.lut_kernel, self.lut_parent, and
621
- # self.lut_child), will be already created on the correct device.
622
- octree = Octree.init_like(self, device)
623
-
624
- # Move all the other properties to the specified device
625
- octree.keys = list_to_device(self.keys)
626
- octree.children = list_to_device(self.children)
627
- octree.neighs = list_to_device(self.neighs)
628
- octree.features = list_to_device(self.features)
629
- octree.normals = list_to_device(self.normals)
630
- octree.points = list_to_device(self.points)
631
-
632
- # The following are small tensors, keep them on CPU to avoid frequent device
633
- # switching, so just clone them.
634
- octree.nnum = self.nnum.clone()
635
- octree.nnum_nempty = self.nnum_nempty.clone()
636
- octree.batch_nnum = self.batch_nnum.clone()
637
- octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
638
- return octree
639
-
640
- def cuda(self, non_blocking: bool = False):
641
- r''' Moves the octree to the GPU. '''
642
-
643
- return self.to('cuda', non_blocking)
644
-
645
- def cpu(self):
646
- r''' Moves the octree to the CPU. '''
647
-
648
- return self.to('cpu')
649
-
650
- def merge_octrees(self, octrees: List['Octree']):
651
- r''' Merges a list of octrees into one batch.
652
-
653
- Args:
654
- octrees (List[Octree]): A list of octrees to merge.
655
-
656
- Returns:
657
- Octree: The merged octree.
658
- '''
659
-
660
- # init and check
661
- batch_size = len(octrees)
662
- self.batch_size = batch_size
663
- for i in range(1, batch_size):
664
- condition = (octrees[i].depth == self.depth and
665
- octrees[i].full_depth == self.full_depth and
666
- octrees[i].device == self.device)
667
- assert condition, 'The check of merge_octrees failed'
668
-
669
- # node num
670
- batch_nnum = torch.stack(
671
- [octrees[i].nnum for i in range(batch_size)], dim=1)
672
- batch_nnum_nempty = torch.stack(
673
- [octrees[i].nnum_nempty for i in range(batch_size)], dim=1)
674
- self.nnum = torch.sum(batch_nnum, dim=1)
675
- self.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
676
- self.batch_nnum = batch_nnum
677
- self.batch_nnum_nempty = batch_nnum_nempty
678
- nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
679
-
680
- # merge octre properties
681
- for d in range(self.depth + 1):
682
- # key
683
- keys = [None] * batch_size
684
- for i in range(batch_size):
685
- key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
686
- keys[i] = key | (i << 48)
687
- self.keys[d] = torch.cat(keys, dim=0)
688
-
689
- # children
690
- children = [None] * batch_size
691
- for i in range(batch_size):
692
- # !! `clone` is used here to avoid modifying the original octrees
693
- child = octrees[i].children[d].clone()
694
- mask = child >= 0
695
- child[mask] = child[mask] + nnum_cum[d, i]
696
- children[i] = child
697
- self.children[d] = torch.cat(children, dim=0)
698
-
699
- # features
700
- if octrees[0].features[d] is not None and d == self.depth:
701
- features = [octrees[i].features[d] for i in range(batch_size)]
702
- self.features[d] = torch.cat(features, dim=0)
703
-
704
- # normals
705
- if octrees[0].normals[d] is not None and d == self.depth:
706
- normals = [octrees[i].normals[d] for i in range(batch_size)]
707
- self.normals[d] = torch.cat(normals, dim=0)
708
-
709
- # points
710
- if octrees[0].points[d] is not None and d == self.depth:
711
- points = [octrees[i].points[d] for i in range(batch_size)]
712
- self.points[d] = torch.cat(points, dim=0)
713
-
714
- return self
715
-
716
- @classmethod
717
- def init_like(cls, octree: 'Octree', device: Union[torch.device, str, None] = None):
718
- r''' Initializes the octree like another octree.
719
-
720
- Args:
721
- octree (Octree): The reference octree.
722
- device (torch.device or str): The device to use for computation.
723
- '''
724
-
725
- device = device if device is not None else octree.device
726
- return cls(depth=octree.depth, full_depth=octree.full_depth,
727
- batch_size=octree.batch_size, device=device)
728
-
729
- @classmethod
730
- def init_octree(cls, depth: int, full_depth: int = 2, batch_size: int = 1,
731
- device: Union[torch.device, str] = 'cpu'):
732
- r'''
733
- Initializes an octree to :attr:`full_depth`.
734
-
735
- Args:
736
- depth (int): The depth of the octree.
737
- full_depth (int): The octree layers with a depth small than
738
- :attr:`full_depth` are forced to be full.
739
- batch_size (int, optional): The batch size.
740
- device (torch.device or str): The device to use for computation.
741
-
742
- Returns:
743
- Octree: The initialized Octree object.
744
- '''
745
-
746
- octree = cls(depth, full_depth, batch_size, device)
747
- for d in range(full_depth + 1):
748
- octree.octree_grow_full(depth=d)
749
- return octree
750
-
751
-
752
- def merge_octrees(octrees: List['Octree']):
753
- r''' A wrapper of :meth:`Octree.merge_octrees`.
754
-
755
- .. deprecated::
756
- Use :meth:`Octree.merge_octrees` instead.
757
- '''
758
-
759
- return Octree.init_like(octrees[0]).merge_octrees(octrees)
760
-
761
-
762
- def init_octree(depth: int, full_depth: int = 2, batch_size: int = 1,
763
- device: Union[torch.device, str] = 'cpu'):
764
- r''' A wrapper of :meth:`Octree.init_octree`.
765
-
766
- .. deprecated::
767
- Use :meth:`Octree.init_octree` instead.
768
- '''
769
-
770
- return Octree.init_octree(depth, full_depth, batch_size, device)
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from typing import Union, List
11
+
12
+ import ocnn
13
+ from ocnn.octree.points import Points
14
+ from ocnn.octree.shuffled_key import xyz2key, key2xyz
15
+ from ocnn.utils import range_grid, scatter_add, cumsum, trunc_div
16
+
17
+
18
+ class Octree:
19
+ r''' Builds an octree from an input point cloud.
20
+
21
+ Args:
22
+ depth (int): The octree depth.
23
+ full_depth (int): The octree layers with a depth small than
24
+ :attr:`full_depth` are forced to be full.
25
+ batch_size (int): The octree batch size.
26
+ device (torch.device or str): Choose from :obj:`cpu` and :obj:`gpu`.
27
+ (default: :obj:`cpu`)
28
+
29
+ .. note::
30
+ The octree data structure requires that if an octree node has children nodes,
31
+ the number of children nodes is exactly 8, in which some of the nodes are
32
+ empty and some nodes are non-empty. The properties of an octree, including
33
+ :obj:`keys`, :obj:`children` and :obj:`neighs`, contain both non-empty and
34
+ empty nodes, and other properties, including :obj:`features`, :obj:`normals`
35
+ and :obj:`points`, contain only non-empty nodes.
36
+
37
+ .. note::
38
+ The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
39
+ is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
40
+ some margin.
41
+ '''
42
+
43
+ def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
44
+ device: Union[torch.device, str] = 'cpu', **kwargs):
45
+ super().__init__()
46
+ # configurations for initialization
47
+ self.depth = depth
48
+ self.full_depth = full_depth
49
+ self.batch_size = batch_size
50
+ self.device = device
51
+
52
+ # properties after building the octree
53
+ self.reset()
54
+
55
+ def reset(self):
56
+ r''' Resets the Octree status and constructs several lookup tables.
57
+ '''
58
+
59
+ # octree features in each octree layers
60
+ num = self.depth + 1
61
+ self.keys = [None] * num
62
+ self.children = [None] * num
63
+ self.neighs = [None] * num
64
+ self.features = [None] * num
65
+ self.normals = [None] * num
66
+ self.points = [None] * num
67
+
68
+ # self.nempty_masks, self.nempty_indices and self.nempty_neighs are
69
+ # for handling of non-empty nodes and are constructed on demand
70
+ self.nempty_masks = [None] * num
71
+ self.nempty_indices = [None] * num
72
+ self.nempty_neighs = [None] * num
73
+
74
+ # octree node numbers in each octree layers.
75
+ # These are small 1-D tensors; just keep them on CPUs
76
+ self.nnum = torch.zeros(num, dtype=torch.long)
77
+ self.nnum_nempty = torch.zeros(num, dtype=torch.long)
78
+
79
+ # the following properties are only valid after `merge_octrees`.
80
+ # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
81
+ batch_size = self.batch_size
82
+ self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
83
+ self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
84
+
85
+ # construct the look up tables for neighborhood searching
86
+ device = self.device
87
+ center_grid = range_grid(2, 3, device) # (8, 3)
88
+ displacement = range_grid(-1, 1, device) # (27, 3)
89
+ neigh_grid = center_grid.unsqueeze(1) + displacement # (8, 27, 3)
90
+ parent_grid = trunc_div(neigh_grid, 2)
91
+ child_grid = neigh_grid % 2
92
+ self.lut_parent = torch.sum(
93
+ parent_grid * torch.tensor([9, 3, 1], device=device), dim=2)
94
+ self.lut_child = torch.sum(
95
+ child_grid * torch.tensor([4, 2, 1], device=device), dim=2)
96
+
97
+ # lookup tables for different kernel sizes
98
+ self.lut_kernel = {
99
+ '222': torch.tensor([13, 14, 16, 17, 22, 23, 25, 26], device=device),
100
+ '311': torch.tensor([4, 13, 22], device=device),
101
+ '131': torch.tensor([10, 13, 16], device=device),
102
+ '113': torch.tensor([12, 13, 14], device=device),
103
+ '331': torch.tensor([1, 4, 7, 10, 13, 16, 19, 22, 25], device=device),
104
+ '313': torch.tensor([3, 4, 5, 12, 13, 14, 21, 22, 23], device=device),
105
+ '133': torch.tensor([9, 10, 11, 12, 13, 14, 15, 16, 17], device=device),
106
+ }
107
+
108
+ def key(self, depth: int, nempty: bool = False):
109
+ r''' Returns the shuffled key of each octree node.
110
+
111
+ Args:
112
+ depth (int): The depth of the octree.
113
+ nempty (bool): If True, returns the results of non-empty octree nodes.
114
+ '''
115
+
116
+ key = self.keys[depth]
117
+ if nempty:
118
+ idx = self.nempty_index(depth)
119
+ key = key[idx]
120
+ return key
121
+
122
+ def xyzb(self, depth: int, nempty: bool = False):
123
+ r''' Returns the xyz coordinates and the batch indices of each octree node.
124
+
125
+ Args:
126
+ depth (int): The depth of the octree.
127
+ nempty (bool): If True, returns the results of non-empty octree nodes.
128
+ '''
129
+
130
+ key = self.key(depth, nempty)
131
+ return key2xyz(key, depth)
132
+
133
+ def batch_id(self, depth: int, nempty: bool = False):
134
+ r''' Returns the batch indices of each octree node.
135
+
136
+ Args:
137
+ depth (int): The depth of the octree.
138
+ nempty (bool): If True, returns the results of non-empty octree nodes.
139
+ '''
140
+
141
+ batch_id = self.keys[depth] >> 48
142
+ if nempty:
143
+ idx = self.nempty_index(depth)
144
+ batch_id = batch_id[idx]
145
+ return batch_id
146
+
147
+ def nempty_mask(self, depth: int, reset: bool = False):
148
+ r''' Returns a binary mask which indicates whether the cooreponding octree
149
+ node is empty or not.
150
+
151
+ Args:
152
+ depth (int): The depth of the octree.
153
+ reset (bool): If True, recomputes the mask.
154
+ '''
155
+
156
+ if self.nempty_masks[depth] is None or reset:
157
+ self.nempty_masks[depth] = self.children[depth] >= 0
158
+ return self.nempty_masks[depth]
159
+
160
+ def nempty_index(self, depth: int, reset: bool = False):
161
+ r''' Returns the indices of non-empty octree nodes.
162
+
163
+ Args:
164
+ depth (int): The depth of the octree.
165
+ reset (bool): If True, recomputes the indices.
166
+ '''
167
+
168
+ if self.nempty_indices[depth] is None or reset:
169
+ mask = self.nempty_mask(depth)
170
+ rng = torch.arange(mask.shape[0], device=mask.device, dtype=torch.long)
171
+ self.nempty_indices[depth] = rng[mask]
172
+ return self.nempty_indices[depth]
173
+
174
+ def nempty_neigh(self, depth: int, reset: bool = False):
175
+ r''' Returns the neighborhoods of non-empty octree nodes.
176
+ Args:
177
+
178
+ depth (int): The depth of the octree.
179
+ reset (bool): If True, recomputes the neighborhoods.
180
+ '''
181
+ if self.nempty_neighs[depth] is None or reset:
182
+ neigh = self.neighs[depth]
183
+ idx = self.nempty_index(depth)
184
+ neigh = self.remap_nempty_neigh(neigh[idx], depth)
185
+ self.nempty_neighs[depth] = neigh
186
+ return self.nempty_neighs[depth]
187
+
188
+ def remap_nempty_neigh(self, neigh: torch.Tensor, depth: int):
189
+ r''' Remaps the neighborhood indices to the non-empty octree nodes.
190
+
191
+ Args:
192
+ neigh (torch.Tensor): The input neighborhoods with shape :obj:`(N, 27)`.
193
+ depth (int): The depth of the octree.
194
+ '''
195
+
196
+ valid = neigh >= 0
197
+ child = self.children[depth]
198
+ neigh[valid] = child[neigh[valid]].long() # remap the index
199
+ return neigh
200
+
201
+ def build_octree(self, point_cloud: Points):
202
+ r''' Builds an octree from a point cloud.
203
+
204
+ Args:
205
+ point_cloud (Points): The input point cloud.
206
+
207
+ .. note::
208
+ The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
209
+ is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
210
+ some margin.
211
+ '''
212
+
213
+ self.device = point_cloud.device
214
+ assert point_cloud.batch_size == self.batch_size, 'Inconsistent batch_size'
215
+
216
+ # normalize points from [-1, 1] to [0, 2^depth]. #[L:Scale]
217
+ scale = 2 ** (self.depth - 1)
218
+ points = (point_cloud.points + 1.0) * scale
219
+
220
+ # get the shuffled key and sort
221
+ x, y, z = points[:, 0], points[:, 1], points[:, 2]
222
+ b = None if self.batch_size == 1 else point_cloud.batch_id.view(-1)
223
+ key = xyz2key(x, y, z, b, self.depth)
224
+ node_key, idx, counts = torch.unique(
225
+ key, sorted=True, return_inverse=True, return_counts=True)
226
+
227
+ # layer 0 to full_layer: the octree is full in these layers
228
+ for d in range(self.full_depth+1):
229
+ self.octree_grow_full(d, update_neigh=False)
230
+
231
+ # layer depth_ to full_layer_
232
+ for d in range(self.depth, self.full_depth, -1):
233
+ # compute parent key, i.e. keys of layer (d -1)
234
+ pkey = node_key >> 3
235
+ pkey, pidx, _ = torch.unique_consecutive(
236
+ pkey, return_inverse=True, return_counts=True)
237
+
238
+ # augmented key
239
+ key = (pkey.unsqueeze(-1) << 3) + torch.arange(8, device=self.device)
240
+ self.keys[d] = key.view(-1)
241
+ self.nnum[d] = key.numel()
242
+ self.nnum_nempty[d] = node_key.numel()
243
+
244
+ # children
245
+ addr = (pidx << 3) | (node_key % 8)
246
+ children = -torch.ones(
247
+ self.nnum[d].item(), dtype=torch.int32, device=self.device)
248
+ children[addr] = torch.arange(
249
+ self.nnum_nempty[d], dtype=torch.int32, device=self.device)
250
+ self.children[d] = children
251
+
252
+ # cache pkey for the next iteration
253
+ # Use `pkey >> 45` instead of `pkey >> 48` in L199 since pkey is already
254
+ # shifted to the right by 3 bits in L177
255
+ node_key = pkey if self.batch_size == 1 else \
256
+ ((pkey >> 45) << 48) | (pkey & ((1 << 45) - 1))
257
+
258
+ # set the children for the layer full_layer,
259
+ # now the node_keys are the key for full_layer
260
+ d = self.full_depth
261
+ children = -torch.ones_like(self.children[d])
262
+ nempty_idx = node_key if self.batch_size == 1 else \
263
+ ((node_key >> 48) << (3 * d)) | (node_key & ((1 << 48) - 1))
264
+ children[nempty_idx] = torch.arange(
265
+ node_key.numel(), dtype=torch.int32, device=self.device)
266
+ self.children[d] = children
267
+ self.nnum_nempty[d] = node_key.numel()
268
+
269
+ # average the signal for the last octree layer
270
+ d = self.depth
271
+ points = scatter_add(points, idx, dim=0) # points is rescaled in [L:Scale]
272
+ self.points[d] = points / counts.unsqueeze(1)
273
+ if point_cloud.normals is not None:
274
+ normals = scatter_add(point_cloud.normals, idx, dim=0)
275
+ self.normals[d] = F.normalize(normals)
276
+ if point_cloud.features is not None:
277
+ features = scatter_add(point_cloud.features, idx, dim=0)
278
+ self.features[d] = features / counts.unsqueeze(1)
279
+
280
+ # reset nempty_masks and nempty_indices, which will be updated on demand
281
+ self.nempty_masks = [None] * (self.depth + 1)
282
+ self.nempty_indices = [None] * (self.depth + 1)
283
+ self.nempty_neighs = [None] * (self.depth + 1)
284
+ return idx
285
+
286
+ def octree_grow_full(self, depth: int, update_neigh: bool = True):
287
+ r''' Builds the full octree, which is essentially a dense volumetric grid.
288
+
289
+ Args:
290
+ depth (int): The depth of the octree.
291
+ update_neigh (bool): If True, construct the neighborhood indices.
292
+ '''
293
+
294
+ # check
295
+ assert depth <= self.full_depth, 'error'
296
+
297
+ # node number
298
+ num = 1 << (3 * depth)
299
+ batch_size = self.batch_size
300
+ self.nnum[depth] = num * batch_size
301
+ self.nnum_nempty[depth] = num * batch_size
302
+
303
+ # update key
304
+ key = torch.arange(num, dtype=torch.long, device=self.device)
305
+ bs = torch.arange(batch_size, dtype=torch.long, device=self.device)
306
+ key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
307
+ self.keys[depth] = key.view(-1)
308
+
309
+ # update children
310
+ self.children[depth] = torch.arange(
311
+ num * batch_size, dtype=torch.int32, device=self.device)
312
+
313
+ # nempty_masks, nempty_indices, and nempty_neighs
314
+ # need not be reset for full octrees
315
+
316
+ # update neigh if needed
317
+ if update_neigh:
318
+ self.construct_neigh(depth)
319
+
320
+ def octree_split(self, split: torch.Tensor, depth: int):
321
+ r''' Sets whether the octree nodes in :attr:`depth` are splitted or not.
322
+
323
+ Args:
324
+ split (torch.Tensor): The input tensor with its element indicating status
325
+ of each octree node: 0 - empty, 1 - non-empty or splitted.
326
+ depth (int): The depth of current octree.
327
+ '''
328
+
329
+ # split -> children
330
+ empty = split == 0
331
+ sum = cumsum(split, dim=0, exclusive=True)
332
+ children, nnum_nempty = torch.split(sum, [split.shape[0], 1])
333
+ children[empty] = -1
334
+
335
+ # boundary case, make sure that at least one octree node is splitted
336
+ if nnum_nempty == 0:
337
+ nnum_nempty = 1
338
+ children[0] = 0
339
+
340
+ # update octree
341
+ self.children[depth] = children.int()
342
+ self.nnum_nempty[depth] = nnum_nempty
343
+
344
+ # reset nempty_masks, nempty_indices, and nempty_neighs as they depend on
345
+ # children[depth] and are invalid now
346
+ self.nempty_masks[depth] = None
347
+ self.nempty_indices[depth] = None
348
+ self.nempty_neighs[depth] = None
349
+
350
+ def octree_grow(self, depth: int, update_neigh: bool = True):
351
+ r''' Grows the octree and updates the relevant properties. And in most
352
+ cases, call :func:`Octree.octree_split` to update the splitting status of
353
+ the octree before this function.
354
+
355
+ Args:
356
+ depth (int): The depth of the octree.
357
+ update_neigh (bool): If True, construct the neighborhood indices.
358
+ '''
359
+
360
+ # increase the octree depth if required
361
+ if depth > self.depth:
362
+ assert depth == self.depth + 1
363
+ self.depth = depth
364
+ self.keys.append(None)
365
+ self.children.append(None)
366
+ self.neighs.append(None)
367
+ self.features.append(None)
368
+ self.normals.append(None)
369
+ self.points.append(None)
370
+ self.nempty_masks.append(None)
371
+ self.nempty_indices.append(None)
372
+ self.nempty_neighs.append(None)
373
+ zero = torch.zeros(1, dtype=torch.long)
374
+ self.nnum = torch.cat([self.nnum, zero])
375
+ self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
376
+ zero = zero.view(1, 1)
377
+ self.batch_nnum = torch.cat([self.batch_nnum, zero], dim=0)
378
+ self.batch_nnum_nempty = torch.cat([self.batch_nnum_nempty, zero], dim=0)
379
+
380
+ # node number
381
+ nnum = self.nnum_nempty[depth-1] * 8
382
+ self.nnum[depth] = nnum
383
+ self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
384
+
385
+ # update keys
386
+ key = self.key(depth-1, nempty=True)
387
+ batch_id = (key >> 48) << 48
388
+ key = (key & ((1 << 48) - 1)) << 3
389
+ key = key | batch_id
390
+ key = key.unsqueeze(1) + torch.arange(8, device=key.device)
391
+ self.keys[depth] = key.view(-1)
392
+
393
+ # update children
394
+ self.children[depth] = torch.arange(
395
+ nnum, dtype=torch.int32, device=self.device)
396
+
397
+ # update neighs
398
+ if update_neigh:
399
+ self.construct_neigh(depth)
400
+
401
+ def construct_neigh(self, depth: int):
402
+ r''' Constructs the :obj:`3x3x3` neighbors for each octree node.
403
+
404
+ Args:
405
+ depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
406
+ '''
407
+
408
+ if depth <= self.full_depth:
409
+ device = self.device
410
+ nnum = 1 << (3 * depth)
411
+ key = torch.arange(nnum, dtype=torch.long, device=device)
412
+ x, y, z, _ = key2xyz(key, depth)
413
+ xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
414
+ grid = range_grid(-1, 1, device) # (27, 3)
415
+ xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
416
+ xyz = xyz.view(-1, 3) # (N*27, 3)
417
+ neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
418
+
419
+ bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
420
+ neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
421
+
422
+ bound = 1 << depth
423
+ invalid = torch.logical_or((xyz < 0).any(1), (xyz >= bound).any(1))
424
+ neigh[:, invalid] = -1
425
+ self.neighs[depth] = neigh.view(-1, 27) # (B*N, 27)
426
+
427
+ else:
428
+ child_p = self.children[depth-1]
429
+ nempty = child_p >= 0
430
+ neigh_p = self.neighs[depth-1][nempty] # (N, 27)
431
+ neigh_p = neigh_p[:, self.lut_parent] # (N, 8, 27)
432
+ child_p = child_p[neigh_p] # (N, 8, 27)
433
+ invalid = torch.logical_or(child_p < 0, neigh_p < 0) # (N, 8, 27)
434
+ neigh = child_p * 8 + self.lut_child
435
+ neigh[invalid] = -1
436
+ self.neighs[depth] = neigh.view(-1, 27)
437
+
438
+ def construct_all_neigh(self):
439
+ r''' A convenient handler for constructing all neighbors.
440
+ '''
441
+
442
+ for depth in range(1, self.depth+1):
443
+ self.construct_neigh(depth)
444
+
445
+ def search_xyzb(self, query: torch.Tensor, depth: int, nempty: bool = False):
446
+ r''' Searches the octree nodes given the query points.
447
+
448
+ Args:
449
+ query (torch.Tensor): The coordinates of query points with shape
450
+ :obj:`(N, 4)`. The first 3 channels of the coordinates are :obj:`x`,
451
+ :obj:`y`, and :obj:`z`, and the last channel is the batch index. Note
452
+ that the coordinates must be in range :obj:`[0, 2^depth)`.
453
+ depth (int): The depth of the octree layer. nemtpy (bool): If true, only
454
+ searches the non-empty octree nodes.
455
+ '''
456
+
457
+ key = xyz2key(query[:, 0], query[:, 1], query[:, 2], query[:, 3], depth)
458
+ idx = self.search_key(key, depth, nempty)
459
+ return idx
460
+
461
+ def search_key(self, query: torch.Tensor, depth: int, nempty: bool = False):
462
+ r''' Searches the octree nodes given the query points.
463
+
464
+ Args:
465
+ query (torch.Tensor): The keys of query points with shape :obj:`(N,)`,
466
+ which are computed from the coordinates of query points.
467
+ depth (int): The depth of the octree layer. nemtpy (bool): If true, only
468
+ searches the non-empty octree nodes.
469
+ '''
470
+
471
+ key = self.key(depth, nempty)
472
+ idx = torch.searchsorted(key, query)
473
+
474
+ # `torch.bucketize` can also be used here; it is similar to
475
+ # `torch.searchsorted`, and it has fewer dimension checks, resulting in
476
+ # slightly better performance for 1D sorted sequences according to the docs
477
+ # of pytorch-1.9.1. `key` is always a 1D sorted sequence.
478
+ # https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
479
+ # idx = torch.bucketize(query, key)
480
+
481
+ valid = idx < key.shape[0] # valid if in-bound
482
+ vi = torch.arange(query.shape[0], device=query.device)[valid]
483
+ valid[vi] = key[idx[vi]] == query[vi] # valid if found
484
+ idx[valid.logical_not()] = -1 # set to -1 if invalid
485
+ return idx
486
+
487
+ def get_neigh(self, depth: int, kernel: str = '333', stride: int = 1,
488
+ nempty: bool = False):
489
+ r''' Returns the neighborhoods given the depth and a kernel shape.
490
+
491
+ Args:
492
+ depth (int): The octree depth with a value larger than 0 (:obj:`>0`).
493
+ kernel (str): The kernel shape from :obj:`333`, :obj:`311`, :obj:`131`,
494
+ :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and :obj:`313`.
495
+ stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
496
+ stride is :obj:`2`, always returns the neighborhood of the first
497
+ siblings.
498
+ nempty (bool): If True, only returns the neighborhoods of the non-empty
499
+ octree nodes.
500
+ '''
501
+
502
+ if stride == 1 and not nempty:
503
+ neigh = self.neighs[depth]
504
+ elif stride == 2 and not nempty:
505
+ neigh = self.neighs[depth][::8].clone()
506
+ elif stride == 1 and nempty:
507
+ neigh = self.nempty_neigh(depth)
508
+ elif stride == 2 and nempty:
509
+ neigh = self.neighs[depth][::8].clone()
510
+ neigh = self.remap_nempty_neigh(neigh, depth)
511
+ else:
512
+ raise ValueError('Unsupported stride {}'.format(stride))
513
+
514
+ if kernel == '333':
515
+ return neigh
516
+ elif kernel in self.lut_kernel:
517
+ lut = self.lut_kernel[kernel]
518
+ return neigh[:, lut]
519
+ else:
520
+ raise ValueError('Unsupported kernel {}'.format(kernel))
521
+
522
+ def get_input_feature(self, feature: str, nempty: bool = False):
523
+ r''' Returns the initial input feature stored in octree.
524
+
525
+ Args:
526
+ feature (str): A string used to indicate which features to extract from
527
+ the input octree. If the character :obj:`N` is in :attr:`feature`, the
528
+ normal signal is extracted (3 channels). Similarly, if :obj:`D` is in
529
+ :attr:`feature`, the local displacement is extracted (1 channels). If
530
+ :obj:`L` is in :attr:`feature`, the local coordinates of the averaged
531
+ points in each octree node is extracted (3 channels). If :attr:`P` is
532
+ in :attr:`feature`, the global coordinates are extracted (3 channels).
533
+ If :attr:`F` is in :attr:`feature`, other features (like colors) are
534
+ extracted (k channels).
535
+ nempty (bool): If false, gets the features of all octree nodes.
536
+ '''
537
+
538
+ features = list()
539
+ depth = self.depth
540
+ feature = feature.upper()
541
+ if 'N' in feature:
542
+ features.append(self.normals[depth])
543
+
544
+ if 'L' in feature or 'D' in feature:
545
+ local_points = self.points[depth].frac() - 0.5
546
+
547
+ if 'D' in feature:
548
+ dis = torch.sum(local_points * self.normals[depth], dim=1, keepdim=True)
549
+ features.append(dis)
550
+
551
+ if 'L' in feature:
552
+ features.append(local_points)
553
+
554
+ if 'P' in feature:
555
+ scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1]
556
+ global_points = self.points[depth] * scale - 1.0
557
+ features.append(global_points)
558
+
559
+ if 'F' in feature:
560
+ features.append(self.features[depth])
561
+
562
+ out = torch.cat(features, dim=1)
563
+ if not nempty:
564
+ out = ocnn.nn.octree_pad(out, self, depth)
565
+ return out
566
+
567
+ def to_points(self, rescale: bool = True):
568
+ r''' Converts averaged points in the octree to a point cloud.
569
+
570
+ Args:
571
+ rescale (bool): rescale the xyz coordinates to [-1, 1] if True.
572
+ '''
573
+
574
+ depth = self.depth
575
+ batch_size = self.batch_size
576
+
577
+ # by default, use the average points generated when building the octree
578
+ # from the input point cloud
579
+ xyz = self.points[depth]
580
+ batch_id = self.batch_id(depth, nempty=True)
581
+
582
+ # xyz is None when the octree is predicted by a neural network
583
+ if xyz is None:
584
+ x, y, z, batch_id = self.xyzb(depth, nempty=True)
585
+ xyz = torch.stack([x, y, z], dim=1) + 0.5
586
+
587
+ # normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
588
+ if rescale:
589
+ scale = 2 ** (1 - depth)
590
+ xyz = xyz * scale - 1.0
591
+
592
+ # construct Points
593
+ out = Points(xyz, self.normals[depth], self.features[depth],
594
+ batch_id=batch_id, batch_size=batch_size)
595
+ return out
596
+
597
+ def to(self, device: Union[torch.device, str], non_blocking: bool = False):
598
+ r''' Moves the octree to a specified device.
599
+
600
+ Args:
601
+ device (torch.device or str): The destination device.
602
+ non_blocking (bool): If True and the source is in pinned memory, the copy
603
+ will be asynchronous with respect to the host. Otherwise, the argument
604
+ has no effect. Default: False.
605
+ '''
606
+
607
+ if isinstance(device, str):
608
+ device = torch.device(device)
609
+
610
+ # If on the save device, directly retrun self
611
+ if self.device == device:
612
+ return self
613
+
614
+ def list_to_device(prop):
615
+ return [p.to(device, non_blocking=non_blocking)
616
+ if isinstance(p, torch.Tensor) else None for p in prop]
617
+
618
+ # Construct a new Octree on the specified device.
619
+ # During the initialization, self.device is used to set up the new Octree;
620
+ # the look-up tables (including self.lut_kernel, self.lut_parent, and
621
+ # self.lut_child), will be already created on the correct device.
622
+ octree = Octree.init_like(self, device)
623
+
624
+ # Move all the other properties to the specified device
625
+ octree.keys = list_to_device(self.keys)
626
+ octree.children = list_to_device(self.children)
627
+ octree.neighs = list_to_device(self.neighs)
628
+ octree.features = list_to_device(self.features)
629
+ octree.normals = list_to_device(self.normals)
630
+ octree.points = list_to_device(self.points)
631
+
632
+ # The following are small tensors, keep them on CPU to avoid frequent device
633
+ # switching, so just clone them.
634
+ octree.nnum = self.nnum.clone()
635
+ octree.nnum_nempty = self.nnum_nempty.clone()
636
+ octree.batch_nnum = self.batch_nnum.clone()
637
+ octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
638
+ return octree
639
+
640
+ def cuda(self, non_blocking: bool = False):
641
+ r''' Moves the octree to the GPU. '''
642
+
643
+ return self.to('cuda', non_blocking)
644
+
645
+ def cpu(self):
646
+ r''' Moves the octree to the CPU. '''
647
+
648
+ return self.to('cpu')
649
+
650
+ def merge_octrees(self, octrees: List['Octree']):
651
+ r''' Merges a list of octrees into one batch.
652
+
653
+ Args:
654
+ octrees (List[Octree]): A list of octrees to merge.
655
+
656
+ Returns:
657
+ Octree: The merged octree.
658
+ '''
659
+
660
+ # init and check
661
+ batch_size = len(octrees)
662
+ self.batch_size = batch_size
663
+ for i in range(1, batch_size):
664
+ condition = (octrees[i].depth == self.depth and
665
+ octrees[i].full_depth == self.full_depth and
666
+ octrees[i].device == self.device)
667
+ assert condition, 'The check of merge_octrees failed'
668
+
669
+ # node num
670
+ batch_nnum = torch.stack(
671
+ [octrees[i].nnum for i in range(batch_size)], dim=1)
672
+ batch_nnum_nempty = torch.stack(
673
+ [octrees[i].nnum_nempty for i in range(batch_size)], dim=1)
674
+ self.nnum = torch.sum(batch_nnum, dim=1)
675
+ self.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
676
+ self.batch_nnum = batch_nnum
677
+ self.batch_nnum_nempty = batch_nnum_nempty
678
+ nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
679
+
680
+ # merge octre properties
681
+ for d in range(self.depth + 1):
682
+ # key
683
+ keys = [None] * batch_size
684
+ for i in range(batch_size):
685
+ key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
686
+ keys[i] = key | (i << 48)
687
+ self.keys[d] = torch.cat(keys, dim=0)
688
+
689
+ # children
690
+ children = [None] * batch_size
691
+ for i in range(batch_size):
692
+ # !! `clone` is used here to avoid modifying the original octrees
693
+ child = octrees[i].children[d].clone()
694
+ mask = child >= 0
695
+ child[mask] = child[mask] + nnum_cum[d, i]
696
+ children[i] = child
697
+ self.children[d] = torch.cat(children, dim=0)
698
+
699
+ # features
700
+ if octrees[0].features[d] is not None and d == self.depth:
701
+ features = [octrees[i].features[d] for i in range(batch_size)]
702
+ self.features[d] = torch.cat(features, dim=0)
703
+
704
+ # normals
705
+ if octrees[0].normals[d] is not None and d == self.depth:
706
+ normals = [octrees[i].normals[d] for i in range(batch_size)]
707
+ self.normals[d] = torch.cat(normals, dim=0)
708
+
709
+ # points
710
+ if octrees[0].points[d] is not None and d == self.depth:
711
+ points = [octrees[i].points[d] for i in range(batch_size)]
712
+ self.points[d] = torch.cat(points, dim=0)
713
+
714
+ return self
715
+
716
+ @classmethod
717
+ def init_like(cls, octree: 'Octree', device: Union[torch.device, str, None] = None):
718
+ r''' Initializes the octree like another octree.
719
+
720
+ Args:
721
+ octree (Octree): The reference octree.
722
+ device (torch.device or str): The device to use for computation.
723
+ '''
724
+
725
+ device = device if device is not None else octree.device
726
+ return cls(depth=octree.depth, full_depth=octree.full_depth,
727
+ batch_size=octree.batch_size, device=device)
728
+
729
+ @classmethod
730
+ def init_octree(cls, depth: int, full_depth: int = 2, batch_size: int = 1,
731
+ device: Union[torch.device, str] = 'cpu'):
732
+ r'''
733
+ Initializes an octree to :attr:`full_depth`.
734
+
735
+ Args:
736
+ depth (int): The depth of the octree.
737
+ full_depth (int): The octree layers with a depth small than
738
+ :attr:`full_depth` are forced to be full.
739
+ batch_size (int, optional): The batch size.
740
+ device (torch.device or str): The device to use for computation.
741
+
742
+ Returns:
743
+ Octree: The initialized Octree object.
744
+ '''
745
+
746
+ octree = cls(depth, full_depth, batch_size, device)
747
+ for d in range(full_depth + 1):
748
+ octree.octree_grow_full(depth=d)
749
+ return octree
750
+
751
+
752
+ def merge_octrees(octrees: List['Octree']):
753
+ r''' A wrapper of :meth:`Octree.merge_octrees`.
754
+
755
+ .. deprecated:: 2.2.7
756
+ Use :meth:`Octree.merge_octrees` instead.
757
+ '''
758
+
759
+ return Octree.init_like(octrees[0]).merge_octrees(octrees)
760
+
761
+
762
+ def init_octree(depth: int, full_depth: int = 2, batch_size: int = 1,
763
+ device: Union[torch.device, str] = 'cpu'):
764
+ r''' A wrapper of :meth:`Octree.init_octree`.
765
+
766
+ .. deprecated:: 2.2.7
767
+ Use :meth:`Octree.init_octree` instead.
768
+ '''
769
+
770
+ return Octree.init_octree(depth, full_depth, batch_size, device)