mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,859 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import itertools
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Any, Optional
|
|
14
|
+
|
|
15
|
+
import click
|
|
16
|
+
import matplotlib.pyplot as plt
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pandas as pd
|
|
19
|
+
import seaborn as sns
|
|
20
|
+
import torch
|
|
21
|
+
import triton # @manual=//triton:triton
|
|
22
|
+
from mslk.bench.common.utils import BenchOptions, profiler
|
|
23
|
+
from mslk.bench.gemm.gemm_ops import ComputeDtype, GemmOpBase, GemmType, get_gemm_ops
|
|
24
|
+
from tabulate import tabulate
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# Compute theoretical roofline values in TFLOPS for GPU and dtype combinations.
|
|
28
|
+
COMPUTE_ROOFLINE_TFLOPS: dict[str, dict[ComputeDtype, float]] = {
|
|
29
|
+
"NVIDIA H100": {
|
|
30
|
+
ComputeDtype.FP8: 1979.0,
|
|
31
|
+
ComputeDtype.BF16: 989.0,
|
|
32
|
+
ComputeDtype.TF32: 494.5,
|
|
33
|
+
ComputeDtype.FP32: 67.0, # non-tensorcore
|
|
34
|
+
},
|
|
35
|
+
"NVIDIA B200": {
|
|
36
|
+
ComputeDtype.FP4: 9000.0,
|
|
37
|
+
ComputeDtype.FP8: 4500.0,
|
|
38
|
+
ComputeDtype.BF16: 2250.0,
|
|
39
|
+
ComputeDtype.TF32: 1100.0,
|
|
40
|
+
ComputeDtype.FP32: 75.0, # non-tensorcore
|
|
41
|
+
},
|
|
42
|
+
"NVIDIA GB200": {
|
|
43
|
+
ComputeDtype.FP4: 10000.0,
|
|
44
|
+
ComputeDtype.FP8: 5000.0,
|
|
45
|
+
ComputeDtype.BF16: 2500.0,
|
|
46
|
+
ComputeDtype.TF32: 1250.0,
|
|
47
|
+
ComputeDtype.FP32: 80.0, # non-tensorcore
|
|
48
|
+
},
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_compute_roofline_tflops(compute_dtype: ComputeDtype) -> float | None:
|
|
53
|
+
gpu_rooflines = COMPUTE_ROOFLINE_TFLOPS.get(torch.cuda.get_device_name())
|
|
54
|
+
if gpu_rooflines is None:
|
|
55
|
+
return None
|
|
56
|
+
return gpu_rooflines.get(compute_dtype)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
shape_registry = {}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def register_shapes(name):
|
|
63
|
+
def decorator(op):
|
|
64
|
+
shape_registry[name] = op
|
|
65
|
+
return op
|
|
66
|
+
|
|
67
|
+
return decorator
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def generate_group_tensor(G, M):
|
|
71
|
+
"""
|
|
72
|
+
Generate a tensor with G elements whose integer elements sum to A.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
G (int): Number of elements in the tensor.
|
|
76
|
+
M (int): Sum of the elements in the tensor.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
torch.Tensor: A tensor with G elements whose integer elements sum to M.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
# First, we generate a random tensor with G elements
|
|
83
|
+
random_tensor = torch.rand(G)
|
|
84
|
+
# Then, we normalize this tensor so it sums up to 1
|
|
85
|
+
normalized_tensor = random_tensor / random_tensor.sum()
|
|
86
|
+
# Finally, we multiply this tensor by M and round to the nearest integer
|
|
87
|
+
output_tensor = torch.round(normalized_tensor * M).to(torch.int64)
|
|
88
|
+
# Adjust the last element to ensure the sum is exactly M
|
|
89
|
+
output_tensor[-1] += max(0, M - output_tensor.sum())
|
|
90
|
+
return output_tensor.tolist()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def set_amd_env_vars() -> None:
|
|
94
|
+
print("Setting environment variables for AMD GPU performance")
|
|
95
|
+
os.environ["DISABLE_ADDMM_HIP_LT"] = "0"
|
|
96
|
+
os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
|
|
97
|
+
os.environ["PYTORCH_TUNABLEOP_VERBOSE"] = "0"
|
|
98
|
+
os.environ["PYTORCH_TUNABLEOP_ENABLED"] = "1"
|
|
99
|
+
os.environ["PYTORCH_TUNABLEOP_TUNING"] = "1"
|
|
100
|
+
os.environ["PYTORCH_TUNABLEOP_FILENAME"] = "hipblas_tuning_pt_llama.csv"
|
|
101
|
+
os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "30"
|
|
102
|
+
os.environ["PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"] = "30"
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@register_shapes("llama3_70b")
|
|
106
|
+
def llama3_70b_shapes() -> list[tuple[int, int, int]]:
|
|
107
|
+
shapes = []
|
|
108
|
+
for M in [1, 16, 32, 64, 96, 128]:
|
|
109
|
+
shapes += [
|
|
110
|
+
(M, 1280, 8192),
|
|
111
|
+
(M, 8192, 1024),
|
|
112
|
+
(M, 7168, 8192),
|
|
113
|
+
(M, 8192, 3584),
|
|
114
|
+
]
|
|
115
|
+
return shapes
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@register_shapes("autotune")
|
|
119
|
+
def autotune() -> list[tuple[int, int, int]]:
|
|
120
|
+
shapes = []
|
|
121
|
+
for M in [
|
|
122
|
+
1,
|
|
123
|
+
64,
|
|
124
|
+
128,
|
|
125
|
+
256,
|
|
126
|
+
512,
|
|
127
|
+
1024,
|
|
128
|
+
2048,
|
|
129
|
+
4096,
|
|
130
|
+
8192,
|
|
131
|
+
16384,
|
|
132
|
+
]:
|
|
133
|
+
for N in range(1024, 16384 + 1, 1024):
|
|
134
|
+
for K in range(1024, 16384 + 1, 1024):
|
|
135
|
+
shapes.append((M, N, K))
|
|
136
|
+
return shapes
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@register_shapes("llama3_405b")
|
|
140
|
+
def llama3_405b_shapes() -> list[tuple[int, int, int]]:
|
|
141
|
+
shapes = []
|
|
142
|
+
for M in [1, 16, 32, 64, 96, 128]:
|
|
143
|
+
shapes += [
|
|
144
|
+
(M, 13312, 6656),
|
|
145
|
+
(M, 13312, 16384),
|
|
146
|
+
(M, 16384, 6656),
|
|
147
|
+
(M, 16384, 16384),
|
|
148
|
+
]
|
|
149
|
+
return shapes
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@register_shapes("llama4")
|
|
153
|
+
def llama4_shapes() -> list[tuple[int, int, int]]:
|
|
154
|
+
shapes = []
|
|
155
|
+
for M in [1, 16, 32, 64, 96, 128]:
|
|
156
|
+
shapes += [
|
|
157
|
+
(M, 896, 5120),
|
|
158
|
+
(M, 5120, 640),
|
|
159
|
+
(M, 2048, 5120),
|
|
160
|
+
(M, 5120, 1024),
|
|
161
|
+
]
|
|
162
|
+
return shapes
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@register_shapes("ldm")
|
|
166
|
+
def ldm_shapes() -> list[tuple[int, int, int]]:
|
|
167
|
+
return [
|
|
168
|
+
(1536, 3584, 3584),
|
|
169
|
+
(8192, 9728, 3584),
|
|
170
|
+
(8192, 3584, 9728),
|
|
171
|
+
(8192, 3584, 3584),
|
|
172
|
+
(4096, 3584, 3584),
|
|
173
|
+
(768, 3584, 3584),
|
|
174
|
+
(4096, 9728, 3584),
|
|
175
|
+
(4096, 3584, 9728),
|
|
176
|
+
(7200, 3584, 3584),
|
|
177
|
+
(7200, 9728, 3584),
|
|
178
|
+
(7200, 3584, 9728),
|
|
179
|
+
(3600, 3584, 3584),
|
|
180
|
+
(3600, 9728, 3584),
|
|
181
|
+
(3600, 3584, 9728),
|
|
182
|
+
(1536, 4096, 4096),
|
|
183
|
+
(3600, 4096, 4096),
|
|
184
|
+
(3600, 11008, 4096),
|
|
185
|
+
(3600, 4096, 11008),
|
|
186
|
+
(4096, 4096, 4096),
|
|
187
|
+
(4096, 11008, 4096),
|
|
188
|
+
(4096, 4096, 11008),
|
|
189
|
+
(32768, 128, 8192),
|
|
190
|
+
(32768, 8192, 1024),
|
|
191
|
+
(32768, 8192, 3072),
|
|
192
|
+
(32768, 3072, 8192),
|
|
193
|
+
(32768, 1024, 8192),
|
|
194
|
+
]
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class ShapeMode(Enum):
|
|
198
|
+
REGULAR = "regular" # (M, N, K)
|
|
199
|
+
GROUPED = "grouped" # G, (M, N, K)
|
|
200
|
+
GROUPED_TOTAL_M = "grouped_total_m" # G, (TotalM, N, K)
|
|
201
|
+
GROUPED_TOTAL_K = "grouped_total_k" # G, (M, N, TotalK)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@dataclass
|
|
205
|
+
class Metrics:
|
|
206
|
+
op: str
|
|
207
|
+
M: Any = 0
|
|
208
|
+
N: Any = 0
|
|
209
|
+
K: Any = 0
|
|
210
|
+
groups: Optional[int] = None
|
|
211
|
+
shape_mode: ShapeMode = ShapeMode.REGULAR
|
|
212
|
+
|
|
213
|
+
sim: float = 0.0
|
|
214
|
+
ms: float = 0.0
|
|
215
|
+
tflops: float = 0.0
|
|
216
|
+
gbps: float = 0.0
|
|
217
|
+
mem_bw_util: float = 0.0
|
|
218
|
+
compute_util: float = 0.0
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def header(shape_mode: ShapeMode = ShapeMode.REGULAR) -> str:
|
|
222
|
+
is_grouped = shape_mode in (
|
|
223
|
+
ShapeMode.GROUPED,
|
|
224
|
+
ShapeMode.GROUPED_TOTAL_M,
|
|
225
|
+
ShapeMode.GROUPED_TOTAL_K,
|
|
226
|
+
)
|
|
227
|
+
if shape_mode == ShapeMode.GROUPED_TOTAL_M:
|
|
228
|
+
shape_col = "(TotalM, N, K)"
|
|
229
|
+
elif shape_mode == ShapeMode.GROUPED_TOTAL_K:
|
|
230
|
+
shape_col = "(M, N, TotalK)"
|
|
231
|
+
else:
|
|
232
|
+
shape_col = "(M, N, K)"
|
|
233
|
+
|
|
234
|
+
group_col = f"{'G':<6}" if is_grouped else ""
|
|
235
|
+
header = (
|
|
236
|
+
f"{'OpName':<30} {group_col} {shape_col:<25} "
|
|
237
|
+
f"{'Sim':<10} {'Ms':<10} {'TFLOPS':<10} "
|
|
238
|
+
f"{'GB/s':<10} {'Mem BW Util %':<14} {'Compute Util %':<10}"
|
|
239
|
+
)
|
|
240
|
+
divider = "-" * len(header)
|
|
241
|
+
return f"GEMM Bench\n{divider}\n{header}\n{divider}"
|
|
242
|
+
|
|
243
|
+
def __str__(self) -> str:
|
|
244
|
+
is_grouped = self.shape_mode in (
|
|
245
|
+
ShapeMode.GROUPED,
|
|
246
|
+
ShapeMode.GROUPED_TOTAL_M,
|
|
247
|
+
ShapeMode.GROUPED_TOTAL_K,
|
|
248
|
+
)
|
|
249
|
+
if self.shape_mode == ShapeMode.GROUPED_TOTAL_M:
|
|
250
|
+
total_m = sum(self.M) if isinstance(self.M, list) else self.M
|
|
251
|
+
shape = f"({total_m}, {self.N}, {self.K})"
|
|
252
|
+
elif self.shape_mode == ShapeMode.GROUPED_TOTAL_K:
|
|
253
|
+
total_k = sum(self.K) if isinstance(self.K, list) else self.K
|
|
254
|
+
shape = f"({self.M}, {self.N}, {total_k})"
|
|
255
|
+
else:
|
|
256
|
+
shape = f"({self.M}, {self.N}, {self.K})"
|
|
257
|
+
|
|
258
|
+
group_col = f"{self.groups:<6}" if is_grouped else ""
|
|
259
|
+
compute_util_str = (
|
|
260
|
+
f"{self.compute_util:<10.2f}" if self.compute_util > 0 else "N/A"
|
|
261
|
+
)
|
|
262
|
+
return (
|
|
263
|
+
f"{self.op:<30} {group_col} {shape:<25} "
|
|
264
|
+
f"{self.sim:<10.3f} {self.ms:<10.3f} "
|
|
265
|
+
f"{self.tflops:<10.2f} {self.gbps:<10.2f} "
|
|
266
|
+
f"{self.mem_bw_util:<14.2f} {compute_util_str}"
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def as_dict(self) -> dict[str, Any]:
|
|
270
|
+
result: dict[str, Any] = {
|
|
271
|
+
"M": self.M,
|
|
272
|
+
"N": self.N,
|
|
273
|
+
"K": self.K,
|
|
274
|
+
f"{self.op}_sim": self.sim,
|
|
275
|
+
f"{self.op}_ms": self.ms,
|
|
276
|
+
f"{self.op}_tflops": self.tflops,
|
|
277
|
+
f"{self.op}_gb/s": self.gbps,
|
|
278
|
+
f"{self.op}_mem_bw_util": self.mem_bw_util,
|
|
279
|
+
f"{self.op}_compute_util": self.compute_util,
|
|
280
|
+
}
|
|
281
|
+
if self.groups is not None:
|
|
282
|
+
result["groups"] = self.groups
|
|
283
|
+
return result
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def benchmark_grouped(
|
|
287
|
+
gemm_ops: list[GemmOpBase],
|
|
288
|
+
m: list[int],
|
|
289
|
+
n: list[int],
|
|
290
|
+
k: list[int],
|
|
291
|
+
mem_bw_roofline_gbps: float,
|
|
292
|
+
opts: BenchOptions,
|
|
293
|
+
bench_quantize: bool = False,
|
|
294
|
+
shape_mode: ShapeMode = ShapeMode.GROUPED,
|
|
295
|
+
) -> list[Metrics]:
|
|
296
|
+
num_groups = len(m)
|
|
297
|
+
# Create input tensors.
|
|
298
|
+
A = []
|
|
299
|
+
B = []
|
|
300
|
+
for i in range(num_groups):
|
|
301
|
+
A.append(torch.randn(m[i], k[i], device="cuda", dtype=torch.bfloat16))
|
|
302
|
+
B.append(torch.randn(n[i], k[i], device="cuda", dtype=torch.bfloat16))
|
|
303
|
+
# Compute baseline output for correctness checking.
|
|
304
|
+
out_ref = []
|
|
305
|
+
for i in range(num_groups):
|
|
306
|
+
out_ref.append(torch.matmul(A[i], B[i].t()))
|
|
307
|
+
# Keep track of results.
|
|
308
|
+
# Only log all shapes in a group if they are unique.
|
|
309
|
+
log_m = m[0] if len(np.unique(m)) == 1 else m
|
|
310
|
+
log_n = n[0] if len(np.unique(n)) == 1 else n
|
|
311
|
+
log_k = k[0] if len(np.unique(k)) == 1 else k
|
|
312
|
+
results: list[Metrics] = []
|
|
313
|
+
# Benchmark each operator.
|
|
314
|
+
for gemm_op in gemm_ops:
|
|
315
|
+
# Build progress message based on shape mode.
|
|
316
|
+
if shape_mode == ShapeMode.GROUPED_TOTAL_M:
|
|
317
|
+
total_m = sum(m)
|
|
318
|
+
shape_str = f"(G={num_groups}, TotalM={total_m}, N={log_n}, K={log_k})"
|
|
319
|
+
elif shape_mode == ShapeMode.GROUPED_TOTAL_K:
|
|
320
|
+
total_k = sum(k)
|
|
321
|
+
shape_str = f"(G={num_groups}, M={log_m}, N={log_n}, TotalK={total_k})"
|
|
322
|
+
else:
|
|
323
|
+
shape_str = f"(G={num_groups}, M={log_m}, N={log_n}, K={log_k})"
|
|
324
|
+
print(f"Benchmarking {gemm_op.name} with {shape_str}")
|
|
325
|
+
metrics = Metrics(
|
|
326
|
+
op=gemm_op.name,
|
|
327
|
+
M=log_m,
|
|
328
|
+
N=log_n,
|
|
329
|
+
K=log_k,
|
|
330
|
+
groups=num_groups,
|
|
331
|
+
shape_mode=shape_mode,
|
|
332
|
+
)
|
|
333
|
+
# Set fast accum mode if applicable.
|
|
334
|
+
if hasattr(gemm_op, "fast_accum"):
|
|
335
|
+
gemm_op.fast_accum = opts.fast_accum
|
|
336
|
+
if hasattr(gemm_op, "torch_compile"):
|
|
337
|
+
gemm_op.torch_compile = opts.torch_compile
|
|
338
|
+
|
|
339
|
+
# Get compute roofline for this op's compute dtype.
|
|
340
|
+
compute_roofline_tflops = get_compute_roofline_tflops(gemm_op.compute_dtype)
|
|
341
|
+
|
|
342
|
+
try:
|
|
343
|
+
# Get the quantized tensors for this operator.
|
|
344
|
+
preprocessed_args = gemm_op.preprocess(A, B)
|
|
345
|
+
quantized_vals = gemm_op.quantize(*preprocessed_args)
|
|
346
|
+
# Compute the output given quantized values.
|
|
347
|
+
output = gemm_op.compute(*quantized_vals)
|
|
348
|
+
except Exception as e:
|
|
349
|
+
print(f"GEMM op {gemm_op.name} failed to run due to error: {e}.")
|
|
350
|
+
continue
|
|
351
|
+
# Some kernels may pad output, just take the first m values of each row.
|
|
352
|
+
if isinstance(output, torch.Tensor) and output.ndim == 2:
|
|
353
|
+
# Output is stacked and needs to be split.
|
|
354
|
+
output = torch.split(output, m, dim=0)
|
|
355
|
+
else:
|
|
356
|
+
# Otherwise output may be padded or require unbinding.
|
|
357
|
+
output = [o[: m[i]] for i, o in enumerate(output)]
|
|
358
|
+
# Compare the quantize op output to reference as a sanity check.
|
|
359
|
+
for i in range(num_groups):
|
|
360
|
+
if m[i] > 0:
|
|
361
|
+
metrics.sim += float(
|
|
362
|
+
torch.mean(torch.pow(output[i] - out_ref[i], 2)).item()
|
|
363
|
+
)
|
|
364
|
+
for _ in range(opts.num_iters):
|
|
365
|
+
# Now perform benchmark.
|
|
366
|
+
if bench_quantize:
|
|
367
|
+
# Benchmark both quantize and compute.
|
|
368
|
+
with profiler(enabled=opts.trace, with_stack=True):
|
|
369
|
+
ms_runtime = gemm_op.benchmark(
|
|
370
|
+
*preprocessed_args,
|
|
371
|
+
opts=opts,
|
|
372
|
+
bench_quantize=True,
|
|
373
|
+
)
|
|
374
|
+
else:
|
|
375
|
+
with profiler(enabled=opts.trace, with_stack=True):
|
|
376
|
+
ms_runtime = gemm_op.benchmark(
|
|
377
|
+
*quantized_vals,
|
|
378
|
+
opts=opts,
|
|
379
|
+
bench_quantize=False,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
for i in range(num_groups):
|
|
383
|
+
output_multiplier = 2 if "fuse_scatter_add" in gemm_op.name else 1
|
|
384
|
+
if m[i] > 0:
|
|
385
|
+
tflops = 2 * m[i] * n[i] * k[i] / (ms_runtime / 1e3) / 1e12
|
|
386
|
+
gbps = (
|
|
387
|
+
(
|
|
388
|
+
m[i] * k[i] * quantized_vals[0][0].element_size()
|
|
389
|
+
+ n[i] * k[i] * quantized_vals[1][0].element_size()
|
|
390
|
+
+ output_multiplier * m[i] * n[i] * output[0].element_size()
|
|
391
|
+
)
|
|
392
|
+
/ (ms_runtime / 1e3)
|
|
393
|
+
/ 1e9
|
|
394
|
+
)
|
|
395
|
+
metrics.gbps += gbps
|
|
396
|
+
metrics.tflops += tflops
|
|
397
|
+
metrics.mem_bw_util += (gbps / mem_bw_roofline_gbps) * 100
|
|
398
|
+
if compute_roofline_tflops is not None:
|
|
399
|
+
metrics.compute_util += (tflops / compute_roofline_tflops) * 100
|
|
400
|
+
metrics.ms += ms_runtime
|
|
401
|
+
metrics.ms /= opts.num_iters
|
|
402
|
+
metrics.tflops /= opts.num_iters
|
|
403
|
+
metrics.gbps /= opts.num_iters
|
|
404
|
+
metrics.mem_bw_util /= opts.num_iters
|
|
405
|
+
metrics.compute_util /= opts.num_iters
|
|
406
|
+
|
|
407
|
+
results.append(metrics)
|
|
408
|
+
|
|
409
|
+
return results
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def benchmark(
|
|
413
|
+
gemm_ops: list[GemmOpBase],
|
|
414
|
+
m: int,
|
|
415
|
+
n: int,
|
|
416
|
+
k: int,
|
|
417
|
+
mem_bw_roofline_gbps: float,
|
|
418
|
+
opts: BenchOptions,
|
|
419
|
+
bench_quantize: bool = False,
|
|
420
|
+
shape_mode: ShapeMode = ShapeMode.REGULAR,
|
|
421
|
+
) -> list[Metrics]:
|
|
422
|
+
# Create input tensors.
|
|
423
|
+
A = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
|
|
424
|
+
B = torch.randn(n, k, device="cuda", dtype=torch.bfloat16)
|
|
425
|
+
|
|
426
|
+
# Compute baseline output for correctness checking.
|
|
427
|
+
out_ref = torch.matmul(A, torch.transpose(B, -2, -1))
|
|
428
|
+
# Keep track of results.
|
|
429
|
+
results: list[Metrics] = []
|
|
430
|
+
# Benchmark each operator.
|
|
431
|
+
for gemm_op in gemm_ops:
|
|
432
|
+
shape_str = f"(M={m}, N={n}, K={k})"
|
|
433
|
+
print(f"Benchmarking {gemm_op.name} with {shape_str}")
|
|
434
|
+
metrics = Metrics(op=gemm_op.name, M=m, N=n, K=k, shape_mode=shape_mode)
|
|
435
|
+
# Set fast accum mode if applicable.
|
|
436
|
+
if hasattr(gemm_op, "fast_accum"):
|
|
437
|
+
gemm_op.fast_accum = opts.fast_accum
|
|
438
|
+
if hasattr(gemm_op, "torch_compile"):
|
|
439
|
+
gemm_op.torch_compile = opts.torch_compile
|
|
440
|
+
|
|
441
|
+
# Get compute roofline for this op's compute dtype.
|
|
442
|
+
compute_roofline_tflops = get_compute_roofline_tflops(gemm_op.compute_dtype)
|
|
443
|
+
|
|
444
|
+
try:
|
|
445
|
+
# Preprocess data if needed.
|
|
446
|
+
preprocessed_args = gemm_op.preprocess(A, B)
|
|
447
|
+
# Get the quantized tensors for this operator.
|
|
448
|
+
quantized_vals = gemm_op.quantize(*preprocessed_args)
|
|
449
|
+
# Compute the output given quantized values.
|
|
450
|
+
output = gemm_op.compute(*quantized_vals)
|
|
451
|
+
except Exception as e:
|
|
452
|
+
print(f"GEMM op {gemm_op.name} failed to run due to error: {e}.")
|
|
453
|
+
continue
|
|
454
|
+
# Compare the quantize op output to reference as a sanity check.
|
|
455
|
+
# TODO(shikaili): This calculation is incorrect for scatter add fusion.
|
|
456
|
+
metrics.sim = torch.mean(torch.pow(output - out_ref, 2)).item()
|
|
457
|
+
|
|
458
|
+
for _ in range(opts.num_iters):
|
|
459
|
+
# Now perform benchmark.
|
|
460
|
+
if bench_quantize:
|
|
461
|
+
# Benchmark both quantize and compute.
|
|
462
|
+
with profiler(enabled=opts.trace, with_stack=True):
|
|
463
|
+
ms_runtime = gemm_op.benchmark(
|
|
464
|
+
*preprocessed_args,
|
|
465
|
+
opts=opts,
|
|
466
|
+
bench_quantize=True,
|
|
467
|
+
)
|
|
468
|
+
else:
|
|
469
|
+
with profiler(enabled=opts.trace, with_stack=True):
|
|
470
|
+
ms_runtime = gemm_op.benchmark(
|
|
471
|
+
*quantized_vals,
|
|
472
|
+
opts=opts,
|
|
473
|
+
bench_quantize=False,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
tflops = 2 * m * n * k / (ms_runtime / 1e3) / 1e12
|
|
477
|
+
metrics.tflops += tflops
|
|
478
|
+
gbps = (
|
|
479
|
+
(
|
|
480
|
+
quantized_vals[0].numel() * quantized_vals[0].element_size()
|
|
481
|
+
+ quantized_vals[1].numel() * quantized_vals[1].element_size()
|
|
482
|
+
+ output.numel() * output.element_size()
|
|
483
|
+
)
|
|
484
|
+
/ (ms_runtime / 1e3)
|
|
485
|
+
/ 1e9
|
|
486
|
+
)
|
|
487
|
+
metrics.gbps += gbps
|
|
488
|
+
metrics.mem_bw_util += (gbps / mem_bw_roofline_gbps) * 100
|
|
489
|
+
if compute_roofline_tflops is not None:
|
|
490
|
+
metrics.compute_util += (tflops / compute_roofline_tflops) * 100
|
|
491
|
+
metrics.ms += ms_runtime
|
|
492
|
+
metrics.ms /= opts.num_iters
|
|
493
|
+
metrics.tflops /= opts.num_iters
|
|
494
|
+
metrics.gbps /= opts.num_iters
|
|
495
|
+
metrics.mem_bw_util /= opts.num_iters
|
|
496
|
+
metrics.compute_util /= opts.num_iters
|
|
497
|
+
|
|
498
|
+
results.append(metrics)
|
|
499
|
+
|
|
500
|
+
return results
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def plot_benchmark(results: list[Metrics], output_dir: str) -> None:
|
|
504
|
+
"""Create a barplot visualizing the TFLOPS of each kernel."""
|
|
505
|
+
# Reprocess into new dataframe with proper graph format.
|
|
506
|
+
data = []
|
|
507
|
+
# Extract measurements for each shape.
|
|
508
|
+
for metric in results:
|
|
509
|
+
mnk = f"{metric.M}, {metric.N}, {metric.K}"
|
|
510
|
+
data.append({"MNK": mnk, "kernel": metric.op, "TFLOPS": metric.tflops})
|
|
511
|
+
|
|
512
|
+
# Create a barplot using seaborn.
|
|
513
|
+
df = pd.DataFrame(data)
|
|
514
|
+
plot = plt.figure()
|
|
515
|
+
plt.xticks(rotation=30)
|
|
516
|
+
plt.yscale("log")
|
|
517
|
+
ax = sns.barplot(x="MNK", y="TFLOPS", hue="kernel", data=df)
|
|
518
|
+
ax.tick_params(axis="x", labelsize=3)
|
|
519
|
+
img_fn = os.path.join(output_dir, "gemm_ops_benchmark.png")
|
|
520
|
+
plot.savefig(img_fn, dpi=300)
|
|
521
|
+
print(f"Plot saved to {img_fn}")
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def collect_kernels_to_profile(
|
|
525
|
+
kernels: Optional[list[str]], is_grouped: bool
|
|
526
|
+
) -> list[GemmOpBase]:
|
|
527
|
+
gemm_type = GemmType.GROUPED if is_grouped else GemmType.REGULAR
|
|
528
|
+
gemm_ops = [
|
|
529
|
+
op
|
|
530
|
+
for op in get_gemm_ops()
|
|
531
|
+
if op.supported and gemm_type in op.supported_gemm_types
|
|
532
|
+
]
|
|
533
|
+
if kernels is None:
|
|
534
|
+
return gemm_ops
|
|
535
|
+
return [op for op in gemm_ops if op.name in kernels]
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def print_kernels(kernels: Optional[list[str]]) -> list[GemmOpBase]:
|
|
539
|
+
data = sorted(
|
|
540
|
+
(
|
|
541
|
+
op.name,
|
|
542
|
+
",".join(accelerator.name for accelerator in op.supported_accelerators),
|
|
543
|
+
)
|
|
544
|
+
for op in get_gemm_ops()
|
|
545
|
+
)
|
|
546
|
+
print(tabulate(data, headers=["Name", "Accelerators"], tablefmt="orgtbl"))
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
@click.command()
|
|
550
|
+
@click.option(
|
|
551
|
+
"--output-dir",
|
|
552
|
+
default="/tmp",
|
|
553
|
+
help="Directory to save plots and csvs to",
|
|
554
|
+
)
|
|
555
|
+
@click.option(
|
|
556
|
+
"--num-iters",
|
|
557
|
+
default=1,
|
|
558
|
+
type=int,
|
|
559
|
+
help="Number of iterations to repeat each benchmark.",
|
|
560
|
+
)
|
|
561
|
+
@click.option(
|
|
562
|
+
"--export-csv",
|
|
563
|
+
is_flag=True,
|
|
564
|
+
help="Export results to a CSV file.",
|
|
565
|
+
)
|
|
566
|
+
@click.option(
|
|
567
|
+
"--plot",
|
|
568
|
+
is_flag=True,
|
|
569
|
+
help="Create a plot of the benchmark measurements.",
|
|
570
|
+
)
|
|
571
|
+
@click.option(
|
|
572
|
+
"--enable-amd-env-vars",
|
|
573
|
+
is_flag=True,
|
|
574
|
+
help="Enable a set of environment variables for AMD GPU performance",
|
|
575
|
+
)
|
|
576
|
+
@click.option(
|
|
577
|
+
"--bench-quantize",
|
|
578
|
+
is_flag=True,
|
|
579
|
+
help="If set, include quantization cost in benchmark.",
|
|
580
|
+
)
|
|
581
|
+
@click.option(
|
|
582
|
+
"--kernels",
|
|
583
|
+
default=None,
|
|
584
|
+
help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
|
|
585
|
+
)
|
|
586
|
+
@click.option(
|
|
587
|
+
"--M",
|
|
588
|
+
default=None,
|
|
589
|
+
help="Comma separated list of M values to benchmark.",
|
|
590
|
+
)
|
|
591
|
+
@click.option(
|
|
592
|
+
"--N",
|
|
593
|
+
default=None,
|
|
594
|
+
help="Comma separated list of N values to benchmark",
|
|
595
|
+
)
|
|
596
|
+
@click.option(
|
|
597
|
+
"--K",
|
|
598
|
+
default=None,
|
|
599
|
+
help="Comma separated list of K values to benchmark.",
|
|
600
|
+
)
|
|
601
|
+
@click.option(
|
|
602
|
+
"--pair-NK",
|
|
603
|
+
is_flag=True,
|
|
604
|
+
help="If set, instead of benchmarking cartesian product of N * K, benchmark consecutive NK pairs together.",
|
|
605
|
+
)
|
|
606
|
+
@click.option(
|
|
607
|
+
"--grouped",
|
|
608
|
+
is_flag=True,
|
|
609
|
+
help="If set, do grouped gemm. In this mode, M, N, and K are interpreted "
|
|
610
|
+
"as the size of groups. The length of each must be the same.",
|
|
611
|
+
)
|
|
612
|
+
@click.option(
|
|
613
|
+
"--groups",
|
|
614
|
+
default=None,
|
|
615
|
+
help="If set with grouped mode, repeat MNK shapes this many times. Comma separated list of groups to benchmark",
|
|
616
|
+
)
|
|
617
|
+
@click.option(
|
|
618
|
+
"--total-K",
|
|
619
|
+
default=None,
|
|
620
|
+
help="If set, adjusts the K values to sum to this number. "
|
|
621
|
+
"This can help simulate real grouped workloads in backward wgrad. "
|
|
622
|
+
"Comma separated list of total-K values to benchmark.",
|
|
623
|
+
)
|
|
624
|
+
@click.option(
|
|
625
|
+
"--total-M",
|
|
626
|
+
default=None,
|
|
627
|
+
help="If set, adjusts the M values to sum to this number. "
|
|
628
|
+
"This can help simulate real grouped workloads."
|
|
629
|
+
"Comma separated list of total-M values to benchmark.",
|
|
630
|
+
)
|
|
631
|
+
@click.option(
|
|
632
|
+
"--no-cuda-graph",
|
|
633
|
+
is_flag=True,
|
|
634
|
+
help="If set, do not use cuda graph for benchmarking.",
|
|
635
|
+
)
|
|
636
|
+
@click.option(
|
|
637
|
+
"--use-rotating-buffer-bench",
|
|
638
|
+
is_flag=True,
|
|
639
|
+
help="If set, use rotating buffer to benchmark.",
|
|
640
|
+
)
|
|
641
|
+
@click.option(
|
|
642
|
+
"--shapes",
|
|
643
|
+
default=None,
|
|
644
|
+
help=f"Specific model shapes to use, options: {', '.join(shape_registry.keys())}.",
|
|
645
|
+
)
|
|
646
|
+
@click.option(
|
|
647
|
+
"--trace",
|
|
648
|
+
is_flag=True,
|
|
649
|
+
help="If set, produce a performance trace of the benchmark.",
|
|
650
|
+
)
|
|
651
|
+
@click.option(
|
|
652
|
+
"--disable-fast-accum",
|
|
653
|
+
is_flag=True,
|
|
654
|
+
help="If set, disable fast accumulation for FP8 implementations.",
|
|
655
|
+
)
|
|
656
|
+
@click.option(
|
|
657
|
+
"--torch-compile",
|
|
658
|
+
is_flag=True,
|
|
659
|
+
help="If set, torch.compile will be used for scaled_mm backed ops.",
|
|
660
|
+
)
|
|
661
|
+
@click.option(
|
|
662
|
+
"--rep",
|
|
663
|
+
default=200,
|
|
664
|
+
type=int,
|
|
665
|
+
help="Repetition time in ms (int) for triton.testing.do_bench",
|
|
666
|
+
)
|
|
667
|
+
def invoke_main(
|
|
668
|
+
output_dir: str,
|
|
669
|
+
num_iters: int,
|
|
670
|
+
export_csv: bool,
|
|
671
|
+
plot: bool,
|
|
672
|
+
enable_amd_env_vars: bool,
|
|
673
|
+
bench_quantize: bool,
|
|
674
|
+
kernels: Optional[str],
|
|
675
|
+
m: Optional[str],
|
|
676
|
+
n: Optional[str],
|
|
677
|
+
k: Optional[str],
|
|
678
|
+
pair_nk: bool,
|
|
679
|
+
grouped: bool,
|
|
680
|
+
groups: Optional[str],
|
|
681
|
+
total_k: Optional[str],
|
|
682
|
+
total_m: Optional[str],
|
|
683
|
+
no_cuda_graph: bool,
|
|
684
|
+
use_rotating_buffer_bench: bool,
|
|
685
|
+
shapes: Optional[str],
|
|
686
|
+
trace: bool,
|
|
687
|
+
disable_fast_accum: bool,
|
|
688
|
+
torch_compile: bool,
|
|
689
|
+
rep: int,
|
|
690
|
+
):
|
|
691
|
+
if enable_amd_env_vars:
|
|
692
|
+
set_amd_env_vars()
|
|
693
|
+
|
|
694
|
+
# Validate that total_m and total_k are mutually exclusive
|
|
695
|
+
if total_m is not None and total_k is not None:
|
|
696
|
+
raise ValueError(
|
|
697
|
+
"total_m and total_k cannot be specified at the same time. "
|
|
698
|
+
"Please provide only one of them."
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
if groups:
|
|
702
|
+
grouped = True
|
|
703
|
+
|
|
704
|
+
# If kernel filter is provided, parse it. Else, benchmark all kernels.
|
|
705
|
+
all_kernels = kernels.strip().split(",") if kernels else None
|
|
706
|
+
gemm_ops = collect_kernels_to_profile(all_kernels, grouped)
|
|
707
|
+
|
|
708
|
+
if len(gemm_ops) == 0:
|
|
709
|
+
print("No valid kernels to benchmark. Available kernels:")
|
|
710
|
+
print_kernels(all_kernels)
|
|
711
|
+
sys.exit(1)
|
|
712
|
+
|
|
713
|
+
if num_iters < 1:
|
|
714
|
+
print("Warning: Number of iterations must be at least 1.")
|
|
715
|
+
num_iters = 1
|
|
716
|
+
|
|
717
|
+
# Enumerate shapes to benchmark.
|
|
718
|
+
if grouped and not groups:
|
|
719
|
+
# In grouped mode, M, N, and K represent the groups of a single gemm.
|
|
720
|
+
assert m is not None and n is not None and k is not None
|
|
721
|
+
M = [int(m_val) for m_val in m.strip().split(",")]
|
|
722
|
+
N = [int(n_val) for n_val in n.strip().split(",")]
|
|
723
|
+
K = [int(k_val) for k_val in k.strip().split(",")]
|
|
724
|
+
assert len(M) == len(N) == len(K), (
|
|
725
|
+
"M, N, and K must be the same length in grouped mode."
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
# Note this is a single grouped gemm.
|
|
729
|
+
MNK = [[M, N, K]]
|
|
730
|
+
else:
|
|
731
|
+
if shapes:
|
|
732
|
+
if shapes not in shape_registry:
|
|
733
|
+
print(
|
|
734
|
+
f"Shape {shapes} not found in shape registry. Valid shapes: {', '.join(shape_registry.keys())}."
|
|
735
|
+
)
|
|
736
|
+
sys.exit(1)
|
|
737
|
+
MNK = shape_registry[shapes]()
|
|
738
|
+
else:
|
|
739
|
+
if m is None:
|
|
740
|
+
M = [1, 4, 8, 16, 32, 64, 128, 2048, 4096, 8192, 16384]
|
|
741
|
+
else:
|
|
742
|
+
M = [int(m_val) for m_val in m.strip().split(",")]
|
|
743
|
+
if n is None:
|
|
744
|
+
N = [1280, 2304, 7168, 8192, 16384]
|
|
745
|
+
else:
|
|
746
|
+
N = [int(n_val) for n_val in n.strip().split(",")]
|
|
747
|
+
if k is None:
|
|
748
|
+
K = [1024, 3584, 8192, 16384]
|
|
749
|
+
else:
|
|
750
|
+
K = [int(k_val) for k_val in k.strip().split(",")]
|
|
751
|
+
# List all shapes for simplicity.
|
|
752
|
+
if pair_nk:
|
|
753
|
+
if len(N) != len(K):
|
|
754
|
+
raise Exception("N and K must be the same length in pair_NK mode.")
|
|
755
|
+
NK = zip(N, K)
|
|
756
|
+
MNK = [(M, N, K) for (M, (N, K)) in itertools.product(M, NK)]
|
|
757
|
+
else:
|
|
758
|
+
MNK = list(itertools.product(M, N, K))
|
|
759
|
+
# When groups is provided transform shapes into grouped format.
|
|
760
|
+
if groups:
|
|
761
|
+
groups_list = [int(g) for g in groups.strip().split(",")]
|
|
762
|
+
if total_m:
|
|
763
|
+
total_m_list = [int(tm) for tm in total_m.strip().split(",")]
|
|
764
|
+
MNK = [
|
|
765
|
+
[
|
|
766
|
+
generate_group_tensor(g, tm),
|
|
767
|
+
[n] * g,
|
|
768
|
+
[k] * g,
|
|
769
|
+
]
|
|
770
|
+
for g in groups_list
|
|
771
|
+
for tm in total_m_list
|
|
772
|
+
for _, n, k in MNK
|
|
773
|
+
]
|
|
774
|
+
shape_mode = ShapeMode.GROUPED_TOTAL_M
|
|
775
|
+
elif total_k:
|
|
776
|
+
total_k_list = [int(tk) for tk in total_k.strip().split(",")]
|
|
777
|
+
MNK = [
|
|
778
|
+
[
|
|
779
|
+
[m] * g,
|
|
780
|
+
[n] * g,
|
|
781
|
+
generate_group_tensor(g, tk),
|
|
782
|
+
]
|
|
783
|
+
for g in groups_list
|
|
784
|
+
for tk in total_k_list
|
|
785
|
+
for m, n, _ in MNK
|
|
786
|
+
]
|
|
787
|
+
shape_mode = ShapeMode.GROUPED_TOTAL_K
|
|
788
|
+
else:
|
|
789
|
+
MNK = [[[m] * g, [n] * g, [k] * g] for g in groups_list for m, n, k in MNK]
|
|
790
|
+
shape_mode = ShapeMode.GROUPED
|
|
791
|
+
elif grouped:
|
|
792
|
+
shape_mode = ShapeMode.GROUPED
|
|
793
|
+
else:
|
|
794
|
+
shape_mode = ShapeMode.REGULAR
|
|
795
|
+
|
|
796
|
+
# Iterate over shapes and benchmark.
|
|
797
|
+
mem_bw_gbps = triton.testing.get_dram_gbps()
|
|
798
|
+
benchmark_results: list[Metrics] = []
|
|
799
|
+
csv: list[dict[str, Any]] = []
|
|
800
|
+
benchmark_func = benchmark_grouped if grouped else benchmark
|
|
801
|
+
|
|
802
|
+
opts = BenchOptions(
|
|
803
|
+
num_iters=num_iters,
|
|
804
|
+
cuda_graph=not no_cuda_graph,
|
|
805
|
+
rotating_buffer=use_rotating_buffer_bench,
|
|
806
|
+
rep_ms=rep,
|
|
807
|
+
trace=trace,
|
|
808
|
+
fast_accum=not disable_fast_accum,
|
|
809
|
+
torch_compile=torch_compile,
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
for m, n, k in MNK:
|
|
813
|
+
shape_measurements = benchmark_func(
|
|
814
|
+
gemm_ops,
|
|
815
|
+
m, # pyre-ignore[6]: Incompatible parameter type [6]
|
|
816
|
+
n, # pyre-ignore[6]: Incompatible parameter type [6]
|
|
817
|
+
k, # pyre-ignore[6]: Incompatible parameter type [6]
|
|
818
|
+
mem_bw_gbps,
|
|
819
|
+
opts,
|
|
820
|
+
bench_quantize,
|
|
821
|
+
shape_mode,
|
|
822
|
+
)
|
|
823
|
+
benchmark_results.extend(shape_measurements)
|
|
824
|
+
csv_row: dict[str, Any] = {}
|
|
825
|
+
for metric in shape_measurements:
|
|
826
|
+
csv_row.update(metric.as_dict())
|
|
827
|
+
csv.append(csv_row)
|
|
828
|
+
|
|
829
|
+
print("")
|
|
830
|
+
print(Metrics.header(shape_mode))
|
|
831
|
+
for metric in benchmark_results:
|
|
832
|
+
print(metric)
|
|
833
|
+
|
|
834
|
+
print("")
|
|
835
|
+
print(f"Hardware: {torch.cuda.get_device_name()}")
|
|
836
|
+
print(f" Memory BW: {mem_bw_gbps:.2f} GB/s")
|
|
837
|
+
|
|
838
|
+
print("")
|
|
839
|
+
print("Benchmark Settings:")
|
|
840
|
+
print(f" CUDA graph: {not no_cuda_graph}")
|
|
841
|
+
print(f" Buffer rotation: {use_rotating_buffer_bench}")
|
|
842
|
+
print(f" Fast accumulation: {not disable_fast_accum}")
|
|
843
|
+
print(f" Torch compile: {torch_compile}")
|
|
844
|
+
|
|
845
|
+
if export_csv or plot:
|
|
846
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
847
|
+
if export_csv:
|
|
848
|
+
datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
849
|
+
csv_file = os.path.join(output_dir, f"gemm_ops_benchmark_{datetime_str}.csv")
|
|
850
|
+
# Export results to a CSV file.
|
|
851
|
+
df = pd.DataFrame(csv)
|
|
852
|
+
df.to_csv(csv_file, na_rep="NaN", index=False)
|
|
853
|
+
print(f"CSV saved to {csv_file}")
|
|
854
|
+
if plot:
|
|
855
|
+
plot_benchmark(benchmark_results, output_dir)
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
if __name__ == "__main__":
|
|
859
|
+
invoke_main() # pragma: no cover
|