triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -1,25 +1,30 @@
1
- from triton.backends.compiler import BaseBackend, GPUTarget
1
+ from triton.backends.compiler import BaseBackend, GPUTarget, Language
2
2
  from triton._C.libtriton import ir, passes, llvm, amd
3
+ from triton import knobs
3
4
  from dataclasses import dataclass
4
5
  from typing import Any, Dict, Tuple
5
6
  from types import ModuleType
6
7
  import hashlib
7
8
  import tempfile
8
- import os
9
9
  import re
10
- import subprocess
11
10
  import functools
11
+ import warnings
12
12
  from pathlib import Path
13
13
 
14
14
 
15
- def min_dot_size(target: GPUTarget):
16
- # If some given configuration is not supported in hardware we fallback to FMA and cast arguments
17
- return lambda lhsType, rhsType: (1, 1, 1)
15
+ def get_min_dot_size(target: GPUTarget):
16
+ # We fallback to use FMA and cast arguments if certain configurations is
17
+ # not supported natively by matrix core units.
18
+ return lambda lhs_type, rhs_type: (1, 1, 1)
18
19
 
19
20
 
20
- def is_pingpong_enabled(arch):
21
- default = "1" if arch == "gfx942" else "0"
22
- return os.getenv("TRITON_HIP_USE_BLOCK_PINGPONG", default) == "1"
21
+ def is_pingpong_schedule_enabled(arch, use_async_copy):
22
+ return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)
23
+ ) if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
24
+
25
+
26
+ def is_in_thread_transpose_enabled(arch):
27
+ return (arch == "gfx942") if knobs.amd.use_in_thread_transpose is None else knobs.amd.use_in_thread_transpose
23
28
 
24
29
 
25
30
  @dataclass(frozen=True)
@@ -28,17 +33,17 @@ class HIPOptions:
28
33
  waves_per_eu: int = 1
29
34
  num_stages: int = 2
30
35
  num_ctas: int = 1
31
- num_buffers_warp_spec: int = 0
32
- num_consumer_groups: int = 0
33
- reg_dec_producer: int = 0
34
- reg_inc_consumer: int = 0
35
36
  extern_libs: dict = None
36
37
  cluster_dims: tuple = (1, 1, 1)
37
38
  debug: bool = False
38
39
  sanitize_overflow: bool = True
39
40
  arch: str = None
40
- supported_fp8_dtypes: Tuple[str] = ("fp8e5", )
41
- deprecated_fp8_dtypes: Tuple[str] = ()
41
+ # We have native support for OCP fp8 variants since CDNA4/RDNA4. For earlier generations,
42
+ # we software emulate the support for them.
43
+ # UZ fp8 variants (fp8e4b8 and fp8e5b16) are natively supported for CDNA3. For other
44
+ # architectures they are software emulated.
45
+ supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e5b16", "fp8e4b8")
46
+ deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
42
47
  default_dot_input_precision: str = "ieee"
43
48
  allowed_dot_input_precisions: Tuple[str] = ("ieee", )
44
49
  enable_fp_fusion: bool = True
@@ -48,6 +53,7 @@ class HIPOptions:
48
53
  allow_flush_denorm: bool = False
49
54
  max_num_imprecise_acc_default: int = 0
50
55
  backend_name: str = 'hip'
56
+ instrumentation_mode: str = ""
51
57
 
52
58
  # The following option provides hints to the AMDGPU backend regarding instruction scheduling
53
59
  # for all `tt.dot` operations in a kernel. The "none" variant preserves the default
@@ -57,32 +63,29 @@ class HIPOptions:
57
63
  #
58
64
  # Current experimental scheduling variants:
59
65
  #
