triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3405 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from warnings import warn
5
+ from contextlib import contextmanager
6
+ from enum import Enum
7
+ from functools import partial, wraps
8
+ import typing
9
+ from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
10
+ from dataclasses import dataclass
11
+ import builtins
12
+ from .. import knobs
13
+ from ..runtime.jit import JITCallable
14
+ import inspect
15
+
16
+ from .._C.libtriton import ir
17
+ from .._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape, get_primitive_bitwidth
18
+
19
+ T = TypeVar('T')
20
+
21
+ TRITON_BUILTIN = "__triton_builtin__"
22
+
23
+ PropagateNan = ir.PROPAGATE_NAN
24
+
25
+
26
+ def must_use_result(x, s=True):
27
+ """If the result of this function is unused, throw an error."""
28
+ if isinstance(x, str):
29
+ return (lambda fn: must_use_result(fn, x))
30
+ x._must_use_result = s
31
+ return x
32
+
33
+
34
+ def builtin(fn: T) -> T:
35
+ """Mark a function as a builtin."""
36
+ assert callable(fn)
37
+
38
+ @wraps(fn)
39
+ def wrapper(*args, **kwargs):
40
+ if "_semantic" not in kwargs or kwargs["_semantic"] is None:
41
+ raise ValueError("Did you forget to add @triton.jit ? "
42
+ "(`_semantic` argument must be provided outside of JIT functions.)")
43
+ return fn(*args, **kwargs)
44
+
45
+ setattr(wrapper, TRITON_BUILTIN, True)
46
+
47
+ return wrapper
48
+
49
+
50
+ def _tensor_member_fn(fn: T) -> T:
51
+ """Decorator that adds this free function as a member fn on class tensor.
52
+
53
+ When called as a member function on class tensor, the first argument to `fn`
54
+ is `self`, i.e. the tensor object.
55
+
56
+ If there are multiple decorators on a function, you probably want this one
57
+ to be the highest one (i.e. furthest from the function's `def`), so it's
58
+ applied last.
59
+
60
+ Unfortunately you still need to add a type stub to the body of class tensor
61
+ in order for pytype to know about it.
62
+ """
63
+ assert callable(fn)
64
+ orig_sig = inspect.signature(fn)
65
+ # Does fn take args other than _semantic, _generator, and the tensor itself?
66
+ has_args = len(orig_sig.parameters.keys() - {"_semantic", "_generator"}) > 1
67
+
68
+ if not fn.__doc__:
69
+ fn.__doc__ = ""
70
+ fn.__doc__ += f"""
71
+ This function can also be called as a member function on :py:class:`tensor`,
72
+ as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of
73
+ :code:`{fn.__name__}(x{", ..." if has_args else ""})`.
74
+ """
75
+
76
+ def wrapper(*args, **kwargs):
77
+ return fn(*args, **kwargs)
78
+
79
+ # Match the signature of `fn`, but change the first arg to `self` so the
80
+ # docs are a little less weird.
81
+ new_params = list(orig_sig.parameters.values())
82
+ new_params[0] = new_params[0].replace(name='self')
83
+ new_sig = orig_sig.replace(parameters=new_params)
84
+ wrapper.__signature__ = new_sig
85
+ wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function"
86
+ # If fn is a builtin, mark the wrapper as a builtin too.
87
+ if is_builtin(fn):
88
+ setattr(wrapper, TRITON_BUILTIN, True)
89
+
90
+ setattr(tensor, fn.__name__, fn if isinstance(fn, JITCallable) else wrapper)
91
+ return fn
92
+
93
+
94
+ def _unwrap_iterable(x):
95
+ """Returns x[0] if x has one element and x[0] is iterable."""
96
+ if len(x) == 1:
97
+ # Determine whether x[0] is iterable.
98
+ #
99
+ # You might want to use collections.abc.Iterable instead of this
100
+ # try/except block. Unfortunately, this doesn't work with constexpr.
101
+ #
102
+ # The problem is that abc.Iterable checks for __iter__ on the *class*.
103
+ # But we want constexpr to expose an __iter__ method if and only if the
104
+ # wrapped *object* (i.e. self.value) is iterable. Therefore there's no
105
+ # right answer for whether the class constexpr defines __iter__, and
106
+ # abc.Iterable doesn't work (at least not without some metaclass magic).
107
+ try:
108
+ iter(x[0])
109
+ return x[0]
110
+ except TypeError:
111
+ pass
112
+
113
+ return x
114
+
115
+
116
+ def is_builtin(fn) -> bool:
117
+ """Is this a registered triton builtin function?"""
118
+ return getattr(fn, TRITON_BUILTIN, False)
119
+
120
+
121
+ @builtin
122
+ def to_tensor(x, _semantic=None):
123
+ return _semantic.to_tensor(x)
124
+
125
+
126
+ # -----------------------
127
+ # constexpr
128
+ # -----------------------
129
+
130
+
131
+ class const:
132
+ """
133
+ This class is used as a type annotation to mark pointers to constant data.
134
+ The `store` function cannot be called with a pointer to const. Constness
135
+ is part of the pointer type and the usual Triton type consistency rules
136
+ apply. For example you cannot have a function that returns constant pointer
137
+ in one return statement and non-constant pointer in another.
138
+ """
139
+ pass
140
+
141
+
142
+ class base_value:
143
+ """Base class of values that exist in the triton IR (i.e. not constexprs).
144
+ """
145
+ type: base_type
146
+
147
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
148
+ """Flatten frontend value into a sequence of mlir handles, which are appended
149
+ to the output list
150
+ """
151
+ raise NotImplementedError
152
+
153
+
154
+ class base_type:
155
+
156
+ def __eq__(self, other) -> bool:
157
+ raise NotImplementedError("Types must implement __eq__")
158
+
159
+ def __ne__(self, other) -> bool:
160
+ return not (self == other)
161
+
162
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
163
+ """Build a frontend value with the current dtype, wrapping a list of existing handles.
164
+ cursor is the index of the first handle relevant to this value, and the function
165
+ should return the updated cursor position after any handles consumed by the created value.
166
+ """
167
+ raise NotImplementedError
168
+
169
+ def mangle(self) -> str:
170
+ raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}")
171
+
172
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
173
+ raise NotImplementedError
174
+
175
+
176
+ class constexpr_type(base_type):
177
+
178
+ def __init__(self, value):
179
+ self.value = value
180
+
181
+ def __eq__(self, other):
182
+ return isinstance(other, constexpr_type) and self.value == other.value
183
+
184
+ def __repr__(self) -> str:
185
+ return f"constexpr_type[{self.value}]"
186
+
187
+ def __hash__(self):
188
+ return hash(self.value)
189
+
190
+ def mangle(self) -> str:
191
+ return repr(self)
192
+
193
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
194
+ return
195
+
196
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
197
+ return constexpr(self.value), cursor
198
+
199
+
200
+ class constexpr(base_value):
201
+ """
202
+ This class is used to store a value that is known at compile-time.
203
+ """
204
+
205
+ def __init__(self, value):
206
+ while isinstance(value, constexpr):
207
+ value = value.value
208
+ self.value = value
209
+ self.type = constexpr_type(value)
210
+
211
+ def __repr__(self) -> str:
212
+ return f"constexpr[{self.value}]"
213
+
214
+ def __hash__(self):
215
+ return hash((self.value, self.type))
216
+
217
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
218
+ return
219
+
220
+ def __index__(self):
221
+ return self.value
222
+
223
+ # In interpreter mode, constant values are not wrapped in constexpr,
224
+ # and therefore do not have a .value attribute.
225
+ # As a result, from here and below, we need to call the _unwrap_if_constexpr
226
+ # function to obtain either constexpr.value or the value itself.
227
+ def __add__(self, other):
228
+ return constexpr(self.value + _unwrap_if_constexpr(other))
229
+
230
+ def __radd__(self, other):
231
+ return constexpr(_unwrap_if_constexpr(other) + self.value)
232
+
233
+ def __sub__(self, other):
234
+ return constexpr(self.value - _unwrap_if_constexpr(other))
235
+
236
+ def __rsub__(self, other):
237
+ return constexpr(_unwrap_if_constexpr(other) - self.value)
238
+
239
+ def __mul__(self, other):
240
+ return constexpr(self.value * _unwrap_if_constexpr(other))
241
+
242
+ def __mod__(self, other):
243
+ return constexpr(self.value % _unwrap_if_constexpr(other))
244
+
245
+ def __rmul__(self, other):
246
+ return constexpr(_unwrap_if_constexpr(other) * self.value)
247
+
248
+ def __truediv__(self, other):
249
+ return constexpr(self.value / _unwrap_if_constexpr(other))
250
+
251
+ def __rtruediv__(self, other):
252
+ return constexpr(_unwrap_if_constexpr(other) / self.value)
253
+
254
+ def __floordiv__(self, other):
255
+ return constexpr(self.value // _unwrap_if_constexpr(other))
256
+
257
+ def __rfloordiv__(self, other):
258
+ return constexpr(_unwrap_if_constexpr(other) // self.value)
259
+
260
+ def __gt__(self, other):
261
+ return constexpr(self.value > _unwrap_if_constexpr(other))
262
+
263
+ def __rgt__(self, other):
264
+ return constexpr(_unwrap_if_constexpr(other) > self.value)
265
+
266
+ def __ge__(self, other):
267
+ return constexpr(self.value >= _unwrap_if_constexpr(other))
268
+
269
+ def __rge__(self, other):
270
+ return constexpr(_unwrap_if_constexpr(other) >= self.value)
271
+
272
+ def __lt__(self, other):
273
+ return constexpr(self.value < _unwrap_if_constexpr(other))
274
+
275
+ def __rlt__(self, other):
276
+ return constexpr(_unwrap_if_constexpr(other) < self.value)
277
+
278
+ def __le__(self, other):
279
+ return constexpr(self.value <= _unwrap_if_constexpr(other))
280
+
281
+ def __rle__(self, other):
282
+ return constexpr(_unwrap_if_constexpr(other) <= self.value)
283
+
284
+ def __eq__(self, other):
285
+ return constexpr(self.value == _unwrap_if_constexpr(other))
286
+
287
+ def __ne__(self, other):
288
+ return constexpr(self.value != _unwrap_if_constexpr(other))
289
+
290
+ def __bool__(self):
291
+ return bool(self.value)
292
+
293
+ def __neg__(self):
294
+ return constexpr(-self.value)
295
+
296
+ def __and__(self, other):
297
+ return constexpr(self.value & _unwrap_if_constexpr(other))
298
+
299
+ def logical_and(self, other):
300
+ return constexpr(self.value and _unwrap_if_constexpr(other))
301
+
302
+ def __or__(self, other):
303
+ return constexpr(self.value | _unwrap_if_constexpr(other))
304
+
305
+ def __xor__(self, other):
306
+ return constexpr(self.value ^ _unwrap_if_constexpr(other))
307
+
308
+ def logical_or(self, other):
309
+ return constexpr(self.value or _unwrap_if_constexpr(other))
310
+
311
+ def __pos__(self):
312
+ return constexpr(+self.value)
313
+
314
+ def __invert__(self):
315
+ return constexpr(~self.value)
316
+
317
+ def __pow__(self, other):
318
+ return constexpr(self.value**_unwrap_if_constexpr(other))
319
+
320
+ def __rpow__(self, other):
321
+ return constexpr(_unwrap_if_constexpr(other)**self.value)
322
+
323
+ def __rshift__(self, other):
324
+ return constexpr(self.value >> _unwrap_if_constexpr(other))
325
+
326
+ def __lshift__(self, other):
327
+ return constexpr(self.value << _unwrap_if_constexpr(other))
328
+
329
+ def __not__(self):
330
+ return constexpr(not self.value)
331
+
332
+ def __iter__(self):
333
+ return iter(self.value)
334
+
335
+ def __call__(self, *args, **kwds):
336
+ return self.value(*args, **kwds)
337
+
338
+ def __getitem__(self, *args):
339
+ args = (_unwrap_if_constexpr(x) for x in _normalize_tuple(args))
340
+ return self.value.__getitem__(*args)
341
+
342
+
343
+ CONSTEXPR_0 = constexpr(0)
344
+
345
+
346
+ def _unwrap_if_constexpr(o):
347
+ if isinstance(o, list):
348
+ return [_unwrap_if_constexpr(x) for x in o]
349
+ if isinstance(o, builtins.tuple):
350
+ return builtins.tuple(_unwrap_if_constexpr(x) for x in o)
351
+ if isinstance(o, tuple):
352
+ return tuple(_unwrap_if_constexpr(x) for x in o)
353
+ return o.value if isinstance(o, constexpr) else o
354
+
355
+
356
+ def _normalize_tuple(t):
357
+ normalized_tuple = _unwrap_if_constexpr(t)
358
+ if isinstance(normalized_tuple, (list, builtins.tuple)):
359
+ normalized_tuple = tuple(normalized_tuple)
360
+ return normalized_tuple
361
+
362
+
363
+ def check_bit_width(value, shift_value):
364
+ if isinstance(value, tensor) and isinstance(shift_value, constexpr):
365
+ bitwidth = value.type.scalar.primitive_bitwidth
366
+ if shift_value.value >= bitwidth:
367
+ warn(
368
+ f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior."
369
+ )
370
+
371
+
372
+ # -----------------------
373
+ # dtype
374
+ # -----------------------
375
+
376
+
377
+ class dtype(base_type):
378
+ SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
379
+ UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
380
+ FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64']
381
+ STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
382
+ OTHER_TYPES = ['void']
383
+
384
+ class SIGNEDNESS(Enum):
385
+ SIGNED = 0
386
+ UNSIGNED = 1
387
+
388
+ class KIND(Enum):
389
+ BOOLEAN = 0
390
+ INTEGRAL = 1
391
+ FLOATING = 2
392
+
393
+ def __init__(self, name):
394
+ name = _unwrap_if_constexpr(name)
395
+ self.name = name
396
+ assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name
397
+ self.primitive_bitwidth = get_primitive_bitwidth(name)
398
+ self.itemsize = self.primitive_bitwidth // 8
399
+ if name in dtype.SINT_TYPES:
400
+ self.int_signedness = dtype.SIGNEDNESS.SIGNED
401
+ self.int_bitwidth = self.primitive_bitwidth
402
+ elif name in dtype.UINT_TYPES:
403
+ self.int_signedness = dtype.SIGNEDNESS.UNSIGNED
404
+ self.int_bitwidth = self.primitive_bitwidth
405
+ elif name in dtype.FP_TYPES:
406
+ if name == 'fp8e4b15':
407
+ self.fp_mantissa_width = 3
408
+ self.exponent_bias = 15
409
+ elif name == 'fp8e4nv':
410
+ self.fp_mantissa_width = 3
411
+ self.exponent_bias = 7
412
+ elif name == 'fp8e4b8':
413
+ self.fp_mantissa_width = 3
414
+ self.exponent_bias = 8
415
+ elif name == 'fp8e5':
416
+ self.fp_mantissa_width = 2
417
+ self.exponent_bias = 15
418
+ elif name == 'fp8e5b16':
419
+ self.fp_mantissa_width = 2
420
+ self.exponent_bias = 16
421
+ elif name == 'fp16':
422
+ self.fp_mantissa_width = 10
423
+ self.exponent_bias = 15
424
+ elif name == 'bf16':
425
+ self.fp_mantissa_width = 7
426
+ self.exponent_bias = 127
427
+ elif name == 'fp32':
428
+ self.fp_mantissa_width = 23
429
+ self.exponent_bias = 127
430
+ elif name == 'fp64':
431
+ self.fp_mantissa_width = 52
432
+ self.exponent_bias = 1023
433
+ else:
434
+ raise RuntimeError(f'Unsupported floating-point type {name}')
435
+
436
+ def is_fp8(self):
437
+ return 'fp8' in self.name
438
+
439
+ def is_fp8e4nv(self):
440
+ return self.name == 'fp8e4nv'
441
+
442
+ def is_fp8e4b8(self):
443
+ return self.name == 'fp8e4b8'
444
+
445
+ def is_fp8e4b15(self):
446
+ return self.name == 'fp8e4b15'
447
+
448
+ def is_fp8e5(self):
449
+ return self.name == 'fp8e5'
450
+
451
+ def is_fp8e5b16(self):
452
+ return self.name == 'fp8e5b16'
453
+
454
+ def is_fp16(self):
455
+ return self.name == 'fp16'
456
+
457
+ def is_bf16(self):
458
+ return self.name == 'bf16'
459
+
460
+ def is_fp32(self):
461
+ return self.name == 'fp32'
462
+
463
+ def is_fp64(self):
464
+ return self.name == 'fp64'
465
+
466
+ def is_int1(self):
467
+ return self.name == 'int1'
468
+
469
+ def is_int8(self):
470
+ return self.name == 'int8'
471
+
472
+ def is_int16(self):
473
+ return self.name == 'int16'
474
+
475
+ def is_int32(self):
476
+ return self.name == 'int32'
477
+
478
+ def is_int64(self):
479
+ return self.name == 'int64'
480
+
481
+ def is_uint8(self):
482
+ return self.name == 'uint8'
483
+
484
+ def is_uint16(self):
485
+ return self.name == 'uint16'
486
+
487
+ def is_uint32(self):
488
+ return self.name == 'uint32'
489
+
490
+ def is_uint64(self):
491
+ return self.name == 'uint64'
492
+
493
+ def is_floating(self):
494
+ return self.name in dtype.FP_TYPES
495
+
496
+ def is_standard_floating(self):
497
+ return self.name in dtype.STANDARD_FP_TYPES
498
+
499
+ def is_int_signed(self):
500
+ return self.name in dtype.SINT_TYPES
501
+
502
+ def is_int_unsigned(self):
503
+ return self.name in dtype.UINT_TYPES
504
+
505
+ def is_int(self):
506
+ return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES
507
+
508
+ def is_bool(self):
509
+ return self.is_int1()
510
+
511
+ def kind(self):
512
+ # Return int value following the type ordering bool < integer < fp
513
+ if self.is_bool():
514
+ return dtype.KIND.BOOLEAN
515
+ elif self.is_int():
516
+ return dtype.KIND.INTEGRAL
517
+ else:
518
+ assert self.is_floating()
519
+ return dtype.KIND.FLOATING
520
+
521
+ def get_int_max_value(self):
522
+ if self.is_int_signed():
523
+ return 2**(self.int_bitwidth - 1) - 1
524
+ if self.is_int_unsigned():
525
+ return 2**self.int_bitwidth - 1
526
+ assert False
527
+
528
+ def get_int_min_value(self):
529
+ if self.is_int_signed():
530
+ return -2**(self.int_bitwidth - 1)
531
+ if self.is_int_unsigned():
532
+ return 0
533
+ assert False
534
+
535
+ @staticmethod
536
+ def is_dtype(type_str):
537
+ return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES
538
+
539
+ @staticmethod
540
+ def is_void():
541
+ raise RuntimeError("Not implemented")
542
+
543
+ @staticmethod
544
+ def is_block():
545
+ return False
546
+
547
+ @staticmethod
548
+ def is_ptr():
549
+ return False
550
+
551
+ @staticmethod
552
+ def is_const():
553
+ return False
554
+
555
+ def __eq__(self, other) -> bool:
556
+ other = _unwrap_if_constexpr(other)
557
+ if not isinstance(other, dtype):
558
+ return False
559
+ return self.name == other.name
560
+
561
+ def __hash__(self):
562
+ return hash((self.name, ))
563
+
564
+ @property
565
+ def scalar(self):
566
+ return self
567
+
568
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
569
+ out.append(self.to_ir(builder))
570
+
571
+ def to_ir(self, builder: ir.builder) -> ir.type:
572
+ if self.name.startswith("fp8"):
573
+ if self.name not in builder.options.supported_fp8_dtypes:
574
+ raise ValueError(f'type {self} not supported in this architecture. '
575
+ f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}')
576
+
577
+ if self.name == 'void':
578
+ return builder.get_void_ty()
579
+ elif self.name == 'int1':
580
+ return builder.get_int1_ty()
581
+ elif self.name in ('int8', 'uint8'):
582
+ return builder.get_int8_ty()
583
+ elif self.name in ('int16', 'uint16'):
584
+ return builder.get_int16_ty()
585
+ elif self.name in ('int32', 'uint32'):
586
+ return builder.get_int32_ty()
587
+ elif self.name in ('int64', 'uint64'):
588
+ return builder.get_int64_ty()
589
+ elif self.name == 'fp8e5':
590
+ return builder.get_fp8e5_ty()
591
+ elif self.name == 'fp8e5b16':
592
+ return builder.get_fp8e5b16_ty()
593
+ elif self.name == 'fp8e4nv':
594
+ return builder.get_fp8e4nv_ty()
595
+ elif self.name == 'fp8e4b8':
596
+ return builder.get_fp8e4b8_ty()
597
+ elif self.name == 'fp8e4b15':
598
+ return builder.get_fp8e4b15_ty()
599
+ elif self.name == 'fp16':
600
+ return builder.get_half_ty()
601
+ elif self.name == 'bf16':
602
+ return builder.get_bf16_ty()
603
+ elif self.name == 'fp32':
604
+ return builder.get_float_ty()
605
+ elif self.name == 'fp64':
606
+ return builder.get_double_ty()
607
+ raise ValueError(f'fail to convert {self} to ir type')
608
+
609
+ def __str__(self):
610
+ return self.name
611
+
612
+ def codegen_name(self):
613
+ if self.name.startswith("fp"):
614
+ return "float" + self.name[2:]
615
+ elif self.name.startswith("bf"):
616
+ return "bfloat" + self.name[2:]
617
+ else:
618
+ return self.name
619
+
620
+ @property
621
+ def cache_key_part(self) -> str:
622
+ """See cache_key_part() in triton.cc."""
623
+ return self.name
624
+
625
+ def __repr__(self):
626
+ """Output of repr needs to be an evaluatable expression"""
627
+ return f'triton.language.{self.codegen_name()}'
628
+
629
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
630
+ return tensor(handles[cursor], self), cursor + 1
631
+
632
+ def mangle(self) -> str:
633
+ if self.is_int():
634
+ SIGNED = dtype.SIGNEDNESS.SIGNED
635
+ prefix = 'i' if self.int_signedness == SIGNED else 'u'
636
+ return prefix + str(self.int_bitwidth)
637
+ if self.is_floating():
638
+ return str(self)
639
+ if self.is_void():
640
+ return 'V'
641
+ return super().mangle()
642
+
643
+ def with_element_ty(self, element_ty: dtype):
644
+ assert not self.is_block()
645
+ return element_ty
646
+
647
+
648
+ # Some functions have a param named `dtype`, which shadows the `dtype` class.
649
+ # We can't change the param name because it is part of function's public API.
650
+ # Declare an alias so those functions can still reference the dtype class.
651
+ _DtypeClass = dtype
652
+
653
+
654
+ class pointer_type(dtype):
655
+
656
+ def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = False):
657
+ element_ty = _unwrap_if_constexpr(element_ty)
658
+ if not isinstance(element_ty, dtype):
659
+ raise TypeError(f'element_ty has type `{type(element_ty).__name__}`; expected `dtype`.')
660
+ self.element_ty = element_ty
661
+ self.address_space = address_space
662
+ self.const = const
663
+ self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>'
664
+
665
+ def to_ir(self, builder: ir.builder) -> ir.pointer_type:
666
+ return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space)
667
+
668
+ def __str__(self):
669
+ return self.name
670
+
671
+ def __repr__(self):
672
+ return self.__str__()
673
+
674
+ def is_ptr(self):
675
+ return True
676
+
677
+ def is_const(self):
678
+ return self.const
679
+
680
+ def __eq__(self, other) -> bool:
681
+ other = _unwrap_if_constexpr(other)
682
+ if not isinstance(other, pointer_type):
683
+ return False
684
+ return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const
685
+
686
+ @property
687
+ def scalar(self):
688
+ return self
689
+
690
+ def mangle(self) -> str:
691
+ return f"P{self.element_ty.mangle()}"
692
+
693
+
694
+ class block_type(dtype):
695
+
696
+ def __init__(self, element_ty: dtype, shape: List):
697
+ self.element_ty = element_ty
698
+
699
+ # Note that block_type's shape is a list of int
700
+ # while tensor's shape is a list of constexpr.
701
+ assert (isinstance(shape, (list, tuple)))
702
+
703
+ # shape can be empty ([]) when an input is a 0D tensor.
704
+ self.shape = tuple(_unwrap_shape(shape))
705
+ if not self.shape:
706
+ raise TypeError('0d block_type is forbidden')
707
+
708
+ self.numel = validate_block_shape(self.shape)
709
+ self.name = f'<{self.shape}, {self.element_ty}>'
710
+
711
+ def to_ir(self, builder: ir.builder) -> ir.block_type:
712
+ return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape)
713
+
714
+ def __str__(self):
715
+ return self.name
716
+
717
+ def __repr__(self):
718
+ return self.__str__()
719
+
720
+ def is_block(self):
721
+ return True
722
+
723
+ def get_block_shapes(self) -> Tuple[int]:
724
+ return self.shape
725
+
726
+ def with_element_ty(self, scalar_ty: dtype) -> block_type:
727
+ return block_type(scalar_ty, self.shape)
728
+
729
+ def __eq__(self, other) -> bool:
730
+ if not isinstance(other, block_type):
731
+ return False
732
+ return self.element_ty == other.element_ty and self.shape == other.shape
733
+
734
+ @property
735
+ def scalar(self):
736
+ return self.element_ty
737
+
738
+ @property
739
+ def nbytes(self):
740
+ return self.numel * (self.element_ty.primitive_bitwidth // 8)
741
+
742
+ def mangle(self) -> str:
743
+ elt = self.scalar.mangle()
744
+ shape = '_'.join(map(str, self.shape))
745
+ return f'{elt}S{shape}S'
746
+
747
+
748
+ class tuple_type(base_type):
749
+
750
+ def __init__(self, types, fields=None):
751
+ self.types = types
752
+ self.fields = fields or [''] * len(types)
753
+ self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']'
754
+
755
+ def __str__(self):
756
+ return self.name
757
+
758
+ def __iter__(self):
759
+ return iter(self.types)
760
+
761
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]):
762
+ for ty in self.types:
763
+ if not isinstance(ty, constexpr):
764
+ ty._flatten_ir_types(builder, out)
765
+
766
+ def __getitem__(self, index: int) -> dtype:
767
+ return self.types[index]
768
+
769
+ def __eq__(self, other):
770
+ return type(self) is type(other) and self.types == other.types and self.fields == other.fields
771
+
772
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]:
773
+ values = []
774
+ for ty in self.types:
775
+ value, cursor = ty._unflatten_ir(handles, cursor)
776
+ values.append(value)
777
+ return tuple(values, self), cursor
778
+
779
+ def mangle(self):
780
+ return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T'
781
+
782
+
783
+ class slice_type(dtype):
784
+
785
+ def __init__(self):
786
+ self.name = 'slice_type'
787
+
788
+
789
+ # scalar types
790
+ void = dtype('void')
791
+ int1 = dtype('int1')
792
+ int8 = dtype('int8')
793
+ int16 = dtype('int16')
794
+ int32 = dtype('int32')
795
+ int64 = dtype('int64')
796
+ uint8 = dtype('uint8')
797
+ uint16 = dtype('uint16')
798
+ uint32 = dtype('uint32')
799
+ uint64 = dtype('uint64')
800
+ float8e5 = dtype('fp8e5')
801
+ float8e5b16 = dtype('fp8e5b16')
802
+ float8e4nv = dtype('fp8e4nv')
803
+ float8e4b8 = dtype('fp8e4b8')
804
+ float8e4b15 = dtype('fp8e4b15')
805
+ float16 = dtype('fp16')
806
+ bfloat16 = dtype('bf16')
807
+ float32 = dtype('fp32')
808
+ float64 = dtype('fp64')
809
+ # pointer types
810
+ pi32_t = pointer_type(int32)
811
+
812
+
813
+ def get_int_dtype(bitwidth: int, signed: bool) -> dtype:
814
+ if bitwidth == 1:
815
+ return int1
816
+ elif bitwidth == 8 and signed:
817
+ return int8
818
+ elif bitwidth == 8 and not signed:
819
+ return uint8
820
+ elif bitwidth == 16 and signed:
821
+ return int16
822
+ elif bitwidth == 16 and not signed:
823
+ return uint16
824
+ elif bitwidth == 32 and signed:
825
+ return int32
826
+ elif bitwidth == 32 and not signed:
827
+ return uint32
828
+ elif bitwidth == 64 and signed:
829
+ return int64
830
+ elif bitwidth == 64 and not signed:
831
+ return uint64
832
+ else:
833
+ raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}')
834
+
835
+
836
+ # -----------------------
837
+ # tensor
838
+ # -----------------------
839
+
840
+
841
+ class tensor(base_value):
842
+ """Represents an N-dimensional array of values or pointers.
843
+
844
+ :code:`tensor` is the fundamental data structure in Triton programs. Most
845
+ functions in :py:mod:`triton.language` operate on and return tensors.
846
+
847
+ Most of the named member functions here are duplicates of the free functions
848
+ in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is
849
+ equivalent to :code:`x.sqrt()`.
850
+
851
+ :code:`tensor` also defines most of the magic/dunder methods, so you can
852
+ write :code:`x+y`, :code:`x << 2`, etc.
853
+
854
+ .. rubric:: Constructors
855
+ ..
856
+ For some reason Sphinx includes __init__ before printing the full table
857
+ of methods. Not what I want, but I can't figure out how to fix it. Give
858
+ it its own section so it looks intentional. :)
859
+ """
860
+
861
+ def __init__(self, handle, type: dtype):
862
+ """Not called by user code."""
863
+ super().__init__()
864
+ # IR handle
865
+ self.handle = handle
866
+ # Block shape
867
+ self.shape = type.shape if type.is_block() else ()
868
+ self.numel = constexpr(math.prod(self.shape))
869
+ self.type = type # Tensor type (can be block_type)
870
+ # Following the practice in pytorch, dtype is scalar type
871
+ self.dtype = type.scalar
872
+ self.shape = tuple([constexpr(s) for s in self.shape])
873
+
874
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
875
+ handles.append(self.handle)
876
+
877
+ def __str__(self) -> str:
878
+ # ex. "float32[16, 32]"
879
+ return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
880
+
881
+ @builtin
882
+ def __add__(self, other, _semantic=None):
883
+ return add(self, other, sanitize_overflow=True, _semantic=_semantic)
884
+
885
+ @builtin
886
+ def __radd__(self, other, _semantic=None):
887
+ return add(other, self, sanitize_overflow=True, _semantic=_semantic)
888
+
889
+ @builtin
890
+ def __sub__(self, other, _semantic=None):
891
+ return sub(self, other, sanitize_overflow=True, _semantic=_semantic)
892
+
893
+ @builtin
894
+ def __rsub__(self, other, _semantic=None):
895
+ return sub(other, self, sanitize_overflow=True, _semantic=_semantic)
896
+
897
+ @builtin
898
+ def __mul__(self, other, _semantic=None):
899
+ return mul(self, other, sanitize_overflow=True, _semantic=_semantic)
900
+
901
+ @builtin
902
+ def __rmul__(self, other, _semantic=None):
903
+ return mul(other, self, sanitize_overflow=True, _semantic=_semantic)
904
+
905
+ @builtin
906
+ def __truediv__(self, other, _semantic=None):
907
+ other = _unwrap_if_constexpr(other)
908
+ return _semantic.truediv(self, other)
909
+
910
+ @builtin
911
+ def __rtruediv__(self, other, _semantic=None):
912
+ other = _unwrap_if_constexpr(other)
913
+ return _semantic.truediv(other, self)
914
+
915
+ @builtin
916
+ def __floordiv__(self, other, _semantic=None):
917
+ other = _unwrap_if_constexpr(other)
918
+ return _semantic.floordiv(self, other)
919
+
920
+ @builtin
921
+ def __rfloordiv__(self, other, _semantic=None):
922
+ other = _unwrap_if_constexpr(other)
923
+ return _semantic.floordiv(other, self)
924
+
925
+ @builtin
926
+ def __mod__(self, other, _semantic=None):
927
+ other = _unwrap_if_constexpr(other)
928
+ return _semantic.mod(self, other)
929
+
930
+ @builtin
931
+ def __rmod__(self, other, _semantic=None):
932
+ other = _unwrap_if_constexpr(other)
933
+ return _semantic.mod(other, self)
934
+
935
+ # unary operators
936
+ @builtin
937
+ def __neg__(self, _semantic=None):
938
+ return _semantic.minus(self)
939
+
940
+ @builtin
941
+ def __invert__(self, _semantic=None):
942
+ return _semantic.invert(self)
943
+
944
+ # bitwise operators
945
+
946
+ @builtin
947
+ def __and__(self, other, _semantic=None):
948
+ other = _unwrap_if_constexpr(other)
949
+ return _semantic.and_(self, other)
950
+
951
+ @builtin
952
+ def __rand__(self, other, _semantic=None):
953
+ other = _unwrap_if_constexpr(other)
954
+ return _semantic.and_(other, self)
955
+
956
+ @builtin
957
+ def __or__(self, other, _semantic=None):
958
+ other = _unwrap_if_constexpr(other)
959
+ return _semantic.or_(self, other)
960
+
961
+ @builtin
962
+ def __ror__(self, other, _semantic=None):
963
+ other = _unwrap_if_constexpr(other)
964
+ return _semantic.or_(other, self)
965
+
966
+ @builtin
967
+ def __xor__(self, other, _semantic=None):
968
+ other = _unwrap_if_constexpr(other)
969
+ return _semantic.xor_(self, other)
970
+
971
+ @builtin
972
+ def __rxor__(self, other, _semantic=None):
973
+ other = _unwrap_if_constexpr(other)
974
+ return _semantic.xor_(other, self)
975
+
976
+ @builtin
977
+ def __lshift__(self, other, _semantic=None):
978
+ check_bit_width(self, other)
979
+ other = _unwrap_if_constexpr(other)
980
+ return _semantic.shl(self, other)
981
+
982
+ @builtin
983
+ def __rlshift__(self, other, _semantic=None):
984
+ check_bit_width(other, self)
985
+ other = _unwrap_if_constexpr(other)
986
+ return _semantic.shl(other, self)
987
+
988
+ @builtin
989
+ def __rshift__(self, other, _semantic=None):
990
+ check_bit_width(self, other)
991
+ other = _unwrap_if_constexpr(other)
992
+ if self.dtype.is_int_signed():
993
+ return _semantic.ashr(self, other)
994
+ else:
995
+ return _semantic.lshr(self, other)
996
+
997
+ @builtin
998
+ def __rrshift__(self, other, _semantic=None):
999
+ check_bit_width(other, self)
1000
+ other = _unwrap_if_constexpr(other)
1001
+ if self.dtype.is_int_signed():
1002
+ return _semantic.ashr(other, self)
1003
+ else:
1004
+ return _semantic.lshr(other, self)
1005
+
1006
+ # >
1007
+ @builtin
1008
+ def __gt__(self, other, _semantic=None):
1009
+ other = _semantic.to_tensor(other)
1010
+ return _semantic.greater_than(self, other)
1011
+
1012
+ @builtin
1013
+ def __rgt__(self, other, _semantic=None):
1014
+ other = _semantic.to_tensor(other)
1015
+ return _semantic.greater_than(other, self)
1016
+
1017
+ # >=
1018
+ @builtin
1019
+ def __ge__(self, other, _semantic=None):
1020
+ other = _semantic.to_tensor(other)
1021
+ return _semantic.greater_equal(self, other)
1022
+
1023
+ @builtin
1024
+ def __rge__(self, other, _semantic=None):
1025
+ other = _semantic.to_tensor(other)
1026
+ return _semantic.greater_equal(other, self)
1027
+
1028
+ # <
1029
+ @builtin
1030
+ def __lt__(self, other, _semantic=None):
1031
+ other = _semantic.to_tensor(other)
1032
+ return _semantic.less_than(self, other)
1033
+
1034
+ @builtin
1035
+ def __rlt__(self, other, _semantic=None):
1036
+ other = _semantic.to_tensor(other)
1037
+ return _semantic.less_than(other, self)
1038
+
1039
+ # <=
1040
+ @builtin
1041
+ def __le__(self, other, _semantic=None):
1042
+ other = _semantic.to_tensor(other)
1043
+ return _semantic.less_equal(self, other)
1044
+
1045
+ @builtin
1046
+ def __rle__(self, other, _semantic=None):
1047
+ other = _semantic.to_tensor(other)
1048
+ return _semantic.less_equal(other, self)
1049
+
1050
+ # ==
1051
+ @builtin
1052
+ def __eq__(self, other, _semantic=None):
1053
+ other = _semantic.to_tensor(other)
1054
+ return _semantic.equal(self, other)
1055
+
1056
+ @builtin
1057
+ def __req__(self, other, _semantic=None):
1058
+ other = _semantic.to_tensor(other)
1059
+ return _semantic.equal(other, self)
1060
+
1061
+ @builtin
1062
+ def __ne__(self, other, _semantic=None):
1063
+ other = _semantic.to_tensor(other)
1064
+ return _semantic.not_equal(self, other)
1065
+
1066
+ @builtin
1067
+ def __rne__(self, other, _semantic=None):
1068
+ other = _semantic.to_tensor(other)
1069
+ return _semantic.not_equal(other, self)
1070
+
1071
+ @builtin
1072
+ def logical_and(self, other, _semantic=None):
1073
+ other = _semantic.to_tensor(other)
1074
+ return _semantic.logical_and(self, other)
1075
+
1076
+ @builtin
1077
+ def logical_or(self, other, _semantic=None):
1078
+ other = _semantic.to_tensor(other)
1079
+ return _semantic.logical_or(self, other)
1080
+
1081
+ # note: __not__ isn't actually a magic method in python
1082
+ # but it's ok because our ASTVisitor handles it
1083
+ @builtin
1084
+ def __not__(self, _semantic=None):
1085
+ return _semantic.not_(self)
1086
+
1087
+ @builtin
1088
+ def __getitem__(self, slices, _semantic=None):
1089
+ if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None:
1090
+ slices = [slices]
1091
+ if isinstance(slices, tuple):
1092
+ slices = slices.values
1093
+ ret = self
1094
+ for dim, sl in enumerate(slices):
1095
+ if _unwrap_if_constexpr(sl) is None:
1096
+ ret = _semantic.expand_dims(ret, dim)
1097
+ elif isinstance(sl, (builtins.slice, slice)) and all(
1098
+ _unwrap_if_constexpr(arg) is None for arg in (sl.start, sl.stop, sl.step)):
1099
+ pass # an unsqueeze
1100
+ else:
1101
+ raise ValueError(f"unsupported tensor index: {sl}")
1102
+ return ret
1103
+
1104
+ @property
1105
+ def T(self):
1106
+ """Transposes a 2D tensor."""
1107
+ assert False, "Transposition must be created by the AST Visitor"
1108
+
1109
+ @builtin
1110
+ def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
1111
+ """
1112
+ Alias for :py:func:`tensor.cast`.
1113
+ """
1114
+ return cast(self, dtype, fp_downcast_rounding, bitcast, _semantic=_semantic)
1115
+
1116
+ # Type stubs for functions added by the _tensor_member_fn decorator.
1117
+ # (Unfortunately these can't be created automatically.)
1118
+ #
1119
+ # We couldn't write these definitions out even if we wanted to, because some
1120
+ # of these functions are defined in standard.py.
1121
+ def broadcast_to(self, *shape) -> tensor:
1122
+ ...
1123
+
1124
+ def trans(self, *dims) -> tensor:
1125
+ ...
1126
+
1127
+ def permute(self, *dims) -> tensor:
1128
+ ...
1129
+
1130
+ def split(self) -> tuple[tensor, tensor]:
1131
+ ...
1132
+
1133
+ def view(self, *shape) -> tensor:
1134
+ ...
1135
+
1136
+ def reshape(self, *shape) -> tensor:
1137
+ ...
1138
+
1139
+ def expand_dims(self, axis) -> tensor:
1140
+ ...
1141
+
1142
+ def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor:
1143
+ ...
1144
+
1145
+ def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor:
1146
+ ...
1147
+
1148
+ def advance(self, offsets) -> tensor:
1149
+ ...
1150
+
1151
+ def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor:
1152
+ ...
1153
+
1154
+ def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor:
1155
+ ...
1156
+
1157
+ def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor:
1158
+ ...
1159
+
1160
+ def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor:
1161
+ ...
1162
+
1163
+ def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor:
1164
+ ...
1165
+
1166
+ def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor:
1167
+ ...
1168
+
1169
+ def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor:
1170
+ ...
1171
+
1172
+ def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor:
1173
+ ...
1174
+
1175
+ def exp(self) -> tensor:
1176
+ ...
1177
+
1178
+ def log(self) -> tensor:
1179
+ ...
1180
+
1181
+ def cos(self) -> tensor:
1182
+ ...
1183
+
1184
+ def sin(self) -> tensor:
1185
+ ...
1186
+
1187
+ def sqrt(self) -> tensor:
1188
+ ...
1189
+
1190
+ def rsqrt(self) -> tensor:
1191
+ ...
1192
+
1193
+ def abs(self) -> tensor:
1194
+ ...
1195
+
1196
+ def reduce(self, axis, combine_fn, keep_dims=False) -> tensor:
1197
+ ...
1198
+
1199
+ def associative_scan(self, axis, combine_fn, reverse=False) -> tensor:
1200
+ ...
1201
+
1202
+ def gather(self, indices, axis) -> tensor:
1203
+ ...
1204
+
1205
+ def histogram(self, num_bins) -> tensor:
1206
+ ...
1207
+
1208
+ def cdiv(self, div) -> tensor:
1209
+ ...
1210
+
1211
+ def sigmoid(self) -> tensor:
1212
+ ...
1213
+
1214
+ def softmax(self, dim=None, keep_dims=False, ieee_rounding=False) -> tensor:
1215
+ ...
1216
+
1217
+ def ravel(self) -> tensor:
1218
+ ...
1219
+
1220
+ def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor:
1221
+ ...
1222
+
1223
+ def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor:
1224
+ ...
1225
+
1226
+ def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor:
1227
+ ...
1228
+
1229
+ def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor:
1230
+ ...
1231
+
1232
+ def sum(self, axis=None, keep_dims=False, dtype=None) -> tensor:
1233
+ ...
1234
+
1235
+ def xor_sum(self, axis=None, keep_dims=False) -> tensor:
1236
+ ...
1237
+
1238
+ def reduce_or(self, axis=None, keep_dims=False) -> tensor:
1239
+ ...
1240
+
1241
+ def cumsum(self, axis=0, reverse=False) -> tensor:
1242
+ ...
1243
+
1244
+ def cumprod(self, axis=0, reverse=False) -> tensor:
1245
+ ...
1246
+
1247
+ def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor:
1248
+ ...
1249
+
1250
+ def flip(self, dim=None) -> tensor:
1251
+ ...
1252
+
1253
+
1254
+ def _type_for_tuple_values(values, fields=None):
1255
+ return tuple_type([constexpr_type(x) if isinstance(x, (int, float, dtype)) else x.type for x in values], fields)
1256
+
1257
+
1258
+ class tuple(base_value):
1259
+
1260
+ def __init__(self, args: Sequence, type: Optional[tuple_type] = None):
1261
+ self.values = [i for i in args]
1262
+ if isinstance(type, tuple_type):
1263
+ self.type = type
1264
+ elif type is not None: # make_template in ASTFunction.deserialize may pass us a list/tuple
1265
+ self.type = tuple_type(type)
1266
+ else:
1267
+ self.type = _type_for_tuple_values(self.values)
1268
+
1269
+ def __getitem__(self, idx: constexpr):
1270
+ if isinstance(idx, int):
1271
+ idx = constexpr(idx)
1272
+ if isinstance(idx, constexpr):
1273
+ return self.values[idx]
1274
+ else:
1275
+ assert isinstance(idx, (slice, builtins.slice))
1276
+ return tuple(self.values[idx.start:idx.stop:idx.step])
1277
+
1278
+ def __getattr__(self, name):
1279
+ return self.values[self.type.fields.index(name)]
1280
+
1281
+ # TODO: remove
1282
+ def _setitem(self, idx, value):
1283
+ idx = _unwrap_if_constexpr(idx)
1284
+ assert isinstance(idx, int)
1285
+ self.values[idx] = value
1286
+ self.type = _type_for_tuple_values(self.values, self.type.fields)
1287
+
1288
+ def __add__(self, other):
1289
+ other = _normalize_tuple(other)
1290
+ return tuple(self.values + other.values)
1291
+ # return tuple(a + b for a, b in zip(self.values, other.values))
1292
+
1293
+ def __mul__(self, other):
1294
+ assert isinstance(other, constexpr)
1295
+ return tuple(self.values * other.value)
1296
+
1297
+ def __eq__(self, other):
1298
+ other = _normalize_tuple(other)
1299
+ return constexpr(self.values == other.values)
1300
+
1301
+ def __hash__(self):
1302
+ return hash(builtins.tuple(self.values))
1303
+
1304
+ def __str__(self):
1305
+ return str([str(x) for x in self.values])
1306
+
1307
+ def __iter__(self):
1308
+ return iter(self.values)
1309
+
1310
+ def __len__(self):
1311
+ return len(self.values)
1312
+
1313
+ def _flatten_ir(self, handles: List[ir.value]):
1314
+ for v in self.values:
1315
+ v._flatten_ir(handles)
1316
+
1317
+ def __repr__(self):
1318
+ return f"({' ,'.join(repr(x) for x in self.values)})"
1319
+
1320
+
1321
+ class slice:
1322
+
1323
+ def __init__(self, start, stop, step):
1324
+ self.start = start
1325
+ self.stop = stop
1326
+ self.step = step
1327
+ self.type = slice_type()
1328
+
1329
+
1330
+ class tensor_descriptor_base_type(base_type):
1331
+
1332
+ def __init__(self, block_type: block_type):
1333
+ self.block_type = block_type
1334
+
1335
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
1336
+ value = tensor_descriptor_base(handles[cursor], self.block_type)
1337
+ return value, cursor + 1
1338
+
1339
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1340
+ is_signed = self.block_type.element_ty.is_int_signed()
1341
+ out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed))
1342
+
1343
+ def __str__(self) -> str:
1344
+ # ex. "tensor_descriptor<float32[16, 32]>"
1345
+ return f"tensor_descriptor<{self.block_type}>"
1346
+
1347
+ def __eq__(self, other) -> bool:
1348
+ if type(other) is not type(self):
1349
+ return False
1350
+ return self.block_type == other.block_type
1351
+
1352
+ def __neq__(self, other) -> bool:
1353
+ return not (self == other)
1354
+
1355
+ def mangle(self) -> str:
1356
+ return f"TD{self.block_type.mangle()}"
1357
+
1358
+
1359
+ class tensor_descriptor_base(base_value):
1360
+ """"
1361
+ A tensor descriptor with unknown shape and strides
1362
+ """
1363
+
1364
+ def __init__(self, handle, block_type: block_type):
1365
+ """Not called by user code."""
1366
+ super().__init__()
1367
+
1368
+ self.handle = handle # IR handle
1369
+ self.type = tensor_descriptor_base_type(block_type) # Tensor type (block_type)
1370
+
1371
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
1372
+ handles.append(self.handle)
1373
+
1374
+ @property
1375
+ def block_type(self):
1376
+ return self.type.block_type
1377
+
1378
+ @property
1379
+ def block_shape(self):
1380
+ return self.type.block_type.shape
1381
+
1382
+ @property
1383
+ def dtype(self):
1384
+ return self.type.block_type.element_ty
1385
+
1386
+ def __str__(self) -> str:
1387
+ return str(self.type)
1388
+
1389
+ @builtin
1390
+ def load(self, offsets: Sequence[constexpr | tensor], _semantic=None) -> tensor:
1391
+ """Load a block from the descriptor starting at the given element offsets.
1392
+
1393
+ Values outside of the tensor bounds will be filled with zeros.
1394
+
1395
+ :note: Offset must be a multiple of 16-bytes
1396
+ """
1397
+ return _semantic.descriptor_load(self, offsets, "", "")
1398
+
1399
+ @builtin
1400
+ def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1401
+ """Store a block from the descriptor starting at the given element offsets.
1402
+
1403
+ Values outside of the tensor bounds will be ignored.
1404
+
1405
+ :note: Offset must be a multiple of 16-bytes
1406
+ """
1407
+ return _semantic.descriptor_store(self, value, offsets)
1408
+
1409
+ @builtin
1410
+ def atomic_add(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1411
+ return _semantic.descriptor_atomic_add(self, value, offsets)
1412
+
1413
+ @builtin
1414
+ def atomic_min(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1415
+ return _semantic.descriptor_atomic_min(self, value, offsets)
1416
+
1417
+ @builtin
1418
+ def atomic_max(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1419
+ return _semantic.descriptor_atomic_max(self, value, offsets)
1420
+
1421
+ @builtin
1422
+ def atomic_and(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1423
+ return _semantic.descriptor_atomic_and(self, value, offsets)
1424
+
1425
+ @builtin
1426
+ def atomic_or(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1427
+ return _semantic.descriptor_atomic_or(self, value, offsets)
1428
+
1429
+ @builtin
1430
+ def atomic_xor(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1431
+ return _semantic.descriptor_atomic_xor(self, value, offsets)
1432
+
1433
+ @builtin
1434
+ def gather(self, *args, _semantic=None) -> tensor:
1435
+ """Gather multiple descriptors worth of data"""
1436
+ assert len(args) == 2, f"descriptor gather only supports 2D indexing, but got {len(args)}"
1437
+ x_offsets = args[0]
1438
+ y_offset = args[1]
1439
+ return _semantic.descriptor_gather(self, x_offsets, y_offset, "", "")
1440
+
1441
+ @builtin
1442
+ def scatter(self, value, *args, _semantic=None) -> tensor:
1443
+ """Scatter multiple descriptors worth of data"""
1444
+ assert len(args) == 2, f"descriptor scatter only supports 2D indexing, but got {len(args)}"
1445
+ x_offsets = args[0]
1446
+ y_offset = args[1]
1447
+ return _semantic.descriptor_scatter(self, value, x_offsets, y_offset)
1448
+
1449
+
1450
+ class tensor_descriptor_type(tensor_descriptor_base_type):
1451
+
1452
+ def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type):
1453
+ self.block_type = block_type
1454
+ self.shape_type = shape_type
1455
+ self.strides_type = strides_type
1456
+
1457
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
1458
+ handle = handles[cursor]
1459
+ cursor += 1
1460
+ shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
1461
+ strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
1462
+ shape = shape.values
1463
+ strides = strides.values
1464
+ value = tensor_descriptor(handle, shape, strides, self.block_type)
1465
+ return value, cursor
1466
+
1467
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1468
+ super()._flatten_ir_types(builder, out)
1469
+ self.shape_type._flatten_ir_types(builder, out)
1470
+ self.strides_type._flatten_ir_types(builder, out)
1471
+
1472
+ def __eq__(self, other):
1473
+ return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type
1474
+ == other.strides_type)
1475
+
1476
+
1477
+ class tensor_descriptor(tensor_descriptor_base):
1478
+ """A descriptor representing a tensor in global memory.
1479
+ """
1480
+
1481
+ def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type):
1482
+ """Not called by user code."""
1483
+ # IR handle
1484
+ super().__init__(handle, block_type)
1485
+ # Global shape
1486
+ self.shape = tuple(shape)
1487
+ self.strides = tuple(strides)
1488
+ self.type = tensor_descriptor_type(
1489
+ block_type,
1490
+ shape_type=self.shape.type,
1491
+ strides_type=self.strides.type,
1492
+ )
1493
+
1494
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
1495
+ handles.append(self.handle)
1496
+ self.shape._flatten_ir(handles)
1497
+ self.strides._flatten_ir(handles)
1498
+
1499
+
1500
+ # -----------------------
1501
+ # aggregate
1502
+ # -----------------------
1503
+
1504
+
1505
+ @dataclass(frozen=True)
1506
+ class _aggregate_type(base_type):
1507
+ """A generic base type for all Triton aggregate types.
1508
+
1509
+ This class contains a reference to the original user-defined Python class
1510
+ and a list of class fields with their Triton types.
1511
+ """
1512
+
1513
+ base_cls: type
1514
+ fields: List[Tuple[str, base_type]]
1515
+
1516
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]:
1517
+ instance = self.base_cls._get_instance()
1518
+ for name, ty in self.fields:
1519
+ value, cursor = ty._unflatten_ir(handles, cursor)
1520
+ setattr(instance, name, value)
1521
+ return instance, cursor
1522
+
1523
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1524
+ for name, ty in self.fields:
1525
+ ty._flatten_ir_types(builder, out)
1526
+
1527
+ def mangle(self) -> str:
1528
+ name = f"{self.base_cls.__module__}.{self.base_cls.__qualname__}"
1529
+ fields = [ty.mangle() for (name, ty) in self.fields]
1530
+ return f"{name}<{', '.join(fields)}>"
1531
+
1532
+
1533
+ def _aggregate(cls):
1534
+
1535
+ # Define the wrapped Triton value type.
1536
+ class aggregate_value(base_value):
1537
+ __triton_builtin__ = True
1538
+ __triton_aggregate__ = True
1539
+
1540
+ @classmethod
1541
+ def _get_instance(this_cls):
1542
+ return super().__new__(this_cls)
1543
+
1544
+ def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs):
1545
+ # Call into the user-defined constructor.
1546
+ instance = this_cls._get_instance()
1547
+ if isinstance(cls.__init__, JITCallable):
1548
+ raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
1549
+ extra_kwargs = {}
1550
+ if "_semantic" in inspect.signature(cls.__init__).parameters:
1551
+ extra_kwargs["_semantic"] = _semantic
1552
+ if "_generator" in inspect.signature(cls.__init__).parameters:
1553
+ extra_kwargs["_generator"] = _generator
1554
+ cls.__init__(instance, *args, **extra_kwargs, **kwargs)
1555
+
1556
+ # Require that the user-defined constructor initialized all fields.
1557
+ for name in cls.__annotations__.keys():
1558
+ if not hasattr(instance, name):
1559
+ raise AttributeError(f"constructor for {cls.__name__} did not initialize attribute '{name}'")
1560
+
1561
+ return instance
1562
+
1563
+ # Only allow setting attributes defined in the class annotations.
1564
+ def __setattr__(self, name, value):
1565
+ if name not in cls.__annotations__:
1566
+ raise AttributeError(f"{cls.__name__} has no attribute '{name}'")
1567
+ if not isinstance(value, cls.__annotations__[name]):
1568
+ raise TypeError(f"Expected {cls.__annotations__[name]} for attribute '{name}', got {type(value)}")
1569
+ super().__setattr__(name, value)
1570
+
1571
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
1572
+ for name in cls.__annotations__.keys():
1573
+ getattr(self, name)._flatten_ir(handles)
1574
+
1575
+ @property
1576
+ def type(self):
1577
+ return _aggregate_type(aggregate_value,
1578
+ [(name, getattr(self, name).type) for name in cls.__annotations__.keys()])
1579
+
1580
+ for (name, member) in inspect.getmembers(cls):
1581
+ if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITCallable):
1582
+ if name != "__init__":
1583
+ setattr(aggregate_value, name, member)
1584
+
1585
+ aggregate_value.__name__ = cls.__name__
1586
+ aggregate_value.__module__ = cls.__module__
1587
+ aggregate_value.__qualname__ = cls.__qualname__
1588
+ aggregate_value.__doc__ = cls.__doc__
1589
+
1590
+ return aggregate_value
1591
+
1592
+
1593
+ # -----------------------
1594
+ # SPMD Programming Model
1595
+ # -----------------------
1596
+
1597
+
1598
+ @builtin
1599
+ def program_id(axis, _semantic=None):
1600
+ """
1601
+ Returns the id of the current program instance along the given :code:`axis`.
1602
+
1603
+ :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
1604
+ :type axis: int
1605
+ """
1606
+ # if axis == -1:
1607
+ # pid0 = _semantic.program_id(0)
1608
+ # pid1 = _semantic.program_id(1)
1609
+ # pid2 = _semantic.program_id(2)
1610
+ # npg0 = _semantic.num_programs(0)
1611
+ # npg1 = _semantic.num_programs(1)
1612
+ # return pid0 + pid1*npg0 + pid2*npg0*npg1
1613
+ axis = _unwrap_if_constexpr(axis)
1614
+ return _semantic.program_id(axis)
1615
+
1616
+
1617
+ @builtin
1618
+ def num_programs(axis, _semantic=None):
1619
+ """
1620
+ Returns the number of program instances launched along the given :code:`axis`.
1621
+
1622
+ :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
1623
+ :type axis: int
1624
+ """
1625
+ axis = _unwrap_if_constexpr(axis)
1626
+ return _semantic.num_programs(axis)
1627
+
1628
+
1629
+ # -----------------------
1630
+ # Block Initialization
1631
+ # -----------------------
1632
+
1633
+
1634
+ @builtin
1635
+ def arange(start, end, _semantic=None):
1636
+ start = _unwrap_if_constexpr(start)
1637
+ end = _unwrap_if_constexpr(end)
1638
+ return _semantic.arange(start, end)
1639
+
1640
+
1641
+ arange.__doc__ = f"""
1642
+ Returns contiguous values within the half-open interval :code:`[start,
1643
+ end)`. :code:`end - start` must be less than or equal to
1644
+ :code:`TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}`
1645
+
1646
+ :param start: Start of the interval. Must be a power of two.
1647
+ :type start: int32
1648
+ :param end: End of the interval. Must be a power of two greater than
1649
+ :code:`start`.
1650
+ :type end: int32
1651
+ """
1652
+
1653
+
1654
+ def _unwrap_shape(shape):
1655
+ shape = _unwrap_if_constexpr(shape)
1656
+ return [_unwrap_if_constexpr(s) for s in shape]
1657
+
1658
+
1659
+ def _shape_check_impl(shape):
1660
+ shape = _unwrap_shape(shape)
1661
+ validate_block_shape(shape)
1662
+ return shape
1663
+
1664
+
1665
+ @builtin
1666
+ def full(shape, value, dtype, _semantic=None):
1667
+ """
1668
+ Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`.
1669
+
1670
+ :param shape: Shape of the new array, e.g., (8, 16) or (8, )
1671
+ :type shape: tuple of ints
1672
+ :param value: A scalar value to fill the array with
1673
+ :type value: scalar
1674
+ :param dtype: Data type of the new array, e.g., :code:`tl.float16`
1675
+ :type dtype: tl.dtype
1676
+ """
1677
+ shape = _shape_check_impl(shape)
1678
+ value = _unwrap_if_constexpr(value)
1679
+ dtype = _unwrap_if_constexpr(dtype)
1680
+ return _semantic.full(shape, value, dtype)
1681
+
1682
+
1683
+ # -----------------------
1684
+ # Shape Manipulation
1685
+ # -----------------------
1686
+
1687
+
1688
+ @builtin
1689
+ def broadcast(input, other, _semantic=None):
1690
+ """
1691
+ Tries to broadcast the two given blocks to a common compatible shape.
1692
+
1693
+ :param input: The first input tensor.
1694
+ :type input: Block
1695
+ :param other: The second input tensor.
1696
+ :type other: Block
1697
+ """
1698
+ return _semantic.broadcast_impl_value(input, other)
1699
+
1700
+
1701
+ @_tensor_member_fn
1702
+ @builtin
1703
+ def broadcast_to(input, *shape, _semantic=None):
1704
+ """
1705
+ Tries to broadcast the given tensor to a new :code:`shape`.
1706
+
1707
+ :param input: The input tensor.
1708
+ :type input: Block
1709
+ :param shape: The desired shape.
1710
+ :type shape:
1711
+
1712
+ :code:`shape` can be passed as a tuple or as individual parameters: ::
1713
+
1714
+ # These are equivalent
1715
+ broadcast_to(x, (32, 32))
1716
+ broadcast_to(x, 32, 32)
1717
+ """
1718
+ shape = _shape_check_impl(_unwrap_iterable(shape))
1719
+ return _semantic.broadcast_impl_shape(input, shape)
1720
+
1721
+
1722
+ @_tensor_member_fn
1723
+ @builtin
1724
+ def trans(input: tensor, *dims, _semantic=None):
1725
+ """
1726
+ Permutes the dimensions of a tensor.
1727
+
1728
+ If the parameter :code:`dims` is not specified, the function defaults to a (1,0) permutation,
1729
+ effectively transposing a 2D tensor.
1730
+
1731
+ :param input: The input tensor.
1732
+ :param dims: The desired ordering of dimensions. For example,
1733
+ :code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
1734
+
1735
+ :code:`dims` can be passed as a tuple or as individual parameters: ::
1736
+
1737
+ # These are equivalent
1738
+ trans(x, (2, 1, 0))
1739
+ trans(x, 2, 1, 0)
1740
+
1741
+ :py:func:`permute` is equivalent to this function, except it doesn't
1742
+ have the special case when no permutation is specified.
1743
+ """
1744
+ dims = _unwrap_iterable(dims)
1745
+ if not dims:
1746
+ dims = (1, 0)
1747
+ return _semantic.permute(input, dims)
1748
+
1749
+
1750
+ @_tensor_member_fn
1751
+ @builtin
1752
+ def permute(input, *dims, _semantic=None):
1753
+ """
1754
+ Permutes the dimensions of a tensor.
1755
+
1756
+ :param input: The input tensor.
1757
+ :type input: Block
1758
+ :param dims: The desired ordering of dimensions. For example,
1759
+ :code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
1760
+
1761
+ :code:`dims` can be passed as a tuple or as individual parameters: ::
1762
+
1763
+ # These are equivalent
1764
+ permute(x, (2, 1, 0))
1765
+ permute(x, 2, 1, 0)
1766
+
1767
+ :py:func:`trans` is equivalent to this function, except when
1768
+ :code:`dims` is empty, it tries to do a (1,0) permutation.
1769
+ """
1770
+ dims = _unwrap_iterable(dims)
1771
+ return _semantic.permute(input, dims)
1772
+
1773
+
1774
+ @builtin
1775
+ def cat(input, other, can_reorder=False, _semantic=None):
1776
+ """
1777
+ Concatenate the given blocks
1778
+
1779
+ :param input: The first input tensor.
1780
+ :type input: Tensor
1781
+ :param other: The second input tensor.
1782
+ :type other: Tensor
1783
+ :param reorder: Compiler hint. If true, the compiler is
1784
+ allowed to reorder elements while concatenating inputs. Only use if the
1785
+ order does not matter (e.g., result is only used in reduction ops).
1786
+ Current implementation of `cat` supports only can_reorder=True.
1787
+ """
1788
+ return _semantic.cat(input, other, can_reorder)
1789
+
1790
+
1791
+ @builtin
1792
+ def join(a, b, _semantic=None):
1793
+ """
1794
+ Join the given tensors in a new, minor dimension.
1795
+
1796
+ For example, given two tensors of shape (4,8), produces a new tensor of
1797
+ shape (4,8,2). Given two scalars, returns a tensor of shape (2).
1798
+
1799
+ The two inputs are broadcasted to be the same shape.
1800
+
1801
+ If you want to join more than two elements, you can use multiple calls to
1802
+ this function. This reflects the constraint in Triton that tensors must
1803
+ have power-of-two sizes.
1804
+
1805
+ join is the inverse of split.
1806
+
1807
+ :param a: The first input tensor.
1808
+ :type a: Tensor
1809
+ :param b: The second input tensor.
1810
+ :type b: Tensor
1811
+ """
1812
+ return _semantic.join(a, b)
1813
+
1814
+
1815
+ def _unsplat(x, _semantic=None, _generator=None):
1816
+ """
1817
+ Convert a single-element tensor to a scalar.
1818
+ """
1819
+ if len(x.shape) == 0:
1820
+ return x
1821
+ numel = 1
1822
+ for d in x.shape:
1823
+ numel *= d
1824
+ assert numel == 1, "can only unsplat single-element tensors"
1825
+ return _semantic.unsplat(x)
1826
+
1827
+
1828
+ @_tensor_member_fn
1829
+ @builtin
1830
+ def split(a, _semantic=None, _generator=None) -> tuple[tensor, tensor]:
1831
+ """
1832
+ Split a tensor in two along its last dim, which must have size 2.
1833
+
1834
+ For example, given a tensor of shape (4,8,2), produces two tensors of shape
1835
+ (4,8). Given a tensor of shape (2), returns two scalars.
1836
+
1837
+ If you want to split into more than two pieces, you can use multiple calls
1838
+ to this function (probably plus calling reshape). This reflects the
1839
+ constraint in Triton that tensors must have power-of-two sizes.
1840
+
1841
+ split is the inverse of join.
1842
+
1843
+ :param a: The tensor to split.
1844
+ :type a: Tensor
1845
+ """
1846
+ # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars.
1847
+ # But _semantic.split can only handle returning tensors. Work around this by
1848
+ # expanding the input to shape [1,2] and then reducing the result.
1849
+ was_rank_1 = len(a.shape) == 1
1850
+ if was_rank_1:
1851
+ a = _semantic.expand_dims(a, 0)
1852
+
1853
+ out_lhs, out_rhs = _semantic.split(a)
1854
+
1855
+ if was_rank_1:
1856
+ # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar.
1857
+ out_lhs = _unsplat(out_lhs, _semantic=_semantic, _generator=_generator)
1858
+ out_rhs = _unsplat(out_rhs, _semantic=_semantic, _generator=_generator)
1859
+
1860
+ return out_lhs, out_rhs
1861
+
1862
+
1863
+ @_tensor_member_fn
1864
+ @builtin
1865
+ def view(input, *shape, _semantic=None):
1866
+ """
1867
+ Returns a tensor with the same elements as `input` but a different shape.
1868
+ The order of the elements may not be preserved.
1869
+
1870
+ :param input: The input tensor.
1871
+ :type input: Block
1872
+ :param shape: The desired shape.
1873
+
1874
+ :code:`shape` can be passed as a tuple or as individual parameters: ::
1875
+
1876
+ # These are equivalent
1877
+ view(x, (32, 32))
1878
+ view(x, 32, 32)
1879
+ """
1880
+ warn("view is deprecated, please use reshape with can_reorder being true.")
1881
+ shape = _shape_check_impl(_unwrap_iterable(shape))
1882
+ return _semantic.reshape(input, shape, can_reorder=True)
1883
+
1884
+
1885
+ @_tensor_member_fn
1886
+ @builtin
1887
+ def item(input, _semantic=None, _generator=None):
1888
+ """
1889
+ Converts a single-element tensor into a scalar.
1890
+ """
1891
+ return _unsplat(input, _semantic=_semantic, _generator=_generator)
1892
+
1893
+
1894
+ @_tensor_member_fn
1895
+ @builtin
1896
+ def reshape(input, *shape, can_reorder=False, _semantic=None, _generator=None):
1897
+ """
1898
+ Returns a tensor with the same number of elements as input but with the
1899
+ provided shape.
1900
+
1901
+ :param input: The input tensor.
1902
+ :type input: Block
1903
+ :param shape: The new shape.
1904
+
1905
+ :code:`shape` can be passed as a tuple or as individual parameters: ::
1906
+
1907
+ # These are equivalent
1908
+ reshape(x, (32, 32))
1909
+ reshape(x, 32, 32)
1910
+ """
1911
+ shape = _shape_check_impl(_unwrap_iterable(shape))
1912
+ if len(shape) == 0:
1913
+ return _unsplat(input, _semantic=_semantic, _generator=_generator)
1914
+ return _semantic.reshape(input, shape, can_reorder)
1915
+
1916
+
1917
+ def _wrap_axis(axis, ndim):
1918
+ if not (-ndim <= axis < ndim):
1919
+ raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}")
1920
+
1921
+ return axis if axis >= 0 else axis + ndim
1922
+
1923
+
1924
+ @_tensor_member_fn
1925
+ @builtin
1926
+ def expand_dims(input, axis, _semantic=None):
1927
+ """
1928
+ Expand the shape of a tensor, by inserting new length-1 dimensions.
1929
+
1930
+ Axis indices are with respect to the resulting tensor, so
1931
+ ``result.shape[axis]`` will be 1 for each axis.
1932
+
1933
+ :param input: The input tensor.
1934
+ :type input: tl.tensor
1935
+ :param axis: The indices to add new axes
1936
+ :type axis: int | Sequence[int]
1937
+
1938
+ """
1939
+ input = _semantic.to_tensor(input)
1940
+ axis = _unwrap_if_constexpr(axis)
1941
+ axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis]
1942
+ new_ndim = len(input.shape) + len(axes)
1943
+ axes = [_wrap_axis(_unwrap_if_constexpr(d), new_ndim) for d in axes]
1944
+
1945
+ if len(set(axes)) != len(axes):
1946
+ raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}")
1947
+
1948
+ ret = input
1949
+ for a in sorted(axes):
1950
+ ret = _semantic.expand_dims(ret, a)
1951
+ return ret
1952
+
1953
+
1954
+ @_tensor_member_fn
1955
+ @builtin
1956
+ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
1957
+ """
1958
+ Casts a tensor to the given :code:`dtype`.
1959
+
1960
+ :param dtype: The target data type.
1961
+ :type dtype: tl.dtype
1962
+ :param fp_downcast_rounding: The rounding mode for downcasting
1963
+ floating-point values. This parameter is only used when self is a
1964
+ floating-point tensor and dtype is a floating-point type with a
1965
+ smaller bitwidth. Supported values are :code:`"rtne"` (round to
1966
+ nearest, ties to even) and :code:`"rtz"` (round towards zero).
1967
+ :type fp_downcast_rounding: str, optional
1968
+ :param bitcast: If true, the tensor is bitcasted to the given
1969
+ :code:`dtype`, instead of being numerically casted.
1970
+ :type bitcast: bool, optional
1971
+ """
1972
+ input = _semantic.to_tensor(input)
1973
+ dtype = _unwrap_if_constexpr(dtype)
1974
+ fp_downcast_rounding = _unwrap_if_constexpr(fp_downcast_rounding)
1975
+ bitcast = _unwrap_if_constexpr(bitcast)
1976
+ if bitcast:
1977
+ return _semantic.bitcast(input, dtype)
1978
+ return _semantic.cast(input, dtype, fp_downcast_rounding)
1979
+
1980
+
1981
+ # -----------------------
1982
+ # Linear Algebra
1983
+ # -----------------------
1984
+
1985
+
1986
+ @builtin
1987
+ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32,
1988
+ _semantic=None):
1989
+ """
1990
+ Returns the matrix product of two blocks.
1991
+
1992
+ The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions.
1993
+ For three-dimensional blocks, `tl.dot` performs the batched matrix product,
1994
+ where the first dimension of each block represents the batch dimension.
1995
+
1996
+ :param input: The first tensor to be multiplied.
1997
+ :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
1998
+ :param other: The second tensor to be multiplied.
1999
+ :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
2000
+ :param acc: The accumulator tensor. If not None, the result is added to this tensor.
2001
+ :type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`}
2002
+ :param input_precision: How to exercise the Tensor Cores for f32 x f32. If
2003
+ the device does not have Tensor Cores or the inputs are not of dtype f32,
2004
+ this option is ignored. For devices that do have tensor cores, the
2005
+ default precision is tf32.
2006
+ :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`.
2007
+ :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32".
2008
+ Only one of :code:`input_precision` and :code:`allow_tf32` can be
2009
+ specified (i.e. at least one must be :code:`None`).
2010
+ """
2011
+ assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified"
2012
+ if input_precision is None:
2013
+ supports_tf32 = "tf32" in _semantic.builder.options.allowed_dot_input_precisions
2014
+ input_precision = knobs.language.fp32_default or ("tf32" if (supports_tf32 and
2015
+ (allow_tf32 or allow_tf32 is None)) else "ieee")
2016
+
2017
+ input_precision = _unwrap_if_constexpr(input_precision)
2018
+ out_dtype = _unwrap_if_constexpr(out_dtype)
2019
+ max_num_imprecise_acc = _unwrap_if_constexpr(max_num_imprecise_acc)
2020
+ acc = _unwrap_if_constexpr(acc)
2021
+ return _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
2022
+
2023
+
2024
+ @builtin
2025
+ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True,
2026
+ rhs_k_pack=True, out_dtype=float32, _semantic=None):
2027
+ """
2028
+ Returns the matrix product of two blocks in microscaling format.
2029
+
2030
+ lhs and rhs use microscaling formats described here:
2031
+ https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
2032
+
2033
+ Software emulation enables targeting hardware architectures without native microscaling
2034
+ operation support. Right now for such case, microscaled lhs/rhs are upcasted to
2035
+ :code:`bf16` element type beforehand for dot computation, with one exception:
2036
+ for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type,
2037
+ the other input is also upcasted to :code:`fp16` element type instead.
2038
+ This behavior is experimental and may be subject to change in the future.
2039
+
2040
+ :param lhs: The first tensor to be multiplied.
2041
+ :type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
2042
+ :param lhs_scale: Scale factor for lhs tensor.
2043
+ :type lhs_scale: e8m0 type represented as an uint8 tensor.
2044
+ :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
2045
+ :type lhs_format: str
2046
+ :param rhs: The second tensor to be multiplied.
2047
+ :type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
2048
+ :param rhs_scale: Scale factor for rhs tensor.
2049
+ :type rhs_scale: e8m0 type represented as an uint8 tensor.
2050
+ :param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
2051
+ :type rhs_format: str
2052
+ :param acc: The accumulator tensor. If not None, the result is added to this tensor.
2053
+ :param lhs_k_pack: If false, the lhs tensor is packed into uint8 along M dimension.
2054
+ :type lhs_k_pack: bool, optional
2055
+ :param rhs_k_pack: If false, the rhs tensor is packed into uint8 along N dimension.
2056
+ :type rhs_k_pack: bool, optional
2057
+ """
2058
+ out_dtype = _unwrap_if_constexpr(out_dtype)
2059
+ assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment"
2060
+ return _semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, lhs_k_pack,
2061
+ rhs_k_pack, out_dtype)
2062
+
2063
+
2064
+ # -----------------------
2065
+ # Non-Atomic Memory Operations
2066
+ # -----------------------
2067
+
2068
+
2069
+ @builtin
2070
+ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="",
2071
+ volatile=False, _semantic=None):
2072
+ """
2073
+ Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
2074
+
2075
+ (1) If `pointer` is a single element pointer, a scalar is be loaded. In
2076
+ this case:
2077
+
2078
+ - `mask` and `other` must also be scalars,
2079
+ - `other` is implicitly typecast to `pointer.dtype.element_ty`, and
2080
+ - `boundary_check` and `padding_option` must be empty.
2081
+
2082
+ (2) If `pointer` is an N-dimensional tensor of pointers, an
2083
+ N-dimensional tensor is loaded. In this case:
2084
+
2085
+ - `mask` and `other` are implicitly broadcast to `pointer.shape`,
2086
+ - `other` is implicitly typecast to `pointer.dtype.element_ty`, and
2087
+ - `boundary_check` and `padding_option` must be empty.
2088
+
2089
+ (3) If `pointer` is a block pointer defined by `make_block_ptr`, a
2090
+ tensor is loaded. In this case:
2091
+
2092
+ - `mask` and `other` must be `None`, and
2093
+ - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access.
2094
+
2095
+ :param pointer: Pointer to the data to be loaded
2096
+ :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
2097
+ :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]`
2098
+ (must be `None` with block pointers)
2099
+ :type mask: Block of `triton.int1`, optional
2100
+ :param other: if `mask[idx]` is false, return `other[idx]`
2101
+ :type other: Block, optional
2102
+ :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
2103
+ :type boundary_check: tuple of ints, optional
2104
+ :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value.
2105
+ :param cache_modifier: changes cache option in NVIDIA PTX
2106
+ :type cache_modifier: str, optional, should be one of {"", ".ca", ".cg", ".cv"}, where ".ca" stands for
2107
+ cache at all levels, ".cg" stands for cache at global level (cache in L2 and below, not L1),
2108
+ and ".cv" means don’t cache and fetch again. see
2109
+ `cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
2110
+ :param eviction_policy: changes eviction policy in NVIDIA PTX
2111
+ :type eviction_policy: str, optional
2112
+ :param volatile: changes volatile option in NVIDIA PTX
2113
+ :type volatile: bool, optional
2114
+ """
2115
+ # `mask` and `other` can be constexpr
2116
+ mask = _unwrap_if_constexpr(mask)
2117
+ other = _unwrap_if_constexpr(other)
2118
+ if mask is not None:
2119
+ mask = _semantic.to_tensor(mask)
2120
+ if other is not None:
2121
+ other = _semantic.to_tensor(other)
2122
+ padding_option = _unwrap_if_constexpr(padding_option)
2123
+ cache_modifier = _unwrap_if_constexpr(cache_modifier)
2124
+ eviction_policy = _unwrap_if_constexpr(eviction_policy)
2125
+ volatile = _unwrap_if_constexpr(volatile)
2126
+ return _semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
2127
+ volatile)
2128
+
2129
+
2130
+ @builtin
2131
+ def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor],
2132
+ _semantic=None) -> tensor:
2133
+ """Load a block of data from a tensor descriptor."""
2134
+ return desc.load(offsets, _semantic=_semantic)
2135
+
2136
+
2137
+ @builtin
2138
+ def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor], value: tensor,
2139
+ _semantic=None) -> tensor:
2140
+ """Store a block of data to a tensor descriptor."""
2141
+ return desc.store(offsets, value, _semantic=_semantic)
2142
+
2143
+
2144
+ @_tensor_member_fn
2145
+ @builtin
2146
+ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _semantic=None):
2147
+ """
2148
+ Store a tensor of data into memory locations defined by `pointer`.
2149
+
2150
+ (1) If `pointer` is a single element pointer, a scalar is stored. In
2151
+ this case:
2152
+
2153
+ - `mask` must also be scalar, and
2154
+ - `boundary_check` and `padding_option` must be empty.
2155
+
2156
+ (2) If `pointer` is an N-dimensional tensor of pointers, an
2157
+ N-dimensional block is stored. In this case:
2158
+
2159
+ - `mask` is implicitly broadcast to `pointer.shape`, and
2160
+ - `boundary_check` must be empty.
2161
+
2162
+ (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block
2163
+ of data is stored. In this case:
2164
+
2165
+ - `mask` must be None, and
2166
+ - `boundary_check` can be specified to control the behavior of out-of-bound access.
2167
+
2168
+ `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`.
2169
+
2170
+ :param pointer: The memory location where the elements of `value` are stored
2171
+ :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
2172
+ :param value: The tensor of elements to be stored
2173
+ :type value: Block
2174
+ :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]`
2175
+ :type mask: Block of triton.int1, optional
2176
+ :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
2177
+ :type boundary_check: tuple of ints, optional
2178
+ :param cache_modifier: changes cache option in NVIDIA PTX
2179
+ :type cache_modifier: str, optional, should be one of {"", ".wb", ".cg", ".cs", ".wt"}, where ".wb" stands for
2180
+ cache write-back all coherent levels, ".cg" stands for cache global, ".cs" stands for cache streaming, ".wt"
2181
+ stands for cache write-through, see `cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
2182
+ :param eviction_policy: changes eviction policy in NVIDIA PTX
2183
+ :type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"}
2184
+ """
2185
+ # `value` can be constexpr
2186
+ value = _semantic.to_tensor(value)
2187
+ mask = _unwrap_if_constexpr(mask)
2188
+ if mask is not None:
2189
+ mask = _semantic.to_tensor(mask)
2190
+ cache_modifier = _unwrap_if_constexpr(cache_modifier)
2191
+ eviction_policy = _unwrap_if_constexpr(eviction_policy)
2192
+ return _semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy)
2193
+
2194
+
2195
+ @builtin
2196
+ def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _semantic=None):
2197
+ """
2198
+ Returns a pointer to a block in a parent tensor
2199
+
2200
+ :param base: The base pointer to the parent tensor
2201
+ :param shape: The shape of the parent tensor
2202
+ :param strides: The strides of the parent tensor
2203
+ :param offsets: The offsets to the block
2204
+ :param block_shape: The shape of the block
2205
+ :param order: The order of the original data format
2206
+ """
2207
+ return _semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order)
2208
+
2209
+
2210
+ @must_use_result(
2211
+ "Note that tl.advance does not have any side effects. To move the block pointer, you need to assign the result of tl.advance to a variable."
2212
+ )
2213
+ @_tensor_member_fn
2214
+ @builtin
2215
+ def advance(base, offsets, _semantic=None):
2216
+ """
2217
+ Advance a block pointer
2218
+
2219
+ :param base: the block pointer to advance
2220
+ :param offsets: the offsets to advance, a tuple by dimension
2221
+ """
2222
+ return _semantic.advance(base, offsets)
2223
+
2224
+
2225
+ @builtin
2226
+ def make_tensor_descriptor(
2227
+ base: tensor,
2228
+ shape: List[tensor],
2229
+ strides: List[tensor],
2230
+ block_shape: List[constexpr],
2231
+ padding_option="zero",
2232
+ _semantic=None,
2233
+ ) -> tensor_descriptor:
2234
+ """Make a tensor descriptor object
2235
+
2236
+ :param base: the base pointer of the tensor, must be 16-byte aligned
2237
+ :param shape: A list of non-negative integers representing the tensor shape
2238
+ :param strides: A list of tensor strides. Leading dimensions must be multiples
2239
+ of 16-byte strides and the last dimension must be contiguous.
2240
+ :param block_shape: The shape of block to be loaded/stored from global memory
2241
+
2242
+ Notes
2243
+ *****
2244
+ On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object
2245
+ and loads and stores from the descriptor will be backed by the TMA hardware.
2246
+
2247
+ Currently only 2-5 dimensional tensors are supported.
2248
+
2249
+ Example
2250
+ *******
2251
+ .. code-block:: python
2252
+
2253
+ @triton.jit
2254
+ def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
2255
+ desc = tl.make_tensor_descriptor(
2256
+ in_out_ptr,
2257
+ shape=[M, N],
2258
+ strides=[N, 1],
2259
+ block_shape=[M_BLOCK, N_BLOCK],
2260
+ )
2261
+
2262
+ moffset = tl.program_id(0) * M_BLOCK
2263
+ noffset = tl.program_id(1) * N_BLOCK
2264
+
2265
+ value = desc.load([moffset, noffset])
2266
+ desc.store([moffset, noffset], tl.abs(value))
2267
+
2268
+ # TMA descriptors require a global memory allocation
2269
+ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
2270
+ return torch.empty(size, device="cuda", dtype=torch.int8)
2271
+
2272
+ triton.set_allocator(alloc_fn)
2273
+
2274
+ M, N = 256, 256
2275
+ x = torch.randn(M, N, device="cuda")
2276
+ M_BLOCK, N_BLOCK = 32, 32
2277
+ grid = (M / M_BLOCK, N / N_BLOCK)
2278
+ inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)
2279
+
2280
+ """
2281
+
2282
+ padding_option = _unwrap_if_constexpr(padding_option)
2283
+ return _semantic.make_tensor_descriptor(base, shape, strides, block_shape, padding_option)
2284
+
2285
+
2286
+ # -----------------------
2287
+ # Atomic Memory Operations
2288
+ # -----------------------
2289
+
2290
+
2291
+ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
2292
+
2293
+ def _decorator(func: T) -> T:
2294
+ docstr = f"""
2295
+ Performs an atomic {name} at the memory location specified by :code:`pointer`.
2296
+
2297
+ Return the data stored at :code:`pointer` before the atomic operation.
2298
+
2299
+ :param pointer: The memory locations to operate on
2300
+ :type pointer: Block of dtype=triton.PointerDType"""
2301
+ if has_cmp:
2302
+ docstr += """
2303
+ :param cmp: The values expected to be found in the atomic object
2304
+ :type cmp: Block of dtype=pointer.dtype.element_ty"""
2305
+ docstr += """
2306
+ :param val: The values with which to perform the atomic operation
2307
+ :type val: Block of dtype=pointer.dtype.element_ty
2308
+ :param sem: Specifies the memory semantics for the operation. Acceptable values are "acquire",
2309
+ "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided,
2310
+ the function defaults to using "acq_rel" semantics.
2311
+ :type sem: str, optional
2312
+ :param scope: Defines the scope of threads that observe the synchronizing effect of the atomic operation.
2313
+ Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu".
2314
+ :type scope: str, optional
2315
+ """
2316
+ func.__doc__ = docstr
2317
+ return func
2318
+
2319
+ return _decorator
2320
+
2321
+
2322
+ @_tensor_member_fn
2323
+ @builtin
2324
+ @_add_atomic_docstr("compare-and-swap", has_cmp=True)
2325
+ def atomic_cas(pointer, cmp, val, sem=None, scope=None, _semantic=None):
2326
+ cmp = _semantic.to_tensor(cmp)
2327
+ val = _semantic.to_tensor(val)
2328
+ sem = _unwrap_if_constexpr(sem)
2329
+ scope = _unwrap_if_constexpr(scope)
2330
+ return _semantic.atomic_cas(pointer, cmp, val, sem, scope)
2331
+
2332
+
2333
+ @_tensor_member_fn
2334
+ @builtin
2335
+ @_add_atomic_docstr("exchange")
2336
+ def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2337
+ val = _semantic.to_tensor(val)
2338
+ sem = _unwrap_if_constexpr(sem)
2339
+ scope = _unwrap_if_constexpr(scope)
2340
+ mask = _unwrap_if_constexpr(mask)
2341
+ return _semantic.atomic_xchg(pointer, val, mask, sem, scope)
2342
+
2343
+
2344
+ @_tensor_member_fn
2345
+ @builtin
2346
+ @_add_atomic_docstr("add")
2347
+ def atomic_add(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2348
+ val = _semantic.to_tensor(val)
2349
+ sem = _unwrap_if_constexpr(sem)
2350
+ scope = _unwrap_if_constexpr(scope)
2351
+ mask = _unwrap_if_constexpr(mask)
2352
+ return _semantic.atomic_add(pointer, val, mask, sem, scope)
2353
+
2354
+
2355
+ @_tensor_member_fn
2356
+ @builtin
2357
+ @_add_atomic_docstr("max")
2358
+ def atomic_max(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2359
+ val = _semantic.to_tensor(val)
2360
+ sem = _unwrap_if_constexpr(sem)
2361
+ scope = _unwrap_if_constexpr(scope)
2362
+ mask = _unwrap_if_constexpr(mask)
2363
+ return _semantic.atomic_max(pointer, val, mask, sem, scope)
2364
+
2365
+
2366
+ @_tensor_member_fn
2367
+ @builtin
2368
+ @_add_atomic_docstr("min")
2369
+ def atomic_min(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2370
+ val = _semantic.to_tensor(val)
2371
+ sem = _unwrap_if_constexpr(sem)
2372
+ scope = _unwrap_if_constexpr(scope)
2373
+ mask = _unwrap_if_constexpr(mask)
2374
+ return _semantic.atomic_min(pointer, val, mask, sem, scope)
2375
+
2376
+
2377
+ @_tensor_member_fn
2378
+ @builtin
2379
+ @_add_atomic_docstr("logical and")
2380
+ def atomic_and(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2381
+ val = _semantic.to_tensor(val)
2382
+ sem = _unwrap_if_constexpr(sem)
2383
+ scope = _unwrap_if_constexpr(scope)
2384
+ mask = _unwrap_if_constexpr(mask)
2385
+ return _semantic.atomic_and(pointer, val, mask, sem, scope)
2386
+
2387
+
2388
+ @_tensor_member_fn
2389
+ @builtin
2390
+ @_add_atomic_docstr("logical or")
2391
+ def atomic_or(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2392
+ val = _semantic.to_tensor(val)
2393
+ sem = _unwrap_if_constexpr(sem)
2394
+ scope = _unwrap_if_constexpr(scope)
2395
+ mask = _unwrap_if_constexpr(mask)
2396
+ return _semantic.atomic_or(pointer, val, mask, sem, scope)
2397
+
2398
+
2399
+ @_tensor_member_fn
2400
+ @builtin
2401
+ @_add_atomic_docstr("logical xor")
2402
+ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2403
+ val = _semantic.to_tensor(val)
2404
+ sem = _unwrap_if_constexpr(sem)
2405
+ scope = _unwrap_if_constexpr(scope)
2406
+ mask = _unwrap_if_constexpr(mask)
2407
+ return _semantic.atomic_xor(pointer, val, mask, sem, scope)
2408
+
2409
+
2410
+ # -----------------------
2411
+ # Conditioning
2412
+ # -----------------------
2413
+
2414
+
2415
+ @builtin
2416
+ def where(condition, x, y, _semantic=None):
2417
+ """
2418
+ Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
2419
+
2420
+ Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`.
2421
+
2422
+ If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead.
2423
+
2424
+ The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`.
2425
+ :code:`x` and :code:`y` must have the same data type.
2426
+
2427
+ :param condition: When True (nonzero), yield x, otherwise yield y.
2428
+ :type condition: Block of triton.bool
2429
+ :param x: values selected at indices where condition is True.
2430
+ :param y: values selected at indices where condition is False.
2431
+ """
2432
+ condition = _semantic.to_tensor(condition)
2433
+ x = _unwrap_if_constexpr(x)
2434
+ y = _unwrap_if_constexpr(y)
2435
+ return _semantic.where(condition, x, y)
2436
+
2437
+
2438
+ # -----------------------
2439
+ # Math
2440
+ # -----------------------
2441
+
2442
+
2443
+ @builtin
2444
+ def add(x, y, sanitize_overflow: constexpr = True, _semantic=None):
2445
+ x = _unwrap_if_constexpr(x)
2446
+ y = _unwrap_if_constexpr(y)
2447
+ return _semantic.add(x, y, sanitize_overflow)
2448
+
2449
+
2450
+ @builtin
2451
+ def sub(x, y, sanitize_overflow: constexpr = True, _semantic=None):
2452
+ x = _unwrap_if_constexpr(x)
2453
+ y = _unwrap_if_constexpr(y)
2454
+ return _semantic.sub(x, y, sanitize_overflow)
2455
+
2456
+
2457
+ @builtin
2458
+ def mul(x, y, sanitize_overflow: constexpr = True, _semantic=None):
2459
+ x = _unwrap_if_constexpr(x)
2460
+ y = _unwrap_if_constexpr(y)
2461
+ return _semantic.mul(x, y, sanitize_overflow)
2462
+
2463
+
2464
+ @builtin
2465
+ def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
2466
+ """
2467
+ Computes the element-wise minimum of :code:`x` and :code:`y`.
2468
+
2469
+ :param x: the first input tensor
2470
+ :type x: Block
2471
+ :param y: the second input tensor
2472
+ :type y: Block
2473
+ :param propagate_nan: whether to propagate NaN values.
2474
+ :type propagate_nan: tl.PropagateNan
2475
+
2476
+ .. seealso:: :class:`tl.PropagateNan`
2477
+ """
2478
+ x = _semantic.to_tensor(x)
2479
+ y = _semantic.to_tensor(y)
2480
+ x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
2481
+ y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
2482
+ propagate_nan = _unwrap_if_constexpr(propagate_nan)
2483
+ return _semantic.minimum(x, y, propagate_nan)
2484
+
2485
+
2486
+ @builtin
2487
+ def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
2488
+ """
2489
+ Computes the element-wise maximum of :code:`x` and :code:`y`.
2490
+
2491
+ :param x: the first input tensor
2492
+ :type x: Block
2493
+ :param y: the second input tensor
2494
+ :type y: Block
2495
+ :param propagate_nan: whether to propagate NaN values.
2496
+ :type propagate_nan: tl.PropagateNan
2497
+
2498
+ .. seealso:: :class:`tl.PropagateNan`
2499
+ """
2500
+ x = _semantic.to_tensor(x)
2501
+ y = _semantic.to_tensor(y)
2502
+ x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
2503
+ y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
2504
+ propagate_nan = _unwrap_if_constexpr(propagate_nan)
2505
+ return _semantic.maximum(x, y, propagate_nan)
2506
+
2507
+
2508
+ @builtin
2509
+ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
2510
+ """
2511
+ Clamps the input tensor :code:`x` within the range [min, max].
2512
+ Behavior when :code:`min` > :code:`max` is undefined.
2513
+
2514
+ :param x: the input tensor
2515
+ :type x: Block
2516
+ :param min: the lower bound for clamping
2517
+ :type min: Block
2518
+ :param max: the upper bound for clamping
2519
+ :type max: Block
2520
+ :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor.
2521
+ If either :code:`min` or :code:`max` is NaN, the result is undefined.
2522
+ :type propagate_nan: tl.PropagateNan
2523
+
2524
+ .. seealso:: :class:`tl.PropagateNan`
2525
+ """
2526
+ x = _semantic.to_tensor(x)
2527
+ min = _semantic.to_tensor(min)
2528
+ max = _semantic.to_tensor(max)
2529
+ x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
2530
+ min = _promote_bfloat16_to_float32(min, _semantic=_semantic)
2531
+ max = _promote_bfloat16_to_float32(max, _semantic=_semantic)
2532
+
2533
+ propagate_nan = _unwrap_if_constexpr(propagate_nan)
2534
+
2535
+ return _semantic.clamp(x, min, max, propagate_nan)
2536
+
2537
+
2538
+ # -----------------------
2539
+ # Reductions
2540
+ # -----------------------
2541
+
2542
+
2543
+ def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None,
2544
+ dtype_arg: str = None) -> Callable[[T], T]:
2545
+
2546
+ def _decorator(func: T) -> T:
2547
+ docstr = """
2548
+ Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
2549
+
2550
+ :param input: the input values
2551
+ :type input: Tensor
2552
+ :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions
2553
+ :type axis: int
2554
+ :param keep_dims: if true, keep the reduced dimensions with length 1
2555
+ :type keep_dims: bool"""
2556
+ if return_indices_arg is not None:
2557
+ docstr += f"""
2558
+ :param {return_indices_arg}: if true, return index corresponding to the {name} value
2559
+ :type {return_indices_arg}: bool"""
2560
+ if tie_break_arg is not None:
2561
+ docstr += f"""
2562
+ :param {tie_break_arg}: if true, in case of a tie (i.e., multiple elements have the same {name} value), return the left-most index for values that aren't NaN
2563
+ :type {tie_break_arg}: bool"""
2564
+ if dtype_arg is not None:
2565
+ docstr += f"""
2566
+ :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. This is useful for preventing data overflows. If not specified, integer and bool dtypes are upcasted to :code:`tl.int32` and float dtypes are upcasted to at least :code:`tl.float32`.
2567
+ :type {dtype_arg}: tl.dtype"""
2568
+
2569
+ func.__doc__ = docstr.format(name=name)
2570
+ return func
2571
+
2572
+ return _decorator
2573
+
2574
+
2575
+ @contextmanager
2576
+ def _insertion_guard(builder):
2577
+ ip = builder.get_insertion_point()
2578
+ yield
2579
+ builder.restore_insertion_point(ip)
2580
+
2581
+
2582
+ @_tensor_member_fn
2583
+ @builtin
2584
+ def reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
2585
+ """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
2586
+
2587
+ :param input: the input tensor, or tuple of tensors
2588
+ :type input: Tensor
2589
+ :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions
2590
+ :type axis: int | None
2591
+ :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit)
2592
+ :type combine_fn: Callable
2593
+ :param keep_dims: if true, keep the reduced dimensions with length 1
2594
+ :type keep_dims: bool
2595
+
2596
+ """
2597
+ if isinstance(input, tensor):
2598
+ return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic, _generator=_generator)[0]
2599
+
2600
+ def make_combine_region(reduce_op):
2601
+ param_types = [t.type.scalar for t in input] * 2
2602
+ region = reduce_op.get_region(0)
2603
+ builder = _semantic.builder
2604
+ with _insertion_guard(builder):
2605
+ to_ir = lambda T: T.to_ir(builder)
2606
+ block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2607
+ args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
2608
+ results = _generator.call_JitFunction(combine_fn, args, kwargs={})
2609
+ if isinstance(results, tensor):
2610
+ handles = [results.handle]
2611
+ else:
2612
+ handles = [r.handle for r in results]
2613
+ builder.create_reduce_ret(*handles)
2614
+
2615
+ def expand_ndims(t, ndims):
2616
+ for _ in builtins.range(ndims):
2617
+ t = expand_dims(t, 0, _semantic=_semantic)
2618
+ return t
2619
+
2620
+ axis = _unwrap_if_constexpr(axis)
2621
+ keep_dims = _unwrap_if_constexpr(keep_dims)
2622
+ if axis is not None:
2623
+ axis = _wrap_axis(axis, len(input[0].shape))
2624
+ ret = _semantic.reduction(input, axis, make_combine_region)
2625
+ if keep_dims:
2626
+ if axis is not None:
2627
+ ret = tuple(expand_dims(t, axis, _semantic=_semantic) for t in ret)
2628
+ else:
2629
+ ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret)
2630
+ return ret
2631
+
2632
+
2633
+ @builtin
2634
+ def _promote_bfloat16_to_float32(t, _semantic=None):
2635
+ scalar_ty = t.type.scalar
2636
+
2637
+ # hardware doesn't support FMAX, FMIN, CMP for bfloat16
2638
+ if scalar_ty is bfloat16:
2639
+ return t.to(float32, _semantic=_semantic)
2640
+ return t
2641
+
2642
+
2643
+ @builtin
2644
+ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
2645
+ axis = _unwrap_if_constexpr(axis)
2646
+ n = input.shape[axis]
2647
+ index = arange(0, n, _semantic=_semantic)
2648
+
2649
+ if len(input.shape) > 1:
2650
+ # Broadcast index across the non-reduced axes
2651
+ axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))]
2652
+ del axes_to_expand[axis]
2653
+ index = expand_dims(index, axes_to_expand, _semantic=_semantic)
2654
+ index = broadcast_to(index, input.shape, _semantic=_semantic)
2655
+
2656
+ rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic,
2657
+ _generator=_generator)
2658
+ return rvalue, rindices
2659
+
2660
+
2661
+ # -----------------------
2662
+ # Scans
2663
+ # -----------------------
2664
+
2665
+
2666
+ def _add_scan_docstr(name: str, dtype_arg: str = None) -> Callable[[T], T]:
2667
+
2668
+ def _decorator(func: T) -> T:
2669
+ docstr = """
2670
+ Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
2671
+
2672
+ :param input: the input values
2673
+ :type input: Tensor
2674
+ :param axis: the dimension along which the scan should be done
2675
+ :type axis: int
2676
+ :param reverse: if true, the scan is performed in the reverse direction
2677
+ :type reverse: bool"""
2678
+
2679
+ if dtype_arg is not None:
2680
+ docstr += f"""
2681
+ :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. If not specified, small integer types (< 32 bits) are upcasted to prevent overflow. Note that :code:`tl.bfloat16` inputs are automatically promoted to :code:`tl.float32`.
2682
+ :type {dtype_arg}: tl.dtype"""
2683
+
2684
+ func.__doc__ = docstr.format(name=name)
2685
+ return func
2686
+
2687
+ return _decorator
2688
+
2689
+
2690
+ @_tensor_member_fn
2691
+ @builtin
2692
+ def associative_scan(input, axis, combine_fn, reverse=False, _semantic=None, _generator=None):
2693
+ """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry
2694
+
2695
+ :param input: the input tensor, or tuple of tensors
2696
+ :type input: Tensor
2697
+ :param axis: the dimension along which the reduction should be done
2698
+ :type axis: int
2699
+ :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit)
2700
+ :type combine_fn: Callable
2701
+ :param reverse: whether to apply the associative scan in the reverse direction along axis
2702
+ :type reverse: bool
2703
+
2704
+ """
2705
+ if isinstance(input, tensor):
2706
+ return associative_scan((input, ), axis, combine_fn, reverse, _semantic=_semantic, _generator=_generator)[0]
2707
+
2708
+ def make_combine_region(scan_op):
2709
+ param_types = [t.type.scalar for t in input] * 2
2710
+ region = scan_op.get_region(0)
2711
+ builder = _semantic.builder
2712
+ with _insertion_guard(builder):
2713
+ to_ir = lambda T: T.to_ir(builder)
2714
+ block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2715
+ args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
2716
+ results = _generator.call_JitFunction(combine_fn, args, kwargs={})
2717
+ if isinstance(results, tensor):
2718
+ handles = [results.handle]
2719
+ else:
2720
+ handles = [r.handle for r in results]
2721
+ builder.create_scan_ret(*handles)
2722
+
2723
+ axis = _unwrap_if_constexpr(axis)
2724
+ if axis is not None:
2725
+ axis = _wrap_axis(axis, len(input[0].shape))
2726
+ return _semantic.associative_scan(input, axis, make_combine_region, reverse)
2727
+
2728
+
2729
+ @_tensor_member_fn
2730
+ @builtin
2731
+ def histogram(input, num_bins, mask=None, _semantic=None, _generator=None):
2732
+ """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0.
2733
+
2734
+ :param input: the input tensor
2735
+ :type input: Tensor
2736
+ :param num_bins: number of histogram bins
2737
+ :type num_bins: int
2738
+ :param mask: if `mask[idx]` is false, exclude `input[idx]` from histogram
2739
+ :type mask: Block of `triton.int1`, optional
2740
+
2741
+ """
2742
+ num_bins = _unwrap_if_constexpr(num_bins)
2743
+ mask = _unwrap_if_constexpr(mask)
2744
+ if mask is not None:
2745
+ mask = _semantic.to_tensor(mask)
2746
+ return _semantic.histogram(input, num_bins, mask)
2747
+
2748
+
2749
+ @_tensor_member_fn
2750
+ @builtin
2751
+ def gather(src, index, axis, _semantic=None):
2752
+ """Gather from a tensor along a given dimension.
2753
+
2754
+ :param src: the source tensor
2755
+ :type src: Tensor
2756
+ :param index: the index tensor
2757
+ :type index: Tensor
2758
+ :param axis: the dimension to gather along
2759
+ :type axis: int
2760
+
2761
+ """
2762
+ axis = _unwrap_if_constexpr(axis)
2763
+ return _semantic.gather(src, index, axis)
2764
+
2765
+
2766
+ @builtin
2767
+ def map_elementwise(
2768
+ scalar_fn: Callable[..., Tuple[tensor, ...]],
2769
+ *args: tensor,
2770
+ pack=1,
2771
+ _semantic=None,
2772
+ _generator=None,
2773
+ ):
2774
+ '''
2775
+ Map a scalar function over a tensor.
2776
+
2777
+ The input tensors :code:`args` are implicitly broadcasted to the same shape.
2778
+
2779
+ This may be useful in allowing control flow over single elements in a tensor,
2780
+ for example a multi-branch function where one branch is more expensive. With
2781
+ :code:`tl.where` you are forced to calculate both sides of the branch, but
2782
+ with an if we only execute one side.
2783
+
2784
+ .. highlight:: python
2785
+ .. code-block:: python
2786
+
2787
+ @triton.jit
2788
+ def selu_scalar(x, alpha):
2789
+ if x > 0:
2790
+ return a
2791
+ else:
2792
+ return alpha * (tl.exp(x) - 1)
2793
+
2794
+ @triton.jit
2795
+ def selu(x, alpha):
2796
+ return tl.map_elementwise(selu_scalar, x, alpha)
2797
+
2798
+ :param scalar_fn: the function to map over.
2799
+ :param pack: the number of elements to be processed by one function call.
2800
+ :return: one tensor or a tuple of tensors, depending on the mapped function.
2801
+ '''
2802
+ # Build the block for the nested region first to discover the return types
2803
+ assert pack >= 1
2804
+ in_scalar_tys = [t.type.scalar for t in args]
2805
+ builder = _semantic.builder
2806
+ block = builder.new_block()
2807
+ scalar_args = []
2808
+ for i, ty in enumerate(in_scalar_tys):
2809
+ for j in builtins.range(pack):
2810
+ block.add_argument(ty.to_ir(builder))
2811
+ scalar_args.append(tensor(block.arg(i * pack + j), ty))
2812
+
2813
+ with _insertion_guard(builder):
2814
+ builder.set_insertion_point_to_start(block)
2815
+ scalar_results = _generator.call_JitFunction(scalar_fn, scalar_args, kwargs={})
2816
+
2817
+ is_single = isinstance(scalar_results, tensor)
2818
+ if is_single:
2819
+ scalar_results = scalar_results,
2820
+
2821
+ handles = [r.handle for r in scalar_results]
2822
+ builder.create_map_elementwise_ret(handles)
2823
+
2824
+ fn_result_types = [x.type for x in scalar_results]
2825
+ scalar_result_types = fn_result_types
2826
+ if pack > 1:
2827
+ scalar_result_types = fn_result_types[::pack]
2828
+ for offset in builtins.range(1, pack):
2829
+ assert scalar_result_types == fn_result_types[offset::pack], "type mismatch in unpacked results"
2830
+
2831
+ def make_elementwise_region(elementwise_op):
2832
+ region = elementwise_op.get_region(0)
2833
+ region.push_back(block)
2834
+
2835
+ result = _semantic.map_elementwise(args, scalar_result_types, pack, make_elementwise_region)
2836
+ return result[0] if is_single else result
2837
+
2838
+
2839
+ # -----------------------
2840
+ # Compiler Hint Ops
2841
+ # -----------------------
2842
+
2843
+
2844
+ @builtin
2845
+ def debug_barrier(_semantic=None):
2846
+ '''
2847
+ Insert a barrier to synchronize all threads in a block.
2848
+ '''
2849
+ return _semantic.debug_barrier()
2850
+
2851
+
2852
+ @builtin
2853
+ def multiple_of(input, values, _semantic=None):
2854
+ """
2855
+ Let the compiler know that the values in :code:`input` are all multiples of :code:`value`.
2856
+ """
2857
+ if isinstance(values, constexpr):
2858
+ values = [values]
2859
+ for i, d in enumerate(values):
2860
+ if not isinstance(d, constexpr):
2861
+ raise TypeError(f"values element {i} must have type `constexpr`")
2862
+ if not isinstance(d.value, int):
2863
+ raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2864
+ values = [x.value for x in values]
2865
+ return _semantic.multiple_of(input, values)
2866
+
2867
+
2868
+ @builtin
2869
+ def max_contiguous(input, values, _semantic=None):
2870
+ """
2871
+ Let the compiler know that the `value` first values in :code:`input` are contiguous.
2872
+ """
2873
+ if isinstance(values, constexpr):
2874
+ values = [values]
2875
+ for i, d in enumerate(values):
2876
+ if not isinstance(d, constexpr):
2877
+ raise TypeError(f"values element {i} must have type `constexpr`")
2878
+ if not isinstance(d.value, int):
2879
+ raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2880
+ values = [x.value for x in values]
2881
+ return _semantic.max_contiguous(input, values)
2882
+
2883
+
2884
+ @builtin
2885
+ def max_constancy(input, values, _semantic=None):
2886
+ """
2887
+ Let the compiler know that the `value` first values in :code:`input` are constant.
2888
+
2889
+ e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal,
2890
+ for example [0, 0, 0, 0, 1, 1, 1, 1].
2891
+ """
2892
+ if isinstance(values, constexpr):
2893
+ values = [values]
2894
+ for i, d in enumerate(values):
2895
+ if not isinstance(d, constexpr):
2896
+ raise TypeError(f"values element {i} must have type `constexpr`")
2897
+ if not isinstance(d.value, int):
2898
+ raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2899
+ values = [x.value for x in values]
2900
+ return _semantic.max_constancy(input, values)
2901
+
2902
+
2903
+ @builtin
2904
+ def assume(cond, _semantic=None):
2905
+ '''
2906
+ Allow compiler to assume the :code:`cond` is True.
2907
+ '''
2908
+ return _semantic.assume(_semantic.to_tensor(cond))
2909
+
2910
+
2911
+ # -----------------------
2912
+ # Debugging functions
2913
+ # -----------------------
2914
+
2915
+
2916
+ @builtin
2917
+ def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _semantic=None):
2918
+ '''
2919
+ Print the values at compile time. The parameters are the same as the builtin :code:`print`.
2920
+
2921
+ NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`,
2922
+ which has special requirements for the arguments.
2923
+
2924
+ .. highlight:: python
2925
+ .. code-block:: python
2926
+
2927
+ tl.static_print(f"BLOCK_SIZE={BLOCK_SIZE}")
2928
+ '''
2929
+ pass
2930
+
2931
+
2932
+ @builtin
2933
+ def static_assert(cond, msg="", _semantic=None):
2934
+ '''
2935
+ Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable
2936
+ is set.
2937
+
2938
+ .. highlight:: python
2939
+ .. code-block:: python
2940
+
2941
+ tl.static_assert(BLOCK_SIZE == 1024)
2942
+ '''
2943
+ pass
2944
+
2945
+
2946
+ @builtin
2947
+ def device_print(prefix, *args, hex=False, _semantic=None):
2948
+ '''
2949
+ Print the values at runtime from the device. String formatting does not work for runtime values, so you should
2950
+ provide the values you want to print as arguments. The first value must be a string, all following values must
2951
+ be scalars or tensors.
2952
+
2953
+ Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match
2954
+ this function (not the normal requirements for :code:`print`).
2955
+
2956
+ .. highlight:: python
2957
+ .. code-block:: python
2958
+
2959
+ tl.device_print("pid", pid)
2960
+ print("pid", pid)
2961
+
2962
+ On CUDA, printfs are streamed through a buffer of limited size (on one host,
2963
+ we measured the default as 6912 KiB, but this may not be consistent across
2964
+ GPUs and CUDA versions). If you notice some printfs are being dropped, you
2965
+ can increase the buffer size by calling
2966
+
2967
+ .. highlight:: python
2968
+ .. code-block:: python
2969
+
2970
+ triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes)
2971
+
2972
+ CUDA may raise an error if you try to change this value after running a
2973
+ kernel that uses printfs. The value set here may only affect the current
2974
+ device (so if you have multiple GPUs, you'd need to call it multiple times).
2975
+
2976
+ :param prefix: a prefix to print before the values. This is required to be a string literal.
2977
+ :param args: the values to print. They can be any tensor or scalar.
2978
+ :param hex: print all values as hex instead of decimal
2979
+ '''
2980
+ import string
2981
+ prefix = _unwrap_if_constexpr(prefix)
2982
+ assert isinstance(prefix, str), f"{prefix} is not string"
2983
+ b_ascii = True
2984
+ for ch in prefix:
2985
+ if ch not in string.printable:
2986
+ b_ascii = False
2987
+ break
2988
+ assert b_ascii, f"{prefix} is not an ascii string"
2989
+ new_args = []
2990
+ for arg in args:
2991
+ new_args.append(_semantic.to_tensor(arg))
2992
+ return _semantic.device_print(prefix, new_args, hex)
2993
+
2994
+
2995
+ @builtin
2996
+ def device_assert(cond, msg="", mask=None, _semantic=None):
2997
+ '''
2998
+ Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG`
2999
+ is set to a value besides :code:`0` in order for this to have any effect.
3000
+
3001
+ Using the Python :code:`assert` statement is the same as calling this function, except that the second argument
3002
+ must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must
3003
+ be set for this :code:`assert` statement to have any effect.
3004
+
3005
+ .. highlight:: python
3006
+ .. code-block:: python
3007
+
3008
+ tl.device_assert(pid == 0)
3009
+ assert pid == 0, f"pid != 0"
3010
+
3011
+ :param cond: the condition to assert. This is required to be a boolean tensor.
3012
+ :param msg: the message to print if the assertion fails. This is required to be a string literal.
3013
+ '''
3014
+ msg = _unwrap_if_constexpr(msg)
3015
+ mask = _unwrap_if_constexpr(mask)
3016
+ if mask is not None:
3017
+ mask = _semantic.to_tensor(mask)
3018
+ return _semantic.device_assert(_semantic.to_tensor(cond), msg, mask)
3019
+
3020
+
3021
+ @builtin
3022
+ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]],
3023
+ is_pure: bool, pack: int, _semantic=None):
3024
+ '''
3025
+ Execute inline assembly over a tensor. Essentially, this is :code:`map`
3026
+ where the function is inline assembly.
3027
+
3028
+ The input tensors :code:`args` are implicitly broadcasted to the same shape.
3029
+
3030
+ :code:`dtype` can be a tuple of types, in which case the output is a
3031
+ tuple of tensors.
3032
+
3033
+ Each invocation of the inline asm processes :code:`pack` elements at a
3034
+ time. Exactly which set of inputs a block receives is unspecified.
3035
+ Input elements of size less than 4 bytes are packed into 4-byte
3036
+ registers.
3037
+
3038
+ This op does not support empty :code:`dtype` -- the inline asm must
3039
+ return at least one tensor, even if you don't need it. You can work
3040
+ around this by returning a dummy tensor of arbitrary type; it shouldn't
3041
+ cost you anything if you don't use it.
3042
+
3043
+ Example using
3044
+ `PTX <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html>`_
3045
+ assembly:
3046
+
3047
+ .. highlight:: python
3048
+ .. code-block:: python
3049
+
3050
+ @triton.jit
3051
+ def kernel(A, B, C, D, BLOCK: tl.constexpr):
3052
+ a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor
3053
+ b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor
3054
+
3055
+ # For each (a,b) in zip(a,b), perform the following:
3056
+ # - Let ai be `a` converted to int32.
3057
+ # - Let af be `a` converted to float.
3058
+ # - Let m be the max of ai and b.
3059
+ # - Return ai and mi.
3060
+ # Do the above 4 elements at a time.
3061
+ (c, d) = tl.inline_asm_elementwise(
3062
+ asm="""
3063
+ {
3064
+ // Unpack `a` into `ai`.
3065
+ .reg .b8 tmp<4>;
3066
+ mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8;
3067
+ cvt.u32.u8 $0, tmp0;
3068
+ cvt.u32.u8 $1, tmp1;
3069
+ cvt.u32.u8 $2, tmp2;
3070
+ cvt.u32.u8 $3, tmp3;
3071
+ }
3072
+ // Convert `ai` to float.
3073
+ cvt.rn.f32.s32 $4, $0;
3074
+ cvt.rn.f32.s32 $5, $1;
3075
+ cvt.rn.f32.s32 $6, $2;
3076
+ cvt.rn.f32.s32 $7, $3;
3077
+ // Take max of `ai` and `b`.
3078
+ max.f32 $4, $4, $9;
3079
+ max.f32 $5, $5, $10;
3080
+ max.f32 $6, $6, $11;
3081
+ max.f32 $7, $7, $12;
3082
+ """,
3083
+ constraints=(
3084
+ # 8 output registers, namely
3085
+ # $0=ai0, $1=ai1, $2=ai2, $3=ai3,
3086
+ # $4=m0, $5=m1, $6=m2, $7=m3.
3087
+ "=r,=r,=r,=r,=r,=r,=r,=r,"
3088
+ # 5 input registers, namely
3089
+ # $8=ai,
3090
+ # $9=b0, $10=b1, $11=b2, $12=b3.
3091
+ # The four elements from `a` are all packed into one register.
3092
+ "r,r,r,r,r"),
3093
+ args=[a, b],
3094
+ dtype=(tl.int32, tl.float32),
3095
+ is_pure=True,
3096
+ pack=4,
3097
+ )
3098
+ tl.store(C + tl.arange(0, BLOCK), c)
3099
+ tl.store(D + tl.arange(0, BLOCK), d)
3100
+
3101
+ :param asm: assembly to run. Must match target's assembly format.
3102
+ :param constraints: asm constraints in
3103
+ `LLVM format <https://llvm.org/docs/LangRef.html#inline-asm-constraint-string>`_
3104
+ :param args: the input tensors, whose values are passed to the asm block
3105
+ :param dtype: the element type(s) of the returned tensor(s)
3106
+ :param is_pure: if true, the compiler assumes the asm block has no side-effects
3107
+ :param pack: the number of elements to be processed by one instance of inline assembly
3108
+ :return: one tensor or a tuple of tensors of the given dtypes
3109
+ '''
3110
+ asm = _unwrap_if_constexpr(asm)
3111
+ constraints = _unwrap_if_constexpr(constraints)
3112
+ pack = _unwrap_if_constexpr(pack)
3113
+ is_pure = _unwrap_if_constexpr(is_pure)
3114
+
3115
+ # Wrap `dtype` in a tuple if it's not already.
3116
+ try:
3117
+ iter(dtype) # type: ignore
3118
+ has_multiple_outputs = True
3119
+ except TypeError:
3120
+ has_multiple_outputs = False
3121
+ dtype = (dtype, ) # type: ignore
3122
+
3123
+ dtype = typing.cast(Sequence[_DtypeClass], dtype)
3124
+
3125
+ res_tys = dtype
3126
+ if dispatch_args := [_semantic.to_tensor(arg) for arg in args]:
3127
+ bin_op_type_checking = partial(
3128
+ _semantic.binary_op_type_checking_impl,
3129
+ arithmetic_check=False,
3130
+ allow_lhs_ptr=True,
3131
+ allow_rhs_ptr=True,
3132
+ )
3133
+ broadcast_arg = dispatch_args[0]
3134
+ # Get the broadcast shape over all the arguments
3135
+ for item in dispatch_args:
3136
+ _, broadcast_arg = bin_op_type_checking(item, broadcast_arg)
3137
+ if broadcast_arg.shape:
3138
+ # Change the shape of each argument based on the broadcast shape
3139
+ for i, item in enumerate(dispatch_args):
3140
+ dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg)
3141
+ res_tys = [broadcast_arg.type.with_element_ty(dt) for dt in dtype]
3142
+ handles = [t.handle for t in dispatch_args]
3143
+ builder = _semantic.builder
3144
+ call = builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(builder) for ty in res_tys], is_pure, pack)
3145
+
3146
+ if not has_multiple_outputs:
3147
+ return tensor(call.get_result(0), res_tys[0])
3148
+ return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys))
3149
+
3150
+
3151
+ # -----------------------
3152
+ # Iterators
3153
+ # -----------------------
3154
+
3155
+
3156
+ class static_range(base_value):
3157
+ """
3158
+ Iterator that counts upward forever.
3159
+
3160
+ .. highlight:: python
3161
+ .. code-block:: python
3162
+
3163
+ @triton.jit
3164
+ def kernel(...):
3165
+ for i in tl.static_range(10):
3166
+ ...
3167
+ :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
3168
+ :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively.
3169
+ :param arg1: the start value.
3170
+ :param arg2: the end value.
3171
+ :param step: the step value.
3172
+ """
3173
+
3174
+ def __init__(self, arg1, arg2=None, step=None):
3175
+ assert isinstance(arg1, constexpr), f"{arg1} used as tl.static_range start value is not a constexpr"
3176
+ if step is None:
3177
+ self.step = constexpr(1)
3178
+ else:
3179
+ assert isinstance(step, constexpr), f"{step} used as tl.static_range step value is not a constexpr"
3180
+ self.step = step
3181
+ if arg2 is None:
3182
+ self.start = constexpr(0)
3183
+ self.end = arg1
3184
+ else:
3185
+ assert isinstance(arg2, constexpr), f"{arg2} used as tl.static_range end value is not a constexpr"
3186
+ self.start = arg1
3187
+ self.end = arg2
3188
+
3189
+ def __iter__(self):
3190
+ raise RuntimeError("static_range can only be used in @triton.jit'd functions")
3191
+
3192
+ def __next__(self):
3193
+ raise RuntimeError("static_range can only be used in @triton.jit'd functions")
3194
+
3195
+
3196
+ class async_task:
3197
+ """
3198
+ Context manager to run code fragments asynchronously.
3199
+ """
3200
+
3201
+ def __init__(self, task_ids, _builder=None):
3202
+ self.task_ids = list({_unwrap_if_constexpr(tid) for tid in task_ids})
3203
+ self.builder = _builder
3204
+
3205
+ def __enter__(self):
3206
+ self.builder.set_async_task_ids(self.task_ids)
3207
+
3208
+ def __exit__(self, exc_type, exc_value, traceback):
3209
+ self.builder.unset_async_task_ids()
3210
+
3211
+
3212
+ class range(base_value):
3213
+ """
3214
+ Iterator that counts upward forever.
3215
+
3216
+ .. highlight:: python
3217
+ .. code-block:: python
3218
+
3219
+ @triton.jit
3220
+ def kernel(...):
3221
+ for i in tl.range(10, num_stages=3):
3222
+ ...
3223
+ :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
3224
+ :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler.
3225
+ :param arg1: the start value.
3226
+ :param arg2: the end value.
3227
+ :param step: the step value.
3228
+ :param num_stages: pipeline the loop into this many stages (so there are
3229
+ :code:`num_stages` iterations of the loop in flight at once).
3230
+
3231
+ Note this is subtly different than passing :code:`num_stages` as a
3232
+ kernel argument. The kernel argument only pipelines loads that feed
3233
+ into :code:`dot` operations, while this attribute tries to pipeline most
3234
+ (though not all) loads in this loop.
3235
+ :param loop_unroll_factor: Tells the Triton IR level loop unroller how many
3236
+ times to unroll a for loop that this range is used with. Less than 2 for
3237
+ this value implies no unrolling.
3238
+ :param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot
3239
+ operation in the loop to be multi-buffered, if applicable.
3240
+ :param flatten: automatically flatten the loop nest starting at this loop to
3241
+ create a single flattened loop. The compiler will try to pipeline the
3242
+ flattened loop which can avoid stage stalling.
3243
+ :param warp_specialize: Enable automatic warp specialization on the loop.
3244
+ The compiler will attempt to partition memory, MMA, and vector
3245
+ operations in the loop into separate async partitions. This will
3246
+ increase the total number of warps required by the kernel.
3247
+ :param disable_licm: Tells the compiler it shouldn't hoist loop invariant
3248
+ code outside the loop. This is often useful to avoid creating long liveranges
3249
+ within a loop.
3250
+
3251
+ Note that warp specialization is only supported on Blackwell GPUs and
3252
+ only works on simple matmul loops. Support for arbitrary loops will be
3253
+ expanded over time.
3254
+ """
3255
+
3256
+ def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
3257
+ disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False):
3258
+ if step is None:
3259
+ self.step = constexpr(1)
3260
+ else:
3261
+ self.step = step
3262
+ if arg2 is None:
3263
+ self.start = constexpr(0)
3264
+ self.end = arg1
3265
+ else:
3266
+ self.start = arg1
3267
+ self.end = arg2
3268
+ self.num_stages = num_stages
3269
+ self.loop_unroll_factor = loop_unroll_factor
3270
+ self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
3271
+ self.flatten = flatten
3272
+ self.warp_specialize = warp_specialize
3273
+ self.disable_licm = disable_licm
3274
+
3275
+ def __iter__(self):
3276
+ raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
3277
+
3278
+ def __next__(self):
3279
+ raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
3280
+
3281
+
3282
+ class condition(base_value):
3283
+ """
3284
+ While loop condition wrapper.
3285
+
3286
+ .. highlight:: python
3287
+ .. code-block:: python
3288
+
3289
+ @triton.jit
3290
+ def kernel(...):
3291
+ while tl.condition(c, disable_licm)
3292
+ ...
3293
+ :note: This is a special wrapper used to annotate while loops in the context of
3294
+ :code:`triton.jit` functions. It allows user to pass extra attributes to the compiler.
3295
+ :param disable_licm: Tells the compiler it shouldn't hoist loop invariant
3296
+ code outside the loop. This is often useful to avoid creating long liveranges
3297
+ within a loop.
3298
+ """
3299
+
3300
+ def __init__(self, arg1, disable_licm=False):
3301
+ self.condition = arg1
3302
+ self.disable_licm = disable_licm
3303
+
3304
+
3305
+ # -----------------------
3306
+ # Extern functions
3307
+ # -----------------------
3308
+
3309
+
3310
+ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_type: dtype, is_pure: bool,
3311
+ _semantic):
3312
+ '''
3313
+ Dispatch a function to a library
3314
+ :param func: the function to dispatch
3315
+ :param lib_name: the name of the library
3316
+ :param lib_path: the path of the library
3317
+ :param args: the arguments of the function
3318
+ :param arg_type_symbol_dict: the type of the arguments
3319
+ :param ret_type: the type of the return value
3320
+ :return: the return value of the function
3321
+ '''
3322
+ if len(arg_type_symbol_dict) == 0:
3323
+ raise ValueError("arg_type_symbol_dict is empty")
3324
+
3325
+ num_args = len(list(arg_type_symbol_dict.keys())[0])
3326
+ if len(args) != num_args:
3327
+ raise ValueError(f"length of input args does not match."
3328
+ f"Expect {len(args)}, got {num_args}")
3329
+
3330
+ arg_types = []
3331
+ arg_list = []
3332
+ for arg in args:
3333
+ if isinstance(arg, tensor):
3334
+ arg_types.append(arg.dtype)
3335
+ arg_list.append(arg.handle)
3336
+ else:
3337
+ arg_types.append(type(arg))
3338
+ arg_list.append(arg)
3339
+ arg_types = tuple(arg_types)
3340
+
3341
+ if arg_types not in arg_type_symbol_dict:
3342
+ raise ValueError(f"input arg type does not match."
3343
+ f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
3344
+ else:
3345
+ symbol = arg_type_symbol_dict[arg_types][0]
3346
+ builder = _semantic.builder
3347
+ return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(builder), is_pure), ret_type)
3348
+
3349
+
3350
+ @builtin
3351
+ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
3352
+ _semantic=None):
3353
+ '''
3354
+ Dispatch an elementwise function to a library
3355
+ :param lib_name: the name of the library
3356
+ :param lib_path: the path of the library
3357
+ :param args: the arguments of the function
3358
+ :param arg_type_symbol_dict: the type of the arguments
3359
+ :param is_pure: whether the function is pure
3360
+ :return: the return value of the function
3361
+ '''
3362
+ dispatch_args = args.copy()
3363
+ all_scalar = True
3364
+ arg_types = []
3365
+ for i in builtins.range(len(dispatch_args)):
3366
+ dispatch_args[i] = _semantic.to_tensor(dispatch_args[i])
3367
+ arg_types.append(dispatch_args[i].dtype)
3368
+ if dispatch_args[i].type.is_block():
3369
+ all_scalar = False
3370
+
3371
+ arg_types = tuple(arg_types)
3372
+ ret_type = arg_type_symbol_dict[arg_types][1]
3373
+ if len(arg_types) > 0:
3374
+ arithmetic_check = True
3375
+ # If there's a type tuple that is not supported by the library, we will do arithmetic check
3376
+ if arg_types in arg_type_symbol_dict:
3377
+ arithmetic_check = False
3378
+ broadcast_arg = dispatch_args[0]
3379
+ # Get the broadcast shape over all the arguments
3380
+ for item in dispatch_args:
3381
+ _, broadcast_arg = _semantic.binary_op_type_checking_impl(item, broadcast_arg,
3382
+ arithmetic_check=arithmetic_check)
3383
+ # Change the shape of each argument based on the broadcast shape
3384
+ for i in builtins.range(len(dispatch_args)):
3385
+ dispatch_args[i], _ = _semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg,
3386
+ arithmetic_check=arithmetic_check)
3387
+ if not all_scalar:
3388
+ ret_type = broadcast_arg.type.with_element_ty(ret_type)
3389
+ func = _semantic.builder.create_extern_elementwise
3390
+ return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_type, is_pure, _semantic)
3391
+
3392
+
3393
+ def binary_op_type_legalization(lhs, rhs, semantic):
3394
+ '''
3395
+ Convert both operands to a single common type
3396
+ :param lhs: the left operand
3397
+ :param rhs: the right operand
3398
+ :param builder: the builder
3399
+ '''
3400
+ return semantic.binary_op_type_checking_impl(lhs, rhs)
3401
+
3402
+
3403
+ def extern(fn):
3404
+ """A decorator for external functions."""
3405
+ return builtin(fn)