triton-windows 3.3.1.post19__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -1,5 +1,6 @@
1
- from triton.backends.compiler import BaseBackend, GPUTarget
1
+ from triton.backends.compiler import BaseBackend, GPUTarget, Language
2
2
  from triton._C.libtriton import ir, passes, llvm, nvidia
3
+ from triton import knobs
3
4
  from triton.runtime.errors import PTXASError
4
5
 
5
6
  from dataclasses import dataclass
@@ -13,7 +14,6 @@ import signal
13
14
  import os
14
15
  import subprocess
15
16
  from pathlib import Path
16
- import sysconfig
17
17
 
18
18
 
19
19
  def min_dot_size(target: GPUTarget):
@@ -22,54 +22,25 @@ def min_dot_size(target: GPUTarget):
22
22
  lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
23
23
  rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
24
24
  assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
25
+ # For small M/N the input we can still use tensorcores with padding.
25
26
  if lhs_bitwidth == 8:
26
- return (16, 16, 32)
27
+ return (1, 1, 32)
27
28
  else:
28
- return (16, 16, 16)
29
+ return (1, 1, 16)
29
30
 
30
31
  return check_dot_compatibility
31
32
 
32
33
 
33
- @functools.lru_cache()
34
- def _path_to_binary(binary: str):
35
- paths = [
36
- os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
37
- ]
38
- binary += sysconfig.get_config_var("EXE")
39
- paths += [
40
- os.path.join(os.path.dirname(__file__), "bin", binary),
41
- ]
42
- if os.name == "nt":
43
- from triton.windows_utils import find_cuda
44
- cuda_bin_path, _, _ = find_cuda()
45
- if cuda_bin_path:
46
- paths += [os.path.join(cuda_bin_path, binary)]
47
-
48
- for path in paths:
49
- if os.path.exists(path) and os.path.isfile(path):
50
- result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
51
- if result is not None:
52
- version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
53
- if version is not None:
54
- return path, version.group(1)
55
- raise RuntimeError(f"Cannot find {binary}")
56
-
57
-
58
- @functools.lru_cache()
59
- def get_ptxas(arch: int):
60
- if os.name == "nt":
61
- name = "ptxas"
62
- else:
63
- name = "ptxas-blackwell" if arch >= 100 else "ptxas"
64
- return _path_to_binary(name)
34
+ def get_ptxas() -> knobs.NvidiaTool:
35
+ return knobs.nvidia.ptxas
65
36
 
66
37
 
67
38
  @functools.lru_cache()
68
- def get_ptxas_version(arch: int):
69
- mock_ver = os.environ.get('TRITON_MOCK_PTX_VERSION')
39
+ def get_ptxas_version():
40
+ mock_ver = knobs.nvidia.mock_ptx_version
70
41
  if mock_ver is not None:
71
42
  return mock_ver # This is not really a version of ptxas, but it is good enough for testing
72
- version = subprocess.check_output([get_ptxas(arch)[0], "--version"]).decode("utf-8")
43
+ version = subprocess.check_output([get_ptxas().path, "--version"]).decode("utf-8")
73
44
  return version
74
45
 
75
46
 
@@ -89,13 +60,18 @@ def ptx_get_version(cuda_version) -> int:
89
60
  return 70 + minor
90
61
  if major == 10:
91
62
  return 63 + minor
63
+
64
+ if major >= 13:
65
+ base_ptx = 90
66
+ return base_ptx + (major - 13) * 10 + minor
67
+
92
68
  raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
93
69
 
94
70
 
95
71
  def get_ptx_version_from_options(options, arch: int):
96
72
  ptx_version = options.ptx_version
97
73
  if ptx_version is None:
98
- _, cuda_version = get_ptxas(arch)
74
+ cuda_version = get_ptxas().version
99
75
  ptx_version = ptx_get_version(cuda_version)
100
76
  return ptx_version
101
77
 
@@ -141,19 +117,19 @@ class CUDAOptions:
141
117
  num_warps: int = 4
142
118
  num_ctas: int = 1
143
119
  num_stages: int = 3
144
- num_buffers_warp_spec: int = 0
145
- num_consumer_groups: int = 0
146
- reg_dec_producer: int = 0
147
- reg_inc_consumer: int = 0
120
+ warp_size: int = 32
148
121
  # maxnreg corresponds to the ptx parameter .maxnreg, which controls the
149
122
  # maximum number of 32-bit registers used by one thread.
150
123
  maxnreg: Optional[int] = None
