mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,3342 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# Keep a registry of all quantize operators.
|
|
8
|
+
import abc
|
|
9
|
+
import functools
|
|
10
|
+
from enum import auto, Enum
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
from mslk.bench.common.utils import BenchOptions, do_bench
|
|
15
|
+
from mslk.gemm.triton.fp8_gemm import matmul_fp8_block, matmul_fp8_row, to_mxfp8
|
|
16
|
+
from mslk.gemm.triton.grouped_gemm import grouped_gemm, grouped_gemm_fp8_rowwise
|
|
17
|
+
from mslk.quantize.shuffle import ck_preshuffle, quantize_int4_preshuffle
|
|
18
|
+
from mslk.quantize.triton.fp4_quantize import (
|
|
19
|
+
_to_blocked,
|
|
20
|
+
calculate_group_max,
|
|
21
|
+
mega_fp4_pack,
|
|
22
|
+
mega_fp4_quantize_kernel,
|
|
23
|
+
mega_fp4_unpack,
|
|
24
|
+
triton_quantize_mx4_unpack,
|
|
25
|
+
triton_quantize_nvfp4,
|
|
26
|
+
triton_scale_nvfp4_quant_rms,
|
|
27
|
+
triton_scale_nvfp4_quant_silu,
|
|
28
|
+
)
|
|
29
|
+
from mslk.quantize.triton.fp8_quantize import (
|
|
30
|
+
quantize_fp8_block,
|
|
31
|
+
quantize_fp8_group,
|
|
32
|
+
quantize_fp8_row,
|
|
33
|
+
scale_fp8_row,
|
|
34
|
+
triton_quantize_fp8_row,
|
|
35
|
+
)
|
|
36
|
+
from mslk.utils.triton.fp8_utils import get_fp8_constants
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
from gen_ai.llm_inference.fb.llm.kernel.rms_norm import rms_norm
|
|
40
|
+
from gen_ai.llm_inference.fb.llm.kernel.silu_mul import silu_mul
|
|
41
|
+
except ImportError:
|
|
42
|
+
# Above is used for some experiments, but the quantize is not relying on them. Okay to just skip.
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
from tinygemm.utils import group_quantize_tensor
|
|
47
|
+
|
|
48
|
+
if torch.cuda.is_available() and torch.version.cuda:
|
|
49
|
+
torch.ops.load_library("//tinygemm:tinygemm")
|
|
50
|
+
TINYGEMM_ENABLED = True
|
|
51
|
+
except ImportError:
|
|
52
|
+
TINYGEMM_ENABLED = False
|
|
53
|
+
|
|
54
|
+
# Marlin currently only is supported only internally at Meta.
|
|
55
|
+
try:
|
|
56
|
+
from marlin.quantize import marlin_quantize
|
|
57
|
+
|
|
58
|
+
torch.ops.load_library("//ai_codesign/gen_ai/marlin:marlin_ops")
|
|
59
|
+
MARLIN_ENABLED = True
|
|
60
|
+
except ImportError:
|
|
61
|
+
MARLIN_ENABLED = False
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
from deep_gemm import (
|
|
65
|
+
gemm_fp8_fp8_bf16_nt,
|
|
66
|
+
get_col_major_tma_aligned_tensor,
|
|
67
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
|
68
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
DEEPGEMM_ENABLED = True
|
|
72
|
+
except ImportError:
|
|
73
|
+
DEEPGEMM_ENABLED = False
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# Machete is also only supported internally at Meta for now.
|
|
77
|
+
try:
|
|
78
|
+
from machete.machete import machete_gemm
|
|
79
|
+
from machete.quantize import machete_quantize_and_pack
|
|
80
|
+
|
|
81
|
+
MACHETE_ENABLED = True
|
|
82
|
+
except ImportError:
|
|
83
|
+
MACHETE_ENABLED = False
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class Accelerator(Enum):
|
|
87
|
+
NVIDIA_SM90 = auto()
|
|
88
|
+
NVIDIA_SM100 = auto()
|
|
89
|
+
NVIDIA_SM103 = auto()
|
|
90
|
+
AMD_MI300X = auto()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class GemmType(Enum):
|
|
94
|
+
REGULAR = auto()
|
|
95
|
+
GROUPED = auto()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ComputeDtype(Enum):
|
|
99
|
+
FP32 = auto()
|
|
100
|
+
TF32 = auto()
|
|
101
|
+
BF16 = auto()
|
|
102
|
+
FP8 = auto()
|
|
103
|
+
FP4 = auto()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@functools.cache
|
|
107
|
+
def get_current_accelerator() -> Accelerator | None:
|
|
108
|
+
if not torch.cuda.is_available():
|
|
109
|
+
raise Exception("Cannot run gemm_bench without accelerator.")
|
|
110
|
+
|
|
111
|
+
if torch.version.hip is not None:
|
|
112
|
+
device_name = torch.cuda.get_device_name()
|
|
113
|
+
if "MI300X" in device_name.upper():
|
|
114
|
+
return Accelerator.AMD_MI300X
|
|
115
|
+
elif torch.version.cuda is not None:
|
|
116
|
+
major, minor = torch.cuda.get_device_capability()
|
|
117
|
+
if major == 9 and minor == 0:
|
|
118
|
+
return Accelerator.NVIDIA_SM90
|
|
119
|
+
elif major == 10 and minor == 0:
|
|
120
|
+
return Accelerator.NVIDIA_SM100
|
|
121
|
+
elif major == 10 and minor == 3:
|
|
122
|
+
return Accelerator.NVIDIA_SM103
|
|
123
|
+
|
|
124
|
+
raise Exception("Cannot detect hardware that is supported by gemm_bench.")
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
gemm_op_registry = []
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class GemmOpBase(metaclass=abc.ABCMeta):
|
|
131
|
+
"""Helper abstract class to define expected methods of quantize ops."""
|
|
132
|
+
|
|
133
|
+
@abc.abstractmethod
|
|
134
|
+
def quantize(self, *args):
|
|
135
|
+
"""Function which quantizes inputs."""
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
@abc.abstractmethod
|
|
139
|
+
def compute(self, *args):
|
|
140
|
+
"""Function which performs main compute operation."""
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
@abc.abstractmethod
|
|
144
|
+
def quantize_and_compute(self, *args):
|
|
145
|
+
"""Function which quantizes inputs and performs main compute operation."""
|
|
146
|
+
pass
|
|
147
|
+
|
|
148
|
+
def preprocess(self, *args):
|
|
149
|
+
"""Preprocess inputs before benchmarking. These outputs will be passed to quantize."""
|
|
150
|
+
return args
|
|
151
|
+
|
|
152
|
+
def benchmark(
|
|
153
|
+
self,
|
|
154
|
+
*args,
|
|
155
|
+
opts: BenchOptions,
|
|
156
|
+
bench_quantize: bool,
|
|
157
|
+
) -> float:
|
|
158
|
+
"""Benchmark runtime of this operator."""
|
|
159
|
+
t = do_bench(
|
|
160
|
+
lambda *a: self.quantize_and_compute(*a)
|
|
161
|
+
if bench_quantize
|
|
162
|
+
else self.compute(*a),
|
|
163
|
+
args,
|
|
164
|
+
opts,
|
|
165
|
+
)
|
|
166
|
+
return t
|
|
167
|
+
|
|
168
|
+
@abc.abstractproperty
|
|
169
|
+
def name(self) -> str:
|
|
170
|
+
"""Name of the operator."""
|
|
171
|
+
pass
|
|
172
|
+
|
|
173
|
+
@abc.abstractproperty
|
|
174
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
175
|
+
pass
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def supported(self) -> bool:
|
|
179
|
+
"""Whether this op will run on the current device."""
|
|
180
|
+
accelerator = get_current_accelerator()
|
|
181
|
+
return accelerator in self.supported_accelerators
|
|
182
|
+
|
|
183
|
+
@abc.abstractproperty
|
|
184
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
185
|
+
pass
|
|
186
|
+
|
|
187
|
+
@abc.abstractproperty
|
|
188
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
189
|
+
"""The dtype used by tensor cores for the compute."""
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def register_gemm_op(op):
|
|
194
|
+
"""Decorator function for assembling all quantize ops."""
|
|
195
|
+
gemm_op_registry.append(op())
|
|
196
|
+
return op
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def get_gemm_ops() -> list[GemmOpBase]:
|
|
200
|
+
"""Get all registered quantize ops."""
|
|
201
|
+
return gemm_op_registry
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@register_gemm_op
|
|
205
|
+
class FP32Baseline(GemmOpBase):
|
|
206
|
+
"""
|
|
207
|
+
FP32 matmul baseline.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def quantize(self, x, w):
|
|
211
|
+
if isinstance(x, list):
|
|
212
|
+
x = [i.float() for i in x]
|
|
213
|
+
w = [torch.transpose(i, -2, -1).float() for i in w]
|
|
214
|
+
else:
|
|
215
|
+
x = x.float()
|
|
216
|
+
w = torch.transpose(w, -2, -1).float()
|
|
217
|
+
return x, w
|
|
218
|
+
|
|
219
|
+
def compute(self, x, w):
|
|
220
|
+
# Handle both grouped and standard gemm.
|
|
221
|
+
if isinstance(x, list):
|
|
222
|
+
output = []
|
|
223
|
+
for i in range(len(x)):
|
|
224
|
+
output.append(torch.matmul(x[i], w[i]))
|
|
225
|
+
return output
|
|
226
|
+
return torch.matmul(x, w)
|
|
227
|
+
|
|
228
|
+
def quantize_and_compute(self, x, w):
|
|
229
|
+
return self.compute(*self.quantize(x, w))
|
|
230
|
+
|
|
231
|
+
@property
|
|
232
|
+
def name(self) -> str:
|
|
233
|
+
return "fp32_baseline"
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
237
|
+
return set(Accelerator)
|
|
238
|
+
|
|
239
|
+
@property
|
|
240
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
241
|
+
return {GemmType.REGULAR, GemmType.GROUPED}
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
245
|
+
return ComputeDtype.FP32
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@register_gemm_op
|
|
249
|
+
class TF32Baseline(GemmOpBase):
|
|
250
|
+
"""
|
|
251
|
+
TF32 matmul baseline.
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
def quantize(self, x, w):
|
|
255
|
+
if isinstance(x, list):
|
|
256
|
+
x = [i.float() for i in x]
|
|
257
|
+
w = [torch.transpose(i, -2, -1).float() for i in w]
|
|
258
|
+
else:
|
|
259
|
+
x = x.float()
|
|
260
|
+
w = torch.transpose(w, -2, -1).float()
|
|
261
|
+
return x, w
|
|
262
|
+
|
|
263
|
+
def compute(self, x, w):
|
|
264
|
+
# Handle both grouped and standard gemm.
|
|
265
|
+
original_precision = torch.get_float32_matmul_precision()
|
|
266
|
+
torch.set_float32_matmul_precision("high")
|
|
267
|
+
if isinstance(x, list):
|
|
268
|
+
output = []
|
|
269
|
+
for i in range(len(x)):
|
|
270
|
+
output.append(torch.matmul(x[i], w[i]))
|
|
271
|
+
return output
|
|
272
|
+
out = torch.matmul(x, w)
|
|
273
|
+
torch.set_float32_matmul_precision(original_precision)
|
|
274
|
+
return out
|
|
275
|
+
|
|
276
|
+
def quantize_and_compute(self, x, w):
|
|
277
|
+
return self.compute(*self.quantize(x, w))
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def name(self) -> str:
|
|
281
|
+
return "tf32_baseline"
|
|
282
|
+
|
|
283
|
+
@property
|
|
284
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
285
|
+
return set(Accelerator)
|
|
286
|
+
|
|
287
|
+
@property
|
|
288
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
289
|
+
return {GemmType.REGULAR, GemmType.GROUPED}
|
|
290
|
+
|
|
291
|
+
@property
|
|
292
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
293
|
+
return ComputeDtype.TF32
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
@register_gemm_op
|
|
297
|
+
class BF16Baseline(GemmOpBase):
|
|
298
|
+
"""
|
|
299
|
+
Baseline BF16 matmul.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
def quantize(self, x, w):
|
|
303
|
+
if isinstance(x, list):
|
|
304
|
+
x = [i.bfloat16() for i in x]
|
|
305
|
+
w = [torch.transpose(i, -2, -1).bfloat16() for i in w]
|
|
306
|
+
else:
|
|
307
|
+
x = x.bfloat16()
|
|
308
|
+
w = torch.transpose(w, -2, -1).bfloat16()
|
|
309
|
+
return x, w
|
|
310
|
+
|
|
311
|
+
def compute(self, x, w):
|
|
312
|
+
# Handle both grouped and standard gemm.
|
|
313
|
+
if isinstance(x, list):
|
|
314
|
+
output = []
|
|
315
|
+
for i in range(len(x)):
|
|
316
|
+
output.append(torch.matmul(x[i], w[i]))
|
|
317
|
+
return output
|
|
318
|
+
return torch.matmul(x, w)
|
|
319
|
+
|
|
320
|
+
def quantize_and_compute(self, x, w):
|
|
321
|
+
return self.compute(*self.quantize(x, w))
|
|
322
|
+
|
|
323
|
+
@property
|
|
324
|
+
def name(self) -> str:
|
|
325
|
+
return "bf16_baseline"
|
|
326
|
+
|
|
327
|
+
@property
|
|
328
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
329
|
+
return set(Accelerator)
|
|
330
|
+
|
|
331
|
+
@property
|
|
332
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
333
|
+
return {GemmType.REGULAR, GemmType.GROUPED}
|
|
334
|
+
|
|
335
|
+
@property
|
|
336
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
337
|
+
return ComputeDtype.BF16
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
@register_gemm_op
|
|
341
|
+
class ScaledMMBaseline(GemmOpBase):
|
|
342
|
+
"""
|
|
343
|
+
Reference FP8 matmul implemented in native torch with cublas or hipblas.
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
def __init__(self):
|
|
347
|
+
self.fp8_dtype, _, _, _ = get_fp8_constants()
|
|
348
|
+
self.E4M3_MAX_POS: float = torch.finfo(self.fp8_dtype).max
|
|
349
|
+
self.E5M2_MAX_POS: float = torch.finfo(torch.float8_e5m2).max
|
|
350
|
+
self.FP16_MAX_POS: float = torch.finfo(torch.float16).max
|
|
351
|
+
self.EPS: float = 1e-12
|
|
352
|
+
self.fast_accum = True
|
|
353
|
+
|
|
354
|
+
def _amax_to_scale(
|
|
355
|
+
self, amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
|
|
356
|
+
) -> torch.Tensor:
|
|
357
|
+
# To make scale dtype to be fp32 for accuracy
|
|
358
|
+
amax = amax.float()
|
|
359
|
+
if float8_dtype == self.fp8_dtype:
|
|
360
|
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
|
361
|
+
res = self.E4M3_MAX_POS / torch.clamp(amax, min=self.EPS)
|
|
362
|
+
else: # e5m2
|
|
363
|
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
|
364
|
+
res = self.E5M2_MAX_POS / torch.clamp(amax, min=self.EPS)
|
|
365
|
+
|
|
366
|
+
# pyre-fixme[7]: Expected `Tensor` but got `Union[float, Tensor]`.
|
|
367
|
+
return res
|
|
368
|
+
|
|
369
|
+
def _to_fp8_saturated(
|
|
370
|
+
self, x: torch.Tensor, float8_dtype: torch.dtype
|
|
371
|
+
) -> torch.Tensor:
|
|
372
|
+
if float8_dtype == torch.float8_e4m3fn:
|
|
373
|
+
x = x.clamp(min=-1 * self.E4M3_MAX_POS, max=self.E4M3_MAX_POS)
|
|
374
|
+
else:
|
|
375
|
+
x = x.clamp(min=-1 * self.E5M2_MAX_POS, max=self.E5M2_MAX_POS)
|
|
376
|
+
return x.to(float8_dtype)
|
|
377
|
+
|
|
378
|
+
def _quantize_tensor(self, x):
|
|
379
|
+
x_amax = torch.max(torch.abs(x))
|
|
380
|
+
scale = self._amax_to_scale(x_amax, self.fp8_dtype, x.dtype)
|
|
381
|
+
scaled_x = self._to_fp8_saturated(x * scale, self.fp8_dtype)
|
|
382
|
+
x_inverse_scale = scale.reciprocal()
|
|
383
|
+
return scaled_x, x_inverse_scale
|
|
384
|
+
|
|
385
|
+
def quantize(self, x, w):
|
|
386
|
+
xq, x_scale = self._quantize_tensor(x)
|
|
387
|
+
wq, w_scale = self._quantize_tensor(w.t())
|
|
388
|
+
return xq, wq, x_scale, w_scale
|
|
389
|
+
|
|
390
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
391
|
+
output = torch._scaled_mm(
|
|
392
|
+
xq,
|
|
393
|
+
wq,
|
|
394
|
+
bias=None,
|
|
395
|
+
out_dtype=torch.bfloat16,
|
|
396
|
+
scale_a=x_scale,
|
|
397
|
+
scale_b=w_scale,
|
|
398
|
+
scale_result=None,
|
|
399
|
+
use_fast_accum=self.fast_accum,
|
|
400
|
+
)
|
|
401
|
+
return output
|
|
402
|
+
|
|
403
|
+
def quantize_and_compute(self, x, w):
|
|
404
|
+
return self.compute(*self.quantize(x, w))
|
|
405
|
+
|
|
406
|
+
@property
|
|
407
|
+
def name(self) -> str:
|
|
408
|
+
return "scaled_mm"
|
|
409
|
+
|
|
410
|
+
@property
|
|
411
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
412
|
+
return set(Accelerator)
|
|
413
|
+
|
|
414
|
+
@property
|
|
415
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
416
|
+
return {GemmType.REGULAR}
|
|
417
|
+
|
|
418
|
+
@property
|
|
419
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
420
|
+
return ComputeDtype.FP8
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@register_gemm_op
|
|
424
|
+
class BF16X9Baseline(GemmOpBase):
|
|
425
|
+
"""
|
|
426
|
+
FP32 matmul implemented with BF16X9 emulation.
|
|
427
|
+
"""
|
|
428
|
+
|
|
429
|
+
def quantize(self, x, w):
|
|
430
|
+
if isinstance(x, list):
|
|
431
|
+
x = [i.float() for i in x]
|
|
432
|
+
w = [i.float() for i in w]
|
|
433
|
+
else:
|
|
434
|
+
x = x.float()
|
|
435
|
+
w = w.float()
|
|
436
|
+
return x, w
|
|
437
|
+
|
|
438
|
+
def compute(self, x, w):
|
|
439
|
+
# Handle both grouped and standard gemm.
|
|
440
|
+
if isinstance(x, list):
|
|
441
|
+
output = []
|
|
442
|
+
for i in range(len(x)):
|
|
443
|
+
output.append(torch.ops.mslk.bf16x9_gemm(x[i], w[i]))
|
|
444
|
+
return output
|
|
445
|
+
return torch.ops.mslk.bf16x9_gemm(x, w)
|
|
446
|
+
|
|
447
|
+
def quantize_and_compute(self, x, w):
|
|
448
|
+
return self.compute(*self.quantize(x, w))
|
|
449
|
+
|
|
450
|
+
@property
|
|
451
|
+
def name(self) -> str:
|
|
452
|
+
return "bf16x9_gemm"
|
|
453
|
+
|
|
454
|
+
@property
|
|
455
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
456
|
+
return {
|
|
457
|
+
Accelerator.NVIDIA_SM90,
|
|
458
|
+
Accelerator.NVIDIA_SM100,
|
|
459
|
+
Accelerator.NVIDIA_SM103,
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
@property
|
|
463
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
464
|
+
return {GemmType.REGULAR, GemmType.GROUPED}
|
|
465
|
+
|
|
466
|
+
@property
|
|
467
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
468
|
+
return ComputeDtype.BF16
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
@register_gemm_op
|
|
472
|
+
class ScaledMMRowwise(GemmOpBase):
|
|
473
|
+
def __init__(self):
|
|
474
|
+
self.fast_accum = True
|
|
475
|
+
self.torch_compile = False
|
|
476
|
+
|
|
477
|
+
def quantize(self, x, w):
|
|
478
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
479
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
480
|
+
return xq, wq.t(), x_scale.unsqueeze(1), w_scale.unsqueeze(0)
|
|
481
|
+
|
|
482
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
483
|
+
if self.torch_compile:
|
|
484
|
+
f = torch.compile(
|
|
485
|
+
torch._scaled_mm,
|
|
486
|
+
options={
|
|
487
|
+
"max_autotune": True,
|
|
488
|
+
"max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
|
|
489
|
+
},
|
|
490
|
+
)
|
|
491
|
+
else:
|
|
492
|
+
f = torch._scaled_mm
|
|
493
|
+
|
|
494
|
+
return f(
|
|
495
|
+
xq,
|
|
496
|
+
wq,
|
|
497
|
+
bias=None,
|
|
498
|
+
out_dtype=torch.bfloat16,
|
|
499
|
+
scale_a=x_scale,
|
|
500
|
+
scale_b=w_scale,
|
|
501
|
+
scale_result=None,
|
|
502
|
+
use_fast_accum=self.fast_accum,
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
def quantize_and_compute(self, x, w):
|
|
506
|
+
return self.compute(*self.quantize(x, w))
|
|
507
|
+
|
|
508
|
+
@property
|
|
509
|
+
def name(self) -> str:
|
|
510
|
+
return "scaled_mm_rowwise"
|
|
511
|
+
|
|
512
|
+
@property
|
|
513
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
514
|
+
return set(Accelerator)
|
|
515
|
+
|
|
516
|
+
@property
|
|
517
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
518
|
+
return {GemmType.REGULAR}
|
|
519
|
+
|
|
520
|
+
@property
|
|
521
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
522
|
+
return ComputeDtype.FP8
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
@register_gemm_op
|
|
526
|
+
class ScaledMMMXFP8(GemmOpBase):
|
|
527
|
+
def __init__(self):
|
|
528
|
+
self.torch_compile = False
|
|
529
|
+
|
|
530
|
+
def quantize(self, x, w):
|
|
531
|
+
x_scale, xq = to_mxfp8(x)
|
|
532
|
+
x_scale = _to_blocked(x_scale)
|
|
533
|
+
w_scale, wq = to_mxfp8(w)
|
|
534
|
+
w_scale = _to_blocked(w_scale)
|
|
535
|
+
return xq, wq.t(), x_scale, w_scale
|
|
536
|
+
|
|
537
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
538
|
+
if self.torch_compile:
|
|
539
|
+
f = torch.compile(
|
|
540
|
+
torch._scaled_mm,
|
|
541
|
+
options={
|
|
542
|
+
"max_autotune": True,
|
|
543
|
+
"max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
|
|
544
|
+
},
|
|
545
|
+
)
|
|
546
|
+
else:
|
|
547
|
+
f = torch._scaled_mm
|
|
548
|
+
|
|
549
|
+
return f(
|
|
550
|
+
xq,
|
|
551
|
+
wq,
|
|
552
|
+
bias=None,
|
|
553
|
+
out_dtype=torch.bfloat16,
|
|
554
|
+
scale_a=x_scale,
|
|
555
|
+
scale_b=w_scale,
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
def quantize_and_compute(self, x, w):
|
|
559
|
+
return self.compute(*self.quantize(x, w))
|
|
560
|
+
|
|
561
|
+
@property
|
|
562
|
+
def name(self) -> str:
|
|
563
|
+
return "scaled_mm_mxfp8"
|
|
564
|
+
|
|
565
|
+
@property
|
|
566
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
567
|
+
return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
|
|
568
|
+
|
|
569
|
+
@property
|
|
570
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
571
|
+
return {GemmType.REGULAR}
|
|
572
|
+
|
|
573
|
+
@property
|
|
574
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
575
|
+
return ComputeDtype.FP8
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
@register_gemm_op
|
|
579
|
+
class ScaledMMNVFP4(GemmOpBase):
|
|
580
|
+
def __init__(self):
|
|
581
|
+
self.torch_compile = False
|
|
582
|
+
|
|
583
|
+
def quantize(self, x, w):
|
|
584
|
+
x_global_scale = torch.tensor([1.0], device=x.device, dtype=torch.float32)
|
|
585
|
+
w_global_scale = torch.tensor([1.0], device=w.device, dtype=torch.float32)
|
|
586
|
+
|
|
587
|
+
xq, x_scale = triton_quantize_nvfp4(x, x_global_scale)
|
|
588
|
+
wq, w_scale = triton_quantize_nvfp4(w, w_global_scale)
|
|
589
|
+
|
|
590
|
+
return (
|
|
591
|
+
xq.view(torch.float4_e2m1fn_x2),
|
|
592
|
+
wq.view(torch.float4_e2m1fn_x2),
|
|
593
|
+
x_scale.view(torch.float8_e4m3fn),
|
|
594
|
+
w_scale.view(torch.float8_e4m3fn),
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
598
|
+
if self.torch_compile:
|
|
599
|
+
f = torch.compile(
|
|
600
|
+
torch._scaled_mm,
|
|
601
|
+
options={
|
|
602
|
+
"max_autotune": True,
|
|
603
|
+
"max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
|
|
604
|
+
},
|
|
605
|
+
)
|
|
606
|
+
else:
|
|
607
|
+
f = torch._scaled_mm
|
|
608
|
+
|
|
609
|
+
return f(
|
|
610
|
+
xq,
|
|
611
|
+
wq.t(),
|
|
612
|
+
bias=None,
|
|
613
|
+
out_dtype=torch.bfloat16,
|
|
614
|
+
scale_a=x_scale,
|
|
615
|
+
scale_b=w_scale,
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
def quantize_and_compute(self, x, w):
|
|
619
|
+
return self.compute(*self.quantize(x, w))
|
|
620
|
+
|
|
621
|
+
@property
|
|
622
|
+
def name(self) -> str:
|
|
623
|
+
return "scaled_mm_nvfp4"
|
|
624
|
+
|
|
625
|
+
@property
|
|
626
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
627
|
+
return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
|
|
628
|
+
|
|
629
|
+
@property
|
|
630
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
631
|
+
return {GemmType.REGULAR}
|
|
632
|
+
|
|
633
|
+
@property
|
|
634
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
635
|
+
return ComputeDtype.FP4
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
@register_gemm_op
|
|
639
|
+
class FP8TensorwiseGemm(GemmOpBase):
|
|
640
|
+
"""
|
|
641
|
+
FP8 matmul with tensorwise scaling.
|
|
642
|
+
"""
|
|
643
|
+
|
|
644
|
+
def quantize(self, x, w):
|
|
645
|
+
# Quantize both input tensors.
|
|
646
|
+
xq, x_scale = torch.ops.mslk.quantize_fp8_per_tensor(x)
|
|
647
|
+
wq, w_scale = torch.ops.mslk.quantize_fp8_per_tensor(w)
|
|
648
|
+
return xq, wq, x_scale, w_scale
|
|
649
|
+
|
|
650
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
651
|
+
return torch.ops.mslk.f8f8bf16(xq, wq, x_scale * w_scale)
|
|
652
|
+
|
|
653
|
+
def quantize_and_compute(self, x, w):
|
|
654
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
655
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
656
|
+
|
|
657
|
+
@property
|
|
658
|
+
def name(self) -> str:
|
|
659
|
+
return "cutlass_tensorwise"
|
|
660
|
+
|
|
661
|
+
@property
|
|
662
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
663
|
+
return {Accelerator.NVIDIA_SM90}
|
|
664
|
+
|
|
665
|
+
@property
|
|
666
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
667
|
+
return {GemmType.REGULAR}
|
|
668
|
+
|
|
669
|
+
@property
|
|
670
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
671
|
+
return ComputeDtype.FP8
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
@register_gemm_op
|
|
675
|
+
class FP8CublasRowwiseGemm(GemmOpBase):
|
|
676
|
+
"""
|
|
677
|
+
FP8 cublas matmul with rowwise scaling.
|
|
678
|
+
"""
|
|
679
|
+
|
|
680
|
+
def quantize(self, x, w):
|
|
681
|
+
# Quantize both input tensors.
|
|
682
|
+
xq, x_scale = torch.ops.mslk.quantize_fp8_per_row(x)
|
|
683
|
+
wq, w_scale = torch.ops.mslk.quantize_fp8_per_row(w)
|
|
684
|
+
return xq, wq, x_scale, w_scale
|
|
685
|
+
|
|
686
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
687
|
+
out = torch.ops.mslk.f8f8bf16_cublas(xq, wq)
|
|
688
|
+
scaled_out = scale_fp8_row(out, x_scale, w_scale)
|
|
689
|
+
return scaled_out
|
|
690
|
+
|
|
691
|
+
def quantize_and_compute(self, x, w):
|
|
692
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
693
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
694
|
+
|
|
695
|
+
@property
|
|
696
|
+
def name(self) -> str:
|
|
697
|
+
return "cublas_rowwise"
|
|
698
|
+
|
|
699
|
+
@property
|
|
700
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
701
|
+
return {
|
|
702
|
+
Accelerator.NVIDIA_SM90,
|
|
703
|
+
Accelerator.NVIDIA_SM100,
|
|
704
|
+
Accelerator.NVIDIA_SM103,
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
@property
|
|
708
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
709
|
+
return {GemmType.REGULAR}
|
|
710
|
+
|
|
711
|
+
@property
|
|
712
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
713
|
+
return ComputeDtype.FP8
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
@register_gemm_op
|
|
717
|
+
class FP8CublasTensorwiseGemm(GemmOpBase):
|
|
718
|
+
"""
|
|
719
|
+
FP8 cublas matmul with tensorwise scaling.
|
|
720
|
+
"""
|
|
721
|
+
|
|
722
|
+
def quantize(self, x, w):
|
|
723
|
+
# Quantize both input tensors.
|
|
724
|
+
xq, x_scale = torch.ops.mslk.quantize_fp8_per_tensor(x)
|
|
725
|
+
wq, w_scale = torch.ops.mslk.quantize_fp8_per_tensor(w)
|
|
726
|
+
return xq, wq, x_scale, w_scale
|
|
727
|
+
|
|
728
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
729
|
+
return torch.ops.mslk.f8f8bf16_cublas(xq, wq, x_scale * w_scale)
|
|
730
|
+
|
|
731
|
+
def quantize_and_compute(self, x, w):
|
|
732
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
733
|
+
return self.compute(xq, wq, x_scale * w_scale)
|
|
734
|
+
|
|
735
|
+
@property
|
|
736
|
+
def name(self) -> str:
|
|
737
|
+
return "cublas_tensorwise"
|
|
738
|
+
|
|
739
|
+
@property
|
|
740
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
741
|
+
return {
|
|
742
|
+
Accelerator.NVIDIA_SM90,
|
|
743
|
+
Accelerator.NVIDIA_SM100,
|
|
744
|
+
Accelerator.NVIDIA_SM103,
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
@property
|
|
748
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
749
|
+
return {GemmType.REGULAR}
|
|
750
|
+
|
|
751
|
+
@property
|
|
752
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
753
|
+
return ComputeDtype.FP8
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
@register_gemm_op
|
|
757
|
+
class FP8RowwiseGemm(GemmOpBase):
|
|
758
|
+
"""
|
|
759
|
+
FP8 matmul with rowwise scaling.
|
|
760
|
+
"""
|
|
761
|
+
|
|
762
|
+
def __init__(self):
|
|
763
|
+
self.fast_accum = True
|
|
764
|
+
self.gemm_op = torch.ops.mslk.f8f8bf16_rowwise
|
|
765
|
+
|
|
766
|
+
def preprocess(self, x, w):
|
|
767
|
+
# Prequantize weights.
|
|
768
|
+
if isinstance(w, (list, tuple)):
|
|
769
|
+
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
|
|
770
|
+
else:
|
|
771
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
772
|
+
if wq.dim() == 3:
|
|
773
|
+
w_scale = w_scale.view(wq.size(0), -1)
|
|
774
|
+
return x, wq, w_scale
|
|
775
|
+
|
|
776
|
+
def quantize(self, x, wq, w_scale):
|
|
777
|
+
# Quantize both input tensors.
|
|
778
|
+
# Handle both grouped and standard gemm.
|
|
779
|
+
if isinstance(x, (list, tuple)):
|
|
780
|
+
xq, x_scale = zip(*[quantize_fp8_row(i) for i in x])
|
|
781
|
+
else:
|
|
782
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
783
|
+
# Set proper batch dimension shapes.
|
|
784
|
+
if xq.dim() == 3:
|
|
785
|
+
x_scale = x_scale.view(xq.size(0), -1)
|
|
786
|
+
return xq, wq, x_scale, w_scale
|
|
787
|
+
|
|
788
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
789
|
+
# Handle group gemm if inputs are grouped.
|
|
790
|
+
if isinstance(xq, (list, tuple)):
|
|
791
|
+
output = []
|
|
792
|
+
for i in range(len(xq)):
|
|
793
|
+
output.append(
|
|
794
|
+
self.gemm_op(
|
|
795
|
+
xq[i],
|
|
796
|
+
wq[i],
|
|
797
|
+
x_scale[i],
|
|
798
|
+
w_scale[i],
|
|
799
|
+
use_fast_accum=self.fast_accum,
|
|
800
|
+
)
|
|
801
|
+
)
|
|
802
|
+
return output
|
|
803
|
+
# Unroll batched gemm if needed.
|
|
804
|
+
elif xq.dim() == 3 and wq.dim() == 3:
|
|
805
|
+
B, M, _ = xq.shape
|
|
806
|
+
_, N, _ = wq.shape
|
|
807
|
+
y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
|
|
808
|
+
for i in range(B):
|
|
809
|
+
y[i] = self.gemm_op(
|
|
810
|
+
xq[i], wq[i], x_scale[i], w_scale[i], use_fast_accum=self.fast_accum
|
|
811
|
+
)
|
|
812
|
+
return y
|
|
813
|
+
# Otherwise return normal gemm result.
|
|
814
|
+
return self.gemm_op(xq, wq, x_scale, w_scale, use_fast_accum=self.fast_accum)
|
|
815
|
+
|
|
816
|
+
def quantize_and_compute(self, x, wq, w_scale):
|
|
817
|
+
xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
|
|
818
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
819
|
+
|
|
820
|
+
@property
|
|
821
|
+
def name(self) -> str:
|
|
822
|
+
if torch.version.cuda:
|
|
823
|
+
return "cutlass_rowwise"
|
|
824
|
+
else:
|
|
825
|
+
return "ck_rowwise"
|
|
826
|
+
|
|
827
|
+
@property
|
|
828
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
829
|
+
return {
|
|
830
|
+
Accelerator.NVIDIA_SM90,
|
|
831
|
+
Accelerator.NVIDIA_SM100,
|
|
832
|
+
Accelerator.NVIDIA_SM103,
|
|
833
|
+
Accelerator.AMD_MI300X,
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
@property
|
|
837
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
838
|
+
return {GemmType.REGULAR}
|
|
839
|
+
|
|
840
|
+
@property
|
|
841
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
842
|
+
return ComputeDtype.FP8
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
@register_gemm_op
|
|
846
|
+
class FP8RowwisePreshuffleGemm(FP8RowwiseGemm):
|
|
847
|
+
"""
|
|
848
|
+
FP8 matmul with rowwise scaling and preshuffling of input B.
|
|
849
|
+
"""
|
|
850
|
+
|
|
851
|
+
def __init__(self):
|
|
852
|
+
self.fast_accum = True
|
|
853
|
+
if self.supported:
|
|
854
|
+
self.gemm_op = torch.ops.mslk.f8f8bf16_rowwise_preshuffle
|
|
855
|
+
|
|
856
|
+
def preprocess(self, x, w):
|
|
857
|
+
x, wq, w_scale = super().preprocess(x, w)
|
|
858
|
+
return x, ck_preshuffle(wq, 16), w_scale
|
|
859
|
+
|
|
860
|
+
@property
|
|
861
|
+
def name(self) -> str:
|
|
862
|
+
return "ck_rowwise_preshuffle"
|
|
863
|
+
|
|
864
|
+
@property
|
|
865
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
866
|
+
return {Accelerator.AMD_MI300X}
|
|
867
|
+
|
|
868
|
+
@property
|
|
869
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
870
|
+
return {GemmType.REGULAR}
|
|
871
|
+
|
|
872
|
+
@property
|
|
873
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
874
|
+
return ComputeDtype.FP8
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
@register_gemm_op
|
|
878
|
+
class FP8RowwiseGroupedGemm(GemmOpBase):
|
|
879
|
+
"""
|
|
880
|
+
FP8 grouped matmul with rowwise scaling.
|
|
881
|
+
"""
|
|
882
|
+
|
|
883
|
+
def preprocess(self, x, w):
|
|
884
|
+
# Apply sparsity to inputs if appropriate.
|
|
885
|
+
# First check if N and K are fixed.
|
|
886
|
+
m_values = [i.shape[0] for i in x]
|
|
887
|
+
n_values = [i.shape[0] for i in w]
|
|
888
|
+
k_values = [i.shape[1] for i in w]
|
|
889
|
+
# If so, do specialized version of initialization.
|
|
890
|
+
if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
|
|
891
|
+
m_values = [i.shape[0] for i in x]
|
|
892
|
+
# Inputs for fixed nk mode must be contiguous, however in the benchmark
|
|
893
|
+
# script they typically are not. Do a little special processing to make them
|
|
894
|
+
# work. In practice this wont be needed.
|
|
895
|
+
# Start by padding along m dimension with zeros.
|
|
896
|
+
max_m = max(m_values)
|
|
897
|
+
x = [
|
|
898
|
+
torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
|
|
899
|
+
for i in x
|
|
900
|
+
]
|
|
901
|
+
# Stack inputs into groups.
|
|
902
|
+
x = torch.stack(x).contiguous()
|
|
903
|
+
w = torch.stack(w).contiguous()
|
|
904
|
+
|
|
905
|
+
# Preapply weight quantization.
|
|
906
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
907
|
+
# Return processed tensors.
|
|
908
|
+
return (
|
|
909
|
+
x,
|
|
910
|
+
wq,
|
|
911
|
+
w_scale,
|
|
912
|
+
torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
|
|
913
|
+
)
|
|
914
|
+
# Otherwise run without sparsity.
|
|
915
|
+
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
|
|
916
|
+
return x, wq, w_scale, None
|
|
917
|
+
|
|
918
|
+
def quantize(self, x, wq, w_scale, m_values=None):
|
|
919
|
+
# Handle case where inputs are explicitly grouped and non-sparse.
|
|
920
|
+
if isinstance(x, (tuple, list)):
|
|
921
|
+
xq, x_scale = zip(*[triton_quantize_fp8_row(i) for i in x])
|
|
922
|
+
return xq, wq, x_scale, w_scale, m_values
|
|
923
|
+
# Otherwise inputs are unified tensors and sparse.
|
|
924
|
+
else:
|
|
925
|
+
B = x.shape[0]
|
|
926
|
+
xq, x_scale = triton_quantize_fp8_row(x, zero_start_index_M=m_values)
|
|
927
|
+
x_scale = x_scale.view(B, -1)
|
|
928
|
+
return xq, wq, x_scale, w_scale, m_values
|
|
929
|
+
|
|
930
|
+
def compute(self, xq, wq, x_scale, w_scale, m_values):
|
|
931
|
+
if m_values is None:
|
|
932
|
+
return torch.ops.mslk.f8f8bf16_rowwise_grouped(
|
|
933
|
+
xq,
|
|
934
|
+
wq,
|
|
935
|
+
x_scale,
|
|
936
|
+
w_scale,
|
|
937
|
+
)
|
|
938
|
+
else:
|
|
939
|
+
# Break tensor into groups, simulates what is done e2e.
|
|
940
|
+
return torch.ops.mslk.f8f8bf16_rowwise_grouped_dynamic(
|
|
941
|
+
xq,
|
|
942
|
+
wq,
|
|
943
|
+
x_scale,
|
|
944
|
+
w_scale,
|
|
945
|
+
zero_start_index_M=m_values,
|
|
946
|
+
)
|
|
947
|
+
|
|
948
|
+
def quantize_and_compute(self, x, wq, w_scale, m_values=None):
|
|
949
|
+
xq, wq, x_scale, w_scale, m_values = self.quantize(x, wq, w_scale, m_values)
|
|
950
|
+
return self.compute(xq, wq, x_scale, w_scale, m_values)
|
|
951
|
+
|
|
952
|
+
@property
|
|
953
|
+
def name(self) -> str:
|
|
954
|
+
if torch.version.cuda:
|
|
955
|
+
return "cutlass_rowwise_grouped"
|
|
956
|
+
else:
|
|
957
|
+
return "ck_rowwise_grouped"
|
|
958
|
+
|
|
959
|
+
@property
|
|
960
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
961
|
+
return {
|
|
962
|
+
Accelerator.NVIDIA_SM90,
|
|
963
|
+
Accelerator.NVIDIA_SM100,
|
|
964
|
+
Accelerator.NVIDIA_SM103,
|
|
965
|
+
Accelerator.AMD_MI300X,
|
|
966
|
+
}
|
|
967
|
+
|
|
968
|
+
@property
|
|
969
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
970
|
+
return {GemmType.GROUPED}
|
|
971
|
+
|
|
972
|
+
@property
|
|
973
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
974
|
+
return ComputeDtype.FP8
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
@register_gemm_op
|
|
978
|
+
class BF16TritonStackedGroupedGemm(GemmOpBase):
|
|
979
|
+
"""
|
|
980
|
+
BF16 grouped matmul with stacked inputs implemented with triton.
|
|
981
|
+
"""
|
|
982
|
+
|
|
983
|
+
def preprocess(self, x, w):
|
|
984
|
+
m_values = [i.shape[0] for i in x]
|
|
985
|
+
# Convert m_values into offsets into grouped tensor.
|
|
986
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
|
|
987
|
+
w = torch.concat(w, dim=0).contiguous()
|
|
988
|
+
# Also view input as flattened.
|
|
989
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
990
|
+
# Return processed tensors.
|
|
991
|
+
return x, w, m_sizes
|
|
992
|
+
|
|
993
|
+
def quantize(self, x, w, m_sizes):
|
|
994
|
+
return x, w, m_sizes
|
|
995
|
+
|
|
996
|
+
def compute(self, x, w, m_sizes):
|
|
997
|
+
return grouped_gemm(x, w, m_sizes, _use_warp_specialization=True)
|
|
998
|
+
|
|
999
|
+
def quantize_and_compute(self, x, w, m_sizes):
|
|
1000
|
+
x, w, m_sizes = self.quantize(x, w, m_sizes)
|
|
1001
|
+
return self.compute(x, w, m_sizes)
|
|
1002
|
+
|
|
1003
|
+
@property
|
|
1004
|
+
def name(self) -> str:
|
|
1005
|
+
return "triton_bf16_grouped_stacked"
|
|
1006
|
+
|
|
1007
|
+
@property
|
|
1008
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1009
|
+
return set(Accelerator)
|
|
1010
|
+
|
|
1011
|
+
@property
|
|
1012
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1013
|
+
return {GemmType.GROUPED}
|
|
1014
|
+
|
|
1015
|
+
@property
|
|
1016
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1017
|
+
return ComputeDtype.BF16
|
|
1018
|
+
|
|
1019
|
+
|
|
1020
|
+
@register_gemm_op
|
|
1021
|
+
class BF16TritonStackedGroupedGemmFuseScatterAdd(BF16TritonStackedGroupedGemm):
|
|
1022
|
+
"""
|
|
1023
|
+
BF16 grouped matmul with stacked inputs implemented with triton. Fused with ScatterAdd.
|
|
1024
|
+
"""
|
|
1025
|
+
|
|
1026
|
+
def preprocess(self, x, w):
|
|
1027
|
+
x, w, m_sizes = super().preprocess(x, w)
|
|
1028
|
+
M = x.shape[0]
|
|
1029
|
+
N = w.shape[0] // m_sizes.shape[0]
|
|
1030
|
+
output = torch.zeros(M, N, dtype=torch.bfloat16, device=x.device)
|
|
1031
|
+
indices = torch.randperm(M, dtype=torch.int32, device=x.device)
|
|
1032
|
+
return x, w, m_sizes, output, indices
|
|
1033
|
+
|
|
1034
|
+
def quantize(self, x, w, m_sizes, *args):
|
|
1035
|
+
return *super().quantize(x, w, m_sizes), *args
|
|
1036
|
+
|
|
1037
|
+
def compute(self, x, w, m_sizes, output, indices):
|
|
1038
|
+
return grouped_gemm(
|
|
1039
|
+
x,
|
|
1040
|
+
w,
|
|
1041
|
+
m_sizes,
|
|
1042
|
+
_use_warp_specialization=True,
|
|
1043
|
+
_output_tensor=output,
|
|
1044
|
+
_scatter_add_indices=indices,
|
|
1045
|
+
)
|
|
1046
|
+
|
|
1047
|
+
def quantize_and_compute(self, x, w, m_sizes, *args):
|
|
1048
|
+
x, w, m_sizes, *ret = self.quantize(x, w, m_sizes, *args)
|
|
1049
|
+
return self.compute(x, w, m_sizes, *ret)
|
|
1050
|
+
|
|
1051
|
+
@property
|
|
1052
|
+
def name(self) -> str:
|
|
1053
|
+
return "triton_bf16_grouped_stacked_fuse_scatter_add"
|
|
1054
|
+
|
|
1055
|
+
|
|
1056
|
+
@register_gemm_op
|
|
1057
|
+
class FP8TritonStackedGroupedGemm(GemmOpBase):
|
|
1058
|
+
"""
|
|
1059
|
+
FP8 grouped matmul with rowwise scaling and stacked inputs implemented with triton.
|
|
1060
|
+
"""
|
|
1061
|
+
|
|
1062
|
+
def preprocess(self, x, w):
|
|
1063
|
+
m_values = [i.shape[0] for i in x]
|
|
1064
|
+
# Convert m_values into offsets into grouped tensor.
|
|
1065
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
|
|
1066
|
+
# Quantize weights.
|
|
1067
|
+
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
|
|
1068
|
+
# Group weights as single tensor.
|
|
1069
|
+
wq = torch.concat(wq, dim=0).contiguous()
|
|
1070
|
+
w_scale = torch.concat(w_scale, dim=0).contiguous()
|
|
1071
|
+
# Also view input as flattened.
|
|
1072
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
1073
|
+
# Return processed tensors.
|
|
1074
|
+
return x, wq, w_scale, m_sizes
|
|
1075
|
+
|
|
1076
|
+
def quantize(self, x, wq, w_scale, m_sizes):
|
|
1077
|
+
B = x.shape[0]
|
|
1078
|
+
xq, x_scale = triton_quantize_fp8_row(x)
|
|
1079
|
+
x_scale = x_scale.view(B, -1)
|
|
1080
|
+
return xq, wq, x_scale, w_scale, m_sizes
|
|
1081
|
+
|
|
1082
|
+
def compute(self, xq, wq, x_scale, w_scale, m_sizes):
|
|
1083
|
+
return grouped_gemm_fp8_rowwise(
|
|
1084
|
+
xq, wq, m_sizes, x_scale, w_scale, _use_warp_specialization=True
|
|
1085
|
+
)
|
|
1086
|
+
|
|
1087
|
+
def quantize_and_compute(self, x, wq, w_scale, m_sizes):
|
|
1088
|
+
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
|
|
1089
|
+
return self.compute(xq, wq, x_scale, w_scale, m_sizes)
|
|
1090
|
+
|
|
1091
|
+
@property
|
|
1092
|
+
def name(self) -> str:
|
|
1093
|
+
return "triton_grouped_stacked"
|
|
1094
|
+
|
|
1095
|
+
@property
|
|
1096
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1097
|
+
return set(Accelerator)
|
|
1098
|
+
|
|
1099
|
+
@property
|
|
1100
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1101
|
+
return {GemmType.GROUPED}
|
|
1102
|
+
|
|
1103
|
+
@property
|
|
1104
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1105
|
+
return ComputeDtype.FP8
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
@register_gemm_op
|
|
1109
|
+
class FP8TritonStackedGroupedGemmFuseScatterAdd(FP8TritonStackedGroupedGemm):
|
|
1110
|
+
"""
|
|
1111
|
+
FP8 grouped matmul with stacked inputs implemented with triton. Fused with ScatterAdd.
|
|
1112
|
+
"""
|
|
1113
|
+
|
|
1114
|
+
def preprocess(self, x, w):
|
|
1115
|
+
x, wq, w_scale, m_sizes = super().preprocess(x, w)
|
|
1116
|
+
M = x.shape[0]
|
|
1117
|
+
N = wq.shape[0] // m_sizes.shape[0]
|
|
1118
|
+
output = torch.zeros(M, N, dtype=torch.bfloat16, device=x.device)
|
|
1119
|
+
indices = torch.randperm(M, dtype=torch.int32, device=x.device)
|
|
1120
|
+
return x, wq, w_scale, m_sizes, output, indices
|
|
1121
|
+
|
|
1122
|
+
def quantize(self, x, wq, w_scale, m_sizes, *args):
|
|
1123
|
+
return *super().quantize(x, wq, w_scale, m_sizes), *args
|
|
1124
|
+
|
|
1125
|
+
def compute(self, xq, wq, x_scale, w_scale, m_sizes, output, indices):
|
|
1126
|
+
return grouped_gemm_fp8_rowwise(
|
|
1127
|
+
xq,
|
|
1128
|
+
wq,
|
|
1129
|
+
m_sizes,
|
|
1130
|
+
x_scale,
|
|
1131
|
+
w_scale,
|
|
1132
|
+
_use_warp_specialization=True,
|
|
1133
|
+
_output_tensor=output,
|
|
1134
|
+
_scatter_add_indices=indices,
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
def quantize_and_compute(self, x, wq, w_scale, m_sizes, *args):
|
|
1138
|
+
xq, wq, x_scale, w_scale, m_sizes, *ret = self.quantize(
|
|
1139
|
+
x, wq, w_scale, m_sizes, *args
|
|
1140
|
+
)
|
|
1141
|
+
return self.compute(xq, wq, x_scale, w_scale, m_sizes, *ret)
|
|
1142
|
+
|
|
1143
|
+
@property
|
|
1144
|
+
def name(self) -> str:
|
|
1145
|
+
return "triton_grouped_stacked_fuse_scatter_add"
|
|
1146
|
+
|
|
1147
|
+
|
|
1148
|
+
@register_gemm_op
|
|
1149
|
+
class DeepGemmStacked(GemmOpBase):
|
|
1150
|
+
"""
|
|
1151
|
+
FP8 grouped matmul with blockwise scaling implemented with DeepGemm.
|
|
1152
|
+
"""
|
|
1153
|
+
|
|
1154
|
+
def preprocess(self, x, w):
|
|
1155
|
+
m_values = [i.shape[0] for i in x]
|
|
1156
|
+
# Convert m_values into offsets into grouped tensor.
|
|
1157
|
+
indices = torch.arange(len(m_values))
|
|
1158
|
+
m_indices = indices.repeat_interleave(torch.tensor(m_values)).to(
|
|
1159
|
+
device=x[0].device, dtype=torch.int
|
|
1160
|
+
)
|
|
1161
|
+
# Quantize weights.
|
|
1162
|
+
wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w])
|
|
1163
|
+
# Group weights as single tensor.
|
|
1164
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
1165
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
1166
|
+
# Also view input as flattened.
|
|
1167
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
1168
|
+
# Return processed tensors.
|
|
1169
|
+
return x, wq, w_scale, m_indices
|
|
1170
|
+
|
|
1171
|
+
def quantize(self, x, wq, w_scale, m_indices):
|
|
1172
|
+
xq, x_scale = quantize_fp8_block(x, block_m=1, block_k=128)
|
|
1173
|
+
# Pretranspose scales to deepgemm format.
|
|
1174
|
+
x_scale = get_col_major_tma_aligned_tensor(x_scale)
|
|
1175
|
+
return xq, wq, x_scale, w_scale, m_indices
|
|
1176
|
+
|
|
1177
|
+
def compute(self, xq, wq, x_scale, w_scale, m_indices):
|
|
1178
|
+
# Preallocate output.
|
|
1179
|
+
out = torch.empty(
|
|
1180
|
+
[xq.shape[0], wq.shape[1]], device=xq.device, dtype=torch.bfloat16
|
|
1181
|
+
)
|
|
1182
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
|
1183
|
+
(xq, x_scale), (wq, w_scale), out, m_indices
|
|
1184
|
+
)
|
|
1185
|
+
return out
|
|
1186
|
+
|
|
1187
|
+
def quantize_and_compute(self, x, wq, w_scale, m_indices):
|
|
1188
|
+
xq, wq, x_scale, w_scale, m_indices = self.quantize(x, wq, w_scale, m_indices)
|
|
1189
|
+
return self.compute(xq, wq, x_scale, w_scale, m_indices)
|
|
1190
|
+
|
|
1191
|
+
@property
|
|
1192
|
+
def name(self) -> str:
|
|
1193
|
+
return "deepgemm_stacked"
|
|
1194
|
+
|
|
1195
|
+
@property
|
|
1196
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1197
|
+
if DEEPGEMM_ENABLED:
|
|
1198
|
+
return {Accelerator.NVIDIA_SM90}
|
|
1199
|
+
return set()
|
|
1200
|
+
|
|
1201
|
+
@property
|
|
1202
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1203
|
+
return {GemmType.GROUPED}
|
|
1204
|
+
|
|
1205
|
+
@property
|
|
1206
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1207
|
+
return ComputeDtype.FP8
|
|
1208
|
+
|
|
1209
|
+
|
|
1210
|
+
@register_gemm_op
|
|
1211
|
+
class DeepGemmMaskedStacked(DeepGemmStacked):
|
|
1212
|
+
def preprocess(self, x, w):
|
|
1213
|
+
# Quantize weights.
|
|
1214
|
+
wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w])
|
|
1215
|
+
# Group weights as single tensor.
|
|
1216
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
1217
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
1218
|
+
|
|
1219
|
+
# Also view input as flattened.
|
|
1220
|
+
m_values = [i.shape[0] for i in x]
|
|
1221
|
+
expected_m = max(m_values)
|
|
1222
|
+
padded_m_max = ((max(m_values) + 127) // 128) * 128
|
|
1223
|
+
masked_m = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
|
|
1224
|
+
|
|
1225
|
+
num_groups = len(m_values)
|
|
1226
|
+
k = x[0].shape[1]
|
|
1227
|
+
x_padded = torch.zeros(
|
|
1228
|
+
[num_groups, padded_m_max, k], device=x[0].device, dtype=x[0].dtype
|
|
1229
|
+
)
|
|
1230
|
+
for g in range(num_groups):
|
|
1231
|
+
x_padded[g, : m_values[g], :] = x[g]
|
|
1232
|
+
|
|
1233
|
+
# Return processed tensors.
|
|
1234
|
+
return x_padded, wq, w_scale, masked_m, expected_m, m_values
|
|
1235
|
+
|
|
1236
|
+
def quantize(self, x, wq, w_scale, masked_m, expected_m, m_values):
|
|
1237
|
+
g, m_max, k = x.shape
|
|
1238
|
+
xq, x_scale = quantize_fp8_block(x.view(-1, k), block_m=1, block_k=128)
|
|
1239
|
+
# Pretranspose scales to deepgemm format.
|
|
1240
|
+
x_scale = get_col_major_tma_aligned_tensor(x_scale)
|
|
1241
|
+
return (
|
|
1242
|
+
xq.view(g, m_max, -1),
|
|
1243
|
+
wq,
|
|
1244
|
+
x_scale.view(g, m_max, -1),
|
|
1245
|
+
w_scale,
|
|
1246
|
+
masked_m,
|
|
1247
|
+
expected_m,
|
|
1248
|
+
m_values,
|
|
1249
|
+
)
|
|
1250
|
+
|
|
1251
|
+
def compute(self, xq, wq, x_scale, w_scale, masked_m, expected_m, m_values):
|
|
1252
|
+
# Preallocate output.
|
|
1253
|
+
out = torch.empty(
|
|
1254
|
+
[xq.shape[0], xq.shape[1], wq.shape[1]],
|
|
1255
|
+
device=xq.device,
|
|
1256
|
+
dtype=torch.bfloat16,
|
|
1257
|
+
)
|
|
1258
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
|
1259
|
+
(xq, x_scale), (wq, w_scale), out, masked_m, expected_m
|
|
1260
|
+
)
|
|
1261
|
+
num_groups = xq.shape[0]
|
|
1262
|
+
out_list = [out[g, : m_values[g], :] for g in range(num_groups)]
|
|
1263
|
+
return out_list
|
|
1264
|
+
|
|
1265
|
+
def quantize_and_compute(self, x, wq, w_scale, masked_m, expected_m, m_values):
|
|
1266
|
+
xq, wq, x_scale, w_scale, masked_m, expected_m = self.quantize(
|
|
1267
|
+
x, wq, w_scale, masked_m, expected_m, m_values
|
|
1268
|
+
)
|
|
1269
|
+
return self.compute(xq, wq, x_scale, w_scale, masked_m, expected_m, m_values)
|
|
1270
|
+
|
|
1271
|
+
@property
|
|
1272
|
+
def name(self) -> str:
|
|
1273
|
+
return "deepgemm_masked_stacked"
|
|
1274
|
+
|
|
1275
|
+
|
|
1276
|
+
@register_gemm_op
|
|
1277
|
+
class DeepGemmBlockwise(GemmOpBase):
|
|
1278
|
+
"""
|
|
1279
|
+
FP8 matmul with blockwise scaling implemented with DeepGemm.
|
|
1280
|
+
"""
|
|
1281
|
+
|
|
1282
|
+
def preprocess(self, x, w):
|
|
1283
|
+
# Quantize weights.
|
|
1284
|
+
wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128)
|
|
1285
|
+
# allocate output.
|
|
1286
|
+
out = torch.empty(
|
|
1287
|
+
x.shape[0], wq.shape[0], device=x.device, dtype=torch.bfloat16
|
|
1288
|
+
)
|
|
1289
|
+
# Return processed tensors.
|
|
1290
|
+
return x, wq, w_scale, out
|
|
1291
|
+
|
|
1292
|
+
def quantize(self, x, wq, w_scale, out):
|
|
1293
|
+
xq, x_scale = quantize_fp8_group(x, group_size=128)
|
|
1294
|
+
return xq, wq, x_scale, w_scale, out
|
|
1295
|
+
|
|
1296
|
+
def compute(self, xq, wq, x_scale, w_scale, out):
|
|
1297
|
+
gemm_fp8_fp8_bf16_nt((xq, x_scale), (wq, w_scale), out)
|
|
1298
|
+
return out
|
|
1299
|
+
|
|
1300
|
+
def quantize_and_compute(self, x, wq, w_scale, out):
|
|
1301
|
+
xq, wq, x_scale, w_scale, out = self.quantize(x, wq, w_scale, out)
|
|
1302
|
+
return self.compute(xq, wq, x_scale, w_scale, out)
|
|
1303
|
+
|
|
1304
|
+
@property
|
|
1305
|
+
def name(self) -> str:
|
|
1306
|
+
return "deepgemm_blockwise"
|
|
1307
|
+
|
|
1308
|
+
@property
|
|
1309
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1310
|
+
if DEEPGEMM_ENABLED:
|
|
1311
|
+
return {Accelerator.NVIDIA_SM90}
|
|
1312
|
+
return set()
|
|
1313
|
+
|
|
1314
|
+
@property
|
|
1315
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1316
|
+
return {GemmType.REGULAR}
|
|
1317
|
+
|
|
1318
|
+
@property
|
|
1319
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1320
|
+
return ComputeDtype.FP8
|
|
1321
|
+
|
|
1322
|
+
|
|
1323
|
+
@register_gemm_op
|
|
1324
|
+
class DeepGemmRowwise(GemmOpBase):
|
|
1325
|
+
"""
|
|
1326
|
+
FP8 matmul with rowwise scaling implemented with DeepGemm.
|
|
1327
|
+
"""
|
|
1328
|
+
|
|
1329
|
+
def preprocess(self, x, w):
|
|
1330
|
+
# Quantize weights.
|
|
1331
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
1332
|
+
# allocate output.
|
|
1333
|
+
out = torch.empty(
|
|
1334
|
+
x.shape[0], wq.shape[0], device=x.device, dtype=torch.bfloat16
|
|
1335
|
+
)
|
|
1336
|
+
# Return processed tensors.
|
|
1337
|
+
return x, wq, w_scale, out
|
|
1338
|
+
|
|
1339
|
+
def quantize(self, x, wq, w_scale, out):
|
|
1340
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
1341
|
+
# Pretranspose scales to deepgemm format.
|
|
1342
|
+
x_scale = get_col_major_tma_aligned_tensor(x_scale, rowwise_scaling=True)
|
|
1343
|
+
return xq, wq, x_scale, w_scale, out
|
|
1344
|
+
|
|
1345
|
+
def compute(self, xq, wq, x_scale, w_scale, out):
|
|
1346
|
+
gemm_fp8_fp8_bf16_nt((xq, x_scale), (wq, w_scale), out)
|
|
1347
|
+
return out
|
|
1348
|
+
|
|
1349
|
+
def quantize_and_compute(self, x, wq, w_scale, out):
|
|
1350
|
+
xq, wq, x_scale, w_scale, out = self.quantize(x, wq, w_scale, out)
|
|
1351
|
+
return self.compute(xq, wq, x_scale, w_scale, out)
|
|
1352
|
+
|
|
1353
|
+
@property
|
|
1354
|
+
def name(self) -> str:
|
|
1355
|
+
return "deepgemm_rowwise"
|
|
1356
|
+
|
|
1357
|
+
@property
|
|
1358
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1359
|
+
if DEEPGEMM_ENABLED:
|
|
1360
|
+
return {Accelerator.NVIDIA_SM90}
|
|
1361
|
+
return set()
|
|
1362
|
+
|
|
1363
|
+
@property
|
|
1364
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1365
|
+
return {GemmType.REGULAR}
|
|
1366
|
+
|
|
1367
|
+
@property
|
|
1368
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1369
|
+
return ComputeDtype.FP8
|
|
1370
|
+
|
|
1371
|
+
|
|
1372
|
+
@register_gemm_op
|
|
1373
|
+
class FP8StackedGroupedGemm(GemmOpBase):
|
|
1374
|
+
"""
|
|
1375
|
+
FP8 grouped matmul with rowwise scaling and stacked inputs.
|
|
1376
|
+
"""
|
|
1377
|
+
|
|
1378
|
+
def preprocess(self, x, w):
|
|
1379
|
+
m_values = [i.shape[0] for i in x]
|
|
1380
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
1381
|
+
# Quantize weights.
|
|
1382
|
+
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
|
|
1383
|
+
# Group weights as single tensor.
|
|
1384
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
1385
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
1386
|
+
# Also view input as flattened.
|
|
1387
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
1388
|
+
# Return processed tensors.
|
|
1389
|
+
return x, wq, w_scale, m_sizes
|
|
1390
|
+
|
|
1391
|
+
def quantize(self, x, wq, w_scale, m_sizes):
|
|
1392
|
+
B = x.shape[0]
|
|
1393
|
+
xq, x_scale = triton_quantize_fp8_row(x)
|
|
1394
|
+
x_scale = x_scale.view(B, -1)
|
|
1395
|
+
return xq, wq, x_scale, w_scale, m_sizes
|
|
1396
|
+
|
|
1397
|
+
def compute(self, xq, wq, x_scale, w_scale, m_sizes):
|
|
1398
|
+
return torch.ops.mslk.f8f8bf16_rowwise_grouped_stacked(
|
|
1399
|
+
xq, wq, x_scale, w_scale, m_sizes
|
|
1400
|
+
)
|
|
1401
|
+
|
|
1402
|
+
def quantize_and_compute(self, x, wq, w_scale, m_sizes):
|
|
1403
|
+
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
|
|
1404
|
+
return self.compute(xq, wq, x_scale, w_scale, m_sizes)
|
|
1405
|
+
|
|
1406
|
+
@property
|
|
1407
|
+
def name(self) -> str:
|
|
1408
|
+
if torch.version.cuda:
|
|
1409
|
+
return "cutlass_grouped_stacked"
|
|
1410
|
+
else:
|
|
1411
|
+
return "ck_grouped_stacked"
|
|
1412
|
+
|
|
1413
|
+
@property
|
|
1414
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1415
|
+
return {
|
|
1416
|
+
Accelerator.NVIDIA_SM90,
|
|
1417
|
+
Accelerator.NVIDIA_SM100,
|
|
1418
|
+
Accelerator.NVIDIA_SM103,
|
|
1419
|
+
Accelerator.AMD_MI300X,
|
|
1420
|
+
}
|
|
1421
|
+
|
|
1422
|
+
@property
|
|
1423
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1424
|
+
return {GemmType.GROUPED}
|
|
1425
|
+
|
|
1426
|
+
@property
|
|
1427
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1428
|
+
return ComputeDtype.FP8
|
|
1429
|
+
|
|
1430
|
+
|
|
1431
|
+
@register_gemm_op
|
|
1432
|
+
class FP8StackedGroupedGemmTorch(FP8StackedGroupedGemm):
|
|
1433
|
+
def quantize(self, x, wq, w_scale, m_sizes):
|
|
1434
|
+
xq, wq, x_scale, w_scale, m_sizes = super().quantize(x, wq, w_scale, m_sizes)
|
|
1435
|
+
offsets = torch.cumsum(m_sizes, dim=0, dtype=torch.int32)
|
|
1436
|
+
out = torch.empty(
|
|
1437
|
+
(xq.shape[0], wq.shape[1]), dtype=torch.bfloat16, device=xq.device
|
|
1438
|
+
)
|
|
1439
|
+
x_scale = x_scale.view(x_scale.shape[0])
|
|
1440
|
+
return xq, wq, x_scale, w_scale, offsets, out
|
|
1441
|
+
|
|
1442
|
+
def compute(self, xq, wq, x_scale, w_scale, offsets, out):
|
|
1443
|
+
return torch.ops.mslk.f8f8bf16_rowwise_grouped_mm(
|
|
1444
|
+
xq, wq, x_scale, w_scale, offsets, out
|
|
1445
|
+
)
|
|
1446
|
+
|
|
1447
|
+
def quantize_and_compute(self, x, wq, w_scale, m_sizes):
|
|
1448
|
+
xq, wq, x_scale, w_scale, offsets, out = self.quantize(x, wq, w_scale, m_sizes)
|
|
1449
|
+
return self.compute(xq, wq, x_scale, w_scale, offsets, out)
|
|
1450
|
+
|
|
1451
|
+
@property
|
|
1452
|
+
def name(self) -> str:
|
|
1453
|
+
return "ck_grouped_stacked_torch_2d3d"
|
|
1454
|
+
|
|
1455
|
+
@property
|
|
1456
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1457
|
+
return {Accelerator.AMD_MI300X}
|
|
1458
|
+
|
|
1459
|
+
@property
|
|
1460
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1461
|
+
return {GemmType.GROUPED}
|
|
1462
|
+
|
|
1463
|
+
@property
|
|
1464
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1465
|
+
return ComputeDtype.FP8
|
|
1466
|
+
|
|
1467
|
+
|
|
1468
|
+
@register_gemm_op
|
|
1469
|
+
class ScaledGroupedMMRowwise(FP8StackedGroupedGemmTorch):
|
|
1470
|
+
def __init__(self):
|
|
1471
|
+
self.fast_accum = True
|
|
1472
|
+
self.torch_compile = False
|
|
1473
|
+
|
|
1474
|
+
def compute(self, xq, wq, x_scale, w_scale, offsets, _):
|
|
1475
|
+
if self.torch_compile:
|
|
1476
|
+
f = torch.compile(
|
|
1477
|
+
torch._scaled_grouped_mm,
|
|
1478
|
+
options={
|
|
1479
|
+
"max_autotune": True,
|
|
1480
|
+
"max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
|
|
1481
|
+
},
|
|
1482
|
+
)
|
|
1483
|
+
else:
|
|
1484
|
+
f = torch._scaled_grouped_mm
|
|
1485
|
+
|
|
1486
|
+
return f(
|
|
1487
|
+
xq,
|
|
1488
|
+
wq.transpose(-2, -1),
|
|
1489
|
+
offs=offsets,
|
|
1490
|
+
out_dtype=torch.bfloat16,
|
|
1491
|
+
scale_a=x_scale,
|
|
1492
|
+
scale_b=w_scale,
|
|
1493
|
+
scale_result=None,
|
|
1494
|
+
use_fast_accum=self.fast_accum,
|
|
1495
|
+
)
|
|
1496
|
+
|
|
1497
|
+
@property
|
|
1498
|
+
def name(self) -> str:
|
|
1499
|
+
return "scaled_grouped_mm_rowwise"
|
|
1500
|
+
|
|
1501
|
+
@property
|
|
1502
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1503
|
+
return {
|
|
1504
|
+
Accelerator.NVIDIA_SM90,
|
|
1505
|
+
Accelerator.AMD_MI300X,
|
|
1506
|
+
}
|
|
1507
|
+
|
|
1508
|
+
@property
|
|
1509
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1510
|
+
return {GemmType.GROUPED}
|
|
1511
|
+
|
|
1512
|
+
@property
|
|
1513
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1514
|
+
return ComputeDtype.FP8
|
|
1515
|
+
|
|
1516
|
+
|
|
1517
|
+
@register_gemm_op
|
|
1518
|
+
class FP8StackedGroupwiseGroupedGemm(GemmOpBase):
|
|
1519
|
+
"""
|
|
1520
|
+
FP8 grouped matmul with groupwise scaling and stacked inputs.
|
|
1521
|
+
"""
|
|
1522
|
+
|
|
1523
|
+
def preprocess(self, x, w):
|
|
1524
|
+
m_values = [i.shape[0] for i in x]
|
|
1525
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
1526
|
+
# Quantize weights.
|
|
1527
|
+
wq, w_scale = zip(
|
|
1528
|
+
*[quantize_fp8_block(i, block_m=128, block_k=128, k_major=False) for i in w]
|
|
1529
|
+
)
|
|
1530
|
+
# Group weights as single tensor.
|
|
1531
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
1532
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
1533
|
+
# Also view input as flattened.
|
|
1534
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
1535
|
+
# Return processed tensors.
|
|
1536
|
+
return x, wq, w_scale, m_sizes
|
|
1537
|
+
|
|
1538
|
+
def quantize(self, x, wq, w_scale, m_sizes):
|
|
1539
|
+
xq, x_scale = quantize_fp8_group(x, m_sizes=m_sizes)
|
|
1540
|
+
return xq, wq, x_scale, w_scale, m_sizes
|
|
1541
|
+
|
|
1542
|
+
def compute(self, xq, wq, x_scale, w_scale, m_sizes):
|
|
1543
|
+
return torch.ops.mslk.f8f8bf16_groupwise_grouped(
|
|
1544
|
+
xq, wq, x_scale, w_scale, m_sizes
|
|
1545
|
+
)
|
|
1546
|
+
|
|
1547
|
+
def quantize_and_compute(self, x, wq, w_scale, m_sizes):
|
|
1548
|
+
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
|
|
1549
|
+
return self.compute(xq, wq, x_scale, w_scale, m_sizes)
|
|
1550
|
+
|
|
1551
|
+
@property
|
|
1552
|
+
def name(self) -> str:
|
|
1553
|
+
return "cutlass_groupwise_grouped"
|
|
1554
|
+
|
|
1555
|
+
@property
|
|
1556
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1557
|
+
return {Accelerator.NVIDIA_SM90}
|
|
1558
|
+
|
|
1559
|
+
@property
|
|
1560
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1561
|
+
return {GemmType.GROUPED}
|
|
1562
|
+
|
|
1563
|
+
@property
|
|
1564
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1565
|
+
return ComputeDtype.FP8
|
|
1566
|
+
|
|
1567
|
+
|
|
1568
|
+
@register_gemm_op
|
|
1569
|
+
class BF16GroupedGemm(GemmOpBase):
|
|
1570
|
+
"""
|
|
1571
|
+
BF16 grouped matmul implemented with CK or Cutlass.
|
|
1572
|
+
"""
|
|
1573
|
+
|
|
1574
|
+
def preprocess(self, x, w):
|
|
1575
|
+
# Apply sparsity to inputs if appropriate.
|
|
1576
|
+
# First check if N and K are fixed.
|
|
1577
|
+
m_values = [i.shape[0] for i in x]
|
|
1578
|
+
n_values = [i.shape[0] for i in w]
|
|
1579
|
+
k_values = [i.shape[1] for i in w]
|
|
1580
|
+
# If so, do specialized version of initialization.
|
|
1581
|
+
if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
|
|
1582
|
+
m_values = [i.shape[0] for i in x]
|
|
1583
|
+
# Inputs for fixed nk mode must be contiguous, however in the benchmark
|
|
1584
|
+
# script they typically are not. Do a little special processing to make them
|
|
1585
|
+
# work. In practice this wont be needed.
|
|
1586
|
+
# Start by padding along m dimension with zeros.
|
|
1587
|
+
max_m = max(m_values)
|
|
1588
|
+
x = [
|
|
1589
|
+
torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
|
|
1590
|
+
for i in x
|
|
1591
|
+
]
|
|
1592
|
+
# Stack inputs into groups.
|
|
1593
|
+
x = torch.stack(x).contiguous()
|
|
1594
|
+
w = torch.stack(w).contiguous()
|
|
1595
|
+
return (
|
|
1596
|
+
x,
|
|
1597
|
+
w,
|
|
1598
|
+
torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
|
|
1599
|
+
)
|
|
1600
|
+
return x, w, None
|
|
1601
|
+
|
|
1602
|
+
def quantize(self, x, w, m_values=None):
|
|
1603
|
+
# No action required.
|
|
1604
|
+
return x, w, m_values
|
|
1605
|
+
|
|
1606
|
+
def compute(self, x, w, m_values):
|
|
1607
|
+
if m_values is None:
|
|
1608
|
+
return torch.ops.mslk.bf16bf16bf16_grouped(x, w)
|
|
1609
|
+
else:
|
|
1610
|
+
return torch.ops.mslk.bf16bf16bf16_grouped_dynamic(x, w, m_values)
|
|
1611
|
+
|
|
1612
|
+
def quantize_and_compute(self, x, w, m_values):
|
|
1613
|
+
return self.compute(x, w, m_values)
|
|
1614
|
+
|
|
1615
|
+
@property
|
|
1616
|
+
def name(self) -> str:
|
|
1617
|
+
if torch.version.cuda:
|
|
1618
|
+
return "cutlass_bf16_grouped"
|
|
1619
|
+
else:
|
|
1620
|
+
return "ck_bf16_grouped"
|
|
1621
|
+
|
|
1622
|
+
@property
|
|
1623
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1624
|
+
return {
|
|
1625
|
+
Accelerator.NVIDIA_SM90,
|
|
1626
|
+
Accelerator.NVIDIA_SM100,
|
|
1627
|
+
Accelerator.NVIDIA_SM103,
|
|
1628
|
+
Accelerator.AMD_MI300X,
|
|
1629
|
+
}
|
|
1630
|
+
|
|
1631
|
+
@property
|
|
1632
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1633
|
+
return {GemmType.GROUPED}
|
|
1634
|
+
|
|
1635
|
+
@property
|
|
1636
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1637
|
+
return ComputeDtype.BF16
|
|
1638
|
+
|
|
1639
|
+
|
|
1640
|
+
@register_gemm_op
|
|
1641
|
+
class FP8RowwiseBatchedGemm(GemmOpBase):
|
|
1642
|
+
"""
|
|
1643
|
+
FP8 batched matmul with rowwise scaling.
|
|
1644
|
+
"""
|
|
1645
|
+
|
|
1646
|
+
def quantize(self, x, w):
|
|
1647
|
+
assert isinstance(x, list) and isinstance(w, list)
|
|
1648
|
+
x = torch.stack(x, dim=0)
|
|
1649
|
+
w = torch.stack(w, dim=0)
|
|
1650
|
+
# Quantize both input tensors.
|
|
1651
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
1652
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
1653
|
+
return xq, wq, x_scale, w_scale
|
|
1654
|
+
|
|
1655
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
1656
|
+
return torch.ops.mslk.f8f8bf16_rowwise_batched(xq, wq, x_scale, w_scale)
|
|
1657
|
+
|
|
1658
|
+
def quantize_and_compute(self, x, w):
|
|
1659
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
1660
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1661
|
+
|
|
1662
|
+
@property
|
|
1663
|
+
def name(self) -> str:
|
|
1664
|
+
if torch.version.cuda:
|
|
1665
|
+
return "cutlass_rowwise_batched"
|
|
1666
|
+
else:
|
|
1667
|
+
return "ck_rowwise_batched"
|
|
1668
|
+
|
|
1669
|
+
@property
|
|
1670
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1671
|
+
return {
|
|
1672
|
+
Accelerator.NVIDIA_SM90,
|
|
1673
|
+
Accelerator.NVIDIA_SM100,
|
|
1674
|
+
Accelerator.NVIDIA_SM103,
|
|
1675
|
+
Accelerator.AMD_MI300X,
|
|
1676
|
+
}
|
|
1677
|
+
|
|
1678
|
+
@property
|
|
1679
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1680
|
+
return {GemmType.GROUPED}
|
|
1681
|
+
|
|
1682
|
+
@property
|
|
1683
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1684
|
+
return ComputeDtype.FP8
|
|
1685
|
+
|
|
1686
|
+
|
|
1687
|
+
# This kernel is broken and causes GPU to lock up, needs some investigation
|
|
1688
|
+
# @register_gemm_op
|
|
1689
|
+
class TritonFP8RowwiseGemm(GemmOpBase):
|
|
1690
|
+
"""
|
|
1691
|
+
FP8 matmul with rowwise scaling implemented with Triton.
|
|
1692
|
+
"""
|
|
1693
|
+
|
|
1694
|
+
def __init__(self):
|
|
1695
|
+
self.fast_accum = True
|
|
1696
|
+
|
|
1697
|
+
def quantize(self, x, w):
|
|
1698
|
+
# Quantize both input tensors.
|
|
1699
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
1700
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
1701
|
+
bias = torch.randn(w.shape[0], device=x.device, dtype=torch.float32)
|
|
1702
|
+
return xq, wq, x_scale, w_scale, bias
|
|
1703
|
+
|
|
1704
|
+
def compute(self, xq, wq, x_scale, w_scale, bias):
|
|
1705
|
+
return matmul_fp8_row(
|
|
1706
|
+
xq,
|
|
1707
|
+
wq,
|
|
1708
|
+
x_scale,
|
|
1709
|
+
w_scale,
|
|
1710
|
+
bias=bias,
|
|
1711
|
+
fp8_fast_accum=self.fast_accum,
|
|
1712
|
+
use_warp_specialization=True,
|
|
1713
|
+
)
|
|
1714
|
+
|
|
1715
|
+
def quantize_and_compute(self, x, w):
|
|
1716
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
1717
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1718
|
+
|
|
1719
|
+
@property
|
|
1720
|
+
def name(self) -> str:
|
|
1721
|
+
return "triton_rowwise"
|
|
1722
|
+
|
|
1723
|
+
@property
|
|
1724
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1725
|
+
return set(Accelerator)
|
|
1726
|
+
|
|
1727
|
+
@property
|
|
1728
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1729
|
+
return {GemmType.REGULAR}
|
|
1730
|
+
|
|
1731
|
+
@property
|
|
1732
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1733
|
+
return ComputeDtype.FP8
|
|
1734
|
+
|
|
1735
|
+
|
|
1736
|
+
@register_gemm_op
|
|
1737
|
+
class FP8TritonBlockwiseGemm(GemmOpBase):
|
|
1738
|
+
"""
|
|
1739
|
+
FP8 matmul with block scaling.
|
|
1740
|
+
"""
|
|
1741
|
+
|
|
1742
|
+
def quantize(self, x, w):
|
|
1743
|
+
# Quantize both input tensors.
|
|
1744
|
+
xq, x_scale = quantize_fp8_block(x, 128, 128)
|
|
1745
|
+
wq, w_scale = quantize_fp8_block(w, 128, 128)
|
|
1746
|
+
return xq, wq, x_scale, w_scale
|
|
1747
|
+
|
|
1748
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
1749
|
+
return matmul_fp8_block(xq, wq, x_scale, w_scale, 128, 128, 128)
|
|
1750
|
+
|
|
1751
|
+
def quantize_and_compute(self, x, w):
|
|
1752
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
1753
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1754
|
+
|
|
1755
|
+
@property
|
|
1756
|
+
def name(self) -> str:
|
|
1757
|
+
return "triton_blockwise"
|
|
1758
|
+
|
|
1759
|
+
@property
|
|
1760
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1761
|
+
return set(Accelerator)
|
|
1762
|
+
|
|
1763
|
+
@property
|
|
1764
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1765
|
+
return {GemmType.REGULAR}
|
|
1766
|
+
|
|
1767
|
+
@property
|
|
1768
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1769
|
+
return ComputeDtype.FP8
|
|
1770
|
+
|
|
1771
|
+
|
|
1772
|
+
@register_gemm_op
|
|
1773
|
+
class FP8CutlassBlockwiseGemm(GemmOpBase):
|
|
1774
|
+
"""
|
|
1775
|
+
FP8 matmul with block scaling.
|
|
1776
|
+
"""
|
|
1777
|
+
|
|
1778
|
+
def quantize(self, x, w):
|
|
1779
|
+
# Quantize both input tensors.
|
|
1780
|
+
xq, x_scale = quantize_fp8_block(x, 128, 128)
|
|
1781
|
+
wq, w_scale = quantize_fp8_block(w, 128, 128)
|
|
1782
|
+
return xq, wq, x_scale, w_scale
|
|
1783
|
+
|
|
1784
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
1785
|
+
return torch.ops.mslk.f8f8bf16_blockwise(
|
|
1786
|
+
xq, wq, x_scale, w_scale, 128, 128, 128
|
|
1787
|
+
)
|
|
1788
|
+
|
|
1789
|
+
def quantize_and_compute(self, x, w):
|
|
1790
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
1791
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1792
|
+
|
|
1793
|
+
@property
|
|
1794
|
+
def name(self) -> str:
|
|
1795
|
+
if torch.version.cuda:
|
|
1796
|
+
return "cutlass_blockwise"
|
|
1797
|
+
else:
|
|
1798
|
+
return "ck_blockwise"
|
|
1799
|
+
|
|
1800
|
+
@property
|
|
1801
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1802
|
+
return {
|
|
1803
|
+
Accelerator.NVIDIA_SM90,
|
|
1804
|
+
Accelerator.AMD_MI300X,
|
|
1805
|
+
}
|
|
1806
|
+
|
|
1807
|
+
@property
|
|
1808
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1809
|
+
return {GemmType.REGULAR}
|
|
1810
|
+
|
|
1811
|
+
@property
|
|
1812
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1813
|
+
return ComputeDtype.FP8
|
|
1814
|
+
|
|
1815
|
+
|
|
1816
|
+
@register_gemm_op
|
|
1817
|
+
class FP8CutlassGroupwiseGemm(GemmOpBase):
|
|
1818
|
+
"""
|
|
1819
|
+
FP8 matmul with group / block scaling.
|
|
1820
|
+
"""
|
|
1821
|
+
|
|
1822
|
+
def preprocess(self, x, w):
|
|
1823
|
+
# Quantize weights.
|
|
1824
|
+
# Scale is expected to be in [K, N] layout (N Major).
|
|
1825
|
+
wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128, k_major=False)
|
|
1826
|
+
# Return processed tensors.
|
|
1827
|
+
return x, wq, w_scale
|
|
1828
|
+
|
|
1829
|
+
def quantize(self, x, wq, w_scale):
|
|
1830
|
+
# Scale is expected to be in [K, M] layout (M Major).
|
|
1831
|
+
xq, x_scale = quantize_fp8_group(x, k_major=False)
|
|
1832
|
+
# Pretranspose scales to deepgemm format.
|
|
1833
|
+
return xq, wq, x_scale, w_scale
|
|
1834
|
+
|
|
1835
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
1836
|
+
return torch.ops.mslk.f8f8bf16_groupwise(xq, wq, x_scale, w_scale)
|
|
1837
|
+
|
|
1838
|
+
def quantize_and_compute(self, x, wq, w_scale):
|
|
1839
|
+
xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
|
|
1840
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1841
|
+
|
|
1842
|
+
@property
|
|
1843
|
+
def name(self) -> str:
|
|
1844
|
+
return "cutlass_groupwise"
|
|
1845
|
+
|
|
1846
|
+
@property
|
|
1847
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1848
|
+
return {Accelerator.NVIDIA_SM90}
|
|
1849
|
+
|
|
1850
|
+
@property
|
|
1851
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1852
|
+
return {GemmType.REGULAR}
|
|
1853
|
+
|
|
1854
|
+
@property
|
|
1855
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1856
|
+
return ComputeDtype.FP8
|
|
1857
|
+
|
|
1858
|
+
|
|
1859
|
+
@register_gemm_op
|
|
1860
|
+
class F8I4RowwiseGemm(GemmOpBase):
|
|
1861
|
+
"""
|
|
1862
|
+
Mixed Precision FP8 Activations with Int4 Weights.
|
|
1863
|
+
"""
|
|
1864
|
+
|
|
1865
|
+
def _int4_row_quantize(
|
|
1866
|
+
self,
|
|
1867
|
+
x: torch.Tensor,
|
|
1868
|
+
group_size: int = 128,
|
|
1869
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1870
|
+
n_bit = 4 # Number of target bits.
|
|
1871
|
+
to_quant = x.reshape(-1, group_size).to(torch.float)
|
|
1872
|
+
|
|
1873
|
+
max_val = to_quant.amax(dim=1, keepdim=True)
|
|
1874
|
+
min_val = to_quant.amin(dim=1, keepdim=True)
|
|
1875
|
+
max_int = 2**n_bit - 1
|
|
1876
|
+
min_int = 0
|
|
1877
|
+
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
|
1878
|
+
|
|
1879
|
+
zeros = min_val + scales * (2 ** (n_bit - 1))
|
|
1880
|
+
|
|
1881
|
+
out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
|
|
1882
|
+
|
|
1883
|
+
# Recenter output and move to int8.
|
|
1884
|
+
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
|
|
1885
|
+
|
|
1886
|
+
# Cutlass expects column major layout for scale and zero point,
|
|
1887
|
+
# so we transpose here and make them contiguous.
|
|
1888
|
+
scales = scales.view(x.shape[0], -1).t().contiguous()
|
|
1889
|
+
zeros = zeros.view(x.shape[0], -1).t().contiguous()
|
|
1890
|
+
|
|
1891
|
+
return out, scales, zeros
|
|
1892
|
+
|
|
1893
|
+
def _pack_int4(self, x: torch.Tensor) -> torch.Tensor:
|
|
1894
|
+
# Given int8 x, pack adjacent int4 values into a single int8.
|
|
1895
|
+
low_x = x[:, ::2]
|
|
1896
|
+
high_x = x[:, 1::2]
|
|
1897
|
+
|
|
1898
|
+
# High bits need to left shift, this also masks off extra bits.
|
|
1899
|
+
high_x = torch.bitwise_left_shift(high_x, 4)
|
|
1900
|
+
# Low bits need to have sign bits removed.
|
|
1901
|
+
low_x = torch.bitwise_and(low_x, 0xF)
|
|
1902
|
+
|
|
1903
|
+
# Recombine into a single value with bitwise or.
|
|
1904
|
+
return torch.bitwise_or(low_x, high_x).contiguous()
|
|
1905
|
+
|
|
1906
|
+
def quantize(self, x, w):
|
|
1907
|
+
# Quantize both input tensors.
|
|
1908
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
1909
|
+
wq, w_scale, w_zp = self._int4_row_quantize(w)
|
|
1910
|
+
# Pack int4 values together.
|
|
1911
|
+
wq = self._pack_int4(wq)
|
|
1912
|
+
return xq, wq, x_scale, w_scale, w_zp
|
|
1913
|
+
|
|
1914
|
+
def compute(self, xq, wq, x_scale, w_scale, w_zp):
|
|
1915
|
+
return torch.ops.mslk.f8i4bf16_rowwise(xq, wq, x_scale, w_scale, w_zp)
|
|
1916
|
+
|
|
1917
|
+
def quantize_and_compute(self, x, w):
|
|
1918
|
+
xq, wq, x_scale, w_scale, w_zp = self.quantize(x, w)
|
|
1919
|
+
return self.compute(xq, wq, x_scale, w_scale, w_zp)
|
|
1920
|
+
|
|
1921
|
+
@property
|
|
1922
|
+
def name(self) -> str:
|
|
1923
|
+
return "cutlass_f8i4_rowwise"
|
|
1924
|
+
|
|
1925
|
+
@property
|
|
1926
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1927
|
+
return {Accelerator.NVIDIA_SM90}
|
|
1928
|
+
|
|
1929
|
+
@property
|
|
1930
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1931
|
+
return {GemmType.REGULAR}
|
|
1932
|
+
|
|
1933
|
+
@property
|
|
1934
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1935
|
+
return ComputeDtype.FP8
|
|
1936
|
+
|
|
1937
|
+
|
|
1938
|
+
@register_gemm_op
|
|
1939
|
+
class F8I4ShuffledGemm(GemmOpBase):
|
|
1940
|
+
def preprocess(self, x, w):
|
|
1941
|
+
# Prequantize and pack weights.
|
|
1942
|
+
wq, (group_scale, row_scale) = quantize_int4_preshuffle(w)
|
|
1943
|
+
return x, wq, row_scale, group_scale
|
|
1944
|
+
|
|
1945
|
+
def quantize(self, x, wq, row_scale, group_scale):
|
|
1946
|
+
# Quantize both input tensors.
|
|
1947
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
1948
|
+
return xq, wq, x_scale, row_scale, group_scale
|
|
1949
|
+
|
|
1950
|
+
def compute(self, xq, wq, x_scale, row_scale, group_scale):
|
|
1951
|
+
# Handle batched cases by looping over each batch.
|
|
1952
|
+
if xq.dim() == 3:
|
|
1953
|
+
B, M, _ = xq.shape
|
|
1954
|
+
_, N, _ = wq.shape
|
|
1955
|
+
y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
|
|
1956
|
+
for i in range(B):
|
|
1957
|
+
y[i] = torch.ops.mslk.f8i4bf16_shuffled(
|
|
1958
|
+
xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i]
|
|
1959
|
+
)
|
|
1960
|
+
return y
|
|
1961
|
+
# Otherwise run gemm normally.
|
|
1962
|
+
return torch.ops.mslk.f8i4bf16_shuffled(xq, wq, x_scale, row_scale, group_scale)
|
|
1963
|
+
|
|
1964
|
+
def quantize_and_compute(self, x, wq, row_scale, group_scale):
|
|
1965
|
+
xq, wq, x_scale, row_scale, group_scale = self.quantize(
|
|
1966
|
+
x, wq, row_scale, group_scale
|
|
1967
|
+
)
|
|
1968
|
+
return self.compute(xq, wq, x_scale, row_scale, group_scale)
|
|
1969
|
+
|
|
1970
|
+
@property
|
|
1971
|
+
def name(self) -> str:
|
|
1972
|
+
return "cutlass_f8i4_preshuffle"
|
|
1973
|
+
|
|
1974
|
+
@property
|
|
1975
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
1976
|
+
return {Accelerator.NVIDIA_SM90}
|
|
1977
|
+
|
|
1978
|
+
@property
|
|
1979
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
1980
|
+
return {GemmType.REGULAR}
|
|
1981
|
+
|
|
1982
|
+
@property
|
|
1983
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
1984
|
+
return ComputeDtype.FP8
|
|
1985
|
+
|
|
1986
|
+
|
|
1987
|
+
@register_gemm_op
|
|
1988
|
+
class BF16I4ShuffledGemm(GemmOpBase):
|
|
1989
|
+
def preprocess(self, x, w):
|
|
1990
|
+
# Prequantize and pack weights.
|
|
1991
|
+
wq, (group_scale, group_zero) = quantize_int4_preshuffle(w, dtype="bf16")
|
|
1992
|
+
return x, wq, group_scale, group_zero
|
|
1993
|
+
|
|
1994
|
+
def quantize(self, x, wq, group_scale, group_zero):
|
|
1995
|
+
# No extra action required.
|
|
1996
|
+
return x, wq, group_scale, group_zero
|
|
1997
|
+
|
|
1998
|
+
def compute(self, x, wq, group_scale, group_zero):
|
|
1999
|
+
# Handle batched cases by looping over each batch.
|
|
2000
|
+
if x.dim() == 3:
|
|
2001
|
+
B, M, _ = x.shape
|
|
2002
|
+
_, N, _ = wq.shape
|
|
2003
|
+
y = torch.empty((B, M, N), device=x.device, dtype=torch.bfloat16)
|
|
2004
|
+
for i in range(B):
|
|
2005
|
+
y[i] = torch.ops.mslk.bf16i4bf16_shuffled(
|
|
2006
|
+
x[i], wq[i], group_scale[i], group_zero[i]
|
|
2007
|
+
)
|
|
2008
|
+
return y
|
|
2009
|
+
# Otherwise run gemm normally.
|
|
2010
|
+
return torch.ops.mslk.bf16i4bf16_shuffled(x, wq, group_scale, group_zero)
|
|
2011
|
+
|
|
2012
|
+
def quantize_and_compute(self, x, wq, group_scale, group_zero):
|
|
2013
|
+
x, wq, group_scale, group_zero = self.quantize(x, wq, group_scale, group_zero)
|
|
2014
|
+
return self.compute(x, wq, group_scale, group_zero)
|
|
2015
|
+
|
|
2016
|
+
@property
|
|
2017
|
+
def name(self) -> str:
|
|
2018
|
+
return "cutlass_bf16i4_preshuffle"
|
|
2019
|
+
|
|
2020
|
+
@property
|
|
2021
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2022
|
+
return {Accelerator.NVIDIA_SM90}
|
|
2023
|
+
|
|
2024
|
+
@property
|
|
2025
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2026
|
+
return {GemmType.REGULAR}
|
|
2027
|
+
|
|
2028
|
+
@property
|
|
2029
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2030
|
+
return ComputeDtype.BF16
|
|
2031
|
+
|
|
2032
|
+
|
|
2033
|
+
@register_gemm_op
|
|
2034
|
+
class BF16I4ShuffledBatchedGemm(GemmOpBase):
|
|
2035
|
+
"""
|
|
2036
|
+
BF16 x INT4 mixed dtype batched gemm with preshuffling.
|
|
2037
|
+
"""
|
|
2038
|
+
|
|
2039
|
+
def preprocess(self, x, w):
|
|
2040
|
+
assert isinstance(x, list) and isinstance(w, list)
|
|
2041
|
+
x = torch.stack(x, dim=0)
|
|
2042
|
+
w = torch.stack(w, dim=0)
|
|
2043
|
+
# Prequantize and pack weights.
|
|
2044
|
+
wq, (group_scale, group_zero) = quantize_int4_preshuffle(w, dtype="bf16")
|
|
2045
|
+
return x, wq, group_scale, group_zero
|
|
2046
|
+
|
|
2047
|
+
def quantize(self, x, wq, group_scale, group_zero):
|
|
2048
|
+
# No extra action required.
|
|
2049
|
+
return x, wq, group_scale, group_zero
|
|
2050
|
+
|
|
2051
|
+
def compute(self, x, wq, group_scale, group_zero):
|
|
2052
|
+
return torch.ops.mslk.bf16i4bf16_shuffled_batched(
|
|
2053
|
+
x, wq, group_scale, group_zero
|
|
2054
|
+
)
|
|
2055
|
+
|
|
2056
|
+
def quantize_and_compute(self, x, wq, group_scale, group_zero):
|
|
2057
|
+
x, wq, group_scale, group_zero = self.quantize(x, wq, group_scale, group_zero)
|
|
2058
|
+
return self.compute(x, wq, group_scale, group_zero)
|
|
2059
|
+
|
|
2060
|
+
@property
|
|
2061
|
+
def name(self) -> str:
|
|
2062
|
+
return "cutlass_bf16i4_preshuffle_batched"
|
|
2063
|
+
|
|
2064
|
+
@property
|
|
2065
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2066
|
+
return {Accelerator.NVIDIA_SM90}
|
|
2067
|
+
|
|
2068
|
+
@property
|
|
2069
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2070
|
+
return {GemmType.GROUPED}
|
|
2071
|
+
|
|
2072
|
+
@property
|
|
2073
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2074
|
+
return ComputeDtype.BF16
|
|
2075
|
+
|
|
2076
|
+
|
|
2077
|
+
@register_gemm_op
|
|
2078
|
+
class F8I4ShuffledGroupedGemm(GemmOpBase):
|
|
2079
|
+
"""
|
|
2080
|
+
FP8 x Int4 mixed dtype grouped gemm with preshuffling.
|
|
2081
|
+
"""
|
|
2082
|
+
|
|
2083
|
+
def preprocess(self, x, w):
|
|
2084
|
+
assert isinstance(x, list) and isinstance(w, list), (
|
|
2085
|
+
"Only supported for grouped inputs."
|
|
2086
|
+
)
|
|
2087
|
+
m_values = [i.shape[0] for i in x]
|
|
2088
|
+
# Convert m_values into offsets into grouped tensor.
|
|
2089
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
|
|
2090
|
+
# Quantize weights.
|
|
2091
|
+
wq, scales = zip(*[quantize_int4_preshuffle(i) for i in w])
|
|
2092
|
+
group_scale, row_scale = zip(*scales)
|
|
2093
|
+
# Group weights as single tensor.
|
|
2094
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
2095
|
+
row_scale = torch.stack(row_scale, dim=0).contiguous()
|
|
2096
|
+
group_scale = torch.stack(group_scale, dim=0).contiguous()
|
|
2097
|
+
# Also view input as flattened.
|
|
2098
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
2099
|
+
# Return processed tensors.
|
|
2100
|
+
return x, wq, row_scale, group_scale, m_sizes
|
|
2101
|
+
|
|
2102
|
+
def quantize(self, x, wq, row_scale, group_scale, m_sizes):
|
|
2103
|
+
B = x.shape[0]
|
|
2104
|
+
xq, x_scale = triton_quantize_fp8_row(x)
|
|
2105
|
+
x_scale = x_scale.view(B, -1)
|
|
2106
|
+
return xq, wq, x_scale, row_scale, group_scale, m_sizes
|
|
2107
|
+
|
|
2108
|
+
def compute(self, xq, wq, x_scale, row_scale, group_scale, m_sizes):
|
|
2109
|
+
out = torch.ops.mslk.f8i4bf16_shuffled_grouped(
|
|
2110
|
+
xq, wq, x_scale, row_scale, group_scale, m_sizes
|
|
2111
|
+
)
|
|
2112
|
+
return out
|
|
2113
|
+
|
|
2114
|
+
def quantize_and_compute(self, x, wq, row_scale, group_scale, m_sizes):
|
|
2115
|
+
xq, wq, x_scale, row_scale, group_scale, m_sizes = self.quantize(
|
|
2116
|
+
x, wq, row_scale, group_scale, m_sizes
|
|
2117
|
+
)
|
|
2118
|
+
return self.compute(xq, wq, x_scale, row_scale, group_scale, m_sizes)
|
|
2119
|
+
|
|
2120
|
+
@property
|
|
2121
|
+
def name(self) -> str:
|
|
2122
|
+
if torch.version.cuda:
|
|
2123
|
+
return "cutlass_f8i4_grouped_preshuffle"
|
|
2124
|
+
else:
|
|
2125
|
+
return "ck_f8i4_grouped_preshuffle"
|
|
2126
|
+
|
|
2127
|
+
@property
|
|
2128
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2129
|
+
return {Accelerator.NVIDIA_SM90}
|
|
2130
|
+
|
|
2131
|
+
@property
|
|
2132
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2133
|
+
return {GemmType.GROUPED}
|
|
2134
|
+
|
|
2135
|
+
@property
|
|
2136
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2137
|
+
return ComputeDtype.FP8
|
|
2138
|
+
|
|
2139
|
+
|
|
2140
|
+
@register_gemm_op
|
|
2141
|
+
class BF16I4ShuffledGroupedGemm(GemmOpBase):
|
|
2142
|
+
"""
|
|
2143
|
+
BF16 x Int4 mixed dtype grouped gemm with preshuffling.
|
|
2144
|
+
"""
|
|
2145
|
+
|
|
2146
|
+
def preprocess(self, x, w):
|
|
2147
|
+
assert isinstance(x, list) and isinstance(w, list), (
|
|
2148
|
+
"Only supported for grouped inputs."
|
|
2149
|
+
)
|
|
2150
|
+
m_values = [i.shape[0] for i in x]
|
|
2151
|
+
# Convert m_values into offsets into grouped tensor.
|
|
2152
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
|
|
2153
|
+
# Quantize weights.
|
|
2154
|
+
wq, scales = zip(
|
|
2155
|
+
*[quantize_int4_preshuffle(i, dtype="bf16", use_zp=False) for i in w]
|
|
2156
|
+
)
|
|
2157
|
+
# Group weights as single tensor.
|
|
2158
|
+
group_scale, group_zero = zip(*scales)
|
|
2159
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
2160
|
+
group_scale = torch.stack(group_scale, dim=0).contiguous()
|
|
2161
|
+
group_zero = torch.stack(group_zero, dim=0).contiguous()
|
|
2162
|
+
# Also view input as flattened.
|
|
2163
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
2164
|
+
# Return processed tensors.
|
|
2165
|
+
return x, wq, group_scale, group_zero, m_sizes
|
|
2166
|
+
|
|
2167
|
+
def quantize(self, x, wq, group_scale, group_zero, m_sizes):
|
|
2168
|
+
return x, wq, group_scale, group_zero, m_sizes
|
|
2169
|
+
|
|
2170
|
+
def compute(self, x, wq, group_scale, group_zero, m_sizes):
|
|
2171
|
+
# TODO Zero points arent currently supported in grouped gemm.
|
|
2172
|
+
# We leave them as inputs for future compatibility but they are ignored.
|
|
2173
|
+
return torch.ops.mslk.bf16i4bf16_shuffled_grouped(
|
|
2174
|
+
x, wq, group_scale, group_zero, m_sizes
|
|
2175
|
+
)
|
|
2176
|
+
|
|
2177
|
+
def quantize_and_compute(self, x, wq, group_scale, group_zero, m_sizes):
|
|
2178
|
+
x, wq, group_scale, group_zero, m_sizes = self.quantize(
|
|
2179
|
+
x, wq, group_scale, group_zero, m_sizes
|
|
2180
|
+
)
|
|
2181
|
+
return self.compute(x, wq, group_scale, group_zero, m_sizes)
|
|
2182
|
+
|
|
2183
|
+
@property
|
|
2184
|
+
def name(self) -> str:
|
|
2185
|
+
return "cutlass_bf16i4_grouped_preshuffle"
|
|
2186
|
+
|
|
2187
|
+
@property
|
|
2188
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2189
|
+
return {Accelerator.NVIDIA_SM90}
|
|
2190
|
+
|
|
2191
|
+
@property
|
|
2192
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2193
|
+
return {GemmType.GROUPED}
|
|
2194
|
+
|
|
2195
|
+
@property
|
|
2196
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2197
|
+
return ComputeDtype.BF16
|
|
2198
|
+
|
|
2199
|
+
|
|
2200
|
+
@register_gemm_op
|
|
2201
|
+
class BF16GroupedGrad(GemmOpBase):
|
|
2202
|
+
"""
|
|
2203
|
+
BF16 grouped matmul with dgrad inputs in pretraining backed by cutlass
|
|
2204
|
+
"""
|
|
2205
|
+
|
|
2206
|
+
def preprocess(self, x, w):
|
|
2207
|
+
m_values = [i.shape[0] for i in x]
|
|
2208
|
+
# Convert m_values into offsets into grouped tensor.
|
|
2209
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
2210
|
+
# Group weights as single tensor.
|
|
2211
|
+
w = torch.stack(w, dim=0).contiguous()
|
|
2212
|
+
# Prepare online dgrad during pretraining backward.
|
|
2213
|
+
w_perm = w.permute(0, 2, 1).contiguous()
|
|
2214
|
+
# w.contiguous() is very expensive so handling it inside the gmm kernel for free
|
|
2215
|
+
w = w_perm.permute(0, 2, 1)
|
|
2216
|
+
|
|
2217
|
+
# Also view input as flattened.
|
|
2218
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
2219
|
+
# Return processed tensors.
|
|
2220
|
+
return x, w, m_sizes
|
|
2221
|
+
|
|
2222
|
+
def quantize(self, x, w, m_sizes):
|
|
2223
|
+
return x, w, m_sizes
|
|
2224
|
+
|
|
2225
|
+
def compute(self, x, w, m_sizes):
|
|
2226
|
+
return torch.ops.mslk.bf16bf16bf16_grouped_grad(x, w, m_sizes)
|
|
2227
|
+
|
|
2228
|
+
def quantize_and_compute(self, x, w, m_sizes):
|
|
2229
|
+
x, w, m_sizes = self.quantize(x, w, m_sizes)
|
|
2230
|
+
return self.compute(x, w, m_sizes)
|
|
2231
|
+
|
|
2232
|
+
@property
|
|
2233
|
+
def name(self) -> str:
|
|
2234
|
+
return "bf16_grouped_grad"
|
|
2235
|
+
|
|
2236
|
+
@property
|
|
2237
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2238
|
+
return {
|
|
2239
|
+
Accelerator.NVIDIA_SM90,
|
|
2240
|
+
Accelerator.NVIDIA_SM100,
|
|
2241
|
+
Accelerator.NVIDIA_SM103,
|
|
2242
|
+
}
|
|
2243
|
+
|
|
2244
|
+
@property
|
|
2245
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2246
|
+
return {GemmType.GROUPED}
|
|
2247
|
+
|
|
2248
|
+
@property
|
|
2249
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2250
|
+
return ComputeDtype.BF16
|
|
2251
|
+
|
|
2252
|
+
|
|
2253
|
+
@register_gemm_op
|
|
2254
|
+
class BF16GroupedWGrad(GemmOpBase):
|
|
2255
|
+
"""
|
|
2256
|
+
BF16 grouped matmul with wgrad inputs in pretraining backed by cutlass
|
|
2257
|
+
"""
|
|
2258
|
+
|
|
2259
|
+
def preprocess(self, x, w):
|
|
2260
|
+
# Get K values for each group
|
|
2261
|
+
k_values = [xi.shape[1] for xi in x] # K dimension for each group
|
|
2262
|
+
|
|
2263
|
+
# Convert k_values into sizes tensor
|
|
2264
|
+
k_sizes = torch.tensor(k_values).to(dtype=torch.int64, device=x[0].device)
|
|
2265
|
+
|
|
2266
|
+
x = torch.concat(x, dim=1).contiguous() # shape: (M, G*K)
|
|
2267
|
+
w = torch.concat(w, dim=1).contiguous() # shape: (N, G*K)
|
|
2268
|
+
|
|
2269
|
+
# Transpose the follows to simulate wgrad shapes
|
|
2270
|
+
x = x.t().contiguous() # shape: (G*K, M)
|
|
2271
|
+
w = w.t().contiguous() # shape: (G*K, N)
|
|
2272
|
+
|
|
2273
|
+
# Return processed tensors
|
|
2274
|
+
return x, w, k_sizes
|
|
2275
|
+
|
|
2276
|
+
def quantize(self, x, w, k_sizes):
|
|
2277
|
+
return x, w, k_sizes
|
|
2278
|
+
|
|
2279
|
+
def compute(self, x, w, k_sizes):
|
|
2280
|
+
return torch.ops.mslk.bf16bf16bf16_grouped_wgrad(x, w, k_sizes)
|
|
2281
|
+
|
|
2282
|
+
def quantize_and_compute(self, x, w, k_sizes):
|
|
2283
|
+
x, w, k_sizes = self.quantize(x, w, k_sizes)
|
|
2284
|
+
return self.compute(x, w, k_sizes)
|
|
2285
|
+
|
|
2286
|
+
@property
|
|
2287
|
+
def name(self) -> str:
|
|
2288
|
+
return "bf16_grouped_wgrad"
|
|
2289
|
+
|
|
2290
|
+
@property
|
|
2291
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2292
|
+
return {
|
|
2293
|
+
Accelerator.NVIDIA_SM90,
|
|
2294
|
+
Accelerator.NVIDIA_SM100,
|
|
2295
|
+
Accelerator.NVIDIA_SM103,
|
|
2296
|
+
}
|
|
2297
|
+
|
|
2298
|
+
@property
|
|
2299
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2300
|
+
return {GemmType.GROUPED}
|
|
2301
|
+
|
|
2302
|
+
@property
|
|
2303
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2304
|
+
return ComputeDtype.BF16
|
|
2305
|
+
|
|
2306
|
+
|
|
2307
|
+
@register_gemm_op
|
|
2308
|
+
class BF16GroupedStacked(GemmOpBase):
|
|
2309
|
+
"""
|
|
2310
|
+
BF16 grouped matmul with stacked inputs backed by cutlass or ck.
|
|
2311
|
+
"""
|
|
2312
|
+
|
|
2313
|
+
def preprocess(self, x, w):
|
|
2314
|
+
m_values = [i.shape[0] for i in x]
|
|
2315
|
+
# Convert m_values into offsets into grouped tensor.
|
|
2316
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
2317
|
+
# Group weights as single tensor.
|
|
2318
|
+
w = torch.stack(w, dim=0).contiguous()
|
|
2319
|
+
# Also view input as flattened.
|
|
2320
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
2321
|
+
# Return processed tensors.
|
|
2322
|
+
return x, w, m_sizes
|
|
2323
|
+
|
|
2324
|
+
def quantize(self, x, w, m_sizes):
|
|
2325
|
+
return x, w, m_sizes
|
|
2326
|
+
|
|
2327
|
+
def compute(self, x, w, m_sizes):
|
|
2328
|
+
return torch.ops.mslk.bf16bf16bf16_grouped_stacked(x, w, m_sizes)
|
|
2329
|
+
|
|
2330
|
+
def quantize_and_compute(self, x, w, m_sizes):
|
|
2331
|
+
x, w, m_sizes = self.quantize(x, w, m_sizes)
|
|
2332
|
+
return self.compute(x, w, m_sizes)
|
|
2333
|
+
|
|
2334
|
+
@property
|
|
2335
|
+
def name(self) -> str:
|
|
2336
|
+
return "bf16_grouped_stacked"
|
|
2337
|
+
|
|
2338
|
+
@property
|
|
2339
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2340
|
+
return {
|
|
2341
|
+
Accelerator.NVIDIA_SM90,
|
|
2342
|
+
Accelerator.NVIDIA_SM100,
|
|
2343
|
+
Accelerator.NVIDIA_SM103,
|
|
2344
|
+
Accelerator.AMD_MI300X,
|
|
2345
|
+
}
|
|
2346
|
+
|
|
2347
|
+
@property
|
|
2348
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2349
|
+
return {GemmType.GROUPED}
|
|
2350
|
+
|
|
2351
|
+
@property
|
|
2352
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2353
|
+
return ComputeDtype.BF16
|
|
2354
|
+
|
|
2355
|
+
|
|
2356
|
+
@register_gemm_op
|
|
2357
|
+
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
|
|
2358
|
+
"""
|
|
2359
|
+
Mixed Precision BF16 Activations with Int4 Weights.
|
|
2360
|
+
"""
|
|
2361
|
+
|
|
2362
|
+
def quantize(self, x, w):
|
|
2363
|
+
# Quantize both input tensors.
|
|
2364
|
+
wq, w_scale, w_zp = self._int4_row_quantize(w)
|
|
2365
|
+
# Pack int4 values together.
|
|
2366
|
+
wq = self._pack_int4(wq)
|
|
2367
|
+
return (
|
|
2368
|
+
x.to(torch.bfloat16),
|
|
2369
|
+
wq,
|
|
2370
|
+
w_scale,
|
|
2371
|
+
w_zp,
|
|
2372
|
+
)
|
|
2373
|
+
|
|
2374
|
+
def compute(self, x, wq, w_scale, w_zp):
|
|
2375
|
+
return torch.ops.mslk.bf16i4bf16_rowwise(x, wq, w_scale, w_zp)
|
|
2376
|
+
|
|
2377
|
+
def quantize_and_compute(self, x, w):
|
|
2378
|
+
x, wq, w_scale, w_zp = self.quantize(x, w)
|
|
2379
|
+
return self.compute(x, wq, w_scale, w_zp)
|
|
2380
|
+
|
|
2381
|
+
@property
|
|
2382
|
+
def name(self) -> str:
|
|
2383
|
+
return "cutlass_bf16i4_rowwise"
|
|
2384
|
+
|
|
2385
|
+
@property
|
|
2386
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2387
|
+
return {Accelerator.NVIDIA_SM90}
|
|
2388
|
+
|
|
2389
|
+
@property
|
|
2390
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2391
|
+
return {GemmType.REGULAR}
|
|
2392
|
+
|
|
2393
|
+
@property
|
|
2394
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2395
|
+
return ComputeDtype.BF16
|
|
2396
|
+
|
|
2397
|
+
|
|
2398
|
+
@register_gemm_op
|
|
2399
|
+
class TinyGemmBF16I4(GemmOpBase):
|
|
2400
|
+
"""
|
|
2401
|
+
Mixed Precision BF16 Activations with Int4 Weights using tinygemm.
|
|
2402
|
+
"""
|
|
2403
|
+
|
|
2404
|
+
def quantize(self, x, w):
|
|
2405
|
+
# Quantize and pack weights to int4 using tinygemm utils.
|
|
2406
|
+
w_int32, w_scales_and_zeros = group_quantize_tensor(
|
|
2407
|
+
w, n_bit=4, q_group_size=128
|
|
2408
|
+
)
|
|
2409
|
+
wq = torch.ops.tinygemm.convert_matrix_to_m16n8k16_Aint4_layout(w_int32, 4)
|
|
2410
|
+
return x, wq, w_scales_and_zeros
|
|
2411
|
+
|
|
2412
|
+
def compute(self, x, wq, scale):
|
|
2413
|
+
return torch.ops.tinygemm.tinygemm_y_f16RM_x_f16RM_w_int4TC(
|
|
2414
|
+
wq, x, 128, scale, False
|
|
2415
|
+
)
|
|
2416
|
+
|
|
2417
|
+
def quantize_and_compute(self, x, w):
|
|
2418
|
+
x, wq, scale = self.quantize(x, w)
|
|
2419
|
+
return self.compute(x, wq, scale)
|
|
2420
|
+
|
|
2421
|
+
@property
|
|
2422
|
+
def name(self) -> str:
|
|
2423
|
+
return "tinygemm_bf16i4"
|
|
2424
|
+
|
|
2425
|
+
@property
|
|
2426
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2427
|
+
if TINYGEMM_ENABLED:
|
|
2428
|
+
return {Accelerator.NVIDIA_SM90}
|
|
2429
|
+
return set()
|
|
2430
|
+
|
|
2431
|
+
@property
|
|
2432
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2433
|
+
return {GemmType.REGULAR}
|
|
2434
|
+
|
|
2435
|
+
@property
|
|
2436
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2437
|
+
return ComputeDtype.BF16
|
|
2438
|
+
|
|
2439
|
+
|
|
2440
|
+
@register_gemm_op
|
|
2441
|
+
class MarlinBF16I4(GemmOpBase):
|
|
2442
|
+
"""
|
|
2443
|
+
Mixed Precision BF16 Activations with Int4 Weights using Marlin.
|
|
2444
|
+
"""
|
|
2445
|
+
|
|
2446
|
+
def quantize(self, x, w):
|
|
2447
|
+
# Marlin quantize expects weights in [K, N] layout.
|
|
2448
|
+
_, wq, scale = marlin_quantize(w.t().contiguous(), 128)
|
|
2449
|
+
return x, wq, scale
|
|
2450
|
+
|
|
2451
|
+
def compute(self, x, wq, scale):
|
|
2452
|
+
return torch.ops.marlin.marlin_gemm(x, wq, scale)
|
|
2453
|
+
|
|
2454
|
+
def quantize_and_compute(self, x, w):
|
|
2455
|
+
x, wq, scale = self.quantize(x, w)
|
|
2456
|
+
return self.compute(x, wq, scale)
|
|
2457
|
+
|
|
2458
|
+
@property
|
|
2459
|
+
def name(self) -> str:
|
|
2460
|
+
return "marlin_bf16i4"
|
|
2461
|
+
|
|
2462
|
+
@property
|
|
2463
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2464
|
+
if MARLIN_ENABLED:
|
|
2465
|
+
return {
|
|
2466
|
+
Accelerator.NVIDIA_SM90,
|
|
2467
|
+
Accelerator.NVIDIA_SM100,
|
|
2468
|
+
Accelerator.NVIDIA_SM103,
|
|
2469
|
+
}
|
|
2470
|
+
return set()
|
|
2471
|
+
|
|
2472
|
+
@property
|
|
2473
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2474
|
+
return {GemmType.REGULAR}
|
|
2475
|
+
|
|
2476
|
+
@property
|
|
2477
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2478
|
+
return ComputeDtype.BF16
|
|
2479
|
+
|
|
2480
|
+
|
|
2481
|
+
@register_gemm_op
|
|
2482
|
+
class MacheteBF16I4(GemmOpBase):
|
|
2483
|
+
"""
|
|
2484
|
+
Mixed Precision BF16 Activations with Int4 Weights using Machete.
|
|
2485
|
+
"""
|
|
2486
|
+
|
|
2487
|
+
def quantize(self, x, w):
|
|
2488
|
+
# Marlin quantize expects weights in [K, N] layout.
|
|
2489
|
+
_, wq, scale, _ = machete_quantize_and_pack(
|
|
2490
|
+
w.t().contiguous(), bits=4, groupsize=128
|
|
2491
|
+
)
|
|
2492
|
+
return x, wq, scale
|
|
2493
|
+
|
|
2494
|
+
def compute(self, x, wq, scale):
|
|
2495
|
+
return machete_gemm(x, wq, bits=4, groupsize=128, scales=scale)
|
|
2496
|
+
|
|
2497
|
+
def quantize_and_compute(self, x, w):
|
|
2498
|
+
x, wq, scale = self.quantize(x, w)
|
|
2499
|
+
return self.compute(x, wq, scale)
|
|
2500
|
+
|
|
2501
|
+
@property
|
|
2502
|
+
def name(self) -> str:
|
|
2503
|
+
return "machete_bf16i4"
|
|
2504
|
+
|
|
2505
|
+
@property
|
|
2506
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2507
|
+
if MACHETE_ENABLED:
|
|
2508
|
+
return {Accelerator.NVIDIA_SM90}
|
|
2509
|
+
return set()
|
|
2510
|
+
|
|
2511
|
+
@property
|
|
2512
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2513
|
+
return {GemmType.REGULAR}
|
|
2514
|
+
|
|
2515
|
+
@property
|
|
2516
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2517
|
+
return ComputeDtype.BF16
|
|
2518
|
+
|
|
2519
|
+
|
|
2520
|
+
@register_gemm_op
|
|
2521
|
+
class NVFP4Gemm(GemmOpBase):
|
|
2522
|
+
"""
|
|
2523
|
+
NVFP4 matmul with block-wise scaling.
|
|
2524
|
+
"""
|
|
2525
|
+
|
|
2526
|
+
def quantize(self, x, w):
|
|
2527
|
+
x_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(x.flatten()), dim=-1).to(
|
|
2528
|
+
torch.float32
|
|
2529
|
+
)
|
|
2530
|
+
w_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(w.flatten()), dim=-1).to(
|
|
2531
|
+
torch.float32
|
|
2532
|
+
)
|
|
2533
|
+
global_scale = 1 / (x_global_scale * w_global_scale)
|
|
2534
|
+
|
|
2535
|
+
xq, x_scale = triton_quantize_nvfp4(x, x_global_scale)
|
|
2536
|
+
wq, w_scale = triton_quantize_nvfp4(w, w_global_scale)
|
|
2537
|
+
|
|
2538
|
+
return xq, wq, x_scale, w_scale, global_scale
|
|
2539
|
+
|
|
2540
|
+
def compute(self, xq, wq, x_scale, w_scale, global_scale):
|
|
2541
|
+
return torch.ops.mslk.f4f4bf16(
|
|
2542
|
+
xq, wq, x_scale, w_scale, global_scale=global_scale
|
|
2543
|
+
)
|
|
2544
|
+
|
|
2545
|
+
def quantize_and_compute(self, x, w):
|
|
2546
|
+
xq, wq, x_scale, w_scale, global_scale = self.quantize(x, w)
|
|
2547
|
+
return self.compute(xq, wq, x_scale, w_scale, global_scale=global_scale)
|
|
2548
|
+
|
|
2549
|
+
@property
|
|
2550
|
+
def name(self) -> str:
|
|
2551
|
+
return "cutlass_nv_f4f4bf16"
|
|
2552
|
+
|
|
2553
|
+
@property
|
|
2554
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2555
|
+
return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
|
|
2556
|
+
|
|
2557
|
+
@property
|
|
2558
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2559
|
+
return {GemmType.REGULAR}
|
|
2560
|
+
|
|
2561
|
+
@property
|
|
2562
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2563
|
+
return ComputeDtype.FP4
|
|
2564
|
+
|
|
2565
|
+
|
|
2566
|
+
@register_gemm_op
|
|
2567
|
+
class NVFP4Quantize(GemmOpBase):
|
|
2568
|
+
"""
|
|
2569
|
+
NVFP4 quantization with block-wise scaling.
|
|
2570
|
+
"""
|
|
2571
|
+
|
|
2572
|
+
def quantize_rms(self, x, w):
|
|
2573
|
+
M, N = w.shape
|
|
2574
|
+
group_size = 16
|
|
2575
|
+
w = torch.randn(group_size, dtype=torch.bfloat16, device=w.device)
|
|
2576
|
+
x_global_scale = torch.tensor([448.0 * 6.0]).to(
|
|
2577
|
+
device=x.device, dtype=torch.float32
|
|
2578
|
+
) / torch.amax(torch.abs(x.flatten()), dim=-1).to(torch.float32)
|
|
2579
|
+
xq_ref, x_scale_ref = triton_scale_nvfp4_quant_rms(
|
|
2580
|
+
x,
|
|
2581
|
+
w.repeat(M * N // group_size),
|
|
2582
|
+
x_global_scale,
|
|
2583
|
+
group_size=group_size,
|
|
2584
|
+
EPS=1e-5,
|
|
2585
|
+
)
|
|
2586
|
+
|
|
2587
|
+
intermediate = rms_norm(x.reshape(-1, 16), w, eps=1e-5)
|
|
2588
|
+
intermediate = intermediate.to(torch.bfloat16).reshape(M, N)
|
|
2589
|
+
xq, x_scale = triton_quantize_nvfp4(
|
|
2590
|
+
intermediate,
|
|
2591
|
+
x_global_scale,
|
|
2592
|
+
group_size=group_size,
|
|
2593
|
+
)
|
|
2594
|
+
|
|
2595
|
+
def quantize_silu(self, x, w):
|
|
2596
|
+
M, N = x.shape
|
|
2597
|
+
group_size = 16
|
|
2598
|
+
x_global_scale = torch.tensor([448.0 * 6.0]).to(
|
|
2599
|
+
device=x.device, dtype=torch.float32
|
|
2600
|
+
) / torch.amax(torch.abs(x.flatten()), dim=-1).to(torch.float32)
|
|
2601
|
+
xq_ref, x_scale_ref = triton_scale_nvfp4_quant_silu(
|
|
2602
|
+
x,
|
|
2603
|
+
w,
|
|
2604
|
+
x_global_scale,
|
|
2605
|
+
group_size=group_size,
|
|
2606
|
+
)
|
|
2607
|
+
|
|
2608
|
+
intermediate = silu_mul(x.reshape(-1, 16), w.reshape(-1, 16))
|
|
2609
|
+
intermediate = intermediate.to(torch.bfloat16).reshape(M, N)
|
|
2610
|
+
xq, x_scale = triton_quantize_nvfp4(
|
|
2611
|
+
intermediate,
|
|
2612
|
+
x_global_scale,
|
|
2613
|
+
group_size=group_size,
|
|
2614
|
+
)
|
|
2615
|
+
|
|
2616
|
+
def quantize(self, x, w):
|
|
2617
|
+
x_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(x.flatten()), dim=-1).to(
|
|
2618
|
+
torch.float32
|
|
2619
|
+
)
|
|
2620
|
+
w_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(w.flatten()), dim=-1).to(
|
|
2621
|
+
torch.float32
|
|
2622
|
+
)
|
|
2623
|
+
global_scale = 1 / (x_global_scale * w_global_scale)
|
|
2624
|
+
|
|
2625
|
+
xq, x_scale = triton_quantize_nvfp4(x, x_global_scale)
|
|
2626
|
+
wq, w_scale = triton_quantize_nvfp4(w, w_global_scale)
|
|
2627
|
+
return xq, wq, x_scale, w_scale, global_scale
|
|
2628
|
+
|
|
2629
|
+
def compute(self, xq, wq, x_scale, w_scale, global_scale):
|
|
2630
|
+
return torch.ops.mslk.f4f4bf16(
|
|
2631
|
+
xq, wq, x_scale, w_scale, global_scale=global_scale
|
|
2632
|
+
)
|
|
2633
|
+
|
|
2634
|
+
def quantize_and_compute(self, x, w):
|
|
2635
|
+
return self.quantize(x, w)
|
|
2636
|
+
|
|
2637
|
+
@property
|
|
2638
|
+
def name(self) -> str:
|
|
2639
|
+
return "nvfp4_quantize"
|
|
2640
|
+
|
|
2641
|
+
@property
|
|
2642
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2643
|
+
return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
|
|
2644
|
+
|
|
2645
|
+
@property
|
|
2646
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2647
|
+
return {GemmType.REGULAR}
|
|
2648
|
+
|
|
2649
|
+
@property
|
|
2650
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2651
|
+
return ComputeDtype.FP4
|
|
2652
|
+
|
|
2653
|
+
|
|
2654
|
+
@register_gemm_op
|
|
2655
|
+
class MXFP4Gemm(GemmOpBase):
|
|
2656
|
+
"""
|
|
2657
|
+
MXFP4 matmul with block-wise scaling.
|
|
2658
|
+
"""
|
|
2659
|
+
|
|
2660
|
+
def quantize(self, x, w):
|
|
2661
|
+
xq, x_scale = triton_quantize_mx4_unpack(x)
|
|
2662
|
+
wq, w_scale = triton_quantize_mx4_unpack(w)
|
|
2663
|
+
return xq, wq, x_scale, w_scale
|
|
2664
|
+
|
|
2665
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
2666
|
+
return torch.ops.mslk.f4f4bf16(xq, wq, x_scale, w_scale)
|
|
2667
|
+
|
|
2668
|
+
def quantize_and_compute(self, x, w):
|
|
2669
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
2670
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
2671
|
+
|
|
2672
|
+
@property
|
|
2673
|
+
def name(self) -> str:
|
|
2674
|
+
return "cutlass_f4f4bf16"
|
|
2675
|
+
|
|
2676
|
+
@property
|
|
2677
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2678
|
+
return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
|
|
2679
|
+
|
|
2680
|
+
@property
|
|
2681
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2682
|
+
return {GemmType.REGULAR}
|
|
2683
|
+
|
|
2684
|
+
@property
|
|
2685
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2686
|
+
return ComputeDtype.FP4
|
|
2687
|
+
|
|
2688
|
+
|
|
2689
|
+
@register_gemm_op
|
|
2690
|
+
class MXFP4StackedGroupedGemm(GemmOpBase):
|
|
2691
|
+
"""
|
|
2692
|
+
MXFP4 grouped matmul with blockwise scaling and stacked inputs.
|
|
2693
|
+
"""
|
|
2694
|
+
|
|
2695
|
+
def preprocess(self, x, w):
|
|
2696
|
+
m_values = [i.shape[0] for i in x]
|
|
2697
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
2698
|
+
wq, w_scale = zip(*[triton_quantize_mx4_unpack(i) for i in w])
|
|
2699
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
2700
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
2701
|
+
return x, wq, w_scale, m_sizes
|
|
2702
|
+
|
|
2703
|
+
def quantize(self, x, wq, w_scale, m_sizes):
|
|
2704
|
+
starting_row_after_padding_list = [0]
|
|
2705
|
+
xq_list = []
|
|
2706
|
+
x_scale_list = []
|
|
2707
|
+
for i in range(m_sizes.shape[0]):
|
|
2708
|
+
scale_slice = x[i]
|
|
2709
|
+
if m_sizes[i].item() != 0:
|
|
2710
|
+
xq, x_scale = triton_quantize_mx4_unpack(scale_slice)
|
|
2711
|
+
xq_list.append(xq)
|
|
2712
|
+
x_scale_list.append(x_scale)
|
|
2713
|
+
starting_row_after_padding_list.append(
|
|
2714
|
+
starting_row_after_padding_list[i]
|
|
2715
|
+
+ x_scale.numel() // (x[0].shape[1] // 32)
|
|
2716
|
+
)
|
|
2717
|
+
else:
|
|
2718
|
+
starting_row_after_padding_list.append(
|
|
2719
|
+
starting_row_after_padding_list[i]
|
|
2720
|
+
)
|
|
2721
|
+
xq = torch.cat(xq_list, dim=0).contiguous()
|
|
2722
|
+
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
|
|
2723
|
+
x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32)
|
|
2724
|
+
xq = xq.view(-1, xq.shape[-1])
|
|
2725
|
+
return (
|
|
2726
|
+
xq,
|
|
2727
|
+
wq,
|
|
2728
|
+
x_scale,
|
|
2729
|
+
w_scale,
|
|
2730
|
+
m_sizes,
|
|
2731
|
+
torch.tensor(starting_row_after_padding_list, device=xq.device),
|
|
2732
|
+
)
|
|
2733
|
+
|
|
2734
|
+
def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding):
|
|
2735
|
+
return torch.ops.mslk.f4f4bf16_grouped_stacked(
|
|
2736
|
+
xq,
|
|
2737
|
+
wq,
|
|
2738
|
+
x_scale,
|
|
2739
|
+
w_scale,
|
|
2740
|
+
m_sizes,
|
|
2741
|
+
starting_row_after_padding=starting_row_after_padding,
|
|
2742
|
+
)
|
|
2743
|
+
|
|
2744
|
+
def quantize_and_compute(self, x, w):
|
|
2745
|
+
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize(
|
|
2746
|
+
x, w
|
|
2747
|
+
)
|
|
2748
|
+
return self.compute(
|
|
2749
|
+
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding
|
|
2750
|
+
)
|
|
2751
|
+
|
|
2752
|
+
@property
|
|
2753
|
+
def name(self) -> str:
|
|
2754
|
+
return "cutlass_f4f4bf16_grouped_stacked"
|
|
2755
|
+
|
|
2756
|
+
@property
|
|
2757
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2758
|
+
return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
|
|
2759
|
+
|
|
2760
|
+
@property
|
|
2761
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2762
|
+
return {GemmType.GROUPED}
|
|
2763
|
+
|
|
2764
|
+
@property
|
|
2765
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2766
|
+
return ComputeDtype.FP4
|
|
2767
|
+
|
|
2768
|
+
|
|
2769
|
+
@register_gemm_op
|
|
2770
|
+
class NVFP4StackedGroupedGemm(GemmOpBase):
|
|
2771
|
+
"""
|
|
2772
|
+
NVFP4 grouped matmul with blockwise scaling and stacked inputs.
|
|
2773
|
+
"""
|
|
2774
|
+
|
|
2775
|
+
def preprocess(self, x, w):
|
|
2776
|
+
m_values = [i.shape[0] for i in x]
|
|
2777
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
2778
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
2779
|
+
|
|
2780
|
+
def get_global_scale(x, w, m_sizes):
|
|
2781
|
+
G = len(w)
|
|
2782
|
+
w_global_scale = []
|
|
2783
|
+
global_scale = []
|
|
2784
|
+
|
|
2785
|
+
cumulative_sum = torch.zeros(
|
|
2786
|
+
m_sizes.shape[0] + 1, dtype=torch.int64, device=m_sizes.device
|
|
2787
|
+
)
|
|
2788
|
+
cumulative_sum[1:] = torch.cumsum(m_sizes, dim=0)
|
|
2789
|
+
|
|
2790
|
+
x_global_scale, tensor_idx = calculate_group_max(x, m_sizes=m_sizes)
|
|
2791
|
+
|
|
2792
|
+
for i in range(G):
|
|
2793
|
+
w_global_scale_ = (448.0 * 6.0) / torch.amax(
|
|
2794
|
+
torch.abs(w[i].flatten()), dim=-1
|
|
2795
|
+
).to(torch.float32)
|
|
2796
|
+
|
|
2797
|
+
global_scale_ = 1 / (x_global_scale[i] * w_global_scale_)
|
|
2798
|
+
|
|
2799
|
+
w_global_scale.append(w_global_scale_)
|
|
2800
|
+
global_scale.append(global_scale_)
|
|
2801
|
+
|
|
2802
|
+
return x_global_scale, w_global_scale, global_scale, tensor_idx
|
|
2803
|
+
|
|
2804
|
+
# Compute global scale for each group
|
|
2805
|
+
G = m_sizes.numel()
|
|
2806
|
+
x_global_scale, w_global_scale, global_scale, tensor_idx = get_global_scale(
|
|
2807
|
+
x, w, m_sizes
|
|
2808
|
+
)
|
|
2809
|
+
global_scale = torch.stack(global_scale, dim=0).contiguous()
|
|
2810
|
+
|
|
2811
|
+
wq, w_scale = zip(
|
|
2812
|
+
*[triton_quantize_nvfp4(w[i], w_global_scale[i]) for i in range(G)]
|
|
2813
|
+
)
|
|
2814
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
2815
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
2816
|
+
|
|
2817
|
+
return x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
|
|
2818
|
+
|
|
2819
|
+
def quantize(
|
|
2820
|
+
self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
|
|
2821
|
+
):
|
|
2822
|
+
# alternative methods, may be useful in some scenarios
|
|
2823
|
+
"""
|
|
2824
|
+
starting_row_after_padding, belong_indices, row_within_tensor = (
|
|
2825
|
+
nvfp4_fused_padding_cumsum_and_segmented_arange(m_sizes, x.shape[0])
|
|
2826
|
+
# fused_single_block_cumsum_and_segmented_arange(m_sizes, x.shape[0])
|
|
2827
|
+
)
|
|
2828
|
+
|
|
2829
|
+
xq, x_scale = triton_nvfp4_quant_stacked(
|
|
2830
|
+
x,
|
|
2831
|
+
x_global_scale[0],
|
|
2832
|
+
belong_indices,
|
|
2833
|
+
starting_row_after_padding,
|
|
2834
|
+
row_within_tensor,
|
|
2835
|
+
)
|
|
2836
|
+
"""
|
|
2837
|
+
|
|
2838
|
+
# we can optionally set optional_tensor_idx to None to run the alternative method
|
|
2839
|
+
xq, x_scale, starting_row_after_padding = mega_fp4_quantize_kernel(
|
|
2840
|
+
m_sizes, x, x_global_scale, optional_tensor_idx=tensor_idx
|
|
2841
|
+
)
|
|
2842
|
+
|
|
2843
|
+
x_scale = x_scale.reshape(-1, x.shape[1] // 16)
|
|
2844
|
+
return (
|
|
2845
|
+
xq,
|
|
2846
|
+
wq,
|
|
2847
|
+
x_scale,
|
|
2848
|
+
w_scale,
|
|
2849
|
+
m_sizes,
|
|
2850
|
+
global_scale,
|
|
2851
|
+
starting_row_after_padding,
|
|
2852
|
+
)
|
|
2853
|
+
|
|
2854
|
+
def compute(
|
|
2855
|
+
self,
|
|
2856
|
+
xq,
|
|
2857
|
+
wq,
|
|
2858
|
+
x_scale,
|
|
2859
|
+
w_scale,
|
|
2860
|
+
m_sizes,
|
|
2861
|
+
global_scale,
|
|
2862
|
+
starting_row_after_padding,
|
|
2863
|
+
):
|
|
2864
|
+
gemm_result = torch.ops.mslk.f4f4bf16_grouped_stacked(
|
|
2865
|
+
xq,
|
|
2866
|
+
wq,
|
|
2867
|
+
x_scale,
|
|
2868
|
+
w_scale,
|
|
2869
|
+
m_sizes,
|
|
2870
|
+
global_scale,
|
|
2871
|
+
starting_row_after_padding,
|
|
2872
|
+
use_mx=False,
|
|
2873
|
+
)
|
|
2874
|
+
return gemm_result
|
|
2875
|
+
|
|
2876
|
+
def quantize_and_compute(
|
|
2877
|
+
self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
|
|
2878
|
+
):
|
|
2879
|
+
(
|
|
2880
|
+
xq,
|
|
2881
|
+
wq,
|
|
2882
|
+
x_scale,
|
|
2883
|
+
w_scale,
|
|
2884
|
+
m_sizes,
|
|
2885
|
+
global_scale,
|
|
2886
|
+
starting_row_after_padding,
|
|
2887
|
+
) = self.quantize(
|
|
2888
|
+
x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
|
|
2889
|
+
)
|
|
2890
|
+
return self.compute(
|
|
2891
|
+
xq,
|
|
2892
|
+
wq,
|
|
2893
|
+
x_scale,
|
|
2894
|
+
w_scale,
|
|
2895
|
+
m_sizes,
|
|
2896
|
+
global_scale,
|
|
2897
|
+
starting_row_after_padding,
|
|
2898
|
+
)
|
|
2899
|
+
|
|
2900
|
+
@property
|
|
2901
|
+
def name(self) -> str:
|
|
2902
|
+
return "cutlass_nv_f4f4bf16_grouped_stacked"
|
|
2903
|
+
|
|
2904
|
+
@property
|
|
2905
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
2906
|
+
return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
|
|
2907
|
+
|
|
2908
|
+
@property
|
|
2909
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
2910
|
+
return {GemmType.GROUPED}
|
|
2911
|
+
|
|
2912
|
+
@property
|
|
2913
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
2914
|
+
return ComputeDtype.FP4
|
|
2915
|
+
|
|
2916
|
+
|
|
2917
|
+
# Broken with cuda graph
|
|
2918
|
+
# @register_gemm_op
|
|
2919
|
+
class NVFP4StackedGroupedGemmPackUnpack(GemmOpBase):
|
|
2920
|
+
"""
|
|
2921
|
+
NVFP4 grouped matmul with blockwise scaling and stacked inputs.
|
|
2922
|
+
"""
|
|
2923
|
+
|
|
2924
|
+
def preprocess(self, x, w):
|
|
2925
|
+
m_values = [i.shape[0] for i in x]
|
|
2926
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
2927
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
2928
|
+
|
|
2929
|
+
def get_global_scale(x, w):
|
|
2930
|
+
G = len(w)
|
|
2931
|
+
x_global_scale = []
|
|
2932
|
+
w_global_scale = []
|
|
2933
|
+
global_scale = []
|
|
2934
|
+
|
|
2935
|
+
x_global_scale_ = (448.0 * 6.0) / torch.amax(
|
|
2936
|
+
torch.abs(x.flatten()), dim=-1
|
|
2937
|
+
).to(torch.float32)
|
|
2938
|
+
|
|
2939
|
+
for i in range(G):
|
|
2940
|
+
w_global_scale_ = (448.0 * 6.0) / torch.amax(
|
|
2941
|
+
torch.abs(w[i].flatten()), dim=-1
|
|
2942
|
+
).to(torch.float32)
|
|
2943
|
+
|
|
2944
|
+
global_scale_ = 1 / (x_global_scale_ * w_global_scale_)
|
|
2945
|
+
|
|
2946
|
+
x_global_scale.append(x_global_scale_)
|
|
2947
|
+
w_global_scale.append(w_global_scale_)
|
|
2948
|
+
global_scale.append(global_scale_)
|
|
2949
|
+
|
|
2950
|
+
return x_global_scale, w_global_scale, global_scale
|
|
2951
|
+
|
|
2952
|
+
# Compute global scale for each group
|
|
2953
|
+
G = m_sizes.numel()
|
|
2954
|
+
x_global_scale, w_global_scale, global_scale = get_global_scale(x, w)
|
|
2955
|
+
|
|
2956
|
+
global_scale = torch.stack(global_scale, dim=0).contiguous()
|
|
2957
|
+
|
|
2958
|
+
wq, w_scale = zip(
|
|
2959
|
+
*[triton_quantize_nvfp4(w[i], w_global_scale[i]) for i in range(G)]
|
|
2960
|
+
)
|
|
2961
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
2962
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
2963
|
+
x_global_scale = torch.tensor(x_global_scale, device=m_sizes.device)
|
|
2964
|
+
return (
|
|
2965
|
+
x,
|
|
2966
|
+
wq,
|
|
2967
|
+
w_scale,
|
|
2968
|
+
x_global_scale,
|
|
2969
|
+
global_scale,
|
|
2970
|
+
m_sizes,
|
|
2971
|
+
)
|
|
2972
|
+
|
|
2973
|
+
def quantize(self, x, wq, w_scale, x_global_scale, global_scale, m_sizes):
|
|
2974
|
+
# alternative packing methods that only uses the overall global scale rather than per tensor
|
|
2975
|
+
"""
|
|
2976
|
+
packed = mega_fp4_pack(x, x_global_scale[0])
|
|
2977
|
+
"""
|
|
2978
|
+
packed = mega_fp4_pack(
|
|
2979
|
+
x,
|
|
2980
|
+
x_global_scale,
|
|
2981
|
+
per_tensor=True,
|
|
2982
|
+
m_sizes=m_sizes,
|
|
2983
|
+
)
|
|
2984
|
+
xq, x_scale, starting_row_after_padding = mega_fp4_unpack(m_sizes, packed)
|
|
2985
|
+
xq_other, x_scale_other, starting_row_after_padding_other = (
|
|
2986
|
+
mega_fp4_quantize_kernel(
|
|
2987
|
+
m_sizes,
|
|
2988
|
+
x,
|
|
2989
|
+
x_global_scale,
|
|
2990
|
+
)
|
|
2991
|
+
)
|
|
2992
|
+
|
|
2993
|
+
x_scale = x_scale.reshape(-1, x.shape[1] // 16)
|
|
2994
|
+
x_scale_other = x_scale_other.reshape(-1, x.shape[1] // 16)
|
|
2995
|
+
return (
|
|
2996
|
+
xq,
|
|
2997
|
+
wq,
|
|
2998
|
+
x_scale,
|
|
2999
|
+
w_scale,
|
|
3000
|
+
m_sizes,
|
|
3001
|
+
global_scale,
|
|
3002
|
+
starting_row_after_padding,
|
|
3003
|
+
xq_other,
|
|
3004
|
+
x_scale_other,
|
|
3005
|
+
starting_row_after_padding_other,
|
|
3006
|
+
)
|
|
3007
|
+
|
|
3008
|
+
def compute(
|
|
3009
|
+
self,
|
|
3010
|
+
xq,
|
|
3011
|
+
wq,
|
|
3012
|
+
x_scale,
|
|
3013
|
+
w_scale,
|
|
3014
|
+
m_sizes,
|
|
3015
|
+
global_scale,
|
|
3016
|
+
starting_row_after_padding,
|
|
3017
|
+
xq_other,
|
|
3018
|
+
x_scale_other,
|
|
3019
|
+
starting_row_after_padding_other,
|
|
3020
|
+
):
|
|
3021
|
+
ref_solution = torch.ops.mslk.f4f4bf16_grouped_stacked(
|
|
3022
|
+
xq_other,
|
|
3023
|
+
wq,
|
|
3024
|
+
x_scale_other,
|
|
3025
|
+
w_scale,
|
|
3026
|
+
m_sizes,
|
|
3027
|
+
global_scale,
|
|
3028
|
+
starting_row_after_padding_other,
|
|
3029
|
+
use_mx=False,
|
|
3030
|
+
)
|
|
3031
|
+
gemm_result = torch.ops.mslk.f4f4bf16_grouped_stacked(
|
|
3032
|
+
xq,
|
|
3033
|
+
wq,
|
|
3034
|
+
x_scale,
|
|
3035
|
+
w_scale,
|
|
3036
|
+
m_sizes,
|
|
3037
|
+
global_scale,
|
|
3038
|
+
starting_row_after_padding,
|
|
3039
|
+
use_mx=False,
|
|
3040
|
+
)
|
|
3041
|
+
assert torch.allclose(ref_solution, gemm_result)
|
|
3042
|
+
|
|
3043
|
+
return gemm_result
|
|
3044
|
+
|
|
3045
|
+
def quantize_and_compute(
|
|
3046
|
+
self, x, wq, w_scale, x_global_scale, global_scale, m_sizes
|
|
3047
|
+
):
|
|
3048
|
+
(
|
|
3049
|
+
xq,
|
|
3050
|
+
wq,
|
|
3051
|
+
x_scale,
|
|
3052
|
+
w_scale,
|
|
3053
|
+
m_sizes,
|
|
3054
|
+
global_scale,
|
|
3055
|
+
starting_row_after_padding,
|
|
3056
|
+
xq_other,
|
|
3057
|
+
x_scale_other,
|
|
3058
|
+
starting_row_after_padding_other,
|
|
3059
|
+
) = self.quantize(x, wq, w_scale, x_global_scale, global_scale, m_sizes)
|
|
3060
|
+
return self.compute(
|
|
3061
|
+
xq,
|
|
3062
|
+
wq,
|
|
3063
|
+
x_scale,
|
|
3064
|
+
w_scale,
|
|
3065
|
+
m_sizes,
|
|
3066
|
+
global_scale,
|
|
3067
|
+
starting_row_after_padding,
|
|
3068
|
+
xq_other,
|
|
3069
|
+
x_scale_other,
|
|
3070
|
+
starting_row_after_padding_other,
|
|
3071
|
+
)
|
|
3072
|
+
|
|
3073
|
+
@property
|
|
3074
|
+
def name(self) -> str:
|
|
3075
|
+
return "cutlass_nv_f4f4bf16_grouped_stacked_pack_unpack"
|
|
3076
|
+
|
|
3077
|
+
@property
|
|
3078
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
3079
|
+
return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
|
|
3080
|
+
|
|
3081
|
+
@property
|
|
3082
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
3083
|
+
return {GemmType.GROUPED}
|
|
3084
|
+
|
|
3085
|
+
@property
|
|
3086
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
3087
|
+
return ComputeDtype.FP4
|
|
3088
|
+
|
|
3089
|
+
|
|
3090
|
+
@register_gemm_op
|
|
3091
|
+
class BF16GroupedGemm2d3d(GemmOpBase):
|
|
3092
|
+
"""
|
|
3093
|
+
Torch BF16 grouped GEMM with 2D inputs and 3D weights.
|
|
3094
|
+
"""
|
|
3095
|
+
|
|
3096
|
+
def preprocess(self, x, w):
|
|
3097
|
+
assert isinstance(x, list)
|
|
3098
|
+
assert isinstance(w, list)
|
|
3099
|
+
offs = torch.tensor(
|
|
3100
|
+
[i.shape[0] for i in x], dtype=torch.int32, device=x[0].device
|
|
3101
|
+
)
|
|
3102
|
+
offs = torch.cumsum(offs, dim=0).to(torch.int32)
|
|
3103
|
+
x = torch.cat(x, dim=0).contiguous() # (G * M, K)
|
|
3104
|
+
w = torch.stack(w, dim=0).contiguous() # (G, N, K)
|
|
3105
|
+
return x, w, offs
|
|
3106
|
+
|
|
3107
|
+
def quantize(self, x, w, offs):
|
|
3108
|
+
return x, w, offs
|
|
3109
|
+
|
|
3110
|
+
def compute(self, x, w, offs):
|
|
3111
|
+
return torch._grouped_mm(
|
|
3112
|
+
x,
|
|
3113
|
+
w.transpose(-2, -1),
|
|
3114
|
+
offs=offs,
|
|
3115
|
+
)
|
|
3116
|
+
|
|
3117
|
+
def quantize_and_compute(self, x, w, offs):
|
|
3118
|
+
x, w, offs = self.quantize(x, w)
|
|
3119
|
+
return self.compute(x, w, offs)
|
|
3120
|
+
|
|
3121
|
+
@property
|
|
3122
|
+
def name(self) -> str:
|
|
3123
|
+
return "bf16_baseline_grouped_2d_3d"
|
|
3124
|
+
|
|
3125
|
+
@property
|
|
3126
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
3127
|
+
return set(Accelerator)
|
|
3128
|
+
|
|
3129
|
+
@property
|
|
3130
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
3131
|
+
return {GemmType.GROUPED}
|
|
3132
|
+
|
|
3133
|
+
@property
|
|
3134
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
3135
|
+
return ComputeDtype.BF16
|
|
3136
|
+
|
|
3137
|
+
|
|
3138
|
+
@register_gemm_op
|
|
3139
|
+
class MXFP8GroupedGemm2d3d(GemmOpBase):
|
|
3140
|
+
"""
|
|
3141
|
+
MXFP8 grouped GEMM with 2D inputs and 3D weights.
|
|
3142
|
+
"""
|
|
3143
|
+
|
|
3144
|
+
def preprocess(self, x, w):
|
|
3145
|
+
assert isinstance(x, list)
|
|
3146
|
+
assert isinstance(w, list)
|
|
3147
|
+
x = torch.cat(x, dim=0).contiguous() # (G * M, K)
|
|
3148
|
+
w = torch.stack(w, dim=0).contiguous() # (G, N, K)
|
|
3149
|
+
return x, w
|
|
3150
|
+
|
|
3151
|
+
def quantize(self, x, w):
|
|
3152
|
+
block_size = 32
|
|
3153
|
+
G, N, K = w.shape
|
|
3154
|
+
total_M = x.shape[0]
|
|
3155
|
+
group_size = total_M // G
|
|
3156
|
+
input_group_end_offsets = torch.arange(
|
|
3157
|
+
group_size, total_M + 1, group_size, dtype=torch.int32, device=x.device
|
|
3158
|
+
)
|
|
3159
|
+
|
|
3160
|
+
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
|
|
3161
|
+
# as they each used for independent gemm in the grouped gemm.
|
|
3162
|
+
wq_list = []
|
|
3163
|
+
w_scale_list = []
|
|
3164
|
+
for i in range(G):
|
|
3165
|
+
w_scale, wq = to_mxfp8(w[i])
|
|
3166
|
+
w_scale = _to_blocked(w_scale)
|
|
3167
|
+
wq_list.append(wq)
|
|
3168
|
+
w_scale_list.append(w_scale)
|
|
3169
|
+
wq = torch.stack(wq_list, dim=0).contiguous()
|
|
3170
|
+
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
|
3171
|
+
|
|
3172
|
+
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
|
|
3173
|
+
# as they each used for independent gemm in the grouped gemm.
|
|
3174
|
+
xq_list = []
|
|
3175
|
+
x_scale_list = []
|
|
3176
|
+
for i in range(G):
|
|
3177
|
+
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
|
|
3178
|
+
curr_group_end = input_group_end_offsets[i]
|
|
3179
|
+
group_size = curr_group_end - prev_group_end
|
|
3180
|
+
if group_size > 0:
|
|
3181
|
+
x_slice = x[prev_group_end:curr_group_end, :]
|
|
3182
|
+
x_scale, xq = to_mxfp8(x_slice)
|
|
3183
|
+
x_scale = _to_blocked(x_scale)
|
|
3184
|
+
xq_list.append(xq)
|
|
3185
|
+
x_scale_list.append(x_scale)
|
|
3186
|
+
xq = torch.cat(xq_list, dim=0).contiguous()
|
|
3187
|
+
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
|
|
3188
|
+
x_scale = x_scale.reshape(-1, K // block_size)
|
|
3189
|
+
xq = xq.view(-1, xq.shape[-1])
|
|
3190
|
+
return xq, wq, x_scale, w_scale, input_group_end_offsets
|
|
3191
|
+
|
|
3192
|
+
def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
|
|
3193
|
+
return torch.ops.mslk.mx8mx8bf16_grouped_mm(
|
|
3194
|
+
xq,
|
|
3195
|
+
wq.transpose(-2, -1),
|
|
3196
|
+
x_scale,
|
|
3197
|
+
w_scale,
|
|
3198
|
+
input_group_end_offsets,
|
|
3199
|
+
)
|
|
3200
|
+
|
|
3201
|
+
def quantize_and_compute(self, x, w):
|
|
3202
|
+
xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
|
|
3203
|
+
return self.compute(
|
|
3204
|
+
xq,
|
|
3205
|
+
wq,
|
|
3206
|
+
x_scale,
|
|
3207
|
+
w_scale,
|
|
3208
|
+
input_group_end_offsets,
|
|
3209
|
+
)
|
|
3210
|
+
|
|
3211
|
+
@property
|
|
3212
|
+
def name(self) -> str:
|
|
3213
|
+
return "cutlass_mx8mx8bf16_grouped_mm_2d_3d"
|
|
3214
|
+
|
|
3215
|
+
@property
|
|
3216
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
3217
|
+
return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
|
|
3218
|
+
|
|
3219
|
+
@property
|
|
3220
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
3221
|
+
return {GemmType.GROUPED}
|
|
3222
|
+
|
|
3223
|
+
@property
|
|
3224
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
3225
|
+
return ComputeDtype.FP8
|
|
3226
|
+
|
|
3227
|
+
|
|
3228
|
+
@register_gemm_op
|
|
3229
|
+
class MXFP8GroupedGemm2d2d(GemmOpBase):
|
|
3230
|
+
"""
|
|
3231
|
+
MXFP8 grouped GEMM with 2D inputs and 3D weights.
|
|
3232
|
+
"""
|
|
3233
|
+
|
|
3234
|
+
def preprocess(self, x, w):
|
|
3235
|
+
assert isinstance(x, list)
|
|
3236
|
+
assert isinstance(w, list)
|
|
3237
|
+
G = len(x)
|
|
3238
|
+
x = torch.cat(x, dim=1).contiguous() # (M, total_K)
|
|
3239
|
+
w = torch.cat(w, dim=1).contiguous() # (N, total_K)
|
|
3240
|
+
return x, w, G
|
|
3241
|
+
|
|
3242
|
+
def quantize(self, x, w, G):
|
|
3243
|
+
# Simulate 2d-2d grouped gemm in backward pass `grad_weight = grad_output_t @ input`,
|
|
3244
|
+
# where we use "K" as the contracting dim which has "G" groups.
|
|
3245
|
+
M, total_K = x.shape
|
|
3246
|
+
N, _ = w.shape
|
|
3247
|
+
group_size = total_K // G
|
|
3248
|
+
input_group_end_offsets = torch.arange(
|
|
3249
|
+
group_size, total_K + 1, group_size, dtype=torch.int32, device=x.device
|
|
3250
|
+
)
|
|
3251
|
+
|
|
3252
|
+
# Convert scales to blocked format.
|
|
3253
|
+
x_list = []
|
|
3254
|
+
w_list = []
|
|
3255
|
+
x_blocked_scale_list = []
|
|
3256
|
+
w_blocked_scale_list = []
|
|
3257
|
+
|
|
3258
|
+
def round_up(x: int, y: int) -> int:
|
|
3259
|
+
return ((x + y - 1) // y) * y
|
|
3260
|
+
|
|
3261
|
+
for group_idx in range(G):
|
|
3262
|
+
# to_mxfp8 per group
|
|
3263
|
+
prev_group_end_offset = (
|
|
3264
|
+
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
|
|
3265
|
+
)
|
|
3266
|
+
curr_group_end_offset = input_group_end_offsets[group_idx]
|
|
3267
|
+
group_size = curr_group_end_offset - prev_group_end_offset
|
|
3268
|
+
if group_size > 0:
|
|
3269
|
+
x_slice = x[
|
|
3270
|
+
:, prev_group_end_offset:curr_group_end_offset
|
|
3271
|
+
].contiguous() # (M, K_group)
|
|
3272
|
+
w_slice = w[
|
|
3273
|
+
:, prev_group_end_offset:curr_group_end_offset
|
|
3274
|
+
].contiguous() # (N, K_group)
|
|
3275
|
+
x_scale_slice, xq_slice = to_mxfp8(
|
|
3276
|
+
x_slice
|
|
3277
|
+
) # scale shape -> (M, K_group // 32)
|
|
3278
|
+
w_scale_slice, wq_slice = to_mxfp8(
|
|
3279
|
+
w_slice
|
|
3280
|
+
) # scale shape -> (N, K_group // 32)
|
|
3281
|
+
x_list.append(xq_slice)
|
|
3282
|
+
w_list.append(wq_slice)
|
|
3283
|
+
|
|
3284
|
+
# Convert scales to blocked format.
|
|
3285
|
+
x_scale_slice_blocked = _to_blocked(
|
|
3286
|
+
x_scale_slice
|
|
3287
|
+
) # (round_up(M, 128), round_up(K_group//32, 4))
|
|
3288
|
+
w_scale_slice_blocked = _to_blocked(
|
|
3289
|
+
w_scale_slice
|
|
3290
|
+
) # (round_up(N, 128), round_up(K_group//32, 4))
|
|
3291
|
+
x_blocked_scale_list.append(x_scale_slice_blocked)
|
|
3292
|
+
w_blocked_scale_list.append(w_scale_slice_blocked)
|
|
3293
|
+
|
|
3294
|
+
# Assemble the full XQ and WQ
|
|
3295
|
+
xq = torch.cat(x_list, dim=1).contiguous()
|
|
3296
|
+
wq = torch.cat(w_list, dim=1).contiguous()
|
|
3297
|
+
|
|
3298
|
+
# Combine all XQ groups blocked scales into one tensor.
|
|
3299
|
+
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
|
|
3300
|
+
M_rounded = round_up(M, 128)
|
|
3301
|
+
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
|
|
3302
|
+
|
|
3303
|
+
# Combine all WQ groups blocked scales into one tensor.
|
|
3304
|
+
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
|
|
3305
|
+
N_rounded = round_up(N, 128)
|
|
3306
|
+
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
|
|
3307
|
+
return xq, wq, x_blocked_scales, w_blocked_scales, input_group_end_offsets
|
|
3308
|
+
|
|
3309
|
+
def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
|
|
3310
|
+
return torch.ops.mslk.mx8mx8bf16_grouped_mm(
|
|
3311
|
+
xq,
|
|
3312
|
+
wq.transpose(-2, -1),
|
|
3313
|
+
x_scale,
|
|
3314
|
+
w_scale,
|
|
3315
|
+
input_group_end_offsets,
|
|
3316
|
+
)
|
|
3317
|
+
|
|
3318
|
+
def quantize_and_compute(self, x, w):
|
|
3319
|
+
xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
|
|
3320
|
+
return self.compute(
|
|
3321
|
+
xq,
|
|
3322
|
+
wq,
|
|
3323
|
+
x_scale,
|
|
3324
|
+
w_scale,
|
|
3325
|
+
input_group_end_offsets,
|
|
3326
|
+
)
|
|
3327
|
+
|
|
3328
|
+
@property
|
|
3329
|
+
def name(self) -> str:
|
|
3330
|
+
return "cutlass_mx8mx8bf16_grouped_mm_2d_2d"
|
|
3331
|
+
|
|
3332
|
+
@property
|
|
3333
|
+
def supported_accelerators(self) -> set[Accelerator]:
|
|
3334
|
+
return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
|
|
3335
|
+
|
|
3336
|
+
@property
|
|
3337
|
+
def supported_gemm_types(self) -> set[GemmType]:
|
|
3338
|
+
return {GemmType.GROUPED}
|
|
3339
|
+
|
|
3340
|
+
@property
|
|
3341
|
+
def compute_dtype(self) -> ComputeDtype:
|
|
3342
|
+
return ComputeDtype.FP8
|