wafer-core 0.1.27__py3-none-any.whl → 0.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,25 +1,18 @@
1
- """Fusion analysis using aligned kernel pairs.
1
+ """Fusion analysis for detecting kernel fusion differences between platforms.
2
2
 
3
- Detects kernel fusion patterns at the kernel-pair level within aligned layers.
3
+ Detects fusion differences between AMD and NVIDIA by analyzing:
4
+ 1. Kernel types unique to each platform (one platform fuses them away)
5
+ 2. Sequence patterns (what comes before/after unique kernels)
6
+ 3. Count imbalances (one platform has significantly more calls)
4
7
 
5
- Key insight: When one platform has an operation type like "RMSNorm+GEMM" (containing '+'),
6
- that platform IS fusing. The OTHER platform runs the components separately.
8
+ This is the pattern-based approach which is more reliable than alignment-based detection.
7
9
  """
8
10
 
9
11
  from collections import Counter, defaultdict
10
12
  from dataclasses import dataclass, field
11
13
  from typing import Any
12
14
 
13
- from .aligner import KernelPair, LayerAlignment
14
-
15
-
16
- # Maps fused operation names to their component operations
17
- FUSED_OPERATION_COMPONENTS: dict[str, list[str]] = {
18
- "RMSNorm+GEMM": ["RMSNorm", "Dense GEMM"],
19
- "SwiGLU+GEMM": ["SwiGLU", "Dense GEMM"],
20
- "MoE GEMM+SwiGLU": ["MoE GEMM", "SwiGLU"],
21
- "Embedding+RMSNorm+GEMM": ["Elementwise", "RMSNorm", "Dense GEMM"], # Embedding is often Elementwise
22
- }
15
+ from .classifier import classify, Op
23
16
 
24
17
 
25
18
  @dataclass
@@ -43,311 +36,467 @@ class FusionAnalysis:
43
36
  summary: dict[str, Any] = field(default_factory=dict)
44
37
 
45
38
 
46
- def _is_fused_operation(operation: str) -> bool:
47
- """Check if an operation type represents a fused operation.
48
-
49
- Operations like 'RMSNorm+GEMM', 'MoE GEMM+SwiGLU' indicate fusion.
50
- Also handles 'Fused (Unknown)' from heuristic detection.
51
- """
52
- return "+" in operation or operation == "Fused (Unknown)"
39
+ def _classify_kernel(name: str, platform: str = "AMD") -> str:
40
+ """Classify a kernel name to its operation type."""
41
+ op, _pattern = classify(name, platform)
42
+ return op.value
53
43
 
54
44
 
55
- def _get_component_operations(fused_op: str) -> list[str]:
56
- """Get the component operations for a fused operation type."""
57
- if fused_op in FUSED_OPERATION_COMPONENTS:
58
- return FUSED_OPERATION_COMPONENTS[fused_op]
59
-
60
- # Parse from the operation name itself (e.g., "RMSNorm+GEMM" -> ["RMSNorm", "GEMM"])
61
- if "+" in fused_op:
62
- parts = [p.strip() for p in fused_op.split("+")]
63
- # Map shorthand to full names
64
- mapping = {
65
- "GEMM": "Dense GEMM",
66
- "SwiGLU": "Triton Fused",
67
- }
68
- return [mapping.get(p, p) for p in parts]
69
-
70
- return []
45
+ def _find_fusion_mappings(
46
+ trace1_kernels: list[dict],
47
+ trace2_kernels: list[dict],
48
+ trace1_name: str = "Trace1",
49
+ trace2_name: str = "Trace2",
50
+ trace1_platform: str = "AMD",
51
+ ) -> list[dict]:
52
+ """Find fusion mappings by analyzing kernel execution sequence patterns.
71
53
 
54
+ This function identifies when one platform runs multiple kernels separately
55
+ while the other platform fuses them into a single kernel.
72
56
 
