triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,342 @@
1
+ """isort:skip_file"""
2
+ # Import order is significant here.
3
+
4
+ from . import math
5
+ from . import extra
6
+ from .standard import (
7
+ argmax,
8
+ argmin,
9
+ bitonic_merge,
10
+ cdiv,
11
+ cumprod,
12
+ cumsum,
13
+ flip,
14
+ interleave,
15
+ max,
16
+ min,
17
+ ravel,
18
+ reduce_or,
19
+ sigmoid,
20
+ softmax,
21
+ sort,
22
+ sum,
23
+ swizzle2d,
24
+ topk,
25
+ xor_sum,
26
+ zeros,
27
+ zeros_like,
28
+ )
29
+ from .core import (
30
+ PropagateNan,
31
+ TRITON_MAX_TENSOR_NUMEL,
32
+ load_tensor_descriptor,
33
+ store_tensor_descriptor,
34
+ make_tensor_descriptor,
35
+ tensor_descriptor,
36
+ tensor_descriptor_type,
37
+ add,
38
+ advance,
39
+ arange,
40
+ associative_scan,
41
+ assume,
42
+ async_task,
43
+ atomic_add,
44
+ atomic_and,
45
+ atomic_cas,
46
+ atomic_max,
47
+ atomic_min,
48
+ atomic_or,
49
+ atomic_xchg,
50
+ atomic_xor,
51
+ bfloat16,
52
+ block_type,
53
+ broadcast,
54
+ broadcast_to,
55
+ cat,
56
+ cast,
57
+ clamp,
58
+ condition,
59
+ const,
60
+ constexpr,
61
+ constexpr_type,
62
+ debug_barrier,
63
+ device_assert,
64
+ device_print,
65
+ dot,
66
+ dot_scaled,
67
+ dtype,
68
+ expand_dims,
69
+ float16,
70
+ float32,
71
+ float64,
72
+ float8e4b15,
73
+ float8e4nv,
74
+ float8e4b8,
75
+ float8e5,
76
+ float8e5b16,
77
+ full,
78
+ gather,
79
+ histogram,
80
+ inline_asm_elementwise,
81
+ int1,
82
+ int16,
83
+ int32,
84
+ int64,
85
+ int8,
86
+ join,
87
+ load,
88
+ make_block_ptr,
89
+ map_elementwise,
90
+ max_constancy,
91
+ max_contiguous,
92
+ maximum,
93
+ minimum,
94
+ multiple_of,
95
+ num_programs,
96
+ permute,
97
+ pi32_t,
98
+ pointer_type,
99
+ program_id,
100
+ range,
101
+ reduce,
102
+ reshape,
103
+ slice,
104
+ split,
105
+ static_assert,
106
+ static_print,
107
+ static_range,
108
+ store,
109
+ tensor,
110
+ trans,
111
+ tuple,
112
+ tuple_type,
113
+ uint16,
114
+ uint32,
115
+ uint64,
116
+ uint8,
117
+ view,
118
+ void,
119
+ where,
120
+ )
121
+ from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor,
122
+ ceil)
123
+ from .random import (
124
+ pair_uniform_to_normal,
125
+ philox,
126
+ philox_impl,
127
+ rand,
128
+ rand4x,
129
+ randint,
130
+ randint4x,
131
+ randn,
132
+ randn4x,
133
+ uint_to_uniform_float,
134
+ )
135
+ from . import target_info
136
+
137
+ __all__ = [
138
+ "PropagateNan",
139
+ "TRITON_MAX_TENSOR_NUMEL",
140
+ "load_tensor_descriptor",
141
+ "store_tensor_descriptor",
142
+ "make_tensor_descriptor",
143
+ "tensor_descriptor",
144
+ "abs",
145
+ "add",
146
+ "advance",
147
+ "arange",
148
+ "argmax",
149
+ "argmin",
150
+ "associative_scan",
151
+ "assume",
152
+ "async_task",
153
+ "atomic_add",
154
+ "atomic_and",
155
+ "atomic_cas",
156
+ "atomic_max",
157
+ "atomic_min",
158
+ "atomic_or",
159
+ "atomic_xchg",
160
+ "atomic_xor",
161
+ "bfloat16",
162
+ "bitonic_merge",
163
+ "block_type",
164
+ "broadcast",
165
+ "broadcast_to",
166
+ "cat",
167
+ "cast",
168
+ "cdiv",
169
+ "ceil",
170
+ "clamp",
171
+ "condition",
172
+ "const",
173
+ "constexpr",
174
+ "constexpr_type",
175
+ "cos",
176
+ "cumprod",
177
+ "cumsum",
178
+ "debug_barrier",
179
+ "device_assert",
180
+ "device_print",
181
+ "div_rn",
182
+ "dot",
183
+ "dot_scaled",
184
+ "dtype",
185
+ "erf",
186
+ "exp",
187
+ "exp2",
188
+ "expand_dims",
189
+ "extra",
190
+ "fdiv",
191
+ "flip",
192
+ "float16",
193
+ "float32",
194
+ "float64",
195
+ "float8e4b15",
196
+ "float8e4nv",
197
+ "float8e4b8",
198
+ "float8e5",
199
+ "float8e5b16",
200
+ "floor",
201
+ "fma",
202
+ "full",
203
+ "gather",
204
+ "histogram",
205
+ "inline_asm_elementwise",
206
+ "interleave",
207
+ "int1",
208
+ "int16",
209
+ "int32",
210
+ "int64",
211
+ "int8",
212
+ "join",
213
+ "load",
214
+ "log",
215
+ "log2",
216
+ "make_block_ptr",
217
+ "map_elementwise",
218
+ "math",
219
+ "max",
220
+ "max_constancy",
221
+ "max_contiguous",
222
+ "maximum",
223
+ "min",
224
+ "minimum",
225
+ "multiple_of",
226
+ "num_programs",
227
+ "pair_uniform_to_normal",
228
+ "permute",
229
+ "philox",
230
+ "philox_impl",
231
+ "pi32_t",
232
+ "pointer_type",
233
+ "program_id",
234
+ "rand",
235
+ "rand4x",
236
+ "randint",
237
+ "randint4x",
238
+ "randn",
239
+ "randn4x",
240
+ "range",
241
+ "ravel",
242
+ "reduce",
243
+ "reduce_or",
244
+ "reshape",
245
+ "rsqrt",
246
+ "slice",
247
+ "sigmoid",
248
+ "sin",
249
+ "softmax",
250
+ "sort",
251
+ "split",
252
+ "sqrt",
253
+ "sqrt_rn",
254
+ "static_assert",
255
+ "static_print",
256
+ "static_range",
257
+ "store",
258
+ "sum",
259
+ "swizzle2d",
260
+ "target_info",
261
+ "tensor",
262
+ "topk",
263
+ "trans",
264
+ "tuple",
265
+ "uint16",
266
+ "uint32",
267
+ "uint64",
268
+ "uint8",
269
+ "uint_to_uniform_float",
270
+ "umulhi",
271
+ "view",
272
+ "void",
273
+ "where",
274
+ "xor_sum",
275
+ "zeros",
276
+ "zeros_like",
277
+ ]
278
+
279
+
280
+ def str_to_ty(name, c):
281
+ from builtins import tuple
282
+
283
+ if isinstance(name, tuple):
284
+ fields = type(name).__dict__.get("_fields", None)
285
+ return tuple_type([str_to_ty(x, c) for x in name], fields)
286
+
287
+ if name[0] == "*":
288
+ name = name[1:]
289
+ const = False
290
+ if name[0] == "k":
291
+ name = name[1:]
292
+ const = True
293
+ ty = str_to_ty(name, c)
294
+ return pointer_type(element_ty=ty, const=const)
295
+
296
+ if name.startswith("tensordesc"):
297
+ inner = name.split("<")[1].rstrip(">")
298
+ dtype, rest = inner.split("[", maxsplit=1)
299
+ block_shape, rest = rest.split("]", maxsplit=1)
300
+ block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")]
301
+ layout = rest.lstrip(",")
302
+ is_gluon = len(layout)
303
+ dtype = str_to_ty(dtype, None)
304
+ ndim = len(block_shape)
305
+ shape_type = tuple_type([int32] * ndim)
306
+ # FIXME: Last dim stride should be constexpr(1)
307
+ stride_type = tuple_type(([int64] * ndim))
308
+ block = block_type(dtype, block_shape)
309
+ if is_gluon:
310
+ from triton.experimental.gluon.language._layouts import NVMMASharedLayout
311
+ from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor_type as gluon_tensor_descriptor_type
312
+ layout = eval(layout, dict(NVMMASharedLayout=NVMMASharedLayout))
313
+ assert isinstance(layout, NVMMASharedLayout)
314
+ return gluon_tensor_descriptor_type(block, shape_type, stride_type, layout)
315
+ return tensor_descriptor_type(block, shape_type, stride_type)
316
+
317
+ if name.startswith("constexpr"):
318
+ return constexpr_type(c)
319
+
320
+ tys = {
321
+ "fp8e4nv": float8e4nv,
322
+ "fp8e4b8": float8e4b8,
323
+ "fp8e5": float8e5,
324
+ "fp8e5b16": float8e5b16,
325
+ "fp8e4b15": float8e4b15,
326
+ "fp16": float16,
327
+ "bf16": bfloat16,
328
+ "fp32": float32,
329
+ "fp64": float64,
330
+ "i1": int1,
331
+ "i8": int8,
332
+ "i16": int16,
333
+ "i32": int32,
334
+ "i64": int64,
335
+ "u1": int1,
336
+ "u8": uint8,
337
+ "u16": uint16,
338
+ "u32": uint32,
339
+ "u64": uint64,
340
+ "B": int1,
341
+ }
342
+ return tys[name]