triton-windows 3.4.0.post20__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,132 @@
1
- from . import mbarrier
2
- from . import tma
1
+ from __future__ import annotations
2
+ from triton.compiler.code_generator import unflatten_ir_values
3
+ from ..ampere import async_copy
4
+ from . import mbarrier, tma
3
5
  from ... import _core
4
6
 
5
- __all__ = ["fence_async_shared", "mbarrier", "tma"]
7
+ from typing import List, Tuple, TYPE_CHECKING
8
+ if TYPE_CHECKING:
9
+ from triton._C.libtriton import ir
10
+
11
+ __all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma", "warpgroup_mma", "warpgroup_mma_wait"]
6
12
 
7
13
 
8
14
  @_core.builtin
9
15
  def fence_async_shared(cluster=False, _semantic=None):
16
+ """
17
+ Issue a fence to complete asynchronous shared memory operations.
18
+
19
+ Args:
20
+ cluster (bool): Whether to fence across cluster. Defaults to False.
21
+ """
10
22
  cluster = _core._unwrap_if_constexpr(cluster)
11
23
  _semantic.builder.create_fence_async_shared(cluster)
24
+
25
+
26
+ class warpgroup_mma_accumulator_type(_core.base_type):
27
+ tensor_type: _core.dtype
28
+
29
+ def __init__(self, tensor_type: _core.dtype):
30
+ self.tensor_type = tensor_type
31
+
32
+ def __str__(self) -> str:
33
+ return f"warpgroup_mma_accumulator<{self.tensor_type}>"
34
+
35
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[warpgroup_mma_accumulator, int]:
36
+ return warpgroup_mma_accumulator(handles[cursor], self.tensor_type), cursor + 1
37
+
38
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
39
+ self.tensor_type._flatten_ir_types(builder, out)
40
+
41
+ def __eq__(self, other) -> bool:
42
+ return type(self) is type(other) and self.tensor_type == other.tensor_type
43
+
44
+ def mangle(self) -> str:
45
+ return f"FT{self.tensor_type.mangle()}FT"
46
+
47
+
48
+ class warpgroup_mma_accumulator(_core.base_value):
49
+ handle: ir.value
50
+ type: warpgroup_mma_accumulator_type
51
+
52
+ def __init__(self, handle, tensor_type: _core.dtype):
53
+ self.handle = handle
54
+ self.type = warpgroup_mma_accumulator_type(tensor_type)
55
+
56
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
57
+ handles.append(self.handle)
58
+
59
+
60
+ @_core.builtin
61
+ def warpgroup_mma_init(value, _semantic):
62
+ assert isinstance(value, _core.tensor)
63
+ return warpgroup_mma_accumulator(value.handle, value.type)
64
+
65
+
66
+ @_core.builtin
67
+ def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_acc=None, is_async=False,
68
+ _semantic=None):
69
+ """
70
+ Perform warpgroup MMA (Tensor Core) operations.
71
+ acc = a * b + (acc if use_acc else 0)
72
+
73
+ Args:
74
+ a (tensor or shared_memory_descriptor): Left hand side operand.
75
+ b (shared_memory_descriptor): Right hand side operand.
76
+ acc (tensor): Accumulator tensor.
77
+ use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True.
78
+ precision (str, optional): Dot input precision. Defaults to builder default.
79
+ max_num_imprecise_acc (int): Max imprecise accumulations. Used for fp8 -> fp32 dot. Determines how many accumulation are done in limited precision. Defaults to None, which means no upcasting is done.
80
+ is_async (bool): Whether operation is asynchronous. Defaults to False.
81
+
82
+ Returns:
83
+ tensor or warpgroup_mma_accumulator: Returns the result if synchronous, or a token to load the value once computed if asynchronous.
84
+ """
85
+ use_acc = _semantic.to_tensor(use_acc)
86
+
87
+ if precision is None:
88
+ precision = _semantic.builder.options.default_dot_input_precision
89
+
90
+ precision = _semantic._str_to_dot_input_precision(precision)
91
+
92
+ K = a.type.shape[-1]
93
+ if max_num_imprecise_acc is None:
94
+ if a.dtype.is_fp8() and b.dtype.is_fp8():
95
+ max_num_imprecise_acc = _semantic.builder.options.max_num_imprecise_acc_default
96
+ else:
97
+ max_num_imprecise_acc = 0
98
+ else:
99
+ if a.dtype.is_fp8() and b.dtype.is_fp8() and max_num_imprecise_acc > K:
100
+ raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})")
101
+
102
+ max_num_imprecise_acc = _core._unwrap_if_constexpr(max_num_imprecise_acc)
103
+ is_async = _core._unwrap_if_constexpr(is_async)
104
+
105
+ handle = _semantic.builder.create_warpgroup_mma(a.handle, b.handle, acc.handle, use_acc.handle, precision,
106
+ max_num_imprecise_acc, is_async)
107
+ tensor_ty = acc.type.tensor_type if isinstance(acc, warpgroup_mma_accumulator) else acc.type
108
+ if is_async:
109
+ return warpgroup_mma_accumulator(handle, tensor_ty)
110
+ else:
111
+ return _core.tensor(handle, tensor_ty)
112
+
113
+
114
+ @_core.builtin
115
+ def warpgroup_mma_wait(num_outstanding=0, deps=None, _semantic=None):
116
+ """
117
+ Wait until `num_outstanding` or less warpgroup MMA operations are in-flight.
118
+
119
+ Args:
120
+ num_outstanding (int): Number of outstanding warpgroup MMA operations to wait for. Defaults to 0.
121
+ deps (Sequence[tensor]): List of dependencies that need to be kept alive while the mma is unfinished.
122
+ """
123
+ if deps is None:
124
+ raise ValueError("warpgroup_mma_wait deps must be given")
125
+ deps_handles = [x.handle for x in deps] if deps is not None else []
126
+ num_outstanding = _core._unwrap_if_constexpr(num_outstanding)
127
+ results = _semantic.builder.create_warpgroup_mma_wait(deps_handles, num_outstanding)
128
+ result_types = [dep.type.tensor_type if isinstance(dep, warpgroup_mma_accumulator) else dep.type for dep in deps]
129
+ results = unflatten_ir_values(results, result_types)
130
+ if len(deps) == 1:
131
+ return next(results)
132
+ return tuple(results)
@@ -1,51 +1,34 @@
1
- from triton.experimental.gluon.language._layouts import SwizzledSharedLayout
2
- from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
1
+ from ..ampere.mbarrier import MBarrierLayout, init, invalidate, wait
2
+ from ..._core import _unwrap_if_constexpr, builtin
3
3
 
