triton-windows 3.3.0.post19__cp310-cp310-win_amd64.whl → 3.4.0.post20__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (173) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/runtime/tcc/lib/python310.def +1610 -0
  56. triton/runtime/tcc/lib/python311.def +1633 -0
  57. triton/runtime/tcc/lib/python312.def +1703 -0
  58. triton/runtime/tcc/lib/python313.def +1651 -0
  59. triton/runtime/tcc/lib/python313t.def +1656 -0
  60. triton/runtime/tcc/lib/python39.def +1644 -0
  61. triton/runtime/tcc/lib/python3t.def +905 -0
  62. triton/testing.py +16 -12
  63. triton/tools/disasm.py +3 -4
  64. triton/tools/tensor_descriptor.py +36 -0
  65. triton/windows_utils.py +14 -6
  66. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  67. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  68. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
  69. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  70. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  71. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  72. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  73. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  80. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  81. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  82. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  83. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  84. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  85. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  86. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  87. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  88. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  89. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  90. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  91. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  92. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  93. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  94. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  95. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  96. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  97. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  98. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  99. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  100. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  101. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  102. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  103. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  104. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  105. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  106. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  107. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  108. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  109. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  110. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  111. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  112. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  113. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  114. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  115. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  116. triton/backends/amd/include/hip/device_functions.h +0 -38
  117. triton/backends/amd/include/hip/driver_types.h +0 -468
  118. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  119. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  120. triton/backends/amd/include/hip/hip_common.h +0 -100
  121. triton/backends/amd/include/hip/hip_complex.h +0 -38
  122. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  123. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  124. triton/backends/amd/include/hip/hip_ext.h +0 -161
  125. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  126. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  127. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  128. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  129. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  130. triton/backends/amd/include/hip/hip_profile.h +0 -27
  131. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  132. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  133. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  134. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  135. triton/backends/amd/include/hip/hip_version.h +0 -17
  136. triton/backends/amd/include/hip/hiprtc.h +0 -421
  137. triton/backends/amd/include/hip/library_types.h +0 -78
  138. triton/backends/amd/include/hip/math_functions.h +0 -42
  139. triton/backends/amd/include/hip/surface_types.h +0 -63
  140. triton/backends/amd/include/hip/texture_types.h +0 -194
  141. triton/backends/amd/include/hsa/Brig.h +0 -1131
  142. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  143. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  144. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  145. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  146. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  147. triton/backends/amd/include/hsa/hsa.h +0 -5738
  148. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  149. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  150. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  151. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  152. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  153. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  154. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  155. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  156. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  157. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  158. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  159. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  160. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  161. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  162. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  163. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  164. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  165. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  166. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  167. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  168. triton/backends/amd/include/roctracer/roctx.h +0 -229
  169. triton/language/_utils.py +0 -21
  170. triton/language/extra/cuda/_experimental_tma.py +0 -106
  171. triton/tools/experimental_descriptor.py +0 -32
  172. triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
  173. triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
@@ -1,16 +1,16 @@
1
1
  import functools
2
2
  import os
3
- import hashlib
4
3
  import subprocess
5
- import tempfile
4
+ import re
6
5
  from pathlib import Path
7
- from triton.runtime.build import _build
8
- from triton.runtime.cache import get_cache_manager
6
+ from triton import knobs
9
7
  from triton.backends.compiler import GPUTarget
10
8
  from triton.backends.driver import GPUDriver
9
+ from triton.runtime.build import compile_module_from_src
10
+ from triton.tools.tensor_descriptor import TensorDescriptor
11
11
 
12
12
  dirname = os.path.dirname(os.path.realpath(__file__))
13
- include_dir = [os.path.join(dirname, "include")]
13
+ include_dirs = [os.path.join(dirname, "include")]
14
14
 
15
15
 
16
16
  def _find_already_mmapped_dylib_on_linux(lib_name):
@@ -66,8 +66,7 @@ def _get_path_to_hip_runtime_dylib():
66
66
  lib_name = "libamdhip64.so"
