ocnn 2.2.3__py3-none-any.whl → 2.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/__init__.py CHANGED
@@ -12,7 +12,7 @@ from . import models
12
12
  from . import dataset
13
13
  from . import utils
14
14
 
15
- __version__ = '2.2.3'
15
+ __version__ = '2.2.5'
16
16
 
17
17
  __all__ = [
18
18
  'octree',
@@ -7,7 +7,6 @@
7
7
 
8
8
  import torch
9
9
  import torch.nn
10
- from typing import Optional
11
10
 
12
11
  import ocnn
13
12
  from ocnn.octree import Octree
@@ -34,8 +33,9 @@ class AutoEncoder(torch.nn.Module):
34
33
  self.full_depth = full_depth
35
34
  self.feature = feature
36
35
  self.resblk_num = 2
37
- self.code_channel = 64 # dim-of-code = code_channel * 2**(3*full_depth)
38
36
  self.channels = [512, 512, 256, 256, 128, 128, 32, 32, 16, 16]
37
+ # dim-of-code = code_channel * 2**(3*full_depth)
38
+ self.code_channel = self.channels[full_depth]
39
39
 
40
40
  # encoder
41
41
  self.conv1 = ocnn.modules.OctreeConvBnRelu(
ocnn/models/ounet.py CHANGED
@@ -18,17 +18,17 @@ class OUNet(AutoEncoder):
18
18
 
19
19
  def __init__(self, channel_in: int, channel_out: int, depth: int,
20
20
  full_depth: int = 2, feature: str = 'ND'):
21
- super().__init__(channel_in, channel_out, depth, full_depth, feature,
22
- code_channel=-1) # !set code_channe=-1
21
+ super().__init__(channel_in, channel_out, depth, full_depth, feature)
23
22
  self.proj = None # remove this module used in AutoEncoder
24
23
 
25
24
  def encoder(self, octree):
26
- r''' The encoder network for extracting heirarchy features.
25
+ r''' The encoder network for extracting heirarchy features.
27
26
  '''
28
27
 
29
28
  convs = dict()
30
29
  depth, full_depth = self.depth, self.full_depth
31
- data = self.get_input_feature(octree)
30
+ data = octree.get_input_feature(self.feature, nempty=False)
31
+ assert data.size(1) == self.channel_in
32
32
  convs[depth] = self.conv1(data, octree, depth)
33
33
  for i, d in enumerate(range(depth, full_depth-1, -1)):
34
34
  convs[d] = self.encoder_blks[i](convs[d], octree, d)
ocnn/models/resnet.py CHANGED
@@ -31,7 +31,7 @@ class ResNet(torch.nn.Module):
31
31
  channels[i], channels[i+1], resblock_num, nempty=nempty)
32
32
  for i in range(stages-1)])
33
33
  self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
34
- nempty) for i in range(stages-1)])
34
+ nempty) for _ in range(stages-1)])
35
35
  self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
36
36
  # self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
37
37
  self.header = torch.nn.Sequential(
ocnn/models/segnet.py CHANGED
@@ -29,7 +29,7 @@ class SegNet(torch.nn.Module):
29
29
  [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
30
30
  for i in range(stages)])
31
31
  self.pools = torch.nn.ModuleList(
32
- [ocnn.nn.OctreeMaxPool(nempty, return_indices) for i in range(stages)])
32
+ [ocnn.nn.OctreeMaxPool(nempty, return_indices) for _ in range(stages)])
33
33
 
34
34
  self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
35
35
 
@@ -38,7 +38,7 @@ class SegNet(torch.nn.Module):
38
38
  [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
39
39
  for i in range(0, stages)])
40
40
  self.unpools = torch.nn.ModuleList(
41
- [ocnn.nn.OctreeMaxUnpool(nempty) for i in range(stages)])
41
+ [ocnn.nn.OctreeMaxUnpool(nempty) for _ in range(stages)])
42
42
 
43
43
  self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
44
44
  self.header = torch.nn.Sequential(
ocnn/modules/__init__.py CHANGED
@@ -7,14 +7,20 @@
7
7
 
8
8
  from .modules import (InputFeature,
9
9
  OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
10
- Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,)
11
- from .resblocks import OctreeResBlock, OctreeResBlock2, OctreeResBlocks
10
+ Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
11
+ OctreeConvGn, OctreeConvGnRelu, OctreeDeconvGnRelu,
12
+ Conv1x1, Conv1x1Gn, Conv1x1GnRelu)
13
+ from .resblocks import (OctreeResBlock, OctreeResBlock2, OctreeResBlockGn,
14
+ OctreeResBlocks,)
12
15
 
13
16
  __all__ = [
14
17
  'InputFeature',
15
18
  'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
16
19
  'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
17
- 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlocks',
20
+ 'OctreeConvGn', 'OctreeConvGnRelu', 'OctreeDeconvGnRelu',
21
+ 'Conv1x1', 'Conv1x1Gn', 'Conv1x1GnRelu',
22
+ 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlockGn',
23
+ 'OctreeResBlocks',
18
24
  ]
19
25
 
20
26
  classes = __all__
ocnn/modules/modules.py CHANGED
@@ -9,7 +9,7 @@ import torch
9
9
  import torch.utils.checkpoint
10
10
  from typing import List
11
11
 
12
- from ocnn.nn import OctreeConv, OctreeDeconv
12
+ from ocnn.nn import OctreeConv, OctreeDeconv, OctreeGroupNorm
13
13
  from ocnn.octree import Octree
14
14
 
15
15
 
@@ -40,7 +40,7 @@ class OctreeConvBn(torch.nn.Module):
40
40
  super().__init__()
41
41
  self.conv = OctreeConv(
42
42
  in_channels, out_channels, kernel_size, stride, nempty)
43
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
43
+ self.bn = torch.nn.BatchNorm1d(out_channels)
44
44
 
45
45
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
46
46
  r''''''
@@ -62,7 +62,7 @@ class OctreeConvBnRelu(torch.nn.Module):
62
62
  super().__init__()
63
63
  self.conv = OctreeConv(
64
64
  in_channels, out_channels, kernel_size, stride, nempty)
65
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
65
+ self.bn = torch.nn.BatchNorm1d(out_channels)
66
66
  self.relu = torch.nn.ReLU(inplace=True)
67
67
 
68
68
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
@@ -86,7 +86,7 @@ class OctreeDeconvBnRelu(torch.nn.Module):
86
86
  super().__init__()
87
87
  self.deconv = OctreeDeconv(
88
88
  in_channels, out_channels, kernel_size, stride, nempty)
89
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
89
+ self.bn = torch.nn.BatchNorm1d(out_channels)
90
90
  self.relu = torch.nn.ReLU(inplace=True)
91
91
 
92
92
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
@@ -123,7 +123,7 @@ class Conv1x1Bn(torch.nn.Module):
123
123
  def __init__(self, in_channels: int, out_channels: int):
124
124
  super().__init__()
125
125
  self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
126
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
126
+ self.bn = torch.nn.BatchNorm1d(out_channels)
127
127
 
128
128
  def forward(self, data: torch.Tensor):
129
129
  r''''''
@@ -140,7 +140,7 @@ class Conv1x1BnRelu(torch.nn.Module):
140
140
  def __init__(self, in_channels: int, out_channels: int):
141
141
  super().__init__()
142
142
  self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
143
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
143
+ self.bn = torch.nn.BatchNorm1d(out_channels)
144
144
  self.relu = torch.nn.ReLU(inplace=True)
145
145
 
146
146
  def forward(self, data: torch.Tensor):
@@ -160,7 +160,7 @@ class FcBnRelu(torch.nn.Module):
160
160
  super().__init__()
161
161
  self.flatten = torch.nn.Flatten(start_dim=1)
162
162
  self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
163
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
163
+ self.bn = torch.nn.BatchNorm1d(out_channels)
164
164
  self.relu = torch.nn.ReLU(inplace=True)
165
165
 
166
166
  def forward(self, data):
@@ -173,6 +173,116 @@ class FcBnRelu(torch.nn.Module):
173
173
  return out
174
174
 
175
175
 
176
+ class OctreeConvGn(torch.nn.Module):
177
+ r''' A sequence of :class:`OctreeConv` and :obj:`OctreeGroupNorm`.
178
+
179
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
180
+ '''
181
+
182
+ def __init__(self, in_channels: int, out_channels: int, group: int,
183
+ kernel_size: List[int] = [3], stride: int = 1,
184
+ nempty: bool = False):
185
+ super().__init__()
186
+ self.conv = OctreeConv(
187
+ in_channels, out_channels, kernel_size, stride, nempty)
188
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
189
+
190
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
191
+ r''''''
192
+
193
+ out = self.conv(data, octree, depth)
194
+ out = self.gn(out, octree, depth)
195
+ return out
196
+
197
+
198
+ class OctreeConvGnRelu(torch.nn.Module):
199
+ r''' A sequence of :class:`OctreeConv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
200
+
201
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
202
+ '''
203
+
204
+ def __init__(self, in_channels: int, out_channels: int, group: int,
205
+ kernel_size: List[int] = [3], stride: int = 1,
206
+ nempty: bool = False):
207
+ super().__init__()
208
+ self.stride = stride
209
+ self.conv = OctreeConv(
210
+ in_channels, out_channels, kernel_size, stride, nempty)
211
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
212
+ self.relu = torch.nn.ReLU(inplace=True)
213
+
214
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
215
+ r''''''
216
+
217
+ out = self.conv(data, octree, depth)
218
+ out = self.gn(out, octree, depth if self.stride == 1 else depth - 1)
219
+ out = self.relu(out)
220
+ return out
221
+
222
+
223
+ class OctreeDeconvGnRelu(torch.nn.Module):
224
+ r''' A sequence of :class:`OctreeDeconv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
225
+
226
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
227
+ '''
228
+
229
+ def __init__(self, in_channels: int, out_channels: int, group: int,
230
+ kernel_size: List[int] = [3], stride: int = 1,
231
+ nempty: bool = False):
232
+ super().__init__()
233
+ self.stride = stride
234
+ self.deconv = OctreeDeconv(
235
+ in_channels, out_channels, kernel_size, stride, nempty)
236
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
237
+ self.relu = torch.nn.ReLU(inplace=True)
238
+
239
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
240
+ r''''''
241
+
242
+ out = self.deconv(data, octree, depth)
243
+ out = self.gn(out, octree, depth if self.stride == 1 else depth + 1)
244
+ out = self.relu(out)
245
+ return out
246
+
247
+
248
+ class Conv1x1Gn(torch.nn.Module):
249
+ r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm`.
250
+ '''
251
+
252
+ def __init__(self, in_channels: int, out_channels: int, group: int,
253
+ nempty: bool = False):
254
+ super().__init__()
255
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
256
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
257
+
258
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
259
+ r''''''
260
+
261
+ out = self.conv(data)
262
+ out = self.gn(out, octree, depth)
263
+ return out
264
+
265
+
266
+ class Conv1x1GnRelu(torch.nn.Module):
267
+ r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm` and :class:`Relu`.
268
+ '''
269
+
270
+ def __init__(self, in_channels: int, out_channels: int, group: int,
271
+ nempty: bool = False):
272
+ super().__init__()
273
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
274
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
275
+ self.relu = torch.nn.ReLU(inplace=True)
276
+
277
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
278
+ r''''''
279
+
280
+ out = self.conv(data)
281
+ out = self.gn(out, octree, depth)
282
+ out = self.relu(out)
283
+ return out
284
+
285
+
176
286
  class InputFeature(torch.nn.Module):
177
287
  r''' Returns the initial input feature stored in octree.
178
288
 
ocnn/modules/resblocks.py CHANGED
@@ -10,7 +10,9 @@ import torch.utils.checkpoint
10
10
 
11
11
  from ocnn.octree import Octree
12
12
  from ocnn.nn import OctreeMaxPool
13
- from ocnn.modules import Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn, OctreeConvBn
13
+ from ocnn.modules import (Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
14
+ OctreeConvBn, OctreeConvGnRelu, Conv1x1Gn,
15
+ OctreeConvGn,)
14
16
 
15
17
 
16
18
  class OctreeResBlock(torch.nn.Module):
@@ -97,6 +99,38 @@ class OctreeResBlock2(torch.nn.Module):
97
99
  return out
98
100
 
99
101
 
102
+ class OctreeResBlockGn(torch.nn.Module):
103
+
104
+ def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
105
+ bottleneck: int = 4, nempty: bool = False, group: int = 32):
106
+ super().__init__()
107
+ self.in_channels = in_channels
108
+ self.out_channels = out_channels
109
+ self.stride = stride
110
+ channelb = int(out_channels / bottleneck)
111
+
112
+ if self.stride == 2:
113
+ self.maxpool = OctreeMaxPool(self.depth)
114
+ self.conv3x3a = OctreeConvGnRelu(in_channels, channelb, group, nempty=nempty)
115
+ self.conv3x3b = OctreeConvGn(channelb, out_channels, group, nempty=nempty)
116
+ if self.in_channels != self.out_channels:
117
+ self.conv1x1 = Conv1x1Gn(in_channels, out_channels, group)
118
+ self.relu = torch.nn.ReLU(inplace=True)
119
+
120
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
121
+ r''''''
122
+
123
+ if self.stride == 2:
124
+ data = self.maxpool(data, octree, depth)
125
+ depth = depth - 1
126
+ conv1 = self.conv3x3a(data, octree, depth)
127
+ conv2 = self.conv3x3b(conv1, octree, depth)
128
+ if self.in_channels != self.out_channels:
129
+ data = self.conv1x1(data, octree, depth)
130
+ out = self.relu(conv2 + data)
131
+ return out
132
+
133
+
100
134
  class OctreeResBlocks(torch.nn.Module):
101
135
  r''' A sequence of :attr:`resblk_num` ResNet blocks.
102
136
  '''
@@ -108,9 +142,9 @@ class OctreeResBlocks(torch.nn.Module):
108
142
  self.use_checkpoint = use_checkpoint
109
143
  channels = [in_channels] + [out_channels] * resblk_num
110
144
 
111
- self.resblks = torch.nn.ModuleList(
112
- [resblk(channels[i], channels[i+1], 1, bottleneck, nempty)
113
- for i in range(self.resblk_num)])
145
+ self.resblks = torch.nn.ModuleList([resblk(
146
+ channels[i], channels[i+1], 1, bottleneck, nempty)
147
+ for i in range(self.resblk_num)])
114
148
 
115
149
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
116
150
  r''''''
@@ -118,7 +152,7 @@ class OctreeResBlocks(torch.nn.Module):
118
152
  for i in range(self.resblk_num):
119
153
  if self.use_checkpoint:
120
154
  data = torch.utils.checkpoint.checkpoint(
121
- self.resblks[i], data, octree, depth)
155
+ self.resblks[i], data, octree, depth, use_reentrant=False)
122
156
  else:
123
157
  data = self.resblks[i](data, octree, depth)
124
158
  return data
ocnn/nn/__init__.py CHANGED
@@ -17,7 +17,8 @@ from .octree_pool import (octree_max_pool, OctreeMaxPool,
17
17
  from .octree_conv import OctreeConv, OctreeDeconv
18
18
  from .octree_gconv import OctreeGroupConv
19
19
  from .octree_dwconv import OctreeDWConv
20
- from .octree_norm import OctreeBatchNorm, OctreeGroupNorm, OctreeInstanceNorm
20
+ from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
21
+ OctreeInstanceNorm, OctreeNorm)
21
22
  from .octree_drop import OctreeDropPath
22
23
  from .octree_align import search_value, octree_align
23
24
 
@@ -35,7 +36,7 @@ __all__ = [
35
36
  'OctreeConv', 'OctreeDeconv',
36
37
  'OctreeGroupConv', 'OctreeDWConv',
37
38
  'OctreeInterp', 'OctreeUpsample',
38
- 'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm',
39
+ 'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
39
40
  'OctreeDropPath',
40
41
  'search_value', 'octree_align',
41
42
  ]
ocnn/nn/octree_norm.py CHANGED
@@ -7,6 +7,7 @@
7
7
 
8
8
  import torch
9
9
  import torch.nn
10
+ from typing import Optional
10
11
 
11
12
  from ocnn.octree import Octree
12
13
  from ocnn.utils import scatter_add
@@ -19,15 +20,19 @@ class OctreeGroupNorm(torch.nn.Module):
19
20
  r''' An group normalization layer for the octree.
20
21
  '''
21
22
 
22
- def __init__(self, in_channels: int, group: int, nempty: bool = False):
23
+ def __init__(self, in_channels: int, group: int, nempty: bool = False,
24
+ min_group_channels: int = 4):
23
25
  super().__init__()
24
26
  self.eps = 1e-5
25
27
  self.nempty = nempty
26
28
  self.group = group
27
29
  self.in_channels = in_channels
30
+ self.min_group_channels = min_group_channels
31
+ if self.min_group_channels * self.group > in_channels:
32
+ self.group = in_channels // self.min_group_channels
28
33
 
29
- assert in_channels % group == 0
30
- self.channels_per_group = in_channels // group
34
+ assert in_channels % self.group == 0
35
+ self.channels_per_group = in_channels // self.group
31
36
 
32
37
  self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
33
38
  self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
@@ -71,8 +76,8 @@ class OctreeGroupNorm(torch.nn.Module):
71
76
  return tensor
72
77
 
73
78
  def extra_repr(self) -> str:
74
- return ('in_channels={}, group={}, nempty={}').format(
75
- self.in_channels, self.group, self.nempty) # noqa
79
+ return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
80
+ self.in_channels, self.group, self.nempty, self.min_group_channels)
76
81
 
77
82
 
78
83
  class OctreeInstanceNorm(OctreeGroupNorm):
@@ -80,7 +85,42 @@ class OctreeInstanceNorm(OctreeGroupNorm):
80
85
  '''
81
86
 
82
87
  def __init__(self, in_channels: int, nempty: bool = False):
83
- super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty)
88
+ super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
89
+ min_group_channels=1) # NOTE: group=in_channels
84
90
 
85
91
  def extra_repr(self) -> str:
86
92
  return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
93
+
94
+
95
+ class OctreeNorm(torch.nn.Module):
96
+ r''' A normalization layer for the octree. It encapsulates octree-based batch,
97
+ group and instance normalization.
98
+ '''
99
+
100
+ def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
101
+ group: int = 32, min_group_channels: int = 4):
102
+ super().__init__()
103
+ self.in_channels = in_channels
104
+ self.norm_type = norm_type
105
+ self.group = group
106
+ self.min_group_channels = min_group_channels
107
+
108
+ if self.norm_type == 'batch_norm':
109
+ self.norm = torch.nn.BatchNorm1d(in_channels)
110
+ elif self.norm_type == 'group_norm':
111
+ self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
112
+ elif self.norm_type == 'instance_norm':
113
+ self.norm = OctreeInstanceNorm(in_channels)
114
+ else:
115
+ raise ValueError
116
+
117
+ def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
118
+ depth: Optional[int] = None):
119
+ if self.norm_type == 'batch_norm':
120
+ output = self.norm(x)
121
+ elif (self.norm_type == 'group_norm' or
122
+ self.norm_type == 'instance_norm'):
123
+ output = self.norm(x, octree, depth)
124
+ else:
125
+ raise ValueError
126
+ return output
ocnn/octree/octree.py CHANGED
@@ -35,7 +35,9 @@ class Octree:
35
35
  and :obj:`points`, contain only non-empty nodes.
36
36
 
37
37
  .. note::
38
- The point cloud must be in range :obj:`[-1, 1]`.
38
+ The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
39
+ is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
40
+ some margin.
39
41
  '''
40
42
 
41
43
  def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
@@ -63,14 +65,14 @@ class Octree:
63
65
 
64
66
  # octree node numbers in each octree layers.
65
67
  # TODO: decide whether to settle them to 'gpu' or not?
66
- self.nnum = torch.zeros(num, dtype=torch.int32)
67
- self.nnum_nempty = torch.zeros(num, dtype=torch.int32)
68
+ self.nnum = torch.zeros(num, dtype=torch.long)
69
+ self.nnum_nempty = torch.zeros(num, dtype=torch.long)
68
70
 
69
71
  # the following properties are valid after `merge_octrees`.
70
72
  # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
71
73
  batch_size = self.batch_size
72
- self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.int32)
73
- self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.int32)
74
+ self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
75
+ self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
74
76
 
75
77
  # construct the look up tables for neighborhood searching
76
78
  device = self.device
@@ -176,7 +178,7 @@ class Octree:
176
178
  for d in range(self.depth, self.full_depth, -1):
177
179
  # compute parent key, i.e. keys of layer (d -1)
178
180
  pkey = node_key >> 3
179
- pkey, pidx, pcounts = torch.unique_consecutive(
181
+ pkey, pidx, _ = torch.unique_consecutive(
180
182
  pkey, return_inverse=True, return_counts=True)
181
183
 
182
184
  # augmented key
@@ -287,10 +289,27 @@ class Octree:
287
289
  update_neigh (bool): If True, construct the neighborhood indices.
288
290
  '''
289
291
 
292
+ # increase the octree depth if required
293
+ if depth > self.depth:
294
+ assert depth == self.depth + 1
295
+ self.depth = depth
296
+ self.keys.append(None)
297
+ self.children.append(None)
298
+ self.neighs.append(None)
299
+ self.features.append(None)
300
+ self.normals.append(None)
301
+ self.points.append(None)
302
+ zero = torch.zeros(1, dtype=torch.long)
303
+ self.nnum = torch.cat([self.nnum, zero])
304
+ self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
305
+ zero = zero.view(1, 1)
306
+ self.batch_nnum = torch.cat([self.batch_nnum, zero], dim=0)
307
+ self.batch_nnum_nempty = torch.cat([self.batch_nnum_nempty, zero], dim=0)
308
+
290
309
  # node number
291
310
  nnum = self.nnum_nempty[depth-1] * 8
292
311
  self.nnum[depth] = nnum
293
- self.nnum_nempty[depth] = nnum
312
+ self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
294
313
 
295
314
  # update keys
296
315
  key = self.key(depth-1, nempty=True)
@@ -326,7 +345,7 @@ class Octree:
326
345
  xyz = xyz.view(-1, 3) # (N*27, 3)
327
346
  neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
328
347
 
329
- bs = torch.arange(self.batch_size, dtype=torch.int32, device=device)
348
+ bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
330
349
  neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
331
350
 
332
351
  bound = 1 << depth
@@ -383,9 +402,10 @@ class Octree:
383
402
  # I choose `torch.bucketize` here because it has fewer dimension checks,
384
403
  # resulting in slightly better performance according to the docs of
385
404
  # pytorch-1.9.1, since `key` is always 1-D sorted sequence.
405
+ # https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
386
406
  idx = torch.bucketize(query, key)
387
407
 
388
- valid = idx < key.shape[0] # valid if NOT out-of-bound
408
+ valid = idx < key.shape[0] # valid if in-bound
389
409
  found = key[idx[valid]] == query[valid]
390
410
  valid[valid.clone()] = found # valid if found
391
411
  idx[valid.logical_not()] = -1 # set to -1 if invalid
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ocnn
3
- Version: 2.2.3
3
+ Version: 2.2.5
4
4
  Summary: Octree-based Sparse Convolutional Neural Networks
5
5
  Home-page: https://github.com/octree-nn/ocnn-pytorch
6
6
  Author: Peng-Shuai Wang
@@ -23,6 +23,7 @@ Requires-Dist: packaging
23
23
 
24
24
  [![Documentation Status](https://readthedocs.org/projects/ocnn-pytorch/badge/?version=latest)](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
25
25
  [![Downloads](https://static.pepy.tech/badge/ocnn)](https://pepy.tech/project/ocnn)
26
+ [![Downloads](https://static.pepy.tech/badge/ocnn/month)](https://pepy.tech/project/ocnn)
26
27
  [![PyPI](https://img.shields.io/pypi/v/ocnn)](https://pypi.org/project/ocnn/)
27
28
 
28
29
  This repository contains the **pure PyTorch**-based implementation of
@@ -1,19 +1,19 @@
1
- ocnn/__init__.py,sha256=kCGigMn_30MohB49Hwy0ZBXf5HTUA_i-QOjckrM48Nc,582
1
+ ocnn/__init__.py,sha256=z0LmCb7hdmNZekWucdzmEvhYse0eJ5L3jGScTreJ7ok,582
2
2
  ocnn/dataset.py,sha256=wvclvjlZs9qTeMXWLaO32K5d1VVY9XHSNuVVJEpVeeo,5266
3
3
  ocnn/utils.py,sha256=XhykveOjHoQd94gjJ5-opzXs-9MOCAzZ34ArZ8mG4sE,6726
4
4
  ocnn/models/__init__.py,sha256=F9PJRhOPHc1OrwkqcfywEBW0J6jmVW7-IHgWjGpY15U,724
5
- ocnn/models/autoencoder.py,sha256=TjOet3dbZLexz-PvZdeV0mbIuGpUAeSH6KWU_z_s-d8,5906
5
+ ocnn/models/autoencoder.py,sha256=nkKMtSPPdKhQXRxFaRNZsPjRfWuQet5Gz9FQlyLUlEQ,5904
6
6
  ocnn/models/hrnet.py,sha256=9W2fi7Fuw0JXDBiZtEoUW2K7ghtpWUm_BWd-mKoHLY0,6684
7
7
  ocnn/models/image2shape.py,sha256=5djcOHJh2SQCwd5XdLPeL5vlQDNWnRU3tJY3ojRI8aQ,4589
8
8
  ocnn/models/lenet.py,sha256=ujVBxnn8AKiIKqB4WHxuK728oJe4TZdqZUTTNy8y3zE,1754
9
- ocnn/models/ounet.py,sha256=sGGVHqF1nKw798vyththxp7-P3wHlsDUO39hbQpUVTw,3303
10
- ocnn/models/resnet.py,sha256=9gZKbhFituqPJHCm-bp1xTtHKZ6wu77s31a0q7RtXoQ,2029
11
- ocnn/models/segnet.py,sha256=VfZf8gBMSPgO5m8Agsfccw9snTrI4LGh0SqxhLZ1F8s,2575
9
+ ocnn/models/ounet.py,sha256=Z9bJqt_C8uHNw5IzD7q2vR1CH5bMkOrtI0UWSyp74TU,3306
10
+ ocnn/models/resnet.py,sha256=_bRaLBYK5yVqklYEvfPjcndWvMpOKaygnSl1pyjT-4w,2029
11
+ ocnn/models/segnet.py,sha256=A3KWF-kJQ1M-ByJKv88wfMsdZYccBzP_p9W_0EzNB6w,2575
12
12
  ocnn/models/unet.py,sha256=1FZbTvmWg6sMYkcZyNxcSr_qN6bfOOag-tIwAoqIPKU,4123
13
- ocnn/modules/__init__.py,sha256=BAEZybvtwDQf7yPZpGcptH1Oxf654Alq87G9D0vKw-E,820
14
- ocnn/modules/modules.py,sha256=7VlBsbwN49J8Xea0DXcFHVhN8POxhzC3FNMY3_JHukM,6207
15
- ocnn/modules/resblocks.py,sha256=Xh5EH6aSx8zcLk3D4m4lZiWDejLMu5AQDH7ZDI4p23A,4548
16
- ocnn/nn/__init__.py,sha256=6M3GEbyepHCU7l_QU1blSI98M0X4d3jbCsXmVGtXYj0,1769
13
+ ocnn/modules/__init__.py,sha256=pRJLNGM4F5a3QO2Y6AJqKnaOJqyR4nApVzNQxJ1vJgQ,1132
14
+ ocnn/modules/modules.py,sha256=l9pkJUgg7_pz9310eJ_806T7brACIqtTvbcK1eFBw-U,9846
15
+ ocnn/modules/resblocks.py,sha256=GHRTwzRX-QG9FnHNgSej3RdJJYSxBVLc5GWQEj7bjjI,5866
16
+ ocnn/nn/__init__.py,sha256=GPOng7D-vfs3FCea9jrRKaGumAkyrMz0f5XWGx2v2Fs,1824
17
17
  ocnn/nn/octree2col.py,sha256=07BGGJD0je0VD-VdS_aKtDo7gWKNWdgojeL1a2n4VRQ,2137
18
18
  ocnn/nn/octree2vox.py,sha256=QgyxxZvRvw2taHFWCvcGyCFAnFC6D6nfw71HrvwN3PI,1524
19
19
  ocnn/nn/octree_align.py,sha256=Y12GBKS-F3JtRNaDdDcBFDBvhndoRxeMOVua4RB5HZE,1696
@@ -22,15 +22,15 @@ ocnn/nn/octree_drop.py,sha256=croMHtk0JScDT0nLpdmbiMnkM_b5uVAz6sOEUcta6sY,1963
22
22
  ocnn/nn/octree_dwconv.py,sha256=cIghi7zyMUGhTd6QsU4PS6K2VlwXaYeLNfZFt2xuVi4,7350
23
23
  ocnn/nn/octree_gconv.py,sha256=Ogkn7wE49dDdP0X_aBthNfJXMsqcRn3ufJuYQtEEjUI,2928
24
24
  ocnn/nn/octree_interp.py,sha256=yQVjiKMNLU1XakPeYjW5ArMsw2fxFRZdh90Nn3cWQuE,7108
25
- ocnn/nn/octree_norm.py,sha256=Mbn28Hv-CEWt2WA0Pdhj2p127m10vwhepzCJVjiZYMI,2976
25
+ ocnn/nn/octree_norm.py,sha256=XrSQ7oZKfepYqQIuPU4loaYNae_rgtPpekR0cqTvryg,4524
26
26
  ocnn/nn/octree_pad.py,sha256=suV6Ftb-UlUuoJzKdQ9DCP9oqVQJq7vXb9_6hq-kUk4,1323
27
27
  ocnn/nn/octree_pool.py,sha256=Zn2XLk5SFl6pqMhhKIvu1uZl5ebonFSlPcKr54fOIPA,6664
28
28
  ocnn/octree/__init__.py,sha256=vKZFc5_r6Gxg5KsPWiCZCR-umWWfWPoE7qBx4PIrUGA,630
29
- ocnn/octree/octree.py,sha256=PgWiECd3h49AZHrK_0GxwwAgnG9l37AcJILanM7d1-k,23713
29
+ ocnn/octree/octree.py,sha256=SNhzI-UOce2oPtQ_KLmR2lduY2P1_ne5fe_4BLEhlv0,24590
30
30
  ocnn/octree/points.py,sha256=7_iiN0y0g9SgV8kLzuGW20C-MFOmz5hh1aQhDhUIa0Q,11355
31
31
  ocnn/octree/shuffled_key.py,sha256=UJZ4eKNA_7nLbf9FbEvS_3VyrAqnZCOzk1hsPtJianM,3936
32
- ocnn-2.2.3.dist-info/LICENSE,sha256=YeOS0Plo8Uistv_8ZXdgddmN9GHJKnIiJ5FZ8zTW6Sw,1114
33
- ocnn-2.2.3.dist-info/METADATA,sha256=Uh2UlUwXzRfM9G9d1EBRIiZARcumPCVZPY3CLjR2q8U,3682
34
- ocnn-2.2.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
35
- ocnn-2.2.3.dist-info/top_level.txt,sha256=ayZdVOnxlOke3kgzAlrRh2IEL_qOudwOaEU3xhjtpZ0,5
36
- ocnn-2.2.3.dist-info/RECORD,,
32
+ ocnn-2.2.5.dist-info/LICENSE,sha256=YeOS0Plo8Uistv_8ZXdgddmN9GHJKnIiJ5FZ8zTW6Sw,1114
33
+ ocnn-2.2.5.dist-info/METADATA,sha256=1r105xmAHwGG9tYy-xcnXuLCZCWVH_jrtao8wLZgCtk,3773
34
+ ocnn-2.2.5.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
35
+ ocnn-2.2.5.dist-info/top_level.txt,sha256=ayZdVOnxlOke3kgzAlrRh2IEL_qOudwOaEU3xhjtpZ0,5
36
+ ocnn-2.2.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.5.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
File without changes