triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -1,16 +1,17 @@
1
1
  import functools
2
2
  import os
3
- import hashlib
4
3
  import subprocess
5
- import tempfile
4
+ import re
6
5
  from pathlib import Path
7
- from triton.runtime.build import _build
8
- from triton.runtime.cache import get_cache_manager
6
+ from triton import knobs
9
7
  from triton.backends.compiler import GPUTarget
10
8
  from triton.backends.driver import GPUDriver
9
+ from triton.runtime import _allocation
10
+ from triton.runtime.build import compile_module_from_src
11
+ from triton.tools.tensor_descriptor import TensorDescriptor
11
12
 
12
13
  dirname = os.path.dirname(os.path.realpath(__file__))
13
- include_dir = [os.path.join(dirname, "include")]
14
+ include_dirs = [os.path.join(dirname, "include")]
14
15
 
15
16
 
16
17
  def _find_already_mmapped_dylib_on_linux(lib_name):
@@ -66,8 +67,7 @@ def _get_path_to_hip_runtime_dylib():
66
67
  lib_name = "libamdhip64.so"
67
68
 
68
69
  # If we are told explicitly what HIP runtime dynamic library to use, obey that.
69
- env_libhip_path = os.getenv("TRITON_LIBHIP_PATH")
70
- if env_libhip_path:
70
+ if env_libhip_path := knobs.amd.libhip_path:
71
71
  if env_libhip_path.endswith(lib_name) and os.path.exists(env_libhip_path):
72
72
  return env_libhip_path
73
73
  raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
@@ -81,6 +81,12 @@ def _get_path_to_hip_runtime_dylib():
81
81
 
82
82
  paths = []
83
83
 
84
+ # Check backend
85
+ local_lib = os.path.join(os.path.dirname(__file__), "lib", lib_name)
86
+ if os.path.exists(local_lib):
87
+ return local_lib
88
+ paths.append(local_lib)
89
+
84
90
  import site
85
91
  # First search the HIP runtime dynamic library packaged with PyTorch. It's very likely
86
92
  # that we run Triton together with PyTorch. This makes sure we use the same dynamic
@@ -104,8 +110,36 @@ def _get_path_to_hip_runtime_dylib():
104
110
  return f
105
111
  paths.append(f)
106
112
 
113
+ # HIP_PATH should point to HIP SDK root if set
114
+ env_hip_path = os.getenv("HIP_PATH")
115
+ if env_hip_path:
116
+ hip_lib_path = os.path.join(env_hip_path, "lib", lib_name)
117
+ if os.path.exists(hip_lib_path):
118
+ return hip_lib_path
119
+ paths.append(hip_lib_path)
120
+
121
+ # if available, `hipconfig --path` prints the HIP SDK root
122
+ try:
123
+ hip_root = subprocess.check_output(["hipconfig", "--path"]).decode().strip()
124
+ if hip_root:
125
+ hip_lib_path = os.path.join(hip_root, "lib", lib_name)
126
+ if os.path.exists(hip_lib_path):
127
+ return hip_lib_path
128
+ paths.append(hip_lib_path)
129
+ except (subprocess.CalledProcessError, FileNotFoundError):
130
+ # hipconfig may not be available
131
+ pass
132
+
133
+ # ROCm lib dir based on env var
134
+ env_rocm_path = os.getenv("ROCM_PATH")
135
+ if env_rocm_path:
136
+ rocm_lib_path = os.path.join(env_rocm_path, "lib", lib_name)
137
+ if os.path.exists(rocm_lib_path):
138
+ return rocm_lib_path
139
+ paths.append(rocm_lib_path)
140
+
107
141
  # Afterwards try to search the loader dynamic library resolution paths.
108
- libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
142
+ libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
109
143
  # each line looks like the following:
110
144
  # libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6
111
145
  # libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so
@@ -124,25 +158,6 @@ def _get_path_to_hip_runtime_dylib():
124
158
  raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}")
125
159
 
126
160
 
