ocnn 2.2.7__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/octree/octree.py CHANGED
@@ -43,11 +43,13 @@ class Octree:
43
43
  def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
44
44
  device: Union[torch.device, str] = 'cpu', **kwargs):
45
45
  super().__init__()
46
+ # configurations for initialization
46
47
  self.depth = depth
47
48
  self.full_depth = full_depth
48
49
  self.batch_size = batch_size
49
50
  self.device = device
50
51
 
52
+ # properties after building the octree
51
53
  self.reset()
52
54
 
53
55
  def reset(self):
@@ -63,12 +65,18 @@ class Octree:
63
65
  self.normals = [None] * num
64
66
  self.points = [None] * num
65
67
 
68
+ # self.nempty_masks, self.nempty_indices and self.nempty_neighs are
69
+ # for handling of non-empty nodes and are constructed on demand
70
+ self.nempty_masks = [None] * num
71
+ self.nempty_indices = [None] * num
72
+ self.nempty_neighs = [None] * num
73
+
66
74
  # octree node numbers in each octree layers.
67
- # TODO: decide whether to settle them to 'gpu' or not?
75
+ # These are small 1-D tensors; just keep them on CPUs
68
76
  self.nnum = torch.zeros(num, dtype=torch.long)
69
77
  self.nnum_nempty = torch.zeros(num, dtype=torch.long)
70
78
 
71
- # the following properties are valid after `merge_octrees`.
79
+ # the following properties are only valid after `merge_octrees`.
72
80
  # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
73
81
  batch_size = self.batch_size
74
82
  self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
@@ -107,8 +115,8 @@ class Octree:
107
115
 
108
116
  key = self.keys[depth]
109
117
  if nempty:
110
- mask = self.nempty_mask(depth)
111
- key = key[mask]
118
+ idx = self.nempty_index(depth)
119
+ key = key[idx]
112
120
  return key
113
121
 
114
122
  def xyzb(self, depth: int, nempty: bool = False):
@@ -132,19 +140,63 @@ class Octree:
132
140
 
133
141
  batch_id = self.keys[depth] >> 48
134
142
  if nempty:
135
- mask = self.nempty_mask(depth)
136
- batch_id = batch_id[mask]
143
+ idx = self.nempty_index(depth)
144
+ batch_id = batch_id[idx]
137
145
  return batch_id
138
146
 
139
- def nempty_mask(self, depth: int):
147
+ def nempty_mask(self, depth: int, reset: bool = False):
140
148
  r''' Returns a binary mask which indicates whether the cooreponding octree
141
149
  node is empty or not.
142
150
 
143
151
  Args:
144
152
  depth (int): The depth of the octree.
153
+ reset (bool): If True, recomputes the mask.
145
154
  '''
146
155
 
147
- return self.children[depth] >= 0
156
+ if self.nempty_masks[depth] is None or reset:
157
+ self.nempty_masks[depth] = self.children[depth] >= 0
158
+ return self.nempty_masks[depth]
159
+
160
+ def nempty_index(self, depth: int, reset: bool = False):
161
+ r''' Returns the indices of non-empty octree nodes.
162
+
163
+ Args:
164
+ depth (int): The depth of the octree.
165
+ reset (bool): If True, recomputes the indices.
166
+ '''
167
+
168
+ if self.nempty_indices[depth] is None or reset:
169
+ mask = self.nempty_mask(depth)
170
+ rng = torch.arange(mask.shape[0], device=mask.device, dtype=torch.long)
171
+ self.nempty_indices[depth] = rng[mask]
172
+ return self.nempty_indices[depth]
173
+
174
+ def nempty_neigh(self, depth: int, reset: bool = False):
175
+ r''' Returns the neighborhoods of non-empty octree nodes.
176
+ Args:
177
+
178
+ depth (int): The depth of the octree.
179
+ reset (bool): If True, recomputes the neighborhoods.
180
+ '''
181
+ if self.nempty_neighs[depth] is None or reset:
182
+ neigh = self.neighs[depth]
183
+ idx = self.nempty_index(depth)
184
+ neigh = self.remap_nempty_neigh(neigh[idx], depth)
185
+ self.nempty_neighs[depth] = neigh
186
+ return self.nempty_neighs[depth]
187
+
188
+ def remap_nempty_neigh(self, neigh: torch.Tensor, depth: int):
189
+ r''' Remaps the neighborhood indices to the non-empty octree nodes.
190
+
191
+ Args:
192
+ neigh (torch.Tensor): The input neighborhoods with shape :obj:`(N, 27)`.
193
+ depth (int): The depth of the octree.
194
+ '''
195
+
196
+ valid = neigh >= 0
197
+ child = self.children[depth]
198
+ neigh[valid] = child[neigh[valid]].long() # remap the index
199
+ return neigh
148
200
 
