triton-windows 3.4.0.post20__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
triton/_C/libtriton.pyd CHANGED
Binary file
triton/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """isort:skip_file"""
2
- __version__ = '3.4.0'
2
+ __version__ = '3.5.0'
3
3
 
4
4
  # ---------------------------------------
5
5
  # Note: import order is significant here.
@@ -17,7 +17,8 @@ from .runtime import (
17
17
  InterpreterError,
18
18
  MockTensor,
19
19
  )
20
- from .runtime.jit import jit
20
+ from .runtime.jit import constexpr_function, jit
21
+ from .runtime._async_compile import AsyncCompileMode, FutureKernel
21
22
  from .compiler import compile, CompilationError
22
23
  from .errors import TritonError
23
24
  from .runtime._allocation import set_allocator
@@ -29,11 +30,14 @@ from . import tools
29
30
  must_use_result = language.core.must_use_result
30
31
 
31
32
  __all__ = [
33
+ "AsyncCompileMode",
32
34
  "autotune",
33
35
  "cdiv",
34
36
  "CompilationError",
35
37
  "compile",
36
38
  "Config",
39
+ "constexpr_function",
40
+ "FutureKernel",
37
41
  "heuristics",
38
42
  "InterpreterError",
39
43
  "jit",
@@ -59,10 +63,12 @@ __all__ = [
59
63
  # -------------------------------------
60
64
 
61
65
 
66
+ @constexpr_function
62
67
  def cdiv(x: int, y: int):
63
68
  return (x + y - 1) // y
64
69
 
65
70
 
71
+ @constexpr_function
66
72
  def next_power_of_2(n: int):
67
73
  """Return the smallest power of 2 greater than or equal to n"""
68
74
  n -= 1
triton/_filecheck.py CHANGED
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  import os
2
3
  import inspect
3
4
  import subprocess
@@ -7,6 +8,7 @@ import triton
7
8
  from triton.compiler import ASTSource, make_backend
8
9
  from triton.backends.compiler import GPUTarget
9
10
  from triton.experimental.gluon._runtime import GluonASTSource
11
+ from triton.runtime.jit import create_function_from_signature
10
12
  from triton._C.libtriton import ir
11
13
 
12
14
  # ===-----------------------------------------------------------------------===#
@@ -15,7 +17,6 @@ from triton._C.libtriton import ir
15
17
 
16
18
  # Stub target for testing the frontend.
17
19
  stub_target = GPUTarget("cuda", 100, 32)
18
- stub_backend = make_backend(stub_target)
19
20
 
20
21
  triton_dir = os.path.dirname(__file__)
21
22
  filecheck_path = os.path.join(triton_dir, "FileCheck")
@@ -42,29 +43,37 @@ def run_filecheck(name, module_str, check_template):
42
43
  temp.write(check_template)
43
44
 
44
45
  try:
45
- subprocess.check_output([filecheck_path, temp_expected, "--input-file", temp_module],
46
- stderr=subprocess.STDOUT)
46
+ subprocess.check_output(
47
+ [filecheck_path, temp_expected, "--input-file", temp_module, "--dump-input-context=50"],
48
+ stderr=subprocess.STDOUT)
47
49
  except subprocess.CalledProcessError as error:
48
50
  decoded = error.output.decode('unicode_escape')
49
51
  raise ValueError(decoded)
50
52
 
51
53
 
52
- def run_parser(kernel_fn):
53
- sigkeys = [x.name for x in kernel_fn.params]
54
- sigvals = [f"arg{i}" for i in range(len(sigkeys))]
55
- signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
54
+ def run_parser(kernel_fn, args=(), kwargs={}, target=stub_target):
55
+ if "sanitize_overflow" not in kwargs:
56
+ kwargs = dict(kwargs)
57
+ kwargs["sanitize_overflow"] = False
58
+ backend = make_backend(target)
59
+ binder = create_function_from_signature(
60
+ kernel_fn.signature,
61
+ kernel_fn.params,
62
+ backend,
63
+ )
64
+
65
+ bound_args, specialization, options = binder(*args, **kwargs)
66
+ options, signature, constexprs, attrs = kernel_fn._pack_args(backend, kwargs, bound_args, specialization, options)
56
67
  source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource
57
- src = source_cls(fn=kernel_fn, signature=signature)
68
+ src = source_cls(kernel_fn, signature, constexprs, attrs)
58
69
 
59
70
  context = ir.context()
60
71
  ir.load_dialects(context)
61
- stub_backend.load_dialects(context)
72
+ backend.load_dialects(context)
62
73
 
63
- extra_options = src.parse_options()
64
- options = stub_backend.parse_options(dict(**extra_options))
65
- codegen_fns = stub_backend.get_codegen_implementation(options)
66
- module_map = stub_backend.get_module_map()
67
- module = src.make_ir(options, codegen_fns, module_map, context)
74
+ codegen_fns = backend.get_codegen_implementation(options)
75
+ module_map = backend.get_module_map()
76
+ module = src.make_ir(target, options, codegen_fns, module_map, context)
68
77
  assert module.verify()
69
78
  return module
70
79
 
@@ -81,6 +90,7 @@ def run_filecheck_test(kernel_fn):
81
90
 
82
91
  def filecheck_test(fn):
83
92
 
93
+ @functools.wraps(fn)
84
94
  def test_fn():
85
95
  run_filecheck_test(fn)
86
96
 
@@ -5,10 +5,10 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
  from triton import knobs
8
+ from typing import Optional, Set, Union
8
9
  import pytest
9
10
 
10
11
  from numpy.random import RandomState
11
- from typing import Optional, Union
12
12
  from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict
13
13
 
14
14
  int_dtypes = ['int8', 'int16', 'int32', 'int64']
@@ -38,10 +38,22 @@ def is_cuda():
38
38
  return False if target is None else target.backend == "cuda"
39
39
 
40
40
 
41
- def is_hopper():
41
+ def is_ampere_or_newer():
42
+ return is_cuda() and torch.cuda.get_device_capability()[0] >= 8
43
+
44
+
45
+ def is_blackwell():
46
+ return is_cuda() and torch.cuda.get_device_capability()[0] == 10
47
+
48
+
49
+ def is_hopper_or_newer():
42
50
  return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
43
51
 
44
52
 
53
+ def is_hopper():
54
+ return is_cuda() and torch.cuda.get_device_capability()[0] == 9
55
+
56
+
45
57
  def is_hip():
46
58
  target = get_current_target()
47
59
  return False if target is None else target.backend == "hip"
@@ -62,9 +74,13 @@ def is_hip_cdna4():
62
74
  return target is not None and target.backend == 'hip' and target.arch == 'gfx950'
63
75
 
64
76
 
77
+ def is_hip_gfx11():
78
+ target = get_current_target()
79
+ return target is not None and target.backend == 'hip' and 'gfx11' in target.arch
80
+
81
+
65
82
  def is_hip_gfx12():
66
83
  target = get_current_target()
67
- print(target.arch)
68
84
  return target is not None and target.backend == 'hip' and 'gfx12' in target.arch
69
85
 
70
86
 
@@ -72,6 +88,10 @@ def is_hip_cdna():
72
88
  return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4()
73
89
 
74
90
 
91
+ def get_hip_lds_size():
92
+ return 163840 if is_hip_cdna4() else 65536
93
+
94
+
75
95
  def is_xpu():
76
96
  target = get_current_target()
77
97
  return False if target is None else target.backend == "xpu"
@@ -132,7 +152,7 @@ def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torc
132
152
 
133
153
 
134
154
  def str_to_triton_dtype(x: str) -> tl.dtype:
135
- return tl.str_to_ty(type_canonicalisation_dict[x])
155
+ return tl.str_to_ty(type_canonicalisation_dict[x], None)
136
156
 
137
157
 
138
158
  def torch_dtype_name(dtype) -> str:
@@ -187,3 +207,49 @@ def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> t
187
207
  if isinstance(t, triton.runtime.jit.TensorWrapper):
188
208
  return t.base
189
209
  return t
210
+
211
+
212
+ def _fresh_knobs_impl(skipped_attr: Optional[Set[str]] = None):
213
+ from triton import knobs
214
+
215
+ if skipped_attr is None:
216
+ skipped_attr = set()
217
+
218
+ monkeypatch = pytest.MonkeyPatch()
219
+
220
+ knobs_map = {
221
+ name: knobset
222
+ for name, knobset in knobs.__dict__.items()
223
+ if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs and name not in skipped_attr
224
+ }
225
+
226
+ # We store which variables we need to unset below in finally because
227
+ # monkeypatch doesn't appear to reset variables that were never set
228
+ # before the monkeypatch.delenv call below.
229
+ env_to_unset = []
230
+ prev_propagate_env = knobs.propagate_env
231
+
232
+ def fresh_function():
233
+ nonlocal env_to_unset
234
+ for name, knobset in knobs_map.items():
235
+ setattr(knobs, name, knobset.copy().reset())
236
+ for knob in knobset.knob_descriptors.values():
237
+ if knob.key in os.environ:
238
+ monkeypatch.delenv(knob.key, raising=False)
239
+ else:
240
+ env_to_unset.append(knob.key)
241
+ knobs.propagate_env = True
242
+ return knobs
243
+
244
+ def reset_function():
245
+ for name, knobset in knobs_map.items():
246
+ setattr(knobs, name, knobset)
247
+ # `undo` should be placed before `del os.environ`
248
+ # Otherwise, it may restore environment variables that monkeypatch deleted
249
+ monkeypatch.undo()
250
+ for k in env_to_unset:
251
+ if k in os.environ:
252
+ del os.environ[k]
253
+ knobs.propagate_env = prev_propagate_env
254
+
255
+ return fresh_function, reset_function
triton/_utils.py CHANGED
@@ -16,9 +16,11 @@ def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any:
16
16
 
17
17
 
18
18
  def set_iterable_path(iterable: IterableType, path: tuple[int, ...], val: Any):
19
+ from .language import core
19
20
  assert len(path) != 0
20
21
  prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
21
- prev[path[-1]] = val # type: ignore[index]
22
+ assert isinstance(prev, core.tuple)
23
+ prev._setitem(path[-1], val)
22
24
 
23
25
 
24
26
  def find_paths_if(iterable: Union[IterableType, Any], pred: Callable[[ObjPath, Any], bool]) -> list[ObjPath]:
@@ -7,8 +7,8 @@ from types import ModuleType
7
7
  import hashlib
8
8
  import tempfile
9
9
  import re
10
- import subprocess
11
10
  import functools
11
+ import warnings
12
12
  from pathlib import Path
13
13
 
14
14
 
@@ -18,8 +18,9 @@ def get_min_dot_size(target: GPUTarget):
18
18
  return lambda lhs_type, rhs_type: (1, 1, 1)
19
19
 
20
20
 
21
- def is_pingpong_schedule_enabled(arch):
22
- return (arch == "gfx942") if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
21
+ def is_pingpong_schedule_enabled(arch, use_async_copy):
22
+ return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)
23
+ ) if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
23
24
 
