ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +45 -44
- ocnn/nn/kernels/__init__.py +14 -0
- ocnn/nn/kernels/autotuner.py +416 -0
- ocnn/nn/kernels/config.py +67 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
- ocnn/nn/kernels/utils.py +44 -0
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +430 -429
- ocnn/nn/octree_conv_t.py +148 -0
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +770 -770
- ocnn/octree/points.py +384 -323
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
- ocnn-2.3.0.dist-info/RECORD +45 -0
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
- ocnn-2.2.8.dist-info/RECORD +0 -36
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
ocnn/models/unet.py
CHANGED
|
@@ -1,105 +1,105 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.nn
|
|
10
|
-
from typing import Dict
|
|
11
|
-
|
|
12
|
-
import ocnn
|
|
13
|
-
from ocnn.octree import Octree
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class UNet(torch.nn.Module):
|
|
17
|
-
r''' Octree-based UNet for segmentation.
|
|
18
|
-
'''
|
|
19
|
-
|
|
20
|
-
def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear',
|
|
21
|
-
nempty: bool = False, **kwargs):
|
|
22
|
-
super(UNet, self).__init__()
|
|
23
|
-
self.in_channels = in_channels
|
|
24
|
-
self.out_channels = out_channels
|
|
25
|
-
self.nempty = nempty
|
|
26
|
-
self.config_network()
|
|
27
|
-
self.encoder_stages = len(self.encoder_blocks)
|
|
28
|
-
self.decoder_stages = len(self.decoder_blocks)
|
|
29
|
-
|
|
30
|
-
# encoder
|
|
31
|
-
self.conv1 = ocnn.modules.OctreeConvBnRelu(
|
|
32
|
-
in_channels, self.encoder_channel[0], nempty=nempty)
|
|
33
|
-
self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
34
|
-
self.encoder_channel[i], self.encoder_channel[i+1], kernel_size=[2],
|
|
35
|
-
stride=2, nempty=nempty) for i in range(self.encoder_stages)])
|
|
36
|
-
self.encoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
37
|
-
self.encoder_channel[i+1], self.encoder_channel[i + 1],
|
|
38
|
-
self.encoder_blocks[i], self.bottleneck, nempty, self.resblk)
|
|
39
|
-
for i in range(self.encoder_stages)])
|
|
40
|
-
|
|
41
|
-
# decoder
|
|
42
|
-
channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2]
|
|
43
|
-
for i in range(self.decoder_stages)]
|
|
44
|
-
self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
|
|
45
|
-
self.decoder_channel[i], self.decoder_channel[i+1], kernel_size=[2],
|
|
46
|
-
stride=2, nempty=nempty) for i in range(self.decoder_stages)])
|
|
47
|
-
self.decoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
48
|
-
channel[i], self.decoder_channel[i+1],
|
|
49
|
-
self.decoder_blocks[i], self.bottleneck, nempty, self.resblk)
|
|
50
|
-
for i in range(self.decoder_stages)])
|
|
51
|
-
|
|
52
|
-
# header
|
|
53
|
-
# channel = self.decoder_channel[self.decoder_stages]
|
|
54
|
-
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
|
|
55
|
-
self.header = torch.nn.Sequential(
|
|
56
|
-
ocnn.modules.Conv1x1BnRelu(self.decoder_channel[-1], self.head_channel),
|
|
57
|
-
ocnn.modules.Conv1x1(self.head_channel, self.out_channels, use_bias=True))
|
|
58
|
-
|
|
59
|
-
def config_network(self):
|
|
60
|
-
r''' Configure the network channels and Resblock numbers.
|
|
61
|
-
'''
|
|
62
|
-
|
|
63
|
-
self.encoder_channel = [32, 32, 64, 128, 256]
|
|
64
|
-
self.decoder_channel = [256, 256, 128, 96, 96]
|
|
65
|
-
self.encoder_blocks = [2, 3, 4, 6]
|
|
66
|
-
self.decoder_blocks = [2, 2, 2, 2]
|
|
67
|
-
self.head_channel = 64
|
|
68
|
-
self.bottleneck = 1
|
|
69
|
-
self.resblk = ocnn.modules.OctreeResBlock2
|
|
70
|
-
|
|
71
|
-
def unet_encoder(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
72
|
-
r''' The encoder of the U-Net.
|
|
73
|
-
'''
|
|
74
|
-
|
|
75
|
-
convd = dict()
|
|
76
|
-
convd[depth] = self.conv1(data, octree, depth)
|
|
77
|
-
for i in range(self.encoder_stages):
|
|
78
|
-
d = depth - i
|
|
79
|
-
conv = self.downsample[i](convd[d], octree, d)
|
|
80
|
-
convd[d-1] = self.encoder[i](conv, octree, d-1)
|
|
81
|
-
return convd
|
|
82
|
-
|
|
83
|
-
def unet_decoder(self, convd: Dict[int, torch.Tensor], octree: Octree, depth: int):
|
|
84
|
-
r''' The decoder of the U-Net.
|
|
85
|
-
'''
|
|
86
|
-
|
|
87
|
-
deconv = convd[depth]
|
|
88
|
-
for i in range(self.decoder_stages):
|
|
89
|
-
d = depth + i
|
|
90
|
-
deconv = self.upsample[i](deconv, octree, d)
|
|
91
|
-
deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections
|
|
92
|
-
deconv = self.decoder[i](deconv, octree, d+1)
|
|
93
|
-
return deconv
|
|
94
|
-
|
|
95
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
96
|
-
query_pts: torch.Tensor):
|
|
97
|
-
r''''''
|
|
98
|
-
|
|
99
|
-
convd = self.unet_encoder(data, octree, depth)
|
|
100
|
-
deconv = self.unet_decoder(convd, octree, depth - self.encoder_stages)
|
|
101
|
-
|
|
102
|
-
interp_depth = depth - self.encoder_stages + self.decoder_stages
|
|
103
|
-
feature = self.octree_interp(deconv, octree, interp_depth, query_pts)
|
|
104
|
-
logits = self.header(feature)
|
|
105
|
-
return logits
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
from typing import Dict
|
|
11
|
+
|
|
12
|
+
import ocnn
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UNet(torch.nn.Module):
|
|
17
|
+
r''' Octree-based UNet for segmentation.
|
|
18
|
+
'''
|
|
19
|
+
|
|
20
|
+
def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear',
|
|
21
|
+
nempty: bool = False, **kwargs):
|
|
22
|
+
super(UNet, self).__init__()
|
|
23
|
+
self.in_channels = in_channels
|
|
24
|
+
self.out_channels = out_channels
|
|
25
|
+
self.nempty = nempty
|
|
26
|
+
self.config_network()
|
|
27
|
+
self.encoder_stages = len(self.encoder_blocks)
|
|
28
|
+
self.decoder_stages = len(self.decoder_blocks)
|
|
29
|
+
|
|
30
|
+
# encoder
|
|
31
|
+
self.conv1 = ocnn.modules.OctreeConvBnRelu(
|
|
32
|
+
in_channels, self.encoder_channel[0], nempty=nempty)
|
|
33
|
+
self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
34
|
+
self.encoder_channel[i], self.encoder_channel[i+1], kernel_size=[2],
|
|
35
|
+
stride=2, nempty=nempty) for i in range(self.encoder_stages)])
|
|
36
|
+
self.encoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
37
|
+
self.encoder_channel[i+1], self.encoder_channel[i + 1],
|
|
38
|
+
self.encoder_blocks[i], self.bottleneck, nempty, self.resblk)
|
|
39
|
+
for i in range(self.encoder_stages)])
|
|
40
|
+
|
|
41
|
+
# decoder
|
|
42
|
+
channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2]
|
|
43
|
+
for i in range(self.decoder_stages)]
|
|
44
|
+
self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
|
|
45
|
+
self.decoder_channel[i], self.decoder_channel[i+1], kernel_size=[2],
|
|
46
|
+
stride=2, nempty=nempty) for i in range(self.decoder_stages)])
|
|
47
|
+
self.decoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
48
|
+
channel[i], self.decoder_channel[i+1],
|
|
49
|
+
self.decoder_blocks[i], self.bottleneck, nempty, self.resblk)
|
|
50
|
+
for i in range(self.decoder_stages)])
|
|
51
|
+
|
|
52
|
+
# header
|
|
53
|
+
# channel = self.decoder_channel[self.decoder_stages]
|
|
54
|
+
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
|
|
55
|
+
self.header = torch.nn.Sequential(
|
|
56
|
+
ocnn.modules.Conv1x1BnRelu(self.decoder_channel[-1], self.head_channel),
|
|
57
|
+
ocnn.modules.Conv1x1(self.head_channel, self.out_channels, use_bias=True))
|
|
58
|
+
|
|
59
|
+
def config_network(self):
|
|
60
|
+
r''' Configure the network channels and Resblock numbers.
|
|
61
|
+
'''
|
|
62
|
+
|
|
63
|
+
self.encoder_channel = [32, 32, 64, 128, 256]
|
|
64
|
+
self.decoder_channel = [256, 256, 128, 96, 96]
|
|
65
|
+
self.encoder_blocks = [2, 3, 4, 6]
|
|
66
|
+
self.decoder_blocks = [2, 2, 2, 2]
|
|
67
|
+
self.head_channel = 64
|
|
68
|
+
self.bottleneck = 1
|
|
69
|
+
self.resblk = ocnn.modules.OctreeResBlock2
|
|
70
|
+
|
|
71
|
+
def unet_encoder(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
72
|
+
r''' The encoder of the U-Net.
|
|
73
|
+
'''
|
|
74
|
+
|
|
75
|
+
convd = dict()
|
|
76
|
+
convd[depth] = self.conv1(data, octree, depth)
|
|
77
|
+
for i in range(self.encoder_stages):
|
|
78
|
+
d = depth - i
|
|
79
|
+
conv = self.downsample[i](convd[d], octree, d)
|
|
80
|
+
convd[d-1] = self.encoder[i](conv, octree, d-1)
|
|
81
|
+
return convd
|
|
82
|
+
|
|
83
|
+
def unet_decoder(self, convd: Dict[int, torch.Tensor], octree: Octree, depth: int):
|
|
84
|
+
r''' The decoder of the U-Net.
|
|
85
|
+
'''
|
|
86
|
+
|
|
87
|
+
deconv = convd[depth]
|
|
88
|
+
for i in range(self.decoder_stages):
|
|
89
|
+
d = depth + i
|
|
90
|
+
deconv = self.upsample[i](deconv, octree, d)
|
|
91
|
+
deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections
|
|
92
|
+
deconv = self.decoder[i](deconv, octree, d+1)
|
|
93
|
+
return deconv
|
|
94
|
+
|
|
95
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
96
|
+
query_pts: torch.Tensor):
|
|
97
|
+
r''''''
|
|
98
|
+
|
|
99
|
+
convd = self.unet_encoder(data, octree, depth)
|
|
100
|
+
deconv = self.unet_decoder(convd, octree, depth - self.encoder_stages)
|
|
101
|
+
|
|
102
|
+
interp_depth = depth - self.encoder_stages + self.decoder_stages
|
|
103
|
+
feature = self.octree_interp(deconv, octree, interp_depth, query_pts)
|
|
104
|
+
logits = self.header(feature)
|
|
105
|
+
return logits
|
ocnn/modules/__init__.py
CHANGED
|
@@ -1,26 +1,26 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
from .modules import (InputFeature,
|
|
9
|
-
OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
|
|
10
|
-
Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
|
|
11
|
-
OctreeConvGn, OctreeConvGnRelu, OctreeDeconvGnRelu,
|
|
12
|
-
Conv1x1, Conv1x1Gn, Conv1x1GnRelu)
|
|
13
|
-
from .resblocks import (OctreeResBlock, OctreeResBlock2, OctreeResBlockGn,
|
|
14
|
-
OctreeResBlocks,)
|
|
15
|
-
|
|
16
|
-
__all__ = [
|
|
17
|
-
'InputFeature',
|
|
18
|
-
'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
|
|
19
|
-
'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
|
|
20
|
-
'OctreeConvGn', 'OctreeConvGnRelu', 'OctreeDeconvGnRelu',
|
|
21
|
-
'Conv1x1', 'Conv1x1Gn', 'Conv1x1GnRelu',
|
|
22
|
-
'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlockGn',
|
|
23
|
-
'OctreeResBlocks',
|
|
24
|
-
]
|
|
25
|
-
|
|
26
|
-
classes = __all__
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
from .modules import (InputFeature,
|
|
9
|
+
OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
|
|
10
|
+
Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
|
|
11
|
+
OctreeConvGn, OctreeConvGnRelu, OctreeDeconvGnRelu,
|
|
12
|
+
Conv1x1, Conv1x1Gn, Conv1x1GnRelu)
|
|
13
|
+
from .resblocks import (OctreeResBlock, OctreeResBlock2, OctreeResBlockGn,
|
|
14
|
+
OctreeResBlocks,)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
'InputFeature',
|
|
18
|
+
'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
|
|
19
|
+
'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
|
|
20
|
+
'OctreeConvGn', 'OctreeConvGnRelu', 'OctreeDeconvGnRelu',
|
|
21
|
+
'Conv1x1', 'Conv1x1Gn', 'Conv1x1GnRelu',
|
|
22
|
+
'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlockGn',
|
|
23
|
+
'OctreeResBlocks',
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
classes = __all__
|