wafer-core 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/PERFORMANCE.md +148 -0
- wafer_core/lib/trace_compare/__init__.py +22 -9
- wafer_core/lib/trace_compare/aligner.py +376 -0
- wafer_core/lib/trace_compare/analyzer.py +558 -159
- wafer_core/lib/trace_compare/api.py +225 -0
- wafer_core/lib/trace_compare/architecture.py +77 -0
- wafer_core/lib/trace_compare/classifier.py +307 -13
- wafer_core/lib/trace_compare/fusion_analyzer.py +280 -706
- wafer_core/lib/trace_compare/kernel_registry.yaml +349 -0
- wafer_core/lib/trace_compare/layer_segmentation.py +114 -0
- wafer_core/lib/trace_compare/loader.py +526 -227
- wafer_core/lib/trace_compare/same_kernel_analyzer.py +119 -0
- wafer_core/lib/trace_compare/warnings.py +99 -0
- wafer_core/targets/__init__.py +47 -21
- wafer_core/targets/pool.py +181 -0
- wafer_core/targets/probe.py +113 -0
- wafer_core/targets/providers/__init__.py +46 -0
- wafer_core/targets/providers/baremetal.py +72 -0
- wafer_core/targets/providers/digitalocean.py +164 -0
- wafer_core/targets/providers/runpod.py +250 -0
- wafer_core/targets/reconcile.py +90 -0
- wafer_core/targets/spec_store.py +200 -0
- wafer_core/targets/state_cache.py +150 -0
- wafer_core/targets/types.py +141 -0
- wafer_core/utils/kernel_utils/targets/config.py +8 -24
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/METADATA +3 -1
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/RECORD +28 -10
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/WHEEL +0 -0
|
@@ -1,144 +1,45 @@
|
|
|
1
1
|
"""Fusion analysis for detecting kernel fusion differences between platforms.
|
|
2
2
|
|
|
3
|
-
Detects fusion differences between AMD and NVIDIA by analyzing
|
|
4
|
-
each platform
|
|
3
|
+
Detects fusion differences between AMD and NVIDIA by analyzing:
|
|
4
|
+
1. Kernel types unique to each platform (one platform fuses them away)
|
|
5
|
+
2. Sequence patterns (what comes before/after unique kernels)
|
|
6
|
+
3. Count imbalances (one platform has significantly more calls)
|
|
7
|
+
|
|
8
|
+
This is the pattern-based approach which is more reliable than alignment-based detection.
|
|
5
9
|
"""
|
|
6
10
|
|
|
7
|
-
import json
|
|
8
11
|
from collections import Counter, defaultdict
|
|
9
|
-
from
|
|
12
|
+
from dataclasses import dataclass, field
|
|
10
13
|
from typing import Any
|
|
11
14
|
|
|
12
|
-
from .classifier import
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def _load_trace_for_fusion(
|
|
16
|
-
file_path: str | Path,
|
|
17
|
-
) -> tuple[str, str, list[dict[str, Any]], dict[int, list[dict[str, Any]]]]:
|
|
18
|
-
"""Load trace and group kernels by correlation ID.
|
|
19
|
-
|
|
20
|
-
Args:
|
|
21
|
-
file_path: Path to trace file
|
|
22
|
-
|
|
23
|
-
Returns:
|
|
24
|
-
Tuple of (platform, gpu_name, all_kernels, corr_groups)
|
|
25
|
-
"""
|
|
26
|
-
with open(file_path, "rb") as f:
|
|
27
|
-
trace = json.load(f)
|
|
28
|
-
|
|
29
|
-
# Detect platform
|
|
30
|
-
props = trace.get("deviceProperties", [{}])[0]
|
|
31
|
-
is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
|
|
32
|
-
platform = "AMD" if is_amd else "NVIDIA"
|
|
33
|
-
gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
|
|
34
|
-
|
|
35
|
-
# Get all kernel events
|
|
36
|
-
events = trace.get("traceEvents", [])
|
|
37
|
-
kernels = [e for e in events if e.get("cat") == "kernel"]
|
|
38
|
-
|
|
39
|
-
# Group by correlation ID
|
|
40
|
-
corr_groups: dict[int, list[dict[str, Any]]] = defaultdict(list)
|
|
41
|
-
for k in kernels:
|
|
42
|
-
corr_id = k.get("args", {}).get("correlation")
|
|
43
|
-
if corr_id is not None:
|
|
44
|
-
corr_groups[corr_id].append(k)
|
|
15
|
+
from .classifier import classify, Op
|
|
45
16
|
|
|
46
|
-
return platform, gpu_name, kernels, dict(corr_groups)
|
|
47
17
|
|
|
18
|
+
@dataclass
|
|
19
|
+
class FusionPattern:
|
|
20
|
+
"""A detected fusion pattern."""
|
|
48
21
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
22
|
+
layer: int
|
|
23
|
+
operation: str # The fused operation name (e.g., "RMSNorm+GEMM")
|
|
24
|
+
fused_platform: str # Platform that fuses (has fewer kernels)
|
|
25
|
+
fused_kernel: str # The actual fused kernel name
|
|
26
|
+
unfused_kernels: list[str] # List of separate kernels on the other platform
|
|
27
|
+
count: int
|
|
28
|
+
evidence: str
|
|
53
29
|
|
|
54
|
-
Args:
|
|
55
|
-
kernels: List of kernel events in the group
|
|
56
|
-
|
|
57
|
-
Returns:
|
|
58
|
-
Tuple of (counts, timings) where counts maps kernel types to counts
|
|
59
|
-
and timings maps kernel types to total duration in microseconds
|
|
60
|
-
"""
|
|
61
|
-
counts: Counter[str] = Counter()
|
|
62
|
-
timings: dict[str, float] = defaultdict(float)
|
|
63
|
-
|
|
64
|
-
for k in kernels:
|
|
65
|
-
kernel_type = classify_kernel(k.get("name", ""))
|
|
66
|
-
counts[kernel_type] += 1
|
|
67
|
-
timings[kernel_type] += k.get("dur", 0)
|
|
68
|
-
|
|
69
|
-
return dict(counts), dict(timings)
|
|
70
30
|
|
|
31
|
+
@dataclass
|
|
32
|
+
class FusionAnalysis:
|
|
33
|
+
"""Complete fusion analysis result."""
|
|
71
34
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
nv_groups: dict[int, list[dict[str, Any]]],
|
|
75
|
-
size_tolerance: float = 0.25,
|
|
76
|
-
) -> list[tuple[int, int]]:
|
|
77
|
-
"""Match AMD and NVIDIA correlation groups by size and composition.
|
|
35
|
+
patterns: list[FusionPattern] = field(default_factory=list)
|
|
36
|
+
summary: dict[str, Any] = field(default_factory=dict)
|
|
78
37
|
|
|
79
|
-
Since correlation IDs don't match between platforms, we match groups that
|
|
80
|
-
have similar sizes (within tolerance) and likely represent the same operation.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
amd_groups: AMD correlation groups
|
|
84
|
-
nv_groups: NVIDIA correlation groups
|
|
85
|
-
size_tolerance: Groups match if sizes are within this fraction (25% default, increased for better matching)
|
|
86
|
-
|
|
87
|
-
Returns:
|
|
88
|
-
List of (amd_corr_id, nv_corr_id) pairs
|
|
89
|
-
"""
|
|
90
|
-
matches = []
|
|
91
|
-
used_nv_ids: set[int] = set()
|
|
92
38
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
# Build size-indexed lookup for NVIDIA groups for faster filtering
|
|
98
|
-
nv_by_size = [(id, len(kernels)) for id, kernels in nv_groups.items()]
|
|
99
|
-
|
|
100
|
-
# Sort AMD groups by size (largest first for better matching)
|
|
101
|
-
amd_sorted = sorted(amd_groups.items(), key=lambda x: len(x[1]), reverse=True)
|
|
102
|
-
|
|
103
|
-
for amd_id, amd_kernels in amd_sorted:
|
|
104
|
-
amd_size = len(amd_kernels)
|
|
105
|
-
amd_comp = amd_comps[amd_id]
|
|
106
|
-
|
|
107
|
-
# Calculate size bounds for matching
|
|
108
|
-
min_size = amd_size / (1 + size_tolerance)
|
|
109
|
-
max_size = amd_size * (1 + size_tolerance)
|
|
110
|
-
|
|
111
|
-
# Find best matching NVIDIA group
|
|
112
|
-
best_match = None
|
|
113
|
-
best_score = float("inf")
|
|
114
|
-
|
|
115
|
-
for nv_id, nv_size in nv_by_size:
|
|
116
|
-
if nv_id in used_nv_ids:
|
|
117
|
-
continue
|
|
118
|
-
|
|
119
|
-
# Quick size filter
|
|
120
|
-
if nv_size < min_size or nv_size > max_size:
|
|
121
|
-
continue
|
|
122
|
-
|
|
123
|
-
# Use pre-computed composition
|
|
124
|
-
nv_comp = nv_comps[nv_id]
|
|
125
|
-
|
|
126
|
-
# Simple similarity: number of shared kernel types
|
|
127
|
-
shared_types = set(amd_comp.keys()) & set(nv_comp.keys())
|
|
128
|
-
similarity = len(shared_types)
|
|
129
|
-
|
|
130
|
-
# Prefer matches with more shared types and closer sizes
|
|
131
|
-
score = abs(amd_size - nv_size) - (similarity * 10)
|
|
132
|
-
|
|
133
|
-
if score < best_score:
|
|
134
|
-
best_score = score
|
|
135
|
-
best_match = nv_id
|
|
136
|
-
|
|
137
|
-
if best_match is not None:
|
|
138
|
-
matches.append((amd_id, best_match))
|
|
139
|
-
used_nv_ids.add(best_match)
|
|
140
|
-
|
|
141
|
-
return matches
|
|
39
|
+
def _classify_kernel(name: str, platform: str = "AMD") -> str:
|
|
40
|
+
"""Classify a kernel name to its operation type."""
|
|
41
|
+
op, _pattern = classify(name, platform)
|
|
42
|
+
return op.value
|
|
142
43
|
|
|
143
44
|
|
|
144
45
|
def _find_fusion_mappings(
|
|
@@ -146,6 +47,7 @@ def _find_fusion_mappings(
|
|
|
146
47
|
trace2_kernels: list[dict],
|
|
147
48
|
trace1_name: str = "Trace1",
|
|
148
49
|
trace2_name: str = "Trace2",
|
|
50
|
+
trace1_platform: str = "AMD",
|
|
149
51
|
) -> list[dict]:
|
|
150
52
|
"""Find fusion mappings by analyzing kernel execution sequence patterns.
|
|
151
53
|
|
|
@@ -157,29 +59,21 @@ def _find_fusion_mappings(
|
|
|
157
59
|
trace2_kernels: List of kernel events from second trace
|
|
158
60
|
trace1_name: Name of first platform (e.g., "AMD")
|
|
159
61
|
trace2_name: Name of second platform (e.g., "NVIDIA")
|
|
62
|
+
trace1_platform: Platform string for classification ("AMD" or "NVIDIA")
|
|
160
63
|
|
|
161
64
|
Returns:
|
|
162
|
-
List of mapping dictionaries
|
|
163
|
-
- fused_platform: Which platform fuses the operations
|
|
164
|
-
- fused_kernel_type: The single fused kernel type
|
|
165
|
-
- unfused_platform: Which platform runs them separately
|
|
166
|
-
- unfused_sequence: List of kernel types run separately
|
|
167
|
-
- pattern_count: How many times this pattern appears
|
|
168
|
-
- pattern_confidence: Fraction of occurrences following this pattern
|
|
169
|
-
- evidence: Human-readable description
|
|
65
|
+
List of mapping dictionaries with fusion details
|
|
170
66
|
"""
|
|
171
|
-
from collections import defaultdict
|
|
172
|
-
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
173
|
-
|
|
174
67
|
mappings = []
|
|
68
|
+
trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
|
|
175
69
|
|
|
176
70
|
# Sort kernels by timestamp
|
|
177
|
-
trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get(
|
|
178
|
-
trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get(
|
|
71
|
+
trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get("ts", 0))
|
|
72
|
+
trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get("ts", 0))
|
|
179
73
|
|
|
180
74
|
# Classify all kernels
|
|
181
|
-
trace1_types = [
|
|
182
|
-
trace2_types = [
|
|
75
|
+
trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_sorted]
|
|
76
|
+
trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_sorted]
|
|
183
77
|
|
|
184
78
|
# Find kernel types unique to each trace
|
|
185
79
|
trace1_type_set = set(trace1_types)
|
|
@@ -190,30 +84,34 @@ def _find_fusion_mappings(
|
|
|
190
84
|
|
|
191
85
|
# For each unique type in trace1, find common sequence patterns
|
|
192
86
|
for unique_type in trace1_only:
|
|
87
|
+
# Skip "Other" since it's too generic
|
|
88
|
+
if unique_type == "Other":
|
|
89
|
+
continue
|
|
90
|
+
|
|
193
91
|
# Find all occurrences of this type
|
|
194
92
|
indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
|
|
195
93
|
|
|
196
94
|
if len(indices) < 5: # Need enough samples to be meaningful
|
|
197
95
|
continue
|
|
198
96
|
|
|
199
|
-
# Analyze what comes before
|
|
200
|
-
before_types = defaultdict(int)
|
|
97
|
+
# Analyze what comes before each occurrence
|
|
98
|
+
before_types: dict[str, int] = defaultdict(int)
|
|
201
99
|
|
|
202
100
|
for idx in indices:
|
|
203
101
|
if idx > 0:
|
|
204
102
|
before_types[trace1_types[idx - 1]] += 1
|
|
205
103
|
|
|
206
|
-
# Find the most common pattern
|
|
207
|
-
|
|
104
|
+
# Find the most common pattern
|
|
105
|
+
if not before_types:
|
|
106
|
+
continue
|
|
107
|
+
most_common_before = max(before_types.items(), key=lambda x: x[1])
|
|
208
108
|
|
|
209
109
|
# If there's a strong pattern (>80% of occurrences)
|
|
210
110
|
if most_common_before[1] / len(indices) > 0.8:
|
|
211
|
-
# This suggests: Trace2 likely fuses [before_type + unique_type] into [before_type]
|
|
212
111
|
fusion_candidate = most_common_before[0]
|
|
213
112
|
|
|
214
113
|
# Verify trace2 has this type
|
|
215
114
|
if fusion_candidate in trace2_type_set:
|
|
216
|
-
# Count occurrences to compare
|
|
217
115
|
trace1_fusion_count = trace1_types.count(fusion_candidate)
|
|
218
116
|
trace2_fusion_count = trace2_types.count(fusion_candidate)
|
|
219
117
|
|
|
@@ -225,15 +123,18 @@ def _find_fusion_mappings(
|
|
|
225
123
|
"unfused_sequence": [fusion_candidate, unique_type],
|
|
226
124
|
"unfused_count_per_type": {
|
|
227
125
|
fusion_candidate: trace1_fusion_count,
|
|
228
|
-
unique_type: len(indices)
|
|
126
|
+
unique_type: len(indices),
|
|
229
127
|
},
|
|
230
128
|
"pattern_count": len(indices),
|
|
231
129
|
"pattern_confidence": most_common_before[1] / len(indices),
|
|
232
|
-
"evidence": f"{trace1_name} runs {fusion_candidate}+{unique_type} separately, {trace2_name} fuses into {fusion_candidate}"
|
|
130
|
+
"evidence": f"{trace1_name} runs {fusion_candidate}+{unique_type} separately, {trace2_name} fuses into {fusion_candidate}",
|
|
233
131
|
})
|
|
234
132
|
|
|
235
133
|
# Also check trace2-only types
|
|
236
134
|
for unique_type in trace2_only:
|
|
135
|
+
if unique_type == "Other":
|
|
136
|
+
continue
|
|
137
|
+
|
|
237
138
|
indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
|
|
238
139
|
|
|
239
140
|
if len(indices) < 5:
|
|
@@ -245,7 +146,9 @@ def _find_fusion_mappings(
|
|
|
245
146
|
if idx > 0:
|
|
246
147
|
before_types[trace2_types[idx - 1]] += 1
|
|
247
148
|
|
|
248
|
-
|
|
149
|
+
if not before_types:
|
|
150
|
+
continue
|
|
151
|
+
most_common_before = max(before_types.items(), key=lambda x: x[1])
|
|
249
152
|
|
|
250
153
|
if most_common_before[1] / len(indices) > 0.8:
|
|
251
154
|
fusion_candidate = most_common_before[0]
|
|
@@ -262,629 +165,300 @@ def _find_fusion_mappings(
|
|
|
262
165
|
"unfused_sequence": [fusion_candidate, unique_type],
|
|
263
166
|
"unfused_count_per_type": {
|
|
264
167
|
fusion_candidate: trace2_fusion_count,
|
|
265
|
-
unique_type: len(indices)
|
|
168
|
+
unique_type: len(indices),
|
|
266
169
|
},
|
|
267
170
|
"pattern_count": len(indices),
|
|
268
171
|
"pattern_confidence": most_common_before[1] / len(indices),
|
|
269
|
-
"evidence": f"{trace2_name} runs {fusion_candidate}+{unique_type} separately, {trace1_name} fuses into {fusion_candidate}"
|
|
270
|
-
})
|
|
271
|
-
|
|
272
|
-
# NEW: Check for partial fusion patterns (kernel exists on both platforms but with different counts)
|
|
273
|
-
# If one platform has significantly fewer calls (>1.3x ratio), look for fusion patterns
|
|
274
|
-
common_types = trace1_type_set & trace2_type_set
|
|
275
|
-
|
|
276
|
-
for ktype in common_types:
|
|
277
|
-
trace1_count = trace1_types.count(ktype)
|
|
278
|
-
trace2_count = trace2_types.count(ktype)
|
|
279
|
-
|
|
280
|
-
# Check if there's a significant imbalance (one has >1.3x more)
|
|
281
|
-
if trace1_count == 0 or trace2_count == 0:
|
|
282
|
-
continue
|
|
283
|
-
|
|
284
|
-
ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
|
|
285
|
-
|
|
286
|
-
if ratio < 1.3 or trace1_count + trace2_count < 100:
|
|
287
|
-
continue
|
|
288
|
-
|
|
289
|
-
# Determine which platform has more (unfused) and which has fewer (fused)
|
|
290
|
-
if trace1_count > trace2_count:
|
|
291
|
-
# Trace1 has more separate calls, Trace2 likely fuses
|
|
292
|
-
unfused_platform = trace1_name
|
|
293
|
-
fused_platform = trace2_name
|
|
294
|
-
unfused_count = trace1_count
|
|
295
|
-
fused_count = trace2_count
|
|
296
|
-
|
|
297
|
-
# Find what Trace2 might be fusing this into
|
|
298
|
-
# Use the most common kernel type in Trace2 as a likely fusion target
|
|
299
|
-
trace2_type_counts = defaultdict(int)
|
|
300
|
-
for t in trace2_types:
|
|
301
|
-
if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
|
|
302
|
-
trace2_type_counts[t] += 1
|
|
303
|
-
|
|
304
|
-
if trace2_type_counts:
|
|
305
|
-
# Use the most common type as the fusion target
|
|
306
|
-
fusion_target = max(trace2_type_counts.items(), key=lambda x: x[1])[0]
|
|
307
|
-
|
|
308
|
-
mappings.append({
|
|
309
|
-
"fused_platform": fused_platform,
|
|
310
|
-
"fused_kernel_type": fusion_target,
|
|
311
|
-
"fused_count": fused_count,
|
|
312
|
-
"unfused_platform": unfused_platform,
|
|
313
|
-
"unfused_sequence": [ktype],
|
|
314
|
-
"unfused_count_per_type": {
|
|
315
|
-
ktype: unfused_count
|
|
316
|
-
},
|
|
317
|
-
"pattern_count": unfused_count - fused_count,
|
|
318
|
-
"pattern_confidence": (unfused_count - fused_count) / unfused_count,
|
|
319
|
-
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
|
|
320
|
-
})
|
|
321
|
-
else:
|
|
322
|
-
# Trace2 has more separate calls, Trace1 likely fuses
|
|
323
|
-
unfused_platform = trace2_name
|
|
324
|
-
fused_platform = trace1_name
|
|
325
|
-
unfused_count = trace2_count
|
|
326
|
-
fused_count = trace1_count
|
|
327
|
-
|
|
328
|
-
# Find what Trace1 might be fusing this into
|
|
329
|
-
# Use the most common kernel type in Trace1 as a likely fusion target
|
|
330
|
-
trace1_type_counts = defaultdict(int)
|
|
331
|
-
for t in trace1_types:
|
|
332
|
-
if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
|
|
333
|
-
trace1_type_counts[t] += 1
|
|
334
|
-
|
|
335
|
-
if trace1_type_counts:
|
|
336
|
-
fusion_target = max(trace1_type_counts.items(), key=lambda x: x[1])[0]
|
|
337
|
-
|
|
338
|
-
mappings.append({
|
|
339
|
-
"fused_platform": fused_platform,
|
|
340
|
-
"fused_kernel_type": fusion_target,
|
|
341
|
-
"fused_count": fused_count,
|
|
342
|
-
"unfused_platform": unfused_platform,
|
|
343
|
-
"unfused_sequence": [ktype],
|
|
344
|
-
"unfused_count_per_type": {
|
|
345
|
-
ktype: unfused_count
|
|
346
|
-
},
|
|
347
|
-
"pattern_count": unfused_count - fused_count,
|
|
348
|
-
"pattern_confidence": (unfused_count - fused_count) / unfused_count,
|
|
349
|
-
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
|
|
172
|
+
"evidence": f"{trace2_name} runs {fusion_candidate}+{unique_type} separately, {trace1_name} fuses into {fusion_candidate}",
|
|
350
173
|
})
|
|
351
174
|
|
|
352
175
|
return mappings
|
|
353
176
|
|
|
354
177
|
|
|
355
|
-
def
|
|
178
|
+
def _find_count_imbalance_fusions(
|
|
356
179
|
trace1_kernels: list[dict],
|
|
357
180
|
trace2_kernels: list[dict],
|
|
358
|
-
trace1_name: str,
|
|
359
|
-
trace2_name: str,
|
|
181
|
+
trace1_name: str = "Trace1",
|
|
182
|
+
trace2_name: str = "Trace2",
|
|
183
|
+
trace1_platform: str = "AMD",
|
|
360
184
|
) -> list[dict]:
|
|
361
|
-
"""
|
|
185
|
+
"""Find fusions by looking for significant count imbalances.
|
|
362
186
|
|
|
363
|
-
|
|
364
|
-
|
|
187
|
+
When one platform has significantly more kernel calls of a type (>1.5x),
|
|
188
|
+
it suggests the other platform fuses those operations.
|
|
365
189
|
"""
|
|
366
|
-
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
367
|
-
|
|
368
|
-
def analyze_chains(kernels):
|
|
369
|
-
"""Find chains of consecutive same-type kernels"""
|
|
370
|
-
sorted_kernels = sorted(kernels, key=lambda k: k.get('ts', 0))
|
|
371
|
-
types = [classify_kernel(k['name']) for k in sorted_kernels]
|
|
372
|
-
|
|
373
|
-
chains = defaultdict(list)
|
|
374
|
-
i = 0
|
|
375
|
-
while i < len(types):
|
|
376
|
-
ktype = types[i]
|
|
377
|
-
count = 0
|
|
378
|
-
while i < len(types) and types[i] == ktype:
|
|
379
|
-
count += 1
|
|
380
|
-
i += 1
|
|
381
|
-
chains[ktype].append(count)
|
|
382
|
-
|
|
383
|
-
return chains
|
|
384
|
-
|
|
385
|
-
trace1_chains = analyze_chains(trace1_kernels)
|
|
386
|
-
trace2_chains = analyze_chains(trace2_kernels)
|
|
387
|
-
|
|
388
190
|
mappings = []
|
|
389
|
-
|
|
191
|
+
trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
|
|
390
192
|
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
193
|
+
# Classify all kernels
|
|
194
|
+
trace1_types = [_classify_kernel(k.get("name", ""), trace1_platform) for k in trace1_kernels]
|
|
195
|
+
trace2_types = [_classify_kernel(k.get("name", ""), trace2_platform) for k in trace2_kernels]
|
|
394
196
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
197
|
+
# Count by type
|
|
198
|
+
trace1_counts = Counter(trace1_types)
|
|
199
|
+
trace2_counts = Counter(trace2_types)
|
|
398
200
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
t2_multi = [l for l in t2_lengths if l > 1]
|
|
201
|
+
# Find common types with significant differences
|
|
202
|
+
common_types = set(trace1_counts.keys()) & set(trace2_counts.keys())
|
|
402
203
|
|
|
403
|
-
|
|
204
|
+
for ktype in common_types:
|
|
205
|
+
if ktype == "Other":
|
|
404
206
|
continue
|
|
405
207
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
t1_chains = len(t1_multi) if t1_multi else len(t1_lengths)
|
|
409
|
-
t2_chains = len(t2_multi) if t2_multi else len(t2_lengths)
|
|
208
|
+
trace1_count = trace1_counts[ktype]
|
|
209
|
+
trace2_count = trace2_counts[ktype]
|
|
410
210
|
|
|
411
|
-
|
|
211
|
+
# Skip trivial counts
|
|
212
|
+
if trace1_count + trace2_count < 50:
|
|
412
213
|
continue
|
|
413
214
|
|
|
414
|
-
|
|
415
|
-
t2_avg_chain = sum(t2_multi) / len(t2_multi) if t2_multi else 1.0
|
|
416
|
-
|
|
417
|
-
chain_ratio = max(t1_avg_chain, t2_avg_chain) / min(t1_avg_chain, t2_avg_chain)
|
|
418
|
-
|
|
419
|
-
# Significant intra-fusion if chains are 2x+ different
|
|
420
|
-
if chain_ratio > 2.0 and abs(t1_total - t2_total) > 100:
|
|
421
|
-
if t1_avg_chain > t2_avg_chain:
|
|
422
|
-
unfused_platform = trace1_name
|
|
423
|
-
fused_platform = trace2_name
|
|
424
|
-
unfused_chains = t1_chains
|
|
425
|
-
fused_chains = t2_chains
|
|
426
|
-
unfused_avg = t1_avg_chain
|
|
427
|
-
fused_avg = t2_avg_chain
|
|
428
|
-
unfused_total = t1_total
|
|
429
|
-
fused_total = t2_total
|
|
430
|
-
else:
|
|
431
|
-
unfused_platform = trace2_name
|
|
432
|
-
fused_platform = trace1_name
|
|
433
|
-
unfused_chains = t2_chains
|
|
434
|
-
fused_chains = t1_chains
|
|
435
|
-
unfused_avg = t2_avg_chain
|
|
436
|
-
fused_avg = t1_avg_chain
|
|
437
|
-
unfused_total = t2_total
|
|
438
|
-
fused_total = t1_total
|
|
439
|
-
|
|
440
|
-
mappings.append({
|
|
441
|
-
"fused_platform": fused_platform,
|
|
442
|
-
"fused_kernel_type": ktype,
|
|
443
|
-
"fused_count": fused_total,
|
|
444
|
-
"unfused_platform": unfused_platform,
|
|
445
|
-
"unfused_sequence": [ktype, ktype], # Same type repeated
|
|
446
|
-
"unfused_count_per_type": {ktype: unfused_total},
|
|
447
|
-
"pattern_count": unfused_total - fused_total,
|
|
448
|
-
"pattern_confidence": min(unfused_chains, fused_chains) / max(unfused_chains, fused_chains),
|
|
449
|
-
"evidence": f"{unfused_platform} runs {ktype} in chains of {unfused_avg:.0f} calls ({unfused_chains} chains, {unfused_total:,} total), {fused_platform} fuses to {fused_avg:.0f} calls ({fused_chains} chains, {fused_total:,} total) - {chain_ratio:.1f}x more efficient"
|
|
450
|
-
})
|
|
451
|
-
|
|
452
|
-
return mappings
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
def _find_partial_fusion_via_groups(
|
|
456
|
-
trace1_large: dict[int, list[dict]],
|
|
457
|
-
trace2_large: dict[int, list[dict]],
|
|
458
|
-
matches: list[tuple[int, int]],
|
|
459
|
-
trace1_name: str,
|
|
460
|
-
trace2_name: str,
|
|
461
|
-
) -> list[dict]:
|
|
462
|
-
"""Find partial fusion patterns by analyzing correlation group differences.
|
|
463
|
-
|
|
464
|
-
When one platform has fewer of a kernel type, check what kernel types the
|
|
465
|
-
other platform has MORE of in those same groups - those are likely fusion targets.
|
|
466
|
-
"""
|
|
467
|
-
from collections import Counter
|
|
468
|
-
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
469
|
-
|
|
470
|
-
mappings = []
|
|
471
|
-
|
|
472
|
-
# For each matched pair, track kernel type counts
|
|
473
|
-
trace1_all_types = []
|
|
474
|
-
trace2_all_types = []
|
|
475
|
-
|
|
476
|
-
for trace1_cid, trace2_cid in matches:
|
|
477
|
-
trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
478
|
-
trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
479
|
-
trace1_all_types.extend(trace1_ktypes)
|
|
480
|
-
trace2_all_types.extend(trace2_ktypes)
|
|
481
|
-
|
|
482
|
-
# Find kernel types with significant imbalances
|
|
483
|
-
trace1_counts = Counter(trace1_all_types)
|
|
484
|
-
trace2_counts = Counter(trace2_all_types)
|
|
485
|
-
all_types = set(trace1_counts.keys()) | set(trace2_counts.keys())
|
|
486
|
-
|
|
487
|
-
for ktype in all_types:
|
|
488
|
-
trace1_count = trace1_counts.get(ktype, 0)
|
|
489
|
-
trace2_count = trace2_counts.get(ktype, 0)
|
|
490
|
-
|
|
215
|
+
# Check if there's a significant imbalance (>1.5x)
|
|
491
216
|
if trace1_count == 0 or trace2_count == 0:
|
|
492
|
-
continue
|
|
217
|
+
continue
|
|
493
218
|
|
|
494
219
|
ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
|
|
495
220
|
|
|
496
|
-
if ratio < 1.
|
|
497
|
-
continue
|
|
221
|
+
if ratio < 1.5:
|
|
222
|
+
continue
|
|
498
223
|
|
|
499
|
-
# Determine which platform has fewer (
|
|
224
|
+
# Determine which platform has more (unfused) and which has fewer (fused)
|
|
500
225
|
if trace1_count > trace2_count:
|
|
501
226
|
unfused_platform = trace1_name
|
|
502
227
|
fused_platform = trace2_name
|
|
503
228
|
unfused_count = trace1_count
|
|
504
229
|
fused_count = trace2_count
|
|
505
|
-
|
|
506
|
-
# Find groups where trace1 has this kernel but trace2 doesn't
|
|
507
|
-
fusion_targets = Counter()
|
|
508
|
-
groups_analyzed = 0
|
|
509
|
-
|
|
510
|
-
for trace1_cid, trace2_cid in matches:
|
|
511
|
-
trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
512
|
-
trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
513
|
-
|
|
514
|
-
trace1_has = ktype in trace1_ktypes
|
|
515
|
-
trace2_has = ktype in trace2_ktypes
|
|
516
|
-
|
|
517
|
-
if trace1_has and not trace2_has:
|
|
518
|
-
# What does trace2 have MORE of in this group?
|
|
519
|
-
trace1_kcounts = Counter(trace1_ktypes)
|
|
520
|
-
trace2_kcounts = Counter(trace2_ktypes)
|
|
521
|
-
|
|
522
|
-
for other_type in set(trace2_kcounts.keys()):
|
|
523
|
-
if other_type == ktype or other_type == "Other":
|
|
524
|
-
continue
|
|
525
|
-
diff = trace2_kcounts[other_type] - trace1_kcounts.get(other_type, 0)
|
|
526
|
-
if diff > 0:
|
|
527
|
-
fusion_targets[other_type] += diff
|
|
528
|
-
|
|
529
|
-
groups_analyzed += 1
|
|
530
|
-
|
|
531
|
-
if fusion_targets and groups_analyzed >= 5:
|
|
532
|
-
# Report top fusion targets
|
|
533
|
-
top_targets = fusion_targets.most_common(3)
|
|
534
|
-
target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
|
|
535
|
-
|
|
536
|
-
mappings.append({
|
|
537
|
-
"fused_platform": fused_platform,
|
|
538
|
-
"fused_kernel_type": top_targets[0][0],
|
|
539
|
-
"fused_count": fused_count,
|
|
540
|
-
"unfused_platform": unfused_platform,
|
|
541
|
-
"unfused_sequence": [ktype],
|
|
542
|
-
"unfused_count_per_type": {ktype: unfused_count},
|
|
543
|
-
"pattern_count": unfused_count - fused_count,
|
|
544
|
-
"pattern_confidence": groups_analyzed / len(matches) if matches else 0,
|
|
545
|
-
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
|
|
546
|
-
})
|
|
547
230
|
else:
|
|
548
|
-
# Symmetric case for trace2 > trace1
|
|
549
231
|
unfused_platform = trace2_name
|
|
550
232
|
fused_platform = trace1_name
|
|
551
233
|
unfused_count = trace2_count
|
|
552
234
|
fused_count = trace1_count
|
|
553
235
|
|
|
554
|
-
|
|
555
|
-
|
|
236
|
+
mappings.append({
|
|
237
|
+
"fused_platform": fused_platform,
|
|
238
|
+
"fused_kernel_type": ktype,
|
|
239
|
+
"fused_count": fused_count,
|
|
240
|
+
"unfused_platform": unfused_platform,
|
|
241
|
+
"unfused_sequence": [ktype],
|
|
242
|
+
"unfused_count_per_type": {ktype: unfused_count},
|
|
243
|
+
"pattern_count": unfused_count - fused_count,
|
|
244
|
+
"pattern_confidence": (unfused_count - fused_count) / unfused_count,
|
|
245
|
+
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses",
|
|
246
|
+
})
|
|
556
247
|
|
|
557
|
-
|
|
558
|
-
trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
559
|
-
trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
248
|
+
return mappings
|
|
560
249
|
|
|
561
|
-
trace1_has = ktype in trace1_ktypes
|
|
562
|
-
trace2_has = ktype in trace2_ktypes
|
|
563
250
|
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
251
|
+
def _find_explicit_fused_operations(
|
|
252
|
+
trace1_kernels: list[dict],
|
|
253
|
+
trace2_kernels: list[dict],
|
|
254
|
+
trace1_name: str = "Trace1",
|
|
255
|
+
trace2_name: str = "Trace2",
|
|
256
|
+
trace1_platform: str = "AMD",
|
|
257
|
+
) -> list[dict]:
|
|
258
|
+
"""Find explicit fused operations (kernels with '+' in their classification).
|
|
567
259
|
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
260
|
+
These are operations like 'RMSNorm+GEMM' that are explicitly classified as fused.
|
|
261
|
+
"""
|
|
262
|
+
mappings = []
|
|
263
|
+
trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
|
|
264
|
+
|
|
265
|
+
def get_fused_ops(kernels: list[dict], platform: str) -> dict[str, list[str]]:
|
|
266
|
+
"""Get fused operations and their kernel names."""
|
|
267
|
+
fused_ops: dict[str, list[str]] = defaultdict(list)
|
|
268
|
+
for k in kernels:
|
|
269
|
+
name = k.get("name", "")
|
|
270
|
+
op, _pattern = classify(name, platform)
|
|
271
|
+
if "+" in op.value or op == Op.FUSED_UNKNOWN:
|
|
272
|
+
fused_ops[op.value].append(name)
|
|
273
|
+
return dict(fused_ops)
|
|
274
|
+
|
|
275
|
+
trace1_fused = get_fused_ops(trace1_kernels, trace1_platform)
|
|
276
|
+
trace2_fused = get_fused_ops(trace2_kernels, trace2_platform)
|
|
277
|
+
|
|
278
|
+
# Find fused operations unique to trace1
|
|
279
|
+
for fused_op, kernels in trace1_fused.items():
|
|
280
|
+
if fused_op not in trace2_fused:
|
|
281
|
+
# Parse components from the fused op name
|
|
282
|
+
if "+" in fused_op:
|
|
283
|
+
components = [c.strip() for c in fused_op.split("+")]
|
|
284
|
+
else:
|
|
285
|
+
components = [fused_op]
|
|
574
286
|
|
|
575
|
-
|
|
287
|
+
mappings.append({
|
|
288
|
+
"fused_platform": trace1_name,
|
|
289
|
+
"fused_kernel_type": fused_op,
|
|
290
|
+
"fused_count": len(kernels),
|
|
291
|
+
"unfused_platform": trace2_name,
|
|
292
|
+
"unfused_sequence": components,
|
|
293
|
+
"unfused_count_per_type": {c: 0 for c in components}, # Unknown
|
|
294
|
+
"pattern_count": len(kernels),
|
|
295
|
+
"pattern_confidence": 1.0,
|
|
296
|
+
"evidence": f"{trace1_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
|
|
297
|
+
"fused_kernel_names": kernels[:3], # Sample of kernel names
|
|
298
|
+
})
|
|
576
299
|
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
300
|
+
# Find fused operations unique to trace2
|
|
301
|
+
for fused_op, kernels in trace2_fused.items():
|
|
302
|
+
if fused_op not in trace1_fused:
|
|
303
|
+
if "+" in fused_op:
|
|
304
|
+
components = [c.strip() for c in fused_op.split("+")]
|
|
305
|
+
else:
|
|
306
|
+
components = [fused_op]
|
|
580
307
|
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
308
|
+
mappings.append({
|
|
309
|
+
"fused_platform": trace2_name,
|
|
310
|
+
"fused_kernel_type": fused_op,
|
|
311
|
+
"fused_count": len(kernels),
|
|
312
|
+
"unfused_platform": trace1_name,
|
|
313
|
+
"unfused_sequence": components,
|
|
314
|
+
"unfused_count_per_type": {c: 0 for c in components},
|
|
315
|
+
"pattern_count": len(kernels),
|
|
316
|
+
"pattern_confidence": 1.0,
|
|
317
|
+
"evidence": f"{trace2_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
|
|
318
|
+
"fused_kernel_names": kernels[:3],
|
|
319
|
+
})
|
|
592
320
|
|
|
593
321
|
return mappings
|
|
594
322
|
|
|
595
323
|
|
|
596
|
-
def
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
324
|
+
def detect_fusion_patterns(
|
|
325
|
+
amd_kernels: list[dict],
|
|
326
|
+
nvidia_kernels: list[dict],
|
|
327
|
+
) -> FusionAnalysis:
|
|
328
|
+
"""Detect fusion patterns using pattern-based analysis.
|
|
329
|
+
|
|
330
|
+
This is the main entry point for fusion detection. It combines:
|
|
331
|
+
1. Explicit fused operations (kernels classified with '+' in name)
|
|
332
|
+
2. Sequence pattern analysis (unique kernel types with consistent patterns)
|
|
333
|
+
3. Count imbalance analysis (one platform has significantly more calls)
|
|
602
334
|
|
|
603
335
|
Args:
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
min_group_size: Only analyze correlation groups with at least this many kernels
|
|
336
|
+
amd_kernels: List of AMD kernel events
|
|
337
|
+
nvidia_kernels: List of NVIDIA kernel events
|
|
607
338
|
|
|
608
339
|
Returns:
|
|
609
|
-
|
|
610
|
-
- metadata: trace info
|
|
611
|
-
- global_counts: kernel type distribution across entire trace
|
|
612
|
-
- fusion_opportunities: significant fusion differences
|
|
613
|
-
- fusion_mappings: actual kernel-to-kernel mappings (NEW)
|
|
340
|
+
FusionAnalysis with detected patterns
|
|
614
341
|
"""
|
|
615
|
-
|
|
616
|
-
trace1_platform, trace1_gpu, trace1_kernels, trace1_corr_groups = _load_trace_for_fusion(
|
|
617
|
-
amd_trace_path
|
|
618
|
-
)
|
|
619
|
-
trace2_platform, trace2_gpu, trace2_kernels, trace2_corr_groups = _load_trace_for_fusion(
|
|
620
|
-
nv_trace_path
|
|
621
|
-
)
|
|
342
|
+
all_mappings: list[dict] = []
|
|
622
343
|
|
|
623
|
-
#
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
}
|
|
629
|
-
trace2_large = {
|
|
630
|
-
cid: kernels
|
|
631
|
-
for cid, kernels in trace2_corr_groups.items()
|
|
632
|
-
if len(kernels) >= min_group_size
|
|
633
|
-
}
|
|
634
|
-
|
|
635
|
-
# Match correlation groups between platforms
|
|
636
|
-
matches = _match_correlation_groups(trace1_large, trace2_large)
|
|
637
|
-
|
|
638
|
-
# Analyze differences in matched groups
|
|
639
|
-
fusion_diffs: dict[str, dict[str, Any]] = defaultdict(
|
|
640
|
-
lambda: {
|
|
641
|
-
"trace1_count": 0,
|
|
642
|
-
"trace2_count": 0,
|
|
643
|
-
"trace1_time_us": 0,
|
|
644
|
-
"trace2_time_us": 0,
|
|
645
|
-
"groups_with_diff": 0,
|
|
646
|
-
"total_groups": 0,
|
|
647
|
-
}
|
|
344
|
+
# 1. Find explicit fused operations (highest confidence)
|
|
345
|
+
explicit_fusions = _find_explicit_fused_operations(
|
|
346
|
+
amd_kernels, nvidia_kernels,
|
|
347
|
+
trace1_name="AMD", trace2_name="NVIDIA",
|
|
348
|
+
trace1_platform="AMD",
|
|
648
349
|
)
|
|
350
|
+
all_mappings.extend(explicit_fusions)
|
|
649
351
|
|
|
650
|
-
#
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
trace2_comp, trace2_times = _analyze_correlation_group(trace2_large[trace2_cid])
|
|
656
|
-
|
|
657
|
-
# Find all kernel types in either platform
|
|
658
|
-
all_types = set(trace1_comp.keys()) | set(trace2_comp.keys())
|
|
659
|
-
|
|
660
|
-
for ktype in all_types:
|
|
661
|
-
trace1_count = trace1_comp.get(ktype, 0)
|
|
662
|
-
trace2_count = trace2_comp.get(ktype, 0)
|
|
663
|
-
trace1_time = trace1_times.get(ktype, 0)
|
|
664
|
-
trace2_time = trace2_times.get(ktype, 0)
|
|
665
|
-
|
|
666
|
-
fusion_diffs[ktype]["trace1_count"] += trace1_count
|
|
667
|
-
fusion_diffs[ktype]["trace2_count"] += trace2_count
|
|
668
|
-
fusion_diffs[ktype]["trace1_time_us"] += trace1_time
|
|
669
|
-
fusion_diffs[ktype]["trace2_time_us"] += trace2_time
|
|
670
|
-
fusion_diffs[ktype]["total_groups"] += 1
|
|
671
|
-
|
|
672
|
-
if trace1_count != trace2_count:
|
|
673
|
-
fusion_diffs[ktype]["groups_with_diff"] += 1
|
|
674
|
-
|
|
675
|
-
# NEW: Find actual kernel mappings in this correlation group
|
|
676
|
-
group_mappings = _find_fusion_mappings(
|
|
677
|
-
trace1_large[trace1_cid],
|
|
678
|
-
trace2_large[trace2_cid],
|
|
679
|
-
trace1_name=trace1_platform,
|
|
680
|
-
trace2_name=trace2_platform
|
|
681
|
-
)
|
|
682
|
-
# Add correlation ID context to each mapping
|
|
683
|
-
for mapping in group_mappings:
|
|
684
|
-
mapping["correlation_group_trace1"] = trace1_cid
|
|
685
|
-
mapping["correlation_group_trace2"] = trace2_cid
|
|
686
|
-
all_fusion_mappings.extend(group_mappings)
|
|
687
|
-
|
|
688
|
-
# Also get global counts for context
|
|
689
|
-
global_trace1_counts: Counter[str] = Counter(
|
|
690
|
-
[classify_kernel(k.get("name", "")) for k in trace1_kernels]
|
|
352
|
+
# 2. Find sequence-based fusions
|
|
353
|
+
sequence_fusions = _find_fusion_mappings(
|
|
354
|
+
amd_kernels, nvidia_kernels,
|
|
355
|
+
trace1_name="AMD", trace2_name="NVIDIA",
|
|
356
|
+
trace1_platform="AMD",
|
|
691
357
|
)
|
|
692
|
-
|
|
693
|
-
|
|
358
|
+
# Deduplicate: skip if same fused_kernel_type already found
|
|
359
|
+
existing_types = {m["fused_kernel_type"] for m in all_mappings}
|
|
360
|
+
for fusion in sequence_fusions:
|
|
361
|
+
if fusion["fused_kernel_type"] not in existing_types:
|
|
362
|
+
all_mappings.append(fusion)
|
|
363
|
+
existing_types.add(fusion["fused_kernel_type"])
|
|
364
|
+
|
|
365
|
+
# 3. Find count-imbalance fusions
|
|
366
|
+
count_fusions = _find_count_imbalance_fusions(
|
|
367
|
+
amd_kernels, nvidia_kernels,
|
|
368
|
+
trace1_name="AMD", trace2_name="NVIDIA",
|
|
369
|
+
trace1_platform="AMD",
|
|
694
370
|
)
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
"
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
trace2_total = global_trace2_counts.get(ktype, 0)
|
|
717
|
-
|
|
718
|
-
if trace1_total > 0 or trace2_total > 0:
|
|
719
|
-
results["global_counts"][ktype] = {
|
|
720
|
-
"trace1_count": trace1_total,
|
|
721
|
-
"trace2_count": trace2_total,
|
|
722
|
-
"ratio": trace1_total / trace2_total if trace2_total > 0 else float("inf"),
|
|
723
|
-
}
|
|
724
|
-
|
|
725
|
-
# Fusion opportunities from matched correlation groups
|
|
726
|
-
for ktype, stats in fusion_diffs.items():
|
|
727
|
-
trace1_avg = (
|
|
728
|
-
stats["trace1_count"] / stats["total_groups"]
|
|
729
|
-
if stats["total_groups"] > 0
|
|
730
|
-
else 0
|
|
731
|
-
)
|
|
732
|
-
trace2_avg = (
|
|
733
|
-
stats["trace2_count"] / stats["total_groups"]
|
|
734
|
-
if stats["total_groups"] > 0
|
|
735
|
-
else 0
|
|
736
|
-
)
|
|
737
|
-
trace1_time_ms = stats["trace1_time_us"] / 1000
|
|
738
|
-
trace2_time_ms = stats["trace2_time_us"] / 1000
|
|
739
|
-
|
|
740
|
-
# Calculate significance
|
|
741
|
-
diff_ratio = trace1_avg / trace2_avg if trace2_avg > 0 else float("inf")
|
|
742
|
-
reverse_ratio = trace2_avg / trace1_avg if trace1_avg > 0 else float("inf")
|
|
743
|
-
|
|
744
|
-
# Only report if there's a significant difference
|
|
745
|
-
# Either: one platform has it and the other doesn't (ratio > 10)
|
|
746
|
-
# Or: one platform has significantly more (ratio > 2.0)
|
|
747
|
-
is_significant = (
|
|
748
|
-
(diff_ratio > 10.0 or reverse_ratio > 10.0)
|
|
749
|
-
or ( # One platform doesn't have it
|
|
750
|
-
(diff_ratio > 2.0 or reverse_ratio > 2.0)
|
|
751
|
-
and stats["trace1_count"] + stats["trace2_count"] > 20 # Significant difference # Not trivial counts
|
|
371
|
+
# Deduplicate
|
|
372
|
+
for fusion in count_fusions:
|
|
373
|
+
if fusion["fused_kernel_type"] not in existing_types:
|
|
374
|
+
all_mappings.append(fusion)
|
|
375
|
+
existing_types.add(fusion["fused_kernel_type"])
|
|
376
|
+
|
|
377
|
+
# Convert to FusionPattern objects
|
|
378
|
+
patterns: list[FusionPattern] = []
|
|
379
|
+
for mapping in all_mappings:
|
|
380
|
+
fused_kernel_names = mapping.get("fused_kernel_names", [])
|
|
381
|
+
fused_kernel = fused_kernel_names[0] if fused_kernel_names else mapping["fused_kernel_type"]
|
|
382
|
+
|
|
383
|
+
patterns.append(
|
|
384
|
+
FusionPattern(
|
|
385
|
+
layer=0, # Pattern-based analysis doesn't track layers
|
|
386
|
+
operation=mapping["fused_kernel_type"],
|
|
387
|
+
fused_platform=mapping["fused_platform"],
|
|
388
|
+
fused_kernel=fused_kernel,
|
|
389
|
+
unfused_kernels=mapping["unfused_sequence"],
|
|
390
|
+
count=mapping["pattern_count"],
|
|
391
|
+
evidence=mapping["evidence"],
|
|
752
392
|
)
|
|
753
393
|
)
|
|
754
394
|
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
if diff_ratio > 1.5:
|
|
758
|
-
fused_by = trace2_platform # Trace 2 has fewer calls, so it fuses more
|
|
759
|
-
ratio = diff_ratio
|
|
760
|
-
else:
|
|
761
|
-
fused_by = trace1_platform # Trace 1 has fewer calls, so it fuses more
|
|
762
|
-
ratio = reverse_ratio
|
|
763
|
-
|
|
764
|
-
# Calculate time ratio (who's faster for this operation)
|
|
765
|
-
time_ratio = trace1_time_ms / trace2_time_ms if trace2_time_ms > 0 else float("inf")
|
|
766
|
-
|
|
767
|
-
results["fusion_opportunities"].append(
|
|
768
|
-
{
|
|
769
|
-
"kernel_type": ktype,
|
|
770
|
-
"trace1_total": stats["trace1_count"],
|
|
771
|
-
"trace2_total": stats["trace2_count"],
|
|
772
|
-
"trace1_avg_per_group": trace1_avg,
|
|
773
|
-
"trace2_avg_per_group": trace2_avg,
|
|
774
|
-
"trace1_time_ms": trace1_time_ms,
|
|
775
|
-
"trace2_time_ms": trace2_time_ms,
|
|
776
|
-
"time_ratio": time_ratio,
|
|
777
|
-
"ratio": ratio,
|
|
778
|
-
"fused_by": fused_by,
|
|
779
|
-
"groups_affected": stats["groups_with_diff"],
|
|
780
|
-
"total_groups": stats["total_groups"],
|
|
781
|
-
}
|
|
782
|
-
)
|
|
395
|
+
amd_fuses = sum(1 for p in patterns if p.fused_platform == "AMD")
|
|
396
|
+
nvidia_fuses = sum(1 for p in patterns if p.fused_platform == "NVIDIA")
|
|
783
397
|
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
398
|
+
return FusionAnalysis(
|
|
399
|
+
patterns=patterns,
|
|
400
|
+
summary={
|
|
401
|
+
"amd_fuses": amd_fuses,
|
|
402
|
+
"nvidia_fuses": nvidia_fuses,
|
|
403
|
+
"total_fusion_opportunities": len(patterns),
|
|
404
|
+
},
|
|
405
|
+
)
|
|
790
406
|
|
|
791
|
-
trace1_total = global_trace1_counts.get(ktype, 0)
|
|
792
|
-
trace2_total = global_trace2_counts.get(ktype, 0)
|
|
793
407
|
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
408
|
+
def analyze_fusion_from_alignment(
|
|
409
|
+
layer_alignments: list[Any],
|
|
410
|
+
amd_kernels: list[dict] | None = None,
|
|
411
|
+
nvidia_kernels: list[dict] | None = None,
|
|
412
|
+
) -> dict[str, Any]:
|
|
413
|
+
"""Analyze fusion from kernel data (for API compatibility).
|
|
797
414
|
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
415
|
+
Args:
|
|
416
|
+
layer_alignments: List of aligned layers (unused - kept for API compatibility)
|
|
417
|
+
amd_kernels: Optional list of AMD kernel events for pattern-based analysis
|
|
418
|
+
nvidia_kernels: Optional list of NVIDIA kernel events for pattern-based analysis
|
|
801
419
|
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
420
|
+
Returns:
|
|
421
|
+
Dictionary with fusion analysis results
|
|
422
|
+
"""
|
|
423
|
+
# If raw kernels provided, use pattern-based analysis (preferred)
|
|
424
|
+
if amd_kernels is not None and nvidia_kernels is not None:
|
|
425
|
+
fusion_analysis = detect_fusion_patterns(amd_kernels, nvidia_kernels)
|
|
426
|
+
else:
|
|
427
|
+
# Fallback: empty analysis if no kernel data
|
|
428
|
+
fusion_analysis = FusionAnalysis(
|
|
429
|
+
patterns=[],
|
|
430
|
+
summary={"amd_fuses": 0, "nvidia_fuses": 0, "total_fusion_opportunities": 0},
|
|
806
431
|
)
|
|
807
432
|
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
"fused_by": fused_by,
|
|
841
|
-
"groups_affected": 0, # Unknown for global analysis
|
|
842
|
-
"total_groups": len(matches),
|
|
843
|
-
}
|
|
844
|
-
)
|
|
845
|
-
|
|
846
|
-
# Sort by impact (ratio * total count)
|
|
847
|
-
results["fusion_opportunities"].sort(
|
|
848
|
-
key=lambda x: x["ratio"] * (x["trace1_total"] + x["trace2_total"]), reverse=True
|
|
849
|
-
)
|
|
850
|
-
|
|
851
|
-
# ADD PARTIAL FUSION MAPPINGS using correlation group differential analysis
|
|
852
|
-
# This catches patterns like Sort that exist on both platforms but with different frequencies
|
|
853
|
-
partial_mappings = _find_partial_fusion_via_groups(
|
|
854
|
-
trace1_large,
|
|
855
|
-
trace2_large,
|
|
856
|
-
matches,
|
|
857
|
-
trace1_name=trace1_platform,
|
|
858
|
-
trace2_name=trace2_platform
|
|
859
|
-
)
|
|
860
|
-
all_fusion_mappings.extend(partial_mappings)
|
|
861
|
-
|
|
862
|
-
# DETECT INTRA-TYPE FUSION (same kernel type fused with itself, like Sort chains)
|
|
863
|
-
# Do this FIRST since it's more accurate than the fallback global analysis
|
|
864
|
-
intra_mappings = _detect_intra_type_fusion(
|
|
865
|
-
trace1_kernels,
|
|
866
|
-
trace2_kernels,
|
|
867
|
-
trace1_name=trace1_platform,
|
|
868
|
-
trace2_name=trace2_platform
|
|
869
|
-
)
|
|
870
|
-
all_fusion_mappings.extend(intra_mappings)
|
|
871
|
-
|
|
872
|
-
# Collect kernel types already handled by intra-type fusion
|
|
873
|
-
intra_handled_types = set(m["fused_kernel_type"] for m in intra_mappings)
|
|
874
|
-
|
|
875
|
-
# ALSO ADD GLOBAL FUSION MAPPINGS for kernels not in large correlation groups
|
|
876
|
-
# Skip types already handled by intra-type fusion (more accurate)
|
|
877
|
-
global_mappings = _find_fusion_mappings(
|
|
878
|
-
trace1_kernels,
|
|
879
|
-
trace2_kernels,
|
|
880
|
-
trace1_name=trace1_platform,
|
|
881
|
-
trace2_name=trace2_platform
|
|
882
|
-
)
|
|
883
|
-
# Filter: skip if already handled or if evidence is duplicate
|
|
884
|
-
existing_evidence = set(m["evidence"] for m in all_fusion_mappings)
|
|
885
|
-
for mapping in global_mappings:
|
|
886
|
-
ktype = mapping["unfused_sequence"][0] if mapping["unfused_sequence"] else None
|
|
887
|
-
if ktype not in intra_handled_types and mapping["evidence"] not in existing_evidence:
|
|
888
|
-
all_fusion_mappings.append(mapping)
|
|
889
|
-
|
|
890
|
-
return results
|
|
433
|
+
fusion_opportunities = []
|
|
434
|
+
fusion_mappings = []
|
|
435
|
+
|
|
436
|
+
for pattern in fusion_analysis.patterns:
|
|
437
|
+
unfused_platform = "NVIDIA" if pattern.fused_platform == "AMD" else "AMD"
|
|
438
|
+
|
|
439
|
+
fusion_opportunities.append({
|
|
440
|
+
"kernel_type": pattern.operation,
|
|
441
|
+
"layer": pattern.layer,
|
|
442
|
+
"fused_by": pattern.fused_platform,
|
|
443
|
+
"fused_kernel": pattern.fused_kernel,
|
|
444
|
+
"unfused_kernels": pattern.unfused_kernels,
|
|
445
|
+
"count": pattern.count,
|
|
446
|
+
"evidence": pattern.evidence,
|
|
447
|
+
})
|
|
448
|
+
|
|
449
|
+
fusion_mappings.append({
|
|
450
|
+
"fused_platform": pattern.fused_platform,
|
|
451
|
+
"fused_kernel_type": pattern.operation,
|
|
452
|
+
"fused_kernel_name": pattern.fused_kernel,
|
|
453
|
+
"unfused_platform": unfused_platform,
|
|
454
|
+
"unfused_sequence": pattern.unfused_kernels,
|
|
455
|
+
"pattern_count": pattern.count,
|
|
456
|
+
"evidence": pattern.evidence,
|
|
457
|
+
"layer": pattern.layer,
|
|
458
|
+
})
|
|
459
|
+
|
|
460
|
+
return {
|
|
461
|
+
"fusion_opportunities": fusion_opportunities,
|
|
462
|
+
"fusion_mappings": fusion_mappings,
|
|
463
|
+
"summary": fusion_analysis.summary,
|
|
464
|
+
}
|