triton-windows 3.3.1.post19__cp312-cp312-win_amd64.whl → 3.4.0.post20__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (166) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +14 -6
  59. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  60. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  61. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  64. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  65. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  66. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  67. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  68. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  69. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  70. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  71. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  72. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  73. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  80. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  81. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  82. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  83. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  84. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  85. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  86. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  87. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  88. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  89. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  90. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  91. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  92. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  93. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  94. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  95. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  96. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  97. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  98. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  99. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  100. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  101. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  102. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  103. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  104. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  105. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  106. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  107. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  108. triton/backends/amd/include/hip/device_functions.h +0 -38
  109. triton/backends/amd/include/hip/driver_types.h +0 -468
  110. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  111. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  112. triton/backends/amd/include/hip/hip_common.h +0 -100
  113. triton/backends/amd/include/hip/hip_complex.h +0 -38
  114. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  115. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  116. triton/backends/amd/include/hip/hip_ext.h +0 -161
  117. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  118. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  119. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  120. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  121. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  122. triton/backends/amd/include/hip/hip_profile.h +0 -27
  123. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  124. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  125. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  126. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  127. triton/backends/amd/include/hip/hip_version.h +0 -17
  128. triton/backends/amd/include/hip/hiprtc.h +0 -421
  129. triton/backends/amd/include/hip/library_types.h +0 -78
  130. triton/backends/amd/include/hip/math_functions.h +0 -42
  131. triton/backends/amd/include/hip/surface_types.h +0 -63
  132. triton/backends/amd/include/hip/texture_types.h +0 -194
  133. triton/backends/amd/include/hsa/Brig.h +0 -1131
  134. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  135. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  136. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  137. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  138. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  139. triton/backends/amd/include/hsa/hsa.h +0 -5738
  140. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  141. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  142. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  143. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  144. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  145. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  146. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  147. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  148. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  149. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  150. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  151. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  152. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  153. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  154. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  155. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  156. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  157. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  158. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  159. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  160. triton/backends/amd/include/roctracer/roctx.h +0 -229
  161. triton/language/_utils.py +0 -21
  162. triton/language/extra/cuda/_experimental_tma.py +0 -106
  163. triton/tools/experimental_descriptor.py +0 -32
  164. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  165. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  166. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
@@ -1,36 +1,33 @@
1
1
  import functools
2
+ import operator
2
3
  import os
3
- import sysconfig
4
- import hashlib
5
4
  import subprocess
6
- import tempfile
5
+ import triton
6
+ import re
7
7
  from pathlib import Path
8
- from triton.runtime.build import _build
9
- from triton.runtime.cache import get_cache_manager
8
+ from triton import knobs
9
+ from triton.runtime.build import compile_module_from_src
10
10
  from triton.runtime import _allocation
11
11
  from triton.backends.compiler import GPUTarget
12
12
  from triton.backends.driver import GPUDriver
13
13
 
14
14
  dirname = os.path.dirname(os.path.realpath(__file__))
15
- include_dir = [os.path.join(dirname, "include")]
15
+ include_dirs = [os.path.join(dirname, "include")]
16
16
  if os.name == "nt":
17
17
  from triton.windows_utils import find_cuda
18
18
  _, cuda_inc_dirs, _ = find_cuda()
19
- include_dir += cuda_inc_dirs
19
+ include_dirs += cuda_inc_dirs
20
20
  libdevice_dir = os.path.join(dirname, "lib")
21
21
  libraries = ['cuda']
22
22
 
23
23
 
24
24
  @functools.lru_cache()
25
25
  def libcuda_dirs():
26
- env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH")
27
- if env_libcuda_path:
26
+ if env_libcuda_path := knobs.nvidia.libcuda_path:
28
27
  return [env_libcuda_path]
29
-
30
28
  if os.name == "nt":
31
29
  _, _, cuda_lib_dirs = find_cuda()
32
30
  return cuda_lib_dirs
33
-
34
31
  libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
35
32
  # each line looks like the following:
36
33
  # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
@@ -55,36 +52,6 @@ def library_dirs():
55
52
  return [libdevice_dir, *libcuda_dirs()]
56
53
 
57
54
 
