triton-windows 3.3.1.post19__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -9,10 +9,15 @@
9
9
  #endif
10
10
 
11
11
  #include <stdbool.h>
12
+ #include <stdlib.h>
12
13
  #define PY_SSIZE_T_CLEAN
13
- #define Py_LIMITED_API 0x03090000
14
14
  #include <Python.h>
15
15
 
16
+ typedef struct {
17
+ PyObject_HEAD
18
+ _Alignas(128) CUtensorMap tensorMap;
19
+ } PyCUtensorMapObject;
20
+
16
21
  // Raises a Python exception and returns false if code is not CUDA_SUCCESS.
17
22
  static bool gpuAssert(CUresult code, const char *file, int line) {
18
23
  if (code == CUDA_SUCCESS)
@@ -35,7 +40,7 @@ static bool gpuAssert(CUresult code, const char *file, int line) {
35
40
  #define CUDA_CHECK_AND_RETURN_NULL(ans) \
36
41
  do { \
37
42
  if (!gpuAssert((ans), __FILE__, __LINE__)) \
38
- return NULL; \
43
+ goto cleanup; \
39
44
  } while (0)
40
45
 
41
46
  // To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
@@ -53,7 +58,7 @@ static bool gpuAssert(CUresult code, const char *file, int line) {
53
58
  if ((funcPointer) == NULL) { \
54
59
  (funcPointer) = (initializerFunction)(); \
55
60
  if ((funcPointer) == NULL) { \
56
- return NULL; \
61
+ goto cleanup; \
57
62
  } \
58
63
  } \
59
64
  } while (0)
@@ -96,6 +101,9 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
96
101
  warp_size, "sm_clock_rate", sm_clock_rate,
97
102
  "mem_clock_rate", mem_clock_rate, "mem_bus_width",
98
103
  mem_bus_width);
104
+
105
+ cleanup:
106
+ return NULL;
99
107
  }
100
108
 
101
109
  static PyObject *loadBinary(PyObject *self, PyObject *args) {
@@ -112,6 +120,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
112
120
  CUmodule mod;
113
121
  int32_t n_regs = 0;
114
122
  int32_t n_spills = 0;
123
+ int32_t n_max_threads = 0;
115
124
  // create driver handles
116
125
  CUcontext pctx = 0;
117
126
 
@@ -132,6 +141,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
132
141
  CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
133
142
  cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
134
143
  n_spills /= 4;
144
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
145
+ &n_max_threads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
135
146
  // set dynamic shared memory if necessary
136
147
  int shared_optin;
137
148
  CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
@@ -155,8 +166,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
155
166
  if (PyErr_Occurred()) {
156
167
  return NULL;
157
168
  }
158
- return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
159
- n_spills);
169
+ return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
170
+ n_spills, n_max_threads);
160
171
  }
161
172
 
162
173
  typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
@@ -266,6 +277,9 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
266
277
  cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config));
267
278
  Py_END_ALLOW_THREADS;
268
279
  return PyLong_FromLong(maxActiveClusters);
280
+
281
+ cleanup:
282
+ return NULL;
269
283
  }
270
284
 
271
285
  static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
@@ -304,116 +318,162 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
304
318
  }
305
319
 
306
320
  Py_END_ALLOW_THREADS;
307
- Py_INCREF(Py_None);
308
- return Py_None;
321
+ Py_RETURN_NONE;
309
322
  }
310
323
 
