ocnn 2.2.7__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/__init__.py CHANGED
@@ -12,7 +12,7 @@ from . import models
12
12
  from . import dataset
13
13
  from . import utils
14
14
 
15
- __version__ = '2.2.7'
15
+ __version__ = '2.3.0'
16
16
 
17
17
  __all__ = [
18
18
  'octree',
ocnn/models/resnet.py CHANGED
@@ -15,7 +15,7 @@ class ResNet(torch.nn.Module):
15
15
  '''
16
16
 
17
17
  def __init__(self, in_channels: int, out_channels: int, resblock_num: int,
18
- stages: int, nempty: bool = False):
18
+ stages: int, nempty: bool = False, dropout: float = 0.5):
19
19
  super().__init__()
20
20
  self.in_channels = in_channels
21
21
  self.out_channels = out_channels
@@ -36,7 +36,7 @@ class ResNet(torch.nn.Module):
36
36
  # self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
37
37
  self.header = torch.nn.Sequential(
38
38
  ocnn.modules.FcBnRelu(channels[-1], 512),
39
- torch.nn.Dropout(p=0.5),
39
+ torch.nn.Dropout(p=dropout),
40
40
  torch.nn.Linear(512, out_channels))
41
41
 
42
42
  def forward(self, data: torch.Tensor, octree: Octree, depth: int):
ocnn/nn/__init__.py CHANGED
@@ -21,7 +21,7 @@ from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
21
21
  OctreeInstanceNorm, OctreeNorm)
22
22
  from .octree_drop import OctreeDropPath
23
23
  from .octree_align import search_value, octree_align
24
-
24
+ from .octree_conv_t import OctreeConvTriton, OctreeConvT, convert_conv_triton
25
25
 
26
26
  __all__ = [
27
27
  'octree2voxel',
@@ -39,6 +39,7 @@ __all__ = [
39
39
  'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
40
40
  'OctreeDropPath',
41
41
  'search_value', 'octree_align',
42
+ 'OctreeConvTriton', 'OctreeConvT', 'convert_conv_triton',
42
43
  ]
43
44
 
44
45
  classes = __all__
@@ -0,0 +1,14 @@
1
+ from .conv_fwd_implicit_gemm_splitk import conv_fwd_implicit_gemm_splitk
2
+ from .conv_bwd_implicit_gemm_splitk import conv_bwd_implicit_gemm_splitk
3
+ from .conv_bwd_implicit_gemm import conv_bwd_implicit_gemm
4
+ from .conv_fwd_implicit_gemm import conv_fwd_implicit_gemm
5
+
6
+ __all__ = [
7
+ 'conv_fwd_implicit_gemm_splitk',
8
+ 'conv_bwd_implicit_gemm_splitk',
9
+ 'conv_bwd_implicit_gemm',
10
+ 'conv_fwd_implicit_gemm',
11
+ ]
12
+
13
+ from .autotuner import load_autotune_cache
14
+ load_autotune_cache()
@@ -0,0 +1,416 @@
1
+ import builtins
2
+ import os
3
+ import json
4
+ import importlib
5
+ import pkgutil
6
+ import torch
7
+ import triton
8
+ import time
9
+ import inspect
10
+ from filelock import FileLock
11
+ from typing import Dict, Mapping
12
+
13
+ VERBOSE_AUTOTUNE = os.getenv('TRITON_PRINT_AUTOTUNING', '0') == '1'
14
+ AUTOSAVE_AUTOTUNE_CACHE = os.getenv('OCNN_AUTOSAVE_AUTOTUNE', '1') == '1'
15
+ AUTOTUNE_CACHE_PATH = os.getenv('OCNN_AUTOTUNE_CACHE_PATH',
16
+ os.path.expanduser('~/.ocnnconvt/autotune_cache.json'))
17
+
18
+
19
+ class TritonPersistentCacheAutotuner(triton.runtime.Autotuner):
20
+ def __init__(
21
+ self,
22
+ fn,
23
+ arg_names,
24
+ configs,
25
+ key,
26
+ reset_to_zero,
27
+ restore_value,
28
+ pre_hook=None,
29
+ post_hook=None,
30
+ prune_configs_by: Dict = None,
31
+ warmup=None,
32
+ rep=None,
33
+ use_cuda_graph=False,
34
+ do_bench=None,
35
+ ):
36
+ super().__init__(
37
+ fn,
38
+ arg_names,
39
+ configs,
40
+ key,
41
+ reset_to_zero,
42
+ restore_value,
43
+ pre_hook,
44
+ post_hook,
45
+ prune_configs_by,
46
+ warmup,
47
+ rep,
48
+ use_cuda_graph,
49
+ do_bench,
50
+ )
51
+
52
+ def run(self, *args, **kwargs):
53
+ self.nargs = dict(zip(self.arg_names, args))
54
+ used_cached_result = True
55
+ if len(self.configs) > 1:
56
+ all_args = {**self.nargs, **kwargs}
57
+ _args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
58
+ key = [_args[key] for key in self.keys if key in _args]
59
+ for _, arg in _args.items():
60
+ if hasattr(arg, "dtype"):
61
+ key.append(str(arg.dtype))
62
+ key = str(tuple(key))
63
+ if key not in self.cache:
64
+ # prune configs
65
+ used_cached_result = False
66
+ pruned_configs = self.prune_configs(kwargs)
67
+ bench_start = time.time()
68
+ timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
69
+ bench_end = time.time()
70
+ self.bench_time = bench_end - bench_start
71
+ self.cache[key] = builtins.min(timings, key=timings.get)
72
+ full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
73
+ self.pre_hook(full_nargs, reset_only=True)
74
+ self.configs_timings = timings
75
+ config = self.cache[key]
76
+ else:
77
+ config = self.configs[0]
78
+ self.best_config = config
79
+ if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result:
80
+ print(f"Triton autotuning for function {self.base_fn.__name__} finished after "
81
+ f"{self.bench_time:.2f}s; best config selected: {self.best_config};")
82
+ if AUTOSAVE_AUTOTUNE_CACHE and not used_cached_result:
83
+ save_autotune_cache()
84
+ if config.pre_hook is not None:
85
+ full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
86
+ config.pre_hook(full_nargs)
87
+ ret = self.fn.run(
88
+ *args,
89
+ **kwargs,
90
+ **config.all_kwargs(),
91
+ )
92
+ self.nargs = None
93
+ return ret
94
+
95
+ def prune_configs(self, kwargs):
96
+ pruned_configs = self.configs
97
+ if self.early_config_prune:
98
+ pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
99
+ if self.perf_model:
100
+ top_k = self.configs_top_k
101
+ if isinstance(top_k, float) and top_k <= 1.0:
102
+ top_k = int(len(self.configs) * top_k)
103
+ if len(pruned_configs) > top_k:
104
+ est_timing = {
105
+ config: self.perf_model(
106
+ **self.nargs,
107
+ **kwargs,
108
+ **config.all_kwargs(),
109
+ )
110
+ for config in pruned_configs
111
+ }
112
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
113
+ return pruned_configs
114
+
115
+ def warmup(self, *args, **kwargs):
116
+ self.nargs = dict(zip(self.arg_names, args))
117
+ ret = []
118
+ for config in self.prune_configs(kwargs):
119
+ ret.append(self.fn.warmup(
120
+ *args,
121
+ **kwargs,
122
+ **config.all_kwargs(),
123
+ ))
124
+ self.nargs = None
125
+ return ret
126
+
127
+
128
+ def triton_autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
129
+ warmup=None, rep=None, use_cuda_graph=False, do_bench=None):
130
+ """
131
+ Decorator for auto-tuning a :code:`triton.jit`'d function.
132
+
133
+ .. highlight:: python
134
+ .. code-block:: python
135
+
136
+ @triton_autotune(configs=[
137
+ triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4),
138
+ triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8),
139
+ ],
140
+ key=['x_size'] # the two above configs will be evaluated anytime
141
+ # the value of x_size changes
142
+ )
143
+ @triton.jit
144
+ def kernel(x_ptr, x_size, **META):
145
+ BLOCK_SIZE = META['BLOCK_SIZE']
146
+ :note: When all the configurations are evaluated, the kernel will run multiple times.
147
+ This means that whatever value the kernel updates will be updated multiple times.
148
+ To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
149
+ resets the value of the provided tensor to `zero` before running any configuration.
150
+
151
+ If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to
152
+ :code:`"1"`, Triton will print a message to stdout after autotuning each
153
+ kernel, including the time spent autotuning and the best configuration.
154
+
155
+ :param configs: a list of :code:`triton.Config` objects
156
+ :type configs: list[triton.Config]
157
+ :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
158
+ :type key: list[str]
159
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
160
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
161
+ 'top_k': number of configs to bench
162
+ 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
163
+ :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
164
+ :type reset_to_zero: list[str]
165
+ :param restore_value: a list of argument names whose value will be restored after evaluating any configs.
166
+ :type restore_value: list[str]
167
+ :param pre_hook: a function that will be called before the kernel is called.
168
+ This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'.
169
+ 'kwargs': a dict of all arguments passed to the kernel.
170
+ 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook.
171
+ :type pre_hook: lambda args, reset_only
172
+ :param post_hook: a function that will be called after the kernel is called.
173
+ This overrides the default post_hook used for 'restore_value'.
174
+ 'kwargs': a dict of all arguments passed to the kernel.
175
+ 'exception': the exception raised by the kernel in case of a compilation or runtime error.
176
+ :type post_hook: lambda args, exception
177
+ :param warmup: warmup time (in ms) to pass to benchmarking (deprecated).
178
+ :type warmup: int
179
+ :param rep: repetition time (in ms) to pass to benchmarking (deprecated).
180
+ :type rep: int
181
+ :param do_bench: a benchmark function to measure the time of each run.
182
+ :type do_bench: lambda fn, quantiles
183
+ """
184
+
185
+ def decorator(fn):
186
+ return TritonPersistentCacheAutotuner(
187
+ fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
188
+ post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
189
+ use_cuda_graph=use_cuda_graph
190
+ )
191
+
192
+ return decorator
193
+
194
+
195
+ class PersistentCacheAutoTuner:
196
+ def __init__(
197
+ self,
198
+ kernel,
199
+ configs=None,
200
+ key=None,
201
+ config_fn=None,
202
+ key_fn=None,
203
+ warmup=3,
204
+ runs=10,
205
+ verbose=False,
206
+ ):
207
+ """
208
+ AutoTuner is a wrapper class for a kernel that automatically tunes the kernel parameters to achieve the best performance.
209
+
210
+ Args:
211
+ kernel: A callable object that takes in input arguments and returns the output.
212
+ configs: A list of Config objects that define the possible kernel parameters and their values.
213
+ key: A list of argument names that retune the kernel on change.
214
+ config_fn: A function that takes in the input arguments and returns configs to be used for autotuning.
215
+ key_fn: A function that takes in the input arguments and returns the key used to cache the tuning results.
216
+ Once the key changes, the autotuning will be rerun.
217
+ warmup: The number of warmup runs to discard before measuring the execution time.
218
+ runs: The number of runs to measure the execution time.
219
+ verbose: Whether to print the autotuning results.
220
+ """
221
+ assert config_fn or configs, "Either configs or config_fn must be provided"
222
+ assert key_fn or key, "Either key or key_fn must be provided"
223
+ self.kernel = kernel
224
+ self.configs = configs
225
+ self.key = key
226
+ self.config_fn = config_fn
227
+ self.key_fn = key_fn
228
+ self.warmup = warmup
229
+ self.runs = runs
230
+ self.verbose = verbose
231
+ self.kernel_arg_names = inspect.getfullargspec(kernel).args
232
+ self.cache = {}
233
+
234
+ def _args_to_kwargs(self, args, kwargs):
235
+ # Convert args to kwargs
236
+ arg_names = self.kernel_arg_names
237
+ arg_dict = dict(zip(arg_names, args))
238
+ arg_dict.update(kwargs)
239
+ return arg_dict
240
+
241
+ def __call__(self, *args, **kwargs):
242
+ arg_dict = self._args_to_kwargs(args, kwargs)
243
+
244
+ # Determine key
245
+ key = self.key_fn(*args, **kwargs) if self.key_fn else tuple(arg_dict[k] for k in self.key)
246
+ key = str(key)
247
+
248
+ # If key changes, rerun autotune
249
+ used_cached_result = True
250
+ if key not in self.cache:
251
+ used_cached_result = False
252
+ if self.verbose:
253
+ print(f"Running autotuning for {self.kernel.__name__} with key {key}")
254
+ configs = self.configs if self.configs else self.config_fn(*args, **kwargs)
255
+ if self.verbose:
256
+ print(f"Configs: {configs}")
257
+ best_config = self._benchmark(args, kwargs, configs)
258
+ if self.verbose:
259
+ print(f"Best config for {self.kernel.__name__} with key {key}: {best_config}")
260
+ self.cache[key] = best_config
261
+ else:
262
+ if self.verbose:
263
+ print('Using cached config for {} with key {}'.format(self.kernel.__name__, key))
264
+ print('Config: {}'.format(self.cache[key]))
265
+
266
+ if AUTOSAVE_AUTOTUNE_CACHE and not used_cached_result:
267
+ save_autotune_cache()
268
+
269
+ # Run the kernel with the best config
270
+ return self.kernel(*args, **kwargs, **self.cache[key])
271
+
272
+ def _benchmark(self, args, kwargs, configs):
273
+ best_time = float('inf')
274
+ best_config = None
275
+
276
+ if len(configs) == 1:
277
+ best_config = configs[0]
278
+ else:
279
+ for config in configs:
280
+ # Run the kernel and measure execution time
281
+ for _ in range(self.warmup):
282
+ self.kernel(*args, **kwargs, **config)
283
+ torch.cuda.synchronize()
284
+ start = time.time()
285
+ for _ in range(self.runs):
286
+ self.kernel(*args, **kwargs, **config)
287
+ torch.cuda.synchronize()
288
+ elapsed = (time.time() - start) / self.runs
289
+ if self.verbose:
290
+ print(f"Config {config}: {elapsed} seconds")
291
+ # Update the best config if the execution time is better
292
+ if elapsed < best_time:
293
+ best_time = elapsed
294
+ best_config = config
295
+
296
+ return best_config
297
+
298
+
299
+
300
+
301
+ def autotune(
302
+ configs=None,
303
+ key=None,
304
+ config_fn=None,
305
+ key_fn=None,
306
+ warmup=3,
307
+ runs=10,
308
+ verbose=VERBOSE_AUTOTUNE
309
+ ):
310
+ def decorator(kernel):
311
+ return PersistentCacheAutoTuner(kernel, configs, key, config_fn, key_fn, warmup, runs, verbose)
312
+ return decorator
313
+
314
+
315
+ def walk_package(package_name, fn):
316
+ try:
317
+ package = importlib.import_module(package_name)
318
+ except ModuleNotFoundError:
319
+ print(f"Package {package_name} not found.")
320
+ return
321
+
322
+ if not hasattr(package, '__path__'):
323
+ print(f"{package_name} is not a package.")
324
+ return
325
+
326
+ for _, module_name, is_pkg in pkgutil.iter_modules(package.__path__):
327
+ full_module_name = f"{package_name}.{module_name}"
328
+ if is_pkg:
329
+ walk_package(full_module_name, fn)
330
+ else:
331
+ fn(full_module_name)
332
+
333
+
334
+ def get_autotune_cache():
335
+ cache = {}
336
+ device_name = torch.cuda.get_device_name()
337
+ if device_name not in cache:
338
+ cache[device_name] = {}
339
+
340
+ def save_cache(full_module_name):
341
+ module = importlib.import_module(full_module_name)
342
+ for attr_name, attr in module.__dict__.items():
343
+ cache_key = f"{full_module_name}.{attr_name}"
344
+ if isinstance(attr, PersistentCacheAutoTuner):
345
+ cache[device_name][cache_key] = attr.cache
346
+ elif isinstance(attr, TritonPersistentCacheAutotuner):
347
+ cache[device_name][cache_key] = {k: v.__dict__ for k, v in attr.cache.items()}
348
+
349
+ walk_package('ocnn.nn.kernels', save_cache)
350
+
351
+ return cache
352
+
353
+
354
+ def save_autotune_cache(path=None):
355
+ path = path or AUTOTUNE_CACHE_PATH
356
+ lock_path = path + ".lock"
357
+
358
+ with FileLock(lock_path):
359
+ if os.path.exists(path):
360
+ with open(path, 'r') as f:
361
+ cache = json.load(f)
362
+ else:
363
+ cache = {}
364
+ # Merge existing cache with new cache
365
+ cache.update(get_autotune_cache())
366
+
367
+ tmp_path = path + ".tmp"
368
+ with open(tmp_path, 'w') as f:
369
+ json.dump(cache, f, indent=4)
370
+ f.flush()
371
+ os.fsync(f.fileno())
372
+ os.replace(tmp_path, path)
373
+
374
+
375
+ def load_autotune_cache(path_or_cache=None):
376
+ cache = None
377
+
378
+ # Preserve path-based loading, but allow callers to provide a preloaded cache object.
379
+ if path_or_cache is None or isinstance(path_or_cache, (str, os.PathLike)):
380
+ path = path_or_cache or AUTOTUNE_CACHE_PATH
381
+ lock_path = path + ".lock"
382
+
383
+ if not os.path.exists(path):
384
+ return
385
+
386
+ with FileLock(lock_path):
387
+ with open(path, 'r') as f:
388
+ cache = json.load(f)
389
+ elif isinstance(path_or_cache, Mapping):
390
+ cache = path_or_cache
391
+ else:
392
+ raise TypeError("load_autotune_cache expects a path or a mapping")
393
+
394
+ if cache is None:
395
+ return
396
+
397
+ device_name = torch.cuda.get_device_name()
398
+ if device_name not in cache and "*" not in cache:
399
+ return
400
+ if "*" in cache and device_name not in cache:
401
+ device_name = "*"
402
+
403
+ def load_cache(full_module_name):
404
+ module = importlib.import_module(full_module_name)
405
+ for attr_name, attr in module.__dict__.items():
406
+ cache_key = f"{full_module_name}.{attr_name}"
407
+ if isinstance(attr, PersistentCacheAutoTuner):
408
+ if cache_key in cache[device_name]:
409
+ attr.cache = cache[device_name][cache_key]
410
+ elif isinstance(attr, TritonPersistentCacheAutotuner):
411
+ if cache_key in cache[device_name]:
412
+ for k, v in cache[device_name][cache_key].items():
413
+ attr.cache[k] = triton.runtime.Config(None)
414
+ attr.cache[k].__dict__.update(v)
415
+
416
+ walk_package('ocnn.nn.kernels', load_cache)
@@ -0,0 +1,67 @@
1
+ import os
2
+ import triton
3
+ from .utils import get_autotune_config
4
+
5
+ allow_tf32 = os.getenv('OCNN_ALLOW_TF32', '1') == '1'
6
+
7
+ autotune_config = get_autotune_config(
8
+ platform={
9
+ 'cuda': [
10
+ triton.Config({'B1': 128, 'B2': 256, 'BK': 64}, num_stages=3, num_warps=8),
11
+ triton.Config({'B1': 64, 'B2': 256, 'BK': 32}, num_stages=4, num_warps=4),
12
+ triton.Config({'B1': 128, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=4),
13
+ triton.Config({'B1': 128, 'B2': 64, 'BK': 32}, num_stages=4, num_warps=4),
14
+ triton.Config({'B1': 64, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=4),
15
+ triton.Config({'B1': 128, 'B2': 32, 'BK': 32}, num_stages=4, num_warps=4),
16
+ triton.Config({'B1': 64, 'B2': 32, 'BK': 32}, num_stages=5, num_warps=2),
17
+ triton.Config({'B1': 32, 'B2': 64, 'BK': 32}, num_stages=5, num_warps=2),
18
+ ],
19
+ 'hip': [
20
+ triton.Config({'B1': 128, 'B2': 256, 'BK': 16, 'waves_per_eu': 2}, num_warps=4, num_stages=2),
21
+ triton.Config({'B1': 256, 'B2': 256, 'BK': 16, 'waves_per_eu': 2}, num_warps=8, num_stages=2),
22
+ triton.Config({'B1': 128, 'B2': 128, 'BK': 32, 'waves_per_eu': 2}, num_warps=8, num_stages=2),
23
+ triton.Config({'B1': 64, 'B2': 128, 'BK': 32, 'waves_per_eu': 3}, num_warps=4, num_stages=2),
24
+ triton.Config({'B1': 64, 'B2': 64, 'BK': 32, 'waves_per_eu': 8}, num_warps=4, num_stages=2),
25
+ ]
26
+ },
27
+ device={
28
+ 'A100': [
29
+ triton.Config({'B1': 256, 'B2': 128, 'BK': 64}, num_stages=4, num_warps=8),
30
+ triton.Config({'B1': 256, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=8),
31
+ triton.Config({'B1': 128, 'B2': 256, 'BK': 64}, num_stages=4, num_warps=8),
32
+ triton.Config({'B1': 128, 'B2': 256, 'BK': 32}, num_stages=4, num_warps=8),
33
+ triton.Config({'B1': 256, 'B2': 64, 'BK': 64}, num_stages=4, num_warps=4),
34
+ triton.Config({'B1': 256, 'B2': 64, 'BK': 32}, num_stages=4, num_warps=4),
35
+ triton.Config({'B1': 64, 'B2': 256, 'BK': 64}, num_stages=4, num_warps=4),
36
+ triton.Config({'B1': 64, 'B2': 256, 'BK': 32}, num_stages=4, num_warps=4),
37
+ triton.Config({'B1': 128, 'B2': 128, 'BK': 64}, num_stages=4, num_warps=4),
38
+ triton.Config({'B1': 128, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=4),
39
+ triton.Config({'B1': 128, 'B2': 64, 'BK': 32}, num_stages=4, num_warps=4),
40
+ triton.Config({'B1': 128, 'B2': 64, 'BK': 32}, num_stages=4, num_warps=2),
41
+ triton.Config({'B1': 64, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=4),
42
+ triton.Config({'B1': 64, 'B2': 128, 'BK': 32}, num_stages=4, num_warps=2),
43
+ triton.Config({'B1': 64, 'B2': 64, 'BK': 64}, num_stages=4, num_warps=2),
44
+ triton.Config({'B1': 64, 'B2': 64, 'BK': 32}, num_stages=4, num_warps=2),
45
+ ],
46
+ 'MI300X': [
47
+ triton.Config({'B1': 256, 'B2': 256, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=16),
48
+ triton.Config({'B1': 256, 'B2': 256, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
49
+ triton.Config({'B1': 256, 'B2': 128, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=16),
50
+ triton.Config({'B1': 256, 'B2': 128, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
51
+ triton.Config({'B1': 128, 'B2': 256, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=16),
52
+ triton.Config({'B1': 128, 'B2': 256, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
53
+ triton.Config({'B1': 256, 'B2': 64, 'BK': 32, 'waves_per_eu': 2}, num_stages=2, num_warps=8),
54
+ triton.Config({'B1': 256, 'B2': 64, 'BK': 32, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
55
+ triton.Config({'B1': 64, 'B2': 256, 'BK': 32, 'waves_per_eu': 2}, num_stages=2, num_warps=8),
56
+ triton.Config({'B1': 64, 'B2': 256, 'BK': 32, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
57
+ triton.Config({'B1': 128, 'B2': 128, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=8),
58
+ triton.Config({'B1': 128, 'B2': 128, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=8),
59
+ triton.Config({'B1': 128, 'B2': 64, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=4),
60
+ triton.Config({'B1': 128, 'B2': 64, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=4),
61
+ triton.Config({'B1': 64, 'B2': 128, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=4),
62
+ triton.Config({'B1': 64, 'B2': 128, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=4),
63
+ triton.Config({'B1': 64, 'B2': 64, 'BK': 64, 'waves_per_eu': 2}, num_stages=2, num_warps=2),
64
+ triton.Config({'B1': 64, 'B2': 64, 'BK': 64, 'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16}, num_stages=2, num_warps=2),
65
+ ],
66
+ }
67
+ )