wafer-core 0.1.27__py3-none-any.whl → 0.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -214,21 +214,28 @@ def align_kernels_within_layer(
214
214
  # The platform that HAS the kernel IS fusing; the other runs components separately
215
215
  is_fused_op = "+" in op_str
216
216
 
217
+ # Operations that can't be "fused away" - absence means alignment issue, not fusion
218
+ non_fusable_ops = {
219
+ "Attention (Prefill)", "Attention (Decode)", "Dense GEMM",
220
+ "KV Cache", "MoE GEMM", "MoE Routing"
221
+ }
222
+ is_non_fusable = op_str in non_fusable_ops
223
+
217
224
  fusion_note = None
218
225
  if amd_count > 0 and nvidia_count == 0:
219
226
  if is_fused_op:
220
227
  # AMD has a fused kernel like "RMSNorm+GEMM" → AMD IS fusing
221
228
  fusion_note = f"AMD fuses {op_str} into {amd_kernel_name}"
222
- else:
223
- # AMD has a regular kernel that NVIDIA doesn't need → NVIDIA fuses it elsewhere
224
- fusion_note = f"AMD runs {amd_kernel_name}, NVIDIA fuses into another kernel"
229
+ elif not is_non_fusable:
230
+ # Only mark as fusion for ops that can legitimately be fused
231
+ fusion_note = f"AMD runs {amd_kernel_name}, NVIDIA may fuse into another kernel"
225
232
  elif amd_count == 0 and nvidia_count > 0:
226
233
  if is_fused_op:
227
234
  # NVIDIA has a fused kernel → NVIDIA IS fusing
228
235
  fusion_note = f"NVIDIA fuses {op_str} into {nvidia_kernel_name}"
229
- else:
230
- # NVIDIA has a regular kernel that AMD doesn't need → AMD fuses it elsewhere
231
- fusion_note = f"NVIDIA runs {nvidia_kernel_name}, AMD fuses into another kernel"
236
+ elif not is_non_fusable:
237
+ # Only mark as fusion for ops that can legitimately be fused
238
+ fusion_note = f"NVIDIA runs {nvidia_kernel_name}, AMD may fuse into another kernel"
232
239
  elif amd_count > nvidia_count * 1.5 and nvidia_count > 0:
233
240
  # AMD runs more kernels = NVIDIA is fusing some
234
241
  fusion_note = f"AMD runs {amd_kernel_name} {amd_count / nvidia_count:.1f}x more → NVIDIA fuses"
@@ -429,13 +429,22 @@ def analyze_traces_aligned(
429
429
  "kernel_pairs": kernel_pairs,
430
430
  })
431
431
 
432
- fusion_result = analyze_fusion_from_alignment(alignment.layer_alignments)
433
- same_kernel_result = analyze_same_kernels_from_alignment(alignment.layer_alignments)
434
-
432
+ # Determine which trace is AMD vs NVIDIA for fusion analysis
435
433
  if trace1.platform == "AMD":
436
434
  amd_trace, nvidia_trace = trace1, trace2
435
+ fusion_amd_kernels = amd_kernels
436
+ fusion_nvidia_kernels = nvidia_kernels
437
437
  else:
438
438
  amd_trace, nvidia_trace = trace2, trace1
439
+ fusion_amd_kernels = nvidia_kernels
440
+ fusion_nvidia_kernels = amd_kernels
441
+
442
+ fusion_result = analyze_fusion_from_alignment(
443
+ alignment.layer_alignments,
444
+ amd_kernels=fusion_amd_kernels,
445
+ nvidia_kernels=fusion_nvidia_kernels,
446
+ )
447
+ same_kernel_result = analyze_same_kernels_from_alignment(alignment.layer_alignments)
439
448
 
