mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,859 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import itertools
8
+ import os
9
+ import sys
10
+ from dataclasses import dataclass
11
+ from datetime import datetime
12
+ from enum import Enum
13
+ from typing import Any, Optional
14
+
15
+ import click
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ import pandas as pd
19
+ import seaborn as sns
20
+ import torch
21
+ import triton # @manual=//triton:triton
22
+ from mslk.bench.common.utils import BenchOptions, profiler
23
+ from mslk.bench.gemm.gemm_ops import ComputeDtype, GemmOpBase, GemmType, get_gemm_ops
24
+ from tabulate import tabulate
25
+
26
+
27
+ # Compute theoretical roofline values in TFLOPS for GPU and dtype combinations.
28
+ COMPUTE_ROOFLINE_TFLOPS: dict[str, dict[ComputeDtype, float]] = {
29
+ "NVIDIA H100": {
30
+ ComputeDtype.FP8: 1979.0,
31
+ ComputeDtype.BF16: 989.0,
32
+ ComputeDtype.TF32: 494.5,
33
+ ComputeDtype.FP32: 67.0, # non-tensorcore
34
+ },
35
+ "NVIDIA B200": {
36
+ ComputeDtype.FP4: 9000.0,
37
+ ComputeDtype.FP8: 4500.0,
38
+ ComputeDtype.BF16: 2250.0,
39
+ ComputeDtype.TF32: 1100.0,
40
+ ComputeDtype.FP32: 75.0, # non-tensorcore
41
+ },
42
+ "NVIDIA GB200": {
43
+ ComputeDtype.FP4: 10000.0,
44
+ ComputeDtype.FP8: 5000.0,
45
+ ComputeDtype.BF16: 2500.0,
46
+ ComputeDtype.TF32: 1250.0,
47
+ ComputeDtype.FP32: 80.0, # non-tensorcore
48
+ },
49
+ }
50
+
51
+
52
+ def get_compute_roofline_tflops(compute_dtype: ComputeDtype) -> float | None:
53
+ gpu_rooflines = COMPUTE_ROOFLINE_TFLOPS.get(torch.cuda.get_device_name())
54
+ if gpu_rooflines is None:
55
+ return None
56
+ return gpu_rooflines.get(compute_dtype)
57
+
58
+
59
+ shape_registry = {}
60
+
61
+
62
+ def register_shapes(name):
63
+ def decorator(op):
64
+ shape_registry[name] = op
65
+ return op
66
+
67
+ return decorator
68
+
69
+
70
+ def generate_group_tensor(G, M):
71
+ """
72
+ Generate a tensor with G elements whose integer elements sum to A.
73
+
74
+ Args:
75
+ G (int): Number of elements in the tensor.
76
+ M (int): Sum of the elements in the tensor.
77
+
78
+ Returns:
79
+ torch.Tensor: A tensor with G elements whose integer elements sum to M.
80
+ """
81
+
82
+ # First, we generate a random tensor with G elements
83
+ random_tensor = torch.rand(G)
84
+ # Then, we normalize this tensor so it sums up to 1
85
+ normalized_tensor = random_tensor / random_tensor.sum()
86
+ # Finally, we multiply this tensor by M and round to the nearest integer
87
+ output_tensor = torch.round(normalized_tensor * M).to(torch.int64)
88
+ # Adjust the last element to ensure the sum is exactly M
89
+ output_tensor[-1] += max(0, M - output_tensor.sum())
90
+ return output_tensor.tolist()
91
+
92
+
93
+ def set_amd_env_vars() -> None:
94
+ print("Setting environment variables for AMD GPU performance")
95
+ os.environ["DISABLE_ADDMM_HIP_LT"] = "0"
96
+ os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
97
+ os.environ["PYTORCH_TUNABLEOP_VERBOSE"] = "0"
98
+ os.environ["PYTORCH_TUNABLEOP_ENABLED"] = "1"
99
+ os.environ["PYTORCH_TUNABLEOP_TUNING"] = "1"
100
+ os.environ["PYTORCH_TUNABLEOP_FILENAME"] = "hipblas_tuning_pt_llama.csv"
101
+ os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "30"
102
+ os.environ["PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"] = "30"
103
+
104
+
105
+ @register_shapes("llama3_70b")
106
+ def llama3_70b_shapes() -> list[tuple[int, int, int]]:
107
+ shapes = []
108
+ for M in [1, 16, 32, 64, 96, 128]:
109
+ shapes += [
110
+ (M, 1280, 8192),
111
+ (M, 8192, 1024),
112
+ (M, 7168, 8192),
113
+ (M, 8192, 3584),
114
+ ]
115
+ return shapes
116
+
117
+
118
+ @register_shapes("autotune")
119
+ def autotune() -> list[tuple[int, int, int]]:
120
+ shapes = []
121
+ for M in [
122
+ 1,
123
+ 64,
124
+ 128,
125
+ 256,
126
+ 512,
127
+ 1024,
128
+ 2048,
129
+ 4096,
130
+ 8192,
131
+ 16384,
132
+ ]:
133
+ for N in range(1024, 16384 + 1, 1024):
134
+ for K in range(1024, 16384 + 1, 1024):
135
+ shapes.append((M, N, K))
136
+ return shapes
137
+
138
+
139
+ @register_shapes("llama3_405b")
140
+ def llama3_405b_shapes() -> list[tuple[int, int, int]]:
141
+ shapes = []
142
+ for M in [1, 16, 32, 64, 96, 128]:
143
+ shapes += [
144
+ (M, 13312, 6656),
145
+ (M, 13312, 16384),
146
+ (M, 16384, 6656),
147
+ (M, 16384, 16384),
148
+ ]
149
+ return shapes
150
+
151
+
152
+ @register_shapes("llama4")
153
+ def llama4_shapes() -> list[tuple[int, int, int]]:
154
+ shapes = []
155
+ for M in [1, 16, 32, 64, 96, 128]:
156
+ shapes += [
157
+ (M, 896, 5120),
158
+ (M, 5120, 640),
159
+ (M, 2048, 5120),
160
+ (M, 5120, 1024),
161
+ ]
162
+ return shapes
163
+
164
+
165
+ @register_shapes("ldm")
166
+ def ldm_shapes() -> list[tuple[int, int, int]]:
167
+ return [
168
+ (1536, 3584, 3584),
169
+ (8192, 9728, 3584),
170
+ (8192, 3584, 9728),
171
+ (8192, 3584, 3584),
172
+ (4096, 3584, 3584),
173
+ (768, 3584, 3584),
174
+ (4096, 9728, 3584),
175
+ (4096, 3584, 9728),
176
+ (7200, 3584, 3584),
177
+ (7200, 9728, 3584),
178
+ (7200, 3584, 9728),
179
+ (3600, 3584, 3584),
180
+ (3600, 9728, 3584),
181
+ (3600, 3584, 9728),
182
+ (1536, 4096, 4096),
183
+ (3600, 4096, 4096),
184
+ (3600, 11008, 4096),
185
+ (3600, 4096, 11008),
186
+ (4096, 4096, 4096),
187
+ (4096, 11008, 4096),
188
+ (4096, 4096, 11008),
189
+ (32768, 128, 8192),
190
+ (32768, 8192, 1024),
191
+ (32768, 8192, 3072),
192
+ (32768, 3072, 8192),
193
+ (32768, 1024, 8192),
194
+ ]
195
+
196
+
197
+ class ShapeMode(Enum):
198
+ REGULAR = "regular" # (M, N, K)
199
+ GROUPED = "grouped" # G, (M, N, K)
200
+ GROUPED_TOTAL_M = "grouped_total_m" # G, (TotalM, N, K)
201
+ GROUPED_TOTAL_K = "grouped_total_k" # G, (M, N, TotalK)
202
+
203
+
204
+ @dataclass
205
+ class Metrics:
206
+ op: str
207
+ M: Any = 0
208
+ N: Any = 0
209
+ K: Any = 0
210
+ groups: Optional[int] = None
211
+ shape_mode: ShapeMode = ShapeMode.REGULAR
212
+
213
+ sim: float = 0.0
214
+ ms: float = 0.0
215
+ tflops: float = 0.0
216
+ gbps: float = 0.0
217
+ mem_bw_util: float = 0.0
218
+ compute_util: float = 0.0
219
+
220
+ @staticmethod
221
+ def header(shape_mode: ShapeMode = ShapeMode.REGULAR) -> str:
222
+ is_grouped = shape_mode in (
223
+ ShapeMode.GROUPED,
224
+ ShapeMode.GROUPED_TOTAL_M,
225
+ ShapeMode.GROUPED_TOTAL_K,
226
+ )
227
+ if shape_mode == ShapeMode.GROUPED_TOTAL_M:
228
+ shape_col = "(TotalM, N, K)"
229
+ elif shape_mode == ShapeMode.GROUPED_TOTAL_K:
230
+ shape_col = "(M, N, TotalK)"
231
+ else:
232
+ shape_col = "(M, N, K)"
233
+
234
+ group_col = f"{'G':<6}" if is_grouped else ""
235
+ header = (
236
+ f"{'OpName':<30} {group_col} {shape_col:<25} "
237
+ f"{'Sim':<10} {'Ms':<10} {'TFLOPS':<10} "
238
+ f"{'GB/s':<10} {'Mem BW Util %':<14} {'Compute Util %':<10}"
239
+ )
240
+ divider = "-" * len(header)
241
+ return f"GEMM Bench\n{divider}\n{header}\n{divider}"
242
+
243
+ def __str__(self) -> str:
244
+ is_grouped = self.shape_mode in (
245
+ ShapeMode.GROUPED,
246
+ ShapeMode.GROUPED_TOTAL_M,
247
+ ShapeMode.GROUPED_TOTAL_K,
248
+ )
249
+ if self.shape_mode == ShapeMode.GROUPED_TOTAL_M:
250
+ total_m = sum(self.M) if isinstance(self.M, list) else self.M
251
+ shape = f"({total_m}, {self.N}, {self.K})"
252
+ elif self.shape_mode == ShapeMode.GROUPED_TOTAL_K:
253
+ total_k = sum(self.K) if isinstance(self.K, list) else self.K
254
+ shape = f"({self.M}, {self.N}, {total_k})"
255
+ else:
256
+ shape = f"({self.M}, {self.N}, {self.K})"
257
+
258
+ group_col = f"{self.groups:<6}" if is_grouped else ""
259
+ compute_util_str = (
260
+ f"{self.compute_util:<10.2f}" if self.compute_util > 0 else "N/A"
261
+ )
262
+ return (
263
+ f"{self.op:<30} {group_col} {shape:<25} "
264
+ f"{self.sim:<10.3f} {self.ms:<10.3f} "
265
+ f"{self.tflops:<10.2f} {self.gbps:<10.2f} "
266
+ f"{self.mem_bw_util:<14.2f} {compute_util_str}"
267
+ )
268
+
269
+ def as_dict(self) -> dict[str, Any]:
270
+ result: dict[str, Any] = {
271
+ "M": self.M,
272
+ "N": self.N,
273
+ "K": self.K,
274
+ f"{self.op}_sim": self.sim,
275
+ f"{self.op}_ms": self.ms,
276
+ f"{self.op}_tflops": self.tflops,
277
+ f"{self.op}_gb/s": self.gbps,
278
+ f"{self.op}_mem_bw_util": self.mem_bw_util,
279
+ f"{self.op}_compute_util": self.compute_util,
280
+ }
281
+ if self.groups is not None:
282
+ result["groups"] = self.groups
283
+ return result
284
+
285
+
286
+ def benchmark_grouped(
287
+ gemm_ops: list[GemmOpBase],
288
+ m: list[int],
289
+ n: list[int],
290
+ k: list[int],
291
+ mem_bw_roofline_gbps: float,
292
+ opts: BenchOptions,
293
+ bench_quantize: bool = False,
294
+ shape_mode: ShapeMode = ShapeMode.GROUPED,
295
+ ) -> list[Metrics]:
296
+ num_groups = len(m)
297
+ # Create input tensors.
298
+ A = []
299
+ B = []
300
+ for i in range(num_groups):
301
+ A.append(torch.randn(m[i], k[i], device="cuda", dtype=torch.bfloat16))
302
+ B.append(torch.randn(n[i], k[i], device="cuda", dtype=torch.bfloat16))
303
+ # Compute baseline output for correctness checking.
304
+ out_ref = []
305
+ for i in range(num_groups):
306
+ out_ref.append(torch.matmul(A[i], B[i].t()))
307
+ # Keep track of results.
308
+ # Only log all shapes in a group if they are unique.
309
+ log_m = m[0] if len(np.unique(m)) == 1 else m
310
+ log_n = n[0] if len(np.unique(n)) == 1 else n
311
+ log_k = k[0] if len(np.unique(k)) == 1 else k
312
+ results: list[Metrics] = []
313
+ # Benchmark each operator.
314
+ for gemm_op in gemm_ops:
315
+ # Build progress message based on shape mode.
316
+ if shape_mode == ShapeMode.GROUPED_TOTAL_M:
317
+ total_m = sum(m)
318
+ shape_str = f"(G={num_groups}, TotalM={total_m}, N={log_n}, K={log_k})"
319
+ elif shape_mode == ShapeMode.GROUPED_TOTAL_K:
320
+ total_k = sum(k)
321
+ shape_str = f"(G={num_groups}, M={log_m}, N={log_n}, TotalK={total_k})"
322
+ else:
323
+ shape_str = f"(G={num_groups}, M={log_m}, N={log_n}, K={log_k})"
324
+ print(f"Benchmarking {gemm_op.name} with {shape_str}")
325
+ metrics = Metrics(
326
+ op=gemm_op.name,
327
+ M=log_m,
328
+ N=log_n,
329
+ K=log_k,
330
+ groups=num_groups,
331
+ shape_mode=shape_mode,
332
+ )
333
+ # Set fast accum mode if applicable.
334
+ if hasattr(gemm_op, "fast_accum"):
335
+ gemm_op.fast_accum = opts.fast_accum
336
+ if hasattr(gemm_op, "torch_compile"):
337
+ gemm_op.torch_compile = opts.torch_compile
338
+
339
+ # Get compute roofline for this op's compute dtype.
340
+ compute_roofline_tflops = get_compute_roofline_tflops(gemm_op.compute_dtype)
341
+
342
+ try:
343
+ # Get the quantized tensors for this operator.
344
+ preprocessed_args = gemm_op.preprocess(A, B)
345
+ quantized_vals = gemm_op.quantize(*preprocessed_args)
346
+ # Compute the output given quantized values.
347
+ output = gemm_op.compute(*quantized_vals)
348
+ except Exception as e:
349
+ print(f"GEMM op {gemm_op.name} failed to run due to error: {e}.")
350
+ continue
351
+ # Some kernels may pad output, just take the first m values of each row.
352
+ if isinstance(output, torch.Tensor) and output.ndim == 2:
353
+ # Output is stacked and needs to be split.
354
+ output = torch.split(output, m, dim=0)
355
+ else:
356
+ # Otherwise output may be padded or require unbinding.
357
+ output = [o[: m[i]] for i, o in enumerate(output)]
358
+ # Compare the quantize op output to reference as a sanity check.
359
+ for i in range(num_groups):
360
+ if m[i] > 0:
361
+ metrics.sim += float(
362
+ torch.mean(torch.pow(output[i] - out_ref[i], 2)).item()
363
+ )
364
+ for _ in range(opts.num_iters):
365
+ # Now perform benchmark.
366
+ if bench_quantize:
367
+ # Benchmark both quantize and compute.
368
+ with profiler(enabled=opts.trace, with_stack=True):
369
+ ms_runtime = gemm_op.benchmark(
370
+ *preprocessed_args,
371
+ opts=opts,
372
+ bench_quantize=True,
373
+ )
374
+ else:
375
+ with profiler(enabled=opts.trace, with_stack=True):
376
+ ms_runtime = gemm_op.benchmark(
377
+ *quantized_vals,
378
+ opts=opts,
379
+ bench_quantize=False,
380
+ )
381
+
382
+ for i in range(num_groups):
383
+ output_multiplier = 2 if "fuse_scatter_add" in gemm_op.name else 1
384
+ if m[i] > 0:
385
+ tflops = 2 * m[i] * n[i] * k[i] / (ms_runtime / 1e3) / 1e12
386
+ gbps = (
387
+ (
388
+ m[i] * k[i] * quantized_vals[0][0].element_size()
389
+ + n[i] * k[i] * quantized_vals[1][0].element_size()
390
+ + output_multiplier * m[i] * n[i] * output[0].element_size()
391
+ )
392
+ / (ms_runtime / 1e3)
393
+ / 1e9
394
+ )
395
+ metrics.gbps += gbps
396
+ metrics.tflops += tflops
397
+ metrics.mem_bw_util += (gbps / mem_bw_roofline_gbps) * 100
398
+ if compute_roofline_tflops is not None:
399
+ metrics.compute_util += (tflops / compute_roofline_tflops) * 100
400
+ metrics.ms += ms_runtime
401
+ metrics.ms /= opts.num_iters
402
+ metrics.tflops /= opts.num_iters
403
+ metrics.gbps /= opts.num_iters
404
+ metrics.mem_bw_util /= opts.num_iters
405
+ metrics.compute_util /= opts.num_iters
406
+
407
+ results.append(metrics)
408
+
409
+ return results
410
+
411
+
412
+ def benchmark(
413
+ gemm_ops: list[GemmOpBase],
414
+ m: int,
415
+ n: int,
416
+ k: int,
417
+ mem_bw_roofline_gbps: float,
418
+ opts: BenchOptions,
419
+ bench_quantize: bool = False,
420
+ shape_mode: ShapeMode = ShapeMode.REGULAR,
421
+ ) -> list[Metrics]:
422
+ # Create input tensors.
423
+ A = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
424
+ B = torch.randn(n, k, device="cuda", dtype=torch.bfloat16)
425
+
426
+ # Compute baseline output for correctness checking.
427
+ out_ref = torch.matmul(A, torch.transpose(B, -2, -1))
428
+ # Keep track of results.
429
+ results: list[Metrics] = []
430
+ # Benchmark each operator.
431
+ for gemm_op in gemm_ops:
432
+ shape_str = f"(M={m}, N={n}, K={k})"
433
+ print(f"Benchmarking {gemm_op.name} with {shape_str}")
434
+ metrics = Metrics(op=gemm_op.name, M=m, N=n, K=k, shape_mode=shape_mode)
435
+ # Set fast accum mode if applicable.
436
+ if hasattr(gemm_op, "fast_accum"):
437
+ gemm_op.fast_accum = opts.fast_accum
438
+ if hasattr(gemm_op, "torch_compile"):
439
+ gemm_op.torch_compile = opts.torch_compile
440
+
441
+ # Get compute roofline for this op's compute dtype.
442
+ compute_roofline_tflops = get_compute_roofline_tflops(gemm_op.compute_dtype)
443
+
444
+ try:
445
+ # Preprocess data if needed.
446
+ preprocessed_args = gemm_op.preprocess(A, B)
447
+ # Get the quantized tensors for this operator.
448
+ quantized_vals = gemm_op.quantize(*preprocessed_args)
449
+ # Compute the output given quantized values.
450
+ output = gemm_op.compute(*quantized_vals)
451
+ except Exception as e:
452
+ print(f"GEMM op {gemm_op.name} failed to run due to error: {e}.")
453
+ continue
454
+ # Compare the quantize op output to reference as a sanity check.
455
+ # TODO(shikaili): This calculation is incorrect for scatter add fusion.
456
+ metrics.sim = torch.mean(torch.pow(output - out_ref, 2)).item()
457
+
458
+ for _ in range(opts.num_iters):
459
+ # Now perform benchmark.
460
+ if bench_quantize:
461
+ # Benchmark both quantize and compute.
462
+ with profiler(enabled=opts.trace, with_stack=True):
463
+ ms_runtime = gemm_op.benchmark(
464
+ *preprocessed_args,
465
+ opts=opts,
466
+ bench_quantize=True,
467
+ )
468
+ else:
469
+ with profiler(enabled=opts.trace, with_stack=True):
470
+ ms_runtime = gemm_op.benchmark(
471
+ *quantized_vals,
472
+ opts=opts,
473
+ bench_quantize=False,
474
+ )
475
+
476
+ tflops = 2 * m * n * k / (ms_runtime / 1e3) / 1e12
477
+ metrics.tflops += tflops
478
+ gbps = (
479
+ (
480
+ quantized_vals[0].numel() * quantized_vals[0].element_size()
481
+ + quantized_vals[1].numel() * quantized_vals[1].element_size()
482
+ + output.numel() * output.element_size()
483
+ )
484
+ / (ms_runtime / 1e3)
485
+ / 1e9
486
+ )
487
+ metrics.gbps += gbps
488
+ metrics.mem_bw_util += (gbps / mem_bw_roofline_gbps) * 100
489
+ if compute_roofline_tflops is not None:
490
+ metrics.compute_util += (tflops / compute_roofline_tflops) * 100
491
+ metrics.ms += ms_runtime
492
+ metrics.ms /= opts.num_iters
493
+ metrics.tflops /= opts.num_iters
494
+ metrics.gbps /= opts.num_iters
495
+ metrics.mem_bw_util /= opts.num_iters
496
+ metrics.compute_util /= opts.num_iters
497
+
498
+ results.append(metrics)
499
+
500
+ return results
501
+
502
+
503
+ def plot_benchmark(results: list[Metrics], output_dir: str) -> None:
504
+ """Create a barplot visualizing the TFLOPS of each kernel."""
505
+ # Reprocess into new dataframe with proper graph format.
506
+ data = []
507
+ # Extract measurements for each shape.
508
+ for metric in results:
509
+ mnk = f"{metric.M}, {metric.N}, {metric.K}"
510
+ data.append({"MNK": mnk, "kernel": metric.op, "TFLOPS": metric.tflops})
511
+
512
+ # Create a barplot using seaborn.
513
+ df = pd.DataFrame(data)
514
+ plot = plt.figure()
515
+ plt.xticks(rotation=30)
516
+ plt.yscale("log")
517
+ ax = sns.barplot(x="MNK", y="TFLOPS", hue="kernel", data=df)
518
+ ax.tick_params(axis="x", labelsize=3)
519
+ img_fn = os.path.join(output_dir, "gemm_ops_benchmark.png")
520
+ plot.savefig(img_fn, dpi=300)
521
+ print(f"Plot saved to {img_fn}")
522
+
523
+
524
+ def collect_kernels_to_profile(
525
+ kernels: Optional[list[str]], is_grouped: bool
526
+ ) -> list[GemmOpBase]:
527
+ gemm_type = GemmType.GROUPED if is_grouped else GemmType.REGULAR
528
+ gemm_ops = [
529
+ op
530
+ for op in get_gemm_ops()
531
+ if op.supported and gemm_type in op.supported_gemm_types
532
+ ]
533
+ if kernels is None:
534
+ return gemm_ops
535
+ return [op for op in gemm_ops if op.name in kernels]
536
+
537
+
538
+ def print_kernels(kernels: Optional[list[str]]) -> list[GemmOpBase]:
539
+ data = sorted(
540
+ (
541
+ op.name,
542
+ ",".join(accelerator.name for accelerator in op.supported_accelerators),
543
+ )
544
+ for op in get_gemm_ops()
545
+ )
546
+ print(tabulate(data, headers=["Name", "Accelerators"], tablefmt="orgtbl"))
547
+
548
+
549
+ @click.command()
550
+ @click.option(
551
+ "--output-dir",
552
+ default="/tmp",
553
+ help="Directory to save plots and csvs to",
554
+ )
555
+ @click.option(
556
+ "--num-iters",
557
+ default=1,
558
+ type=int,
559
+ help="Number of iterations to repeat each benchmark.",
560
+ )
561
+ @click.option(
562
+ "--export-csv",
563
+ is_flag=True,
564
+ help="Export results to a CSV file.",
565
+ )
566
+ @click.option(
567
+ "--plot",
568
+ is_flag=True,
569
+ help="Create a plot of the benchmark measurements.",
570
+ )
571
+ @click.option(
572
+ "--enable-amd-env-vars",
573
+ is_flag=True,
574
+ help="Enable a set of environment variables for AMD GPU performance",
575
+ )
576
+ @click.option(
577
+ "--bench-quantize",
578
+ is_flag=True,
579
+ help="If set, include quantization cost in benchmark.",
580
+ )
581
+ @click.option(
582
+ "--kernels",
583
+ default=None,
584
+ help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
585
+ )
586
+ @click.option(
587
+ "--M",
588
+ default=None,
589
+ help="Comma separated list of M values to benchmark.",
590
+ )
591
+ @click.option(
592
+ "--N",
593
+ default=None,
594
+ help="Comma separated list of N values to benchmark",
595
+ )
596
+ @click.option(
597
+ "--K",
598
+ default=None,
599
+ help="Comma separated list of K values to benchmark.",
600
+ )
601
+ @click.option(
602
+ "--pair-NK",
603
+ is_flag=True,
604
+ help="If set, instead of benchmarking cartesian product of N * K, benchmark consecutive NK pairs together.",
605
+ )
606
+ @click.option(
607
+ "--grouped",
608
+ is_flag=True,
609
+ help="If set, do grouped gemm. In this mode, M, N, and K are interpreted "
610
+ "as the size of groups. The length of each must be the same.",
611
+ )
612
+ @click.option(
613
+ "--groups",
614
+ default=None,
615
+ help="If set with grouped mode, repeat MNK shapes this many times. Comma separated list of groups to benchmark",
616
+ )
617
+ @click.option(
618
+ "--total-K",
619
+ default=None,
620
+ help="If set, adjusts the K values to sum to this number. "
621
+ "This can help simulate real grouped workloads in backward wgrad. "
622
+ "Comma separated list of total-K values to benchmark.",
623
+ )
624
+ @click.option(
625
+ "--total-M",
626
+ default=None,
627
+ help="If set, adjusts the M values to sum to this number. "
628
+ "This can help simulate real grouped workloads."
629
+ "Comma separated list of total-M values to benchmark.",
630
+ )
631
+ @click.option(
632
+ "--no-cuda-graph",
633
+ is_flag=True,
634
+ help="If set, do not use cuda graph for benchmarking.",
635
+ )
636
+ @click.option(
637
+ "--use-rotating-buffer-bench",
638
+ is_flag=True,
639
+ help="If set, use rotating buffer to benchmark.",
640
+ )
641
+ @click.option(
642
+ "--shapes",
643
+ default=None,
644
+ help=f"Specific model shapes to use, options: {', '.join(shape_registry.keys())}.",
645
+ )
646
+ @click.option(
647
+ "--trace",
648
+ is_flag=True,
649
+ help="If set, produce a performance trace of the benchmark.",
650
+ )
651
+ @click.option(
652
+ "--disable-fast-accum",
653
+ is_flag=True,
654
+ help="If set, disable fast accumulation for FP8 implementations.",
655
+ )
656
+ @click.option(
657
+ "--torch-compile",
658
+ is_flag=True,
659
+ help="If set, torch.compile will be used for scaled_mm backed ops.",
660
+ )
661
+ @click.option(
662
+ "--rep",
663
+ default=200,
664
+ type=int,
665
+ help="Repetition time in ms (int) for triton.testing.do_bench",
666
+ )
667
+ def invoke_main(
668
+ output_dir: str,
669
+ num_iters: int,
670
+ export_csv: bool,
671
+ plot: bool,
672
+ enable_amd_env_vars: bool,
673
+ bench_quantize: bool,
674
+ kernels: Optional[str],
675
+ m: Optional[str],
676
+ n: Optional[str],
677
+ k: Optional[str],
678
+ pair_nk: bool,
679
+ grouped: bool,
680
+ groups: Optional[str],
681
+ total_k: Optional[str],
682
+ total_m: Optional[str],
683
+ no_cuda_graph: bool,
684
+ use_rotating_buffer_bench: bool,
685
+ shapes: Optional[str],
686
+ trace: bool,
687
+ disable_fast_accum: bool,
688
+ torch_compile: bool,
689
+ rep: int,
690
+ ):
691
+ if enable_amd_env_vars:
692
+ set_amd_env_vars()
693
+
694
+ # Validate that total_m and total_k are mutually exclusive
695
+ if total_m is not None and total_k is not None:
696
+ raise ValueError(
697
+ "total_m and total_k cannot be specified at the same time. "
698
+ "Please provide only one of them."
699
+ )
700
+
701
+ if groups:
702
+ grouped = True
703
+
704
+ # If kernel filter is provided, parse it. Else, benchmark all kernels.
705
+ all_kernels = kernels.strip().split(",") if kernels else None
706
+ gemm_ops = collect_kernels_to_profile(all_kernels, grouped)
707
+
708
+ if len(gemm_ops) == 0:
709
+ print("No valid kernels to benchmark. Available kernels:")
710
+ print_kernels(all_kernels)
711
+ sys.exit(1)
712
+
713
+ if num_iters < 1:
714
+ print("Warning: Number of iterations must be at least 1.")
715
+ num_iters = 1
716
+
717
+ # Enumerate shapes to benchmark.
718
+ if grouped and not groups:
719
+ # In grouped mode, M, N, and K represent the groups of a single gemm.
720
+ assert m is not None and n is not None and k is not None
721
+ M = [int(m_val) for m_val in m.strip().split(",")]
722
+ N = [int(n_val) for n_val in n.strip().split(",")]
723
+ K = [int(k_val) for k_val in k.strip().split(",")]
724
+ assert len(M) == len(N) == len(K), (
725
+ "M, N, and K must be the same length in grouped mode."
726
+ )
727
+
728
+ # Note this is a single grouped gemm.
729
+ MNK = [[M, N, K]]
730
+ else:
731
+ if shapes:
732
+ if shapes not in shape_registry:
733
+ print(
734
+ f"Shape {shapes} not found in shape registry. Valid shapes: {', '.join(shape_registry.keys())}."
735
+ )
736
+ sys.exit(1)
737
+ MNK = shape_registry[shapes]()
738
+ else:
739
+ if m is None:
740
+ M = [1, 4, 8, 16, 32, 64, 128, 2048, 4096, 8192, 16384]
741
+ else:
742
+ M = [int(m_val) for m_val in m.strip().split(",")]
743
+ if n is None:
744
+ N = [1280, 2304, 7168, 8192, 16384]
745
+ else:
746
+ N = [int(n_val) for n_val in n.strip().split(",")]
747
+ if k is None:
748
+ K = [1024, 3584, 8192, 16384]
749
+ else:
750
+ K = [int(k_val) for k_val in k.strip().split(",")]
751
+ # List all shapes for simplicity.
752
+ if pair_nk:
753
+ if len(N) != len(K):
754
+ raise Exception("N and K must be the same length in pair_NK mode.")
755
+ NK = zip(N, K)
756
+ MNK = [(M, N, K) for (M, (N, K)) in itertools.product(M, NK)]
757
+ else:
758
+ MNK = list(itertools.product(M, N, K))
759
+ # When groups is provided transform shapes into grouped format.
760
+ if groups:
761
+ groups_list = [int(g) for g in groups.strip().split(",")]
762
+ if total_m:
763
+ total_m_list = [int(tm) for tm in total_m.strip().split(",")]
764
+ MNK = [
765
+ [
766
+ generate_group_tensor(g, tm),
767
+ [n] * g,
768
+ [k] * g,
769
+ ]
770
+ for g in groups_list
771
+ for tm in total_m_list
772
+ for _, n, k in MNK
773
+ ]
774
+ shape_mode = ShapeMode.GROUPED_TOTAL_M
775
+ elif total_k:
776
+ total_k_list = [int(tk) for tk in total_k.strip().split(",")]
777
+ MNK = [
778
+ [
779
+ [m] * g,
780
+ [n] * g,
781
+ generate_group_tensor(g, tk),
782
+ ]
783
+ for g in groups_list
784
+ for tk in total_k_list
785
+ for m, n, _ in MNK
786
+ ]
787
+ shape_mode = ShapeMode.GROUPED_TOTAL_K
788
+ else:
789
+ MNK = [[[m] * g, [n] * g, [k] * g] for g in groups_list for m, n, k in MNK]
790
+ shape_mode = ShapeMode.GROUPED
791
+ elif grouped:
792
+ shape_mode = ShapeMode.GROUPED
793
+ else:
794
+ shape_mode = ShapeMode.REGULAR
795
+
796
+ # Iterate over shapes and benchmark.
797
+ mem_bw_gbps = triton.testing.get_dram_gbps()
798
+ benchmark_results: list[Metrics] = []
799
+ csv: list[dict[str, Any]] = []
800
+ benchmark_func = benchmark_grouped if grouped else benchmark
801
+
802
+ opts = BenchOptions(
803
+ num_iters=num_iters,
804
+ cuda_graph=not no_cuda_graph,
805
+ rotating_buffer=use_rotating_buffer_bench,
806
+ rep_ms=rep,
807
+ trace=trace,
808
+ fast_accum=not disable_fast_accum,
809
+ torch_compile=torch_compile,
810
+ )
811
+
812
+ for m, n, k in MNK:
813
+ shape_measurements = benchmark_func(
814
+ gemm_ops,
815
+ m, # pyre-ignore[6]: Incompatible parameter type [6]
816
+ n, # pyre-ignore[6]: Incompatible parameter type [6]
817
+ k, # pyre-ignore[6]: Incompatible parameter type [6]
818
+ mem_bw_gbps,
819
+ opts,
820
+ bench_quantize,
821
+ shape_mode,
822
+ )
823
+ benchmark_results.extend(shape_measurements)
824
+ csv_row: dict[str, Any] = {}
825
+ for metric in shape_measurements:
826
+ csv_row.update(metric.as_dict())
827
+ csv.append(csv_row)
828
+
829
+ print("")
830
+ print(Metrics.header(shape_mode))
831
+ for metric in benchmark_results:
832
+ print(metric)
833
+
834
+ print("")
835
+ print(f"Hardware: {torch.cuda.get_device_name()}")
836
+ print(f" Memory BW: {mem_bw_gbps:.2f} GB/s")
837
+
838
+ print("")
839
+ print("Benchmark Settings:")
840
+ print(f" CUDA graph: {not no_cuda_graph}")
841
+ print(f" Buffer rotation: {use_rotating_buffer_bench}")
842
+ print(f" Fast accumulation: {not disable_fast_accum}")
843
+ print(f" Torch compile: {torch_compile}")
844
+
845
+ if export_csv or plot:
846
+ os.makedirs(output_dir, exist_ok=True)
847
+ if export_csv:
848
+ datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
849
+ csv_file = os.path.join(output_dir, f"gemm_ops_benchmark_{datetime_str}.csv")
850
+ # Export results to a CSV file.
851
+ df = pd.DataFrame(csv)
852
+ df.to_csv(csv_file, na_rep="NaN", index=False)
853
+ print(f"CSV saved to {csv_file}")
854
+ if plot:
855
+ plot_benchmark(benchmark_results, output_dir)
856
+
857
+
858
+ if __name__ == "__main__":
859
+ invoke_main() # pragma: no cover