149
201
  def build_octree(self, point_cloud: Points):
150
202
  r''' Builds an octree from a point cloud.
@@ -225,6 +277,10 @@ class Octree:
225
277
  features = scatter_add(point_cloud.features, idx, dim=0)
226
278
  self.features[d] = features / counts.unsqueeze(1)
227
279
 
280
+ # reset nempty_masks and nempty_indices, which will be updated on demand
281
+ self.nempty_masks = [None] * (self.depth + 1)
282
+ self.nempty_indices = [None] * (self.depth + 1)
283
+ self.nempty_neighs = [None] * (self.depth + 1)
228
284
  return idx
229
285
 
230
286
  def octree_grow_full(self, depth: int, update_neigh: bool = True):
@@ -240,18 +296,22 @@ class Octree:
240
296
 
241
297
  # node number
242
298
  num = 1 << (3 * depth)
243
- self.nnum[depth] = num * self.batch_size
244
- self.nnum_nempty[depth] = num * self.batch_size
299
+ batch_size = self.batch_size
300
+ self.nnum[depth] = num * batch_size
301
+ self.nnum_nempty[depth] = num * batch_size
245
302
 
246
303
  # update key
247
304
  key = torch.arange(num, dtype=torch.long, device=self.device)
248
- bs = torch.arange(self.batch_size, dtype=torch.long, device=self.device)
305
+ bs = torch.arange(batch_size, dtype=torch.long, device=self.device)
249
306
  key = key.unsqueeze(0) | (bs.unsqueeze(1) << 48)
250
307
  self.keys[depth] = key.view(-1)
251
308
 
252
309
  # update children
253
310
  self.children[depth] = torch.arange(
254
- num * self.batch_size, dtype=torch.int32, device=self.device)
311
+ num * batch_size, dtype=torch.int32, device=self.device)
312
+
313
+ # nempty_masks, nempty_indices, and nempty_neighs
314
+ # need not be reset for full octrees
255
315
 
256
316
  # update neigh if needed
257
317
  if update_neigh:
@@ -281,6 +341,12 @@ class Octree:
281
341
  self.children[depth] = children.int()
282
342
  self.nnum_nempty[depth] = nnum_nempty
283
343
 
344
+ # reset nempty_masks, nempty_indices, and nempty_neighs as they depend on
345
+ # children[depth] and are invalid now
346
+ self.nempty_masks[depth] = None
347
+ self.nempty_indices[depth] = None
348
+ self.nempty_neighs[depth] = None
349
+
284
350
  def octree_grow(self, depth: int, update_neigh: bool = True):
285
351
  r''' Grows the octree and updates the relevant properties. And in most
286
352
  cases, call :func:`Octree.octree_split` to update the splitting status of
@@ -301,6 +367,9 @@ class Octree:
301
367
  self.features.append(None)
302
368
  self.normals.append(None)
303
369
  self.points.append(None)
370
+ self.nempty_masks.append(None)
371
+ self.nempty_indices.append(None)
372
+ self.nempty_neighs.append(None)
304
373
  zero = torch.zeros(1, dtype=torch.long)
