triton-windows 3.3.1.post19__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -1,12 +1,10 @@
1
1
  #define __HIP_PLATFORM_AMD__
2
- // clang-format off
3
- // hip_depreated.h needs definitions from hip_runtime.h.
4
2
  #include <hip/hip_runtime.h>
5
- #include <hip/hip_deprecated.h>
6
- // clang-format on
3
+ #include <hip/hip_runtime_api.h>
7
4
  #define PY_SSIZE_T_CLEAN
8
5
  #include <Python.h>
9
6
  #include <dlfcn.h>
7
+ #include <stdbool.h>
10
8
  #include <stdio.h>
11
9
  #include <stdlib.h>
12
10
 
@@ -18,24 +16,9 @@ static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
18
16
  // in this file.
19
17
  // |FOR_EACH_ERR_FN| is a macro to process APIs that return hipError_t;
20
18
  // |FOR_EACH_STR_FN| is a macro to process APIs that return const char *.
21
- //
22
- // HIP 6.0 introduced an updated hipGetDeviceProperties API under a new symbol,
23
- // hipGetDevicePropertiesR0600. However, the associated hipDeviceProp_t was
24
- // directly updated with breaking changes to match hipGetDevicePropertiesR0600
25
- // in the header file. We include the header file from HIP 6.0. So here if we
26
- // use hipGetDeviceProperties together with hipDeviceProp_t we will use the
27
- // old API with a new struct definition and mess up the interpretation.
28
- //
29
- // This is a known issue: https://github.com/ROCm/ROCm/issues/2728.
30
- //
31
- // For now explicitly defer to the old hipDeviceProp_t struct. This should work
32
- // for both 5.x and 6.x. In the long term we need to switch to use
33
- // hipGetProcAddress once available:
34
- // https://github.com/ROCm/clr/commit/0479cdb3dd30ef58718cad44e424bd793c394cc0
35
19
  #define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \
36
20
  FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \
37
- FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_tR0000 *prop, \
38
- int deviceId) \
21
+ FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_t *prop, int deviceId) \
39
22
  FOR_EACH_ERR_FN(hipModuleLoadDataEx, hipModule_t *module, const void *image, \
40
23
  unsigned int numOptions, hipJitOption *options, \
41
24
  void **optionValues) \
@@ -44,6 +27,34 @@ static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
44
27
  FOR_EACH_ERR_FN(hipFuncGetAttribute, int *, hipFunction_attribute attr, \
45
28
  hipFunction_t function)
46
29
 
30
+ // HIP driver version format: HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR *
31
+ // 100000 + HIP_VERSION_PATCH.
32
+ #define TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version) ((version) / 10000000)
33
+ #define TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version) \
34
+ (((version) % 10000000) / 100000)
35
+ #define TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version) ((version) % 100000)
36
+ #define TRITON_HIP_DRIVER_REQ_MAJOR_VERSION (HIP_VERSION_MAJOR)
37
+
38
+ // #define TRITON_HIP_DRIVER_DBG_VERSION
39
+ #ifdef TRITON_HIP_DRIVER_DBG_VERSION
40
+ #define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
41
+ do { \
42
+ snprintf(msgBuff, sizeof(msgBuff), "libamdhip64 version is: %d.%d.%d", \
43
+ TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version), \
44
+ TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version), \
45
+ TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version)); \
46
+ printf("%s\n", msgBuff); \
47
+ } while (0);
48
+ #else
49
+ #define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
50
+ do { \
51
+ (void)msgBuff; \
52
+ (void)(version); \
53
+ } while (0);
54
+ #endif
55
+
56
+ #define TRITON_HIP_MSG_BUFF_SIZE (1024U)
57
+
47
58
  // The HIP symbol table for holding resolved dynamic library symbols.
48
59
  struct HIPSymbolTable {
49
60
  #define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \
@@ -56,39 +67,96 @@ struct HIPSymbolTable {
56
67
 
57
68
  static struct HIPSymbolTable hipSymbolTable;
58
69
 
59
- bool initSymbolTable() {
60
- // Use the HIP runtime library loaded into the existing process if it exits.
61
- void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
62
- if (lib) {
63
- // printf("[triton] chosen loaded libamdhip64.so in the process\n");
70
+ static int checkDriverVersion(void *lib) {
71
+ int hipVersion = -1;
72
+ const char *error = NULL;
73
+ typedef hipError_t (*hipDriverGetVersion_fn)(int *driverVersion);
74
+ hipDriverGetVersion_fn hipDriverGetVersion;
75
+ dlerror(); // Clear existing errors
76
+ hipDriverGetVersion =
77
+ (hipDriverGetVersion_fn)dlsym(lib, "hipDriverGetVersion");
78
+ error = dlerror();
79
+ if (error) {
80
+ PyErr_SetString(PyExc_RuntimeError,
81
+ "cannot query 'hipDriverGetVersion' from libamdhip64.so");
82
+ dlclose(lib);
83
+ return -1;
64
84
  }
65
85
 
66
- // Otherwise, go through the list of search paths to dlopen the first HIP
67
- // driver library.
68
- if (!lib) {
69
- int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
70
- for (int i = 0; i < n; ++i) {
71
- void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
72
- if (handle) {
73
- lib = handle;
74
- // printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
75
- }
86
+ (void)hipDriverGetVersion(&hipVersion);
87
+ char msgBuff[TRITON_HIP_MSG_BUFF_SIZE] = {0};
88
+
89
+ const int hipMajVersion = TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(hipVersion);
90
+ if (hipMajVersion < TRITON_HIP_DRIVER_REQ_MAJOR_VERSION) {
91
+ const int hipMinVersion =
92
+ TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(hipVersion);
93
+ const int hipPatchVersion =
94
+ TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(hipVersion);
95
+ snprintf(msgBuff, sizeof(msgBuff),
96
+ "libamdhip64 version %d.%d.%d is not supported! Required major "
97
+ "version is >=%d.",
98
+ hipMajVersion, hipMinVersion, hipPatchVersion,
99
+ TRITON_HIP_DRIVER_REQ_MAJOR_VERSION);
100
+ PyErr_SetString(PyExc_RuntimeError, msgBuff);
101
+ dlclose(lib);
102
+ return -1;
103
+ }
104
+
105
+ TRITON_HIP_DRIVER_LOG_VERSION(hipVersion, msgBuff);
106
+
107
+ return hipVersion;
108
+ }
109
+
110
+ bool initSymbolTable() {
111
+ void *lib;
112
+
113
+ // Go through the list of search paths to dlopen the first HIP driver library.
114
+ int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
115
+ for (int i = 0; i < n; ++i) {
116
+ void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
117
+ if (handle) {
118
+ lib = handle;
119
+ // printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
76
120
  }
77
121
  }
122
+
78
123
  if (!lib) {
79
124
  PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
80
125
  return false;
81
126
  }
82
127
 
83
- // Resolve all symbols we are interested in.
84
- dlerror(); // Clear existing errors
128
+ int hipVersion = checkDriverVersion(lib);
129
+ if (hipVersion == -1)
130
+ return false;
131
+
85
132
  const char *error = NULL;
133
+ typedef hipError_t (*hipGetProcAddress_fn)(
134
+ const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
135
+ hipDriverProcAddressQueryResult *symbolStatus);
136
+ hipGetProcAddress_fn hipGetProcAddress;
137
+ dlerror(); // Clear existing errors
138
+
139
+ *(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
140
+ error = dlerror();
141
+ if (error) {
142
+ PyErr_SetString(PyExc_RuntimeError,
143
+ "cannot query 'hipGetProcAddress' from libamdhip64.so");
144
+ dlclose(lib);
145
+ return false;
146
+ }
147
+
148
+ // Resolve all symbols we are interested in.
149
+ uint64_t hipFlags = 0;
150
+ hipDriverProcAddressQueryResult symbolStatus;
151
+ hipError_t status = hipSuccess;
86
152
  #define QUERY_EACH_FN(hipSymbolName, ...) \
87
- *(void **)&hipSymbolTable.hipSymbolName = dlsym(lib, #hipSymbolName); \
88
- error = dlerror(); \
89
- if (error) { \
153
+ status = hipGetProcAddress(#hipSymbolName, \
154
+ (void **)&hipSymbolTable.hipSymbolName, \
155
+ hipVersion, hipFlags, &symbolStatus); \
156
+ if (status != hipSuccess) { \
90
157
  PyErr_SetString(PyExc_RuntimeError, \
91
- "cannot query " #hipSymbolName " from libamdhip64.so"); \
158
+ "cannot get address for '" #hipSymbolName \
159
+ "' from libamdhip64.so"); \
92
160
  dlclose(lib); \
93
161
  return false; \
94
162
  }
@@ -104,8 +172,9 @@ static inline void gpuAssert(hipError_t code, const char *file, int line) {
104
172
  {
105
173
  const char *prefix = "Triton Error [HIP]: ";
106
174
  const char *str = hipSymbolTable.hipGetErrorString(code);
107
- char err[1024] = {0};
108
- snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str);
175
+ char err[TRITON_HIP_MSG_BUFF_SIZE] = {0};
176
+ snprintf(err, sizeof(err), "%s Code: %d, Messsage: %s", prefix, code,
177
+ str);
109
178
  PyGILState_STATE gil_state;
110
179
  gil_state = PyGILState_Ensure();
111
180
  PyErr_SetString(PyExc_RuntimeError, err);
@@ -127,7 +196,7 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
127
196
  if (!PyArg_ParseTuple(args, "i", &device_id))
128
197
  return NULL;
129
198
 
130
- hipDeviceProp_tR0000 props;
199
+ hipDeviceProp_t props;
131
200
  HIP_CHECK(hipSymbolTable.hipGetDeviceProperties(&props, device_id));
132
201
 
133
202
  // create a struct to hold device properties
@@ -172,15 +241,18 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
172
241
  // get allocated registers and spilled registers from the function
173
242
  int n_regs = 0;
174
243
  int n_spills = 0;
244
+ int32_t n_max_threads = 0;
175
245
  hipSymbolTable.hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, fun);
176
246
  hipSymbolTable.hipFuncGetAttribute(&n_spills,
177
247
  HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
248
+ hipSymbolTable.hipFuncGetAttribute(
249
+ &n_max_threads, HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun);
178
250
  n_spills /= 4;
179
251
  if (PyErr_Occurred()) {
180
252
  return NULL;
181
253
  }
182
- return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
183
- n_spills);
254
+ return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
255
+ n_spills, n_max_threads);
184
256
  }
185
257
 
186
258
  static PyMethodDef ModuleMethods[] = {