58
- @functools.lru_cache()
59
- def platform_key():
60
- from platform import machine, system, architecture
61
- return ",".join([machine(), system(), *architecture()])
62
-
63
-
64
- def compile_module_from_src(src, name):
65
- key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
66
- cache = get_cache_manager(key)
67
- ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
68
- cache_path = cache.get_file(f"{name}.{ext}")
69
- if cache_path is None:
70
- with tempfile.TemporaryDirectory() as tmpdir:
71
- src_path = os.path.join(tmpdir, f"{name}.c")
72
- with open(src_path, "w") as f:
73
- f.write(src)
74
- so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
75
- with open(so, "rb") as f:
76
- cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True)
77
-
78
- # Loading module with relative path may cause error
79
- cache_path = os.path.abspath(cache_path)
80
-
81
- import importlib.util
82
- spec = importlib.util.spec_from_file_location(name, cache_path)
83
- mod = importlib.util.module_from_spec(spec)
84
- spec.loader.exec_module(mod)
85
- return mod
86
-
87
-
88
55
  # ------------------------
89
56
  # Utils
90
57
  # ------------------------
@@ -98,13 +65,18 @@ class CudaUtils(object):
98
65
  return cls.instance
99
66
 
100
67
  def __init__(self):
101
- mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils")
68
+ mod = compile_module_from_src(
69
+ src=Path(os.path.join(dirname, "driver.c")).read_text(),
70
+ name="cuda_utils",
71
+ library_dirs=library_dirs(),
72
+ include_dirs=include_dirs,
73
+ libraries=libraries,
74
+ )
102
75
  self.load_binary = mod.load_binary
103
76
  self.get_device_properties = mod.get_device_properties
104
77
  self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
105
78
  self.set_printf_fifo_size = mod.set_printf_fifo_size
106
- self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor
107
- self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor
79
+ self.fill_tma_descriptor = mod.fill_tma_descriptor
108
80
 
109
81
 
110
82
  # ------------------------
@@ -115,6 +87,8 @@ class CudaUtils(object):
115
87
  def ty_to_cpp(ty):
116
88
  if ty[0] == '*':
117
89
  return "CUdeviceptr"
90
+ if ty.startswith("tensordesc"):
91
+ return "CUtensorMap"
118
92
  return {
119
93
  "i1": "int32_t",
120
94
  "i8": "int8_t",
@@ -126,21 +100,80 @@ def ty_to_cpp(ty):
126
100
  "u16": "uint16_t",
127
101
  "u32": "uint32_t",
128
102
  "u64": "uint64_t",
129
- "fp16": "float",
130
- "bf16": "float",
131
- "fp32": "float",
132
- "f32": "float",
103
+ "fp16": "double",
104
+ "bf16": "double",
105
+ "fp32": "double",
106
+ "f32": "double",
133
107
  "fp64": "double",
134
108
  "nvTmaDesc": "CUtensorMap",
135
109
  }[ty]
136
110
 
137
111
 
138
- def make_launcher(constants, signature):
139
-
140
- def _serialize_signature(sig):
112
+ FLOAT_STORAGE_TYPE = {
113
+ "fp16": "uint16_t",
114
+ "bf16": "uint16_t",
115
+ "fp32": "uint32_t",
116
+ "f32": "uint32_t",
117
+ "fp64": "uint64_t",
118
+ }
119
+ FLOAT_PACK_FUNCTION = {
120
+ "fp16": "pack_fp16",
121
+ "bf16": "pack_bf16",
122
+ "fp32": "pack_fp32",
123
+ "f32": "pack_fp32",
124
+ "fp64": "pack_fp64",
125
+ }
126
+
127
+ _BASE_ARGS_FORMAT = "iiiKKppOOOOO"
128
+
129
+
130
+ def make_launcher(constants, signature, tensordesc_meta):
131
+
132
+ def _expand_signature(signature):
133
+ output = []
134
+ tensordesc_idx = 0
135
+ # Expand tensor descriptor arguments into either nvTmaDesc, shape and
136
+ # strides, or base pointer, shape and strides depending on whether the
137
+ # kernel was lowered to use the nvTmaDesc or not.
138
+ for sig in signature:
139
+ if isinstance(sig, str) and sig.startswith("tensordesc"):
140
+ meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
141
+ tensordesc_idx += 1
142
+
143
+ match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
144
+ dtype = match.group(1)
145
+ shape = match.group(2)
146
+ ndim = shape.count(",") + 1
147
+
148
+ if meta is None:
149
+ output.append("*" + dtype)
150
+ # Currently the host side tensor descriptors get passed in as a
151
+ # tensor desc, shape, and strides. We have no way to use these
152
+ # shape and strides when processing tensor descriptors which is
153
+ # why we provide our own decomposition above. Sadly this means
154
+ # we have to pass the shape and strides twice.
155
+ for _ in range(2 * ndim):
156
+ output.append("i64")
157
+ else:
158
+ output.append("nvTmaDesc")
159
+
160
+ for _ in range(ndim):
161
+ output.append("i32")
162
+ for _ in range(ndim):
163
+ output.append("i64")
164
+ else:
165
+ output.append(sig)
166
+
167
+ assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
168
+ return output
169
+
170
+ def _flatten_signature(sig, output):
171
+ # Flatten tuples
141
172
  if isinstance(sig, tuple):
