ocnn 2.2.1__py3-none-any.whl → 2.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -158
- ocnn/models/__init__.py +29 -27
- ocnn/models/autoencoder.py +155 -165
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -0
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +20 -20
- ocnn/modules/modules.py +193 -231
- ocnn/modules/resblocks.py +124 -124
- ocnn/nn/__init__.py +42 -42
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +411 -411
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +204 -204
- ocnn/nn/octree_gconv.py +79 -0
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +86 -86
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -21
- ocnn/octree/octree.py +639 -601
- ocnn/octree/points.py +317 -298
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +202 -153
- {ocnn-2.2.1.dist-info → ocnn-2.2.2.dist-info}/LICENSE +21 -21
- {ocnn-2.2.1.dist-info → ocnn-2.2.2.dist-info}/METADATA +67 -65
- ocnn-2.2.2.dist-info/RECORD +36 -0
- {ocnn-2.2.1.dist-info → ocnn-2.2.2.dist-info}/WHEEL +1 -1
- ocnn-2.2.1.dist-info/RECORD +0 -34
- {ocnn-2.2.1.dist-info → ocnn-2.2.2.dist-info}/top_level.txt +0 -0
ocnn/models/hrnet.py
CHANGED
|
@@ -1,192 +1,192 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
from typing import List
|
|
10
|
-
|
|
11
|
-
import ocnn
|
|
12
|
-
from ocnn.octree import Octree
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class Branches(torch.nn.Module):
|
|
16
|
-
|
|
17
|
-
def __init__(self, channels: List[int], resblk_num: int, nempty: bool = False):
|
|
18
|
-
super().__init__()
|
|
19
|
-
self.channels = channels
|
|
20
|
-
self.resblk_num = resblk_num
|
|
21
|
-
bottlenecks = [4 if c < 256 else 8 for c in channels] # to save parameters
|
|
22
|
-
self.resblocks = torch.nn.ModuleList([
|
|
23
|
-
ocnn.modules.OctreeResBlocks(ch, ch, resblk_num, bnk, nempty=nempty)
|
|
24
|
-
for ch, bnk in zip(channels, bottlenecks)])
|
|
25
|
-
|
|
26
|
-
def forward(self, datas: List[torch.Tensor], octree: Octree, depth: int):
|
|
27
|
-
num = len(self.channels)
|
|
28
|
-
torch._assert(len(datas) == num, 'Error')
|
|
29
|
-
|
|
30
|
-
out = [None] * num
|
|
31
|
-
for i in range(num):
|
|
32
|
-
depth_i = depth - i
|
|
33
|
-
out[i] = self.resblocks[i](datas[i], octree, depth_i)
|
|
34
|
-
return out
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class TransFunc(torch.nn.Module):
|
|
38
|
-
|
|
39
|
-
def __init__(self, in_channels: int, out_channels: int, nempty: bool = False):
|
|
40
|
-
super().__init__()
|
|
41
|
-
self.in_channels = in_channels
|
|
42
|
-
self.out_channels = out_channels
|
|
43
|
-
self.nempty = nempty
|
|
44
|
-
self.maxpool = ocnn.nn.OctreeMaxPool(nempty=nempty)
|
|
45
|
-
self.upsample = ocnn.nn.OctreeUpsample(method='nearest', nempty=nempty)
|
|
46
|
-
if in_channels != out_channels:
|
|
47
|
-
self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, out_channels)
|
|
48
|
-
|
|
49
|
-
def forward(self, data: torch.Tensor, octree: Octree,
|
|
50
|
-
in_depth: int, out_depth: int):
|
|
51
|
-
out = data
|
|
52
|
-
if in_depth > out_depth:
|
|
53
|
-
for d in range(in_depth, out_depth, -1):
|
|
54
|
-
out = self.maxpool(out, octree, d)
|
|
55
|
-
if self.in_channels != self.out_channels:
|
|
56
|
-
out = self.conv1x1(out)
|
|
57
|
-
|
|
58
|
-
if in_depth < out_depth:
|
|
59
|
-
if self.in_channels != self.out_channels:
|
|
60
|
-
out = self.conv1x1(out)
|
|
61
|
-
for d in range(in_depth, out_depth, 1):
|
|
62
|
-
out = self.upsample(out, octree, d)
|
|
63
|
-
return out
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class Transitions(torch.nn.Module):
|
|
67
|
-
|
|
68
|
-
def __init__(self, channels: List[int], nempty: bool = False):
|
|
69
|
-
super().__init__()
|
|
70
|
-
self.channels = channels
|
|
71
|
-
self.nempty = nempty
|
|
72
|
-
|
|
73
|
-
num = len(self.channels)
|
|
74
|
-
self.trans_func = torch.nn.ModuleList()
|
|
75
|
-
for i in range(num - 1):
|
|
76
|
-
for j in range(num):
|
|
77
|
-
self.trans_func.append(TransFunc(channels[i], channels[j], nempty))
|
|
78
|
-
|
|
79
|
-
def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
|
|
80
|
-
num = len(self.channels)
|
|
81
|
-
features = [[None] * (num - 1) for _ in range(num)]
|
|
82
|
-
for i in range(num - 1):
|
|
83
|
-
for j in range(num):
|
|
84
|
-
k = i * num + j
|
|
85
|
-
in_depth = depth - i
|
|
86
|
-
out_depth = depth - j
|
|
87
|
-
features[j][i] = self.trans_func[k](
|
|
88
|
-
data[i], octree, in_depth, out_depth)
|
|
89
|
-
|
|
90
|
-
out = [None] * num
|
|
91
|
-
for j in range(num):
|
|
92
|
-
# In the original tensorflow implmentation, a relu is added after the sum.
|
|
93
|
-
out[j] = torch.stack(features[j], dim=0).sum(dim=0)
|
|
94
|
-
return out
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
class FrontLayer(torch.nn.Module):
|
|
98
|
-
|
|
99
|
-
def __init__(self, channels: List[int], nempty: bool = False):
|
|
100
|
-
super().__init__()
|
|
101
|
-
self.channels = channels
|
|
102
|
-
self.num = len(channels) - 1
|
|
103
|
-
self.nempty = nempty
|
|
104
|
-
|
|
105
|
-
self.conv = torch.nn.ModuleList([
|
|
106
|
-
ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
107
|
-
for i in range(self.num)])
|
|
108
|
-
self.maxpool = torch.nn.ModuleList([
|
|
109
|
-
ocnn.nn.OctreeMaxPool(nempty) for i in range(self.num - 1)])
|
|
110
|
-
|
|
111
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
112
|
-
out = data
|
|
113
|
-
for i in range(self.num - 1):
|
|
114
|
-
depth_i = depth - i
|
|
115
|
-
out = self.conv[i](out, octree, depth_i)
|
|
116
|
-
out = self.maxpool[i](out, octree, depth_i)
|
|
117
|
-
out = self.conv[-1](out, octree, depth - self.num + 1)
|
|
118
|
-
return out
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
class ClsHeader(torch.nn.Module):
|
|
122
|
-
|
|
123
|
-
def __init__(self, channels: List[int], out_channels: int, nempty: bool = False):
|
|
124
|
-
super().__init__()
|
|
125
|
-
self.channels = channels
|
|
126
|
-
self.out_channels = out_channels
|
|
127
|
-
self.nempty = nempty
|
|
128
|
-
|
|
129
|
-
in_channels = int(torch.Tensor(channels).sum())
|
|
130
|
-
self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, 1024)
|
|
131
|
-
self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
|
|
132
|
-
self.header = torch.nn.Sequential(
|
|
133
|
-
torch.nn.Flatten(start_dim=1),
|
|
134
|
-
torch.nn.Linear(1024, out_channels, bias=True))
|
|
135
|
-
# self.header = torch.nn.Sequential(
|
|
136
|
-
# ocnn.modules.FcBnRelu(512, 256),
|
|
137
|
-
# torch.nn.Dropout(p=0.5),
|
|
138
|
-
# torch.nn.Linear(256, out_channels))
|
|
139
|
-
|
|
140
|
-
def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
|
|
141
|
-
full_depth = 2
|
|
142
|
-
num = len(data)
|
|
143
|
-
outs = [x for x in data] # avoid modifying the input data
|
|
144
|
-
for i in range(num):
|
|
145
|
-
depth_i = depth - i
|
|
146
|
-
for d in range(depth_i, full_depth, -1):
|
|
147
|
-
outs[i] = ocnn.nn.octree_max_pool(outs[i], octree, d, self.nempty)
|
|
148
|
-
|
|
149
|
-
out = torch.cat(outs, dim=1)
|
|
150
|
-
out = self.conv1x1(out)
|
|
151
|
-
out = self.global_pool(out, octree, full_depth)
|
|
152
|
-
logit = self.header(out)
|
|
153
|
-
return logit
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
class HRNet(torch.nn.Module):
|
|
157
|
-
r''' Octree-based HRNet for classification and segmentation. '''
|
|
158
|
-
|
|
159
|
-
def __init__(self, in_channels: int, out_channels: int, stages: int = 3,
|
|
160
|
-
interp: str = 'linear', nempty: bool = False):
|
|
161
|
-
super().__init__()
|
|
162
|
-
self.in_channels = in_channels
|
|
163
|
-
self.out_channels = out_channels
|
|
164
|
-
self.interp = interp
|
|
165
|
-
self.nempty = nempty
|
|
166
|
-
self.stages = stages
|
|
167
|
-
|
|
168
|
-
self.resblk_num = 3
|
|
169
|
-
self.channels = [128, 256, 512, 512]
|
|
170
|
-
|
|
171
|
-
self.front = FrontLayer([in_channels, 32, self.channels[0]], nempty)
|
|
172
|
-
self.branches = torch.nn.ModuleList([
|
|
173
|
-
Branches(self.channels[:i+1], self.resblk_num, nempty)
|
|
174
|
-
for i in range(stages)])
|
|
175
|
-
self.transitions = torch.nn.ModuleList([
|
|
176
|
-
Transitions(self.channels[:i+2], nempty)
|
|
177
|
-
for i in range(stages-1)])
|
|
178
|
-
|
|
179
|
-
self.cls_header = ClsHeader(self.channels[:stages], out_channels, nempty)
|
|
180
|
-
|
|
181
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
182
|
-
r''''''
|
|
183
|
-
convs = [self.front(data, octree, depth)]
|
|
184
|
-
depth = depth - 1 # the data is downsampled in `front`
|
|
185
|
-
for i in range(self.stages):
|
|
186
|
-
convs = self.branches[i](convs, octree, depth)
|
|
187
|
-
if i < self.stages - 1:
|
|
188
|
-
convs = self.transitions[i](convs, octree, depth)
|
|
189
|
-
|
|
190
|
-
logits = self.cls_header(convs, octree, depth)
|
|
191
|
-
|
|
192
|
-
return logits
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from typing import List
|
|
10
|
+
|
|
11
|
+
import ocnn
|
|
12
|
+
from ocnn.octree import Octree
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Branches(torch.nn.Module):
|
|
16
|
+
|
|
17
|
+
def __init__(self, channels: List[int], resblk_num: int, nempty: bool = False):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.channels = channels
|
|
20
|
+
self.resblk_num = resblk_num
|
|
21
|
+
bottlenecks = [4 if c < 256 else 8 for c in channels] # to save parameters
|
|
22
|
+
self.resblocks = torch.nn.ModuleList([
|
|
23
|
+
ocnn.modules.OctreeResBlocks(ch, ch, resblk_num, bnk, nempty=nempty)
|
|
24
|
+
for ch, bnk in zip(channels, bottlenecks)])
|
|
25
|
+
|
|
26
|
+
def forward(self, datas: List[torch.Tensor], octree: Octree, depth: int):
|
|
27
|
+
num = len(self.channels)
|
|
28
|
+
torch._assert(len(datas) == num, 'Error')
|
|
29
|
+
|
|
30
|
+
out = [None] * num
|
|
31
|
+
for i in range(num):
|
|
32
|
+
depth_i = depth - i
|
|
33
|
+
out[i] = self.resblocks[i](datas[i], octree, depth_i)
|
|
34
|
+
return out
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TransFunc(torch.nn.Module):
|
|
38
|
+
|
|
39
|
+
def __init__(self, in_channels: int, out_channels: int, nempty: bool = False):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.in_channels = in_channels
|
|
42
|
+
self.out_channels = out_channels
|
|
43
|
+
self.nempty = nempty
|
|
44
|
+
self.maxpool = ocnn.nn.OctreeMaxPool(nempty=nempty)
|
|
45
|
+
self.upsample = ocnn.nn.OctreeUpsample(method='nearest', nempty=nempty)
|
|
46
|
+
if in_channels != out_channels:
|
|
47
|
+
self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, out_channels)
|
|
48
|
+
|
|
49
|
+
def forward(self, data: torch.Tensor, octree: Octree,
|
|
50
|
+
in_depth: int, out_depth: int):
|
|
51
|
+
out = data
|
|
52
|
+
if in_depth > out_depth:
|
|
53
|
+
for d in range(in_depth, out_depth, -1):
|
|
54
|
+
out = self.maxpool(out, octree, d)
|
|
55
|
+
if self.in_channels != self.out_channels:
|
|
56
|
+
out = self.conv1x1(out)
|
|
57
|
+
|
|
58
|
+
if in_depth < out_depth:
|
|
59
|
+
if self.in_channels != self.out_channels:
|
|
60
|
+
out = self.conv1x1(out)
|
|
61
|
+
for d in range(in_depth, out_depth, 1):
|
|
62
|
+
out = self.upsample(out, octree, d)
|
|
63
|
+
return out
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Transitions(torch.nn.Module):
|
|
67
|
+
|
|
68
|
+
def __init__(self, channels: List[int], nempty: bool = False):
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.channels = channels
|
|
71
|
+
self.nempty = nempty
|
|
72
|
+
|
|
73
|
+
num = len(self.channels)
|
|
74
|
+
self.trans_func = torch.nn.ModuleList()
|
|
75
|
+
for i in range(num - 1):
|
|
76
|
+
for j in range(num):
|
|
77
|
+
self.trans_func.append(TransFunc(channels[i], channels[j], nempty))
|
|
78
|
+
|
|
79
|
+
def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
|
|
80
|
+
num = len(self.channels)
|
|
81
|
+
features = [[None] * (num - 1) for _ in range(num)]
|
|
82
|
+
for i in range(num - 1):
|
|
83
|
+
for j in range(num):
|
|
84
|
+
k = i * num + j
|
|
85
|
+
in_depth = depth - i
|
|
86
|
+
out_depth = depth - j
|
|
87
|
+
features[j][i] = self.trans_func[k](
|
|
88
|
+
data[i], octree, in_depth, out_depth)
|
|
89
|
+
|
|
90
|
+
out = [None] * num
|
|
91
|
+
for j in range(num):
|
|
92
|
+
# In the original tensorflow implmentation, a relu is added after the sum.
|
|
93
|
+
out[j] = torch.stack(features[j], dim=0).sum(dim=0)
|
|
94
|
+
return out
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class FrontLayer(torch.nn.Module):
|
|
98
|
+
|
|
99
|
+
def __init__(self, channels: List[int], nempty: bool = False):
|
|
100
|
+
super().__init__()
|
|
101
|
+
self.channels = channels
|
|
102
|
+
self.num = len(channels) - 1
|
|
103
|
+
self.nempty = nempty
|
|
104
|
+
|
|
105
|
+
self.conv = torch.nn.ModuleList([
|
|
106
|
+
ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
107
|
+
for i in range(self.num)])
|
|
108
|
+
self.maxpool = torch.nn.ModuleList([
|
|
109
|
+
ocnn.nn.OctreeMaxPool(nempty) for i in range(self.num - 1)])
|
|
110
|
+
|
|
111
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
112
|
+
out = data
|
|
113
|
+
for i in range(self.num - 1):
|
|
114
|
+
depth_i = depth - i
|
|
115
|
+
out = self.conv[i](out, octree, depth_i)
|
|
116
|
+
out = self.maxpool[i](out, octree, depth_i)
|
|
117
|
+
out = self.conv[-1](out, octree, depth - self.num + 1)
|
|
118
|
+
return out
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class ClsHeader(torch.nn.Module):
|
|
122
|
+
|
|
123
|
+
def __init__(self, channels: List[int], out_channels: int, nempty: bool = False):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.channels = channels
|
|
126
|
+
self.out_channels = out_channels
|
|
127
|
+
self.nempty = nempty
|
|
128
|
+
|
|
129
|
+
in_channels = int(torch.Tensor(channels).sum())
|
|
130
|
+
self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, 1024)
|
|
131
|
+
self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
|
|
132
|
+
self.header = torch.nn.Sequential(
|
|
133
|
+
torch.nn.Flatten(start_dim=1),
|
|
134
|
+
torch.nn.Linear(1024, out_channels, bias=True))
|
|
135
|
+
# self.header = torch.nn.Sequential(
|
|
136
|
+
# ocnn.modules.FcBnRelu(512, 256),
|
|
137
|
+
# torch.nn.Dropout(p=0.5),
|
|
138
|
+
# torch.nn.Linear(256, out_channels))
|
|
139
|
+
|
|
140
|
+
def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
|
|
141
|
+
full_depth = 2
|
|
142
|
+
num = len(data)
|
|
143
|
+
outs = [x for x in data] # avoid modifying the input data
|
|
144
|
+
for i in range(num):
|
|
145
|
+
depth_i = depth - i
|
|
146
|
+
for d in range(depth_i, full_depth, -1):
|
|
147
|
+
outs[i] = ocnn.nn.octree_max_pool(outs[i], octree, d, self.nempty)
|
|
148
|
+
|
|
149
|
+
out = torch.cat(outs, dim=1)
|
|
150
|
+
out = self.conv1x1(out)
|
|
151
|
+
out = self.global_pool(out, octree, full_depth)
|
|
152
|
+
logit = self.header(out)
|
|
153
|
+
return logit
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class HRNet(torch.nn.Module):
|
|
157
|
+
r''' Octree-based HRNet for classification and segmentation. '''
|
|
158
|
+
|
|
159
|
+
def __init__(self, in_channels: int, out_channels: int, stages: int = 3,
|
|
160
|
+
interp: str = 'linear', nempty: bool = False):
|
|
161
|
+
super().__init__()
|
|
162
|
+
self.in_channels = in_channels
|
|
163
|
+
self.out_channels = out_channels
|
|
164
|
+
self.interp = interp
|
|
165
|
+
self.nempty = nempty
|
|
166
|
+
self.stages = stages
|
|
167
|
+
|
|
168
|
+
self.resblk_num = 3
|
|
169
|
+
self.channels = [128, 256, 512, 512]
|
|
170
|
+
|
|
171
|
+
self.front = FrontLayer([in_channels, 32, self.channels[0]], nempty)
|
|
172
|
+
self.branches = torch.nn.ModuleList([
|
|
173
|
+
Branches(self.channels[:i+1], self.resblk_num, nempty)
|
|
174
|
+
for i in range(stages)])
|
|
175
|
+
self.transitions = torch.nn.ModuleList([
|
|
176
|
+
Transitions(self.channels[:i+2], nempty)
|
|
177
|
+
for i in range(stages-1)])
|
|
178
|
+
|
|
179
|
+
self.cls_header = ClsHeader(self.channels[:stages], out_channels, nempty)
|
|
180
|
+
|
|
181
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
182
|
+
r''''''
|
|
183
|
+
convs = [self.front(data, octree, depth)]
|
|
184
|
+
depth = depth - 1 # the data is downsampled in `front`
|
|
185
|
+
for i in range(self.stages):
|
|
186
|
+
convs = self.branches[i](convs, octree, depth)
|
|
187
|
+
if i < self.stages - 1:
|
|
188
|
+
convs = self.transitions[i](convs, octree, depth)
|
|
189
|
+
|
|
190
|
+
logits = self.cls_header(convs, octree, depth)
|
|
191
|
+
|
|
192
|
+
return logits
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from torchvision.models import resnet18
|
|
12
|
+
|
|
13
|
+
import ocnn
|
|
14
|
+
from ocnn.octree import Octree
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Image2Shape(torch.nn.Module):
|
|
18
|
+
r''' Octree-based AutoEncoder for shape encoding and decoding.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
channel_out (int): The channel of the output signal.
|
|
22
|
+
depth (int): The depth of the octree.
|
|
23
|
+
full_depth (int): The full depth of the octree.
|
|
24
|
+
'''
|
|
25
|
+
|
|
26
|
+
def __init__(self, channel_out: int, depth: int, full_depth: int = 2,
|
|
27
|
+
code_channel: int = 32):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.depth = depth
|
|
30
|
+
self.full_depth = full_depth
|
|
31
|
+
self.channel_out = channel_out
|
|
32
|
+
self.resblk_num = 2
|
|
33
|
+
self.channels = [512, 512, 256, 256, 128, 128, 64, 64, 32, 32]
|
|
34
|
+
self.code_channel = code_channel
|
|
35
|
+
|
|
36
|
+
# encoder
|
|
37
|
+
self.resnet18 = resnet18()
|
|
38
|
+
channel = self.code_channel * 2 ** (3 * full_depth)
|
|
39
|
+
self.resnet18.fc = torch.nn.Linear(512, channel, bias=True)
|
|
40
|
+
|
|
41
|
+
# decoder
|
|
42
|
+
self.channels[full_depth] = self.code_channel # update `channels`
|
|
43
|
+
self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
|
|
44
|
+
self.channels[d-1], self.channels[d], kernel_size=[2], stride=2,
|
|
45
|
+
nempty=False) for d in range(full_depth+1, depth+1)])
|
|
46
|
+
self.decoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
47
|
+
self.channels[d], self.channels[d], self.resblk_num, nempty=False)
|
|
48
|
+
for d in range(full_depth, depth+1)])
|
|
49
|
+
|
|
50
|
+
# header
|
|
51
|
+
self.predict = torch.nn.ModuleList([self._make_predict_module(
|
|
52
|
+
self.channels[d], 2) for d in range(full_depth, depth + 1)])
|
|
53
|
+
self.header = self._make_predict_module(self.channels[depth], channel_out)
|
|
54
|
+
|
|
55
|
+
def _make_predict_module(self, channel_in, channel_out=2, num_hidden=64):
|
|
56
|
+
return torch.nn.Sequential(
|
|
57
|
+
ocnn.modules.Conv1x1BnRelu(channel_in, num_hidden),
|
|
58
|
+
ocnn.modules.Conv1x1(num_hidden, channel_out, use_bias=True))
|
|
59
|
+
|
|
60
|
+
def decoder(self, shape_code: torch.Tensor, octree: Octree,
|
|
61
|
+
update_octree: bool = False):
|
|
62
|
+
r''' The decoder network of the AutoEncoder.
|
|
63
|
+
'''
|
|
64
|
+
|
|
65
|
+
logits = dict()
|
|
66
|
+
deconv = shape_code
|
|
67
|
+
depth, full_depth = self.depth, self.full_depth
|
|
68
|
+
for i, d in enumerate(range(full_depth, depth+1)):
|
|
69
|
+
if d > full_depth:
|
|
70
|
+
deconv = self.upsample[i-1](deconv, octree, d-1)
|
|
71
|
+
deconv = self.decoder_blks[i](deconv, octree, d)
|
|
72
|
+
|
|
73
|
+
# predict the splitting label
|
|
74
|
+
logit = self.predict[i](deconv)
|
|
75
|
+
logits[d] = logit
|
|
76
|
+
|
|
77
|
+
# update the octree according to predicted labels
|
|
78
|
+
if update_octree:
|
|
79
|
+
split = logit.argmax(1).int()
|
|
80
|
+
octree.octree_split(split, d)
|
|
81
|
+
if d < depth:
|
|
82
|
+
octree.octree_grow(d + 1)
|
|
83
|
+
|
|
84
|
+
# predict the signal
|
|
85
|
+
if d == depth:
|
|
86
|
+
signal = self.header(deconv)
|
|
87
|
+
signal = torch.tanh(signal)
|
|
88
|
+
signal = ocnn.nn.octree_depad(signal, octree, depth)
|
|
89
|
+
if update_octree:
|
|
90
|
+
octree.features[depth] = signal
|
|
91
|
+
|
|
92
|
+
return {'logits': logits, 'signal': signal, 'octree_out': octree}
|
|
93
|
+
|
|
94
|
+
def decode_code(self, shape_code: torch.Tensor):
|
|
95
|
+
r''' Decodes the shape code to an output octree.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
shape_code (torch.Tensor): The shape code for decoding.
|
|
99
|
+
'''
|
|
100
|
+
|
|
101
|
+
octree_out = self.init_octree(shape_code)
|
|
102
|
+
out = self.decoder(shape_code, octree_out, update_octree=True)
|
|
103
|
+
return out
|
|
104
|
+
|
|
105
|
+
def init_octree(self, shape_code: torch.Tensor):
|
|
106
|
+
r''' Initialize a full octree for decoding.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
shape_code (torch.Tensor): The shape code for decoding, used to getting
|
|
110
|
+
the `batch_size` and `device` to initialize the output octree.
|
|
111
|
+
'''
|
|
112
|
+
|
|
113
|
+
node_num = 2 ** (3 * self.full_depth)
|
|
114
|
+
batch_size = shape_code.size(0) // node_num
|
|
115
|
+
octree = ocnn.octree.init_octree(
|
|
116
|
+
self.depth, self.full_depth, batch_size, shape_code.device)
|
|
117
|
+
return octree
|
|
118
|
+
|
|
119
|
+
def forward(self, image: torch.Tensor, octree: Optional[Octree] = None,
|
|
120
|
+
update_octree: bool = False):
|
|
121
|
+
r''''''
|
|
122
|
+
|
|
123
|
+
shape_code = self.resnet18(image)
|
|
124
|
+
shape_code = shape_code.view(-1, self.code_channel)
|
|
125
|
+
if update_octree:
|
|
126
|
+
octree = self.init_octree(shape_code)
|
|
127
|
+
out = self.decoder(shape_code, octree, update_octree)
|
|
128
|
+
return out
|
ocnn/models/lenet.py
CHANGED
|
@@ -1,46 +1,46 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import ocnn
|
|
10
|
-
from ocnn.octree import Octree
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class LeNet(torch.nn.Module):
|
|
14
|
-
r''' Octree-based LeNet for classification.
|
|
15
|
-
'''
|
|
16
|
-
|
|
17
|
-
def __init__(self, in_channels: int, out_channels: int, stages: int,
|
|
18
|
-
nempty: bool = False):
|
|
19
|
-
super().__init__()
|
|
20
|
-
self.in_channels = in_channels
|
|
21
|
-
self.out_channels = out_channels
|
|
22
|
-
self.stages = stages
|
|
23
|
-
self.nempty = nempty
|
|
24
|
-
channels = [in_channels] + [2 ** max(i+7-stages, 2) for i in range(stages)]
|
|
25
|
-
|
|
26
|
-
self.convs = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
27
|
-
channels[i], channels[i+1], nempty=nempty) for i in range(stages)])
|
|
28
|
-
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
29
|
-
nempty) for i in range(stages)])
|
|
30
|
-
self.octree2voxel = ocnn.nn.Octree2Voxel(self.nempty)
|
|
31
|
-
self.header = torch.nn.Sequential(
|
|
32
|
-
torch.nn.Dropout(p=0.5), # drop1
|
|
33
|
-
ocnn.modules.FcBnRelu(64 * 64, 128), # fc1
|
|
34
|
-
torch.nn.Dropout(p=0.5), # drop2
|
|
35
|
-
torch.nn.Linear(128, out_channels)) # fc2
|
|
36
|
-
|
|
37
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
38
|
-
r''''''
|
|
39
|
-
|
|
40
|
-
for i in range(self.stages):
|
|
41
|
-
d = depth - i
|
|
42
|
-
data = self.convs[i](data, octree, d)
|
|
43
|
-
data = self.pools[i](data, octree, d)
|
|
44
|
-
data = self.octree2voxel(data, octree, depth-self.stages)
|
|
45
|
-
data = self.header(data)
|
|
46
|
-
return data
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import ocnn
|
|
10
|
+
from ocnn.octree import Octree
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LeNet(torch.nn.Module):
|
|
14
|
+
r''' Octree-based LeNet for classification.
|
|
15
|
+
'''
|
|
16
|
+
|
|
17
|
+
def __init__(self, in_channels: int, out_channels: int, stages: int,
|
|
18
|
+
nempty: bool = False):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.in_channels = in_channels
|
|
21
|
+
self.out_channels = out_channels
|
|
22
|
+
self.stages = stages
|
|
23
|
+
self.nempty = nempty
|
|
24
|
+
channels = [in_channels] + [2 ** max(i+7-stages, 2) for i in range(stages)]
|
|
25
|
+
|
|
26
|
+
self.convs = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
27
|
+
channels[i], channels[i+1], nempty=nempty) for i in range(stages)])
|
|
28
|
+
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
29
|
+
nempty) for i in range(stages)])
|
|
30
|
+
self.octree2voxel = ocnn.nn.Octree2Voxel(self.nempty)
|
|
31
|
+
self.header = torch.nn.Sequential(
|
|
32
|
+
torch.nn.Dropout(p=0.5), # drop1
|
|
33
|
+
ocnn.modules.FcBnRelu(64 * 64, 128), # fc1
|
|
34
|
+
torch.nn.Dropout(p=0.5), # drop2
|
|
35
|
+
torch.nn.Linear(128, out_channels)) # fc2
|
|
36
|
+
|
|
37
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
38
|
+
r''''''
|
|
39
|
+
|
|
40
|
+
for i in range(self.stages):
|
|
41
|
+
d = depth - i
|
|
42
|
+
data = self.convs[i](data, octree, d)
|
|
43
|
+
data = self.pools[i](data, octree, d)
|
|
44
|
+
data = self.octree2voxel(data, octree, depth-self.stages)
|
|
45
|
+
data = self.header(data)
|
|
46
|
+
return data
|