wafer-core 0.1.27__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,25 +1,18 @@
1
- """Fusion analysis using aligned kernel pairs.
1
+ """Fusion analysis for detecting kernel fusion differences between platforms.
2
2
 
3
- Detects kernel fusion patterns at the kernel-pair level within aligned layers.
3
+ Detects fusion differences between AMD and NVIDIA by analyzing:
4
+ 1. Kernel types unique to each platform (one platform fuses them away)
5
+ 2. Sequence patterns (what comes before/after unique kernels)
6
+ 3. Count imbalances (one platform has significantly more calls)
4
7
 
5
- Key insight: When one platform has an operation type like "RMSNorm+GEMM" (containing '+'),
6
- that platform IS fusing. The OTHER platform runs the components separately.
8
+ This is the pattern-based approach which is more reliable than alignment-based detection.
7
9
  """
8
10
 
9
11
  from collections import Counter, defaultdict
10
12
  from dataclasses import dataclass, field
11
13
  from typing import Any
12
14
 
13
- from .aligner import KernelPair, LayerAlignment
14
-
15
-
16
- # Maps fused operation names to their component operations
17
- FUSED_OPERATION_COMPONENTS: dict[str, list[str]] = {
18
- "RMSNorm+GEMM": ["RMSNorm", "Dense GEMM"],
19
- "SwiGLU+GEMM": ["SwiGLU", "Dense GEMM"],
20
- "MoE GEMM+SwiGLU": ["MoE GEMM", "SwiGLU"],
21
- "Embedding+RMSNorm+GEMM": ["Elementwise", "RMSNorm", "Dense GEMM"], # Embedding is often Elementwise
22
- }
15
+ from .classifier import classify, Op
23
16
 
24
17
 
25
18
  @dataclass
@@ -43,311 +36,426 @@ class FusionAnalysis:
43
36
  summary: dict[str, Any] = field(default_factory=dict)
44
37
 
45
38
 
46
- def _is_fused_operation(operation: str) -> bool:
47
- """Check if an operation type represents a fused operation.
48
-
49
- Operations like 'RMSNorm+GEMM', 'MoE GEMM+SwiGLU' indicate fusion.
50
- Also handles 'Fused (Unknown)' from heuristic detection.
51
- """
52
- return "+" in operation or operation == "Fused (Unknown)"
53
-
54
-
55
- def _get_component_operations(fused_op: str) -> list[str]:
56
- """Get the component operations for a fused operation type."""
57
- if fused_op in FUSED_OPERATION_COMPONENTS:
58
- return FUSED_OPERATION_COMPONENTS[fused_op]
59
-
60
- # Parse from the operation name itself (e.g., "RMSNorm+GEMM" -> ["RMSNorm", "GEMM"])
61
- if "+" in fused_op:
62
- parts = [p.strip() for p in fused_op.split("+")]
63
- # Map shorthand to full names
64
- mapping = {
65
- "GEMM": "Dense GEMM",
66
- "SwiGLU": "Triton Fused",
67
- }
68
- return [mapping.get(p, p) for p in parts]
69
-
70
- return []
71
-
72
-
73
- def _find_component_kernels(
74
- layer_alignment: LayerAlignment,
75
- component_ops: list[str],
76
- platform: str,
77
- ) -> list[str]:
78
- """Find kernels for component operations on the specified platform.
79
-
39
+ def _classify_kernel(name: str, platform: str = "AMD") -> str:
40
+ """Classify a kernel name to its operation type."""
41
+ op, _pattern = classify(name, platform)
42
+ return op.value
43
+
44
+
45
+ def _find_fusion_mappings(
46
+ trace1_kernels: list[dict],
47
+ trace2_kernels: list[dict],
48
+ trace1_name: str = "Trace1",
49
+ trace2_name: str = "Trace2",
50
+ trace1_platform: str = "AMD",
51
+ ) -> list[dict]:
52
+ """Find fusion mappings by analyzing kernel execution sequence patterns.
53
+
54
+ This function identifies when one platform runs multiple kernels separately
55
+ while the other platform fuses them into a single kernel.
56
+
80
57
  Args:
81
- layer_alignment: Layer alignment data
82
- component_ops: List of operation types to look for (e.g., ["RMSNorm", "Dense GEMM"])
83
- platform: "AMD" or "NVIDIA"
84
-
58
+ trace1_kernels: List of kernel events from first trace
59
+ trace2_kernels: List of kernel events from second trace
60
+ trace1_name: Name of first platform (e.g., "AMD")
61
+ trace2_name: Name of second platform (e.g., "NVIDIA")
62
+ trace1_platform: Platform string for classification ("AMD" or "NVIDIA")
63
+
85
64
  Returns:
86
- List of kernel names found for these operations
65
+ List of mapping dictionaries with fusion details
66
+ """
67
+ mappings = []
68
+ trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
69
+
70
+ # Sort kernels by timestamp
71
+ trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get("ts", 0))
72
+ trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get("ts", 0))
73
+
74
+ # Classify all kernels
75
+ trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_sorted]
76
+ trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_sorted]
77
+
78
+ # Find kernel types unique to each trace
79
+ trace1_type_set = set(trace1_types)
80
+ trace2_type_set = set(trace2_types)
81
+
82
+ trace1_only = trace1_type_set - trace2_type_set
83
+ trace2_only = trace2_type_set - trace1_type_set
84
+
85
+ # For each unique type in trace1, find common sequence patterns
86
+ for unique_type in trace1_only:
87
+ # Skip "Other" since it's too generic
88
+ if unique_type == "Other":
89
+ continue
90
+
91
+ # Find all occurrences of this type
92
+ indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
93
+
94
+ if len(indices) < 5: # Need enough samples to be meaningful
95
+ continue
96
+
97
+ # Analyze what comes before each occurrence
98
+ before_types: dict[str, int] = defaultdict(int)
99
+
100
+ for idx in indices:
101
+ if idx > 0:
102
+ before_types[trace1_types[idx - 1]] += 1
103
+
104
+ # Find the most common pattern
105
+ if not before_types:
106
+ continue
107
+ most_common_before = max(before_types.items(), key=lambda x: x[1])
108
+
109
+ # If there's a strong pattern (>80% of occurrences)
110
+ if most_common_before[1] / len(indices) > 0.8:
111
+ fusion_candidate = most_common_before[0]
112
+
113
+ # Verify trace2 has this type
114
+ if fusion_candidate in trace2_type_set:
115
+ trace1_fusion_count = trace1_types.count(fusion_candidate)
116
+ trace2_fusion_count = trace2_types.count(fusion_candidate)
117
+
118
+ mappings.append({
119
+ "fused_platform": trace2_name,
120
+ "fused_kernel_type": fusion_candidate,
121
+ "fused_count": trace2_fusion_count,
122
+ "unfused_platform": trace1_name,
123
+ "unfused_sequence": [fusion_candidate, unique_type],
124
+ "unfused_count_per_type": {
125
+ fusion_candidate: trace1_fusion_count,
126
+ unique_type: len(indices),
127
+ },
128
+ "pattern_count": len(indices),
129
+ "pattern_confidence": most_common_before[1] / len(indices),
130
+ "evidence": f"{trace1_name} runs {fusion_candidate}+{unique_type} separately, {trace2_name} fuses into {fusion_candidate}",
131
+ })
132
+
133
+ # Also check trace2-only types
134
+ for unique_type in trace2_only:
135
+ if unique_type == "Other":
136
+ continue
137
+
138
+ indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
139
+
140
+ if len(indices) < 5:
141
+ continue
142
+
143
+ before_types = defaultdict(int)
144
+
145
+ for idx in indices:
146
+ if idx > 0:
147
+ before_types[trace2_types[idx - 1]] += 1
148
+
149
+ if not before_types:
150
+ continue
151
+ most_common_before = max(before_types.items(), key=lambda x: x[1])
152
+
153
+ if most_common_before[1] / len(indices) > 0.8:
154
+ fusion_candidate = most_common_before[0]
155
+
156
+ if fusion_candidate in trace1_type_set:
157
+ trace1_fusion_count = trace1_types.count(fusion_candidate)
158
+ trace2_fusion_count = trace2_types.count(fusion_candidate)
159
+
160
+ mappings.append({
161
+ "fused_platform": trace1_name,
162
+ "fused_kernel_type": fusion_candidate,
163
+ "fused_count": trace1_fusion_count,
164
+ "unfused_platform": trace2_name,
165
+ "unfused_sequence": [fusion_candidate, unique_type],
166
+ "unfused_count_per_type": {
167
+ fusion_candidate: trace2_fusion_count,
168
+ unique_type: len(indices),
169
+ },
170
+ "pattern_count": len(indices),
171
+ "pattern_confidence": most_common_before[1] / len(indices),
172
+ "evidence": f"{trace2_name} runs {fusion_candidate}+{unique_type} separately, {trace1_name} fuses into {fusion_candidate}",
173
+ })
174
+
175
+ return mappings
176
+
177
+
178
+ def _find_count_imbalance_fusions(
179
+ trace1_kernels: list[dict],
180
+ trace2_kernels: list[dict],
181
+ trace1_name: str = "Trace1",
182
+ trace2_name: str = "Trace2",
183
+ trace1_platform: str = "AMD",
184
+ ) -> list[dict]:
185
+ """Find fusions by looking for significant count imbalances.
186
+
187
+ When one platform has significantly more kernel calls of a type (>1.5x),
188
+ it suggests the other platform fuses those operations.
189
+ """
190
+ mappings = []
191
+ trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
192
+
193
+ # Classify all kernels
194
+ trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_kernels]
195
+ trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_kernels]
196
+
197
+ # Count by type
198
+ trace1_counts = Counter(trace1_types)
199
+ trace2_counts = Counter(trace2_types)
200
+
201
+ # Find common types with significant differences
202
+ common_types = set(trace1_counts.keys()) & set(trace2_counts.keys())
203
+
204
+ for ktype in common_types:
205
+ if ktype == "Other":
206
+ continue
207
+
208
+ trace1_count = trace1_counts[ktype]
209
+ trace2_count = trace2_counts[ktype]
210
+
211
+ # Skip trivial counts
212
+ if trace1_count + trace2_count < 50:
213
+ continue
214
+
215
+ # Check if there's a significant imbalance (>1.5x)
216
+ if trace1_count == 0 or trace2_count == 0:
217
+ continue
218
+
219
+ ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
220
+
221
+ if ratio < 1.5:
222
+ continue
223
+
224
+ # Determine which platform has more (unfused) and which has fewer (fused)
225
+ if trace1_count > trace2_count:
226
+ unfused_platform = trace1_name
227
+ fused_platform = trace2_name
228
+ unfused_count = trace1_count
229
+ fused_count = trace2_count
230
+ else:
231
+ unfused_platform = trace2_name
232
+ fused_platform = trace1_name
233
+ unfused_count = trace2_count
234
+ fused_count = trace1_count
235
+
236
+ mappings.append({
237
+ "fused_platform": fused_platform,
238
+ "fused_kernel_type": ktype,
239
+ "fused_count": fused_count,
240
+ "unfused_platform": unfused_platform,
241
+ "unfused_sequence": [ktype],
242
+ "unfused_count_per_type": {ktype: unfused_count},
243
+ "pattern_count": unfused_count - fused_count,
244
+ "pattern_confidence": (unfused_count - fused_count) / unfused_count,
245
+ "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses",
246
+ })
247
+
248
+ return mappings
249
+
250
+
251
+ def _find_explicit_fused_operations(
252
+ trace1_kernels: list[dict],
253
+ trace2_kernels: list[dict],
254
+ trace1_name: str = "Trace1",
255
+ trace2_name: str = "Trace2",
256
+ trace1_platform: str = "AMD",
257
+ ) -> list[dict]:
258
+ """Find explicit fused operations (kernels with '+' in their classification).
259
+
260
+ These are operations like 'RMSNorm+GEMM' that are explicitly classified as fused.
87
261
  """
88
- found_kernels: list[str] = []
89
-
90
- for pair in layer_alignment.kernel_pairs:
91
- if pair.operation in component_ops:
92
- if platform == "AMD" and pair.amd_kernel and pair.amd_count > 0:
93
- found_kernels.append(pair.amd_kernel)
94
- elif platform == "NVIDIA" and pair.nvidia_kernel and pair.nvidia_count > 0:
95
- found_kernels.append(pair.nvidia_kernel)
96
-
97
- return found_kernels
98
-
99
-
100
- def detect_fusion_in_aligned_layers(
101
- layer_alignments: list[LayerAlignment],
262
+ mappings = []
263
+ trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
264
+
265
+ def get_fused_ops(kernels: list[dict], platform: str) -> dict[str, list[str]]:
266
+ """Get fused operations and their kernel names."""
267
+ fused_ops: dict[str, list[str]] = defaultdict(list)
268
+ for k in kernels:
269
+ name = k.get("name", "")
270
+ op, _pattern = classify(name, platform)
271
+ if "+" in op.value or op == Op.FUSED_UNKNOWN:
272
+ fused_ops[op.value].append(name)
273
+ return dict(fused_ops)
274
+
275
+ trace1_fused = get_fused_ops(trace1_kernels, trace1_platform)
276
+ trace2_fused = get_fused_ops(trace2_kernels, trace2_platform)
277
+
278
+ # Find fused operations unique to trace1
279
+ for fused_op, kernels in trace1_fused.items():
280
+ if fused_op not in trace2_fused:
281
+ # Parse components from the fused op name
282
+ if "+" in fused_op:
283
+ components = [c.strip() for c in fused_op.split("+")]
284
+ else:
285
+ components = [fused_op]
286
+
287
+ mappings.append({
288
+ "fused_platform": trace1_name,
289
+ "fused_kernel_type": fused_op,
290
+ "fused_count": len(kernels),
291
+ "unfused_platform": trace2_name,
292
+ "unfused_sequence": components,
293
+ "unfused_count_per_type": {c: 0 for c in components}, # Unknown
294
+ "pattern_count": len(kernels),
295
+ "pattern_confidence": 1.0,
296
+ "evidence": f"{trace1_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
297
+ "fused_kernel_names": kernels[:3], # Sample of kernel names
298
+ })
299
+
300
+ # Find fused operations unique to trace2
301
+ for fused_op, kernels in trace2_fused.items():
302
+ if fused_op not in trace1_fused:
303
+ if "+" in fused_op:
304
+ components = [c.strip() for c in fused_op.split("+")]
305
+ else:
306
+ components = [fused_op]
307
+
308
+ mappings.append({
309
+ "fused_platform": trace2_name,
310
+ "fused_kernel_type": fused_op,
311
+ "fused_count": len(kernels),
312
+ "unfused_platform": trace1_name,
313
+ "unfused_sequence": components,
314
+ "unfused_count_per_type": {c: 0 for c in components},
315
+ "pattern_count": len(kernels),
316
+ "pattern_confidence": 1.0,
317
+ "evidence": f"{trace2_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
318
+ "fused_kernel_names": kernels[:3],
319
+ })
320
+
321
+ return mappings
322
+
323
+
324
+ def detect_fusion_patterns(
325
+ amd_kernels: list[dict],
326
+ nvidia_kernels: list[dict],
102
327
  ) -> FusionAnalysis:
103
- """Detect fusion patterns from aligned layer data.
328
+ """Detect fusion patterns using pattern-based analysis.
329
+
330
+ This is the main entry point for fusion detection. It combines:
331
+ 1. Explicit fused operations (kernels classified with '+' in name)
332
+ 2. Sequence pattern analysis (unique kernel types with consistent patterns)
333
+ 3. Count imbalance analysis (one platform has significantly more calls)
104
334
 
105
335
  Args:
106
- layer_alignments: List of aligned layers
336
+ amd_kernels: List of AMD kernel events
337
+ nvidia_kernels: List of NVIDIA kernel events
107
338
 
108
339
  Returns:
109
340
  FusionAnalysis with detected patterns
110
341
  """
111
- fusion_patterns: list[FusionPattern] = []
112
-
113
- # Track patterns we've already seen to avoid duplicates
114
- seen_patterns: set[tuple[int, str, str]] = set() # (layer, operation, fused_platform)
115
-
116
- for layer_alignment in layer_alignments:
117
- for pair in layer_alignment.kernel_pairs:
118
- is_fused_op = _is_fused_operation(pair.operation)
119
-
120
- # Case 1: AMD has a kernel for this operation, NVIDIA doesn't
121
- if pair.amd_kernel and pair.amd_count > 0 and (pair.nvidia_kernel is None or pair.nvidia_count == 0):
122
- if is_fused_op:
123
- # AMD HAS the fused kernel → AMD is fusing
124
- # Find what NVIDIA runs separately
125
- component_ops = _get_component_operations(pair.operation)
126
- unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "NVIDIA")
127
-
128
- pattern_key = (layer_alignment.layer, pair.operation, "AMD")
129
- if pattern_key not in seen_patterns:
130
- seen_patterns.add(pattern_key)
131
-
132
- if unfused_kernels:
133
- evidence = f"AMD fuses {' + '.join(component_ops)} into {pair.amd_kernel}, NVIDIA runs {len(unfused_kernels)} separate kernels"
134
- else:
135
- evidence = f"AMD fuses into {pair.amd_kernel}, NVIDIA runs components separately"
136
-
137
- fusion_patterns.append(
138
- FusionPattern(
139
- layer=layer_alignment.layer,
140
- operation=pair.operation,
141
- fused_platform="AMD",
142
- fused_kernel=pair.amd_kernel,
143
- unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
144
- count=pair.amd_count,
145
- evidence=evidence,
146
- )
147
- )
148
- else:
149
- # Regular case: AMD has a kernel that NVIDIA fuses into something else
150
- # This means NVIDIA is fusing (it doesn't need this separate kernel)
151
- pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
152
- if pattern_key not in seen_patterns:
153
- seen_patterns.add(pattern_key)
154
-
155
- evidence = f"AMD runs {pair.amd_kernel} separately ({pair.amd_count}x), NVIDIA fuses this operation"
156
-
157
- fusion_patterns.append(
158
- FusionPattern(
159
- layer=layer_alignment.layer,
160
- operation=pair.operation,
161
- fused_platform="NVIDIA",
162
- fused_kernel="(fused into nearby kernel)",
163
- unfused_kernels=[pair.amd_kernel],
164
- count=pair.amd_count,
165
- evidence=evidence,
166
- )
167
- )
168
-
169
- # Case 2: NVIDIA has a kernel for this operation, AMD doesn't
170
- elif pair.nvidia_kernel and pair.nvidia_count > 0 and (pair.amd_kernel is None or pair.amd_count == 0):
171
- if is_fused_op:
172
- # NVIDIA HAS the fused kernel → NVIDIA is fusing
173
- # Find what AMD runs separately
174
- component_ops = _get_component_operations(pair.operation)
175
- unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "AMD")
176
-
177
- pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
178
- if pattern_key not in seen_patterns:
179
- seen_patterns.add(pattern_key)
180
-
181
- if unfused_kernels:
182
- evidence = f"NVIDIA fuses {' + '.join(component_ops)} into {pair.nvidia_kernel}, AMD runs {len(unfused_kernels)} separate kernels"
183
- else:
184
- evidence = f"NVIDIA fuses into {pair.nvidia_kernel}, AMD runs components separately"
185
-
186
- fusion_patterns.append(
187
- FusionPattern(
188
- layer=layer_alignment.layer,
189
- operation=pair.operation,
190
- fused_platform="NVIDIA",
191
- fused_kernel=pair.nvidia_kernel,
192
- unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
193
- count=pair.nvidia_count,
194
- evidence=evidence,
195
- )
196
- )
197
- else:
198
- # Regular case: NVIDIA has a kernel that AMD fuses into something else
199
- pattern_key = (layer_alignment.layer, pair.operation, "AMD")
200
- if pattern_key not in seen_patterns:
201
- seen_patterns.add(pattern_key)
202
-
203
- evidence = f"NVIDIA runs {pair.nvidia_kernel} separately ({pair.nvidia_count}x), AMD fuses this operation"
204
-
205
- fusion_patterns.append(
206
- FusionPattern(
207
- layer=layer_alignment.layer,
208
- operation=pair.operation,
209
- fused_platform="AMD",
210
- fused_kernel="(fused into nearby kernel)",
211
- unfused_kernels=[pair.nvidia_kernel],
212
- count=pair.nvidia_count,
213
- evidence=evidence,
214
- )
215
- )
216
-
217
- # Case 3: Both have kernels but with very different counts (partial fusion)
218
- elif (
219
- pair.amd_kernel
220
- and pair.nvidia_kernel
221
- and pair.amd_count > 0
222
- and pair.nvidia_count > 0
223
- ):
224
- count_ratio = pair.amd_count / pair.nvidia_count if pair.nvidia_count > 0 else float('inf')
225
-
226
- if count_ratio > 1.5:
227
- # AMD runs more → NVIDIA fuses some instances
228
- pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
229
- if pattern_key not in seen_patterns:
230
- seen_patterns.add(pattern_key)
231
-
232
- evidence = (
233
- f"AMD runs {pair.amd_kernel} {count_ratio:.1f}x more "
234
- f"({pair.amd_count} vs {pair.nvidia_count}), NVIDIA partially fuses"
235
- )
236
-
237
- fusion_patterns.append(
238
- FusionPattern(
239
- layer=layer_alignment.layer,
240
- operation=pair.operation,
241
- fused_platform="NVIDIA",
242
- fused_kernel=pair.nvidia_kernel,
243
- unfused_kernels=[pair.amd_kernel],
244
- count=pair.amd_count - pair.nvidia_count,
245
- evidence=evidence,
246
- )
247
- )
248
-
249
- elif count_ratio < 0.67:
250
- # NVIDIA runs more → AMD fuses some instances
251
- inverse_ratio = pair.nvidia_count / pair.amd_count if pair.amd_count > 0 else float('inf')
252
- pattern_key = (layer_alignment.layer, pair.operation, "AMD")
253
- if pattern_key not in seen_patterns:
254
- seen_patterns.add(pattern_key)
255
-
256
- evidence = (
257
- f"NVIDIA runs {pair.nvidia_kernel} {inverse_ratio:.1f}x more "
258
- f"({pair.nvidia_count} vs {pair.amd_count}), AMD partially fuses"
259
- )
260
-
261
- fusion_patterns.append(
262
- FusionPattern(
263
- layer=layer_alignment.layer,
264
- operation=pair.operation,
265
- fused_platform="AMD",
266
- fused_kernel=pair.amd_kernel,
267
- unfused_kernels=[pair.nvidia_kernel],
268
- count=pair.nvidia_count - pair.amd_count,
269
- evidence=evidence,
270
- )
271
- )
272
-
273
- # Aggregate patterns by operation type and fused platform
274
- aggregated: dict[tuple[str, str], FusionPattern] = {}
275
- for pattern in fusion_patterns:
276
- key = (pattern.operation, pattern.fused_platform)
277
- if key in aggregated:
278
- existing = aggregated[key]
279
- existing.count += pattern.count
280
- # Merge unfused kernels (deduplicate)
281
- for k in pattern.unfused_kernels:
282
- if k not in existing.unfused_kernels:
283
- existing.unfused_kernels.append(k)
284
- else:
285
- aggregated[key] = FusionPattern(
286
- layer=pattern.layer,
287
- operation=pattern.operation,
288
- fused_platform=pattern.fused_platform,
289
- fused_kernel=pattern.fused_kernel,
290
- unfused_kernels=list(pattern.unfused_kernels),
291
- count=pattern.count,
292
- evidence=pattern.evidence,
342
+ all_mappings: list[dict] = []
343
+
344
+ # 1. Find explicit fused operations (highest confidence)
345
+ explicit_fusions = _find_explicit_fused_operations(
346
+ amd_kernels, nvidia_kernels,
347
+ trace1_name="AMD", trace2_name="NVIDIA",
348
+ trace1_platform="AMD",
349
+ )
350
+ all_mappings.extend(explicit_fusions)
351
+
352
+ # 2. Find sequence-based fusions
353
+ sequence_fusions = _find_fusion_mappings(
354
+ amd_kernels, nvidia_kernels,
355
+ trace1_name="AMD", trace2_name="NVIDIA",
356
+ trace1_platform="AMD",
357
+ )
358
+ # Deduplicate: skip if same fused_kernel_type already found
359
+ existing_types = {m["fused_kernel_type"] for m in all_mappings}
360
+ for fusion in sequence_fusions:
361
+ if fusion["fused_kernel_type"] not in existing_types:
362
+ all_mappings.append(fusion)
363
+ existing_types.add(fusion["fused_kernel_type"])
364
+
365
+ # 3. Find count-imbalance fusions
366
+ count_fusions = _find_count_imbalance_fusions(
367
+ amd_kernels, nvidia_kernels,
368
+ trace1_name="AMD", trace2_name="NVIDIA",
369
+ trace1_platform="AMD",
370
+ )
371
+ # Deduplicate
372
+ for fusion in count_fusions:
373
+ if fusion["fused_kernel_type"] not in existing_types:
374
+ all_mappings.append(fusion)
375
+ existing_types.add(fusion["fused_kernel_type"])
376
+
377
+ # Convert to FusionPattern objects
378
+ patterns: list[FusionPattern] = []
379
+ for mapping in all_mappings:
380
+ fused_kernel_names = mapping.get("fused_kernel_names", [])
381
+ fused_kernel = fused_kernel_names[0] if fused_kernel_names else mapping["fused_kernel_type"]
382
+
383
+ patterns.append(
384
+ FusionPattern(
385
+ layer=0, # Pattern-based analysis doesn't track layers
386
+ operation=mapping["fused_kernel_type"],
387
+ fused_platform=mapping["fused_platform"],
388
+ fused_kernel=fused_kernel,
389
+ unfused_kernels=mapping["unfused_sequence"],
390
+ count=mapping["pattern_count"],
391
+ evidence=mapping["evidence"],
293
392
  )
393
+ )
294
394
 
295
- amd_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "AMD")
296
- nvidia_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "NVIDIA")
395
+ amd_fuses = sum(1 for p in patterns if p.fused_platform == "AMD")
396
+ nvidia_fuses = sum(1 for p in patterns if p.fused_platform == "NVIDIA")
297
397
 
298
398
  return FusionAnalysis(
299
- patterns=list(aggregated.values()),
399
+ patterns=patterns,
300
400
  summary={
301
401
  "amd_fuses": amd_fuses,
302
402
  "nvidia_fuses": nvidia_fuses,
303
- "total_fusion_opportunities": len(aggregated),
403
+ "total_fusion_opportunities": len(patterns),
304
404
  },
305
405
  )
306
406
 
307
407
 
308
408
  def analyze_fusion_from_alignment(
309
- layer_alignments: list[LayerAlignment],
409
+ layer_alignments: list[Any],
410
+ amd_kernels: list[dict] | None = None,
411
+ nvidia_kernels: list[dict] | None = None,
310
412
  ) -> dict[str, Any]:
311
- """Analyze fusion from alignment data (for API compatibility).
413
+ """Analyze fusion from kernel data (for API compatibility).
312
414
 
313
415
  Args:
314
- layer_alignments: List of aligned layers
416
+ layer_alignments: List of aligned layers (unused - kept for API compatibility)
417
+ amd_kernels: Optional list of AMD kernel events for pattern-based analysis
418
+ nvidia_kernels: Optional list of NVIDIA kernel events for pattern-based analysis
315
419
 
316
420
  Returns:
317
- Dictionary with fusion analysis results (compatible with old API)
421
+ Dictionary with fusion analysis results
318
422
  """
319
- fusion_analysis = detect_fusion_in_aligned_layers(layer_alignments)
423
+ # If raw kernels provided, use pattern-based analysis (preferred)
424
+ if amd_kernels is not None and nvidia_kernels is not None:
425
+ fusion_analysis = detect_fusion_patterns(amd_kernels, nvidia_kernels)
426
+ else:
427
+ # Fallback: empty analysis if no kernel data
428
+ fusion_analysis = FusionAnalysis(
429
+ patterns=[],
430
+ summary={"amd_fuses": 0, "nvidia_fuses": 0, "total_fusion_opportunities": 0},
431
+ )
320
432
 
321
433
  fusion_opportunities = []
322
434
  fusion_mappings = []
323
435
 
324
436
  for pattern in fusion_analysis.patterns:
325
437
  unfused_platform = "NVIDIA" if pattern.fused_platform == "AMD" else "AMD"
326
-
327
- fusion_opportunities.append(
328
- {
329
- "kernel_type": pattern.operation,
330
- "layer": pattern.layer,
331
- "fused_by": pattern.fused_platform,
332
- "fused_kernel": pattern.fused_kernel,
333
- "unfused_kernels": pattern.unfused_kernels,
334
- "count": pattern.count,
335
- "evidence": pattern.evidence,
336
- }
337
- )
338
438
 
339
- fusion_mappings.append(
340
- {
341
- "fused_platform": pattern.fused_platform,
342
- "fused_kernel_type": pattern.operation,
343
- "fused_kernel_name": pattern.fused_kernel,
344
- "unfused_platform": unfused_platform,
345
- "unfused_sequence": pattern.unfused_kernels,
346
- "pattern_count": pattern.count,
347
- "evidence": pattern.evidence,
348
- "layer": pattern.layer,
349
- }
350
- )
439
+ fusion_opportunities.append({
440
+ "kernel_type": pattern.operation,
441
+ "layer": pattern.layer,
442
+ "fused_by": pattern.fused_platform,
443
+ "fused_kernel": pattern.fused_kernel,
444
+ "unfused_kernels": pattern.unfused_kernels,
445
+ "count": pattern.count,
446
+ "evidence": pattern.evidence,
447
+ })
448
+
449
+ fusion_mappings.append({
450
+ "fused_platform": pattern.fused_platform,
451
+ "fused_kernel_type": pattern.operation,
452
+ "fused_kernel_name": pattern.fused_kernel,
453
+ "unfused_platform": unfused_platform,
454
+ "unfused_sequence": pattern.unfused_kernels,
455
+ "pattern_count": pattern.count,
456
+ "evidence": pattern.evidence,
457
+ "layer": pattern.layer,
458
+ })
351
459
 
352
460
  return {
353
461
  "fusion_opportunities": fusion_opportunities,