60
- # llvm-iglp-0: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `0` to the GEMM's
61
- # k-loop; i.e., "interleave DS and MFMA instructions for small GEMM kernels".
62
- # llvm-iglp-1: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `1` to the GEMM's
63
- # k-loop; i.e., "interleave DS and MFMA instructions for single wave small
64
- # GEMM kernels.".
65
- # local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
66
- # Kernel library. Note, this variant requires the use of buffer load/store ops
67
- # and a special software pipelining style - i.e., 1x LDS and 1x register
68
- # prefetch buffers for each GEMM tile.
69
- instruction_sched_variant: str = 'none'
66
+ # attention: enables a bunch of optimizations for attention kernels, including:
67
+ # - iglp 2 and sched.barrier around it
68
+ # - sink-insts-to-avoid-spills flag to avoid register spills
69
+ schedule_hint: str = 'none'
70
70
 
71
71
  def __post_init__(self):
72
+ gfx_major = int(self.arch[3:-2]) # Drop "gfx" prefix and minor/patch number
73
+ warp_size = 32 if gfx_major >= 10 else 64
74
+ object.__setattr__(self, 'warp_size', warp_size)
75
+ assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
76
+ "num_warps must be a power of 2"
77
+
78
+ if (self.arch == 'gfx950') and (self.kpack != 1):
79
+ warnings.warn(
80
+ f"kpack is deprecated starting from gfx950 and will be removed in later releases. So for now kpack = {self.kpack} will be overwritten to 1 to make transitioning easier."
81
+ )
82
+ object.__setattr__(self, 'kpack', 1)
83
+
72
84
  default_libdir = Path(__file__).parent / 'lib'
73
85
  extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
74
- # Ignore user-defined warp size for gfx9
75
- warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64
76
- object.__setattr__(self, 'warp_size', warp_size)
77
- # Only kpack=1 is supported on gfx950
78
- kpack = 1 if self.arch == 'gfx950' else self.kpack
79
- object.__setattr__(self, 'kpack', kpack)
80
- libs = ["ocml", "ockl"]
81
- for lib in libs:
86
+ for lib in ["ocml", "ockl"]:
82
87
  extern_libs[lib] = str(default_libdir / f'{lib}.bc')
83
88
  object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
84
- assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
85
- "num_warps must be a power of 2"
86
89
 
87
90
  def hash(self):
88
91
  key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
@@ -90,6 +93,7 @@ class HIPOptions:
90
93
 
91
94
 
92
95
  class HIPBackend(BaseBackend):
96
+ instrumentation = None
93
97
 
94
98
  @staticmethod
95
99
  def supports_target(target: GPUTarget):
@@ -100,26 +104,33 @@ class HIPBackend(BaseBackend):
100
104
  assert isinstance(target.arch, str)
101
105
  self.binary_ext = "hsaco"
102
106
 
107
+ def get_target_name(self, options) -> str:
108
+ return f"hip:{options.arch}"
109
+
103
110
  def parse_options(self, opts) -> Any:
104
- args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", self.target.arch)}
111
+ args = {'arch': knobs.runtime.override_arch or self.target.arch}
112
+
113
+ if opts.get("num_ctas", 1) > 1:
114
+ raise ValueError("num_ctas > 1 not supported for AMD GPUs")
105
115
 
106
116
  # Enable XF32 (TF32) for CDNA3 GPUs
107
- if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
117
+ if self.target.arch == 'gfx942':
108
118
  allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
109
119
  allowed_dot_input_precisions.update({'tf32'})
110
120
  args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
111
121
 
112
122
  if "supported_fp8_dtypes" not in opts:
113
- supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
114
- if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
115
- supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
116
- elif self.target.arch in ('gfx950'):
117
- supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
118
- args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
123
+ args["supported_fp8_dtypes"] = tuple(sorted(HIPOptions.supported_fp8_dtypes))
124
+
125
+ if self.target.arch == 'gfx950':
126
+ deprecated_fp8_dot_operand_dtypes = set(HIPOptions.deprecated_fp8_dot_operand_dtypes)
127
+ deprecated_fp8_dot_operand_dtypes.update({"fp8e5b16", "fp8e4b8"})
128
+ args["deprecated_fp8_dot_operand_dtypes"] = tuple(sorted(deprecated_fp8_dot_operand_dtypes))
119
129
 