4
- __all__ = ["MBarrierLayout", "init", "invalidate", "expect", "wait", "arrive"]
5
-
6
-
7
- class MBarrierLayout(SwizzledSharedLayout):
8
-
9
- def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1):
10
- super().__init__(
11
- vec=1,
12
- per_phase=1,
13
- max_phase=1,
14
- order=[0],
15
- ctas_per_cga=[ctas_per_cga],
16
- cta_split_num=[cta_split_num],
17
- cta_order=[0],
18
- )
19
-
20
-
21
- @builtin
22
- def init(mbarrier, count, _semantic=None):
23
- count = _unwrap_if_constexpr(count)
24
- _semantic.builder.create_mbarrier_init(mbarrier.handle, count)
25
-
26
-
27
- @builtin
28
- def invalidate(mbarrier, _semantic=None):
29
- _semantic.builder.create_mbarrier_inval(mbarrier.handle)
4
+ __all__ = ["arrive", "expect", "init", "invalidate", "MBarrierLayout", "wait"]
30
5
 
31
6
 
32
7
  @builtin
33
8
  def expect(mbarrier, bytes, pred=True, _semantic=None):
9
+ """
10
+ Expect a specific number of bytes being copied. When they are copied, the barrier is signaled.
11
+
12
+ Args:
13
+ mbarrier (shared_memory_descriptor): Barrier that will be signaled when the operation is complete.
14
+ bytes (int): Expected byte count.
15
+ pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
16
+ """
34
17
  bytes = _unwrap_if_constexpr(bytes)
35
18
  pred = _semantic.to_tensor(pred)
36
19
  _semantic.builder.create_mbarrier_expect(mbarrier.handle, bytes, pred.handle)
37
20
 
38
21
 
39
22
  @builtin
40
- def wait(mbarrier, phase, pred=True, deps=(), _semantic=None):
41
- phase = _semantic.to_tensor(phase)
42
- pred = _semantic.to_tensor(pred)
43
- deps = [x.handle for x in deps]
44
- _semantic.builder.create_mbarrier_wait(mbarrier.handle, phase.handle, pred.handle, deps)
45
-
46
-
47
- @builtin
48
- def arrive(mbarrier, count, pred=True, _semantic=None):
23
+ def arrive(mbarrier, *, count=1, pred=True, _semantic=None):
24
+ """
25
+ Arrive at an mbarrier with a specified count.
26
+
27
+ Args:
28
+ mbarrier (shared_memory_descriptor): Barrier to be signalled.
29
+ count (int): Count to arrive with. Defaults to 1.
30
+ pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True.
31
+ """
49
32
  count = _unwrap_if_constexpr(count)
50
33
  pred = _semantic.to_tensor(pred)
51
34
  _semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle)
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
  from typing import List, Tuple, TYPE_CHECKING
3
3
  from dataclasses import dataclass
4
+ from triton.language.core import base_type, base_value
4
5
  import triton.experimental.gluon.language._core as ttgl
5
6
  from triton.experimental.gluon.language._layouts import NVMMASharedLayout
6
7
  from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
@@ -12,7 +13,7 @@ __all__ = ["async_copy_global_to_shared", "async_copy_shared_to_global", "store_
12
13
 
13
14
 
14
15
  @dataclass(eq=True)
15
- class tensor_descriptor_type:
16
+ class tensor_descriptor_type(base_type):
16
17
  block_type: ttgl.block_type
17
18
  shape_type: ttgl.tuple_type
18
19
  strides_type: ttgl.tuple_type
@@ -41,10 +42,10 @@ class tensor_descriptor_type:
41
42
  self.strides_type._flatten_ir_types(builder, out)
42
43
 
43
44
  def mangle(self) -> str:
44
- return f"TD{self.block_type.mangle}_{self.layout.mangle()}TD"
45
+ return f"TD{self.block_type.mangle()}_{self.layout.mangle()}TD"
45
46
 
46
47
 
47
- class tensor_descriptor:
48
+ class tensor_descriptor(base_value):
48
49
 
49
50
  def __init__(self, handle, shape: List[ttgl.tensor], strides: List[ttgl.tensor], block_type: ttgl.block_type,
50
51
  layout: NVMMASharedLayout):
@@ -13,6 +13,7 @@ class TensorDescriptor:
13
13
  strides: List[int]
14
14
  block_shape: List[int]
15
15
  layout: NVMMASharedLayout
16
+ padding: str = "zero"
16
17
 
17
18
  def __post_init__(self):
18
19
  rank = len(self.shape)
@@ -28,13 +29,17 @@ class TensorDescriptor:
28
29
  assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
29
30
  assert self.strides[-1] == 1, "Last dimension must be contiguous"
30
31
  assert isinstance(self.layout, NVMMASharedLayout), "Layout must be NVMMASharedLayout"
32
+ assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
33
+ if self.padding == "nan":
34
+ assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
31
35
 
32
36
  @staticmethod
33
- def from_tensor(tensor: Any, block_shape: List[int], layout: NVMMASharedLayout):
37
+ def from_tensor(tensor: Any, block_shape: List[int], layout: NVMMASharedLayout, padding="zero"):
34
38
  return TensorDescriptor(
35
39
  tensor,
36
40
  tensor.shape,
37
41
  tensor.stride(),
38
42
  block_shape,
39
43
  layout,
44
+ padding,
40
45
  )
triton/knobs.py CHANGED
@@ -1,15 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import functools
3
4
  import importlib
4
5
  import os
5
6
  import re
6
7
  import subprocess
7
8
  import sysconfig
9
+ import warnings
8
10
 
9
11
  from dataclasses import dataclass
10
12
  from contextlib import contextmanager
11
13
  from typing import cast, Any, Callable, Generator, Generic, Optional, Protocol, Type, TypeVar, TypedDict, TYPE_CHECKING, Union
12
14
 
15
+ from triton._C.libtriton import getenv, getenv_bool # type: ignore
16
+
13
17
  if TYPE_CHECKING:
14
18
  from .runtime.cache import CacheManager, RemoteCacheBackend
15
19
  from .runtime.jit import JitFunctionInfo, KernelParam
@@ -25,11 +29,6 @@ env = Env()
25
29
  propagate_env: bool = True
26
30
 
27
31
 
28
- def getenv(key: str) -> Optional[str]:
29
- res = os.getenv(key)
30
- return res.strip() if res is not None else res
31
-
32
-
33
32
  def setenv(key: str, value: Optional[str]) -> None:
34
33
  if not propagate_env:
35
34
  return
@@ -62,32 +61,25 @@ def toenv(val: Any) -> Union[None, tuple[Optional[str]]]:
62
61
  SetType = TypeVar("SetType")
63
62
  GetType = TypeVar("GetType")
64
63
 
64
+ _NOTHING = object()
65
+
65
66
 
66
67
  class env_base(Generic[SetType, GetType]):
67
68
 
68
- def __init__(self, key: str, default: Union[SetType, Callable[[], SetType]]) -> None:
69
+ def __init__(self, key: str) -> None:
69
70
  self.key = key
70
- self.default: Callable[[], SetType] = default if callable(default) else lambda: default
71
71
 
72
72
  def __set_name__(self, objclass: Type[object], name: str) -> None:
73
73
  self.name = name
74
74
 
75
75
  def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> GetType:
76
- if obj is None:
77
- raise AttributeError(f"Cannot access {type(self)} on non-instance")
78
-
79
- if self.name in obj.__dict__:
80
- return self.transform(obj.__dict__[self.name])
81
- else:
76
+ py_val = obj.__dict__.get(self.name, _NOTHING)
77
+ if py_val is _NOTHING:
82
78
  return self.get()
83
-
84
- @property
85
- def env_val(self) -> str | None:
86
- return getenv(self.key)
79
+ return self.transform(py_val)
87
80
 
88
81
  def get(self) -> GetType:
89
- env = self.env_val
90
- return self.transform(self.default() if env is None else self.from_env(env))
82
+ raise NotImplementedError()
91
83
 
92
84
  def __set__(self, obj: object, value: Union[SetType, Env]) -> None:
93
85
  if isinstance(value, Env):
@@ -105,54 +97,70 @@ class env_base(Generic[SetType, GetType]):
105
97
  # if GetType != SetType.
106
98
  return cast(GetType, val)
107
99
 
108
- def from_env(self, val: str) -> SetType:
109
- raise NotImplementedError()
110
-
111
100
 
112
101
  class env_str(env_base[str, str]):
113
102
 
114
- def from_env(self, val: str) -> str:
115
- return val
103
+ def __init__(self, key: str, default: str):
104
+ super().__init__(key)
105
+ self.default = default
106
+
107
+ def get(self) -> str:
108
+ return getenv(self.key, self.default)
109
+
110
+
111
+ class env_str_callable_default(env_base[str, str]):
112
+
113
+ def __init__(self, key: str, default_factory: Callable[[], str]):
114
+ super().__init__(key)
115
+ self.default_factory = default_factory
116
+
117
+ def get(self) -> str:
118
+ env_val = getenv(self.key)
119
+ if env_val is None:
120
+ return self.default_factory()
121
+ return env_val
116
122
 
117
123
 
118
124
  class env_bool(env_base[bool, bool]):
119
125
 
120
- def __init__(self, key: str, default: Union[bool, Callable[[], bool]] = False) -> None:
121
- super().__init__(key, default)
126
+ def __init__(self, key: str, default: bool = False) -> None:
127
+ super().__init__(key)
128
+ self.default = default
122
129
 
123
- def from_env(self, val: str) -> bool:
124
- return val.lower() in ("1", "true", "yes", "on", "y")
130
+ def get(self) -> bool:
131
+ return getenv_bool(self.key, self.default)
125
132
 
126
133
 
127
134
  class env_int(env_base[int, int]):
128
135
 
129
- def __init__(self, key: str, default: Union[int, Callable[[], int]] = 0) -> None:
130
- super().__init__(key, default)
136
+ def __init__(self, key: str, default: int = 0) -> None:
137
+ super().__init__(key)
138
+ self.default = default
131
139
 
132
- def from_env(self, val: str) -> int:
140
+ def get(self) -> int:
141
+ val = getenv(self.key)
142
+ if val is None:
143
+ return self.default
133
144
  try:
134
145
  return int(val)
135
146
  except ValueError as exc:
136
147
  raise RuntimeError(f"Unable to use {self.key}={val}: expected int") from exc
137
148
 
138
149
 
139
- class env_opt_base(Generic[GetType, SetType], env_base[Optional[GetType], Optional[SetType]]):
140
-
141
- def __init__(self, key: str) -> None:
142
- super().__init__(key, None)
143
-
144
-
145
150
  ClassType = TypeVar("ClassType")
146
151
 
147
152
 
148
- class env_class(Generic[ClassType], env_opt_base[Type[ClassType], Type[ClassType]]):
153
+ class env_class(Generic[ClassType], env_base[Optional[Type[ClassType]], Optional[Type[ClassType]]]):
149
154
 
150
155
  def __init__(self, key: str, type: str) -> None:
151
156
  super().__init__(key)
152
157
  # We can't pass the type directly to avoid import cycles
153
158
  self.type = type
154
159
 
155
- def from_env(self, val: str) -> Type[ClassType]:
160
+ def get(self) -> Optional[Type[ClassType]]:
161
+ val = getenv(self.key)
162
+ if val is None:
163
+ return None
156
164
  comps = val.split(":", 1)
157
165
  if len(comps) != 2:
158
166
  raise RuntimeError(f"Unable to read {self.key}: '{val}' isn't of the form MODULE:CLASS")
@@ -170,16 +178,15 @@ class NvidiaTool:
170
178
  version: str
171
179
 
172
180
  @staticmethod
181
+ @functools.lru_cache
173
182
  def from_path(path: str) -> Optional[NvidiaTool]:
174
183
  try:
175
184
  result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
176
- if result is None:
177
- return None
178
185
  version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
179
186
  if version is None:
180
187
  return None
181
188
  return NvidiaTool(path, version.group(1))
182
- except subprocess.CalledProcessError:
189
+ except (subprocess.CalledProcessError, FileNotFoundError):
183
190
  return None
184
191
 
185
192
 
@@ -202,6 +209,7 @@ def find_nvidia_tool(binary: str) -> str:
202
209
  if os.access(path, os.X_OK):
203
210
  return path
204
211
 
212
+ warnings.warn(f"Failed to find {binary}")
205
213
  return ""
206
214
 
207
215
 
@@ -210,34 +218,38 @@ class env_nvidia_tool(env_base[str, NvidiaTool]):
210
218
  def __init__(self, binary: str) -> None:
211
219
  binary += sysconfig.get_config_var("EXE")
212
220
  self.binary = binary
213
- super().__init__(f"TRITON_{binary.upper()}_PATH", lambda: find_nvidia_tool(self.binary))
221
+ self.default_path = find_nvidia_tool(binary)
222
+ super().__init__(f"TRITON_{binary.upper()}_PATH")
223
+
224
+ def get(self) -> NvidiaTool:
225
+ return self.transform(getenv(self.key))
214
226
 
215
227
  def transform(self, path: str) -> NvidiaTool:
216
- paths = [
217
- path,
218
- # We still add default as fallback in case the pointed binary isn't
219
- # accessible.
220
- self.default(),
221
- ]
228
+ # We still add default as fallback in case the pointed binary isn't
229
+ # accessible.
230
+ if path is not None:
231
+ paths = [path, self.default_path]
232
+ else:
233
+ paths = [self.default_path]
234
+
222
235
  for path in paths:
223
- if not path or not os.access(path, os.X_OK):
224
- continue
225
236
  if tool := NvidiaTool.from_path(path):
226
237
  return tool
227
238
 
228
239
  raise RuntimeError(f"Cannot find {self.binary}")
229
240
 
230
- def from_env(self, val: str) -> str:
231
- return val
232
-
233
241
 
234
242
  # Separate classes so that types are correct
235
- class env_opt_str(env_opt_base[str, str], env_str):
236
- pass
243
+ class env_opt_str(env_base[Optional[str], Optional[str]]):
244
+
245
+ def get(self) -> Optional[str]:
246
+ return getenv(self.key)
237
247
 
238
248
 
239
- class env_opt_bool(env_opt_base[bool, bool], env_bool):
240
- pass
249
+ class env_opt_bool(env_base):
250
+
251
+ def get(self) -> Optional[str]:
252
+ return getenv_bool(self.key, None)
241
253
 
242
254
 
243
255
  @dataclass(frozen=True)
@@ -305,7 +317,7 @@ class base_knobs:
305
317
  @contextmanager
306
318
  def scope(self) -> Generator[None, None, None]:
307
319
  try:
308
- initial_env = {knob.key: knob.env_val for knob in self.knob_descriptors.values()}
320
+ initial_env = {knob.key: getenv(knob.key) for knob in self.knob_descriptors.values()}
309
321
  orig = dict(self.__dict__)
310
322
  yield
311
323
  finally:
@@ -350,11 +362,11 @@ cache: cache_knobs
350
362
 
351
363
 
352
364
  class cache_knobs(base_knobs):
353
- home_dir: env_str = env_str("TRITON_HOME", lambda: os.path.expanduser("~/"))
365
+ home_dir: env_str = env_str("TRITON_HOME", os.path.expanduser("~/"))
354
366
 
355
- dump_dir: env_str = env_str("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump"))
356
- override_dir: env_str = env_str("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override"))
357
- dir: env_str = env_str("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache"))
367
+ dump_dir = env_str_callable_default("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump"))
368
+ override_dir = env_str_callable_default("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override"))
369
+ dir = env_str_callable_default("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache"))
358
370
 
