triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/_C/libtriton.pyd CHANGED
Binary file
triton/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """isort:skip_file"""
2
- __version__ = '3.3.1'
2
+ __version__ = '3.5.0'
3
3
 
4
4
  # ---------------------------------------
5
5
  # Note: import order is significant here.
@@ -17,7 +17,8 @@ from .runtime import (
17
17
  InterpreterError,
18
18
  MockTensor,
19
19
  )
20
- from .runtime.jit import jit
20
+ from .runtime.jit import constexpr_function, jit
21
+ from .runtime._async_compile import AsyncCompileMode, FutureKernel
21
22
  from .compiler import compile, CompilationError
22
23
  from .errors import TritonError
23
24
  from .runtime._allocation import set_allocator
@@ -26,12 +27,17 @@ from . import language
26
27
  from . import testing
27
28
  from . import tools
28
29
 
30
+ must_use_result = language.core.must_use_result
31
+
29
32
  __all__ = [
33
+ "AsyncCompileMode",
30
34
  "autotune",
31
35
  "cdiv",
32
36
  "CompilationError",
33
37
  "compile",
34
38
  "Config",
39
+ "constexpr_function",
40
+ "FutureKernel",
35
41
  "heuristics",
36
42
  "InterpreterError",
37
43
  "jit",
@@ -39,6 +45,7 @@ __all__ = [
39
45
  "KernelInterface",
40
46
  "language",
41
47
  "MockTensor",
48
+ "must_use_result",
42
49
  "next_power_of_2",
43
50
  "OutOfResources",
44
51
  "reinterpret",
@@ -56,10 +63,12 @@ __all__ = [
56
63
  # -------------------------------------
57
64
 
58
65
 
66
+ @constexpr_function
59
67
  def cdiv(x: int, y: int):
60
68
  return (x + y - 1) // y
61
69
 
62
70
 
71
+ @constexpr_function
63
72
  def next_power_of_2(n: int):
64
73
  """Return the smallest power of 2 greater than or equal to n"""
65
74
  n -= 1
triton/_filecheck.py ADDED
@@ -0,0 +1,97 @@
1
+ import functools
2
+ import os
3
+ import inspect
4
+ import subprocess
5
+ import tempfile
6
+
7
+ import triton
8
+ from triton.compiler import ASTSource, make_backend
9
+ from triton.backends.compiler import GPUTarget
10
+ from triton.experimental.gluon._runtime import GluonASTSource
11
+ from triton.runtime.jit import create_function_from_signature
12
+ from triton._C.libtriton import ir
13
+
14
+ # ===-----------------------------------------------------------------------===#
15
+ # filecheck_test
16
+ # ===-----------------------------------------------------------------------===#
17
+
18
+ # Stub target for testing the frontend.
19
+ stub_target = GPUTarget("cuda", 100, 32)
20
+
21
+ triton_dir = os.path.dirname(__file__)
22
+ filecheck_path = os.path.join(triton_dir, "FileCheck")
23
+
24
+
25
+ class MatchError(ValueError):
26
+
27
+ def __init__(self, message, module_str):
28
+ super().__init__(message)
29
+ self.module_str = module_str
30
+
31
+ def __str__(self):
32
+ return f"{super().__str__()}\n{self.module_str}"
33
+
34
+
35
+ def run_filecheck(name, module_str, check_template):
36
+ with tempfile.TemporaryDirectory() as tempdir:
37
+ temp_module = os.path.join(tempdir, "module")
38
+ with open(temp_module, "w") as temp:
39
+ temp.write(module_str)
40
+
41
+ temp_expected = os.path.join(tempdir, "expected")
42
+ with open(temp_expected, "w") as temp:
43
+ temp.write(check_template)
44
+
45
+ try:
46
+ subprocess.check_output(
47
+ [filecheck_path, temp_expected, "--input-file", temp_module, "--dump-input-context=50"],
48
+ stderr=subprocess.STDOUT)
49
+ except subprocess.CalledProcessError as error:
50
+ decoded = error.output.decode('unicode_escape')
51
+ raise ValueError(decoded)
52
+
53
+
54
+ def run_parser(kernel_fn, args=(), kwargs={}, target=stub_target):
55
+ if "sanitize_overflow" not in kwargs:
56
+ kwargs = dict(kwargs)
57
+ kwargs["sanitize_overflow"] = False
58
+ backend = make_backend(target)
59
+ binder = create_function_from_signature(
60
+ kernel_fn.signature,
61
+ kernel_fn.params,
62
+ backend,
63
+ )
64
+
65
+ bound_args, specialization, options = binder(*args, **kwargs)
66
+ options, signature, constexprs, attrs = kernel_fn._pack_args(backend, kwargs, bound_args, specialization, options)
67
+ source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource
68
+ src = source_cls(kernel_fn, signature, constexprs, attrs)
69
+
70
+ context = ir.context()
71
+ ir.load_dialects(context)
72
+ backend.load_dialects(context)
73
+
74
+ codegen_fns = backend.get_codegen_implementation(options)
75
+ module_map = backend.get_module_map()
76
+ module = src.make_ir(target, options, codegen_fns, module_map, context)
77
+ assert module.verify()
78
+ return module
79
+
80
+
81
+ def run_filecheck_test(kernel_fn):
82
+ assert isinstance(kernel_fn, triton.runtime.JITFunction)
83
+ check_template = inspect.getsource(kernel_fn.fn)
84
+ if check_template is None:
85
+ raise ValueError("kernel function must have a docstring with FileCheck template")
86
+ mlir_module = run_parser(kernel_fn)
87
+
88
+ run_filecheck("placeholder", mlir_module.str_nodebug(), check_template)
89
+
90
+
91
+ def filecheck_test(fn):
92
+
93
+ @functools.wraps(fn)
94
+ def test_fn():
95
+ run_filecheck_test(fn)
96
+
97
+ return test_fn
@@ -4,11 +4,11 @@ import numpy as np
4
4
  import torch
5
5
  import triton
6
6
  import triton.language as tl
7
- from triton.backends.nvidia.compiler import _path_to_binary
7
+ from triton import knobs
8
+ from typing import Optional, Set, Union
8
9
  import pytest
9
10
 
10
11
  from numpy.random import RandomState
11
- from typing import Optional, Union
12
12
  from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict
13
13
 
14
14
  int_dtypes = ['int8', 'int16', 'int32', 'int64']
@@ -20,6 +20,7 @@ dtypes = integral_dtypes + float_dtypes
20
20
  dtypes_with_bfloat16 = dtypes + ['bfloat16']
21
21
  torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2']
22
22
  torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
23
+ tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"})
23
24
 
24
25
 
25
26
  def is_interpreter():
@@ -37,38 +38,58 @@ def is_cuda():
37
38
  return False if target is None else target.backend == "cuda"
38
39
 
39
40
 
40
- def is_hopper():
41
+ def is_ampere_or_newer():
42
+ return is_cuda() and torch.cuda.get_device_capability()[0] >= 8
43
+
44
+
45
+ def is_blackwell():
46
+ return is_cuda() and torch.cuda.get_device_capability()[0] == 10
47
+
48
+
49
+ def is_hopper_or_newer():
41
50
  return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
42
51
 
43
52
 
53
+ def is_hopper():
54
+ return is_cuda() and torch.cuda.get_device_capability()[0] == 9
55
+
56
+
44
57
  def is_hip():
45
58
  target = get_current_target()
46
59
  return False if target is None else target.backend == "hip"
47
60
 
48
61
 
49
- def is_hip_mi200():
62
+ def is_hip_cdna2():
50
63
  target = get_current_target()
51
- if target is None or target.backend != 'hip':
52
- return False
53
- return target.arch == 'gfx90a'
64
+ return target is not None and target.backend == 'hip' and target.arch == 'gfx90a'
54
65
 
55
66
 
56
- def is_hip_mi300():
67
+ def is_hip_cdna3():
57
68
  target = get_current_target()
58
- if target is None or target.backend != 'hip':
59
- return False
60
- return target.arch in ('gfx940', 'gfx941', 'gfx942')
69
+ return target is not None and target.backend == 'hip' and target.arch == 'gfx942'
61
70
 
62
71
 
63
- def is_hip_mi350():
72
+ def is_hip_cdna4():
64
73
  target = get_current_target()
65
- if target is None or target.backend != 'hip':
66
- return False
67
- return target.arch in ('gfx950')
74
+ return target is not None and target.backend == 'hip' and target.arch == 'gfx950'
75
+
76
+
77
+ def is_hip_gfx11():
78
+ target = get_current_target()
79
+ return target is not None and target.backend == 'hip' and 'gfx11' in target.arch
80
+
81
+
82
+ def is_hip_gfx12():
83
+ target = get_current_target()
84
+ return target is not None and target.backend == 'hip' and 'gfx12' in target.arch
68
85
 
69
86
 
70
87
  def is_hip_cdna():
71
- return is_hip_mi200() or is_hip_mi300() or is_hip_mi350()
88
+ return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4()
89
+
90
+
91
+ def get_hip_lds_size():
92
+ return 163840 if is_hip_cdna4() else 65536
72
93
 
73
94
 
74
95
  def is_xpu():
@@ -131,7 +152,7 @@ def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torc
131
152
 
132
153
 
133
154
  def str_to_triton_dtype(x: str) -> tl.dtype:
134
- return tl.str_to_ty(type_canonicalisation_dict[x])
155
+ return tl.str_to_ty(type_canonicalisation_dict[x], None)
135
156
 
136
157
 
137
158
  def torch_dtype_name(dtype) -> str:
@@ -161,7 +182,7 @@ def supports_tma(byval_only=False):
161
182
  return True
162
183
  if not is_cuda():
163
184
  return False
164
- _, cuda_version = _path_to_binary("ptxas")
185
+ cuda_version = knobs.nvidia.ptxas.version
165
186
  min_cuda_version = (12, 0) if byval_only else (12, 3)
166
187
  cuda_version_tuple = tuple(map(int, cuda_version.split(".")))
167
188
  assert len(cuda_version_tuple) == 2, cuda_version_tuple
@@ -176,3 +197,59 @@ def tma_skip_msg(byval_only=False):
176
197
 
177
198
 
178
199
  requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg())
200
+
201
+
202
+ def default_alloc_fn(size: int, align: int, _):
203
+ return torch.empty(size, dtype=torch.int8, device="cuda")
204
+
205
+
206
+ def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> torch.Tensor:
207
+ if isinstance(t, triton.runtime.jit.TensorWrapper):
208
+ return t.base
209
+ return t
210
+
211
+
212
+ def _fresh_knobs_impl(skipped_attr: Optional[Set[str]] = None):
213
+ from triton import knobs
214
+
215
+ if skipped_attr is None:
216
+ skipped_attr = set()
217
+
218
+ monkeypatch = pytest.MonkeyPatch()
219
+
220
+ knobs_map = {
221
+ name: knobset
222
+ for name, knobset in knobs.__dict__.items()
223
+ if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs and name not in skipped_attr
224
+ }
225
+
226
+ # We store which variables we need to unset below in finally because
227
+ # monkeypatch doesn't appear to reset variables that were never set
228
+ # before the monkeypatch.delenv call below.
229
+ env_to_unset = []
230
+ prev_propagate_env = knobs.propagate_env
231
+
232
+ def fresh_function():
233
+ nonlocal env_to_unset
234
+ for name, knobset in knobs_map.items():
235
+ setattr(knobs, name, knobset.copy().reset())
236
+ for knob in knobset.knob_descriptors.values():
237
+ if knob.key in os.environ:
238
+ monkeypatch.delenv(knob.key, raising=False)
239
+ else:
240
+ env_to_unset.append(knob.key)
241
+ knobs.propagate_env = True
242
+ return knobs
243
+
244
+ def reset_function():
245
+ for name, knobset in knobs_map.items():
246
+ setattr(knobs, name, knobset)
247
+ # `undo` should be placed before `del os.environ`
248
+ # Otherwise, it may restore environment variables that monkeypatch deleted
249
+ monkeypatch.undo()
250
+ for k in env_to_unset:
251
+ if k in os.environ:
252
+ del os.environ[k]
253
+ knobs.propagate_env = prev_propagate_env
254
+
255
+ return fresh_function, reset_function
triton/_utils.py CHANGED
@@ -1,35 +1,126 @@
1
+ from __future__ import annotations
2
+
1
3
  from functools import reduce
4
+ from typing import Any, Callable, TYPE_CHECKING, Union, List, Dict
5
+
6
+ if TYPE_CHECKING:
7
+ from .language import core
8
+ IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type]
9
+ ObjPath = tuple[int, ...]
2
10
 
11
+ TRITON_MAX_TENSOR_NUMEL = 1048576
3
12
 
4
- def get_iterable_path(iterable, path):
5
- return reduce(lambda a, idx: a[idx], path, iterable)
6
13
 
14
+ def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any:
15
+ return reduce(lambda a, idx: a[idx], path, iterable) # type: ignore[index]
7
16
 
8
- def set_iterable_path(iterable, path, val):
17
+
18
+ def set_iterable_path(iterable: IterableType, path: tuple[int, ...], val: Any):
19
+ from .language import core
20
+ assert len(path) != 0
9
21
  prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
10
- prev[path[-1]] = val
22
+ assert isinstance(prev, core.tuple)
23
+ prev._setitem(path[-1], val)
11
24
 
12
25
 
13
- def find_paths_if(iterable, pred):
26
+ def find_paths_if(iterable: Union[IterableType, Any], pred: Callable[[ObjPath, Any], bool]) -> list[ObjPath]:
14
27
  from .language import core
15
- is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
16
- ret = dict()
28
+ is_iterable: Callable[[Any], bool] = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
29
+ # We need to use dict so that ordering is maintained, while set doesn't guarantee order
30
+ ret: dict[ObjPath, None] = {}
17
31
 
18
- def _impl(current, path):
19
- path = (path[0], ) if len(path) == 1 else tuple(path)
32
+ def _impl(path: tuple[int, ...], current: Any):
20
33
  if is_iterable(current):
21
34
  for idx, item in enumerate(current):
22
- _impl(item, path + (idx, ))
35
+ _impl((*path, idx), item)
23
36
  elif pred(path, current):
24
- if len(path) == 1:
25
- ret[(path[0], )] = None
26
- else:
27
- ret[tuple(path)] = None
28
-
29
- if is_iterable(iterable):
30
- _impl(iterable, [])
31
- elif pred(list(), iterable):
32
- ret = {tuple(): None}
33
- else:
34
- ret = dict()
37
+ ret[path] = None
38
+
39
+ _impl((), iterable)
40
+
35
41
  return list(ret.keys())
42
+
43
+
44
+ def is_power_of_two(x):
45
+ return (x & (x - 1)) == 0
46
+
47
+
48
+ def validate_block_shape(shape: List[int]):
49
+ numel = 1
50
+ for i, d in enumerate(shape):
51
+ if not isinstance(d, int):
52
+ raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
53
+ if not is_power_of_two(d):
54
+ raise ValueError(f"Shape element {i} must be a power of 2")
55
+ numel *= d
56
+
57
+ if numel > TRITON_MAX_TENSOR_NUMEL:
58
+ raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
59
+ return numel
60
+
61
+
62
+ type_canonicalisation_dict = {
63
+ # we canonicalise all bools to be unsigned:
64
+ "bool": "u1",
65
+ "int1": "u1",
66
+ "uint1": "u1",
67
+ "i1": "u1",
68
+ # floating-point dtypes:
69
+ "float8e4nv": "fp8e4nv",
70
+ "float8e5": "fp8e5",
71
+ "float8e4b15": "fp8e4b15",
72
+ "float8_e4m3fn": "fp8e4nv",
73
+ "float8e4b8": "fp8e4b8",
74
+ "float8_e4m3fnuz": "fp8e4b8",
75
+ "float8_e5m2": "fp8e5",
76
+ "float8e5b16": "fp8e5b16",
77
+ "float8_e5m2fnuz": "fp8e5b16",
78
+ "half": "fp16",
79
+ "float16": "fp16",
80
+ "bfloat16": "bf16",
81
+ "float": "fp32",
82
+ "float32": "fp32",
83
+ "double": "fp64",
84
+ "float64": "fp64",
85
+ # signed integers:
86
+ "int8": "i8",
87
+ "int16": "i16",
88
+ "int": "i32",
89
+ "int32": "i32",
90
+ "int64": "i64",
91
+ # unsigned integers:
92
+ "uint8": "u8",
93
+ "uint16": "u16",
94
+ "uint32": "u32",
95
+ "uint64": "u64",
96
+ "void": "void",
97
+ }
98
+
99
+ for v in list(type_canonicalisation_dict.values()):
100
+ type_canonicalisation_dict[v] = v
101
+
102
+
103
+ def canonicalize_dtype(dtype):
104
+ dtype_str = str(dtype).split(".")[-1]
105
+ return type_canonicalisation_dict[dtype_str]
106
+
107
+
108
+ BITWIDTH_DICT: Dict[str, int] = {
109
+ **{f"u{n}": n
110
+ for n in (1, 8, 16, 32, 64)},
111
+ **{f"i{n}": n
112
+ for n in (1, 8, 16, 32, 64)},
113
+ **{f"fp{n}": n
114
+ for n in (16, 32, 64)},
115
+ **{f"fp8{suffix}": 8
116
+ for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")},
117
+ "bf16": 16,
118
+ "void": 0,
119
+ }
120
+
121
+ for k, v in type_canonicalisation_dict.items():
122
+ BITWIDTH_DICT[k] = BITWIDTH_DICT[v]
123
+
124
+
125
+ def get_primitive_bitwidth(dtype: str) -> int:
126
+ return BITWIDTH_DICT[dtype]
@@ -1,20 +1,22 @@
1
- import os
2
- import importlib.util
1
+ import importlib
3
2
  import inspect
3
+ import sys
4
4
  from dataclasses import dataclass
5
+ from typing import Type, TypeVar, Union
6
+ from types import ModuleType
5
7
  from .driver import DriverBase
6
8
  from .compiler import BaseBackend
7
9
 
10
+ if sys.version_info >= (3, 10):
11
+ from importlib.metadata import entry_points
12
+ else:
13
+ from importlib_metadata import entry_points
8
14
 
9
- def _load_module(name, path):
10
- spec = importlib.util.spec_from_file_location(name, path)
11
- module = importlib.util.module_from_spec(spec)
12
- spec.loader.exec_module(module)
13
- return module
15
+ T = TypeVar("T", bound=Union[BaseBackend, DriverBase])
14
16
 
15
17
 
16
- def _find_concrete_subclasses(module, base_class):
17
- ret = []
18
+ def _find_concrete_subclasses(module: ModuleType, base_class: Type[T]) -> Type[T]:
19
+ ret: list[Type[T]] = []
18
20
  for attr_name in dir(module):
19
21
  attr = getattr(module, attr_name)
20
22
  if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr):
@@ -28,23 +30,18 @@ def _find_concrete_subclasses(module, base_class):
28
30
 
29
31
  @dataclass(frozen=True)
30
32
  class Backend:
31
- compiler: BaseBackend = None
32
- driver: DriverBase = None
33
+ compiler: Type[BaseBackend]
34
+ driver: Type[DriverBase]
33
35
 
34
36
 
35
- def _discover_backends():
37
+ def _discover_backends() -> dict[str, Backend]:
36
38
  backends = dict()
37
- root = os.path.dirname(__file__)
38
- for name in os.listdir(root):
39
- if not os.path.isdir(os.path.join(root, name)):
40
- continue
41
- if name.startswith('__'):
42
- continue
43
- compiler = _load_module(name, os.path.join(root, name, 'compiler.py'))
44
- driver = _load_module(name, os.path.join(root, name, 'driver.py'))
45
- backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend),
46
- _find_concrete_subclasses(driver, DriverBase))
39
+ for ep in entry_points().select(group="triton.backends"):
40
+ compiler = importlib.import_module(f"{ep.value}.compiler")
41
+ driver = importlib.import_module(f"{ep.value}.driver")
42
+ backends[ep.name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), # type: ignore
43
+ _find_concrete_subclasses(driver, DriverBase)) # type: ignore
47
44
  return backends
48
45
 
49
46
 
50
- backends = _discover_backends()
47
+ backends: dict[str, Backend] = _discover_backends()
File without changes