triton-windows 3.3.0.post19__cp311-cp311-win_amd64.whl → 3.4.0.post20__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (173) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/runtime/tcc/lib/python310.def +1610 -0
  56. triton/runtime/tcc/lib/python311.def +1633 -0
  57. triton/runtime/tcc/lib/python312.def +1703 -0
  58. triton/runtime/tcc/lib/python313.def +1651 -0
  59. triton/runtime/tcc/lib/python313t.def +1656 -0
  60. triton/runtime/tcc/lib/python39.def +1644 -0
  61. triton/runtime/tcc/lib/python3t.def +905 -0
  62. triton/testing.py +16 -12
  63. triton/tools/disasm.py +3 -4
  64. triton/tools/tensor_descriptor.py +36 -0
  65. triton/windows_utils.py +14 -6
  66. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  67. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  68. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
  69. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  70. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  71. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  72. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  73. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  80. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  81. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  82. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  83. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  84. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  85. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  86. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  87. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  88. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  89. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  90. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  91. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  92. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  93. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  94. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  95. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  96. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  97. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  98. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  99. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  100. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  101. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  102. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  103. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  104. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  105. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  106. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  107. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  108. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  109. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  110. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  111. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  112. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  113. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  114. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  115. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  116. triton/backends/amd/include/hip/device_functions.h +0 -38
  117. triton/backends/amd/include/hip/driver_types.h +0 -468
  118. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  119. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  120. triton/backends/amd/include/hip/hip_common.h +0 -100
  121. triton/backends/amd/include/hip/hip_complex.h +0 -38
  122. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  123. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  124. triton/backends/amd/include/hip/hip_ext.h +0 -161
  125. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  126. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  127. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  128. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  129. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  130. triton/backends/amd/include/hip/hip_profile.h +0 -27
  131. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  132. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  133. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  134. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  135. triton/backends/amd/include/hip/hip_version.h +0 -17
  136. triton/backends/amd/include/hip/hiprtc.h +0 -421
  137. triton/backends/amd/include/hip/library_types.h +0 -78
  138. triton/backends/amd/include/hip/math_functions.h +0 -42
  139. triton/backends/amd/include/hip/surface_types.h +0 -63
  140. triton/backends/amd/include/hip/texture_types.h +0 -194
  141. triton/backends/amd/include/hsa/Brig.h +0 -1131
  142. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  143. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  144. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  145. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  146. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  147. triton/backends/amd/include/hsa/hsa.h +0 -5738
  148. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  149. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  150. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  151. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  152. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  153. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  154. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  155. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  156. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  157. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  158. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  159. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  160. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  161. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  162. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  163. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  164. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  165. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  166. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  167. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  168. triton/backends/amd/include/roctracer/roctx.h +0 -229
  169. triton/language/_utils.py +0 -21
  170. triton/language/extra/cuda/_experimental_tma.py +0 -106
  171. triton/tools/experimental_descriptor.py +0 -32
  172. triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
  173. triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
triton/language/core.py CHANGED
@@ -6,14 +6,14 @@ from enum import Enum
6
6
  from functools import partial, wraps
7
7
  import typing
8
8
  from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
9
+ from dataclasses import dataclass
9
10
  import builtins
10
- from ..runtime.jit import jit
11
+ from .. import knobs
12
+ from ..runtime.jit import jit, JITFunction
11
13
  import inspect
12
- import os
13
14
 
14
15
  from .._C.libtriton import ir
15
- from . import semantic
16
- from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape
16
+ from .._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape, get_primitive_bitwidth
17
17
 
18
18
  T = TypeVar('T')
19
19
 
@@ -22,15 +22,23 @@ TRITON_BUILTIN = "__triton_builtin__"
22
22
  PropagateNan = ir.PROPAGATE_NAN
23
23
 
24
24
 
25
+ def must_use_result(x, s=True):
26
+ """If the result of this function is unused, throw an error."""
27
+ if isinstance(x, str):
28
+ return (lambda fn: must_use_result(fn, x))
29
+ x._must_use_result = s
30
+ return x
31
+
32
+
25
33
  def builtin(fn: T) -> T:
26
34
  """Mark a function as a builtin."""
27
35
  assert callable(fn)
28
36
 
29
37
  @wraps(fn)
30
38
  def wrapper(*args, **kwargs):
31
- if "_builder" not in kwargs or kwargs["_builder"] is None:
39
+ if "_semantic" not in kwargs or kwargs["_semantic"] is None:
32
40
  raise ValueError("Did you forget to add @triton.jit ? "
33
- "(`_builder` argument must be provided outside of JIT functions.)")
41
+ "(`_semantic` argument must be provided outside of JIT functions.)")
34
42
  return fn(*args, **kwargs)
35
43
 
36
44
  setattr(wrapper, TRITON_BUILTIN, True)
@@ -53,8 +61,8 @@ def _tensor_member_fn(fn: T) -> T:
53
61
  """
54
62
  assert callable(fn)
55
63
  orig_sig = inspect.signature(fn)
56
- # Does fn take args other than _builder, _generator, and the tensor itself?
57
- has_args = len(orig_sig.parameters.keys() - {"_builder", "_generator"}) > 1
64
+ # Does fn take args other than _semantic, _generator, and the tensor itself?
65
+ has_args = len(orig_sig.parameters.keys() - {"_semantic", "_generator"}) > 1
58
66
 
59
67
  if not fn.__doc__:
60
68
  fn.__doc__ = ""
@@ -78,7 +86,7 @@ def _tensor_member_fn(fn: T) -> T:
78
86
  if is_builtin(fn):
79
87
  setattr(wrapper, TRITON_BUILTIN, True)
80
88
 
81
- setattr(tensor, fn.__name__, wrapper)
89
+ setattr(tensor, fn.__name__, fn if isinstance(fn, JITFunction) else wrapper)
82
90
  return fn
83
91
 
84
92
 
@@ -110,8 +118,8 @@ def is_builtin(fn) -> bool:
110
118
 
111
119
 
112
120
  @builtin
113
- def to_tensor(x, _builder=None):
114
- return semantic.to_tensor(x, _builder)
121
+ def to_tensor(x, _semantic=None):
122
+ return _semantic.to_tensor(x)
115
123
 
116
124
 
117
125
  # -----------------------
@@ -130,7 +138,62 @@ class const:
130
138
  pass
131
139
 
132
140
 
133
- class constexpr:
141
+ class base_value:
142
+ """Base class of values that exist in the triton IR (i.e. not constexprs).
143
+ """
144
+ type: base_type
145
+
146
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
147
+ """Flatten frontend value into a sequence of mlir handles, which are appended
148
+ to the output list
149
+ """
150
+ raise NotImplementedError
151
+
152
+
153
+ class base_type:
154
+
155
+ def __eq__(self, other):
156
+ raise NotImplementedError("Types must implement __eq__")
157
+
158
+ def __ne__(self, other):
159
+ return not (self == other)
160
+
161
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
162
+ """Build a frontend value with the current dtype, wrapping a list of existing handles.
163
+ cursor is the index of the first handle relevant to this value, and the function
164
+ should return the updated cursor position after any handles consumed by the created value.
165
+ """
166
+ raise NotImplementedError
167
+
168
+ def mangle(self) -> str:
169
+ raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}")
170
+
171
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
172
+ raise NotImplementedError
173
+
174
+
175
+ class constexpr_type(base_type):
176
+
177
+ def __init__(self, value):
178
+ self.value = value
179
+
180
+ def __eq__(self, other):
181
+ return self.value == other.value
182
+
183
+ def __repr__(self) -> str:
184
+ return f"constexpr[{self.value}]"
185
+
186
+ def mangle(self) -> str:
187
+ return repr(self)
188
+
189
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
190
+ return
191
+
192
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
193
+ return constexpr(self.value), cursor
194
+
195
+
196
+ class constexpr(base_value):
134
197
  """
135
198
  This class is used to store a value that is known at compile-time.