151
124
  cluster_dims: tuple = (1, 1, 1)
152
125
  ptx_version: int = None
126
+ ptx_options: str = None
127
+ ir_override: Optional[str] = None # filename of a user-defined IR (*.{ttir|ttgir|llir|ptx})
153
128
  enable_fp_fusion: bool = True
154
129
  launch_cooperative_grid: bool = False
155
- supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
156
- deprecated_fp8_dtypes: Tuple[str] = ()
130
+ launch_pdl: bool = False
131
+ supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e4b15")
132
+ deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
157
133
  default_dot_input_precision: str = "tf32"
158
134
  allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
159
135
  max_num_imprecise_acc_default: bool = None
@@ -162,12 +138,14 @@ class CUDAOptions:
162
138
  backend_name: str = 'cuda'
163
139
  sanitize_overflow: bool = True
164
140
  arch: str = None
141
+ instrumentation_mode: str = ""
165
142
 
166
143
  def __post_init__(self):
167
144
  default_libdir = Path(__file__).parent / 'lib'
168
145
  extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
169
146
  if not extern_libs.get('libdevice', None):
170
- extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH", str(default_libdir / 'libdevice.10.bc'))
147
+ extern_libs['libdevice'] = knobs.nvidia.libdevice_path or str(default_libdir / 'libdevice.10.bc')
148
+
171
149
  object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
172
150
  assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
173
151
  "num_warps must be a power of 2"
@@ -180,6 +158,7 @@ class CUDAOptions:
180
158
 
181
159
 
182
160
  class CUDABackend(BaseBackend):
161
+ instrumentation = None
183
162
 
184
163
  @staticmethod
185
164
  def supports_target(target: GPUTarget):
@@ -192,27 +171,34 @@ class CUDABackend(BaseBackend):
192
171
  raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
193
172
  return int(match.group(1))
194
173
 
174
+ def get_target_name(self, options) -> str:
175
+ capability = self._parse_arch(options.arch)
176
+ return f"cuda:{capability}"
177
+
195
178
  def __init__(self, target: GPUTarget) -> None:
196
179
  super().__init__(target)
197
180
  self.binary_ext = "cubin"
198
181
 
199
182
  def parse_options(self, opts) -> Any:
200
- args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", f"sm{self.target.arch}")}
183
+ args = {'arch': knobs.runtime.override_arch or f"sm{self.target.arch}"}
201
184
  args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
202
185
  capability = int(self._parse_arch(args["arch"]))
203
186
 
187
+ if args.get("num_ctas", 1) > 1 and capability < 90:
188
+ raise ValueError((f"num_ctas > 1 requires NVIDIA SM90+ (Hopper). "
189
+ f"Current target is sm_{capability}. This configuration will fail. "
190
+ f"Please set num_ctas=1 or target an SM90+ GPU."))
191
+
204
192
  if "supported_fp8_dtypes" not in args:
205
193
  supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
206
- if capability >= 89:
207
- supported_fp8_dtypes.add("fp8e4nv")
208
194
  args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
209
195
 
210
- if "deprecated_fp8_dtypes" not in args:
196
+ if "deprecated_fp8_dot_operand_dtypes" not in args:
211
197
  if capability >= 90:
212
- args["deprecated_fp8_dtypes"] = ("fp8e4b15", )
198
+ args["deprecated_fp8_dot_operand_dtypes"] = ("fp8e4b15", )
213
199
 
214
200
  if "enable_fp_fusion" not in args:
215
- args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
201
+ args["enable_fp_fusion"] = knobs.language.default_fp_fusion
216
202
 
217
203
  args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
218
204
 
@@ -244,13 +230,17 @@ class CUDABackend(BaseBackend):
244
230
 
245
231
  def load_dialects(self, ctx):
246
232
  nvidia.load_dialects(ctx)
233
+ if CUDABackend.instrumentation:
234
+ CUDABackend.instrumentation.load_dialects(ctx)
247
235
 
248
236
  @staticmethod
249
- def make_ttir(mod, metadata, opt):
237
+ def make_ttir(mod, metadata, opt, capability):
250
238
  pm = ir.pass_manager(mod.context)
251
239
  pm.enable_debug()
252
240
  passes.common.add_inliner(pm)
253
241
  passes.ttir.add_rewrite_tensor_pointer(pm)
242
+ if capability // 10 < 9:
243
+ passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
254
244
  passes.common.add_canonicalizer(pm)
255
245
  passes.ttir.add_combine(pm)
