ocnn 2.2.1__py3-none-any.whl → 2.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -158
- ocnn/models/__init__.py +29 -27
- ocnn/models/autoencoder.py +155 -165
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -0
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +20 -20
- ocnn/modules/modules.py +193 -231
- ocnn/modules/resblocks.py +124 -124
- ocnn/nn/__init__.py +43 -42
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +429 -411
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -204
- ocnn/nn/octree_gconv.py +79 -0
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +86 -86
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -21
- ocnn/octree/octree.py +639 -601
- ocnn/octree/points.py +322 -298
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -153
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/LICENSE +21 -21
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/METADATA +79 -65
- ocnn-2.2.3.dist-info/RECORD +36 -0
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/WHEEL +1 -1
- ocnn-2.2.1.dist-info/RECORD +0 -34
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/top_level.txt +0 -0
ocnn/models/ounet.py
CHANGED
|
@@ -1,94 +1,94 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
import ocnn
|
|
10
|
-
import torch
|
|
11
|
-
import torch.nn
|
|
12
|
-
|
|
13
|
-
from ocnn.octree import Octree
|
|
14
|
-
from ocnn.models.autoencoder import AutoEncoder
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class OUNet(AutoEncoder):
|
|
18
|
-
|
|
19
|
-
def __init__(self, channel_in: int, channel_out: int, depth: int,
|
|
20
|
-
full_depth: int = 2, feature: str = 'ND'):
|
|
21
|
-
super().__init__(channel_in, channel_out, depth, full_depth, feature,
|
|
22
|
-
code_channel=-1) # !set code_channe=-1
|
|
23
|
-
self.proj = None # remove this module used in AutoEncoder
|
|
24
|
-
|
|
25
|
-
def encoder(self, octree):
|
|
26
|
-
r''' The encoder network for extracting heirarchy features.
|
|
27
|
-
'''
|
|
28
|
-
|
|
29
|
-
convs = dict()
|
|
30
|
-
depth, full_depth = self.depth, self.full_depth
|
|
31
|
-
data = self.get_input_feature(octree)
|
|
32
|
-
convs[depth] = self.conv1(data, octree, depth)
|
|
33
|
-
for i, d in enumerate(range(depth, full_depth-1, -1)):
|
|
34
|
-
convs[d] = self.encoder_blks[i](convs[d], octree, d)
|
|
35
|
-
if d > full_depth:
|
|
36
|
-
convs[d-1] = self.downsample[i](convs[d], octree, d)
|
|
37
|
-
return convs
|
|
38
|
-
|
|
39
|
-
def decoder(self, convs: dict, octree_in: Octree, octree_out: Octree,
|
|
40
|
-
update_octree: bool = False):
|
|
41
|
-
r''' The decoder network for decode the octree.
|
|
42
|
-
'''
|
|
43
|
-
|
|
44
|
-
logits = dict()
|
|
45
|
-
deconv = convs[self.full_depth]
|
|
46
|
-
depth, full_depth = self.depth, self.full_depth
|
|
47
|
-
for i, d in enumerate(range(full_depth, depth + 1)):
|
|
48
|
-
if d > full_depth:
|
|
49
|
-
deconv = self.upsample[i-1](deconv, octree_out, d-1)
|
|
50
|
-
skip = ocnn.nn.octree_align(convs[d], octree_in, octree_out, d)
|
|
51
|
-
deconv = deconv + skip # output-guided skip connections
|
|
52
|
-
deconv = self.decoder_blks[i](deconv, octree_out, d)
|
|
53
|
-
|
|
54
|
-
# predict the splitting label
|
|
55
|
-
logit = self.predict[i](deconv)
|
|
56
|
-
logits[d] = logit
|
|
57
|
-
|
|
58
|
-
# update the octree according to predicted labels
|
|
59
|
-
if update_octree:
|
|
60
|
-
split = logit.argmax(1).int()
|
|
61
|
-
octree_out.octree_split(split, d)
|
|
62
|
-
if d < depth:
|
|
63
|
-
octree_out.octree_grow(d + 1)
|
|
64
|
-
|
|
65
|
-
# predict the signal
|
|
66
|
-
if d == depth:
|
|
67
|
-
signal = self.header(deconv)
|
|
68
|
-
signal = torch.tanh(signal)
|
|
69
|
-
signal = ocnn.nn.octree_depad(signal, octree_out, depth)
|
|
70
|
-
if update_octree:
|
|
71
|
-
octree_out.features[depth] = signal
|
|
72
|
-
|
|
73
|
-
return {'logits': logits, 'signal': signal, 'octree_out': octree_out}
|
|
74
|
-
|
|
75
|
-
def init_octree(self, octree_in: Octree):
|
|
76
|
-
r''' Initialize a full octree for decoding.
|
|
77
|
-
'''
|
|
78
|
-
|
|
79
|
-
device = octree_in.device
|
|
80
|
-
batch_size = octree_in.batch_size
|
|
81
|
-
octree = Octree(self.depth, self.full_depth, batch_size, device)
|
|
82
|
-
for d in range(self.full_depth+1):
|
|
83
|
-
octree.octree_grow_full(depth=d)
|
|
84
|
-
return octree
|
|
85
|
-
|
|
86
|
-
def forward(self, octree_in, octree_out=None, update_octree: bool = False):
|
|
87
|
-
r''''''
|
|
88
|
-
|
|
89
|
-
if octree_out is None:
|
|
90
|
-
update_octree = True
|
|
91
|
-
octree_out = self.init_octree(octree_in)
|
|
92
|
-
convs = self.encoder(octree_in)
|
|
93
|
-
out = self.decoder(convs, octree_in, octree_out, update_octree)
|
|
94
|
-
return out
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
import ocnn
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn
|
|
12
|
+
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
from ocnn.models.autoencoder import AutoEncoder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class OUNet(AutoEncoder):
|
|
18
|
+
|
|
19
|
+
def __init__(self, channel_in: int, channel_out: int, depth: int,
|
|
20
|
+
full_depth: int = 2, feature: str = 'ND'):
|
|
21
|
+
super().__init__(channel_in, channel_out, depth, full_depth, feature,
|
|
22
|
+
code_channel=-1) # !set code_channe=-1
|
|
23
|
+
self.proj = None # remove this module used in AutoEncoder
|
|
24
|
+
|
|
25
|
+
def encoder(self, octree):
|
|
26
|
+
r''' The encoder network for extracting heirarchy features.
|
|
27
|
+
'''
|
|
28
|
+
|
|
29
|
+
convs = dict()
|
|
30
|
+
depth, full_depth = self.depth, self.full_depth
|
|
31
|
+
data = self.get_input_feature(octree)
|
|
32
|
+
convs[depth] = self.conv1(data, octree, depth)
|
|
33
|
+
for i, d in enumerate(range(depth, full_depth-1, -1)):
|
|
34
|
+
convs[d] = self.encoder_blks[i](convs[d], octree, d)
|
|
35
|
+
if d > full_depth:
|
|
36
|
+
convs[d-1] = self.downsample[i](convs[d], octree, d)
|
|
37
|
+
return convs
|
|
38
|
+
|
|
39
|
+
def decoder(self, convs: dict, octree_in: Octree, octree_out: Octree,
|
|
40
|
+
update_octree: bool = False):
|
|
41
|
+
r''' The decoder network for decode the octree.
|
|
42
|
+
'''
|
|
43
|
+
|
|
44
|
+
logits = dict()
|
|
45
|
+
deconv = convs[self.full_depth]
|
|
46
|
+
depth, full_depth = self.depth, self.full_depth
|
|
47
|
+
for i, d in enumerate(range(full_depth, depth + 1)):
|
|
48
|
+
if d > full_depth:
|
|
49
|
+
deconv = self.upsample[i-1](deconv, octree_out, d-1)
|
|
50
|
+
skip = ocnn.nn.octree_align(convs[d], octree_in, octree_out, d)
|
|
51
|
+
deconv = deconv + skip # output-guided skip connections
|
|
52
|
+
deconv = self.decoder_blks[i](deconv, octree_out, d)
|
|
53
|
+
|
|
54
|
+
# predict the splitting label
|
|
55
|
+
logit = self.predict[i](deconv)
|
|
56
|
+
logits[d] = logit
|
|
57
|
+
|
|
58
|
+
# update the octree according to predicted labels
|
|
59
|
+
if update_octree:
|
|
60
|
+
split = logit.argmax(1).int()
|
|
61
|
+
octree_out.octree_split(split, d)
|
|
62
|
+
if d < depth:
|
|
63
|
+
octree_out.octree_grow(d + 1)
|
|
64
|
+
|
|
65
|
+
# predict the signal
|
|
66
|
+
if d == depth:
|
|
67
|
+
signal = self.header(deconv)
|
|
68
|
+
signal = torch.tanh(signal)
|
|
69
|
+
signal = ocnn.nn.octree_depad(signal, octree_out, depth)
|
|
70
|
+
if update_octree:
|
|
71
|
+
octree_out.features[depth] = signal
|
|
72
|
+
|
|
73
|
+
return {'logits': logits, 'signal': signal, 'octree_out': octree_out}
|
|
74
|
+
|
|
75
|
+
def init_octree(self, octree_in: Octree):
|
|
76
|
+
r''' Initialize a full octree for decoding.
|
|
77
|
+
'''
|
|
78
|
+
|
|
79
|
+
device = octree_in.device
|
|
80
|
+
batch_size = octree_in.batch_size
|
|
81
|
+
octree = Octree(self.depth, self.full_depth, batch_size, device)
|
|
82
|
+
for d in range(self.full_depth+1):
|
|
83
|
+
octree.octree_grow_full(depth=d)
|
|
84
|
+
return octree
|
|
85
|
+
|
|
86
|
+
def forward(self, octree_in, octree_out=None, update_octree: bool = False):
|
|
87
|
+
r''''''
|
|
88
|
+
|
|
89
|
+
if octree_out is None:
|
|
90
|
+
update_octree = True
|
|
91
|
+
octree_out = self.init_octree(octree_in)
|
|
92
|
+
convs = self.encoder(octree_in)
|
|
93
|
+
out = self.decoder(convs, octree_in, octree_out, update_octree)
|
|
94
|
+
return out
|
ocnn/models/resnet.py
CHANGED
|
@@ -1,53 +1,53 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import ocnn
|
|
10
|
-
from ocnn.octree import Octree
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class ResNet(torch.nn.Module):
|
|
14
|
-
r''' Octree-based ResNet for classification.
|
|
15
|
-
'''
|
|
16
|
-
|
|
17
|
-
def __init__(self, in_channels: int, out_channels: int, resblock_num: int,
|
|
18
|
-
stages: int, nempty: bool = False):
|
|
19
|
-
super().__init__()
|
|
20
|
-
self.in_channels = in_channels
|
|
21
|
-
self.out_channels = out_channels
|
|
22
|
-
self.resblk_num = resblock_num
|
|
23
|
-
self.stages = stages
|
|
24
|
-
self.nempty = nempty
|
|
25
|
-
channels = [2 ** max(i+9-stages, 2) for i in range(stages)]
|
|
26
|
-
|
|
27
|
-
self.conv1 = ocnn.modules.OctreeConvBnRelu(
|
|
28
|
-
in_channels, channels[0], nempty=nempty)
|
|
29
|
-
self.pool1 = ocnn.nn.OctreeMaxPool(nempty)
|
|
30
|
-
self.resblocks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
31
|
-
channels[i], channels[i+1], resblock_num, nempty=nempty)
|
|
32
|
-
for i in range(stages-1)])
|
|
33
|
-
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
34
|
-
nempty) for i in range(stages-1)])
|
|
35
|
-
self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
|
|
36
|
-
# self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
|
|
37
|
-
self.header = torch.nn.Sequential(
|
|
38
|
-
ocnn.modules.FcBnRelu(channels[-1], 512),
|
|
39
|
-
torch.nn.Dropout(p=0.5),
|
|
40
|
-
torch.nn.Linear(512, out_channels))
|
|
41
|
-
|
|
42
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
43
|
-
r''''''
|
|
44
|
-
|
|
45
|
-
data = self.conv1(data, octree, depth)
|
|
46
|
-
data = self.pool1(data, octree, depth)
|
|
47
|
-
for i in range(self.stages-1):
|
|
48
|
-
d = depth - i - 1
|
|
49
|
-
data = self.resblocks[i](data, octree, d)
|
|
50
|
-
data = self.pools[i](data, octree, d)
|
|
51
|
-
data = self.global_pool(data, octree, depth-self.stages)
|
|
52
|
-
data = self.header(data)
|
|
53
|
-
return data
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import ocnn
|
|
10
|
+
from ocnn.octree import Octree
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ResNet(torch.nn.Module):
|
|
14
|
+
r''' Octree-based ResNet for classification.
|
|
15
|
+
'''
|
|
16
|
+
|
|
17
|
+
def __init__(self, in_channels: int, out_channels: int, resblock_num: int,
|
|
18
|
+
stages: int, nempty: bool = False):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.in_channels = in_channels
|
|
21
|
+
self.out_channels = out_channels
|
|
22
|
+
self.resblk_num = resblock_num
|
|
23
|
+
self.stages = stages
|
|
24
|
+
self.nempty = nempty
|
|
25
|
+
channels = [2 ** max(i+9-stages, 2) for i in range(stages)]
|
|
26
|
+
|
|
27
|
+
self.conv1 = ocnn.modules.OctreeConvBnRelu(
|
|
28
|
+
in_channels, channels[0], nempty=nempty)
|
|
29
|
+
self.pool1 = ocnn.nn.OctreeMaxPool(nempty)
|
|
30
|
+
self.resblocks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
31
|
+
channels[i], channels[i+1], resblock_num, nempty=nempty)
|
|
32
|
+
for i in range(stages-1)])
|
|
33
|
+
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
34
|
+
nempty) for i in range(stages-1)])
|
|
35
|
+
self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
|
|
36
|
+
# self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
|
|
37
|
+
self.header = torch.nn.Sequential(
|
|
38
|
+
ocnn.modules.FcBnRelu(channels[-1], 512),
|
|
39
|
+
torch.nn.Dropout(p=0.5),
|
|
40
|
+
torch.nn.Linear(512, out_channels))
|
|
41
|
+
|
|
42
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
43
|
+
r''''''
|
|
44
|
+
|
|
45
|
+
data = self.conv1(data, octree, depth)
|
|
46
|
+
data = self.pool1(data, octree, depth)
|
|
47
|
+
for i in range(self.stages-1):
|
|
48
|
+
d = depth - i - 1
|
|
49
|
+
data = self.resblocks[i](data, octree, d)
|
|
50
|
+
data = self.pools[i](data, octree, d)
|
|
51
|
+
data = self.global_pool(data, octree, depth-self.stages)
|
|
52
|
+
data = self.header(data)
|
|
53
|
+
return data
|
ocnn/models/segnet.py
CHANGED
|
@@ -1,72 +1,72 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import ocnn
|
|
10
|
-
from ocnn.octree import Octree
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class SegNet(torch.nn.Module):
|
|
14
|
-
r''' Octree-based SegNet for segmentation.
|
|
15
|
-
'''
|
|
16
|
-
|
|
17
|
-
def __init__(self, in_channels: int, out_channels: int, stages: int,
|
|
18
|
-
interp: str = 'linear', nempty: bool = False, **kwargs):
|
|
19
|
-
super().__init__()
|
|
20
|
-
self.in_channels = in_channels
|
|
21
|
-
self.out_channels = out_channels
|
|
22
|
-
self.stages = stages
|
|
23
|
-
self.nempty = nempty
|
|
24
|
-
return_indices = True
|
|
25
|
-
|
|
26
|
-
channels_stages = [2 ** max(i+8-stages, 2) for i in range(stages)]
|
|
27
|
-
channels = [in_channels] + channels_stages
|
|
28
|
-
self.convs = torch.nn.ModuleList(
|
|
29
|
-
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
30
|
-
for i in range(stages)])
|
|
31
|
-
self.pools = torch.nn.ModuleList(
|
|
32
|
-
[ocnn.nn.OctreeMaxPool(nempty, return_indices) for i in range(stages)])
|
|
33
|
-
|
|
34
|
-
self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
|
|
35
|
-
|
|
36
|
-
channels = channels_stages[::-1] + [channels_stages[0]]
|
|
37
|
-
self.deconvs = torch.nn.ModuleList(
|
|
38
|
-
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
39
|
-
for i in range(0, stages)])
|
|
40
|
-
self.unpools = torch.nn.ModuleList(
|
|
41
|
-
[ocnn.nn.OctreeMaxUnpool(nempty) for i in range(stages)])
|
|
42
|
-
|
|
43
|
-
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
|
|
44
|
-
self.header = torch.nn.Sequential(
|
|
45
|
-
ocnn.modules.Conv1x1BnRelu(channels[-1], 64),
|
|
46
|
-
ocnn.modules.Conv1x1(64, out_channels, use_bias=True))
|
|
47
|
-
|
|
48
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
49
|
-
query_pts: torch.Tensor):
|
|
50
|
-
r''''''
|
|
51
|
-
|
|
52
|
-
# encoder
|
|
53
|
-
indices = dict()
|
|
54
|
-
for i in range(self.stages):
|
|
55
|
-
d = depth - i
|
|
56
|
-
data = self.convs[i](data, octree, d)
|
|
57
|
-
data, indices[d] = self.pools[i](data, octree, d)
|
|
58
|
-
|
|
59
|
-
# bottleneck
|
|
60
|
-
data = self.bottleneck(data, octree, depth-self.stages)
|
|
61
|
-
|
|
62
|
-
# decoder
|
|
63
|
-
for i in range(self.stages):
|
|
64
|
-
d = depth - self.stages + i
|
|
65
|
-
data = self.unpools[i](data, indices[d + 1], octree, d)
|
|
66
|
-
data = self.deconvs[i](data, octree, d + 1)
|
|
67
|
-
|
|
68
|
-
# header
|
|
69
|
-
feature = self.octree_interp(data, octree, depth, query_pts)
|
|
70
|
-
logits = self.header(feature)
|
|
71
|
-
|
|
72
|
-
return logits
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import ocnn
|
|
10
|
+
from ocnn.octree import Octree
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SegNet(torch.nn.Module):
|
|
14
|
+
r''' Octree-based SegNet for segmentation.
|
|
15
|
+
'''
|
|
16
|
+
|
|
17
|
+
def __init__(self, in_channels: int, out_channels: int, stages: int,
|
|
18
|
+
interp: str = 'linear', nempty: bool = False, **kwargs):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.in_channels = in_channels
|
|
21
|
+
self.out_channels = out_channels
|
|
22
|
+
self.stages = stages
|
|
23
|
+
self.nempty = nempty
|
|
24
|
+
return_indices = True
|
|
25
|
+
|
|
26
|
+
channels_stages = [2 ** max(i+8-stages, 2) for i in range(stages)]
|
|
27
|
+
channels = [in_channels] + channels_stages
|
|
28
|
+
self.convs = torch.nn.ModuleList(
|
|
29
|
+
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
30
|
+
for i in range(stages)])
|
|
31
|
+
self.pools = torch.nn.ModuleList(
|
|
32
|
+
[ocnn.nn.OctreeMaxPool(nempty, return_indices) for i in range(stages)])
|
|
33
|
+
|
|
34
|
+
self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
|
|
35
|
+
|
|
36
|
+
channels = channels_stages[::-1] + [channels_stages[0]]
|
|
37
|
+
self.deconvs = torch.nn.ModuleList(
|
|
38
|
+
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
39
|
+
for i in range(0, stages)])
|
|
40
|
+
self.unpools = torch.nn.ModuleList(
|
|
41
|
+
[ocnn.nn.OctreeMaxUnpool(nempty) for i in range(stages)])
|
|
42
|
+
|
|
43
|
+
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
|
|
44
|
+
self.header = torch.nn.Sequential(
|
|
45
|
+
ocnn.modules.Conv1x1BnRelu(channels[-1], 64),
|
|
46
|
+
ocnn.modules.Conv1x1(64, out_channels, use_bias=True))
|
|
47
|
+
|
|
48
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
49
|
+
query_pts: torch.Tensor):
|
|
50
|
+
r''''''
|
|
51
|
+
|
|
52
|
+
# encoder
|
|
53
|
+
indices = dict()
|
|
54
|
+
for i in range(self.stages):
|
|
55
|
+
d = depth - i
|
|
56
|
+
data = self.convs[i](data, octree, d)
|
|
57
|
+
data, indices[d] = self.pools[i](data, octree, d)
|
|
58
|
+
|
|
59
|
+
# bottleneck
|
|
60
|
+
data = self.bottleneck(data, octree, depth-self.stages)
|
|
61
|
+
|
|
62
|
+
# decoder
|
|
63
|
+
for i in range(self.stages):
|
|
64
|
+
d = depth - self.stages + i
|
|
65
|
+
data = self.unpools[i](data, indices[d + 1], octree, d)
|
|
66
|
+
data = self.deconvs[i](data, octree, d + 1)
|
|
67
|
+
|
|
68
|
+
# header
|
|
69
|
+
feature = self.octree_interp(data, octree, depth, query_pts)
|
|
70
|
+
logits = self.header(feature)
|
|
71
|
+
|
|
72
|
+
return logits
|
ocnn/models/unet.py
CHANGED
|
@@ -1,105 +1,105 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.nn
|
|
10
|
-
from typing import Dict
|
|
11
|
-
|
|
12
|
-
import ocnn
|
|
13
|
-
from ocnn.octree import Octree
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class UNet(torch.nn.Module):
|
|
17
|
-
r''' Octree-based UNet for segmentation.
|
|
18
|
-
'''
|
|
19
|
-
|
|
20
|
-
def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear',
|
|
21
|
-
nempty: bool = False, **kwargs):
|
|
22
|
-
super(UNet, self).__init__()
|
|
23
|
-
self.in_channels = in_channels
|
|
24
|
-
self.out_channels = out_channels
|
|
25
|
-
self.nempty = nempty
|
|
26
|
-
self.config_network()
|
|
27
|
-
self.encoder_stages = len(self.encoder_blocks)
|
|
28
|
-
self.decoder_stages = len(self.decoder_blocks)
|
|
29
|
-
|
|
30
|
-
# encoder
|
|
31
|
-
self.conv1 = ocnn.modules.OctreeConvBnRelu(
|
|
32
|
-
in_channels, self.encoder_channel[0], nempty=nempty)
|
|
33
|
-
self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
34
|
-
self.encoder_channel[i], self.encoder_channel[i+1], kernel_size=[2],
|
|
35
|
-
stride=2, nempty=nempty) for i in range(self.encoder_stages)])
|
|
36
|
-
self.encoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
37
|
-
self.encoder_channel[i+1], self.encoder_channel[i + 1],
|
|
38
|
-
self.encoder_blocks[i], self.bottleneck, nempty, self.resblk)
|
|
39
|
-
for i in range(self.encoder_stages)])
|
|
40
|
-
|
|
41
|
-
# decoder
|
|
42
|
-
channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2]
|
|
43
|
-
for i in range(self.decoder_stages)]
|
|
44
|
-
self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
|
|
45
|
-
self.decoder_channel[i], self.decoder_channel[i+1], kernel_size=[2],
|
|
46
|
-
stride=2, nempty=nempty) for i in range(self.decoder_stages)])
|
|
47
|
-
self.decoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
48
|
-
channel[i], self.decoder_channel[i+1],
|
|
49
|
-
self.decoder_blocks[i], self.bottleneck, nempty, self.resblk)
|
|
50
|
-
for i in range(self.decoder_stages)])
|
|
51
|
-
|
|
52
|
-
# header
|
|
53
|
-
# channel = self.decoder_channel[self.decoder_stages]
|
|
54
|
-
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
|
|
55
|
-
self.header = torch.nn.Sequential(
|
|
56
|
-
ocnn.modules.Conv1x1BnRelu(self.decoder_channel[-1], self.head_channel),
|
|
57
|
-
ocnn.modules.Conv1x1(self.head_channel, self.out_channels, use_bias=True))
|
|
58
|
-
|
|
59
|
-
def config_network(self):
|
|
60
|
-
r''' Configure the network channels and Resblock numbers.
|
|
61
|
-
'''
|
|
62
|
-
|
|
63
|
-
self.encoder_channel = [32, 32, 64, 128, 256]
|
|
64
|
-
self.decoder_channel = [256, 256, 128, 96, 96]
|
|
65
|
-
self.encoder_blocks = [2, 3, 4, 6]
|
|
66
|
-
self.decoder_blocks = [2, 2, 2, 2]
|
|
67
|
-
self.head_channel = 64
|
|
68
|
-
self.bottleneck = 1
|
|
69
|
-
self.resblk = ocnn.modules.OctreeResBlock2
|
|
70
|
-
|
|
71
|
-
def unet_encoder(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
72
|
-
r''' The encoder of the U-Net.
|
|
73
|
-
'''
|
|
74
|
-
|
|
75
|
-
convd = dict()
|
|
76
|
-
convd[depth] = self.conv1(data, octree, depth)
|
|
77
|
-
for i in range(self.encoder_stages):
|
|
78
|
-
d = depth - i
|
|
79
|
-
conv = self.downsample[i](convd[d], octree, d)
|
|
80
|
-
convd[d-1] = self.encoder[i](conv, octree, d-1)
|
|
81
|
-
return convd
|
|
82
|
-
|
|
83
|
-
def unet_decoder(self, convd: Dict[int, torch.Tensor], octree: Octree, depth: int):
|
|
84
|
-
r''' The decoder of the U-Net.
|
|
85
|
-
'''
|
|
86
|
-
|
|
87
|
-
deconv = convd[depth]
|
|
88
|
-
for i in range(self.decoder_stages):
|
|
89
|
-
d = depth + i
|
|
90
|
-
deconv = self.upsample[i](deconv, octree, d)
|
|
91
|
-
deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections
|
|
92
|
-
deconv = self.decoder[i](deconv, octree, d+1)
|
|
93
|
-
return deconv
|
|
94
|
-
|
|
95
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
96
|
-
query_pts: torch.Tensor):
|
|
97
|
-
r''''''
|
|
98
|
-
|
|
99
|
-
convd = self.unet_encoder(data, octree, depth)
|
|
100
|
-
deconv = self.unet_decoder(convd, octree, depth - self.encoder_stages)
|
|
101
|
-
|
|
102
|
-
interp_depth = depth - self.encoder_stages + self.decoder_stages
|
|
103
|
-
feature = self.octree_interp(deconv, octree, interp_depth, query_pts)
|
|
104
|
-
logits = self.header(feature)
|
|
105
|
-
return logits
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
from typing import Dict
|
|
11
|
+
|
|
12
|
+
import ocnn
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UNet(torch.nn.Module):
|
|
17
|
+
r''' Octree-based UNet for segmentation.
|
|
18
|
+
'''
|
|
19
|
+
|
|
20
|
+
def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear',
|
|
21
|
+
nempty: bool = False, **kwargs):
|
|
22
|
+
super(UNet, self).__init__()
|
|
23
|
+
self.in_channels = in_channels
|
|
24
|
+
self.out_channels = out_channels
|
|
25
|
+
self.nempty = nempty
|
|
26
|
+
self.config_network()
|
|
27
|
+
self.encoder_stages = len(self.encoder_blocks)
|
|
28
|
+
self.decoder_stages = len(self.decoder_blocks)
|
|
29
|
+
|
|
30
|
+
# encoder
|
|
31
|
+
self.conv1 = ocnn.modules.OctreeConvBnRelu(
|
|
32
|
+
in_channels, self.encoder_channel[0], nempty=nempty)
|
|
33
|
+
self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
34
|
+
self.encoder_channel[i], self.encoder_channel[i+1], kernel_size=[2],
|
|
35
|
+
stride=2, nempty=nempty) for i in range(self.encoder_stages)])
|
|
36
|
+
self.encoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
37
|
+
self.encoder_channel[i+1], self.encoder_channel[i + 1],
|
|
38
|
+
self.encoder_blocks[i], self.bottleneck, nempty, self.resblk)
|
|
39
|
+
for i in range(self.encoder_stages)])
|
|
40
|
+
|
|
41
|
+
# decoder
|
|
42
|
+
channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2]
|
|
43
|
+
for i in range(self.decoder_stages)]
|
|
44
|
+
self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
|
|
45
|
+
self.decoder_channel[i], self.decoder_channel[i+1], kernel_size=[2],
|
|
46
|
+
stride=2, nempty=nempty) for i in range(self.decoder_stages)])
|
|
47
|
+
self.decoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
48
|
+
channel[i], self.decoder_channel[i+1],
|
|
49
|
+
self.decoder_blocks[i], self.bottleneck, nempty, self.resblk)
|
|
50
|
+
for i in range(self.decoder_stages)])
|
|
51
|
+
|
|
52
|
+
# header
|
|
53
|
+
# channel = self.decoder_channel[self.decoder_stages]
|
|
54
|
+
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
|
|
55
|
+
self.header = torch.nn.Sequential(
|
|
56
|
+
ocnn.modules.Conv1x1BnRelu(self.decoder_channel[-1], self.head_channel),
|
|
57
|
+
ocnn.modules.Conv1x1(self.head_channel, self.out_channels, use_bias=True))
|
|
58
|
+
|
|
59
|
+
def config_network(self):
|
|
60
|
+
r''' Configure the network channels and Resblock numbers.
|
|
61
|
+
'''
|
|
62
|
+
|
|
63
|
+
self.encoder_channel = [32, 32, 64, 128, 256]
|
|
64
|
+
self.decoder_channel = [256, 256, 128, 96, 96]
|
|
65
|
+
self.encoder_blocks = [2, 3, 4, 6]
|
|
66
|
+
self.decoder_blocks = [2, 2, 2, 2]
|
|
67
|
+
self.head_channel = 64
|
|
68
|
+
self.bottleneck = 1
|
|
69
|
+
self.resblk = ocnn.modules.OctreeResBlock2
|
|
70
|
+
|
|
71
|
+
def unet_encoder(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
72
|
+
r''' The encoder of the U-Net.
|
|
73
|
+
'''
|
|
74
|
+
|
|
75
|
+
convd = dict()
|
|
76
|
+
convd[depth] = self.conv1(data, octree, depth)
|
|
77
|
+
for i in range(self.encoder_stages):
|
|
78
|
+
d = depth - i
|
|
79
|
+
conv = self.downsample[i](convd[d], octree, d)
|
|
80
|
+
convd[d-1] = self.encoder[i](conv, octree, d-1)
|
|
81
|
+
return convd
|
|
82
|
+
|
|
83
|
+
def unet_decoder(self, convd: Dict[int, torch.Tensor], octree: Octree, depth: int):
|
|
84
|
+
r''' The decoder of the U-Net.
|
|
85
|
+
'''
|
|
86
|
+
|
|
87
|
+
deconv = convd[depth]
|
|
88
|
+
for i in range(self.decoder_stages):
|
|
89
|
+
d = depth + i
|
|
90
|
+
deconv = self.upsample[i](deconv, octree, d)
|
|
91
|
+
deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections
|
|
92
|
+
deconv = self.decoder[i](deconv, octree, d+1)
|
|
93
|
+
return deconv
|
|
94
|
+
|
|
95
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
96
|
+
query_pts: torch.Tensor):
|
|
97
|
+
r''''''
|
|
98
|
+
|
|
99
|
+
convd = self.unet_encoder(data, octree, depth)
|
|
100
|
+
deconv = self.unet_decoder(convd, octree, depth - self.encoder_stages)
|
|
101
|
+
|
|
102
|
+
interp_depth = depth - self.encoder_stages + self.decoder_stages
|
|
103
|
+
feature = self.octree_interp(deconv, octree, interp_depth, query_pts)
|
|
104
|
+
logits = self.header(feature)
|
|
105
|
+
return logits
|