24
25
 
25
26
  def is_in_thread_transpose_enabled(arch):
@@ -37,7 +38,11 @@ class HIPOptions:
37
38
  debug: bool = False
38
39
  sanitize_overflow: bool = True
39
40
  arch: str = None
40
- supported_fp8_dtypes: Tuple[str] = ("fp8e5", )
41
+ # We have native support for OCP fp8 variants since CDNA4/RDNA4. For earlier generations,
42
+ # we software emulate the support for them.
43
+ # UZ fp8 variants (fp8e4b8 and fp8e5b16) are natively supported for CDNA3. For other
44
+ # architectures they are software emulated.
45
+ supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e5b16", "fp8e4b8")
41
46
  deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
42
47
  default_dot_input_precision: str = "ieee"
43
48
  allowed_dot_input_precisions: Tuple[str] = ("ieee", )
@@ -48,6 +53,7 @@ class HIPOptions:
48
53
  allow_flush_denorm: bool = False
49
54
  max_num_imprecise_acc_default: int = 0
50
55
  backend_name: str = 'hip'
56
+ instrumentation_mode: str = ""
51
57
 
52
58
  # The following option provides hints to the AMDGPU backend regarding instruction scheduling
53
59
  # for all `tt.dot` operations in a kernel. The "none" variant preserves the default
