triton-windows 3.3.1.post19__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/language/core.py CHANGED
@@ -1,19 +1,20 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import math
3
4
  from warnings import warn
4
5
  from contextlib import contextmanager
5
6
  from enum import Enum
6
7
  from functools import partial, wraps
7
8
  import typing
8
9
  from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
10
+ from dataclasses import dataclass
9
11
  import builtins
10
- from ..runtime.jit import jit
12
+ from .. import knobs
13
+ from ..runtime.jit import JITCallable
11
14
  import inspect
12
- import os
13
15
 
14
16
  from .._C.libtriton import ir
15
- from . import semantic
16
- from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape
17
+ from .._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape, get_primitive_bitwidth
17
18
 
18
19
  T = TypeVar('T')
19
20
 
@@ -22,15 +23,23 @@ TRITON_BUILTIN = "__triton_builtin__"
22
23
  PropagateNan = ir.PROPAGATE_NAN
23
24
 
24
25
 
26
+ def must_use_result(x, s=True):
27
+ """If the result of this function is unused, throw an error."""
28
+ if isinstance(x, str):
29
+ return (lambda fn: must_use_result(fn, x))
30
+ x._must_use_result = s
31
+ return x
32
+
33
+
25
34
  def builtin(fn: T) -> T:
26
35
  """Mark a function as a builtin."""
27
36
  assert callable(fn)
28
37
 
29
38
  @wraps(fn)
30
39
  def wrapper(*args, **kwargs):
31
- if "_builder" not in kwargs or kwargs["_builder"] is None:
40
+ if "_semantic" not in kwargs or kwargs["_semantic"] is None:
32
41
  raise ValueError("Did you forget to add @triton.jit ? "
33
- "(`_builder` argument must be provided outside of JIT functions.)")
42
+ "(`_semantic` argument must be provided outside of JIT functions.)")
34
43
  return fn(*args, **kwargs)
35
44
 
36
45
  setattr(wrapper, TRITON_BUILTIN, True)
@@ -53,8 +62,8 @@ def _tensor_member_fn(fn: T) -> T:
53
62
  """
54
63
  assert callable(fn)
55
64
  orig_sig = inspect.signature(fn)
56
- # Does fn take args other than _builder, _generator, and the tensor itself?
57
- has_args = len(orig_sig.parameters.keys() - {"_builder", "_generator"}) > 1
65
+ # Does fn take args other than _semantic, _generator, and the tensor itself?
66
+ has_args = len(orig_sig.parameters.keys() - {"_semantic", "_generator"}) > 1
58
67
 
59
68
  if not fn.__doc__:
60
69
  fn.__doc__ = ""
@@ -78,7 +87,7 @@ def _tensor_member_fn(fn: T) -> T:
78
87
  if is_builtin(fn):
79
88
  setattr(wrapper, TRITON_BUILTIN, True)
80
89
 
81
- setattr(tensor, fn.__name__, wrapper)
90
+ setattr(tensor, fn.__name__, fn if isinstance(fn, JITCallable) else wrapper)
82
91
  return fn
83
92
 
84
93
 
@@ -110,8 +119,8 @@ def is_builtin(fn) -> bool:
110
119
 
111
120
 
112
121
  @builtin
113
- def to_tensor(x, _builder=None):
114
- return semantic.to_tensor(x, _builder)
122
+ def to_tensor(x, _semantic=None):
123
+ return _semantic.to_tensor(x)
115
124
 
116
125
 
117
126
  # -----------------------
@@ -130,90 +139,153 @@ class const:
130
139
  pass
131
140
 
132
141
 
133
- class constexpr:
142
+ class base_value:
143
+ """Base class of values that exist in the triton IR (i.e. not constexprs).
144
+ """
145
+ type: base_type
146
+
147
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
148
+ """Flatten frontend value into a sequence of mlir handles, which are appended
149
+ to the output list
150
+ """
151
+ raise NotImplementedError
152
+
153
+
154
+ class base_type:
155
+
156
+ def __eq__(self, other) -> bool:
157
+ raise NotImplementedError("Types must implement __eq__")
158
+
159
+ def __ne__(self, other) -> bool:
160
+ return not (self == other)
161
+
162
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
163
+ """Build a frontend value with the current dtype, wrapping a list of existing handles.
164
+ cursor is the index of the first handle relevant to this value, and the function
165
+ should return the updated cursor position after any handles consumed by the created value.
166
+ """
167
+ raise NotImplementedError
168
+
169
+ def mangle(self) -> str:
170
+ raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}")
171
+
172
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
173
+ raise NotImplementedError
174
+
175
+
176
+ class constexpr_type(base_type):
177
+
178
+ def __init__(self, value):
179
+ self.value = value
180
+
181
+ def __eq__(self, other):
182
+ return isinstance(other, constexpr_type) and self.value == other.value
183
+
184
+ def __repr__(self) -> str:
185
+ return f"constexpr_type[{self.value}]"
186
+
187
+ def __hash__(self):
188
+ return hash(self.value)
189
+
190
+ def mangle(self) -> str:
191
+ return repr(self)
192
+
193
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
194
+ return
195
+
196
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
197
+ return constexpr(self.value), cursor
198
+
199
+
200
+ class constexpr(base_value):
134
201
  """
135
202
  This class is used to store a value that is known at compile-time.
136
203
  """
137
204
 
138
205
  def __init__(self, value):
139
- if isinstance(value, constexpr):
140
- self.value = value.value
141
- else:
142
- self.value = value
143
- self.type = constexpr
206
+ while isinstance(value, constexpr):
207
+ value = value.value
208
+ self.value = value
209
+ self.type = constexpr_type(value)
144
210
 
145
211
  def __repr__(self) -> str:
146
212
  return f"constexpr[{self.value}]"
147
213
 
214
+ def __hash__(self):
215
+ return hash((self.value, self.type))
216
+
217
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
218
+ return
219
+
148
220
  def __index__(self):
149
221
  return self.value
150
222
 
151
223
  # In interpreter mode, constant values are not wrapped in constexpr,
152
224
  # and therefore do not have a .value attribute.
153
- # As a result, from here and below, we need to call the _constexpr_to_value
225
+ # As a result, from here and below, we need to call the _unwrap_if_constexpr
154
226
  # function to obtain either constexpr.value or the value itself.
155
227
  def __add__(self, other):
156
- return constexpr(self.value + _constexpr_to_value(other))
228
+ return constexpr(self.value + _unwrap_if_constexpr(other))
157
229
 
158
230
  def __radd__(self, other):
159
- return constexpr(_constexpr_to_value(other) + self.value)
231
+ return constexpr(_unwrap_if_constexpr(other) + self.value)
160
232
 
161
233
  def __sub__(self, other):
162
- return constexpr(self.value - _constexpr_to_value(other))
234
+ return constexpr(self.value - _unwrap_if_constexpr(other))
163
235
 
164
236
  def __rsub__(self, other):
165
- return constexpr(_constexpr_to_value(other) - self.value)
237
+ return constexpr(_unwrap_if_constexpr(other) - self.value)
166
238
 
167
239
  def __mul__(self, other):
168
- return constexpr(self.value * _constexpr_to_value(other))
240
+ return constexpr(self.value * _unwrap_if_constexpr(other))
169
241
 
170
242
  def __mod__(self, other):
171
- return constexpr(self.value % _constexpr_to_value(other))
243
+ return constexpr(self.value % _unwrap_if_constexpr(other))
172
244
 
173
245
  def __rmul__(self, other):
174
- return constexpr(_constexpr_to_value(other) * self.value)
246
+ return constexpr(_unwrap_if_constexpr(other) * self.value)
175
247
 
176
248
  def __truediv__(self, other):
177
- return constexpr(self.value / _constexpr_to_value(other))
249
+ return constexpr(self.value / _unwrap_if_constexpr(other))
178
250
 
179
251
  def __rtruediv__(self, other):
180
- return constexpr(_constexpr_to_value(other) / self.value)
252
+ return constexpr(_unwrap_if_constexpr(other) / self.value)
181
253
 
182
254
  def __floordiv__(self, other):
183
- return constexpr(self.value // _constexpr_to_value(other))
255
+ return constexpr(self.value // _unwrap_if_constexpr(other))
184
256
 
185
257
  def __rfloordiv__(self, other):
186
- return constexpr(_constexpr_to_value(other) // self.value)
258
+ return constexpr(_unwrap_if_constexpr(other) // self.value)
187
259
 
188
260
  def __gt__(self, other):
189
- return constexpr(self.value > _constexpr_to_value(other))
261
+ return constexpr(self.value > _unwrap_if_constexpr(other))
190
262
 
191
263
  def __rgt__(self, other):
192
- return constexpr(_constexpr_to_value(other) > self.value)
264
+ return constexpr(_unwrap_if_constexpr(other) > self.value)
193
265
 
194
266
  def __ge__(self, other):
195
- return constexpr(self.value >= _constexpr_to_value(other))
267
+ return constexpr(self.value >= _unwrap_if_constexpr(other))
196
268
 
197
269
  def __rge__(self, other):
198
- return constexpr(_constexpr_to_value(other) >= self.value)
270
+ return constexpr(_unwrap_if_constexpr(other) >= self.value)
199
271
 
200
272
  def __lt__(self, other):
201
- return constexpr(self.value < _constexpr_to_value(other))
273
+ return constexpr(self.value < _unwrap_if_constexpr(other))
202
274
 
203
275
  def __rlt__(self, other):
204
- return constexpr(_constexpr_to_value(other) < self.value)
276
+ return constexpr(_unwrap_if_constexpr(other) < self.value)
205
277
 
206
278
  def __le__(self, other):
207
- return constexpr(self.value <= _constexpr_to_value(other))
279
+ return constexpr(self.value <= _unwrap_if_constexpr(other))
208
280
 
209
281
  def __rle__(self, other):
210
- return constexpr(_constexpr_to_value(other) <= self.value)
282
+ return constexpr(_unwrap_if_constexpr(other) <= self.value)
211
283
 
212
284
  def __eq__(self, other):
213
- return constexpr(self.value == _constexpr_to_value(other))
285
+ return constexpr(self.value == _unwrap_if_constexpr(other))
214
286
 
215
287
  def __ne__(self, other):
216
- return constexpr(self.value != _constexpr_to_value(other))
288
+ return constexpr(self.value != _unwrap_if_constexpr(other))
217
289
 
218
290
  def __bool__(self):
219
291
  return bool(self.value)
@@ -222,19 +294,19 @@ class constexpr:
222
294
  return constexpr(-self.value)
223
295
 
224
296
  def __and__(self, other):
225
- return constexpr(self.value & _constexpr_to_value(other))
297
+ return constexpr(self.value & _unwrap_if_constexpr(other))
226
298
 
227
299
  def logical_and(self, other):
228
- return constexpr(self.value and _constexpr_to_value(other))
300
+ return constexpr(self.value and _unwrap_if_constexpr(other))
229
301
 
230
302
  def __or__(self, other):
231
- return constexpr(self.value | _constexpr_to_value(other))
303
+ return constexpr(self.value | _unwrap_if_constexpr(other))
232
304
 
233
305
  def __xor__(self, other):
234
- return constexpr(self.value ^ _constexpr_to_value(other))
306
+ return constexpr(self.value ^ _unwrap_if_constexpr(other))
235
307
 
236
308
  def logical_or(self, other):
237
- return constexpr(self.value or _constexpr_to_value(other))
309
+ return constexpr(self.value or _unwrap_if_constexpr(other))
238
310
 
239
311
  def __pos__(self):
240
312
  return constexpr(+self.value)
@@ -243,16 +315,16 @@ class constexpr:
243
315
  return constexpr(~self.value)
244
316
 
245
317
  def __pow__(self, other):
246
- return constexpr(self.value**_constexpr_to_value(other))
318
+ return constexpr(self.value**_unwrap_if_constexpr(other))
247
319
 
248
320
  def __rpow__(self, other):
249
- return constexpr(_constexpr_to_value(other)**self.value)
321
+ return constexpr(_unwrap_if_constexpr(other)**self.value)
250
322
 
251
323
  def __rshift__(self, other):
252
- return constexpr(self.value >> _constexpr_to_value(other))
324
+ return constexpr(self.value >> _unwrap_if_constexpr(other))
253
325
 
254
326
  def __lshift__(self, other):
255
- return constexpr(self.value << _constexpr_to_value(other))
327
+ return constexpr(self.value << _unwrap_if_constexpr(other))
256
328
 
257
329
  def __not__(self):
258
330
  return constexpr(not self.value)
@@ -263,14 +335,31 @@ class constexpr:
263
335
  def __call__(self, *args, **kwds):
264
336
  return self.value(*args, **kwds)
265
337
 
338
+ def __getitem__(self, *args):
339
+ args = (_unwrap_if_constexpr(x) for x in _normalize_tuple(args))
340
+ return self.value.__getitem__(*args)
341
+
266
342
 
267
343
  CONSTEXPR_0 = constexpr(0)
268
344
 
269
345
 
270
346
  def _unwrap_if_constexpr(o):
347
+ if isinstance(o, list):
348
+ return [_unwrap_if_constexpr(x) for x in o]
349
+ if isinstance(o, builtins.tuple):
350
+ return builtins.tuple(_unwrap_if_constexpr(x) for x in o)
351
+ if isinstance(o, tuple):
352
+ return tuple(_unwrap_if_constexpr(x) for x in o)
271
353
  return o.value if isinstance(o, constexpr) else o
272
354
 
273
355
 
356
+ def _normalize_tuple(t):
357
+ normalized_tuple = _unwrap_if_constexpr(t)
358
+ if isinstance(normalized_tuple, (list, builtins.tuple)):
359
+ normalized_tuple = tuple(normalized_tuple)
360
+ return normalized_tuple
361
+
362
+
274
363
  def check_bit_width(value, shift_value):
275
364
  if isinstance(value, tensor) and isinstance(shift_value, constexpr):
276
365
  bitwidth = value.type.scalar.primitive_bitwidth
@@ -280,34 +369,6 @@ def check_bit_width(value, shift_value):
280
369
  )
281
370
 
282
371
 
283
- class base_value:
284
- """Base class of values that exist in the triton IR (i.e. not constexprs).
285
- """
286
- type: base_type
287
-
288
- def _flatten_ir(self, handles: List[ir.value]) -> None:
289
- """Flatten frontend value into a sequence of mlir handles, which are appended
290
- to the output list
291
- """
292
- raise NotImplementedError
293
-
294
-
295
- class base_type:
296
-
297
- def __eq__(self, other):
298
- raise NotImplementedError("Types must implement __eq__")
299
-
300
- def __ne__(self, other):
301
- return not (self == other)
302
-
303
- def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
304
- """Build a frontend value with the current dtype, wrapping a list of existing handles.
305
- cursor is the index of the first handle relevant to this value, and the function
306
- should return the updated cursor position after any handles consumed by the created value.
307
- """
308
- raise NotImplementedError
309
-
310
-
311
372
  # -----------------------
312
373
  # dtype
313
374
  # -----------------------
@@ -333,55 +394,44 @@ class dtype(base_type):
333
394
  name = _unwrap_if_constexpr(name)
334
395
  self.name = name
335
396
  assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name
397
+ self.primitive_bitwidth = get_primitive_bitwidth(name)
398
+ self.itemsize = self.primitive_bitwidth // 8
336
399
  if name in dtype.SINT_TYPES:
337
400
  self.int_signedness = dtype.SIGNEDNESS.SIGNED
338
- self.int_bitwidth = int(name.split('int')[-1])
339
- self.primitive_bitwidth = self.int_bitwidth
401
+ self.int_bitwidth = self.primitive_bitwidth
340
402
  elif name in dtype.UINT_TYPES:
341
403
  self.int_signedness = dtype.SIGNEDNESS.UNSIGNED
342
- self.int_bitwidth = int(name.split('int')[-1])
343
- self.primitive_bitwidth = self.int_bitwidth
404
+ self.int_bitwidth = self.primitive_bitwidth
344
405
  elif name in dtype.FP_TYPES:
345
406
  if name == 'fp8e4b15':
346
407
  self.fp_mantissa_width = 3
347
- self.primitive_bitwidth = 8
348
408
  self.exponent_bias = 15
349
409
  elif name == 'fp8e4nv':
350
410
  self.fp_mantissa_width = 3
351
- self.primitive_bitwidth = 8
352
411
  self.exponent_bias = 7
353
412
  elif name == 'fp8e4b8':
354
413
  self.fp_mantissa_width = 3
355
- self.primitive_bitwidth = 8
356
414
  self.exponent_bias = 8
357
415
  elif name == 'fp8e5':
358
416
  self.fp_mantissa_width = 2
359
- self.primitive_bitwidth = 8
360
417
  self.exponent_bias = 15
361
418
  elif name == 'fp8e5b16':
362
419
  self.fp_mantissa_width = 2
363
- self.primitive_bitwidth = 8
364
420
  self.exponent_bias = 16
365
421
  elif name == 'fp16':
366
422
  self.fp_mantissa_width = 10
367
- self.primitive_bitwidth = 16
368
423
  self.exponent_bias = 15
369
424
  elif name == 'bf16':
370
425
  self.fp_mantissa_width = 7
371
- self.primitive_bitwidth = 16
372
426
  self.exponent_bias = 127
373
427
  elif name == 'fp32':
374
428
  self.fp_mantissa_width = 23
375
- self.primitive_bitwidth = 32
376
429
  self.exponent_bias = 127
377
430
  elif name == 'fp64':
378
431
  self.fp_mantissa_width = 52
379
- self.primitive_bitwidth = 64
380
432
  self.exponent_bias = 1023
381
433
  else:
382
434
  raise RuntimeError(f'Unsupported floating-point type {name}')
383
- elif name == 'void':
384
- self.primitive_bitwidth = 0
385
435
 
386
436
  def is_fp8(self):
387
437
  return 'fp8' in self.name
@@ -502,11 +552,8 @@ class dtype(base_type):
502
552
  def is_const():
503
553
  return False
504
554
 
505
- @staticmethod
506
- def is_tuple():
507
- return False
508
-
509
- def __eq__(self, other: dtype):
555
+ def __eq__(self, other) -> bool:
556
+ other = _unwrap_if_constexpr(other)
510
557
  if not isinstance(other, dtype):
511
558
  return False
512
559
  return self.name == other.name
@@ -518,13 +565,14 @@ class dtype(base_type):
518
565
  def scalar(self):
519
566
  return self
520
567
 
568
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
569
+ out.append(self.to_ir(builder))
570
+
521
571
  def to_ir(self, builder: ir.builder) -> ir.type:
522
572
  if self.name.startswith("fp8"):
523
573
  if self.name not in builder.options.supported_fp8_dtypes:
524
574
  raise ValueError(f'type {self} not supported in this architecture. '
525
575
  f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}')
526
- if self.name in builder.options.deprecated_fp8_dtypes:
527
- warn(f"{self.name} is deprecated in this architecture and will be removed in a future triton release")
528
576
 
529
577
  if self.name == 'void':
530
578
  return builder.get_void_ty()
@@ -581,6 +629,21 @@ class dtype(base_type):
581
629
  def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
582
630
  return tensor(handles[cursor], self), cursor + 1
583
631
 
632
+ def mangle(self) -> str:
633
+ if self.is_int():
634
+ SIGNED = dtype.SIGNEDNESS.SIGNED
635
+ prefix = 'i' if self.int_signedness == SIGNED else 'u'
636
+ return prefix + str(self.int_bitwidth)
637
+ if self.is_floating():
638
+ return str(self)
639
+ if self.is_void():
640
+ return 'V'
641
+ return super().mangle()
642
+
643
+ def with_element_ty(self, element_ty: dtype):
644
+ assert not self.is_block()
645
+ return element_ty
646
+
584
647
 
585
648
  # Some functions have a param named `dtype`, which shadows the `dtype` class.
586
649
  # We can't change the param name because it is part of function's public API.
@@ -614,7 +677,8 @@ class pointer_type(dtype):
614
677
  def is_const(self):
615
678
  return self.const
616
679
 
617
- def __eq__(self, other: pointer_type) -> bool:
680
+ def __eq__(self, other) -> bool:
681
+ other = _unwrap_if_constexpr(other)
618
682
  if not isinstance(other, pointer_type):
619
683
  return False
620
684
  return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const
@@ -623,12 +687,8 @@ class pointer_type(dtype):
623
687
  def scalar(self):
624
688
  return self
625
689
 
626
-
627
- class nv_tma_desc_type(pointer_type):
628
-
629
- def __init__(self, const=True, address_space=0):
630
- super().__init__(uint8, const=const, address_space=address_space)
631
- self.name = 'nv_tma_desc_type'
690
+ def mangle(self) -> str:
691
+ return f"P{self.element_ty.mangle()}"
632
692
 
633
693
 
634
694
  class block_type(dtype):
@@ -660,9 +720,12 @@ class block_type(dtype):
660
720
  def is_block(self):
661
721
  return True
662
722
 
663
- def get_block_shapes(self) -> List[int]:
723
+ def get_block_shapes(self) -> Tuple[int]:
664
724
  return self.shape
665
725
 
726
+ def with_element_ty(self, scalar_ty: dtype) -> block_type:
727
+ return block_type(scalar_ty, self.shape)
728
+
666
729
  def __eq__(self, other) -> bool:
667
730
  if not isinstance(other, block_type):
668
731
  return False
@@ -672,6 +735,15 @@ class block_type(dtype):
672
735
  def scalar(self):
673
736
  return self.element_ty
674
737
 
738
+ @property
739
+ def nbytes(self):
740
+ return self.numel * (self.element_ty.primitive_bitwidth // 8)
741
+
742
+ def mangle(self) -> str:
743
+ elt = self.scalar.mangle()
744
+ shape = '_'.join(map(str, self.shape))
745
+ return f'{elt}S{shape}S'
746
+
675
747
 
676
748
  class tuple_type(base_type):
677
749
 
@@ -686,15 +758,14 @@ class tuple_type(base_type):
686
758
  def __iter__(self):
687
759
  return iter(self.types)
688
760
 
689
- def to_ir(self, builder: ir.builder):
690
- return [ty.to_ir(builder) for ty in self.types]
761
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]):
762
+ for ty in self.types:
763
+ if not isinstance(ty, constexpr):
764
+ ty._flatten_ir_types(builder, out)
691
765
 
692
766
  def __getitem__(self, index: int) -> dtype:
693
767
  return self.types[index]
694
768
 
695
- def is_tuple(self):
696
- return True
697
-
698
769
  def __eq__(self, other):
699
770
  return type(self) is type(other) and self.types == other.types and self.fields == other.fields
700
771
 
@@ -705,6 +776,9 @@ class tuple_type(base_type):
705
776
  values.append(value)
706
777
  return tuple(values, self), cursor
707
778
 
779
+ def mangle(self):
780
+ return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T'
781
+
708
782
 
709
783
  class slice_type(dtype):
710
784
 
@@ -791,10 +865,7 @@ class tensor(base_value):
791
865
  self.handle = handle
792
866
  # Block shape
793
867
  self.shape = type.shape if type.is_block() else ()
794
- self.numel = 1
795
- for s in self.shape:
796
- self.numel *= s
797
- self.numel = constexpr(self.numel)
868
+ self.numel = constexpr(math.prod(self.shape))
798
869
  self.type = type # Tensor type (can be block_type)
799
870
  # Following the practice in pytorch, dtype is scalar type
800
871
  self.dtype = type.scalar
@@ -808,224 +879,224 @@ class tensor(base_value):
808
879
  return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
809
880
 
810
881
  @builtin
811
- def __add__(self, other, _builder=None):
812
- return add(self, other, sanitize_overflow=True, _builder=_builder)
882
+ def __add__(self, other, _semantic=None):
883
+ return add(self, other, sanitize_overflow=True, _semantic=_semantic)
813
884
 
814
885
  @builtin
815
- def __radd__(self, other, _builder=None):
816
- return add(other, self, sanitize_overflow=True, _builder=_builder)
886
+ def __radd__(self, other, _semantic=None):
887
+ return add(other, self, sanitize_overflow=True, _semantic=_semantic)
817
888
 
818
889
  @builtin
819
- def __sub__(self, other, _builder=None):
820
- return sub(self, other, sanitize_overflow=True, _builder=_builder)
890
+ def __sub__(self, other, _semantic=None):
891
+ return sub(self, other, sanitize_overflow=True, _semantic=_semantic)
821
892
 
822
893
  @builtin
823
- def __rsub__(self, other, _builder=None):
824
- return sub(other, self, sanitize_overflow=True, _builder=_builder)
894
+ def __rsub__(self, other, _semantic=None):
895
+ return sub(other, self, sanitize_overflow=True, _semantic=_semantic)
825
896
 
826
897
  @builtin
827
- def __mul__(self, other, _builder=None):
828
- return mul(self, other, sanitize_overflow=True, _builder=_builder)
898
+ def __mul__(self, other, _semantic=None):
899
+ return mul(self, other, sanitize_overflow=True, _semantic=_semantic)
829
900
 
830
901
  @builtin
831
- def __rmul__(self, other, _builder=None):
832
- return mul(other, self, sanitize_overflow=True, _builder=_builder)
902
+ def __rmul__(self, other, _semantic=None):
903
+ return mul(other, self, sanitize_overflow=True, _semantic=_semantic)
833
904
 
834
905
  @builtin
835
- def __truediv__(self, other, _builder=None):
906
+ def __truediv__(self, other, _semantic=None):
836
907
  other = _unwrap_if_constexpr(other)
837
- return semantic.truediv(self, other, _builder)
908
+ return _semantic.truediv(self, other)
838
909
 
839
910
  @builtin
840
- def __rtruediv__(self, other, _builder=None):
911
+ def __rtruediv__(self, other, _semantic=None):
841
912
  other = _unwrap_if_constexpr(other)
842
- return semantic.truediv(other, self, _builder)
913
+ return _semantic.truediv(other, self)
843
914
 
844
915
  @builtin
845
- def __floordiv__(self, other, _builder=None):
916
+ def __floordiv__(self, other, _semantic=None):
846
917
  other = _unwrap_if_constexpr(other)
847
- return semantic.floordiv(self, other, _builder)
918
+ return _semantic.floordiv(self, other)
848
919
 
849
920
  @builtin
850
- def __rfloordiv__(self, other, _builder=None):
921
+ def __rfloordiv__(self, other, _semantic=None):
851
922
  other = _unwrap_if_constexpr(other)
852
- return semantic.floordiv(other, self, _builder)
923
+ return _semantic.floordiv(other, self)
853
924
 
854
925
  @builtin
855
- def __mod__(self, other, _builder=None):
926
+ def __mod__(self, other, _semantic=None):
856
927
  other = _unwrap_if_constexpr(other)
857
- return semantic.mod(self, other, _builder)
928
+ return _semantic.mod(self, other)
858
929
 
859
930
  @builtin
860
- def __rmod__(self, other, _builder=None):
931
+ def __rmod__(self, other, _semantic=None):
861
932
  other = _unwrap_if_constexpr(other)
862
- return semantic.mod(other, self, _builder)
933
+ return _semantic.mod(other, self)
863
934
 
864
935
  # unary operators
865
936
  @builtin
866
- def __neg__(self, _builder=None):
867
- return semantic.minus(self, _builder)
937
+ def __neg__(self, _semantic=None):
938
+ return _semantic.minus(self)
868
939
 
869
940
  @builtin
870
- def __invert__(self, _builder=None):
871
- return semantic.invert(self, _builder)
941
+ def __invert__(self, _semantic=None):
942
+ return _semantic.invert(self)
872
943
 
873
944
  # bitwise operators
874
945
 
875
946
  @builtin
876
- def __and__(self, other, _builder=None):
947
+ def __and__(self, other, _semantic=None):
877
948
  other = _unwrap_if_constexpr(other)
878
- return semantic.and_(self, other, _builder)
949
+ return _semantic.and_(self, other)
879
950
 
880
951
  @builtin
881
- def __rand__(self, other, _builder=None):
952
+ def __rand__(self, other, _semantic=None):
882
953
  other = _unwrap_if_constexpr(other)
883
- return semantic.and_(other, self, _builder)
954
+ return _semantic.and_(other, self)
884
955
 
885
956
  @builtin
886
- def __or__(self, other, _builder=None):
957
+ def __or__(self, other, _semantic=None):
887
958
  other = _unwrap_if_constexpr(other)
888
- return semantic.or_(self, other, _builder)
959
+ return _semantic.or_(self, other)
889
960
 
890
961
  @builtin
891
- def __ror__(self, other, _builder=None):
962
+ def __ror__(self, other, _semantic=None):
892
963
  other = _unwrap_if_constexpr(other)
893
- return semantic.or_(other, self, _builder)
964
+ return _semantic.or_(other, self)
894
965
 
895
966
  @builtin
896
- def __xor__(self, other, _builder=None):
967
+ def __xor__(self, other, _semantic=None):
897
968
  other = _unwrap_if_constexpr(other)
898
- return semantic.xor_(self, other, _builder)
969
+ return _semantic.xor_(self, other)
899
970
 
900
971
  @builtin
901
- def __rxor__(self, other, _builder=None):
972
+ def __rxor__(self, other, _semantic=None):
902
973
  other = _unwrap_if_constexpr(other)
903
- return semantic.xor_(other, self, _builder)
974
+ return _semantic.xor_(other, self)
904
975
 
905
976
  @builtin
906
- def __lshift__(self, other, _builder=None):
977
+ def __lshift__(self, other, _semantic=None):
907
978
  check_bit_width(self, other)
908
979
  other = _unwrap_if_constexpr(other)
909
- return semantic.shl(self, other, _builder)
980
+ return _semantic.shl(self, other)
910
981
 
911
982
  @builtin
912
- def __rlshift__(self, other, _builder=None):
983
+ def __rlshift__(self, other, _semantic=None):
913
984
  check_bit_width(other, self)
914
985
  other = _unwrap_if_constexpr(other)
915
- return semantic.shl(other, self, _builder)
986
+ return _semantic.shl(other, self)
916
987
 
917
988
  @builtin
918
- def __rshift__(self, other, _builder=None):
989
+ def __rshift__(self, other, _semantic=None):
919
990
  check_bit_width(self, other)
920
991
  other = _unwrap_if_constexpr(other)
921
992
  if self.dtype.is_int_signed():
922
- return semantic.ashr(self, other, _builder)
993
+ return _semantic.ashr(self, other)
923
994
  else:
924
- return semantic.lshr(self, other, _builder)
995
+ return _semantic.lshr(self, other)
925
996
 
926
997
  @builtin
927
- def __rrshift__(self, other, _builder=None):
998
+ def __rrshift__(self, other, _semantic=None):
928
999
  check_bit_width(other, self)
929
1000
  other = _unwrap_if_constexpr(other)
930
1001
  if self.dtype.is_int_signed():
931
- return semantic.ashr(other, self, _builder)
1002
+ return _semantic.ashr(other, self)
932
1003
  else:
933
- return semantic.lshr(other, self, _builder)
1004
+ return _semantic.lshr(other, self)
934
1005
 
935
1006
  # >
936
1007
  @builtin
937
- def __gt__(self, other, _builder=None):
938
- other = semantic.to_tensor(other, _builder)
939
- return semantic.greater_than(self, other, _builder)
1008
+ def __gt__(self, other, _semantic=None):
1009
+ other = _semantic.to_tensor(other)
1010
+ return _semantic.greater_than(self, other)
940
1011
 
941
1012
  @builtin
942
- def __rgt__(self, other, _builder=None):
943
- other = semantic.to_tensor(other, _builder)
944
- return semantic.greater_than(other, self, _builder)
1013
+ def __rgt__(self, other, _semantic=None):
1014
+ other = _semantic.to_tensor(other)
1015
+ return _semantic.greater_than(other, self)
945
1016
 
946
1017
  # >=
947
1018
  @builtin
948
- def __ge__(self, other, _builder=None):
949
- other = semantic.to_tensor(other, _builder)
950
- return semantic.greater_equal(self, other, _builder)
1019
+ def __ge__(self, other, _semantic=None):
1020
+ other = _semantic.to_tensor(other)
1021
+ return _semantic.greater_equal(self, other)
951
1022
 
952
1023
  @builtin
953
- def __rge__(self, other, _builder=None):
954
- other = semantic.to_tensor(other, _builder)
955
- return semantic.greater_equal(other, self, _builder)
1024
+ def __rge__(self, other, _semantic=None):
1025
+ other = _semantic.to_tensor(other)
1026
+ return _semantic.greater_equal(other, self)
956
1027
 
957
1028
  # <
958
1029
  @builtin
959
- def __lt__(self, other, _builder=None):
960
- other = semantic.to_tensor(other, _builder)
961
- return semantic.less_than(self, other, _builder)
1030
+ def __lt__(self, other, _semantic=None):
1031
+ other = _semantic.to_tensor(other)
1032
+ return _semantic.less_than(self, other)
962
1033
 
963
1034
  @builtin
964
- def __rlt__(self, other, _builder=None):
965
- other = semantic.to_tensor(other, _builder)
966
- return semantic.less_than(other, self, _builder)
1035
+ def __rlt__(self, other, _semantic=None):
1036
+ other = _semantic.to_tensor(other)
1037
+ return _semantic.less_than(other, self)
967
1038
 
968
1039
  # <=
969
1040
  @builtin
970
- def __le__(self, other, _builder=None):
971
- other = semantic.to_tensor(other, _builder)
972
- return semantic.less_equal(self, other, _builder)
1041
+ def __le__(self, other, _semantic=None):
1042
+ other = _semantic.to_tensor(other)
1043
+ return _semantic.less_equal(self, other)
973
1044
 
974
1045
  @builtin
975
- def __rle__(self, other, _builder=None):
976
- other = semantic.to_tensor(other, _builder)
977
- return semantic.less_equal(other, self, _builder)
1046
+ def __rle__(self, other, _semantic=None):
1047
+ other = _semantic.to_tensor(other)
1048
+ return _semantic.less_equal(other, self)
978
1049
 
979
1050
  # ==
980
1051
  @builtin
981
- def __eq__(self, other, _builder=None):
982
- other = semantic.to_tensor(other, _builder)
983
- return semantic.equal(self, other, _builder)
1052
+ def __eq__(self, other, _semantic=None):
1053
+ other = _semantic.to_tensor(other)
1054
+ return _semantic.equal(self, other)
984
1055
 
985
1056
  @builtin
986
- def __req__(self, other, _builder=None):
987
- other = semantic.to_tensor(other, _builder)
988
- return semantic.equal(other, self, _builder)
1057
+ def __req__(self, other, _semantic=None):
1058
+ other = _semantic.to_tensor(other)
1059
+ return _semantic.equal(other, self)
989
1060
 
990
1061
  @builtin
991
- def __ne__(self, other, _builder=None):
992
- other = semantic.to_tensor(other, _builder)
993
- return semantic.not_equal(self, other, _builder)
1062
+ def __ne__(self, other, _semantic=None):
1063
+ other = _semantic.to_tensor(other)
1064
+ return _semantic.not_equal(self, other)
994
1065
 
995
1066
  @builtin
996
- def __rne__(self, other, _builder=None):
997
- other = semantic.to_tensor(other, _builder)
998
- return semantic.not_equal(other, self, _builder)
1067
+ def __rne__(self, other, _semantic=None):
1068
+ other = _semantic.to_tensor(other)
1069
+ return _semantic.not_equal(other, self)
999
1070
 
1000
1071
  @builtin
1001
- def logical_and(self, other, _builder=None):
1002
- other = semantic.to_tensor(other, _builder)
1003
- return semantic.logical_and(self, other, _builder)
1072
+ def logical_and(self, other, _semantic=None):
1073
+ other = _semantic.to_tensor(other)
1074
+ return _semantic.logical_and(self, other)
1004
1075
 
1005
1076
  @builtin
1006
- def logical_or(self, other, _builder=None):
1007
- other = semantic.to_tensor(other, _builder)
1008
- return semantic.logical_or(self, other, _builder)
1077
+ def logical_or(self, other, _semantic=None):
1078
+ other = _semantic.to_tensor(other)
1079
+ return _semantic.logical_or(self, other)
1009
1080
 
1010
1081
  # note: __not__ isn't actually a magic method in python
1011
1082
  # but it's ok because our ASTVisitor handles it
1012
1083
  @builtin
1013
- def __not__(self, _builder=None):
1014
- return semantic.not_(self, _builder)
1084
+ def __not__(self, _semantic=None):
1085
+ return _semantic.not_(self)
1015
1086
 
1016
1087
  @builtin
1017
- def __getitem__(self, slices, _builder=None):
1018
- import builtins
1088
+ def __getitem__(self, slices, _semantic=None):
1019
1089
  if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None:
1020
1090
  slices = [slices]
1021
1091
  if isinstance(slices, tuple):
1022
1092
  slices = slices.values
1023
1093
  ret = self
1024
1094
  for dim, sl in enumerate(slices):
1025
- if sl is None or isinstance(sl, constexpr) and sl.value is None:
1026
- ret = semantic.expand_dims(ret, dim, _builder)
1027
- elif isinstance(sl, (builtins.slice, slice)) and sl.start is None and sl.stop is None and sl.step is None:
1028
- pass
1095
+ if _unwrap_if_constexpr(sl) is None:
1096
+ ret = _semantic.expand_dims(ret, dim)
1097
+ elif isinstance(sl, (builtins.slice, slice)) and all(
1098
+ _unwrap_if_constexpr(arg) is None for arg in (sl.start, sl.stop, sl.step)):
1099
+ pass # an unsqueeze
1029
1100
  else:
1030
1101
  raise ValueError(f"unsupported tensor index: {sl}")
1031
1102
  return ret
@@ -1036,11 +1107,11 @@ class tensor(base_value):
1036
1107
  assert False, "Transposition must be created by the AST Visitor"
1037
1108
 
1038
1109
  @builtin
1039
- def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None):
1110
+ def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
1040
1111
  """
1041
1112
  Alias for :py:func:`tensor.cast`.
1042
1113
  """
1043
- return cast(self, dtype, fp_downcast_rounding, bitcast, _builder=_builder)
1114
+ return cast(self, dtype, fp_downcast_rounding, bitcast, _semantic=_semantic)
1044
1115
 
1045
1116
  # Type stubs for functions added by the _tensor_member_fn decorator.
1046
1117
  # (Unfortunately these can't be created automatically.)
@@ -1140,7 +1211,7 @@ class tensor(base_value):
1140
1211
  def sigmoid(self) -> tensor:
1141
1212
  ...
1142
1213
 
1143
- def softmax(self, ieee_rounding=False) -> tensor:
1214
+ def softmax(self, dim=None, keep_dims=False, ieee_rounding=False) -> tensor:
1144
1215
  ...
1145
1216
 
1146
1217
  def ravel(self) -> tensor:
@@ -1164,6 +1235,9 @@ class tensor(base_value):
1164
1235
  def xor_sum(self, axis=None, keep_dims=False) -> tensor:
1165
1236
  ...
1166
1237
 
1238
+ def reduce_or(self, axis=None, keep_dims=False) -> tensor:
1239
+ ...
1240
+
1167
1241
  def cumsum(self, axis=0, reverse=False) -> tensor:
1168
1242
  ...
1169
1243
 
@@ -1177,19 +1251,20 @@ class tensor(base_value):
1177
1251
  ...
1178
1252
 
1179
1253
 
1180
- class tuple(base_value):
1254
+ def _type_for_tuple_values(values, fields=None):
1255
+ return tuple_type([constexpr_type(x) if isinstance(x, (int, float, dtype)) else x.type for x in values], fields)
1181
1256
 
1182
- def __init__(self, args: list, type: tuple_type = None):
1183
- self.values = [i for i in args]
1184
1257
 
1185
- def get_type(x):
1186
- if isinstance(x, dtype):
1187
- return dtype
1188
- if isinstance(x, int):
1189
- return constexpr
1190
- return x.type
1258
+ class tuple(base_value):
1191
1259
 
1192
- self.type = type or tuple_type([get_type(x) for x in self.values])
1260
+ def __init__(self, args: Sequence, type: Optional[tuple_type] = None):
1261
+ self.values = [i for i in args]
1262
+ if isinstance(type, tuple_type):
1263
+ self.type = type
1264
+ elif type is not None: # make_template in ASTFunction.deserialize may pass us a list/tuple
1265
+ self.type = tuple_type(type)
1266
+ else:
1267
+ self.type = _type_for_tuple_values(self.values)
1193
1268
 
1194
1269
  def __getitem__(self, idx: constexpr):
1195
1270
  if isinstance(idx, int):
@@ -1197,7 +1272,6 @@ class tuple(base_value):
1197
1272
  if isinstance(idx, constexpr):
1198
1273
  return self.values[idx]
1199
1274
  else:
1200
- import builtins
1201
1275
  assert isinstance(idx, (slice, builtins.slice))
1202
1276
  return tuple(self.values[idx.start:idx.stop:idx.step])
1203
1277
 
@@ -1205,15 +1279,14 @@ class tuple(base_value):
1205
1279
  return self.values[self.type.fields.index(name)]
1206
1280
 
1207
1281
  # TODO: remove
1208
- def __setitem__(self, idx: constexpr, value):
1209
- if isinstance(idx, int):
1210
- idx = constexpr(idx)
1211
- assert isinstance(idx, constexpr)
1282
+ def _setitem(self, idx, value):
1283
+ idx = _unwrap_if_constexpr(idx)
1284
+ assert isinstance(idx, int)
1212
1285
  self.values[idx] = value
1286
+ self.type = _type_for_tuple_values(self.values, self.type.fields)
1213
1287
 
1214
1288
  def __add__(self, other):
1215
- if isinstance(other, list):
1216
- other = tuple(other)
1289
+ other = _normalize_tuple(other)
1217
1290
  return tuple(self.values + other.values)
1218
1291
  # return tuple(a + b for a, b in zip(self.values, other.values))
1219
1292
 
@@ -1222,13 +1295,10 @@ class tuple(base_value):
1222
1295
  return tuple(self.values * other.value)
1223
1296
 
1224
1297
  def __eq__(self, other):
1225
- import builtins
1226
- if isinstance(other, (list, builtins.tuple)):
1227
- other = tuple(other)
1298
+ other = _normalize_tuple(other)
1228
1299
  return constexpr(self.values == other.values)
1229
1300
 
1230
1301
  def __hash__(self):
1231
- import builtins
1232
1302
  return hash(builtins.tuple(self.values))
1233
1303
 
1234
1304
  def __str__(self):
@@ -1244,6 +1314,9 @@ class tuple(base_value):
1244
1314
  for v in self.values:
1245
1315
  v._flatten_ir(handles)
1246
1316
 
1317
+ def __repr__(self):
1318
+ return f"({' ,'.join(repr(x) for x in self.values)})"
1319
+
1247
1320
 
1248
1321
  class slice:
1249
1322
 
@@ -1259,12 +1332,13 @@ class tensor_descriptor_base_type(base_type):
1259
1332
  def __init__(self, block_type: block_type):
1260
1333
  self.block_type = block_type
1261
1334
 
1262
- def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[_experimental_tensor_descriptor_base, int]:
1263
- value = _experimental_tensor_descriptor_base(handles[cursor], self.block_type)
1335
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
1336
+ value = tensor_descriptor_base(handles[cursor], self.block_type)
1264
1337
  return value, cursor + 1
1265
1338
 
1266
- def to_ir(self, builder: ir.builder):
1267
- return builder.create_tensor_descriptor_type(self.block_type.to_ir(builder))
1339
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1340
+ is_signed = self.block_type.element_ty.is_int_signed()
1341
+ out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed))
1268
1342
 
1269
1343
  def __str__(self) -> str:
1270
1344
  # ex. "tensor_descriptor<float32[16, 32]>"
@@ -1278,8 +1352,11 @@ class tensor_descriptor_base_type(base_type):
1278
1352
  def __neq__(self, other) -> bool:
1279
1353
  return not (self == other)
1280
1354
 
1355
+ def mangle(self) -> str:
1356
+ return f"TD{self.block_type.mangle()}"
1357
+
1281
1358
 
1282
- class _experimental_tensor_descriptor_base(base_value):
1359
+ class tensor_descriptor_base(base_value):
1283
1360
  """"
1284
1361
  A tensor descriptor with unknown shape and strides
1285
1362
  """
@@ -1310,40 +1387,64 @@ class _experimental_tensor_descriptor_base(base_value):
1310
1387
  return str(self.type)
1311
1388
 
1312
1389
  @builtin
1313
- def load(self, offsets: Sequence[constexpr | tensor], _builder=None) -> tensor:
1390
+ def load(self, offsets: Sequence[constexpr | tensor], _semantic=None) -> tensor:
1314
1391
  """Load a block from the descriptor starting at the given element offsets.
1315
1392
 
1316
1393
  Values outside of the tensor bounds will be filled with zeros.
1317
1394
 
1318
1395
  :note: Offset must be a multiple of 16-bytes
1319
1396
  """
1320
- return semantic.descriptor_load(self, offsets, "", "", _builder)
1397
+ return _semantic.descriptor_load(self, offsets, "", "")
1321
1398
 
1322
1399
  @builtin
1323
- def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _builder=None) -> tensor:
1400
+ def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1324
1401
  """Store a block from the descriptor starting at the given element offsets.
1325
1402
 
1326
1403
  Values outside of the tensor bounds will be ignored.
1327
1404
 
1328
1405
  :note: Offset must be a multiple of 16-bytes
1329
1406
  """
1330
- return semantic.descriptor_store(self, value, offsets, _builder)
1407
+ return _semantic.descriptor_store(self, value, offsets)
1408
+
1409
+ @builtin
1410
+ def atomic_add(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1411
+ return _semantic.descriptor_atomic_add(self, value, offsets)
1412
+
1413
+ @builtin
1414
+ def atomic_min(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1415
+ return _semantic.descriptor_atomic_min(self, value, offsets)
1416
+
1417
+ @builtin
1418
+ def atomic_max(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1419
+ return _semantic.descriptor_atomic_max(self, value, offsets)
1420
+
1421
+ @builtin
1422
+ def atomic_and(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1423
+ return _semantic.descriptor_atomic_and(self, value, offsets)
1331
1424
 
1332
1425
  @builtin
1333
- def gather(self, *args, _builder=None) -> tensor:
1426
+ def atomic_or(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1427
+ return _semantic.descriptor_atomic_or(self, value, offsets)
1428
+
1429
+ @builtin
1430
+ def atomic_xor(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1431
+ return _semantic.descriptor_atomic_xor(self, value, offsets)
1432
+
1433
+ @builtin
1434
+ def gather(self, *args, _semantic=None) -> tensor:
1334
1435
  """Gather multiple descriptors worth of data"""
1335
1436
  assert len(args) == 2, f"descriptor gather only supports 2D indexing, but got {len(args)}"
1336
1437
  x_offsets = args[0]
1337
1438
  y_offset = args[1]
1338
- return semantic.descriptor_gather(self, x_offsets, y_offset, "", "", _builder)
1439
+ return _semantic.descriptor_gather(self, x_offsets, y_offset, "", "")
1339
1440
 
1340
1441
  @builtin
1341
- def scatter(self, value, *args, _builder=None) -> tensor:
1442
+ def scatter(self, value, *args, _semantic=None) -> tensor:
1342
1443
  """Scatter multiple descriptors worth of data"""
1343
1444
  assert len(args) == 2, f"descriptor scatter only supports 2D indexing, but got {len(args)}"
1344
1445
  x_offsets = args[0]
1345
1446
  y_offset = args[1]
1346
- return semantic.descriptor_scatter(self, value, x_offsets, y_offset, _builder)
1447
+ return _semantic.descriptor_scatter(self, value, x_offsets, y_offset)
1347
1448
 
1348
1449
 
1349
1450
  class tensor_descriptor_type(tensor_descriptor_base_type):
@@ -1353,25 +1454,27 @@ class tensor_descriptor_type(tensor_descriptor_base_type):
1353
1454
  self.shape_type = shape_type
1354
1455
  self.strides_type = strides_type
1355
1456
 
1356
- def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[_experimental_tensor_descriptor_base, int]:
1457
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
1357
1458
  handle = handles[cursor]
1358
1459
  cursor += 1
1359
1460
  shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
1360
1461
  strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
1361
1462
  shape = shape.values
1362
1463
  strides = strides.values
1363
- value = _experimental_tensor_descriptor(handle, shape, strides, self.block_type)
1464
+ value = tensor_descriptor(handle, shape, strides, self.block_type)
1364
1465
  return value, cursor
1365
1466
 
1366
- def to_ir(self, builder: ir.builder):
1367
- return [super().to_ir(builder), *self.shape_type.to_ir(builder), *self.strides_type.to_ir(builder)]
1467
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1468
+ super()._flatten_ir_types(builder, out)
1469
+ self.shape_type._flatten_ir_types(builder, out)
1470
+ self.strides_type._flatten_ir_types(builder, out)
1368
1471
 
1369
1472
  def __eq__(self, other):
1370
1473
  return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type
1371
1474
  == other.strides_type)
1372
1475
 
1373
1476
 
1374
- class _experimental_tensor_descriptor(_experimental_tensor_descriptor_base):
1477
+ class tensor_descriptor(tensor_descriptor_base):
1375
1478
  """A descriptor representing a tensor in global memory.
1376
1479
  """
1377
1480
 
@@ -1379,37 +1482,121 @@ class _experimental_tensor_descriptor(_experimental_tensor_descriptor_base):
1379
1482
  """Not called by user code."""
1380
1483
  # IR handle
1381
1484
  super().__init__(handle, block_type)
1485
+ # Global shape
1486
+ self.shape = tuple(shape)
1487
+ self.strides = tuple(strides)
1382
1488
  self.type = tensor_descriptor_type(
1383
1489
  block_type,
1384
- shape_type=tuple_type([s.type for s in shape]),
1385
- strides_type=tuple_type([s.type for s in strides]),
1490
+ shape_type=self.shape.type,
1491
+ strides_type=self.strides.type,
1386
1492
  )
1387
- # Global shape
1388
- self.shape = shape
1389
- self.strides = strides
1390
1493
 
1391
1494
  def _flatten_ir(self, handles: List[ir.value]) -> None:
1392
1495
  handles.append(self.handle)
1393
- handles.extend(s.handle for s in self.shape)
1394
- handles.extend(s.handle for s in self.strides)
1496
+ self.shape._flatten_ir(handles)
1497
+ self.strides._flatten_ir(handles)
1395
1498
 
1396
1499
 
1397
- def get_bool_env_var(var_name):
1398
- v = os.getenv(var_name, "0")
1399
- return v == "1" or v == "true" or v == "on"
1500
+ # -----------------------
1501
+ # aggregate
1502
+ # -----------------------
1503
+
1504
+
1505
+ @dataclass(frozen=True)
1506
+ class _aggregate_type(base_type):
1507
+ """A generic base type for all Triton aggregate types.
1508
+
1509
+ This class contains a reference to the original user-defined Python class
1510
+ and a list of class fields with their Triton types.
1511
+ """
1512
+
1513
+ base_cls: type
1514
+ fields: List[Tuple[str, base_type]]
1515
+
1516
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]:
1517
+ instance = self.base_cls._get_instance()
1518
+ for name, ty in self.fields:
1519
+ value, cursor = ty._unflatten_ir(handles, cursor)
1520
+ setattr(instance, name, value)
1521
+ return instance, cursor
1522
+
1523
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1524
+ for name, ty in self.fields:
1525
+ ty._flatten_ir_types(builder, out)
1526
+
1527
+ def mangle(self) -> str:
1528
+ name = f"{self.base_cls.__module__}.{self.base_cls.__qualname__}"
1529
+ fields = [ty.mangle() for (name, ty) in self.fields]
1530
+ return f"{name}<{', '.join(fields)}>"
1531
+
1532
+
1533
+ def _aggregate(cls):
1534
+
1535
+ # Define the wrapped Triton value type.
1536
+ class aggregate_value(base_value):
1537
+ __triton_builtin__ = True
1538
+ __triton_aggregate__ = True
1539
+
1540
+ @classmethod
1541
+ def _get_instance(this_cls):
1542
+ return super().__new__(this_cls)
1543
+
1544
+ def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs):
1545
+ # Call into the user-defined constructor.
1546
+ instance = this_cls._get_instance()
1547
+ if isinstance(cls.__init__, JITCallable):
1548
+ raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
1549
+ extra_kwargs = {}
1550
+ if "_semantic" in inspect.signature(cls.__init__).parameters:
1551
+ extra_kwargs["_semantic"] = _semantic
1552
+ if "_generator" in inspect.signature(cls.__init__).parameters:
1553
+ extra_kwargs["_generator"] = _generator
1554
+ cls.__init__(instance, *args, **extra_kwargs, **kwargs)
1555
+
1556
+ # Require that the user-defined constructor initialized all fields.
1557
+ for name in cls.__annotations__.keys():
1558
+ if not hasattr(instance, name):
1559
+ raise AttributeError(f"constructor for {cls.__name__} did not initialize attribute '{name}'")
1560
+
1561
+ return instance
1562
+
1563
+ # Only allow setting attributes defined in the class annotations.
1564
+ def __setattr__(self, name, value):
1565
+ if name not in cls.__annotations__:
1566
+ raise AttributeError(f"{cls.__name__} has no attribute '{name}'")
1567
+ if not isinstance(value, cls.__annotations__[name]):
1568
+ raise TypeError(f"Expected {cls.__annotations__[name]} for attribute '{name}', got {type(value)}")
1569
+ super().__setattr__(name, value)
1570
+
1571
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
1572
+ for name in cls.__annotations__.keys():
1573
+ getattr(self, name)._flatten_ir(handles)
1574
+
1575
+ @property
1576
+ def type(self):
1577
+ return _aggregate_type(aggregate_value,
1578
+ [(name, getattr(self, name).type) for name in cls.__annotations__.keys()])
1579
+
1580
+ for (name, member) in inspect.getmembers(cls):
1581
+ if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITCallable):
1582
+ if name != "__init__":
1583
+ setattr(aggregate_value, name, member)
1584
+
1585
+ aggregate_value.__name__ = cls.__name__
1586
+ aggregate_value.__module__ = cls.__module__
1587
+ aggregate_value.__qualname__ = cls.__qualname__
1588
+ aggregate_value.__doc__ = cls.__doc__
1589
+
1590
+ return aggregate_value
1400
1591
 
1401
1592
 
1402
1593
  # -----------------------
1403
1594
  # SPMD Programming Model
1404
1595
  # -----------------------
1405
- def _constexpr_to_value(v):
1406
- if isinstance(v, constexpr):
1407
- return v.value
1408
- return v
1409
1596
 
1410
1597
 
1411
1598
  @builtin
1412
- def program_id(axis, _builder=None):
1599
+ def program_id(axis, _semantic=None):
1413
1600
  """
1414
1601
  Returns the id of the current program instance along the given :code:`axis`.
1415
1602
 
@@ -1417,26 +1604,26 @@ def program_id(axis, _builder=None):
1417
1604
  :type axis: int
1418
1605
  """
1419
1606
  # if axis == -1:
1420
- # pid0 = program_id(0, _builder)
1421
- # pid1 = program_id(1, _builder)
1422
- # pid2 = program_id(2, _builder)
1423
- # npg0 = num_programs(0, _builder)
1424
- # npg1 = num_programs(1, _builder)
1607
+ # pid0 = _semantic.program_id(0)
1608
+ # pid1 = _semantic.program_id(1)
1609
+ # pid2 = _semantic.program_id(2)
1610
+ # npg0 = _semantic.num_programs(0)
1611
+ # npg1 = _semantic.num_programs(1)
1425
1612
  # return pid0 + pid1*npg0 + pid2*npg0*npg1
1426
- axis = _constexpr_to_value(axis)
1427
- return semantic.program_id(axis, _builder)
1613
+ axis = _unwrap_if_constexpr(axis)
1614
+ return _semantic.program_id(axis)
1428
1615
 
1429
1616
 
1430
1617
  @builtin
1431
- def num_programs(axis, _builder=None):
1618
+ def num_programs(axis, _semantic=None):
1432
1619
  """
1433
1620
  Returns the number of program instances launched along the given :code:`axis`.
1434
1621
 
1435
1622
  :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
1436
1623
  :type axis: int
1437
1624
  """
1438
- axis = _constexpr_to_value(axis)
1439
- return semantic.num_programs(axis, _builder)
1625
+ axis = _unwrap_if_constexpr(axis)
1626
+ return _semantic.num_programs(axis)
1440
1627
 
1441
1628
 
1442
1629
  # -----------------------
@@ -1445,10 +1632,10 @@ def num_programs(axis, _builder=None):
1445
1632
 
1446
1633
 
1447
1634
  @builtin
1448
- def arange(start, end, _builder=None):
1449
- start = _constexpr_to_value(start)
1450
- end = _constexpr_to_value(end)
1451
- return semantic.arange(start, end, _builder)
1635
+ def arange(start, end, _semantic=None):
1636
+ start = _unwrap_if_constexpr(start)
1637
+ end = _unwrap_if_constexpr(end)
1638
+ return _semantic.arange(start, end)
1452
1639
 
1453
1640
 
1454
1641
  arange.__doc__ = f"""
@@ -1465,8 +1652,8 @@ arange.__doc__ = f"""
1465
1652
 
1466
1653
 
1467
1654
  def _unwrap_shape(shape):
1468
- shape = _constexpr_to_value(shape)
1469
- return [_constexpr_to_value(s) for s in shape]
1655
+ shape = _unwrap_if_constexpr(shape)
1656
+ return [_unwrap_if_constexpr(s) for s in shape]
1470
1657
 
1471
1658
 
1472
1659
  def _shape_check_impl(shape):
@@ -1476,7 +1663,7 @@ def _shape_check_impl(shape):
1476
1663
 
1477
1664
 
1478
1665
  @builtin
1479
- def full(shape, value, dtype, _builder=None):
1666
+ def full(shape, value, dtype, _semantic=None):
1480
1667
  """
1481
1668
  Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`.
1482
1669
 
@@ -1488,9 +1675,9 @@ def full(shape, value, dtype, _builder=None):
1488
1675
  :type dtype: tl.dtype
1489
1676
  """
1490
1677
  shape = _shape_check_impl(shape)
1491
- value = _constexpr_to_value(value)
1492
- dtype = _constexpr_to_value(dtype)
1493
- return semantic.full(shape, value, dtype, _builder)
1678
+ value = _unwrap_if_constexpr(value)
1679
+ dtype = _unwrap_if_constexpr(dtype)
1680
+ return _semantic.full(shape, value, dtype)
1494
1681
 
1495
1682
 
1496
1683
  # -----------------------
@@ -1499,7 +1686,7 @@ def full(shape, value, dtype, _builder=None):
1499
1686
 
1500
1687
 
1501
1688
  @builtin
1502
- def broadcast(input, other, _builder=None):
1689
+ def broadcast(input, other, _semantic=None):
1503
1690
  """
1504
1691
  Tries to broadcast the two given blocks to a common compatible shape.
1505
1692
 
@@ -1508,12 +1695,12 @@ def broadcast(input, other, _builder=None):
1508
1695
  :param other: The second input tensor.
1509
1696
  :type other: Block
1510
1697
  """
1511
- return semantic.broadcast_impl_value(input, other, _builder)
1698
+ return _semantic.broadcast_impl_value(input, other)
1512
1699
 
1513
1700
 
1514
1701
  @_tensor_member_fn
1515
1702
  @builtin
1516
- def broadcast_to(input, *shape, _builder=None):
1703
+ def broadcast_to(input, *shape, _semantic=None):
1517
1704
  """
1518
1705
  Tries to broadcast the given tensor to a new :code:`shape`.
1519
1706
 
@@ -1529,12 +1716,12 @@ def broadcast_to(input, *shape, _builder=None):
1529
1716
  broadcast_to(x, 32, 32)
1530
1717
  """
1531
1718
  shape = _shape_check_impl(_unwrap_iterable(shape))
1532
- return semantic.broadcast_impl_shape(input, shape, _builder)
1719
+ return _semantic.broadcast_impl_shape(input, shape)
1533
1720
 
1534
1721
 
1535
1722
  @_tensor_member_fn
1536
1723
  @builtin
1537
- def trans(input: tensor, *dims, _builder=None):
1724
+ def trans(input: tensor, *dims, _semantic=None):
1538
1725
  """
1539
1726
  Permutes the dimensions of a tensor.
1540
1727
 
@@ -1543,7 +1730,7 @@ def trans(input: tensor, *dims, _builder=None):
1543
1730
 
1544
1731
  :param input: The input tensor.
1545
1732
  :param dims: The desired ordering of dimensions. For example,
1546
- :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor.
1733
+ :code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
1547
1734
 
1548
1735
  :code:`dims` can be passed as a tuple or as individual parameters: ::
1549
1736
 
@@ -1557,19 +1744,19 @@ def trans(input: tensor, *dims, _builder=None):
1557
1744
  dims = _unwrap_iterable(dims)
1558
1745
  if not dims:
1559
1746
  dims = (1, 0)
1560
- return semantic.permute(input, dims, _builder)
1747
+ return _semantic.permute(input, dims)
1561
1748
 
1562
1749
 
1563
1750
  @_tensor_member_fn
1564
1751
  @builtin
1565
- def permute(input, *dims, _builder=None):
1752
+ def permute(input, *dims, _semantic=None):
1566
1753
  """
1567
1754
  Permutes the dimensions of a tensor.
1568
1755
 
1569
1756
  :param input: The input tensor.
1570
1757
  :type input: Block
1571
1758
  :param dims: The desired ordering of dimensions. For example,
1572
- :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor.
1759
+ :code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
1573
1760
 
1574
1761
  :code:`dims` can be passed as a tuple or as individual parameters: ::
1575
1762
 
@@ -1581,11 +1768,11 @@ def permute(input, *dims, _builder=None):
1581
1768
  :code:`dims` is empty, it tries to do a (1,0) permutation.
1582
1769
  """
1583
1770
  dims = _unwrap_iterable(dims)
1584
- return semantic.permute(input, dims, _builder)
1771
+ return _semantic.permute(input, dims)
1585
1772
 
1586
1773
 
1587
1774
  @builtin
1588
- def cat(input, other, can_reorder=False, _builder=None):
1775
+ def cat(input, other, can_reorder=False, _semantic=None):
1589
1776
  """
1590
1777
  Concatenate the given blocks
1591
1778
 
@@ -1598,11 +1785,11 @@ def cat(input, other, can_reorder=False, _builder=None):
1598
1785
  order does not matter (e.g., result is only used in reduction ops).
1599
1786
  Current implementation of `cat` supports only can_reorder=True.
1600
1787
  """
1601
- return semantic.cat(input, other, can_reorder, _builder)
1788
+ return _semantic.cat(input, other, can_reorder)
1602
1789
 
1603
1790
 
1604
1791
  @builtin
1605
- def join(a, b, _builder=None):
1792
+ def join(a, b, _semantic=None):
1606
1793
  """
1607
1794
  Join the given tensors in a new, minor dimension.
1608
1795
 
@@ -1622,17 +1809,25 @@ def join(a, b, _builder=None):
1622
1809
  :param b: The second input tensor.
1623
1810
  :type b: Tensor
1624
1811
  """
1625
- return semantic.join(a, b, _builder)
1812
+ return _semantic.join(a, b)
1626
1813
 
1627
1814
 
1628
- @jit
1629
- def _take_first(a, b):
1630
- return a
1815
+ def _unsplat(x, _semantic=None, _generator=None):
1816
+ """
1817
+ Convert a single-element tensor to a scalar.
1818
+ """
1819
+ if len(x.shape) == 0:
1820
+ return x
1821
+ numel = 1
1822
+ for d in x.shape:
1823
+ numel *= d
1824
+ assert numel == 1, "can only unsplat single-element tensors"
1825
+ return _semantic.unsplat(x)
1631
1826
 
1632
1827
 
1633
1828
  @_tensor_member_fn
1634
1829
  @builtin
1635
- def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]:
1830
+ def split(a, _semantic=None, _generator=None) -> tuple[tensor, tensor]:
1636
1831
  """
1637
1832
  Split a tensor in two along its last dim, which must have size 2.
1638
1833
 
@@ -1649,25 +1844,25 @@ def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]:
1649
1844
  :type a: Tensor
1650
1845
  """
1651
1846
  # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars.
1652
- # But semantic.split can only handle returning tensors. Work around this by
1847
+ # But _semantic.split can only handle returning tensors. Work around this by
1653
1848
  # expanding the input to shape [1,2] and then reducing the result.
1654
1849
  was_rank_1 = len(a.shape) == 1
1655
1850
  if was_rank_1:
1656
- a = semantic.expand_dims(a, 0, _builder)
1851
+ a = _semantic.expand_dims(a, 0)
1657
1852
 
1658
- out_lhs, out_rhs = semantic.split(a, _builder)
1853
+ out_lhs, out_rhs = _semantic.split(a)
1659
1854
 
1660
1855
  if was_rank_1:
1661
1856
  # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar.
1662
- out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator))
1663
- out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator))
1857
+ out_lhs = _unsplat(out_lhs, _semantic=_semantic, _generator=_generator)
1858
+ out_rhs = _unsplat(out_rhs, _semantic=_semantic, _generator=_generator)
1664
1859
 
1665
1860
  return out_lhs, out_rhs
1666
1861
 
1667
1862
 
1668
1863
  @_tensor_member_fn
1669
1864
  @builtin
1670
- def view(input, *shape, _builder=None):
1865
+ def view(input, *shape, _semantic=None):
1671
1866
  """
1672
1867
  Returns a tensor with the same elements as `input` but a different shape.
1673
1868
  The order of the elements may not be preserved.
@@ -1684,12 +1879,21 @@ def view(input, *shape, _builder=None):
1684
1879
  """
1685
1880
  warn("view is deprecated, please use reshape with can_reorder being true.")
1686
1881
  shape = _shape_check_impl(_unwrap_iterable(shape))
1687
- return semantic.reshape(input, shape, can_reorder=True, builder=_builder)
1882
+ return _semantic.reshape(input, shape, can_reorder=True)
1688
1883
 
1689
1884
 
1690
1885
  @_tensor_member_fn
1691
1886
  @builtin
1692
- def reshape(input, *shape, can_reorder=False, _builder=None):
1887
+ def item(input, _semantic=None, _generator=None):
1888
+ """
1889
+ Converts a single-element tensor into a scalar.
1890
+ """
1891
+ return _unsplat(input, _semantic=_semantic, _generator=_generator)
1892
+
1893
+
1894
+ @_tensor_member_fn
1895
+ @builtin
1896
+ def reshape(input, *shape, can_reorder=False, _semantic=None, _generator=None):
1693
1897
  """
1694
1898
  Returns a tensor with the same number of elements as input but with the
1695
1899
  provided shape.
@@ -1705,7 +1909,9 @@ def reshape(input, *shape, can_reorder=False, _builder=None):
1705
1909
  reshape(x, 32, 32)
1706
1910
  """
1707
1911
  shape = _shape_check_impl(_unwrap_iterable(shape))
1708
- return semantic.reshape(input, shape, can_reorder, _builder)
1912
+ if len(shape) == 0:
1913
+ return _unsplat(input, _semantic=_semantic, _generator=_generator)
1914
+ return _semantic.reshape(input, shape, can_reorder)
1709
1915
 
1710
1916
 
1711
1917
  def _wrap_axis(axis, ndim):
@@ -1717,7 +1923,7 @@ def _wrap_axis(axis, ndim):
1717
1923
 
1718
1924
  @_tensor_member_fn
1719
1925
  @builtin
1720
- def expand_dims(input, axis, _builder=None):
1926
+ def expand_dims(input, axis, _semantic=None):
1721
1927
  """
1722
1928
  Expand the shape of a tensor, by inserting new length-1 dimensions.
1723
1929
 
@@ -1730,24 +1936,24 @@ def expand_dims(input, axis, _builder=None):
1730
1936
  :type axis: int | Sequence[int]
1731
1937
 
1732
1938
  """
1733
- input = semantic.to_tensor(input, _builder)
1734
- axis = _constexpr_to_value(axis)
1939
+ input = _semantic.to_tensor(input)
1940
+ axis = _unwrap_if_constexpr(axis)
1735
1941
  axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis]
1736
1942
  new_ndim = len(input.shape) + len(axes)
1737
- axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes]
1943
+ axes = [_wrap_axis(_unwrap_if_constexpr(d), new_ndim) for d in axes]
1738
1944
 
1739
1945
  if len(set(axes)) != len(axes):
1740
1946
  raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}")
1741
1947
 
1742
1948
  ret = input
1743
1949
  for a in sorted(axes):
1744
- ret = semantic.expand_dims(ret, a, _builder)
1950
+ ret = _semantic.expand_dims(ret, a)
1745
1951
  return ret
1746
1952
 
1747
1953
 
1748
1954
  @_tensor_member_fn
1749
1955
  @builtin
1750
- def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None):
1956
+ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
1751
1957
  """
1752
1958
  Casts a tensor to the given :code:`dtype`.
1753
1959
 
@@ -1763,13 +1969,13 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas
1763
1969
  :code:`dtype`, instead of being numerically casted.
1764
1970
  :type bitcast: bool, optional
1765
1971
  """
1766
- input = semantic.to_tensor(input, _builder)
1767
- dtype = _constexpr_to_value(dtype)
1768
- fp_downcast_rounding = _constexpr_to_value(fp_downcast_rounding)
1769
- bitcast = _constexpr_to_value(bitcast)
1972
+ input = _semantic.to_tensor(input)
1973
+ dtype = _unwrap_if_constexpr(dtype)
1974
+ fp_downcast_rounding = _unwrap_if_constexpr(fp_downcast_rounding)
1975
+ bitcast = _unwrap_if_constexpr(bitcast)
1770
1976
  if bitcast:
1771
- return semantic.bitcast(input, dtype, _builder)
1772
- return semantic.cast(input, dtype, _builder, fp_downcast_rounding)
1977
+ return _semantic.bitcast(input, dtype)
1978
+ return _semantic.cast(input, dtype, fp_downcast_rounding)
1773
1979
 
1774
1980
 
1775
1981
  # -----------------------
@@ -1779,7 +1985,7 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas
1779
1985
 
1780
1986
  @builtin
1781
1987
  def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32,
1782
- _builder=None):
1988
+ _semantic=None):
1783
1989
  """
1784
1990
  Returns the matrix product of two blocks.
1785
1991
 
@@ -1804,19 +2010,20 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
1804
2010
  """
1805
2011
  assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified"
1806
2012
  if input_precision is None:
1807
- supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions
1808
- default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee"
1809
- input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision)
2013
+ supports_tf32 = "tf32" in _semantic.builder.options.allowed_dot_input_precisions
2014
+ input_precision = knobs.language.fp32_default or ("tf32" if (supports_tf32 and
2015
+ (allow_tf32 or allow_tf32 is None)) else "ieee")
1810
2016
 
1811
- input_precision = _constexpr_to_value(input_precision)
1812
- out_dtype = _constexpr_to_value(out_dtype)
1813
- max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc)
1814
- return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder)
2017
+ input_precision = _unwrap_if_constexpr(input_precision)
2018
+ out_dtype = _unwrap_if_constexpr(out_dtype)
2019
+ max_num_imprecise_acc = _unwrap_if_constexpr(max_num_imprecise_acc)
2020
+ acc = _unwrap_if_constexpr(acc)
2021
+ return _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
1815
2022
 
1816
2023
 
1817
2024
  @builtin
1818
- def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, out_dtype=float32,
1819
- _builder=None):
2025
+ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True,
2026
+ rhs_k_pack=True, out_dtype=float32, _semantic=None):
1820
2027
  """
1821
2028
  Returns the matrix product of two blocks in microscaling format.
1822
2029
 
@@ -1843,11 +2050,15 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
1843
2050
  :param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
1844
2051
  :type rhs_format: str
1845
2052
  :param acc: The accumulator tensor. If not None, the result is added to this tensor.
2053
+ :param lhs_k_pack: If false, the lhs tensor is packed into uint8 along M dimension.
2054
+ :type lhs_k_pack: bool, optional
2055
+ :param rhs_k_pack: If false, the rhs tensor is packed into uint8 along N dimension.
2056
+ :type rhs_k_pack: bool, optional
1846
2057
  """
1847
- out_dtype = _constexpr_to_value(out_dtype)
2058
+ out_dtype = _unwrap_if_constexpr(out_dtype)
1848
2059
  assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment"
1849
- return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, out_dtype,
1850
- _builder)
2060
+ return _semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, lhs_k_pack,
2061
+ rhs_k_pack, out_dtype)
1851
2062
 
1852
2063
 
1853
2064
  # -----------------------
@@ -1857,7 +2068,7 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
1857
2068
 
1858
2069
  @builtin
1859
2070
  def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="",
1860
- volatile=False, _builder=None):
2071
+ volatile=False, _semantic=None):
1861
2072
  """
1862
2073
  Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
1863
2074
 
@@ -1892,8 +2103,9 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
1892
2103
  :type boundary_check: tuple of ints, optional
1893
2104
  :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value.
1894
2105
  :param cache_modifier: changes cache option in NVIDIA PTX
1895
- :type cache_modifier: str, optional, should be one of {"", "ca", "cg"}, where "ca" stands for
1896
- cache at all levels and "cg" stands for cache at global level (cache in L2 and below, not L1), see
2106
+ :type cache_modifier: str, optional, should be one of {"", ".ca", ".cg", ".cv"}, where ".ca" stands for
2107
+ cache at all levels, ".cg" stands for cache at global level (cache in L2 and below, not L1),
2108
+ and ".cv" means don’t cache and fetch again. see
1897
2109
  `cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
1898
2110
  :param eviction_policy: changes eviction policy in NVIDIA PTX
1899
2111
  :type eviction_policy: str, optional
@@ -1901,57 +2113,37 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
1901
2113
  :type volatile: bool, optional
1902
2114
  """
1903
2115
  # `mask` and `other` can be constexpr
1904
- mask = _constexpr_to_value(mask)
1905
- other = _constexpr_to_value(other)
2116
+ mask = _unwrap_if_constexpr(mask)
2117
+ other = _unwrap_if_constexpr(other)
1906
2118
  if mask is not None:
1907
- mask = semantic.to_tensor(mask, _builder)
2119
+ mask = _semantic.to_tensor(mask)
1908
2120
  if other is not None:
1909
- other = semantic.to_tensor(other, _builder)
1910
- padding_option = _constexpr_to_value(padding_option)
1911
- cache_modifier = _constexpr_to_value(cache_modifier)
1912
- eviction_policy = _constexpr_to_value(eviction_policy)
1913
- volatile = _constexpr_to_value(volatile)
1914
- return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
1915
- volatile, _builder)
1916
-
1917
-
1918
- @builtin
1919
- def _experimental_reinterpret_tensor_descriptor(desc_ptr, block_shape, dtype,
1920
- _builder=None) -> _experimental_tensor_descriptor_base:
1921
- """
1922
- Reinterpret a generic pointer as a TMA-backed tensor descriptor object.
1923
- """
1924
- block_ty = block_type(_constexpr_to_value(dtype), block_shape)
1925
- return semantic.reinterpret_tensor_descriptor(desc_ptr, block_ty, _builder)
2121
+ other = _semantic.to_tensor(other)
2122
+ padding_option = _unwrap_if_constexpr(padding_option)
2123
+ cache_modifier = _unwrap_if_constexpr(cache_modifier)
2124
+ eviction_policy = _unwrap_if_constexpr(eviction_policy)
2125
+ volatile = _unwrap_if_constexpr(volatile)
2126
+ return _semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
2127
+ volatile)
1926
2128
 
1927
2129
 
1928
2130
  @builtin
1929
- def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None):
1930
- """
1931
- Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations.
1932
- This will be removed in the future and shouldn't be used in production code.
1933
-
1934
- This loads a tensor of data based on the descriptor and offsets.
1935
- """
1936
- desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, shape, dtype, _builder=_builder)
1937
- return desc.load(offsets, _builder=_builder)
2131
+ def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor],
2132
+ _semantic=None) -> tensor:
2133
+ """Load a block of data from a tensor descriptor."""
2134
+ return desc.load(offsets, _semantic=_semantic)
1938
2135
 
1939
2136
 
1940
2137
  @builtin
1941
- def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None):
1942
- """
1943
- Experimental feature to access TMA descriptors stores. This is an escape hatch to easily exercise TTGIR operations.
1944
- This will be removed in the future and shouldn't be used in production code.
1945
-
1946
- This stores a tensor of data based on the descriptor and offsets.
1947
- """
1948
- desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, value.shape, value.dtype, _builder=_builder)
1949
- return desc.store(offsets, value, _builder=_builder)
2138
+ def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor], value: tensor,
2139
+ _semantic=None) -> tensor:
2140
+ """Store a block of data to a tensor descriptor."""
2141
+ return desc.store(offsets, value, _semantic=_semantic)
1950
2142
 
1951
2143
 
1952
2144
  @_tensor_member_fn
1953
2145
  @builtin
1954
- def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None):
2146
+ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _semantic=None):
1955
2147
  """
