triton-windows 3.2.0.post11__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (154) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +85 -0
  3. triton/_internal_testing.py +123 -0
  4. triton/backends/__init__.py +50 -0
  5. triton/backends/amd/compiler.py +368 -0
  6. triton/backends/amd/driver.c +211 -0
  7. triton/backends/amd/driver.py +512 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  25. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  26. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  27. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  28. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  31. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  32. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  40. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  41. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  42. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  43. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  44. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  45. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  46. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  48. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  49. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  50. triton/backends/amd/include/hip/device_functions.h +38 -0
  51. triton/backends/amd/include/hip/driver_types.h +468 -0
  52. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  53. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  54. triton/backends/amd/include/hip/hip_common.h +100 -0
  55. triton/backends/amd/include/hip/hip_complex.h +38 -0
  56. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  57. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  58. triton/backends/amd/include/hip/hip_ext.h +159 -0
  59. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  60. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  61. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  62. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  63. triton/backends/amd/include/hip/hip_profile.h +27 -0
  64. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  65. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  66. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  67. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  68. triton/backends/amd/include/hip/hip_version.h +17 -0
  69. triton/backends/amd/include/hip/hiprtc.h +421 -0
  70. triton/backends/amd/include/hip/library_types.h +78 -0
  71. triton/backends/amd/include/hip/math_functions.h +42 -0
  72. triton/backends/amd/include/hip/surface_types.h +63 -0
  73. triton/backends/amd/include/hip/texture_types.h +194 -0
  74. triton/backends/amd/include/hsa/Brig.h +1131 -0
  75. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  76. triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
  77. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  78. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  79. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  80. triton/backends/amd/include/hsa/hsa.h +5729 -0
  81. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  82. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  83. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  84. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  85. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  87. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  88. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  89. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  90. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  91. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  92. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  93. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  94. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  95. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  96. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  97. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  98. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  99. triton/backends/amd/include/roctracer/roctx.h +229 -0
  100. triton/backends/amd/lib/ockl.bc +0 -0
  101. triton/backends/amd/lib/ocml.bc +0 -0
  102. triton/backends/compiler.py +304 -0
  103. triton/backends/driver.py +48 -0
  104. triton/backends/nvidia/__init__.py +0 -0
  105. triton/backends/nvidia/bin/ptxas.exe +0 -0
  106. triton/backends/nvidia/compiler.py +410 -0
  107. triton/backends/nvidia/driver.c +451 -0
  108. triton/backends/nvidia/driver.py +524 -0
  109. triton/backends/nvidia/include/cuda.h +24359 -0
  110. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  111. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  112. triton/compiler/__init__.py +4 -0
  113. triton/compiler/code_generator.py +1303 -0
  114. triton/compiler/compiler.py +430 -0
  115. triton/compiler/errors.py +51 -0
  116. triton/compiler/make_launcher.py +0 -0
  117. triton/errors.py +5 -0
  118. triton/language/__init__.py +294 -0
  119. triton/language/_utils.py +21 -0
  120. triton/language/core.py +2694 -0
  121. triton/language/extra/__init__.py +26 -0
  122. triton/language/extra/cuda/__init__.py +13 -0
  123. triton/language/extra/cuda/_experimental_tma.py +108 -0
  124. triton/language/extra/cuda/libdevice.py +1629 -0
  125. triton/language/extra/cuda/utils.py +109 -0
  126. triton/language/extra/hip/__init__.py +3 -0
  127. triton/language/extra/hip/libdevice.py +475 -0
  128. triton/language/extra/libdevice.py +786 -0
  129. triton/language/math.py +250 -0
  130. triton/language/random.py +207 -0
  131. triton/language/semantic.py +1796 -0
  132. triton/language/standard.py +452 -0
  133. triton/runtime/__init__.py +23 -0
  134. triton/runtime/autotuner.py +408 -0
  135. triton/runtime/build.py +111 -0
  136. triton/runtime/cache.py +295 -0
  137. triton/runtime/driver.py +60 -0
  138. triton/runtime/errors.py +26 -0
  139. triton/runtime/interpreter.py +1235 -0
  140. triton/runtime/jit.py +951 -0
  141. triton/testing.py +511 -0
  142. triton/tools/__init__.py +0 -0
  143. triton/tools/build_extern.py +365 -0
  144. triton/tools/compile.c +67 -0
  145. triton/tools/compile.h +14 -0
  146. triton/tools/compile.py +155 -0
  147. triton/tools/disasm.py +144 -0
  148. triton/tools/experimental_descriptor.py +32 -0
  149. triton/tools/link.py +322 -0
  150. triton/windows_utils.py +375 -0
  151. triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
  152. triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
  153. triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
  154. triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