@@ -57,10 +63,6 @@ class HIPOptions:
57
63
  #
58
64
  # Current experimental scheduling variants:
59
65
  #
60
- # local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
61
- # Kernel library. Note, this variant requires the use of buffer load/store ops
62
- # and a special software pipelining style - i.e., 1x LDS and 1x register
63
- # prefetch buffers for each GEMM tile.
64
66
  # attention: enables a bunch of optimizations for attention kernels, including:
65
67
  # - iglp 2 and sched.barrier around it
66
68
  # - sink-insts-to-avoid-spills flag to avoid register spills
@@ -73,8 +75,11 @@ class HIPOptions:
73
75
  assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
74
76
  "num_warps must be a power of 2"
75
77
 
76
- if self.arch == 'gfx950':
77
- assert self.kpack == 1, "gfx950 only accepts kpack == 1"
78
+ if (self.arch == 'gfx950') and (self.kpack != 1):
79
+ warnings.warn(
80
+ f"kpack is deprecated starting from gfx950 and will be removed in later releases. So for now kpack = {self.kpack} will be overwritten to 1 to make transitioning easier."
81
+ )
82
+ object.__setattr__(self, 'kpack', 1)
78
83
 
79
84
  default_libdir = Path(__file__).parent / 'lib'
80
85
  extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
