ocnn 2.2.3__tar.gz → 2.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {ocnn-2.2.3/ocnn.egg-info → ocnn-2.2.4}/PKG-INFO +2 -1
  2. {ocnn-2.2.3 → ocnn-2.2.4}/README.md +1 -0
  3. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/__init__.py +1 -1
  4. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/autoencoder.py +0 -1
  5. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/ounet.py +3 -2
  6. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/resnet.py +1 -1
  7. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/segnet.py +2 -2
  8. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/modules/__init__.py +9 -3
  9. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/modules/modules.py +117 -7
  10. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/modules/resblocks.py +39 -5
  11. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/__init__.py +3 -2
  12. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_norm.py +46 -6
  13. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/octree/octree.py +8 -7
  14. {ocnn-2.2.3 → ocnn-2.2.4/ocnn.egg-info}/PKG-INFO +2 -1
  15. {ocnn-2.2.3 → ocnn-2.2.4}/setup.py +1 -1
  16. {ocnn-2.2.3 → ocnn-2.2.4}/LICENSE +0 -0
  17. {ocnn-2.2.3 → ocnn-2.2.4}/MANIFEST.in +0 -0
  18. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/dataset.py +0 -0
  19. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/__init__.py +0 -0
  20. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/hrnet.py +0 -0
  21. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/image2shape.py +0 -0
  22. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/lenet.py +0 -0
  23. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/unet.py +0 -0
  24. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree2col.py +0 -0
  25. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree2vox.py +0 -0
  26. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_align.py +0 -0
  27. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_conv.py +0 -0
  28. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_drop.py +0 -0
  29. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_dwconv.py +0 -0
  30. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_gconv.py +0 -0
  31. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_interp.py +0 -0
  32. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_pad.py +0 -0
  33. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_pool.py +0 -0
  34. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/octree/__init__.py +0 -0
  35. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/octree/points.py +0 -0
  36. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/octree/shuffled_key.py +0 -0
  37. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/utils.py +0 -0
  38. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn.egg-info/SOURCES.txt +0 -0
  39. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn.egg-info/dependency_links.txt +0 -0
  40. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn.egg-info/not-zip-safe +0 -0
  41. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn.egg-info/requires.txt +0 -0
  42. {ocnn-2.2.3 → ocnn-2.2.4}/ocnn.egg-info/top_level.txt +0 -0
  43. {ocnn-2.2.3 → ocnn-2.2.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ocnn
3
- Version: 2.2.3
3
+ Version: 2.2.4
4
4
  Summary: Octree-based Sparse Convolutional Neural Networks
5
5
  Home-page: https://github.com/octree-nn/ocnn-pytorch
6
6
  Author: Peng-Shuai Wang
@@ -23,6 +23,7 @@ Requires-Dist: packaging
23
23
 
24
24
  [![Documentation Status](https://readthedocs.org/projects/ocnn-pytorch/badge/?version=latest)](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
25
25
  [![Downloads](https://static.pepy.tech/badge/ocnn)](https://pepy.tech/project/ocnn)
26
+ [![Downloads](https://static.pepy.tech/badge/ocnn/month)](https://pepy.tech/project/ocnn)
26
27
  [![PyPI](https://img.shields.io/pypi/v/ocnn)](https://pypi.org/project/ocnn/)
27
28
 
28
29
  This repository contains the **pure PyTorch**-based implementation of
@@ -4,6 +4,7 @@
4
4
 
5
5
  [![Documentation Status](https://readthedocs.org/projects/ocnn-pytorch/badge/?version=latest)](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
6
6
  [![Downloads](https://static.pepy.tech/badge/ocnn)](https://pepy.tech/project/ocnn)
7
+ [![Downloads](https://static.pepy.tech/badge/ocnn/month)](https://pepy.tech/project/ocnn)
7
8
  [![PyPI](https://img.shields.io/pypi/v/ocnn)](https://pypi.org/project/ocnn/)
8
9
 
9
10
  This repository contains the **pure PyTorch**-based implementation of
@@ -12,7 +12,7 @@ from . import models
12
12
  from . import dataset
13
13
  from . import utils
14
14
 
15
- __version__ = '2.2.3'
15
+ __version__ = '2.2.4'
16
16
 
17
17
  __all__ = [
18
18
  'octree',
@@ -7,7 +7,6 @@
7
7
 
8
8
  import torch
9
9
  import torch.nn
10
- from typing import Optional
11
10
 
12
11
  import ocnn
13
12
  from ocnn.octree import Octree
@@ -23,12 +23,13 @@ class OUNet(AutoEncoder):
23
23
  self.proj = None # remove this module used in AutoEncoder
24
24
 
25
25
  def encoder(self, octree):
26
- r''' The encoder network for extracting heirarchy features.
26
+ r''' The encoder network for extracting heirarchy features.
27
27
  '''
28
28
 
29
29
  convs = dict()
30
30
  depth, full_depth = self.depth, self.full_depth
31
- data = self.get_input_feature(octree)
31
+ data = octree.get_input_feature(self.feature, nempty=False)
32
+ assert data.size(1) == self.channel_in
32
33
  convs[depth] = self.conv1(data, octree, depth)
33
34
  for i, d in enumerate(range(depth, full_depth-1, -1)):
34
35
  convs[d] = self.encoder_blks[i](convs[d], octree, d)
@@ -31,7 +31,7 @@ class ResNet(torch.nn.Module):
31
31
  channels[i], channels[i+1], resblock_num, nempty=nempty)
32
32
  for i in range(stages-1)])
33
33
  self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
34
- nempty) for i in range(stages-1)])
34
+ nempty) for _ in range(stages-1)])
35
35
  self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
36
36
  # self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
37
37
  self.header = torch.nn.Sequential(
@@ -29,7 +29,7 @@ class SegNet(torch.nn.Module):
29
29
  [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
30
30
  for i in range(stages)])
31
31
  self.pools = torch.nn.ModuleList(
32
- [ocnn.nn.OctreeMaxPool(nempty, return_indices) for i in range(stages)])
32
+ [ocnn.nn.OctreeMaxPool(nempty, return_indices) for _ in range(stages)])
33
33
 
34
34
  self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
35
35
 
@@ -38,7 +38,7 @@ class SegNet(torch.nn.Module):
38
38
  [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
39
39
  for i in range(0, stages)])
40
40
  self.unpools = torch.nn.ModuleList(
41
- [ocnn.nn.OctreeMaxUnpool(nempty) for i in range(stages)])
41
+ [ocnn.nn.OctreeMaxUnpool(nempty) for _ in range(stages)])
42
42
 
43
43
  self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
44
44
  self.header = torch.nn.Sequential(
@@ -7,14 +7,20 @@
7
7
 
8
8
  from .modules import (InputFeature,
9
9
  OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
10
- Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,)
11
- from .resblocks import OctreeResBlock, OctreeResBlock2, OctreeResBlocks
10
+ Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
11
+ OctreeConvGn, OctreeConvGnRelu, OctreeDeconvGnRelu,
12
+ Conv1x1, Conv1x1Gn, Conv1x1GnRelu)
13
+ from .resblocks import (OctreeResBlock, OctreeResBlock2, OctreeResBlockGn,
14
+ OctreeResBlocks,)
12
15
 
13
16
  __all__ = [
14
17
  'InputFeature',
15
18
  'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
16
19
  'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
17
- 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlocks',
20
+ 'OctreeConvGn', 'OctreeConvGnRelu', 'OctreeDeconvGnRelu',
21
+ 'Conv1x1', 'Conv1x1Gn', 'Conv1x1GnRelu',
22
+ 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlockGn',
23
+ 'OctreeResBlocks',
18
24
  ]
19
25
 
20
26
  classes = __all__
@@ -9,7 +9,7 @@ import torch
9
9
  import torch.utils.checkpoint
10
10
  from typing import List
11
11
 
12
- from ocnn.nn import OctreeConv, OctreeDeconv
12
+ from ocnn.nn import OctreeConv, OctreeDeconv, OctreeGroupNorm
13
13
  from ocnn.octree import Octree
14
14
 
15
15
 
@@ -40,7 +40,7 @@ class OctreeConvBn(torch.nn.Module):
40
40
  super().__init__()
41
41
  self.conv = OctreeConv(
42
42
  in_channels, out_channels, kernel_size, stride, nempty)
43
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
43
+ self.bn = torch.nn.BatchNorm1d(out_channels)
44
44
 
45
45
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
46
46
  r''''''
@@ -62,7 +62,7 @@ class OctreeConvBnRelu(torch.nn.Module):
62
62
  super().__init__()
63
63
  self.conv = OctreeConv(
64
64
  in_channels, out_channels, kernel_size, stride, nempty)
65
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
65
+ self.bn = torch.nn.BatchNorm1d(out_channels)
66
66
  self.relu = torch.nn.ReLU(inplace=True)
67
67
 
68
68
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
@@ -86,7 +86,7 @@ class OctreeDeconvBnRelu(torch.nn.Module):
86
86
  super().__init__()
87
87
  self.deconv = OctreeDeconv(
88
88
  in_channels, out_channels, kernel_size, stride, nempty)
89
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
89
+ self.bn = torch.nn.BatchNorm1d(out_channels)
90
90
  self.relu = torch.nn.ReLU(inplace=True)
91
91
 
92
92
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
@@ -123,7 +123,7 @@ class Conv1x1Bn(torch.nn.Module):
123
123
  def __init__(self, in_channels: int, out_channels: int):
124
124
  super().__init__()
125
125
  self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
126
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
126
+ self.bn = torch.nn.BatchNorm1d(out_channels)
127
127
 
128
128
  def forward(self, data: torch.Tensor):
129
129
  r''''''
@@ -140,7 +140,7 @@ class Conv1x1BnRelu(torch.nn.Module):
140
140
  def __init__(self, in_channels: int, out_channels: int):
141
141
  super().__init__()
142
142
  self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
143
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
143
+ self.bn = torch.nn.BatchNorm1d(out_channels)
144
144
  self.relu = torch.nn.ReLU(inplace=True)
145
145
 
146
146
  def forward(self, data: torch.Tensor):
@@ -160,7 +160,7 @@ class FcBnRelu(torch.nn.Module):
160
160
  super().__init__()
161
161
  self.flatten = torch.nn.Flatten(start_dim=1)
162
162
  self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
163
- self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
163
+ self.bn = torch.nn.BatchNorm1d(out_channels)
164
164
  self.relu = torch.nn.ReLU(inplace=True)
165
165
 
166
166
  def forward(self, data):
@@ -173,6 +173,116 @@ class FcBnRelu(torch.nn.Module):
173
173
  return out
174
174
 
175
175
 
176
+ class OctreeConvGn(torch.nn.Module):
177
+ r''' A sequence of :class:`OctreeConv` and :obj:`OctreeGroupNorm`.
178
+
179
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
180
+ '''
181
+
182
+ def __init__(self, in_channels: int, out_channels: int, group: int,
183
+ kernel_size: List[int] = [3], stride: int = 1,
184
+ nempty: bool = False):
185
+ super().__init__()
186
+ self.conv = OctreeConv(
187
+ in_channels, out_channels, kernel_size, stride, nempty)
188
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
189
+
190
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
191
+ r''''''
192
+
193
+ out = self.conv(data, octree, depth)
194
+ out = self.gn(out, octree, depth)
195
+ return out
196
+
197
+
198
+ class OctreeConvGnRelu(torch.nn.Module):
199
+ r''' A sequence of :class:`OctreeConv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
200
+
201
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
202
+ '''
203
+
204
+ def __init__(self, in_channels: int, out_channels: int, group: int,
205
+ kernel_size: List[int] = [3], stride: int = 1,
206
+ nempty: bool = False):
207
+ super().__init__()
208
+ self.stride = stride
209
+ self.conv = OctreeConv(
210
+ in_channels, out_channels, kernel_size, stride, nempty)
211
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
212
+ self.relu = torch.nn.ReLU(inplace=True)
213
+
214
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
215
+ r''''''
216
+
217
+ out = self.conv(data, octree, depth)
218
+ out = self.gn(out, octree, depth if self.stride == 1 else depth - 1)
219
+ out = self.relu(out)
220
+ return out
221
+
222
+
223
+ class OctreeDeconvGnRelu(torch.nn.Module):
224
+ r''' A sequence of :class:`OctreeDeconv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
225
+
226
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
227
+ '''
228
+
229
+ def __init__(self, in_channels: int, out_channels: int, group: int,
230
+ kernel_size: List[int] = [3], stride: int = 1,
231
+ nempty: bool = False):
232
+ super().__init__()
233
+ self.stride = stride
234
+ self.deconv = OctreeDeconv(
235
+ in_channels, out_channels, kernel_size, stride, nempty)
236
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
237
+ self.relu = torch.nn.ReLU(inplace=True)
238
+
239
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
240
+ r''''''
241
+
242
+ out = self.deconv(data, octree, depth)
243
+ out = self.gn(out, octree, depth if self.stride == 1 else depth + 1)
244
+ out = self.relu(out)
245
+ return out
246
+
247
+
248
+ class Conv1x1Gn(torch.nn.Module):
249
+ r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm`.
250
+ '''
251
+
252
+ def __init__(self, in_channels: int, out_channels: int, group: int,
253
+ nempty: bool = False):
254
+ super().__init__()
255
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
256
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
257
+
258
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
259
+ r''''''
260
+
261
+ out = self.conv(data)
262
+ out = self.gn(out, octree, depth)
263
+ return out
264
+
265
+
266
+ class Conv1x1GnRelu(torch.nn.Module):
267
+ r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm` and :class:`Relu`.
268
+ '''
269
+
270
+ def __init__(self, in_channels: int, out_channels: int, group: int,
271
+ nempty: bool = False):
272
+ super().__init__()
273
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
274
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
275
+ self.relu = torch.nn.ReLU(inplace=True)
276
+
277
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
278
+ r''''''
279
+
280
+ out = self.conv(data)
281
+ out = self.gn(out, octree, depth)
282
+ out = self.relu(out)
283
+ return out
284
+
285
+
176
286
  class InputFeature(torch.nn.Module):
177
287
  r''' Returns the initial input feature stored in octree.
178
288
 
@@ -10,7 +10,9 @@ import torch.utils.checkpoint
10
10
 
11
11
  from ocnn.octree import Octree
12
12
  from ocnn.nn import OctreeMaxPool
13
- from ocnn.modules import Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn, OctreeConvBn
13
+ from ocnn.modules import (Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
14
+ OctreeConvBn, OctreeConvGnRelu, Conv1x1Gn,
15
+ OctreeConvGn,)
14
16
 
15
17
 
16
18
  class OctreeResBlock(torch.nn.Module):
@@ -97,6 +99,38 @@ class OctreeResBlock2(torch.nn.Module):
97
99
  return out
98
100
 
99
101
 
102
+ class OctreeResBlockGn(torch.nn.Module):
103
+
104
+ def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
105
+ bottleneck: int = 4, nempty: bool = False, group: int = 32):
106
+ super().__init__()
107
+ self.in_channels = in_channels
108
+ self.out_channels = out_channels
109
+ self.stride = stride
110
+ channelb = int(out_channels / bottleneck)
111
+
112
+ if self.stride == 2:
113
+ self.maxpool = OctreeMaxPool(self.depth)
114
+ self.conv3x3a = OctreeConvGnRelu(in_channels, channelb, group, nempty=nempty)
115
+ self.conv3x3b = OctreeConvGn(channelb, out_channels, group, nempty=nempty)
116
+ if self.in_channels != self.out_channels:
117
+ self.conv1x1 = Conv1x1Gn(in_channels, out_channels, group)
118
+ self.relu = torch.nn.ReLU(inplace=True)
119
+
120
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
121
+ r''''''
122
+
123
+ if self.stride == 2:
124
+ data = self.maxpool(data, octree, depth)
125
+ depth = depth - 1
126
+ conv1 = self.conv3x3a(data, octree, depth)
127
+ conv2 = self.conv3x3b(conv1, octree, depth)
128
+ if self.in_channels != self.out_channels:
129
+ data = self.conv1x1(data, octree, depth)
130
+ out = self.relu(conv2 + data)
131
+ return out
132
+
133
+
100
134
  class OctreeResBlocks(torch.nn.Module):
101
135
  r''' A sequence of :attr:`resblk_num` ResNet blocks.
102
136
  '''
@@ -108,9 +142,9 @@ class OctreeResBlocks(torch.nn.Module):
108
142
  self.use_checkpoint = use_checkpoint
109
143
  channels = [in_channels] + [out_channels] * resblk_num
110
144
 
111
- self.resblks = torch.nn.ModuleList(
112
- [resblk(channels[i], channels[i+1], 1, bottleneck, nempty)
113
- for i in range(self.resblk_num)])
145
+ self.resblks = torch.nn.ModuleList([resblk(
146
+ channels[i], channels[i+1], 1, bottleneck, nempty)
147
+ for i in range(self.resblk_num)])
114
148
 
115
149
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
116
150
  r''''''
@@ -118,7 +152,7 @@ class OctreeResBlocks(torch.nn.Module):
118
152
  for i in range(self.resblk_num):
119
153
  if self.use_checkpoint:
120
154
  data = torch.utils.checkpoint.checkpoint(
121
- self.resblks[i], data, octree, depth)
155
+ self.resblks[i], data, octree, depth, use_reentrant=False)
122
156
  else:
123
157
  data = self.resblks[i](data, octree, depth)
124
158
  return data
@@ -17,7 +17,8 @@ from .octree_pool import (octree_max_pool, OctreeMaxPool,
17
17
  from .octree_conv import OctreeConv, OctreeDeconv
18
18
  from .octree_gconv import OctreeGroupConv
19
19
  from .octree_dwconv import OctreeDWConv
20
- from .octree_norm import OctreeBatchNorm, OctreeGroupNorm, OctreeInstanceNorm
20
+ from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
21
+ OctreeInstanceNorm, OctreeNorm)
21
22
  from .octree_drop import OctreeDropPath
22
23
  from .octree_align import search_value, octree_align
23
24
 
@@ -35,7 +36,7 @@ __all__ = [
35
36
  'OctreeConv', 'OctreeDeconv',
36
37
  'OctreeGroupConv', 'OctreeDWConv',
37
38
  'OctreeInterp', 'OctreeUpsample',
38
- 'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm',
39
+ 'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
39
40
  'OctreeDropPath',
40
41
  'search_value', 'octree_align',
41
42
  ]
@@ -7,6 +7,7 @@
7
7
 
8
8
  import torch
9
9
  import torch.nn
10
+ from typing import Optional
10
11
 
11
12
  from ocnn.octree import Octree
12
13
  from ocnn.utils import scatter_add
@@ -19,15 +20,19 @@ class OctreeGroupNorm(torch.nn.Module):
19
20
  r''' An group normalization layer for the octree.
20
21
  '''
21
22
 
22
- def __init__(self, in_channels: int, group: int, nempty: bool = False):
23
+ def __init__(self, in_channels: int, group: int, nempty: bool = False,
24
+ min_group_channels: int = 4):
23
25
  super().__init__()
24
26
  self.eps = 1e-5
25
27
  self.nempty = nempty
26
28
  self.group = group
27
29
  self.in_channels = in_channels
30
+ self.min_group_channels = min_group_channels
31
+ if self.min_group_channels * self.group > in_channels:
32
+ self.group = in_channels // self.min_group_channels
28
33
 
29
- assert in_channels % group == 0
30
- self.channels_per_group = in_channels // group
34
+ assert in_channels % self.group == 0
35
+ self.channels_per_group = in_channels // self.group
31
36
 
32
37
  self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
33
38
  self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
@@ -71,8 +76,8 @@ class OctreeGroupNorm(torch.nn.Module):
71
76
  return tensor
72
77
 
73
78
  def extra_repr(self) -> str:
74
- return ('in_channels={}, group={}, nempty={}').format(
75
- self.in_channels, self.group, self.nempty) # noqa
79
+ return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
80
+ self.in_channels, self.group, self.nempty, self.min_group_channels)
76
81
 
77
82
 
78
83
  class OctreeInstanceNorm(OctreeGroupNorm):
@@ -80,7 +85,42 @@ class OctreeInstanceNorm(OctreeGroupNorm):
80
85
  '''
81
86
 
82
87
  def __init__(self, in_channels: int, nempty: bool = False):
83
- super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty)
88
+ super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
89
+ min_group_channels=1) # NOTE: group=in_channels
84
90
 
85
91
  def extra_repr(self) -> str:
86
92
  return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
93
+
94
+
95
+ class OctreeNorm(torch.nn.Module):
96
+ r''' A normalization layer for the octree. It encapsulates octree-based batch,
97
+ group and instance normalization.
98
+ '''
99
+
100
+ def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
101
+ group: int = 32, min_group_channels: int = 4):
102
+ super().__init__()
103
+ self.in_channels = in_channels
104
+ self.norm_type = norm_type
105
+ self.group = group
106
+ self.min_group_channels = min_group_channels
107
+
108
+ if self.norm_type == 'batch_norm':
109
+ self.norm = torch.nn.BatchNorm1d(in_channels)
110
+ elif self.norm_type == 'group_norm':
111
+ self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
112
+ elif self.norm_type == 'instance_norm':
113
+ self.norm = OctreeInstanceNorm(in_channels)
114
+ else:
115
+ raise ValueError
116
+
117
+ def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
118
+ depth: Optional[int] = None):
119
+ if self.norm_type == 'batch_norm':
120
+ output = self.norm(x)
121
+ elif (self.norm_type == 'group_norm' or
122
+ self.norm_type == 'instance_norm'):
123
+ output = self.norm(x, octree, depth)
124
+ else:
125
+ raise ValueError
126
+ return output
@@ -63,14 +63,14 @@ class Octree:
63
63
 
64
64
  # octree node numbers in each octree layers.
65
65
  # TODO: decide whether to settle them to 'gpu' or not?
66
- self.nnum = torch.zeros(num, dtype=torch.int32)
67
- self.nnum_nempty = torch.zeros(num, dtype=torch.int32)
66
+ self.nnum = torch.zeros(num, dtype=torch.long)
67
+ self.nnum_nempty = torch.zeros(num, dtype=torch.long)
68
68
 
69
69
  # the following properties are valid after `merge_octrees`.
70
70
  # TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
71
71
  batch_size = self.batch_size
72
- self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.int32)
73
- self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.int32)
72
+ self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
73
+ self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
74
74
 
75
75
  # construct the look up tables for neighborhood searching
76
76
  device = self.device
@@ -290,7 +290,7 @@ class Octree:
290
290
  # node number
291
291
  nnum = self.nnum_nempty[depth-1] * 8
292
292
  self.nnum[depth] = nnum
293
- self.nnum_nempty[depth] = nnum
293
+ self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
294
294
 
295
295
  # update keys
296
296
  key = self.key(depth-1, nempty=True)
@@ -326,7 +326,7 @@ class Octree:
326
326
  xyz = xyz.view(-1, 3) # (N*27, 3)
327
327
  neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
328
328
 
329
- bs = torch.arange(self.batch_size, dtype=torch.int32, device=device)
329
+ bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
330
330
  neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
331
331
 
332
332
  bound = 1 << depth
@@ -383,9 +383,10 @@ class Octree:
383
383
  # I choose `torch.bucketize` here because it has fewer dimension checks,
384
384
  # resulting in slightly better performance according to the docs of
385
385
  # pytorch-1.9.1, since `key` is always 1-D sorted sequence.
386
+ # https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
386
387
  idx = torch.bucketize(query, key)
387
388
 
388
- valid = idx < key.shape[0] # valid if NOT out-of-bound
389
+ valid = idx < key.shape[0] # valid if in-bound
389
390
  found = key[idx[valid]] == query[valid]
390
391
  valid[valid.clone()] = found # valid if found
391
392
  idx[valid.logical_not()] = -1 # set to -1 if invalid
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ocnn
3
- Version: 2.2.3
3
+ Version: 2.2.4
4
4
  Summary: Octree-based Sparse Convolutional Neural Networks
5
5
  Home-page: https://github.com/octree-nn/ocnn-pytorch
6
6
  Author: Peng-Shuai Wang
@@ -23,6 +23,7 @@ Requires-Dist: packaging
23
23
 
24
24
  [![Documentation Status](https://readthedocs.org/projects/ocnn-pytorch/badge/?version=latest)](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
25
25
  [![Downloads](https://static.pepy.tech/badge/ocnn)](https://pepy.tech/project/ocnn)
26
+ [![Downloads](https://static.pepy.tech/badge/ocnn/month)](https://pepy.tech/project/ocnn)
26
27
  [![PyPI](https://img.shields.io/pypi/v/ocnn)](https://pypi.org/project/ocnn/)
27
28
 
28
29
  This repository contains the **pure PyTorch**-based implementation of
@@ -7,7 +7,7 @@
7
7
 
8
8
  from setuptools import setup, find_packages
9
9
 
10
- __version__ = '2.2.3'
10
+ __version__ = '2.2.4'
11
11
 
12
12
  with open("README.md", "r", encoding="utf-8") as fid:
13
13
  long_description = fid.read()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes