triton-windows 3.3.0.post19__cp310-cp310-win_amd64.whl → 3.4.0.post20__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (173) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/runtime/tcc/lib/python310.def +1610 -0
  56. triton/runtime/tcc/lib/python311.def +1633 -0
  57. triton/runtime/tcc/lib/python312.def +1703 -0
  58. triton/runtime/tcc/lib/python313.def +1651 -0
  59. triton/runtime/tcc/lib/python313t.def +1656 -0
  60. triton/runtime/tcc/lib/python39.def +1644 -0
  61. triton/runtime/tcc/lib/python3t.def +905 -0
  62. triton/testing.py +16 -12
  63. triton/tools/disasm.py +3 -4
  64. triton/tools/tensor_descriptor.py +36 -0
  65. triton/windows_utils.py +14 -6
  66. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  67. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  68. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
  69. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  70. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  71. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  72. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  73. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  80. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  81. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  82. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  83. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  84. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  85. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  86. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  87. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  88. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  89. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  90. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  91. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  92. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  93. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  94. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  95. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  96. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  97. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  98. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  99. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  100. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  101. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  102. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  103. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  104. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  105. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  106. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  107. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  108. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  109. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  110. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  111. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  112. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  113. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  114. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  115. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  116. triton/backends/amd/include/hip/device_functions.h +0 -38
  117. triton/backends/amd/include/hip/driver_types.h +0 -468
  118. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  119. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  120. triton/backends/amd/include/hip/hip_common.h +0 -100
  121. triton/backends/amd/include/hip/hip_complex.h +0 -38
  122. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  123. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  124. triton/backends/amd/include/hip/hip_ext.h +0 -161
  125. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  126. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  127. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  128. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  129. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  130. triton/backends/amd/include/hip/hip_profile.h +0 -27
  131. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  132. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  133. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  134. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  135. triton/backends/amd/include/hip/hip_version.h +0 -17
  136. triton/backends/amd/include/hip/hiprtc.h +0 -421
  137. triton/backends/amd/include/hip/library_types.h +0 -78
  138. triton/backends/amd/include/hip/math_functions.h +0 -42
  139. triton/backends/amd/include/hip/surface_types.h +0 -63
  140. triton/backends/amd/include/hip/texture_types.h +0 -194
  141. triton/backends/amd/include/hsa/Brig.h +0 -1131
  142. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  143. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  144. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  145. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  146. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  147. triton/backends/amd/include/hsa/hsa.h +0 -5738
  148. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  149. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  150. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  151. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  152. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  153. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  154. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  155. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  156. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  157. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  158. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  159. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  160. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  161. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  162. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  163. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  164. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  165. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  166. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  167. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  168. triton/backends/amd/include/roctracer/roctx.h +0 -229
  169. triton/language/_utils.py +0 -21
  170. triton/language/extra/cuda/_experimental_tma.py +0 -106
  171. triton/tools/experimental_descriptor.py +0 -32
  172. triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
  173. triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations
2
+ from typing import List, Tuple, TYPE_CHECKING
3
+ from dataclasses import dataclass
4
+ import triton.experimental.gluon.language._core as ttgl
5
+ from triton.experimental.gluon.language._layouts import NVMMASharedLayout
6
+ from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
7
+
8
+ if TYPE_CHECKING:
9
+ from triton._C import ir
10
+
11
+ __all__ = ["async_copy_global_to_shared", "async_copy_shared_to_global", "store_wait"]
12
+
13
+
14
+ @dataclass(eq=True)
15
+ class tensor_descriptor_type:
16
+ block_type: ttgl.block_type
17
+ shape_type: ttgl.tuple_type
18
+ strides_type: ttgl.tuple_type
19
+ layout: NVMMASharedLayout
20
+
21
+ def __str__(self) -> str:
22
+ return f"tensor_descriptor<{self.block_type}, {self.layout}>"
23
+
24
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]:
25
+ handle = handles[cursor]
26
+ cursor += 1
27
+ shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
28
+ strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
29
+ value = tensor_descriptor(handle, shape, strides, self.block_type, layout=self.layout)
30
+ return value, cursor
31
+
32
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
33
+ is_signed = self.block_type.element_ty.is_int_signed()
34
+ ty = builder.get_tensor_descriptor_layout_type(
35
+ self.block_type.to_ir(builder),
36
+ is_signed,
37
+ self.layout._to_ir(builder),
38
+ )
39
+ out.append(ty)
40
+ self.shape_type._flatten_ir_types(builder, out)
41
+ self.strides_type._flatten_ir_types(builder, out)
42
+
43
+ def mangle(self) -> str:
44
+ return f"TD{self.block_type.mangle}_{self.layout.mangle()}TD"
45
+
46
+
47
+ class tensor_descriptor:
48
+
49
+ def __init__(self, handle, shape: List[ttgl.tensor], strides: List[ttgl.tensor], block_type: ttgl.block_type,
50
+ layout: NVMMASharedLayout):
51
+ self.handle = handle
52
+ self.shape = ttgl.tuple(shape)
53
+ self.strides = ttgl.tuple(strides)
54
+ self.type = tensor_descriptor_type(block_type, shape_type=self.shape.type, strides_type=self.strides.type,
55
+ layout=layout)
56
+
57
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
58
+ handles.append(self.handle)
59
+ self.shape._flatten_ir(handles)
60
+ self.strides._flatten_ir(handles)
61
+
62
+ @property
63
+ def block_type(self):
64
+ return self.type.block_type
65
+
66
+ @property
67
+ def block_shape(self):
68
+ return self.type.block_type.shape
69
+
70
+ @property
71
+ def dtype(self):
72
+ return self.type.block_type.element_ty
73
+
74
+ @property
75
+ def layout(self):
76
+ return self.type.layout
77
+
78
+
79
+ @builtin
80
+ def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, _semantic=None):
81
+ coord = _semantic._convert_to_ir_values(coord, require_i64=False)
82
+ pred = _semantic.to_tensor(pred)
83
+ _semantic.builder.create_async_tma_copy_global_to_local(tensor_desc.handle, coord, barrier.handle, result.handle,
84
+ pred.handle)
85
+
86
+
87
+ @builtin
88
+ def async_copy_shared_to_global(tensor_desc, coord, src, _semantic=None):
89
+ coord = _semantic._convert_to_ir_values(coord, require_i64=False)
90
+ _semantic.builder.create_async_tma_copy_local_to_global(tensor_desc.handle, coord, src.handle)
91
+
92
+
93
+ @builtin
94
+ def store_wait(pendings, _semantic=None):
95
+ pendings = _unwrap_if_constexpr(pendings)
96
+ _semantic.builder.create_async_tma_store_wait(pendings)
@@ -0,0 +1,4 @@
1
+ from . import hopper
2
+ from . import blackwell
3
+
4
+ __all__ = ["hopper", "blackwell"]
@@ -0,0 +1,3 @@
1
+ from .hopper import TensorDescriptor
2
+
3
+ __all__ = ["TensorDescriptor"]
@@ -0,0 +1,40 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Any
3
+ from triton._utils import validate_block_shape, canonicalize_dtype, get_primitive_bitwidth
4
+ from triton.experimental.gluon.language._layouts import NVMMASharedLayout
5
+
6
+ __all__ = ["TensorDescriptor"]
7
+
8
+
9
+ @dataclass
10
+ class TensorDescriptor:
11
+ base: Any
12
+ shape: List[int]
13
+ strides: List[int]
14
+ block_shape: List[int]
15
+ layout: NVMMASharedLayout
16
+
17
+ def __post_init__(self):
18
+ rank = len(self.shape)
19
+ assert len(self.strides) == rank, f"rank mismatch: {self}"
20
+ assert len(self.block_shape) == rank, f"rank mismatch: {self}"
21
+ assert rank > 0, "rank must not be zero"
22
+ assert rank <= 5, "rank cannot be more than 5"
23
+ assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
24
+ validate_block_shape(self.block_shape)
25
+ dtype_str = canonicalize_dtype(self.base.dtype)
26
+ elem_bytes = get_primitive_bitwidth(dtype_str) // 8
27
+ for stride in self.strides[:-1]:
28
+ assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
29
+ assert self.strides[-1] == 1, "Last dimension must be contiguous"
30
+ assert isinstance(self.layout, NVMMASharedLayout), "Layout must be NVMMASharedLayout"
31
+
32
+ @staticmethod
33
+ def from_tensor(tensor: Any, block_shape: List[int], layout: NVMMASharedLayout):
34
+ return TensorDescriptor(
35
+ tensor,
36
+ tensor.shape,
37
+ tensor.stride(),
38
+ block_shape,
39
+ layout,
40
+ )
triton/knobs.py ADDED
@@ -0,0 +1,481 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import os
5
+ import re
6
+ import subprocess
7
+ import sysconfig
8
+
9
+ from dataclasses import dataclass
10
+ from contextlib import contextmanager
11
+ from typing import cast, Any, Callable, Generator, Generic, Optional, Protocol, Type, TypeVar, TypedDict, TYPE_CHECKING, Union
12
+
13
+ if TYPE_CHECKING:
14
+ from .runtime.cache import CacheManager, RemoteCacheBackend
15
+ from .runtime.jit import JitFunctionInfo, KernelParam
16
+ from .compiler.compiler import ASTSource, LazyDict, IRSource
17
+
18
+
19
+ class Env:
20
+ pass
21
+
22
+
23
+ env = Env()
24
+
25
+ propagate_env: bool = True
26
+
27
+
28
+ def getenv(key: str) -> Optional[str]:
29
+ res = os.getenv(key)
30
+ return res.strip() if res is not None else res
31
+
32
+
33
+ def setenv(key: str, value: Optional[str]) -> None:
34
+ if not propagate_env:
35
+ return
36
+
37
+ if value is not None:
38
+ os.environ[key] = value
39
+ elif key in os.environ:
40
+ del os.environ[key]
41
+
42
+
43
+ def toenv(val: Any) -> Union[None, tuple[Optional[str]]]:
44
+ if val is None:
45
+ return (None, )
46
+
47
+ t = type(val)
48
+ if t is bool:
49
+ return ("1" if val else "0", )
50
+
51
+ if t is str:
52
+ return (val, )
53
+
54
+ if t is int:
55
+ return (str(val), )
56
+
57
+ return None
58
+
59
+
60
+ # There's an asymmetry here so that e.g. env_nvidia_tool can be specified with a
61
+ # a string but return an NvidiaTool.
62
+ SetType = TypeVar("SetType")
63
+ GetType = TypeVar("GetType")
64
+
65
+
66
+ class env_base(Generic[SetType, GetType]):
67
+
68
+ def __init__(self, key: str, default: Union[SetType, Callable[[], SetType]]) -> None:
69
+ self.key = key
70
+ self.default: Callable[[], SetType] = default if callable(default) else lambda: default
71
+
72
+ def __set_name__(self, objclass: Type[object], name: str) -> None:
73
+ self.name = name
74
+
75
+ def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> GetType:
76
+ if obj is None:
77
+ raise AttributeError(f"Cannot access {type(self)} on non-instance")
78
+
79
+ if self.name in obj.__dict__:
80
+ return self.transform(obj.__dict__[self.name])
81
+ else:
82
+ return self.get()
83
+
84
+ @property
85
+ def env_val(self) -> str | None:
86
+ return getenv(self.key)
87
+
88
+ def get(self) -> GetType:
89
+ env = self.env_val
90
+ return self.transform(self.default() if env is None else self.from_env(env))
91
+
92
+ def __set__(self, obj: object, value: Union[SetType, Env]) -> None:
93
+ if isinstance(value, Env):
94
+ obj.__dict__.pop(self.name, None)
95
+ else:
96
+ obj.__dict__[self.name] = value
97
+ if env_val := toenv(value):
98
+ setenv(self.key, env_val[0])
99
+
100
+ def __delete__(self, obj: object) -> None:
101
+ obj.__dict__.pop(self.name, None)
102
+
103
+ def transform(self, val: SetType) -> GetType:
104
+ # See comment about GetType/SetType in their definition above. Only needed
105
+ # if GetType != SetType.
106
+ return cast(GetType, val)
107
+
108
+ def from_env(self, val: str) -> SetType:
109
+ raise NotImplementedError()
110
+
111
+
112
+ class env_str(env_base[str, str]):
113
+
114
+ def from_env(self, val: str) -> str:
115
+ return val
116
+
117
+
118
+ class env_bool(env_base[bool, bool]):
119
+
120
+ def __init__(self, key: str, default: Union[bool, Callable[[], bool]] = False) -> None:
121
+ super().__init__(key, default)
122
+
123
+ def from_env(self, val: str) -> bool:
124
+ return val.lower() in ("1", "true", "yes", "on", "y")
125
+
126
+
127
+ class env_int(env_base[int, int]):
128
+
129
+ def __init__(self, key: str, default: Union[int, Callable[[], int]] = 0) -> None:
130
+ super().__init__(key, default)
131
+
132
+ def from_env(self, val: str) -> int:
133
+ try:
134
+ return int(val)
135
+ except ValueError as exc:
136
+ raise RuntimeError(f"Unable to use {self.key}={val}: expected int") from exc
137
+
138
+
139
+ class env_opt_base(Generic[GetType, SetType], env_base[Optional[GetType], Optional[SetType]]):
140
+
141
+ def __init__(self, key: str) -> None:
142
+ super().__init__(key, None)
143
+
144
+
145
+ ClassType = TypeVar("ClassType")
146
+
147
+
148
+ class env_class(Generic[ClassType], env_opt_base[Type[ClassType], Type[ClassType]]):
149
+
150
+ def __init__(self, key: str, type: str) -> None:
151
+ super().__init__(key)
152
+ # We can't pass the type directly to avoid import cycles
153
+ self.type = type
154
+
155
+ def from_env(self, val: str) -> Type[ClassType]:
156
+ comps = val.split(":", 1)
157
+ if len(comps) != 2:
158
+ raise RuntimeError(f"Unable to read {self.key}: '{val}' isn't of the form MODULE:CLASS")
159
+ cls = getattr(importlib.import_module(comps[0]), comps[1])
160
+
161
+ if not any((c.__name__ == self.type for c in cls.mro())):
162
+ raise RuntimeError(f"Unable to use '{val}' from {self.key}: not of type '{self.type}'")
163
+
164
+ return cast(Type[ClassType], cls)
165
+
166
+
167
+ @dataclass
168
+ class NvidiaTool:
169
+ path: str
170
+ version: str
171
+
172
+ @staticmethod
173
+ def from_path(path: str) -> Optional[NvidiaTool]:
174
+ try:
175
+ result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
176
+ if result is None:
177
+ return None
178
+ version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
179
+ if version is None:
180
+ return None
181
+ return NvidiaTool(path, version.group(1))
182
+ except subprocess.CalledProcessError:
183
+ return None
184
+
185
+
186
+ def find_nvidia_tool(binary: str) -> str:
187
+ path = os.path.join(
188
+ os.path.dirname(__file__),
189
+ "backends",
190
+ "nvidia",
191
+ "bin",
192
+ binary,
193
+ )
194
+ if os.access(path, os.X_OK):
195
+ return path
196
+
197
+ if os.name == "nt":
198
+ from triton.windows_utils import find_cuda
199
+ cuda_bin_path, _, _ = find_cuda()
200
+ if cuda_bin_path:
201
+ path = os.path.join(cuda_bin_path, binary)
202
+ if os.access(path, os.X_OK):
203
+ return path
204
+
205
+ return ""
206
+
207
+
208
+ class env_nvidia_tool(env_base[str, NvidiaTool]):
209
+
210
+ def __init__(self, binary: str) -> None:
211
+ binary += sysconfig.get_config_var("EXE")
212
+ self.binary = binary
213
+ super().__init__(f"TRITON_{binary.upper()}_PATH", lambda: find_nvidia_tool(self.binary))
214
+
215
+ def transform(self, path: str) -> NvidiaTool:
216
+ paths = [
217
+ path,
218
+ # We still add default as fallback in case the pointed binary isn't
219
+ # accessible.
220
+ self.default(),
221
+ ]
222
+ for path in paths:
223
+ if not path or not os.access(path, os.X_OK):
224
+ continue
225
+ if tool := NvidiaTool.from_path(path):
226
+ return tool
227
+
228
+ raise RuntimeError(f"Cannot find {self.binary}")
229
+
230
+ def from_env(self, val: str) -> str:
231
+ return val
232
+
233
+
234
+ # Separate classes so that types are correct
235
+ class env_opt_str(env_opt_base[str, str], env_str):
236
+ pass
237
+
238
+
239
+ class env_opt_bool(env_opt_base[bool, bool], env_bool):
240
+ pass
241
+
242
+
243
+ @dataclass(frozen=True)
244
+ class CompileTimes:
245
+ """
246
+ Model holding timing information for an invocation of the compiler.
247
+
248
+ All times in microseconds.
249
+ """
250
+
251
+ # Duration of make_ir
252
+ ir_initialization: int
253
+
254
+ # Ordered mapping from lowering stage to duration spent in that stage.
255
+ # Keyed by stage extension, e.g. ttir, ttgir
256
+ lowering_stages: list[tuple[str, int]]
257
+
258
+ # Duration of saving artifacts/metadata to cache
259
+ store_results: int
260
+
261
+ @property
262
+ def total_lowering(self) -> int:
263
+ return sum((stage[1] for stage in self.lowering_stages))
264
+
265
+ @property
266
+ def total(self) -> int:
267
+ return self.ir_initialization + self.total_lowering + self.store_results
268
+
269
+
270
+ class CompilationListener(Protocol):
271
+
272
+ def __call__(self, *, src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, str],
273
+ times: CompileTimes, cache_hit: bool) -> None:
274
+ ...
275
+
276
+
277
+ knobs_type = TypeVar("knobs_type", bound='base_knobs')
278
+
279
+
280
+ class base_knobs:
281
+
282
+ @property
283
+ def knob_descriptors(self) -> dict[str, env_base]:
284
+ return {
285
+ k: v
286
+ # data descriptors live on the class object
287
+ for k, v in type(self).__dict__.items()
288
+ if isinstance(v, env_base)
289
+ }
290
+
291
+ @property
292
+ def knobs(self) -> dict[str, Any]:
293
+ return {k: getattr(self, k) for k in self.knob_descriptors.keys()}
294
+
295
+ def copy(self: knobs_type) -> knobs_type:
296
+ res = type(self)()
297
+ res.__dict__.update(self.__dict__)
298
+ return res
299
+
300
+ def reset(self: knobs_type) -> knobs_type:
301
+ for knob in self.knob_descriptors.keys():
302
+ delattr(self, knob)
303
+ return self
304
+
305
+ @contextmanager
306
+ def scope(self) -> Generator[None, None, None]:
307
+ try:
308
+ initial_env = {knob.key: knob.env_val for knob in self.knob_descriptors.values()}
309
+ orig = dict(self.__dict__)
310
+ yield
311
+ finally:
312
+ self.__dict__.clear()
313
+ self.__dict__.update(orig)
314
+
315
+ for k, v in initial_env.items():
316
+ if v is not None:
317
+ os.environ[k] = v
318
+ elif k in os.environ:
319
+ del os.environ[k]
320
+
321
+
322
+ class BuildImpl(Protocol):
323
+
324
+ def __call__(self, name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str],
325
+ libraries: list[str], /) -> str:
326
+ ...
327
+
328
+
329
+ class build_knobs(base_knobs):
330
+ """Configuration controlling how the native compiler is invoked"""
331
+ cc: env_opt_str = env_opt_str("CC")
332
+
333
+ cudacrt_path: env_opt_str = env_opt_str("TRITON_CUDACRT_PATH")
334
+ cudart_path: env_opt_str = env_opt_str("TRITON_CUDART_PATH")
335
+
336
+ impl: Optional[BuildImpl] = None
337
+
338
+ @property
339
+ def backend_dirs(self) -> set[str]:
340
+ return {path for path in (self.cudacrt_path, self.cudart_path) if path is not None}
341
+
342
+
343
+ class redis_knobs(base_knobs):
344
+ key_format: env_str = env_str("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
345
+ host: env_str = env_str("TRITON_REDIS_HOST", "localhost")
346
+ port: env_int = env_int("TRITON_REDIS_PORT", 6379)
347
+
348
+
349
+ cache: cache_knobs
350
+
351
+
352
+ class cache_knobs(base_knobs):
353
+ home_dir: env_str = env_str("TRITON_HOME", lambda: os.path.expanduser("~/"))
354
+
355
+ dump_dir: env_str = env_str("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump"))
356
+ override_dir: env_str = env_str("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override"))
357
+ dir: env_str = env_str("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache"))
358
+
359
+ manager_class: env_class[CacheManager] = env_class("TRITON_CACHE_MANAGER", "CacheManager")
360
+ remote_manager_class: env_class[RemoteCacheBackend] = env_class("TRITON_REMOTE_CACHE_BACKEND", "RemoteCacheBackend")
361
+
362
+ def get_triton_dir(self, dirname: str) -> str:
363
+ return os.path.join(self.home_dir, ".triton", dirname)
364
+
365
+
366
+ class compilation_knobs(base_knobs):
367
+ override: env_bool = env_bool("TRITON_KERNEL_OVERRIDE")
368
+ dump_ir: env_bool = env_bool("TRITON_KERNEL_DUMP")
369
+ store_binary_only: env_bool = env_bool("TRITON_STORE_BINARY_ONLY")
370
+ always_compile: env_bool = env_bool("TRITON_ALWAYS_COMPILE")
371
+ # TODO: Use enum to constrain / 'typecheck' the values
372
+ use_ir_loc: env_opt_str = env_opt_str("USE_IR_LOC")
373
+ enable_asan: env_bool = env_bool("TRITON_ENABLE_ASAN")
374
+ disable_line_info: env_bool = env_bool("TRITON_DISABLE_LINE_INFO")
375
+ front_end_debugging: env_bool = env_bool("TRITON_FRONT_END_DEBUGGING")
376
+ allow_non_constexpr_globals: env_bool = env_bool("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS")
377
+ listener: Union[CompilationListener, None] = None
378
+
379
+
380
+ class autotuning_knobs(base_knobs):
381
+ cache: env_bool = env_bool("TRITON_CACHE_AUTOTUNING")
382
+ print: env_bool = env_bool("TRITON_PRINT_AUTOTUNING")
383
+
384
+
385
+ class LaunchHook(Protocol):
386
+
387
+ def __call__(self, metadata: LazyDict) -> None:
388
+ ...
389
+
390
+
391
+ # This is of the form [attr_name, attr_val]
392
+ # TODO: Use tuple instead of list for better typing.
393
+ KernelAttr = list[Union[str, int]]
394
+
395
+
396
+ class JITHookCompileInfo(TypedDict):
397
+ key: str
398
+ signature: dict[KernelParam, str]
399
+ device: int
400
+ constants: None
401
+ num_warps: int
402
+ num_ctas: int
403
+ num_stages: int
404
+ enable_fp_fusion: bool
405
+ launch_cooperative_grid: bool
406
+ extern_libs: tuple[tuple[str, str], ...]
407
+ configs: list[dict[tuple[int, ...], list[KernelAttr]]]
408
+ specialization_data: str
409
+ is_warmup: bool
410
+
411
+
412
+ class JITHook(Protocol):
413
+
414
+ def __call__(self, *, key: str, repr: str, fn: JitFunctionInfo, compile: JITHookCompileInfo, is_manual_warmup: bool,
415
+ already_compiled: bool) -> Optional[bool]:
416
+ ...
417
+
418
+
419
+ class runtime_knobs(base_knobs):
420
+ interpret: env_bool = env_bool("TRITON_INTERPRET")
421
+ debug: env_bool = env_bool("TRITON_DEBUG")
422
+ override_arch: env_opt_str = env_opt_str("TRITON_OVERRIDE_ARCH")
423
+
424
+ launch_enter_hook: Optional[LaunchHook] = None
425
+ launch_exit_hook: Optional[LaunchHook] = None
426
+
427
+ # Hook for inspecting compiled functions and modules
428
+ jit_cache_hook: Optional[JITHook] = None
429
+ # Hook to signal that a kernel is done compiling and inspect compiled function.
430
+ # jit_cache_hook will always be called before compilation and jit_post_compile_hook after.
431
+ jit_post_compile_hook: Optional[JITHook] = None
432
+
433
+
434
+ class language_knobs(base_knobs):
435
+ fp32_default: env_opt_str = env_opt_str("TRITON_F32_DEFAULT")
436
+ default_fp_fusion: env_bool = env_bool("TRITON_DEFAULT_FP_FUSION", True)
437
+
438
+
439
+ class nvidia_knobs(base_knobs):
440
+ cuobjdump: env_nvidia_tool = env_nvidia_tool("cuobjdump")
441
+ nvdisasm: env_nvidia_tool = env_nvidia_tool("nvdisasm")
442
+ ptxas: env_nvidia_tool = env_nvidia_tool("ptxas")
443
+
444
+ dump_nvptx: env_bool = env_bool("NVPTX_ENABLE_DUMP")
445
+ disable_ptxas_opt: env_bool = env_bool("DISABLE_PTXAS_OPT")
446
+ mock_ptx_version: env_opt_str = env_opt_str("TRITON_MOCK_PTX_VERSION")
447
+
448
+ libdevice_path: env_opt_str = env_opt_str("TRITON_LIBDEVICE_PATH")
449
+ libcuda_path: env_opt_str = env_opt_str("TRITON_LIBCUDA_PATH")
450
+
451
+
452
+ class amd_knobs(base_knobs):
453
+ use_buffer_ops: env_bool = env_bool("AMDGCN_USE_BUFFER_OPS")
454
+ dump_amdgcn: env_bool = env_bool("AMDGCN_ENABLE_DUMP")
455
+ libhip_path: env_opt_str = env_opt_str("TRITON_LIBHIP_PATH")
456
+ lld_path: env_opt_str = env_opt_str("TRITON_HIP_LLD_PATH")
457
+
458
+ # We use strs so that we can have a default value based on other runtime info
459
+ use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG")
460
+ use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE")
461
+
462
+ global_prefetch: env_int = env_int("TRITON_HIP_GLOBAL_PREFETCH")
463
+ local_prefetch: env_int = env_int("TRITON_HIP_LOCAL_PREFETCH")
464
+ use_async_copy: env_bool = env_bool("TRITON_HIP_USE_ASYNC_COPY")
465
+ scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS")
466
+
467
+
468
+ class proton_knobs(base_knobs):
469
+ cupti_dir: env_opt_str = env_opt_str("TRITON_CUPTI_LIB_PATH")
470
+
471
+
472
+ build = build_knobs()
473
+ redis = redis_knobs()
474
+ cache = cache_knobs()
475
+ compilation = compilation_knobs()
476
+ autotuning = autotuning_knobs()
477
+ runtime = runtime_knobs()
478
+ language = language_knobs()
479
+ nvidia = nvidia_knobs()
480
+ amd = amd_knobs()
481
+ proton = proton_knobs()