120
130
  if "enable_fp_fusion" not in opts:
121
- args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
122
- args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts and opts[k] is not None})
131
+ args["enable_fp_fusion"] = knobs.language.default_fp_fusion
132
+ args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() \
133
+ if k in opts and opts[k] is not None})
123
134
  return HIPOptions(**args)
124
135
 
125
136
  def pack_metadata(self, metadata):
@@ -133,8 +144,7 @@ class HIPBackend(BaseBackend):
133
144
  )
134
145
 
135
146
  def get_codegen_implementation(self, options):
136
- codegen_fns = {"min_dot_size": min_dot_size(self.target)}
137
- return codegen_fns
147
+ return {"min_dot_size": get_min_dot_size(self.target)}
138
148
 
139
149
  def get_module_map(self) -> Dict[str, ModuleType]:
140
150
  from triton.language.extra.hip import libdevice
@@ -143,11 +153,8 @@ class HIPBackend(BaseBackend):
143
153
 
144
154
  def load_dialects(self, ctx):
145
155
  amd.load_dialects(ctx)
146
-
147
- @staticmethod
148
- @functools.lru_cache()
149
- def use_buffer_ops():
150
- return os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"
156
+ if HIPBackend.instrumentation:
157
+ HIPBackend.instrumentation.load_dialects(ctx)
151
158
 
152
159
  @staticmethod
153
160
  def is_within_2gb(arg):
@@ -172,41 +179,22 @@ class HIPBackend(BaseBackend):
172
179
  ret = BaseBackend.get_arg_specialization(arg, ty, **kwargs)
173
180
  # Only attempt to do buffer ops specialization if buffer ops are enabled.
174
181
  # Otherwise the is_within_2gb check is unnecessary overhead.
175
- if HIPBackend.use_buffer_ops() and ty == "tensor" and HIPBackend.is_within_2gb(arg):
182
+ if knobs.amd.use_buffer_ops and ty == "tensor" and HIPBackend.is_within_2gb(arg):
176
183
  ret += "S"
177
184
  return ret
178
185
 
179
- @staticmethod
180
- def path_to_rocm_lld():
181
- # Check env path for ld.lld
182
- lld_env_path = os.getenv("TRITON_HIP_LLD_PATH")
183
- if lld_env_path is not None:
184
- lld = Path(lld_env_path)
185
- if lld.is_file():
186
- return lld
187
- # Check backend for ld.lld (used for pytorch wheels)
188
- lld = Path(__file__).parent / "llvm/bin/ld.lld"
189
- if lld.is_file():
190
- return lld
191
- lld = Path("/opt/rocm/llvm/bin/ld.lld")
192
- if lld.is_file():
193
- return lld
194
- lld = Path("/usr/bin/ld.lld")
195
- if lld.is_file():
196
- return lld
197
- raise Exception("ROCm linker /opt/rocm/llvm/bin/ld.lld not found. Set 'TRITON_HIP_LLD_PATH' to its path.")
198
-
199
186
  @staticmethod
200
187
  def make_ttir(mod, metadata, options):
201
188
  pm = ir.pass_manager(mod.context)
202
189
  pm.enable_debug()
203
190
  passes.common.add_inliner(pm)
204
191
  passes.ttir.add_rewrite_tensor_pointer(pm)
192
+ passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
205
193
  passes.common.add_canonicalizer(pm)
206
194
  passes.ttir.add_combine(pm)
207
195
  passes.ttir.add_reorder_broadcast(pm)
208
196
  passes.common.add_cse(pm)
209
- passes.common.add_licm(pm)
197
+ passes.ttir.add_triton_licm(pm)
210
198
  passes.common.add_symbol_dce(pm)
211
199
  passes.ttir.add_loop_unroll(pm)
212
200
  pm.run(mod)
