mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3342 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Keep a registry of all quantize operators.
8
+ import abc
9
+ import functools
10
+ from enum import auto, Enum
11
+
12
+ import numpy as np
13
+ import torch
14
+ from mslk.bench.common.utils import BenchOptions, do_bench
15
+ from mslk.gemm.triton.fp8_gemm import matmul_fp8_block, matmul_fp8_row, to_mxfp8
16
+ from mslk.gemm.triton.grouped_gemm import grouped_gemm, grouped_gemm_fp8_rowwise
17
+ from mslk.quantize.shuffle import ck_preshuffle, quantize_int4_preshuffle
18
+ from mslk.quantize.triton.fp4_quantize import (
19
+ _to_blocked,
20
+ calculate_group_max,
21
+ mega_fp4_pack,
22
+ mega_fp4_quantize_kernel,
23
+ mega_fp4_unpack,
24
+ triton_quantize_mx4_unpack,
25
+ triton_quantize_nvfp4,
26
+ triton_scale_nvfp4_quant_rms,
27
+ triton_scale_nvfp4_quant_silu,
28
+ )
29
+ from mslk.quantize.triton.fp8_quantize import (
30
+ quantize_fp8_block,
31
+ quantize_fp8_group,
32
+ quantize_fp8_row,
33
+ scale_fp8_row,
34
+ triton_quantize_fp8_row,
35
+ )
36
+ from mslk.utils.triton.fp8_utils import get_fp8_constants
37
+
38
+ try:
39
+ from gen_ai.llm_inference.fb.llm.kernel.rms_norm import rms_norm
40
+ from gen_ai.llm_inference.fb.llm.kernel.silu_mul import silu_mul
41
+ except ImportError:
42
+ # Above is used for some experiments, but the quantize is not relying on them. Okay to just skip.
43
+ pass
44
+
45
+ try:
46
+ from tinygemm.utils import group_quantize_tensor
47
+
48
+ if torch.cuda.is_available() and torch.version.cuda:
49
+ torch.ops.load_library("//tinygemm:tinygemm")
50
+ TINYGEMM_ENABLED = True
51
+ except ImportError:
52
+ TINYGEMM_ENABLED = False
53
+
54
+ # Marlin currently only is supported only internally at Meta.
55
+ try:
56
+ from marlin.quantize import marlin_quantize
57
+
58
+ torch.ops.load_library("//ai_codesign/gen_ai/marlin:marlin_ops")
59
+ MARLIN_ENABLED = True
60
+ except ImportError:
61
+ MARLIN_ENABLED = False
62
+
63
+ try:
64
+ from deep_gemm import (
65
+ gemm_fp8_fp8_bf16_nt,
66
+ get_col_major_tma_aligned_tensor,
67
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
68
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked,
69
+ )
70
+
71
+ DEEPGEMM_ENABLED = True
72
+ except ImportError:
73
+ DEEPGEMM_ENABLED = False
74
+
75
+
76
+ # Machete is also only supported internally at Meta for now.
77
+ try:
78
+ from machete.machete import machete_gemm
79
+ from machete.quantize import machete_quantize_and_pack
80
+
81
+ MACHETE_ENABLED = True
82
+ except ImportError:
83
+ MACHETE_ENABLED = False
84
+
85
+
86
+ class Accelerator(Enum):
87
+ NVIDIA_SM90 = auto()
88
+ NVIDIA_SM100 = auto()
89
+ NVIDIA_SM103 = auto()
90
+ AMD_MI300X = auto()
91
+
92
+
93
+ class GemmType(Enum):
94
+ REGULAR = auto()
95
+ GROUPED = auto()
96
+
97
+
98
+ class ComputeDtype(Enum):
99
+ FP32 = auto()
100
+ TF32 = auto()
101
+ BF16 = auto()
102
+ FP8 = auto()
103
+ FP4 = auto()
104
+
105
+
106
+ @functools.cache
107
+ def get_current_accelerator() -> Accelerator | None:
108
+ if not torch.cuda.is_available():
109
+ raise Exception("Cannot run gemm_bench without accelerator.")
110
+
111
+ if torch.version.hip is not None:
112
+ device_name = torch.cuda.get_device_name()
113
+ if "MI300X" in device_name.upper():
114
+ return Accelerator.AMD_MI300X
115
+ elif torch.version.cuda is not None:
116
+ major, minor = torch.cuda.get_device_capability()
117
+ if major == 9 and minor == 0:
118
+ return Accelerator.NVIDIA_SM90
119
+ elif major == 10 and minor == 0:
120
+ return Accelerator.NVIDIA_SM100
121
+ elif major == 10 and minor == 3:
122
+ return Accelerator.NVIDIA_SM103
123
+
124
+ raise Exception("Cannot detect hardware that is supported by gemm_bench.")
125
+
126
+
127
+ gemm_op_registry = []
128
+
129
+
130
+ class GemmOpBase(metaclass=abc.ABCMeta):
131
+ """Helper abstract class to define expected methods of quantize ops."""
132
+
133
+ @abc.abstractmethod
134
+ def quantize(self, *args):
135
+ """Function which quantizes inputs."""
136
+ pass
137
+
138
+ @abc.abstractmethod
139
+ def compute(self, *args):
140
+ """Function which performs main compute operation."""
141
+ pass
142
+
143
+ @abc.abstractmethod
144
+ def quantize_and_compute(self, *args):
145
+ """Function which quantizes inputs and performs main compute operation."""
146
+ pass
147
+
148
+ def preprocess(self, *args):
149
+ """Preprocess inputs before benchmarking. These outputs will be passed to quantize."""
150
+ return args
151
+
152
+ def benchmark(
153
+ self,
154
+ *args,
155
+ opts: BenchOptions,
156
+ bench_quantize: bool,
157
+ ) -> float:
158
+ """Benchmark runtime of this operator."""
159
+ t = do_bench(
160
+ lambda *a: self.quantize_and_compute(*a)
161
+ if bench_quantize
162
+ else self.compute(*a),
163
+ args,
164
+ opts,
165
+ )
166
+ return t
167
+
168
+ @abc.abstractproperty
169
+ def name(self) -> str:
170
+ """Name of the operator."""
171
+ pass
172
+
173
+ @abc.abstractproperty
174
+ def supported_accelerators(self) -> set[Accelerator]:
175
+ pass
176
+
177
+ @property
178
+ def supported(self) -> bool:
179
+ """Whether this op will run on the current device."""
180
+ accelerator = get_current_accelerator()
181
+ return accelerator in self.supported_accelerators
182
+
183
+ @abc.abstractproperty
184
+ def supported_gemm_types(self) -> set[GemmType]:
185
+ pass
186
+
187
+ @abc.abstractproperty
188
+ def compute_dtype(self) -> ComputeDtype:
189
+ """The dtype used by tensor cores for the compute."""
190
+ pass
191
+
192
+
193
+ def register_gemm_op(op):
194
+ """Decorator function for assembling all quantize ops."""
195
+ gemm_op_registry.append(op())
196
+ return op
197
+
198
+
199
+ def get_gemm_ops() -> list[GemmOpBase]:
200
+ """Get all registered quantize ops."""
201
+ return gemm_op_registry
202
+
203
+
204
+ @register_gemm_op
205
+ class FP32Baseline(GemmOpBase):
206
+ """
207
+ FP32 matmul baseline.
208
+ """
209
+
210
+ def quantize(self, x, w):
211
+ if isinstance(x, list):
212
+ x = [i.float() for i in x]
213
+ w = [torch.transpose(i, -2, -1).float() for i in w]
214
+ else:
215
+ x = x.float()
216
+ w = torch.transpose(w, -2, -1).float()
217
+ return x, w
218
+
219
+ def compute(self, x, w):
220
+ # Handle both grouped and standard gemm.
221
+ if isinstance(x, list):
222
+ output = []
223
+ for i in range(len(x)):
224
+ output.append(torch.matmul(x[i], w[i]))
225
+ return output
226
+ return torch.matmul(x, w)
227
+
228
+ def quantize_and_compute(self, x, w):
229
+ return self.compute(*self.quantize(x, w))
230
+
231
+ @property
232
+ def name(self) -> str:
233
+ return "fp32_baseline"
234
+
235
+ @property
236
+ def supported_accelerators(self) -> set[Accelerator]:
237
+ return set(Accelerator)
238
+
239
+ @property
240
+ def supported_gemm_types(self) -> set[GemmType]:
241
+ return {GemmType.REGULAR, GemmType.GROUPED}
242
+
243
+ @property
244
+ def compute_dtype(self) -> ComputeDtype:
245
+ return ComputeDtype.FP32
246
+
247
+
248
+ @register_gemm_op
249
+ class TF32Baseline(GemmOpBase):
250
+ """
251
+ TF32 matmul baseline.
252
+ """
253
+
254
+ def quantize(self, x, w):
255
+ if isinstance(x, list):
256
+ x = [i.float() for i in x]
257
+ w = [torch.transpose(i, -2, -1).float() for i in w]
258
+ else:
259
+ x = x.float()
260
+ w = torch.transpose(w, -2, -1).float()
261
+ return x, w
262
+
263
+ def compute(self, x, w):
264
+ # Handle both grouped and standard gemm.
265
+ original_precision = torch.get_float32_matmul_precision()
266
+ torch.set_float32_matmul_precision("high")
267
+ if isinstance(x, list):
268
+ output = []
269
+ for i in range(len(x)):
270
+ output.append(torch.matmul(x[i], w[i]))
271
+ return output
272
+ out = torch.matmul(x, w)
273
+ torch.set_float32_matmul_precision(original_precision)
274
+ return out
275
+
276
+ def quantize_and_compute(self, x, w):
277
+ return self.compute(*self.quantize(x, w))
278
+
279
+ @property
280
+ def name(self) -> str:
281
+ return "tf32_baseline"
282
+
283
+ @property
284
+ def supported_accelerators(self) -> set[Accelerator]:
285
+ return set(Accelerator)
286
+
287
+ @property
288
+ def supported_gemm_types(self) -> set[GemmType]:
289
+ return {GemmType.REGULAR, GemmType.GROUPED}
290
+
291
+ @property
292
+ def compute_dtype(self) -> ComputeDtype:
293
+ return ComputeDtype.TF32
294
+
295
+
296
+ @register_gemm_op
297
+ class BF16Baseline(GemmOpBase):
298
+ """
299
+ Baseline BF16 matmul.
300
+ """
301
+
302
+ def quantize(self, x, w):
303
+ if isinstance(x, list):
304
+ x = [i.bfloat16() for i in x]
305
+ w = [torch.transpose(i, -2, -1).bfloat16() for i in w]
306
+ else:
307
+ x = x.bfloat16()
308
+ w = torch.transpose(w, -2, -1).bfloat16()
309
+ return x, w
310
+
311
+ def compute(self, x, w):
312
+ # Handle both grouped and standard gemm.
313
+ if isinstance(x, list):
314
+ output = []
315
+ for i in range(len(x)):
316
+ output.append(torch.matmul(x[i], w[i]))
317
+ return output
318
+ return torch.matmul(x, w)
319
+
320
+ def quantize_and_compute(self, x, w):
321
+ return self.compute(*self.quantize(x, w))
322
+
323
+ @property
324
+ def name(self) -> str:
325
+ return "bf16_baseline"
326
+
327
+ @property
328
+ def supported_accelerators(self) -> set[Accelerator]:
329
+ return set(Accelerator)
330
+
331
+ @property
332
+ def supported_gemm_types(self) -> set[GemmType]:
333
+ return {GemmType.REGULAR, GemmType.GROUPED}
334
+
335
+ @property
336
+ def compute_dtype(self) -> ComputeDtype:
337
+ return ComputeDtype.BF16
338
+
339
+
340
+ @register_gemm_op
341
+ class ScaledMMBaseline(GemmOpBase):
342
+ """
343
+ Reference FP8 matmul implemented in native torch with cublas or hipblas.
344
+ """
345
+
346
+ def __init__(self):
347
+ self.fp8_dtype, _, _, _ = get_fp8_constants()
348
+ self.E4M3_MAX_POS: float = torch.finfo(self.fp8_dtype).max
349
+ self.E5M2_MAX_POS: float = torch.finfo(torch.float8_e5m2).max
350
+ self.FP16_MAX_POS: float = torch.finfo(torch.float16).max
351
+ self.EPS: float = 1e-12
352
+ self.fast_accum = True
353
+
354
+ def _amax_to_scale(
355
+ self, amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
356
+ ) -> torch.Tensor:
357
+ # To make scale dtype to be fp32 for accuracy
358
+ amax = amax.float()
359
+ if float8_dtype == self.fp8_dtype:
360
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
361
+ res = self.E4M3_MAX_POS / torch.clamp(amax, min=self.EPS)
362
+ else: # e5m2
363
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
364
+ res = self.E5M2_MAX_POS / torch.clamp(amax, min=self.EPS)
365
+
366
+ # pyre-fixme[7]: Expected `Tensor` but got `Union[float, Tensor]`.
367
+ return res
368
+
369
+ def _to_fp8_saturated(
370
+ self, x: torch.Tensor, float8_dtype: torch.dtype
371
+ ) -> torch.Tensor:
372
+ if float8_dtype == torch.float8_e4m3fn:
373
+ x = x.clamp(min=-1 * self.E4M3_MAX_POS, max=self.E4M3_MAX_POS)
374
+ else:
375
+ x = x.clamp(min=-1 * self.E5M2_MAX_POS, max=self.E5M2_MAX_POS)
376
+ return x.to(float8_dtype)
377
+
378
+ def _quantize_tensor(self, x):
379
+ x_amax = torch.max(torch.abs(x))
380
+ scale = self._amax_to_scale(x_amax, self.fp8_dtype, x.dtype)
381
+ scaled_x = self._to_fp8_saturated(x * scale, self.fp8_dtype)
382
+ x_inverse_scale = scale.reciprocal()
383
+ return scaled_x, x_inverse_scale
384
+
385
+ def quantize(self, x, w):
386
+ xq, x_scale = self._quantize_tensor(x)
387
+ wq, w_scale = self._quantize_tensor(w.t())
388
+ return xq, wq, x_scale, w_scale
389
+
390
+ def compute(self, xq, wq, x_scale, w_scale):
391
+ output = torch._scaled_mm(
392
+ xq,
393
+ wq,
394
+ bias=None,
395
+ out_dtype=torch.bfloat16,
396
+ scale_a=x_scale,
397
+ scale_b=w_scale,
398
+ scale_result=None,
399
+ use_fast_accum=self.fast_accum,
400
+ )
401
+ return output
402
+
403
+ def quantize_and_compute(self, x, w):
404
+ return self.compute(*self.quantize(x, w))
405
+
406
+ @property
407
+ def name(self) -> str:
408
+ return "scaled_mm"
409
+
410
+ @property
411
+ def supported_accelerators(self) -> set[Accelerator]:
412
+ return set(Accelerator)
413
+
414
+ @property
415
+ def supported_gemm_types(self) -> set[GemmType]:
416
+ return {GemmType.REGULAR}
417
+
418
+ @property
419
+ def compute_dtype(self) -> ComputeDtype:
420
+ return ComputeDtype.FP8
421
+
422
+
423
+ @register_gemm_op
424
+ class BF16X9Baseline(GemmOpBase):
425
+ """
426
+ FP32 matmul implemented with BF16X9 emulation.
427
+ """
428
+
429
+ def quantize(self, x, w):
430
+ if isinstance(x, list):
431
+ x = [i.float() for i in x]
432
+ w = [i.float() for i in w]
433
+ else:
434
+ x = x.float()
435
+ w = w.float()
436
+ return x, w
437
+
438
+ def compute(self, x, w):
439
+ # Handle both grouped and standard gemm.
440
+ if isinstance(x, list):
441
+ output = []
442
+ for i in range(len(x)):
443
+ output.append(torch.ops.mslk.bf16x9_gemm(x[i], w[i]))
444
+ return output
445
+ return torch.ops.mslk.bf16x9_gemm(x, w)
446
+
447
+ def quantize_and_compute(self, x, w):
448
+ return self.compute(*self.quantize(x, w))
449
+
450
+ @property
451
+ def name(self) -> str:
452
+ return "bf16x9_gemm"
453
+
454
+ @property
455
+ def supported_accelerators(self) -> set[Accelerator]:
456
+ return {
457
+ Accelerator.NVIDIA_SM90,
458
+ Accelerator.NVIDIA_SM100,
459
+ Accelerator.NVIDIA_SM103,
460
+ }
461
+
462
+ @property
463
+ def supported_gemm_types(self) -> set[GemmType]:
464
+ return {GemmType.REGULAR, GemmType.GROUPED}
465
+
466
+ @property
467
+ def compute_dtype(self) -> ComputeDtype:
468
+ return ComputeDtype.BF16
469
+
470
+
471
+ @register_gemm_op
472
+ class ScaledMMRowwise(GemmOpBase):
473
+ def __init__(self):
474
+ self.fast_accum = True
475
+ self.torch_compile = False
476
+
477
+ def quantize(self, x, w):
478
+ xq, x_scale = quantize_fp8_row(x)
479
+ wq, w_scale = quantize_fp8_row(w)
480
+ return xq, wq.t(), x_scale.unsqueeze(1), w_scale.unsqueeze(0)
481
+
482
+ def compute(self, xq, wq, x_scale, w_scale):
483
+ if self.torch_compile:
484
+ f = torch.compile(
485
+ torch._scaled_mm,
486
+ options={
487
+ "max_autotune": True,
488
+ "max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
489
+ },
490
+ )
491
+ else:
492
+ f = torch._scaled_mm
493
+
494
+ return f(
495
+ xq,
496
+ wq,
497
+ bias=None,
498
+ out_dtype=torch.bfloat16,
499
+ scale_a=x_scale,
500
+ scale_b=w_scale,
501
+ scale_result=None,
502
+ use_fast_accum=self.fast_accum,
503
+ )
504
+
505
+ def quantize_and_compute(self, x, w):
506
+ return self.compute(*self.quantize(x, w))
507
+
508
+ @property
509
+ def name(self) -> str:
510
+ return "scaled_mm_rowwise"
511
+
512
+ @property
513
+ def supported_accelerators(self) -> set[Accelerator]:
514
+ return set(Accelerator)
515
+
516
+ @property
517
+ def supported_gemm_types(self) -> set[GemmType]:
518
+ return {GemmType.REGULAR}
519
+
520
+ @property
521
+ def compute_dtype(self) -> ComputeDtype:
522
+ return ComputeDtype.FP8
523
+
524
+
525
+ @register_gemm_op
526
+ class ScaledMMMXFP8(GemmOpBase):
527
+ def __init__(self):
528
+ self.torch_compile = False
529
+
530
+ def quantize(self, x, w):
531
+ x_scale, xq = to_mxfp8(x)
532
+ x_scale = _to_blocked(x_scale)
533
+ w_scale, wq = to_mxfp8(w)
534
+ w_scale = _to_blocked(w_scale)
535
+ return xq, wq.t(), x_scale, w_scale
536
+
537
+ def compute(self, xq, wq, x_scale, w_scale):
538
+ if self.torch_compile:
539
+ f = torch.compile(
540
+ torch._scaled_mm,
541
+ options={
542
+ "max_autotune": True,
543
+ "max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
544
+ },
545
+ )
546
+ else:
547
+ f = torch._scaled_mm
548
+
549
+ return f(
550
+ xq,
551
+ wq,
552
+ bias=None,
553
+ out_dtype=torch.bfloat16,
554
+ scale_a=x_scale,
555
+ scale_b=w_scale,
556
+ )
557
+
558
+ def quantize_and_compute(self, x, w):
559
+ return self.compute(*self.quantize(x, w))
560
+
561
+ @property
562
+ def name(self) -> str:
563
+ return "scaled_mm_mxfp8"
564
+
565
+ @property
566
+ def supported_accelerators(self) -> set[Accelerator]:
567
+ return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
568
+
569
+ @property
570
+ def supported_gemm_types(self) -> set[GemmType]:
571
+ return {GemmType.REGULAR}
572
+
573
+ @property
574
+ def compute_dtype(self) -> ComputeDtype:
575
+ return ComputeDtype.FP8
576
+
577
+
578
+ @register_gemm_op
579
+ class ScaledMMNVFP4(GemmOpBase):
580
+ def __init__(self):
581
+ self.torch_compile = False
582
+
583
+ def quantize(self, x, w):
584
+ x_global_scale = torch.tensor([1.0], device=x.device, dtype=torch.float32)
585
+ w_global_scale = torch.tensor([1.0], device=w.device, dtype=torch.float32)
586
+
587
+ xq, x_scale = triton_quantize_nvfp4(x, x_global_scale)
588
+ wq, w_scale = triton_quantize_nvfp4(w, w_global_scale)
589
+
590
+ return (
591
+ xq.view(torch.float4_e2m1fn_x2),
592
+ wq.view(torch.float4_e2m1fn_x2),
593
+ x_scale.view(torch.float8_e4m3fn),
594
+ w_scale.view(torch.float8_e4m3fn),
595
+ )
596
+
597
+ def compute(self, xq, wq, x_scale, w_scale):
598
+ if self.torch_compile:
599
+ f = torch.compile(
600
+ torch._scaled_mm,
601
+ options={
602
+ "max_autotune": True,
603
+ "max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
604
+ },
605
+ )
606
+ else:
607
+ f = torch._scaled_mm
608
+
609
+ return f(
610
+ xq,
611
+ wq.t(),
612
+ bias=None,
613
+ out_dtype=torch.bfloat16,
614
+ scale_a=x_scale,
615
+ scale_b=w_scale,
616
+ )
617
+
618
+ def quantize_and_compute(self, x, w):
619
+ return self.compute(*self.quantize(x, w))
620
+
621
+ @property
622
+ def name(self) -> str:
623
+ return "scaled_mm_nvfp4"
624
+
625
+ @property
626
+ def supported_accelerators(self) -> set[Accelerator]:
627
+ return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
628
+
629
+ @property
630
+ def supported_gemm_types(self) -> set[GemmType]:
631
+ return {GemmType.REGULAR}
632
+
633
+ @property
634
+ def compute_dtype(self) -> ComputeDtype:
635
+ return ComputeDtype.FP4
636
+
637
+
638
+ @register_gemm_op
639
+ class FP8TensorwiseGemm(GemmOpBase):
640
+ """
641
+ FP8 matmul with tensorwise scaling.
642
+ """
643
+
644
+ def quantize(self, x, w):
645
+ # Quantize both input tensors.
646
+ xq, x_scale = torch.ops.mslk.quantize_fp8_per_tensor(x)
647
+ wq, w_scale = torch.ops.mslk.quantize_fp8_per_tensor(w)
648
+ return xq, wq, x_scale, w_scale
649
+
650
+ def compute(self, xq, wq, x_scale, w_scale):
651
+ return torch.ops.mslk.f8f8bf16(xq, wq, x_scale * w_scale)
652
+
653
+ def quantize_and_compute(self, x, w):
654
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
655
+ return self.compute(xq, wq, x_scale, w_scale)
656
+
657
+ @property
658
+ def name(self) -> str:
659
+ return "cutlass_tensorwise"
660
+
661
+ @property
662
+ def supported_accelerators(self) -> set[Accelerator]:
663
+ return {Accelerator.NVIDIA_SM90}
664
+
665
+ @property
666
+ def supported_gemm_types(self) -> set[GemmType]:
667
+ return {GemmType.REGULAR}
668
+
669
+ @property
670
+ def compute_dtype(self) -> ComputeDtype:
671
+ return ComputeDtype.FP8
672
+
673
+
674
+ @register_gemm_op
675
+ class FP8CublasRowwiseGemm(GemmOpBase):
676
+ """
677
+ FP8 cublas matmul with rowwise scaling.
678
+ """
679
+
680
+ def quantize(self, x, w):
681
+ # Quantize both input tensors.
682
+ xq, x_scale = torch.ops.mslk.quantize_fp8_per_row(x)
683
+ wq, w_scale = torch.ops.mslk.quantize_fp8_per_row(w)
684
+ return xq, wq, x_scale, w_scale
685
+
686
+ def compute(self, xq, wq, x_scale, w_scale):
687
+ out = torch.ops.mslk.f8f8bf16_cublas(xq, wq)
688
+ scaled_out = scale_fp8_row(out, x_scale, w_scale)
689
+ return scaled_out
690
+
691
+ def quantize_and_compute(self, x, w):
692
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
693
+ return self.compute(xq, wq, x_scale, w_scale)
694
+
695
+ @property
696
+ def name(self) -> str:
697
+ return "cublas_rowwise"
698
+
699
+ @property
700
+ def supported_accelerators(self) -> set[Accelerator]:
701
+ return {
702
+ Accelerator.NVIDIA_SM90,
703
+ Accelerator.NVIDIA_SM100,
704
+ Accelerator.NVIDIA_SM103,
705
+ }
706
+
707
+ @property
708
+ def supported_gemm_types(self) -> set[GemmType]:
709
+ return {GemmType.REGULAR}
710
+
711
+ @property
712
+ def compute_dtype(self) -> ComputeDtype:
713
+ return ComputeDtype.FP8
714
+
715
+
716
+ @register_gemm_op
717
+ class FP8CublasTensorwiseGemm(GemmOpBase):
718
+ """
719
+ FP8 cublas matmul with tensorwise scaling.
720
+ """
721
+
722
+ def quantize(self, x, w):
723
+ # Quantize both input tensors.
724
+ xq, x_scale = torch.ops.mslk.quantize_fp8_per_tensor(x)
725
+ wq, w_scale = torch.ops.mslk.quantize_fp8_per_tensor(w)
726
+ return xq, wq, x_scale, w_scale
727
+
728
+ def compute(self, xq, wq, x_scale, w_scale):
729
+ return torch.ops.mslk.f8f8bf16_cublas(xq, wq, x_scale * w_scale)
730
+
731
+ def quantize_and_compute(self, x, w):
732
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
733
+ return self.compute(xq, wq, x_scale * w_scale)
734
+
735
+ @property
736
+ def name(self) -> str:
737
+ return "cublas_tensorwise"
738
+
739
+ @property
740
+ def supported_accelerators(self) -> set[Accelerator]:
741
+ return {
742
+ Accelerator.NVIDIA_SM90,
743
+ Accelerator.NVIDIA_SM100,
744
+ Accelerator.NVIDIA_SM103,
745
+ }
746
+
747
+ @property
748
+ def supported_gemm_types(self) -> set[GemmType]:
749
+ return {GemmType.REGULAR}
750
+
751
+ @property
752
+ def compute_dtype(self) -> ComputeDtype:
753
+ return ComputeDtype.FP8
754
+
755
+
756
+ @register_gemm_op
757
+ class FP8RowwiseGemm(GemmOpBase):
758
+ """
759
+ FP8 matmul with rowwise scaling.
760
+ """
761
+
762
+ def __init__(self):
763
+ self.fast_accum = True
764
+ self.gemm_op = torch.ops.mslk.f8f8bf16_rowwise
765
+
766
+ def preprocess(self, x, w):
767
+ # Prequantize weights.
768
+ if isinstance(w, (list, tuple)):
769
+ wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
770
+ else:
771
+ wq, w_scale = quantize_fp8_row(w)
772
+ if wq.dim() == 3:
773
+ w_scale = w_scale.view(wq.size(0), -1)
774
+ return x, wq, w_scale
775
+
776
+ def quantize(self, x, wq, w_scale):
777
+ # Quantize both input tensors.
778
+ # Handle both grouped and standard gemm.
779
+ if isinstance(x, (list, tuple)):
780
+ xq, x_scale = zip(*[quantize_fp8_row(i) for i in x])
781
+ else:
782
+ xq, x_scale = quantize_fp8_row(x)
783
+ # Set proper batch dimension shapes.
784
+ if xq.dim() == 3:
785
+ x_scale = x_scale.view(xq.size(0), -1)
786
+ return xq, wq, x_scale, w_scale
787
+
788
+ def compute(self, xq, wq, x_scale, w_scale):
789
+ # Handle group gemm if inputs are grouped.
790
+ if isinstance(xq, (list, tuple)):
791
+ output = []
792
+ for i in range(len(xq)):
793
+ output.append(
794
+ self.gemm_op(
795
+ xq[i],
796
+ wq[i],
797
+ x_scale[i],
798
+ w_scale[i],
799
+ use_fast_accum=self.fast_accum,
800
+ )
801
+ )
802
+ return output
803
+ # Unroll batched gemm if needed.
804
+ elif xq.dim() == 3 and wq.dim() == 3:
805
+ B, M, _ = xq.shape
806
+ _, N, _ = wq.shape
807
+ y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
808
+ for i in range(B):
809
+ y[i] = self.gemm_op(
810
+ xq[i], wq[i], x_scale[i], w_scale[i], use_fast_accum=self.fast_accum
811
+ )
812
+ return y
813
+ # Otherwise return normal gemm result.
814
+ return self.gemm_op(xq, wq, x_scale, w_scale, use_fast_accum=self.fast_accum)
815
+
816
+ def quantize_and_compute(self, x, wq, w_scale):
817
+ xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
818
+ return self.compute(xq, wq, x_scale, w_scale)
819
+
820
+ @property
821
+ def name(self) -> str:
822
+ if torch.version.cuda:
823
+ return "cutlass_rowwise"
824
+ else:
825
+ return "ck_rowwise"
826
+
827
+ @property
828
+ def supported_accelerators(self) -> set[Accelerator]:
829
+ return {
830
+ Accelerator.NVIDIA_SM90,
831
+ Accelerator.NVIDIA_SM100,
832
+ Accelerator.NVIDIA_SM103,
833
+ Accelerator.AMD_MI300X,
834
+ }
835
+
836
+ @property
837
+ def supported_gemm_types(self) -> set[GemmType]:
838
+ return {GemmType.REGULAR}
839
+
840
+ @property
841
+ def compute_dtype(self) -> ComputeDtype:
842
+ return ComputeDtype.FP8
843
+
844
+
845
+ @register_gemm_op
846
+ class FP8RowwisePreshuffleGemm(FP8RowwiseGemm):
847
+ """
848
+ FP8 matmul with rowwise scaling and preshuffling of input B.
849
+ """
850
+
851
+ def __init__(self):
852
+ self.fast_accum = True
853
+ if self.supported:
854
+ self.gemm_op = torch.ops.mslk.f8f8bf16_rowwise_preshuffle
855
+
856
+ def preprocess(self, x, w):
857
+ x, wq, w_scale = super().preprocess(x, w)
858
+ return x, ck_preshuffle(wq, 16), w_scale
859
+
860
+ @property
861
+ def name(self) -> str:
862
+ return "ck_rowwise_preshuffle"
863
+
864
+ @property
865
+ def supported_accelerators(self) -> set[Accelerator]:
866
+ return {Accelerator.AMD_MI300X}
867
+
868
+ @property
869
+ def supported_gemm_types(self) -> set[GemmType]:
870
+ return {GemmType.REGULAR}
871
+
872
+ @property
873
+ def compute_dtype(self) -> ComputeDtype:
874
+ return ComputeDtype.FP8
875
+
876
+
877
+ @register_gemm_op
878
+ class FP8RowwiseGroupedGemm(GemmOpBase):
879
+ """
880
+ FP8 grouped matmul with rowwise scaling.
881
+ """
882
+
883
+ def preprocess(self, x, w):
884
+ # Apply sparsity to inputs if appropriate.
885
+ # First check if N and K are fixed.
886
+ m_values = [i.shape[0] for i in x]
887
+ n_values = [i.shape[0] for i in w]
888
+ k_values = [i.shape[1] for i in w]
889
+ # If so, do specialized version of initialization.
890
+ if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
891
+ m_values = [i.shape[0] for i in x]
892
+ # Inputs for fixed nk mode must be contiguous, however in the benchmark
893
+ # script they typically are not. Do a little special processing to make them
894
+ # work. In practice this wont be needed.
895
+ # Start by padding along m dimension with zeros.
896
+ max_m = max(m_values)
897
+ x = [
898
+ torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
899
+ for i in x
900
+ ]
901
+ # Stack inputs into groups.
902
+ x = torch.stack(x).contiguous()
903
+ w = torch.stack(w).contiguous()
904
+
905
+ # Preapply weight quantization.
906
+ wq, w_scale = quantize_fp8_row(w)
907
+ # Return processed tensors.
908
+ return (
909
+ x,
910
+ wq,
911
+ w_scale,
912
+ torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
913
+ )
914
+ # Otherwise run without sparsity.
915
+ wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
916
+ return x, wq, w_scale, None
917
+
918
+ def quantize(self, x, wq, w_scale, m_values=None):
919
+ # Handle case where inputs are explicitly grouped and non-sparse.
920
+ if isinstance(x, (tuple, list)):
921
+ xq, x_scale = zip(*[triton_quantize_fp8_row(i) for i in x])
922
+ return xq, wq, x_scale, w_scale, m_values
923
+ # Otherwise inputs are unified tensors and sparse.
924
+ else:
925
+ B = x.shape[0]
926
+ xq, x_scale = triton_quantize_fp8_row(x, zero_start_index_M=m_values)
927
+ x_scale = x_scale.view(B, -1)
928
+ return xq, wq, x_scale, w_scale, m_values
929
+
930
+ def compute(self, xq, wq, x_scale, w_scale, m_values):
931
+ if m_values is None:
932
+ return torch.ops.mslk.f8f8bf16_rowwise_grouped(
933
+ xq,
934
+ wq,
935
+ x_scale,
936
+ w_scale,
937
+ )
938
+ else:
939
+ # Break tensor into groups, simulates what is done e2e.
940
+ return torch.ops.mslk.f8f8bf16_rowwise_grouped_dynamic(
941
+ xq,
942
+ wq,
943
+ x_scale,
944
+ w_scale,
945
+ zero_start_index_M=m_values,
946
+ )
947
+
948
+ def quantize_and_compute(self, x, wq, w_scale, m_values=None):
949
+ xq, wq, x_scale, w_scale, m_values = self.quantize(x, wq, w_scale, m_values)
950
+ return self.compute(xq, wq, x_scale, w_scale, m_values)
951
+
952
+ @property
953
+ def name(self) -> str:
954
+ if torch.version.cuda:
955
+ return "cutlass_rowwise_grouped"
956
+ else:
957
+ return "ck_rowwise_grouped"
958
+
959
+ @property
960
+ def supported_accelerators(self) -> set[Accelerator]:
961
+ return {
962
+ Accelerator.NVIDIA_SM90,
963
+ Accelerator.NVIDIA_SM100,
964
+ Accelerator.NVIDIA_SM103,
965
+ Accelerator.AMD_MI300X,
966
+ }
967
+
968
+ @property
969
+ def supported_gemm_types(self) -> set[GemmType]:
970
+ return {GemmType.GROUPED}
971
+
972
+ @property
973
+ def compute_dtype(self) -> ComputeDtype:
974
+ return ComputeDtype.FP8
975
+
976
+
977
+ @register_gemm_op
978
+ class BF16TritonStackedGroupedGemm(GemmOpBase):
979
+ """
980
+ BF16 grouped matmul with stacked inputs implemented with triton.
981
+ """
982
+
983
+ def preprocess(self, x, w):
984
+ m_values = [i.shape[0] for i in x]
985
+ # Convert m_values into offsets into grouped tensor.
986
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
987
+ w = torch.concat(w, dim=0).contiguous()
988
+ # Also view input as flattened.
989
+ x = torch.concat(x, dim=0).contiguous()
990
+ # Return processed tensors.
991
+ return x, w, m_sizes
992
+
993
+ def quantize(self, x, w, m_sizes):
994
+ return x, w, m_sizes
995
+
996
+ def compute(self, x, w, m_sizes):
997
+ return grouped_gemm(x, w, m_sizes, _use_warp_specialization=True)
998
+
999
+ def quantize_and_compute(self, x, w, m_sizes):
1000
+ x, w, m_sizes = self.quantize(x, w, m_sizes)
1001
+ return self.compute(x, w, m_sizes)
1002
+
1003
+ @property
1004
+ def name(self) -> str:
1005
+ return "triton_bf16_grouped_stacked"
1006
+
1007
+ @property
1008
+ def supported_accelerators(self) -> set[Accelerator]:
1009
+ return set(Accelerator)
1010
+
1011
+ @property
1012
+ def supported_gemm_types(self) -> set[GemmType]:
1013
+ return {GemmType.GROUPED}
1014
+
1015
+ @property
1016
+ def compute_dtype(self) -> ComputeDtype:
1017
+ return ComputeDtype.BF16
1018
+
1019
+
1020
+ @register_gemm_op
1021
+ class BF16TritonStackedGroupedGemmFuseScatterAdd(BF16TritonStackedGroupedGemm):
1022
+ """
1023
+ BF16 grouped matmul with stacked inputs implemented with triton. Fused with ScatterAdd.
1024
+ """
1025
+
1026
+ def preprocess(self, x, w):
1027
+ x, w, m_sizes = super().preprocess(x, w)
1028
+ M = x.shape[0]
1029
+ N = w.shape[0] // m_sizes.shape[0]
1030
+ output = torch.zeros(M, N, dtype=torch.bfloat16, device=x.device)
1031
+ indices = torch.randperm(M, dtype=torch.int32, device=x.device)
1032
+ return x, w, m_sizes, output, indices
1033
+
1034
+ def quantize(self, x, w, m_sizes, *args):
1035
+ return *super().quantize(x, w, m_sizes), *args
1036
+
1037
+ def compute(self, x, w, m_sizes, output, indices):
1038
+ return grouped_gemm(
1039
+ x,
1040
+ w,
1041
+ m_sizes,
1042
+ _use_warp_specialization=True,
1043
+ _output_tensor=output,
1044
+ _scatter_add_indices=indices,
1045
+ )
1046
+
1047
+ def quantize_and_compute(self, x, w, m_sizes, *args):
1048
+ x, w, m_sizes, *ret = self.quantize(x, w, m_sizes, *args)
1049
+ return self.compute(x, w, m_sizes, *ret)
1050
+
1051
+ @property
1052
+ def name(self) -> str:
1053
+ return "triton_bf16_grouped_stacked_fuse_scatter_add"
1054
+
1055
+
1056
+ @register_gemm_op
1057
+ class FP8TritonStackedGroupedGemm(GemmOpBase):
1058
+ """
1059
+ FP8 grouped matmul with rowwise scaling and stacked inputs implemented with triton.
1060
+ """
1061
+
1062
+ def preprocess(self, x, w):
1063
+ m_values = [i.shape[0] for i in x]
1064
+ # Convert m_values into offsets into grouped tensor.
1065
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
1066
+ # Quantize weights.
1067
+ wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
1068
+ # Group weights as single tensor.
1069
+ wq = torch.concat(wq, dim=0).contiguous()
1070
+ w_scale = torch.concat(w_scale, dim=0).contiguous()
1071
+ # Also view input as flattened.
1072
+ x = torch.concat(x, dim=0).contiguous()
1073
+ # Return processed tensors.
1074
+ return x, wq, w_scale, m_sizes
1075
+
1076
+ def quantize(self, x, wq, w_scale, m_sizes):
1077
+ B = x.shape[0]
1078
+ xq, x_scale = triton_quantize_fp8_row(x)
1079
+ x_scale = x_scale.view(B, -1)
1080
+ return xq, wq, x_scale, w_scale, m_sizes
1081
+
1082
+ def compute(self, xq, wq, x_scale, w_scale, m_sizes):
1083
+ return grouped_gemm_fp8_rowwise(
1084
+ xq, wq, m_sizes, x_scale, w_scale, _use_warp_specialization=True
1085
+ )
1086
+
1087
+ def quantize_and_compute(self, x, wq, w_scale, m_sizes):
1088
+ xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
1089
+ return self.compute(xq, wq, x_scale, w_scale, m_sizes)
1090
+
1091
+ @property
1092
+ def name(self) -> str:
1093
+ return "triton_grouped_stacked"
1094
+
1095
+ @property
1096
+ def supported_accelerators(self) -> set[Accelerator]:
1097
+ return set(Accelerator)
1098
+
1099
+ @property
1100
+ def supported_gemm_types(self) -> set[GemmType]:
1101
+ return {GemmType.GROUPED}
1102
+
1103
+ @property
1104
+ def compute_dtype(self) -> ComputeDtype:
1105
+ return ComputeDtype.FP8
1106
+
1107
+
1108
+ @register_gemm_op
1109
+ class FP8TritonStackedGroupedGemmFuseScatterAdd(FP8TritonStackedGroupedGemm):
1110
+ """
1111
+ FP8 grouped matmul with stacked inputs implemented with triton. Fused with ScatterAdd.
1112
+ """
1113
+
1114
+ def preprocess(self, x, w):
1115
+ x, wq, w_scale, m_sizes = super().preprocess(x, w)
1116
+ M = x.shape[0]
1117
+ N = wq.shape[0] // m_sizes.shape[0]
1118
+ output = torch.zeros(M, N, dtype=torch.bfloat16, device=x.device)
1119
+ indices = torch.randperm(M, dtype=torch.int32, device=x.device)
1120
+ return x, wq, w_scale, m_sizes, output, indices
1121
+
1122
+ def quantize(self, x, wq, w_scale, m_sizes, *args):
1123
+ return *super().quantize(x, wq, w_scale, m_sizes), *args
1124
+
1125
+ def compute(self, xq, wq, x_scale, w_scale, m_sizes, output, indices):
1126
+ return grouped_gemm_fp8_rowwise(
1127
+ xq,
1128
+ wq,
1129
+ m_sizes,
1130
+ x_scale,
1131
+ w_scale,
1132
+ _use_warp_specialization=True,
1133
+ _output_tensor=output,
1134
+ _scatter_add_indices=indices,
1135
+ )
1136
+
1137
+ def quantize_and_compute(self, x, wq, w_scale, m_sizes, *args):
1138
+ xq, wq, x_scale, w_scale, m_sizes, *ret = self.quantize(
1139
+ x, wq, w_scale, m_sizes, *args
1140
+ )
1141
+ return self.compute(xq, wq, x_scale, w_scale, m_sizes, *ret)
1142
+
1143
+ @property
1144
+ def name(self) -> str:
1145
+ return "triton_grouped_stacked_fuse_scatter_add"
1146
+
1147
+
1148
+ @register_gemm_op
1149
+ class DeepGemmStacked(GemmOpBase):
1150
+ """
1151
+ FP8 grouped matmul with blockwise scaling implemented with DeepGemm.
1152
+ """
1153
+
1154
+ def preprocess(self, x, w):
1155
+ m_values = [i.shape[0] for i in x]
1156
+ # Convert m_values into offsets into grouped tensor.
1157
+ indices = torch.arange(len(m_values))
1158
+ m_indices = indices.repeat_interleave(torch.tensor(m_values)).to(
1159
+ device=x[0].device, dtype=torch.int
1160
+ )
1161
+ # Quantize weights.
1162
+ wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w])
1163
+ # Group weights as single tensor.
1164
+ wq = torch.stack(wq, dim=0).contiguous()
1165
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
1166
+ # Also view input as flattened.
1167
+ x = torch.concat(x, dim=0).contiguous()
1168
+ # Return processed tensors.
1169
+ return x, wq, w_scale, m_indices
1170
+
1171
+ def quantize(self, x, wq, w_scale, m_indices):
1172
+ xq, x_scale = quantize_fp8_block(x, block_m=1, block_k=128)
1173
+ # Pretranspose scales to deepgemm format.
1174
+ x_scale = get_col_major_tma_aligned_tensor(x_scale)
1175
+ return xq, wq, x_scale, w_scale, m_indices
1176
+
1177
+ def compute(self, xq, wq, x_scale, w_scale, m_indices):
1178
+ # Preallocate output.
1179
+ out = torch.empty(
1180
+ [xq.shape[0], wq.shape[1]], device=xq.device, dtype=torch.bfloat16
1181
+ )
1182
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1183
+ (xq, x_scale), (wq, w_scale), out, m_indices
1184
+ )
1185
+ return out
1186
+
1187
+ def quantize_and_compute(self, x, wq, w_scale, m_indices):
1188
+ xq, wq, x_scale, w_scale, m_indices = self.quantize(x, wq, w_scale, m_indices)
1189
+ return self.compute(xq, wq, x_scale, w_scale, m_indices)
1190
+
1191
+ @property
1192
+ def name(self) -> str:
1193
+ return "deepgemm_stacked"
1194
+
1195
+ @property
1196
+ def supported_accelerators(self) -> set[Accelerator]:
1197
+ if DEEPGEMM_ENABLED:
1198
+ return {Accelerator.NVIDIA_SM90}
1199
+ return set()
1200
+
1201
+ @property
1202
+ def supported_gemm_types(self) -> set[GemmType]:
1203
+ return {GemmType.GROUPED}
1204
+
1205
+ @property
1206
+ def compute_dtype(self) -> ComputeDtype:
1207
+ return ComputeDtype.FP8
1208
+
1209
+
1210
+ @register_gemm_op
1211
+ class DeepGemmMaskedStacked(DeepGemmStacked):
1212
+ def preprocess(self, x, w):
1213
+ # Quantize weights.
1214
+ wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w])
1215
+ # Group weights as single tensor.
1216
+ wq = torch.stack(wq, dim=0).contiguous()
1217
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
1218
+
1219
+ # Also view input as flattened.
1220
+ m_values = [i.shape[0] for i in x]
1221
+ expected_m = max(m_values)
1222
+ padded_m_max = ((max(m_values) + 127) // 128) * 128
1223
+ masked_m = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
1224
+
1225
+ num_groups = len(m_values)
1226
+ k = x[0].shape[1]
1227
+ x_padded = torch.zeros(
1228
+ [num_groups, padded_m_max, k], device=x[0].device, dtype=x[0].dtype
1229
+ )
1230
+ for g in range(num_groups):
1231
+ x_padded[g, : m_values[g], :] = x[g]
1232
+
1233
+ # Return processed tensors.
1234
+ return x_padded, wq, w_scale, masked_m, expected_m, m_values
1235
+
1236
+ def quantize(self, x, wq, w_scale, masked_m, expected_m, m_values):
1237
+ g, m_max, k = x.shape
1238
+ xq, x_scale = quantize_fp8_block(x.view(-1, k), block_m=1, block_k=128)
1239
+ # Pretranspose scales to deepgemm format.
1240
+ x_scale = get_col_major_tma_aligned_tensor(x_scale)
1241
+ return (
1242
+ xq.view(g, m_max, -1),
1243
+ wq,
1244
+ x_scale.view(g, m_max, -1),
1245
+ w_scale,
1246
+ masked_m,
1247
+ expected_m,
1248
+ m_values,
1249
+ )
1250
+
1251
+ def compute(self, xq, wq, x_scale, w_scale, masked_m, expected_m, m_values):
1252
+ # Preallocate output.
1253
+ out = torch.empty(
1254
+ [xq.shape[0], xq.shape[1], wq.shape[1]],
1255
+ device=xq.device,
1256
+ dtype=torch.bfloat16,
1257
+ )
1258
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1259
+ (xq, x_scale), (wq, w_scale), out, masked_m, expected_m
1260
+ )
1261
+ num_groups = xq.shape[0]
1262
+ out_list = [out[g, : m_values[g], :] for g in range(num_groups)]
1263
+ return out_list
1264
+
1265
+ def quantize_and_compute(self, x, wq, w_scale, masked_m, expected_m, m_values):
1266
+ xq, wq, x_scale, w_scale, masked_m, expected_m = self.quantize(
1267
+ x, wq, w_scale, masked_m, expected_m, m_values
1268
+ )
1269
+ return self.compute(xq, wq, x_scale, w_scale, masked_m, expected_m, m_values)
1270
+
1271
+ @property
1272
+ def name(self) -> str:
1273
+ return "deepgemm_masked_stacked"
1274
+
1275
+
1276
+ @register_gemm_op
1277
+ class DeepGemmBlockwise(GemmOpBase):
1278
+ """
1279
+ FP8 matmul with blockwise scaling implemented with DeepGemm.
1280
+ """
1281
+
1282
+ def preprocess(self, x, w):
1283
+ # Quantize weights.
1284
+ wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128)
1285
+ # allocate output.
1286
+ out = torch.empty(
1287
+ x.shape[0], wq.shape[0], device=x.device, dtype=torch.bfloat16
1288
+ )
1289
+ # Return processed tensors.
1290
+ return x, wq, w_scale, out
1291
+
1292
+ def quantize(self, x, wq, w_scale, out):
1293
+ xq, x_scale = quantize_fp8_group(x, group_size=128)
1294
+ return xq, wq, x_scale, w_scale, out
1295
+
1296
+ def compute(self, xq, wq, x_scale, w_scale, out):
1297
+ gemm_fp8_fp8_bf16_nt((xq, x_scale), (wq, w_scale), out)
1298
+ return out
1299
+
1300
+ def quantize_and_compute(self, x, wq, w_scale, out):
1301
+ xq, wq, x_scale, w_scale, out = self.quantize(x, wq, w_scale, out)
1302
+ return self.compute(xq, wq, x_scale, w_scale, out)
1303
+
1304
+ @property
1305
+ def name(self) -> str:
1306
+ return "deepgemm_blockwise"
1307
+
1308
+ @property
1309
+ def supported_accelerators(self) -> set[Accelerator]:
1310
+ if DEEPGEMM_ENABLED:
1311
+ return {Accelerator.NVIDIA_SM90}
1312
+ return set()
1313
+
1314
+ @property
1315
+ def supported_gemm_types(self) -> set[GemmType]:
1316
+ return {GemmType.REGULAR}
1317
+
1318
+ @property
1319
+ def compute_dtype(self) -> ComputeDtype:
1320
+ return ComputeDtype.FP8
1321
+
1322
+
1323
+ @register_gemm_op
1324
+ class DeepGemmRowwise(GemmOpBase):
1325
+ """
1326
+ FP8 matmul with rowwise scaling implemented with DeepGemm.
1327
+ """
1328
+
1329
+ def preprocess(self, x, w):
1330
+ # Quantize weights.
1331
+ wq, w_scale = quantize_fp8_row(w)
1332
+ # allocate output.
1333
+ out = torch.empty(
1334
+ x.shape[0], wq.shape[0], device=x.device, dtype=torch.bfloat16
1335
+ )
1336
+ # Return processed tensors.
1337
+ return x, wq, w_scale, out
1338
+
1339
+ def quantize(self, x, wq, w_scale, out):
1340
+ xq, x_scale = quantize_fp8_row(x)
1341
+ # Pretranspose scales to deepgemm format.
1342
+ x_scale = get_col_major_tma_aligned_tensor(x_scale, rowwise_scaling=True)
1343
+ return xq, wq, x_scale, w_scale, out
1344
+
1345
+ def compute(self, xq, wq, x_scale, w_scale, out):
1346
+ gemm_fp8_fp8_bf16_nt((xq, x_scale), (wq, w_scale), out)
1347
+ return out
1348
+
1349
+ def quantize_and_compute(self, x, wq, w_scale, out):
1350
+ xq, wq, x_scale, w_scale, out = self.quantize(x, wq, w_scale, out)
1351
+ return self.compute(xq, wq, x_scale, w_scale, out)
1352
+
1353
+ @property
1354
+ def name(self) -> str:
1355
+ return "deepgemm_rowwise"
1356
+
1357
+ @property
1358
+ def supported_accelerators(self) -> set[Accelerator]:
1359
+ if DEEPGEMM_ENABLED:
1360
+ return {Accelerator.NVIDIA_SM90}
1361
+ return set()
1362
+
1363
+ @property
1364
+ def supported_gemm_types(self) -> set[GemmType]:
1365
+ return {GemmType.REGULAR}
1366
+
1367
+ @property
1368
+ def compute_dtype(self) -> ComputeDtype:
1369
+ return ComputeDtype.FP8
1370
+
1371
+
1372
+ @register_gemm_op
1373
+ class FP8StackedGroupedGemm(GemmOpBase):
1374
+ """
1375
+ FP8 grouped matmul with rowwise scaling and stacked inputs.
1376
+ """
1377
+
1378
+ def preprocess(self, x, w):
1379
+ m_values = [i.shape[0] for i in x]
1380
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
1381
+ # Quantize weights.
1382
+ wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
1383
+ # Group weights as single tensor.
1384
+ wq = torch.stack(wq, dim=0).contiguous()
1385
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
1386
+ # Also view input as flattened.
1387
+ x = torch.concat(x, dim=0).contiguous()
1388
+ # Return processed tensors.
1389
+ return x, wq, w_scale, m_sizes
1390
+
1391
+ def quantize(self, x, wq, w_scale, m_sizes):
1392
+ B = x.shape[0]
1393
+ xq, x_scale = triton_quantize_fp8_row(x)
1394
+ x_scale = x_scale.view(B, -1)
1395
+ return xq, wq, x_scale, w_scale, m_sizes
1396
+
1397
+ def compute(self, xq, wq, x_scale, w_scale, m_sizes):
1398
+ return torch.ops.mslk.f8f8bf16_rowwise_grouped_stacked(
1399
+ xq, wq, x_scale, w_scale, m_sizes
1400
+ )
1401
+
1402
+ def quantize_and_compute(self, x, wq, w_scale, m_sizes):
1403
+ xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
1404
+ return self.compute(xq, wq, x_scale, w_scale, m_sizes)
1405
+
1406
+ @property
1407
+ def name(self) -> str:
1408
+ if torch.version.cuda:
1409
+ return "cutlass_grouped_stacked"
1410
+ else:
1411
+ return "ck_grouped_stacked"
1412
+
1413
+ @property
1414
+ def supported_accelerators(self) -> set[Accelerator]:
1415
+ return {
1416
+ Accelerator.NVIDIA_SM90,
1417
+ Accelerator.NVIDIA_SM100,
1418
+ Accelerator.NVIDIA_SM103,
1419
+ Accelerator.AMD_MI300X,
1420
+ }
1421
+
1422
+ @property
1423
+ def supported_gemm_types(self) -> set[GemmType]:
1424
+ return {GemmType.GROUPED}
1425
+
1426
+ @property
1427
+ def compute_dtype(self) -> ComputeDtype:
1428
+ return ComputeDtype.FP8
1429
+
1430
+
1431
+ @register_gemm_op
1432
+ class FP8StackedGroupedGemmTorch(FP8StackedGroupedGemm):
1433
+ def quantize(self, x, wq, w_scale, m_sizes):
1434
+ xq, wq, x_scale, w_scale, m_sizes = super().quantize(x, wq, w_scale, m_sizes)
1435
+ offsets = torch.cumsum(m_sizes, dim=0, dtype=torch.int32)
1436
+ out = torch.empty(
1437
+ (xq.shape[0], wq.shape[1]), dtype=torch.bfloat16, device=xq.device
1438
+ )
1439
+ x_scale = x_scale.view(x_scale.shape[0])
1440
+ return xq, wq, x_scale, w_scale, offsets, out
1441
+
1442
+ def compute(self, xq, wq, x_scale, w_scale, offsets, out):
1443
+ return torch.ops.mslk.f8f8bf16_rowwise_grouped_mm(
1444
+ xq, wq, x_scale, w_scale, offsets, out
1445
+ )
1446
+
1447
+ def quantize_and_compute(self, x, wq, w_scale, m_sizes):
1448
+ xq, wq, x_scale, w_scale, offsets, out = self.quantize(x, wq, w_scale, m_sizes)
1449
+ return self.compute(xq, wq, x_scale, w_scale, offsets, out)
1450
+
1451
+ @property
1452
+ def name(self) -> str:
1453
+ return "ck_grouped_stacked_torch_2d3d"
1454
+
1455
+ @property
1456
+ def supported_accelerators(self) -> set[Accelerator]:
1457
+ return {Accelerator.AMD_MI300X}
1458
+
1459
+ @property
1460
+ def supported_gemm_types(self) -> set[GemmType]:
1461
+ return {GemmType.GROUPED}
1462
+
1463
+ @property
1464
+ def compute_dtype(self) -> ComputeDtype:
1465
+ return ComputeDtype.FP8
1466
+
1467
+
1468
+ @register_gemm_op
1469
+ class ScaledGroupedMMRowwise(FP8StackedGroupedGemmTorch):
1470
+ def __init__(self):
1471
+ self.fast_accum = True
1472
+ self.torch_compile = False
1473
+
1474
+ def compute(self, xq, wq, x_scale, w_scale, offsets, _):
1475
+ if self.torch_compile:
1476
+ f = torch.compile(
1477
+ torch._scaled_grouped_mm,
1478
+ options={
1479
+ "max_autotune": True,
1480
+ "max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
1481
+ },
1482
+ )
1483
+ else:
1484
+ f = torch._scaled_grouped_mm
1485
+
1486
+ return f(
1487
+ xq,
1488
+ wq.transpose(-2, -1),
1489
+ offs=offsets,
1490
+ out_dtype=torch.bfloat16,
1491
+ scale_a=x_scale,
1492
+ scale_b=w_scale,
1493
+ scale_result=None,
1494
+ use_fast_accum=self.fast_accum,
1495
+ )
1496
+
1497
+ @property
1498
+ def name(self) -> str:
1499
+ return "scaled_grouped_mm_rowwise"
1500
+
1501
+ @property
1502
+ def supported_accelerators(self) -> set[Accelerator]:
1503
+ return {
1504
+ Accelerator.NVIDIA_SM90,
1505
+ Accelerator.AMD_MI300X,
1506
+ }
1507
+
1508
+ @property
1509
+ def supported_gemm_types(self) -> set[GemmType]:
1510
+ return {GemmType.GROUPED}
1511
+
1512
+ @property
1513
+ def compute_dtype(self) -> ComputeDtype:
1514
+ return ComputeDtype.FP8
1515
+
1516
+
1517
+ @register_gemm_op
1518
+ class FP8StackedGroupwiseGroupedGemm(GemmOpBase):
1519
+ """
1520
+ FP8 grouped matmul with groupwise scaling and stacked inputs.
1521
+ """
1522
+
1523
+ def preprocess(self, x, w):
1524
+ m_values = [i.shape[0] for i in x]
1525
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
1526
+ # Quantize weights.
1527
+ wq, w_scale = zip(
1528
+ *[quantize_fp8_block(i, block_m=128, block_k=128, k_major=False) for i in w]
1529
+ )
1530
+ # Group weights as single tensor.
1531
+ wq = torch.stack(wq, dim=0).contiguous()
1532
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
1533
+ # Also view input as flattened.
1534
+ x = torch.concat(x, dim=0).contiguous()
1535
+ # Return processed tensors.
1536
+ return x, wq, w_scale, m_sizes
1537
+
1538
+ def quantize(self, x, wq, w_scale, m_sizes):
1539
+ xq, x_scale = quantize_fp8_group(x, m_sizes=m_sizes)
1540
+ return xq, wq, x_scale, w_scale, m_sizes
1541
+
1542
+ def compute(self, xq, wq, x_scale, w_scale, m_sizes):
1543
+ return torch.ops.mslk.f8f8bf16_groupwise_grouped(
1544
+ xq, wq, x_scale, w_scale, m_sizes
1545
+ )
1546
+
1547
+ def quantize_and_compute(self, x, wq, w_scale, m_sizes):
1548
+ xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
1549
+ return self.compute(xq, wq, x_scale, w_scale, m_sizes)
1550
+
1551
+ @property
1552
+ def name(self) -> str:
1553
+ return "cutlass_groupwise_grouped"
1554
+
1555
+ @property
1556
+ def supported_accelerators(self) -> set[Accelerator]:
1557
+ return {Accelerator.NVIDIA_SM90}
1558
+
1559
+ @property
1560
+ def supported_gemm_types(self) -> set[GemmType]:
1561
+ return {GemmType.GROUPED}
1562
+
1563
+ @property
1564
+ def compute_dtype(self) -> ComputeDtype:
1565
+ return ComputeDtype.FP8
1566
+
1567
+
1568
+ @register_gemm_op
1569
+ class BF16GroupedGemm(GemmOpBase):
1570
+ """
1571
+ BF16 grouped matmul implemented with CK or Cutlass.
1572
+ """
1573
+
1574
+ def preprocess(self, x, w):
1575
+ # Apply sparsity to inputs if appropriate.
1576
+ # First check if N and K are fixed.
1577
+ m_values = [i.shape[0] for i in x]
1578
+ n_values = [i.shape[0] for i in w]
1579
+ k_values = [i.shape[1] for i in w]
1580
+ # If so, do specialized version of initialization.
1581
+ if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
1582
+ m_values = [i.shape[0] for i in x]
1583
+ # Inputs for fixed nk mode must be contiguous, however in the benchmark
1584
+ # script they typically are not. Do a little special processing to make them
1585
+ # work. In practice this wont be needed.
1586
+ # Start by padding along m dimension with zeros.
1587
+ max_m = max(m_values)
1588
+ x = [
1589
+ torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
1590
+ for i in x
1591
+ ]
1592
+ # Stack inputs into groups.
1593
+ x = torch.stack(x).contiguous()
1594
+ w = torch.stack(w).contiguous()
1595
+ return (
1596
+ x,
1597
+ w,
1598
+ torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device),
1599
+ )
1600
+ return x, w, None
1601
+
1602
+ def quantize(self, x, w, m_values=None):
1603
+ # No action required.
1604
+ return x, w, m_values
1605
+
1606
+ def compute(self, x, w, m_values):
1607
+ if m_values is None:
1608
+ return torch.ops.mslk.bf16bf16bf16_grouped(x, w)
1609
+ else:
1610
+ return torch.ops.mslk.bf16bf16bf16_grouped_dynamic(x, w, m_values)
1611
+
1612
+ def quantize_and_compute(self, x, w, m_values):
1613
+ return self.compute(x, w, m_values)
1614
+
1615
+ @property
1616
+ def name(self) -> str:
1617
+ if torch.version.cuda:
1618
+ return "cutlass_bf16_grouped"
1619
+ else:
1620
+ return "ck_bf16_grouped"
1621
+
1622
+ @property
1623
+ def supported_accelerators(self) -> set[Accelerator]:
1624
+ return {
1625
+ Accelerator.NVIDIA_SM90,
1626
+ Accelerator.NVIDIA_SM100,
1627
+ Accelerator.NVIDIA_SM103,
1628
+ Accelerator.AMD_MI300X,
1629
+ }
1630
+
1631
+ @property
1632
+ def supported_gemm_types(self) -> set[GemmType]:
1633
+ return {GemmType.GROUPED}
1634
+
1635
+ @property
1636
+ def compute_dtype(self) -> ComputeDtype:
1637
+ return ComputeDtype.BF16
1638
+
1639
+
1640
+ @register_gemm_op
1641
+ class FP8RowwiseBatchedGemm(GemmOpBase):
1642
+ """
1643
+ FP8 batched matmul with rowwise scaling.
1644
+ """
1645
+
1646
+ def quantize(self, x, w):
1647
+ assert isinstance(x, list) and isinstance(w, list)
1648
+ x = torch.stack(x, dim=0)
1649
+ w = torch.stack(w, dim=0)
1650
+ # Quantize both input tensors.
1651
+ xq, x_scale = quantize_fp8_row(x)
1652
+ wq, w_scale = quantize_fp8_row(w)
1653
+ return xq, wq, x_scale, w_scale
1654
+
1655
+ def compute(self, xq, wq, x_scale, w_scale):
1656
+ return torch.ops.mslk.f8f8bf16_rowwise_batched(xq, wq, x_scale, w_scale)
1657
+
1658
+ def quantize_and_compute(self, x, w):
1659
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
1660
+ return self.compute(xq, wq, x_scale, w_scale)
1661
+
1662
+ @property
1663
+ def name(self) -> str:
1664
+ if torch.version.cuda:
1665
+ return "cutlass_rowwise_batched"
1666
+ else:
1667
+ return "ck_rowwise_batched"
1668
+
1669
+ @property
1670
+ def supported_accelerators(self) -> set[Accelerator]:
1671
+ return {
1672
+ Accelerator.NVIDIA_SM90,
1673
+ Accelerator.NVIDIA_SM100,
1674
+ Accelerator.NVIDIA_SM103,
1675
+ Accelerator.AMD_MI300X,
1676
+ }
1677
+
1678
+ @property
1679
+ def supported_gemm_types(self) -> set[GemmType]:
1680
+ return {GemmType.GROUPED}
1681
+
1682
+ @property
1683
+ def compute_dtype(self) -> ComputeDtype:
1684
+ return ComputeDtype.FP8
1685
+
1686
+
1687
+ # This kernel is broken and causes GPU to lock up, needs some investigation
1688
+ # @register_gemm_op
1689
+ class TritonFP8RowwiseGemm(GemmOpBase):
1690
+ """
1691
+ FP8 matmul with rowwise scaling implemented with Triton.
1692
+ """
1693
+
1694
+ def __init__(self):
1695
+ self.fast_accum = True
1696
+
1697
+ def quantize(self, x, w):
1698
+ # Quantize both input tensors.
1699
+ xq, x_scale = quantize_fp8_row(x)
1700
+ wq, w_scale = quantize_fp8_row(w)
1701
+ bias = torch.randn(w.shape[0], device=x.device, dtype=torch.float32)
1702
+ return xq, wq, x_scale, w_scale, bias
1703
+
1704
+ def compute(self, xq, wq, x_scale, w_scale, bias):
1705
+ return matmul_fp8_row(
1706
+ xq,
1707
+ wq,
1708
+ x_scale,
1709
+ w_scale,
1710
+ bias=bias,
1711
+ fp8_fast_accum=self.fast_accum,
1712
+ use_warp_specialization=True,
1713
+ )
1714
+
1715
+ def quantize_and_compute(self, x, w):
1716
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
1717
+ return self.compute(xq, wq, x_scale, w_scale)
1718
+
1719
+ @property
1720
+ def name(self) -> str:
1721
+ return "triton_rowwise"
1722
+
1723
+ @property
1724
+ def supported_accelerators(self) -> set[Accelerator]:
1725
+ return set(Accelerator)
1726
+
1727
+ @property
1728
+ def supported_gemm_types(self) -> set[GemmType]:
1729
+ return {GemmType.REGULAR}
1730
+
1731
+ @property
1732
+ def compute_dtype(self) -> ComputeDtype:
1733
+ return ComputeDtype.FP8
1734
+
1735
+
1736
+ @register_gemm_op
1737
+ class FP8TritonBlockwiseGemm(GemmOpBase):
1738
+ """
1739
+ FP8 matmul with block scaling.
1740
+ """
1741
+
1742
+ def quantize(self, x, w):
1743
+ # Quantize both input tensors.
1744
+ xq, x_scale = quantize_fp8_block(x, 128, 128)
1745
+ wq, w_scale = quantize_fp8_block(w, 128, 128)
1746
+ return xq, wq, x_scale, w_scale
1747
+
1748
+ def compute(self, xq, wq, x_scale, w_scale):
1749
+ return matmul_fp8_block(xq, wq, x_scale, w_scale, 128, 128, 128)
1750
+
1751
+ def quantize_and_compute(self, x, w):
1752
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
1753
+ return self.compute(xq, wq, x_scale, w_scale)
1754
+
1755
+ @property
1756
+ def name(self) -> str:
1757
+ return "triton_blockwise"
1758
+
1759
+ @property
1760
+ def supported_accelerators(self) -> set[Accelerator]:
1761
+ return set(Accelerator)
1762
+
1763
+ @property
1764
+ def supported_gemm_types(self) -> set[GemmType]:
1765
+ return {GemmType.REGULAR}
1766
+
1767
+ @property
1768
+ def compute_dtype(self) -> ComputeDtype:
1769
+ return ComputeDtype.FP8
1770
+
1771
+
1772
+ @register_gemm_op
1773
+ class FP8CutlassBlockwiseGemm(GemmOpBase):
1774
+ """
1775
+ FP8 matmul with block scaling.
1776
+ """
1777
+
1778
+ def quantize(self, x, w):
1779
+ # Quantize both input tensors.
1780
+ xq, x_scale = quantize_fp8_block(x, 128, 128)
1781
+ wq, w_scale = quantize_fp8_block(w, 128, 128)
1782
+ return xq, wq, x_scale, w_scale
1783
+
1784
+ def compute(self, xq, wq, x_scale, w_scale):
1785
+ return torch.ops.mslk.f8f8bf16_blockwise(
1786
+ xq, wq, x_scale, w_scale, 128, 128, 128
1787
+ )
1788
+
1789
+ def quantize_and_compute(self, x, w):
1790
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
1791
+ return self.compute(xq, wq, x_scale, w_scale)
1792
+
1793
+ @property
1794
+ def name(self) -> str:
1795
+ if torch.version.cuda:
1796
+ return "cutlass_blockwise"
1797
+ else:
1798
+ return "ck_blockwise"
1799
+
1800
+ @property
1801
+ def supported_accelerators(self) -> set[Accelerator]:
1802
+ return {
1803
+ Accelerator.NVIDIA_SM90,
1804
+ Accelerator.AMD_MI300X,
1805
+ }
1806
+
1807
+ @property
1808
+ def supported_gemm_types(self) -> set[GemmType]:
1809
+ return {GemmType.REGULAR}
1810
+
1811
+ @property
1812
+ def compute_dtype(self) -> ComputeDtype:
1813
+ return ComputeDtype.FP8
1814
+
1815
+
1816
+ @register_gemm_op
1817
+ class FP8CutlassGroupwiseGemm(GemmOpBase):
1818
+ """
1819
+ FP8 matmul with group / block scaling.
1820
+ """
1821
+
1822
+ def preprocess(self, x, w):
1823
+ # Quantize weights.
1824
+ # Scale is expected to be in [K, N] layout (N Major).
1825
+ wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128, k_major=False)
1826
+ # Return processed tensors.
1827
+ return x, wq, w_scale
1828
+
1829
+ def quantize(self, x, wq, w_scale):
1830
+ # Scale is expected to be in [K, M] layout (M Major).
1831
+ xq, x_scale = quantize_fp8_group(x, k_major=False)
1832
+ # Pretranspose scales to deepgemm format.
1833
+ return xq, wq, x_scale, w_scale
1834
+
1835
+ def compute(self, xq, wq, x_scale, w_scale):
1836
+ return torch.ops.mslk.f8f8bf16_groupwise(xq, wq, x_scale, w_scale)
1837
+
1838
+ def quantize_and_compute(self, x, wq, w_scale):
1839
+ xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
1840
+ return self.compute(xq, wq, x_scale, w_scale)
1841
+
1842
+ @property
1843
+ def name(self) -> str:
1844
+ return "cutlass_groupwise"
1845
+
1846
+ @property
1847
+ def supported_accelerators(self) -> set[Accelerator]:
1848
+ return {Accelerator.NVIDIA_SM90}
1849
+
1850
+ @property
1851
+ def supported_gemm_types(self) -> set[GemmType]:
1852
+ return {GemmType.REGULAR}
1853
+
1854
+ @property
1855
+ def compute_dtype(self) -> ComputeDtype:
1856
+ return ComputeDtype.FP8
1857
+
1858
+
1859
+ @register_gemm_op
1860
+ class F8I4RowwiseGemm(GemmOpBase):
1861
+ """
1862
+ Mixed Precision FP8 Activations with Int4 Weights.
1863
+ """
1864
+
1865
+ def _int4_row_quantize(
1866
+ self,
1867
+ x: torch.Tensor,
1868
+ group_size: int = 128,
1869
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1870
+ n_bit = 4 # Number of target bits.
1871
+ to_quant = x.reshape(-1, group_size).to(torch.float)
1872
+
1873
+ max_val = to_quant.amax(dim=1, keepdim=True)
1874
+ min_val = to_quant.amin(dim=1, keepdim=True)
1875
+ max_int = 2**n_bit - 1
1876
+ min_int = 0
1877
+ scales = (max_val - min_val).clamp(min=1e-6) / max_int
1878
+
1879
+ zeros = min_val + scales * (2 ** (n_bit - 1))
1880
+
1881
+ out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
1882
+
1883
+ # Recenter output and move to int8.
1884
+ out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
1885
+
1886
+ # Cutlass expects column major layout for scale and zero point,
1887
+ # so we transpose here and make them contiguous.
1888
+ scales = scales.view(x.shape[0], -1).t().contiguous()
1889
+ zeros = zeros.view(x.shape[0], -1).t().contiguous()
1890
+
1891
+ return out, scales, zeros
1892
+
1893
+ def _pack_int4(self, x: torch.Tensor) -> torch.Tensor:
1894
+ # Given int8 x, pack adjacent int4 values into a single int8.
1895
+ low_x = x[:, ::2]
1896
+ high_x = x[:, 1::2]
1897
+
1898
+ # High bits need to left shift, this also masks off extra bits.
1899
+ high_x = torch.bitwise_left_shift(high_x, 4)
1900
+ # Low bits need to have sign bits removed.
1901
+ low_x = torch.bitwise_and(low_x, 0xF)
1902
+
1903
+ # Recombine into a single value with bitwise or.
1904
+ return torch.bitwise_or(low_x, high_x).contiguous()
1905
+
1906
+ def quantize(self, x, w):
1907
+ # Quantize both input tensors.
1908
+ xq, x_scale = quantize_fp8_row(x)
1909
+ wq, w_scale, w_zp = self._int4_row_quantize(w)
1910
+ # Pack int4 values together.
1911
+ wq = self._pack_int4(wq)
1912
+ return xq, wq, x_scale, w_scale, w_zp
1913
+
1914
+ def compute(self, xq, wq, x_scale, w_scale, w_zp):
1915
+ return torch.ops.mslk.f8i4bf16_rowwise(xq, wq, x_scale, w_scale, w_zp)
1916
+
1917
+ def quantize_and_compute(self, x, w):
1918
+ xq, wq, x_scale, w_scale, w_zp = self.quantize(x, w)
1919
+ return self.compute(xq, wq, x_scale, w_scale, w_zp)
1920
+
1921
+ @property
1922
+ def name(self) -> str:
1923
+ return "cutlass_f8i4_rowwise"
1924
+
1925
+ @property
1926
+ def supported_accelerators(self) -> set[Accelerator]:
1927
+ return {Accelerator.NVIDIA_SM90}
1928
+
1929
+ @property
1930
+ def supported_gemm_types(self) -> set[GemmType]:
1931
+ return {GemmType.REGULAR}
1932
+
1933
+ @property
1934
+ def compute_dtype(self) -> ComputeDtype:
1935
+ return ComputeDtype.FP8
1936
+
1937
+
1938
+ @register_gemm_op
1939
+ class F8I4ShuffledGemm(GemmOpBase):
1940
+ def preprocess(self, x, w):
1941
+ # Prequantize and pack weights.
1942
+ wq, (group_scale, row_scale) = quantize_int4_preshuffle(w)
1943
+ return x, wq, row_scale, group_scale
1944
+
1945
+ def quantize(self, x, wq, row_scale, group_scale):
1946
+ # Quantize both input tensors.
1947
+ xq, x_scale = quantize_fp8_row(x)
1948
+ return xq, wq, x_scale, row_scale, group_scale
1949
+
1950
+ def compute(self, xq, wq, x_scale, row_scale, group_scale):
1951
+ # Handle batched cases by looping over each batch.
1952
+ if xq.dim() == 3:
1953
+ B, M, _ = xq.shape
1954
+ _, N, _ = wq.shape
1955
+ y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
1956
+ for i in range(B):
1957
+ y[i] = torch.ops.mslk.f8i4bf16_shuffled(
1958
+ xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i]
1959
+ )
1960
+ return y
1961
+ # Otherwise run gemm normally.
1962
+ return torch.ops.mslk.f8i4bf16_shuffled(xq, wq, x_scale, row_scale, group_scale)
1963
+
1964
+ def quantize_and_compute(self, x, wq, row_scale, group_scale):
1965
+ xq, wq, x_scale, row_scale, group_scale = self.quantize(
1966
+ x, wq, row_scale, group_scale
1967
+ )
1968
+ return self.compute(xq, wq, x_scale, row_scale, group_scale)
1969
+
1970
+ @property
1971
+ def name(self) -> str:
1972
+ return "cutlass_f8i4_preshuffle"
1973
+
1974
+ @property
1975
+ def supported_accelerators(self) -> set[Accelerator]:
1976
+ return {Accelerator.NVIDIA_SM90}
1977
+
1978
+ @property
1979
+ def supported_gemm_types(self) -> set[GemmType]:
1980
+ return {GemmType.REGULAR}
1981
+
1982
+ @property
1983
+ def compute_dtype(self) -> ComputeDtype:
1984
+ return ComputeDtype.FP8
1985
+
1986
+
1987
+ @register_gemm_op
1988
+ class BF16I4ShuffledGemm(GemmOpBase):
1989
+ def preprocess(self, x, w):
1990
+ # Prequantize and pack weights.
1991
+ wq, (group_scale, group_zero) = quantize_int4_preshuffle(w, dtype="bf16")
1992
+ return x, wq, group_scale, group_zero
1993
+
1994
+ def quantize(self, x, wq, group_scale, group_zero):
1995
+ # No extra action required.
1996
+ return x, wq, group_scale, group_zero
1997
+
1998
+ def compute(self, x, wq, group_scale, group_zero):
1999
+ # Handle batched cases by looping over each batch.
2000
+ if x.dim() == 3:
2001
+ B, M, _ = x.shape
2002
+ _, N, _ = wq.shape
2003
+ y = torch.empty((B, M, N), device=x.device, dtype=torch.bfloat16)
2004
+ for i in range(B):
2005
+ y[i] = torch.ops.mslk.bf16i4bf16_shuffled(
2006
+ x[i], wq[i], group_scale[i], group_zero[i]
2007
+ )
2008
+ return y
2009
+ # Otherwise run gemm normally.
2010
+ return torch.ops.mslk.bf16i4bf16_shuffled(x, wq, group_scale, group_zero)
2011
+
2012
+ def quantize_and_compute(self, x, wq, group_scale, group_zero):
2013
+ x, wq, group_scale, group_zero = self.quantize(x, wq, group_scale, group_zero)
2014
+ return self.compute(x, wq, group_scale, group_zero)
2015
+
2016
+ @property
2017
+ def name(self) -> str:
2018
+ return "cutlass_bf16i4_preshuffle"
2019
+
2020
+ @property
2021
+ def supported_accelerators(self) -> set[Accelerator]:
2022
+ return {Accelerator.NVIDIA_SM90}
2023
+
2024
+ @property
2025
+ def supported_gemm_types(self) -> set[GemmType]:
2026
+ return {GemmType.REGULAR}
2027
+
2028
+ @property
2029
+ def compute_dtype(self) -> ComputeDtype:
2030
+ return ComputeDtype.BF16
2031
+
2032
+
2033
+ @register_gemm_op
2034
+ class BF16I4ShuffledBatchedGemm(GemmOpBase):
2035
+ """
2036
+ BF16 x INT4 mixed dtype batched gemm with preshuffling.
2037
+ """
2038
+
2039
+ def preprocess(self, x, w):
2040
+ assert isinstance(x, list) and isinstance(w, list)
2041
+ x = torch.stack(x, dim=0)
2042
+ w = torch.stack(w, dim=0)
2043
+ # Prequantize and pack weights.
2044
+ wq, (group_scale, group_zero) = quantize_int4_preshuffle(w, dtype="bf16")
2045
+ return x, wq, group_scale, group_zero
2046
+
2047
+ def quantize(self, x, wq, group_scale, group_zero):
2048
+ # No extra action required.
2049
+ return x, wq, group_scale, group_zero
2050
+
2051
+ def compute(self, x, wq, group_scale, group_zero):
2052
+ return torch.ops.mslk.bf16i4bf16_shuffled_batched(
2053
+ x, wq, group_scale, group_zero
2054
+ )
2055
+
2056
+ def quantize_and_compute(self, x, wq, group_scale, group_zero):
2057
+ x, wq, group_scale, group_zero = self.quantize(x, wq, group_scale, group_zero)
2058
+ return self.compute(x, wq, group_scale, group_zero)
2059
+
2060
+ @property
2061
+ def name(self) -> str:
2062
+ return "cutlass_bf16i4_preshuffle_batched"
2063
+
2064
+ @property
2065
+ def supported_accelerators(self) -> set[Accelerator]:
2066
+ return {Accelerator.NVIDIA_SM90}
2067
+
2068
+ @property
2069
+ def supported_gemm_types(self) -> set[GemmType]:
2070
+ return {GemmType.GROUPED}
2071
+
2072
+ @property
2073
+ def compute_dtype(self) -> ComputeDtype:
2074
+ return ComputeDtype.BF16
2075
+
2076
+
2077
+ @register_gemm_op
2078
+ class F8I4ShuffledGroupedGemm(GemmOpBase):
2079
+ """
2080
+ FP8 x Int4 mixed dtype grouped gemm with preshuffling.
2081
+ """
2082
+
2083
+ def preprocess(self, x, w):
2084
+ assert isinstance(x, list) and isinstance(w, list), (
2085
+ "Only supported for grouped inputs."
2086
+ )
2087
+ m_values = [i.shape[0] for i in x]
2088
+ # Convert m_values into offsets into grouped tensor.
2089
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
2090
+ # Quantize weights.
2091
+ wq, scales = zip(*[quantize_int4_preshuffle(i) for i in w])
2092
+ group_scale, row_scale = zip(*scales)
2093
+ # Group weights as single tensor.
2094
+ wq = torch.stack(wq, dim=0).contiguous()
2095
+ row_scale = torch.stack(row_scale, dim=0).contiguous()
2096
+ group_scale = torch.stack(group_scale, dim=0).contiguous()
2097
+ # Also view input as flattened.
2098
+ x = torch.concat(x, dim=0).contiguous()
2099
+ # Return processed tensors.
2100
+ return x, wq, row_scale, group_scale, m_sizes
2101
+
2102
+ def quantize(self, x, wq, row_scale, group_scale, m_sizes):
2103
+ B = x.shape[0]
2104
+ xq, x_scale = triton_quantize_fp8_row(x)
2105
+ x_scale = x_scale.view(B, -1)
2106
+ return xq, wq, x_scale, row_scale, group_scale, m_sizes
2107
+
2108
+ def compute(self, xq, wq, x_scale, row_scale, group_scale, m_sizes):
2109
+ out = torch.ops.mslk.f8i4bf16_shuffled_grouped(
2110
+ xq, wq, x_scale, row_scale, group_scale, m_sizes
2111
+ )
2112
+ return out
2113
+
2114
+ def quantize_and_compute(self, x, wq, row_scale, group_scale, m_sizes):
2115
+ xq, wq, x_scale, row_scale, group_scale, m_sizes = self.quantize(
2116
+ x, wq, row_scale, group_scale, m_sizes
2117
+ )
2118
+ return self.compute(xq, wq, x_scale, row_scale, group_scale, m_sizes)
2119
+
2120
+ @property
2121
+ def name(self) -> str:
2122
+ if torch.version.cuda:
2123
+ return "cutlass_f8i4_grouped_preshuffle"
2124
+ else:
2125
+ return "ck_f8i4_grouped_preshuffle"
2126
+
2127
+ @property
2128
+ def supported_accelerators(self) -> set[Accelerator]:
2129
+ return {Accelerator.NVIDIA_SM90}
2130
+
2131
+ @property
2132
+ def supported_gemm_types(self) -> set[GemmType]:
2133
+ return {GemmType.GROUPED}
2134
+
2135
+ @property
2136
+ def compute_dtype(self) -> ComputeDtype:
2137
+ return ComputeDtype.FP8
2138
+
2139
+
2140
+ @register_gemm_op
2141
+ class BF16I4ShuffledGroupedGemm(GemmOpBase):
2142
+ """
2143
+ BF16 x Int4 mixed dtype grouped gemm with preshuffling.
2144
+ """
2145
+
2146
+ def preprocess(self, x, w):
2147
+ assert isinstance(x, list) and isinstance(w, list), (
2148
+ "Only supported for grouped inputs."
2149
+ )
2150
+ m_values = [i.shape[0] for i in x]
2151
+ # Convert m_values into offsets into grouped tensor.
2152
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
2153
+ # Quantize weights.
2154
+ wq, scales = zip(
2155
+ *[quantize_int4_preshuffle(i, dtype="bf16", use_zp=False) for i in w]
2156
+ )
2157
+ # Group weights as single tensor.
2158
+ group_scale, group_zero = zip(*scales)
2159
+ wq = torch.stack(wq, dim=0).contiguous()
2160
+ group_scale = torch.stack(group_scale, dim=0).contiguous()
2161
+ group_zero = torch.stack(group_zero, dim=0).contiguous()
2162
+ # Also view input as flattened.
2163
+ x = torch.concat(x, dim=0).contiguous()
2164
+ # Return processed tensors.
2165
+ return x, wq, group_scale, group_zero, m_sizes
2166
+
2167
+ def quantize(self, x, wq, group_scale, group_zero, m_sizes):
2168
+ return x, wq, group_scale, group_zero, m_sizes
2169
+
2170
+ def compute(self, x, wq, group_scale, group_zero, m_sizes):
2171
+ # TODO Zero points arent currently supported in grouped gemm.
2172
+ # We leave them as inputs for future compatibility but they are ignored.
2173
+ return torch.ops.mslk.bf16i4bf16_shuffled_grouped(
2174
+ x, wq, group_scale, group_zero, m_sizes
2175
+ )
2176
+
2177
+ def quantize_and_compute(self, x, wq, group_scale, group_zero, m_sizes):
2178
+ x, wq, group_scale, group_zero, m_sizes = self.quantize(
2179
+ x, wq, group_scale, group_zero, m_sizes
2180
+ )
2181
+ return self.compute(x, wq, group_scale, group_zero, m_sizes)
2182
+
2183
+ @property
2184
+ def name(self) -> str:
2185
+ return "cutlass_bf16i4_grouped_preshuffle"
2186
+
2187
+ @property
2188
+ def supported_accelerators(self) -> set[Accelerator]:
2189
+ return {Accelerator.NVIDIA_SM90}
2190
+
2191
+ @property
2192
+ def supported_gemm_types(self) -> set[GemmType]:
2193
+ return {GemmType.GROUPED}
2194
+
2195
+ @property
2196
+ def compute_dtype(self) -> ComputeDtype:
2197
+ return ComputeDtype.BF16
2198
+
2199
+
2200
+ @register_gemm_op
2201
+ class BF16GroupedGrad(GemmOpBase):
2202
+ """
2203
+ BF16 grouped matmul with dgrad inputs in pretraining backed by cutlass
2204
+ """
2205
+
2206
+ def preprocess(self, x, w):
2207
+ m_values = [i.shape[0] for i in x]
2208
+ # Convert m_values into offsets into grouped tensor.
2209
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2210
+ # Group weights as single tensor.
2211
+ w = torch.stack(w, dim=0).contiguous()
2212
+ # Prepare online dgrad during pretraining backward.
2213
+ w_perm = w.permute(0, 2, 1).contiguous()
2214
+ # w.contiguous() is very expensive so handling it inside the gmm kernel for free
2215
+ w = w_perm.permute(0, 2, 1)
2216
+
2217
+ # Also view input as flattened.
2218
+ x = torch.concat(x, dim=0).contiguous()
2219
+ # Return processed tensors.
2220
+ return x, w, m_sizes
2221
+
2222
+ def quantize(self, x, w, m_sizes):
2223
+ return x, w, m_sizes
2224
+
2225
+ def compute(self, x, w, m_sizes):
2226
+ return torch.ops.mslk.bf16bf16bf16_grouped_grad(x, w, m_sizes)
2227
+
2228
+ def quantize_and_compute(self, x, w, m_sizes):
2229
+ x, w, m_sizes = self.quantize(x, w, m_sizes)
2230
+ return self.compute(x, w, m_sizes)
2231
+
2232
+ @property
2233
+ def name(self) -> str:
2234
+ return "bf16_grouped_grad"
2235
+
2236
+ @property
2237
+ def supported_accelerators(self) -> set[Accelerator]:
2238
+ return {
2239
+ Accelerator.NVIDIA_SM90,
2240
+ Accelerator.NVIDIA_SM100,
2241
+ Accelerator.NVIDIA_SM103,
2242
+ }
2243
+
2244
+ @property
2245
+ def supported_gemm_types(self) -> set[GemmType]:
2246
+ return {GemmType.GROUPED}
2247
+
2248
+ @property
2249
+ def compute_dtype(self) -> ComputeDtype:
2250
+ return ComputeDtype.BF16
2251
+
2252
+
2253
+ @register_gemm_op
2254
+ class BF16GroupedWGrad(GemmOpBase):
2255
+ """
2256
+ BF16 grouped matmul with wgrad inputs in pretraining backed by cutlass
2257
+ """
2258
+
2259
+ def preprocess(self, x, w):
2260
+ # Get K values for each group
2261
+ k_values = [xi.shape[1] for xi in x] # K dimension for each group
2262
+
2263
+ # Convert k_values into sizes tensor
2264
+ k_sizes = torch.tensor(k_values).to(dtype=torch.int64, device=x[0].device)
2265
+
2266
+ x = torch.concat(x, dim=1).contiguous() # shape: (M, G*K)
2267
+ w = torch.concat(w, dim=1).contiguous() # shape: (N, G*K)
2268
+
2269
+ # Transpose the follows to simulate wgrad shapes
2270
+ x = x.t().contiguous() # shape: (G*K, M)
2271
+ w = w.t().contiguous() # shape: (G*K, N)
2272
+
2273
+ # Return processed tensors
2274
+ return x, w, k_sizes
2275
+
2276
+ def quantize(self, x, w, k_sizes):
2277
+ return x, w, k_sizes
2278
+
2279
+ def compute(self, x, w, k_sizes):
2280
+ return torch.ops.mslk.bf16bf16bf16_grouped_wgrad(x, w, k_sizes)
2281
+
2282
+ def quantize_and_compute(self, x, w, k_sizes):
2283
+ x, w, k_sizes = self.quantize(x, w, k_sizes)
2284
+ return self.compute(x, w, k_sizes)
2285
+
2286
+ @property
2287
+ def name(self) -> str:
2288
+ return "bf16_grouped_wgrad"
2289
+
2290
+ @property
2291
+ def supported_accelerators(self) -> set[Accelerator]:
2292
+ return {
2293
+ Accelerator.NVIDIA_SM90,
2294
+ Accelerator.NVIDIA_SM100,
2295
+ Accelerator.NVIDIA_SM103,
2296
+ }
2297
+
2298
+ @property
2299
+ def supported_gemm_types(self) -> set[GemmType]:
2300
+ return {GemmType.GROUPED}
2301
+
2302
+ @property
2303
+ def compute_dtype(self) -> ComputeDtype:
2304
+ return ComputeDtype.BF16
2305
+
2306
+
2307
+ @register_gemm_op
2308
+ class BF16GroupedStacked(GemmOpBase):
2309
+ """
2310
+ BF16 grouped matmul with stacked inputs backed by cutlass or ck.
2311
+ """
2312
+
2313
+ def preprocess(self, x, w):
2314
+ m_values = [i.shape[0] for i in x]
2315
+ # Convert m_values into offsets into grouped tensor.
2316
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2317
+ # Group weights as single tensor.
2318
+ w = torch.stack(w, dim=0).contiguous()
2319
+ # Also view input as flattened.
2320
+ x = torch.concat(x, dim=0).contiguous()
2321
+ # Return processed tensors.
2322
+ return x, w, m_sizes
2323
+
2324
+ def quantize(self, x, w, m_sizes):
2325
+ return x, w, m_sizes
2326
+
2327
+ def compute(self, x, w, m_sizes):
2328
+ return torch.ops.mslk.bf16bf16bf16_grouped_stacked(x, w, m_sizes)
2329
+
2330
+ def quantize_and_compute(self, x, w, m_sizes):
2331
+ x, w, m_sizes = self.quantize(x, w, m_sizes)
2332
+ return self.compute(x, w, m_sizes)
2333
+
2334
+ @property
2335
+ def name(self) -> str:
2336
+ return "bf16_grouped_stacked"
2337
+
2338
+ @property
2339
+ def supported_accelerators(self) -> set[Accelerator]:
2340
+ return {
2341
+ Accelerator.NVIDIA_SM90,
2342
+ Accelerator.NVIDIA_SM100,
2343
+ Accelerator.NVIDIA_SM103,
2344
+ Accelerator.AMD_MI300X,
2345
+ }
2346
+
2347
+ @property
2348
+ def supported_gemm_types(self) -> set[GemmType]:
2349
+ return {GemmType.GROUPED}
2350
+
2351
+ @property
2352
+ def compute_dtype(self) -> ComputeDtype:
2353
+ return ComputeDtype.BF16
2354
+
2355
+
2356
+ @register_gemm_op
2357
+ class BF16I4RowwiseGemm(F8I4RowwiseGemm):
2358
+ """
2359
+ Mixed Precision BF16 Activations with Int4 Weights.
2360
+ """
2361
+
2362
+ def quantize(self, x, w):
2363
+ # Quantize both input tensors.
2364
+ wq, w_scale, w_zp = self._int4_row_quantize(w)
2365
+ # Pack int4 values together.
2366
+ wq = self._pack_int4(wq)
2367
+ return (
2368
+ x.to(torch.bfloat16),
2369
+ wq,
2370
+ w_scale,
2371
+ w_zp,
2372
+ )
2373
+
2374
+ def compute(self, x, wq, w_scale, w_zp):
2375
+ return torch.ops.mslk.bf16i4bf16_rowwise(x, wq, w_scale, w_zp)
2376
+
2377
+ def quantize_and_compute(self, x, w):
2378
+ x, wq, w_scale, w_zp = self.quantize(x, w)
2379
+ return self.compute(x, wq, w_scale, w_zp)
2380
+
2381
+ @property
2382
+ def name(self) -> str:
2383
+ return "cutlass_bf16i4_rowwise"
2384
+
2385
+ @property
2386
+ def supported_accelerators(self) -> set[Accelerator]:
2387
+ return {Accelerator.NVIDIA_SM90}
2388
+
2389
+ @property
2390
+ def supported_gemm_types(self) -> set[GemmType]:
2391
+ return {GemmType.REGULAR}
2392
+
2393
+ @property
2394
+ def compute_dtype(self) -> ComputeDtype:
2395
+ return ComputeDtype.BF16
2396
+
2397
+
2398
+ @register_gemm_op
2399
+ class TinyGemmBF16I4(GemmOpBase):
2400
+ """
2401
+ Mixed Precision BF16 Activations with Int4 Weights using tinygemm.
2402
+ """
2403
+
2404
+ def quantize(self, x, w):
2405
+ # Quantize and pack weights to int4 using tinygemm utils.
2406
+ w_int32, w_scales_and_zeros = group_quantize_tensor(
2407
+ w, n_bit=4, q_group_size=128
2408
+ )
2409
+ wq = torch.ops.tinygemm.convert_matrix_to_m16n8k16_Aint4_layout(w_int32, 4)
2410
+ return x, wq, w_scales_and_zeros
2411
+
2412
+ def compute(self, x, wq, scale):
2413
+ return torch.ops.tinygemm.tinygemm_y_f16RM_x_f16RM_w_int4TC(
2414
+ wq, x, 128, scale, False
2415
+ )
2416
+
2417
+ def quantize_and_compute(self, x, w):
2418
+ x, wq, scale = self.quantize(x, w)
2419
+ return self.compute(x, wq, scale)
2420
+
2421
+ @property
2422
+ def name(self) -> str:
2423
+ return "tinygemm_bf16i4"
2424
+
2425
+ @property
2426
+ def supported_accelerators(self) -> set[Accelerator]:
2427
+ if TINYGEMM_ENABLED:
2428
+ return {Accelerator.NVIDIA_SM90}
2429
+ return set()
2430
+
2431
+ @property
2432
+ def supported_gemm_types(self) -> set[GemmType]:
2433
+ return {GemmType.REGULAR}
2434
+
2435
+ @property
2436
+ def compute_dtype(self) -> ComputeDtype:
2437
+ return ComputeDtype.BF16
2438
+
2439
+
2440
+ @register_gemm_op
2441
+ class MarlinBF16I4(GemmOpBase):
2442
+ """
2443
+ Mixed Precision BF16 Activations with Int4 Weights using Marlin.
2444
+ """
2445
+
2446
+ def quantize(self, x, w):
2447
+ # Marlin quantize expects weights in [K, N] layout.
2448
+ _, wq, scale = marlin_quantize(w.t().contiguous(), 128)
2449
+ return x, wq, scale
2450
+
2451
+ def compute(self, x, wq, scale):
2452
+ return torch.ops.marlin.marlin_gemm(x, wq, scale)
2453
+
2454
+ def quantize_and_compute(self, x, w):
2455
+ x, wq, scale = self.quantize(x, w)
2456
+ return self.compute(x, wq, scale)
2457
+
2458
+ @property
2459
+ def name(self) -> str:
2460
+ return "marlin_bf16i4"
2461
+
2462
+ @property
2463
+ def supported_accelerators(self) -> set[Accelerator]:
2464
+ if MARLIN_ENABLED:
2465
+ return {
2466
+ Accelerator.NVIDIA_SM90,
2467
+ Accelerator.NVIDIA_SM100,
2468
+ Accelerator.NVIDIA_SM103,
2469
+ }
2470
+ return set()
2471
+
2472
+ @property
2473
+ def supported_gemm_types(self) -> set[GemmType]:
2474
+ return {GemmType.REGULAR}
2475
+
2476
+ @property
2477
+ def compute_dtype(self) -> ComputeDtype:
2478
+ return ComputeDtype.BF16
2479
+
2480
+
2481
+ @register_gemm_op
2482
+ class MacheteBF16I4(GemmOpBase):
2483
+ """
2484
+ Mixed Precision BF16 Activations with Int4 Weights using Machete.
2485
+ """
2486
+
2487
+ def quantize(self, x, w):
2488
+ # Marlin quantize expects weights in [K, N] layout.
2489
+ _, wq, scale, _ = machete_quantize_and_pack(
2490
+ w.t().contiguous(), bits=4, groupsize=128
2491
+ )
2492
+ return x, wq, scale
2493
+
2494
+ def compute(self, x, wq, scale):
2495
+ return machete_gemm(x, wq, bits=4, groupsize=128, scales=scale)
2496
+
2497
+ def quantize_and_compute(self, x, w):
2498
+ x, wq, scale = self.quantize(x, w)
2499
+ return self.compute(x, wq, scale)
2500
+
2501
+ @property
2502
+ def name(self) -> str:
2503
+ return "machete_bf16i4"
2504
+
2505
+ @property
2506
+ def supported_accelerators(self) -> set[Accelerator]:
2507
+ if MACHETE_ENABLED:
2508
+ return {Accelerator.NVIDIA_SM90}
2509
+ return set()
2510
+
2511
+ @property
2512
+ def supported_gemm_types(self) -> set[GemmType]:
2513
+ return {GemmType.REGULAR}
2514
+
2515
+ @property
2516
+ def compute_dtype(self) -> ComputeDtype:
2517
+ return ComputeDtype.BF16
2518
+
2519
+
2520
+ @register_gemm_op
2521
+ class NVFP4Gemm(GemmOpBase):
2522
+ """
2523
+ NVFP4 matmul with block-wise scaling.
2524
+ """
2525
+
2526
+ def quantize(self, x, w):
2527
+ x_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(x.flatten()), dim=-1).to(
2528
+ torch.float32
2529
+ )
2530
+ w_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(w.flatten()), dim=-1).to(
2531
+ torch.float32
2532
+ )
2533
+ global_scale = 1 / (x_global_scale * w_global_scale)
2534
+
2535
+ xq, x_scale = triton_quantize_nvfp4(x, x_global_scale)
2536
+ wq, w_scale = triton_quantize_nvfp4(w, w_global_scale)
2537
+
2538
+ return xq, wq, x_scale, w_scale, global_scale
2539
+
2540
+ def compute(self, xq, wq, x_scale, w_scale, global_scale):
2541
+ return torch.ops.mslk.f4f4bf16(
2542
+ xq, wq, x_scale, w_scale, global_scale=global_scale
2543
+ )
2544
+
2545
+ def quantize_and_compute(self, x, w):
2546
+ xq, wq, x_scale, w_scale, global_scale = self.quantize(x, w)
2547
+ return self.compute(xq, wq, x_scale, w_scale, global_scale=global_scale)
2548
+
2549
+ @property
2550
+ def name(self) -> str:
2551
+ return "cutlass_nv_f4f4bf16"
2552
+
2553
+ @property
2554
+ def supported_accelerators(self) -> set[Accelerator]:
2555
+ return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
2556
+
2557
+ @property
2558
+ def supported_gemm_types(self) -> set[GemmType]:
2559
+ return {GemmType.REGULAR}
2560
+
2561
+ @property
2562
+ def compute_dtype(self) -> ComputeDtype:
2563
+ return ComputeDtype.FP4
2564
+
2565
+
2566
+ @register_gemm_op
2567
+ class NVFP4Quantize(GemmOpBase):
2568
+ """
2569
+ NVFP4 quantization with block-wise scaling.
2570
+ """
2571
+
2572
+ def quantize_rms(self, x, w):
2573
+ M, N = w.shape
2574
+ group_size = 16
2575
+ w = torch.randn(group_size, dtype=torch.bfloat16, device=w.device)
2576
+ x_global_scale = torch.tensor([448.0 * 6.0]).to(
2577
+ device=x.device, dtype=torch.float32
2578
+ ) / torch.amax(torch.abs(x.flatten()), dim=-1).to(torch.float32)
2579
+ xq_ref, x_scale_ref = triton_scale_nvfp4_quant_rms(
2580
+ x,
2581
+ w.repeat(M * N // group_size),
2582
+ x_global_scale,
2583
+ group_size=group_size,
2584
+ EPS=1e-5,
2585
+ )
2586
+
2587
+ intermediate = rms_norm(x.reshape(-1, 16), w, eps=1e-5)
2588
+ intermediate = intermediate.to(torch.bfloat16).reshape(M, N)
2589
+ xq, x_scale = triton_quantize_nvfp4(
2590
+ intermediate,
2591
+ x_global_scale,
2592
+ group_size=group_size,
2593
+ )
2594
+
2595
+ def quantize_silu(self, x, w):
2596
+ M, N = x.shape
2597
+ group_size = 16
2598
+ x_global_scale = torch.tensor([448.0 * 6.0]).to(
2599
+ device=x.device, dtype=torch.float32
2600
+ ) / torch.amax(torch.abs(x.flatten()), dim=-1).to(torch.float32)
2601
+ xq_ref, x_scale_ref = triton_scale_nvfp4_quant_silu(
2602
+ x,
2603
+ w,
2604
+ x_global_scale,
2605
+ group_size=group_size,
2606
+ )
2607
+
2608
+ intermediate = silu_mul(x.reshape(-1, 16), w.reshape(-1, 16))
2609
+ intermediate = intermediate.to(torch.bfloat16).reshape(M, N)
2610
+ xq, x_scale = triton_quantize_nvfp4(
2611
+ intermediate,
2612
+ x_global_scale,
2613
+ group_size=group_size,
2614
+ )
2615
+
2616
+ def quantize(self, x, w):
2617
+ x_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(x.flatten()), dim=-1).to(
2618
+ torch.float32
2619
+ )
2620
+ w_global_scale = (448.0 * 6.0) / torch.amax(torch.abs(w.flatten()), dim=-1).to(
2621
+ torch.float32
2622
+ )
2623
+ global_scale = 1 / (x_global_scale * w_global_scale)
2624
+
2625
+ xq, x_scale = triton_quantize_nvfp4(x, x_global_scale)
2626
+ wq, w_scale = triton_quantize_nvfp4(w, w_global_scale)
2627
+ return xq, wq, x_scale, w_scale, global_scale
2628
+
2629
+ def compute(self, xq, wq, x_scale, w_scale, global_scale):
2630
+ return torch.ops.mslk.f4f4bf16(
2631
+ xq, wq, x_scale, w_scale, global_scale=global_scale
2632
+ )
2633
+
2634
+ def quantize_and_compute(self, x, w):
2635
+ return self.quantize(x, w)
2636
+
2637
+ @property
2638
+ def name(self) -> str:
2639
+ return "nvfp4_quantize"
2640
+
2641
+ @property
2642
+ def supported_accelerators(self) -> set[Accelerator]:
2643
+ return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
2644
+
2645
+ @property
2646
+ def supported_gemm_types(self) -> set[GemmType]:
2647
+ return {GemmType.REGULAR}
2648
+
2649
+ @property
2650
+ def compute_dtype(self) -> ComputeDtype:
2651
+ return ComputeDtype.FP4
2652
+
2653
+
2654
+ @register_gemm_op
2655
+ class MXFP4Gemm(GemmOpBase):
2656
+ """
2657
+ MXFP4 matmul with block-wise scaling.
2658
+ """
2659
+
2660
+ def quantize(self, x, w):
2661
+ xq, x_scale = triton_quantize_mx4_unpack(x)
2662
+ wq, w_scale = triton_quantize_mx4_unpack(w)
2663
+ return xq, wq, x_scale, w_scale
2664
+
2665
+ def compute(self, xq, wq, x_scale, w_scale):
2666
+ return torch.ops.mslk.f4f4bf16(xq, wq, x_scale, w_scale)
2667
+
2668
+ def quantize_and_compute(self, x, w):
2669
+ xq, wq, x_scale, w_scale = self.quantize(x, w)
2670
+ return self.compute(xq, wq, x_scale, w_scale)
2671
+
2672
+ @property
2673
+ def name(self) -> str:
2674
+ return "cutlass_f4f4bf16"
2675
+
2676
+ @property
2677
+ def supported_accelerators(self) -> set[Accelerator]:
2678
+ return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
2679
+
2680
+ @property
2681
+ def supported_gemm_types(self) -> set[GemmType]:
2682
+ return {GemmType.REGULAR}
2683
+
2684
+ @property
2685
+ def compute_dtype(self) -> ComputeDtype:
2686
+ return ComputeDtype.FP4
2687
+
2688
+
2689
+ @register_gemm_op
2690
+ class MXFP4StackedGroupedGemm(GemmOpBase):
2691
+ """
2692
+ MXFP4 grouped matmul with blockwise scaling and stacked inputs.
2693
+ """
2694
+
2695
+ def preprocess(self, x, w):
2696
+ m_values = [i.shape[0] for i in x]
2697
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2698
+ wq, w_scale = zip(*[triton_quantize_mx4_unpack(i) for i in w])
2699
+ wq = torch.stack(wq, dim=0).contiguous()
2700
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
2701
+ return x, wq, w_scale, m_sizes
2702
+
2703
+ def quantize(self, x, wq, w_scale, m_sizes):
2704
+ starting_row_after_padding_list = [0]
2705
+ xq_list = []
2706
+ x_scale_list = []
2707
+ for i in range(m_sizes.shape[0]):
2708
+ scale_slice = x[i]
2709
+ if m_sizes[i].item() != 0:
2710
+ xq, x_scale = triton_quantize_mx4_unpack(scale_slice)
2711
+ xq_list.append(xq)
2712
+ x_scale_list.append(x_scale)
2713
+ starting_row_after_padding_list.append(
2714
+ starting_row_after_padding_list[i]
2715
+ + x_scale.numel() // (x[0].shape[1] // 32)
2716
+ )
2717
+ else:
2718
+ starting_row_after_padding_list.append(
2719
+ starting_row_after_padding_list[i]
2720
+ )
2721
+ xq = torch.cat(xq_list, dim=0).contiguous()
2722
+ x_scale = torch.cat(x_scale_list, dim=0).contiguous()
2723
+ x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32)
2724
+ xq = xq.view(-1, xq.shape[-1])
2725
+ return (
2726
+ xq,
2727
+ wq,
2728
+ x_scale,
2729
+ w_scale,
2730
+ m_sizes,
2731
+ torch.tensor(starting_row_after_padding_list, device=xq.device),
2732
+ )
2733
+
2734
+ def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding):
2735
+ return torch.ops.mslk.f4f4bf16_grouped_stacked(
2736
+ xq,
2737
+ wq,
2738
+ x_scale,
2739
+ w_scale,
2740
+ m_sizes,
2741
+ starting_row_after_padding=starting_row_after_padding,
2742
+ )
2743
+
2744
+ def quantize_and_compute(self, x, w):
2745
+ xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize(
2746
+ x, w
2747
+ )
2748
+ return self.compute(
2749
+ xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding
2750
+ )
2751
+
2752
+ @property
2753
+ def name(self) -> str:
2754
+ return "cutlass_f4f4bf16_grouped_stacked"
2755
+
2756
+ @property
2757
+ def supported_accelerators(self) -> set[Accelerator]:
2758
+ return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
2759
+
2760
+ @property
2761
+ def supported_gemm_types(self) -> set[GemmType]:
2762
+ return {GemmType.GROUPED}
2763
+
2764
+ @property
2765
+ def compute_dtype(self) -> ComputeDtype:
2766
+ return ComputeDtype.FP4
2767
+
2768
+
2769
+ @register_gemm_op
2770
+ class NVFP4StackedGroupedGemm(GemmOpBase):
2771
+ """
2772
+ NVFP4 grouped matmul with blockwise scaling and stacked inputs.
2773
+ """
2774
+
2775
+ def preprocess(self, x, w):
2776
+ m_values = [i.shape[0] for i in x]
2777
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2778
+ x = torch.concat(x, dim=0).contiguous()
2779
+
2780
+ def get_global_scale(x, w, m_sizes):
2781
+ G = len(w)
2782
+ w_global_scale = []
2783
+ global_scale = []
2784
+
2785
+ cumulative_sum = torch.zeros(
2786
+ m_sizes.shape[0] + 1, dtype=torch.int64, device=m_sizes.device
2787
+ )
2788
+ cumulative_sum[1:] = torch.cumsum(m_sizes, dim=0)
2789
+
2790
+ x_global_scale, tensor_idx = calculate_group_max(x, m_sizes=m_sizes)
2791
+
2792
+ for i in range(G):
2793
+ w_global_scale_ = (448.0 * 6.0) / torch.amax(
2794
+ torch.abs(w[i].flatten()), dim=-1
2795
+ ).to(torch.float32)
2796
+
2797
+ global_scale_ = 1 / (x_global_scale[i] * w_global_scale_)
2798
+
2799
+ w_global_scale.append(w_global_scale_)
2800
+ global_scale.append(global_scale_)
2801
+
2802
+ return x_global_scale, w_global_scale, global_scale, tensor_idx
2803
+
2804
+ # Compute global scale for each group
2805
+ G = m_sizes.numel()
2806
+ x_global_scale, w_global_scale, global_scale, tensor_idx = get_global_scale(
2807
+ x, w, m_sizes
2808
+ )
2809
+ global_scale = torch.stack(global_scale, dim=0).contiguous()
2810
+
2811
+ wq, w_scale = zip(
2812
+ *[triton_quantize_nvfp4(w[i], w_global_scale[i]) for i in range(G)]
2813
+ )
2814
+ wq = torch.stack(wq, dim=0).contiguous()
2815
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
2816
+
2817
+ return x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2818
+
2819
+ def quantize(
2820
+ self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2821
+ ):
2822
+ # alternative methods, may be useful in some scenarios
2823
+ """
2824
+ starting_row_after_padding, belong_indices, row_within_tensor = (
2825
+ nvfp4_fused_padding_cumsum_and_segmented_arange(m_sizes, x.shape[0])
2826
+ # fused_single_block_cumsum_and_segmented_arange(m_sizes, x.shape[0])
2827
+ )
2828
+
2829
+ xq, x_scale = triton_nvfp4_quant_stacked(
2830
+ x,
2831
+ x_global_scale[0],
2832
+ belong_indices,
2833
+ starting_row_after_padding,
2834
+ row_within_tensor,
2835
+ )
2836
+ """
2837
+
2838
+ # we can optionally set optional_tensor_idx to None to run the alternative method
2839
+ xq, x_scale, starting_row_after_padding = mega_fp4_quantize_kernel(
2840
+ m_sizes, x, x_global_scale, optional_tensor_idx=tensor_idx
2841
+ )
2842
+
2843
+ x_scale = x_scale.reshape(-1, x.shape[1] // 16)
2844
+ return (
2845
+ xq,
2846
+ wq,
2847
+ x_scale,
2848
+ w_scale,
2849
+ m_sizes,
2850
+ global_scale,
2851
+ starting_row_after_padding,
2852
+ )
2853
+
2854
+ def compute(
2855
+ self,
2856
+ xq,
2857
+ wq,
2858
+ x_scale,
2859
+ w_scale,
2860
+ m_sizes,
2861
+ global_scale,
2862
+ starting_row_after_padding,
2863
+ ):
2864
+ gemm_result = torch.ops.mslk.f4f4bf16_grouped_stacked(
2865
+ xq,
2866
+ wq,
2867
+ x_scale,
2868
+ w_scale,
2869
+ m_sizes,
2870
+ global_scale,
2871
+ starting_row_after_padding,
2872
+ use_mx=False,
2873
+ )
2874
+ return gemm_result
2875
+
2876
+ def quantize_and_compute(
2877
+ self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2878
+ ):
2879
+ (
2880
+ xq,
2881
+ wq,
2882
+ x_scale,
2883
+ w_scale,
2884
+ m_sizes,
2885
+ global_scale,
2886
+ starting_row_after_padding,
2887
+ ) = self.quantize(
2888
+ x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2889
+ )
2890
+ return self.compute(
2891
+ xq,
2892
+ wq,
2893
+ x_scale,
2894
+ w_scale,
2895
+ m_sizes,
2896
+ global_scale,
2897
+ starting_row_after_padding,
2898
+ )
2899
+
2900
+ @property
2901
+ def name(self) -> str:
2902
+ return "cutlass_nv_f4f4bf16_grouped_stacked"
2903
+
2904
+ @property
2905
+ def supported_accelerators(self) -> set[Accelerator]:
2906
+ return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
2907
+
2908
+ @property
2909
+ def supported_gemm_types(self) -> set[GemmType]:
2910
+ return {GemmType.GROUPED}
2911
+
2912
+ @property
2913
+ def compute_dtype(self) -> ComputeDtype:
2914
+ return ComputeDtype.FP4
2915
+
2916
+
2917
+ # Broken with cuda graph
2918
+ # @register_gemm_op
2919
+ class NVFP4StackedGroupedGemmPackUnpack(GemmOpBase):
2920
+ """
2921
+ NVFP4 grouped matmul with blockwise scaling and stacked inputs.
2922
+ """
2923
+
2924
+ def preprocess(self, x, w):
2925
+ m_values = [i.shape[0] for i in x]
2926
+ m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2927
+ x = torch.concat(x, dim=0).contiguous()
2928
+
2929
+ def get_global_scale(x, w):
2930
+ G = len(w)
2931
+ x_global_scale = []
2932
+ w_global_scale = []
2933
+ global_scale = []
2934
+
2935
+ x_global_scale_ = (448.0 * 6.0) / torch.amax(
2936
+ torch.abs(x.flatten()), dim=-1
2937
+ ).to(torch.float32)
2938
+
2939
+ for i in range(G):
2940
+ w_global_scale_ = (448.0 * 6.0) / torch.amax(
2941
+ torch.abs(w[i].flatten()), dim=-1
2942
+ ).to(torch.float32)
2943
+
2944
+ global_scale_ = 1 / (x_global_scale_ * w_global_scale_)
2945
+
2946
+ x_global_scale.append(x_global_scale_)
2947
+ w_global_scale.append(w_global_scale_)
2948
+ global_scale.append(global_scale_)
2949
+
2950
+ return x_global_scale, w_global_scale, global_scale
2951
+
2952
+ # Compute global scale for each group
2953
+ G = m_sizes.numel()
2954
+ x_global_scale, w_global_scale, global_scale = get_global_scale(x, w)
2955
+
2956
+ global_scale = torch.stack(global_scale, dim=0).contiguous()
2957
+
2958
+ wq, w_scale = zip(
2959
+ *[triton_quantize_nvfp4(w[i], w_global_scale[i]) for i in range(G)]
2960
+ )
2961
+ wq = torch.stack(wq, dim=0).contiguous()
2962
+ w_scale = torch.stack(w_scale, dim=0).contiguous()
2963
+ x_global_scale = torch.tensor(x_global_scale, device=m_sizes.device)
2964
+ return (
2965
+ x,
2966
+ wq,
2967
+ w_scale,
2968
+ x_global_scale,
2969
+ global_scale,
2970
+ m_sizes,
2971
+ )
2972
+
2973
+ def quantize(self, x, wq, w_scale, x_global_scale, global_scale, m_sizes):
2974
+ # alternative packing methods that only uses the overall global scale rather than per tensor
2975
+ """
2976
+ packed = mega_fp4_pack(x, x_global_scale[0])
2977
+ """
2978
+ packed = mega_fp4_pack(
2979
+ x,
2980
+ x_global_scale,
2981
+ per_tensor=True,
2982
+ m_sizes=m_sizes,
2983
+ )
2984
+ xq, x_scale, starting_row_after_padding = mega_fp4_unpack(m_sizes, packed)
2985
+ xq_other, x_scale_other, starting_row_after_padding_other = (
2986
+ mega_fp4_quantize_kernel(
2987
+ m_sizes,
2988
+ x,
2989
+ x_global_scale,
2990
+ )
2991
+ )
2992
+
2993
+ x_scale = x_scale.reshape(-1, x.shape[1] // 16)
2994
+ x_scale_other = x_scale_other.reshape(-1, x.shape[1] // 16)
2995
+ return (
2996
+ xq,
2997
+ wq,
2998
+ x_scale,
2999
+ w_scale,
3000
+ m_sizes,
3001
+ global_scale,
3002
+ starting_row_after_padding,
3003
+ xq_other,
3004
+ x_scale_other,
3005
+ starting_row_after_padding_other,
3006
+ )
3007
+
3008
+ def compute(
3009
+ self,
3010
+ xq,
3011
+ wq,
3012
+ x_scale,
3013
+ w_scale,
3014
+ m_sizes,
3015
+ global_scale,
3016
+ starting_row_after_padding,
3017
+ xq_other,
3018
+ x_scale_other,
3019
+ starting_row_after_padding_other,
3020
+ ):
3021
+ ref_solution = torch.ops.mslk.f4f4bf16_grouped_stacked(
3022
+ xq_other,
3023
+ wq,
3024
+ x_scale_other,
3025
+ w_scale,
3026
+ m_sizes,
3027
+ global_scale,
3028
+ starting_row_after_padding_other,
3029
+ use_mx=False,
3030
+ )
3031
+ gemm_result = torch.ops.mslk.f4f4bf16_grouped_stacked(
3032
+ xq,
3033
+ wq,
3034
+ x_scale,
3035
+ w_scale,
3036
+ m_sizes,
3037
+ global_scale,
3038
+ starting_row_after_padding,
3039
+ use_mx=False,
3040
+ )
3041
+ assert torch.allclose(ref_solution, gemm_result)
3042
+
3043
+ return gemm_result
3044
+
3045
+ def quantize_and_compute(
3046
+ self, x, wq, w_scale, x_global_scale, global_scale, m_sizes
3047
+ ):
3048
+ (
3049
+ xq,
3050
+ wq,
3051
+ x_scale,
3052
+ w_scale,
3053
+ m_sizes,
3054
+ global_scale,
3055
+ starting_row_after_padding,
3056
+ xq_other,
3057
+ x_scale_other,
3058
+ starting_row_after_padding_other,
3059
+ ) = self.quantize(x, wq, w_scale, x_global_scale, global_scale, m_sizes)
3060
+ return self.compute(
3061
+ xq,
3062
+ wq,
3063
+ x_scale,
3064
+ w_scale,
3065
+ m_sizes,
3066
+ global_scale,
3067
+ starting_row_after_padding,
3068
+ xq_other,
3069
+ x_scale_other,
3070
+ starting_row_after_padding_other,
3071
+ )
3072
+
3073
+ @property
3074
+ def name(self) -> str:
3075
+ return "cutlass_nv_f4f4bf16_grouped_stacked_pack_unpack"
3076
+
3077
+ @property
3078
+ def supported_accelerators(self) -> set[Accelerator]:
3079
+ return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
3080
+
3081
+ @property
3082
+ def supported_gemm_types(self) -> set[GemmType]:
3083
+ return {GemmType.GROUPED}
3084
+
3085
+ @property
3086
+ def compute_dtype(self) -> ComputeDtype:
3087
+ return ComputeDtype.FP4
3088
+
3089
+
3090
+ @register_gemm_op
3091
+ class BF16GroupedGemm2d3d(GemmOpBase):
3092
+ """
3093
+ Torch BF16 grouped GEMM with 2D inputs and 3D weights.
3094
+ """
3095
+
3096
+ def preprocess(self, x, w):
3097
+ assert isinstance(x, list)
3098
+ assert isinstance(w, list)
3099
+ offs = torch.tensor(
3100
+ [i.shape[0] for i in x], dtype=torch.int32, device=x[0].device
3101
+ )
3102
+ offs = torch.cumsum(offs, dim=0).to(torch.int32)
3103
+ x = torch.cat(x, dim=0).contiguous() # (G * M, K)
3104
+ w = torch.stack(w, dim=0).contiguous() # (G, N, K)
3105
+ return x, w, offs
3106
+
3107
+ def quantize(self, x, w, offs):
3108
+ return x, w, offs
3109
+
3110
+ def compute(self, x, w, offs):
3111
+ return torch._grouped_mm(
3112
+ x,
3113
+ w.transpose(-2, -1),
3114
+ offs=offs,
3115
+ )
3116
+
3117
+ def quantize_and_compute(self, x, w, offs):
3118
+ x, w, offs = self.quantize(x, w)
3119
+ return self.compute(x, w, offs)
3120
+
3121
+ @property
3122
+ def name(self) -> str:
3123
+ return "bf16_baseline_grouped_2d_3d"
3124
+
3125
+ @property
3126
+ def supported_accelerators(self) -> set[Accelerator]:
3127
+ return set(Accelerator)
3128
+
3129
+ @property
3130
+ def supported_gemm_types(self) -> set[GemmType]:
3131
+ return {GemmType.GROUPED}
3132
+
3133
+ @property
3134
+ def compute_dtype(self) -> ComputeDtype:
3135
+ return ComputeDtype.BF16
3136
+
3137
+
3138
+ @register_gemm_op
3139
+ class MXFP8GroupedGemm2d3d(GemmOpBase):
3140
+ """
3141
+ MXFP8 grouped GEMM with 2D inputs and 3D weights.
3142
+ """
3143
+
3144
+ def preprocess(self, x, w):
3145
+ assert isinstance(x, list)
3146
+ assert isinstance(w, list)
3147
+ x = torch.cat(x, dim=0).contiguous() # (G * M, K)
3148
+ w = torch.stack(w, dim=0).contiguous() # (G, N, K)
3149
+ return x, w
3150
+
3151
+ def quantize(self, x, w):
3152
+ block_size = 32
3153
+ G, N, K = w.shape
3154
+ total_M = x.shape[0]
3155
+ group_size = total_M // G
3156
+ input_group_end_offsets = torch.arange(
3157
+ group_size, total_M + 1, group_size, dtype=torch.int32, device=x.device
3158
+ )
3159
+
3160
+ # For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
3161
+ # as they each used for independent gemm in the grouped gemm.
3162
+ wq_list = []
3163
+ w_scale_list = []
3164
+ for i in range(G):
3165
+ w_scale, wq = to_mxfp8(w[i])
3166
+ w_scale = _to_blocked(w_scale)
3167
+ wq_list.append(wq)
3168
+ w_scale_list.append(w_scale)
3169
+ wq = torch.stack(wq_list, dim=0).contiguous()
3170
+ w_scale = torch.stack(w_scale_list, dim=0).contiguous()
3171
+
3172
+ # For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
3173
+ # as they each used for independent gemm in the grouped gemm.
3174
+ xq_list = []
3175
+ x_scale_list = []
3176
+ for i in range(G):
3177
+ prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
3178
+ curr_group_end = input_group_end_offsets[i]
3179
+ group_size = curr_group_end - prev_group_end
3180
+ if group_size > 0:
3181
+ x_slice = x[prev_group_end:curr_group_end, :]
3182
+ x_scale, xq = to_mxfp8(x_slice)
3183
+ x_scale = _to_blocked(x_scale)
3184
+ xq_list.append(xq)
3185
+ x_scale_list.append(x_scale)
3186
+ xq = torch.cat(xq_list, dim=0).contiguous()
3187
+ x_scale = torch.cat(x_scale_list, dim=0).contiguous()
3188
+ x_scale = x_scale.reshape(-1, K // block_size)
3189
+ xq = xq.view(-1, xq.shape[-1])
3190
+ return xq, wq, x_scale, w_scale, input_group_end_offsets
3191
+
3192
+ def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
3193
+ return torch.ops.mslk.mx8mx8bf16_grouped_mm(
3194
+ xq,
3195
+ wq.transpose(-2, -1),
3196
+ x_scale,
3197
+ w_scale,
3198
+ input_group_end_offsets,
3199
+ )
3200
+
3201
+ def quantize_and_compute(self, x, w):
3202
+ xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
3203
+ return self.compute(
3204
+ xq,
3205
+ wq,
3206
+ x_scale,
3207
+ w_scale,
3208
+ input_group_end_offsets,
3209
+ )
3210
+
3211
+ @property
3212
+ def name(self) -> str:
3213
+ return "cutlass_mx8mx8bf16_grouped_mm_2d_3d"
3214
+
3215
+ @property
3216
+ def supported_accelerators(self) -> set[Accelerator]:
3217
+ return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
3218
+
3219
+ @property
3220
+ def supported_gemm_types(self) -> set[GemmType]:
3221
+ return {GemmType.GROUPED}
3222
+
3223
+ @property
3224
+ def compute_dtype(self) -> ComputeDtype:
3225
+ return ComputeDtype.FP8
3226
+
3227
+
3228
+ @register_gemm_op
3229
+ class MXFP8GroupedGemm2d2d(GemmOpBase):
3230
+ """
3231
+ MXFP8 grouped GEMM with 2D inputs and 3D weights.
3232
+ """
3233
+
3234
+ def preprocess(self, x, w):
3235
+ assert isinstance(x, list)
3236
+ assert isinstance(w, list)
3237
+ G = len(x)
3238
+ x = torch.cat(x, dim=1).contiguous() # (M, total_K)
3239
+ w = torch.cat(w, dim=1).contiguous() # (N, total_K)
3240
+ return x, w, G
3241
+
3242
+ def quantize(self, x, w, G):
3243
+ # Simulate 2d-2d grouped gemm in backward pass `grad_weight = grad_output_t @ input`,
3244
+ # where we use "K" as the contracting dim which has "G" groups.
3245
+ M, total_K = x.shape
3246
+ N, _ = w.shape
3247
+ group_size = total_K // G
3248
+ input_group_end_offsets = torch.arange(
3249
+ group_size, total_K + 1, group_size, dtype=torch.int32, device=x.device
3250
+ )
3251
+
3252
+ # Convert scales to blocked format.
3253
+ x_list = []
3254
+ w_list = []
3255
+ x_blocked_scale_list = []
3256
+ w_blocked_scale_list = []
3257
+
3258
+ def round_up(x: int, y: int) -> int:
3259
+ return ((x + y - 1) // y) * y
3260
+
3261
+ for group_idx in range(G):
3262
+ # to_mxfp8 per group
3263
+ prev_group_end_offset = (
3264
+ 0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
3265
+ )
3266
+ curr_group_end_offset = input_group_end_offsets[group_idx]
3267
+ group_size = curr_group_end_offset - prev_group_end_offset
3268
+ if group_size > 0:
3269
+ x_slice = x[
3270
+ :, prev_group_end_offset:curr_group_end_offset
3271
+ ].contiguous() # (M, K_group)
3272
+ w_slice = w[
3273
+ :, prev_group_end_offset:curr_group_end_offset
3274
+ ].contiguous() # (N, K_group)
3275
+ x_scale_slice, xq_slice = to_mxfp8(
3276
+ x_slice
3277
+ ) # scale shape -> (M, K_group // 32)
3278
+ w_scale_slice, wq_slice = to_mxfp8(
3279
+ w_slice
3280
+ ) # scale shape -> (N, K_group // 32)
3281
+ x_list.append(xq_slice)
3282
+ w_list.append(wq_slice)
3283
+
3284
+ # Convert scales to blocked format.
3285
+ x_scale_slice_blocked = _to_blocked(
3286
+ x_scale_slice
3287
+ ) # (round_up(M, 128), round_up(K_group//32, 4))
3288
+ w_scale_slice_blocked = _to_blocked(
3289
+ w_scale_slice
3290
+ ) # (round_up(N, 128), round_up(K_group//32, 4))
3291
+ x_blocked_scale_list.append(x_scale_slice_blocked)
3292
+ w_blocked_scale_list.append(w_scale_slice_blocked)
3293
+
3294
+ # Assemble the full XQ and WQ
3295
+ xq = torch.cat(x_list, dim=1).contiguous()
3296
+ wq = torch.cat(w_list, dim=1).contiguous()
3297
+
3298
+ # Combine all XQ groups blocked scales into one tensor.
3299
+ x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
3300
+ M_rounded = round_up(M, 128)
3301
+ x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
3302
+
3303
+ # Combine all WQ groups blocked scales into one tensor.
3304
+ w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
3305
+ N_rounded = round_up(N, 128)
3306
+ w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
3307
+ return xq, wq, x_blocked_scales, w_blocked_scales, input_group_end_offsets
3308
+
3309
+ def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
3310
+ return torch.ops.mslk.mx8mx8bf16_grouped_mm(
3311
+ xq,
3312
+ wq.transpose(-2, -1),
3313
+ x_scale,
3314
+ w_scale,
3315
+ input_group_end_offsets,
3316
+ )
3317
+
3318
+ def quantize_and_compute(self, x, w):
3319
+ xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
3320
+ return self.compute(
3321
+ xq,
3322
+ wq,
3323
+ x_scale,
3324
+ w_scale,
3325
+ input_group_end_offsets,
3326
+ )
3327
+
3328
+ @property
3329
+ def name(self) -> str:
3330
+ return "cutlass_mx8mx8bf16_grouped_mm_2d_2d"
3331
+
3332
+ @property
3333
+ def supported_accelerators(self) -> set[Accelerator]:
3334
+ return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
3335
+
3336
+ @property
3337
+ def supported_gemm_types(self) -> set[GemmType]:
3338
+ return {GemmType.GROUPED}
3339
+
3340
+ @property
3341
+ def compute_dtype(self) -> ComputeDtype:
3342
+ return ComputeDtype.FP8