ocnn 2.2.0__py3-none-any.whl → 2.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -158
- ocnn/models/__init__.py +29 -24
- ocnn/models/autoencoder.py +155 -165
- ocnn/models/hrnet.py +192 -191
- ocnn/models/image2shape.py +128 -0
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -0
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +20 -20
- ocnn/modules/modules.py +193 -231
- ocnn/modules/resblocks.py +124 -124
- ocnn/nn/__init__.py +42 -40
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -0
- ocnn/nn/octree_conv.py +411 -411
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +204 -204
- ocnn/nn/octree_gconv.py +79 -0
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +86 -56
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -197
- ocnn/octree/__init__.py +22 -21
- ocnn/octree/octree.py +639 -581
- ocnn/octree/points.py +317 -298
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +202 -153
- {ocnn-2.2.0.dist-info → ocnn-2.2.2.dist-info}/LICENSE +21 -21
- {ocnn-2.2.0.dist-info → ocnn-2.2.2.dist-info}/METADATA +67 -65
- ocnn-2.2.2.dist-info/RECORD +36 -0
- {ocnn-2.2.0.dist-info → ocnn-2.2.2.dist-info}/WHEEL +1 -1
- ocnn-2.2.0.dist-info/RECORD +0 -32
- {ocnn-2.2.0.dist-info → ocnn-2.2.2.dist-info}/top_level.txt +0 -0
ocnn/models/autoencoder.py
CHANGED
|
@@ -1,165 +1,155 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.nn
|
|
10
|
-
from typing import Optional
|
|
11
|
-
|
|
12
|
-
import ocnn
|
|
13
|
-
from ocnn.octree import Octree
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class AutoEncoder(torch.nn.Module):
|
|
17
|
-
r''' Octree-based AutoEncoder for shape encoding and decoding.
|
|
18
|
-
|
|
19
|
-
Args:
|
|
20
|
-
channel_in (int): The channel of the input signal.
|
|
21
|
-
channel_out (int): The channel of the output signal.
|
|
22
|
-
depth (int): The depth of the octree.
|
|
23
|
-
full_depth (int): The full depth of the octree.
|
|
24
|
-
feature (str): The feature type of the input signal. For details of this
|
|
25
|
-
argument, please refer to :class:`ocnn.modules.InputFeature`.
|
|
26
|
-
'''
|
|
27
|
-
|
|
28
|
-
def __init__(self, channel_in: int, channel_out: int, depth: int,
|
|
29
|
-
full_depth: int = 2, feature: str = 'ND'):
|
|
30
|
-
super().__init__()
|
|
31
|
-
self.channel_in = channel_in
|
|
32
|
-
self.channel_out = channel_out
|
|
33
|
-
self.depth = depth
|
|
34
|
-
self.full_depth = full_depth
|
|
35
|
-
self.feature = feature
|
|
36
|
-
self.resblk_num = 2
|
|
37
|
-
self.
|
|
38
|
-
self.channels = [512, 512, 256, 256, 128, 128, 32, 32, 16, 16]
|
|
39
|
-
|
|
40
|
-
# encoder
|
|
41
|
-
self.conv1 = ocnn.modules.OctreeConvBnRelu(
|
|
42
|
-
channel_in, self.channels[depth], nempty=False)
|
|
43
|
-
self.encoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
44
|
-
self.channels[d], self.channels[d], self.resblk_num, nempty=False)
|
|
45
|
-
for d in range(depth, full_depth-1, -1)])
|
|
46
|
-
self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
47
|
-
self.channels[d], self.channels[d-1], kernel_size=[2], stride=2,
|
|
48
|
-
nempty=False) for d in range(depth, full_depth, -1)])
|
|
49
|
-
self.proj = torch.nn.Linear(
|
|
50
|
-
self.channels[full_depth], self.
|
|
51
|
-
|
|
52
|
-
# decoder
|
|
53
|
-
self.channels[full_depth] = self.
|
|
54
|
-
self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
|
|
55
|
-
self.channels[d-1], self.channels[d], kernel_size=[2], stride=2,
|
|
56
|
-
nempty=False) for d in range(full_depth+1, depth+1)])
|
|
57
|
-
self.decoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
58
|
-
self.channels[d], self.channels[d], self.resblk_num, nempty=False)
|
|
59
|
-
for d in range(full_depth, depth+1)])
|
|
60
|
-
|
|
61
|
-
# header
|
|
62
|
-
self.predict = torch.nn.ModuleList([self._make_predict_module(
|
|
63
|
-
self.channels[d], 2) for d in range(full_depth, depth + 1)])
|
|
64
|
-
self.header = self._make_predict_module(self.channels[depth], channel_out)
|
|
65
|
-
|
|
66
|
-
def _make_predict_module(self, channel_in, channel_out=2, num_hidden=64):
|
|
67
|
-
return torch.nn.Sequential(
|
|
68
|
-
ocnn.modules.Conv1x1BnRelu(channel_in, num_hidden),
|
|
69
|
-
ocnn.modules.Conv1x1(num_hidden, channel_out, use_bias=True))
|
|
70
|
-
|
|
71
|
-
def
|
|
72
|
-
r'''
|
|
73
|
-
'''
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
if
|
|
119
|
-
octree.
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
return octree
|
|
157
|
-
|
|
158
|
-
def forward(self, octree: Octree, update_octree: bool):
|
|
159
|
-
r''''''
|
|
160
|
-
|
|
161
|
-
shape_code = self.ae_encoder(octree)
|
|
162
|
-
if update_octree:
|
|
163
|
-
octree = self.init_octree(shape_code)
|
|
164
|
-
out = self.ae_decoder(shape_code, octree, update_octree)
|
|
165
|
-
return out
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
import ocnn
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AutoEncoder(torch.nn.Module):
|
|
17
|
+
r''' Octree-based AutoEncoder for shape encoding and decoding.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
channel_in (int): The channel of the input signal.
|
|
21
|
+
channel_out (int): The channel of the output signal.
|
|
22
|
+
depth (int): The depth of the octree.
|
|
23
|
+
full_depth (int): The full depth of the octree.
|
|
24
|
+
feature (str): The feature type of the input signal. For details of this
|
|
25
|
+
argument, please refer to :class:`ocnn.modules.InputFeature`.
|
|
26
|
+
'''
|
|
27
|
+
|
|
28
|
+
def __init__(self, channel_in: int, channel_out: int, depth: int,
|
|
29
|
+
full_depth: int = 2, feature: str = 'ND'):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.channel_in = channel_in
|
|
32
|
+
self.channel_out = channel_out
|
|
33
|
+
self.depth = depth
|
|
34
|
+
self.full_depth = full_depth
|
|
35
|
+
self.feature = feature
|
|
36
|
+
self.resblk_num = 2
|
|
37
|
+
self.code_channel = 64 # dim-of-code = code_channel * 2**(3*full_depth)
|
|
38
|
+
self.channels = [512, 512, 256, 256, 128, 128, 32, 32, 16, 16]
|
|
39
|
+
|
|
40
|
+
# encoder
|
|
41
|
+
self.conv1 = ocnn.modules.OctreeConvBnRelu(
|
|
42
|
+
channel_in, self.channels[depth], nempty=False)
|
|
43
|
+
self.encoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
44
|
+
self.channels[d], self.channels[d], self.resblk_num, nempty=False)
|
|
45
|
+
for d in range(depth, full_depth-1, -1)])
|
|
46
|
+
self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
47
|
+
self.channels[d], self.channels[d-1], kernel_size=[2], stride=2,
|
|
48
|
+
nempty=False) for d in range(depth, full_depth, -1)])
|
|
49
|
+
self.proj = torch.nn.Linear(
|
|
50
|
+
self.channels[full_depth], self.code_channel, bias=True)
|
|
51
|
+
|
|
52
|
+
# decoder
|
|
53
|
+
self.channels[full_depth] = self.code_channel # update `channels`
|
|
54
|
+
self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
|
|
55
|
+
self.channels[d-1], self.channels[d], kernel_size=[2], stride=2,
|
|
56
|
+
nempty=False) for d in range(full_depth+1, depth+1)])
|
|
57
|
+
self.decoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
|
|
58
|
+
self.channels[d], self.channels[d], self.resblk_num, nempty=False)
|
|
59
|
+
for d in range(full_depth, depth+1)])
|
|
60
|
+
|
|
61
|
+
# header
|
|
62
|
+
self.predict = torch.nn.ModuleList([self._make_predict_module(
|
|
63
|
+
self.channels[d], 2) for d in range(full_depth, depth + 1)])
|
|
64
|
+
self.header = self._make_predict_module(self.channels[depth], channel_out)
|
|
65
|
+
|
|
66
|
+
def _make_predict_module(self, channel_in, channel_out=2, num_hidden=64):
|
|
67
|
+
return torch.nn.Sequential(
|
|
68
|
+
ocnn.modules.Conv1x1BnRelu(channel_in, num_hidden),
|
|
69
|
+
ocnn.modules.Conv1x1(num_hidden, channel_out, use_bias=True))
|
|
70
|
+
|
|
71
|
+
def encoder(self, octree: Octree):
|
|
72
|
+
r''' The encoder network of the AutoEncoder.
|
|
73
|
+
'''
|
|
74
|
+
|
|
75
|
+
convs = dict()
|
|
76
|
+
depth, full_depth = self.depth, self.full_depth
|
|
77
|
+
data = octree.get_input_feature(self.feature, nempty=False)
|
|
78
|
+
assert data.size(1) == self.channel_in
|
|
79
|
+
convs[depth] = self.conv1(data, octree, depth)
|
|
80
|
+
for i, d in enumerate(range(depth, full_depth-1, -1)):
|
|
81
|
+
convs[d] = self.encoder_blks[i](convs[d], octree, d)
|
|
82
|
+
if d > full_depth:
|
|
83
|
+
convs[d-1] = self.downsample[i](convs[d], octree, d)
|
|
84
|
+
|
|
85
|
+
# NOTE: here tanh is used to constrain the shape code in [-1, 1]
|
|
86
|
+
shape_code = self.proj(convs[full_depth]).tanh()
|
|
87
|
+
return shape_code
|
|
88
|
+
|
|
89
|
+
def decoder(self, shape_code: torch.Tensor, octree: Octree,
|
|
90
|
+
update_octree: bool = False):
|
|
91
|
+
r''' The decoder network of the AutoEncoder.
|
|
92
|
+
'''
|
|
93
|
+
|
|
94
|
+
logits = dict()
|
|
95
|
+
deconv = shape_code
|
|
96
|
+
depth, full_depth = self.depth, self.full_depth
|
|
97
|
+
for i, d in enumerate(range(full_depth, depth+1)):
|
|
98
|
+
if d > full_depth:
|
|
99
|
+
deconv = self.upsample[i-1](deconv, octree, d-1)
|
|
100
|
+
deconv = self.decoder_blks[i](deconv, octree, d)
|
|
101
|
+
|
|
102
|
+
# predict the splitting label
|
|
103
|
+
logit = self.predict[i](deconv)
|
|
104
|
+
logits[d] = logit
|
|
105
|
+
|
|
106
|
+
# update the octree according to predicted labels
|
|
107
|
+
if update_octree:
|
|
108
|
+
split = logit.argmax(1).int()
|
|
109
|
+
octree.octree_split(split, d)
|
|
110
|
+
if d < depth:
|
|
111
|
+
octree.octree_grow(d + 1)
|
|
112
|
+
|
|
113
|
+
# predict the signal
|
|
114
|
+
if d == depth:
|
|
115
|
+
signal = self.header(deconv)
|
|
116
|
+
signal = torch.tanh(signal)
|
|
117
|
+
signal = ocnn.nn.octree_depad(signal, octree, depth)
|
|
118
|
+
if update_octree:
|
|
119
|
+
octree.features[depth] = signal
|
|
120
|
+
|
|
121
|
+
return {'logits': logits, 'signal': signal, 'octree_out': octree}
|
|
122
|
+
|
|
123
|
+
def decode_code(self, shape_code: torch.Tensor):
|
|
124
|
+
r''' Decodes the shape code to an output octree.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
shape_code (torch.Tensor): The shape code for decoding.
|
|
128
|
+
'''
|
|
129
|
+
|
|
130
|
+
octree_out = self.init_octree(shape_code)
|
|
131
|
+
out = self.decoder(shape_code, octree_out, update_octree=True)
|
|
132
|
+
return out
|
|
133
|
+
|
|
134
|
+
def init_octree(self, shape_code: torch.Tensor):
|
|
135
|
+
r''' Initialize a full octree for decoding.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
shape_code (torch.Tensor): The shape code for decoding, used to get
|
|
139
|
+
the `batch_size` and `device` to initialize the output octree.
|
|
140
|
+
'''
|
|
141
|
+
|
|
142
|
+
node_num = 2 ** (3 * self.full_depth)
|
|
143
|
+
batch_size = shape_code.size(0) // node_num
|
|
144
|
+
octree = ocnn.octree.init_octree(
|
|
145
|
+
self.depth, self.full_depth, batch_size, shape_code.device)
|
|
146
|
+
return octree
|
|
147
|
+
|
|
148
|
+
def forward(self, octree: Octree, update_octree: bool):
|
|
149
|
+
r''''''
|
|
150
|
+
|
|
151
|
+
shape_code = self.encoder(octree)
|
|
152
|
+
if update_octree:
|
|
153
|
+
octree = self.init_octree(shape_code)
|
|
154
|
+
out = self.decoder(shape_code, octree, update_octree)
|
|
155
|
+
return out
|