256
246
  passes.ttir.add_reorder_broadcast(pm)
@@ -262,6 +252,10 @@ class CUDABackend(BaseBackend):
262
252
 
263
253
  @staticmethod
264
254
  def make_ttgir(mod, metadata, opt, capability):
255
+ # Set maxnreg on all kernels, if it was provided.
256
+ if opt.maxnreg is not None:
257
+ mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
258
+
265
259
  cluster_info = nvidia.ClusterInfo()
266
260
  if opt.cluster_dims is not None:
267
261
  cluster_info.clusterDimX = opt.cluster_dims[0]
@@ -281,56 +275,75 @@ class CUDABackend(BaseBackend):
281
275
  passes.ttgpuir.add_accelerate_matmul(pm)
282
276
  passes.ttgpuir.add_remove_layout_conversions(pm)
283
277
  passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
284
- passes.common.add_cse(pm)
278
+ nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
279
+ passes.ttir.add_loop_aware_cse(pm)
285
280
  if capability // 10 in [8, 9]:
286
281
  passes.ttgpuir.add_fuse_nested_loops(pm)
287
282
  passes.common.add_canonicalizer(pm)
288
- passes.common.add_licm(pm)
289
- passes.ttgpuir.add_optimize_accumulator_init(pm)
283
+ passes.ttir.add_triton_licm(pm)
290
284
  passes.common.add_canonicalizer(pm)
291
285
  passes.ttgpuir.add_combine_tensor_select_and_if(pm)
292
- passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups)
293
- passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups)
294
- passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups)
295
- passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups,
296
- opt.reg_dec_producer, opt.reg_inc_consumer)
286
+ nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
287
+ passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
288
+ passes.ttgpuir.add_schedule_loops(pm)
297
289
  passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
298
- passes.ttgpuir.add_ping_pong_sync(pm, opt.num_consumer_groups)
299
- passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)
300
290
  elif capability // 10 >= 10:
301
291
  passes.ttgpuir.add_fuse_nested_loops(pm)
302
292
  passes.common.add_canonicalizer(pm)
303
- passes.common.add_licm(pm)
293
+ passes.ttir.add_triton_licm(pm)
304
294
  passes.ttgpuir.add_optimize_accumulator_init(pm)
305
- passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups)
306
- passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups)
307
- passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups)
308
- passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups,
309
- opt.reg_dec_producer, opt.reg_inc_consumer)
295
+ passes.ttgpuir.add_hoist_tmem_alloc(pm, False)
296
+ nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
297
+ passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
298
+ passes.ttgpuir.add_schedule_loops(pm)
299
+ passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
310
300
  passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
311
301
  passes.ttgpuir.add_combine_tensor_select_and_if(pm)
312
- nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
313
- nvidia.passes.ttnvgpuir.add_keep_acc_in_tmem(pm)
314
- passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)
315
- passes.common.add_canonicalizer(pm)
302
+ # hoist again and allow hoisting out of if statements
303
+ passes.ttgpuir.add_hoist_tmem_alloc(pm, True)
304
+ nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
316
305
  else:
317
- passes.common.add_licm(pm)
306
+ passes.ttir.add_triton_licm(pm)
307
+ passes.common.add_canonicalizer(pm)
308
+ passes.ttir.add_loop_aware_cse(pm)
318
309
  passes.ttgpuir.add_prefetch(pm)
319
310
  passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
320
311
  passes.ttgpuir.add_coalesce_async_copy(pm)
312
+ nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
321
313
  passes.ttgpuir.add_remove_layout_conversions(pm)
314
+ nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
322
315
  passes.ttgpuir.add_reduce_data_duplication(pm)
323
316
  passes.ttgpuir.add_reorder_instructions(pm)
324
- passes.common.add_cse(pm)
317
+ passes.ttir.add_loop_aware_cse(pm)
325
318
  passes.common.add_symbol_dce(pm)
326
319
  if capability // 10 >= 9:
327
- nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
328
320
  nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
321
+ nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
322
+ nvidia.passes.ttnvgpuir.add_lower_mma(pm)
323
+ passes.common.add_sccp(pm)
324
+ passes.common.add_cse(pm)
329
325
  passes.common.add_canonicalizer(pm)
330
- if capability // 10 >= 9:
331
- passes.ttgpuir.add_ws_canonicalization(pm, opt.num_consumer_groups)
326
+
332
327
  pm.run(mod)
333
328
  metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
