ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ocnn/__init__.py +24 -24
  2. ocnn/dataset.py +160 -160
  3. ocnn/models/__init__.py +29 -29
  4. ocnn/models/autoencoder.py +155 -155
  5. ocnn/models/hrnet.py +192 -192
  6. ocnn/models/image2shape.py +128 -128
  7. ocnn/models/lenet.py +46 -46
  8. ocnn/models/ounet.py +94 -94
  9. ocnn/models/resnet.py +53 -53
  10. ocnn/models/segnet.py +72 -72
  11. ocnn/models/unet.py +105 -105
  12. ocnn/modules/__init__.py +26 -26
  13. ocnn/modules/modules.py +303 -303
  14. ocnn/modules/resblocks.py +158 -158
  15. ocnn/nn/__init__.py +45 -44
  16. ocnn/nn/kernels/__init__.py +14 -0
  17. ocnn/nn/kernels/autotuner.py +416 -0
  18. ocnn/nn/kernels/config.py +67 -0
  19. ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
  20. ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
  21. ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
  22. ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
  23. ocnn/nn/kernels/utils.py +44 -0
  24. ocnn/nn/octree2col.py +53 -53
  25. ocnn/nn/octree2vox.py +50 -50
  26. ocnn/nn/octree_align.py +46 -46
  27. ocnn/nn/octree_conv.py +430 -429
  28. ocnn/nn/octree_conv_t.py +148 -0
  29. ocnn/nn/octree_drop.py +55 -55
  30. ocnn/nn/octree_dwconv.py +222 -222
  31. ocnn/nn/octree_gconv.py +79 -79
  32. ocnn/nn/octree_interp.py +196 -196
  33. ocnn/nn/octree_norm.py +126 -126
  34. ocnn/nn/octree_pad.py +39 -39
  35. ocnn/nn/octree_pool.py +200 -200
  36. ocnn/octree/__init__.py +22 -22
  37. ocnn/octree/octree.py +770 -770
  38. ocnn/octree/points.py +384 -323
  39. ocnn/octree/shuffled_key.py +115 -115
  40. ocnn/utils.py +205 -205
  41. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
  42. ocnn-2.3.0.dist-info/RECORD +45 -0
  43. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
  44. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
  45. ocnn-2.2.8.dist-info/RECORD +0 -36
  46. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