67
67
 
68
68
  # If we are told explicitly what HIP runtime dynamic library to use, obey that.
69
- env_libhip_path = os.getenv("TRITON_LIBHIP_PATH")
70
- if env_libhip_path:
69
+ if env_libhip_path := knobs.amd.libhip_path:
71
70
  if env_libhip_path.endswith(lib_name) and os.path.exists(env_libhip_path):
72
71
  return env_libhip_path
73
72
  raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
@@ -81,6 +80,12 @@ def _get_path_to_hip_runtime_dylib():
81
80
 
82
81
  paths = []
83
82
 
83
+ # Check backend
84
+ local_lib = os.path.join(os.path.dirname(__file__), "lib", lib_name)
85
+ if os.path.exists(local_lib):
86
+ return local_lib
87
+ paths.append(local_lib)
88
+
84
89
  import site
85
90
  # First search the HIP runtime dynamic library packaged with PyTorch. It's very likely
86
91
  # that we run Triton together with PyTorch. This makes sure we use the same dynamic
@@ -124,25 +129,6 @@ def _get_path_to_hip_runtime_dylib():
124
129
  raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}")
125
130
 
126
131
 
127
- def compile_module_from_src(src, name):
128
- key = hashlib.sha256(src.encode("utf-8")).hexdigest()
129
- cache = get_cache_manager(key)
130
- cache_path = cache.get_file(f"{name}.so")
131
- if cache_path is None:
132
- with tempfile.TemporaryDirectory() as tmpdir:
133
- src_path = os.path.join(tmpdir, f"{name}.c")
134
- with open(src_path, "w") as f:
135
- f.write(src)
136
- so = _build(name, src_path, tmpdir, [], include_dir, [])
137
- with open(so, "rb") as f:
138
- cache_path = cache.put(f.read(), f"{name}.so", binary=True)
139
- import importlib.util
140
- spec = importlib.util.spec_from_file_location(name, cache_path)
141
- mod = importlib.util.module_from_spec(spec)
142
- spec.loader.exec_module(mod)
143
- return mod
144
-
145
-
146
132
  class HIPUtils(object):
147
133
 
148
134
  def __new__(cls):
@@ -157,7 +143,7 @@ class HIPUtils(object):
157
143
  # This way we don't need to escape-quote C code curly brackets and we can replace
158
144
  # exactly once.
159
145
  src = src.replace('/*py_libhip_search_path*/', libhip_path, 1)
160
- mod = compile_module_from_src(src, "hip_utils")
146
+ mod = compile_module_from_src(src=src, name="hip_utils", include_dirs=include_dirs)
161
147
  self.load_binary = mod.load_binary
162
148
  self.get_device_properties = mod.get_device_properties
163
149
 
@@ -177,16 +163,60 @@ def ty_to_cpp(ty):
177
163
  "u16": "uint16_t",
178
164
  "u32": "uint32_t",
179
165
  "u64": "uint64_t",
180
- "fp16": "float",
181
- "bf16": "float",
182
- "fp32": "float",
183
- "f32": "float",
166
+ "fp16": "double",
167
+ "bf16": "double",
168
+ "fp32": "double",
169
+ "f32": "double",
184
170
  "fp64": "double",
185
171
  }[ty]
186
172
 
187
173
 
174
+ FLOAT_STORAGE_TYPE = {
175
+ "fp16": "uint16_t",
176
+ "bf16": "uint16_t",
177
+ "fp32": "uint32_t",
178
+ "f32": "uint32_t",
179
+ "fp64": "uint64_t",
180
+ }
181
+ FLOAT_PACK_FUNCTION = {
182
+ "fp16": "pack_fp16",
183
+ "bf16": "pack_bf16",
184
+ "fp32": "pack_fp32",
185
+ "f32": "pack_fp32",
186
+ "fp64": "pack_fp64",
187
+ }
188
+
189
+ _BASE_ARGS_FORMAT = "piiiKKOOOO"
190
+
191
+
188
192
  def make_launcher(constants, signature, warp_size):
