fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,3483 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Keep a registry of all quantize operators.
8
+ import abc
9
+
10
+ import fbgemm_gpu.experimental.gen_ai # noqa: F401
11
+ import numpy as np
12
+
13
+ import torch
14
+ import triton # @manual=//triton:triton
15
+
16
+ from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
17
+ _to_blocked,
18
+ calculate_group_max,
19
+ get_nvfp4_global_scales_naive,
20
+ mega_fp4_pack,
21
+ mega_fp4_quantize_kernel,
22
+ mega_fp4_unpack,
23
+ quantize_nvfp4_naive,
24
+ triton_quantize_mx4_unpack,
25
+ triton_scale_nvfp4_quant,
26
+ triton_scale_nvfp4_quant_rms,
27
+ triton_scale_nvfp4_quant_silu,
28
+ )
29
+
30
+ from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
31
+ get_fp8_constants,
32
+ matmul_fp8_block,
33
+ matmul_fp8_row,
34
+ quantize_fp8_block,
35
+ quantize_fp8_group,
36
+ quantize_fp8_row,
37
+ scale_fp8_row,
38
+ to_mxfp8,
39
+ triton_quantize_fp8_row,
40
+ )
41
+ from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
42
+ grouped_gemm,
43
+ grouped_gemm_fp8_rowwise,
44
+ )
45
+ from fbgemm_gpu.experimental.gen_ai.quantize import (
46
+ ck_preshuffle,
47
+ quantize_int4_preshuffle,
48
+ )
49
+
50
+ try:
51
+ from gen_ai.llm_inference.fb.llm.kernel.rms_norm import rms_norm
52
+ from gen_ai.llm_inference.fb.llm.kernel.silu_mul import silu_mul
53
+ except ImportError:
54
+ # Above is used for some experiments, but the quantize is not relying on them. Okay to just skip.
55
+ pass
56
+
57
+ try:
58
+ from tinygemm.utils import group_quantize_tensor
59
+
60
+ if torch.cuda.is_available() and torch.version.cuda:
61
+ torch.ops.load_library("//tinygemm:tinygemm")
62
+ TINYGEMM_ENABLED = True
63
+ except ImportError:
64
+ TINYGEMM_ENABLED = False
65
+
66
+ # Marlin currently only is supported only internally at Meta.
67
+ try:
68
+ from marlin.quantize import marlin_quantize
69
+
70
+ torch.ops.load_library("//ai_codesign/gen_ai/marlin:marlin_ops")
71
+ MARLIN_ENABLED = True
72
+ except ImportError:
73
+ MARLIN_ENABLED = False
74
+
75
+ try:
76
+ from deep_gemm import (
77
+ gemm_fp8_fp8_bf16_nt,
78
+ get_col_major_tma_aligned_tensor,
79
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
80
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked,
81
+ )
82
+
83
+ DEEPGEMM_ENABLED = True
84
+ except ImportError:
85
+ DEEPGEMM_ENABLED = False
86
+
87
+
88
+ # Machete is also only supported internally at Meta for now.
89
+ try:
90
+ from machete.machete import machete_gemm
91
+ from machete.quantize import machete_quantize_and_pack
92
+
93
+ MACHETE_ENABLED = True
94
+ except ImportError:
95
+ MACHETE_ENABLED = False
96
+
97
+
98
+ quantize_op_registry = []
99
+
100
+
101
+ def round_up(x: int, y: int) -> int:
102
+ return ((x + y - 1) // y) * y
103
+
104
+
105
+ class QuantizeOpBase(metaclass=abc.ABCMeta):
106
+ """Helper abstract class to define expected methods of quantize ops."""
107
+
108
+ @abc.abstractmethod
109
+ def quantize(self, *args):
110
+ """Function which quantizes inputs."""
111
+ pass
112
+
113
+ @abc.abstractmethod
114
+ def compute(self, *args, **kwargs):
115
+ """Function which performs main compute operation."""
116
+ pass
117
+
118
+ @abc.abstractmethod
119
+ def quantize_and_compute(self, *args, **kwargs):
120
+ """Function which quantizes inputs and performs main compute operation."""
121
+ pass
122
+
123
+ def preprocess(self, *args):
124
+ """Preprocess inputs before benchmarking. These outputs will be passed to quantize."""
125
+ return args
126
+
127
+ def bench_with_rotating_buffer(self, fn, args, use_cuda_graph: bool = True):
128
+ import copy
129
+ import pickle
130
+
131
+ # torch.cuda.get_device_properties does not have L2/L3 cache size,
132
+ # so hard code an overapproximation of L2/L3 cache size to ensure L2/L3 cache flush
133
+ total_buffer_size = 512 * 1024 * 1024
134
+
135
+ # Use pickle to serialize model input to estimate total sizes of input
136
+ input_sizes = len(pickle.dumps(args))
137
+
138
+ # Make at least one copy of the inputs
139
+ copy_cnt = total_buffer_size // input_sizes
140
+ if copy_cnt == 0:
141
+ copy_cnt = 1
142
+
143
+ args_list = [args]
144
+ for _ in range(copy_cnt):
145
+ args_list.append(copy.deepcopy(args))
146
+
147
+ torch.cuda.synchronize()
148
+
149
+ def rotating_buffer_fn(fn, args_list, copy_cnt):
150
+ for i in range(copy_cnt):
151
+ fn(*(args_list[i]))
152
+
153
+ if use_cuda_graph:
154
+ with torch.cuda.stream(torch.cuda.Stream()):
155
+ # A rotating_buffer_fn contains multiple runs of the fn to benchmark,
156
+ # so divide time accordingly
157
+ return triton.testing.do_bench_cudagraph(
158
+ lambda: rotating_buffer_fn(self.compute, args_list, copy_cnt + 1),
159
+ rep=200,
160
+ ) / (copy_cnt + 1)
161
+ else:
162
+ return triton.testing.do_bench(
163
+ lambda: rotating_buffer_fn(self.compute, args_list, copy_cnt + 1),
164
+ rep=200,
165
+ ) / (copy_cnt + 1)
166
+
167
+ def benchmark(
168
+ self,
169
+ *args,
170
+ bench_quantize: bool = False,
171
+ use_rotating_buffer_bench: bool = False,
172
+ use_cuda_graph: bool = True,
173
+ **kwargs,
174
+ ) -> float:
175
+ """Benchmark runtime of this operator."""
176
+ if bench_quantize:
177
+ if use_cuda_graph:
178
+ with torch.cuda.stream(torch.cuda.Stream()):
179
+ t = triton.testing.do_bench_cudagraph(
180
+ lambda: self.quantize_and_compute(*args, **kwargs), rep=200
181
+ )
182
+ else:
183
+ t = triton.testing.do_bench(
184
+ lambda: self.quantize_and_compute(*args, **kwargs)
185
+ )
186
+ else:
187
+ if use_rotating_buffer_bench:
188
+ t = self.bench_with_rotating_buffer(self.compute, args, use_cuda_graph)
189
+ else:
190
+ if use_cuda_graph:
191
+ with torch.cuda.stream(torch.cuda.Stream()):
192
+ t = triton.testing.do_bench_cudagraph(
193
+ lambda: self.compute(*args, **kwargs), rep=200
194
+ )
195
+ else:
196
+ t = triton.testing.do_bench(lambda: self.compute(*args, **kwargs))
197
+ return t
198
+
199
+ @abc.abstractproperty
200
+ def name(self) -> str:
201
+ """Name of the operator."""
202
+ pass
203
+
204
+ @abc.abstractproperty
205
+ def hip(self) -> bool:
206
+ """Whether this operator supports AMD or not."""
207
+ pass
208
+
209
+ @abc.abstractproperty
210
+ def cuda(self) -> bool:
211
+ """Whether this operator supports Nvidia or not."""
212
+ pass
213
+
214
+ @property
215
+ def supported(self) -> bool:
216
+ """Whether this op will run on the current device."""
217
+ if torch.version.hip is not None:
218
+ return self.hip
219
+ elif torch.version.cuda is not None:
220
+ return self.cuda
221
+ else:
222
+ return False
223
+
224
+
225
+ def register_quantize_op(op):
226
+ """Decorator function for assembling all quantize ops."""
227
+ quantize_op_registry.append(op())
228
+ return op
229
+
230
+
231
+ def get_quantize_ops() -> list[QuantizeOpBase]:
232
+ """Get all registered quantize ops."""
233
+ return quantize_op_registry
234
+
235
+
236
+ @register_quantize_op
237
+ class BF16Baseline(QuantizeOpBase):
238
+ """
239
+ Baseline BF16 matmul.
240
+ """
241
+
242
+ def quantize(self, x, w):
243
+ if isinstance(x, list):
244
+ x = [i.bfloat16() for i in x]
245
+ w = [torch.transpose(i, -2, -1).bfloat16() for i in w]
246
+ else:
247
+ x = x.bfloat16()
248
+ w = torch.transpose(w, -2, -1).bfloat16()
249
+ return x, w
250
+
251
+ def compute(self, x, w):
252
+ # Handle both grouped and standard gemm.
253
+ if isinstance(x, list):
254
+ output = []
255
+ for i in range(len(x)):
256
+ output.append(torch.matmul(x[i], w[i]))
257
+ return output
258
+ return torch.matmul(x, w)
259
+
260
+ def quantize_and_compute(self, x, w):
261
+ return self.compute(*self.quantize(x, w))
262
+
263
+ @property
264
+ def name(self) -> str:
265
+ return "bf16_baseline"
266
+
267
+ @property
268
+ def hip(self) -> bool:
269
+ return True
270
+
271
+ @property
272
+ def cuda(self) -> bool:
273
+ return True
274
+
275
+
276
+ @register_quantize_op
277
+ class ScaledMMBaseline(QuantizeOpBase):
278
+ """
279
+ Reference FP8 matmul implemented in native torch with cublas or hipblas.
280
+ """
281
+
282
+ def __init__(self):
283
+ self.fp8_dtype, _, _, _ = get_fp8_constants()
284
+ self.E4M3_MAX_POS: float = torch.finfo(self.fp8_dtype).max
285
+ self.E5M2_MAX_POS: float = torch.finfo(torch.float8_e5m2).max
286
+ self.FP16_MAX_POS: float = torch.finfo(torch.float16).max
287
+ self.EPS: float = 1e-12
288
+ self.fast_accum = True
289
+
290
+ def _amax_to_scale(
291
+ self, amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
292
+ ) -> torch.Tensor:
293
+ # To make scale dtype to be fp32 for accuracy
294
+ amax = amax.float()
295
+ if float8_dtype == self.fp8_dtype:
296
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
297
+ res = self.E4M3_MAX_POS / torch.clamp(amax, min=self.EPS)
298
+ else: # e5m2
299
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
300
+ res = self.E5M2_MAX_POS / torch.clamp(amax, min=self.EPS)
301
+
302
+ # pyre-fixme[7]: Expected `Tensor` but got `Union[float, Tensor]`.
303
+ return res
304
+
305
+ def _to_fp8_saturated(
306
+ self, x: torch.Tensor, float8_dtype: torch.dtype
307
+ ) -> torch.Tensor:
308
+ if float8_dtype == torch.float8_e4m3fn:
309
+ x = x.clamp(min=-1 * self.E4M3_MAX_POS, max=self.E4M3_MAX_POS)
310
+ else:
311
+ x = x.clamp(min=-1 * self.E5M2_MAX_POS, max=self.E5M2_MAX_POS)
312
+ return x.to(float8_dtype)
313
+
314
+ def _quantize_tensor(self, x):
315
+ x_amax = torch.max(torch.abs(x))
316
+ scale = self._amax_to_scale(x_amax, self.fp8_dtype, x.dtype)
317
+ scaled_x = self._to_fp8_saturated(x * scale, self.fp8_dtype)
318
+ x_inverse_scale = scale.reciprocal()
319
+ return scaled_x, x_inverse_scale
320
+
321
+ def quantize(self, x, w):
322
+ xq, x_scale = self._quantize_tensor(x)
323
+ wq, w_scale = self._quantize_tensor(w.t())
324
+ return xq, wq, x_scale, w_scale
325
+
326
+ def compute(self, xq, wq, x_scale, w_scale):
327
+ output = torch._scaled_mm(
328
+ xq,
329
+ wq,
330
+ bias=None,
331
+ out_dtype=torch.bfloat16,
332
+ scale_a=x_scale,
333
+ scale_b=w_scale,
334
+ scale_result=None,
335
+ use_fast_accum=self.fast_accum,
336
+ )
337
+ return output
338
+
339
+ def quantize_and_compute(self, x, w):
340
+ return self.compute(*self.quantize(x, w))
341
+
342
+ @property
343
+ def name(self) -> str:
344
+ return "scaled_mm"
345
+
346
+ @property
347
+ def hip(self) -> bool:
348
+ return True
349
+
350
+ @property
351
+ def cuda(self) -> bool:
352
+ return True
353
+
354
+
355
+ @register_quantize_op
356
+ class ScaledMMRowwise(QuantizeOpBase):
357
+ def __init__(self):
358
+ self.fast_accum = True
359
+ self.torch_compile = False
360
+
361
+ def quantize(self, x, w):
362
+ xq, x_scale = quantize_fp8_row(x)
363
+ wq, w_scale = quantize_fp8_row(w)
364
+ return xq, wq.t(), x_scale.unsqueeze(1), w_scale.unsqueeze(0)
365
+
366
+ def compute(self, xq, wq, x_scale, w_scale):
367
+ if self.torch_compile:
368
+ f = torch.compile(
369
+ torch._scaled_mm,
370
+ options={
371
+ "max_autotune": True,
372
+ "max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
373
+ },
374
+ )
375
+ else:
376
+ f = torch._scaled_mm
377
+
378
+ return f(
379
+ xq,
380
+ wq,
381
+ bias=None,
382
+ out_dtype=torch.bfloat16,
383
+ scale_a=x_scale,
384
+ scale_b=w_scale,
385
+ scale_result=None,
386
+ use_fast_accum=self.fast_accum,
387
+ )
388
+
389
+ def quantize_and_compute(self, x, w):
390
+ return self.compute(*self.quantize(x, w))
391
+
392
+ @property
393
+ def name(self) -> str:
394
+ return "scaled_mm_rowwise"
395
+
396
+ @property
397
+ def hip(self) -> bool:
398
+ return True
399
+
400
+ @property
401
+ def cuda(self) -> bool:
402
+ return True
403
+
404
+
405
+ @register_quantize_op
406
+ class FP8TensorwiseGemm(QuantizeOpBase):
407
+ """
408
+ FP8 matmul with tensorwise scaling.
409
+ """
410
+
411
+ def quantize(self, x, w):
412
+ # Quantize both input tensors.
413
+ xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
414
+ wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
415
+ return xq, wq, x_scale, w_scale
416
+
417
+ def compute(self, xq, wq, x_scale, w_scale):
418
+ return torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale)
419
+
420
+ def quantize_and_compute(self, x, w):
421
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
422
+ return self.compute(xq, wq, x_scale, w_scale)
423
+
424
+ @property
425
+ def name(self) -> str:
426
+ return "cutlass_tensorwise"
427
+
428
+ @property
429
+ def hip(self) -> bool:
430
+ # Need to add support for better quantize kernel.
431
+ # Also may have an issue with cuda graphs.
432
+ return False
433
+
434
+ @property
435
+ def cuda(self) -> bool:
436
+ return True
437
+
438
+
439
+ @register_quantize_op
440
+ class BF16OSSFastGemv(QuantizeOpBase):
441
+ """
442
+ BF16 OSS fast gemv kernel.
443
+ """
444
+
445
+ def quantize(self, x, w):
446
+ # dummy quantize
447
+ return x, w
448
+
449
+ def compute(self, x, w):
450
+ out = torch.ops.fbgemm.bf16_fast_gemv(x, w)
451
+ return out
452
+
453
+ def quantize_and_compute(self, x, w):
454
+ x, w = self.quantize(x, w)
455
+ return self.compute(x, w)
456
+
457
+ @property
458
+ def name(self) -> str:
459
+ return "bf16_oss_fast_gemv"
460
+
461
+ @property
462
+ def hip(self) -> bool:
463
+ # This implementation is specific to cublas.
464
+ return False
465
+
466
+ @property
467
+ def cuda(self) -> bool:
468
+ return True
469
+
470
+
471
+ @register_quantize_op
472
+ class BF16Fp8OSSFastGemv(QuantizeOpBase):
473
+ """
474
+ BF16FP8 OSS fast gemv kernel.
475
+ """
476
+
477
+ def quantize(self, x, w):
478
+ wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
479
+ return x, wq, w_scale
480
+
481
+ def compute(self, x, wq, w_scale):
482
+ out = torch.ops.fbgemm.bf16fp8bf16_fast_gemv(x, wq, w_scale)
483
+ return out
484
+
485
+ def quantize_and_compute(self, x, w):
486
+ x, wq, w_scale = self.quantize(x, w)
487
+ return self.compute(x, wq, w_scale)
488
+
489
+ @property
490
+ def name(self) -> str:
491
+ return "bf16fp8_oss_fast_gemv"
492
+
493
+ @property
494
+ def hip(self) -> bool:
495
+ # This implementation is specific to cublas.
496
+ return False
497
+
498
+ @property
499
+ def cuda(self) -> bool:
500
+ return True
501
+
502
+
503
+ @register_quantize_op
504
+ class Fp8Fp8OSSFastGemv(QuantizeOpBase):
505
+ """
506
+ FP8FP8 OSS fast gemv kernel.
507
+ """
508
+
509
+ def quantize(self, x, w):
510
+ # rowwise quantize
511
+ xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x)
512
+ wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
513
+ return xq, wq, x_scale, w_scale
514
+
515
+ def compute(self, xq, wq, x_scale, w_scale):
516
+ out = torch.ops.fbgemm.fp8fp8bf16_fast_gemv(
517
+ xq, wq, x_scale, w_scale, is_batched=False
518
+ )
519
+ return out
520
+
521
+ def quantize_and_compute(self, x, w):
522
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
523
+ return self.compute(xq, wq, x_scale, w_scale)
524
+
525
+ @property
526
+ def name(self) -> str:
527
+ return "fp8fp8_oss_fast_gemv"
528
+
529
+ @property
530
+ def hip(self) -> bool:
531
+ # This implementation is specific to cublas.
532
+ return False
533
+
534
+ @property
535
+ def cuda(self) -> bool:
536
+ return True
537
+
538
+
539
+ @register_quantize_op
540
+ class Fp8OSSFastGemvBatched(QuantizeOpBase):
541
+ """
542
+ Batched fp8 fast gemv kernel
543
+ """
544
+
545
+ def quantize(self, x, w):
546
+ # rowwise quantize
547
+ xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x)
548
+ wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
549
+ return xq, wq, x_scale, w_scale
550
+
551
+ def compute(self, xq, wq, x_scale, w_scale):
552
+ out = torch.ops.fbgemm.fp8fp8bf16_fast_gemv(
553
+ xq, wq, x_scale, w_scale, is_batched=True
554
+ )
555
+ return out
556
+
557
+ def quantize_and_compute(self, x, w):
558
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
559
+ return self.compute(xq, wq, x_scale, w_scale)
560
+
561
+ @property
562
+ def name(self) -> str:
563
+ return "fp8fp8_oss_fast_gemv_batched"
564
+
565
+ @property
566
+ def hip(self) -> bool:
567
+ # This implementation is specific to cublas.
568
+ return False
569
+
570
+ @property
571
+ def cuda(self) -> bool:
572
+ return True
573
+
574
+
575
+ @register_quantize_op
576
+ class FP8CublasRowwiseGemm(QuantizeOpBase):
577
+ """
578
+ FP8 cublas matmul with rowwise scaling.
579
+ """
580
+
581
+ def quantize(self, x, w):
582
+ # Quantize both input tensors.
583
+ xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x)
584
+ wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
585
+ return xq, wq, x_scale, w_scale
586
+
587
+ def compute(self, xq, wq, x_scale, w_scale):
588
+ out = torch.ops.fbgemm.f8f8bf16_cublas(xq, wq)
589
+ scaled_out = scale_fp8_row(out, x_scale, w_scale)
590
+ return scaled_out
591
+
592
+ def quantize_and_compute(self, x, w):
593
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
594
+ return self.compute(xq, wq, x_scale, w_scale)
595
+
596
+ @property
597
+ def name(self) -> str:
598
+ return "cublas_rowwise"
599
+
600
+ @property
601
+ def hip(self) -> bool:
602
+ # This implementation is specific to cublas.
603
+ return False
604
+
605
+ @property
606
+ def cuda(self) -> bool:
607
+ return True
608
+
609
+
610
+ @register_quantize_op
611
+ class FP8CublasTensorwiseGemm(QuantizeOpBase):
612
+ """
613
+ FP8 cublas matmul with tensorwise scaling.
614
+ """
615
+
616
+ def quantize(self, x, w):
617
+ # Quantize both input tensors.
618
+ xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
619
+ wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
620
+ return xq, wq, x_scale, w_scale
621
+
622
+ def compute(self, xq, wq, x_scale, w_scale):
623
+ return torch.ops.fbgemm.f8f8bf16_cublas(xq, wq, x_scale * w_scale)
624
+
625
+ def quantize_and_compute(self, x, w):
626
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
627
+ return self.compute(xq, wq, x_scale * w_scale)
628
+
629
+ @property
630
+ def name(self) -> str:
631
+ return "cublas_tensorwise"
632
+
633
+ @property
634
+ def hip(self) -> bool:
635
+ # This implementation is specific to cublas.
636
+ return False
637
+
638
+ @property
639
+ def cuda(self) -> bool:
640
+ return True
641
+
642
+
643
+ @register_quantize_op
644
+ class FP8RowwiseGemm(QuantizeOpBase):
645
+ """
646
+ FP8 matmul with rowwise scaling.
647
+ """
648
+
649
+ def __init__(self):
650
+ self.fast_accum = True
651
+ self.gemm_op = torch.ops.fbgemm.f8f8bf16_rowwise
652
+ self.quantize_op = quantize_fp8_row
653
+
654
+ def preprocess(self, x, w):
655
+ # Prequantize weights.
656
+ if isinstance(w, (list, tuple)):
657
+ wq, w_scale = zip(*[self.quantize_op(i) for i in w])
658
+ else:
659
+ wq, w_scale = self.quantize_op(w)
660
+ if wq.dim() == 3:
661
+ w_scale = w_scale.view(wq.size(0), -1)
662
+ return x, wq, w_scale
663
+
664
+ def quantize(self, x, wq, w_scale):
665
+ # Quantize both input tensors.
666
+ # Handle both grouped and standard gemm.
667
+ if isinstance(x, (list, tuple)):
668
+ xq, x_scale = zip(*[self.quantize_op(i) for i in x])
669
+ else:
670
+ xq, x_scale = self.quantize_op(x)
671
+ # Set proper batch dimension shapes.
672
+ if xq.dim() == 3:
673
+ x_scale = x_scale.view(xq.size(0), -1)
674
+ return xq, wq, x_scale, w_scale
675
+
676
+ def compute(self, xq, wq, x_scale, w_scale):
677
+ # Handle group gemm if inputs are grouped.
678
+ if isinstance(xq, (list, tuple)):
679
+ output = []
680
+ for i in range(len(xq)):
681
+ output.append(
682
+ self.gemm_op(
683
+ xq[i],
684
+ wq[i],
685
+ x_scale[i],
686
+ w_scale[i],
687
+ use_fast_accum=self.fast_accum,
688
+ )
689
+ )
690
+ return output
691
+ # Unroll batched gemm if needed.
692
+ elif xq.dim() == 3 and wq.dim() == 3:
693
+ B, M, _ = xq.shape
694
+ _, N, _ = wq.shape
695
+ y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
696
+ for i in range(B):
697
+ y[i] = self.gemm_op(
698
+ xq[i], wq[i], x_scale[i], w_scale[i], use_fast_accum=self.fast_accum
699
+ )
700
+ return y
701
+ # Otherwise return normal gemm result.
702
+ return self.gemm_op(xq, wq, x_scale, w_scale, use_fast_accum=self.fast_accum)
703
+
704
+ def quantize_and_compute(self, x, wq, w_scale):
705
+ xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
706
+ return self.compute(xq, wq, x_scale, w_scale)
707
+
708
+ @property
709
+ def name(self) -> str:
710
+ if torch.version.cuda:
711
+ return "cutlass_rowwise"
712
+ else:
713
+ return "ck_rowwise"
714
+
715
+ @property
716
+ def hip(self) -> bool:
717
+ return True
718
+
719
+ @property
720
+ def cuda(self) -> bool:
721
+ return True
722
+
723
+
724
+ @register_quantize_op
725
+ class FP8RowwisePreshuffleGemm(FP8RowwiseGemm):
726
+ """
727
+ FP8 matmul with rowwise scaling and preshuffling of input B.
728
+ """
729
+
730
+ def __init__(self):
731
+ self.fast_accum = True
732
+ if self.supported:
733
+ self.gemm_op = torch.ops.fbgemm.f8f8bf16_rowwise_preshuffle
734
+
735
+ def preprocess(self, x, w):
736
+ x, wq, w_scale = super().preprocess(x, w)
737
+ return x, ck_preshuffle(wq, 16), w_scale
738
+
739
+ @property
740
+ def name(self) -> str:
741
+ if torch.version.cuda:
742
+ return "cutlass_rowwise_preshuffle"
743
+ else:
744
+ return "ck_rowwise_preshuffle"
745
+
746
+ @property
747
+ def hip(self) -> bool:
748
+ return True
749
+
750
+ @property
751
+ def cuda(self) -> bool:
752
+ # Not yet supported on nvidia.
753
+ return False
754
+
755
+
756
+ @register_quantize_op
757
+ class FP8RowwiseGroupedGemm(QuantizeOpBase):
758
+ """
759
+ FP8 grouped matmul with rowwise scaling.
760
+ """
761
+
762
+ def preprocess(self, x, w):
763
+ # Apply sparsity to inputs if appropriate.
764
+ # First check if N and K are fixed.
765
+ m_values = [i.shape[0] for i in x]
766
+ n_values = [i.shape[0] for i in w]
767
+ k_values = [i.shape[1] for i in w]
768
+ # If so, do specialized version of initialization.
769
+ if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
770
+ m_values = [i.shape[0] for i in x]
771
+ # Inputs for fixed nk mode must be contiguous, however in the benchmark
772
+ # script they typically are not. Do a little special processing to make them
773
+ # work. In practice this wont be needed.
774
+ # Start by padding along m dimension with zeros.
775
+ max_m = max(m_values)
776
+ x = [
777
+ torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
778
+ for i in x
779
+ ]
780
+ # Stack inputs into groups.
781
+ x = torch.stack(x).contiguous()
782
+ w = torch.stack(w).contiguous()
783
+
784
+ # Preapply weight quantization.
785
+ wq, w_scale = quantize_fp8_row(w)
786
+ # Return processed tensors.
787
+ return (
788
+ x,
789
+ wq,
790
+ w_scale,
791
+ torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
792
+ )
793
+ # Otherwise run without sparsity.
794
+ wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
795
+ return x, wq, w_scale, None
796
+
797
+ def quantize(self, x, wq, w_scale, m_values=None):
798
+ # Handle case where inputs are explicitly grouped and non-sparse.
799
+ if isinstance(x, (tuple, list)):
800
+ xq, x_scale = zip(*[triton_quantize_fp8_row(i) for i in x])
801
+ return xq, wq, x_scale, w_scale, m_values
802
+ # Otherwise inputs are unified tensors and sparse.
803
+ else:
804
+ B = x.shape[0]
805
+ xq, x_scale = triton_quantize_fp8_row(x, zero_start_index_M=m_values)
806
+ x_scale = x_scale.view(B, -1)
807
+ return xq, wq, x_scale, w_scale, m_values
808
+
809
+ def compute(self, xq, wq, x_scale, w_scale, m_values):
810
+ if m_values is None:
811
+ return torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
812
+ xq,
813
+ wq,
814
+ x_scale,
815
+ w_scale,
816
+ )
817
+ else:
818
+ # Break tensor into groups, simulates what is done e2e.
819
+ return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
820
+ xq,
821
+ wq,
822
+ x_scale,
823
+ w_scale,
824
+ zero_start_index_M=m_values,
825
+ )
826
+
827
+ def quantize_and_compute(self, x, wq, w_scale, m_values=None):
828
+ xq, wq, x_scale, w_scale, m_values = self.quantize(x, wq, w_scale, m_values)
829
+ return self.compute(xq, wq, x_scale, w_scale, m_values)
830
+
831
+ @property
832
+ def name(self) -> str:
833
+ if torch.version.cuda:
834
+ return "cutlass_rowwise_grouped"
835
+ else:
836
+ return "ck_rowwise_grouped"
837
+
838
+ @property
839
+ def hip(self) -> bool:
840
+ return True
841
+
842
+ @property
843
+ def cuda(self) -> bool:
844
+ return True
845
+
846
+
847
+ @register_quantize_op
848
+ class BF16TritonStackedGroupedGemm(QuantizeOpBase):
849
+ """
850
+ BF16 grouped matmul with stacked inputs implemented with triton.
851
+ """
852
+
853
+ def preprocess(self, x, w):
854
+ m_values = [i.shape[0] for i in x]
855
+ # Convert m_values into offsets into grouped tensor.
856
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
857
+ w = torch.concat(w, dim=0).contiguous()
858
+ # Also view input as flattened.
859
+ x = torch.concat(x, dim=0).contiguous()
860
+ # Return processed tensors.
861
+ return x, w, m_sizes
862
+
863
+ def quantize(self, x, w, m_sizes):
864
+ return x, w, m_sizes
865
+
866
+ def compute(self, x, w, m_sizes):
867
+ return grouped_gemm(x, w, m_sizes, _use_warp_specialization=True)
868
+
869
+ def quantize_and_compute(self, x, w, m_sizes):
870
+ x, w, m_sizes = self.quantize(x, w, m_sizes)
871
+ return self.compute(x, w, m_sizes)
872
+
873
+ @property
874
+ def name(self) -> str:
875
+ return "triton_bf16_grouped_stacked"
876
+
877
+ @property
878
+ def hip(self) -> bool:
879
+ return True
880
+
881
+ @property
882
+ def cuda(self) -> bool:
883
+ return True
884
+
885
+
886
+ @register_quantize_op
887
+ class BF16TritonStackedGroupedGemmFuseScatterAdd(BF16TritonStackedGroupedGemm):
888
+ """
889
+ BF16 grouped matmul with stacked inputs implemented with triton. Fused with ScatterAdd.
890
+ """
891
+
892
+ def preprocess(self, x, w):
893
+ x, w, m_sizes = super().preprocess(x, w)
894
+ M = x.shape[0]
895
+ N = w.shape[0] // m_sizes.shape[0]
896
+ output = torch.zeros(M, N, dtype=torch.bfloat16, device=x.device)
897
+ indices = torch.randperm(M, dtype=torch.int32, device=x.device)
898
+ return x, w, m_sizes, output, indices
899
+
900
+ def quantize(self, x, w, m_sizes, *args):
901
+ return *super().quantize(x, w, m_sizes), *args
902
+
903
+ def compute(self, x, w, m_sizes, output, indices):
904
+ return grouped_gemm(
905
+ x,
906
+ w,
907
+ m_sizes,
908
+ _use_warp_specialization=True,
909
+ _output_tensor=output,
910
+ _scatter_add_indices=indices,
911
+ )
912
+
913
+ def quantize_and_compute(self, x, w, m_sizes, *args):
914
+ x, w, m_sizes, *ret = self.quantize(x, w, m_sizes, *args)
915
+ return self.compute(x, w, m_sizes, *ret)
916
+
917
+ @property
918
+ def name(self) -> str:
919
+ return "triton_bf16_grouped_stacked_fuse_scatter_add"
920
+
921
+
922
+ @register_quantize_op
923
+ class FP8TritonStackedGroupedGemm(QuantizeOpBase):
924
+ """
925
+ FP8 grouped matmul with rowwise scaling and stacked inputs implemented with triton.
926
+ """
927
+
928
+ def preprocess(self, x, w):
929
+ m_values = [i.shape[0] for i in x]
930
+ # Convert m_values into offsets into grouped tensor.
931
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
932
+ # Quantize weights.
933
+ wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
934
+ # Group weights as single tensor.
935
+ wq = torch.concat(wq, dim=0).contiguous()
936
+ w_scale = torch.concat(w_scale, dim=0).contiguous()
937
+ # Also view input as flattened.
938
+ x = torch.concat(x, dim=0).contiguous()
939
+ # Return processed tensors.
940
+ return x, wq, w_scale, m_sizes
941
+
942
+ def quantize(self, x, wq, w_scale, m_sizes):
943
+ B = x.shape[0]
944
+ xq, x_scale = triton_quantize_fp8_row(x)
945
+ x_scale = x_scale.view(B, -1)
946
+ return xq, wq, x_scale, w_scale, m_sizes
947
+
948
+ def compute(self, xq, wq, x_scale, w_scale, m_sizes):
949
+ return grouped_gemm_fp8_rowwise(
950
+ xq, wq, m_sizes, x_scale, w_scale, _use_warp_specialization=True
951
+ )
952
+
953
+ def quantize_and_compute(self, x, wq, w_scale, m_sizes):
954
+ xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
955
+ return self.compute(xq, wq, x_scale, w_scale, m_sizes)
956
+
957
+ @property
958
+ def name(self) -> str:
959
+ return "triton_grouped_stacked"
960
+
961
+ @property
962
+ def hip(self) -> bool:
963
+ return True
964
+
965
+ @property
966
+ def cuda(self) -> bool:
967
+ return True
968
+
969
+
970
+ @register_quantize_op
971
+ class FP8TritonStackedGroupedGemmFuseScatterAdd(FP8TritonStackedGroupedGemm):
972
+ """
973
+ FP8 grouped matmul with stacked inputs implemented with triton. Fused with ScatterAdd.
974
+ """
975
+
976
+ def preprocess(self, x, w):
977
+ x, wq, w_scale, m_sizes = super().preprocess(x, w)
978
+ M = x.shape[0]
979
+ N = wq.shape[0] // m_sizes.shape[0]
980
+ output = torch.zeros(M, N, dtype=torch.bfloat16, device=x.device)
981
+ indices = torch.randperm(M, dtype=torch.int32, device=x.device)
982
+ return x, wq, w_scale, m_sizes, output, indices
983
+
984
+ def quantize(self, x, wq, w_scale, m_sizes, *args):
985
+ return *super().quantize(x, wq, w_scale, m_sizes), *args
986
+
987
+ def compute(self, xq, wq, x_scale, w_scale, m_sizes, output, indices):
988
+ return grouped_gemm_fp8_rowwise(
989
+ xq,
990
+ wq,
991
+ m_sizes,
992
+ x_scale,
993
+ w_scale,
994
+ _use_warp_specialization=True,
995
+ _output_tensor=output,
996
+ _scatter_add_indices=indices,
997
+ )
998
+
999
+ def quantize_and_compute(self, x, wq, w_scale, m_sizes, *args):
1000
+ xq, wq, x_scale, w_scale, m_sizes, *ret = self.quantize(
1001
+ x, wq, w_scale, m_sizes, *args
1002
+ )
1003
+ return self.compute(xq, wq, x_scale, w_scale, m_sizes, *ret)
1004
+
1005
+ @property
1006
+ def name(self) -> str:
1007
+ return "triton_grouped_stacked_fuse_scatter_add"
1008
+
1009
+
1010
+ @register_quantize_op
1011
+ class DeepGemmStacked(QuantizeOpBase):
1012
+ """
1013
+ FP8 grouped matmul with blockwise scaling implemented with DeepGemm.
1014
+ """
1015
+
1016
+ def preprocess(self, x, w):
1017
+ m_values = [i.shape[0] for i in x]
1018
+ # Convert m_values into offsets into grouped tensor.
1019
+ indices = torch.arange(len(m_values))
1020
+ m_indices = indices.repeat_interleave(torch.tensor(m_values)).to(
1021
+ device=x[0].device, dtype=torch.int
1022
+ )
1023
+ # Quantize weights.
1024
+ wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w])
1025
+ # Group weights as single tensor.
1026
+ wq = torch.stack(wq, dim=0).contiguous()
1027
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
1028
+ # Also view input as flattened.
1029
+ x = torch.concat(x, dim=0).contiguous()
1030
+ # Return processed tensors.
1031
+ return x, wq, w_scale, m_indices
1032
+
1033
+ def quantize(self, x, wq, w_scale, m_indices):
1034
+ xq, x_scale = quantize_fp8_block(x, block_m=1, block_k=128)
1035
+ # Pretranspose scales to deepgemm format.
1036
+ x_scale = get_col_major_tma_aligned_tensor(x_scale)
1037
+ return xq, wq, x_scale, w_scale, m_indices
1038
+
1039
+ def compute(self, xq, wq, x_scale, w_scale, m_indices):
1040
+ # Preallocate output.
1041
+ out = torch.empty(
1042
+ [xq.shape[0], wq.shape[1]], device=xq.device, dtype=torch.bfloat16
1043
+ )
1044
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1045
+ (xq, x_scale), (wq, w_scale), out, m_indices
1046
+ )
1047
+ return out
1048
+
1049
+ def quantize_and_compute(self, x, wq, w_scale, m_indices):
1050
+ xq, wq, x_scale, w_scale, m_indices = self.quantize(x, wq, w_scale, m_indices)
1051
+ return self.compute(xq, wq, x_scale, w_scale, m_indices)
1052
+
1053
+ @property
1054
+ def name(self) -> str:
1055
+ return "deepgemm_stacked"
1056
+
1057
+ @property
1058
+ def hip(self) -> bool:
1059
+ return False
1060
+
1061
+ @property
1062
+ def cuda(self) -> bool:
1063
+ return DEEPGEMM_ENABLED
1064
+
1065
+
1066
+ @register_quantize_op
1067
+ class DeepGemmMaskedStacked(DeepGemmStacked):
1068
+ def preprocess(self, x, w):
1069
+ # Quantize weights.
1070
+ wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w])
1071
+ # Group weights as single tensor.
1072
+ wq = torch.stack(wq, dim=0).contiguous()
1073
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
1074
+
1075
+ # Also view input as flattened.
1076
+ m_values = [i.shape[0] for i in x]
1077
+ expected_m = max(m_values)
1078
+ padded_m_max = ((max(m_values) + 127) // 128) * 128
1079
+ masked_m = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
1080
+
1081
+ num_groups = len(m_values)
1082
+ k = x[0].shape[1]
1083
+ x_padded = torch.zeros(
1084
+ [num_groups, padded_m_max, k], device=x[0].device, dtype=x[0].dtype
1085
+ )
1086
+ for g in range(num_groups):
1087
+ x_padded[g, : m_values[g], :] = x[g]
1088
+
1089
+ # Return processed tensors.
1090
+ return x_padded, wq, w_scale, masked_m, expected_m, m_values
1091
+
1092
+ def quantize(self, x, wq, w_scale, masked_m, expected_m, m_values):
1093
+ g, m_max, k = x.shape
1094
+ xq, x_scale = quantize_fp8_block(x.view(-1, k), block_m=1, block_k=128)
1095
+ # Pretranspose scales to deepgemm format.
1096
+ x_scale = get_col_major_tma_aligned_tensor(x_scale)
1097
+ return (
1098
+ xq.view(g, m_max, -1),
1099
+ wq,
1100
+ x_scale.view(g, m_max, -1),
1101
+ w_scale,
1102
+ masked_m,
1103
+ expected_m,
1104
+ m_values,
1105
+ )
1106
+
1107
+ def compute(self, xq, wq, x_scale, w_scale, masked_m, expected_m, m_values):
1108
+ # Preallocate output.
1109
+ out = torch.empty(
1110
+ [xq.shape[0], xq.shape[1], wq.shape[1]],
1111
+ device=xq.device,
1112
+ dtype=torch.bfloat16,
1113
+ )
1114
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1115
+ (xq, x_scale), (wq, w_scale), out, masked_m, expected_m
1116
+ )
1117
+ num_groups = xq.shape[0]
1118
+ out_list = [out[g, : m_values[g], :] for g in range(num_groups)]
1119
+ return out_list
1120
+
1121
+ def quantize_and_compute(self, x, wq, w_scale, masked_m, expected_m, m_values):
1122
+ xq, wq, x_scale, w_scale, masked_m, expected_m = self.quantize(
1123
+ x, wq, w_scale, masked_m, expected_m, m_values
1124
+ )
1125
+ return self.compute(xq, wq, x_scale, w_scale, masked_m, expected_m, m_values)
1126
+
1127
+ @property
1128
+ def name(self) -> str:
1129
+ return "deepgemm_masked_stacked"
1130
+
1131
+
1132
+ @register_quantize_op
1133
+ class DeepGemmBlockwise(QuantizeOpBase):
1134
+ """
1135
+ FP8 matmul with blockwise scaling implemented with DeepGemm.
1136
+ """
1137
+
1138
+ def preprocess(self, x, w):
1139
+ # Quantize weights.
1140
+ wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128)
1141
+ # allocate output.
1142
+ out = torch.empty(
1143
+ x.shape[0], wq.shape[0], device=x.device, dtype=torch.bfloat16
1144
+ )
1145
+ # Return processed tensors.
1146
+ return x, wq, w_scale, out
1147
+
1148
+ def quantize(self, x, wq, w_scale, out):
1149
+ xq, x_scale = quantize_fp8_group(x, group_size=128)
1150
+ return xq, wq, x_scale, w_scale, out
1151
+
1152
+ def compute(self, xq, wq, x_scale, w_scale, out):
1153
+ gemm_fp8_fp8_bf16_nt((xq, x_scale), (wq, w_scale), out)
1154
+ return out
1155
+
1156
+ def quantize_and_compute(self, x, wq, w_scale, out):
1157
+ xq, wq, x_scale, w_scale, out = self.quantize(x, wq, w_scale, out)
1158
+ return self.compute(xq, wq, x_scale, w_scale, out)
1159
+
1160
+ @property
1161
+ def name(self) -> str:
1162
+ return "deepgemm_blockwise"
1163
+
1164
+ @property
1165
+ def hip(self) -> bool:
1166
+ return False
1167
+
1168
+ @property
1169
+ def cuda(self) -> bool:
1170
+ return True
1171
+
1172
+
1173
+ @register_quantize_op
1174
+ class DeepGemmRowwise(QuantizeOpBase):
1175
+ """
1176
+ FP8 matmul with rowwise scaling implemented with DeepGemm.
1177
+ """
1178
+
1179
+ def preprocess(self, x, w):
1180
+ # Quantize weights.
1181
+ wq, w_scale = quantize_fp8_row(w)
1182
+ # allocate output.
1183
+ out = torch.empty(
1184
+ x.shape[0], wq.shape[0], device=x.device, dtype=torch.bfloat16
1185
+ )
1186
+ # Return processed tensors.
1187
+ return x, wq, w_scale, out
1188
+
1189
+ def quantize(self, x, wq, w_scale, out):
1190
+ xq, x_scale = quantize_fp8_row(x)
1191
+ # Pretranspose scales to deepgemm format.
1192
+ x_scale = get_col_major_tma_aligned_tensor(x_scale, rowwise_scaling=True)
1193
+ return xq, wq, x_scale, w_scale, out
1194
+
1195
+ def compute(self, xq, wq, x_scale, w_scale, out):
1196
+ gemm_fp8_fp8_bf16_nt((xq, x_scale), (wq, w_scale), out)
1197
+ return out
1198
+
1199
+ def quantize_and_compute(self, x, wq, w_scale, out):
1200
+ xq, wq, x_scale, w_scale, out = self.quantize(x, wq, w_scale, out)
1201
+ return self.compute(xq, wq, x_scale, w_scale, out)
1202
+
1203
+ @property
1204
+ def name(self) -> str:
1205
+ return "deepgemm_rowwise"
1206
+
1207
+ @property
1208
+ def hip(self) -> bool:
1209
+ return False
1210
+
1211
+ @property
1212
+ def cuda(self) -> bool:
1213
+ return DEEPGEMM_ENABLED
1214
+
1215
+
1216
+ @register_quantize_op
1217
+ class FP8StackedGroupedGemm(QuantizeOpBase):
1218
+ """
1219
+ FP8 grouped matmul with rowwise scaling and stacked inputs.
1220
+ """
1221
+
1222
+ def preprocess(self, x, w):
1223
+ m_values = [i.shape[0] for i in x]
1224
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
1225
+ # Quantize weights.
1226
+ wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
1227
+ # Group weights as single tensor.
1228
+ wq = torch.stack(wq, dim=0).contiguous()
1229
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
1230
+ # Also view input as flattened.
1231
+ x = torch.concat(x, dim=0).contiguous()
1232
+ # Return processed tensors.
1233
+ return x, wq, w_scale, m_sizes
1234
+
1235
+ def quantize(self, x, wq, w_scale, m_sizes):
1236
+ B = x.shape[0]
1237
+ xq, x_scale = triton_quantize_fp8_row(x)
1238
+ x_scale = x_scale.view(B, -1)
1239
+ return xq, wq, x_scale, w_scale, m_sizes
1240
+
1241
+ def compute(self, xq, wq, x_scale, w_scale, m_sizes):
1242
+ return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
1243
+ xq, wq, x_scale, w_scale, m_sizes
1244
+ )
1245
+
1246
+ def quantize_and_compute(self, x, wq, w_scale, m_sizes):
1247
+ xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
1248
+ return self.compute(xq, wq, x_scale, w_scale, m_sizes)
1249
+
1250
+ @property
1251
+ def name(self) -> str:
1252
+ if torch.version.cuda:
1253
+ return "cutlass_grouped_stacked"
1254
+ else:
1255
+ return "ck_grouped_stacked"
1256
+
1257
+ @property
1258
+ def hip(self) -> bool:
1259
+ return True
1260
+
1261
+ @property
1262
+ def cuda(self) -> bool:
1263
+ return True
1264
+
1265
+
1266
+ @register_quantize_op
1267
+ class FP8StackedGroupedGemmTorch(FP8StackedGroupedGemm):
1268
+ def quantize(self, x, wq, w_scale, m_sizes):
1269
+ xq, wq, x_scale, w_scale, m_sizes = super().quantize(x, wq, w_scale, m_sizes)
1270
+ offsets = torch.cumsum(m_sizes, dim=0, dtype=torch.int32)
1271
+ out = torch.empty(
1272
+ (xq.shape[0], wq.shape[1]), dtype=torch.bfloat16, device=xq.device
1273
+ )
1274
+ x_scale = x_scale.view(x_scale.shape[0])
1275
+ return xq, wq, x_scale, w_scale, offsets, out
1276
+
1277
+ def compute(self, xq, wq, x_scale, w_scale, offsets, out):
1278
+ return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_mm(
1279
+ xq, wq, x_scale, w_scale, offsets, out
1280
+ )
1281
+
1282
+ def quantize_and_compute(self, x, wq, w_scale, m_sizes):
1283
+ xq, wq, x_scale, w_scale, offsets, out = self.quantize(x, wq, w_scale, m_sizes)
1284
+ return self.compute(xq, wq, x_scale, w_scale, offsets, out)
1285
+
1286
+ @property
1287
+ def name(self) -> str:
1288
+ return "ck_grouped_stacked_torch_2d3d"
1289
+
1290
+ @property
1291
+ def hip(self) -> bool:
1292
+ return True
1293
+
1294
+ @property
1295
+ def cuda(self) -> bool:
1296
+ return False
1297
+
1298
+
1299
+ @register_quantize_op
1300
+ class ScaledGroupedMMRowwise(FP8StackedGroupedGemmTorch):
1301
+ def __init__(self):
1302
+ self.fast_accum = True
1303
+ self.torch_compile = False
1304
+
1305
+ def compute(self, xq, wq, x_scale, w_scale, offsets, _):
1306
+ if self.torch_compile:
1307
+ f = torch.compile(
1308
+ torch._scaled_grouped_mm,
1309
+ options={
1310
+ "max_autotune": True,
1311
+ "max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
1312
+ },
1313
+ )
1314
+ else:
1315
+ f = torch._scaled_grouped_mm
1316
+
1317
+ return f(
1318
+ xq,
1319
+ wq.transpose(-2, -1),
1320
+ offs=offsets,
1321
+ out_dtype=torch.bfloat16,
1322
+ scale_a=x_scale,
1323
+ scale_b=w_scale,
1324
+ scale_result=None,
1325
+ use_fast_accum=self.fast_accum,
1326
+ )
1327
+
1328
+ @property
1329
+ def name(self) -> str:
1330
+ return "scaled_grouped_mm_rowwise"
1331
+
1332
+ @property
1333
+ def hip(self) -> bool:
1334
+ return True
1335
+
1336
+ @property
1337
+ def cuda(self) -> bool:
1338
+ return True
1339
+
1340
+
1341
+ @register_quantize_op
1342
+ class FP8StackedGroupwiseGroupedGemm(QuantizeOpBase):
1343
+ """
1344
+ FP8 grouped matmul with groupwise scaling and stacked inputs.
1345
+ """
1346
+
1347
+ def preprocess(self, x, w):
1348
+ m_values = [i.shape[0] for i in x]
1349
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
1350
+ # Quantize weights.
1351
+ wq, w_scale = zip(
1352
+ *[quantize_fp8_block(i, block_m=128, block_k=128, k_major=False) for i in w]
1353
+ )
1354
+ # Group weights as single tensor.
1355
+ wq = torch.stack(wq, dim=0).contiguous()
1356
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
1357
+ # Also view input as flattened.
1358
+ x = torch.concat(x, dim=0).contiguous()
1359
+ # Return processed tensors.
1360
+ return x, wq, w_scale, m_sizes
1361
+
1362
+ def quantize(self, x, wq, w_scale, m_sizes):
1363
+ xq, x_scale = quantize_fp8_group(x, m_sizes=m_sizes)
1364
+ return xq, wq, x_scale, w_scale, m_sizes
1365
+
1366
+ def compute(self, xq, wq, x_scale, w_scale, m_sizes):
1367
+ return torch.ops.fbgemm.f8f8bf16_groupwise_grouped(
1368
+ xq, wq, x_scale, w_scale, m_sizes
1369
+ )
1370
+
1371
+ def quantize_and_compute(self, x, wq, w_scale, m_sizes):
1372
+ xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
1373
+ return self.compute(xq, wq, x_scale, w_scale, m_sizes)
1374
+
1375
+ @property
1376
+ def name(self) -> str:
1377
+ if torch.version.cuda:
1378
+ return "cutlass_groupwise_grouped"
1379
+ else:
1380
+ return "ck_groupwise_grouped"
1381
+
1382
+ @property
1383
+ def hip(self) -> bool:
1384
+ return False
1385
+
1386
+ @property
1387
+ def cuda(self) -> bool:
1388
+ return True
1389
+
1390
+
1391
+ @register_quantize_op
1392
+ class BF16GroupedGemm(QuantizeOpBase):
1393
+ """
1394
+ BF16 grouped matmul implemented with CK or Cutlass.
1395
+ """
1396
+
1397
+ def preprocess(self, x, w):
1398
+ # Apply sparsity to inputs if appropriate.
1399
+ # First check if N and K are fixed.
1400
+ m_values = [i.shape[0] for i in x]
1401
+ n_values = [i.shape[0] for i in w]
1402
+ k_values = [i.shape[1] for i in w]
1403
+ # If so, do specialized version of initialization.
1404
+ if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
1405
+ m_values = [i.shape[0] for i in x]
1406
+ # Inputs for fixed nk mode must be contiguous, however in the benchmark
1407
+ # script they typically are not. Do a little special processing to make them
1408
+ # work. In practice this wont be needed.
1409
+ # Start by padding along m dimension with zeros.
1410
+ max_m = max(m_values)
1411
+ x = [
1412
+ torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
1413
+ for i in x
1414
+ ]
1415
+ # Stack inputs into groups.
1416
+ x = torch.stack(x).contiguous()
1417
+ w = torch.stack(w).contiguous()
1418
+ return (
1419
+ x,
1420
+ w,
1421
+ torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
1422
+ )
1423
+ return x, w, None
1424
+
1425
+ def quantize(self, x, w, m_values=None):
1426
+ # No action required.
1427
+ return x, w, m_values
1428
+
1429
+ def compute(self, x, w, m_values):
1430
+ if m_values is None:
1431
+ return torch.ops.fbgemm.bf16bf16bf16_grouped(x, w)
1432
+ else:
1433
+ return torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic(x, w, m_values)
1434
+
1435
+ def quantize_and_compute(self, x, w, m_values):
1436
+ return self.compute(x, w, m_values)
1437
+
1438
+ @property
1439
+ def name(self) -> str:
1440
+ if torch.version.cuda:
1441
+ return "cutlass_bf16_grouped"
1442
+ else:
1443
+ return "ck_bf16_grouped"
1444
+
1445
+ @property
1446
+ def hip(self) -> bool:
1447
+ return True
1448
+
1449
+ @property
1450
+ def cuda(self) -> bool:
1451
+ return True
1452
+
1453
+
1454
+ @register_quantize_op
1455
+ class FP8RowwiseBatchedGemm(QuantizeOpBase):
1456
+ """
1457
+ FP8 batched matmul with rowwise scaling.
1458
+ """
1459
+
1460
+ def quantize(self, x, w):
1461
+ # Quantize both input tensors.
1462
+ xq, x_scale = quantize_fp8_row(x)
1463
+ wq, w_scale = quantize_fp8_row(w)
1464
+ return xq, wq, x_scale, w_scale
1465
+
1466
+ def compute(self, xq, wq, x_scale, w_scale):
1467
+ return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, wq, x_scale, w_scale)
1468
+
1469
+ def quantize_and_compute(self, x, w):
1470
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
1471
+ return self.compute(xq, wq, x_scale, w_scale)
1472
+
1473
+ @property
1474
+ def name(self) -> str:
1475
+ if torch.version.cuda:
1476
+ return "cutlass_rowwise_batched"
1477
+ else:
1478
+ return "ck_rowwise_batched"
1479
+
1480
+ @property
1481
+ def hip(self) -> bool:
1482
+ return True
1483
+
1484
+ @property
1485
+ def cuda(self) -> bool:
1486
+ return True
1487
+
1488
+
1489
+ @register_quantize_op
1490
+ class FP8LiteGemm(QuantizeOpBase):
1491
+ """
1492
+ FP8 lite matmul for memory bound.
1493
+ """
1494
+
1495
+ def quantize(self, x, w):
1496
+ # Quantize both input tensors.
1497
+ xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
1498
+ wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
1499
+ return xq, wq, x_scale, w_scale
1500
+
1501
+ def compute(self, xq, wq, x_scale, w_scale):
1502
+ return torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
1503
+
1504
+ def quantize_and_compute(self, x, w):
1505
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
1506
+ return self.compute(xq, wq, x_scale, w_scale)
1507
+
1508
+ @property
1509
+ def name(self) -> str:
1510
+ return "cuda_lite"
1511
+
1512
+ @property
1513
+ def hip(self) -> bool:
1514
+ # Need to add support for better quantize kernel.
1515
+ # Also may have an issue with cuda graphs.
1516
+ return False
1517
+
1518
+ @property
1519
+ def cuda(self) -> bool:
1520
+ return True
1521
+
1522
+
1523
+ @register_quantize_op
1524
+ class TritonFP8RowwiseGemm(QuantizeOpBase):
1525
+ """
1526
+ FP8 matmul with rowwise scaling implemented with Triton.
1527
+ """
1528
+
1529
+ def __init__(self):
1530
+ self.fast_accum = True
1531
+
1532
+ def quantize(self, x, w):
1533
+ # Quantize both input tensors.
1534
+ xq, x_scale = quantize_fp8_row(x)
1535
+ wq, w_scale = quantize_fp8_row(w)
1536
+ bias = torch.randn(w.shape[0], device=x.device, dtype=torch.float32)
1537
+ return xq, wq, x_scale, w_scale, bias
1538
+
1539
+ def compute(self, xq, wq, x_scale, w_scale, bias):
1540
+ return matmul_fp8_row(
1541
+ xq,
1542
+ wq,
1543
+ x_scale,
1544
+ w_scale,
1545
+ bias=bias,
1546
+ fp8_fast_accum=self.fast_accum,
1547
+ use_warp_specialization=True,
1548
+ )
1549
+
1550
+ def quantize_and_compute(self, x, w):
1551
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
1552
+ return self.compute(xq, wq, x_scale, w_scale)
1553
+
1554
+ @property
1555
+ def name(self) -> str:
1556
+ return "triton_rowwise"
1557
+
1558
+ @property
1559
+ def hip(self) -> bool:
1560
+ # triton FP8 matmuls do not currently compile on AMD.
1561
+ return False
1562
+
1563
+ @property
1564
+ def cuda(self) -> bool:
1565
+ return True
1566
+
1567
+
1568
+ @register_quantize_op
1569
+ class FP8TritonBlockwiseGemm(QuantizeOpBase):
1570
+ """
1571
+ FP8 matmul with block scaling.
1572
+ """
1573
+
1574
+ def quantize(self, x, w):
1575
+ # Quantize both input tensors.
1576
+ xq, x_scale = quantize_fp8_block(x, 128, 128)
1577
+ wq, w_scale = quantize_fp8_block(w, 128, 128)
1578
+ return xq, wq, x_scale, w_scale
1579
+
1580
+ def compute(self, xq, wq, x_scale, w_scale):
1581
+ return matmul_fp8_block(xq, wq, x_scale, w_scale, 128, 128, 128)
1582
+
1583
+ def quantize_and_compute(self, x, w):
1584
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
1585
+ return self.compute(xq, wq, x_scale, w_scale)
1586
+
1587
+ @property
1588
+ def name(self) -> str:
1589
+ return "triton_blockwise"
1590
+
1591
+ @property
1592
+ def hip(self) -> bool:
1593
+ # Currently has some issues.
1594
+ return False
1595
+
1596
+ @property
1597
+ def cuda(self) -> bool:
1598
+ return True
1599
+
1600
+
1601
+ @register_quantize_op
1602
+ class FP8CutlassBlockwiseGemm(QuantizeOpBase):
1603
+ """
1604
+ FP8 matmul with block scaling.
1605
+ """
1606
+
1607
+ def quantize(self, x, w):
1608
+ # Quantize both input tensors.
1609
+ xq, x_scale = quantize_fp8_block(x, 128, 128)
1610
+ wq, w_scale = quantize_fp8_block(w, 128, 128)
1611
+ return xq, wq, x_scale, w_scale
1612
+
1613
+ def compute(self, xq, wq, x_scale, w_scale):
1614
+ return torch.ops.fbgemm.f8f8bf16_blockwise(
1615
+ xq, wq, x_scale, w_scale, 128, 128, 128
1616
+ )
1617
+
1618
+ def quantize_and_compute(self, x, w):
1619
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
1620
+ return self.compute(xq, wq, x_scale, w_scale)
1621
+
1622
+ @property
1623
+ def name(self) -> str:
1624
+ if torch.version.cuda:
1625
+ return "cutlass_blockwise"
1626
+ else:
1627
+ return "ck_blockwise"
1628
+
1629
+ @property
1630
+ def hip(self) -> bool:
1631
+ return True
1632
+
1633
+ @property
1634
+ def cuda(self) -> bool:
1635
+ return True
1636
+
1637
+
1638
+ @register_quantize_op
1639
+ class FP8CutlassGroupwiseGemm(QuantizeOpBase):
1640
+ """
1641
+ FP8 matmul with group / block scaling.
1642
+ """
1643
+
1644
+ def preprocess(self, x, w):
1645
+ # Quantize weights.
1646
+ # Scale is expected to be in [K, N] layout (N Major).
1647
+ wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128, k_major=False)
1648
+ # Return processed tensors.
1649
+ return x, wq, w_scale
1650
+
1651
+ def quantize(self, x, wq, w_scale):
1652
+ # Scale is expected to be in [K, M] layout (M Major).
1653
+ xq, x_scale = quantize_fp8_group(x, k_major=False)
1654
+ # Pretranspose scales to deepgemm format.
1655
+ return xq, wq, x_scale, w_scale
1656
+
1657
+ def compute(self, xq, wq, x_scale, w_scale):
1658
+ return torch.ops.fbgemm.f8f8bf16_groupwise(xq, wq, x_scale, w_scale)
1659
+
1660
+ def quantize_and_compute(self, x, wq, w_scale):
1661
+ xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
1662
+ return self.compute(xq, wq, x_scale, w_scale)
1663
+
1664
+ @property
1665
+ def name(self) -> str:
1666
+ if torch.version.cuda:
1667
+ return "cutlass_groupwise"
1668
+ else:
1669
+ return "ck_groupwise"
1670
+
1671
+ @property
1672
+ def hip(self) -> bool:
1673
+ return False
1674
+
1675
+ @property
1676
+ def cuda(self) -> bool:
1677
+ return True
1678
+
1679
+
1680
+ ####################################################################################################
1681
+ # CUTLASS kernel v2
1682
+ ####################################################################################################
1683
+
1684
+
1685
+ @register_quantize_op
1686
+ class CutlassFP8TensorwiseGemm_v2(QuantizeOpBase):
1687
+ """
1688
+ FP8 matmul with tensorwise scaling.
1689
+ """
1690
+
1691
+ def quantize(self, x, w):
1692
+ # Quantize both input tensors.
1693
+ xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
1694
+ wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
1695
+ return xq, wq, x_scale, w_scale
1696
+
1697
+ def compute(self, xq, wq, x_scale, w_scale):
1698
+ if hasattr(torch.ops.cutlass_extensions, "f8f8bf16"):
1699
+ return torch.ops.cutlass_extensions.f8f8bf16(xq, wq, x_scale * w_scale)
1700
+ else:
1701
+ raise RuntimeError(
1702
+ "Skipping cutlass_extensions_v2 runs as it is not supported"
1703
+ )
1704
+
1705
+ def quantize_and_compute(self, x, w):
1706
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
1707
+ return self.compute(xq, wq, x_scale, w_scale)
1708
+
1709
+ @property
1710
+ def name(self) -> str:
1711
+ return "cutlass_tensorwise_v2"
1712
+
1713
+ @property
1714
+ def hip(self) -> bool:
1715
+ # Need to add support for better quantize kernel.
1716
+ # Also may have an issue with cuda graphs.
1717
+ return False
1718
+
1719
+ @property
1720
+ def cuda(self) -> bool:
1721
+ return True
1722
+
1723
+
1724
+ # CUTLASS kernel v2
1725
+ @register_quantize_op
1726
+ class CutlassFP8RowwiseGemm_v2(QuantizeOpBase):
1727
+ """
1728
+ FP8 matmul with rowwise scaling.
1729
+ """
1730
+
1731
+ def quantize(self, x, w):
1732
+ # Quantize both input tensors.
1733
+ xq, x_scale = quantize_fp8_row(x)
1734
+ wq, w_scale = quantize_fp8_row(w)
1735
+ return xq, wq, x_scale, w_scale
1736
+
1737
+ def compute(self, xq, wq, x_scale, w_scale):
1738
+ if hasattr(torch.ops.cutlass_extensions, "f8f8bf16_rowwise"):
1739
+ return torch.ops.cutlass_extensions.f8f8bf16_rowwise(
1740
+ xq, wq, x_scale, w_scale
1741
+ )
1742
+ else:
1743
+ raise RuntimeError(
1744
+ "Skipping cutlass_extensions_v2 runs as it is not supported"
1745
+ )
1746
+
1747
+ def quantize_and_compute(self, x, w):
1748
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
1749
+ return self.compute(xq, wq, x_scale, w_scale)
1750
+
1751
+ @property
1752
+ def name(self) -> str:
1753
+ return "cutlass_rowwise_v2"
1754
+
1755
+ @property
1756
+ def hip(self) -> bool:
1757
+ # Need to add support for better quantize kernel.
1758
+ # Also may have an issue with cuda graphs.
1759
+ return False
1760
+
1761
+ @property
1762
+ def cuda(self) -> bool:
1763
+ return True
1764
+
1765
+
1766
+ ####################################################################################################
1767
+
1768
+
1769
+ @register_quantize_op
1770
+ class F8I4RowwiseGemm(QuantizeOpBase):
1771
+ """
1772
+ Mixed Precision FP8 Activations with Int4 Weights.
1773
+ """
1774
+
1775
+ def _int4_row_quantize(
1776
+ self,
1777
+ x: torch.Tensor,
1778
+ group_size: int = 128,
1779
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1780
+ n_bit = 4 # Number of target bits.
1781
+ to_quant = x.reshape(-1, group_size).to(torch.float)
1782
+
1783
+ max_val = to_quant.amax(dim=1, keepdim=True)
1784
+ min_val = to_quant.amin(dim=1, keepdim=True)
1785
+ max_int = 2**n_bit - 1
1786
+ min_int = 0
1787
+ scales = (max_val - min_val).clamp(min=1e-6) / max_int
1788
+
1789
+ zeros = min_val + scales * (2 ** (n_bit - 1))
1790
+
1791
+ out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
1792
+
1793
+ # Recenter output and move to int8.
1794
+ out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
1795
+
1796
+ # Cutlass expects column major layout for scale and zero point,
1797
+ # so we transpose here and make them contiguous.
1798
+ scales = scales.view(x.shape[0], -1).t().contiguous()
1799
+ zeros = zeros.view(x.shape[0], -1).t().contiguous()
1800
+
1801
+ return out, scales, zeros
1802
+
1803
+ def _pack_int4(self, x: torch.Tensor) -> torch.Tensor:
1804
+ # Given int8 x, pack adjacent int4 values into a single int8.
1805
+ low_x = x[:, ::2]
1806
+ high_x = x[:, 1::2]
1807
+
1808
+ # High bits need to left shift, this also masks off extra bits.
1809
+ high_x = torch.bitwise_left_shift(high_x, 4)
1810
+ # Low bits need to have sign bits removed.
1811
+ low_x = torch.bitwise_and(low_x, 0xF)
1812
+
1813
+ # Recombine into a single value with bitwise or.
1814
+ return torch.bitwise_or(low_x, high_x).contiguous()
1815
+
1816
+ def quantize(self, x, w):
1817
+ # Quantize both input tensors.
1818
+ xq, x_scale = quantize_fp8_row(x)
1819
+ wq, w_scale, w_zp = self._int4_row_quantize(w)
1820
+ # Pack int4 values together.
1821
+ wq = self._pack_int4(wq)
1822
+ return xq, wq, x_scale, w_scale, w_zp
1823
+
1824
+ def compute(self, xq, wq, x_scale, w_scale, w_zp):
1825
+ return torch.ops.fbgemm.f8i4bf16_rowwise(xq, wq, x_scale, w_scale, w_zp)
1826
+
1827
+ def quantize_and_compute(self, x, w):
1828
+ xq, wq, x_scale, w_scale, w_zp = self.quantize(x, w)
1829
+ return self.compute(xq, wq, x_scale, w_scale, w_zp)
1830
+
1831
+ @property
1832
+ def name(self) -> str:
1833
+ return "cutlass_f8i4_rowwise"
1834
+
1835
+ @property
1836
+ def hip(self) -> bool:
1837
+ # Not yet supported on AMD.
1838
+ return False
1839
+
1840
+ @property
1841
+ def cuda(self) -> bool:
1842
+ return True
1843
+
1844
+
1845
+ @register_quantize_op
1846
+ class F8I4ShuffledGemm(QuantizeOpBase):
1847
+ def preprocess(self, x, w):
1848
+ # Prequantize and pack weights.
1849
+ wq, (group_scale, row_scale) = quantize_int4_preshuffle(w)
1850
+ return x, wq, row_scale, group_scale
1851
+
1852
+ def quantize(self, x, wq, row_scale, group_scale):
1853
+ # Quantize both input tensors.
1854
+ xq, x_scale = quantize_fp8_row(x)
1855
+ return xq, wq, x_scale, row_scale, group_scale
1856
+
1857
+ def compute(self, xq, wq, x_scale, row_scale, group_scale):
1858
+ # Handle batched cases by looping over each batch.
1859
+ if xq.dim() == 3:
1860
+ B, M, _ = xq.shape
1861
+ _, N, _ = wq.shape
1862
+ y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
1863
+ for i in range(B):
1864
+ y[i] = torch.ops.fbgemm.f8i4bf16_shuffled(
1865
+ xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i]
1866
+ )
1867
+ return y
1868
+ # Otherwise run gemm normally.
1869
+ return torch.ops.fbgemm.f8i4bf16_shuffled(
1870
+ xq, wq, x_scale, row_scale, group_scale
1871
+ )
1872
+
1873
+ def quantize_and_compute(self, x, wq, row_scale, group_scale):
1874
+ xq, wq, x_scale, row_scale, group_scale = self.quantize(
1875
+ x, wq, row_scale, group_scale
1876
+ )
1877
+ return self.compute(xq, wq, x_scale, row_scale, group_scale)
1878
+
1879
+ @property
1880
+ def name(self) -> str:
1881
+ return "cutlass_f8i4_preshuffle"
1882
+
1883
+ @property
1884
+ def hip(self) -> bool:
1885
+ # Not yet supported on AMD.
1886
+ return False
1887
+
1888
+ @property
1889
+ def cuda(self) -> bool:
1890
+ return True
1891
+
1892
+
1893
+ @register_quantize_op
1894
+ class BF16I4ShuffledGemm(QuantizeOpBase):
1895
+ def preprocess(self, x, w):
1896
+ # Prequantize and pack weights.
1897
+ wq, (group_scale, group_zero) = quantize_int4_preshuffle(w, dtype="bf16")
1898
+ return x, wq, group_scale, group_zero
1899
+
1900
+ def quantize(self, x, wq, group_scale, group_zero):
1901
+ # No extra action required.
1902
+ return x, wq, group_scale, group_zero
1903
+
1904
+ def compute(self, x, wq, group_scale, group_zero):
1905
+ # Handle batched cases by looping over each batch.
1906
+ if x.dim() == 3:
1907
+ B, M, _ = x.shape
1908
+ _, N, _ = wq.shape
1909
+ y = torch.empty((B, M, N), device=x.device, dtype=torch.bfloat16)
1910
+ for i in range(B):
1911
+ y[i] = torch.ops.fbgemm.bf16i4bf16_shuffled(
1912
+ x[i], wq[i], group_scale[i], group_zero[i]
1913
+ )
1914
+ return y
1915
+ # Otherwise run gemm normally.
1916
+ return torch.ops.fbgemm.bf16i4bf16_shuffled(x, wq, group_scale, group_zero)
1917
+
1918
+ def quantize_and_compute(self, x, wq, group_scale, group_zero):
1919
+ x, wq, group_scale, group_zero = self.quantize(x, wq, group_scale, group_zero)
1920
+ return self.compute(x, wq, group_scale, group_zero)
1921
+
1922
+ @property
1923
+ def name(self) -> str:
1924
+ return "cutlass_bf16i4_preshuffle"
1925
+
1926
+ @property
1927
+ def hip(self) -> bool:
1928
+ # Not yet supported on AMD.
1929
+ return False
1930
+
1931
+ @property
1932
+ def cuda(self) -> bool:
1933
+ return True
1934
+
1935
+
1936
+ @register_quantize_op
1937
+ class BF16I4ShuffledBatchedGemm(QuantizeOpBase):
1938
+ """
1939
+ BF16 x INT4 mixed dtype batched gemm with preshuffling.
1940
+ """
1941
+
1942
+ def preprocess(self, x, w):
1943
+ # Prequantize and pack weights.
1944
+ wq, (group_scale, group_zero) = quantize_int4_preshuffle(w, dtype="bf16")
1945
+ return x, wq, group_scale, group_zero
1946
+
1947
+ def quantize(self, x, wq, group_scale, group_zero):
1948
+ # No extra action required.
1949
+ return x, wq, group_scale, group_zero
1950
+
1951
+ def compute(self, x, wq, group_scale, group_zero):
1952
+ return torch.ops.fbgemm.bf16i4bf16_shuffled_batched(
1953
+ x, wq, group_scale, group_zero
1954
+ )
1955
+
1956
+ def quantize_and_compute(self, x, wq, group_scale, group_zero):
1957
+ x, wq, group_scale, group_zero = self.quantize(x, wq, group_scale, group_zero)
1958
+ return self.compute(x, wq, group_scale, group_zero)
1959
+
1960
+ @property
1961
+ def name(self) -> str:
1962
+ return "cutlass_bf16i4_preshuffle_batched"
1963
+
1964
+ @property
1965
+ def hip(self) -> bool:
1966
+ # Not yet supported on AMD.
1967
+ return False
1968
+
1969
+ @property
1970
+ def cuda(self) -> bool:
1971
+ return True
1972
+
1973
+
1974
+ @register_quantize_op
1975
+ class F8I4ShuffledGroupedGemm(QuantizeOpBase):
1976
+ """
1977
+ FP8 x Int4 mixed dtype grouped gemm with preshuffling.
1978
+ """
1979
+
1980
+ def preprocess(self, x, w):
1981
+ assert isinstance(x, list) and isinstance(
1982
+ w, list
1983
+ ), "Only supported for grouped inputs."
1984
+ m_values = [i.shape[0] for i in x]
1985
+ # Convert m_values into offsets into grouped tensor.
1986
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
1987
+ # Quantize weights.
1988
+ wq, scales = zip(*[quantize_int4_preshuffle(i) for i in w])
1989
+ group_scale, row_scale = zip(*scales)
1990
+ # Group weights as single tensor.
1991
+ wq = torch.stack(wq, dim=0).contiguous()
1992
+ row_scale = torch.stack(row_scale, dim=0).contiguous()
1993
+ group_scale = torch.stack(group_scale, dim=0).contiguous()
1994
+ # Also view input as flattened.
1995
+ x = torch.concat(x, dim=0).contiguous()
1996
+ # Return processed tensors.
1997
+ return x, wq, row_scale, group_scale, m_sizes
1998
+
1999
+ def quantize(self, x, wq, row_scale, group_scale, m_sizes):
2000
+ B = x.shape[0]
2001
+ xq, x_scale = triton_quantize_fp8_row(x)
2002
+ x_scale = x_scale.view(B, -1)
2003
+ return xq, wq, x_scale, row_scale, group_scale, m_sizes
2004
+
2005
+ def compute(self, xq, wq, x_scale, row_scale, group_scale, m_sizes):
2006
+ out = torch.ops.fbgemm.f8i4bf16_shuffled_grouped(
2007
+ xq, wq, x_scale, row_scale, group_scale, m_sizes
2008
+ )
2009
+ return out
2010
+
2011
+ def quantize_and_compute(self, x, wq, row_scale, group_scale, m_sizes):
2012
+ xq, wq, x_scale, row_scale, group_scale, m_sizes = self.quantize(
2013
+ x, wq, row_scale, group_scale, m_sizes
2014
+ )
2015
+ return self.compute(xq, wq, x_scale, row_scale, group_scale, m_sizes)
2016
+
2017
+ @property
2018
+ def name(self) -> str:
2019
+ if torch.version.cuda:
2020
+ return "cutlass_f8i4_grouped_preshuffle"
2021
+ else:
2022
+ return "ck_f8i4_grouped_preshuffle"
2023
+
2024
+ @property
2025
+ def hip(self) -> bool:
2026
+ return False
2027
+
2028
+ @property
2029
+ def cuda(self) -> bool:
2030
+ return True
2031
+
2032
+
2033
+ @register_quantize_op
2034
+ class BF16I4ShuffledGroupedGemm(QuantizeOpBase):
2035
+ """
2036
+ BF16 x Int4 mixed dtype grouped gemm with preshuffling.
2037
+ """
2038
+
2039
+ def preprocess(self, x, w):
2040
+ assert isinstance(x, list) and isinstance(
2041
+ w, list
2042
+ ), "Only supported for grouped inputs."
2043
+ m_values = [i.shape[0] for i in x]
2044
+ # Convert m_values into offsets into grouped tensor.
2045
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
2046
+ # Quantize weights.
2047
+ wq, scales = zip(
2048
+ *[quantize_int4_preshuffle(i, dtype="bf16", use_zp=False) for i in w]
2049
+ )
2050
+ # Group weights as single tensor.
2051
+ group_scale, group_zero = zip(*scales)
2052
+ wq = torch.stack(wq, dim=0).contiguous()
2053
+ group_scale = torch.stack(group_scale, dim=0).contiguous()
2054
+ group_zero = torch.stack(group_zero, dim=0).contiguous()
2055
+ # Also view input as flattened.
2056
+ x = torch.concat(x, dim=0).contiguous()
2057
+ # Return processed tensors.
2058
+ return x, wq, group_scale, group_zero, m_sizes
2059
+
2060
+ def quantize(self, x, wq, group_scale, group_zero, m_sizes):
2061
+ return x, wq, group_scale, group_zero, m_sizes
2062
+
2063
+ def compute(self, x, wq, group_scale, group_zero, m_sizes):
2064
+ # TODO Zero points arent currently supported in grouped gemm.
2065
+ # We leave them as inputs for future compatibility but they are ignored.
2066
+ return torch.ops.fbgemm.bf16i4bf16_shuffled_grouped(
2067
+ x, wq, group_scale, group_zero, m_sizes
2068
+ )
2069
+
2070
+ def quantize_and_compute(self, x, wq, group_scale, group_zero, m_sizes):
2071
+ x, wq, group_scale, group_zero, m_sizes = self.quantize(
2072
+ x, wq, group_scale, group_zero, m_sizes
2073
+ )
2074
+ return self.compute(x, wq, group_scale, group_zero, m_sizes)
2075
+
2076
+ @property
2077
+ def name(self) -> str:
2078
+ if torch.version.cuda:
2079
+ return "cutlass_bf16i4_grouped_preshuffle"
2080
+ else:
2081
+ return "ck_bf16i4_grouped_preshuffle"
2082
+
2083
+ @property
2084
+ def hip(self) -> bool:
2085
+ return False
2086
+
2087
+ @property
2088
+ def cuda(self) -> bool:
2089
+ return True
2090
+
2091
+
2092
+ @register_quantize_op
2093
+ class BF16GroupedGrad(QuantizeOpBase):
2094
+ """
2095
+ BF16 grouped matmul with dgrad inputs in pretraining backed by cutlass
2096
+ """
2097
+
2098
+ def preprocess(self, x, w):
2099
+ m_values = [i.shape[0] for i in x]
2100
+ # Convert m_values into offsets into grouped tensor.
2101
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2102
+ # Group weights as single tensor.
2103
+ w = torch.stack(w, dim=0).contiguous()
2104
+ # Prepare online dgrad during pretraining backward.
2105
+ w_perm = w.permute(0, 2, 1).contiguous()
2106
+ # w.contiguous() is very expensive so handling it inside the gmm kernel for free
2107
+ w = w_perm.permute(0, 2, 1)
2108
+
2109
+ # Also view input as flattened.
2110
+ x = torch.concat(x, dim=0).contiguous()
2111
+ # Return processed tensors.
2112
+ return x, w, m_sizes
2113
+
2114
+ def quantize(self, x, w, m_sizes):
2115
+ return x, w, m_sizes
2116
+
2117
+ def compute(self, x, w, m_sizes):
2118
+ return torch.ops.fbgemm.bf16bf16bf16_grouped_grad(x, w, m_sizes)
2119
+
2120
+ def quantize_and_compute(self, x, w, m_sizes):
2121
+ x, w, m_sizes = self.quantize(x, w, m_sizes)
2122
+ return self.compute(x, w, m_sizes)
2123
+
2124
+ @property
2125
+ def name(self) -> str:
2126
+ return "bf16_grouped_grad"
2127
+
2128
+ @property
2129
+ def hip(self) -> bool:
2130
+ return False
2131
+
2132
+ @property
2133
+ def cuda(self) -> bool:
2134
+ return True
2135
+
2136
+
2137
+ @register_quantize_op
2138
+ class BF16GroupedWGrad(QuantizeOpBase):
2139
+ """
2140
+ BF16 grouped matmul with wgrad inputs in pretraining backed by cutlass
2141
+ """
2142
+
2143
+ def preprocess(self, x, w):
2144
+ # Get K values for each group
2145
+ k_values = [xi.shape[1] for xi in x] # K dimension for each group
2146
+
2147
+ # Convert k_values into sizes tensor
2148
+ k_sizes = torch.tensor(k_values).to(dtype=torch.int64, device=x[0].device)
2149
+
2150
+ x = torch.concat(x, dim=1).contiguous() # shape: (M, G*K)
2151
+ w = torch.concat(w, dim=1).contiguous() # shape: (N, G*K)
2152
+
2153
+ # Transpose the follows to simulate wgrad shapes
2154
+ x = x.t().contiguous() # shape: (G*K, M)
2155
+ w = w.t().contiguous() # shape: (G*K, N)
2156
+
2157
+ # Return processed tensors
2158
+ return x, w, k_sizes
2159
+
2160
+ def quantize(self, x, w, k_sizes):
2161
+ return x, w, k_sizes
2162
+
2163
+ def compute(self, x, w, k_sizes):
2164
+ return torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad(x, w, k_sizes)
2165
+
2166
+ def quantize_and_compute(self, x, w, k_sizes):
2167
+ x, w, k_sizes = self.quantize(x, w, k_sizes)
2168
+ return self.compute(x, w, k_sizes)
2169
+
2170
+ @property
2171
+ def name(self) -> str:
2172
+ return "bf16_grouped_wgrad"
2173
+
2174
+ @property
2175
+ def hip(self) -> bool:
2176
+ return False
2177
+
2178
+ @property
2179
+ def cuda(self) -> bool:
2180
+ return True
2181
+
2182
+
2183
+ @register_quantize_op
2184
+ class BF16GroupedStacked(QuantizeOpBase):
2185
+ """
2186
+ BF16 grouped matmul with stacked inputs backed by cutlass or ck.
2187
+ """
2188
+
2189
+ def preprocess(self, x, w):
2190
+ m_values = [i.shape[0] for i in x]
2191
+ # Convert m_values into offsets into grouped tensor.
2192
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2193
+ # Group weights as single tensor.
2194
+ w = torch.stack(w, dim=0).contiguous()
2195
+ # Also view input as flattened.
2196
+ x = torch.concat(x, dim=0).contiguous()
2197
+ # Return processed tensors.
2198
+ return x, w, m_sizes
2199
+
2200
+ def quantize(self, x, w, m_sizes):
2201
+ return x, w, m_sizes
2202
+
2203
+ def compute(self, x, w, m_sizes):
2204
+ return torch.ops.fbgemm.bf16bf16bf16_grouped_stacked(x, w, m_sizes)
2205
+
2206
+ def quantize_and_compute(self, x, w, m_sizes):
2207
+ x, w, m_sizes = self.quantize(x, w, m_sizes)
2208
+ return self.compute(x, w, m_sizes)
2209
+
2210
+ @property
2211
+ def name(self) -> str:
2212
+ return "bf16_grouped_stacked"
2213
+
2214
+ @property
2215
+ def hip(self) -> bool:
2216
+ return True
2217
+
2218
+ @property
2219
+ def cuda(self) -> bool:
2220
+ return True
2221
+
2222
+
2223
+ @register_quantize_op
2224
+ class BF16I4RowwiseGemm(F8I4RowwiseGemm):
2225
+ """
2226
+ Mixed Precision BF16 Activations with Int4 Weights.
2227
+ """
2228
+
2229
+ def quantize(self, x, w):
2230
+ # Quantize both input tensors.
2231
+ wq, w_scale, w_zp = self._int4_row_quantize(w)
2232
+ # Pack int4 values together.
2233
+ wq = self._pack_int4(wq)
2234
+ return (
2235
+ x.to(torch.bfloat16),
2236
+ wq,
2237
+ w_scale,
2238
+ w_zp,
2239
+ )
2240
+
2241
+ def compute(self, x, wq, w_scale, w_zp):
2242
+ return torch.ops.fbgemm.bf16i4bf16_rowwise(x, wq, w_scale, w_zp)
2243
+
2244
+ def quantize_and_compute(self, x, w):
2245
+ x, wq, w_scale, w_zp = self.quantize(x, w)
2246
+ return self.compute(x, wq, w_scale, w_zp)
2247
+
2248
+ @property
2249
+ def name(self) -> str:
2250
+ return "cutlass_bf16i4_rowwise"
2251
+
2252
+ @property
2253
+ def hip(self) -> bool:
2254
+ # Not yet supported on AMD.
2255
+ return False
2256
+
2257
+ @property
2258
+ def cuda(self) -> bool:
2259
+ return True
2260
+
2261
+
2262
+ @register_quantize_op
2263
+ class TinyGemmBF16I4(QuantizeOpBase):
2264
+ """
2265
+ Mixed Precision BF16 Activations with Int4 Weights using tinygemm.
2266
+ """
2267
+
2268
+ def quantize(self, x, w):
2269
+ # Quantize and pack weights to int4 using tinygemm utils.
2270
+ w_int32, w_scales_and_zeros = group_quantize_tensor(
2271
+ w, n_bit=4, q_group_size=128
2272
+ )
2273
+ wq = torch.ops.tinygemm.convert_matrix_to_m16n8k16_Aint4_layout(w_int32, 4)
2274
+ return x, wq, w_scales_and_zeros
2275
+
2276
+ def compute(self, x, wq, scale):
2277
+ return torch.ops.tinygemm.tinygemm_y_f16RM_x_f16RM_w_int4TC(
2278
+ wq, x, 128, scale, False
2279
+ )
2280
+
2281
+ def quantize_and_compute(self, x, w):
2282
+ x, wq, scale = self.quantize(x, w)
2283
+ return self.compute(x, wq, scale)
2284
+
2285
+ @property
2286
+ def name(self) -> str:
2287
+ return "tinygemm_bf16i4"
2288
+
2289
+ @property
2290
+ def hip(self) -> bool:
2291
+ # Tinygemm only supported for cuda.
2292
+ return False
2293
+
2294
+ @property
2295
+ def cuda(self) -> bool:
2296
+ # Only enabled if import works.
2297
+ return TINYGEMM_ENABLED
2298
+
2299
+
2300
+ @register_quantize_op
2301
+ class MarlinBF16I4(QuantizeOpBase):
2302
+ """
2303
+ Mixed Precision BF16 Activations with Int4 Weights using Marlin.
2304
+ """
2305
+
2306
+ def quantize(self, x, w):
2307
+ # Marlin quantize expects weights in [K, N] layout.
2308
+ _, wq, scale = marlin_quantize(w.t().contiguous(), 128)
2309
+ return x, wq, scale
2310
+
2311
+ def compute(self, x, wq, scale):
2312
+ return torch.ops.marlin.marlin_gemm(x, wq, scale)
2313
+
2314
+ def quantize_and_compute(self, x, w):
2315
+ x, wq, scale = self.quantize(x, w)
2316
+ return self.compute(x, wq, scale)
2317
+
2318
+ @property
2319
+ def name(self) -> str:
2320
+ return "marlin_bf16i4"
2321
+
2322
+ @property
2323
+ def hip(self) -> bool:
2324
+ # Marlin only supported for cuda.
2325
+ return False
2326
+
2327
+ @property
2328
+ def cuda(self) -> bool:
2329
+ # This op is not always supported.
2330
+ return MARLIN_ENABLED
2331
+
2332
+
2333
+ @register_quantize_op
2334
+ class MacheteBF16I4(QuantizeOpBase):
2335
+ """
2336
+ Mixed Precision BF16 Activations with Int4 Weights using Machete.
2337
+ """
2338
+
2339
+ def quantize(self, x, w):
2340
+ # Marlin quantize expects weights in [K, N] layout.
2341
+ _, wq, scale, _ = machete_quantize_and_pack(
2342
+ w.t().contiguous(), bits=4, groupsize=128
2343
+ )
2344
+ return x, wq, scale
2345
+
2346
+ def compute(self, x, wq, scale):
2347
+ return machete_gemm(x, wq, bits=4, groupsize=128, scales=scale)
2348
+
2349
+ def quantize_and_compute(self, x, w):
2350
+ x, wq, scale = self.quantize(x, w)
2351
+ return self.compute(x, wq, scale)
2352
+
2353
+ @property
2354
+ def name(self) -> str:
2355
+ return "machete_bf16i4"
2356
+
2357
+ @property
2358
+ def hip(self) -> bool:
2359
+ # Machete only supported for cuda.
2360
+ return False
2361
+
2362
+ @property
2363
+ def cuda(self) -> bool:
2364
+ # This op is not always supported.
2365
+ return MACHETE_ENABLED
2366
+
2367
+
2368
+ @register_quantize_op
2369
+ class NVFP4Gemm(QuantizeOpBase):
2370
+ """
2371
+ NVFP4 matmul with block-wise scaling.
2372
+ """
2373
+
2374
+ def quantize(self, x, w):
2375
+ x_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(x.flatten()), dim=-1).to(
2376
+ torch.float32
2377
+ )
2378
+ w_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(w.flatten()), dim=-1).to(
2379
+ torch.float32
2380
+ )
2381
+ global_scale = 1 / (x_global_scale * w_global_scale)
2382
+
2383
+ xq, x_scale = triton_scale_nvfp4_quant(x, x_global_scale)
2384
+ wq, w_scale = triton_scale_nvfp4_quant(w, w_global_scale)
2385
+
2386
+ return xq, wq, x_scale, w_scale, global_scale
2387
+
2388
+ def compute(self, xq, wq, x_scale, w_scale, global_scale):
2389
+ return torch.ops.fbgemm.f4f4bf16(
2390
+ xq, wq, x_scale, w_scale, global_scale=global_scale
2391
+ )
2392
+
2393
+ def quantize_and_compute(self, x, w):
2394
+ xq, wq, x_scale, w_scale, global_scale = self.quantize(x, w)
2395
+ return self.compute(xq, wq, x_scale, w_scale, global_scale=global_scale)
2396
+
2397
+ @property
2398
+ def name(self) -> str:
2399
+ return "cutlass_nv_f4f4bf16"
2400
+
2401
+ @property
2402
+ def hip(self) -> bool:
2403
+ # F4F4BF16 only supported for cuda.
2404
+ return False
2405
+
2406
+ @property
2407
+ def cuda(self) -> bool:
2408
+ return True
2409
+
2410
+
2411
+ @register_quantize_op
2412
+ class NVFP4Quantize(QuantizeOpBase):
2413
+ """
2414
+ NVFP4 quantization with block-wise scaling.
2415
+ """
2416
+
2417
+ def quantize_rms(self, x, w):
2418
+ M, N = w.shape
2419
+ group_size = 16
2420
+ w = torch.randn(group_size, dtype=torch.bfloat16, device=w.device)
2421
+ x_global_scale = torch.tensor([448.0 * 6.0]).to(
2422
+ device=x.device, dtype=torch.float32
2423
+ ) / torch.amax(torch.abs(x.flatten()), dim=-1).to(torch.float32)
2424
+ xq_ref, x_scale_ref = triton_scale_nvfp4_quant_rms(
2425
+ x,
2426
+ w.repeat(M * N // group_size),
2427
+ x_global_scale,
2428
+ group_size=group_size,
2429
+ EPS=1e-5,
2430
+ )
2431
+
2432
+ intermediate = rms_norm(x.reshape(-1, 16), w, eps=1e-5)
2433
+ intermediate = intermediate.to(torch.bfloat16).reshape(M, N)
2434
+ xq, x_scale = triton_scale_nvfp4_quant(
2435
+ intermediate,
2436
+ x_global_scale,
2437
+ group_size=group_size,
2438
+ )
2439
+
2440
+ def quantize_silu(self, x, w):
2441
+ M, N = x.shape
2442
+ group_size = 16
2443
+ x_global_scale = torch.tensor([448.0 * 6.0]).to(
2444
+ device=x.device, dtype=torch.float32
2445
+ ) / torch.amax(torch.abs(x.flatten()), dim=-1).to(torch.float32)
2446
+ xq_ref, x_scale_ref = triton_scale_nvfp4_quant_silu(
2447
+ x,
2448
+ w,
2449
+ x_global_scale,
2450
+ group_size=group_size,
2451
+ )
2452
+
2453
+ intermediate = silu_mul(x.reshape(-1, 16), w.reshape(-1, 16))
2454
+ intermediate = intermediate.to(torch.bfloat16).reshape(M, N)
2455
+ xq, x_scale = triton_scale_nvfp4_quant(
2456
+ intermediate,
2457
+ x_global_scale,
2458
+ group_size=group_size,
2459
+ )
2460
+
2461
+ def quantize(self, x, w):
2462
+ x_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(x.flatten()), dim=-1).to(
2463
+ torch.float32
2464
+ )
2465
+ w_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(w.flatten()), dim=-1).to(
2466
+ torch.float32
2467
+ )
2468
+ global_scale = 1 / (x_global_scale * w_global_scale)
2469
+
2470
+ xq, x_scale = triton_scale_nvfp4_quant(x, x_global_scale)
2471
+ wq, w_scale = triton_scale_nvfp4_quant(w, w_global_scale)
2472
+ return xq, wq, x_scale, w_scale, global_scale
2473
+
2474
+ def compute(self, xq, wq, x_scale, w_scale, global_scale):
2475
+ return torch.ops.fbgemm.f4f4bf16(
2476
+ xq, wq, x_scale, w_scale, global_scale=global_scale
2477
+ )
2478
+
2479
+ def quantize_and_compute(self, x, w):
2480
+ return self.quantize(x, w)
2481
+
2482
+ @property
2483
+ def name(self) -> str:
2484
+ return "nvfp4_quantize"
2485
+
2486
+ @property
2487
+ def hip(self) -> bool:
2488
+ # F4F4BF16 only supported for cuda.
2489
+ return False
2490
+
2491
+ @property
2492
+ def cuda(self) -> bool:
2493
+ return True
2494
+
2495
+
2496
+ @register_quantize_op
2497
+ class MXFP4Gemm(QuantizeOpBase):
2498
+ """
2499
+ MXFP4 matmul with block-wise scaling.
2500
+ """
2501
+
2502
+ def quantize(self, x, w):
2503
+ xq, x_scale = triton_quantize_mx4_unpack(x)
2504
+ wq, w_scale = triton_quantize_mx4_unpack(w)
2505
+ return xq, wq, x_scale, w_scale
2506
+
2507
+ def compute(self, xq, wq, x_scale, w_scale):
2508
+ return torch.ops.fbgemm.f4f4bf16(xq, wq, x_scale, w_scale)
2509
+
2510
+ def quantize_and_compute(self, x, w):
2511
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
2512
+ return self.compute(xq, wq, x_scale, w_scale)
2513
+
2514
+ @property
2515
+ def name(self) -> str:
2516
+ return "cutlass_f4f4bf16"
2517
+
2518
+ @property
2519
+ def hip(self) -> bool:
2520
+ # F4F4BF16 only supported for cuda.
2521
+ return False
2522
+
2523
+ @property
2524
+ def cuda(self) -> bool:
2525
+ return True
2526
+
2527
+
2528
+ @register_quantize_op
2529
+ class MXFP4StackedGroupedGemm(QuantizeOpBase):
2530
+ """
2531
+ MXFP4 grouped matmul with blockwise scaling and stacked inputs.
2532
+ """
2533
+
2534
+ def preprocess(self, x, w):
2535
+ m_values = [i.shape[0] for i in x]
2536
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2537
+ wq, w_scale = zip(*[triton_quantize_mx4_unpack(i) for i in w])
2538
+ wq = torch.stack(wq, dim=0).contiguous()
2539
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
2540
+ return x, wq, w_scale, m_sizes
2541
+
2542
+ def quantize(self, x, wq, w_scale, m_sizes):
2543
+ starting_row_after_padding_list = [0]
2544
+ xq_list = []
2545
+ x_scale_list = []
2546
+ for i in range(m_sizes.shape[0]):
2547
+ scale_slice = x[i]
2548
+ if m_sizes[i].item() != 0:
2549
+ xq, x_scale = triton_quantize_mx4_unpack(scale_slice)
2550
+ xq_list.append(xq)
2551
+ x_scale_list.append(x_scale)
2552
+ starting_row_after_padding_list.append(
2553
+ starting_row_after_padding_list[i]
2554
+ + x_scale.numel() // (x[0].shape[1] // 32)
2555
+ )
2556
+ else:
2557
+ starting_row_after_padding_list.append(
2558
+ starting_row_after_padding_list[i]
2559
+ )
2560
+ xq = torch.cat(xq_list, dim=0).contiguous()
2561
+ x_scale = torch.cat(x_scale_list, dim=0).contiguous()
2562
+ x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32)
2563
+ xq = xq.view(-1, xq.shape[-1])
2564
+ return (
2565
+ xq,
2566
+ wq,
2567
+ x_scale,
2568
+ w_scale,
2569
+ m_sizes,
2570
+ torch.tensor(starting_row_after_padding_list, device=xq.device),
2571
+ )
2572
+
2573
+ def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding):
2574
+ return torch.ops.fbgemm.f4f4bf16_grouped_stacked(
2575
+ xq,
2576
+ wq,
2577
+ x_scale,
2578
+ w_scale,
2579
+ m_sizes,
2580
+ starting_row_after_padding=starting_row_after_padding,
2581
+ )
2582
+
2583
+ def quantize_and_compute(self, x, w):
2584
+ xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize(
2585
+ x, w
2586
+ )
2587
+ return self.compute(
2588
+ xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding
2589
+ )
2590
+
2591
+ @property
2592
+ def name(self) -> str:
2593
+ return "cutlass_f4f4bf16_grouped_stacked"
2594
+
2595
+ @property
2596
+ def hip(self) -> bool:
2597
+ return False
2598
+
2599
+ @property
2600
+ def cuda(self) -> bool:
2601
+ return True
2602
+
2603
+
2604
+ @register_quantize_op
2605
+ class MXFP4GroupedGemm2D3D(QuantizeOpBase):
2606
+ """
2607
+ MXFP4 grouped GEMM with blockwise scaling and Torch 2D2D API.
2608
+ """
2609
+
2610
+ def preprocess(self, xs, ws):
2611
+ m_sizes = [x.shape[0] for x in xs]
2612
+ m_offsets = torch.cumsum(torch.tensor(m_sizes), dim=0).to(
2613
+ dtype=torch.int32, device=xs[0].device
2614
+ )
2615
+
2616
+ wqs = []
2617
+ w_scales = []
2618
+ for w in ws:
2619
+ wq, w_scale = triton_quantize_mx4_unpack(w)
2620
+ wqs.append(wq)
2621
+ w_scales.append(w_scale)
2622
+
2623
+ wq = torch.stack(wqs, dim=0)
2624
+ w_scale = torch.stack(w_scales, dim=0)
2625
+
2626
+ return xs, wq, w_scale, m_offsets
2627
+
2628
+ def quantize(self, xs, wq, w_scale, m_offsets):
2629
+ xqs = []
2630
+ x_scales = []
2631
+ for x in xs:
2632
+ xq, x_scale = triton_quantize_mx4_unpack(x)
2633
+ xqs.append(xq)
2634
+ x_scales.append(x_scale)
2635
+
2636
+ xq = torch.cat(xqs, dim=0)
2637
+ x_scale = torch.stack(x_scales, dim=0)
2638
+
2639
+ xq = xq.view(torch.float4_e2m1fn_x2)
2640
+ wq = wq.view(torch.float4_e2m1fn_x2)
2641
+ x_scale = x_scale.view(torch.float8_e8m0fnu)
2642
+ w_scale = w_scale.view(torch.float8_e8m0fnu)
2643
+
2644
+ return xq, wq, x_scale, w_scale, m_offsets
2645
+
2646
+ def compute(
2647
+ self,
2648
+ xq,
2649
+ wq,
2650
+ x_scale,
2651
+ w_scale,
2652
+ m_offsets,
2653
+ ):
2654
+ return torch.ops.fbgemm.f4f4bf16_grouped_mm(
2655
+ xq,
2656
+ wq.transpose(-2, -1),
2657
+ x_scale,
2658
+ w_scale,
2659
+ m_offsets,
2660
+ )
2661
+
2662
+ def quantize_and_compute(self, xs, wq, w_scale, m_offsets, output):
2663
+ args = self.quantize(xs, wq, w_scale, m_offsets, output)
2664
+ return self.compute(**args)
2665
+
2666
+ @property
2667
+ def name(self) -> str:
2668
+ return "cutlass_mx_f4f4bf16_grouped_mm_2d_3d"
2669
+
2670
+ @property
2671
+ def cuda(self) -> bool:
2672
+ return True
2673
+
2674
+ @property
2675
+ def hip(self) -> bool:
2676
+ return False
2677
+
2678
+
2679
+ @register_quantize_op
2680
+ class MXFP4GroupedGemm2D2D(QuantizeOpBase):
2681
+ """
2682
+ MXFP4 grouped GEMM with blockwise scaling and Torch 2D2D API.
2683
+ """
2684
+
2685
+ def preprocess(self, xs, ws):
2686
+ k_sizes = [x.shape[1] for x in xs]
2687
+ k_offsets = torch.cumsum(torch.tensor(k_sizes), dim=0).to(
2688
+ dtype=torch.int32, device=xs[0].device
2689
+ )
2690
+
2691
+ wqs = []
2692
+ w_scales = []
2693
+ for w in ws:
2694
+ wq, w_scale = triton_quantize_mx4_unpack(w)
2695
+ wqs.append(wq)
2696
+ w_scales.append(w_scale)
2697
+
2698
+ wq = torch.cat(wqs, dim=1)
2699
+ w_scale = torch.stack(w_scales, dim=0)
2700
+
2701
+ return xs, wq, w_scale, k_offsets
2702
+
2703
+ def quantize(self, xs, wq, w_scale, k_offsets):
2704
+ xqs = []
2705
+ x_scales = []
2706
+ for x in xs:
2707
+ xq, x_scale = triton_quantize_mx4_unpack(x)
2708
+ xqs.append(xq)
2709
+ x_scales.append(x_scale)
2710
+
2711
+ xq = torch.cat(xqs, dim=1)
2712
+ x_scale = torch.stack(x_scales, dim=0)
2713
+
2714
+ xq = xq.view(torch.float4_e2m1fn_x2)
2715
+ wq = wq.view(torch.float4_e2m1fn_x2)
2716
+ x_scale = x_scale.view(torch.float8_e8m0fnu)
2717
+ w_scale = w_scale.view(torch.float8_e8m0fnu)
2718
+
2719
+ return xq, wq, x_scale, w_scale, k_offsets
2720
+
2721
+ def compute(
2722
+ self,
2723
+ xq,
2724
+ wq,
2725
+ x_scale,
2726
+ w_scale,
2727
+ k_offsets,
2728
+ ):
2729
+ return torch.ops.fbgemm.f4f4bf16_grouped_mm(
2730
+ xq,
2731
+ wq.transpose(-2, -1),
2732
+ x_scale,
2733
+ w_scale,
2734
+ k_offsets,
2735
+ )
2736
+
2737
+ def quantize_and_compute(self, xs, wq, w_scale, k_offsets, output):
2738
+ args = self.quantize(xs, wq, w_scale, k_offsets, output)
2739
+ return self.compute(**args)
2740
+
2741
+ @property
2742
+ def name(self) -> str:
2743
+ return "cutlass_mx_f4f4bf16_grouped_mm_2d_2d"
2744
+
2745
+ @property
2746
+ def cuda(self) -> bool:
2747
+ return True
2748
+
2749
+ @property
2750
+ def hip(self) -> bool:
2751
+ return False
2752
+
2753
+
2754
+ @register_quantize_op
2755
+ class NVFP4GroupedGemm2D3D(QuantizeOpBase):
2756
+ """
2757
+ NVFP4 grouped GEMM with blockwise scaling and Torch 2D3D API.
2758
+ """
2759
+
2760
+ def preprocess(self, x, w):
2761
+ m_values = [i.shape[0] for i in x]
2762
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2763
+ x = torch.concat(x, dim=0).contiguous()
2764
+
2765
+ def get_global_scale(x, w, m_sizes):
2766
+ G = len(w)
2767
+ w_global_scale = []
2768
+ global_scale = []
2769
+
2770
+ x_global_scale, tensor_idx = calculate_group_max(x, m_sizes=m_sizes)
2771
+
2772
+ for i in range(G):
2773
+ w_global_scale_ = (448.0 * 6.0) / torch.amax(
2774
+ torch.abs(w[i].flatten()), dim=-1
2775
+ ).to(torch.float32)
2776
+
2777
+ global_scale_ = 1 / (x_global_scale[i] * w_global_scale_)
2778
+
2779
+ w_global_scale.append(w_global_scale_)
2780
+ global_scale.append(global_scale_)
2781
+
2782
+ return x_global_scale, w_global_scale, global_scale, tensor_idx
2783
+
2784
+ # Compute global scale for each group
2785
+ G = m_sizes.numel()
2786
+ x_global_scale, w_global_scale, global_scale, tensor_idx = get_global_scale(
2787
+ x, w, m_sizes
2788
+ )
2789
+ global_scale = torch.stack(global_scale, dim=0).contiguous()
2790
+
2791
+ wq, w_scale = zip(
2792
+ *[triton_scale_nvfp4_quant(w[i], w_global_scale[i]) for i in range(G)]
2793
+ )
2794
+ wq = torch.stack(wq, dim=0).contiguous()
2795
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
2796
+
2797
+ return x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2798
+
2799
+ def quantize(
2800
+ self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2801
+ ):
2802
+ xq, x_scale, _ = mega_fp4_quantize_kernel(
2803
+ m_sizes, x, x_global_scale, optional_tensor_idx=tensor_idx
2804
+ )
2805
+
2806
+ x_scale = x_scale.reshape(-1, x.shape[1] // 16)
2807
+ offsets = torch.cumsum(m_sizes, dim=0).to(torch.int32)
2808
+
2809
+ xq = xq.view(torch.float4_e2m1fn_x2)
2810
+ wq = wq.view(torch.float4_e2m1fn_x2)
2811
+ x_scale = x_scale.view(torch.float8_e4m3fn)
2812
+ w_scale = w_scale.view(torch.float8_e4m3fn)
2813
+
2814
+ return (
2815
+ xq,
2816
+ wq.transpose(-2, -1),
2817
+ x_scale,
2818
+ w_scale,
2819
+ offsets,
2820
+ None,
2821
+ global_scale,
2822
+ )
2823
+
2824
+ def compute(
2825
+ self,
2826
+ xq,
2827
+ wq,
2828
+ x_scale,
2829
+ w_scale,
2830
+ offsets,
2831
+ output,
2832
+ global_scale,
2833
+ ):
2834
+ return torch.ops.fbgemm.f4f4bf16_grouped_mm(
2835
+ xq,
2836
+ wq,
2837
+ x_scale,
2838
+ w_scale,
2839
+ offsets,
2840
+ output,
2841
+ global_scale,
2842
+ )
2843
+
2844
+ def quantize_and_compute(self, xq, wq, x_scale, w_scale, global_scale, k_offsets):
2845
+ args = self.quantize(xq, wq, x_scale, w_scale, global_scale, k_offsets)
2846
+ return self.compute(**args)
2847
+
2848
+ @property
2849
+ def name(self) -> str:
2850
+ return "cutlass_nv_f4f4bf16_grouped_mm_2d_3d"
2851
+
2852
+ @property
2853
+ def hip(self) -> bool:
2854
+ return False
2855
+
2856
+ @property
2857
+ def cuda(self) -> bool:
2858
+ return True
2859
+
2860
+
2861
+ @register_quantize_op
2862
+ class NVFP4GroupedGemm2D2D(QuantizeOpBase):
2863
+ """
2864
+ NVFP4 grouped GEMM with blockwise scaling and Torch 2D2D API.
2865
+ """
2866
+
2867
+ def preprocess(self, xs, ws):
2868
+ k_sizes = [x.shape[1] for x in xs]
2869
+ k_offsets = torch.cumsum(torch.tensor(k_sizes), dim=0).to(
2870
+ dtype=torch.int32, device=xs[0].device
2871
+ )
2872
+
2873
+ global_scales, x_global_scales, w_global_scales = get_nvfp4_global_scales_naive(
2874
+ xs, ws
2875
+ )
2876
+ wqs, w_scales = quantize_nvfp4_naive(ws, w_global_scales)
2877
+ wq = torch.cat(wqs, dim=1).view(torch.float4_e2m1fn_x2)
2878
+ w_scale = (
2879
+ torch.stack(w_scales, dim=0)
2880
+ .reshape(round_up(wq.size(0), 128), -1)
2881
+ .view(torch.float8_e4m3fn)
2882
+ )
2883
+ global_scale = torch.stack(global_scales, dim=0)
2884
+
2885
+ return xs, wq, w_scale, global_scale, x_global_scales, k_offsets
2886
+
2887
+ def quantize(self, xs, wq, w_scale, global_scale, x_global_scales, k_offsets):
2888
+ xqs, x_scales = quantize_nvfp4_naive(xs, x_global_scales)
2889
+ xq = torch.cat(xqs, dim=1).view(torch.float4_e2m1fn_x2)
2890
+ x_scale = (
2891
+ torch.stack(x_scales, dim=0)
2892
+ .reshape(round_up(xq.size(0), 128), -1)
2893
+ .view(torch.float8_e4m3fn)
2894
+ )
2895
+
2896
+ return xq, wq, x_scale, w_scale, k_offsets, global_scale
2897
+
2898
+ def compute(
2899
+ self,
2900
+ xq,
2901
+ wq,
2902
+ x_scale,
2903
+ w_scale,
2904
+ k_offsets,
2905
+ global_scale,
2906
+ ):
2907
+ return torch.ops.fbgemm.f4f4bf16_grouped_mm(
2908
+ xq,
2909
+ wq.transpose(-2, -1),
2910
+ x_scale,
2911
+ w_scale,
2912
+ k_offsets,
2913
+ None,
2914
+ global_scale,
2915
+ )
2916
+
2917
+ def quantize_and_compute(self, xq, wq, x_scale, w_scale, global_scale, k_offsets):
2918
+ args = self.quantize(xq, wq, x_scale, w_scale, global_scale, k_offsets)
2919
+ return self.compute(**args)
2920
+
2921
+ @property
2922
+ def name(self) -> str:
2923
+ return "cutlass_nv_f4f4bf16_grouped_mm_2d_2d"
2924
+
2925
+ @property
2926
+ def hip(self) -> bool:
2927
+ return False
2928
+
2929
+ @property
2930
+ def cuda(self) -> bool:
2931
+ return True
2932
+
2933
+
2934
+ @register_quantize_op
2935
+ class NVFP4StackedGroupedGemm(QuantizeOpBase):
2936
+ """
2937
+ NVFP4 grouped matmul with blockwise scaling and stacked inputs.
2938
+ """
2939
+
2940
+ def preprocess(self, x, w):
2941
+ m_values = [i.shape[0] for i in x]
2942
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2943
+ x = torch.concat(x, dim=0).contiguous()
2944
+
2945
+ def get_global_scale(x, w, m_sizes):
2946
+ G = len(w)
2947
+ w_global_scale = []
2948
+ global_scale = []
2949
+
2950
+ cumulative_sum = torch.zeros(
2951
+ m_sizes.shape[0] + 1, dtype=torch.int64, device=m_sizes.device
2952
+ )
2953
+ cumulative_sum[1:] = torch.cumsum(m_sizes, dim=0)
2954
+
2955
+ x_global_scale, tensor_idx = calculate_group_max(x, m_sizes=m_sizes)
2956
+
2957
+ for i in range(G):
2958
+ w_global_scale_ = (448.0 * 6.0) / torch.amax(
2959
+ torch.abs(w[i].flatten()), dim=-1
2960
+ ).to(torch.float32)
2961
+
2962
+ global_scale_ = 1 / (x_global_scale[i] * w_global_scale_)
2963
+
2964
+ w_global_scale.append(w_global_scale_)
2965
+ global_scale.append(global_scale_)
2966
+
2967
+ return x_global_scale, w_global_scale, global_scale, tensor_idx
2968
+
2969
+ # Compute global scale for each group
2970
+ G = m_sizes.numel()
2971
+ x_global_scale, w_global_scale, global_scale, tensor_idx = get_global_scale(
2972
+ x, w, m_sizes
2973
+ )
2974
+ global_scale = torch.stack(global_scale, dim=0).contiguous()
2975
+
2976
+ wq, w_scale = zip(
2977
+ *[triton_scale_nvfp4_quant(w[i], w_global_scale[i]) for i in range(G)]
2978
+ )
2979
+ wq = torch.stack(wq, dim=0).contiguous()
2980
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
2981
+
2982
+ return x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2983
+
2984
+ def quantize(
2985
+ self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2986
+ ):
2987
+ # alternative methods, may be useful in some scenarios
2988
+ """
2989
+ starting_row_after_padding, belong_indices, row_within_tensor = (
2990
+ nvfp4_fused_padding_cumsum_and_segmented_arange(m_sizes, x.shape[0])
2991
+ # fused_single_block_cumsum_and_segmented_arange(m_sizes, x.shape[0])
2992
+ )
2993
+
2994
+ xq, x_scale = triton_nvfp4_quant_stacked(
2995
+ x,
2996
+ x_global_scale[0],
2997
+ belong_indices,
2998
+ starting_row_after_padding,
2999
+ row_within_tensor,
3000
+ )
3001
+ """
3002
+
3003
+ # we can optionally set optional_tensor_idx to None to run the alternative method
3004
+ xq, x_scale, starting_row_after_padding = mega_fp4_quantize_kernel(
3005
+ m_sizes, x, x_global_scale, optional_tensor_idx=tensor_idx
3006
+ )
3007
+
3008
+ x_scale = x_scale.reshape(-1, x.shape[1] // 16)
3009
+ return (
3010
+ xq,
3011
+ wq,
3012
+ x_scale,
3013
+ w_scale,
3014
+ m_sizes,
3015
+ global_scale,
3016
+ starting_row_after_padding,
3017
+ )
3018
+
3019
+ def compute(
3020
+ self,
3021
+ xq,
3022
+ wq,
3023
+ x_scale,
3024
+ w_scale,
3025
+ m_sizes,
3026
+ global_scale,
3027
+ starting_row_after_padding,
3028
+ ):
3029
+ gemm_result = torch.ops.fbgemm.f4f4bf16_grouped_stacked(
3030
+ xq,
3031
+ wq,
3032
+ x_scale,
3033
+ w_scale,
3034
+ m_sizes,
3035
+ global_scale,
3036
+ starting_row_after_padding,
3037
+ use_mx=False,
3038
+ )
3039
+ return gemm_result
3040
+
3041
+ def quantize_and_compute(
3042
+ self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
3043
+ ):
3044
+ (
3045
+ xq,
3046
+ wq,
3047
+ x_scale,
3048
+ w_scale,
3049
+ m_sizes,
3050
+ global_scale,
3051
+ starting_row_after_padding,
3052
+ ) = self.quantize(
3053
+ x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
3054
+ )
3055
+ return self.compute(
3056
+ xq,
3057
+ wq,
3058
+ x_scale,
3059
+ w_scale,
3060
+ m_sizes,
3061
+ global_scale,
3062
+ starting_row_after_padding,
3063
+ )
3064
+
3065
+ @property
3066
+ def name(self) -> str:
3067
+ return "cutlass_nv_f4f4bf16_grouped_stacked"
3068
+
3069
+ @property
3070
+ def hip(self) -> bool:
3071
+ return False
3072
+
3073
+ @property
3074
+ def cuda(self) -> bool:
3075
+ return True
3076
+
3077
+
3078
+ @register_quantize_op
3079
+ class NVFP4StackedGroupedGemmPackUnpack(QuantizeOpBase):
3080
+ """
3081
+ NVFP4 grouped matmul with blockwise scaling and stacked inputs.
3082
+ """
3083
+
3084
+ def preprocess(self, x, w):
3085
+ m_values = [i.shape[0] for i in x]
3086
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
3087
+ x = torch.concat(x, dim=0).contiguous()
3088
+
3089
+ def get_global_scale(x, w):
3090
+ G = len(w)
3091
+ x_global_scale = []
3092
+ w_global_scale = []
3093
+ global_scale = []
3094
+
3095
+ x_global_scale_ = (448.0 * 6.0) / torch.amax(
3096
+ torch.abs(x.flatten()), dim=-1
3097
+ ).to(torch.float32)
3098
+
3099
+ for i in range(G):
3100
+ w_global_scale_ = (448.0 * 6.0) / torch.amax(
3101
+ torch.abs(w[i].flatten()), dim=-1
3102
+ ).to(torch.float32)
3103
+
3104
+ global_scale_ = 1 / (x_global_scale_ * w_global_scale_)
3105
+
3106
+ x_global_scale.append(x_global_scale_)
3107
+ w_global_scale.append(w_global_scale_)
3108
+ global_scale.append(global_scale_)
3109
+
3110
+ return x_global_scale, w_global_scale, global_scale
3111
+
3112
+ # Compute global scale for each group
3113
+ G = m_sizes.numel()
3114
+ x_global_scale, w_global_scale, global_scale = get_global_scale(x, w)
3115
+
3116
+ global_scale = torch.stack(global_scale, dim=0).contiguous()
3117
+
3118
+ wq, w_scale = zip(
3119
+ *[triton_scale_nvfp4_quant(w[i], w_global_scale[i]) for i in range(G)]
3120
+ )
3121
+ wq = torch.stack(wq, dim=0).contiguous()
3122
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
3123
+ x_global_scale = torch.tensor(x_global_scale, device=m_sizes.device)
3124
+ return (
3125
+ x,
3126
+ wq,
3127
+ w_scale,
3128
+ x_global_scale,
3129
+ global_scale,
3130
+ m_sizes,
3131
+ )
3132
+
3133
+ def quantize(self, x, wq, w_scale, x_global_scale, global_scale, m_sizes):
3134
+ # alternative packing methods that only uses the overall global scale rather than per tensor
3135
+ """
3136
+ packed = mega_fp4_pack(x, x_global_scale[0])
3137
+ """
3138
+ packed = mega_fp4_pack(
3139
+ x,
3140
+ x_global_scale,
3141
+ per_tensor=True,
3142
+ m_sizes=m_sizes,
3143
+ )
3144
+ xq, x_scale, starting_row_after_padding = mega_fp4_unpack(m_sizes, packed)
3145
+ xq_other, x_scale_other, starting_row_after_padding_other = (
3146
+ mega_fp4_quantize_kernel(
3147
+ m_sizes,
3148
+ x,
3149
+ x_global_scale,
3150
+ )
3151
+ )
3152
+
3153
+ x_scale = x_scale.reshape(-1, x.shape[1] // 16)
3154
+ x_scale_other = x_scale_other.reshape(-1, x.shape[1] // 16)
3155
+ return (
3156
+ xq,
3157
+ wq,
3158
+ x_scale,
3159
+ w_scale,
3160
+ m_sizes,
3161
+ global_scale,
3162
+ starting_row_after_padding,
3163
+ xq_other,
3164
+ x_scale_other,
3165
+ starting_row_after_padding_other,
3166
+ )
3167
+
3168
+ def compute(
3169
+ self,
3170
+ xq,
3171
+ wq,
3172
+ x_scale,
3173
+ w_scale,
3174
+ m_sizes,
3175
+ global_scale,
3176
+ starting_row_after_padding,
3177
+ xq_other,
3178
+ x_scale_other,
3179
+ starting_row_after_padding_other,
3180
+ ):
3181
+ ref_solution = torch.ops.fbgemm.f4f4bf16_grouped_stacked(
3182
+ xq_other,
3183
+ wq,
3184
+ x_scale_other,
3185
+ w_scale,
3186
+ m_sizes,
3187
+ global_scale,
3188
+ starting_row_after_padding_other,
3189
+ use_mx=False,
3190
+ )
3191
+ gemm_result = torch.ops.fbgemm.f4f4bf16_grouped_stacked(
3192
+ xq,
3193
+ wq,
3194
+ x_scale,
3195
+ w_scale,
3196
+ m_sizes,
3197
+ global_scale,
3198
+ starting_row_after_padding,
3199
+ use_mx=False,
3200
+ )
3201
+ assert torch.allclose(ref_solution, gemm_result)
3202
+
3203
+ return gemm_result
3204
+
3205
+ def quantize_and_compute(
3206
+ self, x, wq, w_scale, x_global_scale, global_scale, m_sizes
3207
+ ):
3208
+ (
3209
+ xq,
3210
+ wq,
3211
+ x_scale,
3212
+ w_scale,
3213
+ m_sizes,
3214
+ global_scale,
3215
+ starting_row_after_padding,
3216
+ xq_other,
3217
+ x_scale_other,
3218
+ starting_row_after_padding_other,
3219
+ ) = self.quantize(x, wq, w_scale, x_global_scale, global_scale, m_sizes)
3220
+ return self.compute(
3221
+ xq,
3222
+ wq,
3223
+ x_scale,
3224
+ w_scale,
3225
+ m_sizes,
3226
+ global_scale,
3227
+ starting_row_after_padding,
3228
+ xq_other,
3229
+ x_scale_other,
3230
+ starting_row_after_padding_other,
3231
+ )
3232
+
3233
+ @property
3234
+ def name(self) -> str:
3235
+ return "cutlass_nv_f4f4bf16_grouped_stacked_pack_unpack"
3236
+
3237
+ @property
3238
+ def hip(self) -> bool:
3239
+ return False
3240
+
3241
+ @property
3242
+ def cuda(self) -> bool:
3243
+ return True
3244
+
3245
+
3246
+ @register_quantize_op
3247
+ class BF16GroupedGemm2d3d(QuantizeOpBase):
3248
+ """
3249
+ Torch BF16 grouped GEMM with 2D inputs and 3D weights.
3250
+ """
3251
+
3252
+ def preprocess(self, x, w):
3253
+ assert isinstance(x, list)
3254
+ assert isinstance(w, list)
3255
+ offs = torch.tensor(
3256
+ [i.shape[0] for i in x], dtype=torch.int32, device=x[0].device
3257
+ )
3258
+ offs = torch.cumsum(offs, dim=0).to(torch.int32)
3259
+ x = torch.cat(x, dim=0).contiguous() # (G * M, K)
3260
+ w = torch.stack(w, dim=0).contiguous() # (G, N, K)
3261
+ return x, w, offs
3262
+
3263
+ def quantize(self, x, w, offs):
3264
+ return x, w, offs
3265
+
3266
+ def compute(self, x, w, offs):
3267
+ return torch._grouped_mm(
3268
+ x,
3269
+ w.transpose(-2, -1),
3270
+ offs=offs,
3271
+ )
3272
+
3273
+ def quantize_and_compute(self, x, w, offs):
3274
+ x, w, offs = self.quantize(x, w)
3275
+ return self.compute(x, w, offs)
3276
+
3277
+ @property
3278
+ def name(self) -> str:
3279
+ return "bf16_baseline_grouped_2d_3d"
3280
+
3281
+ @property
3282
+ def hip(self) -> bool:
3283
+ return False
3284
+
3285
+ @property
3286
+ def cuda(self) -> bool:
3287
+ return True
3288
+
3289
+
3290
+ @register_quantize_op
3291
+ class MXFP8GroupedGemm2d3d(QuantizeOpBase):
3292
+ """
3293
+ MXFP8 grouped GEMM with 2D inputs and 3D weights.
3294
+ """
3295
+
3296
+ def preprocess(self, x, w):
3297
+ assert isinstance(x, list)
3298
+ assert isinstance(w, list)
3299
+ x = torch.cat(x, dim=0).contiguous() # (G * M, K)
3300
+ w = torch.stack(w, dim=0).contiguous() # (G, N, K)
3301
+ return x, w
3302
+
3303
+ def quantize(self, x, w):
3304
+ block_size = 32
3305
+ G, N, K = w.shape
3306
+ total_M = x.shape[0]
3307
+ group_size = total_M // G
3308
+ input_group_end_offsets = torch.arange(
3309
+ group_size, total_M + 1, group_size, dtype=torch.int32, device=x.device
3310
+ )
3311
+
3312
+ # For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
3313
+ # as they each used for independent gemm in the grouped gemm.
3314
+ wq_list = []
3315
+ w_scale_list = []
3316
+ for i in range(G):
3317
+ w_scale, wq = to_mxfp8(w[i])
3318
+ w_scale = _to_blocked(w_scale)
3319
+ wq_list.append(wq)
3320
+ w_scale_list.append(w_scale)
3321
+ wq = torch.stack(wq_list, dim=0).contiguous()
3322
+ w_scale = torch.stack(w_scale_list, dim=0).contiguous()
3323
+
3324
+ # For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
3325
+ # as they each used for independent gemm in the grouped gemm.
3326
+ xq_list = []
3327
+ x_scale_list = []
3328
+ for i in range(G):
3329
+ prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
3330
+ curr_group_end = input_group_end_offsets[i]
3331
+ group_size = curr_group_end - prev_group_end
3332
+ if group_size > 0:
3333
+ x_slice = x[prev_group_end:curr_group_end, :]
3334
+ x_scale, xq = to_mxfp8(x_slice)
3335
+ x_scale = _to_blocked(x_scale)
3336
+ xq_list.append(xq)
3337
+ x_scale_list.append(x_scale)
3338
+ xq = torch.cat(xq_list, dim=0).contiguous()
3339
+ x_scale = torch.cat(x_scale_list, dim=0).contiguous()
3340
+ x_scale = x_scale.reshape(-1, K // block_size)
3341
+ xq = xq.view(-1, xq.shape[-1])
3342
+ return xq, wq, x_scale, w_scale, input_group_end_offsets
3343
+
3344
+ def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
3345
+ return torch.ops.fbgemm.mx8mx8bf16_grouped_mm(
3346
+ xq,
3347
+ wq.transpose(-2, -1),
3348
+ x_scale,
3349
+ w_scale,
3350
+ input_group_end_offsets,
3351
+ )
3352
+
3353
+ def quantize_and_compute(self, x, w):
3354
+ xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
3355
+ return self.compute(
3356
+ xq,
3357
+ wq,
3358
+ x_scale,
3359
+ w_scale,
3360
+ input_group_end_offsets,
3361
+ )
3362
+
3363
+ @property
3364
+ def name(self) -> str:
3365
+ return "cutlass_mx8mx8bf16_grouped_mm_2d_3d"
3366
+
3367
+ @property
3368
+ def hip(self) -> bool:
3369
+ return False
3370
+
3371
+ @property
3372
+ def cuda(self) -> bool:
3373
+ return True
3374
+
3375
+
3376
+ @register_quantize_op
3377
+ class MXFP8GroupedGemm2d2d(QuantizeOpBase):
3378
+ """
3379
+ MXFP8 grouped GEMM with 2D inputs and 3D weights.
3380
+ """
3381
+
3382
+ def preprocess(self, x, w):
3383
+ assert isinstance(x, list)
3384
+ assert isinstance(w, list)
3385
+ G = len(x)
3386
+ x = torch.cat(x, dim=1).contiguous() # (M, total_K)
3387
+ w = torch.cat(w, dim=1).contiguous() # (N, total_K)
3388
+ return x, w, G
3389
+
3390
+ def quantize(self, x, w, G):
3391
+ # Simulate 2d-2d grouped gemm in backward pass `grad_weight = grad_output_t @ input`,
3392
+ # where we use "K" as the contracting dim which has "G" groups.
3393
+ M, total_K = x.shape
3394
+ N, _ = w.shape
3395
+ group_size = total_K // G
3396
+ input_group_end_offsets = torch.arange(
3397
+ group_size, total_K + 1, group_size, dtype=torch.int32, device=x.device
3398
+ )
3399
+
3400
+ # Convert scales to blocked format.
3401
+ x_list = []
3402
+ w_list = []
3403
+ x_blocked_scale_list = []
3404
+ w_blocked_scale_list = []
3405
+
3406
+ for group_idx in range(G):
3407
+ # to_mxfp8 per group
3408
+ prev_group_end_offset = (
3409
+ 0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
3410
+ )
3411
+ curr_group_end_offset = input_group_end_offsets[group_idx]
3412
+ group_size = curr_group_end_offset - prev_group_end_offset
3413
+ if group_size > 0:
3414
+ x_slice = x[
3415
+ :, prev_group_end_offset:curr_group_end_offset
3416
+ ].contiguous() # (M, K_group)
3417
+ w_slice = w[
3418
+ :, prev_group_end_offset:curr_group_end_offset
3419
+ ].contiguous() # (N, K_group)
3420
+ x_scale_slice, xq_slice = to_mxfp8(
3421
+ x_slice
3422
+ ) # scale shape -> (M, K_group // 32)
3423
+ w_scale_slice, wq_slice = to_mxfp8(
3424
+ w_slice
3425
+ ) # scale shape -> (N, K_group // 32)
3426
+ x_list.append(xq_slice)
3427
+ w_list.append(wq_slice)
3428
+
3429
+ # Convert scales to blocked format.
3430
+ x_scale_slice_blocked = _to_blocked(
3431
+ x_scale_slice
3432
+ ) # (round_up(M, 128), round_up(K_group//32, 4))
3433
+ w_scale_slice_blocked = _to_blocked(
3434
+ w_scale_slice
3435
+ ) # (round_up(N, 128), round_up(K_group//32, 4))
3436
+ x_blocked_scale_list.append(x_scale_slice_blocked)
3437
+ w_blocked_scale_list.append(w_scale_slice_blocked)
3438
+
3439
+ # Assemble the full XQ and WQ
3440
+ xq = torch.cat(x_list, dim=1).contiguous()
3441
+ wq = torch.cat(w_list, dim=1).contiguous()
3442
+
3443
+ # Combine all XQ groups blocked scales into one tensor.
3444
+ x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
3445
+ M_rounded = round_up(M, 128)
3446
+ x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
3447
+
3448
+ # Combine all WQ groups blocked scales into one tensor.
3449
+ w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
3450
+ N_rounded = round_up(N, 128)
3451
+ w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
3452
+ return xq, wq, x_blocked_scales, w_blocked_scales, input_group_end_offsets
3453
+
3454
+ def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
3455
+ return torch.ops.fbgemm.mx8mx8bf16_grouped_mm(
3456
+ xq,
3457
+ wq.transpose(-2, -1),
3458
+ x_scale,
3459
+ w_scale,
3460
+ input_group_end_offsets,
3461
+ )
3462
+
3463
+ def quantize_and_compute(self, x, w):
3464
+ xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
3465
+ return self.compute(
3466
+ xq,
3467
+ wq,
3468
+ x_scale,
3469
+ w_scale,
3470
+ input_group_end_offsets,
3471
+ )
3472
+
3473
+ @property
3474
+ def name(self) -> str:
3475
+ return "cutlass_mx8mx8bf16_grouped_mm_2d_2d"
3476
+
3477
+ @property
3478
+ def hip(self) -> bool:
3479
+ return False
3480
+
3481
+ @property
3482
+ def cuda(self) -> bool:
3483
+ return True