440
449
  return {
441
450
  "metadata": {
@@ -24,12 +24,16 @@ class Op(Enum):
24
24
 
25
25
  ATTN_PREFILL = "Attention (Prefill)"
26
26
  ATTN_DECODE = "Attention (Decode)"
27
+ # NVIDIA Flash Attention fuses QKV projection + Softmax + Attention
28
+ FLASH_ATTN_FUSED = "FlashAttention (QKV+Softmax+Attn)"
27
29
  KV_CACHE = "KV Cache"
28
30
  MOE_ROUTING = "MoE Routing"
29
31
  MOE_GEMM = "MoE GEMM"
30
32
  MOE_GEMM_SWIGLU = "MoE GEMM+SwiGLU"
31
33
  MOE_FINALIZE = "MoE Finalize"
32
34
  DENSE_GEMM = "Dense GEMM"
35
+ # NVIDIA cuBLASLt/CUTLASS can fuse GEMM with epilogue (bias + activation)
36
+ GEMM_BIAS_ACT = "GEMM+Bias+Activation"
33
37
  RMSNORM = "RMSNorm"
34
38
  RMSNORM_GEMM = "RMSNorm+GEMM"
35
39
  SWIGLU = "SwiGLU"
@@ -274,16 +278,20 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
274
278
  if "3d" in nl:
275
279
  return Op.ATTN_DECODE, "kernel_unified_attention_3d"
276
280
  else:
277
- # NVIDIA uses fmhaSm100 with 'a' (prefill/context) and 'f' (decode/forgen)
278
- if "fmhasm100a" in nl or "context" in nl:
279
- return Op.ATTN_PREFILL, "fmhaSm100a*_Context"
280
- if "fmhasm100f" in nl or "forgen" in nl:
281
- return Op.ATTN_DECODE, "fmhaSm100f*_ForGen"
281
+ # NVIDIA Flash Attention (fmhaSm100*) is a fused kernel
282
+ # It fuses QKV projection + Softmax + Attention into one kernel
283
+ if "fmhasm100" in nl:
284
+ if "fmhasm100a" in nl or "context" in nl:
285
+ return Op.FLASH_ATTN_FUSED, "fmhaSm100a*_Context (QKV+Softmax+Attn)"
286
+ if "fmhasm100f" in nl or "forgen" in nl:
287
+ return Op.FLASH_ATTN_FUSED, "fmhaSm100f*_ForGen (QKV+Softmax+Attn)"
288
+ return Op.FLASH_ATTN_FUSED, "fmhaSm100* (QKV+Softmax+Attn)"
282
289
  return Op.ATTN_PREFILL, name[:40]
283
290
 
284
- # Flash Attention variants (vLLM)
291
+ # Flash Attention variants (vLLM) - these are fused on NVIDIA
285
292
  if "flash::flash_fwd_kernel" in name or "flash_fwd" in nl:
286
- # Could distinguish prefill/decode if needed, defaulting to prefill
293
+ if platform != "AMD":
294
+ return Op.FLASH_ATTN_FUSED, "flash::flash_fwd_kernel (QKV+Softmax+Attn)"
287
295
  return Op.ATTN_PREFILL, "flash::flash_fwd_kernel"
288
296
 
289
297
  if "reshape_and_cache" in nl:
@@ -306,9 +314,10 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
306
314
  if "moe_sum" in nl:
307
315
  return Op.MOE_FINALIZE, "vllm::moe::moe_sum_*"
308
316
 
309
- # vLLM act_and_mul (can be mangled C++ name)
317
+ # vLLM act_and_mul - fuses activation with element-wise multiply (SiLU * x)
318
+ # This is a fused operation used in SwiGLU/MoE
310
319
  if "vllm::act_and_mul_kernel" in name or ("act_and_mul_kernel" in nl and "vllm" in nl):
311
- return Op.MOE_GEMM_SWIGLU, "vllm::act_and_mul_kernel"
320
+ return Op.SWIGLU_GEMM, "vllm::act_and_mul_kernel (SwiGLU+Mul)"
312
321
 
313
322
  if "_matmul_ogs_" in nl:
314
323
  if "swiglu" in nl: