kernelmeter 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,28 @@
1
+ """kernelmeter: CUDA device attributes without profiling, and kernel
2
+ benchmarks measured against the hardware's speed of light."""
3
+
4
+ from .bench import REGISTRY, BenchResult, BenchSpec, benchmark, device_peaks, run, run_registry
5
+ from .cudadrv import CudaDriverError, CudaNotAvailableError, Driver
6
+ from . import occupancy, roofline
7
+ from .occupancy import Occupancy
8
+ from .peaks import Peaks
9
+
10
+ __version__ = "0.2.0"
11
+
12
+ __all__ = [
13
+ "BenchResult",
14
+ "BenchSpec",
15
+ "CudaDriverError",
16
+ "CudaNotAvailableError",
17
+ "Driver",
18
+ "Occupancy",
19
+ "Peaks",
20
+ "REGISTRY",
21
+ "benchmark",
22
+ "device_peaks",
23
+ "occupancy",
24
+ "roofline",
25
+ "run",
26
+ "run_registry",
27
+ "__version__",
28
+ ]
kernelmeter/attrs.py ADDED
@@ -0,0 +1,149 @@
1
+ """Enumerate every device attribute the driver knows about.
2
+
3
+ Nsight Compute's ``device__attribute_*`` values come straight from
4
+ ``cuDeviceGetAttribute``, so we ask the driver directly. Every id from 1
5
+ to PROBE_MAX is probed: ids the driver rejects are skipped, ids that
6
+ succeed but aren't in the name table yet are reported as ``attribute_<id>``.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from .cudadrv import DeviceHandle, Driver
12
+
13
+ PROBE_MAX = 192
14
+
15
+ # CU_DEVICE_ATTRIBUTE_* names (snake_cased), keyed by enum value.
16
+ # Deprecated/reserved slots (44, 92-94, ...) are intentionally absent.
17
+ KNOWN_ATTRS: dict[int, str] = {
18
+ 1: "max_threads_per_block",
19
+ 2: "max_block_dim_x",
20
+ 3: "max_block_dim_y",
21
+ 4: "max_block_dim_z",
22
+ 5: "max_grid_dim_x",
23
+ 6: "max_grid_dim_y",
24
+ 7: "max_grid_dim_z",
25
+ 8: "max_shared_memory_per_block",
26
+ 9: "total_constant_memory",
27
+ 10: "warp_size",
28
+ 11: "max_pitch",
29
+ 12: "max_registers_per_block",
30
+ 13: "clock_rate_khz",
31
+ 14: "texture_alignment",
32
+ 15: "gpu_overlap",
33
+ 16: "multiprocessor_count",
34
+ 17: "kernel_exec_timeout",
35
+ 18: "integrated",
36
+ 19: "can_map_host_memory",
37
+ 20: "compute_mode",
38
+ 21: "maximum_texture1d_width",
39
+ 22: "maximum_texture2d_width",
40
+ 23: "maximum_texture2d_height",
41
+ 24: "maximum_texture3d_width",
42
+ 25: "maximum_texture3d_height",
43
+ 26: "maximum_texture3d_depth",
44
+ 27: "maximum_texture2d_layered_width",
45
+ 28: "maximum_texture2d_layered_height",
46
+ 29: "maximum_texture2d_layered_layers",
47
+ 30: "surface_alignment",
48
+ 31: "concurrent_kernels",
49
+ 32: "ecc_enabled",
50
+ 33: "pci_bus_id",
51
+ 34: "pci_device_id",
52
+ 35: "tcc_driver",
53
+ 36: "memory_clock_rate_khz",
54
+ 37: "global_memory_bus_width_bits",
55
+ 38: "l2_cache_size",
56
+ 39: "max_threads_per_multiprocessor",
57
+ 40: "async_engine_count",
58
+ 41: "unified_addressing",
59
+ 42: "maximum_texture1d_layered_width",
60
+ 43: "maximum_texture1d_layered_layers",
61
+ 45: "maximum_texture2d_gather_width",
62
+ 46: "maximum_texture2d_gather_height",
63
+ 47: "maximum_texture3d_width_alternate",
64
+ 48: "maximum_texture3d_height_alternate",
65
+ 49: "maximum_texture3d_depth_alternate",
66
+ 50: "pci_domain_id",
67
+ 51: "texture_pitch_alignment",
68
+ 52: "maximum_texturecubemap_width",
69
+ 53: "maximum_texturecubemap_layered_width",
70
+ 54: "maximum_texturecubemap_layered_layers",
71
+ 55: "maximum_surface1d_width",
72
+ 56: "maximum_surface2d_width",
73
+ 57: "maximum_surface2d_height",
74
+ 58: "maximum_surface3d_width",
75
+ 59: "maximum_surface3d_height",
76
+ 60: "maximum_surface3d_depth",
77
+ 61: "maximum_surface1d_layered_width",
78
+ 62: "maximum_surface1d_layered_layers",
79
+ 63: "maximum_surface2d_layered_width",
80
+ 64: "maximum_surface2d_layered_height",
81
+ 65: "maximum_surface2d_layered_layers",
82
+ 66: "maximum_surfacecubemap_width",
83
+ 67: "maximum_surfacecubemap_layered_width",
84
+ 68: "maximum_surfacecubemap_layered_layers",
85
+ 69: "maximum_texture1d_linear_width",
86
+ 70: "maximum_texture2d_linear_width",
87
+ 71: "maximum_texture2d_linear_height",
88
+ 72: "maximum_texture2d_linear_pitch",
89
+ 73: "maximum_texture2d_mipmapped_width",
90
+ 74: "maximum_texture2d_mipmapped_height",
91
+ 75: "compute_capability_major",
92
+ 76: "compute_capability_minor",
93
+ 77: "maximum_texture1d_mipmapped_width",
94
+ 78: "stream_priorities_supported",
95
+ 79: "global_l1_cache_supported",
96
+ 80: "local_l1_cache_supported",
97
+ 81: "max_shared_memory_per_multiprocessor",
98
+ 82: "max_registers_per_multiprocessor",
99
+ 83: "managed_memory",
100
+ 84: "multi_gpu_board",
101
+ 85: "multi_gpu_board_group_id",
102
+ 86: "host_native_atomic_supported",
103
+ 87: "single_to_double_precision_perf_ratio",
104
+ 88: "pageable_memory_access",
105
+ 89: "concurrent_managed_access",
106
+ 90: "compute_preemption_supported",
107
+ 91: "can_use_host_pointer_for_registered_mem",
108
+ 95: "cooperative_launch",
109
+ 96: "cooperative_multi_device_launch",
110
+ 97: "max_shared_memory_per_block_optin",
111
+ 98: "can_flush_remote_writes",
112
+ 99: "host_register_supported",
113
+ 100: "pageable_memory_access_uses_host_page_tables",
114
+ 101: "direct_managed_mem_access_from_host",
115
+ 102: "virtual_memory_management_supported",
116
+ 103: "handle_type_posix_file_descriptor_supported",
117
+ 104: "handle_type_win32_handle_supported",
118
+ 105: "handle_type_win32_kmt_handle_supported",
119
+ 106: "max_blocks_per_multiprocessor",
120
+ 107: "generic_compression_supported",
121
+ 108: "max_persisting_l2_cache_size",
122
+ 109: "max_access_policy_window_size",
123
+ 110: "gpu_direct_rdma_with_cuda_vmm_supported",
124
+ 111: "reserved_shared_memory_per_block",
125
+ 112: "sparse_cuda_array_supported",
126
+ 113: "read_only_host_register_supported",
127
+ 114: "timeline_semaphore_interop_supported",
128
+ 115: "memory_pools_supported",
129
+ 116: "gpu_direct_rdma_supported",
130
+ 117: "gpu_direct_rdma_flush_writes_options",
131
+ 118: "gpu_direct_rdma_writes_ordering",
132
+ 119: "mempool_supported_handle_types",
133
+ 120: "cluster_launch",
134
+ 121: "deferred_mapping_cuda_array_supported",
135
+ }
136
+
137
+
138
+ def attr_name(attr_id: int) -> str:
139
+ return KNOWN_ATTRS.get(attr_id, f"attribute_{attr_id}")
140
+
141
+
142
+ def query_all(driver: Driver, device: DeviceHandle, probe_max: int = PROBE_MAX) -> dict[str, int]:
143
+ """Return {attribute_name: value} for every attribute this driver supports."""
144
+ out: dict[str, int] = {}
145
+ for attr_id in range(1, probe_max + 1):
146
+ value = driver.attribute(device, attr_id)
147
+ if value is not None:
148
+ out[attr_name(attr_id)] = value
149
+ return out
kernelmeter/bench.py ADDED
@@ -0,0 +1,292 @@
1
+ """Benchmark kernels against references and against the hardware ceiling.
2
+
3
+ Workflow: decorate your implementation with @kernelmeter.benchmark, giving
4
+ it an argument factory, a reference implementation, and (optionally) the
5
+ bytes moved / FLOPs performed per call. Run ``kernelmeter bench yourfile.py``
6
+ and you get latency, achieved GB/s and TFLOP/s, percent of the device's
7
+ theoretical peak, and a numerical correctness check against the reference.
8
+
9
+ torch is imported lazily so the rest of the package (``kernelmeter info``)
10
+ works without it.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import statistics
16
+ from dataclasses import dataclass
17
+ from typing import Any, Callable
18
+
19
+ from . import attrs as _attrs
20
+ from . import peaks as _peaks
21
+ from . import roofline as _roofline
22
+ from .cudadrv import CudaNotAvailableError, Driver
23
+
24
+
25
+ @dataclass
26
+ class BenchSpec:
27
+ fn: Callable
28
+ name: str
29
+ args: Callable[[], tuple]
30
+ ref: Callable | None = None
31
+ bytes_per_call: Callable[..., int] | int | None = None
32
+ flops_per_call: Callable[..., int] | int | None = None
33
+ peak_tflops: float | None = None # override for tensor-core work
34
+ warmup: int = 10
35
+ iters: int = 100
36
+
37
+
38
+ @dataclass
39
+ class BenchResult:
40
+ name: str
41
+ ms_mean: float
42
+ ms_median: float
43
+ ms_min: float
44
+ gbps: float | None = None
45
+ tflops: float | None = None
46
+ intensity: float | None = None
47
+ bound: str | None = None
48
+ pct_roofline: float | None = None
49
+ pct_peak_bw: float | None = None
50
+ pct_peak_fp32: float | None = None
51
+ ref_ms_median: float | None = None
52
+ speedup_vs_ref: float | None = None
53
+ correct: bool | None = None
54
+ max_abs_err: float | None = None
55
+ error: str | None = None
56
+
57
+ def as_dict(self) -> dict:
58
+ return {k: v for k, v in self.__dict__.items()}
59
+
60
+
61
+ REGISTRY: list[BenchSpec] = []
62
+
63
+
64
+ def benchmark(
65
+ name: str | None = None,
66
+ *,
67
+ args: Callable[[], tuple],
68
+ ref: Callable | None = None,
69
+ bytes_per_call: Callable[..., int] | int | None = None,
70
+ flops_per_call: Callable[..., int] | int | None = None,
71
+ peak_tflops: float | None = None,
72
+ warmup: int = 10,
73
+ iters: int = 100,
74
+ ):
75
+ """Register a function for ``kernelmeter bench``. Also usable directly:
76
+ the decorated function is returned unchanged.
77
+
78
+ peak_tflops replaces the derived fp32 peak in the roofline when your
79
+ kernel runs on other units (tensor cores, fp16, fp64).
80
+ """
81
+
82
+ def deco(fn: Callable) -> Callable:
83
+ REGISTRY.append(
84
+ BenchSpec(
85
+ fn=fn,
86
+ name=name or fn.__name__,
87
+ args=args,
88
+ ref=ref,
89
+ bytes_per_call=bytes_per_call,
90
+ flops_per_call=flops_per_call,
91
+ peak_tflops=peak_tflops,
92
+ warmup=warmup,
93
+ iters=iters,
94
+ )
95
+ )
96
+ return fn
97
+
98
+ return deco
99
+
100
+
101
+ # ---------------------------------------------------------------------------
102
+ # Pure math helpers (unit-tested without a GPU)
103
+ # ---------------------------------------------------------------------------
104
+
105
+ def summarize_times(times_ms: list[float]) -> tuple[float, float, float]:
106
+ return (
107
+ statistics.fmean(times_ms),
108
+ statistics.median(times_ms),
109
+ min(times_ms),
110
+ )
111
+
112
+
113
+ def achieved_gbps(bytes_per_call: int, ms: float) -> float:
114
+ return bytes_per_call / (ms * 1e-3) / 1e9
115
+
116
+
117
+ def achieved_tflops(flops_per_call: int, ms: float) -> float:
118
+ return flops_per_call / (ms * 1e-3) / 1e12
119
+
120
+
121
+ def pct_of_peak(achieved: float, peak: float | None) -> float | None:
122
+ if not peak:
123
+ return None
124
+ return 100.0 * achieved / peak
125
+
126
+
127
+ def _resolve(metric: Callable[..., int] | int | None, args: tuple) -> int | None:
128
+ if metric is None:
129
+ return None
130
+ if callable(metric):
131
+ return metric(*args)
132
+ return int(metric)
133
+
134
+
135
+ def roofline_score(
136
+ nbytes: int | None,
137
+ nflops: int | None,
138
+ gbps: float | None,
139
+ tflops: float | None,
140
+ peaks: _peaks.Peaks,
141
+ peak_tflops_override: float | None = None,
142
+ ) -> tuple[float | None, str | None, float | None]:
143
+ """(intensity, bound, pct of attainable) given whatever metrics exist.
144
+
145
+ With only bytes the kernel is treated as memory-bound, with only flops
146
+ as compute-bound. With both, the roofline decides.
147
+ """
148
+ peak_tf = peak_tflops_override or peaks.fp32_tflops
149
+ peak_bw = peaks.mem_bandwidth_gbs
150
+
151
+ if nbytes and nflops and peak_tf and peak_bw and tflops:
152
+ ai = _roofline.intensity(nflops, nbytes)
153
+ attainable = _roofline.attainable_tflops(ai, peak_tf, peak_bw)
154
+ return ai, _roofline.bound(ai, peak_tf, peak_bw), 100.0 * tflops / attainable
155
+ if nbytes and gbps and peak_bw:
156
+ return None, "mem", 100.0 * gbps / peak_bw
157
+ if nflops and tflops and peak_tf:
158
+ return None, "comp", 100.0 * tflops / peak_tf
159
+ return None, None, None
160
+
161
+
162
+ def diff_results(baseline: list[dict], results: list["BenchResult"], threshold_pct: float = 5.0):
163
+ """Compare a run against a saved baseline. Returns (rows, regressions)
164
+ where rows are (name, old_ms, new_ms, delta_pct) and regressions lists
165
+ names that got slower by more than the threshold."""
166
+ old = {r["name"]: r for r in baseline}
167
+ rows = []
168
+ regressions = []
169
+ for r in results:
170
+ if r.error or r.name not in old:
171
+ continue
172
+ old_ms = old[r.name]["ms_median"]
173
+ delta = 100.0 * (r.ms_median - old_ms) / old_ms if old_ms else 0.0
174
+ rows.append((r.name, old_ms, r.ms_median, delta))
175
+ if delta > threshold_pct:
176
+ regressions.append(r.name)
177
+ return rows, regressions
178
+
179
+
180
+ # ---------------------------------------------------------------------------
181
+ # GPU execution (requires torch + an NVIDIA device)
182
+ # ---------------------------------------------------------------------------
183
+
184
+ def _time_fn(fn: Callable, args: tuple, warmup: int, iters: int, flush_l2: bool) -> list[float]:
185
+ import torch
186
+
187
+ for _ in range(warmup):
188
+ fn(*args)
189
+ torch.cuda.synchronize()
190
+
191
+ # Flushing L2 between iterations keeps memory-bound numbers honest:
192
+ # without it, small workloads get served from cache and report
193
+ # impossible bandwidths. Same approach as triton.testing.do_bench.
194
+ cache = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda") if flush_l2 else None
195
+
196
+ starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
197
+ ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
198
+ for i in range(iters):
199
+ if cache is not None:
200
+ cache.zero_()
201
+ starts[i].record()
202
+ fn(*args)
203
+ ends[i].record()
204
+ torch.cuda.synchronize()
205
+ return [s.elapsed_time(e) for s, e in zip(starts, ends)]
206
+
207
+
208
+ def _check_correctness(spec: BenchSpec, args: tuple) -> tuple[bool, float]:
209
+ import torch
210
+
211
+ got = spec.fn(*args)
212
+ want = spec.ref(*args)
213
+ max_err = (got.float() - want.float()).abs().max().item()
214
+ try:
215
+ torch.testing.assert_close(got, want, rtol=1.6e-2, atol=1e-3, check_dtype=False)
216
+ return True, max_err
217
+ except AssertionError:
218
+ return False, max_err
219
+
220
+
221
+ def device_peaks() -> _peaks.Peaks:
222
+ drv = Driver()
223
+ dev = drv.device(0)
224
+ return _peaks.derive(_attrs.query_all(drv, dev))
225
+
226
+
227
+ def run(spec: BenchSpec, peaks: _peaks.Peaks | None = None, flush_l2: bool = True) -> BenchResult:
228
+ """Execute one spec on the current CUDA device."""
229
+ if peaks is None:
230
+ try:
231
+ peaks = device_peaks()
232
+ except CudaNotAvailableError:
233
+ peaks = _peaks.Peaks(None, None, None)
234
+
235
+ args = spec.args()
236
+
237
+ correct = max_err = None
238
+ if spec.ref is not None:
239
+ correct, max_err = _check_correctness(spec, args)
240
+
241
+ times = _time_fn(spec.fn, args, spec.warmup, spec.iters, flush_l2)
242
+ ms_mean, ms_median, ms_min = summarize_times(times)
243
+
244
+ nbytes = _resolve(spec.bytes_per_call, args)
245
+ nflops = _resolve(spec.flops_per_call, args)
246
+ gbps = achieved_gbps(nbytes, ms_median) if nbytes else None
247
+ tflops = achieved_tflops(nflops, ms_median) if nflops else None
248
+ ai, kernel_bound, pct_roof = roofline_score(
249
+ nbytes, nflops, gbps, tflops, peaks, spec.peak_tflops
250
+ )
251
+
252
+ ref_ms = speedup = None
253
+ if spec.ref is not None:
254
+ ref_times = _time_fn(spec.ref, args, spec.warmup, spec.iters, flush_l2)
255
+ ref_ms = summarize_times(ref_times)[1]
256
+ speedup = ref_ms / ms_median if ms_median > 0 else None
257
+
258
+ return BenchResult(
259
+ name=spec.name,
260
+ ms_mean=ms_mean,
261
+ ms_median=ms_median,
262
+ ms_min=ms_min,
263
+ gbps=gbps,
264
+ tflops=tflops,
265
+ intensity=ai,
266
+ bound=kernel_bound,
267
+ pct_roofline=pct_roof,
268
+ pct_peak_bw=pct_of_peak(gbps, peaks.mem_bandwidth_gbs) if gbps else None,
269
+ pct_peak_fp32=pct_of_peak(tflops, peaks.fp32_tflops) if tflops else None,
270
+ ref_ms_median=ref_ms,
271
+ speedup_vs_ref=speedup,
272
+ correct=correct,
273
+ max_abs_err=max_err,
274
+ )
275
+
276
+
277
+ def run_registry(flush_l2: bool = True) -> list[BenchResult]:
278
+ try:
279
+ peaks = device_peaks()
280
+ except CudaNotAvailableError:
281
+ peaks = _peaks.Peaks(None, None, None)
282
+ results = []
283
+ for spec in REGISTRY:
284
+ try:
285
+ results.append(run(spec, peaks=peaks, flush_l2=flush_l2))
286
+ except Exception as exc: # surface per-kernel failures, keep going
287
+ results.append(
288
+ BenchResult(
289
+ name=spec.name, ms_mean=0.0, ms_median=0.0, ms_min=0.0, error=str(exc)
290
+ )
291
+ )
292
+ return results
kernelmeter/ceiling.py ADDED
@@ -0,0 +1,111 @@
1
+ """Measure what the card actually delivers, not what the spec sheet says.
2
+
3
+ Theoretical peaks are computed from max boost clocks and are never fully
4
+ reachable. This module runs the four STREAM kernels (copy, scale, add,
5
+ triad) through torch to find the real bandwidth ceiling, and a large
6
+ TF32-disabled matmul to find the real FP32 ceiling. Judge your kernels
7
+ against these numbers and 100% actually means 100%.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass
13
+
14
+ from . import bench as _bench
15
+ from . import peaks as _peaks
16
+
17
+
18
+ @dataclass
19
+ class CeilingResult:
20
+ name: str
21
+ ms: float
22
+ gbps: float | None = None
23
+ tflops: float | None = None
24
+ pct_theoretical: float | None = None
25
+
26
+ def as_dict(self) -> dict:
27
+ return dict(self.__dict__)
28
+
29
+
30
+ def _stream_specs(n: int):
31
+ """The four STREAM kernels and the bytes each one moves."""
32
+ import torch
33
+
34
+ a = torch.randn(n, device="cuda")
35
+ b = torch.empty_like(a)
36
+ c = torch.empty_like(a)
37
+ q = 2.5
38
+ elt = a.element_size()
39
+ return [
40
+ ("copy", lambda: b.copy_(a), 2 * n * elt),
41
+ ("scale", lambda: torch.mul(a, q, out=b), 2 * n * elt),
42
+ ("add", lambda: torch.add(a, b, out=c), 3 * n * elt),
43
+ ("triad", lambda: torch.add(b, c, alpha=q, out=a), 3 * n * elt),
44
+ ]
45
+
46
+
47
+ def measure(
48
+ mb: int = 256, matmul_n: int = 4096, warmup: int = 10, iters: int = 50
49
+ ) -> list[CeilingResult]:
50
+ import torch
51
+
52
+ try:
53
+ device_peaks = _bench.device_peaks()
54
+ except Exception:
55
+ device_peaks = _peaks.Peaks(None, None, None)
56
+
57
+ n = mb * 1024 * 1024 // 4 # fp32 elements
58
+ results = []
59
+ for name, fn, nbytes in _stream_specs(n):
60
+ times = _bench._time_fn(lambda *_: fn(), (), warmup, iters, flush_l2=True)
61
+ ms = _bench.summarize_times(times)[1]
62
+ gbps = _bench.achieved_gbps(nbytes, ms)
63
+ results.append(
64
+ CeilingResult(
65
+ name=name,
66
+ ms=ms,
67
+ gbps=gbps,
68
+ pct_theoretical=_bench.pct_of_peak(gbps, device_peaks.mem_bandwidth_gbs),
69
+ )
70
+ )
71
+
72
+ allow_tf32 = torch.backends.cuda.matmul.allow_tf32
73
+ torch.backends.cuda.matmul.allow_tf32 = False # keep it on the fp32 units
74
+ try:
75
+ m = torch.randn(matmul_n, matmul_n, device="cuda")
76
+ out = torch.empty_like(m)
77
+ times = _bench._time_fn(
78
+ lambda *_: torch.mm(m, m, out=out), (), warmup, iters, flush_l2=True
79
+ )
80
+ ms = _bench.summarize_times(times)[1]
81
+ tflops = _bench.achieved_tflops(2 * matmul_n**3, ms)
82
+ results.append(
83
+ CeilingResult(
84
+ name="fp32 matmul",
85
+ ms=ms,
86
+ tflops=tflops,
87
+ pct_theoretical=_bench.pct_of_peak(tflops, device_peaks.fp32_tflops),
88
+ )
89
+ )
90
+ finally:
91
+ torch.backends.cuda.matmul.allow_tf32 = allow_tf32
92
+
93
+ return results
94
+
95
+
96
+ def format_table(results: list[CeilingResult]) -> list[str]:
97
+ header = f"{'test':<14} {'median ms':>10} {'GB/s':>9} {'TFLOP/s':>9} {'% of theoretical':>17}"
98
+ lines = [header, "-" * len(header)]
99
+ for r in results:
100
+ gbps = f"{r.gbps:.1f}" if r.gbps is not None else "-"
101
+ tflops = f"{r.tflops:.2f}" if r.tflops is not None else "-"
102
+ pct = f"{r.pct_theoretical:.1f}%" if r.pct_theoretical is not None else "-"
103
+ lines.append(f"{r.name:<14} {r.ms:>10.4f} {gbps:>9} {tflops:>9} {pct:>17}")
104
+ best_bw = max((r.gbps for r in results if r.gbps), default=None)
105
+ if best_bw:
106
+ lines.append("")
107
+ lines.append(
108
+ f"measured bandwidth ceiling: {best_bw:.1f} GB/s "
109
+ "(use this as the honest 100% for memory-bound kernels)"
110
+ )
111
+ return lines