triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.4.0.post20__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (166) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +14 -6
  59. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  60. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  61. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  64. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  65. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  66. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  67. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  68. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  69. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  70. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  71. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  72. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  73. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  80. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  81. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  82. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  83. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  84. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  85. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  86. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  87. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  88. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  89. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  90. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  91. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  92. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  93. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  94. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  95. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  96. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  97. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  98. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  99. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  100. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  101. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  102. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  103. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  104. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  105. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  106. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  107. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  108. triton/backends/amd/include/hip/device_functions.h +0 -38
  109. triton/backends/amd/include/hip/driver_types.h +0 -468
  110. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  111. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  112. triton/backends/amd/include/hip/hip_common.h +0 -100
  113. triton/backends/amd/include/hip/hip_complex.h +0 -38
  114. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  115. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  116. triton/backends/amd/include/hip/hip_ext.h +0 -161
  117. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  118. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  119. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  120. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  121. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  122. triton/backends/amd/include/hip/hip_profile.h +0 -27
  123. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  124. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  125. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  126. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  127. triton/backends/amd/include/hip/hip_version.h +0 -17
  128. triton/backends/amd/include/hip/hiprtc.h +0 -421
  129. triton/backends/amd/include/hip/library_types.h +0 -78
  130. triton/backends/amd/include/hip/math_functions.h +0 -42
  131. triton/backends/amd/include/hip/surface_types.h +0 -63
  132. triton/backends/amd/include/hip/texture_types.h +0 -194
  133. triton/backends/amd/include/hsa/Brig.h +0 -1131
  134. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  135. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  136. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  137. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  138. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  139. triton/backends/amd/include/hsa/hsa.h +0 -5738
  140. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  141. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  142. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  143. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  144. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  145. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  146. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  147. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  148. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  149. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  150. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  151. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  152. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  153. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  154. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  155. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  156. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  157. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  158. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  159. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  160. triton/backends/amd/include/roctracer/roctx.h +0 -229
  161. triton/language/_utils.py +0 -21
  162. triton/language/extra/cuda/_experimental_tma.py +0 -106
  163. triton/tools/experimental_descriptor.py +0 -32
  164. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  165. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  166. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
@@ -1,25 +1,29 @@
1
- from triton.backends.compiler import BaseBackend, GPUTarget
1
+ from triton.backends.compiler import BaseBackend, GPUTarget, Language
2
2
  from triton._C.libtriton import ir, passes, llvm, amd
3
+ from triton import knobs
3
4
  from dataclasses import dataclass
4
5
  from typing import Any, Dict, Tuple
5
6
  from types import ModuleType
6
7
  import hashlib
7
8
  import tempfile
8
- import os
9
9
  import re
10
10
  import subprocess
11
11
  import functools
12
12
  from pathlib import Path
13
13
 
14
14
 
15
- def min_dot_size(target: GPUTarget):
16
- # If some given configuration is not supported in hardware we fallback to FMA and cast arguments
17
- return lambda lhsType, rhsType: (1, 1, 1)
15
+ def get_min_dot_size(target: GPUTarget):
16
+ # We fallback to use FMA and cast arguments if certain configurations is
17
+ # not supported natively by matrix core units.
18
+ return lambda lhs_type, rhs_type: (1, 1, 1)
18
19
 
19
20
 
20
- def is_pingpong_enabled(arch):
21
- default = "1" if arch == "gfx942" else "0"
22
- return os.getenv("TRITON_HIP_USE_BLOCK_PINGPONG", default) == "1"
21
+ def is_pingpong_schedule_enabled(arch):
22
+ return (arch == "gfx942") if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
23
+
24
+
25
+ def is_in_thread_transpose_enabled(arch):
26
+ return (arch == "gfx942") if knobs.amd.use_in_thread_transpose is None else knobs.amd.use_in_thread_transpose
23
27
 
24
28
 
25
29
  @dataclass(frozen=True)
@@ -28,17 +32,13 @@ class HIPOptions:
28
32
  waves_per_eu: int = 1
29
33
  num_stages: int = 2
30
34
  num_ctas: int = 1
31
- num_buffers_warp_spec: int = 0
32
- num_consumer_groups: int = 0
33
- reg_dec_producer: int = 0
34
- reg_inc_consumer: int = 0
35
35
  extern_libs: dict = None