311
- // Simple helper to experiment creating TMA descriptors on the host.
312
- // This is a useful to test TMA operations independently.
313
- static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) {
314
- unsigned long long global_address;
315
- uint64_t dim;
316
- uint32_t tensorDim;
317
- int elementSize;
318
- unsigned long long desc_address;
319
- if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim,
320
- &elementSize, &desc_address)) {
321
- return NULL;
322
- }
323
- uint64_t dims[1] = {dim};
324
- uint64_t globalStrides[1] = {dim * elementSize};
325
- uint32_t boxDim[1] = {tensorDim};
326
- uint32_t elementStrides[1] = {1};
327
- CUtensorMapDataType type;
328
- switch (elementSize) {
329
- case 1:
330
- type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
331
- break;
332
- case 2:
333
- type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
334
- break;
335
- case 4:
336
- type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
337
- break;
338
- default:
339
- PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4");
324
+ static PyObject *PyCUtensorMap_alloc(PyTypeObject *type, Py_ssize_t n_items) {
325
+ PyCUtensorMapObject *self = NULL;
326
+ void *mem = NULL;
327
+ size_t size = type->tp_basicsize;
328
+
329
+ #ifdef _WIN32
330
+ mem = _aligned_malloc(size, 128);
331
+ if (mem == NULL) {
332
+ #else
333
+ if (posix_memalign(&mem, 128, size) != 0) {
334
+ #endif
335
+ PyErr_NoMemory();
340
336
  return NULL;
341
337
  }
342
- assert((elementSize * tensorDim) >= 32 && "block size too small.");
343
- int rank = 1;
344
- static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
345
- INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
346
- getCuTensorMapEncodeTiledHandle);
347
- CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
348
- (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims,
349
- globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
350
- CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE,
351
- CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
352
- Py_INCREF(Py_None);
353
- return Py_None;
338
+
339
+ self = (PyCUtensorMapObject *)mem;
340
+ PyObject_INIT(self, type);
341
+ return (PyObject *)self;
342
+ }
343
+
344
+ static void PyCUtensorMap_dealloc(PyObject *self) {
345
+ Py_TYPE(self)->tp_free(self);
346
+ }
347
+
348
+ static void PyCUtensorMap_free(void *ptr) {
349
+ #ifdef _WIN32
350
+ _aligned_free(ptr);
351
+ #else
352
+ free(ptr);
353
+ #endif
354
354
  }
355
355
 
356
- // Simple helper to experiment creating TMA descriptors on the host.
357
- // This is a useful to test TMA operations independently.
358
- static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) {
356
+ // clang-format off
357
+ static PyTypeObject PyCUtensorMapType = {
358
+ PyVarObject_HEAD_INIT(NULL, 0)
359
+ .tp_name = "triton.backends.nvidia.PyCUtensorMap",
360
+ .tp_basicsize = sizeof(PyCUtensorMapObject),
361
+ .tp_itemsize = 0,
362
+ .tp_flags = Py_TPFLAGS_DEFAULT,
363
+ .tp_doc = "<PyCUtensorMap object>",
364
+ .tp_new = PyType_GenericNew,
365
+ .tp_alloc = PyCUtensorMap_alloc,
366
+ .tp_dealloc = (destructor)PyCUtensorMap_dealloc,
367
+ .tp_free = PyCUtensorMap_free,
368
+ };
369
+ // clang-format on
370
+
371
+ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
359
372
  unsigned long long global_address;
360
- uint64_t dims[2];
361
- uint32_t tensorDims[2];
362
- int elementSize;
363
- unsigned long long desc_address;
364
- if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0],
365
- &tensorDims[1], &tensorDims[0], &elementSize,
366
- &desc_address)) {
373
+ int swizzle;
374
+ int elemSize;
375
+ int elemType;
376
+ PyObject *blockSize;
377
+ PyObject *shape;
378
+ PyObject *strides;
379
+ int padding;
380
+
381
+ if (!PyArg_ParseTuple(args, "KiiiOOOi", &global_address, &swizzle, &elemSize,
382
+ &elemType, &blockSize, &shape, &strides, &padding)) {
367
383
  return NULL;
368
384
  }
369
- uint64_t globalStrides[2] = {dims[0] * elementSize,
370
- dims[0] * dims[1] * elementSize};
371
- uint32_t elementStrides[2] = {1, 1};
372
- CUtensorMapDataType type;
373
- switch (elementSize) {
374
- case 1:
375
- type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
376
- break;
377
- case 2:
378
- type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
379
- break;
380
- case 4:
381
- type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
382
- break;
383
- default:
384
- PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4");
385
+
386
+ PyCUtensorMapObject *desc = (PyCUtensorMapObject *)PyObject_CallObject(
387
+ (PyObject *)&PyCUtensorMapType, NULL);
388
+ if (!desc) {
389
+ return NULL;
390
+ }
391
+
392
+ PyObject *blockSizeFast = NULL;
393
+ PyObject *shapeFast = NULL;
394
+ PyObject *stridesFast = NULL;
395
+
396
+ uint32_t blockSizeInt[5];
397
+ uint64_t shapeInt[5];
398
+ uint64_t stridesLL[5];
399
+
400
+ blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
401
+ if (!blockSizeFast)
402
+ goto cleanup;
403
+ int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
404
+
405
+ for (int i = 0; i < rank; ++i) {
406
+ PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
407
+ if (!PyLong_Check(item)) {
408
+ PyErr_SetString(PyExc_TypeError, "block size must be an int");
409
+ goto cleanup;
410
+ }
411
+ blockSizeInt[rank - i - 1] = PyLong_AsLongLong(item);
412
+ }
413
+
414
+ shapeFast = PySequence_Fast(shape, "shape must be a sequence");
415
+ if (!shapeFast)
416
+ goto cleanup;
417
+
418
+ if (rank != PySequence_Fast_GET_SIZE(shapeFast)) {
419
+ PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
420
+ goto cleanup;
385
421
  }
386
- int rank = 2;
387
- // Swizzling should be picked in codegen but since we need to set it on the
388
- // descriptor we rely on a convention between this function and codegen.
389
- CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
390
- uint32_t contigDimSizeInByte = elementSize * tensorDims[0];
391
- if (contigDimSizeInByte >= 128) {
392
- swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
393
- } else if (contigDimSizeInByte >= 64) {
394
- swizzle = CU_TENSOR_MAP_SWIZZLE_64B;
395
- } else if (contigDimSizeInByte >= 32) {
396
- swizzle = CU_TENSOR_MAP_SWIZZLE_32B;
397
- } else {
398
- assert(false && "block size too small.");
422
+ for (int i = 0; i < rank; ++i) {
423
+ PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
424
+ if (!PyLong_Check(item)) {
425
+ PyErr_SetString(PyExc_TypeError, "shape must be an int");
426
+ goto cleanup;
427
+ }
428
+ shapeInt[rank - i - 1] = PyLong_AsLong(item);
399
429
  }
400
- // The bounding box inner dimension must be less than or equal to the swizzle
401
- // size.
402
- // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
403
- // We clamp the block size and the codegen will emit multiple copy operations.
404
- if (contigDimSizeInByte > 128) {
405
- tensorDims[0] = 128 / elementSize;
430
+
431
+ stridesFast = PySequence_Fast(strides, "strides must be a sequence");
432
+ if (!stridesFast)
433
+ goto cleanup;
434
+
435
+ if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
436
+ PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
437
+ goto cleanup;
406
438
  }
439
+ for (int i = 0; i + 1 < rank; ++i) {
440
+ PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
441
+ if (!PyLong_Check(item)) {
442
+ PyErr_SetString(PyExc_TypeError, "shape must be an int");
443
+ goto cleanup;
444
+ }
445
+ stridesLL[rank - i - 2] = elemSize * PyLong_AsLongLong(item);
446
+ }
447
+ stridesLL[rank - 1] =
448
+ shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]);
449
+ Py_DECREF(blockSizeFast);
450
+ blockSizeFast = NULL;
451
+ Py_DECREF(shapeFast);
452
+ shapeFast = NULL;
453
+ Py_DECREF(stridesFast);
454
+ stridesFast = NULL;
455
+
456
+ CUtensorMapFloatOOBfill fill =
457
+ (padding == 1) ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
458
+ : CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
459
+
460
+ uint32_t elementStrides[5] = {1, 1, 1, 1, 1};
407
461
  static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
408
462
  INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
409
463
  getCuTensorMapEncodeTiledHandle);
