triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.4.0.post20__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +149 -47
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +92 -93
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +303 -128
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +76 -12
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +14 -6
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
- triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
- triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
triton/language/semantic.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations # remove after python 3.11
|
|
2
2
|
import warnings
|
|
3
3
|
|
|
4
|
-
from typing import List, Optional, Sequence, Tuple, TypeVar
|
|
4
|
+
from typing import List, Optional, Sequence, Tuple, TypeVar, Generic, Type
|
|
5
5
|
import numbers
|
|
6
6
|
|
|
7
|
+
from triton.runtime import driver
|
|
8
|
+
|
|
7
9
|
from .._C.libtriton import ir
|
|
8
10
|
from . import core as tl
|
|
9
11
|
|
|
10
12
|
T = TypeVar('T')
|
|
13
|
+
TensorTy = TypeVar('TensorTy')
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
class IncompatibleTypeErrorImpl(Exception):
|
|
@@ -19,1932 +22,1865 @@ class IncompatibleTypeErrorImpl(Exception):
|
|
|
19
22
|
super(IncompatibleTypeErrorImpl, self).__init__(self.message)
|
|
20
23
|
|
|
21
24
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
+
class TritonSemantic(Generic[TensorTy]):
|
|
26
|
+
tensor: Type[TensorTy] = tl.tensor
|
|
27
|
+
lang = tl
|
|
25
28
|
|
|
29
|
+
builder: ir.builder
|
|
26
30
|
|
|
27
|
-
def
|
|
28
|
-
|
|
29
|
-
raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
|
|
30
|
-
return tl.tensor(builder.create_get_program_id(axis), tl.int32)
|
|
31
|
+
def __init__(self, builder):
|
|
32
|
+
self.builder = builder
|
|
31
33
|
|
|
34
|
+
# ===----------------------------------------------------------------------===##
|
|
35
|
+
# Programming Model
|
|
36
|
+
# ===----------------------------------------------------------------------===##
|
|
32
37
|
|
|
33
|
-
def
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
38
|
+
def program_id(self, axis: int) -> TensorTy:
|
|
39
|
+
if axis not in (0, 1, 2):
|
|
40
|
+
raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
|
|
41
|
+
return self.tensor(self.builder.create_get_program_id(axis), tl.int32)
|
|
37
42
|
|
|
43
|
+
def num_programs(self, axis: int) -> TensorTy:
|
|
44
|
+
if axis not in (0, 1, 2):
|
|
45
|
+
raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
|
|
46
|
+
return self.tensor(self.builder.create_get_num_programs(axis), tl.int32)
|
|
38
47
|
|
|
39
48
|
# ===----------------------------------------------------------------------===//
|
|
40
49
|
# Implicit Casting Utilities
|
|
41
50
|
# ===----------------------------------------------------------------------===//
|
|
42
51
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
# converted to float
|
|
79
|
-
if a_ty.is_fp32() or b_ty.is_fp32():
|
|
80
|
-
return tl.float32
|
|
81
|
-
# 3 ) if one operand is half, the other is implicitly converted to half
|
|
82
|
-
# unless we're doing / or %, which do not exist natively in PTX for fp16.
|
|
83
|
-
# Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
|
|
84
|
-
if a_ty.is_fp16() or b_ty.is_fp16():
|
|
85
|
-
if div_or_mod:
|
|
52
|
+
def integer_promote_impl(self, a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype:
|
|
53
|
+
a_rank = a_ty.int_bitwidth
|
|
54
|
+
b_rank = b_ty.int_bitwidth
|
|
55
|
+
a_sn = a_ty.int_signedness
|
|
56
|
+
b_sn = b_ty.int_signedness
|
|
57
|
+
# Rules for signedness taken from "Usual arithmetic conversions" on
|
|
58
|
+
# https://en.cppreference.com/w/c/language/conversion.
|
|
59
|
+
if a_sn == b_sn:
|
|
60
|
+
return a_ty if a_rank > b_rank else b_ty
|
|
61
|
+
elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
|
|
62
|
+
return a_ty if a_rank >= b_rank else b_ty
|
|
63
|
+
elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
|
|
64
|
+
return b_ty if b_rank >= a_rank else a_ty
|
|
65
|
+
raise TypeError(f"unexpected signedness {a_sn} and {b_sn}")
|
|
66
|
+
|
|
67
|
+
def computation_type_impl(self, a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool,
|
|
68
|
+
div_or_mod: bool) -> tl.dtype:
|
|
69
|
+
# 0) For scalars we follow semantics similar to PyTorch, namely:
|
|
70
|
+
# - If the scalar is of a lower or equal kind (bool < uint < int < fp),
|
|
71
|
+
# it doesn't participate in the promotion
|
|
72
|
+
if a_is_scalar != b_is_scalar:
|
|
73
|
+
scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty)
|
|
74
|
+
if scalar_ty.kind().value <= tensor_ty.kind().value:
|
|
75
|
+
# Upcast because of 3) and 4) below!
|
|
76
|
+
if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)):
|
|
77
|
+
return tl.float32
|
|
78
|
+
return tensor_ty
|
|
79
|
+
|
|
80
|
+
# 1) if one operand is double, the other is implicitly
|
|
81
|
+
# converted to double
|
|
82
|
+
if a_ty.is_fp64() or b_ty.is_fp64():
|
|
83
|
+
return tl.float64
|
|
84
|
+
# 2) if one operand is float, the other is implicitly
|
|
85
|
+
# converted to float
|
|
86
|
+
if a_ty.is_fp32() or b_ty.is_fp32():
|
|
86
87
|
return tl.float32
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
88
|
+
# 3 ) if one operand is half, the other is implicitly converted to half
|
|
89
|
+
# unless we're doing / or %, which do not exist natively in PTX for fp16.
|
|
90
|
+
# Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
|
|
91
|
+
if a_ty.is_fp16() or b_ty.is_fp16():
|
|
92
|
+
if div_or_mod:
|
|
93
|
+
return tl.float32
|
|
94
|
+
else:
|
|
95
|
+
return tl.float16
|
|
96
|
+
# 4) return bf16 only if both operands are of bf16
|
|
97
|
+
if a_ty.is_bf16() and b_ty.is_bf16():
|
|
98
|
+
if div_or_mod:
|
|
99
|
+
return tl.float32
|
|
100
|
+
else:
|
|
101
|
+
return tl.bfloat16
|
|
102
|
+
if a_ty.is_bf16() or b_ty.is_bf16():
|
|
92
103
|
return tl.float32
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
dtype
|
|
122
|
-
elif
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
return to_tensor(x.value, builder)
|
|
142
|
-
elif isinstance(x, tl.tensor):
|
|
104
|
+
# 5) return fp16 if operands are different fp8
|
|
105
|
+
if a_ty.is_fp8() and b_ty.is_fp8():
|
|
106
|
+
return a_ty if a_ty == b_ty else tl.float16
|
|
107
|
+
if not a_ty.is_int() or not b_ty.is_int():
|
|
108
|
+
raise TypeError(f"unexpected type {a_ty} and {b_ty}")
|
|
109
|
+
# 6 ) both operands are integer and undergo
|
|
110
|
+
# integer promotion
|
|
111
|
+
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
|
|
112
|
+
raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() +
|
|
113
|
+
" because they have different signedness;"
|
|
114
|
+
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
|
|
115
|
+
return self.integer_promote_impl(a_ty, b_ty)
|
|
116
|
+
|
|
117
|
+
def to_tensor(self, x, check_type: bool = True):
|
|
118
|
+
if isinstance(x, bool):
|
|
119
|
+
return self.tensor(self.builder.get_int1(x), tl.int1)
|
|
120
|
+
# Note: compile-time const integers are represented by unsigned values
|
|
121
|
+
elif isinstance(x, int):
|
|
122
|
+
if -2**31 <= x < 2**31:
|
|
123
|
+
dtype = tl.int32
|
|
124
|
+
elif 2**31 <= x < 2**32:
|
|
125
|
+
dtype = tl.uint32
|
|
126
|
+
elif -2**63 <= x < 2**63:
|
|
127
|
+
dtype = tl.int64
|
|
128
|
+
elif 2**63 <= x < 2**64:
|
|
129
|
+
dtype = tl.uint64
|
|
130
|
+
else:
|
|
131
|
+
raise ValueError(f'Nonrepresentable integer {x}.')
|
|
132
|
+
return self.scalar_constant(x, dtype=dtype)
|
|
133
|
+
elif isinstance(x, float):
|
|
134
|
+
min_float32 = 2**-126
|
|
135
|
+
max_float32 = (2 - 2**-23) * 2**127
|
|
136
|
+
abs_x = __builtins__['abs'](x)
|
|
137
|
+
if abs_x == float("inf") or\
|
|
138
|
+
abs_x == 0.0 or \
|
|
139
|
+
x != x or \
|
|
140
|
+
min_float32 <= abs_x <= max_float32:
|
|
141
|
+
dtype = tl.float32
|
|
142
|
+
else:
|
|
143
|
+
dtype = tl.float64
|
|
144
|
+
return self.scalar_constant(x, dtype=dtype)
|
|
145
|
+
|
|
146
|
+
elif isinstance(x, tl.constexpr):
|
|
147
|
+
return self.to_tensor(x.value)
|
|
148
|
+
elif isinstance(x, self.tensor):
|
|
149
|
+
return x
|
|
150
|
+
if check_type:
|
|
151
|
+
raise TypeError(f"cannot convert {x} of type {type(x)} to tensor")
|
|
143
152
|
return x
|
|
144
|
-
if check_type:
|
|
145
|
-
raise TypeError(f"cannot convert {x} of type {type(x)} to tensor")
|
|
146
|
-
return x
|
|
147
|
-
|
|
148
153
|
|
|
149
154
|
# ===----------------------------------------------------------------------===//
|
|
150
155
|
# Binary Operators
|
|
151
156
|
# ===----------------------------------------------------------------------===//
|
|
152
157
|
|
|
158
|
+
def check_ptr_type_impl(self, type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
|
|
159
|
+
if type_a.is_ptr():
|
|
160
|
+
if not allow_ptr_a:
|
|
161
|
+
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
|
162
|
+
# T* + U* with T != U
|
|
163
|
+
if type_b.is_ptr() and (type_a != type_b):
|
|
164
|
+
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
|
165
|
+
# T* + float
|
|
166
|
+
if type_b.is_floating():
|
|
167
|
+
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
|
168
|
+
|
|
169
|
+
def binary_op_type_checking_impl(self, lhs: TensorTy | numbers.Number, rhs: TensorTy | numbers.Number,
|
|
170
|
+
allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True,
|
|
171
|
+
div_or_mod=False) -> Tuple[TensorTy, TensorTy]:
|
|
172
|
+
lhs_is_scalar = isinstance(lhs, numbers.Number)
|
|
173
|
+
rhs_is_scalar = isinstance(rhs, numbers.Number)
|
|
174
|
+
if lhs_is_scalar:
|
|
175
|
+
lhs_scalar = lhs
|
|
176
|
+
lhs = self.to_tensor(lhs)
|
|
177
|
+
if rhs_is_scalar:
|
|
178
|
+
rhs_scalar = rhs
|
|
179
|
+
rhs = self.to_tensor(rhs)
|
|
180
|
+
|
|
181
|
+
# implicit typecasting
|
|
182
|
+
lhs_sca_ty = lhs.type.scalar
|
|
183
|
+
rhs_sca_ty = rhs.type.scalar
|
|
184
|
+
self.check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr)
|
|
185
|
+
self.check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr)
|
|
186
|
+
if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr():
|
|
187
|
+
ret_sca_ty = self.computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod)
|
|
188
|
+
if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned()
|
|
189
|
+
or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()):
|
|
190
|
+
raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. "
|
|
191
|
+
"Perform a explicit cast on one of them.")
|
|
192
|
+
if ret_sca_ty.is_int():
|
|
193
|
+
if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <=
|
|
194
|
+
ret_sca_ty.get_int_max_value()):
|
|
195
|
+
raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}")
|
|
196
|
+
if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <=
|
|
197
|
+
ret_sca_ty.get_int_max_value()):
|
|
198
|
+
raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}")
|
|
199
|
+
lhs = self.scalar_constant(lhs_scalar, dtype=ret_sca_ty) if lhs_is_scalar else self.cast(lhs, ret_sca_ty)
|
|
200
|
+
rhs = self.scalar_constant(rhs_scalar, dtype=ret_sca_ty) if rhs_is_scalar else self.cast(rhs, ret_sca_ty)
|
|
201
|
+
|
|
202
|
+
# implicit broadcasting
|
|
203
|
+
lhs, rhs = self.broadcast_impl_value(lhs, rhs)
|
|
204
|
+
return lhs, rhs
|
|
205
|
+
|
|
206
|
+
def binary_op_sanitize_overflow_impl(self, lhs: TensorTy, rhs: TensorTy, binary_op: callable):
|
|
207
|
+
if lhs.type.scalar.int_bitwidth >= 64 or not self.builder.options.sanitize_overflow:
|
|
208
|
+
return
|
|
209
|
+
lhs_sca_ty = lhs.type.scalar
|
|
210
|
+
rhs_sca_ty = rhs.type.scalar
|
|
211
|
+
assert lhs_sca_ty == rhs_sca_ty
|
|
212
|
+
assert lhs_sca_ty.is_int()
|
|
213
|
+
lhs = self.cast(lhs, tl.int64)
|
|
214
|
+
rhs = self.cast(rhs, tl.int64)
|
|
215
|
+
ret = binary_op(lhs, rhs, False)
|
|
216
|
+
max_value = lhs_sca_ty.get_int_max_value()
|
|
217
|
+
max_value = self.scalar_constant(max_value, tl.int64)
|
|
218
|
+
min_value = lhs_sca_ty.get_int_min_value()
|
|
219
|
+
min_value = self.scalar_constant(min_value, tl.int64)
|
|
220
|
+
cond = self.and_(self.less_equal(ret, max_value), self.greater_equal(ret, min_value))
|
|
221
|
+
msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}"
|
|
222
|
+
self.device_assert(cond, msg)
|
|
223
|
+
|
|
224
|
+
def add(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
|
|
225
|
+
sanitize_overflow: bool) -> TensorTy:
|
|
226
|
+
input, other = self.binary_op_type_checking_impl(input, other, True, True)
|
|
227
|
+
input_scalar_ty = input.type.scalar
|
|
228
|
+
other_scalar_ty = other.type.scalar
|
|
229
|
+
if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr():
|
|
230
|
+
raise TypeError("cannot add pointers together")
|
|
231
|
+
|
|
232
|
+
# offset + ptr
|
|
233
|
+
# ptr + offset
|
|
234
|
+
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
|
|
235
|
+
input, other = other, input
|
|
236
|
+
input_scalar_ty = input.type.scalar
|
|
237
|
+
other_scalar_ty = other.type.scalar
|
|
238
|
+
if input_scalar_ty.is_ptr():
|
|
239
|
+
other_handle = other.handle
|
|
240
|
+
if other.dtype.is_int_unsigned() and other.dtype.int_bitwidth < 64:
|
|
241
|
+
# addptr treats offset as signed. Zero-extend unsigned offsets to ensure they're positive
|
|
242
|
+
i64_ty = other.type.with_element_ty(tl.int64).to_ir(self.builder)
|
|
243
|
+
other_handle = self.builder.create_int_cast(other.handle, i64_ty, False)
|
|
244
|
+
return self.tensor(self.builder.create_addptr(input.handle, other_handle), input.type)
|
|
245
|
+
# float + float
|
|
246
|
+
elif input_scalar_ty.is_floating():
|
|
247
|
+
return self.tensor(self.builder.create_fadd(input.handle, other.handle), input.type)
|
|
248
|
+
# int + int
|
|
249
|
+
elif input_scalar_ty.is_int():
|
|
250
|
+
if sanitize_overflow:
|
|
251
|
+
self.binary_op_sanitize_overflow_impl(input, other, self.add)
|
|
252
|
+
return self.tensor(self.builder.create_add(input.handle, other.handle), input.type)
|
|
253
|
+
raise TypeError(f"unexpected type {input_scalar_ty}")
|
|
153
254
|
|
|
154
|
-
def
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
#
|
|
159
|
-
if
|
|
160
|
-
|
|
161
|
-
#
|
|
162
|
-
if
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. "
|
|
188
|
-
"Perform a explicit cast on one of them.")
|
|
189
|
-
if ret_sca_ty.is_int():
|
|
190
|
-
if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <= ret_sca_ty.get_int_max_value()):
|
|
191
|
-
raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}")
|
|
192
|
-
if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <= ret_sca_ty.get_int_max_value()):
|
|
193
|
-
raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}")
|
|
194
|
-
lhs = full(
|
|
195
|
-
(), lhs_scalar, dtype=ret_sca_ty, builder=builder) if lhs_is_scalar else cast(lhs, ret_sca_ty, builder)
|
|
196
|
-
rhs = full(
|
|
197
|
-
(), rhs_scalar, dtype=ret_sca_ty, builder=builder) if rhs_is_scalar else cast(rhs, ret_sca_ty, builder)
|
|
198
|
-
|
|
199
|
-
# implicit broadcasting
|
|
200
|
-
lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
|
|
201
|
-
return lhs, rhs
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
def binary_op_sanitize_overflow_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, binary_op: callable):
|
|
205
|
-
if lhs.type.scalar.int_bitwidth >= 64 or not builder.options.sanitize_overflow:
|
|
206
|
-
return
|
|
207
|
-
lhs_sca_ty = lhs.type.scalar
|
|
208
|
-
rhs_sca_ty = rhs.type.scalar
|
|
209
|
-
assert lhs_sca_ty == rhs_sca_ty
|
|
210
|
-
assert lhs_sca_ty.is_int()
|
|
211
|
-
lhs = cast(lhs, tl.int64, builder)
|
|
212
|
-
rhs = cast(rhs, tl.int64, builder)
|
|
213
|
-
ret = binary_op(lhs, rhs, False, builder)
|
|
214
|
-
max_value = lhs_sca_ty.get_int_max_value()
|
|
215
|
-
max_value = tl.tensor(builder.get_int64(max_value), tl.int64)
|
|
216
|
-
min_value = lhs_sca_ty.get_int_min_value()
|
|
217
|
-
min_value = tl.tensor(builder.get_int64(min_value), tl.int64)
|
|
218
|
-
cond = and_(less_equal(ret, max_value, builder), greater_equal(ret, min_value, builder), builder)
|
|
219
|
-
msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}"
|
|
220
|
-
device_assert(cond, msg, builder)
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
def add(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool,
|
|
224
|
-
builder: ir.builder) -> tl.tensor:
|
|
225
|
-
input, other = binary_op_type_checking_impl(input, other, builder, True, True)
|
|
226
|
-
input_scalar_ty = input.type.scalar
|
|
227
|
-
other_scalar_ty = other.type.scalar
|
|
228
|
-
if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr():
|
|
229
|
-
raise TypeError("cannot add pointers together")
|
|
230
|
-
|
|
231
|
-
# offset + ptr
|
|
232
|
-
# ptr + offset
|
|
233
|
-
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
|
|
234
|
-
input, other = other, input
|
|
255
|
+
def sub(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
|
|
256
|
+
sanitize_overflow: bool) -> TensorTy:
|
|
257
|
+
input, other = self.binary_op_type_checking_impl(input, other, True, False)
|
|
258
|
+
scalar_ty = input.type.scalar
|
|
259
|
+
# ptr - offset
|
|
260
|
+
if scalar_ty.is_ptr():
|
|
261
|
+
return self.add(input, self.minus(other), sanitize_overflow=False)
|
|
262
|
+
# float - float
|
|
263
|
+
if scalar_ty.is_floating():
|
|
264
|
+
return self.tensor(self.builder.create_fsub(input.handle, other.handle), input.type)
|
|
265
|
+
# int - int
|
|
266
|
+
elif scalar_ty.is_int():
|
|
267
|
+
if sanitize_overflow:
|
|
268
|
+
self.binary_op_sanitize_overflow_impl(input, other, self.sub)
|
|
269
|
+
return self.tensor(self.builder.create_sub(input.handle, other.handle), input.type)
|
|
270
|
+
raise TypeError(f"unexpected type {scalar_ty}")
|
|
271
|
+
|
|
272
|
+
def mul(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
|
|
273
|
+
sanitize_overflow: bool) -> TensorTy:
|
|
274
|
+
input, other = self.binary_op_type_checking_impl(input, other)
|
|
275
|
+
scalar_ty = input.type.scalar
|
|
276
|
+
# float * float
|
|
277
|
+
if scalar_ty.is_floating():
|
|
278
|
+
return self.tensor(self.builder.create_fmul(input.handle, other.handle), input.type)
|
|
279
|
+
# int * int
|
|
280
|
+
elif scalar_ty.is_int():
|
|
281
|
+
if sanitize_overflow:
|
|
282
|
+
self.binary_op_sanitize_overflow_impl(input, other, self.mul)
|
|
283
|
+
return self.tensor(self.builder.create_mul(input.handle, other.handle), input.type)
|
|
284
|
+
raise TypeError(f"unexpected type {scalar_ty}")
|
|
285
|
+
|
|
286
|
+
def truediv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
|
|
287
|
+
input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
|
|
235
288
|
input_scalar_ty = input.type.scalar
|
|
236
289
|
other_scalar_ty = other.type.scalar
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
290
|
+
# float / int
|
|
291
|
+
if input_scalar_ty.is_floating() and other_scalar_ty.is_int():
|
|
292
|
+
other = self.cast(other, input_scalar_ty)
|
|
293
|
+
# int / float
|
|
294
|
+
elif input_scalar_ty.is_int() and other_scalar_ty.is_floating():
|
|
295
|
+
input = self.cast(input, other_scalar_ty)
|
|
296
|
+
# int / int (cast to tl.float32)
|
|
297
|
+
elif input_scalar_ty.is_int() and other_scalar_ty.is_int():
|
|
298
|
+
input = self.cast(input, tl.float32)
|
|
299
|
+
other = self.cast(other, tl.float32)
|
|
300
|
+
# float / float (cast to the highest exponent type)
|
|
301
|
+
elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating():
|
|
302
|
+
if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width:
|
|
303
|
+
other = self.cast(other, input_scalar_ty)
|
|
243
304
|
else:
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
return tl.tensor(builder.create_addptr(input.handle, other_handle), input.type)
|
|
247
|
-
# float + float
|
|
248
|
-
elif input_scalar_ty.is_floating():
|
|
249
|
-
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
|
|
250
|
-
# int + int
|
|
251
|
-
elif input_scalar_ty.is_int():
|
|
252
|
-
if sanitize_overflow:
|
|
253
|
-
binary_op_sanitize_overflow_impl(input, other, builder, add)
|
|
254
|
-
return tl.tensor(builder.create_add(input.handle, other.handle), input.type)
|
|
255
|
-
raise TypeError(f"unexpected type {input_scalar_ty}")
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
def sub(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool,
|
|
259
|
-
builder: ir.builder) -> tl.tensor:
|
|
260
|
-
input, other = binary_op_type_checking_impl(input, other, builder, True, False)
|
|
261
|
-
scalar_ty = input.type.scalar
|
|
262
|
-
# ptr - offset
|
|
263
|
-
if scalar_ty.is_ptr():
|
|
264
|
-
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type)
|
|
265
|
-
# float - float
|
|
266
|
-
if scalar_ty.is_floating():
|
|
267
|
-
return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type)
|
|
268
|
-
# int - int
|
|
269
|
-
elif scalar_ty.is_int():
|
|
270
|
-
if sanitize_overflow:
|
|
271
|
-
binary_op_sanitize_overflow_impl(input, other, builder, sub)
|
|
272
|
-
return tl.tensor(builder.create_sub(input.handle, other.handle), input.type)
|
|
273
|
-
raise TypeError(f"unexpected type {scalar_ty}")
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
def mul(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool,
|
|
277
|
-
builder: ir.builder) -> tl.tensor:
|
|
278
|
-
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
279
|
-
scalar_ty = input.type.scalar
|
|
280
|
-
# float * float
|
|
281
|
-
if scalar_ty.is_floating():
|
|
282
|
-
return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type)
|
|
283
|
-
# int * int
|
|
284
|
-
elif scalar_ty.is_int():
|
|
285
|
-
if sanitize_overflow:
|
|
286
|
-
binary_op_sanitize_overflow_impl(input, other, builder, mul)
|
|
287
|
-
return tl.tensor(builder.create_mul(input.handle, other.handle), input.type)
|
|
288
|
-
raise TypeError(f"unexpected type {scalar_ty}")
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
def truediv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor:
|
|
292
|
-
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
|
293
|
-
input_scalar_ty = input.type.scalar
|
|
294
|
-
other_scalar_ty = other.type.scalar
|
|
295
|
-
# float / int
|
|
296
|
-
if input_scalar_ty.is_floating() and other_scalar_ty.is_int():
|
|
297
|
-
other = cast(other, input_scalar_ty, builder)
|
|
298
|
-
# int / float
|
|
299
|
-
elif input_scalar_ty.is_int() and other_scalar_ty.is_floating():
|
|
300
|
-
input = cast(input, other_scalar_ty, builder)
|
|
301
|
-
# int / int (cast to tl.float32)
|
|
302
|
-
elif input_scalar_ty.is_int() and other_scalar_ty.is_int():
|
|
303
|
-
input = cast(input, tl.float32, builder)
|
|
304
|
-
other = cast(other, tl.float32, builder)
|
|
305
|
-
# float / float (cast to the highest exponent type)
|
|
306
|
-
elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating():
|
|
307
|
-
if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width:
|
|
308
|
-
other = cast(other, input_scalar_ty, builder)
|
|
305
|
+
input = self.cast(input, other_scalar_ty)
|
|
306
|
+
# unreachable
|
|
309
307
|
else:
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
308
|
+
raise TypeError(f"unexpected type {input_scalar_ty}")
|
|
309
|
+
return self.tensor(self.builder.create_fdiv(input.handle, other.handle), input.type)
|
|
310
|
+
|
|
311
|
+
def floordiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
|
|
312
|
+
input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
|
|
313
|
+
input_scalar_ty = input.type.scalar
|
|
314
|
+
other_scalar_ty = other.type.scalar
|
|
315
|
+
if input_scalar_ty.is_int() and other_scalar_ty.is_int():
|
|
316
|
+
ret_ty = self.integer_promote_impl(input_scalar_ty, other_scalar_ty)
|
|
317
|
+
input = self.cast(input, ret_ty)
|
|
318
|
+
other = self.cast(other, ret_ty)
|
|
319
|
+
if ret_ty.is_int_signed():
|
|
320
|
+
return self.tensor(self.builder.create_sdiv(input.handle, other.handle), input.type)
|
|
321
|
+
else:
|
|
322
|
+
return self.tensor(self.builder.create_udiv(input.handle, other.handle), input.type)
|
|
313
323
|
raise TypeError(f"unexpected type {input_scalar_ty}")
|
|
314
|
-
return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type)
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
def floordiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor:
|
|
318
|
-
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
|
319
|
-
input_scalar_ty = input.type.scalar
|
|
320
|
-
other_scalar_ty = other.type.scalar
|
|
321
|
-
if input_scalar_ty.is_int() and other_scalar_ty.is_int():
|
|
322
|
-
ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty)
|
|
323
|
-
input = cast(input, ret_ty, builder)
|
|
324
|
-
other = cast(other, ret_ty, builder)
|
|
325
|
-
if ret_ty.is_int_signed():
|
|
326
|
-
return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type)
|
|
327
|
-
else:
|
|
328
|
-
return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type)
|
|
329
|
-
raise TypeError(f"unexpected type {input_scalar_ty}")
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
def fdiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, ieee_rounding: bool,
|
|
333
|
-
builder: ir.builder) -> tl.tensor:
|
|
334
|
-
input_scalar_ty = input.type.scalar
|
|
335
|
-
other_scalar_ty = other.type.scalar
|
|
336
|
-
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
|
|
337
|
-
raise TypeError("both operands of fdiv must have floating scalar type")
|
|
338
|
-
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
|
|
339
|
-
ret = builder.create_fdiv(input.handle, other.handle)
|
|
340
|
-
return tl.tensor(ret, input.type)
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor:
|
|
344
|
-
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
|
345
|
-
scalar_ty = input.type.scalar
|
|
346
|
-
other_scalar_ty = other.type.scalar
|
|
347
|
-
# float % float
|
|
348
|
-
if scalar_ty.is_floating():
|
|
349
|
-
return tl.tensor(builder.create_frem(input.handle, other.handle), input.type)
|
|
350
|
-
# % int
|
|
351
|
-
elif scalar_ty.is_int():
|
|
352
|
-
if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
|
|
353
|
-
raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " "
|
|
354
|
-
"because they have different signedness;"
|
|
355
|
-
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
|
|
356
|
-
if scalar_ty.is_int_signed():
|
|
357
|
-
return tl.tensor(builder.create_srem(input.handle, other.handle), input.type)
|
|
358
|
-
else:
|
|
359
|
-
return tl.tensor(builder.create_urem(input.handle, other.handle), input.type)
|
|
360
|
-
raise TypeError(f"unexpected type {scalar_ty}")
|
|
361
324
|
|
|
325
|
+
def fdiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, ieee_rounding: bool) -> TensorTy:
|
|
326
|
+
input_scalar_ty = input.type.scalar
|
|
327
|
+
other_scalar_ty = other.type.scalar
|
|
328
|
+
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
|
|
329
|
+
raise TypeError("both operands of fdiv must have floating scalar type")
|
|
330
|
+
input, other = self.binary_op_type_checking_impl(input, other, False, False, False, True)
|
|
331
|
+
ret = self.builder.create_fdiv(input.handle, other.handle)
|
|
332
|
+
return self.tensor(ret, input.type)
|
|
333
|
+
|
|
334
|
+
def mod(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
|
|
335
|
+
input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
|
|
336
|
+
scalar_ty = input.type.scalar
|
|
337
|
+
other_scalar_ty = other.type.scalar
|
|
338
|
+
# float % float
|
|
339
|
+
if scalar_ty.is_floating():
|
|
340
|
+
return self.tensor(self.builder.create_frem(input.handle, other.handle), input.type)
|
|
341
|
+
# % int
|
|
342
|
+
elif scalar_ty.is_int():
|
|
343
|
+
if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
|
|
344
|
+
raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " "
|
|
345
|
+
"because they have different signedness;"
|
|
346
|
+
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
|
|
347
|
+
if scalar_ty.is_int_signed():
|
|
348
|
+
return self.tensor(self.builder.create_srem(input.handle, other.handle), input.type)
|
|
349
|
+
else:
|
|
350
|
+
return self.tensor(self.builder.create_urem(input.handle, other.handle), input.type)
|
|
351
|
+
raise TypeError(f"unexpected type {scalar_ty}")
|
|
362
352
|
|
|
363
353
|
##############
|
|
364
354
|
# other arithmetic ops
|
|
365
355
|
##############
|
|
366
356
|
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
357
|
+
def minimum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan):
|
|
358
|
+
x, y = self.binary_op_type_checking_impl(x, y)
|
|
359
|
+
dtype = x.dtype
|
|
360
|
+
if dtype.is_floating():
|
|
361
|
+
if propagate_nan == tl.PropagateNan.ALL:
|
|
362
|
+
return self.tensor(self.builder.create_minimumf(x.handle, y.handle), x.type)
|
|
363
|
+
elif propagate_nan == tl.PropagateNan.NONE:
|
|
364
|
+
return self.tensor(self.builder.create_minnumf(x.handle, y.handle), x.type)
|
|
365
|
+
else:
|
|
366
|
+
raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
|
|
367
|
+
elif dtype.is_int_signed():
|
|
368
|
+
return self.tensor(self.builder.create_minsi(x.handle, y.handle), x.type)
|
|
369
|
+
elif dtype.is_int_unsigned():
|
|
370
|
+
return self.tensor(self.builder.create_minui(x.handle, y.handle), x.type)
|
|
376
371
|
else:
|
|
377
|
-
raise
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type)
|
|
372
|
+
raise TypeError(f"Unexpected dtype {dtype}")
|
|
373
|
+
|
|
374
|
+
def maximum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan):
|
|
375
|
+
x, y = self.binary_op_type_checking_impl(x, y)
|
|
376
|
+
dtype = x.dtype
|
|
377
|
+
if dtype.is_floating():
|
|
378
|
+
if propagate_nan == tl.PropagateNan.ALL:
|
|
379
|
+
return self.tensor(self.builder.create_maximumf(x.handle, y.handle), x.type)
|
|
380
|
+
elif propagate_nan == tl.PropagateNan.NONE:
|
|
381
|
+
return self.tensor(self.builder.create_maxnumf(x.handle, y.handle), x.type)
|
|
382
|
+
else:
|
|
383
|
+
raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
|
|
384
|
+
elif dtype.is_int_signed():
|
|
385
|
+
return self.tensor(self.builder.create_maxsi(x.handle, y.handle), x.type)
|
|
386
|
+
elif dtype.is_int_unsigned():
|
|
387
|
+
return self.tensor(self.builder.create_maxui(x.handle, y.handle), x.type)
|
|
394
388
|
else:
|
|
395
|
-
raise
|
|
396
|
-
elif dtype.is_int_signed():
|
|
397
|
-
return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type)
|
|
398
|
-
elif dtype.is_int_unsigned():
|
|
399
|
-
return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type)
|
|
400
|
-
else:
|
|
401
|
-
raise TypeError(f"Unexpected dtype {dtype}")
|
|
389
|
+
raise TypeError(f"Unexpected dtype {dtype}")
|
|
402
390
|
|
|
391
|
+
def clamp(self, x: TensorTy, min: TensorTy, max: TensorTy, propagate_nan: tl.PropagateNan):
|
|
392
|
+
min, max = self.binary_op_type_checking_impl(min, max)
|
|
393
|
+
x, min = self.binary_op_type_checking_impl(x, min)
|
|
394
|
+
x, max = self.binary_op_type_checking_impl(x, max)
|
|
403
395
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
dtype = x.dtype
|
|
410
|
-
if dtype.is_floating():
|
|
411
|
-
return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type)
|
|
412
|
-
else:
|
|
413
|
-
raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported")
|
|
414
|
-
|
|
396
|
+
dtype = x.dtype
|
|
397
|
+
if dtype.is_floating():
|
|
398
|
+
return self.tensor(self.builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type)
|
|
399
|
+
else:
|
|
400
|
+
raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported")
|
|
415
401
|
|
|
416
402
|
##############
|
|
417
403
|
# bitwise ops
|
|
418
404
|
##############
|
|
419
405
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
def
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
def
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
input =
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type)
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
479
|
-
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
|
480
|
-
return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type)
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
484
|
-
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
|
485
|
-
return tl.tensor(builder.create_shl(input.handle, other.handle), input.type)
|
|
486
|
-
|
|
406
|
+
def bitwise_op_type_checking_impl(self, input: TensorTy, other: TensorTy) -> Tuple[TensorTy, TensorTy]:
|
|
407
|
+
input, other = self.binary_op_type_checking_impl(input, other)
|
|
408
|
+
input_sca_ty = input.type.scalar
|
|
409
|
+
other_sca_ty = other.type.scalar
|
|
410
|
+
if not input_sca_ty.is_int() or not other_sca_ty.is_int():
|
|
411
|
+
raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
|
|
412
|
+
ret_sca_ty = self.integer_promote_impl(input_sca_ty, other_sca_ty)
|
|
413
|
+
if ret_sca_ty != input_sca_ty:
|
|
414
|
+
input = self.cast(input, ret_sca_ty)
|
|
415
|
+
if ret_sca_ty != other_sca_ty:
|
|
416
|
+
other = self.cast(other, ret_sca_ty)
|
|
417
|
+
return input, other
|
|
418
|
+
|
|
419
|
+
def and_(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
420
|
+
input, other = self.bitwise_op_type_checking_impl(input, other)
|
|
421
|
+
return self.tensor(self.builder.create_and(input.handle, other.handle), input.type)
|
|
422
|
+
|
|
423
|
+
def or_(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
424
|
+
input, other = self.bitwise_op_type_checking_impl(input, other)
|
|
425
|
+
return self.tensor(self.builder.create_or(input.handle, other.handle), input.type)
|
|
426
|
+
|
|
427
|
+
def xor_(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
428
|
+
input, other = self.bitwise_op_type_checking_impl(input, other)
|
|
429
|
+
return self.tensor(self.builder.create_xor(input.handle, other.handle), input.type)
|
|
430
|
+
|
|
431
|
+
def logical_and(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
432
|
+
if not input.type.is_int1():
|
|
433
|
+
input = self.bitcast(input, tl.int1)
|
|
434
|
+
if not other.type.is_int1():
|
|
435
|
+
other = self.bitcast(other, tl.int1)
|
|
436
|
+
return self.and_(input, other)
|
|
437
|
+
|
|
438
|
+
def logical_or(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
439
|
+
if not input.type.is_int1():
|
|
440
|
+
input = self.bitcast(input, tl.int1)
|
|
441
|
+
if not other.type.is_int1():
|
|
442
|
+
other = self.bitcast(other, tl.int1)
|
|
443
|
+
return self.or_(input, other)
|
|
444
|
+
|
|
445
|
+
def not_(self, input: TensorTy):
|
|
446
|
+
if not input.type.is_int1():
|
|
447
|
+
input = self.bitcast(input, tl.int1)
|
|
448
|
+
return self.invert(input)
|
|
449
|
+
|
|
450
|
+
def lshr(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
451
|
+
input, other = self.bitwise_op_type_checking_impl(input, other)
|
|
452
|
+
return self.tensor(self.builder.create_lshr(input.handle, other.handle), input.type)
|
|
453
|
+
|
|
454
|
+
def ashr(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
455
|
+
input, other = self.bitwise_op_type_checking_impl(input, other)
|
|
456
|
+
return self.tensor(self.builder.create_ashr(input.handle, other.handle), input.type)
|
|
457
|
+
|
|
458
|
+
def shl(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
459
|
+
input, other = self.bitwise_op_type_checking_impl(input, other)
|
|
460
|
+
return self.tensor(self.builder.create_shl(input.handle, other.handle), input.type)
|
|
487
461
|
|
|
488
462
|
# ===----------------------------------------------------------------------===//
|
|
489
463
|
# Unary Operators
|
|
490
464
|
# ===----------------------------------------------------------------------===//
|
|
491
465
|
|
|
466
|
+
def plus(self, input: TensorTy) -> TensorTy:
|
|
467
|
+
return input
|
|
492
468
|
|
|
493
|
-
def
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
if input_sca_ty.is_ptr():
|
|
500
|
-
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
|
|
501
|
-
_0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
|
502
|
-
return sub(_0, input, True, builder)
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor:
|
|
506
|
-
input_sca_ty = input.type.scalar
|
|
507
|
-
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
|
|
508
|
-
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
|
|
509
|
-
_1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
|
510
|
-
return xor_(input, _1, builder)
|
|
469
|
+
def minus(self, input: TensorTy) -> TensorTy:
|
|
470
|
+
input_sca_ty = input.type.scalar
|
|
471
|
+
if input_sca_ty.is_ptr():
|
|
472
|
+
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
|
|
473
|
+
_0 = self.tensor(self.builder.get_null_value(input_sca_ty.to_ir(self.builder)), input_sca_ty)
|
|
474
|
+
return self.sub(_0, input, True)
|
|
511
475
|
|
|
476
|
+
def invert(self, input: TensorTy) -> TensorTy:
|
|
477
|
+
input_sca_ty = input.type.scalar
|
|
478
|
+
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
|
|
479
|
+
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
|
|
480
|
+
_1 = self.tensor(self.builder.get_all_ones_value(input_sca_ty.to_ir(self.builder)), input_sca_ty)
|
|
481
|
+
return self.xor_(input, _1)
|
|
512
482
|
|
|
513
483
|
# ===----------------------------------------------------------------------===//
|
|
514
484
|
# Comparison Operators
|
|
515
485
|
# ===----------------------------------------------------------------------===//
|
|
516
|
-
def _bool_like(v: tl.tensor) -> tl.block_type:
|
|
517
|
-
if not v.type.is_block():
|
|
518
|
-
return tl.int1
|
|
519
|
-
shape = v.type.shape
|
|
520
|
-
return tl.block_type(tl.int1, shape)
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
524
|
-
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
525
|
-
scalar_ty = input.type.scalar
|
|
526
|
-
# float > float
|
|
527
|
-
if scalar_ty.is_floating():
|
|
528
|
-
return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input))
|
|
529
|
-
# > int
|
|
530
|
-
elif scalar_ty.is_int():
|
|
531
|
-
if scalar_ty.is_int_signed():
|
|
532
|
-
return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input))
|
|
533
|
-
else:
|
|
534
|
-
return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input))
|
|
535
|
-
raise TypeError(f"unexpected type {scalar_ty}")
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
539
|
-
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
540
|
-
scalar_ty = input.type.scalar
|
|
541
|
-
# float >= float
|
|
542
|
-
if scalar_ty.is_floating():
|
|
543
|
-
return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input))
|
|
544
|
-
# >= int
|
|
545
|
-
elif scalar_ty.is_int():
|
|
546
|
-
if scalar_ty.is_int_signed():
|
|
547
|
-
return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input))
|
|
548
|
-
else:
|
|
549
|
-
return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input))
|
|
550
|
-
raise TypeError(f"unexpected type {scalar_ty}")
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
554
|
-
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
555
|
-
scalar_ty = input.type.scalar
|
|
556
|
-
# float < float
|
|
557
|
-
if scalar_ty.is_floating():
|
|
558
|
-
return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input))
|
|
559
|
-
# < int
|
|
560
|
-
elif scalar_ty.is_int():
|
|
561
|
-
if scalar_ty.is_int_signed():
|
|
562
|
-
return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input))
|
|
563
|
-
else:
|
|
564
|
-
return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input))
|
|
565
|
-
raise TypeError(f"unexpected type {scalar_ty}")
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
569
|
-
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
570
|
-
scalar_ty = input.type.scalar
|
|
571
|
-
# float < float
|
|
572
|
-
if scalar_ty.is_floating():
|
|
573
|
-
return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input))
|
|
574
|
-
# < int
|
|
575
|
-
elif scalar_ty.is_int():
|
|
576
|
-
if scalar_ty.is_int_signed():
|
|
577
|
-
return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input))
|
|
578
|
-
else:
|
|
579
|
-
return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input))
|
|
580
|
-
raise TypeError(f"unexpected type {scalar_ty}")
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
584
|
-
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
585
|
-
scalar_ty = input.type.scalar
|
|
586
|
-
# float == float
|
|
587
|
-
if scalar_ty.is_floating():
|
|
588
|
-
return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input))
|
|
589
|
-
# == int
|
|
590
|
-
elif scalar_ty.is_int():
|
|
591
|
-
return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input))
|
|
592
|
-
raise TypeError(f"unexpected type {scalar_ty}")
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
596
|
-
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
597
|
-
scalar_ty = input.type.scalar
|
|
598
|
-
# float == float
|
|
599
|
-
if scalar_ty.is_floating():
|
|
600
|
-
return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input))
|
|
601
|
-
# == int
|
|
602
|
-
elif scalar_ty.is_int():
|
|
603
|
-
return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input))
|
|
604
|
-
raise TypeError(f"unexpected type {scalar_ty}")
|
|
605
486
|
|
|
487
|
+
def _bool_like(self, v: TensorTy) -> tl.block_type:
|
|
488
|
+
return v.type.with_element_ty(tl.int1)
|
|
489
|
+
|
|
490
|
+
def greater_than(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
491
|
+
input, other = self.binary_op_type_checking_impl(input, other)
|
|
492
|
+
scalar_ty = input.type.scalar
|
|
493
|
+
# float > float
|
|
494
|
+
if scalar_ty.is_floating():
|
|
495
|
+
return self.tensor(self.builder.create_fcmpOGT(input.handle, other.handle), self._bool_like(input))
|
|
496
|
+
# > int
|
|
497
|
+
elif scalar_ty.is_int():
|
|
498
|
+
if scalar_ty.is_int_signed():
|
|
499
|
+
return self.tensor(self.builder.create_icmpSGT(input.handle, other.handle), self._bool_like(input))
|
|
500
|
+
else:
|
|
501
|
+
return self.tensor(self.builder.create_icmpUGT(input.handle, other.handle), self._bool_like(input))
|
|
502
|
+
raise TypeError(f"unexpected type {scalar_ty}")
|
|
503
|
+
|
|
504
|
+
def greater_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
505
|
+
input, other = self.binary_op_type_checking_impl(input, other)
|
|
506
|
+
scalar_ty = input.type.scalar
|
|
507
|
+
# float >= float
|
|
508
|
+
if scalar_ty.is_floating():
|
|
509
|
+
return self.tensor(self.builder.create_fcmpOGE(input.handle, other.handle), self._bool_like(input))
|
|
510
|
+
# >= int
|
|
511
|
+
elif scalar_ty.is_int():
|
|
512
|
+
if scalar_ty.is_int_signed():
|
|
513
|
+
return self.tensor(self.builder.create_icmpSGE(input.handle, other.handle), self._bool_like(input))
|
|
514
|
+
else:
|
|
515
|
+
return self.tensor(self.builder.create_icmpUGE(input.handle, other.handle), self._bool_like(input))
|
|
516
|
+
raise TypeError(f"unexpected type {scalar_ty}")
|
|
517
|
+
|
|
518
|
+
def less_than(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
519
|
+
input, other = self.binary_op_type_checking_impl(input, other)
|
|
520
|
+
scalar_ty = input.type.scalar
|
|
521
|
+
# float < float
|
|
522
|
+
if scalar_ty.is_floating():
|
|
523
|
+
return self.tensor(self.builder.create_fcmpOLT(input.handle, other.handle), self._bool_like(input))
|
|
524
|
+
# < int
|
|
525
|
+
elif scalar_ty.is_int():
|
|
526
|
+
if scalar_ty.is_int_signed():
|
|
527
|
+
return self.tensor(self.builder.create_icmpSLT(input.handle, other.handle), self._bool_like(input))
|
|
528
|
+
else:
|
|
529
|
+
return self.tensor(self.builder.create_icmpULT(input.handle, other.handle), self._bool_like(input))
|
|
530
|
+
raise TypeError(f"unexpected type {scalar_ty}")
|
|
531
|
+
|
|
532
|
+
def less_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
533
|
+
input, other = self.binary_op_type_checking_impl(input, other)
|
|
534
|
+
scalar_ty = input.type.scalar
|
|
535
|
+
# float < float
|
|
536
|
+
if scalar_ty.is_floating():
|
|
537
|
+
return self.tensor(self.builder.create_fcmpOLE(input.handle, other.handle), self._bool_like(input))
|
|
538
|
+
# < int
|
|
539
|
+
elif scalar_ty.is_int():
|
|
540
|
+
if scalar_ty.is_int_signed():
|
|
541
|
+
return self.tensor(self.builder.create_icmpSLE(input.handle, other.handle), self._bool_like(input))
|
|
542
|
+
else:
|
|
543
|
+
return self.tensor(self.builder.create_icmpULE(input.handle, other.handle), self._bool_like(input))
|
|
544
|
+
raise TypeError(f"unexpected type {scalar_ty}")
|
|
545
|
+
|
|
546
|
+
def equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
547
|
+
input, other = self.binary_op_type_checking_impl(input, other)
|
|
548
|
+
scalar_ty = input.type.scalar
|
|
549
|
+
# float == float
|
|
550
|
+
if scalar_ty.is_floating():
|
|
551
|
+
return self.tensor(self.builder.create_fcmpOEQ(input.handle, other.handle), self._bool_like(input))
|
|
552
|
+
# == int
|
|
553
|
+
elif scalar_ty.is_int():
|
|
554
|
+
return self.tensor(self.builder.create_icmpEQ(input.handle, other.handle), self._bool_like(input))
|
|
555
|
+
raise TypeError(f"unexpected type {scalar_ty}")
|
|
556
|
+
|
|
557
|
+
def not_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
|
|
558
|
+
input, other = self.binary_op_type_checking_impl(input, other)
|
|
559
|
+
scalar_ty = input.type.scalar
|
|
560
|
+
# float == float
|
|
561
|
+
if scalar_ty.is_floating():
|
|
562
|
+
return self.tensor(self.builder.create_fcmpUNE(input.handle, other.handle), self._bool_like(input))
|
|
563
|
+
# == int
|
|
564
|
+
elif scalar_ty.is_int():
|
|
565
|
+
return self.tensor(self.builder.create_icmpNE(input.handle, other.handle), self._bool_like(input))
|
|
566
|
+
raise TypeError(f"unexpected type {scalar_ty}")
|
|
606
567
|
|
|
607
568
|
# ===----------------------------------------------------------------------===//
|
|
608
569
|
# Block Creation
|
|
609
570
|
# ===----------------------------------------------------------------------===//
|
|
610
571
|
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
def
|
|
630
|
-
if isinstance(value, tl.tensor):
|
|
631
|
-
assert value.numel.value == 1, "only accepts size-1 tensor"
|
|
632
|
-
value = cast(value, dtype, builder)
|
|
633
|
-
else:
|
|
572
|
+
def arange(self, start: int, end: int, *, ret_ty: tl.block_type = None) -> TensorTy:
|
|
573
|
+
if not isinstance(start, int) or not isinstance(end, int):
|
|
574
|
+
raise ValueError("arange's arguments must be of type tl.constexpr")
|
|
575
|
+
is_start_int64 = bool(start >> 32)
|
|
576
|
+
is_end_int64 = bool(end >> 32)
|
|
577
|
+
if is_start_int64 or is_end_int64:
|
|
578
|
+
raise ValueError("arange must fit in int32")
|
|
579
|
+
if end <= start:
|
|
580
|
+
raise ValueError("arange's end argument must be greater than the start argument")
|
|
581
|
+
range = end - start
|
|
582
|
+
if (range & (range - 1)) != 0:
|
|
583
|
+
raise ValueError("arange's range must be a power of 2")
|
|
584
|
+
shape = [range]
|
|
585
|
+
if ret_ty is None:
|
|
586
|
+
ret_ty = tl.block_type(tl.int32, shape)
|
|
587
|
+
ret_ty_ir = ret_ty.to_ir(self.builder)
|
|
588
|
+
return self.tensor(self.builder.create_make_range(ret_ty_ir, start, end), ret_ty)
|
|
589
|
+
|
|
590
|
+
def scalar_constant(self, value, dtype: tl.dtype) -> TensorTy:
|
|
634
591
|
# scalar
|
|
635
592
|
if dtype is None:
|
|
636
593
|
raise ValueError("dtype must be specified when value is not a tensor")
|
|
637
594
|
if value == 0:
|
|
638
|
-
value = builder.get_null_value(dtype.to_ir(builder))
|
|
595
|
+
value = self.builder.get_null_value(dtype.to_ir(self.builder))
|
|
639
596
|
else:
|
|
640
|
-
get_value_fn = getattr(builder, f"get_{dtype.name}")
|
|
597
|
+
get_value_fn = getattr(self.builder, f"get_{dtype.name}")
|
|
641
598
|
value = get_value_fn(value)
|
|
642
|
-
|
|
599
|
+
return self.tensor(value, dtype)
|
|
643
600
|
|
|
644
|
-
|
|
601
|
+
def make_scalar(self, value, dtype: tl.dtype) -> TensorTy:
|
|
602
|
+
if isinstance(value, tl.tensor):
|
|
603
|
+
assert value.numel.value == 1, "only accepts size-1 tensor"
|
|
604
|
+
return self.cast(value, dtype)
|
|
605
|
+
# scalar
|
|
606
|
+
return self.scalar_constant(value, dtype)
|
|
645
607
|
|
|
608
|
+
def full(self, shape: List[int], value, dtype: tl.dtype) -> TensorTy:
|
|
609
|
+
return self.splat(self.make_scalar(value, dtype), shape)
|
|
646
610
|
|
|
647
611
|
# ===----------------------------------------------------------------------===//
|
|
648
612
|
# Shape Manipulation
|
|
649
613
|
# ===----------------------------------------------------------------------===//
|
|
650
614
|
|
|
615
|
+
def splat(self, value: TensorTy, shape: List[int]) -> TensorTy:
|
|
616
|
+
assert not value.type.is_block(), "Cannot splat a block tensor"
|
|
617
|
+
if len(shape) == 0:
|
|
618
|
+
return value
|
|
619
|
+
ret_ty = tl.block_type(value.dtype, shape)
|
|
620
|
+
return self.tensor(self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle), ret_ty)
|
|
621
|
+
|
|
622
|
+
def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool) -> TensorTy:
|
|
623
|
+
numel = 1
|
|
624
|
+
for s in dst_shape:
|
|
625
|
+
numel *= s
|
|
626
|
+
if input.type.numel != numel:
|
|
627
|
+
raise ValueError("reshape() cannot change total number of elements in tensor")
|
|
628
|
+
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
|
629
|
+
return self.tensor(self.builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty)
|
|
630
|
+
|
|
631
|
+
def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
|
|
632
|
+
dst_shape = [tl._unwrap_if_constexpr(x) for x in input.shape]
|
|
633
|
+
dst_shape.insert(axis, 1)
|
|
634
|
+
|
|
635
|
+
if not input.type.is_block():
|
|
636
|
+
return self.splat(input, shape=dst_shape)
|
|
637
|
+
|
|
638
|
+
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
|
639
|
+
return self.tensor(self.builder.create_expand_dims(input.handle, axis), ret_ty)
|
|
640
|
+
|
|
641
|
+
def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy:
|
|
642
|
+
assert can_reorder, "current implementation of `cat` always may reorder elements"
|
|
643
|
+
assert len(lhs.shape) == 1
|
|
644
|
+
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
|
|
645
|
+
return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type)
|
|
646
|
+
|
|
647
|
+
def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
|
|
648
|
+
a, b = self.broadcast_impl_value(a, b)
|
|
649
|
+
|
|
650
|
+
# The IR can't handle joining two scalars, so upcast them to 1D tensors,
|
|
651
|
+
# then downcast the result.
|
|
652
|
+
was_rank_1 = a.shape == []
|
|
653
|
+
if was_rank_1:
|
|
654
|
+
a = self.expand_dims(a, 0)
|
|
655
|
+
b = self.expand_dims(b, 0)
|
|
656
|
+
|
|
657
|
+
if isinstance(a.shape[-1], tl.constexpr):
|
|
658
|
+
two = tl.constexpr(2)
|
|
659
|
+
else:
|
|
660
|
+
two = 2
|
|
661
|
+
new_shape = a.shape + [two]
|
|
651
662
|
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
if len(shape) == 0:
|
|
655
|
-
return value
|
|
656
|
-
ret_ty = tl.block_type(value.dtype, shape)
|
|
657
|
-
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
def reshape(input: tl.tensor, dst_shape: List[int], can_reorder: bool, builder: ir.builder) -> tl.tensor:
|
|
661
|
-
numel = 1
|
|
662
|
-
for s in dst_shape:
|
|
663
|
-
numel *= s
|
|
664
|
-
if input.type.numel != numel:
|
|
665
|
-
raise ValueError("reshape() cannot change total number of elements in tensor")
|
|
666
|
-
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
|
667
|
-
return tl.tensor(builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty)
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|
671
|
-
dst_shape = [tl._constexpr_to_value(x) for x in input.shape]
|
|
672
|
-
dst_shape.insert(axis, 1)
|
|
673
|
-
|
|
674
|
-
if not input.type.is_block():
|
|
675
|
-
return splat(input, shape=dst_shape, builder=builder)
|
|
676
|
-
|
|
677
|
-
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
|
678
|
-
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
|
|
682
|
-
assert can_reorder, "current implementation of `cat` always may reorder elements"
|
|
683
|
-
assert len(lhs.shape) == 1
|
|
684
|
-
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
|
|
685
|
-
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
def join(a: tl.tensor, b: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
689
|
-
a, b = broadcast_impl_value(a, b, builder)
|
|
690
|
-
|
|
691
|
-
# The IR can't handle joining two scalars, so upcast them to 1D tensors,
|
|
692
|
-
# then downcast the result.
|
|
693
|
-
was_rank_1 = a.shape == []
|
|
694
|
-
if was_rank_1:
|
|
695
|
-
a = expand_dims(a, 0, builder)
|
|
696
|
-
b = expand_dims(b, 0, builder)
|
|
697
|
-
|
|
698
|
-
if isinstance(a.shape[-1], tl.constexpr):
|
|
699
|
-
two = tl.constexpr(2)
|
|
700
|
-
else:
|
|
701
|
-
two = 2
|
|
702
|
-
new_shape = a.shape + [two]
|
|
703
|
-
|
|
704
|
-
ret_type = tl.block_type(a.type.scalar, new_shape)
|
|
705
|
-
ret = tl.tensor(builder.create_join(a.handle, b.handle), ret_type)
|
|
706
|
-
|
|
707
|
-
if was_rank_1:
|
|
708
|
-
ret = reshape(ret, [2], can_reorder=False, builder=builder)
|
|
709
|
-
|
|
710
|
-
return ret
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
def split(a: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
|
|
714
|
-
assert (len(a.shape) > 0)
|
|
715
|
-
assert (tl._constexpr_to_value(a.shape[-1]) == 2)
|
|
716
|
-
|
|
717
|
-
new_shape = a.shape[:-1]
|
|
718
|
-
ret_type = tl.block_type(a.type.scalar, new_shape)
|
|
719
|
-
outLHS, outRHS = builder.create_split(a.handle)
|
|
720
|
-
return (
|
|
721
|
-
tl.tensor(outLHS, ret_type),
|
|
722
|
-
tl.tensor(outRHS, ret_type),
|
|
723
|
-
)
|
|
724
|
-
|
|
663
|
+
ret_type = tl.block_type(a.type.scalar, new_shape)
|
|
664
|
+
ret = self.tensor(self.builder.create_join(a.handle, b.handle), ret_type)
|
|
725
665
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
raise ValueError("permute dims must have the same length as input shape")
|
|
729
|
-
if sorted(tl._constexpr_to_value(d) for d in dims) != list(range(len(dims))):
|
|
730
|
-
raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}")
|
|
666
|
+
if was_rank_1:
|
|
667
|
+
ret = self.reshape(ret, [2], can_reorder=False)
|
|
731
668
|
|
|
732
|
-
|
|
733
|
-
return tl.tensor(builder.create_trans(input.handle, dims), ret_type)
|
|
669
|
+
return ret
|
|
734
670
|
|
|
671
|
+
def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
|
|
672
|
+
assert (len(a.shape) > 0)
|
|
673
|
+
assert (tl._unwrap_if_constexpr(a.shape[-1]) == 2)
|
|
735
674
|
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
return
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
if shape == src_shape:
|
|
744
|
-
return input
|
|
745
|
-
for i, item in enumerate(src_shape):
|
|
746
|
-
if shape[i] != item and item != 1:
|
|
747
|
-
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
|
|
748
|
-
f" must match the existing size ({item}) at non-singleton dimension"
|
|
749
|
-
f" {i}: {src_shape}, {shape}")
|
|
750
|
-
ret_ty = tl.block_type(input.type.scalar, shape)
|
|
751
|
-
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
755
|
-
lhs_ty = lhs.type
|
|
756
|
-
rhs_ty = rhs.type
|
|
757
|
-
|
|
758
|
-
# make_shape_compatible(block, scalar)
|
|
759
|
-
if lhs_ty.is_block() and not rhs_ty.is_block():
|
|
760
|
-
rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape)
|
|
761
|
-
rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty)
|
|
762
|
-
# make_shape_compatible(scalar, block)
|
|
763
|
-
elif not lhs_ty.is_block() and rhs_ty.is_block():
|
|
764
|
-
lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape)
|
|
765
|
-
lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty)
|
|
766
|
-
# make_shape_compatible(block, block)
|
|
767
|
-
elif lhs_ty.is_block() and rhs_ty.is_block():
|
|
768
|
-
lhs_shape = lhs_ty.get_block_shapes()
|
|
769
|
-
rhs_shape = rhs_ty.get_block_shapes()
|
|
770
|
-
|
|
771
|
-
if len(lhs_shape) < len(rhs_shape):
|
|
772
|
-
# Add new axes to lhs
|
|
773
|
-
for _ in range(len(lhs_shape), len(rhs_shape)):
|
|
774
|
-
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0),
|
|
775
|
-
tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values))
|
|
776
|
-
lhs_ty = lhs.type
|
|
777
|
-
lhs_shape = lhs_ty.get_block_shapes()
|
|
778
|
-
elif len(rhs_shape) < len(lhs_shape):
|
|
779
|
-
# Add new axes to rhs
|
|
780
|
-
for _ in range(len(rhs_shape), len(lhs_shape)):
|
|
781
|
-
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0),
|
|
782
|
-
tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values))
|
|
783
|
-
rhs_ty = rhs.type
|
|
784
|
-
rhs_shape = rhs_ty.get_block_shapes()
|
|
785
|
-
assert len(rhs_shape) == len(lhs_shape)
|
|
786
|
-
|
|
787
|
-
ret_shape = []
|
|
788
|
-
for i, left in enumerate(lhs_shape):
|
|
789
|
-
right = rhs_shape[i]
|
|
790
|
-
if left == 1:
|
|
791
|
-
ret_shape.append(right)
|
|
792
|
-
elif (right == 1) or (right == left):
|
|
793
|
-
ret_shape.append(left)
|
|
794
|
-
else:
|
|
795
|
-
raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
|
|
796
|
-
"at index " + str(i) + ": " + str(left) + " and " + str(right))
|
|
797
|
-
if lhs_shape != ret_shape:
|
|
798
|
-
ret_ty = tl.block_type(lhs_ty.scalar, ret_shape)
|
|
799
|
-
lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty)
|
|
800
|
-
if rhs_shape != ret_shape:
|
|
801
|
-
ret_ty = tl.block_type(rhs_ty.scalar, ret_shape)
|
|
802
|
-
rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty)
|
|
803
|
-
# (scalar, scalar) => returns original blocks
|
|
804
|
-
return lhs, rhs
|
|
675
|
+
new_shape = a.shape[:-1]
|
|
676
|
+
ret_type = tl.block_type(a.type.scalar, new_shape)
|
|
677
|
+
outLHS, outRHS = self.builder.create_split(a.handle)
|
|
678
|
+
return (
|
|
679
|
+
self.tensor(outLHS, ret_type),
|
|
680
|
+
self.tensor(outRHS, ret_type),
|
|
681
|
+
)
|
|
805
682
|
|
|
683
|
+
def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
|
|
684
|
+
if len(input.shape) != len(dims):
|
|
685
|
+
raise ValueError("permute dims must have the same length as input shape")
|
|
686
|
+
if sorted(tl._unwrap_if_constexpr(d) for d in dims) != list(range(len(dims))):
|
|
687
|
+
raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}")
|
|
688
|
+
|
|
689
|
+
ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims])
|
|
690
|
+
return self.tensor(self.builder.create_trans(input.handle, dims), ret_type)
|
|
691
|
+
|
|
692
|
+
def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
|
|
693
|
+
if not input.type.is_block():
|
|
694
|
+
return self.splat(input, shape)
|
|
695
|
+
src_shape = input.type.get_block_shapes()
|
|
696
|
+
if len(src_shape) != len(shape):
|
|
697
|
+
raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
|
|
698
|
+
if shape == src_shape:
|
|
699
|
+
return input
|
|
700
|
+
for i, item in enumerate(src_shape):
|
|
701
|
+
if shape[i] != item and item != 1:
|
|
702
|
+
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
|
|
703
|
+
f" must match the existing size ({item}) at non-singleton dimension"
|
|
704
|
+
f" {i}: {src_shape}, {shape}")
|
|
705
|
+
ret_ty = tl.block_type(input.type.scalar, shape)
|
|
706
|
+
return self.tensor(self.builder.create_broadcast(input.handle, shape), ret_ty)
|
|
707
|
+
|
|
708
|
+
def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
|
|
709
|
+
lhs_ty = lhs.type
|
|
710
|
+
rhs_ty = rhs.type
|
|
711
|
+
|
|
712
|
+
# make_shape_compatible(block, scalar)
|
|
713
|
+
if lhs_ty.is_block() and not rhs_ty.is_block():
|
|
714
|
+
rhs_ty = lhs_ty.with_element_ty(rhs_ty.scalar)
|
|
715
|
+
rhs = self.tensor(self.builder.create_splat(rhs_ty.to_ir(self.builder), rhs.handle), rhs_ty)
|
|
716
|
+
# make_shape_compatible(scalar, block)
|
|
717
|
+
elif not lhs_ty.is_block() and rhs_ty.is_block():
|
|
718
|
+
lhs_ty = rhs_ty.with_element_ty(lhs_ty.scalar)
|
|
719
|
+
lhs = self.tensor(self.builder.create_splat(lhs_ty.to_ir(self.builder), lhs.handle), lhs_ty)
|
|
720
|
+
# make_shape_compatible(block, block)
|
|
721
|
+
elif lhs_ty.is_block() and rhs_ty.is_block():
|
|
722
|
+
lhs_shape = lhs_ty.get_block_shapes()
|
|
723
|
+
rhs_shape = rhs_ty.get_block_shapes()
|
|
724
|
+
|
|
725
|
+
if len(lhs_shape) < len(rhs_shape):
|
|
726
|
+
# Add new axes to lhs
|
|
727
|
+
for _ in range(len(lhs_shape), len(rhs_shape)):
|
|
728
|
+
lhs = self.tensor(self.builder.create_expand_dims(lhs.handle, 0),
|
|
729
|
+
tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values))
|
|
730
|
+
lhs_ty = lhs.type
|
|
731
|
+
lhs_shape = lhs_ty.get_block_shapes()
|
|
732
|
+
elif len(rhs_shape) < len(lhs_shape):
|
|
733
|
+
# Add new axes to rhs
|
|
734
|
+
for _ in range(len(rhs_shape), len(lhs_shape)):
|
|
735
|
+
rhs = self.tensor(self.builder.create_expand_dims(rhs.handle, 0),
|
|
736
|
+
tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values))
|
|
737
|
+
rhs_ty = rhs.type
|
|
738
|
+
rhs_shape = rhs_ty.get_block_shapes()
|
|
739
|
+
assert len(rhs_shape) == len(lhs_shape)
|
|
740
|
+
|
|
741
|
+
ret_shape = []
|
|
742
|
+
for i, left in enumerate(lhs_shape):
|
|
743
|
+
right = rhs_shape[i]
|
|
744
|
+
if left == 1:
|
|
745
|
+
ret_shape.append(right)
|
|
746
|
+
elif (right == 1) or (right == left):
|
|
747
|
+
ret_shape.append(left)
|
|
748
|
+
else:
|
|
749
|
+
raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
|
|
750
|
+
"at index " + str(i) + ": " + str(left) + " and " + str(right))
|
|
751
|
+
if lhs_shape != ret_shape:
|
|
752
|
+
ret_ty = tl.block_type(lhs_ty.scalar, ret_shape)
|
|
753
|
+
lhs = self.tensor(self.builder.create_broadcast(lhs.handle, ret_shape), ret_ty)
|
|
754
|
+
if rhs_shape != ret_shape:
|
|
755
|
+
ret_ty = tl.block_type(rhs_ty.scalar, ret_shape)
|
|
756
|
+
rhs = self.tensor(self.builder.create_broadcast(rhs.handle, ret_shape), ret_ty)
|
|
757
|
+
# (scalar, scalar) => returns original blocks
|
|
758
|
+
return lhs, rhs
|
|
806
759
|
|
|
807
760
|
#######
|
|
808
761
|
# cast
|
|
809
762
|
#######
|
|
810
763
|
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
use_custom_rounding = False
|
|
856
|
-
if dst_sca_ty.is_floating() and src_sca_ty.is_floating(
|
|
857
|
-
) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth:
|
|
858
|
-
if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE
|
|
859
|
-
elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True
|
|
860
|
-
else:
|
|
861
|
-
if fp_downcast_rounding is not None:
|
|
862
|
-
raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. "
|
|
863
|
-
"Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty))
|
|
864
|
-
|
|
865
|
-
if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()):
|
|
866
|
-
assert builder.codegen_fns.get(
|
|
867
|
-
"convert_custom_types") is not None, "target doesn't provide conversion for this type."
|
|
868
|
-
return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder)
|
|
869
|
-
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
|
|
870
|
-
# and non-default rounding modes for downcasting
|
|
871
|
-
if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
|
|
872
|
-
(src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \
|
|
873
|
-
use_custom_rounding:
|
|
874
|
-
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty)
|
|
875
|
-
|
|
876
|
-
# bf16 <=> (not fp32)
|
|
877
|
-
if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
|
|
878
|
-
(src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
|
|
879
|
-
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
|
|
880
|
-
|
|
881
|
-
# Standard floating types' casting: truncation
|
|
882
|
-
# fp64 => fp32, fp16, bf16
|
|
883
|
-
# fp32 => fp16, bf16
|
|
884
|
-
truncate_fp = src_sca_ty.is_floating() and \
|
|
885
|
-
dst_sca_ty.is_floating() and \
|
|
886
|
-
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
|
|
887
|
-
if truncate_fp:
|
|
888
|
-
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
|
889
|
-
|
|
890
|
-
# Standard floating types' casting: extension
|
|
891
|
-
# fp32 => fp64
|
|
892
|
-
# fp16 => fp32, fp64
|
|
893
|
-
# bf16 => fp32, fp64
|
|
894
|
-
ext_fp = src_sca_ty.is_floating() and \
|
|
895
|
-
dst_sca_ty.is_floating() and \
|
|
896
|
-
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
|
|
897
|
-
if ext_fp:
|
|
898
|
-
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
|
899
|
-
|
|
900
|
-
# Casting between integer types
|
|
901
|
-
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
|
902
|
-
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
|
|
903
|
-
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
|
|
904
|
-
if dst_sca_ty.is_bool():
|
|
905
|
-
ty = input.dtype.to_ir(builder)
|
|
906
|
-
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
|
907
|
-
return not_equal(input, _0, builder)
|
|
764
|
+
def _str_to_rounding_mode(self, rounding_mode: Optional[str]):
|
|
765
|
+
if rounding_mode is None:
|
|
766
|
+
return None
|
|
767
|
+
if rounding_mode == 'rtne':
|
|
768
|
+
return ir.ROUNDING_MODE.RTNE
|
|
769
|
+
if rounding_mode == 'rtz':
|
|
770
|
+
return ir.ROUNDING_MODE.RTZ
|
|
771
|
+
raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.")
|
|
772
|
+
|
|
773
|
+
def bitcast(self, input: TensorTy, dst_ty: tl.dtype) -> TensorTy:
|
|
774
|
+
src_ty = input.type
|
|
775
|
+
if src_ty.is_block():
|
|
776
|
+
dst_ty = src_ty.with_element_ty(dst_ty.scalar)
|
|
777
|
+
if src_ty == dst_ty:
|
|
778
|
+
return input
|
|
779
|
+
src_sca_ty = src_ty.scalar
|
|
780
|
+
dst_sca_ty = dst_ty.scalar
|
|
781
|
+
if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr():
|
|
782
|
+
return self.cast(input, dst_ty)
|
|
783
|
+
# Bitcast
|
|
784
|
+
src_bits = src_sca_ty.primitive_bitwidth
|
|
785
|
+
dst_bits = dst_sca_ty.primitive_bitwidth
|
|
786
|
+
if src_bits != dst_bits:
|
|
787
|
+
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
|
|
788
|
+
"data-type of size " + str(dst_bits))
|
|
789
|
+
return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
|
|
790
|
+
|
|
791
|
+
def cast(self, input: TensorTy, dst_ty: tl.dtype, fp_downcast_rounding: Optional[str] = None) -> TensorTy:
|
|
792
|
+
src_ty = input.type
|
|
793
|
+
src_sca_ty = src_ty.scalar
|
|
794
|
+
dst_sca_ty = dst_ty.scalar
|
|
795
|
+
if src_sca_ty == dst_sca_ty:
|
|
796
|
+
return input
|
|
797
|
+
if src_ty.is_block():
|
|
798
|
+
dst_ty = src_ty.with_element_ty(dst_sca_ty)
|
|
799
|
+
|
|
800
|
+
# For fp downcasting default rounding mode should be RTNE, for all other conversions it should
|
|
801
|
+
# not be set
|
|
802
|
+
fp_downcast_rounding = self._str_to_rounding_mode(fp_downcast_rounding)
|
|
803
|
+
use_custom_rounding = False
|
|
804
|
+
if dst_sca_ty.is_floating() and src_sca_ty.is_floating(
|
|
805
|
+
) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth:
|
|
806
|
+
if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE
|
|
807
|
+
elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True
|
|
908
808
|
else:
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
809
|
+
if fp_downcast_rounding is not None:
|
|
810
|
+
raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. "
|
|
811
|
+
"Source scalar type is " + str(src_sca_ty) + " and destination type is " +
|
|
812
|
+
str(dst_sca_ty))
|
|
813
|
+
|
|
814
|
+
if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()):
|
|
815
|
+
assert self.builder.codegen_fns.get(
|
|
816
|
+
"convert_custom_types") is not None, "target doesn't provide conversion for this type."
|
|
817
|
+
return self.builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _semantic=self)
|
|
818
|
+
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
|
|
819
|
+
# and non-default rounding modes for downcasting
|
|
820
|
+
if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
|
|
821
|
+
(src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \
|
|
822
|
+
use_custom_rounding:
|
|
823
|
+
return self.tensor(
|
|
824
|
+
self.builder.create_fp_to_fp(input.handle, dst_ty.to_ir(self.builder), fp_downcast_rounding), dst_ty)
|
|
825
|
+
|
|
826
|
+
# bf16 <=> (not fp32)
|
|
827
|
+
if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
|
|
828
|
+
(src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
|
|
829
|
+
return self.cast(self.cast(input, tl.float32), dst_sca_ty)
|
|
830
|
+
|
|
831
|
+
# Standard floating types' casting: truncation
|
|
832
|
+
# fp64 => fp32, fp16, bf16
|
|
833
|
+
# fp32 => fp16, bf16
|
|
834
|
+
truncate_fp = src_sca_ty.is_floating() and \
|
|
835
|
+
dst_sca_ty.is_floating() and \
|
|
836
|
+
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
|
|
837
|
+
if truncate_fp:
|
|
838
|
+
return self.tensor(self.builder.create_fp_trunc(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
|
|
839
|
+
|
|
840
|
+
# Standard floating types' casting: extension
|
|
841
|
+
# fp32 => fp64
|
|
842
|
+
# fp16 => fp32, fp64
|
|
843
|
+
# bf16 => fp32, fp64
|
|
844
|
+
ext_fp = src_sca_ty.is_floating() and \
|
|
845
|
+
dst_sca_ty.is_floating() and \
|
|
846
|
+
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
|
|
847
|
+
if ext_fp:
|
|
848
|
+
return self.tensor(self.builder.create_fp_ext(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
|
|
849
|
+
|
|
850
|
+
# Casting between integer types
|
|
851
|
+
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
|
852
|
+
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
|
|
853
|
+
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
|
|
854
|
+
if dst_sca_ty.is_bool():
|
|
855
|
+
ty = input.dtype.to_ir(self.builder)
|
|
856
|
+
_0 = self.tensor(self.builder.get_null_value(ty), input.dtype)
|
|
857
|
+
return self.not_equal(input, _0)
|
|
858
|
+
else:
|
|
859
|
+
return self.tensor(self.builder.create_int_cast(input.handle, dst_ty.to_ir(self.builder), sign_extend),
|
|
860
|
+
dst_ty)
|
|
861
|
+
|
|
862
|
+
# Casting standard floating types to integer types
|
|
863
|
+
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
|
|
864
|
+
if dst_sca_ty.is_bool():
|
|
865
|
+
ty = input.dtype.to_ir(self.builder)
|
|
866
|
+
_0 = self.tensor(self.builder.get_null_value(ty), input.dtype)
|
|
867
|
+
return self.not_equal(input, _0)
|
|
868
|
+
elif dst_sca_ty.is_int_signed():
|
|
869
|
+
return self.tensor(self.builder.create_fp_to_si(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
|
|
870
|
+
else:
|
|
871
|
+
return self.tensor(self.builder.create_fp_to_ui(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
|
|
928
872
|
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder)
|
|
873
|
+
# Casting integer types to standard floating types
|
|
874
|
+
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
|
|
875
|
+
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
|
|
876
|
+
return self.tensor(self.builder.create_ui_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
|
|
877
|
+
else:
|
|
878
|
+
return self.tensor(self.builder.create_si_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
|
|
936
879
|
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
880
|
+
# Casting pointer types to integer types
|
|
881
|
+
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
|
882
|
+
bitwidth = dst_sca_ty.int_bitwidth
|
|
883
|
+
if bitwidth == 64:
|
|
884
|
+
return self.tensor(self.builder.create_ptr_to_int(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
|
|
885
|
+
if bitwidth == 1:
|
|
886
|
+
return self.not_equal(self.cast(input, tl.int64), self.tensor(self.builder.get_int64(0), tl.int64))
|
|
940
887
|
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
888
|
+
# Casting integer types to pointer types
|
|
889
|
+
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
|
|
890
|
+
return self.tensor(self.builder.create_int_to_ptr(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
|
|
944
891
|
|
|
945
|
-
|
|
892
|
+
# Casting pointer types to pointer types
|
|
893
|
+
if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
|
|
894
|
+
return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
|
|
946
895
|
|
|
896
|
+
assert False, f'cannot cast {input} to {dst_ty}'
|
|
947
897
|
|
|
948
898
|
# ===----------------------------------------------------------------------===//
|
|
949
899
|
# Memory Operators
|
|
950
900
|
# ===----------------------------------------------------------------------===//
|
|
951
901
|
|
|
902
|
+
def _str_to_load_cache_modifier(self, cache_modifier):
|
|
903
|
+
cache = ir.CACHE_MODIFIER.NONE # default
|
|
904
|
+
if cache_modifier:
|
|
905
|
+
if cache_modifier == ".ca":
|
|
906
|
+
cache = ir.CACHE_MODIFIER.CA
|
|
907
|
+
elif cache_modifier == ".cg":
|
|
908
|
+
cache = ir.CACHE_MODIFIER.CG
|
|
909
|
+
elif cache_modifier == ".cv":
|
|
910
|
+
cache = ir.CACHE_MODIFIER.CV
|
|
911
|
+
else:
|
|
912
|
+
raise ValueError(f"Cache modifier {cache_modifier} not supported")
|
|
913
|
+
return cache
|
|
914
|
+
|
|
915
|
+
def _str_to_store_cache_modifier(self, cache_modifier):
|
|
916
|
+
cache = ir.CACHE_MODIFIER.NONE # default
|
|
917
|
+
if cache_modifier:
|
|
918
|
+
if cache_modifier == ".wb":
|
|
919
|
+
cache = ir.CACHE_MODIFIER.WB
|
|
920
|
+
elif cache_modifier == ".cg":
|
|
921
|
+
cache = ir.CACHE_MODIFIER.CG
|
|
922
|
+
elif cache_modifier == ".cs":
|
|
923
|
+
cache = ir.CACHE_MODIFIER.CS
|
|
924
|
+
elif cache_modifier == ".wt":
|
|
925
|
+
cache = ir.CACHE_MODIFIER.WT
|
|
926
|
+
else:
|
|
927
|
+
raise ValueError(f"Cache modifier {cache_modifier} not supported")
|
|
928
|
+
return cache
|
|
929
|
+
|
|
930
|
+
def _str_to_eviction_policy(self, eviction_policy):
|
|
931
|
+
eviction = ir.EVICTION_POLICY.NORMAL # default
|
|
932
|
+
if eviction_policy:
|
|
933
|
+
if eviction_policy == "evict_last":
|
|
934
|
+
eviction = ir.EVICTION_POLICY.EVICT_LAST
|
|
935
|
+
elif eviction_policy == "evict_first":
|
|
936
|
+
eviction = ir.EVICTION_POLICY.EVICT_FIRST
|
|
937
|
+
else:
|
|
938
|
+
raise ValueError(f"Eviction policy {eviction_policy} not supported")
|
|
939
|
+
return eviction
|
|
940
|
+
|
|
941
|
+
def _str_to_padding_option(self, padding_option):
|
|
942
|
+
padding = None # default
|
|
943
|
+
if padding_option:
|
|
944
|
+
if padding_option == "zero":
|
|
945
|
+
padding = ir.PADDING_OPTION.PAD_ZERO
|
|
946
|
+
elif padding_option == "nan":
|
|
947
|
+
padding = ir.PADDING_OPTION.PAD_NAN
|
|
948
|
+
else:
|
|
949
|
+
raise ValueError(f"Padding option {padding_option} not supported")
|
|
950
|
+
return padding
|
|
951
|
+
|
|
952
|
+
def _str_to_sem(self, sem_option):
|
|
953
|
+
sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
|
|
954
|
+
if sem_option:
|
|
955
|
+
if sem_option == "acquire":
|
|
956
|
+
sem = ir.MEM_SEMANTIC.ACQUIRE
|
|
957
|
+
elif sem_option == "release":
|
|
958
|
+
sem = ir.MEM_SEMANTIC.RELEASE
|
|
959
|
+
elif sem_option == "acq_rel":
|
|
960
|
+
sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
|
|
961
|
+
elif sem_option == "relaxed":
|
|
962
|
+
sem = ir.MEM_SEMANTIC.RELAXED
|
|
963
|
+
else:
|
|
964
|
+
raise ValueError(f"Memory semantic {sem_option} not supported")
|
|
965
|
+
return sem
|
|
966
|
+
|
|
967
|
+
def _str_to_scope(self, scope_option):
|
|
968
|
+
scope = ir.MEM_SYNC_SCOPE.GPU
|
|
969
|
+
if scope_option:
|
|
970
|
+
if scope_option == "gpu":
|
|
971
|
+
scope = ir.MEM_SYNC_SCOPE.GPU
|
|
972
|
+
elif scope_option == "cta":
|
|
973
|
+
scope = ir.MEM_SYNC_SCOPE.CTA
|
|
974
|
+
elif scope_option == "sys":
|
|
975
|
+
scope = ir.MEM_SYNC_SCOPE.SYSTEM
|
|
976
|
+
else:
|
|
977
|
+
raise ValueError(f"Memory semantic {scope_option} not supported")
|
|
978
|
+
return scope
|
|
979
|
+
|
|
980
|
+
def _canonicalize_boundary_check(self, boundary_check, block_shape):
|
|
981
|
+
if boundary_check:
|
|
982
|
+
if not hasattr(boundary_check, "__iter__"):
|
|
983
|
+
boundary_check = [boundary_check]
|
|
984
|
+
boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check]
|
|
985
|
+
for dim in boundary_check:
|
|
986
|
+
assert isinstance(dim, int) and 0 <= dim < len(block_shape)
|
|
987
|
+
assert len(boundary_check) > 0
|
|
988
|
+
assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`"
|
|
989
|
+
return sorted(boundary_check)
|
|
990
|
+
return ()
|
|
991
|
+
|
|
992
|
+
def _load_block_pointer(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile):
|
|
993
|
+
# Load by a block pointer: `pointer_type<block_type<>>`
|
|
994
|
+
# Block pointer can not have `mask` and `other` arguments
|
|
995
|
+
if mask is not None or other is not None:
|
|
996
|
+
raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
|
|
952
997
|
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
cache = ir.CACHE_MODIFIER.CA
|
|
958
|
-
elif cache_modifier == ".cg":
|
|
959
|
-
cache = ir.CACHE_MODIFIER.CG
|
|
960
|
-
elif cache_modifier == ".cv":
|
|
961
|
-
cache = ir.CACHE_MODIFIER.CV
|
|
962
|
-
else:
|
|
963
|
-
raise ValueError(f"Cache modifier {cache_modifier} not supported")
|
|
964
|
-
return cache
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
def _str_to_store_cache_modifier(cache_modifier):
|
|
968
|
-
cache = ir.CACHE_MODIFIER.NONE # default
|
|
969
|
-
if cache_modifier:
|
|
970
|
-
if cache_modifier == ".wb":
|
|
971
|
-
cache = ir.CACHE_MODIFIER.WB
|
|
972
|
-
elif cache_modifier == ".cg":
|
|
973
|
-
cache = ir.CACHE_MODIFIER.CG
|
|
974
|
-
elif cache_modifier == ".cs":
|
|
975
|
-
cache = ir.CACHE_MODIFIER.CS
|
|
976
|
-
elif cache_modifier == ".wt":
|
|
977
|
-
cache = ir.CACHE_MODIFIER.WT
|
|
978
|
-
else:
|
|
979
|
-
raise ValueError(f"Cache modifier {cache_modifier} not supported")
|
|
980
|
-
return cache
|
|
998
|
+
elt_ty = ptr.type.element_ty.element_ty
|
|
999
|
+
assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
|
|
1000
|
+
if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
|
|
1001
|
+
raise ValueError("Padding option `nan` is not supported for integer block pointers")
|
|
981
1002
|
|
|
1003
|
+
# `dst_ty` is de-referenced type of the pointer type
|
|
1004
|
+
dst_ty = ptr.type.element_ty
|
|
982
1005
|
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
else:
|
|
991
|
-
raise ValueError(f"Eviction policy {eviction_policy} not supported")
|
|
992
|
-
return eviction
|
|
1006
|
+
# Check `boundary_check` argument
|
|
1007
|
+
boundary_check = self._canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
|
|
1008
|
+
|
|
1009
|
+
# Build IR
|
|
1010
|
+
return self.tensor(
|
|
1011
|
+
self.builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile),
|
|
1012
|
+
dst_ty)
|
|
993
1013
|
|
|
1014
|
+
def _load_legacy(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile):
|
|
1015
|
+
# Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
|
|
1016
|
+
if not ptr.type.scalar.is_ptr():
|
|
1017
|
+
raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")
|
|
1018
|
+
|
|
1019
|
+
# Check `mask`, `other`, `boundary_check`, and `padding` arguments
|
|
1020
|
+
if mask is None and other is not None:
|
|
1021
|
+
raise ValueError("`other` cannot be provided without `mask`")
|
|
1022
|
+
if padding or boundary_check:
|
|
1023
|
+
raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of"
|
|
1024
|
+
"pointers or loading a scalar. Because the compiler does not know the boundary; please "
|
|
1025
|
+
"use block pointers (defined by `make_block_ptr`) instead")
|
|
1026
|
+
|
|
1027
|
+
# For a pointer of scalar, check the type of `mask` and `other`
|
|
1028
|
+
if not ptr.type.is_block():
|
|
1029
|
+
if mask and mask.type.is_block():
|
|
1030
|
+
raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
|
|
1031
|
+
if other and other.type.is_block():
|
|
1032
|
+
raise ValueError("Other argument cannot be block type if pointer argument is not a block")
|
|
1033
|
+
|
|
1034
|
+
# Make `mask` and `other` into the same shape as `ptr`
|
|
1035
|
+
if ptr.type.is_block():
|
|
1036
|
+
if mask is not None:
|
|
1037
|
+
mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes())
|
|
1038
|
+
if other is not None:
|
|
1039
|
+
other = self.broadcast_impl_shape(other, ptr.type.get_block_shapes())
|
|
1040
|
+
|
|
1041
|
+
# Get `pointer_type<elt_ty>` and `elt_ty`
|
|
1042
|
+
ptr_ty = ptr.type.scalar
|
|
1043
|
+
elt_ty = ptr_ty.element_ty
|
|
1044
|
+
|
|
1045
|
+
# Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
|
|
1046
|
+
is_bool = elt_ty == tl.int1
|
|
1047
|
+
if is_bool:
|
|
1048
|
+
elt_ty = tl.int8
|
|
1049
|
+
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
|
1050
|
+
ptr = self.cast(ptr, ptr_ty)
|
|
1051
|
+
|
|
1052
|
+
# Cast `other` into `elt_ty` type
|
|
1053
|
+
if other is not None:
|
|
1054
|
+
other = self.cast(other, elt_ty)
|
|
994
1055
|
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
if padding_option == "zero":
|
|
999
|
-
padding = ir.PADDING_OPTION.PAD_ZERO
|
|
1000
|
-
elif padding_option == "nan":
|
|
1001
|
-
padding = ir.PADDING_OPTION.PAD_NAN
|
|
1056
|
+
# Create loaded result type `dst_ty`
|
|
1057
|
+
if ptr.type.is_block():
|
|
1058
|
+
dst_ty = ptr.type.with_element_ty(elt_ty)
|
|
1002
1059
|
else:
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
if sem_option:
|
|
1010
|
-
if sem_option == "acquire":
|
|
1011
|
-
sem = ir.MEM_SEMANTIC.ACQUIRE
|
|
1012
|
-
elif sem_option == "release":
|
|
1013
|
-
sem = ir.MEM_SEMANTIC.RELEASE
|
|
1014
|
-
elif sem_option == "acq_rel":
|
|
1015
|
-
sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
|
|
1016
|
-
elif sem_option == "relaxed":
|
|
1017
|
-
sem = ir.MEM_SEMANTIC.RELAXED
|
|
1060
|
+
# Load by de-referencing the pointer of scalar
|
|
1061
|
+
dst_ty = elt_ty
|
|
1062
|
+
|
|
1063
|
+
# Build IR
|
|
1064
|
+
if mask is None:
|
|
1065
|
+
ret = self.tensor(self.builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
|
|
1018
1066
|
else:
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1067
|
+
ret = self.tensor(
|
|
1068
|
+
self.builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache,
|
|
1069
|
+
eviction, is_volatile), dst_ty)
|
|
1070
|
+
if is_bool:
|
|
1071
|
+
ret = self.cast(ret, tl.int1)
|
|
1072
|
+
return ret
|
|
1073
|
+
|
|
1074
|
+
def load(self, ptr: TensorTy, mask: Optional[TensorTy], other: Optional[TensorTy], boundary_check: Tuple,
|
|
1075
|
+
padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool) -> TensorTy:
|
|
1076
|
+
# Cache, eviction and padding options
|
|
1077
|
+
cache = self._str_to_load_cache_modifier(cache_modifier)
|
|
1078
|
+
eviction = self._str_to_eviction_policy(eviction_policy)
|
|
1079
|
+
padding = self._str_to_padding_option(padding_option)
|
|
1080
|
+
|
|
1081
|
+
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
|
|
1082
|
+
# Load by a block pointer: `pointer_type<block_type<>>`
|
|
1083
|
+
return self._load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile)
|
|
1032
1084
|
else:
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
#
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
return
|
|
1158
|
-
# Due to limitations of the shared memory encoding, the TMA bounding box has
|
|
1159
|
-
# to be at least as big as the swizzle tile.
|
|
1160
|
-
assert shape[0] >= 8, f"tensor descriptor block shape must have at least 8 rows, but got {shape[0]}"
|
|
1161
|
-
min_cols = 32 // dtype.primitive_bitwidth * 8
|
|
1162
|
-
assert shape[
|
|
1163
|
-
1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}"
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str,
|
|
1167
|
-
builder: ir.builder) -> tl.tensor:
|
|
1168
|
-
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
|
|
1169
|
-
validate_descriptor_block(desc.block_shape, desc.dtype)
|
|
1170
|
-
ndim = len(desc.block_shape)
|
|
1171
|
-
assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
|
|
1172
|
-
|
|
1173
|
-
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
|
|
1174
|
-
x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier),
|
|
1175
|
-
_str_to_eviction_policy(eviction_policy))
|
|
1176
|
-
return tl.tensor(x, desc.block_type)
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
def descriptor_store(desc: tl._experimental_tensor_descriptor_base, value: tl.tensor, offsets,
|
|
1180
|
-
builder: ir.builder) -> tl.tensor:
|
|
1181
|
-
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
|
|
1182
|
-
validate_descriptor_block(desc.block_shape, desc.dtype)
|
|
1183
|
-
ndim = len(desc.block_shape)
|
|
1184
|
-
assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
|
|
1185
|
-
assert value.shape == desc.block_shape
|
|
1186
|
-
|
|
1187
|
-
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
|
|
1188
|
-
return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
def descriptor_gather(desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str,
|
|
1192
|
-
builder: ir.builder) -> tl.tensor:
|
|
1193
|
-
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
|
|
1194
|
-
assert cache_modifier == "", "cache modifier is not supported yet"
|
|
1195
|
-
assert eviction_policy == "", "eviction policy is not supported yet"
|
|
1196
|
-
|
|
1197
|
-
# Validate descriptor.
|
|
1198
|
-
assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
|
|
1199
|
-
assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
|
|
1200
|
-
|
|
1201
|
-
# Validate offsets.
|
|
1202
|
-
assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}"
|
|
1203
|
-
|
|
1204
|
-
# Validate minimum block size.
|
|
1205
|
-
assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}"
|
|
1206
|
-
dtype = desc.dtype
|
|
1207
|
-
min_cols = 32 // dtype.primitive_bitwidth * 8
|
|
1208
|
-
assert desc.block_shape[
|
|
1209
|
-
1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
|
|
1210
|
-
|
|
1211
|
-
type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]])
|
|
1212
|
-
y_offset = _convert_to_ir_values(builder, (y_offset, ), require_i64=False)[0]
|
|
1213
|
-
x = builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(builder))
|
|
1214
|
-
return tl.tensor(x, type)
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
def descriptor_scatter(desc, value: tl.tensor, x_offsets, y_offset, builder: ir.builder) -> tl.tensor:
|
|
1218
|
-
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
|
|
1219
|
-
|
|
1220
|
-
# Validate descriptor.
|
|
1221
|
-
assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
|
|
1222
|
-
assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
|
|
1223
|
-
|
|
1224
|
-
# Validate offsets.
|
|
1225
|
-
assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}"
|
|
1226
|
-
|
|
1227
|
-
# Validate minimum block size.
|
|
1228
|
-
assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}"
|
|
1229
|
-
dtype = desc.dtype
|
|
1230
|
-
min_cols = 32 // dtype.primitive_bitwidth * 8
|
|
1231
|
-
assert desc.block_shape[
|
|
1232
|
-
1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
|
|
1233
|
-
|
|
1234
|
-
y_offset = _convert_to_ir_values(builder, (y_offset, ), require_i64=False)[0]
|
|
1235
|
-
builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset)
|
|
1236
|
-
return tl.tensor(None, tl.void)
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
def tensormap_create(
|
|
1240
|
-
desc_ptr: tl.tensor,
|
|
1241
|
-
global_address: tl.tensor,
|
|
1242
|
-
box_dim: List[tl.tensor],
|
|
1243
|
-
global_dim: List[tl.tensor],
|
|
1244
|
-
global_stride: List[tl.tensor],
|
|
1245
|
-
element_stride: List[tl.tensor],
|
|
1246
|
-
elem_type: int,
|
|
1247
|
-
interleave_layout: int,
|
|
1248
|
-
swizzle_mode: int,
|
|
1249
|
-
fill_mode: int,
|
|
1250
|
-
builder: ir.builder,
|
|
1251
|
-
) -> tl.tensor:
|
|
1252
|
-
assert not global_stride or global_stride[0].dtype == tl.int64
|
|
1253
|
-
return tl.tensor(
|
|
1254
|
-
builder.create_tensormap_create(
|
|
1255
|
-
desc_ptr.handle,
|
|
1256
|
-
global_address.handle,
|
|
1257
|
-
[x.handle for x in box_dim],
|
|
1258
|
-
[x.handle for x in global_dim],
|
|
1259
|
-
[x.handle for x in global_stride],
|
|
1260
|
-
[x.handle for x in element_stride],
|
|
1261
|
-
elem_type,
|
|
1262
|
-
interleave_layout,
|
|
1263
|
-
swizzle_mode,
|
|
1264
|
-
fill_mode,
|
|
1265
|
-
),
|
|
1266
|
-
tl.void,
|
|
1267
|
-
)
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
def tensormap_fenceproxy_acquire(desc_ptr: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
1271
|
-
return tl.tensor(builder.create_tensormap_fenceproxy_acquire(desc_ptr.handle), tl.void)
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder):
|
|
1275
|
-
# Store by a block pointer: `pointer_type<block_type<>>`
|
|
1276
|
-
# Block pointers can not have the `mask` argument
|
|
1277
|
-
if mask is not None:
|
|
1278
|
-
raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
|
|
1279
|
-
|
|
1280
|
-
# Check same shape and element type
|
|
1281
|
-
block_shape = ptr.type.element_ty.get_block_shapes()
|
|
1282
|
-
if not val.type.is_block():
|
|
1283
|
-
val = broadcast_impl_shape(val, block_shape, builder)
|
|
1284
|
-
assert val.type.is_block(), "Value argument must be block type or a scalar"
|
|
1285
|
-
assert block_shape == val.type.get_block_shapes(
|
|
1286
|
-
), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
|
|
1287
|
-
assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
|
|
1288
|
-
|
|
1289
|
-
elt_ty = ptr.type.element_ty.element_ty
|
|
1290
|
-
assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
|
|
1291
|
-
|
|
1292
|
-
# Check `boundary_check` argument
|
|
1293
|
-
boundary_check = _canonicalize_boundary_check(boundary_check, block_shape)
|
|
1294
|
-
|
|
1295
|
-
# Cast to target data type
|
|
1296
|
-
val = cast(val, elt_ty, builder)
|
|
1297
|
-
|
|
1298
|
-
# Build IR
|
|
1299
|
-
return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction),
|
|
1300
|
-
tl.void)
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
|
|
1304
|
-
# Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
|
|
1305
|
-
if not ptr.type.scalar.is_ptr():
|
|
1306
|
-
raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`")
|
|
1307
|
-
|
|
1308
|
-
# Check `boundary_check` argument
|
|
1309
|
-
if boundary_check:
|
|
1310
|
-
raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a "
|
|
1311
|
-
"scalar. Because the compiler does not know the boundary; please use block pointers "
|
|
1312
|
-
"(defined by `make_block_ptr`) instead")
|
|
1313
|
-
|
|
1314
|
-
# For a pointer of scalar, check the type of `val` and `mask`
|
|
1315
|
-
if not ptr.type.is_block():
|
|
1316
|
-
if val.type.is_block():
|
|
1317
|
-
raise ValueError("Value argument cannot be block type if pointer argument is not a block")
|
|
1318
|
-
if mask and mask.type.is_block():
|
|
1319
|
-
raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
|
|
1320
|
-
|
|
1321
|
-
# Make `mask` and `val` into the same shape as `ptr`
|
|
1322
|
-
if ptr.type.is_block():
|
|
1323
|
-
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
|
|
1085
|
+
# Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
|
|
1086
|
+
return self._load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile)
|
|
1087
|
+
|
|
1088
|
+
def descriptor_load(self, desc: tl.tensor_descriptor_base, offsets, cache_modifier: str,
|
|
1089
|
+
eviction_policy: str) -> TensorTy:
|
|
1090
|
+
assert isinstance(desc, tl.tensor_descriptor_base)
|
|
1091
|
+
ndim = len(desc.block_shape)
|
|
1092
|
+
assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
|
|
1093
|
+
|
|
1094
|
+
offsets = self._convert_to_ir_values(offsets, require_i64=False)
|
|
1095
|
+
x = self.builder.create_descriptor_load(desc.handle, offsets, self._str_to_load_cache_modifier(cache_modifier),
|
|
1096
|
+
self._str_to_eviction_policy(eviction_policy))
|
|
1097
|
+
return self.tensor(x, desc.block_type)
|
|
1098
|
+
|
|
1099
|
+
def validate_store_like(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> None:
|
|
1100
|
+
assert isinstance(desc, tl.tensor_descriptor_base)
|
|
1101
|
+
ndim = len(desc.block_shape)
|
|
1102
|
+
assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
|
|
1103
|
+
assert value.shape == desc.block_shape
|
|
1104
|
+
|
|
1105
|
+
def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
|
|
1106
|
+
self.validate_store_like(desc, value, offsets)
|
|
1107
|
+
offsets = self._convert_to_ir_values(offsets, require_i64=False)
|
|
1108
|
+
return self.tensor(self.builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)
|
|
1109
|
+
|
|
1110
|
+
def descriptor_atomic_add(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
|
|
1111
|
+
self.validate_store_like(desc, value, offsets)
|
|
1112
|
+
assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16}, "Unsupported dtype"
|
|
1113
|
+
offsets = self._convert_to_ir_values(offsets, require_i64=False)
|
|
1114
|
+
kind = ir.DESCRIPTOR_REDUCE_KIND.ADD
|
|
1115
|
+
return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
|
|
1116
|
+
|
|
1117
|
+
def _has_native_tma(self, ):
|
|
1118
|
+
target = driver.active.get_current_target()
|
|
1119
|
+
return (target.backend == "cuda" and target.arch >= 90)
|
|
1120
|
+
|
|
1121
|
+
def _descriptor_atomic_min_max_supported(self, dtype):
|
|
1122
|
+
assert dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype"
|
|
1123
|
+
if dtype in {tl.float16, tl.bfloat16}:
|
|
1124
|
+
assert self._has_native_tma(), "16-bit float types require native tma support"
|
|
1125
|
+
|
|
1126
|
+
def descriptor_atomic_min(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
|
|
1127
|
+
self.validate_store_like(desc, value, offsets)
|
|
1128
|
+
self._descriptor_atomic_min_max_supported(desc.dtype)
|
|
1129
|
+
offsets = self._convert_to_ir_values(offsets, require_i64=False)
|
|
1130
|
+
kind = ir.DESCRIPTOR_REDUCE_KIND.MIN
|
|
1131
|
+
return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
|
|
1132
|
+
|
|
1133
|
+
def descriptor_atomic_max(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
|
|
1134
|
+
self.validate_store_like(desc, value, offsets)
|
|
1135
|
+
self._descriptor_atomic_min_max_supported(desc.dtype)
|
|
1136
|
+
offsets = self._convert_to_ir_values(offsets, require_i64=False)
|
|
1137
|
+
kind = ir.DESCRIPTOR_REDUCE_KIND.MAX
|
|
1138
|
+
return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
|
|
1139
|
+
|
|
1140
|
+
def descriptor_atomic_and(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
|
|
1141
|
+
self.validate_store_like(desc, value, offsets)
|
|
1142
|
+
assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
|
|
1143
|
+
offsets = self._convert_to_ir_values(offsets, require_i64=False)
|
|
1144
|
+
kind = ir.DESCRIPTOR_REDUCE_KIND.AND
|
|
1145
|
+
return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
|
|
1146
|
+
|
|
1147
|
+
def descriptor_atomic_or(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
|
|
1148
|
+
self.validate_store_like(desc, value, offsets)
|
|
1149
|
+
assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
|
|
1150
|
+
offsets = self._convert_to_ir_values(offsets, require_i64=False)
|
|
1151
|
+
kind = ir.DESCRIPTOR_REDUCE_KIND.OR
|
|
1152
|
+
return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
|
|
1153
|
+
|
|
1154
|
+
def descriptor_atomic_xor(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
|
|
1155
|
+
self.validate_store_like(desc, value, offsets)
|
|
1156
|
+
assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
|
|
1157
|
+
offsets = self._convert_to_ir_values(offsets, require_i64=False)
|
|
1158
|
+
kind = ir.DESCRIPTOR_REDUCE_KIND.XOR
|
|
1159
|
+
return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
|
|
1160
|
+
|
|
1161
|
+
def descriptor_gather(self, desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str) -> TensorTy:
|
|
1162
|
+
assert isinstance(desc, tl.tensor_descriptor_base)
|
|
1163
|
+
assert cache_modifier == "", "cache modifier is not supported yet"
|
|
1164
|
+
assert eviction_policy == "", "eviction policy is not supported yet"
|
|
1165
|
+
|
|
1166
|
+
# Validate descriptor.
|
|
1167
|
+
assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
|
|
1168
|
+
assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
|
|
1169
|
+
|
|
1170
|
+
# Validate offsets.
|
|
1171
|
+
assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}"
|
|
1172
|
+
|
|
1173
|
+
# Validate minimum block size.
|
|
1174
|
+
assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}"
|
|
1175
|
+
dtype = desc.dtype
|
|
1176
|
+
min_cols = 32 // dtype.primitive_bitwidth * 8
|
|
1177
|
+
assert desc.block_shape[
|
|
1178
|
+
1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
|
|
1179
|
+
|
|
1180
|
+
type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]])
|
|
1181
|
+
y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0]
|
|
1182
|
+
x = self.builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(self.builder))
|
|
1183
|
+
return self.tensor(x, type)
|
|
1184
|
+
|
|
1185
|
+
def descriptor_scatter(self, desc, value: TensorTy, x_offsets, y_offset) -> TensorTy:
|
|
1186
|
+
assert isinstance(desc, tl.tensor_descriptor_base)
|
|
1187
|
+
|
|
1188
|
+
# Validate descriptor.
|
|
1189
|
+
assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
|
|
1190
|
+
assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
|
|
1191
|
+
|
|
1192
|
+
# Validate offsets.
|
|
1193
|
+
assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}"
|
|
1194
|
+
|
|
1195
|
+
# Validate minimum block size.
|
|
1196
|
+
assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}"
|
|
1197
|
+
dtype = desc.dtype
|
|
1198
|
+
min_cols = 32 // dtype.primitive_bitwidth * 8
|
|
1199
|
+
assert desc.block_shape[
|
|
1200
|
+
1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
|
|
1201
|
+
|
|
1202
|
+
y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0]
|
|
1203
|
+
self.builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset)
|
|
1204
|
+
return self.tensor(None, tl.void)
|
|
1205
|
+
|
|
1206
|
+
def _store_block_pointer(self, ptr, val, mask, boundary_check, cache, eviction):
|
|
1207
|
+
# Store by a block pointer: `pointer_type<block_type<>>`
|
|
1208
|
+
# Block pointers can not have the `mask` argument
|
|
1324
1209
|
if mask is not None:
|
|
1325
|
-
|
|
1210
|
+
raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
|
|
1326
1211
|
|
|
1327
|
-
|
|
1328
|
-
|
|
1212
|
+
# Check same shape and element type
|
|
1213
|
+
block_shape = ptr.type.element_ty.get_block_shapes()
|
|
1214
|
+
if not val.type.is_block():
|
|
1215
|
+
val = self.broadcast_impl_shape(val, block_shape)
|
|
1216
|
+
assert val.type.is_block(), "Value argument must be block type or a scalar"
|
|
1217
|
+
assert block_shape == val.type.get_block_shapes(
|
|
1218
|
+
), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
|
|
1219
|
+
assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
|
|
1329
1220
|
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
elt_ty = tl.int8
|
|
1333
|
-
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
|
1334
|
-
ptr = cast(ptr, ptr_ty, builder)
|
|
1221
|
+
elt_ty = ptr.type.element_ty.element_ty
|
|
1222
|
+
assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
|
|
1335
1223
|
|
|
1336
|
-
|
|
1337
|
-
|
|
1224
|
+
# Check `boundary_check` argument
|
|
1225
|
+
boundary_check = self._canonicalize_boundary_check(boundary_check, block_shape)
|
|
1338
1226
|
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
|
|
1342
|
-
if not mask.type.scalar.is_bool():
|
|
1343
|
-
raise ValueError("Mask must have boolean scalar type")
|
|
1344
|
-
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
|
|
1227
|
+
# Cast to target data type
|
|
1228
|
+
val = self.cast(val, elt_ty)
|
|
1345
1229
|
|
|
1230
|
+
# Build IR
|
|
1231
|
+
return self.tensor(
|
|
1232
|
+
self.builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), tl.void)
|
|
1346
1233
|
|
|
1347
|
-
def
|
|
1348
|
-
eviction_policy: str, builder: ir.builder) -> tl.tensor:
|
|
1349
|
-
# Cache and eviction options
|
|
1350
|
-
cache = _str_to_store_cache_modifier(cache_modifier)
|
|
1351
|
-
eviction = _str_to_eviction_policy(eviction_policy)
|
|
1352
|
-
|
|
1353
|
-
if ptr.type.is_const() or ptr.type.scalar.is_const():
|
|
1354
|
-
raise ValueError("Cannot store to a constant pointer")
|
|
1355
|
-
|
|
1356
|
-
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
|
|
1357
|
-
# Store by a block pointer: `pointer_type<block_type<>>`
|
|
1358
|
-
return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder)
|
|
1359
|
-
else:
|
|
1234
|
+
def _store_legacy(self, ptr, val, mask, boundary_check, cache, eviction):
|
|
1360
1235
|
# Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
|
|
1361
|
-
|
|
1362
|
-
|
|
1236
|
+
if not ptr.type.scalar.is_ptr():
|
|
1237
|
+
raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`")
|
|
1238
|
+
|
|
1239
|
+
# Check `boundary_check` argument
|
|
1240
|
+
if boundary_check:
|
|
1241
|
+
raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a "
|
|
1242
|
+
"scalar. Because the compiler does not know the boundary; please use block pointers "
|
|
1243
|
+
"(defined by `make_block_ptr`) instead")
|
|
1244
|
+
|
|
1245
|
+
# For a pointer of scalar, check the type of `val` and `mask`
|
|
1246
|
+
if not ptr.type.is_block():
|
|
1247
|
+
if val.type.is_block():
|
|
1248
|
+
raise ValueError("Value argument cannot be block type if pointer argument is not a block")
|
|
1249
|
+
if mask and mask.type.is_block():
|
|
1250
|
+
raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
|
|
1251
|
+
|
|
1252
|
+
# Make `mask` and `val` into the same shape as `ptr`
|
|
1253
|
+
if ptr.type.is_block():
|
|
1254
|
+
val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes())
|
|
1255
|
+
if mask is not None:
|
|
1256
|
+
mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes())
|
|
1257
|
+
|
|
1258
|
+
ptr_ty = ptr.type.scalar
|
|
1259
|
+
elt_ty = ptr_ty.element_ty
|
|
1260
|
+
|
|
1261
|
+
# Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
|
|
1262
|
+
if elt_ty == tl.int1:
|
|
1263
|
+
elt_ty = tl.int8
|
|
1264
|
+
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
|
1265
|
+
ptr = self.cast(ptr, ptr_ty)
|
|
1266
|
+
|
|
1267
|
+
# Cast to target data type
|
|
1268
|
+
val = self.cast(val, elt_ty)
|
|
1269
|
+
|
|
1270
|
+
# Build IR
|
|
1271
|
+
if mask is None:
|
|
1272
|
+
return self.tensor(self.builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
|
|
1273
|
+
if not mask.type.scalar.is_bool():
|
|
1274
|
+
raise ValueError("Mask must have boolean scalar type")
|
|
1275
|
+
return self.tensor(self.builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction),
|
|
1276
|
+
tl.void)
|
|
1277
|
+
|
|
1278
|
+
def store(self, ptr: TensorTy, val: TensorTy, mask: Optional[TensorTy], boundary_check, cache_modifier: str,
|
|
1279
|
+
eviction_policy: str) -> TensorTy:
|
|
1280
|
+
# Cache and eviction options
|
|
1281
|
+
cache = self._str_to_store_cache_modifier(cache_modifier)
|
|
1282
|
+
eviction = self._str_to_eviction_policy(eviction_policy)
|
|
1283
|
+
|
|
1284
|
+
if ptr.type.is_const() or ptr.type.scalar.is_const():
|
|
1285
|
+
raise ValueError("Cannot store to a constant pointer")
|
|
1286
|
+
|
|
1287
|
+
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
|
|
1288
|
+
# Store by a block pointer: `pointer_type<block_type<>>`
|
|
1289
|
+
return self._store_block_pointer(ptr, val, mask, boundary_check, cache, eviction)
|
|
1290
|
+
else:
|
|
1291
|
+
# Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
|
|
1292
|
+
return self._store_legacy(ptr, val, mask, boundary_check, cache, eviction)
|
|
1363
1293
|
|
|
1364
1294
|
#########
|
|
1365
1295
|
# atomic
|
|
1366
1296
|
#########
|
|
1367
1297
|
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
if ptr.type.is_block():
|
|
1390
|
-
if mask is not None:
|
|
1391
|
-
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
|
|
1392
|
-
if val is not None:
|
|
1393
|
-
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
|
|
1394
|
-
val = cast(val, ptr.type.scalar.element_ty, builder)
|
|
1395
|
-
if mask is None:
|
|
1396
|
-
mask_ir = builder.get_int1(True)
|
|
1397
|
-
mask_ty = tl.int1
|
|
1298
|
+
def atomic_cas(self, ptr: TensorTy, cmp: TensorTy, val: TensorTy, sem: str, scope: str) -> TensorTy:
|
|
1299
|
+
sem = self._str_to_sem(sem)
|
|
1300
|
+
scope = self._str_to_scope(scope)
|
|
1301
|
+
element_ty = ptr.type.scalar.element_ty
|
|
1302
|
+
if element_ty.primitive_bitwidth not in [16, 32, 64]:
|
|
1303
|
+
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
|
|
1304
|
+
return self.tensor(self.builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
|
|
1305
|
+
|
|
1306
|
+
def atom_red_typechecking_impl(self, ptr: TensorTy, val: TensorTy, mask: TensorTy,
|
|
1307
|
+
op: str) -> Tuple[TensorTy, TensorTy, TensorTy]:
|
|
1308
|
+
if not ptr.type.scalar.is_ptr():
|
|
1309
|
+
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
|
1310
|
+
if ptr.type.is_const() or ptr.type.element_ty.is_const():
|
|
1311
|
+
raise ValueError("Cannot store to a constant pointer")
|
|
1312
|
+
element_ty = ptr.type.scalar.element_ty
|
|
1313
|
+
if element_ty is tl.float16 and op != 'add':
|
|
1314
|
+
raise ValueError("atomic_" + op + " does not support fp16")
|
|
1315
|
+
if element_ty is tl.bfloat16 and op != 'add':
|
|
1316
|
+
raise ValueError("atomic_" + op + " does not support bf16")
|
|
1317
|
+
if element_ty in [tl.int16, tl.uint16] or element_ty.primitive_bitwidth < 16:
|
|
1318
|
+
raise ValueError("atomic_" + op + " does not support " + str(element_ty))
|
|
1398
1319
|
if ptr.type.is_block():
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
sem
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1320
|
+
if mask is not None:
|
|
1321
|
+
mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes())
|
|
1322
|
+
if val is not None:
|
|
1323
|
+
val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes())
|
|
1324
|
+
val = self.cast(val, ptr.type.scalar.element_ty)
|
|
1325
|
+
if mask is None:
|
|
1326
|
+
mask_ir = self.builder.get_int1(True)
|
|
1327
|
+
mask_ty = tl.int1
|
|
1328
|
+
if ptr.type.is_block():
|
|
1329
|
+
mask_ty = ptr.type.with_element_ty(tl.int1)
|
|
1330
|
+
mask_ir = self.builder.create_splat(mask_ty.to_ir(self.builder), mask_ir)
|
|
1331
|
+
mask = self.tensor(mask_ir, mask_ty)
|
|
1332
|
+
return ptr, val, mask
|
|
1333
|
+
|
|
1334
|
+
def _signbit(self, x: TensorTy) -> TensorTy:
|
|
1335
|
+
bitwidth = x.dtype.primitive_bitwidth
|
|
1336
|
+
idtype = tl.get_int_dtype(bitwidth=bitwidth, signed=False)
|
|
1337
|
+
ix = self.bitcast(x, idtype)
|
|
1338
|
+
signbit = self.lshr(ix, bitwidth - 1)
|
|
1339
|
+
return self.cast(signbit, tl.int1)
|
|
1340
|
+
|
|
1341
|
+
def atomic_max(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
|
|
1342
|
+
ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'max')
|
|
1343
|
+
sem = self._str_to_sem(sem)
|
|
1344
|
+
scope = self._str_to_scope(scope)
|
|
1345
|
+
sca_ty = val.type.scalar
|
|
1346
|
+
# direct call to atomic_max for integers
|
|
1347
|
+
if sca_ty.is_int():
|
|
1348
|
+
if sca_ty.is_int_signed():
|
|
1349
|
+
return self.tensor(
|
|
1350
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope),
|
|
1351
|
+
val.type)
|
|
1352
|
+
else:
|
|
1353
|
+
return self.tensor(
|
|
1354
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope),
|
|
1355
|
+
val.type)
|
|
1356
|
+
# for float
|
|
1357
|
+
# return atomic_smax(i_ptr, i_val) if val >= 0
|
|
1358
|
+
# return atomic_umin(i_ptr, i_val) if val < 0
|
|
1359
|
+
if sca_ty not in {tl.float32, tl.float64}:
|
|
1360
|
+
raise TypeError(f"atomic_max not supported for dtype {sca_ty}")
|
|
1361
|
+
|
|
1362
|
+
i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
|
|
1363
|
+
i_val = self.bitcast(val, i_type)
|
|
1364
|
+
i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1))
|
|
1365
|
+
ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
|
|
1366
|
+
ui_val = self.bitcast(val, ui_type)
|
|
1367
|
+
ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1))
|
|
1368
|
+
neg = self._signbit(val)
|
|
1369
|
+
pos = self.not_(neg)
|
|
1370
|
+
pos_ret = self.tensor(
|
|
1371
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
|
|
1372
|
+
self.and_(mask, pos).handle, sem, scope), i_val.type)
|
|
1373
|
+
neg_ret = self.tensor(
|
|
1374
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle,
|
|
1375
|
+
self.and_(mask, neg).handle, sem, scope), ui_val.type)
|
|
1376
|
+
ret = self.where(pos, pos_ret, neg_ret)
|
|
1377
|
+
return self.bitcast(ret, sca_ty)
|
|
1378
|
+
|
|
1379
|
+
def atomic_min(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
|
|
1380
|
+
ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'min')
|
|
1381
|
+
sem = self._str_to_sem(sem)
|
|
1382
|
+
scope = self._str_to_scope(scope)
|
|
1383
|
+
sca_ty = val.type.scalar
|
|
1384
|
+
# direct call to atomic_min for integers
|
|
1385
|
+
if sca_ty.is_int():
|
|
1386
|
+
if sca_ty.is_int_signed():
|
|
1387
|
+
return self.tensor(
|
|
1388
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope),
|
|
1389
|
+
val.type)
|
|
1390
|
+
else:
|
|
1391
|
+
return self.tensor(
|
|
1392
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope),
|
|
1393
|
+
val.type)
|
|
1394
|
+
# for float
|
|
1395
|
+
# return atomic_smin(i_ptr, i_val) if val >= 0
|
|
1396
|
+
# return atomic_umax(i_ptr, i_val) if val < 0
|
|
1397
|
+
if sca_ty not in {tl.float32, tl.float64}:
|
|
1398
|
+
raise TypeError(f"atomic_min not supported for dtype {sca_ty}")
|
|
1399
|
+
|
|
1400
|
+
i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
|
|
1401
|
+
i_val = self.bitcast(val, i_type)
|
|
1402
|
+
i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1))
|
|
1403
|
+
ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
|
|
1404
|
+
ui_val = self.bitcast(val, ui_type)
|
|
1405
|
+
ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1))
|
|
1406
|
+
neg = self._signbit(val)
|
|
1407
|
+
pos = self.not_(neg)
|
|
1408
|
+
pos_ret = self.tensor(
|
|
1409
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
|
|
1410
|
+
self.and_(mask, pos).handle, sem, scope), i_val.type)
|
|
1411
|
+
neg_ret = self.tensor(
|
|
1412
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle,
|
|
1413
|
+
self.and_(mask, neg).handle, sem, scope), ui_ptr.type)
|
|
1414
|
+
ret = self.where(pos, pos_ret, neg_ret)
|
|
1415
|
+
return self.bitcast(ret, sca_ty)
|
|
1416
|
+
|
|
1417
|
+
def atomic_add(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
|
|
1418
|
+
ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'add')
|
|
1419
|
+
sem = self._str_to_sem(sem)
|
|
1420
|
+
scope = self._str_to_scope(scope)
|
|
1421
|
+
sca_ty = val.type.scalar
|
|
1422
|
+
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
|
|
1423
|
+
return self.tensor(self.builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope),
|
|
1424
|
+
val.type)
|
|
1425
|
+
|
|
1426
|
+
def atomic_and(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
|
|
1427
|
+
ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'and')
|
|
1428
|
+
sem = self._str_to_sem(sem)
|
|
1429
|
+
scope = self._str_to_scope(scope)
|
|
1430
|
+
return self.tensor(
|
|
1431
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
|
1432
|
+
|
|
1433
|
+
def atomic_or(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
|
|
1434
|
+
ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'or')
|
|
1435
|
+
sem = self._str_to_sem(sem)
|
|
1436
|
+
scope = self._str_to_scope(scope)
|
|
1437
|
+
return self.tensor(
|
|
1438
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
|
1439
|
+
|
|
1440
|
+
def atomic_xor(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
|
|
1441
|
+
ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xor')
|
|
1442
|
+
sem = self._str_to_sem(sem)
|
|
1443
|
+
scope = self._str_to_scope(scope)
|
|
1444
|
+
return self.tensor(
|
|
1445
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
|
1446
|
+
|
|
1447
|
+
def atomic_xchg(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
|
|
1448
|
+
ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xchg')
|
|
1449
|
+
sem = self._str_to_sem(sem)
|
|
1450
|
+
scope = self._str_to_scope(scope)
|
|
1451
|
+
return self.tensor(
|
|
1452
|
+
self.builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
|
|
1453
|
+
val.type)
|
|
1524
1454
|
|
|
1525
1455
|
# ===----------------------------------------------------------------------===//
|
|
1526
1456
|
# Linear Algebra
|
|
1527
1457
|
# ===----------------------------------------------------------------------===//
|
|
1528
1458
|
|
|
1459
|
+
def _str_to_dot_input_precision(self, input_precision):
|
|
1460
|
+
assert input_precision.lower() in self.builder.options.allowed_dot_input_precisions, \
|
|
1461
|
+
f"input_precision must be one of {self.builder.options.allowed_dot_input_precisions}. Got {input_precision}"
|
|
1462
|
+
input_precision = input_precision.upper()
|
|
1463
|
+
if input_precision == "TF32X3":
|
|
1464
|
+
input_precision = "TF32x3"
|
|
1465
|
+
return getattr(ir.INPUT_PRECISION, input_precision)
|
|
1466
|
+
|
|
1467
|
+
def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str],
|
|
1468
|
+
max_num_imprecise_acc: int, out_dtype: tl.dtype) -> TensorTy:
|
|
1469
|
+
assert lhs.type.is_block() and rhs.type.is_block()
|
|
1529
1470
|
|
|
1530
|
-
def _str_to_dot_input_precision(input_precision, builder):
|
|
1531
|
-
assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \
|
|
1532
|
-
f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}"
|
|
1533
|
-
input_precision = input_precision.upper()
|
|
1534
|
-
if input_precision == "TF32X3":
|
|
1535
|
-
input_precision = "TF32x3"
|
|
1536
|
-
return getattr(ir.INPUT_PRECISION, input_precision)
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int,
|
|
1540
|
-
out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
|
1541
|
-
assert lhs.type.is_block() and rhs.type.is_block()
|
|
1542
|
-
|
|
1543
|
-
if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
|
|
1544
|
-
# All combinations of supported fp8 x fp8 are permitted
|
|
1545
|
-
pass
|
|
1546
|
-
else:
|
|
1547
|
-
assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
|
|
1548
|
-
tl.float32), f"Unsupported lhs dtype {lhs.dtype}"
|
|
1549
|
-
assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
|
|
1550
|
-
tl.float32), f"Unsupported rhs dtype {rhs.dtype}"
|
|
1551
|
-
assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
|
|
1552
|
-
|
|
1553
|
-
if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
|
|
1554
|
-
# We upcast because there's no fp8e4b15 type in MLIR
|
|
1555
|
-
lhs = cast(lhs, tl.float16, builder)
|
|
1556
|
-
rhs = cast(rhs, tl.float16, builder)
|
|
1557
|
-
|
|
1558
|
-
if input_precision is None:
|
|
1559
|
-
input_precision = builder.options.default_dot_input_precision
|
|
1560
|
-
|
|
1561
|
-
input_precision = _str_to_dot_input_precision(input_precision, builder)
|
|
1562
|
-
|
|
1563
|
-
lhs_rank = len(lhs.shape)
|
|
1564
|
-
rhs_rank = len(rhs.shape)
|
|
1565
|
-
assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
|
|
1566
|
-
assert lhs.shape[-1].value == rhs.shape[
|
|
1567
|
-
-2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})"
|
|
1568
|
-
assert builder.codegen_fns.get("min_dot_size") is not None, "target doesn't provide lower shape bounds for dot."
|
|
1569
|
-
min_dot_size = builder.codegen_fns["min_dot_size"](lhs.type, rhs.type)
|
|
1570
|
-
assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \
|
|
1571
|
-
and rhs.shape[-1].value >= min_dot_size[1], \
|
|
1572
|
-
f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}"
|
|
1573
|
-
if lhs.type.scalar.is_int():
|
|
1574
|
-
assert lhs.type.scalar == tl.int8, "only int8 supported!"
|
|
1575
|
-
_0 = builder.get_int32(0)
|
|
1576
|
-
ret_scalar_ty = tl.int32
|
|
1577
|
-
elif out_dtype.is_bf16():
|
|
1578
|
-
raise ValueError(
|
|
1579
|
-
"out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
|
|
1580
|
-
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
|
|
1581
|
-
_0 = builder.get_fp32(0)
|
|
1582
|
-
ret_scalar_ty = tl.float32
|
|
1583
|
-
else:
|
|
1584
|
-
_0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0)
|
|
1585
|
-
ret_scalar_ty = out_dtype
|
|
1586
|
-
|
|
1587
|
-
M = lhs.type.shape[-2]
|
|
1588
|
-
N = rhs.type.shape[-1]
|
|
1589
|
-
K = lhs.type.shape[-1]
|
|
1590
|
-
B = lhs.type.shape[0] if lhs_rank == 3 else None
|
|
1591
|
-
ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N])
|
|
1592
|
-
if acc is None:
|
|
1593
|
-
acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N])
|
|
1594
|
-
else:
|
|
1595
|
-
acc_handle = acc.handle
|
|
1596
|
-
assert acc.type == ret_ty
|
|
1597
|
-
|
|
1598
|
-
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
|
|
1599
|
-
if max_num_imprecise_acc is None:
|
|
1600
1471
|
if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
|
|
1601
|
-
|
|
1472
|
+
# All combinations of supported fp8 x fp8 are permitted
|
|
1473
|
+
pass
|
|
1602
1474
|
else:
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
|
|
1619
|
-
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
|
|
1626
|
-
assert
|
|
1627
|
-
|
|
1628
|
-
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1650
|
-
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
|
|
1655
|
-
|
|
1656
|
-
|
|
1657
|
-
|
|
1658
|
-
|
|
1659
|
-
PACKED_A = 2 if lhs_format == "e2m1" else 1
|
|
1660
|
-
PACKED_B = 2 if rhs_format == "e2m1" else 1
|
|
1661
|
-
assert K * PACKED_B == PACKED_A * lhs.type.shape[
|
|
1662
|
-
-1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
|
|
1663
|
-
#assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
|
|
1664
|
-
B = lhs.type.shape[0] if lhs_rank == 3 else None
|
|
1665
|
-
|
|
1666
|
-
ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N])
|
|
1667
|
-
_0 = builder.get_fp32(0)
|
|
1668
|
-
if acc is None:
|
|
1669
|
-
acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N])
|
|
1670
|
-
else:
|
|
1671
|
-
acc_handle = acc.handle
|
|
1672
|
-
assert acc.type == ret_ty
|
|
1673
|
-
rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
|
|
1674
|
-
lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
|
|
1675
|
-
return tl.tensor(
|
|
1676
|
-
builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
|
|
1677
|
-
rhs_format_enum, fast_math, acc_handle), ret_ty)
|
|
1475
|
+
assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
|
|
1476
|
+
tl.float32), f"Unsupported lhs dtype {lhs.dtype}"
|
|
1477
|
+
assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
|
|
1478
|
+
tl.float32), f"Unsupported rhs dtype {rhs.dtype}"
|
|
1479
|
+
assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
|
|
1480
|
+
|
|
1481
|
+
if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
|
|
1482
|
+
if "fp8e4b15" in self.builder.options.deprecated_fp8_dot_operand_dtypes:
|
|
1483
|
+
warnings.warn(
|
|
1484
|
+
"the use of fp8e4b15 is deprecated on Hopper and later architectures and can cause significant slow down. It will be removed in a future triton release"
|
|
1485
|
+
)
|
|
1486
|
+
# We upcast because there's no fp8e4b15 type in MLIR
|
|
1487
|
+
lhs = self.cast(lhs, tl.float16)
|
|
1488
|
+
rhs = self.cast(rhs, tl.float16)
|
|
1489
|
+
|
|
1490
|
+
if input_precision is None:
|
|
1491
|
+
input_precision = self.builder.options.default_dot_input_precision
|
|
1492
|
+
|
|
1493
|
+
input_precision = self._str_to_dot_input_precision(input_precision)
|
|
1494
|
+
|
|
1495
|
+
lhs_rank = len(lhs.shape)
|
|
1496
|
+
rhs_rank = len(rhs.shape)
|
|
1497
|
+
assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
|
|
1498
|
+
assert lhs.shape[-1].value == rhs.shape[
|
|
1499
|
+
-2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})"
|
|
1500
|
+
assert self.builder.codegen_fns.get(
|
|
1501
|
+
"min_dot_size") is not None, "target doesn't provide lower shape bounds for dot."
|
|
1502
|
+
min_dot_size = self.builder.codegen_fns["min_dot_size"](lhs.type, rhs.type)
|
|
1503
|
+
assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \
|
|
1504
|
+
and rhs.shape[-1].value >= min_dot_size[1], \
|
|
1505
|
+
f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}"
|
|
1506
|
+
if lhs.type.scalar.is_int():
|
|
1507
|
+
assert lhs.type.scalar == tl.int8, "only int8 supported!"
|
|
1508
|
+
_0 = self.builder.get_int32(0)
|
|
1509
|
+
ret_scalar_ty = tl.int32
|
|
1510
|
+
elif out_dtype.is_bf16():
|
|
1511
|
+
raise ValueError(
|
|
1512
|
+
"out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`"
|
|
1513
|
+
)
|
|
1514
|
+
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
|
|
1515
|
+
_0 = self.builder.get_fp32(0)
|
|
1516
|
+
ret_scalar_ty = tl.float32
|
|
1517
|
+
else:
|
|
1518
|
+
_0 = self.builder.get_fp16(0) if out_dtype.is_fp16() else self.builder.get_fp32(0)
|
|
1519
|
+
ret_scalar_ty = out_dtype
|
|
1520
|
+
|
|
1521
|
+
M = lhs.type.shape[-2]
|
|
1522
|
+
N = rhs.type.shape[-1]
|
|
1523
|
+
K = lhs.type.shape[-1]
|
|
1524
|
+
B = lhs.type.shape[0] if lhs_rank == 3 else None
|
|
1525
|
+
ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N])
|
|
1526
|
+
if acc is None:
|
|
1527
|
+
acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
|
|
1528
|
+
else:
|
|
1529
|
+
acc_handle = acc.handle
|
|
1530
|
+
assert acc.type == ret_ty
|
|
1678
1531
|
|
|
1532
|
+
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
|
|
1533
|
+
if max_num_imprecise_acc is None:
|
|
1534
|
+
if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
|
|
1535
|
+
max_num_imprecise_acc = self.builder.options.max_num_imprecise_acc_default
|
|
1536
|
+
else:
|
|
1537
|
+
max_num_imprecise_acc = 0
|
|
1538
|
+
else:
|
|
1539
|
+
if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K:
|
|
1540
|
+
raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})")
|
|
1541
|
+
|
|
1542
|
+
return self.tensor(
|
|
1543
|
+
self.builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), ret_ty)
|
|
1544
|
+
|
|
1545
|
+
def _str_to_fp_type(self, float_format: str):
|
|
1546
|
+
ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None)
|
|
1547
|
+
if ty_enum is None:
|
|
1548
|
+
raise ValueError(f"Invalid float format: {float_format}.")
|
|
1549
|
+
return ty_enum
|
|
1550
|
+
|
|
1551
|
+
def _bitcast_to_fp_type(self, val: TensorTy, float_format: str):
|
|
1552
|
+
"""
|
|
1553
|
+
If float_format is subbyte, make sure it's packed as uint8 and return it.
|
|
1554
|
+
Otherwise, return a tensor (perhaps bitcasting) of the specified float format.
|
|
1555
|
+
"""
|
|
1556
|
+
triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16":
|
|
1557
|
+
tl.float16}.get(float_format)
|
|
1558
|
+
if triton_ty is None:
|
|
1559
|
+
assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}"
|
|
1560
|
+
assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}"
|
|
1561
|
+
return val
|
|
1562
|
+
if val.dtype == triton_ty:
|
|
1563
|
+
return val
|
|
1564
|
+
else:
|
|
1565
|
+
unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format]
|
|
1566
|
+
assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}"
|
|
1567
|
+
return self.bitcast(val, triton_ty)
|
|
1568
|
+
|
|
1569
|
+
def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: TensorTy,
|
|
1570
|
+
rhs_scale: Optional[TensorTy], rhs_format: str, acc: TensorTy | None, fast_math: bool,
|
|
1571
|
+
lhs_k_pack: bool, rhs_k_pack: bool, out_dtype: tl.dtype) -> TensorTy:
|
|
1572
|
+
assert lhs.type.is_block() and rhs.type.is_block()
|
|
1573
|
+
#TODO: validate types.
|
|
1574
|
+
lhs_rank = len(lhs.shape)
|
|
1575
|
+
rhs_rank = len(rhs.shape)
|
|
1576
|
+
assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
|
|
1577
|
+
lhs_format: str = lhs_format.value
|
|
1578
|
+
rhs_format: str = rhs_format.value
|
|
1579
|
+
lhs_format_enum = self._str_to_fp_type(lhs_format)
|
|
1580
|
+
rhs_format_enum = self._str_to_fp_type(rhs_format)
|
|
1581
|
+
allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"}
|
|
1582
|
+
assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}"
|
|
1583
|
+
assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}"
|
|
1584
|
+
rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None)
|
|
1585
|
+
lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None)
|
|
1586
|
+
lhs = self._bitcast_to_fp_type(lhs, lhs_format)
|
|
1587
|
+
rhs = self._bitcast_to_fp_type(rhs, rhs_format)
|
|
1588
|
+
|
|
1589
|
+
assert lhs_k_pack or lhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K"
|
|
1590
|
+
assert rhs_k_pack or rhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K"
|
|
1591
|
+
M, K_LHS = lhs.type.shape[-2:]
|
|
1592
|
+
K_RHS, N = rhs.type.shape[-2:]
|
|
1593
|
+
PACKED_A = 2 if lhs_format == "e2m1" else 1
|
|
1594
|
+
PACKED_B = 2 if rhs_format == "e2m1" else 1
|
|
1595
|
+
PACKED_A_DIM = PACKED_A * K_LHS if lhs_k_pack else K_LHS
|
|
1596
|
+
PACKED_B_DIM = PACKED_B * K_RHS if rhs_k_pack else K_RHS
|
|
1597
|
+
assert PACKED_B_DIM == PACKED_A_DIM, f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
|
|
1598
|
+
#assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
|
|
1599
|
+
B = lhs.type.shape[0] if lhs_rank == 3 else None
|
|
1600
|
+
if not lhs_k_pack:
|
|
1601
|
+
M = M * PACKED_A
|
|
1602
|
+
if not rhs_k_pack:
|
|
1603
|
+
N = N * PACKED_B
|
|
1604
|
+
ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N])
|
|
1605
|
+
_0 = self.builder.get_fp32(0)
|
|
1606
|
+
if acc is None:
|
|
1607
|
+
acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
|
|
1608
|
+
else:
|
|
1609
|
+
acc_handle = acc.handle
|
|
1610
|
+
assert acc.type == ret_ty
|
|
1611
|
+
rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
|
|
1612
|
+
lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
|
|
1613
|
+
return self.tensor(
|
|
1614
|
+
self.builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
|
|
1615
|
+
rhs_format_enum, fast_math, lhs_k_pack, rhs_k_pack, acc_handle), ret_ty)
|
|
1679
1616
|
|
|
1680
1617
|
# ===----------------------------------------------------------------------===//
|
|
1681
1618
|
# Indexing
|
|
1682
1619
|
# ===----------------------------------------------------------------------===//
|
|
1683
1620
|
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
|
|
1688
|
-
|
|
1689
|
-
)
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
|
1700
|
-
|
|
1621
|
+
def where(self, condition: TensorTy, x: TensorTy, y: TensorTy) -> TensorTy:
|
|
1622
|
+
if condition.dtype != tl.int1:
|
|
1623
|
+
warnings.warn(
|
|
1624
|
+
f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}"
|
|
1625
|
+
)
|
|
1626
|
+
condition = self.cast(condition, tl.int1)
|
|
1627
|
+
x, y = self.binary_op_type_checking_impl(x, y, True, True)
|
|
1628
|
+
# x, y are broadcasted
|
|
1629
|
+
if condition.type.is_block():
|
|
1630
|
+
condition, x = self.broadcast_impl_value(condition, x)
|
|
1631
|
+
x, y = self.broadcast_impl_value(x, y)
|
|
1632
|
+
else:
|
|
1633
|
+
condition, _ = self.broadcast_impl_value(condition, x)
|
|
1634
|
+
ret_ty = x.type
|
|
1635
|
+
return self.tensor(self.builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
|
1701
1636
|
|
|
1702
1637
|
# ===----------------------------------------------------------------------===//
|
|
1703
1638
|
# Reduction
|
|
1704
1639
|
# ===----------------------------------------------------------------------===
|
|
1705
1640
|
|
|
1706
|
-
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
1721
|
-
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs)))
|
|
1732
|
-
|
|
1641
|
+
def wrap_tensor(self, x, scalar_ty, ret_shape):
|
|
1642
|
+
if ret_shape:
|
|
1643
|
+
res_ty = tl.block_type(scalar_ty, ret_shape)
|
|
1644
|
+
else:
|
|
1645
|
+
# 0d-tensor -> scalar
|
|
1646
|
+
res_ty = scalar_ty
|
|
1647
|
+
return self.tensor(x, res_ty)
|
|
1648
|
+
|
|
1649
|
+
def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
|
|
1650
|
+
if axis is None:
|
|
1651
|
+
inputs = tuple(self.reshape(t, [t.numel.value], can_reorder=True) for t in inputs)
|
|
1652
|
+
axis = 0
|
|
1653
|
+
# get result shape
|
|
1654
|
+
shape = inputs[0].type.shape
|
|
1655
|
+
rank = len(shape)
|
|
1656
|
+
assert axis < rank, f"reduction axis must be < inputs rank ({rank})"
|
|
1657
|
+
ret_shape = [s for i, s in enumerate(shape) if i != axis]
|
|
1658
|
+
assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
|
|
1659
|
+
|
|
1660
|
+
reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
|
|
1661
|
+
region_builder_fn(reduce_op)
|
|
1662
|
+
assert reduce_op.verify()
|
|
1663
|
+
|
|
1664
|
+
return tuple(
|
|
1665
|
+
self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs)))
|
|
1733
1666
|
|
|
1734
1667
|
# ===----------------------------------------------------------------------===
|
|
1735
1668
|
# Associative Scan
|
|
1736
1669
|
# ===----------------------------------------------------------------------===
|
|
1737
1670
|
|
|
1671
|
+
def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
|
|
1672
|
+
reverse: bool) -> Tuple[TensorTy, ...]:
|
|
1673
|
+
shape = inputs[0].type.shape
|
|
1674
|
+
rank = len(shape)
|
|
1738
1675
|
|
|
1739
|
-
|
|
1740
|
-
builder: ir.builder) -> Tuple[tl.tensor, ...]:
|
|
1741
|
-
shape = inputs[0].type.shape
|
|
1742
|
-
rank = len(shape)
|
|
1743
|
-
|
|
1744
|
-
assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
|
|
1676
|
+
assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
|
|
1745
1677
|
|
|
1746
|
-
|
|
1747
|
-
|
|
1678
|
+
if axis < 0:
|
|
1679
|
+
axis += rank
|
|
1748
1680
|
|
|
1749
|
-
|
|
1750
|
-
|
|
1681
|
+
for t in inputs:
|
|
1682
|
+
assert t.type.shape == shape, "all scan inputs must have the same shape"
|
|
1751
1683
|
|
|
1752
|
-
|
|
1753
|
-
|
|
1754
|
-
|
|
1755
|
-
|
|
1756
|
-
return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs)))
|
|
1684
|
+
scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse)
|
|
1685
|
+
region_builder_fn(scan_op)
|
|
1686
|
+
assert scan_op.verify()
|
|
1757
1687
|
|
|
1688
|
+
return tuple(self.wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs)))
|
|
1758
1689
|
|
|
1759
1690
|
# ===----------------------------------------------------------------------===
|
|
1760
1691
|
# Gather
|
|
1761
1692
|
# ===----------------------------------------------------------------------===
|
|
1762
1693
|
|
|
1694
|
+
def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy:
|
|
1695
|
+
assert index.dtype.is_int(), "index must be an integer tensor"
|
|
1763
1696
|
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
|
|
1767
|
-
rank = len(src.type.shape)
|
|
1768
|
-
assert len(index.type.shape) == rank, "source and index tensors must have the same rank"
|
|
1697
|
+
rank = len(src.type.shape)
|
|
1698
|
+
assert len(index.type.shape) == rank, "source and index tensors must have the same rank"
|
|
1769
1699
|
|
|
1770
|
-
|
|
1771
|
-
|
|
1772
|
-
|
|
1700
|
+
assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})"
|
|
1701
|
+
if axis < 0:
|
|
1702
|
+
axis += rank
|
|
1773
1703
|
|
|
1774
|
-
|
|
1775
|
-
|
|
1776
|
-
|
|
1777
|
-
|
|
1704
|
+
for d in range(rank):
|
|
1705
|
+
if d == axis:
|
|
1706
|
+
continue
|
|
1707
|
+
assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim"
|
|
1778
1708
|
|
|
1779
|
-
|
|
1780
|
-
|
|
1709
|
+
gather = self.builder.create_gather(src.handle, index.handle, axis)
|
|
1710
|
+
return self.wrap_tensor(gather, src.type.scalar, index.type.shape)
|
|
1781
1711
|
|
|
1782
1712
|
|
|
1783
1713
|
# ===----------------------------------------------------------------------===
|
|
1784
1714
|
# Histogram
|
|
1785
1715
|
# ===----------------------------------------------------------------------===
|
|
1786
1716
|
|
|
1717
|
+
def histogram(self, input: TensorTy, num_bins: int, mask: Optional[TensorTy]) -> TensorTy:
|
|
1718
|
+
assert len(input.shape) == 1, "histogram only supports 1D input"
|
|
1719
|
+
assert input.dtype.is_int(), "histogram only supports integer input"
|
|
1720
|
+
if mask is not None:
|
|
1721
|
+
mask = self.broadcast_impl_shape(mask, input.shape)
|
|
1722
|
+
if not mask.type.scalar.is_bool():
|
|
1723
|
+
raise ValueError("Mask must have boolean scalar type")
|
|
1724
|
+
mask = mask.handle
|
|
1725
|
+
return self.tensor(self.builder.create_histogram(input.handle, num_bins, mask),
|
|
1726
|
+
tl.block_type(tl.int32, [num_bins]))
|
|
1727
|
+
|
|
1728
|
+
def multiple_of(self, x: TensorTy, values: List[int]) -> TensorTy:
|
|
1729
|
+
if max(1, len(x.shape)) != len(values):
|
|
1730
|
+
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
|
1731
|
+
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
|
|
1732
|
+
return x
|
|
1787
1733
|
|
|
1788
|
-
def
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
|
|
1793
|
-
|
|
1794
|
-
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
|
|
1795
|
-
if max(1, len(x.shape)) != len(values):
|
|
1796
|
-
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
|
1797
|
-
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
|
|
1798
|
-
return x
|
|
1799
|
-
|
|
1800
|
-
|
|
1801
|
-
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
|
|
1802
|
-
if len(x.shape) != len(values):
|
|
1803
|
-
raise ValueError("Shape of input to max_contiguous does not match the length of values")
|
|
1804
|
-
x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
|
|
1805
|
-
return x
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor:
|
|
1809
|
-
if len(x.shape) != len(values):
|
|
1810
|
-
raise ValueError("Shape of input to max_constancy does not match the length of values")
|
|
1811
|
-
x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context()))
|
|
1812
|
-
return x
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
def debug_barrier(builder: ir.builder) -> tl.tensor:
|
|
1816
|
-
return tl.tensor(builder.create_barrier(), tl.void)
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor:
|
|
1820
|
-
# It makes sense visually for prefix to end in ": "; make it so. Also,
|
|
1821
|
-
# non-empty prefixes should start with " ".
|
|
1822
|
-
if not prefix.endswith(" ") and args:
|
|
1823
|
-
prefix += " "
|
|
1824
|
-
if not prefix.endswith(": ") and args:
|
|
1825
|
-
prefix = prefix[:-1] + ": "
|
|
1826
|
-
if len(prefix) > 2 and not prefix.startswith(" "):
|
|
1827
|
-
prefix = " " + prefix
|
|
1828
|
-
|
|
1829
|
-
new_args = [arg.handle for arg in args]
|
|
1830
|
-
is_signed = [arg.dtype in (tl.int1, tl.int8, tl.int16, tl.int32, tl.int64) for arg in args]
|
|
1831
|
-
return tl.tensor(builder.create_print(prefix, hex, new_args, is_signed), tl.void)
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
|
|
1835
|
-
if not builder.options.debug:
|
|
1836
|
-
return
|
|
1837
|
-
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)
|
|
1838
|
-
|
|
1839
|
-
|
|
1840
|
-
def assume(cond, builder: ir.builder) -> tl.tensor:
|
|
1841
|
-
return tl.tensor(builder.create_assume(cond.handle), tl.void)
|
|
1734
|
+
def max_contiguous(self, x: TensorTy, values: List[int]) -> TensorTy:
|
|
1735
|
+
if len(x.shape) != len(values):
|
|
1736
|
+
raise ValueError("Shape of input to max_contiguous does not match the length of values")
|
|
1737
|
+
x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
|
|
1738
|
+
return x
|
|
1842
1739
|
|
|
1740
|
+
def max_constancy(self, x: TensorTy, values: List[int]) -> TensorTy:
|
|
1741
|
+
if len(x.shape) != len(values):
|
|
1742
|
+
raise ValueError("Shape of input to max_constancy does not match the length of values")
|
|
1743
|
+
x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context()))
|
|
1744
|
+
return x
|
|
1843
1745
|
|
|
1844
|
-
def
|
|
1845
|
-
|
|
1846
|
-
|
|
1847
|
-
|
|
1848
|
-
|
|
1849
|
-
|
|
1850
|
-
|
|
1851
|
-
|
|
1852
|
-
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
return
|
|
1865
|
-
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
|
|
1870
|
-
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
|
|
1875
|
-
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
|
|
1889
|
-
|
|
1890
|
-
|
|
1891
|
-
|
|
1892
|
-
|
|
1893
|
-
|
|
1894
|
-
|
|
1895
|
-
|
|
1896
|
-
|
|
1897
|
-
|
|
1898
|
-
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1746
|
+
def debug_barrier(self) -> TensorTy:
|
|
1747
|
+
return self.tensor(self.builder.create_barrier(), tl.void)
|
|
1748
|
+
|
|
1749
|
+
def device_print(self, prefix: str, args: List[TensorTy], hex: bool) -> TensorTy:
|
|
1750
|
+
# It makes sense visually for prefix to end in ": "; make it so. Also,
|
|
1751
|
+
# non-empty prefixes should start with " ".
|
|
1752
|
+
if not prefix.endswith(" ") and args:
|
|
1753
|
+
prefix += " "
|
|
1754
|
+
if not prefix.endswith(": ") and args:
|
|
1755
|
+
prefix = prefix[:-1] + ": "
|
|
1756
|
+
if len(prefix) > 2 and not prefix.startswith(" "):
|
|
1757
|
+
prefix = " " + prefix
|
|
1758
|
+
|
|
1759
|
+
new_args = [arg.handle for arg in args]
|
|
1760
|
+
is_signed = [arg.dtype.is_int_signed() for arg in args]
|
|
1761
|
+
return self.tensor(self.builder.create_print(prefix, hex, new_args, is_signed), tl.void)
|
|
1762
|
+
|
|
1763
|
+
def device_assert(self, cond: TensorTy, msg: str) -> TensorTy:
|
|
1764
|
+
if not self.builder.options.debug:
|
|
1765
|
+
return
|
|
1766
|
+
return self.tensor(self.builder.create_assert(cond.handle, msg), tl.void)
|
|
1767
|
+
|
|
1768
|
+
def assume(self, cond) -> TensorTy:
|
|
1769
|
+
return self.tensor(self.builder.create_assume(cond.handle), tl.void)
|
|
1770
|
+
|
|
1771
|
+
def _convert_elem_to_ir_value(self, elem, require_i64):
|
|
1772
|
+
if isinstance(elem, int):
|
|
1773
|
+
elem = tl.constexpr(elem)
|
|
1774
|
+
if isinstance(elem, tl.constexpr):
|
|
1775
|
+
if isinstance(elem.value, bool):
|
|
1776
|
+
return self.builder.get_int1(elem.value)
|
|
1777
|
+
if require_i64:
|
|
1778
|
+
assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \
|
|
1779
|
+
f"got a value {elem.value} which is out of the range"
|
|
1780
|
+
return self.builder.get_int64(elem.value)
|
|
1781
|
+
else:
|
|
1782
|
+
assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \
|
|
1783
|
+
f"got a value {elem.value} which is out of the range"
|
|
1784
|
+
return self.builder.get_int32(elem.value)
|
|
1785
|
+
elif isinstance(elem, tl.tensor):
|
|
1786
|
+
assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets"
|
|
1787
|
+
assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets"
|
|
1788
|
+
if elem.dtype != tl.int64 and require_i64:
|
|
1789
|
+
return self.builder.create_int_cast(elem.handle, self.builder.get_int64_ty(),
|
|
1790
|
+
elem.dtype.is_int_signed())
|
|
1791
|
+
elif elem.dtype != tl.int32 and not require_i64:
|
|
1792
|
+
assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \
|
|
1793
|
+
"add a `.to(tl.int32)` or use regular indexing for 64 bit support"
|
|
1794
|
+
return elem.handle
|
|
1795
|
+
assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}"
|
|
1796
|
+
|
|
1797
|
+
def _convert_to_ir_values(self, list_like, require_i64=True):
|
|
1798
|
+
if hasattr(list_like, "__iter__"):
|
|
1799
|
+
return [self._convert_elem_to_ir_value(elem, require_i64) for elem in list_like]
|
|
1800
|
+
return [self._convert_elem_to_ir_value(list_like, require_i64)]
|
|
1801
|
+
|
|
1802
|
+
def make_block_ptr(self, base: TensorTy, shape, strides, offsets, block_shape, order) -> TensorTy:
|
|
1803
|
+
# Convert dynamic arguments to IR values
|
|
1804
|
+
# NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t`
|
|
1805
|
+
shape = self._convert_to_ir_values(shape)
|
|
1806
|
+
strides = self._convert_to_ir_values(strides)
|
|
1807
|
+
offsets = self._convert_to_ir_values(offsets, require_i64=False)
|
|
1808
|
+
|
|
1809
|
+
# Check `base` type
|
|
1810
|
+
if not base.type.is_ptr() or base.type.element_ty.is_block():
|
|
1811
|
+
raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)")
|
|
1812
|
+
|
|
1813
|
+
# Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
|
|
1814
|
+
if base.type.element_ty == tl.int1:
|
|
1815
|
+
base = self.cast(base, tl.pointer_type(tl.int8, base.type.address_space))
|
|
1816
|
+
|
|
1817
|
+
# Check whether `block_shape` is static
|
|
1818
|
+
if not hasattr(block_shape, "__iter__"):
|
|
1819
|
+
block_shape = [block_shape]
|
|
1820
|
+
block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape]
|
|
1821
|
+
assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \
|
|
1822
|
+
"Expected a list of constant integers (`int32_t` range) in `block_shape`"
|
|
1823
|
+
|
|
1824
|
+
# Check `order`
|
|
1825
|
+
if not hasattr(order, "__iter__"):
|
|
1826
|
+
order = [order]
|
|
1827
|
+
order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order]
|
|
1828
|
+
assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order"
|
|
1829
|
+
|
|
1830
|
+
# Must have same length
|
|
1831
|
+
assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \
|
|
1832
|
+
"Expected shape/strides/offsets/block_shape to have the same length"
|
|
1833
|
+
|
|
1834
|
+
# Build value, the type is:
|
|
1835
|
+
# `pointer_type<blocked<shape, element_type>>` in Python
|
|
1836
|
+
# `tt.ptr<tensor<shape, element_type>>` in MLIR
|
|
1837
|
+
handle = self.builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order)
|
|
1838
|
+
return self.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape)))
|
|
1839
|
+
|
|
1840
|
+
def advance(self, base: TensorTy, offsets) -> TensorTy:
|
|
1841
|
+
# Convert dynamic offsets to IR values
|
|
1842
|
+
offsets = self._convert_to_ir_values(offsets, require_i64=False)
|
|
1843
|
+
|
|
1844
|
+
# Advanced block pointer type is the same as before
|
|
1845
|
+
return self.tensor(self.builder.create_advance(base.handle, offsets), base.type)
|
|
1846
|
+
|
|
1847
|
+
def make_tensor_descriptor(
|
|
1848
|
+
self,
|
|
1849
|
+
base: TensorTy,
|
|
1850
|
+
shape: List[TensorTy],
|
|
1851
|
+
strides: List[TensorTy],
|
|
1852
|
+
block_shape: List[tl.constexpr],
|
|
1853
|
+
) -> tl.tensor_descriptor:
|
|
1854
|
+
ndim = len(shape)
|
|
1855
|
+
if not (1 <= ndim <= 5):
|
|
1856
|
+
raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions")
|
|
1857
|
+
if len(strides) != ndim:
|
|
1858
|
+
raise ValueError(f"Expected {ndim} strides but got {len(strides)}")
|
|
1859
|
+
if len(block_shape) != ndim:
|
|
1860
|
+
raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}")
|
|
1861
|
+
assert isinstance(base.dtype, tl.pointer_type)
|
|
1862
|
+
elem_size = base.dtype.element_ty.primitive_bitwidth // 8
|
|
1863
|
+
contig_dim_size = tl._unwrap_if_constexpr(block_shape[-1])
|
|
1864
|
+
if contig_dim_size * elem_size < 16:
|
|
1865
|
+
raise ValueError(
|
|
1866
|
+
f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes"
|
|
1867
|
+
)
|
|
1868
|
+
|
|
1869
|
+
strides[-1] = tl._unwrap_if_constexpr(strides[-1])
|
|
1870
|
+
if strides[-1] != 1:
|
|
1871
|
+
raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}")
|
|
1872
|
+
|
|
1873
|
+
shape = [self.make_scalar(x, tl.int32) for x in shape]
|
|
1874
|
+
strides = [self.make_scalar(x, tl.int64) for x in strides]
|
|
1875
|
+
|
|
1876
|
+
# Check whether `block_shape` is static
|
|
1877
|
+
block_shape = tl._unwrap_shape(block_shape)
|
|
1878
|
+
|
|
1879
|
+
assert isinstance(base.type, tl.pointer_type)
|
|
1880
|
+
type = tl.block_type(base.type.element_ty, block_shape)
|
|
1881
|
+
base_handle = base.handle
|
|
1882
|
+
is_signed_int = base.type.element_ty.is_int_signed()
|
|
1883
|
+
|
|
1884
|
+
handle = self.builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape],
|
|
1885
|
+
[s.handle for s in strides], block_shape, is_signed_int)
|
|
1886
|
+
return tl.tensor_descriptor(handle, shape, strides, type)
|