142
- return ','.join(map(_serialize_signature, sig))
143
- return sig
173
+ for x in sig:
174
+ _flatten_signature(x, output)
175
+ else:
176
+ output.append(sig)
144
177
 
145
178
  def _extracted_type(ty):
146
179
  if isinstance(ty, tuple):
@@ -160,8 +193,9 @@ def make_launcher(constants, signature):
160
193
  return "O"
161
194
  if ty in ("constexpr", "nvTmaDesc"):
162
195
  return "O"
196
+ if ty.startswith("tensordesc"):
197
+ return "O"
163
198
  return {
164
- "float": "f",
165
199
  "double": "d",
166
200
  "long": "l",
167
201
  "int8_t": "b",
@@ -174,19 +208,34 @@ def make_launcher(constants, signature):
174
208
  "uint64_t": "K",
175
209
  }[ty_to_cpp(ty)]
176
210
 
211
+ expand_signature = _expand_signature(signature.values())
212
+ signature = {i: s for i, s in enumerate(expand_signature)}
213
+
177
214
  args_format = ''.join([format_of(ty) for ty in signature.values()])
178
- format = "iiiKKpOOOOO" + args_format
179
- signature = ','.join(map(_serialize_signature, signature.values()))
180
- signature = list(filter(bool, signature.split(',')))
181
- signature = {i: s for i, s in enumerate(signature)}
215
+ format = _BASE_ARGS_FORMAT + args_format
216
+
217
+ flat_signature = []
218
+ for sig in signature.values():
219
+ _flatten_signature(sig, flat_signature)
220
+ signature = {i: s for i, s in enumerate(flat_signature)}
182
221
  args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
183
222
  # Record the end of regular arguments;
184
223
  # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
185
- arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
224
+ arg_decl_list = []
225
+ for i, ty in signature.items():
226
+ if ty == "constexpr":
227
+ continue
228
+ if ty in FLOAT_STORAGE_TYPE:
229
+ arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
230
+ else:
231
+ arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
232
+ arg_decls = ', '.join(arg_decl_list)
186
233
  internal_args_list = []
187
234
  for i, ty in signature.items():
188
235
  if ty[0] == "*":
189
236
  internal_args_list.append(f"ptr_info{i}.dev_ptr")
237
+ elif ty in FLOAT_STORAGE_TYPE:
238
+ internal_args_list.append(f"_arg{i}_storage")
190
239
  elif ty == "nvTmaDesc":
191
240
  # Note: we have to dereference the pointer
192
241
  internal_args_list.append(f"*tma_ptr{i}")
@@ -205,14 +254,17 @@ def make_launcher(constants, signature):
205
254
  f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
206
255
  if ty == "nvTmaDesc"
207
256
  ]
257
+ float_storage_decls = [
258
+ f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
259
+ for i, ty in signature.items()
260
+ if ty in FLOAT_STORAGE_TYPE
261
+ ]
208
262
  params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
209
263
  params.append("&global_scratch")
