triton-windows 3.4.0.post20__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -19,16 +19,19 @@ if os.name == "nt":
19
19
  include_dirs += cuda_inc_dirs
20
20
  libdevice_dir = os.path.join(dirname, "lib")
21
21
  libraries = ['cuda']
22
+ PyCUtensorMap = None
22
23
 
23
24
 
24
25
  @functools.lru_cache()
25
26
  def libcuda_dirs():
26
27
  if env_libcuda_path := knobs.nvidia.libcuda_path:
27
28
  return [env_libcuda_path]
29
+
28
30
  if os.name == "nt":
29
31
  _, _, cuda_lib_dirs = find_cuda()
30
32
  return cuda_lib_dirs
31
- libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
33
+
34
+ libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
32
35
  # each line looks like the following:
33
36
  # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
34
37
  locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line]
@@ -72,6 +75,8 @@ class CudaUtils(object):
72
75
  include_dirs=include_dirs,
73
76
  libraries=libraries,
74
77
  )
78
+ global PyCUtensorMap
79
+ PyCUtensorMap = mod.PyCUtensorMap
75
80
  self.load_binary = mod.load_binary
76
81
  self.get_device_properties = mod.get_device_properties
77
82
  self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
@@ -90,12 +95,12 @@ def ty_to_cpp(ty):
90
95
  if ty.startswith("tensordesc"):
91
96
  return "CUtensorMap"
92
97
  return {
93
- "i1": "int32_t",
98
+ "i1": "int8_t",
94
99
  "i8": "int8_t",
95
100
  "i16": "int16_t",
96
101
  "i32": "int32_t",
97
102
  "i64": "int64_t",
98
- "u1": "uint32_t",
103
+ "u1": "uint8_t",
99
104
  "u8": "uint8_t",
100
105
  "u16": "uint16_t",
101
106
  "u32": "uint32_t",
@@ -124,7 +129,8 @@ FLOAT_PACK_FUNCTION = {
124
129
  "fp64": "pack_fp64",
125
130
  }
126
131
 
127
- _BASE_ARGS_FORMAT = "iiiKKppOOOOO"
132
+ _BASE_ARGS_FORMAT = "iiiKKppOOOOOO"
133
+ _BASE_ARGS_FORMAT_LEN = len(_BASE_ARGS_FORMAT)
128
134
 
129
135
 
130
136
  def make_launcher(constants, signature, tensordesc_meta):
@@ -154,6 +160,7 @@ def make_launcher(constants, signature, tensordesc_meta):
154
160
  # we have to pass the shape and strides twice.
155
161
  for _ in range(2 * ndim):
156
162
  output.append("i64")
163
+ output.append("i1")
157
164
  else:
158
165
  output.append("nvTmaDesc")
159
166
 
@@ -261,11 +268,10 @@ def make_launcher(constants, signature, tensordesc_meta):
261
268
  ]
262
269
  params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
263
270
  params.append("&global_scratch")
