triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,534 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..runtime.jit import jit, constexpr_function
|
|
4
|
+
from . import core
|
|
5
|
+
from . import math
|
|
6
|
+
|
|
7
|
+
# constexpr utilities
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@constexpr_function
|
|
11
|
+
def _log2(i):
|
|
12
|
+
log2 = 0
|
|
13
|
+
n = i
|
|
14
|
+
while n > 1:
|
|
15
|
+
n >>= 1
|
|
16
|
+
log2 += 1
|
|
17
|
+
return log2
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@constexpr_function
|
|
21
|
+
def _is_power_of_two(i):
|
|
22
|
+
return (i & (i - 1)) == 0 and i != 0
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# -----------------------
|
|
26
|
+
# Standard library
|
|
27
|
+
# -----------------------
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@core._tensor_member_fn
|
|
31
|
+
@jit
|
|
32
|
+
def cdiv(x, div):
|
|
33
|
+
"""
|
|
34
|
+
Computes the ceiling division of :code:`x` by :code:`div`
|
|
35
|
+
|
|
36
|
+
:param x: the input number
|
|
37
|
+
:type x: Block
|
|
38
|
+
:param div: the divisor
|
|
39
|
+
:type div: Block
|
|
40
|
+
"""
|
|
41
|
+
return (x + div - 1) // div
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@core._tensor_member_fn
|
|
45
|
+
@jit
|
|
46
|
+
@math._add_math_1arg_docstr("sigmoid")
|
|
47
|
+
def sigmoid(x):
|
|
48
|
+
return 1 / (1 + math.exp(-x))
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@core._tensor_member_fn
|
|
52
|
+
@jit
|
|
53
|
+
@math._add_math_1arg_docstr("softmax")
|
|
54
|
+
def softmax(x, dim=None, keep_dims=False, ieee_rounding=False):
|
|
55
|
+
if dim is None:
|
|
56
|
+
_dim: core.constexpr = 0
|
|
57
|
+
else:
|
|
58
|
+
_dim: core.constexpr = dim
|
|
59
|
+
z = x - max(x, _dim, keep_dims=keep_dims)
|
|
60
|
+
num = math.exp(z)
|
|
61
|
+
den = sum(num, _dim, keep_dims=keep_dims)
|
|
62
|
+
return math.fdiv(num, den, ieee_rounding)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@core._tensor_member_fn
|
|
66
|
+
@jit
|
|
67
|
+
def ravel(x, can_reorder=False):
|
|
68
|
+
"""
|
|
69
|
+
Returns a contiguous flattened view of :code:`x`.
|
|
70
|
+
|
|
71
|
+
:param x: the input tensor
|
|
72
|
+
:type x: Block
|
|
73
|
+
"""
|
|
74
|
+
return core.reshape(x, [x.numel], can_reorder=can_reorder)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@jit
|
|
78
|
+
def swizzle2d(i, j, size_i, size_j, size_g):
|
|
79
|
+
"""
|
|
80
|
+
Transforms the indices of a row-major `size_i * size_j` matrix into
|
|
81
|
+
the indices of a column-major matrix for each group of `size_g` rows.
|
|
82
|
+
|
|
83
|
+
For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will
|
|
84
|
+
transform ::
|
|
85
|
+
|
|
86
|
+
[[0 , 1 , 2 , 3 ],
|
|
87
|
+
[4 , 5 , 6 , 7 ],
|
|
88
|
+
[8 , 9 , 10, 11],
|
|
89
|
+
[12, 13, 14, 15]]
|
|
90
|
+
|
|
91
|
+
into ::
|
|
92
|
+
|
|
93
|
+
[[0, 2, 4 , 6 ],
|
|
94
|
+
[1, 3, 5 , 7 ],
|
|
95
|
+
[8, 10, 12, 14],
|
|
96
|
+
[9, 11, 13, 15]]
|
|
97
|
+
"""
|
|
98
|
+
# "unrolled index in array"
|
|
99
|
+
ij = i * size_j + j
|
|
100
|
+
# number of elements in `size_g` groups
|
|
101
|
+
# of `size_j` columns
|
|
102
|
+
size_gj = size_g * size_j
|
|
103
|
+
# index of the group in which (i,j) is
|
|
104
|
+
group_id = ij // size_gj
|
|
105
|
+
# row-index of the first element of this group
|
|
106
|
+
off_i = group_id * size_g
|
|
107
|
+
# last group may have fewer rows
|
|
108
|
+
size_g = core.minimum(size_i - off_i, size_g)
|
|
109
|
+
# linear index with respect to the first element in this group
|
|
110
|
+
ij = ij % size_gj
|
|
111
|
+
# new row and column indices
|
|
112
|
+
new_i = off_i + ij % size_g
|
|
113
|
+
new_j = ij // size_g
|
|
114
|
+
return new_i, new_j
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@jit
|
|
118
|
+
def zeros(shape, dtype):
|
|
119
|
+
"""
|
|
120
|
+
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
|
121
|
+
|
|
122
|
+
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
|
123
|
+
:type shape: tuple of ints
|
|
124
|
+
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
|
125
|
+
:type dtype: DType
|
|
126
|
+
"""
|
|
127
|
+
return core.full(shape, 0, dtype)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@jit
|
|
131
|
+
def zeros_like(input):
|
|
132
|
+
"""
|
|
133
|
+
Returns a tensor of zeros with the same shape and type as a given tensor.
|
|
134
|
+
|
|
135
|
+
:param input: input tensor
|
|
136
|
+
:type input: Tensor
|
|
137
|
+
"""
|
|
138
|
+
return zeros(input.shape, input.dtype)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# max and argmax
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@jit
|
|
145
|
+
def _argmax_combine(value1, index1, value2, index2, tie_break_left):
|
|
146
|
+
if tie_break_left:
|
|
147
|
+
tie = value1 == value2 and index1 < index2
|
|
148
|
+
else:
|
|
149
|
+
tie = False
|
|
150
|
+
gt = value1 > value2 or tie
|
|
151
|
+
v_ret = core.where(gt, value1, value2)
|
|
152
|
+
i_ret = core.where(gt, index1, index2)
|
|
153
|
+
return v_ret, i_ret
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@jit
|
|
157
|
+
def _argmax_combine_tie_break_left(value1, index1, value2, index2):
|
|
158
|
+
return _argmax_combine(value1, index1, value2, index2, True)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@jit
|
|
162
|
+
def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
|
|
163
|
+
return _argmax_combine(value1, index1, value2, index2, False)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@jit
|
|
167
|
+
def _elementwise_max(a, b):
|
|
168
|
+
return core.maximum(a, b)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@core._tensor_member_fn
|
|
172
|
+
@jit
|
|
173
|
+
@core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
|
|
174
|
+
tie_break_arg="return_indices_tie_break_left")
|
|
175
|
+
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
|
|
176
|
+
input = core._promote_bfloat16_to_float32(input)
|
|
177
|
+
if return_indices:
|
|
178
|
+
if return_indices_tie_break_left:
|
|
179
|
+
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims)
|
|
180
|
+
else:
|
|
181
|
+
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims)
|
|
182
|
+
else:
|
|
183
|
+
if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
|
|
184
|
+
if core.constexpr(input.dtype.is_floating()):
|
|
185
|
+
input = input.to(core.float32)
|
|
186
|
+
else:
|
|
187
|
+
assert input.dtype.is_int(), "Expecting input to be integer type"
|
|
188
|
+
input = input.to(core.int32)
|
|
189
|
+
return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@core._tensor_member_fn
|
|
193
|
+
@jit
|
|
194
|
+
@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left")
|
|
195
|
+
def argmax(input, axis, tie_break_left=True, keep_dims=False):
|
|
196
|
+
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
|
|
197
|
+
return ret
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
# min and argmin
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@jit
|
|
204
|
+
def _argmin_combine(value1, index1, value2, index2, tie_break_left):
|
|
205
|
+
if tie_break_left:
|
|
206
|
+
tie = value1 == value2 and index1 < index2
|
|
207
|
+
else:
|
|
208
|
+
tie = False
|
|
209
|
+
lt = value1 < value2 or tie
|
|
210
|
+
value_ret = core.where(lt, value1, value2)
|
|
211
|
+
index_ret = core.where(lt, index1, index2)
|
|
212
|
+
return value_ret, index_ret
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@jit
|
|
216
|
+
def _argmin_combine_tie_break_left(value1, index1, value2, index2):
|
|
217
|
+
return _argmin_combine(value1, index1, value2, index2, True)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@jit
|
|
221
|
+
def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
|
|
222
|
+
return _argmin_combine(value1, index1, value2, index2, False)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@jit
|
|
226
|
+
def _elementwise_min(a, b):
|
|
227
|
+
return core.minimum(a, b)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@core._tensor_member_fn
|
|
231
|
+
@jit
|
|
232
|
+
@core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
|
|
233
|
+
tie_break_arg="return_indices_tie_break_left")
|
|
234
|
+
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
|
|
235
|
+
input = core._promote_bfloat16_to_float32(input)
|
|
236
|
+
if return_indices:
|
|
237
|
+
if return_indices_tie_break_left:
|
|
238
|
+
return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims)
|
|
239
|
+
else:
|
|
240
|
+
return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims)
|
|
241
|
+
else:
|
|
242
|
+
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
|
|
243
|
+
if core.constexpr(input.dtype.is_floating()):
|
|
244
|
+
input = input.to(core.float32)
|
|
245
|
+
else:
|
|
246
|
+
assert input.dtype.is_int(), "Expecting input to be integer type"
|
|
247
|
+
input = input.to(core.int32)
|
|
248
|
+
return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@core._tensor_member_fn
|
|
252
|
+
@jit
|
|
253
|
+
@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
|
|
254
|
+
def argmin(input, axis, tie_break_left=True, keep_dims=False):
|
|
255
|
+
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
|
|
256
|
+
return ret
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@jit
|
|
260
|
+
def _sum_combine(a, b):
|
|
261
|
+
return a + b
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
# sum
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@constexpr_function
|
|
268
|
+
def _pick_sum_dtype(in_dtype, dtype):
|
|
269
|
+
if dtype is not None:
|
|
270
|
+
return dtype
|
|
271
|
+
|
|
272
|
+
# For integer bitwidths less than 32, pick int32 with the same sign to
|
|
273
|
+
# avoid overflow.
|
|
274
|
+
out_dtype = None
|
|
275
|
+
if in_dtype.is_int_signed():
|
|
276
|
+
out_dtype = core.int32 if in_dtype.int_bitwidth < 32 else None
|
|
277
|
+
elif in_dtype.is_int_unsigned():
|
|
278
|
+
out_dtype = core.uint32 if in_dtype.int_bitwidth < 32 else None
|
|
279
|
+
return out_dtype
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@core._tensor_member_fn
|
|
283
|
+
@jit
|
|
284
|
+
@core._add_reduction_docstr("sum", dtype_arg="dtype")
|
|
285
|
+
def sum(input, axis=None, keep_dims=False, dtype: core.constexpr = None):
|
|
286
|
+
# Pick a default dtype for the reduction if one was not specified.
|
|
287
|
+
out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
|
|
288
|
+
|
|
289
|
+
if out_dtype is not None:
|
|
290
|
+
input = input.to(out_dtype)
|
|
291
|
+
return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
@jit
|
|
295
|
+
def _xor_combine(a, b):
|
|
296
|
+
return a ^ b
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
# xor sum
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@core._tensor_member_fn
|
|
303
|
+
@jit
|
|
304
|
+
@core._add_reduction_docstr("xor sum")
|
|
305
|
+
def xor_sum(input, axis=None, keep_dims=False):
|
|
306
|
+
core.static_assert(input.type.scalar.is_int(), "xor_sum only supported for integers")
|
|
307
|
+
return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
# or reduction
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@jit
|
|
314
|
+
def _or_combine(x, y):
|
|
315
|
+
return x | y
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@core._tensor_member_fn
|
|
319
|
+
@jit
|
|
320
|
+
@core._add_reduction_docstr("reduce_or")
|
|
321
|
+
def reduce_or(input, axis, keep_dims=False):
|
|
322
|
+
core.static_assert(input.type.scalar.is_int(), "reduce_or only supported for integers")
|
|
323
|
+
return core.reduce(input, axis, _or_combine, keep_dims=keep_dims)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
# cumsum
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
@core._tensor_member_fn
|
|
330
|
+
@jit
|
|
331
|
+
@core._add_scan_docstr("cumsum", dtype_arg="dtype")
|
|
332
|
+
def cumsum(input, axis=0, reverse=False, dtype: core.constexpr = None):
|
|
333
|
+
# todo rename this to a generic function name
|
|
334
|
+
|
|
335
|
+
input = core._promote_bfloat16_to_float32(input)
|
|
336
|
+
out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
|
|
337
|
+
|
|
338
|
+
if out_dtype is not None:
|
|
339
|
+
input = input.to(out_dtype)
|
|
340
|
+
|
|
341
|
+
return core.associative_scan(input, axis, _sum_combine, reverse)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
# cumprod
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@jit
|
|
348
|
+
def _prod_combine(a, b):
|
|
349
|
+
return a * b
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@core._tensor_member_fn
|
|
353
|
+
@jit
|
|
354
|
+
@core._add_scan_docstr("cumprod")
|
|
355
|
+
def cumprod(input, axis=0, reverse=False):
|
|
356
|
+
# todo rename this to a generic function name
|
|
357
|
+
input = core._promote_bfloat16_to_float32(input)
|
|
358
|
+
return core.associative_scan(input, axis, _prod_combine, reverse)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
# sort
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@jit
|
|
365
|
+
def _indicator(n_dims: core.constexpr, j: core.constexpr):
|
|
366
|
+
ar = core.arange(0, 2)
|
|
367
|
+
ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j)
|
|
368
|
+
return ar
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
@jit
|
|
372
|
+
def _compare_and_swap(x, flip, i: core.constexpr):
|
|
373
|
+
# compare-and-swap on the ith *innermost* dimension
|
|
374
|
+
n_dims: core.constexpr = _log2(x.numel)
|
|
375
|
+
|
|
376
|
+
# flip along middle dimension (the bitwise XORs will be optimised away):
|
|
377
|
+
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
378
|
+
ix = x.to(idtype, bitcast=True)
|
|
379
|
+
iy = ix ^ xor_sum(ix, n_dims - 1 - i, True)
|
|
380
|
+
y = iy.to(x.dtype, bitcast=True)
|
|
381
|
+
|
|
382
|
+
# determines whether we are in the right (rather than left) position along the axis:
|
|
383
|
+
is_right = _indicator(n_dims, i)
|
|
384
|
+
|
|
385
|
+
# conditional swap:
|
|
386
|
+
ret = core.where((x > y) != (flip ^ is_right), y, x)
|
|
387
|
+
return ret
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
@jit
|
|
391
|
+
def _bitonic_merge_hypercube(x, stage: core.constexpr, order: core.constexpr):
|
|
392
|
+
'''
|
|
393
|
+
order_type 0 == ascending
|
|
394
|
+
order_type 1 == descending
|
|
395
|
+
order_type 2 == alternating
|
|
396
|
+
'''
|
|
397
|
+
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
|
|
398
|
+
# descending order.
|
|
399
|
+
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
|
|
400
|
+
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
|
|
401
|
+
# a stride of 2) at this stage
|
|
402
|
+
if order == 2:
|
|
403
|
+
flip = _indicator(_log2(x.numel), stage)
|
|
404
|
+
else:
|
|
405
|
+
flip = order
|
|
406
|
+
# perform `stage` rounds of `compare-and-swap`
|
|
407
|
+
for i in core.static_range(stage):
|
|
408
|
+
x = _compare_and_swap(x, flip, stage - 1 - i)
|
|
409
|
+
return x
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
@jit
|
|
413
|
+
def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
|
|
414
|
+
h = core.reshape(x, [2] * _log2(x.numel))
|
|
415
|
+
h = _bitonic_merge_hypercube(h, stage, order)
|
|
416
|
+
x = core.reshape(h, x.shape)
|
|
417
|
+
return x
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
@jit
|
|
421
|
+
def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
422
|
+
"""
|
|
423
|
+
Sorts a tensor along a specified dimension.
|
|
424
|
+
|
|
425
|
+
:param x: The input tensor to be sorted.
|
|
426
|
+
:type x: Tensor
|
|
427
|
+
:param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
|
|
428
|
+
:type dim: int, optional
|
|
429
|
+
:param k: the number of top elements to select. If none, assume k = x.shape[dim]
|
|
430
|
+
:type k: int, optional
|
|
431
|
+
:param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
|
|
432
|
+
:type descending: bool, optional
|
|
433
|
+
"""
|
|
434
|
+
# handle default dimension or check that it is the most minor dim
|
|
435
|
+
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
|
|
436
|
+
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
|
|
437
|
+
|
|
438
|
+
log_n: core.constexpr = _log2(x.shape[_dim])
|
|
439
|
+
log_k: core.constexpr = log_n if k is None else _log2(k)
|
|
440
|
+
|
|
441
|
+
n_dims: core.constexpr = _log2(x.numel)
|
|
442
|
+
|
|
443
|
+
# reshape to hypercube:
|
|
444
|
+
h = core.reshape(x, [2] * n_dims)
|
|
445
|
+
|
|
446
|
+
# run first log_k bitonic sort iterations:
|
|
447
|
+
for i in core.static_range(1, log_k + 1):
|
|
448
|
+
h = _bitonic_merge_hypercube(h, i, 2 if i < log_n else descending)
|
|
449
|
+
|
|
450
|
+
# select top k elements using bitonic top-k
|
|
451
|
+
# https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf
|
|
452
|
+
for i in core.static_range(log_k + 1, log_n + 1):
|
|
453
|
+
h = max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k))
|
|
454
|
+
h = _bitonic_merge_hypercube(h, log_k, 2 if i < log_n else descending)
|
|
455
|
+
|
|
456
|
+
# reshape back:
|
|
457
|
+
x = core.reshape(h, x.shape[:-1] + [2**log_k])
|
|
458
|
+
return x
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
@jit
|
|
462
|
+
def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
463
|
+
return sort_impl(x, dim=dim, descending=descending)
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
@jit
|
|
467
|
+
def topk(x, k: core.constexpr, dim: core.constexpr = None):
|
|
468
|
+
return sort_impl(x, k=k, dim=dim, descending=True)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
@jit
|
|
472
|
+
def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
473
|
+
# handle default dimension or check that it is the most minor dim
|
|
474
|
+
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
|
|
475
|
+
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
|
|
476
|
+
n_dims: core.constexpr = _log2(x.shape[-1])
|
|
477
|
+
return _bitonic_merge(x, n_dims, descending, n_dims)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
@constexpr_function
|
|
481
|
+
def _get_flip_dim(dim, shape):
|
|
482
|
+
if dim is None:
|
|
483
|
+
dim = len(shape) - 1
|
|
484
|
+
if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
|
|
485
|
+
dim += len(shape)
|
|
486
|
+
return dim
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
@core._tensor_member_fn
|
|
490
|
+
@jit
|
|
491
|
+
def flip(x, dim=None):
|
|
492
|
+
"""
|
|
493
|
+
Flips a tensor `x` along the dimension `dim`.
|
|
494
|
+
|
|
495
|
+
:param x: the first input tensor
|
|
496
|
+
:type x: Block
|
|
497
|
+
:param dim: the dimension to flip along
|
|
498
|
+
:type dim: int
|
|
499
|
+
"""
|
|
500
|
+
core.static_assert(-len(x.shape) <= dim and dim < len(x.shape))
|
|
501
|
+
_dim: core.constexpr = _get_flip_dim(dim, x.shape)
|
|
502
|
+
core.static_assert(_is_power_of_two(x.shape[_dim]))
|
|
503
|
+
steps: core.constexpr = _log2(x.shape[_dim])
|
|
504
|
+
|
|
505
|
+
# reshape the swap dimension to (2, 2, ..., 2)
|
|
506
|
+
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
507
|
+
y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:])
|
|
508
|
+
for i in core.static_range(steps):
|
|
509
|
+
y = y ^ xor_sum(y, _dim + i, True)
|
|
510
|
+
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
|
|
511
|
+
return x
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
@jit
|
|
515
|
+
def interleave(a, b):
|
|
516
|
+
"""
|
|
517
|
+
Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape.
|
|
518
|
+
Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])`
|
|
519
|
+
|
|
520
|
+
:param a: The first input tensor.
|
|
521
|
+
:type a: Tensor
|
|
522
|
+
:param b: The second input tensor.
|
|
523
|
+
:type b: Tensor
|
|
524
|
+
"""
|
|
525
|
+
c = core.join(a, b)
|
|
526
|
+
|
|
527
|
+
if len(c.shape) == 1:
|
|
528
|
+
# We must have interleaved two scalars.
|
|
529
|
+
return c
|
|
530
|
+
else:
|
|
531
|
+
# This `else` is necessary because Triton's AST parser doesn't
|
|
532
|
+
# understand that if we take the `if` above we definitely don't run this
|
|
533
|
+
# `else`.
|
|
534
|
+
return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]])
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from triton.runtime import driver
|
|
2
|
+
from triton.runtime.jit import constexpr_function
|
|
3
|
+
|
|
4
|
+
__all__ = ["current_target"]
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def current_target():
|
|
8
|
+
try:
|
|
9
|
+
active_driver = driver.active
|
|
10
|
+
except RuntimeError:
|
|
11
|
+
# If there is no active driver, return None
|
|
12
|
+
return None
|
|
13
|
+
return active_driver.get_current_target()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
current_target.__triton_builtin__ = True
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@constexpr_function
|
|
20
|
+
def is_cuda():
|
|
21
|
+
target = current_target()
|
|
22
|
+
return target is not None and target.backend == "cuda"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@constexpr_function
|
|
26
|
+
def cuda_capability_geq(major, minor=0):
|
|
27
|
+
"""
|
|
28
|
+
Determines whether we have compute capability >= (major, minor) and
|
|
29
|
+
returns this as a constexpr boolean. This can be used for guarding
|
|
30
|
+
inline asm implementations that require a certain compute capability.
|
|
31
|
+
"""
|
|
32
|
+
target = current_target()
|
|
33
|
+
if target is None or target.backend != "cuda":
|
|
34
|
+
return False
|
|
35
|
+
assert isinstance(target.arch, int)
|
|
36
|
+
return target.arch >= major * 10 + minor
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@constexpr_function
|
|
40
|
+
def is_hip():
|
|
41
|
+
target = current_target()
|
|
42
|
+
return target is not None and target.backend == "hip"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@constexpr_function
|
|
46
|
+
def is_hip_cdna3():
|
|
47
|
+
target = current_target()
|
|
48
|
+
return target is not None and target.arch == "gfx942"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@constexpr_function
|
|
52
|
+
def is_hip_cdna4():
|
|
53
|
+
target = current_target()
|
|
54
|
+
return target is not None and target.arch == "gfx950"
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics)
|
|
2
|
+
from .cache import RedisRemoteCacheBackend, RemoteCacheBackend
|
|
3
|
+
from .driver import driver
|
|
4
|
+
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
|
|
5
|
+
from .errors import OutOfResources, InterpreterError
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"autotune",
|
|
9
|
+
"Autotuner",
|
|
10
|
+
"Config",
|
|
11
|
+
"driver",
|
|
12
|
+
"Heuristics",
|
|
13
|
+
"heuristics",
|
|
14
|
+
"InterpreterError",
|
|
15
|
+
"JITFunction",
|
|
16
|
+
"KernelInterface",
|
|
17
|
+
"MockTensor",
|
|
18
|
+
"OutOfResources",
|
|
19
|
+
"RedisRemoteCacheBackend",
|
|
20
|
+
"reinterpret",
|
|
21
|
+
"RemoteCacheBackend",
|
|
22
|
+
"TensorWrapper",
|
|
23
|
+
]
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from typing import Optional, Protocol
|
|
2
|
+
from contextvars import ContextVar
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Buffer(Protocol):
|
|
6
|
+
|
|
7
|
+
def data_ptr(self) -> int:
|
|
8
|
+
...
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Allocator(Protocol):
|
|
12
|
+
|
|
13
|
+
def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
|
|
14
|
+
...
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class NullAllocator:
|
|
18
|
+
|
|
19
|
+
def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
|
|
20
|
+
raise RuntimeError("Kernel requires a runtime memory allocation, but no allocator was set. " +
|
|
21
|
+
"Use triton.set_allocator to specify an allocator.")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
_allocator: ContextVar[Allocator] = ContextVar("_allocator", default=NullAllocator())
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def set_allocator(allocator: Allocator):
|
|
28
|
+
"""
|
|
29
|
+
The allocator function is called during kernel launch for kernels that
|
|
30
|
+
require additional global memory workspace.
|
|
31
|
+
"""
|
|
32
|
+
_allocator.set(allocator)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
_profile_allocator: Allocator = ContextVar("_allocator", default=NullAllocator())
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def set_profile_allocator(allocator: Optional[Allocator]):
|
|
39
|
+
"""
|
|
40
|
+
The profile allocator function is called before kernel launch for kernels
|
|
41
|
+
that require additional global memory workspace.
|
|
42
|
+
"""
|
|
43
|
+
global _profile_allocator
|
|
44
|
+
_profile_allocator.set(allocator)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable, Optional
|
|
3
|
+
from concurrent.futures import Executor, as_completed, Future
|
|
4
|
+
from contextvars import ContextVar
|
|
5
|
+
|
|
6
|
+
active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FutureKernel:
|
|
10
|
+
|
|
11
|
+
def __init__(self, finalize_compile: Callable, future: Future):
|
|
12
|
+
self.finalize_compile = finalize_compile
|
|
13
|
+
self.kernel = None
|
|
14
|
+
self.future = future
|
|
15
|
+
|
|
16
|
+
def result(self):
|
|
17
|
+
if self.kernel is not None:
|
|
18
|
+
return self.kernel
|
|
19
|
+
|
|
20
|
+
kernel = self.future.result()
|
|
21
|
+
self.finalize_compile(kernel)
|
|
22
|
+
self.kernel = kernel
|
|
23
|
+
return kernel
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AsyncCompileMode:
|
|
27
|
+
|
|
28
|
+
def __init__(self, executor: Executor):
|
|
29
|
+
self.executor = executor
|
|
30
|
+
self.raw_futures = []
|
|
31
|
+
self.future_kernels = {}
|
|
32
|
+
|
|
33
|
+
def submit(self, key, compile_fn, finalize_fn):
|
|
34
|
+
future = self.future_kernels.get(key)
|
|
35
|
+
if future is not None:
|
|
36
|
+
return future
|
|
37
|
+
|
|
38
|
+
future = self.executor.submit(compile_fn)
|
|
39
|
+
future._key = key
|
|
40
|
+
self.raw_futures.append(future)
|
|
41
|
+
future_kernel = FutureKernel(finalize_fn, future)
|
|
42
|
+
self.future_kernels[key] = future_kernel
|
|
43
|
+
return future_kernel
|
|
44
|
+
|
|
45
|
+
def __enter__(self):
|
|
46
|
+
if active_mode.get() is not None:
|
|
47
|
+
raise RuntimeError("Another AsyncCompileMode is already active")
|
|
48
|
+
active_mode.set(self)
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
52
|
+
# Finalize any outstanding compiles
|
|
53
|
+
for future in as_completed(self.raw_futures):
|
|
54
|
+
self.future_kernels[future._key].result()
|
|
55
|
+
active_mode.set(None)
|