triton-windows 3.3.1.post21__cp311-cp311-win_amd64.whl → 3.4.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +143 -46
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +94 -94
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +296 -125
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +73 -9
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +47 -83
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/METADATA +7 -2
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/RECORD +64 -41
- triton_windows-3.4.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post21.dist-info/top_level.txt +1 -0
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post21.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/WHEEL +0 -0
triton/language/standard.py
CHANGED
|
@@ -9,7 +9,7 @@ from . import math
|
|
|
9
9
|
|
|
10
10
|
def _log2(i: core.constexpr):
|
|
11
11
|
log2 = 0
|
|
12
|
-
n = i.value
|
|
12
|
+
n = core.constexpr(i).value
|
|
13
13
|
while n > 1:
|
|
14
14
|
n >>= 1
|
|
15
15
|
log2 += 1
|
|
@@ -50,10 +50,14 @@ def sigmoid(x):
|
|
|
50
50
|
@core._tensor_member_fn
|
|
51
51
|
@jit
|
|
52
52
|
@math._add_math_1arg_docstr("softmax")
|
|
53
|
-
def softmax(x, ieee_rounding=False):
|
|
54
|
-
|
|
53
|
+
def softmax(x, dim=None, keep_dims=False, ieee_rounding=False):
|
|
54
|
+
if dim is None:
|
|
55
|
+
_dim: core.constexpr = 0
|
|
56
|
+
else:
|
|
57
|
+
_dim: core.constexpr = dim
|
|
58
|
+
z = x - max(x, _dim, keep_dims=keep_dims)
|
|
55
59
|
num = math.exp(z)
|
|
56
|
-
den = sum(num,
|
|
60
|
+
den = sum(num, _dim, keep_dims=keep_dims)
|
|
57
61
|
return math.fdiv(num, den, ieee_rounding)
|
|
58
62
|
|
|
59
63
|
|
|
@@ -302,15 +306,37 @@ def xor_sum(input, axis=None, keep_dims=False):
|
|
|
302
306
|
return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims)
|
|
303
307
|
|
|
304
308
|
|
|
309
|
+
# or reduction
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
@jit
|
|
313
|
+
def _or_combine(x, y):
|
|
314
|
+
return x | y
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
@core._tensor_member_fn
|
|
318
|
+
@jit
|
|
319
|
+
@core._add_reduction_docstr("reduce_of")
|
|
320
|
+
def reduce_or(input, axis, keep_dims=False):
|
|
321
|
+
core.static_assert(input.type.scalar.is_int(), "reduce_of only supported for integers")
|
|
322
|
+
return core.reduce(input, axis, _or_combine, keep_dims=keep_dims)
|
|
323
|
+
|
|
324
|
+
|
|
305
325
|
# cumsum
|
|
306
326
|
|
|
307
327
|
|
|
308
328
|
@core._tensor_member_fn
|
|
309
329
|
@jit
|
|
310
|
-
@core._add_scan_docstr("cumsum")
|
|
311
|
-
def cumsum(input, axis=0, reverse=False):
|
|
330
|
+
@core._add_scan_docstr("cumsum", dtype_arg="dtype")
|
|
331
|
+
def cumsum(input, axis=0, reverse=False, dtype: core.constexpr = None):
|
|
312
332
|
# todo rename this to a generic function name
|
|
333
|
+
|
|
313
334
|
input = core._promote_bfloat16_to_float32(input)
|
|
335
|
+
out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
|
|
336
|
+
|
|
337
|
+
if out_dtype is not None:
|
|
338
|
+
input = input.to(out_dtype)
|
|
339
|
+
|
|
314
340
|
return core.associative_scan(input, axis, _sum_combine, reverse)
|
|
315
341
|
|
|
316
342
|
|
|
@@ -335,53 +361,63 @@ def cumprod(input, axis=0, reverse=False):
|
|
|
335
361
|
|
|
336
362
|
|
|
337
363
|
@jit
|
|
338
|
-
def
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
364
|
+
def _indicator(n_dims: core.constexpr, j: core.constexpr):
|
|
365
|
+
ar = core.arange(0, 2)
|
|
366
|
+
ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j)
|
|
367
|
+
return ar
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@jit
|
|
371
|
+
def _compare_and_swap(x, flip, i: core.constexpr):
|
|
372
|
+
# compare-and-swap on the ith *innermost* dimension
|
|
373
|
+
n_dims: core.constexpr = _log2(x.numel)
|
|
374
|
+
|
|
375
|
+
# flip along middle dimension (the bitwise XORs will be optimised away):
|
|
349
376
|
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
350
|
-
ileft = left.to(idtype, bitcast=True)
|
|
351
|
-
iright = right.to(idtype, bitcast=True)
|
|
352
377
|
ix = x.to(idtype, bitcast=True)
|
|
353
|
-
|
|
354
|
-
|
|
378
|
+
iy = ix ^ xor_sum(ix, n_dims - 1 - i, True)
|
|
379
|
+
y = iy.to(x.dtype, bitcast=True)
|
|
380
|
+
|
|
381
|
+
# determines whether we are in the right (rather than left) position along the axis:
|
|
382
|
+
is_right = _indicator(n_dims, i)
|
|
383
|
+
|
|
384
|
+
# conditional swap:
|
|
385
|
+
ret = core.where((x > y) != (flip ^ is_right), y, x)
|
|
386
|
+
return ret
|
|
355
387
|
|
|
356
388
|
|
|
357
389
|
@jit
|
|
358
|
-
def
|
|
390
|
+
def _bitonic_merge_hypercube(x, stage: core.constexpr, order: core.constexpr):
|
|
359
391
|
'''
|
|
360
392
|
order_type 0 == ascending
|
|
361
393
|
order_type 1 == descending
|
|
362
394
|
order_type 2 == alternating
|
|
363
395
|
'''
|
|
364
|
-
n_outer: core.constexpr = x.numel >> n_dims
|
|
365
|
-
core.static_assert(stage <= n_dims)
|
|
366
396
|
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
|
|
367
397
|
# descending order.
|
|
368
398
|
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
|
|
369
399
|
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
|
|
370
400
|
# a stride of 2) at this stage
|
|
371
401
|
if order == 2:
|
|
372
|
-
|
|
373
|
-
flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)
|
|
402
|
+
flip = _indicator(_log2(x.numel), stage)
|
|
374
403
|
else:
|
|
375
404
|
flip = order
|
|
376
405
|
# perform `stage` rounds of `compare-and-swap`
|
|
377
406
|
for i in core.static_range(stage):
|
|
378
|
-
x = _compare_and_swap(x, flip,
|
|
407
|
+
x = _compare_and_swap(x, flip, stage - 1 - i)
|
|
379
408
|
return x
|
|
380
409
|
|
|
381
410
|
|
|
382
|
-
@core._tensor_member_fn
|
|
383
411
|
@jit
|
|
384
|
-
def
|
|
412
|
+
def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
|
|
413
|
+
h = core.reshape(x, [2] * _log2(x.numel))
|
|
414
|
+
h = _bitonic_merge_hypercube(h, stage, order)
|
|
415
|
+
x = core.reshape(h, x.shape)
|
|
416
|
+
return x
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
@jit
|
|
420
|
+
def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
385
421
|
"""
|
|
386
422
|
Sorts a tensor along a specified dimension.
|
|
387
423
|
|
|
@@ -389,20 +425,55 @@ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTE
|
|
|
389
425
|
:type x: Tensor
|
|
390
426
|
:param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
|
|
391
427
|
:type dim: int, optional
|
|
428
|
+
:param k: the number of top elements to select. If none, assume k = x.shape[dim]
|
|
429
|
+
:type k: int, optional
|
|
392
430
|
:param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
|
|
393
431
|
:type descending: bool, optional
|
|
394
432
|
"""
|
|
395
433
|
# handle default dimension or check that it is the most minor dim
|
|
396
434
|
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
|
|
397
435
|
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
436
|
+
|
|
437
|
+
log_n: core.constexpr = _log2(x.shape[_dim])
|
|
438
|
+
log_k: core.constexpr = log_n if k is None else _log2(k)
|
|
439
|
+
|
|
440
|
+
n_dims: core.constexpr = _log2(x.numel)
|
|
441
|
+
|
|
442
|
+
# reshape to hypercube:
|
|
443
|
+
h = core.reshape(x, [2] * n_dims)
|
|
444
|
+
|
|
445
|
+
# run first log_k bitonic sort iterations:
|
|
446
|
+
for i in core.static_range(1, log_k + 1):
|
|
447
|
+
h = _bitonic_merge_hypercube(h, i, 2 if i < log_n else descending)
|
|
448
|
+
|
|
449
|
+
# select top k elements using bitonic top-k
|
|
450
|
+
# https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf
|
|
451
|
+
for i in core.static_range(log_k + 1, log_n + 1):
|
|
452
|
+
h = max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k))
|
|
453
|
+
h = _bitonic_merge_hypercube(h, log_k, 2 if i < log_n else descending)
|
|
454
|
+
|
|
455
|
+
# reshape back:
|
|
456
|
+
x = core.reshape(h, x.shape[:-1] + [2**log_k])
|
|
402
457
|
return x
|
|
403
458
|
|
|
404
459
|
|
|
405
|
-
|
|
460
|
+
@jit
|
|
461
|
+
def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
462
|
+
return sort_impl(x, dim=dim, descending=descending)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
@jit
|
|
466
|
+
def topk(x, k: core.constexpr, dim: core.constexpr = None):
|
|
467
|
+
return sort_impl(x, k=k, dim=dim, descending=True)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
@jit
|
|
471
|
+
def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
472
|
+
# handle default dimension or check that it is the most minor dim
|
|
473
|
+
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
|
|
474
|
+
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
|
|
475
|
+
n_dims: core.constexpr = _log2(x.shape[-1])
|
|
476
|
+
return _bitonic_merge(x, n_dims, descending, n_dims)
|
|
406
477
|
|
|
407
478
|
|
|
408
479
|
def _get_flip_dim(dim, shape):
|
|
@@ -410,7 +481,8 @@ def _get_flip_dim(dim, shape):
|
|
|
410
481
|
shape = core._unwrap_if_constexpr(shape)
|
|
411
482
|
if dim is None:
|
|
412
483
|
dim = len(shape) - 1
|
|
413
|
-
|
|
484
|
+
if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
|
|
485
|
+
dim += len(shape)
|
|
414
486
|
return core.constexpr(dim)
|
|
415
487
|
|
|
416
488
|
|
|
@@ -422,26 +494,19 @@ def flip(x, dim=None):
|
|
|
422
494
|
|
|
423
495
|
:param x: the first input tensor
|
|
424
496
|
:type x: Block
|
|
425
|
-
:param dim: the dimension to flip along
|
|
497
|
+
:param dim: the dimension to flip along
|
|
426
498
|
:type dim: int
|
|
427
499
|
"""
|
|
428
|
-
core.static_assert(
|
|
429
|
-
core.
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
steps: core.constexpr = _log2(x.numel)
|
|
433
|
-
start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
|
|
500
|
+
core.static_assert(-len(x.shape) <= dim and dim < len(x.shape))
|
|
501
|
+
_dim: core.constexpr = _get_flip_dim(dim, x.shape)
|
|
502
|
+
core.static_assert(_is_power_of_two(x.shape[_dim]))
|
|
503
|
+
steps: core.constexpr = _log2(x.shape[_dim])
|
|
434
504
|
|
|
505
|
+
# reshape the swap dimension to (2, 2, ..., 2)
|
|
435
506
|
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
436
|
-
y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
for i in core.static_range(start, steps):
|
|
440
|
-
flip2 = flip
|
|
441
|
-
for j in core.static_range(0, steps + 1):
|
|
442
|
-
if j != i and j != i + 1:
|
|
443
|
-
flip2 = core.expand_dims(flip2, j)
|
|
444
|
-
y = sum(y * flip2, i + 1, keep_dims=True, dtype=y.dtype)
|
|
507
|
+
y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:])
|
|
508
|
+
for i in core.static_range(steps):
|
|
509
|
+
y = y ^ xor_sum(y, _dim + i, True)
|
|
445
510
|
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
|
|
446
511
|
return x
|
|
447
512
|
|
triton/runtime/autotuner.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import builtins
|
|
4
|
-
import os
|
|
5
4
|
import time
|
|
6
5
|
import inspect
|
|
6
|
+
import hashlib
|
|
7
|
+
import json
|
|
8
|
+
from functools import cached_property
|
|
7
9
|
from typing import Dict, Tuple, List, Optional
|
|
8
10
|
|
|
11
|
+
from .. import knobs
|
|
9
12
|
from .jit import KernelInterface
|
|
10
13
|
from .errors import OutOfResources, PTXASError
|
|
11
14
|
from .driver import driver
|
|
@@ -13,22 +16,9 @@ from .driver import driver
|
|
|
13
16
|
|
|
14
17
|
class Autotuner(KernelInterface):
|
|
15
18
|
|
|
16
|
-
def __init__(
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
arg_names,
|
|
20
|
-
configs,
|
|
21
|
-
key,
|
|
22
|
-
reset_to_zero,
|
|
23
|
-
restore_value,
|
|
24
|
-
pre_hook=None,
|
|
25
|
-
post_hook=None,
|
|
26
|
-
prune_configs_by: Optional[Dict] = None,
|
|
27
|
-
warmup=None,
|
|
28
|
-
rep=None,
|
|
29
|
-
use_cuda_graph=False,
|
|
30
|
-
do_bench=None,
|
|
31
|
-
):
|
|
19
|
+
def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None,
|
|
20
|
+
prune_configs_by: Optional[Dict] = None, warmup=None, rep=None, use_cuda_graph=False, do_bench=None,
|
|
21
|
+
cache_results=False):
|
|
32
22
|
"""
|
|
33
23
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
34
24
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
@@ -36,15 +26,13 @@ class Autotuner(KernelInterface):
|
|
|
36
26
|
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
|
37
27
|
"""
|
|
38
28
|
if not configs:
|
|
39
|
-
self.configs = [
|
|
40
|
-
Config({}, num_warps=4, num_stages=3, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0,
|
|
41
|
-
reg_dec_producer=0, reg_inc_consumer=0)
|
|
42
|
-
]
|
|
29
|
+
self.configs = [Config({}, num_warps=4, num_stages=3, num_ctas=1)]
|
|
43
30
|
else:
|
|
44
31
|
self.configs = configs
|
|
45
32
|
self.keys = key
|
|
46
33
|
self.cache: Dict[Tuple, Config] = {}
|
|
47
34
|
self.arg_names = arg_names
|
|
35
|
+
self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret)
|
|
48
36
|
|
|
49
37
|
# Reset to zero or restore values
|
|
50
38
|
self.reset_to_zero = []
|
|
@@ -97,6 +85,7 @@ class Autotuner(KernelInterface):
|
|
|
97
85
|
while not inspect.isfunction(self.base_fn):
|
|
98
86
|
self.base_fn = self.base_fn.fn
|
|
99
87
|
|
|
88
|
+
self._do_bench = do_bench
|
|
100
89
|
self.num_warmups = warmup
|
|
101
90
|
self.num_reps = rep
|
|
102
91
|
self.use_cuda_graph = use_cuda_graph
|
|
@@ -110,7 +99,7 @@ class Autotuner(KernelInterface):
|
|
|
110
99
|
stacklevel=1)
|
|
111
100
|
if use_cuda_graph:
|
|
112
101
|
from ..testing import do_bench_cudagraph
|
|
113
|
-
self.
|
|
102
|
+
self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
|
|
114
103
|
kernel_call,
|
|
115
104
|
rep=rep if rep is not None else 100,
|
|
116
105
|
quantiles=quantiles,
|
|
@@ -118,7 +107,7 @@ class Autotuner(KernelInterface):
|
|
|
118
107
|
return
|
|
119
108
|
|
|
120
109
|
import triton.testing
|
|
121
|
-
self.
|
|
110
|
+
self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
|
|
122
111
|
kernel_call,
|
|
123
112
|
warmup=warmup if warmup is not None else 25,
|
|
124
113
|
rep=rep if rep is not None else 100,
|
|
@@ -126,15 +115,16 @@ class Autotuner(KernelInterface):
|
|
|
126
115
|
)
|
|
127
116
|
return
|
|
128
117
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
118
|
+
@cached_property
|
|
119
|
+
def do_bench(self):
|
|
120
|
+
if self._do_bench is None:
|
|
121
|
+
return driver.active.get_benchmarker()
|
|
122
|
+
return self._do_bench
|
|
133
123
|
|
|
134
124
|
def _bench(self, *args, config, **meta):
|
|
135
125
|
from ..compiler.errors import CompileTimeAssertionFailure
|
|
136
126
|
|
|
137
|
-
verbose =
|
|
127
|
+
verbose = knobs.autotuning.print
|
|
138
128
|
if verbose:
|
|
139
129
|
print(f"Autotuning kernel {self.base_fn.__name__} with config {config}")
|
|
140
130
|
|
|
@@ -173,6 +163,51 @@ class Autotuner(KernelInterface):
|
|
|
173
163
|
print(f"Autotuning failed with {e}")
|
|
174
164
|
return [float("inf"), float("inf"), float("inf")]
|
|
175
165
|
|
|
166
|
+
def check_disk_cache(self, tuning_key, configs, bench_fn):
|
|
167
|
+
# We can't serialize prehooks, so just give up and run the benchmarks.
|
|
168
|
+
if not tuning_key or any(cfg.pre_hook for cfg in configs):
|
|
169
|
+
bench_fn()
|
|
170
|
+
return False
|
|
171
|
+
|
|
172
|
+
from triton._C.libtriton import get_cache_invalidating_env_vars
|
|
173
|
+
from triton.compiler.compiler import make_backend, triton_key
|
|
174
|
+
from triton.runtime.cache import get_cache_manager
|
|
175
|
+
from triton.runtime.jit import JITFunction
|
|
176
|
+
|
|
177
|
+
fn = self.fn
|
|
178
|
+
while not isinstance(fn, JITFunction):
|
|
179
|
+
fn = fn.fn
|
|
180
|
+
|
|
181
|
+
env_vars = get_cache_invalidating_env_vars()
|
|
182
|
+
cache_key = [
|
|
183
|
+
triton_key(),
|
|
184
|
+
make_backend(driver.active.get_current_target()).hash(),
|
|
185
|
+
fn.cache_key,
|
|
186
|
+
str(sorted(env_vars.items())),
|
|
187
|
+
str(tuning_key),
|
|
188
|
+
] + [str(c) for c in configs]
|
|
189
|
+
cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest()
|
|
190
|
+
cache = get_cache_manager(cache_key)
|
|
191
|
+
file_name = f"{fn.__name__[:150]}.autotune.json"
|
|
192
|
+
path = cache.get_file(file_name)
|
|
193
|
+
if path:
|
|
194
|
+
with open(path, "r") as cached_configs:
|
|
195
|
+
timings = json.load(cached_configs)["configs_timings"]
|
|
196
|
+
timings = {Config(**config): timing for config, timing in timings}
|
|
197
|
+
self.cache[tuning_key] = builtins.min(timings, key=timings.get)
|
|
198
|
+
self.configs_timings = timings
|
|
199
|
+
return True
|
|
200
|
+
|
|
201
|
+
bench_fn()
|
|
202
|
+
cache.put(
|
|
203
|
+
json.dumps({
|
|
204
|
+
"key":
|
|
205
|
+
tuning_key,
|
|
206
|
+
"configs_timings":
|
|
207
|
+
[(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook],
|
|
208
|
+
}), file_name, binary=False)
|
|
209
|
+
return False
|
|
210
|
+
|
|
176
211
|
def run(self, *args, **kwargs):
|
|
177
212
|
self.nargs = dict(zip(self.arg_names, args))
|
|
178
213
|
used_cached_result = True
|
|
@@ -185,24 +220,31 @@ class Autotuner(KernelInterface):
|
|
|
185
220
|
key.append(str(arg.dtype))
|
|
186
221
|
key = tuple(key)
|
|
187
222
|
if key not in self.cache:
|
|
188
|
-
# prune configs
|
|
189
223
|
used_cached_result = False
|
|
190
224
|
pruned_configs = self.prune_configs(kwargs)
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
225
|
+
|
|
226
|
+
def benchmark():
|
|
227
|
+
bench_start = time.perf_counter()
|
|
228
|
+
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
|
|
229
|
+
bench_end = time.perf_counter()
|
|
230
|
+
self.bench_time = bench_end - bench_start
|
|
231
|
+
self.cache[key] = builtins.min(timings, key=timings.get)
|
|
232
|
+
full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
|
|
233
|
+
self.pre_hook(full_nargs, reset_only=True)
|
|
234
|
+
self.configs_timings = timings
|
|
235
|
+
|
|
236
|
+
if self.cache_results:
|
|
237
|
+
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
|
|
238
|
+
else:
|
|
239
|
+
benchmark()
|
|
240
|
+
|
|
199
241
|
config = self.cache[key]
|
|
200
242
|
else:
|
|
201
243
|
config = self.configs[0]
|
|
202
244
|
self.best_config = config
|
|
203
|
-
if
|
|
204
|
-
print(f"Triton autotuning for function {self.base_fn.__name__}
|
|
205
|
-
f"{self.bench_time:.2f}s
|
|
245
|
+
if knobs.autotuning.print and not used_cached_result:
|
|
246
|
+
print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
|
|
247
|
+
f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};")
|
|
206
248
|
if config.pre_hook is not None:
|
|
207
249
|
full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
|
|
208
250
|
config.pre_hook(full_nargs)
|
|
@@ -241,11 +283,11 @@ class Autotuner(KernelInterface):
|
|
|
241
283
|
def warmup(self, *args, **kwargs):
|
|
242
284
|
self.nargs = dict(zip(self.arg_names, args))
|
|
243
285
|
ret = []
|
|
244
|
-
for
|
|
286
|
+
for autotune_config in self.prune_configs(kwargs):
|
|
245
287
|
ret.append(self.fn.warmup(
|
|
246
288
|
*args,
|
|
247
289
|
**kwargs,
|
|
248
|
-
**
|
|
290
|
+
**autotune_config.all_kwargs(),
|
|
249
291
|
))
|
|
250
292
|
self.nargs = None
|
|
251
293
|
return ret
|
|
@@ -263,27 +305,34 @@ class Config:
|
|
|
263
305
|
:type num_warps: int
|
|
264
306
|
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
|
265
307
|
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
|
266
|
-
:type
|
|
308
|
+
:type num_stages: int
|
|
267
309
|
:ivar num_ctas: number of blocks in a block cluster. SM90+ only.
|
|
310
|
+
:type num_ctas: int
|
|
268
311
|
:type maxnreg: Optional[int]
|
|
269
312
|
:ivar maxnreg: maximum number of registers one thread can use. Corresponds
|
|
270
313
|
to ptx .maxnreg directive. Not supported on all platforms.
|
|
271
314
|
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
|
272
315
|
function are args.
|
|
316
|
+
:ivar ir_override: filename of a user-defined IR (*.{ttgir|llir|ptx|amdgcn}).
|
|
273
317
|
"""
|
|
274
318
|
|
|
275
|
-
def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1,
|
|
276
|
-
reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None):
|
|
319
|
+
def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None, ir_override=None):
|
|
277
320
|
self.kwargs = kwargs
|
|
278
321
|
self.num_warps = num_warps
|
|
279
322
|
self.num_ctas = num_ctas
|
|
280
323
|
self.num_stages = num_stages
|
|
281
|
-
self.num_buffers_warp_spec = num_buffers_warp_spec
|
|
282
|
-
self.num_consumer_groups = num_consumer_groups
|
|
283
|
-
self.reg_dec_producer = reg_dec_producer
|
|
284
|
-
self.reg_inc_consumer = reg_inc_consumer
|
|
285
324
|
self.maxnreg = maxnreg
|
|
286
325
|
self.pre_hook = pre_hook
|
|
326
|
+
self.ir_override = ir_override
|
|
327
|
+
|
|
328
|
+
def __setstate__(self, state):
|
|
329
|
+
self.kwargs = state.get("kwargs", {})
|
|
330
|
+
self.num_warps = state.get("num_warps", 4)
|
|
331
|
+
self.num_stages = state.get("num_stages", 3)
|
|
332
|
+
self.num_ctas = state.get("num_ctas", 1)
|
|
333
|
+
self.maxnreg = state.get("maxnreg", None)
|
|
334
|
+
self.pre_hook = state.get("pre_hook", None)
|
|
335
|
+
self.ir_override = state.get("ir_override", None)
|
|
287
336
|
|
|
288
337
|
def all_kwargs(self):
|
|
289
338
|
return {
|
|
@@ -293,11 +342,8 @@ class Config:
|
|
|
293
342
|
("num_warps", self.num_warps),
|
|
294
343
|
("num_ctas", self.num_ctas),
|
|
295
344
|
("num_stages", self.num_stages),
|
|
296
|
-
("num_buffers_warp_spec", self.num_buffers_warp_spec),
|
|
297
|
-
("num_consumer_groups", self.num_consumer_groups),
|
|
298
|
-
("reg_dec_producer", self.reg_dec_producer),
|
|
299
|
-
("reg_inc_consumer", self.reg_inc_consumer),
|
|
300
345
|
("maxnreg", self.maxnreg),
|
|
346
|
+
("ir_override", self.ir_override),
|
|
301
347
|
) if v is not None
|
|
302
348
|
}
|
|
303
349
|
}
|
|
@@ -309,16 +355,26 @@ class Config:
|
|
|
309
355
|
res.append(f"num_warps: {self.num_warps}")
|
|
310
356
|
res.append(f"num_ctas: {self.num_ctas}")
|
|
311
357
|
res.append(f"num_stages: {self.num_stages}")
|
|
312
|
-
res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}")
|
|
313
|
-
res.append(f"num_consumer_groups: {self.num_consumer_groups}")
|
|
314
|
-
res.append(f"reg_dec_producer: {self.reg_dec_producer}")
|
|
315
|
-
res.append(f"reg_inc_consumer: {self.reg_inc_consumer}")
|
|
316
358
|
res.append(f"maxnreg: {self.maxnreg}")
|
|
317
359
|
return ", ".join(res)
|
|
318
360
|
|
|
361
|
+
def __hash__(self):
|
|
362
|
+
return hash((*self.all_kwargs().items(), self.pre_hook))
|
|
363
|
+
|
|
364
|
+
def __eq__(self, other):
|
|
365
|
+
self_tuple = tuple((
|
|
366
|
+
*self.all_kwargs().items(),
|
|
367
|
+
self.pre_hook,
|
|
368
|
+
))
|
|
369
|
+
other_tuple = tuple((
|
|
370
|
+
*other.all_kwargs().items(),
|
|
371
|
+
other.pre_hook,
|
|
372
|
+
))
|
|
373
|
+
return self_tuple == other_tuple
|
|
374
|
+
|
|
319
375
|
|
|
320
376
|
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
|
|
321
|
-
warmup=None, rep=None, use_cuda_graph=False, do_bench=None):
|
|
377
|
+
warmup=None, rep=None, use_cuda_graph=False, do_bench=None, cache_results=False):
|
|
322
378
|
"""
|
|
323
379
|
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
|
324
380
|
|
|
@@ -372,12 +428,14 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
|
|
|
372
428
|
:type rep: int
|
|
373
429
|
:param do_bench: a benchmark function to measure the time of each run.
|
|
374
430
|
:type do_bench: lambda fn, quantiles
|
|
431
|
+
:param cache_results: whether to cache autotune timings to disk. Defaults to False.
|
|
432
|
+
"type cache_results: bool
|
|
375
433
|
"""
|
|
376
434
|
|
|
377
435
|
def decorator(fn):
|
|
378
436
|
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
|
|
379
437
|
post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
|
|
380
|
-
use_cuda_graph=use_cuda_graph, do_bench=do_bench)
|
|
438
|
+
use_cuda_graph=use_cuda_graph, do_bench=do_bench, cache_results=cache_results)
|
|
381
439
|
|
|
382
440
|
return decorator
|
|
383
441
|
|