triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,534 @@
1
+ from __future__ import annotations
2
+
3
+ from ..runtime.jit import jit, constexpr_function
4
+ from . import core
5
+ from . import math
6
+
7
+ # constexpr utilities
8
+
9
+
10
+ @constexpr_function
11
+ def _log2(i):
12
+ log2 = 0
13
+ n = i
14
+ while n > 1:
15
+ n >>= 1
16
+ log2 += 1
17
+ return log2
18
+
19
+
20
+ @constexpr_function
21
+ def _is_power_of_two(i):
22
+ return (i & (i - 1)) == 0 and i != 0
23
+
24
+
25
+ # -----------------------
26
+ # Standard library
27
+ # -----------------------
28
+
29
+
30
+ @core._tensor_member_fn
31
+ @jit
32
+ def cdiv(x, div):
33
+ """
34
+ Computes the ceiling division of :code:`x` by :code:`div`
35
+
36
+ :param x: the input number
37
+ :type x: Block
38
+ :param div: the divisor
39
+ :type div: Block
40
+ """
41
+ return (x + div - 1) // div
42
+
43
+
44
+ @core._tensor_member_fn
45
+ @jit
46
+ @math._add_math_1arg_docstr("sigmoid")
47
+ def sigmoid(x):
48
+ return 1 / (1 + math.exp(-x))
49
+
50
+
51
+ @core._tensor_member_fn
52
+ @jit
53
+ @math._add_math_1arg_docstr("softmax")
54
+ def softmax(x, dim=None, keep_dims=False, ieee_rounding=False):
55
+ if dim is None:
56
+ _dim: core.constexpr = 0
57
+ else:
58
+ _dim: core.constexpr = dim
59
+ z = x - max(x, _dim, keep_dims=keep_dims)
60
+ num = math.exp(z)
61
+ den = sum(num, _dim, keep_dims=keep_dims)
62
+ return math.fdiv(num, den, ieee_rounding)
63
+
64
+
65
+ @core._tensor_member_fn
66
+ @jit
67
+ def ravel(x, can_reorder=False):
68
+ """
69
+ Returns a contiguous flattened view of :code:`x`.
70
+
71
+ :param x: the input tensor
72
+ :type x: Block
73
+ """
74
+ return core.reshape(x, [x.numel], can_reorder=can_reorder)
75
+
76
+
77
+ @jit
78
+ def swizzle2d(i, j, size_i, size_j, size_g):
79
+ """
80
+ Transforms the indices of a row-major `size_i * size_j` matrix into
81
+ the indices of a column-major matrix for each group of `size_g` rows.
82
+
83
+ For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will
84
+ transform ::
85
+
86
+ [[0 , 1 , 2 , 3 ],
87
+ [4 , 5 , 6 , 7 ],
88
+ [8 , 9 , 10, 11],
89
+ [12, 13, 14, 15]]
90
+
91
+ into ::
92
+
93
+ [[0, 2, 4 , 6 ],
94
+ [1, 3, 5 , 7 ],
95
+ [8, 10, 12, 14],
96
+ [9, 11, 13, 15]]
97
+ """
98
+ # "unrolled index in array"
99
+ ij = i * size_j + j
100
+ # number of elements in `size_g` groups
101
+ # of `size_j` columns
102
+ size_gj = size_g * size_j
103
+ # index of the group in which (i,j) is
104
+ group_id = ij // size_gj
105
+ # row-index of the first element of this group
106
+ off_i = group_id * size_g
107
+ # last group may have fewer rows
108
+ size_g = core.minimum(size_i - off_i, size_g)
109
+ # linear index with respect to the first element in this group
110
+ ij = ij % size_gj
111
+ # new row and column indices
112
+ new_i = off_i + ij % size_g
113
+ new_j = ij // size_g
114
+ return new_i, new_j
115
+
116
+
117
+ @jit
118
+ def zeros(shape, dtype):
119
+ """
120
+ Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
121
+
122
+ :param shape: Shape of the new array, e.g., (8, 16) or (8, )
123
+ :type shape: tuple of ints
124
+ :param dtype: Data-type of the new array, e.g., :code:`tl.float16`
125
+ :type dtype: DType
126
+ """
127
+ return core.full(shape, 0, dtype)
128
+
129
+
130
+ @jit
131
+ def zeros_like(input):
132
+ """
133
+ Returns a tensor of zeros with the same shape and type as a given tensor.
134
+
135
+ :param input: input tensor
136
+ :type input: Tensor
137
+ """
138
+ return zeros(input.shape, input.dtype)
139
+
140
+
141
+ # max and argmax
142
+
143
+
144
+ @jit
145
+ def _argmax_combine(value1, index1, value2, index2, tie_break_left):
146
+ if tie_break_left:
147
+ tie = value1 == value2 and index1 < index2
148
+ else:
149
+ tie = False
150
+ gt = value1 > value2 or tie
151
+ v_ret = core.where(gt, value1, value2)
152
+ i_ret = core.where(gt, index1, index2)
153
+ return v_ret, i_ret
154
+
155
+
156
+ @jit
157
+ def _argmax_combine_tie_break_left(value1, index1, value2, index2):
158
+ return _argmax_combine(value1, index1, value2, index2, True)
159
+
160
+
161
+ @jit
162
+ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
163
+ return _argmax_combine(value1, index1, value2, index2, False)
164
+
165
+
166
+ @jit
167
+ def _elementwise_max(a, b):
168
+ return core.maximum(a, b)
169
+
170
+
171
+ @core._tensor_member_fn
172
+ @jit
173
+ @core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
174
+ tie_break_arg="return_indices_tie_break_left")
175
+ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
176
+ input = core._promote_bfloat16_to_float32(input)
177
+ if return_indices:
178
+ if return_indices_tie_break_left:
179
+ return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims)
180
+ else:
181
+ return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims)
182
+ else:
183
+ if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
184
+ if core.constexpr(input.dtype.is_floating()):
185
+ input = input.to(core.float32)
186
+ else:
187
+ assert input.dtype.is_int(), "Expecting input to be integer type"
188
+ input = input.to(core.int32)
189
+ return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims)
190
+
191
+
192
+ @core._tensor_member_fn
193
+ @jit
194
+ @core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left")
195
+ def argmax(input, axis, tie_break_left=True, keep_dims=False):
196
+ (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
197
+ return ret
198
+
199
+
200
+ # min and argmin
201
+
202
+
203
+ @jit
204
+ def _argmin_combine(value1, index1, value2, index2, tie_break_left):
205
+ if tie_break_left:
206
+ tie = value1 == value2 and index1 < index2
207
+ else:
208
+ tie = False
209
+ lt = value1 < value2 or tie
210
+ value_ret = core.where(lt, value1, value2)
211
+ index_ret = core.where(lt, index1, index2)
212
+ return value_ret, index_ret
213
+
214
+
215
+ @jit
216
+ def _argmin_combine_tie_break_left(value1, index1, value2, index2):
217
+ return _argmin_combine(value1, index1, value2, index2, True)
218
+
219
+
220
+ @jit
221
+ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
222
+ return _argmin_combine(value1, index1, value2, index2, False)
223
+
224
+
225
+ @jit
226
+ def _elementwise_min(a, b):
227
+ return core.minimum(a, b)
228
+
229
+
230
+ @core._tensor_member_fn
231
+ @jit
232
+ @core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
233
+ tie_break_arg="return_indices_tie_break_left")
234
+ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
235
+ input = core._promote_bfloat16_to_float32(input)
236
+ if return_indices:
237
+ if return_indices_tie_break_left:
238
+ return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims)
239
+ else:
240
+ return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims)
241
+ else:
242
+ if core.constexpr(input.dtype.primitive_bitwidth) < 32:
243
+ if core.constexpr(input.dtype.is_floating()):
244
+ input = input.to(core.float32)
245
+ else:
246
+ assert input.dtype.is_int(), "Expecting input to be integer type"
247
+ input = input.to(core.int32)
248
+ return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims)
249
+
250
+
251
+ @core._tensor_member_fn
252
+ @jit
253
+ @core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
254
+ def argmin(input, axis, tie_break_left=True, keep_dims=False):
255
+ _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
256
+ return ret
257
+
258
+
259
+ @jit
260
+ def _sum_combine(a, b):
261
+ return a + b
262
+
263
+
264
+ # sum
265
+
266
+
267
+ @constexpr_function
268
+ def _pick_sum_dtype(in_dtype, dtype):
269
+ if dtype is not None:
270
+ return dtype
271
+
272
+ # For integer bitwidths less than 32, pick int32 with the same sign to
273
+ # avoid overflow.
274
+ out_dtype = None
275
+ if in_dtype.is_int_signed():
276
+ out_dtype = core.int32 if in_dtype.int_bitwidth < 32 else None
277
+ elif in_dtype.is_int_unsigned():
278
+ out_dtype = core.uint32 if in_dtype.int_bitwidth < 32 else None
279
+ return out_dtype
280
+
281
+
282
+ @core._tensor_member_fn
283
+ @jit
284
+ @core._add_reduction_docstr("sum", dtype_arg="dtype")
285
+ def sum(input, axis=None, keep_dims=False, dtype: core.constexpr = None):
286
+ # Pick a default dtype for the reduction if one was not specified.
287
+ out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
288
+
289
+ if out_dtype is not None:
290
+ input = input.to(out_dtype)
291
+ return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
292
+
293
+
294
+ @jit
295
+ def _xor_combine(a, b):
296
+ return a ^ b
297
+
298
+
299
+ # xor sum
300
+
301
+
302
+ @core._tensor_member_fn
303
+ @jit
304
+ @core._add_reduction_docstr("xor sum")
305
+ def xor_sum(input, axis=None, keep_dims=False):
306
+ core.static_assert(input.type.scalar.is_int(), "xor_sum only supported for integers")
307
+ return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims)
308
+
309
+
310
+ # or reduction
311
+
312
+
313
+ @jit
314
+ def _or_combine(x, y):
315
+ return x | y
316
+
317
+
318
+ @core._tensor_member_fn
319
+ @jit
320
+ @core._add_reduction_docstr("reduce_or")
321
+ def reduce_or(input, axis, keep_dims=False):
322
+ core.static_assert(input.type.scalar.is_int(), "reduce_or only supported for integers")
323
+ return core.reduce(input, axis, _or_combine, keep_dims=keep_dims)
324
+
325
+
326
+ # cumsum
327
+
328
+
329
+ @core._tensor_member_fn
330
+ @jit
331
+ @core._add_scan_docstr("cumsum", dtype_arg="dtype")
332
+ def cumsum(input, axis=0, reverse=False, dtype: core.constexpr = None):
333
+ # todo rename this to a generic function name
334
+
335
+ input = core._promote_bfloat16_to_float32(input)
336
+ out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
337
+
338
+ if out_dtype is not None:
339
+ input = input.to(out_dtype)
340
+
341
+ return core.associative_scan(input, axis, _sum_combine, reverse)
342
+
343
+
344
+ # cumprod
345
+
346
+
347
+ @jit
348
+ def _prod_combine(a, b):
349
+ return a * b
350
+
351
+
352
+ @core._tensor_member_fn
353
+ @jit
354
+ @core._add_scan_docstr("cumprod")
355
+ def cumprod(input, axis=0, reverse=False):
356
+ # todo rename this to a generic function name
357
+ input = core._promote_bfloat16_to_float32(input)
358
+ return core.associative_scan(input, axis, _prod_combine, reverse)
359
+
360
+
361
+ # sort
362
+
363
+
364
+ @jit
365
+ def _indicator(n_dims: core.constexpr, j: core.constexpr):
366
+ ar = core.arange(0, 2)
367
+ ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j)
368
+ return ar
369
+
370
+
371
+ @jit
372
+ def _compare_and_swap(x, flip, i: core.constexpr):
373
+ # compare-and-swap on the ith *innermost* dimension
374
+ n_dims: core.constexpr = _log2(x.numel)
375
+
376
+ # flip along middle dimension (the bitwise XORs will be optimised away):
377
+ idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
378
+ ix = x.to(idtype, bitcast=True)
379
+ iy = ix ^ xor_sum(ix, n_dims - 1 - i, True)
380
+ y = iy.to(x.dtype, bitcast=True)
381
+
382
+ # determines whether we are in the right (rather than left) position along the axis:
383
+ is_right = _indicator(n_dims, i)
384
+
385
+ # conditional swap:
386
+ ret = core.where((x > y) != (flip ^ is_right), y, x)
387
+ return ret
388
+
389
+
390
+ @jit
391
+ def _bitonic_merge_hypercube(x, stage: core.constexpr, order: core.constexpr):
392
+ '''
393
+ order_type 0 == ascending
394
+ order_type 1 == descending
395
+ order_type 2 == alternating
396
+ '''
397
+ # flip denotes whether to re-arrange sub-sequences of elements in ascending or
398
+ # descending order.
399
+ # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
400
+ # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
401
+ # a stride of 2) at this stage
402
+ if order == 2:
403
+ flip = _indicator(_log2(x.numel), stage)
404
+ else:
405
+ flip = order
406
+ # perform `stage` rounds of `compare-and-swap`
407
+ for i in core.static_range(stage):
408
+ x = _compare_and_swap(x, flip, stage - 1 - i)
409
+ return x
410
+
411
+
412
+ @jit
413
+ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
414
+ h = core.reshape(x, [2] * _log2(x.numel))
415
+ h = _bitonic_merge_hypercube(h, stage, order)
416
+ x = core.reshape(h, x.shape)
417
+ return x
418
+
419
+
420
+ @jit
421
+ def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
422
+ """
423
+ Sorts a tensor along a specified dimension.
424
+
425
+ :param x: The input tensor to be sorted.
426
+ :type x: Tensor
427
+ :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
428
+ :type dim: int, optional
429
+ :param k: the number of top elements to select. If none, assume k = x.shape[dim]
430
+ :type k: int, optional
431
+ :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
432
+ :type descending: bool, optional
433
+ """
434
+ # handle default dimension or check that it is the most minor dim
435
+ _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
436
+ core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
437
+
438
+ log_n: core.constexpr = _log2(x.shape[_dim])
439
+ log_k: core.constexpr = log_n if k is None else _log2(k)
440
+
441
+ n_dims: core.constexpr = _log2(x.numel)
442
+
443
+ # reshape to hypercube:
444
+ h = core.reshape(x, [2] * n_dims)
445
+
446
+ # run first log_k bitonic sort iterations:
447
+ for i in core.static_range(1, log_k + 1):
448
+ h = _bitonic_merge_hypercube(h, i, 2 if i < log_n else descending)
449
+
450
+ # select top k elements using bitonic top-k
451
+ # https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf
452
+ for i in core.static_range(log_k + 1, log_n + 1):
453
+ h = max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k))
454
+ h = _bitonic_merge_hypercube(h, log_k, 2 if i < log_n else descending)
455
+
456
+ # reshape back:
457
+ x = core.reshape(h, x.shape[:-1] + [2**log_k])
458
+ return x
459
+
460
+
461
+ @jit
462
+ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
463
+ return sort_impl(x, dim=dim, descending=descending)
464
+
465
+
466
+ @jit
467
+ def topk(x, k: core.constexpr, dim: core.constexpr = None):
468
+ return sort_impl(x, k=k, dim=dim, descending=True)
469
+
470
+
471
+ @jit
472
+ def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
473
+ # handle default dimension or check that it is the most minor dim
474
+ _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
475
+ core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
476
+ n_dims: core.constexpr = _log2(x.shape[-1])
477
+ return _bitonic_merge(x, n_dims, descending, n_dims)
478
+
479
+
480
+ @constexpr_function
481
+ def _get_flip_dim(dim, shape):
482
+ if dim is None:
483
+ dim = len(shape) - 1
484
+ if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
485
+ dim += len(shape)
486
+ return dim
487
+
488
+
489
+ @core._tensor_member_fn
490
+ @jit
491
+ def flip(x, dim=None):
492
+ """
493
+ Flips a tensor `x` along the dimension `dim`.
494
+
495
+ :param x: the first input tensor
496
+ :type x: Block
497
+ :param dim: the dimension to flip along
498
+ :type dim: int
499
+ """
500
+ core.static_assert(-len(x.shape) <= dim and dim < len(x.shape))
501
+ _dim: core.constexpr = _get_flip_dim(dim, x.shape)
502
+ core.static_assert(_is_power_of_two(x.shape[_dim]))
503
+ steps: core.constexpr = _log2(x.shape[_dim])
504
+
505
+ # reshape the swap dimension to (2, 2, ..., 2)
506
+ idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
507
+ y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:])
508
+ for i in core.static_range(steps):
509
+ y = y ^ xor_sum(y, _dim + i, True)
510
+ x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
511
+ return x
512
+
513
+
514
+ @jit
515
+ def interleave(a, b):
516
+ """
517
+ Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape.
518
+ Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])`
519
+
520
+ :param a: The first input tensor.
521
+ :type a: Tensor
522
+ :param b: The second input tensor.
523
+ :type b: Tensor
524
+ """
525
+ c = core.join(a, b)
526
+
527
+ if len(c.shape) == 1:
528
+ # We must have interleaved two scalars.
529
+ return c
530
+ else:
531
+ # This `else` is necessary because Triton's AST parser doesn't
532
+ # understand that if we take the `if` above we definitely don't run this
533
+ # `else`.
534
+ return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]])
@@ -0,0 +1,54 @@
1
+ from triton.runtime import driver
2
+ from triton.runtime.jit import constexpr_function
3
+
4
+ __all__ = ["current_target"]
5
+
6
+
7
+ def current_target():
8
+ try:
9
+ active_driver = driver.active
10
+ except RuntimeError:
11
+ # If there is no active driver, return None
12
+ return None
13
+ return active_driver.get_current_target()
14
+
15
+
16
+ current_target.__triton_builtin__ = True
17
+
18
+
19
+ @constexpr_function
20
+ def is_cuda():
21
+ target = current_target()
22
+ return target is not None and target.backend == "cuda"
23
+
24
+
25
+ @constexpr_function
26
+ def cuda_capability_geq(major, minor=0):
27
+ """
28
+ Determines whether we have compute capability >= (major, minor) and
29
+ returns this as a constexpr boolean. This can be used for guarding
30
+ inline asm implementations that require a certain compute capability.
31
+ """
32
+ target = current_target()
33
+ if target is None or target.backend != "cuda":
34
+ return False
35
+ assert isinstance(target.arch, int)
36
+ return target.arch >= major * 10 + minor
37
+
38
+
39
+ @constexpr_function
40
+ def is_hip():
41
+ target = current_target()
42
+ return target is not None and target.backend == "hip"
43
+
44
+
45
+ @constexpr_function
46
+ def is_hip_cdna3():
47
+ target = current_target()
48
+ return target is not None and target.arch == "gfx942"
49
+
50
+
51
+ @constexpr_function
52
+ def is_hip_cdna4():
53
+ target = current_target()
54
+ return target is not None and target.arch == "gfx950"
@@ -0,0 +1,23 @@
1
+ from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics)
2
+ from .cache import RedisRemoteCacheBackend, RemoteCacheBackend
3
+ from .driver import driver
4
+ from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
5
+ from .errors import OutOfResources, InterpreterError
6
+
7
+ __all__ = [
8
+ "autotune",
9
+ "Autotuner",
10
+ "Config",
11
+ "driver",
12
+ "Heuristics",
13
+ "heuristics",
14
+ "InterpreterError",
15
+ "JITFunction",
16
+ "KernelInterface",
17
+ "MockTensor",
18
+ "OutOfResources",
19
+ "RedisRemoteCacheBackend",
20
+ "reinterpret",
21
+ "RemoteCacheBackend",
22
+ "TensorWrapper",
23
+ ]
@@ -0,0 +1,44 @@
1
+ from typing import Optional, Protocol
2
+ from contextvars import ContextVar
3
+
4
+
5
+ class Buffer(Protocol):
6
+
7
+ def data_ptr(self) -> int:
8
+ ...
9
+
10
+
11
+ class Allocator(Protocol):
12
+
13
+ def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
14
+ ...
15
+
16
+
17
+ class NullAllocator:
18
+
19
+ def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
20
+ raise RuntimeError("Kernel requires a runtime memory allocation, but no allocator was set. " +
21
+ "Use triton.set_allocator to specify an allocator.")
22
+
23
+
24
+ _allocator: ContextVar[Allocator] = ContextVar("_allocator", default=NullAllocator())
25
+
26
+
27
+ def set_allocator(allocator: Allocator):
28
+ """
29
+ The allocator function is called during kernel launch for kernels that
30
+ require additional global memory workspace.
31
+ """
32
+ _allocator.set(allocator)
33
+
34
+
35
+ _profile_allocator: Allocator = ContextVar("_allocator", default=NullAllocator())
36
+
37
+
38
+ def set_profile_allocator(allocator: Optional[Allocator]):
39
+ """
40
+ The profile allocator function is called before kernel launch for kernels
41
+ that require additional global memory workspace.
42
+ """
43
+ global _profile_allocator
44
+ _profile_allocator.set(allocator)
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+ from typing import Callable, Optional
3
+ from concurrent.futures import Executor, as_completed, Future
4
+ from contextvars import ContextVar
5
+
6
+ active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None)
7
+
8
+
9
+ class FutureKernel:
10
+
11
+ def __init__(self, finalize_compile: Callable, future: Future):
12
+ self.finalize_compile = finalize_compile
13
+ self.kernel = None
14
+ self.future = future
15
+
16
+ def result(self):
17
+ if self.kernel is not None:
18
+ return self.kernel
19
+
20
+ kernel = self.future.result()
21
+ self.finalize_compile(kernel)
22
+ self.kernel = kernel
23
+ return kernel
24
+
25
+
26
+ class AsyncCompileMode:
27
+
28
+ def __init__(self, executor: Executor):
29
+ self.executor = executor
30
+ self.raw_futures = []
31
+ self.future_kernels = {}
32
+
33
+ def submit(self, key, compile_fn, finalize_fn):
34
+ future = self.future_kernels.get(key)
35
+ if future is not None:
36
+ return future
37
+
38
+ future = self.executor.submit(compile_fn)
39
+ future._key = key
40
+ self.raw_futures.append(future)
41
+ future_kernel = FutureKernel(finalize_fn, future)
42
+ self.future_kernels[key] = future_kernel
43
+ return future_kernel
44
+
45
+ def __enter__(self):
46
+ if active_mode.get() is not None:
47
+ raise RuntimeError("Another AsyncCompileMode is already active")
48
+ active_mode.set(self)
49
+ return self
50
+
51
+ def __exit__(self, exc_type, exc_value, traceback):
52
+ # Finalize any outstanding compiles
53
+ for future in as_completed(self.raw_futures):
54
+ self.future_kernels[future._key].result()
55
+ active_mode.set(None)