127
- def compile_module_from_src(src, name):
128
- key = hashlib.sha256(src.encode("utf-8")).hexdigest()
129
- cache = get_cache_manager(key)
130
- cache_path = cache.get_file(f"{name}.so")
131
- if cache_path is None:
132
- with tempfile.TemporaryDirectory() as tmpdir:
133
- src_path = os.path.join(tmpdir, f"{name}.c")
134
- with open(src_path, "w") as f:
135
- f.write(src)
136
- so = _build(name, src_path, tmpdir, [], include_dir, [])
137
- with open(so, "rb") as f:
138
- cache_path = cache.put(f.read(), f"{name}.so", binary=True)
139
- import importlib.util
140
- spec = importlib.util.spec_from_file_location(name, cache_path)
141
- mod = importlib.util.module_from_spec(spec)
142
- spec.loader.exec_module(mod)
143
- return mod
144
-
145
-
146
161
  class HIPUtils(object):
147
162
 
148
163
  def __new__(cls):
@@ -157,7 +172,7 @@ class HIPUtils(object):
157
172
  # This way we don't need to escape-quote C code curly brackets and we can replace
158
173
  # exactly once.
159
174
  src = src.replace('/*py_libhip_search_path*/', libhip_path, 1)
160
- mod = compile_module_from_src(src, "hip_utils")
175
+ mod = compile_module_from_src(src=src, name="hip_utils", include_dirs=include_dirs)
161
176
  self.load_binary = mod.load_binary
162
177
  self.get_device_properties = mod.get_device_properties
163
178
 
@@ -167,26 +182,71 @@ def ty_to_cpp(ty):
167
182
  if ty[0] == '*':
168
183
  return "hipDeviceptr_t"
169
184
  return {
170
- "i1": "int32_t",
185
+ "i1": "int8_t",
171
186
  "i8": "int8_t",
172
187
  "i16": "int16_t",
173
188
  "i32": "int32_t",
174
189
  "i64": "int64_t",
175
- "u1": "uint32_t",
190
+ "u1": "uint8_t",
176
191
  "u8": "uint8_t",
177
192
  "u16": "uint16_t",
178
193
  "u32": "uint32_t",
179
194
  "u64": "uint64_t",
180
- "fp16": "float",
181
- "bf16": "float",
182
- "fp32": "float",
183
- "f32": "float",
195
+ "fp16": "double",
196
+ "bf16": "double",
197
+ "fp32": "double",
198
+ "f32": "double",
184
199
  "fp64": "double",
185
200
  }[ty]
186
201
 
187
202
 
203
+ FLOAT_STORAGE_TYPE = {
204
+ "fp16": "uint16_t",
205
+ "bf16": "uint16_t",
206
+ "fp32": "uint32_t",
207
+ "f32": "uint32_t",
208
+ "fp64": "uint64_t",
209
+ }
210
+ FLOAT_PACK_FUNCTION = {
211
+ "fp16": "pack_fp16",
212
+ "bf16": "pack_bf16",
213
+ "fp32": "pack_fp32",
214
+ "f32": "pack_fp32",
215
+ "fp64": "pack_fp64",
216
+ }
217
+
218
+ _BASE_ARGS_FORMAT = "piiiKKOOOOO"
219
+
220
+
188
221
  def make_launcher(constants, signature, warp_size):
189
222
 
223
+ def _expand_signature(signature):
224
+ output = []
225
+ # Expand tensor descriptor arguments into base pointer, shape, and
226
+ # strides
227
+ for sig in signature:
228
+ if isinstance(sig, str) and sig.startswith("tensordesc"):
229
+ ndim = sig.count(",") + 1
230
+ dtype = re.match("tensordesc<([^[>]*)", sig).group()
231
+
232
+ output.append("*" + dtype)
233
+ for _ in range(2 * ndim):
234
+ output.append("i64")
235
+ output.append("i1")
236
+ # Currently the host side tensor descriptors get passed in as a
237
+ # tensor desc, shape, and strides. We have no way to use these
238
+ # shape and strides when processing tensor descriptors which is
239
+ # why we provide our own decomposition above. Sadly this means
240
+ # we have to pass the shape and strides twice.
241
+ for _ in range(ndim):
242
+ output.append("i32")
243
+ for _ in range(ndim):
244
+ output.append("i64")
245
+ else:
246
+ output.append(sig)
247
+
248
+ return output
249
+
190
250
  def _serialize_signature(sig):