73
- def _find_component_kernels(
74
- layer_alignment: LayerAlignment,
75
- component_ops: list[str],
76
- platform: str,
77
- ) -> list[str]:
78
- """Find kernels for component operations on the specified platform.
79
-
80
57
  Args:
81
- layer_alignment: Layer alignment data
82
- component_ops: List of operation types to look for (e.g., ["RMSNorm", "Dense GEMM"])
83
- platform: "AMD" or "NVIDIA"
84
-
58
+ trace1_kernels: List of kernel events from first trace
59
+ trace2_kernels: List of kernel events from second trace
60
+ trace1_name: Name of first platform (e.g., "AMD")
61
+ trace2_name: Name of second platform (e.g., "NVIDIA")
62
+ trace1_platform: Platform string for classification ("AMD" or "NVIDIA")
63
+
85
64
  Returns:
86
- List of kernel names found for these operations
65
+ List of mapping dictionaries with fusion details
87
66
  """
88
- found_kernels: list[str] = []
67
+ mappings = []
68
+ trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
69
+
70
+ # Sort kernels by timestamp
71
+ trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get("ts", 0))
72
+ trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get("ts", 0))
73
+
74
+ # Classify all kernels
75
+ trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_sorted]
76
+ trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_sorted]
77
+
78
+ # Find kernel types unique to each trace
79
+ trace1_type_set = set(trace1_types)
80
+ trace2_type_set = set(trace2_types)
81
+
82
+ trace1_only = trace1_type_set - trace2_type_set
83
+ trace2_only = trace2_type_set - trace1_type_set
84
+
85
+ # For each unique type in trace1, check if it's a fused operation
86
+ # If trace1 has a unique kernel type that trace2 doesn't have, trace1 is likely fusing
87
+ for unique_type in trace1_only:
88
+ # Skip "Other" since it's too generic
89
+ if unique_type == "Other":
90
+ continue
91
+
92
+ # If the unique type contains '+', it's explicitly a fused kernel
93
+ # This means trace1 (which has it) is fusing, not trace2
94
+ if "+" in unique_type:
95
+ # Parse components from the fused op name
96
+ components = [c.strip() for c in unique_type.split("+")]
97
+ indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
98
+
99
+ if len(indices) < 5:
100
+ continue
101
+
102
+ mappings.append({
103
+ "fused_platform": trace1_name,
104
+ "fused_kernel_type": unique_type,
105
+ "fused_count": len(indices),
106
+ "unfused_platform": trace2_name,
107
+ "unfused_sequence": components,
108
+ "unfused_count_per_type": {c: trace2_types.count(c) for c in components},
109
+ "pattern_count": len(indices),
110
+ "pattern_confidence": 1.0,
111
+ "evidence": f"{trace1_name} fuses {' + '.join(components)} into single kernel ({len(indices)} calls), {trace2_name} runs separately",
112
+ })
113
+ continue
114
+
115
+ # For non-fused unique types, find all occurrences
116
+ indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
117
+
118
+ if len(indices) < 5: # Need enough samples to be meaningful
119
+ continue
120
+
121
+ # Analyze what comes before each occurrence
122
+ before_types: dict[str, int] = defaultdict(int)
123
+
124
+ for idx in indices:
125
+ if idx > 0:
126
+ before_types[trace1_types[idx - 1]] += 1
127
+
128
+ # Find the most common pattern
129
+ if not before_types:
130
+ continue
131
+ most_common_before = max(before_types.items(), key=lambda x: x[1])
132
+
133
+ # If there's a strong pattern (>80% of occurrences) and trace2 has the preceding type,
134
+ # trace1 runs them separately while trace2 might fuse
135
+ if most_common_before[1] / len(indices) > 0.8:
136
+ fusion_candidate = most_common_before[0]
137
+
138
+ # Verify trace2 has this type but NOT the unique_type
139
+ if fusion_candidate in trace2_type_set:
140
+ trace1_fusion_count = trace1_types.count(fusion_candidate)
141
+ trace2_fusion_count = trace2_types.count(fusion_candidate)
142
+
143
+ mappings.append({
144
+ "fused_platform": trace2_name,
145
+ "fused_kernel_type": fusion_candidate,
146
+ "fused_count": trace2_fusion_count,
147
+ "unfused_platform": trace1_name,
148
+ "unfused_sequence": [fusion_candidate, unique_type],
149
+ "unfused_count_per_type": {
150
+ fusion_candidate: trace1_fusion_count,
151
+ unique_type: len(indices),
152
+ },
153
+ "pattern_count": len(indices),
154
+ "pattern_confidence": most_common_before[1] / len(indices),
155
+ "evidence": f"{trace1_name} runs {fusion_candidate} + {unique_type} separately, {trace2_name} fuses into {fusion_candidate}",
156
+ })
157
+
158
+ # Also check trace2-only types
159
+ for unique_type in trace2_only:
160
+ if unique_type == "Other":
161
+ continue
162
+
163
+ # If the unique type contains '+', it's explicitly a fused kernel
164
+ # This means trace2 (which has it) is fusing, not trace1
165
+ if "+" in unique_type:
166
+ components = [c.strip() for c in unique_type.split("+")]
167
+ indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
168
+
169
+ if len(indices) < 5:
170
+ continue
171
+
172
+ mappings.append({
173
+ "fused_platform": trace2_name,
174
+ "fused_kernel_type": unique_type,
175
+ "fused_count": len(indices),
176
+ "unfused_platform": trace1_name,
177
+ "unfused_sequence": components,
178
+ "unfused_count_per_type": {c: trace1_types.count(c) for c in components},
179
+ "pattern_count": len(indices),
180
+ "pattern_confidence": 1.0,
181
+ "evidence": f"{trace2_name} fuses {' + '.join(components)} into single kernel ({len(indices)} calls), {trace1_name} runs separately",
182
+ })
183
+ continue
184
+
185
+ indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
186
+
187
+ if len(indices) < 5:
188
+ continue
189
+
190
+ before_types = defaultdict(int)
191
+
192
+ for idx in indices:
193
+ if idx > 0:
194
+ before_types[trace2_types[idx - 1]] += 1
195
+
196
+ if not before_types:
197
+ continue
198
+ most_common_before = max(before_types.items(), key=lambda x: x[1])
199
+
200
+ if most_common_before[1] / len(indices) > 0.8:
201
+ fusion_candidate = most_common_before[0]
202
+
203
+ if fusion_candidate in trace1_type_set:
204
+ trace1_fusion_count = trace1_types.count(fusion_candidate)
205
+ trace2_fusion_count = trace2_types.count(fusion_candidate)
206
+
207
+ mappings.append({
208
+ "fused_platform": trace1_name,
209
+ "fused_kernel_type": fusion_candidate,
210
+ "fused_count": trace1_fusion_count,
211
+ "unfused_platform": trace2_name,
212
+ "unfused_sequence": [fusion_candidate, unique_type],
213
+ "unfused_count_per_type": {
214
+ fusion_candidate: trace2_fusion_count,
215
+ unique_type: len(indices),
216
+ },
217
+ "pattern_count": len(indices),
218
+ "pattern_confidence": most_common_before[1] / len(indices),
219
+ "evidence": f"{trace2_name} runs {fusion_candidate} + {unique_type} separately, {trace1_name} fuses into {fusion_candidate}",
220
+ })
221
+
222
+ return mappings
223
+
224
+
225
+ def _find_count_imbalance_fusions(
226
+ trace1_kernels: list[dict],
227
+ trace2_kernels: list[dict],
228
+ trace1_name: str = "Trace1",
229
+ trace2_name: str = "Trace2",
230
+ trace1_platform: str = "AMD",
231
+ ) -> list[dict]:
232
+ """Find fusions by looking for significant count imbalances.
233
+
234
+ When one platform has significantly more kernel calls of a type (>3x),
235
+ it MAY suggest the other platform fuses those operations.
89
236
 
90
- for pair in layer_alignment.kernel_pairs:
91
- if pair.operation in component_ops:
92
- if platform == "AMD" and pair.amd_kernel and pair.amd_count > 0:
93
- found_kernels.append(pair.amd_kernel)
94
- elif platform == "NVIDIA" and pair.nvidia_kernel and pair.nvidia_count > 0:
95
- found_kernels.append(pair.nvidia_kernel)
237
+ NOTE: This is speculative - count differences can also indicate:
238
+ - Different algorithmic implementations
239
+ - Different library choices (cuBLAS vs hipBLAS)
240
+ - Different optimization strategies
96
241
 
97
- return found_kernels
242
+ Only very large imbalances (>3x) with high counts are flagged.
243
+ """
244
+ mappings = []
245
+ trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
246
+
247
+ # Classify all kernels
248
+ trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_kernels]
249
+ trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_kernels]
250
+
251
+ # Count by type
252
+ trace1_counts = Counter(trace1_types)
253
+ trace2_counts = Counter(trace2_types)
254
+
255
+ # Find common types with significant differences
256
+ common_types = set(trace1_counts.keys()) & set(trace2_counts.keys())
257
+
258
+ # Skip types that are likely just implementation differences, not fusion
259
+ skip_types = {"Reduce", "Copy/Memory", "Sync", "Other", "Elementwise"}
260
+
261
+ for ktype in common_types:
262
+ if ktype in skip_types:
263
+ continue
264
+
265
+ trace1_count = trace1_counts[ktype]
266
+ trace2_count = trace2_counts[ktype]
267
+
268
+ # Skip low counts - need significant samples
269
+ if trace1_count + trace2_count < 200:
270
+ continue
98
271
 
272
+ # Check if there's a very significant imbalance (>3x)
273
+ # Lower ratios are likely implementation differences, not fusion
274
+ if trace1_count == 0 or trace2_count == 0:
275
+ continue
99
276
 
100
- def detect_fusion_in_aligned_layers(
101
- layer_alignments: list[LayerAlignment],
277
+ ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
278
+
279
+ if ratio < 3.0:
280
+ continue
281
+
282
+ # Determine which platform has more (unfused) and which has fewer (fused)
283
+ if trace1_count > trace2_count:
284
+ unfused_platform = trace1_name
285
+ fused_platform = trace2_name
286
+ unfused_count = trace1_count
287
+ fused_count = trace2_count
288
+ else:
289
+ unfused_platform = trace2_name
290
+ fused_platform = trace1_name
291
+ unfused_count = trace2_count
292
+ fused_count = trace1_count
293
+
294
+ mappings.append({
295
+ "fused_platform": fused_platform,
296
+ "fused_kernel_type": ktype,
297
+ "fused_count": fused_count,
298
+ "unfused_platform": unfused_platform,
299
+ "unfused_sequence": [ktype],
300
+ "unfused_count_per_type": {ktype: unfused_count},
301
+ "pattern_count": unfused_count - fused_count,
302
+ "pattern_confidence": (unfused_count - fused_count) / unfused_count,
303
+ "evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}) - possible fusion",
304
+ })
305
+
306
+ return mappings
307
+
308
+
309
+ def _find_explicit_fused_operations(
310
+ trace1_kernels: list[dict],
311
+ trace2_kernels: list[dict],
312
+ trace1_name: str = "Trace1",
313
+ trace2_name: str = "Trace2",
314
+ trace1_platform: str = "AMD",
315
+ ) -> list[dict]:
316
+ """Find explicit fused operations (kernels with '+' in their classification).
317
+
318
+ These are operations like 'RMSNorm+GEMM' that are explicitly classified as fused.
319
+ """
320
+ mappings = []
321
+ trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
322
+
323
+ def get_fused_ops(kernels: list[dict], platform: str) -> dict[str, list[str]]:
324
+ """Get fused operations and their kernel names."""
325
+ fused_ops: dict[str, list[str]] = defaultdict(list)
326
+ for k in kernels:
327
+ name = k.get("name", "")
328
+ op, _pattern = classify(name, platform)
329
+ if "+" in op.value or op == Op.FUSED_UNKNOWN:
330
+ fused_ops[op.value].append(name)
331
+ return dict(fused_ops)
332
+
333
+ trace1_fused = get_fused_ops(trace1_kernels, trace1_platform)
334
+ trace2_fused = get_fused_ops(trace2_kernels, trace2_platform)
335
+
336
+ # Find fused operations unique to trace1
337
+ for fused_op, kernels in trace1_fused.items():
338
+ if fused_op not in trace2_fused:
339
+ # Parse components from the fused op name
340
+ if "+" in fused_op:
341
+ components = [c.strip() for c in fused_op.split("+")]
342
+ else:
343
+ components = [fused_op]
344
+
345
+ mappings.append({
346
+ "fused_platform": trace1_name,
347
+ "fused_kernel_type": fused_op,
348
+ "fused_count": len(kernels),
349
+ "unfused_platform": trace2_name,
350
+ "unfused_sequence": components,
351
+ "unfused_count_per_type": {c: 0 for c in components}, # Unknown
352
+ "pattern_count": len(kernels),
353
+ "pattern_confidence": 1.0,
354
+ "evidence": f"{trace1_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
355
+ "fused_kernel_names": kernels[:3], # Sample of kernel names
356
+ })
357
+
358
+ # Find fused operations unique to trace2
359
+ for fused_op, kernels in trace2_fused.items():
360
+ if fused_op not in trace1_fused:
361
+ if "+" in fused_op:
362
+ components = [c.strip() for c in fused_op.split("+")]
363
+ else:
364
+ components = [fused_op]
365
+
366
+ mappings.append({
367
+ "fused_platform": trace2_name,
368
+ "fused_kernel_type": fused_op,
369
+ "fused_count": len(kernels),
370
+ "unfused_platform": trace1_name,
371
+ "unfused_sequence": components,
372
+ "unfused_count_per_type": {c: 0 for c in components},
373
+ "pattern_count": len(kernels),
374
+ "pattern_confidence": 1.0,
375
+ "evidence": f"{trace2_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
376
+ "fused_kernel_names": kernels[:3],
377
+ })
378
+
379
+ return mappings
380
+
381
+
382
+ def detect_fusion_patterns(
383
+ amd_kernels: list[dict],
384
+ nvidia_kernels: list[dict],
102
385
  ) -> FusionAnalysis:
103
- """Detect fusion patterns from aligned layer data.
386
+ """Detect fusion patterns using explicit fused operation detection only.
387
+
388
+ This approach only reports high-confidence fusions where kernels are
389
+ explicitly classified as fused (e.g., 'RMSNorm+GEMM', 'SwiGLU+GEMM').
390
+
391
+ We intentionally avoid speculative detection (sequence patterns, count
392
+ imbalances) because these produce too many false positives - count
393
+ differences are usually due to different implementations, not fusion.
104
394
 
105
395
  Args:
106
- layer_alignments: List of aligned layers
396
+ amd_kernels: List of AMD kernel events
397
+ nvidia_kernels: List of NVIDIA kernel events
107
398
 
108
399
  Returns:
109
400
  FusionAnalysis with detected patterns
110
401
  """
111
- fusion_patterns: list[FusionPattern] = []
402
+ all_mappings: list[dict] = []
403
+
404
+ # Only use explicit fused operations (highest confidence, no false positives)
405
+ # These are kernels explicitly classified with '+' in their operation type
406
+ explicit_fusions = _find_explicit_fused_operations(
407
+ amd_kernels, nvidia_kernels,
408
+ trace1_name="AMD", trace2_name="NVIDIA",
409
+ trace1_platform="AMD",
410
+ )
411
+ all_mappings.extend(explicit_fusions)
112
412
 
113
- # Track patterns we've already seen to avoid duplicates
114
- seen_patterns: set[tuple[int, str, str]] = set() # (layer, operation, fused_platform)
115
-
116
- for layer_alignment in layer_alignments:
117
- for pair in layer_alignment.kernel_pairs:
118
- is_fused_op = _is_fused_operation(pair.operation)
119
-
120
- # Case 1: AMD has a kernel for this operation, NVIDIA doesn't
121
- if pair.amd_kernel and pair.amd_count > 0 and (pair.nvidia_kernel is None or pair.nvidia_count == 0):
122
- if is_fused_op:
123
- # AMD HAS the fused kernel → AMD is fusing
124
- # Find what NVIDIA runs separately
125
- component_ops = _get_component_operations(pair.operation)
126
- unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "NVIDIA")
127
-
128
- pattern_key = (layer_alignment.layer, pair.operation, "AMD")
129
- if pattern_key not in seen_patterns:
130
- seen_patterns.add(pattern_key)
131
-
132
- if unfused_kernels:
133
- evidence = f"AMD fuses {' + '.join(component_ops)} into {pair.amd_kernel}, NVIDIA runs {len(unfused_kernels)} separate kernels"
134
- else:
135
- evidence = f"AMD fuses into {pair.amd_kernel}, NVIDIA runs components separately"
136
-
137
- fusion_patterns.append(
138
- FusionPattern(
139
- layer=layer_alignment.layer,
140
- operation=pair.operation,
141
- fused_platform="AMD",
142
- fused_kernel=pair.amd_kernel,
143
- unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
144
- count=pair.amd_count,
145
- evidence=evidence,
146
- )
147
- )
148
- else:
149
- # Regular case: AMD has a kernel that NVIDIA fuses into something else
150
- # This means NVIDIA is fusing (it doesn't need this separate kernel)
151
- pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
152
- if pattern_key not in seen_patterns:
153
- seen_patterns.add(pattern_key)
154
-
155
- evidence = f"AMD runs {pair.amd_kernel} separately ({pair.amd_count}x), NVIDIA fuses this operation"
156
-
157
- fusion_patterns.append(
158
- FusionPattern(
159
- layer=layer_alignment.layer,
160
- operation=pair.operation,
161
- fused_platform="NVIDIA",
162
- fused_kernel="(fused into nearby kernel)",
163
- unfused_kernels=[pair.amd_kernel],
164
- count=pair.amd_count,
165
- evidence=evidence,
166
- )
167
- )
168
-
169
- # Case 2: NVIDIA has a kernel for this operation, AMD doesn't
170
- elif pair.nvidia_kernel and pair.nvidia_count > 0 and (pair.amd_kernel is None or pair.amd_count == 0):
171
- if is_fused_op:
172
- # NVIDIA HAS the fused kernel → NVIDIA is fusing
173
- # Find what AMD runs separately
174
- component_ops = _get_component_operations(pair.operation)
175
- unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "AMD")
176
-
177
- pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
178
- if pattern_key not in seen_patterns:
179
- seen_patterns.add(pattern_key)
180
-
181
- if unfused_kernels:
182
- evidence = f"NVIDIA fuses {' + '.join(component_ops)} into {pair.nvidia_kernel}, AMD runs {len(unfused_kernels)} separate kernels"
183
- else:
184
- evidence = f"NVIDIA fuses into {pair.nvidia_kernel}, AMD runs components separately"
185
-
186
- fusion_patterns.append(
187
- FusionPattern(
188
- layer=layer_alignment.layer,
189
- operation=pair.operation,
190
- fused_platform="NVIDIA",
191
- fused_kernel=pair.nvidia_kernel,
192
- unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
193
- count=pair.nvidia_count,
194
- evidence=evidence,
195
- )
196
- )
197
- else:
198
- # Regular case: NVIDIA has a kernel that AMD fuses into something else
199
- pattern_key = (layer_alignment.layer, pair.operation, "AMD")
200
- if pattern_key not in seen_patterns:
201
- seen_patterns.add(pattern_key)
202
-
203
- evidence = f"NVIDIA runs {pair.nvidia_kernel} separately ({pair.nvidia_count}x), AMD fuses this operation"
204
-
205
- fusion_patterns.append(
206
- FusionPattern(
207
- layer=layer_alignment.layer,
208
- operation=pair.operation,
209
- fused_platform="AMD",
210
- fused_kernel="(fused into nearby kernel)",
211
- unfused_kernels=[pair.nvidia_kernel],
212
- count=pair.nvidia_count,
213
- evidence=evidence,
214
- )
215
- )
216
-
217
- # Case 3: Both have kernels but with very different counts (partial fusion)
218
- elif (
219
- pair.amd_kernel
220
- and pair.nvidia_kernel
221
- and pair.amd_count > 0
222
- and pair.nvidia_count > 0
223
- ):
224
- count_ratio = pair.amd_count / pair.nvidia_count if pair.nvidia_count > 0 else float('inf')
225
-
226
- if count_ratio > 1.5:
227
- # AMD runs more → NVIDIA fuses some instances
228
- pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
229
- if pattern_key not in seen_patterns:
230
- seen_patterns.add(pattern_key)
231
-
232
- evidence = (
233
- f"AMD runs {pair.amd_kernel} {count_ratio:.1f}x more "
234
- f"({pair.amd_count} vs {pair.nvidia_count}), NVIDIA partially fuses"
235
- )
236
-
237
- fusion_patterns.append(
238
- FusionPattern(
239
- layer=layer_alignment.layer,
240
- operation=pair.operation,
241
- fused_platform="NVIDIA",
242
- fused_kernel=pair.nvidia_kernel,
243
- unfused_kernels=[pair.amd_kernel],
244
- count=pair.amd_count - pair.nvidia_count,
245
- evidence=evidence,
246
- )
247
- )
248
-
249
- elif count_ratio < 0.67:
250
- # NVIDIA runs more → AMD fuses some instances
251
- inverse_ratio = pair.nvidia_count / pair.amd_count if pair.amd_count > 0 else float('inf')
252
- pattern_key = (layer_alignment.layer, pair.operation, "AMD")
253
- if pattern_key not in seen_patterns:
254
- seen_patterns.add(pattern_key)
255
-
256
- evidence = (
257
- f"NVIDIA runs {pair.nvidia_kernel} {inverse_ratio:.1f}x more "
258
- f"({pair.nvidia_count} vs {pair.amd_count}), AMD partially fuses"
259
- )
260
-
261
- fusion_patterns.append(
262
- FusionPattern(
263
- layer=layer_alignment.layer,
264
- operation=pair.operation,
265
- fused_platform="AMD",
266
- fused_kernel=pair.amd_kernel,
267
- unfused_kernels=[pair.nvidia_kernel],
268
- count=pair.nvidia_count - pair.amd_count,
269
- evidence=evidence,
270
- )
271
- )
272
-
273
- # Aggregate patterns by operation type and fused platform
274
- aggregated: dict[tuple[str, str], FusionPattern] = {}
275
- for pattern in fusion_patterns:
276
- key = (pattern.operation, pattern.fused_platform)
277
- if key in aggregated:
278
- existing = aggregated[key]
279
- existing.count += pattern.count
280
- # Merge unfused kernels (deduplicate)
281
- for k in pattern.unfused_kernels:
282
- if k not in existing.unfused_kernels:
283
- existing.unfused_kernels.append(k)
284
- else:
285
- aggregated[key] = FusionPattern(
286
- layer=pattern.layer,
287
- operation=pattern.operation,
288
- fused_platform=pattern.fused_platform,
289
- fused_kernel=pattern.fused_kernel,
290
- unfused_kernels=list(pattern.unfused_kernels),
291
- count=pattern.count,
292
- evidence=pattern.evidence,
413
+ # NOTE: We intentionally skip sequence-based and count-imbalance detection
414
+ # because they produce false positives. Count differences between platforms
415
+ # are usually due to different library implementations (cuBLAS vs hipBLAS),
416
+ # not actual kernel fusion.
417
+
418
+ # Convert to FusionPattern objects
419
+ patterns: list[FusionPattern] = []
420
+ for mapping in all_mappings:
421
+ fused_kernel_names = mapping.get("fused_kernel_names", [])
422
+ fused_kernel = fused_kernel_names[0] if fused_kernel_names else mapping["fused_kernel_type"]
423
+
424
+ patterns.append(
425
+ FusionPattern(
426
+ layer=0, # Pattern-based analysis doesn't track layers
427
+ operation=mapping["fused_kernel_type"],
428
+ fused_platform=mapping["fused_platform"],
429
+ fused_kernel=fused_kernel,
430
+ unfused_kernels=mapping["unfused_sequence"],
431
+ count=mapping["pattern_count"],
432
+ evidence=mapping["evidence"],
293
433
  )
434
+ )
294
435
 
295
- amd_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "AMD")
296
- nvidia_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "NVIDIA")
436
+ amd_fuses = sum(1 for p in patterns if p.fused_platform == "AMD")
437
+ nvidia_fuses = sum(1 for p in patterns if p.fused_platform == "NVIDIA")
297
438
 
298
439
  return FusionAnalysis(
299
- patterns=list(aggregated.values()),
440
+ patterns=patterns,
300
441
  summary={
301
442
  "amd_fuses": amd_fuses,
302
443
  "nvidia_fuses": nvidia_fuses,
303
- "total_fusion_opportunities": len(aggregated),
444
+ "total_fusion_opportunities": len(patterns),
304
445
  },
305
446
  )
306
447
 
307
448
 
308
449
  def analyze_fusion_from_alignment(
309
- layer_alignments: list[LayerAlignment],
450
+ layer_alignments: list[Any],
451
+ amd_kernels: list[dict] | None = None,
452
+ nvidia_kernels: list[dict] | None = None,
310
453
  ) -> dict[str, Any]:
311
- """Analyze fusion from alignment data (for API compatibility).
454
+ """Analyze fusion from kernel data (for API compatibility).
312
455
 
313
456
  Args:
314
- layer_alignments: List of aligned layers
457
+ layer_alignments: List of aligned layers (unused - kept for API compatibility)
458
+ amd_kernels: Optional list of AMD kernel events for pattern-based analysis
459
+ nvidia_kernels: Optional list of NVIDIA kernel events for pattern-based analysis
315
460
 
316
461
  Returns:
317
- Dictionary with fusion analysis results (compatible with old API)
462
+ Dictionary with fusion analysis results
318
463
  """
319
- fusion_analysis = detect_fusion_in_aligned_layers(layer_alignments)
464
+ # If raw kernels provided, use pattern-based analysis (preferred)
465
+ if amd_kernels is not None and nvidia_kernels is not None:
466
+ fusion_analysis = detect_fusion_patterns(amd_kernels, nvidia_kernels)
467
+ else:
468
+ # Fallback: empty analysis if no kernel data
469
+ fusion_analysis = FusionAnalysis(
470
+ patterns=[],
471
+ summary={"amd_fuses": 0, "nvidia_fuses": 0, "total_fusion_opportunities": 0},
472
+ )
320
473
 
321
474
  fusion_opportunities = []
322
475
  fusion_mappings = []
323
476
 
324
477
  for pattern in fusion_analysis.patterns:
325
478
  unfused_platform = "NVIDIA" if pattern.fused_platform == "AMD" else "AMD"
326
-
327
- fusion_opportunities.append(
328
- {
329
- "kernel_type": pattern.operation,
330
- "layer": pattern.layer,
331
- "fused_by": pattern.fused_platform,
332
- "fused_kernel": pattern.fused_kernel,
333
- "unfused_kernels": pattern.unfused_kernels,
334
- "count": pattern.count,
335
- "evidence": pattern.evidence,
336
- }
337
- )
338
479
 
339
- fusion_mappings.append(
340
- {
341
- "fused_platform": pattern.fused_platform,
342
- "fused_kernel_type": pattern.operation,
343
- "fused_kernel_name": pattern.fused_kernel,
344
- "unfused_platform": unfused_platform,
345
- "unfused_sequence": pattern.unfused_kernels,
346
- "pattern_count": pattern.count,
347
- "evidence": pattern.evidence,
348
- "layer": pattern.layer,
349
- }
350
- )
480
+ fusion_opportunities.append({
481
+ "kernel_type": pattern.operation,
482
+ "layer": pattern.layer,
483
+ "fused_by": pattern.fused_platform,
484
+ "fused_kernel": pattern.fused_kernel,
485
+ "unfused_kernels": pattern.unfused_kernels,
486
+ "count": pattern.count,
487
+ "evidence": pattern.evidence,
488
+ })
489
+
490
+ fusion_mappings.append({
491
+ "fused_platform": pattern.fused_platform,
492
+ "fused_kernel_type": pattern.operation,
493
+ "fused_kernel_name": pattern.fused_kernel,
494
+ "unfused_platform": unfused_platform,
495
+ "unfused_sequence": pattern.unfused_kernels,
496
+ "pattern_count": pattern.count,
497
+ "evidence": pattern.evidence,
498
+ "layer": pattern.layer,
499
+ })
351
500
 
352
501
  return {
353
502
  "fusion_opportunities": fusion_opportunities,