410
464
  CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
411
- (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims,
412
- globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
413
- swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
414
- CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
415
- Py_INCREF(Py_None);
416
- return Py_None;
465
+ &desc->tensorMap, elemType, rank, (void *)global_address, shapeInt,
466
+ stridesLL, blockSizeInt, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
467
+ swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill));
468
+
469
+ return (PyObject *)desc;
470
+
471
+ cleanup:
472
+ Py_XDECREF(blockSizeFast);
473
+ Py_XDECREF(shapeFast);
474
+ Py_XDECREF(stridesFast);
475
+ Py_XDECREF(desc);
476
+ return NULL;
417
477
  }
418
478
 
419
479
  static PyMethodDef ModuleMethods[] = {
@@ -429,8 +489,7 @@ static PyMethodDef ModuleMethods[] = {
429
489
  "being dropped. This inherits all the limitations of this call; in "
430
490
  "particular it's an error to change this value after launching any kernel "
431
491
  "that calls printf()."},
432
- {"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"},
433
- {"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"},
492
+ {"fill_tma_descriptor", fillTMADescriptor, METH_VARARGS, "doc"},
434
493
 
435
494
  {NULL, NULL, 0, NULL} // sentinel
436
495
  };
@@ -441,12 +500,18 @@ static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
441
500
  ModuleMethods};
442
501
 
443
502
  PyMODINIT_FUNC PyInit_cuda_utils(void) {
503
+ if (PyType_Ready(&PyCUtensorMapType) < 0) {
504
+ return NULL;
505
+ }
506
+
444
507
  PyObject *m = PyModule_Create(&ModuleDef);
445
508
  if (m == NULL) {
446
509
  return NULL;
447
510
  }
448
511
 
449
512
  PyModule_AddFunctions(m, ModuleMethods);
513
+ Py_INCREF(&PyCUtensorMapType);
514
+ PyModule_AddObject(m, "PyCUtensorMap", (PyObject *)&PyCUtensorMapType);
450
515
 
451
516
  return m;
452
517
  }