@@ -230,39 +218,60 @@ class HIPBackend(BaseBackend):
230
218
  passes.ttgpuir.add_optimize_dot_operands(pm, True)
231
219
  amd.passes.ttgpuir.add_hoist_layout_conversions(pm)
232
220
 
233
- global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0"))
234
- local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0"))
221
+ passes.ttgpuir.add_fuse_nested_loops(pm)
222
+ passes.common.add_canonicalizer(pm)
223
+ passes.ttir.add_triton_licm(pm)
224
+ passes.common.add_canonicalizer(pm)
235
225
 
236
- # The `local-prefetch` scheduling variant requires turning on buffer ops.
237
- if options.instruction_sched_variant == "local-prefetch":
238
- global_prefetch = local_prefetch = 1
226
+ global_prefetch = knobs.amd.global_prefetch
227
+ local_prefetch = knobs.amd.local_prefetch
228
+ use_async_copy = knobs.amd.use_async_copy
229
+ use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy)
239
230
 
240
- if amd.has_matrix_core_feature(options.arch):
241
- assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. "
242
- "We used to trigger software pipelining with "
243
- "num_stages == 0. Now it will not happen anymore; "
244
- "please update to use num_stages == 2 for "
245
- "equivalent behavior in the past.")
246
- amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch)
247
- passes.common.add_canonicalizer(pm)
248
- if options.instruction_sched_variant.lower() != "none":
249
- amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant)
231
+ amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy,
232
+ use_block_pingpong)
233
+ if use_async_copy:
234
+ amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
235
+ passes.common.add_canonicalizer(pm)
236
+ if options.schedule_hint.lower() != "none":
237
+ amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.schedule_hint)
250
238
  passes.ttgpuir.add_optimize_dot_operands(pm, True)
251
239
  passes.ttgpuir.add_remove_layout_conversions(pm)
252
240
  passes.ttgpuir.add_reduce_data_duplication(pm)
253
- if amd.has_matrix_core_feature(options.arch):
254
- amd.passes.ttgpuir.add_reorder_instructions(pm)
255
- use_block_pingpong = is_pingpong_enabled(options.arch)
256
- if use_block_pingpong and options.num_stages == 2:
257
- amd.passes.ttgpuir.add_block_pingpong(pm)
258
-
259
- if HIPBackend.use_buffer_ops():
241
+ if is_in_thread_transpose_enabled(options.arch):
242
+ amd.passes.ttgpuir.add_in_thread_transpose(pm)
243
+ passes.ttgpuir.add_remove_layout_conversions(pm)
244
+ amd.passes.ttgpuir.add_reorder_instructions(pm)
245
+ if use_block_pingpong and options.num_stages > 1:
246
+ amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
247
+
248
+ if knobs.amd.use_buffer_ops:
260
249
  amd.passes.ttgpuir.add_canonicalize_pointers(pm)
261
250
  passes.common.add_canonicalizer(pm)
262
- amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch)
251
+ amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch, knobs.amd.use_buffer_atomics)
252
+
253
+ amd.passes.ttgpuir.add_fold_true_cmpi(pm)
263
254
  passes.common.add_canonicalizer(pm)
264
255
  passes.common.add_cse(pm)
265
256
  passes.common.add_symbol_dce(pm)
257
+ if use_async_copy:
258
+ amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch)
259
+ pm.run(mod)
260
+ return mod
261
+
262
+ @staticmethod
263
+ def gluon_to_ttgir(src, metadata, options):
264
+ mod = src
265
+ pm = ir.pass_manager(mod.context)
266
+ pm.enable_debug()
267
+
268
+ passes.gluon.add_inliner(pm)
269
+ passes.gluon.add_resolve_auto_encodings(pm)
270
+ passes.common.add_sccp(pm)
271
+ passes.ttir.add_loop_aware_cse(pm)
272
+ passes.gluon.add_canonicalizer(pm)
273
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
274
+
266
275
  pm.run(mod)
267
276
  return mod
