triton-windows 3.1.0.post17__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,2621 @@
1
+ from __future__ import annotations
2
+
3
+ from warnings import warn
4
+ from contextlib import contextmanager
5
+ from enum import Enum
6
+ from functools import partial, wraps
7
+ import typing
8
+ from typing import Union, Callable, List, Sequence, TypeVar, Optional
9
+ import builtins
10
+ from ..runtime.jit import jit
11
+ import inspect
12
+ import os
13
+
14
+ from .._C.libtriton import ir
15
+ from . import semantic
16
+
17
+ T = TypeVar('T')
18
+
19
+ TRITON_MAX_TENSOR_NUMEL = 1048576
20
+
21
+ TRITON_BUILTIN = "__triton_builtin__"
22
+
23
+ PropagateNan = ir.PROPAGATE_NAN
24
+
25
+
26
+ def builtin(fn: T) -> T:
27
+ """Mark a function as a builtin."""
28
+ assert callable(fn)
29
+
30
+ @wraps(fn)
31
+ def wrapper(*args, **kwargs):
32
+ if "_builder" not in kwargs or kwargs["_builder"] is None:
33
+ raise ValueError("Did you forget to add @triton.jit ? "
34
+ "(`_builder` argument must be provided outside of JIT functions.)")
35
+ return fn(*args, **kwargs)
36
+
37
+ setattr(wrapper, TRITON_BUILTIN, True)
38
+
39
+ return wrapper
40
+
41
+
42
+ def _tensor_member_fn(fn: T) -> T:
43
+ """Decorator that adds this free function as a member fn on class tensor.
44
+
45
+ When called as a member function on class tensor, the first argument to `fn`
46
+ is `self`, i.e. the tensor object.
47
+
48
+ If there are multiple decorators on a function, you probably want this one
49
+ to be the highest one (i.e. furthest from the function's `def`), so it's
50
+ applied last.
51
+
52
+ Unfortunately you still need to add a type stub to the body of class tensor
53
+ in order for pytype to know about it.
54
+ """
55
+ assert callable(fn)
56
+ orig_sig = inspect.signature(fn)
57
+ # Does fn take args other than _builder, _generator, and the tensor itself?
58
+ has_args = len(orig_sig.parameters.keys() - {"_builder", "_generator"}) > 1
59
+
60
+ if not fn.__doc__:
61
+ fn.__doc__ = ""
62
+ fn.__doc__ += f"""
63
+ This function can also be called as a member function on :py:class:`tensor`,
64
+ as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of
65
+ :code:`{fn.__name__}(x{", ..." if has_args else ""})`.
66
+ """
67
+
68
+ def wrapper(*args, **kwargs):
69
+ return fn(*args, **kwargs)
70
+
71
+ # Match the signature of `fn`, but change the first arg to `self` so the
72
+ # docs are a little less weird.
73
+ new_params = list(orig_sig.parameters.values())
74
+ new_params[0] = new_params[0].replace(name='self')
75
+ new_sig = orig_sig.replace(parameters=new_params)
76
+ wrapper.__signature__ = new_sig
77
+ wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function"
78
+ # If fn is a builtin, mark the wrapper as a builtin too.
79
+ if is_builtin(fn):
80
+ setattr(wrapper, TRITON_BUILTIN, True)
81
+
82
+ setattr(tensor, fn.__name__, wrapper)
83
+ return fn
84
+
85
+
86
+ def _unwrap_iterable(x):
87
+ """Returns x[0] if x has one element and x[0] is iterable."""
88
+ if len(x) == 1:
89
+ # Determine whether x[0] is iterable.
90
+ #
91
+ # You might want to use collections.abc.Iterable instead of this
92
+ # try/except block. Unfortunately, this doesn't work with constexpr.
93
+ #
94
+ # The problem is that abc.Iterable checks for __iter__ on the *class*.
95
+ # But we want constexpr to expose an __iter__ method if and only if the
96
+ # wrapped *object* (i.e. self.value) is iterable. Therefore there's no
97
+ # right answer for whether the class constexpr defines __iter__, and
98
+ # abc.Iterable doesn't work (at least not without some metaclass magic).
99
+ try:
100
+ iter(x[0])
101
+ return x[0]
102
+ except TypeError:
103
+ pass
104
+
105
+ return x
106
+
107
+
108
+ def is_builtin(fn) -> bool:
109
+ """Is this a registered triton builtin function?"""
110
+ return getattr(fn, TRITON_BUILTIN, False)
111
+
112
+
113
+ @builtin
114
+ def to_tensor(x, _builder=None):
115
+ return _to_tensor(x, _builder)
116
+
117
+
118
+ def _to_tensor(x, builder):
119
+ if isinstance(x, bool):
120
+ return tensor(builder.get_int1(x), int1)
121
+ # Note: compile-time const integers are represented by unsigned values
122
+ elif isinstance(x, int):
123
+ if -2**31 <= x < 2**31:
124
+ return tensor(builder.get_int32(x), int32)
125
+ elif 2**31 <= x < 2**32:
126
+ return tensor(builder.get_uint32(x), uint32)
127
+ elif -2**63 <= x < 2**63:
128
+ return tensor(builder.get_int64(x), int64)
129
+ elif 2**63 <= x < 2**64:
130
+ return tensor(builder.get_uint64(x), uint64)
131
+ else:
132
+ raise RuntimeError(f'Nonrepresentable integer {x}.')
133
+ elif isinstance(x, float):
134
+ min_float32 = 2**-126
135
+ max_float32 = (2 - 2**-23) * 2**127
136
+ abs_x = __builtins__['abs'](x)
137
+ if abs_x == float("inf") or\
138
+ abs_x == 0.0 or \
139
+ x != x or \
140
+ min_float32 <= abs_x <= max_float32:
141
+ return tensor(builder.get_fp32(x), float32)
142
+ else:
143
+ return tensor(builder.get_fp64(x), float64)
144
+
145
+ elif isinstance(x, constexpr):
146
+ return _to_tensor(x.value, builder)
147
+ elif isinstance(x, tensor):
148
+ return x
149
+ assert False, f"cannot convert {x} of type {type(x)} to tensor"
150
+
151
+
152
+ class dtype:
153
+ SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
154
+ UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
155
+ FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64']
156
+ STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
157
+ OTHER_TYPES = ['void']
158
+
159
+ class SIGNEDNESS(Enum):
160
+ SIGNED = 0
161
+ UNSIGNED = 1
162
+
163
+ def __init__(self, name):
164
+ if hasattr(name, 'value'):
165
+ name = name.value
166
+ self.name = name
167
+ assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name
168
+ if name in dtype.SINT_TYPES:
169
+ self.int_signedness = dtype.SIGNEDNESS.SIGNED
170
+ self.int_bitwidth = int(name.split('int')[-1])
171
+ self.primitive_bitwidth = self.int_bitwidth
172
+ elif name in dtype.UINT_TYPES:
173
+ self.int_signedness = dtype.SIGNEDNESS.UNSIGNED
174
+ self.int_bitwidth = int(name.split('int')[-1])
175
+ self.primitive_bitwidth = self.int_bitwidth
176
+ elif name in dtype.FP_TYPES:
177
+ if name == 'fp8e4b15':
178
+ self.fp_mantissa_width = 3
179
+ self.primitive_bitwidth = 8
180
+ self.exponent_bias = 15
181
+ elif name == 'fp8e4nv':
182
+ self.fp_mantissa_width = 3
183
+ self.primitive_bitwidth = 8
184
+ self.exponent_bias = 7
185
+ elif name == 'fp8e4b8':
186
+ self.fp_mantissa_width = 3
187
+ self.primitive_bitwidth = 8
188
+ self.exponent_bias = 8
189
+ elif name == 'fp8e5':
190
+ self.fp_mantissa_width = 2
191
+ self.primitive_bitwidth = 8
192
+ self.exponent_bias = 15
193
+ elif name == 'fp8e5b16':
194
+ self.fp_mantissa_width = 2
195
+ self.primitive_bitwidth = 8
196
+ self.exponent_bias = 16
197
+ elif name == 'fp16':
198
+ self.fp_mantissa_width = 10
199
+ self.primitive_bitwidth = 16
200
+ self.exponent_bias = 15
201
+ elif name == 'bf16':
202
+ self.fp_mantissa_width = 7
203
+ self.primitive_bitwidth = 16
204
+ self.exponent_bias = 127
205
+ elif name == 'fp32':
206
+ self.fp_mantissa_width = 23
207
+ self.primitive_bitwidth = 32
208
+ self.exponent_bias = 127
209
+ elif name == 'fp64':
210
+ self.fp_mantissa_width = 53
211
+ self.primitive_bitwidth = 64
212
+ self.exponent_bias = 1023
213
+ else:
214
+ raise RuntimeError(f'Unsupported floating-point type {name}')
215
+ elif name == 'void':
216
+ self.primitive_bitwidth = 0
217
+
218
+ def is_fp8(self):
219
+ return 'fp8' in self.name
220
+
221
+ def is_fp8e4nv(self):
222
+ return self.name == 'fp8e4nv'
223
+
224
+ def is_fp8e4b8(self):
225
+ return self.name == 'fp8e4b8'
226
+
227
+ def is_fp8e4b15(self):
228
+ return self.name == 'fp8e4b15'
229
+
230
+ def is_fp8e5(self):
231
+ return self.name == 'fp8e5'
232
+
233
+ def is_fp8e5b16(self):
234
+ return self.name == 'fp8e5b16'
235
+
236
+ def is_fp16(self):
237
+ return self.name == 'fp16'
238
+
239
+ def is_bf16(self):
240
+ return self.name == 'bf16'
241
+
242
+ def is_fp32(self):
243
+ return self.name == 'fp32'
244
+
245
+ def is_fp64(self):
246
+ return self.name == 'fp64'
247
+
248
+ def is_int1(self):
249
+ return self.name == 'int1'
250
+
251
+ def is_int8(self):
252
+ return self.name == 'int8'
253
+
254
+ def is_int16(self):
255
+ return self.name == 'int16'
256
+
257
+ def is_int32(self):
258
+ return self.name == 'int32'
259
+
260
+ def is_int64(self):
261
+ return self.name == 'int64'
262
+
263
+ def is_uint8(self):
264
+ return self.name == 'uint8'
265
+
266
+ def is_uint16(self):
267
+ return self.name == 'uint16'
268
+
269
+ def is_uint32(self):
270
+ return self.name == 'uint32'
271
+
272
+ def is_uint64(self):
273
+ return self.name == 'uint64'
274
+
275
+ def is_floating(self):
276
+ return self.name in dtype.FP_TYPES
277
+
278
+ def is_standard_floating(self):
279
+ return self.name in dtype.STANDARD_FP_TYPES
280
+
281
+ def is_int_signed(self):
282
+ return self.name in dtype.SINT_TYPES
283
+
284
+ def is_int_unsigned(self):
285
+ return self.name in dtype.UINT_TYPES
286
+
287
+ def is_int(self):
288
+ return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES
289
+
290
+ def is_bool(self):
291
+ return self.is_int1()
292
+
293
+ @staticmethod
294
+ def is_dtype(type_str):
295
+ return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES
296
+
297
+ @staticmethod
298
+ def is_void():
299
+ raise RuntimeError("Not implemented")
300
+
301
+ @staticmethod
302
+ def is_block():
303
+ return False
304
+
305
+ @staticmethod
306
+ def is_ptr():
307
+ return False
308
+
309
+ @staticmethod
310
+ def is_const():
311
+ return False
312
+
313
+ def __eq__(self, other: dtype):
314
+ if not isinstance(other, dtype):
315
+ return False
316
+ return self.name == other.name
317
+
318
+ def __ne__(self, other: dtype):
319
+ return not self.__eq__(other)
320
+
321
+ def __hash__(self):
322
+ return hash((self.name, ))
323
+
324
+ @property
325
+ def scalar(self):
326
+ return self
327
+
328
+ def to_ir(self, builder: ir.builder) -> ir.type:
329
+ if self.name == 'void':
330
+ return builder.get_void_ty()
331
+ elif self.name == 'int1':
332
+ return builder.get_int1_ty()
333
+ elif self.name in ('int8', 'uint8'):
334
+ return builder.get_int8_ty()
335
+ elif self.name in ('int16', 'uint16'):
336
+ return builder.get_int16_ty()
337
+ elif self.name in ('int32', 'uint32'):
338
+ return builder.get_int32_ty()
339
+ elif self.name in ('int64', 'uint64'):
340
+ return builder.get_int64_ty()
341
+ elif self.name == 'fp8e5':
342
+ return builder.get_fp8e5_ty()
343
+ elif self.name == 'fp8e5b16':
344
+ return builder.get_fp8e5b16_ty()
345
+ elif self.name == 'fp8e4nv':
346
+ return builder.get_fp8e4nv_ty()
347
+ elif self.name == 'fp8e4b8':
348
+ return builder.get_fp8e4b8_ty()
349
+ elif self.name == 'fp8e4b15':
350
+ return builder.get_fp8e4b15_ty()
351
+ elif self.name == 'fp16':
352
+ return builder.get_half_ty()
353
+ elif self.name == 'bf16':
354
+ return builder.get_bf16_ty()
355
+ elif self.name == 'fp32':
356
+ return builder.get_float_ty()
357
+ elif self.name == 'fp64':
358
+ return builder.get_double_ty()
359
+ raise ValueError(f'fail to convert {self} to ir type')
360
+
361
+ def __str__(self):
362
+ return self.name
363
+
364
+ def codegen_name(self):
365
+ if self.name.startswith("fp"):
366
+ return "float" + self.name[2:]
367
+ elif self.name.startswith("bf"):
368
+ return "bfloat" + self.name[2:]
369
+ else:
370
+ return self.name
371
+
372
+ @property
373
+ def cache_key_part(self) -> str:
374
+ """See cache_key_part() in triton.cc."""
375
+ return self.name
376
+
377
+ def __repr__(self):
378
+ """Output of repr needs to be an evaluatable expression"""
379
+ return f'triton.language.{self.codegen_name()}'
380
+
381
+
382
+ # Some functions have a param named `dtype`, which shadows the `dtype` class.
383
+ # We can't change the param name because it is part of function's public API.
384
+ # Declare an alias so those functions can still reference the dtype class.
385
+ _DtypeClass = dtype
386
+
387
+
388
+ class pointer_type(dtype):
389
+
390
+ def __init__(self, element_ty: dtype, address_space: int = 1):
391
+ if not isinstance(element_ty, dtype):
392
+ raise TypeError(f'element_ty is a {type(element_ty).__name__}.')
393
+ self.element_ty = element_ty
394
+ self.address_space = address_space
395
+
396
+ self.name = f'pointer<{element_ty}>'
397
+
398
+ def to_ir(self, builder: ir.builder) -> ir.pointer_type:
399
+ return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1)
400
+
401
+ def __str__(self):
402
+ return self.name
403
+
404
+ def __repr__(self):
405
+ return self.__str__()
406
+
407
+ def is_ptr(self):
408
+ return True
409
+
410
+ def __eq__(self, other: pointer_type) -> bool:
411
+ if not isinstance(other, pointer_type):
412
+ return False
413
+ return self.element_ty == other.element_ty and self.address_space == other.address_space
414
+
415
+ def __ne__(self, other: pointer_type) -> bool:
416
+ return not self.__eq__(other)
417
+
418
+ @property
419
+ def scalar(self):
420
+ return self
421
+
422
+
423
+ class const_pointer_type(pointer_type):
424
+
425
+ def __init__(self, element_ty: dtype, address_space: int = 1):
426
+ super().__init__(element_ty, address_space)
427
+
428
+ def __str__(self):
429
+ return f'const_pointer<{self.element_ty}>'
430
+
431
+ def is_const(self):
432
+ return True
433
+
434
+ def __eq__(self, other) -> bool:
435
+ if not isinstance(other, const_pointer_type):
436
+ return False
437
+ return self.element_ty == other.element_ty and self.address_space == other.address_space
438
+
439
+
440
+ class block_type(dtype):
441
+
442
+ def __init__(self, element_ty: dtype, shape: List):
443
+ self.element_ty = element_ty
444
+
445
+ # Note that block_type's shape is a list of int
446
+ # while tensor's shape is a list of constexpr.
447
+
448
+ # shape can be empty ([]) when an input is a 0D tensor.
449
+ if not shape:
450
+ raise TypeError('0d block_type is forbidden')
451
+ if isinstance(shape[0], constexpr):
452
+ shape = [s.value for s in shape]
453
+
454
+ self.shape = shape
455
+ self.numel = 1
456
+ for s in self.shape:
457
+ self.numel *= s
458
+ if self.numel > TRITON_MAX_TENSOR_NUMEL:
459
+ raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
460
+
461
+ self.name = f'<{self.shape}, {self.element_ty}>'
462
+
463
+ def to_ir(self, builder: ir.builder) -> ir.block_type:
464
+ return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape)
465
+
466
+ def __str__(self):
467
+ return self.name
468
+
469
+ def __repr__(self):
470
+ return self.__str__()
471
+
472
+ def is_block(self):
473
+ return True
474
+
475
+ def get_block_shapes(self) -> List[int]:
476
+ return self.shape
477
+
478
+ def __eq__(self, other: block_type) -> bool:
479
+ if not isinstance(other, block_type):
480
+ return False
481
+ return self.element_ty == other.element_ty and self.shape == other.shape
482
+
483
+ def __ne__(self, other: block_type) -> bool:
484
+ return not self.__eq__(other)
485
+
486
+ @property
487
+ def scalar(self):
488
+ return self.element_ty
489
+
490
+
491
+ class function_type(dtype):
492
+
493
+ def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
494
+ self.ret_types = ret_types
495
+ self.param_types = param_types
496
+
497
+ def __str__(self):
498
+ return f'fn ({self.param_types}) -> {self.ret_types}'
499
+
500
+ def to_ir(self, builder: ir.builder):
501
+ ir_param_types = [ty.to_ir(builder) for ty in self.param_types]
502
+ ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types]
503
+ return builder.get_function_ty(ir_param_types, ret_types)
504
+
505
+
506
+ # scalar types
507
+ void = dtype('void')
508
+ int1 = dtype('int1')
509
+ int8 = dtype('int8')
510
+ int16 = dtype('int16')
511
+ int32 = dtype('int32')
512
+ int64 = dtype('int64')
513
+ uint8 = dtype('uint8')
514
+ uint16 = dtype('uint16')
515
+ uint32 = dtype('uint32')
516
+ uint64 = dtype('uint64')
517
+ float8e5 = dtype('fp8e5')
518
+ float8e5b16 = dtype('fp8e5b16')
519
+ float8e4nv = dtype('fp8e4nv')
520
+ float8e4b8 = dtype('fp8e4b8')
521
+ float8e4b15 = dtype('fp8e4b15')
522
+ float16 = dtype('fp16')
523
+ bfloat16 = dtype('bf16')
524
+ float32 = dtype('fp32')
525
+ float64 = dtype('fp64')
526
+ # pointer types
527
+ pi32_t = pointer_type(int32)
528
+
529
+
530
+ def get_int_dtype(bitwidth: int, signed: bool) -> dtype:
531
+ if bitwidth == 1:
532
+ return int1
533
+ elif bitwidth == 8 and signed:
534
+ return int8
535
+ elif bitwidth == 8 and not signed:
536
+ return uint8
537
+ elif bitwidth == 16 and signed:
538
+ return int16
539
+ elif bitwidth == 16 and not signed:
540
+ return uint16
541
+ elif bitwidth == 32 and signed:
542
+ return int32
543
+ elif bitwidth == 32 and not signed:
544
+ return uint32
545
+ elif bitwidth == 64 and signed:
546
+ return int64
547
+ elif bitwidth == 64 and not signed:
548
+ return uint64
549
+ else:
550
+ raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}')
551
+
552
+
553
+ # -----------------------
554
+ # constexpr
555
+ # -----------------------
556
+
557
+
558
+ class const:
559
+ """
560
+ This class is used as a type annotation to mark pointers to constant data.
561
+ The `store` function cannot be called with a pointer to const. Constness
562
+ is part of the pointer type and the usual Triton type consistency rules
563
+ apply. For example you cannot have a function that returns constant pointer
564
+ in one return statement and non-constant pointer in another.
565
+ """
566
+ pass
567
+
568
+
569
+ class constexpr:
570
+ """
571
+ This class is used to store a value that is known at compile-time.
572
+ """
573
+
574
+ def __init__(self, value):
575
+ if isinstance(value, constexpr):
576
+ self.value = value.value
577
+ else:
578
+ self.value = value
579
+
580
+ def __repr__(self) -> str:
581
+ return f"constexpr[{self.value}]"
582
+
583
+ def __index__(self):
584
+ return self.value
585
+
586
+ # In interpreter mode, constant values are not wrapped in constexpr,
587
+ # and therefore do not have a .value attribute.
588
+ # As a result, from here and below, we need to call the _constexpr_to_value
589
+ # function to obtain either constexpr.value or the value itself.
590
+ def __add__(self, other):
591
+ return constexpr(self.value + _constexpr_to_value(other))
592
+
593
+ def __radd__(self, other):
594
+ return constexpr(_constexpr_to_value(other) + self.value)
595
+
596
+ def __sub__(self, other):
597
+ return constexpr(self.value - _constexpr_to_value(other))
598
+
599
+ def __rsub__(self, other):
600
+ return constexpr(_constexpr_to_value(other) - self.value)
601
+
602
+ def __mul__(self, other):
603
+ return constexpr(self.value * _constexpr_to_value(other))
604
+
605
+ def __mod__(self, other):
606
+ return constexpr(self.value % _constexpr_to_value(other))
607
+
608
+ def __rmul__(self, other):
609
+ return constexpr(_constexpr_to_value(other) * self.value)
610
+
611
+ def __truediv__(self, other):
612
+ return constexpr(self.value / _constexpr_to_value(other))
613
+
614
+ def __rtruediv__(self, other):
615
+ return constexpr(_constexpr_to_value(other) / self.value)
616
+
617
+ def __floordiv__(self, other):
618
+ return constexpr(self.value // _constexpr_to_value(other))
619
+
620
+ def __rfloordiv__(self, other):
621
+ return constexpr(_constexpr_to_value(other) // self.value)
622
+
623
+ def __gt__(self, other):
624
+ return constexpr(self.value > _constexpr_to_value(other))
625
+
626
+ def __rgt__(self, other):
627
+ return constexpr(_constexpr_to_value(other) > self.value)
628
+
629
+ def __ge__(self, other):
630
+ return constexpr(self.value >= _constexpr_to_value(other))
631
+
632
+ def __rge__(self, other):
633
+ return constexpr(_constexpr_to_value(other) >= self.value)
634
+
635
+ def __lt__(self, other):
636
+ return constexpr(self.value < _constexpr_to_value(other))
637
+
638
+ def __rlt__(self, other):
639
+ return constexpr(_constexpr_to_value(other) < self.value)
640
+
641
+ def __le__(self, other):
642
+ return constexpr(self.value <= _constexpr_to_value(other))
643
+
644
+ def __rle__(self, other):
645
+ return constexpr(_constexpr_to_value(other) <= self.value)
646
+
647
+ def __eq__(self, other):
648
+ return constexpr(self.value == _constexpr_to_value(other))
649
+
650
+ def __ne__(self, other):
651
+ return constexpr(self.value != _constexpr_to_value(other))
652
+
653
+ def __bool__(self):
654
+ return bool(self.value)
655
+
656
+ def __neg__(self):
657
+ return constexpr(-self.value)
658
+
659
+ def __and__(self, other):
660
+ return constexpr(self.value & _constexpr_to_value(other))
661
+
662
+ def logical_and(self, other):
663
+ return constexpr(self.value and _constexpr_to_value(other))
664
+
665
+ def __or__(self, other):
666
+ return constexpr(self.value | _constexpr_to_value(other))
667
+
668
+ def __xor__(self, other):
669
+ return constexpr(self.value ^ _constexpr_to_value(other))
670
+
671
+ def logical_or(self, other):
672
+ return constexpr(self.value or _constexpr_to_value(other))
673
+
674
+ def __pos__(self):
675
+ return constexpr(+self.value)
676
+
677
+ def __invert__(self):
678
+ return constexpr(~self.value)
679
+
680
+ def __pow__(self, other):
681
+ return constexpr(self.value**_constexpr_to_value(other))
682
+
683
+ def __rpow__(self, other):
684
+ return constexpr(_constexpr_to_value(other)**self.value)
685
+
686
+ def __rshift__(self, other):
687
+ return constexpr(self.value >> _constexpr_to_value(other))
688
+
689
+ def __lshift__(self, other):
690
+ return constexpr(self.value << _constexpr_to_value(other))
691
+
692
+ def __not__(self):
693
+ return constexpr(not self.value)
694
+
695
+ def __iter__(self):
696
+ return iter(self.value)
697
+
698
+ def __call__(self, *args, **kwds):
699
+ return self.value(*args, **kwds)
700
+
701
+
702
+ CONSTEXPR_0 = constexpr(0)
703
+
704
+
705
+ def check_bit_width(value, shift_value):
706
+ if isinstance(value, tensor) and isinstance(shift_value, constexpr):
707
+ bitwidth = value.type.scalar.primitive_bitwidth
708
+ if shift_value.value >= bitwidth:
709
+ warn(
710
+ f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior."
711
+ )
712
+
713
+
714
+ class tensor:
715
+ """Represents an N-dimensional array of values or pointers.
716
+
717
+ :code:`tensor` is the fundamental data structure in Triton programs. Most
718
+ functions in :py:mod:`triton.language` operate on and return tensors.
719
+
720
+ Most of the named member functions here are duplicates of the free functions
721
+ in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is
722
+ equivalent to :code:`x.sqrt()`.
723
+
724
+ :code:`tensor` also defines most of the magic/dunder methods, so you can
725
+ write :code:`x+y`, :code:`x << 2`, etc.
726
+
727
+ .. rubric:: Constructors
728
+ ..
729
+ For some reason Sphinx includes __init__ before printing the full table
730
+ of methods. Not what I want, but I can't figure out how to fix it. Give
731
+ it its own section so it looks intentional. :)
732
+ """
733
+
734
+ def __init__(self, handle, type: dtype):
735
+ """Not called by user code."""
736
+ # IR handle
737
+ self.handle = handle
738
+ # Block shape
739
+ self.shape = type.shape if type.is_block() else ()
740
+ self.numel = 1
741
+ for s in self.shape:
742
+ self.numel *= s
743
+ self.numel = constexpr(self.numel)
744
+ self.type = type # Tensor type (can be block_type)
745
+ # Following the practice in pytorch, dtype is scalar type
746
+ self.dtype = type.scalar
747
+ self.shape = [constexpr(s) for s in self.shape]
748
+
749
+ def __str__(self) -> str:
750
+ # ex. "float32[16, 32]"
751
+ return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
752
+
753
+ @builtin
754
+ def __add__(self, other, _builder=None):
755
+ other = _to_tensor(other, _builder)
756
+ return semantic.add(self, other, _builder)
757
+
758
+ @builtin
759
+ def __radd__(self, other, _builder=None):
760
+ return self.__add__(other, _builder=_builder)
761
+
762
+ @builtin
763
+ def __sub__(self, other, _builder=None):
764
+ other = _to_tensor(other, _builder)
765
+ return semantic.sub(self, other, _builder)
766
+
767
+ @builtin
768
+ def __rsub__(self, other, _builder=None):
769
+ other = _to_tensor(other, _builder)
770
+ return semantic.sub(other, self, _builder)
771
+
772
+ @builtin
773
+ def __mul__(self, other, _builder=None):
774
+ other = _to_tensor(other, _builder)
775
+ return semantic.mul(self, other, _builder)
776
+
777
+ @builtin
778
+ def __rmul__(self, other, _builder=None):
779
+ return self.__mul__(other, _builder=_builder)
780
+
781
+ @builtin
782
+ def __truediv__(self, other, _builder=None):
783
+ other = _to_tensor(other, _builder)
784
+ return semantic.truediv(self, other, _builder)
785
+
786
+ @builtin
787
+ def __rtruediv__(self, other, _builder=None):
788
+ other = _to_tensor(other, _builder)
789
+ return semantic.truediv(other, self, _builder)
790
+
791
+ @builtin
792
+ def __floordiv__(self, other, _builder=None):
793
+ other = _to_tensor(other, _builder)
794
+ return semantic.floordiv(self, other, _builder)
795
+
796
+ @builtin
797
+ def __rfloordiv__(self, other, _builder=None):
798
+ other = _to_tensor(other, _builder)
799
+ return semantic.floordiv(other, self, _builder)
800
+
801
+ @builtin
802
+ def __mod__(self, other, _builder=None):
803
+ other = _to_tensor(other, _builder)
804
+ return semantic.mod(self, other, _builder)
805
+
806
+ @builtin
807
+ def __rmod__(self, other, _builder=None):
808
+ other = _to_tensor(other, _builder)
809
+ return semantic.mod(other, self, _builder)
810
+
811
+ # unary operators
812
+ @builtin
813
+ def __neg__(self, _builder=None):
814
+ return semantic.minus(self, _builder)
815
+
816
+ @builtin
817
+ def __invert__(self, _builder=None):
818
+ return semantic.invert(self, _builder)
819
+
820
+ # bitwise operators
821
+
822
+ @builtin
823
+ def __and__(self, other, _builder=None):
824
+ other = _to_tensor(other, _builder)
825
+ return semantic.and_(self, other, _builder)
826
+
827
+ @builtin
828
+ def __rand__(self, other, _builder=None):
829
+ other = _to_tensor(other, _builder)
830
+ return semantic.and_(other, self, _builder)
831
+
832
+ @builtin
833
+ def __or__(self, other, _builder=None):
834
+ other = _to_tensor(other, _builder)
835
+ return semantic.or_(self, other, _builder)
836
+
837
+ @builtin
838
+ def __ror__(self, other, _builder=None):
839
+ other = _to_tensor(other, _builder)
840
+ return semantic.or_(other, self, _builder)
841
+
842
+ @builtin
843
+ def __xor__(self, other, _builder=None):
844
+ other = _to_tensor(other, _builder)
845
+ return semantic.xor_(self, other, _builder)
846
+
847
+ @builtin
848
+ def __rxor__(self, other, _builder=None):
849
+ other = _to_tensor(other, _builder)
850
+ return semantic.xor_(other, self, _builder)
851
+
852
+ @builtin
853
+ def __lshift__(self, other, _builder=None):
854
+ check_bit_width(self, other)
855
+ other = _to_tensor(other, _builder)
856
+ return semantic.shl(self, other, _builder)
857
+
858
+ @builtin
859
+ def __rlshift__(self, other, _builder=None):
860
+ check_bit_width(other, self)
861
+ other = _to_tensor(other, _builder)
862
+ return semantic.shl(other, self, _builder)
863
+
864
+ @builtin
865
+ def __rshift__(self, other, _builder=None):
866
+ check_bit_width(self, other)
867
+ other = _to_tensor(other, _builder)
868
+ if self.dtype.is_int_signed():
869
+ return semantic.ashr(self, other, _builder)
870
+ else:
871
+ return semantic.lshr(self, other, _builder)
872
+
873
+ @builtin
874
+ def __rrshift__(self, other, _builder=None):
875
+ check_bit_width(other, self)
876
+ other = _to_tensor(other, _builder)
877
+ if self.dtype.is_int_signed():
878
+ return semantic.ashr(other, self, _builder)
879
+ else:
880
+ return semantic.lshr(other, self, _builder)
881
+
882
+ # >
883
+ @builtin
884
+ def __gt__(self, other, _builder=None):
885
+ other = _to_tensor(other, _builder)
886
+ return semantic.greater_than(self, other, _builder)
887
+
888
+ @builtin
889
+ def __rgt__(self, other, _builder=None):
890
+ other = _to_tensor(other, _builder)
891
+ return semantic.greater_than(other, self, _builder)
892
+
893
+ # >=
894
+ @builtin
895
+ def __ge__(self, other, _builder=None):
896
+ other = _to_tensor(other, _builder)
897
+ return semantic.greater_equal(self, other, _builder)
898
+
899
+ @builtin
900
+ def __rge__(self, other, _builder=None):
901
+ other = _to_tensor(other, _builder)
902
+ return semantic.greater_equal(other, self, _builder)
903
+
904
+ # <
905
+ @builtin
906
+ def __lt__(self, other, _builder=None):
907
+ other = _to_tensor(other, _builder)
908
+ return semantic.less_than(self, other, _builder)
909
+
910
+ @builtin
911
+ def __rlt__(self, other, _builder=None):
912
+ other = _to_tensor(other, _builder)
913
+ return semantic.less_than(other, self, _builder)
914
+
915
+ # <=
916
+ @builtin
917
+ def __le__(self, other, _builder=None):
918
+ other = _to_tensor(other, _builder)
919
+ return semantic.less_equal(self, other, _builder)
920
+
921
+ @builtin
922
+ def __rle__(self, other, _builder=None):
923
+ other = _to_tensor(other, _builder)
924
+ return semantic.less_equal(other, self, _builder)
925
+
926
+ # ==
927
+ @builtin
928
+ def __eq__(self, other, _builder=None):
929
+ other = _to_tensor(other, _builder)
930
+ return semantic.equal(self, other, _builder)
931
+
932
+ @builtin
933
+ def __req__(self, other, _builder=None):
934
+ other = _to_tensor(other, _builder)
935
+ return semantic.equal(other, self, _builder)
936
+
937
+ @builtin
938
+ def __ne__(self, other, _builder=None):
939
+ other = _to_tensor(other, _builder)
940
+ return semantic.not_equal(self, other, _builder)
941
+
942
+ @builtin
943
+ def __rne__(self, other, _builder=None):
944
+ other = _to_tensor(other, _builder)
945
+ return semantic.not_equal(other, self, _builder)
946
+
947
+ @builtin
948
+ def logical_and(self, other, _builder=None):
949
+ other = _to_tensor(other, _builder)
950
+ return semantic.logical_and(self, other, _builder)
951
+
952
+ @builtin
953
+ def logical_or(self, other, _builder=None):
954
+ other = _to_tensor(other, _builder)
955
+ return semantic.logical_or(self, other, _builder)
956
+
957
+ # note: __not__ isn't actually a magic method in python
958
+ # but it's ok because our ASTVisitor handles it
959
+ @builtin
960
+ def __not__(self, _builder=None):
961
+ return semantic.not_(self, _builder)
962
+
963
+ @builtin
964
+ def __getitem__(self, slices, _builder=None):
965
+ if isinstance(slices, (slice, constexpr)) or slices is None:
966
+ slices = [slices]
967
+ ret = self
968
+ for dim, sl in enumerate(slices):
969
+ if sl is None or isinstance(sl, constexpr) and sl.value is None:
970
+ ret = semantic.expand_dims(ret, dim, _builder)
971
+ elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
972
+ pass
973
+ else:
974
+ raise ValueError(f"unsupported tensor index: {sl}")
975
+ return ret
976
+
977
+ @property
978
+ def T(self):
979
+ """Transposes a 2D tensor."""
980
+ assert False, "Transposition must be created by the AST Visitor"
981
+
982
+ @builtin
983
+ def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None):
984
+ """
985
+ Alias for :py:func:`tensor.cast`.
986
+ """
987
+ # Triton doesn't like core functions calling other core functions, so we
988
+ # just copy-paste the implementation of cast here. It's not too bad.
989
+ if isinstance(bitcast, constexpr):
990
+ bitcast = bitcast.value
991
+ if bitcast:
992
+ return semantic.bitcast(self, dtype, _builder)
993
+ return semantic.cast(self, dtype, _builder, fp_downcast_rounding)
994
+
995
+ # Type stubs for functions added by the _tensor_member_fn decorator.
996
+ # (Unfortunately these can't be created automatically.)
997
+ #
998
+ # We couldn't write these definitions out even if we wanted to, because some
999
+ # of these functions are defined in standard.py.
1000
+ def broadcast_to(self, *shape) -> tensor:
1001
+ ...
1002
+
1003
+ def trans(self, *dims) -> tensor:
1004
+ ...
1005
+
1006
+ def permute(self, *dims) -> tensor:
1007
+ ...
1008
+
1009
+ def split(self) -> tuple[tensor, tensor]:
1010
+ ...
1011
+
1012
+ def view(self, *shape) -> tensor:
1013
+ ...
1014
+
1015
+ def reshape(self, *shape) -> tensor:
1016
+ ...
1017
+
1018
+ def expand_dims(self, axis) -> tensor:
1019
+ ...
1020
+
1021
+ def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor:
1022
+ ...
1023
+
1024
+ def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor:
1025
+ ...
1026
+
1027
+ def advance(self, offsets) -> tensor:
1028
+ ...
1029
+
1030
+ def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor:
1031
+ ...
1032
+
1033
+ def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor:
1034
+ ...
1035
+
1036
+ def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor:
1037
+ ...
1038
+
1039
+ def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor:
1040
+ ...
1041
+
1042
+ def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor:
1043
+ ...
1044
+
1045
+ def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor:
1046
+ ...
1047
+
1048
+ def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor:
1049
+ ...
1050
+
1051
+ def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor:
1052
+ ...
1053
+
1054
+ def exp(self) -> tensor:
1055
+ ...
1056
+
1057
+ def log(self) -> tensor:
1058
+ ...
1059
+
1060
+ def cos(self) -> tensor:
1061
+ ...
1062
+
1063
+ def sin(self) -> tensor:
1064
+ ...
1065
+
1066
+ def sqrt(self) -> tensor:
1067
+ ...
1068
+
1069
+ def rsqrt(self) -> tensor:
1070
+ ...
1071
+
1072
+ def abs(self) -> tensor:
1073
+ ...
1074
+
1075
+ def reduce(self, axis, combine_fn, keep_dims=False) -> tensor:
1076
+ ...
1077
+
1078
+ def associative_scan(self, axis, combine_fn, reverse=False) -> tensor:
1079
+ ...
1080
+
1081
+ def histogram(self, num_bins) -> tensor:
1082
+ ...
1083
+
1084
+ def cdiv(self, div) -> tensor:
1085
+ ...
1086
+
1087
+ def sigmoid(self) -> tensor:
1088
+ ...
1089
+
1090
+ def softmax(self, ieee_rounding=False) -> tensor:
1091
+ ...
1092
+
1093
+ def ravel(self) -> tensor:
1094
+ ...
1095
+
1096
+ def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor:
1097
+ ...
1098
+
1099
+ def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor:
1100
+ ...
1101
+
1102
+ def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor:
1103
+ ...
1104
+
1105
+ def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor:
1106
+ ...
1107
+
1108
+ def sum(self, axis=None, keep_dims=False) -> tensor:
1109
+ ...
1110
+
1111
+ def xor_sum(self, axis=None, keep_dims=False) -> tensor:
1112
+ ...
1113
+
1114
+ def cumsum(self, axis=0, reverse=False) -> tensor:
1115
+ ...
1116
+
1117
+ def cumprod(self, axis=0, reverse=False) -> tensor:
1118
+ ...
1119
+
1120
+ def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor:
1121
+ ...
1122
+
1123
+ def flip(self, dim=None) -> tensor:
1124
+ ...
1125
+
1126
+
1127
+ def get_bool_env_var(var_name):
1128
+ v = os.getenv(var_name, "0")
1129
+ return v == "1" or v == "true" or v == "on"
1130
+
1131
+
1132
+ # -----------------------
1133
+ # SPMD Programming Model
1134
+ # -----------------------
1135
+ def _constexpr_to_value(v):
1136
+ if isinstance(v, constexpr):
1137
+ return v.value
1138
+ return v
1139
+
1140
+
1141
+ @builtin
1142
+ def program_id(axis, _builder=None):
1143
+ """
1144
+ Returns the id of the current program instance along the given :code:`axis`.
1145
+
1146
+ :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
1147
+ :type axis: int
1148
+ """
1149
+ # if axis == -1:
1150
+ # pid0 = program_id(0, _builder)
1151
+ # pid1 = program_id(1, _builder)
1152
+ # pid2 = program_id(2, _builder)
1153
+ # npg0 = num_programs(0, _builder)
1154
+ # npg1 = num_programs(0, _builder)
1155
+ # return pid0 + pid1*npg0 + pid2*npg0*npg1
1156
+ axis = _constexpr_to_value(axis)
1157
+ return semantic.program_id(axis, _builder)
1158
+
1159
+
1160
+ @builtin
1161
+ def num_programs(axis, _builder=None):
1162
+ """
1163
+ Returns the number of program instances launched along the given :code:`axis`.
1164
+
1165
+ :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
1166
+ :type axis: int
1167
+ """
1168
+ axis = _constexpr_to_value(axis)
1169
+ return semantic.num_programs(axis, _builder)
1170
+
1171
+
1172
+ # -----------------------
1173
+ # Block Initialization
1174
+ # -----------------------
1175
+
1176
+
1177
+ @builtin
1178
+ def arange(start, end, _builder=None):
1179
+ """
1180
+ Returns contiguous values within the half-open interval :code:`[start,
1181
+ end)`. :code:`end - start` must be less than or equal to
1182
+ :code:`TRITON_MAX_TENSOR_NUMEL = 131072`
1183
+
1184
+ :param start: Start of the interval. Must be a power of two.
1185
+ :type start: int32
1186
+ :param end: End of the interval. Must be a power of two greater than
1187
+ :code:`start`.
1188
+ :type end: int32
1189
+ """
1190
+ start = _constexpr_to_value(start)
1191
+ end = _constexpr_to_value(end)
1192
+ return semantic.arange(start, end, _builder)
1193
+
1194
+
1195
+ def _shape_check_impl(shape):
1196
+ shape = _constexpr_to_value(shape)
1197
+ for i, d in enumerate(shape):
1198
+ if isinstance(d, int):
1199
+ d = constexpr(d)
1200
+ if not isinstance(d, constexpr):
1201
+ raise TypeError(f"Shape element {i} must have type `constexpr`")
1202
+ if not isinstance(d.value, int):
1203
+ raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
1204
+ if d.value & (d.value - 1) != 0:
1205
+ raise ValueError(f"Shape element {i} must be a power of 2")
1206
+ return [_constexpr_to_value(x) for x in shape]
1207
+
1208
+
1209
+ @builtin
1210
+ def full(shape, value, dtype, _builder=None):
1211
+ """
1212
+ Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`.
1213
+
1214
+ :param shape: Shape of the new array, e.g., (8, 16) or (8, )
1215
+ :value value: A scalar value to fill the array with
1216
+ :type shape: tuple of ints
1217
+ :param dtype: Data-type of the new array, e.g., :code:`tl.float16`
1218
+ :type dtype: DType
1219
+ """
1220
+ shape = _shape_check_impl(shape)
1221
+ value = _constexpr_to_value(value)
1222
+ dtype = _constexpr_to_value(dtype)
1223
+ return semantic.full(shape, value, dtype, _builder)
1224
+
1225
+
1226
+ # -----------------------
1227
+ # Shape Manipulation
1228
+ # -----------------------
1229
+
1230
+
1231
+ @builtin
1232
+ def broadcast(input, other, _builder=None):
1233
+ """
1234
+ Tries to broadcast the two given blocks to a common compatible shape.
1235
+
1236
+ :param input: The first input tensor.
1237
+ :type input: Block
1238
+ :param other: The second input tensor.
1239
+ :type other: Block
1240
+ """
1241
+ return semantic.broadcast_impl_value(input, other, _builder)
1242
+
1243
+
1244
+ @_tensor_member_fn
1245
+ @builtin
1246
+ def broadcast_to(input, *shape, _builder=None):
1247
+ """
1248
+ Tries to broadcast the given tensor to a new :code:`shape`.
1249
+
1250
+ :param input: The input tensor.
1251
+ :type input: Block
1252
+ :param shape: The desired shape.
1253
+ :type shape:
1254
+
1255
+ :code:`shape` can be passed as a tuple or as individual parameters: ::
1256
+
1257
+ # These are equivalent
1258
+ broadcast_to(x, (32, 32))
1259
+ broadcast_to(x, 32, 32)
1260
+ """
1261
+ shape = _shape_check_impl(_unwrap_iterable(shape))
1262
+ return semantic.broadcast_impl_shape(input, shape, _builder)
1263
+
1264
+
1265
+ @_tensor_member_fn
1266
+ @builtin
1267
+ def trans(input: tensor, *dims, _builder=None):
1268
+ """
1269
+ Permutes the dimensions of a tensor.
1270
+
1271
+ If no permutation is specified, tries to do a (1,0) permutation, i.e. tries
1272
+ to transpose a 2D tensor.
1273
+
1274
+ :param input: The input tensor.
1275
+ :param dims: The desired ordering of dimensions. For example,
1276
+ :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor.
1277
+
1278
+ :code:`dims` can be passed as a tuple or as individual parameters: ::
1279
+
1280
+ # These are equivalent
1281
+ trans(x, (2, 1, 0))
1282
+ trans(x, 2, 1, 0)
1283
+
1284
+ :py:func:`permute` is equivalent to this function, except it doesn't
1285
+ have the special case when no permutation is specified.
1286
+ """
1287
+ if not dims:
1288
+ dims = (1, 0)
1289
+ return semantic.permute(input, dims, _builder)
1290
+
1291
+
1292
+ @_tensor_member_fn
1293
+ @builtin
1294
+ def permute(input, *dims, _builder=None):
1295
+ """
1296
+ Permutes the dimensions of a tensor.
1297
+
1298
+ :param input: The input tensor.
1299
+ :type input: Block
1300
+ :param dims: The desired ordering of dimensions. For example,
1301
+ :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor.
1302
+
1303
+ :code:`dims` can be passed as a tuple or as individual parameters: ::
1304
+
1305
+ # These are equivalent
1306
+ permute(x, (2, 1, 0))
1307
+ permute(x, 2, 1, 0)
1308
+
1309
+ :py:func:`trans` is equivalent to this function, except when
1310
+ :code:`dims` is empty, it tries to do a (1,0) permutation.
1311
+ """
1312
+ dims = _unwrap_iterable(dims)
1313
+ return semantic.permute(input, dims, _builder)
1314
+
1315
+
1316
+ @builtin
1317
+ def cat(input, other, can_reorder=False, _builder=None):
1318
+ """
1319
+ Concatenate the given blocks
1320
+
1321
+ :param input: The first input tensor.
1322
+ :type input:
1323
+ :param other: The second input tensor.
1324
+ :type other:
1325
+ :param reorder: Compiler hint. If true, the compiler is
1326
+ allowed to reorder elements while concatenating inputs. Only use if the
1327
+ order does not matter (e.g., result is only used in reduction ops)
1328
+ """
1329
+ return semantic.cat(input, other, can_reorder, _builder)
1330
+
1331
+
1332
+ @builtin
1333
+ def join(a, b, _builder=None):
1334
+ """
1335
+ Join the given tensors in a new, minor dimension.
1336
+
1337
+ For example, given two tensors of shape (4,8), produces a new tensor of
1338
+ shape (4,8,2). Given two scalars, returns a tensor of shape (2).
1339
+
1340
+ The two inputs are broadcasted to be the same shape.
1341
+
1342
+ If you want to join more than two elements, you can use multiple calls to
1343
+ this function. This reflects the constraint in Triton that tensors must
1344
+ have power-of-two sizes.
1345
+
1346
+ join is the inverse of split.
1347
+
1348
+ :param a: The first input tensor.
1349
+ :type a: Tensor
1350
+ :param b: The second input tensor.
1351
+ :type b: Tensor
1352
+ """
1353
+ return semantic.join(a, b, _builder)
1354
+
1355
+
1356
+ @jit
1357
+ def _take_first(a, b):
1358
+ return a
1359
+
1360
+
1361
+ @_tensor_member_fn
1362
+ @builtin
1363
+ def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]:
1364
+ """
1365
+ Split a tensor in two along its last dim, which must have size 2.
1366
+
1367
+ For example, given a tensor of shape (4,8,2), produces two tensors of shape
1368
+ (4,8). Given a tensor of shape (2), returns two scalars.
1369
+
1370
+ If you want to split into more than two pieces, you can use multiple calls
1371
+ to this function (probably plus calling reshape). This reflects the
1372
+ constraint in Triton that tensors must have power-of-two sizes.
1373
+
1374
+ split is the inverse of join.
1375
+
1376
+ :param a: The tensor to split.
1377
+ :type a: Tensor
1378
+ """
1379
+ # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars.
1380
+ # But semantic.split can only handle returning tensors. Work around this by
1381
+ # expanding the input to shape [1,2] and then reducing the result.
1382
+ was_rank_1 = len(a.shape) == 1
1383
+ if was_rank_1:
1384
+ a = semantic.expand_dims(a, 0, _builder)
1385
+
1386
+ out_lhs, out_rhs = semantic.split(a, _builder)
1387
+
1388
+ if was_rank_1:
1389
+ # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar.
1390
+ out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator))
1391
+ out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator))
1392
+
1393
+ return out_lhs, out_rhs
1394
+
1395
+
1396
+ @_tensor_member_fn
1397
+ @builtin
1398
+ def view(input, *shape, _builder=None):
1399
+ """
1400
+ Returns a tensor with the same elements as `input` but a different shape.
1401
+ The order of the elements may not be preserved.
1402
+
1403
+ :param input: The input tensor.
1404
+ :type input: Block
1405
+ :param shape: The desired shape.
1406
+
1407
+ :code:`shape` can be passed as a tuple or as individual parameters: ::
1408
+
1409
+ # These are equivalent
1410
+ view(x, (32, 32))
1411
+ view(x, 32, 32)
1412
+ """
1413
+ warn("view is deprecated, please use reshape with can_reorder being true.")
1414
+ shape = _shape_check_impl(_unwrap_iterable(shape))
1415
+ return semantic.reshape(input, shape, can_reorder=True, builder=_builder)
1416
+
1417
+
1418
+ @_tensor_member_fn
1419
+ @builtin
1420
+ def reshape(input, *shape, can_reorder=False, _builder=None):
1421
+ """
1422
+ Returns a tensor with the same number of elements as input but with the
1423
+ provided shape.
1424
+
1425
+ :param input: The input tensor.
1426
+ :type input: Block
1427
+ :param shape: The new shape.
1428
+
1429
+ :code:`shape ` can be passed as a tuple or as individual parameters: ::
1430
+
1431
+ # These are equivalent
1432
+ reshape(x, (32, 32))
1433
+ reshape(x, 32, 32)
1434
+ """
1435
+ shape = _shape_check_impl(_unwrap_iterable(shape))
1436
+ return semantic.reshape(input, shape, can_reorder, _builder)
1437
+
1438
+
1439
+ def _wrap_axis(axis, ndim):
1440
+ if not (-ndim <= axis < ndim):
1441
+ raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}")
1442
+
1443
+ return axis if axis >= 0 else axis + ndim
1444
+
1445
+
1446
+ @_tensor_member_fn
1447
+ @builtin
1448
+ def expand_dims(input, axis, _builder=None):
1449
+ """
1450
+ Expand the shape of a tensor, by inserting new length-1 dimensions.
1451
+
1452
+ Axis indices are with respect to the resulting tensor, so
1453
+ ``result.shape[axis]`` will be 1 for each axis.
1454
+
1455
+ :param input: The input tensor.
1456
+ :type input: tl.tensor
1457
+ :param axis: The indices to add new axes
1458
+ :type axis: int | Sequence[int]
1459
+
1460
+ """
1461
+ input = _to_tensor(input, _builder)
1462
+ axis = _constexpr_to_value(axis)
1463
+ axes = list(axis) if isinstance(axis, Sequence) else [axis]
1464
+ new_ndim = len(input.shape) + len(axes)
1465
+ axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes]
1466
+
1467
+ if len(set(axes)) != len(axes):
1468
+ raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}")
1469
+
1470
+ ret = input
1471
+ for a in sorted(axes):
1472
+ ret = semantic.expand_dims(ret, a, _builder)
1473
+ return ret
1474
+
1475
+
1476
+ @_tensor_member_fn
1477
+ @builtin
1478
+ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None):
1479
+ """
1480
+ Casts a tensor to the given :code:`dtype`.
1481
+
1482
+ :param dtype: The target data type.
1483
+ :param fp_downcast_rounding: The rounding mode for downcasting
1484
+ floating-point values. This parameter is only used when self is a
1485
+ floating-point tensor and dtype is a floating-point type with a
1486
+ smaller bitwidth. Supported values are :code:`"rtne"` (round to
1487
+ nearest, ties to even) and :code:`"rtz"` (round towards zero).
1488
+ :param bitcast: If true, the tensor is bitcasted to the given
1489
+ :code:`dtype`, instead of being numerically casted.
1490
+ """
1491
+ input = _to_tensor(input, _builder)
1492
+ if isinstance(bitcast, constexpr):
1493
+ bitcast = bitcast.value
1494
+ if bitcast:
1495
+ return semantic.bitcast(input, dtype, _builder)
1496
+ return semantic.cast(input, dtype, _builder, fp_downcast_rounding)
1497
+
1498
+
1499
+ # -----------------------
1500
+ # Linear Algebra
1501
+ # -----------------------
1502
+
1503
+
1504
+ @builtin
1505
+ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32,
1506
+ _builder=None):
1507
+ """
1508
+ Returns the matrix product of two blocks.
1509
+
1510
+ The two blocks must be two-dimensional and have compatible inner dimensions.
1511
+
1512
+ :param input: The first tensor to be multiplied.
1513
+ :type input: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
1514
+ :param other: The second tensor to be multiplied.
1515
+ :type other: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
1516
+ :param input_precision: How to exercise the Tensor Cores for f32 x f32. If
1517
+ the device does not have Tensor Cores or the inputs are not of dtype f32,
1518
+ this option is ignored. For devices that do have tensor cores, the
1519
+ default precision is tf32.
1520
+ :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`.
1521
+ :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32".
1522
+ Only one of :code:`input_precision` and :code:`allow_tf32` can be
1523
+ specified (i.e. at least one must be :code:`None`).
1524
+ """
1525
+ assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified"
1526
+ if input_precision is None:
1527
+ supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions
1528
+ default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee"
1529
+ input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision)
1530
+
1531
+ input_precision = _constexpr_to_value(input_precision)
1532
+ out_dtype = _constexpr_to_value(out_dtype)
1533
+ max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc)
1534
+ return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder)
1535
+
1536
+
1537
+ # -----------------------
1538
+ # Non-Atomic Memory Operations
1539
+ # -----------------------
1540
+
1541
+
1542
+ @builtin
1543
+ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="",
1544
+ volatile=False, _builder=None):
1545
+ """
1546
+ Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
1547
+
1548
+ (1) If `pointer` is a single element pointer, a scalar is be loaded. In
1549
+ this case:
1550
+
1551
+ - `mask` and `other` must also be scalars,
1552
+ - `other` is implicitly typecast to `pointer.dtype.element_ty`, and
1553
+ - `boundary_check` and `padding_option` must be empty.
1554
+
1555
+ (2) If `pointer` is an N-dimensional tensor of pointers, an
1556
+ N-dimensional tensor is loaded. In this case:
1557
+
1558
+ - `mask` and `other` are implicitly broadcast to `pointer.shape`,
1559
+ - `other` is implicitly typecast to `pointer.dtype.element_ty`, and
1560
+ - `boundary_check` and `padding_option` must be empty.
1561
+
1562
+ (3) If `pointer` is a block pointer defined by `make_block_ptr`, a
1563
+ tensor is loaded. In this case:
1564
+
1565
+ - `mask` and `other` must be None, and
1566
+ - `boundary_check` and `padding_option` can be specified to control
1567
+ the behavior of out-of-bound access.
1568
+
1569
+ :param pointer: Pointer to the data to be loaded
1570
+ :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
1571
+ :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]`
1572
+ (must be `None` with block pointers)
1573
+ :type mask: Block of `triton.int1`, optional
1574
+ :param other: if `mask[idx]` is false, return `other[idx]`
1575
+ :type other: Block, optional
1576
+ :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
1577
+ :type boundary_check: tuple of ints, optional
1578
+ :param padding_option: should be one of {"", "zero", "nan"}, do padding while out of bound
1579
+ :param cache_modifier: changes cache option in NVIDIA PTX
1580
+ :type cache_modifier: str, optional
1581
+ :param eviction_policy: changes eviction policy in NVIDIA PTX
1582
+ :type eviction_policy: str, optional
1583
+ :param volatile: changes volatile option in NVIDIA PTX
1584
+ :type volatile: bool, optional
1585
+ """
1586
+ # `mask` and `other` can be constexpr
1587
+ mask = _constexpr_to_value(mask)
1588
+ other = _constexpr_to_value(other)
1589
+ if mask is not None:
1590
+ mask = _to_tensor(mask, _builder)
1591
+ if other is not None:
1592
+ other = _to_tensor(other, _builder)
1593
+ padding_option = _constexpr_to_value(padding_option)
1594
+ cache_modifier = _constexpr_to_value(cache_modifier)
1595
+ eviction_policy = _constexpr_to_value(eviction_policy)
1596
+ volatile = _constexpr_to_value(volatile)
1597
+ return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
1598
+ volatile, _builder)
1599
+
1600
+
1601
+ @builtin
1602
+ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None):
1603
+ """
1604
+ Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations.
1605
+ This will be removed in the future and shouldn't be used in production code.
1606
+
1607
+ This loads a tensor of data based on the descriptor and offsets.
1608
+ """
1609
+ type = block_type(dtype, shape)
1610
+ return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder)
1611
+
1612
+
1613
+ @builtin
1614
+ def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None):
1615
+ """
1616
+ Experimental feature to access TMA descriptors stores. This is an escape hatch to easily exercise TTGIR operations.
1617
+ This will be removed in the future and shouldn't be used in production code.
1618
+
1619
+ This stores a tensor of data based on the descriptor and offsets.
1620
+ """
1621
+ return semantic.descriptor_store(desc_pointer, value, offsets, _builder)
1622
+
1623
+
1624
+ @_tensor_member_fn
1625
+ @builtin
1626
+ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None):
1627
+ """
1628
+ Store a tensor of data into memory locations defined by `pointer`.
1629
+
1630
+ (1) If `pointer` is a single element pointer, a scalar is stored. In
1631
+ this case:
1632
+
1633
+ - `mask` must also be scalar, and
1634
+ - `boundary_check` and `padding_option` must be empty.
1635
+
1636
+ (2) If `pointer` is an N-dimensional tensor of pointers, an
1637
+ N-dimensional block is stored. In this case:
1638
+
1639
+ - `mask` is implicitly broadcast to `pointer.shape`, and
1640
+ - `boundary_check` must be empty.
1641
+
1642
+ (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block
1643
+ of data is stored. In this case:
1644
+
1645
+ - `mask` must be None, and
1646
+ - `boundary_check` can be specified to control the behavior of out-of-bound access.
1647
+
1648
+ `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`.
1649
+
1650
+ :param pointer: The memory location where the elements of `value` are stored
1651
+ :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
1652
+ :param value: The tensor of elements to be stored
1653
+ :type value: Block
1654
+ :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]`
1655
+ :type mask: Block of triton.int1, optional
1656
+ :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
1657
+ :type boundary_check: tuple of ints, optional
1658
+ :param cache_modifier: changes cache option in NVIDIA PTX
1659
+ :type cache_modifier: str, optional
1660
+ :param eviction_policy: changes eviction policy in NVIDIA PTX
1661
+ :type eviction_policy: str, optional
1662
+ """
1663
+ # `value` can be constexpr
1664
+ value = _to_tensor(value, _builder)
1665
+ mask = _constexpr_to_value(mask)
1666
+ if mask is not None:
1667
+ mask = _to_tensor(mask, _builder)
1668
+ cache_modifier = _constexpr_to_value(cache_modifier)
1669
+ eviction_policy = _constexpr_to_value(eviction_policy)
1670
+ return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)
1671
+
1672
+
1673
+ @builtin
1674
+ def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None):
1675
+ """
1676
+ Returns a pointer to a block in a parent tensor
1677
+
1678
+ :param base: The base pointer to the parent tensor
1679
+ :param shape: The shape of the parent tensor
1680
+ :param strides: The strides of the parent tensor
1681
+ :param offsets: The offsets to the block
1682
+ :param block_shape: The shape of the block
1683
+ :param order: The order of the original data format
1684
+ """
1685
+ return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
1686
+
1687
+
1688
+ @_tensor_member_fn
1689
+ @builtin
1690
+ def advance(base, offsets, _builder=None):
1691
+ """
1692
+ Advance a block pointer
1693
+
1694
+ :param base: the block pointer to advance
1695
+ :param offsets: the offsets to advance, a tuple by dimension
1696
+ """
1697
+ return semantic.advance(base, offsets, _builder)
1698
+
1699
+
1700
+ # -----------------------
1701
+ # Atomic Memory Operations
1702
+ # -----------------------
1703
+
1704
+
1705
+ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
1706
+
1707
+ def _decorator(func: T) -> T:
1708
+ docstr = f"""
1709
+ Performs an atomic {name} at the memory location specified by :code:`pointer`.
1710
+
1711
+ Return the data stored at :code:`pointer` before the atomic operation.
1712
+
1713
+ :param pointer: The memory locations to operate on
1714
+ :type pointer: Block of dtype=triton.PointerDType"""
1715
+ if has_cmp:
1716
+ docstr += """
1717
+ :param cmp: The values expected to be found in the atomic object
1718
+ :type cmp: Block of dtype=pointer.dtype.element_ty"""
1719
+ docstr += """
1720
+ :param val: The values with which to perform the atomic operation
1721
+ :type val: Block of dtype=pointer.dtype.element_ty
1722
+ :param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default),
1723
+ "ACQUIRE", "RELEASE", or "RELAXED")
1724
+ :type sem: str
1725
+ :param scope: Scope of threads that observe synchronizing effect of the
1726
+ atomic operation ("GPU" (default), "CTA", or "SYSTEM")
1727
+ :type scope: str
1728
+ """
1729
+ func.__doc__ = docstr
1730
+ return func
1731
+
1732
+ return _decorator
1733
+
1734
+
1735
+ @_tensor_member_fn
1736
+ @builtin
1737
+ @_add_atomic_docstr("compare-and-swap", has_cmp=True)
1738
+ def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None):
1739
+ cmp = _to_tensor(cmp, _builder)
1740
+ val = _to_tensor(val, _builder)
1741
+ sem = _constexpr_to_value(sem)
1742
+ scope = _constexpr_to_value(scope)
1743
+ return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder)
1744
+
1745
+
1746
+ @_tensor_member_fn
1747
+ @builtin
1748
+ @_add_atomic_docstr("exchange")
1749
+ def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1750
+ val = _to_tensor(val, _builder)
1751
+ sem = _constexpr_to_value(sem)
1752
+ scope = _constexpr_to_value(scope)
1753
+ mask = _constexpr_to_value(mask)
1754
+ return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder)
1755
+
1756
+
1757
+ @_tensor_member_fn
1758
+ @builtin
1759
+ @_add_atomic_docstr("add")
1760
+ def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1761
+ val = _to_tensor(val, _builder)
1762
+ sem = _constexpr_to_value(sem)
1763
+ scope = _constexpr_to_value(scope)
1764
+ mask = _constexpr_to_value(mask)
1765
+ return semantic.atomic_add(pointer, val, mask, sem, scope, _builder)
1766
+
1767
+
1768
+ @_tensor_member_fn
1769
+ @builtin
1770
+ @_add_atomic_docstr("max")
1771
+ def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1772
+ val = _to_tensor(val, _builder)
1773
+ sem = _constexpr_to_value(sem)
1774
+ scope = _constexpr_to_value(scope)
1775
+ mask = _constexpr_to_value(mask)
1776
+ return semantic.atomic_max(pointer, val, mask, sem, scope, _builder)
1777
+
1778
+
1779
+ @_tensor_member_fn
1780
+ @builtin
1781
+ @_add_atomic_docstr("min")
1782
+ def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1783
+ val = _to_tensor(val, _builder)
1784
+ sem = _constexpr_to_value(sem)
1785
+ scope = _constexpr_to_value(scope)
1786
+ mask = _constexpr_to_value(mask)
1787
+ return semantic.atomic_min(pointer, val, mask, sem, scope, _builder)
1788
+
1789
+
1790
+ @_tensor_member_fn
1791
+ @builtin
1792
+ @_add_atomic_docstr("logical and")
1793
+ def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1794
+ val = _to_tensor(val, _builder)
1795
+ sem = _constexpr_to_value(sem)
1796
+ scope = _constexpr_to_value(scope)
1797
+ mask = _constexpr_to_value(mask)
1798
+ return semantic.atomic_and(pointer, val, mask, sem, scope, _builder)
1799
+
1800
+
1801
+ @_tensor_member_fn
1802
+ @builtin
1803
+ @_add_atomic_docstr("logical or")
1804
+ def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1805
+ val = _to_tensor(val, _builder)
1806
+ sem = _constexpr_to_value(sem)
1807
+ scope = _constexpr_to_value(scope)
1808
+ mask = _constexpr_to_value(mask)
1809
+ return semantic.atomic_or(pointer, val, mask, sem, scope, _builder)
1810
+
1811
+
1812
+ @_tensor_member_fn
1813
+ @builtin
1814
+ @_add_atomic_docstr("logical xor")
1815
+ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1816
+ val = _to_tensor(val, _builder)
1817
+ sem = _constexpr_to_value(sem)
1818
+ scope = _constexpr_to_value(scope)
1819
+ mask = _constexpr_to_value(mask)
1820
+ return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder)
1821
+
1822
+
1823
+ # -----------------------
1824
+ # Conditioning
1825
+ # -----------------------
1826
+
1827
+
1828
+ @builtin
1829
+ def where(condition, x, y, _builder=None):
1830
+ """
1831
+ Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
1832
+
1833
+ Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`.
1834
+
1835
+ If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead.
1836
+
1837
+ The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`.
1838
+ :code:`x` and :code:`y` must have the same data type.
1839
+
1840
+ :param condition: When True (nonzero), yield x, otherwise yield y.
1841
+ :type condition: Block of triton.bool
1842
+ :param x: values selected at indices where condition is True.
1843
+ :param y: values selected at indices where condition is False.
1844
+ """
1845
+ condition = _to_tensor(condition, _builder)
1846
+ x = _to_tensor(x, _builder)
1847
+ y = _to_tensor(y, _builder)
1848
+ return semantic.where(condition, x, y, _builder)
1849
+
1850
+
1851
+ # -----------------------
1852
+ # Math
1853
+ # -----------------------
1854
+
1855
+
1856
+ @builtin
1857
+ def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
1858
+ """
1859
+ Computes the element-wise minimum of :code:`x` and :code:`y`.
1860
+
1861
+ :param x: the first input tensor
1862
+ :type x: Block
1863
+ :param y: the second input tensor
1864
+ :type y: Block
1865
+ :param propagate_nan: whether to propagate NaN values.
1866
+ :type propagate_nan: tl.PropagateNan
1867
+
1868
+ .. seealso:: :class:`tl.PropagateNan`
1869
+ """
1870
+ x = _to_tensor(x, _builder)
1871
+ y = _to_tensor(y, _builder)
1872
+ x = _promote_bfloat16_to_float32(x, _builder=_builder)
1873
+ y = _promote_bfloat16_to_float32(y, _builder=_builder)
1874
+ propagate_nan = _constexpr_to_value(propagate_nan)
1875
+ return semantic.minimum(x, y, propagate_nan, _builder)
1876
+
1877
+
1878
+ @builtin
1879
+ def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
1880
+ """
1881
+ Computes the element-wise maximum of :code:`x` and :code:`y`.
1882
+
1883
+ :param x: the first input tensor
1884
+ :type x: Block
1885
+ :param y: the second input tensor
1886
+ :type y: Block
1887
+ :param propagate_nan: whether to propagate NaN values.
1888
+ :type propagate_nan: tl.PropagateNan
1889
+
1890
+ .. seealso:: :class:`tl.PropagateNan`
1891
+ """
1892
+ x = _to_tensor(x, _builder)
1893
+ y = _to_tensor(y, _builder)
1894
+ x = _promote_bfloat16_to_float32(x, _builder=_builder)
1895
+ y = _promote_bfloat16_to_float32(y, _builder=_builder)
1896
+ propagate_nan = _constexpr_to_value(propagate_nan)
1897
+ return semantic.maximum(x, y, propagate_nan, _builder)
1898
+
1899
+
1900
+ @builtin
1901
+ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
1902
+ """
1903
+ Clamps the input tensor :code:`x` within the range [min, max].
1904
+ Behavior when :code:`min` > :code:`max` is undefined.
1905
+
1906
+ :param x: the input tensor
1907
+ :type x: Block
1908
+ :param min: the lower bound for clamping
1909
+ :type min: Block
1910
+ :param max: the upper bound for clamping
1911
+ :type max: Block
1912
+ :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor.
1913
+ If either :code:`min` or :code:`max` is NaN, the result is undefined.
1914
+ :type propagate_nan: tl.PropagateNan
1915
+
1916
+ .. seealso:: :class:`tl.PropagateNan`
1917
+ """
1918
+ x = _to_tensor(x, _builder)
1919
+ min = _to_tensor(min, _builder)
1920
+ max = _to_tensor(max, _builder)
1921
+ x = _promote_bfloat16_to_float32(x, _builder=_builder)
1922
+ min = _promote_bfloat16_to_float32(min, _builder=_builder)
1923
+ max = _promote_bfloat16_to_float32(max, _builder=_builder)
1924
+
1925
+ propagate_nan = _constexpr_to_value(propagate_nan)
1926
+
1927
+ return semantic.clamp(x, min, max, propagate_nan, _builder)
1928
+
1929
+
1930
+ # -----------------------
1931
+ # Reductions
1932
+ # -----------------------
1933
+
1934
+
1935
+ def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
1936
+
1937
+ def _decorator(func: T) -> T:
1938
+ docstr = """
1939
+ Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
1940
+
1941
+ :param input: the input values
1942
+ :param axis: the dimension along which the reduction should be done
1943
+ :param keep_dims: if true, keep the reduced dimensions with length 1"""
1944
+ if return_indices_arg is not None:
1945
+ docstr += f"""
1946
+ :param {return_indices_arg}: if true, return index corresponding to the {name} value"""
1947
+ if tie_break_arg is not None:
1948
+ docstr += f"""
1949
+ :param {tie_break_arg}: if true, return the left-most indices in case of ties for values that aren't NaN"""
1950
+
1951
+ func.__doc__ = docstr.format(name=name)
1952
+ return func
1953
+
1954
+ return _decorator
1955
+
1956
+
1957
+ @contextmanager
1958
+ def _insertion_guard(builder):
1959
+ ip = builder.get_insertion_point()
1960
+ yield
1961
+ builder.restore_insertion_point(ip)
1962
+
1963
+
1964
+ @_tensor_member_fn
1965
+ @builtin
1966
+ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None):
1967
+ """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
1968
+
1969
+ :param input: the input tensor, or tuple of tensors
1970
+ :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions
1971
+ :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit)
1972
+ :param keep_dims: if true, keep the reduced dimensions with length 1
1973
+
1974
+ """
1975
+ if isinstance(input, tensor):
1976
+ return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0]
1977
+
1978
+ def make_combine_region(reduce_op):
1979
+ in_scalar_tys = [t.type.scalar for t in input]
1980
+ prototype = function_type(in_scalar_tys, in_scalar_tys * 2)
1981
+
1982
+ region = reduce_op.get_region(0)
1983
+ with _insertion_guard(_builder):
1984
+ param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
1985
+ block = _builder.create_block_with_parent(region, param_types)
1986
+ args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
1987
+ results = _generator.call_JitFunction(combine_fn, args, kwargs={})
1988
+ if isinstance(results, tensor):
1989
+ handles = [results.handle]
1990
+ else:
1991
+ handles = [r.handle for r in results]
1992
+ _builder.create_reduce_ret(*handles)
1993
+
1994
+ def expand_ndims(t, ndims):
1995
+ for _ in builtins.range(ndims):
1996
+ t = expand_dims(t, 0, _builder=_builder)
1997
+ return t
1998
+
1999
+ axis = _constexpr_to_value(axis)
2000
+ keep_dims = _constexpr_to_value(keep_dims)
2001
+ if axis is not None:
2002
+ axis = _wrap_axis(axis, len(input[0].shape))
2003
+ ret = semantic.reduction(input, axis, make_combine_region, _builder)
2004
+ if keep_dims:
2005
+ if axis is not None:
2006
+ ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret)
2007
+ else:
2008
+ ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret)
2009
+ return ret
2010
+
2011
+
2012
+ @builtin
2013
+ def _promote_bfloat16_to_float32(t, _builder=None):
2014
+ scalar_ty = t.type.scalar
2015
+
2016
+ # hardware doesn't support FMAX, FMIN, CMP for bfloat16
2017
+ if scalar_ty is bfloat16:
2018
+ return t.to(float32, _builder=_builder)
2019
+ return t
2020
+
2021
+
2022
+ @builtin
2023
+ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None):
2024
+ axis = _constexpr_to_value(axis)
2025
+ n = input.shape[axis]
2026
+ index = arange(0, n, _builder=_builder)
2027
+
2028
+ if len(input.shape) > 1:
2029
+ # Broadcast index across the non-reduced axes
2030
+ axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))]
2031
+ del axes_to_expand[axis]
2032
+ index = expand_dims(index, axes_to_expand, _builder=_builder)
2033
+ index = broadcast_to(index, input.shape, _builder=_builder)
2034
+
2035
+ rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _builder=_builder,
2036
+ _generator=_generator)
2037
+ return rvalue, rindices
2038
+
2039
+
2040
+ # -----------------------
2041
+ # Scans
2042
+ # -----------------------
2043
+
2044
+
2045
+ def _add_scan_docstr(name: str) -> Callable[[T], T]:
2046
+
2047
+ def _decorator(func: T) -> T:
2048
+ docstr = """
2049
+ Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
2050
+
2051
+ :param input: the input values
2052
+ :param axis: the dimension along which the scan should be done"""
2053
+ func.__doc__ = docstr.format(name=name)
2054
+ return func
2055
+
2056
+ return _decorator
2057
+
2058
+
2059
+ @_tensor_member_fn
2060
+ @builtin
2061
+ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _generator=None):
2062
+ """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry
2063
+
2064
+ :param input: the input tensor, or tuple of tensors
2065
+ :param axis: the dimension along which the reduction should be done
2066
+ :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit)
2067
+ :param reverse: apply the associative scan in the reverse direction along axis.
2068
+
2069
+ """
2070
+ if isinstance(input, tensor):
2071
+ return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0]
2072
+
2073
+ def make_combine_region(scan_op):
2074
+ in_scalar_tys = [t.type.scalar for t in input]
2075
+ prototype = function_type(in_scalar_tys, in_scalar_tys * 2)
2076
+
2077
+ region = scan_op.get_region(0)
2078
+ with _insertion_guard(_builder):
2079
+ param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
2080
+ block = _builder.create_block_with_parent(region, param_types)
2081
+ args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
2082
+ results = _generator.call_JitFunction(combine_fn, args, kwargs={})
2083
+ if isinstance(results, tensor):
2084
+ handles = [results.handle]
2085
+ else:
2086
+ handles = [r.handle for r in results]
2087
+ _builder.create_scan_ret(*handles)
2088
+
2089
+ axis = _constexpr_to_value(axis)
2090
+ if axis is not None:
2091
+ axis = _wrap_axis(axis, len(input[0].shape))
2092
+ return semantic.associative_scan(input, axis, make_combine_region, reverse, _builder)
2093
+
2094
+
2095
+ @_tensor_member_fn
2096
+ @builtin
2097
+ def histogram(input, num_bins, _builder=None, _generator=None):
2098
+ """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0.
2099
+
2100
+ :param input: the input tensor
2101
+ :param num_bins: number of histogram bins
2102
+
2103
+ """
2104
+ num_bins = _constexpr_to_value(num_bins)
2105
+ return semantic.histogram(input, num_bins, _builder)
2106
+
2107
+
2108
+ # -----------------------
2109
+ # Compiler Hint Ops
2110
+ # -----------------------
2111
+
2112
+
2113
+ @builtin
2114
+ def debug_barrier(_builder=None):
2115
+ '''
2116
+ Insert a barrier to synchronize all threads in a block.
2117
+ '''
2118
+ return semantic.debug_barrier(_builder)
2119
+
2120
+
2121
+ @builtin
2122
+ def multiple_of(input, values, _builder=None):
2123
+ """
2124
+ Let the compiler know that the values in :code:`input` are all multiples of :code:`value`.
2125
+ """
2126
+ if isinstance(values, constexpr):
2127
+ values = [values]
2128
+ for i, d in enumerate(values):
2129
+ if not isinstance(d, constexpr):
2130
+ raise TypeError(f"values element {i} must have type `constexpr`")
2131
+ if not isinstance(d.value, int):
2132
+ raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2133
+ values = [x.value for x in values]
2134
+ return semantic.multiple_of(input, values)
2135
+
2136
+
2137
+ @builtin
2138
+ def max_contiguous(input, values, _builder=None):
2139
+ """
2140
+ Let the compiler know that the `value` first values in :code:`input` are contiguous.
2141
+ """
2142
+ if isinstance(values, constexpr):
2143
+ values = [values]
2144
+ for i, d in enumerate(values):
2145
+ if not isinstance(d, constexpr):
2146
+ raise TypeError(f"values element {i} must have type `constexpr`")
2147
+ if not isinstance(d.value, int):
2148
+ raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2149
+ values = [x.value for x in values]
2150
+ return semantic.max_contiguous(input, values)
2151
+
2152
+
2153
+ @builtin
2154
+ def max_constancy(input, values, _builder=None):
2155
+ """
2156
+ Let the compiler know that the `value` first values in :code:`input` are constant.
2157
+
2158
+ e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal,
2159
+ for example [0, 0, 0, 0, 1, 1, 1, 1].
2160
+ """
2161
+ if isinstance(values, constexpr):
2162
+ values = [values]
2163
+ for i, d in enumerate(values):
2164
+ if not isinstance(d, constexpr):
2165
+ raise TypeError(f"values element {i} must have type `constexpr`")
2166
+ if not isinstance(d.value, int):
2167
+ raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2168
+ values = [x.value for x in values]
2169
+ return semantic.max_constancy(input, values)
2170
+
2171
+
2172
+ # -----------------------
2173
+ # Debugging functions
2174
+ # -----------------------
2175
+
2176
+
2177
+ @builtin
2178
+ def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None):
2179
+ '''
2180
+ Print the values at compile time. The parameters are the same as the builtin :code:`print`.
2181
+
2182
+ NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`,
2183
+ which has special requirements for the arguments.
2184
+
2185
+ .. highlight:: python
2186
+ .. code-block:: python
2187
+
2188
+ tl.static_print(f"{BLOCK_SIZE=}")
2189
+ '''
2190
+ pass
2191
+
2192
+
2193
+ @builtin
2194
+ def static_assert(cond, msg="", _builder=None):
2195
+ '''
2196
+ Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable
2197
+ is set.
2198
+
2199
+ .. highlight:: python
2200
+ .. code-block:: python
2201
+
2202
+ tl.static_assert(BLOCK_SIZE == 1024)
2203
+ '''
2204
+ pass
2205
+
2206
+
2207
+ @builtin
2208
+ def device_print(prefix, *args, hex=False, _builder=None):
2209
+ '''
2210
+ Print the values at runtime from the device. String formatting does not work for runtime values, so you should
2211
+ provide the values you want to print as arguments. The first value must be a string, all following values must
2212
+ be scalars or tensors.
2213
+
2214
+ Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match
2215
+ this function (not the normal requirements for :code:`print`).
2216
+
2217
+ .. highlight:: python
2218
+ .. code-block:: python
2219
+
2220
+ tl.device_print("pid", pid)
2221
+ print("pid", pid)
2222
+
2223
+ On CUDA, printfs are streamed through a buffer of limited size (on one host,
2224
+ we measured the default as 6912 KiB, but this may not be consistent across
2225
+ GPUs and CUDA versions). If you notice some printfs are being dropped, you
2226
+ can increase the buffer size by calling
2227
+
2228
+ .. highlight:: python
2229
+ .. code-block:: python
2230
+
2231
+ triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes)
2232
+
2233
+ CUDA may raise an error if you try to change this value after running a
2234
+ kernel that uses printfs. The value set here may only affect the current
2235
+ device (so if you have multiple GPUs, you'd need to call it multiple times).
2236
+
2237
+ :param prefix: a prefix to print before the values. This is required to be a string literal.
2238
+ :param args: the values to print. They can be any tensor or scalar.
2239
+ :param hex: print all values as hex instead of decimal
2240
+ '''
2241
+ import string
2242
+ prefix = _constexpr_to_value(prefix)
2243
+ assert isinstance(prefix, str), f"{prefix} is not string"
2244
+ b_ascii = True
2245
+ for ch in prefix:
2246
+ if ch not in string.printable:
2247
+ b_ascii = False
2248
+ break
2249
+ assert b_ascii, f"{prefix} is not an ascii string"
2250
+ new_args = []
2251
+ for arg in args:
2252
+ new_args.append(_to_tensor(arg, _builder))
2253
+ return semantic.device_print(prefix, new_args, hex, _builder)
2254
+
2255
+
2256
+ @builtin
2257
+ def device_assert(cond, msg="", _builder=None):
2258
+ '''
2259
+ Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG`
2260
+ is set to a value besides :code:`0` in order for this to have any effect.
2261
+
2262
+ Using the Python :code:`assert` statement is the same as calling this function, except that the second argument
2263
+ must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must
2264
+ be set for this :code:`assert` statement to have any effect.
2265
+
2266
+ .. highlight:: python
2267
+ .. code-block:: python
2268
+
2269
+ tl.device_assert(pid == 0)
2270
+ assert pid == 0, f"pid != 0"
2271
+
2272
+ :param cond: the condition to assert. This is required to be a boolean tensor.
2273
+ :param msg: the message to print if the assertion fails. This is required to be a string literal.
2274
+ '''
2275
+ msg = _constexpr_to_value(msg)
2276
+ import inspect
2277
+ frame = inspect.currentframe()
2278
+ module = inspect.getmodule(frame)
2279
+ # The triton function module doesn't have the name attribute.
2280
+ # We use this trick to find the caller.
2281
+ while hasattr(module, "__name__"):
2282
+ frame = frame.f_back
2283
+ module = inspect.getmodule(frame)
2284
+ lineno = 0
2285
+ func_name = 'unknown'
2286
+ file_name = 'unknown'
2287
+ if frame is not None and frame.f_back is not None:
2288
+ func_name = frame.f_code.co_name
2289
+ file_name = frame.f_back.f_code.co_filename
2290
+ # TODO: The line number currently indicates the line
2291
+ # where the triton function is called but not where the
2292
+ # device_assert is called. Need to enhance this.
2293
+ lineno = frame.f_back.f_lineno
2294
+ return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder)
2295
+
2296
+
2297
+ @builtin
2298
+ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]],
2299
+ is_pure: bool, pack: int, _builder=None):
2300
+ '''
2301
+ Execute inline assembly over a tensor. Essentially, this is :code:`map`
2302
+ where the function is inline assembly.
2303
+
2304
+ The input tensors :code:`args` are implicitly broadcasted to the same shape.
2305
+
2306
+ :code:`dtype` can be a tuple of types, in which case the output is a
2307
+ tuple of tensors.
2308
+
2309
+ Each invocation of the inline asm processes :code:`pack` elements at a
2310
+ time. Exactly which set of inputs a block receives is unspecified.
2311
+ Input elements of size less than 4 bytes are packed into 4-byte
2312
+ registers.
2313
+
2314
+ This op does not support empty :code:`dtype` -- the inline asm must
2315
+ return at least one tensor, even if you don't need it. You can work
2316
+ around this by returning a dummy tensor of arbitrary type; it shouldn't
2317
+ cost you anything if you don't use it.
2318
+
2319
+ Example using
2320
+ [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html)
2321
+ assembly:
2322
+
2323
+ .. highlight:: python
2324
+ .. code-block:: python
2325
+
2326
+ @triton.jit
2327
+ def kernel(A, B, C, D, BLOCK: tl.constexpr):
2328
+ a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor
2329
+ b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor
2330
+
2331
+ # For each (a,b) in zip(a,b), perform the following:
2332
+ # - Let ai be `a` converted to int32.
2333
+ # - Let af be `a` converted to float.
2334
+ # - Let m be the max of ai and b.
2335
+ # - Return ai and mi.
2336
+ # Do the above 4 elements at a time.
2337
+ (c, d) = tl.inline_asm_elementwise(
2338
+ asm="""
2339
+ {
2340
+ // Unpack `a` into `ai`.
2341
+ .reg .b8 tmp<4>;
2342
+ mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8;
2343
+ cvt.u32.u8 $0, tmp0;
2344
+ cvt.u32.u8 $1, tmp1;
2345
+ cvt.u32.u8 $2, tmp2;
2346
+ cvt.u32.u8 $3, tmp3;
2347
+ }
2348
+ // Convert `ai` to float.
2349
+ cvt.rn.f32.s32 $4, $0;
2350
+ cvt.rn.f32.s32 $5, $1;
2351
+ cvt.rn.f32.s32 $6, $2;
2352
+ cvt.rn.f32.s32 $7, $3;
2353
+ // Take max of `ai` and `b`.
2354
+ max.f32 $4, $4, $9;
2355
+ max.f32 $5, $5, $10;
2356
+ max.f32 $6, $6, $11;
2357
+ max.f32 $7, $7, $12;
2358
+ """,
2359
+ constraints=(
2360
+ # 8 output registers, namely
2361
+ # $0=ai0, $1=ai1, $2=ai2, $3=ai3,
2362
+ # $4=m0, $5=m1, $6=m2, $7=m3.
2363
+ "=r,=r,=r,=r,=r,=r,=r,=r,"
2364
+ # 5 input registers, namely
2365
+ # $8=ai,
2366
+ # $9=b0, $10=b1, $11=b2, $12=b3.
2367
+ # The four elements from `a` are all packed into one register.
2368
+ "r,r,r,r,r"),
2369
+ args=[a, b],
2370
+ dtype=(tl.int32, tl.float32),
2371
+ is_pure=True,
2372
+ pack=4,
2373
+ )
2374
+ tl.store(C + tl.arange(0, BLOCK), c)
2375
+ tl.store(D + tl.arange(0, BLOCK), d)
2376
+
2377
+ :param asm: assembly to run. Must match target's assembly format.
2378
+ :param constraints: asm constraints in
2379
+ [LLVM format](https://llvm.org/docs/LangRef.html#inline-asm-constraint-string)
2380
+ :param args: the input tensors, whose values are passed to the asm block
2381
+ :param dtype: the element type(s) of the returned tensor(s)
2382
+ :param is_pure: if true, the compiler assumes the asm block has no side-effects
2383
+ :param pack: the number of elements to be processed by one instance of inline assembly
2384
+ :param _builder: the builder
2385
+ :return: one tensor or a tuple of tensors of the given dtypes
2386
+ '''
2387
+ asm = _constexpr_to_value(asm)
2388
+ constraints = _constexpr_to_value(constraints)
2389
+ pack = _constexpr_to_value(pack)
2390
+ is_pure = _constexpr_to_value(is_pure)
2391
+
2392
+ # Wrap `dtype` in a tuple if it's not already.
2393
+ try:
2394
+ iter(dtype) # type: ignore
2395
+ has_multiple_outputs = True
2396
+ except TypeError:
2397
+ has_multiple_outputs = False
2398
+ dtype = (dtype, ) # type: ignore
2399
+
2400
+ dtype = typing.cast(Sequence[_DtypeClass], dtype)
2401
+
2402
+ res_tys = dtype
2403
+ if dispatch_args := [_to_tensor(arg, _builder) for arg in args]:
2404
+ bin_op_type_checking = partial(
2405
+ semantic.binary_op_type_checking_impl,
2406
+ builder=_builder,
2407
+ arithmetic_check=False,
2408
+ allow_lhs_ptr=True,
2409
+ allow_rhs_ptr=True,
2410
+ )
2411
+ broadcast_arg = dispatch_args[0]
2412
+ # Get the broadcast shape over all the arguments
2413
+ for item in dispatch_args:
2414
+ _, broadcast_arg = bin_op_type_checking(item, broadcast_arg)
2415
+ if broadcast_arg.shape:
2416
+ # Change the shape of each argument based on the broadcast shape
2417
+ for i, item in enumerate(dispatch_args):
2418
+ dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg)
2419
+ res_tys = [block_type(dt, broadcast_arg.shape) for dt in dtype]
2420
+ handles = [t.handle for t in dispatch_args]
2421
+ call = _builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(_builder) for ty in res_tys], is_pure, pack)
2422
+
2423
+ if not has_multiple_outputs:
2424
+ return tensor(call.get_result(0), res_tys[0])
2425
+ return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys))
2426
+
2427
+
2428
+ # -----------------------
2429
+ # Iterators
2430
+ # -----------------------
2431
+
2432
+
2433
+ class static_range:
2434
+ """
2435
+ Iterator that counts upward forever.
2436
+
2437
+ .. highlight:: python
2438
+ .. code-block:: python
2439
+
2440
+ @triton.jit
2441
+ def kernel(...):
2442
+ for i in tl.static_range(10):
2443
+ ...
2444
+ :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
2445
+ :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively.
2446
+ :param arg1: the start value.
2447
+ :param arg2: the end value.
2448
+ :param step: the step value.
2449
+ """
2450
+
2451
+ def __init__(self, arg1, arg2=None, step=None):
2452
+ assert isinstance(arg1, constexpr)
2453
+ if step is None:
2454
+ self.step = constexpr(1)
2455
+ else:
2456
+ assert isinstance(step, constexpr)
2457
+ self.step = step
2458
+ if arg2 is None:
2459
+ self.start = constexpr(0)
2460
+ self.end = arg1
2461
+ else:
2462
+ assert isinstance(arg2, constexpr)
2463
+ self.start = arg1
2464
+ self.end = arg2
2465
+
2466
+ def __iter__(self):
2467
+ raise RuntimeError("static_range can only be used in @triton.jit'd functions")
2468
+
2469
+ def __next__(self):
2470
+ raise RuntimeError("static_range can only be used in @triton.jit'd functions")
2471
+
2472
+
2473
+ class range:
2474
+ """
2475
+ Iterator that counts upward forever.
2476
+
2477
+ .. highlight:: python
2478
+ .. code-block:: python
2479
+
2480
+ @triton.jit
2481
+ def kernel(...):
2482
+ for i in tl.range(10, num_stages=3):
2483
+ ...
2484
+ :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
2485
+ :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler.
2486
+ :param arg1: the start value.
2487
+ :param arg2: the end value.
2488
+ :param step: the step value.
2489
+ :param num_stages: pipeline the loop into this many stages (so there are
2490
+ :code:`num_stages` iterations of the loop in flight at once).
2491
+
2492
+ Note this is subtly different than passing :code:`num_stages` as a
2493
+ kernel argument. The kernel argument only pipelines loads that feed
2494
+ into :code:`dot` operations, while this attribute tries to pipeline most
2495
+ (though not all) loads in this loop.
2496
+ """
2497
+
2498
+ def __init__(self, arg1, arg2=None, step=None, num_stages=None):
2499
+ if step is None:
2500
+ self.step = constexpr(1)
2501
+ else:
2502
+ self.step = step
2503
+ if arg2 is None:
2504
+ self.start = constexpr(0)
2505
+ self.end = arg1
2506
+ else:
2507
+ self.start = arg1
2508
+ self.end = arg2
2509
+ self.num_stages = num_stages
2510
+
2511
+ def __iter__(self):
2512
+ raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
2513
+
2514
+ def __next__(self):
2515
+ raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
2516
+
2517
+
2518
+ # -----------------------
2519
+ # Extern functions
2520
+ # -----------------------
2521
+
2522
+
2523
+ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
2524
+ is_pure: bool, _builder=None):
2525
+ '''
2526
+ Dispatch a function to a library
2527
+ :param func: the function to dispatch
2528
+ :param lib_name: the name of the library
2529
+ :param lib_path: the path of the library
2530
+ :param args: the arguments of the function
2531
+ :param arg_type_symbol_dict: the type of the arguments
2532
+ :param ret_shape: the shape of the return value
2533
+ :param _builder: the builder
2534
+ :return: the return value of the function
2535
+ '''
2536
+ if len(arg_type_symbol_dict) == 0:
2537
+ raise ValueError("arg_type_symbol_dict is empty")
2538
+
2539
+ num_args = len(list(arg_type_symbol_dict.keys())[0])
2540
+ if len(args) != num_args:
2541
+ raise ValueError(f"length of input args does not match."
2542
+ f"Expect {len(args)}, got {num_args}")
2543
+
2544
+ arg_types = []
2545
+ arg_list = []
2546
+ for arg in args:
2547
+ if isinstance(arg, tensor):
2548
+ arg_types.append(arg.dtype)
2549
+ arg_list.append(arg.handle)
2550
+ else:
2551
+ arg_types.append(type(arg))
2552
+ arg_list.append(arg)
2553
+ arg_types = tuple(arg_types)
2554
+
2555
+ if arg_types not in arg_type_symbol_dict:
2556
+ raise ValueError(f"input arg type does not match."
2557
+ f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
2558
+ else:
2559
+ symbol = arg_type_symbol_dict[arg_types][0]
2560
+ ret_type = arg_type_symbol_dict[arg_types][1]
2561
+ if ret_shape:
2562
+ ret_type = block_type(ret_type, ret_shape)
2563
+ return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
2564
+
2565
+
2566
+ @builtin
2567
+ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
2568
+ _builder=None):
2569
+ '''
2570
+ Dispatch an elementwise function to a library
2571
+ :param lib_name: the name of the library
2572
+ :param lib_path: the path of the library
2573
+ :param args: the arguments of the function
2574
+ :param arg_type_symbol_dict: the type of the arguments
2575
+ :param is_pure: whether the function is pure
2576
+ :param _builder: the builder
2577
+ :return: the return value of the function
2578
+ '''
2579
+ dispatch_args = args.copy()
2580
+ all_scalar = True
2581
+ ret_shape = None
2582
+ arg_types = []
2583
+ for i in builtins.range(len(dispatch_args)):
2584
+ dispatch_args[i] = _to_tensor(dispatch_args[i], _builder)
2585
+ arg_types.append(dispatch_args[i].dtype)
2586
+ if dispatch_args[i].type.is_block():
2587
+ all_scalar = False
2588
+ if len(arg_types) > 0:
2589
+ arg_types = tuple(arg_types)
2590
+ arithmetic_check = True
2591
+ # If there's a type tuple that is not supported by the library, we will do arithmetic check
2592
+ if arg_types in arg_type_symbol_dict:
2593
+ arithmetic_check = False
2594
+ broadcast_arg = dispatch_args[0]
2595
+ # Get the broadcast shape over all the arguments
2596
+ for item in dispatch_args:
2597
+ _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
2598
+ arithmetic_check=arithmetic_check)
2599
+ # Change the shape of each argument based on the broadcast shape
2600
+ for i in builtins.range(len(dispatch_args)):
2601
+ dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
2602
+ arithmetic_check=arithmetic_check)
2603
+ if not all_scalar:
2604
+ ret_shape = broadcast_arg.shape
2605
+ func = _builder.create_extern_elementwise
2606
+ return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder)
2607
+
2608
+
2609
+ def binary_op_type_legalization(lhs, rhs, builder):
2610
+ '''
2611
+ Convert both operands to a single common type
2612
+ :param lhs: the left operand
2613
+ :param rhs: the right operand
2614
+ :param builder: the builder
2615
+ '''
2616
+ return semantic.binary_op_type_checking_impl(lhs, rhs, builder)
2617
+
2618
+
2619
+ def extern(fn):
2620
+ """A decorator for external functions."""
2621
+ return builtin(fn)