210
264
  src = f"""
211
265
  #define _CRT_SECURE_NO_WARNINGS
212
266
  #include \"cuda.h\"
213
267
  #include <stdbool.h>
214
- #define PY_SSIZE_T_CLEAN
215
- #define Py_LIMITED_API 0x03090000
216
268
  #include <Python.h>
217
269
 
218
270
  #ifndef _WIN32
@@ -282,67 +334,65 @@ static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
282
334
  }}
283
335
  #endif
284
336
 
285
- static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
337
+ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
286
338
  void *params[] = {{ {', '.join(params)} }};
287
339
  if (gridX*gridY*gridZ > 0) {{
288
- if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{
289
- CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
290
- }} else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {{
291
- CUlaunchAttribute launchAttr[1];
340
+ // 4 attributes that we can currently pass maxmimum
341
+ CUlaunchAttribute launchAttr[4];
342
+ static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
343
+ if (cuLaunchKernelExHandle == NULL) {{
344
+ cuLaunchKernelExHandle = getLaunchKernelExHandle();
345
+ }}
346
+ CUlaunchConfig config;
347
+ config.gridDimX = gridX;
348
+ config.gridDimY = gridY;
349
+ config.gridDimZ = gridZ;
350
+
351
+ if (num_ctas != 1) {{
352
+ config.gridDimX *= clusterDimX;
353
+ config.gridDimY *= clusterDimY;
354
+ config.gridDimZ *= clusterDimZ;
355
+ }}
356
+
357
+ config.blockDimX = 32 * num_warps;
358
+ config.blockDimY = 1;
359
+ config.blockDimZ = 1;
360
+ config.sharedMemBytes = shared_memory;
361
+ config.hStream = stream;
362
+ config.attrs = launchAttr;
363
+ int num_attrs = 0;
364
+
365
+ if (launch_pdl != 0) {{
366
+ CUlaunchAttribute pdlAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1}};
367
+ launchAttr[num_attrs] = pdlAttr;
368
+ ++num_attrs;
369
+ }}
370
+
371
+ if (launch_cooperative_grid != 0) {{
292
372
  CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
293
- launchAttr[0] = coopAttr;
294
-
295
- CUlaunchConfig config;
296
- config.gridDimX = gridX;
297
- config.gridDimY = gridY;
298
- config.gridDimZ = gridZ;
299
- config.blockDimX = 32 * num_warps;
300
- config.blockDimY = 1;
301
- config.blockDimZ = 1;
302
- config.sharedMemBytes = shared_memory;
303
- config.hStream = stream;
304
- config.attrs = launchAttr;
305
- config.numAttrs = 1;
306
-
307
- static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
308
- if (cuLaunchKernelExHandle == NULL) {{
309
- cuLaunchKernelExHandle = getLaunchKernelExHandle();
310
- }}
311
- CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
312
-
313
- }} else {{
314
- CUlaunchAttribute launchAttr[3];
315
- launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
316
- launchAttr[0].value.clusterDim.x = clusterDimX;
317
- launchAttr[0].value.clusterDim.y = clusterDimY;
318
- launchAttr[0].value.clusterDim.z = clusterDimZ;
319
- launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
320
- launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
321
-
322
- unsigned numAttrs = 2;
323
- if (0 != launch_cooperative_grid) {{
324
- CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
325
- launchAttr[2] = coopAttr;
326
- numAttrs = 3;
327
- }}
328
-
329
- CUlaunchConfig config;
330
- config.gridDimX = gridX * clusterDimX;
331
- config.gridDimY = gridY * clusterDimY;
332
- config.gridDimZ = gridZ * clusterDimZ;
333
- config.blockDimX = 32 * num_warps;
334
- config.blockDimY = 1;
335
- config.blockDimZ = 1;
336
- config.sharedMemBytes = shared_memory;
337
- config.hStream = stream;
338
- config.attrs = launchAttr;
339
- config.numAttrs = numAttrs;
340
- static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
341
- if (cuLaunchKernelExHandle == NULL) {{
342
- cuLaunchKernelExHandle = getLaunchKernelExHandle();
343
- }}
344
- CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
373
+ launchAttr[num_attrs] = coopAttr;
374
+ ++num_attrs;
375
+ }}
376
+
377
+ if (num_ctas != 1) {{
378
+ CUlaunchAttribute clusterAttr = {{}};
379
+ clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
380
+ clusterAttr.value.clusterDim.x = clusterDimX;
381
+ clusterAttr.value.clusterDim.y = clusterDimY;
382
+ clusterAttr.value.clusterDim.z = clusterDimZ;
383
+ launchAttr[num_attrs] = clusterAttr;
384
+ ++num_attrs;
385
+
386
+ CUlaunchAttribute clusterSchedulingAttr = {{}};
387
+ clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
388
+ clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
389
+ launchAttr[num_attrs] = clusterSchedulingAttr;
390
+ ++num_attrs;
345
391
  }}
392
+
393
+ config.numAttrs = num_attrs;
394
+
395
+ CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
346
396
  }}
347
397
  }}
348
398
 
@@ -372,11 +422,14 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
372
422
  if (!PyLong_Check(ret)) {{
373
423
  PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
374
424
  ptr_info.valid = false;
425
+ Py_DECREF(ret);
375
426
  return ptr_info;
376
427
  }}
377
428
  ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
378
- if(!ptr_info.dev_ptr)
429
+ if(!ptr_info.dev_ptr) {{
430
+ Py_DECREF(ret);
379
431
  return ptr_info;
432
+ }}
380
433
  uint64_t dev_ptr;
381
434
  int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
382
435
  if (status == CUDA_ERROR_INVALID_VALUE) {{
@@ -388,7 +441,7 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
388
441
  ptr_info.valid = false;
389
442
  }}
390
443
  ptr_info.dev_ptr = dev_ptr;
391
- Py_DECREF(ret); // Thanks ChatGPT!
444
+ Py_DECREF(ret);
392
445
  return ptr_info;
393
446
  }}
394
447
  PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
@@ -454,6 +507,32 @@ static void ensureCudaContext() {{
454
507
  }}
455
508
  }}
456
509
 
510
+ static uint16_t pack_fp16(double f) {{
511
+ uint16_t result;
512
+ // from https://github.com/python/pythoncapi-compat
513
+ #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
514
+ _PyFloat_Pack2(f, (unsigned char*)&result, 1);
515
+ #else
516
+ PyFloat_Pack2(f, (unsigned char*)&result, 1);
517
+ #endif
518
+ return result;
519
+ }}
520
+
521
+ static uint16_t pack_bf16(double f) {{
522
+ float f32 = (float)f;
523
+ uint32_t u32 = *(uint32_t*)&f32;
524
+ return (uint16_t)(u32 >> 16);
525
+ }}
526
+
527
+ static uint32_t pack_fp32(double f) {{
528
+ float f32 = (float)f;
529
+ return *(uint32_t*)&f32;
530
+ }}
531
+
532
+ static uint64_t pack_fp64(double f) {{
533
+ return *(uint64_t*)&f;
534
+ }}
535
+
457
536
  static PyObject* launch(PyObject* self, PyObject* args) {{
458
537
  // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
459
538
  ensureCudaContext();
@@ -462,6 +541,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
462
541
  uint64_t _stream;
463
542
  uint64_t _function;
464
543
  int launch_cooperative_grid;
544
+ int launch_pdl;
465
545
  PyObject *launch_enter_hook = NULL;
466
546
  PyObject *launch_exit_hook = NULL;
467
547
  PyObject *kernel_metadata = NULL;
@@ -469,7 +549,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
469
549
  PyObject *global_scratch_obj = NULL;
470
550
  {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
471
551
  if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
472
- &_stream, &_function, &launch_cooperative_grid, &global_scratch_obj,
552
+ &_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj,
473
553
  &kernel_metadata, &launch_metadata,
474
554
  &launch_enter_hook, &launch_exit_hook{args_list})) {{
475
555
  return NULL;
@@ -488,6 +568,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
488
568
  Py_DECREF(args);
489
569
  if (!ret)
490
570
  return NULL;
571
+ Py_DECREF(ret);
491
572
  }}
492
573
 
493
574
  CUdeviceptr global_scratch = 0;
@@ -502,8 +583,9 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
502
583
  // raise exception asap
503
584
  {newline.join(ptr_decls)}
504
585
  {newline.join(tma_decls)}
586
+ {newline.join(float_storage_decls)}
505
587
  Py_BEGIN_ALLOW_THREADS;
506
- _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
588
+ _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
507
589
  Py_END_ALLOW_THREADS;
508
590
  if (PyErr_Occurred()) {{
509
591
  return NULL;
@@ -515,7 +597,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
515
597
  Py_DECREF(args);
516
598
  if (!ret)
517
599
  return NULL;
518
-
600
+ Py_DECREF(ret);
519
601
  }}
520
602
 
521
603
  Py_RETURN_NONE;
@@ -546,6 +628,87 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
546
628
  return src
547
629
 
548
630
 
631
+ class TmaDescKernelParam:
632
+ TMA_DESC_SIZE = 128
633
+
634
+ def __init__(self):
635
+ import torch
636
+ self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device="cpu")
637
+
638
+ # Return a CUtensorMap* pointer in host memory
639
+ def tma_desc_cpu_ptr(self):
640
+ return self.desc.data_ptr()
641
+
642
+
643
+ # The TMA dtype enum values are slightly different on host vs device...
644
+ TMA_DTYPE_DEVICE_TO_HOST = dict((i, i) for i in range(16))
645
+ TMA_DTYPE_DEVICE_TO_HOST[8] = 10
646
+ TMA_DTYPE_DEVICE_TO_HOST[9] = 8
647
+ TMA_DTYPE_DEVICE_TO_HOST[10] = 9
648
+
649
+
650
+ def make_tensordesc_arg(arg, metadata):
651
+ if metadata is None:
652
+ # Currently the host side tensor descriptors get decomposed in
653
+ # the frontend to tensor desc, shape, and strides. We have no
654
+ # way to use these shape and strides when processing tensor
655
+ # descriptors which is why we provide our own decomposition
656
+ # above. Sadly this means we have to pass the shape and strides
657
+ # twice.
658
+ return [arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides]
659
+
660
+ swizzle = metadata["swizzle"]
661
+ elem_size = metadata["elem_size"]
662
+ elem_type = metadata["elem_type"]
663
+ block_size = metadata["block_size"]
664
+ fp4_padded = metadata["fp4_padded"]
665
+
666
+ data_ptr = arg.base.data_ptr()
667
+ shape = arg.shape
668
+ strides = arg.strides
669
+ assert strides[-1] == 1
670
+
671
+ desc = TmaDescKernelParam()
672
+ result = [desc, *shape, *strides]
673
+
674
+ if fp4_padded:
675
+ shape = list(shape)
676
+ shape[-1] *= 2
677
+ triton.runtime.driver.active.utils.fill_tma_descriptor(
678
+ desc.tma_desc_cpu_ptr(),
679
+ data_ptr,
680
+ swizzle,
681
+ elem_size,
682
+ TMA_DTYPE_DEVICE_TO_HOST[elem_type],
683
+ block_size,
684
+ shape,
685
+ strides,
686
+ )
687
+ return result
688
+
689
+
690
+ def wrap_handle_tensordesc(launcher, tensordesc_meta):
691
+ from triton.tools.tensor_descriptor import TensorDescriptor
692
+ from triton.experimental.gluon.nvidia.hopper import TensorDescriptor as GluonTensorDescriptor
693
+
694
+ def inner(*args):
695
+ meta_args = args[:len(_BASE_ARGS_FORMAT)]
696
+ raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
697
+ tensordesc_idx = 0
698
+ final_args = []
699
+ for i, arg in enumerate(raw_kernel_args):
700
+ if isinstance(arg, (TensorDescriptor, GluonTensorDescriptor)):
701
+ meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
702
+ tensordesc_idx += 1
703
+ final_args.extend(make_tensordesc_arg(arg, meta))
704
+ else:
705
+ final_args.append(arg)
706
+ assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
707
+ return launcher(*meta_args, *final_args)
708
+
709
+ return inner
710
+
711
+
549
712
  class CudaLauncher(object):
550
713
 
551
714
  def __init__(self, src, metadata):
@@ -553,21 +716,33 @@ class CudaLauncher(object):
553
716
  arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
554
717
  constants = {arg_idx(idx): value for idx, value in constants.items()}
555
718
  signature = {idx: value for idx, value in src.signature.items()}
556
- src = make_launcher(constants, signature)
557
- mod = compile_module_from_src(src, "__triton_launcher")
558
- self.launch = mod.launch
719
+ tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
720
+ src = make_launcher(constants, signature, tensordesc_meta)
721
+ mod = compile_module_from_src(
722
+ src=src,
723
+ name="__triton_launcher",
724
+ library_dirs=library_dirs(),
725
+ include_dirs=include_dirs,
726
+ libraries=libraries,
727
+ )
728
+ has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
729
+
730
+ self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)
731
+ self.launch = wrap_handle_tensordesc(mod.launch, tensordesc_meta) if has_tensor_desc_arg else mod.launch
559
732
  self.global_scratch_size = metadata.global_scratch_size
560
733
  self.global_scratch_align = metadata.global_scratch_align
561
734
  self.launch_cooperative_grid = metadata.launch_cooperative_grid
735
+ self.launch_pdl = metadata.launch_pdl
562
736
 
563
737
  def __call__(self, gridX, gridY, gridZ, stream, function, *args):
564
738
  if self.global_scratch_size > 0:
565
739
  grid_size = gridX * gridY * gridZ
566
- alloc_size = grid_size * self.global_scratch_size
740
+ alloc_size = grid_size * self.num_ctas * self.global_scratch_size
567
741
  global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
568
742
  else:
569
743
  global_scratch = None
570
- self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)
744
+ self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
745
+ global_scratch, *args)
571
746
 
572
747
 
573
748
  class CudaDriver(GPUDriver):