268
277
 
@@ -272,7 +281,6 @@ class HIPBackend(BaseBackend):
272
281
  # TritonGPU -> LLVM-IR (MLIR)
273
282
  pm = ir.pass_manager(mod.context)
274
283
  pm.enable_debug()
275
- amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch)
276
284
  # custom_lds_size is an experimental parameter that defines amount of LDS available
277
285
  # for one thread block. Measured in bytes.
278
286
  #
@@ -283,7 +291,10 @@ class HIPBackend(BaseBackend):
283
291
  passes.convert.add_scf_to_cf(pm)
284
292
  passes.convert.add_index_to_llvmir(pm)
285
293
 
286
- passes.ttgpuir.add_allocate_shared_memory(pm)
294
+ amd.passes.ttgpuir.add_allocate_shared_memory(pm)
295
+ # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
296
+ if HIPBackend.instrumentation:
297
+ HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
287
298
  ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
288
299
  ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
289
300
  ## of the value of kernel arg `allow_flush_denorm`.
@@ -301,10 +312,17 @@ class HIPBackend(BaseBackend):
301
312
  passes.common.add_canonicalizer(pm)
302
313
  passes.common.add_cse(pm)
303
314
  passes.common.add_symbol_dce(pm)
304
- if options.instruction_sched_variant.lower() != "none":
315
+
316
+ if options.schedule_hint.lower() != "none":
305
317
  amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
306
- if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
318
+
319
+ # This can not be moved below the di_scope pass
320
+ if HIPBackend.instrumentation:
321
+ HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
322
+
323
+ if not knobs.compilation.disable_line_info:
307
324
  passes.llvmir.add_di_scope(pm)
325
+
308
326
  amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
309
327
  pm.run(mod)
310
328
 
@@ -314,7 +332,7 @@ class HIPBackend(BaseBackend):
314
332
  llvm_mod = llvm.to_module(mod, context)
315
333
  amd.attach_target_triple(llvm_mod)
316
334
  target_features = ''
317
- if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
335
+ if knobs.compilation.enable_asan:
318
336
  target_features = '+xnack'
319
337
  llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features)
320
338
 
@@ -342,7 +360,7 @@ class HIPBackend(BaseBackend):
342
360
  fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}")
343
361
  denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee"
344
362
  fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode)
345
- if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
363
+ if knobs.compilation.enable_asan:
346
364
  fns[0].add_fn_target_feature("+xnack")
347
365
  fns[0].add_fn_asan_attr()
348
366
 
@@ -351,7 +369,7 @@ class HIPBackend(BaseBackend):
351
369
  # from memory.
352
370
  amd.set_all_fn_arg_inreg(fns[0])
353
371
 
354
- if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
372
+ if knobs.compilation.enable_asan:
355
373
  default_libdir = Path(__file__).parent / 'lib'