305
374
  self.nnum = torch.cat([self.nnum, zero])
306
375
  self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
@@ -342,7 +411,7 @@ class Octree:
342
411
  key = torch.arange(nnum, dtype=torch.long, device=device)
343
412
  x, y, z, _ = key2xyz(key, depth)
344
413
  xyz = torch.stack([x, y, z], dim=-1) # (N, 3)
345
- grid = range_grid(-1, 1, device) # (27, 3)
414
+ grid = range_grid(-1, 1, device) # (27, 3)
346
415
  xyz = xyz.unsqueeze(1) + grid # (N, 27, 3)
347
416
  xyz = xyz.view(-1, 3) # (N*27, 3)
348
417
  neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
@@ -400,16 +469,18 @@ class Octree:
400
469
  '''
401
470
 
402
471
  key = self.key(depth, nempty)
403
- # `torch.bucketize` is similar to `torch.searchsorted`.
404
- # I choose `torch.bucketize` here because it has fewer dimension checks,
405
- # resulting in slightly better performance according to the docs of
406
- # pytorch-1.9.1, since `key` is always 1-D sorted sequence.
472
+ idx = torch.searchsorted(key, query)
473
+
474
+ # `torch.bucketize` can also be used here; it is similar to
475
+ # `torch.searchsorted`, and it has fewer dimension checks, resulting in
476
+ # slightly better performance for 1D sorted sequences according to the docs
477
+ # of pytorch-1.9.1. `key` is always a 1D sorted sequence.
407
478
  # https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
408
- idx = torch.bucketize(query, key)
479
+ # idx = torch.bucketize(query, key)
409
480
 
410
481
  valid = idx < key.shape[0] # valid if in-bound
411
- found = key[idx[valid]] == query[valid]
412
- valid[valid.clone()] = found # valid if found
482
+ vi = torch.arange(query.shape[0], device=query.device)[valid]
483
+ valid[vi] = key[idx[vi]] == query[vi] # valid if found
413
484
  idx[valid.logical_not()] = -1 # set to -1 if invalid
414
485
  return idx
415
486
 
@@ -428,22 +499,18 @@ class Octree:
428
499
  octree nodes.
429
500
  '''
430
501
 
431
- if stride == 1:
502
+ if stride == 1 and not nempty:
432
503
  neigh = self.neighs[depth]
433
- elif stride == 2:
434
- # clone neigh to avoid self.neigh[depth] being modified
504
+ elif stride == 2 and not nempty:
435
505
  neigh = self.neighs[depth][::8].clone()
506
+ elif stride == 1 and nempty:
507
+ neigh = self.nempty_neigh(depth)
508
+ elif stride == 2 and nempty:
509
+ neigh = self.neighs[depth][::8].clone()
510
+ neigh = self.remap_nempty_neigh(neigh, depth)
436
511
  else:
437
512
  raise ValueError('Unsupported stride {}'.format(stride))
438
513
 
439
- if nempty:
440
- child = self.children[depth]
441
- if stride == 1:
442
- nempty_node = child >= 0
443
- neigh = neigh[nempty_node]
444
- valid = neigh >= 0
445
- neigh[valid] = child[neigh[valid]].long() # remap the index
446
-
447
514
  if kernel == '333':
448
515
  return neigh
449
516
  elif kernel in self.lut_kernel:
@@ -548,15 +615,23 @@ class Octree:
548
615
  return [p.to(device, non_blocking=non_blocking)
549
616
  if isinstance(p, torch.Tensor) else None for p in prop]
550
617
 
551
- # Construct a new Octree on the specified device
552
- octree = Octree(self.depth, self.full_depth, self.batch_size, device)
618
+ # Construct a new Octree on the specified device.
619
+ # During the initialization, self.device is used to set up the new Octree;
620
+ # the look-up tables (including self.lut_kernel, self.lut_parent, and
621
+ # self.lut_child), will be already created on the correct device.
622
+ octree = Octree.init_like(self, device)
623
+
624
+ # Move all the other properties to the specified device
553
625
  octree.keys = list_to_device(self.keys)
