triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -1,37 +1,37 @@
1
1
  import functools
2
+ import operator
2
3
  import os
3
- import sysconfig
4
- import hashlib
5
4
  import subprocess
6
- import tempfile
5
+ import triton
6
+ import re
7
7
  from pathlib import Path
8
- from triton.runtime.build import _build
9
- from triton.runtime.cache import get_cache_manager
8
+ from triton import knobs
9
+ from triton.runtime.build import compile_module_from_src
10
10
  from triton.runtime import _allocation
11
11
  from triton.backends.compiler import GPUTarget
12
12
  from triton.backends.driver import GPUDriver
13
13
 
14
14
  dirname = os.path.dirname(os.path.realpath(__file__))
15
- include_dir = [os.path.join(dirname, "include")]
15
+ include_dirs = [os.path.join(dirname, "include")]
16
16
  if os.name == "nt":
17
17
  from triton.windows_utils import find_cuda
18
18
  _, cuda_inc_dirs, _ = find_cuda()
19
- include_dir += cuda_inc_dirs
19
+ include_dirs += cuda_inc_dirs
20
20
  libdevice_dir = os.path.join(dirname, "lib")
21
21
  libraries = ['cuda']
22
+ PyCUtensorMap = None
22
23
 
23
24
 
24
25
  @functools.lru_cache()
25
26
  def libcuda_dirs():
26
- env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH")
27
- if env_libcuda_path:
27
+ if env_libcuda_path := knobs.nvidia.libcuda_path:
28
28
  return [env_libcuda_path]
29
29
 
30
30
  if os.name == "nt":
31
31
  _, _, cuda_lib_dirs = find_cuda()
32
32
  return cuda_lib_dirs
33
33
 
34
- libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
34
+ libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
35
35
  # each line looks like the following:
36
36
  # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
37
37
  locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line]
@@ -55,36 +55,6 @@ def library_dirs():
55
55
  return [libdevice_dir, *libcuda_dirs()]
56
56
 
57
57
 
58
- @functools.lru_cache()
59
- def platform_key():
60
- from platform import machine, system, architecture
61
- return ",".join([machine(), system(), *architecture()])
62
-
63
-
64
- def compile_module_from_src(src, name):
65
- key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
66
- cache = get_cache_manager(key)
67
- ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
68
- cache_path = cache.get_file(f"{name}.{ext}")
69
- if cache_path is None:
70
- with tempfile.TemporaryDirectory() as tmpdir:
71
- src_path = os.path.join(tmpdir, f"{name}.c")
72
- with open(src_path, "w") as f:
73
- f.write(src)
74
- so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
75
- with open(so, "rb") as f:
76
- cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True)
77
-
78
- # Loading module with relative path may cause error
79
- cache_path = os.path.abspath(cache_path)
80
-
81
- import importlib.util
82
- spec = importlib.util.spec_from_file_location(name, cache_path)
83
- mod = importlib.util.module_from_spec(spec)
84
- spec.loader.exec_module(mod)
85
- return mod
86
-
87
-
88
58
  # ------------------------
89
59
  # Utils
90
60
  # ------------------------
@@ -98,13 +68,20 @@ class CudaUtils(object):
98
68
  return cls.instance
99
69
 
100
70
  def __init__(self):
101
- mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils")
71
+ mod = compile_module_from_src(
72
+ src=Path(os.path.join(dirname, "driver.c")).read_text(),
73
+ name="cuda_utils",
74
+ library_dirs=library_dirs(),
75
+ include_dirs=include_dirs,
76
+ libraries=libraries,
77
+ )
78
+ global PyCUtensorMap
79
+ PyCUtensorMap = mod.PyCUtensorMap
102
80
  self.load_binary = mod.load_binary
103
81
  self.get_device_properties = mod.get_device_properties
104
82
  self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
105
83
  self.set_printf_fifo_size = mod.set_printf_fifo_size
106
- self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor
107
- self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor
84
+ self.fill_tma_descriptor = mod.fill_tma_descriptor
108
85
 
