triton-windows 3.2.0.post12__cp313-cp313-win_amd64.whl → 3.3.0a0.post12__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
triton/_C/libtriton.pyd CHANGED
Binary file
triton/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """isort:skip_file"""
2
- __version__ = '3.2.0'
2
+ __version__ = '3.3.0'
3
3
 
4
4
  # Users may not know how to add cl and CUDA to PATH. Let's do it before loading anything
5
5
  import os
@@ -32,6 +32,7 @@ from .runtime import (
32
32
  from .runtime.jit import jit
33
33
  from .compiler import compile, CompilationError
34
34
  from .errors import TritonError
35
+ from .runtime._allocation import set_allocator
35
36
 
36
37
  from . import language
37
38
  from . import testing
@@ -44,7 +45,6 @@ __all__ = [
44
45
  "compile",
45
46
  "Config",
46
47
  "heuristics",
47
- "impl",
48
48
  "InterpreterError",
49
49
  "jit",
50
50
  "JITFunction",
@@ -52,10 +52,10 @@ __all__ = [
52
52
  "language",
53
53
  "MockTensor",
54
54
  "next_power_of_2",
55
- "ops",
56
55
  "OutOfResources",
57
56
  "reinterpret",
58
57
  "runtime",
58
+ "set_allocator",
59
59
  "TensorWrapper",
60
60
  "TritonError",
61
61
  "testing",
@@ -4,16 +4,18 @@ import numpy as np
4
4
  import torch
5
5
  import triton
6
6
  import triton.language as tl
7
+ from triton.backends.nvidia.compiler import _path_to_binary
7
8
  import pytest
8
9
 
9
10
  from numpy.random import RandomState
10
11
  from typing import Optional, Union
11
- from triton.runtime.jit import TensorWrapper, reinterpret
12
+ from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict
12
13
 
13
14
  int_dtypes = ['int8', 'int16', 'int32', 'int64']
14
15
  uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
15
16
  integral_dtypes = int_dtypes + uint_dtypes
16
17
  float_dtypes = ['float16', 'float32', 'float64']
18
+ float_dtypes_with_bfloat16 = float_dtypes + ['bfloat16']
17
19
  dtypes = integral_dtypes + float_dtypes
18
20
  dtypes_with_bfloat16 = dtypes + ['bfloat16']
19
21
  torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2']
@@ -35,11 +37,45 @@ def is_cuda():
35
37
  return False if target is None else target.backend == "cuda"
36
38
 
37
39
 
40
+ def is_hopper():
41
+ return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
42
+
43
+
38
44
  def is_hip():
39
45
  target = get_current_target()
40
46
  return False if target is None else target.backend == "hip"
41
47
 
42
48
 
49
+ def is_hip_mi200():
50
+ target = get_current_target()
51
+ if target is None or target.backend != 'hip':
52
+ return False
53
+ return target.arch == 'gfx90a'
54
+
55
+
56
+ def is_hip_mi300():
57
+ target = get_current_target()
58
+ if target is None or target.backend != 'hip':
59
+ return False
60
+ return target.arch in ('gfx940', 'gfx941', 'gfx942')
61
+
62
+
63
+ def is_hip_mi350():
64
+ target = get_current_target()
65
+ if target is None or target.backend != 'hip':
66
+ return False
67
+ return target.arch in ('gfx950')
68
+
69
+
70
+ def is_hip_cdna():
71
+ return is_hip_mi200() or is_hip_mi300() or is_hip_mi350()
72
+
73
+
74
+ def is_xpu():
75
+ target = get_current_target()
76
+ return False if target is None else target.backend == "xpu"
77
+
78
+
43
79
  def get_arch():
44
80
  target = get_current_target()
45
81
  return "" if target is None else str(target.arch)
@@ -94,6 +130,10 @@ def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torc
94
130
  return torch.tensor(x, device=device)
95
131
 
96
132
 
133
+ def str_to_triton_dtype(x: str) -> tl.dtype:
134
+ return tl.str_to_ty(type_canonicalisation_dict[x])
135
+
136
+
97
137
  def torch_dtype_name(dtype) -> str:
98
138
  if isinstance(dtype, triton.language.dtype):
99
139
  return dtype.name
@@ -116,8 +156,23 @@ def to_numpy(x):
116
156
  raise ValueError(f"Not a triton-compatible tensor: {x}")
117
157
 
118
158
 
119
- def supports_tma():
120
- return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
159
+ def supports_tma(byval_only=False):
160
+ if is_interpreter():
161
+ return True
162
+ if not is_cuda():
163
+ return False
164
+ _, cuda_version = _path_to_binary("ptxas")
165
+ min_cuda_version = (12, 0) if byval_only else (12, 3)
166
+ cuda_version_tuple = tuple(map(int, cuda_version.split(".")))
167
+ assert len(cuda_version_tuple) == 2, cuda_version_tuple
168
+ return torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version
169
+
170
+
171
+ def tma_skip_msg(byval_only=False):
172
+ if byval_only:
173
+ return "Requires __grid_constant__ TMA support (NVIDIA Hopper or higher, CUDA 12.0 or higher)"
174
+ else:
175
+ return "Requires advanced TMA support (NVIDIA Hopper or higher, CUDA 12.3 or higher)"
121
176
 
122
177
 
123
- requires_tma = pytest.mark.skipif(not supports_tma(), reason="Requires TMA support (NVIDIA Hopper or higher)")
178
+ requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg())
triton/_utils.py ADDED
@@ -0,0 +1,35 @@
1
+ from functools import reduce
2
+
3
+
4
+ def get_iterable_path(iterable, path):
5
+ return reduce(lambda a, idx: a[idx], path, iterable)
6
+
7
+
8
+ def set_iterable_path(iterable, path, val):
9
+ prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
10
+ prev[path[-1]] = val
11
+
12
+
13
+ def find_paths_if(iterable, pred):
14
+ from .language import core
15
+ is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
16
+ ret = dict()
17
+
18
+ def _impl(current, path):
19
+ path = (path[0], ) if len(path) == 1 else tuple(path)
20
+ if is_iterable(current):
21
+ for idx, item in enumerate(current):
22
+ _impl(item, path + (idx, ))
23
+ elif pred(path, current):
24
+ if len(path) == 1:
25
+ ret[(path[0], )] = None
26
+ else:
27
+ ret[tuple(path)] = None
28
+
29
+ if is_iterable(iterable):
30
+ _impl(iterable, [])
31
+ elif pred(list(), iterable):
32
+ ret = {tuple(): None}
33
+ else:
34
+ ret = dict()
35
+ return list(ret.keys())
@@ -1,4 +1,4 @@
1
- from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor
1
+ from triton.backends.compiler import BaseBackend, GPUTarget
2
2
  from triton._C.libtriton import ir, passes, llvm, amd
3
3
  from dataclasses import dataclass
4
4
  from typing import Any, Dict, Tuple
@@ -13,16 +13,13 @@ from pathlib import Path
13
13
 
14
14
 
15
15
  def min_dot_size(target: GPUTarget):
16
- arch_str = target.arch
17
- # CDNA 3.0 supports k==8 in all mfma variants except for int8
18
- # (where the smallest `k` supported is 16)
19
- if "gfx94" in arch_str:
20
- return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.is_int8() or rhsType.is_int8()) else (16, 16, 8)
21
- # CDNA 2.0 always supports `k==8`
22
- if "gfx9" in arch_str:
23
- return lambda lhsType, rhsType: (16, 16, 8)
24
- # Other architectures will only support 16,16,16
25
- return lambda lhsType, rhsType: (16, 16, 16)
16
+ # If some given configuration is not supported in hardware we fallback to FMA and cast arguments
17
+ return lambda lhsType, rhsType: (1, 1, 1)
18
+
19
+
20
+ def is_pingpong_enabled(arch):
21
+ default = "1" if arch == "gfx942" else "0"
22
+ return os.getenv("TRITON_HIP_USE_BLOCK_PINGPONG", default) == "1"
26
23
 
27
24
 
28
25
  @dataclass(frozen=True)
@@ -31,10 +28,6 @@ class HIPOptions:
31
28
  waves_per_eu: int = 1
32
29
  num_stages: int = 2
33
30
  num_ctas: int = 1
34
- num_buffers_warp_spec: int = 0
35
- num_consumer_groups: int = 0
36
- reg_dec_producer: int = 0
37
- reg_inc_consumer: int = 0
38
31
  extern_libs: dict = None
39
32
  cluster_dims: tuple = (1, 1, 1)
40
33
  debug: bool = False
@@ -45,6 +38,7 @@ class HIPOptions:
45
38
  default_dot_input_precision: str = "ieee"
46
39
  allowed_dot_input_precisions: Tuple[str] = ("ieee", )
47
40
  enable_fp_fusion: bool = True
41
+ launch_cooperative_grid: bool = False
48
42
  matrix_instr_nonkdim: int = 0
49
43
  kpack: int = 1
50
44
  allow_flush_denorm: bool = False
@@ -52,11 +46,23 @@ class HIPOptions:
52
46
  backend_name: str = 'hip'
53
47
 
54
48
  # The following option provides hints to the AMDGPU backend regarding instruction scheduling
55
- # for all `tt.dot` operations in a kernel. The "default" variant preserves the default
49
+ # for all `tt.dot` operations in a kernel. The "none" variant preserves the default
56
50
  # instruction scheduling of the AMDGPU backend which aims at maximizing occupancy.
57
51
  # The option is experimental and may change at any time regarding its semantics and/or may
58
52
  # be gone entirely anytime.
59
- instruction_sched_variant: str = 'default'
53
+ #
54
+ # Current experimental scheduling variants:
55
+ #
56
+ # llvm-iglp-0: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `0` to the GEMM's
57
+ # k-loop; i.e., "interleave DS and MFMA instructions for small GEMM kernels".
58
+ # llvm-iglp-1: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `1` to the GEMM's
59
+ # k-loop; i.e., "interleave DS and MFMA instructions for single wave small
60
+ # GEMM kernels.".
61
+ # local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
62
+ # Kernel library. Note, this variant requires the use of buffer load/store ops
63
+ # and a special software pipelining style - i.e., 1x LDS and 1x register
64
+ # prefetch buffers for each GEMM tile.
65
+ instruction_sched_variant: str = 'none'
60
66
 
61
67
  def __post_init__(self):
62
68
  default_libdir = Path(__file__).parent / 'lib'
@@ -64,6 +70,9 @@ class HIPOptions:
64
70
  # Ignore user-defined warp size for gfx9
65
71
  warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64
66
72
  object.__setattr__(self, 'warp_size', warp_size)
73
+ # Only kpack=1 is supported on gfx950
74
+ kpack = 1 if self.arch == 'gfx950' else self.kpack
75
+ object.__setattr__(self, 'kpack', kpack)
67
76
  libs = ["ocml", "ockl"]
68
77
  for lib in libs:
69
78
  extern_libs[lib] = str(default_libdir / f'{lib}.bc')
@@ -76,44 +85,6 @@ class HIPOptions:
76
85
  return hashlib.sha256(key.encode("utf-8")).hexdigest()
77
86
 
78
87
 
79
- @register_descriptor
80
- class HIPAttrsDescriptor(AttrsDescriptor):
81
- # This property asserts if the underlying storage area of a given pointer
82
- # can be resepresented as a 32 bit integer. When this is true, we can be
83
- # sure that all indices into the tensor behind that pointer can use 32-bit
84
- # indexing. That opens the door for the AMD backend to use buffer load/store
85
- # instrinsics, which requires this property. Buffer load/store intrinsics
86
- # gives direct out-of-bound support and simplifies index calculation for
87
- # lower register pressure.
88
- __slots__ = ("pointer_range_32")
89
-
90
- def _add_backend_properties(self, params=None, values=None):
91
- self.property_values["tt.pointer_range"] = 32
92
- if params is None or values is None:
93
- return
94
-
95
- self.arg_properties["tt.pointer_range"] = [
96
- param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg)
97
- and not param.do_not_specialize and not param.do_not_specialize_on_alignment
98
- ]
99
-
100
- @staticmethod
101
- def is_within2gb(arg):
102
- if hasattr(arg, "ptr_range"):
103
- return arg.ptr_range() <= 2**31 - 1
104
- if "torch.Tensor" in str(type(arg)) and hasattr(arg, "untyped_storage"):
105
- # Please note that 2**31-1 is the max int32 positive limit
106
- return arg.untyped_storage().size() <= 2**31 - 1
107
- return False
108
-
109
- @staticmethod
110
- def get_property_key(val, align):
111
- generic_key = AttrsDescriptor.get_property_key(val, align)
112
- hip_key = "S" if HIPAttrsDescriptor.is_within2gb(val) else "N"
113
- key = (generic_key + hip_key).replace("N", "")
114
- return key if key else "N"
115
-
116
-
117
88
  class HIPBackend(BaseBackend):