36
36
  cluster_dims: tuple = (1, 1, 1)
37
37
  debug: bool = False
38
38
  sanitize_overflow: bool = True
39
39
  arch: str = None
40
40
  supported_fp8_dtypes: Tuple[str] = ("fp8e5", )
41
- deprecated_fp8_dtypes: Tuple[str] = ()
41
+ deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
42
42
  default_dot_input_precision: str = "ieee"
43
43
  allowed_dot_input_precisions: Tuple[str] = ("ieee", )
44
44
  enable_fp_fusion: bool = True
@@ -57,32 +57,30 @@ class HIPOptions:
57
57
  #
58
58
  # Current experimental scheduling variants:
59
59
  #
60
- # llvm-iglp-0: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `0` to the GEMM's
61
- # k-loop; i.e., "interleave DS and MFMA instructions for small GEMM kernels".
62
- # llvm-iglp-1: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `1` to the GEMM's
63
- # k-loop; i.e., "interleave DS and MFMA instructions for single wave small
64
- # GEMM kernels.".
65
60
  # local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
66
61
  # Kernel library. Note, this variant requires the use of buffer load/store ops
67
62
  # and a special software pipelining style - i.e., 1x LDS and 1x register
68
63
  # prefetch buffers for each GEMM tile.
69
- instruction_sched_variant: str = 'none'
64
+ # attention: enables a bunch of optimizations for attention kernels, including:
65
+ # - iglp 2 and sched.barrier around it
66
+ # - sink-insts-to-avoid-spills flag to avoid register spills
67
+ schedule_hint: str = 'none'
70
68
 
71
69
  def __post_init__(self):
70
+ gfx_major = int(self.arch[3:-2]) # Drop "gfx" prefix and minor/patch number
71
+ warp_size = 32 if gfx_major >= 10 else 64
72
+ object.__setattr__(self, 'warp_size', warp_size)
73
+ assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
74
+ "num_warps must be a power of 2"
75
+
76
+ if self.arch == 'gfx950':
77
+ assert self.kpack == 1, "gfx950 only accepts kpack == 1"
78
+
72
79
  default_libdir = Path(__file__).parent / 'lib'
73
80
  extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
74
- # Ignore user-defined warp size for gfx9
75
- warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64
76
- object.__setattr__(self, 'warp_size', warp_size)
77
- # Only kpack=1 is supported on gfx950
78
- kpack = 1 if self.arch == 'gfx950' else self.kpack
79
- object.__setattr__(self, 'kpack', kpack)
80
- libs = ["ocml", "ockl"]
81
- for lib in libs:
81
+ for lib in ["ocml", "ockl"]:
82
82
  extern_libs[lib] = str(default_libdir / f'{lib}.bc')
83
83
  object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
84
- assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
85
- "num_warps must be a power of 2"
86
84
 
87
85
  def hash(self):
88
86
  key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
@@ -100,26 +98,32 @@ class HIPBackend(BaseBackend):
100
98
  assert isinstance(target.arch, str)
101
99
  self.binary_ext = "hsaco"
102
100
 
101
+ def get_target_name(self, options) -> str:
102
+ return f"hip:{options.arch}"
103
+
103
104
  def parse_options(self, opts) -> Any:
104
- args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", self.target.arch)}
105
+ args = {'arch': knobs.runtime.override_arch or self.target.arch}
105
106
 
106
107
  # Enable XF32 (TF32) for CDNA3 GPUs
107
- if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
108
+ if self.target.arch == 'gfx942':
108
109
  allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
109
110
  allowed_dot_input_precisions.update({'tf32'})
110
111
  args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
111
112
 
112
113
  if "supported_fp8_dtypes" not in opts:
113
114
  supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
114
- if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
115
+ if self.target.arch == 'gfx942':
115
116
  supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
116
- elif self.target.arch in ('gfx950'):
117
+ elif self.target.arch == 'gfx950':
118
+ supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
119
+ elif 'gfx12' in self.target.arch:
117
120
  supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
118
121
  args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
119
122
 
120
123
  if "enable_fp_fusion" not in opts:
121
- args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
122
- args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts and opts[k] is not None})
124
+ args["enable_fp_fusion"] = knobs.language.default_fp_fusion
125
+ args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() \
126
+ if k in opts and opts[k] is not None})
123
127
  return HIPOptions(**args)
124
128
 
125
129
  def pack_metadata(self, metadata):
@@ -133,8 +137,7 @@ class HIPBackend(BaseBackend):
133
137
  )
134
138
 
135
139
  def get_codegen_implementation(self, options):
136
- codegen_fns = {"min_dot_size": min_dot_size(self.target)}
137
- return codegen_fns
140
+ return {"min_dot_size": get_min_dot_size(self.target)}
138
141
 
139
142
  def get_module_map(self) -> Dict[str, ModuleType]:
140
143
  from triton.language.extra.hip import libdevice
@@ -144,11 +147,6 @@ class HIPBackend(BaseBackend):
144
147
  def load_dialects(self, ctx):
145
148
  amd.load_dialects(ctx)
146
149
 
147
- @staticmethod
148
- @functools.lru_cache()
149
- def use_buffer_ops():
150
- return os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"
151
-
152
150
  @staticmethod
153
151
  def is_within_2gb(arg):
154
152
  import torch
@@ -172,14 +170,14 @@ class HIPBackend(BaseBackend):
172
170
  ret = BaseBackend.get_arg_specialization(arg, ty, **kwargs)
173
171
  # Only attempt to do buffer ops specialization if buffer ops are enabled.
174
172
  # Otherwise the is_within_2gb check is unnecessary overhead.
175
- if HIPBackend.use_buffer_ops() and ty == "tensor" and HIPBackend.is_within_2gb(arg):
173
+ if knobs.amd.use_buffer_ops and ty == "tensor" and HIPBackend.is_within_2gb(arg):
176
174
  ret += "S"
177
175
  return ret
178
176
 
179
177
  @staticmethod
180
178
  def path_to_rocm_lld():
181
179
  # Check env path for ld.lld
182
- lld_env_path = os.getenv("TRITON_HIP_LLD_PATH")
180
+ lld_env_path = knobs.amd.lld_path
183
181
  if lld_env_path is not None:
184
182
  lld = Path(lld_env_path)
185
183
  if lld.is_file():
@@ -202,11 +200,12 @@ class HIPBackend(BaseBackend):
202
200
  pm.enable_debug()
203
201
  passes.common.add_inliner(pm)
204
202
  passes.ttir.add_rewrite_tensor_pointer(pm)
203
+ passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
205
204
  passes.common.add_canonicalizer(pm)
206
205
  passes.ttir.add_combine(pm)
207
206
  passes.ttir.add_reorder_broadcast(pm)
208
207
  passes.common.add_cse(pm)
209
- passes.common.add_licm(pm)
208
+ passes.ttir.add_triton_licm(pm)
210
209
  passes.common.add_symbol_dce(pm)
211
210
  passes.ttir.add_loop_unroll(pm)
212
211
  pm.run(mod)
@@ -230,39 +229,62 @@ class HIPBackend(BaseBackend):
230
229
  passes.ttgpuir.add_optimize_dot_operands(pm, True)
231
230
  amd.passes.ttgpuir.add_hoist_layout_conversions(pm)
232
231
 
233
- global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0"))
234
- local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0"))
232
+ passes.ttgpuir.add_fuse_nested_loops(pm)
233
+ passes.common.add_canonicalizer(pm)
234
+ passes.ttir.add_triton_licm(pm)
235
+ passes.common.add_canonicalizer(pm)
236
+
237
+ global_prefetch = knobs.amd.global_prefetch
238
+ local_prefetch = knobs.amd.local_prefetch
239
+ use_async_copy = knobs.amd.use_async_copy
235
240
 
236
241
  # The `local-prefetch` scheduling variant requires turning on buffer ops.
237
- if options.instruction_sched_variant == "local-prefetch":
242
+ if options.schedule_hint == "local-prefetch":
238
243
  global_prefetch = local_prefetch = 1
239
244
 
