wafer-core 0.1.27__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/aligner.py +13 -6
- wafer_core/lib/trace_compare/analyzer.py +12 -3
- wafer_core/lib/trace_compare/fusion_analyzer.py +392 -284
- wafer_core/targets/__init__.py +47 -21
- wafer_core/targets/pool.py +181 -0
- wafer_core/targets/probe.py +113 -0
- wafer_core/targets/providers/__init__.py +46 -0
- wafer_core/targets/providers/baremetal.py +72 -0
- wafer_core/targets/providers/digitalocean.py +164 -0
- wafer_core/targets/providers/runpod.py +250 -0
- wafer_core/targets/reconcile.py +90 -0
- wafer_core/targets/spec_store.py +200 -0
- wafer_core/targets/state_cache.py +150 -0
- wafer_core/targets/types.py +141 -0
- wafer_core/utils/kernel_utils/targets/config.py +8 -24
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.28.dist-info}/METADATA +1 -1
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.28.dist-info}/RECORD +18 -8
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.28.dist-info}/WHEEL +0 -0
|
@@ -1,25 +1,18 @@
|
|
|
1
|
-
"""Fusion analysis
|
|
1
|
+
"""Fusion analysis for detecting kernel fusion differences between platforms.
|
|
2
2
|
|
|
3
|
-
Detects
|
|
3
|
+
Detects fusion differences between AMD and NVIDIA by analyzing:
|
|
4
|
+
1. Kernel types unique to each platform (one platform fuses them away)
|
|
5
|
+
2. Sequence patterns (what comes before/after unique kernels)
|
|
6
|
+
3. Count imbalances (one platform has significantly more calls)
|
|
4
7
|
|
|
5
|
-
|
|
6
|
-
that platform IS fusing. The OTHER platform runs the components separately.
|
|
8
|
+
This is the pattern-based approach which is more reliable than alignment-based detection.
|
|
7
9
|
"""
|
|
8
10
|
|
|
9
11
|
from collections import Counter, defaultdict
|
|
10
12
|
from dataclasses import dataclass, field
|
|
11
13
|
from typing import Any
|
|
12
14
|
|
|
13
|
-
from .
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
# Maps fused operation names to their component operations
|
|
17
|
-
FUSED_OPERATION_COMPONENTS: dict[str, list[str]] = {
|
|
18
|
-
"RMSNorm+GEMM": ["RMSNorm", "Dense GEMM"],
|
|
19
|
-
"SwiGLU+GEMM": ["SwiGLU", "Dense GEMM"],
|
|
20
|
-
"MoE GEMM+SwiGLU": ["MoE GEMM", "SwiGLU"],
|
|
21
|
-
"Embedding+RMSNorm+GEMM": ["Elementwise", "RMSNorm", "Dense GEMM"], # Embedding is often Elementwise
|
|
22
|
-
}
|
|
15
|
+
from .classifier import classify, Op
|
|
23
16
|
|
|
24
17
|
|
|
25
18
|
@dataclass
|
|
@@ -43,311 +36,426 @@ class FusionAnalysis:
|
|
|
43
36
|
summary: dict[str, Any] = field(default_factory=dict)
|
|
44
37
|
|
|
45
38
|
|
|
46
|
-
def
|
|
47
|
-
"""
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
mapping = {
|
|
65
|
-
"GEMM": "Dense GEMM",
|
|
66
|
-
"SwiGLU": "Triton Fused",
|
|
67
|
-
}
|
|
68
|
-
return [mapping.get(p, p) for p in parts]
|
|
69
|
-
|
|
70
|
-
return []
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def _find_component_kernels(
|
|
74
|
-
layer_alignment: LayerAlignment,
|
|
75
|
-
component_ops: list[str],
|
|
76
|
-
platform: str,
|
|
77
|
-
) -> list[str]:
|
|
78
|
-
"""Find kernels for component operations on the specified platform.
|
|
79
|
-
|
|
39
|
+
def _classify_kernel(name: str, platform: str = "AMD") -> str:
|
|
40
|
+
"""Classify a kernel name to its operation type."""
|
|
41
|
+
op, _pattern = classify(name, platform)
|
|
42
|
+
return op.value
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _find_fusion_mappings(
|
|
46
|
+
trace1_kernels: list[dict],
|
|
47
|
+
trace2_kernels: list[dict],
|
|
48
|
+
trace1_name: str = "Trace1",
|
|
49
|
+
trace2_name: str = "Trace2",
|
|
50
|
+
trace1_platform: str = "AMD",
|
|
51
|
+
) -> list[dict]:
|
|
52
|
+
"""Find fusion mappings by analyzing kernel execution sequence patterns.
|
|
53
|
+
|
|
54
|
+
This function identifies when one platform runs multiple kernels separately
|
|
55
|
+
while the other platform fuses them into a single kernel.
|
|
56
|
+
|
|
80
57
|
Args:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
58
|
+
trace1_kernels: List of kernel events from first trace
|
|
59
|
+
trace2_kernels: List of kernel events from second trace
|
|
60
|
+
trace1_name: Name of first platform (e.g., "AMD")
|
|
61
|
+
trace2_name: Name of second platform (e.g., "NVIDIA")
|
|
62
|
+
trace1_platform: Platform string for classification ("AMD" or "NVIDIA")
|
|
63
|
+
|
|
85
64
|
Returns:
|
|
86
|
-
List of
|
|
65
|
+
List of mapping dictionaries with fusion details
|
|
66
|
+
"""
|
|
67
|
+
mappings = []
|
|
68
|
+
trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
|
|
69
|
+
|
|
70
|
+
# Sort kernels by timestamp
|
|
71
|
+
trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get("ts", 0))
|
|
72
|
+
trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get("ts", 0))
|
|
73
|
+
|
|
74
|
+
# Classify all kernels
|
|
75
|
+
trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_sorted]
|
|
76
|
+
trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_sorted]
|
|
77
|
+
|
|
78
|
+
# Find kernel types unique to each trace
|
|
79
|
+
trace1_type_set = set(trace1_types)
|
|
80
|
+
trace2_type_set = set(trace2_types)
|
|
81
|
+
|
|
82
|
+
trace1_only = trace1_type_set - trace2_type_set
|
|
83
|
+
trace2_only = trace2_type_set - trace1_type_set
|
|
84
|
+
|
|
85
|
+
# For each unique type in trace1, find common sequence patterns
|
|
86
|
+
for unique_type in trace1_only:
|
|
87
|
+
# Skip "Other" since it's too generic
|
|
88
|
+
if unique_type == "Other":
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
# Find all occurrences of this type
|
|
92
|
+
indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
|
|
93
|
+
|
|
94
|
+
if len(indices) < 5: # Need enough samples to be meaningful
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
# Analyze what comes before each occurrence
|
|
98
|
+
before_types: dict[str, int] = defaultdict(int)
|
|
99
|
+
|
|
100
|
+
for idx in indices:
|
|
101
|
+
if idx > 0:
|
|
102
|
+
before_types[trace1_types[idx - 1]] += 1
|
|
103
|
+
|
|
104
|
+
# Find the most common pattern
|
|
105
|
+
if not before_types:
|
|
106
|
+
continue
|
|
107
|
+
most_common_before = max(before_types.items(), key=lambda x: x[1])
|
|
108
|
+
|
|
109
|
+
# If there's a strong pattern (>80% of occurrences)
|
|
110
|
+
if most_common_before[1] / len(indices) > 0.8:
|
|
111
|
+
fusion_candidate = most_common_before[0]
|
|
112
|
+
|
|
113
|
+
# Verify trace2 has this type
|
|
114
|
+
if fusion_candidate in trace2_type_set:
|
|
115
|
+
trace1_fusion_count = trace1_types.count(fusion_candidate)
|
|
116
|
+
trace2_fusion_count = trace2_types.count(fusion_candidate)
|
|
117
|
+
|
|
118
|
+
mappings.append({
|
|
119
|
+
"fused_platform": trace2_name,
|
|
120
|
+
"fused_kernel_type": fusion_candidate,
|
|
121
|
+
"fused_count": trace2_fusion_count,
|
|
122
|
+
"unfused_platform": trace1_name,
|
|
123
|
+
"unfused_sequence": [fusion_candidate, unique_type],
|
|
124
|
+
"unfused_count_per_type": {
|
|
125
|
+
fusion_candidate: trace1_fusion_count,
|
|
126
|
+
unique_type: len(indices),
|
|
127
|
+
},
|
|
128
|
+
"pattern_count": len(indices),
|
|
129
|
+
"pattern_confidence": most_common_before[1] / len(indices),
|
|
130
|
+
"evidence": f"{trace1_name} runs {fusion_candidate}+{unique_type} separately, {trace2_name} fuses into {fusion_candidate}",
|
|
131
|
+
})
|
|
132
|
+
|
|
133
|
+
# Also check trace2-only types
|
|
134
|
+
for unique_type in trace2_only:
|
|
135
|
+
if unique_type == "Other":
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
|
|
139
|
+
|
|
140
|
+
if len(indices) < 5:
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
before_types = defaultdict(int)
|
|
144
|
+
|
|
145
|
+
for idx in indices:
|
|
146
|
+
if idx > 0:
|
|
147
|
+
before_types[trace2_types[idx - 1]] += 1
|
|
148
|
+
|
|
149
|
+
if not before_types:
|
|
150
|
+
continue
|
|
151
|
+
most_common_before = max(before_types.items(), key=lambda x: x[1])
|
|
152
|
+
|
|
153
|
+
if most_common_before[1] / len(indices) > 0.8:
|
|
154
|
+
fusion_candidate = most_common_before[0]
|
|
155
|
+
|
|
156
|
+
if fusion_candidate in trace1_type_set:
|
|
157
|
+
trace1_fusion_count = trace1_types.count(fusion_candidate)
|
|
158
|
+
trace2_fusion_count = trace2_types.count(fusion_candidate)
|
|
159
|
+
|
|
160
|
+
mappings.append({
|
|
161
|
+
"fused_platform": trace1_name,
|
|
162
|
+
"fused_kernel_type": fusion_candidate,
|
|
163
|
+
"fused_count": trace1_fusion_count,
|
|
164
|
+
"unfused_platform": trace2_name,
|
|
165
|
+
"unfused_sequence": [fusion_candidate, unique_type],
|
|
166
|
+
"unfused_count_per_type": {
|
|
167
|
+
fusion_candidate: trace2_fusion_count,
|
|
168
|
+
unique_type: len(indices),
|
|
169
|
+
},
|
|
170
|
+
"pattern_count": len(indices),
|
|
171
|
+
"pattern_confidence": most_common_before[1] / len(indices),
|
|
172
|
+
"evidence": f"{trace2_name} runs {fusion_candidate}+{unique_type} separately, {trace1_name} fuses into {fusion_candidate}",
|
|
173
|
+
})
|
|
174
|
+
|
|
175
|
+
return mappings
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _find_count_imbalance_fusions(
|
|
179
|
+
trace1_kernels: list[dict],
|
|
180
|
+
trace2_kernels: list[dict],
|
|
181
|
+
trace1_name: str = "Trace1",
|
|
182
|
+
trace2_name: str = "Trace2",
|
|
183
|
+
trace1_platform: str = "AMD",
|
|
184
|
+
) -> list[dict]:
|
|
185
|
+
"""Find fusions by looking for significant count imbalances.
|
|
186
|
+
|
|
187
|
+
When one platform has significantly more kernel calls of a type (>1.5x),
|
|
188
|
+
it suggests the other platform fuses those operations.
|
|
189
|
+
"""
|
|
190
|
+
mappings = []
|
|
191
|
+
trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
|
|
192
|
+
|
|
193
|
+
# Classify all kernels
|
|
194
|
+
trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_kernels]
|
|
195
|
+
trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_kernels]
|
|
196
|
+
|
|
197
|
+
# Count by type
|
|
198
|
+
trace1_counts = Counter(trace1_types)
|
|
199
|
+
trace2_counts = Counter(trace2_types)
|
|
200
|
+
|
|
201
|
+
# Find common types with significant differences
|
|
202
|
+
common_types = set(trace1_counts.keys()) & set(trace2_counts.keys())
|
|
203
|
+
|
|
204
|
+
for ktype in common_types:
|
|
205
|
+
if ktype == "Other":
|
|
206
|
+
continue
|
|
207
|
+
|
|
208
|
+
trace1_count = trace1_counts[ktype]
|
|
209
|
+
trace2_count = trace2_counts[ktype]
|
|
210
|
+
|
|
211
|
+
# Skip trivial counts
|
|
212
|
+
if trace1_count + trace2_count < 50:
|
|
213
|
+
continue
|
|
214
|
+
|
|
215
|
+
# Check if there's a significant imbalance (>1.5x)
|
|
216
|
+
if trace1_count == 0 or trace2_count == 0:
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
|
|
220
|
+
|
|
221
|
+
if ratio < 1.5:
|
|
222
|
+
continue
|
|
223
|
+
|
|
224
|
+
# Determine which platform has more (unfused) and which has fewer (fused)
|
|
225
|
+
if trace1_count > trace2_count:
|
|
226
|
+
unfused_platform = trace1_name
|
|
227
|
+
fused_platform = trace2_name
|
|
228
|
+
unfused_count = trace1_count
|
|
229
|
+
fused_count = trace2_count
|
|
230
|
+
else:
|
|
231
|
+
unfused_platform = trace2_name
|
|
232
|
+
fused_platform = trace1_name
|
|
233
|
+
unfused_count = trace2_count
|
|
234
|
+
fused_count = trace1_count
|
|
235
|
+
|
|
236
|
+
mappings.append({
|
|
237
|
+
"fused_platform": fused_platform,
|
|
238
|
+
"fused_kernel_type": ktype,
|
|
239
|
+
"fused_count": fused_count,
|
|
240
|
+
"unfused_platform": unfused_platform,
|
|
241
|
+
"unfused_sequence": [ktype],
|
|
242
|
+
"unfused_count_per_type": {ktype: unfused_count},
|
|
243
|
+
"pattern_count": unfused_count - fused_count,
|
|
244
|
+
"pattern_confidence": (unfused_count - fused_count) / unfused_count,
|
|
245
|
+
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses",
|
|
246
|
+
})
|
|
247
|
+
|
|
248
|
+
return mappings
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _find_explicit_fused_operations(
|
|
252
|
+
trace1_kernels: list[dict],
|
|
253
|
+
trace2_kernels: list[dict],
|
|
254
|
+
trace1_name: str = "Trace1",
|
|
255
|
+
trace2_name: str = "Trace2",
|
|
256
|
+
trace1_platform: str = "AMD",
|
|
257
|
+
) -> list[dict]:
|
|
258
|
+
"""Find explicit fused operations (kernels with '+' in their classification).
|
|
259
|
+
|
|
260
|
+
These are operations like 'RMSNorm+GEMM' that are explicitly classified as fused.
|
|
87
261
|
"""
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
262
|
+
mappings = []
|
|
263
|
+
trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
|
|
264
|
+
|
|
265
|
+
def get_fused_ops(kernels: list[dict], platform: str) -> dict[str, list[str]]:
|
|
266
|
+
"""Get fused operations and their kernel names."""
|
|
267
|
+
fused_ops: dict[str, list[str]] = defaultdict(list)
|
|
268
|
+
for k in kernels:
|
|
269
|
+
name = k.get("name", "")
|
|
270
|
+
op, _pattern = classify(name, platform)
|
|
271
|
+
if "+" in op.value or op == Op.FUSED_UNKNOWN:
|
|
272
|
+
fused_ops[op.value].append(name)
|
|
273
|
+
return dict(fused_ops)
|
|
274
|
+
|
|
275
|
+
trace1_fused = get_fused_ops(trace1_kernels, trace1_platform)
|
|
276
|
+
trace2_fused = get_fused_ops(trace2_kernels, trace2_platform)
|
|
277
|
+
|
|
278
|
+
# Find fused operations unique to trace1
|
|
279
|
+
for fused_op, kernels in trace1_fused.items():
|
|
280
|
+
if fused_op not in trace2_fused:
|
|
281
|
+
# Parse components from the fused op name
|
|
282
|
+
if "+" in fused_op:
|
|
283
|
+
components = [c.strip() for c in fused_op.split("+")]
|
|
284
|
+
else:
|
|
285
|
+
components = [fused_op]
|
|
286
|
+
|
|
287
|
+
mappings.append({
|
|
288
|
+
"fused_platform": trace1_name,
|
|
289
|
+
"fused_kernel_type": fused_op,
|
|
290
|
+
"fused_count": len(kernels),
|
|
291
|
+
"unfused_platform": trace2_name,
|
|
292
|
+
"unfused_sequence": components,
|
|
293
|
+
"unfused_count_per_type": {c: 0 for c in components}, # Unknown
|
|
294
|
+
"pattern_count": len(kernels),
|
|
295
|
+
"pattern_confidence": 1.0,
|
|
296
|
+
"evidence": f"{trace1_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
|
|
297
|
+
"fused_kernel_names": kernels[:3], # Sample of kernel names
|
|
298
|
+
})
|
|
299
|
+
|
|
300
|
+
# Find fused operations unique to trace2
|
|
301
|
+
for fused_op, kernels in trace2_fused.items():
|
|
302
|
+
if fused_op not in trace1_fused:
|
|
303
|
+
if "+" in fused_op:
|
|
304
|
+
components = [c.strip() for c in fused_op.split("+")]
|
|
305
|
+
else:
|
|
306
|
+
components = [fused_op]
|
|
307
|
+
|
|
308
|
+
mappings.append({
|
|
309
|
+
"fused_platform": trace2_name,
|
|
310
|
+
"fused_kernel_type": fused_op,
|
|
311
|
+
"fused_count": len(kernels),
|
|
312
|
+
"unfused_platform": trace1_name,
|
|
313
|
+
"unfused_sequence": components,
|
|
314
|
+
"unfused_count_per_type": {c: 0 for c in components},
|
|
315
|
+
"pattern_count": len(kernels),
|
|
316
|
+
"pattern_confidence": 1.0,
|
|
317
|
+
"evidence": f"{trace2_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
|
|
318
|
+
"fused_kernel_names": kernels[:3],
|
|
319
|
+
})
|
|
320
|
+
|
|
321
|
+
return mappings
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def detect_fusion_patterns(
|
|
325
|
+
amd_kernels: list[dict],
|
|
326
|
+
nvidia_kernels: list[dict],
|
|
102
327
|
) -> FusionAnalysis:
|
|
103
|
-
"""Detect fusion patterns
|
|
328
|
+
"""Detect fusion patterns using pattern-based analysis.
|
|
329
|
+
|
|
330
|
+
This is the main entry point for fusion detection. It combines:
|
|
331
|
+
1. Explicit fused operations (kernels classified with '+' in name)
|
|
332
|
+
2. Sequence pattern analysis (unique kernel types with consistent patterns)
|
|
333
|
+
3. Count imbalance analysis (one platform has significantly more calls)
|
|
104
334
|
|
|
105
335
|
Args:
|
|
106
|
-
|
|
336
|
+
amd_kernels: List of AMD kernel events
|
|
337
|
+
nvidia_kernels: List of NVIDIA kernel events
|
|
107
338
|
|
|
108
339
|
Returns:
|
|
109
340
|
FusionAnalysis with detected patterns
|
|
110
341
|
"""
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
#
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
fused_platform="NVIDIA",
|
|
162
|
-
fused_kernel="(fused into nearby kernel)",
|
|
163
|
-
unfused_kernels=[pair.amd_kernel],
|
|
164
|
-
count=pair.amd_count,
|
|
165
|
-
evidence=evidence,
|
|
166
|
-
)
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
# Case 2: NVIDIA has a kernel for this operation, AMD doesn't
|
|
170
|
-
elif pair.nvidia_kernel and pair.nvidia_count > 0 and (pair.amd_kernel is None or pair.amd_count == 0):
|
|
171
|
-
if is_fused_op:
|
|
172
|
-
# NVIDIA HAS the fused kernel → NVIDIA is fusing
|
|
173
|
-
# Find what AMD runs separately
|
|
174
|
-
component_ops = _get_component_operations(pair.operation)
|
|
175
|
-
unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "AMD")
|
|
176
|
-
|
|
177
|
-
pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
|
|
178
|
-
if pattern_key not in seen_patterns:
|
|
179
|
-
seen_patterns.add(pattern_key)
|
|
180
|
-
|
|
181
|
-
if unfused_kernels:
|
|
182
|
-
evidence = f"NVIDIA fuses {' + '.join(component_ops)} into {pair.nvidia_kernel}, AMD runs {len(unfused_kernels)} separate kernels"
|
|
183
|
-
else:
|
|
184
|
-
evidence = f"NVIDIA fuses into {pair.nvidia_kernel}, AMD runs components separately"
|
|
185
|
-
|
|
186
|
-
fusion_patterns.append(
|
|
187
|
-
FusionPattern(
|
|
188
|
-
layer=layer_alignment.layer,
|
|
189
|
-
operation=pair.operation,
|
|
190
|
-
fused_platform="NVIDIA",
|
|
191
|
-
fused_kernel=pair.nvidia_kernel,
|
|
192
|
-
unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
|
|
193
|
-
count=pair.nvidia_count,
|
|
194
|
-
evidence=evidence,
|
|
195
|
-
)
|
|
196
|
-
)
|
|
197
|
-
else:
|
|
198
|
-
# Regular case: NVIDIA has a kernel that AMD fuses into something else
|
|
199
|
-
pattern_key = (layer_alignment.layer, pair.operation, "AMD")
|
|
200
|
-
if pattern_key not in seen_patterns:
|
|
201
|
-
seen_patterns.add(pattern_key)
|
|
202
|
-
|
|
203
|
-
evidence = f"NVIDIA runs {pair.nvidia_kernel} separately ({pair.nvidia_count}x), AMD fuses this operation"
|
|
204
|
-
|
|
205
|
-
fusion_patterns.append(
|
|
206
|
-
FusionPattern(
|
|
207
|
-
layer=layer_alignment.layer,
|
|
208
|
-
operation=pair.operation,
|
|
209
|
-
fused_platform="AMD",
|
|
210
|
-
fused_kernel="(fused into nearby kernel)",
|
|
211
|
-
unfused_kernels=[pair.nvidia_kernel],
|
|
212
|
-
count=pair.nvidia_count,
|
|
213
|
-
evidence=evidence,
|
|
214
|
-
)
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
# Case 3: Both have kernels but with very different counts (partial fusion)
|
|
218
|
-
elif (
|
|
219
|
-
pair.amd_kernel
|
|
220
|
-
and pair.nvidia_kernel
|
|
221
|
-
and pair.amd_count > 0
|
|
222
|
-
and pair.nvidia_count > 0
|
|
223
|
-
):
|
|
224
|
-
count_ratio = pair.amd_count / pair.nvidia_count if pair.nvidia_count > 0 else float('inf')
|
|
225
|
-
|
|
226
|
-
if count_ratio > 1.5:
|
|
227
|
-
# AMD runs more → NVIDIA fuses some instances
|
|
228
|
-
pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
|
|
229
|
-
if pattern_key not in seen_patterns:
|
|
230
|
-
seen_patterns.add(pattern_key)
|
|
231
|
-
|
|
232
|
-
evidence = (
|
|
233
|
-
f"AMD runs {pair.amd_kernel} {count_ratio:.1f}x more "
|
|
234
|
-
f"({pair.amd_count} vs {pair.nvidia_count}), NVIDIA partially fuses"
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
fusion_patterns.append(
|
|
238
|
-
FusionPattern(
|
|
239
|
-
layer=layer_alignment.layer,
|
|
240
|
-
operation=pair.operation,
|
|
241
|
-
fused_platform="NVIDIA",
|
|
242
|
-
fused_kernel=pair.nvidia_kernel,
|
|
243
|
-
unfused_kernels=[pair.amd_kernel],
|
|
244
|
-
count=pair.amd_count - pair.nvidia_count,
|
|
245
|
-
evidence=evidence,
|
|
246
|
-
)
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
elif count_ratio < 0.67:
|
|
250
|
-
# NVIDIA runs more → AMD fuses some instances
|
|
251
|
-
inverse_ratio = pair.nvidia_count / pair.amd_count if pair.amd_count > 0 else float('inf')
|
|
252
|
-
pattern_key = (layer_alignment.layer, pair.operation, "AMD")
|
|
253
|
-
if pattern_key not in seen_patterns:
|
|
254
|
-
seen_patterns.add(pattern_key)
|
|
255
|
-
|
|
256
|
-
evidence = (
|
|
257
|
-
f"NVIDIA runs {pair.nvidia_kernel} {inverse_ratio:.1f}x more "
|
|
258
|
-
f"({pair.nvidia_count} vs {pair.amd_count}), AMD partially fuses"
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
fusion_patterns.append(
|
|
262
|
-
FusionPattern(
|
|
263
|
-
layer=layer_alignment.layer,
|
|
264
|
-
operation=pair.operation,
|
|
265
|
-
fused_platform="AMD",
|
|
266
|
-
fused_kernel=pair.amd_kernel,
|
|
267
|
-
unfused_kernels=[pair.nvidia_kernel],
|
|
268
|
-
count=pair.nvidia_count - pair.amd_count,
|
|
269
|
-
evidence=evidence,
|
|
270
|
-
)
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
# Aggregate patterns by operation type and fused platform
|
|
274
|
-
aggregated: dict[tuple[str, str], FusionPattern] = {}
|
|
275
|
-
for pattern in fusion_patterns:
|
|
276
|
-
key = (pattern.operation, pattern.fused_platform)
|
|
277
|
-
if key in aggregated:
|
|
278
|
-
existing = aggregated[key]
|
|
279
|
-
existing.count += pattern.count
|
|
280
|
-
# Merge unfused kernels (deduplicate)
|
|
281
|
-
for k in pattern.unfused_kernels:
|
|
282
|
-
if k not in existing.unfused_kernels:
|
|
283
|
-
existing.unfused_kernels.append(k)
|
|
284
|
-
else:
|
|
285
|
-
aggregated[key] = FusionPattern(
|
|
286
|
-
layer=pattern.layer,
|
|
287
|
-
operation=pattern.operation,
|
|
288
|
-
fused_platform=pattern.fused_platform,
|
|
289
|
-
fused_kernel=pattern.fused_kernel,
|
|
290
|
-
unfused_kernels=list(pattern.unfused_kernels),
|
|
291
|
-
count=pattern.count,
|
|
292
|
-
evidence=pattern.evidence,
|
|
342
|
+
all_mappings: list[dict] = []
|
|
343
|
+
|
|
344
|
+
# 1. Find explicit fused operations (highest confidence)
|
|
345
|
+
explicit_fusions = _find_explicit_fused_operations(
|
|
346
|
+
amd_kernels, nvidia_kernels,
|
|
347
|
+
trace1_name="AMD", trace2_name="NVIDIA",
|
|
348
|
+
trace1_platform="AMD",
|
|
349
|
+
)
|
|
350
|
+
all_mappings.extend(explicit_fusions)
|
|
351
|
+
|
|
352
|
+
# 2. Find sequence-based fusions
|
|
353
|
+
sequence_fusions = _find_fusion_mappings(
|
|
354
|
+
amd_kernels, nvidia_kernels,
|
|
355
|
+
trace1_name="AMD", trace2_name="NVIDIA",
|
|
356
|
+
trace1_platform="AMD",
|
|
357
|
+
)
|
|
358
|
+
# Deduplicate: skip if same fused_kernel_type already found
|
|
359
|
+
existing_types = {m["fused_kernel_type"] for m in all_mappings}
|
|
360
|
+
for fusion in sequence_fusions:
|
|
361
|
+
if fusion["fused_kernel_type"] not in existing_types:
|
|
362
|
+
all_mappings.append(fusion)
|
|
363
|
+
existing_types.add(fusion["fused_kernel_type"])
|
|
364
|
+
|
|
365
|
+
# 3. Find count-imbalance fusions
|
|
366
|
+
count_fusions = _find_count_imbalance_fusions(
|
|
367
|
+
amd_kernels, nvidia_kernels,
|
|
368
|
+
trace1_name="AMD", trace2_name="NVIDIA",
|
|
369
|
+
trace1_platform="AMD",
|
|
370
|
+
)
|
|
371
|
+
# Deduplicate
|
|
372
|
+
for fusion in count_fusions:
|
|
373
|
+
if fusion["fused_kernel_type"] not in existing_types:
|
|
374
|
+
all_mappings.append(fusion)
|
|
375
|
+
existing_types.add(fusion["fused_kernel_type"])
|
|
376
|
+
|
|
377
|
+
# Convert to FusionPattern objects
|
|
378
|
+
patterns: list[FusionPattern] = []
|
|
379
|
+
for mapping in all_mappings:
|
|
380
|
+
fused_kernel_names = mapping.get("fused_kernel_names", [])
|
|
381
|
+
fused_kernel = fused_kernel_names[0] if fused_kernel_names else mapping["fused_kernel_type"]
|
|
382
|
+
|
|
383
|
+
patterns.append(
|
|
384
|
+
FusionPattern(
|
|
385
|
+
layer=0, # Pattern-based analysis doesn't track layers
|
|
386
|
+
operation=mapping["fused_kernel_type"],
|
|
387
|
+
fused_platform=mapping["fused_platform"],
|
|
388
|
+
fused_kernel=fused_kernel,
|
|
389
|
+
unfused_kernels=mapping["unfused_sequence"],
|
|
390
|
+
count=mapping["pattern_count"],
|
|
391
|
+
evidence=mapping["evidence"],
|
|
293
392
|
)
|
|
393
|
+
)
|
|
294
394
|
|
|
295
|
-
amd_fuses = sum(1 for p in
|
|
296
|
-
nvidia_fuses = sum(1 for p in
|
|
395
|
+
amd_fuses = sum(1 for p in patterns if p.fused_platform == "AMD")
|
|
396
|
+
nvidia_fuses = sum(1 for p in patterns if p.fused_platform == "NVIDIA")
|
|
297
397
|
|
|
298
398
|
return FusionAnalysis(
|
|
299
|
-
patterns=
|
|
399
|
+
patterns=patterns,
|
|
300
400
|
summary={
|
|
301
401
|
"amd_fuses": amd_fuses,
|
|
302
402
|
"nvidia_fuses": nvidia_fuses,
|
|
303
|
-
"total_fusion_opportunities": len(
|
|
403
|
+
"total_fusion_opportunities": len(patterns),
|
|
304
404
|
},
|
|
305
405
|
)
|
|
306
406
|
|
|
307
407
|
|
|
308
408
|
def analyze_fusion_from_alignment(
|
|
309
|
-
layer_alignments: list[
|
|
409
|
+
layer_alignments: list[Any],
|
|
410
|
+
amd_kernels: list[dict] | None = None,
|
|
411
|
+
nvidia_kernels: list[dict] | None = None,
|
|
310
412
|
) -> dict[str, Any]:
|
|
311
|
-
"""Analyze fusion from
|
|
413
|
+
"""Analyze fusion from kernel data (for API compatibility).
|
|
312
414
|
|
|
313
415
|
Args:
|
|
314
|
-
layer_alignments: List of aligned layers
|
|
416
|
+
layer_alignments: List of aligned layers (unused - kept for API compatibility)
|
|
417
|
+
amd_kernels: Optional list of AMD kernel events for pattern-based analysis
|
|
418
|
+
nvidia_kernels: Optional list of NVIDIA kernel events for pattern-based analysis
|
|
315
419
|
|
|
316
420
|
Returns:
|
|
317
|
-
Dictionary with fusion analysis results
|
|
421
|
+
Dictionary with fusion analysis results
|
|
318
422
|
"""
|
|
319
|
-
|
|
423
|
+
# If raw kernels provided, use pattern-based analysis (preferred)
|
|
424
|
+
if amd_kernels is not None and nvidia_kernels is not None:
|
|
425
|
+
fusion_analysis = detect_fusion_patterns(amd_kernels, nvidia_kernels)
|
|
426
|
+
else:
|
|
427
|
+
# Fallback: empty analysis if no kernel data
|
|
428
|
+
fusion_analysis = FusionAnalysis(
|
|
429
|
+
patterns=[],
|
|
430
|
+
summary={"amd_fuses": 0, "nvidia_fuses": 0, "total_fusion_opportunities": 0},
|
|
431
|
+
)
|
|
320
432
|
|
|
321
433
|
fusion_opportunities = []
|
|
322
434
|
fusion_mappings = []
|
|
323
435
|
|
|
324
436
|
for pattern in fusion_analysis.patterns:
|
|
325
437
|
unfused_platform = "NVIDIA" if pattern.fused_platform == "AMD" else "AMD"
|
|
326
|
-
|
|
327
|
-
fusion_opportunities.append(
|
|
328
|
-
{
|
|
329
|
-
"kernel_type": pattern.operation,
|
|
330
|
-
"layer": pattern.layer,
|
|
331
|
-
"fused_by": pattern.fused_platform,
|
|
332
|
-
"fused_kernel": pattern.fused_kernel,
|
|
333
|
-
"unfused_kernels": pattern.unfused_kernels,
|
|
334
|
-
"count": pattern.count,
|
|
335
|
-
"evidence": pattern.evidence,
|
|
336
|
-
}
|
|
337
|
-
)
|
|
338
438
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
439
|
+
fusion_opportunities.append({
|
|
440
|
+
"kernel_type": pattern.operation,
|
|
441
|
+
"layer": pattern.layer,
|
|
442
|
+
"fused_by": pattern.fused_platform,
|
|
443
|
+
"fused_kernel": pattern.fused_kernel,
|
|
444
|
+
"unfused_kernels": pattern.unfused_kernels,
|
|
445
|
+
"count": pattern.count,
|
|
446
|
+
"evidence": pattern.evidence,
|
|
447
|
+
})
|
|
448
|
+
|
|
449
|
+
fusion_mappings.append({
|
|
450
|
+
"fused_platform": pattern.fused_platform,
|
|
451
|
+
"fused_kernel_type": pattern.operation,
|
|
452
|
+
"fused_kernel_name": pattern.fused_kernel,
|
|
453
|
+
"unfused_platform": unfused_platform,
|
|
454
|
+
"unfused_sequence": pattern.unfused_kernels,
|
|
455
|
+
"pattern_count": pattern.count,
|
|
456
|
+
"evidence": pattern.evidence,
|
|
457
|
+
"layer": pattern.layer,
|
|
458
|
+
})
|
|
351
459
|
|
|
352
460
|
return {
|
|
353
461
|
"fusion_opportunities": fusion_opportunities,
|