triton-windows 3.3.1.post21__cp39-cp39-win_amd64.whl → 3.4.0.post21__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +143 -46
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +94 -94
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +296 -125
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +73 -9
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +47 -83
  59. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/METADATA +7 -2
  60. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/RECORD +64 -41
  61. triton_windows-3.4.0.post21.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post21.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post21.dist-info/top_level.txt +1 -0
  64. triton/language/_utils.py +0 -21
  65. triton/language/extra/cuda/_experimental_tma.py +0 -106
  66. triton/tools/experimental_descriptor.py +0 -32
  67. triton_windows-3.3.1.post21.dist-info/top_level.txt +0 -14
  68. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/WHEEL +0 -0
@@ -1,36 +1,33 @@
1
1
  import functools
2
+ import operator
2
3
  import os
3
- import sysconfig
4
- import hashlib
5
4
  import subprocess
6
- import tempfile
5
+ import triton
6
+ import re
7
7
  from pathlib import Path
8
- from triton.runtime.build import _build
9
- from triton.runtime.cache import get_cache_manager
8
+ from triton import knobs
9
+ from triton.runtime.build import compile_module_from_src
10
10
  from triton.runtime import _allocation
11
11
  from triton.backends.compiler import GPUTarget
12
12
  from triton.backends.driver import GPUDriver
13
13
 
14
14
  dirname = os.path.dirname(os.path.realpath(__file__))
15
- include_dir = [os.path.join(dirname, "include")]
15
+ include_dirs = [os.path.join(dirname, "include")]
16
16
  if os.name == "nt":
17
17
  from triton.windows_utils import find_cuda
18
18
  _, cuda_inc_dirs, _ = find_cuda()
19
- include_dir += cuda_inc_dirs
19
+ include_dirs += cuda_inc_dirs
20
20
  libdevice_dir = os.path.join(dirname, "lib")
21
21
  libraries = ['cuda']
22
22
 
23
23
 
24
24
  @functools.lru_cache()
25
25
  def libcuda_dirs():
26
- env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH")
27
- if env_libcuda_path:
26
+ if env_libcuda_path := knobs.nvidia.libcuda_path:
28
27
  return [env_libcuda_path]
29
-
30
28
  if os.name == "nt":
31
29
  _, _, cuda_lib_dirs = find_cuda()
32
30
  return cuda_lib_dirs
33
-
34
31
  libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
35
32
  # each line looks like the following:
36
33
  # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
@@ -55,36 +52,6 @@ def library_dirs():
55
52
  return [libdevice_dir, *libcuda_dirs()]
56
53
 
57
54
 
58
- @functools.lru_cache()
59
- def platform_key():
60
- from platform import machine, system, architecture
61
- return ",".join([machine(), system(), *architecture()])
62
-
63
-
64
- def compile_module_from_src(src, name):
65
- key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
66
- cache = get_cache_manager(key)
67
- ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
68
- cache_path = cache.get_file(f"{name}.{ext}")
69
- if cache_path is None:
70
- with tempfile.TemporaryDirectory() as tmpdir:
71
- src_path = os.path.join(tmpdir, f"{name}.c")
72
- with open(src_path, "w") as f:
73
- f.write(src)
74
- so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
75
- with open(so, "rb") as f:
76
- cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True)
77
-
78
- # Loading module with relative path may cause error
79
- cache_path = os.path.abspath(cache_path)
80
-
81
- import importlib.util
82
- spec = importlib.util.spec_from_file_location(name, cache_path)
83
- mod = importlib.util.module_from_spec(spec)
84
- spec.loader.exec_module(mod)
85
- return mod
86
-
87
-
88
55
  # ------------------------
89
56
  # Utils
90
57
  # ------------------------
@@ -98,13 +65,18 @@ class CudaUtils(object):
98
65
  return cls.instance
99
66
 
100
67
  def __init__(self):
101
- mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils")
68
+ mod = compile_module_from_src(
69
+ src=Path(os.path.join(dirname, "driver.c")).read_text(),
70
+ name="cuda_utils",
71
+ library_dirs=library_dirs(),
72
+ include_dirs=include_dirs,
73
+ libraries=libraries,
74
+ )
102
75
  self.load_binary = mod.load_binary
103
76
  self.get_device_properties = mod.get_device_properties
104
77
  self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
105
78
  self.set_printf_fifo_size = mod.set_printf_fifo_size
106
- self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor
107
- self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor
79
+ self.fill_tma_descriptor = mod.fill_tma_descriptor
108
80
 
109
81
 
110
82
  # ------------------------
@@ -115,6 +87,8 @@ class CudaUtils(object):
115
87
  def ty_to_cpp(ty):
116
88
  if ty[0] == '*':
117
89
  return "CUdeviceptr"
90
+ if ty.startswith("tensordesc"):
91
+ return "CUtensorMap"
118
92
  return {
119
93
  "i1": "int32_t",
120
94
  "i8": "int8_t",
@@ -126,21 +100,80 @@ def ty_to_cpp(ty):
126
100
  "u16": "uint16_t",
127
101
  "u32": "uint32_t",
128
102
  "u64": "uint64_t",
129
- "fp16": "float",
130
- "bf16": "float",
131
- "fp32": "float",
132
- "f32": "float",
103
+ "fp16": "double",
104
+ "bf16": "double",
105
+ "fp32": "double",
106
+ "f32": "double",
133
107
  "fp64": "double",
134
108
  "nvTmaDesc": "CUtensorMap",
135
109
  }[ty]
136
110
 
137
111
 
138
- def make_launcher(constants, signature):
139
-
140
- def _serialize_signature(sig):
112
+ FLOAT_STORAGE_TYPE = {
113
+ "fp16": "uint16_t",
114
+ "bf16": "uint16_t",
115
+ "fp32": "uint32_t",
116
+ "f32": "uint32_t",
117
+ "fp64": "uint64_t",
118
+ }
119
+ FLOAT_PACK_FUNCTION = {
120
+ "fp16": "pack_fp16",
121
+ "bf16": "pack_bf16",
122
+ "fp32": "pack_fp32",
123
+ "f32": "pack_fp32",
124
+ "fp64": "pack_fp64",
125
+ }
126
+
127
+ _BASE_ARGS_FORMAT = "iiiKKppOOOOO"
128
+
129
+
130
+ def make_launcher(constants, signature, tensordesc_meta):
131
+
132
+ def _expand_signature(signature):
133
+ output = []
134
+ tensordesc_idx = 0
135
+ # Expand tensor descriptor arguments into either nvTmaDesc, shape and
136
+ # strides, or base pointer, shape and strides depending on whether the
137
+ # kernel was lowered to use the nvTmaDesc or not.
138
+ for sig in signature:
139
+ if isinstance(sig, str) and sig.startswith("tensordesc"):
140
+ meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
141
+ tensordesc_idx += 1
142
+
143
+ match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
144
+ dtype = match.group(1)
145
+ shape = match.group(2)
146
+ ndim = shape.count(",") + 1
147
+
148
+ if meta is None:
149
+ output.append("*" + dtype)
150
+ # Currently the host side tensor descriptors get passed in as a
151
+ # tensor desc, shape, and strides. We have no way to use these
152
+ # shape and strides when processing tensor descriptors which is
153
+ # why we provide our own decomposition above. Sadly this means
154
+ # we have to pass the shape and strides twice.
155
+ for _ in range(2 * ndim):
156
+ output.append("i64")
157
+ else:
158
+ output.append("nvTmaDesc")
159
+
160
+ for _ in range(ndim):
161
+ output.append("i32")
162
+ for _ in range(ndim):
163
+ output.append("i64")
164
+ else:
165
+ output.append(sig)
166
+
167
+ assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
168
+ return output
169
+
170
+ def _flatten_signature(sig, output):
171
+ # Flatten tuples
141
172
  if isinstance(sig, tuple):
