triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
Binary file
Binary file
triton/testing.py CHANGED
@@ -95,7 +95,11 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
95
95
  end_event.record()
96
96
  torch.cuda.synchronize()
97
97
  estimate_ms = start_event.elapsed_time(end_event) / 5
98
- n_repeat = max(1, int(rep / estimate_ms))
98
+ # Rewrite to avoid possible division by 0 issues with fast benchmarks
99
+ if estimate_ms == 0:
100
+ n_repeat = 1000
101
+ else:
102
+ n_repeat = max(1, int(rep / estimate_ms))
99
103
  # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
100
104
  # host overhead
101
105
  g = torch.cuda.CUDAGraph()
@@ -383,18 +387,18 @@ class Mark:
383
387
  has_single_bench = isinstance(self.benchmarks, Benchmark)
384
388
  benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
385
389
  result_dfs = []
386
- if save_path:
387
- # Create directory if it doesn't exist
388
- os.makedirs(save_path, exist_ok=True)
389
- html = open(os.path.join(save_path, "results.html"), "w")
390
- html.write("<html><body>\n")
391
- for bench in benchmarks:
392
- result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
390
+ try:
391
+ for bench in benchmarks:
392
+ result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
393
+ finally:
393
394
  if save_path:
394
- html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
395
- if save_path:
396
- html.write("</body></html>\n")
397
- html.close()
395
+ # Create directory if it doesn't exist
396
+ os.makedirs(save_path, exist_ok=True)
397
+ with open(os.path.join(save_path, "results.html"), "w") as html:
398
+ html.write("<html><body>\n")
399
+ for bench in benchmarks[:len(result_dfs)]:
400
+ html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
401
+ html.write("</body></html>\n")
398
402
  if return_df:
399
403
  if has_single_bench:
400
404
  return result_dfs[0]
triton/tools/compile.py CHANGED
@@ -3,12 +3,29 @@ import hashlib
3
3
  import importlib.util
4
4
  import sys
5
5
  from argparse import ArgumentParser
6
+ from dataclasses import dataclass
6
7
  from pathlib import Path
7
8
  from typing import List
8
9
 
9
10
  import triton
10
11
  import triton.backends
11
- from triton.backends.nvidia.driver import ty_to_cpp
12
+
13
+
14
+ @dataclass
15
+ class CompileArgs:
16
+ '''
17
+ A class to contain arguments from command-line parser.
18
+ '''
19
+ path: str = ''
20
+ kernel_name: str = ''
21
+ signature: str = ''
22
+ grid: str = ''
23
+ target: str | None = None
24
+ num_warps: int = 1
25
+ num_stages: int = 3
26
+ out_name: str | None = None
27
+ out_path: Path | None = None
28
+
12
29
 
13
30
  desc = """
14
31
  Triton ahead-of-time compiler:
@@ -36,14 +53,18 @@ NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed
36
53
  used to run this `compile.py` script
37
54
  """
38
55
 
39
- if __name__ == "__main__":
40
56
 
57
+ def main():
41
58
  # command-line arguments
42
59
  parser = ArgumentParser(description=desc)
43
60
  parser.add_argument("path",
44
61
  help="Path to Python source containing desired kernel in its scope. File will be executed.")
45
62
  parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
46
63
  required=True)
64
+ parser.add_argument(
65
+ "--target", "-t", type=str, default=None,
66
+ help="The target to compile towards, in format of '<backend>:<arch>:<warp-size>'; "
67
+ "e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target")
47
68
  parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
48
69
  parser.add_argument("--num-stages", "-ns", type=int, default=3,
49
70
  help="Number of stages (meta-parameter of the kernel)")
@@ -51,8 +72,12 @@ if __name__ == "__main__":
51
72
  parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
52
73
  parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
53
74
  parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True)
54
- args = parser.parse_args()
75
+ cli_args = parser.parse_args()
76
+ args = CompileArgs(**vars(cli_args)) # A sanity check to ensure class CompileArgs is updated as well.
77
+ compile_kernel(args)
55
78
 
79
+
80
+ def compile_kernel(args: CompileArgs):
56
81
  out_name = args.out_name if args.out_name else args.kernel_name
