triton-windows 3.1.0.post17__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +73 -0
- triton/backends/__init__.py +50 -0
- triton/backends/amd/compiler.py +262 -0
- triton/backends/amd/driver.c +211 -0
- triton/backends/amd/driver.py +497 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
- triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
- triton/backends/amd/include/hip/channel_descriptor.h +39 -0
- triton/backends/amd/include/hip/device_functions.h +38 -0
- triton/backends/amd/include/hip/driver_types.h +468 -0
- triton/backends/amd/include/hip/hip_bf16.h +36 -0
- triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
- triton/backends/amd/include/hip/hip_common.h +100 -0
- triton/backends/amd/include/hip/hip_complex.h +38 -0
- triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
- triton/backends/amd/include/hip/hip_deprecated.h +95 -0
- triton/backends/amd/include/hip/hip_ext.h +159 -0
- triton/backends/amd/include/hip/hip_fp16.h +36 -0
- triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
- triton/backends/amd/include/hip/hip_hcc.h +24 -0
- triton/backends/amd/include/hip/hip_math_constants.h +36 -0
- triton/backends/amd/include/hip/hip_profile.h +27 -0
- triton/backends/amd/include/hip/hip_runtime.h +75 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
- triton/backends/amd/include/hip/hip_texture_types.h +29 -0
- triton/backends/amd/include/hip/hip_vector_types.h +41 -0
- triton/backends/amd/include/hip/hip_version.h +17 -0
- triton/backends/amd/include/hip/hiprtc.h +421 -0
- triton/backends/amd/include/hip/library_types.h +78 -0
- triton/backends/amd/include/hip/math_functions.h +42 -0
- triton/backends/amd/include/hip/surface_types.h +63 -0
- triton/backends/amd/include/hip/texture_types.h +194 -0
- triton/backends/amd/include/hsa/Brig.h +1131 -0
- triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
- triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
- triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
- triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
- triton/backends/amd/include/hsa/hsa.h +5729 -0
- triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
- triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
- triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
- triton/backends/amd/include/roctracer/roctracer.h +779 -0
- triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
- triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
- triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
- triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
- triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
- triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
- triton/backends/amd/include/roctracer/roctx.h +229 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +76 -0
- triton/backends/driver.py +34 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +347 -0
- triton/backends/nvidia/driver.c +451 -0
- triton/backends/nvidia/driver.py +430 -0
- triton/backends/nvidia/include/cuda.h +24359 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +4 -0
- triton/compiler/code_generator.py +1302 -0
- triton/compiler/compiler.py +416 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/language/__init__.py +284 -0
- triton/language/core.py +2621 -0
- triton/language/extra/__init__.py +4 -0
- triton/language/extra/cuda/__init__.py +8 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +3 -0
- triton/language/extra/hip/libdevice.py +468 -0
- triton/language/extra/libdevice.py +1213 -0
- triton/language/math.py +250 -0
- triton/language/random.py +207 -0
- triton/language/semantic.py +1621 -0
- triton/language/standard.py +441 -0
- triton/ops/__init__.py +7 -0
- triton/ops/blocksparse/__init__.py +7 -0
- triton/ops/blocksparse/matmul.py +432 -0
- triton/ops/blocksparse/softmax.py +228 -0
- triton/ops/cross_entropy.py +96 -0
- triton/ops/flash_attention.py +466 -0
- triton/ops/matmul.py +219 -0
- triton/ops/matmul_perf_model.py +171 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/autotuner.py +361 -0
- triton/runtime/build.py +129 -0
- triton/runtime/cache.py +289 -0
- triton/runtime/driver.py +60 -0
- triton/runtime/errors.py +26 -0
- triton/runtime/interpreter.py +1127 -0
- triton/runtime/jit.py +956 -0
- triton/runtime/tcc/include/_mingw.h +170 -0
- triton/runtime/tcc/include/assert.h +57 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +57 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/limits.h +111 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +737 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdarg.h +79 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +54 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +580 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +118 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/winbase.h +2951 -0
- triton/runtime/tcc/include/winapi/wincon.h +301 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnt.h +5835 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +496 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.c +67 -0
- triton/tools/compile.h +14 -0
- triton/tools/compile.py +145 -0
- triton/tools/disasm.py +142 -0
- triton/tools/link.py +322 -0
- triton/windows_utils.py +373 -0
- triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
- triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
- triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
- triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
|
@@ -0,0 +1,441 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..runtime.jit import jit
|
|
4
|
+
from . import core
|
|
5
|
+
from . import math
|
|
6
|
+
|
|
7
|
+
# constexpr utilities (triton metaprogramming sucks)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _unwrap_if_constexpr(o):
|
|
11
|
+
return o.value if isinstance(o, core.constexpr) else o
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _log2(i: core.constexpr):
|
|
15
|
+
log2 = 0
|
|
16
|
+
n = i.value
|
|
17
|
+
while n > 1:
|
|
18
|
+
n >>= 1
|
|
19
|
+
log2 += 1
|
|
20
|
+
return core.constexpr(log2)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _is_power_of_two(i: core.constexpr):
|
|
24
|
+
n = i.value
|
|
25
|
+
return core.constexpr((n & (n - 1)) == 0 and n != 0)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# -----------------------
|
|
29
|
+
# Standard library
|
|
30
|
+
# -----------------------
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@core._tensor_member_fn
|
|
34
|
+
@jit
|
|
35
|
+
def cdiv(x, div):
|
|
36
|
+
"""
|
|
37
|
+
Computes the ceiling division of :code:`x` by :code:`div`
|
|
38
|
+
|
|
39
|
+
:param x: the input number
|
|
40
|
+
:type x: Block
|
|
41
|
+
:param div: the divisor
|
|
42
|
+
:param div: Block
|
|
43
|
+
"""
|
|
44
|
+
return (x + div - 1) // div
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@core._tensor_member_fn
|
|
48
|
+
@jit
|
|
49
|
+
@math._add_math_1arg_docstr("sigmoid")
|
|
50
|
+
def sigmoid(x):
|
|
51
|
+
return 1 / (1 + math.exp(-x))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@core._tensor_member_fn
|
|
55
|
+
@jit
|
|
56
|
+
@math._add_math_1arg_docstr("softmax")
|
|
57
|
+
def softmax(x, ieee_rounding=False):
|
|
58
|
+
z = x - max(x, 0)
|
|
59
|
+
num = math.exp(z)
|
|
60
|
+
den = sum(num, 0)
|
|
61
|
+
return math.fdiv(num, den, ieee_rounding)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@core._tensor_member_fn
|
|
65
|
+
@jit
|
|
66
|
+
def ravel(x):
|
|
67
|
+
"""
|
|
68
|
+
Returns a contiguous flattened view of :code:`x`.
|
|
69
|
+
|
|
70
|
+
:param x: the input tensor
|
|
71
|
+
:type x: Block
|
|
72
|
+
"""
|
|
73
|
+
return core.reshape(x, [x.numel], can_reorder=True)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@jit
|
|
77
|
+
def swizzle2d(i, j, size_i, size_j, size_g):
|
|
78
|
+
"""
|
|
79
|
+
Transforms indices of a row-major :code:`size_i * size_j` matrix into those
|
|
80
|
+
of one where the indices are col-major for each group of :code:`size_g`
|
|
81
|
+
rows.
|
|
82
|
+
|
|
83
|
+
For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will
|
|
84
|
+
transform ::
|
|
85
|
+
|
|
86
|
+
[[0 , 1 , 2 , 3 ],
|
|
87
|
+
[4 , 5 , 6 , 7 ],
|
|
88
|
+
[8 , 9 , 10, 11],
|
|
89
|
+
[12, 13, 14, 15]]
|
|
90
|
+
|
|
91
|
+
into ::
|
|
92
|
+
|
|
93
|
+
[[0, 2, 4 , 6 ],
|
|
94
|
+
[1, 3, 5 , 7 ],
|
|
95
|
+
[8, 10, 12, 14],
|
|
96
|
+
[9, 11, 13, 15]]
|
|
97
|
+
"""
|
|
98
|
+
# "unrolled index in array"
|
|
99
|
+
ij = i * size_j + j
|
|
100
|
+
# number of elements in `size_g` groups
|
|
101
|
+
# of `size_j` columns
|
|
102
|
+
size_gj = size_g * size_j
|
|
103
|
+
# index of the group in which (i,j) is
|
|
104
|
+
group_id = ij // size_gj
|
|
105
|
+
# row-index of the first element of this group
|
|
106
|
+
off_i = group_id * size_g
|
|
107
|
+
# last group may have fewer rows
|
|
108
|
+
size_g = core.minimum(size_i - off_i, size_g)
|
|
109
|
+
# new row and column indices
|
|
110
|
+
new_i = off_i + (ij % size_g)
|
|
111
|
+
new_j = (ij % size_gj) // size_g
|
|
112
|
+
return new_i, new_j
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@jit
|
|
116
|
+
def zeros(shape, dtype):
|
|
117
|
+
"""
|
|
118
|
+
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
|
119
|
+
|
|
120
|
+
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
|
121
|
+
:type shape: tuple of ints
|
|
122
|
+
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
|
123
|
+
:type dtype: DType
|
|
124
|
+
"""
|
|
125
|
+
return core.full(shape, 0, dtype)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@jit
|
|
129
|
+
def zeros_like(input):
|
|
130
|
+
"""
|
|
131
|
+
Creates a tensor of zeros with the same shape and type as a given tensor.
|
|
132
|
+
"""
|
|
133
|
+
return zeros(input.shape, input.dtype)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# max and argmax
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@jit
|
|
140
|
+
def _argmax_combine(value1, index1, value2, index2, tie_break_left):
|
|
141
|
+
if tie_break_left:
|
|
142
|
+
tie = value1 == value2 and index1 < index2
|
|
143
|
+
else:
|
|
144
|
+
tie = False
|
|
145
|
+
gt = value1 > value2 or tie
|
|
146
|
+
v_ret = core.where(gt, value1, value2)
|
|
147
|
+
i_ret = core.where(gt, index1, index2)
|
|
148
|
+
return v_ret, i_ret
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@jit
|
|
152
|
+
def _argmax_combine_tie_break_left(value1, index1, value2, index2):
|
|
153
|
+
return _argmax_combine(value1, index1, value2, index2, True)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@jit
|
|
157
|
+
def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
|
|
158
|
+
return _argmax_combine(value1, index1, value2, index2, False)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@jit
|
|
162
|
+
def _elementwise_max(a, b):
|
|
163
|
+
return core.maximum(a, b)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@core._tensor_member_fn
|
|
167
|
+
@jit
|
|
168
|
+
@core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
|
|
169
|
+
tie_break_arg="return_indices_tie_break_left")
|
|
170
|
+
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
|
|
171
|
+
input = core._promote_bfloat16_to_float32(input)
|
|
172
|
+
if return_indices:
|
|
173
|
+
if return_indices_tie_break_left:
|
|
174
|
+
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims)
|
|
175
|
+
else:
|
|
176
|
+
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims)
|
|
177
|
+
else:
|
|
178
|
+
if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
|
|
179
|
+
if core.constexpr(input.dtype.is_floating()):
|
|
180
|
+
input = input.to(core.float32)
|
|
181
|
+
else:
|
|
182
|
+
assert input.dtype.is_int(), "Expecting input to be integer type"
|
|
183
|
+
input = input.to(core.int32)
|
|
184
|
+
return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@core._tensor_member_fn
|
|
188
|
+
@jit
|
|
189
|
+
@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left")
|
|
190
|
+
def argmax(input, axis, tie_break_left=True, keep_dims=False):
|
|
191
|
+
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
|
|
192
|
+
return ret
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# min and argmin
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
@jit
|
|
199
|
+
def _argmin_combine(value1, index1, value2, index2, tie_break_left):
|
|
200
|
+
if tie_break_left:
|
|
201
|
+
tie = value1 == value2 and index1 < index2
|
|
202
|
+
else:
|
|
203
|
+
tie = False
|
|
204
|
+
lt = value1 < value2 or tie
|
|
205
|
+
value_ret = core.where(lt, value1, value2)
|
|
206
|
+
index_ret = core.where(lt, index1, index2)
|
|
207
|
+
return value_ret, index_ret
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@jit
|
|
211
|
+
def _argmin_combine_tie_break_left(value1, index1, value2, index2):
|
|
212
|
+
return _argmin_combine(value1, index1, value2, index2, True)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@jit
|
|
216
|
+
def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
|
|
217
|
+
return _argmin_combine(value1, index1, value2, index2, False)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@jit
|
|
221
|
+
def _elementwise_min(a, b):
|
|
222
|
+
return core.minimum(a, b)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@core._tensor_member_fn
|
|
226
|
+
@jit
|
|
227
|
+
@core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
|
|
228
|
+
tie_break_arg="return_indices_tie_break_left")
|
|
229
|
+
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
|
|
230
|
+
input = core._promote_bfloat16_to_float32(input)
|
|
231
|
+
if return_indices:
|
|
232
|
+
if return_indices_tie_break_left:
|
|
233
|
+
return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims)
|
|
234
|
+
else:
|
|
235
|
+
return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims)
|
|
236
|
+
else:
|
|
237
|
+
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
|
|
238
|
+
if core.constexpr(input.dtype.is_floating()):
|
|
239
|
+
input = input.to(core.float32)
|
|
240
|
+
else:
|
|
241
|
+
assert input.dtype.is_int(), "Expecting input to be integer type"
|
|
242
|
+
input = input.to(core.int32)
|
|
243
|
+
return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@core._tensor_member_fn
|
|
247
|
+
@jit
|
|
248
|
+
@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
|
|
249
|
+
def argmin(input, axis, tie_break_left=True, keep_dims=False):
|
|
250
|
+
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
|
|
251
|
+
return ret
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@jit
|
|
255
|
+
def _sum_combine(a, b):
|
|
256
|
+
return a + b
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
# sum
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
@core._tensor_member_fn
|
|
263
|
+
@jit
|
|
264
|
+
@core._add_reduction_docstr("sum")
|
|
265
|
+
def sum(input, axis=None, keep_dims=False):
|
|
266
|
+
input = core._promote_bfloat16_to_float32(input)
|
|
267
|
+
return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@jit
|
|
271
|
+
def _xor_combine(a, b):
|
|
272
|
+
return a ^ b
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# xor sum
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@core._tensor_member_fn
|
|
279
|
+
@core.builtin
|
|
280
|
+
@core._add_reduction_docstr("xor sum")
|
|
281
|
+
def xor_sum(input, axis=None, keep_dims=False, _builder=None, _generator=None):
|
|
282
|
+
scalar_ty = input.type.scalar
|
|
283
|
+
if not scalar_ty.is_int():
|
|
284
|
+
raise ValueError("xor_sum only supported for integers")
|
|
285
|
+
|
|
286
|
+
input = core._promote_bfloat16_to_float32(input, _builder=_builder)
|
|
287
|
+
return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
# cumsum
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@core._tensor_member_fn
|
|
294
|
+
@jit
|
|
295
|
+
@core._add_scan_docstr("cumsum")
|
|
296
|
+
def cumsum(input, axis=0, reverse=False):
|
|
297
|
+
# todo rename this to a generic function name
|
|
298
|
+
input = core._promote_bfloat16_to_float32(input)
|
|
299
|
+
return core.associative_scan(input, axis, _sum_combine, reverse)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# cumprod
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@jit
|
|
306
|
+
def _prod_combine(a, b):
|
|
307
|
+
return a * b
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@core._tensor_member_fn
|
|
311
|
+
@jit
|
|
312
|
+
@core._add_scan_docstr("cumprod")
|
|
313
|
+
def cumprod(input, axis=0, reverse=False):
|
|
314
|
+
# todo rename this to a generic function name
|
|
315
|
+
input = core._promote_bfloat16_to_float32(input)
|
|
316
|
+
return core.associative_scan(input, axis, _prod_combine, reverse)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# sort
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@jit
|
|
323
|
+
def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr):
|
|
324
|
+
n_outer: core.constexpr = x.numel >> n_dims
|
|
325
|
+
shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
|
|
326
|
+
y = core.reshape(x, shape)
|
|
327
|
+
# slice left/right with 'stride' 2**(n_dims - i - 1)
|
|
328
|
+
mask = core.arange(0, 2)[None, :, None]
|
|
329
|
+
left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape)
|
|
330
|
+
right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape)
|
|
331
|
+
left = core.reshape(left, x.shape)
|
|
332
|
+
right = core.reshape(right, x.shape)
|
|
333
|
+
# actual compare-and-swap
|
|
334
|
+
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
335
|
+
ileft = left.to(idtype, bitcast=True)
|
|
336
|
+
iright = right.to(idtype, bitcast=True)
|
|
337
|
+
ix = x.to(idtype, bitcast=True)
|
|
338
|
+
ret = ix ^ core.where((left > right) ^ flip, ileft ^ iright, zeros_like(ix))
|
|
339
|
+
return ret.to(x.dtype, bitcast=True)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@jit
|
|
343
|
+
def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
|
|
344
|
+
'''
|
|
345
|
+
order_type 0 == ascending
|
|
346
|
+
order_type 1 == descending
|
|
347
|
+
order_type 2 == alternating
|
|
348
|
+
'''
|
|
349
|
+
n_outer: core.constexpr = x.numel >> n_dims
|
|
350
|
+
core.static_assert(stage <= n_dims)
|
|
351
|
+
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
|
|
352
|
+
# descending order.
|
|
353
|
+
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
|
|
354
|
+
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
|
|
355
|
+
# a stride of 2) at this stage
|
|
356
|
+
if order == 2:
|
|
357
|
+
shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]
|
|
358
|
+
flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)
|
|
359
|
+
else:
|
|
360
|
+
flip = order
|
|
361
|
+
# perform `stage` rounds of `compare-and-swap`
|
|
362
|
+
for i in core.static_range(stage):
|
|
363
|
+
x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims)
|
|
364
|
+
return x
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
@core._tensor_member_fn
|
|
368
|
+
@jit
|
|
369
|
+
def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
370
|
+
# handle default dimension or check that it is the most minor dim
|
|
371
|
+
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
|
|
372
|
+
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
|
|
373
|
+
# iteratively run bitonic merge-sort steps
|
|
374
|
+
n_dims: core.constexpr = _log2(x.shape[_dim])
|
|
375
|
+
for i in core.static_range(1, n_dims + 1):
|
|
376
|
+
x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims)
|
|
377
|
+
return x
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
# flip
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _get_flip_dim(dim, shape):
|
|
384
|
+
dim = _unwrap_if_constexpr(dim)
|
|
385
|
+
shape = _unwrap_if_constexpr(shape)
|
|
386
|
+
if dim is None:
|
|
387
|
+
dim = len(shape) - 1
|
|
388
|
+
assert dim == len(shape) - 1, "Currently only support flipping the last dimension"
|
|
389
|
+
return core.constexpr(dim)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
@core._tensor_member_fn
|
|
393
|
+
@jit
|
|
394
|
+
def flip(x, dim=None):
|
|
395
|
+
"""
|
|
396
|
+
Flips a tensor `x` along the dimension `dim`.
|
|
397
|
+
|
|
398
|
+
:param x: the first input tensor
|
|
399
|
+
:type x: Block
|
|
400
|
+
:param dim: the dimension to flip along (currently only final dimension supported)
|
|
401
|
+
:type dim: int
|
|
402
|
+
"""
|
|
403
|
+
core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)]))
|
|
404
|
+
core.static_assert(_is_power_of_two(x.numel))
|
|
405
|
+
# # reshape the tensor to have all dimensions be 2.
|
|
406
|
+
# # TODO: We shouldn't have to change the dimensions not sorted.
|
|
407
|
+
steps: core.constexpr = _log2(x.numel)
|
|
408
|
+
start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
|
|
409
|
+
y = core.reshape(x, [2] * steps)
|
|
410
|
+
y = core.expand_dims(y, start)
|
|
411
|
+
flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2))
|
|
412
|
+
for i in core.static_range(start, steps):
|
|
413
|
+
flip2 = flip
|
|
414
|
+
for j in core.static_range(0, steps + 1):
|
|
415
|
+
if j != i and j != i + 1:
|
|
416
|
+
flip2 = core.expand_dims(flip2, j)
|
|
417
|
+
y = sum(y * flip2, i + 1, keep_dims=True)
|
|
418
|
+
x = core.reshape(y, x.shape)
|
|
419
|
+
return x
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
@jit
|
|
423
|
+
def interleave(a, b):
|
|
424
|
+
"""
|
|
425
|
+
Interleaves the values of two tensors along their last dimension.
|
|
426
|
+
|
|
427
|
+
The two tensors must have the same shape.
|
|
428
|
+
|
|
429
|
+
Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])`
|
|
430
|
+
"""
|
|
431
|
+
c = core.join(a, b)
|
|
432
|
+
|
|
433
|
+
assert isinstance(c.shape, list)
|
|
434
|
+
if len(c.shape) == 1:
|
|
435
|
+
# We must have interleaved two scalars.
|
|
436
|
+
return c
|
|
437
|
+
else:
|
|
438
|
+
# This `else` is necessary because Triton's AST parser doesn't
|
|
439
|
+
# understand that if we take the `if` above we definitely don't run this
|
|
440
|
+
# `else`.
|
|
441
|
+
return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]])
|
triton/ops/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# from .conv import _conv, conv
|
|
2
|
+
from . import blocksparse
|
|
3
|
+
from .cross_entropy import _cross_entropy, cross_entropy
|
|
4
|
+
from .flash_attention import attention
|
|
5
|
+
from .matmul import _matmul, get_higher_dtype, matmul
|
|
6
|
+
|
|
7
|
+
__all__ = ["blocksparse", "_cross_entropy", "cross_entropy", "_matmul", "matmul", "attention", "get_higher_dtype"]
|