118
89
 
119
90
  @staticmethod
@@ -126,17 +97,25 @@ class HIPBackend(BaseBackend):
126
97
  self.binary_ext = "hsaco"
127
98
 
128
99
  def parse_options(self, opts) -> Any:
129
- args = {'arch': self.target.arch}
100
+ args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", self.target.arch)}
101
+
102
+ # Enable XF32 (TF32) for CDNA3 GPUs
103
+ if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
104
+ allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
105
+ allowed_dot_input_precisions.update({'tf32'})
106
+ args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
130
107
 
131
108
  if "supported_fp8_dtypes" not in opts:
132
109
  supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
133
110
  if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
134
- supported_fp8_dtypes.update({'fp8e4b8', 'fp8e5b16'})
111
+ supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
112
+ elif self.target.arch in ('gfx950'):
113
+ supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
135
114
  args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
136
115
 
137
116
  if "enable_fp_fusion" not in opts:
138
117
  args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
139
- args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts})
118
+ args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts and opts[k] is not None})
140
119
  return HIPOptions(**args)
141
120
 
142
121
  def pack_metadata(self, metadata):
@@ -149,23 +128,49 @@ class HIPBackend(BaseBackend):
149
128
  metadata.cluster_dims[2],
150
129
  )
151
130
 