189
193
 
194
+ def _expand_signature(signature):
195
+ output = []
196
+ # Expand tensor descriptor arguments into base pointer, shape, and
197
+ # strides
198
+ for sig in signature:
199
+ if isinstance(sig, str) and sig.startswith("tensordesc"):
200
+ ndim = sig.count(",") + 1
201
+ dtype = re.match("tensordesc<([^[>]*)", sig).group()
202
+
203
+ output.append("*" + dtype)
204
+ for _ in range(2 * ndim):
205
+ output.append("i64")
206
+ # Currently the host side tensor descriptors get passed in as a
207
+ # tensor desc, shape, and strides. We have no way to use these
208
+ # shape and strides when processing tensor descriptors which is
209
+ # why we provide our own decomposition above. Sadly this means
210
+ # we have to pass the shape and strides twice.
211
+ for _ in range(ndim):
212
+ output.append("i32")
213
+ for _ in range(ndim):
214
+ output.append("i64")
215
+ else:
216
+ output.append(sig)
217
+
218
+ return output
219
+
190
220
  def _serialize_signature(sig):
191
221
  if isinstance(sig, tuple):
192
222
  return ','.join(map(_serialize_signature, sig))
@@ -198,7 +228,7 @@ def make_launcher(constants, signature, warp_size):
198
228
  return f"[{val}]"
199
229
  if ty[0] == '*':
200
230
  return "PyObject*"
201
- if ty in ("constexpr"):
231
+ if ty == "constexpr":
202
232
  return "PyObject*"
203
233
  return ty_to_cpp(ty)
204
234
 
@@ -208,10 +238,9 @@ def make_launcher(constants, signature, warp_size):
208
238
  return f"({val})"
209
239
  if ty[0] == '*':
210
240
  return "O"
211
- if ty in ("constexpr"):
241
+ if ty == "constexpr":
212
242
  return "O"
213
243
  return {
214
- "float": "f",
215
244
  "double": "d",
216
245
  "long": "l",
217
246
  "int8_t": "b",
@@ -224,21 +253,40 @@ def make_launcher(constants, signature, warp_size):
224
253
  "uint64_t": "K",
225
254
  }[ty_to_cpp(ty)]
226
255
 
256
+ signature = {idx: s for idx, s in enumerate(_expand_signature(signature.values()))}
257
+
227
258
  args_format = ''.join([format_of(ty) for ty in signature.values()])
228
- format = "piiiKKOOOO" + args_format
259
+ format = _BASE_ARGS_FORMAT + args_format
229
260
  signature = ','.join(map(_serialize_signature, signature.values()))
230
261
  signature = list(filter(bool, signature.split(',')))
231
262
  signature = {i: s for i, s in enumerate(signature)}
232
263
  args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
233
264
  # Record the end of regular arguments;
234
265
  # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
235
- arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
266
+ arg_decl_list = []
267
+ for i, ty in signature.items():
268
+ if ty == "constexpr":
269
+ continue
270
+ if ty in FLOAT_STORAGE_TYPE:
271
+ arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
272
+ else:
273
+ arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
274
+ arg_decls = ', '.join(arg_decl_list)
236
275
  internal_args_list = []
237
276
  for i, ty in signature.items():
238
277
  if ty[0] == "*":
239
278
  internal_args_list.append(f"ptr_info{i}.dev_ptr")
279
+ elif ty in FLOAT_STORAGE_TYPE:
280
+ internal_args_list.append(f"_arg{i}_storage")
240
281
  elif ty != "constexpr":
241
282
  internal_args_list.append(f"_arg{i}")
283
+
284
+ float_storage_decls = [
285
+ f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
286
+ for i, ty in signature.items()
287
+ if ty in FLOAT_STORAGE_TYPE
288
+ ]
289
+
242
290
  libhip_path = _get_path_to_hip_runtime_dylib()
243
291
 
244
292
  # generate glue code
