ocnn 2.2.3__tar.gz → 2.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ocnn-2.2.3/ocnn.egg-info → ocnn-2.2.5}/PKG-INFO +2 -1
- {ocnn-2.2.3 → ocnn-2.2.5}/README.md +1 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/__init__.py +1 -1
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/models/autoencoder.py +2 -2
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/models/ounet.py +4 -4
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/models/resnet.py +1 -1
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/models/segnet.py +2 -2
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/modules/__init__.py +9 -3
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/modules/modules.py +117 -7
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/modules/resblocks.py +39 -5
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/__init__.py +3 -2
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/octree_norm.py +46 -6
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/octree/octree.py +29 -9
- {ocnn-2.2.3 → ocnn-2.2.5/ocnn.egg-info}/PKG-INFO +2 -1
- {ocnn-2.2.3 → ocnn-2.2.5}/setup.py +1 -1
- {ocnn-2.2.3 → ocnn-2.2.5}/LICENSE +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/MANIFEST.in +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/dataset.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/models/__init__.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/models/hrnet.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/models/image2shape.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/models/lenet.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/models/unet.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/octree2col.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/octree2vox.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/octree_align.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/octree_conv.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/octree_drop.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/octree_dwconv.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/octree_gconv.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/octree_interp.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/octree_pad.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/nn/octree_pool.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/octree/__init__.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/octree/points.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/octree/shuffled_key.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn/utils.py +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn.egg-info/SOURCES.txt +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn.egg-info/dependency_links.txt +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn.egg-info/not-zip-safe +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn.egg-info/requires.txt +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/ocnn.egg-info/top_level.txt +0 -0
- {ocnn-2.2.3 → ocnn-2.2.5}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ocnn
|
|
3
|
-
Version: 2.2.
|
|
3
|
+
Version: 2.2.5
|
|
4
4
|
Summary: Octree-based Sparse Convolutional Neural Networks
|
|
5
5
|
Home-page: https://github.com/octree-nn/ocnn-pytorch
|
|
6
6
|
Author: Peng-Shuai Wang
|
|
@@ -23,6 +23,7 @@ Requires-Dist: packaging
|
|
|
23
23
|
|
|
24
24
|
[](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
|
|
25
25
|
[](https://pepy.tech/project/ocnn)
|
|
26
|
+
[](https://pepy.tech/project/ocnn)
|
|
26
27
|
[](https://pypi.org/project/ocnn/)
|
|
27
28
|
|
|
28
29
|
This repository contains the **pure PyTorch**-based implementation of
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
[](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
|
|
6
6
|
[](https://pepy.tech/project/ocnn)
|
|
7
|
+
[](https://pepy.tech/project/ocnn)
|
|
7
8
|
[](https://pypi.org/project/ocnn/)
|
|
8
9
|
|
|
9
10
|
This repository contains the **pure PyTorch**-based implementation of
|
|
@@ -7,7 +7,6 @@
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
import torch.nn
|
|
10
|
-
from typing import Optional
|
|
11
10
|
|
|
12
11
|
import ocnn
|
|
13
12
|
from ocnn.octree import Octree
|
|
@@ -34,8 +33,9 @@ class AutoEncoder(torch.nn.Module):
|
|
|
34
33
|
self.full_depth = full_depth
|
|
35
34
|
self.feature = feature
|
|
36
35
|
self.resblk_num = 2
|
|
37
|
-
self.code_channel = 64 # dim-of-code = code_channel * 2**(3*full_depth)
|
|
38
36
|
self.channels = [512, 512, 256, 256, 128, 128, 32, 32, 16, 16]
|
|
37
|
+
# dim-of-code = code_channel * 2**(3*full_depth)
|
|
38
|
+
self.code_channel = self.channels[full_depth]
|
|
39
39
|
|
|
40
40
|
# encoder
|
|
41
41
|
self.conv1 = ocnn.modules.OctreeConvBnRelu(
|
|
@@ -18,17 +18,17 @@ class OUNet(AutoEncoder):
|
|
|
18
18
|
|
|
19
19
|
def __init__(self, channel_in: int, channel_out: int, depth: int,
|
|
20
20
|
full_depth: int = 2, feature: str = 'ND'):
|
|
21
|
-
super().__init__(channel_in, channel_out, depth, full_depth, feature
|
|
22
|
-
code_channel=-1) # !set code_channe=-1
|
|
21
|
+
super().__init__(channel_in, channel_out, depth, full_depth, feature)
|
|
23
22
|
self.proj = None # remove this module used in AutoEncoder
|
|
24
23
|
|
|
25
24
|
def encoder(self, octree):
|
|
26
|
-
r''' The encoder network for extracting heirarchy features.
|
|
25
|
+
r''' The encoder network for extracting heirarchy features.
|
|
27
26
|
'''
|
|
28
27
|
|
|
29
28
|
convs = dict()
|
|
30
29
|
depth, full_depth = self.depth, self.full_depth
|
|
31
|
-
data =
|
|
30
|
+
data = octree.get_input_feature(self.feature, nempty=False)
|
|
31
|
+
assert data.size(1) == self.channel_in
|
|
32
32
|
convs[depth] = self.conv1(data, octree, depth)
|
|
33
33
|
for i, d in enumerate(range(depth, full_depth-1, -1)):
|
|
34
34
|
convs[d] = self.encoder_blks[i](convs[d], octree, d)
|
|
@@ -31,7 +31,7 @@ class ResNet(torch.nn.Module):
|
|
|
31
31
|
channels[i], channels[i+1], resblock_num, nempty=nempty)
|
|
32
32
|
for i in range(stages-1)])
|
|
33
33
|
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
34
|
-
nempty) for
|
|
34
|
+
nempty) for _ in range(stages-1)])
|
|
35
35
|
self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
|
|
36
36
|
# self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
|
|
37
37
|
self.header = torch.nn.Sequential(
|
|
@@ -29,7 +29,7 @@ class SegNet(torch.nn.Module):
|
|
|
29
29
|
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
30
30
|
for i in range(stages)])
|
|
31
31
|
self.pools = torch.nn.ModuleList(
|
|
32
|
-
[ocnn.nn.OctreeMaxPool(nempty, return_indices) for
|
|
32
|
+
[ocnn.nn.OctreeMaxPool(nempty, return_indices) for _ in range(stages)])
|
|
33
33
|
|
|
34
34
|
self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
|
|
35
35
|
|
|
@@ -38,7 +38,7 @@ class SegNet(torch.nn.Module):
|
|
|
38
38
|
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
39
39
|
for i in range(0, stages)])
|
|
40
40
|
self.unpools = torch.nn.ModuleList(
|
|
41
|
-
[ocnn.nn.OctreeMaxUnpool(nempty) for
|
|
41
|
+
[ocnn.nn.OctreeMaxUnpool(nempty) for _ in range(stages)])
|
|
42
42
|
|
|
43
43
|
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
|
|
44
44
|
self.header = torch.nn.Sequential(
|
|
@@ -7,14 +7,20 @@
|
|
|
7
7
|
|
|
8
8
|
from .modules import (InputFeature,
|
|
9
9
|
OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
|
|
10
|
-
Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
|
|
11
|
-
|
|
10
|
+
Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
|
|
11
|
+
OctreeConvGn, OctreeConvGnRelu, OctreeDeconvGnRelu,
|
|
12
|
+
Conv1x1, Conv1x1Gn, Conv1x1GnRelu)
|
|
13
|
+
from .resblocks import (OctreeResBlock, OctreeResBlock2, OctreeResBlockGn,
|
|
14
|
+
OctreeResBlocks,)
|
|
12
15
|
|
|
13
16
|
__all__ = [
|
|
14
17
|
'InputFeature',
|
|
15
18
|
'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
|
|
16
19
|
'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
|
|
17
|
-
'
|
|
20
|
+
'OctreeConvGn', 'OctreeConvGnRelu', 'OctreeDeconvGnRelu',
|
|
21
|
+
'Conv1x1', 'Conv1x1Gn', 'Conv1x1GnRelu',
|
|
22
|
+
'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlockGn',
|
|
23
|
+
'OctreeResBlocks',
|
|
18
24
|
]
|
|
19
25
|
|
|
20
26
|
classes = __all__
|
|
@@ -9,7 +9,7 @@ import torch
|
|
|
9
9
|
import torch.utils.checkpoint
|
|
10
10
|
from typing import List
|
|
11
11
|
|
|
12
|
-
from ocnn.nn import OctreeConv, OctreeDeconv
|
|
12
|
+
from ocnn.nn import OctreeConv, OctreeDeconv, OctreeGroupNorm
|
|
13
13
|
from ocnn.octree import Octree
|
|
14
14
|
|
|
15
15
|
|
|
@@ -40,7 +40,7 @@ class OctreeConvBn(torch.nn.Module):
|
|
|
40
40
|
super().__init__()
|
|
41
41
|
self.conv = OctreeConv(
|
|
42
42
|
in_channels, out_channels, kernel_size, stride, nempty)
|
|
43
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
43
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
44
44
|
|
|
45
45
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
46
46
|
r''''''
|
|
@@ -62,7 +62,7 @@ class OctreeConvBnRelu(torch.nn.Module):
|
|
|
62
62
|
super().__init__()
|
|
63
63
|
self.conv = OctreeConv(
|
|
64
64
|
in_channels, out_channels, kernel_size, stride, nempty)
|
|
65
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
65
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
66
66
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
67
67
|
|
|
68
68
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
@@ -86,7 +86,7 @@ class OctreeDeconvBnRelu(torch.nn.Module):
|
|
|
86
86
|
super().__init__()
|
|
87
87
|
self.deconv = OctreeDeconv(
|
|
88
88
|
in_channels, out_channels, kernel_size, stride, nempty)
|
|
89
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
89
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
90
90
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
91
91
|
|
|
92
92
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
@@ -123,7 +123,7 @@ class Conv1x1Bn(torch.nn.Module):
|
|
|
123
123
|
def __init__(self, in_channels: int, out_channels: int):
|
|
124
124
|
super().__init__()
|
|
125
125
|
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
126
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
126
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
127
127
|
|
|
128
128
|
def forward(self, data: torch.Tensor):
|
|
129
129
|
r''''''
|
|
@@ -140,7 +140,7 @@ class Conv1x1BnRelu(torch.nn.Module):
|
|
|
140
140
|
def __init__(self, in_channels: int, out_channels: int):
|
|
141
141
|
super().__init__()
|
|
142
142
|
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
143
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
143
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
144
144
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
145
145
|
|
|
146
146
|
def forward(self, data: torch.Tensor):
|
|
@@ -160,7 +160,7 @@ class FcBnRelu(torch.nn.Module):
|
|
|
160
160
|
super().__init__()
|
|
161
161
|
self.flatten = torch.nn.Flatten(start_dim=1)
|
|
162
162
|
self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
|
|
163
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
163
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
164
164
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
165
165
|
|
|
166
166
|
def forward(self, data):
|
|
@@ -173,6 +173,116 @@ class FcBnRelu(torch.nn.Module):
|
|
|
173
173
|
return out
|
|
174
174
|
|
|
175
175
|
|
|
176
|
+
class OctreeConvGn(torch.nn.Module):
|
|
177
|
+
r''' A sequence of :class:`OctreeConv` and :obj:`OctreeGroupNorm`.
|
|
178
|
+
|
|
179
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
180
|
+
'''
|
|
181
|
+
|
|
182
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
183
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
184
|
+
nempty: bool = False):
|
|
185
|
+
super().__init__()
|
|
186
|
+
self.conv = OctreeConv(
|
|
187
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
188
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
189
|
+
|
|
190
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
191
|
+
r''''''
|
|
192
|
+
|
|
193
|
+
out = self.conv(data, octree, depth)
|
|
194
|
+
out = self.gn(out, octree, depth)
|
|
195
|
+
return out
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class OctreeConvGnRelu(torch.nn.Module):
|
|
199
|
+
r''' A sequence of :class:`OctreeConv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
|
|
200
|
+
|
|
201
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
202
|
+
'''
|
|
203
|
+
|
|
204
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
205
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
206
|
+
nempty: bool = False):
|
|
207
|
+
super().__init__()
|
|
208
|
+
self.stride = stride
|
|
209
|
+
self.conv = OctreeConv(
|
|
210
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
211
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
212
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
213
|
+
|
|
214
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
215
|
+
r''''''
|
|
216
|
+
|
|
217
|
+
out = self.conv(data, octree, depth)
|
|
218
|
+
out = self.gn(out, octree, depth if self.stride == 1 else depth - 1)
|
|
219
|
+
out = self.relu(out)
|
|
220
|
+
return out
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class OctreeDeconvGnRelu(torch.nn.Module):
|
|
224
|
+
r''' A sequence of :class:`OctreeDeconv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
|
|
225
|
+
|
|
226
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
227
|
+
'''
|
|
228
|
+
|
|
229
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
230
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
231
|
+
nempty: bool = False):
|
|
232
|
+
super().__init__()
|
|
233
|
+
self.stride = stride
|
|
234
|
+
self.deconv = OctreeDeconv(
|
|
235
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
236
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
237
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
238
|
+
|
|
239
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
240
|
+
r''''''
|
|
241
|
+
|
|
242
|
+
out = self.deconv(data, octree, depth)
|
|
243
|
+
out = self.gn(out, octree, depth if self.stride == 1 else depth + 1)
|
|
244
|
+
out = self.relu(out)
|
|
245
|
+
return out
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class Conv1x1Gn(torch.nn.Module):
|
|
249
|
+
r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm`.
|
|
250
|
+
'''
|
|
251
|
+
|
|
252
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
253
|
+
nempty: bool = False):
|
|
254
|
+
super().__init__()
|
|
255
|
+
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
256
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
257
|
+
|
|
258
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
259
|
+
r''''''
|
|
260
|
+
|
|
261
|
+
out = self.conv(data)
|
|
262
|
+
out = self.gn(out, octree, depth)
|
|
263
|
+
return out
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class Conv1x1GnRelu(torch.nn.Module):
|
|
267
|
+
r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm` and :class:`Relu`.
|
|
268
|
+
'''
|
|
269
|
+
|
|
270
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
271
|
+
nempty: bool = False):
|
|
272
|
+
super().__init__()
|
|
273
|
+
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
274
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
275
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
276
|
+
|
|
277
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
278
|
+
r''''''
|
|
279
|
+
|
|
280
|
+
out = self.conv(data)
|
|
281
|
+
out = self.gn(out, octree, depth)
|
|
282
|
+
out = self.relu(out)
|
|
283
|
+
return out
|
|
284
|
+
|
|
285
|
+
|
|
176
286
|
class InputFeature(torch.nn.Module):
|
|
177
287
|
r''' Returns the initial input feature stored in octree.
|
|
178
288
|
|
|
@@ -10,7 +10,9 @@ import torch.utils.checkpoint
|
|
|
10
10
|
|
|
11
11
|
from ocnn.octree import Octree
|
|
12
12
|
from ocnn.nn import OctreeMaxPool
|
|
13
|
-
from ocnn.modules import Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
|
|
13
|
+
from ocnn.modules import (Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
|
|
14
|
+
OctreeConvBn, OctreeConvGnRelu, Conv1x1Gn,
|
|
15
|
+
OctreeConvGn,)
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class OctreeResBlock(torch.nn.Module):
|
|
@@ -97,6 +99,38 @@ class OctreeResBlock2(torch.nn.Module):
|
|
|
97
99
|
return out
|
|
98
100
|
|
|
99
101
|
|
|
102
|
+
class OctreeResBlockGn(torch.nn.Module):
|
|
103
|
+
|
|
104
|
+
def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
|
|
105
|
+
bottleneck: int = 4, nempty: bool = False, group: int = 32):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.in_channels = in_channels
|
|
108
|
+
self.out_channels = out_channels
|
|
109
|
+
self.stride = stride
|
|
110
|
+
channelb = int(out_channels / bottleneck)
|
|
111
|
+
|
|
112
|
+
if self.stride == 2:
|
|
113
|
+
self.maxpool = OctreeMaxPool(self.depth)
|
|
114
|
+
self.conv3x3a = OctreeConvGnRelu(in_channels, channelb, group, nempty=nempty)
|
|
115
|
+
self.conv3x3b = OctreeConvGn(channelb, out_channels, group, nempty=nempty)
|
|
116
|
+
if self.in_channels != self.out_channels:
|
|
117
|
+
self.conv1x1 = Conv1x1Gn(in_channels, out_channels, group)
|
|
118
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
119
|
+
|
|
120
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
121
|
+
r''''''
|
|
122
|
+
|
|
123
|
+
if self.stride == 2:
|
|
124
|
+
data = self.maxpool(data, octree, depth)
|
|
125
|
+
depth = depth - 1
|
|
126
|
+
conv1 = self.conv3x3a(data, octree, depth)
|
|
127
|
+
conv2 = self.conv3x3b(conv1, octree, depth)
|
|
128
|
+
if self.in_channels != self.out_channels:
|
|
129
|
+
data = self.conv1x1(data, octree, depth)
|
|
130
|
+
out = self.relu(conv2 + data)
|
|
131
|
+
return out
|
|
132
|
+
|
|
133
|
+
|
|
100
134
|
class OctreeResBlocks(torch.nn.Module):
|
|
101
135
|
r''' A sequence of :attr:`resblk_num` ResNet blocks.
|
|
102
136
|
'''
|
|
@@ -108,9 +142,9 @@ class OctreeResBlocks(torch.nn.Module):
|
|
|
108
142
|
self.use_checkpoint = use_checkpoint
|
|
109
143
|
channels = [in_channels] + [out_channels] * resblk_num
|
|
110
144
|
|
|
111
|
-
self.resblks = torch.nn.ModuleList(
|
|
112
|
-
|
|
113
|
-
|
|
145
|
+
self.resblks = torch.nn.ModuleList([resblk(
|
|
146
|
+
channels[i], channels[i+1], 1, bottleneck, nempty)
|
|
147
|
+
for i in range(self.resblk_num)])
|
|
114
148
|
|
|
115
149
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
116
150
|
r''''''
|
|
@@ -118,7 +152,7 @@ class OctreeResBlocks(torch.nn.Module):
|
|
|
118
152
|
for i in range(self.resblk_num):
|
|
119
153
|
if self.use_checkpoint:
|
|
120
154
|
data = torch.utils.checkpoint.checkpoint(
|
|
121
|
-
self.resblks[i], data, octree, depth)
|
|
155
|
+
self.resblks[i], data, octree, depth, use_reentrant=False)
|
|
122
156
|
else:
|
|
123
157
|
data = self.resblks[i](data, octree, depth)
|
|
124
158
|
return data
|
|
@@ -17,7 +17,8 @@ from .octree_pool import (octree_max_pool, OctreeMaxPool,
|
|
|
17
17
|
from .octree_conv import OctreeConv, OctreeDeconv
|
|
18
18
|
from .octree_gconv import OctreeGroupConv
|
|
19
19
|
from .octree_dwconv import OctreeDWConv
|
|
20
|
-
from .octree_norm import OctreeBatchNorm, OctreeGroupNorm,
|
|
20
|
+
from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
|
|
21
|
+
OctreeInstanceNorm, OctreeNorm)
|
|
21
22
|
from .octree_drop import OctreeDropPath
|
|
22
23
|
from .octree_align import search_value, octree_align
|
|
23
24
|
|
|
@@ -35,7 +36,7 @@ __all__ = [
|
|
|
35
36
|
'OctreeConv', 'OctreeDeconv',
|
|
36
37
|
'OctreeGroupConv', 'OctreeDWConv',
|
|
37
38
|
'OctreeInterp', 'OctreeUpsample',
|
|
38
|
-
'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm',
|
|
39
|
+
'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
|
|
39
40
|
'OctreeDropPath',
|
|
40
41
|
'search_value', 'octree_align',
|
|
41
42
|
]
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
import torch.nn
|
|
10
|
+
from typing import Optional
|
|
10
11
|
|
|
11
12
|
from ocnn.octree import Octree
|
|
12
13
|
from ocnn.utils import scatter_add
|
|
@@ -19,15 +20,19 @@ class OctreeGroupNorm(torch.nn.Module):
|
|
|
19
20
|
r''' An group normalization layer for the octree.
|
|
20
21
|
'''
|
|
21
22
|
|
|
22
|
-
def __init__(self, in_channels: int, group: int, nempty: bool = False
|
|
23
|
+
def __init__(self, in_channels: int, group: int, nempty: bool = False,
|
|
24
|
+
min_group_channels: int = 4):
|
|
23
25
|
super().__init__()
|
|
24
26
|
self.eps = 1e-5
|
|
25
27
|
self.nempty = nempty
|
|
26
28
|
self.group = group
|
|
27
29
|
self.in_channels = in_channels
|
|
30
|
+
self.min_group_channels = min_group_channels
|
|
31
|
+
if self.min_group_channels * self.group > in_channels:
|
|
32
|
+
self.group = in_channels // self.min_group_channels
|
|
28
33
|
|
|
29
|
-
assert in_channels % group == 0
|
|
30
|
-
self.channels_per_group = in_channels // group
|
|
34
|
+
assert in_channels % self.group == 0
|
|
35
|
+
self.channels_per_group = in_channels // self.group
|
|
31
36
|
|
|
32
37
|
self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
33
38
|
self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
@@ -71,8 +76,8 @@ class OctreeGroupNorm(torch.nn.Module):
|
|
|
71
76
|
return tensor
|
|
72
77
|
|
|
73
78
|
def extra_repr(self) -> str:
|
|
74
|
-
return ('in_channels={}, group={}, nempty={}').format(
|
|
75
|
-
|
|
79
|
+
return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
|
|
80
|
+
self.in_channels, self.group, self.nempty, self.min_group_channels)
|
|
76
81
|
|
|
77
82
|
|
|
78
83
|
class OctreeInstanceNorm(OctreeGroupNorm):
|
|
@@ -80,7 +85,42 @@ class OctreeInstanceNorm(OctreeGroupNorm):
|
|
|
80
85
|
'''
|
|
81
86
|
|
|
82
87
|
def __init__(self, in_channels: int, nempty: bool = False):
|
|
83
|
-
super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty
|
|
88
|
+
super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
|
|
89
|
+
min_group_channels=1) # NOTE: group=in_channels
|
|
84
90
|
|
|
85
91
|
def extra_repr(self) -> str:
|
|
86
92
|
return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class OctreeNorm(torch.nn.Module):
|
|
96
|
+
r''' A normalization layer for the octree. It encapsulates octree-based batch,
|
|
97
|
+
group and instance normalization.
|
|
98
|
+
'''
|
|
99
|
+
|
|
100
|
+
def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
|
|
101
|
+
group: int = 32, min_group_channels: int = 4):
|
|
102
|
+
super().__init__()
|
|
103
|
+
self.in_channels = in_channels
|
|
104
|
+
self.norm_type = norm_type
|
|
105
|
+
self.group = group
|
|
106
|
+
self.min_group_channels = min_group_channels
|
|
107
|
+
|
|
108
|
+
if self.norm_type == 'batch_norm':
|
|
109
|
+
self.norm = torch.nn.BatchNorm1d(in_channels)
|
|
110
|
+
elif self.norm_type == 'group_norm':
|
|
111
|
+
self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
|
|
112
|
+
elif self.norm_type == 'instance_norm':
|
|
113
|
+
self.norm = OctreeInstanceNorm(in_channels)
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError
|
|
116
|
+
|
|
117
|
+
def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
|
|
118
|
+
depth: Optional[int] = None):
|
|
119
|
+
if self.norm_type == 'batch_norm':
|
|
120
|
+
output = self.norm(x)
|
|
121
|
+
elif (self.norm_type == 'group_norm' or
|
|
122
|
+
self.norm_type == 'instance_norm'):
|
|
123
|
+
output = self.norm(x, octree, depth)
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError
|
|
126
|
+
return output
|
|
@@ -35,7 +35,9 @@ class Octree:
|
|
|
35
35
|
and :obj:`points`, contain only non-empty nodes.
|
|
36
36
|
|
|
37
37
|
.. note::
|
|
38
|
-
The point cloud must be in range :obj:`[-1, 1]`.
|
|
38
|
+
The point cloud must be strictly in range :obj:`[-1, 1]`. A good practice
|
|
39
|
+
is to normalize it into :obj:`[-0.99, 0.99]` or :obj:`[0.9, 0.9]` to retain
|
|
40
|
+
some margin.
|
|
39
41
|
'''
|
|
40
42
|
|
|
41
43
|
def __init__(self, depth: int, full_depth: int = 2, batch_size: int = 1,
|
|
@@ -63,14 +65,14 @@ class Octree:
|
|
|
63
65
|
|
|
64
66
|
# octree node numbers in each octree layers.
|
|
65
67
|
# TODO: decide whether to settle them to 'gpu' or not?
|
|
66
|
-
self.nnum = torch.zeros(num, dtype=torch.
|
|
67
|
-
self.nnum_nempty = torch.zeros(num, dtype=torch.
|
|
68
|
+
self.nnum = torch.zeros(num, dtype=torch.long)
|
|
69
|
+
self.nnum_nempty = torch.zeros(num, dtype=torch.long)
|
|
68
70
|
|
|
69
71
|
# the following properties are valid after `merge_octrees`.
|
|
70
72
|
# TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
|
|
71
73
|
batch_size = self.batch_size
|
|
72
|
-
self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.
|
|
73
|
-
self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.
|
|
74
|
+
self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
|
|
75
|
+
self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
|
|
74
76
|
|
|
75
77
|
# construct the look up tables for neighborhood searching
|
|
76
78
|
device = self.device
|
|
@@ -176,7 +178,7 @@ class Octree:
|
|
|
176
178
|
for d in range(self.depth, self.full_depth, -1):
|
|
177
179
|
# compute parent key, i.e. keys of layer (d -1)
|
|
178
180
|
pkey = node_key >> 3
|
|
179
|
-
pkey, pidx,
|
|
181
|
+
pkey, pidx, _ = torch.unique_consecutive(
|
|
180
182
|
pkey, return_inverse=True, return_counts=True)
|
|
181
183
|
|
|
182
184
|
# augmented key
|
|
@@ -287,10 +289,27 @@ class Octree:
|
|
|
287
289
|
update_neigh (bool): If True, construct the neighborhood indices.
|
|
288
290
|
'''
|
|
289
291
|
|
|
292
|
+
# increase the octree depth if required
|
|
293
|
+
if depth > self.depth:
|
|
294
|
+
assert depth == self.depth + 1
|
|
295
|
+
self.depth = depth
|
|
296
|
+
self.keys.append(None)
|
|
297
|
+
self.children.append(None)
|
|
298
|
+
self.neighs.append(None)
|
|
299
|
+
self.features.append(None)
|
|
300
|
+
self.normals.append(None)
|
|
301
|
+
self.points.append(None)
|
|
302
|
+
zero = torch.zeros(1, dtype=torch.long)
|
|
303
|
+
self.nnum = torch.cat([self.nnum, zero])
|
|
304
|
+
self.nnum_nempty = torch.cat([self.nnum_nempty, zero])
|
|
305
|
+
zero = zero.view(1, 1)
|
|
306
|
+
self.batch_nnum = torch.cat([self.batch_nnum, zero], dim=0)
|
|
307
|
+
self.batch_nnum_nempty = torch.cat([self.batch_nnum_nempty, zero], dim=0)
|
|
308
|
+
|
|
290
309
|
# node number
|
|
291
310
|
nnum = self.nnum_nempty[depth-1] * 8
|
|
292
311
|
self.nnum[depth] = nnum
|
|
293
|
-
self.nnum_nempty[depth] = nnum
|
|
312
|
+
self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
|
|
294
313
|
|
|
295
314
|
# update keys
|
|
296
315
|
key = self.key(depth-1, nempty=True)
|
|
@@ -326,7 +345,7 @@ class Octree:
|
|
|
326
345
|
xyz = xyz.view(-1, 3) # (N*27, 3)
|
|
327
346
|
neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
|
|
328
347
|
|
|
329
|
-
bs = torch.arange(self.batch_size, dtype=torch.
|
|
348
|
+
bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
|
|
330
349
|
neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
|
|
331
350
|
|
|
332
351
|
bound = 1 << depth
|
|
@@ -383,9 +402,10 @@ class Octree:
|
|
|
383
402
|
# I choose `torch.bucketize` here because it has fewer dimension checks,
|
|
384
403
|
# resulting in slightly better performance according to the docs of
|
|
385
404
|
# pytorch-1.9.1, since `key` is always 1-D sorted sequence.
|
|
405
|
+
# https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
|
|
386
406
|
idx = torch.bucketize(query, key)
|
|
387
407
|
|
|
388
|
-
valid = idx < key.shape[0] # valid if
|
|
408
|
+
valid = idx < key.shape[0] # valid if in-bound
|
|
389
409
|
found = key[idx[valid]] == query[valid]
|
|
390
410
|
valid[valid.clone()] = found # valid if found
|
|
391
411
|
idx[valid.logical_not()] = -1 # set to -1 if invalid
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ocnn
|
|
3
|
-
Version: 2.2.
|
|
3
|
+
Version: 2.2.5
|
|
4
4
|
Summary: Octree-based Sparse Convolutional Neural Networks
|
|
5
5
|
Home-page: https://github.com/octree-nn/ocnn-pytorch
|
|
6
6
|
Author: Peng-Shuai Wang
|
|
@@ -23,6 +23,7 @@ Requires-Dist: packaging
|
|
|
23
23
|
|
|
24
24
|
[](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
|
|
25
25
|
[](https://pepy.tech/project/ocnn)
|
|
26
|
+
[](https://pepy.tech/project/ocnn)
|
|
26
27
|
[](https://pypi.org/project/ocnn/)
|
|
27
28
|
|
|
28
29
|
This repository contains the **pure PyTorch**-based implementation of
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|