329
+ tensordesc_meta = mod.get_tensordesc_metadata()
330
+ metadata["tensordesc_meta"] = tensordesc_meta
331
+ return mod
332
+
333
+ def gluon_to_ttgir(self, src, metadata, options, capability):
334
+ mod = src
335
+ pm = ir.pass_manager(mod.context)
336
+ pm.enable_debug()
337
+
338
+ passes.gluon.add_inliner(pm)
339
+ passes.gluon.add_resolve_auto_encodings(pm)
340
+ passes.common.add_sccp(pm)
341
+ passes.ttir.add_loop_aware_cse(pm)
342
+ passes.gluon.add_canonicalizer(pm)
343
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
344
+
345
+ pm.run(mod)
346
+ metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
334
347
  return mod
335
348
 
336
349
  def make_llir(self, src, metadata, options, capability):
@@ -341,13 +354,19 @@ class CUDABackend(BaseBackend):
341
354
  pm = ir.pass_manager(mod.context)
342
355
  pm.enable_debug()
343
356
 
344
- nvidia.passes.ttnvgpuir.add_lower_mma(pm)
345
357
  passes.ttgpuir.add_combine_tensor_select_and_if(pm)
346
358
  passes.ttgpuir.add_allocate_warp_groups(pm)
347
359
  passes.convert.add_scf_to_cf(pm)
348
- passes.ttgpuir.add_allocate_shared_memory(pm)
360
+ nvidia.passes.ttgpuir.add_allocate_shared_memory_nv(pm, capability, ptx_version)
349
361
  nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
362
+ if knobs.compilation.enable_experimental_consan:
363
+ # Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
364
+ passes.ttgpuir.add_concurrency_sanitizer(pm)
350
365
  passes.ttgpuir.add_allocate_global_scratch_memory(pm)
366
+ nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
367
+ # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
368
+ if CUDABackend.instrumentation:
369
+ CUDABackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
351
370
  nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
352
371
  passes.common.add_canonicalizer(pm)
353
372
  passes.common.add_cse(pm)
@@ -356,29 +375,28 @@ class CUDABackend(BaseBackend):
356
375
  passes.common.add_canonicalizer(pm)
357
376
  passes.common.add_cse(pm)
358
377
  passes.common.add_symbol_dce(pm)
359
- if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
378
+ passes.convert.add_nvvm_to_llvm(pm)
379
+ if not knobs.compilation.disable_line_info:
360
380
  passes.llvmir.add_di_scope(pm)
381
+ if CUDABackend.instrumentation:
382
+ CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
383
+
361
384
  pm.run(mod)
362
385
  # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
363
386
  llvm.init_targets()
364
387
  context = llvm.context()
365
- if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
388
+ if knobs.compilation.enable_asan:
366
389
  raise RuntimeError(
367
390
  "Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
368
391
  llvm_mod = llvm.to_module(mod, context)
369
392
  proc = sm_arch_from_capability(capability)
370
393
  features = get_features(options, self.target.arch)
371
394
  triple = 'nvptx64-nvidia-cuda'
395
+ nvidia.set_short_ptr()
372
396
  llvm.attach_datalayout(llvm_mod, triple, proc, features)
373
397
  nvidia.set_nvvm_reflect_ftz(llvm_mod)
374
398
 
375
- # Set maxnreg on all kernels, if it was provided.
376
- if options.maxnreg is not None:
377
- for k in llvm_mod.get_functions():
378
- if not k.is_declaration() and k.is_external_linkage():
379
- k.set_nvvm_maxnreg(options.maxnreg)
380
-
381
- if options.extern_libs:
399
+ if options.extern_libs and nvidia.has_extern_deps(llvm_mod):
382
400
  paths = [path for (name, path) in options.extern_libs]
383
401
  llvm.link_extern_libs(llvm_mod, paths)
384
402
 
@@ -393,6 +411,8 @@ class CUDABackend(BaseBackend):
393
411
  metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
394
412
  metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
395
413
  metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
414
+ metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
415
+ metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
396
416
  ret = str(llvm_mod)
397
417
  del llvm_mod
398
418
  del context
@@ -404,7 +424,7 @@ class CUDABackend(BaseBackend):
404
424
  triple = 'nvptx64-nvidia-cuda'
405
425
  proc = sm_arch_from_capability(capability)
406
426
  features = get_features(opt, self.target.arch)
407
- ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False)
427
+ ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
408
428
  # Find kernel names (there should only be one)
409
429
  names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
410
430
  assert len(names) == 1
