ocnn 2.2.3__tar.gz → 2.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ocnn-2.2.3/ocnn.egg-info → ocnn-2.2.4}/PKG-INFO +2 -1
- {ocnn-2.2.3 → ocnn-2.2.4}/README.md +1 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/__init__.py +1 -1
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/autoencoder.py +0 -1
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/ounet.py +3 -2
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/resnet.py +1 -1
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/segnet.py +2 -2
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/modules/__init__.py +9 -3
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/modules/modules.py +117 -7
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/modules/resblocks.py +39 -5
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/__init__.py +3 -2
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_norm.py +46 -6
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/octree/octree.py +8 -7
- {ocnn-2.2.3 → ocnn-2.2.4/ocnn.egg-info}/PKG-INFO +2 -1
- {ocnn-2.2.3 → ocnn-2.2.4}/setup.py +1 -1
- {ocnn-2.2.3 → ocnn-2.2.4}/LICENSE +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/MANIFEST.in +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/dataset.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/__init__.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/hrnet.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/image2shape.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/lenet.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/models/unet.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree2col.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree2vox.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_align.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_conv.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_drop.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_dwconv.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_gconv.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_interp.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_pad.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/nn/octree_pool.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/octree/__init__.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/octree/points.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/octree/shuffled_key.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn/utils.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn.egg-info/SOURCES.txt +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn.egg-info/dependency_links.txt +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn.egg-info/not-zip-safe +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn.egg-info/requires.txt +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/ocnn.egg-info/top_level.txt +0 -0
- {ocnn-2.2.3 → ocnn-2.2.4}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ocnn
|
|
3
|
-
Version: 2.2.
|
|
3
|
+
Version: 2.2.4
|
|
4
4
|
Summary: Octree-based Sparse Convolutional Neural Networks
|
|
5
5
|
Home-page: https://github.com/octree-nn/ocnn-pytorch
|
|
6
6
|
Author: Peng-Shuai Wang
|
|
@@ -23,6 +23,7 @@ Requires-Dist: packaging
|
|
|
23
23
|
|
|
24
24
|
[](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
|
|
25
25
|
[](https://pepy.tech/project/ocnn)
|
|
26
|
+
[](https://pepy.tech/project/ocnn)
|
|
26
27
|
[](https://pypi.org/project/ocnn/)
|
|
27
28
|
|
|
28
29
|
This repository contains the **pure PyTorch**-based implementation of
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
[](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
|
|
6
6
|
[](https://pepy.tech/project/ocnn)
|
|
7
|
+
[](https://pepy.tech/project/ocnn)
|
|
7
8
|
[](https://pypi.org/project/ocnn/)
|
|
8
9
|
|
|
9
10
|
This repository contains the **pure PyTorch**-based implementation of
|
|
@@ -23,12 +23,13 @@ class OUNet(AutoEncoder):
|
|
|
23
23
|
self.proj = None # remove this module used in AutoEncoder
|
|
24
24
|
|
|
25
25
|
def encoder(self, octree):
|
|
26
|
-
r''' The encoder network for extracting heirarchy features.
|
|
26
|
+
r''' The encoder network for extracting heirarchy features.
|
|
27
27
|
'''
|
|
28
28
|
|
|
29
29
|
convs = dict()
|
|
30
30
|
depth, full_depth = self.depth, self.full_depth
|
|
31
|
-
data =
|
|
31
|
+
data = octree.get_input_feature(self.feature, nempty=False)
|
|
32
|
+
assert data.size(1) == self.channel_in
|
|
32
33
|
convs[depth] = self.conv1(data, octree, depth)
|
|
33
34
|
for i, d in enumerate(range(depth, full_depth-1, -1)):
|
|
34
35
|
convs[d] = self.encoder_blks[i](convs[d], octree, d)
|
|
@@ -31,7 +31,7 @@ class ResNet(torch.nn.Module):
|
|
|
31
31
|
channels[i], channels[i+1], resblock_num, nempty=nempty)
|
|
32
32
|
for i in range(stages-1)])
|
|
33
33
|
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
34
|
-
nempty) for
|
|
34
|
+
nempty) for _ in range(stages-1)])
|
|
35
35
|
self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
|
|
36
36
|
# self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
|
|
37
37
|
self.header = torch.nn.Sequential(
|
|
@@ -29,7 +29,7 @@ class SegNet(torch.nn.Module):
|
|
|
29
29
|
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
30
30
|
for i in range(stages)])
|
|
31
31
|
self.pools = torch.nn.ModuleList(
|
|
32
|
-
[ocnn.nn.OctreeMaxPool(nempty, return_indices) for
|
|
32
|
+
[ocnn.nn.OctreeMaxPool(nempty, return_indices) for _ in range(stages)])
|
|
33
33
|
|
|
34
34
|
self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
|
|
35
35
|
|
|
@@ -38,7 +38,7 @@ class SegNet(torch.nn.Module):
|
|
|
38
38
|
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
39
39
|
for i in range(0, stages)])
|
|
40
40
|
self.unpools = torch.nn.ModuleList(
|
|
41
|
-
[ocnn.nn.OctreeMaxUnpool(nempty) for
|
|
41
|
+
[ocnn.nn.OctreeMaxUnpool(nempty) for _ in range(stages)])
|
|
42
42
|
|
|
43
43
|
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
|
|
44
44
|
self.header = torch.nn.Sequential(
|
|
@@ -7,14 +7,20 @@
|
|
|
7
7
|
|
|
8
8
|
from .modules import (InputFeature,
|
|
9
9
|
OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
|
|
10
|
-
Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
|
|
11
|
-
|
|
10
|
+
Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
|
|
11
|
+
OctreeConvGn, OctreeConvGnRelu, OctreeDeconvGnRelu,
|
|
12
|
+
Conv1x1, Conv1x1Gn, Conv1x1GnRelu)
|
|
13
|
+
from .resblocks import (OctreeResBlock, OctreeResBlock2, OctreeResBlockGn,
|
|
14
|
+
OctreeResBlocks,)
|
|
12
15
|
|
|
13
16
|
__all__ = [
|
|
14
17
|
'InputFeature',
|
|
15
18
|
'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
|
|
16
19
|
'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
|
|
17
|
-
'
|
|
20
|
+
'OctreeConvGn', 'OctreeConvGnRelu', 'OctreeDeconvGnRelu',
|
|
21
|
+
'Conv1x1', 'Conv1x1Gn', 'Conv1x1GnRelu',
|
|
22
|
+
'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlockGn',
|
|
23
|
+
'OctreeResBlocks',
|
|
18
24
|
]
|
|
19
25
|
|
|
20
26
|
classes = __all__
|
|
@@ -9,7 +9,7 @@ import torch
|
|
|
9
9
|
import torch.utils.checkpoint
|
|
10
10
|
from typing import List
|
|
11
11
|
|
|
12
|
-
from ocnn.nn import OctreeConv, OctreeDeconv
|
|
12
|
+
from ocnn.nn import OctreeConv, OctreeDeconv, OctreeGroupNorm
|
|
13
13
|
from ocnn.octree import Octree
|
|
14
14
|
|
|
15
15
|
|
|
@@ -40,7 +40,7 @@ class OctreeConvBn(torch.nn.Module):
|
|
|
40
40
|
super().__init__()
|
|
41
41
|
self.conv = OctreeConv(
|
|
42
42
|
in_channels, out_channels, kernel_size, stride, nempty)
|
|
43
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
43
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
44
44
|
|
|
45
45
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
46
46
|
r''''''
|
|
@@ -62,7 +62,7 @@ class OctreeConvBnRelu(torch.nn.Module):
|
|
|
62
62
|
super().__init__()
|
|
63
63
|
self.conv = OctreeConv(
|
|
64
64
|
in_channels, out_channels, kernel_size, stride, nempty)
|
|
65
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
65
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
66
66
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
67
67
|
|
|
68
68
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
@@ -86,7 +86,7 @@ class OctreeDeconvBnRelu(torch.nn.Module):
|
|
|
86
86
|
super().__init__()
|
|
87
87
|
self.deconv = OctreeDeconv(
|
|
88
88
|
in_channels, out_channels, kernel_size, stride, nempty)
|
|
89
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
89
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
90
90
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
91
91
|
|
|
92
92
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
@@ -123,7 +123,7 @@ class Conv1x1Bn(torch.nn.Module):
|
|
|
123
123
|
def __init__(self, in_channels: int, out_channels: int):
|
|
124
124
|
super().__init__()
|
|
125
125
|
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
126
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
126
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
127
127
|
|
|
128
128
|
def forward(self, data: torch.Tensor):
|
|
129
129
|
r''''''
|
|
@@ -140,7 +140,7 @@ class Conv1x1BnRelu(torch.nn.Module):
|
|
|
140
140
|
def __init__(self, in_channels: int, out_channels: int):
|
|
141
141
|
super().__init__()
|
|
142
142
|
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
143
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
143
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
144
144
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
145
145
|
|
|
146
146
|
def forward(self, data: torch.Tensor):
|
|
@@ -160,7 +160,7 @@ class FcBnRelu(torch.nn.Module):
|
|
|
160
160
|
super().__init__()
|
|
161
161
|
self.flatten = torch.nn.Flatten(start_dim=1)
|
|
162
162
|
self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
|
|
163
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
163
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
164
164
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
165
165
|
|
|
166
166
|
def forward(self, data):
|
|
@@ -173,6 +173,116 @@ class FcBnRelu(torch.nn.Module):
|
|
|
173
173
|
return out
|
|
174
174
|
|
|
175
175
|
|
|
176
|
+
class OctreeConvGn(torch.nn.Module):
|
|
177
|
+
r''' A sequence of :class:`OctreeConv` and :obj:`OctreeGroupNorm`.
|
|
178
|
+
|
|
179
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
180
|
+
'''
|
|
181
|
+
|
|
182
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
183
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
184
|
+
nempty: bool = False):
|
|
185
|
+
super().__init__()
|
|
186
|
+
self.conv = OctreeConv(
|
|
187
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
188
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
189
|
+
|
|
190
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
191
|
+
r''''''
|
|
192
|
+
|
|
193
|
+
out = self.conv(data, octree, depth)
|
|
194
|
+
out = self.gn(out, octree, depth)
|
|
195
|
+
return out
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class OctreeConvGnRelu(torch.nn.Module):
|
|
199
|
+
r''' A sequence of :class:`OctreeConv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
|
|
200
|
+
|
|
201
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
202
|
+
'''
|
|
203
|
+
|
|
204
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
205
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
206
|
+
nempty: bool = False):
|
|
207
|
+
super().__init__()
|
|
208
|
+
self.stride = stride
|
|
209
|
+
self.conv = OctreeConv(
|
|
210
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
211
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
212
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
213
|
+
|
|
214
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
215
|
+
r''''''
|
|
216
|
+
|
|
217
|
+
out = self.conv(data, octree, depth)
|
|
218
|
+
out = self.gn(out, octree, depth if self.stride == 1 else depth - 1)
|
|
219
|
+
out = self.relu(out)
|
|
220
|
+
return out
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class OctreeDeconvGnRelu(torch.nn.Module):
|
|
224
|
+
r''' A sequence of :class:`OctreeDeconv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
|
|
225
|
+
|
|
226
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
227
|
+
'''
|
|
228
|
+
|
|
229
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
230
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
231
|
+
nempty: bool = False):
|
|
232
|
+
super().__init__()
|
|
233
|
+
self.stride = stride
|
|
234
|
+
self.deconv = OctreeDeconv(
|
|
235
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
236
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
237
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
238
|
+
|
|
239
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
240
|
+
r''''''
|
|
241
|
+
|
|
242
|
+
out = self.deconv(data, octree, depth)
|
|
243
|
+
out = self.gn(out, octree, depth if self.stride == 1 else depth + 1)
|
|
244
|
+
out = self.relu(out)
|
|
245
|
+
return out
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class Conv1x1Gn(torch.nn.Module):
|
|
249
|
+
r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm`.
|
|
250
|
+
'''
|
|
251
|
+
|
|
252
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
253
|
+
nempty: bool = False):
|
|
254
|
+
super().__init__()
|
|
255
|
+
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
256
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
257
|
+
|
|
258
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
259
|
+
r''''''
|
|
260
|
+
|
|
261
|
+
out = self.conv(data)
|
|
262
|
+
out = self.gn(out, octree, depth)
|
|
263
|
+
return out
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class Conv1x1GnRelu(torch.nn.Module):
|
|
267
|
+
r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm` and :class:`Relu`.
|
|
268
|
+
'''
|
|
269
|
+
|
|
270
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
271
|
+
nempty: bool = False):
|
|
272
|
+
super().__init__()
|
|
273
|
+
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
274
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
275
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
276
|
+
|
|
277
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
278
|
+
r''''''
|
|
279
|
+
|
|
280
|
+
out = self.conv(data)
|
|
281
|
+
out = self.gn(out, octree, depth)
|
|
282
|
+
out = self.relu(out)
|
|
283
|
+
return out
|
|
284
|
+
|
|
285
|
+
|
|
176
286
|
class InputFeature(torch.nn.Module):
|
|
177
287
|
r''' Returns the initial input feature stored in octree.
|
|
178
288
|
|
|
@@ -10,7 +10,9 @@ import torch.utils.checkpoint
|
|
|
10
10
|
|
|
11
11
|
from ocnn.octree import Octree
|
|
12
12
|
from ocnn.nn import OctreeMaxPool
|
|
13
|
-
from ocnn.modules import Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
|
|
13
|
+
from ocnn.modules import (Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
|
|
14
|
+
OctreeConvBn, OctreeConvGnRelu, Conv1x1Gn,
|
|
15
|
+
OctreeConvGn,)
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class OctreeResBlock(torch.nn.Module):
|
|
@@ -97,6 +99,38 @@ class OctreeResBlock2(torch.nn.Module):
|
|
|
97
99
|
return out
|
|
98
100
|
|
|
99
101
|
|
|
102
|
+
class OctreeResBlockGn(torch.nn.Module):
|
|
103
|
+
|
|
104
|
+
def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
|
|
105
|
+
bottleneck: int = 4, nempty: bool = False, group: int = 32):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.in_channels = in_channels
|
|
108
|
+
self.out_channels = out_channels
|
|
109
|
+
self.stride = stride
|
|
110
|
+
channelb = int(out_channels / bottleneck)
|
|
111
|
+
|
|
112
|
+
if self.stride == 2:
|
|
113
|
+
self.maxpool = OctreeMaxPool(self.depth)
|
|
114
|
+
self.conv3x3a = OctreeConvGnRelu(in_channels, channelb, group, nempty=nempty)
|
|
115
|
+
self.conv3x3b = OctreeConvGn(channelb, out_channels, group, nempty=nempty)
|
|
116
|
+
if self.in_channels != self.out_channels:
|
|
117
|
+
self.conv1x1 = Conv1x1Gn(in_channels, out_channels, group)
|
|
118
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
119
|
+
|
|
120
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
121
|
+
r''''''
|
|
122
|
+
|
|
123
|
+
if self.stride == 2:
|
|
124
|
+
data = self.maxpool(data, octree, depth)
|
|
125
|
+
depth = depth - 1
|
|
126
|
+
conv1 = self.conv3x3a(data, octree, depth)
|
|
127
|
+
conv2 = self.conv3x3b(conv1, octree, depth)
|
|
128
|
+
if self.in_channels != self.out_channels:
|
|
129
|
+
data = self.conv1x1(data, octree, depth)
|
|
130
|
+
out = self.relu(conv2 + data)
|
|
131
|
+
return out
|
|
132
|
+
|
|
133
|
+
|
|
100
134
|
class OctreeResBlocks(torch.nn.Module):
|
|
101
135
|
r''' A sequence of :attr:`resblk_num` ResNet blocks.
|
|
102
136
|
'''
|
|
@@ -108,9 +142,9 @@ class OctreeResBlocks(torch.nn.Module):
|
|
|
108
142
|
self.use_checkpoint = use_checkpoint
|
|
109
143
|
channels = [in_channels] + [out_channels] * resblk_num
|
|
110
144
|
|
|
111
|
-
self.resblks = torch.nn.ModuleList(
|
|
112
|
-
|
|
113
|
-
|
|
145
|
+
self.resblks = torch.nn.ModuleList([resblk(
|
|
146
|
+
channels[i], channels[i+1], 1, bottleneck, nempty)
|
|
147
|
+
for i in range(self.resblk_num)])
|
|
114
148
|
|
|
115
149
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
116
150
|
r''''''
|
|
@@ -118,7 +152,7 @@ class OctreeResBlocks(torch.nn.Module):
|
|
|
118
152
|
for i in range(self.resblk_num):
|
|
119
153
|
if self.use_checkpoint:
|
|
120
154
|
data = torch.utils.checkpoint.checkpoint(
|
|
121
|
-
self.resblks[i], data, octree, depth)
|
|
155
|
+
self.resblks[i], data, octree, depth, use_reentrant=False)
|
|
122
156
|
else:
|
|
123
157
|
data = self.resblks[i](data, octree, depth)
|
|
124
158
|
return data
|
|
@@ -17,7 +17,8 @@ from .octree_pool import (octree_max_pool, OctreeMaxPool,
|
|
|
17
17
|
from .octree_conv import OctreeConv, OctreeDeconv
|
|
18
18
|
from .octree_gconv import OctreeGroupConv
|
|
19
19
|
from .octree_dwconv import OctreeDWConv
|
|
20
|
-
from .octree_norm import OctreeBatchNorm, OctreeGroupNorm,
|
|
20
|
+
from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
|
|
21
|
+
OctreeInstanceNorm, OctreeNorm)
|
|
21
22
|
from .octree_drop import OctreeDropPath
|
|
22
23
|
from .octree_align import search_value, octree_align
|
|
23
24
|
|
|
@@ -35,7 +36,7 @@ __all__ = [
|
|
|
35
36
|
'OctreeConv', 'OctreeDeconv',
|
|
36
37
|
'OctreeGroupConv', 'OctreeDWConv',
|
|
37
38
|
'OctreeInterp', 'OctreeUpsample',
|
|
38
|
-
'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm',
|
|
39
|
+
'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
|
|
39
40
|
'OctreeDropPath',
|
|
40
41
|
'search_value', 'octree_align',
|
|
41
42
|
]
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
import torch.nn
|
|
10
|
+
from typing import Optional
|
|
10
11
|
|
|
11
12
|
from ocnn.octree import Octree
|
|
12
13
|
from ocnn.utils import scatter_add
|
|
@@ -19,15 +20,19 @@ class OctreeGroupNorm(torch.nn.Module):
|
|
|
19
20
|
r''' An group normalization layer for the octree.
|
|
20
21
|
'''
|
|
21
22
|
|
|
22
|
-
def __init__(self, in_channels: int, group: int, nempty: bool = False
|
|
23
|
+
def __init__(self, in_channels: int, group: int, nempty: bool = False,
|
|
24
|
+
min_group_channels: int = 4):
|
|
23
25
|
super().__init__()
|
|
24
26
|
self.eps = 1e-5
|
|
25
27
|
self.nempty = nempty
|
|
26
28
|
self.group = group
|
|
27
29
|
self.in_channels = in_channels
|
|
30
|
+
self.min_group_channels = min_group_channels
|
|
31
|
+
if self.min_group_channels * self.group > in_channels:
|
|
32
|
+
self.group = in_channels // self.min_group_channels
|
|
28
33
|
|
|
29
|
-
assert in_channels % group == 0
|
|
30
|
-
self.channels_per_group = in_channels // group
|
|
34
|
+
assert in_channels % self.group == 0
|
|
35
|
+
self.channels_per_group = in_channels // self.group
|
|
31
36
|
|
|
32
37
|
self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
33
38
|
self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
@@ -71,8 +76,8 @@ class OctreeGroupNorm(torch.nn.Module):
|
|
|
71
76
|
return tensor
|
|
72
77
|
|
|
73
78
|
def extra_repr(self) -> str:
|
|
74
|
-
return ('in_channels={}, group={}, nempty={}').format(
|
|
75
|
-
|
|
79
|
+
return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
|
|
80
|
+
self.in_channels, self.group, self.nempty, self.min_group_channels)
|
|
76
81
|
|
|
77
82
|
|
|
78
83
|
class OctreeInstanceNorm(OctreeGroupNorm):
|
|
@@ -80,7 +85,42 @@ class OctreeInstanceNorm(OctreeGroupNorm):
|
|
|
80
85
|
'''
|
|
81
86
|
|
|
82
87
|
def __init__(self, in_channels: int, nempty: bool = False):
|
|
83
|
-
super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty
|
|
88
|
+
super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
|
|
89
|
+
min_group_channels=1) # NOTE: group=in_channels
|
|
84
90
|
|
|
85
91
|
def extra_repr(self) -> str:
|
|
86
92
|
return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class OctreeNorm(torch.nn.Module):
|
|
96
|
+
r''' A normalization layer for the octree. It encapsulates octree-based batch,
|
|
97
|
+
group and instance normalization.
|
|
98
|
+
'''
|
|
99
|
+
|
|
100
|
+
def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
|
|
101
|
+
group: int = 32, min_group_channels: int = 4):
|
|
102
|
+
super().__init__()
|
|
103
|
+
self.in_channels = in_channels
|
|
104
|
+
self.norm_type = norm_type
|
|
105
|
+
self.group = group
|
|
106
|
+
self.min_group_channels = min_group_channels
|
|
107
|
+
|
|
108
|
+
if self.norm_type == 'batch_norm':
|
|
109
|
+
self.norm = torch.nn.BatchNorm1d(in_channels)
|
|
110
|
+
elif self.norm_type == 'group_norm':
|
|
111
|
+
self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
|
|
112
|
+
elif self.norm_type == 'instance_norm':
|
|
113
|
+
self.norm = OctreeInstanceNorm(in_channels)
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError
|
|
116
|
+
|
|
117
|
+
def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
|
|
118
|
+
depth: Optional[int] = None):
|
|
119
|
+
if self.norm_type == 'batch_norm':
|
|
120
|
+
output = self.norm(x)
|
|
121
|
+
elif (self.norm_type == 'group_norm' or
|
|
122
|
+
self.norm_type == 'instance_norm'):
|
|
123
|
+
output = self.norm(x, octree, depth)
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError
|
|
126
|
+
return output
|
|
@@ -63,14 +63,14 @@ class Octree:
|
|
|
63
63
|
|
|
64
64
|
# octree node numbers in each octree layers.
|
|
65
65
|
# TODO: decide whether to settle them to 'gpu' or not?
|
|
66
|
-
self.nnum = torch.zeros(num, dtype=torch.
|
|
67
|
-
self.nnum_nempty = torch.zeros(num, dtype=torch.
|
|
66
|
+
self.nnum = torch.zeros(num, dtype=torch.long)
|
|
67
|
+
self.nnum_nempty = torch.zeros(num, dtype=torch.long)
|
|
68
68
|
|
|
69
69
|
# the following properties are valid after `merge_octrees`.
|
|
70
70
|
# TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
|
|
71
71
|
batch_size = self.batch_size
|
|
72
|
-
self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.
|
|
73
|
-
self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.
|
|
72
|
+
self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
|
|
73
|
+
self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
|
|
74
74
|
|
|
75
75
|
# construct the look up tables for neighborhood searching
|
|
76
76
|
device = self.device
|
|
@@ -290,7 +290,7 @@ class Octree:
|
|
|
290
290
|
# node number
|
|
291
291
|
nnum = self.nnum_nempty[depth-1] * 8
|
|
292
292
|
self.nnum[depth] = nnum
|
|
293
|
-
self.nnum_nempty[depth] = nnum
|
|
293
|
+
self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
|
|
294
294
|
|
|
295
295
|
# update keys
|
|
296
296
|
key = self.key(depth-1, nempty=True)
|
|
@@ -326,7 +326,7 @@ class Octree:
|
|
|
326
326
|
xyz = xyz.view(-1, 3) # (N*27, 3)
|
|
327
327
|
neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
|
|
328
328
|
|
|
329
|
-
bs = torch.arange(self.batch_size, dtype=torch.
|
|
329
|
+
bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
|
|
330
330
|
neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
|
|
331
331
|
|
|
332
332
|
bound = 1 << depth
|
|
@@ -383,9 +383,10 @@ class Octree:
|
|
|
383
383
|
# I choose `torch.bucketize` here because it has fewer dimension checks,
|
|
384
384
|
# resulting in slightly better performance according to the docs of
|
|
385
385
|
# pytorch-1.9.1, since `key` is always 1-D sorted sequence.
|
|
386
|
+
# https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
|
|
386
387
|
idx = torch.bucketize(query, key)
|
|
387
388
|
|
|
388
|
-
valid = idx < key.shape[0] # valid if
|
|
389
|
+
valid = idx < key.shape[0] # valid if in-bound
|
|
389
390
|
found = key[idx[valid]] == query[valid]
|
|
390
391
|
valid[valid.clone()] = found # valid if found
|
|
391
392
|
idx[valid.logical_not()] = -1 # set to -1 if invalid
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ocnn
|
|
3
|
-
Version: 2.2.
|
|
3
|
+
Version: 2.2.4
|
|
4
4
|
Summary: Octree-based Sparse Convolutional Neural Networks
|
|
5
5
|
Home-page: https://github.com/octree-nn/ocnn-pytorch
|
|
6
6
|
Author: Peng-Shuai Wang
|
|
@@ -23,6 +23,7 @@ Requires-Dist: packaging
|
|
|
23
23
|
|
|
24
24
|
[](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
|
|
25
25
|
[](https://pepy.tech/project/ocnn)
|
|
26
|
+
[](https://pepy.tech/project/ocnn)
|
|
26
27
|
[](https://pypi.org/project/ocnn/)
|
|
27
28
|
|
|
28
29
|
This repository contains the **pure PyTorch**-based implementation of
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|