554
626
  octree.children = list_to_device(self.children)
555
627
  octree.neighs = list_to_device(self.neighs)
556
628
  octree.features = list_to_device(self.features)
557
629
  octree.normals = list_to_device(self.normals)
558
630
  octree.points = list_to_device(self.points)
559
- octree.nnum = self.nnum.clone() # TODO: whether to move nnum to the self.device?
631
+
632
+ # The following are small tensors, keep them on CPU to avoid frequent device
633
+ # switching, so just clone them.
634
+ octree.nnum = self.nnum.clone()
560
635
  octree.nnum_nempty = self.nnum_nempty.clone()
561
636
  octree.batch_nnum = self.batch_nnum.clone()
562
637
  octree.batch_nnum_nempty = self.batch_nnum_nempty.clone()
@@ -572,90 +647,124 @@ class Octree:
572
647
 
573
648
  return self.to('cpu')
574
649
 
650
+ def merge_octrees(self, octrees: List['Octree']):
651
+ r''' Merges a list of octrees into one batch.
575
652
 
576
- def merge_octrees(octrees: List['Octree']):
577
- r''' Merges a list of octrees into one batch.
653
+ Args:
654
+ octrees (List[Octree]): A list of octrees to merge.
578
655
 
579
- Args:
580
- octrees (List[Octree]): A list of octrees to merge.
656
+ Returns:
657
+ Octree: The merged octree.
658
+ '''
659
+
660
+ # init and check
661
+ batch_size = len(octrees)
662
+ self.batch_size = batch_size
663
+ for i in range(1, batch_size):
664
+ condition = (octrees[i].depth == self.depth and
665
+ octrees[i].full_depth == self.full_depth and
666
+ octrees[i].device == self.device)
667
+ assert condition, 'The check of merge_octrees failed'
668
+
669
+ # node num
670
+ batch_nnum = torch.stack(
671
+ [octrees[i].nnum for i in range(batch_size)], dim=1)
672
+ batch_nnum_nempty = torch.stack(
673
+ [octrees[i].nnum_nempty for i in range(batch_size)], dim=1)
674
+ self.nnum = torch.sum(batch_nnum, dim=1)
675
+ self.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
676
+ self.batch_nnum = batch_nnum
677
+ self.batch_nnum_nempty = batch_nnum_nempty
678
+ nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
679
+
680
+ # merge octre properties
681
+ for d in range(self.depth + 1):
682
+ # key
683
+ keys = [None] * batch_size
684
+ for i in range(batch_size):
685
+ key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
686
+ keys[i] = key | (i << 48)
687
+ self.keys[d] = torch.cat(keys, dim=0)
688
+
689
+ # children
690
+ children = [None] * batch_size
691
+ for i in range(batch_size):
692
+ # !! `clone` is used here to avoid modifying the original octrees
693
+ child = octrees[i].children[d].clone()
694
+ mask = child >= 0
695
+ child[mask] = child[mask] + nnum_cum[d, i]
696
+ children[i] = child
697
+ self.children[d] = torch.cat(children, dim=0)
698
+
699
+ # features
700
+ if octrees[0].features[d] is not None and d == self.depth:
701
+ features = [octrees[i].features[d] for i in range(batch_size)]
702
+ self.features[d] = torch.cat(features, dim=0)
703
+
704
+ # normals
705
+ if octrees[0].normals[d] is not None and d == self.depth:
706
+ normals = [octrees[i].normals[d] for i in range(batch_size)]
707
+ self.normals[d] = torch.cat(normals, dim=0)
708
+
709
+ # points
710
+ if octrees[0].points[d] is not None and d == self.depth:
711
+ points = [octrees[i].points[d] for i in range(batch_size)]
712
+ self.points[d] = torch.cat(points, dim=0)
713
+
714
+ return self
715
+
716
+ @classmethod
717
+ def init_like(cls, octree: 'Octree', device: Union[torch.device, str, None] = None):
718
+ r''' Initializes the octree like another octree.
719
+
720
+ Args:
721
+ octree (Octree): The reference octree.
722
+ device (torch.device or str): The device to use for computation.
723
+ '''
724
+
725
+ device = device if device is not None else octree.device
726
+ return cls(depth=octree.depth, full_depth=octree.full_depth,
727
+ batch_size=octree.batch_size, device=device)
728
+
729
+ @classmethod
730
+ def init_octree(cls, depth: int, full_depth: int = 2, batch_size: int = 1,
731
+ device: Union[torch.device, str] = 'cpu'):
732
+ r'''
733
+ Initializes an octree to :attr:`full_depth`.
581
734
 