109
86
 
110
87
  # ------------------------
@@ -115,32 +92,95 @@ class CudaUtils(object):
115
92
  def ty_to_cpp(ty):
116
93
  if ty[0] == '*':
117
94
  return "CUdeviceptr"
95
+ if ty.startswith("tensordesc"):
96
+ return "CUtensorMap"
118
97
  return {
119
- "i1": "int32_t",
98
+ "i1": "int8_t",
120
99
  "i8": "int8_t",
121
100
  "i16": "int16_t",
122
101
  "i32": "int32_t",
123
102
  "i64": "int64_t",
124
- "u1": "uint32_t",
103
+ "u1": "uint8_t",
125
104
  "u8": "uint8_t",
126
105
  "u16": "uint16_t",
127
106
  "u32": "uint32_t",
128
107
  "u64": "uint64_t",
129
- "fp16": "float",
130
- "bf16": "float",
131
- "fp32": "float",
132
- "f32": "float",
108
+ "fp16": "double",
109
+ "bf16": "double",
110
+ "fp32": "double",
111
+ "f32": "double",
133
112
  "fp64": "double",
134
113
  "nvTmaDesc": "CUtensorMap",
135
114
  }[ty]
136
115
 
137
116
 
138
- def make_launcher(constants, signature):
139
-
140
- def _serialize_signature(sig):
117
+ FLOAT_STORAGE_TYPE = {
118
+ "fp16": "uint16_t",
119
+ "bf16": "uint16_t",
120
+ "fp32": "uint32_t",
121
+ "f32": "uint32_t",
122
+ "fp64": "uint64_t",
123
+ }
124
+ FLOAT_PACK_FUNCTION = {
125
+ "fp16": "pack_fp16",
126
+ "bf16": "pack_bf16",
127
+ "fp32": "pack_fp32",
128
+ "f32": "pack_fp32",
129
+ "fp64": "pack_fp64",
130
+ }
131
+
132
+ _BASE_ARGS_FORMAT = "iiiKKppOOOOOO"
133
+ _BASE_ARGS_FORMAT_LEN = len(_BASE_ARGS_FORMAT)
134
+
135
+
136
+ def make_launcher(constants, signature, tensordesc_meta):
137
+
138
+ def _expand_signature(signature):
139
+ output = []
140
+ tensordesc_idx = 0
141
+ # Expand tensor descriptor arguments into either nvTmaDesc, shape and
142
+ # strides, or base pointer, shape and strides depending on whether the
143
+ # kernel was lowered to use the nvTmaDesc or not.
144
+ for sig in signature:
145
+ if isinstance(sig, str) and sig.startswith("tensordesc"):
146
+ meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
147
+ tensordesc_idx += 1
148
+
149
+ match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
150
+ dtype = match.group(1)
151
+ shape = match.group(2)
152
+ ndim = shape.count(",") + 1
153
+
154
+ if meta is None:
155
+ output.append("*" + dtype)
156
+ # Currently the host side tensor descriptors get passed in as a
157
+ # tensor desc, shape, and strides. We have no way to use these
158
+ # shape and strides when processing tensor descriptors which is
159
+ # why we provide our own decomposition above. Sadly this means
160
+ # we have to pass the shape and strides twice.
161
+ for _ in range(2 * ndim):
162
+ output.append("i64")
163
+ output.append("i1")
164
+ else:
165
+ output.append("nvTmaDesc")
166
+
167
+ for _ in range(ndim):
168
+ output.append("i32")
169
+ for _ in range(ndim):
170
+ output.append("i64")
171
+ else:
172
+ output.append(sig)
173
+
174
+ assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
175
+ return output
176
+
177
+ def _flatten_signature(sig, output):
178
+ # Flatten tuples
141
179
  if isinstance(sig, tuple):
142
- return ','.join(map(_serialize_signature, sig))
143
- return sig
180
+ for x in sig:
181
+ _flatten_signature(x, output)
182
+ else:
183
+ output.append(sig)
144
184
 
