wafer-core 0.1.27__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/aligner.py +13 -6
- wafer_core/lib/trace_compare/analyzer.py +12 -3
- wafer_core/lib/trace_compare/classifier.py +18 -9
- wafer_core/lib/trace_compare/fusion_analyzer.py +424 -275
- wafer_core/targets/__init__.py +47 -21
- wafer_core/targets/pool.py +181 -0
- wafer_core/targets/probe.py +113 -0
- wafer_core/targets/providers/__init__.py +46 -0
- wafer_core/targets/providers/baremetal.py +72 -0
- wafer_core/targets/providers/digitalocean.py +164 -0
- wafer_core/targets/providers/runpod.py +250 -0
- wafer_core/targets/reconcile.py +90 -0
- wafer_core/targets/spec_store.py +200 -0
- wafer_core/targets/state_cache.py +150 -0
- wafer_core/targets/types.py +141 -0
- wafer_core/utils/kernel_utils/targets/config.py +8 -24
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.29.dist-info}/METADATA +1 -1
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.29.dist-info}/RECORD +19 -9
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.29.dist-info}/WHEEL +0 -0
|
@@ -1,25 +1,18 @@
|
|
|
1
|
-
"""Fusion analysis
|
|
1
|
+
"""Fusion analysis for detecting kernel fusion differences between platforms.
|
|
2
2
|
|
|
3
|
-
Detects
|
|
3
|
+
Detects fusion differences between AMD and NVIDIA by analyzing:
|
|
4
|
+
1. Kernel types unique to each platform (one platform fuses them away)
|
|
5
|
+
2. Sequence patterns (what comes before/after unique kernels)
|
|
6
|
+
3. Count imbalances (one platform has significantly more calls)
|
|
4
7
|
|
|
5
|
-
|
|
6
|
-
that platform IS fusing. The OTHER platform runs the components separately.
|
|
8
|
+
This is the pattern-based approach which is more reliable than alignment-based detection.
|
|
7
9
|
"""
|
|
8
10
|
|
|
9
11
|
from collections import Counter, defaultdict
|
|
10
12
|
from dataclasses import dataclass, field
|
|
11
13
|
from typing import Any
|
|
12
14
|
|
|
13
|
-
from .
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
# Maps fused operation names to their component operations
|
|
17
|
-
FUSED_OPERATION_COMPONENTS: dict[str, list[str]] = {
|
|
18
|
-
"RMSNorm+GEMM": ["RMSNorm", "Dense GEMM"],
|
|
19
|
-
"SwiGLU+GEMM": ["SwiGLU", "Dense GEMM"],
|
|
20
|
-
"MoE GEMM+SwiGLU": ["MoE GEMM", "SwiGLU"],
|
|
21
|
-
"Embedding+RMSNorm+GEMM": ["Elementwise", "RMSNorm", "Dense GEMM"], # Embedding is often Elementwise
|
|
22
|
-
}
|
|
15
|
+
from .classifier import classify, Op
|
|
23
16
|
|
|
24
17
|
|
|
25
18
|
@dataclass
|
|
@@ -43,311 +36,467 @@ class FusionAnalysis:
|
|
|
43
36
|
summary: dict[str, Any] = field(default_factory=dict)
|
|
44
37
|
|
|
45
38
|
|
|
46
|
-
def
|
|
47
|
-
"""
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
Also handles 'Fused (Unknown)' from heuristic detection.
|
|
51
|
-
"""
|
|
52
|
-
return "+" in operation or operation == "Fused (Unknown)"
|
|
39
|
+
def _classify_kernel(name: str, platform: str = "AMD") -> str:
|
|
40
|
+
"""Classify a kernel name to its operation type."""
|
|
41
|
+
op, _pattern = classify(name, platform)
|
|
42
|
+
return op.value
|
|
53
43
|
|
|
54
44
|
|
|
55
|
-
def
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
# Map shorthand to full names
|
|
64
|
-
mapping = {
|
|
65
|
-
"GEMM": "Dense GEMM",
|
|
66
|
-
"SwiGLU": "Triton Fused",
|
|
67
|
-
}
|
|
68
|
-
return [mapping.get(p, p) for p in parts]
|
|
69
|
-
|
|
70
|
-
return []
|
|
45
|
+
def _find_fusion_mappings(
|
|
46
|
+
trace1_kernels: list[dict],
|
|
47
|
+
trace2_kernels: list[dict],
|
|
48
|
+
trace1_name: str = "Trace1",
|
|
49
|
+
trace2_name: str = "Trace2",
|
|
50
|
+
trace1_platform: str = "AMD",
|
|
51
|
+
) -> list[dict]:
|
|
52
|
+
"""Find fusion mappings by analyzing kernel execution sequence patterns.
|
|
71
53
|
|
|
54
|
+
This function identifies when one platform runs multiple kernels separately
|
|
55
|
+
while the other platform fuses them into a single kernel.
|
|
72
56
|
|
|
73
|
-
def _find_component_kernels(
|
|
74
|
-
layer_alignment: LayerAlignment,
|
|
75
|
-
component_ops: list[str],
|
|
76
|
-
platform: str,
|
|
77
|
-
) -> list[str]:
|
|
78
|
-
"""Find kernels for component operations on the specified platform.
|
|
79
|
-
|
|
80
57
|
Args:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
58
|
+
trace1_kernels: List of kernel events from first trace
|
|
59
|
+
trace2_kernels: List of kernel events from second trace
|
|
60
|
+
trace1_name: Name of first platform (e.g., "AMD")
|
|
61
|
+
trace2_name: Name of second platform (e.g., "NVIDIA")
|
|
62
|
+
trace1_platform: Platform string for classification ("AMD" or "NVIDIA")
|
|
63
|
+
|
|
85
64
|
Returns:
|
|
86
|
-
List of
|
|
65
|
+
List of mapping dictionaries with fusion details
|
|
87
66
|
"""
|
|
88
|
-
|
|
67
|
+
mappings = []
|
|
68
|
+
trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
|
|
69
|
+
|
|
70
|
+
# Sort kernels by timestamp
|
|
71
|
+
trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get("ts", 0))
|
|
72
|
+
trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get("ts", 0))
|
|
73
|
+
|
|
74
|
+
# Classify all kernels
|
|
75
|
+
trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_sorted]
|
|
76
|
+
trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_sorted]
|
|
77
|
+
|
|
78
|
+
# Find kernel types unique to each trace
|
|
79
|
+
trace1_type_set = set(trace1_types)
|
|
80
|
+
trace2_type_set = set(trace2_types)
|
|
81
|
+
|
|
82
|
+
trace1_only = trace1_type_set - trace2_type_set
|
|
83
|
+
trace2_only = trace2_type_set - trace1_type_set
|
|
84
|
+
|
|
85
|
+
# For each unique type in trace1, check if it's a fused operation
|
|
86
|
+
# If trace1 has a unique kernel type that trace2 doesn't have, trace1 is likely fusing
|
|
87
|
+
for unique_type in trace1_only:
|
|
88
|
+
# Skip "Other" since it's too generic
|
|
89
|
+
if unique_type == "Other":
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
# If the unique type contains '+', it's explicitly a fused kernel
|
|
93
|
+
# This means trace1 (which has it) is fusing, not trace2
|
|
94
|
+
if "+" in unique_type:
|
|
95
|
+
# Parse components from the fused op name
|
|
96
|
+
components = [c.strip() for c in unique_type.split("+")]
|
|
97
|
+
indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
|
|
98
|
+
|
|
99
|
+
if len(indices) < 5:
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
mappings.append({
|
|
103
|
+
"fused_platform": trace1_name,
|
|
104
|
+
"fused_kernel_type": unique_type,
|
|
105
|
+
"fused_count": len(indices),
|
|
106
|
+
"unfused_platform": trace2_name,
|
|
107
|
+
"unfused_sequence": components,
|
|
108
|
+
"unfused_count_per_type": {c: trace2_types.count(c) for c in components},
|
|
109
|
+
"pattern_count": len(indices),
|
|
110
|
+
"pattern_confidence": 1.0,
|
|
111
|
+
"evidence": f"{trace1_name} fuses {' + '.join(components)} into single kernel ({len(indices)} calls), {trace2_name} runs separately",
|
|
112
|
+
})
|
|
113
|
+
continue
|
|
114
|
+
|
|
115
|
+
# For non-fused unique types, find all occurrences
|
|
116
|
+
indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
|
|
117
|
+
|
|
118
|
+
if len(indices) < 5: # Need enough samples to be meaningful
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
# Analyze what comes before each occurrence
|
|
122
|
+
before_types: dict[str, int] = defaultdict(int)
|
|
123
|
+
|
|
124
|
+
for idx in indices:
|
|
125
|
+
if idx > 0:
|
|
126
|
+
before_types[trace1_types[idx - 1]] += 1
|
|
127
|
+
|
|
128
|
+
# Find the most common pattern
|
|
129
|
+
if not before_types:
|
|
130
|
+
continue
|
|
131
|
+
most_common_before = max(before_types.items(), key=lambda x: x[1])
|
|
132
|
+
|
|
133
|
+
# If there's a strong pattern (>80% of occurrences) and trace2 has the preceding type,
|
|
134
|
+
# trace1 runs them separately while trace2 might fuse
|
|
135
|
+
if most_common_before[1] / len(indices) > 0.8:
|
|
136
|
+
fusion_candidate = most_common_before[0]
|
|
137
|
+
|
|
138
|
+
# Verify trace2 has this type but NOT the unique_type
|
|
139
|
+
if fusion_candidate in trace2_type_set:
|
|
140
|
+
trace1_fusion_count = trace1_types.count(fusion_candidate)
|
|
141
|
+
trace2_fusion_count = trace2_types.count(fusion_candidate)
|
|
142
|
+
|
|
143
|
+
mappings.append({
|
|
144
|
+
"fused_platform": trace2_name,
|
|
145
|
+
"fused_kernel_type": fusion_candidate,
|
|
146
|
+
"fused_count": trace2_fusion_count,
|
|
147
|
+
"unfused_platform": trace1_name,
|
|
148
|
+
"unfused_sequence": [fusion_candidate, unique_type],
|
|
149
|
+
"unfused_count_per_type": {
|
|
150
|
+
fusion_candidate: trace1_fusion_count,
|
|
151
|
+
unique_type: len(indices),
|
|
152
|
+
},
|
|
153
|
+
"pattern_count": len(indices),
|
|
154
|
+
"pattern_confidence": most_common_before[1] / len(indices),
|
|
155
|
+
"evidence": f"{trace1_name} runs {fusion_candidate} + {unique_type} separately, {trace2_name} fuses into {fusion_candidate}",
|
|
156
|
+
})
|
|
157
|
+
|
|
158
|
+
# Also check trace2-only types
|
|
159
|
+
for unique_type in trace2_only:
|
|
160
|
+
if unique_type == "Other":
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
# If the unique type contains '+', it's explicitly a fused kernel
|
|
164
|
+
# This means trace2 (which has it) is fusing, not trace1
|
|
165
|
+
if "+" in unique_type:
|
|
166
|
+
components = [c.strip() for c in unique_type.split("+")]
|
|
167
|
+
indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
|
|
168
|
+
|
|
169
|
+
if len(indices) < 5:
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
mappings.append({
|
|
173
|
+
"fused_platform": trace2_name,
|
|
174
|
+
"fused_kernel_type": unique_type,
|
|
175
|
+
"fused_count": len(indices),
|
|
176
|
+
"unfused_platform": trace1_name,
|
|
177
|
+
"unfused_sequence": components,
|
|
178
|
+
"unfused_count_per_type": {c: trace1_types.count(c) for c in components},
|
|
179
|
+
"pattern_count": len(indices),
|
|
180
|
+
"pattern_confidence": 1.0,
|
|
181
|
+
"evidence": f"{trace2_name} fuses {' + '.join(components)} into single kernel ({len(indices)} calls), {trace1_name} runs separately",
|
|
182
|
+
})
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
|
|
186
|
+
|
|
187
|
+
if len(indices) < 5:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
before_types = defaultdict(int)
|
|
191
|
+
|
|
192
|
+
for idx in indices:
|
|
193
|
+
if idx > 0:
|
|
194
|
+
before_types[trace2_types[idx - 1]] += 1
|
|
195
|
+
|
|
196
|
+
if not before_types:
|
|
197
|
+
continue
|
|
198
|
+
most_common_before = max(before_types.items(), key=lambda x: x[1])
|
|
199
|
+
|
|
200
|
+
if most_common_before[1] / len(indices) > 0.8:
|
|
201
|
+
fusion_candidate = most_common_before[0]
|
|
202
|
+
|
|
203
|
+
if fusion_candidate in trace1_type_set:
|
|
204
|
+
trace1_fusion_count = trace1_types.count(fusion_candidate)
|
|
205
|
+
trace2_fusion_count = trace2_types.count(fusion_candidate)
|
|
206
|
+
|
|
207
|
+
mappings.append({
|
|
208
|
+
"fused_platform": trace1_name,
|
|
209
|
+
"fused_kernel_type": fusion_candidate,
|
|
210
|
+
"fused_count": trace1_fusion_count,
|
|
211
|
+
"unfused_platform": trace2_name,
|
|
212
|
+
"unfused_sequence": [fusion_candidate, unique_type],
|
|
213
|
+
"unfused_count_per_type": {
|
|
214
|
+
fusion_candidate: trace2_fusion_count,
|
|
215
|
+
unique_type: len(indices),
|
|
216
|
+
},
|
|
217
|
+
"pattern_count": len(indices),
|
|
218
|
+
"pattern_confidence": most_common_before[1] / len(indices),
|
|
219
|
+
"evidence": f"{trace2_name} runs {fusion_candidate} + {unique_type} separately, {trace1_name} fuses into {fusion_candidate}",
|
|
220
|
+
})
|
|
221
|
+
|
|
222
|
+
return mappings
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _find_count_imbalance_fusions(
|
|
226
|
+
trace1_kernels: list[dict],
|
|
227
|
+
trace2_kernels: list[dict],
|
|
228
|
+
trace1_name: str = "Trace1",
|
|
229
|
+
trace2_name: str = "Trace2",
|
|
230
|
+
trace1_platform: str = "AMD",
|
|
231
|
+
) -> list[dict]:
|
|
232
|
+
"""Find fusions by looking for significant count imbalances.
|
|
233
|
+
|
|
234
|
+
When one platform has significantly more kernel calls of a type (>3x),
|
|
235
|
+
it MAY suggest the other platform fuses those operations.
|
|
89
236
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
elif platform == "NVIDIA" and pair.nvidia_kernel and pair.nvidia_count > 0:
|
|
95
|
-
found_kernels.append(pair.nvidia_kernel)
|
|
237
|
+
NOTE: This is speculative - count differences can also indicate:
|
|
238
|
+
- Different algorithmic implementations
|
|
239
|
+
- Different library choices (cuBLAS vs hipBLAS)
|
|
240
|
+
- Different optimization strategies
|
|
96
241
|
|
|
97
|
-
|
|
242
|
+
Only very large imbalances (>3x) with high counts are flagged.
|
|
243
|
+
"""
|
|
244
|
+
mappings = []
|
|
245
|
+
trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
|
|
246
|
+
|
|
247
|
+
# Classify all kernels
|
|
248
|
+
trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_kernels]
|
|
249
|
+
trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_kernels]
|
|
250
|
+
|
|
251
|
+
# Count by type
|
|
252
|
+
trace1_counts = Counter(trace1_types)
|
|
253
|
+
trace2_counts = Counter(trace2_types)
|
|
254
|
+
|
|
255
|
+
# Find common types with significant differences
|
|
256
|
+
common_types = set(trace1_counts.keys()) & set(trace2_counts.keys())
|
|
257
|
+
|
|
258
|
+
# Skip types that are likely just implementation differences, not fusion
|
|
259
|
+
skip_types = {"Reduce", "Copy/Memory", "Sync", "Other", "Elementwise"}
|
|
260
|
+
|
|
261
|
+
for ktype in common_types:
|
|
262
|
+
if ktype in skip_types:
|
|
263
|
+
continue
|
|
264
|
+
|
|
265
|
+
trace1_count = trace1_counts[ktype]
|
|
266
|
+
trace2_count = trace2_counts[ktype]
|
|
267
|
+
|
|
268
|
+
# Skip low counts - need significant samples
|
|
269
|
+
if trace1_count + trace2_count < 200:
|
|
270
|
+
continue
|
|
98
271
|
|
|
272
|
+
# Check if there's a very significant imbalance (>3x)
|
|
273
|
+
# Lower ratios are likely implementation differences, not fusion
|
|
274
|
+
if trace1_count == 0 or trace2_count == 0:
|
|
275
|
+
continue
|
|
99
276
|
|
|
100
|
-
|
|
101
|
-
|
|
277
|
+
ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
|
|
278
|
+
|
|
279
|
+
if ratio < 3.0:
|
|
280
|
+
continue
|
|
281
|
+
|
|
282
|
+
# Determine which platform has more (unfused) and which has fewer (fused)
|
|
283
|
+
if trace1_count > trace2_count:
|
|
284
|
+
unfused_platform = trace1_name
|
|
285
|
+
fused_platform = trace2_name
|
|
286
|
+
unfused_count = trace1_count
|
|
287
|
+
fused_count = trace2_count
|
|
288
|
+
else:
|
|
289
|
+
unfused_platform = trace2_name
|
|
290
|
+
fused_platform = trace1_name
|
|
291
|
+
unfused_count = trace2_count
|
|
292
|
+
fused_count = trace1_count
|
|
293
|
+
|
|
294
|
+
mappings.append({
|
|
295
|
+
"fused_platform": fused_platform,
|
|
296
|
+
"fused_kernel_type": ktype,
|
|
297
|
+
"fused_count": fused_count,
|
|
298
|
+
"unfused_platform": unfused_platform,
|
|
299
|
+
"unfused_sequence": [ktype],
|
|
300
|
+
"unfused_count_per_type": {ktype: unfused_count},
|
|
301
|
+
"pattern_count": unfused_count - fused_count,
|
|
302
|
+
"pattern_confidence": (unfused_count - fused_count) / unfused_count,
|
|
303
|
+
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}) - possible fusion",
|
|
304
|
+
})
|
|
305
|
+
|
|
306
|
+
return mappings
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _find_explicit_fused_operations(
|
|
310
|
+
trace1_kernels: list[dict],
|
|
311
|
+
trace2_kernels: list[dict],
|
|
312
|
+
trace1_name: str = "Trace1",
|
|
313
|
+
trace2_name: str = "Trace2",
|
|
314
|
+
trace1_platform: str = "AMD",
|
|
315
|
+
) -> list[dict]:
|
|
316
|
+
"""Find explicit fused operations (kernels with '+' in their classification).
|
|
317
|
+
|
|
318
|
+
These are operations like 'RMSNorm+GEMM' that are explicitly classified as fused.
|
|
319
|
+
"""
|
|
320
|
+
mappings = []
|
|
321
|
+
trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
|
|
322
|
+
|
|
323
|
+
def get_fused_ops(kernels: list[dict], platform: str) -> dict[str, list[str]]:
|
|
324
|
+
"""Get fused operations and their kernel names."""
|
|
325
|
+
fused_ops: dict[str, list[str]] = defaultdict(list)
|
|
326
|
+
for k in kernels:
|
|
327
|
+
name = k.get("name", "")
|
|
328
|
+
op, _pattern = classify(name, platform)
|
|
329
|
+
if "+" in op.value or op == Op.FUSED_UNKNOWN:
|
|
330
|
+
fused_ops[op.value].append(name)
|
|
331
|
+
return dict(fused_ops)
|
|
332
|
+
|
|
333
|
+
trace1_fused = get_fused_ops(trace1_kernels, trace1_platform)
|
|
334
|
+
trace2_fused = get_fused_ops(trace2_kernels, trace2_platform)
|
|
335
|
+
|
|
336
|
+
# Find fused operations unique to trace1
|
|
337
|
+
for fused_op, kernels in trace1_fused.items():
|
|
338
|
+
if fused_op not in trace2_fused:
|
|
339
|
+
# Parse components from the fused op name
|
|
340
|
+
if "+" in fused_op:
|
|
341
|
+
components = [c.strip() for c in fused_op.split("+")]
|
|
342
|
+
else:
|
|
343
|
+
components = [fused_op]
|
|
344
|
+
|
|
345
|
+
mappings.append({
|
|
346
|
+
"fused_platform": trace1_name,
|
|
347
|
+
"fused_kernel_type": fused_op,
|
|
348
|
+
"fused_count": len(kernels),
|
|
349
|
+
"unfused_platform": trace2_name,
|
|
350
|
+
"unfused_sequence": components,
|
|
351
|
+
"unfused_count_per_type": {c: 0 for c in components}, # Unknown
|
|
352
|
+
"pattern_count": len(kernels),
|
|
353
|
+
"pattern_confidence": 1.0,
|
|
354
|
+
"evidence": f"{trace1_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
|
|
355
|
+
"fused_kernel_names": kernels[:3], # Sample of kernel names
|
|
356
|
+
})
|
|
357
|
+
|
|
358
|
+
# Find fused operations unique to trace2
|
|
359
|
+
for fused_op, kernels in trace2_fused.items():
|
|
360
|
+
if fused_op not in trace1_fused:
|
|
361
|
+
if "+" in fused_op:
|
|
362
|
+
components = [c.strip() for c in fused_op.split("+")]
|
|
363
|
+
else:
|
|
364
|
+
components = [fused_op]
|
|
365
|
+
|
|
366
|
+
mappings.append({
|
|
367
|
+
"fused_platform": trace2_name,
|
|
368
|
+
"fused_kernel_type": fused_op,
|
|
369
|
+
"fused_count": len(kernels),
|
|
370
|
+
"unfused_platform": trace1_name,
|
|
371
|
+
"unfused_sequence": components,
|
|
372
|
+
"unfused_count_per_type": {c: 0 for c in components},
|
|
373
|
+
"pattern_count": len(kernels),
|
|
374
|
+
"pattern_confidence": 1.0,
|
|
375
|
+
"evidence": f"{trace2_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
|
|
376
|
+
"fused_kernel_names": kernels[:3],
|
|
377
|
+
})
|
|
378
|
+
|
|
379
|
+
return mappings
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def detect_fusion_patterns(
|
|
383
|
+
amd_kernels: list[dict],
|
|
384
|
+
nvidia_kernels: list[dict],
|
|
102
385
|
) -> FusionAnalysis:
|
|
103
|
-
"""Detect fusion patterns
|
|
386
|
+
"""Detect fusion patterns using explicit fused operation detection only.
|
|
387
|
+
|
|
388
|
+
This approach only reports high-confidence fusions where kernels are
|
|
389
|
+
explicitly classified as fused (e.g., 'RMSNorm+GEMM', 'SwiGLU+GEMM').
|
|
390
|
+
|
|
391
|
+
We intentionally avoid speculative detection (sequence patterns, count
|
|
392
|
+
imbalances) because these produce too many false positives - count
|
|
393
|
+
differences are usually due to different implementations, not fusion.
|
|
104
394
|
|
|
105
395
|
Args:
|
|
106
|
-
|
|
396
|
+
amd_kernels: List of AMD kernel events
|
|
397
|
+
nvidia_kernels: List of NVIDIA kernel events
|
|
107
398
|
|
|
108
399
|
Returns:
|
|
109
400
|
FusionAnalysis with detected patterns
|
|
110
401
|
"""
|
|
111
|
-
|
|
402
|
+
all_mappings: list[dict] = []
|
|
403
|
+
|
|
404
|
+
# Only use explicit fused operations (highest confidence, no false positives)
|
|
405
|
+
# These are kernels explicitly classified with '+' in their operation type
|
|
406
|
+
explicit_fusions = _find_explicit_fused_operations(
|
|
407
|
+
amd_kernels, nvidia_kernels,
|
|
408
|
+
trace1_name="AMD", trace2_name="NVIDIA",
|
|
409
|
+
trace1_platform="AMD",
|
|
410
|
+
)
|
|
411
|
+
all_mappings.extend(explicit_fusions)
|
|
112
412
|
|
|
113
|
-
#
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
evidence = f"AMD fuses {' + '.join(component_ops)} into {pair.amd_kernel}, NVIDIA runs {len(unfused_kernels)} separate kernels"
|
|
134
|
-
else:
|
|
135
|
-
evidence = f"AMD fuses into {pair.amd_kernel}, NVIDIA runs components separately"
|
|
136
|
-
|
|
137
|
-
fusion_patterns.append(
|
|
138
|
-
FusionPattern(
|
|
139
|
-
layer=layer_alignment.layer,
|
|
140
|
-
operation=pair.operation,
|
|
141
|
-
fused_platform="AMD",
|
|
142
|
-
fused_kernel=pair.amd_kernel,
|
|
143
|
-
unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
|
|
144
|
-
count=pair.amd_count,
|
|
145
|
-
evidence=evidence,
|
|
146
|
-
)
|
|
147
|
-
)
|
|
148
|
-
else:
|
|
149
|
-
# Regular case: AMD has a kernel that NVIDIA fuses into something else
|
|
150
|
-
# This means NVIDIA is fusing (it doesn't need this separate kernel)
|
|
151
|
-
pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
|
|
152
|
-
if pattern_key not in seen_patterns:
|
|
153
|
-
seen_patterns.add(pattern_key)
|
|
154
|
-
|
|
155
|
-
evidence = f"AMD runs {pair.amd_kernel} separately ({pair.amd_count}x), NVIDIA fuses this operation"
|
|
156
|
-
|
|
157
|
-
fusion_patterns.append(
|
|
158
|
-
FusionPattern(
|
|
159
|
-
layer=layer_alignment.layer,
|
|
160
|
-
operation=pair.operation,
|
|
161
|
-
fused_platform="NVIDIA",
|
|
162
|
-
fused_kernel="(fused into nearby kernel)",
|
|
163
|
-
unfused_kernels=[pair.amd_kernel],
|
|
164
|
-
count=pair.amd_count,
|
|
165
|
-
evidence=evidence,
|
|
166
|
-
)
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
# Case 2: NVIDIA has a kernel for this operation, AMD doesn't
|
|
170
|
-
elif pair.nvidia_kernel and pair.nvidia_count > 0 and (pair.amd_kernel is None or pair.amd_count == 0):
|
|
171
|
-
if is_fused_op:
|
|
172
|
-
# NVIDIA HAS the fused kernel → NVIDIA is fusing
|
|
173
|
-
# Find what AMD runs separately
|
|
174
|
-
component_ops = _get_component_operations(pair.operation)
|
|
175
|
-
unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "AMD")
|
|
176
|
-
|
|
177
|
-
pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
|
|
178
|
-
if pattern_key not in seen_patterns:
|
|
179
|
-
seen_patterns.add(pattern_key)
|
|
180
|
-
|
|
181
|
-
if unfused_kernels:
|
|
182
|
-
evidence = f"NVIDIA fuses {' + '.join(component_ops)} into {pair.nvidia_kernel}, AMD runs {len(unfused_kernels)} separate kernels"
|
|
183
|
-
else:
|
|
184
|
-
evidence = f"NVIDIA fuses into {pair.nvidia_kernel}, AMD runs components separately"
|
|
185
|
-
|
|
186
|
-
fusion_patterns.append(
|
|
187
|
-
FusionPattern(
|
|
188
|
-
layer=layer_alignment.layer,
|
|
189
|
-
operation=pair.operation,
|
|
190
|
-
fused_platform="NVIDIA",
|
|
191
|
-
fused_kernel=pair.nvidia_kernel,
|
|
192
|
-
unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
|
|
193
|
-
count=pair.nvidia_count,
|
|
194
|
-
evidence=evidence,
|
|
195
|
-
)
|
|
196
|
-
)
|
|
197
|
-
else:
|
|
198
|
-
# Regular case: NVIDIA has a kernel that AMD fuses into something else
|
|
199
|
-
pattern_key = (layer_alignment.layer, pair.operation, "AMD")
|
|
200
|
-
if pattern_key not in seen_patterns:
|
|
201
|
-
seen_patterns.add(pattern_key)
|
|
202
|
-
|
|
203
|
-
evidence = f"NVIDIA runs {pair.nvidia_kernel} separately ({pair.nvidia_count}x), AMD fuses this operation"
|
|
204
|
-
|
|
205
|
-
fusion_patterns.append(
|
|
206
|
-
FusionPattern(
|
|
207
|
-
layer=layer_alignment.layer,
|
|
208
|
-
operation=pair.operation,
|
|
209
|
-
fused_platform="AMD",
|
|
210
|
-
fused_kernel="(fused into nearby kernel)",
|
|
211
|
-
unfused_kernels=[pair.nvidia_kernel],
|
|
212
|
-
count=pair.nvidia_count,
|
|
213
|
-
evidence=evidence,
|
|
214
|
-
)
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
# Case 3: Both have kernels but with very different counts (partial fusion)
|
|
218
|
-
elif (
|
|
219
|
-
pair.amd_kernel
|
|
220
|
-
and pair.nvidia_kernel
|
|
221
|
-
and pair.amd_count > 0
|
|
222
|
-
and pair.nvidia_count > 0
|
|
223
|
-
):
|
|
224
|
-
count_ratio = pair.amd_count / pair.nvidia_count if pair.nvidia_count > 0 else float('inf')
|
|
225
|
-
|
|
226
|
-
if count_ratio > 1.5:
|
|
227
|
-
# AMD runs more → NVIDIA fuses some instances
|
|
228
|
-
pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
|
|
229
|
-
if pattern_key not in seen_patterns:
|
|
230
|
-
seen_patterns.add(pattern_key)
|
|
231
|
-
|
|
232
|
-
evidence = (
|
|
233
|
-
f"AMD runs {pair.amd_kernel} {count_ratio:.1f}x more "
|
|
234
|
-
f"({pair.amd_count} vs {pair.nvidia_count}), NVIDIA partially fuses"
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
fusion_patterns.append(
|
|
238
|
-
FusionPattern(
|
|
239
|
-
layer=layer_alignment.layer,
|
|
240
|
-
operation=pair.operation,
|
|
241
|
-
fused_platform="NVIDIA",
|
|
242
|
-
fused_kernel=pair.nvidia_kernel,
|
|
243
|
-
unfused_kernels=[pair.amd_kernel],
|
|
244
|
-
count=pair.amd_count - pair.nvidia_count,
|
|
245
|
-
evidence=evidence,
|
|
246
|
-
)
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
elif count_ratio < 0.67:
|
|
250
|
-
# NVIDIA runs more → AMD fuses some instances
|
|
251
|
-
inverse_ratio = pair.nvidia_count / pair.amd_count if pair.amd_count > 0 else float('inf')
|
|
252
|
-
pattern_key = (layer_alignment.layer, pair.operation, "AMD")
|
|
253
|
-
if pattern_key not in seen_patterns:
|
|
254
|
-
seen_patterns.add(pattern_key)
|
|
255
|
-
|
|
256
|
-
evidence = (
|
|
257
|
-
f"NVIDIA runs {pair.nvidia_kernel} {inverse_ratio:.1f}x more "
|
|
258
|
-
f"({pair.nvidia_count} vs {pair.amd_count}), AMD partially fuses"
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
fusion_patterns.append(
|
|
262
|
-
FusionPattern(
|
|
263
|
-
layer=layer_alignment.layer,
|
|
264
|
-
operation=pair.operation,
|
|
265
|
-
fused_platform="AMD",
|
|
266
|
-
fused_kernel=pair.amd_kernel,
|
|
267
|
-
unfused_kernels=[pair.nvidia_kernel],
|
|
268
|
-
count=pair.nvidia_count - pair.amd_count,
|
|
269
|
-
evidence=evidence,
|
|
270
|
-
)
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
# Aggregate patterns by operation type and fused platform
|
|
274
|
-
aggregated: dict[tuple[str, str], FusionPattern] = {}
|
|
275
|
-
for pattern in fusion_patterns:
|
|
276
|
-
key = (pattern.operation, pattern.fused_platform)
|
|
277
|
-
if key in aggregated:
|
|
278
|
-
existing = aggregated[key]
|
|
279
|
-
existing.count += pattern.count
|
|
280
|
-
# Merge unfused kernels (deduplicate)
|
|
281
|
-
for k in pattern.unfused_kernels:
|
|
282
|
-
if k not in existing.unfused_kernels:
|
|
283
|
-
existing.unfused_kernels.append(k)
|
|
284
|
-
else:
|
|
285
|
-
aggregated[key] = FusionPattern(
|
|
286
|
-
layer=pattern.layer,
|
|
287
|
-
operation=pattern.operation,
|
|
288
|
-
fused_platform=pattern.fused_platform,
|
|
289
|
-
fused_kernel=pattern.fused_kernel,
|
|
290
|
-
unfused_kernels=list(pattern.unfused_kernels),
|
|
291
|
-
count=pattern.count,
|
|
292
|
-
evidence=pattern.evidence,
|
|
413
|
+
# NOTE: We intentionally skip sequence-based and count-imbalance detection
|
|
414
|
+
# because they produce false positives. Count differences between platforms
|
|
415
|
+
# are usually due to different library implementations (cuBLAS vs hipBLAS),
|
|
416
|
+
# not actual kernel fusion.
|
|
417
|
+
|
|
418
|
+
# Convert to FusionPattern objects
|
|
419
|
+
patterns: list[FusionPattern] = []
|
|
420
|
+
for mapping in all_mappings:
|
|
421
|
+
fused_kernel_names = mapping.get("fused_kernel_names", [])
|
|
422
|
+
fused_kernel = fused_kernel_names[0] if fused_kernel_names else mapping["fused_kernel_type"]
|
|
423
|
+
|
|
424
|
+
patterns.append(
|
|
425
|
+
FusionPattern(
|
|
426
|
+
layer=0, # Pattern-based analysis doesn't track layers
|
|
427
|
+
operation=mapping["fused_kernel_type"],
|
|
428
|
+
fused_platform=mapping["fused_platform"],
|
|
429
|
+
fused_kernel=fused_kernel,
|
|
430
|
+
unfused_kernels=mapping["unfused_sequence"],
|
|
431
|
+
count=mapping["pattern_count"],
|
|
432
|
+
evidence=mapping["evidence"],
|
|
293
433
|
)
|
|
434
|
+
)
|
|
294
435
|
|
|
295
|
-
amd_fuses = sum(1 for p in
|
|
296
|
-
nvidia_fuses = sum(1 for p in
|
|
436
|
+
amd_fuses = sum(1 for p in patterns if p.fused_platform == "AMD")
|
|
437
|
+
nvidia_fuses = sum(1 for p in patterns if p.fused_platform == "NVIDIA")
|
|
297
438
|
|
|
298
439
|
return FusionAnalysis(
|
|
299
|
-
patterns=
|
|
440
|
+
patterns=patterns,
|
|
300
441
|
summary={
|
|
301
442
|
"amd_fuses": amd_fuses,
|
|
302
443
|
"nvidia_fuses": nvidia_fuses,
|
|
303
|
-
"total_fusion_opportunities": len(
|
|
444
|
+
"total_fusion_opportunities": len(patterns),
|
|
304
445
|
},
|
|
305
446
|
)
|
|
306
447
|
|
|
307
448
|
|
|
308
449
|
def analyze_fusion_from_alignment(
|
|
309
|
-
layer_alignments: list[
|
|
450
|
+
layer_alignments: list[Any],
|
|
451
|
+
amd_kernels: list[dict] | None = None,
|
|
452
|
+
nvidia_kernels: list[dict] | None = None,
|
|
310
453
|
) -> dict[str, Any]:
|
|
311
|
-
"""Analyze fusion from
|
|
454
|
+
"""Analyze fusion from kernel data (for API compatibility).
|
|
312
455
|
|
|
313
456
|
Args:
|
|
314
|
-
layer_alignments: List of aligned layers
|
|
457
|
+
layer_alignments: List of aligned layers (unused - kept for API compatibility)
|
|
458
|
+
amd_kernels: Optional list of AMD kernel events for pattern-based analysis
|
|
459
|
+
nvidia_kernels: Optional list of NVIDIA kernel events for pattern-based analysis
|
|
315
460
|
|
|
316
461
|
Returns:
|
|
317
|
-
Dictionary with fusion analysis results
|
|
462
|
+
Dictionary with fusion analysis results
|
|
318
463
|
"""
|
|
319
|
-
|
|
464
|
+
# If raw kernels provided, use pattern-based analysis (preferred)
|
|
465
|
+
if amd_kernels is not None and nvidia_kernels is not None:
|
|
466
|
+
fusion_analysis = detect_fusion_patterns(amd_kernels, nvidia_kernels)
|
|
467
|
+
else:
|
|
468
|
+
# Fallback: empty analysis if no kernel data
|
|
469
|
+
fusion_analysis = FusionAnalysis(
|
|
470
|
+
patterns=[],
|
|
471
|
+
summary={"amd_fuses": 0, "nvidia_fuses": 0, "total_fusion_opportunities": 0},
|
|
472
|
+
)
|
|
320
473
|
|
|
321
474
|
fusion_opportunities = []
|
|
322
475
|
fusion_mappings = []
|
|
323
476
|
|
|
324
477
|
for pattern in fusion_analysis.patterns:
|
|
325
478
|
unfused_platform = "NVIDIA" if pattern.fused_platform == "AMD" else "AMD"
|
|
326
|
-
|
|
327
|
-
fusion_opportunities.append(
|
|
328
|
-
{
|
|
329
|
-
"kernel_type": pattern.operation,
|
|
330
|
-
"layer": pattern.layer,
|
|
331
|
-
"fused_by": pattern.fused_platform,
|
|
332
|
-
"fused_kernel": pattern.fused_kernel,
|
|
333
|
-
"unfused_kernels": pattern.unfused_kernels,
|
|
334
|
-
"count": pattern.count,
|
|
335
|
-
"evidence": pattern.evidence,
|
|
336
|
-
}
|
|
337
|
-
)
|
|
338
479
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
480
|
+
fusion_opportunities.append({
|
|
481
|
+
"kernel_type": pattern.operation,
|
|
482
|
+
"layer": pattern.layer,
|
|
483
|
+
"fused_by": pattern.fused_platform,
|
|
484
|
+
"fused_kernel": pattern.fused_kernel,
|
|
485
|
+
"unfused_kernels": pattern.unfused_kernels,
|
|
486
|
+
"count": pattern.count,
|
|
487
|
+
"evidence": pattern.evidence,
|
|
488
|
+
})
|
|
489
|
+
|
|
490
|
+
fusion_mappings.append({
|
|
491
|
+
"fused_platform": pattern.fused_platform,
|
|
492
|
+
"fused_kernel_type": pattern.operation,
|
|
493
|
+
"fused_kernel_name": pattern.fused_kernel,
|
|
494
|
+
"unfused_platform": unfused_platform,
|
|
495
|
+
"unfused_sequence": pattern.unfused_kernels,
|
|
496
|
+
"pattern_count": pattern.count,
|
|
497
|
+
"evidence": pattern.evidence,
|
|
498
|
+
"layer": pattern.layer,
|
|
499
|
+
})
|
|
351
500
|
|
|
352
501
|
return {
|
|
353
502
|
"fusion_opportunities": fusion_opportunities,
|