57
82
  out_path = args.out_path if args.out_path else Path(out_name)
58
83
 
@@ -108,10 +133,18 @@ if __name__ == "__main__":
108
133
  assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}"
109
134
  attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16}
110
135
  src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs)
111
- opts = {"num_warps": args.num_warps, "num_stages": args.num_stages}
112
- ccinfo = triton.compile(src, options=opts)
113
- if ccinfo.metadata.global_scratch_size > 0:
136
+
137
+ target = triton.backends.compiler.GPUTarget(*args.target.split(":")) \
138
+ if args.target else triton.runtime.driver.active.get_current_target()
139
+ backend = triton.compiler.make_backend(target)
140
+ kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages}
141
+ options = backend.parse_options(kwargs)
142
+ ccinfo = triton.compile(src, target=target, options=options.__dict__)
143
+
144
+ if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0:
114
145
  raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented")
146
+ if ccinfo.metadata.profile_scratch_size > 0:
147
+ raise RuntimeError("AOT compiling kernels with profile scratch requirements is not yet implemented")
115
148
 
116
149
  arg_names = []
117
150
  arg_types = []
@@ -136,8 +169,12 @@ if __name__ == "__main__":
136
169
  if hints.get((i, ), None) == 16:
137
170
  suffix += 'd'
138
171
  func_name = '_'.join([out_name, sig_hash, suffix])
139
- asm = ccinfo.asm["cubin"] # store binary data once
172
+ asm = ccinfo.asm[backend.binary_ext] # store binary data once
173
+
140
174
  hex_ = str(binascii.hexlify(asm))[2:-1]
175
+
176
+ ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type
177
+
141
178
  params = {
142
179
  "kernel_name": func_name,
143
180
  "triton_kernel_name": args.kernel_name,
@@ -145,18 +182,29 @@ if __name__ == "__main__":
145
182
  "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
146
183
  "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]),
147
184
  "full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
148
- "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"]),
149
- "num_args": len(arg_names_not_1) + 1,
185
+ "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"] + ["&profile_scratch"]),
186
+ "num_args": len(arg_names_not_1) + 2, # +2 for global and profile scratch
150
187
  "kernel_docstring": doc_string,
151
188
  "shared": ccinfo.metadata.shared,
152
189
  "num_warps": args.num_warps,
153
- "algo_info": '_'.join([const_sig, meta_sig]),
190
+ "algo_info": "_".join([const_sig, meta_sig]),
154
191
  "gridX": grid[0],
155
192
  "gridY": grid[1],
156
193
  "gridZ": grid[2],
157
194
  "_placeholder": "",
158
195
  }
159
- for ext in ['h', 'c']:
160
- template_path = Path(__file__).parent / "extra" / "cuda" / f"compile.{ext}"
161
- with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp:
162
- fp.write(Path(template_path).read_text().format(**params))
196
+ output_files = []
197
+ backend_name = target.backend
198
+ template_dir = Path(__file__).parent / "extra" / backend_name
199
+ for template_path in template_dir.glob('compile.*'):
200
+ ext = template_path.suffix
201
+ output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}")
202
+ with output_file.open("w") as fp:
203
+ fp.write(template_path.read_text().format(**params))
204
+ output_files.append(output_file)
205
+
206
+ return func_name, output_files
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()
triton/tools/disasm.py CHANGED
@@ -75,14 +75,13 @@ def get_sass(cubin_asm, fun=None):
75
75
  return sass
76
76
 
77
77
 
78
- @functools.lru_cache()
79
78
  def path_to_cuobjdump():
80
- from triton.backends.nvidia.compiler import _path_to_binary
81
- return _path_to_binary("cuobjdump")
79
+ from triton import knobs
80
+ return knobs.nvidia.cuobjdump.path
82
81
 
83
82
 
84
83
  def extract(file_path, fun):
85
- cuobjdump, _ = path_to_cuobjdump()
84
+ cuobjdump = path_to_cuobjdump()
86
85
  if fun is None:
87
86
  sass_str = subprocess.check_output([cuobjdump, "-sass", file_path])