1956
2148
  Store a tensor of data into memory locations defined by `pointer`.
1957
2149
 
@@ -1991,17 +2183,17 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict
1991
2183
  :type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"}
1992
2184
  """
1993
2185
  # `value` can be constexpr
1994
- value = semantic.to_tensor(value, _builder)
1995
- mask = _constexpr_to_value(mask)
2186
+ value = _semantic.to_tensor(value)
2187
+ mask = _unwrap_if_constexpr(mask)
1996
2188
  if mask is not None:
1997
- mask = semantic.to_tensor(mask, _builder)
1998
- cache_modifier = _constexpr_to_value(cache_modifier)
1999
- eviction_policy = _constexpr_to_value(eviction_policy)
2000
- return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)
2189
+ mask = _semantic.to_tensor(mask)
2190
+ cache_modifier = _unwrap_if_constexpr(cache_modifier)
2191
+ eviction_policy = _unwrap_if_constexpr(eviction_policy)
2192
+ return _semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy)
2001
2193
 
2002
2194
 
2003
2195
  @builtin
2004
- def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None):
2196
+ def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _semantic=None):
2005
2197
  """
2006
2198
  Returns a pointer to a block in a parent tensor
2007
2199
 
@@ -2012,30 +2204,34 @@ def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _b
2012
2204
  :param block_shape: The shape of the block
2013
2205
  :param order: The order of the original data format
2014
2206
  """
2015
- return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
2207
+ return _semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order)
2016
2208
 
2017
2209
 
2210
+ @must_use_result(
2211
+ "Note that tl.advance does not have any side effects. To move the block pointer, you need to assign the result of tl.advance to a variable."
2212
+ )
2018
2213
  @_tensor_member_fn
2019
2214
  @builtin
2020
- def advance(base, offsets, _builder=None):
2215
+ def advance(base, offsets, _semantic=None):
2021
2216
  """
2022
2217
  Advance a block pointer
2023
2218
 
2024
2219
  :param base: the block pointer to advance
2025
2220
  :param offsets: the offsets to advance, a tuple by dimension
2026
2221
  """
2027
- return semantic.advance(base, offsets, _builder)
2222
+ return _semantic.advance(base, offsets)
2028
2223
 
2029
2224
 
2030
2225
  @builtin
2031
- def _experimental_make_tensor_descriptor(
2226
+ def make_tensor_descriptor(
2032
2227
  base: tensor,
2033
2228
  shape: List[tensor],
2034
2229
  strides: List[tensor],
2035
2230
  block_shape: List[constexpr],
2036
- _builder=None,
2037
- ) -> _experimental_tensor_descriptor:
2038
- """Make an experimental tensor descriptor object
2231
+ padding_option="zero",
2232
+ _semantic=None,
2233
+ ) -> tensor_descriptor:
2234
+ """Make a tensor descriptor object
2039
2235
 
2040
2236
  :param base: the base pointer of the tensor, must be 16-byte aligned
2041
2237
  :param shape: A list of non-negative integers representing the tensor shape
@@ -2056,7 +2252,7 @@ def _experimental_make_tensor_descriptor(
2056
2252
 
2057
2253
  @triton.jit
2058
2254
  def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
2059
- desc = tl._experimental_make_tensor_descriptor(
2255
+ desc = tl.make_tensor_descriptor(
2060
2256
  in_out_ptr,
2061
2257
  shape=[M, N],
2062
2258
  strides=[N, 1],
@@ -2082,7 +2278,9 @@ def _experimental_make_tensor_descriptor(
2082
2278
  inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)
2083
2279
 
2084
2280
  """
2085
- return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder)
2281
+
2282
+ padding_option = _unwrap_if_constexpr(padding_option)
2283
+ return _semantic.make_tensor_descriptor(base, shape, strides, block_shape, padding_option)
2086
2284
 
2087
2285
 
2088
2286
  # -----------------------
@@ -2124,89 +2322,89 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
2124
2322
  @_tensor_member_fn
2125
2323
  @builtin
2126
2324
  @_add_atomic_docstr("compare-and-swap", has_cmp=True)
2127
- def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None):
2128
- cmp = semantic.to_tensor(cmp, _builder)
2129
- val = semantic.to_tensor(val, _builder)
2130
- sem = _constexpr_to_value(sem)
2131
- scope = _constexpr_to_value(scope)
2132
- return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder)
2325
+ def atomic_cas(pointer, cmp, val, sem=None, scope=None, _semantic=None):
2326
+ cmp = _semantic.to_tensor(cmp)
2327
+ val = _semantic.to_tensor(val)
2328
+ sem = _unwrap_if_constexpr(sem)
2329
+ scope = _unwrap_if_constexpr(scope)
2330
+ return _semantic.atomic_cas(pointer, cmp, val, sem, scope)
2133
2331
 
2134
2332
 
2135
2333
  @_tensor_member_fn
2136
2334
  @builtin
2137
2335
  @_add_atomic_docstr("exchange")
2138
- def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2139
- val = semantic.to_tensor(val, _builder)
2140
- sem = _constexpr_to_value(sem)
2141
- scope = _constexpr_to_value(scope)
2142
- mask = _constexpr_to_value(mask)
2143
- return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder)
2336
+ def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2337
+ val = _semantic.to_tensor(val)
2338
+ sem = _unwrap_if_constexpr(sem)
2339
+ scope = _unwrap_if_constexpr(scope)
2340
+ mask = _unwrap_if_constexpr(mask)
2341
+ return _semantic.atomic_xchg(pointer, val, mask, sem, scope)
2144
2342
 
2145
2343
 
2146
2344
  @_tensor_member_fn
2147
2345
  @builtin
2148
2346
  @_add_atomic_docstr("add")
2149
- def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2150
- val = semantic.to_tensor(val, _builder)
2151
- sem = _constexpr_to_value(sem)
2152
- scope = _constexpr_to_value(scope)
2153
- mask = _constexpr_to_value(mask)
2154
- return semantic.atomic_add(pointer, val, mask, sem, scope, _builder)
2347
+ def atomic_add(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2348
+ val = _semantic.to_tensor(val)
2349
+ sem = _unwrap_if_constexpr(sem)
2350
+ scope = _unwrap_if_constexpr(scope)
2351
+ mask = _unwrap_if_constexpr(mask)
2352
+ return _semantic.atomic_add(pointer, val, mask, sem, scope)
2155
2353
 
2156
2354
 
2157
2355
  @_tensor_member_fn
2158
2356
  @builtin
2159
2357
  @_add_atomic_docstr("max")
2160
- def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2161
- val = semantic.to_tensor(val, _builder)
2162
- sem = _constexpr_to_value(sem)
2163
- scope = _constexpr_to_value(scope)
2164
- mask = _constexpr_to_value(mask)
2165
- return semantic.atomic_max(pointer, val, mask, sem, scope, _builder)
2358
+ def atomic_max(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2359
+ val = _semantic.to_tensor(val)
2360
+ sem = _unwrap_if_constexpr(sem)
2361
+ scope = _unwrap_if_constexpr(scope)
2362
+ mask = _unwrap_if_constexpr(mask)
2363
+ return _semantic.atomic_max(pointer, val, mask, sem, scope)
2166
2364
 
2167
2365
 
2168
2366
  @_tensor_member_fn
2169
2367
  @builtin
2170
2368
  @_add_atomic_docstr("min")
2171
- def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2172
- val = semantic.to_tensor(val, _builder)
2173
- sem = _constexpr_to_value(sem)
2174
- scope = _constexpr_to_value(scope)
2175
- mask = _constexpr_to_value(mask)
2176
- return semantic.atomic_min(pointer, val, mask, sem, scope, _builder)
2369
+ def atomic_min(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2370
+ val = _semantic.to_tensor(val)
2371
+ sem = _unwrap_if_constexpr(sem)
2372
+ scope = _unwrap_if_constexpr(scope)
2373
+ mask = _unwrap_if_constexpr(mask)
2374
+ return _semantic.atomic_min(pointer, val, mask, sem, scope)
2177
2375
 
2178
2376
 
2179
2377
  @_tensor_member_fn
2180
2378
  @builtin
2181
2379
  @_add_atomic_docstr("logical and")
2182
- def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2183
- val = semantic.to_tensor(val, _builder)
2184
- sem = _constexpr_to_value(sem)
2185
- scope = _constexpr_to_value(scope)
2186
- mask = _constexpr_to_value(mask)
2187
- return semantic.atomic_and(pointer, val, mask, sem, scope, _builder)
2380
+ def atomic_and(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2381
+ val = _semantic.to_tensor(val)
2382
+ sem = _unwrap_if_constexpr(sem)
2383
+ scope = _unwrap_if_constexpr(scope)
2384
+ mask = _unwrap_if_constexpr(mask)
2385
+ return _semantic.atomic_and(pointer, val, mask, sem, scope)
2188
2386
 
2189
2387
 
2190
2388
  @_tensor_member_fn
2191
2389
  @builtin
2192
2390
  @_add_atomic_docstr("logical or")
2193
- def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2194
- val = semantic.to_tensor(val, _builder)
2195
- sem = _constexpr_to_value(sem)
2196
- scope = _constexpr_to_value(scope)
2197
- mask = _constexpr_to_value(mask)
2198
- return semantic.atomic_or(pointer, val, mask, sem, scope, _builder)
2391
+ def atomic_or(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2392
+ val = _semantic.to_tensor(val)
2393
+ sem = _unwrap_if_constexpr(sem)
2394
+ scope = _unwrap_if_constexpr(scope)
2395
+ mask = _unwrap_if_constexpr(mask)
2396
+ return _semantic.atomic_or(pointer, val, mask, sem, scope)
2199
2397
 
2200
2398
 
2201
2399
  @_tensor_member_fn
2202
2400
  @builtin
2203
2401
  @_add_atomic_docstr("logical xor")
2204
- def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2205
- val = semantic.to_tensor(val, _builder)
2206
- sem = _constexpr_to_value(sem)
2207
- scope = _constexpr_to_value(scope)
2208
- mask = _constexpr_to_value(mask)
2209
- return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder)
2402
+ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2403
+ val = _semantic.to_tensor(val)
2404
+ sem = _unwrap_if_constexpr(sem)
2405
+ scope = _unwrap_if_constexpr(scope)
2406
+ mask = _unwrap_if_constexpr(mask)
2407
+ return _semantic.atomic_xor(pointer, val, mask, sem, scope)
2210
2408
 
2211
2409
 
2212
2410
  # -----------------------
@@ -2215,7 +2413,7 @@ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2215
2413
 
2216
2414
 
2217
2415
  @builtin
2218
- def where(condition, x, y, _builder=None):
2416
+ def where(condition, x, y, _semantic=None):
2219
2417
  """
2220
2418
  Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
2221
2419
 
@@ -2231,10 +2429,10 @@ def where(condition, x, y, _builder=None):
2231
2429
  :param x: values selected at indices where condition is True.
2232
2430
  :param y: values selected at indices where condition is False.
2233
2431
  """
2234
- condition = semantic.to_tensor(condition, _builder)
2432
+ condition = _semantic.to_tensor(condition)
2235
2433
  x = _unwrap_if_constexpr(x)
2236
2434
  y = _unwrap_if_constexpr(y)
2237
- return semantic.where(condition, x, y, _builder)
2435
+ return _semantic.where(condition, x, y)
2238
2436
 
2239
2437
 
2240
2438
  # -----------------------
@@ -2243,28 +2441,28 @@ def where(condition, x, y, _builder=None):
2243
2441
 
2244
2442
 
2245
2443
  @builtin
2246
- def add(x, y, sanitize_overflow: constexpr = True, _builder=None):
2444
+ def add(x, y, sanitize_overflow: constexpr = True, _semantic=None):
2247
2445
  x = _unwrap_if_constexpr(x)
2248
2446
  y = _unwrap_if_constexpr(y)
2249
- return semantic.add(x, y, sanitize_overflow, _builder)
2447
+ return _semantic.add(x, y, sanitize_overflow)
2250
2448
 
2251
2449
 
2252
2450
  @builtin
2253
- def sub(x, y, sanitize_overflow: constexpr = True, _builder=None):
2451
+ def sub(x, y, sanitize_overflow: constexpr = True, _semantic=None):
2254
2452
  x = _unwrap_if_constexpr(x)
2255
2453
  y = _unwrap_if_constexpr(y)
2256
- return semantic.sub(x, y, sanitize_overflow, _builder)
2454
+ return _semantic.sub(x, y, sanitize_overflow)
2257
2455
 
2258
2456
 
2259
2457
  @builtin
2260
- def mul(x, y, sanitize_overflow: constexpr = True, _builder=None):
2458
+ def mul(x, y, sanitize_overflow: constexpr = True, _semantic=None):
2261
2459
  x = _unwrap_if_constexpr(x)
2262
2460
  y = _unwrap_if_constexpr(y)
2263
- return semantic.mul(x, y, sanitize_overflow, _builder)
2461
+ return _semantic.mul(x, y, sanitize_overflow)
2264
2462
 
2265
2463
 
2266
2464
  @builtin
2267
- def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
2465
+ def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
2268
2466
  """
2269
2467
  Computes the element-wise minimum of :code:`x` and :code:`y`.
2270
2468
 
@@ -2277,16 +2475,16 @@ def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
2277
2475
 
2278
2476
  .. seealso:: :class:`tl.PropagateNan`
2279
2477
  """
2280
- x = semantic.to_tensor(x, _builder)
2281
- y = semantic.to_tensor(y, _builder)
2282
- x = _promote_bfloat16_to_float32(x, _builder=_builder)
2283
- y = _promote_bfloat16_to_float32(y, _builder=_builder)
2284
- propagate_nan = _constexpr_to_value(propagate_nan)
2285
- return semantic.minimum(x, y, propagate_nan, _builder)
2478
+ x = _semantic.to_tensor(x)
2479
+ y = _semantic.to_tensor(y)
2480
+ x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
2481
+ y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
2482
+ propagate_nan = _unwrap_if_constexpr(propagate_nan)
2483
+ return _semantic.minimum(x, y, propagate_nan)
2286
2484
 
2287
2485
 
2288
2486
  @builtin
2289
- def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
2487
+ def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
2290
2488
  """
2291
2489
  Computes the element-wise maximum of :code:`x` and :code:`y`.
2292
2490
 
@@ -2299,16 +2497,16 @@ def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
2299
2497
 
2300
2498
  .. seealso:: :class:`tl.PropagateNan`
2301
2499
  """
2302
- x = semantic.to_tensor(x, _builder)
2303
- y = semantic.to_tensor(y, _builder)
2304
- x = _promote_bfloat16_to_float32(x, _builder=_builder)
2305
- y = _promote_bfloat16_to_float32(y, _builder=_builder)
2306
- propagate_nan = _constexpr_to_value(propagate_nan)
2307
- return semantic.maximum(x, y, propagate_nan, _builder)
2500
+ x = _semantic.to_tensor(x)
2501
+ y = _semantic.to_tensor(y)
2502
+ x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
2503
+ y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
2504
+ propagate_nan = _unwrap_if_constexpr(propagate_nan)
2505
+ return _semantic.maximum(x, y, propagate_nan)
2308
2506
 
2309
2507
 
2310
2508
  @builtin
2311
- def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
2509
+ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
2312
2510
  """
2313
2511
  Clamps the input tensor :code:`x` within the range [min, max].
2314
2512
  Behavior when :code:`min` > :code:`max` is undefined.
@@ -2325,16 +2523,16 @@ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=No
2325
2523
 
2326
2524
  .. seealso:: :class:`tl.PropagateNan`
2327
2525
  """
2328
- x = semantic.to_tensor(x, _builder)
2329
- min = semantic.to_tensor(min, _builder)
2330
- max = semantic.to_tensor(max, _builder)
2331
- x = _promote_bfloat16_to_float32(x, _builder=_builder)
2332
- min = _promote_bfloat16_to_float32(min, _builder=_builder)
2333
- max = _promote_bfloat16_to_float32(max, _builder=_builder)
2526
+ x = _semantic.to_tensor(x)
2527
+ min = _semantic.to_tensor(min)
2528
+ max = _semantic.to_tensor(max)
2529
+ x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
2530
+ min = _promote_bfloat16_to_float32(min, _semantic=_semantic)
2531
+ max = _promote_bfloat16_to_float32(max, _semantic=_semantic)
2334
2532
 
2335
- propagate_nan = _constexpr_to_value(propagate_nan)
2533
+ propagate_nan = _unwrap_if_constexpr(propagate_nan)
2336
2534
 
2337
- return semantic.clamp(x, min, max, propagate_nan, _builder)
2535
+ return _semantic.clamp(x, min, max, propagate_nan)
2338
2536
 
2339
2537
 
2340
2538
  # -----------------------
@@ -2383,7 +2581,7 @@ def _insertion_guard(builder):
2383
2581
 
2384
2582
  @_tensor_member_fn
2385
2583
  @builtin
2386
- def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None):
2584
+ def reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
2387
2585
  """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
2388
2586
 
2389
2587
  :param input: the input tensor, or tuple of tensors
@@ -2397,64 +2595,65 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N
2397
2595
 
2398
2596
  """
2399
2597
  if isinstance(input, tensor):
2400
- return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0]
2598
+ return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic, _generator=_generator)[0]
2401
2599
 
2402
2600
  def make_combine_region(reduce_op):
2403
2601
  param_types = [t.type.scalar for t in input] * 2
2404
2602
  region = reduce_op.get_region(0)
2405
- with _insertion_guard(_builder):
2406
- to_ir = lambda T: T.to_ir(_builder)
2407
- block = _builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2603
+ builder = _semantic.builder
2604
+ with _insertion_guard(builder):
2605
+ to_ir = lambda T: T.to_ir(builder)
2606
+ block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2408
2607
  args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
2409
2608
  results = _generator.call_JitFunction(combine_fn, args, kwargs={})
2410
2609
  if isinstance(results, tensor):
2411
2610
  handles = [results.handle]
2412
2611
  else:
2413
2612
  handles = [r.handle for r in results]
2414
- _builder.create_reduce_ret(*handles)
2613
+ builder.create_reduce_ret(*handles)
2415
2614
 
2416
2615
  def expand_ndims(t, ndims):
2417
2616
  for _ in builtins.range(ndims):
2418
- t = expand_dims(t, 0, _builder=_builder)
2617
+ t = expand_dims(t, 0, _semantic=_semantic)
2419
2618
  return t
2420
2619
 
2421
- axis = _constexpr_to_value(axis)
2422
- keep_dims = _constexpr_to_value(keep_dims)
2620
+ axis = _unwrap_if_constexpr(axis)
2621
+ keep_dims = _unwrap_if_constexpr(keep_dims)
2423
2622
  if axis is not None:
2424
2623
  axis = _wrap_axis(axis, len(input[0].shape))
2425
- ret = semantic.reduction(input, axis, make_combine_region, _builder)
2624
+ ret = _semantic.reduction(input, axis, make_combine_region)
2426
2625
  if keep_dims:
2427
2626
  if axis is not None:
2428
- ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret)
2627
+ ret = tuple(expand_dims(t, axis, _semantic=_semantic) for t in ret)
2429
2628
  else:
