wafer-core 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,144 +1,45 @@
1
1
  """Fusion analysis for detecting kernel fusion differences between platforms.
2
2
 
3
- Detects fusion differences between AMD and NVIDIA by analyzing how many kernels
4
- each platform launches for the same logical operations.
3
+ Detects fusion differences between AMD and NVIDIA by analyzing:
4
+ 1. Kernel types unique to each platform (one platform fuses them away)
5
+ 2. Sequence patterns (what comes before/after unique kernels)
6
+ 3. Count imbalances (one platform has significantly more calls)
7
+
8
+ This is the pattern-based approach which is more reliable than alignment-based detection.
5
9
  """
6
10
 
7
- import json
8
11
  from collections import Counter, defaultdict
9
- from pathlib import Path
12
+ from dataclasses import dataclass, field
10
13
  from typing import Any
11
14
 
12
- from .classifier import classify_kernel
13
-
14
-
15
- def _load_trace_for_fusion(
16
- file_path: str | Path,
17
- ) -> tuple[str, str, list[dict[str, Any]], dict[int, list[dict[str, Any]]]]:
18
- """Load trace and group kernels by correlation ID.
19
-
20
- Args:
21
- file_path: Path to trace file
22
-
23
- Returns:
24
- Tuple of (platform, gpu_name, all_kernels, corr_groups)
25
- """
26
- with open(file_path, "rb") as f:
27
- trace = json.load(f)
28
-
29
- # Detect platform
30
- props = trace.get("deviceProperties", [{}])[0]
31
- is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
32
- platform = "AMD" if is_amd else "NVIDIA"
33
- gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
34
-
35
- # Get all kernel events
36
- events = trace.get("traceEvents", [])
37
- kernels = [e for e in events if e.get("cat") == "kernel"]
38
-
39
- # Group by correlation ID
40
- corr_groups: dict[int, list[dict[str, Any]]] = defaultdict(list)
41
- for k in kernels:
42
- corr_id = k.get("args", {}).get("correlation")
43
- if corr_id is not None:
44
- corr_groups[corr_id].append(k)
15
+ from .classifier import classify, Op
45
16
 
46
- return platform, gpu_name, kernels, dict(corr_groups)
47
17
 
18
+ @dataclass
19
+ class FusionPattern:
20
+ """A detected fusion pattern."""
48
21
 
49
- def _analyze_correlation_group(
50
- kernels: list[dict[str, Any]],
51
- ) -> tuple[dict[str, int], dict[str, float]]:
52
- """Analyze kernel composition within a correlation group.
22
+ layer: int
23
+ operation: str # The fused operation name (e.g., "RMSNorm+GEMM")
24
+ fused_platform: str # Platform that fuses (has fewer kernels)
25
+ fused_kernel: str # The actual fused kernel name
26
+ unfused_kernels: list[str] # List of separate kernels on the other platform
27
+ count: int
28
+ evidence: str
53
29
 
54
- Args:
55
- kernels: List of kernel events in the group
56
-
57
- Returns:
58
- Tuple of (counts, timings) where counts maps kernel types to counts
59
- and timings maps kernel types to total duration in microseconds
60
- """
61
- counts: Counter[str] = Counter()
62
- timings: dict[str, float] = defaultdict(float)
63
-
64
- for k in kernels:
65
- kernel_type = classify_kernel(k.get("name", ""))
66
- counts[kernel_type] += 1
67
- timings[kernel_type] += k.get("dur", 0)
68
-
69
- return dict(counts), dict(timings)
70
30
 
31
+ @dataclass
32
+ class FusionAnalysis:
33
+ """Complete fusion analysis result."""
71
34
 
72
- def _match_correlation_groups(
73
- amd_groups: dict[int, list[dict[str, Any]]],
74
- nv_groups: dict[int, list[dict[str, Any]]],
75
- size_tolerance: float = 0.25,
76
- ) -> list[tuple[int, int]]:
77
- """Match AMD and NVIDIA correlation groups by size and composition.
35
+ patterns: list[FusionPattern] = field(default_factory=list)
36
+ summary: dict[str, Any] = field(default_factory=dict)
78
37
 
79
- Since correlation IDs don't match between platforms, we match groups that
80
- have similar sizes (within tolerance) and likely represent the same operation.
81
-
82
- Args:
83
- amd_groups: AMD correlation groups
84
- nv_groups: NVIDIA correlation groups
85
- size_tolerance: Groups match if sizes are within this fraction (25% default, increased for better matching)
86
-
87
- Returns:
88
- List of (amd_corr_id, nv_corr_id) pairs
89
- """
90
- matches = []
91
- used_nv_ids: set[int] = set()
92
38
 
93
- # Pre-compute compositions for all groups to avoid redundant analysis
94
- amd_comps = {id: _analyze_correlation_group(kernels)[0] for id, kernels in amd_groups.items()}
95
- nv_comps = {id: _analyze_correlation_group(kernels)[0] for id, kernels in nv_groups.items()}
96
-
97
- # Build size-indexed lookup for NVIDIA groups for faster filtering
98
- nv_by_size = [(id, len(kernels)) for id, kernels in nv_groups.items()]
99
-
100
- # Sort AMD groups by size (largest first for better matching)
101
- amd_sorted = sorted(amd_groups.items(), key=lambda x: len(x[1]), reverse=True)
102
-
103
- for amd_id, amd_kernels in amd_sorted:
104
- amd_size = len(amd_kernels)
105
- amd_comp = amd_comps[amd_id]
106
-
107
- # Calculate size bounds for matching
108
- min_size = amd_size / (1 + size_tolerance)
109
- max_size = amd_size * (1 + size_tolerance)
110
-
111
- # Find best matching NVIDIA group
112
- best_match = None
113
- best_score = float("inf")
114
-
115
- for nv_id, nv_size in nv_by_size:
116
- if nv_id in used_nv_ids:
117
- continue
118
-
119
- # Quick size filter
120
- if nv_size < min_size or nv_size > max_size:
121
- continue
122
-
123
- # Use pre-computed composition
124
- nv_comp = nv_comps[nv_id]
125
-
126
- # Simple similarity: number of shared kernel types
127
- shared_types = set(amd_comp.keys()) & set(nv_comp.keys())
128
- similarity = len(shared_types)
129
-
130
- # Prefer matches with more shared types and closer sizes
131
- score = abs(amd_size - nv_size) - (similarity * 10)
132
-
133
- if score < best_score:
134
- best_score = score
135
- best_match = nv_id
136
-
137
- if best_match is not None:
138
- matches.append((amd_id, best_match))
139
- used_nv_ids.add(best_match)
140
-
141
- return matches
39
+ def _classify_kernel(name: str, platform: str = "AMD") -> str:
40
+ """Classify a kernel name to its operation type."""
41
+ op, _pattern = classify(name, platform)
42
+ return op.value
142
43
 
143
44
 
144
45
  def _find_fusion_mappings(
@@ -146,6 +47,7 @@ def _find_fusion_mappings(
146
47
  trace2_kernels: list[dict],
147
48
  trace1_name: str = "Trace1",
148
49
  trace2_name: str = "Trace2",
50
+ trace1_platform: str = "AMD",
149
51
  ) -> list[dict]:
150
52
  """Find fusion mappings by analyzing kernel execution sequence patterns.
151
53
 
@@ -157,29 +59,21 @@ def _find_fusion_mappings(
157
59
  trace2_kernels: List of kernel events from second trace
158
60
  trace1_name: Name of first platform (e.g., "AMD")
159
61
  trace2_name: Name of second platform (e.g., "NVIDIA")
62
+ trace1_platform: Platform string for classification ("AMD" or "NVIDIA")
160
63
 
161
64
  Returns:
162
- List of mapping dictionaries, each containing:
163
- - fused_platform: Which platform fuses the operations
164
- - fused_kernel_type: The single fused kernel type
165
- - unfused_platform: Which platform runs them separately
166
- - unfused_sequence: List of kernel types run separately
167
- - pattern_count: How many times this pattern appears
168
- - pattern_confidence: Fraction of occurrences following this pattern
169
- - evidence: Human-readable description
65
+ List of mapping dictionaries with fusion details
170
66
  """
171
- from collections import defaultdict
172
- from wafer_core.lib.trace_compare.classifier import classify_kernel
173
-
174
67
  mappings = []
68
+ trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
175
69
 
176
70
  # Sort kernels by timestamp
177
- trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get('ts', 0))
178
- trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get('ts', 0))
71
+ trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get("ts", 0))
72
+ trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get("ts", 0))
179
73
 
180
74
  # Classify all kernels
181
- trace1_types = [classify_kernel(k.get('name', '')) for k in trace1_sorted]
182
- trace2_types = [classify_kernel(k.get('name', '')) for k in trace2_sorted]
75
+ trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_sorted]
76
+ trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_sorted]
183
77
 
184
78
  # Find kernel types unique to each trace
185
79
  trace1_type_set = set(trace1_types)
@@ -190,30 +84,34 @@ def _find_fusion_mappings(
190
84
 
191
85
  # For each unique type in trace1, find common sequence patterns
192
86
  for unique_type in trace1_only:
87
+ # Skip "Other" since it's too generic
88
+ if unique_type == "Other":
89
+ continue
90
+
193
91
  # Find all occurrences of this type
194
92
  indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
195
93
 
196
94
  if len(indices) < 5: # Need enough samples to be meaningful
197
95
  continue
198
96
 
199
- # Analyze what comes before/after each occurrence
200
- before_types = defaultdict(int)
97
+ # Analyze what comes before each occurrence
98
+ before_types: dict[str, int] = defaultdict(int)
201
99
 
202
100
  for idx in indices:
203
101
  if idx > 0:
204
102
  before_types[trace1_types[idx - 1]] += 1
205
103
 
206
- # Find the most common pattern (e.g., "Attention → Reduce")
207
- most_common_before = max(before_types.items(), key=lambda x: x[1]) if before_types else (None, 0)
104
+ # Find the most common pattern
105
+ if not before_types:
106
+ continue
107
+ most_common_before = max(before_types.items(), key=lambda x: x[1])
208
108
 
209
109
  # If there's a strong pattern (>80% of occurrences)
210
110
  if most_common_before[1] / len(indices) > 0.8:
211
- # This suggests: Trace2 likely fuses [before_type + unique_type] into [before_type]
212
111
  fusion_candidate = most_common_before[0]
213
112
 
214
113
  # Verify trace2 has this type
215
114
  if fusion_candidate in trace2_type_set:
216
- # Count occurrences to compare
217
115
  trace1_fusion_count = trace1_types.count(fusion_candidate)
218
116
  trace2_fusion_count = trace2_types.count(fusion_candidate)
219
117
 
@@ -225,15 +123,18 @@ def _find_fusion_mappings(
225
123
  "unfused_sequence": [fusion_candidate, unique_type],
226
124
  "unfused_count_per_type": {
227
125
  fusion_candidate: trace1_fusion_count,
228
- unique_type: len(indices)
126
+ unique_type: len(indices),
229
127
  },
230
128
  "pattern_count": len(indices),
231
129
  "pattern_confidence": most_common_before[1] / len(indices),
232
- "evidence": f"{trace1_name} runs {fusion_candidate}+{unique_type} separately, {trace2_name} fuses into {fusion_candidate}"
130
+ "evidence": f"{trace1_name} runs {fusion_candidate}+{unique_type} separately, {trace2_name} fuses into {fusion_candidate}",
233
131
  })
234
132
 
235
133
  # Also check trace2-only types
236
134
  for unique_type in trace2_only:
135
+ if unique_type == "Other":
136
+ continue
137
+
237
138
  indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
238
139
 
239
140
  if len(indices) < 5:
@@ -245,7 +146,9 @@ def _find_fusion_mappings(
245
146
  if idx > 0:
246
147
  before_types[trace2_types[idx - 1]] += 1
247
148
 
248
- most_common_before = max(before_types.items(), key=lambda x: x[1]) if before_types else (None, 0)
149
+ if not before_types:
150
+ continue
151
+ most_common_before = max(before_types.items(), key=lambda x: x[1])
249
152
 
250
153
  if most_common_before[1] / len(indices) > 0.8:
251
154
  fusion_candidate = most_common_before[0]
@@ -262,629 +165,300 @@ def _find_fusion_mappings(
262
165
  "unfused_sequence": [fusion_candidate, unique_type],
263
166
  "unfused_count_per_type": {
264
167
  fusion_candidate: trace2_fusion_count,
265
- unique_type: len(indices)
168
+ unique_type: len(indices),
266
169
  },
267
170
  "pattern_count": len(indices),
268
171
  "pattern_confidence": most_common_before[1] / len(indices),
269
- "evidence": f"{trace2_name} runs {fusion_candidate}+{unique_type} separately, {trace1_name} fuses into {fusion_candidate}"
270
- })
271
-
272
- # NEW: Check for partial fusion patterns (kernel exists on both platforms but with different counts)
273
- # If one platform has significantly fewer calls (>1.3x ratio), look for fusion patterns
274
- common_types = trace1_type_set & trace2_type_set
275
-
276
- for ktype in common_types:
277
- trace1_count = trace1_types.count(ktype)
278
- trace2_count = trace2_types.count(ktype)
279
-
280
- # Check if there's a significant imbalance (one has >1.3x more)
281
- if trace1_count == 0 or trace2_count == 0:
282
- continue
283
-
284
- ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
285
-
286
- if ratio < 1.3 or trace1_count + trace2_count < 100:
287
- continue
288
-
289
- # Determine which platform has more (unfused) and which has fewer (fused)
290
- if trace1_count > trace2_count:
291
- # Trace1 has more separate calls, Trace2 likely fuses
292
- unfused_platform = trace1_name
293
- fused_platform = trace2_name
294
- unfused_count = trace1_count
295
- fused_count = trace2_count
296
-
297
- # Find what Trace2 might be fusing this into
298
- # Use the most common kernel type in Trace2 as a likely fusion target
299
- trace2_type_counts = defaultdict(int)
300
- for t in trace2_types:
301
- if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
302
- trace2_type_counts[t] += 1
303
-
304
- if trace2_type_counts:
305
- # Use the most common type as the fusion target
306
- fusion_target = max(trace2_type_counts.items(), key=lambda x: x[1])[0]
307
-
308
- mappings.append({
309
- "fused_platform": fused_platform,
310
- "fused_kernel_type": fusion_target,
311
- "fused_count": fused_count,
312
- "unfused_platform": unfused_platform,
313
- "unfused_sequence": [ktype],
314
- "unfused_count_per_type": {
315
- ktype: unfused_count
316
- },
317
- "pattern_count": unfused_count - fused_count,
318
- "pattern_confidence": (unfused_count - fused_count) / unfused_count,
319
- "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
320
- })
321
- else:
322
- # Trace2 has more separate calls, Trace1 likely fuses
323
- unfused_platform = trace2_name
324
- fused_platform = trace1_name
325
- unfused_count = trace2_count
326
- fused_count = trace1_count
327
-
328
- # Find what Trace1 might be fusing this into
329
- # Use the most common kernel type in Trace1 as a likely fusion target
330
- trace1_type_counts = defaultdict(int)
331
- for t in trace1_types:
332
- if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
333
- trace1_type_counts[t] += 1
334
-
335
- if trace1_type_counts:
336
- fusion_target = max(trace1_type_counts.items(), key=lambda x: x[1])[0]
337
-
338
- mappings.append({
339
- "fused_platform": fused_platform,
340
- "fused_kernel_type": fusion_target,
341
- "fused_count": fused_count,
342
- "unfused_platform": unfused_platform,
343
- "unfused_sequence": [ktype],
344
- "unfused_count_per_type": {
345
- ktype: unfused_count
346
- },
347
- "pattern_count": unfused_count - fused_count,
348
- "pattern_confidence": (unfused_count - fused_count) / unfused_count,
349
- "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
172
+ "evidence": f"{trace2_name} runs {fusion_candidate}+{unique_type} separately, {trace1_name} fuses into {fusion_candidate}",
350
173
  })
351
174
 
352
175
  return mappings
353
176
 
354
177
 
355
- def _detect_intra_type_fusion(
178
+ def _find_count_imbalance_fusions(
356
179
  trace1_kernels: list[dict],
357
180
  trace2_kernels: list[dict],
358
- trace1_name: str,
359
- trace2_name: str,
181
+ trace1_name: str = "Trace1",
182
+ trace2_name: str = "Trace2",
183
+ trace1_platform: str = "AMD",
360
184
  ) -> list[dict]:
361
- """Detect intra-type fusion where consecutive same-type kernels are fused.
185
+ """Find fusions by looking for significant count imbalances.
362
186
 
363
- Example: AMD runs Sort→Sort→Sort (42 calls) while NVIDIA runs Sort→Sort (10 calls)
364
- This indicates NVIDIA has a more efficient Sort implementation that fuses operations.
187
+ When one platform has significantly more kernel calls of a type (>1.5x),
188
+ it suggests the other platform fuses those operations.
365
189
  """
366
- from wafer_core.lib.trace_compare.classifier import classify_kernel
367
-
368
- def analyze_chains(kernels):
369
- """Find chains of consecutive same-type kernels"""
370
- sorted_kernels = sorted(kernels, key=lambda k: k.get('ts', 0))
371
- types = [classify_kernel(k['name']) for k in sorted_kernels]
372
-
373
- chains = defaultdict(list)
374
- i = 0
375
- while i < len(types):
376
- ktype = types[i]
377
- count = 0
378
- while i < len(types) and types[i] == ktype:
379
- count += 1
380
- i += 1
381
- chains[ktype].append(count)
382
-
383
- return chains
384
-
385
- trace1_chains = analyze_chains(trace1_kernels)
386
- trace2_chains = analyze_chains(trace2_kernels)
387
-
388
190
  mappings = []
389
- all_types = set(trace1_chains.keys()) | set(trace2_chains.keys())
191
+ trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
390
192
 
391
- for ktype in all_types:
392
- t1_lengths = trace1_chains.get(ktype, [])
393
- t2_lengths = trace2_chains.get(ktype, [])
193
+ # Classify all kernels
194
+ trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_kernels]
195
+ trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_kernels]
394
196
 
395
- # Skip if not enough data
396
- if len(t1_lengths) < 5 and len(t2_lengths) < 5:
397
- continue
197
+ # Count by type
198
+ trace1_counts = Counter(trace1_types)
199
+ trace2_counts = Counter(trace2_types)
398
200
 
399
- # Filter to chains with multiple kernels
400
- t1_multi = [l for l in t1_lengths if l > 1]
401
- t2_multi = [l for l in t2_lengths if l > 1]
201
+ # Find common types with significant differences
202
+ common_types = set(trace1_counts.keys()) & set(trace2_counts.keys())
402
203
 
403
- if not t1_multi and not t2_multi:
204
+ for ktype in common_types:
205
+ if ktype == "Other":
404
206
  continue
405
207
 
406
- t1_total = sum(t1_lengths)
407
- t2_total = sum(t2_lengths)
408
- t1_chains = len(t1_multi) if t1_multi else len(t1_lengths)
409
- t2_chains = len(t2_multi) if t2_multi else len(t2_lengths)
208
+ trace1_count = trace1_counts[ktype]
209
+ trace2_count = trace2_counts[ktype]
410
210
 
411
- if t1_chains == 0 or t2_chains == 0:
211
+ # Skip trivial counts
212
+ if trace1_count + trace2_count < 50:
412
213
  continue
413
214
 
414
- t1_avg_chain = sum(t1_multi) / len(t1_multi) if t1_multi else 1.0
415
- t2_avg_chain = sum(t2_multi) / len(t2_multi) if t2_multi else 1.0
416
-
417
- chain_ratio = max(t1_avg_chain, t2_avg_chain) / min(t1_avg_chain, t2_avg_chain)
418
-
419
- # Significant intra-fusion if chains are 2x+ different
420
- if chain_ratio > 2.0 and abs(t1_total - t2_total) > 100:
421
- if t1_avg_chain > t2_avg_chain:
422
- unfused_platform = trace1_name
423
- fused_platform = trace2_name
424
- unfused_chains = t1_chains
425
- fused_chains = t2_chains
426
- unfused_avg = t1_avg_chain
427
- fused_avg = t2_avg_chain
428
- unfused_total = t1_total
429
- fused_total = t2_total
430
- else:
431
- unfused_platform = trace2_name
432
- fused_platform = trace1_name
433
- unfused_chains = t2_chains
434
- fused_chains = t1_chains
435
- unfused_avg = t2_avg_chain
436
- fused_avg = t1_avg_chain
437
- unfused_total = t2_total
438
- fused_total = t1_total
439
-
440
- mappings.append({
441
- "fused_platform": fused_platform,
442
- "fused_kernel_type": ktype,
443
- "fused_count": fused_total,
444
- "unfused_platform": unfused_platform,
445
- "unfused_sequence": [ktype, ktype], # Same type repeated
446
- "unfused_count_per_type": {ktype: unfused_total},
447
- "pattern_count": unfused_total - fused_total,
448
- "pattern_confidence": min(unfused_chains, fused_chains) / max(unfused_chains, fused_chains),
449
- "evidence": f"{unfused_platform} runs {ktype} in chains of {unfused_avg:.0f} calls ({unfused_chains} chains, {unfused_total:,} total), {fused_platform} fuses to {fused_avg:.0f} calls ({fused_chains} chains, {fused_total:,} total) - {chain_ratio:.1f}x more efficient"
450
- })
451
-
452
- return mappings
453
-
454
-
455
- def _find_partial_fusion_via_groups(
456
- trace1_large: dict[int, list[dict]],
457
- trace2_large: dict[int, list[dict]],
458
- matches: list[tuple[int, int]],
459
- trace1_name: str,
460
- trace2_name: str,
461
- ) -> list[dict]:
462
- """Find partial fusion patterns by analyzing correlation group differences.
463
-
464
- When one platform has fewer of a kernel type, check what kernel types the
465
- other platform has MORE of in those same groups - those are likely fusion targets.
466
- """
467
- from collections import Counter
468
- from wafer_core.lib.trace_compare.classifier import classify_kernel
469
-
470
- mappings = []
471
-
472
- # For each matched pair, track kernel type counts
473
- trace1_all_types = []
474
- trace2_all_types = []
475
-
476
- for trace1_cid, trace2_cid in matches:
477
- trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
478
- trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
479
- trace1_all_types.extend(trace1_ktypes)
480
- trace2_all_types.extend(trace2_ktypes)
481
-
482
- # Find kernel types with significant imbalances
483
- trace1_counts = Counter(trace1_all_types)
484
- trace2_counts = Counter(trace2_all_types)
485
- all_types = set(trace1_counts.keys()) | set(trace2_counts.keys())
486
-
487
- for ktype in all_types:
488
- trace1_count = trace1_counts.get(ktype, 0)
489
- trace2_count = trace2_counts.get(ktype, 0)
490
-
215
+ # Check if there's a significant imbalance (>1.5x)
491
216
  if trace1_count == 0 or trace2_count == 0:
492
- continue # Handled by sequence-based detection
217
+ continue
493
218
 
494
219
  ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
495
220
 
496
- if ratio < 1.3 or trace1_count + trace2_count < 100:
497
- continue # Not significant
221
+ if ratio < 1.5:
222
+ continue
498
223
 
499
- # Determine which platform has fewer (fuses more)
224
+ # Determine which platform has more (unfused) and which has fewer (fused)
500
225
  if trace1_count > trace2_count:
501
226
  unfused_platform = trace1_name
502
227
  fused_platform = trace2_name
503
228
  unfused_count = trace1_count
504
229
  fused_count = trace2_count
505
-
506
- # Find groups where trace1 has this kernel but trace2 doesn't
507
- fusion_targets = Counter()
508
- groups_analyzed = 0
509
-
510
- for trace1_cid, trace2_cid in matches:
511
- trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
512
- trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
513
-
514
- trace1_has = ktype in trace1_ktypes
515
- trace2_has = ktype in trace2_ktypes
516
-
517
- if trace1_has and not trace2_has:
518
- # What does trace2 have MORE of in this group?
519
- trace1_kcounts = Counter(trace1_ktypes)
520
- trace2_kcounts = Counter(trace2_ktypes)
521
-
522
- for other_type in set(trace2_kcounts.keys()):
523
- if other_type == ktype or other_type == "Other":
524
- continue
525
- diff = trace2_kcounts[other_type] - trace1_kcounts.get(other_type, 0)
526
- if diff > 0:
527
- fusion_targets[other_type] += diff
528
-
529
- groups_analyzed += 1
530
-
531
- if fusion_targets and groups_analyzed >= 5:
532
- # Report top fusion targets
533
- top_targets = fusion_targets.most_common(3)
534
- target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
535
-
536
- mappings.append({
537
- "fused_platform": fused_platform,
538
- "fused_kernel_type": top_targets[0][0],
539
- "fused_count": fused_count,
540
- "unfused_platform": unfused_platform,
541
- "unfused_sequence": [ktype],
542
- "unfused_count_per_type": {ktype: unfused_count},
543
- "pattern_count": unfused_count - fused_count,
544
- "pattern_confidence": groups_analyzed / len(matches) if matches else 0,
545
- "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
546
- })
547
230
  else:
548
- # Symmetric case for trace2 > trace1
549
231
  unfused_platform = trace2_name
550
232
  fused_platform = trace1_name
551
233
  unfused_count = trace2_count
552
234
  fused_count = trace1_count
553
235
 
554
- fusion_targets = Counter()
555
- groups_analyzed = 0
236
+ mappings.append({
237
+ "fused_platform": fused_platform,
238
+ "fused_kernel_type": ktype,
239
+ "fused_count": fused_count,
240
+ "unfused_platform": unfused_platform,
241
+ "unfused_sequence": [ktype],
242
+ "unfused_count_per_type": {ktype: unfused_count},
243
+ "pattern_count": unfused_count - fused_count,
244
+ "pattern_confidence": (unfused_count - fused_count) / unfused_count,
245
+ "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses",
246
+ })
556
247
 
557
- for trace1_cid, trace2_cid in matches:
558
- trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
559
- trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
248
+ return mappings
560
249
 
561
- trace1_has = ktype in trace1_ktypes
562
- trace2_has = ktype in trace2_ktypes
563
250
 
564
- if trace2_has and not trace1_has:
565
- trace1_kcounts = Counter(trace1_ktypes)
566
- trace2_kcounts = Counter(trace2_ktypes)
251
+ def _find_explicit_fused_operations(
252
+ trace1_kernels: list[dict],
253
+ trace2_kernels: list[dict],
254
+ trace1_name: str = "Trace1",
255
+ trace2_name: str = "Trace2",
256
+ trace1_platform: str = "AMD",
257
+ ) -> list[dict]:
258
+ """Find explicit fused operations (kernels with '+' in their classification).
567
259
 
568
- for other_type in set(trace1_kcounts.keys()):
569
- if other_type == ktype or other_type == "Other":
570
- continue
571
- diff = trace1_kcounts[other_type] - trace2_kcounts.get(other_type, 0)
572
- if diff > 0:
573
- fusion_targets[other_type] += diff
260
+ These are operations like 'RMSNorm+GEMM' that are explicitly classified as fused.
261
+ """
262
+ mappings = []
263
+ trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
264
+
265
+ def get_fused_ops(kernels: list[dict], platform: str) -> dict[str, list[str]]:
266
+ """Get fused operations and their kernel names."""
267
+ fused_ops: dict[str, list[str]] = defaultdict(list)
268
+ for k in kernels:
269
+ name = k.get("name", "")
270
+ op, _pattern = classify(name, platform)
271
+ if "+" in op.value or op == Op.FUSED_UNKNOWN:
272
+ fused_ops[op.value].append(name)
273
+ return dict(fused_ops)
274
+
275
+ trace1_fused = get_fused_ops(trace1_kernels, trace1_platform)
276
+ trace2_fused = get_fused_ops(trace2_kernels, trace2_platform)
277
+
278
+ # Find fused operations unique to trace1
279
+ for fused_op, kernels in trace1_fused.items():
280
+ if fused_op not in trace2_fused:
281
+ # Parse components from the fused op name
282
+ if "+" in fused_op:
283
+ components = [c.strip() for c in fused_op.split("+")]
284
+ else:
285
+ components = [fused_op]
574
286
 
575
- groups_analyzed += 1
287
+ mappings.append({
288
+ "fused_platform": trace1_name,
289
+ "fused_kernel_type": fused_op,
290
+ "fused_count": len(kernels),
291
+ "unfused_platform": trace2_name,
292
+ "unfused_sequence": components,
293
+ "unfused_count_per_type": {c: 0 for c in components}, # Unknown
294
+ "pattern_count": len(kernels),
295
+ "pattern_confidence": 1.0,
296
+ "evidence": f"{trace1_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
297
+ "fused_kernel_names": kernels[:3], # Sample of kernel names
298
+ })
576
299
 
577
- if fusion_targets and groups_analyzed >= 5:
578
- top_targets = fusion_targets.most_common(3)
579
- target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
300
+ # Find fused operations unique to trace2
301
+ for fused_op, kernels in trace2_fused.items():
302
+ if fused_op not in trace1_fused:
303
+ if "+" in fused_op:
304
+ components = [c.strip() for c in fused_op.split("+")]
305
+ else:
306
+ components = [fused_op]
580
307
 
581
- mappings.append({
582
- "fused_platform": fused_platform,
583
- "fused_kernel_type": top_targets[0][0],
584
- "fused_count": fused_count,
585
- "unfused_platform": unfused_platform,
586
- "unfused_sequence": [ktype],
587
- "unfused_count_per_type": {ktype: unfused_count},
588
- "pattern_count": unfused_count - fused_count,
589
- "pattern_confidence": groups_analyzed / len(matches) if matches else 0,
590
- "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
591
- })
308
+ mappings.append({
309
+ "fused_platform": trace2_name,
310
+ "fused_kernel_type": fused_op,
311
+ "fused_count": len(kernels),
312
+ "unfused_platform": trace1_name,
313
+ "unfused_sequence": components,
314
+ "unfused_count_per_type": {c: 0 for c in components},
315
+ "pattern_count": len(kernels),
316
+ "pattern_confidence": 1.0,
317
+ "evidence": f"{trace2_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
318
+ "fused_kernel_names": kernels[:3],
319
+ })
592
320
 
593
321
  return mappings
594
322
 
595
323
 
596
- def analyze_fusion_differences(
597
- amd_trace_path: str | Path,
598
- nv_trace_path: str | Path,
599
- min_group_size: int = 50,
600
- ) -> dict[str, Any]:
601
- """Main fusion analysis function.
324
+ def detect_fusion_patterns(
325
+ amd_kernels: list[dict],
326
+ nvidia_kernels: list[dict],
327
+ ) -> FusionAnalysis:
328
+ """Detect fusion patterns using pattern-based analysis.
329
+
330
+ This is the main entry point for fusion detection. It combines:
331
+ 1. Explicit fused operations (kernels classified with '+' in name)
332
+ 2. Sequence pattern analysis (unique kernel types with consistent patterns)
333
+ 3. Count imbalance analysis (one platform has significantly more calls)
602
334
 
603
335
  Args:
604
- amd_trace_path: Path to AMD trace
605
- nv_trace_path: Path to NVIDIA trace
606
- min_group_size: Only analyze correlation groups with at least this many kernels
336
+ amd_kernels: List of AMD kernel events
337
+ nvidia_kernels: List of NVIDIA kernel events
607
338
 
608
339
  Returns:
609
- Dictionary with analysis results containing:
610
- - metadata: trace info
611
- - global_counts: kernel type distribution across entire trace
612
- - fusion_opportunities: significant fusion differences
613
- - fusion_mappings: actual kernel-to-kernel mappings (NEW)
340
+ FusionAnalysis with detected patterns
614
341
  """
615
- # Load traces (maintain order - don't swap)
616
- trace1_platform, trace1_gpu, trace1_kernels, trace1_corr_groups = _load_trace_for_fusion(
617
- amd_trace_path
618
- )
619
- trace2_platform, trace2_gpu, trace2_kernels, trace2_corr_groups = _load_trace_for_fusion(
620
- nv_trace_path
621
- )
342
+ all_mappings: list[dict] = []
622
343
 
623
- # Filter to "significant" correlation groups
624
- trace1_large = {
625
- cid: kernels
626
- for cid, kernels in trace1_corr_groups.items()
627
- if len(kernels) >= min_group_size
628
- }
629
- trace2_large = {
630
- cid: kernels
631
- for cid, kernels in trace2_corr_groups.items()
632
- if len(kernels) >= min_group_size
633
- }
634
-
635
- # Match correlation groups between platforms
636
- matches = _match_correlation_groups(trace1_large, trace2_large)
637
-
638
- # Analyze differences in matched groups
639
- fusion_diffs: dict[str, dict[str, Any]] = defaultdict(
640
- lambda: {
641
- "trace1_count": 0,
642
- "trace2_count": 0,
643
- "trace1_time_us": 0,
644
- "trace2_time_us": 0,
645
- "groups_with_diff": 0,
646
- "total_groups": 0,
647
- }
344
+ # 1. Find explicit fused operations (highest confidence)
345
+ explicit_fusions = _find_explicit_fused_operations(
346
+ amd_kernels, nvidia_kernels,
347
+ trace1_name="AMD", trace2_name="NVIDIA",
348
+ trace1_platform="AMD",
648
349
  )
350
+ all_mappings.extend(explicit_fusions)
649
351
 
650
- # NEW: Collect actual fusion mappings
651
- all_fusion_mappings = []
652
-
653
- for trace1_cid, trace2_cid in matches:
654
- trace1_comp, trace1_times = _analyze_correlation_group(trace1_large[trace1_cid])
655
- trace2_comp, trace2_times = _analyze_correlation_group(trace2_large[trace2_cid])
656
-
657
- # Find all kernel types in either platform
658
- all_types = set(trace1_comp.keys()) | set(trace2_comp.keys())
659
-
660
- for ktype in all_types:
661
- trace1_count = trace1_comp.get(ktype, 0)
662
- trace2_count = trace2_comp.get(ktype, 0)
663
- trace1_time = trace1_times.get(ktype, 0)
664
- trace2_time = trace2_times.get(ktype, 0)
665
-
666
- fusion_diffs[ktype]["trace1_count"] += trace1_count
667
- fusion_diffs[ktype]["trace2_count"] += trace2_count
668
- fusion_diffs[ktype]["trace1_time_us"] += trace1_time
669
- fusion_diffs[ktype]["trace2_time_us"] += trace2_time
670
- fusion_diffs[ktype]["total_groups"] += 1
671
-
672
- if trace1_count != trace2_count:
673
- fusion_diffs[ktype]["groups_with_diff"] += 1
674
-
675
- # NEW: Find actual kernel mappings in this correlation group
676
- group_mappings = _find_fusion_mappings(
677
- trace1_large[trace1_cid],
678
- trace2_large[trace2_cid],
679
- trace1_name=trace1_platform,
680
- trace2_name=trace2_platform
681
- )
682
- # Add correlation ID context to each mapping
683
- for mapping in group_mappings:
684
- mapping["correlation_group_trace1"] = trace1_cid
685
- mapping["correlation_group_trace2"] = trace2_cid
686
- all_fusion_mappings.extend(group_mappings)
687
-
688
- # Also get global counts for context
689
- global_trace1_counts: Counter[str] = Counter(
690
- [classify_kernel(k.get("name", "")) for k in trace1_kernels]
352
+ # 2. Find sequence-based fusions
353
+ sequence_fusions = _find_fusion_mappings(
354
+ amd_kernels, nvidia_kernels,
355
+ trace1_name="AMD", trace2_name="NVIDIA",
356
+ trace1_platform="AMD",
691
357
  )
692
- global_trace2_counts: Counter[str] = Counter(
693
- [classify_kernel(k.get("name", "")) for k in trace2_kernels]
358
+ # Deduplicate: skip if same fused_kernel_type already found
359
+ existing_types = {m["fused_kernel_type"] for m in all_mappings}
360
+ for fusion in sequence_fusions:
361
+ if fusion["fused_kernel_type"] not in existing_types:
362
+ all_mappings.append(fusion)
363
+ existing_types.add(fusion["fused_kernel_type"])
364
+
365
+ # 3. Find count-imbalance fusions
366
+ count_fusions = _find_count_imbalance_fusions(
367
+ amd_kernels, nvidia_kernels,
368
+ trace1_name="AMD", trace2_name="NVIDIA",
369
+ trace1_platform="AMD",
694
370
  )
695
-
696
- # Build results
697
- results: dict[str, Any] = {
698
- "metadata": {
699
- "trace1_gpu": trace1_gpu,
700
- "trace2_gpu": trace2_gpu,
701
- "trace1_total_kernels": len(trace1_kernels),
702
- "trace2_total_kernels": len(trace2_kernels),
703
- "trace1_correlation_groups": len(trace1_large),
704
- "trace2_correlation_groups": len(trace2_large),
705
- "matched_groups": len(matches),
706
- },
707
- "global_counts": {},
708
- "fusion_opportunities": [],
709
- "fusion_mappings": all_fusion_mappings, # NEW: Include actual mappings
710
- }
711
-
712
- # Global counts for all kernel types
713
- all_ktypes = set(global_trace1_counts.keys()) | set(global_trace2_counts.keys())
714
- for ktype in all_ktypes:
715
- trace1_total = global_trace1_counts.get(ktype, 0)
716
- trace2_total = global_trace2_counts.get(ktype, 0)
717
-
718
- if trace1_total > 0 or trace2_total > 0:
719
- results["global_counts"][ktype] = {
720
- "trace1_count": trace1_total,
721
- "trace2_count": trace2_total,
722
- "ratio": trace1_total / trace2_total if trace2_total > 0 else float("inf"),
723
- }
724
-
725
- # Fusion opportunities from matched correlation groups
726
- for ktype, stats in fusion_diffs.items():
727
- trace1_avg = (
728
- stats["trace1_count"] / stats["total_groups"]
729
- if stats["total_groups"] > 0
730
- else 0
731
- )
732
- trace2_avg = (
733
- stats["trace2_count"] / stats["total_groups"]
734
- if stats["total_groups"] > 0
735
- else 0
736
- )
737
- trace1_time_ms = stats["trace1_time_us"] / 1000
738
- trace2_time_ms = stats["trace2_time_us"] / 1000
739
-
740
- # Calculate significance
741
- diff_ratio = trace1_avg / trace2_avg if trace2_avg > 0 else float("inf")
742
- reverse_ratio = trace2_avg / trace1_avg if trace1_avg > 0 else float("inf")
743
-
744
- # Only report if there's a significant difference
745
- # Either: one platform has it and the other doesn't (ratio > 10)
746
- # Or: one platform has significantly more (ratio > 2.0)
747
- is_significant = (
748
- (diff_ratio > 10.0 or reverse_ratio > 10.0)
749
- or ( # One platform doesn't have it
750
- (diff_ratio > 2.0 or reverse_ratio > 2.0)
751
- and stats["trace1_count"] + stats["trace2_count"] > 20 # Significant difference # Not trivial counts
371
+ # Deduplicate
372
+ for fusion in count_fusions:
373
+ if fusion["fused_kernel_type"] not in existing_types:
374
+ all_mappings.append(fusion)
375
+ existing_types.add(fusion["fused_kernel_type"])
376
+
377
+ # Convert to FusionPattern objects
378
+ patterns: list[FusionPattern] = []
379
+ for mapping in all_mappings:
380
+ fused_kernel_names = mapping.get("fused_kernel_names", [])
381
+ fused_kernel = fused_kernel_names[0] if fused_kernel_names else mapping["fused_kernel_type"]
382
+
383
+ patterns.append(
384
+ FusionPattern(
385
+ layer=0, # Pattern-based analysis doesn't track layers
386
+ operation=mapping["fused_kernel_type"],
387
+ fused_platform=mapping["fused_platform"],
388
+ fused_kernel=fused_kernel,
389
+ unfused_kernels=mapping["unfused_sequence"],
390
+ count=mapping["pattern_count"],
391
+ evidence=mapping["evidence"],
752
392
  )
753
393
  )
754
394
 
755
- if is_significant:
756
- # Determine who fuses (who has FEWER calls = more fusion)
757
- if diff_ratio > 1.5:
758
- fused_by = trace2_platform # Trace 2 has fewer calls, so it fuses more
759
- ratio = diff_ratio
760
- else:
761
- fused_by = trace1_platform # Trace 1 has fewer calls, so it fuses more
762
- ratio = reverse_ratio
763
-
764
- # Calculate time ratio (who's faster for this operation)
765
- time_ratio = trace1_time_ms / trace2_time_ms if trace2_time_ms > 0 else float("inf")
766
-
767
- results["fusion_opportunities"].append(
768
- {
769
- "kernel_type": ktype,
770
- "trace1_total": stats["trace1_count"],
771
- "trace2_total": stats["trace2_count"],
772
- "trace1_avg_per_group": trace1_avg,
773
- "trace2_avg_per_group": trace2_avg,
774
- "trace1_time_ms": trace1_time_ms,
775
- "trace2_time_ms": trace2_time_ms,
776
- "time_ratio": time_ratio,
777
- "ratio": ratio,
778
- "fused_by": fused_by,
779
- "groups_affected": stats["groups_with_diff"],
780
- "total_groups": stats["total_groups"],
781
- }
782
- )
395
+ amd_fuses = sum(1 for p in patterns if p.fused_platform == "AMD")
396
+ nvidia_fuses = sum(1 for p in patterns if p.fused_platform == "NVIDIA")
783
397
 
784
- # ADDITIONAL: Check global counts for significant differences not captured above
785
- # This catches patterns like Sort that may be in small groups or distributed differently
786
- for ktype in all_ktypes:
787
- # Skip if already added from correlation group analysis
788
- if any(opp["kernel_type"] == ktype for opp in results["fusion_opportunities"]):
789
- continue
398
+ return FusionAnalysis(
399
+ patterns=patterns,
400
+ summary={
401
+ "amd_fuses": amd_fuses,
402
+ "nvidia_fuses": nvidia_fuses,
403
+ "total_fusion_opportunities": len(patterns),
404
+ },
405
+ )
790
406
 
791
- trace1_total = global_trace1_counts.get(ktype, 0)
792
- trace2_total = global_trace2_counts.get(ktype, 0)
793
407
 
794
- # Skip trivial counts
795
- if trace1_total + trace2_total < 100:
796
- continue
408
+ def analyze_fusion_from_alignment(
409
+ layer_alignments: list[Any],
410
+ amd_kernels: list[dict] | None = None,
411
+ nvidia_kernels: list[dict] | None = None,
412
+ ) -> dict[str, Any]:
413
+ """Analyze fusion from kernel data (for API compatibility).
797
414
 
798
- # Calculate global ratio
799
- global_ratio = trace1_total / trace2_total if trace2_total > 0 else float("inf")
800
- global_reverse_ratio = trace2_total / trace1_total if trace1_total > 0 else float("inf")
415
+ Args:
416
+ layer_alignments: List of aligned layers (unused - kept for API compatibility)
417
+ amd_kernels: Optional list of AMD kernel events for pattern-based analysis
418
+ nvidia_kernels: Optional list of NVIDIA kernel events for pattern-based analysis
801
419
 
802
- # Check if globally significant (more aggressive threshold for comprehensive detection)
803
- is_globally_significant = (
804
- (global_ratio > 2.0 or global_reverse_ratio > 2.0)
805
- and (trace1_total + trace2_total > 100)
420
+ Returns:
421
+ Dictionary with fusion analysis results
422
+ """
423
+ # If raw kernels provided, use pattern-based analysis (preferred)
424
+ if amd_kernels is not None and nvidia_kernels is not None:
425
+ fusion_analysis = detect_fusion_patterns(amd_kernels, nvidia_kernels)
426
+ else:
427
+ # Fallback: empty analysis if no kernel data
428
+ fusion_analysis = FusionAnalysis(
429
+ patterns=[],
430
+ summary={"amd_fuses": 0, "nvidia_fuses": 0, "total_fusion_opportunities": 0},
806
431
  )
807
432
 
808
- if is_globally_significant:
809
- # Get timing info from all kernels (not just matched groups)
810
- trace1_time = sum(
811
- k.get("dur", 0) for k in trace1_kernels
812
- if classify_kernel(k.get("name", "")) == ktype
813
- ) / 1000 # Convert to ms
814
- trace2_time = sum(
815
- k.get("dur", 0) for k in trace2_kernels
816
- if classify_kernel(k.get("name", "")) == ktype
817
- ) / 1000
818
-
819
- # Determine who fuses (who has FEWER calls = more fusion)
820
- if global_ratio > 1.5:
821
- fused_by = "Trace 2" # Trace 2 has fewer calls
822
- ratio = global_ratio
823
- else:
824
- fused_by = "Trace 1" # Trace 1 has fewer calls
825
- ratio = global_reverse_ratio
826
-
827
- time_ratio = trace1_time / trace2_time if trace2_time > 0 else float("inf")
828
-
829
- results["fusion_opportunities"].append(
830
- {
831
- "kernel_type": ktype,
832
- "trace1_total": trace1_total,
833
- "trace2_total": trace2_total,
834
- "trace1_avg_per_group": trace1_total / len(matches) if matches else 0,
835
- "trace2_avg_per_group": trace2_total / len(matches) if matches else 0,
836
- "trace1_time_ms": trace1_time,
837
- "trace2_time_ms": trace2_time,
838
- "time_ratio": time_ratio,
839
- "ratio": ratio,
840
- "fused_by": fused_by,
841
- "groups_affected": 0, # Unknown for global analysis
842
- "total_groups": len(matches),
843
- }
844
- )
845
-
846
- # Sort by impact (ratio * total count)
847
- results["fusion_opportunities"].sort(
848
- key=lambda x: x["ratio"] * (x["trace1_total"] + x["trace2_total"]), reverse=True
849
- )
850
-
851
- # ADD PARTIAL FUSION MAPPINGS using correlation group differential analysis
852
- # This catches patterns like Sort that exist on both platforms but with different frequencies
853
- partial_mappings = _find_partial_fusion_via_groups(
854
- trace1_large,
855
- trace2_large,
856
- matches,
857
- trace1_name=trace1_platform,
858
- trace2_name=trace2_platform
859
- )
860
- all_fusion_mappings.extend(partial_mappings)
861
-
862
- # DETECT INTRA-TYPE FUSION (same kernel type fused with itself, like Sort chains)
863
- # Do this FIRST since it's more accurate than the fallback global analysis
864
- intra_mappings = _detect_intra_type_fusion(
865
- trace1_kernels,
866
- trace2_kernels,
867
- trace1_name=trace1_platform,
868
- trace2_name=trace2_platform
869
- )
870
- all_fusion_mappings.extend(intra_mappings)
871
-
872
- # Collect kernel types already handled by intra-type fusion
873
- intra_handled_types = set(m["fused_kernel_type"] for m in intra_mappings)
874
-
875
- # ALSO ADD GLOBAL FUSION MAPPINGS for kernels not in large correlation groups
876
- # Skip types already handled by intra-type fusion (more accurate)
877
- global_mappings = _find_fusion_mappings(
878
- trace1_kernels,
879
- trace2_kernels,
880
- trace1_name=trace1_platform,
881
- trace2_name=trace2_platform
882
- )
883
- # Filter: skip if already handled or if evidence is duplicate
884
- existing_evidence = set(m["evidence"] for m in all_fusion_mappings)
885
- for mapping in global_mappings:
886
- ktype = mapping["unfused_sequence"][0] if mapping["unfused_sequence"] else None
887
- if ktype not in intra_handled_types and mapping["evidence"] not in existing_evidence:
888
- all_fusion_mappings.append(mapping)
889
-
890
- return results
433
+ fusion_opportunities = []
434
+ fusion_mappings = []
435
+
436
+ for pattern in fusion_analysis.patterns:
437
+ unfused_platform = "NVIDIA" if pattern.fused_platform == "AMD" else "AMD"
438
+
439
+ fusion_opportunities.append({
440
+ "kernel_type": pattern.operation,
441
+ "layer": pattern.layer,
442
+ "fused_by": pattern.fused_platform,
443
+ "fused_kernel": pattern.fused_kernel,
444
+ "unfused_kernels": pattern.unfused_kernels,
445
+ "count": pattern.count,
446
+ "evidence": pattern.evidence,
447
+ })
448
+
449
+ fusion_mappings.append({
450
+ "fused_platform": pattern.fused_platform,
451
+ "fused_kernel_type": pattern.operation,
452
+ "fused_kernel_name": pattern.fused_kernel,
453
+ "unfused_platform": unfused_platform,
454
+ "unfused_sequence": pattern.unfused_kernels,
455
+ "pattern_count": pattern.count,
456
+ "evidence": pattern.evidence,
457
+ "layer": pattern.layer,
458
+ })
459
+
460
+ return {
461
+ "fusion_opportunities": fusion_opportunities,
462
+ "fusion_mappings": fusion_mappings,
463
+ "summary": fusion_analysis.summary,
464
+ }