triton-windows 3.3.1.post19__cp39-cp39-win_amd64.whl → 3.4.0.post20__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +149 -47
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +92 -93
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +303 -128
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +76 -12
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +14 -6
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
- triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
- triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
triton/language/core.py
CHANGED
|
@@ -6,14 +6,14 @@ from enum import Enum
|
|
|
6
6
|
from functools import partial, wraps
|
|
7
7
|
import typing
|
|
8
8
|
from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
|
|
9
|
+
from dataclasses import dataclass
|
|
9
10
|
import builtins
|
|
10
|
-
from ..
|
|
11
|
+
from .. import knobs
|
|
12
|
+
from ..runtime.jit import jit, JITFunction
|
|
11
13
|
import inspect
|
|
12
|
-
import os
|
|
13
14
|
|
|
14
15
|
from .._C.libtriton import ir
|
|
15
|
-
from
|
|
16
|
-
from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape
|
|
16
|
+
from .._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape, get_primitive_bitwidth
|
|
17
17
|
|
|
18
18
|
T = TypeVar('T')
|
|
19
19
|
|
|
@@ -22,15 +22,23 @@ TRITON_BUILTIN = "__triton_builtin__"
|
|
|
22
22
|
PropagateNan = ir.PROPAGATE_NAN
|
|
23
23
|
|
|
24
24
|
|
|
25
|
+
def must_use_result(x, s=True):
|
|
26
|
+
"""If the result of this function is unused, throw an error."""
|
|
27
|
+
if isinstance(x, str):
|
|
28
|
+
return (lambda fn: must_use_result(fn, x))
|
|
29
|
+
x._must_use_result = s
|
|
30
|
+
return x
|
|
31
|
+
|
|
32
|
+
|
|
25
33
|
def builtin(fn: T) -> T:
|
|
26
34
|
"""Mark a function as a builtin."""
|
|
27
35
|
assert callable(fn)
|
|
28
36
|
|
|
29
37
|
@wraps(fn)
|
|
30
38
|
def wrapper(*args, **kwargs):
|
|
31
|
-
if "
|
|
39
|
+
if "_semantic" not in kwargs or kwargs["_semantic"] is None:
|
|
32
40
|
raise ValueError("Did you forget to add @triton.jit ? "
|
|
33
|
-
"(`
|
|
41
|
+
"(`_semantic` argument must be provided outside of JIT functions.)")
|
|
34
42
|
return fn(*args, **kwargs)
|
|
35
43
|
|
|
36
44
|
setattr(wrapper, TRITON_BUILTIN, True)
|
|
@@ -53,8 +61,8 @@ def _tensor_member_fn(fn: T) -> T:
|
|
|
53
61
|
"""
|
|
54
62
|
assert callable(fn)
|
|
55
63
|
orig_sig = inspect.signature(fn)
|
|
56
|
-
# Does fn take args other than
|
|
57
|
-
has_args = len(orig_sig.parameters.keys() - {"
|
|
64
|
+
# Does fn take args other than _semantic, _generator, and the tensor itself?
|
|
65
|
+
has_args = len(orig_sig.parameters.keys() - {"_semantic", "_generator"}) > 1
|
|
58
66
|
|
|
59
67
|
if not fn.__doc__:
|
|
60
68
|
fn.__doc__ = ""
|
|
@@ -78,7 +86,7 @@ def _tensor_member_fn(fn: T) -> T:
|
|
|
78
86
|
if is_builtin(fn):
|
|
79
87
|
setattr(wrapper, TRITON_BUILTIN, True)
|
|
80
88
|
|
|
81
|
-
setattr(tensor, fn.__name__, wrapper)
|
|
89
|
+
setattr(tensor, fn.__name__, fn if isinstance(fn, JITFunction) else wrapper)
|
|
82
90
|
return fn
|
|
83
91
|
|
|
84
92
|
|
|
@@ -110,8 +118,8 @@ def is_builtin(fn) -> bool:
|
|
|
110
118
|
|
|
111
119
|
|
|
112
120
|
@builtin
|
|
113
|
-
def to_tensor(x,
|
|
114
|
-
return
|
|
121
|
+
def to_tensor(x, _semantic=None):
|
|
122
|
+
return _semantic.to_tensor(x)
|
|
115
123
|
|
|
116
124
|
|
|
117
125
|
# -----------------------
|
|
@@ -130,7 +138,62 @@ class const:
|
|
|
130
138
|
pass
|
|
131
139
|
|
|
132
140
|
|
|
133
|
-
class
|
|
141
|
+
class base_value:
|
|
142
|
+
"""Base class of values that exist in the triton IR (i.e. not constexprs).
|
|
143
|
+
"""
|
|
144
|
+
type: base_type
|
|
145
|
+
|
|
146
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
147
|
+
"""Flatten frontend value into a sequence of mlir handles, which are appended
|
|
148
|
+
to the output list
|
|
149
|
+
"""
|
|
150
|
+
raise NotImplementedError
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class base_type:
|
|
154
|
+
|
|
155
|
+
def __eq__(self, other):
|
|
156
|
+
raise NotImplementedError("Types must implement __eq__")
|
|
157
|
+
|
|
158
|
+
def __ne__(self, other):
|
|
159
|
+
return not (self == other)
|
|
160
|
+
|
|
161
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
|
|
162
|
+
"""Build a frontend value with the current dtype, wrapping a list of existing handles.
|
|
163
|
+
cursor is the index of the first handle relevant to this value, and the function
|
|
164
|
+
should return the updated cursor position after any handles consumed by the created value.
|
|
165
|
+
"""
|
|
166
|
+
raise NotImplementedError
|
|
167
|
+
|
|
168
|
+
def mangle(self) -> str:
|
|
169
|
+
raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}")
|
|
170
|
+
|
|
171
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
172
|
+
raise NotImplementedError
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class constexpr_type(base_type):
|
|
176
|
+
|
|
177
|
+
def __init__(self, value):
|
|
178
|
+
self.value = value
|
|
179
|
+
|
|
180
|
+
def __eq__(self, other):
|
|
181
|
+
return self.value == other.value
|
|
182
|
+
|
|
183
|
+
def __repr__(self) -> str:
|
|
184
|
+
return f"constexpr[{self.value}]"
|
|
185
|
+
|
|
186
|
+
def mangle(self) -> str:
|
|
187
|
+
return repr(self)
|
|
188
|
+
|
|
189
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
|
|
193
|
+
return constexpr(self.value), cursor
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class constexpr(base_value):
|
|
134
197
|
"""
|
|
135
198
|
This class is used to store a value that is known at compile-time.
|
|
136
199
|
"""
|
|
@@ -140,80 +203,83 @@ class constexpr:
|
|
|
140
203
|
self.value = value.value
|
|
141
204
|
else:
|
|
142
205
|
self.value = value
|
|
143
|
-
self.type =
|
|
206
|
+
self.type = constexpr_type(value)
|
|
144
207
|
|
|
145
208
|
def __repr__(self) -> str:
|
|
146
209
|
return f"constexpr[{self.value}]"
|
|
147
210
|
|
|
211
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
212
|
+
return
|
|
213
|
+
|
|
148
214
|
def __index__(self):
|
|
149
215
|
return self.value
|
|
150
216
|
|
|
151
217
|
# In interpreter mode, constant values are not wrapped in constexpr,
|
|
152
218
|
# and therefore do not have a .value attribute.
|
|
153
|
-
# As a result, from here and below, we need to call the
|
|
219
|
+
# As a result, from here and below, we need to call the _unwrap_if_constexpr
|
|
154
220
|
# function to obtain either constexpr.value or the value itself.
|
|
155
221
|
def __add__(self, other):
|
|
156
|
-
return constexpr(self.value +
|
|
222
|
+
return constexpr(self.value + _unwrap_if_constexpr(other))
|
|
157
223
|
|
|
158
224
|
def __radd__(self, other):
|
|
159
|
-
return constexpr(
|
|
225
|
+
return constexpr(_unwrap_if_constexpr(other) + self.value)
|
|
160
226
|
|
|
161
227
|
def __sub__(self, other):
|
|
162
|
-
return constexpr(self.value -
|
|
228
|
+
return constexpr(self.value - _unwrap_if_constexpr(other))
|
|
163
229
|
|
|
164
230
|
def __rsub__(self, other):
|
|
165
|
-
return constexpr(
|
|
231
|
+
return constexpr(_unwrap_if_constexpr(other) - self.value)
|
|
166
232
|
|
|
167
233
|
def __mul__(self, other):
|
|
168
|
-
return constexpr(self.value *
|
|
234
|
+
return constexpr(self.value * _unwrap_if_constexpr(other))
|
|
169
235
|
|
|
170
236
|
def __mod__(self, other):
|
|
171
|
-
return constexpr(self.value %
|
|
237
|
+
return constexpr(self.value % _unwrap_if_constexpr(other))
|
|
172
238
|
|
|
173
239
|
def __rmul__(self, other):
|
|
174
|
-
return constexpr(
|
|
240
|
+
return constexpr(_unwrap_if_constexpr(other) * self.value)
|
|
175
241
|
|
|
176
242
|
def __truediv__(self, other):
|
|
177
|
-
return constexpr(self.value /
|
|
243
|
+
return constexpr(self.value / _unwrap_if_constexpr(other))
|
|
178
244
|
|
|
179
245
|
def __rtruediv__(self, other):
|
|
180
|
-
return constexpr(
|
|
246
|
+
return constexpr(_unwrap_if_constexpr(other) / self.value)
|
|
181
247
|
|
|
182
248
|
def __floordiv__(self, other):
|
|
183
|
-
return constexpr(self.value //
|
|
249
|
+
return constexpr(self.value // _unwrap_if_constexpr(other))
|
|
184
250
|
|
|
185
251
|
def __rfloordiv__(self, other):
|
|
186
|
-
return constexpr(
|
|
252
|
+
return constexpr(_unwrap_if_constexpr(other) // self.value)
|
|
187
253
|
|
|
188
254
|
def __gt__(self, other):
|
|
189
|
-
return constexpr(self.value >
|
|
255
|
+
return constexpr(self.value > _unwrap_if_constexpr(other))
|
|
190
256
|
|
|
191
257
|
def __rgt__(self, other):
|
|
192
|
-
return constexpr(
|
|
258
|
+
return constexpr(_unwrap_if_constexpr(other) > self.value)
|
|
193
259
|
|
|
194
260
|
def __ge__(self, other):
|
|
195
|
-
return constexpr(self.value >=
|
|
261
|
+
return constexpr(self.value >= _unwrap_if_constexpr(other))
|
|
196
262
|
|
|
197
263
|
def __rge__(self, other):
|
|
198
|
-
return constexpr(
|
|
264
|
+
return constexpr(_unwrap_if_constexpr(other) >= self.value)
|
|
199
265
|
|
|
200
266
|
def __lt__(self, other):
|
|
201
|
-
return constexpr(self.value <
|
|
267
|
+
return constexpr(self.value < _unwrap_if_constexpr(other))
|
|
202
268
|
|
|
203
269
|
def __rlt__(self, other):
|
|
204
|
-
return constexpr(
|
|
270
|
+
return constexpr(_unwrap_if_constexpr(other) < self.value)
|
|
205
271
|
|
|
206
272
|
def __le__(self, other):
|
|
207
|
-
return constexpr(self.value <=
|
|
273
|
+
return constexpr(self.value <= _unwrap_if_constexpr(other))
|
|
208
274
|
|
|
209
275
|
def __rle__(self, other):
|
|
210
|
-
return constexpr(
|
|
276
|
+
return constexpr(_unwrap_if_constexpr(other) <= self.value)
|
|
211
277
|
|
|
212
278
|
def __eq__(self, other):
|
|
213
|
-
return constexpr(self.value ==
|
|
279
|
+
return constexpr(self.value == _unwrap_if_constexpr(other))
|
|
214
280
|
|
|
215
281
|
def __ne__(self, other):
|
|
216
|
-
return constexpr(self.value !=
|
|
282
|
+
return constexpr(self.value != _unwrap_if_constexpr(other))
|
|
217
283
|
|
|
218
284
|
def __bool__(self):
|
|
219
285
|
return bool(self.value)
|
|
@@ -222,19 +288,19 @@ class constexpr:
|
|
|
222
288
|
return constexpr(-self.value)
|
|
223
289
|
|
|
224
290
|
def __and__(self, other):
|
|
225
|
-
return constexpr(self.value &
|
|
291
|
+
return constexpr(self.value & _unwrap_if_constexpr(other))
|
|
226
292
|
|
|
227
293
|
def logical_and(self, other):
|
|
228
|
-
return constexpr(self.value and
|
|
294
|
+
return constexpr(self.value and _unwrap_if_constexpr(other))
|
|
229
295
|
|
|
230
296
|
def __or__(self, other):
|
|
231
|
-
return constexpr(self.value |
|
|
297
|
+
return constexpr(self.value | _unwrap_if_constexpr(other))
|
|
232
298
|
|
|
233
299
|
def __xor__(self, other):
|
|
234
|
-
return constexpr(self.value ^
|
|
300
|
+
return constexpr(self.value ^ _unwrap_if_constexpr(other))
|
|
235
301
|
|
|
236
302
|
def logical_or(self, other):
|
|
237
|
-
return constexpr(self.value or
|
|
303
|
+
return constexpr(self.value or _unwrap_if_constexpr(other))
|
|
238
304
|
|
|
239
305
|
def __pos__(self):
|
|
240
306
|
return constexpr(+self.value)
|
|
@@ -243,16 +309,16 @@ class constexpr:
|
|
|
243
309
|
return constexpr(~self.value)
|
|
244
310
|
|
|
245
311
|
def __pow__(self, other):
|
|
246
|
-
return constexpr(self.value**
|
|
312
|
+
return constexpr(self.value**_unwrap_if_constexpr(other))
|
|
247
313
|
|
|
248
314
|
def __rpow__(self, other):
|
|
249
|
-
return constexpr(
|
|
315
|
+
return constexpr(_unwrap_if_constexpr(other)**self.value)
|
|
250
316
|
|
|
251
317
|
def __rshift__(self, other):
|
|
252
|
-
return constexpr(self.value >>
|
|
318
|
+
return constexpr(self.value >> _unwrap_if_constexpr(other))
|
|
253
319
|
|
|
254
320
|
def __lshift__(self, other):
|
|
255
|
-
return constexpr(self.value <<
|
|
321
|
+
return constexpr(self.value << _unwrap_if_constexpr(other))
|
|
256
322
|
|
|
257
323
|
def __not__(self):
|
|
258
324
|
return constexpr(not self.value)
|
|
@@ -263,14 +329,57 @@ class constexpr:
|
|
|
263
329
|
def __call__(self, *args, **kwds):
|
|
264
330
|
return self.value(*args, **kwds)
|
|
265
331
|
|
|
332
|
+
def __getitem__(self, *args):
|
|
333
|
+
args = (_unwrap_if_constexpr(x) for x in _normalize_tuple(args))
|
|
334
|
+
return self.value.__getitem__(*args)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def constexpr_function(f):
|
|
338
|
+
"""
|
|
339
|
+
Wraps an arbitrary Python function so that it can be called at
|
|
340
|
+
compile-time on constexpr arguments in a Triton function and
|
|
341
|
+
returns a constexpr result.
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
@wraps(f)
|
|
345
|
+
def wrapper(*args, _semantic=None, **kwargs):
|
|
346
|
+
# de-constexpr arguments and discard the _semantic keyword argument:
|
|
347
|
+
args = [_unwrap_if_constexpr(x) for x in args]
|
|
348
|
+
kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()}
|
|
349
|
+
|
|
350
|
+
# call the raw Python function f:
|
|
351
|
+
res = f(*args, **kwargs)
|
|
352
|
+
|
|
353
|
+
# convert result back to a Triton constexpr:
|
|
354
|
+
return constexpr(res)
|
|
355
|
+
|
|
356
|
+
# disguise the function as a Triton builtin to avoid raising an error
|
|
357
|
+
# that we're calling a non-JIT function from within a Triton kernel:
|
|
358
|
+
wrapper.__triton_builtin__ = True
|
|
359
|
+
wrapper.__module__ = constexpr_function.__module__
|
|
360
|
+
return wrapper
|
|
361
|
+
|
|
266
362
|
|
|
267
363
|
CONSTEXPR_0 = constexpr(0)
|
|
268
364
|
|
|
269
365
|
|
|
270
366
|
def _unwrap_if_constexpr(o):
|
|
367
|
+
if isinstance(o, list):
|
|
368
|
+
return [_unwrap_if_constexpr(x) for x in o]
|
|
369
|
+
if isinstance(o, builtins.tuple):
|
|
370
|
+
return builtins.tuple(_unwrap_if_constexpr(x) for x in o)
|
|
371
|
+
if isinstance(o, tuple):
|
|
372
|
+
return tuple(_unwrap_if_constexpr(x) for x in o)
|
|
271
373
|
return o.value if isinstance(o, constexpr) else o
|
|
272
374
|
|
|
273
375
|
|
|
376
|
+
def _normalize_tuple(t):
|
|
377
|
+
normalized_tuple = _unwrap_if_constexpr(t)
|
|
378
|
+
if isinstance(normalized_tuple, (list, builtins.tuple)):
|
|
379
|
+
normalized_tuple = tuple(normalized_tuple)
|
|
380
|
+
return normalized_tuple
|
|
381
|
+
|
|
382
|
+
|
|
274
383
|
def check_bit_width(value, shift_value):
|
|
275
384
|
if isinstance(value, tensor) and isinstance(shift_value, constexpr):
|
|
276
385
|
bitwidth = value.type.scalar.primitive_bitwidth
|
|
@@ -280,34 +389,6 @@ def check_bit_width(value, shift_value):
|
|
|
280
389
|
)
|
|
281
390
|
|
|
282
391
|
|
|
283
|
-
class base_value:
|
|
284
|
-
"""Base class of values that exist in the triton IR (i.e. not constexprs).
|
|
285
|
-
"""
|
|
286
|
-
type: base_type
|
|
287
|
-
|
|
288
|
-
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
289
|
-
"""Flatten frontend value into a sequence of mlir handles, which are appended
|
|
290
|
-
to the output list
|
|
291
|
-
"""
|
|
292
|
-
raise NotImplementedError
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
class base_type:
|
|
296
|
-
|
|
297
|
-
def __eq__(self, other):
|
|
298
|
-
raise NotImplementedError("Types must implement __eq__")
|
|
299
|
-
|
|
300
|
-
def __ne__(self, other):
|
|
301
|
-
return not (self == other)
|
|
302
|
-
|
|
303
|
-
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
|
|
304
|
-
"""Build a frontend value with the current dtype, wrapping a list of existing handles.
|
|
305
|
-
cursor is the index of the first handle relevant to this value, and the function
|
|
306
|
-
should return the updated cursor position after any handles consumed by the created value.
|
|
307
|
-
"""
|
|
308
|
-
raise NotImplementedError
|
|
309
|
-
|
|
310
|
-
|
|
311
392
|
# -----------------------
|
|
312
393
|
# dtype
|
|
313
394
|
# -----------------------
|
|
@@ -333,55 +414,44 @@ class dtype(base_type):
|
|
|
333
414
|
name = _unwrap_if_constexpr(name)
|
|
334
415
|
self.name = name
|
|
335
416
|
assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name
|
|
417
|
+
self.primitive_bitwidth = get_primitive_bitwidth(name)
|
|
418
|
+
self.itemsize = self.primitive_bitwidth // 8
|
|
336
419
|
if name in dtype.SINT_TYPES:
|
|
337
420
|
self.int_signedness = dtype.SIGNEDNESS.SIGNED
|
|
338
|
-
self.int_bitwidth =
|
|
339
|
-
self.primitive_bitwidth = self.int_bitwidth
|
|
421
|
+
self.int_bitwidth = self.primitive_bitwidth
|
|
340
422
|
elif name in dtype.UINT_TYPES:
|
|
341
423
|
self.int_signedness = dtype.SIGNEDNESS.UNSIGNED
|
|
342
|
-
self.int_bitwidth =
|
|
343
|
-
self.primitive_bitwidth = self.int_bitwidth
|
|
424
|
+
self.int_bitwidth = self.primitive_bitwidth
|
|
344
425
|
elif name in dtype.FP_TYPES:
|
|
345
426
|
if name == 'fp8e4b15':
|
|
346
427
|
self.fp_mantissa_width = 3
|
|
347
|
-
self.primitive_bitwidth = 8
|
|
348
428
|
self.exponent_bias = 15
|
|
349
429
|
elif name == 'fp8e4nv':
|
|
350
430
|
self.fp_mantissa_width = 3
|
|
351
|
-
self.primitive_bitwidth = 8
|
|
352
431
|
self.exponent_bias = 7
|
|
353
432
|
elif name == 'fp8e4b8':
|
|
354
433
|
self.fp_mantissa_width = 3
|
|
355
|
-
self.primitive_bitwidth = 8
|
|
356
434
|
self.exponent_bias = 8
|
|
357
435
|
elif name == 'fp8e5':
|
|
358
436
|
self.fp_mantissa_width = 2
|
|
359
|
-
self.primitive_bitwidth = 8
|
|
360
437
|
self.exponent_bias = 15
|
|
361
438
|
elif name == 'fp8e5b16':
|
|
362
439
|
self.fp_mantissa_width = 2
|
|
363
|
-
self.primitive_bitwidth = 8
|
|
364
440
|
self.exponent_bias = 16
|
|
365
441
|
elif name == 'fp16':
|
|
366
442
|
self.fp_mantissa_width = 10
|
|
367
|
-
self.primitive_bitwidth = 16
|
|
368
443
|
self.exponent_bias = 15
|
|
369
444
|
elif name == 'bf16':
|
|
370
445
|
self.fp_mantissa_width = 7
|
|
371
|
-
self.primitive_bitwidth = 16
|
|
372
446
|
self.exponent_bias = 127
|
|
373
447
|
elif name == 'fp32':
|
|
374
448
|
self.fp_mantissa_width = 23
|
|
375
|
-
self.primitive_bitwidth = 32
|
|
376
449
|
self.exponent_bias = 127
|
|
377
450
|
elif name == 'fp64':
|
|
378
451
|
self.fp_mantissa_width = 52
|
|
379
|
-
self.primitive_bitwidth = 64
|
|
380
452
|
self.exponent_bias = 1023
|
|
381
453
|
else:
|
|
382
454
|
raise RuntimeError(f'Unsupported floating-point type {name}')
|
|
383
|
-
elif name == 'void':
|
|
384
|
-
self.primitive_bitwidth = 0
|
|
385
455
|
|
|
386
456
|
def is_fp8(self):
|
|
387
457
|
return 'fp8' in self.name
|
|
@@ -502,10 +572,6 @@ class dtype(base_type):
|
|
|
502
572
|
def is_const():
|
|
503
573
|
return False
|
|
504
574
|
|
|
505
|
-
@staticmethod
|
|
506
|
-
def is_tuple():
|
|
507
|
-
return False
|
|
508
|
-
|
|
509
575
|
def __eq__(self, other: dtype):
|
|
510
576
|
if not isinstance(other, dtype):
|
|
511
577
|
return False
|
|
@@ -518,13 +584,14 @@ class dtype(base_type):
|
|
|
518
584
|
def scalar(self):
|
|
519
585
|
return self
|
|
520
586
|
|
|
587
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
588
|
+
out.append(self.to_ir(builder))
|
|
589
|
+
|
|
521
590
|
def to_ir(self, builder: ir.builder) -> ir.type:
|
|
522
591
|
if self.name.startswith("fp8"):
|
|
523
592
|
if self.name not in builder.options.supported_fp8_dtypes:
|
|
524
593
|
raise ValueError(f'type {self} not supported in this architecture. '
|
|
525
594
|
f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}')
|
|
526
|
-
if self.name in builder.options.deprecated_fp8_dtypes:
|
|
527
|
-
warn(f"{self.name} is deprecated in this architecture and will be removed in a future triton release")
|
|
528
595
|
|
|
529
596
|
if self.name == 'void':
|
|
530
597
|
return builder.get_void_ty()
|
|
@@ -581,6 +648,21 @@ class dtype(base_type):
|
|
|
581
648
|
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
|
|
582
649
|
return tensor(handles[cursor], self), cursor + 1
|
|
583
650
|
|
|
651
|
+
def mangle(self) -> str:
|
|
652
|
+
if self.is_int():
|
|
653
|
+
SIGNED = dtype.SIGNEDNESS.SIGNED
|
|
654
|
+
prefix = 'i' if self.int_signedness == SIGNED else 'u'
|
|
655
|
+
return prefix + str(self.int_bitwidth)
|
|
656
|
+
if self.is_floating():
|
|
657
|
+
return str(self)
|
|
658
|
+
if self.is_void():
|
|
659
|
+
return 'V'
|
|
660
|
+
return super().mangle()
|
|
661
|
+
|
|
662
|
+
def with_element_ty(self, element_ty: dtype):
|
|
663
|
+
assert not self.is_block()
|
|
664
|
+
return element_ty
|
|
665
|
+
|
|
584
666
|
|
|
585
667
|
# Some functions have a param named `dtype`, which shadows the `dtype` class.
|
|
586
668
|
# We can't change the param name because it is part of function's public API.
|
|
@@ -623,12 +705,8 @@ class pointer_type(dtype):
|
|
|
623
705
|
def scalar(self):
|
|
624
706
|
return self
|
|
625
707
|
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
def __init__(self, const=True, address_space=0):
|
|
630
|
-
super().__init__(uint8, const=const, address_space=address_space)
|
|
631
|
-
self.name = 'nv_tma_desc_type'
|
|
708
|
+
def mangle(self) -> str:
|
|
709
|
+
return f"P{self.element_ty.mangle()}"
|
|
632
710
|
|
|
633
711
|
|
|
634
712
|
class block_type(dtype):
|
|
@@ -660,9 +738,12 @@ class block_type(dtype):
|
|
|
660
738
|
def is_block(self):
|
|
661
739
|
return True
|
|
662
740
|
|
|
663
|
-
def get_block_shapes(self) ->
|
|
741
|
+
def get_block_shapes(self) -> Tuple[int]:
|
|
664
742
|
return self.shape
|
|
665
743
|
|
|
744
|
+
def with_element_ty(self, scalar_ty: dtype) -> block_type:
|
|
745
|
+
return block_type(scalar_ty, self.shape)
|
|
746
|
+
|
|
666
747
|
def __eq__(self, other) -> bool:
|
|
667
748
|
if not isinstance(other, block_type):
|
|
668
749
|
return False
|
|
@@ -672,6 +753,11 @@ class block_type(dtype):
|
|
|
672
753
|
def scalar(self):
|
|
673
754
|
return self.element_ty
|
|
674
755
|
|
|
756
|
+
def mangle(self) -> str:
|
|
757
|
+
elt = self.scalar.mangle()
|
|
758
|
+
shape = '_'.join(map(str, self.shape))
|
|
759
|
+
return f'{elt}S{shape}S'
|
|
760
|
+
|
|
675
761
|
|
|
676
762
|
class tuple_type(base_type):
|
|
677
763
|
|
|
@@ -686,15 +772,14 @@ class tuple_type(base_type):
|
|
|
686
772
|
def __iter__(self):
|
|
687
773
|
return iter(self.types)
|
|
688
774
|
|
|
689
|
-
def
|
|
690
|
-
|
|
775
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]):
|
|
776
|
+
for ty in self.types:
|
|
777
|
+
if not isinstance(ty, constexpr):
|
|
778
|
+
ty._flatten_ir_types(builder, out)
|
|
691
779
|
|
|
692
780
|
def __getitem__(self, index: int) -> dtype:
|
|
693
781
|
return self.types[index]
|
|
694
782
|
|
|
695
|
-
def is_tuple(self):
|
|
696
|
-
return True
|
|
697
|
-
|
|
698
783
|
def __eq__(self, other):
|
|
699
784
|
return type(self) is type(other) and self.types == other.types and self.fields == other.fields
|
|
700
785
|
|
|
@@ -705,6 +790,9 @@ class tuple_type(base_type):
|
|
|
705
790
|
values.append(value)
|
|
706
791
|
return tuple(values, self), cursor
|
|
707
792
|
|
|
793
|
+
def mangle(self):
|
|
794
|
+
return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T'
|
|
795
|
+
|
|
708
796
|
|
|
709
797
|
class slice_type(dtype):
|
|
710
798
|
|
|
@@ -808,224 +896,224 @@ class tensor(base_value):
|
|
|
808
896
|
return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
|
|
809
897
|
|
|
810
898
|
@builtin
|
|
811
|
-
def __add__(self, other,
|
|
812
|
-
return add(self, other, sanitize_overflow=True,
|
|
899
|
+
def __add__(self, other, _semantic=None):
|
|
900
|
+
return add(self, other, sanitize_overflow=True, _semantic=_semantic)
|
|
813
901
|
|
|
814
902
|
@builtin
|
|
815
|
-
def __radd__(self, other,
|
|
816
|
-
return add(other, self, sanitize_overflow=True,
|
|
903
|
+
def __radd__(self, other, _semantic=None):
|
|
904
|
+
return add(other, self, sanitize_overflow=True, _semantic=_semantic)
|
|
817
905
|
|
|
818
906
|
@builtin
|
|
819
|
-
def __sub__(self, other,
|
|
820
|
-
return sub(self, other, sanitize_overflow=True,
|
|
907
|
+
def __sub__(self, other, _semantic=None):
|
|
908
|
+
return sub(self, other, sanitize_overflow=True, _semantic=_semantic)
|
|
821
909
|
|
|
822
910
|
@builtin
|
|
823
|
-
def __rsub__(self, other,
|
|
824
|
-
return sub(other, self, sanitize_overflow=True,
|
|
911
|
+
def __rsub__(self, other, _semantic=None):
|
|
912
|
+
return sub(other, self, sanitize_overflow=True, _semantic=_semantic)
|
|
825
913
|
|
|
826
914
|
@builtin
|
|
827
|
-
def __mul__(self, other,
|
|
828
|
-
return mul(self, other, sanitize_overflow=True,
|
|
915
|
+
def __mul__(self, other, _semantic=None):
|
|
916
|
+
return mul(self, other, sanitize_overflow=True, _semantic=_semantic)
|
|
829
917
|
|
|
830
918
|
@builtin
|
|
831
|
-
def __rmul__(self, other,
|
|
832
|
-
return mul(other, self, sanitize_overflow=True,
|
|
919
|
+
def __rmul__(self, other, _semantic=None):
|
|
920
|
+
return mul(other, self, sanitize_overflow=True, _semantic=_semantic)
|
|
833
921
|
|
|
834
922
|
@builtin
|
|
835
|
-
def __truediv__(self, other,
|
|
923
|
+
def __truediv__(self, other, _semantic=None):
|
|
836
924
|
other = _unwrap_if_constexpr(other)
|
|
837
|
-
return
|
|
925
|
+
return _semantic.truediv(self, other)
|
|
838
926
|
|
|
839
927
|
@builtin
|
|
840
|
-
def __rtruediv__(self, other,
|
|
928
|
+
def __rtruediv__(self, other, _semantic=None):
|
|
841
929
|
other = _unwrap_if_constexpr(other)
|
|
842
|
-
return
|
|
930
|
+
return _semantic.truediv(other, self)
|
|
843
931
|
|
|
844
932
|
@builtin
|
|
845
|
-
def __floordiv__(self, other,
|
|
933
|
+
def __floordiv__(self, other, _semantic=None):
|
|
846
934
|
other = _unwrap_if_constexpr(other)
|
|
847
|
-
return
|
|
935
|
+
return _semantic.floordiv(self, other)
|
|
848
936
|
|
|
849
937
|
@builtin
|
|
850
|
-
def __rfloordiv__(self, other,
|
|
938
|
+
def __rfloordiv__(self, other, _semantic=None):
|
|
851
939
|
other = _unwrap_if_constexpr(other)
|
|
852
|
-
return
|
|
940
|
+
return _semantic.floordiv(other, self)
|
|
853
941
|
|
|
854
942
|
@builtin
|
|
855
|
-
def __mod__(self, other,
|
|
943
|
+
def __mod__(self, other, _semantic=None):
|
|
856
944
|
other = _unwrap_if_constexpr(other)
|
|
857
|
-
return
|
|
945
|
+
return _semantic.mod(self, other)
|
|
858
946
|
|
|
859
947
|
@builtin
|
|
860
|
-
def __rmod__(self, other,
|
|
948
|
+
def __rmod__(self, other, _semantic=None):
|
|
861
949
|
other = _unwrap_if_constexpr(other)
|
|
862
|
-
return
|
|
950
|
+
return _semantic.mod(other, self)
|
|
863
951
|
|
|
864
952
|
# unary operators
|
|
865
953
|
@builtin
|
|
866
|
-
def __neg__(self,
|
|
867
|
-
return
|
|
954
|
+
def __neg__(self, _semantic=None):
|
|
955
|
+
return _semantic.minus(self)
|
|
868
956
|
|
|
869
957
|
@builtin
|
|
870
|
-
def __invert__(self,
|
|
871
|
-
return
|
|
958
|
+
def __invert__(self, _semantic=None):
|
|
959
|
+
return _semantic.invert(self)
|
|
872
960
|
|
|
873
961
|
# bitwise operators
|
|
874
962
|
|
|
875
963
|
@builtin
|
|
876
|
-
def __and__(self, other,
|
|
964
|
+
def __and__(self, other, _semantic=None):
|
|
877
965
|
other = _unwrap_if_constexpr(other)
|
|
878
|
-
return
|
|
966
|
+
return _semantic.and_(self, other)
|
|
879
967
|
|
|
880
968
|
@builtin
|
|
881
|
-
def __rand__(self, other,
|
|
969
|
+
def __rand__(self, other, _semantic=None):
|
|
882
970
|
other = _unwrap_if_constexpr(other)
|
|
883
|
-
return
|
|
971
|
+
return _semantic.and_(other, self)
|
|
884
972
|
|
|
885
973
|
@builtin
|
|
886
|
-
def __or__(self, other,
|
|
974
|
+
def __or__(self, other, _semantic=None):
|
|
887
975
|
other = _unwrap_if_constexpr(other)
|
|
888
|
-
return
|
|
976
|
+
return _semantic.or_(self, other)
|
|
889
977
|
|
|
890
978
|
@builtin
|
|
891
|
-
def __ror__(self, other,
|
|
979
|
+
def __ror__(self, other, _semantic=None):
|
|
892
980
|
other = _unwrap_if_constexpr(other)
|
|
893
|
-
return
|
|
981
|
+
return _semantic.or_(other, self)
|
|
894
982
|
|
|
895
983
|
@builtin
|
|
896
|
-
def __xor__(self, other,
|
|
984
|
+
def __xor__(self, other, _semantic=None):
|
|
897
985
|
other = _unwrap_if_constexpr(other)
|
|
898
|
-
return
|
|
986
|
+
return _semantic.xor_(self, other)
|
|
899
987
|
|
|
900
988
|
@builtin
|
|
901
|
-
def __rxor__(self, other,
|
|
989
|
+
def __rxor__(self, other, _semantic=None):
|
|
902
990
|
other = _unwrap_if_constexpr(other)
|
|
903
|
-
return
|
|
991
|
+
return _semantic.xor_(other, self)
|
|
904
992
|
|
|
905
993
|
@builtin
|
|
906
|
-
def __lshift__(self, other,
|
|
994
|
+
def __lshift__(self, other, _semantic=None):
|
|
907
995
|
check_bit_width(self, other)
|
|
908
996
|
other = _unwrap_if_constexpr(other)
|
|
909
|
-
return
|
|
997
|
+
return _semantic.shl(self, other)
|
|
910
998
|
|
|
911
999
|
@builtin
|
|
912
|
-
def __rlshift__(self, other,
|
|
1000
|
+
def __rlshift__(self, other, _semantic=None):
|
|
913
1001
|
check_bit_width(other, self)
|
|
914
1002
|
other = _unwrap_if_constexpr(other)
|
|
915
|
-
return
|
|
1003
|
+
return _semantic.shl(other, self)
|
|
916
1004
|
|
|
917
1005
|
@builtin
|
|
918
|
-
def __rshift__(self, other,
|
|
1006
|
+
def __rshift__(self, other, _semantic=None):
|
|
919
1007
|
check_bit_width(self, other)
|
|
920
1008
|
other = _unwrap_if_constexpr(other)
|
|
921
1009
|
if self.dtype.is_int_signed():
|
|
922
|
-
return
|
|
1010
|
+
return _semantic.ashr(self, other)
|
|
923
1011
|
else:
|
|
924
|
-
return
|
|
1012
|
+
return _semantic.lshr(self, other)
|
|
925
1013
|
|
|
926
1014
|
@builtin
|
|
927
|
-
def __rrshift__(self, other,
|
|
1015
|
+
def __rrshift__(self, other, _semantic=None):
|
|
928
1016
|
check_bit_width(other, self)
|
|
929
1017
|
other = _unwrap_if_constexpr(other)
|
|
930
1018
|
if self.dtype.is_int_signed():
|
|
931
|
-
return
|
|
1019
|
+
return _semantic.ashr(other, self)
|
|
932
1020
|
else:
|
|
933
|
-
return
|
|
1021
|
+
return _semantic.lshr(other, self)
|
|
934
1022
|
|
|
935
1023
|
# >
|
|
936
1024
|
@builtin
|
|
937
|
-
def __gt__(self, other,
|
|
938
|
-
other =
|
|
939
|
-
return
|
|
1025
|
+
def __gt__(self, other, _semantic=None):
|
|
1026
|
+
other = _semantic.to_tensor(other)
|
|
1027
|
+
return _semantic.greater_than(self, other)
|
|
940
1028
|
|
|
941
1029
|
@builtin
|
|
942
|
-
def __rgt__(self, other,
|
|
943
|
-
other =
|
|
944
|
-
return
|
|
1030
|
+
def __rgt__(self, other, _semantic=None):
|
|
1031
|
+
other = _semantic.to_tensor(other)
|
|
1032
|
+
return _semantic.greater_than(other, self)
|
|
945
1033
|
|
|
946
1034
|
# >=
|
|
947
1035
|
@builtin
|
|
948
|
-
def __ge__(self, other,
|
|
949
|
-
other =
|
|
950
|
-
return
|
|
1036
|
+
def __ge__(self, other, _semantic=None):
|
|
1037
|
+
other = _semantic.to_tensor(other)
|
|
1038
|
+
return _semantic.greater_equal(self, other)
|
|
951
1039
|
|
|
952
1040
|
@builtin
|
|
953
|
-
def __rge__(self, other,
|
|
954
|
-
other =
|
|
955
|
-
return
|
|
1041
|
+
def __rge__(self, other, _semantic=None):
|
|
1042
|
+
other = _semantic.to_tensor(other)
|
|
1043
|
+
return _semantic.greater_equal(other, self)
|
|
956
1044
|
|
|
957
1045
|
# <
|
|
958
1046
|
@builtin
|
|
959
|
-
def __lt__(self, other,
|
|
960
|
-
other =
|
|
961
|
-
return
|
|
1047
|
+
def __lt__(self, other, _semantic=None):
|
|
1048
|
+
other = _semantic.to_tensor(other)
|
|
1049
|
+
return _semantic.less_than(self, other)
|
|
962
1050
|
|
|
963
1051
|
@builtin
|
|
964
|
-
def __rlt__(self, other,
|
|
965
|
-
other =
|
|
966
|
-
return
|
|
1052
|
+
def __rlt__(self, other, _semantic=None):
|
|
1053
|
+
other = _semantic.to_tensor(other)
|
|
1054
|
+
return _semantic.less_than(other, self)
|
|
967
1055
|
|
|
968
1056
|
# <=
|
|
969
1057
|
@builtin
|
|
970
|
-
def __le__(self, other,
|
|
971
|
-
other =
|
|
972
|
-
return
|
|
1058
|
+
def __le__(self, other, _semantic=None):
|
|
1059
|
+
other = _semantic.to_tensor(other)
|
|
1060
|
+
return _semantic.less_equal(self, other)
|
|
973
1061
|
|
|
974
1062
|
@builtin
|
|
975
|
-
def __rle__(self, other,
|
|
976
|
-
other =
|
|
977
|
-
return
|
|
1063
|
+
def __rle__(self, other, _semantic=None):
|
|
1064
|
+
other = _semantic.to_tensor(other)
|
|
1065
|
+
return _semantic.less_equal(other, self)
|
|
978
1066
|
|
|
979
1067
|
# ==
|
|
980
1068
|
@builtin
|
|
981
|
-
def __eq__(self, other,
|
|
982
|
-
other =
|
|
983
|
-
return
|
|
1069
|
+
def __eq__(self, other, _semantic=None):
|
|
1070
|
+
other = _semantic.to_tensor(other)
|
|
1071
|
+
return _semantic.equal(self, other)
|
|
984
1072
|
|
|
985
1073
|
@builtin
|
|
986
|
-
def __req__(self, other,
|
|
987
|
-
other =
|
|
988
|
-
return
|
|
1074
|
+
def __req__(self, other, _semantic=None):
|
|
1075
|
+
other = _semantic.to_tensor(other)
|
|
1076
|
+
return _semantic.equal(other, self)
|
|
989
1077
|
|
|
990
1078
|
@builtin
|
|
991
|
-
def __ne__(self, other,
|
|
992
|
-
other =
|
|
993
|
-
return
|
|
1079
|
+
def __ne__(self, other, _semantic=None):
|
|
1080
|
+
other = _semantic.to_tensor(other)
|
|
1081
|
+
return _semantic.not_equal(self, other)
|
|
994
1082
|
|
|
995
1083
|
@builtin
|
|
996
|
-
def __rne__(self, other,
|
|
997
|
-
other =
|
|
998
|
-
return
|
|
1084
|
+
def __rne__(self, other, _semantic=None):
|
|
1085
|
+
other = _semantic.to_tensor(other)
|
|
1086
|
+
return _semantic.not_equal(other, self)
|
|
999
1087
|
|
|
1000
1088
|
@builtin
|
|
1001
|
-
def logical_and(self, other,
|
|
1002
|
-
other =
|
|
1003
|
-
return
|
|
1089
|
+
def logical_and(self, other, _semantic=None):
|
|
1090
|
+
other = _semantic.to_tensor(other)
|
|
1091
|
+
return _semantic.logical_and(self, other)
|
|
1004
1092
|
|
|
1005
1093
|
@builtin
|
|
1006
|
-
def logical_or(self, other,
|
|
1007
|
-
other =
|
|
1008
|
-
return
|
|
1094
|
+
def logical_or(self, other, _semantic=None):
|
|
1095
|
+
other = _semantic.to_tensor(other)
|
|
1096
|
+
return _semantic.logical_or(self, other)
|
|
1009
1097
|
|
|
1010
1098
|
# note: __not__ isn't actually a magic method in python
|
|
1011
1099
|
# but it's ok because our ASTVisitor handles it
|
|
1012
1100
|
@builtin
|
|
1013
|
-
def __not__(self,
|
|
1014
|
-
return
|
|
1101
|
+
def __not__(self, _semantic=None):
|
|
1102
|
+
return _semantic.not_(self)
|
|
1015
1103
|
|
|
1016
1104
|
@builtin
|
|
1017
|
-
def __getitem__(self, slices,
|
|
1018
|
-
import builtins
|
|
1105
|
+
def __getitem__(self, slices, _semantic=None):
|
|
1019
1106
|
if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None:
|
|
1020
1107
|
slices = [slices]
|
|
1021
1108
|
if isinstance(slices, tuple):
|
|
1022
1109
|
slices = slices.values
|
|
1023
1110
|
ret = self
|
|
1024
1111
|
for dim, sl in enumerate(slices):
|
|
1025
|
-
if
|
|
1026
|
-
ret =
|
|
1027
|
-
elif isinstance(sl, (builtins.slice, slice)) and
|
|
1028
|
-
|
|
1112
|
+
if _unwrap_if_constexpr(sl) is None:
|
|
1113
|
+
ret = _semantic.expand_dims(ret, dim)
|
|
1114
|
+
elif isinstance(sl, (builtins.slice, slice)) and all(
|
|
1115
|
+
_unwrap_if_constexpr(arg) is None for arg in (sl.start, sl.stop, sl.step)):
|
|
1116
|
+
pass # an unsqueeze
|
|
1029
1117
|
else:
|
|
1030
1118
|
raise ValueError(f"unsupported tensor index: {sl}")
|
|
1031
1119
|
return ret
|
|
@@ -1036,11 +1124,11 @@ class tensor(base_value):
|
|
|
1036
1124
|
assert False, "Transposition must be created by the AST Visitor"
|
|
1037
1125
|
|
|
1038
1126
|
@builtin
|
|
1039
|
-
def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False,
|
|
1127
|
+
def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
|
|
1040
1128
|
"""
|
|
1041
1129
|
Alias for :py:func:`tensor.cast`.
|
|
1042
1130
|
"""
|
|
1043
|
-
return cast(self, dtype, fp_downcast_rounding, bitcast,
|
|
1131
|
+
return cast(self, dtype, fp_downcast_rounding, bitcast, _semantic=_semantic)
|
|
1044
1132
|
|
|
1045
1133
|
# Type stubs for functions added by the _tensor_member_fn decorator.
|
|
1046
1134
|
# (Unfortunately these can't be created automatically.)
|
|
@@ -1140,7 +1228,7 @@ class tensor(base_value):
|
|
|
1140
1228
|
def sigmoid(self) -> tensor:
|
|
1141
1229
|
...
|
|
1142
1230
|
|
|
1143
|
-
def softmax(self, ieee_rounding=False) -> tensor:
|
|
1231
|
+
def softmax(self, dim=None, keep_dims=False, ieee_rounding=False) -> tensor:
|
|
1144
1232
|
...
|
|
1145
1233
|
|
|
1146
1234
|
def ravel(self) -> tensor:
|
|
@@ -1164,6 +1252,9 @@ class tensor(base_value):
|
|
|
1164
1252
|
def xor_sum(self, axis=None, keep_dims=False) -> tensor:
|
|
1165
1253
|
...
|
|
1166
1254
|
|
|
1255
|
+
def reduce_or(self, axis=None, keep_dims=False) -> tensor:
|
|
1256
|
+
...
|
|
1257
|
+
|
|
1167
1258
|
def cumsum(self, axis=0, reverse=False) -> tensor:
|
|
1168
1259
|
...
|
|
1169
1260
|
|
|
@@ -1179,13 +1270,13 @@ class tensor(base_value):
|
|
|
1179
1270
|
|
|
1180
1271
|
class tuple(base_value):
|
|
1181
1272
|
|
|
1182
|
-
def __init__(self, args:
|
|
1273
|
+
def __init__(self, args: Sequence, type: tuple_type = None):
|
|
1183
1274
|
self.values = [i for i in args]
|
|
1184
1275
|
|
|
1185
1276
|
def get_type(x):
|
|
1186
1277
|
if isinstance(x, dtype):
|
|
1187
1278
|
return dtype
|
|
1188
|
-
if isinstance(x, int):
|
|
1279
|
+
if isinstance(x, (int, float)):
|
|
1189
1280
|
return constexpr
|
|
1190
1281
|
return x.type
|
|
1191
1282
|
|
|
@@ -1197,7 +1288,6 @@ class tuple(base_value):
|
|
|
1197
1288
|
if isinstance(idx, constexpr):
|
|
1198
1289
|
return self.values[idx]
|
|
1199
1290
|
else:
|
|
1200
|
-
import builtins
|
|
1201
1291
|
assert isinstance(idx, (slice, builtins.slice))
|
|
1202
1292
|
return tuple(self.values[idx.start:idx.stop:idx.step])
|
|
1203
1293
|
|
|
@@ -1212,8 +1302,7 @@ class tuple(base_value):
|
|
|
1212
1302
|
self.values[idx] = value
|
|
1213
1303
|
|
|
1214
1304
|
def __add__(self, other):
|
|
1215
|
-
|
|
1216
|
-
other = tuple(other)
|
|
1305
|
+
other = _normalize_tuple(other)
|
|
1217
1306
|
return tuple(self.values + other.values)
|
|
1218
1307
|
# return tuple(a + b for a, b in zip(self.values, other.values))
|
|
1219
1308
|
|
|
@@ -1222,13 +1311,10 @@ class tuple(base_value):
|
|
|
1222
1311
|
return tuple(self.values * other.value)
|
|
1223
1312
|
|
|
1224
1313
|
def __eq__(self, other):
|
|
1225
|
-
|
|
1226
|
-
if isinstance(other, (list, builtins.tuple)):
|
|
1227
|
-
other = tuple(other)
|
|
1314
|
+
other = _normalize_tuple(other)
|
|
1228
1315
|
return constexpr(self.values == other.values)
|
|
1229
1316
|
|
|
1230
1317
|
def __hash__(self):
|
|
1231
|
-
import builtins
|
|
1232
1318
|
return hash(builtins.tuple(self.values))
|
|
1233
1319
|
|
|
1234
1320
|
def __str__(self):
|
|
@@ -1244,6 +1330,9 @@ class tuple(base_value):
|
|
|
1244
1330
|
for v in self.values:
|
|
1245
1331
|
v._flatten_ir(handles)
|
|
1246
1332
|
|
|
1333
|
+
def __repr__(self):
|
|
1334
|
+
return f"({' ,'.join(repr(x) for x in self.values)})"
|
|
1335
|
+
|
|
1247
1336
|
|
|
1248
1337
|
class slice:
|
|
1249
1338
|
|
|
@@ -1259,12 +1348,13 @@ class tensor_descriptor_base_type(base_type):
|
|
|
1259
1348
|
def __init__(self, block_type: block_type):
|
|
1260
1349
|
self.block_type = block_type
|
|
1261
1350
|
|
|
1262
|
-
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[
|
|
1263
|
-
value =
|
|
1351
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
|
|
1352
|
+
value = tensor_descriptor_base(handles[cursor], self.block_type)
|
|
1264
1353
|
return value, cursor + 1
|
|
1265
1354
|
|
|
1266
|
-
def
|
|
1267
|
-
|
|
1355
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
1356
|
+
is_signed = self.block_type.element_ty.is_int_signed()
|
|
1357
|
+
out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed))
|
|
1268
1358
|
|
|
1269
1359
|
def __str__(self) -> str:
|
|
1270
1360
|
# ex. "tensor_descriptor<float32[16, 32]>"
|
|
@@ -1278,8 +1368,11 @@ class tensor_descriptor_base_type(base_type):
|
|
|
1278
1368
|
def __neq__(self, other) -> bool:
|
|
1279
1369
|
return not (self == other)
|
|
1280
1370
|
|
|
1371
|
+
def mangle(self) -> str:
|
|
1372
|
+
return f"TD{self.block_type.mangle()}"
|
|
1373
|
+
|
|
1281
1374
|
|
|
1282
|
-
class
|
|
1375
|
+
class tensor_descriptor_base(base_value):
|
|
1283
1376
|
""""
|
|
1284
1377
|
A tensor descriptor with unknown shape and strides
|
|
1285
1378
|
"""
|
|
@@ -1310,40 +1403,64 @@ class _experimental_tensor_descriptor_base(base_value):
|
|
|
1310
1403
|
return str(self.type)
|
|
1311
1404
|
|
|
1312
1405
|
@builtin
|
|
1313
|
-
def load(self, offsets: Sequence[constexpr | tensor],
|
|
1406
|
+
def load(self, offsets: Sequence[constexpr | tensor], _semantic=None) -> tensor:
|
|
1314
1407
|
"""Load a block from the descriptor starting at the given element offsets.
|
|
1315
1408
|
|
|
1316
1409
|
Values outside of the tensor bounds will be filled with zeros.
|
|
1317
1410
|
|
|
1318
1411
|
:note: Offset must be a multiple of 16-bytes
|
|
1319
1412
|
"""
|
|
1320
|
-
return
|
|
1413
|
+
return _semantic.descriptor_load(self, offsets, "", "")
|
|
1321
1414
|
|
|
1322
1415
|
@builtin
|
|
1323
|
-
def store(self, offsets: Sequence[constexpr | tensor], value: tensor,
|
|
1416
|
+
def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1324
1417
|
"""Store a block from the descriptor starting at the given element offsets.
|
|
1325
1418
|
|
|
1326
1419
|
Values outside of the tensor bounds will be ignored.
|
|
1327
1420
|
|
|
1328
1421
|
:note: Offset must be a multiple of 16-bytes
|
|
1329
1422
|
"""
|
|
1330
|
-
return
|
|
1423
|
+
return _semantic.descriptor_store(self, value, offsets)
|
|
1424
|
+
|
|
1425
|
+
@builtin
|
|
1426
|
+
def atomic_add(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1427
|
+
return _semantic.descriptor_atomic_add(self, value, offsets)
|
|
1428
|
+
|
|
1429
|
+
@builtin
|
|
1430
|
+
def atomic_min(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1431
|
+
return _semantic.descriptor_atomic_min(self, value, offsets)
|
|
1331
1432
|
|
|
1332
1433
|
@builtin
|
|
1333
|
-
def
|
|
1434
|
+
def atomic_max(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1435
|
+
return _semantic.descriptor_atomic_max(self, value, offsets)
|
|
1436
|
+
|
|
1437
|
+
@builtin
|
|
1438
|
+
def atomic_and(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1439
|
+
return _semantic.descriptor_atomic_and(self, value, offsets)
|
|
1440
|
+
|
|
1441
|
+
@builtin
|
|
1442
|
+
def atomic_or(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1443
|
+
return _semantic.descriptor_atomic_or(self, value, offsets)
|
|
1444
|
+
|
|
1445
|
+
@builtin
|
|
1446
|
+
def atomic_xor(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1447
|
+
return _semantic.descriptor_atomic_xor(self, value, offsets)
|
|
1448
|
+
|
|
1449
|
+
@builtin
|
|
1450
|
+
def gather(self, *args, _semantic=None) -> tensor:
|
|
1334
1451
|
"""Gather multiple descriptors worth of data"""
|
|
1335
1452
|
assert len(args) == 2, f"descriptor gather only supports 2D indexing, but got {len(args)}"
|
|
1336
1453
|
x_offsets = args[0]
|
|
1337
1454
|
y_offset = args[1]
|
|
1338
|
-
return
|
|
1455
|
+
return _semantic.descriptor_gather(self, x_offsets, y_offset, "", "")
|
|
1339
1456
|
|
|
1340
1457
|
@builtin
|
|
1341
|
-
def scatter(self, value, *args,
|
|
1458
|
+
def scatter(self, value, *args, _semantic=None) -> tensor:
|
|
1342
1459
|
"""Scatter multiple descriptors worth of data"""
|
|
1343
1460
|
assert len(args) == 2, f"descriptor scatter only supports 2D indexing, but got {len(args)}"
|
|
1344
1461
|
x_offsets = args[0]
|
|
1345
1462
|
y_offset = args[1]
|
|
1346
|
-
return
|
|
1463
|
+
return _semantic.descriptor_scatter(self, value, x_offsets, y_offset)
|
|
1347
1464
|
|
|
1348
1465
|
|
|
1349
1466
|
class tensor_descriptor_type(tensor_descriptor_base_type):
|
|
@@ -1353,25 +1470,27 @@ class tensor_descriptor_type(tensor_descriptor_base_type):
|
|
|
1353
1470
|
self.shape_type = shape_type
|
|
1354
1471
|
self.strides_type = strides_type
|
|
1355
1472
|
|
|
1356
|
-
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[
|
|
1473
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
|
|
1357
1474
|
handle = handles[cursor]
|
|
1358
1475
|
cursor += 1
|
|
1359
1476
|
shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
|
|
1360
1477
|
strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
|
|
1361
1478
|
shape = shape.values
|
|
1362
1479
|
strides = strides.values
|
|
1363
|
-
value =
|
|
1480
|
+
value = tensor_descriptor(handle, shape, strides, self.block_type)
|
|
1364
1481
|
return value, cursor
|
|
1365
1482
|
|
|
1366
|
-
def
|
|
1367
|
-
|
|
1483
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
1484
|
+
super()._flatten_ir_types(builder, out)
|
|
1485
|
+
self.shape_type._flatten_ir_types(builder, out)
|
|
1486
|
+
self.strides_type._flatten_ir_types(builder, out)
|
|
1368
1487
|
|
|
1369
1488
|
def __eq__(self, other):
|
|
1370
1489
|
return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type
|
|
1371
1490
|
== other.strides_type)
|
|
1372
1491
|
|
|
1373
1492
|
|
|
1374
|
-
class
|
|
1493
|
+
class tensor_descriptor(tensor_descriptor_base):
|
|
1375
1494
|
"""A descriptor representing a tensor in global memory.
|
|
1376
1495
|
"""
|
|
1377
1496
|
|
|
@@ -1379,37 +1498,121 @@ class _experimental_tensor_descriptor(_experimental_tensor_descriptor_base):
|
|
|
1379
1498
|
"""Not called by user code."""
|
|
1380
1499
|
# IR handle
|
|
1381
1500
|
super().__init__(handle, block_type)
|
|
1501
|
+
# Global shape
|
|
1502
|
+
self.shape = tuple(shape)
|
|
1503
|
+
self.strides = tuple(strides)
|
|
1382
1504
|
self.type = tensor_descriptor_type(
|
|
1383
1505
|
block_type,
|
|
1384
|
-
shape_type=
|
|
1385
|
-
strides_type=
|
|
1506
|
+
shape_type=self.shape.type,
|
|
1507
|
+
strides_type=self.strides.type,
|
|
1386
1508
|
)
|
|
1387
|
-
# Global shape
|
|
1388
|
-
self.shape = shape
|
|
1389
|
-
self.strides = strides
|
|
1390
1509
|
|
|
1391
1510
|
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
1392
1511
|
handles.append(self.handle)
|
|
1393
|
-
|
|
1394
|
-
|
|
1512
|
+
self.shape._flatten_ir(handles)
|
|
1513
|
+
self.strides._flatten_ir(handles)
|
|
1514
|
+
|
|
1515
|
+
|
|
1516
|
+
# -----------------------
|
|
1517
|
+
# aggregate
|
|
1518
|
+
# -----------------------
|
|
1519
|
+
|
|
1395
1520
|
|
|
1521
|
+
@dataclass(frozen=True)
|
|
1522
|
+
class _aggregate_type(base_type):
|
|
1523
|
+
"""A generic base type for all Triton aggregate types.
|
|
1396
1524
|
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1525
|
+
This class contains a reference to the original user-defined Python class
|
|
1526
|
+
and a list of class fields with their Triton types.
|
|
1527
|
+
"""
|
|
1528
|
+
|
|
1529
|
+
base_cls: type
|
|
1530
|
+
fields: List[Tuple[str, base_type]]
|
|
1531
|
+
|
|
1532
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]:
|
|
1533
|
+
instance = self.base_cls._get_instance()
|
|
1534
|
+
for name, ty in self.fields:
|
|
1535
|
+
value, cursor = ty._unflatten_ir(handles, cursor)
|
|
1536
|
+
setattr(instance, name, value)
|
|
1537
|
+
return instance, cursor
|
|
1538
|
+
|
|
1539
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
1540
|
+
for name, ty in self.fields:
|
|
1541
|
+
ty._flatten_ir_types(builder, out)
|
|
1542
|
+
|
|
1543
|
+
def mangle(self) -> str:
|
|
1544
|
+
name = f"{self.base_cls.__module__}.{self.base_cls.__qualname__}"
|
|
1545
|
+
fields = [ty.mangle() for (name, ty) in self.fields]
|
|
1546
|
+
return f"{name}<{', '.join(fields)}>"
|
|
1547
|
+
|
|
1548
|
+
|
|
1549
|
+
def _aggregate(cls):
|
|
1550
|
+
|
|
1551
|
+
# Define the wrapped Triton value type.
|
|
1552
|
+
class aggregate_value(base_value):
|
|
1553
|
+
__triton_builtin__ = True
|
|
1554
|
+
__triton_aggregate__ = True
|
|
1555
|
+
|
|
1556
|
+
@classmethod
|
|
1557
|
+
def _get_instance(this_cls):
|
|
1558
|
+
return super().__new__(this_cls)
|
|
1559
|
+
|
|
1560
|
+
def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs):
|
|
1561
|
+
# Call into the user-defined constructor.
|
|
1562
|
+
instance = this_cls._get_instance()
|
|
1563
|
+
if isinstance(cls.__init__, JITFunction):
|
|
1564
|
+
raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
|
|
1565
|
+
extra_kwargs = {}
|
|
1566
|
+
if "_semantic" in inspect.signature(cls.__init__).parameters:
|
|
1567
|
+
extra_kwargs["_semantic"] = _semantic
|
|
1568
|
+
if "_generator" in inspect.signature(cls.__init__).parameters:
|
|
1569
|
+
extra_kwargs["_generator"] = _generator
|
|
1570
|
+
cls.__init__(instance, *args, **extra_kwargs, **kwargs)
|
|
1571
|
+
|
|
1572
|
+
# Require that the user-defined constructor initialized all fields.
|
|
1573
|
+
for name in cls.__annotations__.keys():
|
|
1574
|
+
if not hasattr(instance, name):
|
|
1575
|
+
raise AttributeError(f"constructor for {cls.__name__} did not initialize attribute '{name}'")
|
|
1576
|
+
|
|
1577
|
+
return instance
|
|
1578
|
+
|
|
1579
|
+
# Only allow setting attributes defined in the class annotations.
|
|
1580
|
+
def __setattr__(self, name, value):
|
|
1581
|
+
if name not in cls.__annotations__:
|
|
1582
|
+
raise AttributeError(f"{cls.__name__} has no attribute '{name}'")
|
|
1583
|
+
if not isinstance(value, cls.__annotations__[name]):
|
|
1584
|
+
raise TypeError(f"Expected {cls.__annotations__[name]} for attribute '{name}', got {type(value)}")
|
|
1585
|
+
super().__setattr__(name, value)
|
|
1586
|
+
|
|
1587
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
1588
|
+
for name in cls.__annotations__.keys():
|
|
1589
|
+
getattr(self, name)._flatten_ir(handles)
|
|
1590
|
+
|
|
1591
|
+
@property
|
|
1592
|
+
def type(self):
|
|
1593
|
+
return _aggregate_type(aggregate_value,
|
|
1594
|
+
[(name, getattr(self, name).type) for name in cls.__annotations__.keys()])
|
|
1595
|
+
|
|
1596
|
+
for (name, member) in inspect.getmembers(cls):
|
|
1597
|
+
if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITFunction):
|
|
1598
|
+
if name != "__init__":
|
|
1599
|
+
setattr(aggregate_value, name, member)
|
|
1600
|
+
|
|
1601
|
+
aggregate_value.__name__ = cls.__name__
|
|
1602
|
+
aggregate_value.__module__ = cls.__module__
|
|
1603
|
+
aggregate_value.__qualname__ = cls.__qualname__
|
|
1604
|
+
aggregate_value.__doc__ = cls.__doc__
|
|
1605
|
+
|
|
1606
|
+
return aggregate_value
|
|
1400
1607
|
|
|
1401
1608
|
|
|
1402
1609
|
# -----------------------
|
|
1403
1610
|
# SPMD Programming Model
|
|
1404
1611
|
# -----------------------
|
|
1405
|
-
def _constexpr_to_value(v):
|
|
1406
|
-
if isinstance(v, constexpr):
|
|
1407
|
-
return v.value
|
|
1408
|
-
return v
|
|
1409
1612
|
|
|
1410
1613
|
|
|
1411
1614
|
@builtin
|
|
1412
|
-
def program_id(axis,
|
|
1615
|
+
def program_id(axis, _semantic=None):
|
|
1413
1616
|
"""
|
|
1414
1617
|
Returns the id of the current program instance along the given :code:`axis`.
|
|
1415
1618
|
|
|
@@ -1417,26 +1620,26 @@ def program_id(axis, _builder=None):
|
|
|
1417
1620
|
:type axis: int
|
|
1418
1621
|
"""
|
|
1419
1622
|
# if axis == -1:
|
|
1420
|
-
# pid0 = program_id(0
|
|
1421
|
-
# pid1 = program_id(1
|
|
1422
|
-
# pid2 = program_id(2
|
|
1423
|
-
# npg0 = num_programs(0
|
|
1424
|
-
# npg1 = num_programs(1
|
|
1623
|
+
# pid0 = _semantic.program_id(0)
|
|
1624
|
+
# pid1 = _semantic.program_id(1)
|
|
1625
|
+
# pid2 = _semantic.program_id(2)
|
|
1626
|
+
# npg0 = _semantic.num_programs(0)
|
|
1627
|
+
# npg1 = _semantic.num_programs(1)
|
|
1425
1628
|
# return pid0 + pid1*npg0 + pid2*npg0*npg1
|
|
1426
|
-
axis =
|
|
1427
|
-
return
|
|
1629
|
+
axis = _unwrap_if_constexpr(axis)
|
|
1630
|
+
return _semantic.program_id(axis)
|
|
1428
1631
|
|
|
1429
1632
|
|
|
1430
1633
|
@builtin
|
|
1431
|
-
def num_programs(axis,
|
|
1634
|
+
def num_programs(axis, _semantic=None):
|
|
1432
1635
|
"""
|
|
1433
1636
|
Returns the number of program instances launched along the given :code:`axis`.
|
|
1434
1637
|
|
|
1435
1638
|
:param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
|
|
1436
1639
|
:type axis: int
|
|
1437
1640
|
"""
|
|
1438
|
-
axis =
|
|
1439
|
-
return
|
|
1641
|
+
axis = _unwrap_if_constexpr(axis)
|
|
1642
|
+
return _semantic.num_programs(axis)
|
|
1440
1643
|
|
|
1441
1644
|
|
|
1442
1645
|
# -----------------------
|
|
@@ -1445,10 +1648,10 @@ def num_programs(axis, _builder=None):
|
|
|
1445
1648
|
|
|
1446
1649
|
|
|
1447
1650
|
@builtin
|
|
1448
|
-
def arange(start, end,
|
|
1449
|
-
start =
|
|
1450
|
-
end =
|
|
1451
|
-
return
|
|
1651
|
+
def arange(start, end, _semantic=None):
|
|
1652
|
+
start = _unwrap_if_constexpr(start)
|
|
1653
|
+
end = _unwrap_if_constexpr(end)
|
|
1654
|
+
return _semantic.arange(start, end)
|
|
1452
1655
|
|
|
1453
1656
|
|
|
1454
1657
|
arange.__doc__ = f"""
|
|
@@ -1465,8 +1668,8 @@ arange.__doc__ = f"""
|
|
|
1465
1668
|
|
|
1466
1669
|
|
|
1467
1670
|
def _unwrap_shape(shape):
|
|
1468
|
-
shape =
|
|
1469
|
-
return [
|
|
1671
|
+
shape = _unwrap_if_constexpr(shape)
|
|
1672
|
+
return [_unwrap_if_constexpr(s) for s in shape]
|
|
1470
1673
|
|
|
1471
1674
|
|
|
1472
1675
|
def _shape_check_impl(shape):
|
|
@@ -1476,7 +1679,7 @@ def _shape_check_impl(shape):
|
|
|
1476
1679
|
|
|
1477
1680
|
|
|
1478
1681
|
@builtin
|
|
1479
|
-
def full(shape, value, dtype,
|
|
1682
|
+
def full(shape, value, dtype, _semantic=None):
|
|
1480
1683
|
"""
|
|
1481
1684
|
Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`.
|
|
1482
1685
|
|
|
@@ -1488,9 +1691,9 @@ def full(shape, value, dtype, _builder=None):
|
|
|
1488
1691
|
:type dtype: tl.dtype
|
|
1489
1692
|
"""
|
|
1490
1693
|
shape = _shape_check_impl(shape)
|
|
1491
|
-
value =
|
|
1492
|
-
dtype =
|
|
1493
|
-
return
|
|
1694
|
+
value = _unwrap_if_constexpr(value)
|
|
1695
|
+
dtype = _unwrap_if_constexpr(dtype)
|
|
1696
|
+
return _semantic.full(shape, value, dtype)
|
|
1494
1697
|
|
|
1495
1698
|
|
|
1496
1699
|
# -----------------------
|
|
@@ -1499,7 +1702,7 @@ def full(shape, value, dtype, _builder=None):
|
|
|
1499
1702
|
|
|
1500
1703
|
|
|
1501
1704
|
@builtin
|
|
1502
|
-
def broadcast(input, other,
|
|
1705
|
+
def broadcast(input, other, _semantic=None):
|
|
1503
1706
|
"""
|
|
1504
1707
|
Tries to broadcast the two given blocks to a common compatible shape.
|
|
1505
1708
|
|
|
@@ -1508,12 +1711,12 @@ def broadcast(input, other, _builder=None):
|
|
|
1508
1711
|
:param other: The second input tensor.
|
|
1509
1712
|
:type other: Block
|
|
1510
1713
|
"""
|
|
1511
|
-
return
|
|
1714
|
+
return _semantic.broadcast_impl_value(input, other)
|
|
1512
1715
|
|
|
1513
1716
|
|
|
1514
1717
|
@_tensor_member_fn
|
|
1515
1718
|
@builtin
|
|
1516
|
-
def broadcast_to(input, *shape,
|
|
1719
|
+
def broadcast_to(input, *shape, _semantic=None):
|
|
1517
1720
|
"""
|
|
1518
1721
|
Tries to broadcast the given tensor to a new :code:`shape`.
|
|
1519
1722
|
|
|
@@ -1529,12 +1732,12 @@ def broadcast_to(input, *shape, _builder=None):
|
|
|
1529
1732
|
broadcast_to(x, 32, 32)
|
|
1530
1733
|
"""
|
|
1531
1734
|
shape = _shape_check_impl(_unwrap_iterable(shape))
|
|
1532
|
-
return
|
|
1735
|
+
return _semantic.broadcast_impl_shape(input, shape)
|
|
1533
1736
|
|
|
1534
1737
|
|
|
1535
1738
|
@_tensor_member_fn
|
|
1536
1739
|
@builtin
|
|
1537
|
-
def trans(input: tensor, *dims,
|
|
1740
|
+
def trans(input: tensor, *dims, _semantic=None):
|
|
1538
1741
|
"""
|
|
1539
1742
|
Permutes the dimensions of a tensor.
|
|
1540
1743
|
|
|
@@ -1543,7 +1746,7 @@ def trans(input: tensor, *dims, _builder=None):
|
|
|
1543
1746
|
|
|
1544
1747
|
:param input: The input tensor.
|
|
1545
1748
|
:param dims: The desired ordering of dimensions. For example,
|
|
1546
|
-
:code:`(2, 1, 0)` reverses the order dims in a
|
|
1749
|
+
:code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
|
|
1547
1750
|
|
|
1548
1751
|
:code:`dims` can be passed as a tuple or as individual parameters: ::
|
|
1549
1752
|
|
|
@@ -1557,19 +1760,19 @@ def trans(input: tensor, *dims, _builder=None):
|
|
|
1557
1760
|
dims = _unwrap_iterable(dims)
|
|
1558
1761
|
if not dims:
|
|
1559
1762
|
dims = (1, 0)
|
|
1560
|
-
return
|
|
1763
|
+
return _semantic.permute(input, dims)
|
|
1561
1764
|
|
|
1562
1765
|
|
|
1563
1766
|
@_tensor_member_fn
|
|
1564
1767
|
@builtin
|
|
1565
|
-
def permute(input, *dims,
|
|
1768
|
+
def permute(input, *dims, _semantic=None):
|
|
1566
1769
|
"""
|
|
1567
1770
|
Permutes the dimensions of a tensor.
|
|
1568
1771
|
|
|
1569
1772
|
:param input: The input tensor.
|
|
1570
1773
|
:type input: Block
|
|
1571
1774
|
:param dims: The desired ordering of dimensions. For example,
|
|
1572
|
-
:code:`(2, 1, 0)` reverses the order dims in a
|
|
1775
|
+
:code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
|
|
1573
1776
|
|
|
1574
1777
|
:code:`dims` can be passed as a tuple or as individual parameters: ::
|
|
1575
1778
|
|
|
@@ -1581,11 +1784,11 @@ def permute(input, *dims, _builder=None):
|
|
|
1581
1784
|
:code:`dims` is empty, it tries to do a (1,0) permutation.
|
|
1582
1785
|
"""
|
|
1583
1786
|
dims = _unwrap_iterable(dims)
|
|
1584
|
-
return
|
|
1787
|
+
return _semantic.permute(input, dims)
|
|
1585
1788
|
|
|
1586
1789
|
|
|
1587
1790
|
@builtin
|
|
1588
|
-
def cat(input, other, can_reorder=False,
|
|
1791
|
+
def cat(input, other, can_reorder=False, _semantic=None):
|
|
1589
1792
|
"""
|
|
1590
1793
|
Concatenate the given blocks
|
|
1591
1794
|
|
|
@@ -1598,11 +1801,11 @@ def cat(input, other, can_reorder=False, _builder=None):
|
|
|
1598
1801
|
order does not matter (e.g., result is only used in reduction ops).
|
|
1599
1802
|
Current implementation of `cat` supports only can_reorder=True.
|
|
1600
1803
|
"""
|
|
1601
|
-
return
|
|
1804
|
+
return _semantic.cat(input, other, can_reorder)
|
|
1602
1805
|
|
|
1603
1806
|
|
|
1604
1807
|
@builtin
|
|
1605
|
-
def join(a, b,
|
|
1808
|
+
def join(a, b, _semantic=None):
|
|
1606
1809
|
"""
|
|
1607
1810
|
Join the given tensors in a new, minor dimension.
|
|
1608
1811
|
|
|
@@ -1622,7 +1825,7 @@ def join(a, b, _builder=None):
|
|
|
1622
1825
|
:param b: The second input tensor.
|
|
1623
1826
|
:type b: Tensor
|
|
1624
1827
|
"""
|
|
1625
|
-
return
|
|
1828
|
+
return _semantic.join(a, b)
|
|
1626
1829
|
|
|
1627
1830
|
|
|
1628
1831
|
@jit
|
|
@@ -1630,9 +1833,25 @@ def _take_first(a, b):
|
|
|
1630
1833
|
return a
|
|
1631
1834
|
|
|
1632
1835
|
|
|
1836
|
+
def _unsplat(x, _semantic=None, _generator=None):
|
|
1837
|
+
"""
|
|
1838
|
+
Convert a single-element tensor to a scalar.
|
|
1839
|
+
"""
|
|
1840
|
+
if len(x.shape) == 0:
|
|
1841
|
+
return x
|
|
1842
|
+
numel = 1
|
|
1843
|
+
for d in x.shape:
|
|
1844
|
+
numel *= d
|
|
1845
|
+
assert numel == 1, "can only unsplat single-element tensors"
|
|
1846
|
+
if len(x.shape) >= 2:
|
|
1847
|
+
x = _semantic.reshape(x, [1])
|
|
1848
|
+
x = typing.cast(tensor, reduce(x, 0, _take_first, _semantic=_semantic, _generator=_generator))
|
|
1849
|
+
return x
|
|
1850
|
+
|
|
1851
|
+
|
|
1633
1852
|
@_tensor_member_fn
|
|
1634
1853
|
@builtin
|
|
1635
|
-
def split(a,
|
|
1854
|
+
def split(a, _semantic=None, _generator=None) -> tuple[tensor, tensor]:
|
|
1636
1855
|
"""
|
|
1637
1856
|
Split a tensor in two along its last dim, which must have size 2.
|
|
1638
1857
|
|
|
@@ -1649,25 +1868,25 @@ def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]:
|
|
|
1649
1868
|
:type a: Tensor
|
|
1650
1869
|
"""
|
|
1651
1870
|
# If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars.
|
|
1652
|
-
# But
|
|
1871
|
+
# But _semantic.split can only handle returning tensors. Work around this by
|
|
1653
1872
|
# expanding the input to shape [1,2] and then reducing the result.
|
|
1654
1873
|
was_rank_1 = len(a.shape) == 1
|
|
1655
1874
|
if was_rank_1:
|
|
1656
|
-
a =
|
|
1875
|
+
a = _semantic.expand_dims(a, 0)
|
|
1657
1876
|
|
|
1658
|
-
out_lhs, out_rhs =
|
|
1877
|
+
out_lhs, out_rhs = _semantic.split(a)
|
|
1659
1878
|
|
|
1660
1879
|
if was_rank_1:
|
|
1661
1880
|
# Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar.
|
|
1662
|
-
out_lhs =
|
|
1663
|
-
out_rhs =
|
|
1881
|
+
out_lhs = _unsplat(out_lhs, _semantic=_semantic, _generator=_generator)
|
|
1882
|
+
out_rhs = _unsplat(out_rhs, _semantic=_semantic, _generator=_generator)
|
|
1664
1883
|
|
|
1665
1884
|
return out_lhs, out_rhs
|
|
1666
1885
|
|
|
1667
1886
|
|
|
1668
1887
|
@_tensor_member_fn
|
|
1669
1888
|
@builtin
|
|
1670
|
-
def view(input, *shape,
|
|
1889
|
+
def view(input, *shape, _semantic=None):
|
|
1671
1890
|
"""
|
|
1672
1891
|
Returns a tensor with the same elements as `input` but a different shape.
|
|
1673
1892
|
The order of the elements may not be preserved.
|
|
@@ -1684,12 +1903,21 @@ def view(input, *shape, _builder=None):
|
|
|
1684
1903
|
"""
|
|
1685
1904
|
warn("view is deprecated, please use reshape with can_reorder being true.")
|
|
1686
1905
|
shape = _shape_check_impl(_unwrap_iterable(shape))
|
|
1687
|
-
return
|
|
1906
|
+
return _semantic.reshape(input, shape, can_reorder=True)
|
|
1688
1907
|
|
|
1689
1908
|
|
|
1690
1909
|
@_tensor_member_fn
|
|
1691
1910
|
@builtin
|
|
1692
|
-
def
|
|
1911
|
+
def item(input, _semantic=None, _generator=None):
|
|
1912
|
+
"""
|
|
1913
|
+
Converts a single-element tensor into a scalar.
|
|
1914
|
+
"""
|
|
1915
|
+
return _unsplat(input, _semantic=_semantic, _generator=_generator)
|
|
1916
|
+
|
|
1917
|
+
|
|
1918
|
+
@_tensor_member_fn
|
|
1919
|
+
@builtin
|
|
1920
|
+
def reshape(input, *shape, can_reorder=False, _semantic=None, _generator=None):
|
|
1693
1921
|
"""
|
|
1694
1922
|
Returns a tensor with the same number of elements as input but with the
|
|
1695
1923
|
provided shape.
|
|
@@ -1705,7 +1933,9 @@ def reshape(input, *shape, can_reorder=False, _builder=None):
|
|
|
1705
1933
|
reshape(x, 32, 32)
|
|
1706
1934
|
"""
|
|
1707
1935
|
shape = _shape_check_impl(_unwrap_iterable(shape))
|
|
1708
|
-
|
|
1936
|
+
if len(shape) == 0:
|
|
1937
|
+
return _unsplat(input, _semantic=_semantic, _generator=_generator)
|
|
1938
|
+
return _semantic.reshape(input, shape, can_reorder)
|
|
1709
1939
|
|
|
1710
1940
|
|
|
1711
1941
|
def _wrap_axis(axis, ndim):
|
|
@@ -1717,7 +1947,7 @@ def _wrap_axis(axis, ndim):
|
|
|
1717
1947
|
|
|
1718
1948
|
@_tensor_member_fn
|
|
1719
1949
|
@builtin
|
|
1720
|
-
def expand_dims(input, axis,
|
|
1950
|
+
def expand_dims(input, axis, _semantic=None):
|
|
1721
1951
|
"""
|
|
1722
1952
|
Expand the shape of a tensor, by inserting new length-1 dimensions.
|
|
1723
1953
|
|
|
@@ -1730,24 +1960,24 @@ def expand_dims(input, axis, _builder=None):
|
|
|
1730
1960
|
:type axis: int | Sequence[int]
|
|
1731
1961
|
|
|
1732
1962
|
"""
|
|
1733
|
-
input =
|
|
1734
|
-
axis =
|
|
1963
|
+
input = _semantic.to_tensor(input)
|
|
1964
|
+
axis = _unwrap_if_constexpr(axis)
|
|
1735
1965
|
axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis]
|
|
1736
1966
|
new_ndim = len(input.shape) + len(axes)
|
|
1737
|
-
axes = [_wrap_axis(
|
|
1967
|
+
axes = [_wrap_axis(_unwrap_if_constexpr(d), new_ndim) for d in axes]
|
|
1738
1968
|
|
|
1739
1969
|
if len(set(axes)) != len(axes):
|
|
1740
1970
|
raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}")
|
|
1741
1971
|
|
|
1742
1972
|
ret = input
|
|
1743
1973
|
for a in sorted(axes):
|
|
1744
|
-
ret =
|
|
1974
|
+
ret = _semantic.expand_dims(ret, a)
|
|
1745
1975
|
return ret
|
|
1746
1976
|
|
|
1747
1977
|
|
|
1748
1978
|
@_tensor_member_fn
|
|
1749
1979
|
@builtin
|
|
1750
|
-
def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False,
|
|
1980
|
+
def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
|
|
1751
1981
|
"""
|
|
1752
1982
|
Casts a tensor to the given :code:`dtype`.
|
|
1753
1983
|
|
|
@@ -1763,13 +1993,13 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas
|
|
|
1763
1993
|
:code:`dtype`, instead of being numerically casted.
|
|
1764
1994
|
:type bitcast: bool, optional
|
|
1765
1995
|
"""
|
|
1766
|
-
input =
|
|
1767
|
-
dtype =
|
|
1768
|
-
fp_downcast_rounding =
|
|
1769
|
-
bitcast =
|
|
1996
|
+
input = _semantic.to_tensor(input)
|
|
1997
|
+
dtype = _unwrap_if_constexpr(dtype)
|
|
1998
|
+
fp_downcast_rounding = _unwrap_if_constexpr(fp_downcast_rounding)
|
|
1999
|
+
bitcast = _unwrap_if_constexpr(bitcast)
|
|
1770
2000
|
if bitcast:
|
|
1771
|
-
return
|
|
1772
|
-
return
|
|
2001
|
+
return _semantic.bitcast(input, dtype)
|
|
2002
|
+
return _semantic.cast(input, dtype, fp_downcast_rounding)
|
|
1773
2003
|
|
|
1774
2004
|
|
|
1775
2005
|
# -----------------------
|
|
@@ -1779,7 +2009,7 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas
|
|
|
1779
2009
|
|
|
1780
2010
|
@builtin
|
|
1781
2011
|
def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32,
|
|
1782
|
-
|
|
2012
|
+
_semantic=None):
|
|
1783
2013
|
"""
|
|
1784
2014
|
Returns the matrix product of two blocks.
|
|
1785
2015
|
|
|
@@ -1804,19 +2034,20 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
|
|
|
1804
2034
|
"""
|
|
1805
2035
|
assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified"
|
|
1806
2036
|
if input_precision is None:
|
|
1807
|
-
supports_tf32 =
|
|
1808
|
-
|
|
1809
|
-
|
|
2037
|
+
supports_tf32 = "tf32" in _semantic.builder.options.allowed_dot_input_precisions
|
|
2038
|
+
input_precision = knobs.language.fp32_default or ("tf32" if (supports_tf32 and
|
|
2039
|
+
(allow_tf32 or allow_tf32 is None)) else "ieee")
|
|
1810
2040
|
|
|
1811
|
-
input_precision =
|
|
1812
|
-
out_dtype =
|
|
1813
|
-
max_num_imprecise_acc =
|
|
1814
|
-
|
|
2041
|
+
input_precision = _unwrap_if_constexpr(input_precision)
|
|
2042
|
+
out_dtype = _unwrap_if_constexpr(out_dtype)
|
|
2043
|
+
max_num_imprecise_acc = _unwrap_if_constexpr(max_num_imprecise_acc)
|
|
2044
|
+
acc = _unwrap_if_constexpr(acc)
|
|
2045
|
+
return _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
|
|
1815
2046
|
|
|
1816
2047
|
|
|
1817
2048
|
@builtin
|
|
1818
|
-
def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False,
|
|
1819
|
-
|
|
2049
|
+
def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True,
|
|
2050
|
+
rhs_k_pack=True, out_dtype=float32, _semantic=None):
|
|
1820
2051
|
"""
|
|
1821
2052
|
Returns the matrix product of two blocks in microscaling format.
|
|
1822
2053
|
|
|
@@ -1843,11 +2074,15 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
|
|
|
1843
2074
|
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
|
|
1844
2075
|
:type rhs_format: str
|
|
1845
2076
|
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
|
|
2077
|
+
:param lhs_k_pack: If false, the lhs tensor is packed into uint8 along M dimension.
|
|
2078
|
+
:type lhs_k_pack: bool, optional
|
|
2079
|
+
:param rhs_k_pack: If false, the rhs tensor is packed into uint8 along N dimension.
|
|
2080
|
+
:type rhs_k_pack: bool, optional
|
|
1846
2081
|
"""
|
|
1847
|
-
out_dtype =
|
|
2082
|
+
out_dtype = _unwrap_if_constexpr(out_dtype)
|
|
1848
2083
|
assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment"
|
|
1849
|
-
return
|
|
1850
|
-
|
|
2084
|
+
return _semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, lhs_k_pack,
|
|
2085
|
+
rhs_k_pack, out_dtype)
|
|
1851
2086
|
|
|
1852
2087
|
|
|
1853
2088
|
# -----------------------
|
|
@@ -1857,7 +2092,7 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
|
|
|
1857
2092
|
|
|
1858
2093
|
@builtin
|
|
1859
2094
|
def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="",
|
|
1860
|
-
volatile=False,
|
|
2095
|
+
volatile=False, _semantic=None):
|
|
1861
2096
|
"""
|
|
1862
2097
|
Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
|
|
1863
2098
|
|
|
@@ -1892,8 +2127,9 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
|
|
|
1892
2127
|
:type boundary_check: tuple of ints, optional
|
|
1893
2128
|
:param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value.
|
|
1894
2129
|
:param cache_modifier: changes cache option in NVIDIA PTX
|
|
1895
|
-
:type cache_modifier: str, optional, should be one of {"", "ca", "cg"}, where "ca" stands for
|
|
1896
|
-
cache at all levels
|
|
2130
|
+
:type cache_modifier: str, optional, should be one of {"", ".ca", ".cg", ".cv"}, where ".ca" stands for
|
|
2131
|
+
cache at all levels, ".cg" stands for cache at global level (cache in L2 and below, not L1),
|
|
2132
|
+
and ".cv" means don’t cache and fetch again. see
|
|
1897
2133
|
`cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
|
|
1898
2134
|
:param eviction_policy: changes eviction policy in NVIDIA PTX
|
|
1899
2135
|
:type eviction_policy: str, optional
|
|
@@ -1901,57 +2137,37 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
|
|
|
1901
2137
|
:type volatile: bool, optional
|
|
1902
2138
|
"""
|
|
1903
2139
|
# `mask` and `other` can be constexpr
|
|
1904
|
-
mask =
|
|
1905
|
-
other =
|
|
2140
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2141
|
+
other = _unwrap_if_constexpr(other)
|
|
1906
2142
|
if mask is not None:
|
|
1907
|
-
mask =
|
|
2143
|
+
mask = _semantic.to_tensor(mask)
|
|
1908
2144
|
if other is not None:
|
|
1909
|
-
other =
|
|
1910
|
-
padding_option =
|
|
1911
|
-
cache_modifier =
|
|
1912
|
-
eviction_policy =
|
|
1913
|
-
volatile =
|
|
1914
|
-
return
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
@builtin
|
|
1919
|
-
def _experimental_reinterpret_tensor_descriptor(desc_ptr, block_shape, dtype,
|
|
1920
|
-
_builder=None) -> _experimental_tensor_descriptor_base:
|
|
1921
|
-
"""
|
|
1922
|
-
Reinterpret a generic pointer as a TMA-backed tensor descriptor object.
|
|
1923
|
-
"""
|
|
1924
|
-
block_ty = block_type(_constexpr_to_value(dtype), block_shape)
|
|
1925
|
-
return semantic.reinterpret_tensor_descriptor(desc_ptr, block_ty, _builder)
|
|
2145
|
+
other = _semantic.to_tensor(other)
|
|
2146
|
+
padding_option = _unwrap_if_constexpr(padding_option)
|
|
2147
|
+
cache_modifier = _unwrap_if_constexpr(cache_modifier)
|
|
2148
|
+
eviction_policy = _unwrap_if_constexpr(eviction_policy)
|
|
2149
|
+
volatile = _unwrap_if_constexpr(volatile)
|
|
2150
|
+
return _semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
|
|
2151
|
+
volatile)
|
|
1926
2152
|
|
|
1927
2153
|
|
|
1928
2154
|
@builtin
|
|
1929
|
-
def
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
This loads a tensor of data based on the descriptor and offsets.
|
|
1935
|
-
"""
|
|
1936
|
-
desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, shape, dtype, _builder=_builder)
|
|
1937
|
-
return desc.load(offsets, _builder=_builder)
|
|
2155
|
+
def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor],
|
|
2156
|
+
_semantic=None) -> tensor:
|
|
2157
|
+
"""Load a block of data from a tensor descriptor."""
|
|
2158
|
+
return desc.load(offsets, _semantic=_semantic)
|
|
1938
2159
|
|
|
1939
2160
|
|
|
1940
2161
|
@builtin
|
|
1941
|
-
def
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
This stores a tensor of data based on the descriptor and offsets.
|
|
1947
|
-
"""
|
|
1948
|
-
desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, value.shape, value.dtype, _builder=_builder)
|
|
1949
|
-
return desc.store(offsets, value, _builder=_builder)
|
|
2162
|
+
def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor], value: tensor,
|
|
2163
|
+
_semantic=None) -> tensor:
|
|
2164
|
+
"""Store a block of data to a tensor descriptor."""
|
|
2165
|
+
return desc.store(offsets, value, _semantic=_semantic)
|
|
1950
2166
|
|
|
1951
2167
|
|
|
1952
2168
|
@_tensor_member_fn
|
|
1953
2169
|
@builtin
|
|
1954
|
-
def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="",
|
|
2170
|
+
def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _semantic=None):
|
|
1955
2171
|
"""
|
|
1956
2172
|
Store a tensor of data into memory locations defined by `pointer`.
|
|
1957
2173
|
|
|
@@ -1991,17 +2207,17 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict
|
|
|
1991
2207
|
:type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"}
|
|
1992
2208
|
"""
|
|
1993
2209
|
# `value` can be constexpr
|
|
1994
|
-
value =
|
|
1995
|
-
mask =
|
|
2210
|
+
value = _semantic.to_tensor(value)
|
|
2211
|
+
mask = _unwrap_if_constexpr(mask)
|
|
1996
2212
|
if mask is not None:
|
|
1997
|
-
mask =
|
|
1998
|
-
cache_modifier =
|
|
1999
|
-
eviction_policy =
|
|
2000
|
-
return
|
|
2213
|
+
mask = _semantic.to_tensor(mask)
|
|
2214
|
+
cache_modifier = _unwrap_if_constexpr(cache_modifier)
|
|
2215
|
+
eviction_policy = _unwrap_if_constexpr(eviction_policy)
|
|
2216
|
+
return _semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy)
|
|
2001
2217
|
|
|
2002
2218
|
|
|
2003
2219
|
@builtin
|
|
2004
|
-
def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order,
|
|
2220
|
+
def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _semantic=None):
|
|
2005
2221
|
"""
|
|
2006
2222
|
Returns a pointer to a block in a parent tensor
|
|
2007
2223
|
|
|
@@ -2012,30 +2228,33 @@ def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _b
|
|
|
2012
2228
|
:param block_shape: The shape of the block
|
|
2013
2229
|
:param order: The order of the original data format
|
|
2014
2230
|
"""
|
|
2015
|
-
return
|
|
2231
|
+
return _semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order)
|
|
2016
2232
|
|
|
2017
2233
|
|
|
2234
|
+
@must_use_result(
|
|
2235
|
+
"Note that tl.advance does not have any side effects. To move the block pointer, you need to assign the result of tl.advance to a variable."
|
|
2236
|
+
)
|
|
2018
2237
|
@_tensor_member_fn
|
|
2019
2238
|
@builtin
|
|
2020
|
-
def advance(base, offsets,
|
|
2239
|
+
def advance(base, offsets, _semantic=None):
|
|
2021
2240
|
"""
|
|
2022
2241
|
Advance a block pointer
|
|
2023
2242
|
|
|
2024
2243
|
:param base: the block pointer to advance
|
|
2025
2244
|
:param offsets: the offsets to advance, a tuple by dimension
|
|
2026
2245
|
"""
|
|
2027
|
-
return
|
|
2246
|
+
return _semantic.advance(base, offsets)
|
|
2028
2247
|
|
|
2029
2248
|
|
|
2030
2249
|
@builtin
|
|
2031
|
-
def
|
|
2250
|
+
def make_tensor_descriptor(
|
|
2032
2251
|
base: tensor,
|
|
2033
2252
|
shape: List[tensor],
|
|
2034
2253
|
strides: List[tensor],
|
|
2035
2254
|
block_shape: List[constexpr],
|
|
2036
|
-
|
|
2037
|
-
) ->
|
|
2038
|
-
"""Make
|
|
2255
|
+
_semantic=None,
|
|
2256
|
+
) -> tensor_descriptor:
|
|
2257
|
+
"""Make a tensor descriptor object
|
|
2039
2258
|
|
|
2040
2259
|
:param base: the base pointer of the tensor, must be 16-byte aligned
|
|
2041
2260
|
:param shape: A list of non-negative integers representing the tensor shape
|
|
@@ -2056,7 +2275,7 @@ def _experimental_make_tensor_descriptor(
|
|
|
2056
2275
|
|
|
2057
2276
|
@triton.jit
|
|
2058
2277
|
def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
|
|
2059
|
-
desc = tl.
|
|
2278
|
+
desc = tl.make_tensor_descriptor(
|
|
2060
2279
|
in_out_ptr,
|
|
2061
2280
|
shape=[M, N],
|
|
2062
2281
|
strides=[N, 1],
|
|
@@ -2082,7 +2301,7 @@ def _experimental_make_tensor_descriptor(
|
|
|
2082
2301
|
inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)
|
|
2083
2302
|
|
|
2084
2303
|
"""
|
|
2085
|
-
return
|
|
2304
|
+
return _semantic.make_tensor_descriptor(base, shape, strides, block_shape)
|
|
2086
2305
|
|
|
2087
2306
|
|
|
2088
2307
|
# -----------------------
|
|
@@ -2124,89 +2343,89 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
|
|
|
2124
2343
|
@_tensor_member_fn
|
|
2125
2344
|
@builtin
|
|
2126
2345
|
@_add_atomic_docstr("compare-and-swap", has_cmp=True)
|
|
2127
|
-
def atomic_cas(pointer, cmp, val, sem=None, scope=None,
|
|
2128
|
-
cmp =
|
|
2129
|
-
val =
|
|
2130
|
-
sem =
|
|
2131
|
-
scope =
|
|
2132
|
-
return
|
|
2346
|
+
def atomic_cas(pointer, cmp, val, sem=None, scope=None, _semantic=None):
|
|
2347
|
+
cmp = _semantic.to_tensor(cmp)
|
|
2348
|
+
val = _semantic.to_tensor(val)
|
|
2349
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2350
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2351
|
+
return _semantic.atomic_cas(pointer, cmp, val, sem, scope)
|
|
2133
2352
|
|
|
2134
2353
|
|
|
2135
2354
|
@_tensor_member_fn
|
|
2136
2355
|
@builtin
|
|
2137
2356
|
@_add_atomic_docstr("exchange")
|
|
2138
|
-
def atomic_xchg(pointer, val, mask=None, sem=None, scope=None,
|
|
2139
|
-
val =
|
|
2140
|
-
sem =
|
|
2141
|
-
scope =
|
|
2142
|
-
mask =
|
|
2143
|
-
return
|
|
2357
|
+
def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2358
|
+
val = _semantic.to_tensor(val)
|
|
2359
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2360
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2361
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2362
|
+
return _semantic.atomic_xchg(pointer, val, mask, sem, scope)
|
|
2144
2363
|
|
|
2145
2364
|
|
|
2146
2365
|
@_tensor_member_fn
|
|
2147
2366
|
@builtin
|
|
2148
2367
|
@_add_atomic_docstr("add")
|
|
2149
|
-
def atomic_add(pointer, val, mask=None, sem=None, scope=None,
|
|
2150
|
-
val =
|
|
2151
|
-
sem =
|
|
2152
|
-
scope =
|
|
2153
|
-
mask =
|
|
2154
|
-
return
|
|
2368
|
+
def atomic_add(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2369
|
+
val = _semantic.to_tensor(val)
|
|
2370
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2371
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2372
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2373
|
+
return _semantic.atomic_add(pointer, val, mask, sem, scope)
|
|
2155
2374
|
|
|
2156
2375
|
|
|
2157
2376
|
@_tensor_member_fn
|
|
2158
2377
|
@builtin
|
|
2159
2378
|
@_add_atomic_docstr("max")
|
|
2160
|
-
def atomic_max(pointer, val, mask=None, sem=None, scope=None,
|
|
2161
|
-
val =
|
|
2162
|
-
sem =
|
|
2163
|
-
scope =
|
|
2164
|
-
mask =
|
|
2165
|
-
return
|
|
2379
|
+
def atomic_max(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2380
|
+
val = _semantic.to_tensor(val)
|
|
2381
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2382
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2383
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2384
|
+
return _semantic.atomic_max(pointer, val, mask, sem, scope)
|
|
2166
2385
|
|
|
2167
2386
|
|
|
2168
2387
|
@_tensor_member_fn
|
|
2169
2388
|
@builtin
|
|
2170
2389
|
@_add_atomic_docstr("min")
|
|
2171
|
-
def atomic_min(pointer, val, mask=None, sem=None, scope=None,
|
|
2172
|
-
val =
|
|
2173
|
-
sem =
|
|
2174
|
-
scope =
|
|
2175
|
-
mask =
|
|
2176
|
-
return
|
|
2390
|
+
def atomic_min(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2391
|
+
val = _semantic.to_tensor(val)
|
|
2392
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2393
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2394
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2395
|
+
return _semantic.atomic_min(pointer, val, mask, sem, scope)
|
|
2177
2396
|
|
|
2178
2397
|
|
|
2179
2398
|
@_tensor_member_fn
|
|
2180
2399
|
@builtin
|
|
2181
2400
|
@_add_atomic_docstr("logical and")
|
|
2182
|
-
def atomic_and(pointer, val, mask=None, sem=None, scope=None,
|
|
2183
|
-
val =
|
|
2184
|
-
sem =
|
|
2185
|
-
scope =
|
|
2186
|
-
mask =
|
|
2187
|
-
return
|
|
2401
|
+
def atomic_and(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2402
|
+
val = _semantic.to_tensor(val)
|
|
2403
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2404
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2405
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2406
|
+
return _semantic.atomic_and(pointer, val, mask, sem, scope)
|
|
2188
2407
|
|
|
2189
2408
|
|
|
2190
2409
|
@_tensor_member_fn
|
|
2191
2410
|
@builtin
|
|
2192
2411
|
@_add_atomic_docstr("logical or")
|
|
2193
|
-
def atomic_or(pointer, val, mask=None, sem=None, scope=None,
|
|
2194
|
-
val =
|
|
2195
|
-
sem =
|
|
2196
|
-
scope =
|
|
2197
|
-
mask =
|
|
2198
|
-
return
|
|
2412
|
+
def atomic_or(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2413
|
+
val = _semantic.to_tensor(val)
|
|
2414
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2415
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2416
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2417
|
+
return _semantic.atomic_or(pointer, val, mask, sem, scope)
|
|
2199
2418
|
|
|
2200
2419
|
|
|
2201
2420
|
@_tensor_member_fn
|
|
2202
2421
|
@builtin
|
|
2203
2422
|
@_add_atomic_docstr("logical xor")
|
|
2204
|
-
def atomic_xor(pointer, val, mask=None, sem=None, scope=None,
|
|
2205
|
-
val =
|
|
2206
|
-
sem =
|
|
2207
|
-
scope =
|
|
2208
|
-
mask =
|
|
2209
|
-
return
|
|
2423
|
+
def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2424
|
+
val = _semantic.to_tensor(val)
|
|
2425
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2426
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2427
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2428
|
+
return _semantic.atomic_xor(pointer, val, mask, sem, scope)
|
|
2210
2429
|
|
|
2211
2430
|
|
|
2212
2431
|
# -----------------------
|
|
@@ -2215,7 +2434,7 @@ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
|
|
2215
2434
|
|
|
2216
2435
|
|
|
2217
2436
|
@builtin
|
|
2218
|
-
def where(condition, x, y,
|
|
2437
|
+
def where(condition, x, y, _semantic=None):
|
|
2219
2438
|
"""
|
|
2220
2439
|
Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
|
|
2221
2440
|
|
|
@@ -2231,10 +2450,10 @@ def where(condition, x, y, _builder=None):
|
|
|
2231
2450
|
:param x: values selected at indices where condition is True.
|
|
2232
2451
|
:param y: values selected at indices where condition is False.
|
|
2233
2452
|
"""
|
|
2234
|
-
condition =
|
|
2453
|
+
condition = _semantic.to_tensor(condition)
|
|
2235
2454
|
x = _unwrap_if_constexpr(x)
|
|
2236
2455
|
y = _unwrap_if_constexpr(y)
|
|
2237
|
-
return
|
|
2456
|
+
return _semantic.where(condition, x, y)
|
|
2238
2457
|
|
|
2239
2458
|
|
|
2240
2459
|
# -----------------------
|
|
@@ -2243,28 +2462,28 @@ def where(condition, x, y, _builder=None):
|
|
|
2243
2462
|
|
|
2244
2463
|
|
|
2245
2464
|
@builtin
|
|
2246
|
-
def add(x, y, sanitize_overflow: constexpr = True,
|
|
2465
|
+
def add(x, y, sanitize_overflow: constexpr = True, _semantic=None):
|
|
2247
2466
|
x = _unwrap_if_constexpr(x)
|
|
2248
2467
|
y = _unwrap_if_constexpr(y)
|
|
2249
|
-
return
|
|
2468
|
+
return _semantic.add(x, y, sanitize_overflow)
|
|
2250
2469
|
|
|
2251
2470
|
|
|
2252
2471
|
@builtin
|
|
2253
|
-
def sub(x, y, sanitize_overflow: constexpr = True,
|
|
2472
|
+
def sub(x, y, sanitize_overflow: constexpr = True, _semantic=None):
|
|
2254
2473
|
x = _unwrap_if_constexpr(x)
|
|
2255
2474
|
y = _unwrap_if_constexpr(y)
|
|
2256
|
-
return
|
|
2475
|
+
return _semantic.sub(x, y, sanitize_overflow)
|
|
2257
2476
|
|
|
2258
2477
|
|
|
2259
2478
|
@builtin
|
|
2260
|
-
def mul(x, y, sanitize_overflow: constexpr = True,
|
|
2479
|
+
def mul(x, y, sanitize_overflow: constexpr = True, _semantic=None):
|
|
2261
2480
|
x = _unwrap_if_constexpr(x)
|
|
2262
2481
|
y = _unwrap_if_constexpr(y)
|
|
2263
|
-
return
|
|
2482
|
+
return _semantic.mul(x, y, sanitize_overflow)
|
|
2264
2483
|
|
|
2265
2484
|
|
|
2266
2485
|
@builtin
|
|
2267
|
-
def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE,
|
|
2486
|
+
def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
|
|
2268
2487
|
"""
|
|
2269
2488
|
Computes the element-wise minimum of :code:`x` and :code:`y`.
|
|
2270
2489
|
|
|
@@ -2277,16 +2496,16 @@ def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
|
|
|
2277
2496
|
|
|
2278
2497
|
.. seealso:: :class:`tl.PropagateNan`
|
|
2279
2498
|
"""
|
|
2280
|
-
x =
|
|
2281
|
-
y =
|
|
2282
|
-
x = _promote_bfloat16_to_float32(x,
|
|
2283
|
-
y = _promote_bfloat16_to_float32(y,
|
|
2284
|
-
propagate_nan =
|
|
2285
|
-
return
|
|
2499
|
+
x = _semantic.to_tensor(x)
|
|
2500
|
+
y = _semantic.to_tensor(y)
|
|
2501
|
+
x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
|
|
2502
|
+
y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
|
|
2503
|
+
propagate_nan = _unwrap_if_constexpr(propagate_nan)
|
|
2504
|
+
return _semantic.minimum(x, y, propagate_nan)
|
|
2286
2505
|
|
|
2287
2506
|
|
|
2288
2507
|
@builtin
|
|
2289
|
-
def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE,
|
|
2508
|
+
def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
|
|
2290
2509
|
"""
|
|
2291
2510
|
Computes the element-wise maximum of :code:`x` and :code:`y`.
|
|
2292
2511
|
|
|
@@ -2299,16 +2518,16 @@ def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
|
|
|
2299
2518
|
|
|
2300
2519
|
.. seealso:: :class:`tl.PropagateNan`
|
|
2301
2520
|
"""
|
|
2302
|
-
x =
|
|
2303
|
-
y =
|
|
2304
|
-
x = _promote_bfloat16_to_float32(x,
|
|
2305
|
-
y = _promote_bfloat16_to_float32(y,
|
|
2306
|
-
propagate_nan =
|
|
2307
|
-
return
|
|
2521
|
+
x = _semantic.to_tensor(x)
|
|
2522
|
+
y = _semantic.to_tensor(y)
|
|
2523
|
+
x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
|
|
2524
|
+
y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
|
|
2525
|
+
propagate_nan = _unwrap_if_constexpr(propagate_nan)
|
|
2526
|
+
return _semantic.maximum(x, y, propagate_nan)
|
|
2308
2527
|
|
|
2309
2528
|
|
|
2310
2529
|
@builtin
|
|
2311
|
-
def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE,
|
|
2530
|
+
def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
|
|
2312
2531
|
"""
|
|
2313
2532
|
Clamps the input tensor :code:`x` within the range [min, max].
|
|
2314
2533
|
Behavior when :code:`min` > :code:`max` is undefined.
|
|
@@ -2325,16 +2544,16 @@ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=No
|
|
|
2325
2544
|
|
|
2326
2545
|
.. seealso:: :class:`tl.PropagateNan`
|
|
2327
2546
|
"""
|
|
2328
|
-
x =
|
|
2329
|
-
min =
|
|
2330
|
-
max =
|
|
2331
|
-
x = _promote_bfloat16_to_float32(x,
|
|
2332
|
-
min = _promote_bfloat16_to_float32(min,
|
|
2333
|
-
max = _promote_bfloat16_to_float32(max,
|
|
2547
|
+
x = _semantic.to_tensor(x)
|
|
2548
|
+
min = _semantic.to_tensor(min)
|
|
2549
|
+
max = _semantic.to_tensor(max)
|
|
2550
|
+
x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
|
|
2551
|
+
min = _promote_bfloat16_to_float32(min, _semantic=_semantic)
|
|
2552
|
+
max = _promote_bfloat16_to_float32(max, _semantic=_semantic)
|
|
2334
2553
|
|
|
2335
|
-
propagate_nan =
|
|
2554
|
+
propagate_nan = _unwrap_if_constexpr(propagate_nan)
|
|
2336
2555
|
|
|
2337
|
-
return
|
|
2556
|
+
return _semantic.clamp(x, min, max, propagate_nan)
|
|
2338
2557
|
|
|
2339
2558
|
|
|
2340
2559
|
# -----------------------
|
|
@@ -2383,7 +2602,7 @@ def _insertion_guard(builder):
|
|
|
2383
2602
|
|
|
2384
2603
|
@_tensor_member_fn
|
|
2385
2604
|
@builtin
|
|
2386
|
-
def reduce(input, axis, combine_fn, keep_dims=False,
|
|
2605
|
+
def reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
|
|
2387
2606
|
"""Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
|
|
2388
2607
|
|
|
2389
2608
|
:param input: the input tensor, or tuple of tensors
|
|
@@ -2397,64 +2616,65 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N
|
|
|
2397
2616
|
|
|
2398
2617
|
"""
|
|
2399
2618
|
if isinstance(input, tensor):
|
|
2400
|
-
return reduce((input, ), axis, combine_fn, keep_dims=keep_dims,
|
|
2619
|
+
return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic, _generator=_generator)[0]
|
|
2401
2620
|
|
|
2402
2621
|
def make_combine_region(reduce_op):
|
|
2403
2622
|
param_types = [t.type.scalar for t in input] * 2
|
|
2404
2623
|
region = reduce_op.get_region(0)
|
|
2405
|
-
|
|
2406
|
-
|
|
2407
|
-
|
|
2624
|
+
builder = _semantic.builder
|
|
2625
|
+
with _insertion_guard(builder):
|
|
2626
|
+
to_ir = lambda T: T.to_ir(builder)
|
|
2627
|
+
block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
|
|
2408
2628
|
args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
|
|
2409
2629
|
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
|
2410
2630
|
if isinstance(results, tensor):
|
|
2411
2631
|
handles = [results.handle]
|
|
2412
2632
|
else:
|
|
2413
2633
|
handles = [r.handle for r in results]
|
|
2414
|
-
|
|
2634
|
+
builder.create_reduce_ret(*handles)
|
|
2415
2635
|
|
|
2416
2636
|
def expand_ndims(t, ndims):
|
|
2417
2637
|
for _ in builtins.range(ndims):
|
|
2418
|
-
t = expand_dims(t, 0,
|
|
2638
|
+
t = expand_dims(t, 0, _semantic=_semantic)
|
|
2419
2639
|
return t
|
|
2420
2640
|
|
|
2421
|
-
axis =
|
|
2422
|
-
keep_dims =
|
|
2641
|
+
axis = _unwrap_if_constexpr(axis)
|
|
2642
|
+
keep_dims = _unwrap_if_constexpr(keep_dims)
|
|
2423
2643
|
if axis is not None:
|
|
2424
2644
|
axis = _wrap_axis(axis, len(input[0].shape))
|
|
2425
|
-
ret =
|
|
2645
|
+
ret = _semantic.reduction(input, axis, make_combine_region)
|
|
2426
2646
|
if keep_dims:
|
|
2427
2647
|
if axis is not None:
|
|
2428
|
-
ret = tuple(expand_dims(t, axis,
|
|
2648
|
+
ret = tuple(expand_dims(t, axis, _semantic=_semantic) for t in ret)
|
|
2429
2649
|
else:
|
|
2430
2650
|
ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret)
|
|
2431
2651
|
return ret
|
|
2432
2652
|
|
|
2433
2653
|
|
|
2434
2654
|
@builtin
|
|
2435
|
-
def _promote_bfloat16_to_float32(t,
|
|
2655
|
+
def _promote_bfloat16_to_float32(t, _semantic=None):
|
|
2436
2656
|
scalar_ty = t.type.scalar
|
|
2437
2657
|
|
|
2438
2658
|
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
|
|
2439
2659
|
if scalar_ty is bfloat16:
|
|
2440
|
-
return t.to(float32,
|
|
2660
|
+
return t.to(float32, _semantic=_semantic)
|
|
2441
2661
|
return t
|
|
2442
2662
|
|
|
2443
2663
|
|
|
2444
2664
|
@builtin
|
|
2445
|
-
def _reduce_with_indices(input, axis, combine_fn, keep_dims=False,
|
|
2446
|
-
axis =
|
|
2665
|
+
def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
|
|
2666
|
+
axis = _unwrap_if_constexpr(axis)
|
|
2447
2667
|
n = input.shape[axis]
|
|
2448
|
-
index = arange(0, n,
|
|
2668
|
+
index = arange(0, n, _semantic=_semantic)
|
|
2449
2669
|
|
|
2450
2670
|
if len(input.shape) > 1:
|
|
2451
2671
|
# Broadcast index across the non-reduced axes
|
|
2452
2672
|
axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))]
|
|
2453
2673
|
del axes_to_expand[axis]
|
|
2454
|
-
index = expand_dims(index, axes_to_expand,
|
|
2455
|
-
index = broadcast_to(index, input.shape,
|
|
2674
|
+
index = expand_dims(index, axes_to_expand, _semantic=_semantic)
|
|
2675
|
+
index = broadcast_to(index, input.shape, _semantic=_semantic)
|
|
2456
2676
|
|
|
2457
|
-
rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims,
|
|
2677
|
+
rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic,
|
|
2458
2678
|
_generator=_generator)
|
|
2459
2679
|
return rvalue, rindices
|
|
2460
2680
|
|
|
@@ -2464,7 +2684,7 @@ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None
|
|
|
2464
2684
|
# -----------------------
|
|
2465
2685
|
|
|
2466
2686
|
|
|
2467
|
-
def _add_scan_docstr(name: str) -> Callable[[T], T]:
|
|
2687
|
+
def _add_scan_docstr(name: str, dtype_arg: str = None) -> Callable[[T], T]:
|
|
2468
2688
|
|
|
2469
2689
|
def _decorator(func: T) -> T:
|
|
2470
2690
|
docstr = """
|
|
@@ -2473,7 +2693,15 @@ def _add_scan_docstr(name: str) -> Callable[[T], T]:
|
|
|
2473
2693
|
:param input: the input values
|
|
2474
2694
|
:type input: Tensor
|
|
2475
2695
|
:param axis: the dimension along which the scan should be done
|
|
2476
|
-
:type axis: int
|
|
2696
|
+
:type axis: int
|
|
2697
|
+
:param reverse: if true, the scan is performed in the reverse direction
|
|
2698
|
+
:type reverse: bool"""
|
|
2699
|
+
|
|
2700
|
+
if dtype_arg is not None:
|
|
2701
|
+
docstr += f"""
|
|
2702
|
+
:param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. If not specified, small integer types (< 32 bits) are upcasted to prevent overflow. Note that :code:`tl.bfloat16` inputs are automatically promoted to :code:`tl.float32`.
|
|
2703
|
+
:type {dtype_arg}: tl.dtype"""
|
|
2704
|
+
|
|
2477
2705
|
func.__doc__ = docstr.format(name=name)
|
|
2478
2706
|
return func
|
|
2479
2707
|
|
|
@@ -2482,7 +2710,7 @@ def _add_scan_docstr(name: str) -> Callable[[T], T]:
|
|
|
2482
2710
|
|
|
2483
2711
|
@_tensor_member_fn
|
|
2484
2712
|
@builtin
|
|
2485
|
-
def associative_scan(input, axis, combine_fn, reverse=False,
|
|
2713
|
+
def associative_scan(input, axis, combine_fn, reverse=False, _semantic=None, _generator=None):
|
|
2486
2714
|
"""Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry
|
|
2487
2715
|
|
|
2488
2716
|
:param input: the input tensor, or tuple of tensors
|
|
@@ -2496,46 +2724,52 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen
|
|
|
2496
2724
|
|
|
2497
2725
|
"""
|
|
2498
2726
|
if isinstance(input, tensor):
|
|
2499
|
-
return associative_scan((input, ), axis, combine_fn, reverse,
|
|
2727
|
+
return associative_scan((input, ), axis, combine_fn, reverse, _semantic=_semantic, _generator=_generator)[0]
|
|
2500
2728
|
|
|
2501
2729
|
def make_combine_region(scan_op):
|
|
2502
2730
|
param_types = [t.type.scalar for t in input] * 2
|
|
2503
2731
|
region = scan_op.get_region(0)
|
|
2504
|
-
|
|
2505
|
-
|
|
2506
|
-
|
|
2732
|
+
builder = _semantic.builder
|
|
2733
|
+
with _insertion_guard(builder):
|
|
2734
|
+
to_ir = lambda T: T.to_ir(builder)
|
|
2735
|
+
block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
|
|
2507
2736
|
args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
|
|
2508
2737
|
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
|
2509
2738
|
if isinstance(results, tensor):
|
|
2510
2739
|
handles = [results.handle]
|
|
2511
2740
|
else:
|
|
2512
2741
|
handles = [r.handle for r in results]
|
|
2513
|
-
|
|
2742
|
+
builder.create_scan_ret(*handles)
|
|
2514
2743
|
|
|
2515
|
-
axis =
|
|
2744
|
+
axis = _unwrap_if_constexpr(axis)
|
|
2516
2745
|
if axis is not None:
|
|
2517
2746
|
axis = _wrap_axis(axis, len(input[0].shape))
|
|
2518
|
-
return
|
|
2747
|
+
return _semantic.associative_scan(input, axis, make_combine_region, reverse)
|
|
2519
2748
|
|
|
2520
2749
|
|
|
2521
2750
|
@_tensor_member_fn
|
|
2522
2751
|
@builtin
|
|
2523
|
-
def histogram(input, num_bins,
|
|
2752
|
+
def histogram(input, num_bins, mask=None, _semantic=None, _generator=None):
|
|
2524
2753
|
"""computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0.
|
|
2525
2754
|
|
|
2526
2755
|
:param input: the input tensor
|
|
2527
2756
|
:type input: Tensor
|
|
2528
2757
|
:param num_bins: number of histogram bins
|
|
2529
2758
|
:type num_bins: int
|
|
2759
|
+
:param mask: if `mask[idx]` is false, exclude `input[idx]` from histogram
|
|
2760
|
+
:type mask: Block of `triton.int1`, optional
|
|
2530
2761
|
|
|
2531
2762
|
"""
|
|
2532
|
-
num_bins =
|
|
2533
|
-
|
|
2763
|
+
num_bins = _unwrap_if_constexpr(num_bins)
|
|
2764
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2765
|
+
if mask is not None:
|
|
2766
|
+
mask = _semantic.to_tensor(mask)
|
|
2767
|
+
return _semantic.histogram(input, num_bins, mask)
|
|
2534
2768
|
|
|
2535
2769
|
|
|
2536
2770
|
@_tensor_member_fn
|
|
2537
2771
|
@builtin
|
|
2538
|
-
def gather(src, index, axis,
|
|
2772
|
+
def gather(src, index, axis, _semantic=None):
|
|
2539
2773
|
"""Gather from a tensor along a given dimension.
|
|
2540
2774
|
|
|
2541
2775
|
:param src: the source tensor
|
|
@@ -2546,8 +2780,8 @@ def gather(src, index, axis, _builder=None):
|
|
|
2546
2780
|
:type axis: int
|
|
2547
2781
|
|
|
2548
2782
|
"""
|
|
2549
|
-
axis =
|
|
2550
|
-
return
|
|
2783
|
+
axis = _unwrap_if_constexpr(axis)
|
|
2784
|
+
return _semantic.gather(src, index, axis)
|
|
2551
2785
|
|
|
2552
2786
|
|
|
2553
2787
|
# -----------------------
|
|
@@ -2556,15 +2790,15 @@ def gather(src, index, axis, _builder=None):
|
|
|
2556
2790
|
|
|
2557
2791
|
|
|
2558
2792
|
@builtin
|
|
2559
|
-
def debug_barrier(
|
|
2793
|
+
def debug_barrier(_semantic=None):
|
|
2560
2794
|
'''
|
|
2561
2795
|
Insert a barrier to synchronize all threads in a block.
|
|
2562
2796
|
'''
|
|
2563
|
-
return
|
|
2797
|
+
return _semantic.debug_barrier()
|
|
2564
2798
|
|
|
2565
2799
|
|
|
2566
2800
|
@builtin
|
|
2567
|
-
def multiple_of(input, values,
|
|
2801
|
+
def multiple_of(input, values, _semantic=None):
|
|
2568
2802
|
"""
|
|
2569
2803
|
Let the compiler know that the values in :code:`input` are all multiples of :code:`value`.
|
|
2570
2804
|
"""
|
|
@@ -2576,11 +2810,11 @@ def multiple_of(input, values, _builder=None):
|
|
|
2576
2810
|
if not isinstance(d.value, int):
|
|
2577
2811
|
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
|
2578
2812
|
values = [x.value for x in values]
|
|
2579
|
-
return
|
|
2813
|
+
return _semantic.multiple_of(input, values)
|
|
2580
2814
|
|
|
2581
2815
|
|
|
2582
2816
|
@builtin
|
|
2583
|
-
def max_contiguous(input, values,
|
|
2817
|
+
def max_contiguous(input, values, _semantic=None):
|
|
2584
2818
|
"""
|
|
2585
2819
|
Let the compiler know that the `value` first values in :code:`input` are contiguous.
|
|
2586
2820
|
"""
|
|
@@ -2592,11 +2826,11 @@ def max_contiguous(input, values, _builder=None):
|
|
|
2592
2826
|
if not isinstance(d.value, int):
|
|
2593
2827
|
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
|
2594
2828
|
values = [x.value for x in values]
|
|
2595
|
-
return
|
|
2829
|
+
return _semantic.max_contiguous(input, values)
|
|
2596
2830
|
|
|
2597
2831
|
|
|
2598
2832
|
@builtin
|
|
2599
|
-
def max_constancy(input, values,
|
|
2833
|
+
def max_constancy(input, values, _semantic=None):
|
|
2600
2834
|
"""
|
|
2601
2835
|
Let the compiler know that the `value` first values in :code:`input` are constant.
|
|
2602
2836
|
|
|
@@ -2611,15 +2845,15 @@ def max_constancy(input, values, _builder=None):
|
|
|
2611
2845
|
if not isinstance(d.value, int):
|
|
2612
2846
|
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
|
2613
2847
|
values = [x.value for x in values]
|
|
2614
|
-
return
|
|
2848
|
+
return _semantic.max_constancy(input, values)
|
|
2615
2849
|
|
|
2616
2850
|
|
|
2617
2851
|
@builtin
|
|
2618
|
-
def assume(cond,
|
|
2852
|
+
def assume(cond, _semantic=None):
|
|
2619
2853
|
'''
|
|
2620
2854
|
Allow compiler to assume the :code:`cond` is True.
|
|
2621
2855
|
'''
|
|
2622
|
-
return
|
|
2856
|
+
return _semantic.assume(_semantic.to_tensor(cond))
|
|
2623
2857
|
|
|
2624
2858
|
|
|
2625
2859
|
# -----------------------
|
|
@@ -2628,7 +2862,7 @@ def assume(cond, _builder=None):
|
|
|
2628
2862
|
|
|
2629
2863
|
|
|
2630
2864
|
@builtin
|
|
2631
|
-
def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False,
|
|
2865
|
+
def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _semantic=None):
|
|
2632
2866
|
'''
|
|
2633
2867
|
Print the values at compile time. The parameters are the same as the builtin :code:`print`.
|
|
2634
2868
|
|
|
@@ -2644,7 +2878,7 @@ def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=Fals
|
|
|
2644
2878
|
|
|
2645
2879
|
|
|
2646
2880
|
@builtin
|
|
2647
|
-
def static_assert(cond, msg="",
|
|
2881
|
+
def static_assert(cond, msg="", _semantic=None):
|
|
2648
2882
|
'''
|
|
2649
2883
|
Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable
|
|
2650
2884
|
is set.
|
|
@@ -2658,7 +2892,7 @@ def static_assert(cond, msg="", _builder=None):
|
|
|
2658
2892
|
|
|
2659
2893
|
|
|
2660
2894
|
@builtin
|
|
2661
|
-
def device_print(prefix, *args, hex=False,
|
|
2895
|
+
def device_print(prefix, *args, hex=False, _semantic=None):
|
|
2662
2896
|
'''
|
|
2663
2897
|
Print the values at runtime from the device. String formatting does not work for runtime values, so you should
|
|
2664
2898
|
provide the values you want to print as arguments. The first value must be a string, all following values must
|
|
@@ -2692,7 +2926,7 @@ def device_print(prefix, *args, hex=False, _builder=None):
|
|
|
2692
2926
|
:param hex: print all values as hex instead of decimal
|
|
2693
2927
|
'''
|
|
2694
2928
|
import string
|
|
2695
|
-
prefix =
|
|
2929
|
+
prefix = _unwrap_if_constexpr(prefix)
|
|
2696
2930
|
assert isinstance(prefix, str), f"{prefix} is not string"
|
|
2697
2931
|
b_ascii = True
|
|
2698
2932
|
for ch in prefix:
|
|
@@ -2702,12 +2936,12 @@ def device_print(prefix, *args, hex=False, _builder=None):
|
|
|
2702
2936
|
assert b_ascii, f"{prefix} is not an ascii string"
|
|
2703
2937
|
new_args = []
|
|
2704
2938
|
for arg in args:
|
|
2705
|
-
new_args.append(
|
|
2706
|
-
return
|
|
2939
|
+
new_args.append(_semantic.to_tensor(arg))
|
|
2940
|
+
return _semantic.device_print(prefix, new_args, hex)
|
|
2707
2941
|
|
|
2708
2942
|
|
|
2709
2943
|
@builtin
|
|
2710
|
-
def device_assert(cond, msg="",
|
|
2944
|
+
def device_assert(cond, msg="", _semantic=None):
|
|
2711
2945
|
'''
|
|
2712
2946
|
Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG`
|
|
2713
2947
|
is set to a value besides :code:`0` in order for this to have any effect.
|
|
@@ -2725,13 +2959,13 @@ def device_assert(cond, msg="", _builder=None):
|
|
|
2725
2959
|
:param cond: the condition to assert. This is required to be a boolean tensor.
|
|
2726
2960
|
:param msg: the message to print if the assertion fails. This is required to be a string literal.
|
|
2727
2961
|
'''
|
|
2728
|
-
msg =
|
|
2729
|
-
return
|
|
2962
|
+
msg = _unwrap_if_constexpr(msg)
|
|
2963
|
+
return _semantic.device_assert(_semantic.to_tensor(cond), msg)
|
|
2730
2964
|
|
|
2731
2965
|
|
|
2732
2966
|
@builtin
|
|
2733
2967
|
def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]],
|
|
2734
|
-
is_pure: bool, pack: int,
|
|
2968
|
+
is_pure: bool, pack: int, _semantic=None):
|
|
2735
2969
|
'''
|
|
2736
2970
|
Execute inline assembly over a tensor. Essentially, this is :code:`map`
|
|
2737
2971
|
where the function is inline assembly.
|
|
@@ -2816,13 +3050,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
|
|
|
2816
3050
|
:param dtype: the element type(s) of the returned tensor(s)
|
|
2817
3051
|
:param is_pure: if true, the compiler assumes the asm block has no side-effects
|
|
2818
3052
|
:param pack: the number of elements to be processed by one instance of inline assembly
|
|
2819
|
-
:param _builder: the builder
|
|
2820
3053
|
:return: one tensor or a tuple of tensors of the given dtypes
|
|
2821
3054
|
'''
|
|
2822
|
-
asm =
|
|
2823
|
-
constraints =
|
|
2824
|
-
pack =
|
|
2825
|
-
is_pure =
|
|
3055
|
+
asm = _unwrap_if_constexpr(asm)
|
|
3056
|
+
constraints = _unwrap_if_constexpr(constraints)
|
|
3057
|
+
pack = _unwrap_if_constexpr(pack)
|
|
3058
|
+
is_pure = _unwrap_if_constexpr(is_pure)
|
|
2826
3059
|
|
|
2827
3060
|
# Wrap `dtype` in a tuple if it's not already.
|
|
2828
3061
|
try:
|
|
@@ -2835,10 +3068,9 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
|
|
|
2835
3068
|
dtype = typing.cast(Sequence[_DtypeClass], dtype)
|
|
2836
3069
|
|
|
2837
3070
|
res_tys = dtype
|
|
2838
|
-
if dispatch_args := [
|
|
3071
|
+
if dispatch_args := [_semantic.to_tensor(arg) for arg in args]:
|
|
2839
3072
|
bin_op_type_checking = partial(
|
|
2840
|
-
|
|
2841
|
-
builder=_builder,
|
|
3073
|
+
_semantic.binary_op_type_checking_impl,
|
|
2842
3074
|
arithmetic_check=False,
|
|
2843
3075
|
allow_lhs_ptr=True,
|
|
2844
3076
|
allow_rhs_ptr=True,
|
|
@@ -2851,9 +3083,10 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
|
|
|
2851
3083
|
# Change the shape of each argument based on the broadcast shape
|
|
2852
3084
|
for i, item in enumerate(dispatch_args):
|
|
2853
3085
|
dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg)
|
|
2854
|
-
res_tys = [
|
|
3086
|
+
res_tys = [broadcast_arg.type.with_element_ty(dt) for dt in dtype]
|
|
2855
3087
|
handles = [t.handle for t in dispatch_args]
|
|
2856
|
-
|
|
3088
|
+
builder = _semantic.builder
|
|
3089
|
+
call = builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(builder) for ty in res_tys], is_pure, pack)
|
|
2857
3090
|
|
|
2858
3091
|
if not has_multiple_outputs:
|
|
2859
3092
|
return tensor(call.get_result(0), res_tys[0])
|
|
@@ -2905,6 +3138,22 @@ class static_range:
|
|
|
2905
3138
|
raise RuntimeError("static_range can only be used in @triton.jit'd functions")
|
|
2906
3139
|
|
|
2907
3140
|
|
|
3141
|
+
class async_task:
|
|
3142
|
+
"""
|
|
3143
|
+
Context manager to run code fragments asynchronously.
|
|
3144
|
+
"""
|
|
3145
|
+
|
|
3146
|
+
def __init__(self, task_ids, _builder=None):
|
|
3147
|
+
self.task_ids = list({_unwrap_if_constexpr(tid) for tid in task_ids})
|
|
3148
|
+
self.builder = _builder
|
|
3149
|
+
|
|
3150
|
+
def __enter__(self):
|
|
3151
|
+
self.builder.set_async_task_ids(self.task_ids)
|
|
3152
|
+
|
|
3153
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
3154
|
+
self.builder.unset_async_task_ids()
|
|
3155
|
+
|
|
3156
|
+
|
|
2908
3157
|
class range:
|
|
2909
3158
|
"""
|
|
2910
3159
|
Iterator that counts upward forever.
|
|
@@ -2936,10 +3185,18 @@ class range:
|
|
|
2936
3185
|
:param flatten: automatically flatten the loop nest starting at this loop to
|
|
2937
3186
|
create a single flattened loop. The compiler will try to pipeline the
|
|
2938
3187
|
flattened loop which can avoid stage stalling.
|
|
3188
|
+
:param warp_specialize: Enable automatic warp specialization on the loop.
|
|
3189
|
+
The compiler will attempt to partition memory, MMA, and vector
|
|
3190
|
+
operations in the loop into separate async partitions. This will
|
|
3191
|
+
increase the total number of warps required by the kernel.
|
|
3192
|
+
|
|
3193
|
+
Note that warp specialization is only supported on Blackwell GPUs and
|
|
3194
|
+
only works on simple matmul loops. Support for arbitrary loops will be
|
|
3195
|
+
expanded over time.
|
|
2939
3196
|
"""
|
|
2940
3197
|
|
|
2941
3198
|
def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
|
|
2942
|
-
disallow_acc_multi_buffer=False, flatten=False):
|
|
3199
|
+
disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False):
|
|
2943
3200
|
if step is None:
|
|
2944
3201
|
self.step = constexpr(1)
|
|
2945
3202
|
else:
|
|
@@ -2954,6 +3211,7 @@ class range:
|
|
|
2954
3211
|
self.loop_unroll_factor = loop_unroll_factor
|
|
2955
3212
|
self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
|
|
2956
3213
|
self.flatten = flatten
|
|
3214
|
+
self.warp_specialize = warp_specialize
|
|
2957
3215
|
|
|
2958
3216
|
def __iter__(self):
|
|
2959
3217
|
raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
|
|
@@ -2968,7 +3226,7 @@ class range:
|
|
|
2968
3226
|
|
|
2969
3227
|
|
|
2970
3228
|
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
|
|
2971
|
-
is_pure: bool,
|
|
3229
|
+
is_pure: bool, _semantic):
|
|
2972
3230
|
'''
|
|
2973
3231
|
Dispatch a function to a library
|
|
2974
3232
|
:param func: the function to dispatch
|
|
@@ -2977,7 +3235,6 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
|
|
2977
3235
|
:param args: the arguments of the function
|
|
2978
3236
|
:param arg_type_symbol_dict: the type of the arguments
|
|
2979
3237
|
:param ret_shape: the shape of the return value
|
|
2980
|
-
:param _builder: the builder
|
|
2981
3238
|
:return: the return value of the function
|
|
2982
3239
|
'''
|
|
2983
3240
|
if len(arg_type_symbol_dict) == 0:
|
|
@@ -3007,12 +3264,13 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
|
|
3007
3264
|
ret_type = arg_type_symbol_dict[arg_types][1]
|
|
3008
3265
|
if ret_shape:
|
|
3009
3266
|
ret_type = block_type(ret_type, ret_shape)
|
|
3010
|
-
|
|
3267
|
+
builder = _semantic.builder
|
|
3268
|
+
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(builder), is_pure), ret_type)
|
|
3011
3269
|
|
|
3012
3270
|
|
|
3013
3271
|
@builtin
|
|
3014
3272
|
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
|
|
3015
|
-
|
|
3273
|
+
_semantic=None):
|
|
3016
3274
|
'''
|
|
3017
3275
|
Dispatch an elementwise function to a library
|
|
3018
3276
|
:param lib_name: the name of the library
|
|
@@ -3020,7 +3278,6 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
|
|
3020
3278
|
:param args: the arguments of the function
|
|
3021
3279
|
:param arg_type_symbol_dict: the type of the arguments
|
|
3022
3280
|
:param is_pure: whether the function is pure
|
|
3023
|
-
:param _builder: the builder
|
|
3024
3281
|
:return: the return value of the function
|
|
3025
3282
|
'''
|
|
3026
3283
|
dispatch_args = args.copy()
|
|
@@ -3028,7 +3285,7 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
|
|
3028
3285
|
ret_shape = None
|
|
3029
3286
|
arg_types = []
|
|
3030
3287
|
for i in builtins.range(len(dispatch_args)):
|
|
3031
|
-
dispatch_args[i] =
|
|
3288
|
+
dispatch_args[i] = _semantic.to_tensor(dispatch_args[i])
|
|
3032
3289
|
arg_types.append(dispatch_args[i].dtype)
|
|
3033
3290
|
if dispatch_args[i].type.is_block():
|
|
3034
3291
|
all_scalar = False
|
|
@@ -3041,26 +3298,26 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
|
|
3041
3298
|
broadcast_arg = dispatch_args[0]
|
|
3042
3299
|
# Get the broadcast shape over all the arguments
|
|
3043
3300
|
for item in dispatch_args:
|
|
3044
|
-
_, broadcast_arg =
|
|
3045
|
-
|
|
3301
|
+
_, broadcast_arg = _semantic.binary_op_type_checking_impl(item, broadcast_arg,
|
|
3302
|
+
arithmetic_check=arithmetic_check)
|
|
3046
3303
|
# Change the shape of each argument based on the broadcast shape
|
|
3047
3304
|
for i in builtins.range(len(dispatch_args)):
|
|
3048
|
-
dispatch_args[i], _ =
|
|
3049
|
-
|
|
3305
|
+
dispatch_args[i], _ = _semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg,
|
|
3306
|
+
arithmetic_check=arithmetic_check)
|
|
3050
3307
|
if not all_scalar:
|
|
3051
3308
|
ret_shape = broadcast_arg.shape
|
|
3052
|
-
func =
|
|
3053
|
-
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure,
|
|
3309
|
+
func = _semantic.builder.create_extern_elementwise
|
|
3310
|
+
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _semantic)
|
|
3054
3311
|
|
|
3055
3312
|
|
|
3056
|
-
def binary_op_type_legalization(lhs, rhs,
|
|
3313
|
+
def binary_op_type_legalization(lhs, rhs, semantic):
|
|
3057
3314
|
'''
|
|
3058
3315
|
Convert both operands to a single common type
|
|
3059
3316
|
:param lhs: the left operand
|
|
3060
3317
|
:param rhs: the right operand
|
|
3061
3318
|
:param builder: the builder
|
|
3062
3319
|
'''
|
|
3063
|
-
return semantic.binary_op_type_checking_impl(lhs, rhs
|
|
3320
|
+
return semantic.binary_op_type_checking_impl(lhs, rhs)
|
|
3064
3321
|
|
|
3065
3322
|
|
|
3066
3323
|
def extern(fn):
|