145
185
  def _extracted_type(ty):
146
186
  if isinstance(ty, tuple):
@@ -160,8 +200,9 @@ def make_launcher(constants, signature):
160
200
  return "O"
161
201
  if ty in ("constexpr", "nvTmaDesc"):
162
202
  return "O"
203
+ if ty.startswith("tensordesc"):
204
+ return "O"
163
205
  return {
164
- "float": "f",
165
206
  "double": "d",
166
207
  "long": "l",
167
208
  "int8_t": "b",
@@ -174,19 +215,34 @@ def make_launcher(constants, signature):
174
215
  "uint64_t": "K",
175
216
  }[ty_to_cpp(ty)]
176
217
 
218
+ expand_signature = _expand_signature(signature.values())
219
+ signature = {i: s for i, s in enumerate(expand_signature)}
220
+
177
221
  args_format = ''.join([format_of(ty) for ty in signature.values()])
178
- format = "iiiKKpOOOOO" + args_format
179
- signature = ','.join(map(_serialize_signature, signature.values()))
180
- signature = list(filter(bool, signature.split(',')))
181
- signature = {i: s for i, s in enumerate(signature)}
222
+ format = _BASE_ARGS_FORMAT + args_format
223
+
224
+ flat_signature = []
225
+ for sig in signature.values():
226
+ _flatten_signature(sig, flat_signature)
227
+ signature = {i: s for i, s in enumerate(flat_signature)}
182
228
  args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
183
229
  # Record the end of regular arguments;
184
230
  # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
185
- arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
231
+ arg_decl_list = []
232
+ for i, ty in signature.items():
233
+ if ty == "constexpr":
234
+ continue
235
+ if ty in FLOAT_STORAGE_TYPE:
236
+ arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
237
+ else:
238
+ arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
239
+ arg_decls = ', '.join(arg_decl_list)
186
240
  internal_args_list = []
187
241
  for i, ty in signature.items():
188
242
  if ty[0] == "*":
189
243
  internal_args_list.append(f"ptr_info{i}.dev_ptr")
244
+ elif ty in FLOAT_STORAGE_TYPE:
245
+ internal_args_list.append(f"_arg{i}_storage")
190
246
  elif ty == "nvTmaDesc":
191
247
  # Note: we have to dereference the pointer
192
248
  internal_args_list.append(f"*tma_ptr{i}")
@@ -205,15 +261,17 @@ def make_launcher(constants, signature):
205
261
  f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
206
262
  if ty == "nvTmaDesc"
207
263
  ]
264
+ float_storage_decls = [
265
+ f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
266
+ for i, ty in signature.items()
267
+ if ty in FLOAT_STORAGE_TYPE
268
+ ]
208
269
  params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
209
270
  params.append("&global_scratch")