136
199
  """
@@ -140,80 +203,83 @@ class constexpr:
140
203
  self.value = value.value
141
204
  else:
142
205
  self.value = value
143
- self.type = constexpr
206
+ self.type = constexpr_type(value)
144
207
 
145
208
  def __repr__(self) -> str:
146
209
  return f"constexpr[{self.value}]"
147
210
 
211
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
212
+ return
213
+
148
214
  def __index__(self):
149
215
  return self.value
150
216
 
151
217
  # In interpreter mode, constant values are not wrapped in constexpr,
152
218
  # and therefore do not have a .value attribute.
153
- # As a result, from here and below, we need to call the _constexpr_to_value
219
+ # As a result, from here and below, we need to call the _unwrap_if_constexpr
154
220
  # function to obtain either constexpr.value or the value itself.
155
221
  def __add__(self, other):
156
- return constexpr(self.value + _constexpr_to_value(other))
222
+ return constexpr(self.value + _unwrap_if_constexpr(other))
157
223
 
158
224
  def __radd__(self, other):
159
- return constexpr(_constexpr_to_value(other) + self.value)
225
+ return constexpr(_unwrap_if_constexpr(other) + self.value)
160
226
 
161
227
  def __sub__(self, other):
162
- return constexpr(self.value - _constexpr_to_value(other))
228
+ return constexpr(self.value - _unwrap_if_constexpr(other))
163
229
 
164
230
  def __rsub__(self, other):
165
- return constexpr(_constexpr_to_value(other) - self.value)
231
+ return constexpr(_unwrap_if_constexpr(other) - self.value)
166
232
 
167
233
  def __mul__(self, other):
168
- return constexpr(self.value * _constexpr_to_value(other))
234
+ return constexpr(self.value * _unwrap_if_constexpr(other))
169
235
 
170
236
  def __mod__(self, other):
171
- return constexpr(self.value % _constexpr_to_value(other))
237
+ return constexpr(self.value % _unwrap_if_constexpr(other))
172
238
 
173
239
  def __rmul__(self, other):
174
- return constexpr(_constexpr_to_value(other) * self.value)
240
+ return constexpr(_unwrap_if_constexpr(other) * self.value)
175
241
 
176
242
  def __truediv__(self, other):
177
- return constexpr(self.value / _constexpr_to_value(other))
243
+ return constexpr(self.value / _unwrap_if_constexpr(other))
178
244
 
179
245
  def __rtruediv__(self, other):
180
- return constexpr(_constexpr_to_value(other) / self.value)
246
+ return constexpr(_unwrap_if_constexpr(other) / self.value)
181
247
 
182
248
  def __floordiv__(self, other):
183
- return constexpr(self.value // _constexpr_to_value(other))
249
+ return constexpr(self.value // _unwrap_if_constexpr(other))
184
250
 
185
251
  def __rfloordiv__(self, other):
186
- return constexpr(_constexpr_to_value(other) // self.value)
252
+ return constexpr(_unwrap_if_constexpr(other) // self.value)
187
253
 
188
254
  def __gt__(self, other):
189
- return constexpr(self.value > _constexpr_to_value(other))
255
+ return constexpr(self.value > _unwrap_if_constexpr(other))
190
256
 
191
257
  def __rgt__(self, other):
192
- return constexpr(_constexpr_to_value(other) > self.value)
258
+ return constexpr(_unwrap_if_constexpr(other) > self.value)
193
259
 
194
260
  def __ge__(self, other):
195
- return constexpr(self.value >= _constexpr_to_value(other))
261
+ return constexpr(self.value >= _unwrap_if_constexpr(other))
196
262
 
197
263
  def __rge__(self, other):
198
- return constexpr(_constexpr_to_value(other) >= self.value)
264
+ return constexpr(_unwrap_if_constexpr(other) >= self.value)
199
265
 
200
266
  def __lt__(self, other):
201
- return constexpr(self.value < _constexpr_to_value(other))
267
+ return constexpr(self.value < _unwrap_if_constexpr(other))
202
268
 
203
269
  def __rlt__(self, other):
204
- return constexpr(_constexpr_to_value(other) < self.value)
270
+ return constexpr(_unwrap_if_constexpr(other) < self.value)
205
271
 
206
272
  def __le__(self, other):
207
- return constexpr(self.value <= _constexpr_to_value(other))
273
+ return constexpr(self.value <= _unwrap_if_constexpr(other))
208
274
 
209
275
  def __rle__(self, other):
210
- return constexpr(_constexpr_to_value(other) <= self.value)
276
+ return constexpr(_unwrap_if_constexpr(other) <= self.value)
211
277
 
212
278
  def __eq__(self, other):
213
- return constexpr(self.value == _constexpr_to_value(other))
279
+ return constexpr(self.value == _unwrap_if_constexpr(other))
214
280
 
215
281
  def __ne__(self, other):
216
- return constexpr(self.value != _constexpr_to_value(other))
282
+ return constexpr(self.value != _unwrap_if_constexpr(other))
217
283
 
218
284
  def __bool__(self):
219
285
  return bool(self.value)
@@ -222,19 +288,19 @@ class constexpr:
222
288
  return constexpr(-self.value)
223
289
 
224
290
  def __and__(self, other):
225
- return constexpr(self.value & _constexpr_to_value(other))
291
+ return constexpr(self.value & _unwrap_if_constexpr(other))
226
292
 
227
293
  def logical_and(self, other):
228
- return constexpr(self.value and _constexpr_to_value(other))
294
+ return constexpr(self.value and _unwrap_if_constexpr(other))
229
295
 
230
296
  def __or__(self, other):
231
- return constexpr(self.value | _constexpr_to_value(other))
297
+ return constexpr(self.value | _unwrap_if_constexpr(other))
232
298
 
233
299
  def __xor__(self, other):
234
- return constexpr(self.value ^ _constexpr_to_value(other))
300
+ return constexpr(self.value ^ _unwrap_if_constexpr(other))
235
301
 
236
302
  def logical_or(self, other):
237
- return constexpr(self.value or _constexpr_to_value(other))
303
+ return constexpr(self.value or _unwrap_if_constexpr(other))
238
304
 
239
305
  def __pos__(self):
240
306
  return constexpr(+self.value)
@@ -243,16 +309,16 @@ class constexpr:
243
309
  return constexpr(~self.value)
244
310
 
245
311
  def __pow__(self, other):
246
- return constexpr(self.value**_constexpr_to_value(other))
312
+ return constexpr(self.value**_unwrap_if_constexpr(other))
247
313
 
248
314
  def __rpow__(self, other):
249
- return constexpr(_constexpr_to_value(other)**self.value)
315
+ return constexpr(_unwrap_if_constexpr(other)**self.value)
250
316
 
251
317
  def __rshift__(self, other):
252
- return constexpr(self.value >> _constexpr_to_value(other))
318
+ return constexpr(self.value >> _unwrap_if_constexpr(other))
253
319
 
254
320
  def __lshift__(self, other):
255
- return constexpr(self.value << _constexpr_to_value(other))
321
+ return constexpr(self.value << _unwrap_if_constexpr(other))
256
322
 
257
323
  def __not__(self):
258
324
  return constexpr(not self.value)
@@ -263,14 +329,57 @@ class constexpr:
263
329
  def __call__(self, *args, **kwds):
264
330
  return self.value(*args, **kwds)
265
331
 
332
+ def __getitem__(self, *args):
333
+ args = (_unwrap_if_constexpr(x) for x in _normalize_tuple(args))
334
+ return self.value.__getitem__(*args)
335
+
336
+
337
+ def constexpr_function(f):
338
+ """
339
+ Wraps an arbitrary Python function so that it can be called at
340
+ compile-time on constexpr arguments in a Triton function and
341
+ returns a constexpr result.
342
+ """
343
+
344
+ @wraps(f)
345
+ def wrapper(*args, _semantic=None, **kwargs):
346
+ # de-constexpr arguments and discard the _semantic keyword argument:
347
+ args = [_unwrap_if_constexpr(x) for x in args]
348
+ kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()}
349
+
350
+ # call the raw Python function f:
351
+ res = f(*args, **kwargs)
352
+
353
+ # convert result back to a Triton constexpr:
354
+ return constexpr(res)
355
+
356
+ # disguise the function as a Triton builtin to avoid raising an error
357
+ # that we're calling a non-JIT function from within a Triton kernel:
358
+ wrapper.__triton_builtin__ = True
359
+ wrapper.__module__ = constexpr_function.__module__
360
+ return wrapper
361
+
266
362
 
267
363
  CONSTEXPR_0 = constexpr(0)
268
364
 
269
365
 
270
366
  def _unwrap_if_constexpr(o):
367
+ if isinstance(o, list):
368
+ return [_unwrap_if_constexpr(x) for x in o]
369
+ if isinstance(o, builtins.tuple):
370
+ return builtins.tuple(_unwrap_if_constexpr(x) for x in o)
371
+ if isinstance(o, tuple):
372
+ return tuple(_unwrap_if_constexpr(x) for x in o)
271
373
  return o.value if isinstance(o, constexpr) else o
272
374
 
273
375
 
376
+ def _normalize_tuple(t):
377
+ normalized_tuple = _unwrap_if_constexpr(t)
378
+ if isinstance(normalized_tuple, (list, builtins.tuple)):
379
+ normalized_tuple = tuple(normalized_tuple)
380
+ return normalized_tuple
381
+
382
+
274
383
  def check_bit_width(value, shift_value):
275
384
  if isinstance(value, tensor) and isinstance(shift_value, constexpr):
276
385
  bitwidth = value.type.scalar.primitive_bitwidth
@@ -280,34 +389,6 @@ def check_bit_width(value, shift_value):
280
389
  )
281
390
 
282
391
 
283
- class base_value:
284
- """Base class of values that exist in the triton IR (i.e. not constexprs).
285
- """
286
- type: base_type
287
-
288
- def _flatten_ir(self, handles: List[ir.value]) -> None:
289
- """Flatten frontend value into a sequence of mlir handles, which are appended
290
- to the output list
291
- """
292
- raise NotImplementedError
293
-
294
-
295
- class base_type:
296
-
297
- def __eq__(self, other):
298
- raise NotImplementedError("Types must implement __eq__")
299
-
300
- def __ne__(self, other):
301
- return not (self == other)
302
-
303
- def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
304
- """Build a frontend value with the current dtype, wrapping a list of existing handles.
305
- cursor is the index of the first handle relevant to this value, and the function
306
- should return the updated cursor position after any handles consumed by the created value.
307
- """
308
- raise NotImplementedError
309
-
310
-
311
392
  # -----------------------
312
393
  # dtype
313
394
  # -----------------------
@@ -333,55 +414,44 @@ class dtype(base_type):
333
414
  name = _unwrap_if_constexpr(name)
334
415
  self.name = name
335
416
  assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name
417
+ self.primitive_bitwidth = get_primitive_bitwidth(name)
418
+ self.itemsize = self.primitive_bitwidth // 8
336
419
  if name in dtype.SINT_TYPES:
337
420
  self.int_signedness = dtype.SIGNEDNESS.SIGNED
338
- self.int_bitwidth = int(name.split('int')[-1])
339
- self.primitive_bitwidth = self.int_bitwidth
421
+ self.int_bitwidth = self.primitive_bitwidth
340
422
  elif name in dtype.UINT_TYPES:
341
423
  self.int_signedness = dtype.SIGNEDNESS.UNSIGNED
342
- self.int_bitwidth = int(name.split('int')[-1])
343
- self.primitive_bitwidth = self.int_bitwidth
424
+ self.int_bitwidth = self.primitive_bitwidth
344
425
  elif name in dtype.FP_TYPES:
345
426
  if name == 'fp8e4b15':
346
427
  self.fp_mantissa_width = 3
347
- self.primitive_bitwidth = 8
348
428
  self.exponent_bias = 15
349
429
  elif name == 'fp8e4nv':
350
430
  self.fp_mantissa_width = 3
351
- self.primitive_bitwidth = 8
352
431
  self.exponent_bias = 7
353
432
  elif name == 'fp8e4b8':
354
433
  self.fp_mantissa_width = 3
355
- self.primitive_bitwidth = 8
356
434
  self.exponent_bias = 8
357
435
  elif name == 'fp8e5':
358
436
  self.fp_mantissa_width = 2
359
- self.primitive_bitwidth = 8
360
437
  self.exponent_bias = 15
361
438
  elif name == 'fp8e5b16':
362
439
  self.fp_mantissa_width = 2
363
- self.primitive_bitwidth = 8
364
440
  self.exponent_bias = 16
365
441
  elif name == 'fp16':
366
442
  self.fp_mantissa_width = 10
367
- self.primitive_bitwidth = 16
368
443
  self.exponent_bias = 15
369
444
  elif name == 'bf16':
370
445
  self.fp_mantissa_width = 7
371
- self.primitive_bitwidth = 16
372
446
  self.exponent_bias = 127
373
447
  elif name == 'fp32':
374
448
  self.fp_mantissa_width = 23
375
- self.primitive_bitwidth = 32
376
449
  self.exponent_bias = 127
377
450
  elif name == 'fp64':
378
451
  self.fp_mantissa_width = 52
379
- self.primitive_bitwidth = 64
380
452
  self.exponent_bias = 1023
381
453
  else:
382
454
  raise RuntimeError(f'Unsupported floating-point type {name}')
383
- elif name == 'void':
384
- self.primitive_bitwidth = 0
385
455
 
386
456
  def is_fp8(self):
387
457
  return 'fp8' in self.name
@@ -502,10 +572,6 @@ class dtype(base_type):
502
572
  def is_const():
503
573
  return False
504
574
 
505
- @staticmethod
506
- def is_tuple():
507
- return False
508
-
509
575
  def __eq__(self, other: dtype):
510
576
  if not isinstance(other, dtype):
511
577
  return False
@@ -518,13 +584,14 @@ class dtype(base_type):
518
584
  def scalar(self):
519
585
  return self
520
586
 
587
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
588
+ out.append(self.to_ir(builder))
589
+
521
590
  def to_ir(self, builder: ir.builder) -> ir.type:
522
591
  if self.name.startswith("fp8"):
523
592
  if self.name not in builder.options.supported_fp8_dtypes:
524
593
  raise ValueError(f'type {self} not supported in this architecture. '
525
594
  f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}')
526
- if self.name in builder.options.deprecated_fp8_dtypes:
527
- warn(f"{self.name} is deprecated in this architecture and will be removed in a future triton release")
528
595
 
529
596
  if self.name == 'void':
530
597
  return builder.get_void_ty()
@@ -581,6 +648,21 @@ class dtype(base_type):
581
648
  def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
582
649
  return tensor(handles[cursor], self), cursor + 1
583
650
 
651
+ def mangle(self) -> str:
652
+ if self.is_int():
653
+ SIGNED = dtype.SIGNEDNESS.SIGNED
654
+ prefix = 'i' if self.int_signedness == SIGNED else 'u'
655
+ return prefix + str(self.int_bitwidth)
656
+ if self.is_floating():
657
+ return str(self)
658
+ if self.is_void():
659
+ return 'V'
660
+ return super().mangle()
661
+
662
+ def with_element_ty(self, element_ty: dtype):
663
+ assert not self.is_block()
664
+ return element_ty
665
+
584
666
 
585
667
  # Some functions have a param named `dtype`, which shadows the `dtype` class.
586
668
  # We can't change the param name because it is part of function's public API.
@@ -623,12 +705,8 @@ class pointer_type(dtype):
623
705
  def scalar(self):
624
706
  return self
625
707
 
626
-
627
- class nv_tma_desc_type(pointer_type):
628
-
629
- def __init__(self, const=True, address_space=0):
630
- super().__init__(uint8, const=const, address_space=address_space)
631
- self.name = 'nv_tma_desc_type'
708
+ def mangle(self) -> str:
709
+ return f"P{self.element_ty.mangle()}"
632
710
 
633
711
 
634
712
  class block_type(dtype):
@@ -660,9 +738,12 @@ class block_type(dtype):
660
738
  def is_block(self):
661
739
  return True
662
740
 
663
- def get_block_shapes(self) -> List[int]:
741
+ def get_block_shapes(self) -> Tuple[int]:
664
742
  return self.shape
665
743
 
744
+ def with_element_ty(self, scalar_ty: dtype) -> block_type:
745
+ return block_type(scalar_ty, self.shape)
746
+
666
747
  def __eq__(self, other) -> bool:
667
748
  if not isinstance(other, block_type):
668
749
  return False
@@ -672,6 +753,11 @@ class block_type(dtype):
672
753
  def scalar(self):
673
754
  return self.element_ty
674
755
 
756
+ def mangle(self) -> str:
757
+ elt = self.scalar.mangle()
758
+ shape = '_'.join(map(str, self.shape))
759
+ return f'{elt}S{shape}S'
760
+
675
761
 
676
762
  class tuple_type(base_type):
677
763
 
@@ -686,15 +772,14 @@ class tuple_type(base_type):
686
772
  def __iter__(self):
687
773
  return iter(self.types)
688
774
 
689
- def to_ir(self, builder: ir.builder):
690
- return [ty.to_ir(builder) for ty in self.types]
775
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]):
776
+ for ty in self.types:
777
+ if not isinstance(ty, constexpr):
778
+ ty._flatten_ir_types(builder, out)
691
779
 
692
780
  def __getitem__(self, index: int) -> dtype:
693
781
  return self.types[index]
694
782
 
695
- def is_tuple(self):
696
- return True
697
-
698
783
  def __eq__(self, other):
699
784
  return type(self) is type(other) and self.types == other.types and self.fields == other.fields
700
785
 
@@ -705,6 +790,9 @@ class tuple_type(base_type):
705
790
  values.append(value)
706
791
  return tuple(values, self), cursor
707
792
 
793
+ def mangle(self):
794
+ return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T'
795
+
708
796
 
709
797
  class slice_type(dtype):
710
798
 
@@ -808,224 +896,224 @@ class tensor(base_value):
808
896
  return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
809
897
 
810
898
  @builtin
811
- def __add__(self, other, _builder=None):
812
- return add(self, other, sanitize_overflow=True, _builder=_builder)
899
+ def __add__(self, other, _semantic=None):
900
+ return add(self, other, sanitize_overflow=True, _semantic=_semantic)
813
901
 
814
902
  @builtin
815
- def __radd__(self, other, _builder=None):
816
- return add(other, self, sanitize_overflow=True, _builder=_builder)
903
+ def __radd__(self, other, _semantic=None):
904
+ return add(other, self, sanitize_overflow=True, _semantic=_semantic)
817
905
 
818
906
  @builtin
819
- def __sub__(self, other, _builder=None):
820
- return sub(self, other, sanitize_overflow=True, _builder=_builder)
907
+ def __sub__(self, other, _semantic=None):
908
+ return sub(self, other, sanitize_overflow=True, _semantic=_semantic)
821
909
 
822
910
  @builtin
823
- def __rsub__(self, other, _builder=None):
824
- return sub(other, self, sanitize_overflow=True, _builder=_builder)
911
+ def __rsub__(self, other, _semantic=None):
912
+ return sub(other, self, sanitize_overflow=True, _semantic=_semantic)
825
913
 
826
914
  @builtin
827
- def __mul__(self, other, _builder=None):
828
- return mul(self, other, sanitize_overflow=True, _builder=_builder)
915
+ def __mul__(self, other, _semantic=None):
916
+ return mul(self, other, sanitize_overflow=True, _semantic=_semantic)
829
917
 
830
918
  @builtin
831
- def __rmul__(self, other, _builder=None):
832
- return mul(other, self, sanitize_overflow=True, _builder=_builder)
919
+ def __rmul__(self, other, _semantic=None):
920
+ return mul(other, self, sanitize_overflow=True, _semantic=_semantic)
833
921
 
834
922
  @builtin
835
- def __truediv__(self, other, _builder=None):
923
+ def __truediv__(self, other, _semantic=None):
836
924
  other = _unwrap_if_constexpr(other)
837
- return semantic.truediv(self, other, _builder)
925
+ return _semantic.truediv(self, other)
838
926
 
839
927
  @builtin
840
- def __rtruediv__(self, other, _builder=None):
928
+ def __rtruediv__(self, other, _semantic=None):
841
929
  other = _unwrap_if_constexpr(other)
842
- return semantic.truediv(other, self, _builder)
930
+ return _semantic.truediv(other, self)
843
931
 
844
932
  @builtin
845
- def __floordiv__(self, other, _builder=None):
933
+ def __floordiv__(self, other, _semantic=None):
846
934
  other = _unwrap_if_constexpr(other)
847
- return semantic.floordiv(self, other, _builder)
935
+ return _semantic.floordiv(self, other)
848
936
 
849
937
  @builtin
850
- def __rfloordiv__(self, other, _builder=None):
938
+ def __rfloordiv__(self, other, _semantic=None):
851
939
  other = _unwrap_if_constexpr(other)
852
- return semantic.floordiv(other, self, _builder)
940
+ return _semantic.floordiv(other, self)
853
941
 
854
942
  @builtin
855
- def __mod__(self, other, _builder=None):
943
+ def __mod__(self, other, _semantic=None):
856
944
  other = _unwrap_if_constexpr(other)
857
- return semantic.mod(self, other, _builder)
945
+ return _semantic.mod(self, other)
858
946
 
859
947
  @builtin
860
- def __rmod__(self, other, _builder=None):
948
+ def __rmod__(self, other, _semantic=None):
861
949
  other = _unwrap_if_constexpr(other)
862
- return semantic.mod(other, self, _builder)
950
+ return _semantic.mod(other, self)
863
951
 
864
952
  # unary operators
865
953
  @builtin
866
- def __neg__(self, _builder=None):
867
- return semantic.minus(self, _builder)
954
+ def __neg__(self, _semantic=None):
955
+ return _semantic.minus(self)
868
956
 
869
957
  @builtin
870
- def __invert__(self, _builder=None):
871
- return semantic.invert(self, _builder)
958
+ def __invert__(self, _semantic=None):
959
+ return _semantic.invert(self)
872
960
 
873
961
  # bitwise operators
874
962
 
875
963
  @builtin
876
- def __and__(self, other, _builder=None):
964
+ def __and__(self, other, _semantic=None):
877
965
  other = _unwrap_if_constexpr(other)
878
- return semantic.and_(self, other, _builder)
966
+ return _semantic.and_(self, other)
879
967
 
880
968
  @builtin
881
- def __rand__(self, other, _builder=None):
969
+ def __rand__(self, other, _semantic=None):
882
970
  other = _unwrap_if_constexpr(other)
883
- return semantic.and_(other, self, _builder)
971
+ return _semantic.and_(other, self)
884
972
 
885
973
  @builtin
886
- def __or__(self, other, _builder=None):
974
+ def __or__(self, other, _semantic=None):
887
975
  other = _unwrap_if_constexpr(other)
888
- return semantic.or_(self, other, _builder)
976
+ return _semantic.or_(self, other)
889
977
 
890
978
  @builtin
891
- def __ror__(self, other, _builder=None):
979
+ def __ror__(self, other, _semantic=None):
892
980
  other = _unwrap_if_constexpr(other)
893
- return semantic.or_(other, self, _builder)
981
+ return _semantic.or_(other, self)
894
982
 
895
983
  @builtin
896
- def __xor__(self, other, _builder=None):
984
+ def __xor__(self, other, _semantic=None):
897
985
  other = _unwrap_if_constexpr(other)
898
- return semantic.xor_(self, other, _builder)
986
+ return _semantic.xor_(self, other)
899
987
 
900
988
  @builtin
901
- def __rxor__(self, other, _builder=None):
989
+ def __rxor__(self, other, _semantic=None):
902
990
  other = _unwrap_if_constexpr(other)
903
- return semantic.xor_(other, self, _builder)
991
+ return _semantic.xor_(other, self)
904
992
 
905
993
  @builtin
906
- def __lshift__(self, other, _builder=None):
994
+ def __lshift__(self, other, _semantic=None):
907
995
  check_bit_width(self, other)
908
996
  other = _unwrap_if_constexpr(other)
909
- return semantic.shl(self, other, _builder)
997
+ return _semantic.shl(self, other)
910
998
 
911
999
  @builtin
912
- def __rlshift__(self, other, _builder=None):
1000
+ def __rlshift__(self, other, _semantic=None):
913
1001
  check_bit_width(other, self)
914
1002
  other = _unwrap_if_constexpr(other)
915
- return semantic.shl(other, self, _builder)
1003
+ return _semantic.shl(other, self)
916
1004
 
917
1005
  @builtin
918
- def __rshift__(self, other, _builder=None):
1006
+ def __rshift__(self, other, _semantic=None):
919
1007
  check_bit_width(self, other)
920
1008
  other = _unwrap_if_constexpr(other)
921
1009
  if self.dtype.is_int_signed():
922
- return semantic.ashr(self, other, _builder)
1010
+ return _semantic.ashr(self, other)
923
1011
  else:
924
- return semantic.lshr(self, other, _builder)
1012
+ return _semantic.lshr(self, other)
925
1013
 
926
1014
  @builtin
927
- def __rrshift__(self, other, _builder=None):
1015
+ def __rrshift__(self, other, _semantic=None):
928
1016
  check_bit_width(other, self)
929
1017
  other = _unwrap_if_constexpr(other)
930
1018
  if self.dtype.is_int_signed():
931
- return semantic.ashr(other, self, _builder)
1019
+ return _semantic.ashr(other, self)
932
1020
  else:
933
- return semantic.lshr(other, self, _builder)
1021
+ return _semantic.lshr(other, self)
934
1022
 
935
1023
  # >
936
1024
  @builtin
937
- def __gt__(self, other, _builder=None):
938
- other = semantic.to_tensor(other, _builder)
939
- return semantic.greater_than(self, other, _builder)
1025
+ def __gt__(self, other, _semantic=None):
1026
+ other = _semantic.to_tensor(other)
1027
+ return _semantic.greater_than(self, other)
940
1028
 
941
1029
  @builtin
942
- def __rgt__(self, other, _builder=None):
943
- other = semantic.to_tensor(other, _builder)
944
- return semantic.greater_than(other, self, _builder)
1030
+ def __rgt__(self, other, _semantic=None):
1031
+ other = _semantic.to_tensor(other)
1032
+ return _semantic.greater_than(other, self)
945
1033
 
946
1034
  # >=
947
1035
  @builtin
948
- def __ge__(self, other, _builder=None):
949
- other = semantic.to_tensor(other, _builder)
950
- return semantic.greater_equal(self, other, _builder)
1036
+ def __ge__(self, other, _semantic=None):
1037
+ other = _semantic.to_tensor(other)
1038
+ return _semantic.greater_equal(self, other)
951
1039
 
952
1040
  @builtin
953
- def __rge__(self, other, _builder=None):
954
- other = semantic.to_tensor(other, _builder)
955
- return semantic.greater_equal(other, self, _builder)
1041
+ def __rge__(self, other, _semantic=None):
1042
+ other = _semantic.to_tensor(other)
1043
+ return _semantic.greater_equal(other, self)
956
1044
 
957
1045
  # <
958
1046
  @builtin
959
- def __lt__(self, other, _builder=None):
960
- other = semantic.to_tensor(other, _builder)
961
- return semantic.less_than(self, other, _builder)
1047
+ def __lt__(self, other, _semantic=None):
1048
+ other = _semantic.to_tensor(other)
1049
+ return _semantic.less_than(self, other)
962
1050
 
963
1051
  @builtin
964
- def __rlt__(self, other, _builder=None):
965
- other = semantic.to_tensor(other, _builder)
966
- return semantic.less_than(other, self, _builder)
1052
+ def __rlt__(self, other, _semantic=None):
1053
+ other = _semantic.to_tensor(other)
1054
+ return _semantic.less_than(other, self)
967
1055
 
968
1056
  # <=
969
1057
  @builtin
970
- def __le__(self, other, _builder=None):
971
- other = semantic.to_tensor(other, _builder)
972
- return semantic.less_equal(self, other, _builder)
1058
+ def __le__(self, other, _semantic=None):
1059
+ other = _semantic.to_tensor(other)
1060
+ return _semantic.less_equal(self, other)
973
1061
 
974
1062
  @builtin
975
- def __rle__(self, other, _builder=None):
976
- other = semantic.to_tensor(other, _builder)
977
- return semantic.less_equal(other, self, _builder)
1063
+ def __rle__(self, other, _semantic=None):
1064
+ other = _semantic.to_tensor(other)
1065
+ return _semantic.less_equal(other, self)
978
1066
 
979
1067
  # ==
980
1068
  @builtin
981
- def __eq__(self, other, _builder=None):
982
- other = semantic.to_tensor(other, _builder)
983
- return semantic.equal(self, other, _builder)
1069
+ def __eq__(self, other, _semantic=None):
1070
+ other = _semantic.to_tensor(other)
1071
+ return _semantic.equal(self, other)
984
1072
 
985
1073
  @builtin
986
- def __req__(self, other, _builder=None):
987
- other = semantic.to_tensor(other, _builder)
988
- return semantic.equal(other, self, _builder)
1074
+ def __req__(self, other, _semantic=None):
1075
+ other = _semantic.to_tensor(other)
1076
+ return _semantic.equal(other, self)
989
1077
 
990
1078
  @builtin
991
- def __ne__(self, other, _builder=None):
992
- other = semantic.to_tensor(other, _builder)
993
- return semantic.not_equal(self, other, _builder)
1079
+ def __ne__(self, other, _semantic=None):
1080
+ other = _semantic.to_tensor(other)
1081
+ return _semantic.not_equal(self, other)
994
1082
 
995
1083
  @builtin
996
- def __rne__(self, other, _builder=None):
997
- other = semantic.to_tensor(other, _builder)
998
- return semantic.not_equal(other, self, _builder)
1084
+ def __rne__(self, other, _semantic=None):
1085
+ other = _semantic.to_tensor(other)
1086
+ return _semantic.not_equal(other, self)
999
1087
 
1000
1088
  @builtin
1001
- def logical_and(self, other, _builder=None):
1002
- other = semantic.to_tensor(other, _builder)
1003
- return semantic.logical_and(self, other, _builder)
1089
+ def logical_and(self, other, _semantic=None):
1090
+ other = _semantic.to_tensor(other)
1091
+ return _semantic.logical_and(self, other)
1004
1092
 
1005
1093
  @builtin
1006
- def logical_or(self, other, _builder=None):
1007
- other = semantic.to_tensor(other, _builder)
1008
- return semantic.logical_or(self, other, _builder)
1094
+ def logical_or(self, other, _semantic=None):
1095
+ other = _semantic.to_tensor(other)
1096
+ return _semantic.logical_or(self, other)
1009
1097
 
1010
1098
  # note: __not__ isn't actually a magic method in python
1011
1099
  # but it's ok because our ASTVisitor handles it
1012
1100
  @builtin
1013
- def __not__(self, _builder=None):
1014
- return semantic.not_(self, _builder)
1101
+ def __not__(self, _semantic=None):
1102
+ return _semantic.not_(self)
1015
1103
 
1016
1104
  @builtin
1017
- def __getitem__(self, slices, _builder=None):
1018
- import builtins
1105
+ def __getitem__(self, slices, _semantic=None):
1019
1106
  if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None:
1020
1107
  slices = [slices]
1021
1108
  if isinstance(slices, tuple):
1022
1109
  slices = slices.values
1023
1110
  ret = self
1024
1111
  for dim, sl in enumerate(slices):
1025
- if sl is None or isinstance(sl, constexpr) and sl.value is None:
1026
- ret = semantic.expand_dims(ret, dim, _builder)
1027
- elif isinstance(sl, (builtins.slice, slice)) and sl.start is None and sl.stop is None and sl.step is None:
1028
- pass
1112
+ if _unwrap_if_constexpr(sl) is None:
1113
+ ret = _semantic.expand_dims(ret, dim)
1114
+ elif isinstance(sl, (builtins.slice, slice)) and all(
1115
+ _unwrap_if_constexpr(arg) is None for arg in (sl.start, sl.stop, sl.step)):
1116
+ pass # an unsqueeze
1029
1117
  else:
1030
1118
  raise ValueError(f"unsupported tensor index: {sl}")
1031
1119
  return ret
@@ -1036,11 +1124,11 @@ class tensor(base_value):
1036
1124
  assert False, "Transposition must be created by the AST Visitor"
1037
1125
 
1038
1126
  @builtin
1039
- def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None):
1127
+ def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
1040
1128
  """