88
87
  else:
@@ -61,6 +61,7 @@ CUresult {kernel_name}(CUstream stream, {signature}) {{
61
61
  unsigned int gY = {gridY};
62
62
  unsigned int gZ = {gridZ};
63
63
  CUdeviceptr global_scratch = 0;
64
+ CUdeviceptr profile_scratch = 0;
64
65
  void *args[{num_args}] = {{ {arg_pointers} }};
65
66
  // TODO: shared memory
66
67
  if(gX * gY * gZ > 0)
@@ -0,0 +1,66 @@
1
+ // SPDX-License-Identifier: MIT
2
+ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+ /* clang-format off */
5
+ #include <stdio.h>
6
+ #include <stdint.h>
7
+ #include <inttypes.h>
8
+ #include <string.h>
9
+ #include <hip/hip_runtime.h>
10
+
11
+ // helpers to check for hip errors
12
+ #define HIP_CHECK(ans) {{\
13
+ gpuAssert((ans), __FILE__, __LINE__);\
14
+ }}\
15
+
16
+ static inline void gpuAssert(hipError_t code, const char *file, int line) {{
17
+ if (code != hipSuccess) {{
18
+ const char *prefix = "Triton Error [HIP]: ";
19
+ const char *str;
20
+ hipDrvGetErrorString(code, &str);
21
+ char err[1024] = {{0}};
22
+ strcat(err, prefix);
23
+ strcat(err, str);
24
+ printf("%s\\n", err);
25
+ exit(code);
26
+ }}
27
+ }}
28
+
29
+ // globals
30
+ #define HSACO_NAME {kernel_name}_hsaco
31
+ hipModule_t {kernel_name}_mod = nullptr;
32
+ hipFunction_t {kernel_name}_func = nullptr;
33
+ unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }};
34
+
35
+
36
+ void unload_{kernel_name}(void) {{
37
+ HIP_CHECK(hipModuleUnload({kernel_name}_mod));
38
+ }}
39
+
40
+
41
+ void load_{kernel_name}() {{
42
+ int dev = 0;
43
+ void *bin = (void *)&HSACO_NAME;
44
+ int shared = {shared};
45
+ HIP_CHECK(hipModuleLoadData(&{kernel_name}_mod, bin));
46
+ HIP_CHECK(hipModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"));
47
+ }}
48
+
49
+ /*
50
+ {kernel_docstring}
51
+ */
52
+ hipError_t {kernel_name}(hipStream_t stream, {signature}) {{
53
+ if ({kernel_name}_func == nullptr)
54
+ load_{kernel_name}();
55
+ unsigned int gX = {gridX};
56
+ unsigned int gY = {gridY};
57
+ unsigned int gZ = {gridZ};
58
+ hipDeviceptr_t global_scratch = 0;
59
+ hipDeviceptr_t profile_scratch = 0;
60
+ void *args[{num_args}] = {{ {arg_pointers} }};
61
+ // TODO: shared memory
62
+ if(gX * gY * gZ > 0)
63
+ return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * warpSize, 1, 1, {shared}, stream, args, nullptr);
64
+ else
65
+ return hipErrorInvalidValue;
66
+ }}
@@ -0,0 +1,13 @@
1
+ // SPDX-License-Identifier: MIT
2
+ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+ #pragma once
5
+
6
+ #include <hip/hip_runtime.h>
7
+ #include <inttypes.h>
8
+ #include <stdint.h>
9
+ #include <stdio.h>
10
+
11
+ void unload_{kernel_name}(void);
12
+ void load_{kernel_name}(void);
13
+ hipError_t{_placeholder} {kernel_name}(hipStream_t stream, {signature});
@@ -0,0 +1,92 @@
1
+ import triton
2
+ import triton.language as tl
3
+ from triton.tools.tensor_descriptor import TensorDescriptor
4
+
5
+ # fmt: off
6
+
7
+
8
+ def create_ragged_descriptor(T, block_shape, ragged_dim=0):
9
+ """
10
+ Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
11
+ which behaves like a concatenation (along the first axis) of subarrays
12
+ of potentially unequal size.
13
+
14
+ The load_ragged and store_ragged device functions can be used to read
15
+ and write from subarrays T[batch_offset : batch_offset + batch_size]
16
+ with hardware bounds-checking preventing any sort of leakage outside
17
+ the subarray.
18
+ """
19
+
20
+ block_shape = list(block_shape)
21
+ tensor_shape = list(T.shape)
22
+ rank = len(tensor_shape)
23
+
24
+ if ragged_dim < 0:
25
+ ragged_dim += rank
26
+
27
+ assert 0 <= ragged_dim < rank - 1, "last dimension cannot be ragged"
28
+ assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions"
29
+
30
+ assert len(block_shape) == rank, "block shape must have same length as tensor shape"
31
+
32
+ max_int = 0x7fff0000
33
+ billion = 0x40000000 # == 2**30
34
+
35
+ assert tensor_shape[ragged_dim] <= billion, "number of rows may not exceed 2**30"
36
+ tensor_shape[ragged_dim] = billion
37
+ ragged_stride = T.stride(ragged_dim)
38
+
39
+ # we prepend an extra two dimensions and rely on the fact that pointers
40
+ # have 64-bit wraparound semantics:
41
+ tma_stride = [2**34 - ragged_stride, ragged_stride] + [T.stride(i) for i in range(rank)]
42
+ tma_shape = [max_int, max_int] + tensor_shape
43
+ box_shape = [1, 1] + block_shape
44
+
45
+ return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
46
+
47
+
48
+ @triton.jit
49
+ def to_ragged_indices(batch_offset, batch_size, row):
50
+ """
51
+ Helper function for load_ragged and store_ragged.
52
+ """
53
+
54
+ billion = 0x40000000 # == 2**30
55
+ x = billion - batch_size + row
56
+ y = batch_offset + batch_size
57
+
58
+ return billion, y, x
59
+
60
+
61
+ @triton.jit
62
+ def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0):
63
+ """
64
+ Read from a subarray T[batch_offset : batch_offset + batch_size] with
65
+ hardware bounds-checking, where reading outside the subarray gives zeros.
66
+
67
+ Coords should be an appropriately-sized list of integers, just like in
68
+ TMA.load().
69
+ """
70
+
71
+ tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor")
72
+
73
+ c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
74
+ data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:])
75
+ data = tl.reshape(data, data.shape[2:])
76
+ return data
77
+
78
+
79
+ @triton.jit
80
+ def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
81
+ """
82
+ Write to a subarray T[batch_offset : batch_offset + batch_size] with
83
+ hardware bounds-checking, where writes outside the subarray are masked
84
+ correctly.
85
+
86
+ Coords should be an appropriately-sized list of integers, just like in
87
+ TMA.store().
88
+ """
89
+
90
+ c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
91
+ data = tl.reshape(data, [1, 1] + data.shape)
92
+ TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
@@ -0,0 +1,34 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Any
3
+ from triton._utils import validate_block_shape
4
+
5
+
6
+ @dataclass
7
+ class TensorDescriptor:
8
+ base: Any
9
+ shape: List[int]
10
+ strides: List[int]
11
+ block_shape: List[int]
12
+ padding: str = "zero"
13
+
14
+ def __post_init__(self):
15
+ rank = len(self.shape)
16
+ assert len(self.strides) == rank, f"rank mismatch: {self}"
17
+ assert len(self.block_shape) == rank, f"rank mismatch: {self}"
18
+ assert rank > 0, "rank must not be zero"
19
+ assert rank <= 5, "rank cannot be more than 5"
20
+ ty = type(self.base)
21
+ if ty.__name__ not in ("FakeTensor", "FunctionalTensor"):
22
+ assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
23
+ validate_block_shape(self.block_shape)
24
+ elem_bytes = self.base.dtype.itemsize
25
+ for stride in self.strides[:-1]:
26
+ assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
27
+ assert self.strides[-1] == 1, "Last dimension must be contiguous"
28
+ assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
29
+ if self.padding == "nan":
30
+ assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
31
+
32
+ @staticmethod
33
+ def from_tensor(tensor: Any, block_shape: List[int], padding="zero"):
34
+ return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, padding)