triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/language/core.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import math
|
|
3
4
|
from warnings import warn
|
|
4
5
|
from contextlib import contextmanager
|
|
5
6
|
from enum import Enum
|
|
6
7
|
from functools import partial, wraps
|
|
7
8
|
import typing
|
|
8
9
|
from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
|
|
10
|
+
from dataclasses import dataclass
|
|
9
11
|
import builtins
|
|
10
|
-
from ..
|
|
12
|
+
from .. import knobs
|
|
13
|
+
from ..runtime.jit import JITCallable
|
|
11
14
|
import inspect
|
|
12
|
-
import os
|
|
13
15
|
|
|
14
16
|
from .._C.libtriton import ir
|
|
15
|
-
from
|
|
16
|
-
from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape
|
|
17
|
+
from .._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape, get_primitive_bitwidth
|
|
17
18
|
|
|
18
19
|
T = TypeVar('T')
|
|
19
20
|
|
|
@@ -22,15 +23,23 @@ TRITON_BUILTIN = "__triton_builtin__"
|
|
|
22
23
|
PropagateNan = ir.PROPAGATE_NAN
|
|
23
24
|
|
|
24
25
|
|
|
26
|
+
def must_use_result(x, s=True):
|
|
27
|
+
"""If the result of this function is unused, throw an error."""
|
|
28
|
+
if isinstance(x, str):
|
|
29
|
+
return (lambda fn: must_use_result(fn, x))
|
|
30
|
+
x._must_use_result = s
|
|
31
|
+
return x
|
|
32
|
+
|
|
33
|
+
|
|
25
34
|
def builtin(fn: T) -> T:
|
|
26
35
|
"""Mark a function as a builtin."""
|
|
27
36
|
assert callable(fn)
|
|
28
37
|
|
|
29
38
|
@wraps(fn)
|
|
30
39
|
def wrapper(*args, **kwargs):
|
|
31
|
-
if "
|
|
40
|
+
if "_semantic" not in kwargs or kwargs["_semantic"] is None:
|
|
32
41
|
raise ValueError("Did you forget to add @triton.jit ? "
|
|
33
|
-
"(`
|
|
42
|
+
"(`_semantic` argument must be provided outside of JIT functions.)")
|
|
34
43
|
return fn(*args, **kwargs)
|
|
35
44
|
|
|
36
45
|
setattr(wrapper, TRITON_BUILTIN, True)
|
|
@@ -53,8 +62,8 @@ def _tensor_member_fn(fn: T) -> T:
|
|
|
53
62
|
"""
|
|
54
63
|
assert callable(fn)
|
|
55
64
|
orig_sig = inspect.signature(fn)
|
|
56
|
-
# Does fn take args other than
|
|
57
|
-
has_args = len(orig_sig.parameters.keys() - {"
|
|
65
|
+
# Does fn take args other than _semantic, _generator, and the tensor itself?
|
|
66
|
+
has_args = len(orig_sig.parameters.keys() - {"_semantic", "_generator"}) > 1
|
|
58
67
|
|
|
59
68
|
if not fn.__doc__:
|
|
60
69
|
fn.__doc__ = ""
|
|
@@ -78,7 +87,7 @@ def _tensor_member_fn(fn: T) -> T:
|
|
|
78
87
|
if is_builtin(fn):
|
|
79
88
|
setattr(wrapper, TRITON_BUILTIN, True)
|
|
80
89
|
|
|
81
|
-
setattr(tensor, fn.__name__, wrapper)
|
|
90
|
+
setattr(tensor, fn.__name__, fn if isinstance(fn, JITCallable) else wrapper)
|
|
82
91
|
return fn
|
|
83
92
|
|
|
84
93
|
|
|
@@ -110,8 +119,8 @@ def is_builtin(fn) -> bool:
|
|
|
110
119
|
|
|
111
120
|
|
|
112
121
|
@builtin
|
|
113
|
-
def to_tensor(x,
|
|
114
|
-
return
|
|
122
|
+
def to_tensor(x, _semantic=None):
|
|
123
|
+
return _semantic.to_tensor(x)
|
|
115
124
|
|
|
116
125
|
|
|
117
126
|
# -----------------------
|
|
@@ -130,90 +139,153 @@ class const:
|
|
|
130
139
|
pass
|
|
131
140
|
|
|
132
141
|
|
|
133
|
-
class
|
|
142
|
+
class base_value:
|
|
143
|
+
"""Base class of values that exist in the triton IR (i.e. not constexprs).
|
|
144
|
+
"""
|
|
145
|
+
type: base_type
|
|
146
|
+
|
|
147
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
148
|
+
"""Flatten frontend value into a sequence of mlir handles, which are appended
|
|
149
|
+
to the output list
|
|
150
|
+
"""
|
|
151
|
+
raise NotImplementedError
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class base_type:
|
|
155
|
+
|
|
156
|
+
def __eq__(self, other) -> bool:
|
|
157
|
+
raise NotImplementedError("Types must implement __eq__")
|
|
158
|
+
|
|
159
|
+
def __ne__(self, other) -> bool:
|
|
160
|
+
return not (self == other)
|
|
161
|
+
|
|
162
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
|
|
163
|
+
"""Build a frontend value with the current dtype, wrapping a list of existing handles.
|
|
164
|
+
cursor is the index of the first handle relevant to this value, and the function
|
|
165
|
+
should return the updated cursor position after any handles consumed by the created value.
|
|
166
|
+
"""
|
|
167
|
+
raise NotImplementedError
|
|
168
|
+
|
|
169
|
+
def mangle(self) -> str:
|
|
170
|
+
raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}")
|
|
171
|
+
|
|
172
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
173
|
+
raise NotImplementedError
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class constexpr_type(base_type):
|
|
177
|
+
|
|
178
|
+
def __init__(self, value):
|
|
179
|
+
self.value = value
|
|
180
|
+
|
|
181
|
+
def __eq__(self, other):
|
|
182
|
+
return isinstance(other, constexpr_type) and self.value == other.value
|
|
183
|
+
|
|
184
|
+
def __repr__(self) -> str:
|
|
185
|
+
return f"constexpr_type[{self.value}]"
|
|
186
|
+
|
|
187
|
+
def __hash__(self):
|
|
188
|
+
return hash(self.value)
|
|
189
|
+
|
|
190
|
+
def mangle(self) -> str:
|
|
191
|
+
return repr(self)
|
|
192
|
+
|
|
193
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
|
|
197
|
+
return constexpr(self.value), cursor
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class constexpr(base_value):
|
|
134
201
|
"""
|
|
135
202
|
This class is used to store a value that is known at compile-time.
|
|
136
203
|
"""
|
|
137
204
|
|
|
138
205
|
def __init__(self, value):
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
self.type = constexpr
|
|
206
|
+
while isinstance(value, constexpr):
|
|
207
|
+
value = value.value
|
|
208
|
+
self.value = value
|
|
209
|
+
self.type = constexpr_type(value)
|
|
144
210
|
|
|
145
211
|
def __repr__(self) -> str:
|
|
146
212
|
return f"constexpr[{self.value}]"
|
|
147
213
|
|
|
214
|
+
def __hash__(self):
|
|
215
|
+
return hash((self.value, self.type))
|
|
216
|
+
|
|
217
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
218
|
+
return
|
|
219
|
+
|
|
148
220
|
def __index__(self):
|
|
149
221
|
return self.value
|
|
150
222
|
|
|
151
223
|
# In interpreter mode, constant values are not wrapped in constexpr,
|
|
152
224
|
# and therefore do not have a .value attribute.
|
|
153
|
-
# As a result, from here and below, we need to call the
|
|
225
|
+
# As a result, from here and below, we need to call the _unwrap_if_constexpr
|
|
154
226
|
# function to obtain either constexpr.value or the value itself.
|
|
155
227
|
def __add__(self, other):
|
|
156
|
-
return constexpr(self.value +
|
|
228
|
+
return constexpr(self.value + _unwrap_if_constexpr(other))
|
|
157
229
|
|
|
158
230
|
def __radd__(self, other):
|
|
159
|
-
return constexpr(
|
|
231
|
+
return constexpr(_unwrap_if_constexpr(other) + self.value)
|
|
160
232
|
|
|
161
233
|
def __sub__(self, other):
|
|
162
|
-
return constexpr(self.value -
|
|
234
|
+
return constexpr(self.value - _unwrap_if_constexpr(other))
|
|
163
235
|
|
|
164
236
|
def __rsub__(self, other):
|
|
165
|
-
return constexpr(
|
|
237
|
+
return constexpr(_unwrap_if_constexpr(other) - self.value)
|
|
166
238
|
|
|
167
239
|
def __mul__(self, other):
|
|
168
|
-
return constexpr(self.value *
|
|
240
|
+
return constexpr(self.value * _unwrap_if_constexpr(other))
|
|
169
241
|
|
|
170
242
|
def __mod__(self, other):
|
|
171
|
-
return constexpr(self.value %
|
|
243
|
+
return constexpr(self.value % _unwrap_if_constexpr(other))
|
|
172
244
|
|
|
173
245
|
def __rmul__(self, other):
|
|
174
|
-
return constexpr(
|
|
246
|
+
return constexpr(_unwrap_if_constexpr(other) * self.value)
|
|
175
247
|
|
|
176
248
|
def __truediv__(self, other):
|
|
177
|
-
return constexpr(self.value /
|
|
249
|
+
return constexpr(self.value / _unwrap_if_constexpr(other))
|
|
178
250
|
|
|
179
251
|
def __rtruediv__(self, other):
|
|
180
|
-
return constexpr(
|
|
252
|
+
return constexpr(_unwrap_if_constexpr(other) / self.value)
|
|
181
253
|
|
|
182
254
|
def __floordiv__(self, other):
|
|
183
|
-
return constexpr(self.value //
|
|
255
|
+
return constexpr(self.value // _unwrap_if_constexpr(other))
|
|
184
256
|
|
|
185
257
|
def __rfloordiv__(self, other):
|
|
186
|
-
return constexpr(
|
|
258
|
+
return constexpr(_unwrap_if_constexpr(other) // self.value)
|
|
187
259
|
|
|
188
260
|
def __gt__(self, other):
|
|
189
|
-
return constexpr(self.value >
|
|
261
|
+
return constexpr(self.value > _unwrap_if_constexpr(other))
|
|
190
262
|
|
|
191
263
|
def __rgt__(self, other):
|
|
192
|
-
return constexpr(
|
|
264
|
+
return constexpr(_unwrap_if_constexpr(other) > self.value)
|
|
193
265
|
|
|
194
266
|
def __ge__(self, other):
|
|
195
|
-
return constexpr(self.value >=
|
|
267
|
+
return constexpr(self.value >= _unwrap_if_constexpr(other))
|
|
196
268
|
|
|
197
269
|
def __rge__(self, other):
|
|
198
|
-
return constexpr(
|
|
270
|
+
return constexpr(_unwrap_if_constexpr(other) >= self.value)
|
|
199
271
|
|
|
200
272
|
def __lt__(self, other):
|
|
201
|
-
return constexpr(self.value <
|
|
273
|
+
return constexpr(self.value < _unwrap_if_constexpr(other))
|
|
202
274
|
|
|
203
275
|
def __rlt__(self, other):
|
|
204
|
-
return constexpr(
|
|
276
|
+
return constexpr(_unwrap_if_constexpr(other) < self.value)
|
|
205
277
|
|
|
206
278
|
def __le__(self, other):
|
|
207
|
-
return constexpr(self.value <=
|
|
279
|
+
return constexpr(self.value <= _unwrap_if_constexpr(other))
|
|
208
280
|
|
|
209
281
|
def __rle__(self, other):
|
|
210
|
-
return constexpr(
|
|
282
|
+
return constexpr(_unwrap_if_constexpr(other) <= self.value)
|
|
211
283
|
|
|
212
284
|
def __eq__(self, other):
|
|
213
|
-
return constexpr(self.value ==
|
|
285
|
+
return constexpr(self.value == _unwrap_if_constexpr(other))
|
|
214
286
|
|
|
215
287
|
def __ne__(self, other):
|
|
216
|
-
return constexpr(self.value !=
|
|
288
|
+
return constexpr(self.value != _unwrap_if_constexpr(other))
|
|
217
289
|
|
|
218
290
|
def __bool__(self):
|
|
219
291
|
return bool(self.value)
|
|
@@ -222,19 +294,19 @@ class constexpr:
|
|
|
222
294
|
return constexpr(-self.value)
|
|
223
295
|
|
|
224
296
|
def __and__(self, other):
|
|
225
|
-
return constexpr(self.value &
|
|
297
|
+
return constexpr(self.value & _unwrap_if_constexpr(other))
|
|
226
298
|
|
|
227
299
|
def logical_and(self, other):
|
|
228
|
-
return constexpr(self.value and
|
|
300
|
+
return constexpr(self.value and _unwrap_if_constexpr(other))
|
|
229
301
|
|
|
230
302
|
def __or__(self, other):
|
|
231
|
-
return constexpr(self.value |
|
|
303
|
+
return constexpr(self.value | _unwrap_if_constexpr(other))
|
|
232
304
|
|
|
233
305
|
def __xor__(self, other):
|
|
234
|
-
return constexpr(self.value ^
|
|
306
|
+
return constexpr(self.value ^ _unwrap_if_constexpr(other))
|
|
235
307
|
|
|
236
308
|
def logical_or(self, other):
|
|
237
|
-
return constexpr(self.value or
|
|
309
|
+
return constexpr(self.value or _unwrap_if_constexpr(other))
|
|
238
310
|
|
|
239
311
|
def __pos__(self):
|
|
240
312
|
return constexpr(+self.value)
|
|
@@ -243,16 +315,16 @@ class constexpr:
|
|
|
243
315
|
return constexpr(~self.value)
|
|
244
316
|
|
|
245
317
|
def __pow__(self, other):
|
|
246
|
-
return constexpr(self.value**
|
|
318
|
+
return constexpr(self.value**_unwrap_if_constexpr(other))
|
|
247
319
|
|
|
248
320
|
def __rpow__(self, other):
|
|
249
|
-
return constexpr(
|
|
321
|
+
return constexpr(_unwrap_if_constexpr(other)**self.value)
|
|
250
322
|
|
|
251
323
|
def __rshift__(self, other):
|
|
252
|
-
return constexpr(self.value >>
|
|
324
|
+
return constexpr(self.value >> _unwrap_if_constexpr(other))
|
|
253
325
|
|
|
254
326
|
def __lshift__(self, other):
|
|
255
|
-
return constexpr(self.value <<
|
|
327
|
+
return constexpr(self.value << _unwrap_if_constexpr(other))
|
|
256
328
|
|
|
257
329
|
def __not__(self):
|
|
258
330
|
return constexpr(not self.value)
|
|
@@ -263,14 +335,31 @@ class constexpr:
|
|
|
263
335
|
def __call__(self, *args, **kwds):
|
|
264
336
|
return self.value(*args, **kwds)
|
|
265
337
|
|
|
338
|
+
def __getitem__(self, *args):
|
|
339
|
+
args = (_unwrap_if_constexpr(x) for x in _normalize_tuple(args))
|
|
340
|
+
return self.value.__getitem__(*args)
|
|
341
|
+
|
|
266
342
|
|
|
267
343
|
CONSTEXPR_0 = constexpr(0)
|
|
268
344
|
|
|
269
345
|
|
|
270
346
|
def _unwrap_if_constexpr(o):
|
|
347
|
+
if isinstance(o, list):
|
|
348
|
+
return [_unwrap_if_constexpr(x) for x in o]
|
|
349
|
+
if isinstance(o, builtins.tuple):
|
|
350
|
+
return builtins.tuple(_unwrap_if_constexpr(x) for x in o)
|
|
351
|
+
if isinstance(o, tuple):
|
|
352
|
+
return tuple(_unwrap_if_constexpr(x) for x in o)
|
|
271
353
|
return o.value if isinstance(o, constexpr) else o
|
|
272
354
|
|
|
273
355
|
|
|
356
|
+
def _normalize_tuple(t):
|
|
357
|
+
normalized_tuple = _unwrap_if_constexpr(t)
|
|
358
|
+
if isinstance(normalized_tuple, (list, builtins.tuple)):
|
|
359
|
+
normalized_tuple = tuple(normalized_tuple)
|
|
360
|
+
return normalized_tuple
|
|
361
|
+
|
|
362
|
+
|
|
274
363
|
def check_bit_width(value, shift_value):
|
|
275
364
|
if isinstance(value, tensor) and isinstance(shift_value, constexpr):
|
|
276
365
|
bitwidth = value.type.scalar.primitive_bitwidth
|
|
@@ -280,34 +369,6 @@ def check_bit_width(value, shift_value):
|
|
|
280
369
|
)
|
|
281
370
|
|
|
282
371
|
|
|
283
|
-
class base_value:
|
|
284
|
-
"""Base class of values that exist in the triton IR (i.e. not constexprs).
|
|
285
|
-
"""
|
|
286
|
-
type: base_type
|
|
287
|
-
|
|
288
|
-
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
289
|
-
"""Flatten frontend value into a sequence of mlir handles, which are appended
|
|
290
|
-
to the output list
|
|
291
|
-
"""
|
|
292
|
-
raise NotImplementedError
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
class base_type:
|
|
296
|
-
|
|
297
|
-
def __eq__(self, other):
|
|
298
|
-
raise NotImplementedError("Types must implement __eq__")
|
|
299
|
-
|
|
300
|
-
def __ne__(self, other):
|
|
301
|
-
return not (self == other)
|
|
302
|
-
|
|
303
|
-
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
|
|
304
|
-
"""Build a frontend value with the current dtype, wrapping a list of existing handles.
|
|
305
|
-
cursor is the index of the first handle relevant to this value, and the function
|
|
306
|
-
should return the updated cursor position after any handles consumed by the created value.
|
|
307
|
-
"""
|
|
308
|
-
raise NotImplementedError
|
|
309
|
-
|
|
310
|
-
|
|
311
372
|
# -----------------------
|
|
312
373
|
# dtype
|
|
313
374
|
# -----------------------
|
|
@@ -333,55 +394,44 @@ class dtype(base_type):
|
|
|
333
394
|
name = _unwrap_if_constexpr(name)
|
|
334
395
|
self.name = name
|
|
335
396
|
assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name
|
|
397
|
+
self.primitive_bitwidth = get_primitive_bitwidth(name)
|
|
398
|
+
self.itemsize = self.primitive_bitwidth // 8
|
|
336
399
|
if name in dtype.SINT_TYPES:
|
|
337
400
|
self.int_signedness = dtype.SIGNEDNESS.SIGNED
|
|
338
|
-
self.int_bitwidth =
|
|
339
|
-
self.primitive_bitwidth = self.int_bitwidth
|
|
401
|
+
self.int_bitwidth = self.primitive_bitwidth
|
|
340
402
|
elif name in dtype.UINT_TYPES:
|
|
341
403
|
self.int_signedness = dtype.SIGNEDNESS.UNSIGNED
|
|
342
|
-
self.int_bitwidth =
|
|
343
|
-
self.primitive_bitwidth = self.int_bitwidth
|
|
404
|
+
self.int_bitwidth = self.primitive_bitwidth
|
|
344
405
|
elif name in dtype.FP_TYPES:
|
|
345
406
|
if name == 'fp8e4b15':
|
|
346
407
|
self.fp_mantissa_width = 3
|
|
347
|
-
self.primitive_bitwidth = 8
|
|
348
408
|
self.exponent_bias = 15
|
|
349
409
|
elif name == 'fp8e4nv':
|
|
350
410
|
self.fp_mantissa_width = 3
|
|
351
|
-
self.primitive_bitwidth = 8
|
|
352
411
|
self.exponent_bias = 7
|
|
353
412
|
elif name == 'fp8e4b8':
|
|
354
413
|
self.fp_mantissa_width = 3
|
|
355
|
-
self.primitive_bitwidth = 8
|
|
356
414
|
self.exponent_bias = 8
|
|
357
415
|
elif name == 'fp8e5':
|
|
358
416
|
self.fp_mantissa_width = 2
|
|
359
|
-
self.primitive_bitwidth = 8
|
|
360
417
|
self.exponent_bias = 15
|
|
361
418
|
elif name == 'fp8e5b16':
|
|
362
419
|
self.fp_mantissa_width = 2
|
|
363
|
-
self.primitive_bitwidth = 8
|
|
364
420
|
self.exponent_bias = 16
|
|
365
421
|
elif name == 'fp16':
|
|
366
422
|
self.fp_mantissa_width = 10
|
|
367
|
-
self.primitive_bitwidth = 16
|
|
368
423
|
self.exponent_bias = 15
|
|
369
424
|
elif name == 'bf16':
|
|
370
425
|
self.fp_mantissa_width = 7
|
|
371
|
-
self.primitive_bitwidth = 16
|
|
372
426
|
self.exponent_bias = 127
|
|
373
427
|
elif name == 'fp32':
|
|
374
428
|
self.fp_mantissa_width = 23
|
|
375
|
-
self.primitive_bitwidth = 32
|
|
376
429
|
self.exponent_bias = 127
|
|
377
430
|
elif name == 'fp64':
|
|
378
431
|
self.fp_mantissa_width = 52
|
|
379
|
-
self.primitive_bitwidth = 64
|
|
380
432
|
self.exponent_bias = 1023
|
|
381
433
|
else:
|
|
382
434
|
raise RuntimeError(f'Unsupported floating-point type {name}')
|
|
383
|
-
elif name == 'void':
|
|
384
|
-
self.primitive_bitwidth = 0
|
|
385
435
|
|
|
386
436
|
def is_fp8(self):
|
|
387
437
|
return 'fp8' in self.name
|
|
@@ -502,11 +552,8 @@ class dtype(base_type):
|
|
|
502
552
|
def is_const():
|
|
503
553
|
return False
|
|
504
554
|
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
return False
|
|
508
|
-
|
|
509
|
-
def __eq__(self, other: dtype):
|
|
555
|
+
def __eq__(self, other) -> bool:
|
|
556
|
+
other = _unwrap_if_constexpr(other)
|
|
510
557
|
if not isinstance(other, dtype):
|
|
511
558
|
return False
|
|
512
559
|
return self.name == other.name
|
|
@@ -518,13 +565,14 @@ class dtype(base_type):
|
|
|
518
565
|
def scalar(self):
|
|
519
566
|
return self
|
|
520
567
|
|
|
568
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
569
|
+
out.append(self.to_ir(builder))
|
|
570
|
+
|
|
521
571
|
def to_ir(self, builder: ir.builder) -> ir.type:
|
|
522
572
|
if self.name.startswith("fp8"):
|
|
523
573
|
if self.name not in builder.options.supported_fp8_dtypes:
|
|
524
574
|
raise ValueError(f'type {self} not supported in this architecture. '
|
|
525
575
|
f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}')
|
|
526
|
-
if self.name in builder.options.deprecated_fp8_dtypes:
|
|
527
|
-
warn(f"{self.name} is deprecated in this architecture and will be removed in a future triton release")
|
|
528
576
|
|
|
529
577
|
if self.name == 'void':
|
|
530
578
|
return builder.get_void_ty()
|
|
@@ -581,6 +629,21 @@ class dtype(base_type):
|
|
|
581
629
|
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
|
|
582
630
|
return tensor(handles[cursor], self), cursor + 1
|
|
583
631
|
|
|
632
|
+
def mangle(self) -> str:
|
|
633
|
+
if self.is_int():
|
|
634
|
+
SIGNED = dtype.SIGNEDNESS.SIGNED
|
|
635
|
+
prefix = 'i' if self.int_signedness == SIGNED else 'u'
|
|
636
|
+
return prefix + str(self.int_bitwidth)
|
|
637
|
+
if self.is_floating():
|
|
638
|
+
return str(self)
|
|
639
|
+
if self.is_void():
|
|
640
|
+
return 'V'
|
|
641
|
+
return super().mangle()
|
|
642
|
+
|
|
643
|
+
def with_element_ty(self, element_ty: dtype):
|
|
644
|
+
assert not self.is_block()
|
|
645
|
+
return element_ty
|
|
646
|
+
|
|
584
647
|
|
|
585
648
|
# Some functions have a param named `dtype`, which shadows the `dtype` class.
|
|
586
649
|
# We can't change the param name because it is part of function's public API.
|
|
@@ -614,7 +677,8 @@ class pointer_type(dtype):
|
|
|
614
677
|
def is_const(self):
|
|
615
678
|
return self.const
|
|
616
679
|
|
|
617
|
-
def __eq__(self, other
|
|
680
|
+
def __eq__(self, other) -> bool:
|
|
681
|
+
other = _unwrap_if_constexpr(other)
|
|
618
682
|
if not isinstance(other, pointer_type):
|
|
619
683
|
return False
|
|
620
684
|
return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const
|
|
@@ -623,12 +687,8 @@ class pointer_type(dtype):
|
|
|
623
687
|
def scalar(self):
|
|
624
688
|
return self
|
|
625
689
|
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
def __init__(self, const=True, address_space=0):
|
|
630
|
-
super().__init__(uint8, const=const, address_space=address_space)
|
|
631
|
-
self.name = 'nv_tma_desc_type'
|
|
690
|
+
def mangle(self) -> str:
|
|
691
|
+
return f"P{self.element_ty.mangle()}"
|
|
632
692
|
|
|
633
693
|
|
|
634
694
|
class block_type(dtype):
|
|
@@ -660,9 +720,12 @@ class block_type(dtype):
|
|
|
660
720
|
def is_block(self):
|
|
661
721
|
return True
|
|
662
722
|
|
|
663
|
-
def get_block_shapes(self) ->
|
|
723
|
+
def get_block_shapes(self) -> Tuple[int]:
|
|
664
724
|
return self.shape
|
|
665
725
|
|
|
726
|
+
def with_element_ty(self, scalar_ty: dtype) -> block_type:
|
|
727
|
+
return block_type(scalar_ty, self.shape)
|
|
728
|
+
|
|
666
729
|
def __eq__(self, other) -> bool:
|
|
667
730
|
if not isinstance(other, block_type):
|
|
668
731
|
return False
|
|
@@ -672,6 +735,15 @@ class block_type(dtype):
|
|
|
672
735
|
def scalar(self):
|
|
673
736
|
return self.element_ty
|
|
674
737
|
|
|
738
|
+
@property
|
|
739
|
+
def nbytes(self):
|
|
740
|
+
return self.numel * (self.element_ty.primitive_bitwidth // 8)
|
|
741
|
+
|
|
742
|
+
def mangle(self) -> str:
|
|
743
|
+
elt = self.scalar.mangle()
|
|
744
|
+
shape = '_'.join(map(str, self.shape))
|
|
745
|
+
return f'{elt}S{shape}S'
|
|
746
|
+
|
|
675
747
|
|
|
676
748
|
class tuple_type(base_type):
|
|
677
749
|
|
|
@@ -686,15 +758,14 @@ class tuple_type(base_type):
|
|
|
686
758
|
def __iter__(self):
|
|
687
759
|
return iter(self.types)
|
|
688
760
|
|
|
689
|
-
def
|
|
690
|
-
|
|
761
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]):
|
|
762
|
+
for ty in self.types:
|
|
763
|
+
if not isinstance(ty, constexpr):
|
|
764
|
+
ty._flatten_ir_types(builder, out)
|
|
691
765
|
|
|
692
766
|
def __getitem__(self, index: int) -> dtype:
|
|
693
767
|
return self.types[index]
|
|
694
768
|
|
|
695
|
-
def is_tuple(self):
|
|
696
|
-
return True
|
|
697
|
-
|
|
698
769
|
def __eq__(self, other):
|
|
699
770
|
return type(self) is type(other) and self.types == other.types and self.fields == other.fields
|
|
700
771
|
|
|
@@ -705,6 +776,9 @@ class tuple_type(base_type):
|
|
|
705
776
|
values.append(value)
|
|
706
777
|
return tuple(values, self), cursor
|
|
707
778
|
|
|
779
|
+
def mangle(self):
|
|
780
|
+
return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T'
|
|
781
|
+
|
|
708
782
|
|
|
709
783
|
class slice_type(dtype):
|
|
710
784
|
|
|
@@ -791,10 +865,7 @@ class tensor(base_value):
|
|
|
791
865
|
self.handle = handle
|
|
792
866
|
# Block shape
|
|
793
867
|
self.shape = type.shape if type.is_block() else ()
|
|
794
|
-
self.numel =
|
|
795
|
-
for s in self.shape:
|
|
796
|
-
self.numel *= s
|
|
797
|
-
self.numel = constexpr(self.numel)
|
|
868
|
+
self.numel = constexpr(math.prod(self.shape))
|
|
798
869
|
self.type = type # Tensor type (can be block_type)
|
|
799
870
|
# Following the practice in pytorch, dtype is scalar type
|
|
800
871
|
self.dtype = type.scalar
|
|
@@ -808,224 +879,224 @@ class tensor(base_value):
|
|
|
808
879
|
return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
|
|
809
880
|
|
|
810
881
|
@builtin
|
|
811
|
-
def __add__(self, other,
|
|
812
|
-
return add(self, other, sanitize_overflow=True,
|
|
882
|
+
def __add__(self, other, _semantic=None):
|
|
883
|
+
return add(self, other, sanitize_overflow=True, _semantic=_semantic)
|
|
813
884
|
|
|
814
885
|
@builtin
|
|
815
|
-
def __radd__(self, other,
|
|
816
|
-
return add(other, self, sanitize_overflow=True,
|
|
886
|
+
def __radd__(self, other, _semantic=None):
|
|
887
|
+
return add(other, self, sanitize_overflow=True, _semantic=_semantic)
|
|
817
888
|
|
|
818
889
|
@builtin
|
|
819
|
-
def __sub__(self, other,
|
|
820
|
-
return sub(self, other, sanitize_overflow=True,
|
|
890
|
+
def __sub__(self, other, _semantic=None):
|
|
891
|
+
return sub(self, other, sanitize_overflow=True, _semantic=_semantic)
|
|
821
892
|
|
|
822
893
|
@builtin
|
|
823
|
-
def __rsub__(self, other,
|
|
824
|
-
return sub(other, self, sanitize_overflow=True,
|
|
894
|
+
def __rsub__(self, other, _semantic=None):
|
|
895
|
+
return sub(other, self, sanitize_overflow=True, _semantic=_semantic)
|
|
825
896
|
|
|
826
897
|
@builtin
|
|
827
|
-
def __mul__(self, other,
|
|
828
|
-
return mul(self, other, sanitize_overflow=True,
|
|
898
|
+
def __mul__(self, other, _semantic=None):
|
|
899
|
+
return mul(self, other, sanitize_overflow=True, _semantic=_semantic)
|
|
829
900
|
|
|
830
901
|
@builtin
|
|
831
|
-
def __rmul__(self, other,
|
|
832
|
-
return mul(other, self, sanitize_overflow=True,
|
|
902
|
+
def __rmul__(self, other, _semantic=None):
|
|
903
|
+
return mul(other, self, sanitize_overflow=True, _semantic=_semantic)
|
|
833
904
|
|
|
834
905
|
@builtin
|
|
835
|
-
def __truediv__(self, other,
|
|
906
|
+
def __truediv__(self, other, _semantic=None):
|
|
836
907
|
other = _unwrap_if_constexpr(other)
|
|
837
|
-
return
|
|
908
|
+
return _semantic.truediv(self, other)
|
|
838
909
|
|
|
839
910
|
@builtin
|
|
840
|
-
def __rtruediv__(self, other,
|
|
911
|
+
def __rtruediv__(self, other, _semantic=None):
|
|
841
912
|
other = _unwrap_if_constexpr(other)
|
|
842
|
-
return
|
|
913
|
+
return _semantic.truediv(other, self)
|
|
843
914
|
|
|
844
915
|
@builtin
|
|
845
|
-
def __floordiv__(self, other,
|
|
916
|
+
def __floordiv__(self, other, _semantic=None):
|
|
846
917
|
other = _unwrap_if_constexpr(other)
|
|
847
|
-
return
|
|
918
|
+
return _semantic.floordiv(self, other)
|
|
848
919
|
|
|
849
920
|
@builtin
|
|
850
|
-
def __rfloordiv__(self, other,
|
|
921
|
+
def __rfloordiv__(self, other, _semantic=None):
|
|
851
922
|
other = _unwrap_if_constexpr(other)
|
|
852
|
-
return
|
|
923
|
+
return _semantic.floordiv(other, self)
|
|
853
924
|
|
|
854
925
|
@builtin
|
|
855
|
-
def __mod__(self, other,
|
|
926
|
+
def __mod__(self, other, _semantic=None):
|
|
856
927
|
other = _unwrap_if_constexpr(other)
|
|
857
|
-
return
|
|
928
|
+
return _semantic.mod(self, other)
|
|
858
929
|
|
|
859
930
|
@builtin
|
|
860
|
-
def __rmod__(self, other,
|
|
931
|
+
def __rmod__(self, other, _semantic=None):
|
|
861
932
|
other = _unwrap_if_constexpr(other)
|
|
862
|
-
return
|
|
933
|
+
return _semantic.mod(other, self)
|
|
863
934
|
|
|
864
935
|
# unary operators
|
|
865
936
|
@builtin
|
|
866
|
-
def __neg__(self,
|
|
867
|
-
return
|
|
937
|
+
def __neg__(self, _semantic=None):
|
|
938
|
+
return _semantic.minus(self)
|
|
868
939
|
|
|
869
940
|
@builtin
|
|
870
|
-
def __invert__(self,
|
|
871
|
-
return
|
|
941
|
+
def __invert__(self, _semantic=None):
|
|
942
|
+
return _semantic.invert(self)
|
|
872
943
|
|
|
873
944
|
# bitwise operators
|
|
874
945
|
|
|
875
946
|
@builtin
|
|
876
|
-
def __and__(self, other,
|
|
947
|
+
def __and__(self, other, _semantic=None):
|
|
877
948
|
other = _unwrap_if_constexpr(other)
|
|
878
|
-
return
|
|
949
|
+
return _semantic.and_(self, other)
|
|
879
950
|
|
|
880
951
|
@builtin
|
|
881
|
-
def __rand__(self, other,
|
|
952
|
+
def __rand__(self, other, _semantic=None):
|
|
882
953
|
other = _unwrap_if_constexpr(other)
|
|
883
|
-
return
|
|
954
|
+
return _semantic.and_(other, self)
|
|
884
955
|
|
|
885
956
|
@builtin
|
|
886
|
-
def __or__(self, other,
|
|
957
|
+
def __or__(self, other, _semantic=None):
|
|
887
958
|
other = _unwrap_if_constexpr(other)
|
|
888
|
-
return
|
|
959
|
+
return _semantic.or_(self, other)
|
|
889
960
|
|
|
890
961
|
@builtin
|
|
891
|
-
def __ror__(self, other,
|
|
962
|
+
def __ror__(self, other, _semantic=None):
|
|
892
963
|
other = _unwrap_if_constexpr(other)
|
|
893
|
-
return
|
|
964
|
+
return _semantic.or_(other, self)
|
|
894
965
|
|
|
895
966
|
@builtin
|
|
896
|
-
def __xor__(self, other,
|
|
967
|
+
def __xor__(self, other, _semantic=None):
|
|
897
968
|
other = _unwrap_if_constexpr(other)
|
|
898
|
-
return
|
|
969
|
+
return _semantic.xor_(self, other)
|
|
899
970
|
|
|
900
971
|
@builtin
|
|
901
|
-
def __rxor__(self, other,
|
|
972
|
+
def __rxor__(self, other, _semantic=None):
|
|
902
973
|
other = _unwrap_if_constexpr(other)
|
|
903
|
-
return
|
|
974
|
+
return _semantic.xor_(other, self)
|
|
904
975
|
|
|
905
976
|
@builtin
|
|
906
|
-
def __lshift__(self, other,
|
|
977
|
+
def __lshift__(self, other, _semantic=None):
|
|
907
978
|
check_bit_width(self, other)
|
|
908
979
|
other = _unwrap_if_constexpr(other)
|
|
909
|
-
return
|
|
980
|
+
return _semantic.shl(self, other)
|
|
910
981
|
|
|
911
982
|
@builtin
|
|
912
|
-
def __rlshift__(self, other,
|
|
983
|
+
def __rlshift__(self, other, _semantic=None):
|
|
913
984
|
check_bit_width(other, self)
|
|
914
985
|
other = _unwrap_if_constexpr(other)
|
|
915
|
-
return
|
|
986
|
+
return _semantic.shl(other, self)
|
|
916
987
|
|
|
917
988
|
@builtin
|
|
918
|
-
def __rshift__(self, other,
|
|
989
|
+
def __rshift__(self, other, _semantic=None):
|
|
919
990
|
check_bit_width(self, other)
|
|
920
991
|
other = _unwrap_if_constexpr(other)
|
|
921
992
|
if self.dtype.is_int_signed():
|
|
922
|
-
return
|
|
993
|
+
return _semantic.ashr(self, other)
|
|
923
994
|
else:
|
|
924
|
-
return
|
|
995
|
+
return _semantic.lshr(self, other)
|
|
925
996
|
|
|
926
997
|
@builtin
|
|
927
|
-
def __rrshift__(self, other,
|
|
998
|
+
def __rrshift__(self, other, _semantic=None):
|
|
928
999
|
check_bit_width(other, self)
|
|
929
1000
|
other = _unwrap_if_constexpr(other)
|
|
930
1001
|
if self.dtype.is_int_signed():
|
|
931
|
-
return
|
|
1002
|
+
return _semantic.ashr(other, self)
|
|
932
1003
|
else:
|
|
933
|
-
return
|
|
1004
|
+
return _semantic.lshr(other, self)
|
|
934
1005
|
|
|
935
1006
|
# >
|
|
936
1007
|
@builtin
|
|
937
|
-
def __gt__(self, other,
|
|
938
|
-
other =
|
|
939
|
-
return
|
|
1008
|
+
def __gt__(self, other, _semantic=None):
|
|
1009
|
+
other = _semantic.to_tensor(other)
|
|
1010
|
+
return _semantic.greater_than(self, other)
|
|
940
1011
|
|
|
941
1012
|
@builtin
|
|
942
|
-
def __rgt__(self, other,
|
|
943
|
-
other =
|
|
944
|
-
return
|
|
1013
|
+
def __rgt__(self, other, _semantic=None):
|
|
1014
|
+
other = _semantic.to_tensor(other)
|
|
1015
|
+
return _semantic.greater_than(other, self)
|
|
945
1016
|
|
|
946
1017
|
# >=
|
|
947
1018
|
@builtin
|
|
948
|
-
def __ge__(self, other,
|
|
949
|
-
other =
|
|
950
|
-
return
|
|
1019
|
+
def __ge__(self, other, _semantic=None):
|
|
1020
|
+
other = _semantic.to_tensor(other)
|
|
1021
|
+
return _semantic.greater_equal(self, other)
|
|
951
1022
|
|
|
952
1023
|
@builtin
|
|
953
|
-
def __rge__(self, other,
|
|
954
|
-
other =
|
|
955
|
-
return
|
|
1024
|
+
def __rge__(self, other, _semantic=None):
|
|
1025
|
+
other = _semantic.to_tensor(other)
|
|
1026
|
+
return _semantic.greater_equal(other, self)
|
|
956
1027
|
|
|
957
1028
|
# <
|
|
958
1029
|
@builtin
|
|
959
|
-
def __lt__(self, other,
|
|
960
|
-
other =
|
|
961
|
-
return
|
|
1030
|
+
def __lt__(self, other, _semantic=None):
|
|
1031
|
+
other = _semantic.to_tensor(other)
|
|
1032
|
+
return _semantic.less_than(self, other)
|
|
962
1033
|
|
|
963
1034
|
@builtin
|
|
964
|
-
def __rlt__(self, other,
|
|
965
|
-
other =
|
|
966
|
-
return
|
|
1035
|
+
def __rlt__(self, other, _semantic=None):
|
|
1036
|
+
other = _semantic.to_tensor(other)
|
|
1037
|
+
return _semantic.less_than(other, self)
|
|
967
1038
|
|
|
968
1039
|
# <=
|
|
969
1040
|
@builtin
|
|
970
|
-
def __le__(self, other,
|
|
971
|
-
other =
|
|
972
|
-
return
|
|
1041
|
+
def __le__(self, other, _semantic=None):
|
|
1042
|
+
other = _semantic.to_tensor(other)
|
|
1043
|
+
return _semantic.less_equal(self, other)
|
|
973
1044
|
|
|
974
1045
|
@builtin
|
|
975
|
-
def __rle__(self, other,
|
|
976
|
-
other =
|
|
977
|
-
return
|
|
1046
|
+
def __rle__(self, other, _semantic=None):
|
|
1047
|
+
other = _semantic.to_tensor(other)
|
|
1048
|
+
return _semantic.less_equal(other, self)
|
|
978
1049
|
|
|
979
1050
|
# ==
|
|
980
1051
|
@builtin
|
|
981
|
-
def __eq__(self, other,
|
|
982
|
-
other =
|
|
983
|
-
return
|
|
1052
|
+
def __eq__(self, other, _semantic=None):
|
|
1053
|
+
other = _semantic.to_tensor(other)
|
|
1054
|
+
return _semantic.equal(self, other)
|
|
984
1055
|
|
|
985
1056
|
@builtin
|
|
986
|
-
def __req__(self, other,
|
|
987
|
-
other =
|
|
988
|
-
return
|
|
1057
|
+
def __req__(self, other, _semantic=None):
|
|
1058
|
+
other = _semantic.to_tensor(other)
|
|
1059
|
+
return _semantic.equal(other, self)
|
|
989
1060
|
|
|
990
1061
|
@builtin
|
|
991
|
-
def __ne__(self, other,
|
|
992
|
-
other =
|
|
993
|
-
return
|
|
1062
|
+
def __ne__(self, other, _semantic=None):
|
|
1063
|
+
other = _semantic.to_tensor(other)
|
|
1064
|
+
return _semantic.not_equal(self, other)
|
|
994
1065
|
|
|
995
1066
|
@builtin
|
|
996
|
-
def __rne__(self, other,
|
|
997
|
-
other =
|
|
998
|
-
return
|
|
1067
|
+
def __rne__(self, other, _semantic=None):
|
|
1068
|
+
other = _semantic.to_tensor(other)
|
|
1069
|
+
return _semantic.not_equal(other, self)
|
|
999
1070
|
|
|
1000
1071
|
@builtin
|
|
1001
|
-
def logical_and(self, other,
|
|
1002
|
-
other =
|
|
1003
|
-
return
|
|
1072
|
+
def logical_and(self, other, _semantic=None):
|
|
1073
|
+
other = _semantic.to_tensor(other)
|
|
1074
|
+
return _semantic.logical_and(self, other)
|
|
1004
1075
|
|
|
1005
1076
|
@builtin
|
|
1006
|
-
def logical_or(self, other,
|
|
1007
|
-
other =
|
|
1008
|
-
return
|
|
1077
|
+
def logical_or(self, other, _semantic=None):
|
|
1078
|
+
other = _semantic.to_tensor(other)
|
|
1079
|
+
return _semantic.logical_or(self, other)
|
|
1009
1080
|
|
|
1010
1081
|
# note: __not__ isn't actually a magic method in python
|
|
1011
1082
|
# but it's ok because our ASTVisitor handles it
|
|
1012
1083
|
@builtin
|
|
1013
|
-
def __not__(self,
|
|
1014
|
-
return
|
|
1084
|
+
def __not__(self, _semantic=None):
|
|
1085
|
+
return _semantic.not_(self)
|
|
1015
1086
|
|
|
1016
1087
|
@builtin
|
|
1017
|
-
def __getitem__(self, slices,
|
|
1018
|
-
import builtins
|
|
1088
|
+
def __getitem__(self, slices, _semantic=None):
|
|
1019
1089
|
if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None:
|
|
1020
1090
|
slices = [slices]
|
|
1021
1091
|
if isinstance(slices, tuple):
|
|
1022
1092
|
slices = slices.values
|
|
1023
1093
|
ret = self
|
|
1024
1094
|
for dim, sl in enumerate(slices):
|
|
1025
|
-
if
|
|
1026
|
-
ret =
|
|
1027
|
-
elif isinstance(sl, (builtins.slice, slice)) and
|
|
1028
|
-
|
|
1095
|
+
if _unwrap_if_constexpr(sl) is None:
|
|
1096
|
+
ret = _semantic.expand_dims(ret, dim)
|
|
1097
|
+
elif isinstance(sl, (builtins.slice, slice)) and all(
|
|
1098
|
+
_unwrap_if_constexpr(arg) is None for arg in (sl.start, sl.stop, sl.step)):
|
|
1099
|
+
pass # an unsqueeze
|
|
1029
1100
|
else:
|
|
1030
1101
|
raise ValueError(f"unsupported tensor index: {sl}")
|
|
1031
1102
|
return ret
|
|
@@ -1036,11 +1107,11 @@ class tensor(base_value):
|
|
|
1036
1107
|
assert False, "Transposition must be created by the AST Visitor"
|
|
1037
1108
|
|
|
1038
1109
|
@builtin
|
|
1039
|
-
def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False,
|
|
1110
|
+
def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
|
|
1040
1111
|
"""
|
|
1041
1112
|
Alias for :py:func:`tensor.cast`.
|
|
1042
1113
|
"""
|
|
1043
|
-
return cast(self, dtype, fp_downcast_rounding, bitcast,
|
|
1114
|
+
return cast(self, dtype, fp_downcast_rounding, bitcast, _semantic=_semantic)
|
|
1044
1115
|
|
|
1045
1116
|
# Type stubs for functions added by the _tensor_member_fn decorator.
|
|
1046
1117
|
# (Unfortunately these can't be created automatically.)
|
|
@@ -1140,7 +1211,7 @@ class tensor(base_value):
|
|
|
1140
1211
|
def sigmoid(self) -> tensor:
|
|
1141
1212
|
...
|
|
1142
1213
|
|
|
1143
|
-
def softmax(self, ieee_rounding=False) -> tensor:
|
|
1214
|
+
def softmax(self, dim=None, keep_dims=False, ieee_rounding=False) -> tensor:
|
|
1144
1215
|
...
|
|
1145
1216
|
|
|
1146
1217
|
def ravel(self) -> tensor:
|
|
@@ -1164,6 +1235,9 @@ class tensor(base_value):
|
|
|
1164
1235
|
def xor_sum(self, axis=None, keep_dims=False) -> tensor:
|
|
1165
1236
|
...
|
|
1166
1237
|
|
|
1238
|
+
def reduce_or(self, axis=None, keep_dims=False) -> tensor:
|
|
1239
|
+
...
|
|
1240
|
+
|
|
1167
1241
|
def cumsum(self, axis=0, reverse=False) -> tensor:
|
|
1168
1242
|
...
|
|
1169
1243
|
|
|
@@ -1177,19 +1251,20 @@ class tensor(base_value):
|
|
|
1177
1251
|
...
|
|
1178
1252
|
|
|
1179
1253
|
|
|
1180
|
-
|
|
1254
|
+
def _type_for_tuple_values(values, fields=None):
|
|
1255
|
+
return tuple_type([constexpr_type(x) if isinstance(x, (int, float, dtype)) else x.type for x in values], fields)
|
|
1181
1256
|
|
|
1182
|
-
def __init__(self, args: list, type: tuple_type = None):
|
|
1183
|
-
self.values = [i for i in args]
|
|
1184
1257
|
|
|
1185
|
-
|
|
1186
|
-
if isinstance(x, dtype):
|
|
1187
|
-
return dtype
|
|
1188
|
-
if isinstance(x, int):
|
|
1189
|
-
return constexpr
|
|
1190
|
-
return x.type
|
|
1258
|
+
class tuple(base_value):
|
|
1191
1259
|
|
|
1192
|
-
|
|
1260
|
+
def __init__(self, args: Sequence, type: Optional[tuple_type] = None):
|
|
1261
|
+
self.values = [i for i in args]
|
|
1262
|
+
if isinstance(type, tuple_type):
|
|
1263
|
+
self.type = type
|
|
1264
|
+
elif type is not None: # make_template in ASTFunction.deserialize may pass us a list/tuple
|
|
1265
|
+
self.type = tuple_type(type)
|
|
1266
|
+
else:
|
|
1267
|
+
self.type = _type_for_tuple_values(self.values)
|
|
1193
1268
|
|
|
1194
1269
|
def __getitem__(self, idx: constexpr):
|
|
1195
1270
|
if isinstance(idx, int):
|
|
@@ -1197,7 +1272,6 @@ class tuple(base_value):
|
|
|
1197
1272
|
if isinstance(idx, constexpr):
|
|
1198
1273
|
return self.values[idx]
|
|
1199
1274
|
else:
|
|
1200
|
-
import builtins
|
|
1201
1275
|
assert isinstance(idx, (slice, builtins.slice))
|
|
1202
1276
|
return tuple(self.values[idx.start:idx.stop:idx.step])
|
|
1203
1277
|
|
|
@@ -1205,15 +1279,14 @@ class tuple(base_value):
|
|
|
1205
1279
|
return self.values[self.type.fields.index(name)]
|
|
1206
1280
|
|
|
1207
1281
|
# TODO: remove
|
|
1208
|
-
def
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
assert isinstance(idx, constexpr)
|
|
1282
|
+
def _setitem(self, idx, value):
|
|
1283
|
+
idx = _unwrap_if_constexpr(idx)
|
|
1284
|
+
assert isinstance(idx, int)
|
|
1212
1285
|
self.values[idx] = value
|
|
1286
|
+
self.type = _type_for_tuple_values(self.values, self.type.fields)
|
|
1213
1287
|
|
|
1214
1288
|
def __add__(self, other):
|
|
1215
|
-
|
|
1216
|
-
other = tuple(other)
|
|
1289
|
+
other = _normalize_tuple(other)
|
|
1217
1290
|
return tuple(self.values + other.values)
|
|
1218
1291
|
# return tuple(a + b for a, b in zip(self.values, other.values))
|
|
1219
1292
|
|
|
@@ -1222,13 +1295,10 @@ class tuple(base_value):
|
|
|
1222
1295
|
return tuple(self.values * other.value)
|
|
1223
1296
|
|
|
1224
1297
|
def __eq__(self, other):
|
|
1225
|
-
|
|
1226
|
-
if isinstance(other, (list, builtins.tuple)):
|
|
1227
|
-
other = tuple(other)
|
|
1298
|
+
other = _normalize_tuple(other)
|
|
1228
1299
|
return constexpr(self.values == other.values)
|
|
1229
1300
|
|
|
1230
1301
|
def __hash__(self):
|
|
1231
|
-
import builtins
|
|
1232
1302
|
return hash(builtins.tuple(self.values))
|
|
1233
1303
|
|
|
1234
1304
|
def __str__(self):
|
|
@@ -1244,6 +1314,9 @@ class tuple(base_value):
|
|
|
1244
1314
|
for v in self.values:
|
|
1245
1315
|
v._flatten_ir(handles)
|
|
1246
1316
|
|
|
1317
|
+
def __repr__(self):
|
|
1318
|
+
return f"({' ,'.join(repr(x) for x in self.values)})"
|
|
1319
|
+
|
|
1247
1320
|
|
|
1248
1321
|
class slice:
|
|
1249
1322
|
|
|
@@ -1259,12 +1332,13 @@ class tensor_descriptor_base_type(base_type):
|
|
|
1259
1332
|
def __init__(self, block_type: block_type):
|
|
1260
1333
|
self.block_type = block_type
|
|
1261
1334
|
|
|
1262
|
-
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[
|
|
1263
|
-
value =
|
|
1335
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
|
|
1336
|
+
value = tensor_descriptor_base(handles[cursor], self.block_type)
|
|
1264
1337
|
return value, cursor + 1
|
|
1265
1338
|
|
|
1266
|
-
def
|
|
1267
|
-
|
|
1339
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
1340
|
+
is_signed = self.block_type.element_ty.is_int_signed()
|
|
1341
|
+
out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed))
|
|
1268
1342
|
|
|
1269
1343
|
def __str__(self) -> str:
|
|
1270
1344
|
# ex. "tensor_descriptor<float32[16, 32]>"
|
|
@@ -1278,8 +1352,11 @@ class tensor_descriptor_base_type(base_type):
|
|
|
1278
1352
|
def __neq__(self, other) -> bool:
|
|
1279
1353
|
return not (self == other)
|
|
1280
1354
|
|
|
1355
|
+
def mangle(self) -> str:
|
|
1356
|
+
return f"TD{self.block_type.mangle()}"
|
|
1357
|
+
|
|
1281
1358
|
|
|
1282
|
-
class
|
|
1359
|
+
class tensor_descriptor_base(base_value):
|
|
1283
1360
|
""""
|
|
1284
1361
|
A tensor descriptor with unknown shape and strides
|
|
1285
1362
|
"""
|
|
@@ -1310,40 +1387,64 @@ class _experimental_tensor_descriptor_base(base_value):
|
|
|
1310
1387
|
return str(self.type)
|
|
1311
1388
|
|
|
1312
1389
|
@builtin
|
|
1313
|
-
def load(self, offsets: Sequence[constexpr | tensor],
|
|
1390
|
+
def load(self, offsets: Sequence[constexpr | tensor], _semantic=None) -> tensor:
|
|
1314
1391
|
"""Load a block from the descriptor starting at the given element offsets.
|
|
1315
1392
|
|
|
1316
1393
|
Values outside of the tensor bounds will be filled with zeros.
|
|
1317
1394
|
|
|
1318
1395
|
:note: Offset must be a multiple of 16-bytes
|
|
1319
1396
|
"""
|
|
1320
|
-
return
|
|
1397
|
+
return _semantic.descriptor_load(self, offsets, "", "")
|
|
1321
1398
|
|
|
1322
1399
|
@builtin
|
|
1323
|
-
def store(self, offsets: Sequence[constexpr | tensor], value: tensor,
|
|
1400
|
+
def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1324
1401
|
"""Store a block from the descriptor starting at the given element offsets.
|
|
1325
1402
|
|
|
1326
1403
|
Values outside of the tensor bounds will be ignored.
|
|
1327
1404
|
|
|
1328
1405
|
:note: Offset must be a multiple of 16-bytes
|
|
1329
1406
|
"""
|
|
1330
|
-
return
|
|
1407
|
+
return _semantic.descriptor_store(self, value, offsets)
|
|
1408
|
+
|
|
1409
|
+
@builtin
|
|
1410
|
+
def atomic_add(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1411
|
+
return _semantic.descriptor_atomic_add(self, value, offsets)
|
|
1412
|
+
|
|
1413
|
+
@builtin
|
|
1414
|
+
def atomic_min(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1415
|
+
return _semantic.descriptor_atomic_min(self, value, offsets)
|
|
1416
|
+
|
|
1417
|
+
@builtin
|
|
1418
|
+
def atomic_max(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1419
|
+
return _semantic.descriptor_atomic_max(self, value, offsets)
|
|
1420
|
+
|
|
1421
|
+
@builtin
|
|
1422
|
+
def atomic_and(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1423
|
+
return _semantic.descriptor_atomic_and(self, value, offsets)
|
|
1331
1424
|
|
|
1332
1425
|
@builtin
|
|
1333
|
-
def
|
|
1426
|
+
def atomic_or(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1427
|
+
return _semantic.descriptor_atomic_or(self, value, offsets)
|
|
1428
|
+
|
|
1429
|
+
@builtin
|
|
1430
|
+
def atomic_xor(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
|
|
1431
|
+
return _semantic.descriptor_atomic_xor(self, value, offsets)
|
|
1432
|
+
|
|
1433
|
+
@builtin
|
|
1434
|
+
def gather(self, *args, _semantic=None) -> tensor:
|
|
1334
1435
|
"""Gather multiple descriptors worth of data"""
|
|
1335
1436
|
assert len(args) == 2, f"descriptor gather only supports 2D indexing, but got {len(args)}"
|
|
1336
1437
|
x_offsets = args[0]
|
|
1337
1438
|
y_offset = args[1]
|
|
1338
|
-
return
|
|
1439
|
+
return _semantic.descriptor_gather(self, x_offsets, y_offset, "", "")
|
|
1339
1440
|
|
|
1340
1441
|
@builtin
|
|
1341
|
-
def scatter(self, value, *args,
|
|
1442
|
+
def scatter(self, value, *args, _semantic=None) -> tensor:
|
|
1342
1443
|
"""Scatter multiple descriptors worth of data"""
|
|
1343
1444
|
assert len(args) == 2, f"descriptor scatter only supports 2D indexing, but got {len(args)}"
|
|
1344
1445
|
x_offsets = args[0]
|
|
1345
1446
|
y_offset = args[1]
|
|
1346
|
-
return
|
|
1447
|
+
return _semantic.descriptor_scatter(self, value, x_offsets, y_offset)
|
|
1347
1448
|
|
|
1348
1449
|
|
|
1349
1450
|
class tensor_descriptor_type(tensor_descriptor_base_type):
|
|
@@ -1353,25 +1454,27 @@ class tensor_descriptor_type(tensor_descriptor_base_type):
|
|
|
1353
1454
|
self.shape_type = shape_type
|
|
1354
1455
|
self.strides_type = strides_type
|
|
1355
1456
|
|
|
1356
|
-
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[
|
|
1457
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
|
|
1357
1458
|
handle = handles[cursor]
|
|
1358
1459
|
cursor += 1
|
|
1359
1460
|
shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
|
|
1360
1461
|
strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
|
|
1361
1462
|
shape = shape.values
|
|
1362
1463
|
strides = strides.values
|
|
1363
|
-
value =
|
|
1464
|
+
value = tensor_descriptor(handle, shape, strides, self.block_type)
|
|
1364
1465
|
return value, cursor
|
|
1365
1466
|
|
|
1366
|
-
def
|
|
1367
|
-
|
|
1467
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
1468
|
+
super()._flatten_ir_types(builder, out)
|
|
1469
|
+
self.shape_type._flatten_ir_types(builder, out)
|
|
1470
|
+
self.strides_type._flatten_ir_types(builder, out)
|
|
1368
1471
|
|
|
1369
1472
|
def __eq__(self, other):
|
|
1370
1473
|
return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type
|
|
1371
1474
|
== other.strides_type)
|
|
1372
1475
|
|
|
1373
1476
|
|
|
1374
|
-
class
|
|
1477
|
+
class tensor_descriptor(tensor_descriptor_base):
|
|
1375
1478
|
"""A descriptor representing a tensor in global memory.
|
|
1376
1479
|
"""
|
|
1377
1480
|
|
|
@@ -1379,37 +1482,121 @@ class _experimental_tensor_descriptor(_experimental_tensor_descriptor_base):
|
|
|
1379
1482
|
"""Not called by user code."""
|
|
1380
1483
|
# IR handle
|
|
1381
1484
|
super().__init__(handle, block_type)
|
|
1485
|
+
# Global shape
|
|
1486
|
+
self.shape = tuple(shape)
|
|
1487
|
+
self.strides = tuple(strides)
|
|
1382
1488
|
self.type = tensor_descriptor_type(
|
|
1383
1489
|
block_type,
|
|
1384
|
-
shape_type=
|
|
1385
|
-
strides_type=
|
|
1490
|
+
shape_type=self.shape.type,
|
|
1491
|
+
strides_type=self.strides.type,
|
|
1386
1492
|
)
|
|
1387
|
-
# Global shape
|
|
1388
|
-
self.shape = shape
|
|
1389
|
-
self.strides = strides
|
|
1390
1493
|
|
|
1391
1494
|
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
1392
1495
|
handles.append(self.handle)
|
|
1393
|
-
|
|
1394
|
-
|
|
1496
|
+
self.shape._flatten_ir(handles)
|
|
1497
|
+
self.strides._flatten_ir(handles)
|
|
1395
1498
|
|
|
1396
1499
|
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1500
|
+
# -----------------------
|
|
1501
|
+
# aggregate
|
|
1502
|
+
# -----------------------
|
|
1503
|
+
|
|
1504
|
+
|
|
1505
|
+
@dataclass(frozen=True)
|
|
1506
|
+
class _aggregate_type(base_type):
|
|
1507
|
+
"""A generic base type for all Triton aggregate types.
|
|
1508
|
+
|
|
1509
|
+
This class contains a reference to the original user-defined Python class
|
|
1510
|
+
and a list of class fields with their Triton types.
|
|
1511
|
+
"""
|
|
1512
|
+
|
|
1513
|
+
base_cls: type
|
|
1514
|
+
fields: List[Tuple[str, base_type]]
|
|
1515
|
+
|
|
1516
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]:
|
|
1517
|
+
instance = self.base_cls._get_instance()
|
|
1518
|
+
for name, ty in self.fields:
|
|
1519
|
+
value, cursor = ty._unflatten_ir(handles, cursor)
|
|
1520
|
+
setattr(instance, name, value)
|
|
1521
|
+
return instance, cursor
|
|
1522
|
+
|
|
1523
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
1524
|
+
for name, ty in self.fields:
|
|
1525
|
+
ty._flatten_ir_types(builder, out)
|
|
1526
|
+
|
|
1527
|
+
def mangle(self) -> str:
|
|
1528
|
+
name = f"{self.base_cls.__module__}.{self.base_cls.__qualname__}"
|
|
1529
|
+
fields = [ty.mangle() for (name, ty) in self.fields]
|
|
1530
|
+
return f"{name}<{', '.join(fields)}>"
|
|
1531
|
+
|
|
1532
|
+
|
|
1533
|
+
def _aggregate(cls):
|
|
1534
|
+
|
|
1535
|
+
# Define the wrapped Triton value type.
|
|
1536
|
+
class aggregate_value(base_value):
|
|
1537
|
+
__triton_builtin__ = True
|
|
1538
|
+
__triton_aggregate__ = True
|
|
1539
|
+
|
|
1540
|
+
@classmethod
|
|
1541
|
+
def _get_instance(this_cls):
|
|
1542
|
+
return super().__new__(this_cls)
|
|
1543
|
+
|
|
1544
|
+
def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs):
|
|
1545
|
+
# Call into the user-defined constructor.
|
|
1546
|
+
instance = this_cls._get_instance()
|
|
1547
|
+
if isinstance(cls.__init__, JITCallable):
|
|
1548
|
+
raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
|
|
1549
|
+
extra_kwargs = {}
|
|
1550
|
+
if "_semantic" in inspect.signature(cls.__init__).parameters:
|
|
1551
|
+
extra_kwargs["_semantic"] = _semantic
|
|
1552
|
+
if "_generator" in inspect.signature(cls.__init__).parameters:
|
|
1553
|
+
extra_kwargs["_generator"] = _generator
|
|
1554
|
+
cls.__init__(instance, *args, **extra_kwargs, **kwargs)
|
|
1555
|
+
|
|
1556
|
+
# Require that the user-defined constructor initialized all fields.
|
|
1557
|
+
for name in cls.__annotations__.keys():
|
|
1558
|
+
if not hasattr(instance, name):
|
|
1559
|
+
raise AttributeError(f"constructor for {cls.__name__} did not initialize attribute '{name}'")
|
|
1560
|
+
|
|
1561
|
+
return instance
|
|
1562
|
+
|
|
1563
|
+
# Only allow setting attributes defined in the class annotations.
|
|
1564
|
+
def __setattr__(self, name, value):
|
|
1565
|
+
if name not in cls.__annotations__:
|
|
1566
|
+
raise AttributeError(f"{cls.__name__} has no attribute '{name}'")
|
|
1567
|
+
if not isinstance(value, cls.__annotations__[name]):
|
|
1568
|
+
raise TypeError(f"Expected {cls.__annotations__[name]} for attribute '{name}', got {type(value)}")
|
|
1569
|
+
super().__setattr__(name, value)
|
|
1570
|
+
|
|
1571
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
1572
|
+
for name in cls.__annotations__.keys():
|
|
1573
|
+
getattr(self, name)._flatten_ir(handles)
|
|
1574
|
+
|
|
1575
|
+
@property
|
|
1576
|
+
def type(self):
|
|
1577
|
+
return _aggregate_type(aggregate_value,
|
|
1578
|
+
[(name, getattr(self, name).type) for name in cls.__annotations__.keys()])
|
|
1579
|
+
|
|
1580
|
+
for (name, member) in inspect.getmembers(cls):
|
|
1581
|
+
if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITCallable):
|
|
1582
|
+
if name != "__init__":
|
|
1583
|
+
setattr(aggregate_value, name, member)
|
|
1584
|
+
|
|
1585
|
+
aggregate_value.__name__ = cls.__name__
|
|
1586
|
+
aggregate_value.__module__ = cls.__module__
|
|
1587
|
+
aggregate_value.__qualname__ = cls.__qualname__
|
|
1588
|
+
aggregate_value.__doc__ = cls.__doc__
|
|
1589
|
+
|
|
1590
|
+
return aggregate_value
|
|
1400
1591
|
|
|
1401
1592
|
|
|
1402
1593
|
# -----------------------
|
|
1403
1594
|
# SPMD Programming Model
|
|
1404
1595
|
# -----------------------
|
|
1405
|
-
def _constexpr_to_value(v):
|
|
1406
|
-
if isinstance(v, constexpr):
|
|
1407
|
-
return v.value
|
|
1408
|
-
return v
|
|
1409
1596
|
|
|
1410
1597
|
|
|
1411
1598
|
@builtin
|
|
1412
|
-
def program_id(axis,
|
|
1599
|
+
def program_id(axis, _semantic=None):
|
|
1413
1600
|
"""
|
|
1414
1601
|
Returns the id of the current program instance along the given :code:`axis`.
|
|
1415
1602
|
|
|
@@ -1417,26 +1604,26 @@ def program_id(axis, _builder=None):
|
|
|
1417
1604
|
:type axis: int
|
|
1418
1605
|
"""
|
|
1419
1606
|
# if axis == -1:
|
|
1420
|
-
# pid0 = program_id(0
|
|
1421
|
-
# pid1 = program_id(1
|
|
1422
|
-
# pid2 = program_id(2
|
|
1423
|
-
# npg0 = num_programs(0
|
|
1424
|
-
# npg1 = num_programs(1
|
|
1607
|
+
# pid0 = _semantic.program_id(0)
|
|
1608
|
+
# pid1 = _semantic.program_id(1)
|
|
1609
|
+
# pid2 = _semantic.program_id(2)
|
|
1610
|
+
# npg0 = _semantic.num_programs(0)
|
|
1611
|
+
# npg1 = _semantic.num_programs(1)
|
|
1425
1612
|
# return pid0 + pid1*npg0 + pid2*npg0*npg1
|
|
1426
|
-
axis =
|
|
1427
|
-
return
|
|
1613
|
+
axis = _unwrap_if_constexpr(axis)
|
|
1614
|
+
return _semantic.program_id(axis)
|
|
1428
1615
|
|
|
1429
1616
|
|
|
1430
1617
|
@builtin
|
|
1431
|
-
def num_programs(axis,
|
|
1618
|
+
def num_programs(axis, _semantic=None):
|
|
1432
1619
|
"""
|
|
1433
1620
|
Returns the number of program instances launched along the given :code:`axis`.
|
|
1434
1621
|
|
|
1435
1622
|
:param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
|
|
1436
1623
|
:type axis: int
|
|
1437
1624
|
"""
|
|
1438
|
-
axis =
|
|
1439
|
-
return
|
|
1625
|
+
axis = _unwrap_if_constexpr(axis)
|
|
1626
|
+
return _semantic.num_programs(axis)
|
|
1440
1627
|
|
|
1441
1628
|
|
|
1442
1629
|
# -----------------------
|
|
@@ -1445,10 +1632,10 @@ def num_programs(axis, _builder=None):
|
|
|
1445
1632
|
|
|
1446
1633
|
|
|
1447
1634
|
@builtin
|
|
1448
|
-
def arange(start, end,
|
|
1449
|
-
start =
|
|
1450
|
-
end =
|
|
1451
|
-
return
|
|
1635
|
+
def arange(start, end, _semantic=None):
|
|
1636
|
+
start = _unwrap_if_constexpr(start)
|
|
1637
|
+
end = _unwrap_if_constexpr(end)
|
|
1638
|
+
return _semantic.arange(start, end)
|
|
1452
1639
|
|
|
1453
1640
|
|
|
1454
1641
|
arange.__doc__ = f"""
|
|
@@ -1465,8 +1652,8 @@ arange.__doc__ = f"""
|
|
|
1465
1652
|
|
|
1466
1653
|
|
|
1467
1654
|
def _unwrap_shape(shape):
|
|
1468
|
-
shape =
|
|
1469
|
-
return [
|
|
1655
|
+
shape = _unwrap_if_constexpr(shape)
|
|
1656
|
+
return [_unwrap_if_constexpr(s) for s in shape]
|
|
1470
1657
|
|
|
1471
1658
|
|
|
1472
1659
|
def _shape_check_impl(shape):
|
|
@@ -1476,7 +1663,7 @@ def _shape_check_impl(shape):
|
|
|
1476
1663
|
|
|
1477
1664
|
|
|
1478
1665
|
@builtin
|
|
1479
|
-
def full(shape, value, dtype,
|
|
1666
|
+
def full(shape, value, dtype, _semantic=None):
|
|
1480
1667
|
"""
|
|
1481
1668
|
Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`.
|
|
1482
1669
|
|
|
@@ -1488,9 +1675,9 @@ def full(shape, value, dtype, _builder=None):
|
|
|
1488
1675
|
:type dtype: tl.dtype
|
|
1489
1676
|
"""
|
|
1490
1677
|
shape = _shape_check_impl(shape)
|
|
1491
|
-
value =
|
|
1492
|
-
dtype =
|
|
1493
|
-
return
|
|
1678
|
+
value = _unwrap_if_constexpr(value)
|
|
1679
|
+
dtype = _unwrap_if_constexpr(dtype)
|
|
1680
|
+
return _semantic.full(shape, value, dtype)
|
|
1494
1681
|
|
|
1495
1682
|
|
|
1496
1683
|
# -----------------------
|
|
@@ -1499,7 +1686,7 @@ def full(shape, value, dtype, _builder=None):
|
|
|
1499
1686
|
|
|
1500
1687
|
|
|
1501
1688
|
@builtin
|
|
1502
|
-
def broadcast(input, other,
|
|
1689
|
+
def broadcast(input, other, _semantic=None):
|
|
1503
1690
|
"""
|
|
1504
1691
|
Tries to broadcast the two given blocks to a common compatible shape.
|
|
1505
1692
|
|
|
@@ -1508,12 +1695,12 @@ def broadcast(input, other, _builder=None):
|
|
|
1508
1695
|
:param other: The second input tensor.
|
|
1509
1696
|
:type other: Block
|
|
1510
1697
|
"""
|
|
1511
|
-
return
|
|
1698
|
+
return _semantic.broadcast_impl_value(input, other)
|
|
1512
1699
|
|
|
1513
1700
|
|
|
1514
1701
|
@_tensor_member_fn
|
|
1515
1702
|
@builtin
|
|
1516
|
-
def broadcast_to(input, *shape,
|
|
1703
|
+
def broadcast_to(input, *shape, _semantic=None):
|
|
1517
1704
|
"""
|
|
1518
1705
|
Tries to broadcast the given tensor to a new :code:`shape`.
|
|
1519
1706
|
|
|
@@ -1529,12 +1716,12 @@ def broadcast_to(input, *shape, _builder=None):
|
|
|
1529
1716
|
broadcast_to(x, 32, 32)
|
|
1530
1717
|
"""
|
|
1531
1718
|
shape = _shape_check_impl(_unwrap_iterable(shape))
|
|
1532
|
-
return
|
|
1719
|
+
return _semantic.broadcast_impl_shape(input, shape)
|
|
1533
1720
|
|
|
1534
1721
|
|
|
1535
1722
|
@_tensor_member_fn
|
|
1536
1723
|
@builtin
|
|
1537
|
-
def trans(input: tensor, *dims,
|
|
1724
|
+
def trans(input: tensor, *dims, _semantic=None):
|
|
1538
1725
|
"""
|
|
1539
1726
|
Permutes the dimensions of a tensor.
|
|
1540
1727
|
|
|
@@ -1543,7 +1730,7 @@ def trans(input: tensor, *dims, _builder=None):
|
|
|
1543
1730
|
|
|
1544
1731
|
:param input: The input tensor.
|
|
1545
1732
|
:param dims: The desired ordering of dimensions. For example,
|
|
1546
|
-
:code:`(2, 1, 0)` reverses the order dims in a
|
|
1733
|
+
:code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
|
|
1547
1734
|
|
|
1548
1735
|
:code:`dims` can be passed as a tuple or as individual parameters: ::
|
|
1549
1736
|
|
|
@@ -1557,19 +1744,19 @@ def trans(input: tensor, *dims, _builder=None):
|
|
|
1557
1744
|
dims = _unwrap_iterable(dims)
|
|
1558
1745
|
if not dims:
|
|
1559
1746
|
dims = (1, 0)
|
|
1560
|
-
return
|
|
1747
|
+
return _semantic.permute(input, dims)
|
|
1561
1748
|
|
|
1562
1749
|
|
|
1563
1750
|
@_tensor_member_fn
|
|
1564
1751
|
@builtin
|
|
1565
|
-
def permute(input, *dims,
|
|
1752
|
+
def permute(input, *dims, _semantic=None):
|
|
1566
1753
|
"""
|
|
1567
1754
|
Permutes the dimensions of a tensor.
|
|
1568
1755
|
|
|
1569
1756
|
:param input: The input tensor.
|
|
1570
1757
|
:type input: Block
|
|
1571
1758
|
:param dims: The desired ordering of dimensions. For example,
|
|
1572
|
-
:code:`(2, 1, 0)` reverses the order dims in a
|
|
1759
|
+
:code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
|
|
1573
1760
|
|
|
1574
1761
|
:code:`dims` can be passed as a tuple or as individual parameters: ::
|
|
1575
1762
|
|
|
@@ -1581,11 +1768,11 @@ def permute(input, *dims, _builder=None):
|
|
|
1581
1768
|
:code:`dims` is empty, it tries to do a (1,0) permutation.
|
|
1582
1769
|
"""
|
|
1583
1770
|
dims = _unwrap_iterable(dims)
|
|
1584
|
-
return
|
|
1771
|
+
return _semantic.permute(input, dims)
|
|
1585
1772
|
|
|
1586
1773
|
|
|
1587
1774
|
@builtin
|
|
1588
|
-
def cat(input, other, can_reorder=False,
|
|
1775
|
+
def cat(input, other, can_reorder=False, _semantic=None):
|
|
1589
1776
|
"""
|
|
1590
1777
|
Concatenate the given blocks
|
|
1591
1778
|
|
|
@@ -1598,11 +1785,11 @@ def cat(input, other, can_reorder=False, _builder=None):
|
|
|
1598
1785
|
order does not matter (e.g., result is only used in reduction ops).
|
|
1599
1786
|
Current implementation of `cat` supports only can_reorder=True.
|
|
1600
1787
|
"""
|
|
1601
|
-
return
|
|
1788
|
+
return _semantic.cat(input, other, can_reorder)
|
|
1602
1789
|
|
|
1603
1790
|
|
|
1604
1791
|
@builtin
|
|
1605
|
-
def join(a, b,
|
|
1792
|
+
def join(a, b, _semantic=None):
|
|
1606
1793
|
"""
|
|
1607
1794
|
Join the given tensors in a new, minor dimension.
|
|
1608
1795
|
|
|
@@ -1622,17 +1809,25 @@ def join(a, b, _builder=None):
|
|
|
1622
1809
|
:param b: The second input tensor.
|
|
1623
1810
|
:type b: Tensor
|
|
1624
1811
|
"""
|
|
1625
|
-
return
|
|
1812
|
+
return _semantic.join(a, b)
|
|
1626
1813
|
|
|
1627
1814
|
|
|
1628
|
-
|
|
1629
|
-
|
|
1630
|
-
|
|
1815
|
+
def _unsplat(x, _semantic=None, _generator=None):
|
|
1816
|
+
"""
|
|
1817
|
+
Convert a single-element tensor to a scalar.
|
|
1818
|
+
"""
|
|
1819
|
+
if len(x.shape) == 0:
|
|
1820
|
+
return x
|
|
1821
|
+
numel = 1
|
|
1822
|
+
for d in x.shape:
|
|
1823
|
+
numel *= d
|
|
1824
|
+
assert numel == 1, "can only unsplat single-element tensors"
|
|
1825
|
+
return _semantic.unsplat(x)
|
|
1631
1826
|
|
|
1632
1827
|
|
|
1633
1828
|
@_tensor_member_fn
|
|
1634
1829
|
@builtin
|
|
1635
|
-
def split(a,
|
|
1830
|
+
def split(a, _semantic=None, _generator=None) -> tuple[tensor, tensor]:
|
|
1636
1831
|
"""
|
|
1637
1832
|
Split a tensor in two along its last dim, which must have size 2.
|
|
1638
1833
|
|
|
@@ -1649,25 +1844,25 @@ def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]:
|
|
|
1649
1844
|
:type a: Tensor
|
|
1650
1845
|
"""
|
|
1651
1846
|
# If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars.
|
|
1652
|
-
# But
|
|
1847
|
+
# But _semantic.split can only handle returning tensors. Work around this by
|
|
1653
1848
|
# expanding the input to shape [1,2] and then reducing the result.
|
|
1654
1849
|
was_rank_1 = len(a.shape) == 1
|
|
1655
1850
|
if was_rank_1:
|
|
1656
|
-
a =
|
|
1851
|
+
a = _semantic.expand_dims(a, 0)
|
|
1657
1852
|
|
|
1658
|
-
out_lhs, out_rhs =
|
|
1853
|
+
out_lhs, out_rhs = _semantic.split(a)
|
|
1659
1854
|
|
|
1660
1855
|
if was_rank_1:
|
|
1661
1856
|
# Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar.
|
|
1662
|
-
out_lhs =
|
|
1663
|
-
out_rhs =
|
|
1857
|
+
out_lhs = _unsplat(out_lhs, _semantic=_semantic, _generator=_generator)
|
|
1858
|
+
out_rhs = _unsplat(out_rhs, _semantic=_semantic, _generator=_generator)
|
|
1664
1859
|
|
|
1665
1860
|
return out_lhs, out_rhs
|
|
1666
1861
|
|
|
1667
1862
|
|
|
1668
1863
|
@_tensor_member_fn
|
|
1669
1864
|
@builtin
|
|
1670
|
-
def view(input, *shape,
|
|
1865
|
+
def view(input, *shape, _semantic=None):
|
|
1671
1866
|
"""
|
|
1672
1867
|
Returns a tensor with the same elements as `input` but a different shape.
|
|
1673
1868
|
The order of the elements may not be preserved.
|
|
@@ -1684,12 +1879,21 @@ def view(input, *shape, _builder=None):
|
|
|
1684
1879
|
"""
|
|
1685
1880
|
warn("view is deprecated, please use reshape with can_reorder being true.")
|
|
1686
1881
|
shape = _shape_check_impl(_unwrap_iterable(shape))
|
|
1687
|
-
return
|
|
1882
|
+
return _semantic.reshape(input, shape, can_reorder=True)
|
|
1688
1883
|
|
|
1689
1884
|
|
|
1690
1885
|
@_tensor_member_fn
|
|
1691
1886
|
@builtin
|
|
1692
|
-
def
|
|
1887
|
+
def item(input, _semantic=None, _generator=None):
|
|
1888
|
+
"""
|
|
1889
|
+
Converts a single-element tensor into a scalar.
|
|
1890
|
+
"""
|
|
1891
|
+
return _unsplat(input, _semantic=_semantic, _generator=_generator)
|
|
1892
|
+
|
|
1893
|
+
|
|
1894
|
+
@_tensor_member_fn
|
|
1895
|
+
@builtin
|
|
1896
|
+
def reshape(input, *shape, can_reorder=False, _semantic=None, _generator=None):
|
|
1693
1897
|
"""
|
|
1694
1898
|
Returns a tensor with the same number of elements as input but with the
|
|
1695
1899
|
provided shape.
|
|
@@ -1705,7 +1909,9 @@ def reshape(input, *shape, can_reorder=False, _builder=None):
|
|
|
1705
1909
|
reshape(x, 32, 32)
|
|
1706
1910
|
"""
|
|
1707
1911
|
shape = _shape_check_impl(_unwrap_iterable(shape))
|
|
1708
|
-
|
|
1912
|
+
if len(shape) == 0:
|
|
1913
|
+
return _unsplat(input, _semantic=_semantic, _generator=_generator)
|
|
1914
|
+
return _semantic.reshape(input, shape, can_reorder)
|
|
1709
1915
|
|
|
1710
1916
|
|
|
1711
1917
|
def _wrap_axis(axis, ndim):
|
|
@@ -1717,7 +1923,7 @@ def _wrap_axis(axis, ndim):
|
|
|
1717
1923
|
|
|
1718
1924
|
@_tensor_member_fn
|
|
1719
1925
|
@builtin
|
|
1720
|
-
def expand_dims(input, axis,
|
|
1926
|
+
def expand_dims(input, axis, _semantic=None):
|
|
1721
1927
|
"""
|
|
1722
1928
|
Expand the shape of a tensor, by inserting new length-1 dimensions.
|
|
1723
1929
|
|
|
@@ -1730,24 +1936,24 @@ def expand_dims(input, axis, _builder=None):
|
|
|
1730
1936
|
:type axis: int | Sequence[int]
|
|
1731
1937
|
|
|
1732
1938
|
"""
|
|
1733
|
-
input =
|
|
1734
|
-
axis =
|
|
1939
|
+
input = _semantic.to_tensor(input)
|
|
1940
|
+
axis = _unwrap_if_constexpr(axis)
|
|
1735
1941
|
axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis]
|
|
1736
1942
|
new_ndim = len(input.shape) + len(axes)
|
|
1737
|
-
axes = [_wrap_axis(
|
|
1943
|
+
axes = [_wrap_axis(_unwrap_if_constexpr(d), new_ndim) for d in axes]
|
|
1738
1944
|
|
|
1739
1945
|
if len(set(axes)) != len(axes):
|
|
1740
1946
|
raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}")
|
|
1741
1947
|
|
|
1742
1948
|
ret = input
|
|
1743
1949
|
for a in sorted(axes):
|
|
1744
|
-
ret =
|
|
1950
|
+
ret = _semantic.expand_dims(ret, a)
|
|
1745
1951
|
return ret
|
|
1746
1952
|
|
|
1747
1953
|
|
|
1748
1954
|
@_tensor_member_fn
|
|
1749
1955
|
@builtin
|
|
1750
|
-
def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False,
|
|
1956
|
+
def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
|
|
1751
1957
|
"""
|
|
1752
1958
|
Casts a tensor to the given :code:`dtype`.
|
|
1753
1959
|
|
|
@@ -1763,13 +1969,13 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas
|
|
|
1763
1969
|
:code:`dtype`, instead of being numerically casted.
|
|
1764
1970
|
:type bitcast: bool, optional
|
|
1765
1971
|
"""
|
|
1766
|
-
input =
|
|
1767
|
-
dtype =
|
|
1768
|
-
fp_downcast_rounding =
|
|
1769
|
-
bitcast =
|
|
1972
|
+
input = _semantic.to_tensor(input)
|
|
1973
|
+
dtype = _unwrap_if_constexpr(dtype)
|
|
1974
|
+
fp_downcast_rounding = _unwrap_if_constexpr(fp_downcast_rounding)
|
|
1975
|
+
bitcast = _unwrap_if_constexpr(bitcast)
|
|
1770
1976
|
if bitcast:
|
|
1771
|
-
return
|
|
1772
|
-
return
|
|
1977
|
+
return _semantic.bitcast(input, dtype)
|
|
1978
|
+
return _semantic.cast(input, dtype, fp_downcast_rounding)
|
|
1773
1979
|
|
|
1774
1980
|
|
|
1775
1981
|
# -----------------------
|
|
@@ -1779,7 +1985,7 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas
|
|
|
1779
1985
|
|
|
1780
1986
|
@builtin
|
|
1781
1987
|
def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32,
|
|
1782
|
-
|
|
1988
|
+
_semantic=None):
|
|
1783
1989
|
"""
|
|
1784
1990
|
Returns the matrix product of two blocks.
|
|
1785
1991
|
|
|
@@ -1804,19 +2010,20 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
|
|
|
1804
2010
|
"""
|
|
1805
2011
|
assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified"
|
|
1806
2012
|
if input_precision is None:
|
|
1807
|
-
supports_tf32 =
|
|
1808
|
-
|
|
1809
|
-
|
|
2013
|
+
supports_tf32 = "tf32" in _semantic.builder.options.allowed_dot_input_precisions
|
|
2014
|
+
input_precision = knobs.language.fp32_default or ("tf32" if (supports_tf32 and
|
|
2015
|
+
(allow_tf32 or allow_tf32 is None)) else "ieee")
|
|
1810
2016
|
|
|
1811
|
-
input_precision =
|
|
1812
|
-
out_dtype =
|
|
1813
|
-
max_num_imprecise_acc =
|
|
1814
|
-
|
|
2017
|
+
input_precision = _unwrap_if_constexpr(input_precision)
|
|
2018
|
+
out_dtype = _unwrap_if_constexpr(out_dtype)
|
|
2019
|
+
max_num_imprecise_acc = _unwrap_if_constexpr(max_num_imprecise_acc)
|
|
2020
|
+
acc = _unwrap_if_constexpr(acc)
|
|
2021
|
+
return _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
|
|
1815
2022
|
|
|
1816
2023
|
|
|
1817
2024
|
@builtin
|
|
1818
|
-
def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False,
|
|
1819
|
-
|
|
2025
|
+
def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True,
|
|
2026
|
+
rhs_k_pack=True, out_dtype=float32, _semantic=None):
|
|
1820
2027
|
"""
|
|
1821
2028
|
Returns the matrix product of two blocks in microscaling format.
|
|
1822
2029
|
|
|
@@ -1843,11 +2050,15 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
|
|
|
1843
2050
|
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
|
|
1844
2051
|
:type rhs_format: str
|
|
1845
2052
|
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
|
|
2053
|
+
:param lhs_k_pack: If false, the lhs tensor is packed into uint8 along M dimension.
|
|
2054
|
+
:type lhs_k_pack: bool, optional
|
|
2055
|
+
:param rhs_k_pack: If false, the rhs tensor is packed into uint8 along N dimension.
|
|
2056
|
+
:type rhs_k_pack: bool, optional
|
|
1846
2057
|
"""
|
|
1847
|
-
out_dtype =
|
|
2058
|
+
out_dtype = _unwrap_if_constexpr(out_dtype)
|
|
1848
2059
|
assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment"
|
|
1849
|
-
return
|
|
1850
|
-
|
|
2060
|
+
return _semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, lhs_k_pack,
|
|
2061
|
+
rhs_k_pack, out_dtype)
|
|
1851
2062
|
|
|
1852
2063
|
|
|
1853
2064
|
# -----------------------
|
|
@@ -1857,7 +2068,7 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
|
|
|
1857
2068
|
|
|
1858
2069
|
@builtin
|
|
1859
2070
|
def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="",
|
|
1860
|
-
volatile=False,
|
|
2071
|
+
volatile=False, _semantic=None):
|
|
1861
2072
|
"""
|
|
1862
2073
|
Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
|
|
1863
2074
|
|
|
@@ -1892,8 +2103,9 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
|
|
|
1892
2103
|
:type boundary_check: tuple of ints, optional
|
|
1893
2104
|
:param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value.
|
|
1894
2105
|
:param cache_modifier: changes cache option in NVIDIA PTX
|
|
1895
|
-
:type cache_modifier: str, optional, should be one of {"", "ca", "cg"}, where "ca" stands for
|
|
1896
|
-
cache at all levels
|
|
2106
|
+
:type cache_modifier: str, optional, should be one of {"", ".ca", ".cg", ".cv"}, where ".ca" stands for
|
|
2107
|
+
cache at all levels, ".cg" stands for cache at global level (cache in L2 and below, not L1),
|
|
2108
|
+
and ".cv" means don’t cache and fetch again. see
|
|
1897
2109
|
`cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
|
|
1898
2110
|
:param eviction_policy: changes eviction policy in NVIDIA PTX
|
|
1899
2111
|
:type eviction_policy: str, optional
|
|
@@ -1901,57 +2113,37 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
|
|
|
1901
2113
|
:type volatile: bool, optional
|
|
1902
2114
|
"""
|
|
1903
2115
|
# `mask` and `other` can be constexpr
|
|
1904
|
-
mask =
|
|
1905
|
-
other =
|
|
2116
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2117
|
+
other = _unwrap_if_constexpr(other)
|
|
1906
2118
|
if mask is not None:
|
|
1907
|
-
mask =
|
|
2119
|
+
mask = _semantic.to_tensor(mask)
|
|
1908
2120
|
if other is not None:
|
|
1909
|
-
other =
|
|
1910
|
-
padding_option =
|
|
1911
|
-
cache_modifier =
|
|
1912
|
-
eviction_policy =
|
|
1913
|
-
volatile =
|
|
1914
|
-
return
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
@builtin
|
|
1919
|
-
def _experimental_reinterpret_tensor_descriptor(desc_ptr, block_shape, dtype,
|
|
1920
|
-
_builder=None) -> _experimental_tensor_descriptor_base:
|
|
1921
|
-
"""
|
|
1922
|
-
Reinterpret a generic pointer as a TMA-backed tensor descriptor object.
|
|
1923
|
-
"""
|
|
1924
|
-
block_ty = block_type(_constexpr_to_value(dtype), block_shape)
|
|
1925
|
-
return semantic.reinterpret_tensor_descriptor(desc_ptr, block_ty, _builder)
|
|
2121
|
+
other = _semantic.to_tensor(other)
|
|
2122
|
+
padding_option = _unwrap_if_constexpr(padding_option)
|
|
2123
|
+
cache_modifier = _unwrap_if_constexpr(cache_modifier)
|
|
2124
|
+
eviction_policy = _unwrap_if_constexpr(eviction_policy)
|
|
2125
|
+
volatile = _unwrap_if_constexpr(volatile)
|
|
2126
|
+
return _semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
|
|
2127
|
+
volatile)
|
|
1926
2128
|
|
|
1927
2129
|
|
|
1928
2130
|
@builtin
|
|
1929
|
-
def
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
This loads a tensor of data based on the descriptor and offsets.
|
|
1935
|
-
"""
|
|
1936
|
-
desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, shape, dtype, _builder=_builder)
|
|
1937
|
-
return desc.load(offsets, _builder=_builder)
|
|
2131
|
+
def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor],
|
|
2132
|
+
_semantic=None) -> tensor:
|
|
2133
|
+
"""Load a block of data from a tensor descriptor."""
|
|
2134
|
+
return desc.load(offsets, _semantic=_semantic)
|
|
1938
2135
|
|
|
1939
2136
|
|
|
1940
2137
|
@builtin
|
|
1941
|
-
def
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
This stores a tensor of data based on the descriptor and offsets.
|
|
1947
|
-
"""
|
|
1948
|
-
desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, value.shape, value.dtype, _builder=_builder)
|
|
1949
|
-
return desc.store(offsets, value, _builder=_builder)
|
|
2138
|
+
def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor], value: tensor,
|
|
2139
|
+
_semantic=None) -> tensor:
|
|
2140
|
+
"""Store a block of data to a tensor descriptor."""
|
|
2141
|
+
return desc.store(offsets, value, _semantic=_semantic)
|
|
1950
2142
|
|
|
1951
2143
|
|
|
1952
2144
|
@_tensor_member_fn
|
|
1953
2145
|
@builtin
|
|
1954
|
-
def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="",
|
|
2146
|
+
def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _semantic=None):
|
|
1955
2147
|
"""
|
|
1956
2148
|
Store a tensor of data into memory locations defined by `pointer`.
|
|
1957
2149
|
|
|
@@ -1991,17 +2183,17 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict
|
|
|
1991
2183
|
:type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"}
|
|
1992
2184
|
"""
|
|
1993
2185
|
# `value` can be constexpr
|
|
1994
|
-
value =
|
|
1995
|
-
mask =
|
|
2186
|
+
value = _semantic.to_tensor(value)
|
|
2187
|
+
mask = _unwrap_if_constexpr(mask)
|
|
1996
2188
|
if mask is not None:
|
|
1997
|
-
mask =
|
|
1998
|
-
cache_modifier =
|
|
1999
|
-
eviction_policy =
|
|
2000
|
-
return
|
|
2189
|
+
mask = _semantic.to_tensor(mask)
|
|
2190
|
+
cache_modifier = _unwrap_if_constexpr(cache_modifier)
|
|
2191
|
+
eviction_policy = _unwrap_if_constexpr(eviction_policy)
|
|
2192
|
+
return _semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy)
|
|
2001
2193
|
|
|
2002
2194
|
|
|
2003
2195
|
@builtin
|
|
2004
|
-
def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order,
|
|
2196
|
+
def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _semantic=None):
|
|
2005
2197
|
"""
|
|
2006
2198
|
Returns a pointer to a block in a parent tensor
|
|
2007
2199
|
|
|
@@ -2012,30 +2204,34 @@ def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _b
|
|
|
2012
2204
|
:param block_shape: The shape of the block
|
|
2013
2205
|
:param order: The order of the original data format
|
|
2014
2206
|
"""
|
|
2015
|
-
return
|
|
2207
|
+
return _semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order)
|
|
2016
2208
|
|
|
2017
2209
|
|
|
2210
|
+
@must_use_result(
|
|
2211
|
+
"Note that tl.advance does not have any side effects. To move the block pointer, you need to assign the result of tl.advance to a variable."
|
|
2212
|
+
)
|
|
2018
2213
|
@_tensor_member_fn
|
|
2019
2214
|
@builtin
|
|
2020
|
-
def advance(base, offsets,
|
|
2215
|
+
def advance(base, offsets, _semantic=None):
|
|
2021
2216
|
"""
|
|
2022
2217
|
Advance a block pointer
|
|
2023
2218
|
|
|
2024
2219
|
:param base: the block pointer to advance
|
|
2025
2220
|
:param offsets: the offsets to advance, a tuple by dimension
|
|
2026
2221
|
"""
|
|
2027
|
-
return
|
|
2222
|
+
return _semantic.advance(base, offsets)
|
|
2028
2223
|
|
|
2029
2224
|
|
|
2030
2225
|
@builtin
|
|
2031
|
-
def
|
|
2226
|
+
def make_tensor_descriptor(
|
|
2032
2227
|
base: tensor,
|
|
2033
2228
|
shape: List[tensor],
|
|
2034
2229
|
strides: List[tensor],
|
|
2035
2230
|
block_shape: List[constexpr],
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2231
|
+
padding_option="zero",
|
|
2232
|
+
_semantic=None,
|
|
2233
|
+
) -> tensor_descriptor:
|
|
2234
|
+
"""Make a tensor descriptor object
|
|
2039
2235
|
|
|
2040
2236
|
:param base: the base pointer of the tensor, must be 16-byte aligned
|
|
2041
2237
|
:param shape: A list of non-negative integers representing the tensor shape
|
|
@@ -2056,7 +2252,7 @@ def _experimental_make_tensor_descriptor(
|
|
|
2056
2252
|
|
|
2057
2253
|
@triton.jit
|
|
2058
2254
|
def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
|
|
2059
|
-
desc = tl.
|
|
2255
|
+
desc = tl.make_tensor_descriptor(
|
|
2060
2256
|
in_out_ptr,
|
|
2061
2257
|
shape=[M, N],
|
|
2062
2258
|
strides=[N, 1],
|
|
@@ -2082,7 +2278,9 @@ def _experimental_make_tensor_descriptor(
|
|
|
2082
2278
|
inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)
|
|
2083
2279
|
|
|
2084
2280
|
"""
|
|
2085
|
-
|
|
2281
|
+
|
|
2282
|
+
padding_option = _unwrap_if_constexpr(padding_option)
|
|
2283
|
+
return _semantic.make_tensor_descriptor(base, shape, strides, block_shape, padding_option)
|
|
2086
2284
|
|
|
2087
2285
|
|
|
2088
2286
|
# -----------------------
|
|
@@ -2124,89 +2322,89 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
|
|
|
2124
2322
|
@_tensor_member_fn
|
|
2125
2323
|
@builtin
|
|
2126
2324
|
@_add_atomic_docstr("compare-and-swap", has_cmp=True)
|
|
2127
|
-
def atomic_cas(pointer, cmp, val, sem=None, scope=None,
|
|
2128
|
-
cmp =
|
|
2129
|
-
val =
|
|
2130
|
-
sem =
|
|
2131
|
-
scope =
|
|
2132
|
-
return
|
|
2325
|
+
def atomic_cas(pointer, cmp, val, sem=None, scope=None, _semantic=None):
|
|
2326
|
+
cmp = _semantic.to_tensor(cmp)
|
|
2327
|
+
val = _semantic.to_tensor(val)
|
|
2328
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2329
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2330
|
+
return _semantic.atomic_cas(pointer, cmp, val, sem, scope)
|
|
2133
2331
|
|
|
2134
2332
|
|
|
2135
2333
|
@_tensor_member_fn
|
|
2136
2334
|
@builtin
|
|
2137
2335
|
@_add_atomic_docstr("exchange")
|
|
2138
|
-
def atomic_xchg(pointer, val, mask=None, sem=None, scope=None,
|
|
2139
|
-
val =
|
|
2140
|
-
sem =
|
|
2141
|
-
scope =
|
|
2142
|
-
mask =
|
|
2143
|
-
return
|
|
2336
|
+
def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2337
|
+
val = _semantic.to_tensor(val)
|
|
2338
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2339
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2340
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2341
|
+
return _semantic.atomic_xchg(pointer, val, mask, sem, scope)
|
|
2144
2342
|
|
|
2145
2343
|
|
|
2146
2344
|
@_tensor_member_fn
|
|
2147
2345
|
@builtin
|
|
2148
2346
|
@_add_atomic_docstr("add")
|
|
2149
|
-
def atomic_add(pointer, val, mask=None, sem=None, scope=None,
|
|
2150
|
-
val =
|
|
2151
|
-
sem =
|
|
2152
|
-
scope =
|
|
2153
|
-
mask =
|
|
2154
|
-
return
|
|
2347
|
+
def atomic_add(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2348
|
+
val = _semantic.to_tensor(val)
|
|
2349
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2350
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2351
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2352
|
+
return _semantic.atomic_add(pointer, val, mask, sem, scope)
|
|
2155
2353
|
|
|
2156
2354
|
|
|
2157
2355
|
@_tensor_member_fn
|
|
2158
2356
|
@builtin
|
|
2159
2357
|
@_add_atomic_docstr("max")
|
|
2160
|
-
def atomic_max(pointer, val, mask=None, sem=None, scope=None,
|
|
2161
|
-
val =
|
|
2162
|
-
sem =
|
|
2163
|
-
scope =
|
|
2164
|
-
mask =
|
|
2165
|
-
return
|
|
2358
|
+
def atomic_max(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2359
|
+
val = _semantic.to_tensor(val)
|
|
2360
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2361
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2362
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2363
|
+
return _semantic.atomic_max(pointer, val, mask, sem, scope)
|
|
2166
2364
|
|
|
2167
2365
|
|
|
2168
2366
|
@_tensor_member_fn
|
|
2169
2367
|
@builtin
|
|
2170
2368
|
@_add_atomic_docstr("min")
|
|
2171
|
-
def atomic_min(pointer, val, mask=None, sem=None, scope=None,
|
|
2172
|
-
val =
|
|
2173
|
-
sem =
|
|
2174
|
-
scope =
|
|
2175
|
-
mask =
|
|
2176
|
-
return
|
|
2369
|
+
def atomic_min(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2370
|
+
val = _semantic.to_tensor(val)
|
|
2371
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2372
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2373
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2374
|
+
return _semantic.atomic_min(pointer, val, mask, sem, scope)
|
|
2177
2375
|
|
|
2178
2376
|
|
|
2179
2377
|
@_tensor_member_fn
|
|
2180
2378
|
@builtin
|
|
2181
2379
|
@_add_atomic_docstr("logical and")
|
|
2182
|
-
def atomic_and(pointer, val, mask=None, sem=None, scope=None,
|
|
2183
|
-
val =
|
|
2184
|
-
sem =
|
|
2185
|
-
scope =
|
|
2186
|
-
mask =
|
|
2187
|
-
return
|
|
2380
|
+
def atomic_and(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2381
|
+
val = _semantic.to_tensor(val)
|
|
2382
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2383
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2384
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2385
|
+
return _semantic.atomic_and(pointer, val, mask, sem, scope)
|
|
2188
2386
|
|
|
2189
2387
|
|
|
2190
2388
|
@_tensor_member_fn
|
|
2191
2389
|
@builtin
|
|
2192
2390
|
@_add_atomic_docstr("logical or")
|
|
2193
|
-
def atomic_or(pointer, val, mask=None, sem=None, scope=None,
|
|
2194
|
-
val =
|
|
2195
|
-
sem =
|
|
2196
|
-
scope =
|
|
2197
|
-
mask =
|
|
2198
|
-
return
|
|
2391
|
+
def atomic_or(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2392
|
+
val = _semantic.to_tensor(val)
|
|
2393
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2394
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2395
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2396
|
+
return _semantic.atomic_or(pointer, val, mask, sem, scope)
|
|
2199
2397
|
|
|
2200
2398
|
|
|
2201
2399
|
@_tensor_member_fn
|
|
2202
2400
|
@builtin
|
|
2203
2401
|
@_add_atomic_docstr("logical xor")
|
|
2204
|
-
def atomic_xor(pointer, val, mask=None, sem=None, scope=None,
|
|
2205
|
-
val =
|
|
2206
|
-
sem =
|
|
2207
|
-
scope =
|
|
2208
|
-
mask =
|
|
2209
|
-
return
|
|
2402
|
+
def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
|
|
2403
|
+
val = _semantic.to_tensor(val)
|
|
2404
|
+
sem = _unwrap_if_constexpr(sem)
|
|
2405
|
+
scope = _unwrap_if_constexpr(scope)
|
|
2406
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2407
|
+
return _semantic.atomic_xor(pointer, val, mask, sem, scope)
|
|
2210
2408
|
|
|
2211
2409
|
|
|
2212
2410
|
# -----------------------
|
|
@@ -2215,7 +2413,7 @@ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
|
|
2215
2413
|
|
|
2216
2414
|
|
|
2217
2415
|
@builtin
|
|
2218
|
-
def where(condition, x, y,
|
|
2416
|
+
def where(condition, x, y, _semantic=None):
|
|
2219
2417
|
"""
|
|
2220
2418
|
Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
|
|
2221
2419
|
|
|
@@ -2231,10 +2429,10 @@ def where(condition, x, y, _builder=None):
|
|
|
2231
2429
|
:param x: values selected at indices where condition is True.
|
|
2232
2430
|
:param y: values selected at indices where condition is False.
|
|
2233
2431
|
"""
|
|
2234
|
-
condition =
|
|
2432
|
+
condition = _semantic.to_tensor(condition)
|
|
2235
2433
|
x = _unwrap_if_constexpr(x)
|
|
2236
2434
|
y = _unwrap_if_constexpr(y)
|
|
2237
|
-
return
|
|
2435
|
+
return _semantic.where(condition, x, y)
|
|
2238
2436
|
|
|
2239
2437
|
|
|
2240
2438
|
# -----------------------
|
|
@@ -2243,28 +2441,28 @@ def where(condition, x, y, _builder=None):
|
|
|
2243
2441
|
|
|
2244
2442
|
|
|
2245
2443
|
@builtin
|
|
2246
|
-
def add(x, y, sanitize_overflow: constexpr = True,
|
|
2444
|
+
def add(x, y, sanitize_overflow: constexpr = True, _semantic=None):
|
|
2247
2445
|
x = _unwrap_if_constexpr(x)
|
|
2248
2446
|
y = _unwrap_if_constexpr(y)
|
|
2249
|
-
return
|
|
2447
|
+
return _semantic.add(x, y, sanitize_overflow)
|
|
2250
2448
|
|
|
2251
2449
|
|
|
2252
2450
|
@builtin
|
|
2253
|
-
def sub(x, y, sanitize_overflow: constexpr = True,
|
|
2451
|
+
def sub(x, y, sanitize_overflow: constexpr = True, _semantic=None):
|
|
2254
2452
|
x = _unwrap_if_constexpr(x)
|
|
2255
2453
|
y = _unwrap_if_constexpr(y)
|
|
2256
|
-
return
|
|
2454
|
+
return _semantic.sub(x, y, sanitize_overflow)
|
|
2257
2455
|
|
|
2258
2456
|
|
|
2259
2457
|
@builtin
|
|
2260
|
-
def mul(x, y, sanitize_overflow: constexpr = True,
|
|
2458
|
+
def mul(x, y, sanitize_overflow: constexpr = True, _semantic=None):
|
|
2261
2459
|
x = _unwrap_if_constexpr(x)
|
|
2262
2460
|
y = _unwrap_if_constexpr(y)
|
|
2263
|
-
return
|
|
2461
|
+
return _semantic.mul(x, y, sanitize_overflow)
|
|
2264
2462
|
|
|
2265
2463
|
|
|
2266
2464
|
@builtin
|
|
2267
|
-
def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE,
|
|
2465
|
+
def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
|
|
2268
2466
|
"""
|
|
2269
2467
|
Computes the element-wise minimum of :code:`x` and :code:`y`.
|
|
2270
2468
|
|
|
@@ -2277,16 +2475,16 @@ def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
|
|
|
2277
2475
|
|
|
2278
2476
|
.. seealso:: :class:`tl.PropagateNan`
|
|
2279
2477
|
"""
|
|
2280
|
-
x =
|
|
2281
|
-
y =
|
|
2282
|
-
x = _promote_bfloat16_to_float32(x,
|
|
2283
|
-
y = _promote_bfloat16_to_float32(y,
|
|
2284
|
-
propagate_nan =
|
|
2285
|
-
return
|
|
2478
|
+
x = _semantic.to_tensor(x)
|
|
2479
|
+
y = _semantic.to_tensor(y)
|
|
2480
|
+
x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
|
|
2481
|
+
y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
|
|
2482
|
+
propagate_nan = _unwrap_if_constexpr(propagate_nan)
|
|
2483
|
+
return _semantic.minimum(x, y, propagate_nan)
|
|
2286
2484
|
|
|
2287
2485
|
|
|
2288
2486
|
@builtin
|
|
2289
|
-
def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE,
|
|
2487
|
+
def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
|
|
2290
2488
|
"""
|
|
2291
2489
|
Computes the element-wise maximum of :code:`x` and :code:`y`.
|
|
2292
2490
|
|
|
@@ -2299,16 +2497,16 @@ def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None):
|
|
|
2299
2497
|
|
|
2300
2498
|
.. seealso:: :class:`tl.PropagateNan`
|
|
2301
2499
|
"""
|
|
2302
|
-
x =
|
|
2303
|
-
y =
|
|
2304
|
-
x = _promote_bfloat16_to_float32(x,
|
|
2305
|
-
y = _promote_bfloat16_to_float32(y,
|
|
2306
|
-
propagate_nan =
|
|
2307
|
-
return
|
|
2500
|
+
x = _semantic.to_tensor(x)
|
|
2501
|
+
y = _semantic.to_tensor(y)
|
|
2502
|
+
x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
|
|
2503
|
+
y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
|
|
2504
|
+
propagate_nan = _unwrap_if_constexpr(propagate_nan)
|
|
2505
|
+
return _semantic.maximum(x, y, propagate_nan)
|
|
2308
2506
|
|
|
2309
2507
|
|
|
2310
2508
|
@builtin
|
|
2311
|
-
def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE,
|
|
2509
|
+
def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
|
|
2312
2510
|
"""
|
|
2313
2511
|
Clamps the input tensor :code:`x` within the range [min, max].
|
|
2314
2512
|
Behavior when :code:`min` > :code:`max` is undefined.
|
|
@@ -2325,16 +2523,16 @@ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=No
|
|
|
2325
2523
|
|
|
2326
2524
|
.. seealso:: :class:`tl.PropagateNan`
|
|
2327
2525
|
"""
|
|
2328
|
-
x =
|
|
2329
|
-
min =
|
|
2330
|
-
max =
|
|
2331
|
-
x = _promote_bfloat16_to_float32(x,
|
|
2332
|
-
min = _promote_bfloat16_to_float32(min,
|
|
2333
|
-
max = _promote_bfloat16_to_float32(max,
|
|
2526
|
+
x = _semantic.to_tensor(x)
|
|
2527
|
+
min = _semantic.to_tensor(min)
|
|
2528
|
+
max = _semantic.to_tensor(max)
|
|
2529
|
+
x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
|
|
2530
|
+
min = _promote_bfloat16_to_float32(min, _semantic=_semantic)
|
|
2531
|
+
max = _promote_bfloat16_to_float32(max, _semantic=_semantic)
|
|
2334
2532
|
|
|
2335
|
-
propagate_nan =
|
|
2533
|
+
propagate_nan = _unwrap_if_constexpr(propagate_nan)
|
|
2336
2534
|
|
|
2337
|
-
return
|
|
2535
|
+
return _semantic.clamp(x, min, max, propagate_nan)
|
|
2338
2536
|
|
|
2339
2537
|
|
|
2340
2538
|
# -----------------------
|
|
@@ -2383,7 +2581,7 @@ def _insertion_guard(builder):
|
|
|
2383
2581
|
|
|
2384
2582
|
@_tensor_member_fn
|
|
2385
2583
|
@builtin
|
|
2386
|
-
def reduce(input, axis, combine_fn, keep_dims=False,
|
|
2584
|
+
def reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
|
|
2387
2585
|
"""Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
|
|
2388
2586
|
|
|
2389
2587
|
:param input: the input tensor, or tuple of tensors
|
|
@@ -2397,64 +2595,65 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N
|
|
|
2397
2595
|
|
|
2398
2596
|
"""
|
|
2399
2597
|
if isinstance(input, tensor):
|
|
2400
|
-
return reduce((input, ), axis, combine_fn, keep_dims=keep_dims,
|
|
2598
|
+
return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic, _generator=_generator)[0]
|
|
2401
2599
|
|
|
2402
2600
|
def make_combine_region(reduce_op):
|
|
2403
2601
|
param_types = [t.type.scalar for t in input] * 2
|
|
2404
2602
|
region = reduce_op.get_region(0)
|
|
2405
|
-
|
|
2406
|
-
|
|
2407
|
-
|
|
2603
|
+
builder = _semantic.builder
|
|
2604
|
+
with _insertion_guard(builder):
|
|
2605
|
+
to_ir = lambda T: T.to_ir(builder)
|
|
2606
|
+
block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
|
|
2408
2607
|
args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
|
|
2409
2608
|
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
|
2410
2609
|
if isinstance(results, tensor):
|
|
2411
2610
|
handles = [results.handle]
|
|
2412
2611
|
else:
|
|
2413
2612
|
handles = [r.handle for r in results]
|
|
2414
|
-
|
|
2613
|
+
builder.create_reduce_ret(*handles)
|
|
2415
2614
|
|
|
2416
2615
|
def expand_ndims(t, ndims):
|
|
2417
2616
|
for _ in builtins.range(ndims):
|
|
2418
|
-
t = expand_dims(t, 0,
|
|
2617
|
+
t = expand_dims(t, 0, _semantic=_semantic)
|
|
2419
2618
|
return t
|
|
2420
2619
|
|
|
2421
|
-
axis =
|
|
2422
|
-
keep_dims =
|
|
2620
|
+
axis = _unwrap_if_constexpr(axis)
|
|
2621
|
+
keep_dims = _unwrap_if_constexpr(keep_dims)
|
|
2423
2622
|
if axis is not None:
|
|
2424
2623
|
axis = _wrap_axis(axis, len(input[0].shape))
|
|
2425
|
-
ret =
|
|
2624
|
+
ret = _semantic.reduction(input, axis, make_combine_region)
|
|
2426
2625
|
if keep_dims:
|
|
2427
2626
|
if axis is not None:
|
|
2428
|
-
ret = tuple(expand_dims(t, axis,
|
|
2627
|
+
ret = tuple(expand_dims(t, axis, _semantic=_semantic) for t in ret)
|
|
2429
2628
|
else:
|
|
2430
2629
|
ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret)
|
|
2431
2630
|
return ret
|
|
2432
2631
|
|
|
2433
2632
|
|
|
2434
2633
|
@builtin
|
|
2435
|
-
def _promote_bfloat16_to_float32(t,
|
|
2634
|
+
def _promote_bfloat16_to_float32(t, _semantic=None):
|
|
2436
2635
|
scalar_ty = t.type.scalar
|
|
2437
2636
|
|
|
2438
2637
|
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
|
|
2439
2638
|
if scalar_ty is bfloat16:
|
|
2440
|
-
return t.to(float32,
|
|
2639
|
+
return t.to(float32, _semantic=_semantic)
|
|
2441
2640
|
return t
|
|
2442
2641
|
|
|
2443
2642
|
|
|
2444
2643
|
@builtin
|
|
2445
|
-
def _reduce_with_indices(input, axis, combine_fn, keep_dims=False,
|
|
2446
|
-
axis =
|
|
2644
|
+
def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
|
|
2645
|
+
axis = _unwrap_if_constexpr(axis)
|
|
2447
2646
|
n = input.shape[axis]
|
|
2448
|
-
index = arange(0, n,
|
|
2647
|
+
index = arange(0, n, _semantic=_semantic)
|
|
2449
2648
|
|
|
2450
2649
|
if len(input.shape) > 1:
|
|
2451
2650
|
# Broadcast index across the non-reduced axes
|
|
2452
2651
|
axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))]
|
|
2453
2652
|
del axes_to_expand[axis]
|
|
2454
|
-
index = expand_dims(index, axes_to_expand,
|
|
2455
|
-
index = broadcast_to(index, input.shape,
|
|
2653
|
+
index = expand_dims(index, axes_to_expand, _semantic=_semantic)
|
|
2654
|
+
index = broadcast_to(index, input.shape, _semantic=_semantic)
|
|
2456
2655
|
|
|
2457
|
-
rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims,
|
|
2656
|
+
rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic,
|
|
2458
2657
|
_generator=_generator)
|
|
2459
2658
|
return rvalue, rindices
|
|
2460
2659
|
|
|
@@ -2464,7 +2663,7 @@ def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None
|
|
|
2464
2663
|
# -----------------------
|
|
2465
2664
|
|
|
2466
2665
|
|
|
2467
|
-
def _add_scan_docstr(name: str) -> Callable[[T], T]:
|
|
2666
|
+
def _add_scan_docstr(name: str, dtype_arg: str = None) -> Callable[[T], T]:
|
|
2468
2667
|
|
|
2469
2668
|
def _decorator(func: T) -> T:
|
|
2470
2669
|
docstr = """
|
|
@@ -2473,7 +2672,15 @@ def _add_scan_docstr(name: str) -> Callable[[T], T]:
|
|
|
2473
2672
|
:param input: the input values
|
|
2474
2673
|
:type input: Tensor
|
|
2475
2674
|
:param axis: the dimension along which the scan should be done
|
|
2476
|
-
:type axis: int
|
|
2675
|
+
:type axis: int
|
|
2676
|
+
:param reverse: if true, the scan is performed in the reverse direction
|
|
2677
|
+
:type reverse: bool"""
|
|
2678
|
+
|
|
2679
|
+
if dtype_arg is not None:
|
|
2680
|
+
docstr += f"""
|
|
2681
|
+
:param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. If not specified, small integer types (< 32 bits) are upcasted to prevent overflow. Note that :code:`tl.bfloat16` inputs are automatically promoted to :code:`tl.float32`.
|
|
2682
|
+
:type {dtype_arg}: tl.dtype"""
|
|
2683
|
+
|
|
2477
2684
|
func.__doc__ = docstr.format(name=name)
|
|
2478
2685
|
return func
|
|
2479
2686
|
|
|
@@ -2482,7 +2689,7 @@ def _add_scan_docstr(name: str) -> Callable[[T], T]:
|
|
|
2482
2689
|
|
|
2483
2690
|
@_tensor_member_fn
|
|
2484
2691
|
@builtin
|
|
2485
|
-
def associative_scan(input, axis, combine_fn, reverse=False,
|
|
2692
|
+
def associative_scan(input, axis, combine_fn, reverse=False, _semantic=None, _generator=None):
|
|
2486
2693
|
"""Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry
|
|
2487
2694
|
|
|
2488
2695
|
:param input: the input tensor, or tuple of tensors
|
|
@@ -2496,46 +2703,52 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen
|
|
|
2496
2703
|
|
|
2497
2704
|
"""
|
|
2498
2705
|
if isinstance(input, tensor):
|
|
2499
|
-
return associative_scan((input, ), axis, combine_fn, reverse,
|
|
2706
|
+
return associative_scan((input, ), axis, combine_fn, reverse, _semantic=_semantic, _generator=_generator)[0]
|
|
2500
2707
|
|
|
2501
2708
|
def make_combine_region(scan_op):
|
|
2502
2709
|
param_types = [t.type.scalar for t in input] * 2
|
|
2503
2710
|
region = scan_op.get_region(0)
|
|
2504
|
-
|
|
2505
|
-
|
|
2506
|
-
|
|
2711
|
+
builder = _semantic.builder
|
|
2712
|
+
with _insertion_guard(builder):
|
|
2713
|
+
to_ir = lambda T: T.to_ir(builder)
|
|
2714
|
+
block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
|
|
2507
2715
|
args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
|
|
2508
2716
|
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
|
2509
2717
|
if isinstance(results, tensor):
|
|
2510
2718
|
handles = [results.handle]
|
|
2511
2719
|
else:
|
|
2512
2720
|
handles = [r.handle for r in results]
|
|
2513
|
-
|
|
2721
|
+
builder.create_scan_ret(*handles)
|
|
2514
2722
|
|
|
2515
|
-
axis =
|
|
2723
|
+
axis = _unwrap_if_constexpr(axis)
|
|
2516
2724
|
if axis is not None:
|
|
2517
2725
|
axis = _wrap_axis(axis, len(input[0].shape))
|
|
2518
|
-
return
|
|
2726
|
+
return _semantic.associative_scan(input, axis, make_combine_region, reverse)
|
|
2519
2727
|
|
|
2520
2728
|
|
|
2521
2729
|
@_tensor_member_fn
|
|
2522
2730
|
@builtin
|
|
2523
|
-
def histogram(input, num_bins,
|
|
2731
|
+
def histogram(input, num_bins, mask=None, _semantic=None, _generator=None):
|
|
2524
2732
|
"""computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0.
|
|
2525
2733
|
|
|
2526
2734
|
:param input: the input tensor
|
|
2527
2735
|
:type input: Tensor
|
|
2528
2736
|
:param num_bins: number of histogram bins
|
|
2529
2737
|
:type num_bins: int
|
|
2738
|
+
:param mask: if `mask[idx]` is false, exclude `input[idx]` from histogram
|
|
2739
|
+
:type mask: Block of `triton.int1`, optional
|
|
2530
2740
|
|
|
2531
2741
|
"""
|
|
2532
|
-
num_bins =
|
|
2533
|
-
|
|
2742
|
+
num_bins = _unwrap_if_constexpr(num_bins)
|
|
2743
|
+
mask = _unwrap_if_constexpr(mask)
|
|
2744
|
+
if mask is not None:
|
|
2745
|
+
mask = _semantic.to_tensor(mask)
|
|
2746
|
+
return _semantic.histogram(input, num_bins, mask)
|
|
2534
2747
|
|
|
2535
2748
|
|
|
2536
2749
|
@_tensor_member_fn
|
|
2537
2750
|
@builtin
|
|
2538
|
-
def gather(src, index, axis,
|
|
2751
|
+
def gather(src, index, axis, _semantic=None):
|
|
2539
2752
|
"""Gather from a tensor along a given dimension.
|
|
2540
2753
|
|
|
2541
2754
|
:param src: the source tensor
|
|
@@ -2546,8 +2759,81 @@ def gather(src, index, axis, _builder=None):
|
|
|
2546
2759
|
:type axis: int
|
|
2547
2760
|
|
|
2548
2761
|
"""
|
|
2549
|
-
axis =
|
|
2550
|
-
return
|
|
2762
|
+
axis = _unwrap_if_constexpr(axis)
|
|
2763
|
+
return _semantic.gather(src, index, axis)
|
|
2764
|
+
|
|
2765
|
+
|
|
2766
|
+
@builtin
|
|
2767
|
+
def map_elementwise(
|
|
2768
|
+
scalar_fn: Callable[..., Tuple[tensor, ...]],
|
|
2769
|
+
*args: tensor,
|
|
2770
|
+
pack=1,
|
|
2771
|
+
_semantic=None,
|
|
2772
|
+
_generator=None,
|
|
2773
|
+
):
|
|
2774
|
+
'''
|
|
2775
|
+
Map a scalar function over a tensor.
|
|
2776
|
+
|
|
2777
|
+
The input tensors :code:`args` are implicitly broadcasted to the same shape.
|
|
2778
|
+
|
|
2779
|
+
This may be useful in allowing control flow over single elements in a tensor,
|
|
2780
|
+
for example a multi-branch function where one branch is more expensive. With
|
|
2781
|
+
:code:`tl.where` you are forced to calculate both sides of the branch, but
|
|
2782
|
+
with an if we only execute one side.
|
|
2783
|
+
|
|
2784
|
+
.. highlight:: python
|
|
2785
|
+
.. code-block:: python
|
|
2786
|
+
|
|
2787
|
+
@triton.jit
|
|
2788
|
+
def selu_scalar(x, alpha):
|
|
2789
|
+
if x > 0:
|
|
2790
|
+
return a
|
|
2791
|
+
else:
|
|
2792
|
+
return alpha * (tl.exp(x) - 1)
|
|
2793
|
+
|
|
2794
|
+
@triton.jit
|
|
2795
|
+
def selu(x, alpha):
|
|
2796
|
+
return tl.map_elementwise(selu_scalar, x, alpha)
|
|
2797
|
+
|
|
2798
|
+
:param scalar_fn: the function to map over.
|
|
2799
|
+
:param pack: the number of elements to be processed by one function call.
|
|
2800
|
+
:return: one tensor or a tuple of tensors, depending on the mapped function.
|
|
2801
|
+
'''
|
|
2802
|
+
# Build the block for the nested region first to discover the return types
|
|
2803
|
+
assert pack >= 1
|
|
2804
|
+
in_scalar_tys = [t.type.scalar for t in args]
|
|
2805
|
+
builder = _semantic.builder
|
|
2806
|
+
block = builder.new_block()
|
|
2807
|
+
scalar_args = []
|
|
2808
|
+
for i, ty in enumerate(in_scalar_tys):
|
|
2809
|
+
for j in builtins.range(pack):
|
|
2810
|
+
block.add_argument(ty.to_ir(builder))
|
|
2811
|
+
scalar_args.append(tensor(block.arg(i * pack + j), ty))
|
|
2812
|
+
|
|
2813
|
+
with _insertion_guard(builder):
|
|
2814
|
+
builder.set_insertion_point_to_start(block)
|
|
2815
|
+
scalar_results = _generator.call_JitFunction(scalar_fn, scalar_args, kwargs={})
|
|
2816
|
+
|
|
2817
|
+
is_single = isinstance(scalar_results, tensor)
|
|
2818
|
+
if is_single:
|
|
2819
|
+
scalar_results = scalar_results,
|
|
2820
|
+
|
|
2821
|
+
handles = [r.handle for r in scalar_results]
|
|
2822
|
+
builder.create_map_elementwise_ret(handles)
|
|
2823
|
+
|
|
2824
|
+
fn_result_types = [x.type for x in scalar_results]
|
|
2825
|
+
scalar_result_types = fn_result_types
|
|
2826
|
+
if pack > 1:
|
|
2827
|
+
scalar_result_types = fn_result_types[::pack]
|
|
2828
|
+
for offset in builtins.range(1, pack):
|
|
2829
|
+
assert scalar_result_types == fn_result_types[offset::pack], "type mismatch in unpacked results"
|
|
2830
|
+
|
|
2831
|
+
def make_elementwise_region(elementwise_op):
|
|
2832
|
+
region = elementwise_op.get_region(0)
|
|
2833
|
+
region.push_back(block)
|
|
2834
|
+
|
|
2835
|
+
result = _semantic.map_elementwise(args, scalar_result_types, pack, make_elementwise_region)
|
|
2836
|
+
return result[0] if is_single else result
|
|
2551
2837
|
|
|
2552
2838
|
|
|
2553
2839
|
# -----------------------
|
|
@@ -2556,15 +2842,15 @@ def gather(src, index, axis, _builder=None):
|
|
|
2556
2842
|
|
|
2557
2843
|
|
|
2558
2844
|
@builtin
|
|
2559
|
-
def debug_barrier(
|
|
2845
|
+
def debug_barrier(_semantic=None):
|
|
2560
2846
|
'''
|
|
2561
2847
|
Insert a barrier to synchronize all threads in a block.
|
|
2562
2848
|
'''
|
|
2563
|
-
return
|
|
2849
|
+
return _semantic.debug_barrier()
|
|
2564
2850
|
|
|
2565
2851
|
|
|
2566
2852
|
@builtin
|
|
2567
|
-
def multiple_of(input, values,
|
|
2853
|
+
def multiple_of(input, values, _semantic=None):
|
|
2568
2854
|
"""
|
|
2569
2855
|
Let the compiler know that the values in :code:`input` are all multiples of :code:`value`.
|
|
2570
2856
|
"""
|
|
@@ -2576,11 +2862,11 @@ def multiple_of(input, values, _builder=None):
|
|
|
2576
2862
|
if not isinstance(d.value, int):
|
|
2577
2863
|
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
|
2578
2864
|
values = [x.value for x in values]
|
|
2579
|
-
return
|
|
2865
|
+
return _semantic.multiple_of(input, values)
|
|
2580
2866
|
|
|
2581
2867
|
|
|
2582
2868
|
@builtin
|
|
2583
|
-
def max_contiguous(input, values,
|
|
2869
|
+
def max_contiguous(input, values, _semantic=None):
|
|
2584
2870
|
"""
|
|
2585
2871
|
Let the compiler know that the `value` first values in :code:`input` are contiguous.
|
|
2586
2872
|
"""
|
|
@@ -2592,11 +2878,11 @@ def max_contiguous(input, values, _builder=None):
|
|
|
2592
2878
|
if not isinstance(d.value, int):
|
|
2593
2879
|
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
|
2594
2880
|
values = [x.value for x in values]
|
|
2595
|
-
return
|
|
2881
|
+
return _semantic.max_contiguous(input, values)
|
|
2596
2882
|
|
|
2597
2883
|
|
|
2598
2884
|
@builtin
|
|
2599
|
-
def max_constancy(input, values,
|
|
2885
|
+
def max_constancy(input, values, _semantic=None):
|
|
2600
2886
|
"""
|
|
2601
2887
|
Let the compiler know that the `value` first values in :code:`input` are constant.
|
|
2602
2888
|
|
|
@@ -2611,15 +2897,15 @@ def max_constancy(input, values, _builder=None):
|
|
|
2611
2897
|
if not isinstance(d.value, int):
|
|
2612
2898
|
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
|
2613
2899
|
values = [x.value for x in values]
|
|
2614
|
-
return
|
|
2900
|
+
return _semantic.max_constancy(input, values)
|
|
2615
2901
|
|
|
2616
2902
|
|
|
2617
2903
|
@builtin
|
|
2618
|
-
def assume(cond,
|
|
2904
|
+
def assume(cond, _semantic=None):
|
|
2619
2905
|
'''
|
|
2620
2906
|
Allow compiler to assume the :code:`cond` is True.
|
|
2621
2907
|
'''
|
|
2622
|
-
return
|
|
2908
|
+
return _semantic.assume(_semantic.to_tensor(cond))
|
|
2623
2909
|
|
|
2624
2910
|
|
|
2625
2911
|
# -----------------------
|
|
@@ -2628,7 +2914,7 @@ def assume(cond, _builder=None):
|
|
|
2628
2914
|
|
|
2629
2915
|
|
|
2630
2916
|
@builtin
|
|
2631
|
-
def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False,
|
|
2917
|
+
def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _semantic=None):
|
|
2632
2918
|
'''
|
|
2633
2919
|
Print the values at compile time. The parameters are the same as the builtin :code:`print`.
|
|
2634
2920
|
|
|
@@ -2644,7 +2930,7 @@ def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=Fals
|
|
|
2644
2930
|
|
|
2645
2931
|
|
|
2646
2932
|
@builtin
|
|
2647
|
-
def static_assert(cond, msg="",
|
|
2933
|
+
def static_assert(cond, msg="", _semantic=None):
|
|
2648
2934
|
'''
|
|
2649
2935
|
Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable
|
|
2650
2936
|
is set.
|
|
@@ -2658,7 +2944,7 @@ def static_assert(cond, msg="", _builder=None):
|
|
|
2658
2944
|
|
|
2659
2945
|
|
|
2660
2946
|
@builtin
|
|
2661
|
-
def device_print(prefix, *args, hex=False,
|
|
2947
|
+
def device_print(prefix, *args, hex=False, _semantic=None):
|
|
2662
2948
|
'''
|
|
2663
2949
|
Print the values at runtime from the device. String formatting does not work for runtime values, so you should
|
|
2664
2950
|
provide the values you want to print as arguments. The first value must be a string, all following values must
|
|
@@ -2692,7 +2978,7 @@ def device_print(prefix, *args, hex=False, _builder=None):
|
|
|
2692
2978
|
:param hex: print all values as hex instead of decimal
|
|
2693
2979
|
'''
|
|
2694
2980
|
import string
|
|
2695
|
-
prefix =
|
|
2981
|
+
prefix = _unwrap_if_constexpr(prefix)
|
|
2696
2982
|
assert isinstance(prefix, str), f"{prefix} is not string"
|
|
2697
2983
|
b_ascii = True
|
|
2698
2984
|
for ch in prefix:
|
|
@@ -2702,12 +2988,12 @@ def device_print(prefix, *args, hex=False, _builder=None):
|
|
|
2702
2988
|
assert b_ascii, f"{prefix} is not an ascii string"
|
|
2703
2989
|
new_args = []
|
|
2704
2990
|
for arg in args:
|
|
2705
|
-
new_args.append(
|
|
2706
|
-
return
|
|
2991
|
+
new_args.append(_semantic.to_tensor(arg))
|
|
2992
|
+
return _semantic.device_print(prefix, new_args, hex)
|
|
2707
2993
|
|
|
2708
2994
|
|
|
2709
2995
|
@builtin
|
|
2710
|
-
def device_assert(cond, msg="",
|
|
2996
|
+
def device_assert(cond, msg="", mask=None, _semantic=None):
|
|
2711
2997
|
'''
|
|
2712
2998
|
Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG`
|
|
2713
2999
|
is set to a value besides :code:`0` in order for this to have any effect.
|
|
@@ -2725,13 +3011,16 @@ def device_assert(cond, msg="", _builder=None):
|
|
|
2725
3011
|
:param cond: the condition to assert. This is required to be a boolean tensor.
|
|
2726
3012
|
:param msg: the message to print if the assertion fails. This is required to be a string literal.
|
|
2727
3013
|
'''
|
|
2728
|
-
msg =
|
|
2729
|
-
|
|
3014
|
+
msg = _unwrap_if_constexpr(msg)
|
|
3015
|
+
mask = _unwrap_if_constexpr(mask)
|
|
3016
|
+
if mask is not None:
|
|
3017
|
+
mask = _semantic.to_tensor(mask)
|
|
3018
|
+
return _semantic.device_assert(_semantic.to_tensor(cond), msg, mask)
|
|
2730
3019
|
|
|
2731
3020
|
|
|
2732
3021
|
@builtin
|
|
2733
3022
|
def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]],
|
|
2734
|
-
is_pure: bool, pack: int,
|
|
3023
|
+
is_pure: bool, pack: int, _semantic=None):
|
|
2735
3024
|
'''
|
|
2736
3025
|
Execute inline assembly over a tensor. Essentially, this is :code:`map`
|
|
2737
3026
|
where the function is inline assembly.
|
|
@@ -2816,13 +3105,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
|
|
|
2816
3105
|
:param dtype: the element type(s) of the returned tensor(s)
|
|
2817
3106
|
:param is_pure: if true, the compiler assumes the asm block has no side-effects
|
|
2818
3107
|
:param pack: the number of elements to be processed by one instance of inline assembly
|
|
2819
|
-
:param _builder: the builder
|
|
2820
3108
|
:return: one tensor or a tuple of tensors of the given dtypes
|
|
2821
3109
|
'''
|
|
2822
|
-
asm =
|
|
2823
|
-
constraints =
|
|
2824
|
-
pack =
|
|
2825
|
-
is_pure =
|
|
3110
|
+
asm = _unwrap_if_constexpr(asm)
|
|
3111
|
+
constraints = _unwrap_if_constexpr(constraints)
|
|
3112
|
+
pack = _unwrap_if_constexpr(pack)
|
|
3113
|
+
is_pure = _unwrap_if_constexpr(is_pure)
|
|
2826
3114
|
|
|
2827
3115
|
# Wrap `dtype` in a tuple if it's not already.
|
|
2828
3116
|
try:
|
|
@@ -2835,10 +3123,9 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
|
|
|
2835
3123
|
dtype = typing.cast(Sequence[_DtypeClass], dtype)
|
|
2836
3124
|
|
|
2837
3125
|
res_tys = dtype
|
|
2838
|
-
if dispatch_args := [
|
|
3126
|
+
if dispatch_args := [_semantic.to_tensor(arg) for arg in args]:
|
|
2839
3127
|
bin_op_type_checking = partial(
|
|
2840
|
-
|
|
2841
|
-
builder=_builder,
|
|
3128
|
+
_semantic.binary_op_type_checking_impl,
|
|
2842
3129
|
arithmetic_check=False,
|
|
2843
3130
|
allow_lhs_ptr=True,
|
|
2844
3131
|
allow_rhs_ptr=True,
|
|
@@ -2851,9 +3138,10 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
|
|
|
2851
3138
|
# Change the shape of each argument based on the broadcast shape
|
|
2852
3139
|
for i, item in enumerate(dispatch_args):
|
|
2853
3140
|
dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg)
|
|
2854
|
-
res_tys = [
|
|
3141
|
+
res_tys = [broadcast_arg.type.with_element_ty(dt) for dt in dtype]
|
|
2855
3142
|
handles = [t.handle for t in dispatch_args]
|
|
2856
|
-
|
|
3143
|
+
builder = _semantic.builder
|
|
3144
|
+
call = builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(builder) for ty in res_tys], is_pure, pack)
|
|
2857
3145
|
|
|
2858
3146
|
if not has_multiple_outputs:
|
|
2859
3147
|
return tensor(call.get_result(0), res_tys[0])
|
|
@@ -2865,7 +3153,7 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
|
|
|
2865
3153
|
# -----------------------
|
|
2866
3154
|
|
|
2867
3155
|
|
|
2868
|
-
class static_range:
|
|
3156
|
+
class static_range(base_value):
|
|
2869
3157
|
"""
|
|
2870
3158
|
Iterator that counts upward forever.
|
|
2871
3159
|
|
|
@@ -2905,7 +3193,23 @@ class static_range:
|
|
|
2905
3193
|
raise RuntimeError("static_range can only be used in @triton.jit'd functions")
|
|
2906
3194
|
|
|
2907
3195
|
|
|
2908
|
-
class
|
|
3196
|
+
class async_task:
|
|
3197
|
+
"""
|
|
3198
|
+
Context manager to run code fragments asynchronously.
|
|
3199
|
+
"""
|
|
3200
|
+
|
|
3201
|
+
def __init__(self, task_ids, _builder=None):
|
|
3202
|
+
self.task_ids = list({_unwrap_if_constexpr(tid) for tid in task_ids})
|
|
3203
|
+
self.builder = _builder
|
|
3204
|
+
|
|
3205
|
+
def __enter__(self):
|
|
3206
|
+
self.builder.set_async_task_ids(self.task_ids)
|
|
3207
|
+
|
|
3208
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
3209
|
+
self.builder.unset_async_task_ids()
|
|
3210
|
+
|
|
3211
|
+
|
|
3212
|
+
class range(base_value):
|
|
2909
3213
|
"""
|
|
2910
3214
|
Iterator that counts upward forever.
|
|
2911
3215
|
|
|
@@ -2936,10 +3240,21 @@ class range:
|
|
|
2936
3240
|
:param flatten: automatically flatten the loop nest starting at this loop to
|
|
2937
3241
|
create a single flattened loop. The compiler will try to pipeline the
|
|
2938
3242
|
flattened loop which can avoid stage stalling.
|
|
3243
|
+
:param warp_specialize: Enable automatic warp specialization on the loop.
|
|
3244
|
+
The compiler will attempt to partition memory, MMA, and vector
|
|
3245
|
+
operations in the loop into separate async partitions. This will
|
|
3246
|
+
increase the total number of warps required by the kernel.
|
|
3247
|
+
:param disable_licm: Tells the compiler it shouldn't hoist loop invariant
|
|
3248
|
+
code outside the loop. This is often useful to avoid creating long liveranges
|
|
3249
|
+
within a loop.
|
|
3250
|
+
|
|
3251
|
+
Note that warp specialization is only supported on Blackwell GPUs and
|
|
3252
|
+
only works on simple matmul loops. Support for arbitrary loops will be
|
|
3253
|
+
expanded over time.
|
|
2939
3254
|
"""
|
|
2940
3255
|
|
|
2941
3256
|
def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
|
|
2942
|
-
disallow_acc_multi_buffer=False, flatten=False):
|
|
3257
|
+
disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False):
|
|
2943
3258
|
if step is None:
|
|
2944
3259
|
self.step = constexpr(1)
|
|
2945
3260
|
else:
|
|
@@ -2954,6 +3269,8 @@ class range:
|
|
|
2954
3269
|
self.loop_unroll_factor = loop_unroll_factor
|
|
2955
3270
|
self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
|
|
2956
3271
|
self.flatten = flatten
|
|
3272
|
+
self.warp_specialize = warp_specialize
|
|
3273
|
+
self.disable_licm = disable_licm
|
|
2957
3274
|
|
|
2958
3275
|
def __iter__(self):
|
|
2959
3276
|
raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
|
|
@@ -2962,13 +3279,36 @@ class range:
|
|
|
2962
3279
|
raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
|
|
2963
3280
|
|
|
2964
3281
|
|
|
3282
|
+
class condition(base_value):
|
|
3283
|
+
"""
|
|
3284
|
+
While loop condition wrapper.
|
|
3285
|
+
|
|
3286
|
+
.. highlight:: python
|
|
3287
|
+
.. code-block:: python
|
|
3288
|
+
|
|
3289
|
+
@triton.jit
|
|
3290
|
+
def kernel(...):
|
|
3291
|
+
while tl.condition(c, disable_licm)
|
|
3292
|
+
...
|
|
3293
|
+
:note: This is a special wrapper used to annotate while loops in the context of
|
|
3294
|
+
:code:`triton.jit` functions. It allows user to pass extra attributes to the compiler.
|
|
3295
|
+
:param disable_licm: Tells the compiler it shouldn't hoist loop invariant
|
|
3296
|
+
code outside the loop. This is often useful to avoid creating long liveranges
|
|
3297
|
+
within a loop.
|
|
3298
|
+
"""
|
|
3299
|
+
|
|
3300
|
+
def __init__(self, arg1, disable_licm=False):
|
|
3301
|
+
self.condition = arg1
|
|
3302
|
+
self.disable_licm = disable_licm
|
|
3303
|
+
|
|
3304
|
+
|
|
2965
3305
|
# -----------------------
|
|
2966
3306
|
# Extern functions
|
|
2967
3307
|
# -----------------------
|
|
2968
3308
|
|
|
2969
3309
|
|
|
2970
|
-
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict,
|
|
2971
|
-
|
|
3310
|
+
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_type: dtype, is_pure: bool,
|
|
3311
|
+
_semantic):
|
|
2972
3312
|
'''
|
|
2973
3313
|
Dispatch a function to a library
|
|
2974
3314
|
:param func: the function to dispatch
|
|
@@ -2976,8 +3316,7 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
|
|
2976
3316
|
:param lib_path: the path of the library
|
|
2977
3317
|
:param args: the arguments of the function
|
|
2978
3318
|
:param arg_type_symbol_dict: the type of the arguments
|
|
2979
|
-
:param
|
|
2980
|
-
:param _builder: the builder
|
|
3319
|
+
:param ret_type: the type of the return value
|
|
2981
3320
|
:return: the return value of the function
|
|
2982
3321
|
'''
|
|
2983
3322
|
if len(arg_type_symbol_dict) == 0:
|
|
@@ -3004,15 +3343,13 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
|
|
3004
3343
|
f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
|
|
3005
3344
|
else:
|
|
3006
3345
|
symbol = arg_type_symbol_dict[arg_types][0]
|
|
3007
|
-
|
|
3008
|
-
|
|
3009
|
-
ret_type = block_type(ret_type, ret_shape)
|
|
3010
|
-
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
|
|
3346
|
+
builder = _semantic.builder
|
|
3347
|
+
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(builder), is_pure), ret_type)
|
|
3011
3348
|
|
|
3012
3349
|
|
|
3013
3350
|
@builtin
|
|
3014
3351
|
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
|
|
3015
|
-
|
|
3352
|
+
_semantic=None):
|
|
3016
3353
|
'''
|
|
3017
3354
|
Dispatch an elementwise function to a library
|
|
3018
3355
|
:param lib_name: the name of the library
|
|
@@ -3020,20 +3357,20 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
|
|
3020
3357
|
:param args: the arguments of the function
|
|
3021
3358
|
:param arg_type_symbol_dict: the type of the arguments
|
|
3022
3359
|
:param is_pure: whether the function is pure
|
|
3023
|
-
:param _builder: the builder
|
|
3024
3360
|
:return: the return value of the function
|
|
3025
3361
|
'''
|
|
3026
3362
|
dispatch_args = args.copy()
|
|
3027
3363
|
all_scalar = True
|
|
3028
|
-
ret_shape = None
|
|
3029
3364
|
arg_types = []
|
|
3030
3365
|
for i in builtins.range(len(dispatch_args)):
|
|
3031
|
-
dispatch_args[i] =
|
|
3366
|
+
dispatch_args[i] = _semantic.to_tensor(dispatch_args[i])
|
|
3032
3367
|
arg_types.append(dispatch_args[i].dtype)
|
|
3033
3368
|
if dispatch_args[i].type.is_block():
|
|
3034
3369
|
all_scalar = False
|
|
3370
|
+
|
|
3371
|
+
arg_types = tuple(arg_types)
|
|
3372
|
+
ret_type = arg_type_symbol_dict[arg_types][1]
|
|
3035
3373
|
if len(arg_types) > 0:
|
|
3036
|
-
arg_types = tuple(arg_types)
|
|
3037
3374
|
arithmetic_check = True
|
|
3038
3375
|
# If there's a type tuple that is not supported by the library, we will do arithmetic check
|
|
3039
3376
|
if arg_types in arg_type_symbol_dict:
|
|
@@ -3041,26 +3378,26 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
|
|
3041
3378
|
broadcast_arg = dispatch_args[0]
|
|
3042
3379
|
# Get the broadcast shape over all the arguments
|
|
3043
3380
|
for item in dispatch_args:
|
|
3044
|
-
_, broadcast_arg =
|
|
3045
|
-
|
|
3381
|
+
_, broadcast_arg = _semantic.binary_op_type_checking_impl(item, broadcast_arg,
|
|
3382
|
+
arithmetic_check=arithmetic_check)
|
|
3046
3383
|
# Change the shape of each argument based on the broadcast shape
|
|
3047
3384
|
for i in builtins.range(len(dispatch_args)):
|
|
3048
|
-
dispatch_args[i], _ =
|
|
3049
|
-
|
|
3385
|
+
dispatch_args[i], _ = _semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg,
|
|
3386
|
+
arithmetic_check=arithmetic_check)
|
|
3050
3387
|
if not all_scalar:
|
|
3051
|
-
|
|
3052
|
-
func =
|
|
3053
|
-
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict,
|
|
3388
|
+
ret_type = broadcast_arg.type.with_element_ty(ret_type)
|
|
3389
|
+
func = _semantic.builder.create_extern_elementwise
|
|
3390
|
+
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_type, is_pure, _semantic)
|
|
3054
3391
|
|
|
3055
3392
|
|
|
3056
|
-
def binary_op_type_legalization(lhs, rhs,
|
|
3393
|
+
def binary_op_type_legalization(lhs, rhs, semantic):
|
|
3057
3394
|
'''
|
|
3058
3395
|
Convert both operands to a single common type
|
|
3059
3396
|
:param lhs: the left operand
|
|
3060
3397
|
:param rhs: the right operand
|
|
3061
3398
|
:param builder: the builder
|
|
3062
3399
|
'''
|
|
3063
|
-
return semantic.binary_op_type_checking_impl(lhs, rhs
|
|
3400
|
+
return semantic.binary_op_type_checking_impl(lhs, rhs)
|
|
3064
3401
|
|
|
3065
3402
|
|
|
3066
3403
|
def extern(fn):
|