triton-windows 3.4.0.post20__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,10 @@
1
1
  #define __HIP_PLATFORM_AMD__
2
- // clang-format off
3
- // hip_depreated.h needs definitions from hip_runtime.h.
4
2
  #include <hip/hip_runtime.h>
5
- #include <hip/hip_deprecated.h>
6
- // clang-format on
3
+ #include <hip/hip_runtime_api.h>
7
4
  #define PY_SSIZE_T_CLEAN
8
5
  #include <Python.h>
9
6
  #include <dlfcn.h>
7
+ #include <stdbool.h>
10
8
  #include <stdio.h>
11
9
  #include <stdlib.h>
12
10
 
@@ -18,24 +16,9 @@ static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
18
16
  // in this file.
19
17
  // |FOR_EACH_ERR_FN| is a macro to process APIs that return hipError_t;
20
18
  // |FOR_EACH_STR_FN| is a macro to process APIs that return const char *.
21
- //
22
- // HIP 6.0 introduced an updated hipGetDeviceProperties API under a new symbol,
23
- // hipGetDevicePropertiesR0600. However, the associated hipDeviceProp_t was
24
- // directly updated with breaking changes to match hipGetDevicePropertiesR0600
25
- // in the header file. We include the header file from HIP 6.0. So here if we
26
- // use hipGetDeviceProperties together with hipDeviceProp_t we will use the
27
- // old API with a new struct definition and mess up the interpretation.
28
- //
29
- // This is a known issue: https://github.com/ROCm/ROCm/issues/2728.
30
- //
31
- // For now explicitly defer to the old hipDeviceProp_t struct. This should work
32
- // for both 5.x and 6.x. In the long term we need to switch to use
33
- // hipGetProcAddress once available:
34
- // https://github.com/ROCm/clr/commit/0479cdb3dd30ef58718cad44e424bd793c394cc0
35
19
  #define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \
36
20
  FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \
37
- FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_tR0000 *prop, \
38
- int deviceId) \
21
+ FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_t *prop, int deviceId) \
39
22
  FOR_EACH_ERR_FN(hipModuleLoadDataEx, hipModule_t *module, const void *image, \
40
23
  unsigned int numOptions, hipJitOption *options, \
41
24
  void **optionValues) \
@@ -44,6 +27,34 @@ static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
44
27
  FOR_EACH_ERR_FN(hipFuncGetAttribute, int *, hipFunction_attribute attr, \
45
28
  hipFunction_t function)
46
29
 
30
+ // HIP driver version format: HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR *
31
+ // 100000 + HIP_VERSION_PATCH.
32
+ #define TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version) ((version) / 10000000)
33
+ #define TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version) \
34
+ (((version) % 10000000) / 100000)
35
+ #define TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version) ((version) % 100000)
36
+ #define TRITON_HIP_DRIVER_REQ_MAJOR_VERSION (HIP_VERSION_MAJOR)
37
+
38
+ // #define TRITON_HIP_DRIVER_DBG_VERSION
39
+ #ifdef TRITON_HIP_DRIVER_DBG_VERSION
40
+ #define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
41
+ do { \
42
+ snprintf(msgBuff, sizeof(msgBuff), "libamdhip64 version is: %d.%d.%d", \
43
+ TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version), \
44
+ TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version), \
45
+ TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version)); \
46
+ printf("%s\n", msgBuff); \
47
+ } while (0);
48
+ #else
49
+ #define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
50
+ do { \
51
+ (void)msgBuff; \
52
+ (void)(version); \
53
+ } while (0);
54
+ #endif
55
+
56
+ #define TRITON_HIP_MSG_BUFF_SIZE (1024U)
57
+
47
58
  // The HIP symbol table for holding resolved dynamic library symbols.