@@ -291,9 +339,6 @@ static struct HIPSymbolTable hipSymbolTable;
291
339
  bool initSymbolTable() {{
292
340
  // Use the HIP runtime library loaded into the existing process if it exits.
293
341
  void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
294
- if (lib) {{
295
- // printf("[triton] chosen loaded libamdhip64.so in the process\\n");
296
- }}
297
342
 
298
343
  // Otherwise, go through the list of search paths to dlopen the first HIP
299
344
  // driver library.
@@ -303,7 +348,6 @@ bool initSymbolTable() {{
303
348
  void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
304
349
  if (handle) {{
305
350
  lib = handle;
306
- // printf("[triton] chosen %s\\n", hipLibSearchPaths[i]);
307
351
  }}
308
352
  }}
309
353
  }}
@@ -345,7 +389,6 @@ static inline void gpuAssert(hipError_t code, const char *file, int line)
345
389
  #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
346
390
 
347
391
  static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
348
- // printf("_launch hip kernel\\n");
349
392
  hipDeviceptr_t global_scratch = 0;
350
393
  void *params[] = {{ {', '.join(params)} }};
351
394
  if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
@@ -383,11 +426,14 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
383
426
  if (!PyLong_Check(ret)) {{
384
427
  PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
385
428
  ptr_info.valid = false;
429
+ Py_DECREF(ret);
386
430
  return ptr_info;
387
431
  }}
388
432
  ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
389
- if(!ptr_info.dev_ptr)
433
+ if(!ptr_info.dev_ptr) {{
434
+ Py_DECREF(ret);
390
435
  return ptr_info;
436
+ }}
391
437
  uint64_t dev_ptr;
392
438
  hipError_t status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
393
439
  if (status == hipErrorInvalidValue) {{
@@ -403,8 +449,33 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
403
449
  return ptr_info;
404
450
  }}
405
451
 
452
+ static uint16_t pack_fp16(double f) {{
453
+ uint16_t result;
454
+ // from https://github.com/python/pythoncapi-compat
455
+ #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
456
+ _PyFloat_Pack2(f, (unsigned char*)&result, 1);
457
+ #else
458
+ PyFloat_Pack2(f, (unsigned char*)&result, 1);
459
+ #endif
460
+ return result;
461
+ }}
462
+
463
+ static uint16_t pack_bf16(double f) {{
464
+ float f32 = (float)f;
465
+ uint32_t u32 = *(uint32_t*)&f32;
466
+ return (uint16_t)(u32 >> 16);
467
+ }}
468
+
469
+ static uint32_t pack_fp32(double f) {{
470
+ float f32 = (float)f;
471
+ return *(uint32_t*)&f32;
472
+ }}
473
+
474
+ static uint64_t pack_fp64(double f) {{
475
+ return *(uint64_t*)&f;
476
+ }}
477
+
406
478
  static PyObject* launch(PyObject* self, PyObject* args) {{
407
- // printf("launch\\n");
408
479
  int gridX, gridY, gridZ;
409
480
  uint64_t _stream;
410
481
  uint64_t _function;
@@ -421,6 +492,8 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
421
492
  return NULL;
422
493
  }}
423
494
 
495
+ {' '.join(float_storage_decls)}
496
+
424
497
  // extract kernel metadata
425
498
  int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
426
499
  if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{
@@ -433,6 +506,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
433
506
  Py_DECREF(args);
434
507
  if (!ret)
435
508
  return NULL;
509
+ Py_DECREF(ret);
436
510
  }}
437
511
 
438
512
 
@@ -446,6 +520,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
446
520
  Py_DECREF(args);
447
521
  if (!ret)
448
522
  return NULL;
523
+ Py_DECREF(ret);
449
524
  }}
450
525
 
