wafer-core 0.1.27__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/aligner.py +13 -6
- wafer_core/lib/trace_compare/analyzer.py +12 -3
- wafer_core/lib/trace_compare/classifier.py +18 -9
- wafer_core/lib/trace_compare/fusion_analyzer.py +424 -275
- wafer_core/targets/__init__.py +47 -21
- wafer_core/targets/pool.py +181 -0
- wafer_core/targets/probe.py +113 -0
- wafer_core/targets/providers/__init__.py +46 -0
- wafer_core/targets/providers/baremetal.py +72 -0
- wafer_core/targets/providers/digitalocean.py +164 -0
- wafer_core/targets/providers/runpod.py +250 -0
- wafer_core/targets/reconcile.py +90 -0
- wafer_core/targets/spec_store.py +200 -0
- wafer_core/targets/state_cache.py +150 -0
- wafer_core/targets/types.py +141 -0
- wafer_core/utils/kernel_utils/targets/config.py +8 -24
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.29.dist-info}/METADATA +1 -1
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.29.dist-info}/RECORD +19 -9
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.29.dist-info}/WHEEL +0 -0
|
@@ -214,21 +214,28 @@ def align_kernels_within_layer(
|
|
|
214
214
|
# The platform that HAS the kernel IS fusing; the other runs components separately
|
|
215
215
|
is_fused_op = "+" in op_str
|
|
216
216
|
|
|
217
|
+
# Operations that can't be "fused away" - absence means alignment issue, not fusion
|
|
218
|
+
non_fusable_ops = {
|
|
219
|
+
"Attention (Prefill)", "Attention (Decode)", "Dense GEMM",
|
|
220
|
+
"KV Cache", "MoE GEMM", "MoE Routing"
|
|
221
|
+
}
|
|
222
|
+
is_non_fusable = op_str in non_fusable_ops
|
|
223
|
+
|
|
217
224
|
fusion_note = None
|
|
218
225
|
if amd_count > 0 and nvidia_count == 0:
|
|
219
226
|
if is_fused_op:
|
|
220
227
|
# AMD has a fused kernel like "RMSNorm+GEMM" → AMD IS fusing
|
|
221
228
|
fusion_note = f"AMD fuses {op_str} into {amd_kernel_name}"
|
|
222
|
-
|
|
223
|
-
#
|
|
224
|
-
fusion_note = f"AMD runs {amd_kernel_name}, NVIDIA
|
|
229
|
+
elif not is_non_fusable:
|
|
230
|
+
# Only mark as fusion for ops that can legitimately be fused
|
|
231
|
+
fusion_note = f"AMD runs {amd_kernel_name}, NVIDIA may fuse into another kernel"
|
|
225
232
|
elif amd_count == 0 and nvidia_count > 0:
|
|
226
233
|
if is_fused_op:
|
|
227
234
|
# NVIDIA has a fused kernel → NVIDIA IS fusing
|
|
228
235
|
fusion_note = f"NVIDIA fuses {op_str} into {nvidia_kernel_name}"
|
|
229
|
-
|
|
230
|
-
#
|
|
231
|
-
fusion_note = f"NVIDIA runs {nvidia_kernel_name}, AMD
|
|
236
|
+
elif not is_non_fusable:
|
|
237
|
+
# Only mark as fusion for ops that can legitimately be fused
|
|
238
|
+
fusion_note = f"NVIDIA runs {nvidia_kernel_name}, AMD may fuse into another kernel"
|
|
232
239
|
elif amd_count > nvidia_count * 1.5 and nvidia_count > 0:
|
|
233
240
|
# AMD runs more kernels = NVIDIA is fusing some
|
|
234
241
|
fusion_note = f"AMD runs {amd_kernel_name} {amd_count / nvidia_count:.1f}x more → NVIDIA fuses"
|
|
@@ -429,13 +429,22 @@ def analyze_traces_aligned(
|
|
|
429
429
|
"kernel_pairs": kernel_pairs,
|
|
430
430
|
})
|
|
431
431
|
|
|
432
|
-
|
|
433
|
-
same_kernel_result = analyze_same_kernels_from_alignment(alignment.layer_alignments)
|
|
434
|
-
|
|
432
|
+
# Determine which trace is AMD vs NVIDIA for fusion analysis
|
|
435
433
|
if trace1.platform == "AMD":
|
|
436
434
|
amd_trace, nvidia_trace = trace1, trace2
|
|
435
|
+
fusion_amd_kernels = amd_kernels
|
|
436
|
+
fusion_nvidia_kernels = nvidia_kernels
|
|
437
437
|
else:
|
|
438
438
|
amd_trace, nvidia_trace = trace2, trace1
|
|
439
|
+
fusion_amd_kernels = nvidia_kernels
|
|
440
|
+
fusion_nvidia_kernels = amd_kernels
|
|
441
|
+
|
|
442
|
+
fusion_result = analyze_fusion_from_alignment(
|
|
443
|
+
alignment.layer_alignments,
|
|
444
|
+
amd_kernels=fusion_amd_kernels,
|
|
445
|
+
nvidia_kernels=fusion_nvidia_kernels,
|
|
446
|
+
)
|
|
447
|
+
same_kernel_result = analyze_same_kernels_from_alignment(alignment.layer_alignments)
|
|
439
448
|
|
|
440
449
|
return {
|
|
441
450
|
"metadata": {
|
|
@@ -24,12 +24,16 @@ class Op(Enum):
|
|
|
24
24
|
|
|
25
25
|
ATTN_PREFILL = "Attention (Prefill)"
|
|
26
26
|
ATTN_DECODE = "Attention (Decode)"
|
|
27
|
+
# NVIDIA Flash Attention fuses QKV projection + Softmax + Attention
|
|
28
|
+
FLASH_ATTN_FUSED = "FlashAttention (QKV+Softmax+Attn)"
|
|
27
29
|
KV_CACHE = "KV Cache"
|
|
28
30
|
MOE_ROUTING = "MoE Routing"
|
|
29
31
|
MOE_GEMM = "MoE GEMM"
|
|
30
32
|
MOE_GEMM_SWIGLU = "MoE GEMM+SwiGLU"
|
|
31
33
|
MOE_FINALIZE = "MoE Finalize"
|
|
32
34
|
DENSE_GEMM = "Dense GEMM"
|
|
35
|
+
# NVIDIA cuBLASLt/CUTLASS can fuse GEMM with epilogue (bias + activation)
|
|
36
|
+
GEMM_BIAS_ACT = "GEMM+Bias+Activation"
|
|
33
37
|
RMSNORM = "RMSNorm"
|
|
34
38
|
RMSNORM_GEMM = "RMSNorm+GEMM"
|
|
35
39
|
SWIGLU = "SwiGLU"
|
|
@@ -274,16 +278,20 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
274
278
|
if "3d" in nl:
|
|
275
279
|
return Op.ATTN_DECODE, "kernel_unified_attention_3d"
|
|
276
280
|
else:
|
|
277
|
-
# NVIDIA
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
281
|
+
# NVIDIA Flash Attention (fmhaSm100*) is a fused kernel
|
|
282
|
+
# It fuses QKV projection + Softmax + Attention into one kernel
|
|
283
|
+
if "fmhasm100" in nl:
|
|
284
|
+
if "fmhasm100a" in nl or "context" in nl:
|
|
285
|
+
return Op.FLASH_ATTN_FUSED, "fmhaSm100a*_Context (QKV+Softmax+Attn)"
|
|
286
|
+
if "fmhasm100f" in nl or "forgen" in nl:
|
|
287
|
+
return Op.FLASH_ATTN_FUSED, "fmhaSm100f*_ForGen (QKV+Softmax+Attn)"
|
|
288
|
+
return Op.FLASH_ATTN_FUSED, "fmhaSm100* (QKV+Softmax+Attn)"
|
|
282
289
|
return Op.ATTN_PREFILL, name[:40]
|
|
283
290
|
|
|
284
|
-
# Flash Attention variants (vLLM)
|
|
291
|
+
# Flash Attention variants (vLLM) - these are fused on NVIDIA
|
|
285
292
|
if "flash::flash_fwd_kernel" in name or "flash_fwd" in nl:
|
|
286
|
-
|
|
293
|
+
if platform != "AMD":
|
|
294
|
+
return Op.FLASH_ATTN_FUSED, "flash::flash_fwd_kernel (QKV+Softmax+Attn)"
|
|
287
295
|
return Op.ATTN_PREFILL, "flash::flash_fwd_kernel"
|
|
288
296
|
|
|
289
297
|
if "reshape_and_cache" in nl:
|
|
@@ -306,9 +314,10 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
306
314
|
if "moe_sum" in nl:
|
|
307
315
|
return Op.MOE_FINALIZE, "vllm::moe::moe_sum_*"
|
|
308
316
|
|
|
309
|
-
# vLLM act_and_mul
|
|
317
|
+
# vLLM act_and_mul - fuses activation with element-wise multiply (SiLU * x)
|
|
318
|
+
# This is a fused operation used in SwiGLU/MoE
|
|
310
319
|
if "vllm::act_and_mul_kernel" in name or ("act_and_mul_kernel" in nl and "vllm" in nl):
|
|
311
|
-
return Op.
|
|
320
|
+
return Op.SWIGLU_GEMM, "vllm::act_and_mul_kernel (SwiGLU+Mul)"
|
|
312
321
|
|
|
313
322
|
if "_matmul_ogs_" in nl:
|
|
314
323
|
if "swiglu" in nl:
|