ocnn 2.2.5__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +44 -44
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +429 -429
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +661 -659
- ocnn/octree/points.py +323 -322
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- ocnn-2.2.7.dist-info/METADATA +112 -0
- ocnn-2.2.7.dist-info/RECORD +36 -0
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info}/WHEEL +1 -1
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info/licenses}/LICENSE +21 -21
- ocnn-2.2.5.dist-info/METADATA +0 -80
- ocnn-2.2.5.dist-info/RECORD +0 -36
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info}/top_level.txt +0 -0
ocnn/models/lenet.py
CHANGED
|
@@ -1,46 +1,46 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import ocnn
|
|
10
|
-
from ocnn.octree import Octree
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class LeNet(torch.nn.Module):
|
|
14
|
-
r''' Octree-based LeNet for classification.
|
|
15
|
-
'''
|
|
16
|
-
|
|
17
|
-
def __init__(self, in_channels: int, out_channels: int, stages: int,
|
|
18
|
-
nempty: bool = False):
|
|
19
|
-
super().__init__()
|
|
20
|
-
self.in_channels = in_channels
|
|
21
|
-
self.out_channels = out_channels
|
|
22
|
-
self.stages = stages
|
|
23
|
-
self.nempty = nempty
|
|
24
|
-
channels = [in_channels] + [2 ** max(i+7-stages, 2) for i in range(stages)]
|
|
25
|
-
|
|
26
|
-
self.convs = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
27
|
-
channels[i], channels[i+1], nempty=nempty) for i in range(stages)])
|
|
28
|
-
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
29
|
-
nempty) for _ in range(stages)])
|
|
30
|
-
self.octree2voxel = ocnn.nn.Octree2Voxel(self.nempty)
|
|
31
|
-
self.header = torch.nn.Sequential(
|
|
32
|
-
torch.nn.Dropout(p=0.5), # drop1
|
|
33
|
-
ocnn.modules.FcBnRelu(64 * 64, 128), # fc1
|
|
34
|
-
torch.nn.Dropout(p=0.5), # drop2
|
|
35
|
-
torch.nn.Linear(128, out_channels)) # fc2
|
|
36
|
-
|
|
37
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
38
|
-
r''''''
|
|
39
|
-
|
|
40
|
-
for i in range(self.stages):
|
|
41
|
-
d = depth - i
|
|
42
|
-
data = self.convs[i](data, octree, d)
|
|
43
|
-
data = self.pools[i](data, octree, d)
|
|
44
|
-
data = self.octree2voxel(data, octree, depth-self.stages)
|
|
45
|
-
data = self.header(data)
|
|
46
|
-
return data
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import ocnn
|
|
10
|
+
from ocnn.octree import Octree
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LeNet(torch.nn.Module):
|
|
14
|
+
r''' Octree-based LeNet for classification.
|
|
15
|
+
'''
|
|
16
|
+
|
|
17
|
+
def __init__(self, in_channels: int, out_channels: int, stages: int,
|
|
18
|
+
nempty: bool = False):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.in_channels = in_channels
|
|
21
|
+
self.out_channels = out_channels
|
|
22
|
+
self.stages = stages
|
|
23
|
+
self.nempty = nempty
|
|
24
|
+
channels = [in_channels] + [2 ** max(i+7-stages, 2) for i in range(stages)]
|
|
25
|
+
|
|
26
|
+
self.convs = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
27
|
+
channels[i], channels[i+1], nempty=nempty) for i in range(stages)])
|
|
28
|
+
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
29
|
+
nempty) for _ in range(stages)])
|
|
30
|
+
self.octree2voxel = ocnn.nn.Octree2Voxel(self.nempty)
|
|
31
|
+
self.header = torch.nn.Sequential(
|
|
32
|
+
torch.nn.Dropout(p=0.5), # drop1
|
|
33
|
+
ocnn.modules.FcBnRelu(64 * 64, 128), # fc1
|
|
34
|
+
torch.nn.Dropout(p=0.5), # drop2
|
|
35
|
+
torch.nn.Linear(128, out_channels)) # fc2
|
|
36
|
+
|
|
37
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
38
|
+
r''''''
|
|
39
|
+
|
|
40
|
+
for i in range(self.stages):
|
|
41
|
+
d = depth - i
|
|
42
|
+
data = self.convs[i](data, octree, d)
|
|
43
|
+
data = self.pools[i](data, octree, d)
|
|
44
|
+
data = self.octree2voxel(data, octree, depth-self.stages)
|
|
45
|
+
data = self.header(data)
|
|
46
|
+
return data
|
ocnn/models/ounet.py
CHANGED
|
@@ -1,94 +1,94 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
import ocnn
|
|
10
|
-
import torch
|
|
11
|
-
import torch.nn
|
|
12
|
-
|
|
13
|
-
from ocnn.octree import Octree
|
|
14
|
-
from ocnn.models.autoencoder import AutoEncoder
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class OUNet(AutoEncoder):
|
|
18
|
-
|
|
19
|
-
def __init__(self, channel_in: int, channel_out: int, depth: int,
|
|
20
|
-
full_depth: int = 2, feature: str = 'ND'):
|
|
21
|
-
super().__init__(channel_in, channel_out, depth, full_depth, feature)
|
|
22
|
-
self.proj = None # remove this module used in AutoEncoder
|
|
23
|
-
|
|
24
|
-
def encoder(self, octree):
|
|
25
|
-
r''' The encoder network for extracting heirarchy features.
|
|
26
|
-
'''
|
|
27
|
-
|
|
28
|
-
convs = dict()
|
|
29
|
-
depth, full_depth = self.depth, self.full_depth
|
|
30
|
-
data = octree.get_input_feature(self.feature, nempty=False)
|
|
31
|
-
assert data.size(1) == self.channel_in
|
|
32
|
-
convs[depth] = self.conv1(data, octree, depth)
|
|
33
|
-
for i, d in enumerate(range(depth, full_depth-1, -1)):
|
|
34
|
-
convs[d] = self.encoder_blks[i](convs[d], octree, d)
|
|
35
|
-
if d > full_depth:
|
|
36
|
-
convs[d-1] = self.downsample[i](convs[d], octree, d)
|
|
37
|
-
return convs
|
|
38
|
-
|
|
39
|
-
def decoder(self, convs: dict, octree_in: Octree, octree_out: Octree,
|
|
40
|
-
update_octree: bool = False):
|
|
41
|
-
r''' The decoder network for decode the octree.
|
|
42
|
-
'''
|
|
43
|
-
|
|
44
|
-
logits = dict()
|
|
45
|
-
deconv = convs[self.full_depth]
|
|
46
|
-
depth, full_depth = self.depth, self.full_depth
|
|
47
|
-
for i, d in enumerate(range(full_depth, depth + 1)):
|
|
48
|
-
if d > full_depth:
|
|
49
|
-
deconv = self.upsample[i-1](deconv, octree_out, d-1)
|
|
50
|
-
skip = ocnn.nn.octree_align(convs[d], octree_in, octree_out, d)
|
|
51
|
-
deconv = deconv + skip # output-guided skip connections
|
|
52
|
-
deconv = self.decoder_blks[i](deconv, octree_out, d)
|
|
53
|
-
|
|
54
|
-
# predict the splitting label
|
|
55
|
-
logit = self.predict[i](deconv)
|
|
56
|
-
logits[d] = logit
|
|
57
|
-
|
|
58
|
-
# update the octree according to predicted labels
|
|
59
|
-
if update_octree:
|
|
60
|
-
split = logit.argmax(1).int()
|
|
61
|
-
octree_out.octree_split(split, d)
|
|
62
|
-
if d < depth:
|
|
63
|
-
octree_out.octree_grow(d + 1)
|
|
64
|
-
|
|
65
|
-
# predict the signal
|
|
66
|
-
if d == depth:
|
|
67
|
-
signal = self.header(deconv)
|
|
68
|
-
signal = torch.tanh(signal)
|
|
69
|
-
signal = ocnn.nn.octree_depad(signal, octree_out, depth)
|
|
70
|
-
if update_octree:
|
|
71
|
-
octree_out.features[depth] = signal
|
|
72
|
-
|
|
73
|
-
return {'logits': logits, 'signal': signal, 'octree_out': octree_out}
|
|
74
|
-
|
|
75
|
-
def init_octree(self, octree_in: Octree):
|
|
76
|
-
r''' Initialize a full octree for decoding.
|
|
77
|
-
'''
|
|
78
|
-
|
|
79
|
-
device = octree_in.device
|
|
80
|
-
batch_size = octree_in.batch_size
|
|
81
|
-
octree = Octree(self.depth, self.full_depth, batch_size, device)
|
|
82
|
-
for d in range(self.full_depth+1):
|
|
83
|
-
octree.octree_grow_full(depth=d)
|
|
84
|
-
return octree
|
|
85
|
-
|
|
86
|
-
def forward(self, octree_in, octree_out=None, update_octree: bool = False):
|
|
87
|
-
r''''''
|
|
88
|
-
|
|
89
|
-
if octree_out is None:
|
|
90
|
-
update_octree = True
|
|
91
|
-
octree_out = self.init_octree(octree_in)
|
|
92
|
-
convs = self.encoder(octree_in)
|
|
93
|
-
out = self.decoder(convs, octree_in, octree_out, update_octree)
|
|
94
|
-
return out
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
import ocnn
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn
|
|
12
|
+
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
from ocnn.models.autoencoder import AutoEncoder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class OUNet(AutoEncoder):
|
|
18
|
+
|
|
19
|
+
def __init__(self, channel_in: int, channel_out: int, depth: int,
|
|
20
|
+
full_depth: int = 2, feature: str = 'ND'):
|
|
21
|
+
super().__init__(channel_in, channel_out, depth, full_depth, feature)
|
|
22
|
+
self.proj = None # remove this module used in AutoEncoder
|
|
23
|
+
|
|
24
|
+
def encoder(self, octree):
|
|
25
|
+
r''' The encoder network for extracting heirarchy features.
|
|
26
|
+
'''
|
|
27
|
+
|
|
28
|
+
convs = dict()
|
|
29
|
+
depth, full_depth = self.depth, self.full_depth
|
|
30
|
+
data = octree.get_input_feature(self.feature, nempty=False)
|
|
31
|
+
assert data.size(1) == self.channel_in
|
|
32
|
+
convs[depth] = self.conv1(data, octree, depth)
|
|
33
|
+
for i, d in enumerate(range(depth, full_depth-1, -1)):
|
|
34
|
+
convs[d] = self.encoder_blks[i](convs[d], octree, d)
|
|
35
|
+
if d > full_depth:
|
|
36
|
+
convs[d-1] = self.downsample[i](convs[d], octree, d)
|
|
37
|
+
return convs
|
|
38
|
+
|
|
39
|
+
def decoder(self, convs: dict, octree_in: Octree, octree_out: Octree,
|
|
40
|
+
update_octree: bool = False):
|
|
41
|
+
r''' The decoder network for decode the octree.
|
|
42
|
+
'''
|
|
43
|
+
|
|
44
|
+
logits = dict()
|
|
45
|
+
deconv = convs[self.full_depth]
|
|
46
|
+
depth, full_depth = self.depth, self.full_depth
|
|
47
|
+
for i, d in enumerate(range(full_depth, depth + 1)):
|
|
48
|
+
if d > full_depth:
|
|
49
|
+
deconv = self.upsample[i-1](deconv, octree_out, d-1)
|
|
50
|
+
skip = ocnn.nn.octree_align(convs[d], octree_in, octree_out, d)
|
|
51
|
+
deconv = deconv + skip # output-guided skip connections
|
|
52
|
+
deconv = self.decoder_blks[i](deconv, octree_out, d)
|
|
53
|
+
|
|
54
|
+
# predict the splitting label
|
|
55
|
+
logit = self.predict[i](deconv)
|
|
56
|
+
logits[d] = logit
|
|
57
|
+
|
|
58
|
+
# update the octree according to predicted labels
|
|
59
|
+
if update_octree:
|
|
60
|
+
split = logit.argmax(1).int()
|
|
61
|
+
octree_out.octree_split(split, d)
|
|
62
|
+
if d < depth:
|
|
63
|
+
octree_out.octree_grow(d + 1)
|
|
64
|
+
|
|
65
|
+
# predict the signal
|
|
66
|
+
if d == depth:
|
|
67
|
+
signal = self.header(deconv)
|
|
68
|
+
signal = torch.tanh(signal)
|
|
69
|
+
signal = ocnn.nn.octree_depad(signal, octree_out, depth)
|
|
70
|
+
if update_octree:
|
|
71
|
+
octree_out.features[depth] = signal
|
|
72
|
+
|
|
73
|
+
return {'logits': logits, 'signal': signal, 'octree_out': octree_out}
|
|
74
|
+
|
|
75
|
+
def init_octree(self, octree_in: Octree):
|
|
76
|
+
r''' Initialize a full octree for decoding.
|
|
77
|
+
'''
|
|
78
|
+
|
|
79
|
+
device = octree_in.device
|
|
80
|
+
batch_size = octree_in.batch_size
|
|
81
|
+
octree = Octree(self.depth, self.full_depth, batch_size, device)
|
|
82
|
+
for d in range(self.full_depth+1):
|
|
83
|
+
octree.octree_grow_full(depth=d)
|
|
84
|
+
return octree
|
|
85
|
+
|
|
86
|
+
def forward(self, octree_in, octree_out=None, update_octree: bool = False):
|
|
87
|
+
r''''''
|
|
88
|
+
|
|
89
|
+
if octree_out is None:
|
|
90
|
+
update_octree = True
|
|
91
|
+
octree_out = self.init_octree(octree_in)
|
|
92
|
+
convs = self.encoder(octree_in)
|
|
93
|
+
out = self.decoder(convs, octree_in, octree_out, update_octree)
|
|
94
|
+
return out
|
ocnn/models/resnet.py
CHANGED
|
@@ -1,53 +1,53 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import ocnn
|
|
10
|
-
from ocnn.octree import Octree
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class ResNet(torch.nn.Module):
|
|
14
|
-
r''' Octree-based ResNet for classification.
|
|
15
|
-
'''
|
|
16
|
-
|
|
17
|
-
def __init__(self, in_channels: int, out_channels: int, resblock_num: int,
|
|
18
|
-
stages: int, nempty: bool = False):
|
|
19
|
-
super().__init__()
|
|
20
|
-
self.in_channels = in_channels
|
|
21
|
-
self.out_channels = out_channels
|
|
22
|
-
self.resblk_num = resblock_num
|
|
23
|
-
self.stages = stages
|
|
24
|
-
self.nempty = nempty
|
|
25
|
-
channels = [2 ** max(i+9-stages, 2) for i in range(stages)]
|
|
26
|
-
|
|
27
|
-
self.conv1 = ocnn.modules.OctreeConvBnRelu(
|
|
28
|
-
in_channels, channels[0], nempty=nempty)
|
|
29
|
-
self.pool1 = ocnn.nn.OctreeMaxPool(nempty)
|
|
30
|
-
self.resblocks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
31
|
-
channels[i], channels[i+1], resblock_num, nempty=nempty)
|
|
32
|
-
for i in range(stages-1)])
|
|
33
|
-
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
34
|
-
nempty) for _ in range(stages-1)])
|
|
35
|
-
self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
|
|
36
|
-
# self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
|
|
37
|
-
self.header = torch.nn.Sequential(
|
|
38
|
-
ocnn.modules.FcBnRelu(channels[-1], 512),
|
|
39
|
-
torch.nn.Dropout(p=0.5),
|
|
40
|
-
torch.nn.Linear(512, out_channels))
|
|
41
|
-
|
|
42
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
43
|
-
r''''''
|
|
44
|
-
|
|
45
|
-
data = self.conv1(data, octree, depth)
|
|
46
|
-
data = self.pool1(data, octree, depth)
|
|
47
|
-
for i in range(self.stages-1):
|
|
48
|
-
d = depth - i - 1
|
|
49
|
-
data = self.resblocks[i](data, octree, d)
|
|
50
|
-
data = self.pools[i](data, octree, d)
|
|
51
|
-
data = self.global_pool(data, octree, depth-self.stages)
|
|
52
|
-
data = self.header(data)
|
|
53
|
-
return data
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import ocnn
|
|
10
|
+
from ocnn.octree import Octree
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ResNet(torch.nn.Module):
|
|
14
|
+
r''' Octree-based ResNet for classification.
|
|
15
|
+
'''
|
|
16
|
+
|
|
17
|
+
def __init__(self, in_channels: int, out_channels: int, resblock_num: int,
|
|
18
|
+
stages: int, nempty: bool = False):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.in_channels = in_channels
|
|
21
|
+
self.out_channels = out_channels
|
|
22
|
+
self.resblk_num = resblock_num
|
|
23
|
+
self.stages = stages
|
|
24
|
+
self.nempty = nempty
|
|
25
|
+
channels = [2 ** max(i+9-stages, 2) for i in range(stages)]
|
|
26
|
+
|
|
27
|
+
self.conv1 = ocnn.modules.OctreeConvBnRelu(
|
|
28
|
+
in_channels, channels[0], nempty=nempty)
|
|
29
|
+
self.pool1 = ocnn.nn.OctreeMaxPool(nempty)
|
|
30
|
+
self.resblocks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
31
|
+
channels[i], channels[i+1], resblock_num, nempty=nempty)
|
|
32
|
+
for i in range(stages-1)])
|
|
33
|
+
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
34
|
+
nempty) for _ in range(stages-1)])
|
|
35
|
+
self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
|
|
36
|
+
# self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
|
|
37
|
+
self.header = torch.nn.Sequential(
|
|
38
|
+
ocnn.modules.FcBnRelu(channels[-1], 512),
|
|
39
|
+
torch.nn.Dropout(p=0.5),
|
|
40
|
+
torch.nn.Linear(512, out_channels))
|
|
41
|
+
|
|
42
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
43
|
+
r''''''
|
|
44
|
+
|
|
45
|
+
data = self.conv1(data, octree, depth)
|
|
46
|
+
data = self.pool1(data, octree, depth)
|
|
47
|
+
for i in range(self.stages-1):
|
|
48
|
+
d = depth - i - 1
|
|
49
|
+
data = self.resblocks[i](data, octree, d)
|
|
50
|
+
data = self.pools[i](data, octree, d)
|
|
51
|
+
data = self.global_pool(data, octree, depth-self.stages)
|
|
52
|
+
data = self.header(data)
|
|
53
|
+
return data
|
ocnn/models/segnet.py
CHANGED
|
@@ -1,72 +1,72 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import ocnn
|
|
10
|
-
from ocnn.octree import Octree
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class SegNet(torch.nn.Module):
|
|
14
|
-
r''' Octree-based SegNet for segmentation.
|
|
15
|
-
'''
|
|
16
|
-
|
|
17
|
-
def __init__(self, in_channels: int, out_channels: int, stages: int,
|
|
18
|
-
interp: str = 'linear', nempty: bool = False, **kwargs):
|
|
19
|
-
super().__init__()
|
|
20
|
-
self.in_channels = in_channels
|
|
21
|
-
self.out_channels = out_channels
|
|
22
|
-
self.stages = stages
|
|
23
|
-
self.nempty = nempty
|
|
24
|
-
return_indices = True
|
|
25
|
-
|
|
26
|
-
channels_stages = [2 ** max(i+8-stages, 2) for i in range(stages)]
|
|
27
|
-
channels = [in_channels] + channels_stages
|
|
28
|
-
self.convs = torch.nn.ModuleList(
|
|
29
|
-
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
30
|
-
for i in range(stages)])
|
|
31
|
-
self.pools = torch.nn.ModuleList(
|
|
32
|
-
[ocnn.nn.OctreeMaxPool(nempty, return_indices) for _ in range(stages)])
|
|
33
|
-
|
|
34
|
-
self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
|
|
35
|
-
|
|
36
|
-
channels = channels_stages[::-1] + [channels_stages[0]]
|
|
37
|
-
self.deconvs = torch.nn.ModuleList(
|
|
38
|
-
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
39
|
-
for i in range(0, stages)])
|
|
40
|
-
self.unpools = torch.nn.ModuleList(
|
|
41
|
-
[ocnn.nn.OctreeMaxUnpool(nempty) for _ in range(stages)])
|
|
42
|
-
|
|
43
|
-
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
|
|
44
|
-
self.header = torch.nn.Sequential(
|
|
45
|
-
ocnn.modules.Conv1x1BnRelu(channels[-1], 64),
|
|
46
|
-
ocnn.modules.Conv1x1(64, out_channels, use_bias=True))
|
|
47
|
-
|
|
48
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
49
|
-
query_pts: torch.Tensor):
|
|
50
|
-
r''''''
|
|
51
|
-
|
|
52
|
-
# encoder
|
|
53
|
-
indices = dict()
|
|
54
|
-
for i in range(self.stages):
|
|
55
|
-
d = depth - i
|
|
56
|
-
data = self.convs[i](data, octree, d)
|
|
57
|
-
data, indices[d] = self.pools[i](data, octree, d)
|
|
58
|
-
|
|
59
|
-
# bottleneck
|
|
60
|
-
data = self.bottleneck(data, octree, depth-self.stages)
|
|
61
|
-
|
|
62
|
-
# decoder
|
|
63
|
-
for i in range(self.stages):
|
|
64
|
-
d = depth - self.stages + i
|
|
65
|
-
data = self.unpools[i](data, indices[d + 1], octree, d)
|
|
66
|
-
data = self.deconvs[i](data, octree, d + 1)
|
|
67
|
-
|
|
68
|
-
# header
|
|
69
|
-
feature = self.octree_interp(data, octree, depth, query_pts)
|
|
70
|
-
logits = self.header(feature)
|
|
71
|
-
|
|
72
|
-
return logits
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import ocnn
|
|
10
|
+
from ocnn.octree import Octree
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SegNet(torch.nn.Module):
|
|
14
|
+
r''' Octree-based SegNet for segmentation.
|
|
15
|
+
'''
|
|
16
|
+
|
|
17
|
+
def __init__(self, in_channels: int, out_channels: int, stages: int,
|
|
18
|
+
interp: str = 'linear', nempty: bool = False, **kwargs):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.in_channels = in_channels
|
|
21
|
+
self.out_channels = out_channels
|
|
22
|
+
self.stages = stages
|
|
23
|
+
self.nempty = nempty
|
|
24
|
+
return_indices = True
|
|
25
|
+
|
|
26
|
+
channels_stages = [2 ** max(i+8-stages, 2) for i in range(stages)]
|
|
27
|
+
channels = [in_channels] + channels_stages
|
|
28
|
+
self.convs = torch.nn.ModuleList(
|
|
29
|
+
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
30
|
+
for i in range(stages)])
|
|
31
|
+
self.pools = torch.nn.ModuleList(
|
|
32
|
+
[ocnn.nn.OctreeMaxPool(nempty, return_indices) for _ in range(stages)])
|
|
33
|
+
|
|
34
|
+
self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
|
|
35
|
+
|
|
36
|
+
channels = channels_stages[::-1] + [channels_stages[0]]
|
|
37
|
+
self.deconvs = torch.nn.ModuleList(
|
|
38
|
+
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
39
|
+
for i in range(0, stages)])
|
|
40
|
+
self.unpools = torch.nn.ModuleList(
|
|
41
|
+
[ocnn.nn.OctreeMaxUnpool(nempty) for _ in range(stages)])
|
|
42
|
+
|
|
43
|
+
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
|
|
44
|
+
self.header = torch.nn.Sequential(
|
|
45
|
+
ocnn.modules.Conv1x1BnRelu(channels[-1], 64),
|
|
46
|
+
ocnn.modules.Conv1x1(64, out_channels, use_bias=True))
|
|
47
|
+
|
|
48
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
49
|
+
query_pts: torch.Tensor):
|
|
50
|
+
r''''''
|
|
51
|
+
|
|
52
|
+
# encoder
|
|
53
|
+
indices = dict()
|
|
54
|
+
for i in range(self.stages):
|
|
55
|
+
d = depth - i
|
|
56
|
+
data = self.convs[i](data, octree, d)
|
|
57
|
+
data, indices[d] = self.pools[i](data, octree, d)
|
|
58
|
+
|
|
59
|
+
# bottleneck
|
|
60
|
+
data = self.bottleneck(data, octree, depth-self.stages)
|
|
61
|
+
|
|
62
|
+
# decoder
|
|
63
|
+
for i in range(self.stages):
|
|
64
|
+
d = depth - self.stages + i
|
|
65
|
+
data = self.unpools[i](data, indices[d + 1], octree, d)
|
|
66
|
+
data = self.deconvs[i](data, octree, d + 1)
|
|
67
|
+
|
|
68
|
+
# header
|
|
69
|
+
feature = self.octree_interp(data, octree, depth, query_pts)
|
|
70
|
+
logits = self.header(feature)
|
|
71
|
+
|
|
72
|
+
return logits
|