582
- Returns:
583
- Octree: The merged octree.
735
+ Args:
736
+ depth (int): The depth of the octree.
737
+ full_depth (int): The octree layers with a depth small than
738
+ :attr:`full_depth` are forced to be full.
739
+ batch_size (int, optional): The batch size.
740
+ device (torch.device or str): The device to use for computation.
741
+
742
+ Returns:
743
+ Octree: The initialized Octree object.
744
+ '''
745
+
746
+ octree = cls(depth, full_depth, batch_size, device)
747
+ for d in range(full_depth + 1):
748
+ octree.octree_grow_full(depth=d)
749
+ return octree
750
+
751
+
752
+ def merge_octrees(octrees: List['Octree']):
753
+ r''' A wrapper of :meth:`Octree.merge_octrees`.
754
+
755
+ .. deprecated:: 2.2.7
756
+ Use :meth:`Octree.merge_octrees` instead.
584
757
  '''
585
758
 
586
- # init and check
587
- octree = Octree(depth=octrees[0].depth, full_depth=octrees[0].full_depth,
588
- batch_size=len(octrees), device=octrees[0].device)
589
- for i in range(1, octree.batch_size):
590
- condition = (octrees[i].depth == octree.depth and
591
- octrees[i].full_depth == octree.full_depth and
592
- octrees[i].device == octree.device)
593
- assert condition, 'The check of merge_octrees failed'
594
-
595
- # node num
596
- batch_nnum = torch.stack(
597
- [octrees[i].nnum for i in range(octree.batch_size)], dim=1)
598
- batch_nnum_nempty = torch.stack(
599
- [octrees[i].nnum_nempty for i in range(octree.batch_size)], dim=1)
600
- octree.nnum = torch.sum(batch_nnum, dim=1)
601
- octree.nnum_nempty = torch.sum(batch_nnum_nempty, dim=1)
602
- octree.batch_nnum = batch_nnum
603
- octree.batch_nnum_nempty = batch_nnum_nempty
604
- nnum_cum = cumsum(batch_nnum_nempty, dim=1, exclusive=True)
605
-
606
- # merge octre properties
607
- for d in range(octree.depth+1):
608
- # key
609
- keys = [None] * octree.batch_size
610
- for i in range(octree.batch_size):
611
- key = octrees[i].keys[d] & ((1 << 48) - 1) # clear the highest bits
612
- keys[i] = key | (i << 48)
613
- octree.keys[d] = torch.cat(keys, dim=0)
614
-
615
- # children
616
- children = [None] * octree.batch_size
617
- for i in range(octree.batch_size):
618
- child = octrees[i].children[d].clone() # !! `clone` is used here to avoid
619
- mask = child >= 0 # !! modifying the original octrees
620
- child[mask] = child[mask] + nnum_cum[d, i]
621
- children[i] = child
622
- octree.children[d] = torch.cat(children, dim=0)
623
-
624
- # features
625
- if octrees[0].features[d] is not None and d == octree.depth:
626
- features = [octrees[i].features[d] for i in range(octree.batch_size)]
627
- octree.features[d] = torch.cat(features, dim=0)
628
-
629
- # normals
630
- if octrees[0].normals[d] is not None and d == octree.depth:
631
- normals = [octrees[i].normals[d] for i in range(octree.batch_size)]
632
- octree.normals[d] = torch.cat(normals, dim=0)
633
-
634
- # points
635
- if octrees[0].points[d] is not None and d == octree.depth:
636
- points = [octrees[i].points[d] for i in range(octree.batch_size)]
637
- octree.points[d] = torch.cat(points, dim=0)
638
-
639
- return octree
759
+ return Octree.init_like(octrees[0]).merge_octrees(octrees)
640
760
 
641
761
 
642
762
  def init_octree(depth: int, full_depth: int = 2, batch_size: int = 1,
643
763
  device: Union[torch.device, str] = 'cpu'):
644
- r'''
645
- Initializes an octree to :attr:`full_depth`.
646
-
647
- Args:
648
- depth (int): The depth of the octree.
649
- full_depth (int): The octree layers with a depth small than
650
- :attr:`full_depth` are forced to be full.
651
- batch_size (int, optional): The batch size.
652
- device (torch.device or str): The device to use for computation.
764
+ r''' A wrapper of :meth:`Octree.init_octree`.
653
765
 
654
- Returns:
655
- Octree: The initialized Octree object.
766
+ .. deprecated:: 2.2.7
767
+ Use :meth:`Octree.init_octree` instead.
656
768
  '''
