triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
#define __HIP_PLATFORM_AMD__
|
|
2
|
+
#include <hip/hip_runtime.h>
|
|
3
|
+
#include <hip/hip_runtime_api.h>
|
|
4
|
+
#define PY_SSIZE_T_CLEAN
|
|
5
|
+
#include <Python.h>
|
|
6
|
+
#include <dlfcn.h>
|
|
7
|
+
#include <stdbool.h>
|
|
8
|
+
#include <stdio.h>
|
|
9
|
+
#include <stdlib.h>
|
|
10
|
+
|
|
11
|
+
// The list of paths to search for the HIP runtime library. The caller Python
|
|
12
|
+
// code should substitute the search path placeholder.
|
|
13
|
+
static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
|
|
14
|
+
|
|
15
|
+
// The list of HIP dynamic library symbols and their signature we are interested
|
|
16
|
+
// in this file.
|
|
17
|
+
// |FOR_EACH_ERR_FN| is a macro to process APIs that return hipError_t;
|
|
18
|
+
// |FOR_EACH_STR_FN| is a macro to process APIs that return const char *.
|
|
19
|
+
#define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \
|
|
20
|
+
FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \
|
|
21
|
+
FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_t *prop, int deviceId) \
|
|
22
|
+
FOR_EACH_ERR_FN(hipModuleLoadDataEx, hipModule_t *module, const void *image, \
|
|
23
|
+
unsigned int numOptions, hipJitOption *options, \
|
|
24
|
+
void **optionValues) \
|
|
25
|
+
FOR_EACH_ERR_FN(hipModuleGetFunction, hipFunction_t *function, \
|
|
26
|
+
hipModule_t module, const char *kname) \
|
|
27
|
+
FOR_EACH_ERR_FN(hipFuncGetAttribute, int *, hipFunction_attribute attr, \
|
|
28
|
+
hipFunction_t function)
|
|
29
|
+
|
|
30
|
+
// HIP driver version format: HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR *
|
|
31
|
+
// 100000 + HIP_VERSION_PATCH.
|
|
32
|
+
#define TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version) ((version) / 10000000)
|
|
33
|
+
#define TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version) \
|
|
34
|
+
(((version) % 10000000) / 100000)
|
|
35
|
+
#define TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version) ((version) % 100000)
|
|
36
|
+
#define TRITON_HIP_DRIVER_REQ_MAJOR_VERSION (HIP_VERSION_MAJOR)
|
|
37
|
+
|
|
38
|
+
// #define TRITON_HIP_DRIVER_DBG_VERSION
|
|
39
|
+
#ifdef TRITON_HIP_DRIVER_DBG_VERSION
|
|
40
|
+
#define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
|
|
41
|
+
do { \
|
|
42
|
+
snprintf(msgBuff, sizeof(msgBuff), "libamdhip64 version is: %d.%d.%d", \
|
|
43
|
+
TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version), \
|
|
44
|
+
TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version), \
|
|
45
|
+
TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version)); \
|
|
46
|
+
printf("%s\n", msgBuff); \
|
|
47
|
+
} while (0);
|
|
48
|
+
#else
|
|
49
|
+
#define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
|
|
50
|
+
do { \
|
|
51
|
+
(void)msgBuff; \
|
|
52
|
+
(void)(version); \
|
|
53
|
+
} while (0);
|
|
54
|
+
#endif
|
|
55
|
+
|
|
56
|
+
#define TRITON_HIP_MSG_BUFF_SIZE (1024U)
|
|
57
|
+
|
|
58
|
+
// The HIP symbol table for holding resolved dynamic library symbols.
|
|
59
|
+
struct HIPSymbolTable {
|
|
60
|
+
#define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \
|
|
61
|
+
hipError_t (*hipSymbolName)(__VA_ARGS__);
|
|
62
|
+
#define DEFINE_EACH_STR_FIELD(hipSymbolName, ...) \
|
|
63
|
+
const char *(*hipSymbolName)(__VA_ARGS__);
|
|
64
|
+
|
|
65
|
+
HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD)
|
|
66
|
+
};
|
|
67
|
+
|
|
68
|
+
static struct HIPSymbolTable hipSymbolTable;
|
|
69
|
+
|
|
70
|
+
static int checkDriverVersion(void *lib) {
|
|
71
|
+
int hipVersion = -1;
|
|
72
|
+
const char *error = NULL;
|
|
73
|
+
typedef hipError_t (*hipDriverGetVersion_fn)(int *driverVersion);
|
|
74
|
+
hipDriverGetVersion_fn hipDriverGetVersion;
|
|
75
|
+
dlerror(); // Clear existing errors
|
|
76
|
+
hipDriverGetVersion =
|
|
77
|
+
(hipDriverGetVersion_fn)dlsym(lib, "hipDriverGetVersion");
|
|
78
|
+
error = dlerror();
|
|
79
|
+
if (error) {
|
|
80
|
+
PyErr_SetString(PyExc_RuntimeError,
|
|
81
|
+
"cannot query 'hipDriverGetVersion' from libamdhip64.so");
|
|
82
|
+
dlclose(lib);
|
|
83
|
+
return -1;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
(void)hipDriverGetVersion(&hipVersion);
|
|
87
|
+
char msgBuff[TRITON_HIP_MSG_BUFF_SIZE] = {0};
|
|
88
|
+
|
|
89
|
+
const int hipMajVersion = TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(hipVersion);
|
|
90
|
+
if (hipMajVersion < TRITON_HIP_DRIVER_REQ_MAJOR_VERSION) {
|
|
91
|
+
const int hipMinVersion =
|
|
92
|
+
TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(hipVersion);
|
|
93
|
+
const int hipPatchVersion =
|
|
94
|
+
TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(hipVersion);
|
|
95
|
+
snprintf(msgBuff, sizeof(msgBuff),
|
|
96
|
+
"libamdhip64 version %d.%d.%d is not supported! Required major "
|
|
97
|
+
"version is >=%d.",
|
|
98
|
+
hipMajVersion, hipMinVersion, hipPatchVersion,
|
|
99
|
+
TRITON_HIP_DRIVER_REQ_MAJOR_VERSION);
|
|
100
|
+
PyErr_SetString(PyExc_RuntimeError, msgBuff);
|
|
101
|
+
dlclose(lib);
|
|
102
|
+
return -1;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
TRITON_HIP_DRIVER_LOG_VERSION(hipVersion, msgBuff);
|
|
106
|
+
|
|
107
|
+
return hipVersion;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
bool initSymbolTable() {
|
|
111
|
+
void *lib;
|
|
112
|
+
|
|
113
|
+
// Go through the list of search paths to dlopen the first HIP driver library.
|
|
114
|
+
int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
|
|
115
|
+
for (int i = 0; i < n; ++i) {
|
|
116
|
+
void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
|
|
117
|
+
if (handle) {
|
|
118
|
+
lib = handle;
|
|
119
|
+
// printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
if (!lib) {
|
|
124
|
+
PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
|
|
125
|
+
return false;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
int hipVersion = checkDriverVersion(lib);
|
|
129
|
+
if (hipVersion == -1)
|
|
130
|
+
return false;
|
|
131
|
+
|
|
132
|
+
const char *error = NULL;
|
|
133
|
+
typedef hipError_t (*hipGetProcAddress_fn)(
|
|
134
|
+
const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
|
|
135
|
+
hipDriverProcAddressQueryResult *symbolStatus);
|
|
136
|
+
hipGetProcAddress_fn hipGetProcAddress;
|
|
137
|
+
dlerror(); // Clear existing errors
|
|
138
|
+
|
|
139
|
+
*(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
|
|
140
|
+
error = dlerror();
|
|
141
|
+
if (error) {
|
|
142
|
+
PyErr_SetString(PyExc_RuntimeError,
|
|
143
|
+
"cannot query 'hipGetProcAddress' from libamdhip64.so");
|
|
144
|
+
dlclose(lib);
|
|
145
|
+
return false;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
// Resolve all symbols we are interested in.
|
|
149
|
+
uint64_t hipFlags = 0;
|
|
150
|
+
hipDriverProcAddressQueryResult symbolStatus;
|
|
151
|
+
hipError_t status = hipSuccess;
|
|
152
|
+
#define QUERY_EACH_FN(hipSymbolName, ...) \
|
|
153
|
+
status = hipGetProcAddress(#hipSymbolName, \
|
|
154
|
+
(void **)&hipSymbolTable.hipSymbolName, \
|
|
155
|
+
hipVersion, hipFlags, &symbolStatus); \
|
|
156
|
+
if (status != hipSuccess) { \
|
|
157
|
+
PyErr_SetString(PyExc_RuntimeError, \
|
|
158
|
+
"cannot get address for '" #hipSymbolName \
|
|
159
|
+
"' from libamdhip64.so"); \
|
|
160
|
+
dlclose(lib); \
|
|
161
|
+
return false; \
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
|
|
165
|
+
|
|
166
|
+
return true;
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
static inline void gpuAssert(hipError_t code, const char *file, int line) {
|
|
170
|
+
{
|
|
171
|
+
if (code != HIP_SUCCESS) {
|
|
172
|
+
{
|
|
173
|
+
const char *prefix = "Triton Error [HIP]: ";
|
|
174
|
+
const char *str = hipSymbolTable.hipGetErrorString(code);
|
|
175
|
+
char err[TRITON_HIP_MSG_BUFF_SIZE] = {0};
|
|
176
|
+
snprintf(err, sizeof(err), "%s Code: %d, Messsage: %s", prefix, code,
|
|
177
|
+
str);
|
|
178
|
+
PyGILState_STATE gil_state;
|
|
179
|
+
gil_state = PyGILState_Ensure();
|
|
180
|
+
PyErr_SetString(PyExc_RuntimeError, err);
|
|
181
|
+
PyGILState_Release(gil_state);
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
#define HIP_CHECK(ans) \
|
|
188
|
+
{ \
|
|
189
|
+
gpuAssert((ans), __FILE__, __LINE__); \
|
|
190
|
+
if (PyErr_Occurred()) \
|
|
191
|
+
return NULL; \
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
|
195
|
+
int device_id;
|
|
196
|
+
if (!PyArg_ParseTuple(args, "i", &device_id))
|
|
197
|
+
return NULL;
|
|
198
|
+
|
|
199
|
+
hipDeviceProp_t props;
|
|
200
|
+
HIP_CHECK(hipSymbolTable.hipGetDeviceProperties(&props, device_id));
|
|
201
|
+
|
|
202
|
+
// create a struct to hold device properties
|
|
203
|
+
return Py_BuildValue(
|
|
204
|
+
"{s:i, s:i, s:i, s:i, s:i, s:i, s:s, s:i, s:i}", "max_shared_mem",
|
|
205
|
+
props.sharedMemPerBlock, "max_num_regs", props.regsPerBlock,
|
|
206
|
+
"multiprocessor_count", props.multiProcessorCount, "sm_clock_rate",
|
|
207
|
+
props.clockRate, "mem_clock_rate", props.memoryClockRate, "mem_bus_width",
|
|
208
|
+
props.memoryBusWidth, "arch", props.gcnArchName, "warpSize",
|
|
209
|
+
props.warpSize, "max_threads_per_sm", props.maxThreadsPerMultiProcessor);
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
|
213
|
+
const char *name;
|
|
214
|
+
const char *data;
|
|
215
|
+
Py_ssize_t data_size;
|
|
216
|
+
int shared;
|
|
217
|
+
int device;
|
|
218
|
+
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
|
|
219
|
+
&device)) {
|
|
220
|
+
return NULL;
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
// set HIP options
|
|
224
|
+
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
|
|
225
|
+
hipJitOptionErrorLogBuffer,
|
|
226
|
+
hipJitOptionInfoLogBufferSizeBytes,
|
|
227
|
+
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
|
|
228
|
+
const unsigned int errbufsize = 8192;
|
|
229
|
+
const unsigned int logbufsize = 8192;
|
|
230
|
+
char _err[errbufsize];
|
|
231
|
+
char _log[logbufsize];
|
|
232
|
+
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
|
|
233
|
+
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
|
|
234
|
+
|
|
235
|
+
// launch HIP Binary
|
|
236
|
+
hipModule_t mod;
|
|
237
|
+
hipFunction_t fun;
|
|
238
|
+
HIP_CHECK(hipSymbolTable.hipModuleLoadDataEx(&mod, data, 5, opt, optval))
|
|
239
|
+
HIP_CHECK(hipSymbolTable.hipModuleGetFunction(&fun, mod, name));
|
|
240
|
+
|
|
241
|
+
// get allocated registers and spilled registers from the function
|
|
242
|
+
int n_regs = 0;
|
|
243
|
+
int n_spills = 0;
|
|
244
|
+
int32_t n_max_threads = 0;
|
|
245
|
+
hipSymbolTable.hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
|
246
|
+
hipSymbolTable.hipFuncGetAttribute(&n_spills,
|
|
247
|
+
HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
|
248
|
+
hipSymbolTable.hipFuncGetAttribute(
|
|
249
|
+
&n_max_threads, HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun);
|
|
250
|
+
n_spills /= 4;
|
|
251
|
+
if (PyErr_Occurred()) {
|
|
252
|
+
return NULL;
|
|
253
|
+
}
|
|
254
|
+
return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
|
|
255
|
+
n_spills, n_max_threads);
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
static PyMethodDef ModuleMethods[] = {
|
|
259
|
+
{"load_binary", loadBinary, METH_VARARGS,
|
|
260
|
+
"Load provided hsaco into HIP driver"},
|
|
261
|
+
{"get_device_properties", getDeviceProperties, METH_VARARGS,
|
|
262
|
+
"Get the properties for a given device"},
|
|
263
|
+
{NULL, NULL, 0, NULL} // sentinel
|
|
264
|
+
};
|
|
265
|
+
|
|
266
|
+
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils",
|
|
267
|
+
NULL, // documentation
|
|
268
|
+
-1, // size
|
|
269
|
+
ModuleMethods};
|
|
270
|
+
|
|
271
|
+
PyMODINIT_FUNC PyInit_hip_utils(void) {
|
|
272
|
+
if (!initSymbolTable()) {
|
|
273
|
+
return NULL;
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
PyObject *m = PyModule_Create(&ModuleDef);
|
|
277
|
+
if (m == NULL) {
|
|
278
|
+
return NULL;
|
|
279
|
+
}
|
|
280
|
+
PyModule_AddFunctions(m, ModuleMethods);
|
|
281
|
+
|
|
282
|
+
return m;
|
|
283
|
+
}
|