ocnn 2.2.1__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,165 +1,155 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from typing import Optional
11
-
12
- import ocnn
13
- from ocnn.octree import Octree
14
-
15
-
16
- class AutoEncoder(torch.nn.Module):
17
- r''' Octree-based AutoEncoder for shape encoding and decoding.
18
-
19
- Args:
20
- channel_in (int): The channel of the input signal.
21
- channel_out (int): The channel of the output signal.
22
- depth (int): The depth of the octree.
23
- full_depth (int): The full depth of the octree.
24
- feature (str): The feature type of the input signal. For details of this
25
- argument, please refer to :class:`ocnn.modules.InputFeature`.
26
- '''
27
-
28
- def __init__(self, channel_in: int, channel_out: int, depth: int,
29
- full_depth: int = 2, feature: str = 'ND', code_channel: int = 128):
30
- super().__init__()
31
- self.channel_in = channel_in
32
- self.channel_out = channel_out
33
- self.depth = depth
34
- self.full_depth = full_depth
35
- self.feature = feature
36
- self.resblk_num = 2
37
- self.channels = [512, 512, 256, 256, 128, 128, 32, 32, 16, 16]
38
- self.code_channel = code_channel if code_channel > 0 else self.channels[full_depth]
39
-
40
- # encoder
41
- self.conv1 = ocnn.modules.OctreeConvBnRelu(
42
- channel_in, self.channels[depth], nempty=False)
43
- self.encoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
44
- self.channels[d], self.channels[d], self.resblk_num, nempty=False)
45
- for d in range(depth, full_depth-1, -1)])
46
- self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
47
- self.channels[d], self.channels[d-1], kernel_size=[2], stride=2,
48
- nempty=False) for d in range(depth, full_depth, -1)])
49
- self.proj = torch.nn.Linear(
50
- self.channels[full_depth], self.code_channel, bias=True)
51
-
52
- # decoder
53
- self.channels[full_depth] = self.code_channel # update `channels`
54
- self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
55
- self.channels[d-1], self.channels[d], kernel_size=[2], stride=2,
56
- nempty=False) for d in range(full_depth+1, depth+1)])
57
- self.decoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
58
- self.channels[d], self.channels[d], self.resblk_num, nempty=False)
59
- for d in range(full_depth, depth+1)])
60
-
61
- # header
62
- self.predict = torch.nn.ModuleList([self._make_predict_module(
63
- self.channels[d], 2) for d in range(full_depth, depth + 1)])
64
- self.header = self._make_predict_module(self.channels[depth], channel_out)
65
-
66
- def _make_predict_module(self, channel_in, channel_out=2, num_hidden=64):
67
- return torch.nn.Sequential(
68
- ocnn.modules.Conv1x1BnRelu(channel_in, num_hidden),
69
- ocnn.modules.Conv1x1(num_hidden, channel_out, use_bias=True))
70
-
71
- def get_input_feature(self, octree: Octree):
72
- r''' Get the input feature from the input `octree`.
73
- '''
74
-
75
- octree_feature = ocnn.modules.InputFeature(self.feature, nempty=False)
76
- out = octree_feature(octree)
77
- assert out.size(1) == self.channel_in
78
- return out
79
-
80
- def encoder(self, octree: Octree):
81
- r''' The encoder network of the AutoEncoder.
82
- '''
83
-
84
- convs = dict()
85
- depth, full_depth = self.depth, self.full_depth
86
- data = self.get_input_feature(octree)
87
- convs[depth] = self.conv1(data, octree, depth)
88
- for i, d in enumerate(range(depth, full_depth-1, -1)):
89
- convs[d] = self.encoder_blks[i](convs[d], octree, d)
90
- if d > full_depth:
91
- convs[d-1] = self.downsample[i](convs[d], octree, d)
92
-
93
- # NOTE: here tanh is used to constrain the shape code in [-1, 1]
94
- shape_code = self.proj(convs[full_depth]).tanh()
95
- return shape_code
96
-
97
- def decoder(self, shape_code: torch.Tensor, octree: Octree,
98
- update_octree: bool = False):
99
- r''' The decoder network of the AutoEncoder.
100
- '''
101
-
102
- logits = dict()
103
- deconv = shape_code
104
- depth, full_depth = self.depth, self.full_depth
105
- for i, d in enumerate(range(full_depth, depth+1)):
106
- if d > full_depth:
107
- deconv = self.upsample[i-1](deconv, octree, d-1)
108
- deconv = self.decoder_blks[i](deconv, octree, d)
109
-
110
- # predict the splitting label
111
- logit = self.predict[i](deconv)
112
- logits[d] = logit
113
-
114
- # update the octree according to predicted labels
115
- if update_octree:
116
- split = logit.argmax(1).int()
117
- octree.octree_split(split, d)
118
- if d < depth:
119
- octree.octree_grow(d + 1)
120
-
121
- # predict the signal
122
- if d == depth:
123
- signal = self.header(deconv)
124
- signal = torch.tanh(signal)
125
- signal = ocnn.nn.octree_depad(signal, octree, depth)
126
- if update_octree:
127
- octree.features[depth] = signal
128
-
129
- return {'logits': logits, 'signal': signal, 'octree_out': octree}
130
-
131
- def decode_code(self, shape_code: torch.Tensor):
132
- r''' Decodes the shape code to an output octree.
133
-
134
- Args:
135
- shape_code (torch.Tensor): The shape code for decoding.
136
- '''
137
-
138
- octree_out = self.init_octree(shape_code)
139
- out = self.decoder(shape_code, octree_out, update_octree=True)
140
- return out
141
-
142
- def init_octree(self, shape_code: torch.Tensor):
143
- r''' Initialize a full octree for decoding.
144
-
145
- Args:
146
- shape_code (torch.Tensor): The shape code for decoding, used to getting
147
- the `batch_size` and `device` to initialize the output octree.
148
- '''
149
-
150
- device = shape_code.device
151
- node_num = 2 ** (3 * self.full_depth)
152
- batch_size = shape_code.size(0) // node_num
153
- octree = Octree(self.depth, self.full_depth, batch_size, device)
154
- for d in range(self.full_depth+1):
155
- octree.octree_grow_full(depth=d)
156
- return octree
157
-
158
- def forward(self, octree: Octree, update_octree: bool):
159
- r''''''
160
-
161
- shape_code = self.encoder(octree)
162
- if update_octree:
163
- octree = self.init_octree(shape_code)
164
- out = self.decoder(shape_code, octree, update_octree)
165
- return out
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from typing import Optional
11
+
12
+ import ocnn
13
+ from ocnn.octree import Octree
14
+
15
+
16
+ class AutoEncoder(torch.nn.Module):
17
+ r''' Octree-based AutoEncoder for shape encoding and decoding.
18
+
19
+ Args:
20
+ channel_in (int): The channel of the input signal.
21
+ channel_out (int): The channel of the output signal.
22
+ depth (int): The depth of the octree.
23
+ full_depth (int): The full depth of the octree.
24
+ feature (str): The feature type of the input signal. For details of this
25
+ argument, please refer to :class:`ocnn.modules.InputFeature`.
26
+ '''
27
+
28
+ def __init__(self, channel_in: int, channel_out: int, depth: int,
29
+ full_depth: int = 2, feature: str = 'ND'):
30
+ super().__init__()
31
+ self.channel_in = channel_in
32
+ self.channel_out = channel_out
33
+ self.depth = depth
34
+ self.full_depth = full_depth
35
+ self.feature = feature
36
+ self.resblk_num = 2
37
+ self.code_channel = 64 # dim-of-code = code_channel * 2**(3*full_depth)
38
+ self.channels = [512, 512, 256, 256, 128, 128, 32, 32, 16, 16]
39
+
40
+ # encoder
41
+ self.conv1 = ocnn.modules.OctreeConvBnRelu(
42
+ channel_in, self.channels[depth], nempty=False)
43
+ self.encoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
44
+ self.channels[d], self.channels[d], self.resblk_num, nempty=False)
45
+ for d in range(depth, full_depth-1, -1)])
46
+ self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
47
+ self.channels[d], self.channels[d-1], kernel_size=[2], stride=2,
48
+ nempty=False) for d in range(depth, full_depth, -1)])
49
+ self.proj = torch.nn.Linear(
50
+ self.channels[full_depth], self.code_channel, bias=True)
51
+
52
+ # decoder
53
+ self.channels[full_depth] = self.code_channel # update `channels`
54
+ self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
55
+ self.channels[d-1], self.channels[d], kernel_size=[2], stride=2,
56
+ nempty=False) for d in range(full_depth+1, depth+1)])
57
+ self.decoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
58
+ self.channels[d], self.channels[d], self.resblk_num, nempty=False)
59
+ for d in range(full_depth, depth+1)])
60
+
61
+ # header
62
+ self.predict = torch.nn.ModuleList([self._make_predict_module(
63
+ self.channels[d], 2) for d in range(full_depth, depth + 1)])
64
+ self.header = self._make_predict_module(self.channels[depth], channel_out)
65
+
66
+ def _make_predict_module(self, channel_in, channel_out=2, num_hidden=64):
67
+ return torch.nn.Sequential(
68
+ ocnn.modules.Conv1x1BnRelu(channel_in, num_hidden),
69
+ ocnn.modules.Conv1x1(num_hidden, channel_out, use_bias=True))
70
+
71
+ def encoder(self, octree: Octree):
72
+ r''' The encoder network of the AutoEncoder.
73
+ '''
74
+
75
+ convs = dict()
76
+ depth, full_depth = self.depth, self.full_depth
77
+ data = octree.get_input_feature(self.feature, nempty=False)
78
+ assert data.size(1) == self.channel_in
79
+ convs[depth] = self.conv1(data, octree, depth)
80
+ for i, d in enumerate(range(depth, full_depth-1, -1)):
81
+ convs[d] = self.encoder_blks[i](convs[d], octree, d)
82
+ if d > full_depth:
83
+ convs[d-1] = self.downsample[i](convs[d], octree, d)
84
+
85
+ # NOTE: here tanh is used to constrain the shape code in [-1, 1]
86
+ shape_code = self.proj(convs[full_depth]).tanh()
87
+ return shape_code
88
+
89
+ def decoder(self, shape_code: torch.Tensor, octree: Octree,
90
+ update_octree: bool = False):
91
+ r''' The decoder network of the AutoEncoder.
92
+ '''
93
+
94
+ logits = dict()
95
+ deconv = shape_code
96
+ depth, full_depth = self.depth, self.full_depth
97
+ for i, d in enumerate(range(full_depth, depth+1)):
98
+ if d > full_depth:
99
+ deconv = self.upsample[i-1](deconv, octree, d-1)
100
+ deconv = self.decoder_blks[i](deconv, octree, d)
101
+
102
+ # predict the splitting label
103
+ logit = self.predict[i](deconv)
104
+ logits[d] = logit
105
+
106
+ # update the octree according to predicted labels
107
+ if update_octree:
108
+ split = logit.argmax(1).int()
109
+ octree.octree_split(split, d)
110
+ if d < depth:
111
+ octree.octree_grow(d + 1)
112
+
113
+ # predict the signal
114
+ if d == depth:
115
+ signal = self.header(deconv)
116
+ signal = torch.tanh(signal)
117
+ signal = ocnn.nn.octree_depad(signal, octree, depth)
118
+ if update_octree:
119
+ octree.features[depth] = signal
120
+
121
+ return {'logits': logits, 'signal': signal, 'octree_out': octree}
122
+
123
+ def decode_code(self, shape_code: torch.Tensor):
124
+ r''' Decodes the shape code to an output octree.
125
+
126
+ Args:
127
+ shape_code (torch.Tensor): The shape code for decoding.
128
+ '''
129
+
130
+ octree_out = self.init_octree(shape_code)
131
+ out = self.decoder(shape_code, octree_out, update_octree=True)
132
+ return out
133
+
134
+ def init_octree(self, shape_code: torch.Tensor):
135
+ r''' Initialize a full octree for decoding.
136
+
137
+ Args:
138
+ shape_code (torch.Tensor): The shape code for decoding, used to get
139
+ the `batch_size` and `device` to initialize the output octree.
140
+ '''
141
+
142
+ node_num = 2 ** (3 * self.full_depth)
143
+ batch_size = shape_code.size(0) // node_num
144
+ octree = ocnn.octree.init_octree(
145
+ self.depth, self.full_depth, batch_size, shape_code.device)
146
+ return octree
147
+
148
+ def forward(self, octree: Octree, update_octree: bool):
149
+ r''''''
150
+
151
+ shape_code = self.encoder(octree)
152
+ if update_octree:
153
+ octree = self.init_octree(shape_code)
154
+ out = self.decoder(shape_code, octree, update_octree)
155
+ return out