@@ -88,6 +93,7 @@ class HIPOptions:
88
93
 
89
94
 
90
95
  class HIPBackend(BaseBackend):
96
+ instrumentation = None
91
97
 
92
98
  @staticmethod
93
99
  def supports_target(target: GPUTarget):
@@ -104,6 +110,9 @@ class HIPBackend(BaseBackend):
104
110
  def parse_options(self, opts) -> Any:
105
111
  args = {'arch': knobs.runtime.override_arch or self.target.arch}
106
112
 
113
+ if opts.get("num_ctas", 1) > 1:
114
+ raise ValueError("num_ctas > 1 not supported for AMD GPUs")
115
+
107
116
  # Enable XF32 (TF32) for CDNA3 GPUs
108
117
  if self.target.arch == 'gfx942':
109
118
  allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
@@ -111,14 +120,12 @@ class HIPBackend(BaseBackend):
111
120
  args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
112
121
 
113
122
  if "supported_fp8_dtypes" not in opts:
114
- supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
115
- if self.target.arch == 'gfx942':
116
- supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
117
- elif self.target.arch == 'gfx950':
118
- supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
119
- elif 'gfx12' in self.target.arch:
120
- supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
121
- args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
123
+ args["supported_fp8_dtypes"] = tuple(sorted(HIPOptions.supported_fp8_dtypes))
124
+
125
+ if self.target.arch == 'gfx950':
126
+ deprecated_fp8_dot_operand_dtypes = set(HIPOptions.deprecated_fp8_dot_operand_dtypes)
127
+ deprecated_fp8_dot_operand_dtypes.update({"fp8e5b16", "fp8e4b8"})
128
+ args["deprecated_fp8_dot_operand_dtypes"] = tuple(sorted(deprecated_fp8_dot_operand_dtypes))
122
129
 
123
130
  if "enable_fp_fusion" not in opts:
124
131
  args["enable_fp_fusion"] = knobs.language.default_fp_fusion
@@ -146,6 +153,8 @@ class HIPBackend(BaseBackend):
146
153
 
147
154
  def load_dialects(self, ctx):
148
155
  amd.load_dialects(ctx)
156
+ if HIPBackend.instrumentation:
157
+ HIPBackend.instrumentation.load_dialects(ctx)
149
158
 
150
159
  @staticmethod
151
160
  def is_within_2gb(arg):
@@ -174,26 +183,6 @@ class HIPBackend(BaseBackend):
174
183
  ret += "S"
175
184
  return ret
176
185
 
177
- @staticmethod
178
- def path_to_rocm_lld():
179
- # Check env path for ld.lld
180
- lld_env_path = knobs.amd.lld_path
181
- if lld_env_path is not None:
182
- lld = Path(lld_env_path)
183
- if lld.is_file():
184
- return lld
185
- # Check backend for ld.lld (used for pytorch wheels)
186
- lld = Path(__file__).parent / "llvm/bin/ld.lld"
187
- if lld.is_file():
188
- return lld
189
- lld = Path("/opt/rocm/llvm/bin/ld.lld")
190
- if lld.is_file():
191
- return lld
192
- lld = Path("/usr/bin/ld.lld")
193
- if lld.is_file():
194
- return lld
195
- raise Exception("ROCm linker /opt/rocm/llvm/bin/ld.lld not found. Set 'TRITON_HIP_LLD_PATH' to its path.")
196
-
197
186
  @staticmethod
