triton-windows 3.1.0.post17__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,373 @@
1
+ import functools
2
+ import os
3
+ import re
4
+ import subprocess
5
+ import sys
6
+ import sysconfig
7
+ import warnings
8
+ import winreg
9
+ from collections.abc import Iterable
10
+ from functools import partial
11
+ from glob import glob
12
+ from pathlib import Path
13
+ from typing import Callable, Optional
14
+
15
+
16
+ def find_in_program_files(rel_path: str) -> Optional[Path]:
17
+ program_files = os.getenv("ProgramFiles(x86)", r"C:\Program Files (x86)")
18
+ path = Path(program_files) / rel_path
19
+ if path.exists():
20
+ return path
21
+
22
+ program_files = os.getenv("ProgramW6432", r"C:\Program Files")
23
+ path = Path(program_files) / rel_path
24
+ if path.exists():
25
+ return path
26
+
27
+ return None
28
+
29
+
30
+ def parse_version(s: str, prefix: str = "") -> Optional[tuple[int, ...]]:
31
+ s = s.removeprefix(prefix)
32
+ try:
33
+ return tuple(int(x) for x in s.split("."))
34
+ except ValueError:
35
+ return None
36
+
37
+
38
+ def unparse_version(t: Iterable[int], prefix: str = "") -> str:
39
+ return prefix + ".".join([str(x) for x in t])
40
+
41
+
42
+ def max_version(
43
+ versions: Iterable[str],
44
+ prefix: str = "",
45
+ check: Callable[[str], bool] = lambda x: True,
46
+ ) -> Optional[str]:
47
+ versions = [x for x in versions if check(x)]
48
+ versions = [parse_version(x, prefix) for x in versions]
49
+ versions = [x for x in versions if x is not None]
50
+ if not versions:
51
+ return None
52
+ version = unparse_version(max(versions), prefix)
53
+ return version
54
+
55
+
56
+ def check_msvc(msvc_base_path: Path, version: str) -> bool:
57
+ return all(
58
+ x.exists()
59
+ for x in [
60
+ msvc_base_path / version / "include" / "vcruntime.h",
61
+ msvc_base_path / version / "lib" / "x64" / "vcruntime.lib",
62
+ ]
63
+ )
64
+
65
+
66
+ def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
67
+ vswhere_path = find_in_program_files(
68
+ r"Microsoft Visual Studio\Installer\vswhere.exe"
69
+ )
70
+ if vswhere_path is None:
71
+ return None, None
72
+
73
+ command = [
74
+ str(vswhere_path),
75
+ "-prerelease",
76
+ "-products",
77
+ "*",
78
+ "-requires",
79
+ "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
80
+ "-requires",
81
+ "Microsoft.VisualStudio.Component.Windows10SDK",
82
+ "-latest",
83
+ "-property",
84
+ "installationPath",
85
+ ]
86
+ try:
87
+ output = subprocess.check_output(command, text=True).strip()
88
+ except subprocess.CalledProcessError:
89
+ return None, None
90
+
91
+ msvc_base_path = Path(output) / "VC" / "Tools" / "MSVC"
92
+ if not msvc_base_path.exists():
93
+ return None, None
94
+
95
+ version = max_version(
96
+ os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
97
+ )
98
+ if version is None:
99
+ return None, None
100
+
101
+ return msvc_base_path, version
102
+
103
+
104
+ def find_msvc_envpath() -> tuple[Optional[Path], Optional[str]]:
105
+ paths = os.getenv("PATH", "").split(os.pathsep)
106
+ for path in paths:
107
+ path = path.replace("/", "\\")
108
+ match = re.compile(r".*\\VC\\Tools\\MSVC\\").match(path)
109
+ if not match:
110
+ continue
111
+
112
+ msvc_base_path = Path(match.group(0))
113
+ if not msvc_base_path.exists():
114
+ continue
115
+
116
+ version = max_version(
117
+ os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
118
+ )
119
+ if version is None:
120
+ continue
121
+
122
+ return msvc_base_path, version
123
+
124
+ return None, None
125
+
126
+
127
+ def find_msvc_hardcoded() -> tuple[Optional[Path], Optional[str]]:
128
+ vs_path = find_in_program_files("Microsoft Visual Studio")
129
+ if vs_path is None:
130
+ return None, None
131
+
132
+ paths = glob(str(vs_path / "*" / "*" / "VC" / "Tools" / "MSVC"))
133
+ # First try the highest version
134
+ paths = sorted(paths)[::-1]
135
+ for msvc_base_path in paths:
136
+ msvc_base_path = Path(msvc_base_path)
137
+ version = max_version(
138
+ os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
139
+ )
140
+ if version is None:
141
+ continue
142
+ return msvc_base_path, version
143
+
144
+ return None, None
145
+
146
+
147
+ def find_msvc() -> tuple[list[str], list[str]]:
148
+ msvc_base_path, version = find_msvc_vswhere()
149
+ if msvc_base_path is None:
150
+ msvc_base_path, version = find_msvc_envpath()
151
+ if msvc_base_path is None:
152
+ msvc_base_path, version = find_msvc_hardcoded()
153
+ if msvc_base_path is None:
154
+ warnings.warn("Failed to find MSVC.")
155
+ return [], []
156
+
157
+ return (
158
+ [str(msvc_base_path / version / "include")],
159
+ [str(msvc_base_path / version / "lib" / "x64")],
160
+ )
161
+
162
+
163
+ def check_winsdk(winsdk_base_path: Path, version: str) -> bool:
164
+ return all(
165
+ x.exists()
166
+ for x in [
167
+ winsdk_base_path / "Include" / version / "ucrt" / "stdlib.h",
168
+ winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib",
169
+ ]
170
+ )
171
+
172
+
173
+ def find_winsdk_registry() -> tuple[Optional[Path], Optional[str]]:
174
+ try:
175
+ reg = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
176
+ key = winreg.OpenKeyEx(
177
+ reg, r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0"
178
+ )
179
+ folder = winreg.QueryValueEx(key, "InstallationFolder")[0]
180
+ winreg.CloseKey(key)
181
+ except OSError:
182
+ return None, None
183
+
184
+ winsdk_base_path = Path(folder)
185
+ if not (winsdk_base_path / "Include").exists():
186
+ return None, None
187
+
188
+ version = max_version(
189
+ os.listdir(winsdk_base_path / "Include"),
190
+ check=partial(check_winsdk, winsdk_base_path),
191
+ )
192
+ if version is None:
193
+ return None, None
194
+
195
+ return winsdk_base_path, version
196
+
197
+
198
+ def find_winsdk_hardcoded() -> tuple[Optional[Path], Optional[str]]:
199
+ winsdk_base_path = find_in_program_files(r"Windows Kits\10")
200
+ if winsdk_base_path is None:
201
+ return None, None
202
+ if not (winsdk_base_path / "Include").exists():
203
+ return None, None
204
+
205
+ version = max_version(
206
+ os.listdir(winsdk_base_path / "Include"),
207
+ check=partial(check_winsdk, winsdk_base_path),
208
+ )
209
+ if version is None:
210
+ return None, None
211
+
212
+ return winsdk_base_path, version
213
+
214
+
215
+ def find_winsdk() -> tuple[list[str], list[str]]:
216
+ winsdk_base_path, version = find_winsdk_registry()
217
+ if winsdk_base_path is None:
218
+ winsdk_base_path, version = find_winsdk_hardcoded()
219
+ if winsdk_base_path is None:
220
+ warnings.warn("Failed to find Windows SDK.")
221
+ return [], []
222
+
223
+ return (
224
+ [
225
+ str(winsdk_base_path / "Include" / version / "shared"),
226
+ str(winsdk_base_path / "Include" / version / "ucrt"),
227
+ str(winsdk_base_path / "Include" / version / "um"),
228
+ ],
229
+ [
230
+ str(winsdk_base_path / "Lib" / version / "ucrt" / "x64"),
231
+ str(winsdk_base_path / "Lib" / version / "um" / "x64"),
232
+ ],
233
+ )
234
+
235
+
236
+ @functools.cache
237
+ def find_msvc_winsdk() -> tuple[list[str], list[str]]:
238
+ msvc_inc_dirs, msvc_lib_dirs = find_msvc()
239
+ winsdk_inc_dirs, winsdk_lib_dirs = find_winsdk()
240
+ return msvc_inc_dirs + winsdk_inc_dirs, msvc_lib_dirs + winsdk_lib_dirs
241
+
242
+
243
+ @functools.cache
244
+ def find_python() -> list[str]:
245
+ for python_base_path in [
246
+ sys.exec_prefix,
247
+ sys.base_exec_prefix,
248
+ os.path.dirname(sys.executable),
249
+ ]:
250
+ python_lib_dir = Path(python_base_path) / "libs"
251
+ if (python_lib_dir / "python3.lib").exists():
252
+ return [str(python_lib_dir)]
253
+
254
+ warnings.warn("Failed to find Python libs.")
255
+ return []
256
+
257
+
258
+ def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list[str]]:
259
+ # pip
260
+ if all(
261
+ x.exists()
262
+ for x in [
263
+ base_path / "cuda_nvcc" / "bin" / "ptxas.exe",
264
+ base_path / "cuda_runtime" / "include" / "cuda.h",
265
+ base_path / "cuda_runtime" / "lib" / "x64" / "cuda.lib",
266
+ ]
267
+ ):
268
+ return (
269
+ str(base_path / "cuda_nvcc" / "bin"),
270
+ [str(base_path / "cuda_runtime" / "include")],
271
+ [str(base_path / "cuda_runtime" / "lib" / "x64")],
272
+ )
273
+
274
+ # conda
275
+ if all(
276
+ x.exists()
277
+ for x in [
278
+ base_path / "bin" / "ptxas.exe",
279
+ base_path / "include" / "cuda.h",
280
+ base_path / "lib" / "cuda.lib",
281
+ ]
282
+ ):
283
+ return (
284
+ str(base_path / "bin"),
285
+ [str(base_path / "include")],
286
+ [str(base_path / "lib")],
287
+ )
288
+
289
+ # bundled or system-wide
290
+ if all(
291
+ x.exists()
292
+ for x in [
293
+ base_path / "bin" / "ptxas.exe",
294
+ base_path / "include" / "cuda.h",
295
+ base_path / "lib" / "x64" / "cuda.lib",
296
+ ]
297
+ ):
298
+ return (
299
+ str(base_path / "bin"),
300
+ [str(base_path / "include")],
301
+ [str(base_path / "lib" / "x64")],
302
+ )
303
+
304
+ return None, [], []
305
+
306
+
307
+ def find_cuda_env() -> tuple[Optional[str], list[str], list[str]]:
308
+ for cuda_base_path in ["CUDA_PATH", "CUDA_HOME"]:
309
+ cuda_base_path = os.getenv(cuda_base_path)
310
+ if cuda_base_path is None:
311
+ continue
312
+
313
+ cuda_base_path = Path(cuda_base_path)
314
+ cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(
315
+ cuda_base_path
316
+ )
317
+ if cuda_bin_path:
318
+ return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
319
+
320
+ return None, [], []
321
+
322
+
323
+ def find_cuda_bundled() -> tuple[Optional[str], list[str], list[str]]:
324
+ cuda_base_path = (
325
+ Path(sysconfig.get_paths()["platlib"]) / "triton" / "backends" / "nvidia"
326
+ )
327
+ return check_and_find_cuda(cuda_base_path)
328
+
329
+
330
+ def find_cuda_pip() -> tuple[Optional[str], list[str], list[str]]:
331
+ nvidia_base_path = Path(sysconfig.get_paths()["platlib"]) / "nvidia"
332
+ return check_and_find_cuda(nvidia_base_path)
333
+
334
+
335
+ def find_cuda_conda() -> tuple[Optional[str], list[str], list[str]]:
336
+ cuda_base_path = Path(sys.exec_prefix) / "Library"
337
+ return check_and_find_cuda(cuda_base_path)
338
+
339
+
340
+ def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
341
+ parent = find_in_program_files(r"NVIDIA GPU Computing Toolkit\CUDA")
342
+ if parent is None:
343
+ return None, [], []
344
+
345
+ paths = glob(str(parent / "v12*"))
346
+ # First try the highest version
347
+ paths = sorted(paths)[::-1]
348
+ for cuda_base_path in paths:
349
+ cuda_base_path = Path(cuda_base_path)
350
+ cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(
351
+ cuda_base_path
352
+ )
353
+ if cuda_bin_path:
354
+ return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
355
+
356
+ return None, [], []
357
+
358
+
359
+ @functools.cache
360
+ def find_cuda() -> tuple[Optional[str], list[str], list[str]]:
361
+ for f in [
362
+ find_cuda_env,
363
+ find_cuda_bundled,
364
+ find_cuda_pip,
365
+ find_cuda_conda,
366
+ find_cuda_hardcoded,
367
+ ]:
368
+ cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = f()
369
+ if cuda_bin_path:
370
+ return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
371
+
372
+ warnings.warn("Failed to find CUDA.")
373
+ return None, [], []
@@ -0,0 +1,41 @@
1
+ Metadata-Version: 2.4
2
+ Name: triton-windows
3
+ Version: 3.1.0.post17
4
+ Summary: A language and compiler for custom Deep Learning operations
5
+ Home-page: https://github.com/woct0rdho/triton-windows
6
+ Author: Philippe Tillet, Dian Wu
7
+ Author-email: phil@openai.com, woctordho@outlook.com
8
+ Keywords: Compiler,Deep Learning
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Topic :: Software Development :: Build Tools
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3.8
14
+ Classifier: Programming Language :: Python :: 3.9
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Requires-Dist: filelock
19
+ Provides-Extra: build
20
+ Requires-Dist: cmake>=3.20; extra == "build"
21
+ Requires-Dist: lit; extra == "build"
22
+ Provides-Extra: tests
23
+ Requires-Dist: autopep8; extra == "tests"
24
+ Requires-Dist: flake8; extra == "tests"
25
+ Requires-Dist: isort; extra == "tests"
26
+ Requires-Dist: numpy; extra == "tests"
27
+ Requires-Dist: pytest; extra == "tests"
28
+ Requires-Dist: scipy>=1.7.1; extra == "tests"
29
+ Requires-Dist: llnl-hatchet; extra == "tests"
30
+ Provides-Extra: tutorials
31
+ Requires-Dist: matplotlib; extra == "tutorials"
32
+ Requires-Dist: pandas; extra == "tutorials"
33
+ Requires-Dist: tabulate; extra == "tutorials"
34
+ Dynamic: author
35
+ Dynamic: author-email
36
+ Dynamic: classifier
37
+ Dynamic: home-page
38
+ Dynamic: keywords
39
+ Dynamic: provides-extra
40
+ Dynamic: requires-dist
41
+ Dynamic: summary