2430
2629
  ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret)
2431
2630
  return ret
2432
2631
 
2433
2632
 
2434
2633
  @builtin
2435
- def _promote_bfloat16_to_float32(t, _builder=None):
2634
+ def _promote_bfloat16_to_float32(t, _semantic=None):
2436
2635
  scalar_ty = t.type.scalar
2437
2636
 
2438
2637
  # hardware doesn't support FMAX, FMIN, CMP for bfloat16
2439
2638
  if scalar_ty is bfloat16:
2440
- return t.to(float32, _builder=_builder)
2639
+ return t.to(float32, _semantic=_semantic)
2441
2640
  return t
2442
2641
 
2443
2642
 
2444
2643
  @builtin
2445
- def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None):
2446
- axis = _constexpr_to_value(axis)
2644
+ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
2645
+ axis = _unwrap_if_constexpr(axis)
2447
2646
  n = input.shape[axis]
2448
- index = arange(0, n, _builder=_builder)
2647
+ index = arange(0, n, _semantic=_semantic)
2449
2648
 
2450
2649
  if len(input.shape) > 1:
2451
2650
  # Broadcast index across the non-reduced axes
2452
2651
  axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))]
2453
2652
  del axes_to_expand[axis]
2454
- index = expand_dims(index, axes_to_expand, _builder=_builder)
2455
- index = broadcast_to(index, input.shape, _builder=_builder)
2653
+ index = expand_dims(index, axes_to_expand, _semantic=_semantic)
2654
+ index = broadcast_to(index, input.shape, _semantic=_semantic)
2456
2655
 
2457
- rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _builder=_builder,
2656
+ rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic,
2458
2657
  _generator=_generator)
2459
2658
  return rvalue, rindices
2460
2659
 
@@ -2464,7 +2663,7 @@ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None
2464
2663
  # -----------------------
2465
2664
 
2466
2665
 
2467
- def _add_scan_docstr(name: str) -> Callable[[T], T]:
2666
+ def _add_scan_docstr(name: str, dtype_arg: str = None) -> Callable[[T], T]:
2468
2667
 
2469
2668
  def _decorator(func: T) -> T:
2470
2669
  docstr = """
@@ -2473,7 +2672,15 @@ def _add_scan_docstr(name: str) -> Callable[[T], T]:
2473
2672
  :param input: the input values
2474
2673
  :type input: Tensor
2475
2674
  :param axis: the dimension along which the scan should be done
2476
- :type axis: int"""
2675
+ :type axis: int
2676
+ :param reverse: if true, the scan is performed in the reverse direction
2677
+ :type reverse: bool"""
2678
+
2679
+ if dtype_arg is not None:
2680
+ docstr += f"""
2681
+ :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. If not specified, small integer types (< 32 bits) are upcasted to prevent overflow. Note that :code:`tl.bfloat16` inputs are automatically promoted to :code:`tl.float32`.
2682
+ :type {dtype_arg}: tl.dtype"""
2683
+
2477
2684
  func.__doc__ = docstr.format(name=name)
2478
2685
  return func
2479
2686
 
@@ -2482,7 +2689,7 @@ def _add_scan_docstr(name: str) -> Callable[[T], T]:
2482
2689
 
2483
2690
  @_tensor_member_fn
2484
2691
  @builtin
2485
- def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _generator=None):
2692
+ def associative_scan(input, axis, combine_fn, reverse=False, _semantic=None, _generator=None):
2486
2693
  """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry
2487
2694
 
2488
2695
  :param input: the input tensor, or tuple of tensors
@@ -2496,46 +2703,52 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen
2496
2703
 
2497
2704
  """
2498
2705
  if isinstance(input, tensor):
2499
- return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0]
2706
+ return associative_scan((input, ), axis, combine_fn, reverse, _semantic=_semantic, _generator=_generator)[0]
2500
2707
 
2501
2708
  def make_combine_region(scan_op):
2502
2709
  param_types = [t.type.scalar for t in input] * 2
2503
2710
  region = scan_op.get_region(0)
2504
- with _insertion_guard(_builder):
2505
- to_ir = lambda T: T.to_ir(_builder)
2506
- block = _builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2711
+ builder = _semantic.builder
2712
+ with _insertion_guard(builder):
2713
+ to_ir = lambda T: T.to_ir(builder)
2714
+ block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2507
2715
  args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
2508
2716
  results = _generator.call_JitFunction(combine_fn, args, kwargs={})
2509
2717
  if isinstance(results, tensor):
2510
2718
  handles = [results.handle]
2511
2719
  else:
2512
2720
  handles = [r.handle for r in results]
2513
- _builder.create_scan_ret(*handles)
2721
+ builder.create_scan_ret(*handles)
2514
2722
 
2515
- axis = _constexpr_to_value(axis)
2723
+ axis = _unwrap_if_constexpr(axis)
2516
2724
  if axis is not None:
2517
2725
  axis = _wrap_axis(axis, len(input[0].shape))
2518
- return semantic.associative_scan(input, axis, make_combine_region, reverse, _builder)
2726
+ return _semantic.associative_scan(input, axis, make_combine_region, reverse)
2519
2727
 
2520
2728
 
2521
2729
  @_tensor_member_fn
2522
2730
  @builtin
2523
- def histogram(input, num_bins, _builder=None, _generator=None):
2731
+ def histogram(input, num_bins, mask=None, _semantic=None, _generator=None):
2524
2732
  """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0.
2525
2733
 
2526
2734
  :param input: the input tensor
2527
2735
  :type input: Tensor
2528
2736
  :param num_bins: number of histogram bins
2529
2737
  :type num_bins: int
2738
+ :param mask: if `mask[idx]` is false, exclude `input[idx]` from histogram
2739
+ :type mask: Block of `triton.int1`, optional
2530
2740
 
2531
2741
  """
2532
- num_bins = _constexpr_to_value(num_bins)
2533
- return semantic.histogram(input, num_bins, _builder)
2742
+ num_bins = _unwrap_if_constexpr(num_bins)
2743
+ mask = _unwrap_if_constexpr(mask)
2744
+ if mask is not None:
2745
+ mask = _semantic.to_tensor(mask)
2746
+ return _semantic.histogram(input, num_bins, mask)
2534
2747
 
2535
2748
 
2536
2749
  @_tensor_member_fn
2537
2750
  @builtin
2538
- def gather(src, index, axis, _builder=None):
2751
+ def gather(src, index, axis, _semantic=None):
2539
2752
  """Gather from a tensor along a given dimension.
2540
2753
 
2541
2754
  :param src: the source tensor
@@ -2546,8 +2759,81 @@ def gather(src, index, axis, _builder=None):
2546
2759
  :type axis: int
2547
2760
 
2548
2761
  """
2549
- axis = _constexpr_to_value(axis)
2550
- return semantic.gather(src, index, axis, _builder)
2762
+ axis = _unwrap_if_constexpr(axis)
2763
+ return _semantic.gather(src, index, axis)
2764
+
2765
+
2766
+ @builtin
2767
+ def map_elementwise(
2768
+ scalar_fn: Callable[..., Tuple[tensor, ...]],
2769
+ *args: tensor,
2770
+ pack=1,
2771
+ _semantic=None,
2772
+ _generator=None,
2773
+ ):
2774
+ '''
2775
+ Map a scalar function over a tensor.
2776
+
2777
+ The input tensors :code:`args` are implicitly broadcasted to the same shape.
2778
+
2779
+ This may be useful in allowing control flow over single elements in a tensor,
2780
+ for example a multi-branch function where one branch is more expensive. With
2781
+ :code:`tl.where` you are forced to calculate both sides of the branch, but
2782
+ with an if we only execute one side.
2783
+
2784
+ .. highlight:: python
2785
+ .. code-block:: python
2786
+
2787
+ @triton.jit
2788
+ def selu_scalar(x, alpha):
2789
+ if x > 0:
2790
+ return a
2791
+ else:
2792
+ return alpha * (tl.exp(x) - 1)
2793
+
2794
+ @triton.jit
2795
+ def selu(x, alpha):
2796
+ return tl.map_elementwise(selu_scalar, x, alpha)
2797
+
2798
+ :param scalar_fn: the function to map over.
2799
+ :param pack: the number of elements to be processed by one function call.
2800
+ :return: one tensor or a tuple of tensors, depending on the mapped function.
2801
+ '''
2802
+ # Build the block for the nested region first to discover the return types
2803
+ assert pack >= 1
2804
+ in_scalar_tys = [t.type.scalar for t in args]
2805
+ builder = _semantic.builder
2806
+ block = builder.new_block()
2807
+ scalar_args = []
2808
+ for i, ty in enumerate(in_scalar_tys):
2809
+ for j in builtins.range(pack):
2810
+ block.add_argument(ty.to_ir(builder))
2811
+ scalar_args.append(tensor(block.arg(i * pack + j), ty))
2812
+
2813
+ with _insertion_guard(builder):
2814
+ builder.set_insertion_point_to_start(block)
2815
+ scalar_results = _generator.call_JitFunction(scalar_fn, scalar_args, kwargs={})
2816
+
2817
+ is_single = isinstance(scalar_results, tensor)
2818
+ if is_single:
2819
+ scalar_results = scalar_results,
2820
+
2821
+ handles = [r.handle for r in scalar_results]
2822
+ builder.create_map_elementwise_ret(handles)
2823
+
2824
+ fn_result_types = [x.type for x in scalar_results]
2825
+ scalar_result_types = fn_result_types
2826
+ if pack > 1:
2827
+ scalar_result_types = fn_result_types[::pack]
2828
+ for offset in builtins.range(1, pack):
2829
+ assert scalar_result_types == fn_result_types[offset::pack], "type mismatch in unpacked results"
2830
+
2831
+ def make_elementwise_region(elementwise_op):
2832
+ region = elementwise_op.get_region(0)
2833
+ region.push_back(block)
2834
+
2835
+ result = _semantic.map_elementwise(args, scalar_result_types, pack, make_elementwise_region)
2836
+ return result[0] if is_single else result
2551
2837
 
2552
2838
 
2553
2839
  # -----------------------
@@ -2556,15 +2842,15 @@ def gather(src, index, axis, _builder=None):
2556
2842
 
2557
2843
 
2558
2844
  @builtin
2559
- def debug_barrier(_builder=None):
2845
+ def debug_barrier(_semantic=None):
2560
2846
  '''
2561
2847
  Insert a barrier to synchronize all threads in a block.
2562
2848
  '''
2563
- return semantic.debug_barrier(_builder)
2849
+ return _semantic.debug_barrier()
2564
2850
 
2565
2851
 
2566
2852
  @builtin
2567
- def multiple_of(input, values, _builder=None):
2853
+ def multiple_of(input, values, _semantic=None):
2568
2854
  """
2569
2855
  Let the compiler know that the values in :code:`input` are all multiples of :code:`value`.
2570
2856
  """
@@ -2576,11 +2862,11 @@ def multiple_of(input, values, _builder=None):
2576
2862
  if not isinstance(d.value, int):
2577
2863
  raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2578
2864
  values = [x.value for x in values]
2579
- return semantic.multiple_of(input, values)
2865
+ return _semantic.multiple_of(input, values)
2580
2866
 
2581
2867
 
2582
2868
  @builtin
2583
- def max_contiguous(input, values, _builder=None):
2869
+ def max_contiguous(input, values, _semantic=None):
2584
2870
  """
2585
2871
  Let the compiler know that the `value` first values in :code:`input` are contiguous.
2586
2872
  """
@@ -2592,11 +2878,11 @@ def max_contiguous(input, values, _builder=None):
2592
2878
  if not isinstance(d.value, int):
2593
2879
  raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2594
2880
  values = [x.value for x in values]
2595
- return semantic.max_contiguous(input, values)
2881
+ return _semantic.max_contiguous(input, values)
2596
2882
 
2597
2883
 
2598
2884
  @builtin
2599
- def max_constancy(input, values, _builder=None):
2885
+ def max_constancy(input, values, _semantic=None):
2600
2886
  """
