wafer-core 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,349 @@
1
+ # Kernel Pattern Registry
2
+ # Version: 2025-01
3
+ # Last updated: 2025-01-28
4
+ # Update when: New GPU architecture, new library version, new model architecture
5
+
6
+ version: "2025-01"
7
+
8
+ # ============================================================================
9
+ # SUPPORTED HARDWARE
10
+ # ============================================================================
11
+ # NVIDIA:
12
+ # - SM100 (Blackwell): B200, B100
13
+ # - SM90 (Hopper): H100, H200
14
+ # - SM89 (Ada Lovelace): L40, RTX 4090
15
+ # - SM80 (Ampere): A100, A10, A30
16
+ #
17
+ # AMD:
18
+ # - CDNA 4 (gfx950): MI355X
19
+ # - CDNA 3 (gfx942): MI300X, MI300A, MI325X
20
+ # - CDNA 2 (gfx90a): MI250X, MI210
21
+ #
22
+ # Note: MI325X uses same gfx942 ISA as MI300X but with 256GB HBM3e memory
23
+ # ============================================================================
24
+
25
+ attention:
26
+ nvidia:
27
+ # SM100 (Blackwell B200/B100) - 'a' suffix = prefill/context, 'f' suffix = decode/forgen
28
+ - pattern: "fmhaSm100a*"
29
+ hardware: "SM100 (Blackwell)"
30
+ library: "Flash Attention 3"
31
+ phase: prefill
32
+ - pattern: "fmhaSm100f*"
33
+ hardware: "SM100 (Blackwell)"
34
+ library: "Flash Attention 3"
35
+ phase: decode
36
+ # SM90 (Hopper H100/H200) - Flash Attention 2/3
37
+ - pattern: "fmhaSm90*"
38
+ hardware: "SM90 (Hopper)"
39
+ library: "Flash Attention 3"
40
+ - pattern: "flash::flash_fwd_kernel*"
41
+ hardware: "SM90 (Hopper)"
42
+ library: "Flash Attention 2"
43
+ phase: prefill
44
+ - pattern: "flash_fwd_*"
45
+ hardware: "SM90 (Hopper)"
46
+ library: "Flash Attention 2"
47
+ - pattern: "fmha_v2_*flash_attention_forward*"
48
+ hardware: "SM90 (Hopper)"
49
+ library: "Flash Attention 2"
50
+ phase: prefill
51
+ - pattern: "fmha_v2_*"
52
+ hardware: "SM90 (Hopper)"
53
+ library: "Flash Attention 2"
54
+ # SM89 (Ada Lovelace L40/RTX 4090)
55
+ - pattern: "fmhaSm89*"
56
+ hardware: "SM89 (Ada Lovelace)"
57
+ library: "Flash Attention"
58
+ # SM80 (Ampere A100/A10)
59
+ - pattern: "fmhaSm80*"
60
+ hardware: "SM80 (Ampere)"
61
+ library: "Flash Attention"
62
+ - pattern: "fmha_*"
63
+ hardware: "SM80 (Ampere)"
64
+ library: "Flash Attention"
65
+ # Generic phase patterns (fallback)
66
+ - pattern: "*Context*"
67
+ phase: prefill
68
+ - pattern: "*context*"
69
+ phase: prefill
70
+ - pattern: "*ForGen*"
71
+ phase: decode
72
+ - pattern: "*forgen*"
73
+ phase: decode
74
+ amd:
75
+ # CDNA 4 (MI355X - gfx950) - Composable Kernel v2
76
+ - pattern: "*ck_fmha_*"
77
+ hardware: "CDNA 4 (MI355X)"
78
+ library: "Composable Kernel"
79
+ - pattern: "*flash_attn_ck*"
80
+ hardware: "CDNA 4 (MI355X)"
81
+ library: "Composable Kernel"
82
+ # CDNA 3 (MI300X/MI325X - gfx942) - Composable Kernel unified attention
83
+ - pattern: "*unified_attention_2d*"
84
+ hardware: "CDNA 3 (MI300X/MI325X)"
85
+ phase: prefill
86
+ library: "Composable Kernel"
87
+ - pattern: "*unified_attention_3d*"
88
+ hardware: "CDNA 3 (MI300X/MI325X)"
89
+ phase: decode
90
+ library: "Composable Kernel"
91
+ - pattern: "kernel_unified_attention_2d*"
92
+ hardware: "CDNA 3 (MI300X/MI325X)"
93
+ phase: prefill
94
+ library: "Composable Kernel"
95
+ - pattern: "kernel_unified_attention_3d*"
96
+ hardware: "CDNA 3 (MI300X/MI325X)"
97
+ phase: decode
98
+ library: "Composable Kernel"
99
+ - pattern: "attention_2d*"
100
+ phase: prefill
101
+ library: "Composable Kernel"
102
+ - pattern: "attention_3d*"
103
+ phase: decode
104
+ library: "Composable Kernel"
105
+ # Triton Flash Attention (works on all AMD GPUs)
106
+ - pattern: "triton_*flash*"
107
+ library: "Triton Flash Attention"
108
+ - pattern: "triton_*attention*"
109
+ library: "Triton"
110
+
111
+ gemm:
112
+ nvidia:
113
+ # cuBLASLt (H100/H200 optimized)
114
+ - pattern: "nvjet_*"
115
+ library: "cuBLASLt"
116
+ hardware: "SM90+ (Hopper/Blackwell)"
117
+ - pattern: "void cublasLt*"
118
+ library: "cuBLASLt"
119
+ # CUTLASS (all architectures)
120
+ - pattern: "cutlass*gemm*"
121
+ library: "CUTLASS 3.x"
122
+ - pattern: "cutlass_*"
123
+ library: "CUTLASS"
124
+ # cuBLAS legacy
125
+ - pattern: "cublas*"
126
+ library: "cuBLAS"
127
+ # FP8 GEMM (H100+ specific)
128
+ - pattern: "*fp8*gemm*"
129
+ library: "cuBLASLt FP8"
130
+ hardware: "SM90+ (Hopper)"
131
+ - pattern: "*e4m3*"
132
+ library: "cuBLASLt FP8"
133
+ hardware: "SM90+ (Hopper)"
134
+ amd:
135
+ # Tensile (all CDNA architectures)
136
+ - pattern: "Cijk_*"
137
+ library: "Tensile"
138
+ - pattern: "Custom_Cijk_*"
139
+ library: "Tensile"
140
+ # hipBLASLt (MI300X/MI325X/MI355X optimized)
141
+ - pattern: "wvSplitK*"
142
+ library: "hipBLASLt"
143
+ hardware: "CDNA 3/4 (MI300X/MI325X/MI355X)"
144
+ - pattern: "hipblaslt*"
145
+ library: "hipBLASLt"
146
+ - pattern: "hipblas*"
147
+ library: "hipBLAS"
148
+ # FP8 GEMM (MI300X+ specific)
149
+ - pattern: "*fp8*"
150
+ library: "hipBLASLt FP8"
151
+ hardware: "CDNA 3+ (MI300X/MI325X/MI355X)"
152
+ # CDNA 4 specific (MI355X - gfx950)
153
+ - pattern: "*gfx950*"
154
+ library: "Tensile"
155
+ hardware: "CDNA 4 (MI355X)"
156
+ # ISA-specific patterns (gfx942 = MI300X/MI325X, gfx950 = MI355X)
157
+ - pattern: "*ISA942*"
158
+ library: "Tensile"
159
+ hardware: "CDNA 3 (MI300X/MI325X)"
160
+ - pattern: "*ISA950*"
161
+ library: "Tensile"
162
+ hardware: "CDNA 4 (MI355X)"
163
+
164
+ ssm:
165
+ both:
166
+ - pattern: "selective_scan*"
167
+ model: "Mamba"
168
+ - pattern: "ssd_*"
169
+ model: "Mamba-2"
170
+ - pattern: "causal_conv1d*"
171
+ model: "Mamba"
172
+ - pattern: "mamba_*"
173
+ model: "Mamba"
174
+
175
+ rmsnorm:
176
+ both:
177
+ # Fused RMSNorm+GEMM patterns (AMD Triton fuses these)
178
+ # Key indicator: *rocm_unquantized_gemm* in kernel name
179
+ - pattern: "triton_*rocm_unquantized_gemm*rsqrt*"
180
+ library: "Triton"
181
+ fused_with: "GEMM"
182
+ - pattern: "triton_*rsqrt*rocm_unquantized_gemm*"
183
+ library: "Triton"
184
+ fused_with: "GEMM"
185
+ - pattern: "triton_*rsqrt*gemm*"
186
+ library: "Triton"
187
+ fused_with: "GEMM"
188
+ - pattern: "triton_*gemm*rsqrt*"
189
+ library: "Triton"
190
+ fused_with: "GEMM"
191
+ # Non-fused RMSNorm (no gemm in name)
192
+ - pattern: "triton_*rsqrt*"
193
+ library: "Triton"
194
+ - pattern: "*rmsnorm*"
195
+ library: "Various"
196
+
197
+ moe:
198
+ both:
199
+ - pattern: "_matmul_ogs_*"
200
+ library: "Triton"
201
+ - pattern: "bmm_*dynbatch*"
202
+ library: "Triton"
203
+ - pattern: "*routing*"
204
+ library: "Various"
205
+ - pattern: "*topk*"
206
+ library: "Various"
207
+ - pattern: "fused_moe_kernel*"
208
+ library: "vLLM"
209
+ - pattern: "*vllm::moe::*"
210
+ library: "vLLM"
211
+ - pattern: "*moe_align_block_size*"
212
+ library: "vLLM"
213
+ - pattern: "*count_and_sort_expert*"
214
+ library: "vLLM"
215
+ - pattern: "*topkGatingSoftmax*"
216
+ library: "vLLM"
217
+
218
+ # Activation functions (SwiGLU, SiLU, etc.)
219
+ activation:
220
+ both:
221
+ # Fused SwiGLU+GEMM (AMD Triton fuses these)
222
+ - pattern: "triton_*rocm_unquantized_gemm*silu*"
223
+ operation: "SwiGLU+GEMM"
224
+ library: "Triton"
225
+ fused_with: "GEMM"
226
+ - pattern: "triton_*silu*rocm_unquantized_gemm*"
227
+ operation: "SwiGLU+GEMM"
228
+ library: "Triton"
229
+ fused_with: "GEMM"
230
+ - pattern: "triton_*gemm*silu*"
231
+ operation: "SwiGLU+GEMM"
232
+ library: "Triton"
233
+ fused_with: "GEMM"
234
+ - pattern: "triton_*silu*gemm*"
235
+ operation: "SwiGLU+GEMM"
236
+ library: "Triton"
237
+ fused_with: "GEMM"
238
+ # Non-fused activation
239
+ - pattern: "*act_and_mul_kernel*"
240
+ operation: "SwiGLU"
241
+ library: "vLLM"
242
+ - pattern: "triton_*silu*"
243
+ operation: "SiLU"
244
+ library: "Triton"
245
+ - pattern: "*silu_kernel*"
246
+ operation: "SiLU"
247
+ library: "vLLM"
248
+ - pattern: "*gelu*"
249
+ operation: "GELU"
250
+ library: "Various"
251
+
252
+ # KV Cache operations
253
+ kv_cache:
254
+ both:
255
+ - pattern: "*reshape_and_cache*"
256
+ library: "vLLM"
257
+ - pattern: "*concat_and_cache*"
258
+ library: "vLLM"
259
+ - pattern: "*cache_mla*"
260
+ library: "vLLM"
261
+
262
+ # Softmax operations
263
+ softmax:
264
+ both:
265
+ - pattern: "*SoftMax*"
266
+ library: "PyTorch"
267
+ - pattern: "*softmax*"
268
+ library: "PyTorch"
269
+
270
+ # Triton fused operations (more specific patterns)
271
+ triton:
272
+ both:
273
+ - pattern: "triton_poi_fused_mul*silu*"
274
+ operation: "SwiGLU"
275
+ library: "Triton"
276
+ - pattern: "triton_poi_fused*"
277
+ operation: "Pointwise"
278
+ library: "Triton"
279
+ - pattern: "triton_red_fused*"
280
+ operation: "Reduction"
281
+ library: "Triton"
282
+ - pattern: "triton_per_fused*"
283
+ operation: "Persistent"
284
+ library: "Triton"
285
+
286
+ # Reduce/Scan operations
287
+ reduce:
288
+ nvidia:
289
+ - pattern: "*cub::*Reduce*"
290
+ library: "CUB"
291
+ - pattern: "*cub::*Scan*"
292
+ library: "CUB"
293
+ - pattern: "*splitKreduce*"
294
+ library: "cuBLASLt"
295
+ note: "GEMM epilogue reduction"
296
+ amd:
297
+ - pattern: "*rocprim::*reduce*"
298
+ library: "rocPRIM"
299
+ - pattern: "*rocprim::*scan*"
300
+ library: "rocPRIM"
301
+ - pattern: "reduce_segments*"
302
+ library: "vLLM"
303
+
304
+ # Sorting operations
305
+ sorting:
306
+ nvidia:
307
+ - pattern: "*RadixSort*"
308
+ library: "CUB"
309
+ - pattern: "*DeviceSort*"
310
+ library: "CUB"
311
+ amd:
312
+ - pattern: "*rocprim::*sort*"
313
+ library: "rocPRIM"
314
+ - pattern: "*rocprim::*merge*"
315
+ library: "rocPRIM"
316
+
317
+ # Memory/Copy operations
318
+ memory:
319
+ both:
320
+ - pattern: "*memcpy*"
321
+ library: "CUDA/HIP Runtime"
322
+ - pattern: "*direct_copy*"
323
+ library: "PyTorch"
324
+ - pattern: "*copy_page_indices*"
325
+ library: "vLLM"
326
+ - pattern: "*rocclr_copyBuffer*"
327
+ library: "AMD ROCclr"
328
+ - pattern: "*rocprim::*transform*"
329
+ library: "rocPRIM"
330
+
331
+ # Indexing/Scatter-Gather operations
332
+ indexing:
333
+ both:
334
+ - pattern: "*scatter_gather*"
335
+ library: "PyTorch"
336
+ - pattern: "*index_elementwise*"
337
+ library: "PyTorch"
338
+ - pattern: "*fill_reverse_indices*"
339
+ library: "PyTorch"
340
+
341
+ # Elementwise operations (fallback patterns)
342
+ elementwise:
343
+ both:
344
+ - pattern: "at::native::*elementwise*"
345
+ library: "PyTorch"
346
+ - pattern: "at::native::*vectorized*"
347
+ library: "PyTorch"
348
+ - pattern: "*distribution_elementwise*"
349
+ library: "PyTorch"
@@ -0,0 +1,114 @@
1
+ """Layer segmentation based on architecture type.
2
+
3
+ Segments kernels into transformer layers based on architecture-specific markers
4
+ (e.g., attention kernels for transformers, SSM scan kernels for Mamba).
5
+ """
6
+
7
+ import bisect
8
+ from typing import Any
9
+
10
+ from .architecture import ArchitectureType
11
+ from .warnings import TraceWarning
12
+
13
+
14
+ def segment_layers_by_architecture(
15
+ kernels: list[dict[str, Any]],
16
+ architecture: ArchitectureType,
17
+ ) -> tuple[dict[int, list[dict[str, Any]]], list[TraceWarning]]:
18
+ """Segment kernels into layers based on architecture.
19
+
20
+ Args:
21
+ kernels: List of kernel events with 'name', 'ts', and other fields
22
+ architecture: Detected architecture type
23
+
24
+ Returns:
25
+ Tuple of (layer_mapping, warnings)
26
+ layer_mapping: Dict mapping layer_num -> list of kernel events
27
+ warnings: List of warnings if segmentation fails
28
+ """
29
+ warnings: list[TraceWarning] = []
30
+
31
+ if architecture == ArchitectureType.HYBRID:
32
+ warnings.append(
33
+ TraceWarning(
34
+ code="HYBRID_ARCHITECTURE",
35
+ severity="info",
36
+ message="Hybrid architecture detected (both attention and SSM kernels). Layer segmentation unavailable.",
37
+ suggestion="Hybrid models require custom segmentation logic. Layer analysis will be skipped.",
38
+ )
39
+ )
40
+ return {}, warnings
41
+
42
+ if architecture == ArchitectureType.UNKNOWN:
43
+ warnings.append(
44
+ TraceWarning(
45
+ code="UNKNOWN_ARCHITECTURE",
46
+ severity="warning",
47
+ message="Cannot determine model architecture. Layer segmentation unavailable.",
48
+ suggestion="Ensure trace contains recognizable kernel patterns (attention, SSM, etc.).",
49
+ )
50
+ )
51
+ return {}, warnings
52
+
53
+ layer_markers: list[tuple[int, str]] = []
54
+
55
+ for kernel in kernels:
56
+ name_lower = kernel.get("name", "").lower()
57
+
58
+ if architecture == ArchitectureType.TRANSFORMER:
59
+ if any(pattern in name_lower for pattern in ["fmha", "attention", "flash"]):
60
+ if "context" in name_lower or "2d" in name_lower or "fmhasm100a" in name_lower:
61
+ layer_markers.append((kernel.get("ts", 0), kernel.get("name", "")))
62
+ elif architecture == ArchitectureType.SSM:
63
+ if any(pattern in name_lower for pattern in ["selective_scan", "mamba", "ssd"]):
64
+ layer_markers.append((kernel.get("ts", 0), kernel.get("name", "")))
65
+
66
+ if not layer_markers:
67
+ warnings.append(
68
+ TraceWarning(
69
+ code="NO_LAYER_MARKERS",
70
+ severity="warning",
71
+ message=f"No layer marker kernels found for {architecture.value} architecture.",
72
+ suggestion="Ensure trace contains expected kernel patterns for this architecture type.",
73
+ )
74
+ )
75
+ return {}, warnings
76
+
77
+ layer_markers.sort(key=lambda x: x[0])
78
+
79
+ # Sort kernels by timestamp for binary search
80
+ sorted_kernels = sorted(kernels, key=lambda k: k.get("ts", 0))
81
+ kernel_timestamps = [k.get("ts", 0) for k in sorted_kernels]
82
+
83
+ layer_mapping: dict[int, list[dict[str, Any]]] = {}
84
+
85
+ for i, (marker_ts, _) in enumerate(layer_markers):
86
+ layer_num = i
87
+ ts_start = marker_ts
88
+ ts_end = layer_markers[i + 1][0] if i + 1 < len(layer_markers) else float("inf")
89
+
90
+ # Binary search for start and end indices
91
+ start_idx = bisect.bisect_left(kernel_timestamps, ts_start)
92
+ end_idx = bisect.bisect_left(kernel_timestamps, ts_end) if ts_end != float("inf") else len(sorted_kernels)
93
+
94
+ layer_kernels = sorted_kernels[start_idx:end_idx]
95
+
96
+ if layer_kernels:
97
+ layer_mapping[layer_num] = layer_kernels
98
+
99
+ if layer_mapping:
100
+ kernel_counts = [len(kernels) for kernels in layer_mapping.values()]
101
+ if kernel_counts:
102
+ mean_count = sum(kernel_counts) / len(kernel_counts)
103
+ variances = [abs(count - mean_count) / mean_count for count in kernel_counts]
104
+ if any(v > 0.3 for v in variances):
105
+ warnings.append(
106
+ TraceWarning(
107
+ code="LAYER_SIZE_VARIANCE",
108
+ severity="info",
109
+ message="Layer kernel counts vary significantly. Segmentation may be inaccurate.",
110
+ suggestion="This is normal for models with varying layer sizes or non-uniform workloads.",
111
+ )
112
+ )
113
+
114
+ return layer_mapping, warnings