ocnn 2.2.3__py3-none-any.whl → 2.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/__init__.py CHANGED
@@ -12,7 +12,7 @@ from . import models
12
12
  from . import dataset
13
13
  from . import utils
14
14
 
15
- __version__ = '2.2.3'
15
+ __version__ = '2.2.4'
16
16
 
17
17
  __all__ = [
18
18
  'octree',
@@ -7,7 +7,6 @@
7
7
 
8
8
  import torch
9
9
  import torch.nn
10
- from typing import Optional
11
10
 
12
11
  import ocnn
13
12
  from ocnn.octree import Octree
ocnn/models/ounet.py CHANGED
@@ -23,12 +23,13 @@ class OUNet(AutoEncoder):
23
23
  self.proj = None # remove this module used in AutoEncoder
24
24
 
25
25
  def encoder(self, octree):
26
- r''' The encoder network for extracting heirarchy features.
26
+ r''' The encoder network for extracting heirarchy features.
27
27
  '''
28
28
 
29
29
  convs = dict()
30
30
  depth, full_depth = self.depth, self.full_depth
31
- data = self.get_input_feature(octree)
31
+ data = octree.get_input_feature(self.feature, nempty=False)
32
+ assert data.size(1) == self.channel_in
32
33
  convs[depth] = self.conv1(data, octree, depth)
33
34
  for i, d in enumerate(range(depth, full_depth-1, -1)):
34
35
  convs[d] = self.encoder_blks[i](convs[d], octree, d)
ocnn/models/resnet.py CHANGED
@@ -31,7 +31,7 @@ class ResNet(torch.nn.Module):
31
31
  channels[i], channels[i+1], resblock_num, nempty=nempty)
32
32
  for i in range(stages-1)])
33
33
  self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
34
- nempty) for i in range(stages-1)])
34
+ nempty) for _ in range(stages-1)])
35
35
  self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
36
36
  # self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
37
37
  self.header = torch.nn.Sequential(
ocnn/models/segnet.py CHANGED
@@ -29,7 +29,7 @@ class SegNet(torch.nn.Module):
29
29
  [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
30
30
  for i in range(stages)])
31
31
  self.pools = torch.nn.ModuleList(
32
- [ocnn.nn.OctreeMaxPool(nempty, return_indices) for i in range(stages)])
32
+ [ocnn.nn.OctreeMaxPool(nempty, return_indices) for _ in range(stages)])
33
33
 
34
34
  self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
35
35
 
@@ -38,7 +38,7 @@ class SegNet(torch.nn.Module):
38
38
  [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
39
39
  for i in range(0, stages)])
40
40
  self.unpools = torch.nn.ModuleList(
41
- [ocnn.nn.OctreeMaxUnpool(nempty) for i in range(stages)])
41
+ [ocnn.nn.OctreeMaxUnpool(nempty) for _ in range(stages)])
42
42
 
43
43
  self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
44
44
  self.header = torch.nn.Sequential(
ocnn/modules/__init__.py CHANGED
@@ -7,14 +7,20 @@
7
7
 
8
8
  from .modules import (InputFeature,
9
9
  OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
10
- Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,)
11
- from .resblocks import OctreeResBlock, OctreeResBlock2, OctreeResBlocks
10
+ Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
11
+ OctreeConvGn, OctreeConvGnRelu, OctreeDeconvGnRelu,
12
+ Conv1x1, Conv1x1Gn, Conv1x1GnRelu)
13
+ from .resblocks import (OctreeResBlock, OctreeResBlock2, OctreeResBlockGn,
14
+ OctreeResBlocks,)
12
15
 
13
16
  __all__ = [
14
17
  'InputFeature',
15
18
  'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
16
19
  'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
17
- 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlocks',
20
+ 'OctreeConvGn', 'OctreeConvGnRelu', 'OctreeDeconvGnRelu',
21
+ 'Conv1x1', 'Conv1x1Gn', 'Conv1x1GnRelu',
22
+ 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlockGn',
23
+ 'OctreeResBlocks',
18
24
  ]
19
25
 
20
26
  classes = __all__
ocnn/modules/modules.py CHANGED
@@ -9,7 +9,7 @@ import torch
9
9
  import torch.utils.checkpoint
10
10
  from typing import List
11
11
 
12
- from ocnn.nn import OctreeConv, OctreeDeconv
12
+ from ocnn.nn import OctreeConv, OctreeDeconv, OctreeGroupNorm
13
13
  from ocnn.octree import Octree
14
14
 
15
15
 
@@ -40,7 +40,7 @@ class OctreeConvBn(torch.nn.Module):
40
40
  super().__init__()
41
41
  self.conv = OctreeConv(
42
42
  in_channels, out_channels, kernel_size, stride, nempty)
43
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
43
+ self.bn = torch.nn.BatchNorm1d(out_channels)
44
44
 
45
45
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
46
46
  r''''''
@@ -62,7 +62,7 @@ class OctreeConvBnRelu(torch.nn.Module):
62
62
  super().__init__()
63
63
  self.conv = OctreeConv(
64
64
  in_channels, out_channels, kernel_size, stride, nempty)
65
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
65
+ self.bn = torch.nn.BatchNorm1d(out_channels)
66
66
  self.relu = torch.nn.ReLU(inplace=True)
67
67
 
68
68
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
@@ -86,7 +86,7 @@ class OctreeDeconvBnRelu(torch.nn.Module):
86
86
  super().__init__()
87
87
  self.deconv = OctreeDeconv(
88
88
  in_channels, out_channels, kernel_size, stride, nempty)
89
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
89
+ self.bn = torch.nn.BatchNorm1d(out_channels)
90
90
  self.relu = torch.nn.ReLU(inplace=True)
91
91
 
92
92
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
@@ -123,7 +123,7 @@ class Conv1x1Bn(torch.nn.Module):
123
123
  def __init__(self, in_channels: int, out_channels: int):
124
124
  super().__init__()
125
125
  self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
126
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
126
+ self.bn = torch.nn.BatchNorm1d(out_channels)
127
127
 
128
128
  def forward(self, data: torch.Tensor):
129
129
  r''''''
@@ -140,7 +140,7 @@ class Conv1x1BnRelu(torch.nn.Module):
140
140
  def __init__(self, in_channels: int, out_channels: int):
141
141
  super().__init__()
142
142
  self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
143
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
143
+ self.bn = torch.nn.BatchNorm1d(out_channels)
144
144
  self.relu = torch.nn.ReLU(inplace=True)
145
145
 
146
146
  def forward(self, data: torch.Tensor):
@@ -160,7 +160,7 @@ class FcBnRelu(torch.nn.Module):
160
160
  super().__init__()
161
161
  self.flatten = torch.nn.Flatten(start_dim=1)
162
162
  self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
163
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
163
+ self.bn = torch.nn.BatchNorm1d(out_channels)
164
164
  self.relu = torch.nn.ReLU(inplace=True)
165
165
 
166
166
  def forward(self, data):
@@ -173,6 +173,116 @@ class FcBnRelu(torch.nn.Module):
173
173
  return out
174
174
 
175
175
 
176
+ class OctreeConvGn(torch.nn.Module):
177
+ r''' A sequence of :class:`OctreeConv` and :obj:`OctreeGroupNorm`.
178
+
179
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
180
+ '''
181
+
182
+ def __init__(self, in_channels: int, out_channels: int, group: int,
183
+ kernel_size: List[int] = [3], stride: int = 1,
184
+ nempty: bool = False):
185
+ super().__init__()
186
+ self.conv = OctreeConv(
187
+ in_channels, out_channels, kernel_size, stride, nempty)
188
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
189
+
190
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
191
+ r''''''
192
+
193
+ out = self.conv(data, octree, depth)
194
+ out = self.gn(out, octree, depth)
195
+ return out
196
+
197
+
198
+ class OctreeConvGnRelu(torch.nn.Module):
199
+ r''' A sequence of :class:`OctreeConv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
200
+
201
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
202
+ '''
203
+
204
+ def __init__(self, in_channels: int, out_channels: int, group: int,
205
+ kernel_size: List[int] = [3], stride: int = 1,
206
+ nempty: bool = False):
207
+ super().__init__()
208
+ self.stride = stride
209
+ self.conv = OctreeConv(
210
+ in_channels, out_channels, kernel_size, stride, nempty)
211
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
212
+ self.relu = torch.nn.ReLU(inplace=True)
213
+
214
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
215
+ r''''''
216
+
217
+ out = self.conv(data, octree, depth)
218
+ out = self.gn(out, octree, depth if self.stride == 1 else depth - 1)
219
+ out = self.relu(out)
220
+ return out
221
+
222
+
223
+ class OctreeDeconvGnRelu(torch.nn.Module):
224
+ r''' A sequence of :class:`OctreeDeconv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
225
+
226
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
227
+ '''
228
+
229
+ def __init__(self, in_channels: int, out_channels: int, group: int,
230
+ kernel_size: List[int] = [3], stride: int = 1,
231
+ nempty: bool = False):
232
+ super().__init__()
233
+ self.stride = stride
234
+ self.deconv = OctreeDeconv(
235
+ in_channels, out_channels, kernel_size, stride, nempty)
236
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
237
+ self.relu = torch.nn.ReLU(inplace=True)
238
+
239
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
240
+ r''''''
241
+
242
+ out = self.deconv(data, octree, depth)
243
+ out = self.gn(out, octree, depth if self.stride == 1 else depth + 1)
244
+ out = self.relu(out)
245
+ return out
246
+
247
+
248
+ class Conv1x1Gn(torch.nn.Module):
249
+ r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm`.
250
+ '''
251
+
252
+ def __init__(self, in_channels: int, out_channels: int, group: int,
253
+ nempty: bool = False):
254
+ super().__init__()
255
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
256
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
257
+
258
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
259
+ r''''''
260
+
261
+ out = self.conv(data)
262
+ out = self.gn(out, octree, depth)
263
+ return out
264
+
265
+
266
+ class Conv1x1GnRelu(torch.nn.Module):
267
+ r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm` and :class:`Relu`.
268
+ '''
269
+
270
+ def __init__(self, in_channels: int, out_channels: int, group: int,
271
+ nempty: bool = False):
272
+ super().__init__()
273
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
274
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
275
+ self.relu = torch.nn.ReLU(inplace=True)
276
+
277
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
278
+ r''''''
279
+
280
+ out = self.conv(data)
281
+ out = self.gn(out, octree, depth)
282
+ out = self.relu(out)
283
+ return out
284
+
285
+
176
286
  class InputFeature(torch.nn.Module):
177
287
  r''' Returns the initial input feature stored in octree.
178
288
 
ocnn/modules/resblocks.py CHANGED
@@ -10,7 +10,9 @@ import torch.utils.checkpoint
10
10
 
11
11
  from ocnn.octree import Octree
12
12
  from ocnn.nn import OctreeMaxPool
13
- from ocnn.modules import Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn, OctreeConvBn
13
+ from ocnn.modules import (Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
14
+ OctreeConvBn, OctreeConvGnRelu, Conv1x1Gn,
15
+ OctreeConvGn,)
14
16
 
15
17
 
16
18
  class OctreeResBlock(torch.nn.Module):
@@ -97,6 +99,38 @@ class OctreeResBlock2(torch.nn.Module):
97
99
  return out
98
100
 
99
101
 
102
+ class OctreeResBlockGn(torch.nn.Module):
103
+
104
+ def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
105
+ bottleneck: int = 4, nempty: bool = False, group: int = 32):
106
+ super().__init__()
107
+ self.in_channels = in_channels
108
+ self.out_channels = out_channels
109
+ self.stride = stride
110
+ channelb = int(out_channels / bottleneck)
111
+
112
+ if self.stride == 2:
113
+ self.maxpool = OctreeMaxPool(self.depth)
114
+ self.conv3x3a = OctreeConvGnRelu(in_channels, channelb, group, nempty=nempty)
115
+ self.conv3x3b = OctreeConvGn(channelb, out_channels, group, nempty=nempty)
116
+ if self.in_channels != self.out_channels:
117
+ self.conv1x1 = Conv1x1Gn(in_channels, out_channels, group)
118
+ self.relu = torch.nn.ReLU(inplace=True)
119
+
120
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
121
+ r''''''
122
+
123
+ if self.stride == 2:
124
+ data = self.maxpool(data, octree, depth)
125
+ depth = depth - 1
126
+ conv1 = self.conv3x3a(data, octree, depth)
127
+ conv2 = self.conv3x3b(conv1, octree, depth)
128
+ if self.in_channels != self.out_channels:
129
+ data = self.conv1x1(data, octree, depth)
130
+ out = self.relu(conv2 + data)
131
+ return out
132
+
133
+
100
134
  class OctreeResBlocks(torch.nn.Module):
101
135
  r''' A sequence of :attr:`resblk_num` ResNet blocks.
102
136
  '''
@@ -108,9 +142,9 @@ class OctreeResBlocks(torch.nn.Module):
108
142
  self.use_checkpoint = use_checkpoint
109
143
  channels = [in_channels] + [out_channels] * resblk_num
110
144
 
111
- self.resblks = torch.nn.ModuleList(
112
- [resblk(channels[i], channels[i+1], 1, bottleneck, nempty)
113
- for i in range(self.resblk_num)])
145
+ self.resblks = torch.nn.ModuleList([resblk(
146
+ channels[i], channels[i+1], 1, bottleneck, nempty)
147
+ for i in range(self.resblk_num)])
114
148
 
115
149
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
116
150
  r''''''
@@ -118,7 +152,7 @@ class OctreeResBlocks(torch.nn.Module):
118
152
  for i in range(self.resblk_num):
119
153
  if self.use_checkpoint:
120
154
  data = torch.utils.checkpoint.checkpoint(
121
- self.resblks[i], data, octree, depth)
155
+ self.resblks[i], data, octree, depth, use_reentrant=False)
122
156
  else:
123
157
  data = self.resblks[i](data, octree, depth)
124
158
  return data
ocnn/nn/__init__.py CHANGED
@@ -17,7 +17,8 @@ from .octree_pool import (octree_max_pool, OctreeMaxPool,
17
17
  from .octree_conv import OctreeConv, OctreeDeconv
18
18
  from .octree_gconv import OctreeGroupConv
19
19
  from .octree_dwconv import OctreeDWConv
20
- from .octree_norm import OctreeBatchNorm, OctreeGroupNorm, OctreeInstanceNorm
20
+ from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
21
+ OctreeInstanceNorm, OctreeNorm)
21
22
  from .octree_drop import OctreeDropPath
22
23
  from .octree_align import search_value, octree_align
23
24
 
@@ -35,7 +36,7 @@ __all__ = [
35
36
  'OctreeConv', 'OctreeDeconv',
36
37
  'OctreeGroupConv', 'OctreeDWConv',
37
38
  'OctreeInterp', 'OctreeUpsample',
38
- 'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm',
39
+ 'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
39
40
  'OctreeDropPath',
40
41
  'search_value', 'octree_align',
41
42
  ]
ocnn/nn/octree_norm.py CHANGED
@@ -7,6 +7,7 @@
7
7
 
8
8
  import torch
9
9
  import torch.nn
10
+ from typing import Optional
10
11
 
11
12
  from ocnn.octree import Octree
12
13
  from ocnn.utils import scatter_add
@@ -19,15 +20,19 @@ class OctreeGroupNorm(torch.nn.Module):
19
20
  r''' An group normalization layer for the octree.
20
21
  '''
21
22
 
22
- def __init__(self, in_channels: int, group: int, nempty: bool = False):
23
+ def __init__(self, in_channels: int, group: int, nempty: bool = False,
24
+ min_group_channels: int = 4):
23
25
  super().__init__()
24
26
  self.eps = 1e-5
25
27
  self.nempty = nempty
26
28
  self.group = group
27
29
  self.in_channels = in_channels
30
+ self.min_group_channels = min_group_channels
31
+ if self.min_group_channels * self.group > in_channels:
32
+ self.group = in_channels // self.min_group_channels
28
33
 
29
- assert in_channels % group == 0
30
- self.channels_per_group = in_channels // group
34
+ assert in_channels % self.group == 0
35
+ self.channels_per_group = in_channels // self.group
31
36
 
32
37
  self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
33
38
  self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
@@ -71,8 +76,8 @@ class OctreeGroupNorm(torch.nn.Module):
71
76
  return tensor
72
77
 
73
78
  def extra_repr(self) -> str:
74
- return ('in_channels={}, group={}, nempty={}').format(
75
- self.in_channels, self.group, self.nempty) # noqa
79
+ return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
80
+ self.in_channels, self.group, self.nempty, self.min_group_channels)
76
81
 
77
82
 
78
83
  class OctreeInstanceNorm(OctreeGroupNorm):
@@ -80,7 +85,42 @@ class OctreeInstanceNorm(OctreeGroupNorm):
80
85
  '''
81
86
 
82
87
  def __init__(self, in_channels: int, nempty: bool = False):
83
- super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty)
88
+ super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
89
+ min_group_channels=1) # NOTE: group=in_channels
84
90
 
85
91
  def extra_repr(self) -> str:
86
92
  return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
93
+
94
+
95
+ class OctreeNorm(torch.nn.Module):
96
+ r''' A normalization layer for the octree. It encapsulates octree-based batch,
97
+ group and instance normalization.
98
+ '''
99
+
100
+ def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
101
+ group: int = 32, min_group_channels: int = 4):
102
+ super().__init__()
103
+ self.in_channels = in_channels
104
+ self.norm_type = norm_type
105
+ self.group = group
106
+ self.min_group_channels = min_group_channels
107
+
108
+ if self.norm_type == 'batch_norm':
109
+ self.norm = torch.nn.BatchNorm1d(in_channels)
110
+ elif self.norm_type == 'group_norm':
111
+ self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
112
+ elif self.norm_type == 'instance_norm':
113
+ self.norm = OctreeInstanceNorm(in_channels)
114
+ else:
115
+ raise ValueError
116
+
117
+ def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
118
+ depth: Optional[int] = None):
119
+ if self.norm_type == 'batch_norm':
120
+ output = self.norm(x)
121
+ elif (self.norm_type == 'group_norm' or
122
+ self.norm_type == 'instance_norm'):
123
+ output = self.norm(x, octree, depth)
124
+ else:
125
+ raise ValueError
126
+ return output
ocnn/octree/octree.py CHANGED
@@ -63,14 +63,14 @@ class Octree:
63
63
 
64
64
  # octree node numbers in each octree layers.
65
65
  # TODO: decide whether to settle them to 'gpu' or not?
66
- self.nnum = torch.zeros(num, dtype=torch.int32)
67
- self.nnum_nempty = torch.zeros(num, dtype=torch.int32)
66
+ self.nnum = torch.zeros(num, dtype=torch.long)
67
+ self.nnum_nempty = torch.zeros(num, dtype=torch.long)
68
68
 
69
69
  # the following properties are valid after `merge_octrees`.
70
70
  # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
71
71
  batch_size = self.batch_size
72
- self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.int32)
73
- self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.int32)
72
+ self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
73
+ self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
74
74
 
75
75
  # construct the look up tables for neighborhood searching
76
76
  device = self.device
@@ -290,7 +290,7 @@ class Octree:
290
290
  # node number
291
291
  nnum = self.nnum_nempty[depth-1] * 8
292
292
  self.nnum[depth] = nnum
293
- self.nnum_nempty[depth] = nnum
293
+ self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
294
294
 
295
295
  # update keys
296
296
  key = self.key(depth-1, nempty=True)
@@ -326,7 +326,7 @@ class Octree:
326
326
  xyz = xyz.view(-1, 3) # (N*27, 3)
327
327
  neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
328
328
 
329
- bs = torch.arange(self.batch_size, dtype=torch.int32, device=device)
329
+ bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
330
330
  neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
331
331
 
332
332
  bound = 1 << depth
@@ -383,9 +383,10 @@ class Octree:
383
383
  # I choose `torch.bucketize` here because it has fewer dimension checks,
384
384
  # resulting in slightly better performance according to the docs of
385
385
  # pytorch-1.9.1, since `key` is always 1-D sorted sequence.
386
+ # https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
386
387
  idx = torch.bucketize(query, key)
387
388
 
388
- valid = idx < key.shape[0] # valid if NOT out-of-bound
389
+ valid = idx < key.shape[0] # valid if in-bound
389
390
  found = key[idx[valid]] == query[valid]
390
391
  valid[valid.clone()] = found # valid if found
391
392
  idx[valid.logical_not()] = -1 # set to -1 if invalid
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ocnn
3
- Version: 2.2.3
3
+ Version: 2.2.4
4
4
  Summary: Octree-based Sparse Convolutional Neural Networks
5
5
  Home-page: https://github.com/octree-nn/ocnn-pytorch
6
6
  Author: Peng-Shuai Wang
@@ -23,6 +23,7 @@ Requires-Dist: packaging
23
23
 
24
24
  [![Documentation Status](https://readthedocs.org/projects/ocnn-pytorch/badge/?version=latest)](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
25
25
  [![Downloads](https://static.pepy.tech/badge/ocnn)](https://pepy.tech/project/ocnn)
26
+ [![Downloads](https://static.pepy.tech/badge/ocnn/month)](https://pepy.tech/project/ocnn)
26
27
  [![PyPI](https://img.shields.io/pypi/v/ocnn)](https://pypi.org/project/ocnn/)
27
28
 
28
29
  This repository contains the **pure PyTorch**-based implementation of
@@ -1,19 +1,19 @@
1
- ocnn/__init__.py,sha256=kCGigMn_30MohB49Hwy0ZBXf5HTUA_i-QOjckrM48Nc,582
1
+ ocnn/__init__.py,sha256=Fdq7gwK0CNhErDf-3rX7M-zbqj1RQbBQNoGI5bBTMOg,582
2
2
  ocnn/dataset.py,sha256=wvclvjlZs9qTeMXWLaO32K5d1VVY9XHSNuVVJEpVeeo,5266
3
3
  ocnn/utils.py,sha256=XhykveOjHoQd94gjJ5-opzXs-9MOCAzZ34ArZ8mG4sE,6726
4
4
  ocnn/models/__init__.py,sha256=F9PJRhOPHc1OrwkqcfywEBW0J6jmVW7-IHgWjGpY15U,724
5
- ocnn/models/autoencoder.py,sha256=TjOet3dbZLexz-PvZdeV0mbIuGpUAeSH6KWU_z_s-d8,5906
5
+ ocnn/models/autoencoder.py,sha256=aCGuouaia0qXxKB8mQ-wLnM7liwnn2JZGLiNLo9g6kg,5877
6
6
  ocnn/models/hrnet.py,sha256=9W2fi7Fuw0JXDBiZtEoUW2K7ghtpWUm_BWd-mKoHLY0,6684
7
7
  ocnn/models/image2shape.py,sha256=5djcOHJh2SQCwd5XdLPeL5vlQDNWnRU3tJY3ojRI8aQ,4589
8
8
  ocnn/models/lenet.py,sha256=ujVBxnn8AKiIKqB4WHxuK728oJe4TZdqZUTTNy8y3zE,1754
9
- ocnn/models/ounet.py,sha256=sGGVHqF1nKw798vyththxp7-P3wHlsDUO39hbQpUVTw,3303
10
- ocnn/models/resnet.py,sha256=9gZKbhFituqPJHCm-bp1xTtHKZ6wu77s31a0q7RtXoQ,2029
11
- ocnn/models/segnet.py,sha256=VfZf8gBMSPgO5m8Agsfccw9snTrI4LGh0SqxhLZ1F8s,2575
9
+ ocnn/models/ounet.py,sha256=PcoOQiEQqWmtvfpA4zNAce8ijATiiYa-17DKdWkUm7A,3368
10
+ ocnn/models/resnet.py,sha256=_bRaLBYK5yVqklYEvfPjcndWvMpOKaygnSl1pyjT-4w,2029
11
+ ocnn/models/segnet.py,sha256=A3KWF-kJQ1M-ByJKv88wfMsdZYccBzP_p9W_0EzNB6w,2575
12
12
  ocnn/models/unet.py,sha256=1FZbTvmWg6sMYkcZyNxcSr_qN6bfOOag-tIwAoqIPKU,4123
13
- ocnn/modules/__init__.py,sha256=BAEZybvtwDQf7yPZpGcptH1Oxf654Alq87G9D0vKw-E,820
14
- ocnn/modules/modules.py,sha256=7VlBsbwN49J8Xea0DXcFHVhN8POxhzC3FNMY3_JHukM,6207
15
- ocnn/modules/resblocks.py,sha256=Xh5EH6aSx8zcLk3D4m4lZiWDejLMu5AQDH7ZDI4p23A,4548
16
- ocnn/nn/__init__.py,sha256=6M3GEbyepHCU7l_QU1blSI98M0X4d3jbCsXmVGtXYj0,1769
13
+ ocnn/modules/__init__.py,sha256=pRJLNGM4F5a3QO2Y6AJqKnaOJqyR4nApVzNQxJ1vJgQ,1132
14
+ ocnn/modules/modules.py,sha256=l9pkJUgg7_pz9310eJ_806T7brACIqtTvbcK1eFBw-U,9846
15
+ ocnn/modules/resblocks.py,sha256=GHRTwzRX-QG9FnHNgSej3RdJJYSxBVLc5GWQEj7bjjI,5866
16
+ ocnn/nn/__init__.py,sha256=GPOng7D-vfs3FCea9jrRKaGumAkyrMz0f5XWGx2v2Fs,1824
17
17
  ocnn/nn/octree2col.py,sha256=07BGGJD0je0VD-VdS_aKtDo7gWKNWdgojeL1a2n4VRQ,2137
18
18
  ocnn/nn/octree2vox.py,sha256=QgyxxZvRvw2taHFWCvcGyCFAnFC6D6nfw71HrvwN3PI,1524
19
19
  ocnn/nn/octree_align.py,sha256=Y12GBKS-F3JtRNaDdDcBFDBvhndoRxeMOVua4RB5HZE,1696
@@ -22,15 +22,15 @@ ocnn/nn/octree_drop.py,sha256=croMHtk0JScDT0nLpdmbiMnkM_b5uVAz6sOEUcta6sY,1963
22
22
  ocnn/nn/octree_dwconv.py,sha256=cIghi7zyMUGhTd6QsU4PS6K2VlwXaYeLNfZFt2xuVi4,7350
23
23
  ocnn/nn/octree_gconv.py,sha256=Ogkn7wE49dDdP0X_aBthNfJXMsqcRn3ufJuYQtEEjUI,2928
24
24
  ocnn/nn/octree_interp.py,sha256=yQVjiKMNLU1XakPeYjW5ArMsw2fxFRZdh90Nn3cWQuE,7108
25
- ocnn/nn/octree_norm.py,sha256=Mbn28Hv-CEWt2WA0Pdhj2p127m10vwhepzCJVjiZYMI,2976
25
+ ocnn/nn/octree_norm.py,sha256=XrSQ7oZKfepYqQIuPU4loaYNae_rgtPpekR0cqTvryg,4524
26
26
  ocnn/nn/octree_pad.py,sha256=suV6Ftb-UlUuoJzKdQ9DCP9oqVQJq7vXb9_6hq-kUk4,1323
27
27
  ocnn/nn/octree_pool.py,sha256=Zn2XLk5SFl6pqMhhKIvu1uZl5ebonFSlPcKr54fOIPA,6664
28
28
  ocnn/octree/__init__.py,sha256=vKZFc5_r6Gxg5KsPWiCZCR-umWWfWPoE7qBx4PIrUGA,630
29
- ocnn/octree/octree.py,sha256=PgWiECd3h49AZHrK_0GxwwAgnG9l37AcJILanM7d1-k,23713
29
+ ocnn/octree/octree.py,sha256=Ir1jP_s4jrTVGtFZAY3zCaTLpekDHqtoSvGDJG84_QQ,23803
30
30
  ocnn/octree/points.py,sha256=7_iiN0y0g9SgV8kLzuGW20C-MFOmz5hh1aQhDhUIa0Q,11355
31
31
  ocnn/octree/shuffled_key.py,sha256=UJZ4eKNA_7nLbf9FbEvS_3VyrAqnZCOzk1hsPtJianM,3936
32
- ocnn-2.2.3.dist-info/LICENSE,sha256=YeOS0Plo8Uistv_8ZXdgddmN9GHJKnIiJ5FZ8zTW6Sw,1114
33
- ocnn-2.2.3.dist-info/METADATA,sha256=Uh2UlUwXzRfM9G9d1EBRIiZARcumPCVZPY3CLjR2q8U,3682
34
- ocnn-2.2.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
35
- ocnn-2.2.3.dist-info/top_level.txt,sha256=ayZdVOnxlOke3kgzAlrRh2IEL_qOudwOaEU3xhjtpZ0,5
36
- ocnn-2.2.3.dist-info/RECORD,,
32
+ ocnn-2.2.4.dist-info/LICENSE,sha256=YeOS0Plo8Uistv_8ZXdgddmN9GHJKnIiJ5FZ8zTW6Sw,1114
33
+ ocnn-2.2.4.dist-info/METADATA,sha256=M2KsDEZXmlZ3sli4dNixbGVyNLF7qN-CpLjf9OZP_vg,3773
34
+ ocnn-2.2.4.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
35
+ ocnn-2.2.4.dist-info/top_level.txt,sha256=ayZdVOnxlOke3kgzAlrRh2IEL_qOudwOaEU3xhjtpZ0,5
36
+ ocnn-2.2.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
File without changes