2601
2887
  Let the compiler know that the `value` first values in :code:`input` are constant.
2602
2888
 
@@ -2611,15 +2897,15 @@ def max_constancy(input, values, _builder=None):
2611
2897
  if not isinstance(d.value, int):
2612
2898
  raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2613
2899
  values = [x.value for x in values]
2614
- return semantic.max_constancy(input, values)
2900
+ return _semantic.max_constancy(input, values)
2615
2901
 
2616
2902
 
2617
2903
  @builtin
2618
- def assume(cond, _builder=None):
2904
+ def assume(cond, _semantic=None):
2619
2905
  '''
2620
2906
  Allow compiler to assume the :code:`cond` is True.
2621
2907
  '''
2622
- return semantic.assume(semantic.to_tensor(cond, _builder), _builder)
2908
+ return _semantic.assume(_semantic.to_tensor(cond))
2623
2909
 
2624
2910
 
2625
2911
  # -----------------------
@@ -2628,7 +2914,7 @@ def assume(cond, _builder=None):
2628
2914
 
2629
2915
 
2630
2916
  @builtin
2631
- def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None):
2917
+ def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _semantic=None):
2632
2918
  '''
2633
2919
  Print the values at compile time. The parameters are the same as the builtin :code:`print`.
2634
2920
 
@@ -2644,7 +2930,7 @@ def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=Fals
2644
2930
 
2645
2931
 
2646
2932
  @builtin
2647
- def static_assert(cond, msg="", _builder=None):
2933
+ def static_assert(cond, msg="", _semantic=None):
2648
2934
  '''
2649
2935
  Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable
2650
2936
  is set.
@@ -2658,7 +2944,7 @@ def static_assert(cond, msg="", _builder=None):
2658
2944
 
2659
2945
 
2660
2946
  @builtin
2661
- def device_print(prefix, *args, hex=False, _builder=None):
2947
+ def device_print(prefix, *args, hex=False, _semantic=None):
2662
2948
  '''
2663
2949
  Print the values at runtime from the device. String formatting does not work for runtime values, so you should
2664
2950
  provide the values you want to print as arguments. The first value must be a string, all following values must
@@ -2692,7 +2978,7 @@ def device_print(prefix, *args, hex=False, _builder=None):
2692
2978
  :param hex: print all values as hex instead of decimal
2693
2979
  '''
2694
2980
  import string
2695
- prefix = _constexpr_to_value(prefix)
2981
+ prefix = _unwrap_if_constexpr(prefix)
2696
2982
  assert isinstance(prefix, str), f"{prefix} is not string"
2697
2983
  b_ascii = True
2698
2984
  for ch in prefix:
@@ -2702,12 +2988,12 @@ def device_print(prefix, *args, hex=False, _builder=None):
2702
2988
  assert b_ascii, f"{prefix} is not an ascii string"
2703
2989
  new_args = []
2704
2990
  for arg in args:
2705
- new_args.append(semantic.to_tensor(arg, _builder))
2706
- return semantic.device_print(prefix, new_args, hex, _builder)
2991
+ new_args.append(_semantic.to_tensor(arg))
2992
+ return _semantic.device_print(prefix, new_args, hex)
2707
2993
 
2708
2994
 
2709
2995
  @builtin
2710
- def device_assert(cond, msg="", _builder=None):
2996
+ def device_assert(cond, msg="", mask=None, _semantic=None):
2711
2997
  '''
2712
2998
  Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG`
2713
2999
  is set to a value besides :code:`0` in order for this to have any effect.
@@ -2725,13 +3011,16 @@ def device_assert(cond, msg="", _builder=None):
2725
3011
  :param cond: the condition to assert. This is required to be a boolean tensor.
2726
3012
  :param msg: the message to print if the assertion fails. This is required to be a string literal.
2727
3013
  '''
2728
- msg = _constexpr_to_value(msg)
2729
- return semantic.device_assert(semantic.to_tensor(cond, _builder), msg, _builder)
3014
+ msg = _unwrap_if_constexpr(msg)
3015
+ mask = _unwrap_if_constexpr(mask)
3016
+ if mask is not None:
3017
+ mask = _semantic.to_tensor(mask)
3018
+ return _semantic.device_assert(_semantic.to_tensor(cond), msg, mask)
2730
3019
 
2731
3020
 
2732
3021
  @builtin
2733
3022
  def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]],
2734
- is_pure: bool, pack: int, _builder=None):
3023
+ is_pure: bool, pack: int, _semantic=None):
2735
3024
  '''
2736
3025
  Execute inline assembly over a tensor. Essentially, this is :code:`map`
2737
3026
  where the function is inline assembly.
@@ -2816,13 +3105,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
2816
3105
  :param dtype: the element type(s) of the returned tensor(s)
2817
3106
  :param is_pure: if true, the compiler assumes the asm block has no side-effects
2818
3107
  :param pack: the number of elements to be processed by one instance of inline assembly
2819
- :param _builder: the builder
2820
3108
  :return: one tensor or a tuple of tensors of the given dtypes
2821
3109
  '''
2822
- asm = _constexpr_to_value(asm)
2823
- constraints = _constexpr_to_value(constraints)
2824
- pack = _constexpr_to_value(pack)
2825
- is_pure = _constexpr_to_value(is_pure)
3110
+ asm = _unwrap_if_constexpr(asm)
3111
+ constraints = _unwrap_if_constexpr(constraints)
3112
+ pack = _unwrap_if_constexpr(pack)
3113
+ is_pure = _unwrap_if_constexpr(is_pure)
2826
3114
 
2827
3115
  # Wrap `dtype` in a tuple if it's not already.
2828
3116
  try:
@@ -2835,10 +3123,9 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
2835
3123
  dtype = typing.cast(Sequence[_DtypeClass], dtype)
2836
3124
 
2837
3125
  res_tys = dtype
2838
- if dispatch_args := [semantic.to_tensor(arg, _builder) for arg in args]:
3126
+ if dispatch_args := [_semantic.to_tensor(arg) for arg in args]:
2839
3127
  bin_op_type_checking = partial(
2840
- semantic.binary_op_type_checking_impl,
2841
- builder=_builder,
3128
+ _semantic.binary_op_type_checking_impl,
2842
3129
  arithmetic_check=False,
2843
3130
  allow_lhs_ptr=True,
2844
3131
  allow_rhs_ptr=True,
@@ -2851,9 +3138,10 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
2851
3138
  # Change the shape of each argument based on the broadcast shape
2852
3139
  for i, item in enumerate(dispatch_args):
2853
3140
  dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg)
2854
- res_tys = [block_type(dt, broadcast_arg.shape) for dt in dtype]
3141
+ res_tys = [broadcast_arg.type.with_element_ty(dt) for dt in dtype]
2855
3142
  handles = [t.handle for t in dispatch_args]
2856
- call = _builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(_builder) for ty in res_tys], is_pure, pack)
3143
+ builder = _semantic.builder
3144
+ call = builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(builder) for ty in res_tys], is_pure, pack)
2857
3145
 
2858
3146
  if not has_multiple_outputs:
2859
3147
  return tensor(call.get_result(0), res_tys[0])
@@ -2865,7 +3153,7 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
2865
3153
  # -----------------------
2866
3154
 
2867
3155
 
2868
- class static_range:
3156
+ class static_range(base_value):
2869
3157
  """
2870
3158
  Iterator that counts upward forever.
2871
3159
 
@@ -2905,7 +3193,23 @@ class static_range:
2905
3193
  raise RuntimeError("static_range can only be used in @triton.jit'd functions")
2906
3194
 
2907
3195
 
2908
- class range:
3196
+ class async_task:
3197
+ """
3198
+ Context manager to run code fragments asynchronously.
3199
+ """
3200
+
3201
+ def __init__(self, task_ids, _builder=None):
3202
+ self.task_ids = list({_unwrap_if_constexpr(tid) for tid in task_ids})
3203
+ self.builder = _builder
3204
+
3205
+ def __enter__(self):
3206
+ self.builder.set_async_task_ids(self.task_ids)
3207
+
3208
+ def __exit__(self, exc_type, exc_value, traceback):
3209
+ self.builder.unset_async_task_ids()
3210
+
3211
+
3212
+ class range(base_value):
2909
3213
  """
2910
3214
  Iterator that counts upward forever.
2911
3215
 
@@ -2936,10 +3240,21 @@ class range:
2936
3240
  :param flatten: automatically flatten the loop nest starting at this loop to
2937
3241
  create a single flattened loop. The compiler will try to pipeline the
2938
3242
  flattened loop which can avoid stage stalling.
3243
+ :param warp_specialize: Enable automatic warp specialization on the loop.
3244
+ The compiler will attempt to partition memory, MMA, and vector
3245
+ operations in the loop into separate async partitions. This will
3246
+ increase the total number of warps required by the kernel.
3247
+ :param disable_licm: Tells the compiler it shouldn't hoist loop invariant
3248
+ code outside the loop. This is often useful to avoid creating long liveranges
3249
+ within a loop.
3250
+
3251
+ Note that warp specialization is only supported on Blackwell GPUs and
3252
+ only works on simple matmul loops. Support for arbitrary loops will be
3253
+ expanded over time.
2939
3254
  """
2940
3255
 
2941
3256
  def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
2942
- disallow_acc_multi_buffer=False, flatten=False):
3257
+ disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False):
2943
3258
  if step is None:
2944
3259
  self.step = constexpr(1)
2945
3260
  else:
@@ -2954,6 +3269,8 @@ class range:
2954
3269
  self.loop_unroll_factor = loop_unroll_factor
2955
3270
  self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
2956
3271
  self.flatten = flatten
3272
+ self.warp_specialize = warp_specialize
3273
+ self.disable_licm = disable_licm
2957
3274
 
2958
3275
  def __iter__(self):
2959
3276
  raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
@@ -2962,13 +3279,36 @@ class range:
2962
3279
  raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
2963
3280
 
2964
3281
 
3282
+ class condition(base_value):
3283
+ """
3284
+ While loop condition wrapper.
3285
+
3286
+ .. highlight:: python
3287
+ .. code-block:: python
3288
+
3289
+ @triton.jit
3290
+ def kernel(...):
3291
+ while tl.condition(c, disable_licm)
3292
+ ...
3293
+ :note: This is a special wrapper used to annotate while loops in the context of
3294
+ :code:`triton.jit` functions. It allows user to pass extra attributes to the compiler.
3295
+ :param disable_licm: Tells the compiler it shouldn't hoist loop invariant
3296
+ code outside the loop. This is often useful to avoid creating long liveranges
3297
+ within a loop.
3298
+ """
3299
+
3300
+ def __init__(self, arg1, disable_licm=False):
3301
+ self.condition = arg1
3302
+ self.disable_licm = disable_licm
3303
+
3304
+
2965
3305
  # -----------------------
2966
3306
  # Extern functions
2967
3307
  # -----------------------
2968
3308
 
2969
3309
 
2970
- def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
2971
- is_pure: bool, _builder=None):
3310
+ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_type: dtype, is_pure: bool,
3311
+ _semantic):
2972
3312
  '''
2973
3313
  Dispatch a function to a library
2974
3314
  :param func: the function to dispatch
@@ -2976,8 +3316,7 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
2976
3316
  :param lib_path: the path of the library
2977
3317
  :param args: the arguments of the function
2978
3318
  :param arg_type_symbol_dict: the type of the arguments
2979
- :param ret_shape: the shape of the return value
2980
- :param _builder: the builder
3319
+ :param ret_type: the type of the return value
2981
3320
  :return: the return value of the function
2982
3321
  '''
2983
3322
  if len(arg_type_symbol_dict) == 0:
@@ -3004,15 +3343,13 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
3004
3343
  f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
3005
3344
  else:
3006
3345
  symbol = arg_type_symbol_dict[arg_types][0]
3007
- ret_type = arg_type_symbol_dict[arg_types][1]
3008
- if ret_shape:
3009
- ret_type = block_type(ret_type, ret_shape)
3010
- return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
3346
+ builder = _semantic.builder
3347
+ return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(builder), is_pure), ret_type)
3011
3348
 
3012
3349
 
3013
3350
  @builtin
3014
3351
  def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
3015
- _builder=None):
3352
+ _semantic=None):
3016
3353
  '''
3017
3354
  Dispatch an elementwise function to a library
3018
3355
  :param lib_name: the name of the library
@@ -3020,20 +3357,20 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
3020
3357
  :param args: the arguments of the function
3021
3358
  :param arg_type_symbol_dict: the type of the arguments
3022
3359
  :param is_pure: whether the function is pure
3023
- :param _builder: the builder
3024
3360
  :return: the return value of the function
3025
3361
  '''
3026
3362
  dispatch_args = args.copy()
3027
3363
  all_scalar = True
3028
- ret_shape = None
3029
3364
  arg_types = []
3030
3365
  for i in builtins.range(len(dispatch_args)):
3031
- dispatch_args[i] = semantic.to_tensor(dispatch_args[i], _builder)
3366
+ dispatch_args[i] = _semantic.to_tensor(dispatch_args[i])
3032
3367
  arg_types.append(dispatch_args[i].dtype)
3033
3368
  if dispatch_args[i].type.is_block():
3034
3369
  all_scalar = False
3370
+
3371
+ arg_types = tuple(arg_types)
3372
+ ret_type = arg_type_symbol_dict[arg_types][1]
3035
3373
  if len(arg_types) > 0:
3036
- arg_types = tuple(arg_types)
3037
3374
  arithmetic_check = True
3038
3375
  # If there's a type tuple that is not supported by the library, we will do arithmetic check
3039
3376
  if arg_types in arg_type_symbol_dict:
@@ -3041,26 +3378,26 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
3041
3378
  broadcast_arg = dispatch_args[0]
3042
3379
  # Get the broadcast shape over all the arguments
3043
3380
  for item in dispatch_args:
3044
- _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
3045
- arithmetic_check=arithmetic_check)
3381
+ _, broadcast_arg = _semantic.binary_op_type_checking_impl(item, broadcast_arg,
3382
+ arithmetic_check=arithmetic_check)
3046
3383
  # Change the shape of each argument based on the broadcast shape
3047
3384
  for i in builtins.range(len(dispatch_args)):
3048
- dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
3049
- arithmetic_check=arithmetic_check)
3385
+ dispatch_args[i], _ = _semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg,
3386
+ arithmetic_check=arithmetic_check)
3050
3387
  if not all_scalar:
3051
- ret_shape = broadcast_arg.shape
3052
- func = _builder.create_extern_elementwise
3053
- return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder)
3388
+ ret_type = broadcast_arg.type.with_element_ty(ret_type)
3389
+ func = _semantic.builder.create_extern_elementwise
3390
+ return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_type, is_pure, _semantic)
3054
3391
 
3055
3392
 
3056
- def binary_op_type_legalization(lhs, rhs, builder):
3393
+ def binary_op_type_legalization(lhs, rhs, semantic):
3057
3394
  '''
3058
3395
  Convert both operands to a single common type
3059
3396
  :param lhs: the left operand
3060
3397
  :param rhs: the right operand
3061
3398
  :param builder: the builder
3062
3399
  '''
3063
- return semantic.binary_op_type_checking_impl(lhs, rhs, builder)
3400
+ return semantic.binary_op_type_checking_impl(lhs, rhs)
3064
3401
 
3065
3402
 
3066
3403
  def extern(fn):