356
374
  paths = [
357
375
  str(default_libdir / 'asanrtl.bc'),
@@ -361,12 +379,27 @@ class HIPBackend(BaseBackend):
361
379
  llvm.link_extern_libs(llvm_mod, paths)
362
380
  elif options.extern_libs:
363
381
  paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
364
- llvm.link_extern_libs(llvm_mod, paths)
382
+ if len(paths) > 0:
383
+ llvm.link_extern_libs(llvm_mod, paths)
365
384
 
366
385
  llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
367
386
 
387
+ # Architectures with architected SGPRs store the workgroup id in ttmp9 (X) and ttmp7 (Y[15:0], Z[31:16]).
388
+ # These attributes are used to determine if Z should be masked out when loading Y. They are inferred during
389
+ # optimize_module from calls to @llvm.amdgcn.workgroup.id.x/y/z(). We cannot rely on this because a
390
+ # dispatch dimensions might be used even if there is no program_id() call for it.
391
+ if amd.has_architected_sgprs(options.arch):
392
+ fns[0].remove_fn_attr("amdgpu-no-workgroup-id-x")
393
+ fns[0].remove_fn_attr("amdgpu-no-workgroup-id-y")
394
+ fns[0].remove_fn_attr("amdgpu-no-workgroup-id-z")
395
+
396
+ if knobs.amd.scalarize_packed_fops:
397
+ amd.add_scalarize_packed_fops_llvm_pass(fns[0])
398
+
368
399
  # Get some metadata
369
400
  metadata["shared"] = src.get_int_attr("ttg.shared")
401
+ metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
402
+ metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
370
403
 
371
404
  amd.cleanup_bitcode_metadata(llvm_mod)
372
405
  # Disable inlining of print related functions,
@@ -377,14 +410,23 @@ class HIPBackend(BaseBackend):
377
410
  @staticmethod
378
411
  def make_amdgcn(src, metadata, options):
379
412
  # Find kernel names (there should only be one)
380
- # We get the name at the last possible step to accomodate `triton.compile`
413
+ # We get the name at the last possible step to accommodate `triton.compile`
381
414
  # on user-provided LLVM
382
415
  names = re.findall(r"define amdgpu_kernel void @([a-zA-Z_][a-zA-Z0-9_]*)", src)
383
416
  assert len(names) == 1
384
417
  metadata["name"] = names[0]
385
418
  # llvm -> hsaco
386
- amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', [], options.enable_fp_fusion, False)
387
- if os.environ.get("AMDGCN_ENABLE_DUMP", "0") == "1":
419
+ flags = []
420
+ # The sink-insts-to-avoid-spills flag asks LLVM backend to sink instructions
421
+ # into loops to avoid register spills in the MachineSinking pass, while it
422
+ # can also lead to regression in some cases. But from current observation,
423
+ # the regression is not significant. It would be better to have some heuristics.
424
+ if options.schedule_hint == 'attention':
425
+ flags.append('sink-insts-to-avoid-spills')
426
+ features = '-real-true16' if 'gfx11' in options.arch else ''
427
+ amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion,
428
+ False)
429
+ if knobs.amd.dump_amdgcn:
388
430
  print("// -----// AMDGCN Dump //----- //")
389
431
  print(amdgcn)
390
432
  return amdgcn
@@ -392,28 +434,28 @@ class HIPBackend(BaseBackend):
392
434
  @staticmethod
393
435
  def make_hsaco(src, metadata, options):
394
436
  target_features = ''
395
- if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
437
+ if knobs.compilation.enable_asan:
396
438
  target_features = '+xnack'
397
439
  hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
398
-
399
- rocm_path = HIPBackend.path_to_rocm_lld()
400
440
  with tempfile.NamedTemporaryFile() as tmp_out:
401
441
  with tempfile.NamedTemporaryFile() as tmp_in:
402
- with open(tmp_in.name, 'wb') as fd_in:
442
+ with open(tmp_in.name, "wb") as fd_in:
403
443
  fd_in.write(hsaco)
404
- subprocess.check_call([rocm_path, '-flavor', 'gnu', '-shared', tmp_in.name, '-o', tmp_out.name])
405
- with open(tmp_out.name, 'rb') as fd_out:
444
+ amd.link_hsaco(tmp_in.name, tmp_out.name)
445
+ with open(tmp_out.name, "rb") as fd_out:
406
446
  ret = fd_out.read()
407
447
  return ret
408
448
 
409
- def add_stages(self, stages, options):
410
- stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
411
- stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options)
449
+ def add_stages(self, stages, options, language):
450
+ if language == Language.TRITON:
451
+ stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
452
+ stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options)
453
+ elif language == Language.GLUON:
454
+ stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
412
455
  stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
413
456
  stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options)
414
457
  stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options)
415
458
 
416
459
  @functools.lru_cache()
417
460
  def hash(self):
418
- version = subprocess.check_output([HIPBackend.path_to_rocm_lld(), "--version"], encoding='utf-8')
419
- return f'{version}-{self.target}'
461
+ return f'{self.target}'