wafer-core 0.1.27__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/aligner.py +13 -6
- wafer_core/lib/trace_compare/analyzer.py +12 -3
- wafer_core/lib/trace_compare/fusion_analyzer.py +392 -284
- wafer_core/targets/__init__.py +47 -21
- wafer_core/targets/pool.py +181 -0
- wafer_core/targets/probe.py +113 -0
- wafer_core/targets/providers/__init__.py +46 -0
- wafer_core/targets/providers/baremetal.py +72 -0
- wafer_core/targets/providers/digitalocean.py +164 -0
- wafer_core/targets/providers/runpod.py +250 -0
- wafer_core/targets/reconcile.py +90 -0
- wafer_core/targets/spec_store.py +200 -0
- wafer_core/targets/state_cache.py +150 -0
- wafer_core/targets/types.py +141 -0
- wafer_core/utils/kernel_utils/targets/config.py +8 -24
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.28.dist-info}/METADATA +1 -1
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.28.dist-info}/RECORD +18 -8
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.28.dist-info}/WHEEL +0 -0
|
@@ -214,21 +214,28 @@ def align_kernels_within_layer(
|
|
|
214
214
|
# The platform that HAS the kernel IS fusing; the other runs components separately
|
|
215
215
|
is_fused_op = "+" in op_str
|
|
216
216
|
|
|
217
|
+
# Operations that can't be "fused away" - absence means alignment issue, not fusion
|
|
218
|
+
non_fusable_ops = {
|
|
219
|
+
"Attention (Prefill)", "Attention (Decode)", "Dense GEMM",
|
|
220
|
+
"KV Cache", "MoE GEMM", "MoE Routing"
|
|
221
|
+
}
|
|
222
|
+
is_non_fusable = op_str in non_fusable_ops
|
|
223
|
+
|
|
217
224
|
fusion_note = None
|
|
218
225
|
if amd_count > 0 and nvidia_count == 0:
|
|
219
226
|
if is_fused_op:
|
|
220
227
|
# AMD has a fused kernel like "RMSNorm+GEMM" → AMD IS fusing
|
|
221
228
|
fusion_note = f"AMD fuses {op_str} into {amd_kernel_name}"
|
|
222
|
-
|
|
223
|
-
#
|
|
224
|
-
fusion_note = f"AMD runs {amd_kernel_name}, NVIDIA
|
|
229
|
+
elif not is_non_fusable:
|
|
230
|
+
# Only mark as fusion for ops that can legitimately be fused
|
|
231
|
+
fusion_note = f"AMD runs {amd_kernel_name}, NVIDIA may fuse into another kernel"
|
|
225
232
|
elif amd_count == 0 and nvidia_count > 0:
|
|
226
233
|
if is_fused_op:
|
|
227
234
|
# NVIDIA has a fused kernel → NVIDIA IS fusing
|
|
228
235
|
fusion_note = f"NVIDIA fuses {op_str} into {nvidia_kernel_name}"
|
|
229
|
-
|
|
230
|
-
#
|
|
231
|
-
fusion_note = f"NVIDIA runs {nvidia_kernel_name}, AMD
|
|
236
|
+
elif not is_non_fusable:
|
|
237
|
+
# Only mark as fusion for ops that can legitimately be fused
|
|
238
|
+
fusion_note = f"NVIDIA runs {nvidia_kernel_name}, AMD may fuse into another kernel"
|
|
232
239
|
elif amd_count > nvidia_count * 1.5 and nvidia_count > 0:
|
|
233
240
|
# AMD runs more kernels = NVIDIA is fusing some
|
|
234
241
|
fusion_note = f"AMD runs {amd_kernel_name} {amd_count / nvidia_count:.1f}x more → NVIDIA fuses"
|
|
@@ -429,13 +429,22 @@ def analyze_traces_aligned(
|
|
|
429
429
|
"kernel_pairs": kernel_pairs,
|
|
430
430
|
})
|
|
431
431
|
|
|
432
|
-
|
|
433
|
-
same_kernel_result = analyze_same_kernels_from_alignment(alignment.layer_alignments)
|
|
434
|
-
|
|
432
|
+
# Determine which trace is AMD vs NVIDIA for fusion analysis
|
|
435
433
|
if trace1.platform == "AMD":
|
|
436
434
|
amd_trace, nvidia_trace = trace1, trace2
|
|
435
|
+
fusion_amd_kernels = amd_kernels
|
|
436
|
+
fusion_nvidia_kernels = nvidia_kernels
|
|
437
437
|
else:
|
|
438
438
|
amd_trace, nvidia_trace = trace2, trace1
|
|
439
|
+
fusion_amd_kernels = nvidia_kernels
|
|
440
|
+
fusion_nvidia_kernels = amd_kernels
|
|
441
|
+
|
|
442
|
+
fusion_result = analyze_fusion_from_alignment(
|
|
443
|
+
alignment.layer_alignments,
|
|
444
|
+
amd_kernels=fusion_amd_kernels,
|
|
445
|
+
nvidia_kernels=fusion_nvidia_kernels,
|
|
446
|
+
)
|
|
447
|
+
same_kernel_result = analyze_same_kernels_from_alignment(alignment.layer_alignments)
|
|
439
448
|
|
|
440
449
|
return {
|
|
441
450
|
"metadata": {
|