ocnn 2.2.7__py3-none-any.whl → 2.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/models/hrnet.py CHANGED
@@ -1,192 +1,192 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- from typing import List
10
-
11
- import ocnn
12
- from ocnn.octree import Octree
13
-
14
-
15
- class Branches(torch.nn.Module):
16
-
17
- def __init__(self, channels: List[int], resblk_num: int, nempty: bool = False):
18
- super().__init__()
19
- self.channels = channels
20
- self.resblk_num = resblk_num
21
- bottlenecks = [4 if c < 256 else 8 for c in channels] # to save parameters
22
- self.resblocks = torch.nn.ModuleList([
23
- ocnn.modules.OctreeResBlocks(ch, ch, resblk_num, bnk, nempty=nempty)
24
- for ch, bnk in zip(channels, bottlenecks)])
25
-
26
- def forward(self, datas: List[torch.Tensor], octree: Octree, depth: int):
27
- num = len(self.channels)
28
- torch._assert(len(datas) == num, 'Error')
29
-
30
- out = [None] * num
31
- for i in range(num):
32
- depth_i = depth - i
33
- out[i] = self.resblocks[i](datas[i], octree, depth_i)
34
- return out
35
-
36
-
37
- class TransFunc(torch.nn.Module):
38
-
39
- def __init__(self, in_channels: int, out_channels: int, nempty: bool = False):
40
- super().__init__()
41
- self.in_channels = in_channels
42
- self.out_channels = out_channels
43
- self.nempty = nempty
44
- self.maxpool = ocnn.nn.OctreeMaxPool(nempty=nempty)
45
- self.upsample = ocnn.nn.OctreeUpsample(method='nearest', nempty=nempty)
46
- if in_channels != out_channels:
47
- self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, out_channels)
48
-
49
- def forward(self, data: torch.Tensor, octree: Octree,
50
- in_depth: int, out_depth: int):
51
- out = data
52
- if in_depth > out_depth:
53
- for d in range(in_depth, out_depth, -1):
54
- out = self.maxpool(out, octree, d)
55
- if self.in_channels != self.out_channels:
56
- out = self.conv1x1(out)
57
-
58
- if in_depth < out_depth:
59
- if self.in_channels != self.out_channels:
60
- out = self.conv1x1(out)
61
- for d in range(in_depth, out_depth, 1):
62
- out = self.upsample(out, octree, d)
63
- return out
64
-
65
-
66
- class Transitions(torch.nn.Module):
67
-
68
- def __init__(self, channels: List[int], nempty: bool = False):
69
- super().__init__()
70
- self.channels = channels
71
- self.nempty = nempty
72
-
73
- num = len(self.channels)
74
- self.trans_func = torch.nn.ModuleList()
75
- for i in range(num - 1):
76
- for j in range(num):
77
- self.trans_func.append(TransFunc(channels[i], channels[j], nempty))
78
-
79
- def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
80
- num = len(self.channels)
81
- features = [[None] * (num - 1) for _ in range(num)]
82
- for i in range(num - 1):
83
- for j in range(num):
84
- k = i * num + j
85
- in_depth = depth - i
86
- out_depth = depth - j
87
- features[j][i] = self.trans_func[k](
88
- data[i], octree, in_depth, out_depth)
89
-
90
- out = [None] * num
91
- for j in range(num):
92
- # In the original tensorflow implmentation, a relu is added after the sum.
93
- out[j] = torch.stack(features[j], dim=0).sum(dim=0)
94
- return out
95
-
96
-
97
- class FrontLayer(torch.nn.Module):
98
-
99
- def __init__(self, channels: List[int], nempty: bool = False):
100
- super().__init__()
101
- self.channels = channels
102
- self.num = len(channels) - 1
103
- self.nempty = nempty
104
-
105
- self.conv = torch.nn.ModuleList([
106
- ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
107
- for i in range(self.num)])
108
- self.maxpool = torch.nn.ModuleList([
109
- ocnn.nn.OctreeMaxPool(nempty) for i in range(self.num - 1)])
110
-
111
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
112
- out = data
113
- for i in range(self.num - 1):
114
- depth_i = depth - i
115
- out = self.conv[i](out, octree, depth_i)
116
- out = self.maxpool[i](out, octree, depth_i)
117
- out = self.conv[-1](out, octree, depth - self.num + 1)
118
- return out
119
-
120
-
121
- class ClsHeader(torch.nn.Module):
122
-
123
- def __init__(self, channels: List[int], out_channels: int, nempty: bool = False):
124
- super().__init__()
125
- self.channels = channels
126
- self.out_channels = out_channels
127
- self.nempty = nempty
128
-
129
- in_channels = int(torch.Tensor(channels).sum())
130
- self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, 1024)
131
- self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
132
- self.header = torch.nn.Sequential(
133
- torch.nn.Flatten(start_dim=1),
134
- torch.nn.Linear(1024, out_channels, bias=True))
135
- # self.header = torch.nn.Sequential(
136
- # ocnn.modules.FcBnRelu(512, 256),
137
- # torch.nn.Dropout(p=0.5),
138
- # torch.nn.Linear(256, out_channels))
139
-
140
- def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
141
- full_depth = 2
142
- num = len(data)
143
- outs = [x for x in data] # avoid modifying the input data
144
- for i in range(num):
145
- depth_i = depth - i
146
- for d in range(depth_i, full_depth, -1):
147
- outs[i] = ocnn.nn.octree_max_pool(outs[i], octree, d, self.nempty)
148
-
149
- out = torch.cat(outs, dim=1)
150
- out = self.conv1x1(out)
151
- out = self.global_pool(out, octree, full_depth)
152
- logit = self.header(out)
153
- return logit
154
-
155
-
156
- class HRNet(torch.nn.Module):
157
- r''' Octree-based HRNet for classification and segmentation. '''
158
-
159
- def __init__(self, in_channels: int, out_channels: int, stages: int = 3,
160
- interp: str = 'linear', nempty: bool = False):
161
- super().__init__()
162
- self.in_channels = in_channels
163
- self.out_channels = out_channels
164
- self.interp = interp
165
- self.nempty = nempty
166
- self.stages = stages
167
-
168
- self.resblk_num = 3
169
- self.channels = [128, 256, 512, 512]
170
-
171
- self.front = FrontLayer([in_channels, 32, self.channels[0]], nempty)
172
- self.branches = torch.nn.ModuleList([
173
- Branches(self.channels[:i+1], self.resblk_num, nempty)
174
- for i in range(stages)])
175
- self.transitions = torch.nn.ModuleList([
176
- Transitions(self.channels[:i+2], nempty)
177
- for i in range(stages-1)])
178
-
179
- self.cls_header = ClsHeader(self.channels[:stages], out_channels, nempty)
180
-
181
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
182
- r''''''
183
- convs = [self.front(data, octree, depth)]
184
- depth = depth - 1 # the data is downsampled in `front`
185
- for i in range(self.stages):
186
- convs = self.branches[i](convs, octree, depth)
187
- if i < self.stages - 1:
188
- convs = self.transitions[i](convs, octree, depth)
189
-
190
- logits = self.cls_header(convs, octree, depth)
191
-
192
- return logits
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from typing import List
10
+
11
+ import ocnn
12
+ from ocnn.octree import Octree
13
+
14
+
15
+ class Branches(torch.nn.Module):
16
+
17
+ def __init__(self, channels: List[int], resblk_num: int, nempty: bool = False):
18
+ super().__init__()
19
+ self.channels = channels
20
+ self.resblk_num = resblk_num
21
+ bottlenecks = [4 if c < 256 else 8 for c in channels] # to save parameters
22
+ self.resblocks = torch.nn.ModuleList([
23
+ ocnn.modules.OctreeResBlocks(ch, ch, resblk_num, bnk, nempty=nempty)
24
+ for ch, bnk in zip(channels, bottlenecks)])
25
+
26
+ def forward(self, datas: List[torch.Tensor], octree: Octree, depth: int):
27
+ num = len(self.channels)
28
+ torch._assert(len(datas) == num, 'Error')
29
+
30
+ out = [None] * num
31
+ for i in range(num):
32
+ depth_i = depth - i
33
+ out[i] = self.resblocks[i](datas[i], octree, depth_i)
34
+ return out
35
+
36
+
37
+ class TransFunc(torch.nn.Module):
38
+
39
+ def __init__(self, in_channels: int, out_channels: int, nempty: bool = False):
40
+ super().__init__()
41
+ self.in_channels = in_channels
42
+ self.out_channels = out_channels
43
+ self.nempty = nempty
44
+ self.maxpool = ocnn.nn.OctreeMaxPool(nempty=nempty)
45
+ self.upsample = ocnn.nn.OctreeUpsample(method='nearest', nempty=nempty)
46
+ if in_channels != out_channels:
47
+ self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, out_channels)
48
+
49
+ def forward(self, data: torch.Tensor, octree: Octree,
50
+ in_depth: int, out_depth: int):
51
+ out = data
52
+ if in_depth > out_depth:
53
+ for d in range(in_depth, out_depth, -1):
54
+ out = self.maxpool(out, octree, d)
55
+ if self.in_channels != self.out_channels:
56
+ out = self.conv1x1(out)
57
+
58
+ if in_depth < out_depth:
59
+ if self.in_channels != self.out_channels:
60
+ out = self.conv1x1(out)
61
+ for d in range(in_depth, out_depth, 1):
62
+ out = self.upsample(out, octree, d)
63
+ return out
64
+
65
+
66
+ class Transitions(torch.nn.Module):
67
+
68
+ def __init__(self, channels: List[int], nempty: bool = False):
69
+ super().__init__()
70
+ self.channels = channels
71
+ self.nempty = nempty
72
+
73
+ num = len(self.channels)
74
+ self.trans_func = torch.nn.ModuleList()
75
+ for i in range(num - 1):
76
+ for j in range(num):
77
+ self.trans_func.append(TransFunc(channels[i], channels[j], nempty))
78
+
79
+ def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
80
+ num = len(self.channels)
81
+ features = [[None] * (num - 1) for _ in range(num)]
82
+ for i in range(num - 1):
83
+ for j in range(num):
84
+ k = i * num + j
85
+ in_depth = depth - i
86
+ out_depth = depth - j
87
+ features[j][i] = self.trans_func[k](
88
+ data[i], octree, in_depth, out_depth)
89
+
90
+ out = [None] * num
91
+ for j in range(num):
92
+ # In the original tensorflow implmentation, a relu is added after the sum.
93
+ out[j] = torch.stack(features[j], dim=0).sum(dim=0)
94
+ return out
95
+
96
+
97
+ class FrontLayer(torch.nn.Module):
98
+
99
+ def __init__(self, channels: List[int], nempty: bool = False):
100
+ super().__init__()
101
+ self.channels = channels
102
+ self.num = len(channels) - 1
103
+ self.nempty = nempty
104
+
105
+ self.conv = torch.nn.ModuleList([
106
+ ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
107
+ for i in range(self.num)])
108
+ self.maxpool = torch.nn.ModuleList([
109
+ ocnn.nn.OctreeMaxPool(nempty) for i in range(self.num - 1)])
110
+
111
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
112
+ out = data
113
+ for i in range(self.num - 1):
114
+ depth_i = depth - i
115
+ out = self.conv[i](out, octree, depth_i)
116
+ out = self.maxpool[i](out, octree, depth_i)
117
+ out = self.conv[-1](out, octree, depth - self.num + 1)
118
+ return out
119
+
120
+
121
+ class ClsHeader(torch.nn.Module):
122
+
123
+ def __init__(self, channels: List[int], out_channels: int, nempty: bool = False):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels
127
+ self.nempty = nempty
128
+
129
+ in_channels = int(torch.Tensor(channels).sum())
130
+ self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, 1024)
131
+ self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
132
+ self.header = torch.nn.Sequential(
133
+ torch.nn.Flatten(start_dim=1),
134
+ torch.nn.Linear(1024, out_channels, bias=True))
135
+ # self.header = torch.nn.Sequential(
136
+ # ocnn.modules.FcBnRelu(512, 256),
137
+ # torch.nn.Dropout(p=0.5),
138
+ # torch.nn.Linear(256, out_channels))
139
+
140
+ def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
141
+ full_depth = 2
142
+ num = len(data)
143
+ outs = [x for x in data] # avoid modifying the input data
144
+ for i in range(num):
145
+ depth_i = depth - i
146
+ for d in range(depth_i, full_depth, -1):
147
+ outs[i] = ocnn.nn.octree_max_pool(outs[i], octree, d, self.nempty)
148
+
149
+ out = torch.cat(outs, dim=1)
150
+ out = self.conv1x1(out)
151
+ out = self.global_pool(out, octree, full_depth)
152
+ logit = self.header(out)
153
+ return logit
154
+
155
+
156
+ class HRNet(torch.nn.Module):
157
+ r''' Octree-based HRNet for classification and segmentation. '''
158
+
159
+ def __init__(self, in_channels: int, out_channels: int, stages: int = 3,
160
+ interp: str = 'linear', nempty: bool = False):
161
+ super().__init__()
162
+ self.in_channels = in_channels
163
+ self.out_channels = out_channels
164
+ self.interp = interp
165
+ self.nempty = nempty
166
+ self.stages = stages
167
+
168
+ self.resblk_num = 3
169
+ self.channels = [128, 256, 512, 512]
170
+
171
+ self.front = FrontLayer([in_channels, 32, self.channels[0]], nempty)
172
+ self.branches = torch.nn.ModuleList([
173
+ Branches(self.channels[:i+1], self.resblk_num, nempty)
174
+ for i in range(stages)])
175
+ self.transitions = torch.nn.ModuleList([
176
+ Transitions(self.channels[:i+2], nempty)
177
+ for i in range(stages-1)])
178
+
179
+ self.cls_header = ClsHeader(self.channels[:stages], out_channels, nempty)
180
+
181
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
182
+ r''''''
183
+ convs = [self.front(data, octree, depth)]
184
+ depth = depth - 1 # the data is downsampled in `front`
185
+ for i in range(self.stages):
186
+ convs = self.branches[i](convs, octree, depth)
187
+ if i < self.stages - 1:
188
+ convs = self.transitions[i](convs, octree, depth)
189
+
190
+ logits = self.cls_header(convs, octree, depth)
191
+
192
+ return logits
@@ -1,128 +1,128 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from typing import Optional
11
- from torchvision.models import resnet18
12
-
13
- import ocnn
14
- from ocnn.octree import Octree
15
-
16
-
17
- class Image2Shape(torch.nn.Module):
18
- r''' Octree-based AutoEncoder for shape encoding and decoding.
19
-
20
- Args:
21
- channel_out (int): The channel of the output signal.
22
- depth (int): The depth of the octree.
23
- full_depth (int): The full depth of the octree.
24
- '''
25
-
26
- def __init__(self, channel_out: int, depth: int, full_depth: int = 2,
27
- code_channel: int = 32):
28
- super().__init__()
29
- self.depth = depth
30
- self.full_depth = full_depth
31
- self.channel_out = channel_out
32
- self.resblk_num = 2
33
- self.channels = [512, 512, 256, 256, 128, 128, 64, 64, 32, 32]
34
- self.code_channel = code_channel
35
-
36
- # encoder
37
- self.resnet18 = resnet18()
38
- channel = self.code_channel * 2 ** (3 * full_depth)
39
- self.resnet18.fc = torch.nn.Linear(512, channel, bias=True)
40
-
41
- # decoder
42
- self.channels[full_depth] = self.code_channel # update `channels`
43
- self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
44
- self.channels[d-1], self.channels[d], kernel_size=[2], stride=2,
45
- nempty=False) for d in range(full_depth+1, depth+1)])
46
- self.decoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
47
- self.channels[d], self.channels[d], self.resblk_num, nempty=False)
48
- for d in range(full_depth, depth+1)])
49
-
50
- # header
51
- self.predict = torch.nn.ModuleList([self._make_predict_module(
52
- self.channels[d], 2) for d in range(full_depth, depth + 1)])
53
- self.header = self._make_predict_module(self.channels[depth], channel_out)
54
-
55
- def _make_predict_module(self, channel_in, channel_out=2, num_hidden=64):
56
- return torch.nn.Sequential(
57
- ocnn.modules.Conv1x1BnRelu(channel_in, num_hidden),
58
- ocnn.modules.Conv1x1(num_hidden, channel_out, use_bias=True))
59
-
60
- def decoder(self, shape_code: torch.Tensor, octree: Octree,
61
- update_octree: bool = False):
62
- r''' The decoder network of the AutoEncoder.
63
- '''
64
-
65
- logits = dict()
66
- deconv = shape_code
67
- depth, full_depth = self.depth, self.full_depth
68
- for i, d in enumerate(range(full_depth, depth+1)):
69
- if d > full_depth:
70
- deconv = self.upsample[i-1](deconv, octree, d-1)
71
- deconv = self.decoder_blks[i](deconv, octree, d)
72
-
73
- # predict the splitting label
74
- logit = self.predict[i](deconv)
75
- logits[d] = logit
76
-
77
- # update the octree according to predicted labels
78
- if update_octree:
79
- split = logit.argmax(1).int()
80
- octree.octree_split(split, d)
81
- if d < depth:
82
- octree.octree_grow(d + 1)
83
-
84
- # predict the signal
85
- if d == depth:
86
- signal = self.header(deconv)
87
- signal = torch.tanh(signal)
88
- signal = ocnn.nn.octree_depad(signal, octree, depth)
89
- if update_octree:
90
- octree.features[depth] = signal
91
-
92
- return {'logits': logits, 'signal': signal, 'octree_out': octree}
93
-
94
- def decode_code(self, shape_code: torch.Tensor):
95
- r''' Decodes the shape code to an output octree.
96
-
97
- Args:
98
- shape_code (torch.Tensor): The shape code for decoding.
99
- '''
100
-
101
- octree_out = self.init_octree(shape_code)
102
- out = self.decoder(shape_code, octree_out, update_octree=True)
103
- return out
104
-
105
- def init_octree(self, shape_code: torch.Tensor):
106
- r''' Initialize a full octree for decoding.
107
-
108
- Args:
109
- shape_code (torch.Tensor): The shape code for decoding, used to getting
110
- the `batch_size` and `device` to initialize the output octree.
111
- '''
112
-
113
- node_num = 2 ** (3 * self.full_depth)
114
- batch_size = shape_code.size(0) // node_num
115
- octree = ocnn.octree.init_octree(
116
- self.depth, self.full_depth, batch_size, shape_code.device)
117
- return octree
118
-
119
- def forward(self, image: torch.Tensor, octree: Optional[Octree] = None,
120
- update_octree: bool = False):
121
- r''''''
122
-
123
- shape_code = self.resnet18(image)
124
- shape_code = shape_code.view(-1, self.code_channel)
125
- if update_octree:
126
- octree = self.init_octree(shape_code)
127
- out = self.decoder(shape_code, octree, update_octree)
128
- return out
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from typing import Optional
11
+ from torchvision.models import resnet18
12
+
13
+ import ocnn
14
+ from ocnn.octree import Octree
15
+
16
+
17
+ class Image2Shape(torch.nn.Module):
18
+ r''' Octree-based AutoEncoder for shape encoding and decoding.
19
+
20
+ Args:
21
+ channel_out (int): The channel of the output signal.
22
+ depth (int): The depth of the octree.
23
+ full_depth (int): The full depth of the octree.
24
+ '''
25
+
26
+ def __init__(self, channel_out: int, depth: int, full_depth: int = 2,
27
+ code_channel: int = 32):
28
+ super().__init__()
29
+ self.depth = depth
30
+ self.full_depth = full_depth
31
+ self.channel_out = channel_out
32
+ self.resblk_num = 2
33
+ self.channels = [512, 512, 256, 256, 128, 128, 64, 64, 32, 32]
34
+ self.code_channel = code_channel
35
+
36
+ # encoder
37
+ self.resnet18 = resnet18()
38
+ channel = self.code_channel * 2 ** (3 * full_depth)
39
+ self.resnet18.fc = torch.nn.Linear(512, channel, bias=True)
40
+
41
+ # decoder
42
+ self.channels[full_depth] = self.code_channel # update `channels`
43
+ self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
44
+ self.channels[d-1], self.channels[d], kernel_size=[2], stride=2,
45
+ nempty=False) for d in range(full_depth+1, depth+1)])
46
+ self.decoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
47
+ self.channels[d], self.channels[d], self.resblk_num, nempty=False)
48
+ for d in range(full_depth, depth+1)])
49
+
50
+ # header
51
+ self.predict = torch.nn.ModuleList([self._make_predict_module(
52
+ self.channels[d], 2) for d in range(full_depth, depth + 1)])
53
+ self.header = self._make_predict_module(self.channels[depth], channel_out)
54
+
55
+ def _make_predict_module(self, channel_in, channel_out=2, num_hidden=64):
56
+ return torch.nn.Sequential(
57
+ ocnn.modules.Conv1x1BnRelu(channel_in, num_hidden),
58
+ ocnn.modules.Conv1x1(num_hidden, channel_out, use_bias=True))
59
+
60
+ def decoder(self, shape_code: torch.Tensor, octree: Octree,
61
+ update_octree: bool = False):
62
+ r''' The decoder network of the AutoEncoder.
63
+ '''
64
+
65
+ logits = dict()
66
+ deconv = shape_code
67
+ depth, full_depth = self.depth, self.full_depth
68
+ for i, d in enumerate(range(full_depth, depth+1)):
69
+ if d > full_depth:
70
+ deconv = self.upsample[i-1](deconv, octree, d-1)
71
+ deconv = self.decoder_blks[i](deconv, octree, d)
72
+
73
+ # predict the splitting label
74
+ logit = self.predict[i](deconv)
75
+ logits[d] = logit
76
+
77
+ # update the octree according to predicted labels
78
+ if update_octree:
79
+ split = logit.argmax(1).int()
80
+ octree.octree_split(split, d)
81
+ if d < depth:
82
+ octree.octree_grow(d + 1)
83
+
84
+ # predict the signal
85
+ if d == depth:
86
+ signal = self.header(deconv)
87
+ signal = torch.tanh(signal)
88
+ signal = ocnn.nn.octree_depad(signal, octree, depth)
89
+ if update_octree:
90
+ octree.features[depth] = signal
91
+
92
+ return {'logits': logits, 'signal': signal, 'octree_out': octree}
93
+
94
+ def decode_code(self, shape_code: torch.Tensor):
95
+ r''' Decodes the shape code to an output octree.
96
+
97
+ Args:
98
+ shape_code (torch.Tensor): The shape code for decoding.
99
+ '''
100
+
101
+ octree_out = self.init_octree(shape_code)
102
+ out = self.decoder(shape_code, octree_out, update_octree=True)
103
+ return out
104
+
105
+ def init_octree(self, shape_code: torch.Tensor):
106
+ r''' Initialize a full octree for decoding.
107
+
108
+ Args:
109
+ shape_code (torch.Tensor): The shape code for decoding, used to getting
110
+ the `batch_size` and `device` to initialize the output octree.
111
+ '''
112
+
113
+ node_num = 2 ** (3 * self.full_depth)
114
+ batch_size = shape_code.size(0) // node_num
115
+ octree = ocnn.octree.init_octree(
116
+ self.depth, self.full_depth, batch_size, shape_code.device)
117
+ return octree
118
+
119
+ def forward(self, image: torch.Tensor, octree: Optional[Octree] = None,
120
+ update_octree: bool = False):
121
+ r''''''
122
+
123
+ shape_code = self.resnet18(image)
124
+ shape_code = shape_code.view(-1, self.code_channel)
125
+ if update_octree:
126
+ octree = self.init_octree(shape_code)
127
+ out = self.decoder(shape_code, octree, update_octree)
128
+ return out