1041
1129
  Alias for :py:func:`tensor.cast`.
1042
1130
  """
1043
- return cast(self, dtype, fp_downcast_rounding, bitcast, _builder=_builder)
1131
+ return cast(self, dtype, fp_downcast_rounding, bitcast, _semantic=_semantic)
1044
1132
 
1045
1133
  # Type stubs for functions added by the _tensor_member_fn decorator.
1046
1134
  # (Unfortunately these can't be created automatically.)
@@ -1140,7 +1228,7 @@ class tensor(base_value):
1140
1228
  def sigmoid(self) -> tensor:
1141
1229
  ...
1142
1230
 
1143
- def softmax(self, ieee_rounding=False) -> tensor:
1231
+ def softmax(self, dim=None, keep_dims=False, ieee_rounding=False) -> tensor:
1144
1232
  ...
1145
1233
 
1146
1234
  def ravel(self) -> tensor:
@@ -1164,6 +1252,9 @@ class tensor(base_value):
1164
1252
  def xor_sum(self, axis=None, keep_dims=False) -> tensor:
1165
1253
  ...
1166
1254
 
1255
+ def reduce_or(self, axis=None, keep_dims=False) -> tensor:
1256
+ ...
1257
+
1167
1258
  def cumsum(self, axis=0, reverse=False) -> tensor:
1168
1259
  ...
1169
1260
 
@@ -1179,13 +1270,13 @@ class tensor(base_value):
1179
1270
 
1180
1271
  class tuple(base_value):
1181
1272
 
1182
- def __init__(self, args: list, type: tuple_type = None):
1273
+ def __init__(self, args: Sequence, type: tuple_type = None):
1183
1274
  self.values = [i for i in args]
1184
1275
 
1185
1276
  def get_type(x):
1186
1277
  if isinstance(x, dtype):
1187
1278
  return dtype
1188
- if isinstance(x, int):
1279
+ if isinstance(x, (int, float)):
1189
1280
  return constexpr
1190
1281
  return x.type
1191
1282
 
@@ -1197,7 +1288,6 @@ class tuple(base_value):
1197
1288
  if isinstance(idx, constexpr):
1198
1289
  return self.values[idx]
1199
1290
  else:
1200
- import builtins
1201
1291
  assert isinstance(idx, (slice, builtins.slice))
1202
1292
  return tuple(self.values[idx.start:idx.stop:idx.step])
1203
1293
 
@@ -1212,8 +1302,7 @@ class tuple(base_value):
1212
1302
  self.values[idx] = value
1213
1303
 
1214
1304
  def __add__(self, other):
1215
- if isinstance(other, list):
1216
- other = tuple(other)
1305
+ other = _normalize_tuple(other)
1217
1306
  return tuple(self.values + other.values)
1218
1307
  # return tuple(a + b for a, b in zip(self.values, other.values))
1219
1308
 
@@ -1222,13 +1311,10 @@ class tuple(base_value):
1222
1311
  return tuple(self.values * other.value)
1223
1312
 
1224
1313
  def __eq__(self, other):
1225
- import builtins
1226
- if isinstance(other, (list, builtins.tuple)):
1227
- other = tuple(other)
1314
+ other = _normalize_tuple(other)
1228
1315
  return constexpr(self.values == other.values)
1229
1316
 
1230
1317
  def __hash__(self):
1231
- import builtins
1232
1318
  return hash(builtins.tuple(self.values))
1233
1319
 
1234
1320
  def __str__(self):
@@ -1244,6 +1330,9 @@ class tuple(base_value):
1244
1330
  for v in self.values:
1245
1331
  v._flatten_ir(handles)
1246
1332
 
1333
+ def __repr__(self):
1334
+ return f"({' ,'.join(repr(x) for x in self.values)})"
1335
+
1247
1336
 
1248
1337
  class slice:
1249
1338
 
@@ -1259,12 +1348,13 @@ class tensor_descriptor_base_type(base_type):
1259
1348
  def __init__(self, block_type: block_type):
1260
1349
  self.block_type = block_type
1261
1350
 
1262
- def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[_experimental_tensor_descriptor_base, int]:
1263
- value = _experimental_tensor_descriptor_base(handles[cursor], self.block_type)
1351
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
1352
+ value = tensor_descriptor_base(handles[cursor], self.block_type)
1264
1353
  return value, cursor + 1
1265
1354
 
1266
- def to_ir(self, builder: ir.builder):
1267
- return builder.create_tensor_descriptor_type(self.block_type.to_ir(builder))
1355
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1356
+ is_signed = self.block_type.element_ty.is_int_signed()
1357
+ out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed))
1268
1358
 
1269
1359
  def __str__(self) -> str:
1270
1360
  # ex. "tensor_descriptor<float32[16, 32]>"
@@ -1278,8 +1368,11 @@ class tensor_descriptor_base_type(base_type):
1278
1368
  def __neq__(self, other) -> bool:
1279
1369
  return not (self == other)
1280
1370
 
1371
+ def mangle(self) -> str:
1372
+ return f"TD{self.block_type.mangle()}"
1373
+
1281
1374
 
1282
- class _experimental_tensor_descriptor_base(base_value):
1375
+ class tensor_descriptor_base(base_value):
1283
1376
  """"
1284
1377
  A tensor descriptor with unknown shape and strides
1285
1378
  """
@@ -1310,40 +1403,64 @@ class _experimental_tensor_descriptor_base(base_value):
1310
1403
  return str(self.type)
1311
1404
 
1312
1405
  @builtin
1313
- def load(self, offsets: Sequence[constexpr | tensor], _builder=None) -> tensor:
1406
+ def load(self, offsets: Sequence[constexpr | tensor], _semantic=None) -> tensor:
1314
1407
  """Load a block from the descriptor starting at the given element offsets.
1315
1408
 
1316
1409
  Values outside of the tensor bounds will be filled with zeros.
1317
1410
 
1318
1411
  :note: Offset must be a multiple of 16-bytes
1319
1412
  """
1320
- return semantic.descriptor_load(self, offsets, "", "", _builder)
1413
+ return _semantic.descriptor_load(self, offsets, "", "")
1321
1414
 
1322
1415
  @builtin
1323
- def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _builder=None) -> tensor:
1416
+ def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1324
1417
  """Store a block from the descriptor starting at the given element offsets.
1325
1418
 
1326
1419
  Values outside of the tensor bounds will be ignored.
1327
1420
 
1328
1421
  :note: Offset must be a multiple of 16-bytes
1329
1422
  """
1330
- return semantic.descriptor_store(self, value, offsets, _builder)
1423
+ return _semantic.descriptor_store(self, value, offsets)
1424
+
1425
+ @builtin
1426
+ def atomic_add(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1427
+ return _semantic.descriptor_atomic_add(self, value, offsets)
1428
+
1429
+ @builtin
1430
+ def atomic_min(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1431
+ return _semantic.descriptor_atomic_min(self, value, offsets)
1331
1432
 
1332
1433
  @builtin
1333
- def gather(self, *args, _builder=None) -> tensor:
1434
+ def atomic_max(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1435
+ return _semantic.descriptor_atomic_max(self, value, offsets)
1436
+
1437
+ @builtin
1438
+ def atomic_and(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1439
+ return _semantic.descriptor_atomic_and(self, value, offsets)
1440
+
1441
+ @builtin
1442
+ def atomic_or(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1443
+ return _semantic.descriptor_atomic_or(self, value, offsets)
1444
+
1445
+ @builtin
1446
+ def atomic_xor(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
1447
+ return _semantic.descriptor_atomic_xor(self, value, offsets)
1448
+
1449
+ @builtin
1450
+ def gather(self, *args, _semantic=None) -> tensor:
1334
1451
  """Gather multiple descriptors worth of data"""
1335
1452
  assert len(args) == 2, f"descriptor gather only supports 2D indexing, but got {len(args)}"
1336
1453
  x_offsets = args[0]
1337
1454
  y_offset = args[1]
1338
- return semantic.descriptor_gather(self, x_offsets, y_offset, "", "", _builder)
1455
+ return _semantic.descriptor_gather(self, x_offsets, y_offset, "", "")
1339
1456
 
1340
1457
  @builtin
1341
- def scatter(self, value, *args, _builder=None) -> tensor:
1458
+ def scatter(self, value, *args, _semantic=None) -> tensor:
1342
1459
  """Scatter multiple descriptors worth of data"""
1343
1460
  assert len(args) == 2, f"descriptor scatter only supports 2D indexing, but got {len(args)}"
1344
1461
  x_offsets = args[0]
1345
1462
  y_offset = args[1]
1346
- return semantic.descriptor_scatter(self, value, x_offsets, y_offset, _builder)
1463
+ return _semantic.descriptor_scatter(self, value, x_offsets, y_offset)
1347
1464
 
1348
1465
 
1349
1466
  class tensor_descriptor_type(tensor_descriptor_base_type):
@@ -1353,25 +1470,27 @@ class tensor_descriptor_type(tensor_descriptor_base_type):
1353
1470
  self.shape_type = shape_type
1354
1471
  self.strides_type = strides_type
1355
1472
 
1356
- def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[_experimental_tensor_descriptor_base, int]:
1473
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
1357
1474
  handle = handles[cursor]
1358
1475
  cursor += 1
1359
1476
  shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
1360
1477
  strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
1361
1478
  shape = shape.values
1362
1479
  strides = strides.values
1363
- value = _experimental_tensor_descriptor(handle, shape, strides, self.block_type)
1480
+ value = tensor_descriptor(handle, shape, strides, self.block_type)
1364
1481
  return value, cursor
1365
1482
 
1366
- def to_ir(self, builder: ir.builder):
1367
- return [super().to_ir(builder), *self.shape_type.to_ir(builder), *self.strides_type.to_ir(builder)]
1483
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1484
+ super()._flatten_ir_types(builder, out)
1485
+ self.shape_type._flatten_ir_types(builder, out)
1486
+ self.strides_type._flatten_ir_types(builder, out)
1368
1487
 
1369
1488
  def __eq__(self, other):
1370
1489
  return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type
1371
1490
  == other.strides_type)
1372
1491
 
1373
1492
 
1374
- class _experimental_tensor_descriptor(_experimental_tensor_descriptor_base):
1493
+ class tensor_descriptor(tensor_descriptor_base):
1375
1494
  """A descriptor representing a tensor in global memory.
1376
1495
  """
1377
1496
 
@@ -1379,37 +1498,121 @@ class _experimental_tensor_descriptor(_experimental_tensor_descriptor_base):
1379
1498
  """Not called by user code."""
1380
1499
  # IR handle
1381
1500
  super().__init__(handle, block_type)
1501
+ # Global shape
1502
+ self.shape = tuple(shape)
1503
+ self.strides = tuple(strides)
1382
1504
  self.type = tensor_descriptor_type(
1383
1505
  block_type,
1384
- shape_type=tuple_type([s.type for s in shape]),
1385
- strides_type=tuple_type([s.type for s in strides]),
1506
+ shape_type=self.shape.type,
1507
+ strides_type=self.strides.type,
1386
1508
  )
1387
- # Global shape
1388
- self.shape = shape
1389
- self.strides = strides
1390
1509
 
1391
1510
  def _flatten_ir(self, handles: List[ir.value]) -> None:
1392
1511
  handles.append(self.handle)
1393
- handles.extend(s.handle for s in self.shape)
1394
- handles.extend(s.handle for s in self.strides)
1512
+ self.shape._flatten_ir(handles)
1513
+ self.strides._flatten_ir(handles)
1514
+
1515
+
1516
+ # -----------------------
1517
+ # aggregate
1518
+ # -----------------------
1519
+
1395
1520
 
1521
+ @dataclass(frozen=True)
1522
+ class _aggregate_type(base_type):
1523
+ """A generic base type for all Triton aggregate types.
1396
1524
 
1397
- def get_bool_env_var(var_name):
1398
- v = os.getenv(var_name, "0")
1399
- return v == "1" or v == "true" or v == "on"
1525
+ This class contains a reference to the original user-defined Python class
1526
+ and a list of class fields with their Triton types.
1527
+ """
1528
+
1529
+ base_cls: type
1530
+ fields: List[Tuple[str, base_type]]
1531
+
1532
+ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]:
1533
+ instance = self.base_cls._get_instance()
1534
+ for name, ty in self.fields:
1535
+ value, cursor = ty._unflatten_ir(handles, cursor)
1536
+ setattr(instance, name, value)
1537
+ return instance, cursor
1538
+
1539
+ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1540
+ for name, ty in self.fields:
1541
+ ty._flatten_ir_types(builder, out)
1542
+
1543
+ def mangle(self) -> str:
1544
+ name = f"{self.base_cls.__module__}.{self.base_cls.__qualname__}"
1545
+ fields = [ty.mangle() for (name, ty) in self.fields]
1546
+ return f"{name}<{', '.join(fields)}>"
1547
+
1548
+
1549
+ def _aggregate(cls):
1550
+
1551
+ # Define the wrapped Triton value type.
1552
+ class aggregate_value(base_value):
1553
+ __triton_builtin__ = True
1554
+ __triton_aggregate__ = True
1555
+
1556
+ @classmethod
1557
+ def _get_instance(this_cls):
1558
+ return super().__new__(this_cls)
1559
+
1560
+ def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs):
1561
+ # Call into the user-defined constructor.
1562
+ instance = this_cls._get_instance()
1563
+ if isinstance(cls.__init__, JITFunction):
1564
+ raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
1565
+ extra_kwargs = {}
1566
+ if "_semantic" in inspect.signature(cls.__init__).parameters:
1567
+ extra_kwargs["_semantic"] = _semantic
1568
+ if "_generator" in inspect.signature(cls.__init__).parameters:
1569
+ extra_kwargs["_generator"] = _generator
1570
+ cls.__init__(instance, *args, **extra_kwargs, **kwargs)
1571
+
1572
+ # Require that the user-defined constructor initialized all fields.
1573
+ for name in cls.__annotations__.keys():
1574
+ if not hasattr(instance, name):
1575
+ raise AttributeError(f"constructor for {cls.__name__} did not initialize attribute '{name}'")
1576
+
1577
+ return instance
1578
+
1579
+ # Only allow setting attributes defined in the class annotations.
1580
+ def __setattr__(self, name, value):
1581
+ if name not in cls.__annotations__:
1582
+ raise AttributeError(f"{cls.__name__} has no attribute '{name}'")
1583
+ if not isinstance(value, cls.__annotations__[name]):
1584
+ raise TypeError(f"Expected {cls.__annotations__[name]} for attribute '{name}', got {type(value)}")
1585
+ super().__setattr__(name, value)
1586
+
1587
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
1588
+ for name in cls.__annotations__.keys():
1589
+ getattr(self, name)._flatten_ir(handles)
1590
+
1591
+ @property
1592
+ def type(self):
1593
+ return _aggregate_type(aggregate_value,
1594
+ [(name, getattr(self, name).type) for name in cls.__annotations__.keys()])
1595
+
1596
+ for (name, member) in inspect.getmembers(cls):
1597
+ if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITFunction):
1598
+ if name != "__init__":
1599
+ setattr(aggregate_value, name, member)
1600
+
1601
+ aggregate_value.__name__ = cls.__name__
1602
+ aggregate_value.__module__ = cls.__module__
1603
+ aggregate_value.__qualname__ = cls.__qualname__
1604
+ aggregate_value.__doc__ = cls.__doc__
1605
+
1606
+ return aggregate_value
1400
1607
 
1401
1608
 
1402
1609
  # -----------------------
1403
1610
  # SPMD Programming Model
1404
1611
  # -----------------------
1405
- def _constexpr_to_value(v):
1406
- if isinstance(v, constexpr):
1407
- return v.value
1408
- return v
1409
1612
 
1410
1613
 
1411
1614
  @builtin
1412
- def program_id(axis, _builder=None):
1615
+ def program_id(axis, _semantic=None):
1413
1616
  """
1414
1617
  Returns the id of the current program instance along the given :code:`axis`.
1415
1618
 
@@ -1417,26 +1620,26 @@ def program_id(axis, _builder=None):
1417
1620
  :type axis: int
1418
1621
  """
1419
1622
  # if axis == -1:
1420
- # pid0 = program_id(0, _builder)
1421
- # pid1 = program_id(1, _builder)
1422
- # pid2 = program_id(2, _builder)
1423
- # npg0 = num_programs(0, _builder)
1424
- # npg1 = num_programs(1, _builder)
1623
+ # pid0 = _semantic.program_id(0)
1624
+ # pid1 = _semantic.program_id(1)
1625
+ # pid2 = _semantic.program_id(2)
1626
+ # npg0 = _semantic.num_programs(0)
1627
+ # npg1 = _semantic.num_programs(1)
1425
1628
  # return pid0 + pid1*npg0 + pid2*npg0*npg1
1426
- axis = _constexpr_to_value(axis)
1427
- return semantic.program_id(axis, _builder)
1629
+ axis = _unwrap_if_constexpr(axis)
1630
+ return _semantic.program_id(axis)
1428
1631
 
1429
1632
 
1430
1633
  @builtin
1431
- def num_programs(axis, _builder=None):
1634
+ def num_programs(axis, _semantic=None):
1432
1635
  """
1433
1636
  Returns the number of program instances launched along the given :code:`axis`.
1434
1637
 
1435
1638
  :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
1436
1639
  :type axis: int
1437
1640
  """
1438
- axis = _constexpr_to_value(axis)
1439
- return semantic.num_programs(axis, _builder)
1641
+ axis = _unwrap_if_constexpr(axis)
1642
+ return _semantic.num_programs(axis)
1440
1643
 
1441
1644
 
1442
1645
  # -----------------------
@@ -1445,10 +1648,10 @@ def num_programs(axis, _builder=None):
1445
1648
 
1446
1649
 
1447
1650
  @builtin
1448
- def arange(start, end, _builder=None):
1449
- start = _constexpr_to_value(start)
1450
- end = _constexpr_to_value(end)
1451
- return semantic.arange(start, end, _builder)
1651
+ def arange(start, end, _semantic=None):
1652
+ start = _unwrap_if_constexpr(start)
1653
+ end = _unwrap_if_constexpr(end)
1654
+ return _semantic.arange(start, end)
1452
1655
 
1453
1656
 
1454
1657
  arange.__doc__ = f"""
@@ -1465,8 +1668,8 @@ arange.__doc__ = f"""
1465
1668
 
1466
1669
 
1467
1670
  def _unwrap_shape(shape):
1468
- shape = _constexpr_to_value(shape)
1469
- return [_constexpr_to_value(s) for s in shape]
1671
+ shape = _unwrap_if_constexpr(shape)
1672
+ return [_unwrap_if_constexpr(s) for s in shape]
1470
1673
 
1471
1674
 
1472
1675
  def _shape_check_impl(shape):
@@ -1476,7 +1679,7 @@ def _shape_check_impl(shape):
1476
1679
 
1477
1680
 
1478
1681
  @builtin
1479
- def full(shape, value, dtype, _builder=None):
1682
+ def full(shape, value, dtype, _semantic=None):
1480
1683
  """
1481
1684
  Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`.
1482
1685
 
@@ -1488,9 +1691,9 @@ def full(shape, value, dtype, _builder=None):
1488
1691
  :type dtype: tl.dtype
1489
1692
  """
1490
1693
  shape = _shape_check_impl(shape)
1491
- value = _constexpr_to_value(value)
1492
- dtype = _constexpr_to_value(dtype)
1493
- return semantic.full(shape, value, dtype, _builder)
1694
+ value = _unwrap_if_constexpr(value)
1695
+ dtype = _unwrap_if_constexpr(dtype)
1696
+ return _semantic.full(shape, value, dtype)
1494
1697
 
1495
1698
 
1496
1699
  # -----------------------
@@ -1499,7 +1702,7 @@ def full(shape, value, dtype, _builder=None):
1499
1702
 
1500
1703
 
1501
1704
  @builtin
1502
- def broadcast(input, other, _builder=None):
1705
+ def broadcast(input, other, _semantic=None):
1503
1706
  """
1504
1707
  Tries to broadcast the two given blocks to a common compatible shape.
1505
1708
 
@@ -1508,12 +1711,12 @@ def broadcast(input, other, _builder=None):
1508
1711
  :param other: The second input tensor.
1509
1712
  :type other: Block
1510
1713
  """
1511
- return semantic.broadcast_impl_value(input, other, _builder)
1714
+ return _semantic.broadcast_impl_value(input, other)
1512
1715
 
1513
1716
 
1514
1717
  @_tensor_member_fn
1515
1718
  @builtin
1516
- def broadcast_to(input, *shape, _builder=None):
1719
+ def broadcast_to(input, *shape, _semantic=None):
1517
1720
  """
1518
1721
  Tries to broadcast the given tensor to a new :code:`shape`.
1519
1722
 
@@ -1529,12 +1732,12 @@ def broadcast_to(input, *shape, _builder=None):
1529
1732
  broadcast_to(x, 32, 32)
1530
1733
  """
1531
1734
  shape = _shape_check_impl(_unwrap_iterable(shape))
1532
- return semantic.broadcast_impl_shape(input, shape, _builder)
1735
+ return _semantic.broadcast_impl_shape(input, shape)
1533
1736
 
1534
1737
 
1535
1738
  @_tensor_member_fn
1536
1739
  @builtin
1537
- def trans(input: tensor, *dims, _builder=None):
1740
+ def trans(input: tensor, *dims, _semantic=None):
1538
1741
  """
1539
1742
  Permutes the dimensions of a tensor.
1540
1743
 
@@ -1543,7 +1746,7 @@ def trans(input: tensor, *dims, _builder=None):
1543
1746
 
1544
1747
  :param input: The input tensor.
1545
1748
  :param dims: The desired ordering of dimensions. For example,
1546
- :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor.
1749
+ :code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
1547
1750
 
1548
1751
  :code:`dims` can be passed as a tuple or as individual parameters: ::
1549
1752
 
@@ -1557,19 +1760,19 @@ def trans(input: tensor, *dims, _builder=None):
1557
1760
  dims = _unwrap_iterable(dims)
1558
1761
  if not dims:
1559
1762
  dims = (1, 0)
1560
- return semantic.permute(input, dims, _builder)
1763
+ return _semantic.permute(input, dims)
1561
1764
 
1562
1765
 
1563
1766
  @_tensor_member_fn
1564
1767
  @builtin
1565
- def permute(input, *dims, _builder=None):
1768
+ def permute(input, *dims, _semantic=None):
1566
1769
  """
1567
1770
  Permutes the dimensions of a tensor.
1568
1771
 
1569
1772
  :param input: The input tensor.
1570
1773
  :type input: Block
1571
1774
  :param dims: The desired ordering of dimensions. For example,
1572
- :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor.
1775
+ :code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
1573
1776
 
1574
1777
  :code:`dims` can be passed as a tuple or as individual parameters: ::
1575
1778
 
@@ -1581,11 +1784,11 @@ def permute(input, *dims, _builder=None):
1581
1784
  :code:`dims` is empty, it tries to do a (1,0) permutation.
1582
1785
  """
1583
1786
  dims = _unwrap_iterable(dims)
1584
- return semantic.permute(input, dims, _builder)
1787
+ return _semantic.permute(input, dims)
1585
1788
 
1586
1789
 
1587
1790
  @builtin
1588
- def cat(input, other, can_reorder=False, _builder=None):
1791
+ def cat(input, other, can_reorder=False, _semantic=None):
1589
1792
  """
1590
1793
  Concatenate the given blocks
1591
1794
 
@@ -1598,11 +1801,11 @@ def cat(input, other, can_reorder=False, _builder=None):
1598
1801
  order does not matter (e.g., result is only used in reduction ops).
1599
1802
  Current implementation of `cat` supports only can_reorder=True.
1600
1803
  """
1601
- return semantic.cat(input, other, can_reorder, _builder)
1804
+ return _semantic.cat(input, other, can_reorder)
1602
1805
 
1603
1806
 
1604
1807
  @builtin
1605
- def join(a, b, _builder=None):
1808
+ def join(a, b, _semantic=None):
1606
1809
  """
1607
1810
  Join the given tensors in a new, minor dimension.
1608
1811
 
@@ -1622,7 +1825,7 @@ def join(a, b, _builder=None):
1622
1825
  :param b: The second input tensor.
1623
1826
  :type b: Tensor
1624
1827
  """
1625
- return semantic.join(a, b, _builder)
1828
+ return _semantic.join(a, b)
1626
1829
 
1627
1830
 
1628
1831
  @jit
@@ -1630,9 +1833,25 @@ def _take_first(a, b):
1630
1833
  return a
1631
1834
 
1632
1835
 
1836
+ def _unsplat(x, _semantic=None, _generator=None):
1837
+ """
1838
+ Convert a single-element tensor to a scalar.
1839
+ """
1840
+ if len(x.shape) == 0:
1841
+ return x
1842
+ numel = 1
1843
+ for d in x.shape:
1844
+ numel *= d
1845
+ assert numel == 1, "can only unsplat single-element tensors"
1846
+ if len(x.shape) >= 2:
1847
+ x = _semantic.reshape(x, [1])
1848
+ x = typing.cast(tensor, reduce(x, 0, _take_first, _semantic=_semantic, _generator=_generator))
1849
+ return x
1850
+
1851
+
1633
1852
  @_tensor_member_fn
1634
1853
  @builtin
1635
- def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]:
1854
+ def split(a, _semantic=None, _generator=None) -> tuple[tensor, tensor]:
1636
1855
  """
1637
1856
  Split a tensor in two along its last dim, which must have size 2.
1638
1857
 
@@ -1649,25 +1868,25 @@ def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]:
1649
1868
  :type a: Tensor
1650
1869
  """
1651
1870
  # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars.
1652
- # But semantic.split can only handle returning tensors. Work around this by
1871
+ # But _semantic.split can only handle returning tensors. Work around this by
1653
1872
  # expanding the input to shape [1,2] and then reducing the result.
1654
1873
  was_rank_1 = len(a.shape) == 1
1655
1874
  if was_rank_1:
1656
- a = semantic.expand_dims(a, 0, _builder)
1875
+ a = _semantic.expand_dims(a, 0)
1657
1876
 
1658
- out_lhs, out_rhs = semantic.split(a, _builder)
1877
+ out_lhs, out_rhs = _semantic.split(a)
1659
1878
 
1660
1879
  if was_rank_1:
1661
1880
  # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar.
1662
- out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator))
1663
- out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator))
1881
+ out_lhs = _unsplat(out_lhs, _semantic=_semantic, _generator=_generator)
1882
+ out_rhs = _unsplat(out_rhs, _semantic=_semantic, _generator=_generator)
1664
1883
 
1665
1884
  return out_lhs, out_rhs
1666
1885
 
1667
1886
 
1668
1887
  @_tensor_member_fn
1669
1888
  @builtin
1670
- def view(input, *shape, _builder=None):
1889
+ def view(input, *shape, _semantic=None):
1671
1890
  """
1672
1891
  Returns a tensor with the same elements as `input` but a different shape.
1673
1892
  The order of the elements may not be preserved.
@@ -1684,12 +1903,21 @@ def view(input, *shape, _builder=None):
1684
1903
  """
1685
1904
  warn("view is deprecated, please use reshape with can_reorder being true.")
1686
1905
  shape = _shape_check_impl(_unwrap_iterable(shape))
1687
- return semantic.reshape(input, shape, can_reorder=True, builder=_builder)
1906
+ return _semantic.reshape(input, shape, can_reorder=True)
1688
1907
 
1689
1908
 
1690
1909
  @_tensor_member_fn
1691
1910
  @builtin
1692
- def reshape(input, *shape, can_reorder=False, _builder=None):
1911
+ def item(input, _semantic=None, _generator=None):
1912
+ """
1913
+ Converts a single-element tensor into a scalar.
1914
+ """
1915
+ return _unsplat(input, _semantic=_semantic, _generator=_generator)
1916
+
1917
+
1918
+ @_tensor_member_fn
1919
+ @builtin
1920
+ def reshape(input, *shape, can_reorder=False, _semantic=None, _generator=None):
1693
1921
  """
1694
1922
  Returns a tensor with the same number of elements as input but with the
1695
1923
  provided shape.
@@ -1705,7 +1933,9 @@ def reshape(input, *shape, can_reorder=False, _builder=None):
1705
1933
  reshape(x, 32, 32)
1706
1934
  """
1707
1935
  shape = _shape_check_impl(_unwrap_iterable(shape))
1708
- return semantic.reshape(input, shape, can_reorder, _builder)
1936
+ if len(shape) == 0:
1937
+ return _unsplat(input, _semantic=_semantic, _generator=_generator)
1938
+ return _semantic.reshape(input, shape, can_reorder)
1709
1939
 
1710
1940
 
1711
1941
  def _wrap_axis(axis, ndim):
@@ -1717,7 +1947,7 @@ def _wrap_axis(axis, ndim):
1717
1947
 
1718
1948
  @_tensor_member_fn
1719
1949
  @builtin
1720
- def expand_dims(input, axis, _builder=None):
1950
+ def expand_dims(input, axis, _semantic=None):
1721
1951
  """
1722
1952
  Expand the shape of a tensor, by inserting new length-1 dimensions.
1723
1953
 
@@ -1730,24 +1960,24 @@ def expand_dims(input, axis, _builder=None):
1730
1960
  :type axis: int | Sequence[int]
1731
1961
 
1732
1962
  """
1733
- input = semantic.to_tensor(input, _builder)
1734
- axis = _constexpr_to_value(axis)
1963
+ input = _semantic.to_tensor(input)
1964
+ axis = _unwrap_if_constexpr(axis)
1735
1965
  axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis]
1736
1966
  new_ndim = len(input.shape) + len(axes)
1737
- axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes]
1967
+ axes = [_wrap_axis(_unwrap_if_constexpr(d), new_ndim) for d in axes]
1738
1968
 
1739
1969
  if len(set(axes)) != len(axes):
1740
1970
  raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}")
1741
1971
 
1742
1972
  ret = input
1743
1973
  for a in sorted(axes):
1744
- ret = semantic.expand_dims(ret, a, _builder)
1974
+ ret = _semantic.expand_dims(ret, a)
1745
1975
  return ret
1746
1976
 
1747
1977
 
1748
1978
  @_tensor_member_fn
1749
1979
  @builtin
1750
- def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None):
1980
+ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
1751
1981
  """
1752
1982
  Casts a tensor to the given :code:`dtype`.
1753
1983
 
@@ -1763,13 +1993,13 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas
1763
1993
  :code:`dtype`, instead of being numerically casted.
1764
1994
  :type bitcast: bool, optional
1765
1995
  """
1766
- input = semantic.to_tensor(input, _builder)
1767
- dtype = _constexpr_to_value(dtype)
1768
- fp_downcast_rounding = _constexpr_to_value(fp_downcast_rounding)
1769
- bitcast = _constexpr_to_value(bitcast)
1996
+ input = _semantic.to_tensor(input)
1997
+ dtype = _unwrap_if_constexpr(dtype)
1998
+ fp_downcast_rounding = _unwrap_if_constexpr(fp_downcast_rounding)
1999
+ bitcast = _unwrap_if_constexpr(bitcast)
1770
2000
  if bitcast:
1771
- return semantic.bitcast(input, dtype, _builder)
1772
- return semantic.cast(input, dtype, _builder, fp_downcast_rounding)
2001
+ return _semantic.bitcast(input, dtype)
2002
+ return _semantic.cast(input, dtype, fp_downcast_rounding)
1773
2003
 
1774
2004
 
1775
2005
  # -----------------------
@@ -1779,7 +2009,7 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas
1779
2009
 
1780
2010
  @builtin
1781
2011
  def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32,
1782
- _builder=None):
2012
+ _semantic=None):
1783
2013
  """
1784
2014
  Returns the matrix product of two blocks.
1785
2015
 
@@ -1804,19 +2034,20 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
1804
2034
  """
1805
2035
  assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified"
1806
2036
  if input_precision is None:
1807
- supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions
1808
- default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee"
1809
- input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision)
2037
+ supports_tf32 = "tf32" in _semantic.builder.options.allowed_dot_input_precisions
2038
+ input_precision = knobs.language.fp32_default or ("tf32" if (supports_tf32 and
2039
+ (allow_tf32 or allow_tf32 is None)) else "ieee")
1810
2040
 
1811
- input_precision = _constexpr_to_value(input_precision)
1812
- out_dtype = _constexpr_to_value(out_dtype)
1813
- max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc)
1814
- return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder)
2041
+ input_precision = _unwrap_if_constexpr(input_precision)
2042
+ out_dtype = _unwrap_if_constexpr(out_dtype)
2043
+ max_num_imprecise_acc = _unwrap_if_constexpr(max_num_imprecise_acc)
2044
+ acc = _unwrap_if_constexpr(acc)
2045
+ return _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
1815
2046
 
1816
2047
 
1817
2048
  @builtin
1818
- def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, out_dtype=float32,
1819
- _builder=None):
2049
+ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True,
2050
+ rhs_k_pack=True, out_dtype=float32, _semantic=None):
1820
2051
  """
1821
2052
  Returns the matrix product of two blocks in microscaling format.
1822
2053
 
@@ -1843,11 +2074,15 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
1843
2074
  :param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
1844
2075
  :type rhs_format: str
1845
2076
  :param acc: The accumulator tensor. If not None, the result is added to this tensor.
2077
+ :param lhs_k_pack: If false, the lhs tensor is packed into uint8 along M dimension.
2078
+ :type lhs_k_pack: bool, optional
2079
+ :param rhs_k_pack: If false, the rhs tensor is packed into uint8 along N dimension.
2080
+ :type rhs_k_pack: bool, optional
1846
2081
  """
1847
- out_dtype = _constexpr_to_value(out_dtype)
2082
+ out_dtype = _unwrap_if_constexpr(out_dtype)
1848
2083
  assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment"
1849
- return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, out_dtype,
1850
- _builder)
2084
+ return _semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, lhs_k_pack,
2085
+ rhs_k_pack, out_dtype)
1851
2086
 
1852
2087
 
1853
2088
  # -----------------------
@@ -1857,7 +2092,7 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
1857
2092
 
1858
2093
  @builtin
1859
2094
  def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="",
1860
- volatile=False, _builder=None):
2095
+ volatile=False, _semantic=None):
1861
2096
  """
1862
2097
  Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
1863
2098
 
@@ -1892,8 +2127,9 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
1892
2127
  :type boundary_check: tuple of ints, optional
1893
2128
  :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value.
1894
2129
  :param cache_modifier: changes cache option in NVIDIA PTX
1895
- :type cache_modifier: str, optional, should be one of {"", "ca", "cg"}, where "ca" stands for
1896
- cache at all levels and "cg" stands for cache at global level (cache in L2 and below, not L1), see
2130
+ :type cache_modifier: str, optional, should be one of {"", ".ca", ".cg", ".cv"}, where ".ca" stands for
2131
+ cache at all levels, ".cg" stands for cache at global level (cache in L2 and below, not L1),
2132
+ and ".cv" means don’t cache and fetch again. see
1897
2133
  `cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
1898
2134
  :param eviction_policy: changes eviction policy in NVIDIA PTX
1899
2135
  :type eviction_policy: str, optional
@@ -1901,57 +2137,37 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
1901
2137
  :type volatile: bool, optional
1902
2138
  """
1903
2139
  # `mask` and `other` can be constexpr
1904
- mask = _constexpr_to_value(mask)
1905
- other = _constexpr_to_value(other)
2140
+ mask = _unwrap_if_constexpr(mask)
2141
+ other = _unwrap_if_constexpr(other)
1906
2142
  if mask is not None:
1907
- mask = semantic.to_tensor(mask, _builder)
2143
+ mask = _semantic.to_tensor(mask)
1908
2144
  if other is not None:
1909
- other = semantic.to_tensor(other, _builder)
1910
- padding_option = _constexpr_to_value(padding_option)
1911
- cache_modifier = _constexpr_to_value(cache_modifier)
1912
- eviction_policy = _constexpr_to_value(eviction_policy)
1913
- volatile = _constexpr_to_value(volatile)
1914
- return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
1915
- volatile, _builder)
1916
-
1917
-
1918
- @builtin
1919
- def _experimental_reinterpret_tensor_descriptor(desc_ptr, block_shape, dtype,
1920
- _builder=None) -> _experimental_tensor_descriptor_base:
1921
- """
1922
- Reinterpret a generic pointer as a TMA-backed tensor descriptor object.
1923
- """
1924
- block_ty = block_type(_constexpr_to_value(dtype), block_shape)
1925
- return semantic.reinterpret_tensor_descriptor(desc_ptr, block_ty, _builder)
2145
+ other = _semantic.to_tensor(other)
2146
+ padding_option = _unwrap_if_constexpr(padding_option)
2147
+ cache_modifier = _unwrap_if_constexpr(cache_modifier)
2148
+ eviction_policy = _unwrap_if_constexpr(eviction_policy)
2149
+ volatile = _unwrap_if_constexpr(volatile)
2150
+ return _semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
2151
+ volatile)
1926
2152
 
1927
2153
 
1928
2154
  @builtin
1929
- def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None):
1930
- """
1931
- Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations.
1932
- This will be removed in the future and shouldn't be used in production code.
1933
-
1934
- This loads a tensor of data based on the descriptor and offsets.
1935
- """
1936
- desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, shape, dtype, _builder=_builder)
1937
- return desc.load(offsets, _builder=_builder)
2155
+ def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor],
2156
+ _semantic=None) -> tensor:
2157
+ """Load a block of data from a tensor descriptor."""
2158
+ return desc.load(offsets, _semantic=_semantic)
1938
2159
 
1939
2160
 
1940
2161
  @builtin
1941
- def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None):
1942
- """
1943
- Experimental feature to access TMA descriptors stores. This is an escape hatch to easily exercise TTGIR operations.
1944
- This will be removed in the future and shouldn't be used in production code.
1945
-
1946
- This stores a tensor of data based on the descriptor and offsets.
1947
- """
1948
- desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, value.shape, value.dtype, _builder=_builder)
1949
- return desc.store(offsets, value, _builder=_builder)
2162
+ def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor], value: tensor,
2163
+ _semantic=None) -> tensor:
2164
+ """Store a block of data to a tensor descriptor."""
2165
+ return desc.store(offsets, value, _semantic=_semantic)
1950
2166
 
1951
2167
 
1952
2168
  @_tensor_member_fn
1953
2169
  @builtin
1954
- def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None):
2170
+ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _semantic=None):
1955
2171
  """
1956
2172
  Store a tensor of data into memory locations defined by `pointer`.
1957
2173
 
@@ -1991,17 +2207,17 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict
1991
2207
  :type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"}
1992
2208
  """
1993
2209
  # `value` can be constexpr
1994
- value = semantic.to_tensor(value, _builder)
1995
- mask = _constexpr_to_value(mask)
2210
+ value = _semantic.to_tensor(value)
2211
+ mask = _unwrap_if_constexpr(mask)
1996
2212
  if mask is not None:
1997
- mask = semantic.to_tensor(mask, _builder)
1998
- cache_modifier = _constexpr_to_value(cache_modifier)
1999
- eviction_policy = _constexpr_to_value(eviction_policy)
2000
- return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)
2213
+ mask = _semantic.to_tensor(mask)
2214
+ cache_modifier = _unwrap_if_constexpr(cache_modifier)
2215
+ eviction_policy = _unwrap_if_constexpr(eviction_policy)
2216
+ return _semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy)
2001
2217
 
2002
2218
 
2003
2219
  @builtin
2004
- def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None):
2220
+ def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _semantic=None):
2005
2221
  """
2006
2222
  Returns a pointer to a block in a parent tensor
2007
2223
 
@@ -2012,30 +2228,33 @@ def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _b
2012
2228
  :param block_shape: The shape of the block
2013
2229
  :param order: The order of the original data format
2014
2230
  """
2015
- return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
2231
+ return _semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order)
2016
2232
 
2017
2233
 
2234
+ @must_use_result(
2235
+ "Note that tl.advance does not have any side effects. To move the block pointer, you need to assign the result of tl.advance to a variable."
2236
+ )
2018
2237
  @_tensor_member_fn
2019
2238
  @builtin
2020
- def advance(base, offsets, _builder=None):
2239
+ def advance(base, offsets, _semantic=None):
2021
2240
  """
2022
2241
  Advance a block pointer
2023
2242
 
2024
2243
  :param base: the block pointer to advance
2025
2244
  :param offsets: the offsets to advance, a tuple by dimension
2026
2245
  """
2027
- return semantic.advance(base, offsets, _builder)
2246
+ return _semantic.advance(base, offsets)
2028
2247
 
2029
2248
 
2030
2249
  @builtin
2031
- def _experimental_make_tensor_descriptor(
2250
+ def make_tensor_descriptor(
2032
2251
  base: tensor,
2033
2252
  shape: List[tensor],
2034
2253
  strides: List[tensor],
2035
2254
  block_shape: List[constexpr],
2036
- _builder=None,
2037
- ) -> _experimental_tensor_descriptor:
2038
- """Make an experimental tensor descriptor object
2255
+ _semantic=None,
2256
+ ) -> tensor_descriptor:
2257
+ """Make a tensor descriptor object
2039
2258
 
2040
2259
  :param base: the base pointer of the tensor, must be 16-byte aligned
2041
2260
  :param shape: A list of non-negative integers representing the tensor shape
@@ -2056,7 +2275,7 @@ def _experimental_make_tensor_descriptor(
2056
2275
 
2057
2276
  @triton.jit
2058
2277
  def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
2059
- desc = tl._experimental_make_tensor_descriptor(
2278
+ desc = tl.make_tensor_descriptor(
2060
2279
  in_out_ptr,
2061
2280
  shape=[M, N],
2062
2281
  strides=[N, 1],
@@ -2082,7 +2301,7 @@ def _experimental_make_tensor_descriptor(
2082
2301
  inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)
2083
2302
 
2084
2303
  """
2085
- return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder)
2304
+ return _semantic.make_tensor_descriptor(base, shape, strides, block_shape)
2086
2305
 
2087
2306
 
2088
2307
  # -----------------------
@@ -2124,89 +2343,89 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
2124
2343
  @_tensor_member_fn
2125
2344
  @builtin
2126
2345
  @_add_atomic_docstr("compare-and-swap", has_cmp=True)
2127
- def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None):
2128
- cmp = semantic.to_tensor(cmp, _builder)
2129
- val = semantic.to_tensor(val, _builder)
2130
- sem = _constexpr_to_value(sem)
2131
- scope = _constexpr_to_value(scope)
2132
- return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder)
2346
+ def atomic_cas(pointer, cmp, val, sem=None, scope=None, _semantic=None):
2347
+ cmp = _semantic.to_tensor(cmp)
2348
+ val = _semantic.to_tensor(val)
2349
+ sem = _unwrap_if_constexpr(sem)
2350
+ scope = _unwrap_if_constexpr(scope)
2351
+ return _semantic.atomic_cas(pointer, cmp, val, sem, scope)
2133
2352
 
2134
2353
 
2135
2354
  @_tensor_member_fn
2136
2355
  @builtin
2137
2356
  @_add_atomic_docstr("exchange")
2138
- def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2139
- val = semantic.to_tensor(val, _builder)
2140
- sem = _constexpr_to_value(sem)
2141
- scope = _constexpr_to_value(scope)
2142
- mask = _constexpr_to_value(mask)
2143
- return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder)
2357
+ def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2358
+ val = _semantic.to_tensor(val)
2359
+ sem = _unwrap_if_constexpr(sem)
2360
+ scope = _unwrap_if_constexpr(scope)
2361
+ mask = _unwrap_if_constexpr(mask)
2362
+ return _semantic.atomic_xchg(pointer, val, mask, sem, scope)
2144
2363
 
2145
2364
 
2146
2365
  @_tensor_member_fn
2147
2366
  @builtin
2148
2367
  @_add_atomic_docstr("add")
2149
- def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2150
- val = semantic.to_tensor(val, _builder)
2151
- sem = _constexpr_to_value(sem)
2152
- scope = _constexpr_to_value(scope)
2153
- mask = _constexpr_to_value(mask)
2154
- return semantic.atomic_add(pointer, val, mask, sem, scope, _builder)
2368
+ def atomic_add(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2369
+ val = _semantic.to_tensor(val)
2370
+ sem = _unwrap_if_constexpr(sem)
2371
+ scope = _unwrap_if_constexpr(scope)
2372
+ mask = _unwrap_if_constexpr(mask)
2373
+ return _semantic.atomic_add(pointer, val, mask, sem, scope)
2155
2374
 
2156
2375
 
2157
2376
  @_tensor_member_fn
2158
2377
  @builtin
2159
2378
  @_add_atomic_docstr("max")
2160
- def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2161
- val = semantic.to_tensor(val, _builder)
2162
- sem = _constexpr_to_value(sem)
2163
- scope = _constexpr_to_value(scope)
2164
- mask = _constexpr_to_value(mask)
2165
- return semantic.atomic_max(pointer, val, mask, sem, scope, _builder)
2379
+ def atomic_max(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2380
+ val = _semantic.to_tensor(val)
2381
+ sem = _unwrap_if_constexpr(sem)
2382
+ scope = _unwrap_if_constexpr(scope)
2383
+ mask = _unwrap_if_constexpr(mask)
2384
+ return _semantic.atomic_max(pointer, val, mask, sem, scope)
2166
2385
 
2167
2386
 
2168
2387
  @_tensor_member_fn
2169
2388
  @builtin
2170
2389
  @_add_atomic_docstr("min")
2171
- def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2172
- val = semantic.to_tensor(val, _builder)
2173
- sem = _constexpr_to_value(sem)
2174
- scope = _constexpr_to_value(scope)
2175
- mask = _constexpr_to_value(mask)
2176
- return semantic.atomic_min(pointer, val, mask, sem, scope, _builder)
2390
+ def atomic_min(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2391
+ val = _semantic.to_tensor(val)
2392
+ sem = _unwrap_if_constexpr(sem)
2393
+ scope = _unwrap_if_constexpr(scope)
2394
+ mask = _unwrap_if_constexpr(mask)
2395
+ return _semantic.atomic_min(pointer, val, mask, sem, scope)
2177
2396
 
2178
2397
 
2179
2398
  @_tensor_member_fn
2180
2399
  @builtin
2181
2400
  @_add_atomic_docstr("logical and")
2182
- def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2183
- val = semantic.to_tensor(val, _builder)
2184
- sem = _constexpr_to_value(sem)
2185
- scope = _constexpr_to_value(scope)
2186
- mask = _constexpr_to_value(mask)
2187
- return semantic.atomic_and(pointer, val, mask, sem, scope, _builder)
2401
+ def atomic_and(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2402
+ val = _semantic.to_tensor(val)
2403
+ sem = _unwrap_if_constexpr(sem)
2404
+ scope = _unwrap_if_constexpr(scope)
2405
+ mask = _unwrap_if_constexpr(mask)
2406
+ return _semantic.atomic_and(pointer, val, mask, sem, scope)
2188
2407
 
2189
2408
 
2190
2409
  @_tensor_member_fn
2191
2410
  @builtin
2192
2411
  @_add_atomic_docstr("logical or")
2193
- def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2194
- val = semantic.to_tensor(val, _builder)
2195
- sem = _constexpr_to_value(sem)
2196
- scope = _constexpr_to_value(scope)
2197
- mask = _constexpr_to_value(mask)
2198
- return semantic.atomic_or(pointer, val, mask, sem, scope, _builder)
2412
+ def atomic_or(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2413
+ val = _semantic.to_tensor(val)
2414
+ sem = _unwrap_if_constexpr(sem)
2415
+ scope = _unwrap_if_constexpr(scope)
2416
+ mask = _unwrap_if_constexpr(mask)
2417
+ return _semantic.atomic_or(pointer, val, mask, sem, scope)
2199
2418
 
2200
2419
 
2201
2420
  @_tensor_member_fn
2202
2421
  @builtin
2203
2422
  @_add_atomic_docstr("logical xor")
2204
- def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2205
- val = semantic.to_tensor(val, _builder)
2206
- sem = _constexpr_to_value(sem)
2207
- scope = _constexpr_to_value(scope)
2208
- mask = _constexpr_to_value(mask)
2209
- return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder)
2423
+ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
2424
+ val = _semantic.to_tensor(val)
2425
+ sem = _unwrap_if_constexpr(sem)
2426
+ scope = _unwrap_if_constexpr(scope)
2427
+ mask = _unwrap_if_constexpr(mask)
2428
+ return _semantic.atomic_xor(pointer, val, mask, sem, scope)
2210
2429
 
2211
2430
 
2212
2431
  # -----------------------
@@ -2215,7 +2434,7 @@ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
2215
2434
 
2216
2435
 
2217
2436
  @builtin
2218
- def where(condition, x, y, _builder=None):
2437
+ def where(condition, x, y, _semantic=None):
2219
2438
  """
2220
2439
  Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
2221
2440
 
@@ -2231,10 +2450,10 @@ def where(condition, x, y, _builder=None):
2231
2450
  :param x: values selected at indices where condition is True.
2232
2451
  :param y: values selected at indices where condition is False.
2233
2452
  """
2234
- condition = semantic.to_tensor(condition, _builder)
2453
+ condition = _semantic.to_tensor(condition)
2235
2454
  x = _unwrap_if_constexpr(x)
2236
2455
  y = _unwrap_if_constexpr(y)
2237
- return semantic.where(condition, x, y, _builder)
2456
+ return _semantic.where(condition, x, y)
2238
2457
 
2239
2458
 
2240
2459
  # -----------------------
@@ -2243,28 +2462,28 @@ def where(condition, x, y, _builder=None):
2243
2462
 
2244
2463
 
2245
2464
  @builtin
2246
- def add(x, y, sanitize_overflow: constexpr = True, _builder=None):
2465
+ def add(x, y, sanitize_overflow: constexpr = True, _semantic=None):
2247
2466
  x = _unwrap_if_constexpr(x)
2248
2467
  y = _unwrap_if_constexpr(y)
2249
- return semantic.add(x, y, sanitize_overflow, _builder)
2468
+ return _semantic.add(x, y, sanitize_overflow)
2250
2469
 
2251
2470
 
2252
2471
  @builtin
2253
- def sub(x, y, sanitize_overflow: constexpr = True, _builder=None):
2472
+ def sub(x, y, sanitize_overflow: constexpr = True, _semantic=None):
2254
2473
  x = _unwrap_if_constexpr(x)
2255
2474
  y = _unwrap_if_constexpr(y)
2256
- return semantic.sub(x, y, sanitize_overflow, _builder)
2475
+ return _semantic.sub(x, y, sanitize_overflow)
2257
2476
 
2258
2477
 
2259
2478
  @builtin
2260
- def mul(x, y, sanitize_overflow: constexpr = True, _builder=None):
2479
+ def mul(x, y, sanitize_overflow: constexpr = True, _semantic=None):
2261
2480
  x = _unwrap_if_constexpr(x)
2262
2481
  y = _unwrap_if_constexpr(y)
2263
- return semantic.mul(x, y, sanitize_overflow, _builder)
2482
+ return _semantic.mul(x, y, sanitize_overflow)
2264
2483
 
2265
2484
 
2266
2485
  @builtin
2267
- def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
2486
+ def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
2268
2487
  """
2269
2488
  Computes the element-wise minimum of :code:`x` and :code:`y`.
2270
2489
 
@@ -2277,16 +2496,16 @@ def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
2277
2496
 
2278
2497
  .. seealso:: :class:`tl.PropagateNan`
2279
2498
  """
2280
- x = semantic.to_tensor(x, _builder)
2281
- y = semantic.to_tensor(y, _builder)
2282
- x = _promote_bfloat16_to_float32(x, _builder=_builder)
2283
- y = _promote_bfloat16_to_float32(y, _builder=_builder)
2284
- propagate_nan = _constexpr_to_value(propagate_nan)
2285
- return semantic.minimum(x, y, propagate_nan, _builder)
2499
+ x = _semantic.to_tensor(x)
2500
+ y = _semantic.to_tensor(y)
2501
+ x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
2502
+ y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
2503
+ propagate_nan = _unwrap_if_constexpr(propagate_nan)
2504
+ return _semantic.minimum(x, y, propagate_nan)
2286
2505
 
2287
2506
 
2288
2507
  @builtin
2289
- def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
2508
+ def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
2290
2509
  """
2291
2510
  Computes the element-wise maximum of :code:`x` and :code:`y`.
2292
2511
 
@@ -2299,16 +2518,16 @@ def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
2299
2518
 
2300
2519
  .. seealso:: :class:`tl.PropagateNan`
2301
2520
  """
2302
- x = semantic.to_tensor(x, _builder)
2303
- y = semantic.to_tensor(y, _builder)
2304
- x = _promote_bfloat16_to_float32(x, _builder=_builder)
2305
- y = _promote_bfloat16_to_float32(y, _builder=_builder)
2306
- propagate_nan = _constexpr_to_value(propagate_nan)
2307
- return semantic.maximum(x, y, propagate_nan, _builder)
2521
+ x = _semantic.to_tensor(x)
2522
+ y = _semantic.to_tensor(y)
2523
+ x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
2524
+ y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
2525
+ propagate_nan = _unwrap_if_constexpr(propagate_nan)
2526
+ return _semantic.maximum(x, y, propagate_nan)
2308
2527
 
2309
2528
 
2310
2529
  @builtin
2311
- def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
2530
+ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
2312
2531
  """
2313
2532
  Clamps the input tensor :code:`x` within the range [min, max].
2314
2533
  Behavior when :code:`min` > :code:`max` is undefined.
@@ -2325,16 +2544,16 @@ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=No
2325
2544
 
2326
2545
  .. seealso:: :class:`tl.PropagateNan`
2327
2546
  """
2328
- x = semantic.to_tensor(x, _builder)
2329
- min = semantic.to_tensor(min, _builder)
2330
- max = semantic.to_tensor(max, _builder)
2331
- x = _promote_bfloat16_to_float32(x, _builder=_builder)
2332
- min = _promote_bfloat16_to_float32(min, _builder=_builder)
2333
- max = _promote_bfloat16_to_float32(max, _builder=_builder)
2547
+ x = _semantic.to_tensor(x)
2548
+ min = _semantic.to_tensor(min)
2549
+ max = _semantic.to_tensor(max)
2550
+ x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
2551
+ min = _promote_bfloat16_to_float32(min, _semantic=_semantic)
2552
+ max = _promote_bfloat16_to_float32(max, _semantic=_semantic)
2334
2553
 
2335
- propagate_nan = _constexpr_to_value(propagate_nan)
2554
+ propagate_nan = _unwrap_if_constexpr(propagate_nan)
2336
2555
 
2337
- return semantic.clamp(x, min, max, propagate_nan, _builder)
2556
+ return _semantic.clamp(x, min, max, propagate_nan)
2338
2557
 
2339
2558
 
2340
2559
  # -----------------------
@@ -2383,7 +2602,7 @@ def _insertion_guard(builder):
2383
2602
 
2384
2603
  @_tensor_member_fn
2385
2604
  @builtin
2386
- def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None):
2605
+ def reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
2387
2606
  """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
2388
2607
 
2389
2608
  :param input: the input tensor, or tuple of tensors
@@ -2397,64 +2616,65 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N
2397
2616
 
2398
2617
  """
2399
2618
  if isinstance(input, tensor):
2400
- return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0]
2619
+ return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic, _generator=_generator)[0]
2401
2620
 
2402
2621
  def make_combine_region(reduce_op):
2403
2622
  param_types = [t.type.scalar for t in input] * 2
2404
2623
  region = reduce_op.get_region(0)
2405
- with _insertion_guard(_builder):
2406
- to_ir = lambda T: T.to_ir(_builder)
2407
- block = _builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2624
+ builder = _semantic.builder
2625
+ with _insertion_guard(builder):
2626
+ to_ir = lambda T: T.to_ir(builder)
2627
+ block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2408
2628
  args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
2409
2629
  results = _generator.call_JitFunction(combine_fn, args, kwargs={})
2410
2630
  if isinstance(results, tensor):
2411
2631
  handles = [results.handle]
2412
2632
  else:
2413
2633
  handles = [r.handle for r in results]
2414
- _builder.create_reduce_ret(*handles)
2634
+ builder.create_reduce_ret(*handles)
2415
2635
 
2416
2636
  def expand_ndims(t, ndims):
2417
2637
  for _ in builtins.range(ndims):
2418
- t = expand_dims(t, 0, _builder=_builder)
2638
+ t = expand_dims(t, 0, _semantic=_semantic)
2419
2639
  return t
2420
2640
 
2421
- axis = _constexpr_to_value(axis)
2422
- keep_dims = _constexpr_to_value(keep_dims)
2641
+ axis = _unwrap_if_constexpr(axis)
2642
+ keep_dims = _unwrap_if_constexpr(keep_dims)
2423
2643
  if axis is not None:
2424
2644
  axis = _wrap_axis(axis, len(input[0].shape))
2425
- ret = semantic.reduction(input, axis, make_combine_region, _builder)
2645
+ ret = _semantic.reduction(input, axis, make_combine_region)
2426
2646
  if keep_dims:
2427
2647
  if axis is not None:
2428
- ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret)
2648
+ ret = tuple(expand_dims(t, axis, _semantic=_semantic) for t in ret)
2429
2649
  else:
2430
2650
  ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret)
2431
2651
  return ret
2432
2652
 
2433
2653
 
2434
2654
  @builtin
2435
- def _promote_bfloat16_to_float32(t, _builder=None):
2655
+ def _promote_bfloat16_to_float32(t, _semantic=None):
2436
2656
  scalar_ty = t.type.scalar
2437
2657
 
2438
2658
  # hardware doesn't support FMAX, FMIN, CMP for bfloat16
2439
2659
  if scalar_ty is bfloat16:
2440
- return t.to(float32, _builder=_builder)
2660
+ return t.to(float32, _semantic=_semantic)
2441
2661
  return t
2442
2662
 
2443
2663
 
2444
2664
  @builtin
2445
- def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None):
2446
- axis = _constexpr_to_value(axis)
2665
+ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
2666
+ axis = _unwrap_if_constexpr(axis)
2447
2667
  n = input.shape[axis]
2448
- index = arange(0, n, _builder=_builder)
2668
+ index = arange(0, n, _semantic=_semantic)
2449
2669
 
2450
2670
  if len(input.shape) > 1:
2451
2671
  # Broadcast index across the non-reduced axes
2452
2672
  axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))]
2453
2673
  del axes_to_expand[axis]
2454
- index = expand_dims(index, axes_to_expand, _builder=_builder)
2455
- index = broadcast_to(index, input.shape, _builder=_builder)
2674
+ index = expand_dims(index, axes_to_expand, _semantic=_semantic)
2675
+ index = broadcast_to(index, input.shape, _semantic=_semantic)
2456
2676
 
2457
- rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _builder=_builder,
2677
+ rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic,
2458
2678
  _generator=_generator)
2459
2679
  return rvalue, rindices
2460
2680
 
@@ -2464,7 +2684,7 @@ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None
2464
2684
  # -----------------------
2465
2685
 
2466
2686
 
2467
- def _add_scan_docstr(name: str) -> Callable[[T], T]:
2687
+ def _add_scan_docstr(name: str, dtype_arg: str = None) -> Callable[[T], T]:
2468
2688
 
2469
2689
  def _decorator(func: T) -> T:
2470
2690
  docstr = """
@@ -2473,7 +2693,15 @@ def _add_scan_docstr(name: str) -> Callable[[T], T]:
2473
2693
  :param input: the input values
2474
2694
  :type input: Tensor
2475
2695
  :param axis: the dimension along which the scan should be done
2476
- :type axis: int"""
2696
+ :type axis: int
2697
+ :param reverse: if true, the scan is performed in the reverse direction
2698
+ :type reverse: bool"""
2699
+
2700
+ if dtype_arg is not None:
2701
+ docstr += f"""
2702
+ :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. If not specified, small integer types (< 32 bits) are upcasted to prevent overflow. Note that :code:`tl.bfloat16` inputs are automatically promoted to :code:`tl.float32`.
2703
+ :type {dtype_arg}: tl.dtype"""
2704
+
2477
2705
  func.__doc__ = docstr.format(name=name)
2478
2706
  return func
2479
2707
 
@@ -2482,7 +2710,7 @@ def _add_scan_docstr(name: str) -> Callable[[T], T]:
2482
2710
 
2483
2711
  @_tensor_member_fn
2484
2712
  @builtin
2485
- def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _generator=None):
2713
+ def associative_scan(input, axis, combine_fn, reverse=False, _semantic=None, _generator=None):
2486
2714
  """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry
2487
2715
 
2488
2716
  :param input: the input tensor, or tuple of tensors
@@ -2496,46 +2724,52 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen
2496
2724
 
2497
2725
  """
2498
2726
  if isinstance(input, tensor):
2499
- return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0]
2727
+ return associative_scan((input, ), axis, combine_fn, reverse, _semantic=_semantic, _generator=_generator)[0]
2500
2728
 
2501
2729
  def make_combine_region(scan_op):
2502
2730
  param_types = [t.type.scalar for t in input] * 2
2503
2731
  region = scan_op.get_region(0)
2504
- with _insertion_guard(_builder):
2505
- to_ir = lambda T: T.to_ir(_builder)
2506
- block = _builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2732
+ builder = _semantic.builder
2733
+ with _insertion_guard(builder):
2734
+ to_ir = lambda T: T.to_ir(builder)
2735
+ block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
2507
2736
  args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
2508
2737
  results = _generator.call_JitFunction(combine_fn, args, kwargs={})
2509
2738
  if isinstance(results, tensor):
2510
2739
  handles = [results.handle]
2511
2740
  else:
2512
2741
  handles = [r.handle for r in results]
2513
- _builder.create_scan_ret(*handles)
2742
+ builder.create_scan_ret(*handles)
2514
2743
 
2515
- axis = _constexpr_to_value(axis)
2744
+ axis = _unwrap_if_constexpr(axis)
2516
2745
  if axis is not None:
2517
2746
  axis = _wrap_axis(axis, len(input[0].shape))
2518
- return semantic.associative_scan(input, axis, make_combine_region, reverse, _builder)
2747
+ return _semantic.associative_scan(input, axis, make_combine_region, reverse)
2519
2748
 
2520
2749
 
2521
2750
  @_tensor_member_fn
2522
2751
  @builtin
2523
- def histogram(input, num_bins, _builder=None, _generator=None):
2752
+ def histogram(input, num_bins, mask=None, _semantic=None, _generator=None):
2524
2753
  """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0.
2525
2754
 
2526
2755
  :param input: the input tensor
2527
2756
  :type input: Tensor
2528
2757
  :param num_bins: number of histogram bins
2529
2758
  :type num_bins: int
2759
+ :param mask: if `mask[idx]` is false, exclude `input[idx]` from histogram
2760
+ :type mask: Block of `triton.int1`, optional
2530
2761
 
2531
2762
  """
2532
- num_bins = _constexpr_to_value(num_bins)
2533
- return semantic.histogram(input, num_bins, _builder)
2763
+ num_bins = _unwrap_if_constexpr(num_bins)
2764
+ mask = _unwrap_if_constexpr(mask)
2765
+ if mask is not None:
2766
+ mask = _semantic.to_tensor(mask)
2767
+ return _semantic.histogram(input, num_bins, mask)
2534
2768
 
2535
2769
 
2536
2770
  @_tensor_member_fn
2537
2771
  @builtin
2538
- def gather(src, index, axis, _builder=None):
2772
+ def gather(src, index, axis, _semantic=None):
2539
2773
  """Gather from a tensor along a given dimension.
2540
2774
 
2541
2775
  :param src: the source tensor
@@ -2546,8 +2780,8 @@ def gather(src, index, axis, _builder=None):
2546
2780
  :type axis: int
2547
2781
 