48
59
  struct HIPSymbolTable {
49
60
  #define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \
@@ -56,39 +67,96 @@ struct HIPSymbolTable {
56
67
 
57
68
  static struct HIPSymbolTable hipSymbolTable;
58
69
 
59
- bool initSymbolTable() {
60
- // Use the HIP runtime library loaded into the existing process if it exits.
61
- void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
62
- if (lib) {
63
- // printf("[triton] chosen loaded libamdhip64.so in the process\n");
70
+ static int checkDriverVersion(void *lib) {
71
+ int hipVersion = -1;
72
+ const char *error = NULL;
73
+ typedef hipError_t (*hipDriverGetVersion_fn)(int *driverVersion);
74
+ hipDriverGetVersion_fn hipDriverGetVersion;
75
+ dlerror(); // Clear existing errors
76
+ hipDriverGetVersion =
77
+ (hipDriverGetVersion_fn)dlsym(lib, "hipDriverGetVersion");
78
+ error = dlerror();
79
+ if (error) {
80
+ PyErr_SetString(PyExc_RuntimeError,
81
+ "cannot query 'hipDriverGetVersion' from libamdhip64.so");
82
+ dlclose(lib);
83
+ return -1;
64
84
  }
65
85
 
66
- // Otherwise, go through the list of search paths to dlopen the first HIP
67
- // driver library.
68
- if (!lib) {
69
- int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
70
- for (int i = 0; i < n; ++i) {
71
- void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
72
- if (handle) {
73
- lib = handle;
74
- // printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
75
- }
86
+ (void)hipDriverGetVersion(&hipVersion);
87
+ char msgBuff[TRITON_HIP_MSG_BUFF_SIZE] = {0};
88
+
89
+ const int hipMajVersion = TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(hipVersion);
90
+ if (hipMajVersion < TRITON_HIP_DRIVER_REQ_MAJOR_VERSION) {
91
+ const int hipMinVersion =
92
+ TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(hipVersion);
93
+ const int hipPatchVersion =
94
+ TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(hipVersion);
95
+ snprintf(msgBuff, sizeof(msgBuff),
96
+ "libamdhip64 version %d.%d.%d is not supported! Required major "
97
+ "version is >=%d.",
98
+ hipMajVersion, hipMinVersion, hipPatchVersion,
99
+ TRITON_HIP_DRIVER_REQ_MAJOR_VERSION);
100
+ PyErr_SetString(PyExc_RuntimeError, msgBuff);
101
+ dlclose(lib);
102
+ return -1;
103
+ }
104
+
105
+ TRITON_HIP_DRIVER_LOG_VERSION(hipVersion, msgBuff);
106
+
107
+ return hipVersion;
108
+ }
109
+
110
+ bool initSymbolTable() {
111
+ void *lib;
112
+
113
+ // Go through the list of search paths to dlopen the first HIP driver library.
114
+ int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
115
+ for (int i = 0; i < n; ++i) {
116
+ void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
117
+ if (handle) {
118
+ lib = handle;
119
+ // printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
76
120
  }
77
121
  }
122
+
78
123
  if (!lib) {
79
124
  PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
80
125
  return false;
81
126
  }
82
127
 
83
- // Resolve all symbols we are interested in.
84
- dlerror(); // Clear existing errors
128
+ int hipVersion = checkDriverVersion(lib);
129
+ if (hipVersion == -1)
130
+ return false;
131
+
85
132
  const char *error = NULL;
133
+ typedef hipError_t (*hipGetProcAddress_fn)(
134
+ const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
135
+ hipDriverProcAddressQueryResult *symbolStatus);
136
+ hipGetProcAddress_fn hipGetProcAddress;
137
+ dlerror(); // Clear existing errors
138
+
139
+ *(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
140
+ error = dlerror();
141
+ if (error) {
142
+ PyErr_SetString(PyExc_RuntimeError,
143
+ "cannot query 'hipGetProcAddress' from libamdhip64.so");
144
+ dlclose(lib);
145
+ return false;
146
+ }
147
+
148
+ // Resolve all symbols we are interested in.
149
+ uint64_t hipFlags = 0;
150
+ hipDriverProcAddressQueryResult symbolStatus;
151
+ hipError_t status = hipSuccess;
86
152
  #define QUERY_EACH_FN(hipSymbolName, ...) \
87
- *(void **)&hipSymbolTable.hipSymbolName = dlsym(lib, #hipSymbolName); \
88
- error = dlerror(); \
89
- if (error) { \
153
+ status = hipGetProcAddress(#hipSymbolName, \
154
+ (void **)&hipSymbolTable.hipSymbolName, \
155
+ hipVersion, hipFlags, &symbolStatus); \
156
+ if (status != hipSuccess) { \
90
157
  PyErr_SetString(PyExc_RuntimeError, \
91
- "cannot query " #hipSymbolName " from libamdhip64.so"); \
158
+ "cannot get address for '" #hipSymbolName \
159
+ "' from libamdhip64.so"); \
92
160
  dlclose(lib); \
93
161
  return false; \
94
162
  }
@@ -104,8 +172,9 @@ static inline void gpuAssert(hipError_t code, const char *file, int line) {
104
172
  {
105
173
  const char *prefix = "Triton Error [HIP]: ";
106
174
  const char *str = hipSymbolTable.hipGetErrorString(code);
107
- char err[1024] = {0};
108
- snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str);
175
+ char err[TRITON_HIP_MSG_BUFF_SIZE] = {0};
176
+ snprintf(err, sizeof(err), "%s Code: %d, Messsage: %s", prefix, code,
177
+ str);
109
178
  PyGILState_STATE gil_state;
110
179
  gil_state = PyGILState_Ensure();
111
180
  PyErr_SetString(PyExc_RuntimeError, err);
@@ -127,7 +196,7 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
127
196
  if (!PyArg_ParseTuple(args, "i", &device_id))
128
197
  return NULL;
129
198
 
130
- hipDeviceProp_tR0000 props;
199
+ hipDeviceProp_t props;
131
200
  HIP_CHECK(hipSymbolTable.hipGetDeviceProperties(&props, device_id));
132
201
 
133
202
  // create a struct to hold device properties
@@ -6,6 +6,7 @@ from pathlib import Path
6
6
  from triton import knobs
7
7
  from triton.backends.compiler import GPUTarget
8
8
  from triton.backends.driver import GPUDriver
9
+ from triton.runtime import _allocation
9
10
  from triton.runtime.build import compile_module_from_src
10
11
  from triton.tools.tensor_descriptor import TensorDescriptor
11
12
 
@@ -109,8 +110,36 @@ def _get_path_to_hip_runtime_dylib():
109
110
  return f
110
111
  paths.append(f)
111
112
 
113
+ # HIP_PATH should point to HIP SDK root if set
114
+ env_hip_path = os.getenv("HIP_PATH")
115
+ if env_hip_path:
116
+ hip_lib_path = os.path.join(env_hip_path, "lib", lib_name)
117
+ if os.path.exists(hip_lib_path):
118
+ return hip_lib_path
119
+ paths.append(hip_lib_path)
120
+
121
+ # if available, `hipconfig --path` prints the HIP SDK root
122
+ try:
123
+ hip_root = subprocess.check_output(["hipconfig", "--path"]).decode().strip()
124
+ if hip_root:
125
+ hip_lib_path = os.path.join(hip_root, "lib", lib_name)
126
+ if os.path.exists(hip_lib_path):
127
+ return hip_lib_path
128
+ paths.append(hip_lib_path)
129
+ except (subprocess.CalledProcessError, FileNotFoundError):
130
+ # hipconfig may not be available
131
+ pass
132
+
133
+ # ROCm lib dir based on env var
134
+ env_rocm_path = os.getenv("ROCM_PATH")
135
+ if env_rocm_path:
136
+ rocm_lib_path = os.path.join(env_rocm_path, "lib", lib_name)
137
+ if os.path.exists(rocm_lib_path):
138
+ return rocm_lib_path
139
+ paths.append(rocm_lib_path)
140
+
112
141
  # Afterwards try to search the loader dynamic library resolution paths.
113
- libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
142
+ libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
114
143
  # each line looks like the following:
115
144
  # libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6
116
145
  # libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so
@@ -153,12 +182,12 @@ def ty_to_cpp(ty):
153
182
  if ty[0] == '*':
154
183
  return "hipDeviceptr_t"
155
184
  return {
156
- "i1": "int32_t",
185
+ "i1": "int8_t",
157
186
  "i8": "int8_t",
158
187
  "i16": "int16_t",
159
188
  "i32": "int32_t",
160
189
  "i64": "int64_t",
161
- "u1": "uint32_t",
190
+ "u1": "uint8_t",
162
191
  "u8": "uint8_t",
163
192
  "u16": "uint16_t",
164
193
  "u32": "uint32_t",
@@ -186,7 +215,7 @@ FLOAT_PACK_FUNCTION = {
186
215
  "fp64": "pack_fp64",
187
216
  }
188
217
 
189
- _BASE_ARGS_FORMAT = "piiiKKOOOO"
218
+ _BASE_ARGS_FORMAT = "piiiKKOOOOO"
190
219
 
191
220
 
192
221
  def make_launcher(constants, signature, warp_size):
@@ -203,6 +232,7 @@ def make_launcher(constants, signature, warp_size):
203
232
  output.append("*" + dtype)
204
233
  for _ in range(2 * ndim):
205
234
  output.append("i64")
235
+ output.append("i1")
206
236
  # Currently the host side tensor descriptors get passed in as a
207
237
  # tensor desc, shape, and strides. We have no way to use these
208
238
  # shape and strides when processing tensor descriptors which is
@@ -293,9 +323,11 @@ def make_launcher(constants, signature, warp_size):
293
323
  params = list(range(len(signature)))
294
324
  params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
295
325
  params.append("&global_scratch")
326
+ params.append("&profile_scratch")
296
327
  src = f"""
297
328
  #define __HIP_PLATFORM_AMD__
298
329
  #include <hip/hip_runtime.h>
330
+ #include <hip/hip_runtime_api.h>
299
331
  #include <Python.h>
300
332
  #include <dlfcn.h>
301
333
  #include <stdbool.h>
@@ -308,6 +340,7 @@ static const char *hipLibSearchPaths[] = {{"{libhip_path}"}};
308
340
  // The list of HIP dynamic library symbols and their signature we are interested
309
341
  // in this file.
310
342
  #define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \\
343
+ FOR_EACH_STR_FN(hipGetLastError) \\
311
344
  FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \\
312
345
  FOR_EACH_ERR_FN(hipModuleLaunchKernel, hipFunction_t f, \\
313
346
  unsigned int gridDimX, unsigned int gridDimY, \\
@@ -356,17 +389,36 @@ bool initSymbolTable() {{
356
389
  return false;
357
390
  }}
358
391
 
359
- // Resolve all symbols we are interested in.
392
+ typedef hipError_t (*hipGetProcAddress_fn)(
393
+ const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
394
+ hipDriverProcAddressQueryResult *symbolStatus);
395
+ hipGetProcAddress_fn hipGetProcAddress;
360
396
  dlerror(); // Clear existing errors
361
397
  const char *error = NULL;
362
- #define QUERY_EACH_FN(hipSymbolName, ...) \\
363
- *(void **)&hipSymbolTable.hipSymbolName = dlsym(lib, #hipSymbolName); \\
364
- error = dlerror(); \\
365
- if (error) {{ \\
366
- PyErr_SetString(PyExc_RuntimeError, \\
367
- "cannot query " #hipSymbolName " from libamdhip64.so"); \\
368
- dlclose(lib); \\
369
- return false; \\
398
+ *(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
399
+ error = dlerror();
400
+ if (error) {{
401
+ PyErr_SetString(PyExc_RuntimeError,
402
+ "cannot query 'hipGetProcAddress' from libamdhip64.so");
403
+ dlclose(lib);
404
+ return false;
405
+ }}
406
+
407
+ // Resolve all symbols we are interested in.
408
+ int hipVersion = HIP_VERSION;
409
+ uint64_t hipFlags = 0;
410
+ hipDriverProcAddressQueryResult symbolStatus;
411
+ hipError_t status = hipSuccess;
412
+ #define QUERY_EACH_FN(hipSymbolName, ...) \
413
+ status = hipGetProcAddress(#hipSymbolName, \
414
+ (void **)&hipSymbolTable.hipSymbolName, \
415
+ hipVersion, hipFlags, &symbolStatus); \
416
+ if (status != hipSuccess) {{ \
417
+ PyErr_SetString(PyExc_RuntimeError, \
418
+ "cannot get address for '" #hipSymbolName \
419
+ "' from libamdhip64.so"); \
420
+ dlclose(lib); \
421
+ return false; \
370
422
  }}
371
423
 
372
424
  HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
@@ -388,7 +440,7 @@ static inline void gpuAssert(hipError_t code, const char *file, int line)
388
440
 
389
441
  #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
390
442
 
391
- static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
443
+ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
392
444
  hipDeviceptr_t global_scratch = 0;
393
445
  void *params[] = {{ {', '.join(params)} }};
394
446
  if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
@@ -405,8 +457,11 @@ typedef struct _DevicePtrInfo {{
405
457
  bool valid;
406
458
  }} DevicePtrInfo;
407
459
 
460
+ static PyObject* data_ptr_str = NULL;
461
+
408
462
  static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
409
463
  DevicePtrInfo ptr_info;
464
+ hipError_t status = hipSuccess;
410
465
  ptr_info.dev_ptr = 0;
411
466
  ptr_info.valid = true;
412
467
  if (PyLong_Check(obj)) {{
@@ -417,45 +472,42 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
417
472
  // valid nullptr
418
473
  return ptr_info;
419
474
  }}
420
- PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
421
- if(ptr){{
422
- PyObject *empty_tuple = PyTuple_New(0);
423
- PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
424
- Py_DECREF(empty_tuple);
425
- Py_DECREF(ptr);
426
- if (!PyLong_Check(ret)) {{
427
- PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
475
+ PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
476
+ if (!ret) {{
477
+ PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
478
+ ptr_info.valid = false;
479
+ goto cleanup;
480
+ }}
481
+ if (!PyLong_Check(ret)) {{
482
+ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
483
+ ptr_info.valid = false;
484
+ goto cleanup;
485
+ }}
486
+ ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
487
+ if (!ptr_info.dev_ptr)
488
+ goto cleanup;
489
+ uint64_t dev_ptr;
490
+ status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
491
+ if (status == hipErrorInvalidValue) {{
492
+ PyErr_Format(PyExc_ValueError,
493
+ "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
428
494
  ptr_info.valid = false;
429
- Py_DECREF(ret);
430
- return ptr_info;
431
- }}
432
- ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
433
- if(!ptr_info.dev_ptr) {{
434
- Py_DECREF(ret);
435
- return ptr_info;
436
- }}
437
- uint64_t dev_ptr;
438
- hipError_t status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
439
- if (status == hipErrorInvalidValue) {{
440
- PyErr_Format(PyExc_ValueError,
441
- "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
442
- ptr_info.valid = false;
443
- }}
444
- ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
445
- Py_DECREF(ret);
446
- return ptr_info;
495
+ // Clear and ignore HIP error
496
+ (void)hipSymbolTable.hipGetLastError();
447
497
  }}
448
- PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
498
+ ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
499
+ cleanup:
500
+ Py_DECREF(ret);
449
501
  return ptr_info;
450
502
  }}
451
503
 
452
504
  static uint16_t pack_fp16(double f) {{
453
505
  uint16_t result;
454
- // from https://github.com/python/pythoncapi-compat
506
+ // from https://github.com/python/pythoncapi-compat/blob/5e317108f872c904eb726cb8d560dcadbdf88a72/pythoncapi_compat.h#L482-L492
455
507
  #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
456
508
  _PyFloat_Pack2(f, (unsigned char*)&result, 1);
457
509
  #else
458
- PyFloat_Pack2(f, (unsigned char*)&result, 1);
510
+ PyFloat_Pack2(f, (char*)&result, 1);
459
511
  #endif
460
512
  return result;
461
513
  }}
@@ -480,13 +532,14 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
480
532
  uint64_t _stream;
481
533
  uint64_t _function;
482
534
  int launch_cooperative_grid;
535
+ PyObject *profile_scratch_obj = NULL;
483
536
  PyObject *launch_enter_hook = NULL;
484
537
  PyObject *launch_exit_hook = NULL;
485
538
  PyObject *kernel_metadata = NULL;
486
539
  PyObject *launch_metadata = NULL;
487
540
  {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
488
541
  if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
489
- &gridX, &gridY, &gridZ, &_stream, &_function,
542
+ &gridX, &gridY, &gridZ, &_stream, &_function, &profile_scratch_obj,
490
543
  &kernel_metadata, &launch_metadata,
491
544
  &launch_enter_hook, &launch_exit_hook {args_list})) {{
492
545
  return NULL;
@@ -501,23 +554,27 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
501
554
  }}
502
555
  // extract launch metadata
503
556
  if (launch_enter_hook != Py_None){{
504
- PyObject* args = Py_BuildValue("(O)", launch_metadata);
505
- PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
506
- Py_DECREF(args);
557
+ PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
507
558
  if (!ret)
508
559
  return NULL;
509
560
  Py_DECREF(ret);
510
561
  }}
511
562
 
563
+ hipDeviceptr_t profile_scratch = 0;
564
+ if (profile_scratch_obj != Py_None) {{
565
+ DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
566
+ if (!profile_scratch_info.valid) {{
567
+ return NULL;
568
+ }}
569
+ profile_scratch = profile_scratch_info.dev_ptr;
570
+ }}
512
571
 
513
572
  // raise exception asap
514
573
  {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
515
- _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
574
+ _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
516
575
 
517
576
  if(launch_exit_hook != Py_None){{
518
- PyObject* args = Py_BuildValue("(O)", launch_metadata);
519
- PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
520
- Py_DECREF(args);
577
+ PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
521
578
  if (!ret)
522
579
  return NULL;
523
580
  Py_DECREF(ret);
@@ -526,9 +583,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
526
583
  if(PyErr_Occurred()) {{
527
584
  return NULL;
528
585
  }}
529
- // return None
530
- Py_INCREF(Py_None);
531
- return Py_None;
586
+ Py_RETURN_NONE;
532
587
  }}
533
588
 
534
589
  static PyMethodDef ModuleMethods[] = {{
@@ -552,6 +607,10 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
552
607
  if(m == NULL) {{
553
608
  return NULL;
554
609
  }}
610
+ data_ptr_str = PyUnicode_InternFromString("data_ptr");
611
+ if(data_ptr_str == NULL) {{
612
+ return NULL;
613
+ }}
555
614
  PyModule_AddFunctions(m, ModuleMethods);
556
615
  return m;
557
616
  }}
@@ -576,7 +635,7 @@ def wrap_handle_tensor_descriptor(launcher):
576
635
  # descriptors which is why we provide our own decomposition
577
636
  # above. Sadly this means we have to pass the shape and strides
578
637
  # twice.
579
- final_args.extend([arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides])
638
+ final_args.extend([arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides])
580
639
  else:
581
640
  final_args.append(arg)
582
641
  return launcher(*meta_args, *final_args)
@@ -597,9 +656,23 @@ class HIPLauncher(object):
597
656
 
598
657
  self.launch = wrap_handle_tensor_descriptor(mod.launch) if has_tensor_desc_arg else mod.launch
599
658
  self.launch_cooperative_grid = metadata.launch_cooperative_grid
659
+ self.profile_scratch_size = metadata.profile_scratch_size
660
+ self.profile_scratch_align = metadata.profile_scratch_align
661
+
662
+ def __call__(self, gridX, gridY, gridZ, stream, function, *args):
663
+
664
+ def allocate_scratch(size, align, allocator):
665
+ if size > 0:
666
+ grid_size = gridX * gridY * gridZ
667
+ alloc_size = grid_size * size
668
+ alloc_fn = allocator.get()
669
+ return alloc_fn(alloc_size, align, stream)
670
+ return None
600
671
 
601
- def __call__(self, *args):
602
- self.launch(self.launch_cooperative_grid, *args)
672
+ profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
673
+ _allocation._profile_allocator)
674
+
675
+ self.launch(self.launch_cooperative_grid, gridX, gridY, gridZ, stream, function, profile_scratch, *args)
603
676
 
604
677
 
605
678
  class HIPDriver(GPUDriver):
@@ -621,6 +694,9 @@ class HIPDriver(GPUDriver):
621
694
  except ImportError:
622
695
  return False
623
696
 
697
+ def map_python_to_cpp_type(self, ty: str) -> str:
698
+ return ty_to_cpp(ty)
699
+
624
700
  def get_current_target(self):
625
701
  device = self.get_current_device()
626
702
  device_properties = self.utils.get_device_properties(device)
triton/backends/driver.py CHANGED
@@ -15,6 +15,19 @@ class DriverBase(metaclass=ABCMeta):
15
15
  def is_active(self):
16
16
  pass
17
17
 
18
+ @abstractmethod
19
+ def map_python_to_cpp_type(self, ty: str) -> str:
20
+ """
21
+ Converts a Triton type string to its corresponding C++ type string for this backend.
22
+
23
+ Args:
24
+ ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'.
25
+
26
+ Returns:
27
+ str: The C++ type string.
28
+ """
29
+ pass
30
+
18
31
  @abstractmethod
19
32
  def get_current_target(self):
20
33
  pass