wafer-core 0.1.26__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,890 +1,356 @@
1
- """Fusion analysis for detecting kernel fusion differences between platforms.
1
+ """Fusion analysis using aligned kernel pairs.
2
2
 
3
- Detects fusion differences between AMD and NVIDIA by analyzing how many kernels
4
- each platform launches for the same logical operations.
3
+ Detects kernel fusion patterns at the kernel-pair level within aligned layers.
4
+
5
+ Key insight: When one platform has an operation type like "RMSNorm+GEMM" (containing '+'),
6
+ that platform IS fusing. The OTHER platform runs the components separately.
5
7
  """
6
8
 
7
- import json
8
9
  from collections import Counter, defaultdict
9
- from pathlib import Path
10
+ from dataclasses import dataclass, field
10
11
  from typing import Any
11
12
 
12
- from .classifier import classify_kernel
13
-
13
+ from .aligner import KernelPair, LayerAlignment
14
14
 
15
- def _load_trace_for_fusion(
16
- file_path: str | Path,
17
- ) -> tuple[str, str, list[dict[str, Any]], dict[int, list[dict[str, Any]]]]:
18
- """Load trace and group kernels by correlation ID.
19
15
 
20
- Args:
21
- file_path: Path to trace file
22
-
23
- Returns:
24
- Tuple of (platform, gpu_name, all_kernels, corr_groups)
25
- """
26
- with open(file_path, "rb") as f:
27
- trace = json.load(f)
16
+ # Maps fused operation names to their component operations
17
+ FUSED_OPERATION_COMPONENTS: dict[str, list[str]] = {
18
+ "RMSNorm+GEMM": ["RMSNorm", "Dense GEMM"],
19
+ "SwiGLU+GEMM": ["SwiGLU", "Dense GEMM"],
20
+ "MoE GEMM+SwiGLU": ["MoE GEMM", "SwiGLU"],
21
+ "Embedding+RMSNorm+GEMM": ["Elementwise", "RMSNorm", "Dense GEMM"], # Embedding is often Elementwise
22
+ }
28
23
 
29
- # Detect platform
30
- props = trace.get("deviceProperties", [{}])[0]
31
- is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
32
- platform = "AMD" if is_amd else "NVIDIA"
33
- gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
34
24
 
35
- # Get all kernel events
36
- events = trace.get("traceEvents", [])
37
- kernels = [e for e in events if e.get("cat") == "kernel"]
25
+ @dataclass
26
+ class FusionPattern:
27
+ """A detected fusion pattern."""
38
28
 
39
- # Group by correlation ID
40
- corr_groups: dict[int, list[dict[str, Any]]] = defaultdict(list)
41
- for k in kernels:
42
- corr_id = k.get("args", {}).get("correlation")
43
- if corr_id is not None:
44
- corr_groups[corr_id].append(k)
29
+ layer: int
30
+ operation: str # The fused operation name (e.g., "RMSNorm+GEMM")
31
+ fused_platform: str # Platform that fuses (has fewer kernels)
32
+ fused_kernel: str # The actual fused kernel name
33
+ unfused_kernels: list[str] # List of separate kernels on the other platform
34
+ count: int
35
+ evidence: str
45
36
 
46
- return platform, gpu_name, kernels, dict(corr_groups)
47
37
 
38
+ @dataclass
39
+ class FusionAnalysis:
40
+ """Complete fusion analysis result."""
48
41
 
49
- def _analyze_correlation_group(
50
- kernels: list[dict[str, Any]],
51
- ) -> tuple[dict[str, int], dict[str, float]]:
52
- """Analyze kernel composition within a correlation group.
42
+ patterns: list[FusionPattern] = field(default_factory=list)
43
+ summary: dict[str, Any] = field(default_factory=dict)
53
44
 
54
- Args:
55
- kernels: List of kernel events in the group
56
45
 
57
- Returns:
58
- Tuple of (counts, timings) where counts maps kernel types to counts
59
- and timings maps kernel types to total duration in microseconds
46
+ def _is_fused_operation(operation: str) -> bool:
47
+ """Check if an operation type represents a fused operation.
48
+
49
+ Operations like 'RMSNorm+GEMM', 'MoE GEMM+SwiGLU' indicate fusion.
50
+ Also handles 'Fused (Unknown)' from heuristic detection.
60
51
  """
61
- counts: Counter[str] = Counter()
62
- timings: dict[str, float] = defaultdict(float)
63
-
64
- for k in kernels:
65
- kernel_type = classify_kernel(k.get("name", ""))
66
- counts[kernel_type] += 1
67
- timings[kernel_type] += k.get("dur", 0)
52
+ return "+" in operation or operation == "Fused (Unknown)"
68
53
 
69
- return dict(counts), dict(timings)
70
54
 
55
+ def _get_component_operations(fused_op: str) -> list[str]:
56
+ """Get the component operations for a fused operation type."""
57
+ if fused_op in FUSED_OPERATION_COMPONENTS:
58
+ return FUSED_OPERATION_COMPONENTS[fused_op]
59
+
60
+ # Parse from the operation name itself (e.g., "RMSNorm+GEMM" -> ["RMSNorm", "GEMM"])
61
+ if "+" in fused_op:
62
+ parts = [p.strip() for p in fused_op.split("+")]
63
+ # Map shorthand to full names
64
+ mapping = {
65
+ "GEMM": "Dense GEMM",
66
+ "SwiGLU": "Triton Fused",
67
+ }
68
+ return [mapping.get(p, p) for p in parts]
69
+
70
+ return []
71
71
 
72
- def _match_correlation_groups(
73
- amd_groups: dict[int, list[dict[str, Any]]],
74
- nv_groups: dict[int, list[dict[str, Any]]],
75
- size_tolerance: float = 0.25,
76
- ) -> list[tuple[int, int]]:
77
- """Match AMD and NVIDIA correlation groups by size and composition.
78
-
79
- Since correlation IDs don't match between platforms, we match groups that
80
- have similar sizes (within tolerance) and likely represent the same operation.
81
72
 
73
+ def _find_component_kernels(
74
+ layer_alignment: LayerAlignment,
75
+ component_ops: list[str],
76
+ platform: str,
77
+ ) -> list[str]:
78
+ """Find kernels for component operations on the specified platform.
79
+
82
80
  Args:
83
- amd_groups: AMD correlation groups
84
- nv_groups: NVIDIA correlation groups
85
- size_tolerance: Groups match if sizes are within this fraction (25% default, increased for better matching)
86
-
81
+ layer_alignment: Layer alignment data
82
+ component_ops: List of operation types to look for (e.g., ["RMSNorm", "Dense GEMM"])
83
+ platform: "AMD" or "NVIDIA"
84
+
87
85
  Returns:
88
- List of (amd_corr_id, nv_corr_id) pairs
86
+ List of kernel names found for these operations
89
87
  """
90
- matches = []
91
- used_nv_ids: set[int] = set()
92
-
93
- # Pre-compute compositions for all groups to avoid redundant analysis
94
- amd_comps = {id: _analyze_correlation_group(kernels)[0] for id, kernels in amd_groups.items()}
95
- nv_comps = {id: _analyze_correlation_group(kernels)[0] for id, kernels in nv_groups.items()}
96
-
97
- # Build size-indexed lookup for NVIDIA groups for faster filtering
98
- nv_by_size = [(id, len(kernels)) for id, kernels in nv_groups.items()]
99
-
100
- # Sort AMD groups by size (largest first for better matching)
101
- amd_sorted = sorted(amd_groups.items(), key=lambda x: len(x[1]), reverse=True)
102
-
103
- for amd_id, amd_kernels in amd_sorted:
104
- amd_size = len(amd_kernels)
105
- amd_comp = amd_comps[amd_id]
106
-
107
- # Calculate size bounds for matching
108
- min_size = amd_size / (1 + size_tolerance)
109
- max_size = amd_size * (1 + size_tolerance)
110
-
111
- # Find best matching NVIDIA group
112
- best_match = None
113
- best_score = float("inf")
114
-
115
- for nv_id, nv_size in nv_by_size:
116
- if nv_id in used_nv_ids:
117
- continue
118
-
119
- # Quick size filter
120
- if nv_size < min_size or nv_size > max_size:
121
- continue
122
-
123
- # Use pre-computed composition
124
- nv_comp = nv_comps[nv_id]
125
-
126
- # Simple similarity: number of shared kernel types
127
- shared_types = set(amd_comp.keys()) & set(nv_comp.keys())
128
- similarity = len(shared_types)
129
-
130
- # Prefer matches with more shared types and closer sizes
131
- score = abs(amd_size - nv_size) - (similarity * 10)
132
-
133
- if score < best_score:
134
- best_score = score
135
- best_match = nv_id
136
-
137
- if best_match is not None:
138
- matches.append((amd_id, best_match))
139
- used_nv_ids.add(best_match)
140
-
141
- return matches
142
-
88
+ found_kernels: list[str] = []
89
+
90
+ for pair in layer_alignment.kernel_pairs:
91
+ if pair.operation in component_ops:
92
+ if platform == "AMD" and pair.amd_kernel and pair.amd_count > 0:
93
+ found_kernels.append(pair.amd_kernel)
94
+ elif platform == "NVIDIA" and pair.nvidia_kernel and pair.nvidia_count > 0:
95
+ found_kernels.append(pair.nvidia_kernel)
96
+
97
+ return found_kernels
143
98
 
144
- def _find_fusion_mappings(
145
- trace1_kernels: list[dict],
146
- trace2_kernels: list[dict],
147
- trace1_name: str = "Trace1",
148
- trace2_name: str = "Trace2",
149
- ) -> list[dict]:
150
- """Find fusion mappings by analyzing kernel execution sequence patterns.
151
99
 
152
- This function identifies when one platform runs multiple kernels separately
153
- while the other platform fuses them into a single kernel.
100
+ def detect_fusion_in_aligned_layers(
101
+ layer_alignments: list[LayerAlignment],
102
+ ) -> FusionAnalysis:
103
+ """Detect fusion patterns from aligned layer data.
154
104
 
155
105
  Args:
156
- trace1_kernels: List of kernel events from first trace
157
- trace2_kernels: List of kernel events from second trace
158
- trace1_name: Name of first platform (e.g., "AMD")
159
- trace2_name: Name of second platform (e.g., "NVIDIA")
106
+ layer_alignments: List of aligned layers
160
107
 
161
108
  Returns:
162
- List of mapping dictionaries, each containing:
163
- - fused_platform: Which platform fuses the operations
164
- - fused_kernel_type: The single fused kernel type
165
- - unfused_platform: Which platform runs them separately
166
- - unfused_sequence: List of kernel types run separately
167
- - pattern_count: How many times this pattern appears
168
- - pattern_confidence: Fraction of occurrences following this pattern
169
- - evidence: Human-readable description
109
+ FusionAnalysis with detected patterns
170
110
  """
171
- from collections import defaultdict
172
- from wafer_core.lib.trace_compare.classifier import classify_kernel
111
+ fusion_patterns: list[FusionPattern] = []
173
112
 
174
- mappings = []
175
-
176
- # Sort kernels by timestamp
177
- trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get('ts', 0))
178
- trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get('ts', 0))
179
-
180
- # Classify all kernels
181
- trace1_types = [classify_kernel(k.get('name', '')) for k in trace1_sorted]
182
- trace2_types = [classify_kernel(k.get('name', '')) for k in trace2_sorted]
183
-
184
- # Find kernel types unique to each trace
185
- trace1_type_set = set(trace1_types)
186
- trace2_type_set = set(trace2_types)
187
-
188
- trace1_only = trace1_type_set - trace2_type_set
189
- trace2_only = trace2_type_set - trace1_type_set
190
-
191
- # For each unique type in trace1, find common sequence patterns
192
- for unique_type in trace1_only:
193
- # Find all occurrences of this type
194
- indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
195
-
196
- if len(indices) < 5: # Need enough samples to be meaningful
197
- continue
198
-
199
- # Analyze what comes before/after each occurrence
200
- before_types = defaultdict(int)
201
-
202
- for idx in indices:
203
- if idx > 0:
204
- before_types[trace1_types[idx - 1]] += 1
205
-
206
- # Find the most common pattern (e.g., "Attention → Reduce")
207
- most_common_before = max(before_types.items(), key=lambda x: x[1]) if before_types else (None, 0)
208
-
209
- # If there's a strong pattern (>80% of occurrences)
210
- if most_common_before[1] / len(indices) > 0.8:
211
- # This suggests: Trace2 likely fuses [before_type + unique_type] into [before_type]
212
- fusion_candidate = most_common_before[0]
213
-
214
- # Verify trace2 has this type
215
- if fusion_candidate in trace2_type_set:
216
- # Count occurrences to compare
217
- trace1_fusion_count = trace1_types.count(fusion_candidate)
218
- trace2_fusion_count = trace2_types.count(fusion_candidate)
219
-
220
- mappings.append({
221
- "fused_platform": trace2_name,
222
- "fused_kernel_type": fusion_candidate,
223
- "fused_count": trace2_fusion_count,
224
- "unfused_platform": trace1_name,
225
- "unfused_sequence": [fusion_candidate, unique_type],
226
- "unfused_count_per_type": {
227
- fusion_candidate: trace1_fusion_count,
228
- unique_type: len(indices)
229
- },
230
- "pattern_count": len(indices),
231
- "pattern_confidence": most_common_before[1] / len(indices),
232
- "evidence": f"{trace1_name} runs {fusion_candidate}+{unique_type} separately, {trace2_name} fuses into {fusion_candidate}"
233
- })
234
-
235
- # Also check trace2-only types
236
- for unique_type in trace2_only:
237
- indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
238
-
239
- if len(indices) < 5:
240
- continue
241
-
242
- before_types = defaultdict(int)
243
-
244
- for idx in indices:
245
- if idx > 0:
246
- before_types[trace2_types[idx - 1]] += 1
247
-
248
- most_common_before = max(before_types.items(), key=lambda x: x[1]) if before_types else (None, 0)
249
-
250
- if most_common_before[1] / len(indices) > 0.8:
251
- fusion_candidate = most_common_before[0]
252
-
253
- if fusion_candidate in trace1_type_set:
254
- trace1_fusion_count = trace1_types.count(fusion_candidate)
255
- trace2_fusion_count = trace2_types.count(fusion_candidate)
256
-
257
- mappings.append({
258
- "fused_platform": trace1_name,
259
- "fused_kernel_type": fusion_candidate,
260
- "fused_count": trace1_fusion_count,
261
- "unfused_platform": trace2_name,
262
- "unfused_sequence": [fusion_candidate, unique_type],
263
- "unfused_count_per_type": {
264
- fusion_candidate: trace2_fusion_count,
265
- unique_type: len(indices)
266
- },
267
- "pattern_count": len(indices),
268
- "pattern_confidence": most_common_before[1] / len(indices),
269
- "evidence": f"{trace2_name} runs {fusion_candidate}+{unique_type} separately, {trace1_name} fuses into {fusion_candidate}"
270
- })
271
-
272
- # NEW: Check for partial fusion patterns (kernel exists on both platforms but with different counts)
273
- # If one platform has significantly fewer calls (>1.3x ratio), look for fusion patterns
274
- common_types = trace1_type_set & trace2_type_set
275
-
276
- for ktype in common_types:
277
- trace1_count = trace1_types.count(ktype)
278
- trace2_count = trace2_types.count(ktype)
279
-
280
- # Check if there's a significant imbalance (one has >1.3x more)
281
- if trace1_count == 0 or trace2_count == 0:
282
- continue
283
-
284
- ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
285
-
286
- if ratio < 1.3 or trace1_count + trace2_count < 100:
287
- continue
288
-
289
- # Determine which platform has more (unfused) and which has fewer (fused)
290
- if trace1_count > trace2_count:
291
- # Trace1 has more separate calls, Trace2 likely fuses
292
- unfused_platform = trace1_name
293
- fused_platform = trace2_name
294
- unfused_count = trace1_count
295
- fused_count = trace2_count
296
-
297
- # Find what Trace2 might be fusing this into
298
- # Use the most common kernel type in Trace2 as a likely fusion target
299
- trace2_type_counts = defaultdict(int)
300
- for t in trace2_types:
301
- if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
302
- trace2_type_counts[t] += 1
303
-
304
- if trace2_type_counts:
305
- # Use the most common type as the fusion target
306
- fusion_target = max(trace2_type_counts.items(), key=lambda x: x[1])[0]
307
-
308
- mappings.append({
309
- "fused_platform": fused_platform,
310
- "fused_kernel_type": fusion_target,
311
- "fused_count": fused_count,
312
- "unfused_platform": unfused_platform,
313
- "unfused_sequence": [ktype],
314
- "unfused_count_per_type": {
315
- ktype: unfused_count
316
- },
317
- "pattern_count": unfused_count - fused_count,
318
- "pattern_confidence": (unfused_count - fused_count) / unfused_count,
319
- "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
320
- })
321
- else:
322
- # Trace2 has more separate calls, Trace1 likely fuses
323
- unfused_platform = trace2_name
324
- fused_platform = trace1_name
325
- unfused_count = trace2_count
326
- fused_count = trace1_count
327
-
328
- # Find what Trace1 might be fusing this into
329
- # Use the most common kernel type in Trace1 as a likely fusion target
330
- trace1_type_counts = defaultdict(int)
331
- for t in trace1_types:
332
- if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
333
- trace1_type_counts[t] += 1
334
-
335
- if trace1_type_counts:
336
- fusion_target = max(trace1_type_counts.items(), key=lambda x: x[1])[0]
337
-
338
- mappings.append({
339
- "fused_platform": fused_platform,
340
- "fused_kernel_type": fusion_target,
341
- "fused_count": fused_count,
342
- "unfused_platform": unfused_platform,
343
- "unfused_sequence": [ktype],
344
- "unfused_count_per_type": {
345
- ktype: unfused_count
346
- },
347
- "pattern_count": unfused_count - fused_count,
348
- "pattern_confidence": (unfused_count - fused_count) / unfused_count,
349
- "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
350
- })
351
-
352
- return mappings
353
-
354
-
355
- def _detect_intra_type_fusion(
356
- trace1_kernels: list[dict],
357
- trace2_kernels: list[dict],
358
- trace1_name: str,
359
- trace2_name: str,
360
- ) -> list[dict]:
361
- """Detect intra-type fusion where consecutive same-type kernels are fused.
362
-
363
- Example: AMD runs Sort→Sort→Sort (42 calls) while NVIDIA runs Sort→Sort (10 calls)
364
- This indicates NVIDIA has a more efficient Sort implementation that fuses operations.
365
- """
366
- from wafer_core.lib.trace_compare.classifier import classify_kernel
367
-
368
- def analyze_chains(kernels):
369
- """Find chains of consecutive same-type kernels"""
370
- sorted_kernels = sorted(kernels, key=lambda k: k.get('ts', 0))
371
- types = [classify_kernel(k['name']) for k in sorted_kernels]
372
-
373
- chains = defaultdict(list)
374
- i = 0
375
- while i < len(types):
376
- ktype = types[i]
377
- count = 0
378
- while i < len(types) and types[i] == ktype:
379
- count += 1
380
- i += 1
381
- chains[ktype].append(count)
382
-
383
- return chains
384
-
385
- trace1_chains = analyze_chains(trace1_kernels)
386
- trace2_chains = analyze_chains(trace2_kernels)
387
-
388
- mappings = []
389
- all_types = set(trace1_chains.keys()) | set(trace2_chains.keys())
390
-
391
- for ktype in all_types:
392
- t1_lengths = trace1_chains.get(ktype, [])
393
- t2_lengths = trace2_chains.get(ktype, [])
394
-
395
- # Skip if not enough data
396
- if len(t1_lengths) < 5 and len(t2_lengths) < 5:
397
- continue
398
-
399
- # Filter to chains with multiple kernels
400
- t1_multi = [l for l in t1_lengths if l > 1]
401
- t2_multi = [l for l in t2_lengths if l > 1]
402
-
403
- if not t1_multi and not t2_multi:
404
- continue
405
-
406
- t1_total = sum(t1_lengths)
407
- t2_total = sum(t2_lengths)
408
- t1_chains = len(t1_multi) if t1_multi else len(t1_lengths)
409
- t2_chains = len(t2_multi) if t2_multi else len(t2_lengths)
410
-
411
- if t1_chains == 0 or t2_chains == 0:
412
- continue
413
-
414
- t1_avg_chain = sum(t1_multi) / len(t1_multi) if t1_multi else 1.0
415
- t2_avg_chain = sum(t2_multi) / len(t2_multi) if t2_multi else 1.0
416
-
417
- chain_ratio = max(t1_avg_chain, t2_avg_chain) / min(t1_avg_chain, t2_avg_chain)
418
-
419
- # Significant intra-fusion if chains are 2x+ different
420
- if chain_ratio > 2.0 and abs(t1_total - t2_total) > 100:
421
- if t1_avg_chain > t2_avg_chain:
422
- unfused_platform = trace1_name
423
- fused_platform = trace2_name
424
- unfused_chains = t1_chains
425
- fused_chains = t2_chains
426
- unfused_avg = t1_avg_chain
427
- fused_avg = t2_avg_chain
428
- unfused_total = t1_total
429
- fused_total = t2_total
430
- else:
431
- unfused_platform = trace2_name
432
- fused_platform = trace1_name
433
- unfused_chains = t2_chains
434
- fused_chains = t1_chains
435
- unfused_avg = t2_avg_chain
436
- fused_avg = t1_avg_chain
437
- unfused_total = t2_total
438
- fused_total = t1_total
439
-
440
- mappings.append({
441
- "fused_platform": fused_platform,
442
- "fused_kernel_type": ktype,
443
- "fused_count": fused_total,
444
- "unfused_platform": unfused_platform,
445
- "unfused_sequence": [ktype, ktype], # Same type repeated
446
- "unfused_count_per_type": {ktype: unfused_total},
447
- "pattern_count": unfused_total - fused_total,
448
- "pattern_confidence": min(unfused_chains, fused_chains) / max(unfused_chains, fused_chains),
449
- "evidence": f"{unfused_platform} runs {ktype} in chains of {unfused_avg:.0f} calls ({unfused_chains} chains, {unfused_total:,} total), {fused_platform} fuses to {fused_avg:.0f} calls ({fused_chains} chains, {fused_total:,} total) - {chain_ratio:.1f}x more efficient"
450
- })
451
-
452
- return mappings
453
-
454
-
455
- def _find_partial_fusion_via_groups(
456
- trace1_large: dict[int, list[dict]],
457
- trace2_large: dict[int, list[dict]],
458
- matches: list[tuple[int, int]],
459
- trace1_name: str,
460
- trace2_name: str,
461
- ) -> list[dict]:
462
- """Find partial fusion patterns by analyzing correlation group differences.
463
-
464
- When one platform has fewer of a kernel type, check what kernel types the
465
- other platform has MORE of in those same groups - those are likely fusion targets.
466
- """
467
- from collections import Counter
468
- from wafer_core.lib.trace_compare.classifier import classify_kernel
469
-
470
- mappings = []
471
-
472
- # For each matched pair, track kernel type counts
473
- trace1_all_types = []
474
- trace2_all_types = []
475
-
476
- for trace1_cid, trace2_cid in matches:
477
- trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
478
- trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
479
- trace1_all_types.extend(trace1_ktypes)
480
- trace2_all_types.extend(trace2_ktypes)
481
-
482
- # Find kernel types with significant imbalances
483
- trace1_counts = Counter(trace1_all_types)
484
- trace2_counts = Counter(trace2_all_types)
485
- all_types = set(trace1_counts.keys()) | set(trace2_counts.keys())
486
-
487
- for ktype in all_types:
488
- trace1_count = trace1_counts.get(ktype, 0)
489
- trace2_count = trace2_counts.get(ktype, 0)
490
-
491
- if trace1_count == 0 or trace2_count == 0:
492
- continue # Handled by sequence-based detection
493
-
494
- ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
495
-
496
- if ratio < 1.3 or trace1_count + trace2_count < 100:
497
- continue # Not significant
498
-
499
- # Determine which platform has fewer (fuses more)
500
- if trace1_count > trace2_count:
501
- unfused_platform = trace1_name
502
- fused_platform = trace2_name
503
- unfused_count = trace1_count
504
- fused_count = trace2_count
505
-
506
- # Find groups where trace1 has this kernel but trace2 doesn't
507
- fusion_targets = Counter()
508
- groups_analyzed = 0
509
-
510
- for trace1_cid, trace2_cid in matches:
511
- trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
512
- trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
513
-
514
- trace1_has = ktype in trace1_ktypes
515
- trace2_has = ktype in trace2_ktypes
516
-
517
- if trace1_has and not trace2_has:
518
- # What does trace2 have MORE of in this group?
519
- trace1_kcounts = Counter(trace1_ktypes)
520
- trace2_kcounts = Counter(trace2_ktypes)
521
-
522
- for other_type in set(trace2_kcounts.keys()):
523
- if other_type == ktype or other_type == "Other":
524
- continue
525
- diff = trace2_kcounts[other_type] - trace1_kcounts.get(other_type, 0)
526
- if diff > 0:
527
- fusion_targets[other_type] += diff
528
-
529
- groups_analyzed += 1
530
-
531
- if fusion_targets and groups_analyzed >= 5:
532
- # Report top fusion targets
533
- top_targets = fusion_targets.most_common(3)
534
- target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
535
-
536
- mappings.append({
537
- "fused_platform": fused_platform,
538
- "fused_kernel_type": top_targets[0][0],
539
- "fused_count": fused_count,
540
- "unfused_platform": unfused_platform,
541
- "unfused_sequence": [ktype],
542
- "unfused_count_per_type": {ktype: unfused_count},
543
- "pattern_count": unfused_count - fused_count,
544
- "pattern_confidence": groups_analyzed / len(matches) if matches else 0,
545
- "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
546
- })
113
+ # Track patterns we've already seen to avoid duplicates
114
+ seen_patterns: set[tuple[int, str, str]] = set() # (layer, operation, fused_platform)
115
+
116
+ for layer_alignment in layer_alignments:
117
+ for pair in layer_alignment.kernel_pairs:
118
+ is_fused_op = _is_fused_operation(pair.operation)
119
+
120
+ # Case 1: AMD has a kernel for this operation, NVIDIA doesn't
121
+ if pair.amd_kernel and pair.amd_count > 0 and (pair.nvidia_kernel is None or pair.nvidia_count == 0):
122
+ if is_fused_op:
123
+ # AMD HAS the fused kernel AMD is fusing
124
+ # Find what NVIDIA runs separately
125
+ component_ops = _get_component_operations(pair.operation)
126
+ unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "NVIDIA")
127
+
128
+ pattern_key = (layer_alignment.layer, pair.operation, "AMD")
129
+ if pattern_key not in seen_patterns:
130
+ seen_patterns.add(pattern_key)
131
+
132
+ if unfused_kernels:
133
+ evidence = f"AMD fuses {' + '.join(component_ops)} into {pair.amd_kernel}, NVIDIA runs {len(unfused_kernels)} separate kernels"
134
+ else:
135
+ evidence = f"AMD fuses into {pair.amd_kernel}, NVIDIA runs components separately"
136
+
137
+ fusion_patterns.append(
138
+ FusionPattern(
139
+ layer=layer_alignment.layer,
140
+ operation=pair.operation,
141
+ fused_platform="AMD",
142
+ fused_kernel=pair.amd_kernel,
143
+ unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
144
+ count=pair.amd_count,
145
+ evidence=evidence,
146
+ )
147
+ )
148
+ else:
149
+ # Regular case: AMD has a kernel that NVIDIA fuses into something else
150
+ # This means NVIDIA is fusing (it doesn't need this separate kernel)
151
+ pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
152
+ if pattern_key not in seen_patterns:
153
+ seen_patterns.add(pattern_key)
154
+
155
+ evidence = f"AMD runs {pair.amd_kernel} separately ({pair.amd_count}x), NVIDIA fuses this operation"
156
+
157
+ fusion_patterns.append(
158
+ FusionPattern(
159
+ layer=layer_alignment.layer,
160
+ operation=pair.operation,
161
+ fused_platform="NVIDIA",
162
+ fused_kernel="(fused into nearby kernel)",
163
+ unfused_kernels=[pair.amd_kernel],
164
+ count=pair.amd_count,
165
+ evidence=evidence,
166
+ )
167
+ )
168
+
169
+ # Case 2: NVIDIA has a kernel for this operation, AMD doesn't
170
+ elif pair.nvidia_kernel and pair.nvidia_count > 0 and (pair.amd_kernel is None or pair.amd_count == 0):
171
+ if is_fused_op:
172
+ # NVIDIA HAS the fused kernel → NVIDIA is fusing
173
+ # Find what AMD runs separately
174
+ component_ops = _get_component_operations(pair.operation)
175
+ unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "AMD")
176
+
177
+ pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
178
+ if pattern_key not in seen_patterns:
179
+ seen_patterns.add(pattern_key)
180
+
181
+ if unfused_kernels:
182
+ evidence = f"NVIDIA fuses {' + '.join(component_ops)} into {pair.nvidia_kernel}, AMD runs {len(unfused_kernels)} separate kernels"
183
+ else:
184
+ evidence = f"NVIDIA fuses into {pair.nvidia_kernel}, AMD runs components separately"
185
+
186
+ fusion_patterns.append(
187
+ FusionPattern(
188
+ layer=layer_alignment.layer,
189
+ operation=pair.operation,
190
+ fused_platform="NVIDIA",
191
+ fused_kernel=pair.nvidia_kernel,
192
+ unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
193
+ count=pair.nvidia_count,
194
+ evidence=evidence,
195
+ )
196
+ )
197
+ else:
198
+ # Regular case: NVIDIA has a kernel that AMD fuses into something else
199
+ pattern_key = (layer_alignment.layer, pair.operation, "AMD")
200
+ if pattern_key not in seen_patterns:
201
+ seen_patterns.add(pattern_key)
202
+
203
+ evidence = f"NVIDIA runs {pair.nvidia_kernel} separately ({pair.nvidia_count}x), AMD fuses this operation"
204
+
205
+ fusion_patterns.append(
206
+ FusionPattern(
207
+ layer=layer_alignment.layer,
208
+ operation=pair.operation,
209
+ fused_platform="AMD",
210
+ fused_kernel="(fused into nearby kernel)",
211
+ unfused_kernels=[pair.nvidia_kernel],
212
+ count=pair.nvidia_count,
213
+ evidence=evidence,
214
+ )
215
+ )
216
+
217
+ # Case 3: Both have kernels but with very different counts (partial fusion)
218
+ elif (
219
+ pair.amd_kernel
220
+ and pair.nvidia_kernel
221
+ and pair.amd_count > 0
222
+ and pair.nvidia_count > 0
223
+ ):
224
+ count_ratio = pair.amd_count / pair.nvidia_count if pair.nvidia_count > 0 else float('inf')
225
+
226
+ if count_ratio > 1.5:
227
+ # AMD runs more → NVIDIA fuses some instances
228
+ pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
229
+ if pattern_key not in seen_patterns:
230
+ seen_patterns.add(pattern_key)
231
+
232
+ evidence = (
233
+ f"AMD runs {pair.amd_kernel} {count_ratio:.1f}x more "
234
+ f"({pair.amd_count} vs {pair.nvidia_count}), NVIDIA partially fuses"
235
+ )
236
+
237
+ fusion_patterns.append(
238
+ FusionPattern(
239
+ layer=layer_alignment.layer,
240
+ operation=pair.operation,
241
+ fused_platform="NVIDIA",
242
+ fused_kernel=pair.nvidia_kernel,
243
+ unfused_kernels=[pair.amd_kernel],
244
+ count=pair.amd_count - pair.nvidia_count,
245
+ evidence=evidence,
246
+ )
247
+ )
248
+
249
+ elif count_ratio < 0.67:
250
+ # NVIDIA runs more → AMD fuses some instances
251
+ inverse_ratio = pair.nvidia_count / pair.amd_count if pair.amd_count > 0 else float('inf')
252
+ pattern_key = (layer_alignment.layer, pair.operation, "AMD")
253
+ if pattern_key not in seen_patterns:
254
+ seen_patterns.add(pattern_key)
255
+
256
+ evidence = (
257
+ f"NVIDIA runs {pair.nvidia_kernel} {inverse_ratio:.1f}x more "
258
+ f"({pair.nvidia_count} vs {pair.amd_count}), AMD partially fuses"
259
+ )
260
+
261
+ fusion_patterns.append(
262
+ FusionPattern(
263
+ layer=layer_alignment.layer,
264
+ operation=pair.operation,
265
+ fused_platform="AMD",
266
+ fused_kernel=pair.amd_kernel,
267
+ unfused_kernels=[pair.nvidia_kernel],
268
+ count=pair.nvidia_count - pair.amd_count,
269
+ evidence=evidence,
270
+ )
271
+ )
272
+
273
+ # Aggregate patterns by operation type and fused platform
274
+ aggregated: dict[tuple[str, str], FusionPattern] = {}
275
+ for pattern in fusion_patterns:
276
+ key = (pattern.operation, pattern.fused_platform)
277
+ if key in aggregated:
278
+ existing = aggregated[key]
279
+ existing.count += pattern.count
280
+ # Merge unfused kernels (deduplicate)
281
+ for k in pattern.unfused_kernels:
282
+ if k not in existing.unfused_kernels:
283
+ existing.unfused_kernels.append(k)
547
284
  else:
548
- # Symmetric case for trace2 > trace1
549
- unfused_platform = trace2_name
550
- fused_platform = trace1_name
551
- unfused_count = trace2_count
552
- fused_count = trace1_count
553
-
554
- fusion_targets = Counter()
555
- groups_analyzed = 0
556
-
557
- for trace1_cid, trace2_cid in matches:
558
- trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
559
- trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
560
-
561
- trace1_has = ktype in trace1_ktypes
562
- trace2_has = ktype in trace2_ktypes
563
-
564
- if trace2_has and not trace1_has:
565
- trace1_kcounts = Counter(trace1_ktypes)
566
- trace2_kcounts = Counter(trace2_ktypes)
567
-
568
- for other_type in set(trace1_kcounts.keys()):
569
- if other_type == ktype or other_type == "Other":
570
- continue
571
- diff = trace1_kcounts[other_type] - trace2_kcounts.get(other_type, 0)
572
- if diff > 0:
573
- fusion_targets[other_type] += diff
574
-
575
- groups_analyzed += 1
576
-
577
- if fusion_targets and groups_analyzed >= 5:
578
- top_targets = fusion_targets.most_common(3)
579
- target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
580
-
581
- mappings.append({
582
- "fused_platform": fused_platform,
583
- "fused_kernel_type": top_targets[0][0],
584
- "fused_count": fused_count,
585
- "unfused_platform": unfused_platform,
586
- "unfused_sequence": [ktype],
587
- "unfused_count_per_type": {ktype: unfused_count},
588
- "pattern_count": unfused_count - fused_count,
589
- "pattern_confidence": groups_analyzed / len(matches) if matches else 0,
590
- "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
591
- })
592
-
593
- return mappings
594
-
595
-
596
- def analyze_fusion_differences(
597
- amd_trace_path: str | Path,
598
- nv_trace_path: str | Path,
599
- min_group_size: int = 50,
600
- ) -> dict[str, Any]:
601
- """Main fusion analysis function.
285
+ aggregated[key] = FusionPattern(
286
+ layer=pattern.layer,
287
+ operation=pattern.operation,
288
+ fused_platform=pattern.fused_platform,
289
+ fused_kernel=pattern.fused_kernel,
290
+ unfused_kernels=list(pattern.unfused_kernels),
291
+ count=pattern.count,
292
+ evidence=pattern.evidence,
293
+ )
602
294
 
603
- Args:
604
- amd_trace_path: Path to AMD trace
605
- nv_trace_path: Path to NVIDIA trace
606
- min_group_size: Only analyze correlation groups with at least this many kernels
295
+ amd_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "AMD")
296
+ nvidia_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "NVIDIA")
607
297
 
608
- Returns:
609
- Dictionary with analysis results containing:
610
- - metadata: trace info
611
- - global_counts: kernel type distribution across entire trace
612
- - fusion_opportunities: significant fusion differences
613
- - fusion_mappings: actual kernel-to-kernel mappings (NEW)
614
- """
615
- # Load traces (maintain order - don't swap)
616
- trace1_platform, trace1_gpu, trace1_kernels, trace1_corr_groups = _load_trace_for_fusion(
617
- amd_trace_path
618
- )
619
- trace2_platform, trace2_gpu, trace2_kernels, trace2_corr_groups = _load_trace_for_fusion(
620
- nv_trace_path
298
+ return FusionAnalysis(
299
+ patterns=list(aggregated.values()),
300
+ summary={
301
+ "amd_fuses": amd_fuses,
302
+ "nvidia_fuses": nvidia_fuses,
303
+ "total_fusion_opportunities": len(aggregated),
304
+ },
621
305
  )
622
306
 
623
- # Filter to "significant" correlation groups
624
- trace1_large = {
625
- cid: kernels
626
- for cid, kernels in trace1_corr_groups.items()
627
- if len(kernels) >= min_group_size
628
- }
629
- trace2_large = {
630
- cid: kernels
631
- for cid, kernels in trace2_corr_groups.items()
632
- if len(kernels) >= min_group_size
633
- }
634
-
635
- # Match correlation groups between platforms
636
- matches = _match_correlation_groups(trace1_large, trace2_large)
637
-
638
- # Analyze differences in matched groups
639
- fusion_diffs: dict[str, dict[str, Any]] = defaultdict(
640
- lambda: {
641
- "trace1_count": 0,
642
- "trace2_count": 0,
643
- "trace1_time_us": 0,
644
- "trace2_time_us": 0,
645
- "groups_with_diff": 0,
646
- "total_groups": 0,
647
- }
648
- )
649
307
 
650
- # NEW: Collect actual fusion mappings
651
- all_fusion_mappings = []
652
-
653
- for trace1_cid, trace2_cid in matches:
654
- trace1_comp, trace1_times = _analyze_correlation_group(trace1_large[trace1_cid])
655
- trace2_comp, trace2_times = _analyze_correlation_group(trace2_large[trace2_cid])
656
-
657
- # Find all kernel types in either platform
658
- all_types = set(trace1_comp.keys()) | set(trace2_comp.keys())
659
-
660
- for ktype in all_types:
661
- trace1_count = trace1_comp.get(ktype, 0)
662
- trace2_count = trace2_comp.get(ktype, 0)
663
- trace1_time = trace1_times.get(ktype, 0)
664
- trace2_time = trace2_times.get(ktype, 0)
665
-
666
- fusion_diffs[ktype]["trace1_count"] += trace1_count
667
- fusion_diffs[ktype]["trace2_count"] += trace2_count
668
- fusion_diffs[ktype]["trace1_time_us"] += trace1_time
669
- fusion_diffs[ktype]["trace2_time_us"] += trace2_time
670
- fusion_diffs[ktype]["total_groups"] += 1
671
-
672
- if trace1_count != trace2_count:
673
- fusion_diffs[ktype]["groups_with_diff"] += 1
674
-
675
- # NEW: Find actual kernel mappings in this correlation group
676
- group_mappings = _find_fusion_mappings(
677
- trace1_large[trace1_cid],
678
- trace2_large[trace2_cid],
679
- trace1_name=trace1_platform,
680
- trace2_name=trace2_platform
681
- )
682
- # Add correlation ID context to each mapping
683
- for mapping in group_mappings:
684
- mapping["correlation_group_trace1"] = trace1_cid
685
- mapping["correlation_group_trace2"] = trace2_cid
686
- all_fusion_mappings.extend(group_mappings)
687
-
688
- # Also get global counts for context
689
- global_trace1_counts: Counter[str] = Counter(
690
- [classify_kernel(k.get("name", "")) for k in trace1_kernels]
691
- )
692
- global_trace2_counts: Counter[str] = Counter(
693
- [classify_kernel(k.get("name", "")) for k in trace2_kernels]
694
- )
308
+ def analyze_fusion_from_alignment(
309
+ layer_alignments: list[LayerAlignment],
310
+ ) -> dict[str, Any]:
311
+ """Analyze fusion from alignment data (for API compatibility).
695
312
 
696
- # Build results
697
- results: dict[str, Any] = {
698
- "metadata": {
699
- "trace1_gpu": trace1_gpu,
700
- "trace2_gpu": trace2_gpu,
701
- "trace1_total_kernels": len(trace1_kernels),
702
- "trace2_total_kernels": len(trace2_kernels),
703
- "trace1_correlation_groups": len(trace1_large),
704
- "trace2_correlation_groups": len(trace2_large),
705
- "matched_groups": len(matches),
706
- },
707
- "global_counts": {},
708
- "fusion_opportunities": [],
709
- "fusion_mappings": all_fusion_mappings, # NEW: Include actual mappings
710
- }
313
+ Args:
314
+ layer_alignments: List of aligned layers
711
315
 
712
- # Global counts for all kernel types
713
- all_ktypes = set(global_trace1_counts.keys()) | set(global_trace2_counts.keys())
714
- for ktype in all_ktypes:
715
- trace1_total = global_trace1_counts.get(ktype, 0)
716
- trace2_total = global_trace2_counts.get(ktype, 0)
717
-
718
- if trace1_total > 0 or trace2_total > 0:
719
- results["global_counts"][ktype] = {
720
- "trace1_count": trace1_total,
721
- "trace2_count": trace2_total,
722
- "ratio": trace1_total / trace2_total if trace2_total > 0 else float("inf"),
316
+ Returns:
317
+ Dictionary with fusion analysis results (compatible with old API)
318
+ """
319
+ fusion_analysis = detect_fusion_in_aligned_layers(layer_alignments)
320
+
321
+ fusion_opportunities = []
322
+ fusion_mappings = []
323
+
324
+ for pattern in fusion_analysis.patterns:
325
+ unfused_platform = "NVIDIA" if pattern.fused_platform == "AMD" else "AMD"
326
+
327
+ fusion_opportunities.append(
328
+ {
329
+ "kernel_type": pattern.operation,
330
+ "layer": pattern.layer,
331
+ "fused_by": pattern.fused_platform,
332
+ "fused_kernel": pattern.fused_kernel,
333
+ "unfused_kernels": pattern.unfused_kernels,
334
+ "count": pattern.count,
335
+ "evidence": pattern.evidence,
723
336
  }
724
-
725
- # Fusion opportunities from matched correlation groups
726
- for ktype, stats in fusion_diffs.items():
727
- trace1_avg = (
728
- stats["trace1_count"] / stats["total_groups"]
729
- if stats["total_groups"] > 0
730
- else 0
731
337
  )
732
- trace2_avg = (
733
- stats["trace2_count"] / stats["total_groups"]
734
- if stats["total_groups"] > 0
735
- else 0
736
- )
737
- trace1_time_ms = stats["trace1_time_us"] / 1000
738
- trace2_time_ms = stats["trace2_time_us"] / 1000
739
-
740
- # Calculate significance
741
- diff_ratio = trace1_avg / trace2_avg if trace2_avg > 0 else float("inf")
742
- reverse_ratio = trace2_avg / trace1_avg if trace1_avg > 0 else float("inf")
743
-
744
- # Only report if there's a significant difference
745
- # Either: one platform has it and the other doesn't (ratio > 10)
746
- # Or: one platform has significantly more (ratio > 2.0)
747
- is_significant = (
748
- (diff_ratio > 10.0 or reverse_ratio > 10.0)
749
- or ( # One platform doesn't have it
750
- (diff_ratio > 2.0 or reverse_ratio > 2.0)
751
- and stats["trace1_count"] + stats["trace2_count"] > 20 # Significant difference # Not trivial counts
752
- )
753
- )
754
-
755
- if is_significant:
756
- # Determine who fuses (who has FEWER calls = more fusion)
757
- if diff_ratio > 1.5:
758
- fused_by = trace2_platform # Trace 2 has fewer calls, so it fuses more
759
- ratio = diff_ratio
760
- else:
761
- fused_by = trace1_platform # Trace 1 has fewer calls, so it fuses more
762
- ratio = reverse_ratio
763
-
764
- # Calculate time ratio (who's faster for this operation)
765
- time_ratio = trace1_time_ms / trace2_time_ms if trace2_time_ms > 0 else float("inf")
766
-
767
- results["fusion_opportunities"].append(
768
- {
769
- "kernel_type": ktype,
770
- "trace1_total": stats["trace1_count"],
771
- "trace2_total": stats["trace2_count"],
772
- "trace1_avg_per_group": trace1_avg,
773
- "trace2_avg_per_group": trace2_avg,
774
- "trace1_time_ms": trace1_time_ms,
775
- "trace2_time_ms": trace2_time_ms,
776
- "time_ratio": time_ratio,
777
- "ratio": ratio,
778
- "fused_by": fused_by,
779
- "groups_affected": stats["groups_with_diff"],
780
- "total_groups": stats["total_groups"],
781
- }
782
- )
783
338
 
784
- # ADDITIONAL: Check global counts for significant differences not captured above
785
- # This catches patterns like Sort that may be in small groups or distributed differently
786
- for ktype in all_ktypes:
787
- # Skip if already added from correlation group analysis
788
- if any(opp["kernel_type"] == ktype for opp in results["fusion_opportunities"]):
789
- continue
790
-
791
- trace1_total = global_trace1_counts.get(ktype, 0)
792
- trace2_total = global_trace2_counts.get(ktype, 0)
793
-
794
- # Skip trivial counts
795
- if trace1_total + trace2_total < 100:
796
- continue
797
-
798
- # Calculate global ratio
799
- global_ratio = trace1_total / trace2_total if trace2_total > 0 else float("inf")
800
- global_reverse_ratio = trace2_total / trace1_total if trace1_total > 0 else float("inf")
801
-
802
- # Check if globally significant (more aggressive threshold for comprehensive detection)
803
- is_globally_significant = (
804
- (global_ratio > 2.0 or global_reverse_ratio > 2.0)
805
- and (trace1_total + trace2_total > 100)
339
+ fusion_mappings.append(
340
+ {
341
+ "fused_platform": pattern.fused_platform,
342
+ "fused_kernel_type": pattern.operation,
343
+ "fused_kernel_name": pattern.fused_kernel,
344
+ "unfused_platform": unfused_platform,
345
+ "unfused_sequence": pattern.unfused_kernels,
346
+ "pattern_count": pattern.count,
347
+ "evidence": pattern.evidence,
348
+ "layer": pattern.layer,
349
+ }
806
350
  )
807
351
 
808
- if is_globally_significant:
809
- # Get timing info from all kernels (not just matched groups)
810
- trace1_time = sum(
811
- k.get("dur", 0) for k in trace1_kernels
812
- if classify_kernel(k.get("name", "")) == ktype
813
- ) / 1000 # Convert to ms
814
- trace2_time = sum(
815
- k.get("dur", 0) for k in trace2_kernels
816
- if classify_kernel(k.get("name", "")) == ktype
817
- ) / 1000
818
-
819
- # Determine who fuses (who has FEWER calls = more fusion)
820
- if global_ratio > 1.5:
821
- fused_by = "Trace 2" # Trace 2 has fewer calls
822
- ratio = global_ratio
823
- else:
824
- fused_by = "Trace 1" # Trace 1 has fewer calls
825
- ratio = global_reverse_ratio
826
-
827
- time_ratio = trace1_time / trace2_time if trace2_time > 0 else float("inf")
828
-
829
- results["fusion_opportunities"].append(
830
- {
831
- "kernel_type": ktype,
832
- "trace1_total": trace1_total,
833
- "trace2_total": trace2_total,
834
- "trace1_avg_per_group": trace1_total / len(matches) if matches else 0,
835
- "trace2_avg_per_group": trace2_total / len(matches) if matches else 0,
836
- "trace1_time_ms": trace1_time,
837
- "trace2_time_ms": trace2_time,
838
- "time_ratio": time_ratio,
839
- "ratio": ratio,
840
- "fused_by": fused_by,
841
- "groups_affected": 0, # Unknown for global analysis
842
- "total_groups": len(matches),
843
- }
844
- )
845
-
846
- # Sort by impact (ratio * total count)
847
- results["fusion_opportunities"].sort(
848
- key=lambda x: x["ratio"] * (x["trace1_total"] + x["trace2_total"]), reverse=True
849
- )
850
-
851
- # ADD PARTIAL FUSION MAPPINGS using correlation group differential analysis
852
- # This catches patterns like Sort that exist on both platforms but with different frequencies
853
- partial_mappings = _find_partial_fusion_via_groups(
854
- trace1_large,
855
- trace2_large,
856
- matches,
857
- trace1_name=trace1_platform,
858
- trace2_name=trace2_platform
859
- )
860
- all_fusion_mappings.extend(partial_mappings)
861
-
862
- # DETECT INTRA-TYPE FUSION (same kernel type fused with itself, like Sort chains)
863
- # Do this FIRST since it's more accurate than the fallback global analysis
864
- intra_mappings = _detect_intra_type_fusion(
865
- trace1_kernels,
866
- trace2_kernels,
867
- trace1_name=trace1_platform,
868
- trace2_name=trace2_platform
869
- )
870
- all_fusion_mappings.extend(intra_mappings)
871
-
872
- # Collect kernel types already handled by intra-type fusion
873
- intra_handled_types = set(m["fused_kernel_type"] for m in intra_mappings)
874
-
875
- # ALSO ADD GLOBAL FUSION MAPPINGS for kernels not in large correlation groups
876
- # Skip types already handled by intra-type fusion (more accurate)
877
- global_mappings = _find_fusion_mappings(
878
- trace1_kernels,
879
- trace2_kernels,
880
- trace1_name=trace1_platform,
881
- trace2_name=trace2_platform
882
- )
883
- # Filter: skip if already handled or if evidence is duplicate
884
- existing_evidence = set(m["evidence"] for m in all_fusion_mappings)
885
- for mapping in global_mappings:
886
- ktype = mapping["unfused_sequence"][0] if mapping["unfused_sequence"] else None
887
- if ktype not in intra_handled_types and mapping["evidence"] not in existing_evidence:
888
- all_fusion_mappings.append(mapping)
889
-
890
- return results
352
+ return {
353
+ "fusion_opportunities": fusion_opportunities,
354
+ "fusion_mappings": fusion_mappings,
355
+ "summary": fusion_analysis.summary,
356
+ }