wafer-core 0.1.31__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,165 @@
1
+ """Roofline analysis for kernel performance.
2
+
3
+ Calculates what percentage of peak hardware performance a kernel achieves.
4
+ """
5
+
6
+ from typing import Literal
7
+
8
+ from wafer_core.tools.dispatch_baseline.dtypes import HardwareSpec, KernelInfo, RooflineAnalysis
9
+
10
+
11
+ # Hardware specifications for bare metal GPUs we have access to
12
+ # Note: TFLOPS values are theoretical peak (without sparsity for realistic comparison)
13
+ HARDWARE_SPECS: dict[str, HardwareSpec] = {
14
+ "B200": HardwareSpec(
15
+ name="B200",
16
+ peak_fp16_tflops=2250.0, # 4500 with sparsity
17
+ peak_fp32_tflops=18.0,
18
+ peak_memory_bw_tbps=8.0, # HBM3e
19
+ peak_fp8_tflops=4500.0,
20
+ peak_int8_tops=4500.0,
21
+ shared_memory_per_sm_kb=228.0,
22
+ ),
23
+ "MI300X": HardwareSpec(
24
+ name="MI300X",
25
+ peak_fp16_tflops=1307.4, # 2614.9 with sparsity
26
+ peak_fp32_tflops=163.4,
27
+ peak_memory_bw_tbps=5.3,
28
+ peak_fp8_tflops=2614.9,
29
+ peak_int8_tops=2614.9,
30
+ shared_memory_per_sm_kb=64.0, # per CU
31
+ ),
32
+ }
33
+
34
+
35
+ def get_hardware_spec(hardware: str) -> HardwareSpec | None:
36
+ """Get hardware specification by name.
37
+
38
+ Args:
39
+ hardware: Hardware name (e.g., "H100", "MI300X")
40
+
41
+ Returns:
42
+ HardwareSpec if found, None otherwise
43
+ """
44
+ # Normalize name (uppercase, remove common suffixes)
45
+ normalized = hardware.upper().replace("-SXM", "").replace("-PCIE", "")
46
+ return HARDWARE_SPECS.get(normalized)
47
+
48
+
49
+ def compute_roofline(
50
+ kernel: KernelInfo,
51
+ hardware: str,
52
+ flops_per_call: float,
53
+ bytes_per_call: float,
54
+ ) -> RooflineAnalysis | None:
55
+ """Compute roofline analysis for a kernel.
56
+
57
+ Args:
58
+ kernel: Kernel information with duration
59
+ hardware: Hardware name
60
+ flops_per_call: Total floating point operations per kernel call
61
+ bytes_per_call: Total bytes transferred per kernel call
62
+
63
+ Returns:
64
+ RooflineAnalysis if hardware spec found, None otherwise
65
+ """
66
+ hw_spec = get_hardware_spec(hardware)
67
+ if hw_spec is None:
68
+ return None
69
+
70
+ if kernel.duration_us <= 0:
71
+ return None
72
+
73
+ # Calculate achieved throughput
74
+ duration_sec = kernel.duration_us / 1e6
75
+ achieved_tflops = (flops_per_call / 1e12) / duration_sec
76
+ achieved_tbps = (bytes_per_call / 1e12) / duration_sec
77
+
78
+ # Calculate percentage of peak
79
+ compute_pct = (achieved_tflops / hw_spec.peak_fp16_tflops) * 100
80
+ memory_pct = (achieved_tbps / hw_spec.peak_memory_bw_tbps) * 100
81
+
82
+ # Calculate arithmetic intensity
83
+ arithmetic_intensity = flops_per_call / bytes_per_call if bytes_per_call > 0 else 0
84
+
85
+ # Determine bottleneck
86
+ # Ridge point = peak_flops / peak_bandwidth (in FLOPS/byte)
87
+ ridge_point = (hw_spec.peak_fp16_tflops * 1e12) / (hw_spec.peak_memory_bw_tbps * 1e12)
88
+
89
+ if arithmetic_intensity < ridge_point * 0.8:
90
+ bottleneck: Literal["compute", "memory", "balanced"] = "memory"
91
+ elif arithmetic_intensity > ridge_point * 1.2:
92
+ bottleneck = "compute"
93
+ else:
94
+ bottleneck = "balanced"
95
+
96
+ return RooflineAnalysis(
97
+ achieved_tflops=achieved_tflops,
98
+ achieved_memory_bw_tbps=achieved_tbps,
99
+ compute_pct_of_peak=compute_pct,
100
+ memory_bw_pct_of_peak=memory_pct,
101
+ bottleneck=bottleneck,
102
+ arithmetic_intensity=arithmetic_intensity,
103
+ )
104
+
105
+
106
+ def estimate_matmul_flops(m: int, n: int, k: int) -> float:
107
+ """Estimate FLOPs for matrix multiplication.
108
+
109
+ For C[M,N] = A[M,K] @ B[K,N]:
110
+ FLOPs = 2 * M * N * K (multiply-add counted as 2 ops)
111
+ """
112
+ return 2.0 * m * n * k
113
+
114
+
115
+ def estimate_matmul_bytes(m: int, n: int, k: int, dtype_bytes: int = 2) -> float:
116
+ """Estimate bytes transferred for matrix multiplication.
117
+
118
+ Minimum bytes = read A + read B + write C
119
+ """
120
+ a_bytes = m * k * dtype_bytes
121
+ b_bytes = k * n * dtype_bytes
122
+ c_bytes = m * n * dtype_bytes
123
+ return float(a_bytes + b_bytes + c_bytes)
124
+
125
+
126
+ def estimate_softmax_flops(elements: int) -> float:
127
+ """Estimate FLOPs for softmax.
128
+
129
+ Per element: exp, sum reduction, division
130
+ Roughly 5 ops per element (conservative)
131
+ """
132
+ return 5.0 * elements
133
+
134
+
135
+ def estimate_softmax_bytes(elements: int, dtype_bytes: int = 2) -> float:
136
+ """Estimate bytes for softmax.
137
+
138
+ Read input, write output
139
+ """
140
+ return 2.0 * elements * dtype_bytes
141
+
142
+
143
+ def estimate_attention_flops(batch: int, heads: int, seq_len: int, head_dim: int) -> float:
144
+ """Estimate FLOPs for attention.
145
+
146
+ Q @ K^T: 2 * batch * heads * seq_len * seq_len * head_dim
147
+ softmax: ~5 * batch * heads * seq_len * seq_len
148
+ attn @ V: 2 * batch * heads * seq_len * head_dim * seq_len
149
+ """
150
+ qk_flops = 2.0 * batch * heads * seq_len * seq_len * head_dim
151
+ softmax_flops = 5.0 * batch * heads * seq_len * seq_len
152
+ av_flops = 2.0 * batch * heads * seq_len * head_dim * seq_len
153
+ return qk_flops + softmax_flops + av_flops
154
+
155
+
156
+ def estimate_attention_bytes(
157
+ batch: int, heads: int, seq_len: int, head_dim: int, dtype_bytes: int = 2
158
+ ) -> float:
159
+ """Estimate bytes for attention.
160
+
161
+ Read Q, K, V, write output
162
+ """
163
+ qkv_bytes = 3 * batch * heads * seq_len * head_dim * dtype_bytes
164
+ output_bytes = batch * heads * seq_len * head_dim * dtype_bytes
165
+ return float(qkv_bytes + output_bytes)