271
+ params.append("&profile_scratch")
264
272
  src = f"""
265
273
  #define _CRT_SECURE_NO_WARNINGS
266
274
  #include \"cuda.h\"
267
- #include <stdbool.h>
268
- #include <Python.h>
269
275
 
270
276
  #ifndef _WIN32
271
277
  #include <dlfcn.h>
@@ -274,6 +280,16 @@ def make_launcher(constants, signature, tensordesc_meta):
274
280
  #include <windows.h>
275
281
  #endif
276
282
 
283
+ #include <stdbool.h>
284
+ #include <stdlib.h>
285
+ #define PY_SSIZE_T_CLEAN
286
+ #include <Python.h>
287
+
288
+ typedef struct {{
289
+ PyObject_HEAD
290
+ _Alignas(128) CUtensorMap tensorMap;
291
+ }} PyCUtensorMapObject;
292
+
277
293
  static inline void gpuAssert(CUresult code, const char *file, int line)
278
294
  {{
279
295
  if (code != CUDA_SUCCESS)
@@ -334,10 +350,10 @@ static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
334
350
  }}
335
351
  #endif
336
352
 
337
- static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
353
+ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
338
354
  void *params[] = {{ {', '.join(params)} }};
339
355
  if (gridX*gridY*gridZ > 0) {{
340
- // 4 attributes that we can currently pass maxmimum
356
+ // 4 attributes that we can currently pass maximum
341
357
  CUlaunchAttribute launchAttr[4];
342
358
  static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
343
359
  if (cuLaunchKernelExHandle == NULL) {{
@@ -401,6 +417,9 @@ typedef struct _DevicePtrInfo {{
401
417
  bool valid;
402
418
  }} DevicePtrInfo;
403
419
 
420
+ static PyObject* data_ptr_str = NULL;
421
+ static PyObject* py_tensor_map_type = NULL;
422
+
404
423
  static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
405
424
  DevicePtrInfo ptr_info;
406
425
  ptr_info.dev_ptr = 0;
@@ -413,40 +432,35 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
413
432
  // valid nullptr
414
433
  return ptr_info;
415
434
  }}
416
- PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
417
- if(ptr){{
418
- PyObject *empty_tuple = PyTuple_New(0);
419
- PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
420
- Py_DECREF(empty_tuple);
421
- Py_DECREF(ptr);
422
- if (!PyLong_Check(ret)) {{
423
- PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
424
- ptr_info.valid = false;
425
- Py_DECREF(ret);
426
- return ptr_info;
427
- }}
428
- ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
429
- if(!ptr_info.dev_ptr) {{
430
- Py_DECREF(ret);
431
- return ptr_info;
432
- }}
433
- uint64_t dev_ptr;
434
- int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
435
- if (status == CUDA_ERROR_INVALID_VALUE) {{
436
- PyErr_Format(PyExc_ValueError,
437
- "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
438
- ptr_info.valid = false;
439
- }} else if (status != CUDA_SUCCESS) {{
440
- CUDA_CHECK(status); // Catch any other cuda API errors
441
- ptr_info.valid = false;
442
- }}
443
- ptr_info.dev_ptr = dev_ptr;
444
- Py_DECREF(ret);
435
+ PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
436
+ if (!ret) {{
437
+ PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
438
+ ptr_info.valid = false;
439
+ goto cleanup;
440
+ }}
441
+ if (!PyLong_Check(ret)) {{
442
+ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
443
+ ptr_info.valid = false;
444
+ goto cleanup;
445
+ }}
446
+ ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
447
+ if(!ptr_info.dev_ptr)
445
448
  return ptr_info;
449
+ uint64_t dev_ptr;
450
+ int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
451
+ if (status == CUDA_ERROR_INVALID_VALUE) {{
452
+ PyErr_Format(PyExc_ValueError,
453
+ "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
454
+ ptr_info.valid = false;
455
+ }} else if (status != CUDA_SUCCESS) {{
456
+ CUDA_CHECK(status); // Catch any other cuda API errors
457
+ ptr_info.valid = false;
446
458
  }}
447
- PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
448
- ptr_info.valid = false;
459
+ ptr_info.dev_ptr = dev_ptr;
460
+ cleanup:
461
+ Py_XDECREF(ret);
449
462
  return ptr_info;
463
+
450
464
  }}
451
465
 
452
466
  static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
@@ -455,44 +469,18 @@ static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
455
469
  return NULL;
456
470
  }}
457
471
 
458
- PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr");
459
- if (!method_handle) {{
460
- PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist");
472
+ if (Py_TYPE(obj) != (PyTypeObject*)py_tensor_map_type) {{
473
+ PyErr_Format(PyExc_TypeError, "object must be of type PyCUtensorMap, got %s", Py_TYPE(obj)->tp_name);
461
474
  return NULL;
462
- }}
463
-
464
- PyObject *empty_tuple = PyTuple_New(0);
465
- if (!empty_tuple) {{
466
- Py_DECREF(method_handle);
467
- PyErr_SetString(PyExc_SystemError, "Internal Python error!");
468
- return NULL;
469
- }}
470
- PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL);
471
- Py_DECREF(empty_tuple);
472
- Py_DECREF(method_handle);
473
- if (!method_ret) {{
474
- PyErr_SetString(PyExc_SystemError, "Internal Python error!");
475
- return NULL;
476
- }}
477
-
478
- if (!PyLong_Check(method_ret)) {{
479
- PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int");
480
- Py_DECREF(method_ret);
481
- return NULL;
482
- }}
475
+ }}
483
476
 
484
- uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret);
485
- Py_DECREF(method_ret);
486
- if (!ptr_as_uint) {{
487
- PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()");
477
+ CUtensorMap* map = &((PyCUtensorMapObject*)obj)->tensorMap;
478
+ uintptr_t align_128 = (uintptr_t)map & (128 - 1);
479
+ if (align_128 != 0) {{
480
+ PyErr_Format(PyExc_ValueError, "CUtensorMap must be aligned to 128B, but got (&map) mod 128 = %ld", align_128);
488
481
  return NULL;
489
482
  }}
490
- if (ptr_as_uint % 64 != 0) {{
491
- PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned");
492
- return NULL;
493
- }}
494
-
495
- return (CUtensorMap*)(ptr_as_uint);
483
+ return map;
496
484
  }}
497
485
 
498
486
  static void ensureCudaContext() {{
@@ -547,9 +535,10 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
547
535
  PyObject *kernel_metadata = NULL;
548
536
  PyObject *launch_metadata = NULL;
549
537
  PyObject *global_scratch_obj = NULL;
538
+ PyObject *profile_scratch_obj = NULL;
550
539
  {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
551
540
  if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
552
- &_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj,
541
+ &_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj, &profile_scratch_obj,
553
542
  &kernel_metadata, &launch_metadata,
554
543
  &launch_enter_hook, &launch_exit_hook{args_list})) {{
555
544
  return NULL;
@@ -563,9 +552,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
563
552
 
564
553
  // extract launch metadata
565
554
  if (launch_enter_hook != Py_None){{
566
- PyObject* args = Py_BuildValue("(O)", launch_metadata);
567
- PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
568
- Py_DECREF(args);
555
+ PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
569
556
  if (!ret)
570
557
  return NULL;
571
558
  Py_DECREF(ret);
@@ -580,21 +567,28 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
580
567
  global_scratch = global_scratch_info.dev_ptr;
581
568
  }}
582
569
 
570
+ CUdeviceptr profile_scratch = 0;
571
+ if (profile_scratch_obj != Py_None) {{
572
+ DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
573
+ if (!profile_scratch_info.valid) {{
574
+ return NULL;
575
+ }}
576
+ profile_scratch = profile_scratch_info.dev_ptr;
577
+ }}
578
+
583
579
  // raise exception asap
584
580
  {newline.join(ptr_decls)}
585
581
  {newline.join(tma_decls)}
586
582
  {newline.join(float_storage_decls)}
587
583
  Py_BEGIN_ALLOW_THREADS;
588
- _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
584
+ _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
589
585
  Py_END_ALLOW_THREADS;
590
586
  if (PyErr_Occurred()) {{
591
587
  return NULL;
592
588
  }}
593
589
 
594
590
  if(launch_exit_hook != Py_None){{
595
- PyObject* args = Py_BuildValue("(O)", launch_metadata);
596
- PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
597
- Py_DECREF(args);
591
+ PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
598
592
  if (!ret)
599
593
  return NULL;
600
594
  Py_DECREF(ret);
@@ -617,6 +611,19 @@ static struct PyModuleDef ModuleDef = {{
617
611
  }};
618
612
 
619
613
  PyMODINIT_FUNC PyInit___triton_launcher(void) {{
614
+ data_ptr_str = PyUnicode_InternFromString("data_ptr");
615
+ if(data_ptr_str == NULL) {{
616
+ return NULL;
617
+ }}
618
+ PyObject* driver_mod = PyImport_ImportModule("triton.backends.nvidia.driver");
619
+ if (driver_mod == NULL) {{
620
+ return NULL;
621
+ }}
622
+ py_tensor_map_type = PyObject_GetAttrString(driver_mod, "PyCUtensorMap");
623
+ if (py_tensor_map_type == NULL) {{
624
+ return NULL;
625
+ }}
626
+
620
627
  PyObject *m = PyModule_Create(&ModuleDef);
621
628
  if(m == NULL) {{
622
629
  return NULL;
@@ -628,18 +635,6 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
628
635
  return src
629
636
 
630
637
 
631
- class TmaDescKernelParam:
632
- TMA_DESC_SIZE = 128
633
-
634
- def __init__(self):
635
- import torch
636
- self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device="cpu")
637
-
638
- # Return a CUtensorMap* pointer in host memory
639
- def tma_desc_cpu_ptr(self):
640
- return self.desc.data_ptr()
641
-
642
-
643
638
  # The TMA dtype enum values are slightly different on host vs device...
644
639
  TMA_DTYPE_DEVICE_TO_HOST = dict((i, i) for i in range(16))
645
640
  TMA_DTYPE_DEVICE_TO_HOST[8] = 10
@@ -655,7 +650,7 @@ def make_tensordesc_arg(arg, metadata):
655
650
  # descriptors which is why we provide our own decomposition
656
651
  # above. Sadly this means we have to pass the shape and strides
657
652
  # twice.
658
- return [arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides]
653
+ return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
659
654
 
660
655
  swizzle = metadata["swizzle"]
661
656
  elem_size = metadata["elem_size"]
@@ -663,48 +658,50 @@ def make_tensordesc_arg(arg, metadata):
663
658
  block_size = metadata["block_size"]
664
659
  fp4_padded = metadata["fp4_padded"]
665
660
 
666
- data_ptr = arg.base.data_ptr()
667
661
  shape = arg.shape
668
662
  strides = arg.strides
669
663
  assert strides[-1] == 1
670
-
671
- desc = TmaDescKernelParam()
672
- result = [desc, *shape, *strides]
664
+ padding = 1 if arg.padding == "nan" else 0
673
665
 
674
666
  if fp4_padded:
675
667
  shape = list(shape)
676
668
  shape[-1] *= 2
677
- triton.runtime.driver.active.utils.fill_tma_descriptor(
678
- desc.tma_desc_cpu_ptr(),
679
- data_ptr,
669
+
670
+ cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor(
671
+ arg.base.data_ptr(),
680
672
  swizzle,
681
673
  elem_size,
682
674
  TMA_DTYPE_DEVICE_TO_HOST[elem_type],
683
675
  block_size,
684
676
  shape,
685
677
  strides,
678
+ padding,
686
679
  )
687
- return result
680
+
681
+ return [cu_tensor_map, *shape, *strides]
688
682
 
689
683
 
690
- def wrap_handle_tensordesc(launcher, tensordesc_meta):
691
- from triton.tools.tensor_descriptor import TensorDescriptor
692
- from triton.experimental.gluon.nvidia.hopper import TensorDescriptor as GluonTensorDescriptor
684
+ def wrap_handle_tensordesc(launcher, signature, tensordesc_meta):
685
+ has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
686
+ if not has_tensor_desc_arg:
687
+ return launcher
688
+
689
+ tensordesc_indices = set(
690
+ [i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
691
+ assert not tensordesc_meta or len(tensordesc_meta) == len(tensordesc_indices)
692
+ if not tensordesc_meta:
693
+ tensordesc_meta = [None] * len(tensordesc_indices)
693
694
 
694
695
  def inner(*args):
695
- meta_args = args[:len(_BASE_ARGS_FORMAT)]
696
- raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
696
+ final_args = list(args[:_BASE_ARGS_FORMAT_LEN])
697
697
  tensordesc_idx = 0
698
- final_args = []
699
- for i, arg in enumerate(raw_kernel_args):
700
- if isinstance(arg, (TensorDescriptor, GluonTensorDescriptor)):
701
- meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
698
+ for i, arg in enumerate(args[_BASE_ARGS_FORMAT_LEN:]):
699
+ if i in tensordesc_indices:
700
+ final_args.extend(make_tensordesc_arg(arg, tensordesc_meta[tensordesc_idx]))
702
701
  tensordesc_idx += 1
703
- final_args.extend(make_tensordesc_arg(arg, meta))
704
702
  else:
705
703
  final_args.append(arg)
706
- assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
707
- return launcher(*meta_args, *final_args)
704
+ return launcher(*final_args)
708
705
 
709
706
  return inner
710
707
 
@@ -725,24 +722,31 @@ class CudaLauncher(object):
725
722
  include_dirs=include_dirs,
726
723
  libraries=libraries,
727
724
  )
728
- has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
729
725
 
730
726
  self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)
731
- self.launch = wrap_handle_tensordesc(mod.launch, tensordesc_meta) if has_tensor_desc_arg else mod.launch
727
+ self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
732
728
  self.global_scratch_size = metadata.global_scratch_size
733
729
  self.global_scratch_align = metadata.global_scratch_align
730
+ self.profile_scratch_size = metadata.profile_scratch_size
731
+ self.profile_scratch_align = metadata.profile_scratch_align
734
732
  self.launch_cooperative_grid = metadata.launch_cooperative_grid
735
733
  self.launch_pdl = metadata.launch_pdl
736
734
 
737
735
  def __call__(self, gridX, gridY, gridZ, stream, function, *args):
738
- if self.global_scratch_size > 0:
739
- grid_size = gridX * gridY * gridZ
740
- alloc_size = grid_size * self.num_ctas * self.global_scratch_size
741
- global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
742
- else:
743
- global_scratch = None
736
+
737
+ def allocate_scratch(size, align, allocator):
738
+ if size > 0:
739
+ grid_size = gridX * gridY * gridZ
740
+ alloc_size = grid_size * self.num_ctas * size
741
+ alloc_fn = allocator.get()
742
+ return alloc_fn(alloc_size, align, stream)
743
+ return None
744
+
745
+ global_scratch = allocate_scratch(self.global_scratch_size, self.global_scratch_align, _allocation._allocator)
746
+ profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
747
+ _allocation._profile_allocator)
744
748
  self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
745
- global_scratch, *args)
749
+ global_scratch, profile_scratch, *args)
746
750
 
747
751
 
748
752
  class CudaDriver(GPUDriver):
@@ -775,6 +779,9 @@ class CudaDriver(GPUDriver):
775
779
  except ImportError:
776
780
  return False
777
781
 
782
+ def map_python_to_cpp_type(self, ty: str) -> str:
783
+ return ty_to_cpp(ty)
784
+
778
785
  def get_benchmarker(self):
779
786
  from triton.testing import do_bench
780
787
  return do_bench
@@ -1,4 +1,7 @@
1
- from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict
1
+ from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict, get_cache_key
2
2
  from .errors import CompilationError
3
3
 
4
- __all__ = ["compile", "make_backend", "ASTSource", "IRSource", "CompiledKernel", "CompilationError", "LazyDict"]
4
+ __all__ = [
5
+ "compile", "make_backend", "ASTSource", "IRSource", "CompiledKernel", "CompilationError", "LazyDict",
6
+ "get_cache_key"
7
+ ]