657
769
 
658
- octree = Octree(depth, full_depth, batch_size, device)
659
- for d in range(full_depth+1):
660
- octree.octree_grow_full(depth=d)
661
- return octree
770
+ return Octree.init_octree(depth, full_depth, batch_size, device)
ocnn/octree/points.py CHANGED
@@ -8,6 +8,7 @@
8
8
  import torch
9
9
  import numpy as np
10
10
  from typing import Optional, Union, List
11
+ from ocnn.utils import cumsum
11
12
 
12
13
 
13
14
  class Points:
@@ -66,6 +67,7 @@ class Points:
66
67
  if self.batch_id.dim() == 1:
67
68
  self.batch_id = self.batch_id.unsqueeze(1)
68
69
  assert self.batch_id.size(1) == 1
70
+ assert self.batch_size == self.batch_id.max().item() + 1
69
71
 
70
72
  @property
71
73
  def npt(self):
@@ -165,26 +167,23 @@ class Points:
165
167
  '''
166
168
 
167
169
  mask = self.inbox_mask(min + esp, max - esp)
168
- tmp = self.__getitem__(mask)
169
- self.__dict__.update(tmp.__dict__)
170
+ self.copy_from(self[mask])
170
171
  return mask
171
172
 
172
- def __getitem__(self, mask: torch.Tensor):
173
- r''' Slices the point cloud according a given :attr:`mask`.
173
+ def __getitem__(self, idx):
174
+ r''' Slices the point cloud according a given :attr:`idx`.
174
175
  '''
175
176
 
176
- dummy_pts = torch.zeros(1, 3, device=self.device)
177
- out = Points(dummy_pts, batch_size=self.batch_size)
178
-
179
- out.points = self.points[mask]
177
+ out = self.init_points(self.device, self.batch_size)
178
+ out.points = self.points[idx]
180
179
  if self.normals is not None:
181
- out.normals = self.normals[mask]
180
+ out.normals = self.normals[idx]
182
181
  if self.features is not None:
183
- out.features = self.features[mask]
182
+ out.features = self.features[idx]
184
183
  if self.labels is not None:
185
- out.labels = self.labels[mask]
184
+ out.labels = self.labels[idx]
186
185
  if self.batch_id is not None:
187
- out.batch_id = self.batch_id[mask]
186
+ out.batch_id = self.batch_id[idx]
188
187
  return out