@@ -415,29 +435,52 @@ class CUDABackend(BaseBackend):
415
435
  ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
416
436
  # Remove the debug flag that prevents ptxas from optimizing the code
417
437
  ret = re.sub(r",\s*debug|debug,\s*", "", ret)
418
- if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1":
438
+ if knobs.nvidia.dump_nvptx:
419
439
  print("// -----// NVPTX Dump //----- //")
420
440
  print(ret)
421
441
  return ret
422
442
 
423
443
  def make_cubin(self, src, metadata, opt, capability):
424
- ptxas, _ = get_ptxas(self.target.arch)
444
+ ptxas = get_ptxas().path
425
445
  with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
426
446
  tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
427
447
  fsrc.write(src)
428
448
  fsrc.flush()
429
449
  fbin = fsrc.name + '.o'
430
450
 
431
- line_info = ["-lineinfo", "-suppress-debug-info"] if os.environ.get("TRITON_DISABLE_LINE_INFO",
432
- "0") == "1" else ["-lineinfo"]
433
- fmad = [] if opt.enable_fp_fusion else ['--fmad=false']
451
+ debug_info = []
452
+ if knobs.compilation.disable_line_info:
453
+ # This option is ignored if used without -lineinfo
454
+ debug_info += ["-lineinfo", "-suppress-debug-info"]
455
+ elif knobs.nvidia.disable_ptxas_opt:
456
+ # Synthesize complete debug info
457
+ debug_info += ["-g"]
458
+ else:
459
+ # Only emit line info
460
+ debug_info += ["-lineinfo"]
461
+
462
+ fmad = [] if opt.enable_fp_fusion else ["--fmad=false"]
434
463
  arch = sm_arch_from_capability(capability)
435
- opt_level = ['--opt-level', '0'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1" else []
436
- ptxas_cmd = [ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name={arch}', fsrc.name, '-o', fbin]
464
+
465
+ # Disable ptxas optimizations if requested
466
+ disable_opt = ['--opt-level', '0'] if knobs.nvidia.disable_ptxas_opt else []
467
+
468
+ # Accept more ptxas options if provided
469
+ ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else []
470
+
471
+ ptxas_cmd = [
472
+ ptxas, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name,
473
+ '-o', fbin
474
+ ]
437
475
  try:
438
476
  # close_fds=True on Windows and False on Linux, see https://github.com/triton-lang/triton/pull/4357
439
477
  # On Windows, both stdout and stderr need to be redirected to flog
440
- subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog, stderr=flog)
478
+ subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog,
479
+ stderr=flog)
480
+ if knobs.nvidia.dump_ptxas_log:
481
+ with open(flog.name) as log_file:
482
+ print(log_file.read())
483
+
441
484
  except subprocess.CalledProcessError as e:
442
485
  with open(flog.name) as log_file:
443
486
  log = log_file.read()
@@ -449,9 +492,20 @@ class CUDABackend(BaseBackend):
449
492
  else:
450
493
  error = f'`ptxas` failed with error code {e.returncode}'
451
494
 
452
- raise PTXASError(f"{error}\n"
453
- f"`ptxas` stderr:\n{log}\n"
454
- f'Repro command: {" ".join(ptxas_cmd)}\n')
495
+ error = (f"{error}\n"
496
+ f"`ptxas` stderr:\n{log}\n"
497
+ f'Repro command: {" ".join(ptxas_cmd)}\n')
498
+
499
+ print(f"""
500
+
501
+ ================================================================
502
+ {error}
503
+
504
+ {src}
505
+ ================================================================
506
+ please share the reproducer above with Triton project.
507
+ """)
508
+ raise PTXASError(error)
455
509
 
456
510
  with open(fbin, 'rb') as f:
457
511
  cubin = f.read()
@@ -462,15 +516,18 @@ class CUDABackend(BaseBackend):
462
516
  try_remove(flog.name)
463
517
  return cubin
464
518
 
465
- def add_stages(self, stages, options):
519
+ def add_stages(self, stages, options, language):
466
520
  capability = self._parse_arch(options.arch)
467
- stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
468
- stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
521
+ if language == Language.TRITON:
522
+ stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability)
523
+ stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
524
+ elif language == Language.GLUON:
525
+ stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options, capability)
469
526
  stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
470
527
  stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
471
528
  stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
472
529
 
473
530
  @functools.lru_cache()
474
531
  def hash(self):
475
- version = get_ptxas_version(self.target.arch)
532
+ version = get_ptxas_version()
476
533
  return f'{version}-{self.target.arch}'