359
371
  manager_class: env_class[CacheManager] = env_class("TRITON_CACHE_MANAGER", "CacheManager")
360
372
  remote_manager_class: env_class[RemoteCacheBackend] = env_class("TRITON_REMOTE_CACHE_BACKEND", "RemoteCacheBackend")
@@ -374,6 +386,7 @@ class compilation_knobs(base_knobs):
374
386
  disable_line_info: env_bool = env_bool("TRITON_DISABLE_LINE_INFO")
375
387
  front_end_debugging: env_bool = env_bool("TRITON_FRONT_END_DEBUGGING")
376
388
  allow_non_constexpr_globals: env_bool = env_bool("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS")
389
+ enable_experimental_consan: env_bool = env_bool("TRITON_ENABLE_EXPERIMENTAL_CONSAN")
377
390
  listener: Union[CompilationListener, None] = None
378
391
 
379
392
 
@@ -383,11 +396,53 @@ class autotuning_knobs(base_knobs):
383
396
 
384
397
 
385
398
  class LaunchHook(Protocol):
399
+ """Hook invoked before and after kernel launching
400
+ """
386
401
 
387
402
  def __call__(self, metadata: LazyDict) -> None:
388
403
  ...
389
404
 
390
405
 
406
+ class InitHandleHook(Protocol):
407
+ """Hook invoked around kernel binary/module loading.
408
+ module/function can be None for the *start* hook (before loading).
409
+ """
410
+
411
+ def __call__(
412
+ self,
413
+ module: Optional[object],
414
+ function: Optional[Callable],
415
+ name: str,
416
+ metadata_group: dict[str, str],
417
+ hash: str,
418
+ ) -> None:
419
+ ...
420
+
421
+
422
+ F = TypeVar("F", bound=Callable)
423
+
424
+
425
+ class HookChain(Generic[F]):
426
+ """A chain of hooks of the same type F to be called in order.
427
+ """
428
+
429
+ def __init__(self, reversed: bool = False):
430
+ self.calls: list[F] = []
431
+ self.reversed = reversed
432
+
433
+ def add(self, func: F) -> None:
434
+ if func not in self.calls:
435
+ self.calls.append(func)
436
+
437
+ def remove(self, func: F) -> None:
438
+ if func in self.calls:
439
+ self.calls.remove(func)
440
+
441
+ def __call__(self, *args, **kwargs):
442
+ for call in self.calls if not self.reversed else reversed(self.calls):
443
+ call(*args, **kwargs)
444
+
445
+
391
446
  # This is of the form [attr_name, attr_val]
