mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
import contextlib
|
|
10
|
+
import copy
|
|
11
|
+
import os
|
|
12
|
+
import tempfile
|
|
13
|
+
import uuid
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from typing import Any, Callable
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import triton # @manual=//triton:triton
|
|
19
|
+
from torch.profiler import profile, ProfilerActivity # pyre-ignore
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class BenchOptions:
|
|
24
|
+
"""Common benchmark options used across all benchmark scripts.
|
|
25
|
+
|
|
26
|
+
This dataclass encapsulates all configuration options for running benchmarks
|
|
27
|
+
in the MSLK benchmark suite. It is used by gemm_bench, conv_bench, and
|
|
28
|
+
quantize_bench to maintain consistent benchmarking behavior.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
num_iters: Number of iterations to repeat each benchmark for averaging.
|
|
32
|
+
cuda_graph: Whether to use CUDA graphs for benchmarking. CUDA graphs
|
|
33
|
+
reduce kernel launch overhead and provide more accurate measurements
|
|
34
|
+
for GPU-bound workloads.
|
|
35
|
+
rotating_buffer: Whether to use a rotating buffer during benchmarking.
|
|
36
|
+
This helps flush L2/L3 cache between iterations to get more realistic
|
|
37
|
+
memory-bound performance measurements.
|
|
38
|
+
rep_ms: Repetition time in milliseconds for triton.testing.do_bench.
|
|
39
|
+
Controls how long each benchmark runs before measuring.
|
|
40
|
+
trace: Whether to produce a performance trace of the benchmark using
|
|
41
|
+
PyTorch profiler. Traces are saved to Manifold or temp directory.
|
|
42
|
+
fast_accum: Whether to enable fast accumulation for FP8 implementations.
|
|
43
|
+
This is only relevant for Hopper GPUs.
|
|
44
|
+
torch_compile: Whether to use torch.compile for applicable operations.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
num_iters: int = 1
|
|
48
|
+
cuda_graph: bool = True
|
|
49
|
+
rotating_buffer: bool = False
|
|
50
|
+
rep_ms: int = 200
|
|
51
|
+
trace: bool = False
|
|
52
|
+
fast_accum: bool = False
|
|
53
|
+
torch_compile: bool = False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _do_bench(
|
|
57
|
+
fn: Callable[[], Any],
|
|
58
|
+
opts: BenchOptions,
|
|
59
|
+
) -> float:
|
|
60
|
+
if opts.cuda_graph:
|
|
61
|
+
with torch.cuda.stream(torch.cuda.Stream()):
|
|
62
|
+
return triton.testing.do_bench_cudagraph(fn, rep=opts.rep_ms)
|
|
63
|
+
else:
|
|
64
|
+
return triton.testing.do_bench(fn, rep=opts.rep_ms)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def do_bench(
|
|
68
|
+
fn: Callable[..., Any],
|
|
69
|
+
args: tuple[Any, ...],
|
|
70
|
+
opts: BenchOptions,
|
|
71
|
+
) -> float:
|
|
72
|
+
"""
|
|
73
|
+
Benchmark a function using triton's benchmarking utilities.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
fn: The function to benchmark.
|
|
77
|
+
args: Tuple of arguments to pass to the function.
|
|
78
|
+
opts: Benchmark options
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
The runtime in milliseconds.
|
|
82
|
+
"""
|
|
83
|
+
if not opts.rotating_buffer:
|
|
84
|
+
return _do_bench(lambda: fn(*args), opts)
|
|
85
|
+
|
|
86
|
+
# Calculate input size to determine how many copies we need.
|
|
87
|
+
input_size_bytes = sum(
|
|
88
|
+
t.element_size() * t.numel() for t in args if isinstance(t, torch.Tensor)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Use a 50MB buffer, this may need to be variable in the future.
|
|
92
|
+
rotating_buffer_size_bytes = 50 * 1024 * 1024
|
|
93
|
+
# Make at least one copy of the inputs.
|
|
94
|
+
copy_cnt = max(rotating_buffer_size_bytes // input_size_bytes, 1)
|
|
95
|
+
|
|
96
|
+
args_list = [args]
|
|
97
|
+
for _ in range(copy_cnt):
|
|
98
|
+
args_list.append(copy.deepcopy(args))
|
|
99
|
+
|
|
100
|
+
# We benchmark on a different stream, so a sync is required.
|
|
101
|
+
torch.cuda.synchronize()
|
|
102
|
+
|
|
103
|
+
def rotating_buffer_fn() -> None:
|
|
104
|
+
for a in args_list:
|
|
105
|
+
fn(*a)
|
|
106
|
+
|
|
107
|
+
return _do_bench(rotating_buffer_fn, opts) / len(args_list)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def profiler(
|
|
111
|
+
enabled: bool,
|
|
112
|
+
with_stack: bool = False,
|
|
113
|
+
record_shapes: bool = False,
|
|
114
|
+
):
|
|
115
|
+
"""
|
|
116
|
+
Returns a profiler context manager if enabled, otherwise a null context.
|
|
117
|
+
|
|
118
|
+
When enabled, profiles CPU and CUDA activities.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
enabled: Whether to enable profiling.
|
|
122
|
+
with_stack: Whether to record stack traces.
|
|
123
|
+
record_shapes: Whether to record tensor shapes.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
A context manager - either a torch profiler or nullcontext.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def _kineto_trace_handler(p: torch.profiler.profile) -> None:
|
|
130
|
+
trace_filename = f"mslk_{os.getpid()}_{uuid.uuid4().hex}.json"
|
|
131
|
+
|
|
132
|
+
if os.path.exists("/etc/fbwhoami"):
|
|
133
|
+
trace_url = f"manifold://gpu_traces/tree/accelerator/{trace_filename}"
|
|
134
|
+
else:
|
|
135
|
+
trace_url = os.path.join(tempfile.gettempdir(), trace_filename)
|
|
136
|
+
|
|
137
|
+
p.export_chrome_trace(trace_url)
|
|
138
|
+
|
|
139
|
+
return (
|
|
140
|
+
profile(
|
|
141
|
+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], # pyre-ignore
|
|
142
|
+
on_trace_ready=_kineto_trace_handler,
|
|
143
|
+
with_stack=with_stack,
|
|
144
|
+
record_shapes=record_shapes,
|
|
145
|
+
)
|
|
146
|
+
if enabled
|
|
147
|
+
else contextlib.nullcontext()
|
|
148
|
+
)
|
|
@@ -0,0 +1,551 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import itertools
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from typing import Any, Optional
|
|
13
|
+
|
|
14
|
+
import click
|
|
15
|
+
import matplotlib.pyplot as plt
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import seaborn as sns
|
|
18
|
+
import torch
|
|
19
|
+
import triton # @manual=//triton:triton
|
|
20
|
+
from mslk.bench.common.utils import BenchOptions, profiler
|
|
21
|
+
from mslk.bench.conv.conv_ops import ConvOpBase, get_conv_ops
|
|
22
|
+
from tabulate import tabulate
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
shape_registry = {}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def register_shapes(name):
|
|
29
|
+
def decorator(op):
|
|
30
|
+
shape_registry[name] = op
|
|
31
|
+
return op
|
|
32
|
+
|
|
33
|
+
return decorator
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@register_shapes("default")
|
|
37
|
+
def default_shapes() -> list[
|
|
38
|
+
tuple[int, int, int, int, int, int, int, int, int, int, int, int]
|
|
39
|
+
]:
|
|
40
|
+
"""
|
|
41
|
+
Default convolution shapes for benchmarking.
|
|
42
|
+
|
|
43
|
+
Returns tuples of (N, D, H, W, C, K, T, R, S, pad, stride, dilation)
|
|
44
|
+
where:
|
|
45
|
+
N = batch size
|
|
46
|
+
D, H, W = input spatial dimensions (depth, height, width)
|
|
47
|
+
C = input channels
|
|
48
|
+
K = output channels (filters)
|
|
49
|
+
T, R, S = kernel spatial dimensions
|
|
50
|
+
pad, stride, dilation = convolution parameters (applied uniformly to D, H, W)
|
|
51
|
+
"""
|
|
52
|
+
shapes = []
|
|
53
|
+
# Common batch sizes
|
|
54
|
+
for N in [1, 4, 8, 16]:
|
|
55
|
+
# Common configurations
|
|
56
|
+
shapes += [
|
|
57
|
+
# Small spatial dimensions, various channel sizes
|
|
58
|
+
(N, 8, 8, 8, 64, 64, 3, 3, 3, 1, 1, 1),
|
|
59
|
+
(N, 8, 8, 8, 64, 128, 3, 3, 3, 1, 1, 1),
|
|
60
|
+
(N, 8, 8, 8, 128, 128, 3, 3, 3, 1, 1, 1),
|
|
61
|
+
# Medium spatial dimensions
|
|
62
|
+
(N, 16, 16, 16, 64, 64, 3, 3, 3, 1, 1, 1),
|
|
63
|
+
(N, 16, 16, 16, 64, 128, 3, 3, 3, 1, 1, 1),
|
|
64
|
+
# Larger spatial dimensions
|
|
65
|
+
(N, 32, 32, 32, 32, 64, 3, 3, 3, 1, 1, 1),
|
|
66
|
+
(N, 32, 32, 32, 64, 64, 3, 3, 3, 1, 1, 1),
|
|
67
|
+
# 1x1x1 convolutions (common in ResNets)
|
|
68
|
+
(N, 16, 16, 16, 64, 128, 1, 1, 1, 0, 1, 1),
|
|
69
|
+
(N, 16, 16, 16, 128, 256, 1, 1, 1, 0, 1, 1),
|
|
70
|
+
]
|
|
71
|
+
return shapes
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass
|
|
75
|
+
class Metrics:
|
|
76
|
+
op: str
|
|
77
|
+
N: int = 0
|
|
78
|
+
D: int = 0
|
|
79
|
+
H: int = 0
|
|
80
|
+
W: int = 0
|
|
81
|
+
C: int = 0
|
|
82
|
+
K: int = 0
|
|
83
|
+
T: int = 0
|
|
84
|
+
R: int = 0
|
|
85
|
+
S: int = 0
|
|
86
|
+
pad: int = 0
|
|
87
|
+
stride: int = 0
|
|
88
|
+
dilation: int = 0
|
|
89
|
+
|
|
90
|
+
sim: float = 0.0
|
|
91
|
+
ms: float = 0.0
|
|
92
|
+
tflops: float = 0.0
|
|
93
|
+
gbps: float = 0.0
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def header() -> str:
|
|
97
|
+
header = (
|
|
98
|
+
f"{'OpName':<20} {'(N,D,H,W,C,K,T,R,S,pad,stride,dilation)':<50} "
|
|
99
|
+
f"{'Sim':<10} {'ms':<10} {'TFLOPS':<10} {'GB/s':<10}"
|
|
100
|
+
)
|
|
101
|
+
divider = "-" * len(header)
|
|
102
|
+
return f"Conv Bench\n{divider}\n{header}\n{divider}"
|
|
103
|
+
|
|
104
|
+
def __str__(self) -> str:
|
|
105
|
+
problem_shape = (
|
|
106
|
+
f"({self.N},{self.D},{self.H},{self.W},{self.C},{self.K},"
|
|
107
|
+
f"{self.T},{self.R},{self.S},{self.pad},{self.stride},{self.dilation})"
|
|
108
|
+
)
|
|
109
|
+
return (
|
|
110
|
+
f"{self.op:<20} {problem_shape:<50} "
|
|
111
|
+
f"{self.sim:<10.3f} {self.ms:<10.3f} {self.tflops:<10.2f} {self.gbps:<10.2f}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def as_dict(self) -> dict[str, float]:
|
|
115
|
+
return {
|
|
116
|
+
"N": self.N,
|
|
117
|
+
"D": self.D,
|
|
118
|
+
"H": self.H,
|
|
119
|
+
"W": self.W,
|
|
120
|
+
"C": self.C,
|
|
121
|
+
"K": self.K,
|
|
122
|
+
"T": self.T,
|
|
123
|
+
"R": self.R,
|
|
124
|
+
"S": self.S,
|
|
125
|
+
"pad": self.pad,
|
|
126
|
+
"stride": self.stride,
|
|
127
|
+
"dilation": self.dilation,
|
|
128
|
+
f"{self.op}_sim": self.sim,
|
|
129
|
+
f"{self.op}_ms": self.ms,
|
|
130
|
+
f"{self.op}_tflops": self.tflops,
|
|
131
|
+
f"{self.op}_gb/s": self.gbps,
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def benchmark(
|
|
136
|
+
conv_ops: list[ConvOpBase],
|
|
137
|
+
n: int,
|
|
138
|
+
d: int,
|
|
139
|
+
h: int,
|
|
140
|
+
w: int,
|
|
141
|
+
c: int,
|
|
142
|
+
k: int,
|
|
143
|
+
t: int,
|
|
144
|
+
r: int,
|
|
145
|
+
s: int,
|
|
146
|
+
pad: int,
|
|
147
|
+
stride: int,
|
|
148
|
+
dilation: int,
|
|
149
|
+
opts: BenchOptions,
|
|
150
|
+
bench_quantize: bool = False,
|
|
151
|
+
) -> list[Metrics]:
|
|
152
|
+
# Create input tensors in NCDHW format
|
|
153
|
+
activation = torch.randn(n, c, d, h, w, device="cuda", dtype=torch.bfloat16)
|
|
154
|
+
# Create filter tensors in KCTRS format
|
|
155
|
+
filter = torch.randn(k, c, t, r, s, device="cuda", dtype=torch.bfloat16)
|
|
156
|
+
|
|
157
|
+
# Convolution parameters (uniform for all dimensions)
|
|
158
|
+
padding = [pad, pad, pad]
|
|
159
|
+
stride_vec = [stride, stride, stride]
|
|
160
|
+
dilation_vec = [dilation, dilation, dilation]
|
|
161
|
+
|
|
162
|
+
# Compute baseline output for correctness checking using PyTorch
|
|
163
|
+
# Convert to NCDHW for PyTorch
|
|
164
|
+
out_ref = torch.nn.functional.conv3d(
|
|
165
|
+
activation,
|
|
166
|
+
filter,
|
|
167
|
+
bias=None,
|
|
168
|
+
stride=stride_vec,
|
|
169
|
+
padding=padding,
|
|
170
|
+
dilation=dilation_vec,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Keep track of results.
|
|
174
|
+
results = []
|
|
175
|
+
|
|
176
|
+
# Benchmark each operator.
|
|
177
|
+
for conv_op in conv_ops:
|
|
178
|
+
print(
|
|
179
|
+
f"Benchmarking {conv_op.name} with "
|
|
180
|
+
f"(N={n}, D={d}, H={h}, W={w}, C={c}, K={k}, T={t}, R={r}, S={s}, "
|
|
181
|
+
f"pad={pad}, stride={stride}, dilation={dilation})"
|
|
182
|
+
)
|
|
183
|
+
metrics = Metrics(
|
|
184
|
+
op=conv_op.name,
|
|
185
|
+
N=n,
|
|
186
|
+
D=d,
|
|
187
|
+
H=h,
|
|
188
|
+
W=w,
|
|
189
|
+
C=c,
|
|
190
|
+
K=k,
|
|
191
|
+
T=t,
|
|
192
|
+
R=r,
|
|
193
|
+
S=s,
|
|
194
|
+
pad=pad,
|
|
195
|
+
stride=stride,
|
|
196
|
+
dilation=dilation,
|
|
197
|
+
)
|
|
198
|
+
if hasattr(conv_op, "torch_compile"):
|
|
199
|
+
conv_op.torch_compile = opts.torch_compile
|
|
200
|
+
# Preprocess data if needed.
|
|
201
|
+
preprocessed_args = conv_op.preprocess(
|
|
202
|
+
activation, filter, padding, stride_vec, dilation_vec
|
|
203
|
+
)
|
|
204
|
+
# Get the quantized tensors for this operator.
|
|
205
|
+
quantized_vals = conv_op.quantize(*preprocessed_args)
|
|
206
|
+
# Compute the output given quantized values.
|
|
207
|
+
output = conv_op.compute(*quantized_vals)
|
|
208
|
+
# Compare the quantize op output to reference as a sanity check.
|
|
209
|
+
metrics.sim = torch.mean(torch.pow(output - out_ref, 2)).item()
|
|
210
|
+
|
|
211
|
+
# Compute output spatial dimensions
|
|
212
|
+
z = 1 + (d + 2 * pad - ((t - 1) * dilation + 1)) // stride
|
|
213
|
+
p = 1 + (h + 2 * pad - ((r - 1) * dilation + 1)) // stride
|
|
214
|
+
q = 1 + (w + 2 * pad - ((s - 1) * dilation + 1)) // stride
|
|
215
|
+
|
|
216
|
+
for _ in range(opts.num_iters):
|
|
217
|
+
# Now perform benchmark.
|
|
218
|
+
if bench_quantize:
|
|
219
|
+
# Benchmark both quantize and compute.
|
|
220
|
+
with profiler(enabled=opts.trace, with_stack=True):
|
|
221
|
+
ms_runtime = conv_op.benchmark(
|
|
222
|
+
*preprocessed_args,
|
|
223
|
+
opts=opts,
|
|
224
|
+
bench_quantize=True,
|
|
225
|
+
)
|
|
226
|
+
else:
|
|
227
|
+
with profiler(enabled=opts.trace, with_stack=True):
|
|
228
|
+
ms_runtime = conv_op.benchmark(
|
|
229
|
+
*quantized_vals,
|
|
230
|
+
opts=opts,
|
|
231
|
+
bench_quantize=False,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Compute performance metrics
|
|
235
|
+
# FLOPs for convolution: 2 * N * Z * P * Q * K * T * R * S * C
|
|
236
|
+
flops = 2 * n * z * p * q * k * t * r * s * c
|
|
237
|
+
metrics.tflops += flops / (ms_runtime / 1e3) / 1e12
|
|
238
|
+
|
|
239
|
+
# Compute memory bandwidth
|
|
240
|
+
# Input: N * D * H * W * C, Filter: K * T * R * S * C, Output: N * Z * P * Q * K
|
|
241
|
+
input_size = n * d * h * w * c * quantized_vals[0].element_size()
|
|
242
|
+
filter_size = k * t * r * s * c * quantized_vals[1].element_size()
|
|
243
|
+
output_size = n * z * p * q * k * output.element_size()
|
|
244
|
+
|
|
245
|
+
metrics.gbps += (
|
|
246
|
+
(input_size + filter_size + output_size) / (ms_runtime / 1e3) / 1e9
|
|
247
|
+
)
|
|
248
|
+
metrics.ms += ms_runtime
|
|
249
|
+
|
|
250
|
+
# Average metrics over iterations.
|
|
251
|
+
metrics.ms /= opts.num_iters
|
|
252
|
+
metrics.tflops /= opts.num_iters
|
|
253
|
+
metrics.gbps /= opts.num_iters
|
|
254
|
+
|
|
255
|
+
results.append(metrics)
|
|
256
|
+
|
|
257
|
+
return results
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def plot_benchmark(results: list[dict[str, Any]], output_dir: str) -> None:
|
|
261
|
+
"""Create a barplot visualizing the TFLOPS of each kernel."""
|
|
262
|
+
# Reprocess into new dataframe with proper graph format.
|
|
263
|
+
data = []
|
|
264
|
+
# Extract measurements for each shape.
|
|
265
|
+
for impl in results:
|
|
266
|
+
shape_str = f"N{impl['N']}_D{impl['D']}_H{impl['H']}_W{impl['W']}_C{impl['C']}_K{impl['K']}"
|
|
267
|
+
# Iterate over keys to find tflops entries.
|
|
268
|
+
for key in impl:
|
|
269
|
+
if "tflops" in key:
|
|
270
|
+
op_name = key.split("_tflops")[0]
|
|
271
|
+
op_tflops = impl[key]
|
|
272
|
+
data.append(
|
|
273
|
+
{"Shape": shape_str, "kernel": op_name, "TFLOPS": op_tflops}
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Create a barplot using seaborn.
|
|
277
|
+
df = pd.DataFrame(data)
|
|
278
|
+
plot = plt.figure()
|
|
279
|
+
plt.xticks(rotation=30)
|
|
280
|
+
plt.yscale("log")
|
|
281
|
+
ax = sns.barplot(x="Shape", y="TFLOPS", hue="kernel", data=df)
|
|
282
|
+
ax.tick_params(axis="x", labelsize=3)
|
|
283
|
+
img_fn = os.path.join(output_dir, "conv_ops_benchmark.png")
|
|
284
|
+
plot.savefig(img_fn, dpi=300)
|
|
285
|
+
print(f"Plot saved to {img_fn}")
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def collect_kernels_to_profile(kernels: Optional[list[str]]) -> list[ConvOpBase]:
|
|
289
|
+
# Get existing convolution operators.
|
|
290
|
+
conv_ops = [op for op in get_conv_ops() if op.supported]
|
|
291
|
+
if kernels is None:
|
|
292
|
+
return conv_ops
|
|
293
|
+
return [op for op in conv_ops if op.name in kernels]
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def print_kernels(kernels: Optional[list[str]]) -> list[ConvOpBase]:
|
|
297
|
+
data = sorted(
|
|
298
|
+
[
|
|
299
|
+
(op.name, "Yes" if op.cuda else "No", "Yes" if op.hip else "No")
|
|
300
|
+
for op in get_conv_ops()
|
|
301
|
+
]
|
|
302
|
+
)
|
|
303
|
+
print(tabulate(data, headers=["Name", "CUDA", "ROCm"], tablefmt="orgtbl"))
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@click.command()
|
|
307
|
+
@click.option(
|
|
308
|
+
"--output-dir",
|
|
309
|
+
default="/tmp",
|
|
310
|
+
help="Directory to save plots and csvs to",
|
|
311
|
+
)
|
|
312
|
+
@click.option(
|
|
313
|
+
"--num-iters",
|
|
314
|
+
default=1,
|
|
315
|
+
type=int,
|
|
316
|
+
help="Number of iterations to repeat each benchmark.",
|
|
317
|
+
)
|
|
318
|
+
@click.option(
|
|
319
|
+
"--export-csv",
|
|
320
|
+
is_flag=True,
|
|
321
|
+
help="Export results to a CSV file.",
|
|
322
|
+
)
|
|
323
|
+
@click.option(
|
|
324
|
+
"--plot",
|
|
325
|
+
is_flag=True,
|
|
326
|
+
help="Create a plot of the benchmark measurements.",
|
|
327
|
+
)
|
|
328
|
+
@click.option(
|
|
329
|
+
"--bench-quantize",
|
|
330
|
+
is_flag=True,
|
|
331
|
+
help="If set, include quantization cost in benchmark.",
|
|
332
|
+
)
|
|
333
|
+
@click.option(
|
|
334
|
+
"--kernels",
|
|
335
|
+
default=None,
|
|
336
|
+
help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
|
|
337
|
+
)
|
|
338
|
+
@click.option(
|
|
339
|
+
"--N",
|
|
340
|
+
default=None,
|
|
341
|
+
help="Comma separated list of batch sizes to benchmark.",
|
|
342
|
+
)
|
|
343
|
+
@click.option(
|
|
344
|
+
"--D",
|
|
345
|
+
default=None,
|
|
346
|
+
help="Comma separated list of depth values to benchmark.",
|
|
347
|
+
)
|
|
348
|
+
@click.option(
|
|
349
|
+
"--H",
|
|
350
|
+
default=None,
|
|
351
|
+
help="Comma separated list of height values to benchmark.",
|
|
352
|
+
)
|
|
353
|
+
@click.option(
|
|
354
|
+
"--W",
|
|
355
|
+
default=None,
|
|
356
|
+
help="Comma separated list of width values to benchmark.",
|
|
357
|
+
)
|
|
358
|
+
@click.option(
|
|
359
|
+
"--C",
|
|
360
|
+
default=None,
|
|
361
|
+
help="Comma separated list of input channel values to benchmark.",
|
|
362
|
+
)
|
|
363
|
+
@click.option(
|
|
364
|
+
"--K",
|
|
365
|
+
default=None,
|
|
366
|
+
help="Comma separated list of output channel (filter) values to benchmark.",
|
|
367
|
+
)
|
|
368
|
+
@click.option(
|
|
369
|
+
"--T",
|
|
370
|
+
default=None,
|
|
371
|
+
help="Comma separated list of kernel depth values to benchmark.",
|
|
372
|
+
)
|
|
373
|
+
@click.option(
|
|
374
|
+
"--R",
|
|
375
|
+
default=None,
|
|
376
|
+
help="Comma separated list of kernel height values to benchmark.",
|
|
377
|
+
)
|
|
378
|
+
@click.option(
|
|
379
|
+
"--S",
|
|
380
|
+
default=None,
|
|
381
|
+
help="Comma separated list of kernel width values to benchmark.",
|
|
382
|
+
)
|
|
383
|
+
@click.option(
|
|
384
|
+
"--pad",
|
|
385
|
+
default=None,
|
|
386
|
+
help="Comma separated list of padding values to benchmark.",
|
|
387
|
+
)
|
|
388
|
+
@click.option(
|
|
389
|
+
"--stride",
|
|
390
|
+
default=None,
|
|
391
|
+
help="Comma separated list of stride values to benchmark.",
|
|
392
|
+
)
|
|
393
|
+
@click.option(
|
|
394
|
+
"--dilation",
|
|
395
|
+
default=None,
|
|
396
|
+
help="Comma separated list of dilation values to benchmark.",
|
|
397
|
+
)
|
|
398
|
+
@click.option(
|
|
399
|
+
"--no-cuda-graph",
|
|
400
|
+
is_flag=True,
|
|
401
|
+
help="If set, do not use cuda graph for benchmarking.",
|
|
402
|
+
)
|
|
403
|
+
@click.option(
|
|
404
|
+
"--shapes",
|
|
405
|
+
default=None,
|
|
406
|
+
help=f"Specific model shapes to use, options: {', '.join(shape_registry.keys())}.",
|
|
407
|
+
)
|
|
408
|
+
@click.option(
|
|
409
|
+
"--trace",
|
|
410
|
+
is_flag=True,
|
|
411
|
+
help="If set, produce a performance trace of the benchmark.",
|
|
412
|
+
)
|
|
413
|
+
@click.option(
|
|
414
|
+
"--torch-compile",
|
|
415
|
+
is_flag=True,
|
|
416
|
+
help="If set, torch.compile will be used for aten backed ops.",
|
|
417
|
+
)
|
|
418
|
+
def invoke_main(
|
|
419
|
+
output_dir: str,
|
|
420
|
+
num_iters: int,
|
|
421
|
+
export_csv: bool,
|
|
422
|
+
plot: bool,
|
|
423
|
+
bench_quantize: bool,
|
|
424
|
+
kernels: Optional[str],
|
|
425
|
+
n: Optional[str],
|
|
426
|
+
d: Optional[str],
|
|
427
|
+
h: Optional[str],
|
|
428
|
+
w: Optional[str],
|
|
429
|
+
c: Optional[str],
|
|
430
|
+
k: Optional[str],
|
|
431
|
+
t: Optional[str],
|
|
432
|
+
r: Optional[str],
|
|
433
|
+
s: Optional[str],
|
|
434
|
+
pad: Optional[str],
|
|
435
|
+
stride: Optional[str],
|
|
436
|
+
dilation: Optional[str],
|
|
437
|
+
no_cuda_graph: bool,
|
|
438
|
+
shapes: Optional[str],
|
|
439
|
+
trace: bool,
|
|
440
|
+
torch_compile: bool,
|
|
441
|
+
):
|
|
442
|
+
# If kernel filter is provided, parse it. Else, benchmark all kernels.
|
|
443
|
+
all_kernels = kernels.strip().split(",") if kernels else None
|
|
444
|
+
conv_ops = collect_kernels_to_profile(all_kernels)
|
|
445
|
+
|
|
446
|
+
if len(conv_ops) == 0:
|
|
447
|
+
print("No valid kernels to benchmark. Available kernels:")
|
|
448
|
+
print_kernels(all_kernels)
|
|
449
|
+
sys.exit(1)
|
|
450
|
+
|
|
451
|
+
if num_iters < 1:
|
|
452
|
+
print("Warning: Number of iterations must be at least 1.")
|
|
453
|
+
num_iters = 1
|
|
454
|
+
|
|
455
|
+
# Enumerate shapes to benchmark.
|
|
456
|
+
if shapes:
|
|
457
|
+
if shapes not in shape_registry:
|
|
458
|
+
print(
|
|
459
|
+
f"Shape {shapes} not found in shape registry. Valid shapes: {', '.join(shape_registry.keys())}."
|
|
460
|
+
)
|
|
461
|
+
sys.exit(1)
|
|
462
|
+
conv_shapes = shape_registry[shapes]()
|
|
463
|
+
else:
|
|
464
|
+
# Parse individual dimension parameters
|
|
465
|
+
N = [int(n_val) for n_val in n.strip().split(",")] if n else [1, 4, 8]
|
|
466
|
+
D = [int(d_val) for d_val in d.strip().split(",")] if d else [8, 16]
|
|
467
|
+
H = [int(h_val) for h_val in h.strip().split(",")] if h else [8, 16]
|
|
468
|
+
W = [int(w_val) for w_val in w.strip().split(",")] if w else [8, 16]
|
|
469
|
+
C = [int(c_val) for c_val in c.strip().split(",")] if c else [64, 128]
|
|
470
|
+
K = [int(k_val) for k_val in k.strip().split(",")] if k else [64, 128]
|
|
471
|
+
T = [int(t_val) for t_val in t.strip().split(",")] if t else [3]
|
|
472
|
+
R = [int(r_val) for r_val in r.strip().split(",")] if r else [3]
|
|
473
|
+
S = [int(s_val) for s_val in s.strip().split(",")] if s else [3]
|
|
474
|
+
Pad = [int(p_val) for p_val in pad.strip().split(",")] if pad else [1]
|
|
475
|
+
Stride = (
|
|
476
|
+
[int(st_val) for st_val in stride.strip().split(",")] if stride else [1]
|
|
477
|
+
)
|
|
478
|
+
Dilation = (
|
|
479
|
+
[int(di_val) for di_val in dilation.strip().split(",")] if dilation else [1]
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Create all combinations
|
|
483
|
+
conv_shapes = list(
|
|
484
|
+
itertools.product(N, D, H, W, C, K, T, R, S, Pad, Stride, Dilation)
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# Iterate over shapes and benchmark.
|
|
488
|
+
benchmark_results = []
|
|
489
|
+
csv = []
|
|
490
|
+
opts = BenchOptions(
|
|
491
|
+
num_iters=num_iters,
|
|
492
|
+
cuda_graph=not no_cuda_graph,
|
|
493
|
+
trace=trace,
|
|
494
|
+
torch_compile=torch_compile,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
for n, d, h, w, c, k, t, r, s, pad, stride, dilation in conv_shapes:
|
|
498
|
+
conv_measurements = benchmark(
|
|
499
|
+
conv_ops,
|
|
500
|
+
n,
|
|
501
|
+
d,
|
|
502
|
+
h,
|
|
503
|
+
w,
|
|
504
|
+
c,
|
|
505
|
+
k,
|
|
506
|
+
t,
|
|
507
|
+
r,
|
|
508
|
+
s,
|
|
509
|
+
pad,
|
|
510
|
+
stride,
|
|
511
|
+
dilation,
|
|
512
|
+
opts,
|
|
513
|
+
bench_quantize,
|
|
514
|
+
)
|
|
515
|
+
benchmark_results.extend(conv_measurements)
|
|
516
|
+
csv_row = {}
|
|
517
|
+
for metric in conv_measurements:
|
|
518
|
+
csv_row.update(metric.as_dict())
|
|
519
|
+
csv.append(csv_row)
|
|
520
|
+
|
|
521
|
+
print(Metrics.header())
|
|
522
|
+
for metric in benchmark_results:
|
|
523
|
+
print(metric)
|
|
524
|
+
|
|
525
|
+
mem_bw_roofline_gbps = triton.testing.get_dram_gbps()
|
|
526
|
+
|
|
527
|
+
print("")
|
|
528
|
+
print(f"Hardware: {torch.cuda.get_device_name()}")
|
|
529
|
+
print(f" Memory BW Roofline: {mem_bw_roofline_gbps} GB/s")
|
|
530
|
+
|
|
531
|
+
print("")
|
|
532
|
+
print("Benchmark Settings:")
|
|
533
|
+
print(f" CUDA graph: {opts.cuda_graph}")
|
|
534
|
+
print(f" Bench quantize: {bench_quantize}")
|
|
535
|
+
print(f" Torch compile: {opts.torch_compile}")
|
|
536
|
+
|
|
537
|
+
if export_csv or plot:
|
|
538
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
539
|
+
if export_csv:
|
|
540
|
+
datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
541
|
+
csv_file = os.path.join(output_dir, f"conv_ops_benchmark_{datetime_str}.csv")
|
|
542
|
+
# Export results to a CSV file.
|
|
543
|
+
df = pd.DataFrame(csv)
|
|
544
|
+
df.to_csv(csv_file, na_rep="NaN", index=False)
|
|
545
|
+
print(f"CSV saved to {csv_file}")
|
|
546
|
+
if plot:
|
|
547
|
+
plot_benchmark(csv, output_dir)
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
if __name__ == "__main__":
|
|
551
|
+
invoke_main() # pragma: no cover
|