ocnn 2.2.1__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/models/ounet.py CHANGED
@@ -1,94 +1,94 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
-
9
- import ocnn
10
- import torch
11
- import torch.nn
12
-
13
- from ocnn.octree import Octree
14
- from ocnn.models.autoencoder import AutoEncoder
15
-
16
-
17
- class OUNet(AutoEncoder):
18
-
19
- def __init__(self, channel_in: int, channel_out: int, depth: int,
20
- full_depth: int = 2, feature: str = 'ND'):
21
- super().__init__(channel_in, channel_out, depth, full_depth, feature,
22
- code_channel=-1) # !set code_channe=-1
23
- self.proj = None # remove this module used in AutoEncoder
24
-
25
- def encoder(self, octree):
26
- r''' The encoder network for extracting heirarchy features.
27
- '''
28
-
29
- convs = dict()
30
- depth, full_depth = self.depth, self.full_depth
31
- data = self.get_input_feature(octree)
32
- convs[depth] = self.conv1(data, octree, depth)
33
- for i, d in enumerate(range(depth, full_depth-1, -1)):
34
- convs[d] = self.encoder_blks[i](convs[d], octree, d)
35
- if d > full_depth:
36
- convs[d-1] = self.downsample[i](convs[d], octree, d)
37
- return convs
38
-
39
- def decoder(self, convs: dict, octree_in: Octree, octree_out: Octree,
40
- update_octree: bool = False):
41
- r''' The decoder network for decode the octree.
42
- '''
43
-
44
- logits = dict()
45
- deconv = convs[self.full_depth]
46
- depth, full_depth = self.depth, self.full_depth
47
- for i, d in enumerate(range(full_depth, depth + 1)):
48
- if d > full_depth:
49
- deconv = self.upsample[i-1](deconv, octree_out, d-1)
50
- skip = ocnn.nn.octree_align(convs[d], octree_in, octree_out, d)
51
- deconv = deconv + skip # output-guided skip connections
52
- deconv = self.decoder_blks[i](deconv, octree_out, d)
53
-
54
- # predict the splitting label
55
- logit = self.predict[i](deconv)
56
- logits[d] = logit
57
-
58
- # update the octree according to predicted labels
59
- if update_octree:
60
- split = logit.argmax(1).int()
61
- octree_out.octree_split(split, d)
62
- if d < depth:
63
- octree_out.octree_grow(d + 1)
64
-
65
- # predict the signal
66
- if d == depth:
67
- signal = self.header(deconv)
68
- signal = torch.tanh(signal)
69
- signal = ocnn.nn.octree_depad(signal, octree_out, depth)
70
- if update_octree:
71
- octree_out.features[depth] = signal
72
-
73
- return {'logits': logits, 'signal': signal, 'octree_out': octree_out}
74
-
75
- def init_octree(self, octree_in: Octree):
76
- r''' Initialize a full octree for decoding.
77
- '''
78
-
79
- device = octree_in.device
80
- batch_size = octree_in.batch_size
81
- octree = Octree(self.depth, self.full_depth, batch_size, device)
82
- for d in range(self.full_depth+1):
83
- octree.octree_grow_full(depth=d)
84
- return octree
85
-
86
- def forward(self, octree_in, octree_out=None, update_octree: bool = False):
87
- r''''''
88
-
89
- if octree_out is None:
90
- update_octree = True
91
- octree_out = self.init_octree(octree_in)
92
- convs = self.encoder(octree_in)
93
- out = self.decoder(convs, octree_in, octree_out, update_octree)
94
- return out
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+
9
+ import ocnn
10
+ import torch
11
+ import torch.nn
12
+
13
+ from ocnn.octree import Octree
14
+ from ocnn.models.autoencoder import AutoEncoder
15
+
16
+
17
+ class OUNet(AutoEncoder):
18
+
19
+ def __init__(self, channel_in: int, channel_out: int, depth: int,
20
+ full_depth: int = 2, feature: str = 'ND'):
21
+ super().__init__(channel_in, channel_out, depth, full_depth, feature,
22
+ code_channel=-1) # !set code_channe=-1
23
+ self.proj = None # remove this module used in AutoEncoder
24
+
25
+ def encoder(self, octree):
26
+ r''' The encoder network for extracting heirarchy features.
27
+ '''
28
+
29
+ convs = dict()
30
+ depth, full_depth = self.depth, self.full_depth
31
+ data = self.get_input_feature(octree)
32
+ convs[depth] = self.conv1(data, octree, depth)
33
+ for i, d in enumerate(range(depth, full_depth-1, -1)):
34
+ convs[d] = self.encoder_blks[i](convs[d], octree, d)
35
+ if d > full_depth:
36
+ convs[d-1] = self.downsample[i](convs[d], octree, d)
37
+ return convs
38
+
39
+ def decoder(self, convs: dict, octree_in: Octree, octree_out: Octree,
40
+ update_octree: bool = False):
41
+ r''' The decoder network for decode the octree.
42
+ '''
43
+
44
+ logits = dict()
45
+ deconv = convs[self.full_depth]
46
+ depth, full_depth = self.depth, self.full_depth
47
+ for i, d in enumerate(range(full_depth, depth + 1)):
48
+ if d > full_depth:
49
+ deconv = self.upsample[i-1](deconv, octree_out, d-1)
50
+ skip = ocnn.nn.octree_align(convs[d], octree_in, octree_out, d)
51
+ deconv = deconv + skip # output-guided skip connections
52
+ deconv = self.decoder_blks[i](deconv, octree_out, d)
53
+
54
+ # predict the splitting label
55
+ logit = self.predict[i](deconv)
56
+ logits[d] = logit
57
+
58
+ # update the octree according to predicted labels
59
+ if update_octree:
60
+ split = logit.argmax(1).int()
61
+ octree_out.octree_split(split, d)
62
+ if d < depth:
63
+ octree_out.octree_grow(d + 1)
64
+
65
+ # predict the signal
66
+ if d == depth:
67
+ signal = self.header(deconv)
68
+ signal = torch.tanh(signal)
69
+ signal = ocnn.nn.octree_depad(signal, octree_out, depth)
70
+ if update_octree:
71
+ octree_out.features[depth] = signal
72
+
73
+ return {'logits': logits, 'signal': signal, 'octree_out': octree_out}
74
+
75
+ def init_octree(self, octree_in: Octree):
76
+ r''' Initialize a full octree for decoding.
77
+ '''
78
+
79
+ device = octree_in.device
80
+ batch_size = octree_in.batch_size
81
+ octree = Octree(self.depth, self.full_depth, batch_size, device)
82
+ for d in range(self.full_depth+1):
83
+ octree.octree_grow_full(depth=d)
84
+ return octree
85
+
86
+ def forward(self, octree_in, octree_out=None, update_octree: bool = False):
87
+ r''''''
88
+
89
+ if octree_out is None:
90
+ update_octree = True
91
+ octree_out = self.init_octree(octree_in)
92
+ convs = self.encoder(octree_in)
93
+ out = self.decoder(convs, octree_in, octree_out, update_octree)
94
+ return out
ocnn/models/resnet.py CHANGED
@@ -1,53 +1,53 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import ocnn
10
- from ocnn.octree import Octree
11
-
12
-
13
- class ResNet(torch.nn.Module):
14
- r''' Octree-based ResNet for classification.
15
- '''
16
-
17
- def __init__(self, in_channels: int, out_channels: int, resblock_num: int,
18
- stages: int, nempty: bool = False):
19
- super().__init__()
20
- self.in_channels = in_channels
21
- self.out_channels = out_channels
22
- self.resblk_num = resblock_num
23
- self.stages = stages
24
- self.nempty = nempty
25
- channels = [2 ** max(i+9-stages, 2) for i in range(stages)]
26
-
27
- self.conv1 = ocnn.modules.OctreeConvBnRelu(
28
- in_channels, channels[0], nempty=nempty)
29
- self.pool1 = ocnn.nn.OctreeMaxPool(nempty)
30
- self.resblocks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
31
- channels[i], channels[i+1], resblock_num, nempty=nempty)
32
- for i in range(stages-1)])
33
- self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
34
- nempty) for i in range(stages-1)])
35
- self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
36
- # self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
37
- self.header = torch.nn.Sequential(
38
- ocnn.modules.FcBnRelu(channels[-1], 512),
39
- torch.nn.Dropout(p=0.5),
40
- torch.nn.Linear(512, out_channels))
41
-
42
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
43
- r''''''
44
-
45
- data = self.conv1(data, octree, depth)
46
- data = self.pool1(data, octree, depth)
47
- for i in range(self.stages-1):
48
- d = depth - i - 1
49
- data = self.resblocks[i](data, octree, d)
50
- data = self.pools[i](data, octree, d)
51
- data = self.global_pool(data, octree, depth-self.stages)
52
- data = self.header(data)
53
- return data
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import ocnn
10
+ from ocnn.octree import Octree
11
+
12
+
13
+ class ResNet(torch.nn.Module):
14
+ r''' Octree-based ResNet for classification.
15
+ '''
16
+
17
+ def __init__(self, in_channels: int, out_channels: int, resblock_num: int,
18
+ stages: int, nempty: bool = False):
19
+ super().__init__()
20
+ self.in_channels = in_channels
21
+ self.out_channels = out_channels
22
+ self.resblk_num = resblock_num
23
+ self.stages = stages
24
+ self.nempty = nempty
25
+ channels = [2 ** max(i+9-stages, 2) for i in range(stages)]
26
+
27
+ self.conv1 = ocnn.modules.OctreeConvBnRelu(
28
+ in_channels, channels[0], nempty=nempty)
29
+ self.pool1 = ocnn.nn.OctreeMaxPool(nempty)
30
+ self.resblocks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
31
+ channels[i], channels[i+1], resblock_num, nempty=nempty)
32
+ for i in range(stages-1)])
33
+ self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
34
+ nempty) for i in range(stages-1)])
35
+ self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
36
+ # self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
37
+ self.header = torch.nn.Sequential(
38
+ ocnn.modules.FcBnRelu(channels[-1], 512),
39
+ torch.nn.Dropout(p=0.5),
40
+ torch.nn.Linear(512, out_channels))
41
+
42
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
43
+ r''''''
44
+
45
+ data = self.conv1(data, octree, depth)
46
+ data = self.pool1(data, octree, depth)
47
+ for i in range(self.stages-1):
48
+ d = depth - i - 1
49
+ data = self.resblocks[i](data, octree, d)
50
+ data = self.pools[i](data, octree, d)
51
+ data = self.global_pool(data, octree, depth-self.stages)
52
+ data = self.header(data)
53
+ return data
ocnn/models/segnet.py CHANGED
@@ -1,72 +1,72 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import ocnn
10
- from ocnn.octree import Octree
11
-
12
-
13
- class SegNet(torch.nn.Module):
14
- r''' Octree-based SegNet for segmentation.
15
- '''
16
-
17
- def __init__(self, in_channels: int, out_channels: int, stages: int,
18
- interp: str = 'linear', nempty: bool = False, **kwargs):
19
- super().__init__()
20
- self.in_channels = in_channels
21
- self.out_channels = out_channels
22
- self.stages = stages
23
- self.nempty = nempty
24
- return_indices = True
25
-
26
- channels_stages = [2 ** max(i+8-stages, 2) for i in range(stages)]
27
- channels = [in_channels] + channels_stages
28
- self.convs = torch.nn.ModuleList(
29
- [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
30
- for i in range(stages)])
31
- self.pools = torch.nn.ModuleList(
32
- [ocnn.nn.OctreeMaxPool(nempty, return_indices) for i in range(stages)])
33
-
34
- self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
35
-
36
- channels = channels_stages[::-1] + [channels_stages[0]]
37
- self.deconvs = torch.nn.ModuleList(
38
- [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
39
- for i in range(0, stages)])
40
- self.unpools = torch.nn.ModuleList(
41
- [ocnn.nn.OctreeMaxUnpool(nempty) for i in range(stages)])
42
-
43
- self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
44
- self.header = torch.nn.Sequential(
45
- ocnn.modules.Conv1x1BnRelu(channels[-1], 64),
46
- ocnn.modules.Conv1x1(64, out_channels, use_bias=True))
47
-
48
- def forward(self, data: torch.Tensor, octree: Octree, depth: int,
49
- query_pts: torch.Tensor):
50
- r''''''
51
-
52
- # encoder
53
- indices = dict()
54
- for i in range(self.stages):
55
- d = depth - i
56
- data = self.convs[i](data, octree, d)
57
- data, indices[d] = self.pools[i](data, octree, d)
58
-
59
- # bottleneck
60
- data = self.bottleneck(data, octree, depth-self.stages)
61
-
62
- # decoder
63
- for i in range(self.stages):
64
- d = depth - self.stages + i
65
- data = self.unpools[i](data, indices[d + 1], octree, d)
66
- data = self.deconvs[i](data, octree, d + 1)
67
-
68
- # header
69
- feature = self.octree_interp(data, octree, depth, query_pts)
70
- logits = self.header(feature)
71
-
72
- return logits
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import ocnn
10
+ from ocnn.octree import Octree
11
+
12
+
13
+ class SegNet(torch.nn.Module):
14
+ r''' Octree-based SegNet for segmentation.
15
+ '''
16
+
17
+ def __init__(self, in_channels: int, out_channels: int, stages: int,
18
+ interp: str = 'linear', nempty: bool = False, **kwargs):
19
+ super().__init__()
20
+ self.in_channels = in_channels
21
+ self.out_channels = out_channels
22
+ self.stages = stages
23
+ self.nempty = nempty
24
+ return_indices = True
25
+
26
+ channels_stages = [2 ** max(i+8-stages, 2) for i in range(stages)]
27
+ channels = [in_channels] + channels_stages
28
+ self.convs = torch.nn.ModuleList(
29
+ [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
30
+ for i in range(stages)])
31
+ self.pools = torch.nn.ModuleList(
32
+ [ocnn.nn.OctreeMaxPool(nempty, return_indices) for i in range(stages)])
33
+
34
+ self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
35
+
36
+ channels = channels_stages[::-1] + [channels_stages[0]]
37
+ self.deconvs = torch.nn.ModuleList(
38
+ [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
39
+ for i in range(0, stages)])
40
+ self.unpools = torch.nn.ModuleList(
41
+ [ocnn.nn.OctreeMaxUnpool(nempty) for i in range(stages)])
42
+
43
+ self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
44
+ self.header = torch.nn.Sequential(
45
+ ocnn.modules.Conv1x1BnRelu(channels[-1], 64),
46
+ ocnn.modules.Conv1x1(64, out_channels, use_bias=True))
47
+
48
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int,
49
+ query_pts: torch.Tensor):
50
+ r''''''
51
+
52
+ # encoder
53
+ indices = dict()
54
+ for i in range(self.stages):
55
+ d = depth - i
56
+ data = self.convs[i](data, octree, d)
57
+ data, indices[d] = self.pools[i](data, octree, d)
58
+
59
+ # bottleneck
60
+ data = self.bottleneck(data, octree, depth-self.stages)
61
+
62
+ # decoder
63
+ for i in range(self.stages):
64
+ d = depth - self.stages + i
65
+ data = self.unpools[i](data, indices[d + 1], octree, d)
66
+ data = self.deconvs[i](data, octree, d + 1)
67
+
68
+ # header
69
+ feature = self.octree_interp(data, octree, depth, query_pts)
70
+ logits = self.header(feature)
71
+
72
+ return logits
ocnn/models/unet.py CHANGED
@@ -1,105 +1,105 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from typing import Dict
11
-
12
- import ocnn
13
- from ocnn.octree import Octree
14
-
15
-
16
- class UNet(torch.nn.Module):
17
- r''' Octree-based UNet for segmentation.
18
- '''
19
-
20
- def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear',
21
- nempty: bool = False, **kwargs):
22
- super(UNet, self).__init__()
23
- self.in_channels = in_channels
24
- self.out_channels = out_channels
25
- self.nempty = nempty
26
- self.config_network()
27
- self.encoder_stages = len(self.encoder_blocks)
28
- self.decoder_stages = len(self.decoder_blocks)
29
-
30
- # encoder
31
- self.conv1 = ocnn.modules.OctreeConvBnRelu(
32
- in_channels, self.encoder_channel[0], nempty=nempty)
33
- self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
34
- self.encoder_channel[i], self.encoder_channel[i+1], kernel_size=[2],
35
- stride=2, nempty=nempty) for i in range(self.encoder_stages)])
36
- self.encoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
37
- self.encoder_channel[i+1], self.encoder_channel[i + 1],
38
- self.encoder_blocks[i], self.bottleneck, nempty, self.resblk)
39
- for i in range(self.encoder_stages)])
40
-
41
- # decoder
42
- channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2]
43
- for i in range(self.decoder_stages)]
44
- self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
45
- self.decoder_channel[i], self.decoder_channel[i+1], kernel_size=[2],
46
- stride=2, nempty=nempty) for i in range(self.decoder_stages)])
47
- self.decoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
48
- channel[i], self.decoder_channel[i+1],
49
- self.decoder_blocks[i], self.bottleneck, nempty, self.resblk)
50
- for i in range(self.decoder_stages)])
51
-
52
- # header
53
- # channel = self.decoder_channel[self.decoder_stages]
54
- self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
55
- self.header = torch.nn.Sequential(
56
- ocnn.modules.Conv1x1BnRelu(self.decoder_channel[-1], self.head_channel),
57
- ocnn.modules.Conv1x1(self.head_channel, self.out_channels, use_bias=True))
58
-
59
- def config_network(self):
60
- r''' Configure the network channels and Resblock numbers.
61
- '''
62
-
63
- self.encoder_channel = [32, 32, 64, 128, 256]
64
- self.decoder_channel = [256, 256, 128, 96, 96]
65
- self.encoder_blocks = [2, 3, 4, 6]
66
- self.decoder_blocks = [2, 2, 2, 2]
67
- self.head_channel = 64
68
- self.bottleneck = 1
69
- self.resblk = ocnn.modules.OctreeResBlock2
70
-
71
- def unet_encoder(self, data: torch.Tensor, octree: Octree, depth: int):
72
- r''' The encoder of the U-Net.
73
- '''
74
-
75
- convd = dict()
76
- convd[depth] = self.conv1(data, octree, depth)
77
- for i in range(self.encoder_stages):
78
- d = depth - i
79
- conv = self.downsample[i](convd[d], octree, d)
80
- convd[d-1] = self.encoder[i](conv, octree, d-1)
81
- return convd
82
-
83
- def unet_decoder(self, convd: Dict[int, torch.Tensor], octree: Octree, depth: int):
84
- r''' The decoder of the U-Net.
85
- '''
86
-
87
- deconv = convd[depth]
88
- for i in range(self.decoder_stages):
89
- d = depth + i
90
- deconv = self.upsample[i](deconv, octree, d)
91
- deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections
92
- deconv = self.decoder[i](deconv, octree, d+1)
93
- return deconv
94
-
95
- def forward(self, data: torch.Tensor, octree: Octree, depth: int,
96
- query_pts: torch.Tensor):
97
- r''''''
98
-
99
- convd = self.unet_encoder(data, octree, depth)
100
- deconv = self.unet_decoder(convd, octree, depth - self.encoder_stages)
101
-
102
- interp_depth = depth - self.encoder_stages + self.decoder_stages
103
- feature = self.octree_interp(deconv, octree, interp_depth, query_pts)
104
- logits = self.header(feature)
105
- return logits
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from typing import Dict
11
+
12
+ import ocnn
13
+ from ocnn.octree import Octree
14
+
15
+
16
+ class UNet(torch.nn.Module):
17
+ r''' Octree-based UNet for segmentation.
18
+ '''
19
+
20
+ def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear',
21
+ nempty: bool = False, **kwargs):
22
+ super(UNet, self).__init__()
23
+ self.in_channels = in_channels
24
+ self.out_channels = out_channels
25
+ self.nempty = nempty
26
+ self.config_network()
27
+ self.encoder_stages = len(self.encoder_blocks)
28
+ self.decoder_stages = len(self.decoder_blocks)
29
+
30
+ # encoder
31
+ self.conv1 = ocnn.modules.OctreeConvBnRelu(
32
+ in_channels, self.encoder_channel[0], nempty=nempty)
33
+ self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
34
+ self.encoder_channel[i], self.encoder_channel[i+1], kernel_size=[2],
35
+ stride=2, nempty=nempty) for i in range(self.encoder_stages)])
36
+ self.encoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
37
+ self.encoder_channel[i+1], self.encoder_channel[i + 1],
38
+ self.encoder_blocks[i], self.bottleneck, nempty, self.resblk)
39
+ for i in range(self.encoder_stages)])
40
+
41
+ # decoder
42
+ channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2]
43
+ for i in range(self.decoder_stages)]
44
+ self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
45
+ self.decoder_channel[i], self.decoder_channel[i+1], kernel_size=[2],
46
+ stride=2, nempty=nempty) for i in range(self.decoder_stages)])
47
+ self.decoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
48
+ channel[i], self.decoder_channel[i+1],
49
+ self.decoder_blocks[i], self.bottleneck, nempty, self.resblk)
50
+ for i in range(self.decoder_stages)])
51
+
52
+ # header
53
+ # channel = self.decoder_channel[self.decoder_stages]
54
+ self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
55
+ self.header = torch.nn.Sequential(
56
+ ocnn.modules.Conv1x1BnRelu(self.decoder_channel[-1], self.head_channel),
57
+ ocnn.modules.Conv1x1(self.head_channel, self.out_channels, use_bias=True))
58
+
59
+ def config_network(self):
60
+ r''' Configure the network channels and Resblock numbers.
61
+ '''
62
+
63
+ self.encoder_channel = [32, 32, 64, 128, 256]
64
+ self.decoder_channel = [256, 256, 128, 96, 96]
65
+ self.encoder_blocks = [2, 3, 4, 6]
66
+ self.decoder_blocks = [2, 2, 2, 2]
67
+ self.head_channel = 64
68
+ self.bottleneck = 1
69
+ self.resblk = ocnn.modules.OctreeResBlock2
70
+
71
+ def unet_encoder(self, data: torch.Tensor, octree: Octree, depth: int):
72
+ r''' The encoder of the U-Net.
73
+ '''
74
+
75
+ convd = dict()
76
+ convd[depth] = self.conv1(data, octree, depth)
77
+ for i in range(self.encoder_stages):
78
+ d = depth - i
79
+ conv = self.downsample[i](convd[d], octree, d)
80
+ convd[d-1] = self.encoder[i](conv, octree, d-1)
81
+ return convd
82
+
83
+ def unet_decoder(self, convd: Dict[int, torch.Tensor], octree: Octree, depth: int):
84
+ r''' The decoder of the U-Net.
85
+ '''
86
+
87
+ deconv = convd[depth]
88
+ for i in range(self.decoder_stages):
89
+ d = depth + i
90
+ deconv = self.upsample[i](deconv, octree, d)
91
+ deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections
92
+ deconv = self.decoder[i](deconv, octree, d+1)
93
+ return deconv
94
+
95
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int,
96
+ query_pts: torch.Tensor):
97
+ r''''''
98
+
99
+ convd = self.unet_encoder(data, octree, depth)
100
+ deconv = self.unet_decoder(convd, octree, depth - self.encoder_stages)
101
+
102
+ interp_depth = depth - self.encoder_stages + self.decoder_stages
103
+ feature = self.octree_interp(deconv, octree, interp_depth, query_pts)
104
+ logits = self.header(feature)
105
+ return logits