240
- if amd.has_matrix_core_feature(options.arch):
241
- assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. "
242
- "We used to trigger software pipelining with "
243
- "num_stages == 0. Now it will not happen anymore; "
244
- "please update to use num_stages == 2 for "
245
- "equivalent behavior in the past.")
246
- amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch)
247
- passes.common.add_canonicalizer(pm)
248
- if options.instruction_sched_variant.lower() != "none":
249
- amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant)
245
+ amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy)
246
+ if use_async_copy:
247
+ amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
248
+ passes.common.add_canonicalizer(pm)
249
+ if options.schedule_hint.lower() != "none":
250
+ amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.schedule_hint)
250
251
  passes.ttgpuir.add_optimize_dot_operands(pm, True)
251
252
  passes.ttgpuir.add_remove_layout_conversions(pm)
252
253
  passes.ttgpuir.add_reduce_data_duplication(pm)
253
- if amd.has_matrix_core_feature(options.arch):
254
- amd.passes.ttgpuir.add_reorder_instructions(pm)
255
- use_block_pingpong = is_pingpong_enabled(options.arch)
256
- if use_block_pingpong and options.num_stages == 2:
257
- amd.passes.ttgpuir.add_block_pingpong(pm)
258
-
259
- if HIPBackend.use_buffer_ops():
254
+ if is_in_thread_transpose_enabled(options.arch):
255
+ amd.passes.ttgpuir.add_in_thread_transpose(pm)
256
+ passes.ttgpuir.add_remove_layout_conversions(pm)
257
+ amd.passes.ttgpuir.add_reorder_instructions(pm)
258
+ use_block_pingpong = is_pingpong_schedule_enabled(options.arch)
259
+ if use_block_pingpong and options.num_stages == 2:
260
+ amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
261
+
262
+ if knobs.amd.use_buffer_ops:
260
263
  amd.passes.ttgpuir.add_canonicalize_pointers(pm)
261
264
  passes.common.add_canonicalizer(pm)
262
265
  amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch)
266
+
267
+ amd.passes.ttgpuir.add_fold_true_cmpi(pm)
263
268
  passes.common.add_canonicalizer(pm)
264
269
  passes.common.add_cse(pm)
265
270
  passes.common.add_symbol_dce(pm)
271
+ if use_async_copy:
272
+ amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch)
273
+ pm.run(mod)
274
+ return mod
275
+
276
+ @staticmethod
277
+ def ttgir_opt(src, metadata, options):
278
+ mod = src
279
+ pm = ir.pass_manager(mod.context)
280
+ pm.enable_debug()
281
+
282
+ passes.ttgpuir.add_inliner(pm)
283
+ passes.common.add_sccp(pm)
284
+ passes.ttir.add_loop_aware_cse(pm)
285
+ passes.ttgpuir.add_canonicalizer(pm)
286
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
287
+
266
288
  pm.run(mod)
267
289
  return mod
268
290
 
@@ -272,7 +294,6 @@ class HIPBackend(BaseBackend):
272
294
  # TritonGPU -> LLVM-IR (MLIR)
273
295
  pm = ir.pass_manager(mod.context)
274
296
  pm.enable_debug()
275
- amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch)
276
297
  # custom_lds_size is an experimental parameter that defines amount of LDS available
277
298
  # for one thread block. Measured in bytes.
278
299
  #
@@ -301,9 +322,9 @@ class HIPBackend(BaseBackend):
301
322
  passes.common.add_canonicalizer(pm)
302
323
  passes.common.add_cse(pm)
303
324
  passes.common.add_symbol_dce(pm)
304
- if options.instruction_sched_variant.lower() != "none":
325
+ if options.schedule_hint.lower() != "none":
305
326
  amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
306
- if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
327
+ if not knobs.compilation.disable_line_info:
307
328
  passes.llvmir.add_di_scope(pm)
308
329
  amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
309
330
  pm.run(mod)
@@ -314,7 +335,7 @@ class HIPBackend(BaseBackend):
314
335
  llvm_mod = llvm.to_module(mod, context)
315
336
  amd.attach_target_triple(llvm_mod)
316
337
  target_features = ''
317
- if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
338
+ if knobs.compilation.enable_asan:
318
339
  target_features = '+xnack'
319
340
  llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features)
320
341
 
@@ -342,7 +363,7 @@ class HIPBackend(BaseBackend):
342
363
  fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}")
343
364
  denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee"
344
365
  fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode)