392
447
  # TODO: Use tuple instead of list for better typing.
393
448
  KernelAttr = list[Union[str, int]]
@@ -418,11 +473,15 @@ class JITHook(Protocol):
418
473
 
419
474
  class runtime_knobs(base_knobs):
420
475
  interpret: env_bool = env_bool("TRITON_INTERPRET")
421
- debug: env_bool = env_bool("TRITON_DEBUG")
476
+ # debug is on critical path for kernel launches
477
+ # avoid repeated reads from env-var by calling get directly
478
+ debug: bool = env_bool("TRITON_DEBUG").get()
422
479
  override_arch: env_opt_str = env_opt_str("TRITON_OVERRIDE_ARCH")
423
480
 
424
- launch_enter_hook: Optional[LaunchHook] = None
425
- launch_exit_hook: Optional[LaunchHook] = None
481
+ launch_enter_hook: HookChain[LaunchHook] = HookChain()
482
+ launch_exit_hook: HookChain[LaunchHook] = HookChain(reversed=True)
483
+ kernel_load_start_hook: HookChain[InitHandleHook] = HookChain()
484
+ kernel_load_end_hook: HookChain[InitHandleHook] = HookChain(reversed=True)
426
485
 
427
486
  # Hook for inspecting compiled functions and modules
428
487
  jit_cache_hook: Optional[JITHook] = None
