wafer-core 0.1.27__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -214,21 +214,28 @@ def align_kernels_within_layer(
214
214
  # The platform that HAS the kernel IS fusing; the other runs components separately
215
215
  is_fused_op = "+" in op_str
216
216
 
217
+ # Operations that can't be "fused away" - absence means alignment issue, not fusion
218
+ non_fusable_ops = {
219
+ "Attention (Prefill)", "Attention (Decode)", "Dense GEMM",
220
+ "KV Cache", "MoE GEMM", "MoE Routing"
221
+ }
222
+ is_non_fusable = op_str in non_fusable_ops
223
+
217
224
  fusion_note = None
218
225
  if amd_count > 0 and nvidia_count == 0:
219
226
  if is_fused_op:
220
227
  # AMD has a fused kernel like "RMSNorm+GEMM" → AMD IS fusing
221
228
  fusion_note = f"AMD fuses {op_str} into {amd_kernel_name}"
222
- else:
223
- # AMD has a regular kernel that NVIDIA doesn't need → NVIDIA fuses it elsewhere
224
- fusion_note = f"AMD runs {amd_kernel_name}, NVIDIA fuses into another kernel"
229
+ elif not is_non_fusable:
230
+ # Only mark as fusion for ops that can legitimately be fused
231
+ fusion_note = f"AMD runs {amd_kernel_name}, NVIDIA may fuse into another kernel"
225
232
  elif amd_count == 0 and nvidia_count > 0:
226
233
  if is_fused_op:
227
234
  # NVIDIA has a fused kernel → NVIDIA IS fusing
228
235
  fusion_note = f"NVIDIA fuses {op_str} into {nvidia_kernel_name}"
229
- else:
230
- # NVIDIA has a regular kernel that AMD doesn't need → AMD fuses it elsewhere
231
- fusion_note = f"NVIDIA runs {nvidia_kernel_name}, AMD fuses into another kernel"
236
+ elif not is_non_fusable:
237
+ # Only mark as fusion for ops that can legitimately be fused
238
+ fusion_note = f"NVIDIA runs {nvidia_kernel_name}, AMD may fuse into another kernel"
232
239
  elif amd_count > nvidia_count * 1.5 and nvidia_count > 0:
233
240
  # AMD runs more kernels = NVIDIA is fusing some
234
241
  fusion_note = f"AMD runs {amd_kernel_name} {amd_count / nvidia_count:.1f}x more → NVIDIA fuses"
@@ -429,13 +429,22 @@ def analyze_traces_aligned(
429
429
  "kernel_pairs": kernel_pairs,
430
430
  })
431
431
 
432
- fusion_result = analyze_fusion_from_alignment(alignment.layer_alignments)
433
- same_kernel_result = analyze_same_kernels_from_alignment(alignment.layer_alignments)
434
-
432
+ # Determine which trace is AMD vs NVIDIA for fusion analysis
435
433
  if trace1.platform == "AMD":
436
434
  amd_trace, nvidia_trace = trace1, trace2
435
+ fusion_amd_kernels = amd_kernels
436
+ fusion_nvidia_kernels = nvidia_kernels
437
437
  else:
438
438
  amd_trace, nvidia_trace = trace2, trace1
439
+ fusion_amd_kernels = nvidia_kernels
440
+ fusion_nvidia_kernels = amd_kernels
441
+
442
+ fusion_result = analyze_fusion_from_alignment(
443
+ alignment.layer_alignments,
444
+ amd_kernels=fusion_amd_kernels,
445
+ nvidia_kernels=fusion_nvidia_kernels,
446
+ )
447
+ same_kernel_result = analyze_same_kernels_from_alignment(alignment.layer_alignments)
439
448
 
440
449
  return {
441
450
  "metadata": {