wafer-core 0.1.28__py3-none-any.whl → 0.1.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -145,7 +145,17 @@ def analyze_traces_from_loaded(
145
145
  trace2_total = trace2_agg["total_us"] / 1000
146
146
  trace1_count = int(trace1_agg["count"])
147
147
  trace2_count = int(trace2_agg["count"])
148
- ratio = trace1_avg / trace2_avg if trace2_avg > 0 else 1
148
+ # Speedup: ratio of total times (not per-call averages)
149
+ # Shows how many times faster/slower trace1 is compared to trace2
150
+ # > 1.0 means trace1 is slower, < 1.0 means trace1 is faster
151
+ # Using total time instead of avg time per call because operations may have
152
+ # vastly different call counts (e.g., fused vs unfused operations)
153
+ if trace2_total > 0:
154
+ ratio = trace1_total / trace2_total
155
+ elif trace1_total > 0:
156
+ ratio = float("inf") # trace2 has no time, trace1 is infinitely slower
157
+ else:
158
+ ratio = 1.0 # Both are zero
149
159
  gap_ms = trace1_total - trace2_total
150
160
 
151
161
  trace1_pattern = list(
@@ -446,6 +456,11 @@ def analyze_traces_aligned(
446
456
  )
447
457
  same_kernel_result = analyze_same_kernels_from_alignment(alignment.layer_alignments)
448
458
 
459
+ # Note: amd_kernels = trace1's kernels (filtered if phase_filter != "all")
460
+ # nvidia_kernels = trace2's kernels (filtered if phase_filter != "all")
461
+ # The variable names are misleading but trace1_* should use amd_kernels,
462
+ # and trace2_* should use nvidia_kernels to match the filtered kernel counts/totals.
463
+
449
464
  return {
450
465
  "metadata": {
451
466
  "amd_gpu": amd_trace.gpu_name,
@@ -462,10 +477,10 @@ def analyze_traces_aligned(
462
477
  "trace2_platform": trace2.platform,
463
478
  "trace2_gpu": trace2.gpu_name,
464
479
  "trace2_device": trace2.device_props,
465
- "trace1_kernels": len(amd_trace.kernel_events),
466
- "trace2_kernels": len(nvidia_trace.kernel_events),
467
- "trace1_total_ms": sum(k.get("dur", 0) for k in amd_trace.kernel_events) / 1000,
468
- "trace2_total_ms": sum(k.get("dur", 0) for k in nvidia_trace.kernel_events) / 1000,
480
+ "trace1_kernels": len(amd_kernels),
481
+ "trace2_kernels": len(nvidia_kernels),
482
+ "trace1_total_ms": sum(k.get("dur", 0) for k in amd_kernels) / 1000,
483
+ "trace2_total_ms": sum(k.get("dur", 0) for k in nvidia_kernels) / 1000,
469
484
  "phase": phase_filter,
470
485
  "trace1_layers": alignment.num_layers,
471
486
  "trace2_layers": alignment.num_layers,
@@ -579,7 +594,17 @@ def analyze_traces_aligned(
579
594
  trace2_total = trace2_agg["total_us"] / 1000
580
595
  trace1_count = int(trace1_agg["count"])
581
596
  trace2_count = int(trace2_agg["count"])
582
- ratio = trace1_avg / trace2_avg if trace2_avg > 0 else 1
597
+ # Speedup: ratio of total times (not per-call averages)
598
+ # Shows how many times faster/slower trace1 is compared to trace2
599
+ # > 1.0 means trace1 is slower, < 1.0 means trace1 is faster
600
+ # Using total time instead of avg time per call because operations may have
601
+ # vastly different call counts (e.g., fused vs unfused operations)
602
+ if trace2_total > 0:
603
+ ratio = trace1_total / trace2_total
604
+ elif trace1_total > 0:
605
+ ratio = float("inf") # trace2 has no time, trace1 is infinitely slower
606
+ else:
607
+ ratio = 1.0 # Both are zero
583
608
  gap_ms = trace1_total - trace2_total
584
609
 
585
610
  trace1_pattern = list(
@@ -24,12 +24,16 @@ class Op(Enum):
24
24
 
25
25
  ATTN_PREFILL = "Attention (Prefill)"
26
26
  ATTN_DECODE = "Attention (Decode)"
27
+ # NVIDIA Flash Attention fuses QKV projection + Softmax + Attention
28
+ FLASH_ATTN_FUSED = "FlashAttention (QKV+Softmax+Attn)"
27
29
  KV_CACHE = "KV Cache"
28
30
  MOE_ROUTING = "MoE Routing"
29
31
  MOE_GEMM = "MoE GEMM"
30
32
  MOE_GEMM_SWIGLU = "MoE GEMM+SwiGLU"
31
33
  MOE_FINALIZE = "MoE Finalize"
32
34
  DENSE_GEMM = "Dense GEMM"
35
+ # NVIDIA cuBLASLt/CUTLASS can fuse GEMM with epilogue (bias + activation)
36
+ GEMM_BIAS_ACT = "GEMM+Bias+Activation"
33
37
  RMSNORM = "RMSNorm"
34
38
  RMSNORM_GEMM = "RMSNorm+GEMM"
35
39
  SWIGLU = "SwiGLU"
@@ -274,16 +278,20 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
274
278
  if "3d" in nl:
275
279
  return Op.ATTN_DECODE, "kernel_unified_attention_3d"
276
280
  else:
277
- # NVIDIA uses fmhaSm100 with 'a' (prefill/context) and 'f' (decode/forgen)
278
- if "fmhasm100a" in nl or "context" in nl:
279
- return Op.ATTN_PREFILL, "fmhaSm100a*_Context"
280
- if "fmhasm100f" in nl or "forgen" in nl:
281
- return Op.ATTN_DECODE, "fmhaSm100f*_ForGen"
281
+ # NVIDIA Flash Attention (fmhaSm100*) is a fused kernel
282
+ # It fuses QKV projection + Softmax + Attention into one kernel
283
+ if "fmhasm100" in nl:
284
+ if "fmhasm100a" in nl or "context" in nl:
285
+ return Op.FLASH_ATTN_FUSED, "fmhaSm100a*_Context (QKV+Softmax+Attn)"
286
+ if "fmhasm100f" in nl or "forgen" in nl:
287
+ return Op.FLASH_ATTN_FUSED, "fmhaSm100f*_ForGen (QKV+Softmax+Attn)"
288
+ return Op.FLASH_ATTN_FUSED, "fmhaSm100* (QKV+Softmax+Attn)"
282
289
  return Op.ATTN_PREFILL, name[:40]
283
290
 
284
- # Flash Attention variants (vLLM)
291
+ # Flash Attention variants (vLLM) - these are fused on NVIDIA
285
292
  if "flash::flash_fwd_kernel" in name or "flash_fwd" in nl:
286
- # Could distinguish prefill/decode if needed, defaulting to prefill
293
+ if platform != "AMD":
294
+ return Op.FLASH_ATTN_FUSED, "flash::flash_fwd_kernel (QKV+Softmax+Attn)"
287
295
  return Op.ATTN_PREFILL, "flash::flash_fwd_kernel"
288
296
 
289
297
  if "reshape_and_cache" in nl:
@@ -306,9 +314,10 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
306
314
  if "moe_sum" in nl:
307
315
  return Op.MOE_FINALIZE, "vllm::moe::moe_sum_*"
308
316
 
309
- # vLLM act_and_mul (can be mangled C++ name)
317
+ # vLLM act_and_mul - fuses activation with element-wise multiply (SiLU * x)
318
+ # This is a fused operation used in SwiGLU/MoE
310
319
  if "vllm::act_and_mul_kernel" in name or ("act_and_mul_kernel" in nl and "vllm" in nl):
311
- return Op.MOE_GEMM_SWIGLU, "vllm::act_and_mul_kernel"
320
+ return Op.SWIGLU_GEMM, "vllm::act_and_mul_kernel (SwiGLU+Mul)"
312
321
 
313
322
  if "_matmul_ogs_" in nl:
314
323
  if "swiglu" in nl:
@@ -82,13 +82,37 @@ def _find_fusion_mappings(
82
82
  trace1_only = trace1_type_set - trace2_type_set
83
83
  trace2_only = trace2_type_set - trace1_type_set
84
84
 
85
- # For each unique type in trace1, find common sequence patterns
85
+ # For each unique type in trace1, check if it's a fused operation
86
+ # If trace1 has a unique kernel type that trace2 doesn't have, trace1 is likely fusing
86
87
  for unique_type in trace1_only:
87
88
  # Skip "Other" since it's too generic
88
89
  if unique_type == "Other":
89
90
  continue
90
91
 
91
- # Find all occurrences of this type
92
+ # If the unique type contains '+', it's explicitly a fused kernel
93
+ # This means trace1 (which has it) is fusing, not trace2
94
+ if "+" in unique_type:
95
+ # Parse components from the fused op name
96
+ components = [c.strip() for c in unique_type.split("+")]
97
+ indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
98
+
99
+ if len(indices) < 5:
100
+ continue
101
+
102
+ mappings.append({
103
+ "fused_platform": trace1_name,
104
+ "fused_kernel_type": unique_type,
105
+ "fused_count": len(indices),
106
+ "unfused_platform": trace2_name,
107
+ "unfused_sequence": components,
108
+ "unfused_count_per_type": {c: trace2_types.count(c) for c in components},
109
+ "pattern_count": len(indices),
110
+ "pattern_confidence": 1.0,
111
+ "evidence": f"{trace1_name} fuses {' + '.join(components)} into single kernel ({len(indices)} calls), {trace2_name} runs separately",
112
+ })
113
+ continue
114
+
115
+ # For non-fused unique types, find all occurrences
92
116
  indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
93
117
 
94
118
  if len(indices) < 5: # Need enough samples to be meaningful
@@ -106,11 +130,12 @@ def _find_fusion_mappings(
106
130
  continue
107
131
  most_common_before = max(before_types.items(), key=lambda x: x[1])
108
132
 
109
- # If there's a strong pattern (>80% of occurrences)
133
+ # If there's a strong pattern (>80% of occurrences) and trace2 has the preceding type,
134
+ # trace1 runs them separately while trace2 might fuse
110
135
  if most_common_before[1] / len(indices) > 0.8:
111
136
  fusion_candidate = most_common_before[0]
112
137
 
113
- # Verify trace2 has this type
138
+ # Verify trace2 has this type but NOT the unique_type
114
139
  if fusion_candidate in trace2_type_set:
115
140
  trace1_fusion_count = trace1_types.count(fusion_candidate)
116
141
  trace2_fusion_count = trace2_types.count(fusion_candidate)
@@ -127,7 +152,7 @@ def _find_fusion_mappings(
127
152
  },
128
153
  "pattern_count": len(indices),
129
154
  "pattern_confidence": most_common_before[1] / len(indices),
130
- "evidence": f"{trace1_name} runs {fusion_candidate}+{unique_type} separately, {trace2_name} fuses into {fusion_candidate}",
155
+ "evidence": f"{trace1_name} runs {fusion_candidate} + {unique_type} separately, {trace2_name} fuses into {fusion_candidate}",
131
156
  })
132
157
 
133
158
  # Also check trace2-only types
@@ -135,6 +160,28 @@ def _find_fusion_mappings(
135
160
  if unique_type == "Other":
136
161
  continue
137
162
 
163
+ # If the unique type contains '+', it's explicitly a fused kernel
164
+ # This means trace2 (which has it) is fusing, not trace1
165
+ if "+" in unique_type:
166
+ components = [c.strip() for c in unique_type.split("+")]
167
+ indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
168
+
169
+ if len(indices) < 5:
170
+ continue
171
+
172
+ mappings.append({
173
+ "fused_platform": trace2_name,
174
+ "fused_kernel_type": unique_type,
175
+ "fused_count": len(indices),
176
+ "unfused_platform": trace1_name,
177
+ "unfused_sequence": components,
178
+ "unfused_count_per_type": {c: trace1_types.count(c) for c in components},
179
+ "pattern_count": len(indices),
180
+ "pattern_confidence": 1.0,
181
+ "evidence": f"{trace2_name} fuses {' + '.join(components)} into single kernel ({len(indices)} calls), {trace1_name} runs separately",
182
+ })
183
+ continue
184
+
138
185
  indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
139
186
 
140
187
  if len(indices) < 5:
@@ -169,7 +216,7 @@ def _find_fusion_mappings(
169
216
  },
170
217
  "pattern_count": len(indices),
171
218
  "pattern_confidence": most_common_before[1] / len(indices),
172
- "evidence": f"{trace2_name} runs {fusion_candidate}+{unique_type} separately, {trace1_name} fuses into {fusion_candidate}",
219
+ "evidence": f"{trace2_name} runs {fusion_candidate} + {unique_type} separately, {trace1_name} fuses into {fusion_candidate}",
173
220
  })
174
221
 
175
222
  return mappings
@@ -184,8 +231,15 @@ def _find_count_imbalance_fusions(
184
231
  ) -> list[dict]:
185
232
  """Find fusions by looking for significant count imbalances.
186
233
 
187
- When one platform has significantly more kernel calls of a type (>1.5x),
188
- it suggests the other platform fuses those operations.
234
+ When one platform has significantly more kernel calls of a type (>3x),
235
+ it MAY suggest the other platform fuses those operations.
236
+
237
+ NOTE: This is speculative - count differences can also indicate:
238
+ - Different algorithmic implementations
239
+ - Different library choices (cuBLAS vs hipBLAS)
240
+ - Different optimization strategies
241
+
242
+ Only very large imbalances (>3x) with high counts are flagged.
189
243
  """
190
244
  mappings = []
191
245
  trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
@@ -201,24 +255,28 @@ def _find_count_imbalance_fusions(
201
255
  # Find common types with significant differences
202
256
  common_types = set(trace1_counts.keys()) & set(trace2_counts.keys())
203
257
 
258
+ # Skip types that are likely just implementation differences, not fusion
259
+ skip_types = {"Reduce", "Copy/Memory", "Sync", "Other", "Elementwise"}
260
+
204
261
  for ktype in common_types:
205
- if ktype == "Other":
262
+ if ktype in skip_types:
206
263
  continue
207
264
 
208
265
  trace1_count = trace1_counts[ktype]
209
266
  trace2_count = trace2_counts[ktype]
210
267
 
211
- # Skip trivial counts
212
- if trace1_count + trace2_count < 50:
268
+ # Skip low counts - need significant samples
269
+ if trace1_count + trace2_count < 200:
213
270
  continue
214
271
 
215
- # Check if there's a significant imbalance (>1.5x)
272
+ # Check if there's a very significant imbalance (>3x)
273
+ # Lower ratios are likely implementation differences, not fusion
216
274
  if trace1_count == 0 or trace2_count == 0:
217
275
  continue
218
276
 
219
277
  ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
220
278
 
221
- if ratio < 1.5:
279
+ if ratio < 3.0:
222
280
  continue
223
281
 
224
282
  # Determine which platform has more (unfused) and which has fewer (fused)
@@ -242,7 +300,7 @@ def _find_count_imbalance_fusions(
242
300
  "unfused_count_per_type": {ktype: unfused_count},
243
301
  "pattern_count": unfused_count - fused_count,
244
302
  "pattern_confidence": (unfused_count - fused_count) / unfused_count,
245
- "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses",
303
+ "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}) - possible fusion",
246
304
  })
247
305
 
248
306
  return mappings
@@ -325,12 +383,14 @@ def detect_fusion_patterns(
325
383
  amd_kernels: list[dict],
326
384
  nvidia_kernels: list[dict],
327
385
  ) -> FusionAnalysis:
328
- """Detect fusion patterns using pattern-based analysis.
386
+ """Detect fusion patterns using explicit fused operation detection only.
329
387
 
330
- This is the main entry point for fusion detection. It combines:
331
- 1. Explicit fused operations (kernels classified with '+' in name)
332
- 2. Sequence pattern analysis (unique kernel types with consistent patterns)
333
- 3. Count imbalance analysis (one platform has significantly more calls)
388
+ This approach only reports high-confidence fusions where kernels are
389
+ explicitly classified as fused (e.g., 'RMSNorm+GEMM', 'SwiGLU+GEMM').
390
+
391
+ We intentionally avoid speculative detection (sequence patterns, count
392
+ imbalances) because these produce too many false positives - count
393
+ differences are usually due to different implementations, not fusion.
334
394
 
335
395
  Args:
336
396
  amd_kernels: List of AMD kernel events
@@ -341,38 +401,19 @@ def detect_fusion_patterns(
341
401
  """
342
402
  all_mappings: list[dict] = []
343
403
 
344
- # 1. Find explicit fused operations (highest confidence)
404
+ # Only use explicit fused operations (highest confidence, no false positives)
405
+ # These are kernels explicitly classified with '+' in their operation type
345
406
  explicit_fusions = _find_explicit_fused_operations(
346
407
  amd_kernels, nvidia_kernels,
347
408
  trace1_name="AMD", trace2_name="NVIDIA",
348
409
  trace1_platform="AMD",
349
410
  )
350
411
  all_mappings.extend(explicit_fusions)
351
-
352
- # 2. Find sequence-based fusions
353
- sequence_fusions = _find_fusion_mappings(
354
- amd_kernels, nvidia_kernels,
355
- trace1_name="AMD", trace2_name="NVIDIA",
356
- trace1_platform="AMD",
357
- )
358
- # Deduplicate: skip if same fused_kernel_type already found
359
- existing_types = {m["fused_kernel_type"] for m in all_mappings}
360
- for fusion in sequence_fusions:
361
- if fusion["fused_kernel_type"] not in existing_types:
362
- all_mappings.append(fusion)
363
- existing_types.add(fusion["fused_kernel_type"])
364
-
365
- # 3. Find count-imbalance fusions
366
- count_fusions = _find_count_imbalance_fusions(
367
- amd_kernels, nvidia_kernels,
368
- trace1_name="AMD", trace2_name="NVIDIA",
369
- trace1_platform="AMD",
370
- )
371
- # Deduplicate
372
- for fusion in count_fusions:
373
- if fusion["fused_kernel_type"] not in existing_types:
374
- all_mappings.append(fusion)
375
- existing_types.add(fusion["fused_kernel_type"])
412
+
413
+ # NOTE: We intentionally skip sequence-based and count-imbalance detection
414
+ # because they produce false positives. Count differences between platforms
415
+ # are usually due to different library implementations (cuBLAS vs hipBLAS),
416
+ # not actual kernel fusion.
376
417
 
377
418
  # Convert to FusionPattern objects
378
419
  patterns: list[FusionPattern] = []
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-core
3
- Version: 0.1.28
3
+ Version: 0.1.30
4
4
  Summary: Core utilities and environments for Wafer GPU kernel optimization
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: aiohttp>=3.9.0
@@ -321,12 +321,12 @@ wafer_core/lib/rocprofiler/systems/sample/profiler.py,sha256=CYZPTzNXd48LoCfmY6h
321
321
  wafer_core/lib/trace_compare/PERFORMANCE.md,sha256=jkJh7ApZi8H7NKTcz8v0LNtwSFtIUqY88e3QbL749ww,3823
322
322
  wafer_core/lib/trace_compare/__init__.py,sha256=CyUPbPQDYhVLCFFA7S_jNSilG3OgqYjmHSKfR5X11go,1377
323
323
  wafer_core/lib/trace_compare/aligner.py,sha256=1S8Ob3RaEsIjN0HdqEx0yGsW5uf_lMrJVSH_MnZhKok,13788
324
- wafer_core/lib/trace_compare/analyzer.py,sha256=YkuOPA3HFX_7mNUEhE9CMOtEMGLQd12lvUkvqqeQF14,29698
324
+ wafer_core/lib/trace_compare/analyzer.py,sha256=Ou_gooG027YVuYVF5oddAkMsObXrrPQLBPHUzSMA4Vg,31078
325
325
  wafer_core/lib/trace_compare/api.py,sha256=JSRTcd7eZK1Z8l18TFEiA5A8ENJS1TMz7oIiw1KBbAs,8796
326
326
  wafer_core/lib/trace_compare/architecture.py,sha256=8bqlAJQeJLBHblyXvFV-w55PIKiVQDPjDQZ8Jx4tuGg,2110
327
- wafer_core/lib/trace_compare/classifier.py,sha256=CDGzY9TY-I5wRuEGsu4mTCdljqVTOnLWyFLyNgmkGXI,16864
327
+ wafer_core/lib/trace_compare/classifier.py,sha256=cYAmDW8S75N6cE3mJNZM-UKCJSX7rFP-8klVrukBvNQ,17504
328
328
  wafer_core/lib/trace_compare/formatter.py,sha256=GNrCZ45ueBN05CEXjOtTuKvTI8z-g-ZZFil-ni3sWVY,37962
329
- wafer_core/lib/trace_compare/fusion_analyzer.py,sha256=ZbFXUuPOt8ezT08WfjlDx7XaUNoUgg9hlFTJb68-eo0,17433
329
+ wafer_core/lib/trace_compare/fusion_analyzer.py,sha256=ga0sfxx8OCQu9Hq7uJSAMfXhnCvBaAmzVofBN7_gdV8,19843
330
330
  wafer_core/lib/trace_compare/kernel_registry.yaml,sha256=0-knXwsF3pR1x1JdIz-aWaH-5xDgTylh53E47Kf6nHo,9808
331
331
  wafer_core/lib/trace_compare/layer_segmentation.py,sha256=kI_Y1e9nrKZfdwfcrGo4h7gpMxqXI_xkgXk46zuFen4,4642
332
332
  wafer_core/lib/trace_compare/loader.py,sha256=zBHI0r7CX_wJ2mz0_-s0lm9KGSdaVaq7OKyxUL6KIlw,23997
@@ -697,6 +697,6 @@ wafer_core/utils/modal_execution/modal_app.py,sha256=VfS2cX8gHtnlPXemmMcEwDPeQdh
697
697
  wafer_core/utils/modal_execution/modal_config.py,sha256=7cGX9TGqilQ3qxI3OFGXV5orjtyRU-PEDOJ4vP2oxno,4421
698
698
  wafer_core/utils/modal_execution/modal_execution.py,sha256=gChjnV6jqA3A7IRP3DfvV5cSfm_MN0X4f7JZufXgdZE,24594
699
699
  wafer_core/utils/modal_execution/test_modal.py,sha256=_jqou_hrLs1Daf1590Pnb0a_lXMMa2rczAPpW9HpoNQ,8153
700
- wafer_core-0.1.28.dist-info/METADATA,sha256=0x6opc3zOlxGhlZNJDVDY2LPnBZHYP5K4U0I6ZDl0Os,1477
701
- wafer_core-0.1.28.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
702
- wafer_core-0.1.28.dist-info/RECORD,,
700
+ wafer_core-0.1.30.dist-info/METADATA,sha256=YuF3VyyP3tvmv2S-7E8epi1J2_1e2yXJfapS1uGQ0Zs,1477
701
+ wafer_core-0.1.30.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
702
+ wafer_core-0.1.30.dist-info/RECORD,,