triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,405 @@
1
+ import functools
2
+ import os
3
+ import re
4
+ import subprocess
5
+ import sys
6
+ import sysconfig
7
+ import warnings
8
+ import winreg
9
+ from collections.abc import Iterable
10
+ from functools import partial
11
+ from glob import glob
12
+ from pathlib import Path
13
+ from typing import Callable, Optional
14
+
15
+
16
+ def find_in_program_files(rel_path: str) -> Optional[Path]:
17
+ program_files = os.getenv("ProgramFiles(x86)", r"C:\Program Files (x86)")
18
+ path = Path(program_files) / rel_path
19
+ if path.exists():
20
+ return path
21
+
22
+ program_files = os.getenv("ProgramW6432", r"C:\Program Files")
23
+ path = Path(program_files) / rel_path
24
+ if path.exists():
25
+ return path
26
+
27
+ return None
28
+
29
+
30
+ def parse_version(s: str, prefix: str = "") -> Optional[tuple[int, ...]]:
31
+ s = s.removeprefix(prefix)
32
+ try:
33
+ return tuple(int(x) for x in s.split("."))
34
+ except ValueError:
35
+ return None
36
+
37
+
38
+ def unparse_version(t: Iterable[int], prefix: str = "") -> str:
39
+ return prefix + ".".join([str(x) for x in t])
40
+
41
+
42
+ def max_version(
43
+ versions: Iterable[str],
44
+ prefix: str = "",
45
+ check: Callable[[str], bool] = lambda x: True,
46
+ ) -> Optional[str]:
47
+ versions = [x for x in versions if check(x)]
48
+ versions = [parse_version(x, prefix) for x in versions]
49
+ versions = [x for x in versions if x is not None]
50
+ if not versions:
51
+ return None
52
+ version = unparse_version(max(versions), prefix)
53
+ return version
54
+
55
+
56
+ def check_msvc(msvc_base_path: Path, version: str) -> bool:
57
+ return all(x.exists() for x in [
58
+ msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe",
59
+ msvc_base_path / version / "include" / "vcruntime.h",
60
+ msvc_base_path / version / "lib" / "x64" / "vcruntime.lib",
61
+ ])
62
+
63
+
64
+ def find_msvc_env() -> tuple[Optional[Path], Optional[str]]:
65
+ msvc_base_path = os.getenv("VCINSTALLDIR")
66
+ if msvc_base_path is None:
67
+ return None, None
68
+ msvc_base_path = Path(msvc_base_path) / "Tools" / "MSVC"
69
+
70
+ version = os.getenv("VCToolsVersion")
71
+ if not check_msvc(msvc_base_path, version):
72
+ warnings.warn(f"Environment variables VCINSTALLDIR = {os.getenv('VCINSTALLDIR')}, "
73
+ f"VCToolsVersion = {os.getenv('VCToolsVersion')} are set, "
74
+ "but this MSVC installation is incomplete.")
75
+ return None, None
76
+
77
+ return msvc_base_path, version
78
+
79
+
80
+ def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
81
+ vswhere_path = find_in_program_files(r"Microsoft Visual Studio\Installer\vswhere.exe")
82
+ if vswhere_path is None:
83
+ return None, None
84
+
85
+ command = [
86
+ str(vswhere_path),
87
+ "-prerelease",
88
+ "-products",
89
+ "*",
90
+ "-requires",
91
+ "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
92
+ "-requires",
93
+ "Microsoft.VisualStudio.Component.Windows10SDK",
94
+ "-latest",
95
+ "-property",
96
+ "installationPath",
97
+ ]
98
+ try:
99
+ output = subprocess.check_output(command, text=True).strip()
100
+ except subprocess.CalledProcessError:
101
+ return None, None
102
+
103
+ msvc_base_path = Path(output) / "VC" / "Tools" / "MSVC"
104
+ if not msvc_base_path.exists():
105
+ return None, None
106
+
107
+ version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
108
+ if version is None:
109
+ return None, None
110
+
111
+ return msvc_base_path, version
112
+
113
+
114
+ def find_msvc_envpath() -> tuple[Optional[Path], Optional[str]]:
115
+ paths = os.getenv("PATH", "").split(os.pathsep)
116
+ for path in paths:
117
+ path = path.replace("/", "\\")
118
+ match = re.compile(r".*\\VC\\Tools\\MSVC\\").match(path)
119
+ if not match:
120
+ continue
121
+
122
+ msvc_base_path = Path(match.group(0))
123
+ if not msvc_base_path.exists():
124
+ continue
125
+
126
+ version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
127
+ if version is None:
128
+ continue
129
+
130
+ return msvc_base_path, version
131
+
132
+ return None, None
133
+
134
+
135
+ def find_msvc_hardcoded() -> tuple[Optional[Path], Optional[str]]:
136
+ vs_path = find_in_program_files("Microsoft Visual Studio")
137
+ if vs_path is None:
138
+ return None, None
139
+
140
+ paths = glob(str(vs_path / "*" / "*" / "VC" / "Tools" / "MSVC"))
141
+ # First try the highest version
142
+ paths = sorted(paths)[::-1]
143
+ for msvc_base_path in paths:
144
+ msvc_base_path = Path(msvc_base_path)
145
+ version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
146
+ if version is None:
147
+ continue
148
+ return msvc_base_path, version
149
+
150
+ return None, None
151
+
152
+
153
+ def find_msvc(env_only: bool) -> tuple[Optional[str], list[str], list[str]]:
154
+ if env_only:
155
+ fs = [find_msvc_env]
156
+ else:
157
+ fs = [
158
+ find_msvc_env,
159
+ find_msvc_vswhere,
160
+ find_msvc_envpath,
161
+ find_msvc_hardcoded,
162
+ ]
163
+ for f in fs:
164
+ msvc_base_path, version = f()
165
+ if msvc_base_path:
166
+ return (
167
+ str(msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe"),
168
+ [str(msvc_base_path / version / "include")],
169
+ [str(msvc_base_path / version / "lib" / "x64")],
170
+ )
171
+
172
+ if not env_only:
173
+ warnings.warn("Failed to find MSVC.")
174
+ return None, [], []
175
+
176
+
177
+ def check_winsdk(winsdk_base_path: Path, version: str) -> bool:
178
+ return all(x.exists() for x in [
179
+ winsdk_base_path / "Include" / version / "ucrt" / "stdlib.h",
180
+ winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib",
181
+ ])
182
+
183
+
184
+ def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
185
+ winsdk_base_path = os.getenv("WindowsSdkDir")
186
+ if winsdk_base_path is None:
187
+ return None, None
188
+ winsdk_base_path = Path(winsdk_base_path)
189
+
190
+ version = os.getenv("WindowsSDKVersion")
191
+ if version is None:
192
+ version = os.getenv("WindowsSDKVer")
193
+ if version is None:
194
+ warnings.warn(f"Environment variable WindowsSdkDir = {winsdk_base_path}, "
195
+ "but WindowsSDKVersion (or WindowsSDKVer) is not set.")
196
+ return None, None
197
+ version = version.rstrip("\\")
198
+ if not check_winsdk(winsdk_base_path, version):
199
+ warnings.warn(f"Environment variables WindowsSdkDir = {winsdk_base_path}, "
200
+ f"WindowsSDKVersion (or WindowsSDKVer) = {version} are set, "
201
+ "but this Windows SDK installation is incomplete.")
202
+ return None, None
203
+
204
+ return winsdk_base_path, version
205
+
206
+
207
+ def find_winsdk_registry() -> tuple[Optional[Path], Optional[str]]:
208
+ try:
209
+ reg = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
210
+ key = winreg.OpenKeyEx(reg, r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0")
211
+ folder = winreg.QueryValueEx(key, "InstallationFolder")[0]
212
+ winreg.CloseKey(key)
213
+ except OSError:
214
+ return None, None
215
+
216
+ winsdk_base_path = Path(folder)
217
+ if not (winsdk_base_path / "Include").exists():
218
+ return None, None
219
+
220
+ version = max_version(
221
+ os.listdir(winsdk_base_path / "Include"),
222
+ check=partial(check_winsdk, winsdk_base_path),
223
+ )
224
+ if version is None:
225
+ return None, None
226
+
227
+ return winsdk_base_path, version
228
+
229
+
230
+ def find_winsdk_hardcoded() -> tuple[Optional[Path], Optional[str]]:
231
+ winsdk_base_path = find_in_program_files(r"Windows Kits\10")
232
+ if winsdk_base_path is None:
233
+ return None, None
234
+ if not (winsdk_base_path / "Include").exists():
235
+ return None, None
236
+
237
+ version = max_version(
238
+ os.listdir(winsdk_base_path / "Include"),
239
+ check=partial(check_winsdk, winsdk_base_path),
240
+ )
241
+ if version is None:
242
+ return None, None
243
+
244
+ return winsdk_base_path, version
245
+
246
+
247
+ def find_winsdk(env_only: bool) -> tuple[list[str], list[str]]:
248
+ if env_only:
249
+ fs = [find_winsdk_env]
250
+ else:
251
+ fs = [
252
+ find_winsdk_env,
253
+ find_winsdk_registry,
254
+ find_winsdk_hardcoded,
255
+ ]
256
+ for f in fs:
257
+ winsdk_base_path, version = f()
258
+ if winsdk_base_path:
259
+ return (
260
+ [
261
+ str(winsdk_base_path / "Include" / version / "shared"),
262
+ str(winsdk_base_path / "Include" / version / "ucrt"),
263
+ str(winsdk_base_path / "Include" / version / "um"),
264
+ ],
265
+ [
266
+ str(winsdk_base_path / "Lib" / version / "ucrt" / "x64"),
267
+ str(winsdk_base_path / "Lib" / version / "um" / "x64"),
268
+ ],
269
+ )
270
+
271
+ if not env_only:
272
+ warnings.warn("Failed to find Windows SDK.")
273
+ return [], []
274
+
275
+
276
+ @functools.lru_cache
277
+ def find_msvc_winsdk(env_only: bool = False, ) -> tuple[Optional[str], list[str], list[str]]:
278
+ msvc_bin_path, msvc_inc_dirs, msvc_lib_dirs = find_msvc(env_only)
279
+ winsdk_inc_dirs, winsdk_lib_dirs = find_winsdk(env_only)
280
+ return (
281
+ msvc_bin_path,
282
+ msvc_inc_dirs + winsdk_inc_dirs,
283
+ msvc_lib_dirs + winsdk_lib_dirs,
284
+ )
285
+
286
+
287
+ @functools.lru_cache
288
+ def find_python() -> list[str]:
289
+ version = sysconfig.get_python_version().replace(".", "")
290
+ if sysconfig.get_config_var("Py_GIL_DISABLED"):
291
+ version += "t"
292
+ for python_base_path in [
293
+ sys.exec_prefix,
294
+ sys.base_exec_prefix,
295
+ os.path.dirname(sys.executable),
296
+ ]:
297
+ python_lib_dir = Path(python_base_path) / "libs"
298
+ if (python_lib_dir / f"python{version}.lib").exists():
299
+ return [str(python_lib_dir)]
300
+
301
+ warnings.warn("Failed to find Python libs.")
302
+ return []
303
+
304
+
305
+ def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list[str]]:
306
+ # pip
307
+ if all(x.exists() for x in [
308
+ base_path / "cuda_nvcc" / "bin" / "ptxas.exe",
309
+ base_path / "cuda_runtime" / "include" / "cuda.h",
310
+ base_path / "cuda_runtime" / "lib" / "x64" / "cuda.lib",
311
+ ]):
312
+ return (
313
+ str(base_path / "cuda_nvcc" / "bin"),
314
+ [str(base_path / "cuda_runtime" / "include")],
315
+ [str(base_path / "cuda_runtime" / "lib" / "x64")],
316
+ )
317
+
318
+ # conda
319
+ if all(x.exists() for x in [
320
+ base_path / "bin" / "ptxas.exe",
321
+ base_path / "include" / "cuda.h",
322
+ base_path / "lib" / "cuda.lib",
323
+ ]):
324
+ return (
325
+ str(base_path / "bin"),
326
+ [str(base_path / "include")],
327
+ [str(base_path / "lib")],
328
+ )
329
+
330
+ # bundled or system-wide
331
+ if all(x.exists() for x in [
332
+ base_path / "bin" / "ptxas.exe",
333
+ base_path / "include" / "cuda.h",
334
+ base_path / "lib" / "x64" / "cuda.lib",
335
+ ]):
336
+ return (
337
+ str(base_path / "bin"),
338
+ [str(base_path / "include")],
339
+ [str(base_path / "lib" / "x64")],
340
+ )
341
+
342
+ return None, [], []
343
+
344
+
345
+ def find_cuda_env() -> tuple[Optional[str], list[str], list[str]]:
346
+ for cuda_base_path in ["CUDA_PATH", "CUDA_HOME"]:
347
+ cuda_base_path = os.getenv(cuda_base_path)
348
+ if cuda_base_path is None:
349
+ continue
350
+
351
+ cuda_base_path = Path(cuda_base_path)
352
+ cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
353
+ if cuda_bin_path:
354
+ return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
355
+
356
+ return None, [], []
357
+
358
+
359
+ def find_cuda_bundled() -> tuple[Optional[str], list[str], list[str]]:
360
+ cuda_base_path = (Path(sysconfig.get_paths()["platlib"]) / "triton" / "backends" / "nvidia")
361
+ return check_and_find_cuda(cuda_base_path)
362
+
363
+
364
+ def find_cuda_pip() -> tuple[Optional[str], list[str], list[str]]:
365
+ nvidia_base_path = Path(sysconfig.get_paths()["platlib"]) / "nvidia"
366
+ return check_and_find_cuda(nvidia_base_path)
367
+
368
+
369
+ def find_cuda_conda() -> tuple[Optional[str], list[str], list[str]]:
370
+ cuda_base_path = Path(sys.exec_prefix) / "Library"
371
+ return check_and_find_cuda(cuda_base_path)
372
+
373
+
374
+ def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
375
+ parent = find_in_program_files(r"NVIDIA GPU Computing Toolkit\CUDA")
376
+ if parent is None:
377
+ return None, [], []
378
+
379
+ paths = glob(str(parent / "v12*"))
380
+ # First try the highest version
381
+ paths = sorted(paths)[::-1]
382
+ for cuda_base_path in paths:
383
+ cuda_base_path = Path(cuda_base_path)
384
+ cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
385
+ if cuda_bin_path:
386
+ return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
387
+
388
+ return None, [], []
389
+
390
+
391
+ @functools.lru_cache
392
+ def find_cuda() -> tuple[Optional[str], list[str], list[str]]:
393
+ for f in [
394
+ find_cuda_env,
395
+ find_cuda_bundled,
396
+ find_cuda_pip,
397
+ find_cuda_conda,
398
+ find_cuda_hardcoded,
399
+ ]:
400
+ cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = f()
401
+ if cuda_bin_path:
402
+ return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
403
+
404
+ warnings.warn("Failed to find CUDA.")
405
+ return None, [], []
@@ -0,0 +1,46 @@
1
+ Metadata-Version: 2.4
2
+ Name: triton-windows
3
+ Version: 3.5.1.post21
4
+ Summary: A language and compiler for custom Deep Learning operations
5
+ Home-page: https://github.com/woct0rdho/triton-windows
6
+ Author: Philippe Tillet, Dian Wu
7
+ Author-email: phil@openai.com, woctordho@outlook.com
8
+ Keywords: Compiler,Deep Learning
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Topic :: Software Development :: Build Tools
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Classifier: Programming Language :: Python :: 3.14
18
+ Requires-Python: >=3.10,<3.15
19
+ License-File: LICENSE
20
+ Requires-Dist: importlib-metadata; python_version < "3.10"
21
+ Provides-Extra: build
22
+ Requires-Dist: cmake<4.0,>=3.20; extra == "build"
23
+ Requires-Dist: lit; extra == "build"
24
+ Provides-Extra: tests
25
+ Requires-Dist: autopep8; extra == "tests"
26
+ Requires-Dist: isort; extra == "tests"
27
+ Requires-Dist: numpy; extra == "tests"
28
+ Requires-Dist: pytest; extra == "tests"
29
+ Requires-Dist: pytest-forked; extra == "tests"
30
+ Requires-Dist: pytest-xdist; extra == "tests"
31
+ Requires-Dist: scipy>=1.7.1; extra == "tests"
32
+ Requires-Dist: llnl-hatchet; extra == "tests"
33
+ Provides-Extra: tutorials
34
+ Requires-Dist: matplotlib; extra == "tutorials"
35
+ Requires-Dist: pandas; extra == "tutorials"
36
+ Requires-Dist: tabulate; extra == "tutorials"
37
+ Dynamic: author
38
+ Dynamic: author-email
39
+ Dynamic: classifier
40
+ Dynamic: home-page
41
+ Dynamic: keywords
42
+ Dynamic: license-file
43
+ Dynamic: provides-extra
44
+ Dynamic: requires-dist
45
+ Dynamic: requires-python
46
+ Dynamic: summary