triton-windows 3.4.0.post20__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,132 @@
|
|
|
1
|
-
from
|
|
2
|
-
from . import
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from triton.compiler.code_generator import unflatten_ir_values
|
|
3
|
+
from ..ampere import async_copy
|
|
4
|
+
from . import mbarrier, tma
|
|
3
5
|
from ... import _core
|
|
4
6
|
|
|
5
|
-
|
|
7
|
+
from typing import List, Tuple, TYPE_CHECKING
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from triton._C.libtriton import ir
|
|
10
|
+
|
|
11
|
+
__all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma", "warpgroup_mma", "warpgroup_mma_wait"]
|
|
6
12
|
|
|
7
13
|
|
|
8
14
|
@_core.builtin
|
|
9
15
|
def fence_async_shared(cluster=False, _semantic=None):
|
|
16
|
+
"""
|
|
17
|
+
Issue a fence to complete asynchronous shared memory operations.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
cluster (bool): Whether to fence across cluster. Defaults to False.
|
|
21
|
+
"""
|
|
10
22
|
cluster = _core._unwrap_if_constexpr(cluster)
|
|
11
23
|
_semantic.builder.create_fence_async_shared(cluster)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class warpgroup_mma_accumulator_type(_core.base_type):
|
|
27
|
+
tensor_type: _core.dtype
|
|
28
|
+
|
|
29
|
+
def __init__(self, tensor_type: _core.dtype):
|
|
30
|
+
self.tensor_type = tensor_type
|
|
31
|
+
|
|
32
|
+
def __str__(self) -> str:
|
|
33
|
+
return f"warpgroup_mma_accumulator<{self.tensor_type}>"
|
|
34
|
+
|
|
35
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[warpgroup_mma_accumulator, int]:
|
|
36
|
+
return warpgroup_mma_accumulator(handles[cursor], self.tensor_type), cursor + 1
|
|
37
|
+
|
|
38
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
39
|
+
self.tensor_type._flatten_ir_types(builder, out)
|
|
40
|
+
|
|
41
|
+
def __eq__(self, other) -> bool:
|
|
42
|
+
return type(self) is type(other) and self.tensor_type == other.tensor_type
|
|
43
|
+
|
|
44
|
+
def mangle(self) -> str:
|
|
45
|
+
return f"FT{self.tensor_type.mangle()}FT"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class warpgroup_mma_accumulator(_core.base_value):
|
|
49
|
+
handle: ir.value
|
|
50
|
+
type: warpgroup_mma_accumulator_type
|
|
51
|
+
|
|
52
|
+
def __init__(self, handle, tensor_type: _core.dtype):
|
|
53
|
+
self.handle = handle
|
|
54
|
+
self.type = warpgroup_mma_accumulator_type(tensor_type)
|
|
55
|
+
|
|
56
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
57
|
+
handles.append(self.handle)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@_core.builtin
|
|
61
|
+
def warpgroup_mma_init(value, _semantic):
|
|
62
|
+
assert isinstance(value, _core.tensor)
|
|
63
|
+
return warpgroup_mma_accumulator(value.handle, value.type)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@_core.builtin
|
|
67
|
+
def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_acc=None, is_async=False,
|
|
68
|
+
_semantic=None):
|
|
69
|
+
"""
|
|
70
|
+
Perform warpgroup MMA (Tensor Core) operations.
|
|
71
|
+
acc = a * b + (acc if use_acc else 0)
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
a (tensor or shared_memory_descriptor): Left hand side operand.
|
|
75
|
+
b (shared_memory_descriptor): Right hand side operand.
|
|
76
|
+
acc (tensor): Accumulator tensor.
|
|
77
|
+
use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True.
|
|
78
|
+
precision (str, optional): Dot input precision. Defaults to builder default.
|
|
79
|
+
max_num_imprecise_acc (int): Max imprecise accumulations. Used for fp8 -> fp32 dot. Determines how many accumulation are done in limited precision. Defaults to None, which means no upcasting is done.
|
|
80
|
+
is_async (bool): Whether operation is asynchronous. Defaults to False.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
tensor or warpgroup_mma_accumulator: Returns the result if synchronous, or a token to load the value once computed if asynchronous.
|
|
84
|
+
"""
|
|
85
|
+
use_acc = _semantic.to_tensor(use_acc)
|
|
86
|
+
|
|
87
|
+
if precision is None:
|
|
88
|
+
precision = _semantic.builder.options.default_dot_input_precision
|
|
89
|
+
|
|
90
|
+
precision = _semantic._str_to_dot_input_precision(precision)
|
|
91
|
+
|
|
92
|
+
K = a.type.shape[-1]
|
|
93
|
+
if max_num_imprecise_acc is None:
|
|
94
|
+
if a.dtype.is_fp8() and b.dtype.is_fp8():
|
|
95
|
+
max_num_imprecise_acc = _semantic.builder.options.max_num_imprecise_acc_default
|
|
96
|
+
else:
|
|
97
|
+
max_num_imprecise_acc = 0
|
|
98
|
+
else:
|
|
99
|
+
if a.dtype.is_fp8() and b.dtype.is_fp8() and max_num_imprecise_acc > K:
|
|
100
|
+
raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})")
|
|
101
|
+
|
|
102
|
+
max_num_imprecise_acc = _core._unwrap_if_constexpr(max_num_imprecise_acc)
|
|
103
|
+
is_async = _core._unwrap_if_constexpr(is_async)
|
|
104
|
+
|
|
105
|
+
handle = _semantic.builder.create_warpgroup_mma(a.handle, b.handle, acc.handle, use_acc.handle, precision,
|
|
106
|
+
max_num_imprecise_acc, is_async)
|
|
107
|
+
tensor_ty = acc.type.tensor_type if isinstance(acc, warpgroup_mma_accumulator) else acc.type
|
|
108
|
+
if is_async:
|
|
109
|
+
return warpgroup_mma_accumulator(handle, tensor_ty)
|
|
110
|
+
else:
|
|
111
|
+
return _core.tensor(handle, tensor_ty)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@_core.builtin
|
|
115
|
+
def warpgroup_mma_wait(num_outstanding=0, deps=None, _semantic=None):
|
|
116
|
+
"""
|
|
117
|
+
Wait until `num_outstanding` or less warpgroup MMA operations are in-flight.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
num_outstanding (int): Number of outstanding warpgroup MMA operations to wait for. Defaults to 0.
|
|
121
|
+
deps (Sequence[tensor]): List of dependencies that need to be kept alive while the mma is unfinished.
|
|
122
|
+
"""
|
|
123
|
+
if deps is None:
|
|
124
|
+
raise ValueError("warpgroup_mma_wait deps must be given")
|
|
125
|
+
deps_handles = [x.handle for x in deps] if deps is not None else []
|
|
126
|
+
num_outstanding = _core._unwrap_if_constexpr(num_outstanding)
|
|
127
|
+
results = _semantic.builder.create_warpgroup_mma_wait(deps_handles, num_outstanding)
|
|
128
|
+
result_types = [dep.type.tensor_type if isinstance(dep, warpgroup_mma_accumulator) else dep.type for dep in deps]
|
|
129
|
+
results = unflatten_ir_values(results, result_types)
|
|
130
|
+
if len(deps) == 1:
|
|
131
|
+
return next(results)
|
|
132
|
+
return tuple(results)
|
|
@@ -1,51 +1,34 @@
|
|
|
1
|
-
from
|
|
2
|
-
from
|
|
1
|
+
from ..ampere.mbarrier import MBarrierLayout, init, invalidate, wait
|
|
2
|
+
from ..._core import _unwrap_if_constexpr, builtin
|
|
3
3
|
|
|
4
|
-
__all__ = ["
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class MBarrierLayout(SwizzledSharedLayout):
|
|
8
|
-
|
|
9
|
-
def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1):
|
|
10
|
-
super().__init__(
|
|
11
|
-
vec=1,
|
|
12
|
-
per_phase=1,
|
|
13
|
-
max_phase=1,
|
|
14
|
-
order=[0],
|
|
15
|
-
ctas_per_cga=[ctas_per_cga],
|
|
16
|
-
cta_split_num=[cta_split_num],
|
|
17
|
-
cta_order=[0],
|
|
18
|
-
)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@builtin
|
|
22
|
-
def init(mbarrier, count, _semantic=None):
|
|
23
|
-
count = _unwrap_if_constexpr(count)
|
|
24
|
-
_semantic.builder.create_mbarrier_init(mbarrier.handle, count)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
@builtin
|
|
28
|
-
def invalidate(mbarrier, _semantic=None):
|
|
29
|
-
_semantic.builder.create_mbarrier_inval(mbarrier.handle)
|
|
4
|
+
__all__ = ["arrive", "expect", "init", "invalidate", "MBarrierLayout", "wait"]
|
|
30
5
|
|
|
31
6
|
|
|
32
7
|
@builtin
|
|
33
8
|
def expect(mbarrier, bytes, pred=True, _semantic=None):
|
|
9
|
+
"""
|
|
10
|
+
Expect a specific number of bytes being copied. When they are copied, the barrier is signaled.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
mbarrier (shared_memory_descriptor): Barrier that will be signaled when the operation is complete.
|
|
14
|
+
bytes (int): Expected byte count.
|
|
15
|
+
pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
|
|
16
|
+
"""
|
|
34
17
|
bytes = _unwrap_if_constexpr(bytes)
|
|
35
18
|
pred = _semantic.to_tensor(pred)
|
|
36
19
|
_semantic.builder.create_mbarrier_expect(mbarrier.handle, bytes, pred.handle)
|
|
37
20
|
|
|
38
21
|
|
|
39
22
|
@builtin
|
|
40
|
-
def
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
23
|
+
def arrive(mbarrier, *, count=1, pred=True, _semantic=None):
|
|
24
|
+
"""
|
|
25
|
+
Arrive at an mbarrier with a specified count.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
mbarrier (shared_memory_descriptor): Barrier to be signalled.
|
|
29
|
+
count (int): Count to arrive with. Defaults to 1.
|
|
30
|
+
pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
|
|
31
|
+
"""
|
|
49
32
|
count = _unwrap_if_constexpr(count)
|
|
50
33
|
pred = _semantic.to_tensor(pred)
|
|
51
34
|
_semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from typing import List, Tuple, TYPE_CHECKING
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
+
from triton.language.core import base_type, base_value
|
|
4
5
|
import triton.experimental.gluon.language._core as ttgl
|
|
5
6
|
from triton.experimental.gluon.language._layouts import NVMMASharedLayout
|
|
6
7
|
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
|
|
@@ -12,7 +13,7 @@ __all__ = ["async_copy_global_to_shared", "async_copy_shared_to_global", "store_
|
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
@dataclass(eq=True)
|
|
15
|
-
class tensor_descriptor_type:
|
|
16
|
+
class tensor_descriptor_type(base_type):
|
|
16
17
|
block_type: ttgl.block_type
|
|
17
18
|
shape_type: ttgl.tuple_type
|
|
18
19
|
strides_type: ttgl.tuple_type
|
|
@@ -41,10 +42,10 @@ class tensor_descriptor_type:
|
|
|
41
42
|
self.strides_type._flatten_ir_types(builder, out)
|
|
42
43
|
|
|
43
44
|
def mangle(self) -> str:
|
|
44
|
-
return f"TD{self.block_type.mangle}_{self.layout.mangle()}TD"
|
|
45
|
+
return f"TD{self.block_type.mangle()}_{self.layout.mangle()}TD"
|
|
45
46
|
|
|
46
47
|
|
|
47
|
-
class tensor_descriptor:
|
|
48
|
+
class tensor_descriptor(base_value):
|
|
48
49
|
|
|
49
50
|
def __init__(self, handle, shape: List[ttgl.tensor], strides: List[ttgl.tensor], block_type: ttgl.block_type,
|
|
50
51
|
layout: NVMMASharedLayout):
|
|
@@ -13,6 +13,7 @@ class TensorDescriptor:
|
|
|
13
13
|
strides: List[int]
|
|
14
14
|
block_shape: List[int]
|
|
15
15
|
layout: NVMMASharedLayout
|
|
16
|
+
padding: str = "zero"
|
|
16
17
|
|
|
17
18
|
def __post_init__(self):
|
|
18
19
|
rank = len(self.shape)
|
|
@@ -28,13 +29,17 @@ class TensorDescriptor:
|
|
|
28
29
|
assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
|
|
29
30
|
assert self.strides[-1] == 1, "Last dimension must be contiguous"
|
|
30
31
|
assert isinstance(self.layout, NVMMASharedLayout), "Layout must be NVMMASharedLayout"
|
|
32
|
+
assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
|
|
33
|
+
if self.padding == "nan":
|
|
34
|
+
assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
|
|
31
35
|
|
|
32
36
|
@staticmethod
|
|
33
|
-
def from_tensor(tensor: Any, block_shape: List[int], layout: NVMMASharedLayout):
|
|
37
|
+
def from_tensor(tensor: Any, block_shape: List[int], layout: NVMMASharedLayout, padding="zero"):
|
|
34
38
|
return TensorDescriptor(
|
|
35
39
|
tensor,
|
|
36
40
|
tensor.shape,
|
|
37
41
|
tensor.stride(),
|
|
38
42
|
block_shape,
|
|
39
43
|
layout,
|
|
44
|
+
padding,
|
|
40
45
|
)
|
triton/knobs.py
CHANGED
|
@@ -1,15 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import functools
|
|
3
4
|
import importlib
|
|
4
5
|
import os
|
|
5
6
|
import re
|
|
6
7
|
import subprocess
|
|
7
8
|
import sysconfig
|
|
9
|
+
import warnings
|
|
8
10
|
|
|
9
11
|
from dataclasses import dataclass
|
|
10
12
|
from contextlib import contextmanager
|
|
11
13
|
from typing import cast, Any, Callable, Generator, Generic, Optional, Protocol, Type, TypeVar, TypedDict, TYPE_CHECKING, Union
|
|
12
14
|
|
|
15
|
+
from triton._C.libtriton import getenv, getenv_bool # type: ignore
|
|
16
|
+
|
|
13
17
|
if TYPE_CHECKING:
|
|
14
18
|
from .runtime.cache import CacheManager, RemoteCacheBackend
|
|
15
19
|
from .runtime.jit import JitFunctionInfo, KernelParam
|
|
@@ -25,11 +29,6 @@ env = Env()
|
|
|
25
29
|
propagate_env: bool = True
|
|
26
30
|
|
|
27
31
|
|
|
28
|
-
def getenv(key: str) -> Optional[str]:
|
|
29
|
-
res = os.getenv(key)
|
|
30
|
-
return res.strip() if res is not None else res
|
|
31
|
-
|
|
32
|
-
|
|
33
32
|
def setenv(key: str, value: Optional[str]) -> None:
|
|
34
33
|
if not propagate_env:
|
|
35
34
|
return
|
|
@@ -62,32 +61,25 @@ def toenv(val: Any) -> Union[None, tuple[Optional[str]]]:
|
|
|
62
61
|
SetType = TypeVar("SetType")
|
|
63
62
|
GetType = TypeVar("GetType")
|
|
64
63
|
|
|
64
|
+
_NOTHING = object()
|
|
65
|
+
|
|
65
66
|
|
|
66
67
|
class env_base(Generic[SetType, GetType]):
|
|
67
68
|
|
|
68
|
-
def __init__(self, key: str
|
|
69
|
+
def __init__(self, key: str) -> None:
|
|
69
70
|
self.key = key
|
|
70
|
-
self.default: Callable[[], SetType] = default if callable(default) else lambda: default
|
|
71
71
|
|
|
72
72
|
def __set_name__(self, objclass: Type[object], name: str) -> None:
|
|
73
73
|
self.name = name
|
|
74
74
|
|
|
75
75
|
def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> GetType:
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
if self.name in obj.__dict__:
|
|
80
|
-
return self.transform(obj.__dict__[self.name])
|
|
81
|
-
else:
|
|
76
|
+
py_val = obj.__dict__.get(self.name, _NOTHING)
|
|
77
|
+
if py_val is _NOTHING:
|
|
82
78
|
return self.get()
|
|
83
|
-
|
|
84
|
-
@property
|
|
85
|
-
def env_val(self) -> str | None:
|
|
86
|
-
return getenv(self.key)
|
|
79
|
+
return self.transform(py_val)
|
|
87
80
|
|
|
88
81
|
def get(self) -> GetType:
|
|
89
|
-
|
|
90
|
-
return self.transform(self.default() if env is None else self.from_env(env))
|
|
82
|
+
raise NotImplementedError()
|
|
91
83
|
|
|
92
84
|
def __set__(self, obj: object, value: Union[SetType, Env]) -> None:
|
|
93
85
|
if isinstance(value, Env):
|
|
@@ -105,54 +97,70 @@ class env_base(Generic[SetType, GetType]):
|
|
|
105
97
|
# if GetType != SetType.
|
|
106
98
|
return cast(GetType, val)
|
|
107
99
|
|
|
108
|
-
def from_env(self, val: str) -> SetType:
|
|
109
|
-
raise NotImplementedError()
|
|
110
|
-
|
|
111
100
|
|
|
112
101
|
class env_str(env_base[str, str]):
|
|
113
102
|
|
|
114
|
-
def
|
|
115
|
-
|
|
103
|
+
def __init__(self, key: str, default: str):
|
|
104
|
+
super().__init__(key)
|
|
105
|
+
self.default = default
|
|
106
|
+
|
|
107
|
+
def get(self) -> str:
|
|
108
|
+
return getenv(self.key, self.default)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class env_str_callable_default(env_base[str, str]):
|
|
112
|
+
|
|
113
|
+
def __init__(self, key: str, default_factory: Callable[[], str]):
|
|
114
|
+
super().__init__(key)
|
|
115
|
+
self.default_factory = default_factory
|
|
116
|
+
|
|
117
|
+
def get(self) -> str:
|
|
118
|
+
env_val = getenv(self.key)
|
|
119
|
+
if env_val is None:
|
|
120
|
+
return self.default_factory()
|
|
121
|
+
return env_val
|
|
116
122
|
|
|
117
123
|
|
|
118
124
|
class env_bool(env_base[bool, bool]):
|
|
119
125
|
|
|
120
|
-
def __init__(self, key: str, default:
|
|
121
|
-
super().__init__(key
|
|
126
|
+
def __init__(self, key: str, default: bool = False) -> None:
|
|
127
|
+
super().__init__(key)
|
|
128
|
+
self.default = default
|
|
122
129
|
|
|
123
|
-
def
|
|
124
|
-
return
|
|
130
|
+
def get(self) -> bool:
|
|
131
|
+
return getenv_bool(self.key, self.default)
|
|
125
132
|
|
|
126
133
|
|
|
127
134
|
class env_int(env_base[int, int]):
|
|
128
135
|
|
|
129
|
-
def __init__(self, key: str, default:
|
|
130
|
-
super().__init__(key
|
|
136
|
+
def __init__(self, key: str, default: int = 0) -> None:
|
|
137
|
+
super().__init__(key)
|
|
138
|
+
self.default = default
|
|
131
139
|
|
|
132
|
-
def
|
|
140
|
+
def get(self) -> int:
|
|
141
|
+
val = getenv(self.key)
|
|
142
|
+
if val is None:
|
|
143
|
+
return self.default
|
|
133
144
|
try:
|
|
134
145
|
return int(val)
|
|
135
146
|
except ValueError as exc:
|
|
136
147
|
raise RuntimeError(f"Unable to use {self.key}={val}: expected int") from exc
|
|
137
148
|
|
|
138
149
|
|
|
139
|
-
class env_opt_base(Generic[GetType, SetType], env_base[Optional[GetType], Optional[SetType]]):
|
|
140
|
-
|
|
141
|
-
def __init__(self, key: str) -> None:
|
|
142
|
-
super().__init__(key, None)
|
|
143
|
-
|
|
144
|
-
|
|
145
150
|
ClassType = TypeVar("ClassType")
|
|
146
151
|
|
|
147
152
|
|
|
148
|
-
class env_class(Generic[ClassType],
|
|
153
|
+
class env_class(Generic[ClassType], env_base[Optional[Type[ClassType]], Optional[Type[ClassType]]]):
|
|
149
154
|
|
|
150
155
|
def __init__(self, key: str, type: str) -> None:
|
|
151
156
|
super().__init__(key)
|
|
152
157
|
# We can't pass the type directly to avoid import cycles
|
|
153
158
|
self.type = type
|
|
154
159
|
|
|
155
|
-
def
|
|
160
|
+
def get(self) -> Optional[Type[ClassType]]:
|
|
161
|
+
val = getenv(self.key)
|
|
162
|
+
if val is None:
|
|
163
|
+
return None
|
|
156
164
|
comps = val.split(":", 1)
|
|
157
165
|
if len(comps) != 2:
|
|
158
166
|
raise RuntimeError(f"Unable to read {self.key}: '{val}' isn't of the form MODULE:CLASS")
|
|
@@ -170,16 +178,15 @@ class NvidiaTool:
|
|
|
170
178
|
version: str
|
|
171
179
|
|
|
172
180
|
@staticmethod
|
|
181
|
+
@functools.lru_cache
|
|
173
182
|
def from_path(path: str) -> Optional[NvidiaTool]:
|
|
174
183
|
try:
|
|
175
184
|
result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
|
|
176
|
-
if result is None:
|
|
177
|
-
return None
|
|
178
185
|
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
|
179
186
|
if version is None:
|
|
180
187
|
return None
|
|
181
188
|
return NvidiaTool(path, version.group(1))
|
|
182
|
-
except subprocess.CalledProcessError:
|
|
189
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
183
190
|
return None
|
|
184
191
|
|
|
185
192
|
|
|
@@ -202,6 +209,7 @@ def find_nvidia_tool(binary: str) -> str:
|
|
|
202
209
|
if os.access(path, os.X_OK):
|
|
203
210
|
return path
|
|
204
211
|
|
|
212
|
+
warnings.warn(f"Failed to find {binary}")
|
|
205
213
|
return ""
|
|
206
214
|
|
|
207
215
|
|
|
@@ -210,34 +218,38 @@ class env_nvidia_tool(env_base[str, NvidiaTool]):
|
|
|
210
218
|
def __init__(self, binary: str) -> None:
|
|
211
219
|
binary += sysconfig.get_config_var("EXE")
|
|
212
220
|
self.binary = binary
|
|
213
|
-
|
|
221
|
+
self.default_path = find_nvidia_tool(binary)
|
|
222
|
+
super().__init__(f"TRITON_{binary.upper()}_PATH")
|
|
223
|
+
|
|
224
|
+
def get(self) -> NvidiaTool:
|
|
225
|
+
return self.transform(getenv(self.key))
|
|
214
226
|
|
|
215
227
|
def transform(self, path: str) -> NvidiaTool:
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
228
|
+
# We still add default as fallback in case the pointed binary isn't
|
|
229
|
+
# accessible.
|
|
230
|
+
if path is not None:
|
|
231
|
+
paths = [path, self.default_path]
|
|
232
|
+
else:
|
|
233
|
+
paths = [self.default_path]
|
|
234
|
+
|
|
222
235
|
for path in paths:
|
|
223
|
-
if not path or not os.access(path, os.X_OK):
|
|
224
|
-
continue
|
|
225
236
|
if tool := NvidiaTool.from_path(path):
|
|
226
237
|
return tool
|
|
227
238
|
|
|
228
239
|
raise RuntimeError(f"Cannot find {self.binary}")
|
|
229
240
|
|
|
230
|
-
def from_env(self, val: str) -> str:
|
|
231
|
-
return val
|
|
232
|
-
|
|
233
241
|
|
|
234
242
|
# Separate classes so that types are correct
|
|
235
|
-
class env_opt_str(
|
|
236
|
-
|
|
243
|
+
class env_opt_str(env_base[Optional[str], Optional[str]]):
|
|
244
|
+
|
|
245
|
+
def get(self) -> Optional[str]:
|
|
246
|
+
return getenv(self.key)
|
|
237
247
|
|
|
238
248
|
|
|
239
|
-
class env_opt_bool(
|
|
240
|
-
|
|
249
|
+
class env_opt_bool(env_base):
|
|
250
|
+
|
|
251
|
+
def get(self) -> Optional[str]:
|
|
252
|
+
return getenv_bool(self.key, None)
|
|
241
253
|
|
|
242
254
|
|
|
243
255
|
@dataclass(frozen=True)
|
|
@@ -305,7 +317,7 @@ class base_knobs:
|
|
|
305
317
|
@contextmanager
|
|
306
318
|
def scope(self) -> Generator[None, None, None]:
|
|
307
319
|
try:
|
|
308
|
-
initial_env = {knob.key: knob.
|
|
320
|
+
initial_env = {knob.key: getenv(knob.key) for knob in self.knob_descriptors.values()}
|
|
309
321
|
orig = dict(self.__dict__)
|
|
310
322
|
yield
|
|
311
323
|
finally:
|
|
@@ -350,11 +362,11 @@ cache: cache_knobs
|
|
|
350
362
|
|
|
351
363
|
|
|
352
364
|
class cache_knobs(base_knobs):
|
|
353
|
-
home_dir: env_str = env_str("TRITON_HOME",
|
|
365
|
+
home_dir: env_str = env_str("TRITON_HOME", os.path.expanduser("~/"))
|
|
354
366
|
|
|
355
|
-
dump_dir
|
|
356
|
-
override_dir
|
|
357
|
-
dir
|
|
367
|
+
dump_dir = env_str_callable_default("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump"))
|
|
368
|
+
override_dir = env_str_callable_default("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override"))
|
|
369
|
+
dir = env_str_callable_default("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache"))
|
|
358
370
|
|
|
359
371
|
manager_class: env_class[CacheManager] = env_class("TRITON_CACHE_MANAGER", "CacheManager")
|
|
360
372
|
remote_manager_class: env_class[RemoteCacheBackend] = env_class("TRITON_REMOTE_CACHE_BACKEND", "RemoteCacheBackend")
|
|
@@ -374,6 +386,7 @@ class compilation_knobs(base_knobs):
|
|
|
374
386
|
disable_line_info: env_bool = env_bool("TRITON_DISABLE_LINE_INFO")
|
|
375
387
|
front_end_debugging: env_bool = env_bool("TRITON_FRONT_END_DEBUGGING")
|
|
376
388
|
allow_non_constexpr_globals: env_bool = env_bool("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS")
|
|
389
|
+
enable_experimental_consan: env_bool = env_bool("TRITON_ENABLE_EXPERIMENTAL_CONSAN")
|
|
377
390
|
listener: Union[CompilationListener, None] = None
|
|
378
391
|
|
|
379
392
|
|
|
@@ -383,11 +396,53 @@ class autotuning_knobs(base_knobs):
|
|
|
383
396
|
|
|
384
397
|
|
|
385
398
|
class LaunchHook(Protocol):
|
|
399
|
+
"""Hook invoked before and after kernel launching
|
|
400
|
+
"""
|
|
386
401
|
|
|
387
402
|
def __call__(self, metadata: LazyDict) -> None:
|
|
388
403
|
...
|
|
389
404
|
|
|
390
405
|
|
|
406
|
+
class InitHandleHook(Protocol):
|
|
407
|
+
"""Hook invoked around kernel binary/module loading.
|
|
408
|
+
module/function can be None for the *start* hook (before loading).
|
|
409
|
+
"""
|
|
410
|
+
|
|
411
|
+
def __call__(
|
|
412
|
+
self,
|
|
413
|
+
module: Optional[object],
|
|
414
|
+
function: Optional[Callable],
|
|
415
|
+
name: str,
|
|
416
|
+
metadata_group: dict[str, str],
|
|
417
|
+
hash: str,
|
|
418
|
+
) -> None:
|
|
419
|
+
...
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
F = TypeVar("F", bound=Callable)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class HookChain(Generic[F]):
|
|
426
|
+
"""A chain of hooks of the same type F to be called in order.
|
|
427
|
+
"""
|
|
428
|
+
|
|
429
|
+
def __init__(self, reversed: bool = False):
|
|
430
|
+
self.calls: list[F] = []
|
|
431
|
+
self.reversed = reversed
|
|
432
|
+
|
|
433
|
+
def add(self, func: F) -> None:
|
|
434
|
+
if func not in self.calls:
|
|
435
|
+
self.calls.append(func)
|
|
436
|
+
|
|
437
|
+
def remove(self, func: F) -> None:
|
|
438
|
+
if func in self.calls:
|
|
439
|
+
self.calls.remove(func)
|
|
440
|
+
|
|
441
|
+
def __call__(self, *args, **kwargs):
|
|
442
|
+
for call in self.calls if not self.reversed else reversed(self.calls):
|
|
443
|
+
call(*args, **kwargs)
|
|
444
|
+
|
|
445
|
+
|
|
391
446
|
# This is of the form [attr_name, attr_val]
|
|
392
447
|
# TODO: Use tuple instead of list for better typing.
|
|
393
448
|
KernelAttr = list[Union[str, int]]
|
|
@@ -418,11 +473,15 @@ class JITHook(Protocol):
|
|
|
418
473
|
|
|
419
474
|
class runtime_knobs(base_knobs):
|
|
420
475
|
interpret: env_bool = env_bool("TRITON_INTERPRET")
|
|
421
|
-
debug
|
|
476
|
+
# debug is on critical path for kernel launches
|
|
477
|
+
# avoid repeated reads from env-var by calling get directly
|
|
478
|
+
debug: bool = env_bool("TRITON_DEBUG").get()
|
|
422
479
|
override_arch: env_opt_str = env_opt_str("TRITON_OVERRIDE_ARCH")
|
|
423
480
|
|
|
424
|
-
launch_enter_hook:
|
|
425
|
-
launch_exit_hook:
|
|
481
|
+
launch_enter_hook: HookChain[LaunchHook] = HookChain()
|
|
482
|
+
launch_exit_hook: HookChain[LaunchHook] = HookChain(reversed=True)
|
|
483
|
+
kernel_load_start_hook: HookChain[InitHandleHook] = HookChain()
|
|
484
|
+
kernel_load_end_hook: HookChain[InitHandleHook] = HookChain(reversed=True)
|
|
426
485
|
|
|
427
486
|
# Hook for inspecting compiled functions and modules
|
|
428
487
|
jit_cache_hook: Optional[JITHook] = None
|
|
@@ -444,6 +503,7 @@ class nvidia_knobs(base_knobs):
|
|
|
444
503
|
dump_nvptx: env_bool = env_bool("NVPTX_ENABLE_DUMP")
|
|
445
504
|
disable_ptxas_opt: env_bool = env_bool("DISABLE_PTXAS_OPT")
|
|
446
505
|
mock_ptx_version: env_opt_str = env_opt_str("TRITON_MOCK_PTX_VERSION")
|
|
506
|
+
dump_ptxas_log: env_bool = env_bool("TRITON_DUMP_PTXAS_LOG")
|
|
447
507
|
|
|
448
508
|
libdevice_path: env_opt_str = env_opt_str("TRITON_LIBDEVICE_PATH")
|
|
449
509
|
libcuda_path: env_opt_str = env_opt_str("TRITON_LIBCUDA_PATH")
|
|
@@ -451,9 +511,10 @@ class nvidia_knobs(base_knobs):
|
|
|
451
511
|
|
|
452
512
|
class amd_knobs(base_knobs):
|
|
453
513
|
use_buffer_ops: env_bool = env_bool("AMDGCN_USE_BUFFER_OPS")
|
|
514
|
+
# Note: This requires use_buffer_ops be true to have any effect
|
|
515
|
+
use_buffer_atomics: env_bool = env_bool("AMDGCN_USE_BUFFER_ATOMICS", True)
|
|
454
516
|
dump_amdgcn: env_bool = env_bool("AMDGCN_ENABLE_DUMP")
|
|
455
517
|
libhip_path: env_opt_str = env_opt_str("TRITON_LIBHIP_PATH")
|
|
456
|
-
lld_path: env_opt_str = env_opt_str("TRITON_HIP_LLD_PATH")
|
|
457
518
|
|
|
458
519
|
# We use strs so that we can have a default value based on other runtime info
|
|
459
520
|
use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG")
|
|
@@ -479,3 +540,7 @@ language = language_knobs()
|
|
|
479
540
|
nvidia = nvidia_knobs()
|
|
480
541
|
amd = amd_knobs()
|
|
481
542
|
proton = proton_knobs()
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def refresh_knobs():
|
|
546
|
+
runtime.debug = env_bool("TRITON_DEBUG").get()
|