ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ocnn/__init__.py +24 -24
  2. ocnn/dataset.py +160 -160
  3. ocnn/models/__init__.py +29 -29
  4. ocnn/models/autoencoder.py +155 -155
  5. ocnn/models/hrnet.py +192 -192
  6. ocnn/models/image2shape.py +128 -128
  7. ocnn/models/lenet.py +46 -46
  8. ocnn/models/ounet.py +94 -94
  9. ocnn/models/resnet.py +53 -53
  10. ocnn/models/segnet.py +72 -72
  11. ocnn/models/unet.py +105 -105
  12. ocnn/modules/__init__.py +26 -26
  13. ocnn/modules/modules.py +303 -303
  14. ocnn/modules/resblocks.py +158 -158
  15. ocnn/nn/__init__.py +45 -44
  16. ocnn/nn/kernels/__init__.py +14 -0
  17. ocnn/nn/kernels/autotuner.py +416 -0
  18. ocnn/nn/kernels/config.py +67 -0
  19. ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
  20. ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
  21. ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
  22. ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
  23. ocnn/nn/kernels/utils.py +44 -0
  24. ocnn/nn/octree2col.py +53 -53
  25. ocnn/nn/octree2vox.py +50 -50
  26. ocnn/nn/octree_align.py +46 -46
  27. ocnn/nn/octree_conv.py +430 -429
  28. ocnn/nn/octree_conv_t.py +148 -0
  29. ocnn/nn/octree_drop.py +55 -55
  30. ocnn/nn/octree_dwconv.py +222 -222
  31. ocnn/nn/octree_gconv.py +79 -79
  32. ocnn/nn/octree_interp.py +196 -196
  33. ocnn/nn/octree_norm.py +126 -126
  34. ocnn/nn/octree_pad.py +39 -39
  35. ocnn/nn/octree_pool.py +200 -200
  36. ocnn/octree/__init__.py +22 -22
  37. ocnn/octree/octree.py +770 -770
  38. ocnn/octree/points.py +384 -323
  39. ocnn/octree/shuffled_key.py +115 -115
  40. ocnn/utils.py +205 -205
  41. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
  42. ocnn-2.3.0.dist-info/RECORD +45 -0
  43. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
  44. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
  45. ocnn-2.2.8.dist-info/RECORD +0 -36
  46. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