2548
2782
  """
2549
- axis = _constexpr_to_value(axis)
2550
- return semantic.gather(src, index, axis, _builder)
2783
+ axis = _unwrap_if_constexpr(axis)
2784
+ return _semantic.gather(src, index, axis)
2551
2785
 
2552
2786
 
2553
2787
  # -----------------------
@@ -2556,15 +2790,15 @@ def gather(src, index, axis, _builder=None):
2556
2790
 
2557
2791
 
2558
2792
  @builtin
2559
- def debug_barrier(_builder=None):
2793
+ def debug_barrier(_semantic=None):
2560
2794
  '''
2561
2795
  Insert a barrier to synchronize all threads in a block.
2562
2796
  '''
2563
- return semantic.debug_barrier(_builder)
2797
+ return _semantic.debug_barrier()
2564
2798
 
2565
2799
 
2566
2800
  @builtin
2567
- def multiple_of(input, values, _builder=None):
2801
+ def multiple_of(input, values, _semantic=None):
2568
2802
  """
2569
2803
  Let the compiler know that the values in :code:`input` are all multiples of :code:`value`.
2570
2804
  """
@@ -2576,11 +2810,11 @@ def multiple_of(input, values, _builder=None):
2576
2810
  if not isinstance(d.value, int):
2577
2811
  raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2578
2812
  values = [x.value for x in values]
2579
- return semantic.multiple_of(input, values)
2813
+ return _semantic.multiple_of(input, values)
2580
2814
 
2581
2815
 
2582
2816
  @builtin
2583
- def max_contiguous(input, values, _builder=None):
2817
+ def max_contiguous(input, values, _semantic=None):
2584
2818
  """
2585
2819
  Let the compiler know that the `value` first values in :code:`input` are contiguous.
2586
2820
  """
@@ -2592,11 +2826,11 @@ def max_contiguous(input, values, _builder=None):
2592
2826
  if not isinstance(d.value, int):
2593
2827
  raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2594
2828
  values = [x.value for x in values]
2595
- return semantic.max_contiguous(input, values)
2829
+ return _semantic.max_contiguous(input, values)
2596
2830
 
2597
2831
 
2598
2832
  @builtin
2599
- def max_constancy(input, values, _builder=None):
2833
+ def max_constancy(input, values, _semantic=None):
2600
2834
  """
2601
2835
  Let the compiler know that the `value` first values in :code:`input` are constant.
2602
2836
 
@@ -2611,15 +2845,15 @@ def max_constancy(input, values, _builder=None):
2611
2845
  if not isinstance(d.value, int):
2612
2846
  raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
2613
2847
  values = [x.value for x in values]
2614
- return semantic.max_constancy(input, values)
2848
+ return _semantic.max_constancy(input, values)
2615
2849
 
2616
2850
 
2617
2851
  @builtin
2618
- def assume(cond, _builder=None):
2852
+ def assume(cond, _semantic=None):
2619
2853
  '''
2620
2854
  Allow compiler to assume the :code:`cond` is True.
2621
2855
  '''
2622
- return semantic.assume(semantic.to_tensor(cond, _builder), _builder)
2856
+ return _semantic.assume(_semantic.to_tensor(cond))
2623
2857
 
2624
2858
 
2625
2859
  # -----------------------
@@ -2628,7 +2862,7 @@ def assume(cond, _builder=None):
2628
2862
 
2629
2863
 
2630
2864
  @builtin
2631
- def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None):
2865
+ def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _semantic=None):
2632
2866
  '''
2633
2867
  Print the values at compile time. The parameters are the same as the builtin :code:`print`.
2634
2868
 
@@ -2644,7 +2878,7 @@ def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=Fals
2644
2878
 
2645
2879
 
2646
2880
  @builtin
2647
- def static_assert(cond, msg="", _builder=None):
2881
+ def static_assert(cond, msg="", _semantic=None):
2648
2882
  '''
2649
2883
  Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable
2650
2884
  is set.
@@ -2658,7 +2892,7 @@ def static_assert(cond, msg="", _builder=None):
2658
2892
 
2659
2893
 
2660
2894
  @builtin
2661
- def device_print(prefix, *args, hex=False, _builder=None):
2895
+ def device_print(prefix, *args, hex=False, _semantic=None):
2662
2896
  '''
2663
2897
  Print the values at runtime from the device. String formatting does not work for runtime values, so you should
2664
2898
  provide the values you want to print as arguments. The first value must be a string, all following values must
@@ -2692,7 +2926,7 @@ def device_print(prefix, *args, hex=False, _builder=None):
2692
2926
  :param hex: print all values as hex instead of decimal
2693
2927
  '''
2694
2928
  import string
2695
- prefix = _constexpr_to_value(prefix)
2929
+ prefix = _unwrap_if_constexpr(prefix)
2696
2930
  assert isinstance(prefix, str), f"{prefix} is not string"
2697
2931
  b_ascii = True
2698
2932
  for ch in prefix:
@@ -2702,12 +2936,12 @@ def device_print(prefix, *args, hex=False, _builder=None):
2702
2936
  assert b_ascii, f"{prefix} is not an ascii string"
2703
2937
  new_args = []
2704
2938
  for arg in args:
2705
- new_args.append(semantic.to_tensor(arg, _builder))
2706
- return semantic.device_print(prefix, new_args, hex, _builder)
2939
+ new_args.append(_semantic.to_tensor(arg))
2940
+ return _semantic.device_print(prefix, new_args, hex)
2707
2941
 
2708
2942
 
2709
2943
  @builtin
2710
- def device_assert(cond, msg="", _builder=None):
2944
+ def device_assert(cond, msg="", _semantic=None):
2711
2945
  '''
2712
2946
  Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG`
2713
2947
  is set to a value besides :code:`0` in order for this to have any effect.
@@ -2725,13 +2959,13 @@ def device_assert(cond, msg="", _builder=None):
2725
2959
  :param cond: the condition to assert. This is required to be a boolean tensor.
2726
2960
  :param msg: the message to print if the assertion fails. This is required to be a string literal.
2727
2961
  '''
2728
- msg = _constexpr_to_value(msg)
2729
- return semantic.device_assert(semantic.to_tensor(cond, _builder), msg, _builder)
2962
+ msg = _unwrap_if_constexpr(msg)
2963
+ return _semantic.device_assert(_semantic.to_tensor(cond), msg)
2730
2964
 
2731
2965
 
2732
2966
  @builtin
2733
2967
  def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]],
2734
- is_pure: bool, pack: int, _builder=None):
2968
+ is_pure: bool, pack: int, _semantic=None):
2735
2969
  '''
2736
2970
  Execute inline assembly over a tensor. Essentially, this is :code:`map`
2737
2971
  where the function is inline assembly.
@@ -2816,13 +3050,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
2816
3050
  :param dtype: the element type(s) of the returned tensor(s)
2817
3051
  :param is_pure: if true, the compiler assumes the asm block has no side-effects
2818
3052
  :param pack: the number of elements to be processed by one instance of inline assembly
2819
- :param _builder: the builder
2820
3053
  :return: one tensor or a tuple of tensors of the given dtypes
2821
3054
  '''
2822
- asm = _constexpr_to_value(asm)
2823
- constraints = _constexpr_to_value(constraints)
2824
- pack = _constexpr_to_value(pack)
2825
- is_pure = _constexpr_to_value(is_pure)
3055
+ asm = _unwrap_if_constexpr(asm)
3056
+ constraints = _unwrap_if_constexpr(constraints)
3057
+ pack = _unwrap_if_constexpr(pack)
3058
+ is_pure = _unwrap_if_constexpr(is_pure)
2826
3059
 
2827
3060
  # Wrap `dtype` in a tuple if it's not already.
2828
3061
  try:
@@ -2835,10 +3068,9 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
2835
3068
  dtype = typing.cast(Sequence[_DtypeClass], dtype)
2836
3069
 
2837
3070
  res_tys = dtype
2838
- if dispatch_args := [semantic.to_tensor(arg, _builder) for arg in args]:
3071
+ if dispatch_args := [_semantic.to_tensor(arg) for arg in args]:
2839
3072
  bin_op_type_checking = partial(
2840
- semantic.binary_op_type_checking_impl,
2841
- builder=_builder,
3073
+ _semantic.binary_op_type_checking_impl,
2842
3074
  arithmetic_check=False,
2843
3075
  allow_lhs_ptr=True,
2844
3076
  allow_rhs_ptr=True,
@@ -2851,9 +3083,10 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
2851
3083
  # Change the shape of each argument based on the broadcast shape
2852
3084
  for i, item in enumerate(dispatch_args):
2853
3085
  dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg)
2854
- res_tys = [block_type(dt, broadcast_arg.shape) for dt in dtype]
3086
+ res_tys = [broadcast_arg.type.with_element_ty(dt) for dt in dtype]
2855
3087
  handles = [t.handle for t in dispatch_args]
2856
- call = _builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(_builder) for ty in res_tys], is_pure, pack)
3088
+ builder = _semantic.builder
3089
+ call = builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(builder) for ty in res_tys], is_pure, pack)
2857
3090
 
2858
3091
  if not has_multiple_outputs:
2859
3092
  return tensor(call.get_result(0), res_tys[0])
@@ -2905,6 +3138,22 @@ class static_range:
2905
3138
  raise RuntimeError("static_range can only be used in @triton.jit'd functions")
2906
3139
 
2907
3140
 
3141
+ class async_task:
3142
+ """
3143
+ Context manager to run code fragments asynchronously.
3144
+ """
3145
+
3146
+ def __init__(self, task_ids, _builder=None):
3147
+ self.task_ids = list({_unwrap_if_constexpr(tid) for tid in task_ids})
3148
+ self.builder = _builder
3149
+
3150
+ def __enter__(self):
3151
+ self.builder.set_async_task_ids(self.task_ids)
3152
+
3153
+ def __exit__(self, exc_type, exc_value, traceback):
3154
+ self.builder.unset_async_task_ids()
3155
+
3156
+
2908
3157
  class range:
2909
3158
  """
2910
3159
  Iterator that counts upward forever.
@@ -2936,10 +3185,18 @@ class range:
2936
3185
  :param flatten: automatically flatten the loop nest starting at this loop to
2937
3186
  create a single flattened loop. The compiler will try to pipeline the
2938
3187
  flattened loop which can avoid stage stalling.
3188
+ :param warp_specialize: Enable automatic warp specialization on the loop.
3189
+ The compiler will attempt to partition memory, MMA, and vector
3190
+ operations in the loop into separate async partitions. This will
3191
+ increase the total number of warps required by the kernel.
3192
+
3193
+ Note that warp specialization is only supported on Blackwell GPUs and
3194
+ only works on simple matmul loops. Support for arbitrary loops will be
3195
+ expanded over time.
2939
3196
  """
2940
3197
 
2941
3198
  def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
2942
- disallow_acc_multi_buffer=False, flatten=False):
3199
+ disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False):
2943
3200
  if step is None:
2944
3201
  self.step = constexpr(1)
2945
3202
  else:
@@ -2954,6 +3211,7 @@ class range:
2954
3211
  self.loop_unroll_factor = loop_unroll_factor
2955
3212
  self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
2956
3213
  self.flatten = flatten
3214
+ self.warp_specialize = warp_specialize
2957
3215
 
2958
3216
  def __iter__(self):
2959
3217
  raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
@@ -2968,7 +3226,7 @@ class range:
2968
3226
 
2969
3227
 
2970
3228
  def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
2971
- is_pure: bool, _builder=None):
3229
+ is_pure: bool, _semantic):
2972
3230
  '''
2973
3231
  Dispatch a function to a library
2974
3232
  :param func: the function to dispatch
@@ -2977,7 +3235,6 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
2977
3235
  :param args: the arguments of the function
2978
3236
  :param arg_type_symbol_dict: the type of the arguments
2979
3237
  :param ret_shape: the shape of the return value
2980
- :param _builder: the builder
2981
3238
  :return: the return value of the function
2982
3239
  '''
2983
3240
  if len(arg_type_symbol_dict) == 0:
@@ -3007,12 +3264,13 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
3007
3264
  ret_type = arg_type_symbol_dict[arg_types][1]
3008
3265
  if ret_shape:
3009
3266
  ret_type = block_type(ret_type, ret_shape)
3010
- return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
3267
+ builder = _semantic.builder
3268
+ return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(builder), is_pure), ret_type)
3011
3269
 
3012
3270
 
3013
3271
  @builtin
3014
3272
  def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
3015
- _builder=None):
3273
+ _semantic=None):
3016
3274
  '''
3017
3275
  Dispatch an elementwise function to a library
3018
3276
  :param lib_name: the name of the library
@@ -3020,7 +3278,6 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
3020
3278
  :param args: the arguments of the function
3021
3279
  :param arg_type_symbol_dict: the type of the arguments
3022
3280
  :param is_pure: whether the function is pure
3023
- :param _builder: the builder
3024
3281
  :return: the return value of the function
3025
3282
  '''
3026
3283
  dispatch_args = args.copy()
@@ -3028,7 +3285,7 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
3028
3285
  ret_shape = None
3029
3286
  arg_types = []
3030
3287
  for i in builtins.range(len(dispatch_args)):
3031
- dispatch_args[i] = semantic.to_tensor(dispatch_args[i], _builder)
3288
+ dispatch_args[i] = _semantic.to_tensor(dispatch_args[i])
3032
3289
  arg_types.append(dispatch_args[i].dtype)
3033
3290
  if dispatch_args[i].type.is_block():
3034
3291
  all_scalar = False
@@ -3041,26 +3298,26 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
3041
3298
  broadcast_arg = dispatch_args[0]
3042
3299
  # Get the broadcast shape over all the arguments
3043
3300
  for item in dispatch_args:
3044
- _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder,
3045
- arithmetic_check=arithmetic_check)
3301
+ _, broadcast_arg = _semantic.binary_op_type_checking_impl(item, broadcast_arg,
3302
+ arithmetic_check=arithmetic_check)
3046
3303
  # Change the shape of each argument based on the broadcast shape
3047
3304
  for i in builtins.range(len(dispatch_args)):
3048
- dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder,
3049
- arithmetic_check=arithmetic_check)
3305
+ dispatch_args[i], _ = _semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg,
3306
+ arithmetic_check=arithmetic_check)
3050
3307
  if not all_scalar:
3051
3308
  ret_shape = broadcast_arg.shape
3052
- func = _builder.create_extern_elementwise
3053
- return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder)
3309
+ func = _semantic.builder.create_extern_elementwise
3310
+ return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _semantic)
3054
3311
 
3055
3312
 
3056
- def binary_op_type_legalization(lhs, rhs, builder):
3313
+ def binary_op_type_legalization(lhs, rhs, semantic):
3057
3314
  '''
3058
3315
  Convert both operands to a single common type
3059
3316
  :param lhs: the left operand
3060
3317
  :param rhs: the right operand
3061
3318
  :param builder: the builder
3062
3319
  '''
3063
- return semantic.binary_op_type_checking_impl(lhs, rhs, builder)
3320
+ return semantic.binary_op_type_checking_impl(lhs, rhs)
3064
3321
 
3065
3322
 
3066
3323
  def extern(fn):