191
251
  if isinstance(sig, tuple):
192
252
  return ','.join(map(_serialize_signature, sig))
@@ -198,7 +258,7 @@ def make_launcher(constants, signature, warp_size):
198
258
  return f"[{val}]"
199
259
  if ty[0] == '*':
200
260
  return "PyObject*"
201
- if ty in ("constexpr"):
261
+ if ty == "constexpr":
202
262
  return "PyObject*"
203
263
  return ty_to_cpp(ty)
204
264
 
@@ -208,10 +268,9 @@ def make_launcher(constants, signature, warp_size):
208
268
  return f"({val})"
209
269
  if ty[0] == '*':
210
270
  return "O"
211
- if ty in ("constexpr"):
271
+ if ty == "constexpr":
212
272
  return "O"
213
273
  return {
214
- "float": "f",
215
274
  "double": "d",
216
275
  "long": "l",
217
276
  "int8_t": "b",
@@ -224,30 +283,51 @@ def make_launcher(constants, signature, warp_size):
224
283
  "uint64_t": "K",
225
284
  }[ty_to_cpp(ty)]
226
285
 
286
+ signature = {idx: s for idx, s in enumerate(_expand_signature(signature.values()))}
287
+
227
288
  args_format = ''.join([format_of(ty) for ty in signature.values()])
228
- format = "piiiKKOOOO" + args_format
289
+ format = _BASE_ARGS_FORMAT + args_format
229
290
  signature = ','.join(map(_serialize_signature, signature.values()))
230
291
  signature = list(filter(bool, signature.split(',')))
231
292
  signature = {i: s for i, s in enumerate(signature)}
232
293
  args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
233
294
  # Record the end of regular arguments;
234
295
  # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
235
- arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
296
+ arg_decl_list = []
297
+ for i, ty in signature.items():
298
+ if ty == "constexpr":
299
+ continue
300
+ if ty in FLOAT_STORAGE_TYPE:
301
+ arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
302
+ else:
303
+ arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
304
+ arg_decls = ', '.join(arg_decl_list)
236
305
  internal_args_list = []
237
306
  for i, ty in signature.items():
238
307
  if ty[0] == "*":
239
308
  internal_args_list.append(f"ptr_info{i}.dev_ptr")
309
+ elif ty in FLOAT_STORAGE_TYPE:
310
+ internal_args_list.append(f"_arg{i}_storage")
240
311
  elif ty != "constexpr":
241
312
  internal_args_list.append(f"_arg{i}")
313
+
314
+ float_storage_decls = [
315
+ f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
316
+ for i, ty in signature.items()
317
+ if ty in FLOAT_STORAGE_TYPE
318
+ ]
319
+
242
320
  libhip_path = _get_path_to_hip_runtime_dylib()
243
321
 
244
322
  # generate glue code
245
323
  params = list(range(len(signature)))
246
324
  params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
247
325
  params.append("&global_scratch")
