mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,860 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Tri Dao.
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
import hashlib
|
|
6
|
+
import inspect
|
|
7
|
+
import re
|
|
8
|
+
from typing import Type, Callable, Optional, Tuple, overload
|
|
9
|
+
from functools import partial
|
|
10
|
+
|
|
11
|
+
import cutlass
|
|
12
|
+
import cutlass.cute as cute
|
|
13
|
+
|
|
14
|
+
from cutlass import Float32, const_expr
|
|
15
|
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
16
|
+
from cutlass._mlir.dialects import nvvm, llvm
|
|
17
|
+
from cutlass.cute.runtime import from_dlpack
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default
|
|
21
|
+
fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
|
|
22
|
+
mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
|
|
23
|
+
add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
|
|
24
|
+
sub_packed_f32x2 = partial(
|
|
25
|
+
cute.arch.calc_packed_f32x2_op,
|
|
26
|
+
src_c=None,
|
|
27
|
+
calc_func=nvvm.sub_packed_f32x2,
|
|
28
|
+
rnd=nvvm.RoundingModeKind.RN,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def hash_callable(func: Callable, set_cute_hash=True) -> str:
|
|
33
|
+
"""Hash a callable based on the source code or bytecode and closure values.
|
|
34
|
+
|
|
35
|
+
Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__``
|
|
36
|
+
attribute, that value is returned immediately. Code-generation backends such
|
|
37
|
+
as Inductor can set this attribute to avoid expensive runtime hashing.
|
|
38
|
+
|
|
39
|
+
set_cute_hash: whether or not to set func.__cute_hash__ if not present
|
|
40
|
+
"""
|
|
41
|
+
if hasattr(func, "__cute_hash__"):
|
|
42
|
+
return func.__cute_hash__
|
|
43
|
+
|
|
44
|
+
# Unwrap decorated functions (e.g., cute.jit wrappers).
|
|
45
|
+
if hasattr(func, "__wrapped__"):
|
|
46
|
+
base_func = func.__wrapped__
|
|
47
|
+
if hasattr(base_func, "__cute_hash__"):
|
|
48
|
+
return base_func.__cute_hash__
|
|
49
|
+
func = base_func
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
data = inspect.getsource(func).encode()
|
|
53
|
+
except (OSError, TypeError):
|
|
54
|
+
if hasattr(func, "__code__") and func.__code__ is not None:
|
|
55
|
+
data = func.__code__.co_code
|
|
56
|
+
else:
|
|
57
|
+
data = repr(func).encode()
|
|
58
|
+
|
|
59
|
+
hasher = hashlib.sha256(data)
|
|
60
|
+
|
|
61
|
+
if hasattr(func, "__closure__") and func.__closure__ is not None:
|
|
62
|
+
for idx, cell in enumerate(func.__closure__):
|
|
63
|
+
cell_value = cell.cell_contents
|
|
64
|
+
hasher.update(repr(cell_value).encode())
|
|
65
|
+
|
|
66
|
+
hash = hasher.hexdigest()
|
|
67
|
+
|
|
68
|
+
if set_cute_hash:
|
|
69
|
+
func.__cute_hash__ = hash
|
|
70
|
+
|
|
71
|
+
return hash
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def create_softcap_scoremod(softcap_val):
|
|
75
|
+
inv_softcap = 1.0 / softcap_val
|
|
76
|
+
|
|
77
|
+
@cute.jit
|
|
78
|
+
def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
|
|
79
|
+
scores = acc_S_SSA * inv_softcap
|
|
80
|
+
return scores * cute.math.tanh(scores, fastmath=True)
|
|
81
|
+
|
|
82
|
+
return scoremod_premask_fn
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
|
|
86
|
+
return (
|
|
87
|
+
from_dlpack(x, assumed_align=alignment)
|
|
88
|
+
.mark_layout_dynamic(leading_dim=leading_dim)
|
|
89
|
+
.mark_compact_shape_dynamic(
|
|
90
|
+
mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def convert_from_dlpack_leading_static(
|
|
96
|
+
x, leading_dim, alignment=16, static_modes=None, stride_order=None
|
|
97
|
+
) -> cute.Tensor:
|
|
98
|
+
if stride_order is None:
|
|
99
|
+
stride_order = x.dim_order()
|
|
100
|
+
x_ = from_dlpack(x, assumed_align=alignment)
|
|
101
|
+
for i in range(x.ndim):
|
|
102
|
+
if i != leading_dim and (static_modes is None or i not in static_modes):
|
|
103
|
+
x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order)
|
|
104
|
+
return x_
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def make_tiled_copy_A(
|
|
108
|
+
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
|
|
109
|
+
) -> cute.TiledCopy:
|
|
110
|
+
if const_expr(swapAB):
|
|
111
|
+
return cute.make_tiled_copy_B(copy_atom, tiled_mma)
|
|
112
|
+
else:
|
|
113
|
+
return cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def make_tiled_copy_B(
|
|
117
|
+
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
|
|
118
|
+
) -> cute.TiledCopy:
|
|
119
|
+
if const_expr(swapAB):
|
|
120
|
+
return cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
|
121
|
+
else:
|
|
122
|
+
return cute.make_tiled_copy_B(copy_atom, tiled_mma)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def mma_make_fragment_A(
|
|
126
|
+
smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
|
|
127
|
+
) -> cute.Tensor:
|
|
128
|
+
if const_expr(swapAB):
|
|
129
|
+
return mma_make_fragment_B(smem, thr_mma)
|
|
130
|
+
else:
|
|
131
|
+
return thr_mma.make_fragment_A(thr_mma.partition_A(smem))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def mma_make_fragment_B(
|
|
135
|
+
smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
|
|
136
|
+
) -> cute.Tensor:
|
|
137
|
+
if const_expr(swapAB):
|
|
138
|
+
return mma_make_fragment_A(smem, thr_mma)
|
|
139
|
+
else:
|
|
140
|
+
return thr_mma.make_fragment_B(thr_mma.partition_B(smem))
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def get_smem_store_atom(
|
|
144
|
+
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
|
|
145
|
+
) -> cute.CopyAtom:
|
|
146
|
+
if const_expr(arch < 90 or element_type.width != 16):
|
|
147
|
+
return cute.make_copy_atom(
|
|
148
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
149
|
+
element_type,
|
|
150
|
+
num_bits_per_copy=2 * element_type.width,
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
return cute.make_copy_atom(
|
|
154
|
+
cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
|
155
|
+
element_type,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@cute.jit
|
|
160
|
+
def warp_reduce(
|
|
161
|
+
val: cute.TensorSSA | cute.Numeric,
|
|
162
|
+
op: Callable,
|
|
163
|
+
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
164
|
+
) -> cute.TensorSSA | cute.Numeric:
|
|
165
|
+
if const_expr(isinstance(val, cute.TensorSSA)):
|
|
166
|
+
res = cute.make_fragment(val.shape, val.dtype)
|
|
167
|
+
res.store(val)
|
|
168
|
+
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
|
169
|
+
res[i] = warp_reduce(res[i], op, width)
|
|
170
|
+
return res.load()
|
|
171
|
+
else:
|
|
172
|
+
for i in cutlass.range_constexpr(int(math.log2(width))):
|
|
173
|
+
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
|
174
|
+
return val
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout:
|
|
178
|
+
"""
|
|
179
|
+
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
|
180
|
+
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
|
181
|
+
"""
|
|
182
|
+
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
|
183
|
+
shape = (
|
|
184
|
+
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
|
185
|
+
(
|
|
186
|
+
acc_layout_col_major.shape[0][0],
|
|
187
|
+
*acc_layout_col_major.shape[0][2:],
|
|
188
|
+
acc_layout_col_major.shape[2],
|
|
189
|
+
), # MMA_N
|
|
190
|
+
*acc_layout_col_major.shape[3:],
|
|
191
|
+
)
|
|
192
|
+
stride = (
|
|
193
|
+
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
|
194
|
+
(
|
|
195
|
+
acc_layout_col_major.stride[0][0],
|
|
196
|
+
*acc_layout_col_major.stride[0][2:],
|
|
197
|
+
acc_layout_col_major.stride[2],
|
|
198
|
+
), # MMA_N
|
|
199
|
+
*acc_layout_col_major.stride[3:],
|
|
200
|
+
)
|
|
201
|
+
if const_expr(transpose):
|
|
202
|
+
shape = (shape[1], shape[0], *shape[2:])
|
|
203
|
+
stride = (stride[1], stride[0], *stride[2:])
|
|
204
|
+
acc_layout_mn = cute.make_layout(shape, stride=stride)
|
|
205
|
+
return cute.composition(acc_layout, acc_layout_mn)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
|
|
209
|
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
@cute.jit
|
|
213
|
+
def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
|
|
214
|
+
# For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
|
|
215
|
+
# For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
216
|
+
# For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
|
217
|
+
# TODO: Sm90 FP8
|
|
218
|
+
if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
|
|
219
|
+
l = cute.logical_divide(
|
|
220
|
+
acc_layout, ((None, None, 2), None, None)
|
|
221
|
+
) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
|
|
222
|
+
rA_mma_view = cute.make_layout(
|
|
223
|
+
(
|
|
224
|
+
(l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
|
|
225
|
+
l.shape[1],
|
|
226
|
+
(l.shape[0][2][1], l.shape[2]),
|
|
227
|
+
),
|
|
228
|
+
stride=(
|
|
229
|
+
(l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
|
|
230
|
+
l.stride[1],
|
|
231
|
+
(l.stride[0][2][1], l.stride[2]),
|
|
232
|
+
),
|
|
233
|
+
)
|
|
234
|
+
else: # Sm80
|
|
235
|
+
# (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
|
|
236
|
+
l = cute.logical_divide(acc_layout, (None, None, 2))
|
|
237
|
+
rA_mma_view = cute.make_layout(
|
|
238
|
+
(
|
|
239
|
+
(l.shape[0], l.shape[2][0]),
|
|
240
|
+
l.shape[1],
|
|
241
|
+
l.shape[2][1],
|
|
242
|
+
),
|
|
243
|
+
stride=(
|
|
244
|
+
(l.stride[0], l.stride[2][0]),
|
|
245
|
+
l.stride[1],
|
|
246
|
+
l.stride[2][1],
|
|
247
|
+
),
|
|
248
|
+
)
|
|
249
|
+
return rA_mma_view
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def make_acc_tensor_frgA_view(acc: cute.Tensor) -> cute.Tensor:
|
|
253
|
+
return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
|
|
257
|
+
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
|
261
|
+
"""Transpose the first two dimensions of a tensor on smem."""
|
|
262
|
+
shape = (a.shape[1], a.shape[0], *a.shape[2:])
|
|
263
|
+
order = (1, 0, *range(2, cute.rank(a)))
|
|
264
|
+
return cute.composition(a, cute.make_ordered_layout(shape, order=order))
|
|
265
|
+
# stride = (a.layout.stride[1], a.layout.stride[0], *a.layout.stride[2:])
|
|
266
|
+
# return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle:
|
|
270
|
+
"""Extract swizzle parameters from a pointer's swizzle_type.
|
|
271
|
+
|
|
272
|
+
The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
|
|
273
|
+
b, m, s are the swizzle parameters (bits, base, shift).
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
A cute.Swizzle object constructed from the extracted parameters
|
|
277
|
+
|
|
278
|
+
Raises:
|
|
279
|
+
ValueError: If the swizzle_type string cannot be parsed
|
|
280
|
+
"""
|
|
281
|
+
# Ideally there should be a better API to get swizzle parameters, but we'll just parse
|
|
282
|
+
# the string here.
|
|
283
|
+
swizzle_str = str(ptr.type.swizzle_type)
|
|
284
|
+
# Extract the inner part "S<b,m,s>"
|
|
285
|
+
match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
|
|
286
|
+
if match:
|
|
287
|
+
b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
|
288
|
+
return cute.make_swizzle(b, m, s)
|
|
289
|
+
else:
|
|
290
|
+
raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@cute.jit
|
|
294
|
+
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
295
|
+
"""exp2f calculation for both vector and scalar.
|
|
296
|
+
:param x: input value
|
|
297
|
+
:type x: cute.TensorSSA or Float32
|
|
298
|
+
:return: exp2 value
|
|
299
|
+
:rtype: cute.TensorSSA or Float32
|
|
300
|
+
"""
|
|
301
|
+
if const_expr(isinstance(x, cute.TensorSSA)):
|
|
302
|
+
res = cute.make_fragment(x.shape, Float32)
|
|
303
|
+
res.store(x)
|
|
304
|
+
for i in cutlass.range_constexpr(cute.size(x.shape)):
|
|
305
|
+
res[i] = cute.arch.exp2(res[i])
|
|
306
|
+
return res.load()
|
|
307
|
+
else:
|
|
308
|
+
return cute.arch.exp2(x)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
@dsl_user_op
|
|
312
|
+
def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
313
|
+
return Float32(
|
|
314
|
+
llvm.inline_asm(
|
|
315
|
+
T.f32(),
|
|
316
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
317
|
+
"lg2.approx.ftz.f32 $0, $1;",
|
|
318
|
+
"=f,f",
|
|
319
|
+
has_side_effects=False,
|
|
320
|
+
is_align_stack=False,
|
|
321
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
322
|
+
)
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
@dsl_user_op
|
|
327
|
+
def logf(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
328
|
+
return log2f(a, loc=loc, ip=ip) * math.log(2.0)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
@dsl_user_op
|
|
332
|
+
def fmax(
|
|
333
|
+
a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
|
|
334
|
+
) -> Float32:
|
|
335
|
+
return Float32(
|
|
336
|
+
nvvm.fmax(
|
|
337
|
+
T.f32(),
|
|
338
|
+
Float32(a).ir_value(loc=loc, ip=ip),
|
|
339
|
+
Float32(b).ir_value(loc=loc, ip=ip),
|
|
340
|
+
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
|
|
341
|
+
loc=loc,
|
|
342
|
+
ip=ip,
|
|
343
|
+
)
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@cute.jit
|
|
348
|
+
def fmax_reduce(
|
|
349
|
+
x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
|
|
350
|
+
) -> Float32:
|
|
351
|
+
if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
|
|
352
|
+
# if const_expr(init_val is None):
|
|
353
|
+
# init_val = -cutlass.Float32.if
|
|
354
|
+
# return x.reduce(cute.ReductionOp.MAX, init_val, 0)
|
|
355
|
+
res = cute.make_fragment(x.shape, Float32)
|
|
356
|
+
res.store(x)
|
|
357
|
+
# local_max = [res[0], res[1]]
|
|
358
|
+
# for i in cutlass.range_constexpr(2, cute.size(x.shape), 2):
|
|
359
|
+
# local_max[0] = fmax(local_max[0], res[i + 0])
|
|
360
|
+
# local_max[1] = fmax(local_max[1], res[i + 1])
|
|
361
|
+
# local_max[0] = fmax(local_max[0], local_max[1])
|
|
362
|
+
# return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
|
|
363
|
+
local_max = [res[0], res[1], res[2], res[3]]
|
|
364
|
+
for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
|
|
365
|
+
local_max[0] = fmax(local_max[0], res[i + 0])
|
|
366
|
+
local_max[1] = fmax(local_max[1], res[i + 1])
|
|
367
|
+
local_max[2] = fmax(local_max[2], res[i + 2])
|
|
368
|
+
local_max[3] = fmax(local_max[3], res[i + 3])
|
|
369
|
+
local_max[0] = fmax(local_max[0], local_max[1])
|
|
370
|
+
local_max[2] = fmax(local_max[2], local_max[3])
|
|
371
|
+
local_max[0] = fmax(local_max[0], local_max[2])
|
|
372
|
+
return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
|
|
373
|
+
else:
|
|
374
|
+
# [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max
|
|
375
|
+
# We instead force the 3-input max.
|
|
376
|
+
res = cute.make_fragment(x.shape, Float32)
|
|
377
|
+
res.store(x)
|
|
378
|
+
local_max_0 = (
|
|
379
|
+
fmax(init_val, res[0], res[1])
|
|
380
|
+
if const_expr(init_val is not None)
|
|
381
|
+
else fmax(res[0], res[1])
|
|
382
|
+
)
|
|
383
|
+
local_max = [
|
|
384
|
+
local_max_0,
|
|
385
|
+
fmax(res[2], res[3]),
|
|
386
|
+
fmax(res[4], res[5]),
|
|
387
|
+
fmax(res[6], res[7]),
|
|
388
|
+
]
|
|
389
|
+
for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
|
|
390
|
+
local_max[0] = fmax(local_max[0], res[i], res[i + 1])
|
|
391
|
+
local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3])
|
|
392
|
+
local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5])
|
|
393
|
+
local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7])
|
|
394
|
+
local_max[0] = fmax(local_max[0], local_max[1])
|
|
395
|
+
return fmax(local_max[0], local_max[2], local_max[3])
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@cute.jit
|
|
399
|
+
def fadd_reduce(
|
|
400
|
+
x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
|
|
401
|
+
) -> Float32:
|
|
402
|
+
if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
|
|
403
|
+
if const_expr(init_val is None):
|
|
404
|
+
init_val = Float32.zero
|
|
405
|
+
return x.reduce(cute.ReductionOp.ADD, init_val, 0)
|
|
406
|
+
# res = cute.make_fragment(x.shape, Float32)
|
|
407
|
+
# res.store(x)
|
|
408
|
+
# local_sum = [res[0], res[1], res[2], res[3]]
|
|
409
|
+
# for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
|
|
410
|
+
# local_sum[0] += res[i + 0]
|
|
411
|
+
# local_sum[1] += res[i + 1]
|
|
412
|
+
# local_sum[2] += res[i + 2]
|
|
413
|
+
# local_sum[3] += res[i + 3]
|
|
414
|
+
# local_sum[0] += local_sum[1]
|
|
415
|
+
# local_sum[2] += local_sum[3]
|
|
416
|
+
# local_sum[0] += local_sum[2]
|
|
417
|
+
# return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val
|
|
418
|
+
else:
|
|
419
|
+
res = cute.make_fragment(x.shape, Float32)
|
|
420
|
+
res.store(x)
|
|
421
|
+
local_sum_0 = (
|
|
422
|
+
add_packed_f32x2((init_val, 0.0), (res[0], res[1]))
|
|
423
|
+
# add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1]))
|
|
424
|
+
if const_expr(init_val is not None)
|
|
425
|
+
else (res[0], res[1])
|
|
426
|
+
)
|
|
427
|
+
local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])]
|
|
428
|
+
for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
|
|
429
|
+
local_sum[0] = add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1]))
|
|
430
|
+
local_sum[1] = add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3]))
|
|
431
|
+
local_sum[2] = add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5]))
|
|
432
|
+
local_sum[3] = add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7]))
|
|
433
|
+
local_sum[0] = add_packed_f32x2(local_sum[0], local_sum[1])
|
|
434
|
+
local_sum[2] = add_packed_f32x2(local_sum[2], local_sum[3])
|
|
435
|
+
local_sum[0] = add_packed_f32x2(local_sum[0], local_sum[2])
|
|
436
|
+
return local_sum[0][0] + local_sum[0][1]
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
@dsl_user_op
|
|
440
|
+
def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None:
|
|
441
|
+
# gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
442
|
+
# # cache_hint = cutlass.Int64(0x12F0000000000000)
|
|
443
|
+
# llvm.inline_asm(
|
|
444
|
+
# None,
|
|
445
|
+
# [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)],
|
|
446
|
+
# # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],
|
|
447
|
+
# "red.global.add.f32 [$0], $1;",
|
|
448
|
+
# # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;",
|
|
449
|
+
# # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;",
|
|
450
|
+
# "l,f",
|
|
451
|
+
# # "l,f,l",
|
|
452
|
+
# has_side_effects=True,
|
|
453
|
+
# is_align_stack=False,
|
|
454
|
+
# asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
455
|
+
# )
|
|
456
|
+
nvvm.atomicrmw(
|
|
457
|
+
res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value()
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
@dsl_user_op
|
|
462
|
+
def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
|
|
463
|
+
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
@dsl_user_op
|
|
467
|
+
def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
|
|
468
|
+
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
469
|
+
flat_stride = cute.flatten_to_tuple(x.stride)
|
|
470
|
+
assert len(flat_coord_i64) == len(flat_stride), (
|
|
471
|
+
"Coordinate and stride must have the same length"
|
|
472
|
+
)
|
|
473
|
+
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
474
|
+
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
475
|
+
byte_offset = offset * x.element_type.width // 8
|
|
476
|
+
return cute.make_ptr(
|
|
477
|
+
x.element_type,
|
|
478
|
+
x.iterator.toint() + byte_offset,
|
|
479
|
+
x.memspace,
|
|
480
|
+
assumed_align=x.iterator.alignment,
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
@cute.jit
|
|
485
|
+
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
486
|
+
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
487
|
+
tApA = cute.make_fragment(
|
|
488
|
+
cute.make_layout(
|
|
489
|
+
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
|
490
|
+
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
|
491
|
+
),
|
|
492
|
+
cutlass.Boolean,
|
|
493
|
+
)
|
|
494
|
+
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
|
495
|
+
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
|
496
|
+
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
|
497
|
+
return tApA
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32:
|
|
501
|
+
warp_group_idx = cute.arch.thread_idx()[0] // 128
|
|
502
|
+
if const_expr(sync):
|
|
503
|
+
warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx)
|
|
504
|
+
return warp_group_idx
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
# @dsl_user_op
|
|
508
|
+
# def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean:
|
|
509
|
+
# mask = cutlass.Int32(-1)
|
|
510
|
+
# return cutlass.Boolean(
|
|
511
|
+
# llvm.inline_asm(
|
|
512
|
+
# T.i32(),
|
|
513
|
+
# [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)],
|
|
514
|
+
# ".pred p1, p2;\n"
|
|
515
|
+
# "setp.lt.f32 p1, $1, $2;\n"
|
|
516
|
+
# "vote.sync.any.pred p2, p1, $3;\n"
|
|
517
|
+
# "selp.u32 $0, 1, 0, p2;",
|
|
518
|
+
# # "selp.u32 $0, 1, 0, p1;",
|
|
519
|
+
# "=r,f,f,r",
|
|
520
|
+
# has_side_effects=False,
|
|
521
|
+
# is_align_stack=False,
|
|
522
|
+
# asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
523
|
+
# )
|
|
524
|
+
# )
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
@cute.jit
|
|
528
|
+
def shuffle_sync(
|
|
529
|
+
value: cute.Numeric,
|
|
530
|
+
offset: cute.typing.Int,
|
|
531
|
+
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
532
|
+
) -> cute.Numeric:
|
|
533
|
+
assert value.width % 32 == 0, "value type must be a multiple of 32 bits"
|
|
534
|
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
535
|
+
mask = cute.arch.WARP_SIZE - width
|
|
536
|
+
clamp = cute.arch.WARP_SIZE - 1
|
|
537
|
+
mask_and_clamp = mask << 8 | clamp
|
|
538
|
+
# important: need stride 1 and not 0 for recast_tensor to work
|
|
539
|
+
val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value))
|
|
540
|
+
val[0] = value
|
|
541
|
+
val_i32 = cute.recast_tensor(val, cutlass.Int32)
|
|
542
|
+
for i in cutlass.range_constexpr(cute.size(val_i32)):
|
|
543
|
+
val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp)
|
|
544
|
+
return val[0]
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
@dsl_user_op
|
|
548
|
+
def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
|
|
549
|
+
return cutlass.Uint32(
|
|
550
|
+
llvm.inline_asm(
|
|
551
|
+
T.i32(),
|
|
552
|
+
[
|
|
553
|
+
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
|
|
554
|
+
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
|
|
555
|
+
],
|
|
556
|
+
"shr.s32 $0, $1, $2;",
|
|
557
|
+
"=r,r,r",
|
|
558
|
+
has_side_effects=False,
|
|
559
|
+
is_align_stack=False,
|
|
560
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
561
|
+
)
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
@cute.jit
|
|
566
|
+
def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
|
|
567
|
+
if const_expr(lane is None):
|
|
568
|
+
lane = cute.arch.lane_idx()
|
|
569
|
+
# if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val)
|
|
570
|
+
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
|
|
571
|
+
offset = 1 << i
|
|
572
|
+
# Very important that we set mask_and_clamp to 0
|
|
573
|
+
partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
|
|
574
|
+
if lane >= offset:
|
|
575
|
+
val += partial_sum
|
|
576
|
+
# if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val)
|
|
577
|
+
return val
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
@dsl_user_op
|
|
581
|
+
def cvt_f16x2_f32(
|
|
582
|
+
a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None
|
|
583
|
+
) -> cutlass.Int32:
|
|
584
|
+
assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16"
|
|
585
|
+
return cutlass.Int32(
|
|
586
|
+
llvm.inline_asm(
|
|
587
|
+
T.i32(),
|
|
588
|
+
[Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)],
|
|
589
|
+
f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;",
|
|
590
|
+
"=r,f,f",
|
|
591
|
+
has_side_effects=False,
|
|
592
|
+
is_align_stack=False,
|
|
593
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
594
|
+
)
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
@overload
|
|
599
|
+
def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
@overload
|
|
603
|
+
def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ...
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
@cute.jit
|
|
607
|
+
def cvt_f16(src: cute.Tensor, dst_or_dtype):
|
|
608
|
+
"""Convert Float32 tensor to Float16/BFloat16.
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
src: Source tensor with Float32 element type
|
|
612
|
+
dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16)
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
None if dst is a tensor, or a new tensor if dtype is provided
|
|
616
|
+
"""
|
|
617
|
+
if const_expr(isinstance(dst_or_dtype, type)):
|
|
618
|
+
# dtype variant: create new tensor and call the tensor variant
|
|
619
|
+
dtype = dst_or_dtype
|
|
620
|
+
dst = cute.make_fragment(src.shape, dtype)
|
|
621
|
+
cvt_f16(src, dst)
|
|
622
|
+
return dst
|
|
623
|
+
else:
|
|
624
|
+
# tensor variant: write to dst
|
|
625
|
+
dst = dst_or_dtype
|
|
626
|
+
assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size"
|
|
627
|
+
assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements"
|
|
628
|
+
assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], (
|
|
629
|
+
"dst must be BFloat16 or Float16"
|
|
630
|
+
)
|
|
631
|
+
assert src.element_type is Float32, "src must be Float32"
|
|
632
|
+
dst_i32 = cute.recast_tensor(dst, cutlass.Int32)
|
|
633
|
+
assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)
|
|
634
|
+
for i in cutlass.range_constexpr(cute.size(dst_i32)):
|
|
635
|
+
dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type)
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
@dsl_user_op
|
|
639
|
+
@cute.jit
|
|
640
|
+
def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32:
|
|
641
|
+
deg = len(poly) - 1
|
|
642
|
+
out = poly[deg]
|
|
643
|
+
for i in cutlass.range_constexpr(deg - 1, -1, -1):
|
|
644
|
+
out = out * x + poly[i]
|
|
645
|
+
return out
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
@dsl_user_op
|
|
649
|
+
@cute.jit
|
|
650
|
+
def evaluate_polynomial_2(
|
|
651
|
+
x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None
|
|
652
|
+
) -> Tuple[Float32, Float32]:
|
|
653
|
+
deg = len(poly) - 1
|
|
654
|
+
out = (poly[deg], poly[deg])
|
|
655
|
+
for i in cutlass.range_constexpr(deg - 1, -1, -1):
|
|
656
|
+
out = fma_packed_f32x2(out, (x, y), (poly[i], poly[i]))
|
|
657
|
+
return out
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
@dsl_user_op
|
|
661
|
+
def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
662
|
+
# There's probably a way to call llvm or nvvm to do this instead of ptx
|
|
663
|
+
return cutlass.Float32(
|
|
664
|
+
llvm.inline_asm(
|
|
665
|
+
T.f32(),
|
|
666
|
+
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)],
|
|
667
|
+
"add.rm.ftz.f32 $0, $1, $2;",
|
|
668
|
+
"=f,f,f",
|
|
669
|
+
has_side_effects=False,
|
|
670
|
+
is_align_stack=False,
|
|
671
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
672
|
+
)
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
@dsl_user_op
|
|
677
|
+
def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32:
|
|
678
|
+
return cutlass.Float32(
|
|
679
|
+
llvm.inline_asm(
|
|
680
|
+
T.f32(),
|
|
681
|
+
[
|
|
682
|
+
Float32(x_rounded).ir_value(loc=loc, ip=ip),
|
|
683
|
+
Float32(frac_ex2).ir_value(loc=loc, ip=ip),
|
|
684
|
+
],
|
|
685
|
+
"{\n\t"
|
|
686
|
+
".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t"
|
|
687
|
+
"mov.b32 x_rounded_i, $1;\n\t"
|
|
688
|
+
"mov.b32 frac_ex_i, $2;\n\t"
|
|
689
|
+
"shl.b32 x_rounded_e, x_rounded_i, 23;\n\t"
|
|
690
|
+
# add.u32 generates IMAD instruction and add.s32 generates LEA instruction
|
|
691
|
+
# IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik
|
|
692
|
+
"add.s32 out_i, x_rounded_e, frac_ex_i;\n\t"
|
|
693
|
+
"mov.b32 $0, out_i;\n\t"
|
|
694
|
+
"}\n",
|
|
695
|
+
"=f,f,f",
|
|
696
|
+
has_side_effects=False,
|
|
697
|
+
is_align_stack=False,
|
|
698
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
699
|
+
)
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
@dsl_user_op
|
|
704
|
+
def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
705
|
+
# We assume x <= 127.0
|
|
706
|
+
poly_ex2_deg3 = (
|
|
707
|
+
1.0,
|
|
708
|
+
0.695146143436431884765625,
|
|
709
|
+
0.227564394474029541015625,
|
|
710
|
+
0.077119089663028717041015625,
|
|
711
|
+
)
|
|
712
|
+
fp32_round_int = float(2**23 + 2**22)
|
|
713
|
+
x_clamped = cute.arch.fmax(x, -127.0)
|
|
714
|
+
# We want to round down here, so that the fractional part is in [0, 1)
|
|
715
|
+
x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip)
|
|
716
|
+
# The integer floor of x is now in the last 8 bits of x_rounded
|
|
717
|
+
# We assume the next 2 ops round to nearest even. The rounding mode is important.
|
|
718
|
+
x_rounded_back = x_rounded - fp32_round_int
|
|
719
|
+
x_frac = x_clamped - x_rounded_back
|
|
720
|
+
x_frac_ex2 = evaluate_polynomial(x_frac, poly_ex2_deg3, loc=loc, ip=ip)
|
|
721
|
+
return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip)
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
# TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version
|
|
725
|
+
@dsl_user_op
|
|
726
|
+
def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
|
727
|
+
# We assume x <= 127.0 and y <= 127.0
|
|
728
|
+
poly_ex2_deg3 = (
|
|
729
|
+
1.0,
|
|
730
|
+
0.695146143436431884765625,
|
|
731
|
+
0.227564394474029541015625,
|
|
732
|
+
0.077119089663028717041015625,
|
|
733
|
+
)
|
|
734
|
+
fp32_round_int = float(2**23 + 2**22)
|
|
735
|
+
xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0))
|
|
736
|
+
# We want to round down here, so that the fractional part is in [0, 1)
|
|
737
|
+
xy_rounded = cute.arch.add_packed_f32x2(
|
|
738
|
+
xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM
|
|
739
|
+
)
|
|
740
|
+
# The integer floor of x & y are now in the last 8 bits of xy_rounded
|
|
741
|
+
# We want the next 2 ops to round to nearest even. The rounding mode is important.
|
|
742
|
+
xy_rounded_back = sub_packed_f32x2(xy_rounded, (fp32_round_int, fp32_round_int))
|
|
743
|
+
xy_frac = sub_packed_f32x2(xy_clamped, xy_rounded_back)
|
|
744
|
+
xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, poly_ex2_deg3, loc=loc, ip=ip)
|
|
745
|
+
x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip)
|
|
746
|
+
y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip)
|
|
747
|
+
return x_out, y_out
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
@dsl_user_op
|
|
751
|
+
def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
|
752
|
+
out_f32x2 = llvm.inline_asm(
|
|
753
|
+
llvm.StructType.get_literal([T.f32(), T.f32()]),
|
|
754
|
+
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()],
|
|
755
|
+
"{\n\t"
|
|
756
|
+
".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t"
|
|
757
|
+
".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t"
|
|
758
|
+
".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t"
|
|
759
|
+
"max.ftz.f32 f1, $2, 0fC2FE0000;\n\t"
|
|
760
|
+
"max.ftz.f32 f2, $3, 0fC2FE0000;\n\t"
|
|
761
|
+
"mov.b64 l1, {f1, f2};\n\t"
|
|
762
|
+
"mov.f32 f3, 0f4B400000;\n\t"
|
|
763
|
+
"mov.b64 l2, {f3, f3};\n\t"
|
|
764
|
+
"add.rm.ftz.f32x2 l7, l1, l2;\n\t"
|
|
765
|
+
"sub.rn.ftz.f32x2 l8, l7, l2;\n\t"
|
|
766
|
+
"sub.rn.ftz.f32x2 l9, l1, l8;\n\t"
|
|
767
|
+
"mov.f32 f7, 0f3D9DF09D;\n\t"
|
|
768
|
+
"mov.b64 l6, {f7, f7};\n\t"
|
|
769
|
+
"mov.f32 f6, 0f3E6906A4;\n\t"
|
|
770
|
+
"mov.b64 l5, {f6, f6};\n\t"
|
|
771
|
+
"mov.f32 f5, 0f3F31F519;\n\t"
|
|
772
|
+
"mov.b64 l4, {f5, f5};\n\t"
|
|
773
|
+
"mov.f32 f4, 0f3F800000;\n\t"
|
|
774
|
+
"mov.b64 l3, {f4, f4};\n\t"
|
|
775
|
+
"fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t"
|
|
776
|
+
"fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t"
|
|
777
|
+
"fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t"
|
|
778
|
+
"mov.b64 {r1, r2}, l7;\n\t"
|
|
779
|
+
"mov.b64 {r3, r4}, l10;\n\t"
|
|
780
|
+
"shl.b32 r5, r1, 23;\n\t"
|
|
781
|
+
"add.s32 r7, r5, r3;\n\t"
|
|
782
|
+
"shl.b32 r6, r2, 23;\n\t"
|
|
783
|
+
"add.s32 r8, r6, r4;\n\t"
|
|
784
|
+
"mov.b32 $0, r7;\n\t"
|
|
785
|
+
"mov.b32 $1, r8;\n\t"
|
|
786
|
+
"}\n",
|
|
787
|
+
"=r,=r,f,f",
|
|
788
|
+
has_side_effects=False,
|
|
789
|
+
is_align_stack=False,
|
|
790
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
791
|
+
)
|
|
792
|
+
out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip))
|
|
793
|
+
out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip))
|
|
794
|
+
return out0, out1
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
@dsl_user_op
|
|
798
|
+
def domain_offset_aligned(
|
|
799
|
+
coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
|
|
800
|
+
) -> cute.Tensor:
|
|
801
|
+
assert isinstance(tensor.iterator, cute.Pointer)
|
|
802
|
+
# We assume that applying the offset does not change the pointer alignment
|
|
803
|
+
new_ptr = cute.make_ptr(
|
|
804
|
+
tensor.element_type,
|
|
805
|
+
elem_pointer(tensor, coord).toint(),
|
|
806
|
+
tensor.memspace,
|
|
807
|
+
assumed_align=tensor.iterator.alignment,
|
|
808
|
+
)
|
|
809
|
+
return cute.make_tensor(new_ptr, tensor.layout)
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
@dsl_user_op
|
|
813
|
+
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
814
|
+
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
815
|
+
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
816
|
+
assert len(flat_coord_i64) == len(flat_stride), (
|
|
817
|
+
"Coordinate and stride must have the same length"
|
|
818
|
+
)
|
|
819
|
+
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
820
|
+
assert isinstance(tensor.iterator, cute.Pointer)
|
|
821
|
+
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
822
|
+
new_ptr = cute.make_ptr(
|
|
823
|
+
tensor.element_type,
|
|
824
|
+
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
825
|
+
tensor.memspace,
|
|
826
|
+
assumed_align=tensor.iterator.max_alignment,
|
|
827
|
+
)
|
|
828
|
+
return cute.make_tensor(new_ptr, tensor.layout)
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
@dsl_user_op
|
|
832
|
+
def coord_offset_i64(
|
|
833
|
+
tensor: cute.Tensor, idx: cute.typing.Int, dim: int, *, loc=None, ip=None
|
|
834
|
+
) -> cute.Tensor:
|
|
835
|
+
offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
|
|
836
|
+
assert isinstance(tensor.iterator, cute.Pointer)
|
|
837
|
+
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
838
|
+
new_ptr = cute.make_ptr(
|
|
839
|
+
tensor.element_type,
|
|
840
|
+
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
841
|
+
tensor.memspace,
|
|
842
|
+
assumed_align=tensor.iterator.max_alignment,
|
|
843
|
+
)
|
|
844
|
+
new_layout = cute.slice_(
|
|
845
|
+
tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1))
|
|
846
|
+
)
|
|
847
|
+
return cute.make_tensor(new_ptr, new_layout)
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
@cute.jit
|
|
851
|
+
def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:
|
|
852
|
+
"""Convert a scalar to a cute TensorSSA of shape (1,) and given dtype"""
|
|
853
|
+
vec = cute.make_fragment(1, dtype)
|
|
854
|
+
vec[0] = a
|
|
855
|
+
return vec.load()
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def ssa_to_scalar(val):
|
|
859
|
+
"""Could inline but nice for reflecting the above api"""
|
|
860
|
+
return val[0]
|