142
- return ','.join(map(_serialize_signature, sig))
143
- return sig
173
+ for x in sig:
174
+ _flatten_signature(x, output)
175
+ else:
176
+ output.append(sig)
144
177
 
145
178
  def _extracted_type(ty):
146
179
  if isinstance(ty, tuple):
@@ -160,8 +193,9 @@ def make_launcher(constants, signature):
160
193
  return "O"
161
194
  if ty in ("constexpr", "nvTmaDesc"):
162
195
  return "O"
196
+ if ty.startswith("tensordesc"):
197
+ return "O"
163
198
  return {
164
- "float": "f",
165
199
  "double": "d",
166
200
  "long": "l",
167
201
  "int8_t": "b",
@@ -174,19 +208,34 @@ def make_launcher(constants, signature):
174
208
  "uint64_t": "K",
175
209
  }[ty_to_cpp(ty)]
176
210
 
211
+ expand_signature = _expand_signature(signature.values())
212
+ signature = {i: s for i, s in enumerate(expand_signature)}
213
+
177
214
  args_format = ''.join([format_of(ty) for ty in signature.values()])
178
- format = "iiiKKpOOOOO" + args_format
179
- signature = ','.join(map(_serialize_signature, signature.values()))
180
- signature = list(filter(bool, signature.split(',')))
181
- signature = {i: s for i, s in enumerate(signature)}
215
+ format = _BASE_ARGS_FORMAT + args_format
216
+
217
+ flat_signature = []
218
+ for sig in signature.values():
219
+ _flatten_signature(sig, flat_signature)
220
+ signature = {i: s for i, s in enumerate(flat_signature)}
182
221
  args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
183
222
  # Record the end of regular arguments;
184
223
  # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
185
- arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
224
+ arg_decl_list = []
225
+ for i, ty in signature.items():
226
+ if ty == "constexpr":
227
+ continue
228
+ if ty in FLOAT_STORAGE_TYPE:
229
+ arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
230
+ else:
231
+ arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
232
+ arg_decls = ', '.join(arg_decl_list)
186
233
  internal_args_list = []
187
234
  for i, ty in signature.items():
188
235
  if ty[0] == "*":
189
236
  internal_args_list.append(f"ptr_info{i}.dev_ptr")
237
+ elif ty in FLOAT_STORAGE_TYPE:
238
+ internal_args_list.append(f"_arg{i}_storage")
190
239
  elif ty == "nvTmaDesc":
191
240
  # Note: we have to dereference the pointer
192
241
  internal_args_list.append(f"*tma_ptr{i}")
@@ -205,14 +254,17 @@ def make_launcher(constants, signature):
205
254
  f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
206
255
  if ty == "nvTmaDesc"
207
256
  ]
257
+ float_storage_decls = [
258
+ f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
259
+ for i, ty in signature.items()
260
+ if ty in FLOAT_STORAGE_TYPE
261
+ ]
208
262
  params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
209
263
  params.append("&global_scratch")
