triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -0,0 +1,97 @@
1
+ from __future__ import annotations
2
+ from typing import List, Tuple, TYPE_CHECKING
3
+ from dataclasses import dataclass
4
+ from triton.language.core import base_type, base_value
5
+ import triton.experimental.gluon.language._core as ttgl
6
+ from triton.experimental.gluon.language._layouts import NVMMASharedLayout
7
+ from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
8
+
9
+ if TYPE_CHECKING:
10
+ from triton._C import ir
11
+
12
+ __all__ = ["async_copy_global_to_shared", "async_copy_shared_to_global", "store_wait"]
13
+
14
+
15
+ @dataclass(eq=True)
16
+ class tensor_descriptor_type(base_type):
17
+ block_type: ttgl.block_type
18
+ shape_type: ttgl.tuple_type
19
+ strides_type: ttgl.tuple_type
20
+ layout: NVMMASharedLayout
21
+
22
+ def __str__(self) -> str:
23
+ return f"tensor_descriptor<{self.block_type}, {self.layout}>"
24
+
25
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]:
26
+ handle = handles[cursor]
27
+ cursor += 1
28
+ shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
29
+ strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
30
+ value = tensor_descriptor(handle, shape, strides, self.block_type, layout=self.layout)
31
+ return value, cursor
32
+
33
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
34
+ is_signed = self.block_type.element_ty.is_int_signed()
35
+ ty = builder.get_tensor_descriptor_layout_type(
36
+ self.block_type.to_ir(builder),
37
+ is_signed,
38
+ self.layout._to_ir(builder),
39
+ )
40
+ out.append(ty)
41
+ self.shape_type._flatten_ir_types(builder, out)
42
+ self.strides_type._flatten_ir_types(builder, out)
43
+
44
+ def mangle(self) -> str:
45
+ return f"TD{self.block_type.mangle()}_{self.layout.mangle()}TD"
46
+
47
+
48
+ class tensor_descriptor(base_value):
49
+
50
+ def __init__(self, handle, shape: List[ttgl.tensor], strides: List[ttgl.tensor], block_type: ttgl.block_type,
51
+ layout: NVMMASharedLayout):
52
+ self.handle = handle
53
+ self.shape = ttgl.tuple(shape)
54
+ self.strides = ttgl.tuple(strides)
55
+ self.type = tensor_descriptor_type(block_type, shape_type=self.shape.type, strides_type=self.strides.type,
56
+ layout=layout)
57
+
58
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
59
+ handles.append(self.handle)
60
+ self.shape._flatten_ir(handles)
61
+ self.strides._flatten_ir(handles)
62
+
63
+ @property
64
+ def block_type(self):
65
+ return self.type.block_type
66
+
67
+ @property
68
+ def block_shape(self):
69
+ return self.type.block_type.shape
70
+
71
+ @property
72
+ def dtype(self):
73
+ return self.type.block_type.element_ty
74
+
75
+ @property
76
+ def layout(self):
77
+ return self.type.layout
78
+
79
+
80
+ @builtin
81
+ def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, _semantic=None):
82
+ coord = _semantic._convert_to_ir_values(coord, require_i64=False)
83
+ pred = _semantic.to_tensor(pred)
84
+ _semantic.builder.create_async_tma_copy_global_to_local(tensor_desc.handle, coord, barrier.handle, result.handle,
85
+ pred.handle)
86
+
87
+
88
+ @builtin
89
+ def async_copy_shared_to_global(tensor_desc, coord, src, _semantic=None):
90
+ coord = _semantic._convert_to_ir_values(coord, require_i64=False)
91
+ _semantic.builder.create_async_tma_copy_local_to_global(tensor_desc.handle, coord, src.handle)
92
+
93
+
94
+ @builtin
95
+ def store_wait(pendings, _semantic=None):
96
+ pendings = _unwrap_if_constexpr(pendings)
97
+ _semantic.builder.create_async_tma_store_wait(pendings)
@@ -0,0 +1,4 @@
1
+ from . import hopper
2
+ from . import blackwell
3
+
4
+ __all__ = ["hopper", "blackwell"]
@@ -0,0 +1,3 @@
1
+ from .hopper import TensorDescriptor
2
+
3
+ __all__ = ["TensorDescriptor"]
@@ -0,0 +1,45 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Any
3
+ from triton._utils import validate_block_shape, canonicalize_dtype, get_primitive_bitwidth
4
+ from triton.experimental.gluon.language._layouts import NVMMASharedLayout
5
+
6
+ __all__ = ["TensorDescriptor"]
7
+
8
+
9
+ @dataclass
10
+ class TensorDescriptor:
11
+ base: Any
12
+ shape: List[int]
13
+ strides: List[int]
14
+ block_shape: List[int]
15
+ layout: NVMMASharedLayout
16
+ padding: str = "zero"
17
+
18
+ def __post_init__(self):
19
+ rank = len(self.shape)
20
+ assert len(self.strides) == rank, f"rank mismatch: {self}"
21
+ assert len(self.block_shape) == rank, f"rank mismatch: {self}"
22
+ assert rank > 0, "rank must not be zero"
23
+ assert rank <= 5, "rank cannot be more than 5"
24
+ assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
25
+ validate_block_shape(self.block_shape)
26
+ dtype_str = canonicalize_dtype(self.base.dtype)
27
+ elem_bytes = get_primitive_bitwidth(dtype_str) // 8
28
+ for stride in self.strides[:-1]:
29
+ assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
30
+ assert self.strides[-1] == 1, "Last dimension must be contiguous"
31
+ assert isinstance(self.layout, NVMMASharedLayout), "Layout must be NVMMASharedLayout"
32
+ assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
33
+ if self.padding == "nan":
34
+ assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
35
+
36
+ @staticmethod
37
+ def from_tensor(tensor: Any, block_shape: List[int], layout: NVMMASharedLayout, padding="zero"):
38
+ return TensorDescriptor(
39
+ tensor,
40
+ tensor.shape,
41
+ tensor.stride(),
42
+ block_shape,
43
+ layout,
44
+ padding,
45
+ )
triton/knobs.py ADDED
@@ -0,0 +1,546 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import importlib
5
+ import os
6
+ import re
7
+ import subprocess
8
+ import sysconfig
9
+ import warnings
10
+
11
+ from dataclasses import dataclass
12
+ from contextlib import contextmanager
13
+ from typing import cast, Any, Callable, Generator, Generic, Optional, Protocol, Type, TypeVar, TypedDict, TYPE_CHECKING, Union
14
+
15
+ from triton._C.libtriton import getenv, getenv_bool # type: ignore
16
+
17
+ if TYPE_CHECKING:
18
+ from .runtime.cache import CacheManager, RemoteCacheBackend
19
+ from .runtime.jit import JitFunctionInfo, KernelParam
20
+ from .compiler.compiler import ASTSource, LazyDict, IRSource
21
+
22
+
23
+ class Env:
24
+ pass
25
+
26
+
27
+ env = Env()
28
+
29
+ propagate_env: bool = True
30
+
31
+
32
+ def setenv(key: str, value: Optional[str]) -> None:
33
+ if not propagate_env:
34
+ return
35
+
36
+ if value is not None:
37
+ os.environ[key] = value
38
+ elif key in os.environ:
39
+ del os.environ[key]
40
+
41
+
42
+ def toenv(val: Any) -> Union[None, tuple[Optional[str]]]:
43
+ if val is None:
44
+ return (None, )
45
+
46
+ t = type(val)
47
+ if t is bool:
48
+ return ("1" if val else "0", )
49
+
50
+ if t is str:
51
+ return (val, )
52
+
53
+ if t is int:
54
+ return (str(val), )
55
+
56
+ return None
57
+
58
+
59
+ # There's an asymmetry here so that e.g. env_nvidia_tool can be specified with a
60
+ # a string but return an NvidiaTool.
61
+ SetType = TypeVar("SetType")
62
+ GetType = TypeVar("GetType")
63
+
64
+ _NOTHING = object()
65
+
66
+
67
+ class env_base(Generic[SetType, GetType]):
68
+
69
+ def __init__(self, key: str) -> None:
70
+ self.key = key
71
+
72
+ def __set_name__(self, objclass: Type[object], name: str) -> None:
73
+ self.name = name
74
+
75
+ def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> GetType:
76
+ py_val = obj.__dict__.get(self.name, _NOTHING)
77
+ if py_val is _NOTHING:
78
+ return self.get()
79
+ return self.transform(py_val)
80
+
81
+ def get(self) -> GetType:
82
+ raise NotImplementedError()
83
+
84
+ def __set__(self, obj: object, value: Union[SetType, Env]) -> None:
85
+ if isinstance(value, Env):
86
+ obj.__dict__.pop(self.name, None)
87
+ else:
88
+ obj.__dict__[self.name] = value
89
+ if env_val := toenv(value):
90
+ setenv(self.key, env_val[0])
91
+
92
+ def __delete__(self, obj: object) -> None:
93
+ obj.__dict__.pop(self.name, None)
94
+
95
+ def transform(self, val: SetType) -> GetType:
96
+ # See comment about GetType/SetType in their definition above. Only needed
97
+ # if GetType != SetType.
98
+ return cast(GetType, val)
99
+
100
+
101
+ class env_str(env_base[str, str]):
102
+
103
+ def __init__(self, key: str, default: str):
104
+ super().__init__(key)
105
+ self.default = default
106
+
107
+ def get(self) -> str:
108
+ return getenv(self.key, self.default)
109
+
110
+
111
+ class env_str_callable_default(env_base[str, str]):
112
+
113
+ def __init__(self, key: str, default_factory: Callable[[], str]):
114
+ super().__init__(key)
115
+ self.default_factory = default_factory
116
+
117
+ def get(self) -> str:
118
+ env_val = getenv(self.key)
119
+ if env_val is None:
120
+ return self.default_factory()
121
+ return env_val
122
+
123
+
124
+ class env_bool(env_base[bool, bool]):
125
+
126
+ def __init__(self, key: str, default: bool = False) -> None:
127
+ super().__init__(key)
128
+ self.default = default
129
+
130
+ def get(self) -> bool:
131
+ return getenv_bool(self.key, self.default)
132
+
133
+
134
+ class env_int(env_base[int, int]):
135
+
136
+ def __init__(self, key: str, default: int = 0) -> None:
137
+ super().__init__(key)
138
+ self.default = default
139
+
140
+ def get(self) -> int:
141
+ val = getenv(self.key)
142
+ if val is None:
143
+ return self.default
144
+ try:
145
+ return int(val)
146
+ except ValueError as exc:
147
+ raise RuntimeError(f"Unable to use {self.key}={val}: expected int") from exc
148
+
149
+
150
+ ClassType = TypeVar("ClassType")
151
+
152
+
153
+ class env_class(Generic[ClassType], env_base[Optional[Type[ClassType]], Optional[Type[ClassType]]]):
154
+
155
+ def __init__(self, key: str, type: str) -> None:
156
+ super().__init__(key)
157
+ # We can't pass the type directly to avoid import cycles
158
+ self.type = type
159
+
160
+ def get(self) -> Optional[Type[ClassType]]:
161
+ val = getenv(self.key)
162
+ if val is None:
163
+ return None
164
+ comps = val.split(":", 1)
165
+ if len(comps) != 2:
166
+ raise RuntimeError(f"Unable to read {self.key}: '{val}' isn't of the form MODULE:CLASS")
167
+ cls = getattr(importlib.import_module(comps[0]), comps[1])
168
+
169
+ if not any((c.__name__ == self.type for c in cls.mro())):
170
+ raise RuntimeError(f"Unable to use '{val}' from {self.key}: not of type '{self.type}'")
171
+
172
+ return cast(Type[ClassType], cls)
173
+
174
+
175
+ @dataclass
176
+ class NvidiaTool:
177
+ path: str
178
+ version: str
179
+
180
+ @staticmethod
181
+ @functools.lru_cache
182
+ def from_path(path: str) -> Optional[NvidiaTool]:
183
+ try:
184
+ result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
185
+ version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
186
+ if version is None:
187
+ return None
188
+ return NvidiaTool(path, version.group(1))
189
+ except (subprocess.CalledProcessError, FileNotFoundError):
190
+ return None
191
+
192
+
193
+ def find_nvidia_tool(binary: str) -> str:
194
+ path = os.path.join(
195
+ os.path.dirname(__file__),
196
+ "backends",
197
+ "nvidia",
198
+ "bin",
199
+ binary,
200
+ )
201
+ if os.access(path, os.X_OK):
202
+ return path
203
+
204
+ if os.name == "nt":
205
+ from triton.windows_utils import find_cuda
206
+ cuda_bin_path, _, _ = find_cuda()
207
+ if cuda_bin_path:
208
+ path = os.path.join(cuda_bin_path, binary)
209
+ if os.access(path, os.X_OK):
210
+ return path
211
+
212
+ warnings.warn(f"Failed to find {binary}")
213
+ return ""
214
+
215
+
216
+ class env_nvidia_tool(env_base[str, NvidiaTool]):
217
+
218
+ def __init__(self, binary: str) -> None:
219
+ binary += sysconfig.get_config_var("EXE")
220
+ self.binary = binary
221
+ self.default_path = find_nvidia_tool(binary)
222
+ super().__init__(f"TRITON_{binary.upper()}_PATH")
223
+
224
+ def get(self) -> NvidiaTool:
225
+ return self.transform(getenv(self.key))
226
+
227
+ def transform(self, path: str) -> NvidiaTool:
228
+ # We still add default as fallback in case the pointed binary isn't
229
+ # accessible.
230
+ if path is not None:
231
+ paths = [path, self.default_path]
232
+ else:
233
+ paths = [self.default_path]
234
+
235
+ for path in paths:
236
+ if tool := NvidiaTool.from_path(path):
237
+ return tool
238
+
239
+ raise RuntimeError(f"Cannot find {self.binary}")
240
+
241
+
242
+ # Separate classes so that types are correct
243
+ class env_opt_str(env_base[Optional[str], Optional[str]]):
244
+
245
+ def get(self) -> Optional[str]:
246
+ return getenv(self.key)
247
+
248
+
249
+ class env_opt_bool(env_base):
250
+
251
+ def get(self) -> Optional[str]:
252
+ return getenv_bool(self.key, None)
253
+
254
+
255
+ @dataclass(frozen=True)
256
+ class CompileTimes:
257
+ """
258
+ Model holding timing information for an invocation of the compiler.
259
+
260
+ All times in microseconds.
261
+ """
262
+
263
+ # Duration of make_ir
264
+ ir_initialization: int
265
+
266
+ # Ordered mapping from lowering stage to duration spent in that stage.
267
+ # Keyed by stage extension, e.g. ttir, ttgir
268
+ lowering_stages: list[tuple[str, int]]
269
+
270
+ # Duration of saving artifacts/metadata to cache
271
+ store_results: int
272
+
273
+ @property
274
+ def total_lowering(self) -> int:
275
+ return sum((stage[1] for stage in self.lowering_stages))
276
+
277
+ @property
278
+ def total(self) -> int:
279
+ return self.ir_initialization + self.total_lowering + self.store_results
280
+
281
+
282
+ class CompilationListener(Protocol):
283
+
284
+ def __call__(self, *, src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, str],
285
+ times: CompileTimes, cache_hit: bool) -> None:
286
+ ...
287
+
288
+
289
+ knobs_type = TypeVar("knobs_type", bound='base_knobs')
290
+
291
+
292
+ class base_knobs:
293
+
294
+ @property
295
+ def knob_descriptors(self) -> dict[str, env_base]:
296
+ return {
297
+ k: v
298
+ # data descriptors live on the class object
299
+ for k, v in type(self).__dict__.items()
300
+ if isinstance(v, env_base)
301
+ }
302
+
303
+ @property
304
+ def knobs(self) -> dict[str, Any]:
305
+ return {k: getattr(self, k) for k in self.knob_descriptors.keys()}
306
+
307
+ def copy(self: knobs_type) -> knobs_type:
308
+ res = type(self)()
309
+ res.__dict__.update(self.__dict__)
310
+ return res
311
+
312
+ def reset(self: knobs_type) -> knobs_type:
313
+ for knob in self.knob_descriptors.keys():
314
+ delattr(self, knob)
315
+ return self
316
+
317
+ @contextmanager
318
+ def scope(self) -> Generator[None, None, None]:
319
+ try:
320
+ initial_env = {knob.key: getenv(knob.key) for knob in self.knob_descriptors.values()}
321
+ orig = dict(self.__dict__)
322
+ yield
323
+ finally:
324
+ self.__dict__.clear()
325
+ self.__dict__.update(orig)
326
+
327
+ for k, v in initial_env.items():
328
+ if v is not None:
329
+ os.environ[k] = v
330
+ elif k in os.environ:
331
+ del os.environ[k]
332
+
333
+
334
+ class BuildImpl(Protocol):
335
+
336
+ def __call__(self, name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str],
337
+ libraries: list[str], /) -> str:
338
+ ...
339
+
340
+
341
+ class build_knobs(base_knobs):
342
+ """Configuration controlling how the native compiler is invoked"""
343
+ cc: env_opt_str = env_opt_str("CC")
344
+
345
+ cudacrt_path: env_opt_str = env_opt_str("TRITON_CUDACRT_PATH")
346
+ cudart_path: env_opt_str = env_opt_str("TRITON_CUDART_PATH")
347
+
348
+ impl: Optional[BuildImpl] = None
349
+
350
+ @property
351
+ def backend_dirs(self) -> set[str]:
352
+ return {path for path in (self.cudacrt_path, self.cudart_path) if path is not None}
353
+
354
+
355
+ class redis_knobs(base_knobs):
356
+ key_format: env_str = env_str("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
357
+ host: env_str = env_str("TRITON_REDIS_HOST", "localhost")
358
+ port: env_int = env_int("TRITON_REDIS_PORT", 6379)
359
+
360
+
361
+ cache: cache_knobs
362
+
363
+
364
+ class cache_knobs(base_knobs):
365
+ home_dir: env_str = env_str("TRITON_HOME", os.path.expanduser("~/"))
366
+
367
+ dump_dir = env_str_callable_default("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump"))
368
+ override_dir = env_str_callable_default("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override"))
369
+ dir = env_str_callable_default("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache"))
370
+
371
+ manager_class: env_class[CacheManager] = env_class("TRITON_CACHE_MANAGER", "CacheManager")
372
+ remote_manager_class: env_class[RemoteCacheBackend] = env_class("TRITON_REMOTE_CACHE_BACKEND", "RemoteCacheBackend")
373
+
374
+ def get_triton_dir(self, dirname: str) -> str:
375
+ return os.path.join(self.home_dir, ".triton", dirname)
376
+
377
+
378
+ class compilation_knobs(base_knobs):
379
+ override: env_bool = env_bool("TRITON_KERNEL_OVERRIDE")
380
+ dump_ir: env_bool = env_bool("TRITON_KERNEL_DUMP")
381
+ store_binary_only: env_bool = env_bool("TRITON_STORE_BINARY_ONLY")
382
+ always_compile: env_bool = env_bool("TRITON_ALWAYS_COMPILE")
383
+ # TODO: Use enum to constrain / 'typecheck' the values
384
+ use_ir_loc: env_opt_str = env_opt_str("USE_IR_LOC")
385
+ enable_asan: env_bool = env_bool("TRITON_ENABLE_ASAN")
386
+ disable_line_info: env_bool = env_bool("TRITON_DISABLE_LINE_INFO")
387
+ front_end_debugging: env_bool = env_bool("TRITON_FRONT_END_DEBUGGING")
388
+ allow_non_constexpr_globals: env_bool = env_bool("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS")
389
+ enable_experimental_consan: env_bool = env_bool("TRITON_ENABLE_EXPERIMENTAL_CONSAN")
390
+ listener: Union[CompilationListener, None] = None
391
+
392
+
393
+ class autotuning_knobs(base_knobs):
394
+ cache: env_bool = env_bool("TRITON_CACHE_AUTOTUNING")
395
+ print: env_bool = env_bool("TRITON_PRINT_AUTOTUNING")
396
+
397
+
398
+ class LaunchHook(Protocol):
399
+ """Hook invoked before and after kernel launching
400
+ """
401
+
402
+ def __call__(self, metadata: LazyDict) -> None:
403
+ ...
404
+
405
+
406
+ class InitHandleHook(Protocol):
407
+ """Hook invoked around kernel binary/module loading.
408
+ module/function can be None for the *start* hook (before loading).
409
+ """
410
+
411
+ def __call__(
412
+ self,
413
+ module: Optional[object],
414
+ function: Optional[Callable],
415
+ name: str,
416
+ metadata_group: dict[str, str],
417
+ hash: str,
418
+ ) -> None:
419
+ ...
420
+
421
+
422
+ F = TypeVar("F", bound=Callable)
423
+
424
+
425
+ class HookChain(Generic[F]):
426
+ """A chain of hooks of the same type F to be called in order.
427
+ """
428
+
429
+ def __init__(self, reversed: bool = False):
430
+ self.calls: list[F] = []
431
+ self.reversed = reversed
432
+
433
+ def add(self, func: F) -> None:
434
+ if func not in self.calls:
435
+ self.calls.append(func)
436
+
437
+ def remove(self, func: F) -> None:
438
+ if func in self.calls:
439
+ self.calls.remove(func)
440
+
441
+ def __call__(self, *args, **kwargs):
442
+ for call in self.calls if not self.reversed else reversed(self.calls):
443
+ call(*args, **kwargs)
444
+
445
+
446
+ # This is of the form [attr_name, attr_val]
447
+ # TODO: Use tuple instead of list for better typing.
448
+ KernelAttr = list[Union[str, int]]
449
+
450
+
451
+ class JITHookCompileInfo(TypedDict):
452
+ key: str
453
+ signature: dict[KernelParam, str]
454
+ device: int
455
+ constants: None
456
+ num_warps: int
457
+ num_ctas: int
458
+ num_stages: int
459
+ enable_fp_fusion: bool
460
+ launch_cooperative_grid: bool
461
+ extern_libs: tuple[tuple[str, str], ...]
462
+ configs: list[dict[tuple[int, ...], list[KernelAttr]]]
463
+ specialization_data: str
464
+ is_warmup: bool
465
+
466
+
467
+ class JITHook(Protocol):
468
+
469
+ def __call__(self, *, key: str, repr: str, fn: JitFunctionInfo, compile: JITHookCompileInfo, is_manual_warmup: bool,
470
+ already_compiled: bool) -> Optional[bool]:
471
+ ...
472
+
473
+
474
+ class runtime_knobs(base_knobs):
475
+ interpret: env_bool = env_bool("TRITON_INTERPRET")
476
+ # debug is on critical path for kernel launches
477
+ # avoid repeated reads from env-var by calling get directly
478
+ debug: bool = env_bool("TRITON_DEBUG").get()
479
+ override_arch: env_opt_str = env_opt_str("TRITON_OVERRIDE_ARCH")
480
+
481
+ launch_enter_hook: HookChain[LaunchHook] = HookChain()
482
+ launch_exit_hook: HookChain[LaunchHook] = HookChain(reversed=True)
483
+ kernel_load_start_hook: HookChain[InitHandleHook] = HookChain()
484
+ kernel_load_end_hook: HookChain[InitHandleHook] = HookChain(reversed=True)
485
+
486
+ # Hook for inspecting compiled functions and modules
487
+ jit_cache_hook: Optional[JITHook] = None
488
+ # Hook to signal that a kernel is done compiling and inspect compiled function.
489
+ # jit_cache_hook will always be called before compilation and jit_post_compile_hook after.
490
+ jit_post_compile_hook: Optional[JITHook] = None
491
+
492
+
493
+ class language_knobs(base_knobs):
494
+ fp32_default: env_opt_str = env_opt_str("TRITON_F32_DEFAULT")
495
+ default_fp_fusion: env_bool = env_bool("TRITON_DEFAULT_FP_FUSION", True)
496
+
497
+
498
+ class nvidia_knobs(base_knobs):
499
+ cuobjdump: env_nvidia_tool = env_nvidia_tool("cuobjdump")
500
+ nvdisasm: env_nvidia_tool = env_nvidia_tool("nvdisasm")
501
+ ptxas: env_nvidia_tool = env_nvidia_tool("ptxas")
502
+
503
+ dump_nvptx: env_bool = env_bool("NVPTX_ENABLE_DUMP")
504
+ disable_ptxas_opt: env_bool = env_bool("DISABLE_PTXAS_OPT")
505
+ mock_ptx_version: env_opt_str = env_opt_str("TRITON_MOCK_PTX_VERSION")
506
+ dump_ptxas_log: env_bool = env_bool("TRITON_DUMP_PTXAS_LOG")
507
+
508
+ libdevice_path: env_opt_str = env_opt_str("TRITON_LIBDEVICE_PATH")
509
+ libcuda_path: env_opt_str = env_opt_str("TRITON_LIBCUDA_PATH")
510
+
511
+
512
+ class amd_knobs(base_knobs):
513
+ use_buffer_ops: env_bool = env_bool("AMDGCN_USE_BUFFER_OPS")
514
+ # Note: This requires use_buffer_ops be true to have any effect
515
+ use_buffer_atomics: env_bool = env_bool("AMDGCN_USE_BUFFER_ATOMICS", True)
516
+ dump_amdgcn: env_bool = env_bool("AMDGCN_ENABLE_DUMP")
517
+ libhip_path: env_opt_str = env_opt_str("TRITON_LIBHIP_PATH")
518
+
519
+ # We use strs so that we can have a default value based on other runtime info
520
+ use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG")
521
+ use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE")
522
+
523
+ global_prefetch: env_int = env_int("TRITON_HIP_GLOBAL_PREFETCH")
524
+ local_prefetch: env_int = env_int("TRITON_HIP_LOCAL_PREFETCH")
525
+ use_async_copy: env_bool = env_bool("TRITON_HIP_USE_ASYNC_COPY")
526
+ scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS")
527
+
528
+
529
+ class proton_knobs(base_knobs):
530
+ cupti_dir: env_opt_str = env_opt_str("TRITON_CUPTI_LIB_PATH")
531
+
532
+
533
+ build = build_knobs()
534
+ redis = redis_knobs()
535
+ cache = cache_knobs()
536
+ compilation = compilation_knobs()
537
+ autotuning = autotuning_knobs()
538
+ runtime = runtime_knobs()
539
+ language = language_knobs()
540
+ nvidia = nvidia_knobs()
541
+ amd = amd_knobs()
542
+ proton = proton_knobs()
543
+
544
+
545
+ def refresh_knobs():
546
+ runtime.debug = env_bool("TRITON_DEBUG").get()