152
- def get_codegen_implementation(self):
131
+ def get_codegen_implementation(self, options):
153
132
  codegen_fns = {"min_dot_size": min_dot_size(self.target)}
154
133
  return codegen_fns
155
134
 
156
135
  def get_module_map(self) -> Dict[str, ModuleType]:
157
136
  from triton.language.extra.hip import libdevice
137
+
158
138
  return {"triton.language.extra.libdevice": libdevice}
159
139
 
160
140
  def load_dialects(self, ctx):
161
141
  amd.load_dialects(ctx)
162
142
 
163
- def get_attrs_descriptor(self, params, args):
164
- return HIPAttrsDescriptor(params, args)
143
+ @staticmethod
144
+ @functools.lru_cache()
145
+ def use_buffer_ops():
146
+ return os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"
147
+
148
+ @staticmethod
149
+ def is_within_2gb(arg):
150
+ import torch
151
+
152
+ MAX_INT_32 = 2**31 - 1
153
+ if hasattr(arg, "ptr_range"):
154
+ return arg.ptr_range() <= MAX_INT_32
155
+ if isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"):
156
+ return arg.untyped_storage().size() <= MAX_INT_32
157
+ return False
158
+
159
+ @staticmethod
160
+ def parse_attr(desc):
161
+ ret = BaseBackend.parse_attr(desc)
162
+ if "S" in desc:
163
+ ret += [["tt.pointer_range", 32]]
164
+ return ret
165
165
 