@@ -444,6 +503,7 @@ class nvidia_knobs(base_knobs):
444
503
  dump_nvptx: env_bool = env_bool("NVPTX_ENABLE_DUMP")
445
504
  disable_ptxas_opt: env_bool = env_bool("DISABLE_PTXAS_OPT")
446
505
  mock_ptx_version: env_opt_str = env_opt_str("TRITON_MOCK_PTX_VERSION")
506
+ dump_ptxas_log: env_bool = env_bool("TRITON_DUMP_PTXAS_LOG")
447
507
 
448
508
  libdevice_path: env_opt_str = env_opt_str("TRITON_LIBDEVICE_PATH")
449
509
  libcuda_path: env_opt_str = env_opt_str("TRITON_LIBCUDA_PATH")
@@ -451,9 +511,10 @@ class nvidia_knobs(base_knobs):
451
511
 
452
512
  class amd_knobs(base_knobs):
453
513
  use_buffer_ops: env_bool = env_bool("AMDGCN_USE_BUFFER_OPS")
514
+ # Note: This requires use_buffer_ops be true to have any effect
515
+ use_buffer_atomics: env_bool = env_bool("AMDGCN_USE_BUFFER_ATOMICS", True)
454
516
  dump_amdgcn: env_bool = env_bool("AMDGCN_ENABLE_DUMP")
455
517
  libhip_path: env_opt_str = env_opt_str("TRITON_LIBHIP_PATH")
456
- lld_path: env_opt_str = env_opt_str("TRITON_HIP_LLD_PATH")
457
518
 
458
519
  # We use strs so that we can have a default value based on other runtime info
459
520
  use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG")
@@ -479,3 +540,7 @@ language = language_knobs()
479
540
  nvidia = nvidia_knobs()
480
541
  amd = amd_knobs()
481
542
  proton = proton_knobs()
543
+
544
+
545
+ def refresh_knobs():
546
+ runtime.debug = env_bool("TRITON_DEBUG").get()