wafer-core 0.1.32__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,21 +2,9 @@
2
2
 
3
3
  Classifies GPU kernels into operation categories (attention, GEMM, normalization, etc.)
4
4
  based on kernel name patterns and platform-specific conventions.
5
-
6
- Can optionally load patterns from kernel_registry.yaml, but falls back to hardcoded patterns
7
- for comprehensive coverage.
8
5
  """
9
6
 
10
- import fnmatch
11
7
  from enum import Enum
12
- from functools import lru_cache
13
- from pathlib import Path
14
- from typing import Any
15
-
16
- try:
17
- import yaml
18
- except ImportError:
19
- yaml = None
20
8
 
21
9
 
22
10
  class Op(Enum):
@@ -24,223 +12,24 @@ class Op(Enum):
24
12
 
25
13
  ATTN_PREFILL = "Attention (Prefill)"
26
14
  ATTN_DECODE = "Attention (Decode)"
27
- # NVIDIA Flash Attention fuses QKV projection + Softmax + Attention
28
- FLASH_ATTN_FUSED = "FlashAttention (QKV+Softmax+Attn)"
29
15
  KV_CACHE = "KV Cache"
30
16
  MOE_ROUTING = "MoE Routing"
31
17
  MOE_GEMM = "MoE GEMM"
32
18
  MOE_GEMM_SWIGLU = "MoE GEMM+SwiGLU"
33
19
  MOE_FINALIZE = "MoE Finalize"
34
20
  DENSE_GEMM = "Dense GEMM"
35
- # NVIDIA cuBLASLt/CUTLASS can fuse GEMM with epilogue (bias + activation)
36
- GEMM_BIAS_ACT = "GEMM+Bias+Activation"
37
21
  RMSNORM = "RMSNorm"
38
22
  RMSNORM_GEMM = "RMSNorm+GEMM"
39
- SWIGLU = "SwiGLU"
40
- SWIGLU_GEMM = "SwiGLU+GEMM"
41
- EMBEDDING_RMSNORM_GEMM = "Embedding+RMSNorm+GEMM"
42
- SOFTMAX = "SoftMax"
43
23
  TRITON_FUSED = "Triton Fused"
44
24
  ELEMENTWISE = "Elementwise"
45
25
  SORTING = "Sorting"
46
26
  REDUCE = "Reduce"
47
- INDEXING = "Indexing"
48
27
  COPY_MEMORY = "Copy/Memory"
49
- FUSED_UNKNOWN = "Fused (Unknown)" # Heuristically detected fusion
50
28
  OTHER = "Other"
51
29
 
52
30
 
53
- # Keywords that indicate specific operations - used for heuristic fusion detection
54
- FUSION_KEYWORDS: dict[str, str] = {
55
- # Normalization
56
- "rsqrt": "Norm",
57
- "rmsnorm": "RMSNorm",
58
- "layernorm": "LayerNorm",
59
- # GEMM
60
- "gemm": "GEMM",
61
- "matmul": "GEMM",
62
- "mm_": "GEMM",
63
- # Activations
64
- "silu": "SiLU",
65
- "swiglu": "SwiGLU",
66
- "gelu": "GELU",
67
- "relu": "ReLU",
68
- # Other ops
69
- "softmax": "Softmax",
70
- "attention": "Attention",
71
- "embedding": "Embedding",
72
- "reduce": "Reduce",
73
- }
74
-
75
-
76
- # Load kernel registry YAML if available
77
- _KERNEL_REGISTRY: dict[str, Any] | None = None
78
-
79
-
80
- def _load_kernel_registry() -> dict[str, Any] | None:
81
- """Load kernel pattern registry from YAML file."""
82
- global _KERNEL_REGISTRY
83
-
84
- if _KERNEL_REGISTRY is not None:
85
- return _KERNEL_REGISTRY
86
-
87
- if yaml is None:
88
- return None
89
-
90
- registry_path = Path(__file__).parent / "kernel_registry.yaml"
91
- if not registry_path.exists():
92
- return None
93
-
94
- try:
95
- with open(registry_path) as f:
96
- _KERNEL_REGISTRY = yaml.safe_load(f)
97
- return _KERNEL_REGISTRY
98
- except Exception:
99
- return None
100
-
101
-
102
- def _match_registry_pattern(name: str, category: str, platform: str) -> tuple[Op, str] | None:
103
- """Try to match kernel name against registry patterns.
104
-
105
- Returns (Op, pattern_name) if match found, None otherwise.
106
- """
107
- registry = _load_kernel_registry()
108
- if not registry:
109
- return None
110
-
111
- nl = name.lower()
112
- category_data = registry.get(category, {})
113
-
114
- platform_key = "amd" if platform.lower() == "amd" else "nvidia"
115
- patterns = category_data.get(platform_key, [])
116
- both_patterns = category_data.get("both", [])
117
- patterns = patterns + both_patterns
118
-
119
- for pattern_entry in patterns:
120
- pattern = pattern_entry.get("pattern", "")
121
- if not pattern:
122
- continue
123
-
124
- if fnmatch.fnmatch(name, pattern) or fnmatch.fnmatch(nl, pattern.lower()):
125
- if category == "attention":
126
- phase = pattern_entry.get("phase")
127
- if phase == "prefill":
128
- return Op.ATTN_PREFILL, pattern
129
- elif phase == "decode":
130
- return Op.ATTN_DECODE, pattern
131
- return Op.ATTN_PREFILL, pattern
132
- elif category == "gemm":
133
- return Op.DENSE_GEMM, pattern
134
- elif category == "rmsnorm":
135
- # Detect fused operations (AMD Triton often fuses RMSNorm with GEMM)
136
- has_gemm = "gemm" in nl or "unquantized_gemm" in nl
137
- has_embedding = "embedding" in nl
138
- if has_gemm and has_embedding:
139
- # Embedding + RMSNorm + GEMM all fused together
140
- return Op.EMBEDDING_RMSNORM_GEMM, pattern
141
- elif has_gemm:
142
- # RMSNorm + GEMM fused
143
- return Op.RMSNORM_GEMM, pattern
144
- return Op.RMSNORM, pattern
145
- elif category == "moe":
146
- # Distinguish between MoE sub-operations
147
- if "swiglu" in nl:
148
- return Op.MOE_GEMM_SWIGLU, pattern
149
- if any(x in nl for x in ["routing", "topk", "align_block", "count_and_sort", "gating"]):
150
- return Op.MOE_ROUTING, pattern
151
- if "finalize" in nl or "scatter" in nl:
152
- return Op.MOE_FINALIZE, pattern
153
- return Op.MOE_GEMM, pattern
154
- elif category == "kv_cache":
155
- return Op.KV_CACHE, pattern
156
- elif category == "softmax":
157
- return Op.SOFTMAX, pattern
158
- elif category == "reduce":
159
- return Op.REDUCE, pattern
160
- elif category == "sorting":
161
- return Op.SORTING, pattern
162
- elif category == "memory":
163
- return Op.COPY_MEMORY, pattern
164
- elif category == "indexing":
165
- return Op.INDEXING, pattern
166
- elif category == "elementwise":
167
- return Op.ELEMENTWISE, pattern
168
- elif category == "triton":
169
- return Op.TRITON_FUSED, pattern
170
- elif category == "activation":
171
- # Check for fused SwiGLU+GEMM (AMD Triton)
172
- has_gemm = "gemm" in nl or "unquantized_gemm" in nl
173
- has_silu = "silu" in nl or "swiglu" in nl
174
- if has_gemm and has_silu:
175
- return Op.SWIGLU_GEMM, pattern
176
- elif has_silu:
177
- return Op.SWIGLU, pattern
178
- # Other activations (GELU, etc.)
179
- return Op.TRITON_FUSED, pattern
180
-
181
- return None
182
-
183
-
184
- def _detect_heuristic_fusion(name: str) -> tuple[Op, str] | None:
185
- """Heuristically detect potential fusions based on multiple operation keywords.
186
-
187
- This is a fallback for kernels we haven't explicitly classified.
188
- If a kernel name contains 2+ distinct operation keywords, it's likely fused.
189
-
190
- Returns (Op.FUSED_UNKNOWN, "Component1+Component2+...") if suspected fusion.
191
- The pattern name contains the fused components for display.
192
- """
193
- nl = name.lower()
194
-
195
- # Only check Triton kernels - these are most likely to be fused
196
- if "triton" not in nl:
197
- return None
198
-
199
- # Find all operation keywords present in the name
200
- # Use ordered list to maintain consistent ordering
201
- found_ops: list[str] = []
202
- keyword_priority = [
203
- # Order matters - more specific first
204
- ("embedding", "Embedding"),
205
- ("rmsnorm", "RMSNorm"),
206
- ("layernorm", "LayerNorm"),
207
- ("rsqrt", "Norm"), # Generic norm indicator
208
- ("swiglu", "SwiGLU"),
209
- ("silu", "SiLU"),
210
- ("gelu", "GELU"),
211
- ("relu", "ReLU"),
212
- ("gemm", "GEMM"),
213
- ("matmul", "GEMM"),
214
- ("mm_", "GEMM"),
215
- ("softmax", "Softmax"),
216
- ("attention", "Attention"),
217
- ("reduce", "Reduce"),
218
- ]
219
-
220
- for keyword, op_name in keyword_priority:
221
- if keyword in nl and op_name not in found_ops:
222
- # Avoid duplicates like "RMSNorm" and "Norm"
223
- if op_name == "Norm" and any(n in found_ops for n in ["RMSNorm", "LayerNorm"]):
224
- continue
225
- # Avoid duplicates like "SwiGLU" and "SiLU"
226
- if op_name == "SiLU" and "SwiGLU" in found_ops:
227
- continue
228
- found_ops.append(op_name)
229
-
230
- # If 2+ operations detected, it's likely a fusion
231
- if len(found_ops) >= 2:
232
- fused_name = "+".join(found_ops)
233
- # The pattern name IS the fused operation name for display
234
- return Op.FUSED_UNKNOWN, fused_name
235
-
236
- return None
237
-
238
-
239
- @lru_cache(maxsize=4096)
240
31
  def classify(name: str, platform: str) -> tuple[Op, str]:
241
32
  """Classify kernel by operation type.
