mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,5 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
@@ -0,0 +1,148 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import contextlib
10
+ import copy
11
+ import os
12
+ import tempfile
13
+ import uuid
14
+ from dataclasses import dataclass
15
+ from typing import Any, Callable
16
+
17
+ import torch
18
+ import triton # @manual=//triton:triton
19
+ from torch.profiler import profile, ProfilerActivity # pyre-ignore
20
+
21
+
22
+ @dataclass
23
+ class BenchOptions:
24
+ """Common benchmark options used across all benchmark scripts.
25
+
26
+ This dataclass encapsulates all configuration options for running benchmarks
27
+ in the MSLK benchmark suite. It is used by gemm_bench, conv_bench, and
28
+ quantize_bench to maintain consistent benchmarking behavior.
29
+
30
+ Attributes:
31
+ num_iters: Number of iterations to repeat each benchmark for averaging.
32
+ cuda_graph: Whether to use CUDA graphs for benchmarking. CUDA graphs
33
+ reduce kernel launch overhead and provide more accurate measurements
34
+ for GPU-bound workloads.
35
+ rotating_buffer: Whether to use a rotating buffer during benchmarking.
36
+ This helps flush L2/L3 cache between iterations to get more realistic
37
+ memory-bound performance measurements.
38
+ rep_ms: Repetition time in milliseconds for triton.testing.do_bench.
39
+ Controls how long each benchmark runs before measuring.
40
+ trace: Whether to produce a performance trace of the benchmark using
41
+ PyTorch profiler. Traces are saved to Manifold or temp directory.
42
+ fast_accum: Whether to enable fast accumulation for FP8 implementations.
43
+ This is only relevant for Hopper GPUs.
44
+ torch_compile: Whether to use torch.compile for applicable operations.
45
+ """
46
+
47
+ num_iters: int = 1
48
+ cuda_graph: bool = True
49
+ rotating_buffer: bool = False
50
+ rep_ms: int = 200
51
+ trace: bool = False
52
+ fast_accum: bool = False
53
+ torch_compile: bool = False
54
+
55
+
56
+ def _do_bench(
57
+ fn: Callable[[], Any],
58
+ opts: BenchOptions,
59
+ ) -> float:
60
+ if opts.cuda_graph:
61
+ with torch.cuda.stream(torch.cuda.Stream()):
62
+ return triton.testing.do_bench_cudagraph(fn, rep=opts.rep_ms)
63
+ else:
64
+ return triton.testing.do_bench(fn, rep=opts.rep_ms)
65
+
66
+
67
+ def do_bench(
68
+ fn: Callable[..., Any],
69
+ args: tuple[Any, ...],
70
+ opts: BenchOptions,
71
+ ) -> float:
72
+ """
73
+ Benchmark a function using triton's benchmarking utilities.
74
+
75
+ Args:
76
+ fn: The function to benchmark.
77
+ args: Tuple of arguments to pass to the function.
78
+ opts: Benchmark options
79
+
80
+ Returns:
81
+ The runtime in milliseconds.
82
+ """
83
+ if not opts.rotating_buffer:
84
+ return _do_bench(lambda: fn(*args), opts)
85
+
86
+ # Calculate input size to determine how many copies we need.
87
+ input_size_bytes = sum(
88
+ t.element_size() * t.numel() for t in args if isinstance(t, torch.Tensor)
89
+ )
90
+
91
+ # Use a 50MB buffer, this may need to be variable in the future.
92
+ rotating_buffer_size_bytes = 50 * 1024 * 1024
93
+ # Make at least one copy of the inputs.
94
+ copy_cnt = max(rotating_buffer_size_bytes // input_size_bytes, 1)
95
+
96
+ args_list = [args]
97
+ for _ in range(copy_cnt):
98
+ args_list.append(copy.deepcopy(args))
99
+
100
+ # We benchmark on a different stream, so a sync is required.
101
+ torch.cuda.synchronize()
102
+
103
+ def rotating_buffer_fn() -> None:
104
+ for a in args_list:
105
+ fn(*a)
106
+
107
+ return _do_bench(rotating_buffer_fn, opts) / len(args_list)
108
+
109
+
110
+ def profiler(
111
+ enabled: bool,
112
+ with_stack: bool = False,
113
+ record_shapes: bool = False,
114
+ ):
115
+ """
116
+ Returns a profiler context manager if enabled, otherwise a null context.
117
+
118
+ When enabled, profiles CPU and CUDA activities.
119
+
120
+ Args:
121
+ enabled: Whether to enable profiling.
122
+ with_stack: Whether to record stack traces.
123
+ record_shapes: Whether to record tensor shapes.
124
+
125
+ Returns:
126
+ A context manager - either a torch profiler or nullcontext.
127
+ """
128
+
129
+ def _kineto_trace_handler(p: torch.profiler.profile) -> None:
130
+ trace_filename = f"mslk_{os.getpid()}_{uuid.uuid4().hex}.json"
131
+
132
+ if os.path.exists("/etc/fbwhoami"):
133
+ trace_url = f"manifold://gpu_traces/tree/accelerator/{trace_filename}"
134
+ else:
135
+ trace_url = os.path.join(tempfile.gettempdir(), trace_filename)
136
+
137
+ p.export_chrome_trace(trace_url)
138
+
139
+ return (
140
+ profile(
141
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], # pyre-ignore
142
+ on_trace_ready=_kineto_trace_handler,
143
+ with_stack=with_stack,
144
+ record_shapes=record_shapes,
145
+ )
146
+ if enabled
147
+ else contextlib.nullcontext()
148
+ )
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -0,0 +1,551 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import itertools
8
+ import os
9
+ import sys
10
+ from dataclasses import dataclass
11
+ from datetime import datetime
12
+ from typing import Any, Optional
13
+
14
+ import click
15
+ import matplotlib.pyplot as plt
16
+ import pandas as pd
17
+ import seaborn as sns
18
+ import torch
19
+ import triton # @manual=//triton:triton
20
+ from mslk.bench.common.utils import BenchOptions, profiler
21
+ from mslk.bench.conv.conv_ops import ConvOpBase, get_conv_ops
22
+ from tabulate import tabulate
23
+
24
+
25
+ shape_registry = {}
26
+
27
+
28
+ def register_shapes(name):
29
+ def decorator(op):
30
+ shape_registry[name] = op
31
+ return op
32
+
33
+ return decorator
34
+
35
+
36
+ @register_shapes("default")
37
+ def default_shapes() -> list[
38
+ tuple[int, int, int, int, int, int, int, int, int, int, int, int]
39
+ ]:
40
+ """
41
+ Default convolution shapes for benchmarking.
42
+
43
+ Returns tuples of (N, D, H, W, C, K, T, R, S, pad, stride, dilation)
44
+ where:
45
+ N = batch size
46
+ D, H, W = input spatial dimensions (depth, height, width)
47
+ C = input channels
48
+ K = output channels (filters)
49
+ T, R, S = kernel spatial dimensions
50
+ pad, stride, dilation = convolution parameters (applied uniformly to D, H, W)
51
+ """
52
+ shapes = []
53
+ # Common batch sizes
54
+ for N in [1, 4, 8, 16]:
55
+ # Common configurations
56
+ shapes += [
57
+ # Small spatial dimensions, various channel sizes
58
+ (N, 8, 8, 8, 64, 64, 3, 3, 3, 1, 1, 1),
59
+ (N, 8, 8, 8, 64, 128, 3, 3, 3, 1, 1, 1),
60
+ (N, 8, 8, 8, 128, 128, 3, 3, 3, 1, 1, 1),
61
+ # Medium spatial dimensions
62
+ (N, 16, 16, 16, 64, 64, 3, 3, 3, 1, 1, 1),
63
+ (N, 16, 16, 16, 64, 128, 3, 3, 3, 1, 1, 1),
64
+ # Larger spatial dimensions
65
+ (N, 32, 32, 32, 32, 64, 3, 3, 3, 1, 1, 1),
66
+ (N, 32, 32, 32, 64, 64, 3, 3, 3, 1, 1, 1),
67
+ # 1x1x1 convolutions (common in ResNets)
68
+ (N, 16, 16, 16, 64, 128, 1, 1, 1, 0, 1, 1),
69
+ (N, 16, 16, 16, 128, 256, 1, 1, 1, 0, 1, 1),
70
+ ]
71
+ return shapes
72
+
73
+
74
+ @dataclass
75
+ class Metrics:
76
+ op: str
77
+ N: int = 0
78
+ D: int = 0
79
+ H: int = 0
80
+ W: int = 0
81
+ C: int = 0
82
+ K: int = 0
83
+ T: int = 0
84
+ R: int = 0
85
+ S: int = 0
86
+ pad: int = 0
87
+ stride: int = 0
88
+ dilation: int = 0
89
+
90
+ sim: float = 0.0
91
+ ms: float = 0.0
92
+ tflops: float = 0.0
93
+ gbps: float = 0.0
94
+
95
+ @staticmethod
96
+ def header() -> str:
97
+ header = (
98
+ f"{'OpName':<20} {'(N,D,H,W,C,K,T,R,S,pad,stride,dilation)':<50} "
99
+ f"{'Sim':<10} {'ms':<10} {'TFLOPS':<10} {'GB/s':<10}"
100
+ )
101
+ divider = "-" * len(header)
102
+ return f"Conv Bench\n{divider}\n{header}\n{divider}"
103
+
104
+ def __str__(self) -> str:
105
+ problem_shape = (
106
+ f"({self.N},{self.D},{self.H},{self.W},{self.C},{self.K},"
107
+ f"{self.T},{self.R},{self.S},{self.pad},{self.stride},{self.dilation})"
108
+ )
109
+ return (
110
+ f"{self.op:<20} {problem_shape:<50} "
111
+ f"{self.sim:<10.3f} {self.ms:<10.3f} {self.tflops:<10.2f} {self.gbps:<10.2f}"
112
+ )
113
+
114
+ def as_dict(self) -> dict[str, float]:
115
+ return {
116
+ "N": self.N,
117
+ "D": self.D,
118
+ "H": self.H,
119
+ "W": self.W,
120
+ "C": self.C,
121
+ "K": self.K,
122
+ "T": self.T,
123
+ "R": self.R,
124
+ "S": self.S,
125
+ "pad": self.pad,
126
+ "stride": self.stride,
127
+ "dilation": self.dilation,
128
+ f"{self.op}_sim": self.sim,
129
+ f"{self.op}_ms": self.ms,
130
+ f"{self.op}_tflops": self.tflops,
131
+ f"{self.op}_gb/s": self.gbps,
132
+ }
133
+
134
+
135
+ def benchmark(
136
+ conv_ops: list[ConvOpBase],
137
+ n: int,
138
+ d: int,
139
+ h: int,
140
+ w: int,
141
+ c: int,
142
+ k: int,
143
+ t: int,
144
+ r: int,
145
+ s: int,
146
+ pad: int,
147
+ stride: int,
148
+ dilation: int,
149
+ opts: BenchOptions,
150
+ bench_quantize: bool = False,
151
+ ) -> list[Metrics]:
152
+ # Create input tensors in NCDHW format
153
+ activation = torch.randn(n, c, d, h, w, device="cuda", dtype=torch.bfloat16)
154
+ # Create filter tensors in KCTRS format
155
+ filter = torch.randn(k, c, t, r, s, device="cuda", dtype=torch.bfloat16)
156
+
157
+ # Convolution parameters (uniform for all dimensions)
158
+ padding = [pad, pad, pad]
159
+ stride_vec = [stride, stride, stride]
160
+ dilation_vec = [dilation, dilation, dilation]
161
+
162
+ # Compute baseline output for correctness checking using PyTorch
163
+ # Convert to NCDHW for PyTorch
164
+ out_ref = torch.nn.functional.conv3d(
165
+ activation,
166
+ filter,
167
+ bias=None,
168
+ stride=stride_vec,
169
+ padding=padding,
170
+ dilation=dilation_vec,
171
+ )
172
+
173
+ # Keep track of results.
174
+ results = []
175
+
176
+ # Benchmark each operator.
177
+ for conv_op in conv_ops:
178
+ print(
179
+ f"Benchmarking {conv_op.name} with "
180
+ f"(N={n}, D={d}, H={h}, W={w}, C={c}, K={k}, T={t}, R={r}, S={s}, "
181
+ f"pad={pad}, stride={stride}, dilation={dilation})"
182
+ )
183
+ metrics = Metrics(
184
+ op=conv_op.name,
185
+ N=n,
186
+ D=d,
187
+ H=h,
188
+ W=w,
189
+ C=c,
190
+ K=k,
191
+ T=t,
192
+ R=r,
193
+ S=s,
194
+ pad=pad,
195
+ stride=stride,
196
+ dilation=dilation,
197
+ )
198
+ if hasattr(conv_op, "torch_compile"):
199
+ conv_op.torch_compile = opts.torch_compile
200
+ # Preprocess data if needed.
201
+ preprocessed_args = conv_op.preprocess(
202
+ activation, filter, padding, stride_vec, dilation_vec
203
+ )
204
+ # Get the quantized tensors for this operator.
205
+ quantized_vals = conv_op.quantize(*preprocessed_args)
206
+ # Compute the output given quantized values.
207
+ output = conv_op.compute(*quantized_vals)
208
+ # Compare the quantize op output to reference as a sanity check.
209
+ metrics.sim = torch.mean(torch.pow(output - out_ref, 2)).item()
210
+
211
+ # Compute output spatial dimensions
212
+ z = 1 + (d + 2 * pad - ((t - 1) * dilation + 1)) // stride
213
+ p = 1 + (h + 2 * pad - ((r - 1) * dilation + 1)) // stride
214
+ q = 1 + (w + 2 * pad - ((s - 1) * dilation + 1)) // stride
215
+
216
+ for _ in range(opts.num_iters):
217
+ # Now perform benchmark.
218
+ if bench_quantize:
219
+ # Benchmark both quantize and compute.
220
+ with profiler(enabled=opts.trace, with_stack=True):
221
+ ms_runtime = conv_op.benchmark(
222
+ *preprocessed_args,
223
+ opts=opts,
224
+ bench_quantize=True,
225
+ )
226
+ else:
227
+ with profiler(enabled=opts.trace, with_stack=True):
228
+ ms_runtime = conv_op.benchmark(
229
+ *quantized_vals,
230
+ opts=opts,
231
+ bench_quantize=False,
232
+ )
233
+
234
+ # Compute performance metrics
235
+ # FLOPs for convolution: 2 * N * Z * P * Q * K * T * R * S * C
236
+ flops = 2 * n * z * p * q * k * t * r * s * c
237
+ metrics.tflops += flops / (ms_runtime / 1e3) / 1e12
238
+
239
+ # Compute memory bandwidth
240
+ # Input: N * D * H * W * C, Filter: K * T * R * S * C, Output: N * Z * P * Q * K
241
+ input_size = n * d * h * w * c * quantized_vals[0].element_size()
242
+ filter_size = k * t * r * s * c * quantized_vals[1].element_size()
243
+ output_size = n * z * p * q * k * output.element_size()
244
+
245
+ metrics.gbps += (
246
+ (input_size + filter_size + output_size) / (ms_runtime / 1e3) / 1e9
247
+ )
248
+ metrics.ms += ms_runtime
249
+
250
+ # Average metrics over iterations.
251
+ metrics.ms /= opts.num_iters
252
+ metrics.tflops /= opts.num_iters
253
+ metrics.gbps /= opts.num_iters
254
+
255
+ results.append(metrics)
256
+
257
+ return results
258
+
259
+
260
+ def plot_benchmark(results: list[dict[str, Any]], output_dir: str) -> None:
261
+ """Create a barplot visualizing the TFLOPS of each kernel."""
262
+ # Reprocess into new dataframe with proper graph format.
263
+ data = []
264
+ # Extract measurements for each shape.
265
+ for impl in results:
266
+ shape_str = f"N{impl['N']}_D{impl['D']}_H{impl['H']}_W{impl['W']}_C{impl['C']}_K{impl['K']}"
267
+ # Iterate over keys to find tflops entries.
268
+ for key in impl:
269
+ if "tflops" in key:
270
+ op_name = key.split("_tflops")[0]
271
+ op_tflops = impl[key]
272
+ data.append(
273
+ {"Shape": shape_str, "kernel": op_name, "TFLOPS": op_tflops}
274
+ )
275
+
276
+ # Create a barplot using seaborn.
277
+ df = pd.DataFrame(data)
278
+ plot = plt.figure()
279
+ plt.xticks(rotation=30)
280
+ plt.yscale("log")
281
+ ax = sns.barplot(x="Shape", y="TFLOPS", hue="kernel", data=df)
282
+ ax.tick_params(axis="x", labelsize=3)
283
+ img_fn = os.path.join(output_dir, "conv_ops_benchmark.png")
284
+ plot.savefig(img_fn, dpi=300)
285
+ print(f"Plot saved to {img_fn}")
286
+
287
+
288
+ def collect_kernels_to_profile(kernels: Optional[list[str]]) -> list[ConvOpBase]:
289
+ # Get existing convolution operators.
290
+ conv_ops = [op for op in get_conv_ops() if op.supported]
291
+ if kernels is None:
292
+ return conv_ops
293
+ return [op for op in conv_ops if op.name in kernels]
294
+
295
+
296
+ def print_kernels(kernels: Optional[list[str]]) -> list[ConvOpBase]:
297
+ data = sorted(
298
+ [
299
+ (op.name, "Yes" if op.cuda else "No", "Yes" if op.hip else "No")
300
+ for op in get_conv_ops()
301
+ ]
302
+ )
303
+ print(tabulate(data, headers=["Name", "CUDA", "ROCm"], tablefmt="orgtbl"))
304
+
305
+
306
+ @click.command()
307
+ @click.option(
308
+ "--output-dir",
309
+ default="/tmp",
310
+ help="Directory to save plots and csvs to",
311
+ )
312
+ @click.option(
313
+ "--num-iters",
314
+ default=1,
315
+ type=int,
316
+ help="Number of iterations to repeat each benchmark.",
317
+ )
318
+ @click.option(
319
+ "--export-csv",
320
+ is_flag=True,
321
+ help="Export results to a CSV file.",
322
+ )
323
+ @click.option(
324
+ "--plot",
325
+ is_flag=True,
326
+ help="Create a plot of the benchmark measurements.",
327
+ )
328
+ @click.option(
329
+ "--bench-quantize",
330
+ is_flag=True,
331
+ help="If set, include quantization cost in benchmark.",
332
+ )
333
+ @click.option(
334
+ "--kernels",
335
+ default=None,
336
+ help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
337
+ )
338
+ @click.option(
339
+ "--N",
340
+ default=None,
341
+ help="Comma separated list of batch sizes to benchmark.",
342
+ )
343
+ @click.option(
344
+ "--D",
345
+ default=None,
346
+ help="Comma separated list of depth values to benchmark.",
347
+ )
348
+ @click.option(
349
+ "--H",
350
+ default=None,
351
+ help="Comma separated list of height values to benchmark.",
352
+ )
353
+ @click.option(
354
+ "--W",
355
+ default=None,
356
+ help="Comma separated list of width values to benchmark.",
357
+ )
358
+ @click.option(
359
+ "--C",
360
+ default=None,
361
+ help="Comma separated list of input channel values to benchmark.",
362
+ )
363
+ @click.option(
364
+ "--K",
365
+ default=None,
366
+ help="Comma separated list of output channel (filter) values to benchmark.",
367
+ )
368
+ @click.option(
369
+ "--T",
370
+ default=None,
371
+ help="Comma separated list of kernel depth values to benchmark.",
372
+ )
373
+ @click.option(
374
+ "--R",
375
+ default=None,
376
+ help="Comma separated list of kernel height values to benchmark.",
377
+ )
378
+ @click.option(
379
+ "--S",
380
+ default=None,
381
+ help="Comma separated list of kernel width values to benchmark.",
382
+ )
383
+ @click.option(
384
+ "--pad",
385
+ default=None,
386
+ help="Comma separated list of padding values to benchmark.",
387
+ )
388
+ @click.option(
389
+ "--stride",
390
+ default=None,
391
+ help="Comma separated list of stride values to benchmark.",
392
+ )
393
+ @click.option(
394
+ "--dilation",
395
+ default=None,
396
+ help="Comma separated list of dilation values to benchmark.",
397
+ )
398
+ @click.option(
399
+ "--no-cuda-graph",
400
+ is_flag=True,
401
+ help="If set, do not use cuda graph for benchmarking.",
402
+ )
403
+ @click.option(
404
+ "--shapes",
405
+ default=None,
406
+ help=f"Specific model shapes to use, options: {', '.join(shape_registry.keys())}.",
407
+ )
408
+ @click.option(
409
+ "--trace",
410
+ is_flag=True,
411
+ help="If set, produce a performance trace of the benchmark.",
412
+ )
413
+ @click.option(
414
+ "--torch-compile",
415
+ is_flag=True,
416
+ help="If set, torch.compile will be used for aten backed ops.",
417
+ )
418
+ def invoke_main(
419
+ output_dir: str,
420
+ num_iters: int,
421
+ export_csv: bool,
422
+ plot: bool,
423
+ bench_quantize: bool,
424
+ kernels: Optional[str],
425
+ n: Optional[str],
426
+ d: Optional[str],
427
+ h: Optional[str],
428
+ w: Optional[str],
429
+ c: Optional[str],
430
+ k: Optional[str],
431
+ t: Optional[str],
432
+ r: Optional[str],
433
+ s: Optional[str],
434
+ pad: Optional[str],
435
+ stride: Optional[str],
436
+ dilation: Optional[str],
437
+ no_cuda_graph: bool,
438
+ shapes: Optional[str],
439
+ trace: bool,
440
+ torch_compile: bool,
441
+ ):
442
+ # If kernel filter is provided, parse it. Else, benchmark all kernels.
443
+ all_kernels = kernels.strip().split(",") if kernels else None
444
+ conv_ops = collect_kernels_to_profile(all_kernels)
445
+
446
+ if len(conv_ops) == 0:
447
+ print("No valid kernels to benchmark. Available kernels:")
448
+ print_kernels(all_kernels)
449
+ sys.exit(1)
450
+
451
+ if num_iters < 1:
452
+ print("Warning: Number of iterations must be at least 1.")
453
+ num_iters = 1
454
+
455
+ # Enumerate shapes to benchmark.
456
+ if shapes:
457
+ if shapes not in shape_registry:
458
+ print(
459
+ f"Shape {shapes} not found in shape registry. Valid shapes: {', '.join(shape_registry.keys())}."
460
+ )
461
+ sys.exit(1)
462
+ conv_shapes = shape_registry[shapes]()
463
+ else:
464
+ # Parse individual dimension parameters
465
+ N = [int(n_val) for n_val in n.strip().split(",")] if n else [1, 4, 8]
466
+ D = [int(d_val) for d_val in d.strip().split(",")] if d else [8, 16]
467
+ H = [int(h_val) for h_val in h.strip().split(",")] if h else [8, 16]
468
+ W = [int(w_val) for w_val in w.strip().split(",")] if w else [8, 16]
469
+ C = [int(c_val) for c_val in c.strip().split(",")] if c else [64, 128]
470
+ K = [int(k_val) for k_val in k.strip().split(",")] if k else [64, 128]
471
+ T = [int(t_val) for t_val in t.strip().split(",")] if t else [3]
472
+ R = [int(r_val) for r_val in r.strip().split(",")] if r else [3]
473
+ S = [int(s_val) for s_val in s.strip().split(",")] if s else [3]
474
+ Pad = [int(p_val) for p_val in pad.strip().split(",")] if pad else [1]
475
+ Stride = (
476
+ [int(st_val) for st_val in stride.strip().split(",")] if stride else [1]
477
+ )
478
+ Dilation = (
479
+ [int(di_val) for di_val in dilation.strip().split(",")] if dilation else [1]
480
+ )
481
+
482
+ # Create all combinations
483
+ conv_shapes = list(
484
+ itertools.product(N, D, H, W, C, K, T, R, S, Pad, Stride, Dilation)
485
+ )
486
+
487
+ # Iterate over shapes and benchmark.
488
+ benchmark_results = []
489
+ csv = []
490
+ opts = BenchOptions(
491
+ num_iters=num_iters,
492
+ cuda_graph=not no_cuda_graph,
493
+ trace=trace,
494
+ torch_compile=torch_compile,
495
+ )
496
+
497
+ for n, d, h, w, c, k, t, r, s, pad, stride, dilation in conv_shapes:
498
+ conv_measurements = benchmark(
499
+ conv_ops,
500
+ n,
501
+ d,
502
+ h,
503
+ w,
504
+ c,
505
+ k,
506
+ t,
507
+ r,
508
+ s,
509
+ pad,
510
+ stride,
511
+ dilation,
512
+ opts,
513
+ bench_quantize,
514
+ )
515
+ benchmark_results.extend(conv_measurements)
516
+ csv_row = {}
517
+ for metric in conv_measurements:
518
+ csv_row.update(metric.as_dict())
519
+ csv.append(csv_row)
520
+
521
+ print(Metrics.header())
522
+ for metric in benchmark_results:
523
+ print(metric)
524
+
525
+ mem_bw_roofline_gbps = triton.testing.get_dram_gbps()
526
+
527
+ print("")
528
+ print(f"Hardware: {torch.cuda.get_device_name()}")
529
+ print(f" Memory BW Roofline: {mem_bw_roofline_gbps} GB/s")
530
+
531
+ print("")
532
+ print("Benchmark Settings:")
533
+ print(f" CUDA graph: {opts.cuda_graph}")
534
+ print(f" Bench quantize: {bench_quantize}")
535
+ print(f" Torch compile: {opts.torch_compile}")
536
+
537
+ if export_csv or plot:
538
+ os.makedirs(output_dir, exist_ok=True)
539
+ if export_csv:
540
+ datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
541
+ csv_file = os.path.join(output_dir, f"conv_ops_benchmark_{datetime_str}.csv")
542
+ # Export results to a CSV file.
543
+ df = pd.DataFrame(csv)
544
+ df.to_csv(csv_file, na_rep="NaN", index=False)
545
+ print(f"CSV saved to {csv_file}")
546
+ if plot:
547
+ plot_benchmark(csv, output_dir)
548
+
549
+
550
+ if __name__ == "__main__":
551
+ invoke_main() # pragma: no cover