mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
mslk/quantize/shuffle.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
# Helper functions for using MSLK quantized operators.
|
|
10
|
+
|
|
11
|
+
from typing import Tuple
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from mslk.quantize.triton.fp8_quantize import quantize_fp8_row
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def pack_int4(x: torch.Tensor) -> torch.Tensor:
|
|
18
|
+
# Given int8 x, pack adjacent int4 values into a single int8.
|
|
19
|
+
low_x = x[:, ::2]
|
|
20
|
+
high_x = x[:, 1::2]
|
|
21
|
+
|
|
22
|
+
# High bits need to left shift, this also masks off extra bits.
|
|
23
|
+
high_x = torch.bitwise_left_shift(high_x, 4)
|
|
24
|
+
# Low bits need to have sign bits removed.
|
|
25
|
+
low_x = torch.bitwise_and(low_x, 0xF)
|
|
26
|
+
|
|
27
|
+
# Recombine into a single value with bitwise or.
|
|
28
|
+
return torch.bitwise_or(low_x, high_x).contiguous()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def int4_row_quantize_zp(
|
|
32
|
+
x: torch.Tensor,
|
|
33
|
+
group_size: int = 128,
|
|
34
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
35
|
+
n_bit = 4 # Number of target bits.
|
|
36
|
+
# Split input into chunks of group_size. This approach allows K that isnt divisible by group_size.
|
|
37
|
+
to_quant = torch.split(x.to(torch.float), group_size, dim=-1)
|
|
38
|
+
|
|
39
|
+
max_val = [chunk.amax(dim=1, keepdim=True) for chunk in to_quant]
|
|
40
|
+
min_val = [chunk.amin(dim=1, keepdim=True) for chunk in to_quant]
|
|
41
|
+
max_int = 2**n_bit - 1
|
|
42
|
+
min_int = 0
|
|
43
|
+
scales = [
|
|
44
|
+
(max_chunk - min_chunk).clamp(min=1e-6) / max_int
|
|
45
|
+
for max_chunk, min_chunk in zip(max_val, min_val)
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
zeros = [
|
|
49
|
+
min_chunk + scale_chunk * (2 ** (n_bit - 1))
|
|
50
|
+
for min_chunk, scale_chunk in zip(min_val, scales)
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
out = [
|
|
54
|
+
chunk.sub(min_chunk).div(scale_chunk).round().clamp_(min_int, max_int)
|
|
55
|
+
for chunk, min_chunk, scale_chunk in zip(to_quant, min_val, scales)
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
# Recenter output and move to int8.
|
|
59
|
+
out = [(chunk - 2 ** (n_bit - 1)).to(dtype=torch.int8) for chunk in out]
|
|
60
|
+
|
|
61
|
+
# Recombine chunks.
|
|
62
|
+
out = torch.cat(out, dim=-1)
|
|
63
|
+
|
|
64
|
+
# Cutlass expects column major layout for scale and zero point,
|
|
65
|
+
# so we transpose here and make them contiguous.
|
|
66
|
+
scales = torch.cat(scales, dim=-1).t().contiguous()
|
|
67
|
+
zeros = torch.cat(zeros, dim=-1).t().contiguous()
|
|
68
|
+
|
|
69
|
+
return out, scales, zeros
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def int4_row_quantize(
|
|
73
|
+
x: torch.Tensor,
|
|
74
|
+
group_size: int = 128,
|
|
75
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
76
|
+
"""
|
|
77
|
+
Helper function to quantize a tensor to int4 with groupwise scales.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
x (Tensor): [N, K] Higher precision weight tensor to quantize.
|
|
81
|
+
group_size (int): Number of elements to calculate group scale for.
|
|
82
|
+
Returns:
|
|
83
|
+
wq (Tensor): [N, K // 2] Quantized int4 tensor stored in int8 elements.
|
|
84
|
+
group_scale (Tensor): [K / group_size, N] FP32 Scale per group.
|
|
85
|
+
"""
|
|
86
|
+
n_bit = 4 # Number of target bits.
|
|
87
|
+
# Split input into chunks of group_size. This approach allows K that isnt divisible by group_size.
|
|
88
|
+
to_quant = torch.split(x.to(torch.float), group_size, dim=-1)
|
|
89
|
+
|
|
90
|
+
max_val = [torch.abs(chunk).amax(dim=-1, keepdim=True) for chunk in to_quant]
|
|
91
|
+
max_int = 2 ** (n_bit - 1)
|
|
92
|
+
min_int = -(2 ** (n_bit - 1))
|
|
93
|
+
scales = [chunk.clamp(min=1e-6) / max_int for chunk in max_val]
|
|
94
|
+
|
|
95
|
+
out = [
|
|
96
|
+
chunk.div(chunk_scale).round().clamp_(min_int, max_int - 1)
|
|
97
|
+
for chunk, chunk_scale in zip(to_quant, scales)
|
|
98
|
+
]
|
|
99
|
+
# Recombine chunks.
|
|
100
|
+
out = torch.cat(out, dim=-1)
|
|
101
|
+
|
|
102
|
+
# Cast to int8 and restore shape.
|
|
103
|
+
out = out.to(dtype=torch.int8)
|
|
104
|
+
|
|
105
|
+
# Scales should be in [num_groups, N] layout.
|
|
106
|
+
scales = torch.cat(scales, dim=-1).t().contiguous()
|
|
107
|
+
|
|
108
|
+
return out, scales
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def quantize_int4_preshuffle(
|
|
112
|
+
w: torch.Tensor, group_size: int = 128, dtype: str = "fp8", use_zp: bool = True
|
|
113
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
114
|
+
"""
|
|
115
|
+
Quantizes an input weight tensor to int4 using preshuffling and scale packing.
|
|
116
|
+
This function is intended to be used with MSLKs mixed dtype kernels and is expected
|
|
117
|
+
to be applied to weights ahead of time. As such, it is not perfectly optimized.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
w (Tensor): [N, K] Higher precision weight tensor to quantize. May optionally have a batch dimension.
|
|
121
|
+
group_size (int): Number of elements to calculate group scale for, must be at least 128.
|
|
122
|
+
dtype (torch.dtype): Type of corresponding activations. Must be fp8 or bf16.
|
|
123
|
+
use_zp (bool): If true, uses zero points during weight quantization. Only relevant for bf16 currently.
|
|
124
|
+
Returns:
|
|
125
|
+
wq (Tensor): [N, K // 2] Quantized int4 weight tensor packed into int8 elements.
|
|
126
|
+
scales (Tuple[Tensor]): Scale tensors for the specified activation type. When FP8 is used,
|
|
127
|
+
scales is a tuple of row_scale ([N]) and group_scale ([K / group_size, 8, N]). When BF16 is
|
|
128
|
+
used, scales is a tuple of group_scale([K / group_size, N]) and group_zero ([K / group_size, N])
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def _quantize(
|
|
132
|
+
w: torch.Tensor, dtype: str = "fp8"
|
|
133
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
134
|
+
if dtype == "fp8":
|
|
135
|
+
# Start by lowering weights to FP8 and producing row scales.
|
|
136
|
+
wq, row_scale = quantize_fp8_row(w)
|
|
137
|
+
|
|
138
|
+
# Now reduce to INT4.
|
|
139
|
+
wq, group_scale = int4_row_quantize(wq, group_size)
|
|
140
|
+
# Reduce group scale to FP8.
|
|
141
|
+
group_scale = group_scale.to(torch.float8_e4m3fn)
|
|
142
|
+
# Take quantized weights and pack them efficiently.
|
|
143
|
+
wq = pack_int4(wq)
|
|
144
|
+
# Finally pack weights and scales into efficient preshuffled format.
|
|
145
|
+
wq, group_scale = torch.ops.mslk.preshuffle_i4(wq, group_scale)
|
|
146
|
+
return wq, (group_scale, row_scale)
|
|
147
|
+
|
|
148
|
+
elif dtype == "bf16":
|
|
149
|
+
if use_zp:
|
|
150
|
+
wq, group_scale, group_zero = int4_row_quantize_zp(w, group_size)
|
|
151
|
+
else:
|
|
152
|
+
wq, group_scale = int4_row_quantize(w, group_size)
|
|
153
|
+
group_zero = torch.zeros_like(group_scale)
|
|
154
|
+
# Set scales to activation type.
|
|
155
|
+
group_scale = group_scale.to(torch.bfloat16)
|
|
156
|
+
group_zero = group_zero.to(torch.bfloat16)
|
|
157
|
+
# Take quantized weights and pack them efficiently.
|
|
158
|
+
wq = pack_int4(wq)
|
|
159
|
+
# Finally pack weights and scales into efficient preshuffled format.
|
|
160
|
+
wq, group_scale = torch.ops.mslk.preshuffle_i4(wq, group_scale)
|
|
161
|
+
return wq, (group_scale, group_zero)
|
|
162
|
+
else:
|
|
163
|
+
raise NotImplementedError("Only fp8 and bf16 activations supported.")
|
|
164
|
+
|
|
165
|
+
if w.ndim >= 3:
|
|
166
|
+
orig_shape = w.shape
|
|
167
|
+
# Flatten to 3 dimensions then iterate over batches.
|
|
168
|
+
wq, scales = zip(*[_quantize(i, dtype=dtype) for i in w])
|
|
169
|
+
wq = torch.stack(wq).view(*orig_shape[:-2], *wq[0].shape)
|
|
170
|
+
# Decompose then stack scales back into a tuple.
|
|
171
|
+
a_scales, b_scales = zip(*scales)
|
|
172
|
+
scales = (
|
|
173
|
+
torch.stack(a_scales).view(*orig_shape[:-2], *a_scales[0].shape),
|
|
174
|
+
torch.stack(b_scales).view(*orig_shape[:-2], *b_scales[0].shape),
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
wq, scales = _quantize(w, dtype=dtype)
|
|
178
|
+
|
|
179
|
+
return wq, scales
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def shuffle_slice(
|
|
183
|
+
x: torch.Tensor, dim: int, start: int, length: int, dtype: str = "fp8"
|
|
184
|
+
) -> torch.Tensor:
|
|
185
|
+
"""
|
|
186
|
+
Helper function to slice a preshuffled int4 tensor. This is needed since the shuffling
|
|
187
|
+
reorders rows based on the size of the input. Slicing a tensor shuffled for a larger input
|
|
188
|
+
is no longer valid. We must reorder the tensor to the appropriate size then slice.
|
|
189
|
+
Args:
|
|
190
|
+
x (Tensor): [N, K // 2] Preshuffled int4 tensor.
|
|
191
|
+
dim (int): Dimension to slice.
|
|
192
|
+
start (int): Start of slice.
|
|
193
|
+
length (int): Number of elements to slice in the original [N, K] dimension.
|
|
194
|
+
dtype (str): Type of corresponding activations. Must be fp8 or bf16.
|
|
195
|
+
Returns:
|
|
196
|
+
sliced (Tensor): [stop-start, K // 2] Sliced tensor.
|
|
197
|
+
"""
|
|
198
|
+
# Get the size of the input tensor.
|
|
199
|
+
assert dim in [x.ndim - 2, x.ndim - 1], "Only slicing along N or K is supported."
|
|
200
|
+
assert length % 16 == 0, "Slicing must be a multiple of 16."
|
|
201
|
+
orig_shape = x.shape
|
|
202
|
+
N = x.shape[-2]
|
|
203
|
+
K = x.shape[-1]
|
|
204
|
+
# Tile shape is based on the activation dtype.
|
|
205
|
+
assert dtype in ("fp8", "bf16"), "Only fp8 and bf16 activations supported."
|
|
206
|
+
# Handle slice along M
|
|
207
|
+
if dim == x.ndim - 2:
|
|
208
|
+
tile_shape = 8 if dtype == "fp8" else 16
|
|
209
|
+
block_size = N // length
|
|
210
|
+
# View the shape in terms of shuffled tiles then permute to allow slicing.
|
|
211
|
+
x_s = x.view(-1, tile_shape, block_size, length // tile_shape, K)
|
|
212
|
+
x_s = x_s.permute(0, 2, 1, 3, 4).contiguous().view(-1, N, K)
|
|
213
|
+
out_slice = x_s.narrow(1, start, length)
|
|
214
|
+
# Reshape back to original shape.
|
|
215
|
+
return out_slice.view(*orig_shape[:-2], length, K)
|
|
216
|
+
# Handle slice along K
|
|
217
|
+
else:
|
|
218
|
+
outer_dim = x.view(-1, N, K).shape[0]
|
|
219
|
+
x_s = x.view(outer_dim, -1, length // 2)
|
|
220
|
+
row_factor = x_s.shape[1] * (length // 2) // K
|
|
221
|
+
# Take slices of rows corresponding to column slice.
|
|
222
|
+
return x_s.narrow(1, start * 2 * K // length, row_factor).view(
|
|
223
|
+
*orig_shape[:-2], N, length // 2
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def scale_nvfp4_quant(
|
|
228
|
+
input: torch.Tensor, input_global_scale: torch.Tensor
|
|
229
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
230
|
+
"""
|
|
231
|
+
Quantize input tensor to FP4 and return quantized tensor and scale.
|
|
232
|
+
This function quantizes the last dimension of the given tensor `input`. For
|
|
233
|
+
every 16 consecutive elements, a single dynamically computed scaling factor
|
|
234
|
+
is shared. This scaling factor is quantized using the `input_global_scale`
|
|
235
|
+
and is stored in a swizzled layout (see
|
|
236
|
+
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
|
|
237
|
+
Args:
|
|
238
|
+
input: The input tensor to be quantized to FP4
|
|
239
|
+
input_global_scale: A scalar scaling factor for the entire tensor.
|
|
240
|
+
Returns:
|
|
241
|
+
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
|
|
242
|
+
two values are packed into a uint8 and float8_e4m3 scaling factors
|
|
243
|
+
in the sizzled layout.
|
|
244
|
+
"""
|
|
245
|
+
assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}."
|
|
246
|
+
other_dims = 1 if input.ndim == 1 else -1
|
|
247
|
+
input = input.reshape(other_dims, input.shape[-1])
|
|
248
|
+
m, n = input.shape
|
|
249
|
+
block_size = 16
|
|
250
|
+
device = input.device
|
|
251
|
+
|
|
252
|
+
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
|
|
253
|
+
assert input.dtype in (
|
|
254
|
+
torch.float16,
|
|
255
|
+
torch.bfloat16,
|
|
256
|
+
), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
|
|
257
|
+
|
|
258
|
+
# Two fp4 values will be packed into an uint8.
|
|
259
|
+
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
|
260
|
+
|
|
261
|
+
# We use the rounded values to store the swizzled values. Due to the
|
|
262
|
+
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
|
|
263
|
+
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
|
|
264
|
+
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
|
|
265
|
+
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
|
|
266
|
+
def round_up(x: int, y: int) -> int:
|
|
267
|
+
return (x + y - 1) // y * y
|
|
268
|
+
|
|
269
|
+
rounded_m = round_up(m, 128)
|
|
270
|
+
scale_n = n // block_size
|
|
271
|
+
rounded_n = round_up(scale_n, 4)
|
|
272
|
+
output_scale = torch.empty(
|
|
273
|
+
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
torch.ops.mslk.scaled_fp4_quant(output, input, output_scale, input_global_scale)
|
|
277
|
+
output_scale = output_scale.view(torch.float8_e4m3fn)
|
|
278
|
+
return output, output_scale
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def ck_preshuffle(src: torch.Tensor, NXdl: int = 16) -> torch.Tensor:
|
|
282
|
+
"""
|
|
283
|
+
Applies shuffling to make weights more efficient for use with CK kernels.
|
|
284
|
+
Args:
|
|
285
|
+
src (torch.Tensor): Input tensor with dtype float8_e4m3fnuz.
|
|
286
|
+
NXdl (int): Wave tile size along N.
|
|
287
|
+
Returns:
|
|
288
|
+
torch.Tensor: The shuffled tensor.
|
|
289
|
+
"""
|
|
290
|
+
# Check input datatype
|
|
291
|
+
if src.dtype != torch.float8_e4m3fnuz:
|
|
292
|
+
raise TypeError("Input must be type float8_e4m3fnuz.")
|
|
293
|
+
N, K = src.shape
|
|
294
|
+
KPack = 16
|
|
295
|
+
NLane = NXdl
|
|
296
|
+
KLane = 64 // NLane
|
|
297
|
+
K0 = K // (KLane * KPack)
|
|
298
|
+
# Reshape src to enable the required permutation
|
|
299
|
+
# Original shape: (N, K)
|
|
300
|
+
# Desired intermediate shape for permutation: (N0, NLane, K0, KLane, KPack)
|
|
301
|
+
src = src.reshape(N // NLane, NLane, K0, KLane, KPack)
|
|
302
|
+
# Apply permutation: (N0, NLane, K0, KLane, KPack) -> (N0, K0, KLane, NLane, KPack)
|
|
303
|
+
dst = src.permute(0, 2, 3, 1, 4).contiguous()
|
|
304
|
+
# Reshape to original input shape.
|
|
305
|
+
dst = dst.reshape(N, K)
|
|
306
|
+
return dst
|