triton-windows 3.4.0.post20__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
triton/backends/amd/driver.c
CHANGED
|
@@ -1,12 +1,10 @@
|
|
|
1
1
|
#define __HIP_PLATFORM_AMD__
|
|
2
|
-
// clang-format off
|
|
3
|
-
// hip_depreated.h needs definitions from hip_runtime.h.
|
|
4
2
|
#include <hip/hip_runtime.h>
|
|
5
|
-
#include <hip/
|
|
6
|
-
// clang-format on
|
|
3
|
+
#include <hip/hip_runtime_api.h>
|
|
7
4
|
#define PY_SSIZE_T_CLEAN
|
|
8
5
|
#include <Python.h>
|
|
9
6
|
#include <dlfcn.h>
|
|
7
|
+
#include <stdbool.h>
|
|
10
8
|
#include <stdio.h>
|
|
11
9
|
#include <stdlib.h>
|
|
12
10
|
|
|
@@ -18,24 +16,9 @@ static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
|
|
|
18
16
|
// in this file.
|
|
19
17
|
// |FOR_EACH_ERR_FN| is a macro to process APIs that return hipError_t;
|
|
20
18
|
// |FOR_EACH_STR_FN| is a macro to process APIs that return const char *.
|
|
21
|
-
//
|
|
22
|
-
// HIP 6.0 introduced an updated hipGetDeviceProperties API under a new symbol,
|
|
23
|
-
// hipGetDevicePropertiesR0600. However, the associated hipDeviceProp_t was
|
|
24
|
-
// directly updated with breaking changes to match hipGetDevicePropertiesR0600
|
|
25
|
-
// in the header file. We include the header file from HIP 6.0. So here if we
|
|
26
|
-
// use hipGetDeviceProperties together with hipDeviceProp_t we will use the
|
|
27
|
-
// old API with a new struct definition and mess up the interpretation.
|
|
28
|
-
//
|
|
29
|
-
// This is a known issue: https://github.com/ROCm/ROCm/issues/2728.
|
|
30
|
-
//
|
|
31
|
-
// For now explicitly defer to the old hipDeviceProp_t struct. This should work
|
|
32
|
-
// for both 5.x and 6.x. In the long term we need to switch to use
|
|
33
|
-
// hipGetProcAddress once available:
|
|
34
|
-
// https://github.com/ROCm/clr/commit/0479cdb3dd30ef58718cad44e424bd793c394cc0
|
|
35
19
|
#define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \
|
|
36
20
|
FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \
|
|
37
|
-
FOR_EACH_ERR_FN(hipGetDeviceProperties,
|
|
38
|
-
int deviceId) \
|
|
21
|
+
FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_t *prop, int deviceId) \
|
|
39
22
|
FOR_EACH_ERR_FN(hipModuleLoadDataEx, hipModule_t *module, const void *image, \
|
|
40
23
|
unsigned int numOptions, hipJitOption *options, \
|
|
41
24
|
void **optionValues) \
|
|
@@ -44,6 +27,34 @@ static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
|
|
|
44
27
|
FOR_EACH_ERR_FN(hipFuncGetAttribute, int *, hipFunction_attribute attr, \
|
|
45
28
|
hipFunction_t function)
|
|
46
29
|
|
|
30
|
+
// HIP driver version format: HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR *
|
|
31
|
+
// 100000 + HIP_VERSION_PATCH.
|
|
32
|
+
#define TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version) ((version) / 10000000)
|
|
33
|
+
#define TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version) \
|
|
34
|
+
(((version) % 10000000) / 100000)
|
|
35
|
+
#define TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version) ((version) % 100000)
|
|
36
|
+
#define TRITON_HIP_DRIVER_REQ_MAJOR_VERSION (HIP_VERSION_MAJOR)
|
|
37
|
+
|
|
38
|
+
// #define TRITON_HIP_DRIVER_DBG_VERSION
|
|
39
|
+
#ifdef TRITON_HIP_DRIVER_DBG_VERSION
|
|
40
|
+
#define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
|
|
41
|
+
do { \
|
|
42
|
+
snprintf(msgBuff, sizeof(msgBuff), "libamdhip64 version is: %d.%d.%d", \
|
|
43
|
+
TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version), \
|
|
44
|
+
TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version), \
|
|
45
|
+
TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version)); \
|
|
46
|
+
printf("%s\n", msgBuff); \
|
|
47
|
+
} while (0);
|
|
48
|
+
#else
|
|
49
|
+
#define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
|
|
50
|
+
do { \
|
|
51
|
+
(void)msgBuff; \
|
|
52
|
+
(void)(version); \
|
|
53
|
+
} while (0);
|
|
54
|
+
#endif
|
|
55
|
+
|
|
56
|
+
#define TRITON_HIP_MSG_BUFF_SIZE (1024U)
|
|
57
|
+
|
|
47
58
|
// The HIP symbol table for holding resolved dynamic library symbols.
|
|
48
59
|
struct HIPSymbolTable {
|
|
49
60
|
#define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \
|
|
@@ -56,39 +67,96 @@ struct HIPSymbolTable {
|
|
|
56
67
|
|
|
57
68
|
static struct HIPSymbolTable hipSymbolTable;
|
|
58
69
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
70
|
+
static int checkDriverVersion(void *lib) {
|
|
71
|
+
int hipVersion = -1;
|
|
72
|
+
const char *error = NULL;
|
|
73
|
+
typedef hipError_t (*hipDriverGetVersion_fn)(int *driverVersion);
|
|
74
|
+
hipDriverGetVersion_fn hipDriverGetVersion;
|
|
75
|
+
dlerror(); // Clear existing errors
|
|
76
|
+
hipDriverGetVersion =
|
|
77
|
+
(hipDriverGetVersion_fn)dlsym(lib, "hipDriverGetVersion");
|
|
78
|
+
error = dlerror();
|
|
79
|
+
if (error) {
|
|
80
|
+
PyErr_SetString(PyExc_RuntimeError,
|
|
81
|
+
"cannot query 'hipDriverGetVersion' from libamdhip64.so");
|
|
82
|
+
dlclose(lib);
|
|
83
|
+
return -1;
|
|
64
84
|
}
|
|
65
85
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
86
|
+
(void)hipDriverGetVersion(&hipVersion);
|
|
87
|
+
char msgBuff[TRITON_HIP_MSG_BUFF_SIZE] = {0};
|
|
88
|
+
|
|
89
|
+
const int hipMajVersion = TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(hipVersion);
|
|
90
|
+
if (hipMajVersion < TRITON_HIP_DRIVER_REQ_MAJOR_VERSION) {
|
|
91
|
+
const int hipMinVersion =
|
|
92
|
+
TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(hipVersion);
|
|
93
|
+
const int hipPatchVersion =
|
|
94
|
+
TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(hipVersion);
|
|
95
|
+
snprintf(msgBuff, sizeof(msgBuff),
|
|
96
|
+
"libamdhip64 version %d.%d.%d is not supported! Required major "
|
|
97
|
+
"version is >=%d.",
|
|
98
|
+
hipMajVersion, hipMinVersion, hipPatchVersion,
|
|
99
|
+
TRITON_HIP_DRIVER_REQ_MAJOR_VERSION);
|
|
100
|
+
PyErr_SetString(PyExc_RuntimeError, msgBuff);
|
|
101
|
+
dlclose(lib);
|
|
102
|
+
return -1;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
TRITON_HIP_DRIVER_LOG_VERSION(hipVersion, msgBuff);
|
|
106
|
+
|
|
107
|
+
return hipVersion;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
bool initSymbolTable() {
|
|
111
|
+
void *lib;
|
|
112
|
+
|
|
113
|
+
// Go through the list of search paths to dlopen the first HIP driver library.
|
|
114
|
+
int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
|
|
115
|
+
for (int i = 0; i < n; ++i) {
|
|
116
|
+
void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
|
|
117
|
+
if (handle) {
|
|
118
|
+
lib = handle;
|
|
119
|
+
// printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
|
|
76
120
|
}
|
|
77
121
|
}
|
|
122
|
+
|
|
78
123
|
if (!lib) {
|
|
79
124
|
PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
|
|
80
125
|
return false;
|
|
81
126
|
}
|
|
82
127
|
|
|
83
|
-
|
|
84
|
-
|
|
128
|
+
int hipVersion = checkDriverVersion(lib);
|
|
129
|
+
if (hipVersion == -1)
|
|
130
|
+
return false;
|
|
131
|
+
|
|
85
132
|
const char *error = NULL;
|
|
133
|
+
typedef hipError_t (*hipGetProcAddress_fn)(
|
|
134
|
+
const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
|
|
135
|
+
hipDriverProcAddressQueryResult *symbolStatus);
|
|
136
|
+
hipGetProcAddress_fn hipGetProcAddress;
|
|
137
|
+
dlerror(); // Clear existing errors
|
|
138
|
+
|
|
139
|
+
*(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
|
|
140
|
+
error = dlerror();
|
|
141
|
+
if (error) {
|
|
142
|
+
PyErr_SetString(PyExc_RuntimeError,
|
|
143
|
+
"cannot query 'hipGetProcAddress' from libamdhip64.so");
|
|
144
|
+
dlclose(lib);
|
|
145
|
+
return false;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
// Resolve all symbols we are interested in.
|
|
149
|
+
uint64_t hipFlags = 0;
|
|
150
|
+
hipDriverProcAddressQueryResult symbolStatus;
|
|
151
|
+
hipError_t status = hipSuccess;
|
|
86
152
|
#define QUERY_EACH_FN(hipSymbolName, ...) \
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
153
|
+
status = hipGetProcAddress(#hipSymbolName, \
|
|
154
|
+
(void **)&hipSymbolTable.hipSymbolName, \
|
|
155
|
+
hipVersion, hipFlags, &symbolStatus); \
|
|
156
|
+
if (status != hipSuccess) { \
|
|
90
157
|
PyErr_SetString(PyExc_RuntimeError, \
|
|
91
|
-
"cannot
|
|
158
|
+
"cannot get address for '" #hipSymbolName \
|
|
159
|
+
"' from libamdhip64.so"); \
|
|
92
160
|
dlclose(lib); \
|
|
93
161
|
return false; \
|
|
94
162
|
}
|
|
@@ -104,8 +172,9 @@ static inline void gpuAssert(hipError_t code, const char *file, int line) {
|
|
|
104
172
|
{
|
|
105
173
|
const char *prefix = "Triton Error [HIP]: ";
|
|
106
174
|
const char *str = hipSymbolTable.hipGetErrorString(code);
|
|
107
|
-
char err[
|
|
108
|
-
snprintf(err,
|
|
175
|
+
char err[TRITON_HIP_MSG_BUFF_SIZE] = {0};
|
|
176
|
+
snprintf(err, sizeof(err), "%s Code: %d, Messsage: %s", prefix, code,
|
|
177
|
+
str);
|
|
109
178
|
PyGILState_STATE gil_state;
|
|
110
179
|
gil_state = PyGILState_Ensure();
|
|
111
180
|
PyErr_SetString(PyExc_RuntimeError, err);
|
|
@@ -127,7 +196,7 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
|
|
127
196
|
if (!PyArg_ParseTuple(args, "i", &device_id))
|
|
128
197
|
return NULL;
|
|
129
198
|
|
|
130
|
-
|
|
199
|
+
hipDeviceProp_t props;
|
|
131
200
|
HIP_CHECK(hipSymbolTable.hipGetDeviceProperties(&props, device_id));
|
|
132
201
|
|
|
133
202
|
// create a struct to hold device properties
|
triton/backends/amd/driver.py
CHANGED
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
|
|
6
6
|
from triton import knobs
|
|
7
7
|
from triton.backends.compiler import GPUTarget
|
|
8
8
|
from triton.backends.driver import GPUDriver
|
|
9
|
+
from triton.runtime import _allocation
|
|
9
10
|
from triton.runtime.build import compile_module_from_src
|
|
10
11
|
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
11
12
|
|
|
@@ -109,8 +110,36 @@ def _get_path_to_hip_runtime_dylib():
|
|
|
109
110
|
return f
|
|
110
111
|
paths.append(f)
|
|
111
112
|
|
|
113
|
+
# HIP_PATH should point to HIP SDK root if set
|
|
114
|
+
env_hip_path = os.getenv("HIP_PATH")
|
|
115
|
+
if env_hip_path:
|
|
116
|
+
hip_lib_path = os.path.join(env_hip_path, "lib", lib_name)
|
|
117
|
+
if os.path.exists(hip_lib_path):
|
|
118
|
+
return hip_lib_path
|
|
119
|
+
paths.append(hip_lib_path)
|
|
120
|
+
|
|
121
|
+
# if available, `hipconfig --path` prints the HIP SDK root
|
|
122
|
+
try:
|
|
123
|
+
hip_root = subprocess.check_output(["hipconfig", "--path"]).decode().strip()
|
|
124
|
+
if hip_root:
|
|
125
|
+
hip_lib_path = os.path.join(hip_root, "lib", lib_name)
|
|
126
|
+
if os.path.exists(hip_lib_path):
|
|
127
|
+
return hip_lib_path
|
|
128
|
+
paths.append(hip_lib_path)
|
|
129
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
130
|
+
# hipconfig may not be available
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
# ROCm lib dir based on env var
|
|
134
|
+
env_rocm_path = os.getenv("ROCM_PATH")
|
|
135
|
+
if env_rocm_path:
|
|
136
|
+
rocm_lib_path = os.path.join(env_rocm_path, "lib", lib_name)
|
|
137
|
+
if os.path.exists(rocm_lib_path):
|
|
138
|
+
return rocm_lib_path
|
|
139
|
+
paths.append(rocm_lib_path)
|
|
140
|
+
|
|
112
141
|
# Afterwards try to search the loader dynamic library resolution paths.
|
|
113
|
-
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
|
|
142
|
+
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
|
|
114
143
|
# each line looks like the following:
|
|
115
144
|
# libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6
|
|
116
145
|
# libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so
|
|
@@ -153,12 +182,12 @@ def ty_to_cpp(ty):
|
|
|
153
182
|
if ty[0] == '*':
|
|
154
183
|
return "hipDeviceptr_t"
|
|
155
184
|
return {
|
|
156
|
-
"i1": "
|
|
185
|
+
"i1": "int8_t",
|
|
157
186
|
"i8": "int8_t",
|
|
158
187
|
"i16": "int16_t",
|
|
159
188
|
"i32": "int32_t",
|
|
160
189
|
"i64": "int64_t",
|
|
161
|
-
"u1": "
|
|
190
|
+
"u1": "uint8_t",
|
|
162
191
|
"u8": "uint8_t",
|
|
163
192
|
"u16": "uint16_t",
|
|
164
193
|
"u32": "uint32_t",
|
|
@@ -186,7 +215,7 @@ FLOAT_PACK_FUNCTION = {
|
|
|
186
215
|
"fp64": "pack_fp64",
|
|
187
216
|
}
|
|
188
217
|
|
|
189
|
-
_BASE_ARGS_FORMAT = "
|
|
218
|
+
_BASE_ARGS_FORMAT = "piiiKKOOOOO"
|
|
190
219
|
|
|
191
220
|
|
|
192
221
|
def make_launcher(constants, signature, warp_size):
|
|
@@ -203,6 +232,7 @@ def make_launcher(constants, signature, warp_size):
|
|
|
203
232
|
output.append("*" + dtype)
|
|
204
233
|
for _ in range(2 * ndim):
|
|
205
234
|
output.append("i64")
|
|
235
|
+
output.append("i1")
|
|
206
236
|
# Currently the host side tensor descriptors get passed in as a
|
|
207
237
|
# tensor desc, shape, and strides. We have no way to use these
|
|
208
238
|
# shape and strides when processing tensor descriptors which is
|
|
@@ -293,9 +323,11 @@ def make_launcher(constants, signature, warp_size):
|
|
|
293
323
|
params = list(range(len(signature)))
|
|
294
324
|
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
|
|
295
325
|
params.append("&global_scratch")
|
|
326
|
+
params.append("&profile_scratch")
|
|
296
327
|
src = f"""
|
|
297
328
|
#define __HIP_PLATFORM_AMD__
|
|
298
329
|
#include <hip/hip_runtime.h>
|
|
330
|
+
#include <hip/hip_runtime_api.h>
|
|
299
331
|
#include <Python.h>
|
|
300
332
|
#include <dlfcn.h>
|
|
301
333
|
#include <stdbool.h>
|
|
@@ -308,6 +340,7 @@ static const char *hipLibSearchPaths[] = {{"{libhip_path}"}};
|
|
|
308
340
|
// The list of HIP dynamic library symbols and their signature we are interested
|
|
309
341
|
// in this file.
|
|
310
342
|
#define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \\
|
|
343
|
+
FOR_EACH_STR_FN(hipGetLastError) \\
|
|
311
344
|
FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \\
|
|
312
345
|
FOR_EACH_ERR_FN(hipModuleLaunchKernel, hipFunction_t f, \\
|
|
313
346
|
unsigned int gridDimX, unsigned int gridDimY, \\
|
|
@@ -356,17 +389,36 @@ bool initSymbolTable() {{
|
|
|
356
389
|
return false;
|
|
357
390
|
}}
|
|
358
391
|
|
|
359
|
-
|
|
392
|
+
typedef hipError_t (*hipGetProcAddress_fn)(
|
|
393
|
+
const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
|
|
394
|
+
hipDriverProcAddressQueryResult *symbolStatus);
|
|
395
|
+
hipGetProcAddress_fn hipGetProcAddress;
|
|
360
396
|
dlerror(); // Clear existing errors
|
|
361
397
|
const char *error = NULL;
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
398
|
+
*(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
|
|
399
|
+
error = dlerror();
|
|
400
|
+
if (error) {{
|
|
401
|
+
PyErr_SetString(PyExc_RuntimeError,
|
|
402
|
+
"cannot query 'hipGetProcAddress' from libamdhip64.so");
|
|
403
|
+
dlclose(lib);
|
|
404
|
+
return false;
|
|
405
|
+
}}
|
|
406
|
+
|
|
407
|
+
// Resolve all symbols we are interested in.
|
|
408
|
+
int hipVersion = HIP_VERSION;
|
|
409
|
+
uint64_t hipFlags = 0;
|
|
410
|
+
hipDriverProcAddressQueryResult symbolStatus;
|
|
411
|
+
hipError_t status = hipSuccess;
|
|
412
|
+
#define QUERY_EACH_FN(hipSymbolName, ...) \
|
|
413
|
+
status = hipGetProcAddress(#hipSymbolName, \
|
|
414
|
+
(void **)&hipSymbolTable.hipSymbolName, \
|
|
415
|
+
hipVersion, hipFlags, &symbolStatus); \
|
|
416
|
+
if (status != hipSuccess) {{ \
|
|
417
|
+
PyErr_SetString(PyExc_RuntimeError, \
|
|
418
|
+
"cannot get address for '" #hipSymbolName \
|
|
419
|
+
"' from libamdhip64.so"); \
|
|
420
|
+
dlclose(lib); \
|
|
421
|
+
return false; \
|
|
370
422
|
}}
|
|
371
423
|
|
|
372
424
|
HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
|
|
@@ -388,7 +440,7 @@ static inline void gpuAssert(hipError_t code, const char *file, int line)
|
|
|
388
440
|
|
|
389
441
|
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
|
390
442
|
|
|
391
|
-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
443
|
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
|
392
444
|
hipDeviceptr_t global_scratch = 0;
|
|
393
445
|
void *params[] = {{ {', '.join(params)} }};
|
|
394
446
|
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
|
|
@@ -405,8 +457,11 @@ typedef struct _DevicePtrInfo {{
|
|
|
405
457
|
bool valid;
|
|
406
458
|
}} DevicePtrInfo;
|
|
407
459
|
|
|
460
|
+
static PyObject* data_ptr_str = NULL;
|
|
461
|
+
|
|
408
462
|
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
409
463
|
DevicePtrInfo ptr_info;
|
|
464
|
+
hipError_t status = hipSuccess;
|
|
410
465
|
ptr_info.dev_ptr = 0;
|
|
411
466
|
ptr_info.valid = true;
|
|
412
467
|
if (PyLong_Check(obj)) {{
|
|
@@ -417,45 +472,42 @@ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
|
|
417
472
|
// valid nullptr
|
|
418
473
|
return ptr_info;
|
|
419
474
|
}}
|
|
420
|
-
PyObject *
|
|
421
|
-
if(
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
475
|
+
PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
|
|
476
|
+
if (!ret) {{
|
|
477
|
+
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
|
478
|
+
ptr_info.valid = false;
|
|
479
|
+
goto cleanup;
|
|
480
|
+
}}
|
|
481
|
+
if (!PyLong_Check(ret)) {{
|
|
482
|
+
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
|
483
|
+
ptr_info.valid = false;
|
|
484
|
+
goto cleanup;
|
|
485
|
+
}}
|
|
486
|
+
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
|
487
|
+
if (!ptr_info.dev_ptr)
|
|
488
|
+
goto cleanup;
|
|
489
|
+
uint64_t dev_ptr;
|
|
490
|
+
status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
|
491
|
+
if (status == hipErrorInvalidValue) {{
|
|
492
|
+
PyErr_Format(PyExc_ValueError,
|
|
493
|
+
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
|
428
494
|
ptr_info.valid = false;
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
}}
|
|
432
|
-
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
|
433
|
-
if(!ptr_info.dev_ptr) {{
|
|
434
|
-
Py_DECREF(ret);
|
|
435
|
-
return ptr_info;
|
|
436
|
-
}}
|
|
437
|
-
uint64_t dev_ptr;
|
|
438
|
-
hipError_t status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
|
439
|
-
if (status == hipErrorInvalidValue) {{
|
|
440
|
-
PyErr_Format(PyExc_ValueError,
|
|
441
|
-
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
|
442
|
-
ptr_info.valid = false;
|
|
443
|
-
}}
|
|
444
|
-
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
|
445
|
-
Py_DECREF(ret);
|
|
446
|
-
return ptr_info;
|
|
495
|
+
// Clear and ignore HIP error
|
|
496
|
+
(void)hipSymbolTable.hipGetLastError();
|
|
447
497
|
}}
|
|
448
|
-
|
|
498
|
+
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
|
499
|
+
cleanup:
|
|
500
|
+
Py_DECREF(ret);
|
|
449
501
|
return ptr_info;
|
|
450
502
|
}}
|
|
451
503
|
|
|
452
504
|
static uint16_t pack_fp16(double f) {{
|
|
453
505
|
uint16_t result;
|
|
454
|
-
// from https://github.com/python/pythoncapi-compat
|
|
506
|
+
// from https://github.com/python/pythoncapi-compat/blob/5e317108f872c904eb726cb8d560dcadbdf88a72/pythoncapi_compat.h#L482-L492
|
|
455
507
|
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
|
|
456
508
|
_PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
|
457
509
|
#else
|
|
458
|
-
PyFloat_Pack2(f, (
|
|
510
|
+
PyFloat_Pack2(f, (char*)&result, 1);
|
|
459
511
|
#endif
|
|
460
512
|
return result;
|
|
461
513
|
}}
|
|
@@ -480,13 +532,14 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
480
532
|
uint64_t _stream;
|
|
481
533
|
uint64_t _function;
|
|
482
534
|
int launch_cooperative_grid;
|
|
535
|
+
PyObject *profile_scratch_obj = NULL;
|
|
483
536
|
PyObject *launch_enter_hook = NULL;
|
|
484
537
|
PyObject *launch_exit_hook = NULL;
|
|
485
538
|
PyObject *kernel_metadata = NULL;
|
|
486
539
|
PyObject *launch_metadata = NULL;
|
|
487
540
|
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
|
488
541
|
if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
|
|
489
|
-
&gridX, &gridY, &gridZ, &_stream, &_function,
|
|
542
|
+
&gridX, &gridY, &gridZ, &_stream, &_function, &profile_scratch_obj,
|
|
490
543
|
&kernel_metadata, &launch_metadata,
|
|
491
544
|
&launch_enter_hook, &launch_exit_hook {args_list})) {{
|
|
492
545
|
return NULL;
|
|
@@ -501,23 +554,27 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
501
554
|
}}
|
|
502
555
|
// extract launch metadata
|
|
503
556
|
if (launch_enter_hook != Py_None){{
|
|
504
|
-
PyObject*
|
|
505
|
-
PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
|
|
506
|
-
Py_DECREF(args);
|
|
557
|
+
PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
|
|
507
558
|
if (!ret)
|
|
508
559
|
return NULL;
|
|
509
560
|
Py_DECREF(ret);
|
|
510
561
|
}}
|
|
511
562
|
|
|
563
|
+
hipDeviceptr_t profile_scratch = 0;
|
|
564
|
+
if (profile_scratch_obj != Py_None) {{
|
|
565
|
+
DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
|
|
566
|
+
if (!profile_scratch_info.valid) {{
|
|
567
|
+
return NULL;
|
|
568
|
+
}}
|
|
569
|
+
profile_scratch = profile_scratch_info.dev_ptr;
|
|
570
|
+
}}
|
|
512
571
|
|
|
513
572
|
// raise exception asap
|
|
514
573
|
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
|
515
|
-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
574
|
+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
|
516
575
|
|
|
517
576
|
if(launch_exit_hook != Py_None){{
|
|
518
|
-
PyObject*
|
|
519
|
-
PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
|
|
520
|
-
Py_DECREF(args);
|
|
577
|
+
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
|
|
521
578
|
if (!ret)
|
|
522
579
|
return NULL;
|
|
523
580
|
Py_DECREF(ret);
|
|
@@ -526,9 +583,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|
|
526
583
|
if(PyErr_Occurred()) {{
|
|
527
584
|
return NULL;
|
|
528
585
|
}}
|
|
529
|
-
|
|
530
|
-
Py_INCREF(Py_None);
|
|
531
|
-
return Py_None;
|
|
586
|
+
Py_RETURN_NONE;
|
|
532
587
|
}}
|
|
533
588
|
|
|
534
589
|
static PyMethodDef ModuleMethods[] = {{
|
|
@@ -552,6 +607,10 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
|
|
552
607
|
if(m == NULL) {{
|
|
553
608
|
return NULL;
|
|
554
609
|
}}
|
|
610
|
+
data_ptr_str = PyUnicode_InternFromString("data_ptr");
|
|
611
|
+
if(data_ptr_str == NULL) {{
|
|
612
|
+
return NULL;
|
|
613
|
+
}}
|
|
555
614
|
PyModule_AddFunctions(m, ModuleMethods);
|
|
556
615
|
return m;
|
|
557
616
|
}}
|
|
@@ -576,7 +635,7 @@ def wrap_handle_tensor_descriptor(launcher):
|
|
|
576
635
|
# descriptors which is why we provide our own decomposition
|
|
577
636
|
# above. Sadly this means we have to pass the shape and strides
|
|
578
637
|
# twice.
|
|
579
|
-
final_args.extend([arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides])
|
|
638
|
+
final_args.extend([arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides])
|
|
580
639
|
else:
|
|
581
640
|
final_args.append(arg)
|
|
582
641
|
return launcher(*meta_args, *final_args)
|
|
@@ -597,9 +656,23 @@ class HIPLauncher(object):
|
|
|
597
656
|
|
|
598
657
|
self.launch = wrap_handle_tensor_descriptor(mod.launch) if has_tensor_desc_arg else mod.launch
|
|
599
658
|
self.launch_cooperative_grid = metadata.launch_cooperative_grid
|
|
659
|
+
self.profile_scratch_size = metadata.profile_scratch_size
|
|
660
|
+
self.profile_scratch_align = metadata.profile_scratch_align
|
|
661
|
+
|
|
662
|
+
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
|
|
663
|
+
|
|
664
|
+
def allocate_scratch(size, align, allocator):
|
|
665
|
+
if size > 0:
|
|
666
|
+
grid_size = gridX * gridY * gridZ
|
|
667
|
+
alloc_size = grid_size * size
|
|
668
|
+
alloc_fn = allocator.get()
|
|
669
|
+
return alloc_fn(alloc_size, align, stream)
|
|
670
|
+
return None
|
|
600
671
|
|
|
601
|
-
|
|
602
|
-
|
|
672
|
+
profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
|
|
673
|
+
_allocation._profile_allocator)
|
|
674
|
+
|
|
675
|
+
self.launch(self.launch_cooperative_grid, gridX, gridY, gridZ, stream, function, profile_scratch, *args)
|
|
603
676
|
|
|
604
677
|
|
|
605
678
|
class HIPDriver(GPUDriver):
|
|
@@ -621,6 +694,9 @@ class HIPDriver(GPUDriver):
|
|
|
621
694
|
except ImportError:
|
|
622
695
|
return False
|
|
623
696
|
|
|
697
|
+
def map_python_to_cpp_type(self, ty: str) -> str:
|
|
698
|
+
return ty_to_cpp(ty)
|
|
699
|
+
|
|
624
700
|
def get_current_target(self):
|
|
625
701
|
device = self.get_current_device()
|
|
626
702
|
device_properties = self.utils.get_device_properties(device)
|
triton/backends/driver.py
CHANGED
|
@@ -15,6 +15,19 @@ class DriverBase(metaclass=ABCMeta):
|
|
|
15
15
|
def is_active(self):
|
|
16
16
|
pass
|
|
17
17
|
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def map_python_to_cpp_type(self, ty: str) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Converts a Triton type string to its corresponding C++ type string for this backend.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
str: The C++ type string.
|
|
28
|
+
"""
|
|
29
|
+
pass
|
|
30
|
+
|
|
18
31
|
@abstractmethod
|
|
19
32
|
def get_current_target(self):
|
|
20
33
|
pass
|