ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +45 -44
- ocnn/nn/kernels/__init__.py +14 -0
- ocnn/nn/kernels/autotuner.py +416 -0
- ocnn/nn/kernels/config.py +67 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
- ocnn/nn/kernels/utils.py +44 -0
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +430 -429
- ocnn/nn/octree_conv_t.py +148 -0
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +770 -770
- ocnn/octree/points.py +384 -323
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
- ocnn-2.3.0.dist-info/RECORD +45 -0
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
- ocnn-2.2.8.dist-info/RECORD +0 -36
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
ocnn/modules/resblocks.py
CHANGED
|
@@ -1,158 +1,158 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.utils.checkpoint
|
|
10
|
-
|
|
11
|
-
from ocnn.octree import Octree
|
|
12
|
-
from ocnn.nn import OctreeMaxPool
|
|
13
|
-
from ocnn.modules import (Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
|
|
14
|
-
OctreeConvBn, OctreeConvGnRelu, Conv1x1Gn,
|
|
15
|
-
OctreeConvGn,)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class OctreeResBlock(torch.nn.Module):
|
|
19
|
-
r''' Octree-based ResNet block in a bottleneck style. The block is composed of
|
|
20
|
-
a series of :obj:`Conv1x1`, :obj:`Conv3x3`, and :obj:`Conv1x1`.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
in_channels (int): Number of input channels.
|
|
24
|
-
out_channels (int): Number of output channels.
|
|
25
|
-
stride (int): The stride of the block (:obj:`1` or :obj:`2`).
|
|
26
|
-
bottleneck (int): The input and output channels of the :obj:`Conv3x3` is
|
|
27
|
-
equal to the input channel divided by :attr:`bottleneck`.
|
|
28
|
-
nempty (bool): If True, only performs the convolution on non-empty
|
|
29
|
-
octree nodes.
|
|
30
|
-
'''
|
|
31
|
-
|
|
32
|
-
def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
|
|
33
|
-
bottleneck: int = 4, nempty: bool = False):
|
|
34
|
-
super().__init__()
|
|
35
|
-
self.in_channels = in_channels
|
|
36
|
-
self.out_channels = out_channels
|
|
37
|
-
self.bottleneck = bottleneck
|
|
38
|
-
self.stride = stride
|
|
39
|
-
channelb = int(out_channels / bottleneck)
|
|
40
|
-
|
|
41
|
-
if self.stride == 2:
|
|
42
|
-
self.max_pool = OctreeMaxPool(nempty)
|
|
43
|
-
self.conv1x1a = Conv1x1BnRelu(in_channels, channelb)
|
|
44
|
-
self.conv3x3 = OctreeConvBnRelu(channelb, channelb, nempty=nempty)
|
|
45
|
-
self.conv1x1b = Conv1x1Bn(channelb, out_channels)
|
|
46
|
-
if self.in_channels != self.out_channels:
|
|
47
|
-
self.conv1x1c = Conv1x1Bn(in_channels, out_channels)
|
|
48
|
-
self.relu = torch.nn.ReLU(inplace=True)
|
|
49
|
-
|
|
50
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
51
|
-
r''''''
|
|
52
|
-
|
|
53
|
-
if self.stride == 2:
|
|
54
|
-
data = self.max_pool(data, octree, depth)
|
|
55
|
-
depth = depth - 1
|
|
56
|
-
conv1 = self.conv1x1a(data)
|
|
57
|
-
conv2 = self.conv3x3(conv1, octree, depth)
|
|
58
|
-
conv3 = self.conv1x1b(conv2)
|
|
59
|
-
if self.in_channels != self.out_channels:
|
|
60
|
-
data = self.conv1x1c(data)
|
|
61
|
-
out = self.relu(conv3 + data)
|
|
62
|
-
return out
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
class OctreeResBlock2(torch.nn.Module):
|
|
66
|
-
r''' Basic Octree-based ResNet block. The block is composed of
|
|
67
|
-
a series of :obj:`Conv3x3` and :obj:`Conv3x3`.
|
|
68
|
-
|
|
69
|
-
Refer to :class:`OctreeResBlock` for the details of arguments.
|
|
70
|
-
'''
|
|
71
|
-
|
|
72
|
-
def __init__(self, in_channels, out_channels, stride=1, bottleneck=1,
|
|
73
|
-
nempty=False):
|
|
74
|
-
super().__init__()
|
|
75
|
-
self.in_channels = in_channels
|
|
76
|
-
self.out_channels = out_channels
|
|
77
|
-
self.stride = stride
|
|
78
|
-
channelb = int(out_channels / bottleneck)
|
|
79
|
-
|
|
80
|
-
if self.stride == 2:
|
|
81
|
-
self.maxpool = OctreeMaxPool(self.depth)
|
|
82
|
-
self.conv3x3a = OctreeConvBnRelu(in_channels, channelb, nempty=nempty)
|
|
83
|
-
self.conv3x3b = OctreeConvBn(channelb, out_channels, nempty=nempty)
|
|
84
|
-
if self.in_channels != self.out_channels:
|
|
85
|
-
self.conv1x1 = Conv1x1Bn(in_channels, out_channels)
|
|
86
|
-
self.relu = torch.nn.ReLU(inplace=True)
|
|
87
|
-
|
|
88
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
89
|
-
r''''''
|
|
90
|
-
|
|
91
|
-
if self.stride == 2:
|
|
92
|
-
data = self.maxpool(data, octree, depth)
|
|
93
|
-
depth = depth - 1
|
|
94
|
-
conv1 = self.conv3x3a(data, octree, depth)
|
|
95
|
-
conv2 = self.conv3x3b(conv1, octree, depth)
|
|
96
|
-
if self.in_channels != self.out_channels:
|
|
97
|
-
data = self.conv1x1(data)
|
|
98
|
-
out = self.relu(conv2 + data)
|
|
99
|
-
return out
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
class OctreeResBlockGn(torch.nn.Module):
|
|
103
|
-
|
|
104
|
-
def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
|
|
105
|
-
bottleneck: int = 4, nempty: bool = False, group: int = 32):
|
|
106
|
-
super().__init__()
|
|
107
|
-
self.in_channels = in_channels
|
|
108
|
-
self.out_channels = out_channels
|
|
109
|
-
self.stride = stride
|
|
110
|
-
channelb = int(out_channels / bottleneck)
|
|
111
|
-
|
|
112
|
-
if self.stride == 2:
|
|
113
|
-
self.maxpool = OctreeMaxPool(self.depth)
|
|
114
|
-
self.conv3x3a = OctreeConvGnRelu(in_channels, channelb, group, nempty=nempty)
|
|
115
|
-
self.conv3x3b = OctreeConvGn(channelb, out_channels, group, nempty=nempty)
|
|
116
|
-
if self.in_channels != self.out_channels:
|
|
117
|
-
self.conv1x1 = Conv1x1Gn(in_channels, out_channels, group)
|
|
118
|
-
self.relu = torch.nn.ReLU(inplace=True)
|
|
119
|
-
|
|
120
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
121
|
-
r''''''
|
|
122
|
-
|
|
123
|
-
if self.stride == 2:
|
|
124
|
-
data = self.maxpool(data, octree, depth)
|
|
125
|
-
depth = depth - 1
|
|
126
|
-
conv1 = self.conv3x3a(data, octree, depth)
|
|
127
|
-
conv2 = self.conv3x3b(conv1, octree, depth)
|
|
128
|
-
if self.in_channels != self.out_channels:
|
|
129
|
-
data = self.conv1x1(data, octree, depth)
|
|
130
|
-
out = self.relu(conv2 + data)
|
|
131
|
-
return out
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
class OctreeResBlocks(torch.nn.Module):
|
|
135
|
-
r''' A sequence of :attr:`resblk_num` ResNet blocks.
|
|
136
|
-
'''
|
|
137
|
-
|
|
138
|
-
def __init__(self, in_channels, out_channels, resblk_num, bottleneck=4,
|
|
139
|
-
nempty=False, resblk=OctreeResBlock, use_checkpoint=False):
|
|
140
|
-
super().__init__()
|
|
141
|
-
self.resblk_num = resblk_num
|
|
142
|
-
self.use_checkpoint = use_checkpoint
|
|
143
|
-
channels = [in_channels] + [out_channels] * resblk_num
|
|
144
|
-
|
|
145
|
-
self.resblks = torch.nn.ModuleList([resblk(
|
|
146
|
-
channels[i], channels[i+1], 1, bottleneck, nempty)
|
|
147
|
-
for i in range(self.resblk_num)])
|
|
148
|
-
|
|
149
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
150
|
-
r''''''
|
|
151
|
-
|
|
152
|
-
for i in range(self.resblk_num):
|
|
153
|
-
if self.use_checkpoint:
|
|
154
|
-
data = torch.utils.checkpoint.checkpoint(
|
|
155
|
-
self.resblks[i], data, octree, depth, use_reentrant=False)
|
|
156
|
-
else:
|
|
157
|
-
data = self.resblks[i](data, octree, depth)
|
|
158
|
-
return data
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.utils.checkpoint
|
|
10
|
+
|
|
11
|
+
from ocnn.octree import Octree
|
|
12
|
+
from ocnn.nn import OctreeMaxPool
|
|
13
|
+
from ocnn.modules import (Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
|
|
14
|
+
OctreeConvBn, OctreeConvGnRelu, Conv1x1Gn,
|
|
15
|
+
OctreeConvGn,)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OctreeResBlock(torch.nn.Module):
|
|
19
|
+
r''' Octree-based ResNet block in a bottleneck style. The block is composed of
|
|
20
|
+
a series of :obj:`Conv1x1`, :obj:`Conv3x3`, and :obj:`Conv1x1`.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
in_channels (int): Number of input channels.
|
|
24
|
+
out_channels (int): Number of output channels.
|
|
25
|
+
stride (int): The stride of the block (:obj:`1` or :obj:`2`).
|
|
26
|
+
bottleneck (int): The input and output channels of the :obj:`Conv3x3` is
|
|
27
|
+
equal to the input channel divided by :attr:`bottleneck`.
|
|
28
|
+
nempty (bool): If True, only performs the convolution on non-empty
|
|
29
|
+
octree nodes.
|
|
30
|
+
'''
|
|
31
|
+
|
|
32
|
+
def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
|
|
33
|
+
bottleneck: int = 4, nempty: bool = False):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.in_channels = in_channels
|
|
36
|
+
self.out_channels = out_channels
|
|
37
|
+
self.bottleneck = bottleneck
|
|
38
|
+
self.stride = stride
|
|
39
|
+
channelb = int(out_channels / bottleneck)
|
|
40
|
+
|
|
41
|
+
if self.stride == 2:
|
|
42
|
+
self.max_pool = OctreeMaxPool(nempty)
|
|
43
|
+
self.conv1x1a = Conv1x1BnRelu(in_channels, channelb)
|
|
44
|
+
self.conv3x3 = OctreeConvBnRelu(channelb, channelb, nempty=nempty)
|
|
45
|
+
self.conv1x1b = Conv1x1Bn(channelb, out_channels)
|
|
46
|
+
if self.in_channels != self.out_channels:
|
|
47
|
+
self.conv1x1c = Conv1x1Bn(in_channels, out_channels)
|
|
48
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
49
|
+
|
|
50
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
51
|
+
r''''''
|
|
52
|
+
|
|
53
|
+
if self.stride == 2:
|
|
54
|
+
data = self.max_pool(data, octree, depth)
|
|
55
|
+
depth = depth - 1
|
|
56
|
+
conv1 = self.conv1x1a(data)
|
|
57
|
+
conv2 = self.conv3x3(conv1, octree, depth)
|
|
58
|
+
conv3 = self.conv1x1b(conv2)
|
|
59
|
+
if self.in_channels != self.out_channels:
|
|
60
|
+
data = self.conv1x1c(data)
|
|
61
|
+
out = self.relu(conv3 + data)
|
|
62
|
+
return out
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class OctreeResBlock2(torch.nn.Module):
|
|
66
|
+
r''' Basic Octree-based ResNet block. The block is composed of
|
|
67
|
+
a series of :obj:`Conv3x3` and :obj:`Conv3x3`.
|
|
68
|
+
|
|
69
|
+
Refer to :class:`OctreeResBlock` for the details of arguments.
|
|
70
|
+
'''
|
|
71
|
+
|
|
72
|
+
def __init__(self, in_channels, out_channels, stride=1, bottleneck=1,
|
|
73
|
+
nempty=False):
|
|
74
|
+
super().__init__()
|
|
75
|
+
self.in_channels = in_channels
|
|
76
|
+
self.out_channels = out_channels
|
|
77
|
+
self.stride = stride
|
|
78
|
+
channelb = int(out_channels / bottleneck)
|
|
79
|
+
|
|
80
|
+
if self.stride == 2:
|
|
81
|
+
self.maxpool = OctreeMaxPool(self.depth)
|
|
82
|
+
self.conv3x3a = OctreeConvBnRelu(in_channels, channelb, nempty=nempty)
|
|
83
|
+
self.conv3x3b = OctreeConvBn(channelb, out_channels, nempty=nempty)
|
|
84
|
+
if self.in_channels != self.out_channels:
|
|
85
|
+
self.conv1x1 = Conv1x1Bn(in_channels, out_channels)
|
|
86
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
87
|
+
|
|
88
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
89
|
+
r''''''
|
|
90
|
+
|
|
91
|
+
if self.stride == 2:
|
|
92
|
+
data = self.maxpool(data, octree, depth)
|
|
93
|
+
depth = depth - 1
|
|
94
|
+
conv1 = self.conv3x3a(data, octree, depth)
|
|
95
|
+
conv2 = self.conv3x3b(conv1, octree, depth)
|
|
96
|
+
if self.in_channels != self.out_channels:
|
|
97
|
+
data = self.conv1x1(data)
|
|
98
|
+
out = self.relu(conv2 + data)
|
|
99
|
+
return out
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class OctreeResBlockGn(torch.nn.Module):
|
|
103
|
+
|
|
104
|
+
def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
|
|
105
|
+
bottleneck: int = 4, nempty: bool = False, group: int = 32):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.in_channels = in_channels
|
|
108
|
+
self.out_channels = out_channels
|
|
109
|
+
self.stride = stride
|
|
110
|
+
channelb = int(out_channels / bottleneck)
|
|
111
|
+
|
|
112
|
+
if self.stride == 2:
|
|
113
|
+
self.maxpool = OctreeMaxPool(self.depth)
|
|
114
|
+
self.conv3x3a = OctreeConvGnRelu(in_channels, channelb, group, nempty=nempty)
|
|
115
|
+
self.conv3x3b = OctreeConvGn(channelb, out_channels, group, nempty=nempty)
|
|
116
|
+
if self.in_channels != self.out_channels:
|
|
117
|
+
self.conv1x1 = Conv1x1Gn(in_channels, out_channels, group)
|
|
118
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
119
|
+
|
|
120
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
121
|
+
r''''''
|
|
122
|
+
|
|
123
|
+
if self.stride == 2:
|
|
124
|
+
data = self.maxpool(data, octree, depth)
|
|
125
|
+
depth = depth - 1
|
|
126
|
+
conv1 = self.conv3x3a(data, octree, depth)
|
|
127
|
+
conv2 = self.conv3x3b(conv1, octree, depth)
|
|
128
|
+
if self.in_channels != self.out_channels:
|
|
129
|
+
data = self.conv1x1(data, octree, depth)
|
|
130
|
+
out = self.relu(conv2 + data)
|
|
131
|
+
return out
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class OctreeResBlocks(torch.nn.Module):
|
|
135
|
+
r''' A sequence of :attr:`resblk_num` ResNet blocks.
|
|
136
|
+
'''
|
|
137
|
+
|
|
138
|
+
def __init__(self, in_channels, out_channels, resblk_num, bottleneck=4,
|
|
139
|
+
nempty=False, resblk=OctreeResBlock, use_checkpoint=False):
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.resblk_num = resblk_num
|
|
142
|
+
self.use_checkpoint = use_checkpoint
|
|
143
|
+
channels = [in_channels] + [out_channels] * resblk_num
|
|
144
|
+
|
|
145
|
+
self.resblks = torch.nn.ModuleList([resblk(
|
|
146
|
+
channels[i], channels[i+1], 1, bottleneck, nempty)
|
|
147
|
+
for i in range(self.resblk_num)])
|
|
148
|
+
|
|
149
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
150
|
+
r''''''
|
|
151
|
+
|
|
152
|
+
for i in range(self.resblk_num):
|
|
153
|
+
if self.use_checkpoint:
|
|
154
|
+
data = torch.utils.checkpoint.checkpoint(
|
|
155
|
+
self.resblks[i], data, octree, depth, use_reentrant=False)
|
|
156
|
+
else:
|
|
157
|
+
data = self.resblks[i](data, octree, depth)
|
|
158
|
+
return data
|
ocnn/nn/__init__.py
CHANGED
|
@@ -1,44 +1,45 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
from .octree2vox import octree2voxel, Octree2Voxel
|
|
9
|
-
from .octree2col import octree2col, col2octree
|
|
10
|
-
from .octree_pad import octree_pad, octree_depad
|
|
11
|
-
from .octree_interp import (octree_nearest_pts, octree_linear_pts,
|
|
12
|
-
OctreeInterp, OctreeUpsample)
|
|
13
|
-
from .octree_pool import (octree_max_pool, OctreeMaxPool,
|
|
14
|
-
octree_max_unpool, OctreeMaxUnpool,
|
|
15
|
-
octree_global_pool, OctreeGlobalPool,
|
|
16
|
-
octree_avg_pool, OctreeAvgPool,)
|
|
17
|
-
from .octree_conv import OctreeConv, OctreeDeconv
|
|
18
|
-
from .octree_gconv import OctreeGroupConv
|
|
19
|
-
from .octree_dwconv import OctreeDWConv
|
|
20
|
-
from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
|
|
21
|
-
OctreeInstanceNorm, OctreeNorm)
|
|
22
|
-
from .octree_drop import OctreeDropPath
|
|
23
|
-
from .octree_align import search_value, octree_align
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
__all__ = [
|
|
27
|
-
'octree2voxel',
|
|
28
|
-
'octree2col', 'col2octree',
|
|
29
|
-
'octree_pad', 'octree_depad',
|
|
30
|
-
'octree_nearest_pts', 'octree_linear_pts',
|
|
31
|
-
'octree_max_pool', 'octree_max_unpool',
|
|
32
|
-
'octree_global_pool', 'octree_avg_pool',
|
|
33
|
-
'Octree2Voxel',
|
|
34
|
-
'OctreeMaxPool', 'OctreeMaxUnpool',
|
|
35
|
-
'OctreeGlobalPool', 'OctreeAvgPool',
|
|
36
|
-
'OctreeConv', 'OctreeDeconv',
|
|
37
|
-
'OctreeGroupConv', 'OctreeDWConv',
|
|
38
|
-
'OctreeInterp', 'OctreeUpsample',
|
|
39
|
-
'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
|
|
40
|
-
'OctreeDropPath',
|
|
41
|
-
'search_value', 'octree_align',
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
from .octree2vox import octree2voxel, Octree2Voxel
|
|
9
|
+
from .octree2col import octree2col, col2octree
|
|
10
|
+
from .octree_pad import octree_pad, octree_depad
|
|
11
|
+
from .octree_interp import (octree_nearest_pts, octree_linear_pts,
|
|
12
|
+
OctreeInterp, OctreeUpsample)
|
|
13
|
+
from .octree_pool import (octree_max_pool, OctreeMaxPool,
|
|
14
|
+
octree_max_unpool, OctreeMaxUnpool,
|
|
15
|
+
octree_global_pool, OctreeGlobalPool,
|
|
16
|
+
octree_avg_pool, OctreeAvgPool,)
|
|
17
|
+
from .octree_conv import OctreeConv, OctreeDeconv
|
|
18
|
+
from .octree_gconv import OctreeGroupConv
|
|
19
|
+
from .octree_dwconv import OctreeDWConv
|
|
20
|
+
from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
|
|
21
|
+
OctreeInstanceNorm, OctreeNorm)
|
|
22
|
+
from .octree_drop import OctreeDropPath
|
|
23
|
+
from .octree_align import search_value, octree_align
|
|
24
|
+
from .octree_conv_t import OctreeConvTriton, OctreeConvT, convert_conv_triton
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
'octree2voxel',
|
|
28
|
+
'octree2col', 'col2octree',
|
|
29
|
+
'octree_pad', 'octree_depad',
|
|
30
|
+
'octree_nearest_pts', 'octree_linear_pts',
|
|
31
|
+
'octree_max_pool', 'octree_max_unpool',
|
|
32
|
+
'octree_global_pool', 'octree_avg_pool',
|
|
33
|
+
'Octree2Voxel',
|
|
34
|
+
'OctreeMaxPool', 'OctreeMaxUnpool',
|
|
35
|
+
'OctreeGlobalPool', 'OctreeAvgPool',
|
|
36
|
+
'OctreeConv', 'OctreeDeconv',
|
|
37
|
+
'OctreeGroupConv', 'OctreeDWConv',
|
|
38
|
+
'OctreeInterp', 'OctreeUpsample',
|
|
39
|
+
'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
|
|
40
|
+
'OctreeDropPath',
|
|
41
|
+
'search_value', 'octree_align',
|
|
42
|
+
'OctreeConvTriton', 'OctreeConvT', 'convert_conv_triton',
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
classes = __all__
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .conv_fwd_implicit_gemm_splitk import conv_fwd_implicit_gemm_splitk
|
|
2
|
+
from .conv_bwd_implicit_gemm_splitk import conv_bwd_implicit_gemm_splitk
|
|
3
|
+
from .conv_bwd_implicit_gemm import conv_bwd_implicit_gemm
|
|
4
|
+
from .conv_fwd_implicit_gemm import conv_fwd_implicit_gemm
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
'conv_fwd_implicit_gemm_splitk',
|
|
8
|
+
'conv_bwd_implicit_gemm_splitk',
|
|
9
|
+
'conv_bwd_implicit_gemm',
|
|
10
|
+
'conv_fwd_implicit_gemm',
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
from .autotuner import load_autotune_cache
|
|
14
|
+
load_autotune_cache()
|