ocnn 2.2.7__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +1 -1
- ocnn/models/resnet.py +2 -2
- ocnn/nn/__init__.py +2 -1
- ocnn/nn/kernels/__init__.py +14 -0
- ocnn/nn/kernels/autotuner.py +416 -0
- ocnn/nn/kernels/config.py +67 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
- ocnn/nn/kernels/utils.py +44 -0
- ocnn/nn/octree_conv.py +2 -1
- ocnn/nn/octree_conv_t.py +148 -0
- ocnn/nn/octree_pad.py +4 -4
- ocnn/octree/octree.py +218 -109
- ocnn/octree/points.py +95 -34
- {ocnn-2.2.7.dist-info → ocnn-2.3.0.dist-info}/METADATA +11 -6
- {ocnn-2.2.7.dist-info → ocnn-2.3.0.dist-info}/RECORD +21 -12
- {ocnn-2.2.7.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
- {ocnn-2.2.7.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +0 -0
- {ocnn-2.2.7.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
ocnn/octree/octree.py
CHANGED
|
@@ -43,11 +43,13 @@ class Octree:
|
|
|
43
43
|
def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
|
|
44
44
|
device: Union[torch.device, str] = 'cpu', **kwargs):
|
|
45
45
|
super().__init__()
|
|
46
|
+
# configurations for initialization
|
|
46
47
|
self.depth = depth
|
|
47
48
|
self.full_depth = full_depth
|
|
48
49
|
self.batch_size = batch_size
|
|
49
50
|
self.device = device
|
|
50
51
|
|
|
52
|
+
# properties after building the octree
|
|
51
53
|
self.reset()
|
|
52
54
|
|
|
53
55
|
def reset(self):
|
|
@@ -63,12 +65,18 @@ class Octree:
|
|
|
63
65
|
self.normals = [None] * num
|
|
64
66
|
self.points = [None] * num
|
|
65
67
|
|
|
68
|
+
# self.nempty_masks, self.nempty_indices and self.nempty_neighs are
|
|
69
|
+
# for handling of non-empty nodes and are constructed on demand
|
|
70
|
+
self.nempty_masks = [None] * num
|
|
71
|
+
self.nempty_indices = [None] * num
|
|
72
|
+
self.nempty_neighs = [None] * num
|
|
73
|
+
|
|
66
74
|
# octree node numbers in each octree layers.
|
|
67
|
-
#
|
|
75
|
+
# These are small 1-D tensors; just keep them on CPUs
|
|
68
76
|
self.nnum = torch.zeros(num, dtype=torch.long)
|
|
69
77
|
self.nnum_nempty = torch.zeros(num, dtype=torch.long)
|
|
70
78
|
|
|
71
|
-
# the following properties are valid after `merge_octrees`.
|
|
79
|
+
# the following properties are only valid after `merge_octrees`.
|
|
72
80
|
# TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
|
|
73
81
|
batch_size = self.batch_size
|
|
74
82
|
self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
|
|
@@ -107,8 +115,8 @@ class Octree:
|
|
|
107
115
|
|
|
108
116
|
key = self.keys[depth]
|
|
109
117
|
if nempty:
|
|
110
|
-
|
|
111
|
-
key = key[
|
|
118
|
+
idx = self.nempty_index(depth)
|
|
119
|
+
key = key[idx]
|
|
112
120
|
return key
|
|
113
121
|
|
|
114
122
|
def xyzb(self, depth: int, nempty: bool = False):
|
|
@@ -132,19 +140,63 @@ class Octree:
|
|
|
132
140
|
|
|
133
141
|
batch_id = self.keys[depth] >> 48
|
|
134
142
|
if nempty:
|
|
135
|
-
|
|
136
|
-
batch_id = batch_id[
|
|
143
|
+
idx = self.nempty_index(depth)
|
|
144
|
+
batch_id = batch_id[idx]
|
|
137
145
|
return batch_id
|
|
138
146
|
|
|
139
|
-
def nempty_mask(self, depth: int):
|
|
147
|
+
def nempty_mask(self, depth: int, reset: bool = False):
|
|
140
148
|
r''' Returns a binary mask which indicates whether the cooreponding octree
|
|
141
149
|
node is empty or not.
|
|
142
150
|
|
|
143
151
|
Args:
|
|
144
152
|
depth (int): The depth of the octree.
|
|
153
|
+
reset (bool): If True, recomputes the mask.
|
|
145
154
|
'''
|
|
146
155
|
|
|
147
|
-
|
|
156
|
+
if self.nempty_masks[depth] is None or reset:
|
|
157
|
+
self.nempty_masks[depth] = self.children[depth] >= 0
|
|
158
|
+
return self.nempty_masks[depth]
|
|
159
|
+
|
|
160
|
+
def nempty_index(self, depth: int, reset: bool = False):
|
|
161
|
+
r''' Returns the indices of non-empty octree nodes.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
depth (int): The depth of the octree.
|
|
165
|
+
reset (bool): If True, recomputes the indices.
|
|
166
|
+
'''
|
|
167
|
+
|
|
168
|
+
if self.nempty_indices[depth] is None or reset:
|
|
169
|
+
mask = self.nempty_mask(depth)
|
|
170
|
+
rng = torch.arange(mask.shape[0], device=mask.device, dtype=torch.long)
|
|
171
|
+
self.nempty_indices[depth] = rng[mask]
|
|
172
|
+
return self.nempty_indices[depth]
|
|
173
|
+
|
|
174
|
+
def nempty_neigh(self, depth: int, reset: bool = False):
|
|
175
|
+
r''' Returns the neighborhoods of non-empty octree nodes.
|
|
176
|
+
Args:
|
|
177
|
+
|
|
178
|
+
depth (int): The depth of the octree.
|
|
179
|
+
reset (bool): If True, recomputes the neighborhoods.
|
|
180
|
+
'''
|
|
181
|
+
if self.nempty_neighs[depth] is None or reset:
|
|
182
|
+
neigh = self.neighs[depth]
|
|
183
|
+
idx = self.nempty_index(depth)
|
|
184
|
+
neigh = self.remap_nempty_neigh(neigh[idx], depth)
|
|
185
|
+
self.nempty_neighs[depth] = neigh
|
|
186
|
+
return self.nempty_neighs[depth]
|
|
187
|
+
|
|
188
|
+
def remap_nempty_neigh(self, neigh: torch.Tensor, depth: int):
|
|
189
|
+
r''' Remaps the neighborhood indices to the non-empty octree nodes.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
neigh (torch.Tensor): The input neighborhoods with shape :obj:`(N, 27)`.
|
|
193
|
+
depth (int): The depth of the octree.
|
|
194
|
+
'''
|
|
195
|
+
|
|
196
|
+
valid = neigh >= 0
|
|
197
|
+
child = self.children[depth]
|
|
198
|
+
neigh[valid] = child[neigh[valid]].long() # remap the index
|
|
199
|
+
return neigh
|
|
148
200
|
|
|
149
201
|
def build_octree(self, point_cloud: Points):
|
|
150
202
|
r''' Builds an octree from a point cloud.
|
|
@@ -225,6 +277,10 @@ class Octree:
|
|
|
225
277
|
features = scatter_add(point_cloud.features, idx, dim=0)
|
|
226
278
|
self.features[d] = features / counts.unsqueeze(1)
|
|
227
279
|
|
|
280
|
+
# reset nempty_masks and nempty_indices, which will be updated on demand
|
|
281
|
+
self.nempty_masks = [None] * (self.depth + 1)
|
|
282
|
+
self.nempty_indices = [None] * (self.depth + 1)
|
|
283
|
+
self.nempty_neighs = [None] * (self.depth + 1)
|
|
228
284
|
return idx
|
|
229
285
|
|
|
230
286
|
def octree_grow_full(self, depth: int, update_neigh: bool = True):
|
|
@@ -240,18 +296,22 @@ class Octree:
|
|
|
240
296
|
|
|
241
297
|
# node number
|
|
242
298
|
num = 1 << (3 * depth)
|
|
243
|
-
|
|
244
|
-
self.
|
|
299
|
+
batch_size = self.batch_size
|
|
300
|
+
self.nnum[depth] = num * batch_size
|
|
301
|
+
self.nnum_nempty[depth] = num * batch_size
|
|
245
302
|
|
|
246
303
|
# update key
|
|
247
304
|
key = torch.arange(num, dtype=torch.long, device=self.device)
|
|
248
|
-
bs = torch.arange(
|
|
305
|
+
bs = torch.arange(batch_size, dtype=torch.long, device=self.device)
|
|
249
306
|
key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
|
|
250
307
|
self.keys[depth] = key.view(-1)
|
|
251
308
|
|
|
252
309
|
# update children
|
|
253
310
|
self.children[depth] = torch.arange(
|
|
254
|
-
num *
|
|
311
|
+
num * batch_size, dtype=torch.int32, device=self.device)
|
|
312
|
+
|
|
313
|
+
# nempty_masks, nempty_indices, and nempty_neighs
|
|
314
|
+
# need not be reset for full octrees
|
|
255
315
|
|
|
256
316
|
# update neigh if needed
|
|
257
317
|
if update_neigh:
|
|
@@ -281,6 +341,12 @@ class Octree:
|
|
|
281
341
|
self.children[depth] = children.int()
|
|
282
342
|
self.nnum_nempty[depth] = nnum_nempty
|
|
283
343
|
|
|
344
|
+
# reset nempty_masks, nempty_indices, and nempty_neighs as they depend on
|
|
345
|
+
# children[depth] and are invalid now
|
|
346
|
+
self.nempty_masks[depth] = None
|
|
347
|
+
self.nempty_indices[depth] = None
|
|
348
|
+
self.nempty_neighs[depth] = None
|
|
349
|
+
|
|
284
350
|
def octree_grow(self, depth: int, update_neigh: bool = True):
|
|
285
351
|
r''' Grows the octree and updates the relevant properties. And in most
|
|
286
352
|
cases, call :func:`Octree.octree_split` to update the splitting status of
|
|
@@ -301,6 +367,9 @@ class Octree:
|
|
|
301
367
|
self.features.append(None)
|
|
302
368
|
self.normals.append(None)
|
|
303
369
|
self.points.append(None)
|
|
370
|
+
self.nempty_masks.append(None)
|
|
371
|
+
self.nempty_indices.append(None)
|
|
372
|
+
self.nempty_neighs.append(None)
|
|
304
373
|
zero = torch.zeros(1, dtype=torch.long)
|
|
305
374
|
self.nnum = torch.cat([self.nnum, zero])
|
|
306
375
|
self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
|
|
@@ -342,7 +411,7 @@ class Octree:
|
|
|
342
411
|
key = torch.arange(nnum, dtype=torch.long, device=device)
|
|
343
412
|
x, y, z, _ = key2xyz(key, depth)
|
|
344
413
|
xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
|
|
345
|
-
grid = range_grid(-1, 1, device)
|
|
414
|
+
grid = range_grid(-1, 1, device) # (27, 3)
|
|
346
415
|
xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
|
|
347
416
|
xyz = xyz.view(-1, 3) # (N*27, 3)
|
|
348
417
|
neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
|
|
@@ -400,16 +469,18 @@ class Octree:
|
|
|
400
469
|
'''
|
|
401
470
|
|
|
402
471
|
key = self.key(depth, nempty)
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
#
|
|
406
|
-
#
|
|
472
|
+
idx = torch.searchsorted(key, query)
|
|
473
|
+
|
|
474
|
+
# `torch.bucketize` can also be used here; it is similar to
|
|
475
|
+
# `torch.searchsorted`, and it has fewer dimension checks, resulting in
|
|
476
|
+
# slightly better performance for 1D sorted sequences according to the docs
|
|
477
|
+
# of pytorch-1.9.1. `key` is always a 1D sorted sequence.
|
|
407
478
|
# https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
|
|
408
|
-
idx = torch.bucketize(query, key)
|
|
479
|
+
# idx = torch.bucketize(query, key)
|
|
409
480
|
|
|
410
481
|
valid = idx < key.shape[0] # valid if in-bound
|
|
411
|
-
|
|
412
|
-
valid[
|
|
482
|
+
vi = torch.arange(query.shape[0], device=query.device)[valid]
|
|
483
|
+
valid[vi] = key[idx[vi]] == query[vi] # valid if found
|
|
413
484
|
idx[valid.logical_not()] = -1 # set to -1 if invalid
|
|
414
485
|
return idx
|
|
415
486
|
|
|
@@ -428,22 +499,18 @@ class Octree:
|
|
|
428
499
|
octree nodes.
|
|
429
500
|
'''
|
|
430
501
|
|
|
431
|
-
if stride == 1:
|
|
502
|
+
if stride == 1 and not nempty:
|
|
432
503
|
neigh = self.neighs[depth]
|
|
433
|
-
elif stride == 2:
|
|
434
|
-
# clone neigh to avoid self.neigh[depth] being modified
|
|
504
|
+
elif stride == 2 and not nempty:
|
|
435
505
|
neigh = self.neighs[depth][::8].clone()
|
|
506
|
+
elif stride == 1 and nempty:
|
|
507
|
+
neigh = self.nempty_neigh(depth)
|
|
508
|
+
elif stride == 2 and nempty:
|
|
509
|
+
neigh = self.neighs[depth][::8].clone()
|
|
510
|
+
neigh = self.remap_nempty_neigh(neigh, depth)
|
|
436
511
|
else:
|
|
437
512
|
raise ValueError('Unsupported stride {}'.format(stride))
|
|
438
513
|
|
|
439
|
-
if nempty:
|
|
440
|
-
child = self.children[depth]
|
|
441
|
-
if stride == 1:
|
|
442
|
-
nempty_node = child >= 0
|
|
443
|
-
neigh = neigh[nempty_node]
|
|
444
|
-
valid = neigh >= 0
|
|
445
|
-
neigh[valid] = child[neigh[valid]].long() # remap the index
|
|
446
|
-
|
|
447
514
|
if kernel == '333':
|
|
448
515
|
return neigh
|
|
449
516
|
elif kernel in self.lut_kernel:
|
|
@@ -548,15 +615,23 @@ class Octree:
|
|
|
548
615
|
return [p.to(device, non_blocking=non_blocking)
|
|
549
616
|
if isinstance(p, torch.Tensor) else None for p in prop]
|
|
550
617
|
|
|
551
|
-
# Construct a new Octree on the specified device
|
|
552
|
-
|
|
618
|
+
# Construct a new Octree on the specified device.
|
|
619
|
+
# During the initialization, self.device is used to set up the new Octree;
|
|
620
|
+
# the look-up tables (including self.lut_kernel, self.lut_parent, and
|
|
621
|
+
# self.lut_child), will be already created on the correct device.
|
|
622
|
+
octree = Octree.init_like(self, device)
|
|
623
|
+
|
|
624
|
+
# Move all the other properties to the specified device
|
|
553
625
|
octree.keys = list_to_device(self.keys)
|
|
554
626
|
octree.children = list_to_device(self.children)
|
|
555
627
|
octree.neighs = list_to_device(self.neighs)
|
|
556
628
|
octree.features = list_to_device(self.features)
|
|
557
629
|
octree.normals = list_to_device(self.normals)
|
|
558
630
|
octree.points = list_to_device(self.points)
|
|
559
|
-
|
|
631
|
+
|
|
632
|
+
# The following are small tensors, keep them on CPU to avoid frequent device
|
|
633
|
+
# switching, so just clone them.
|
|
634
|
+
octree.nnum = self.nnum.clone()
|
|
560
635
|
octree.nnum_nempty = self.nnum_nempty.clone()
|
|
561
636
|
octree.batch_nnum = self.batch_nnum.clone()
|
|
562
637
|
octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
|
|
@@ -572,90 +647,124 @@ class Octree:
|
|
|
572
647
|
|
|
573
648
|
return self.to('cpu')
|
|
574
649
|
|
|
650
|
+
def merge_octrees(self, octrees: List['Octree']):
|
|
651
|
+
r''' Merges a list of octrees into one batch.
|
|
575
652
|
|
|
576
|
-
|
|
577
|
-
|
|
653
|
+
Args:
|
|
654
|
+
octrees (List[Octree]): A list of octrees to merge.
|
|
578
655
|
|
|
579
|
-
|
|
580
|
-
|
|
656
|
+
Returns:
|
|
657
|
+
Octree: The merged octree.
|
|
658
|
+
'''
|
|
659
|
+
|
|
660
|
+
# init and check
|
|
661
|
+
batch_size = len(octrees)
|
|
662
|
+
self.batch_size = batch_size
|
|
663
|
+
for i in range(1, batch_size):
|
|
664
|
+
condition = (octrees[i].depth == self.depth and
|
|
665
|
+
octrees[i].full_depth == self.full_depth and
|
|
666
|
+
octrees[i].device == self.device)
|
|
667
|
+
assert condition, 'The check of merge_octrees failed'
|
|
668
|
+
|
|
669
|
+
# node num
|
|
670
|
+
batch_nnum = torch.stack(
|
|
671
|
+
[octrees[i].nnum for i in range(batch_size)], dim=1)
|
|
672
|
+
batch_nnum_nempty = torch.stack(
|
|
673
|
+
[octrees[i].nnum_nempty for i in range(batch_size)], dim=1)
|
|
674
|
+
self.nnum = torch.sum(batch_nnum, dim=1)
|
|
675
|
+
self.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
|
|
676
|
+
self.batch_nnum = batch_nnum
|
|
677
|
+
self.batch_nnum_nempty = batch_nnum_nempty
|
|
678
|
+
nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
|
|
679
|
+
|
|
680
|
+
# merge octre properties
|
|
681
|
+
for d in range(self.depth + 1):
|
|
682
|
+
# key
|
|
683
|
+
keys = [None] * batch_size
|
|
684
|
+
for i in range(batch_size):
|
|
685
|
+
key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
|
|
686
|
+
keys[i] = key | (i << 48)
|
|
687
|
+
self.keys[d] = torch.cat(keys, dim=0)
|
|
688
|
+
|
|
689
|
+
# children
|
|
690
|
+
children = [None] * batch_size
|
|
691
|
+
for i in range(batch_size):
|
|
692
|
+
# !! `clone` is used here to avoid modifying the original octrees
|
|
693
|
+
child = octrees[i].children[d].clone()
|
|
694
|
+
mask = child >= 0
|
|
695
|
+
child[mask] = child[mask] + nnum_cum[d, i]
|
|
696
|
+
children[i] = child
|
|
697
|
+
self.children[d] = torch.cat(children, dim=0)
|
|
698
|
+
|
|
699
|
+
# features
|
|
700
|
+
if octrees[0].features[d] is not None and d == self.depth:
|
|
701
|
+
features = [octrees[i].features[d] for i in range(batch_size)]
|
|
702
|
+
self.features[d] = torch.cat(features, dim=0)
|
|
703
|
+
|
|
704
|
+
# normals
|
|
705
|
+
if octrees[0].normals[d] is not None and d == self.depth:
|
|
706
|
+
normals = [octrees[i].normals[d] for i in range(batch_size)]
|
|
707
|
+
self.normals[d] = torch.cat(normals, dim=0)
|
|
708
|
+
|
|
709
|
+
# points
|
|
710
|
+
if octrees[0].points[d] is not None and d == self.depth:
|
|
711
|
+
points = [octrees[i].points[d] for i in range(batch_size)]
|
|
712
|
+
self.points[d] = torch.cat(points, dim=0)
|
|
713
|
+
|
|
714
|
+
return self
|
|
715
|
+
|
|
716
|
+
@classmethod
|
|
717
|
+
def init_like(cls, octree: 'Octree', device: Union[torch.device, str, None] = None):
|
|
718
|
+
r''' Initializes the octree like another octree.
|
|
719
|
+
|
|
720
|
+
Args:
|
|
721
|
+
octree (Octree): The reference octree.
|
|
722
|
+
device (torch.device or str): The device to use for computation.
|
|
723
|
+
'''
|
|
724
|
+
|
|
725
|
+
device = device if device is not None else octree.device
|
|
726
|
+
return cls(depth=octree.depth, full_depth=octree.full_depth,
|
|
727
|
+
batch_size=octree.batch_size, device=device)
|
|
728
|
+
|
|
729
|
+
@classmethod
|
|
730
|
+
def init_octree(cls, depth: int, full_depth: int = 2, batch_size: int = 1,
|
|
731
|
+
device: Union[torch.device, str] = 'cpu'):
|
|
732
|
+
r'''
|
|
733
|
+
Initializes an octree to :attr:`full_depth`.
|
|
581
734
|
|
|
582
|
-
|
|
583
|
-
|
|
735
|
+
Args:
|
|
736
|
+
depth (int): The depth of the octree.
|
|
737
|
+
full_depth (int): The octree layers with a depth small than
|
|
738
|
+
:attr:`full_depth` are forced to be full.
|
|
739
|
+
batch_size (int, optional): The batch size.
|
|
740
|
+
device (torch.device or str): The device to use for computation.
|
|
741
|
+
|
|
742
|
+
Returns:
|
|
743
|
+
Octree: The initialized Octree object.
|
|
744
|
+
'''
|
|
745
|
+
|
|
746
|
+
octree = cls(depth, full_depth, batch_size, device)
|
|
747
|
+
for d in range(full_depth + 1):
|
|
748
|
+
octree.octree_grow_full(depth=d)
|
|
749
|
+
return octree
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
def merge_octrees(octrees: List['Octree']):
|
|
753
|
+
r''' A wrapper of :meth:`Octree.merge_octrees`.
|
|
754
|
+
|
|
755
|
+
.. deprecated:: 2.2.7
|
|
756
|
+
Use :meth:`Octree.merge_octrees` instead.
|
|
584
757
|
'''
|
|
585
758
|
|
|
586
|
-
|
|
587
|
-
octree = Octree(depth=octrees[0].depth, full_depth=octrees[0].full_depth,
|
|
588
|
-
batch_size=len(octrees), device=octrees[0].device)
|
|
589
|
-
for i in range(1, octree.batch_size):
|
|
590
|
-
condition = (octrees[i].depth == octree.depth and
|
|
591
|
-
octrees[i].full_depth == octree.full_depth and
|
|
592
|
-
octrees[i].device == octree.device)
|
|
593
|
-
assert condition, 'The check of merge_octrees failed'
|
|
594
|
-
|
|
595
|
-
# node num
|
|
596
|
-
batch_nnum = torch.stack(
|
|
597
|
-
[octrees[i].nnum for i in range(octree.batch_size)], dim=1)
|
|
598
|
-
batch_nnum_nempty = torch.stack(
|
|
599
|
-
[octrees[i].nnum_nempty for i in range(octree.batch_size)], dim=1)
|
|
600
|
-
octree.nnum = torch.sum(batch_nnum, dim=1)
|
|
601
|
-
octree.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
|
|
602
|
-
octree.batch_nnum = batch_nnum
|
|
603
|
-
octree.batch_nnum_nempty = batch_nnum_nempty
|
|
604
|
-
nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
|
|
605
|
-
|
|
606
|
-
# merge octre properties
|
|
607
|
-
for d in range(octree.depth+1):
|
|
608
|
-
# key
|
|
609
|
-
keys = [None] * octree.batch_size
|
|
610
|
-
for i in range(octree.batch_size):
|
|
611
|
-
key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
|
|
612
|
-
keys[i] = key | (i << 48)
|
|
613
|
-
octree.keys[d] = torch.cat(keys, dim=0)
|
|
614
|
-
|
|
615
|
-
# children
|
|
616
|
-
children = [None] * octree.batch_size
|
|
617
|
-
for i in range(octree.batch_size):
|
|
618
|
-
child = octrees[i].children[d].clone() # !! `clone` is used here to avoid
|
|
619
|
-
mask = child >= 0 # !! modifying the original octrees
|
|
620
|
-
child[mask] = child[mask] + nnum_cum[d, i]
|
|
621
|
-
children[i] = child
|
|
622
|
-
octree.children[d] = torch.cat(children, dim=0)
|
|
623
|
-
|
|
624
|
-
# features
|
|
625
|
-
if octrees[0].features[d] is not None and d == octree.depth:
|
|
626
|
-
features = [octrees[i].features[d] for i in range(octree.batch_size)]
|
|
627
|
-
octree.features[d] = torch.cat(features, dim=0)
|
|
628
|
-
|
|
629
|
-
# normals
|
|
630
|
-
if octrees[0].normals[d] is not None and d == octree.depth:
|
|
631
|
-
normals = [octrees[i].normals[d] for i in range(octree.batch_size)]
|
|
632
|
-
octree.normals[d] = torch.cat(normals, dim=0)
|
|
633
|
-
|
|
634
|
-
# points
|
|
635
|
-
if octrees[0].points[d] is not None and d == octree.depth:
|
|
636
|
-
points = [octrees[i].points[d] for i in range(octree.batch_size)]
|
|
637
|
-
octree.points[d] = torch.cat(points, dim=0)
|
|
638
|
-
|
|
639
|
-
return octree
|
|
759
|
+
return Octree.init_like(octrees[0]).merge_octrees(octrees)
|
|
640
760
|
|
|
641
761
|
|
|
642
762
|
def init_octree(depth: int, full_depth: int = 2, batch_size: int = 1,
|
|
643
763
|
device: Union[torch.device, str] = 'cpu'):
|
|
644
|
-
r'''
|
|
645
|
-
Initializes an octree to :attr:`full_depth`.
|
|
646
|
-
|
|
647
|
-
Args:
|
|
648
|
-
depth (int): The depth of the octree.
|
|
649
|
-
full_depth (int): The octree layers with a depth small than
|
|
650
|
-
:attr:`full_depth` are forced to be full.
|
|
651
|
-
batch_size (int, optional): The batch size.
|
|
652
|
-
device (torch.device or str): The device to use for computation.
|
|
764
|
+
r''' A wrapper of :meth:`Octree.init_octree`.
|
|
653
765
|
|
|
654
|
-
|
|
655
|
-
|
|
766
|
+
.. deprecated:: 2.2.7
|
|
767
|
+
Use :meth:`Octree.init_octree` instead.
|
|
656
768
|
'''
|
|
657
769
|
|
|
658
|
-
|
|
659
|
-
for d in range(full_depth+1):
|
|
660
|
-
octree.octree_grow_full(depth=d)
|
|
661
|
-
return octree
|
|
770
|
+
return Octree.init_octree(depth, full_depth, batch_size, device)
|
ocnn/octree/points.py
CHANGED
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
import torch
|
|
9
9
|
import numpy as np
|
|
10
10
|
from typing import Optional, Union, List
|
|
11
|
+
from ocnn.utils import cumsum
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class Points:
|
|
@@ -66,6 +67,7 @@ class Points:
|
|
|
66
67
|
if self.batch_id.dim() == 1:
|
|
67
68
|
self.batch_id = self.batch_id.unsqueeze(1)
|
|
68
69
|
assert self.batch_id.size(1) == 1
|
|
70
|
+
assert self.batch_size == self.batch_id.max().item() + 1
|
|
69
71
|
|
|
70
72
|
@property
|
|
71
73
|
def npt(self):
|
|
@@ -165,26 +167,23 @@ class Points:
|
|
|
165
167
|
'''
|
|
166
168
|
|
|
167
169
|
mask = self.inbox_mask(min + esp, max - esp)
|
|
168
|
-
|
|
169
|
-
self.__dict__.update(tmp.__dict__)
|
|
170
|
+
self.copy_from(self[mask])
|
|
170
171
|
return mask
|
|
171
172
|
|
|
172
|
-
def __getitem__(self,
|
|
173
|
-
r''' Slices the point cloud according a given :attr:`
|
|
173
|
+
def __getitem__(self, idx):
|
|
174
|
+
r''' Slices the point cloud according a given :attr:`idx`.
|
|
174
175
|
'''
|
|
175
176
|
|
|
176
|
-
|
|
177
|
-
out =
|
|
178
|
-
|
|
179
|
-
out.points = self.points[mask]
|
|
177
|
+
out = self.init_points(self.device, self.batch_size)
|
|
178
|
+
out.points = self.points[idx]
|
|
180
179
|
if self.normals is not None:
|
|
181
|
-
out.normals = self.normals[
|
|
180
|
+
out.normals = self.normals[idx]
|
|
182
181
|
if self.features is not None:
|
|
183
|
-
out.features = self.features[
|
|
182
|
+
out.features = self.features[idx]
|
|
184
183
|
if self.labels is not None:
|
|
185
|
-
out.labels = self.labels[
|
|
184
|
+
out.labels = self.labels[idx]
|
|
186
185
|
if self.batch_id is not None:
|
|
187
|
-
out.batch_id = self.batch_id[
|
|
186
|
+
out.batch_id = self.batch_id[idx]
|
|
188
187
|
return out
|
|
189
188
|
|
|
190
189
|
def inbox_mask(self, bbmin: Union[float, torch.Tensor] = -1.0,
|
|
@@ -239,8 +238,7 @@ class Points:
|
|
|
239
238
|
return self
|
|
240
239
|
|
|
241
240
|
# Construct a new Points on the specified device
|
|
242
|
-
points =
|
|
243
|
-
points.batch_size = self.batch_size
|
|
241
|
+
points = self.init_points(device, self.batch_size)
|
|
244
242
|
points.batch_npt = self.batch_npt
|
|
245
243
|
points.points = self.points.to(device, non_blocking=non_blocking)
|
|
246
244
|
if self.normals is not None:
|
|
@@ -295,29 +293,92 @@ class Points:
|
|
|
295
293
|
else:
|
|
296
294
|
raise ValueError
|
|
297
295
|
|
|
296
|
+
def copy_from(self, points: 'Points'):
|
|
297
|
+
r''' Shallow copy from another Points.
|
|
298
|
+
'''
|
|
299
|
+
|
|
300
|
+
self.points = points.points
|
|
301
|
+
self.normals = points.normals
|
|
302
|
+
self.features = points.features
|
|
303
|
+
self.labels = points.labels
|
|
304
|
+
self.batch_id = points.batch_id
|
|
305
|
+
self.batch_size = points.batch_size
|
|
306
|
+
self.device = points.device
|
|
307
|
+
self.batch_npt = points.batch_npt
|
|
308
|
+
|
|
309
|
+
def merge_points(self, points: List['Points'], update_batch_info: bool = True):
|
|
310
|
+
r''' Merges a list of points into one batch.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
points (List[Octree]): A list of points to merge. The batch size of each
|
|
314
|
+
points in the list is assumed to be 1, and the :obj:`batch_size`,
|
|
315
|
+
:obj:`batch_id`, and :obj:`batch_npt` in the points are ignored.
|
|
316
|
+
'''
|
|
317
|
+
|
|
318
|
+
self.points = torch.cat([p.points for p in points], dim=0)
|
|
319
|
+
if points[0].normals is not None:
|
|
320
|
+
self.normals = torch.cat([p.normals for p in points], dim=0)
|
|
321
|
+
if points[0].features is not None:
|
|
322
|
+
self.features = torch.cat([p.features for p in points], dim=0)
|
|
323
|
+
if points[0].labels is not None:
|
|
324
|
+
self.labels = torch.cat([p.labels for p in points], dim=0)
|
|
325
|
+
self.device = points[0].device
|
|
326
|
+
|
|
327
|
+
if update_batch_info:
|
|
328
|
+
self.batch_size = len(points)
|
|
329
|
+
self.batch_npt = torch.Tensor([p.npt for p in points]).long()
|
|
330
|
+
self.batch_id = torch.cat([p.points.new_full((p.npt, 1), i)
|
|
331
|
+
for i, p in enumerate(points)], dim=0)
|
|
332
|
+
return self
|
|
333
|
+
|
|
334
|
+
def split_points(self):
|
|
335
|
+
r''' Splits the batched points into a list of Points.
|
|
336
|
+
'''
|
|
337
|
+
|
|
338
|
+
if self.batch_npt is None:
|
|
339
|
+
self.batch_npt = torch.bincount(
|
|
340
|
+
self.batch_id.squeeze(), minlength=self.batch_size)
|
|
341
|
+
|
|
342
|
+
outs = []
|
|
343
|
+
cs = cumsum(self.batch_npt, dim=0, exclusive=True)
|
|
344
|
+
for i in range(self.batch_size):
|
|
345
|
+
rng = range(cs[i], cs[i+1])
|
|
346
|
+
out = Points.init_points(self.device, batch_size=1)
|
|
347
|
+
out.points = self.points[rng]
|
|
348
|
+
if self.normals is not None:
|
|
349
|
+
out.normals = self.normals[rng]
|
|
350
|
+
if self.features is not None:
|
|
351
|
+
out.features = self.features[rng]
|
|
352
|
+
if self.labels is not None:
|
|
353
|
+
out.labels = self.labels[rng]
|
|
354
|
+
outs.append(out)
|
|
355
|
+
return outs
|
|
356
|
+
|
|
357
|
+
@classmethod
|
|
358
|
+
def init_points(cls, device: Union[torch.device, str, None] = None,
|
|
359
|
+
batch_size: int = 1):
|
|
360
|
+
r''' Initialzes a Points object with dummy data on a specified device.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
device (torch.device or str or None): The device of the Points. If
|
|
364
|
+
:obj:`None`, the device is set to :obj:`cpu`.
|
|
365
|
+
batch_size (int): The batch size.
|
|
366
|
+
'''
|
|
367
|
+
|
|
368
|
+
points = torch.zeros(batch_size, 3, device=device)
|
|
369
|
+
batch_id = (torch.arange(batch_size, device=device).unsqueeze(1)
|
|
370
|
+
if batch_size > 1 else None)
|
|
371
|
+
return cls(points, batch_size=batch_size, batch_id=batch_id)
|
|
372
|
+
|
|
298
373
|
|
|
299
374
|
def merge_points(points: List['Points'], update_batch_info: bool = True):
|
|
300
|
-
r'''
|
|
375
|
+
r''' A wrapper of :meth:`Points.merge_points`.
|
|
301
376
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
points in the list is assumed to be 1, and the :obj:`batch_size`,
|
|
305
|
-
:obj:`batch_id`, and :obj:`batch_npt` in the points are ignored.
|
|
377
|
+
.. deprecated:: 2.2.7
|
|
378
|
+
Use :meth:`Points.merge_points` instead.
|
|
306
379
|
'''
|
|
307
380
|
|
|
308
|
-
|
|
309
|
-
out
|
|
310
|
-
|
|
311
|
-
out.normals = torch.cat([p.normals for p in points], dim=0)
|
|
312
|
-
if points[0].features is not None:
|
|
313
|
-
out.features = torch.cat([p.features for p in points], dim=0)
|
|
314
|
-
if points[0].labels is not None:
|
|
315
|
-
out.labels = torch.cat([p.labels for p in points], dim=0)
|
|
316
|
-
out.device = points[0].device
|
|
317
|
-
|
|
318
|
-
if update_batch_info:
|
|
319
|
-
out.batch_size = len(points)
|
|
320
|
-
out.batch_npt = torch.Tensor([p.npt for p in points]).long()
|
|
321
|
-
out.batch_id = torch.cat([p.points.new_full((p.npt, 1), i)
|
|
322
|
-
for i, p in enumerate(points)], dim=0)
|
|
381
|
+
assert len(points) > 0, 'The input points list is empty.'
|
|
382
|
+
out = Points.init_points(points[0].device, batch_size=len(points))
|
|
383
|
+
out.merge_points(points, update_batch_info)
|
|
323
384
|
return out
|
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ocnn
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.3.0
|
|
4
4
|
Summary: Octree-based Sparse Convolutional Neural Networks
|
|
5
5
|
Home-page: https://github.com/octree-nn/ocnn-pytorch
|
|
6
6
|
Author: Peng-Shuai Wang
|
|
7
7
|
Author-email: wangps@hotmail.com
|
|
8
8
|
License: MIT
|
|
9
9
|
Classifier: Programming Language :: Python :: 3
|
|
10
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
11
10
|
Classifier: Operating System :: OS Independent
|
|
12
11
|
Requires-Python: >=3.6
|
|
13
12
|
Description-Content-Type: text/markdown
|
|
@@ -84,6 +83,13 @@ octrees to perform convolution operations. Of course, it also supports other 3D
|
|
|
84
83
|
data formats, such as meshes and volumetric grids, which can be converted into
|
|
85
84
|
octrees to leverage the library's capabilities.
|
|
86
85
|
|
|
86
|
+
## Updates
|
|
87
|
+
|
|
88
|
+
- **2026.02.02**: Release `v2.3.0`, incorporating Triton to accelerate
|
|
89
|
+
octree-based sparse convolution in the upcoming release. OctreeConv is even
|
|
90
|
+
**2.5 times faster than the latest spconv**!
|
|
91
|
+
- **2025.12.18**: Release `v2.2.8`, improving neighbor search efficiency.
|
|
92
|
+
|
|
87
93
|
|
|
88
94
|
## Key benefits of ocnn-pytorch
|
|
89
95
|
|
|
@@ -93,10 +99,9 @@ octrees to leverage the library's capabilities.
|
|
|
93
99
|
configure the compiling environment.
|
|
94
100
|
|
|
95
101
|
- **Efficiency**. The ocnn-pytorch is very efficient compared with other sparse
|
|
96
|
-
convolution frameworks.
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
takes 30 hours.
|
|
102
|
+
convolution frameworks. It is **even 2.5 times faster than the latest spconv
|
|
103
|
+
implementation**! Check the benchmark [code](test/benchmark_conv.py) and
|
|
104
|
+
[results](test/benchmark/results.png) for details. ✨
|
|
100
105
|
|
|
101
106
|
## Citation
|
|
102
107
|
|