triton-windows 3.1.0.post17__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
Binary file
triton/__init__.py ADDED
@@ -0,0 +1,73 @@
1
+ """isort:skip_file"""
2
+ __version__ = '3.1.0'
3
+
4
+ # ---------------------------------------
5
+ # Note: import order is significant here.
6
+
7
+ # submodules
8
+ from .runtime import (
9
+ autotune,
10
+ Config,
11
+ heuristics,
12
+ JITFunction,
13
+ KernelInterface,
14
+ reinterpret,
15
+ TensorWrapper,
16
+ OutOfResources,
17
+ InterpreterError,
18
+ MockTensor,
19
+ )
20
+ from .runtime.jit import jit
21
+ from .compiler import compile, CompilationError
22
+ from .errors import TritonError
23
+
24
+ from . import language
25
+ from . import testing
26
+ from . import tools
27
+
28
+ __all__ = [
29
+ "autotune",
30
+ "cdiv",
31
+ "CompilationError",
32
+ "compile",
33
+ "Config",
34
+ "heuristics",
35
+ "impl",
36
+ "InterpreterError",
37
+ "jit",
38
+ "JITFunction",
39
+ "KernelInterface",
40
+ "language",
41
+ "MockTensor",
42
+ "next_power_of_2",
43
+ "ops",
44
+ "OutOfResources",
45
+ "reinterpret",
46
+ "runtime",
47
+ "TensorWrapper",
48
+ "TritonError",
49
+ "testing",
50
+ "tools",
51
+ ]
52
+
53
+ # -------------------------------------
54
+ # misc. utilities that don't fit well
55
+ # into any specific module
56
+ # -------------------------------------
57
+
58
+
59
+ def cdiv(x: int, y: int):
60
+ return (x + y - 1) // y
61
+
62
+
63
+ def next_power_of_2(n: int):
64
+ """Return the smallest power of 2 greater than or equal to n"""
65
+ n -= 1
66
+ n |= n >> 1
67
+ n |= n >> 2
68
+ n |= n >> 4
69
+ n |= n >> 8
70
+ n |= n >> 16
71
+ n |= n >> 32
72
+ n += 1
73
+ return n
@@ -0,0 +1,50 @@
1
+ import os
2
+ import importlib.util
3
+ import inspect
4
+ from dataclasses import dataclass
5
+ from .driver import DriverBase
6
+ from .compiler import BaseBackend
7
+
8
+
9
+ def _load_module(name, path):
10
+ spec = importlib.util.spec_from_file_location(name[:-3], path)
11
+ module = importlib.util.module_from_spec(spec)
12
+ spec.loader.exec_module(module)
13
+ return module
14
+
15
+
16
+ def _find_concrete_subclasses(module, base_class):
17
+ ret = []
18
+ for attr_name in dir(module):
19
+ attr = getattr(module, attr_name)
20
+ if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr):
21
+ ret.append(attr)
22
+ if len(ret) == 0:
23
+ raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}")
24
+ if len(ret) > 1:
25
+ raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}")
26
+ return ret[0]
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class Backend:
31
+ compiler: BaseBackend = None
32
+ driver: DriverBase = None
33
+
34
+
35
+ def _discover_backends():
36
+ backends = dict()
37
+ root = os.path.dirname(__file__)
38
+ for name in os.listdir(root):
39
+ if not os.path.isdir(os.path.join(root, name)):
40
+ continue
41
+ if name.startswith('__'):
42
+ continue
43
+ compiler = _load_module(name, os.path.join(root, name, 'compiler.py'))
44
+ driver = _load_module(name, os.path.join(root, name, 'driver.py'))
45
+ backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend),
46
+ _find_concrete_subclasses(driver, DriverBase))
47
+ return backends
48
+
49
+
50
+ backends = _discover_backends()
@@ -0,0 +1,262 @@
1
+ from triton.backends.compiler import BaseBackend, GPUTarget
2
+ from triton._C.libtriton import ir, passes, llvm, amd
3
+ from dataclasses import dataclass
4
+ from typing import Any, Tuple
5
+ import hashlib
6
+ import tempfile
7
+ import os
8
+ import re
9
+ import subprocess
10
+ import functools
11
+ from pathlib import Path
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class HIPOptions:
16
+ num_warps: int = 4
17
+ waves_per_eu: int = 1
18
+ num_stages: int = 0
19
+ num_ctas: int = 1
20
+ extern_libs: dict = None
21
+ cluster_dims: tuple = (1, 1, 1)
22
+ debug: bool = False
23
+ arch: str = None
24
+ allow_fp8e4nv: bool = False
25
+ allow_fp8e4b15: bool = False
26
+ default_dot_input_precision: str = "ieee"
27
+ allowed_dot_input_precisions: Tuple[str] = ("ieee", )
28
+ enable_fp_fusion: bool = True
29
+ matrix_instr_nonkdim: int = 0
30
+ kpack: int = 1
31
+ allow_flush_denorm: bool = False
32
+ max_num_imprecise_acc_default: int = 0
33
+ backend_name: str = 'hip'
34
+
35
+ def __post_init__(self):
36
+ default_libdir = Path(__file__).parent / 'lib'
37
+ extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
38
+ # Ignore user-defined warp size for gfx9
39
+ warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch else 64
40
+ object.__setattr__(self, 'warp_size', warp_size)
41
+ libs = ["ocml", "ockl"]
42
+ for lib in libs:
43
+ extern_libs[lib] = str(default_libdir / f'{lib}.bc')
44
+ object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
45
+ assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
46
+ "num_warps must be a power of 2"
47
+
48
+ def hash(self):
49
+ key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
50
+ return hashlib.sha256(key.encode("utf-8")).hexdigest()
51
+
52
+
53
+ class HIPBackend(BaseBackend):
54
+
55
+ @staticmethod
56
+ def supports_target(target: GPUTarget):
57
+ return target.backend == 'hip'
58
+
59
+ def __init__(self, target: GPUTarget) -> None:
60
+ super().__init__(target)
61
+ assert isinstance(target.arch, str)
62
+ self.binary_ext = "hsaco"
63
+
64
+ def parse_options(self, opts) -> Any:
65
+ args = {'arch': self.target.arch}
66
+ args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts})
67
+ return HIPOptions(**args)
68
+
69
+ def pack_metadata(self, metadata):
70
+ return (
71
+ metadata.num_warps,
72
+ metadata.num_ctas,
73
+ metadata.shared,
74
+ metadata.cluster_dims[0],
75
+ metadata.cluster_dims[1],
76
+ metadata.cluster_dims[2],
77
+ )
78
+
79
+ def get_codegen_implementation(self):
80
+ codegen_fns = dict()
81
+ return codegen_fns
82
+
83
+ def load_dialects(self, ctx):
84
+ amd.load_dialects(ctx)
85
+
86
+ @staticmethod
87
+ def path_to_rocm_lld():
88
+ # Check env path for ld.lld
89
+ lld_env_path = os.getenv("TRITON_HIP_LLD_PATH")
90
+ if lld_env_path is not None:
91
+ lld = Path(lld_env_path)
92
+ if lld.is_file():
93
+ return lld
94
+ # Check backend for ld.lld (used for pytorch wheels)
95
+ lld = Path(__file__).parent / "llvm/bin/ld.lld"
96
+ if lld.is_file():
97
+ return lld
98
+ lld = Path("/opt/rocm/llvm/bin/ld.lld")
99
+ if lld.is_file():
100
+ return lld
101
+ lld = Path("/usr/bin/ld.lld")
102
+ if lld.is_file():
103
+ return lld
104
+ raise Exception("ROCm linker /opt/rocm/llvm/bin/ld.lld not found")
105
+
106
+ @staticmethod
107
+ def make_ttir(mod, metadata, options):
108
+ pm = ir.pass_manager(mod.context)
109
+ pm.enable_debug()
110
+ passes.common.add_inliner(pm)
111
+ passes.ttir.add_rewrite_tensor_pointer(pm)
112
+ passes.ttir.add_combine(pm)
113
+ passes.common.add_canonicalizer(pm)
114
+ passes.ttir.add_reorder_broadcast(pm)
115
+ passes.common.add_cse(pm)
116
+ passes.common.add_licm(pm)
117
+ passes.common.add_symbol_dce(pm)
118
+ pm.run(mod)
119
+ return mod
120
+
121
+ @staticmethod
122
+ def make_ttgir(mod, metadata, options):
123
+ pm = ir.pass_manager(mod.context)
124
+ pm.enable_debug()
125
+ passes.ttir.add_convert_to_ttgpuir(pm, f"hip:{options.arch}", options.num_warps, options.warp_size,
126
+ options.num_ctas)
127
+ pm.run(mod)
128
+ pm = ir.pass_manager(mod.context)
129
+ pm.enable_debug()
130
+ passes.ttgpuir.add_coalesce(pm)
131
+ passes.ttgpuir.add_remove_layout_conversions(pm)
132
+ passes.ttgpuir.add_optimize_thread_locality(pm)
133
+ amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack)
134
+ passes.ttgpuir.add_remove_layout_conversions(pm)
135
+ amd.passes.ttgpuir.add_optimize_epilogue(pm)
136
+ passes.ttgpuir.add_optimize_dot_operands(pm, True)
137
+ if options.num_stages == 0 and amd.has_matrix_core_feature(options.arch):
138
+ amd.passes.ttgpuir.add_stream_pipeline(pm)
139
+ passes.common.add_canonicalizer(pm)
140
+ passes.ttgpuir.add_optimize_dot_operands(pm, True)
141
+ passes.ttgpuir.add_remove_layout_conversions(pm)
142
+ passes.ttgpuir.add_reduce_data_duplication(pm)
143
+ if options.num_stages != 0:
144
+ amd.passes.ttgpuir.add_reorder_instructions(pm)
145
+ passes.common.add_cse(pm)
146
+ passes.common.add_symbol_dce(pm)
147
+ pm.run(mod)
148
+ return mod
149
+
150
+ @staticmethod
151
+ def make_llir(src, metadata, options):
152
+ mod = src
153
+ # TritonGPU -> LLVM-IR (MLIR)
154
+ pm = ir.pass_manager(mod.context)
155
+ pm.enable_debug()
156
+ amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch)
157
+ passes.convert.add_scf_to_cf(pm)
158
+ passes.convert.add_index_to_llvmir(pm)
159
+
160
+ passes.ttgpuir.add_allocate_shared_memory(pm)
161
+ ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
162
+ ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
163
+ ## of the value of kernel arg `allow_flush_denorm`.
164
+ ## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output
165
+ ## depends on the value of kernel arg `allow_flush_denorm`.
166
+ ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
167
+ ## For now it is used as a controller for developers only.
168
+ __HIP_FTZ = True
169
+ amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
170
+ passes.common.add_canonicalizer(pm)
171
+ passes.common.add_cse(pm)
172
+
173
+ passes.convert.add_cf_to_llvmir(pm)
174
+ passes.convert.add_arith_to_llvmir(pm)
175
+ passes.common.add_canonicalizer(pm)
176
+ passes.common.add_cse(pm)
177
+ passes.common.add_symbol_dce(pm)
178
+ if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
179
+ passes.llvmir.add_di_scope(pm)
180
+ # This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block
181
+ # count caused by predicated loads/stores. In certain kernels, the addition of these blocks can cause the MLIR
182
+ # canonicalizer to never finish when attempting to merge blocks. The permanent solution under consideration
183
+ # involves using MUBUF instructions that have built-in out-of-bounds checks, which would eliminate the need
184
+ # for conditional branching around memory accesses.
185
+ amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm)
186
+ pm.run(mod)
187
+
188
+ # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
189
+ llvm.init_targets()
190
+ context = llvm.context()
191
+ llvm_mod = llvm.to_module(mod, context)
192
+
193
+ # Set various control constants on the LLVM module so that device
194
+ # libraries can resolve references to them.
195
+ amd.set_isa_version(llvm_mod, options.arch)
196
+ amd.set_abi_version(llvm_mod, 400)
197
+ amd.set_bool_control_constant(llvm_mod, "__oclc_finite_only_opt", False)
198
+ amd.set_bool_control_constant(llvm_mod, "__oclc_correctly_rounded_sqrt32", True)
199
+ amd.set_bool_control_constant(llvm_mod, "__oclc_unsafe_math_opt", False)
200
+ amd.set_bool_control_constant(llvm_mod, "__oclc_wavefrontsize64", options.warp_size == 64)
201
+
202
+ # Set kernel attributes first given this may affect later optimizations.
203
+ fns = [fn for fn in llvm_mod.get_functions() if not fn.is_declaration()]
204
+ # The public kernel should be kernel 0.
205
+ fns[0].set_calling_conv(amd.CALLING_CONV_AMDGPU_KERNEL)
206
+ fns[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{options.num_warps*options.warp_size}")
207
+ fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}")
208
+ denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee"
209
+ fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode)
210
+
211
+ if options.extern_libs:
212
+ paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
213
+ llvm.link_extern_libs(llvm_mod, paths)
214
+
215
+ llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, amd.TARGET_TRIPLE)
216
+
217
+ # Get some metadata
218
+ metadata["shared"] = src.get_int_attr("triton_gpu.shared")
219
+
220
+ amd.cleanup_bitcode_metadata(llvm_mod)
221
+ return str(llvm_mod)
222
+
223
+ @staticmethod
224
+ def make_amdgcn(src, metadata, options):
225
+ # Find kernel names (there should only be one)
226
+ # We get the name at the last possible step to accomodate `triton.compile`
227
+ # on user-provided LLVM
228
+ names = re.findall(r"define amdgpu_kernel void @([a-zA-Z_][a-zA-Z0-9_]*)", src)
229
+ assert len(names) == 1
230
+ metadata["name"] = names[0]
231
+ # llvm -> hsaco
232
+ amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', [], options.enable_fp_fusion, False)
233
+ if os.environ.get("AMDGCN_ENABLE_DUMP", "0") == "1":
234
+ print("// -----// AMDGCN Dump //----- //")
235
+ print(amdgcn)
236
+ return amdgcn
237
+
238
+ @staticmethod
239
+ def make_hsaco(src, metadata, options):
240
+ hsaco = amd.assemble_amdgcn(src, options.arch, '')
241
+
242
+ rocm_path = HIPBackend.path_to_rocm_lld()
243
+ with tempfile.NamedTemporaryFile() as tmp_out:
244
+ with tempfile.NamedTemporaryFile() as tmp_in:
245
+ with open(tmp_in.name, 'wb') as fd_in:
246
+ fd_in.write(hsaco)
247
+ subprocess.check_call([rocm_path, '-flavor', 'gnu', '-shared', tmp_in.name, '-o', tmp_out.name])
248
+ with open(tmp_out.name, 'rb') as fd_out:
249
+ ret = fd_out.read()
250
+ return ret
251
+
252
+ def add_stages(self, stages, options):
253
+ stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
254
+ stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options)
255
+ stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
256
+ stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options)
257
+ stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options)
258
+
259
+ @functools.lru_cache()
260
+ def hash(self):
261
+ version = subprocess.check_output([HIPBackend.path_to_rocm_lld(), "--version"], encoding='utf-8')
262
+ return f'{version}-{self.target}'
@@ -0,0 +1,211 @@
1
+ #define __HIP_PLATFORM_AMD__
2
+ // clang-format off
3
+ // hip_depreated.h needs definitions from hip_runtime.h.
4
+ #include <hip/hip_runtime.h>
5
+ #include <hip/hip_deprecated.h>
6
+ // clang-format on
7
+ #define PY_SSIZE_T_CLEAN
8
+ #include <Python.h>
9
+ #include <dlfcn.h>
10
+ #include <stdio.h>
11
+ #include <stdlib.h>
12
+
13
+ // The list of paths to search for the HIP runtime library. The caller Python
14
+ // code should substitute the search path placeholder.
15
+ static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
16
+
17
+ // The list of HIP dynamic library symbols and their signature we are interested
18
+ // in this file.
19
+ // |FOR_EACH_ERR_FN| is a macro to process APIs that return hipError_t;
20
+ // |FOR_EACH_STR_FN| is a macro to process APIs that return const char *.
21
+ //
22
+ // HIP 6.0 introduced an updated hipGetDeviceProperties API under a new symbol,
23
+ // hipGetDevicePropertiesR0600. However, the associated hipDeviceProp_t was
24
+ // directly updated with breaking changes to match hipGetDevicePropertiesR0600
25
+ // in the header file. We include the header file from HIP 6.0. So here if we
26
+ // use hipGetDeviceProperties together with hipDeviceProp_t we will use the
27
+ // old API with a new struct definition and mess up the interpretation.
28
+ //
29
+ // This is a known issue: https://github.com/ROCm/ROCm/issues/2728.
30
+ //
31
+ // For now explicitly defer to the old hipDeviceProp_t struct. This should work
32
+ // for both 5.x and 6.x. In the long term we need to switch to use
33
+ // hipGetProcAddress once available:
34
+ // https://github.com/ROCm/clr/commit/0479cdb3dd30ef58718cad44e424bd793c394cc0
35
+ #define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \
36
+ FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \
37
+ FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_tR0000 *prop, \
38
+ int deviceId) \
39
+ FOR_EACH_ERR_FN(hipModuleLoadDataEx, hipModule_t *module, const void *image, \
40
+ unsigned int numOptions, hipJitOption *options, \
41
+ void **optionValues) \
42
+ FOR_EACH_ERR_FN(hipModuleGetFunction, hipFunction_t *function, \
43
+ hipModule_t module, const char *kname) \
44
+ FOR_EACH_ERR_FN(hipFuncGetAttribute, int *, hipFunction_attribute attr, \
45
+ hipFunction_t function)
46
+
47
+ // The HIP symbol table for holding resolved dynamic library symbols.
48
+ struct HIPSymbolTable {
49
+ #define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \
50
+ hipError_t (*hipSymbolName)(__VA_ARGS__);
51
+ #define DEFINE_EACH_STR_FIELD(hipSymbolName, ...) \
52
+ const char *(*hipSymbolName)(__VA_ARGS__);
53
+
54
+ HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD)
55
+ };
56
+
57
+ static struct HIPSymbolTable hipSymbolTable;
58
+
59
+ bool initSymbolTable() {
60
+ // Use the HIP runtime library loaded into the existing process if it exits.
61
+ void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
62
+ if (lib) {
63
+ // printf("[triton] chosen loaded libamdhip64.so in the process\n");
64
+ }
65
+
66
+ // Otherwise, go through the list of search paths to dlopen the first HIP
67
+ // driver library.
68
+ if (!lib) {
69
+ int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
70
+ for (int i = 0; i < n; ++i) {
71
+ void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
72
+ if (handle) {
73
+ lib = handle;
74
+ // printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
75
+ }
76
+ }
77
+ }
78
+ if (!lib) {
79
+ PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
80
+ return false;
81
+ }
82
+
83
+ // Resolve all symbols we are interested in.
84
+ dlerror(); // Clear existing errors
85
+ const char *error = NULL;
86
+ #define QUERY_EACH_FN(hipSymbolName, ...) \
87
+ *(void **)&hipSymbolTable.hipSymbolName = dlsym(lib, #hipSymbolName); \
88
+ error = dlerror(); \
89
+ if (error) { \
90
+ PyErr_SetString(PyExc_RuntimeError, \
91
+ "cannot query " #hipSymbolName " from libamdhip64.so"); \
92
+ dlclose(lib); \
93
+ return false; \
94
+ }
95
+
96
+ HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
97
+
98
+ return true;
99
+ }
100
+
101
+ static inline void gpuAssert(hipError_t code, const char *file, int line) {
102
+ {
103
+ if (code != HIP_SUCCESS) {
104
+ {
105
+ const char *prefix = "Triton Error [HIP]: ";
106
+ const char *str = hipSymbolTable.hipGetErrorString(code);
107
+ char err[1024] = {0};
108
+ snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str);
109
+ PyGILState_STATE gil_state;
110
+ gil_state = PyGILState_Ensure();
111
+ PyErr_SetString(PyExc_RuntimeError, err);
112
+ PyGILState_Release(gil_state);
113
+ }
114
+ }
115
+ }
116
+ }
117
+
118
+ #define HIP_CHECK(ans) \
119
+ { \
120
+ gpuAssert((ans), __FILE__, __LINE__); \
121
+ if (PyErr_Occurred()) \
122
+ return NULL; \
123
+ }
124
+
125
+ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
126
+ int device_id;
127
+ if (!PyArg_ParseTuple(args, "i", &device_id))
128
+ return NULL;
129
+
130
+ hipDeviceProp_tR0000 props;
131
+ HIP_CHECK(hipSymbolTable.hipGetDeviceProperties(&props, device_id));
132
+
133
+ // create a struct to hold device properties
134
+ return Py_BuildValue(
135
+ "{s:i, s:i, s:i, s:i, s:i, s:i, s:s, s:i}", "max_shared_mem",
136
+ props.sharedMemPerBlock, "max_num_regs", props.regsPerBlock,
137
+ "multiprocessor_count", props.multiProcessorCount, "sm_clock_rate",
138
+ props.clockRate, "mem_clock_rate", props.memoryClockRate, "mem_bus_width",
139
+ props.memoryBusWidth, "arch", props.gcnArchName, "warpSize",
140
+ props.warpSize);
141
+ }
142
+
143
+ static PyObject *loadBinary(PyObject *self, PyObject *args) {
144
+ const char *name;
145
+ const char *data;
146
+ Py_ssize_t data_size;
147
+ int shared;
148
+ int device;
149
+ if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
150
+ &device)) {
151
+ return NULL;
152
+ }
153
+
154
+ // set HIP options
155
+ hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
156
+ hipJitOptionErrorLogBuffer,
157
+ hipJitOptionInfoLogBufferSizeBytes,
158
+ hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
159
+ const unsigned int errbufsize = 8192;
160
+ const unsigned int logbufsize = 8192;
161
+ char _err[errbufsize];
162
+ char _log[logbufsize];
163
+ void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
164
+ (void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
165
+
166
+ // launch HIP Binary
167
+ hipModule_t mod;
168
+ hipFunction_t fun;
169
+ HIP_CHECK(hipSymbolTable.hipModuleLoadDataEx(&mod, data, 5, opt, optval))
170
+ HIP_CHECK(hipSymbolTable.hipModuleGetFunction(&fun, mod, name));
171
+
172
+ // get allocated registers and spilled registers from the function
173
+ int n_regs = 0;
174
+ int n_spills = 0;
175
+ hipSymbolTable.hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, fun);
176
+ hipSymbolTable.hipFuncGetAttribute(&n_spills,
177
+ HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
178
+ n_spills /= 4;
179
+ if (PyErr_Occurred()) {
180
+ return NULL;
181
+ }
182
+ return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
183
+ n_spills);
184
+ }
185
+
186
+ static PyMethodDef ModuleMethods[] = {
187
+ {"load_binary", loadBinary, METH_VARARGS,
188
+ "Load provided hsaco into HIP driver"},
189
+ {"get_device_properties", getDeviceProperties, METH_VARARGS,
190
+ "Get the properties for a given device"},
191
+ {NULL, NULL, 0, NULL} // sentinel
192
+ };
193
+
194
+ static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils",
195
+ NULL, // documentation
196
+ -1, // size
197
+ ModuleMethods};
198
+
199
+ PyMODINIT_FUNC PyInit_hip_utils(void) {
200
+ if (!initSymbolTable()) {
201
+ return NULL;
202
+ }
203
+
204
+ PyObject *m = PyModule_Create(&ModuleDef);
205
+ if (m == NULL) {
206
+ return NULL;
207
+ }
208
+ PyModule_AddFunctions(m, ModuleMethods);
209
+
210
+ return m;
211
+ }