451
526
  if(PyErr_Occurred()) {{
@@ -484,6 +559,31 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
484
559
  return src
485
560
 
486
561
 
562
+ def wrap_handle_tensor_descriptor(launcher):
563
+ """
564
+ Replace all tensor descriptors with the base ptr, shape, and strides
565
+ """
566
+
567
+ def inner(*args):
568
+ meta_args = args[:len(_BASE_ARGS_FORMAT)]
569
+ raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
570
+ final_args = []
571
+ for arg in raw_kernel_args:
572
+ if isinstance(arg, TensorDescriptor):
573
+ # Currently the host side tensor descriptors get decomposed in
574
+ # the frontend to tensor desc, shape, and strides. We have no
575
+ # way to use these shape and strides when processing tensor
576
+ # descriptors which is why we provide our own decomposition
577
+ # above. Sadly this means we have to pass the shape and strides
578
+ # twice.
579
+ final_args.extend([arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides])
580
+ else:
581
+ final_args.append(arg)
582
+ return launcher(*meta_args, *final_args)
583
+
584
+ return inner
585
+
586
+
487
587
  class HIPLauncher(object):
488
588
 
489
589
  def __init__(self, src, metadata):
@@ -492,8 +592,10 @@ class HIPLauncher(object):
492
592
  constants = {arg_idx(idx): value for idx, value in constants.items()}
493
593
  signature = {idx: value for idx, value in src.signature.items()}
494
594
  src = make_launcher(constants, signature, metadata.warp_size)
495
- mod = compile_module_from_src(src, "__triton_launcher")
496
- self.launch = mod.launch
595
+ mod = compile_module_from_src(src=src, name="__triton_launcher", include_dirs=include_dirs)
596
+ has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
597
+
598
+ self.launch = wrap_handle_tensor_descriptor(mod.launch) if has_tensor_desc_arg else mod.launch
497
599
  self.launch_cooperative_grid = metadata.launch_cooperative_grid
498
600
 
499
601
  def __call__(self, *args):
@@ -515,14 +617,14 @@ class HIPDriver(GPUDriver):
515
617
  def is_active():
516
618
  try:
517
619
  import torch
518
- return torch.version.hip is not None
620
+ return torch.cuda.is_available() and (torch.version.hip is not None)
519
621
  except ImportError:
520
622
  return False
521
623
 
522
624
  def get_current_target(self):
523
625
  device = self.get_current_device()
524
626
  device_properties = self.utils.get_device_properties(device)
525
- arch = device_properties['arch']
627
+ arch = knobs.runtime.override_arch or device_properties['arch']
526
628
  warp_size = device_properties['warpSize']
527
629
  return GPUTarget("hip", arch.split(':')[0], warp_size)
528
630
 
@@ -1,9 +1,6 @@
1
- import os
2
- import re
3
- import subprocess
4
- import sysconfig
5
1
  from abc import ABCMeta, abstractmethod
6
2
  from dataclasses import dataclass
3
+ from enum import Enum
7
4
  from typing import Dict, Union
8
5
  from types import ModuleType
9
6
 
@@ -17,6 +14,12 @@ class GPUTarget(object):
17
14
  warp_size: int
18
15
 
19
16
 
17
+ class Language(Enum):
18
+ """The input language being compiled by the backend."""
19
+ TRITON = 0
20
+ GLUON = 1
21
+
22
+
20
23
  class BaseBackend(metaclass=ABCMeta):
21
24
 
22
25
  def __init__(self, target: GPUTarget) -> None:
@@ -24,23 +27,6 @@ class BaseBackend(metaclass=ABCMeta):
24
27
  assert self.supports_target(target)
25
28
 
26
29
  @staticmethod
27
- def _path_to_binary(binary: str):
28
- binary += sysconfig.get_config_var("EXE")
29
- base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
30
- paths = [
31
- os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
32
- os.path.join(base_dir, "third_party", "cuda", "bin", binary),
33
- ]
34
- for path in paths:
35
- if os.path.exists(path) and os.path.isfile(path):
36
- result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
37
- if result is not None:
38
- version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
39
- if version is not None:
40
- return path, version.group(1)
41
- raise RuntimeError(f"Cannot find {binary}")
42
-
43
- @classmethod
44
30
  @abstractmethod
45
31
  def supports_target(target: GPUTarget):
46
32
  raise NotImplementedError
Binary file