wafer-core 0.1.26__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,9 +2,21 @@
2
2
 
3
3
  Classifies GPU kernels into operation categories (attention, GEMM, normalization, etc.)
4
4
  based on kernel name patterns and platform-specific conventions.
5
+
6
+ Can optionally load patterns from kernel_registry.yaml, but falls back to hardcoded patterns
7
+ for comprehensive coverage.
5
8
  """
6
9
 
10
+ import fnmatch
7
11
  from enum import Enum
12
+ from functools import lru_cache
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ try:
17
+ import yaml
18
+ except ImportError:
19
+ yaml = None
8
20
 
9
21
 
10
22
  class Op(Enum):
@@ -20,16 +32,211 @@ class Op(Enum):
20
32
  DENSE_GEMM = "Dense GEMM"
21
33
  RMSNORM = "RMSNorm"
22
34
  RMSNORM_GEMM = "RMSNorm+GEMM"
35
+ SWIGLU = "SwiGLU"
36
+ SWIGLU_GEMM = "SwiGLU+GEMM"
37
+ EMBEDDING_RMSNORM_GEMM = "Embedding+RMSNorm+GEMM"
38
+ SOFTMAX = "SoftMax"
23
39
  TRITON_FUSED = "Triton Fused"
24
40
  ELEMENTWISE = "Elementwise"
25
41
  SORTING = "Sorting"
26
42
  REDUCE = "Reduce"
43
+ INDEXING = "Indexing"
27
44
  COPY_MEMORY = "Copy/Memory"
45
+ FUSED_UNKNOWN = "Fused (Unknown)" # Heuristically detected fusion
28
46
  OTHER = "Other"
29
47
 
30
48
 
49
+ # Keywords that indicate specific operations - used for heuristic fusion detection
50
+ FUSION_KEYWORDS: dict[str, str] = {
51
+ # Normalization
52
+ "rsqrt": "Norm",
53
+ "rmsnorm": "RMSNorm",
54
+ "layernorm": "LayerNorm",
55
+ # GEMM
56
+ "gemm": "GEMM",
57
+ "matmul": "GEMM",
58
+ "mm_": "GEMM",
59
+ # Activations
60
+ "silu": "SiLU",
61
+ "swiglu": "SwiGLU",
62
+ "gelu": "GELU",
63
+ "relu": "ReLU",
64
+ # Other ops
65
+ "softmax": "Softmax",
66
+ "attention": "Attention",
67
+ "embedding": "Embedding",
68
+ "reduce": "Reduce",
69
+ }
70
+
71
+
72
+ # Load kernel registry YAML if available
73
+ _KERNEL_REGISTRY: dict[str, Any] | None = None
74
+
75
+
76
+ def _load_kernel_registry() -> dict[str, Any] | None:
77
+ """Load kernel pattern registry from YAML file."""
78
+ global _KERNEL_REGISTRY
79
+
80
+ if _KERNEL_REGISTRY is not None:
81
+ return _KERNEL_REGISTRY
82
+
83
+ if yaml is None:
84
+ return None
85
+
86
+ registry_path = Path(__file__).parent / "kernel_registry.yaml"
87
+ if not registry_path.exists():
88
+ return None
89
+
90
+ try:
91
+ with open(registry_path) as f:
92
+ _KERNEL_REGISTRY = yaml.safe_load(f)
93
+ return _KERNEL_REGISTRY
94
+ except Exception:
95
+ return None
96
+
97
+
98
+ def _match_registry_pattern(name: str, category: str, platform: str) -> tuple[Op, str] | None:
99
+ """Try to match kernel name against registry patterns.
100
+
101
+ Returns (Op, pattern_name) if match found, None otherwise.
102
+ """
103
+ registry = _load_kernel_registry()
104
+ if not registry:
105
+ return None
106
+
107
+ nl = name.lower()
108
+ category_data = registry.get(category, {})
109
+
110
+ platform_key = "amd" if platform.lower() == "amd" else "nvidia"
111
+ patterns = category_data.get(platform_key, [])
112
+ both_patterns = category_data.get("both", [])
113
+ patterns = patterns + both_patterns
114
+
115
+ for pattern_entry in patterns:
116
+ pattern = pattern_entry.get("pattern", "")
117
+ if not pattern:
118
+ continue
119
+
120
+ if fnmatch.fnmatch(name, pattern) or fnmatch.fnmatch(nl, pattern.lower()):
121
+ if category == "attention":
122
+ phase = pattern_entry.get("phase")
123
+ if phase == "prefill":
124
+ return Op.ATTN_PREFILL, pattern
125
+ elif phase == "decode":
126
+ return Op.ATTN_DECODE, pattern
127
+ return Op.ATTN_PREFILL, pattern
128
+ elif category == "gemm":
129
+ return Op.DENSE_GEMM, pattern
130
+ elif category == "rmsnorm":
131
+ # Detect fused operations (AMD Triton often fuses RMSNorm with GEMM)
132
+ has_gemm = "gemm" in nl or "unquantized_gemm" in nl
133
+ has_embedding = "embedding" in nl
134
+ if has_gemm and has_embedding:
135
+ # Embedding + RMSNorm + GEMM all fused together
136
+ return Op.EMBEDDING_RMSNORM_GEMM, pattern
137
+ elif has_gemm:
138
+ # RMSNorm + GEMM fused
139
+ return Op.RMSNORM_GEMM, pattern
140
+ return Op.RMSNORM, pattern
141
+ elif category == "moe":
142
+ # Distinguish between MoE sub-operations
143
+ if "swiglu" in nl:
144
+ return Op.MOE_GEMM_SWIGLU, pattern
145
+ if any(x in nl for x in ["routing", "topk", "align_block", "count_and_sort", "gating"]):
146
+ return Op.MOE_ROUTING, pattern
147
+ if "finalize" in nl or "scatter" in nl:
148
+ return Op.MOE_FINALIZE, pattern
149
+ return Op.MOE_GEMM, pattern
150
+ elif category == "kv_cache":
151
+ return Op.KV_CACHE, pattern
152
+ elif category == "softmax":
153
+ return Op.SOFTMAX, pattern
154
+ elif category == "reduce":
155
+ return Op.REDUCE, pattern
156
+ elif category == "sorting":
157
+ return Op.SORTING, pattern
158
+ elif category == "memory":
159
+ return Op.COPY_MEMORY, pattern
160
+ elif category == "indexing":
161
+ return Op.INDEXING, pattern
162
+ elif category == "elementwise":
163
+ return Op.ELEMENTWISE, pattern
164
+ elif category == "triton":
165
+ return Op.TRITON_FUSED, pattern
166
+ elif category == "activation":
167
+ # Check for fused SwiGLU+GEMM (AMD Triton)
168
+ has_gemm = "gemm" in nl or "unquantized_gemm" in nl
169
+ has_silu = "silu" in nl or "swiglu" in nl
170
+ if has_gemm and has_silu:
171
+ return Op.SWIGLU_GEMM, pattern
172
+ elif has_silu:
173
+ return Op.SWIGLU, pattern
174
+ # Other activations (GELU, etc.)
175
+ return Op.TRITON_FUSED, pattern
176
+
177
+ return None
178
+
179
+
180
+ def _detect_heuristic_fusion(name: str) -> tuple[Op, str] | None:
181
+ """Heuristically detect potential fusions based on multiple operation keywords.
182
+
183
+ This is a fallback for kernels we haven't explicitly classified.
184
+ If a kernel name contains 2+ distinct operation keywords, it's likely fused.
185
+
186
+ Returns (Op.FUSED_UNKNOWN, "Component1+Component2+...") if suspected fusion.
187
+ The pattern name contains the fused components for display.
188
+ """
189
+ nl = name.lower()
190
+
191
+ # Only check Triton kernels - these are most likely to be fused
192
+ if "triton" not in nl:
193
+ return None
194
+
195
+ # Find all operation keywords present in the name
196
+ # Use ordered list to maintain consistent ordering
197
+ found_ops: list[str] = []
198
+ keyword_priority = [
199
+ # Order matters - more specific first
200
+ ("embedding", "Embedding"),
201
+ ("rmsnorm", "RMSNorm"),
202
+ ("layernorm", "LayerNorm"),
203
+ ("rsqrt", "Norm"), # Generic norm indicator
204
+ ("swiglu", "SwiGLU"),
205
+ ("silu", "SiLU"),
206
+ ("gelu", "GELU"),
207
+ ("relu", "ReLU"),
208
+ ("gemm", "GEMM"),
209
+ ("matmul", "GEMM"),
210
+ ("mm_", "GEMM"),
211
+ ("softmax", "Softmax"),
212
+ ("attention", "Attention"),
213
+ ("reduce", "Reduce"),
214
+ ]
215
+
216
+ for keyword, op_name in keyword_priority:
217
+ if keyword in nl and op_name not in found_ops:
218
+ # Avoid duplicates like "RMSNorm" and "Norm"
219
+ if op_name == "Norm" and any(n in found_ops for n in ["RMSNorm", "LayerNorm"]):
220
+ continue
221
+ # Avoid duplicates like "SwiGLU" and "SiLU"
222
+ if op_name == "SiLU" and "SwiGLU" in found_ops:
223
+ continue
224
+ found_ops.append(op_name)
225
+
226
+ # If 2+ operations detected, it's likely a fusion
227
+ if len(found_ops) >= 2:
228
+ fused_name = "+".join(found_ops)
229
+ # The pattern name IS the fused operation name for display
230
+ return Op.FUSED_UNKNOWN, fused_name
231
+
232
+ return None
233
+
234
+
235
+ @lru_cache(maxsize=4096)
31
236
  def classify(name: str, platform: str) -> tuple[Op, str]:
32
237
  """Classify kernel by operation type.