326
+ params.append("&profile_scratch")
248
327
  src = f"""
249
328
  #define __HIP_PLATFORM_AMD__
250
329
  #include <hip/hip_runtime.h>
330
+ #include <hip/hip_runtime_api.h>
251
331
  #include <Python.h>
252
332
  #include <dlfcn.h>
253
333
  #include <stdbool.h>
@@ -260,6 +340,7 @@ static const char *hipLibSearchPaths[] = {{"{libhip_path}"}};
260
340
  // The list of HIP dynamic library symbols and their signature we are interested
261
341
  // in this file.
262
342
  #define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \\
343
+ FOR_EACH_STR_FN(hipGetLastError) \\
263
344
  FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \\
264
345
  FOR_EACH_ERR_FN(hipModuleLaunchKernel, hipFunction_t f, \\
265
346
  unsigned int gridDimX, unsigned int gridDimY, \\
@@ -291,9 +372,6 @@ static struct HIPSymbolTable hipSymbolTable;
291
372
  bool initSymbolTable() {{
292
373
  // Use the HIP runtime library loaded into the existing process if it exits.
293
374
  void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
294
- if (lib) {{
295
- // printf("[triton] chosen loaded libamdhip64.so in the process\\n");
296
- }}
297
375
 
298
376
  // Otherwise, go through the list of search paths to dlopen the first HIP
299
377
  // driver library.
@@ -303,7 +381,6 @@ bool initSymbolTable() {{
303
381
  void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
304
382
  if (handle) {{
305
383
  lib = handle;
306
- // printf("[triton] chosen %s\\n", hipLibSearchPaths[i]);
307
384
  }}
308
385
  }}
309
386
  }}
@@ -312,17 +389,36 @@ bool initSymbolTable() {{
312
389
  return false;
313
390
  }}
314
391
 
315
- // Resolve all symbols we are interested in.
392
+ typedef hipError_t (*hipGetProcAddress_fn)(
393
+ const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
394
+ hipDriverProcAddressQueryResult *symbolStatus);
395
+ hipGetProcAddress_fn hipGetProcAddress;
316
396
  dlerror(); // Clear existing errors
317
397
  const char *error = NULL;
318
- #define QUERY_EACH_FN(hipSymbolName, ...) \\
319
- *(void **)&hipSymbolTable.hipSymbolName = dlsym(lib, #hipSymbolName); \\
320
- error = dlerror(); \\
321
- if (error) {{ \\
322
- PyErr_SetString(PyExc_RuntimeError, \\
323
- "cannot query " #hipSymbolName " from libamdhip64.so"); \\
324
- dlclose(lib); \\
325
- return false; \\
398
+ *(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
399
+ error = dlerror();
400
+ if (error) {{
401
+ PyErr_SetString(PyExc_RuntimeError,
402
+ "cannot query 'hipGetProcAddress' from libamdhip64.so");
403
+ dlclose(lib);
404
+ return false;
405
+ }}
406
+
407
+ // Resolve all symbols we are interested in.
408
+ int hipVersion = HIP_VERSION;
409
+ uint64_t hipFlags = 0;
410
+ hipDriverProcAddressQueryResult symbolStatus;
411
+ hipError_t status = hipSuccess;
412
+ #define QUERY_EACH_FN(hipSymbolName, ...) \
413
+ status = hipGetProcAddress(#hipSymbolName, \
414
+ (void **)&hipSymbolTable.hipSymbolName, \
415
+ hipVersion, hipFlags, &symbolStatus); \
416
+ if (status != hipSuccess) {{ \
417
+ PyErr_SetString(PyExc_RuntimeError, \
418
+ "cannot get address for '" #hipSymbolName \
419
+ "' from libamdhip64.so"); \
420
+ dlclose(lib); \
421
+ return false; \
326
422
  }}
327
423
 
328
424
  HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
@@ -344,8 +440,7 @@ static inline void gpuAssert(hipError_t code, const char *file, int line)
344
440
 
345
441
  #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
346
442
 
347
- static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
348
- // printf("_launch hip kernel\\n");
443
+ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
349
444
  hipDeviceptr_t global_scratch = 0;
350
445
  void *params[] = {{ {', '.join(params)} }};
351
446
  if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
@@ -362,8 +457,11 @@ typedef struct _DevicePtrInfo {{
362
457
  bool valid;
363
458
  }} DevicePtrInfo;
364
459
 
460
+ static PyObject* data_ptr_str = NULL;
461
+
365
462
  static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
366
463
  DevicePtrInfo ptr_info;
464
+ hipError_t status = hipSuccess;
367
465
  ptr_info.dev_ptr = 0;
368
466
  ptr_info.valid = true;
369
467
  if (PyLong_Check(obj)) {{
@@ -374,53 +472,81 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
374
472
  // valid nullptr
375
473
  return ptr_info;
376
474
  }}
377
- PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
378
- if(ptr){{
379
- PyObject *empty_tuple = PyTuple_New(0);
380
- PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
381
- Py_DECREF(empty_tuple);
382
- Py_DECREF(ptr);
383
- if (!PyLong_Check(ret)) {{
384
- PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
475
+ PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
476
+ if (!ret) {{
477
+ PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
478
+ ptr_info.valid = false;
479
+ goto cleanup;
480
+ }}
481
+ if (!PyLong_Check(ret)) {{
482
+ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
483
+ ptr_info.valid = false;
484
+ goto cleanup;
485
+ }}
486
+ ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
487
+ if (!ptr_info.dev_ptr)
488
+ goto cleanup;
489
+ uint64_t dev_ptr;
490
+ status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
491
+ if (status == hipErrorInvalidValue) {{
492
+ PyErr_Format(PyExc_ValueError,
493
+ "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
385
494
  ptr_info.valid = false;
386
- return ptr_info;
387
- }}
388
- ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
389
- if(!ptr_info.dev_ptr)
390
- return ptr_info;
391
- uint64_t dev_ptr;
392
- hipError_t status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
393
- if (status == hipErrorInvalidValue) {{
394
- PyErr_Format(PyExc_ValueError,
395
- "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
396
- ptr_info.valid = false;
397
- }}
398
- ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
399
- Py_DECREF(ret);
400
- return ptr_info;
495
+ // Clear and ignore HIP error
496
+ (void)hipSymbolTable.hipGetLastError();
401
497
  }}
402
- PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
498
+ ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
499
+ cleanup:
500
+ Py_DECREF(ret);
403
501
  return ptr_info;
404
502
  }}
405
503
 
504
+ static uint16_t pack_fp16(double f) {{
505
+ uint16_t result;
506
+ // from https://github.com/python/pythoncapi-compat/blob/5e317108f872c904eb726cb8d560dcadbdf88a72/pythoncapi_compat.h#L482-L492
507
+ #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
508
+ _PyFloat_Pack2(f, (unsigned char*)&result, 1);
509
+ #else
510
+ PyFloat_Pack2(f, (char*)&result, 1);
511
+ #endif
512
+ return result;
513
+ }}
514
+
515
+ static uint16_t pack_bf16(double f) {{
516
+ float f32 = (float)f;
517
+ uint32_t u32 = *(uint32_t*)&f32;
518
+ return (uint16_t)(u32 >> 16);
519
+ }}
520
+
521
+ static uint32_t pack_fp32(double f) {{
522
+ float f32 = (float)f;
523
+ return *(uint32_t*)&f32;
524
+ }}
525
+
526
+ static uint64_t pack_fp64(double f) {{
527
+ return *(uint64_t*)&f;
528
+ }}
529
+
406
530
  static PyObject* launch(PyObject* self, PyObject* args) {{
407
- // printf("launch\\n");
408
531
  int gridX, gridY, gridZ;
409
532
  uint64_t _stream;
410
533
  uint64_t _function;
411
534
  int launch_cooperative_grid;
535
+ PyObject *profile_scratch_obj = NULL;
412
536
  PyObject *launch_enter_hook = NULL;
413
537
  PyObject *launch_exit_hook = NULL;
414
538
  PyObject *kernel_metadata = NULL;
415
539
  PyObject *launch_metadata = NULL;
416
540
  {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
417
541
  if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
418
- &gridX, &gridY, &gridZ, &_stream, &_function,
542
+ &gridX, &gridY, &gridZ, &_stream, &_function, &profile_scratch_obj,
419
543
  &kernel_metadata, &launch_metadata,
420
544
  &launch_enter_hook, &launch_exit_hook {args_list})) {{
421
545
  return NULL;
422
546
  }}
423
547
 
548
+ {' '.join(float_storage_decls)}
549
+
424
550
  // extract kernel metadata
425
551
  int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
426
552
  if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
@@ -428,32 +554,36 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
428
554
  }}
429
555
  // extract launch metadata
430
556
  if (launch_enter_hook != Py_None){{
431
- PyObject* args = Py_BuildValue("(O)", launch_metadata);
432
- PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
433
- Py_DECREF(args);
557
+ PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
434
558
  if (!ret)
435
559
  return NULL;
560
+ Py_DECREF(ret);
436
561
  }}
437
562
 
563
+ hipDeviceptr_t profile_scratch = 0;
564
+ if (profile_scratch_obj != Py_None) {{
565
+ DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
566
+ if (!profile_scratch_info.valid) {{
567
+ return NULL;
568
+ }}
569
+ profile_scratch = profile_scratch_info.dev_ptr;
570
+ }}
438
571
 
439
572
  // raise exception asap
440
573
  {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
441
- _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
574
+ _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
442
575
 
443
576
  if(launch_exit_hook != Py_None){{
444
- PyObject* args = Py_BuildValue("(O)", launch_metadata);
445
- PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
446
- Py_DECREF(args);
577
+ PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
447
578
  if (!ret)
448
579
  return NULL;
580
+ Py_DECREF(ret);
449
581
  }}
450
582
 
451
583
  if(PyErr_Occurred()) {{
452
584
  return NULL;
453
585
  }}
454
- // return None
455
- Py_INCREF(Py_None);
456
- return Py_None;
586
+ Py_RETURN_NONE;
457
587
  }}
458
588
 
459
589
  static PyMethodDef ModuleMethods[] = {{
@@ -477,6 +607,10 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
477
607
  if(m == NULL) {{
478
608
  return NULL;
479
609
  }}
610
+ data_ptr_str = PyUnicode_InternFromString("data_ptr");
611
+ if(data_ptr_str == NULL) {{
612
+ return NULL;
613
+ }}
480
614
  PyModule_AddFunctions(m, ModuleMethods);
481
615
  return m;
482
616
  }}
@@ -484,6 +618,31 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
484
618
  return src
485
619
 
486
620
 
621
+ def wrap_handle_tensor_descriptor(launcher):
622
+ """
623
+ Replace all tensor descriptors with the base ptr, shape, and strides
624
+ """
625
+
626
+ def inner(*args):
627
+ meta_args = args[:len(_BASE_ARGS_FORMAT)]
628
+ raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
629
+ final_args = []
630
+ for arg in raw_kernel_args:
631
+ if isinstance(arg, TensorDescriptor):
632
+ # Currently the host side tensor descriptors get decomposed in
633
+ # the frontend to tensor desc, shape, and strides. We have no
634
+ # way to use these shape and strides when processing tensor
635
+ # descriptors which is why we provide our own decomposition
636
+ # above. Sadly this means we have to pass the shape and strides
637
+ # twice.
638
+ final_args.extend([arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides])
639
+ else:
640
+ final_args.append(arg)
641
+ return launcher(*meta_args, *final_args)
642
+
643
+ return inner
644
+
645
+
487
646
  class HIPLauncher(object):
488
647
 
489
648
  def __init__(self, src, metadata):
@@ -492,12 +651,28 @@ class HIPLauncher(object):
492
651
  constants = {arg_idx(idx): value for idx, value in constants.items()}
493
652
  signature = {idx: value for idx, value in src.signature.items()}
494
653
  src = make_launcher(constants, signature, metadata.warp_size)
495
- mod = compile_module_from_src(src, "__triton_launcher")
496
- self.launch = mod.launch
654
+ mod = compile_module_from_src(src=src, name="__triton_launcher", include_dirs=include_dirs)
655
+ has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
656
+
657
+ self.launch = wrap_handle_tensor_descriptor(mod.launch) if has_tensor_desc_arg else mod.launch
497
658
  self.launch_cooperative_grid = metadata.launch_cooperative_grid
659
+ self.profile_scratch_size = metadata.profile_scratch_size
660
+ self.profile_scratch_align = metadata.profile_scratch_align
661
+
662
+ def __call__(self, gridX, gridY, gridZ, stream, function, *args):
663
+
664
+ def allocate_scratch(size, align, allocator):
665
+ if size > 0:
666
+ grid_size = gridX * gridY * gridZ
667
+ alloc_size = grid_size * size
668
+ alloc_fn = allocator.get()
669
+ return alloc_fn(alloc_size, align, stream)
670
+ return None
498
671
 
499
- def __call__(self, *args):
500
- self.launch(self.launch_cooperative_grid, *args)
672
+ profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
673
+ _allocation._profile_allocator)
674
+
675
+ self.launch(self.launch_cooperative_grid, gridX, gridY, gridZ, stream, function, profile_scratch, *args)
501
676
 
502
677
 
503
678
  class HIPDriver(GPUDriver):
@@ -515,14 +690,17 @@ class HIPDriver(GPUDriver):
515
690
  def is_active():
516
691
  try:
517
692
  import torch
518
- return torch.version.hip is not None
693
+ return torch.cuda.is_available() and (torch.version.hip is not None)
519
694
  except ImportError:
520
695
  return False
521
696
 
697
+ def map_python_to_cpp_type(self, ty: str) -> str:
698
+ return ty_to_cpp(ty)
699
+
522
700
  def get_current_target(self):
523
701
  device = self.get_current_device()
524
702
  device_properties = self.utils.get_device_properties(device)
525
- arch = device_properties['arch']
703
+ arch = knobs.runtime.override_arch or device_properties['arch']
526
704
  warp_size = device_properties['warpSize']
527
705
  return GPUTarget("hip", arch.split(':')[0], warp_size)
528
706
 
@@ -1,9 +1,6 @@
1
- import os
2
- import re
3
- import subprocess
4
- import sysconfig
5
1
  from abc import ABCMeta, abstractmethod
6
2
  from dataclasses import dataclass
3
+ from enum import Enum
7
4
  from typing import Dict, Union
8
5
  from types import ModuleType
9
6
 
@@ -17,6 +14,12 @@ class GPUTarget(object):
17
14
  warp_size: int
18
15
 
19
16
 
17
+ class Language(Enum):
18
+ """The input language being compiled by the backend."""
19
+ TRITON = 0
20
+ GLUON = 1
21
+
22
+
20
23
  class BaseBackend(metaclass=ABCMeta):
21
24
 
22
25
  def __init__(self, target: GPUTarget) -> None:
@@ -24,23 +27,6 @@ class BaseBackend(metaclass=ABCMeta):
24
27
  assert self.supports_target(target)
25
28
 
26
29
  @staticmethod
27
- def _path_to_binary(binary: str):
28
- binary += sysconfig.get_config_var("EXE")
29
- base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
30
- paths = [
31
- os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
32
- os.path.join(base_dir, "third_party", "cuda", "bin", binary),
33
- ]
34
- for path in paths:
35
- if os.path.exists(path) and os.path.isfile(path):
36
- result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
37
- if result is not None:
38
- version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
39
- if version is not None:
40
- return path, version.group(1)
41
- raise RuntimeError(f"Cannot find {binary}")
42
-
43
- @classmethod
44
30
  @abstractmethod
45
31
  def supports_target(target: GPUTarget):
46
32
  raise NotImplementedError
triton/backends/driver.py CHANGED
@@ -15,6 +15,19 @@ class DriverBase(metaclass=ABCMeta):
15
15
  def is_active(self):
16
16
  pass
17
17
 
18
+ @abstractmethod
19
+ def map_python_to_cpp_type(self, ty: str) -> str:
20
+ """
21
+ Converts a Triton type string to its corresponding C++ type string for this backend.
22
+
23
+ Args:
24
+ ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'.
25
+
26
+ Returns:
27
+ str: The C++ type string.
28
+ """
29
+ pass
30
+
18
31
  @abstractmethod
19
32
  def get_current_target(self):
20
33
  pass
Binary file