triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,517 @@
1
+ #define _CRT_SECURE_NO_WARNINGS
2
+ #include "cuda.h"
3
+
4
+ #ifndef _WIN32
5
+ #include <dlfcn.h>
6
+ #else
7
+ #define WIN32_LEAN_AND_MEAN
8
+ #include <windows.h>
9
+ #endif
10
+
11
+ #include <stdbool.h>
12
+ #include <stdlib.h>
13
+ #define PY_SSIZE_T_CLEAN
14
+ #include <Python.h>
15
+
16
+ typedef struct {
17
+ PyObject_HEAD
18
+ _Alignas(128) CUtensorMap tensorMap;
19
+ } PyCUtensorMapObject;
20
+
21
+ // Raises a Python exception and returns false if code is not CUDA_SUCCESS.
22
+ static bool gpuAssert(CUresult code, const char *file, int line) {
23
+ if (code == CUDA_SUCCESS)
24
+ return true;
25
+
26
+ const char *prefix = "Triton Error [CUDA]: ";
27
+ const char *str;
28
+ cuGetErrorString(code, &str);
29
+ char err[1024] = {0};
30
+ strcat(err, prefix);
31
+ strcat(err, str);
32
+ PyGILState_STATE gil_state;
33
+ gil_state = PyGILState_Ensure();
34
+ PyErr_SetString(PyExc_RuntimeError, err);
35
+ PyGILState_Release(gil_state);
36
+ return false;
37
+ }
38
+
39
+ // To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block.
40
+ #define CUDA_CHECK_AND_RETURN_NULL(ans) \
41
+ do { \
42
+ if (!gpuAssert((ans), __FILE__, __LINE__)) \
43
+ goto cleanup; \
44
+ } while (0)
45
+
46
+ // To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
47
+ #define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \
48
+ do { \
49
+ if (!gpuAssert((ans), __FILE__, __LINE__)) { \
50
+ PyEval_RestoreThread(_save); \
51
+ return NULL; \
52
+ } \
53
+ } while (0)
54
+
55
+ // Used to check if functions exist in old CUDA driver versions.
56
+ #define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \
57
+ do { \
58
+ if ((funcPointer) == NULL) { \
59
+ (funcPointer) = (initializerFunction)(); \
60
+ if ((funcPointer) == NULL) { \
61
+ goto cleanup; \
62
+ } \
63
+ } \
64
+ } while (0)
65
+
66
+ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
67
+ int device_id;
68
+ if (!PyArg_ParseTuple(args, "i", &device_id))
69
+ return NULL;
70
+ // Get device handle
71
+ CUdevice device;
72
+ cuDeviceGet(&device, device_id);
73
+
74
+ // create a struct to hold device properties
75
+ int max_shared_mem;
76
+ int max_num_regs;
77
+ int multiprocessor_count;
78
+ int warp_size;
79
+ int sm_clock_rate;
80
+ int mem_clock_rate;
81
+ int mem_bus_width;
82
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
83
+ &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
84
+ device));
85
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
86
+ &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device));
87
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
88
+ &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
89
+ CUDA_CHECK_AND_RETURN_NULL(
90
+ cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device));
91
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
92
+ &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
93
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
94
+ &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
95
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
96
+ &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
97
+
98
+ return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
99
+ max_shared_mem, "max_num_regs", max_num_regs,
100
+ "multiprocessor_count", multiprocessor_count, "warpSize",
101
+ warp_size, "sm_clock_rate", sm_clock_rate,
102
+ "mem_clock_rate", mem_clock_rate, "mem_bus_width",
103
+ mem_bus_width);
104
+
105
+ cleanup:
106
+ return NULL;
107
+ }
108
+
109
+ static PyObject *loadBinary(PyObject *self, PyObject *args) {
110
+ const char *name;
111
+ const char *data;
112
+ Py_ssize_t data_size;
113
+ int shared;
114
+ int device;
115
+ if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
116
+ &device)) {
117
+ return NULL;
118
+ }
119
+ CUfunction fun;
120
+ CUmodule mod;
121
+ int32_t n_regs = 0;
122
+ int32_t n_spills = 0;
123
+ int32_t n_max_threads = 0;
124
+ // create driver handles
125
+ CUcontext pctx = 0;
126
+
127
+ Py_BEGIN_ALLOW_THREADS;
128
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx));
129
+ if (!pctx) {
130
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
131
+ cuDevicePrimaryCtxRetain(&pctx, device));
132
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx));
133
+ }
134
+
135
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data));
136
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
137
+ cuModuleGetFunction(&fun, mod, name));
138
+ // get allocated registers and spilled registers from the function
139
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
140
+ cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
141
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
142
+ cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
143
+ n_spills /= 4;
144
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
145
+ &n_max_threads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
146
+ // set dynamic shared memory if necessary
147
+ int shared_optin;
148
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
149
+ &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
150
+ device));
151
+ if (shared > 49152 && shared_optin > 49152) {
152
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
153
+ cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
154
+ int shared_total, shared_static;
155
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
156
+ &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
157
+ device));
158
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
159
+ &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
160
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
161
+ cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
162
+ shared_optin - shared_static));
163
+ }
164
+ Py_END_ALLOW_THREADS;
165
+
166
+ if (PyErr_Occurred()) {
167
+ return NULL;
168
+ }
169
+ return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
170
+ n_spills, n_max_threads);
171
+ }
172
+
173
+ typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
174
+ int *numClusters, CUfunction func, const CUlaunchConfig *config);
175
+
176
+ typedef CUresult (*cuTensorMapEncodeTiled_t)(
177
+ CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
178
+ cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
179
+ const cuuint64_t *globalStrides, const cuuint32_t *boxDim,
180
+ const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
181
+ CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
182
+ CUtensorMapFloatOOBfill oobFill);
183
+
184
+ #ifndef _WIN32
185
+ #define defineGetFunctionHandle(name, symbolName) \
186
+ static symbolName##_t name() { \
187
+ /* Open the shared library */ \
188
+ void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \
189
+ if (!libHandle) { \
190
+ PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \
191
+ return NULL; \
192
+ } \
193
+ /* Clear any existing error */ \
194
+ dlerror(); \
195
+ symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \
196
+ /* Check for errors */ \
197
+ const char *err = dlerror(); \
198
+ if (err) { \
199
+ PyErr_SetString(PyExc_RuntimeError, \
200
+ "Failed to retrieve " #symbolName " from libcuda.so.1"); \
201
+ dlclose(libHandle); \
202
+ return NULL; \
203
+ } \
204
+ return funcHandle; \
205
+ }
206
+ #else
207
+ #define defineGetFunctionHandle(name, symbolName) \
208
+ static symbolName##_t name() { \
209
+ /* Open the shared library */ \
210
+ HMODULE handle = LoadLibraryA("nvcuda.dll"); \
211
+ if (!handle) { \
212
+ PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll"); \
213
+ return NULL; \
214
+ } \
215
+ symbolName##_t funcHandle = \
216
+ (symbolName##_t)GetProcAddress((HMODULE)handle, #symbolName); \
217
+ /* Check for errors */ \
218
+ long err = GetLastError(); \
219
+ if (err) { \
220
+ PyErr_SetString(PyExc_RuntimeError, \
221
+ "Failed to retrieve " #symbolName " from nvcuda.dll"); \
222
+ return NULL; \
223
+ } \
224
+ return funcHandle; \
225
+ }
226
+ #endif
227
+
228
+ defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
229
+ cuOccupancyMaxActiveClusters);
230
+
231
+ defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
232
+ cuTensorMapEncodeTiled);
233
+
234
+ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
235
+ int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
236
+ maxActiveClusters = -1;
237
+ int shared = 0;
238
+ CUfunction func;
239
+
240
+ if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX,
241
+ &clusterDimY, &clusterDimZ)) {
242
+ return NULL;
243
+ }
244
+
245
+ // Let each SM have one block
246
+ int maxActiveBlocks = 1;
247
+ Py_BEGIN_ALLOW_THREADS;
248
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
249
+ func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared));
250
+ Py_END_ALLOW_THREADS;
251
+
252
+ CUlaunchAttribute launchAttr[1];
253
+ launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
254
+ launchAttr[0].value.clusterDim.x = clusterDimX;
255
+ launchAttr[0].value.clusterDim.y = clusterDimY;
256
+ launchAttr[0].value.clusterDim.z = clusterDimZ;
257
+ CUlaunchConfig config;
258
+ config.gridDimX = clusterDimX;
259
+ config.gridDimY = maxActiveBlocks * clusterDimY;
260
+ config.gridDimZ = clusterDimZ;
261
+ config.blockDimX = 128;
262
+ config.blockDimY = 1;
263
+ config.blockDimZ = 1;
264
+ config.sharedMemBytes = shared;
265
+ config.hStream = 0;
266
+ config.numAttrs = 1;
267
+ config.attrs = launchAttr;
268
+
269
+ static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL;
270
+ INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters,
271
+ getCuOccupancyMaxActiveClustersHandle);
272
+
273
+ Py_BEGIN_ALLOW_THREADS;
274
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
275
+ func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
276
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
277
+ cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config));
278
+ Py_END_ALLOW_THREADS;
279
+ return PyLong_FromLong(maxActiveClusters);
280
+
281
+ cleanup:
282
+ return NULL;
283
+ }
284
+
285
+ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
286
+ long size;
287
+ if (!PyArg_ParseTuple(args, "l", &size)) {
288
+ return NULL;
289
+ }
290
+ if (size < 0) {
291
+ PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative");
292
+ return NULL;
293
+ }
294
+
295
+ Py_BEGIN_ALLOW_THREADS;
296
+
297
+ // Ensure we have an active context.
298
+ CUcontext ctx = NULL;
299
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx));
300
+ if (!ctx) {
301
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
302
+ cuDevicePrimaryCtxRetain(&ctx, /*device=*/0));
303
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx));
304
+ }
305
+
306
+ // We can't set the fifo size after running a kernel that calls printf. This
307
+ // is true even if the set() call is a nop and the new size is the same as the
308
+ // old size.
309
+ //
310
+ // This is unfriendly, so check if the old size matches the new size, and skip
311
+ // the set() call if so.
312
+ size_t oldSize = 0;
313
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
314
+ cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE));
315
+ if (oldSize != size) {
316
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
317
+ cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size));
318
+ }
319
+
320
+ Py_END_ALLOW_THREADS;
321
+ Py_RETURN_NONE;
322
+ }
323
+
324
+ static PyObject *PyCUtensorMap_alloc(PyTypeObject *type, Py_ssize_t n_items) {
325
+ PyCUtensorMapObject *self = NULL;
326
+ void *mem = NULL;
327
+ size_t size = type->tp_basicsize;
328
+
329
+ #ifdef _WIN32
330
+ mem = _aligned_malloc(size, 128);
331
+ if (mem == NULL) {
332
+ #else
333
+ if (posix_memalign(&mem, 128, size) != 0) {
334
+ #endif
335
+ PyErr_NoMemory();
336
+ return NULL;
337
+ }
338
+
339
+ self = (PyCUtensorMapObject *)mem;
340
+ PyObject_INIT(self, type);
341
+ return (PyObject *)self;
342
+ }
343
+
344
+ static void PyCUtensorMap_dealloc(PyObject *self) {
345
+ Py_TYPE(self)->tp_free(self);
346
+ }
347
+
348
+ static void PyCUtensorMap_free(void *ptr) {
349
+ #ifdef _WIN32
350
+ _aligned_free(ptr);
351
+ #else
352
+ free(ptr);
353
+ #endif
354
+ }
355
+
356
+ // clang-format off
357
+ static PyTypeObject PyCUtensorMapType = {
358
+ PyVarObject_HEAD_INIT(NULL, 0)
359
+ .tp_name = "triton.backends.nvidia.PyCUtensorMap",
360
+ .tp_basicsize = sizeof(PyCUtensorMapObject),
361
+ .tp_itemsize = 0,
362
+ .tp_flags = Py_TPFLAGS_DEFAULT,
363
+ .tp_doc = "<PyCUtensorMap object>",
364
+ .tp_new = PyType_GenericNew,
365
+ .tp_alloc = PyCUtensorMap_alloc,
366
+ .tp_dealloc = (destructor)PyCUtensorMap_dealloc,
367
+ .tp_free = PyCUtensorMap_free,
368
+ };
369
+ // clang-format on
370
+
371
+ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
372
+ unsigned long long global_address;
373
+ int swizzle;
374
+ int elemSize;
375
+ int elemType;
376
+ PyObject *blockSize;
377
+ PyObject *shape;
378
+ PyObject *strides;
379
+ int padding;
380
+
381
+ if (!PyArg_ParseTuple(args, "KiiiOOOi", &global_address, &swizzle, &elemSize,
382
+ &elemType, &blockSize, &shape, &strides, &padding)) {
383
+ return NULL;
384
+ }
385
+
386
+ PyCUtensorMapObject *desc = (PyCUtensorMapObject *)PyObject_CallObject(
387
+ (PyObject *)&PyCUtensorMapType, NULL);
388
+ if (!desc) {
389
+ return NULL;
390
+ }
391
+
392
+ PyObject *blockSizeFast = NULL;
393
+ PyObject *shapeFast = NULL;
394
+ PyObject *stridesFast = NULL;
395
+
396
+ uint32_t blockSizeInt[5];
397
+ uint64_t shapeInt[5];
398
+ uint64_t stridesLL[5];
399
+
400
+ blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
401
+ if (!blockSizeFast)
402
+ goto cleanup;
403
+ int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
404
+
405
+ for (int i = 0; i < rank; ++i) {
406
+ PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
407
+ if (!PyLong_Check(item)) {
408
+ PyErr_SetString(PyExc_TypeError, "block size must be an int");
409
+ goto cleanup;
410
+ }
411
+ blockSizeInt[rank - i - 1] = PyLong_AsLongLong(item);
412
+ }
413
+
414
+ shapeFast = PySequence_Fast(shape, "shape must be a sequence");
415
+ if (!shapeFast)
416
+ goto cleanup;
417
+
418
+ if (rank != PySequence_Fast_GET_SIZE(shapeFast)) {
419
+ PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
420
+ goto cleanup;
421
+ }
422
+ for (int i = 0; i < rank; ++i) {
423
+ PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
424
+ if (!PyLong_Check(item)) {
425
+ PyErr_SetString(PyExc_TypeError, "shape must be an int");
426
+ goto cleanup;
427
+ }
428
+ shapeInt[rank - i - 1] = PyLong_AsLong(item);
429
+ }
430
+
431
+ stridesFast = PySequence_Fast(strides, "strides must be a sequence");
432
+ if (!stridesFast)
433
+ goto cleanup;
434
+
435
+ if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
436
+ PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
437
+ goto cleanup;
438
+ }
439
+ for (int i = 0; i + 1 < rank; ++i) {
440
+ PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
441
+ if (!PyLong_Check(item)) {
442
+ PyErr_SetString(PyExc_TypeError, "shape must be an int");
443
+ goto cleanup;
444
+ }
445
+ stridesLL[rank - i - 2] = elemSize * PyLong_AsLongLong(item);
446
+ }
447
+ stridesLL[rank - 1] =
448
+ shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]);
449
+ Py_DECREF(blockSizeFast);
450
+ blockSizeFast = NULL;
451
+ Py_DECREF(shapeFast);
452
+ shapeFast = NULL;
453
+ Py_DECREF(stridesFast);
454
+ stridesFast = NULL;
455
+
456
+ CUtensorMapFloatOOBfill fill =
457
+ (padding == 1) ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
458
+ : CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
459
+
460
+ uint32_t elementStrides[5] = {1, 1, 1, 1, 1};
461
+ static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
462
+ INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
463
+ getCuTensorMapEncodeTiledHandle);
464
+ CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
465
+ &desc->tensorMap, elemType, rank, (void *)global_address, shapeInt,
466
+ stridesLL, blockSizeInt, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
467
+ swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill));
468
+
469
+ return (PyObject *)desc;
470
+
471
+ cleanup:
472
+ Py_XDECREF(blockSizeFast);
473
+ Py_XDECREF(shapeFast);
474
+ Py_XDECREF(stridesFast);
475
+ Py_XDECREF(desc);
476
+ return NULL;
477
+ }
478
+
479
+ static PyMethodDef ModuleMethods[] = {
480
+ {"load_binary", loadBinary, METH_VARARGS,
481
+ "Load provided cubin into CUDA driver"},
482
+ {"get_device_properties", getDeviceProperties, METH_VARARGS,
483
+ "Get the properties for a given device"},
484
+ {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS,
485
+ "Python interface for cuOccupancyMaxActiveClusters function"},
486
+ {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS,
487
+ "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which "
488
+ "controls how many bytes can be streamed from kernels before data starts "
489
+ "being dropped. This inherits all the limitations of this call; in "
490
+ "particular it's an error to change this value after launching any kernel "
491
+ "that calls printf()."},
492
+ {"fill_tma_descriptor", fillTMADescriptor, METH_VARARGS, "doc"},
493
+
494
+ {NULL, NULL, 0, NULL} // sentinel
495
+ };
496
+
497
+ static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
498
+ NULL, // documentation
499
+ -1, // size
500
+ ModuleMethods};
501
+
502
+ PyMODINIT_FUNC PyInit_cuda_utils(void) {
503
+ if (PyType_Ready(&PyCUtensorMapType) < 0) {
504
+ return NULL;
505
+ }
506
+
507
+ PyObject *m = PyModule_Create(&ModuleDef);
508
+ if (m == NULL) {
509
+ return NULL;
510
+ }
511
+
512
+ PyModule_AddFunctions(m, ModuleMethods);
513
+ Py_INCREF(&PyCUtensorMapType);
514
+ PyModule_AddObject(m, "PyCUtensorMap", (PyObject *)&PyCUtensorMapType);
515
+
516
+ return m;
517
+ }