mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
import itertools
|
|
10
|
+
import os
|
|
11
|
+
import sys
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from typing import Callable, Optional
|
|
15
|
+
|
|
16
|
+
import click
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import torch
|
|
19
|
+
import triton # @manual=//triton:triton
|
|
20
|
+
from mslk.bench.common.utils import BenchOptions, profiler
|
|
21
|
+
from mslk.bench.quantize.quantize_ops import get_ops, QuantizeOpBase
|
|
22
|
+
from tabulate import tabulate
|
|
23
|
+
|
|
24
|
+
type ShapeFunction = Callable[[], list[tuple[int, int]]]
|
|
25
|
+
|
|
26
|
+
shape_registry: dict[str, ShapeFunction] = {}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def register_shapes(name: str) -> Callable[[ShapeFunction], ShapeFunction]:
|
|
30
|
+
def decorator(
|
|
31
|
+
shape_function: ShapeFunction,
|
|
32
|
+
) -> ShapeFunction:
|
|
33
|
+
shape_registry[name] = shape_function
|
|
34
|
+
return shape_function
|
|
35
|
+
|
|
36
|
+
return decorator
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@register_shapes("llm_eval")
|
|
40
|
+
def llm_eval() -> list[tuple[int, int]]:
|
|
41
|
+
return [
|
|
42
|
+
(1, 5120),
|
|
43
|
+
(1024, 5120),
|
|
44
|
+
(2000, 5120),
|
|
45
|
+
(4096, 5120),
|
|
46
|
+
(16384, 5120),
|
|
47
|
+
(1024, 7168),
|
|
48
|
+
(4096, 4096),
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@register_shapes("decode_1024")
|
|
53
|
+
def decode_1024_shapes() -> list[tuple[int, int]]:
|
|
54
|
+
return [
|
|
55
|
+
(1, 1024),
|
|
56
|
+
(1, 2048),
|
|
57
|
+
(1, 4096),
|
|
58
|
+
(1, 5120),
|
|
59
|
+
(1, 6144),
|
|
60
|
+
(1, 7168),
|
|
61
|
+
(1, 8192),
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@register_shapes("prefill_1024")
|
|
66
|
+
def prefill_1024_shapes() -> list[tuple[int, int]]:
|
|
67
|
+
shapes = []
|
|
68
|
+
for M in [2048, 4096, 8192, 16384]:
|
|
69
|
+
shapes += [
|
|
70
|
+
(M, 1024),
|
|
71
|
+
(M, 2048),
|
|
72
|
+
(M, 4096),
|
|
73
|
+
(M, 5120),
|
|
74
|
+
(M, 6144),
|
|
75
|
+
(M, 7168),
|
|
76
|
+
(M, 8192),
|
|
77
|
+
]
|
|
78
|
+
return shapes
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class Metrics:
|
|
83
|
+
op: str
|
|
84
|
+
M: int = 0
|
|
85
|
+
K: int = 0
|
|
86
|
+
sim: float = 0.0
|
|
87
|
+
us: float = 0.0
|
|
88
|
+
gbps: float = 0.0
|
|
89
|
+
memory_bw_util: float = 0.0
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def header() -> str:
|
|
93
|
+
header = f"{'OpName':<20} {'Problem Shape':<15} {'Sim':<10} {'Us':<10} {'GB/s':<10} {'Mem BW Util %':<10}"
|
|
94
|
+
divider = "-" * len(header)
|
|
95
|
+
return f"Quantize Bench\n{divider}\n{header}\n{divider}"
|
|
96
|
+
|
|
97
|
+
def __str__(self) -> str:
|
|
98
|
+
problem_shape = f"({self.M}, {self.K})"
|
|
99
|
+
return f"{self.op:<20} {problem_shape:<15} {self.sim:<10.3f} {self.us:<10.3f} {self.gbps:<10.2f} {self.memory_bw_util:<10.2f}"
|
|
100
|
+
|
|
101
|
+
def as_dict(self) -> dict[str, float]:
|
|
102
|
+
return {
|
|
103
|
+
"M": self.M,
|
|
104
|
+
"K": self.K,
|
|
105
|
+
f"{self.op}_sim": self.sim,
|
|
106
|
+
f"{self.op}_us": self.us,
|
|
107
|
+
f"{self.op}_gb/s": self.gbps,
|
|
108
|
+
f"{self.op}_memory_bw_util": self.memory_bw_util,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_problem_shapes(
|
|
113
|
+
shapes: Optional[str],
|
|
114
|
+
m: Optional[str],
|
|
115
|
+
k: Optional[str],
|
|
116
|
+
pair_mk: bool,
|
|
117
|
+
) -> list[tuple[int, int]]:
|
|
118
|
+
if shapes:
|
|
119
|
+
all_shapes = set()
|
|
120
|
+
|
|
121
|
+
for shape in shapes.strip().split(","):
|
|
122
|
+
if shape not in shape_registry:
|
|
123
|
+
print(
|
|
124
|
+
f"Shape {shape} not found in shape registry. Valid shapes: {', '.join(shape_registry.keys())}."
|
|
125
|
+
)
|
|
126
|
+
sys.exit(1)
|
|
127
|
+
all_shapes.update(shape_registry[shape]())
|
|
128
|
+
|
|
129
|
+
return list(all_shapes)
|
|
130
|
+
|
|
131
|
+
if m is None:
|
|
132
|
+
raise Exception("M must be non-empty.")
|
|
133
|
+
M = [int(m_val) for m_val in m.strip().split(",")]
|
|
134
|
+
if k is None:
|
|
135
|
+
raise Exception("K must be non-empty.")
|
|
136
|
+
K = [int(k_val) for k_val in k.strip().split(",")]
|
|
137
|
+
|
|
138
|
+
if pair_mk:
|
|
139
|
+
if len(M) != len(K):
|
|
140
|
+
raise Exception("M and K must be the same length in pair_MK mode.")
|
|
141
|
+
return list(zip(M, K))
|
|
142
|
+
else:
|
|
143
|
+
return list(itertools.product(M, K))
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def benchmark(
|
|
147
|
+
quantize_ops: list[QuantizeOpBase],
|
|
148
|
+
m: int,
|
|
149
|
+
k: int,
|
|
150
|
+
mem_bw_roofline_gbps: float,
|
|
151
|
+
opts: BenchOptions,
|
|
152
|
+
) -> list[Metrics]:
|
|
153
|
+
# Create input tensors.
|
|
154
|
+
input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
|
|
155
|
+
|
|
156
|
+
# Keep track of results.
|
|
157
|
+
results = []
|
|
158
|
+
# Benchmark each operator.
|
|
159
|
+
for quantize_op in quantize_ops:
|
|
160
|
+
metrics = Metrics(op=quantize_op.name, M=m, K=k)
|
|
161
|
+
args = quantize_op.preprocess(input)
|
|
162
|
+
quantized = quantize_op.quantize(input, *args)
|
|
163
|
+
dequantized = quantize_op.dequantize(*quantized)
|
|
164
|
+
metrics.sim = torch.mean(torch.pow(dequantized - input, 2)).item()
|
|
165
|
+
|
|
166
|
+
for _ in range(opts.num_iters):
|
|
167
|
+
with profiler(enabled=opts.trace, with_stack=True):
|
|
168
|
+
ms_runtime = quantize_op.benchmark(
|
|
169
|
+
input,
|
|
170
|
+
args,
|
|
171
|
+
opts=opts,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
input_bytes = input.numel() * input.element_size()
|
|
175
|
+
output_bytes = sum(t.numel() * t.element_size() for t in quantized)
|
|
176
|
+
total_size_bytes = input_bytes + output_bytes
|
|
177
|
+
gbps = (total_size_bytes / 1e9) / (ms_runtime / 1e3)
|
|
178
|
+
metrics.gbps += gbps
|
|
179
|
+
metrics.us += ms_runtime * 1000
|
|
180
|
+
metrics.memory_bw_util += (gbps / mem_bw_roofline_gbps) * 100
|
|
181
|
+
|
|
182
|
+
metrics.us /= opts.num_iters
|
|
183
|
+
metrics.gbps /= opts.num_iters
|
|
184
|
+
metrics.memory_bw_util /= opts.num_iters
|
|
185
|
+
|
|
186
|
+
results.append(metrics)
|
|
187
|
+
|
|
188
|
+
return results
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def collect_kernels_to_profile(kernels: Optional[list[str]]) -> list[QuantizeOpBase]:
|
|
192
|
+
# Get existing quantization operators.
|
|
193
|
+
quantize_ops = [op for op in get_ops() if op.supported]
|
|
194
|
+
if kernels is None:
|
|
195
|
+
return quantize_ops
|
|
196
|
+
return [op for op in quantize_ops if op.name in kernels]
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def print_kernels(kernels: Optional[list[str]]) -> None:
|
|
200
|
+
data = sorted(
|
|
201
|
+
[
|
|
202
|
+
(op.name, "Yes" if op.cuda else "No", "Yes" if op.hip else "No")
|
|
203
|
+
for op in get_ops()
|
|
204
|
+
]
|
|
205
|
+
)
|
|
206
|
+
print(tabulate(data, headers=["Name", "CUDA", "ROCm"], tablefmt="orgtbl"))
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@click.command()
|
|
210
|
+
@click.option(
|
|
211
|
+
"--output-dir",
|
|
212
|
+
default="/tmp",
|
|
213
|
+
help="Directory to save plots and csvs to",
|
|
214
|
+
)
|
|
215
|
+
@click.option(
|
|
216
|
+
"--num-iters",
|
|
217
|
+
default=1,
|
|
218
|
+
type=int,
|
|
219
|
+
help="Number of iterations to repeat each benchmark.",
|
|
220
|
+
)
|
|
221
|
+
@click.option(
|
|
222
|
+
"--export-csv",
|
|
223
|
+
is_flag=True,
|
|
224
|
+
help="Export results to a CSV file.",
|
|
225
|
+
)
|
|
226
|
+
@click.option(
|
|
227
|
+
"--kernels",
|
|
228
|
+
default=None,
|
|
229
|
+
help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
|
|
230
|
+
)
|
|
231
|
+
@click.option(
|
|
232
|
+
"--M",
|
|
233
|
+
default=None,
|
|
234
|
+
help="Comma separated list of M values to benchmark.",
|
|
235
|
+
)
|
|
236
|
+
@click.option(
|
|
237
|
+
"--K",
|
|
238
|
+
default=None,
|
|
239
|
+
help="Comma separated list of K values to benchmark.",
|
|
240
|
+
)
|
|
241
|
+
@click.option(
|
|
242
|
+
"--pair-MK",
|
|
243
|
+
is_flag=True,
|
|
244
|
+
help="If set, instead of benchmarking cartesian product of M * K, benchmark consecutive MK pairs together.",
|
|
245
|
+
)
|
|
246
|
+
@click.option(
|
|
247
|
+
"--no-cuda-graph",
|
|
248
|
+
is_flag=True,
|
|
249
|
+
help="If set, do not use cuda graph for benchmarking.",
|
|
250
|
+
)
|
|
251
|
+
@click.option(
|
|
252
|
+
"--no-rotating-buffer",
|
|
253
|
+
is_flag=True,
|
|
254
|
+
help="If set, do not use rotating buffer for benchmarking.",
|
|
255
|
+
)
|
|
256
|
+
@click.option(
|
|
257
|
+
"--shapes",
|
|
258
|
+
default=None,
|
|
259
|
+
help=f"Specific model shapes to use, options: {', '.join(shape_registry.keys())}.",
|
|
260
|
+
)
|
|
261
|
+
@click.option(
|
|
262
|
+
"--trace",
|
|
263
|
+
is_flag=True,
|
|
264
|
+
help="If set, produce a performance trace of the benchmark.",
|
|
265
|
+
)
|
|
266
|
+
def invoke_main(
|
|
267
|
+
output_dir: str,
|
|
268
|
+
num_iters: int,
|
|
269
|
+
export_csv: bool,
|
|
270
|
+
kernels: Optional[str],
|
|
271
|
+
m: Optional[str],
|
|
272
|
+
k: Optional[str],
|
|
273
|
+
pair_mk: bool,
|
|
274
|
+
no_cuda_graph: bool,
|
|
275
|
+
no_rotating_buffer: bool,
|
|
276
|
+
shapes: Optional[str],
|
|
277
|
+
trace: bool,
|
|
278
|
+
) -> None:
|
|
279
|
+
# If kernel filter is provided, parse it. Else, benchmark all kernels.
|
|
280
|
+
all_kernels = kernels.strip().split(",") if kernels else None
|
|
281
|
+
quantize_ops = collect_kernels_to_profile(all_kernels)
|
|
282
|
+
|
|
283
|
+
if len(quantize_ops) == 0:
|
|
284
|
+
print("No valid kernels to benchmark. Available kernels:")
|
|
285
|
+
print_kernels(all_kernels)
|
|
286
|
+
sys.exit(1)
|
|
287
|
+
|
|
288
|
+
if num_iters < 1:
|
|
289
|
+
print("Warning: Number of iterations must be at least 1.")
|
|
290
|
+
num_iters = 1
|
|
291
|
+
|
|
292
|
+
mem_bw_roofline_gbps = triton.testing.get_dram_gbps()
|
|
293
|
+
MK = get_problem_shapes(shapes, m, k, pair_mk)
|
|
294
|
+
|
|
295
|
+
opts = BenchOptions(
|
|
296
|
+
num_iters=num_iters,
|
|
297
|
+
cuda_graph=not no_cuda_graph,
|
|
298
|
+
rotating_buffer=not no_rotating_buffer,
|
|
299
|
+
trace=trace,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Iterate over shapes and benchmark.
|
|
303
|
+
benchmark_results = []
|
|
304
|
+
csv = []
|
|
305
|
+
for M, K in MK:
|
|
306
|
+
quantize_measurements = benchmark(
|
|
307
|
+
quantize_ops,
|
|
308
|
+
M,
|
|
309
|
+
K,
|
|
310
|
+
mem_bw_roofline_gbps,
|
|
311
|
+
opts,
|
|
312
|
+
)
|
|
313
|
+
benchmark_results.extend(quantize_measurements)
|
|
314
|
+
csv_row = {}
|
|
315
|
+
for metric in quantize_measurements:
|
|
316
|
+
csv_row.update(metric.as_dict())
|
|
317
|
+
csv.append(csv_row)
|
|
318
|
+
|
|
319
|
+
print(Metrics.header())
|
|
320
|
+
for metric in benchmark_results:
|
|
321
|
+
print(metric)
|
|
322
|
+
|
|
323
|
+
print("")
|
|
324
|
+
print(f"Hardware: {torch.cuda.get_device_name()}")
|
|
325
|
+
print(f" Memory BW Roofline: {mem_bw_roofline_gbps} GB/s")
|
|
326
|
+
|
|
327
|
+
print("")
|
|
328
|
+
print("Benchmark Settings:")
|
|
329
|
+
print(f" CUDA graph: {opts.cuda_graph}")
|
|
330
|
+
print(f" Buffer rotation: {opts.rotating_buffer}")
|
|
331
|
+
|
|
332
|
+
if export_csv:
|
|
333
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
334
|
+
datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
335
|
+
csv_file = os.path.join(
|
|
336
|
+
output_dir, f"quantize_ops_benchmark_{datetime_str}.csv"
|
|
337
|
+
)
|
|
338
|
+
# Export results to a CSV file.
|
|
339
|
+
df = pd.DataFrame(csv)
|
|
340
|
+
df.to_csv(csv_file, na_rep="NaN", index=False)
|
|
341
|
+
print(f"CSV saved to {csv_file}")
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
if __name__ == "__main__":
|
|
345
|
+
invoke_main()
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
import abc
|
|
10
|
+
from typing import Any, TypeVar
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from mslk.bench.common.utils import BenchOptions, do_bench
|
|
14
|
+
from mslk.quantize.triton.fp4_quantize import triton_quantize_nvfp4
|
|
15
|
+
from mslk.quantize.triton.fp8_quantize import (
|
|
16
|
+
dequantize_fp8_block,
|
|
17
|
+
dequantize_fp8_row,
|
|
18
|
+
triton_quantize_fp8_block,
|
|
19
|
+
triton_quantize_fp8_group,
|
|
20
|
+
triton_quantize_fp8_row,
|
|
21
|
+
triton_quantize_fp8_tensor,
|
|
22
|
+
)
|
|
23
|
+
from mslk.test.quantize.triton.fp4_quantize_test import (
|
|
24
|
+
dequantize_nvfp4,
|
|
25
|
+
global_scale_nvfp4,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class QuantizeOpBase(metaclass=abc.ABCMeta):
|
|
30
|
+
"""Helper abstract class to define expected methods of quantize ops."""
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def quantize(self, input: torch.Tensor) -> Any:
|
|
34
|
+
"""Function which quantizes inputs."""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@abc.abstractmethod
|
|
38
|
+
def dequantize(self, *args: Any) -> torch.Tensor:
|
|
39
|
+
"""Function which dequantizes inputs. Used for sanity checking."""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abc.abstractproperty
|
|
43
|
+
def hip(self) -> bool:
|
|
44
|
+
"""Whether this operator supports AMD or not."""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
@abc.abstractproperty
|
|
48
|
+
def cuda(self) -> bool:
|
|
49
|
+
"""Whether this operator supports Nvidia or not."""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
def preprocess(self, input: torch.Tensor) -> Any:
|
|
53
|
+
"""This is used for ops that require additional preprocessing. This method will not be counted in benchmarking."""
|
|
54
|
+
return ()
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def name(self) -> str:
|
|
58
|
+
"""Name of this operator."""
|
|
59
|
+
return self.__class__.__name__
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def supported(self) -> bool:
|
|
63
|
+
"""Whether this op will run on the current device."""
|
|
64
|
+
if torch.version.hip is not None:
|
|
65
|
+
return self.hip
|
|
66
|
+
elif torch.version.cuda is not None:
|
|
67
|
+
return self.cuda
|
|
68
|
+
else:
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
def benchmark(
|
|
72
|
+
self,
|
|
73
|
+
input: torch.Tensor,
|
|
74
|
+
args: Any,
|
|
75
|
+
opts: BenchOptions,
|
|
76
|
+
) -> float:
|
|
77
|
+
"""Benchmark runtime of this operator using do_bench from common."""
|
|
78
|
+
return do_bench(
|
|
79
|
+
lambda inp, *a: self.quantize(inp, *a),
|
|
80
|
+
(input, *args),
|
|
81
|
+
opts,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
op_registry: dict[str, QuantizeOpBase] = {}
|
|
86
|
+
|
|
87
|
+
T = TypeVar("T", bound=QuantizeOpBase)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def register_op(op_class: type[T]) -> type[T]:
|
|
91
|
+
"""Decorator function for assembling all quantize ops."""
|
|
92
|
+
op_registry[op_class.__name__] = op_class()
|
|
93
|
+
return op_class
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def get_ops() -> list[QuantizeOpBase]:
|
|
97
|
+
"""Get all registered quantize ops."""
|
|
98
|
+
return list(op_registry.values())
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@register_op
|
|
102
|
+
class TritonFP8Rowwise(QuantizeOpBase):
|
|
103
|
+
def quantize(self, input: torch.Tensor) -> Any:
|
|
104
|
+
return triton_quantize_fp8_row(input)
|
|
105
|
+
|
|
106
|
+
def dequantize(self, *args: Any) -> torch.Tensor:
|
|
107
|
+
input_quantized: torch.Tensor
|
|
108
|
+
scale: torch.Tensor
|
|
109
|
+
input_quantized, scale = args
|
|
110
|
+
return dequantize_fp8_row(input_quantized, scale)
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def hip(self) -> bool:
|
|
114
|
+
return True
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def cuda(self) -> bool:
|
|
118
|
+
return True
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@register_op
|
|
122
|
+
class TritonFP8Blockwise(QuantizeOpBase):
|
|
123
|
+
def __init__(self) -> None:
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.block_m = 128
|
|
126
|
+
self.block_k = 128
|
|
127
|
+
|
|
128
|
+
def quantize(self, input: torch.Tensor) -> Any:
|
|
129
|
+
return triton_quantize_fp8_block(
|
|
130
|
+
input, block_m=self.block_m, block_k=self.block_k
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def dequantize(self, *args: Any) -> torch.Tensor:
|
|
134
|
+
input_quantized: torch.Tensor
|
|
135
|
+
scale: torch.Tensor
|
|
136
|
+
input_quantized, scale = args
|
|
137
|
+
return dequantize_fp8_block(input_quantized, scale, self.block_m, self.block_k)
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def hip(self) -> bool:
|
|
141
|
+
return True
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def cuda(self) -> bool:
|
|
145
|
+
return True
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@register_op
|
|
149
|
+
class TritonFP8Groupwise(QuantizeOpBase):
|
|
150
|
+
def __init__(self) -> None:
|
|
151
|
+
super().__init__()
|
|
152
|
+
self.group_size = 128
|
|
153
|
+
|
|
154
|
+
def quantize(self, input: torch.Tensor) -> Any:
|
|
155
|
+
return triton_quantize_fp8_group(input, group_size=self.group_size)
|
|
156
|
+
|
|
157
|
+
def dequantize(self, *args: Any) -> torch.Tensor:
|
|
158
|
+
input_quantized: torch.Tensor
|
|
159
|
+
scale: torch.Tensor
|
|
160
|
+
input_quantized, scale = args
|
|
161
|
+
|
|
162
|
+
input_quantized = input_quantized.to(torch.float)
|
|
163
|
+
dequantized = input_quantized.view(
|
|
164
|
+
-1, input_quantized.shape[1] // self.group_size, self.group_size
|
|
165
|
+
) * scale.unsqueeze(-1)
|
|
166
|
+
return dequantized.view(input_quantized.shape)
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def hip(self) -> bool:
|
|
170
|
+
return True
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def cuda(self) -> bool:
|
|
174
|
+
return True
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@register_op
|
|
178
|
+
class TritonNVFP4(QuantizeOpBase):
|
|
179
|
+
def __init__(self) -> None:
|
|
180
|
+
super().__init__()
|
|
181
|
+
|
|
182
|
+
def preprocess(self, input: torch.Tensor) -> Any:
|
|
183
|
+
global_scale = global_scale_nvfp4(input)
|
|
184
|
+
return (global_scale,)
|
|
185
|
+
|
|
186
|
+
def quantize(self, input: torch.Tensor, *args: Any) -> Any:
|
|
187
|
+
global_scale: torch.Tensor
|
|
188
|
+
global_scale = args[0]
|
|
189
|
+
input_quantized, scales = triton_quantize_nvfp4(input, global_scale)
|
|
190
|
+
return input_quantized.view(torch.uint8), scales, global_scale
|
|
191
|
+
|
|
192
|
+
def dequantize(self, *args: Any) -> torch.Tensor:
|
|
193
|
+
input_quantized: torch.Tensor
|
|
194
|
+
scale: torch.Tensor
|
|
195
|
+
global_scale: torch.Tensor
|
|
196
|
+
input_quantized, scale, global_scale = args
|
|
197
|
+
|
|
198
|
+
return dequantize_nvfp4(input_quantized, scale, global_scale)
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def hip(self) -> bool:
|
|
202
|
+
return False
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def cuda(self) -> bool:
|
|
206
|
+
return True
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@register_op
|
|
210
|
+
class CudaFP8Rowwise(QuantizeOpBase):
|
|
211
|
+
def quantize(self, input: torch.Tensor) -> Any:
|
|
212
|
+
return torch.ops.mslk.quantize_fp8_per_row(input)
|
|
213
|
+
|
|
214
|
+
def dequantize(self, *args: Any) -> torch.Tensor:
|
|
215
|
+
input_quantized: torch.Tensor
|
|
216
|
+
scale: torch.Tensor
|
|
217
|
+
input_quantized, scale = args
|
|
218
|
+
return dequantize_fp8_row(input_quantized, scale)
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def hip(self) -> bool:
|
|
222
|
+
return True
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def cuda(self) -> bool:
|
|
226
|
+
return True
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@register_op
|
|
230
|
+
class CudaFP8Tensorwise(QuantizeOpBase):
|
|
231
|
+
def quantize(self, input: torch.Tensor) -> Any:
|
|
232
|
+
return torch.ops.mslk.quantize_fp8_per_tensor(input)
|
|
233
|
+
|
|
234
|
+
def dequantize(self, *args: Any) -> torch.Tensor:
|
|
235
|
+
input_quantized: torch.Tensor
|
|
236
|
+
scale: torch.Tensor
|
|
237
|
+
input_quantized, scale = args
|
|
238
|
+
return input_quantized.to(torch.float32) * scale
|
|
239
|
+
|
|
240
|
+
@property
|
|
241
|
+
def hip(self) -> bool:
|
|
242
|
+
return True
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def cuda(self) -> bool:
|
|
246
|
+
return True
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@register_op
|
|
250
|
+
class TritonFP8Tensorwise(QuantizeOpBase):
|
|
251
|
+
def quantize(self, input: torch.Tensor) -> Any:
|
|
252
|
+
return triton_quantize_fp8_tensor(input)
|
|
253
|
+
|
|
254
|
+
def dequantize(self, *args: Any) -> torch.Tensor:
|
|
255
|
+
input_quantized: torch.Tensor
|
|
256
|
+
scale: torch.Tensor
|
|
257
|
+
input_quantized, scale = args
|
|
258
|
+
return input_quantized.to(torch.float32) * scale
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def hip(self) -> bool:
|
|
262
|
+
return True
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def cuda(self) -> bool:
|
|
266
|
+
return True
|
mslk/comm/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
from mslk.utils.torch.library import load_library_buck
|
|
10
|
+
|
|
11
|
+
load_library_buck("//mslk/csrc/comm:car_ops")
|
mslk/conv/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
from mslk.utils.torch.library import load_library_buck
|
|
10
|
+
|
|
11
|
+
load_library_buck("//mslk/csrc/conv:conv_ops")
|
mslk/gemm/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
from mslk.utils.torch.library import load_library_buck
|
|
10
|
+
|
|
11
|
+
load_library_buck("//mslk/csrc/gemm:gemm_ops")
|
|
12
|
+
|
|
13
|
+
gemm_ops = [
|
|
14
|
+
"//mslk/csrc/gemm/cutlass:cutlass_bf16bf16bf16_grouped_grad",
|
|
15
|
+
"//mslk/csrc/gemm/cutlass:cutlass_bf16bf16bf16_grouped_wgrad",
|
|
16
|
+
]
|
|
17
|
+
for op in gemm_ops:
|
|
18
|
+
load_library_buck(op)
|