ocnn 2.2.7__tar.gz → 2.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ocnn-2.2.7/ocnn.egg-info → ocnn-2.3.0}/PKG-INFO +11 -6
- {ocnn-2.2.7 → ocnn-2.3.0}/README.md +10 -4
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/__init__.py +1 -1
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/models/resnet.py +2 -2
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/__init__.py +2 -1
- ocnn-2.3.0/ocnn/nn/kernels/__init__.py +14 -0
- ocnn-2.3.0/ocnn/nn/kernels/autotuner.py +416 -0
- ocnn-2.3.0/ocnn/nn/kernels/config.py +67 -0
- ocnn-2.3.0/ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
- ocnn-2.3.0/ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
- ocnn-2.3.0/ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
- ocnn-2.3.0/ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
- ocnn-2.3.0/ocnn/nn/kernels/utils.py +44 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/octree_conv.py +2 -1
- ocnn-2.3.0/ocnn/nn/octree_conv_t.py +148 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/octree_pad.py +4 -4
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/octree/octree.py +218 -109
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/octree/points.py +95 -34
- {ocnn-2.2.7 → ocnn-2.3.0/ocnn.egg-info}/PKG-INFO +11 -6
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn.egg-info/SOURCES.txt +9 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/setup.py +2 -2
- {ocnn-2.2.7 → ocnn-2.3.0}/LICENSE +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/MANIFEST.in +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/dataset.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/models/__init__.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/models/autoencoder.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/models/hrnet.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/models/image2shape.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/models/lenet.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/models/ounet.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/models/segnet.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/models/unet.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/modules/__init__.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/modules/modules.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/modules/resblocks.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/octree2col.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/octree2vox.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/octree_align.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/octree_drop.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/octree_dwconv.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/octree_gconv.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/octree_interp.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/octree_norm.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/nn/octree_pool.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/octree/__init__.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/octree/shuffled_key.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn/utils.py +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn.egg-info/dependency_links.txt +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn.egg-info/not-zip-safe +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn.egg-info/requires.txt +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/ocnn.egg-info/top_level.txt +0 -0
- {ocnn-2.2.7 → ocnn-2.3.0}/setup.cfg +0 -0
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ocnn
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.3.0
|
|
4
4
|
Summary: Octree-based Sparse Convolutional Neural Networks
|
|
5
5
|
Home-page: https://github.com/octree-nn/ocnn-pytorch
|
|
6
6
|
Author: Peng-Shuai Wang
|
|
7
7
|
Author-email: wangps@hotmail.com
|
|
8
8
|
License: MIT
|
|
9
9
|
Classifier: Programming Language :: Python :: 3
|
|
10
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
11
10
|
Classifier: Operating System :: OS Independent
|
|
12
11
|
Requires-Python: >=3.6
|
|
13
12
|
Description-Content-Type: text/markdown
|
|
@@ -84,6 +83,13 @@ octrees to perform convolution operations. Of course, it also supports other 3D
|
|
|
84
83
|
data formats, such as meshes and volumetric grids, which can be converted into
|
|
85
84
|
octrees to leverage the library's capabilities.
|
|
86
85
|
|
|
86
|
+
## Updates
|
|
87
|
+
|
|
88
|
+
- **2026.02.02**: Release `v2.3.0`, incorporating Triton to accelerate
|
|
89
|
+
octree-based sparse convolution in the upcoming release. OctreeConv is even
|
|
90
|
+
**2.5 times faster than the latest spconv**!
|
|
91
|
+
- **2025.12.18**: Release `v2.2.8`, improving neighbor search efficiency.
|
|
92
|
+
|
|
87
93
|
|
|
88
94
|
## Key benefits of ocnn-pytorch
|
|
89
95
|
|
|
@@ -93,10 +99,9 @@ octrees to leverage the library's capabilities.
|
|
|
93
99
|
configure the compiling environment.
|
|
94
100
|
|
|
95
101
|
- **Efficiency**. The ocnn-pytorch is very efficient compared with other sparse
|
|
96
|
-
convolution frameworks.
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
takes 30 hours.
|
|
102
|
+
convolution frameworks. It is **even 2.5 times faster than the latest spconv
|
|
103
|
+
implementation**! Check the benchmark [code](test/benchmark_conv.py) and
|
|
104
|
+
[results](test/benchmark/results.png) for details. ✨
|
|
100
105
|
|
|
101
106
|
## Citation
|
|
102
107
|
|
|
@@ -54,6 +54,13 @@ octrees to perform convolution operations. Of course, it also supports other 3D
|
|
|
54
54
|
data formats, such as meshes and volumetric grids, which can be converted into
|
|
55
55
|
octrees to leverage the library's capabilities.
|
|
56
56
|
|
|
57
|
+
## Updates
|
|
58
|
+
|
|
59
|
+
- **2026.02.02**: Release `v2.3.0`, incorporating Triton to accelerate
|
|
60
|
+
octree-based sparse convolution in the upcoming release. OctreeConv is even
|
|
61
|
+
**2.5 times faster than the latest spconv**!
|
|
62
|
+
- **2025.12.18**: Release `v2.2.8`, improving neighbor search efficiency.
|
|
63
|
+
|
|
57
64
|
|
|
58
65
|
## Key benefits of ocnn-pytorch
|
|
59
66
|
|
|
@@ -63,10 +70,9 @@ octrees to leverage the library's capabilities.
|
|
|
63
70
|
configure the compiling environment.
|
|
64
71
|
|
|
65
72
|
- **Efficiency**. The ocnn-pytorch is very efficient compared with other sparse
|
|
66
|
-
convolution frameworks.
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
takes 30 hours.
|
|
73
|
+
convolution frameworks. It is **even 2.5 times faster than the latest spconv
|
|
74
|
+
implementation**! Check the benchmark [code](test/benchmark_conv.py) and
|
|
75
|
+
[results](test/benchmark/results.png) for details. ✨
|
|
70
76
|
|
|
71
77
|
## Citation
|
|
72
78
|
|
|
@@ -15,7 +15,7 @@ class ResNet(torch.nn.Module):
|
|
|
15
15
|
'''
|
|
16
16
|
|
|
17
17
|
def __init__(self, in_channels: int, out_channels: int, resblock_num: int,
|
|
18
|
-
stages: int, nempty: bool = False):
|
|
18
|
+
stages: int, nempty: bool = False, dropout: float = 0.5):
|
|
19
19
|
super().__init__()
|
|
20
20
|
self.in_channels = in_channels
|
|
21
21
|
self.out_channels = out_channels
|
|
@@ -36,7 +36,7 @@ class ResNet(torch.nn.Module):
|
|
|
36
36
|
# self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
|
|
37
37
|
self.header = torch.nn.Sequential(
|
|
38
38
|
ocnn.modules.FcBnRelu(channels[-1], 512),
|
|
39
|
-
torch.nn.Dropout(p=
|
|
39
|
+
torch.nn.Dropout(p=dropout),
|
|
40
40
|
torch.nn.Linear(512, out_channels))
|
|
41
41
|
|
|
42
42
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
@@ -21,7 +21,7 @@ from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
|
|
|
21
21
|
OctreeInstanceNorm, OctreeNorm)
|
|
22
22
|
from .octree_drop import OctreeDropPath
|
|
23
23
|
from .octree_align import search_value, octree_align
|
|
24
|
-
|
|
24
|
+
from .octree_conv_t import OctreeConvTriton, OctreeConvT, convert_conv_triton
|
|
25
25
|
|
|
26
26
|
__all__ = [
|
|
27
27
|
'octree2voxel',
|
|
@@ -39,6 +39,7 @@ __all__ = [
|
|
|
39
39
|
'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
|
|
40
40
|
'OctreeDropPath',
|
|
41
41
|
'search_value', 'octree_align',
|
|
42
|
+
'OctreeConvTriton', 'OctreeConvT', 'convert_conv_triton',
|
|
42
43
|
]
|
|
43
44
|
|
|
44
45
|
classes = __all__
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .conv_fwd_implicit_gemm_splitk import conv_fwd_implicit_gemm_splitk
|
|
2
|
+
from .conv_bwd_implicit_gemm_splitk import conv_bwd_implicit_gemm_splitk
|
|
3
|
+
from .conv_bwd_implicit_gemm import conv_bwd_implicit_gemm
|
|
4
|
+
from .conv_fwd_implicit_gemm import conv_fwd_implicit_gemm
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
'conv_fwd_implicit_gemm_splitk',
|
|
8
|
+
'conv_bwd_implicit_gemm_splitk',
|
|
9
|
+
'conv_bwd_implicit_gemm',
|
|
10
|
+
'conv_fwd_implicit_gemm',
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
from .autotuner import load_autotune_cache
|
|
14
|
+
load_autotune_cache()
|
|
@@ -0,0 +1,416 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import importlib
|
|
5
|
+
import pkgutil
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
import time
|
|
9
|
+
import inspect
|
|
10
|
+
from filelock import FileLock
|
|
11
|
+
from typing import Dict, Mapping
|
|
12
|
+
|
|
13
|
+
VERBOSE_AUTOTUNE = os.getenv('TRITON_PRINT_AUTOTUNING', '0') == '1'
|
|
14
|
+
AUTOSAVE_AUTOTUNE_CACHE = os.getenv('OCNN_AUTOSAVE_AUTOTUNE', '1') == '1'
|
|
15
|
+
AUTOTUNE_CACHE_PATH = os.getenv('OCNN_AUTOTUNE_CACHE_PATH',
|
|
16
|
+
os.path.expanduser('~/.ocnnconvt/autotune_cache.json'))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TritonPersistentCacheAutotuner(triton.runtime.Autotuner):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
fn,
|
|
23
|
+
arg_names,
|
|
24
|
+
configs,
|
|
25
|
+
key,
|
|
26
|
+
reset_to_zero,
|
|
27
|
+
restore_value,
|
|
28
|
+
pre_hook=None,
|
|
29
|
+
post_hook=None,
|
|
30
|
+
prune_configs_by: Dict = None,
|
|
31
|
+
warmup=None,
|
|
32
|
+
rep=None,
|
|
33
|
+
use_cuda_graph=False,
|
|
34
|
+
do_bench=None,
|
|
35
|
+
):
|
|
36
|
+
super().__init__(
|
|
37
|
+
fn,
|
|
38
|
+
arg_names,
|
|
39
|
+
configs,
|
|
40
|
+
key,
|
|
41
|
+
reset_to_zero,
|
|
42
|
+
restore_value,
|
|
43
|
+
pre_hook,
|
|
44
|
+
post_hook,
|
|
45
|
+
prune_configs_by,
|
|
46
|
+
warmup,
|
|
47
|
+
rep,
|
|
48
|
+
use_cuda_graph,
|
|
49
|
+
do_bench,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def run(self, *args, **kwargs):
|
|
53
|
+
self.nargs = dict(zip(self.arg_names, args))
|
|
54
|
+
used_cached_result = True
|
|
55
|
+
if len(self.configs) > 1:
|
|
56
|
+
all_args = {**self.nargs, **kwargs}
|
|
57
|
+
_args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
|
|
58
|
+
key = [_args[key] for key in self.keys if key in _args]
|
|
59
|
+
for _, arg in _args.items():
|
|
60
|
+
if hasattr(arg, "dtype"):
|
|
61
|
+
key.append(str(arg.dtype))
|
|
62
|
+
key = str(tuple(key))
|
|
63
|
+
if key not in self.cache:
|
|
64
|
+
# prune configs
|
|
65
|
+
used_cached_result = False
|
|
66
|
+
pruned_configs = self.prune_configs(kwargs)
|
|
67
|
+
bench_start = time.time()
|
|
68
|
+
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
|
|
69
|
+
bench_end = time.time()
|
|
70
|
+
self.bench_time = bench_end - bench_start
|
|
71
|
+
self.cache[key] = builtins.min(timings, key=timings.get)
|
|
72
|
+
full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
|
|
73
|
+
self.pre_hook(full_nargs, reset_only=True)
|
|
74
|
+
self.configs_timings = timings
|
|
75
|
+
config = self.cache[key]
|
|
76
|
+
else:
|
|
77
|
+
config = self.configs[0]
|
|
78
|
+
self.best_config = config
|
|
79
|
+
if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result:
|
|
80
|
+
print(f"Triton autotuning for function {self.base_fn.__name__} finished after "
|
|
81
|
+
f"{self.bench_time:.2f}s; best config selected: {self.best_config};")
|
|
82
|
+
if AUTOSAVE_AUTOTUNE_CACHE and not used_cached_result:
|
|
83
|
+
save_autotune_cache()
|
|
84
|
+
if config.pre_hook is not None:
|
|
85
|
+
full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
|
|
86
|
+
config.pre_hook(full_nargs)
|
|
87
|
+
ret = self.fn.run(
|
|
88
|
+
*args,
|
|
89
|
+
**kwargs,
|
|
90
|
+
**config.all_kwargs(),
|
|
91
|
+
)
|
|
92
|
+
self.nargs = None
|
|
93
|
+
return ret
|
|
94
|
+
|
|
95
|
+
def prune_configs(self, kwargs):
|
|
96
|
+
pruned_configs = self.configs
|
|
97
|
+
if self.early_config_prune:
|
|
98
|
+
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
|
|
99
|
+
if self.perf_model:
|
|
100
|
+
top_k = self.configs_top_k
|
|
101
|
+
if isinstance(top_k, float) and top_k <= 1.0:
|
|
102
|
+
top_k = int(len(self.configs) * top_k)
|
|
103
|
+
if len(pruned_configs) > top_k:
|
|
104
|
+
est_timing = {
|
|
105
|
+
config: self.perf_model(
|
|
106
|
+
**self.nargs,
|
|
107
|
+
**kwargs,
|
|
108
|
+
**config.all_kwargs(),
|
|
109
|
+
)
|
|
110
|
+
for config in pruned_configs
|
|
111
|
+
}
|
|
112
|
+
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
|
113
|
+
return pruned_configs
|
|
114
|
+
|
|
115
|
+
def warmup(self, *args, **kwargs):
|
|
116
|
+
self.nargs = dict(zip(self.arg_names, args))
|
|
117
|
+
ret = []
|
|
118
|
+
for config in self.prune_configs(kwargs):
|
|
119
|
+
ret.append(self.fn.warmup(
|
|
120
|
+
*args,
|
|
121
|
+
**kwargs,
|
|
122
|
+
**config.all_kwargs(),
|
|
123
|
+
))
|
|
124
|
+
self.nargs = None
|
|
125
|
+
return ret
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def triton_autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
|
|
129
|
+
warmup=None, rep=None, use_cuda_graph=False, do_bench=None):
|
|
130
|
+
"""
|
|
131
|
+
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
|
132
|
+
|
|
133
|
+
.. highlight:: python
|
|
134
|
+
.. code-block:: python
|
|
135
|
+
|
|
136
|
+
@triton_autotune(configs=[
|
|
137
|
+
triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4),
|
|
138
|
+
triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8),
|
|
139
|
+
],
|
|
140
|
+
key=['x_size'] # the two above configs will be evaluated anytime
|
|
141
|
+
# the value of x_size changes
|
|
142
|
+
)
|
|
143
|
+
@triton.jit
|
|
144
|
+
def kernel(x_ptr, x_size, **META):
|
|
145
|
+
BLOCK_SIZE = META['BLOCK_SIZE']
|
|
146
|
+
:note: When all the configurations are evaluated, the kernel will run multiple times.
|
|
147
|
+
This means that whatever value the kernel updates will be updated multiple times.
|
|
148
|
+
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
|
149
|
+
resets the value of the provided tensor to `zero` before running any configuration.
|
|
150
|
+
|
|
151
|
+
If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to
|
|
152
|
+
:code:`"1"`, Triton will print a message to stdout after autotuning each
|
|
153
|
+
kernel, including the time spent autotuning and the best configuration.
|
|
154
|
+
|
|
155
|
+
:param configs: a list of :code:`triton.Config` objects
|
|
156
|
+
:type configs: list[triton.Config]
|
|
157
|
+
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
|
158
|
+
:type key: list[str]
|
|
159
|
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
160
|
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
161
|
+
'top_k': number of configs to bench
|
|
162
|
+
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
|
163
|
+
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
|
164
|
+
:type reset_to_zero: list[str]
|
|
165
|
+
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
|
166
|
+
:type restore_value: list[str]
|
|
167
|
+
:param pre_hook: a function that will be called before the kernel is called.
|
|
168
|
+
This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'.
|
|
169
|
+
'kwargs': a dict of all arguments passed to the kernel.
|
|
170
|
+
'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook.
|
|
171
|
+
:type pre_hook: lambda args, reset_only
|
|
172
|
+
:param post_hook: a function that will be called after the kernel is called.
|
|
173
|
+
This overrides the default post_hook used for 'restore_value'.
|
|
174
|
+
'kwargs': a dict of all arguments passed to the kernel.
|
|
175
|
+
'exception': the exception raised by the kernel in case of a compilation or runtime error.
|
|
176
|
+
:type post_hook: lambda args, exception
|
|
177
|
+
:param warmup: warmup time (in ms) to pass to benchmarking (deprecated).
|
|
178
|
+
:type warmup: int
|
|
179
|
+
:param rep: repetition time (in ms) to pass to benchmarking (deprecated).
|
|
180
|
+
:type rep: int
|
|
181
|
+
:param do_bench: a benchmark function to measure the time of each run.
|
|
182
|
+
:type do_bench: lambda fn, quantiles
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def decorator(fn):
|
|
186
|
+
return TritonPersistentCacheAutotuner(
|
|
187
|
+
fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
|
|
188
|
+
post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
|
|
189
|
+
use_cuda_graph=use_cuda_graph
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return decorator
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class PersistentCacheAutoTuner:
|
|
196
|
+
def __init__(
|
|
197
|
+
self,
|
|
198
|
+
kernel,
|
|
199
|
+
configs=None,
|
|
200
|
+
key=None,
|
|
201
|
+
config_fn=None,
|
|
202
|
+
key_fn=None,
|
|
203
|
+
warmup=3,
|
|
204
|
+
runs=10,
|
|
205
|
+
verbose=False,
|
|
206
|
+
):
|
|
207
|
+
"""
|
|
208
|
+
AutoTuner is a wrapper class for a kernel that automatically tunes the kernel parameters to achieve the best performance.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
kernel: A callable object that takes in input arguments and returns the output.
|
|
212
|
+
configs: A list of Config objects that define the possible kernel parameters and their values.
|
|
213
|
+
key: A list of argument names that retune the kernel on change.
|
|
214
|
+
config_fn: A function that takes in the input arguments and returns configs to be used for autotuning.
|
|
215
|
+
key_fn: A function that takes in the input arguments and returns the key used to cache the tuning results.
|
|
216
|
+
Once the key changes, the autotuning will be rerun.
|
|
217
|
+
warmup: The number of warmup runs to discard before measuring the execution time.
|
|
218
|
+
runs: The number of runs to measure the execution time.
|
|
219
|
+
verbose: Whether to print the autotuning results.
|
|
220
|
+
"""
|
|
221
|
+
assert config_fn or configs, "Either configs or config_fn must be provided"
|
|
222
|
+
assert key_fn or key, "Either key or key_fn must be provided"
|
|
223
|
+
self.kernel = kernel
|
|
224
|
+
self.configs = configs
|
|
225
|
+
self.key = key
|
|
226
|
+
self.config_fn = config_fn
|
|
227
|
+
self.key_fn = key_fn
|
|
228
|
+
self.warmup = warmup
|
|
229
|
+
self.runs = runs
|
|
230
|
+
self.verbose = verbose
|
|
231
|
+
self.kernel_arg_names = inspect.getfullargspec(kernel).args
|
|
232
|
+
self.cache = {}
|
|
233
|
+
|
|
234
|
+
def _args_to_kwargs(self, args, kwargs):
|
|
235
|
+
# Convert args to kwargs
|
|
236
|
+
arg_names = self.kernel_arg_names
|
|
237
|
+
arg_dict = dict(zip(arg_names, args))
|
|
238
|
+
arg_dict.update(kwargs)
|
|
239
|
+
return arg_dict
|
|
240
|
+
|
|
241
|
+
def __call__(self, *args, **kwargs):
|
|
242
|
+
arg_dict = self._args_to_kwargs(args, kwargs)
|
|
243
|
+
|
|
244
|
+
# Determine key
|
|
245
|
+
key = self.key_fn(*args, **kwargs) if self.key_fn else tuple(arg_dict[k] for k in self.key)
|
|
246
|
+
key = str(key)
|
|
247
|
+
|
|
248
|
+
# If key changes, rerun autotune
|
|
249
|
+
used_cached_result = True
|
|
250
|
+
if key not in self.cache:
|
|
251
|
+
used_cached_result = False
|
|
252
|
+
if self.verbose:
|
|
253
|
+
print(f"Running autotuning for {self.kernel.__name__} with key {key}")
|
|
254
|
+
configs = self.configs if self.configs else self.config_fn(*args, **kwargs)
|
|
255
|
+
if self.verbose:
|
|
256
|
+
print(f"Configs: {configs}")
|
|
257
|
+
best_config = self._benchmark(args, kwargs, configs)
|
|
258
|
+
if self.verbose:
|
|
259
|
+
print(f"Best config for {self.kernel.__name__} with key {key}: {best_config}")
|
|
260
|
+
self.cache[key] = best_config
|
|
261
|
+
else:
|
|
262
|
+
if self.verbose:
|
|
263
|
+
print('Using cached config for {} with key {}'.format(self.kernel.__name__, key))
|
|
264
|
+
print('Config: {}'.format(self.cache[key]))
|
|
265
|
+
|
|
266
|
+
if AUTOSAVE_AUTOTUNE_CACHE and not used_cached_result:
|
|
267
|
+
save_autotune_cache()
|
|
268
|
+
|
|
269
|
+
# Run the kernel with the best config
|
|
270
|
+
return self.kernel(*args, **kwargs, **self.cache[key])
|
|
271
|
+
|
|
272
|
+
def _benchmark(self, args, kwargs, configs):
|
|
273
|
+
best_time = float('inf')
|
|
274
|
+
best_config = None
|
|
275
|
+
|
|
276
|
+
if len(configs) == 1:
|
|
277
|
+
best_config = configs[0]
|
|
278
|
+
else:
|
|
279
|
+
for config in configs:
|
|
280
|
+
# Run the kernel and measure execution time
|
|
281
|
+
for _ in range(self.warmup):
|
|
282
|
+
self.kernel(*args, **kwargs, **config)
|
|
283
|
+
torch.cuda.synchronize()
|
|
284
|
+
start = time.time()
|
|
285
|
+
for _ in range(self.runs):
|
|
286
|
+
self.kernel(*args, **kwargs, **config)
|
|
287
|
+
torch.cuda.synchronize()
|
|
288
|
+
elapsed = (time.time() - start) / self.runs
|
|
289
|
+
if self.verbose:
|
|
290
|
+
print(f"Config {config}: {elapsed} seconds")
|
|
291
|
+
# Update the best config if the execution time is better
|
|
292
|
+
if elapsed < best_time:
|
|
293
|
+
best_time = elapsed
|
|
294
|
+
best_config = config
|
|
295
|
+
|
|
296
|
+
return best_config
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def autotune(
|
|
302
|
+
configs=None,
|
|
303
|
+
key=None,
|
|
304
|
+
config_fn=None,
|
|
305
|
+
key_fn=None,
|
|
306
|
+
warmup=3,
|
|
307
|
+
runs=10,
|
|
308
|
+
verbose=VERBOSE_AUTOTUNE
|
|
309
|
+
):
|
|
310
|
+
def decorator(kernel):
|
|
311
|
+
return PersistentCacheAutoTuner(kernel, configs, key, config_fn, key_fn, warmup, runs, verbose)
|
|
312
|
+
return decorator
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def walk_package(package_name, fn):
|
|
316
|
+
try:
|
|
317
|
+
package = importlib.import_module(package_name)
|
|
318
|
+
except ModuleNotFoundError:
|
|
319
|
+
print(f"Package {package_name} not found.")
|
|
320
|
+
return
|
|
321
|
+
|
|
322
|
+
if not hasattr(package, '__path__'):
|
|
323
|
+
print(f"{package_name} is not a package.")
|
|
324
|
+
return
|
|
325
|
+
|
|
326
|
+
for _, module_name, is_pkg in pkgutil.iter_modules(package.__path__):
|
|
327
|
+
full_module_name = f"{package_name}.{module_name}"
|
|
328
|
+
if is_pkg:
|
|
329
|
+
walk_package(full_module_name, fn)
|
|
330
|
+
else:
|
|
331
|
+
fn(full_module_name)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def get_autotune_cache():
|
|
335
|
+
cache = {}
|
|
336
|
+
device_name = torch.cuda.get_device_name()
|
|
337
|
+
if device_name not in cache:
|
|
338
|
+
cache[device_name] = {}
|
|
339
|
+
|
|
340
|
+
def save_cache(full_module_name):
|
|
341
|
+
module = importlib.import_module(full_module_name)
|
|
342
|
+
for attr_name, attr in module.__dict__.items():
|
|
343
|
+
cache_key = f"{full_module_name}.{attr_name}"
|
|
344
|
+
if isinstance(attr, PersistentCacheAutoTuner):
|
|
345
|
+
cache[device_name][cache_key] = attr.cache
|
|
346
|
+
elif isinstance(attr, TritonPersistentCacheAutotuner):
|
|
347
|
+
cache[device_name][cache_key] = {k: v.__dict__ for k, v in attr.cache.items()}
|
|
348
|
+
|
|
349
|
+
walk_package('ocnn.nn.kernels', save_cache)
|
|
350
|
+
|
|
351
|
+
return cache
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def save_autotune_cache(path=None):
|
|
355
|
+
path = path or AUTOTUNE_CACHE_PATH
|
|
356
|
+
lock_path = path + ".lock"
|
|
357
|
+
|
|
358
|
+
with FileLock(lock_path):
|
|
359
|
+
if os.path.exists(path):
|
|
360
|
+
with open(path, 'r') as f:
|
|
361
|
+
cache = json.load(f)
|
|
362
|
+
else:
|
|
363
|
+
cache = {}
|
|
364
|
+
# Merge existing cache with new cache
|
|
365
|
+
cache.update(get_autotune_cache())
|
|
366
|
+
|
|
367
|
+
tmp_path = path + ".tmp"
|
|
368
|
+
with open(tmp_path, 'w') as f:
|
|
369
|
+
json.dump(cache, f, indent=4)
|
|
370
|
+
f.flush()
|
|
371
|
+
os.fsync(f.fileno())
|
|
372
|
+
os.replace(tmp_path, path)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def load_autotune_cache(path_or_cache=None):
|
|
376
|
+
cache = None
|
|
377
|
+
|
|
378
|
+
# Preserve path-based loading, but allow callers to provide a preloaded cache object.
|
|
379
|
+
if path_or_cache is None or isinstance(path_or_cache, (str, os.PathLike)):
|
|
380
|
+
path = path_or_cache or AUTOTUNE_CACHE_PATH
|
|
381
|
+
lock_path = path + ".lock"
|
|
382
|
+
|
|
383
|
+
if not os.path.exists(path):
|
|
384
|
+
return
|
|
385
|
+
|
|
386
|
+
with FileLock(lock_path):
|
|
387
|
+
with open(path, 'r') as f:
|
|
388
|
+
cache = json.load(f)
|
|
389
|
+
elif isinstance(path_or_cache, Mapping):
|
|
390
|
+
cache = path_or_cache
|
|
391
|
+
else:
|
|
392
|
+
raise TypeError("load_autotune_cache expects a path or a mapping")
|
|
393
|
+
|
|
394
|
+
if cache is None:
|
|
395
|
+
return
|
|
396
|
+
|
|
397
|
+
device_name = torch.cuda.get_device_name()
|
|
398
|
+
if device_name not in cache and "*" not in cache:
|
|
399
|
+
return
|
|
400
|
+
if "*" in cache and device_name not in cache:
|
|
401
|
+
device_name = "*"
|
|
402
|
+
|
|
403
|
+
def load_cache(full_module_name):
|
|
404
|
+
module = importlib.import_module(full_module_name)
|
|
405
|
+
for attr_name, attr in module.__dict__.items():
|
|
406
|
+
cache_key = f"{full_module_name}.{attr_name}"
|
|
407
|
+
if isinstance(attr, PersistentCacheAutoTuner):
|
|
408
|
+
if cache_key in cache[device_name]:
|
|
409
|
+
attr.cache = cache[device_name][cache_key]
|
|
410
|
+
elif isinstance(attr, TritonPersistentCacheAutotuner):
|
|
411
|
+
if cache_key in cache[device_name]:
|
|
412
|
+
for k, v in cache[device_name][cache_key].items():
|
|
413
|
+
attr.cache[k] = triton.runtime.Config(None)
|
|
414
|
+
attr.cache[k].__dict__.update(v)
|
|
415
|
+
|
|
416
|
+
walk_package('ocnn.nn.kernels', load_cache)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import triton
|
|
3
|
+
from .utils import get_autotune_config
|
|
4
|
+
|
|
5
|
+
allow_tf32 = os.getenv('OCNN_ALLOW_TF32', '1') == '1'
|
|
6
|
+
|
|
7
|
+
autotune_config = get_autotune_config(
|
|
8
|
+
platform={
|
|
9
|
+
'cuda': [
|
|
10
|
+
triton.Config({'B1': 128, 'B2': 256, 'BK': 64}, num_stages=3, num_warps=8),
|
|
11
|
+
triton.Config({'B1': 64, 'B2': 256, 'BK': 32}, num_stages=4, num_warps=4),
|
|
12
|
+
triton.Config({'B1': 128, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=4),
|
|
13
|
+
triton.Config({'B1': 128, 'B2': 64, 'BK': 32}, num_stages=4, num_warps=4),
|
|
14
|
+
triton.Config({'B1': 64, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=4),
|
|
15
|
+
triton.Config({'B1': 128, 'B2': 32, 'BK': 32}, num_stages=4, num_warps=4),
|
|
16
|
+
triton.Config({'B1': 64, 'B2': 32, 'BK': 32}, num_stages=5, num_warps=2),
|
|
17
|
+
triton.Config({'B1': 32, 'B2': 64, 'BK': 32}, num_stages=5, num_warps=2),
|
|
18
|
+
],
|
|
19
|
+
'hip': [
|
|
20
|
+
triton.Config({'B1': 128, 'B2': 256, 'BK': 16, 'waves_per_eu': 2}, num_warps=4, num_stages=2),
|
|
21
|
+
triton.Config({'B1': 256, 'B2': 256, 'BK': 16, 'waves_per_eu': 2}, num_warps=8, num_stages=2),
|
|
22
|
+
triton.Config({'B1': 128, 'B2': 128, 'BK': 32, 'waves_per_eu': 2}, num_warps=8, num_stages=2),
|
|
23
|
+
triton.Config({'B1': 64, 'B2': 128, 'BK': 32, 'waves_per_eu': 3}, num_warps=4, num_stages=2),
|
|
24
|
+
triton.Config({'B1': 64, 'B2': 64, 'BK': 32, 'waves_per_eu': 8}, num_warps=4, num_stages=2),
|
|
25
|
+
]
|
|
26
|
+
},
|
|
27
|
+
device={
|
|
28
|
+
'A100': [
|
|
29
|
+
triton.Config({'B1': 256, 'B2': 128, 'BK': 64}, num_stages=4, num_warps=8),
|
|
30
|
+
triton.Config({'B1': 256, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=8),
|
|
31
|
+
triton.Config({'B1': 128, 'B2': 256, 'BK': 64}, num_stages=4, num_warps=8),
|
|
32
|
+
triton.Config({'B1': 128, 'B2': 256, 'BK': 32}, num_stages=4, num_warps=8),
|
|
33
|
+
triton.Config({'B1': 256, 'B2': 64, 'BK': 64}, num_stages=4, num_warps=4),
|
|
34
|
+
triton.Config({'B1': 256, 'B2': 64, 'BK': 32}, num_stages=4, num_warps=4),
|
|
35
|
+
triton.Config({'B1': 64, 'B2': 256, 'BK': 64}, num_stages=4, num_warps=4),
|
|
36
|
+
triton.Config({'B1': 64, 'B2': 256, 'BK': 32}, num_stages=4, num_warps=4),
|
|
37
|
+
triton.Config({'B1': 128, 'B2': 128, 'BK': 64}, num_stages=4, num_warps=4),
|
|
38
|
+
triton.Config({'B1': 128, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=4),
|
|
39
|
+
triton.Config({'B1': 128, 'B2': 64, 'BK': 32}, num_stages=4, num_warps=4),
|
|
40
|
+
triton.Config({'B1': 128, 'B2': 64, 'BK': 32}, num_stages=4, num_warps=2),
|
|
41
|
+
triton.Config({'B1': 64, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=4),
|
|
42
|
+
triton.Config({'B1': 64, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=2),
|
|
43
|
+
triton.Config({'B1': 64, 'B2': 64, 'BK': 64}, num_stages=4, num_warps=2),
|
|
44
|
+
triton.Config({'B1': 64, 'B2': 64, 'BK': 32}, num_stages=4, num_warps=2),
|
|
45
|
+
],
|
|
46
|
+
'MI300X': [
|
|
47
|
+
triton.Config({'B1': 256, 'B2': 256, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=16),
|
|
48
|
+
triton.Config({'B1': 256, 'B2': 256, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
|
|
49
|
+
triton.Config({'B1': 256, 'B2': 128, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=16),
|
|
50
|
+
triton.Config({'B1': 256, 'B2': 128, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
|
|
51
|
+
triton.Config({'B1': 128, 'B2': 256, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=16),
|
|
52
|
+
triton.Config({'B1': 128, 'B2': 256, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
|
|
53
|
+
triton.Config({'B1': 256, 'B2': 64, 'BK': 32, 'waves_per_eu': 2}, num_stages=2, num_warps=8),
|
|
54
|
+
triton.Config({'B1': 256, 'B2': 64, 'BK': 32, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
|
|
55
|
+
triton.Config({'B1': 64, 'B2': 256, 'BK': 32, 'waves_per_eu': 2}, num_stages=2, num_warps=8),
|
|
56
|
+
triton.Config({'B1': 64, 'B2': 256, 'BK': 32, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
|
|
57
|
+
triton.Config({'B1': 128, 'B2': 128, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=8),
|
|
58
|
+
triton.Config({'B1': 128, 'B2': 128, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
|
|
59
|
+
triton.Config({'B1': 128, 'B2': 64, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=4),
|
|
60
|
+
triton.Config({'B1': 128, 'B2': 64, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=4),
|
|
61
|
+
triton.Config({'B1': 64, 'B2': 128, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=4),
|
|
62
|
+
triton.Config({'B1': 64, 'B2': 128, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=4),
|
|
63
|
+
triton.Config({'B1': 64, 'B2': 64, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=2),
|
|
64
|
+
triton.Config({'B1': 64, 'B2': 64, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=2),
|
|
65
|
+
],
|
|
66
|
+
}
|
|
67
|
+
)
|