ocnn/modules/resblocks.py CHANGED
@@ -1,158 +1,158 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.utils.checkpoint
10
-
11
- from ocnn.octree import Octree
12
- from ocnn.nn import OctreeMaxPool
13
- from ocnn.modules import (Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
14
- OctreeConvBn, OctreeConvGnRelu, Conv1x1Gn,
15
- OctreeConvGn,)
16
-
17
-
18
- class OctreeResBlock(torch.nn.Module):
19
- r''' Octree-based ResNet block in a bottleneck style. The block is composed of
20
- a series of :obj:`Conv1x1`, :obj:`Conv3x3`, and :obj:`Conv1x1`.
21
-
22
- Args:
23
- in_channels (int): Number of input channels.
24
- out_channels (int): Number of output channels.
25
- stride (int): The stride of the block (:obj:`1` or :obj:`2`).
26
- bottleneck (int): The input and output channels of the :obj:`Conv3x3` is
27
- equal to the input channel divided by :attr:`bottleneck`.
28
- nempty (bool): If True, only performs the convolution on non-empty
29
- octree nodes.
30
- '''
31
-
32
- def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
33
- bottleneck: int = 4, nempty: bool = False):
34
- super().__init__()
35
- self.in_channels = in_channels
36
- self.out_channels = out_channels
37
- self.bottleneck = bottleneck
38
- self.stride = stride
39
- channelb = int(out_channels / bottleneck)
40
-
41
- if self.stride == 2:
42
- self.max_pool = OctreeMaxPool(nempty)
43
- self.conv1x1a = Conv1x1BnRelu(in_channels, channelb)
44
- self.conv3x3 = OctreeConvBnRelu(channelb, channelb, nempty=nempty)
45
- self.conv1x1b = Conv1x1Bn(channelb, out_channels)
46
- if self.in_channels != self.out_channels:
47
- self.conv1x1c = Conv1x1Bn(in_channels, out_channels)
48
- self.relu = torch.nn.ReLU(inplace=True)
49
-
50
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
51
- r''''''
52
-
53
- if self.stride == 2:
54
- data = self.max_pool(data, octree, depth)
55
- depth = depth - 1
56
- conv1 = self.conv1x1a(data)
57
- conv2 = self.conv3x3(conv1, octree, depth)
58
- conv3 = self.conv1x1b(conv2)
59
- if self.in_channels != self.out_channels:
60
- data = self.conv1x1c(data)
61
- out = self.relu(conv3 + data)
62
- return out
63
-
64
-
65
- class OctreeResBlock2(torch.nn.Module):
66
- r''' Basic Octree-based ResNet block. The block is composed of
67
- a series of :obj:`Conv3x3` and :obj:`Conv3x3`.
68
-
69
- Refer to :class:`OctreeResBlock` for the details of arguments.
70
- '''
71
-
72
- def __init__(self, in_channels, out_channels, stride=1, bottleneck=1,
73
- nempty=False):
74
- super().__init__()
75
- self.in_channels = in_channels
76
- self.out_channels = out_channels
77
- self.stride = stride
78
- channelb = int(out_channels / bottleneck)
79
-
80
- if self.stride == 2:
81
- self.maxpool = OctreeMaxPool(self.depth)
82
- self.conv3x3a = OctreeConvBnRelu(in_channels, channelb, nempty=nempty)
83
- self.conv3x3b = OctreeConvBn(channelb, out_channels, nempty=nempty)
84
- if self.in_channels != self.out_channels:
85
- self.conv1x1 = Conv1x1Bn(in_channels, out_channels)
86
- self.relu = torch.nn.ReLU(inplace=True)
87
-
88
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
89
- r''''''
90
-
91
- if self.stride == 2:
92
- data = self.maxpool(data, octree, depth)
93
- depth = depth - 1
94
- conv1 = self.conv3x3a(data, octree, depth)
95
- conv2 = self.conv3x3b(conv1, octree, depth)
96
- if self.in_channels != self.out_channels:
97
- data = self.conv1x1(data)
98
- out = self.relu(conv2 + data)
99
- return out
100
-
101
-
102
- class OctreeResBlockGn(torch.nn.Module):
103
-
104
- def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
105
- bottleneck: int = 4, nempty: bool = False, group: int = 32):
106
- super().__init__()
107
- self.in_channels = in_channels
108
- self.out_channels = out_channels
109
- self.stride = stride
110
- channelb = int(out_channels / bottleneck)
111
-
112
- if self.stride == 2:
113
- self.maxpool = OctreeMaxPool(self.depth)
114
- self.conv3x3a = OctreeConvGnRelu(in_channels, channelb, group, nempty=nempty)
115
- self.conv3x3b = OctreeConvGn(channelb, out_channels, group, nempty=nempty)
116
- if self.in_channels != self.out_channels:
117
- self.conv1x1 = Conv1x1Gn(in_channels, out_channels, group)
118
- self.relu = torch.nn.ReLU(inplace=True)
119
-
120
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
121
- r''''''
122
-
123
- if self.stride == 2:
124
- data = self.maxpool(data, octree, depth)
125
- depth = depth - 1
126
- conv1 = self.conv3x3a(data, octree, depth)
127
- conv2 = self.conv3x3b(conv1, octree, depth)
128
- if self.in_channels != self.out_channels:
129
- data = self.conv1x1(data, octree, depth)
130
- out = self.relu(conv2 + data)
131
- return out
132
-
133
-
134
- class OctreeResBlocks(torch.nn.Module):
135
- r''' A sequence of :attr:`resblk_num` ResNet blocks.
136
- '''
137
-
138
- def __init__(self, in_channels, out_channels, resblk_num, bottleneck=4,
139
- nempty=False, resblk=OctreeResBlock, use_checkpoint=False):
140
- super().__init__()
141
- self.resblk_num = resblk_num
142
- self.use_checkpoint = use_checkpoint
143
- channels = [in_channels] + [out_channels] * resblk_num
144
-
145
- self.resblks = torch.nn.ModuleList([resblk(
146
- channels[i], channels[i+1], 1, bottleneck, nempty)
147
- for i in range(self.resblk_num)])
148
-
149
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
150
- r''''''
151
-
152
- for i in range(self.resblk_num):
153
- if self.use_checkpoint:
154
- data = torch.utils.checkpoint.checkpoint(
155
- self.resblks[i], data, octree, depth, use_reentrant=False)
156
- else:
157
- data = self.resblks[i](data, octree, depth)
158
- return data
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.utils.checkpoint
10
+
11
+ from ocnn.octree import Octree
12
+ from ocnn.nn import OctreeMaxPool
13
+ from ocnn.modules import (Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
14
+ OctreeConvBn, OctreeConvGnRelu, Conv1x1Gn,
15
+ OctreeConvGn,)
16
+
17
+
18
+ class OctreeResBlock(torch.nn.Module):
19
+ r''' Octree-based ResNet block in a bottleneck style. The block is composed of
20
+ a series of :obj:`Conv1x1`, :obj:`Conv3x3`, and :obj:`Conv1x1`.
21
+
22
+ Args:
23
+ in_channels (int): Number of input channels.
24
+ out_channels (int): Number of output channels.
25
+ stride (int): The stride of the block (:obj:`1` or :obj:`2`).
26
+ bottleneck (int): The input and output channels of the :obj:`Conv3x3` is
27
+ equal to the input channel divided by :attr:`bottleneck`.
28
+ nempty (bool): If True, only performs the convolution on non-empty
29
+ octree nodes.
30
+ '''
31
+
32
+ def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
33
+ bottleneck: int = 4, nempty: bool = False):
34
+ super().__init__()
35
+ self.in_channels = in_channels
36
+ self.out_channels = out_channels
37
+ self.bottleneck = bottleneck
38
+ self.stride = stride
39
+ channelb = int(out_channels / bottleneck)
40
+
41
+ if self.stride == 2:
42
+ self.max_pool = OctreeMaxPool(nempty)
43
+ self.conv1x1a = Conv1x1BnRelu(in_channels, channelb)
44
+ self.conv3x3 = OctreeConvBnRelu(channelb, channelb, nempty=nempty)
45
+ self.conv1x1b = Conv1x1Bn(channelb, out_channels)
46
+ if self.in_channels != self.out_channels:
47
+ self.conv1x1c = Conv1x1Bn(in_channels, out_channels)
48
+ self.relu = torch.nn.ReLU(inplace=True)
49
+
50
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
51
+ r''''''
52
+
53
+ if self.stride == 2:
54
+ data = self.max_pool(data, octree, depth)
55
+ depth = depth - 1
56
+ conv1 = self.conv1x1a(data)
57
+ conv2 = self.conv3x3(conv1, octree, depth)
58
+ conv3 = self.conv1x1b(conv2)
59
+ if self.in_channels != self.out_channels:
60
+ data = self.conv1x1c(data)
61
+ out = self.relu(conv3 + data)
62
+ return out
63
+
64
+
65
+ class OctreeResBlock2(torch.nn.Module):
66
+ r''' Basic Octree-based ResNet block. The block is composed of
67
+ a series of :obj:`Conv3x3` and :obj:`Conv3x3`.
68
+
69
+ Refer to :class:`OctreeResBlock` for the details of arguments.
70
+ '''
71
+
72
+ def __init__(self, in_channels, out_channels, stride=1, bottleneck=1,
73
+ nempty=False):
74
+ super().__init__()
75
+ self.in_channels = in_channels
76
+ self.out_channels = out_channels
77
+ self.stride = stride
78
+ channelb = int(out_channels / bottleneck)
79
+
80
+ if self.stride == 2:
81
+ self.maxpool = OctreeMaxPool(self.depth)
82
+ self.conv3x3a = OctreeConvBnRelu(in_channels, channelb, nempty=nempty)
83
+ self.conv3x3b = OctreeConvBn(channelb, out_channels, nempty=nempty)
84
+ if self.in_channels != self.out_channels:
85
+ self.conv1x1 = Conv1x1Bn(in_channels, out_channels)
86
+ self.relu = torch.nn.ReLU(inplace=True)
87
+
88
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
89
+ r''''''
90
+
91
+ if self.stride == 2:
92
+ data = self.maxpool(data, octree, depth)
93
+ depth = depth - 1
94
+ conv1 = self.conv3x3a(data, octree, depth)
95
+ conv2 = self.conv3x3b(conv1, octree, depth)
96
+ if self.in_channels != self.out_channels:
97
+ data = self.conv1x1(data)
98
+ out = self.relu(conv2 + data)
99
+ return out
100
+
101
+
102
+ class OctreeResBlockGn(torch.nn.Module):
103
+
104
+ def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
105
+ bottleneck: int = 4, nempty: bool = False, group: int = 32):
106
+ super().__init__()
107
+ self.in_channels = in_channels
108
+ self.out_channels = out_channels
109
+ self.stride = stride
110
+ channelb = int(out_channels / bottleneck)
111
+
112
+ if self.stride == 2:
113
+ self.maxpool = OctreeMaxPool(self.depth)
114
+ self.conv3x3a = OctreeConvGnRelu(in_channels, channelb, group, nempty=nempty)
115
+ self.conv3x3b = OctreeConvGn(channelb, out_channels, group, nempty=nempty)
116
+ if self.in_channels != self.out_channels:
117
+ self.conv1x1 = Conv1x1Gn(in_channels, out_channels, group)
118
+ self.relu = torch.nn.ReLU(inplace=True)
119
+
120
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
121
+ r''''''
122
+
123
+ if self.stride == 2:
124
+ data = self.maxpool(data, octree, depth)
125
+ depth = depth - 1
126
+ conv1 = self.conv3x3a(data, octree, depth)
127
+ conv2 = self.conv3x3b(conv1, octree, depth)
128
+ if self.in_channels != self.out_channels:
129
+ data = self.conv1x1(data, octree, depth)
130
+ out = self.relu(conv2 + data)
131
+ return out
132
+
133
+
134
+ class OctreeResBlocks(torch.nn.Module):
135
+ r''' A sequence of :attr:`resblk_num` ResNet blocks.
136
+ '''
137
+
138
+ def __init__(self, in_channels, out_channels, resblk_num, bottleneck=4,
139
+ nempty=False, resblk=OctreeResBlock, use_checkpoint=False):
140
+ super().__init__()
141
+ self.resblk_num = resblk_num
142
+ self.use_checkpoint = use_checkpoint
143
+ channels = [in_channels] + [out_channels] * resblk_num
144
+
145
+ self.resblks = torch.nn.ModuleList([resblk(
146
+ channels[i], channels[i+1], 1, bottleneck, nempty)
147
+ for i in range(self.resblk_num)])
148
+
149
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
150
+ r''''''
151
+
152
+ for i in range(self.resblk_num):
153
+ if self.use_checkpoint:
154
+ data = torch.utils.checkpoint.checkpoint(
155
+ self.resblks[i], data, octree, depth, use_reentrant=False)
156
+ else:
157
+ data = self.resblks[i](data, octree, depth)
158
+ return data
ocnn/nn/__init__.py CHANGED
@@ -1,44 +1,45 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- from .octree2vox import octree2voxel, Octree2Voxel
9
- from .octree2col import octree2col, col2octree
10
- from .octree_pad import octree_pad, octree_depad
11
- from .octree_interp import (octree_nearest_pts, octree_linear_pts,
12
- OctreeInterp, OctreeUpsample)
13
- from .octree_pool import (octree_max_pool, OctreeMaxPool,
14
- octree_max_unpool, OctreeMaxUnpool,
15
- octree_global_pool, OctreeGlobalPool,
16
- octree_avg_pool, OctreeAvgPool,)
17
- from .octree_conv import OctreeConv, OctreeDeconv
18
- from .octree_gconv import OctreeGroupConv
19
- from .octree_dwconv import OctreeDWConv
20
- from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
21
- OctreeInstanceNorm, OctreeNorm)
22
- from .octree_drop import OctreeDropPath
23
- from .octree_align import search_value, octree_align
24
-
25
-
26
- __all__ = [
27
- 'octree2voxel',
28
- 'octree2col', 'col2octree',
29
- 'octree_pad', 'octree_depad',
30
- 'octree_nearest_pts', 'octree_linear_pts',
31
- 'octree_max_pool', 'octree_max_unpool',
32
- 'octree_global_pool', 'octree_avg_pool',
33
- 'Octree2Voxel',
34
- 'OctreeMaxPool', 'OctreeMaxUnpool',
35
- 'OctreeGlobalPool', 'OctreeAvgPool',
36
- 'OctreeConv', 'OctreeDeconv',
37
- 'OctreeGroupConv', 'OctreeDWConv',
38
- 'OctreeInterp', 'OctreeUpsample',
39
- 'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
40
- 'OctreeDropPath',
41
- 'search_value', 'octree_align',
42
- ]
43
-
44
- classes = __all__
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ from .octree2vox import octree2voxel, Octree2Voxel
9
+ from .octree2col import octree2col, col2octree
10
+ from .octree_pad import octree_pad, octree_depad
11
+ from .octree_interp import (octree_nearest_pts, octree_linear_pts,
12
+ OctreeInterp, OctreeUpsample)
13
+ from .octree_pool import (octree_max_pool, OctreeMaxPool,
14
+ octree_max_unpool, OctreeMaxUnpool,
15
+ octree_global_pool, OctreeGlobalPool,
16
+ octree_avg_pool, OctreeAvgPool,)
17
+ from .octree_conv import OctreeConv, OctreeDeconv
18
+ from .octree_gconv import OctreeGroupConv
19
+ from .octree_dwconv import OctreeDWConv
20
+ from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
21
+ OctreeInstanceNorm, OctreeNorm)
22
+ from .octree_drop import OctreeDropPath
23
+ from .octree_align import search_value, octree_align
24
+ from .octree_conv_t import OctreeConvTriton, OctreeConvT, convert_conv_triton
25
+
26
+ __all__ = [
27
+ 'octree2voxel',
28
+ 'octree2col', 'col2octree',
29
+ 'octree_pad', 'octree_depad',
30
+ 'octree_nearest_pts', 'octree_linear_pts',
31
+ 'octree_max_pool', 'octree_max_unpool',
32
+ 'octree_global_pool', 'octree_avg_pool',
33
+ 'Octree2Voxel',
34
+ 'OctreeMaxPool', 'OctreeMaxUnpool',
35
+ 'OctreeGlobalPool', 'OctreeAvgPool',
36
+ 'OctreeConv', 'OctreeDeconv',
37
+ 'OctreeGroupConv', 'OctreeDWConv',
38
+ 'OctreeInterp', 'OctreeUpsample',
39
+ 'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
40
+ 'OctreeDropPath',
41
+ 'search_value', 'octree_align',
42
+ 'OctreeConvTriton', 'OctreeConvT', 'convert_conv_triton',
43
+ ]
44
+
45
+ classes = __all__
@@ -0,0 +1,14 @@
1
+ from .conv_fwd_implicit_gemm_splitk import conv_fwd_implicit_gemm_splitk
2
+ from .conv_bwd_implicit_gemm_splitk import conv_bwd_implicit_gemm_splitk
3
+ from .conv_bwd_implicit_gemm import conv_bwd_implicit_gemm
4
+ from .conv_fwd_implicit_gemm import conv_fwd_implicit_gemm
5
+
6
+ __all__ = [
7
+ 'conv_fwd_implicit_gemm_splitk',
8
+ 'conv_bwd_implicit_gemm_splitk',
9
+ 'conv_bwd_implicit_gemm',
10
+ 'conv_fwd_implicit_gemm',
11
+ ]
12
+
13
+ from .autotuner import load_autotune_cache
14
+ load_autotune_cache()