ocnn/models/unet.py CHANGED
@@ -1,105 +1,105 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from typing import Dict
11
-
12
- import ocnn
13
- from ocnn.octree import Octree
14
-
15
-
16
- class UNet(torch.nn.Module):
17
- r''' Octree-based UNet for segmentation.
18
- '''
19
-
20
- def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear',
21
- nempty: bool = False, **kwargs):
22
- super(UNet, self).__init__()
23
- self.in_channels = in_channels
24
- self.out_channels = out_channels
25
- self.nempty = nempty
26
- self.config_network()
27
- self.encoder_stages = len(self.encoder_blocks)
28
- self.decoder_stages = len(self.decoder_blocks)
29
-
30
- # encoder
31
- self.conv1 = ocnn.modules.OctreeConvBnRelu(
32
- in_channels, self.encoder_channel[0], nempty=nempty)
33
- self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
34
- self.encoder_channel[i], self.encoder_channel[i+1], kernel_size=[2],
35
- stride=2, nempty=nempty) for i in range(self.encoder_stages)])
36
- self.encoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
37
- self.encoder_channel[i+1], self.encoder_channel[i + 1],
38
- self.encoder_blocks[i], self.bottleneck, nempty, self.resblk)
39
- for i in range(self.encoder_stages)])
40
-
41
- # decoder
42
- channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2]
43
- for i in range(self.decoder_stages)]
44
- self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
45
- self.decoder_channel[i], self.decoder_channel[i+1], kernel_size=[2],
46
- stride=2, nempty=nempty) for i in range(self.decoder_stages)])
47
- self.decoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
48
- channel[i], self.decoder_channel[i+1],
49
- self.decoder_blocks[i], self.bottleneck, nempty, self.resblk)
50
- for i in range(self.decoder_stages)])
51
-
52
- # header
53
- # channel = self.decoder_channel[self.decoder_stages]
54
- self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
55
- self.header = torch.nn.Sequential(
56
- ocnn.modules.Conv1x1BnRelu(self.decoder_channel[-1], self.head_channel),
57
- ocnn.modules.Conv1x1(self.head_channel, self.out_channels, use_bias=True))
58
-
59
- def config_network(self):
60
- r''' Configure the network channels and Resblock numbers.
61
- '''
62
-
63
- self.encoder_channel = [32, 32, 64, 128, 256]
64
- self.decoder_channel = [256, 256, 128, 96, 96]
65
- self.encoder_blocks = [2, 3, 4, 6]
66
- self.decoder_blocks = [2, 2, 2, 2]
67
- self.head_channel = 64
68
- self.bottleneck = 1
69
- self.resblk = ocnn.modules.OctreeResBlock2
70
-
71
- def unet_encoder(self, data: torch.Tensor, octree: Octree, depth: int):
72
- r''' The encoder of the U-Net.
73
- '''
74
-
75
- convd = dict()
76
- convd[depth] = self.conv1(data, octree, depth)
77
- for i in range(self.encoder_stages):
78
- d = depth - i
79
- conv = self.downsample[i](convd[d], octree, d)
80
- convd[d-1] = self.encoder[i](conv, octree, d-1)
81
- return convd
82
-
83
- def unet_decoder(self, convd: Dict[int, torch.Tensor], octree: Octree, depth: int):
84
- r''' The decoder of the U-Net.
85
- '''
86
-
87
- deconv = convd[depth]
88
- for i in range(self.decoder_stages):
89
- d = depth + i
90
- deconv = self.upsample[i](deconv, octree, d)
91
- deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections
92
- deconv = self.decoder[i](deconv, octree, d+1)
93
- return deconv
94
-
95
- def forward(self, data: torch.Tensor, octree: Octree, depth: int,
96
- query_pts: torch.Tensor):
97
- r''''''
98
-
99
- convd = self.unet_encoder(data, octree, depth)
100
- deconv = self.unet_decoder(convd, octree, depth - self.encoder_stages)
101
-
102
- interp_depth = depth - self.encoder_stages + self.decoder_stages
103
- feature = self.octree_interp(deconv, octree, interp_depth, query_pts)
104
- logits = self.header(feature)
105
- return logits
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from typing import Dict
11
+
12
+ import ocnn
13
+ from ocnn.octree import Octree
14
+
15
+
16
+ class UNet(torch.nn.Module):
17
+ r''' Octree-based UNet for segmentation.
18
+ '''
19
+
20
+ def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear',
21
+ nempty: bool = False, **kwargs):
22
+ super(UNet, self).__init__()
23
+ self.in_channels = in_channels
24
+ self.out_channels = out_channels
25
+ self.nempty = nempty
26
+ self.config_network()
27
+ self.encoder_stages = len(self.encoder_blocks)
28
+ self.decoder_stages = len(self.decoder_blocks)
29
+
30
+ # encoder
31
+ self.conv1 = ocnn.modules.OctreeConvBnRelu(
32
+ in_channels, self.encoder_channel[0], nempty=nempty)
33
+ self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
34
+ self.encoder_channel[i], self.encoder_channel[i+1], kernel_size=[2],
35
+ stride=2, nempty=nempty) for i in range(self.encoder_stages)])
36
+ self.encoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
37
+ self.encoder_channel[i+1], self.encoder_channel[i + 1],
38
+ self.encoder_blocks[i], self.bottleneck, nempty, self.resblk)
39
+ for i in range(self.encoder_stages)])
40
+
41
+ # decoder
42
+ channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2]
43
+ for i in range(self.decoder_stages)]
44
+ self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
45
+ self.decoder_channel[i], self.decoder_channel[i+1], kernel_size=[2],
46
+ stride=2, nempty=nempty) for i in range(self.decoder_stages)])
47
+ self.decoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
48
+ channel[i], self.decoder_channel[i+1],
49
+ self.decoder_blocks[i], self.bottleneck, nempty, self.resblk)
50
+ for i in range(self.decoder_stages)])
51
+
52
+ # header
53
+ # channel = self.decoder_channel[self.decoder_stages]
54
+ self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
55
+ self.header = torch.nn.Sequential(
56
+ ocnn.modules.Conv1x1BnRelu(self.decoder_channel[-1], self.head_channel),
57
+ ocnn.modules.Conv1x1(self.head_channel, self.out_channels, use_bias=True))
58
+
59
+ def config_network(self):
60
+ r''' Configure the network channels and Resblock numbers.
61
+ '''
62
+
63
+ self.encoder_channel = [32, 32, 64, 128, 256]
64
+ self.decoder_channel = [256, 256, 128, 96, 96]
65
+ self.encoder_blocks = [2, 3, 4, 6]
66
+ self.decoder_blocks = [2, 2, 2, 2]
67
+ self.head_channel = 64
68
+ self.bottleneck = 1
69
+ self.resblk = ocnn.modules.OctreeResBlock2
70
+
71
+ def unet_encoder(self, data: torch.Tensor, octree: Octree, depth: int):
72
+ r''' The encoder of the U-Net.
73
+ '''
74
+
75
+ convd = dict()
76
+ convd[depth] = self.conv1(data, octree, depth)
77
+ for i in range(self.encoder_stages):
78
+ d = depth - i
79
+ conv = self.downsample[i](convd[d], octree, d)
80
+ convd[d-1] = self.encoder[i](conv, octree, d-1)
81
+ return convd
82
+
83
+ def unet_decoder(self, convd: Dict[int, torch.Tensor], octree: Octree, depth: int):
84
+ r''' The decoder of the U-Net.
85
+ '''
86
+
87
+ deconv = convd[depth]
88
+ for i in range(self.decoder_stages):
89
+ d = depth + i
90
+ deconv = self.upsample[i](deconv, octree, d)
91
+ deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections
92
+ deconv = self.decoder[i](deconv, octree, d+1)
93
+ return deconv
94
+
95
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int,
96
+ query_pts: torch.Tensor):
97
+ r''''''
98
+
99
+ convd = self.unet_encoder(data, octree, depth)
100
+ deconv = self.unet_decoder(convd, octree, depth - self.encoder_stages)
101
+
102
+ interp_depth = depth - self.encoder_stages + self.decoder_stages
103
+ feature = self.octree_interp(deconv, octree, interp_depth, query_pts)
104
+ logits = self.header(feature)
105
+ return logits
ocnn/modules/__init__.py CHANGED
@@ -1,26 +1,26 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- from .modules import (InputFeature,
9
- OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
10
- Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
11
- OctreeConvGn, OctreeConvGnRelu, OctreeDeconvGnRelu,
12
- Conv1x1, Conv1x1Gn, Conv1x1GnRelu)
13
- from .resblocks import (OctreeResBlock, OctreeResBlock2, OctreeResBlockGn,
14
- OctreeResBlocks,)
15
-
16
- __all__ = [
17
- 'InputFeature',
18
- 'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
19
- 'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
20
- 'OctreeConvGn', 'OctreeConvGnRelu', 'OctreeDeconvGnRelu',
21
- 'Conv1x1', 'Conv1x1Gn', 'Conv1x1GnRelu',
22
- 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlockGn',
23
- 'OctreeResBlocks',
24
- ]
25
-
26
- classes = __all__
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ from .modules import (InputFeature,
9
+ OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
10
+ Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
11
+ OctreeConvGn, OctreeConvGnRelu, OctreeDeconvGnRelu,
12
+ Conv1x1, Conv1x1Gn, Conv1x1GnRelu)
13
+ from .resblocks import (OctreeResBlock, OctreeResBlock2, OctreeResBlockGn,
14
+ OctreeResBlocks,)
15
+
16
+ __all__ = [
17
+ 'InputFeature',
18
+ 'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
19
+ 'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
20
+ 'OctreeConvGn', 'OctreeConvGnRelu', 'OctreeDeconvGnRelu',
21
+ 'Conv1x1', 'Conv1x1Gn', 'Conv1x1GnRelu',
22
+ 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlockGn',
23
+ 'OctreeResBlocks',
24
+ ]
25
+
26
+ classes = __all__