166
166
  @staticmethod
167
- def compute_spec_key(arg, align):
168
- return HIPAttrsDescriptor.get_property_key(arg, align)
167
+ def get_arg_specialization(arg, ty, **kwargs):
168
+ ret = BaseBackend.get_arg_specialization(arg, ty, **kwargs)
169
+ # Only attempt to do buffer ops specialization if buffer ops are enabled.
170
+ # Otherwise the is_within_2gb check is unnecessary overhead.
171
+ if HIPBackend.use_buffer_ops() and ty == "tensor" and HIPBackend.is_within_2gb(arg):
172
+ ret += "S"
173
+ return ret
169
174
 
170
175
  @staticmethod
171
176
  def path_to_rocm_lld():
@@ -193,8 +198,8 @@ class HIPBackend(BaseBackend):
193
198
  pm.enable_debug()
194
199
  passes.common.add_inliner(pm)
195
200
  passes.ttir.add_rewrite_tensor_pointer(pm)
196
- passes.ttir.add_combine(pm)
197
201
  passes.common.add_canonicalizer(pm)
202
+ passes.ttir.add_combine(pm)
198
203
  passes.ttir.add_reorder_broadcast(pm)
199
204
  passes.common.add_cse(pm)
200
205
  passes.common.add_licm(pm)
