wafer-core 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,486 @@
1
+ """Kernel classification logic for trace comparison.
2
+
3
+ Classifies GPU kernels into operation categories (attention, GEMM, normalization, etc.)
4
+ based on kernel name patterns and platform-specific conventions.
5
+
6
+ Can optionally load patterns from kernel_registry.yaml, but falls back to hardcoded patterns
7
+ for comprehensive coverage.
8
+ """
9
+
10
+ import fnmatch
11
+ from enum import Enum
12
+ from functools import lru_cache
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ try:
17
+ import yaml
18
+ except ImportError:
19
+ yaml = None
20
+
21
+
22
+ class Op(Enum):
23
+ """Kernel operation categories."""
24
+
25
+ ATTN_PREFILL = "Attention (Prefill)"
26
+ ATTN_DECODE = "Attention (Decode)"
27
+ KV_CACHE = "KV Cache"
28
+ MOE_ROUTING = "MoE Routing"
29
+ MOE_GEMM = "MoE GEMM"
30
+ MOE_GEMM_SWIGLU = "MoE GEMM+SwiGLU"
31
+ MOE_FINALIZE = "MoE Finalize"
32
+ DENSE_GEMM = "Dense GEMM"
33
+ RMSNORM = "RMSNorm"
34
+ RMSNORM_GEMM = "RMSNorm+GEMM"
35
+ SWIGLU = "SwiGLU"
36
+ SWIGLU_GEMM = "SwiGLU+GEMM"
37
+ EMBEDDING_RMSNORM_GEMM = "Embedding+RMSNorm+GEMM"
38
+ SOFTMAX = "SoftMax"
39
+ TRITON_FUSED = "Triton Fused"
40
+ ELEMENTWISE = "Elementwise"
41
+ SORTING = "Sorting"
42
+ REDUCE = "Reduce"
43
+ INDEXING = "Indexing"
44
+ COPY_MEMORY = "Copy/Memory"
45
+ FUSED_UNKNOWN = "Fused (Unknown)" # Heuristically detected fusion
46
+ OTHER = "Other"
47
+
48
+
49
+ # Keywords that indicate specific operations - used for heuristic fusion detection
50
+ FUSION_KEYWORDS: dict[str, str] = {
51
+ # Normalization
52
+ "rsqrt": "Norm",
53
+ "rmsnorm": "RMSNorm",
54
+ "layernorm": "LayerNorm",
55
+ # GEMM
56
+ "gemm": "GEMM",
57
+ "matmul": "GEMM",
58
+ "mm_": "GEMM",
59
+ # Activations
60
+ "silu": "SiLU",
61
+ "swiglu": "SwiGLU",
62
+ "gelu": "GELU",
63
+ "relu": "ReLU",
64
+ # Other ops
65
+ "softmax": "Softmax",
66
+ "attention": "Attention",
67
+ "embedding": "Embedding",
68
+ "reduce": "Reduce",
69
+ }
70
+
71
+
72
+ # Load kernel registry YAML if available
73
+ _KERNEL_REGISTRY: dict[str, Any] | None = None
74
+
75
+
76
+ def _load_kernel_registry() -> dict[str, Any] | None:
77
+ """Load kernel pattern registry from YAML file."""
78
+ global _KERNEL_REGISTRY
79
+
80
+ if _KERNEL_REGISTRY is not None:
81
+ return _KERNEL_REGISTRY
82
+
83
+ if yaml is None:
84
+ return None
85
+
86
+ registry_path = Path(__file__).parent / "kernel_registry.yaml"
87
+ if not registry_path.exists():
88
+ return None
89
+
90
+ try:
91
+ with open(registry_path) as f:
92
+ _KERNEL_REGISTRY = yaml.safe_load(f)
93
+ return _KERNEL_REGISTRY
94
+ except Exception:
95
+ return None
96
+
97
+
98
+ def _match_registry_pattern(name: str, category: str, platform: str) -> tuple[Op, str] | None:
99
+ """Try to match kernel name against registry patterns.
100
+
101
+ Returns (Op, pattern_name) if match found, None otherwise.
102
+ """
103
+ registry = _load_kernel_registry()
104
+ if not registry:
105
+ return None
106
+
107
+ nl = name.lower()
108
+ category_data = registry.get(category, {})
109
+
110
+ platform_key = "amd" if platform.lower() == "amd" else "nvidia"
111
+ patterns = category_data.get(platform_key, [])
112
+ both_patterns = category_data.get("both", [])
113
+ patterns = patterns + both_patterns
114
+
115
+ for pattern_entry in patterns:
116
+ pattern = pattern_entry.get("pattern", "")
117
+ if not pattern:
118
+ continue
119
+
120
+ if fnmatch.fnmatch(name, pattern) or fnmatch.fnmatch(nl, pattern.lower()):
121
+ if category == "attention":
122
+ phase = pattern_entry.get("phase")
123
+ if phase == "prefill":
124
+ return Op.ATTN_PREFILL, pattern
125
+ elif phase == "decode":
126
+ return Op.ATTN_DECODE, pattern
127
+ return Op.ATTN_PREFILL, pattern
128
+ elif category == "gemm":
129
+ return Op.DENSE_GEMM, pattern
130
+ elif category == "rmsnorm":
131
+ # Detect fused operations (AMD Triton often fuses RMSNorm with GEMM)
132
+ has_gemm = "gemm" in nl or "unquantized_gemm" in nl
133
+ has_embedding = "embedding" in nl
134
+ if has_gemm and has_embedding:
135
+ # Embedding + RMSNorm + GEMM all fused together
136
+ return Op.EMBEDDING_RMSNORM_GEMM, pattern
137
+ elif has_gemm:
138
+ # RMSNorm + GEMM fused
139
+ return Op.RMSNORM_GEMM, pattern
140
+ return Op.RMSNORM, pattern
141
+ elif category == "moe":
142
+ # Distinguish between MoE sub-operations
143
+ if "swiglu" in nl:
144
+ return Op.MOE_GEMM_SWIGLU, pattern
145
+ if any(x in nl for x in ["routing", "topk", "align_block", "count_and_sort", "gating"]):
146
+ return Op.MOE_ROUTING, pattern
147
+ if "finalize" in nl or "scatter" in nl:
148
+ return Op.MOE_FINALIZE, pattern
149
+ return Op.MOE_GEMM, pattern
150
+ elif category == "kv_cache":
151
+ return Op.KV_CACHE, pattern
152
+ elif category == "softmax":
153
+ return Op.SOFTMAX, pattern
154
+ elif category == "reduce":
155
+ return Op.REDUCE, pattern
156
+ elif category == "sorting":
157
+ return Op.SORTING, pattern
158
+ elif category == "memory":
159
+ return Op.COPY_MEMORY, pattern
160
+ elif category == "indexing":
161
+ return Op.INDEXING, pattern
162
+ elif category == "elementwise":
163
+ return Op.ELEMENTWISE, pattern
164
+ elif category == "triton":
165
+ return Op.TRITON_FUSED, pattern
166
+ elif category == "activation":
167
+ # Check for fused SwiGLU+GEMM (AMD Triton)
168
+ has_gemm = "gemm" in nl or "unquantized_gemm" in nl
169
+ has_silu = "silu" in nl or "swiglu" in nl
170
+ if has_gemm and has_silu:
171
+ return Op.SWIGLU_GEMM, pattern
172
+ elif has_silu:
173
+ return Op.SWIGLU, pattern
174
+ # Other activations (GELU, etc.)
175
+ return Op.TRITON_FUSED, pattern
176
+
177
+ return None
178
+
179
+
180
+ def _detect_heuristic_fusion(name: str) -> tuple[Op, str] | None:
181
+ """Heuristically detect potential fusions based on multiple operation keywords.
182
+
183
+ This is a fallback for kernels we haven't explicitly classified.
184
+ If a kernel name contains 2+ distinct operation keywords, it's likely fused.
185
+
186
+ Returns (Op.FUSED_UNKNOWN, "Component1+Component2+...") if suspected fusion.
187
+ The pattern name contains the fused components for display.
188
+ """
189
+ nl = name.lower()
190
+
191
+ # Only check Triton kernels - these are most likely to be fused
192
+ if "triton" not in nl:
193
+ return None
194
+
195
+ # Find all operation keywords present in the name
196
+ # Use ordered list to maintain consistent ordering
197
+ found_ops: list[str] = []
198
+ keyword_priority = [
199
+ # Order matters - more specific first
200
+ ("embedding", "Embedding"),
201
+ ("rmsnorm", "RMSNorm"),
202
+ ("layernorm", "LayerNorm"),
203
+ ("rsqrt", "Norm"), # Generic norm indicator
204
+ ("swiglu", "SwiGLU"),
205
+ ("silu", "SiLU"),
206
+ ("gelu", "GELU"),
207
+ ("relu", "ReLU"),
208
+ ("gemm", "GEMM"),
209
+ ("matmul", "GEMM"),
210
+ ("mm_", "GEMM"),
211
+ ("softmax", "Softmax"),
212
+ ("attention", "Attention"),
213
+ ("reduce", "Reduce"),
214
+ ]
215
+
216
+ for keyword, op_name in keyword_priority:
217
+ if keyword in nl and op_name not in found_ops:
218
+ # Avoid duplicates like "RMSNorm" and "Norm"
219
+ if op_name == "Norm" and any(n in found_ops for n in ["RMSNorm", "LayerNorm"]):
220
+ continue
221
+ # Avoid duplicates like "SwiGLU" and "SiLU"
222
+ if op_name == "SiLU" and "SwiGLU" in found_ops:
223
+ continue
224
+ found_ops.append(op_name)
225
+
226
+ # If 2+ operations detected, it's likely a fusion
227
+ if len(found_ops) >= 2:
228
+ fused_name = "+".join(found_ops)
229
+ # The pattern name IS the fused operation name for display
230
+ return Op.FUSED_UNKNOWN, fused_name
231
+
232
+ return None
233
+
234
+
235
+ @lru_cache(maxsize=4096)
236
+ def classify(name: str, platform: str) -> tuple[Op, str]:
237
+ """Classify kernel by operation type.
238
+
239
+ Cached because PyTorch traces have ~48 unique kernel names repeated 810k times.
240
+
241
+ Args:
242
+ name: Kernel name from trace
243
+ platform: 'AMD' or 'NVIDIA'
244
+
245
+ Returns:
246
+ Tuple of (operation type, pattern name)
247
+ """
248
+ nl = name.lower()
249
+
250
+ # Check registry patterns first (order matters - more specific categories first)
251
+ registry_categories = [
252
+ "attention", # Attention ops (prefill/decode)
253
+ "gemm", # Dense GEMM
254
+ "rmsnorm", # RMSNorm
255
+ "moe", # MoE operations
256
+ "kv_cache", # KV Cache
257
+ "activation", # SwiGLU, SiLU, GELU
258
+ "softmax", # Softmax
259
+ "reduce", # Reduce/Scan
260
+ "sorting", # Sorting
261
+ "memory", # Memory/Copy
262
+ "indexing", # Index/Scatter-Gather
263
+ "triton", # Triton fused ops
264
+ "elementwise", # Elementwise (last - most generic)
265
+ ]
266
+ for category in registry_categories:
267
+ result = _match_registry_pattern(name, category, platform)
268
+ if result:
269
+ return result
270
+ if "attention" in nl or "fmha" in nl:
271
+ if platform == "AMD":
272
+ if "2d" in nl:
273
+ return Op.ATTN_PREFILL, "kernel_unified_attention_2d"
274
+ if "3d" in nl:
275
+ return Op.ATTN_DECODE, "kernel_unified_attention_3d"
276
+ else:
277
+ # NVIDIA uses fmhaSm100 with 'a' (prefill/context) and 'f' (decode/forgen)
278
+ if "fmhasm100a" in nl or "context" in nl:
279
+ return Op.ATTN_PREFILL, "fmhaSm100a*_Context"
280
+ if "fmhasm100f" in nl or "forgen" in nl:
281
+ return Op.ATTN_DECODE, "fmhaSm100f*_ForGen"
282
+ return Op.ATTN_PREFILL, name[:40]
283
+
284
+ # Flash Attention variants (vLLM)
285
+ if "flash::flash_fwd_kernel" in name or "flash_fwd" in nl:
286
+ # Could distinguish prefill/decode if needed, defaulting to prefill
287
+ return Op.ATTN_PREFILL, "flash::flash_fwd_kernel"
288
+
289
+ if "reshape_and_cache" in nl:
290
+ return Op.KV_CACHE, "reshape_and_cache_*"
291
+
292
+ # KV Cache variants (vLLM)
293
+ if "concat_and_cache" in nl or "cache_mla" in nl:
294
+ return Op.KV_CACHE, "vllm::concat_and_cache_*"
295
+
296
+ # MoE
297
+ # vLLM MoE kernels - these are very common in MoE models
298
+ if "fused_moe_kernel" in nl:
299
+ return Op.MOE_GEMM, "fused_moe_kernel"
300
+
301
+ if "vllm::moe::" in name:
302
+ if "moe_align_block_size" in nl:
303
+ if "small_batch" in nl:
304
+ return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_small_batch_*"
305
+ return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_*"
306
+ if "moe_sum" in nl:
307
+ return Op.MOE_FINALIZE, "vllm::moe::moe_sum_*"
308
+
309
+ # vLLM act_and_mul (can be mangled C++ name)
310
+ if "vllm::act_and_mul_kernel" in name or ("act_and_mul_kernel" in nl and "vllm" in nl):
311
+ return Op.MOE_GEMM_SWIGLU, "vllm::act_and_mul_kernel"
312
+
313
+ if "_matmul_ogs_" in nl:
314
+ if "swiglu" in nl:
315
+ return Op.MOE_GEMM_SWIGLU, "_matmul_ogs_*_swiglu"
316
+ return Op.MOE_GEMM, "_matmul_ogs_*"
317
+
318
+ if name.startswith("bmm_") and "dynbatch" in nl:
319
+ if "swiglu" in nl:
320
+ return Op.MOE_GEMM_SWIGLU, "bmm_*_swiGlu_dynBatch"
321
+ return Op.MOE_GEMM, "bmm_*_dynBatch"
322
+
323
+ # Generic MoE routing patterns (check before finalize)
324
+ if any(x in nl for x in ["topk", "routing", "bitmatrix", "moe_forward", "_combined_routing"]):
325
+ if "moe::dev::routing::" in name or "moe::" in name:
326
+ return Op.MOE_ROUTING, "moe::dev::routing::*"
327
+ return Op.MOE_ROUTING, "moe_routing_*"
328
+
329
+ # MoE finalize patterns
330
+ if "finalize" in nl or ("scatter" in nl and "moe" in nl):
331
+ return Op.MOE_FINALIZE, "moe_finalize_*"
332
+
333
+ # RMSNorm - match actual patterns from traces
334
+ if "triton" in nl and ("rsqrt" in nl or ("mean" in nl and "mul" in nl and "pow" in nl)):
335
+ if "gemm" in nl or "addmm" in nl:
336
+ return Op.RMSNORM_GEMM, "triton_*_rmsnorm_gemm"
337
+ return Op.RMSNORM, "triton_*_rsqrt"
338
+
339
+ # Dense GEMM - these are the most common kernels
340
+ if name.startswith("Cijk_") or name.startswith("Custom_Cijk_"):
341
+ return Op.DENSE_GEMM, "Cijk_* (Tensile)"
342
+ if name.startswith("nvjet_") or "cublaslt" in nl:
343
+ return Op.DENSE_GEMM, "nvjet_* (cuBLASLt)"
344
+ if "wvsplitk" in nl or name.startswith("void wvSplitK"):
345
+ return Op.DENSE_GEMM, "wvSplitK_* (hipBLASLt)"
346
+
347
+ # CUTLASS GEMM variants
348
+ if "cutlass" in nl and ("sgemm" in nl or "gemm" in nl or "cutlass3x" in name.lower()):
349
+ return Op.DENSE_GEMM, "cutlass*_gemm"
350
+
351
+ # GEMV (matrix-vector) operations - treat as GEMM
352
+ if "gemv" in nl or "gemvx" in nl:
353
+ return Op.DENSE_GEMM, "gemv*"
354
+
355
+ # Generic GEMM patterns
356
+ if "gemmsn" in nl or name.startswith("void gemmSN"):
357
+ return Op.DENSE_GEMM, "gemmSN_*"
358
+
359
+ # Triton fused operations - very common
360
+ if "triton_poi" in nl or "triton_red" in nl or "triton_per" in nl:
361
+ # Distinguish between different fusion patterns
362
+ if "silu" in nl or "swiglu" in nl:
363
+ return Op.TRITON_FUSED, "triton_*_silu"
364
+ return Op.TRITON_FUSED, "triton_*"
365
+
366
+ # SoftMax operations
367
+ if "softmax" in nl:
368
+ return Op.SOFTMAX, "softmax_*"
369
+
370
+ # Reduce operations - catch more patterns
371
+ if "reduce" in nl:
372
+ if "reduce_segments" in nl or "devicereduce" in nl:
373
+ return Op.REDUCE, "reduce_segments"
374
+ if "reduce_kernel" in nl:
375
+ return Op.REDUCE, "reduce_kernel"
376
+ return Op.REDUCE, "reduce_*"
377
+
378
+ # Scan operations (library internals - similar to reduce)
379
+ if "scan" in nl and ("cub::" in name or "rocprim::" in name or "at_cuda_detail::cub::" in name):
380
+ return Op.REDUCE, "cub/rocprim::scan_*"
381
+
382
+ # Sorting operations (common in sampling/topk)
383
+ if "sort" in nl or "radixsort" in nl or "merge" in nl:
384
+ if platform == "AMD":
385
+ return Op.SORTING, "rocprim::sort/merge_*"
386
+ else:
387
+ return Op.SORTING, "cub::DeviceRadixSort*"
388
+
389
+ # Indexing/Scatter/Gather operations
390
+ if any(x in nl for x in ["indices", "scatter", "gather", "index_select", "embedding"]):
391
+ return Op.INDEXING, "index/scatter_*"
392
+
393
+ # Memory copy operations
394
+ if "copy" in nl or "memcpy" in nl or "_copy_page_indices" in nl:
395
+ return Op.COPY_MEMORY, "copy_*"
396
+
397
+ # PyTorch native operations (catch-all for at::native)
398
+ if "at::native::" in name:
399
+ # Try to be more specific
400
+ if "fill" in nl:
401
+ return Op.ELEMENTWISE, "at::native::fill_*"
402
+ return Op.ELEMENTWISE, "at::native::*"
403
+
404
+ # ROCm/CUDA library kernels (other)
405
+ if "rocprim::" in name or "cub::" in name:
406
+ return Op.OTHER, "rocprim/cub_*"
407
+
408
+ # Fallback: Heuristic fusion detection for unknown Triton kernels
409
+ # If a kernel has multiple operation keywords, it's likely fused
410
+ heuristic_result = _detect_heuristic_fusion(name)
411
+ if heuristic_result:
412
+ return heuristic_result
413
+
414
+ return Op.OTHER, name[:40]
415
+
416
+
417
+ def classify_kernel(name: str) -> str:
418
+ """Simplified kernel classification for fusion analysis.
419
+
420
+ Args:
421
+ name: Kernel name from trace
422
+
423
+ Returns:
424
+ Simple category name consistent across platforms
425
+ """
426
+ nl = name.lower()
427
+
428
+ # GEMM operations (matrix multiplication)
429
+ if any(x in nl for x in ["cijk_", "nvjet", "wvsplitk", "cublas", "hipblas", "tensile"]):
430
+ return "GEMM"
431
+
432
+ # Attention
433
+ if "attention" in nl or "fmha" in nl:
434
+ return "Attention"
435
+
436
+ # KV Cache
437
+ if "reshape_and_cache" in nl:
438
+ return "KV_Cache"
439
+
440
+ # RMSNorm / LayerNorm
441
+ if "triton" in nl and "rsqrt" in nl:
442
+ return "RMSNorm"
443
+ if "layernorm" in nl or "rmsnorm" in nl:
444
+ return "RMSNorm"
445
+
446
+ # SwiGLU / Activations
447
+ if "silu" in nl or "swiglu" in nl:
448
+ return "SwiGLU"
449
+ if "gelu" in nl:
450
+ return "GELU"
451
+ if "relu" in nl and "gelu" not in nl:
452
+ return "ReLU"
453
+
454
+ # Triton fused operations (generic)
455
+ if "triton_poi" in nl:
456
+ return "Triton_Pointwise"
457
+ if "triton_red" in nl:
458
+ return "Triton_Reduce"
459
+ if "triton_per" in nl:
460
+ return "Triton_Persistent"
461
+
462
+ # Reduce operations
463
+ if "reduce" in nl:
464
+ return "Reduce"
465
+
466
+ # Sort operations
467
+ if "sort" in nl or "radixsort" in nl or "merge" in nl:
468
+ return "Sort"
469
+
470
+ # Softmax
471
+ if "softmax" in nl:
472
+ return "Softmax"
473
+
474
+ # Indexing/Scatter/Gather
475
+ if any(x in nl for x in ["indices", "scatter", "gather", "index_select", "embedding"]):
476
+ return "Indexing"
477
+
478
+ # Elementwise operations
479
+ if any(x in nl for x in ["elementwise", "unrolled_elementwise"]):
480
+ return "Elementwise"
481
+
482
+ # Copy/Memory operations
483
+ if "copy" in nl or "memcpy" in nl:
484
+ return "MemCopy"
485
+
486
+ return "Other"