189
188
 
190
189
  def inbox_mask(self, bbmin: Union[float, torch.Tensor] = -1.0,
@@ -239,8 +238,7 @@ class Points:
239
238
  return self
240
239
 
241
240
  # Construct a new Points on the specified device
242
- points = Points(torch.zeros(1, 3, device=device))
243
- points.batch_size = self.batch_size
241
+ points = self.init_points(device, self.batch_size)
244
242
  points.batch_npt = self.batch_npt
245
243
  points.points = self.points.to(device, non_blocking=non_blocking)
246
244
  if self.normals is not None:
@@ -295,29 +293,92 @@ class Points:
295
293
  else:
296
294
  raise ValueError
297
295
 
296
+ def copy_from(self, points: 'Points'):
297
+ r''' Shallow copy from another Points.
298
+ '''
299
+
300
+ self.points = points.points
301
+ self.normals = points.normals
302
+ self.features = points.features
303
+ self.labels = points.labels
304
+ self.batch_id = points.batch_id
305
+ self.batch_size = points.batch_size
306
+ self.device = points.device
307
+ self.batch_npt = points.batch_npt
308
+
309
+ def merge_points(self, points: List['Points'], update_batch_info: bool = True):
310
+ r''' Merges a list of points into one batch.
311
+
312
+ Args:
313
+ points (List[Octree]): A list of points to merge. The batch size of each
314
+ points in the list is assumed to be 1, and the :obj:`batch_size`,
315
+ :obj:`batch_id`, and :obj:`batch_npt` in the points are ignored.
316
+ '''
317
+
318
+ self.points = torch.cat([p.points for p in points], dim=0)
319
+ if points[0].normals is not None:
320
+ self.normals = torch.cat([p.normals for p in points], dim=0)
321
+ if points[0].features is not None:
322
+ self.features = torch.cat([p.features for p in points], dim=0)
323
+ if points[0].labels is not None:
324
+ self.labels = torch.cat([p.labels for p in points], dim=0)
325
+ self.device = points[0].device
326
+
327
+ if update_batch_info:
328
+ self.batch_size = len(points)
329
+ self.batch_npt = torch.Tensor([p.npt for p in points]).long()
330
+ self.batch_id = torch.cat([p.points.new_full((p.npt, 1), i)
331
+ for i, p in enumerate(points)], dim=0)
332
+ return self
333
+
334
+ def split_points(self):
335
+ r''' Splits the batched points into a list of Points.
336
+ '''
337
+
338
+ if self.batch_npt is None:
339
+ self.batch_npt = torch.bincount(
340
+ self.batch_id.squeeze(), minlength=self.batch_size)
341
+
342
+ outs = []
343
+ cs = cumsum(self.batch_npt, dim=0, exclusive=True)
344
+ for i in range(self.batch_size):
345
+ rng = range(cs[i], cs[i+1])
346
+ out = Points.init_points(self.device, batch_size=1)
347
+ out.points = self.points[rng]
348
+ if self.normals is not None:
349
+ out.normals = self.normals[rng]
350
+ if self.features is not None:
351
+ out.features = self.features[rng]
352
+ if self.labels is not None:
353
+ out.labels = self.labels[rng]
354
+ outs.append(out)
355
+ return outs
356
+
357
+ @classmethod
358
+ def init_points(cls, device: Union[torch.device, str, None] = None,
359
+ batch_size: int = 1):
360
+ r''' Initialzes a Points object with dummy data on a specified device.
361
+
362
+ Args:
363
+ device (torch.device or str or None): The device of the Points. If
364
+ :obj:`None`, the device is set to :obj:`cpu`.
365
+ batch_size (int): The batch size.
366
+ '''
367
+
368
+ points = torch.zeros(batch_size, 3, device=device)
369
+ batch_id = (torch.arange(batch_size, device=device).unsqueeze(1)
370
+ if batch_size > 1 else None)
371
+ return cls(points, batch_size=batch_size, batch_id=batch_id)
372
+
298
373
 
299
374
  def merge_points(points: List['Points'], update_batch_info: bool = True):
300
- r''' Merges a list of points into one batch.
375
+ r''' A wrapper of :meth:`Points.merge_points`.
301
376
 
302
- Args:
303
- points (List[Octree]): A list of points to merge. The batch size of each
304
- points in the list is assumed to be 1, and the :obj:`batch_size`,
305
- :obj:`batch_id`, and :obj:`batch_npt` in the points are ignored.
377
+ .. deprecated:: 2.2.7
378
+ Use :meth:`Points.merge_points` instead.
306
379
  '''
307
380
 
308
- out = Points(torch.zeros(1, 3))
309
- out.points = torch.cat([p.points for p in points], dim=0)
310
- if points[0].normals is not None:
311
- out.normals = torch.cat([p.normals for p in points], dim=0)
312
- if points[0].features is not None:
313
- out.features = torch.cat([p.features for p in points], dim=0)
314
- if points[0].labels is not None:
315
- out.labels = torch.cat([p.labels for p in points], dim=0)
316
- out.device = points[0].device
317
-
318
- if update_batch_info:
319
- out.batch_size = len(points)
320
- out.batch_npt = torch.Tensor([p.npt for p in points]).long()
321
- out.batch_id = torch.cat([p.points.new_full((p.npt, 1), i)
322
- for i, p in enumerate(points)], dim=0)
381
+ assert len(points) > 0, 'The input points list is empty.'
382
+ out = Points.init_points(points[0].device, batch_size=len(points))
383
+ out.merge_points(points, update_batch_info)
323
384
  return out
@@ -1,13 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocnn
3
- Version: 2.2.7
3
+ Version: 2.3.0
4
4
  Summary: Octree-based Sparse Convolutional Neural Networks
5
5
  Home-page: https://github.com/octree-nn/ocnn-pytorch
6
6
  Author: Peng-Shuai Wang
7
7
  Author-email: wangps@hotmail.com
8
8
  License: MIT
9
9
  Classifier: Programming Language :: Python :: 3
10
- Classifier: License :: OSI Approved :: MIT License
11
10
  Classifier: Operating System :: OS Independent
12
11
  Requires-Python: >=3.6
13
12
  Description-Content-Type: text/markdown
@@ -84,6 +83,13 @@ octrees to perform convolution operations. Of course, it also supports other 3D
84
83
  data formats, such as meshes and volumetric grids, which can be converted into
85
84
  octrees to leverage the library's capabilities.
86
85
 
86
+ ## Updates
87
+
88
+ - **2026.02.02**: Release `v2.3.0`, incorporating Triton to accelerate
89
+ octree-based sparse convolution in the upcoming release. OctreeConv is even
90
+ **2.5 times faster than the latest spconv**!
91
+ - **2025.12.18**: Release `v2.2.8`, improving neighbor search efficiency.
92
+
87
93
 
88
94
  ## Key benefits of ocnn-pytorch
89
95
 
@@ -93,10 +99,9 @@ octrees to leverage the library's capabilities.
93
99
  configure the compiling environment.
94
100
 
95
101
  - **Efficiency**. The ocnn-pytorch is very efficient compared with other sparse
96
- convolution frameworks. It only takes 18 hours to train the network on
97
- ScanNet for 600 epochs with 4 V100 GPUs. For reference, under the same
98
- training settings, MinkowskiNet 0.4.3 takes 60 hours and MinkowskiNet 0.5.4
99
- takes 30 hours.
102
+ convolution frameworks. It is **even 2.5 times faster than the latest spconv
103
+ implementation**! Check the benchmark [code](test/benchmark_conv.py) and
104
+ [results](test/benchmark/results.png) for details.
100
105
 
101
106
  ## Citation
102
107