fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,3483 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# Keep a registry of all quantize operators.
|
|
8
|
+
import abc
|
|
9
|
+
|
|
10
|
+
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import triton # @manual=//triton:triton
|
|
15
|
+
|
|
16
|
+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
|
|
17
|
+
_to_blocked,
|
|
18
|
+
calculate_group_max,
|
|
19
|
+
get_nvfp4_global_scales_naive,
|
|
20
|
+
mega_fp4_pack,
|
|
21
|
+
mega_fp4_quantize_kernel,
|
|
22
|
+
mega_fp4_unpack,
|
|
23
|
+
quantize_nvfp4_naive,
|
|
24
|
+
triton_quantize_mx4_unpack,
|
|
25
|
+
triton_scale_nvfp4_quant,
|
|
26
|
+
triton_scale_nvfp4_quant_rms,
|
|
27
|
+
triton_scale_nvfp4_quant_silu,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
|
|
31
|
+
get_fp8_constants,
|
|
32
|
+
matmul_fp8_block,
|
|
33
|
+
matmul_fp8_row,
|
|
34
|
+
quantize_fp8_block,
|
|
35
|
+
quantize_fp8_group,
|
|
36
|
+
quantize_fp8_row,
|
|
37
|
+
scale_fp8_row,
|
|
38
|
+
to_mxfp8,
|
|
39
|
+
triton_quantize_fp8_row,
|
|
40
|
+
)
|
|
41
|
+
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
|
|
42
|
+
grouped_gemm,
|
|
43
|
+
grouped_gemm_fp8_rowwise,
|
|
44
|
+
)
|
|
45
|
+
from fbgemm_gpu.experimental.gen_ai.quantize import (
|
|
46
|
+
ck_preshuffle,
|
|
47
|
+
quantize_int4_preshuffle,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
from gen_ai.llm_inference.fb.llm.kernel.rms_norm import rms_norm
|
|
52
|
+
from gen_ai.llm_inference.fb.llm.kernel.silu_mul import silu_mul
|
|
53
|
+
except ImportError:
|
|
54
|
+
# Above is used for some experiments, but the quantize is not relying on them. Okay to just skip.
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
from tinygemm.utils import group_quantize_tensor
|
|
59
|
+
|
|
60
|
+
if torch.cuda.is_available() and torch.version.cuda:
|
|
61
|
+
torch.ops.load_library("//tinygemm:tinygemm")
|
|
62
|
+
TINYGEMM_ENABLED = True
|
|
63
|
+
except ImportError:
|
|
64
|
+
TINYGEMM_ENABLED = False
|
|
65
|
+
|
|
66
|
+
# Marlin currently only is supported only internally at Meta.
|
|
67
|
+
try:
|
|
68
|
+
from marlin.quantize import marlin_quantize
|
|
69
|
+
|
|
70
|
+
torch.ops.load_library("//ai_codesign/gen_ai/marlin:marlin_ops")
|
|
71
|
+
MARLIN_ENABLED = True
|
|
72
|
+
except ImportError:
|
|
73
|
+
MARLIN_ENABLED = False
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
from deep_gemm import (
|
|
77
|
+
gemm_fp8_fp8_bf16_nt,
|
|
78
|
+
get_col_major_tma_aligned_tensor,
|
|
79
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
|
80
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
DEEPGEMM_ENABLED = True
|
|
84
|
+
except ImportError:
|
|
85
|
+
DEEPGEMM_ENABLED = False
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# Machete is also only supported internally at Meta for now.
|
|
89
|
+
try:
|
|
90
|
+
from machete.machete import machete_gemm
|
|
91
|
+
from machete.quantize import machete_quantize_and_pack
|
|
92
|
+
|
|
93
|
+
MACHETE_ENABLED = True
|
|
94
|
+
except ImportError:
|
|
95
|
+
MACHETE_ENABLED = False
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
quantize_op_registry = []
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def round_up(x: int, y: int) -> int:
|
|
102
|
+
return ((x + y - 1) // y) * y
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class QuantizeOpBase(metaclass=abc.ABCMeta):
|
|
106
|
+
"""Helper abstract class to define expected methods of quantize ops."""
|
|
107
|
+
|
|
108
|
+
@abc.abstractmethod
|
|
109
|
+
def quantize(self, *args):
|
|
110
|
+
"""Function which quantizes inputs."""
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
@abc.abstractmethod
|
|
114
|
+
def compute(self, *args, **kwargs):
|
|
115
|
+
"""Function which performs main compute operation."""
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
@abc.abstractmethod
|
|
119
|
+
def quantize_and_compute(self, *args, **kwargs):
|
|
120
|
+
"""Function which quantizes inputs and performs main compute operation."""
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
def preprocess(self, *args):
|
|
124
|
+
"""Preprocess inputs before benchmarking. These outputs will be passed to quantize."""
|
|
125
|
+
return args
|
|
126
|
+
|
|
127
|
+
def bench_with_rotating_buffer(self, fn, args, use_cuda_graph: bool = True):
|
|
128
|
+
import copy
|
|
129
|
+
import pickle
|
|
130
|
+
|
|
131
|
+
# torch.cuda.get_device_properties does not have L2/L3 cache size,
|
|
132
|
+
# so hard code an overapproximation of L2/L3 cache size to ensure L2/L3 cache flush
|
|
133
|
+
total_buffer_size = 512 * 1024 * 1024
|
|
134
|
+
|
|
135
|
+
# Use pickle to serialize model input to estimate total sizes of input
|
|
136
|
+
input_sizes = len(pickle.dumps(args))
|
|
137
|
+
|
|
138
|
+
# Make at least one copy of the inputs
|
|
139
|
+
copy_cnt = total_buffer_size // input_sizes
|
|
140
|
+
if copy_cnt == 0:
|
|
141
|
+
copy_cnt = 1
|
|
142
|
+
|
|
143
|
+
args_list = [args]
|
|
144
|
+
for _ in range(copy_cnt):
|
|
145
|
+
args_list.append(copy.deepcopy(args))
|
|
146
|
+
|
|
147
|
+
torch.cuda.synchronize()
|
|
148
|
+
|
|
149
|
+
def rotating_buffer_fn(fn, args_list, copy_cnt):
|
|
150
|
+
for i in range(copy_cnt):
|
|
151
|
+
fn(*(args_list[i]))
|
|
152
|
+
|
|
153
|
+
if use_cuda_graph:
|
|
154
|
+
with torch.cuda.stream(torch.cuda.Stream()):
|
|
155
|
+
# A rotating_buffer_fn contains multiple runs of the fn to benchmark,
|
|
156
|
+
# so divide time accordingly
|
|
157
|
+
return triton.testing.do_bench_cudagraph(
|
|
158
|
+
lambda: rotating_buffer_fn(self.compute, args_list, copy_cnt + 1),
|
|
159
|
+
rep=200,
|
|
160
|
+
) / (copy_cnt + 1)
|
|
161
|
+
else:
|
|
162
|
+
return triton.testing.do_bench(
|
|
163
|
+
lambda: rotating_buffer_fn(self.compute, args_list, copy_cnt + 1),
|
|
164
|
+
rep=200,
|
|
165
|
+
) / (copy_cnt + 1)
|
|
166
|
+
|
|
167
|
+
def benchmark(
|
|
168
|
+
self,
|
|
169
|
+
*args,
|
|
170
|
+
bench_quantize: bool = False,
|
|
171
|
+
use_rotating_buffer_bench: bool = False,
|
|
172
|
+
use_cuda_graph: bool = True,
|
|
173
|
+
**kwargs,
|
|
174
|
+
) -> float:
|
|
175
|
+
"""Benchmark runtime of this operator."""
|
|
176
|
+
if bench_quantize:
|
|
177
|
+
if use_cuda_graph:
|
|
178
|
+
with torch.cuda.stream(torch.cuda.Stream()):
|
|
179
|
+
t = triton.testing.do_bench_cudagraph(
|
|
180
|
+
lambda: self.quantize_and_compute(*args, **kwargs), rep=200
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
t = triton.testing.do_bench(
|
|
184
|
+
lambda: self.quantize_and_compute(*args, **kwargs)
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
if use_rotating_buffer_bench:
|
|
188
|
+
t = self.bench_with_rotating_buffer(self.compute, args, use_cuda_graph)
|
|
189
|
+
else:
|
|
190
|
+
if use_cuda_graph:
|
|
191
|
+
with torch.cuda.stream(torch.cuda.Stream()):
|
|
192
|
+
t = triton.testing.do_bench_cudagraph(
|
|
193
|
+
lambda: self.compute(*args, **kwargs), rep=200
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
t = triton.testing.do_bench(lambda: self.compute(*args, **kwargs))
|
|
197
|
+
return t
|
|
198
|
+
|
|
199
|
+
@abc.abstractproperty
|
|
200
|
+
def name(self) -> str:
|
|
201
|
+
"""Name of the operator."""
|
|
202
|
+
pass
|
|
203
|
+
|
|
204
|
+
@abc.abstractproperty
|
|
205
|
+
def hip(self) -> bool:
|
|
206
|
+
"""Whether this operator supports AMD or not."""
|
|
207
|
+
pass
|
|
208
|
+
|
|
209
|
+
@abc.abstractproperty
|
|
210
|
+
def cuda(self) -> bool:
|
|
211
|
+
"""Whether this operator supports Nvidia or not."""
|
|
212
|
+
pass
|
|
213
|
+
|
|
214
|
+
@property
|
|
215
|
+
def supported(self) -> bool:
|
|
216
|
+
"""Whether this op will run on the current device."""
|
|
217
|
+
if torch.version.hip is not None:
|
|
218
|
+
return self.hip
|
|
219
|
+
elif torch.version.cuda is not None:
|
|
220
|
+
return self.cuda
|
|
221
|
+
else:
|
|
222
|
+
return False
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def register_quantize_op(op):
|
|
226
|
+
"""Decorator function for assembling all quantize ops."""
|
|
227
|
+
quantize_op_registry.append(op())
|
|
228
|
+
return op
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def get_quantize_ops() -> list[QuantizeOpBase]:
|
|
232
|
+
"""Get all registered quantize ops."""
|
|
233
|
+
return quantize_op_registry
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@register_quantize_op
|
|
237
|
+
class BF16Baseline(QuantizeOpBase):
|
|
238
|
+
"""
|
|
239
|
+
Baseline BF16 matmul.
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
def quantize(self, x, w):
|
|
243
|
+
if isinstance(x, list):
|
|
244
|
+
x = [i.bfloat16() for i in x]
|
|
245
|
+
w = [torch.transpose(i, -2, -1).bfloat16() for i in w]
|
|
246
|
+
else:
|
|
247
|
+
x = x.bfloat16()
|
|
248
|
+
w = torch.transpose(w, -2, -1).bfloat16()
|
|
249
|
+
return x, w
|
|
250
|
+
|
|
251
|
+
def compute(self, x, w):
|
|
252
|
+
# Handle both grouped and standard gemm.
|
|
253
|
+
if isinstance(x, list):
|
|
254
|
+
output = []
|
|
255
|
+
for i in range(len(x)):
|
|
256
|
+
output.append(torch.matmul(x[i], w[i]))
|
|
257
|
+
return output
|
|
258
|
+
return torch.matmul(x, w)
|
|
259
|
+
|
|
260
|
+
def quantize_and_compute(self, x, w):
|
|
261
|
+
return self.compute(*self.quantize(x, w))
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def name(self) -> str:
|
|
265
|
+
return "bf16_baseline"
|
|
266
|
+
|
|
267
|
+
@property
|
|
268
|
+
def hip(self) -> bool:
|
|
269
|
+
return True
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def cuda(self) -> bool:
|
|
273
|
+
return True
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@register_quantize_op
|
|
277
|
+
class ScaledMMBaseline(QuantizeOpBase):
|
|
278
|
+
"""
|
|
279
|
+
Reference FP8 matmul implemented in native torch with cublas or hipblas.
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def __init__(self):
|
|
283
|
+
self.fp8_dtype, _, _, _ = get_fp8_constants()
|
|
284
|
+
self.E4M3_MAX_POS: float = torch.finfo(self.fp8_dtype).max
|
|
285
|
+
self.E5M2_MAX_POS: float = torch.finfo(torch.float8_e5m2).max
|
|
286
|
+
self.FP16_MAX_POS: float = torch.finfo(torch.float16).max
|
|
287
|
+
self.EPS: float = 1e-12
|
|
288
|
+
self.fast_accum = True
|
|
289
|
+
|
|
290
|
+
def _amax_to_scale(
|
|
291
|
+
self, amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
|
|
292
|
+
) -> torch.Tensor:
|
|
293
|
+
# To make scale dtype to be fp32 for accuracy
|
|
294
|
+
amax = amax.float()
|
|
295
|
+
if float8_dtype == self.fp8_dtype:
|
|
296
|
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
|
297
|
+
res = self.E4M3_MAX_POS / torch.clamp(amax, min=self.EPS)
|
|
298
|
+
else: # e5m2
|
|
299
|
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
|
300
|
+
res = self.E5M2_MAX_POS / torch.clamp(amax, min=self.EPS)
|
|
301
|
+
|
|
302
|
+
# pyre-fixme[7]: Expected `Tensor` but got `Union[float, Tensor]`.
|
|
303
|
+
return res
|
|
304
|
+
|
|
305
|
+
def _to_fp8_saturated(
|
|
306
|
+
self, x: torch.Tensor, float8_dtype: torch.dtype
|
|
307
|
+
) -> torch.Tensor:
|
|
308
|
+
if float8_dtype == torch.float8_e4m3fn:
|
|
309
|
+
x = x.clamp(min=-1 * self.E4M3_MAX_POS, max=self.E4M3_MAX_POS)
|
|
310
|
+
else:
|
|
311
|
+
x = x.clamp(min=-1 * self.E5M2_MAX_POS, max=self.E5M2_MAX_POS)
|
|
312
|
+
return x.to(float8_dtype)
|
|
313
|
+
|
|
314
|
+
def _quantize_tensor(self, x):
|
|
315
|
+
x_amax = torch.max(torch.abs(x))
|
|
316
|
+
scale = self._amax_to_scale(x_amax, self.fp8_dtype, x.dtype)
|
|
317
|
+
scaled_x = self._to_fp8_saturated(x * scale, self.fp8_dtype)
|
|
318
|
+
x_inverse_scale = scale.reciprocal()
|
|
319
|
+
return scaled_x, x_inverse_scale
|
|
320
|
+
|
|
321
|
+
def quantize(self, x, w):
|
|
322
|
+
xq, x_scale = self._quantize_tensor(x)
|
|
323
|
+
wq, w_scale = self._quantize_tensor(w.t())
|
|
324
|
+
return xq, wq, x_scale, w_scale
|
|
325
|
+
|
|
326
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
327
|
+
output = torch._scaled_mm(
|
|
328
|
+
xq,
|
|
329
|
+
wq,
|
|
330
|
+
bias=None,
|
|
331
|
+
out_dtype=torch.bfloat16,
|
|
332
|
+
scale_a=x_scale,
|
|
333
|
+
scale_b=w_scale,
|
|
334
|
+
scale_result=None,
|
|
335
|
+
use_fast_accum=self.fast_accum,
|
|
336
|
+
)
|
|
337
|
+
return output
|
|
338
|
+
|
|
339
|
+
def quantize_and_compute(self, x, w):
|
|
340
|
+
return self.compute(*self.quantize(x, w))
|
|
341
|
+
|
|
342
|
+
@property
|
|
343
|
+
def name(self) -> str:
|
|
344
|
+
return "scaled_mm"
|
|
345
|
+
|
|
346
|
+
@property
|
|
347
|
+
def hip(self) -> bool:
|
|
348
|
+
return True
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def cuda(self) -> bool:
|
|
352
|
+
return True
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
@register_quantize_op
|
|
356
|
+
class ScaledMMRowwise(QuantizeOpBase):
|
|
357
|
+
def __init__(self):
|
|
358
|
+
self.fast_accum = True
|
|
359
|
+
self.torch_compile = False
|
|
360
|
+
|
|
361
|
+
def quantize(self, x, w):
|
|
362
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
363
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
364
|
+
return xq, wq.t(), x_scale.unsqueeze(1), w_scale.unsqueeze(0)
|
|
365
|
+
|
|
366
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
367
|
+
if self.torch_compile:
|
|
368
|
+
f = torch.compile(
|
|
369
|
+
torch._scaled_mm,
|
|
370
|
+
options={
|
|
371
|
+
"max_autotune": True,
|
|
372
|
+
"max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
|
|
373
|
+
},
|
|
374
|
+
)
|
|
375
|
+
else:
|
|
376
|
+
f = torch._scaled_mm
|
|
377
|
+
|
|
378
|
+
return f(
|
|
379
|
+
xq,
|
|
380
|
+
wq,
|
|
381
|
+
bias=None,
|
|
382
|
+
out_dtype=torch.bfloat16,
|
|
383
|
+
scale_a=x_scale,
|
|
384
|
+
scale_b=w_scale,
|
|
385
|
+
scale_result=None,
|
|
386
|
+
use_fast_accum=self.fast_accum,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
def quantize_and_compute(self, x, w):
|
|
390
|
+
return self.compute(*self.quantize(x, w))
|
|
391
|
+
|
|
392
|
+
@property
|
|
393
|
+
def name(self) -> str:
|
|
394
|
+
return "scaled_mm_rowwise"
|
|
395
|
+
|
|
396
|
+
@property
|
|
397
|
+
def hip(self) -> bool:
|
|
398
|
+
return True
|
|
399
|
+
|
|
400
|
+
@property
|
|
401
|
+
def cuda(self) -> bool:
|
|
402
|
+
return True
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
@register_quantize_op
|
|
406
|
+
class FP8TensorwiseGemm(QuantizeOpBase):
|
|
407
|
+
"""
|
|
408
|
+
FP8 matmul with tensorwise scaling.
|
|
409
|
+
"""
|
|
410
|
+
|
|
411
|
+
def quantize(self, x, w):
|
|
412
|
+
# Quantize both input tensors.
|
|
413
|
+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
|
|
414
|
+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
|
|
415
|
+
return xq, wq, x_scale, w_scale
|
|
416
|
+
|
|
417
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
418
|
+
return torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale)
|
|
419
|
+
|
|
420
|
+
def quantize_and_compute(self, x, w):
|
|
421
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
422
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
423
|
+
|
|
424
|
+
@property
|
|
425
|
+
def name(self) -> str:
|
|
426
|
+
return "cutlass_tensorwise"
|
|
427
|
+
|
|
428
|
+
@property
|
|
429
|
+
def hip(self) -> bool:
|
|
430
|
+
# Need to add support for better quantize kernel.
|
|
431
|
+
# Also may have an issue with cuda graphs.
|
|
432
|
+
return False
|
|
433
|
+
|
|
434
|
+
@property
|
|
435
|
+
def cuda(self) -> bool:
|
|
436
|
+
return True
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
@register_quantize_op
|
|
440
|
+
class BF16OSSFastGemv(QuantizeOpBase):
|
|
441
|
+
"""
|
|
442
|
+
BF16 OSS fast gemv kernel.
|
|
443
|
+
"""
|
|
444
|
+
|
|
445
|
+
def quantize(self, x, w):
|
|
446
|
+
# dummy quantize
|
|
447
|
+
return x, w
|
|
448
|
+
|
|
449
|
+
def compute(self, x, w):
|
|
450
|
+
out = torch.ops.fbgemm.bf16_fast_gemv(x, w)
|
|
451
|
+
return out
|
|
452
|
+
|
|
453
|
+
def quantize_and_compute(self, x, w):
|
|
454
|
+
x, w = self.quantize(x, w)
|
|
455
|
+
return self.compute(x, w)
|
|
456
|
+
|
|
457
|
+
@property
|
|
458
|
+
def name(self) -> str:
|
|
459
|
+
return "bf16_oss_fast_gemv"
|
|
460
|
+
|
|
461
|
+
@property
|
|
462
|
+
def hip(self) -> bool:
|
|
463
|
+
# This implementation is specific to cublas.
|
|
464
|
+
return False
|
|
465
|
+
|
|
466
|
+
@property
|
|
467
|
+
def cuda(self) -> bool:
|
|
468
|
+
return True
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
@register_quantize_op
|
|
472
|
+
class BF16Fp8OSSFastGemv(QuantizeOpBase):
|
|
473
|
+
"""
|
|
474
|
+
BF16FP8 OSS fast gemv kernel.
|
|
475
|
+
"""
|
|
476
|
+
|
|
477
|
+
def quantize(self, x, w):
|
|
478
|
+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
|
|
479
|
+
return x, wq, w_scale
|
|
480
|
+
|
|
481
|
+
def compute(self, x, wq, w_scale):
|
|
482
|
+
out = torch.ops.fbgemm.bf16fp8bf16_fast_gemv(x, wq, w_scale)
|
|
483
|
+
return out
|
|
484
|
+
|
|
485
|
+
def quantize_and_compute(self, x, w):
|
|
486
|
+
x, wq, w_scale = self.quantize(x, w)
|
|
487
|
+
return self.compute(x, wq, w_scale)
|
|
488
|
+
|
|
489
|
+
@property
|
|
490
|
+
def name(self) -> str:
|
|
491
|
+
return "bf16fp8_oss_fast_gemv"
|
|
492
|
+
|
|
493
|
+
@property
|
|
494
|
+
def hip(self) -> bool:
|
|
495
|
+
# This implementation is specific to cublas.
|
|
496
|
+
return False
|
|
497
|
+
|
|
498
|
+
@property
|
|
499
|
+
def cuda(self) -> bool:
|
|
500
|
+
return True
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
@register_quantize_op
|
|
504
|
+
class Fp8Fp8OSSFastGemv(QuantizeOpBase):
|
|
505
|
+
"""
|
|
506
|
+
FP8FP8 OSS fast gemv kernel.
|
|
507
|
+
"""
|
|
508
|
+
|
|
509
|
+
def quantize(self, x, w):
|
|
510
|
+
# rowwise quantize
|
|
511
|
+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x)
|
|
512
|
+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
|
513
|
+
return xq, wq, x_scale, w_scale
|
|
514
|
+
|
|
515
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
516
|
+
out = torch.ops.fbgemm.fp8fp8bf16_fast_gemv(
|
|
517
|
+
xq, wq, x_scale, w_scale, is_batched=False
|
|
518
|
+
)
|
|
519
|
+
return out
|
|
520
|
+
|
|
521
|
+
def quantize_and_compute(self, x, w):
|
|
522
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
523
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
524
|
+
|
|
525
|
+
@property
|
|
526
|
+
def name(self) -> str:
|
|
527
|
+
return "fp8fp8_oss_fast_gemv"
|
|
528
|
+
|
|
529
|
+
@property
|
|
530
|
+
def hip(self) -> bool:
|
|
531
|
+
# This implementation is specific to cublas.
|
|
532
|
+
return False
|
|
533
|
+
|
|
534
|
+
@property
|
|
535
|
+
def cuda(self) -> bool:
|
|
536
|
+
return True
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
@register_quantize_op
|
|
540
|
+
class Fp8OSSFastGemvBatched(QuantizeOpBase):
|
|
541
|
+
"""
|
|
542
|
+
Batched fp8 fast gemv kernel
|
|
543
|
+
"""
|
|
544
|
+
|
|
545
|
+
def quantize(self, x, w):
|
|
546
|
+
# rowwise quantize
|
|
547
|
+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x)
|
|
548
|
+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
|
549
|
+
return xq, wq, x_scale, w_scale
|
|
550
|
+
|
|
551
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
552
|
+
out = torch.ops.fbgemm.fp8fp8bf16_fast_gemv(
|
|
553
|
+
xq, wq, x_scale, w_scale, is_batched=True
|
|
554
|
+
)
|
|
555
|
+
return out
|
|
556
|
+
|
|
557
|
+
def quantize_and_compute(self, x, w):
|
|
558
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
559
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
560
|
+
|
|
561
|
+
@property
|
|
562
|
+
def name(self) -> str:
|
|
563
|
+
return "fp8fp8_oss_fast_gemv_batched"
|
|
564
|
+
|
|
565
|
+
@property
|
|
566
|
+
def hip(self) -> bool:
|
|
567
|
+
# This implementation is specific to cublas.
|
|
568
|
+
return False
|
|
569
|
+
|
|
570
|
+
@property
|
|
571
|
+
def cuda(self) -> bool:
|
|
572
|
+
return True
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
@register_quantize_op
|
|
576
|
+
class FP8CublasRowwiseGemm(QuantizeOpBase):
|
|
577
|
+
"""
|
|
578
|
+
FP8 cublas matmul with rowwise scaling.
|
|
579
|
+
"""
|
|
580
|
+
|
|
581
|
+
def quantize(self, x, w):
|
|
582
|
+
# Quantize both input tensors.
|
|
583
|
+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x)
|
|
584
|
+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
|
585
|
+
return xq, wq, x_scale, w_scale
|
|
586
|
+
|
|
587
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
588
|
+
out = torch.ops.fbgemm.f8f8bf16_cublas(xq, wq)
|
|
589
|
+
scaled_out = scale_fp8_row(out, x_scale, w_scale)
|
|
590
|
+
return scaled_out
|
|
591
|
+
|
|
592
|
+
def quantize_and_compute(self, x, w):
|
|
593
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
594
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
595
|
+
|
|
596
|
+
@property
|
|
597
|
+
def name(self) -> str:
|
|
598
|
+
return "cublas_rowwise"
|
|
599
|
+
|
|
600
|
+
@property
|
|
601
|
+
def hip(self) -> bool:
|
|
602
|
+
# This implementation is specific to cublas.
|
|
603
|
+
return False
|
|
604
|
+
|
|
605
|
+
@property
|
|
606
|
+
def cuda(self) -> bool:
|
|
607
|
+
return True
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
@register_quantize_op
|
|
611
|
+
class FP8CublasTensorwiseGemm(QuantizeOpBase):
|
|
612
|
+
"""
|
|
613
|
+
FP8 cublas matmul with tensorwise scaling.
|
|
614
|
+
"""
|
|
615
|
+
|
|
616
|
+
def quantize(self, x, w):
|
|
617
|
+
# Quantize both input tensors.
|
|
618
|
+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
|
|
619
|
+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
|
|
620
|
+
return xq, wq, x_scale, w_scale
|
|
621
|
+
|
|
622
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
623
|
+
return torch.ops.fbgemm.f8f8bf16_cublas(xq, wq, x_scale * w_scale)
|
|
624
|
+
|
|
625
|
+
def quantize_and_compute(self, x, w):
|
|
626
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
627
|
+
return self.compute(xq, wq, x_scale * w_scale)
|
|
628
|
+
|
|
629
|
+
@property
|
|
630
|
+
def name(self) -> str:
|
|
631
|
+
return "cublas_tensorwise"
|
|
632
|
+
|
|
633
|
+
@property
|
|
634
|
+
def hip(self) -> bool:
|
|
635
|
+
# This implementation is specific to cublas.
|
|
636
|
+
return False
|
|
637
|
+
|
|
638
|
+
@property
|
|
639
|
+
def cuda(self) -> bool:
|
|
640
|
+
return True
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
@register_quantize_op
|
|
644
|
+
class FP8RowwiseGemm(QuantizeOpBase):
|
|
645
|
+
"""
|
|
646
|
+
FP8 matmul with rowwise scaling.
|
|
647
|
+
"""
|
|
648
|
+
|
|
649
|
+
def __init__(self):
|
|
650
|
+
self.fast_accum = True
|
|
651
|
+
self.gemm_op = torch.ops.fbgemm.f8f8bf16_rowwise
|
|
652
|
+
self.quantize_op = quantize_fp8_row
|
|
653
|
+
|
|
654
|
+
def preprocess(self, x, w):
|
|
655
|
+
# Prequantize weights.
|
|
656
|
+
if isinstance(w, (list, tuple)):
|
|
657
|
+
wq, w_scale = zip(*[self.quantize_op(i) for i in w])
|
|
658
|
+
else:
|
|
659
|
+
wq, w_scale = self.quantize_op(w)
|
|
660
|
+
if wq.dim() == 3:
|
|
661
|
+
w_scale = w_scale.view(wq.size(0), -1)
|
|
662
|
+
return x, wq, w_scale
|
|
663
|
+
|
|
664
|
+
def quantize(self, x, wq, w_scale):
|
|
665
|
+
# Quantize both input tensors.
|
|
666
|
+
# Handle both grouped and standard gemm.
|
|
667
|
+
if isinstance(x, (list, tuple)):
|
|
668
|
+
xq, x_scale = zip(*[self.quantize_op(i) for i in x])
|
|
669
|
+
else:
|
|
670
|
+
xq, x_scale = self.quantize_op(x)
|
|
671
|
+
# Set proper batch dimension shapes.
|
|
672
|
+
if xq.dim() == 3:
|
|
673
|
+
x_scale = x_scale.view(xq.size(0), -1)
|
|
674
|
+
return xq, wq, x_scale, w_scale
|
|
675
|
+
|
|
676
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
677
|
+
# Handle group gemm if inputs are grouped.
|
|
678
|
+
if isinstance(xq, (list, tuple)):
|
|
679
|
+
output = []
|
|
680
|
+
for i in range(len(xq)):
|
|
681
|
+
output.append(
|
|
682
|
+
self.gemm_op(
|
|
683
|
+
xq[i],
|
|
684
|
+
wq[i],
|
|
685
|
+
x_scale[i],
|
|
686
|
+
w_scale[i],
|
|
687
|
+
use_fast_accum=self.fast_accum,
|
|
688
|
+
)
|
|
689
|
+
)
|
|
690
|
+
return output
|
|
691
|
+
# Unroll batched gemm if needed.
|
|
692
|
+
elif xq.dim() == 3 and wq.dim() == 3:
|
|
693
|
+
B, M, _ = xq.shape
|
|
694
|
+
_, N, _ = wq.shape
|
|
695
|
+
y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
|
|
696
|
+
for i in range(B):
|
|
697
|
+
y[i] = self.gemm_op(
|
|
698
|
+
xq[i], wq[i], x_scale[i], w_scale[i], use_fast_accum=self.fast_accum
|
|
699
|
+
)
|
|
700
|
+
return y
|
|
701
|
+
# Otherwise return normal gemm result.
|
|
702
|
+
return self.gemm_op(xq, wq, x_scale, w_scale, use_fast_accum=self.fast_accum)
|
|
703
|
+
|
|
704
|
+
def quantize_and_compute(self, x, wq, w_scale):
|
|
705
|
+
xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
|
|
706
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
707
|
+
|
|
708
|
+
@property
|
|
709
|
+
def name(self) -> str:
|
|
710
|
+
if torch.version.cuda:
|
|
711
|
+
return "cutlass_rowwise"
|
|
712
|
+
else:
|
|
713
|
+
return "ck_rowwise"
|
|
714
|
+
|
|
715
|
+
@property
|
|
716
|
+
def hip(self) -> bool:
|
|
717
|
+
return True
|
|
718
|
+
|
|
719
|
+
@property
|
|
720
|
+
def cuda(self) -> bool:
|
|
721
|
+
return True
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
@register_quantize_op
|
|
725
|
+
class FP8RowwisePreshuffleGemm(FP8RowwiseGemm):
|
|
726
|
+
"""
|
|
727
|
+
FP8 matmul with rowwise scaling and preshuffling of input B.
|
|
728
|
+
"""
|
|
729
|
+
|
|
730
|
+
def __init__(self):
|
|
731
|
+
self.fast_accum = True
|
|
732
|
+
if self.supported:
|
|
733
|
+
self.gemm_op = torch.ops.fbgemm.f8f8bf16_rowwise_preshuffle
|
|
734
|
+
|
|
735
|
+
def preprocess(self, x, w):
|
|
736
|
+
x, wq, w_scale = super().preprocess(x, w)
|
|
737
|
+
return x, ck_preshuffle(wq, 16), w_scale
|
|
738
|
+
|
|
739
|
+
@property
|
|
740
|
+
def name(self) -> str:
|
|
741
|
+
if torch.version.cuda:
|
|
742
|
+
return "cutlass_rowwise_preshuffle"
|
|
743
|
+
else:
|
|
744
|
+
return "ck_rowwise_preshuffle"
|
|
745
|
+
|
|
746
|
+
@property
|
|
747
|
+
def hip(self) -> bool:
|
|
748
|
+
return True
|
|
749
|
+
|
|
750
|
+
@property
|
|
751
|
+
def cuda(self) -> bool:
|
|
752
|
+
# Not yet supported on nvidia.
|
|
753
|
+
return False
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
@register_quantize_op
|
|
757
|
+
class FP8RowwiseGroupedGemm(QuantizeOpBase):
|
|
758
|
+
"""
|
|
759
|
+
FP8 grouped matmul with rowwise scaling.
|
|
760
|
+
"""
|
|
761
|
+
|
|
762
|
+
def preprocess(self, x, w):
|
|
763
|
+
# Apply sparsity to inputs if appropriate.
|
|
764
|
+
# First check if N and K are fixed.
|
|
765
|
+
m_values = [i.shape[0] for i in x]
|
|
766
|
+
n_values = [i.shape[0] for i in w]
|
|
767
|
+
k_values = [i.shape[1] for i in w]
|
|
768
|
+
# If so, do specialized version of initialization.
|
|
769
|
+
if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
|
|
770
|
+
m_values = [i.shape[0] for i in x]
|
|
771
|
+
# Inputs for fixed nk mode must be contiguous, however in the benchmark
|
|
772
|
+
# script they typically are not. Do a little special processing to make them
|
|
773
|
+
# work. In practice this wont be needed.
|
|
774
|
+
# Start by padding along m dimension with zeros.
|
|
775
|
+
max_m = max(m_values)
|
|
776
|
+
x = [
|
|
777
|
+
torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
|
|
778
|
+
for i in x
|
|
779
|
+
]
|
|
780
|
+
# Stack inputs into groups.
|
|
781
|
+
x = torch.stack(x).contiguous()
|
|
782
|
+
w = torch.stack(w).contiguous()
|
|
783
|
+
|
|
784
|
+
# Preapply weight quantization.
|
|
785
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
786
|
+
# Return processed tensors.
|
|
787
|
+
return (
|
|
788
|
+
x,
|
|
789
|
+
wq,
|
|
790
|
+
w_scale,
|
|
791
|
+
torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
|
|
792
|
+
)
|
|
793
|
+
# Otherwise run without sparsity.
|
|
794
|
+
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
|
|
795
|
+
return x, wq, w_scale, None
|
|
796
|
+
|
|
797
|
+
def quantize(self, x, wq, w_scale, m_values=None):
|
|
798
|
+
# Handle case where inputs are explicitly grouped and non-sparse.
|
|
799
|
+
if isinstance(x, (tuple, list)):
|
|
800
|
+
xq, x_scale = zip(*[triton_quantize_fp8_row(i) for i in x])
|
|
801
|
+
return xq, wq, x_scale, w_scale, m_values
|
|
802
|
+
# Otherwise inputs are unified tensors and sparse.
|
|
803
|
+
else:
|
|
804
|
+
B = x.shape[0]
|
|
805
|
+
xq, x_scale = triton_quantize_fp8_row(x, zero_start_index_M=m_values)
|
|
806
|
+
x_scale = x_scale.view(B, -1)
|
|
807
|
+
return xq, wq, x_scale, w_scale, m_values
|
|
808
|
+
|
|
809
|
+
def compute(self, xq, wq, x_scale, w_scale, m_values):
|
|
810
|
+
if m_values is None:
|
|
811
|
+
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
|
|
812
|
+
xq,
|
|
813
|
+
wq,
|
|
814
|
+
x_scale,
|
|
815
|
+
w_scale,
|
|
816
|
+
)
|
|
817
|
+
else:
|
|
818
|
+
# Break tensor into groups, simulates what is done e2e.
|
|
819
|
+
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
|
|
820
|
+
xq,
|
|
821
|
+
wq,
|
|
822
|
+
x_scale,
|
|
823
|
+
w_scale,
|
|
824
|
+
zero_start_index_M=m_values,
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
def quantize_and_compute(self, x, wq, w_scale, m_values=None):
|
|
828
|
+
xq, wq, x_scale, w_scale, m_values = self.quantize(x, wq, w_scale, m_values)
|
|
829
|
+
return self.compute(xq, wq, x_scale, w_scale, m_values)
|
|
830
|
+
|
|
831
|
+
@property
|
|
832
|
+
def name(self) -> str:
|
|
833
|
+
if torch.version.cuda:
|
|
834
|
+
return "cutlass_rowwise_grouped"
|
|
835
|
+
else:
|
|
836
|
+
return "ck_rowwise_grouped"
|
|
837
|
+
|
|
838
|
+
@property
|
|
839
|
+
def hip(self) -> bool:
|
|
840
|
+
return True
|
|
841
|
+
|
|
842
|
+
@property
|
|
843
|
+
def cuda(self) -> bool:
|
|
844
|
+
return True
|
|
845
|
+
|
|
846
|
+
|
|
847
|
+
@register_quantize_op
|
|
848
|
+
class BF16TritonStackedGroupedGemm(QuantizeOpBase):
|
|
849
|
+
"""
|
|
850
|
+
BF16 grouped matmul with stacked inputs implemented with triton.
|
|
851
|
+
"""
|
|
852
|
+
|
|
853
|
+
def preprocess(self, x, w):
|
|
854
|
+
m_values = [i.shape[0] for i in x]
|
|
855
|
+
# Convert m_values into offsets into grouped tensor.
|
|
856
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
|
|
857
|
+
w = torch.concat(w, dim=0).contiguous()
|
|
858
|
+
# Also view input as flattened.
|
|
859
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
860
|
+
# Return processed tensors.
|
|
861
|
+
return x, w, m_sizes
|
|
862
|
+
|
|
863
|
+
def quantize(self, x, w, m_sizes):
|
|
864
|
+
return x, w, m_sizes
|
|
865
|
+
|
|
866
|
+
def compute(self, x, w, m_sizes):
|
|
867
|
+
return grouped_gemm(x, w, m_sizes, _use_warp_specialization=True)
|
|
868
|
+
|
|
869
|
+
def quantize_and_compute(self, x, w, m_sizes):
|
|
870
|
+
x, w, m_sizes = self.quantize(x, w, m_sizes)
|
|
871
|
+
return self.compute(x, w, m_sizes)
|
|
872
|
+
|
|
873
|
+
@property
|
|
874
|
+
def name(self) -> str:
|
|
875
|
+
return "triton_bf16_grouped_stacked"
|
|
876
|
+
|
|
877
|
+
@property
|
|
878
|
+
def hip(self) -> bool:
|
|
879
|
+
return True
|
|
880
|
+
|
|
881
|
+
@property
|
|
882
|
+
def cuda(self) -> bool:
|
|
883
|
+
return True
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
@register_quantize_op
|
|
887
|
+
class BF16TritonStackedGroupedGemmFuseScatterAdd(BF16TritonStackedGroupedGemm):
|
|
888
|
+
"""
|
|
889
|
+
BF16 grouped matmul with stacked inputs implemented with triton. Fused with ScatterAdd.
|
|
890
|
+
"""
|
|
891
|
+
|
|
892
|
+
def preprocess(self, x, w):
|
|
893
|
+
x, w, m_sizes = super().preprocess(x, w)
|
|
894
|
+
M = x.shape[0]
|
|
895
|
+
N = w.shape[0] // m_sizes.shape[0]
|
|
896
|
+
output = torch.zeros(M, N, dtype=torch.bfloat16, device=x.device)
|
|
897
|
+
indices = torch.randperm(M, dtype=torch.int32, device=x.device)
|
|
898
|
+
return x, w, m_sizes, output, indices
|
|
899
|
+
|
|
900
|
+
def quantize(self, x, w, m_sizes, *args):
|
|
901
|
+
return *super().quantize(x, w, m_sizes), *args
|
|
902
|
+
|
|
903
|
+
def compute(self, x, w, m_sizes, output, indices):
|
|
904
|
+
return grouped_gemm(
|
|
905
|
+
x,
|
|
906
|
+
w,
|
|
907
|
+
m_sizes,
|
|
908
|
+
_use_warp_specialization=True,
|
|
909
|
+
_output_tensor=output,
|
|
910
|
+
_scatter_add_indices=indices,
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
def quantize_and_compute(self, x, w, m_sizes, *args):
|
|
914
|
+
x, w, m_sizes, *ret = self.quantize(x, w, m_sizes, *args)
|
|
915
|
+
return self.compute(x, w, m_sizes, *ret)
|
|
916
|
+
|
|
917
|
+
@property
|
|
918
|
+
def name(self) -> str:
|
|
919
|
+
return "triton_bf16_grouped_stacked_fuse_scatter_add"
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
@register_quantize_op
|
|
923
|
+
class FP8TritonStackedGroupedGemm(QuantizeOpBase):
|
|
924
|
+
"""
|
|
925
|
+
FP8 grouped matmul with rowwise scaling and stacked inputs implemented with triton.
|
|
926
|
+
"""
|
|
927
|
+
|
|
928
|
+
def preprocess(self, x, w):
|
|
929
|
+
m_values = [i.shape[0] for i in x]
|
|
930
|
+
# Convert m_values into offsets into grouped tensor.
|
|
931
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
|
|
932
|
+
# Quantize weights.
|
|
933
|
+
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
|
|
934
|
+
# Group weights as single tensor.
|
|
935
|
+
wq = torch.concat(wq, dim=0).contiguous()
|
|
936
|
+
w_scale = torch.concat(w_scale, dim=0).contiguous()
|
|
937
|
+
# Also view input as flattened.
|
|
938
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
939
|
+
# Return processed tensors.
|
|
940
|
+
return x, wq, w_scale, m_sizes
|
|
941
|
+
|
|
942
|
+
def quantize(self, x, wq, w_scale, m_sizes):
|
|
943
|
+
B = x.shape[0]
|
|
944
|
+
xq, x_scale = triton_quantize_fp8_row(x)
|
|
945
|
+
x_scale = x_scale.view(B, -1)
|
|
946
|
+
return xq, wq, x_scale, w_scale, m_sizes
|
|
947
|
+
|
|
948
|
+
def compute(self, xq, wq, x_scale, w_scale, m_sizes):
|
|
949
|
+
return grouped_gemm_fp8_rowwise(
|
|
950
|
+
xq, wq, m_sizes, x_scale, w_scale, _use_warp_specialization=True
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
def quantize_and_compute(self, x, wq, w_scale, m_sizes):
|
|
954
|
+
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
|
|
955
|
+
return self.compute(xq, wq, x_scale, w_scale, m_sizes)
|
|
956
|
+
|
|
957
|
+
@property
|
|
958
|
+
def name(self) -> str:
|
|
959
|
+
return "triton_grouped_stacked"
|
|
960
|
+
|
|
961
|
+
@property
|
|
962
|
+
def hip(self) -> bool:
|
|
963
|
+
return True
|
|
964
|
+
|
|
965
|
+
@property
|
|
966
|
+
def cuda(self) -> bool:
|
|
967
|
+
return True
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
@register_quantize_op
|
|
971
|
+
class FP8TritonStackedGroupedGemmFuseScatterAdd(FP8TritonStackedGroupedGemm):
|
|
972
|
+
"""
|
|
973
|
+
FP8 grouped matmul with stacked inputs implemented with triton. Fused with ScatterAdd.
|
|
974
|
+
"""
|
|
975
|
+
|
|
976
|
+
def preprocess(self, x, w):
|
|
977
|
+
x, wq, w_scale, m_sizes = super().preprocess(x, w)
|
|
978
|
+
M = x.shape[0]
|
|
979
|
+
N = wq.shape[0] // m_sizes.shape[0]
|
|
980
|
+
output = torch.zeros(M, N, dtype=torch.bfloat16, device=x.device)
|
|
981
|
+
indices = torch.randperm(M, dtype=torch.int32, device=x.device)
|
|
982
|
+
return x, wq, w_scale, m_sizes, output, indices
|
|
983
|
+
|
|
984
|
+
def quantize(self, x, wq, w_scale, m_sizes, *args):
|
|
985
|
+
return *super().quantize(x, wq, w_scale, m_sizes), *args
|
|
986
|
+
|
|
987
|
+
def compute(self, xq, wq, x_scale, w_scale, m_sizes, output, indices):
|
|
988
|
+
return grouped_gemm_fp8_rowwise(
|
|
989
|
+
xq,
|
|
990
|
+
wq,
|
|
991
|
+
m_sizes,
|
|
992
|
+
x_scale,
|
|
993
|
+
w_scale,
|
|
994
|
+
_use_warp_specialization=True,
|
|
995
|
+
_output_tensor=output,
|
|
996
|
+
_scatter_add_indices=indices,
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
def quantize_and_compute(self, x, wq, w_scale, m_sizes, *args):
|
|
1000
|
+
xq, wq, x_scale, w_scale, m_sizes, *ret = self.quantize(
|
|
1001
|
+
x, wq, w_scale, m_sizes, *args
|
|
1002
|
+
)
|
|
1003
|
+
return self.compute(xq, wq, x_scale, w_scale, m_sizes, *ret)
|
|
1004
|
+
|
|
1005
|
+
@property
|
|
1006
|
+
def name(self) -> str:
|
|
1007
|
+
return "triton_grouped_stacked_fuse_scatter_add"
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
@register_quantize_op
|
|
1011
|
+
class DeepGemmStacked(QuantizeOpBase):
|
|
1012
|
+
"""
|
|
1013
|
+
FP8 grouped matmul with blockwise scaling implemented with DeepGemm.
|
|
1014
|
+
"""
|
|
1015
|
+
|
|
1016
|
+
def preprocess(self, x, w):
|
|
1017
|
+
m_values = [i.shape[0] for i in x]
|
|
1018
|
+
# Convert m_values into offsets into grouped tensor.
|
|
1019
|
+
indices = torch.arange(len(m_values))
|
|
1020
|
+
m_indices = indices.repeat_interleave(torch.tensor(m_values)).to(
|
|
1021
|
+
device=x[0].device, dtype=torch.int
|
|
1022
|
+
)
|
|
1023
|
+
# Quantize weights.
|
|
1024
|
+
wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w])
|
|
1025
|
+
# Group weights as single tensor.
|
|
1026
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
1027
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
1028
|
+
# Also view input as flattened.
|
|
1029
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
1030
|
+
# Return processed tensors.
|
|
1031
|
+
return x, wq, w_scale, m_indices
|
|
1032
|
+
|
|
1033
|
+
def quantize(self, x, wq, w_scale, m_indices):
|
|
1034
|
+
xq, x_scale = quantize_fp8_block(x, block_m=1, block_k=128)
|
|
1035
|
+
# Pretranspose scales to deepgemm format.
|
|
1036
|
+
x_scale = get_col_major_tma_aligned_tensor(x_scale)
|
|
1037
|
+
return xq, wq, x_scale, w_scale, m_indices
|
|
1038
|
+
|
|
1039
|
+
def compute(self, xq, wq, x_scale, w_scale, m_indices):
|
|
1040
|
+
# Preallocate output.
|
|
1041
|
+
out = torch.empty(
|
|
1042
|
+
[xq.shape[0], wq.shape[1]], device=xq.device, dtype=torch.bfloat16
|
|
1043
|
+
)
|
|
1044
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
|
1045
|
+
(xq, x_scale), (wq, w_scale), out, m_indices
|
|
1046
|
+
)
|
|
1047
|
+
return out
|
|
1048
|
+
|
|
1049
|
+
def quantize_and_compute(self, x, wq, w_scale, m_indices):
|
|
1050
|
+
xq, wq, x_scale, w_scale, m_indices = self.quantize(x, wq, w_scale, m_indices)
|
|
1051
|
+
return self.compute(xq, wq, x_scale, w_scale, m_indices)
|
|
1052
|
+
|
|
1053
|
+
@property
|
|
1054
|
+
def name(self) -> str:
|
|
1055
|
+
return "deepgemm_stacked"
|
|
1056
|
+
|
|
1057
|
+
@property
|
|
1058
|
+
def hip(self) -> bool:
|
|
1059
|
+
return False
|
|
1060
|
+
|
|
1061
|
+
@property
|
|
1062
|
+
def cuda(self) -> bool:
|
|
1063
|
+
return DEEPGEMM_ENABLED
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
@register_quantize_op
|
|
1067
|
+
class DeepGemmMaskedStacked(DeepGemmStacked):
|
|
1068
|
+
def preprocess(self, x, w):
|
|
1069
|
+
# Quantize weights.
|
|
1070
|
+
wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w])
|
|
1071
|
+
# Group weights as single tensor.
|
|
1072
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
1073
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
1074
|
+
|
|
1075
|
+
# Also view input as flattened.
|
|
1076
|
+
m_values = [i.shape[0] for i in x]
|
|
1077
|
+
expected_m = max(m_values)
|
|
1078
|
+
padded_m_max = ((max(m_values) + 127) // 128) * 128
|
|
1079
|
+
masked_m = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
|
|
1080
|
+
|
|
1081
|
+
num_groups = len(m_values)
|
|
1082
|
+
k = x[0].shape[1]
|
|
1083
|
+
x_padded = torch.zeros(
|
|
1084
|
+
[num_groups, padded_m_max, k], device=x[0].device, dtype=x[0].dtype
|
|
1085
|
+
)
|
|
1086
|
+
for g in range(num_groups):
|
|
1087
|
+
x_padded[g, : m_values[g], :] = x[g]
|
|
1088
|
+
|
|
1089
|
+
# Return processed tensors.
|
|
1090
|
+
return x_padded, wq, w_scale, masked_m, expected_m, m_values
|
|
1091
|
+
|
|
1092
|
+
def quantize(self, x, wq, w_scale, masked_m, expected_m, m_values):
|
|
1093
|
+
g, m_max, k = x.shape
|
|
1094
|
+
xq, x_scale = quantize_fp8_block(x.view(-1, k), block_m=1, block_k=128)
|
|
1095
|
+
# Pretranspose scales to deepgemm format.
|
|
1096
|
+
x_scale = get_col_major_tma_aligned_tensor(x_scale)
|
|
1097
|
+
return (
|
|
1098
|
+
xq.view(g, m_max, -1),
|
|
1099
|
+
wq,
|
|
1100
|
+
x_scale.view(g, m_max, -1),
|
|
1101
|
+
w_scale,
|
|
1102
|
+
masked_m,
|
|
1103
|
+
expected_m,
|
|
1104
|
+
m_values,
|
|
1105
|
+
)
|
|
1106
|
+
|
|
1107
|
+
def compute(self, xq, wq, x_scale, w_scale, masked_m, expected_m, m_values):
|
|
1108
|
+
# Preallocate output.
|
|
1109
|
+
out = torch.empty(
|
|
1110
|
+
[xq.shape[0], xq.shape[1], wq.shape[1]],
|
|
1111
|
+
device=xq.device,
|
|
1112
|
+
dtype=torch.bfloat16,
|
|
1113
|
+
)
|
|
1114
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
|
1115
|
+
(xq, x_scale), (wq, w_scale), out, masked_m, expected_m
|
|
1116
|
+
)
|
|
1117
|
+
num_groups = xq.shape[0]
|
|
1118
|
+
out_list = [out[g, : m_values[g], :] for g in range(num_groups)]
|
|
1119
|
+
return out_list
|
|
1120
|
+
|
|
1121
|
+
def quantize_and_compute(self, x, wq, w_scale, masked_m, expected_m, m_values):
|
|
1122
|
+
xq, wq, x_scale, w_scale, masked_m, expected_m = self.quantize(
|
|
1123
|
+
x, wq, w_scale, masked_m, expected_m, m_values
|
|
1124
|
+
)
|
|
1125
|
+
return self.compute(xq, wq, x_scale, w_scale, masked_m, expected_m, m_values)
|
|
1126
|
+
|
|
1127
|
+
@property
|
|
1128
|
+
def name(self) -> str:
|
|
1129
|
+
return "deepgemm_masked_stacked"
|
|
1130
|
+
|
|
1131
|
+
|
|
1132
|
+
@register_quantize_op
|
|
1133
|
+
class DeepGemmBlockwise(QuantizeOpBase):
|
|
1134
|
+
"""
|
|
1135
|
+
FP8 matmul with blockwise scaling implemented with DeepGemm.
|
|
1136
|
+
"""
|
|
1137
|
+
|
|
1138
|
+
def preprocess(self, x, w):
|
|
1139
|
+
# Quantize weights.
|
|
1140
|
+
wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128)
|
|
1141
|
+
# allocate output.
|
|
1142
|
+
out = torch.empty(
|
|
1143
|
+
x.shape[0], wq.shape[0], device=x.device, dtype=torch.bfloat16
|
|
1144
|
+
)
|
|
1145
|
+
# Return processed tensors.
|
|
1146
|
+
return x, wq, w_scale, out
|
|
1147
|
+
|
|
1148
|
+
def quantize(self, x, wq, w_scale, out):
|
|
1149
|
+
xq, x_scale = quantize_fp8_group(x, group_size=128)
|
|
1150
|
+
return xq, wq, x_scale, w_scale, out
|
|
1151
|
+
|
|
1152
|
+
def compute(self, xq, wq, x_scale, w_scale, out):
|
|
1153
|
+
gemm_fp8_fp8_bf16_nt((xq, x_scale), (wq, w_scale), out)
|
|
1154
|
+
return out
|
|
1155
|
+
|
|
1156
|
+
def quantize_and_compute(self, x, wq, w_scale, out):
|
|
1157
|
+
xq, wq, x_scale, w_scale, out = self.quantize(x, wq, w_scale, out)
|
|
1158
|
+
return self.compute(xq, wq, x_scale, w_scale, out)
|
|
1159
|
+
|
|
1160
|
+
@property
|
|
1161
|
+
def name(self) -> str:
|
|
1162
|
+
return "deepgemm_blockwise"
|
|
1163
|
+
|
|
1164
|
+
@property
|
|
1165
|
+
def hip(self) -> bool:
|
|
1166
|
+
return False
|
|
1167
|
+
|
|
1168
|
+
@property
|
|
1169
|
+
def cuda(self) -> bool:
|
|
1170
|
+
return True
|
|
1171
|
+
|
|
1172
|
+
|
|
1173
|
+
@register_quantize_op
|
|
1174
|
+
class DeepGemmRowwise(QuantizeOpBase):
|
|
1175
|
+
"""
|
|
1176
|
+
FP8 matmul with rowwise scaling implemented with DeepGemm.
|
|
1177
|
+
"""
|
|
1178
|
+
|
|
1179
|
+
def preprocess(self, x, w):
|
|
1180
|
+
# Quantize weights.
|
|
1181
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
1182
|
+
# allocate output.
|
|
1183
|
+
out = torch.empty(
|
|
1184
|
+
x.shape[0], wq.shape[0], device=x.device, dtype=torch.bfloat16
|
|
1185
|
+
)
|
|
1186
|
+
# Return processed tensors.
|
|
1187
|
+
return x, wq, w_scale, out
|
|
1188
|
+
|
|
1189
|
+
def quantize(self, x, wq, w_scale, out):
|
|
1190
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
1191
|
+
# Pretranspose scales to deepgemm format.
|
|
1192
|
+
x_scale = get_col_major_tma_aligned_tensor(x_scale, rowwise_scaling=True)
|
|
1193
|
+
return xq, wq, x_scale, w_scale, out
|
|
1194
|
+
|
|
1195
|
+
def compute(self, xq, wq, x_scale, w_scale, out):
|
|
1196
|
+
gemm_fp8_fp8_bf16_nt((xq, x_scale), (wq, w_scale), out)
|
|
1197
|
+
return out
|
|
1198
|
+
|
|
1199
|
+
def quantize_and_compute(self, x, wq, w_scale, out):
|
|
1200
|
+
xq, wq, x_scale, w_scale, out = self.quantize(x, wq, w_scale, out)
|
|
1201
|
+
return self.compute(xq, wq, x_scale, w_scale, out)
|
|
1202
|
+
|
|
1203
|
+
@property
|
|
1204
|
+
def name(self) -> str:
|
|
1205
|
+
return "deepgemm_rowwise"
|
|
1206
|
+
|
|
1207
|
+
@property
|
|
1208
|
+
def hip(self) -> bool:
|
|
1209
|
+
return False
|
|
1210
|
+
|
|
1211
|
+
@property
|
|
1212
|
+
def cuda(self) -> bool:
|
|
1213
|
+
return DEEPGEMM_ENABLED
|
|
1214
|
+
|
|
1215
|
+
|
|
1216
|
+
@register_quantize_op
|
|
1217
|
+
class FP8StackedGroupedGemm(QuantizeOpBase):
|
|
1218
|
+
"""
|
|
1219
|
+
FP8 grouped matmul with rowwise scaling and stacked inputs.
|
|
1220
|
+
"""
|
|
1221
|
+
|
|
1222
|
+
def preprocess(self, x, w):
|
|
1223
|
+
m_values = [i.shape[0] for i in x]
|
|
1224
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
1225
|
+
# Quantize weights.
|
|
1226
|
+
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
|
|
1227
|
+
# Group weights as single tensor.
|
|
1228
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
1229
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
1230
|
+
# Also view input as flattened.
|
|
1231
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
1232
|
+
# Return processed tensors.
|
|
1233
|
+
return x, wq, w_scale, m_sizes
|
|
1234
|
+
|
|
1235
|
+
def quantize(self, x, wq, w_scale, m_sizes):
|
|
1236
|
+
B = x.shape[0]
|
|
1237
|
+
xq, x_scale = triton_quantize_fp8_row(x)
|
|
1238
|
+
x_scale = x_scale.view(B, -1)
|
|
1239
|
+
return xq, wq, x_scale, w_scale, m_sizes
|
|
1240
|
+
|
|
1241
|
+
def compute(self, xq, wq, x_scale, w_scale, m_sizes):
|
|
1242
|
+
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
|
|
1243
|
+
xq, wq, x_scale, w_scale, m_sizes
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
def quantize_and_compute(self, x, wq, w_scale, m_sizes):
|
|
1247
|
+
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
|
|
1248
|
+
return self.compute(xq, wq, x_scale, w_scale, m_sizes)
|
|
1249
|
+
|
|
1250
|
+
@property
|
|
1251
|
+
def name(self) -> str:
|
|
1252
|
+
if torch.version.cuda:
|
|
1253
|
+
return "cutlass_grouped_stacked"
|
|
1254
|
+
else:
|
|
1255
|
+
return "ck_grouped_stacked"
|
|
1256
|
+
|
|
1257
|
+
@property
|
|
1258
|
+
def hip(self) -> bool:
|
|
1259
|
+
return True
|
|
1260
|
+
|
|
1261
|
+
@property
|
|
1262
|
+
def cuda(self) -> bool:
|
|
1263
|
+
return True
|
|
1264
|
+
|
|
1265
|
+
|
|
1266
|
+
@register_quantize_op
|
|
1267
|
+
class FP8StackedGroupedGemmTorch(FP8StackedGroupedGemm):
|
|
1268
|
+
def quantize(self, x, wq, w_scale, m_sizes):
|
|
1269
|
+
xq, wq, x_scale, w_scale, m_sizes = super().quantize(x, wq, w_scale, m_sizes)
|
|
1270
|
+
offsets = torch.cumsum(m_sizes, dim=0, dtype=torch.int32)
|
|
1271
|
+
out = torch.empty(
|
|
1272
|
+
(xq.shape[0], wq.shape[1]), dtype=torch.bfloat16, device=xq.device
|
|
1273
|
+
)
|
|
1274
|
+
x_scale = x_scale.view(x_scale.shape[0])
|
|
1275
|
+
return xq, wq, x_scale, w_scale, offsets, out
|
|
1276
|
+
|
|
1277
|
+
def compute(self, xq, wq, x_scale, w_scale, offsets, out):
|
|
1278
|
+
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_mm(
|
|
1279
|
+
xq, wq, x_scale, w_scale, offsets, out
|
|
1280
|
+
)
|
|
1281
|
+
|
|
1282
|
+
def quantize_and_compute(self, x, wq, w_scale, m_sizes):
|
|
1283
|
+
xq, wq, x_scale, w_scale, offsets, out = self.quantize(x, wq, w_scale, m_sizes)
|
|
1284
|
+
return self.compute(xq, wq, x_scale, w_scale, offsets, out)
|
|
1285
|
+
|
|
1286
|
+
@property
|
|
1287
|
+
def name(self) -> str:
|
|
1288
|
+
return "ck_grouped_stacked_torch_2d3d"
|
|
1289
|
+
|
|
1290
|
+
@property
|
|
1291
|
+
def hip(self) -> bool:
|
|
1292
|
+
return True
|
|
1293
|
+
|
|
1294
|
+
@property
|
|
1295
|
+
def cuda(self) -> bool:
|
|
1296
|
+
return False
|
|
1297
|
+
|
|
1298
|
+
|
|
1299
|
+
@register_quantize_op
|
|
1300
|
+
class ScaledGroupedMMRowwise(FP8StackedGroupedGemmTorch):
|
|
1301
|
+
def __init__(self):
|
|
1302
|
+
self.fast_accum = True
|
|
1303
|
+
self.torch_compile = False
|
|
1304
|
+
|
|
1305
|
+
def compute(self, xq, wq, x_scale, w_scale, offsets, _):
|
|
1306
|
+
if self.torch_compile:
|
|
1307
|
+
f = torch.compile(
|
|
1308
|
+
torch._scaled_grouped_mm,
|
|
1309
|
+
options={
|
|
1310
|
+
"max_autotune": True,
|
|
1311
|
+
"max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
|
|
1312
|
+
},
|
|
1313
|
+
)
|
|
1314
|
+
else:
|
|
1315
|
+
f = torch._scaled_grouped_mm
|
|
1316
|
+
|
|
1317
|
+
return f(
|
|
1318
|
+
xq,
|
|
1319
|
+
wq.transpose(-2, -1),
|
|
1320
|
+
offs=offsets,
|
|
1321
|
+
out_dtype=torch.bfloat16,
|
|
1322
|
+
scale_a=x_scale,
|
|
1323
|
+
scale_b=w_scale,
|
|
1324
|
+
scale_result=None,
|
|
1325
|
+
use_fast_accum=self.fast_accum,
|
|
1326
|
+
)
|
|
1327
|
+
|
|
1328
|
+
@property
|
|
1329
|
+
def name(self) -> str:
|
|
1330
|
+
return "scaled_grouped_mm_rowwise"
|
|
1331
|
+
|
|
1332
|
+
@property
|
|
1333
|
+
def hip(self) -> bool:
|
|
1334
|
+
return True
|
|
1335
|
+
|
|
1336
|
+
@property
|
|
1337
|
+
def cuda(self) -> bool:
|
|
1338
|
+
return True
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
@register_quantize_op
|
|
1342
|
+
class FP8StackedGroupwiseGroupedGemm(QuantizeOpBase):
|
|
1343
|
+
"""
|
|
1344
|
+
FP8 grouped matmul with groupwise scaling and stacked inputs.
|
|
1345
|
+
"""
|
|
1346
|
+
|
|
1347
|
+
def preprocess(self, x, w):
|
|
1348
|
+
m_values = [i.shape[0] for i in x]
|
|
1349
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
1350
|
+
# Quantize weights.
|
|
1351
|
+
wq, w_scale = zip(
|
|
1352
|
+
*[quantize_fp8_block(i, block_m=128, block_k=128, k_major=False) for i in w]
|
|
1353
|
+
)
|
|
1354
|
+
# Group weights as single tensor.
|
|
1355
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
1356
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
1357
|
+
# Also view input as flattened.
|
|
1358
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
1359
|
+
# Return processed tensors.
|
|
1360
|
+
return x, wq, w_scale, m_sizes
|
|
1361
|
+
|
|
1362
|
+
def quantize(self, x, wq, w_scale, m_sizes):
|
|
1363
|
+
xq, x_scale = quantize_fp8_group(x, m_sizes=m_sizes)
|
|
1364
|
+
return xq, wq, x_scale, w_scale, m_sizes
|
|
1365
|
+
|
|
1366
|
+
def compute(self, xq, wq, x_scale, w_scale, m_sizes):
|
|
1367
|
+
return torch.ops.fbgemm.f8f8bf16_groupwise_grouped(
|
|
1368
|
+
xq, wq, x_scale, w_scale, m_sizes
|
|
1369
|
+
)
|
|
1370
|
+
|
|
1371
|
+
def quantize_and_compute(self, x, wq, w_scale, m_sizes):
|
|
1372
|
+
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
|
|
1373
|
+
return self.compute(xq, wq, x_scale, w_scale, m_sizes)
|
|
1374
|
+
|
|
1375
|
+
@property
|
|
1376
|
+
def name(self) -> str:
|
|
1377
|
+
if torch.version.cuda:
|
|
1378
|
+
return "cutlass_groupwise_grouped"
|
|
1379
|
+
else:
|
|
1380
|
+
return "ck_groupwise_grouped"
|
|
1381
|
+
|
|
1382
|
+
@property
|
|
1383
|
+
def hip(self) -> bool:
|
|
1384
|
+
return False
|
|
1385
|
+
|
|
1386
|
+
@property
|
|
1387
|
+
def cuda(self) -> bool:
|
|
1388
|
+
return True
|
|
1389
|
+
|
|
1390
|
+
|
|
1391
|
+
@register_quantize_op
|
|
1392
|
+
class BF16GroupedGemm(QuantizeOpBase):
|
|
1393
|
+
"""
|
|
1394
|
+
BF16 grouped matmul implemented with CK or Cutlass.
|
|
1395
|
+
"""
|
|
1396
|
+
|
|
1397
|
+
def preprocess(self, x, w):
|
|
1398
|
+
# Apply sparsity to inputs if appropriate.
|
|
1399
|
+
# First check if N and K are fixed.
|
|
1400
|
+
m_values = [i.shape[0] for i in x]
|
|
1401
|
+
n_values = [i.shape[0] for i in w]
|
|
1402
|
+
k_values = [i.shape[1] for i in w]
|
|
1403
|
+
# If so, do specialized version of initialization.
|
|
1404
|
+
if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
|
|
1405
|
+
m_values = [i.shape[0] for i in x]
|
|
1406
|
+
# Inputs for fixed nk mode must be contiguous, however in the benchmark
|
|
1407
|
+
# script they typically are not. Do a little special processing to make them
|
|
1408
|
+
# work. In practice this wont be needed.
|
|
1409
|
+
# Start by padding along m dimension with zeros.
|
|
1410
|
+
max_m = max(m_values)
|
|
1411
|
+
x = [
|
|
1412
|
+
torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
|
|
1413
|
+
for i in x
|
|
1414
|
+
]
|
|
1415
|
+
# Stack inputs into groups.
|
|
1416
|
+
x = torch.stack(x).contiguous()
|
|
1417
|
+
w = torch.stack(w).contiguous()
|
|
1418
|
+
return (
|
|
1419
|
+
x,
|
|
1420
|
+
w,
|
|
1421
|
+
torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
|
|
1422
|
+
)
|
|
1423
|
+
return x, w, None
|
|
1424
|
+
|
|
1425
|
+
def quantize(self, x, w, m_values=None):
|
|
1426
|
+
# No action required.
|
|
1427
|
+
return x, w, m_values
|
|
1428
|
+
|
|
1429
|
+
def compute(self, x, w, m_values):
|
|
1430
|
+
if m_values is None:
|
|
1431
|
+
return torch.ops.fbgemm.bf16bf16bf16_grouped(x, w)
|
|
1432
|
+
else:
|
|
1433
|
+
return torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic(x, w, m_values)
|
|
1434
|
+
|
|
1435
|
+
def quantize_and_compute(self, x, w, m_values):
|
|
1436
|
+
return self.compute(x, w, m_values)
|
|
1437
|
+
|
|
1438
|
+
@property
|
|
1439
|
+
def name(self) -> str:
|
|
1440
|
+
if torch.version.cuda:
|
|
1441
|
+
return "cutlass_bf16_grouped"
|
|
1442
|
+
else:
|
|
1443
|
+
return "ck_bf16_grouped"
|
|
1444
|
+
|
|
1445
|
+
@property
|
|
1446
|
+
def hip(self) -> bool:
|
|
1447
|
+
return True
|
|
1448
|
+
|
|
1449
|
+
@property
|
|
1450
|
+
def cuda(self) -> bool:
|
|
1451
|
+
return True
|
|
1452
|
+
|
|
1453
|
+
|
|
1454
|
+
@register_quantize_op
|
|
1455
|
+
class FP8RowwiseBatchedGemm(QuantizeOpBase):
|
|
1456
|
+
"""
|
|
1457
|
+
FP8 batched matmul with rowwise scaling.
|
|
1458
|
+
"""
|
|
1459
|
+
|
|
1460
|
+
def quantize(self, x, w):
|
|
1461
|
+
# Quantize both input tensors.
|
|
1462
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
1463
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
1464
|
+
return xq, wq, x_scale, w_scale
|
|
1465
|
+
|
|
1466
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
1467
|
+
return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, wq, x_scale, w_scale)
|
|
1468
|
+
|
|
1469
|
+
def quantize_and_compute(self, x, w):
|
|
1470
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
1471
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1472
|
+
|
|
1473
|
+
@property
|
|
1474
|
+
def name(self) -> str:
|
|
1475
|
+
if torch.version.cuda:
|
|
1476
|
+
return "cutlass_rowwise_batched"
|
|
1477
|
+
else:
|
|
1478
|
+
return "ck_rowwise_batched"
|
|
1479
|
+
|
|
1480
|
+
@property
|
|
1481
|
+
def hip(self) -> bool:
|
|
1482
|
+
return True
|
|
1483
|
+
|
|
1484
|
+
@property
|
|
1485
|
+
def cuda(self) -> bool:
|
|
1486
|
+
return True
|
|
1487
|
+
|
|
1488
|
+
|
|
1489
|
+
@register_quantize_op
|
|
1490
|
+
class FP8LiteGemm(QuantizeOpBase):
|
|
1491
|
+
"""
|
|
1492
|
+
FP8 lite matmul for memory bound.
|
|
1493
|
+
"""
|
|
1494
|
+
|
|
1495
|
+
def quantize(self, x, w):
|
|
1496
|
+
# Quantize both input tensors.
|
|
1497
|
+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
|
|
1498
|
+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
|
|
1499
|
+
return xq, wq, x_scale, w_scale
|
|
1500
|
+
|
|
1501
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
1502
|
+
return torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
|
|
1503
|
+
|
|
1504
|
+
def quantize_and_compute(self, x, w):
|
|
1505
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
1506
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1507
|
+
|
|
1508
|
+
@property
|
|
1509
|
+
def name(self) -> str:
|
|
1510
|
+
return "cuda_lite"
|
|
1511
|
+
|
|
1512
|
+
@property
|
|
1513
|
+
def hip(self) -> bool:
|
|
1514
|
+
# Need to add support for better quantize kernel.
|
|
1515
|
+
# Also may have an issue with cuda graphs.
|
|
1516
|
+
return False
|
|
1517
|
+
|
|
1518
|
+
@property
|
|
1519
|
+
def cuda(self) -> bool:
|
|
1520
|
+
return True
|
|
1521
|
+
|
|
1522
|
+
|
|
1523
|
+
@register_quantize_op
|
|
1524
|
+
class TritonFP8RowwiseGemm(QuantizeOpBase):
|
|
1525
|
+
"""
|
|
1526
|
+
FP8 matmul with rowwise scaling implemented with Triton.
|
|
1527
|
+
"""
|
|
1528
|
+
|
|
1529
|
+
def __init__(self):
|
|
1530
|
+
self.fast_accum = True
|
|
1531
|
+
|
|
1532
|
+
def quantize(self, x, w):
|
|
1533
|
+
# Quantize both input tensors.
|
|
1534
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
1535
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
1536
|
+
bias = torch.randn(w.shape[0], device=x.device, dtype=torch.float32)
|
|
1537
|
+
return xq, wq, x_scale, w_scale, bias
|
|
1538
|
+
|
|
1539
|
+
def compute(self, xq, wq, x_scale, w_scale, bias):
|
|
1540
|
+
return matmul_fp8_row(
|
|
1541
|
+
xq,
|
|
1542
|
+
wq,
|
|
1543
|
+
x_scale,
|
|
1544
|
+
w_scale,
|
|
1545
|
+
bias=bias,
|
|
1546
|
+
fp8_fast_accum=self.fast_accum,
|
|
1547
|
+
use_warp_specialization=True,
|
|
1548
|
+
)
|
|
1549
|
+
|
|
1550
|
+
def quantize_and_compute(self, x, w):
|
|
1551
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
1552
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1553
|
+
|
|
1554
|
+
@property
|
|
1555
|
+
def name(self) -> str:
|
|
1556
|
+
return "triton_rowwise"
|
|
1557
|
+
|
|
1558
|
+
@property
|
|
1559
|
+
def hip(self) -> bool:
|
|
1560
|
+
# triton FP8 matmuls do not currently compile on AMD.
|
|
1561
|
+
return False
|
|
1562
|
+
|
|
1563
|
+
@property
|
|
1564
|
+
def cuda(self) -> bool:
|
|
1565
|
+
return True
|
|
1566
|
+
|
|
1567
|
+
|
|
1568
|
+
@register_quantize_op
|
|
1569
|
+
class FP8TritonBlockwiseGemm(QuantizeOpBase):
|
|
1570
|
+
"""
|
|
1571
|
+
FP8 matmul with block scaling.
|
|
1572
|
+
"""
|
|
1573
|
+
|
|
1574
|
+
def quantize(self, x, w):
|
|
1575
|
+
# Quantize both input tensors.
|
|
1576
|
+
xq, x_scale = quantize_fp8_block(x, 128, 128)
|
|
1577
|
+
wq, w_scale = quantize_fp8_block(w, 128, 128)
|
|
1578
|
+
return xq, wq, x_scale, w_scale
|
|
1579
|
+
|
|
1580
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
1581
|
+
return matmul_fp8_block(xq, wq, x_scale, w_scale, 128, 128, 128)
|
|
1582
|
+
|
|
1583
|
+
def quantize_and_compute(self, x, w):
|
|
1584
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
1585
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1586
|
+
|
|
1587
|
+
@property
|
|
1588
|
+
def name(self) -> str:
|
|
1589
|
+
return "triton_blockwise"
|
|
1590
|
+
|
|
1591
|
+
@property
|
|
1592
|
+
def hip(self) -> bool:
|
|
1593
|
+
# Currently has some issues.
|
|
1594
|
+
return False
|
|
1595
|
+
|
|
1596
|
+
@property
|
|
1597
|
+
def cuda(self) -> bool:
|
|
1598
|
+
return True
|
|
1599
|
+
|
|
1600
|
+
|
|
1601
|
+
@register_quantize_op
|
|
1602
|
+
class FP8CutlassBlockwiseGemm(QuantizeOpBase):
|
|
1603
|
+
"""
|
|
1604
|
+
FP8 matmul with block scaling.
|
|
1605
|
+
"""
|
|
1606
|
+
|
|
1607
|
+
def quantize(self, x, w):
|
|
1608
|
+
# Quantize both input tensors.
|
|
1609
|
+
xq, x_scale = quantize_fp8_block(x, 128, 128)
|
|
1610
|
+
wq, w_scale = quantize_fp8_block(w, 128, 128)
|
|
1611
|
+
return xq, wq, x_scale, w_scale
|
|
1612
|
+
|
|
1613
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
1614
|
+
return torch.ops.fbgemm.f8f8bf16_blockwise(
|
|
1615
|
+
xq, wq, x_scale, w_scale, 128, 128, 128
|
|
1616
|
+
)
|
|
1617
|
+
|
|
1618
|
+
def quantize_and_compute(self, x, w):
|
|
1619
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
1620
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1621
|
+
|
|
1622
|
+
@property
|
|
1623
|
+
def name(self) -> str:
|
|
1624
|
+
if torch.version.cuda:
|
|
1625
|
+
return "cutlass_blockwise"
|
|
1626
|
+
else:
|
|
1627
|
+
return "ck_blockwise"
|
|
1628
|
+
|
|
1629
|
+
@property
|
|
1630
|
+
def hip(self) -> bool:
|
|
1631
|
+
return True
|
|
1632
|
+
|
|
1633
|
+
@property
|
|
1634
|
+
def cuda(self) -> bool:
|
|
1635
|
+
return True
|
|
1636
|
+
|
|
1637
|
+
|
|
1638
|
+
@register_quantize_op
|
|
1639
|
+
class FP8CutlassGroupwiseGemm(QuantizeOpBase):
|
|
1640
|
+
"""
|
|
1641
|
+
FP8 matmul with group / block scaling.
|
|
1642
|
+
"""
|
|
1643
|
+
|
|
1644
|
+
def preprocess(self, x, w):
|
|
1645
|
+
# Quantize weights.
|
|
1646
|
+
# Scale is expected to be in [K, N] layout (N Major).
|
|
1647
|
+
wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128, k_major=False)
|
|
1648
|
+
# Return processed tensors.
|
|
1649
|
+
return x, wq, w_scale
|
|
1650
|
+
|
|
1651
|
+
def quantize(self, x, wq, w_scale):
|
|
1652
|
+
# Scale is expected to be in [K, M] layout (M Major).
|
|
1653
|
+
xq, x_scale = quantize_fp8_group(x, k_major=False)
|
|
1654
|
+
# Pretranspose scales to deepgemm format.
|
|
1655
|
+
return xq, wq, x_scale, w_scale
|
|
1656
|
+
|
|
1657
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
1658
|
+
return torch.ops.fbgemm.f8f8bf16_groupwise(xq, wq, x_scale, w_scale)
|
|
1659
|
+
|
|
1660
|
+
def quantize_and_compute(self, x, wq, w_scale):
|
|
1661
|
+
xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
|
|
1662
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1663
|
+
|
|
1664
|
+
@property
|
|
1665
|
+
def name(self) -> str:
|
|
1666
|
+
if torch.version.cuda:
|
|
1667
|
+
return "cutlass_groupwise"
|
|
1668
|
+
else:
|
|
1669
|
+
return "ck_groupwise"
|
|
1670
|
+
|
|
1671
|
+
@property
|
|
1672
|
+
def hip(self) -> bool:
|
|
1673
|
+
return False
|
|
1674
|
+
|
|
1675
|
+
@property
|
|
1676
|
+
def cuda(self) -> bool:
|
|
1677
|
+
return True
|
|
1678
|
+
|
|
1679
|
+
|
|
1680
|
+
####################################################################################################
|
|
1681
|
+
# CUTLASS kernel v2
|
|
1682
|
+
####################################################################################################
|
|
1683
|
+
|
|
1684
|
+
|
|
1685
|
+
@register_quantize_op
|
|
1686
|
+
class CutlassFP8TensorwiseGemm_v2(QuantizeOpBase):
|
|
1687
|
+
"""
|
|
1688
|
+
FP8 matmul with tensorwise scaling.
|
|
1689
|
+
"""
|
|
1690
|
+
|
|
1691
|
+
def quantize(self, x, w):
|
|
1692
|
+
# Quantize both input tensors.
|
|
1693
|
+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
|
|
1694
|
+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
|
|
1695
|
+
return xq, wq, x_scale, w_scale
|
|
1696
|
+
|
|
1697
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
1698
|
+
if hasattr(torch.ops.cutlass_extensions, "f8f8bf16"):
|
|
1699
|
+
return torch.ops.cutlass_extensions.f8f8bf16(xq, wq, x_scale * w_scale)
|
|
1700
|
+
else:
|
|
1701
|
+
raise RuntimeError(
|
|
1702
|
+
"Skipping cutlass_extensions_v2 runs as it is not supported"
|
|
1703
|
+
)
|
|
1704
|
+
|
|
1705
|
+
def quantize_and_compute(self, x, w):
|
|
1706
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
1707
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1708
|
+
|
|
1709
|
+
@property
|
|
1710
|
+
def name(self) -> str:
|
|
1711
|
+
return "cutlass_tensorwise_v2"
|
|
1712
|
+
|
|
1713
|
+
@property
|
|
1714
|
+
def hip(self) -> bool:
|
|
1715
|
+
# Need to add support for better quantize kernel.
|
|
1716
|
+
# Also may have an issue with cuda graphs.
|
|
1717
|
+
return False
|
|
1718
|
+
|
|
1719
|
+
@property
|
|
1720
|
+
def cuda(self) -> bool:
|
|
1721
|
+
return True
|
|
1722
|
+
|
|
1723
|
+
|
|
1724
|
+
# CUTLASS kernel v2
|
|
1725
|
+
@register_quantize_op
|
|
1726
|
+
class CutlassFP8RowwiseGemm_v2(QuantizeOpBase):
|
|
1727
|
+
"""
|
|
1728
|
+
FP8 matmul with rowwise scaling.
|
|
1729
|
+
"""
|
|
1730
|
+
|
|
1731
|
+
def quantize(self, x, w):
|
|
1732
|
+
# Quantize both input tensors.
|
|
1733
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
1734
|
+
wq, w_scale = quantize_fp8_row(w)
|
|
1735
|
+
return xq, wq, x_scale, w_scale
|
|
1736
|
+
|
|
1737
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
1738
|
+
if hasattr(torch.ops.cutlass_extensions, "f8f8bf16_rowwise"):
|
|
1739
|
+
return torch.ops.cutlass_extensions.f8f8bf16_rowwise(
|
|
1740
|
+
xq, wq, x_scale, w_scale
|
|
1741
|
+
)
|
|
1742
|
+
else:
|
|
1743
|
+
raise RuntimeError(
|
|
1744
|
+
"Skipping cutlass_extensions_v2 runs as it is not supported"
|
|
1745
|
+
)
|
|
1746
|
+
|
|
1747
|
+
def quantize_and_compute(self, x, w):
|
|
1748
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
1749
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
1750
|
+
|
|
1751
|
+
@property
|
|
1752
|
+
def name(self) -> str:
|
|
1753
|
+
return "cutlass_rowwise_v2"
|
|
1754
|
+
|
|
1755
|
+
@property
|
|
1756
|
+
def hip(self) -> bool:
|
|
1757
|
+
# Need to add support for better quantize kernel.
|
|
1758
|
+
# Also may have an issue with cuda graphs.
|
|
1759
|
+
return False
|
|
1760
|
+
|
|
1761
|
+
@property
|
|
1762
|
+
def cuda(self) -> bool:
|
|
1763
|
+
return True
|
|
1764
|
+
|
|
1765
|
+
|
|
1766
|
+
####################################################################################################
|
|
1767
|
+
|
|
1768
|
+
|
|
1769
|
+
@register_quantize_op
|
|
1770
|
+
class F8I4RowwiseGemm(QuantizeOpBase):
|
|
1771
|
+
"""
|
|
1772
|
+
Mixed Precision FP8 Activations with Int4 Weights.
|
|
1773
|
+
"""
|
|
1774
|
+
|
|
1775
|
+
def _int4_row_quantize(
|
|
1776
|
+
self,
|
|
1777
|
+
x: torch.Tensor,
|
|
1778
|
+
group_size: int = 128,
|
|
1779
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1780
|
+
n_bit = 4 # Number of target bits.
|
|
1781
|
+
to_quant = x.reshape(-1, group_size).to(torch.float)
|
|
1782
|
+
|
|
1783
|
+
max_val = to_quant.amax(dim=1, keepdim=True)
|
|
1784
|
+
min_val = to_quant.amin(dim=1, keepdim=True)
|
|
1785
|
+
max_int = 2**n_bit - 1
|
|
1786
|
+
min_int = 0
|
|
1787
|
+
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
|
1788
|
+
|
|
1789
|
+
zeros = min_val + scales * (2 ** (n_bit - 1))
|
|
1790
|
+
|
|
1791
|
+
out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
|
|
1792
|
+
|
|
1793
|
+
# Recenter output and move to int8.
|
|
1794
|
+
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
|
|
1795
|
+
|
|
1796
|
+
# Cutlass expects column major layout for scale and zero point,
|
|
1797
|
+
# so we transpose here and make them contiguous.
|
|
1798
|
+
scales = scales.view(x.shape[0], -1).t().contiguous()
|
|
1799
|
+
zeros = zeros.view(x.shape[0], -1).t().contiguous()
|
|
1800
|
+
|
|
1801
|
+
return out, scales, zeros
|
|
1802
|
+
|
|
1803
|
+
def _pack_int4(self, x: torch.Tensor) -> torch.Tensor:
|
|
1804
|
+
# Given int8 x, pack adjacent int4 values into a single int8.
|
|
1805
|
+
low_x = x[:, ::2]
|
|
1806
|
+
high_x = x[:, 1::2]
|
|
1807
|
+
|
|
1808
|
+
# High bits need to left shift, this also masks off extra bits.
|
|
1809
|
+
high_x = torch.bitwise_left_shift(high_x, 4)
|
|
1810
|
+
# Low bits need to have sign bits removed.
|
|
1811
|
+
low_x = torch.bitwise_and(low_x, 0xF)
|
|
1812
|
+
|
|
1813
|
+
# Recombine into a single value with bitwise or.
|
|
1814
|
+
return torch.bitwise_or(low_x, high_x).contiguous()
|
|
1815
|
+
|
|
1816
|
+
def quantize(self, x, w):
|
|
1817
|
+
# Quantize both input tensors.
|
|
1818
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
1819
|
+
wq, w_scale, w_zp = self._int4_row_quantize(w)
|
|
1820
|
+
# Pack int4 values together.
|
|
1821
|
+
wq = self._pack_int4(wq)
|
|
1822
|
+
return xq, wq, x_scale, w_scale, w_zp
|
|
1823
|
+
|
|
1824
|
+
def compute(self, xq, wq, x_scale, w_scale, w_zp):
|
|
1825
|
+
return torch.ops.fbgemm.f8i4bf16_rowwise(xq, wq, x_scale, w_scale, w_zp)
|
|
1826
|
+
|
|
1827
|
+
def quantize_and_compute(self, x, w):
|
|
1828
|
+
xq, wq, x_scale, w_scale, w_zp = self.quantize(x, w)
|
|
1829
|
+
return self.compute(xq, wq, x_scale, w_scale, w_zp)
|
|
1830
|
+
|
|
1831
|
+
@property
|
|
1832
|
+
def name(self) -> str:
|
|
1833
|
+
return "cutlass_f8i4_rowwise"
|
|
1834
|
+
|
|
1835
|
+
@property
|
|
1836
|
+
def hip(self) -> bool:
|
|
1837
|
+
# Not yet supported on AMD.
|
|
1838
|
+
return False
|
|
1839
|
+
|
|
1840
|
+
@property
|
|
1841
|
+
def cuda(self) -> bool:
|
|
1842
|
+
return True
|
|
1843
|
+
|
|
1844
|
+
|
|
1845
|
+
@register_quantize_op
|
|
1846
|
+
class F8I4ShuffledGemm(QuantizeOpBase):
|
|
1847
|
+
def preprocess(self, x, w):
|
|
1848
|
+
# Prequantize and pack weights.
|
|
1849
|
+
wq, (group_scale, row_scale) = quantize_int4_preshuffle(w)
|
|
1850
|
+
return x, wq, row_scale, group_scale
|
|
1851
|
+
|
|
1852
|
+
def quantize(self, x, wq, row_scale, group_scale):
|
|
1853
|
+
# Quantize both input tensors.
|
|
1854
|
+
xq, x_scale = quantize_fp8_row(x)
|
|
1855
|
+
return xq, wq, x_scale, row_scale, group_scale
|
|
1856
|
+
|
|
1857
|
+
def compute(self, xq, wq, x_scale, row_scale, group_scale):
|
|
1858
|
+
# Handle batched cases by looping over each batch.
|
|
1859
|
+
if xq.dim() == 3:
|
|
1860
|
+
B, M, _ = xq.shape
|
|
1861
|
+
_, N, _ = wq.shape
|
|
1862
|
+
y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
|
|
1863
|
+
for i in range(B):
|
|
1864
|
+
y[i] = torch.ops.fbgemm.f8i4bf16_shuffled(
|
|
1865
|
+
xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i]
|
|
1866
|
+
)
|
|
1867
|
+
return y
|
|
1868
|
+
# Otherwise run gemm normally.
|
|
1869
|
+
return torch.ops.fbgemm.f8i4bf16_shuffled(
|
|
1870
|
+
xq, wq, x_scale, row_scale, group_scale
|
|
1871
|
+
)
|
|
1872
|
+
|
|
1873
|
+
def quantize_and_compute(self, x, wq, row_scale, group_scale):
|
|
1874
|
+
xq, wq, x_scale, row_scale, group_scale = self.quantize(
|
|
1875
|
+
x, wq, row_scale, group_scale
|
|
1876
|
+
)
|
|
1877
|
+
return self.compute(xq, wq, x_scale, row_scale, group_scale)
|
|
1878
|
+
|
|
1879
|
+
@property
|
|
1880
|
+
def name(self) -> str:
|
|
1881
|
+
return "cutlass_f8i4_preshuffle"
|
|
1882
|
+
|
|
1883
|
+
@property
|
|
1884
|
+
def hip(self) -> bool:
|
|
1885
|
+
# Not yet supported on AMD.
|
|
1886
|
+
return False
|
|
1887
|
+
|
|
1888
|
+
@property
|
|
1889
|
+
def cuda(self) -> bool:
|
|
1890
|
+
return True
|
|
1891
|
+
|
|
1892
|
+
|
|
1893
|
+
@register_quantize_op
|
|
1894
|
+
class BF16I4ShuffledGemm(QuantizeOpBase):
|
|
1895
|
+
def preprocess(self, x, w):
|
|
1896
|
+
# Prequantize and pack weights.
|
|
1897
|
+
wq, (group_scale, group_zero) = quantize_int4_preshuffle(w, dtype="bf16")
|
|
1898
|
+
return x, wq, group_scale, group_zero
|
|
1899
|
+
|
|
1900
|
+
def quantize(self, x, wq, group_scale, group_zero):
|
|
1901
|
+
# No extra action required.
|
|
1902
|
+
return x, wq, group_scale, group_zero
|
|
1903
|
+
|
|
1904
|
+
def compute(self, x, wq, group_scale, group_zero):
|
|
1905
|
+
# Handle batched cases by looping over each batch.
|
|
1906
|
+
if x.dim() == 3:
|
|
1907
|
+
B, M, _ = x.shape
|
|
1908
|
+
_, N, _ = wq.shape
|
|
1909
|
+
y = torch.empty((B, M, N), device=x.device, dtype=torch.bfloat16)
|
|
1910
|
+
for i in range(B):
|
|
1911
|
+
y[i] = torch.ops.fbgemm.bf16i4bf16_shuffled(
|
|
1912
|
+
x[i], wq[i], group_scale[i], group_zero[i]
|
|
1913
|
+
)
|
|
1914
|
+
return y
|
|
1915
|
+
# Otherwise run gemm normally.
|
|
1916
|
+
return torch.ops.fbgemm.bf16i4bf16_shuffled(x, wq, group_scale, group_zero)
|
|
1917
|
+
|
|
1918
|
+
def quantize_and_compute(self, x, wq, group_scale, group_zero):
|
|
1919
|
+
x, wq, group_scale, group_zero = self.quantize(x, wq, group_scale, group_zero)
|
|
1920
|
+
return self.compute(x, wq, group_scale, group_zero)
|
|
1921
|
+
|
|
1922
|
+
@property
|
|
1923
|
+
def name(self) -> str:
|
|
1924
|
+
return "cutlass_bf16i4_preshuffle"
|
|
1925
|
+
|
|
1926
|
+
@property
|
|
1927
|
+
def hip(self) -> bool:
|
|
1928
|
+
# Not yet supported on AMD.
|
|
1929
|
+
return False
|
|
1930
|
+
|
|
1931
|
+
@property
|
|
1932
|
+
def cuda(self) -> bool:
|
|
1933
|
+
return True
|
|
1934
|
+
|
|
1935
|
+
|
|
1936
|
+
@register_quantize_op
|
|
1937
|
+
class BF16I4ShuffledBatchedGemm(QuantizeOpBase):
|
|
1938
|
+
"""
|
|
1939
|
+
BF16 x INT4 mixed dtype batched gemm with preshuffling.
|
|
1940
|
+
"""
|
|
1941
|
+
|
|
1942
|
+
def preprocess(self, x, w):
|
|
1943
|
+
# Prequantize and pack weights.
|
|
1944
|
+
wq, (group_scale, group_zero) = quantize_int4_preshuffle(w, dtype="bf16")
|
|
1945
|
+
return x, wq, group_scale, group_zero
|
|
1946
|
+
|
|
1947
|
+
def quantize(self, x, wq, group_scale, group_zero):
|
|
1948
|
+
# No extra action required.
|
|
1949
|
+
return x, wq, group_scale, group_zero
|
|
1950
|
+
|
|
1951
|
+
def compute(self, x, wq, group_scale, group_zero):
|
|
1952
|
+
return torch.ops.fbgemm.bf16i4bf16_shuffled_batched(
|
|
1953
|
+
x, wq, group_scale, group_zero
|
|
1954
|
+
)
|
|
1955
|
+
|
|
1956
|
+
def quantize_and_compute(self, x, wq, group_scale, group_zero):
|
|
1957
|
+
x, wq, group_scale, group_zero = self.quantize(x, wq, group_scale, group_zero)
|
|
1958
|
+
return self.compute(x, wq, group_scale, group_zero)
|
|
1959
|
+
|
|
1960
|
+
@property
|
|
1961
|
+
def name(self) -> str:
|
|
1962
|
+
return "cutlass_bf16i4_preshuffle_batched"
|
|
1963
|
+
|
|
1964
|
+
@property
|
|
1965
|
+
def hip(self) -> bool:
|
|
1966
|
+
# Not yet supported on AMD.
|
|
1967
|
+
return False
|
|
1968
|
+
|
|
1969
|
+
@property
|
|
1970
|
+
def cuda(self) -> bool:
|
|
1971
|
+
return True
|
|
1972
|
+
|
|
1973
|
+
|
|
1974
|
+
@register_quantize_op
|
|
1975
|
+
class F8I4ShuffledGroupedGemm(QuantizeOpBase):
|
|
1976
|
+
"""
|
|
1977
|
+
FP8 x Int4 mixed dtype grouped gemm with preshuffling.
|
|
1978
|
+
"""
|
|
1979
|
+
|
|
1980
|
+
def preprocess(self, x, w):
|
|
1981
|
+
assert isinstance(x, list) and isinstance(
|
|
1982
|
+
w, list
|
|
1983
|
+
), "Only supported for grouped inputs."
|
|
1984
|
+
m_values = [i.shape[0] for i in x]
|
|
1985
|
+
# Convert m_values into offsets into grouped tensor.
|
|
1986
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
|
|
1987
|
+
# Quantize weights.
|
|
1988
|
+
wq, scales = zip(*[quantize_int4_preshuffle(i) for i in w])
|
|
1989
|
+
group_scale, row_scale = zip(*scales)
|
|
1990
|
+
# Group weights as single tensor.
|
|
1991
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
1992
|
+
row_scale = torch.stack(row_scale, dim=0).contiguous()
|
|
1993
|
+
group_scale = torch.stack(group_scale, dim=0).contiguous()
|
|
1994
|
+
# Also view input as flattened.
|
|
1995
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
1996
|
+
# Return processed tensors.
|
|
1997
|
+
return x, wq, row_scale, group_scale, m_sizes
|
|
1998
|
+
|
|
1999
|
+
def quantize(self, x, wq, row_scale, group_scale, m_sizes):
|
|
2000
|
+
B = x.shape[0]
|
|
2001
|
+
xq, x_scale = triton_quantize_fp8_row(x)
|
|
2002
|
+
x_scale = x_scale.view(B, -1)
|
|
2003
|
+
return xq, wq, x_scale, row_scale, group_scale, m_sizes
|
|
2004
|
+
|
|
2005
|
+
def compute(self, xq, wq, x_scale, row_scale, group_scale, m_sizes):
|
|
2006
|
+
out = torch.ops.fbgemm.f8i4bf16_shuffled_grouped(
|
|
2007
|
+
xq, wq, x_scale, row_scale, group_scale, m_sizes
|
|
2008
|
+
)
|
|
2009
|
+
return out
|
|
2010
|
+
|
|
2011
|
+
def quantize_and_compute(self, x, wq, row_scale, group_scale, m_sizes):
|
|
2012
|
+
xq, wq, x_scale, row_scale, group_scale, m_sizes = self.quantize(
|
|
2013
|
+
x, wq, row_scale, group_scale, m_sizes
|
|
2014
|
+
)
|
|
2015
|
+
return self.compute(xq, wq, x_scale, row_scale, group_scale, m_sizes)
|
|
2016
|
+
|
|
2017
|
+
@property
|
|
2018
|
+
def name(self) -> str:
|
|
2019
|
+
if torch.version.cuda:
|
|
2020
|
+
return "cutlass_f8i4_grouped_preshuffle"
|
|
2021
|
+
else:
|
|
2022
|
+
return "ck_f8i4_grouped_preshuffle"
|
|
2023
|
+
|
|
2024
|
+
@property
|
|
2025
|
+
def hip(self) -> bool:
|
|
2026
|
+
return False
|
|
2027
|
+
|
|
2028
|
+
@property
|
|
2029
|
+
def cuda(self) -> bool:
|
|
2030
|
+
return True
|
|
2031
|
+
|
|
2032
|
+
|
|
2033
|
+
@register_quantize_op
|
|
2034
|
+
class BF16I4ShuffledGroupedGemm(QuantizeOpBase):
|
|
2035
|
+
"""
|
|
2036
|
+
BF16 x Int4 mixed dtype grouped gemm with preshuffling.
|
|
2037
|
+
"""
|
|
2038
|
+
|
|
2039
|
+
def preprocess(self, x, w):
|
|
2040
|
+
assert isinstance(x, list) and isinstance(
|
|
2041
|
+
w, list
|
|
2042
|
+
), "Only supported for grouped inputs."
|
|
2043
|
+
m_values = [i.shape[0] for i in x]
|
|
2044
|
+
# Convert m_values into offsets into grouped tensor.
|
|
2045
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
|
|
2046
|
+
# Quantize weights.
|
|
2047
|
+
wq, scales = zip(
|
|
2048
|
+
*[quantize_int4_preshuffle(i, dtype="bf16", use_zp=False) for i in w]
|
|
2049
|
+
)
|
|
2050
|
+
# Group weights as single tensor.
|
|
2051
|
+
group_scale, group_zero = zip(*scales)
|
|
2052
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
2053
|
+
group_scale = torch.stack(group_scale, dim=0).contiguous()
|
|
2054
|
+
group_zero = torch.stack(group_zero, dim=0).contiguous()
|
|
2055
|
+
# Also view input as flattened.
|
|
2056
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
2057
|
+
# Return processed tensors.
|
|
2058
|
+
return x, wq, group_scale, group_zero, m_sizes
|
|
2059
|
+
|
|
2060
|
+
def quantize(self, x, wq, group_scale, group_zero, m_sizes):
|
|
2061
|
+
return x, wq, group_scale, group_zero, m_sizes
|
|
2062
|
+
|
|
2063
|
+
def compute(self, x, wq, group_scale, group_zero, m_sizes):
|
|
2064
|
+
# TODO Zero points arent currently supported in grouped gemm.
|
|
2065
|
+
# We leave them as inputs for future compatibility but they are ignored.
|
|
2066
|
+
return torch.ops.fbgemm.bf16i4bf16_shuffled_grouped(
|
|
2067
|
+
x, wq, group_scale, group_zero, m_sizes
|
|
2068
|
+
)
|
|
2069
|
+
|
|
2070
|
+
def quantize_and_compute(self, x, wq, group_scale, group_zero, m_sizes):
|
|
2071
|
+
x, wq, group_scale, group_zero, m_sizes = self.quantize(
|
|
2072
|
+
x, wq, group_scale, group_zero, m_sizes
|
|
2073
|
+
)
|
|
2074
|
+
return self.compute(x, wq, group_scale, group_zero, m_sizes)
|
|
2075
|
+
|
|
2076
|
+
@property
|
|
2077
|
+
def name(self) -> str:
|
|
2078
|
+
if torch.version.cuda:
|
|
2079
|
+
return "cutlass_bf16i4_grouped_preshuffle"
|
|
2080
|
+
else:
|
|
2081
|
+
return "ck_bf16i4_grouped_preshuffle"
|
|
2082
|
+
|
|
2083
|
+
@property
|
|
2084
|
+
def hip(self) -> bool:
|
|
2085
|
+
return False
|
|
2086
|
+
|
|
2087
|
+
@property
|
|
2088
|
+
def cuda(self) -> bool:
|
|
2089
|
+
return True
|
|
2090
|
+
|
|
2091
|
+
|
|
2092
|
+
@register_quantize_op
|
|
2093
|
+
class BF16GroupedGrad(QuantizeOpBase):
|
|
2094
|
+
"""
|
|
2095
|
+
BF16 grouped matmul with dgrad inputs in pretraining backed by cutlass
|
|
2096
|
+
"""
|
|
2097
|
+
|
|
2098
|
+
def preprocess(self, x, w):
|
|
2099
|
+
m_values = [i.shape[0] for i in x]
|
|
2100
|
+
# Convert m_values into offsets into grouped tensor.
|
|
2101
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
2102
|
+
# Group weights as single tensor.
|
|
2103
|
+
w = torch.stack(w, dim=0).contiguous()
|
|
2104
|
+
# Prepare online dgrad during pretraining backward.
|
|
2105
|
+
w_perm = w.permute(0, 2, 1).contiguous()
|
|
2106
|
+
# w.contiguous() is very expensive so handling it inside the gmm kernel for free
|
|
2107
|
+
w = w_perm.permute(0, 2, 1)
|
|
2108
|
+
|
|
2109
|
+
# Also view input as flattened.
|
|
2110
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
2111
|
+
# Return processed tensors.
|
|
2112
|
+
return x, w, m_sizes
|
|
2113
|
+
|
|
2114
|
+
def quantize(self, x, w, m_sizes):
|
|
2115
|
+
return x, w, m_sizes
|
|
2116
|
+
|
|
2117
|
+
def compute(self, x, w, m_sizes):
|
|
2118
|
+
return torch.ops.fbgemm.bf16bf16bf16_grouped_grad(x, w, m_sizes)
|
|
2119
|
+
|
|
2120
|
+
def quantize_and_compute(self, x, w, m_sizes):
|
|
2121
|
+
x, w, m_sizes = self.quantize(x, w, m_sizes)
|
|
2122
|
+
return self.compute(x, w, m_sizes)
|
|
2123
|
+
|
|
2124
|
+
@property
|
|
2125
|
+
def name(self) -> str:
|
|
2126
|
+
return "bf16_grouped_grad"
|
|
2127
|
+
|
|
2128
|
+
@property
|
|
2129
|
+
def hip(self) -> bool:
|
|
2130
|
+
return False
|
|
2131
|
+
|
|
2132
|
+
@property
|
|
2133
|
+
def cuda(self) -> bool:
|
|
2134
|
+
return True
|
|
2135
|
+
|
|
2136
|
+
|
|
2137
|
+
@register_quantize_op
|
|
2138
|
+
class BF16GroupedWGrad(QuantizeOpBase):
|
|
2139
|
+
"""
|
|
2140
|
+
BF16 grouped matmul with wgrad inputs in pretraining backed by cutlass
|
|
2141
|
+
"""
|
|
2142
|
+
|
|
2143
|
+
def preprocess(self, x, w):
|
|
2144
|
+
# Get K values for each group
|
|
2145
|
+
k_values = [xi.shape[1] for xi in x] # K dimension for each group
|
|
2146
|
+
|
|
2147
|
+
# Convert k_values into sizes tensor
|
|
2148
|
+
k_sizes = torch.tensor(k_values).to(dtype=torch.int64, device=x[0].device)
|
|
2149
|
+
|
|
2150
|
+
x = torch.concat(x, dim=1).contiguous() # shape: (M, G*K)
|
|
2151
|
+
w = torch.concat(w, dim=1).contiguous() # shape: (N, G*K)
|
|
2152
|
+
|
|
2153
|
+
# Transpose the follows to simulate wgrad shapes
|
|
2154
|
+
x = x.t().contiguous() # shape: (G*K, M)
|
|
2155
|
+
w = w.t().contiguous() # shape: (G*K, N)
|
|
2156
|
+
|
|
2157
|
+
# Return processed tensors
|
|
2158
|
+
return x, w, k_sizes
|
|
2159
|
+
|
|
2160
|
+
def quantize(self, x, w, k_sizes):
|
|
2161
|
+
return x, w, k_sizes
|
|
2162
|
+
|
|
2163
|
+
def compute(self, x, w, k_sizes):
|
|
2164
|
+
return torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad(x, w, k_sizes)
|
|
2165
|
+
|
|
2166
|
+
def quantize_and_compute(self, x, w, k_sizes):
|
|
2167
|
+
x, w, k_sizes = self.quantize(x, w, k_sizes)
|
|
2168
|
+
return self.compute(x, w, k_sizes)
|
|
2169
|
+
|
|
2170
|
+
@property
|
|
2171
|
+
def name(self) -> str:
|
|
2172
|
+
return "bf16_grouped_wgrad"
|
|
2173
|
+
|
|
2174
|
+
@property
|
|
2175
|
+
def hip(self) -> bool:
|
|
2176
|
+
return False
|
|
2177
|
+
|
|
2178
|
+
@property
|
|
2179
|
+
def cuda(self) -> bool:
|
|
2180
|
+
return True
|
|
2181
|
+
|
|
2182
|
+
|
|
2183
|
+
@register_quantize_op
|
|
2184
|
+
class BF16GroupedStacked(QuantizeOpBase):
|
|
2185
|
+
"""
|
|
2186
|
+
BF16 grouped matmul with stacked inputs backed by cutlass or ck.
|
|
2187
|
+
"""
|
|
2188
|
+
|
|
2189
|
+
def preprocess(self, x, w):
|
|
2190
|
+
m_values = [i.shape[0] for i in x]
|
|
2191
|
+
# Convert m_values into offsets into grouped tensor.
|
|
2192
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
2193
|
+
# Group weights as single tensor.
|
|
2194
|
+
w = torch.stack(w, dim=0).contiguous()
|
|
2195
|
+
# Also view input as flattened.
|
|
2196
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
2197
|
+
# Return processed tensors.
|
|
2198
|
+
return x, w, m_sizes
|
|
2199
|
+
|
|
2200
|
+
def quantize(self, x, w, m_sizes):
|
|
2201
|
+
return x, w, m_sizes
|
|
2202
|
+
|
|
2203
|
+
def compute(self, x, w, m_sizes):
|
|
2204
|
+
return torch.ops.fbgemm.bf16bf16bf16_grouped_stacked(x, w, m_sizes)
|
|
2205
|
+
|
|
2206
|
+
def quantize_and_compute(self, x, w, m_sizes):
|
|
2207
|
+
x, w, m_sizes = self.quantize(x, w, m_sizes)
|
|
2208
|
+
return self.compute(x, w, m_sizes)
|
|
2209
|
+
|
|
2210
|
+
@property
|
|
2211
|
+
def name(self) -> str:
|
|
2212
|
+
return "bf16_grouped_stacked"
|
|
2213
|
+
|
|
2214
|
+
@property
|
|
2215
|
+
def hip(self) -> bool:
|
|
2216
|
+
return True
|
|
2217
|
+
|
|
2218
|
+
@property
|
|
2219
|
+
def cuda(self) -> bool:
|
|
2220
|
+
return True
|
|
2221
|
+
|
|
2222
|
+
|
|
2223
|
+
@register_quantize_op
|
|
2224
|
+
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
|
|
2225
|
+
"""
|
|
2226
|
+
Mixed Precision BF16 Activations with Int4 Weights.
|
|
2227
|
+
"""
|
|
2228
|
+
|
|
2229
|
+
def quantize(self, x, w):
|
|
2230
|
+
# Quantize both input tensors.
|
|
2231
|
+
wq, w_scale, w_zp = self._int4_row_quantize(w)
|
|
2232
|
+
# Pack int4 values together.
|
|
2233
|
+
wq = self._pack_int4(wq)
|
|
2234
|
+
return (
|
|
2235
|
+
x.to(torch.bfloat16),
|
|
2236
|
+
wq,
|
|
2237
|
+
w_scale,
|
|
2238
|
+
w_zp,
|
|
2239
|
+
)
|
|
2240
|
+
|
|
2241
|
+
def compute(self, x, wq, w_scale, w_zp):
|
|
2242
|
+
return torch.ops.fbgemm.bf16i4bf16_rowwise(x, wq, w_scale, w_zp)
|
|
2243
|
+
|
|
2244
|
+
def quantize_and_compute(self, x, w):
|
|
2245
|
+
x, wq, w_scale, w_zp = self.quantize(x, w)
|
|
2246
|
+
return self.compute(x, wq, w_scale, w_zp)
|
|
2247
|
+
|
|
2248
|
+
@property
|
|
2249
|
+
def name(self) -> str:
|
|
2250
|
+
return "cutlass_bf16i4_rowwise"
|
|
2251
|
+
|
|
2252
|
+
@property
|
|
2253
|
+
def hip(self) -> bool:
|
|
2254
|
+
# Not yet supported on AMD.
|
|
2255
|
+
return False
|
|
2256
|
+
|
|
2257
|
+
@property
|
|
2258
|
+
def cuda(self) -> bool:
|
|
2259
|
+
return True
|
|
2260
|
+
|
|
2261
|
+
|
|
2262
|
+
@register_quantize_op
|
|
2263
|
+
class TinyGemmBF16I4(QuantizeOpBase):
|
|
2264
|
+
"""
|
|
2265
|
+
Mixed Precision BF16 Activations with Int4 Weights using tinygemm.
|
|
2266
|
+
"""
|
|
2267
|
+
|
|
2268
|
+
def quantize(self, x, w):
|
|
2269
|
+
# Quantize and pack weights to int4 using tinygemm utils.
|
|
2270
|
+
w_int32, w_scales_and_zeros = group_quantize_tensor(
|
|
2271
|
+
w, n_bit=4, q_group_size=128
|
|
2272
|
+
)
|
|
2273
|
+
wq = torch.ops.tinygemm.convert_matrix_to_m16n8k16_Aint4_layout(w_int32, 4)
|
|
2274
|
+
return x, wq, w_scales_and_zeros
|
|
2275
|
+
|
|
2276
|
+
def compute(self, x, wq, scale):
|
|
2277
|
+
return torch.ops.tinygemm.tinygemm_y_f16RM_x_f16RM_w_int4TC(
|
|
2278
|
+
wq, x, 128, scale, False
|
|
2279
|
+
)
|
|
2280
|
+
|
|
2281
|
+
def quantize_and_compute(self, x, w):
|
|
2282
|
+
x, wq, scale = self.quantize(x, w)
|
|
2283
|
+
return self.compute(x, wq, scale)
|
|
2284
|
+
|
|
2285
|
+
@property
|
|
2286
|
+
def name(self) -> str:
|
|
2287
|
+
return "tinygemm_bf16i4"
|
|
2288
|
+
|
|
2289
|
+
@property
|
|
2290
|
+
def hip(self) -> bool:
|
|
2291
|
+
# Tinygemm only supported for cuda.
|
|
2292
|
+
return False
|
|
2293
|
+
|
|
2294
|
+
@property
|
|
2295
|
+
def cuda(self) -> bool:
|
|
2296
|
+
# Only enabled if import works.
|
|
2297
|
+
return TINYGEMM_ENABLED
|
|
2298
|
+
|
|
2299
|
+
|
|
2300
|
+
@register_quantize_op
|
|
2301
|
+
class MarlinBF16I4(QuantizeOpBase):
|
|
2302
|
+
"""
|
|
2303
|
+
Mixed Precision BF16 Activations with Int4 Weights using Marlin.
|
|
2304
|
+
"""
|
|
2305
|
+
|
|
2306
|
+
def quantize(self, x, w):
|
|
2307
|
+
# Marlin quantize expects weights in [K, N] layout.
|
|
2308
|
+
_, wq, scale = marlin_quantize(w.t().contiguous(), 128)
|
|
2309
|
+
return x, wq, scale
|
|
2310
|
+
|
|
2311
|
+
def compute(self, x, wq, scale):
|
|
2312
|
+
return torch.ops.marlin.marlin_gemm(x, wq, scale)
|
|
2313
|
+
|
|
2314
|
+
def quantize_and_compute(self, x, w):
|
|
2315
|
+
x, wq, scale = self.quantize(x, w)
|
|
2316
|
+
return self.compute(x, wq, scale)
|
|
2317
|
+
|
|
2318
|
+
@property
|
|
2319
|
+
def name(self) -> str:
|
|
2320
|
+
return "marlin_bf16i4"
|
|
2321
|
+
|
|
2322
|
+
@property
|
|
2323
|
+
def hip(self) -> bool:
|
|
2324
|
+
# Marlin only supported for cuda.
|
|
2325
|
+
return False
|
|
2326
|
+
|
|
2327
|
+
@property
|
|
2328
|
+
def cuda(self) -> bool:
|
|
2329
|
+
# This op is not always supported.
|
|
2330
|
+
return MARLIN_ENABLED
|
|
2331
|
+
|
|
2332
|
+
|
|
2333
|
+
@register_quantize_op
|
|
2334
|
+
class MacheteBF16I4(QuantizeOpBase):
|
|
2335
|
+
"""
|
|
2336
|
+
Mixed Precision BF16 Activations with Int4 Weights using Machete.
|
|
2337
|
+
"""
|
|
2338
|
+
|
|
2339
|
+
def quantize(self, x, w):
|
|
2340
|
+
# Marlin quantize expects weights in [K, N] layout.
|
|
2341
|
+
_, wq, scale, _ = machete_quantize_and_pack(
|
|
2342
|
+
w.t().contiguous(), bits=4, groupsize=128
|
|
2343
|
+
)
|
|
2344
|
+
return x, wq, scale
|
|
2345
|
+
|
|
2346
|
+
def compute(self, x, wq, scale):
|
|
2347
|
+
return machete_gemm(x, wq, bits=4, groupsize=128, scales=scale)
|
|
2348
|
+
|
|
2349
|
+
def quantize_and_compute(self, x, w):
|
|
2350
|
+
x, wq, scale = self.quantize(x, w)
|
|
2351
|
+
return self.compute(x, wq, scale)
|
|
2352
|
+
|
|
2353
|
+
@property
|
|
2354
|
+
def name(self) -> str:
|
|
2355
|
+
return "machete_bf16i4"
|
|
2356
|
+
|
|
2357
|
+
@property
|
|
2358
|
+
def hip(self) -> bool:
|
|
2359
|
+
# Machete only supported for cuda.
|
|
2360
|
+
return False
|
|
2361
|
+
|
|
2362
|
+
@property
|
|
2363
|
+
def cuda(self) -> bool:
|
|
2364
|
+
# This op is not always supported.
|
|
2365
|
+
return MACHETE_ENABLED
|
|
2366
|
+
|
|
2367
|
+
|
|
2368
|
+
@register_quantize_op
|
|
2369
|
+
class NVFP4Gemm(QuantizeOpBase):
|
|
2370
|
+
"""
|
|
2371
|
+
NVFP4 matmul with block-wise scaling.
|
|
2372
|
+
"""
|
|
2373
|
+
|
|
2374
|
+
def quantize(self, x, w):
|
|
2375
|
+
x_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(x.flatten()), dim=-1).to(
|
|
2376
|
+
torch.float32
|
|
2377
|
+
)
|
|
2378
|
+
w_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(w.flatten()), dim=-1).to(
|
|
2379
|
+
torch.float32
|
|
2380
|
+
)
|
|
2381
|
+
global_scale = 1 / (x_global_scale * w_global_scale)
|
|
2382
|
+
|
|
2383
|
+
xq, x_scale = triton_scale_nvfp4_quant(x, x_global_scale)
|
|
2384
|
+
wq, w_scale = triton_scale_nvfp4_quant(w, w_global_scale)
|
|
2385
|
+
|
|
2386
|
+
return xq, wq, x_scale, w_scale, global_scale
|
|
2387
|
+
|
|
2388
|
+
def compute(self, xq, wq, x_scale, w_scale, global_scale):
|
|
2389
|
+
return torch.ops.fbgemm.f4f4bf16(
|
|
2390
|
+
xq, wq, x_scale, w_scale, global_scale=global_scale
|
|
2391
|
+
)
|
|
2392
|
+
|
|
2393
|
+
def quantize_and_compute(self, x, w):
|
|
2394
|
+
xq, wq, x_scale, w_scale, global_scale = self.quantize(x, w)
|
|
2395
|
+
return self.compute(xq, wq, x_scale, w_scale, global_scale=global_scale)
|
|
2396
|
+
|
|
2397
|
+
@property
|
|
2398
|
+
def name(self) -> str:
|
|
2399
|
+
return "cutlass_nv_f4f4bf16"
|
|
2400
|
+
|
|
2401
|
+
@property
|
|
2402
|
+
def hip(self) -> bool:
|
|
2403
|
+
# F4F4BF16 only supported for cuda.
|
|
2404
|
+
return False
|
|
2405
|
+
|
|
2406
|
+
@property
|
|
2407
|
+
def cuda(self) -> bool:
|
|
2408
|
+
return True
|
|
2409
|
+
|
|
2410
|
+
|
|
2411
|
+
@register_quantize_op
|
|
2412
|
+
class NVFP4Quantize(QuantizeOpBase):
|
|
2413
|
+
"""
|
|
2414
|
+
NVFP4 quantization with block-wise scaling.
|
|
2415
|
+
"""
|
|
2416
|
+
|
|
2417
|
+
def quantize_rms(self, x, w):
|
|
2418
|
+
M, N = w.shape
|
|
2419
|
+
group_size = 16
|
|
2420
|
+
w = torch.randn(group_size, dtype=torch.bfloat16, device=w.device)
|
|
2421
|
+
x_global_scale = torch.tensor([448.0 * 6.0]).to(
|
|
2422
|
+
device=x.device, dtype=torch.float32
|
|
2423
|
+
) / torch.amax(torch.abs(x.flatten()), dim=-1).to(torch.float32)
|
|
2424
|
+
xq_ref, x_scale_ref = triton_scale_nvfp4_quant_rms(
|
|
2425
|
+
x,
|
|
2426
|
+
w.repeat(M * N // group_size),
|
|
2427
|
+
x_global_scale,
|
|
2428
|
+
group_size=group_size,
|
|
2429
|
+
EPS=1e-5,
|
|
2430
|
+
)
|
|
2431
|
+
|
|
2432
|
+
intermediate = rms_norm(x.reshape(-1, 16), w, eps=1e-5)
|
|
2433
|
+
intermediate = intermediate.to(torch.bfloat16).reshape(M, N)
|
|
2434
|
+
xq, x_scale = triton_scale_nvfp4_quant(
|
|
2435
|
+
intermediate,
|
|
2436
|
+
x_global_scale,
|
|
2437
|
+
group_size=group_size,
|
|
2438
|
+
)
|
|
2439
|
+
|
|
2440
|
+
def quantize_silu(self, x, w):
|
|
2441
|
+
M, N = x.shape
|
|
2442
|
+
group_size = 16
|
|
2443
|
+
x_global_scale = torch.tensor([448.0 * 6.0]).to(
|
|
2444
|
+
device=x.device, dtype=torch.float32
|
|
2445
|
+
) / torch.amax(torch.abs(x.flatten()), dim=-1).to(torch.float32)
|
|
2446
|
+
xq_ref, x_scale_ref = triton_scale_nvfp4_quant_silu(
|
|
2447
|
+
x,
|
|
2448
|
+
w,
|
|
2449
|
+
x_global_scale,
|
|
2450
|
+
group_size=group_size,
|
|
2451
|
+
)
|
|
2452
|
+
|
|
2453
|
+
intermediate = silu_mul(x.reshape(-1, 16), w.reshape(-1, 16))
|
|
2454
|
+
intermediate = intermediate.to(torch.bfloat16).reshape(M, N)
|
|
2455
|
+
xq, x_scale = triton_scale_nvfp4_quant(
|
|
2456
|
+
intermediate,
|
|
2457
|
+
x_global_scale,
|
|
2458
|
+
group_size=group_size,
|
|
2459
|
+
)
|
|
2460
|
+
|
|
2461
|
+
def quantize(self, x, w):
|
|
2462
|
+
x_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(x.flatten()), dim=-1).to(
|
|
2463
|
+
torch.float32
|
|
2464
|
+
)
|
|
2465
|
+
w_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(w.flatten()), dim=-1).to(
|
|
2466
|
+
torch.float32
|
|
2467
|
+
)
|
|
2468
|
+
global_scale = 1 / (x_global_scale * w_global_scale)
|
|
2469
|
+
|
|
2470
|
+
xq, x_scale = triton_scale_nvfp4_quant(x, x_global_scale)
|
|
2471
|
+
wq, w_scale = triton_scale_nvfp4_quant(w, w_global_scale)
|
|
2472
|
+
return xq, wq, x_scale, w_scale, global_scale
|
|
2473
|
+
|
|
2474
|
+
def compute(self, xq, wq, x_scale, w_scale, global_scale):
|
|
2475
|
+
return torch.ops.fbgemm.f4f4bf16(
|
|
2476
|
+
xq, wq, x_scale, w_scale, global_scale=global_scale
|
|
2477
|
+
)
|
|
2478
|
+
|
|
2479
|
+
def quantize_and_compute(self, x, w):
|
|
2480
|
+
return self.quantize(x, w)
|
|
2481
|
+
|
|
2482
|
+
@property
|
|
2483
|
+
def name(self) -> str:
|
|
2484
|
+
return "nvfp4_quantize"
|
|
2485
|
+
|
|
2486
|
+
@property
|
|
2487
|
+
def hip(self) -> bool:
|
|
2488
|
+
# F4F4BF16 only supported for cuda.
|
|
2489
|
+
return False
|
|
2490
|
+
|
|
2491
|
+
@property
|
|
2492
|
+
def cuda(self) -> bool:
|
|
2493
|
+
return True
|
|
2494
|
+
|
|
2495
|
+
|
|
2496
|
+
@register_quantize_op
|
|
2497
|
+
class MXFP4Gemm(QuantizeOpBase):
|
|
2498
|
+
"""
|
|
2499
|
+
MXFP4 matmul with block-wise scaling.
|
|
2500
|
+
"""
|
|
2501
|
+
|
|
2502
|
+
def quantize(self, x, w):
|
|
2503
|
+
xq, x_scale = triton_quantize_mx4_unpack(x)
|
|
2504
|
+
wq, w_scale = triton_quantize_mx4_unpack(w)
|
|
2505
|
+
return xq, wq, x_scale, w_scale
|
|
2506
|
+
|
|
2507
|
+
def compute(self, xq, wq, x_scale, w_scale):
|
|
2508
|
+
return torch.ops.fbgemm.f4f4bf16(xq, wq, x_scale, w_scale)
|
|
2509
|
+
|
|
2510
|
+
def quantize_and_compute(self, x, w):
|
|
2511
|
+
xq, wq, x_scale, w_scale = self.quantize(x, w)
|
|
2512
|
+
return self.compute(xq, wq, x_scale, w_scale)
|
|
2513
|
+
|
|
2514
|
+
@property
|
|
2515
|
+
def name(self) -> str:
|
|
2516
|
+
return "cutlass_f4f4bf16"
|
|
2517
|
+
|
|
2518
|
+
@property
|
|
2519
|
+
def hip(self) -> bool:
|
|
2520
|
+
# F4F4BF16 only supported for cuda.
|
|
2521
|
+
return False
|
|
2522
|
+
|
|
2523
|
+
@property
|
|
2524
|
+
def cuda(self) -> bool:
|
|
2525
|
+
return True
|
|
2526
|
+
|
|
2527
|
+
|
|
2528
|
+
@register_quantize_op
|
|
2529
|
+
class MXFP4StackedGroupedGemm(QuantizeOpBase):
|
|
2530
|
+
"""
|
|
2531
|
+
MXFP4 grouped matmul with blockwise scaling and stacked inputs.
|
|
2532
|
+
"""
|
|
2533
|
+
|
|
2534
|
+
def preprocess(self, x, w):
|
|
2535
|
+
m_values = [i.shape[0] for i in x]
|
|
2536
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
2537
|
+
wq, w_scale = zip(*[triton_quantize_mx4_unpack(i) for i in w])
|
|
2538
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
2539
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
2540
|
+
return x, wq, w_scale, m_sizes
|
|
2541
|
+
|
|
2542
|
+
def quantize(self, x, wq, w_scale, m_sizes):
|
|
2543
|
+
starting_row_after_padding_list = [0]
|
|
2544
|
+
xq_list = []
|
|
2545
|
+
x_scale_list = []
|
|
2546
|
+
for i in range(m_sizes.shape[0]):
|
|
2547
|
+
scale_slice = x[i]
|
|
2548
|
+
if m_sizes[i].item() != 0:
|
|
2549
|
+
xq, x_scale = triton_quantize_mx4_unpack(scale_slice)
|
|
2550
|
+
xq_list.append(xq)
|
|
2551
|
+
x_scale_list.append(x_scale)
|
|
2552
|
+
starting_row_after_padding_list.append(
|
|
2553
|
+
starting_row_after_padding_list[i]
|
|
2554
|
+
+ x_scale.numel() // (x[0].shape[1] // 32)
|
|
2555
|
+
)
|
|
2556
|
+
else:
|
|
2557
|
+
starting_row_after_padding_list.append(
|
|
2558
|
+
starting_row_after_padding_list[i]
|
|
2559
|
+
)
|
|
2560
|
+
xq = torch.cat(xq_list, dim=0).contiguous()
|
|
2561
|
+
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
|
|
2562
|
+
x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32)
|
|
2563
|
+
xq = xq.view(-1, xq.shape[-1])
|
|
2564
|
+
return (
|
|
2565
|
+
xq,
|
|
2566
|
+
wq,
|
|
2567
|
+
x_scale,
|
|
2568
|
+
w_scale,
|
|
2569
|
+
m_sizes,
|
|
2570
|
+
torch.tensor(starting_row_after_padding_list, device=xq.device),
|
|
2571
|
+
)
|
|
2572
|
+
|
|
2573
|
+
def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding):
|
|
2574
|
+
return torch.ops.fbgemm.f4f4bf16_grouped_stacked(
|
|
2575
|
+
xq,
|
|
2576
|
+
wq,
|
|
2577
|
+
x_scale,
|
|
2578
|
+
w_scale,
|
|
2579
|
+
m_sizes,
|
|
2580
|
+
starting_row_after_padding=starting_row_after_padding,
|
|
2581
|
+
)
|
|
2582
|
+
|
|
2583
|
+
def quantize_and_compute(self, x, w):
|
|
2584
|
+
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize(
|
|
2585
|
+
x, w
|
|
2586
|
+
)
|
|
2587
|
+
return self.compute(
|
|
2588
|
+
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding
|
|
2589
|
+
)
|
|
2590
|
+
|
|
2591
|
+
@property
|
|
2592
|
+
def name(self) -> str:
|
|
2593
|
+
return "cutlass_f4f4bf16_grouped_stacked"
|
|
2594
|
+
|
|
2595
|
+
@property
|
|
2596
|
+
def hip(self) -> bool:
|
|
2597
|
+
return False
|
|
2598
|
+
|
|
2599
|
+
@property
|
|
2600
|
+
def cuda(self) -> bool:
|
|
2601
|
+
return True
|
|
2602
|
+
|
|
2603
|
+
|
|
2604
|
+
@register_quantize_op
|
|
2605
|
+
class MXFP4GroupedGemm2D3D(QuantizeOpBase):
|
|
2606
|
+
"""
|
|
2607
|
+
MXFP4 grouped GEMM with blockwise scaling and Torch 2D2D API.
|
|
2608
|
+
"""
|
|
2609
|
+
|
|
2610
|
+
def preprocess(self, xs, ws):
|
|
2611
|
+
m_sizes = [x.shape[0] for x in xs]
|
|
2612
|
+
m_offsets = torch.cumsum(torch.tensor(m_sizes), dim=0).to(
|
|
2613
|
+
dtype=torch.int32, device=xs[0].device
|
|
2614
|
+
)
|
|
2615
|
+
|
|
2616
|
+
wqs = []
|
|
2617
|
+
w_scales = []
|
|
2618
|
+
for w in ws:
|
|
2619
|
+
wq, w_scale = triton_quantize_mx4_unpack(w)
|
|
2620
|
+
wqs.append(wq)
|
|
2621
|
+
w_scales.append(w_scale)
|
|
2622
|
+
|
|
2623
|
+
wq = torch.stack(wqs, dim=0)
|
|
2624
|
+
w_scale = torch.stack(w_scales, dim=0)
|
|
2625
|
+
|
|
2626
|
+
return xs, wq, w_scale, m_offsets
|
|
2627
|
+
|
|
2628
|
+
def quantize(self, xs, wq, w_scale, m_offsets):
|
|
2629
|
+
xqs = []
|
|
2630
|
+
x_scales = []
|
|
2631
|
+
for x in xs:
|
|
2632
|
+
xq, x_scale = triton_quantize_mx4_unpack(x)
|
|
2633
|
+
xqs.append(xq)
|
|
2634
|
+
x_scales.append(x_scale)
|
|
2635
|
+
|
|
2636
|
+
xq = torch.cat(xqs, dim=0)
|
|
2637
|
+
x_scale = torch.stack(x_scales, dim=0)
|
|
2638
|
+
|
|
2639
|
+
xq = xq.view(torch.float4_e2m1fn_x2)
|
|
2640
|
+
wq = wq.view(torch.float4_e2m1fn_x2)
|
|
2641
|
+
x_scale = x_scale.view(torch.float8_e8m0fnu)
|
|
2642
|
+
w_scale = w_scale.view(torch.float8_e8m0fnu)
|
|
2643
|
+
|
|
2644
|
+
return xq, wq, x_scale, w_scale, m_offsets
|
|
2645
|
+
|
|
2646
|
+
def compute(
|
|
2647
|
+
self,
|
|
2648
|
+
xq,
|
|
2649
|
+
wq,
|
|
2650
|
+
x_scale,
|
|
2651
|
+
w_scale,
|
|
2652
|
+
m_offsets,
|
|
2653
|
+
):
|
|
2654
|
+
return torch.ops.fbgemm.f4f4bf16_grouped_mm(
|
|
2655
|
+
xq,
|
|
2656
|
+
wq.transpose(-2, -1),
|
|
2657
|
+
x_scale,
|
|
2658
|
+
w_scale,
|
|
2659
|
+
m_offsets,
|
|
2660
|
+
)
|
|
2661
|
+
|
|
2662
|
+
def quantize_and_compute(self, xs, wq, w_scale, m_offsets, output):
|
|
2663
|
+
args = self.quantize(xs, wq, w_scale, m_offsets, output)
|
|
2664
|
+
return self.compute(**args)
|
|
2665
|
+
|
|
2666
|
+
@property
|
|
2667
|
+
def name(self) -> str:
|
|
2668
|
+
return "cutlass_mx_f4f4bf16_grouped_mm_2d_3d"
|
|
2669
|
+
|
|
2670
|
+
@property
|
|
2671
|
+
def cuda(self) -> bool:
|
|
2672
|
+
return True
|
|
2673
|
+
|
|
2674
|
+
@property
|
|
2675
|
+
def hip(self) -> bool:
|
|
2676
|
+
return False
|
|
2677
|
+
|
|
2678
|
+
|
|
2679
|
+
@register_quantize_op
|
|
2680
|
+
class MXFP4GroupedGemm2D2D(QuantizeOpBase):
|
|
2681
|
+
"""
|
|
2682
|
+
MXFP4 grouped GEMM with blockwise scaling and Torch 2D2D API.
|
|
2683
|
+
"""
|
|
2684
|
+
|
|
2685
|
+
def preprocess(self, xs, ws):
|
|
2686
|
+
k_sizes = [x.shape[1] for x in xs]
|
|
2687
|
+
k_offsets = torch.cumsum(torch.tensor(k_sizes), dim=0).to(
|
|
2688
|
+
dtype=torch.int32, device=xs[0].device
|
|
2689
|
+
)
|
|
2690
|
+
|
|
2691
|
+
wqs = []
|
|
2692
|
+
w_scales = []
|
|
2693
|
+
for w in ws:
|
|
2694
|
+
wq, w_scale = triton_quantize_mx4_unpack(w)
|
|
2695
|
+
wqs.append(wq)
|
|
2696
|
+
w_scales.append(w_scale)
|
|
2697
|
+
|
|
2698
|
+
wq = torch.cat(wqs, dim=1)
|
|
2699
|
+
w_scale = torch.stack(w_scales, dim=0)
|
|
2700
|
+
|
|
2701
|
+
return xs, wq, w_scale, k_offsets
|
|
2702
|
+
|
|
2703
|
+
def quantize(self, xs, wq, w_scale, k_offsets):
|
|
2704
|
+
xqs = []
|
|
2705
|
+
x_scales = []
|
|
2706
|
+
for x in xs:
|
|
2707
|
+
xq, x_scale = triton_quantize_mx4_unpack(x)
|
|
2708
|
+
xqs.append(xq)
|
|
2709
|
+
x_scales.append(x_scale)
|
|
2710
|
+
|
|
2711
|
+
xq = torch.cat(xqs, dim=1)
|
|
2712
|
+
x_scale = torch.stack(x_scales, dim=0)
|
|
2713
|
+
|
|
2714
|
+
xq = xq.view(torch.float4_e2m1fn_x2)
|
|
2715
|
+
wq = wq.view(torch.float4_e2m1fn_x2)
|
|
2716
|
+
x_scale = x_scale.view(torch.float8_e8m0fnu)
|
|
2717
|
+
w_scale = w_scale.view(torch.float8_e8m0fnu)
|
|
2718
|
+
|
|
2719
|
+
return xq, wq, x_scale, w_scale, k_offsets
|
|
2720
|
+
|
|
2721
|
+
def compute(
|
|
2722
|
+
self,
|
|
2723
|
+
xq,
|
|
2724
|
+
wq,
|
|
2725
|
+
x_scale,
|
|
2726
|
+
w_scale,
|
|
2727
|
+
k_offsets,
|
|
2728
|
+
):
|
|
2729
|
+
return torch.ops.fbgemm.f4f4bf16_grouped_mm(
|
|
2730
|
+
xq,
|
|
2731
|
+
wq.transpose(-2, -1),
|
|
2732
|
+
x_scale,
|
|
2733
|
+
w_scale,
|
|
2734
|
+
k_offsets,
|
|
2735
|
+
)
|
|
2736
|
+
|
|
2737
|
+
def quantize_and_compute(self, xs, wq, w_scale, k_offsets, output):
|
|
2738
|
+
args = self.quantize(xs, wq, w_scale, k_offsets, output)
|
|
2739
|
+
return self.compute(**args)
|
|
2740
|
+
|
|
2741
|
+
@property
|
|
2742
|
+
def name(self) -> str:
|
|
2743
|
+
return "cutlass_mx_f4f4bf16_grouped_mm_2d_2d"
|
|
2744
|
+
|
|
2745
|
+
@property
|
|
2746
|
+
def cuda(self) -> bool:
|
|
2747
|
+
return True
|
|
2748
|
+
|
|
2749
|
+
@property
|
|
2750
|
+
def hip(self) -> bool:
|
|
2751
|
+
return False
|
|
2752
|
+
|
|
2753
|
+
|
|
2754
|
+
@register_quantize_op
|
|
2755
|
+
class NVFP4GroupedGemm2D3D(QuantizeOpBase):
|
|
2756
|
+
"""
|
|
2757
|
+
NVFP4 grouped GEMM with blockwise scaling and Torch 2D3D API.
|
|
2758
|
+
"""
|
|
2759
|
+
|
|
2760
|
+
def preprocess(self, x, w):
|
|
2761
|
+
m_values = [i.shape[0] for i in x]
|
|
2762
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
2763
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
2764
|
+
|
|
2765
|
+
def get_global_scale(x, w, m_sizes):
|
|
2766
|
+
G = len(w)
|
|
2767
|
+
w_global_scale = []
|
|
2768
|
+
global_scale = []
|
|
2769
|
+
|
|
2770
|
+
x_global_scale, tensor_idx = calculate_group_max(x, m_sizes=m_sizes)
|
|
2771
|
+
|
|
2772
|
+
for i in range(G):
|
|
2773
|
+
w_global_scale_ = (448.0 * 6.0) / torch.amax(
|
|
2774
|
+
torch.abs(w[i].flatten()), dim=-1
|
|
2775
|
+
).to(torch.float32)
|
|
2776
|
+
|
|
2777
|
+
global_scale_ = 1 / (x_global_scale[i] * w_global_scale_)
|
|
2778
|
+
|
|
2779
|
+
w_global_scale.append(w_global_scale_)
|
|
2780
|
+
global_scale.append(global_scale_)
|
|
2781
|
+
|
|
2782
|
+
return x_global_scale, w_global_scale, global_scale, tensor_idx
|
|
2783
|
+
|
|
2784
|
+
# Compute global scale for each group
|
|
2785
|
+
G = m_sizes.numel()
|
|
2786
|
+
x_global_scale, w_global_scale, global_scale, tensor_idx = get_global_scale(
|
|
2787
|
+
x, w, m_sizes
|
|
2788
|
+
)
|
|
2789
|
+
global_scale = torch.stack(global_scale, dim=0).contiguous()
|
|
2790
|
+
|
|
2791
|
+
wq, w_scale = zip(
|
|
2792
|
+
*[triton_scale_nvfp4_quant(w[i], w_global_scale[i]) for i in range(G)]
|
|
2793
|
+
)
|
|
2794
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
2795
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
2796
|
+
|
|
2797
|
+
return x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
|
|
2798
|
+
|
|
2799
|
+
def quantize(
|
|
2800
|
+
self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
|
|
2801
|
+
):
|
|
2802
|
+
xq, x_scale, _ = mega_fp4_quantize_kernel(
|
|
2803
|
+
m_sizes, x, x_global_scale, optional_tensor_idx=tensor_idx
|
|
2804
|
+
)
|
|
2805
|
+
|
|
2806
|
+
x_scale = x_scale.reshape(-1, x.shape[1] // 16)
|
|
2807
|
+
offsets = torch.cumsum(m_sizes, dim=0).to(torch.int32)
|
|
2808
|
+
|
|
2809
|
+
xq = xq.view(torch.float4_e2m1fn_x2)
|
|
2810
|
+
wq = wq.view(torch.float4_e2m1fn_x2)
|
|
2811
|
+
x_scale = x_scale.view(torch.float8_e4m3fn)
|
|
2812
|
+
w_scale = w_scale.view(torch.float8_e4m3fn)
|
|
2813
|
+
|
|
2814
|
+
return (
|
|
2815
|
+
xq,
|
|
2816
|
+
wq.transpose(-2, -1),
|
|
2817
|
+
x_scale,
|
|
2818
|
+
w_scale,
|
|
2819
|
+
offsets,
|
|
2820
|
+
None,
|
|
2821
|
+
global_scale,
|
|
2822
|
+
)
|
|
2823
|
+
|
|
2824
|
+
def compute(
|
|
2825
|
+
self,
|
|
2826
|
+
xq,
|
|
2827
|
+
wq,
|
|
2828
|
+
x_scale,
|
|
2829
|
+
w_scale,
|
|
2830
|
+
offsets,
|
|
2831
|
+
output,
|
|
2832
|
+
global_scale,
|
|
2833
|
+
):
|
|
2834
|
+
return torch.ops.fbgemm.f4f4bf16_grouped_mm(
|
|
2835
|
+
xq,
|
|
2836
|
+
wq,
|
|
2837
|
+
x_scale,
|
|
2838
|
+
w_scale,
|
|
2839
|
+
offsets,
|
|
2840
|
+
output,
|
|
2841
|
+
global_scale,
|
|
2842
|
+
)
|
|
2843
|
+
|
|
2844
|
+
def quantize_and_compute(self, xq, wq, x_scale, w_scale, global_scale, k_offsets):
|
|
2845
|
+
args = self.quantize(xq, wq, x_scale, w_scale, global_scale, k_offsets)
|
|
2846
|
+
return self.compute(**args)
|
|
2847
|
+
|
|
2848
|
+
@property
|
|
2849
|
+
def name(self) -> str:
|
|
2850
|
+
return "cutlass_nv_f4f4bf16_grouped_mm_2d_3d"
|
|
2851
|
+
|
|
2852
|
+
@property
|
|
2853
|
+
def hip(self) -> bool:
|
|
2854
|
+
return False
|
|
2855
|
+
|
|
2856
|
+
@property
|
|
2857
|
+
def cuda(self) -> bool:
|
|
2858
|
+
return True
|
|
2859
|
+
|
|
2860
|
+
|
|
2861
|
+
@register_quantize_op
|
|
2862
|
+
class NVFP4GroupedGemm2D2D(QuantizeOpBase):
|
|
2863
|
+
"""
|
|
2864
|
+
NVFP4 grouped GEMM with blockwise scaling and Torch 2D2D API.
|
|
2865
|
+
"""
|
|
2866
|
+
|
|
2867
|
+
def preprocess(self, xs, ws):
|
|
2868
|
+
k_sizes = [x.shape[1] for x in xs]
|
|
2869
|
+
k_offsets = torch.cumsum(torch.tensor(k_sizes), dim=0).to(
|
|
2870
|
+
dtype=torch.int32, device=xs[0].device
|
|
2871
|
+
)
|
|
2872
|
+
|
|
2873
|
+
global_scales, x_global_scales, w_global_scales = get_nvfp4_global_scales_naive(
|
|
2874
|
+
xs, ws
|
|
2875
|
+
)
|
|
2876
|
+
wqs, w_scales = quantize_nvfp4_naive(ws, w_global_scales)
|
|
2877
|
+
wq = torch.cat(wqs, dim=1).view(torch.float4_e2m1fn_x2)
|
|
2878
|
+
w_scale = (
|
|
2879
|
+
torch.stack(w_scales, dim=0)
|
|
2880
|
+
.reshape(round_up(wq.size(0), 128), -1)
|
|
2881
|
+
.view(torch.float8_e4m3fn)
|
|
2882
|
+
)
|
|
2883
|
+
global_scale = torch.stack(global_scales, dim=0)
|
|
2884
|
+
|
|
2885
|
+
return xs, wq, w_scale, global_scale, x_global_scales, k_offsets
|
|
2886
|
+
|
|
2887
|
+
def quantize(self, xs, wq, w_scale, global_scale, x_global_scales, k_offsets):
|
|
2888
|
+
xqs, x_scales = quantize_nvfp4_naive(xs, x_global_scales)
|
|
2889
|
+
xq = torch.cat(xqs, dim=1).view(torch.float4_e2m1fn_x2)
|
|
2890
|
+
x_scale = (
|
|
2891
|
+
torch.stack(x_scales, dim=0)
|
|
2892
|
+
.reshape(round_up(xq.size(0), 128), -1)
|
|
2893
|
+
.view(torch.float8_e4m3fn)
|
|
2894
|
+
)
|
|
2895
|
+
|
|
2896
|
+
return xq, wq, x_scale, w_scale, k_offsets, global_scale
|
|
2897
|
+
|
|
2898
|
+
def compute(
|
|
2899
|
+
self,
|
|
2900
|
+
xq,
|
|
2901
|
+
wq,
|
|
2902
|
+
x_scale,
|
|
2903
|
+
w_scale,
|
|
2904
|
+
k_offsets,
|
|
2905
|
+
global_scale,
|
|
2906
|
+
):
|
|
2907
|
+
return torch.ops.fbgemm.f4f4bf16_grouped_mm(
|
|
2908
|
+
xq,
|
|
2909
|
+
wq.transpose(-2, -1),
|
|
2910
|
+
x_scale,
|
|
2911
|
+
w_scale,
|
|
2912
|
+
k_offsets,
|
|
2913
|
+
None,
|
|
2914
|
+
global_scale,
|
|
2915
|
+
)
|
|
2916
|
+
|
|
2917
|
+
def quantize_and_compute(self, xq, wq, x_scale, w_scale, global_scale, k_offsets):
|
|
2918
|
+
args = self.quantize(xq, wq, x_scale, w_scale, global_scale, k_offsets)
|
|
2919
|
+
return self.compute(**args)
|
|
2920
|
+
|
|
2921
|
+
@property
|
|
2922
|
+
def name(self) -> str:
|
|
2923
|
+
return "cutlass_nv_f4f4bf16_grouped_mm_2d_2d"
|
|
2924
|
+
|
|
2925
|
+
@property
|
|
2926
|
+
def hip(self) -> bool:
|
|
2927
|
+
return False
|
|
2928
|
+
|
|
2929
|
+
@property
|
|
2930
|
+
def cuda(self) -> bool:
|
|
2931
|
+
return True
|
|
2932
|
+
|
|
2933
|
+
|
|
2934
|
+
@register_quantize_op
|
|
2935
|
+
class NVFP4StackedGroupedGemm(QuantizeOpBase):
|
|
2936
|
+
"""
|
|
2937
|
+
NVFP4 grouped matmul with blockwise scaling and stacked inputs.
|
|
2938
|
+
"""
|
|
2939
|
+
|
|
2940
|
+
def preprocess(self, x, w):
|
|
2941
|
+
m_values = [i.shape[0] for i in x]
|
|
2942
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
2943
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
2944
|
+
|
|
2945
|
+
def get_global_scale(x, w, m_sizes):
|
|
2946
|
+
G = len(w)
|
|
2947
|
+
w_global_scale = []
|
|
2948
|
+
global_scale = []
|
|
2949
|
+
|
|
2950
|
+
cumulative_sum = torch.zeros(
|
|
2951
|
+
m_sizes.shape[0] + 1, dtype=torch.int64, device=m_sizes.device
|
|
2952
|
+
)
|
|
2953
|
+
cumulative_sum[1:] = torch.cumsum(m_sizes, dim=0)
|
|
2954
|
+
|
|
2955
|
+
x_global_scale, tensor_idx = calculate_group_max(x, m_sizes=m_sizes)
|
|
2956
|
+
|
|
2957
|
+
for i in range(G):
|
|
2958
|
+
w_global_scale_ = (448.0 * 6.0) / torch.amax(
|
|
2959
|
+
torch.abs(w[i].flatten()), dim=-1
|
|
2960
|
+
).to(torch.float32)
|
|
2961
|
+
|
|
2962
|
+
global_scale_ = 1 / (x_global_scale[i] * w_global_scale_)
|
|
2963
|
+
|
|
2964
|
+
w_global_scale.append(w_global_scale_)
|
|
2965
|
+
global_scale.append(global_scale_)
|
|
2966
|
+
|
|
2967
|
+
return x_global_scale, w_global_scale, global_scale, tensor_idx
|
|
2968
|
+
|
|
2969
|
+
# Compute global scale for each group
|
|
2970
|
+
G = m_sizes.numel()
|
|
2971
|
+
x_global_scale, w_global_scale, global_scale, tensor_idx = get_global_scale(
|
|
2972
|
+
x, w, m_sizes
|
|
2973
|
+
)
|
|
2974
|
+
global_scale = torch.stack(global_scale, dim=0).contiguous()
|
|
2975
|
+
|
|
2976
|
+
wq, w_scale = zip(
|
|
2977
|
+
*[triton_scale_nvfp4_quant(w[i], w_global_scale[i]) for i in range(G)]
|
|
2978
|
+
)
|
|
2979
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
2980
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
2981
|
+
|
|
2982
|
+
return x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
|
|
2983
|
+
|
|
2984
|
+
def quantize(
|
|
2985
|
+
self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
|
|
2986
|
+
):
|
|
2987
|
+
# alternative methods, may be useful in some scenarios
|
|
2988
|
+
"""
|
|
2989
|
+
starting_row_after_padding, belong_indices, row_within_tensor = (
|
|
2990
|
+
nvfp4_fused_padding_cumsum_and_segmented_arange(m_sizes, x.shape[0])
|
|
2991
|
+
# fused_single_block_cumsum_and_segmented_arange(m_sizes, x.shape[0])
|
|
2992
|
+
)
|
|
2993
|
+
|
|
2994
|
+
xq, x_scale = triton_nvfp4_quant_stacked(
|
|
2995
|
+
x,
|
|
2996
|
+
x_global_scale[0],
|
|
2997
|
+
belong_indices,
|
|
2998
|
+
starting_row_after_padding,
|
|
2999
|
+
row_within_tensor,
|
|
3000
|
+
)
|
|
3001
|
+
"""
|
|
3002
|
+
|
|
3003
|
+
# we can optionally set optional_tensor_idx to None to run the alternative method
|
|
3004
|
+
xq, x_scale, starting_row_after_padding = mega_fp4_quantize_kernel(
|
|
3005
|
+
m_sizes, x, x_global_scale, optional_tensor_idx=tensor_idx
|
|
3006
|
+
)
|
|
3007
|
+
|
|
3008
|
+
x_scale = x_scale.reshape(-1, x.shape[1] // 16)
|
|
3009
|
+
return (
|
|
3010
|
+
xq,
|
|
3011
|
+
wq,
|
|
3012
|
+
x_scale,
|
|
3013
|
+
w_scale,
|
|
3014
|
+
m_sizes,
|
|
3015
|
+
global_scale,
|
|
3016
|
+
starting_row_after_padding,
|
|
3017
|
+
)
|
|
3018
|
+
|
|
3019
|
+
def compute(
|
|
3020
|
+
self,
|
|
3021
|
+
xq,
|
|
3022
|
+
wq,
|
|
3023
|
+
x_scale,
|
|
3024
|
+
w_scale,
|
|
3025
|
+
m_sizes,
|
|
3026
|
+
global_scale,
|
|
3027
|
+
starting_row_after_padding,
|
|
3028
|
+
):
|
|
3029
|
+
gemm_result = torch.ops.fbgemm.f4f4bf16_grouped_stacked(
|
|
3030
|
+
xq,
|
|
3031
|
+
wq,
|
|
3032
|
+
x_scale,
|
|
3033
|
+
w_scale,
|
|
3034
|
+
m_sizes,
|
|
3035
|
+
global_scale,
|
|
3036
|
+
starting_row_after_padding,
|
|
3037
|
+
use_mx=False,
|
|
3038
|
+
)
|
|
3039
|
+
return gemm_result
|
|
3040
|
+
|
|
3041
|
+
def quantize_and_compute(
|
|
3042
|
+
self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
|
|
3043
|
+
):
|
|
3044
|
+
(
|
|
3045
|
+
xq,
|
|
3046
|
+
wq,
|
|
3047
|
+
x_scale,
|
|
3048
|
+
w_scale,
|
|
3049
|
+
m_sizes,
|
|
3050
|
+
global_scale,
|
|
3051
|
+
starting_row_after_padding,
|
|
3052
|
+
) = self.quantize(
|
|
3053
|
+
x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
|
|
3054
|
+
)
|
|
3055
|
+
return self.compute(
|
|
3056
|
+
xq,
|
|
3057
|
+
wq,
|
|
3058
|
+
x_scale,
|
|
3059
|
+
w_scale,
|
|
3060
|
+
m_sizes,
|
|
3061
|
+
global_scale,
|
|
3062
|
+
starting_row_after_padding,
|
|
3063
|
+
)
|
|
3064
|
+
|
|
3065
|
+
@property
|
|
3066
|
+
def name(self) -> str:
|
|
3067
|
+
return "cutlass_nv_f4f4bf16_grouped_stacked"
|
|
3068
|
+
|
|
3069
|
+
@property
|
|
3070
|
+
def hip(self) -> bool:
|
|
3071
|
+
return False
|
|
3072
|
+
|
|
3073
|
+
@property
|
|
3074
|
+
def cuda(self) -> bool:
|
|
3075
|
+
return True
|
|
3076
|
+
|
|
3077
|
+
|
|
3078
|
+
@register_quantize_op
|
|
3079
|
+
class NVFP4StackedGroupedGemmPackUnpack(QuantizeOpBase):
|
|
3080
|
+
"""
|
|
3081
|
+
NVFP4 grouped matmul with blockwise scaling and stacked inputs.
|
|
3082
|
+
"""
|
|
3083
|
+
|
|
3084
|
+
def preprocess(self, x, w):
|
|
3085
|
+
m_values = [i.shape[0] for i in x]
|
|
3086
|
+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
|
|
3087
|
+
x = torch.concat(x, dim=0).contiguous()
|
|
3088
|
+
|
|
3089
|
+
def get_global_scale(x, w):
|
|
3090
|
+
G = len(w)
|
|
3091
|
+
x_global_scale = []
|
|
3092
|
+
w_global_scale = []
|
|
3093
|
+
global_scale = []
|
|
3094
|
+
|
|
3095
|
+
x_global_scale_ = (448.0 * 6.0) / torch.amax(
|
|
3096
|
+
torch.abs(x.flatten()), dim=-1
|
|
3097
|
+
).to(torch.float32)
|
|
3098
|
+
|
|
3099
|
+
for i in range(G):
|
|
3100
|
+
w_global_scale_ = (448.0 * 6.0) / torch.amax(
|
|
3101
|
+
torch.abs(w[i].flatten()), dim=-1
|
|
3102
|
+
).to(torch.float32)
|
|
3103
|
+
|
|
3104
|
+
global_scale_ = 1 / (x_global_scale_ * w_global_scale_)
|
|
3105
|
+
|
|
3106
|
+
x_global_scale.append(x_global_scale_)
|
|
3107
|
+
w_global_scale.append(w_global_scale_)
|
|
3108
|
+
global_scale.append(global_scale_)
|
|
3109
|
+
|
|
3110
|
+
return x_global_scale, w_global_scale, global_scale
|
|
3111
|
+
|
|
3112
|
+
# Compute global scale for each group
|
|
3113
|
+
G = m_sizes.numel()
|
|
3114
|
+
x_global_scale, w_global_scale, global_scale = get_global_scale(x, w)
|
|
3115
|
+
|
|
3116
|
+
global_scale = torch.stack(global_scale, dim=0).contiguous()
|
|
3117
|
+
|
|
3118
|
+
wq, w_scale = zip(
|
|
3119
|
+
*[triton_scale_nvfp4_quant(w[i], w_global_scale[i]) for i in range(G)]
|
|
3120
|
+
)
|
|
3121
|
+
wq = torch.stack(wq, dim=0).contiguous()
|
|
3122
|
+
w_scale = torch.stack(w_scale, dim=0).contiguous()
|
|
3123
|
+
x_global_scale = torch.tensor(x_global_scale, device=m_sizes.device)
|
|
3124
|
+
return (
|
|
3125
|
+
x,
|
|
3126
|
+
wq,
|
|
3127
|
+
w_scale,
|
|
3128
|
+
x_global_scale,
|
|
3129
|
+
global_scale,
|
|
3130
|
+
m_sizes,
|
|
3131
|
+
)
|
|
3132
|
+
|
|
3133
|
+
def quantize(self, x, wq, w_scale, x_global_scale, global_scale, m_sizes):
|
|
3134
|
+
# alternative packing methods that only uses the overall global scale rather than per tensor
|
|
3135
|
+
"""
|
|
3136
|
+
packed = mega_fp4_pack(x, x_global_scale[0])
|
|
3137
|
+
"""
|
|
3138
|
+
packed = mega_fp4_pack(
|
|
3139
|
+
x,
|
|
3140
|
+
x_global_scale,
|
|
3141
|
+
per_tensor=True,
|
|
3142
|
+
m_sizes=m_sizes,
|
|
3143
|
+
)
|
|
3144
|
+
xq, x_scale, starting_row_after_padding = mega_fp4_unpack(m_sizes, packed)
|
|
3145
|
+
xq_other, x_scale_other, starting_row_after_padding_other = (
|
|
3146
|
+
mega_fp4_quantize_kernel(
|
|
3147
|
+
m_sizes,
|
|
3148
|
+
x,
|
|
3149
|
+
x_global_scale,
|
|
3150
|
+
)
|
|
3151
|
+
)
|
|
3152
|
+
|
|
3153
|
+
x_scale = x_scale.reshape(-1, x.shape[1] // 16)
|
|
3154
|
+
x_scale_other = x_scale_other.reshape(-1, x.shape[1] // 16)
|
|
3155
|
+
return (
|
|
3156
|
+
xq,
|
|
3157
|
+
wq,
|
|
3158
|
+
x_scale,
|
|
3159
|
+
w_scale,
|
|
3160
|
+
m_sizes,
|
|
3161
|
+
global_scale,
|
|
3162
|
+
starting_row_after_padding,
|
|
3163
|
+
xq_other,
|
|
3164
|
+
x_scale_other,
|
|
3165
|
+
starting_row_after_padding_other,
|
|
3166
|
+
)
|
|
3167
|
+
|
|
3168
|
+
def compute(
|
|
3169
|
+
self,
|
|
3170
|
+
xq,
|
|
3171
|
+
wq,
|
|
3172
|
+
x_scale,
|
|
3173
|
+
w_scale,
|
|
3174
|
+
m_sizes,
|
|
3175
|
+
global_scale,
|
|
3176
|
+
starting_row_after_padding,
|
|
3177
|
+
xq_other,
|
|
3178
|
+
x_scale_other,
|
|
3179
|
+
starting_row_after_padding_other,
|
|
3180
|
+
):
|
|
3181
|
+
ref_solution = torch.ops.fbgemm.f4f4bf16_grouped_stacked(
|
|
3182
|
+
xq_other,
|
|
3183
|
+
wq,
|
|
3184
|
+
x_scale_other,
|
|
3185
|
+
w_scale,
|
|
3186
|
+
m_sizes,
|
|
3187
|
+
global_scale,
|
|
3188
|
+
starting_row_after_padding_other,
|
|
3189
|
+
use_mx=False,
|
|
3190
|
+
)
|
|
3191
|
+
gemm_result = torch.ops.fbgemm.f4f4bf16_grouped_stacked(
|
|
3192
|
+
xq,
|
|
3193
|
+
wq,
|
|
3194
|
+
x_scale,
|
|
3195
|
+
w_scale,
|
|
3196
|
+
m_sizes,
|
|
3197
|
+
global_scale,
|
|
3198
|
+
starting_row_after_padding,
|
|
3199
|
+
use_mx=False,
|
|
3200
|
+
)
|
|
3201
|
+
assert torch.allclose(ref_solution, gemm_result)
|
|
3202
|
+
|
|
3203
|
+
return gemm_result
|
|
3204
|
+
|
|
3205
|
+
def quantize_and_compute(
|
|
3206
|
+
self, x, wq, w_scale, x_global_scale, global_scale, m_sizes
|
|
3207
|
+
):
|
|
3208
|
+
(
|
|
3209
|
+
xq,
|
|
3210
|
+
wq,
|
|
3211
|
+
x_scale,
|
|
3212
|
+
w_scale,
|
|
3213
|
+
m_sizes,
|
|
3214
|
+
global_scale,
|
|
3215
|
+
starting_row_after_padding,
|
|
3216
|
+
xq_other,
|
|
3217
|
+
x_scale_other,
|
|
3218
|
+
starting_row_after_padding_other,
|
|
3219
|
+
) = self.quantize(x, wq, w_scale, x_global_scale, global_scale, m_sizes)
|
|
3220
|
+
return self.compute(
|
|
3221
|
+
xq,
|
|
3222
|
+
wq,
|
|
3223
|
+
x_scale,
|
|
3224
|
+
w_scale,
|
|
3225
|
+
m_sizes,
|
|
3226
|
+
global_scale,
|
|
3227
|
+
starting_row_after_padding,
|
|
3228
|
+
xq_other,
|
|
3229
|
+
x_scale_other,
|
|
3230
|
+
starting_row_after_padding_other,
|
|
3231
|
+
)
|
|
3232
|
+
|
|
3233
|
+
@property
|
|
3234
|
+
def name(self) -> str:
|
|
3235
|
+
return "cutlass_nv_f4f4bf16_grouped_stacked_pack_unpack"
|
|
3236
|
+
|
|
3237
|
+
@property
|
|
3238
|
+
def hip(self) -> bool:
|
|
3239
|
+
return False
|
|
3240
|
+
|
|
3241
|
+
@property
|
|
3242
|
+
def cuda(self) -> bool:
|
|
3243
|
+
return True
|
|
3244
|
+
|
|
3245
|
+
|
|
3246
|
+
@register_quantize_op
|
|
3247
|
+
class BF16GroupedGemm2d3d(QuantizeOpBase):
|
|
3248
|
+
"""
|
|
3249
|
+
Torch BF16 grouped GEMM with 2D inputs and 3D weights.
|
|
3250
|
+
"""
|
|
3251
|
+
|
|
3252
|
+
def preprocess(self, x, w):
|
|
3253
|
+
assert isinstance(x, list)
|
|
3254
|
+
assert isinstance(w, list)
|
|
3255
|
+
offs = torch.tensor(
|
|
3256
|
+
[i.shape[0] for i in x], dtype=torch.int32, device=x[0].device
|
|
3257
|
+
)
|
|
3258
|
+
offs = torch.cumsum(offs, dim=0).to(torch.int32)
|
|
3259
|
+
x = torch.cat(x, dim=0).contiguous() # (G * M, K)
|
|
3260
|
+
w = torch.stack(w, dim=0).contiguous() # (G, N, K)
|
|
3261
|
+
return x, w, offs
|
|
3262
|
+
|
|
3263
|
+
def quantize(self, x, w, offs):
|
|
3264
|
+
return x, w, offs
|
|
3265
|
+
|
|
3266
|
+
def compute(self, x, w, offs):
|
|
3267
|
+
return torch._grouped_mm(
|
|
3268
|
+
x,
|
|
3269
|
+
w.transpose(-2, -1),
|
|
3270
|
+
offs=offs,
|
|
3271
|
+
)
|
|
3272
|
+
|
|
3273
|
+
def quantize_and_compute(self, x, w, offs):
|
|
3274
|
+
x, w, offs = self.quantize(x, w)
|
|
3275
|
+
return self.compute(x, w, offs)
|
|
3276
|
+
|
|
3277
|
+
@property
|
|
3278
|
+
def name(self) -> str:
|
|
3279
|
+
return "bf16_baseline_grouped_2d_3d"
|
|
3280
|
+
|
|
3281
|
+
@property
|
|
3282
|
+
def hip(self) -> bool:
|
|
3283
|
+
return False
|
|
3284
|
+
|
|
3285
|
+
@property
|
|
3286
|
+
def cuda(self) -> bool:
|
|
3287
|
+
return True
|
|
3288
|
+
|
|
3289
|
+
|
|
3290
|
+
@register_quantize_op
|
|
3291
|
+
class MXFP8GroupedGemm2d3d(QuantizeOpBase):
|
|
3292
|
+
"""
|
|
3293
|
+
MXFP8 grouped GEMM with 2D inputs and 3D weights.
|
|
3294
|
+
"""
|
|
3295
|
+
|
|
3296
|
+
def preprocess(self, x, w):
|
|
3297
|
+
assert isinstance(x, list)
|
|
3298
|
+
assert isinstance(w, list)
|
|
3299
|
+
x = torch.cat(x, dim=0).contiguous() # (G * M, K)
|
|
3300
|
+
w = torch.stack(w, dim=0).contiguous() # (G, N, K)
|
|
3301
|
+
return x, w
|
|
3302
|
+
|
|
3303
|
+
def quantize(self, x, w):
|
|
3304
|
+
block_size = 32
|
|
3305
|
+
G, N, K = w.shape
|
|
3306
|
+
total_M = x.shape[0]
|
|
3307
|
+
group_size = total_M // G
|
|
3308
|
+
input_group_end_offsets = torch.arange(
|
|
3309
|
+
group_size, total_M + 1, group_size, dtype=torch.int32, device=x.device
|
|
3310
|
+
)
|
|
3311
|
+
|
|
3312
|
+
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
|
|
3313
|
+
# as they each used for independent gemm in the grouped gemm.
|
|
3314
|
+
wq_list = []
|
|
3315
|
+
w_scale_list = []
|
|
3316
|
+
for i in range(G):
|
|
3317
|
+
w_scale, wq = to_mxfp8(w[i])
|
|
3318
|
+
w_scale = _to_blocked(w_scale)
|
|
3319
|
+
wq_list.append(wq)
|
|
3320
|
+
w_scale_list.append(w_scale)
|
|
3321
|
+
wq = torch.stack(wq_list, dim=0).contiguous()
|
|
3322
|
+
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
|
3323
|
+
|
|
3324
|
+
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
|
|
3325
|
+
# as they each used for independent gemm in the grouped gemm.
|
|
3326
|
+
xq_list = []
|
|
3327
|
+
x_scale_list = []
|
|
3328
|
+
for i in range(G):
|
|
3329
|
+
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
|
|
3330
|
+
curr_group_end = input_group_end_offsets[i]
|
|
3331
|
+
group_size = curr_group_end - prev_group_end
|
|
3332
|
+
if group_size > 0:
|
|
3333
|
+
x_slice = x[prev_group_end:curr_group_end, :]
|
|
3334
|
+
x_scale, xq = to_mxfp8(x_slice)
|
|
3335
|
+
x_scale = _to_blocked(x_scale)
|
|
3336
|
+
xq_list.append(xq)
|
|
3337
|
+
x_scale_list.append(x_scale)
|
|
3338
|
+
xq = torch.cat(xq_list, dim=0).contiguous()
|
|
3339
|
+
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
|
|
3340
|
+
x_scale = x_scale.reshape(-1, K // block_size)
|
|
3341
|
+
xq = xq.view(-1, xq.shape[-1])
|
|
3342
|
+
return xq, wq, x_scale, w_scale, input_group_end_offsets
|
|
3343
|
+
|
|
3344
|
+
def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
|
|
3345
|
+
return torch.ops.fbgemm.mx8mx8bf16_grouped_mm(
|
|
3346
|
+
xq,
|
|
3347
|
+
wq.transpose(-2, -1),
|
|
3348
|
+
x_scale,
|
|
3349
|
+
w_scale,
|
|
3350
|
+
input_group_end_offsets,
|
|
3351
|
+
)
|
|
3352
|
+
|
|
3353
|
+
def quantize_and_compute(self, x, w):
|
|
3354
|
+
xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
|
|
3355
|
+
return self.compute(
|
|
3356
|
+
xq,
|
|
3357
|
+
wq,
|
|
3358
|
+
x_scale,
|
|
3359
|
+
w_scale,
|
|
3360
|
+
input_group_end_offsets,
|
|
3361
|
+
)
|
|
3362
|
+
|
|
3363
|
+
@property
|
|
3364
|
+
def name(self) -> str:
|
|
3365
|
+
return "cutlass_mx8mx8bf16_grouped_mm_2d_3d"
|
|
3366
|
+
|
|
3367
|
+
@property
|
|
3368
|
+
def hip(self) -> bool:
|
|
3369
|
+
return False
|
|
3370
|
+
|
|
3371
|
+
@property
|
|
3372
|
+
def cuda(self) -> bool:
|
|
3373
|
+
return True
|
|
3374
|
+
|
|
3375
|
+
|
|
3376
|
+
@register_quantize_op
|
|
3377
|
+
class MXFP8GroupedGemm2d2d(QuantizeOpBase):
|
|
3378
|
+
"""
|
|
3379
|
+
MXFP8 grouped GEMM with 2D inputs and 3D weights.
|
|
3380
|
+
"""
|
|
3381
|
+
|
|
3382
|
+
def preprocess(self, x, w):
|
|
3383
|
+
assert isinstance(x, list)
|
|
3384
|
+
assert isinstance(w, list)
|
|
3385
|
+
G = len(x)
|
|
3386
|
+
x = torch.cat(x, dim=1).contiguous() # (M, total_K)
|
|
3387
|
+
w = torch.cat(w, dim=1).contiguous() # (N, total_K)
|
|
3388
|
+
return x, w, G
|
|
3389
|
+
|
|
3390
|
+
def quantize(self, x, w, G):
|
|
3391
|
+
# Simulate 2d-2d grouped gemm in backward pass `grad_weight = grad_output_t @ input`,
|
|
3392
|
+
# where we use "K" as the contracting dim which has "G" groups.
|
|
3393
|
+
M, total_K = x.shape
|
|
3394
|
+
N, _ = w.shape
|
|
3395
|
+
group_size = total_K // G
|
|
3396
|
+
input_group_end_offsets = torch.arange(
|
|
3397
|
+
group_size, total_K + 1, group_size, dtype=torch.int32, device=x.device
|
|
3398
|
+
)
|
|
3399
|
+
|
|
3400
|
+
# Convert scales to blocked format.
|
|
3401
|
+
x_list = []
|
|
3402
|
+
w_list = []
|
|
3403
|
+
x_blocked_scale_list = []
|
|
3404
|
+
w_blocked_scale_list = []
|
|
3405
|
+
|
|
3406
|
+
for group_idx in range(G):
|
|
3407
|
+
# to_mxfp8 per group
|
|
3408
|
+
prev_group_end_offset = (
|
|
3409
|
+
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
|
|
3410
|
+
)
|
|
3411
|
+
curr_group_end_offset = input_group_end_offsets[group_idx]
|
|
3412
|
+
group_size = curr_group_end_offset - prev_group_end_offset
|
|
3413
|
+
if group_size > 0:
|
|
3414
|
+
x_slice = x[
|
|
3415
|
+
:, prev_group_end_offset:curr_group_end_offset
|
|
3416
|
+
].contiguous() # (M, K_group)
|
|
3417
|
+
w_slice = w[
|
|
3418
|
+
:, prev_group_end_offset:curr_group_end_offset
|
|
3419
|
+
].contiguous() # (N, K_group)
|
|
3420
|
+
x_scale_slice, xq_slice = to_mxfp8(
|
|
3421
|
+
x_slice
|
|
3422
|
+
) # scale shape -> (M, K_group // 32)
|
|
3423
|
+
w_scale_slice, wq_slice = to_mxfp8(
|
|
3424
|
+
w_slice
|
|
3425
|
+
) # scale shape -> (N, K_group // 32)
|
|
3426
|
+
x_list.append(xq_slice)
|
|
3427
|
+
w_list.append(wq_slice)
|
|
3428
|
+
|
|
3429
|
+
# Convert scales to blocked format.
|
|
3430
|
+
x_scale_slice_blocked = _to_blocked(
|
|
3431
|
+
x_scale_slice
|
|
3432
|
+
) # (round_up(M, 128), round_up(K_group//32, 4))
|
|
3433
|
+
w_scale_slice_blocked = _to_blocked(
|
|
3434
|
+
w_scale_slice
|
|
3435
|
+
) # (round_up(N, 128), round_up(K_group//32, 4))
|
|
3436
|
+
x_blocked_scale_list.append(x_scale_slice_blocked)
|
|
3437
|
+
w_blocked_scale_list.append(w_scale_slice_blocked)
|
|
3438
|
+
|
|
3439
|
+
# Assemble the full XQ and WQ
|
|
3440
|
+
xq = torch.cat(x_list, dim=1).contiguous()
|
|
3441
|
+
wq = torch.cat(w_list, dim=1).contiguous()
|
|
3442
|
+
|
|
3443
|
+
# Combine all XQ groups blocked scales into one tensor.
|
|
3444
|
+
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
|
|
3445
|
+
M_rounded = round_up(M, 128)
|
|
3446
|
+
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
|
|
3447
|
+
|
|
3448
|
+
# Combine all WQ groups blocked scales into one tensor.
|
|
3449
|
+
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
|
|
3450
|
+
N_rounded = round_up(N, 128)
|
|
3451
|
+
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
|
|
3452
|
+
return xq, wq, x_blocked_scales, w_blocked_scales, input_group_end_offsets
|
|
3453
|
+
|
|
3454
|
+
def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
|
|
3455
|
+
return torch.ops.fbgemm.mx8mx8bf16_grouped_mm(
|
|
3456
|
+
xq,
|
|
3457
|
+
wq.transpose(-2, -1),
|
|
3458
|
+
x_scale,
|
|
3459
|
+
w_scale,
|
|
3460
|
+
input_group_end_offsets,
|
|
3461
|
+
)
|
|
3462
|
+
|
|
3463
|
+
def quantize_and_compute(self, x, w):
|
|
3464
|
+
xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
|
|
3465
|
+
return self.compute(
|
|
3466
|
+
xq,
|
|
3467
|
+
wq,
|
|
3468
|
+
x_scale,
|
|
3469
|
+
w_scale,
|
|
3470
|
+
input_group_end_offsets,
|
|
3471
|
+
)
|
|
3472
|
+
|
|
3473
|
+
@property
|
|
3474
|
+
def name(self) -> str:
|
|
3475
|
+
return "cutlass_mx8mx8bf16_grouped_mm_2d_2d"
|
|
3476
|
+
|
|
3477
|
+
@property
|
|
3478
|
+
def hip(self) -> bool:
|
|
3479
|
+
return False
|
|
3480
|
+
|
|
3481
|
+
@property
|
|
3482
|
+
def cuda(self) -> bool:
|
|
3483
|
+
return True
|