@@ -219,24 +224,38 @@ class HIPBackend(BaseBackend):
219
224
  passes.ttgpuir.add_remove_layout_conversions(pm)
220
225
  amd.passes.ttgpuir.add_optimize_epilogue(pm)
221
226
  passes.ttgpuir.add_optimize_dot_operands(pm, True)
227
+ amd.passes.ttgpuir.add_hoist_layout_conversions(pm)
228
+
229
+ global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0"))
230
+ local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0"))
231
+
232
+ # The `local-prefetch` scheduling variant requires turning on buffer ops.
233
+ if options.instruction_sched_variant == "local-prefetch":
234
+ global_prefetch = local_prefetch = 1
235
+
222
236
  if amd.has_matrix_core_feature(options.arch):
223
237
  assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. "
224
238
  "We used to trigger software pipelining with "
225
239
  "num_stages == 0. Now it will not happen anymore; "
226
240
  "please update to use num_stages == 2 for "
227
241
  "equivalent behavior in the past.")
228
- amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages)
242
+ amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch)
229
243
  passes.common.add_canonicalizer(pm)
230
- amd.passes.ttgpuir.insert_instruction_sched_hints(pm)
244
+ if options.instruction_sched_variant.lower() != "none":
245
+ amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant)
231
246
  passes.ttgpuir.add_optimize_dot_operands(pm, True)
232
247
  passes.ttgpuir.add_remove_layout_conversions(pm)
233
248
  passes.ttgpuir.add_reduce_data_duplication(pm)
234
249
  if amd.has_matrix_core_feature(options.arch):
235
250
  amd.passes.ttgpuir.add_reorder_instructions(pm)
236
- if os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1":
251
+ use_block_pingpong = is_pingpong_enabled(options.arch)
252
+ if use_block_pingpong and options.num_stages == 2:
253
+ amd.passes.ttgpuir.add_block_pingpong(pm)
254
+
255
+ if HIPBackend.use_buffer_ops():
237
256
  amd.passes.ttgpuir.add_canonicalize_pointers(pm)
238
257
  passes.common.add_canonicalizer(pm)
239
- amd.passes.ttgpuir.add_convert_to_buffer_ops(pm)
258
+ amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch)
240
259
  passes.common.add_canonicalizer(pm)
241
260
  passes.common.add_cse(pm)
242
261
  passes.common.add_symbol_dce(pm)
@@ -278,7 +297,8 @@ class HIPBackend(BaseBackend):
278
297
  passes.common.add_canonicalizer(pm)
279
298
  passes.common.add_cse(pm)
280
299
  passes.common.add_symbol_dce(pm)
281
- amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant)
300
+ if options.instruction_sched_variant.lower() != "none":
301
+ amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
282
302
  if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
283
303
  passes.llvmir.add_di_scope(pm)
284
304
  amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
@@ -289,12 +309,15 @@ class HIPBackend(BaseBackend):
289
309
  context = llvm.context()