198
187
  def make_ttir(mod, metadata, options):
199
188
  pm = ir.pass_manager(mod.context)
@@ -237,12 +226,10 @@ class HIPBackend(BaseBackend):
237
226
  global_prefetch = knobs.amd.global_prefetch
238
227
  local_prefetch = knobs.amd.local_prefetch
239
228
  use_async_copy = knobs.amd.use_async_copy
229
+ use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy)
240
230
 
241
- # The `local-prefetch` scheduling variant requires turning on buffer ops.
242
- if options.schedule_hint == "local-prefetch":
243
- global_prefetch = local_prefetch = 1
244
-
245
- amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy)
231
+ amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy,
232
+ use_block_pingpong)
246
233
  if use_async_copy:
247
234
  amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
248
235
  passes.common.add_canonicalizer(pm)
@@ -255,14 +242,13 @@ class HIPBackend(BaseBackend):
255
242
  amd.passes.ttgpuir.add_in_thread_transpose(pm)
256
243
  passes.ttgpuir.add_remove_layout_conversions(pm)
257
244
  amd.passes.ttgpuir.add_reorder_instructions(pm)
258
- use_block_pingpong = is_pingpong_schedule_enabled(options.arch)
259
- if use_block_pingpong and options.num_stages == 2:
245
+ if use_block_pingpong and options.num_stages > 1:
260
246
  amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
261
247
 
262
248
  if knobs.amd.use_buffer_ops:
263
249
  amd.passes.ttgpuir.add_canonicalize_pointers(pm)
264
250
  passes.common.add_canonicalizer(pm)
265
- amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch)
251
+ amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch, knobs.amd.use_buffer_atomics)
266
252
 
267
253
  amd.passes.ttgpuir.add_fold_true_cmpi(pm)
268
254
  passes.common.add_canonicalizer(pm)
@@ -274,15 +260,16 @@ class HIPBackend(BaseBackend):
274
260
  return mod
275
261
 
276
262
  @staticmethod
277
- def ttgir_opt(src, metadata, options):
263
+ def gluon_to_ttgir(src, metadata, options):
278
264
  mod = src
279
265
  pm = ir.pass_manager(mod.context)
280
266
  pm.enable_debug()
281
267
 
282
- passes.ttgpuir.add_inliner(pm)
268
+ passes.gluon.add_inliner(pm)
269
+ passes.gluon.add_resolve_auto_encodings(pm)
283
270
  passes.common.add_sccp(pm)
284
271
  passes.ttir.add_loop_aware_cse(pm)
285
- passes.ttgpuir.add_canonicalizer(pm)
272
+ passes.gluon.add_canonicalizer(pm)
286
273
  passes.ttgpuir.add_combine_tensor_select_and_if(pm)
287
274
 
288
275
  pm.run(mod)
@@ -304,7 +291,10 @@ class HIPBackend(BaseBackend):
304
291
  passes.convert.add_scf_to_cf(pm)
305
292
  passes.convert.add_index_to_llvmir(pm)
306
293
 
307
- passes.ttgpuir.add_allocate_shared_memory(pm)
294
+ amd.passes.ttgpuir.add_allocate_shared_memory(pm)
295
+ # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
296
+ if HIPBackend.instrumentation:
297
+ HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
308
298
  ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
309
299
  ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
310
300
  ## of the value of kernel arg `allow_flush_denorm`.
@@ -322,10 +312,17 @@ class HIPBackend(BaseBackend):
322
312
  passes.common.add_canonicalizer(pm)
323
313
  passes.common.add_cse(pm)
324
314
  passes.common.add_symbol_dce(pm)
315
+
325
316
  if options.schedule_hint.lower() != "none":
326
317
  amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