@@ -0,0 +1,2694 @@
1
+ from __future__ import annotations
2
+
3
+ from warnings import warn
4
+ from contextlib import contextmanager
5
+ from enum import Enum
6
+ from functools import partial, wraps
7
+ import typing
8
+ from typing import Union, Callable, List, Sequence, TypeVar, Optional
9
+ import builtins
10
+ from ..runtime.jit import jit
11
+ import inspect
12
+ import os
13
+
14
+ from .._C.libtriton import ir
15
+ from . import semantic
16
+ from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape
17
+
18
+ T = TypeVar('T')
19
+
20
+ TRITON_BUILTIN = "__triton_builtin__"
21
+
22
+ PropagateNan = ir.PROPAGATE_NAN
23
+
24
+
25
+ def builtin(fn: T) -> T:
26
+ """Mark a function as a builtin."""
27
+ assert callable(fn)
28
+
29
+ @wraps(fn)
30
+ def wrapper(*args, **kwargs):
31
+ if "_builder" not in kwargs or kwargs["_builder"] is None:
32
+ print(kwargs)
33
+ raise ValueError("Did you forget to add @triton.jit ? "
34
+ "(`_builder` argument must be provided outside of JIT functions.)")
35
+ return fn(*args, **kwargs)
36
+
37
+ setattr(wrapper, TRITON_BUILTIN, True)
38
+
39
+ return wrapper
40
+
41
+
42
+ def _tensor_member_fn(fn: T) -> T:
43
+ """Decorator that adds this free function as a member fn on class tensor.
44
+
45
+ When called as a member function on class tensor, the first argument to `fn`
46
+ is `self`, i.e. the tensor object.
47
+
48
+ If there are multiple decorators on a function, you probably want this one
49
+ to be the highest one (i.e. furthest from the function's `def`), so it's
50
+ applied last.
51
+
52
+ Unfortunately you still need to add a type stub to the body of class tensor
53
+ in order for pytype to know about it.
54
+ """
55
+ assert callable(fn)
56
+ orig_sig = inspect.signature(fn)
57
+ # Does fn take args other than _builder, _generator, and the tensor itself?
58
+ has_args = len(orig_sig.parameters.keys() - {"_builder", "_generator"}) > 1
59
+
60
+ if not fn.__doc__:
61
+ fn.__doc__ = ""
62
+ fn.__doc__ += f"""
63
+ This function can also be called as a member function on :py:class:`tensor`,
64
+ as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of
65
+ :code:`{fn.__name__}(x{", ..." if has_args else ""})`.
66
+ """
67
+
68
+ def wrapper(*args, **kwargs):
69
+ return fn(*args, **kwargs)
70
+
71
+ # Match the signature of `fn`, but change the first arg to `self` so the
72
+ # docs are a little less weird.
73
+ new_params = list(orig_sig.parameters.values())
74
+ new_params[0] = new_params[0].replace(name='self')
75
+ new_sig = orig_sig.replace(parameters=new_params)
76
+ wrapper.__signature__ = new_sig
77
+ wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function"
78
+ # If fn is a builtin, mark the wrapper as a builtin too.
79
+ if is_builtin(fn):
80
+ setattr(wrapper, TRITON_BUILTIN, True)
81
+
82
+ setattr(tensor, fn.__name__, wrapper)
83
+ return fn
84
+
85
+
86
+ def _unwrap_iterable(x):
87
+ """Returns x[0] if x has one element and x[0] is iterable."""
88
+ if len(x) == 1:
89
+ # Determine whether x[0] is iterable.
90
+ #
91
+ # You might want to use collections.abc.Iterable instead of this
92
+ # try/except block. Unfortunately, this doesn't work with constexpr.
93
+ #
94
+ # The problem is that abc.Iterable checks for __iter__ on the *class*.
95
+ # But we want constexpr to expose an __iter__ method if and only if the
96
+ # wrapped *object* (i.e. self.value) is iterable. Therefore there's no
97
+ # right answer for whether the class constexpr defines __iter__, and
98
+ # abc.Iterable doesn't work (at least not without some metaclass magic).
99
+ try:
100
+ iter(x[0])
101
+ return x[0]
102
+ except TypeError:
103
+ pass
104
+
105
+ return x
106
+
107
+
108
+ def is_builtin(fn) -> bool:
109
+ """Is this a registered triton builtin function?"""
110
+ return getattr(fn, TRITON_BUILTIN, False)
111
+
112
+
113
+ @builtin
114
+ def to_tensor(x, _builder=None):
115
+ return semantic.to_tensor(x, _builder)
116
+
117
+
118
+ # -----------------------
119
+ # constexpr
120
+ # -----------------------
121
+
122
+
123
+ class const:
124
+ """
125
+ This class is used as a type annotation to mark pointers to constant data.
126
+ The `store` function cannot be called with a pointer to const. Constness
127
+ is part of the pointer type and the usual Triton type consistency rules
128
+ apply. For example you cannot have a function that returns constant pointer
129
+ in one return statement and non-constant pointer in another.
130
+ """
131
+ pass
132
+
133
+
134
+ class constexpr:
135
+ """
136
+ This class is used to store a value that is known at compile-time.
137
+ """
138
+
139
+ def __init__(self, value):
140
+ if isinstance(value, constexpr):
141
+ self.value = value.value
142
+ else:
143
+ self.value = value
144
+
145
+ def __repr__(self) -> str:
146
+ return f"constexpr[{self.value}]"
147
+
148
+ def __index__(self):
149
+ return self.value
150
+
151
+ # In interpreter mode, constant values are not wrapped in constexpr,
152
+ # and therefore do not have a .value attribute.
153
+ # As a result, from here and below, we need to call the _constexpr_to_value
154
+ # function to obtain either constexpr.value or the value itself.
155
+ def __add__(self, other):
156
+ return constexpr(self.value + _constexpr_to_value(other))
157
+
158
+ def __radd__(self, other):
159
+ return constexpr(_constexpr_to_value(other) + self.value)
160
+
161
+ def __sub__(self, other):
162
+ return constexpr(self.value - _constexpr_to_value(other))
163
+
164
+ def __rsub__(self, other):
165
+ return constexpr(_constexpr_to_value(other) - self.value)
166
+
167
+ def __mul__(self, other):
168
+ return constexpr(self.value * _constexpr_to_value(other))
169
+
170
+ def __mod__(self, other):
171
+ return constexpr(self.value % _constexpr_to_value(other))
172
+
173
+ def __rmul__(self, other):
174
+ return constexpr(_constexpr_to_value(other) * self.value)
175
+
176
+ def __truediv__(self, other):
177
+ return constexpr(self.value / _constexpr_to_value(other))
178
+
179
+ def __rtruediv__(self, other):
180
+ return constexpr(_constexpr_to_value(other) / self.value)
181
+
182
+ def __floordiv__(self, other):
183
+ return constexpr(self.value // _constexpr_to_value(other))
184
+
185
+ def __rfloordiv__(self, other):
186
+ return constexpr(_constexpr_to_value(other) // self.value)
187
+
188
+ def __gt__(self, other):
189
+ return constexpr(self.value > _constexpr_to_value(other))
190
+
191
+ def __rgt__(self, other):
192
+ return constexpr(_constexpr_to_value(other) > self.value)
193
+
194
+ def __ge__(self, other):
195
+ return constexpr(self.value >= _constexpr_to_value(other))
196
+
197
+ def __rge__(self, other):
198
+ return constexpr(_constexpr_to_value(other) >= self.value)
199
+
200
+ def __lt__(self, other):
201
+ return constexpr(self.value < _constexpr_to_value(other))
202
+
203
+ def __rlt__(self, other):
204
+ return constexpr(_constexpr_to_value(other) < self.value)
205
+
206
+ def __le__(self, other):
207
+ return constexpr(self.value <= _constexpr_to_value(other))
208
+
209
+ def __rle__(self, other):
210
+ return constexpr(_constexpr_to_value(other) <= self.value)
211
+
212
+ def __eq__(self, other):
213
+ return constexpr(self.value == _constexpr_to_value(other))
214
+
215
+ def __ne__(self, other):
216
+ return constexpr(self.value != _constexpr_to_value(other))
217
+
218
+ def __bool__(self):
219
+ return bool(self.value)
220
+
221
+ def __neg__(self):
222
+ return constexpr(-self.value)
223
+
224
+ def __and__(self, other):
225
+ return constexpr(self.value & _constexpr_to_value(other))
226
+
227
+ def logical_and(self, other):
228
+ return constexpr(self.value and _constexpr_to_value(other))
229
+
230
+ def __or__(self, other):
231
+ return constexpr(self.value | _constexpr_to_value(other))
232
+
233
+ def __xor__(self, other):
234
+ return constexpr(self.value ^ _constexpr_to_value(other))
235
+
236
+ def logical_or(self, other):
237
+ return constexpr(self.value or _constexpr_to_value(other))
238
+
239
+ def __pos__(self):
240
+ return constexpr(+self.value)
241
+
242
+ def __invert__(self):
243
+ return constexpr(~self.value)
244
+
245
+ def __pow__(self, other):
246
+ return constexpr(self.value**_constexpr_to_value(other))
247
+
248
+ def __rpow__(self, other):
249
+ return constexpr(_constexpr_to_value(other)**self.value)
250
+
251
+ def __rshift__(self, other):
252
+ return constexpr(self.value >> _constexpr_to_value(other))
253
+
254
+ def __lshift__(self, other):
255
+ return constexpr(self.value << _constexpr_to_value(other))
256
+
257
+ def __not__(self):
258
+ return constexpr(not self.value)
259
+
260
+ def __iter__(self):
261
+ return iter(self.value)
262
+
263
+ def __call__(self, *args, **kwds):
264
+ return self.value(*args, **kwds)
265
+
266
+
267
+ CONSTEXPR_0 = constexpr(0)
268
+
269
+
270
+ def _unwrap_if_constexpr(o):
271
+ return o.value if isinstance(o, constexpr) else o
272
+
273
+
274
+ def check_bit_width(value, shift_value):
275
+ if isinstance(value, tensor) and isinstance(shift_value, constexpr):
276
+ bitwidth = value.type.scalar.primitive_bitwidth
277
+ if shift_value.value >= bitwidth:
278
+ warn(
279
+ f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior."
280
+ )
281
+
282
+
283
+ # -----------------------
284
+ # dtype
285
+ # -----------------------
286
+
287
+
288
+ class dtype:
289
+ SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
290
+ UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
291
+ FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64']
292
+ STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
293
+ OTHER_TYPES = ['void']
294
+
295
+ class SIGNEDNESS(Enum):
296
+ SIGNED = 0
297
+ UNSIGNED = 1
298
+
299
+ class KIND(Enum):
300
+ BOOLEAN = 0
301
+ INTEGRAL = 1
302
+ FLOATING = 2
303
+
304
+ def __init__(self, name):
305
+ name = _unwrap_if_constexpr(name)
306
+ self.name = name
307
+ assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name
308
+ if name in dtype.SINT_TYPES:
309
+ self.int_signedness = dtype.SIGNEDNESS.SIGNED
310
+ self.int_bitwidth = int(name.split('int')[-1])
311
+ self.primitive_bitwidth = self.int_bitwidth
312
+ elif name in dtype.UINT_TYPES:
313
+ self.int_signedness = dtype.SIGNEDNESS.UNSIGNED
314
+ self.int_bitwidth = int(name.split('int')[-1])
315
+ self.primitive_bitwidth = self.int_bitwidth
316
+ elif name in dtype.FP_TYPES:
317
+ if name == 'fp8e4b15':
318
+ self.fp_mantissa_width = 3
319
+ self.primitive_bitwidth = 8
320
+ self.exponent_bias = 15
321
+ elif name == 'fp8e4nv':
322
+ self.fp_mantissa_width = 3
323
+ self.primitive_bitwidth = 8
324
+ self.exponent_bias = 7
325
+ elif name == 'fp8e4b8':
326
+ self.fp_mantissa_width = 3
327
+ self.primitive_bitwidth = 8
328
+ self.exponent_bias = 8
329
+ elif name == 'fp8e5':
330
+ self.fp_mantissa_width = 2
331
+ self.primitive_bitwidth = 8
332
+ self.exponent_bias = 15
333
+ elif name == 'fp8e5b16':
334
+ self.fp_mantissa_width = 2
335
+ self.primitive_bitwidth = 8
336
+ self.exponent_bias = 16
337
+ elif name == 'fp16':
338
+ self.fp_mantissa_width = 10
339
+ self.primitive_bitwidth = 16
340
+ self.exponent_bias = 15
341
+ elif name == 'bf16':
342
+ self.fp_mantissa_width = 7
343
+ self.primitive_bitwidth = 16
344
+ self.exponent_bias = 127
345
+ elif name == 'fp32':
346
+ self.fp_mantissa_width = 23
347
+ self.primitive_bitwidth = 32
348
+ self.exponent_bias = 127
349
+ elif name == 'fp64':
350
+ self.fp_mantissa_width = 52
351
+ self.primitive_bitwidth = 64
352
+ self.exponent_bias = 1023
353
+ else:
354
+ raise RuntimeError(f'Unsupported floating-point type {name}')
355
+ elif name == 'void':
356
+ self.primitive_bitwidth = 0
357
+
358
+ def is_fp8(self):
359
+ return 'fp8' in self.name
360
+
361
+ def is_fp8e4nv(self):
362
+ return self.name == 'fp8e4nv'
363
+
364
+ def is_fp8e4b8(self):
365
+ return self.name == 'fp8e4b8'
366
+
367
+ def is_fp8e4b15(self):
368
+ return self.name == 'fp8e4b15'
369
+
370
+ def is_fp8e5(self):
371
+ return self.name == 'fp8e5'
372
+
373
+ def is_fp8e5b16(self):
374
+ return self.name == 'fp8e5b16'
375
+
376
+ def is_fp16(self):
377
+ return self.name == 'fp16'
378
+
379
+ def is_bf16(self):
380
+ return self.name == 'bf16'
381
+
382
+ def is_fp32(self):
383
+ return self.name == 'fp32'
384
+
385
+ def is_fp64(self):
386
+ return self.name == 'fp64'
387
+
388
+ def is_int1(self):
389
+ return self.name == 'int1'
390
+
391
+ def is_int8(self):
392
+ return self.name == 'int8'
393
+
394
+ def is_int16(self):
395
+ return self.name == 'int16'
396
+
397
+ def is_int32(self):
398
+ return self.name == 'int32'
399
+
400
+ def is_int64(self):
401
+ return self.name == 'int64'
402
+
403
+ def is_uint8(self):
404
+ return self.name == 'uint8'
405
+
406
+ def is_uint16(self):
407
+ return self.name == 'uint16'
408
+
409
+ def is_uint32(self):
410
+ return self.name == 'uint32'
411
+
412
+ def is_uint64(self):
413
+ return self.name == 'uint64'
414
+
415
+ def is_floating(self):
416
+ return self.name in dtype.FP_TYPES
417
+
418
+ def is_standard_floating(self):
419
+ return self.name in dtype.STANDARD_FP_TYPES
420
+
421
+ def is_int_signed(self):
422
+ return self.name in dtype.SINT_TYPES
423
+
424
+ def is_int_unsigned(self):
425
+ return self.name in dtype.UINT_TYPES
426
+
427
+ def is_int(self):
428
+ return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES
429
+
430
+ def is_bool(self):
431
+ return self.is_int1()
432
+
433
+ def kind(self):
434
+ # Return int value following the type ordering bool < integer < fp
435
+ if self.is_bool():
436
+ return dtype.KIND.BOOLEAN
437
+ elif self.is_int():
438
+ return dtype.KIND.INTEGRAL
439
+ else:
440
+ assert self.is_floating()
441
+ return dtype.KIND.FLOATING
442
+
443
+ def get_int_max_value(self):
444
+ if self.is_int_signed():
445
+ return 2**(self.int_bitwidth - 1) - 1
446
+ if self.is_int_unsigned():
447
+ return 2**self.int_bitwidth - 1
448
+ assert False
449
+
450
+ def get_int_min_value(self):
451
+ if self.is_int_signed():
452
+ return -2**(self.int_bitwidth - 1)
453
+ if self.is_int_unsigned():
454
+ return 0
455
+ assert False
456
+
457
+ @staticmethod
458
+ def is_dtype(type_str):
459
+ return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES
460
+
461
+ @staticmethod
462
+ def is_void():
463
+ raise RuntimeError("Not implemented")
464
+
465
+ @staticmethod
466
+ def is_block():
467
+ return False
468
+
469
+ @staticmethod
470
+ def is_ptr():
471
+ return False
472
+
473
+ @staticmethod
474
+ def is_const():
475
+ return False
476
+
477
+ def __eq__(self, other: dtype):
478
+ if not isinstance(other, dtype):
479
+ return False
480
+ return self.name == other.name
481
+
482
+ def __ne__(self, other: dtype):
483
+ return not self.__eq__(other)
484
+
485
+ def __hash__(self):
486
+ return hash((self.name, ))
487
+
488
+ @property
489
+ def scalar(self):
490
+ return self
491
+
492
+ def to_ir(self, builder: ir.builder) -> ir.type:
493
+ if self.name.startswith("fp8"):
494
+ if self.name not in builder.options.supported_fp8_dtypes:
495
+ raise ValueError(f'type {self} not supported in this architecture. '
496
+ f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}')
497
+ if self.name in builder.options.deprecated_fp8_dtypes:
498
+ warn(f"{self.name} is deprecated in this architecture and will be removed in a future triton release")
499
+
500
+ if self.name == 'void':
501
+ return builder.get_void_ty()
502
+ elif self.name == 'int1':
503
+ return builder.get_int1_ty()
504
+ elif self.name in ('int8', 'uint8'):
505
+ return builder.get_int8_ty()
506
+ elif self.name in ('int16', 'uint16'):
507
+ return builder.get_int16_ty()
508
+ elif self.name in ('int32', 'uint32'):
509
+ return builder.get_int32_ty()
510
+ elif self.name in ('int64', 'uint64'):
511
+ return builder.get_int64_ty()
512
+ elif self.name == 'fp8e5':
513
+ return builder.get_fp8e5_ty()
514
+ elif self.name == 'fp8e5b16':
515
+ return builder.get_fp8e5b16_ty()
516
+ elif self.name == 'fp8e4nv':
517
+ return builder.get_fp8e4nv_ty()
518
+ elif self.name == 'fp8e4b8':
519
+ return builder.get_fp8e4b8_ty()
520
+ elif self.name == 'fp8e4b15':
521
+ return builder.get_fp8e4b15_ty()
522
+ elif self.name == 'fp16':
523
+ return builder.get_half_ty()
524
+ elif self.name == 'bf16':
525
+ return builder.get_bf16_ty()
526
+ elif self.name == 'fp32':
527
+ return builder.get_float_ty()
528
+ elif self.name == 'fp64':
529
+ return builder.get_double_ty()
530
+ raise ValueError(f'fail to convert {self} to ir type')
531
+
532
+ def __str__(self):
533
+ return self.name
534
+
535
+ def codegen_name(self):
536
+ if self.name.startswith("fp"):
537
+ return "float" + self.name[2:]
538
+ elif self.name.startswith("bf"):
539
+ return "bfloat" + self.name[2:]
540
+ else:
541
+ return self.name
542
+
543
+ @property
544
+ def cache_key_part(self) -> str:
545
+ """See cache_key_part() in triton.cc."""
546
+ return self.name
547
+
548
+ def __repr__(self):
549
+ """Output of repr needs to be an evaluatable expression"""
550
+ return f'triton.language.{self.codegen_name()}'
551
+
552
+
553
+ # Some functions have a param named `dtype`, which shadows the `dtype` class.
554
+ # We can't change the param name because it is part of function's public API.
555
+ # Declare an alias so those functions can still reference the dtype class.
556
+ _DtypeClass = dtype
557
+
558
+
559
+ class pointer_type(dtype):
560
+
561
+ def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = False):
562
+ element_ty = _unwrap_if_constexpr(element_ty)
563
+ if not isinstance(element_ty, dtype):
564
+ raise TypeError(f'element_ty has type `{type(element_ty).__name__}`; expected `dtype`.')
565
+ self.element_ty = element_ty
566
+ self.address_space = address_space
567
+ self.const = const
568
+ self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>'
569
+
570
+ def to_ir(self, builder: ir.builder) -> ir.pointer_type:
571
+ return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space)
572
+
573
+ def __str__(self):
574
+ return self.name
575
+
576
+ def __repr__(self):
577
+ return self.__str__()
578
+
579
+ def is_ptr(self):
580
+ return True
581
+
582
+ def is_const(self):
583
+ return self.const
584
+
585
+ def __eq__(self, other: pointer_type) -> bool:
586
+ if not isinstance(other, pointer_type):
587
+ return False
588
+ return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const
589
+
590
+ def __ne__(self, other: pointer_type) -> bool:
591
+ return not self.__eq__(other)
592
+
593
+ @property
594
+ def scalar(self):
595
+ return self
596
+
597
+
598
+ class nv_tma_desc_type(pointer_type):
599
+
600
+ def __init__(self, const=True, address_space=0):
601
+ super().__init__(uint8, const=const, address_space=address_space)
602
+ self.name = 'nv_tma_desc_type'
603
+
604
+
605
+ class block_type(dtype):
606
+
607
+ def __init__(self, element_ty: dtype, shape: List):
608
+ self.element_ty = element_ty
609
+
610
+ # Note that block_type's shape is a list of int
611
+ # while tensor's shape is a list of constexpr.
612
+
613
+ # shape can be empty ([]) when an input is a 0D tensor.
614
+ self.shape = _unwrap_shape(shape)
615
+ if not self.shape:
616
+ raise TypeError('0d block_type is forbidden')
617
+
618
+ self.numel = validate_block_shape(self.shape)
619
+ self.name = f'<{self.shape}, {self.element_ty}>'
620
+
621
+ def to_ir(self, builder: ir.builder) -> ir.block_type:
622
+ return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape)
623
+
624
+ def __str__(self):
625
+ return self.name
626
+
627
+ def __repr__(self):
628
+ return self.__str__()
629
+
630
+ def is_block(self):
631
+ return True
632
+
633
+ def get_block_shapes(self) -> List[int]:
634
+ return self.shape
635
+
636
+ def __eq__(self, other: block_type) -> bool:
637
+ if not isinstance(other, block_type):
638
+ return False
639
+ return self.element_ty == other.element_ty and self.shape == other.shape
640
+
641
+ def __ne__(self, other: block_type) -> bool:
642
+ return not self.__eq__(other)
643
+
644
+ @property
645
+ def scalar(self):
646
+ return self.element_ty
647
+
648
+
649
+ class function_type(dtype):
650
+
651
+ def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
652
+ self.ret_types = ret_types
653
+ self.param_types = param_types
654
+
655
+ def __str__(self):
656
+ return f'fn ({self.param_types}) -> {self.ret_types}'
657
+
658
+ def to_ir(self, builder: ir.builder):
659
+ ir_param_types = [ty.to_ir(builder) for ty in self.param_types]
660
+ ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types]
661
+ return builder.get_function_ty(ir_param_types, ret_types)
662
+
663
+
664
+ # scalar types
665
+ void = dtype('void')
666
+ int1 = dtype('int1')
667
+ int8 = dtype('int8')
668
+ int16 = dtype('int16')
669
+ int32 = dtype('int32')
670
+ int64 = dtype('int64')
671
+ uint8 = dtype('uint8')
672
+ uint16 = dtype('uint16')
673
+ uint32 = dtype('uint32')
674
+ uint64 = dtype('uint64')
675
+ float8e5 = dtype('fp8e5')
676
+ float8e5b16 = dtype('fp8e5b16')
677
+ float8e4nv = dtype('fp8e4nv')
678
+ float8e4b8 = dtype('fp8e4b8')
679
+ float8e4b15 = dtype('fp8e4b15')
680
+ float16 = dtype('fp16')
681
+ bfloat16 = dtype('bf16')
682
+ float32 = dtype('fp32')
683
+ float64 = dtype('fp64')
684
+ # pointer types
685
+ pi32_t = pointer_type(int32)
686
+
687
+
688
+ def get_int_dtype(bitwidth: int, signed: bool) -> dtype:
689
+ if bitwidth == 1:
690
+ return int1
691
+ elif bitwidth == 8 and signed:
692
+ return int8
693
+ elif bitwidth == 8 and not signed:
694
+ return uint8
695
+ elif bitwidth == 16 and signed:
696
+ return int16
697
+ elif bitwidth == 16 and not signed:
698
+ return uint16
699
+ elif bitwidth == 32 and signed:
700
+ return int32
701
+ elif bitwidth == 32 and not signed:
702
+ return uint32
703
+ elif bitwidth == 64 and signed:
704
+ return int64
705
+ elif bitwidth == 64 and not signed:
706
+ return uint64
707
+ else:
708
+ raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}')
709
+
710
+
711
+ class _value:
712
+ """Base class of values that exist in the triton IR (i.e. not constexprs).
713
+ """
714
+
715
+ def __init__(self, handle):
716
+ self.handle = handle
717
+
718
+
719
+ # -----------------------
720
+ # tensor
721
+ # -----------------------
722
+
723
+
724
+ class tensor(_value):
725
+ """Represents an N-dimensional array of values or pointers.
726
+
727
+ :code:`tensor` is the fundamental data structure in Triton programs. Most
728
+ functions in :py:mod:`triton.language` operate on and return tensors.
729
+
730
+ Most of the named member functions here are duplicates of the free functions
731
+ in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is
732
+ equivalent to :code:`x.sqrt()`.
733
+
734
+ :code:`tensor` also defines most of the magic/dunder methods, so you can
735
+ write :code:`x+y`, :code:`x << 2`, etc.
736
+
737
+ .. rubric:: Constructors
738
+ ..
739
+ For some reason Sphinx includes __init__ before printing the full table
740
+ of methods. Not what I want, but I can't figure out how to fix it. Give
741
+ it its own section so it looks intentional. :)
742
+ """
743
+
744
+ def __init__(self, handle, type: dtype):
745
+ """Not called by user code."""
746
+ # IR handle
747
+ super().__init__(handle)
748
+ # Block shape
749
+ self.shape = type.shape if type.is_block() else ()
750
+ self.numel = 1
751
+ for s in self.shape:
752
+ self.numel *= s
753
+ self.numel = constexpr(self.numel)
754
+ self.type = type # Tensor type (can be block_type)
755
+ # Following the practice in pytorch, dtype is scalar type
756
+ self.dtype = type.scalar
757
+ self.shape = [constexpr(s) for s in self.shape]
758
+
759
+ def __str__(self) -> str:
760
+ # ex. "float32[16, 32]"
761
+ return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
762
+
763
+ @builtin
764
+ def __add__(self, other, _builder=None):
765
+ return add(self, other, sanitize_overflow=True, _builder=_builder)
766
+
767
+ @builtin
768
+ def __radd__(self, other, _builder=None):
769
+ return add(other, self, sanitize_overflow=True, _builder=_builder)
770
+
771
+ @builtin
772
+ def __sub__(self, other, _builder=None):
773
+ return sub(self, other, sanitize_overflow=True, _builder=_builder)
774
+
775
+ @builtin
776
+ def __rsub__(self, other, _builder=None):
777
+ return sub(other, self, sanitize_overflow=True, _builder=_builder)
778
+
779
+ @builtin
780
+ def __mul__(self, other, _builder=None):
781
+ return mul(self, other, sanitize_overflow=True, _builder=_builder)
782
+
783
+ @builtin
784
+ def __rmul__(self, other, _builder=None):
785
+ return mul(other, self, sanitize_overflow=True, _builder=_builder)
786
+
787
+ @builtin
788
+ def __truediv__(self, other, _builder=None):
789
+ other = _unwrap_if_constexpr(other)
790
+ return semantic.truediv(self, other, _builder)
791
+
792
+ @builtin
793
+ def __rtruediv__(self, other, _builder=None):
794
+ other = _unwrap_if_constexpr(other)
795
+ return semantic.truediv(other, self, _builder)
796
+
797
+ @builtin
798
+ def __floordiv__(self, other, _builder=None):
799
+ other = _unwrap_if_constexpr(other)
800
+ return semantic.floordiv(self, other, _builder)
801
+
802
+ @builtin
803
+ def __rfloordiv__(self, other, _builder=None):
804
+ other = _unwrap_if_constexpr(other)
805
+ return semantic.floordiv(other, self, _builder)
806
+
807
+ @builtin
808
+ def __mod__(self, other, _builder=None):
809
+ other = _unwrap_if_constexpr(other)
810
+ return semantic.mod(self, other, _builder)
811
+
812
+ @builtin
813
+ def __rmod__(self, other, _builder=None):
814
+ other = _unwrap_if_constexpr(other)
815
+ return semantic.mod(other, self, _builder)
816
+
817
+ # unary operators
818
+ @builtin
819
+ def __neg__(self, _builder=None):
820
+ return semantic.minus(self, _builder)
821
+
822
+ @builtin
823
+ def __invert__(self, _builder=None):
824
+ return semantic.invert(self, _builder)
825
+
826
+ # bitwise operators
827
+
828
+ @builtin
829
+ def __and__(self, other, _builder=None):
830
+ other = _unwrap_if_constexpr(other)
831
+ return semantic.and_(self, other, _builder)
832
+
833
+ @builtin
834
+ def __rand__(self, other, _builder=None):
835
+ other = _unwrap_if_constexpr(other)
836
+ return semantic.and_(other, self, _builder)
837
+
838
+ @builtin
839
+ def __or__(self, other, _builder=None):
840
+ other = _unwrap_if_constexpr(other)
841
+ return semantic.or_(self, other, _builder)
842
+
843
+ @builtin
844
+ def __ror__(self, other, _builder=None):
845
+ other = _unwrap_if_constexpr(other)
846
+ return semantic.or_(other, self, _builder)
847
+
848
+ @builtin
849
+ def __xor__(self, other, _builder=None):
850
+ other = _unwrap_if_constexpr(other)
851
+ return semantic.xor_(self, other, _builder)
852
+
853
+ @builtin
854
+ def __rxor__(self, other, _builder=None):
855
+ other = _unwrap_if_constexpr(other)
856
+ return semantic.xor_(other, self, _builder)
857
+
858
+ @builtin
859
+ def __lshift__(self, other, _builder=None):
860
+ check_bit_width(self, other)
861
+ other = _unwrap_if_constexpr(other)
862
+ return semantic.shl(self, other, _builder)
863
+
864
+ @builtin
865
+ def __rlshift__(self, other, _builder=None):
866
+ check_bit_width(other, self)
867
+ other = _unwrap_if_constexpr(other)
868
+ return semantic.shl(other, self, _builder)
869
+
870
+ @builtin
871
+ def __rshift__(self, other, _builder=None):
872
+ check_bit_width(self, other)
873
+ other = _unwrap_if_constexpr(other)
874
+ if self.dtype.is_int_signed():
875
+ return semantic.ashr(self, other, _builder)
876
+ else:
877
+ return semantic.lshr(self, other, _builder)
878
+
879
+ @builtin
880
+ def __rrshift__(self, other, _builder=None):
881
+ check_bit_width(other, self)
882
+ other = _unwrap_if_constexpr(other)
883
+ if self.dtype.is_int_signed():
884
+ return semantic.ashr(other, self, _builder)
885
+ else:
886
+ return semantic.lshr(other, self, _builder)
887
+
888
+ # >
889
+ @builtin
890
+ def __gt__(self, other, _builder=None):
891
+ other = semantic.to_tensor(other, _builder)
892
+ return semantic.greater_than(self, other, _builder)
893
+
894
+ @builtin
895
+ def __rgt__(self, other, _builder=None):
896
+ other = semantic.to_tensor(other, _builder)
897
+ return semantic.greater_than(other, self, _builder)
898
+
899
+ # >=
900
+ @builtin
901
+ def __ge__(self, other, _builder=None):
902
+ other = semantic.to_tensor(other, _builder)
903
+ return semantic.greater_equal(self, other, _builder)
904
+
905
+ @builtin
906
+ def __rge__(self, other, _builder=None):
907
+ other = semantic.to_tensor(other, _builder)
908
+ return semantic.greater_equal(other, self, _builder)
909
+
910
+ # <
911
+ @builtin
912
+ def __lt__(self, other, _builder=None):
913
+ other = semantic.to_tensor(other, _builder)
914
+ return semantic.less_than(self, other, _builder)
915
+
916
+ @builtin
917
+ def __rlt__(self, other, _builder=None):
918
+ other = semantic.to_tensor(other, _builder)
919
+ return semantic.less_than(other, self, _builder)
920
+
921
+ # <=
922
+ @builtin
923
+ def __le__(self, other, _builder=None):
924
+ other = semantic.to_tensor(other, _builder)
925
+ return semantic.less_equal(self, other, _builder)
926
+
927
+ @builtin
928
+ def __rle__(self, other, _builder=None):
929
+ other = semantic.to_tensor(other, _builder)
930
+ return semantic.less_equal(other, self, _builder)
931
+
932
+ # ==
933
+ @builtin
934
+ def __eq__(self, other, _builder=None):
935
+ other = semantic.to_tensor(other, _builder)
936
+ return semantic.equal(self, other, _builder)
937
+
938
+ @builtin
939
+ def __req__(self, other, _builder=None):
940
+ other = semantic.to_tensor(other, _builder)
941
+ return semantic.equal(other, self, _builder)
942
+
943
+ @builtin
944
+ def __ne__(self, other, _builder=None):
945
+ other = semantic.to_tensor(other, _builder)
946
+ return semantic.not_equal(self, other, _builder)
947
+
948
+ @builtin
949
+ def __rne__(self, other, _builder=None):
950
+ other = semantic.to_tensor(other, _builder)
951
+ return semantic.not_equal(other, self, _builder)
952
+
953
+ @builtin
954
+ def logical_and(self, other, _builder=None):
955
+ other = semantic.to_tensor(other, _builder)
956
+ return semantic.logical_and(self, other, _builder)
957
+
958
+ @builtin
959
+ def logical_or(self, other, _builder=None):
960
+ other = semantic.to_tensor(other, _builder)
961
+ return semantic.logical_or(self, other, _builder)
962
+
963
+ # note: __not__ isn't actually a magic method in python
964
+ # but it's ok because our ASTVisitor handles it
965
+ @builtin
966
+ def __not__(self, _builder=None):
967
+ return semantic.not_(self, _builder)
968
+
969
+ @builtin
970
+ def __getitem__(self, slices, _builder=None):
971
+ if isinstance(slices, (slice, constexpr)) or slices is None:
972
+ slices = [slices]
973
+ ret = self
974
+ for dim, sl in enumerate(slices):
975
+ if sl is None or isinstance(sl, constexpr) and sl.value is None:
976
+ ret = semantic.expand_dims(ret, dim, _builder)
977
+ elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
978
+ pass
979
+ else:
980
+ raise ValueError(f"unsupported tensor index: {sl}")
981
+ return ret
982
+
983
+ @property
984
+ def T(self):
985
+ """Transposes a 2D tensor."""
986
+ assert False, "Transposition must be created by the AST Visitor"
987
+
988
+ @builtin
989
+ def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None):
990
+ """
991
+ Alias for :py:func:`tensor.cast`.
992
+ """
993
+ # Triton doesn't like core functions calling other core functions, so we
994
+ # just copy-paste the implementation of cast here. It's not too bad.
995
+ dtype = _unwrap_if_constexpr(dtype)
996
+ bitcast = _unwrap_if_constexpr(bitcast)
997
+ if bitcast:
998
+ return semantic.bitcast(self, dtype, _builder)
999
+ return semantic.cast(self, dtype, _builder, fp_downcast_rounding)
1000
+
1001
+ # Type stubs for functions added by the _tensor_member_fn decorator.
1002
+ # (Unfortunately these can't be created automatically.)
1003
+ #
1004
+ # We couldn't write these definitions out even if we wanted to, because some
1005
+ # of these functions are defined in standard.py.
1006
+ def broadcast_to(self, *shape) -> tensor:
1007
+ ...
1008
+
1009
+ def trans(self, *dims) -> tensor:
1010
+ ...
1011
+
1012
+ def permute(self, *dims) -> tensor:
1013
+ ...
1014
+
1015
+ def split(self) -> tuple[tensor, tensor]:
1016
+ ...
1017
+
1018
+ def view(self, *shape) -> tensor:
1019
+ ...
1020
+
1021
+ def reshape(self, *shape) -> tensor:
1022
+ ...
1023
+
1024
+ def expand_dims(self, axis) -> tensor:
1025
+ ...
1026
+
1027
+ def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor:
1028
+ ...
1029
+
1030
+ def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor:
1031
+ ...
1032
+
1033
+ def advance(self, offsets) -> tensor:
1034
+ ...
1035
+
1036
+ def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor:
1037
+ ...
1038
+
1039
+ def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor:
1040
+ ...
1041
+
1042
+ def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor:
1043
+ ...
1044
+
1045
+ def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor:
1046
+ ...
1047
+
1048
+ def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor:
1049
+ ...
1050
+
1051
+ def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor:
1052
+ ...
1053
+
1054
+ def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor:
1055
+ ...
1056
+
1057
+ def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor:
1058
+ ...
1059
+
1060
+ def exp(self) -> tensor:
1061
+ ...
1062
+
1063
+ def log(self) -> tensor:
1064
+ ...
1065
+
1066
+ def cos(self) -> tensor:
1067
+ ...
1068
+
1069
+ def sin(self) -> tensor:
1070
+ ...
1071
+
1072
+ def sqrt(self) -> tensor:
1073
+ ...
1074
+
1075
+ def rsqrt(self) -> tensor:
1076
+ ...
1077
+
1078
+ def abs(self) -> tensor:
1079
+ ...
1080
+
1081
+ def reduce(self, axis, combine_fn, keep_dims=False) -> tensor:
1082
+ ...
1083
+
1084
+ def associative_scan(self, axis, combine_fn, reverse=False) -> tensor:
1085
+ ...
1086
+
1087
+ def histogram(self, num_bins) -> tensor:
1088
+ ...
1089
+
1090
+ def cdiv(self, div) -> tensor:
1091
+ ...
1092
+
1093
+ def sigmoid(self) -> tensor:
1094
+ ...
1095
+
1096
+ def softmax(self, ieee_rounding=False) -> tensor:
1097
+ ...
1098
+
1099
+ def ravel(self) -> tensor:
1100
+ ...
1101
+
1102
+ def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor:
1103
+ ...
1104
+
1105
+ def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor:
1106
+ ...
1107
+
1108
+ def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor:
1109
+ ...
1110
+
1111
+ def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor:
1112
+ ...
1113
+
1114
+ def sum(self, axis=None, keep_dims=False) -> tensor:
1115
+ ...
1116
+
1117
+ def xor_sum(self, axis=None, keep_dims=False) -> tensor:
1118
+ ...
1119
+
1120
+ def cumsum(self, axis=0, reverse=False) -> tensor:
1121
+ ...
1122
+
1123
+ def cumprod(self, axis=0, reverse=False) -> tensor:
1124
+ ...
1125
+
1126
+ def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor:
1127
+ ...
1128
+
1129
+ def flip(self, dim=None) -> tensor:
1130
+ ...
1131
+
1132
+
1133
+ def get_bool_env_var(var_name):
1134
+ v = os.getenv(var_name, "0")
1135
+ return v == "1" or v == "true" or v == "on"
1136
+
1137
+
1138
+ # -----------------------
1139
+ # SPMD Programming Model
1140
+ # -----------------------
1141
+ def _constexpr_to_value(v):
1142
+ if isinstance(v, constexpr):
1143
+ return v.value
1144
+ return v
1145
+
1146
+
1147
+ @builtin
1148
+ def program_id(axis, _builder=None):
1149
+ """
1150
+ Returns the id of the current program instance along the given :code:`axis`.
1151
+
1152
+ :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
1153
+ :type axis: int
1154
+ """
1155
+ # if axis == -1:
1156
+ # pid0 = program_id(0, _builder)
1157
+ # pid1 = program_id(1, _builder)
1158
+ # pid2 = program_id(2, _builder)
1159
+ # npg0 = num_programs(0, _builder)
1160
+ # npg1 = num_programs(1, _builder)
1161
+ # return pid0 + pid1*npg0 + pid2*npg0*npg1
1162
+ axis = _constexpr_to_value(axis)
1163
+ return semantic.program_id(axis, _builder)
1164
+
1165
+
1166
+ @builtin
1167
+ def num_programs(axis, _builder=None):
1168
+ """
1169
+ Returns the number of program instances launched along the given :code:`axis`.
1170
+
1171
+ :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
1172
+ :type axis: int
1173
+ """
1174
+ axis = _constexpr_to_value(axis)
1175
+ return semantic.num_programs(axis, _builder)
1176
+
1177
+
1178
+ # -----------------------
1179
+ # Block Initialization
1180
+ # -----------------------
1181
+
1182
+
1183
+ @builtin
1184
+ def arange(start, end, _builder=None):
1185
+ start = _constexpr_to_value(start)
1186
+ end = _constexpr_to_value(end)
1187
+ return semantic.arange(start, end, _builder)
1188
+
1189
+
1190
+ arange.__doc__ = f"""
1191
+ Returns contiguous values within the half-open interval :code:`[start,
1192
+ end)`. :code:`end - start` must be less than or equal to
1193
+ :code:`TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}`
1194
+
1195
+ :param start: Start of the interval. Must be a power of two.
1196
+ :type start: int32
1197
+ :param end: End of the interval. Must be a power of two greater than
1198
+ :code:`start`.
1199
+ :type end: int32
1200
+ """
1201
+
1202
+
1203
+ def _unwrap_shape(shape):
1204
+ shape = _constexpr_to_value(shape)
1205
+ return [_constexpr_to_value(s) for s in shape]
1206
+
1207
+
1208
+ def _shape_check_impl(shape):
1209
+ shape = _unwrap_shape(shape)
1210
+ validate_block_shape(shape)
1211
+ return shape
1212
+
1213
+
1214
+ @builtin
1215
+ def full(shape, value, dtype, _builder=None):
1216
+ """
1217
+ Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`.
1218
+
1219
+ :param shape: Shape of the new array, e.g., (8, 16) or (8, )
1220
+ :type shape: tuple of ints
1221
+ :param value: A scalar value to fill the array with
1222
+ :type value: scalar
1223
+ :param dtype: Data type of the new array, e.g., :code:`tl.float16`
1224
+ :type dtype: tl.dtype
1225
+ """
1226
+ shape = _shape_check_impl(shape)
1227
+ value = _constexpr_to_value(value)
1228
+ dtype = _constexpr_to_value(dtype)
1229
+ return semantic.full(shape, value, dtype, _builder)
1230
+
1231
+
1232
+ # -----------------------
1233
+ # Shape Manipulation
1234
+ # -----------------------
1235
+
1236
+
1237
+ @builtin
1238
+ def broadcast(input, other, _builder=None):
1239
+ """
1240
+ Tries to broadcast the two given blocks to a common compatible shape.
1241
+
1242
+ :param input: The first input tensor.
1243
+ :type input: Block
1244
+ :param other: The second input tensor.
1245
+ :type other: Block
1246
+ """
1247
+ return semantic.broadcast_impl_value(input, other, _builder)
1248
+
1249
+
1250
+ @_tensor_member_fn
1251
+ @builtin
1252
+ def broadcast_to(input, *shape, _builder=None):
1253
+ """
1254
+ Tries to broadcast the given tensor to a new :code:`shape`.
1255
+
1256
+ :param input: The input tensor.
1257
+ :type input: Block
1258
+ :param shape: The desired shape.
1259
+ :type shape:
1260
+
1261
+ :code:`shape` can be passed as a tuple or as individual parameters: ::
1262
+
1263
+ # These are equivalent
1264
+ broadcast_to(x, (32, 32))
1265
+ broadcast_to(x, 32, 32)
1266
+ """
1267
+ shape = _shape_check_impl(_unwrap_iterable(shape))
1268
+ return semantic.broadcast_impl_shape(input, shape, _builder)
1269
+
1270
+
1271
+ @_tensor_member_fn
1272
+ @builtin
1273
+ def trans(input: tensor, *dims, _builder=None):
1274
+ """
1275
+ Permutes the dimensions of a tensor.
1276
+
1277
+ If the parameter :code:`dims` is not specified, the function defaults to a (1,0) permutation,
1278
+ effectively transposing a 2D tensor.
1279
+
1280
+ :param input: The input tensor.
1281
+ :param dims: The desired ordering of dimensions. For example,
1282
+ :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor.
1283
+
1284
+ :code:`dims` can be passed as a tuple or as individual parameters: ::
1285
+
1286
+ # These are equivalent
1287
+ trans(x, (2, 1, 0))
1288
+ trans(x, 2, 1, 0)
1289
+
1290
+ :py:func:`permute` is equivalent to this function, except it doesn't
1291
+ have the special case when no permutation is specified.
1292
+ """
1293
+ if not dims:
1294
+ dims = (1, 0)
1295
+ return semantic.permute(input, dims, _builder)
1296
+
1297
+
1298
+ @_tensor_member_fn
1299
+ @builtin
1300
+ def permute(input, *dims, _builder=None):
1301
+ """
1302
+ Permutes the dimensions of a tensor.
1303
+
1304
+ :param input: The input tensor.
1305
+ :type input: Block
1306
+ :param dims: The desired ordering of dimensions. For example,
1307
+ :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor.
1308
+
1309
+ :code:`dims` can be passed as a tuple or as individual parameters: ::
1310
+
1311
+ # These are equivalent
1312
+ permute(x, (2, 1, 0))
1313
+ permute(x, 2, 1, 0)
1314
+
1315
+ :py:func:`trans` is equivalent to this function, except when
1316
+ :code:`dims` is empty, it tries to do a (1,0) permutation.
1317
+ """
1318
+ dims = _unwrap_iterable(dims)
1319
+ return semantic.permute(input, dims, _builder)
1320
+
1321
+
1322
+ @builtin
1323
+ def cat(input, other, can_reorder=False, _builder=None):
1324
+ """
1325
+ Concatenate the given blocks
1326
+
1327
+ :param input: The first input tensor.
1328
+ :type input: Tensor
1329
+ :param other: The second input tensor.
1330
+ :type other: Tensor
1331
+ :param reorder: Compiler hint. If true, the compiler is
1332
+ allowed to reorder elements while concatenating inputs. Only use if the
1333
+ order does not matter (e.g., result is only used in reduction ops).
1334
+ Current implementation of `cat` supports only can_reorder=True.
1335
+ """
1336
+ return semantic.cat(input, other, can_reorder, _builder)
1337
+
1338
+
1339
+ @builtin
1340
+ def join(a, b, _builder=None):
1341
+ """
1342
+ Join the given tensors in a new, minor dimension.
1343
+
1344
+ For example, given two tensors of shape (4,8), produces a new tensor of
1345
+ shape (4,8,2). Given two scalars, returns a tensor of shape (2).
1346
+
1347
+ The two inputs are broadcasted to be the same shape.
1348
+
1349
+ If you want to join more than two elements, you can use multiple calls to
1350
+ this function. This reflects the constraint in Triton that tensors must
1351
+ have power-of-two sizes.
1352
+
1353
+ join is the inverse of split.
1354
+
1355
+ :param a: The first input tensor.
1356
+ :type a: Tensor
1357
+ :param b: The second input tensor.
1358
+ :type b: Tensor
1359
+ """
1360
+ return semantic.join(a, b, _builder)
1361
+
1362
+
1363
+ @jit
1364
+ def _take_first(a, b):
1365
+ return a
1366
+
1367
+
1368
+ @_tensor_member_fn
1369
+ @builtin
1370
+ def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]:
1371
+ """
1372
+ Split a tensor in two along its last dim, which must have size 2.
1373
+
1374
+ For example, given a tensor of shape (4,8,2), produces two tensors of shape
1375
+ (4,8). Given a tensor of shape (2), returns two scalars.
1376
+
1377
+ If you want to split into more than two pieces, you can use multiple calls
1378
+ to this function (probably plus calling reshape). This reflects the
1379
+ constraint in Triton that tensors must have power-of-two sizes.
1380
+
1381
+ split is the inverse of join.
1382
+
1383
+ :param a: The tensor to split.
1384
+ :type a: Tensor
1385
+ """
1386
+ # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars.
1387
+ # But semantic.split can only handle returning tensors. Work around this by
1388
+ # expanding the input to shape [1,2] and then reducing the result.
1389
+ was_rank_1 = len(a.shape) == 1
1390
+ if was_rank_1:
1391
+ a = semantic.expand_dims(a, 0, _builder)
1392
+
1393
+ out_lhs, out_rhs = semantic.split(a, _builder)
1394
+
1395
+ if was_rank_1:
1396
+ # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar.
1397
+ out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator))
1398
+ out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator))
1399
+
1400
+ return out_lhs, out_rhs
1401
+
1402
+
1403
+ @_tensor_member_fn
1404
+ @builtin
1405
+ def view(input, *shape, _builder=None):
1406
+ """
1407
+ Returns a tensor with the same elements as `input` but a different shape.
1408
+ The order of the elements may not be preserved.
1409
+
1410
+ :param input: The input tensor.
1411
+ :type input: Block
1412
+ :param shape: The desired shape.
1413
+
1414
+ :code:`shape` can be passed as a tuple or as individual parameters: ::
1415
+
1416
+ # These are equivalent
1417
+ view(x, (32, 32))
1418
+ view(x, 32, 32)
1419
+ """
1420
+ warn("view is deprecated, please use reshape with can_reorder being true.")
1421
+ shape = _shape_check_impl(_unwrap_iterable(shape))
1422
+ return semantic.reshape(input, shape, can_reorder=True, builder=_builder)
1423
+
1424
+
1425
+ @_tensor_member_fn
1426
+ @builtin
1427
+ def reshape(input, *shape, can_reorder=False, _builder=None):
1428
+ """
1429
+ Returns a tensor with the same number of elements as input but with the
1430
+ provided shape.
1431
+
1432
+ :param input: The input tensor.
1433
+ :type input: Block
1434
+ :param shape: The new shape.
1435
+
1436
+ :code:`shape` can be passed as a tuple or as individual parameters: ::
1437
+
1438
+ # These are equivalent
1439
+ reshape(x, (32, 32))
1440
+ reshape(x, 32, 32)
1441
+ """
1442
+ shape = _shape_check_impl(_unwrap_iterable(shape))
1443
+ return semantic.reshape(input, shape, can_reorder, _builder)
1444
+
1445
+
1446
+ def _wrap_axis(axis, ndim):
1447
+ if not (-ndim <= axis < ndim):
1448
+ raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}")
1449
+
1450
+ return axis if axis >= 0 else axis + ndim
1451
+
1452
+
1453
+ @_tensor_member_fn
1454
+ @builtin
1455
+ def expand_dims(input, axis, _builder=None):
1456
+ """
1457
+ Expand the shape of a tensor, by inserting new length-1 dimensions.
1458
+
1459
+ Axis indices are with respect to the resulting tensor, so
1460
+ ``result.shape[axis]`` will be 1 for each axis.
1461
+
1462
+ :param input: The input tensor.
1463
+ :type input: tl.tensor
1464
+ :param axis: The indices to add new axes
1465
+ :type axis: int | Sequence[int]
1466
+
1467
+ """
1468
+ input = semantic.to_tensor(input, _builder)
1469
+ axis = _constexpr_to_value(axis)
1470
+ axes = list(axis) if isinstance(axis, Sequence) else [axis]
1471
+ new_ndim = len(input.shape) + len(axes)
1472
+ axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes]
1473
+
1474
+ if len(set(axes)) != len(axes):
1475
+ raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}")
1476
+
1477
+ ret = input
1478
+ for a in sorted(axes):
1479
+ ret = semantic.expand_dims(ret, a, _builder)
1480
+ return ret
1481
+
1482
+
1483
+ @_tensor_member_fn
1484
+ @builtin
1485
+ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None):
1486
+ """
1487
+ Casts a tensor to the given :code:`dtype`.
1488
+
1489
+ :param dtype: The target data type.
1490
+ :type dtype: tl.dtype
1491
+ :param fp_downcast_rounding: The rounding mode for downcasting
1492
+ floating-point values. This parameter is only used when self is a
1493
+ floating-point tensor and dtype is a floating-point type with a
1494
+ smaller bitwidth. Supported values are :code:`"rtne"` (round to
1495
+ nearest, ties to even) and :code:`"rtz"` (round towards zero).
1496
+ :type fp_downcast_rounding: str, optional
1497
+ :param bitcast: If true, the tensor is bitcasted to the given
1498
+ :code:`dtype`, instead of being numerically casted.
1499
+ :type bitcast: bool, optional
1500
+ """
1501
+ input = semantic.to_tensor(input, _builder)
1502
+ if isinstance(bitcast, constexpr):
1503
+ bitcast = bitcast.value
1504
+ if bitcast:
1505
+ return semantic.bitcast(input, dtype, _builder)
1506
+ return semantic.cast(input, dtype, _builder, fp_downcast_rounding)
1507
+
1508
+
1509
+ # -----------------------
1510
+ # Linear Algebra
1511
+ # -----------------------
1512
+
1513
+
1514
+ @builtin
1515
+ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32,
1516
+ _builder=None):
1517
+ """
1518
+ Returns the matrix product of two blocks.
1519
+
1520
+ The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions.
1521
+ For three-dimensional blocks, `tl.dot` performs the batched matrix product,
1522
+ where the first dimension of each block represents the batch dimension.
1523
+
1524
+ :param input: The first tensor to be multiplied.
1525
+ :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
1526
+ :param other: The second tensor to be multiplied.
1527
+ :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
1528
+ :param acc: The accumulator tensor. If not None, the result is added to this tensor.
1529
+ :type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`}
1530
+ :param input_precision: How to exercise the Tensor Cores for f32 x f32. If
1531
+ the device does not have Tensor Cores or the inputs are not of dtype f32,
1532
+ this option is ignored. For devices that do have tensor cores, the
1533
+ default precision is tf32.
1534
+ :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`.
1535
+ :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32".
1536
+ Only one of :code:`input_precision` and :code:`allow_tf32` can be
1537
+ specified (i.e. at least one must be :code:`None`).
1538
+ """
1539
+ assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified"
1540
+ if input_precision is None:
1541
+ supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions
1542
+ default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee"
1543
+ input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision)
1544
+
1545
+ input_precision = _constexpr_to_value(input_precision)
1546
+ out_dtype = _constexpr_to_value(out_dtype)
1547
+ max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc)
1548
+ return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder)
1549
+
1550
+
1551
+ @builtin
1552
+ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=float32, _builder=None):
1553
+ """
1554
+ Returns the matrix product of two blocks in microscaling format.
1555
+ lhs and rhs use microscaling formats described here:
1556
+ https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
1557
+ :param lhs: The first tensor to be multiplied.
1558
+ :type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format.
1559
+ :param lhs_scale: Scale factor for lhs tensor.
1560
+ :type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor).
1561
+ :param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}.
1562
+ :param rhs: The second tensor to be multiplied.
1563
+ :type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format.
1564
+ :param rhs_scale: Scale factor for rhs tensor.
1565
+ :type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor).
1566
+ :param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}.
1567
+ :param acc: The accumulator tensor. If not None, the result is added to this tensor.
1568
+ """
1569
+ out_dtype = _constexpr_to_value(out_dtype)
1570
+ assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment"
1571
+ return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, out_dtype, _builder)
1572
+
1573
+
1574
+ # -----------------------
1575
+ # Non-Atomic Memory Operations
1576
+ # -----------------------
1577
+
1578
+
1579
+ @builtin
1580
+ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="",
1581
+ volatile=False, _builder=None):
1582
+ """
1583
+ Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
1584
+
1585
+ (1) If `pointer` is a single element pointer, a scalar is be loaded. In
1586
+ this case:
1587
+
1588
+ - `mask` and `other` must also be scalars,
1589
+ - `other` is implicitly typecast to `pointer.dtype.element_ty`, and
1590
+ - `boundary_check` and `padding_option` must be empty.
1591
+
1592
+ (2) If `pointer` is an N-dimensional tensor of pointers, an
1593
+ N-dimensional tensor is loaded. In this case:
1594
+
1595
+ - `mask` and `other` are implicitly broadcast to `pointer.shape`,
1596
+ - `other` is implicitly typecast to `pointer.dtype.element_ty`, and
1597
+ - `boundary_check` and `padding_option` must be empty.
1598
+
1599
+ (3) If `pointer` is a block pointer defined by `make_block_ptr`, a
1600
+ tensor is loaded. In this case:
1601
+
1602
+ - `mask` and `other` must be `None`, and
1603
+ - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access.
1604
+
1605
+ :param pointer: Pointer to the data to be loaded
1606
+ :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
1607
+ :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]`
1608
+ (must be `None` with block pointers)
1609
+ :type mask: Block of `triton.int1`, optional
1610
+ :param other: if `mask[idx]` is false, return `other[idx]`
1611
+ :type other: Block, optional
1612
+ :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
1613
+ :type boundary_check: tuple of ints, optional
1614
+ :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value.
1615
+ :param cache_modifier: changes cache option in NVIDIA PTX
1616
+ :type cache_modifier: str, optional, should be one of {"", "ca", "cg"}, where "ca" stands for
1617
+ cache at all levels and "cg" stands for cache at global level (cache in L2 and below, not L1), see
1618
+ `cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
1619
+ :param eviction_policy: changes eviction policy in NVIDIA PTX
1620
+ :type eviction_policy: str, optional
1621
+ :param volatile: changes volatile option in NVIDIA PTX
1622
+ :type volatile: bool, optional
1623
+ """
1624
+ # `mask` and `other` can be constexpr
1625
+ mask = _constexpr_to_value(mask)
1626
+ other = _constexpr_to_value(other)
1627
+ if mask is not None:
1628
+ mask = semantic.to_tensor(mask, _builder)
1629
+ if other is not None:
1630
+ other = semantic.to_tensor(other, _builder)
1631
+ padding_option = _constexpr_to_value(padding_option)
1632
+ cache_modifier = _constexpr_to_value(cache_modifier)
1633
+ eviction_policy = _constexpr_to_value(eviction_policy)
1634
+ volatile = _constexpr_to_value(volatile)
1635
+ return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
1636
+ volatile, _builder)
1637
+
1638
+
1639
+ @builtin
1640
+ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None):
1641
+ """
1642
+ Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations.
1643
+ This will be removed in the future and shouldn't be used in production code.
1644
+
1645
+ This loads a tensor of data based on the descriptor and offsets.
1646
+ """
1647
+ type = block_type(_constexpr_to_value(dtype), shape)
1648
+ return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder)
1649
+
1650
+
1651
+ @builtin
1652
+ def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None):
1653
+ """
1654
+ Experimental feature to access TMA descriptors stores. This is an escape hatch to easily exercise TTGIR operations.
1655
+ This will be removed in the future and shouldn't be used in production code.
1656
+
1657
+ This stores a tensor of data based on the descriptor and offsets.
1658
+ """
1659
+ return semantic.descriptor_store(desc_pointer, value, offsets, _builder)
1660
+
1661
+
1662
+ @_tensor_member_fn
1663
+ @builtin
1664
+ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None):
1665
+ """
1666
+ Store a tensor of data into memory locations defined by `pointer`.
1667
+
1668
+ (1) If `pointer` is a single element pointer, a scalar is stored. In
1669
+ this case:
1670
+
1671
+ - `mask` must also be scalar, and
1672
+ - `boundary_check` and `padding_option` must be empty.
1673
+
1674
+ (2) If `pointer` is an N-dimensional tensor of pointers, an
1675
+ N-dimensional block is stored. In this case:
1676
+
1677
+ - `mask` is implicitly broadcast to `pointer.shape`, and
1678
+ - `boundary_check` must be empty.
1679
+
1680
+ (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block
1681
+ of data is stored. In this case:
1682
+
1683
+ - `mask` must be None, and
1684
+ - `boundary_check` can be specified to control the behavior of out-of-bound access.
1685
+
1686
+ `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`.
1687
+
1688
+ :param pointer: The memory location where the elements of `value` are stored
1689
+ :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
1690
+ :param value: The tensor of elements to be stored
1691
+ :type value: Block
1692
+ :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]`
1693
+ :type mask: Block of triton.int1, optional
1694
+ :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
1695
+ :type boundary_check: tuple of ints, optional
1696
+ :param cache_modifier: changes cache option in NVIDIA PTX
1697
+ :type cache_modifier: str, optional, should be one of {"", ".wb", ".cg", ".cs", ".wt"}, where ".wb" stands for
1698
+ cache write-back all coherent levels, ".cg" stands for cache global, ".cs" stands for cache streaming, ".wt"
1699
+ stands for cache write-through, see `cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
1700
+ :param eviction_policy: changes eviction policy in NVIDIA PTX
1701
+ :type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"}
1702
+ """
1703
+ # `value` can be constexpr
1704
+ value = semantic.to_tensor(value, _builder)
1705
+ mask = _constexpr_to_value(mask)
1706
+ if mask is not None:
1707
+ mask = semantic.to_tensor(mask, _builder)
1708
+ cache_modifier = _constexpr_to_value(cache_modifier)
1709
+ eviction_policy = _constexpr_to_value(eviction_policy)
1710
+ return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)
1711
+
1712
+
1713
+ @builtin
1714
+ def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None):
1715
+ """
1716
+ Returns a pointer to a block in a parent tensor
1717
+
1718
+ :param base: The base pointer to the parent tensor
1719
+ :param shape: The shape of the parent tensor
1720
+ :param strides: The strides of the parent tensor
1721
+ :param offsets: The offsets to the block
1722
+ :param block_shape: The shape of the block
1723
+ :param order: The order of the original data format
1724
+ """
1725
+ return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
1726
+
1727
+
1728
+ @_tensor_member_fn
1729
+ @builtin
1730
+ def advance(base, offsets, _builder=None):
1731
+ """
1732
+ Advance a block pointer
1733
+
1734
+ :param base: the block pointer to advance
1735
+ :param offsets: the offsets to advance, a tuple by dimension
1736
+ """
1737
+ return semantic.advance(base, offsets, _builder)
1738
+
1739
+
1740
+ # -----------------------
1741
+ # Atomic Memory Operations
1742
+ # -----------------------
1743
+
1744
+
1745
+ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
1746
+
1747
+ def _decorator(func: T) -> T:
1748
+ docstr = f"""
1749
+ Performs an atomic {name} at the memory location specified by :code:`pointer`.
1750
+
1751
+ Return the data stored at :code:`pointer` before the atomic operation.
1752
+
1753
+ :param pointer: The memory locations to operate on
1754
+ :type pointer: Block of dtype=triton.PointerDType"""
1755
+ if has_cmp:
1756
+ docstr += """
1757
+ :param cmp: The values expected to be found in the atomic object
1758
+ :type cmp: Block of dtype=pointer.dtype.element_ty"""
1759
+ docstr += """
1760
+ :param val: The values with which to perform the atomic operation
1761
+ :type val: Block of dtype=pointer.dtype.element_ty
1762
+ :param sem: Specifies the memory semantics for the operation. Acceptable values are "acquire",
1763
+ "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided,
1764
+ the function defaults to using "acq_rel" semantics.
1765
+ :type sem: str, optional
1766
+ :param scope: Defines the scope of threads that observe the synchronizing effect of the atomic operation.
1767
+ Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu".
1768
+ :type scope: str, optional
1769
+ """
1770
+ func.__doc__ = docstr
1771
+ return func
1772
+
1773
+ return _decorator
1774
+
1775
+
1776
+ @_tensor_member_fn
1777
+ @builtin
1778
+ @_add_atomic_docstr("compare-and-swap", has_cmp=True)
1779
+ def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None):
1780
+ cmp = semantic.to_tensor(cmp, _builder)
1781
+ val = semantic.to_tensor(val, _builder)
1782
+ sem = _constexpr_to_value(sem)
1783
+ scope = _constexpr_to_value(scope)
1784
+ return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder)
1785
+
1786
+
1787
+ @_tensor_member_fn
1788
+ @builtin
1789
+ @_add_atomic_docstr("exchange")
1790
+ def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1791
+ val = semantic.to_tensor(val, _builder)
1792
+ sem = _constexpr_to_value(sem)
1793
+ scope = _constexpr_to_value(scope)
1794
+ mask = _constexpr_to_value(mask)
1795
+ return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder)
1796
+
1797
+
1798
+ @_tensor_member_fn
1799
+ @builtin
1800
+ @_add_atomic_docstr("add")
1801
+ def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1802
+ val = semantic.to_tensor(val, _builder)
1803
+ sem = _constexpr_to_value(sem)
1804
+ scope = _constexpr_to_value(scope)
1805
+ mask = _constexpr_to_value(mask)
1806
+ return semantic.atomic_add(pointer, val, mask, sem, scope, _builder)
1807
+
1808
+
1809
+ @_tensor_member_fn
1810
+ @builtin
1811
+ @_add_atomic_docstr("max")
1812
+ def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1813
+ val = semantic.to_tensor(val, _builder)
1814
+ sem = _constexpr_to_value(sem)
1815
+ scope = _constexpr_to_value(scope)
1816
+ mask = _constexpr_to_value(mask)
1817
+ return semantic.atomic_max(pointer, val, mask, sem, scope, _builder)
1818
+
1819
+
1820
+ @_tensor_member_fn
1821
+ @builtin
1822
+ @_add_atomic_docstr("min")
1823
+ def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1824
+ val = semantic.to_tensor(val, _builder)
1825
+ sem = _constexpr_to_value(sem)
1826
+ scope = _constexpr_to_value(scope)
1827
+ mask = _constexpr_to_value(mask)
1828
+ return semantic.atomic_min(pointer, val, mask, sem, scope, _builder)
1829
+
1830
+
1831
+ @_tensor_member_fn
1832
+ @builtin
1833
+ @_add_atomic_docstr("logical and")
1834
+ def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1835
+ val = semantic.to_tensor(val, _builder)
1836
+ sem = _constexpr_to_value(sem)
1837
+ scope = _constexpr_to_value(scope)
1838
+ mask = _constexpr_to_value(mask)
1839
+ return semantic.atomic_and(pointer, val, mask, sem, scope, _builder)
1840
+
1841
+
1842
+ @_tensor_member_fn
1843
+ @builtin
1844
+ @_add_atomic_docstr("logical or")
1845
+ def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1846
+ val = semantic.to_tensor(val, _builder)
1847
+ sem = _constexpr_to_value(sem)
1848
+ scope = _constexpr_to_value(scope)
1849
+ mask = _constexpr_to_value(mask)
1850
+ return semantic.atomic_or(pointer, val, mask, sem, scope, _builder)
1851
+
1852
+
1853
+ @_tensor_member_fn
1854
+ @builtin
1855
+ @_add_atomic_docstr("logical xor")
1856
+ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
1857
+ val = semantic.to_tensor(val, _builder)
1858
+ sem = _constexpr_to_value(sem)
1859
+ scope = _constexpr_to_value(scope)
1860
+ mask = _constexpr_to_value(mask)
1861
+ return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder)
1862
+
1863
+
1864
+ # -----------------------
1865
+ # Conditioning
1866
+ # -----------------------
1867
+
1868
+
1869
+ @builtin
1870
+ def where(condition, x, y, _builder=None):
1871
+ """
1872
+ Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
1873
+
1874
+ Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`.
1875
+
1876
+ If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead.
1877
+
1878
+ The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`.
1879
+ :code:`x` and :code:`y` must have the same data type.
1880
+
1881
+ :param condition: When True (nonzero), yield x, otherwise yield y.
1882
+ :type condition: Block of triton.bool
1883
+ :param x: values selected at indices where condition is True.
1884
+ :param y: values selected at indices where condition is False.
1885
+ """
1886
+ condition = semantic.to_tensor(condition, _builder)
1887
+ x = _unwrap_if_constexpr(x)
1888
+ y = _unwrap_if_constexpr(y)
1889
+ return semantic.where(condition, x, y, _builder)
1890
+
1891
+
1892
+ # -----------------------
1893
+ # Math
1894
+ # -----------------------
1895
+
1896
+
1897
+ @builtin
1898
+ def add(x, y, sanitize_overflow: constexpr = True, _builder=None):
1899
+ x = _unwrap_if_constexpr(x)
1900
+ y = _unwrap_if_constexpr(y)
1901
+ return semantic.add(x, y, sanitize_overflow, _builder)
1902
+
1903
+
1904
+ @builtin
1905
+ def sub(x, y, sanitize_overflow: constexpr = True, _builder=None):
1906
+ x = _unwrap_if_constexpr(x)
1907
+ y = _unwrap_if_constexpr(y)
1908
+ return semantic.sub(x, y, sanitize_overflow, _builder)
1909
+
1910
+
1911
+ @builtin
1912
+ def mul(x, y, sanitize_overflow: constexpr = True, _builder=None):
1913
+ x = _unwrap_if_constexpr(x)
1914
+ y = _unwrap_if_constexpr(y)
1915
+ return semantic.mul(x, y, sanitize_overflow, _builder)
1916
+
1917
+
1918
+ @builtin
1919
+ def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
1920
+ """
1921
+ Computes the element-wise minimum of :code:`x` and :code:`y`.
1922
+
1923
+ :param x: the first input tensor
1924
+ :type x: Block
1925
+ :param y: the second input tensor
1926
+ :type y: Block
1927
+ :param propagate_nan: whether to propagate NaN values.
1928
+ :type propagate_nan: tl.PropagateNan
1929
+
1930
+ .. seealso:: :class:`tl.PropagateNan`
1931
+ """
1932
+ x = semantic.to_tensor(x, _builder)
1933
+ y = semantic.to_tensor(y, _builder)
1934
+ x = _promote_bfloat16_to_float32(x, _builder=_builder)
1935
+ y = _promote_bfloat16_to_float32(y, _builder=_builder)
1936
+ propagate_nan = _constexpr_to_value(propagate_nan)
1937
+ return semantic.minimum(x, y, propagate_nan, _builder)
1938
+
1939
+
1940
+ @builtin
1941
+ def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
1942
+ """
1943
+ Computes the element-wise maximum of :code:`x` and :code:`y`.
1944
+
1945
+ :param x: the first input tensor
1946
+ :type x: Block
1947
+ :param y: the second input tensor
1948
+ :type y: Block
1949
+ :param propagate_nan: whether to propagate NaN values.
1950
+ :type propagate_nan: tl.PropagateNan
1951
+
1952
+ .. seealso:: :class:`tl.PropagateNan`
1953
+ """
1954
+ x = semantic.to_tensor(x, _builder)
1955
+ y = semantic.to_tensor(y, _builder)
1956
+ x = _promote_bfloat16_to_float32(x, _builder=_builder)
1957
+ y = _promote_bfloat16_to_float32(y, _builder=_builder)
1958
+ propagate_nan = _constexpr_to_value(propagate_nan)
1959
+ return semantic.maximum(x, y, propagate_nan, _builder)
1960
+
1961
+
1962
+ @builtin
1963
+ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
1964
+ """
1965
+ Clamps the input tensor :code:`x` within the range [min, max].
1966
+ Behavior when :code:`min` > :code:`max` is undefined.
1967
+
1968
+ :param x: the input tensor
1969
+ :type x: Block
1970
+ :param min: the lower bound for clamping
1971
+ :type min: Block
1972
+ :param max: the upper bound for clamping
1973
+ :type max: Block
1974
+ :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor.
1975
+ If either :code:`min` or :code:`max` is NaN, the result is undefined.
1976
+ :type propagate_nan: tl.PropagateNan
1977
+
1978
+ .. seealso:: :class:`tl.PropagateNan`
1979
+ """
1980
+ x = semantic.to_tensor(x, _builder)
1981
+ min = semantic.to_tensor(min, _builder)
1982
+ max = semantic.to_tensor(max, _builder)
1983
+ x = _promote_bfloat16_to_float32(x, _builder=_builder)
1984
+ min = _promote_bfloat16_to_float32(min, _builder=_builder)
1985
+ max = _promote_bfloat16_to_float32(max, _builder=_builder)
1986
+
1987
+ propagate_nan = _constexpr_to_value(propagate_nan)
1988
+
1989
+ return semantic.clamp(x, min, max, propagate_nan, _builder)
1990
+
1991
+
1992
+ # -----------------------
1993
+ # Reductions
1994
+ # -----------------------
1995
+
1996
+
1997
+ def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]:
1998
+
1999
+ def _decorator(func: T) -> T:
2000
+ docstr = """
2001
+ Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
2002
+
2003
+ :param input: the input values
2004
+ :type input: Tensor
2005
+ :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions
2006
+ :type axis: int
2007
+ :param keep_dims: if true, keep the reduced dimensions with length 1
2008
+ :type keep_dims: bool"""
2009
+ if return_indices_arg is not None:
2010
+ docstr += f"""
2011
+ :param {return_indices_arg}: if true, return index corresponding to the {name} value
2012
+ :type {return_indices_arg}: bool"""
2013
+ if tie_break_arg is not None:
2014
+ docstr += f"""
2015
+ :param {tie_break_arg}: if true, in case of a tie (i.e., multiple elements have the same {name} value), return the left-most index for values that aren't NaN
2016
+ :type {tie_break_arg}: bool"""
2017
+
2018
+ func.__doc__ = docstr.format(name=name)
2019
+ return func
2020
+
2021
+ return _decorator
2022
+
2023
+
2024
+ @contextmanager
2025
+ def _insertion_guard(builder):
2026
+ ip = builder.get_insertion_point()
2027
+ yield
2028
+ builder.restore_insertion_point(ip)
2029
+
2030
+
2031
+ @_tensor_member_fn
2032
+ @builtin
2033
+ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None):
2034
+ """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
2035
+
2036
+ :param input: the input tensor, or tuple of tensors
2037
+ :type input: Tensor
2038
+ :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions
2039
+ :type axis: int | None
2040
+ :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit)
2041
+ :type combine_fn: Callable
2042
+ :param keep_dims: if true, keep the reduced dimensions with length 1
2043
+ :type keep_dims: bool
2044
+
2045
+ """
2046
+ if isinstance(input, tensor):
2047
+ return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0]
2048
+
2049
+ def make_combine_region(reduce_op):
2050
+ in_scalar_tys = [t.type.scalar for t in input]
2051
+ prototype = function_type(in_scalar_tys, in_scalar_tys * 2)
2052
+
2053
+ region = reduce_op.get_region(0)
2054
+ with _insertion_guard(_builder):
2055
+ param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
2056
+ block = _builder.create_block_with_parent(region, param_types)
2057
+ args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
2058
+ results = _generator.call_JitFunction(combine_fn, args, kwargs={})
2059
+ if isinstance(results, tensor):
2060
+ handles = [results.handle]
2061
+ else:
2062
+ handles = [r.handle for r in results]
2063
+ _builder.create_reduce_ret(*handles)
2064
+
2065
+ def expand_ndims(t, ndims):
2066
+ for _ in builtins.range(ndims):
2067
+ t = expand_dims(t, 0, _builder=_builder)
2068
+ return t
2069
+
2070
+ axis = _constexpr_to_value(axis)
2071
+ keep_dims = _constexpr_to_value(keep_dims)
2072
+ if axis is not None:
2073
+ axis = _wrap_axis(axis, len(input[0].shape))
2074
+ ret = semantic.reduction(input, axis, make_combine_region, _builder)
2075
+ if keep_dims:
2076
+ if axis is not None:
2077
+ ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret)
2078
+ else:
2079
+ ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret)
2080
+ return ret
2081
+
2082
+
2083
+ @builtin
2084
+ def _promote_bfloat16_to_float32(t, _builder=None):
2085
+ scalar_ty = t.type.scalar
2086
+
2087
+ # hardware doesn't support FMAX, FMIN, CMP for bfloat16
2088
+ if scalar_ty is bfloat16:
2089
+ return t.to(float32, _builder=_builder)
2090
+ return t
2091
+
2092
+
2093
+ @builtin
2094
+ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None):
2095
+ axis = _constexpr_to_value(axis)
2096
+ n = input.shape[axis]
2097
+ index = arange(0, n, _builder=_builder)
2098
+
2099
+ if len(input.shape) > 1:
2100
+ # Broadcast index across the non-reduced axes
2101
+ axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))]
2102
+ del axes_to_expand[axis]
2103
+ index = expand_dims(index, axes_to_expand, _builder=_builder)
2104
+ index = broadcast_to(index, input.shape, _builder=_builder)
2105
+
2106
+ rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _builder=_builder,
2107
+ _generator=_generator)
2108
+ return rvalue, rindices
2109
+
2110
+
2111
+ # -----------------------
2112
+ # Scans
2113
+ # -----------------------
2114
+
2115
+
2116
+ def _add_scan_docstr(name: str) -> Callable[[T], T]:
2117
+
2118
+ def _decorator(func: T) -> T:
2119
+ docstr = """
2120
+ Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
2121
+
2122
+ :param input: the input values
2123
+ :type input: Tensor
2124
+ :param axis: the dimension along which the scan should be done
2125
+ :type axis: int"""
2126
+ func.__doc__ = docstr.format(name=name)
2127
+ return func
2128
+
2129
+ return _decorator
2130
+
2131
+
2132
+ @_tensor_member_fn
2133
+ @builtin
2134
+ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _generator=None):
2135
+ """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry
2136
+
2137
+ :param input: the input tensor, or tuple of tensors
2138
+ :type input: Tensor
2139
+ :param axis: the dimension along which the reduction should be done
2140
+ :type axis: int
2141
+ :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit)
2142
+ :type combine_fn: Callable
2143
+ :param reverse: whether to apply the associative scan in the reverse direction along axis
2144
+ :type reverse: bool
2145
+
2146
+ """
2147
+ if isinstance(input, tensor):
2148
+ return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0]
2149
+
2150
+ def make_combine_region(scan_op):
2151
+ in_scalar_tys = [t.type.scalar for t in input]
2152
+ prototype = function_type(in_scalar_tys, in_scalar_tys * 2)
2153
+
2154
+ region = scan_op.get_region(0)
2155
+ with _insertion_guard(_builder):
2156
+ param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
2157
+ block = _builder.create_block_with_parent(region, param_types)
2158
+ args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
2159
+ results = _generator.call_JitFunction(combine_fn, args, kwargs={})
2160
+ if isinstance(results, tensor):
2161
+ handles = [results.handle]
2162
+ else:
2163
+ handles = [r.handle for r in results]
2164
+ _builder.create_scan_ret(*handles)
2165
+
2166
+ axis = _constexpr_to_value(axis)
2167
+ if axis is not None:
2168
+ axis = _wrap_axis(axis, len(input[0].shape))
2169
+ return semantic.associative_scan(input, axis, make_combine_region, reverse, _builder)
2170
+
2171
+
2172
+ @_tensor_member_fn
2173
+ @builtin
2174
+ def histogram(input, num_bins, _builder=None, _generator=None):
2175
+ """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0.
2176
+
2177
+ :param input: the input tensor
2178
+ :type input: Tensor
2179
+ :param num_bins: number of histogram bins
2180
+ :type num_bins: int
2181
+
2182
+ """
2183
+ num_bins = _constexpr_to_value(num_bins)
2184
+ return semantic.histogram(input, num_bins, _builder)
2185
+
2186
+
2187
+ # -----------------------
2188
+ # Compiler Hint Ops
2189
+ # -----------------------
2190
+
2191
+
2192
+ @builtin
2193
+ def debug_barrier(_builder=None):
2194
+ '''
2195
+ Insert a barrier to synchronize all threads in a block.
2196
+ '''
2197
+ return semantic.debug_barrier(_builder)
2198
+
2199
+
2200
+ @builtin
2201
+ def multiple_of(input, values, _builder=None):
2202
+ """
2203
+ Let the compiler know that the values in :code:`input` are all multiples of :code:`value`.
2204
+ """
2205
+ if isinstance(values, constexpr):
2206
+ values = [values]
2207
+ for i, d in enumerate(values):
2208
+ if not isinstance(d, constexpr):
2209
+ raise TypeError(f"values element {i} must have type `constexpr`")
2210
+ if not isinstance(d.value, int):
2211
+ raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2212
+ values = [x.value for x in values]
2213
+ return semantic.multiple_of(input, values)
2214
+
2215
+
2216
+ @builtin
2217
+ def max_contiguous(input, values, _builder=None):
2218
+ """
2219
+ Let the compiler know that the `value` first values in :code:`input` are contiguous.
2220
+ """
2221
+ if isinstance(values, constexpr):
2222
+ values = [values]
2223
+ for i, d in enumerate(values):
2224
+ if not isinstance(d, constexpr):
2225
+ raise TypeError(f"values element {i} must have type `constexpr`")
2226
+ if not isinstance(d.value, int):
2227
+ raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2228
+ values = [x.value for x in values]
2229
+ return semantic.max_contiguous(input, values)
2230
+
2231
+
2232
+ @builtin
2233
+ def max_constancy(input, values, _builder=None):
2234
+ """
2235
+ Let the compiler know that the `value` first values in :code:`input` are constant.
2236
+
2237
+ e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal,
2238
+ for example [0, 0, 0, 0, 1, 1, 1, 1].
2239
+ """
2240
+ if isinstance(values, constexpr):
2241
+ values = [values]
2242
+ for i, d in enumerate(values):
2243
+ if not isinstance(d, constexpr):
2244
+ raise TypeError(f"values element {i} must have type `constexpr`")
2245
+ if not isinstance(d.value, int):
2246
+ raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2247
+ values = [x.value for x in values]
2248
+ return semantic.max_constancy(input, values)
2249
+
2250
+
2251
+ @builtin
2252
+ def assume(cond, _builder=None):
2253
+ '''
2254
+ Allow compiler to assume the :code:`cond` is True.
2255
+ '''
2256
+ return semantic.assume(semantic.to_tensor(cond, _builder), _builder)
2257
+
2258
+
2259
+ # -----------------------
2260
+ # Debugging functions
2261
+ # -----------------------
2262
+
2263
+
2264
+ @builtin
2265
+ def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None):
2266
+ '''
2267
+ Print the values at compile time. The parameters are the same as the builtin :code:`print`.
2268
+
2269
+ NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`,
2270
+ which has special requirements for the arguments.
2271
+
2272
+ .. highlight:: python
2273
+ .. code-block:: python
2274
+
2275
+ tl.static_print(f"BLOCK_SIZE={BLOCK_SIZE}")
2276
+ '''
2277
+ pass
2278
+
2279
+
2280
+ @builtin
2281
+ def static_assert(cond, msg="", _builder=None):
2282
+ '''
2283
+ Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable
2284
+ is set.
2285
+
2286
+ .. highlight:: python
2287
+ .. code-block:: python
2288
+
2289
+ tl.static_assert(BLOCK_SIZE == 1024)
2290
+ '''
2291
+ pass
2292
+
2293
+
2294
+ @builtin
2295
+ def device_print(prefix, *args, hex=False, _builder=None):
2296
+ '''
2297
+ Print the values at runtime from the device. String formatting does not work for runtime values, so you should
2298
+ provide the values you want to print as arguments. The first value must be a string, all following values must
2299
+ be scalars or tensors.
2300
+
2301
+ Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match
2302
+ this function (not the normal requirements for :code:`print`).
2303
+
2304
+ .. highlight:: python
2305
+ .. code-block:: python
2306
+
2307
+ tl.device_print("pid", pid)
2308
+ print("pid", pid)
2309
+
2310
+ On CUDA, printfs are streamed through a buffer of limited size (on one host,
2311
+ we measured the default as 6912 KiB, but this may not be consistent across
2312
+ GPUs and CUDA versions). If you notice some printfs are being dropped, you
2313
+ can increase the buffer size by calling
2314
+
2315
+ .. highlight:: python
2316
+ .. code-block:: python
2317
+
2318
+ triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes)
2319
+
2320
+ CUDA may raise an error if you try to change this value after running a
2321
+ kernel that uses printfs. The value set here may only affect the current
2322
+ device (so if you have multiple GPUs, you'd need to call it multiple times).
2323
+
2324
+ :param prefix: a prefix to print before the values. This is required to be a string literal.
2325
+ :param args: the values to print. They can be any tensor or scalar.
2326
+ :param hex: print all values as hex instead of decimal
2327
+ '''
2328
+ import string
2329
+ prefix = _constexpr_to_value(prefix)
2330
+ assert isinstance(prefix, str), f"{prefix} is not string"
2331
+ b_ascii = True
2332
+ for ch in prefix:
2333
+ if ch not in string.printable:
2334
+ b_ascii = False
2335
+ break
2336
+ assert b_ascii, f"{prefix} is not an ascii string"
2337
+ new_args = []
2338
+ for arg in args:
2339
+ new_args.append(semantic.to_tensor(arg, _builder))
2340
+ return semantic.device_print(prefix, new_args, hex, _builder)
2341
+
2342
+
2343
+ @builtin
2344
+ def device_assert(cond, msg="", _builder=None):
2345
+ '''
2346
+ Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG`
2347
+ is set to a value besides :code:`0` in order for this to have any effect.
2348
+
2349
+ Using the Python :code:`assert` statement is the same as calling this function, except that the second argument
2350
+ must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must
2351
+ be set for this :code:`assert` statement to have any effect.
2352
+
2353
+ .. highlight:: python
2354
+ .. code-block:: python
2355
+
2356
+ tl.device_assert(pid == 0)
2357
+ assert pid == 0, f"pid != 0"
2358
+
2359
+ :param cond: the condition to assert. This is required to be a boolean tensor.
2360
+ :param msg: the message to print if the assertion fails. This is required to be a string literal.
2361
+ '''
2362
+ msg = _constexpr_to_value(msg)
2363
+ return semantic.device_assert(semantic.to_tensor(cond, _builder), msg, _builder)
2364
+
2365
+
2366
+ @builtin
2367
+ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]],
2368
+ is_pure: bool, pack: int, _builder=None):
2369
+ '''
2370
+ Execute inline assembly over a tensor. Essentially, this is :code:`map`
2371
+ where the function is inline assembly.
2372
+
2373
+ The input tensors :code:`args` are implicitly broadcasted to the same shape.
2374
+
2375
+ :code:`dtype` can be a tuple of types, in which case the output is a
2376
+ tuple of tensors.
2377
+
2378
+ Each invocation of the inline asm processes :code:`pack` elements at a
2379
+ time. Exactly which set of inputs a block receives is unspecified.
2380
+ Input elements of size less than 4 bytes are packed into 4-byte
2381
+ registers.
2382
+
2383
+ This op does not support empty :code:`dtype` -- the inline asm must
2384
+ return at least one tensor, even if you don't need it. You can work
2385
+ around this by returning a dummy tensor of arbitrary type; it shouldn't
2386
+ cost you anything if you don't use it.
2387
+
2388
+ Example using
2389
+ `PTX <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html>`_
2390
+ assembly:
2391
+
2392
+ .. highlight:: python
2393
+ .. code-block:: python
2394
+
2395
+ @triton.jit
2396
+ def kernel(A, B, C, D, BLOCK: tl.constexpr):
2397
+ a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor
2398
+ b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor
2399
+
2400
+ # For each (a,b) in zip(a,b), perform the following:
2401
+ # - Let ai be `a` converted to int32.
2402
+ # - Let af be `a` converted to float.
2403
+ # - Let m be the max of ai and b.
2404
+ # - Return ai and mi.
2405
+ # Do the above 4 elements at a time.
2406
+ (c, d) = tl.inline_asm_elementwise(
2407
+ asm="""
2408
+ {
2409
+ // Unpack `a` into `ai`.
2410
+ .reg .b8 tmp<4>;
2411
+ mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8;
2412
+ cvt.u32.u8 $0, tmp0;
2413
+ cvt.u32.u8 $1, tmp1;
2414
+ cvt.u32.u8 $2, tmp2;
2415
+ cvt.u32.u8 $3, tmp3;
2416
+ }
2417
+ // Convert `ai` to float.
2418
+ cvt.rn.f32.s32 $4, $0;
2419
+ cvt.rn.f32.s32 $5, $1;
2420
+ cvt.rn.f32.s32 $6, $2;
2421
+ cvt.rn.f32.s32 $7, $3;
2422
+ // Take max of `ai` and `b`.
2423
+ max.f32 $4, $4, $9;
2424
+ max.f32 $5, $5, $10;
2425
+ max.f32 $6, $6, $11;
2426
+ max.f32 $7, $7, $12;
2427
+ """,
2428
+ constraints=(
2429
+ # 8 output registers, namely
2430
+ # $0=ai0, $1=ai1, $2=ai2, $3=ai3,
2431
+ # $4=m0, $5=m1, $6=m2, $7=m3.
2432
+ "=r,=r,=r,=r,=r,=r,=r,=r,"
2433
+ # 5 input registers, namely
2434
+ # $8=ai,
2435
+ # $9=b0, $10=b1, $11=b2, $12=b3.
2436
+ # The four elements from `a` are all packed into one register.
2437
+ "r,r,r,r,r"),
2438
+ args=[a, b],
2439
+ dtype=(tl.int32, tl.float32),
2440
+ is_pure=True,
2441
+ pack=4,
2442
+ )
2443
+ tl.store(C + tl.arange(0, BLOCK), c)
2444
+ tl.store(D + tl.arange(0, BLOCK), d)
2445
+
2446
+ :param asm: assembly to run. Must match target's assembly format.
2447
+ :param constraints: asm constraints in
2448
+ `LLVM format <https://llvm.org/docs/LangRef.html#inline-asm-constraint-string>`_
2449
+ :param args: the input tensors, whose values are passed to the asm block
2450
+ :param dtype: the element type(s) of the returned tensor(s)
2451
+ :param is_pure: if true, the compiler assumes the asm block has no side-effects
2452
+ :param pack: the number of elements to be processed by one instance of inline assembly
2453
+ :param _builder: the builder
2454
+ :return: one tensor or a tuple of tensors of the given dtypes
2455
+ '''
2456
+ asm = _constexpr_to_value(asm)
2457
+ constraints = _constexpr_to_value(constraints)
2458
+ pack = _constexpr_to_value(pack)
2459
+ is_pure = _constexpr_to_value(is_pure)
2460
+
2461
+ # Wrap `dtype` in a tuple if it's not already.
2462
+ try:
2463
+ iter(dtype) # type: ignore
2464
+ has_multiple_outputs = True
2465
+ except TypeError:
2466
+ has_multiple_outputs = False
2467
+ dtype = (dtype, ) # type: ignore
2468
+
2469
+ dtype = typing.cast(Sequence[_DtypeClass], dtype)
2470
+
2471
+ res_tys = dtype
2472
+ if dispatch_args := [semantic.to_tensor(arg, _builder) for arg in args]:
2473
+ bin_op_type_checking = partial(
2474
+ semantic.binary_op_type_checking_impl,
2475
+ builder=_builder,
2476
+ arithmetic_check=False,
2477
+ allow_lhs_ptr=True,
2478
+ allow_rhs_ptr=True,
2479
+ )
2480
+ broadcast_arg = dispatch_args[0]
2481
+ # Get the broadcast shape over all the arguments
2482
+ for item in dispatch_args:
2483
+ _, broadcast_arg = bin_op_type_checking(item, broadcast_arg)
2484
+ if broadcast_arg.shape:
2485
+ # Change the shape of each argument based on the broadcast shape
2486
+ for i, item in enumerate(dispatch_args):
2487
+ dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg)
2488
+ res_tys = [block_type(dt, broadcast_arg.shape) for dt in dtype]
2489
+ handles = [t.handle for t in dispatch_args]
2490
+ call = _builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(_builder) for ty in res_tys], is_pure, pack)
2491
+
2492
+ if not has_multiple_outputs:
2493
+ return tensor(call.get_result(0), res_tys[0])
2494
+ return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys))
2495
+
2496
+
2497
+ # -----------------------
2498
+ # Iterators
2499
+ # -----------------------
2500
+
2501
+
2502
+ class static_range:
2503
+ """
2504
+ Iterator that counts upward forever.
2505
+
2506
+ .. highlight:: python
2507
+ .. code-block:: python
2508
+
2509
+ @triton.jit
2510
+ def kernel(...):
2511
+ for i in tl.static_range(10):
2512
+ ...
2513
+ :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
2514
+ :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively.
2515
+ :param arg1: the start value.
2516
+ :param arg2: the end value.
2517
+ :param step: the step value.
2518
+ """
2519
+
2520
+ def __init__(self, arg1, arg2=None, step=None):
2521
+ assert isinstance(arg1, constexpr), f"{arg1} used as tl.static_range start value is not a constexpr"
2522
+ if step is None:
2523
+ self.step = constexpr(1)
2524
+ else:
2525
+ assert isinstance(step, constexpr), f"{step} used as tl.static_range step value is not a constexpr"
2526
+ self.step = step
2527
+ if arg2 is None:
2528
+ self.start = constexpr(0)
2529
+ self.end = arg1
2530
+ else:
2531
+ assert isinstance(arg2, constexpr), f"{arg2} used as tl.static_range end value is not a constexpr"
2532
+ self.start = arg1
2533
+ self.end = arg2
2534
+
2535
+ def __iter__(self):
2536
+ raise RuntimeError("static_range can only be used in @triton.jit'd functions")
2537
+
2538
+ def __next__(self):
2539
+ raise RuntimeError("static_range can only be used in @triton.jit'd functions")
2540
+
2541
+
2542
+ class range:
2543
+ """
2544
+ Iterator that counts upward forever.
2545
+
2546
+ .. highlight:: python
2547
+ .. code-block:: python
2548
+
2549
+ @triton.jit
2550
+ def kernel(...):
2551
+ for i in tl.range(10, num_stages=3):
2552
+ ...
2553
+ :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
2554
+ :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler.
2555
+ :param arg1: the start value.
2556
+ :param arg2: the end value.
2557
+ :param step: the step value.
2558
+ :param num_stages: pipeline the loop into this many stages (so there are
2559
+ :code:`num_stages` iterations of the loop in flight at once).
2560
+
2561
+ Note this is subtly different than passing :code:`num_stages` as a
2562
+ kernel argument. The kernel argument only pipelines loads that feed
2563
+ into :code:`dot` operations, while this attribute tries to pipeline most
2564
+ (though not all) loads in this loop.
2565
+ :param loop_unroll_factor: Tells the Triton IR level loop unroller how many
2566
+ times to unroll a for loop that this range is used with. Less than 2 for
2567
+ this value implies no unrolling.
2568
+ """
2569
+
2570
+ def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None):
2571
+ if step is None:
2572
+ self.step = constexpr(1)
2573
+ else:
2574
+ self.step = step
2575
+ if arg2 is None:
2576
+ self.start = constexpr(0)
2577
+ self.end = arg1
2578
+ else:
2579
+ self.start = arg1
2580
+ self.end = arg2
2581
+ self.num_stages = num_stages
2582
+ self.loop_unroll_factor = loop_unroll_factor
2583
+
2584
+ def __iter__(self):
2585
+ raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
2586
+
2587
+ def __next__(self):
2588
+ raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
2589
+
2590
+
2591
+ # -----------------------
2592
+ # Extern functions
2593
+ # -----------------------
2594
+
2595
+
2596
+ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
2597
+ is_pure: bool, _builder=None):
2598
+ '''
2599
+ Dispatch a function to a library
2600
+ :param func: the function to dispatch
2601
+ :param lib_name: the name of the library
2602
+ :param lib_path: the path of the library
2603
+ :param args: the arguments of the function
2604
+ :param arg_type_symbol_dict: the type of the arguments
2605
+ :param ret_shape: the shape of the return value
2606
+ :param _builder: the builder
2607
+ :return: the return value of the function
2608
+ '''
2609
+ if len(arg_type_symbol_dict) == 0:
2610
+ raise ValueError("arg_type_symbol_dict is empty")
2611
+
2612
+ num_args = len(list(arg_type_symbol_dict.keys())[0])
2613
+ if len(args) != num_args:
2614
+ raise ValueError(f"length of input args does not match."
2615
+ f"Expect {len(args)}, got {num_args}")
2616
+
2617
+ arg_types = []
2618
+ arg_list = []
2619
+ for arg in args:
2620
+ if isinstance(arg, tensor):
2621
+ arg_types.append(arg.dtype)
2622
+ arg_list.append(arg.handle)
2623
+ else:
2624
+ arg_types.append(type(arg))
2625
+ arg_list.append(arg)
2626
+ arg_types = tuple(arg_types)
2627
+
2628
+ if arg_types not in arg_type_symbol_dict:
2629
+ raise ValueError(f"input arg type does not match."
2630
+ f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
2631
+ else:
2632
+ symbol = arg_type_symbol_dict[arg_types][0]
2633
+ ret_type = arg_type_symbol_dict[arg_types][1]
2634
+ if ret_shape:
2635
+ ret_type = block_type(ret_type, ret_shape)
2636
+ return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
2637
+
2638
+
2639
+ @builtin
2640
+ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
2641
+ _builder=None):
2642
+ '''
2643
+ Dispatch an elementwise function to a library
2644
+ :param lib_name: the name of the library
2645
+ :param lib_path: the path of the library
2646
+ :param args: the arguments of the function
2647
+ :param arg_type_symbol_dict: the type of the arguments
2648
+ :param is_pure: whether the function is pure
2649
+ :param _builder: the builder
2650
+ :return: the return value of the function
2651
+ '''
2652
+ dispatch_args = args.copy()
2653
+ all_scalar = True
2654
+ ret_shape = None
2655
+ arg_types = []
2656
+ for i in builtins.range(len(dispatch_args)):
2657
+ dispatch_args[i] = semantic.to_tensor(dispatch_args[i], _builder)
2658
+ arg_types.append(dispatch_args[i].dtype)
2659
+ if dispatch_args[i].type.is_block():
2660
+ all_scalar = False
2661
+ if len(arg_types) > 0:
2662
+ arg_types = tuple(arg_types)
2663
+ arithmetic_check = True
2664
+ # If there's a type tuple that is not supported by the library, we will do arithmetic check
2665
+ if arg_types in arg_type_symbol_dict:
2666
+ arithmetic_check = False
2667
+ broadcast_arg = dispatch_args[0]
2668
+ # Get the broadcast shape over all the arguments
2669
+ for item in dispatch_args:
2670
+ _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
2671
+ arithmetic_check=arithmetic_check)
2672
+ # Change the shape of each argument based on the broadcast shape
2673
+ for i in builtins.range(len(dispatch_args)):
2674
+ dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
2675
+ arithmetic_check=arithmetic_check)
2676
+ if not all_scalar:
2677
+ ret_shape = broadcast_arg.shape
2678
+ func = _builder.create_extern_elementwise
2679
+ return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder)
2680
+
2681
+
2682
+ def binary_op_type_legalization(lhs, rhs, builder):
2683
+ '''
2684
+ Convert both operands to a single common type
2685
+ :param lhs: the left operand
2686
+ :param rhs: the right operand
2687
+ :param builder: the builder
2688
+ '''
2689
+ return semantic.binary_op_type_checking_impl(lhs, rhs, builder)
2690
+
2691
+
2692
+ def extern(fn):
2693
+ """A decorator for external functions."""
2694
+ return builtin(fn)