345
- if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
366
+ if knobs.compilation.enable_asan:
346
367
  fns[0].add_fn_target_feature("+xnack")
347
368
  fns[0].add_fn_asan_attr()
348
369
 
@@ -351,7 +372,7 @@ class HIPBackend(BaseBackend):
351
372
  # from memory.
352
373
  amd.set_all_fn_arg_inreg(fns[0])
353
374
 
354
- if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
375
+ if knobs.compilation.enable_asan:
355
376
  default_libdir = Path(__file__).parent / 'lib'
356
377
  paths = [
357
378
  str(default_libdir / 'asanrtl.bc'),
@@ -365,6 +386,9 @@ class HIPBackend(BaseBackend):
365
386
 
366
387
  llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
367
388
 
389
+ if knobs.amd.scalarize_packed_fops:
390
+ amd.add_scalarize_packed_fops_llvm_pass(fns[0])
391
+
368
392
  # Get some metadata
369
393
  metadata["shared"] = src.get_int_attr("ttg.shared")
370
394
 
@@ -377,14 +401,21 @@ class HIPBackend(BaseBackend):
377
401
  @staticmethod
378
402
  def make_amdgcn(src, metadata, options):
379
403
  # Find kernel names (there should only be one)
380
- # We get the name at the last possible step to accomodate `triton.compile`
404
+ # We get the name at the last possible step to accommodate `triton.compile`
381
405
  # on user-provided LLVM
382
406
  names = re.findall(r"define amdgpu_kernel void @([a-zA-Z_][a-zA-Z0-9_]*)", src)
383
407
  assert len(names) == 1
384
408
  metadata["name"] = names[0]
385
409
  # llvm -> hsaco
386
- amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', [], options.enable_fp_fusion, False)
387
- if os.environ.get("AMDGCN_ENABLE_DUMP", "0") == "1":
410
+ flags = []
411
+ # The sink-insts-to-avoid-spills flag asks LLVM backend to sink instructions
412
+ # into loops to avoid register spills in the MachineSinking pass, while it
413
+ # can also lead to regression in some cases. But from current observation,
414
+ # the regression is not significant. It would be better to have some heuristics.
415
+ if options.schedule_hint == 'attention':
416
+ flags.append('sink-insts-to-avoid-spills')
417
+ amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', flags, options.enable_fp_fusion, False)
418
+ if knobs.amd.dump_amdgcn:
388
419
  print("// -----// AMDGCN Dump //----- //")
389
420
  print(amdgcn)
390
421
  return amdgcn
@@ -392,7 +423,7 @@ class HIPBackend(BaseBackend):
392
423
  @staticmethod
393
424
  def make_hsaco(src, metadata, options):
394
425
  target_features = ''
395
- if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
426
+ if knobs.compilation.enable_asan:
396
427
  target_features = '+xnack'
397
428
  hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
398
429
 
@@ -406,9 +437,12 @@ class HIPBackend(BaseBackend):
406
437
  ret = fd_out.read()
407
438
  return ret
408
439
 
409
- def add_stages(self, stages, options):
410
- stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
411
- stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options)
440
+ def add_stages(self, stages, options, language):
441
+ if language == Language.TRITON:
442
+ stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
443
+ stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options)
444
+ elif language == Language.GLUON:
445
+ stages["ttgir"] = lambda src, metadata: self.ttgir_opt(src, metadata, options)
412
446
  stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
413
447
  stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options)
414
448
  stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options)
@@ -172,15 +172,18 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
172
172
  // get allocated registers and spilled registers from the function
173
173
  int n_regs = 0;
174
174
  int n_spills = 0;
175
+ int32_t n_max_threads = 0;
175
176
  hipSymbolTable.hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, fun);
176
177
  hipSymbolTable.hipFuncGetAttribute(&n_spills,
177
178
  HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
179
+ hipSymbolTable.hipFuncGetAttribute(
180
+ &n_max_threads, HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun);
178
181
  n_spills /= 4;
179
182
  if (PyErr_Occurred()) {
180
183
  return NULL;
181
184
  }
182
- return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
183
- n_spills);
185
+ return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
186
+ n_spills, n_max_threads);
184
187
  }
185
188
 
186
189
  static PyMethodDef ModuleMethods[] = {