318
+
319
+ # This can not be moved below the di_scope pass
320
+ if HIPBackend.instrumentation:
321
+ HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
322
+
327
323
  if not knobs.compilation.disable_line_info:
328
324
  passes.llvmir.add_di_scope(pm)
325
+
329
326
  amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
330
327
  pm.run(mod)
331
328
 
@@ -382,15 +379,27 @@ class HIPBackend(BaseBackend):
382
379
  llvm.link_extern_libs(llvm_mod, paths)
383
380
  elif options.extern_libs:
384
381
  paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
385
- llvm.link_extern_libs(llvm_mod, paths)
382
+ if len(paths) > 0:
383
+ llvm.link_extern_libs(llvm_mod, paths)
386
384
 
387
385
  llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
388
386
 
387
+ # Architectures with architected SGPRs store the workgroup id in ttmp9 (X) and ttmp7 (Y[15:0], Z[31:16]).
388
+ # These attributes are used to determine if Z should be masked out when loading Y. They are inferred during
389
+ # optimize_module from calls to @llvm.amdgcn.workgroup.id.x/y/z(). We cannot rely on this because a
390
+ # dispatch dimensions might be used even if there is no program_id() call for it.
391
+ if amd.has_architected_sgprs(options.arch):
392
+ fns[0].remove_fn_attr("amdgpu-no-workgroup-id-x")
393
+ fns[0].remove_fn_attr("amdgpu-no-workgroup-id-y")
394
+ fns[0].remove_fn_attr("amdgpu-no-workgroup-id-z")
395
+
389
396
  if knobs.amd.scalarize_packed_fops:
390
397
  amd.add_scalarize_packed_fops_llvm_pass(fns[0])
391
398
 
392
399
  # Get some metadata
393
400
  metadata["shared"] = src.get_int_attr("ttg.shared")
401
+ metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
402
+ metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
394
403
 
395
404
  amd.cleanup_bitcode_metadata(llvm_mod)
396
405
  # Disable inlining of print related functions,
@@ -414,7 +423,9 @@ class HIPBackend(BaseBackend):
414
423
  # the regression is not significant. It would be better to have some heuristics.
415
424
  if options.schedule_hint == 'attention':
416
425
  flags.append('sink-insts-to-avoid-spills')
417
- amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', flags, options.enable_fp_fusion, False)
426
+ features = '-real-true16' if 'gfx11' in options.arch else ''
427
+ amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion,
428
+ False)
418
429
  if knobs.amd.dump_amdgcn:
419
430
  print("// -----// AMDGCN Dump //----- //")
420
431
  print(amdgcn)
@@ -426,14 +437,12 @@ class HIPBackend(BaseBackend):
426
437
  if knobs.compilation.enable_asan:
427
438
  target_features = '+xnack'
428
439
  hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
429
-
430
- rocm_path = HIPBackend.path_to_rocm_lld()
431
440
  with tempfile.NamedTemporaryFile() as tmp_out:
432
441
  with tempfile.NamedTemporaryFile() as tmp_in:
433
- with open(tmp_in.name, 'wb') as fd_in:
442
+ with open(tmp_in.name, "wb") as fd_in:
434
443
  fd_in.write(hsaco)
435
- subprocess.check_call([rocm_path, '-flavor', 'gnu', '-shared', tmp_in.name, '-o', tmp_out.name])
436
- with open(tmp_out.name, 'rb') as fd_out:
444
+ amd.link_hsaco(tmp_in.name, tmp_out.name)
445
+ with open(tmp_out.name, "rb") as fd_out:
437
446
  ret = fd_out.read()
438
447
  return ret
439
448
 
@@ -442,12 +451,11 @@ class HIPBackend(BaseBackend):
442
451
  stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
443
452
  stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options)
444
453
  elif language == Language.GLUON:
445
- stages["ttgir"] = lambda src, metadata: self.ttgir_opt(src, metadata, options)
454
+ stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
446
455
  stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
447
456
  stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options)
448
457
  stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options)
449
458
 
450
459
  @functools.lru_cache()
451
460
  def hash(self):
452
- version = subprocess.check_output([HIPBackend.path_to_rocm_lld(), "--version"], encoding='utf-8')
453
- return f'{version}-{self.target}'
461
+ return f'{self.target}'