210
264
  src = f"""
211
265
  #define _CRT_SECURE_NO_WARNINGS
212
266
  #include \"cuda.h\"
213
267
  #include <stdbool.h>
214
- #define PY_SSIZE_T_CLEAN
215
- #define Py_LIMITED_API 0x03090000
216
268
  #include <Python.h>
217
269
 
218
270
  #ifndef _WIN32
@@ -282,67 +334,65 @@ static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
282
334
  }}
283
335
  #endif
284
336
 
285
- static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
337
+ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
286
338
  void *params[] = {{ {', '.join(params)} }};
287
339
  if (gridX*gridY*gridZ > 0) {{
288
- if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{
289
- CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
290
- }} else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {{
291
- CUlaunchAttribute launchAttr[1];
340
+ // 4 attributes that we can currently pass maxmimum
341
+ CUlaunchAttribute launchAttr[4];
342
+ static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
343
+ if (cuLaunchKernelExHandle == NULL) {{
344
+ cuLaunchKernelExHandle = getLaunchKernelExHandle();
345
+ }}
346
+ CUlaunchConfig config;
347
+ config.gridDimX = gridX;
348
+ config.gridDimY = gridY;
349
+ config.gridDimZ = gridZ;
350
+
351
+ if (num_ctas != 1) {{
352
+ config.gridDimX *= clusterDimX;
353
+ config.gridDimY *= clusterDimY;
354
+ config.gridDimZ *= clusterDimZ;
355
+ }}
356
+
357
+ config.blockDimX = 32 * num_warps;
358
+ config.blockDimY = 1;
359
+ config.blockDimZ = 1;
360
+ config.sharedMemBytes = shared_memory;
361
+ config.hStream = stream;
362
+ config.attrs = launchAttr;
363
+ int num_attrs = 0;
364
+
365
+ if (launch_pdl != 0) {{
366
+ CUlaunchAttribute pdlAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1}};
367
+ launchAttr[num_attrs] = pdlAttr;
368
+ ++num_attrs;
369
+ }}
370
+
371
+ if (launch_cooperative_grid != 0) {{
292
372
  CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
293
- launchAttr[0] = coopAttr;
294
-
295
- CUlaunchConfig config;
296
- config.gridDimX = gridX;
297
- config.gridDimY = gridY;
298
- config.gridDimZ = gridZ;
299
- config.blockDimX = 32 * num_warps;
300
- config.blockDimY = 1;
301
- config.blockDimZ = 1;
302
- config.sharedMemBytes = shared_memory;
303
- config.hStream = stream;
304
- config.attrs = launchAttr;
305
- config.numAttrs = 1;
306
-
307
- static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
308
- if (cuLaunchKernelExHandle == NULL) {{
309
- cuLaunchKernelExHandle = getLaunchKernelExHandle();
310
- }}
311
- CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
312
-
313
- }} else {{
314
- CUlaunchAttribute launchAttr[3];
315
- launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
316
- launchAttr[0].value.clusterDim.x = clusterDimX;
317
- launchAttr[0].value.clusterDim.y = clusterDimY;
318
- launchAttr[0].value.clusterDim.z = clusterDimZ;
319
- launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
320
- launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
321
-
322
- unsigned numAttrs = 2;
323
- if (0 != launch_cooperative_grid) {{
324
- CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
325
- launchAttr[2] = coopAttr;
326
- numAttrs = 3;
327
- }}
328
-
329
- CUlaunchConfig config;
330
- config.gridDimX = gridX * clusterDimX;
331
- config.gridDimY = gridY * clusterDimY;
332
- config.gridDimZ = gridZ * clusterDimZ;
333
- config.blockDimX = 32 * num_warps;
334
- config.blockDimY = 1;
335
- config.blockDimZ = 1;
336
- config.sharedMemBytes = shared_memory;
337
- config.hStream = stream;
338
- config.attrs = launchAttr;
339
- config.numAttrs = numAttrs;
340
- static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
341
- if (cuLaunchKernelExHandle == NULL) {{
342
- cuLaunchKernelExHandle = getLaunchKernelExHandle();
343
- }}
344
- CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
373
+ launchAttr[num_attrs] = coopAttr;
374
+ ++num_attrs;
345
375
  }}
376
+
377
+ if (num_ctas != 1) {{
378
+ CUlaunchAttribute clusterAttr = {{}};
379
+ clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
380
+ clusterAttr.value.clusterDim.x = clusterDimX;
381
+ clusterAttr.value.clusterDim.y = clusterDimY;
382
+ clusterAttr.value.clusterDim.z = clusterDimZ;
383
+ launchAttr[num_attrs] = clusterAttr;
384
+ ++num_attrs;
385
+
386
+ CUlaunchAttribute clusterSchedulingAttr = {{}};
387
+ clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
388
+ clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
389
+ launchAttr[num_attrs] = clusterSchedulingAttr;
390
+ ++num_attrs;
391
+ }}
392
+
393
+ config.numAttrs = num_attrs;
394
+
395
+ CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
346
396
  }}
347
397
  }}
348
398
 
@@ -457,6 +507,32 @@ static void ensureCudaContext() {{
457
507
  }}
458
508
  }}
459
509
 
510
+ static uint16_t pack_fp16(double f) {{
511
+ uint16_t result;
512
+ // from https://github.com/python/pythoncapi-compat
513
+ #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
514
+ _PyFloat_Pack2(f, (unsigned char*)&result, 1);
515
+ #else
516
+ PyFloat_Pack2(f, (unsigned char*)&result, 1);
517
+ #endif
518
+ return result;
519
+ }}
520
+
521
+ static uint16_t pack_bf16(double f) {{
522
+ float f32 = (float)f;
523
+ uint32_t u32 = *(uint32_t*)&f32;
524
+ return (uint16_t)(u32 >> 16);
525
+ }}
526
+
527
+ static uint32_t pack_fp32(double f) {{
528
+ float f32 = (float)f;
529
+ return *(uint32_t*)&f32;
530
+ }}
531
+
532
+ static uint64_t pack_fp64(double f) {{
533
+ return *(uint64_t*)&f;
534
+ }}
535
+
460
536
  static PyObject* launch(PyObject* self, PyObject* args) {{
461
537
  // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
462
538
  ensureCudaContext();
@@ -465,6 +541,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
465
541
  uint64_t _stream;
466
542
  uint64_t _function;
467
543
  int launch_cooperative_grid;
544
+ int launch_pdl;
468
545
  PyObject *launch_enter_hook = NULL;
469
546
  PyObject *launch_exit_hook = NULL;
470
547
  PyObject *kernel_metadata = NULL;
@@ -472,7 +549,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
472
549
  PyObject *global_scratch_obj = NULL;
473
550
  {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
474
551
  if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
475
- &_stream, &_function, &launch_cooperative_grid, &global_scratch_obj,
552
+ &_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj,
476
553
  &kernel_metadata, &launch_metadata,
477
554
  &launch_enter_hook, &launch_exit_hook{args_list})) {{
478
555
  return NULL;
@@ -506,8 +583,9 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
506
583
  // raise exception asap
507
584
  {newline.join(ptr_decls)}
508
585
  {newline.join(tma_decls)}
586
+ {newline.join(float_storage_decls)}
509
587
  Py_BEGIN_ALLOW_THREADS;
510
- _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
588
+ _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
511
589
  Py_END_ALLOW_THREADS;
512
590
  if (PyErr_Occurred()) {{
513
591
  return NULL;
@@ -550,6 +628,87 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
550
628
  return src
551
629
 
552
630
 
631
+ class TmaDescKernelParam:
632
+ TMA_DESC_SIZE = 128
633
+
634
+ def __init__(self):
635
+ import torch
636
+ self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device="cpu")
637
+
638
+ # Return a CUtensorMap* pointer in host memory
639
+ def tma_desc_cpu_ptr(self):
640
+ return self.desc.data_ptr()
641
+
642
+
643
+ # The TMA dtype enum values are slightly different on host vs device...
644
+ TMA_DTYPE_DEVICE_TO_HOST = dict((i, i) for i in range(16))
645
+ TMA_DTYPE_DEVICE_TO_HOST[8] = 10
646
+ TMA_DTYPE_DEVICE_TO_HOST[9] = 8
647
+ TMA_DTYPE_DEVICE_TO_HOST[10] = 9
648
+
649
+
650
+ def make_tensordesc_arg(arg, metadata):
651
+ if metadata is None:
652
+ # Currently the host side tensor descriptors get decomposed in
653
+ # the frontend to tensor desc, shape, and strides. We have no
654
+ # way to use these shape and strides when processing tensor
655
+ # descriptors which is why we provide our own decomposition
656
+ # above. Sadly this means we have to pass the shape and strides
657
+ # twice.
658
+ return [arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides]
659
+
660
+ swizzle = metadata["swizzle"]
661
+ elem_size = metadata["elem_size"]
662
+ elem_type = metadata["elem_type"]
663
+ block_size = metadata["block_size"]
664
+ fp4_padded = metadata["fp4_padded"]
665
+
666
+ data_ptr = arg.base.data_ptr()
667
+ shape = arg.shape
668
+ strides = arg.strides
669
+ assert strides[-1] == 1
670
+
671
+ desc = TmaDescKernelParam()
672
+ result = [desc, *shape, *strides]
673
+
674
+ if fp4_padded:
675
+ shape = list(shape)
676
+ shape[-1] *= 2
677
+ triton.runtime.driver.active.utils.fill_tma_descriptor(
678
+ desc.tma_desc_cpu_ptr(),
679
+ data_ptr,
680
+ swizzle,
681
+ elem_size,
682
+ TMA_DTYPE_DEVICE_TO_HOST[elem_type],
683
+ block_size,
684
+ shape,
685
+ strides,
686
+ )
687
+ return result
688
+
689
+
690
+ def wrap_handle_tensordesc(launcher, tensordesc_meta):
691
+ from triton.tools.tensor_descriptor import TensorDescriptor
692
+ from triton.experimental.gluon.nvidia.hopper import TensorDescriptor as GluonTensorDescriptor
693
+
694
+ def inner(*args):
695
+ meta_args = args[:len(_BASE_ARGS_FORMAT)]
696
+ raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
697
+ tensordesc_idx = 0
698
+ final_args = []
699
+ for i, arg in enumerate(raw_kernel_args):
700
+ if isinstance(arg, (TensorDescriptor, GluonTensorDescriptor)):
701
+ meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
702
+ tensordesc_idx += 1
703
+ final_args.extend(make_tensordesc_arg(arg, meta))
704
+ else:
705
+ final_args.append(arg)
706
+ assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
707
+ return launcher(*meta_args, *final_args)
708
+
709
+ return inner
710
+
711
+
553
712
  class CudaLauncher(object):
554
713
 
555
714
  def __init__(self, src, metadata):
@@ -557,21 +716,33 @@ class CudaLauncher(object):
557
716
  arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
558
717
  constants = {arg_idx(idx): value for idx, value in constants.items()}
559
718
  signature = {idx: value for idx, value in src.signature.items()}
560
- src = make_launcher(constants, signature)
561
- mod = compile_module_from_src(src, "__triton_launcher")
562
- self.launch = mod.launch
719
+ tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
720
+ src = make_launcher(constants, signature, tensordesc_meta)
721
+ mod = compile_module_from_src(
722
+ src=src,
723
+ name="__triton_launcher",
724
+ library_dirs=library_dirs(),
725
+ include_dirs=include_dirs,
726
+ libraries=libraries,
727
+ )
728
+ has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
729
+
730
+ self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)
731
+ self.launch = wrap_handle_tensordesc(mod.launch, tensordesc_meta) if has_tensor_desc_arg else mod.launch
563
732
  self.global_scratch_size = metadata.global_scratch_size
564
733
  self.global_scratch_align = metadata.global_scratch_align
565
734
  self.launch_cooperative_grid = metadata.launch_cooperative_grid
735
+ self.launch_pdl = metadata.launch_pdl
566
736
 
567
737
  def __call__(self, gridX, gridY, gridZ, stream, function, *args):
568
738
  if self.global_scratch_size > 0:
569
739
  grid_size = gridX * gridY * gridZ
570
- alloc_size = grid_size * self.global_scratch_size
740
+ alloc_size = grid_size * self.num_ctas * self.global_scratch_size
571
741
  global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
572
742
  else:
573
743
  global_scratch = None
574
- self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)
744
+ self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
745
+ global_scratch, *args)
575
746
 
576
747
 
577
748
  class CudaDriver(GPUDriver):