290
310
  llvm_mod = llvm.to_module(mod, context)
291
311
  amd.attach_target_triple(llvm_mod)
292
- llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, '')
312
+ target_features = ''
313
+ if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
314
+ target_features = '+xnack'
315
+ llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features)
293
316
 
294
317
  # Set various control constants on the LLVM module so that device
295
318
  # libraries can resolve references to them.
296
319
  amd.set_isa_version(llvm_mod, options.arch)
297
- amd.set_abi_version(llvm_mod, 400)
320
+ amd.set_abi_version(llvm_mod, 500)
298
321
  amd.set_bool_control_constant(llvm_mod, "__oclc_finite_only_opt", False)
299
322
  amd.set_bool_control_constant(llvm_mod, "__oclc_correctly_rounded_sqrt32", True)
300
323
  amd.set_bool_control_constant(llvm_mod, "__oclc_unsafe_math_opt", False)
@@ -305,25 +328,46 @@ class HIPBackend(BaseBackend):
305
328
  # The public kernel should be kernel 0.
306
329
  fns[0].set_calling_conv(amd.CALLING_CONV_AMDGPU_KERNEL)
307
330
  fns[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{options.num_warps*options.warp_size}")
331
+ # LLVM AMDGPU backend supports the attribute "amdgpu-waves-per-eu"="<min>[, <max>]".
332
+ # This attribute may be attached to a kernel function definition and is an optimization hint.
333
+ # <min> parameter specifies the requested minimum number of waves per EU, and optional <max> parameter
334
+ # specifies the requested maximum number of waves per EU (must be greater than <min> if specified).
335
+ # If <max> is omitted, then there is no restriction on the maximum number of waves per EU other than
336
+ # the one dictated by the hardware for which the kernel is compiled. Passing 0, 0 as <min>, <max>
337
+ # implies the default behavior (no limits).
308
338
  fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}")
309
339
  denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee"
310
340
  fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode)
341
+ if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
342
+ fns[0].add_fn_target_feature("+xnack")
343
+ fns[0].add_fn_asan_attr()
311
344
 
312
345
  # Hint the compiler that we'd like the firmware to set the kernel arguments
313
346
  # to user SGPRs so that the kernel does not need to s_load its arguments
314
347
  # from memory.
315
348
  amd.set_all_fn_arg_inreg(fns[0])
316
349
 
317
- if options.extern_libs:
350
+ if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
351
+ default_libdir = Path(__file__).parent / 'lib'
352
+ paths = [
353
+ str(default_libdir / 'asanrtl.bc'),
354
+ str(default_libdir / "ocml.bc"),
355
+ str(default_libdir / "ockl.bc")
356
+ ]
357
+ llvm.link_extern_libs(llvm_mod, paths)
358
+ elif options.extern_libs:
318
359
  paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
319
360
  llvm.link_extern_libs(llvm_mod, paths)
320
361
 
321
362
  llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
322
363
 
323
364
  # Get some metadata
324
- metadata["shared"] = src.get_int_attr("triton_gpu.shared")
365
+ metadata["shared"] = src.get_int_attr("ttg.shared")
325
366
 
326
367
  amd.cleanup_bitcode_metadata(llvm_mod)
368
+ # Disable inlining of print related functions,
369
+ # because inlining of these function could slow down compilation significantly
370
+ amd.disable_print_inline(llvm_mod)
327
371
  return str(llvm_mod)
328
372
 
329
373
  @staticmethod
@@ -343,7 +387,10 @@ class HIPBackend(BaseBackend):
343
387
 
344
388
  @staticmethod
345
389
  def make_hsaco(src, metadata, options):
346
- hsaco = amd.assemble_amdgcn(src, options.arch, '')
390
+ target_features = ''
391
+ if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
392
+ target_features = '+xnack'
393
+ hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
347
394
 
348
395
  rocm_path = HIPBackend.path_to_rocm_lld()
349
396
  with tempfile.NamedTemporaryFile() as tmp_out: