wafer-core 0.1.32__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,45 +1,179 @@
1
1
  """Fusion analysis for detecting kernel fusion differences between platforms.
2
2
 
3
- Detects fusion differences between AMD and NVIDIA by analyzing:
4
- 1. Kernel types unique to each platform (one platform fuses them away)
5
- 2. Sequence patterns (what comes before/after unique kernels)
6
- 3. Count imbalances (one platform has significantly more calls)
7
-
8
- This is the pattern-based approach which is more reliable than alignment-based detection.
3
+ Detects fusion differences between AMD and NVIDIA by analyzing how many kernels
4
+ each platform launches for the same logical operations.
9
5
  """
10
6
 
7
+ import json
11
8
  from collections import Counter, defaultdict
12
- from dataclasses import dataclass, field
9
+ from functools import lru_cache
10
+ from pathlib import Path
13
11
  from typing import Any
14
12
 
15
- from .classifier import classify, Op
13
+ from .classifier import classify_kernel
14
+
15
+
16
+ @lru_cache(maxsize=10000)
17
+ def _classify_kernel_cached(name: str) -> str:
18
+ """Classify kernel with caching to avoid redundant regex matching."""
19
+ return classify_kernel(name)
20
+
21
+
22
+ def _load_trace_for_fusion(
23
+ file_path: str | Path,
24
+ ) -> tuple[str, str, list[dict[str, Any]], dict[int, list[dict[str, Any]]]]:
25
+ """Load trace and group kernels by correlation ID.
26
+
27
+ Args:
28
+ file_path: Path to trace file
29
+
30
+ Returns:
31
+ Tuple of (platform, gpu_name, all_kernels, corr_groups)
32
+ """
33
+ with open(file_path, "rb") as f:
34
+ trace = json.load(f)
35
+
36
+ # Detect platform
37
+ props = trace.get("deviceProperties", [{}])[0]
38
+ is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
39
+ platform = "AMD" if is_amd else "NVIDIA"
40
+ gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
41
+
42
+ # Get all kernel events
43
+ events = trace.get("traceEvents", [])
44
+ kernels = [e for e in events if e.get("cat") == "kernel"]
45
+
46
+ # Group by correlation ID
47
+ corr_groups: dict[int, list[dict[str, Any]]] = defaultdict(list)
48
+ for k in kernels:
49
+ corr_id = k.get("args", {}).get("correlation")
50
+ if corr_id is not None:
51
+ corr_groups[corr_id].append(k)
52
+
53
+ return platform, gpu_name, kernels, dict(corr_groups)
54
+
55
+
56
+ def _compute_group_signature(kernels: list[dict[str, Any]]) -> tuple:
57
+ """Compute a hashable signature for a correlation group for O(1) lookups.
58
+
59
+ This enables signature-based matching which is O(n) average instead of O(n²).
60
+
61
+ Args:
62
+ kernels: List of kernel events in the group
63
+
64
+ Returns:
65
+ Tuple of (size_bucket, has_attention, has_ffn, dominant_type)
66
+ """
67
+ counts = Counter(_classify_kernel_cached(k.get("name", "")) for k in kernels)
68
+ dominant = counts.most_common(1)[0][0] if counts else "Other"
69
+ size_bucket = len(kernels) // 10 * 10 # Round to nearest 10
70
+ has_attn = any("attention" in k.get("name", "").lower() or "fmha" in k.get("name", "").lower() for k in kernels)
71
+ has_ffn = any(
72
+ any(x in k.get("name", "").lower() for x in ["cijk", "nvjet", "gemm"])
73
+ for k in kernels
74
+ )
75
+ return (size_bucket, has_attn, has_ffn, dominant)
76
+
77
+
78
+ def _analyze_correlation_group(
79
+ kernels: list[dict[str, Any]],
80
+ ) -> tuple[dict[str, int], dict[str, float]]:
81
+ """Analyze kernel composition within a correlation group.
82
+
83
+ Args:
84
+ kernels: List of kernel events in the group
16
85
 
86
+ Returns:
87
+ Tuple of (counts, timings) where counts maps kernel types to counts
88
+ and timings maps kernel types to total duration in microseconds
89
+ """
90
+ counts: Counter[str] = Counter()
91
+ timings: dict[str, float] = defaultdict(float)
17
92
 
18
- @dataclass
19
- class FusionPattern:
20
- """A detected fusion pattern."""
93
+ for k in kernels:
94
+ kernel_type = _classify_kernel_cached(k.get("name", ""))
95
+ counts[kernel_type] += 1
96
+ timings[kernel_type] += k.get("dur", 0)
21
97
 
22
- layer: int
23
- operation: str # The fused operation name (e.g., "RMSNorm+GEMM")
24
- fused_platform: str # Platform that fuses (has fewer kernels)
25
- fused_kernel: str # The actual fused kernel name
26
- unfused_kernels: list[str] # List of separate kernels on the other platform
27
- count: int
28
- evidence: str
98
+ return dict(counts), dict(timings)
29
99
 
30
100
 
31
- @dataclass
32
- class FusionAnalysis:
33
- """Complete fusion analysis result."""
101
+ def _match_correlation_groups(
102
+ trace1_groups: dict[int, list[dict[str, Any]]],
103
+ trace2_groups: dict[int, list[dict[str, Any]]],
104
+ size_tolerance: float = 0.25,
105
+ ) -> list[tuple[int, int]]:
106
+ """Match correlation groups using hybrid signature + composition approach.
34
107
 
35
- patterns: list[FusionPattern] = field(default_factory=list)
36
- summary: dict[str, Any] = field(default_factory=dict)
108
+ Uses signature-based lookup for speed (O(n) average) but falls back to
109
+ composition-based scoring for accuracy when multiple candidates exist.
37
110
 
111
+ Args:
112
+ trace1_groups: Trace 1 correlation groups
113
+ trace2_groups: Trace 2 correlation groups
114
+ size_tolerance: Groups match if sizes are within this fraction
38
115
 
39
- def _classify_kernel(name: str, platform: str = "AMD") -> str:
40
- """Classify a kernel name to its operation type."""
41
- op, _pattern = classify(name, platform)
42
- return op.value
116
+ Returns:
117
+ List of (trace1_corr_id, trace2_corr_id) pairs
118
+ """
119
+ matches = []
120
+ used_trace2_ids: set[int] = set()
121
+
122
+ # Pre-compute compositions for scoring (keeps original O(n) cost)
123
+ trace1_comps = {id: _analyze_correlation_group(kernels)[0] for id, kernels in trace1_groups.items()}
124
+ trace2_comps = {id: _analyze_correlation_group(kernels)[0] for id, kernels in trace2_groups.items()}
125
+
126
+ # Build signature-based lookup for fast filtering (O(n))
127
+ trace2_by_sig: dict[tuple, list[tuple[int, int]]] = defaultdict(list)
128
+ for gid, kernels in trace2_groups.items():
129
+ sig = _compute_group_signature(kernels)
130
+ trace2_by_sig[sig].append((gid, len(kernels)))
131
+
132
+ # Sort trace1 groups by size (largest first)
133
+ trace1_sorted = sorted(trace1_groups.items(), key=lambda x: len(x[1]), reverse=True)
134
+
135
+ # Match each trace1 group
136
+ for trace1_id, trace1_kernels in trace1_sorted:
137
+ trace1_size = len(trace1_kernels)
138
+ trace1_comp = trace1_comps[trace1_id]
139
+ sig = _compute_group_signature(trace1_kernels)
140
+
141
+ # Fast lookup: get candidates with matching signature
142
+ candidates = [
143
+ (gid, size) for gid, size in trace2_by_sig.get(sig, [])
144
+ if gid not in used_trace2_ids
145
+ ]
146
+
147
+ # Expand search to adjacent size buckets if needed
148
+ if not candidates:
149
+ for adj_size in [sig[0] - 10, sig[0] + 10]:
150
+ adj_sig = (adj_size, sig[1], sig[2], sig[3])
151
+ for gid, size in trace2_by_sig.get(adj_sig, []):
152
+ if gid not in used_trace2_ids:
153
+ candidates.append((gid, size))
154
+
155
+ # Apply size tolerance filter
156
+ min_size = trace1_size / (1 + size_tolerance)
157
+ max_size = trace1_size * (1 + size_tolerance)
158
+ candidates = [(gid, size) for gid, size in candidates if min_size <= size <= max_size]
159
+
160
+ # Score by composition similarity (only for filtered candidates)
161
+ if candidates:
162
+ best_id = None
163
+ best_score = float("inf")
164
+ for gid, size in candidates:
165
+ trace2_comp = trace2_comps[gid]
166
+ shared_types = len(set(trace1_comp.keys()) & set(trace2_comp.keys()))
167
+ score = abs(trace1_size - size) - (shared_types * 10)
168
+ if score < best_score:
169
+ best_score = score
170
+ best_id = gid
171
+
172
+ if best_id is not None:
173
+ matches.append((trace1_id, best_id))
174
+ used_trace2_ids.add(best_id)
175
+
176
+ return matches
43
177
 
44
178
 
45
179
  def _find_fusion_mappings(
@@ -47,7 +181,6 @@ def _find_fusion_mappings(
47
181
  trace2_kernels: list[dict],
48
182
  trace1_name: str = "Trace1",
49
183
  trace2_name: str = "Trace2",
50
- trace1_platform: str = "AMD",
51
184
  ) -> list[dict]:
52
185
  """Find fusion mappings by analyzing kernel execution sequence patterns.
53
186
 
@@ -59,21 +192,29 @@ def _find_fusion_mappings(
59
192
  trace2_kernels: List of kernel events from second trace
60
193
  trace1_name: Name of first platform (e.g., "AMD")
61
194
  trace2_name: Name of second platform (e.g., "NVIDIA")
62
- trace1_platform: Platform string for classification ("AMD" or "NVIDIA")
63
195
 
64
196
  Returns:
65
- List of mapping dictionaries with fusion details
197
+ List of mapping dictionaries, each containing:
198
+ - fused_platform: Which platform fuses the operations
199
+ - fused_kernel_type: The single fused kernel type
200
+ - unfused_platform: Which platform runs them separately
201
+ - unfused_sequence: List of kernel types run separately
202
+ - pattern_count: How many times this pattern appears
203
+ - pattern_confidence: Fraction of occurrences following this pattern
204
+ - evidence: Human-readable description
66
205
  """
206
+ from collections import defaultdict
207
+ from wafer_core.lib.trace_compare.classifier import classify_kernel
208
+
67
209
  mappings = []
68
- trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
69
210
 
70
211
  # Sort kernels by timestamp
71
- trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get("ts", 0))
72
- trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get("ts", 0))
212
+ trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get('ts', 0))
213
+ trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get('ts', 0))
73
214
 
74
215
  # Classify all kernels
75
- trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_sorted]
76
- trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_sorted]
216
+ trace1_types = [_classify_kernel_cached(k.get('name', '')) for k in trace1_sorted]
217
+ trace2_types = [_classify_kernel_cached(k.get('name', '')) for k in trace2_sorted]
77
218
 
78
219
  # Find kernel types unique to each trace
79
220
  trace1_type_set = set(trace1_types)
@@ -82,61 +223,32 @@ def _find_fusion_mappings(
82
223
  trace1_only = trace1_type_set - trace2_type_set
83
224
  trace2_only = trace2_type_set - trace1_type_set
84
225
 
85
- # For each unique type in trace1, check if it's a fused operation
86
- # If trace1 has a unique kernel type that trace2 doesn't have, trace1 is likely fusing
226
+ # For each unique type in trace1, find common sequence patterns
87
227
  for unique_type in trace1_only:
88
- # Skip "Other" since it's too generic
89
- if unique_type == "Other":
90
- continue
91
-
92
- # If the unique type contains '+', it's explicitly a fused kernel
93
- # This means trace1 (which has it) is fusing, not trace2
94
- if "+" in unique_type:
95
- # Parse components from the fused op name
96
- components = [c.strip() for c in unique_type.split("+")]
97
- indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
98
-
99
- if len(indices) < 5:
100
- continue
101
-
102
- mappings.append({
103
- "fused_platform": trace1_name,
104
- "fused_kernel_type": unique_type,
105
- "fused_count": len(indices),
106
- "unfused_platform": trace2_name,
107
- "unfused_sequence": components,
108
- "unfused_count_per_type": {c: trace2_types.count(c) for c in components},
109
- "pattern_count": len(indices),
110
- "pattern_confidence": 1.0,
111
- "evidence": f"{trace1_name} fuses {' + '.join(components)} into single kernel ({len(indices)} calls), {trace2_name} runs separately",
112
- })
113
- continue
114
-
115
- # For non-fused unique types, find all occurrences
228
+ # Find all occurrences of this type
116
229
  indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
117
230
 
118
231
  if len(indices) < 5: # Need enough samples to be meaningful
119
232
  continue
120
233
 
121
- # Analyze what comes before each occurrence
122
- before_types: dict[str, int] = defaultdict(int)
234
+ # Analyze what comes before/after each occurrence
235
+ before_types = defaultdict(int)
123
236
 
124
237
  for idx in indices:
125
238
  if idx > 0:
126
239
  before_types[trace1_types[idx - 1]] += 1
127
240
 
128
- # Find the most common pattern
129
- if not before_types:
130
- continue
131
- most_common_before = max(before_types.items(), key=lambda x: x[1])
241
+ # Find the most common pattern (e.g., "Attention → Reduce")
242
+ most_common_before = max(before_types.items(), key=lambda x: x[1]) if before_types else (None, 0)
132
243
 
133
- # If there's a strong pattern (>80% of occurrences) and trace2 has the preceding type,
134
- # trace1 runs them separately while trace2 might fuse
244
+ # If there's a strong pattern (>80% of occurrences)
135
245
  if most_common_before[1] / len(indices) > 0.8:
246
+ # This suggests: Trace2 likely fuses [before_type + unique_type] into [before_type]
136
247
  fusion_candidate = most_common_before[0]
137
248
 
138
- # Verify trace2 has this type but NOT the unique_type
249
+ # Verify trace2 has this type
139
250
  if fusion_candidate in trace2_type_set:
251
+ # Count occurrences to compare
140
252
  trace1_fusion_count = trace1_types.count(fusion_candidate)
141
253
  trace2_fusion_count = trace2_types.count(fusion_candidate)
142
254
 
@@ -148,40 +260,15 @@ def _find_fusion_mappings(
148
260
  "unfused_sequence": [fusion_candidate, unique_type],
149
261
  "unfused_count_per_type": {
150
262
  fusion_candidate: trace1_fusion_count,
151
- unique_type: len(indices),
263
+ unique_type: len(indices)
152
264
  },
153
265
  "pattern_count": len(indices),
154
266
  "pattern_confidence": most_common_before[1] / len(indices),
155
- "evidence": f"{trace1_name} runs {fusion_candidate} + {unique_type} separately, {trace2_name} fuses into {fusion_candidate}",
267
+ "evidence": f"{trace1_name} runs {fusion_candidate}+{unique_type} separately, {trace2_name} fuses into {fusion_candidate}"
156
268
  })
157
269
 
158
270
  # Also check trace2-only types
159
271
  for unique_type in trace2_only:
160
- if unique_type == "Other":
161
- continue
162
-
163
- # If the unique type contains '+', it's explicitly a fused kernel
164
- # This means trace2 (which has it) is fusing, not trace1
165
- if "+" in unique_type:
166
- components = [c.strip() for c in unique_type.split("+")]
167
- indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
168
-
169
- if len(indices) < 5:
170
- continue
171
-
172
- mappings.append({
173
- "fused_platform": trace2_name,
174
- "fused_kernel_type": unique_type,
175
- "fused_count": len(indices),
176
- "unfused_platform": trace1_name,
177
- "unfused_sequence": components,
178
- "unfused_count_per_type": {c: trace1_types.count(c) for c in components},
179
- "pattern_count": len(indices),
180
- "pattern_confidence": 1.0,
181
- "evidence": f"{trace2_name} fuses {' + '.join(components)} into single kernel ({len(indices)} calls), {trace1_name} runs separately",
182
- })
183
- continue
184
-
185
272
  indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
186
273
 
187
274
  if len(indices) < 5:
@@ -193,9 +280,7 @@ def _find_fusion_mappings(
193
280
  if idx > 0:
194
281
  before_types[trace2_types[idx - 1]] += 1
195
282
 
196
- if not before_types:
197
- continue
198
- most_common_before = max(before_types.items(), key=lambda x: x[1])
283
+ most_common_before = max(before_types.items(), key=lambda x: x[1]) if before_types else (None, 0)
199
284
 
200
285
  if most_common_before[1] / len(indices) > 0.8:
201
286
  fusion_candidate = most_common_before[0]
@@ -212,294 +297,633 @@ def _find_fusion_mappings(
212
297
  "unfused_sequence": [fusion_candidate, unique_type],
213
298
  "unfused_count_per_type": {
214
299
  fusion_candidate: trace2_fusion_count,
215
- unique_type: len(indices),
300
+ unique_type: len(indices)
216
301
  },
217
302
  "pattern_count": len(indices),
218
303
  "pattern_confidence": most_common_before[1] / len(indices),
219
- "evidence": f"{trace2_name} runs {fusion_candidate} + {unique_type} separately, {trace1_name} fuses into {fusion_candidate}",
304
+ "evidence": f"{trace2_name} runs {fusion_candidate}+{unique_type} separately, {trace1_name} fuses into {fusion_candidate}"
305
+ })
306
+
307
+ # NEW: Check for partial fusion patterns (kernel exists on both platforms but with different counts)
308
+ # If one platform has significantly fewer calls (>1.3x ratio), look for fusion patterns
309
+ common_types = trace1_type_set & trace2_type_set
310
+
311
+ for ktype in common_types:
312
+ trace1_count = trace1_types.count(ktype)
313
+ trace2_count = trace2_types.count(ktype)
314
+
315
+ # Check if there's a significant imbalance (one has >1.3x more)
316
+ if trace1_count == 0 or trace2_count == 0:
317
+ continue
318
+
319
+ ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
320
+
321
+ if ratio < 1.3 or trace1_count + trace2_count < 100:
322
+ continue
323
+
324
+ # Determine which platform has more (unfused) and which has fewer (fused)
325
+ if trace1_count > trace2_count:
326
+ # Trace1 has more separate calls, Trace2 likely fuses
327
+ unfused_platform = trace1_name
328
+ fused_platform = trace2_name
329
+ unfused_count = trace1_count
330
+ fused_count = trace2_count
331
+
332
+ # Find what Trace2 might be fusing this into
333
+ # Use the most common kernel type in Trace2 as a likely fusion target
334
+ trace2_type_counts = defaultdict(int)
335
+ for t in trace2_types:
336
+ if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
337
+ trace2_type_counts[t] += 1
338
+
339
+ if trace2_type_counts:
340
+ # Use the most common type as the fusion target
341
+ fusion_target = max(trace2_type_counts.items(), key=lambda x: x[1])[0]
342
+
343
+ mappings.append({
344
+ "fused_platform": fused_platform,
345
+ "fused_kernel_type": fusion_target,
346
+ "fused_count": fused_count,
347
+ "unfused_platform": unfused_platform,
348
+ "unfused_sequence": [ktype],
349
+ "unfused_count_per_type": {
350
+ ktype: unfused_count
351
+ },
352
+ "pattern_count": unfused_count - fused_count,
353
+ "pattern_confidence": (unfused_count - fused_count) / unfused_count,
354
+ "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
355
+ })
356
+ else:
357
+ # Trace2 has more separate calls, Trace1 likely fuses
358
+ unfused_platform = trace2_name
359
+ fused_platform = trace1_name
360
+ unfused_count = trace2_count
361
+ fused_count = trace1_count
362
+
363
+ # Find what Trace1 might be fusing this into
364
+ # Use the most common kernel type in Trace1 as a likely fusion target
365
+ trace1_type_counts = defaultdict(int)
366
+ for t in trace1_types:
367
+ if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
368
+ trace1_type_counts[t] += 1
369
+
370
+ if trace1_type_counts:
371
+ fusion_target = max(trace1_type_counts.items(), key=lambda x: x[1])[0]
372
+
373
+ mappings.append({
374
+ "fused_platform": fused_platform,
375
+ "fused_kernel_type": fusion_target,
376
+ "fused_count": fused_count,
377
+ "unfused_platform": unfused_platform,
378
+ "unfused_sequence": [ktype],
379
+ "unfused_count_per_type": {
380
+ ktype: unfused_count
381
+ },
382
+ "pattern_count": unfused_count - fused_count,
383
+ "pattern_confidence": (unfused_count - fused_count) / unfused_count,
384
+ "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
220
385
  })
221
386
 
222
387
  return mappings
223
388
 
224
389
 
225
- def _find_count_imbalance_fusions(
390
+ def _detect_intra_type_fusion(
226
391
  trace1_kernels: list[dict],
227
392
  trace2_kernels: list[dict],
228
- trace1_name: str = "Trace1",
229
- trace2_name: str = "Trace2",
230
- trace1_platform: str = "AMD",
393
+ trace1_name: str,
394
+ trace2_name: str,
231
395
  ) -> list[dict]:
232
- """Find fusions by looking for significant count imbalances.
396
+ """Detect intra-type fusion where consecutive same-type kernels are fused.
233
397
 
234
- When one platform has significantly more kernel calls of a type (>3x),
235
- it MAY suggest the other platform fuses those operations.
236
-
237
- NOTE: This is speculative - count differences can also indicate:
238
- - Different algorithmic implementations
239
- - Different library choices (cuBLAS vs hipBLAS)
240
- - Different optimization strategies
241
-
242
- Only very large imbalances (>3x) with high counts are flagged.
398
+ Example: AMD runs Sort→Sort→Sort (42 calls) while NVIDIA runs Sort→Sort (10 calls)
399
+ This indicates NVIDIA has a more efficient Sort implementation that fuses operations.
243
400
  """
244
- mappings = []
245
- trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
401
+ from wafer_core.lib.trace_compare.classifier import classify_kernel
246
402
 
247
- # Classify all kernels
248
- trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_kernels]
249
- trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_kernels]
403
+ def analyze_chains(kernels):
404
+ """Find chains of consecutive same-type kernels"""
405
+ sorted_kernels = sorted(kernels, key=lambda k: k.get('ts', 0))
406
+ types = [_classify_kernel_cached(k['name']) for k in sorted_kernels]
250
407
 
251
- # Count by type
252
- trace1_counts = Counter(trace1_types)
253
- trace2_counts = Counter(trace2_types)
408
+ chains = defaultdict(list)
409
+ i = 0
410
+ while i < len(types):
411
+ ktype = types[i]
412
+ count = 0
413
+ while i < len(types) and types[i] == ktype:
414
+ count += 1
415
+ i += 1
416
+ chains[ktype].append(count)
254
417
 
255
- # Find common types with significant differences
256
- common_types = set(trace1_counts.keys()) & set(trace2_counts.keys())
418
+ return chains
257
419
 
258
- # Skip types that are likely just implementation differences, not fusion
259
- skip_types = {"Reduce", "Copy/Memory", "Sync", "Other", "Elementwise"}
420
+ trace1_chains = analyze_chains(trace1_kernels)
421
+ trace2_chains = analyze_chains(trace2_kernels)
260
422
 
261
- for ktype in common_types:
262
- if ktype in skip_types:
423
+ mappings = []
424
+ all_types = set(trace1_chains.keys()) | set(trace2_chains.keys())
425
+
426
+ for ktype in all_types:
427
+ t1_lengths = trace1_chains.get(ktype, [])
428
+ t2_lengths = trace2_chains.get(ktype, [])
429
+
430
+ # Skip if not enough data
431
+ if len(t1_lengths) < 5 and len(t2_lengths) < 5:
263
432
  continue
264
433
 
265
- trace1_count = trace1_counts[ktype]
266
- trace2_count = trace2_counts[ktype]
434
+ # Filter to chains with multiple kernels
435
+ t1_multi = [l for l in t1_lengths if l > 1]
436
+ t2_multi = [l for l in t2_lengths if l > 1]
267
437
 
268
- # Skip low counts - need significant samples
269
- if trace1_count + trace2_count < 200:
438
+ if not t1_multi and not t2_multi:
270
439
  continue
271
440
 
272
- # Check if there's a very significant imbalance (>3x)
273
- # Lower ratios are likely implementation differences, not fusion
274
- if trace1_count == 0 or trace2_count == 0:
441
+ t1_total = sum(t1_lengths)
442
+ t2_total = sum(t2_lengths)
443
+ t1_chains = len(t1_multi) if t1_multi else len(t1_lengths)
444
+ t2_chains = len(t2_multi) if t2_multi else len(t2_lengths)
445
+
446
+ if t1_chains == 0 or t2_chains == 0:
275
447
  continue
276
448
 
449
+ t1_avg_chain = sum(t1_multi) / len(t1_multi) if t1_multi else 1.0
450
+ t2_avg_chain = sum(t2_multi) / len(t2_multi) if t2_multi else 1.0
451
+
452
+ chain_ratio = max(t1_avg_chain, t2_avg_chain) / min(t1_avg_chain, t2_avg_chain)
453
+
454
+ # Significant intra-fusion if chains are 2x+ different
455
+ if chain_ratio > 2.0 and abs(t1_total - t2_total) > 100:
456
+ if t1_avg_chain > t2_avg_chain:
457
+ unfused_platform = trace1_name
458
+ fused_platform = trace2_name
459
+ unfused_chains = t1_chains
460
+ fused_chains = t2_chains
461
+ unfused_avg = t1_avg_chain
462
+ fused_avg = t2_avg_chain
463
+ unfused_total = t1_total
464
+ fused_total = t2_total
465
+ else:
466
+ unfused_platform = trace2_name
467
+ fused_platform = trace1_name
468
+ unfused_chains = t2_chains
469
+ fused_chains = t1_chains
470
+ unfused_avg = t2_avg_chain
471
+ fused_avg = t1_avg_chain
472
+ unfused_total = t2_total
473
+ fused_total = t1_total
474
+
475
+ mappings.append({
476
+ "fused_platform": fused_platform,
477
+ "fused_kernel_type": ktype,
478
+ "fused_count": fused_total,
479
+ "unfused_platform": unfused_platform,
480
+ "unfused_sequence": [ktype, ktype], # Same type repeated
481
+ "unfused_count_per_type": {ktype: unfused_total},
482
+ "pattern_count": unfused_total - fused_total,
483
+ "pattern_confidence": min(unfused_chains, fused_chains) / max(unfused_chains, fused_chains),
484
+ "evidence": f"{unfused_platform} runs {ktype} in chains of {unfused_avg:.0f} calls ({unfused_chains} chains, {unfused_total:,} total), {fused_platform} fuses to {fused_avg:.0f} calls ({fused_chains} chains, {fused_total:,} total) - {chain_ratio:.1f}x more efficient"
485
+ })
486
+
487
+ return mappings
488
+
489
+
490
+ def _find_partial_fusion_via_groups(
491
+ trace1_large: dict[int, list[dict]],
492
+ trace2_large: dict[int, list[dict]],
493
+ matches: list[tuple[int, int]],
494
+ trace1_name: str,
495
+ trace2_name: str,
496
+ ) -> list[dict]:
497
+ """Find partial fusion patterns by analyzing correlation group differences.
498
+
499
+ When one platform has fewer of a kernel type, check what kernel types the
500
+ other platform has MORE of in those same groups - those are likely fusion targets.
501
+ """
502
+ from collections import Counter
503
+ from wafer_core.lib.trace_compare.classifier import classify_kernel
504
+
505
+ mappings = []
506
+
507
+ # For each matched pair, track kernel type counts
508
+ trace1_all_types = []
509
+ trace2_all_types = []
510
+
511
+ for trace1_cid, trace2_cid in matches:
512
+ trace1_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace1_large[trace1_cid]]
513
+ trace2_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace2_large[trace2_cid]]
514
+ trace1_all_types.extend(trace1_ktypes)
515
+ trace2_all_types.extend(trace2_ktypes)
516
+
517
+ # Find kernel types with significant imbalances
518
+ trace1_counts = Counter(trace1_all_types)
519
+ trace2_counts = Counter(trace2_all_types)
520
+ all_types = set(trace1_counts.keys()) | set(trace2_counts.keys())
521
+
522
+ for ktype in all_types:
523
+ trace1_count = trace1_counts.get(ktype, 0)
524
+ trace2_count = trace2_counts.get(ktype, 0)
525
+
526
+ if trace1_count == 0 or trace2_count == 0:
527
+ continue # Handled by sequence-based detection
528
+
277
529
  ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
278
530
 
279
- if ratio < 3.0:
280
- continue
531
+ if ratio < 1.3 or trace1_count + trace2_count < 100:
532
+ continue # Not significant
281
533
 
282
- # Determine which platform has more (unfused) and which has fewer (fused)
534
+ # Determine which platform has fewer (fuses more)
283
535
  if trace1_count > trace2_count:
284
536
  unfused_platform = trace1_name
285
537
  fused_platform = trace2_name
286
538
  unfused_count = trace1_count
287
539
  fused_count = trace2_count
540
+
541
+ # Find groups where trace1 has this kernel but trace2 doesn't
542
+ fusion_targets = Counter()
543
+ groups_analyzed = 0
544
+
545
+ for trace1_cid, trace2_cid in matches:
546
+ trace1_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace1_large[trace1_cid]]
547
+ trace2_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace2_large[trace2_cid]]
548
+
549
+ trace1_has = ktype in trace1_ktypes
550
+ trace2_has = ktype in trace2_ktypes
551
+
552
+ if trace1_has and not trace2_has:
553
+ # What does trace2 have MORE of in this group?
554
+ trace1_kcounts = Counter(trace1_ktypes)
555
+ trace2_kcounts = Counter(trace2_ktypes)
556
+
557
+ for other_type in set(trace2_kcounts.keys()):
558
+ if other_type == ktype or other_type == "Other":
559
+ continue
560
+ diff = trace2_kcounts[other_type] - trace1_kcounts.get(other_type, 0)
561
+ if diff > 0:
562
+ fusion_targets[other_type] += diff
563
+
564
+ groups_analyzed += 1
565
+
566
+ if fusion_targets and groups_analyzed >= 5:
567
+ # Report top fusion targets
568
+ top_targets = fusion_targets.most_common(3)
569
+ target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
570
+
571
+ mappings.append({
572
+ "fused_platform": fused_platform,
573
+ "fused_kernel_type": top_targets[0][0],
574
+ "fused_count": fused_count,
575
+ "unfused_platform": unfused_platform,
576
+ "unfused_sequence": [ktype],
577
+ "unfused_count_per_type": {ktype: unfused_count},
578
+ "pattern_count": unfused_count - fused_count,
579
+ "pattern_confidence": groups_analyzed / len(matches) if matches else 0,
580
+ "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
581
+ })
288
582
  else:
583
+ # Symmetric case for trace2 > trace1
289
584
  unfused_platform = trace2_name
290
585
  fused_platform = trace1_name
291
586
  unfused_count = trace2_count
292
587
  fused_count = trace1_count
293
588
 
294
- mappings.append({
295
- "fused_platform": fused_platform,
296
- "fused_kernel_type": ktype,
297
- "fused_count": fused_count,
298
- "unfused_platform": unfused_platform,
299
- "unfused_sequence": [ktype],
300
- "unfused_count_per_type": {ktype: unfused_count},
301
- "pattern_count": unfused_count - fused_count,
302
- "pattern_confidence": (unfused_count - fused_count) / unfused_count,
303
- "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}) - possible fusion",
304
- })
589
+ fusion_targets = Counter()
590
+ groups_analyzed = 0
305
591
 
306
- return mappings
592
+ for trace1_cid, trace2_cid in matches:
593
+ trace1_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace1_large[trace1_cid]]
594
+ trace2_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace2_large[trace2_cid]]
307
595
 
596
+ trace1_has = ktype in trace1_ktypes
597
+ trace2_has = ktype in trace2_ktypes
308
598
 
309
- def _find_explicit_fused_operations(
310
- trace1_kernels: list[dict],
311
- trace2_kernels: list[dict],
312
- trace1_name: str = "Trace1",
313
- trace2_name: str = "Trace2",
314
- trace1_platform: str = "AMD",
315
- ) -> list[dict]:
316
- """Find explicit fused operations (kernels with '+' in their classification).
599
+ if trace2_has and not trace1_has:
600
+ trace1_kcounts = Counter(trace1_ktypes)
601
+ trace2_kcounts = Counter(trace2_ktypes)
317
602
 
318
- These are operations like 'RMSNorm+GEMM' that are explicitly classified as fused.
319
- """
320
- mappings = []
321
- trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
322
-
323
- def get_fused_ops(kernels: list[dict], platform: str) -> dict[str, list[str]]:
324
- """Get fused operations and their kernel names."""
325
- fused_ops: dict[str, list[str]] = defaultdict(list)
326
- for k in kernels:
327
- name = k.get("name", "")
328
- op, _pattern = classify(name, platform)
329
- if "+" in op.value or op == Op.FUSED_UNKNOWN:
330
- fused_ops[op.value].append(name)
331
- return dict(fused_ops)
332
-
333
- trace1_fused = get_fused_ops(trace1_kernels, trace1_platform)
334
- trace2_fused = get_fused_ops(trace2_kernels, trace2_platform)
335
-
336
- # Find fused operations unique to trace1
337
- for fused_op, kernels in trace1_fused.items():
338
- if fused_op not in trace2_fused:
339
- # Parse components from the fused op name
340
- if "+" in fused_op:
341
- components = [c.strip() for c in fused_op.split("+")]
342
- else:
343
- components = [fused_op]
603
+ for other_type in set(trace1_kcounts.keys()):
604
+ if other_type == ktype or other_type == "Other":
605
+ continue
606
+ diff = trace1_kcounts[other_type] - trace2_kcounts.get(other_type, 0)
607
+ if diff > 0:
608
+ fusion_targets[other_type] += diff
344
609
 
345
- mappings.append({
346
- "fused_platform": trace1_name,
347
- "fused_kernel_type": fused_op,
348
- "fused_count": len(kernels),
349
- "unfused_platform": trace2_name,
350
- "unfused_sequence": components,
351
- "unfused_count_per_type": {c: 0 for c in components}, # Unknown
352
- "pattern_count": len(kernels),
353
- "pattern_confidence": 1.0,
354
- "evidence": f"{trace1_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
355
- "fused_kernel_names": kernels[:3], # Sample of kernel names
356
- })
610
+ groups_analyzed += 1
357
611
 
358
- # Find fused operations unique to trace2
359
- for fused_op, kernels in trace2_fused.items():
360
- if fused_op not in trace1_fused:
361
- if "+" in fused_op:
362
- components = [c.strip() for c in fused_op.split("+")]
363
- else:
364
- components = [fused_op]
612
+ if fusion_targets and groups_analyzed >= 5:
613
+ top_targets = fusion_targets.most_common(3)
614
+ target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
365
615
 
366
- mappings.append({
367
- "fused_platform": trace2_name,
368
- "fused_kernel_type": fused_op,
369
- "fused_count": len(kernels),
370
- "unfused_platform": trace1_name,
371
- "unfused_sequence": components,
372
- "unfused_count_per_type": {c: 0 for c in components},
373
- "pattern_count": len(kernels),
374
- "pattern_confidence": 1.0,
375
- "evidence": f"{trace2_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
376
- "fused_kernel_names": kernels[:3],
377
- })
616
+ mappings.append({
617
+ "fused_platform": fused_platform,
618
+ "fused_kernel_type": top_targets[0][0],
619
+ "fused_count": fused_count,
620
+ "unfused_platform": unfused_platform,
621
+ "unfused_sequence": [ktype],
622
+ "unfused_count_per_type": {ktype: unfused_count},
623
+ "pattern_count": unfused_count - fused_count,
624
+ "pattern_confidence": groups_analyzed / len(matches) if matches else 0,
625
+ "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
626
+ })
378
627
 
379
628
  return mappings
380
629
 
381
630
 
382
- def detect_fusion_patterns(
383
- amd_kernels: list[dict],
384
- nvidia_kernels: list[dict],
385
- ) -> FusionAnalysis:
386
- """Detect fusion patterns using explicit fused operation detection only.
387
-
388
- This approach only reports high-confidence fusions where kernels are
389
- explicitly classified as fused (e.g., 'RMSNorm+GEMM', 'SwiGLU+GEMM').
390
-
391
- We intentionally avoid speculative detection (sequence patterns, count
392
- imbalances) because these produce too many false positives - count
393
- differences are usually due to different implementations, not fusion.
631
+ def analyze_fusion_differences(
632
+ amd_trace_path: str | Path,
633
+ nv_trace_path: str | Path,
634
+ min_group_size: int = 50,
635
+ ) -> dict[str, Any]:
636
+ """Main fusion analysis function.
394
637
 
395
638
  Args:
396
- amd_kernels: List of AMD kernel events
397
- nvidia_kernels: List of NVIDIA kernel events
639
+ amd_trace_path: Path to AMD trace
640
+ nv_trace_path: Path to NVIDIA trace
641
+ min_group_size: Only analyze correlation groups with at least this many kernels
398
642
 
399
643
  Returns:
400
- FusionAnalysis with detected patterns
644
+ Dictionary with analysis results containing:
645
+ - metadata: trace info
646
+ - global_counts: kernel type distribution across entire trace
647
+ - fusion_opportunities: significant fusion differences
648
+ - fusion_mappings: actual kernel-to-kernel mappings (NEW)
401
649
  """
402
- all_mappings: list[dict] = []
403
-
404
- # Only use explicit fused operations (highest confidence, no false positives)
405
- # These are kernels explicitly classified with '+' in their operation type
406
- explicit_fusions = _find_explicit_fused_operations(
407
- amd_kernels, nvidia_kernels,
408
- trace1_name="AMD", trace2_name="NVIDIA",
409
- trace1_platform="AMD",
650
+ # Load traces (maintain order - don't swap)
651
+ trace1_platform, trace1_gpu, trace1_kernels, trace1_corr_groups = _load_trace_for_fusion(
652
+ amd_trace_path
410
653
  )
411
- all_mappings.extend(explicit_fusions)
412
-
413
- # NOTE: We intentionally skip sequence-based and count-imbalance detection
414
- # because they produce false positives. Count differences between platforms
415
- # are usually due to different library implementations (cuBLAS vs hipBLAS),
416
- # not actual kernel fusion.
417
-
418
- # Convert to FusionPattern objects
419
- patterns: list[FusionPattern] = []
420
- for mapping in all_mappings:
421
- fused_kernel_names = mapping.get("fused_kernel_names", [])
422
- fused_kernel = fused_kernel_names[0] if fused_kernel_names else mapping["fused_kernel_type"]
423
-
424
- patterns.append(
425
- FusionPattern(
426
- layer=0, # Pattern-based analysis doesn't track layers
427
- operation=mapping["fused_kernel_type"],
428
- fused_platform=mapping["fused_platform"],
429
- fused_kernel=fused_kernel,
430
- unfused_kernels=mapping["unfused_sequence"],
431
- count=mapping["pattern_count"],
432
- evidence=mapping["evidence"],
654
+ trace2_platform, trace2_gpu, trace2_kernels, trace2_corr_groups = _load_trace_for_fusion(
655
+ nv_trace_path
656
+ )
657
+
658
+ # Override platform names with generic "Trace 1" and "Trace 2" for UI consistency
659
+ trace1_platform = "Trace 1"
660
+ trace2_platform = "Trace 2"
661
+
662
+ # Filter to "significant" correlation groups
663
+ trace1_large = {
664
+ cid: kernels
665
+ for cid, kernels in trace1_corr_groups.items()
666
+ if len(kernels) >= min_group_size
667
+ }
668
+ trace2_large = {
669
+ cid: kernels
670
+ for cid, kernels in trace2_corr_groups.items()
671
+ if len(kernels) >= min_group_size
672
+ }
673
+
674
+ # Match correlation groups between platforms
675
+ matches = _match_correlation_groups(trace1_large, trace2_large)
676
+
677
+ # Analyze differences in matched groups
678
+ fusion_diffs: dict[str, dict[str, Any]] = defaultdict(
679
+ lambda: {
680
+ "trace1_count": 0,
681
+ "trace2_count": 0,
682
+ "trace1_time_us": 0,
683
+ "trace2_time_us": 0,
684
+ "groups_with_diff": 0,
685
+ "total_groups": 0,
686
+ }
687
+ )
688
+
689
+ # NEW: Collect actual fusion mappings
690
+ all_fusion_mappings = []
691
+
692
+ for trace1_cid, trace2_cid in matches:
693
+ trace1_comp, trace1_times = _analyze_correlation_group(trace1_large[trace1_cid])
694
+ trace2_comp, trace2_times = _analyze_correlation_group(trace2_large[trace2_cid])
695
+
696
+ # Find all kernel types in either platform
697
+ all_types = set(trace1_comp.keys()) | set(trace2_comp.keys())
698
+
699
+ for ktype in all_types:
700
+ trace1_count = trace1_comp.get(ktype, 0)
701
+ trace2_count = trace2_comp.get(ktype, 0)
702
+ trace1_time = trace1_times.get(ktype, 0)
703
+ trace2_time = trace2_times.get(ktype, 0)
704
+
705
+ fusion_diffs[ktype]["trace1_count"] += trace1_count
706
+ fusion_diffs[ktype]["trace2_count"] += trace2_count
707
+ fusion_diffs[ktype]["trace1_time_us"] += trace1_time
708
+ fusion_diffs[ktype]["trace2_time_us"] += trace2_time
709
+ fusion_diffs[ktype]["total_groups"] += 1
710
+
711
+ if trace1_count != trace2_count:
712
+ fusion_diffs[ktype]["groups_with_diff"] += 1
713
+
714
+ # NEW: Find actual kernel mappings in this correlation group
715
+ group_mappings = _find_fusion_mappings(
716
+ trace1_large[trace1_cid],
717
+ trace2_large[trace2_cid],
718
+ trace1_name=trace1_platform,
719
+ trace2_name=trace2_platform
720
+ )
721
+ # Add correlation ID context to each mapping
722
+ for mapping in group_mappings:
723
+ mapping["correlation_group_trace1"] = trace1_cid
724
+ mapping["correlation_group_trace2"] = trace2_cid
725
+ all_fusion_mappings.extend(group_mappings)
726
+
727
+ # Also get global counts for context
728
+ global_trace1_counts: Counter[str] = Counter(
729
+ [_classify_kernel_cached(k.get("name", "")) for k in trace1_kernels]
730
+ )
731
+ global_trace2_counts: Counter[str] = Counter(
732
+ [_classify_kernel_cached(k.get("name", "")) for k in trace2_kernels]
733
+ )
734
+
735
+ # Build results
736
+ results: dict[str, Any] = {
737
+ "metadata": {
738
+ "trace1_gpu": trace1_gpu,
739
+ "trace2_gpu": trace2_gpu,
740
+ "trace1_total_kernels": len(trace1_kernels),
741
+ "trace2_total_kernels": len(trace2_kernels),
742
+ "trace1_correlation_groups": len(trace1_large),
743
+ "trace2_correlation_groups": len(trace2_large),
744
+ "matched_groups": len(matches),
745
+ },
746
+ "global_counts": {},
747
+ "fusion_opportunities": [],
748
+ "fusion_mappings": all_fusion_mappings, # NEW: Include actual mappings
749
+ }
750
+
751
+ # Global counts for all kernel types
752
+ all_ktypes = set(global_trace1_counts.keys()) | set(global_trace2_counts.keys())
753
+ for ktype in all_ktypes:
754
+ trace1_total = global_trace1_counts.get(ktype, 0)
755
+ trace2_total = global_trace2_counts.get(ktype, 0)
756
+
757
+ if trace1_total > 0 or trace2_total > 0:
758
+ results["global_counts"][ktype] = {
759
+ "trace1_count": trace1_total,
760
+ "trace2_count": trace2_total,
761
+ "ratio": trace1_total / trace2_total if trace2_total > 0 else float("inf"),
762
+ }
763
+
764
+ # Fusion opportunities from matched correlation groups
765
+ for ktype, stats in fusion_diffs.items():
766
+ trace1_avg = (
767
+ stats["trace1_count"] / stats["total_groups"]
768
+ if stats["total_groups"] > 0
769
+ else 0
770
+ )
771
+ trace2_avg = (
772
+ stats["trace2_count"] / stats["total_groups"]
773
+ if stats["total_groups"] > 0
774
+ else 0
775
+ )
776
+ trace1_time_ms = stats["trace1_time_us"] / 1000
777
+ trace2_time_ms = stats["trace2_time_us"] / 1000
778
+
779
+ # Calculate significance
780
+ diff_ratio = trace1_avg / trace2_avg if trace2_avg > 0 else float("inf")
781
+ reverse_ratio = trace2_avg / trace1_avg if trace1_avg > 0 else float("inf")
782
+
783
+ # Only report if there's a significant difference
784
+ # Either: one platform has it and the other doesn't (ratio > 10)
785
+ # Or: one platform has significantly more (ratio > 2.0)
786
+ is_significant = (
787
+ (diff_ratio > 10.0 or reverse_ratio > 10.0)
788
+ or ( # One platform doesn't have it
789
+ (diff_ratio > 2.0 or reverse_ratio > 2.0)
790
+ and stats["trace1_count"] + stats["trace2_count"] > 20 # Significant difference # Not trivial counts
433
791
  )
434
792
  )
435
793
 
436
- amd_fuses = sum(1 for p in patterns if p.fused_platform == "AMD")
437
- nvidia_fuses = sum(1 for p in patterns if p.fused_platform == "NVIDIA")
794
+ if is_significant:
795
+ # Determine who fuses (who has FEWER calls = more fusion)
796
+ if diff_ratio > 1.5:
797
+ fused_by = "Trace 2" # Trace 2 has fewer calls, so it fuses more
798
+ ratio = diff_ratio
799
+ else:
800
+ fused_by = "Trace 1" # Trace 1 has fewer calls, so it fuses more
801
+ ratio = reverse_ratio
802
+
803
+ # Calculate time ratio (who's faster for this operation)
804
+ time_ratio = trace1_time_ms / trace2_time_ms if trace2_time_ms > 0 else float("inf")
805
+
806
+ results["fusion_opportunities"].append(
807
+ {
808
+ "kernel_type": ktype,
809
+ "trace1_total": stats["trace1_count"],
810
+ "trace2_total": stats["trace2_count"],
811
+ "trace1_avg_per_group": trace1_avg,
812
+ "trace2_avg_per_group": trace2_avg,
813
+ "trace1_time_ms": trace1_time_ms,
814
+ "trace2_time_ms": trace2_time_ms,
815
+ "time_ratio": time_ratio,
816
+ "ratio": ratio,
817
+ "fused_by": fused_by,
818
+ "groups_affected": stats["groups_with_diff"],
819
+ "total_groups": stats["total_groups"],
820
+ }
821
+ )
438
822
 
439
- return FusionAnalysis(
440
- patterns=patterns,
441
- summary={
442
- "amd_fuses": amd_fuses,
443
- "nvidia_fuses": nvidia_fuses,
444
- "total_fusion_opportunities": len(patterns),
445
- },
446
- )
823
+ # ADDITIONAL: Check global counts for significant differences not captured above
824
+ # This catches patterns like Sort that may be in small groups or distributed differently
825
+ for ktype in all_ktypes:
826
+ # Skip if already added from correlation group analysis
827
+ if any(opp["kernel_type"] == ktype for opp in results["fusion_opportunities"]):
828
+ continue
447
829
 
830
+ trace1_total = global_trace1_counts.get(ktype, 0)
831
+ trace2_total = global_trace2_counts.get(ktype, 0)
448
832
 
449
- def analyze_fusion_from_alignment(
450
- layer_alignments: list[Any],
451
- amd_kernels: list[dict] | None = None,
452
- nvidia_kernels: list[dict] | None = None,
453
- ) -> dict[str, Any]:
454
- """Analyze fusion from kernel data (for API compatibility).
833
+ # Skip trivial counts
834
+ if trace1_total + trace2_total < 100:
835
+ continue
455
836
 
456
- Args:
457
- layer_alignments: List of aligned layers (unused - kept for API compatibility)
458
- amd_kernels: Optional list of AMD kernel events for pattern-based analysis
459
- nvidia_kernels: Optional list of NVIDIA kernel events for pattern-based analysis
837
+ # Calculate global ratio
838
+ global_ratio = trace1_total / trace2_total if trace2_total > 0 else float("inf")
839
+ global_reverse_ratio = trace2_total / trace1_total if trace1_total > 0 else float("inf")
460
840
 
461
- Returns:
462
- Dictionary with fusion analysis results
463
- """
464
- # If raw kernels provided, use pattern-based analysis (preferred)
465
- if amd_kernels is not None and nvidia_kernels is not None:
466
- fusion_analysis = detect_fusion_patterns(amd_kernels, nvidia_kernels)
467
- else:
468
- # Fallback: empty analysis if no kernel data
469
- fusion_analysis = FusionAnalysis(
470
- patterns=[],
471
- summary={"amd_fuses": 0, "nvidia_fuses": 0, "total_fusion_opportunities": 0},
841
+ # Check if globally significant (more aggressive threshold for comprehensive detection)
842
+ is_globally_significant = (
843
+ (global_ratio > 2.0 or global_reverse_ratio > 2.0)
844
+ and (trace1_total + trace2_total > 100)
472
845
  )
473
846
 
474
- fusion_opportunities = []
475
- fusion_mappings = []
476
-
477
- for pattern in fusion_analysis.patterns:
478
- unfused_platform = "NVIDIA" if pattern.fused_platform == "AMD" else "AMD"
479
-
480
- fusion_opportunities.append({
481
- "kernel_type": pattern.operation,
482
- "layer": pattern.layer,
483
- "fused_by": pattern.fused_platform,
484
- "fused_kernel": pattern.fused_kernel,
485
- "unfused_kernels": pattern.unfused_kernels,
486
- "count": pattern.count,
487
- "evidence": pattern.evidence,
488
- })
489
-
490
- fusion_mappings.append({
491
- "fused_platform": pattern.fused_platform,
492
- "fused_kernel_type": pattern.operation,
493
- "fused_kernel_name": pattern.fused_kernel,
494
- "unfused_platform": unfused_platform,
495
- "unfused_sequence": pattern.unfused_kernels,
496
- "pattern_count": pattern.count,
497
- "evidence": pattern.evidence,
498
- "layer": pattern.layer,
499
- })
500
-
501
- return {
502
- "fusion_opportunities": fusion_opportunities,
503
- "fusion_mappings": fusion_mappings,
504
- "summary": fusion_analysis.summary,
505
- }
847
+ if is_globally_significant:
848
+ # Get timing info from all kernels (not just matched groups)
849
+ trace1_time = sum(
850
+ k.get("dur", 0) for k in trace1_kernels
851
+ if _classify_kernel_cached(k.get("name", "")) == ktype
852
+ ) / 1000 # Convert to ms
853
+ trace2_time = sum(
854
+ k.get("dur", 0) for k in trace2_kernels
855
+ if _classify_kernel_cached(k.get("name", "")) == ktype
856
+ ) / 1000
857
+
858
+ # Determine who fuses (who has FEWER calls = more fusion)
859
+ if global_ratio > 1.5:
860
+ fused_by = "Trace 2" # Trace 2 has fewer calls
861
+ ratio = global_ratio
862
+ else:
863
+ fused_by = "Trace 1" # Trace 1 has fewer calls
864
+ ratio = global_reverse_ratio
865
+
866
+ time_ratio = trace1_time / trace2_time if trace2_time > 0 else float("inf")
867
+
868
+ results["fusion_opportunities"].append(
869
+ {
870
+ "kernel_type": ktype,
871
+ "trace1_total": trace1_total,
872
+ "trace2_total": trace2_total,
873
+ "trace1_avg_per_group": trace1_total / len(matches) if matches else 0,
874
+ "trace2_avg_per_group": trace2_total / len(matches) if matches else 0,
875
+ "trace1_time_ms": trace1_time,
876
+ "trace2_time_ms": trace2_time,
877
+ "time_ratio": time_ratio,
878
+ "ratio": ratio,
879
+ "fused_by": fused_by,
880
+ "groups_affected": 0, # Unknown for global analysis
881
+ "total_groups": len(matches),
882
+ }
883
+ )
884
+
885
+ # Sort by impact (ratio * total count)
886
+ results["fusion_opportunities"].sort(
887
+ key=lambda x: x["ratio"] * (x["trace1_total"] + x["trace2_total"]), reverse=True
888
+ )
889
+
890
+ # ADD PARTIAL FUSION MAPPINGS using correlation group differential analysis
891
+ # This catches patterns like Sort that exist on both platforms but with different frequencies
892
+ partial_mappings = _find_partial_fusion_via_groups(
893
+ trace1_large,
894
+ trace2_large,
895
+ matches,
896
+ trace1_name=trace1_platform,
897
+ trace2_name=trace2_platform
898
+ )
899
+ all_fusion_mappings.extend(partial_mappings)
900
+
901
+ # DETECT INTRA-TYPE FUSION (same kernel type fused with itself, like Sort chains)
902
+ # Do this FIRST since it's more accurate than the fallback global analysis
903
+ intra_mappings = _detect_intra_type_fusion(
904
+ trace1_kernels,
905
+ trace2_kernels,
906
+ trace1_name=trace1_platform,
907
+ trace2_name=trace2_platform
908
+ )
909
+ all_fusion_mappings.extend(intra_mappings)
910
+
911
+ # Collect kernel types already handled by intra-type fusion
912
+ intra_handled_types = set(m["fused_kernel_type"] for m in intra_mappings)
913
+
914
+ # ALSO ADD GLOBAL FUSION MAPPINGS for kernels not in large correlation groups
915
+ # Skip types already handled by intra-type fusion (more accurate)
916
+ global_mappings = _find_fusion_mappings(
917
+ trace1_kernels,
918
+ trace2_kernels,
919
+ trace1_name=trace1_platform,
920
+ trace2_name=trace2_platform
921
+ )
922
+ # Filter: skip if already handled or if evidence is duplicate
923
+ existing_evidence = set(m["evidence"] for m in all_fusion_mappings)
924
+ for mapping in global_mappings:
925
+ ktype = mapping["unfused_sequence"][0] if mapping["unfused_sequence"] else None
926
+ if ktype not in intra_handled_types and mapping["evidence"] not in existing_evidence:
927
+ all_fusion_mappings.append(mapping)
928
+
929
+ return results