238
+
239
+ Cached because PyTorch traces have ~48 unique kernel names repeated 810k times.
33
240
 
34
241
  Args:
35
242
  name: Kernel name from trace
@@ -39,8 +246,27 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
39
246
  Tuple of (operation type, pattern name)
40
247
  """
41
248
  nl = name.lower()
42
-
43
- # Attention
249
+
250
+ # Check registry patterns first (order matters - more specific categories first)
251
+ registry_categories = [
252
+ "attention", # Attention ops (prefill/decode)
253
+ "gemm", # Dense GEMM
254
+ "rmsnorm", # RMSNorm
255
+ "moe", # MoE operations
256
+ "kv_cache", # KV Cache
257
+ "activation", # SwiGLU, SiLU, GELU
258
+ "softmax", # Softmax
259
+ "reduce", # Reduce/Scan
260
+ "sorting", # Sorting
261
+ "memory", # Memory/Copy
262
+ "indexing", # Index/Scatter-Gather
263
+ "triton", # Triton fused ops
264
+ "elementwise", # Elementwise (last - most generic)
265
+ ]
266
+ for category in registry_categories:
267
+ result = _match_registry_pattern(name, category, platform)
268
+ if result:
269
+ return result
44
270
  if "attention" in nl or "fmha" in nl:
45
271
  if platform == "AMD":
46
272
  if "2d" in nl:
@@ -55,10 +281,35 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
55
281
  return Op.ATTN_DECODE, "fmhaSm100f*_ForGen"
56
282
  return Op.ATTN_PREFILL, name[:40]
57
283
 
284
+ # Flash Attention variants (vLLM)
285
+ if "flash::flash_fwd_kernel" in name or "flash_fwd" in nl:
286
+ # Could distinguish prefill/decode if needed, defaulting to prefill
287
+ return Op.ATTN_PREFILL, "flash::flash_fwd_kernel"
288
+
58
289
  if "reshape_and_cache" in nl:
59
290
  return Op.KV_CACHE, "reshape_and_cache_*"
60
291
 
292
+ # KV Cache variants (vLLM)
293
+ if "concat_and_cache" in nl or "cache_mla" in nl:
294
+ return Op.KV_CACHE, "vllm::concat_and_cache_*"
295
+
61
296
  # MoE
297
+ # vLLM MoE kernels - these are very common in MoE models
298
+ if "fused_moe_kernel" in nl:
299
+ return Op.MOE_GEMM, "fused_moe_kernel"
300
+
301
+ if "vllm::moe::" in name:
302
+ if "moe_align_block_size" in nl:
303
+ if "small_batch" in nl:
304
+ return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_small_batch_*"
305
+ return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_*"
306
+ if "moe_sum" in nl:
307
+ return Op.MOE_FINALIZE, "vllm::moe::moe_sum_*"
308
+
309
+ # vLLM act_and_mul (can be mangled C++ name)
310
+ if "vllm::act_and_mul_kernel" in name or ("act_and_mul_kernel" in nl and "vllm" in nl):
311
+ return Op.MOE_GEMM_SWIGLU, "vllm::act_and_mul_kernel"
312
+
62
313
  if "_matmul_ogs_" in nl:
63
314
  if "swiglu" in nl:
64
315
  return Op.MOE_GEMM_SWIGLU, "_matmul_ogs_*_swiglu"
@@ -69,8 +320,13 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
69
320
  return Op.MOE_GEMM_SWIGLU, "bmm_*_swiGlu_dynBatch"
70
321
  return Op.MOE_GEMM, "bmm_*_dynBatch"
71
322
 
323
+ # Generic MoE routing patterns (check before finalize)
72
324
  if any(x in nl for x in ["topk", "routing", "bitmatrix", "moe_forward", "_combined_routing"]):
325
+ if "moe::dev::routing::" in name or "moe::" in name:
326
+ return Op.MOE_ROUTING, "moe::dev::routing::*"
73
327
  return Op.MOE_ROUTING, "moe_routing_*"
328
+
329
+ # MoE finalize patterns
74
330
  if "finalize" in nl or ("scatter" in nl and "moe" in nl):
75
331
  return Op.MOE_FINALIZE, "moe_finalize_*"
76
332
 
@@ -87,6 +343,18 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
87
343
  return Op.DENSE_GEMM, "nvjet_* (cuBLASLt)"
88
344
  if "wvsplitk" in nl or name.startswith("void wvSplitK"):
89
345
  return Op.DENSE_GEMM, "wvSplitK_* (hipBLASLt)"
346
+
347
+ # CUTLASS GEMM variants
348
+ if "cutlass" in nl and ("sgemm" in nl or "gemm" in nl or "cutlass3x" in name.lower()):
349
+ return Op.DENSE_GEMM, "cutlass*_gemm"
350
+
351
+ # GEMV (matrix-vector) operations - treat as GEMM
352
+ if "gemv" in nl or "gemvx" in nl:
353
+ return Op.DENSE_GEMM, "gemv*"
354
+
355
+ # Generic GEMM patterns
356
+ if "gemmsn" in nl or name.startswith("void gemmSN"):
357
+ return Op.DENSE_GEMM, "gemmSN_*"
90
358
 
91
359
  # Triton fused operations - very common
92
360
  if "triton_poi" in nl or "triton_red" in nl or "triton_per" in nl:
@@ -95,9 +363,21 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
95
363
  return Op.TRITON_FUSED, "triton_*_silu"
96
364
  return Op.TRITON_FUSED, "triton_*"
97
365
 
98
- # PyTorch native operations
99
- if "at::native::" in name:
100
- return Op.ELEMENTWISE, "at::native::*"
366
+ # SoftMax operations
367
+ if "softmax" in nl:
368
+ return Op.SOFTMAX, "softmax_*"
369
+
370
+ # Reduce operations - catch more patterns
371
+ if "reduce" in nl:
372
+ if "reduce_segments" in nl or "devicereduce" in nl:
373
+ return Op.REDUCE, "reduce_segments"
374
+ if "reduce_kernel" in nl:
375
+ return Op.REDUCE, "reduce_kernel"
376
+ return Op.REDUCE, "reduce_*"
377
+
378
+ # Scan operations (library internals - similar to reduce)
379
+ if "scan" in nl and ("cub::" in name or "rocprim::" in name or "at_cuda_detail::cub::" in name):
380
+ return Op.REDUCE, "cub/rocprim::scan_*"
101
381
 
102
382
  # Sorting operations (common in sampling/topk)
103
383
  if "sort" in nl or "radixsort" in nl or "merge" in nl:
@@ -106,21 +386,31 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
106
386
  else:
107
387
  return Op.SORTING, "cub::DeviceRadixSort*"
108
388
 
109
- # Reduce operations
110
- if "reduce" in nl and ("reduce_segments" in nl or "devicereduce" in nl or "devicescan" in nl):
111
- if platform == "AMD":
112
- return Op.REDUCE, "reduce_segments"
113
- else:
114
- return Op.REDUCE, "cub::DeviceReduce*"
115
-
389
+ # Indexing/Scatter/Gather operations
390
+ if any(x in nl for x in ["indices", "scatter", "gather", "index_select", "embedding"]):
391
+ return Op.INDEXING, "index/scatter_*"
392
+
116
393
  # Memory copy operations
117
394
  if "copy" in nl or "memcpy" in nl or "_copy_page_indices" in nl:
118
395
  return Op.COPY_MEMORY, "copy_*"
119
396
 
397
+ # PyTorch native operations (catch-all for at::native)
398
+ if "at::native::" in name:
399
+ # Try to be more specific
400
+ if "fill" in nl:
401
+ return Op.ELEMENTWISE, "at::native::fill_*"
402
+ return Op.ELEMENTWISE, "at::native::*"
403
+
120
404
  # ROCm/CUDA library kernels (other)
121
405
  if "rocprim::" in name or "cub::" in name:
122
406
  return Op.OTHER, "rocprim/cub_*"
123
407
 
408
+ # Fallback: Heuristic fusion detection for unknown Triton kernels
409
+ # If a kernel has multiple operation keywords, it's likely fused
410
+ heuristic_result = _detect_heuristic_fusion(name)
411
+ if heuristic_result:
412
+ return heuristic_result
413
+
124
414
  return Op.OTHER, name[:40]
125
415
 
126
416
 
@@ -170,7 +460,7 @@ def classify_kernel(name: str) -> str:
170
460
  return "Triton_Persistent"
171
461
 
172
462
  # Reduce operations
173
- if "reduce_segments" in nl or "devicereduce" in nl:
463
+ if "reduce" in nl:
174
464
  return "Reduce"
175
465
 
176
466
  # Sort operations
@@ -180,6 +470,10 @@ def classify_kernel(name: str) -> str:
180
470
  # Softmax
181
471
  if "softmax" in nl:
182
472
  return "Softmax"
473
+
474
+ # Indexing/Scatter/Gather
475
+ if any(x in nl for x in ["indices", "scatter", "gather", "index_select", "embedding"]):
476
+ return "Indexing"
183
477
 
184
478
  # Elementwise operations
185
479
  if any(x in nl for x in ["elementwise", "unrolled_elementwise"]):