271
+ params.append("&profile_scratch")
210
272
  src = f"""
211
273
  #define _CRT_SECURE_NO_WARNINGS
212
274
  #include \"cuda.h\"
213
- #include <stdbool.h>
214
- #define PY_SSIZE_T_CLEAN
215
- #define Py_LIMITED_API 0x03090000
216
- #include <Python.h>
217
275
 
218
276
  #ifndef _WIN32
219
277
  #include <dlfcn.h>
@@ -222,6 +280,16 @@ def make_launcher(constants, signature):
222
280
  #include <windows.h>
223
281
  #endif
224
282
 
283
+ #include <stdbool.h>
284
+ #include <stdlib.h>
285
+ #define PY_SSIZE_T_CLEAN
286
+ #include <Python.h>
287
+
288
+ typedef struct {{
289
+ PyObject_HEAD
290
+ _Alignas(128) CUtensorMap tensorMap;
291
+ }} PyCUtensorMapObject;
292
+
225
293
  static inline void gpuAssert(CUresult code, const char *file, int line)
226
294
  {{
227
295
  if (code != CUDA_SUCCESS)
@@ -282,67 +350,65 @@ static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
282
350
  }}
283
351
  #endif
284
352
 
285
- static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
353
+ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
286
354
  void *params[] = {{ {', '.join(params)} }};
287
355
  if (gridX*gridY*gridZ > 0) {{
288
- if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{
289
- CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
290
- }} else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {{
291
- CUlaunchAttribute launchAttr[1];
356
+ // 4 attributes that we can currently pass maximum
357
+ CUlaunchAttribute launchAttr[4];
358
+ static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
359
+ if (cuLaunchKernelExHandle == NULL) {{
360
+ cuLaunchKernelExHandle = getLaunchKernelExHandle();
361
+ }}
362
+ CUlaunchConfig config;
363
+ config.gridDimX = gridX;
364
+ config.gridDimY = gridY;
365
+ config.gridDimZ = gridZ;
366
+
367
+ if (num_ctas != 1) {{
368
+ config.gridDimX *= clusterDimX;
369
+ config.gridDimY *= clusterDimY;
370
+ config.gridDimZ *= clusterDimZ;
371
+ }}
372
+
373
+ config.blockDimX = 32 * num_warps;
374
+ config.blockDimY = 1;
375
+ config.blockDimZ = 1;
376
+ config.sharedMemBytes = shared_memory;
377
+ config.hStream = stream;
378
+ config.attrs = launchAttr;
379
+ int num_attrs = 0;
380
+
381
+ if (launch_pdl != 0) {{
382
+ CUlaunchAttribute pdlAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1}};
383
+ launchAttr[num_attrs] = pdlAttr;
384
+ ++num_attrs;
385
+ }}
386
+
387
+ if (launch_cooperative_grid != 0) {{
292
388
  CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
293
- launchAttr[0] = coopAttr;
294
-
295
- CUlaunchConfig config;
296
- config.gridDimX = gridX;
297
- config.gridDimY = gridY;
298
- config.gridDimZ = gridZ;
299
- config.blockDimX = 32 * num_warps;
300
- config.blockDimY = 1;
301
- config.blockDimZ = 1;
302
- config.sharedMemBytes = shared_memory;
303
- config.hStream = stream;
304
- config.attrs = launchAttr;
305
- config.numAttrs = 1;
306
-
307
- static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
308
- if (cuLaunchKernelExHandle == NULL) {{
309
- cuLaunchKernelExHandle = getLaunchKernelExHandle();
310
- }}
311
- CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
312
-
313
- }} else {{
314
- CUlaunchAttribute launchAttr[3];
315
- launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
316
- launchAttr[0].value.clusterDim.x = clusterDimX;
317
- launchAttr[0].value.clusterDim.y = clusterDimY;
318
- launchAttr[0].value.clusterDim.z = clusterDimZ;
319
- launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
320
- launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
321
-
322
- unsigned numAttrs = 2;
323
- if (0 != launch_cooperative_grid) {{
324
- CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
325
- launchAttr[2] = coopAttr;
326
- numAttrs = 3;
327
- }}
328
-
329
- CUlaunchConfig config;
330
- config.gridDimX = gridX * clusterDimX;
331
- config.gridDimY = gridY * clusterDimY;
332
- config.gridDimZ = gridZ * clusterDimZ;
333
- config.blockDimX = 32 * num_warps;
334
- config.blockDimY = 1;
335
- config.blockDimZ = 1;
336
- config.sharedMemBytes = shared_memory;
337
- config.hStream = stream;
338
- config.attrs = launchAttr;
339
- config.numAttrs = numAttrs;
340
- static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
341
- if (cuLaunchKernelExHandle == NULL) {{
342
- cuLaunchKernelExHandle = getLaunchKernelExHandle();
343
- }}
344
- CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
389
+ launchAttr[num_attrs] = coopAttr;
390
+ ++num_attrs;
345
391
  }}
392
+
393
+ if (num_ctas != 1) {{
394
+ CUlaunchAttribute clusterAttr = {{}};
395
+ clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
396
+ clusterAttr.value.clusterDim.x = clusterDimX;
397
+ clusterAttr.value.clusterDim.y = clusterDimY;
398
+ clusterAttr.value.clusterDim.z = clusterDimZ;
399
+ launchAttr[num_attrs] = clusterAttr;
400
+ ++num_attrs;
401
+
402
+ CUlaunchAttribute clusterSchedulingAttr = {{}};
403
+ clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
404
+ clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
405
+ launchAttr[num_attrs] = clusterSchedulingAttr;
406
+ ++num_attrs;
407
+ }}
408
+
409
+ config.numAttrs = num_attrs;
410
+
411
+ CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
346
412
  }}
347
413
  }}
348
414
 
@@ -351,6 +417,9 @@ typedef struct _DevicePtrInfo {{
351
417
  bool valid;
352
418
  }} DevicePtrInfo;
353
419
 
420
+ static PyObject* data_ptr_str = NULL;
421
+ static PyObject* py_tensor_map_type = NULL;
422
+
354
423
  static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
355
424
  DevicePtrInfo ptr_info;
356
425
  ptr_info.dev_ptr = 0;
@@ -363,37 +432,35 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
363
432
  // valid nullptr
364
433
  return ptr_info;
365
434
  }}
366
- PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
367
- if(ptr){{
368
- PyObject *empty_tuple = PyTuple_New(0);
369
- PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
370
- Py_DECREF(empty_tuple);
371
- Py_DECREF(ptr);
372
- if (!PyLong_Check(ret)) {{
373
- PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
374
- ptr_info.valid = false;
375
- return ptr_info;
376
- }}
377
- ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
378
- if(!ptr_info.dev_ptr)
379
- return ptr_info;
380
- uint64_t dev_ptr;
381
- int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
382
- if (status == CUDA_ERROR_INVALID_VALUE) {{
383
- PyErr_Format(PyExc_ValueError,
384
- "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
385
- ptr_info.valid = false;
386
- }} else if (status != CUDA_SUCCESS) {{
387
- CUDA_CHECK(status); // Catch any other cuda API errors
388
- ptr_info.valid = false;
389
- }}
390
- ptr_info.dev_ptr = dev_ptr;
391
- Py_DECREF(ret); // Thanks ChatGPT!
435
+ PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
436
+ if (!ret) {{
437
+ PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
438
+ ptr_info.valid = false;
439
+ goto cleanup;
440
+ }}
441
+ if (!PyLong_Check(ret)) {{
442
+ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
443
+ ptr_info.valid = false;
444
+ goto cleanup;
445
+ }}
446
+ ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
447
+ if(!ptr_info.dev_ptr)
392
448
  return ptr_info;
449
+ uint64_t dev_ptr;
450
+ int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
451
+ if (status == CUDA_ERROR_INVALID_VALUE) {{
452
+ PyErr_Format(PyExc_ValueError,
453
+ "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
454
+ ptr_info.valid = false;
455
+ }} else if (status != CUDA_SUCCESS) {{
456
+ CUDA_CHECK(status); // Catch any other cuda API errors
457
+ ptr_info.valid = false;
393
458
  }}
394
- PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
395
- ptr_info.valid = false;
459
+ ptr_info.dev_ptr = dev_ptr;
460
+ cleanup:
461
+ Py_XDECREF(ret);
396
462
  return ptr_info;
463
+
397
464
  }}
398
465
 
399
466
  static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
@@ -402,44 +469,18 @@ static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
402
469
  return NULL;
403
470
  }}
404
471
 
405
- PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr");
406
- if (!method_handle) {{
407
- PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist");
408
- return NULL;
409
- }}
410
-
411
- PyObject *empty_tuple = PyTuple_New(0);
412
- if (!empty_tuple) {{
413
- Py_DECREF(method_handle);
414
- PyErr_SetString(PyExc_SystemError, "Internal Python error!");
415
- return NULL;
416
- }}
417
- PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL);
418
- Py_DECREF(empty_tuple);
419
- Py_DECREF(method_handle);
420
- if (!method_ret) {{
421
- PyErr_SetString(PyExc_SystemError, "Internal Python error!");
422
- return NULL;
423
- }}
424
-
425
- if (!PyLong_Check(method_ret)) {{
426
- PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int");
427
- Py_DECREF(method_ret);
472
+ if (Py_TYPE(obj) != (PyTypeObject*)py_tensor_map_type) {{
473
+ PyErr_Format(PyExc_TypeError, "object must be of type PyCUtensorMap, got %s", Py_TYPE(obj)->tp_name);
428
474
  return NULL;
429
- }}
475
+ }}
430
476
 
431
- uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret);
432
- Py_DECREF(method_ret);
433
- if (!ptr_as_uint) {{
434
- PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()");
435
- return NULL;
436
- }}
437
- if (ptr_as_uint % 64 != 0) {{
438
- PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned");
477
+ CUtensorMap* map = &((PyCUtensorMapObject*)obj)->tensorMap;
478
+ uintptr_t align_128 = (uintptr_t)map & (128 - 1);
479
+ if (align_128 != 0) {{
480
+ PyErr_Format(PyExc_ValueError, "CUtensorMap must be aligned to 128B, but got (&map) mod 128 = %ld", align_128);
439
481
  return NULL;
440
482
  }}
441
-
442
- return (CUtensorMap*)(ptr_as_uint);
483
+ return map;
443
484
  }}
444
485
 
445
486
  static void ensureCudaContext() {{
@@ -454,6 +495,32 @@ static void ensureCudaContext() {{
454
495
  }}
455
496
  }}
456
497
 
498
+ static uint16_t pack_fp16(double f) {{
499
+ uint16_t result;
500
+ // from https://github.com/python/pythoncapi-compat
501
+ #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
502
+ _PyFloat_Pack2(f, (unsigned char*)&result, 1);
503
+ #else
504
+ PyFloat_Pack2(f, (unsigned char*)&result, 1);
505
+ #endif
506
+ return result;
507
+ }}
508
+
509
+ static uint16_t pack_bf16(double f) {{
510
+ float f32 = (float)f;
511
+ uint32_t u32 = *(uint32_t*)&f32;
512
+ return (uint16_t)(u32 >> 16);
513
+ }}
514
+
515
+ static uint32_t pack_fp32(double f) {{
516
+ float f32 = (float)f;
517
+ return *(uint32_t*)&f32;
518
+ }}
519
+
520
+ static uint64_t pack_fp64(double f) {{
521
+ return *(uint64_t*)&f;
522
+ }}
523
+
457
524
  static PyObject* launch(PyObject* self, PyObject* args) {{
458
525
  // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
459
526
  ensureCudaContext();
@@ -462,14 +529,16 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
462
529
  uint64_t _stream;
463
530
  uint64_t _function;
464
531
  int launch_cooperative_grid;
532
+ int launch_pdl;
465
533
  PyObject *launch_enter_hook = NULL;
466
534
  PyObject *launch_exit_hook = NULL;
467
535
  PyObject *kernel_metadata = NULL;
468
536
  PyObject *launch_metadata = NULL;
469
537
  PyObject *global_scratch_obj = NULL;
538
+ PyObject *profile_scratch_obj = NULL;
470
539
  {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
471
540
  if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
472
- &_stream, &_function, &launch_cooperative_grid, &global_scratch_obj,
541
+ &_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj, &profile_scratch_obj,
473
542
  &kernel_metadata, &launch_metadata,
474
543
  &launch_enter_hook, &launch_exit_hook{args_list})) {{
475
544
  return NULL;
@@ -483,11 +552,10 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
483
552
 
484
553
  // extract launch metadata
485
554
  if (launch_enter_hook != Py_None){{
486
- PyObject* args = Py_BuildValue("(O)", launch_metadata);
487
- PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
488
- Py_DECREF(args);
555
+ PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
489
556
  if (!ret)
490
557
  return NULL;
558
+ Py_DECREF(ret);
491
559
  }}
492
560
 
493
561
  CUdeviceptr global_scratch = 0;
@@ -499,23 +567,31 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
499
567
  global_scratch = global_scratch_info.dev_ptr;
500
568
  }}
501
569
 
570
+ CUdeviceptr profile_scratch = 0;
571
+ if (profile_scratch_obj != Py_None) {{
572
+ DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
573
+ if (!profile_scratch_info.valid) {{
574
+ return NULL;
575
+ }}
576
+ profile_scratch = profile_scratch_info.dev_ptr;
577
+ }}
578
+
502
579
  // raise exception asap
503
580
  {newline.join(ptr_decls)}
504
581
  {newline.join(tma_decls)}
582
+ {newline.join(float_storage_decls)}
505
583
  Py_BEGIN_ALLOW_THREADS;
506
- _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
584
+ _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
507
585
  Py_END_ALLOW_THREADS;
508
586
  if (PyErr_Occurred()) {{
509
587
  return NULL;
510
588
  }}
511
589
 
512
590
  if(launch_exit_hook != Py_None){{
513
- PyObject* args = Py_BuildValue("(O)", launch_metadata);
514
- PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
515
- Py_DECREF(args);
591
+ PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
516
592
  if (!ret)
517
593
  return NULL;
518
-
594
+ Py_DECREF(ret);
519
595
  }}
520
596
 
521
597
  Py_RETURN_NONE;
@@ -535,6 +611,19 @@ static struct PyModuleDef ModuleDef = {{
535
611
  }};
536
612
 
537
613
  PyMODINIT_FUNC PyInit___triton_launcher(void) {{
614
+ data_ptr_str = PyUnicode_InternFromString("data_ptr");
615
+ if(data_ptr_str == NULL) {{
616
+ return NULL;
617
+ }}
618
+ PyObject* driver_mod = PyImport_ImportModule("triton.backends.nvidia.driver");
619
+ if (driver_mod == NULL) {{
620
+ return NULL;
621
+ }}
622
+ py_tensor_map_type = PyObject_GetAttrString(driver_mod, "PyCUtensorMap");
623
+ if (py_tensor_map_type == NULL) {{
624
+ return NULL;
625
+ }}
626
+
538
627
  PyObject *m = PyModule_Create(&ModuleDef);
539
628
  if(m == NULL) {{
540
629
  return NULL;
@@ -546,6 +635,77 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
546
635
  return src
547
636
 
548
637
 
638
+ # The TMA dtype enum values are slightly different on host vs device...
639
+ TMA_DTYPE_DEVICE_TO_HOST = dict((i, i) for i in range(16))
640
+ TMA_DTYPE_DEVICE_TO_HOST[8] = 10
641
+ TMA_DTYPE_DEVICE_TO_HOST[9] = 8
642
+ TMA_DTYPE_DEVICE_TO_HOST[10] = 9
643
+
644
+
645
+ def make_tensordesc_arg(arg, metadata):
646
+ if metadata is None:
647
+ # Currently the host side tensor descriptors get decomposed in
648
+ # the frontend to tensor desc, shape, and strides. We have no
649
+ # way to use these shape and strides when processing tensor
650
+ # descriptors which is why we provide our own decomposition
651
+ # above. Sadly this means we have to pass the shape and strides
652
+ # twice.
653
+ return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
654
+
655
+ swizzle = metadata["swizzle"]
656
+ elem_size = metadata["elem_size"]
657
+ elem_type = metadata["elem_type"]
658
+ block_size = metadata["block_size"]
659
+ fp4_padded = metadata["fp4_padded"]
660
+
661
+ shape = arg.shape
662
+ strides = arg.strides
663
+ assert strides[-1] == 1
664
+ padding = 1 if arg.padding == "nan" else 0
665
+
666
+ if fp4_padded:
667
+ shape = list(shape)
668
+ shape[-1] *= 2
669
+
670
+ cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor(
671
+ arg.base.data_ptr(),
672
+ swizzle,
673
+ elem_size,
674
+ TMA_DTYPE_DEVICE_TO_HOST[elem_type],
675
+ block_size,
676
+ shape,
677
+ strides,
678
+ padding,
679
+ )
680
+
681
+ return [cu_tensor_map, *shape, *strides]
682
+
683
+
684
+ def wrap_handle_tensordesc(launcher, signature, tensordesc_meta):
685
+ has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
686
+ if not has_tensor_desc_arg:
687
+ return launcher
688
+
689
+ tensordesc_indices = set(
690
+ [i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
691
+ assert not tensordesc_meta or len(tensordesc_meta) == len(tensordesc_indices)
692
+ if not tensordesc_meta:
693
+ tensordesc_meta = [None] * len(tensordesc_indices)
694
+
695
+ def inner(*args):
696
+ final_args = list(args[:_BASE_ARGS_FORMAT_LEN])
697
+ tensordesc_idx = 0
698
+ for i, arg in enumerate(args[_BASE_ARGS_FORMAT_LEN:]):
699
+ if i in tensordesc_indices:
700
+ final_args.extend(make_tensordesc_arg(arg, tensordesc_meta[tensordesc_idx]))
701
+ tensordesc_idx += 1
702
+ else:
703
+ final_args.append(arg)
704
+ return launcher(*final_args)
705
+
706
+ return inner
707
+
708
+
549
709
  class CudaLauncher(object):
550
710
 
551
711
  def __init__(self, src, metadata):
@@ -553,21 +713,40 @@ class CudaLauncher(object):
553
713
  arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
554
714
  constants = {arg_idx(idx): value for idx, value in constants.items()}
555
715
  signature = {idx: value for idx, value in src.signature.items()}
556
- src = make_launcher(constants, signature)
557
- mod = compile_module_from_src(src, "__triton_launcher")
558
- self.launch = mod.launch
716
+ tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
717
+ src = make_launcher(constants, signature, tensordesc_meta)
718
+ mod = compile_module_from_src(
719
+ src=src,
720
+ name="__triton_launcher",
721
+ library_dirs=library_dirs(),
722
+ include_dirs=include_dirs,
723
+ libraries=libraries,
724
+ )
725
+
726
+ self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)
727
+ self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
559
728
  self.global_scratch_size = metadata.global_scratch_size
560
729
  self.global_scratch_align = metadata.global_scratch_align
730
+ self.profile_scratch_size = metadata.profile_scratch_size
731
+ self.profile_scratch_align = metadata.profile_scratch_align
561
732
  self.launch_cooperative_grid = metadata.launch_cooperative_grid
733
+ self.launch_pdl = metadata.launch_pdl
562
734
 
563
735
  def __call__(self, gridX, gridY, gridZ, stream, function, *args):
564
- if self.global_scratch_size > 0:
565
- grid_size = gridX * gridY * gridZ
566
- alloc_size = grid_size * self.global_scratch_size
567
- global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
568
- else:
569
- global_scratch = None
570
- self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)
736
+
737
+ def allocate_scratch(size, align, allocator):
738
+ if size > 0:
739
+ grid_size = gridX * gridY * gridZ
740
+ alloc_size = grid_size * self.num_ctas * size
741
+ alloc_fn = allocator.get()
742
+ return alloc_fn(alloc_size, align, stream)
743
+ return None
744
+
745
+ global_scratch = allocate_scratch(self.global_scratch_size, self.global_scratch_align, _allocation._allocator)
746
+ profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
747
+ _allocation._profile_allocator)
748
+ self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
749
+ global_scratch, profile_scratch, *args)
571
750
 
572
751
 
573
752
  class CudaDriver(GPUDriver):
@@ -600,6 +779,9 @@ class CudaDriver(GPUDriver):
600
779
  except ImportError:
601
780
  return False
602
781
 
782
+ def map_python_to_cpp_type(self, ty: str) -> str:
783
+ return ty_to_cpp(ty)
784
+
603
785
  def get_benchmarker(self):
604
786
  from triton.testing import do_bench
605
787
  return do_bench