242
-
243
- Cached because PyTorch traces have ~48 unique kernel names repeated 810k times.
244
33
 
245
34
  Args:
246
35
  name: Kernel name from trace
@@ -250,27 +39,8 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
250
39
  Tuple of (operation type, pattern name)
251
40
  """
252
41
  nl = name.lower()
253
-
254
- # Check registry patterns first (order matters - more specific categories first)
255
- registry_categories = [
256
- "attention", # Attention ops (prefill/decode)
257
- "gemm", # Dense GEMM
258
- "rmsnorm", # RMSNorm
259
- "moe", # MoE operations
260
- "kv_cache", # KV Cache
261
- "activation", # SwiGLU, SiLU, GELU
262
- "softmax", # Softmax
263
- "reduce", # Reduce/Scan
264
- "sorting", # Sorting
265
- "memory", # Memory/Copy
266
- "indexing", # Index/Scatter-Gather
267
- "triton", # Triton fused ops
268
- "elementwise", # Elementwise (last - most generic)
269
- ]
270
- for category in registry_categories:
271
- result = _match_registry_pattern(name, category, platform)
272
- if result:
273
- return result
42
+
43
+ # Attention
274
44
  if "attention" in nl or "fmha" in nl:
275
45
  if platform == "AMD":
276
46
  if "2d" in nl:
@@ -278,47 +48,17 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
278
48
  if "3d" in nl:
279
49
  return Op.ATTN_DECODE, "kernel_unified_attention_3d"
280
50
  else:
281
- # NVIDIA Flash Attention (fmhaSm100*) is a fused kernel
282
- # It fuses QKV projection + Softmax + Attention into one kernel
283
- if "fmhasm100" in nl:
284
- if "fmhasm100a" in nl or "context" in nl:
285
- return Op.FLASH_ATTN_FUSED, "fmhaSm100a*_Context (QKV+Softmax+Attn)"
286
- if "fmhasm100f" in nl or "forgen" in nl:
287
- return Op.FLASH_ATTN_FUSED, "fmhaSm100f*_ForGen (QKV+Softmax+Attn)"
288
- return Op.FLASH_ATTN_FUSED, "fmhaSm100* (QKV+Softmax+Attn)"
51
+ # NVIDIA uses fmhaSm100 with 'a' (prefill/context) and 'f' (decode/forgen)
52
+ if "fmhasm100a" in nl or "context" in nl:
53
+ return Op.ATTN_PREFILL, "fmhaSm100a*_Context"
54
+ if "fmhasm100f" in nl or "forgen" in nl:
55
+ return Op.ATTN_DECODE, "fmhaSm100f*_ForGen"
289
56
  return Op.ATTN_PREFILL, name[:40]
290
57
 
291
- # Flash Attention variants (vLLM) - these are fused on NVIDIA
292
- if "flash::flash_fwd_kernel" in name or "flash_fwd" in nl:
293
- if platform != "AMD":
294
- return Op.FLASH_ATTN_FUSED, "flash::flash_fwd_kernel (QKV+Softmax+Attn)"
295
- return Op.ATTN_PREFILL, "flash::flash_fwd_kernel"
296
-
297
58
  if "reshape_and_cache" in nl:
298
59
  return Op.KV_CACHE, "reshape_and_cache_*"
299
60
 
300
- # KV Cache variants (vLLM)
301
- if "concat_and_cache" in nl or "cache_mla" in nl:
302
- return Op.KV_CACHE, "vllm::concat_and_cache_*"
303
-
304
61
  # MoE
305
- # vLLM MoE kernels - these are very common in MoE models
306
- if "fused_moe_kernel" in nl:
307
- return Op.MOE_GEMM, "fused_moe_kernel"
308
-
309
- if "vllm::moe::" in name:
310
- if "moe_align_block_size" in nl:
311
- if "small_batch" in nl:
312
- return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_small_batch_*"
313
- return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_*"
314
- if "moe_sum" in nl:
315
- return Op.MOE_FINALIZE, "vllm::moe::moe_sum_*"
316
-
317
- # vLLM act_and_mul - fuses activation with element-wise multiply (SiLU * x)
318
- # This is a fused operation used in SwiGLU/MoE
319
- if "vllm::act_and_mul_kernel" in name or ("act_and_mul_kernel" in nl and "vllm" in nl):
320
- return Op.SWIGLU_GEMM, "vllm::act_and_mul_kernel (SwiGLU+Mul)"
321
-
322
62
  if "_matmul_ogs_" in nl:
323
63
  if "swiglu" in nl:
324
64
  return Op.MOE_GEMM_SWIGLU, "_matmul_ogs_*_swiglu"
@@ -329,13 +69,8 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
329
69
  return Op.MOE_GEMM_SWIGLU, "bmm_*_swiGlu_dynBatch"
330
70
  return Op.MOE_GEMM, "bmm_*_dynBatch"
331
71
 
332
- # Generic MoE routing patterns (check before finalize)
333
72
  if any(x in nl for x in ["topk", "routing", "bitmatrix", "moe_forward", "_combined_routing"]):
334
- if "moe::dev::routing::" in name or "moe::" in name:
335
- return Op.MOE_ROUTING, "moe::dev::routing::*"
336
73
  return Op.MOE_ROUTING, "moe_routing_*"
337
-
338
- # MoE finalize patterns
339
74
  if "finalize" in nl or ("scatter" in nl and "moe" in nl):
340
75
  return Op.MOE_FINALIZE, "moe_finalize_*"
341
76
 
@@ -352,18 +87,6 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
352
87
  return Op.DENSE_GEMM, "nvjet_* (cuBLASLt)"
353
88
  if "wvsplitk" in nl or name.startswith("void wvSplitK"):
354
89
  return Op.DENSE_GEMM, "wvSplitK_* (hipBLASLt)"
355
-
356
- # CUTLASS GEMM variants
357
- if "cutlass" in nl and ("sgemm" in nl or "gemm" in nl or "cutlass3x" in name.lower()):
358
- return Op.DENSE_GEMM, "cutlass*_gemm"
359
-
360
- # GEMV (matrix-vector) operations - treat as GEMM
361
- if "gemv" in nl or "gemvx" in nl:
362
- return Op.DENSE_GEMM, "gemv*"
363
-
364
- # Generic GEMM patterns
365
- if "gemmsn" in nl or name.startswith("void gemmSN"):
366
- return Op.DENSE_GEMM, "gemmSN_*"
367
90
 
368
91
  # Triton fused operations - very common
369
92
  if "triton_poi" in nl or "triton_red" in nl or "triton_per" in nl:
@@ -372,21 +95,9 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
372
95
  return Op.TRITON_FUSED, "triton_*_silu"
373
96
  return Op.TRITON_FUSED, "triton_*"
374
97
 
375
- # SoftMax operations
376
- if "softmax" in nl:
377
- return Op.SOFTMAX, "softmax_*"
378
-
379
- # Reduce operations - catch more patterns
380
- if "reduce" in nl:
381
- if "reduce_segments" in nl or "devicereduce" in nl:
382
- return Op.REDUCE, "reduce_segments"
383
- if "reduce_kernel" in nl:
384
- return Op.REDUCE, "reduce_kernel"
385
- return Op.REDUCE, "reduce_*"
386
-
387
- # Scan operations (library internals - similar to reduce)
388
- if "scan" in nl and ("cub::" in name or "rocprim::" in name or "at_cuda_detail::cub::" in name):
389
- return Op.REDUCE, "cub/rocprim::scan_*"
98
+ # PyTorch native operations
99
+ if "at::native::" in name:
100
+ return Op.ELEMENTWISE, "at::native::*"
390
101
 
391
102
  # Sorting operations (common in sampling/topk)
392
103
  if "sort" in nl or "radixsort" in nl or "merge" in nl:
@@ -395,31 +106,21 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
395
106
  else:
396
107
  return Op.SORTING, "cub::DeviceRadixSort*"
397
108
 
398
- # Indexing/Scatter/Gather operations
399
- if any(x in nl for x in ["indices", "scatter", "gather", "index_select", "embedding"]):
400
- return Op.INDEXING, "index/scatter_*"
401
-
109
+ # Reduce operations
110
+ if "reduce" in nl and ("reduce_segments" in nl or "devicereduce" in nl or "devicescan" in nl):
111
+ if platform == "AMD":
112
+ return Op.REDUCE, "reduce_segments"
113
+ else:
114
+ return Op.REDUCE, "cub::DeviceReduce*"
115
+
402
116
  # Memory copy operations
403
117
  if "copy" in nl or "memcpy" in nl or "_copy_page_indices" in nl:
404
118
  return Op.COPY_MEMORY, "copy_*"
405
119
 
406
- # PyTorch native operations (catch-all for at::native)
407
- if "at::native::" in name:
408
- # Try to be more specific
409
- if "fill" in nl:
410
- return Op.ELEMENTWISE, "at::native::fill_*"
411
- return Op.ELEMENTWISE, "at::native::*"
412
-
413
120
  # ROCm/CUDA library kernels (other)
414
121
  if "rocprim::" in name or "cub::" in name:
415
122
  return Op.OTHER, "rocprim/cub_*"
416
123
 
417
- # Fallback: Heuristic fusion detection for unknown Triton kernels
418
- # If a kernel has multiple operation keywords, it's likely fused
419
- heuristic_result = _detect_heuristic_fusion(name)
420
- if heuristic_result:
421
- return heuristic_result
422
-
423
124
  return Op.OTHER, name[:40]
424
125
 
425
126
 
@@ -469,7 +170,7 @@ def classify_kernel(name: str) -> str:
469
170
  return "Triton_Persistent"
470
171
 
471
172
  # Reduce operations
472
- if "reduce" in nl:
173
+ if "reduce_segments" in nl or "devicereduce" in nl:
473
174
  return "Reduce"
474
175
 
475
176
  # Sort operations
@@ -479,10 +180,6 @@ def classify_kernel(name: str) -> str:
479
180
  # Softmax
480
181
  if "softmax" in nl:
481
182
  return "Softmax"
482
-
483
- # Indexing/Scatter/Gather
484
- if any(x in nl for x in ["indices", "scatter", "gather", "index_select", "embedding"]):
485
- return "Indexing"
486
183
 
487
184
  # Elementwise operations
488
185
  if any(x in nl for x in ["elementwise", "unrolled_elementwise"]):