wafer-core 0.1.33__py3-none-any.whl → 0.1.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/__init__.py +9 -22
- wafer_core/lib/trace_compare/analyzer.py +160 -584
- wafer_core/lib/trace_compare/classifier.py +18 -321
- wafer_core/lib/trace_compare/fusion_analyzer.py +753 -329
- wafer_core/lib/trace_compare/loader.py +220 -413
- wafer_core/targets/__init__.py +21 -47
- wafer_core/utils/kernel_utils/defense.py +1 -813
- wafer_core/utils/kernel_utils/targets/config.py +24 -8
- {wafer_core-0.1.33.dist-info → wafer_core-0.1.35.dist-info}/METADATA +1 -1
- {wafer_core-0.1.33.dist-info → wafer_core-0.1.35.dist-info}/RECORD +11 -11
- {wafer_core-0.1.33.dist-info → wafer_core-0.1.35.dist-info}/WHEEL +0 -0
|
@@ -1,45 +1,179 @@
|
|
|
1
1
|
"""Fusion analysis for detecting kernel fusion differences between platforms.
|
|
2
2
|
|
|
3
|
-
Detects fusion differences between AMD and NVIDIA by analyzing
|
|
4
|
-
|
|
5
|
-
2. Sequence patterns (what comes before/after unique kernels)
|
|
6
|
-
3. Count imbalances (one platform has significantly more calls)
|
|
7
|
-
|
|
8
|
-
This is the pattern-based approach which is more reliable than alignment-based detection.
|
|
3
|
+
Detects fusion differences between AMD and NVIDIA by analyzing how many kernels
|
|
4
|
+
each platform launches for the same logical operations.
|
|
9
5
|
"""
|
|
10
6
|
|
|
7
|
+
import json
|
|
11
8
|
from collections import Counter, defaultdict
|
|
12
|
-
from
|
|
9
|
+
from functools import lru_cache
|
|
10
|
+
from pathlib import Path
|
|
13
11
|
from typing import Any
|
|
14
12
|
|
|
15
|
-
from .classifier import
|
|
13
|
+
from .classifier import classify_kernel
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@lru_cache(maxsize=10000)
|
|
17
|
+
def _classify_kernel_cached(name: str) -> str:
|
|
18
|
+
"""Classify kernel with caching to avoid redundant regex matching."""
|
|
19
|
+
return classify_kernel(name)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _load_trace_for_fusion(
|
|
23
|
+
file_path: str | Path,
|
|
24
|
+
) -> tuple[str, str, list[dict[str, Any]], dict[int, list[dict[str, Any]]]]:
|
|
25
|
+
"""Load trace and group kernels by correlation ID.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
file_path: Path to trace file
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Tuple of (platform, gpu_name, all_kernels, corr_groups)
|
|
32
|
+
"""
|
|
33
|
+
with open(file_path, "rb") as f:
|
|
34
|
+
trace = json.load(f)
|
|
35
|
+
|
|
36
|
+
# Detect platform
|
|
37
|
+
props = trace.get("deviceProperties", [{}])[0]
|
|
38
|
+
is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
|
|
39
|
+
platform = "AMD" if is_amd else "NVIDIA"
|
|
40
|
+
gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
|
|
41
|
+
|
|
42
|
+
# Get all kernel events
|
|
43
|
+
events = trace.get("traceEvents", [])
|
|
44
|
+
kernels = [e for e in events if e.get("cat") == "kernel"]
|
|
45
|
+
|
|
46
|
+
# Group by correlation ID
|
|
47
|
+
corr_groups: dict[int, list[dict[str, Any]]] = defaultdict(list)
|
|
48
|
+
for k in kernels:
|
|
49
|
+
corr_id = k.get("args", {}).get("correlation")
|
|
50
|
+
if corr_id is not None:
|
|
51
|
+
corr_groups[corr_id].append(k)
|
|
52
|
+
|
|
53
|
+
return platform, gpu_name, kernels, dict(corr_groups)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _compute_group_signature(kernels: list[dict[str, Any]]) -> tuple:
|
|
57
|
+
"""Compute a hashable signature for a correlation group for O(1) lookups.
|
|
58
|
+
|
|
59
|
+
This enables signature-based matching which is O(n) average instead of O(n²).
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
kernels: List of kernel events in the group
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Tuple of (size_bucket, has_attention, has_ffn, dominant_type)
|
|
66
|
+
"""
|
|
67
|
+
counts = Counter(_classify_kernel_cached(k.get("name", "")) for k in kernels)
|
|
68
|
+
dominant = counts.most_common(1)[0][0] if counts else "Other"
|
|
69
|
+
size_bucket = len(kernels) // 10 * 10 # Round to nearest 10
|
|
70
|
+
has_attn = any("attention" in k.get("name", "").lower() or "fmha" in k.get("name", "").lower() for k in kernels)
|
|
71
|
+
has_ffn = any(
|
|
72
|
+
any(x in k.get("name", "").lower() for x in ["cijk", "nvjet", "gemm"])
|
|
73
|
+
for k in kernels
|
|
74
|
+
)
|
|
75
|
+
return (size_bucket, has_attn, has_ffn, dominant)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _analyze_correlation_group(
|
|
79
|
+
kernels: list[dict[str, Any]],
|
|
80
|
+
) -> tuple[dict[str, int], dict[str, float]]:
|
|
81
|
+
"""Analyze kernel composition within a correlation group.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
kernels: List of kernel events in the group
|
|
16
85
|
|
|
86
|
+
Returns:
|
|
87
|
+
Tuple of (counts, timings) where counts maps kernel types to counts
|
|
88
|
+
and timings maps kernel types to total duration in microseconds
|
|
89
|
+
"""
|
|
90
|
+
counts: Counter[str] = Counter()
|
|
91
|
+
timings: dict[str, float] = defaultdict(float)
|
|
17
92
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
93
|
+
for k in kernels:
|
|
94
|
+
kernel_type = _classify_kernel_cached(k.get("name", ""))
|
|
95
|
+
counts[kernel_type] += 1
|
|
96
|
+
timings[kernel_type] += k.get("dur", 0)
|
|
21
97
|
|
|
22
|
-
|
|
23
|
-
operation: str # The fused operation name (e.g., "RMSNorm+GEMM")
|
|
24
|
-
fused_platform: str # Platform that fuses (has fewer kernels)
|
|
25
|
-
fused_kernel: str # The actual fused kernel name
|
|
26
|
-
unfused_kernels: list[str] # List of separate kernels on the other platform
|
|
27
|
-
count: int
|
|
28
|
-
evidence: str
|
|
98
|
+
return dict(counts), dict(timings)
|
|
29
99
|
|
|
30
100
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
101
|
+
def _match_correlation_groups(
|
|
102
|
+
trace1_groups: dict[int, list[dict[str, Any]]],
|
|
103
|
+
trace2_groups: dict[int, list[dict[str, Any]]],
|
|
104
|
+
size_tolerance: float = 0.25,
|
|
105
|
+
) -> list[tuple[int, int]]:
|
|
106
|
+
"""Match correlation groups using hybrid signature + composition approach.
|
|
34
107
|
|
|
35
|
-
|
|
36
|
-
|
|
108
|
+
Uses signature-based lookup for speed (O(n) average) but falls back to
|
|
109
|
+
composition-based scoring for accuracy when multiple candidates exist.
|
|
37
110
|
|
|
111
|
+
Args:
|
|
112
|
+
trace1_groups: Trace 1 correlation groups
|
|
113
|
+
trace2_groups: Trace 2 correlation groups
|
|
114
|
+
size_tolerance: Groups match if sizes are within this fraction
|
|
38
115
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
116
|
+
Returns:
|
|
117
|
+
List of (trace1_corr_id, trace2_corr_id) pairs
|
|
118
|
+
"""
|
|
119
|
+
matches = []
|
|
120
|
+
used_trace2_ids: set[int] = set()
|
|
121
|
+
|
|
122
|
+
# Pre-compute compositions for scoring (keeps original O(n) cost)
|
|
123
|
+
trace1_comps = {id: _analyze_correlation_group(kernels)[0] for id, kernels in trace1_groups.items()}
|
|
124
|
+
trace2_comps = {id: _analyze_correlation_group(kernels)[0] for id, kernels in trace2_groups.items()}
|
|
125
|
+
|
|
126
|
+
# Build signature-based lookup for fast filtering (O(n))
|
|
127
|
+
trace2_by_sig: dict[tuple, list[tuple[int, int]]] = defaultdict(list)
|
|
128
|
+
for gid, kernels in trace2_groups.items():
|
|
129
|
+
sig = _compute_group_signature(kernels)
|
|
130
|
+
trace2_by_sig[sig].append((gid, len(kernels)))
|
|
131
|
+
|
|
132
|
+
# Sort trace1 groups by size (largest first)
|
|
133
|
+
trace1_sorted = sorted(trace1_groups.items(), key=lambda x: len(x[1]), reverse=True)
|
|
134
|
+
|
|
135
|
+
# Match each trace1 group
|
|
136
|
+
for trace1_id, trace1_kernels in trace1_sorted:
|
|
137
|
+
trace1_size = len(trace1_kernels)
|
|
138
|
+
trace1_comp = trace1_comps[trace1_id]
|
|
139
|
+
sig = _compute_group_signature(trace1_kernels)
|
|
140
|
+
|
|
141
|
+
# Fast lookup: get candidates with matching signature
|
|
142
|
+
candidates = [
|
|
143
|
+
(gid, size) for gid, size in trace2_by_sig.get(sig, [])
|
|
144
|
+
if gid not in used_trace2_ids
|
|
145
|
+
]
|
|
146
|
+
|
|
147
|
+
# Expand search to adjacent size buckets if needed
|
|
148
|
+
if not candidates:
|
|
149
|
+
for adj_size in [sig[0] - 10, sig[0] + 10]:
|
|
150
|
+
adj_sig = (adj_size, sig[1], sig[2], sig[3])
|
|
151
|
+
for gid, size in trace2_by_sig.get(adj_sig, []):
|
|
152
|
+
if gid not in used_trace2_ids:
|
|
153
|
+
candidates.append((gid, size))
|
|
154
|
+
|
|
155
|
+
# Apply size tolerance filter
|
|
156
|
+
min_size = trace1_size / (1 + size_tolerance)
|
|
157
|
+
max_size = trace1_size * (1 + size_tolerance)
|
|
158
|
+
candidates = [(gid, size) for gid, size in candidates if min_size <= size <= max_size]
|
|
159
|
+
|
|
160
|
+
# Score by composition similarity (only for filtered candidates)
|
|
161
|
+
if candidates:
|
|
162
|
+
best_id = None
|
|
163
|
+
best_score = float("inf")
|
|
164
|
+
for gid, size in candidates:
|
|
165
|
+
trace2_comp = trace2_comps[gid]
|
|
166
|
+
shared_types = len(set(trace1_comp.keys()) & set(trace2_comp.keys()))
|
|
167
|
+
score = abs(trace1_size - size) - (shared_types * 10)
|
|
168
|
+
if score < best_score:
|
|
169
|
+
best_score = score
|
|
170
|
+
best_id = gid
|
|
171
|
+
|
|
172
|
+
if best_id is not None:
|
|
173
|
+
matches.append((trace1_id, best_id))
|
|
174
|
+
used_trace2_ids.add(best_id)
|
|
175
|
+
|
|
176
|
+
return matches
|
|
43
177
|
|
|
44
178
|
|
|
45
179
|
def _find_fusion_mappings(
|
|
@@ -47,7 +181,6 @@ def _find_fusion_mappings(
|
|
|
47
181
|
trace2_kernels: list[dict],
|
|
48
182
|
trace1_name: str = "Trace1",
|
|
49
183
|
trace2_name: str = "Trace2",
|
|
50
|
-
trace1_platform: str = "AMD",
|
|
51
184
|
) -> list[dict]:
|
|
52
185
|
"""Find fusion mappings by analyzing kernel execution sequence patterns.
|
|
53
186
|
|
|
@@ -59,21 +192,29 @@ def _find_fusion_mappings(
|
|
|
59
192
|
trace2_kernels: List of kernel events from second trace
|
|
60
193
|
trace1_name: Name of first platform (e.g., "AMD")
|
|
61
194
|
trace2_name: Name of second platform (e.g., "NVIDIA")
|
|
62
|
-
trace1_platform: Platform string for classification ("AMD" or "NVIDIA")
|
|
63
195
|
|
|
64
196
|
Returns:
|
|
65
|
-
List of mapping dictionaries
|
|
197
|
+
List of mapping dictionaries, each containing:
|
|
198
|
+
- fused_platform: Which platform fuses the operations
|
|
199
|
+
- fused_kernel_type: The single fused kernel type
|
|
200
|
+
- unfused_platform: Which platform runs them separately
|
|
201
|
+
- unfused_sequence: List of kernel types run separately
|
|
202
|
+
- pattern_count: How many times this pattern appears
|
|
203
|
+
- pattern_confidence: Fraction of occurrences following this pattern
|
|
204
|
+
- evidence: Human-readable description
|
|
66
205
|
"""
|
|
206
|
+
from collections import defaultdict
|
|
207
|
+
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
208
|
+
|
|
67
209
|
mappings = []
|
|
68
|
-
trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
|
|
69
210
|
|
|
70
211
|
# Sort kernels by timestamp
|
|
71
|
-
trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get(
|
|
72
|
-
trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get(
|
|
212
|
+
trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get('ts', 0))
|
|
213
|
+
trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get('ts', 0))
|
|
73
214
|
|
|
74
215
|
# Classify all kernels
|
|
75
|
-
trace1_types = [
|
|
76
|
-
trace2_types = [
|
|
216
|
+
trace1_types = [_classify_kernel_cached(k.get('name', '')) for k in trace1_sorted]
|
|
217
|
+
trace2_types = [_classify_kernel_cached(k.get('name', '')) for k in trace2_sorted]
|
|
77
218
|
|
|
78
219
|
# Find kernel types unique to each trace
|
|
79
220
|
trace1_type_set = set(trace1_types)
|
|
@@ -82,61 +223,32 @@ def _find_fusion_mappings(
|
|
|
82
223
|
trace1_only = trace1_type_set - trace2_type_set
|
|
83
224
|
trace2_only = trace2_type_set - trace1_type_set
|
|
84
225
|
|
|
85
|
-
# For each unique type in trace1,
|
|
86
|
-
# If trace1 has a unique kernel type that trace2 doesn't have, trace1 is likely fusing
|
|
226
|
+
# For each unique type in trace1, find common sequence patterns
|
|
87
227
|
for unique_type in trace1_only:
|
|
88
|
-
#
|
|
89
|
-
if unique_type == "Other":
|
|
90
|
-
continue
|
|
91
|
-
|
|
92
|
-
# If the unique type contains '+', it's explicitly a fused kernel
|
|
93
|
-
# This means trace1 (which has it) is fusing, not trace2
|
|
94
|
-
if "+" in unique_type:
|
|
95
|
-
# Parse components from the fused op name
|
|
96
|
-
components = [c.strip() for c in unique_type.split("+")]
|
|
97
|
-
indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
|
|
98
|
-
|
|
99
|
-
if len(indices) < 5:
|
|
100
|
-
continue
|
|
101
|
-
|
|
102
|
-
mappings.append({
|
|
103
|
-
"fused_platform": trace1_name,
|
|
104
|
-
"fused_kernel_type": unique_type,
|
|
105
|
-
"fused_count": len(indices),
|
|
106
|
-
"unfused_platform": trace2_name,
|
|
107
|
-
"unfused_sequence": components,
|
|
108
|
-
"unfused_count_per_type": {c: trace2_types.count(c) for c in components},
|
|
109
|
-
"pattern_count": len(indices),
|
|
110
|
-
"pattern_confidence": 1.0,
|
|
111
|
-
"evidence": f"{trace1_name} fuses {' + '.join(components)} into single kernel ({len(indices)} calls), {trace2_name} runs separately",
|
|
112
|
-
})
|
|
113
|
-
continue
|
|
114
|
-
|
|
115
|
-
# For non-fused unique types, find all occurrences
|
|
228
|
+
# Find all occurrences of this type
|
|
116
229
|
indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
|
|
117
230
|
|
|
118
231
|
if len(indices) < 5: # Need enough samples to be meaningful
|
|
119
232
|
continue
|
|
120
233
|
|
|
121
|
-
# Analyze what comes before each occurrence
|
|
122
|
-
before_types
|
|
234
|
+
# Analyze what comes before/after each occurrence
|
|
235
|
+
before_types = defaultdict(int)
|
|
123
236
|
|
|
124
237
|
for idx in indices:
|
|
125
238
|
if idx > 0:
|
|
126
239
|
before_types[trace1_types[idx - 1]] += 1
|
|
127
240
|
|
|
128
|
-
# Find the most common pattern
|
|
129
|
-
|
|
130
|
-
continue
|
|
131
|
-
most_common_before = max(before_types.items(), key=lambda x: x[1])
|
|
241
|
+
# Find the most common pattern (e.g., "Attention → Reduce")
|
|
242
|
+
most_common_before = max(before_types.items(), key=lambda x: x[1]) if before_types else (None, 0)
|
|
132
243
|
|
|
133
|
-
# If there's a strong pattern (>80% of occurrences)
|
|
134
|
-
# trace1 runs them separately while trace2 might fuse
|
|
244
|
+
# If there's a strong pattern (>80% of occurrences)
|
|
135
245
|
if most_common_before[1] / len(indices) > 0.8:
|
|
246
|
+
# This suggests: Trace2 likely fuses [before_type + unique_type] into [before_type]
|
|
136
247
|
fusion_candidate = most_common_before[0]
|
|
137
248
|
|
|
138
|
-
# Verify trace2 has this type
|
|
249
|
+
# Verify trace2 has this type
|
|
139
250
|
if fusion_candidate in trace2_type_set:
|
|
251
|
+
# Count occurrences to compare
|
|
140
252
|
trace1_fusion_count = trace1_types.count(fusion_candidate)
|
|
141
253
|
trace2_fusion_count = trace2_types.count(fusion_candidate)
|
|
142
254
|
|
|
@@ -148,40 +260,15 @@ def _find_fusion_mappings(
|
|
|
148
260
|
"unfused_sequence": [fusion_candidate, unique_type],
|
|
149
261
|
"unfused_count_per_type": {
|
|
150
262
|
fusion_candidate: trace1_fusion_count,
|
|
151
|
-
unique_type: len(indices)
|
|
263
|
+
unique_type: len(indices)
|
|
152
264
|
},
|
|
153
265
|
"pattern_count": len(indices),
|
|
154
266
|
"pattern_confidence": most_common_before[1] / len(indices),
|
|
155
|
-
"evidence": f"{trace1_name} runs {fusion_candidate}
|
|
267
|
+
"evidence": f"{trace1_name} runs {fusion_candidate}+{unique_type} separately, {trace2_name} fuses into {fusion_candidate}"
|
|
156
268
|
})
|
|
157
269
|
|
|
158
270
|
# Also check trace2-only types
|
|
159
271
|
for unique_type in trace2_only:
|
|
160
|
-
if unique_type == "Other":
|
|
161
|
-
continue
|
|
162
|
-
|
|
163
|
-
# If the unique type contains '+', it's explicitly a fused kernel
|
|
164
|
-
# This means trace2 (which has it) is fusing, not trace1
|
|
165
|
-
if "+" in unique_type:
|
|
166
|
-
components = [c.strip() for c in unique_type.split("+")]
|
|
167
|
-
indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
|
|
168
|
-
|
|
169
|
-
if len(indices) < 5:
|
|
170
|
-
continue
|
|
171
|
-
|
|
172
|
-
mappings.append({
|
|
173
|
-
"fused_platform": trace2_name,
|
|
174
|
-
"fused_kernel_type": unique_type,
|
|
175
|
-
"fused_count": len(indices),
|
|
176
|
-
"unfused_platform": trace1_name,
|
|
177
|
-
"unfused_sequence": components,
|
|
178
|
-
"unfused_count_per_type": {c: trace1_types.count(c) for c in components},
|
|
179
|
-
"pattern_count": len(indices),
|
|
180
|
-
"pattern_confidence": 1.0,
|
|
181
|
-
"evidence": f"{trace2_name} fuses {' + '.join(components)} into single kernel ({len(indices)} calls), {trace1_name} runs separately",
|
|
182
|
-
})
|
|
183
|
-
continue
|
|
184
|
-
|
|
185
272
|
indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
|
|
186
273
|
|
|
187
274
|
if len(indices) < 5:
|
|
@@ -193,9 +280,7 @@ def _find_fusion_mappings(
|
|
|
193
280
|
if idx > 0:
|
|
194
281
|
before_types[trace2_types[idx - 1]] += 1
|
|
195
282
|
|
|
196
|
-
|
|
197
|
-
continue
|
|
198
|
-
most_common_before = max(before_types.items(), key=lambda x: x[1])
|
|
283
|
+
most_common_before = max(before_types.items(), key=lambda x: x[1]) if before_types else (None, 0)
|
|
199
284
|
|
|
200
285
|
if most_common_before[1] / len(indices) > 0.8:
|
|
201
286
|
fusion_candidate = most_common_before[0]
|
|
@@ -212,294 +297,633 @@ def _find_fusion_mappings(
|
|
|
212
297
|
"unfused_sequence": [fusion_candidate, unique_type],
|
|
213
298
|
"unfused_count_per_type": {
|
|
214
299
|
fusion_candidate: trace2_fusion_count,
|
|
215
|
-
unique_type: len(indices)
|
|
300
|
+
unique_type: len(indices)
|
|
216
301
|
},
|
|
217
302
|
"pattern_count": len(indices),
|
|
218
303
|
"pattern_confidence": most_common_before[1] / len(indices),
|
|
219
|
-
"evidence": f"{trace2_name} runs {fusion_candidate}
|
|
304
|
+
"evidence": f"{trace2_name} runs {fusion_candidate}+{unique_type} separately, {trace1_name} fuses into {fusion_candidate}"
|
|
305
|
+
})
|
|
306
|
+
|
|
307
|
+
# NEW: Check for partial fusion patterns (kernel exists on both platforms but with different counts)
|
|
308
|
+
# If one platform has significantly fewer calls (>1.3x ratio), look for fusion patterns
|
|
309
|
+
common_types = trace1_type_set & trace2_type_set
|
|
310
|
+
|
|
311
|
+
for ktype in common_types:
|
|
312
|
+
trace1_count = trace1_types.count(ktype)
|
|
313
|
+
trace2_count = trace2_types.count(ktype)
|
|
314
|
+
|
|
315
|
+
# Check if there's a significant imbalance (one has >1.3x more)
|
|
316
|
+
if trace1_count == 0 or trace2_count == 0:
|
|
317
|
+
continue
|
|
318
|
+
|
|
319
|
+
ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
|
|
320
|
+
|
|
321
|
+
if ratio < 1.3 or trace1_count + trace2_count < 100:
|
|
322
|
+
continue
|
|
323
|
+
|
|
324
|
+
# Determine which platform has more (unfused) and which has fewer (fused)
|
|
325
|
+
if trace1_count > trace2_count:
|
|
326
|
+
# Trace1 has more separate calls, Trace2 likely fuses
|
|
327
|
+
unfused_platform = trace1_name
|
|
328
|
+
fused_platform = trace2_name
|
|
329
|
+
unfused_count = trace1_count
|
|
330
|
+
fused_count = trace2_count
|
|
331
|
+
|
|
332
|
+
# Find what Trace2 might be fusing this into
|
|
333
|
+
# Use the most common kernel type in Trace2 as a likely fusion target
|
|
334
|
+
trace2_type_counts = defaultdict(int)
|
|
335
|
+
for t in trace2_types:
|
|
336
|
+
if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
|
|
337
|
+
trace2_type_counts[t] += 1
|
|
338
|
+
|
|
339
|
+
if trace2_type_counts:
|
|
340
|
+
# Use the most common type as the fusion target
|
|
341
|
+
fusion_target = max(trace2_type_counts.items(), key=lambda x: x[1])[0]
|
|
342
|
+
|
|
343
|
+
mappings.append({
|
|
344
|
+
"fused_platform": fused_platform,
|
|
345
|
+
"fused_kernel_type": fusion_target,
|
|
346
|
+
"fused_count": fused_count,
|
|
347
|
+
"unfused_platform": unfused_platform,
|
|
348
|
+
"unfused_sequence": [ktype],
|
|
349
|
+
"unfused_count_per_type": {
|
|
350
|
+
ktype: unfused_count
|
|
351
|
+
},
|
|
352
|
+
"pattern_count": unfused_count - fused_count,
|
|
353
|
+
"pattern_confidence": (unfused_count - fused_count) / unfused_count,
|
|
354
|
+
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
|
|
355
|
+
})
|
|
356
|
+
else:
|
|
357
|
+
# Trace2 has more separate calls, Trace1 likely fuses
|
|
358
|
+
unfused_platform = trace2_name
|
|
359
|
+
fused_platform = trace1_name
|
|
360
|
+
unfused_count = trace2_count
|
|
361
|
+
fused_count = trace1_count
|
|
362
|
+
|
|
363
|
+
# Find what Trace1 might be fusing this into
|
|
364
|
+
# Use the most common kernel type in Trace1 as a likely fusion target
|
|
365
|
+
trace1_type_counts = defaultdict(int)
|
|
366
|
+
for t in trace1_types:
|
|
367
|
+
if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
|
|
368
|
+
trace1_type_counts[t] += 1
|
|
369
|
+
|
|
370
|
+
if trace1_type_counts:
|
|
371
|
+
fusion_target = max(trace1_type_counts.items(), key=lambda x: x[1])[0]
|
|
372
|
+
|
|
373
|
+
mappings.append({
|
|
374
|
+
"fused_platform": fused_platform,
|
|
375
|
+
"fused_kernel_type": fusion_target,
|
|
376
|
+
"fused_count": fused_count,
|
|
377
|
+
"unfused_platform": unfused_platform,
|
|
378
|
+
"unfused_sequence": [ktype],
|
|
379
|
+
"unfused_count_per_type": {
|
|
380
|
+
ktype: unfused_count
|
|
381
|
+
},
|
|
382
|
+
"pattern_count": unfused_count - fused_count,
|
|
383
|
+
"pattern_confidence": (unfused_count - fused_count) / unfused_count,
|
|
384
|
+
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
|
|
220
385
|
})
|
|
221
386
|
|
|
222
387
|
return mappings
|
|
223
388
|
|
|
224
389
|
|
|
225
|
-
def
|
|
390
|
+
def _detect_intra_type_fusion(
|
|
226
391
|
trace1_kernels: list[dict],
|
|
227
392
|
trace2_kernels: list[dict],
|
|
228
|
-
trace1_name: str
|
|
229
|
-
trace2_name: str
|
|
230
|
-
trace1_platform: str = "AMD",
|
|
393
|
+
trace1_name: str,
|
|
394
|
+
trace2_name: str,
|
|
231
395
|
) -> list[dict]:
|
|
232
|
-
"""
|
|
396
|
+
"""Detect intra-type fusion where consecutive same-type kernels are fused.
|
|
233
397
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
NOTE: This is speculative - count differences can also indicate:
|
|
238
|
-
- Different algorithmic implementations
|
|
239
|
-
- Different library choices (cuBLAS vs hipBLAS)
|
|
240
|
-
- Different optimization strategies
|
|
241
|
-
|
|
242
|
-
Only very large imbalances (>3x) with high counts are flagged.
|
|
398
|
+
Example: AMD runs Sort→Sort→Sort (42 calls) while NVIDIA runs Sort→Sort (10 calls)
|
|
399
|
+
This indicates NVIDIA has a more efficient Sort implementation that fuses operations.
|
|
243
400
|
"""
|
|
244
|
-
|
|
245
|
-
trace2_platform = "NVIDIA" if trace1_platform == "AMD" else "AMD"
|
|
401
|
+
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
246
402
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
403
|
+
def analyze_chains(kernels):
|
|
404
|
+
"""Find chains of consecutive same-type kernels"""
|
|
405
|
+
sorted_kernels = sorted(kernels, key=lambda k: k.get('ts', 0))
|
|
406
|
+
types = [_classify_kernel_cached(k['name']) for k in sorted_kernels]
|
|
250
407
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
408
|
+
chains = defaultdict(list)
|
|
409
|
+
i = 0
|
|
410
|
+
while i < len(types):
|
|
411
|
+
ktype = types[i]
|
|
412
|
+
count = 0
|
|
413
|
+
while i < len(types) and types[i] == ktype:
|
|
414
|
+
count += 1
|
|
415
|
+
i += 1
|
|
416
|
+
chains[ktype].append(count)
|
|
254
417
|
|
|
255
|
-
|
|
256
|
-
common_types = set(trace1_counts.keys()) & set(trace2_counts.keys())
|
|
418
|
+
return chains
|
|
257
419
|
|
|
258
|
-
|
|
259
|
-
|
|
420
|
+
trace1_chains = analyze_chains(trace1_kernels)
|
|
421
|
+
trace2_chains = analyze_chains(trace2_kernels)
|
|
260
422
|
|
|
261
|
-
|
|
262
|
-
|
|
423
|
+
mappings = []
|
|
424
|
+
all_types = set(trace1_chains.keys()) | set(trace2_chains.keys())
|
|
425
|
+
|
|
426
|
+
for ktype in all_types:
|
|
427
|
+
t1_lengths = trace1_chains.get(ktype, [])
|
|
428
|
+
t2_lengths = trace2_chains.get(ktype, [])
|
|
429
|
+
|
|
430
|
+
# Skip if not enough data
|
|
431
|
+
if len(t1_lengths) < 5 and len(t2_lengths) < 5:
|
|
263
432
|
continue
|
|
264
433
|
|
|
265
|
-
|
|
266
|
-
|
|
434
|
+
# Filter to chains with multiple kernels
|
|
435
|
+
t1_multi = [l for l in t1_lengths if l > 1]
|
|
436
|
+
t2_multi = [l for l in t2_lengths if l > 1]
|
|
267
437
|
|
|
268
|
-
|
|
269
|
-
if trace1_count + trace2_count < 200:
|
|
438
|
+
if not t1_multi and not t2_multi:
|
|
270
439
|
continue
|
|
271
440
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
441
|
+
t1_total = sum(t1_lengths)
|
|
442
|
+
t2_total = sum(t2_lengths)
|
|
443
|
+
t1_chains = len(t1_multi) if t1_multi else len(t1_lengths)
|
|
444
|
+
t2_chains = len(t2_multi) if t2_multi else len(t2_lengths)
|
|
445
|
+
|
|
446
|
+
if t1_chains == 0 or t2_chains == 0:
|
|
275
447
|
continue
|
|
276
448
|
|
|
449
|
+
t1_avg_chain = sum(t1_multi) / len(t1_multi) if t1_multi else 1.0
|
|
450
|
+
t2_avg_chain = sum(t2_multi) / len(t2_multi) if t2_multi else 1.0
|
|
451
|
+
|
|
452
|
+
chain_ratio = max(t1_avg_chain, t2_avg_chain) / min(t1_avg_chain, t2_avg_chain)
|
|
453
|
+
|
|
454
|
+
# Significant intra-fusion if chains are 2x+ different
|
|
455
|
+
if chain_ratio > 2.0 and abs(t1_total - t2_total) > 100:
|
|
456
|
+
if t1_avg_chain > t2_avg_chain:
|
|
457
|
+
unfused_platform = trace1_name
|
|
458
|
+
fused_platform = trace2_name
|
|
459
|
+
unfused_chains = t1_chains
|
|
460
|
+
fused_chains = t2_chains
|
|
461
|
+
unfused_avg = t1_avg_chain
|
|
462
|
+
fused_avg = t2_avg_chain
|
|
463
|
+
unfused_total = t1_total
|
|
464
|
+
fused_total = t2_total
|
|
465
|
+
else:
|
|
466
|
+
unfused_platform = trace2_name
|
|
467
|
+
fused_platform = trace1_name
|
|
468
|
+
unfused_chains = t2_chains
|
|
469
|
+
fused_chains = t1_chains
|
|
470
|
+
unfused_avg = t2_avg_chain
|
|
471
|
+
fused_avg = t1_avg_chain
|
|
472
|
+
unfused_total = t2_total
|
|
473
|
+
fused_total = t1_total
|
|
474
|
+
|
|
475
|
+
mappings.append({
|
|
476
|
+
"fused_platform": fused_platform,
|
|
477
|
+
"fused_kernel_type": ktype,
|
|
478
|
+
"fused_count": fused_total,
|
|
479
|
+
"unfused_platform": unfused_platform,
|
|
480
|
+
"unfused_sequence": [ktype, ktype], # Same type repeated
|
|
481
|
+
"unfused_count_per_type": {ktype: unfused_total},
|
|
482
|
+
"pattern_count": unfused_total - fused_total,
|
|
483
|
+
"pattern_confidence": min(unfused_chains, fused_chains) / max(unfused_chains, fused_chains),
|
|
484
|
+
"evidence": f"{unfused_platform} runs {ktype} in chains of {unfused_avg:.0f} calls ({unfused_chains} chains, {unfused_total:,} total), {fused_platform} fuses to {fused_avg:.0f} calls ({fused_chains} chains, {fused_total:,} total) - {chain_ratio:.1f}x more efficient"
|
|
485
|
+
})
|
|
486
|
+
|
|
487
|
+
return mappings
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def _find_partial_fusion_via_groups(
|
|
491
|
+
trace1_large: dict[int, list[dict]],
|
|
492
|
+
trace2_large: dict[int, list[dict]],
|
|
493
|
+
matches: list[tuple[int, int]],
|
|
494
|
+
trace1_name: str,
|
|
495
|
+
trace2_name: str,
|
|
496
|
+
) -> list[dict]:
|
|
497
|
+
"""Find partial fusion patterns by analyzing correlation group differences.
|
|
498
|
+
|
|
499
|
+
When one platform has fewer of a kernel type, check what kernel types the
|
|
500
|
+
other platform has MORE of in those same groups - those are likely fusion targets.
|
|
501
|
+
"""
|
|
502
|
+
from collections import Counter
|
|
503
|
+
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
504
|
+
|
|
505
|
+
mappings = []
|
|
506
|
+
|
|
507
|
+
# For each matched pair, track kernel type counts
|
|
508
|
+
trace1_all_types = []
|
|
509
|
+
trace2_all_types = []
|
|
510
|
+
|
|
511
|
+
for trace1_cid, trace2_cid in matches:
|
|
512
|
+
trace1_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
513
|
+
trace2_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
514
|
+
trace1_all_types.extend(trace1_ktypes)
|
|
515
|
+
trace2_all_types.extend(trace2_ktypes)
|
|
516
|
+
|
|
517
|
+
# Find kernel types with significant imbalances
|
|
518
|
+
trace1_counts = Counter(trace1_all_types)
|
|
519
|
+
trace2_counts = Counter(trace2_all_types)
|
|
520
|
+
all_types = set(trace1_counts.keys()) | set(trace2_counts.keys())
|
|
521
|
+
|
|
522
|
+
for ktype in all_types:
|
|
523
|
+
trace1_count = trace1_counts.get(ktype, 0)
|
|
524
|
+
trace2_count = trace2_counts.get(ktype, 0)
|
|
525
|
+
|
|
526
|
+
if trace1_count == 0 or trace2_count == 0:
|
|
527
|
+
continue # Handled by sequence-based detection
|
|
528
|
+
|
|
277
529
|
ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
|
|
278
530
|
|
|
279
|
-
if ratio < 3
|
|
280
|
-
continue
|
|
531
|
+
if ratio < 1.3 or trace1_count + trace2_count < 100:
|
|
532
|
+
continue # Not significant
|
|
281
533
|
|
|
282
|
-
# Determine which platform has
|
|
534
|
+
# Determine which platform has fewer (fuses more)
|
|
283
535
|
if trace1_count > trace2_count:
|
|
284
536
|
unfused_platform = trace1_name
|
|
285
537
|
fused_platform = trace2_name
|
|
286
538
|
unfused_count = trace1_count
|
|
287
539
|
fused_count = trace2_count
|
|
540
|
+
|
|
541
|
+
# Find groups where trace1 has this kernel but trace2 doesn't
|
|
542
|
+
fusion_targets = Counter()
|
|
543
|
+
groups_analyzed = 0
|
|
544
|
+
|
|
545
|
+
for trace1_cid, trace2_cid in matches:
|
|
546
|
+
trace1_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
547
|
+
trace2_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
548
|
+
|
|
549
|
+
trace1_has = ktype in trace1_ktypes
|
|
550
|
+
trace2_has = ktype in trace2_ktypes
|
|
551
|
+
|
|
552
|
+
if trace1_has and not trace2_has:
|
|
553
|
+
# What does trace2 have MORE of in this group?
|
|
554
|
+
trace1_kcounts = Counter(trace1_ktypes)
|
|
555
|
+
trace2_kcounts = Counter(trace2_ktypes)
|
|
556
|
+
|
|
557
|
+
for other_type in set(trace2_kcounts.keys()):
|
|
558
|
+
if other_type == ktype or other_type == "Other":
|
|
559
|
+
continue
|
|
560
|
+
diff = trace2_kcounts[other_type] - trace1_kcounts.get(other_type, 0)
|
|
561
|
+
if diff > 0:
|
|
562
|
+
fusion_targets[other_type] += diff
|
|
563
|
+
|
|
564
|
+
groups_analyzed += 1
|
|
565
|
+
|
|
566
|
+
if fusion_targets and groups_analyzed >= 5:
|
|
567
|
+
# Report top fusion targets
|
|
568
|
+
top_targets = fusion_targets.most_common(3)
|
|
569
|
+
target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
|
|
570
|
+
|
|
571
|
+
mappings.append({
|
|
572
|
+
"fused_platform": fused_platform,
|
|
573
|
+
"fused_kernel_type": top_targets[0][0],
|
|
574
|
+
"fused_count": fused_count,
|
|
575
|
+
"unfused_platform": unfused_platform,
|
|
576
|
+
"unfused_sequence": [ktype],
|
|
577
|
+
"unfused_count_per_type": {ktype: unfused_count},
|
|
578
|
+
"pattern_count": unfused_count - fused_count,
|
|
579
|
+
"pattern_confidence": groups_analyzed / len(matches) if matches else 0,
|
|
580
|
+
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
|
|
581
|
+
})
|
|
288
582
|
else:
|
|
583
|
+
# Symmetric case for trace2 > trace1
|
|
289
584
|
unfused_platform = trace2_name
|
|
290
585
|
fused_platform = trace1_name
|
|
291
586
|
unfused_count = trace2_count
|
|
292
587
|
fused_count = trace1_count
|
|
293
588
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
"fused_kernel_type": ktype,
|
|
297
|
-
"fused_count": fused_count,
|
|
298
|
-
"unfused_platform": unfused_platform,
|
|
299
|
-
"unfused_sequence": [ktype],
|
|
300
|
-
"unfused_count_per_type": {ktype: unfused_count},
|
|
301
|
-
"pattern_count": unfused_count - fused_count,
|
|
302
|
-
"pattern_confidence": (unfused_count - fused_count) / unfused_count,
|
|
303
|
-
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}) - possible fusion",
|
|
304
|
-
})
|
|
589
|
+
fusion_targets = Counter()
|
|
590
|
+
groups_analyzed = 0
|
|
305
591
|
|
|
306
|
-
|
|
592
|
+
for trace1_cid, trace2_cid in matches:
|
|
593
|
+
trace1_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
594
|
+
trace2_ktypes = [_classify_kernel_cached(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
307
595
|
|
|
596
|
+
trace1_has = ktype in trace1_ktypes
|
|
597
|
+
trace2_has = ktype in trace2_ktypes
|
|
308
598
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
trace1_name: str = "Trace1",
|
|
313
|
-
trace2_name: str = "Trace2",
|
|
314
|
-
trace1_platform: str = "AMD",
|
|
315
|
-
) -> list[dict]:
|
|
316
|
-
"""Find explicit fused operations (kernels with '+' in their classification).
|
|
599
|
+
if trace2_has and not trace1_has:
|
|
600
|
+
trace1_kcounts = Counter(trace1_ktypes)
|
|
601
|
+
trace2_kcounts = Counter(trace2_ktypes)
|
|
317
602
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
"""Get fused operations and their kernel names."""
|
|
325
|
-
fused_ops: dict[str, list[str]] = defaultdict(list)
|
|
326
|
-
for k in kernels:
|
|
327
|
-
name = k.get("name", "")
|
|
328
|
-
op, _pattern = classify(name, platform)
|
|
329
|
-
if "+" in op.value or op == Op.FUSED_UNKNOWN:
|
|
330
|
-
fused_ops[op.value].append(name)
|
|
331
|
-
return dict(fused_ops)
|
|
332
|
-
|
|
333
|
-
trace1_fused = get_fused_ops(trace1_kernels, trace1_platform)
|
|
334
|
-
trace2_fused = get_fused_ops(trace2_kernels, trace2_platform)
|
|
335
|
-
|
|
336
|
-
# Find fused operations unique to trace1
|
|
337
|
-
for fused_op, kernels in trace1_fused.items():
|
|
338
|
-
if fused_op not in trace2_fused:
|
|
339
|
-
# Parse components from the fused op name
|
|
340
|
-
if "+" in fused_op:
|
|
341
|
-
components = [c.strip() for c in fused_op.split("+")]
|
|
342
|
-
else:
|
|
343
|
-
components = [fused_op]
|
|
603
|
+
for other_type in set(trace1_kcounts.keys()):
|
|
604
|
+
if other_type == ktype or other_type == "Other":
|
|
605
|
+
continue
|
|
606
|
+
diff = trace1_kcounts[other_type] - trace2_kcounts.get(other_type, 0)
|
|
607
|
+
if diff > 0:
|
|
608
|
+
fusion_targets[other_type] += diff
|
|
344
609
|
|
|
345
|
-
|
|
346
|
-
"fused_platform": trace1_name,
|
|
347
|
-
"fused_kernel_type": fused_op,
|
|
348
|
-
"fused_count": len(kernels),
|
|
349
|
-
"unfused_platform": trace2_name,
|
|
350
|
-
"unfused_sequence": components,
|
|
351
|
-
"unfused_count_per_type": {c: 0 for c in components}, # Unknown
|
|
352
|
-
"pattern_count": len(kernels),
|
|
353
|
-
"pattern_confidence": 1.0,
|
|
354
|
-
"evidence": f"{trace1_name} fuses {' + '.join(components)} into {fused_op} ({len(kernels)} calls)",
|
|
355
|
-
"fused_kernel_names": kernels[:3], # Sample of kernel names
|
|
356
|
-
})
|
|
610
|
+
groups_analyzed += 1
|
|
357
611
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
if "+" in fused_op:
|
|
362
|
-
components = [c.strip() for c in fused_op.split("+")]
|
|
363
|
-
else:
|
|
364
|
-
components = [fused_op]
|
|
612
|
+
if fusion_targets and groups_analyzed >= 5:
|
|
613
|
+
top_targets = fusion_targets.most_common(3)
|
|
614
|
+
target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
|
|
365
615
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
})
|
|
616
|
+
mappings.append({
|
|
617
|
+
"fused_platform": fused_platform,
|
|
618
|
+
"fused_kernel_type": top_targets[0][0],
|
|
619
|
+
"fused_count": fused_count,
|
|
620
|
+
"unfused_platform": unfused_platform,
|
|
621
|
+
"unfused_sequence": [ktype],
|
|
622
|
+
"unfused_count_per_type": {ktype: unfused_count},
|
|
623
|
+
"pattern_count": unfused_count - fused_count,
|
|
624
|
+
"pattern_confidence": groups_analyzed / len(matches) if matches else 0,
|
|
625
|
+
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
|
|
626
|
+
})
|
|
378
627
|
|
|
379
628
|
return mappings
|
|
380
629
|
|
|
381
630
|
|
|
382
|
-
def
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
This approach only reports high-confidence fusions where kernels are
|
|
389
|
-
explicitly classified as fused (e.g., 'RMSNorm+GEMM', 'SwiGLU+GEMM').
|
|
390
|
-
|
|
391
|
-
We intentionally avoid speculative detection (sequence patterns, count
|
|
392
|
-
imbalances) because these produce too many false positives - count
|
|
393
|
-
differences are usually due to different implementations, not fusion.
|
|
631
|
+
def analyze_fusion_differences(
|
|
632
|
+
amd_trace_path: str | Path,
|
|
633
|
+
nv_trace_path: str | Path,
|
|
634
|
+
min_group_size: int = 50,
|
|
635
|
+
) -> dict[str, Any]:
|
|
636
|
+
"""Main fusion analysis function.
|
|
394
637
|
|
|
395
638
|
Args:
|
|
396
|
-
|
|
397
|
-
|
|
639
|
+
amd_trace_path: Path to AMD trace
|
|
640
|
+
nv_trace_path: Path to NVIDIA trace
|
|
641
|
+
min_group_size: Only analyze correlation groups with at least this many kernels
|
|
398
642
|
|
|
399
643
|
Returns:
|
|
400
|
-
|
|
644
|
+
Dictionary with analysis results containing:
|
|
645
|
+
- metadata: trace info
|
|
646
|
+
- global_counts: kernel type distribution across entire trace
|
|
647
|
+
- fusion_opportunities: significant fusion differences
|
|
648
|
+
- fusion_mappings: actual kernel-to-kernel mappings (NEW)
|
|
401
649
|
"""
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
# These are kernels explicitly classified with '+' in their operation type
|
|
406
|
-
explicit_fusions = _find_explicit_fused_operations(
|
|
407
|
-
amd_kernels, nvidia_kernels,
|
|
408
|
-
trace1_name="AMD", trace2_name="NVIDIA",
|
|
409
|
-
trace1_platform="AMD",
|
|
650
|
+
# Load traces (maintain order - don't swap)
|
|
651
|
+
trace1_platform, trace1_gpu, trace1_kernels, trace1_corr_groups = _load_trace_for_fusion(
|
|
652
|
+
amd_trace_path
|
|
410
653
|
)
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
#
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
654
|
+
trace2_platform, trace2_gpu, trace2_kernels, trace2_corr_groups = _load_trace_for_fusion(
|
|
655
|
+
nv_trace_path
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
# Override platform names with generic "Trace 1" and "Trace 2" for UI consistency
|
|
659
|
+
trace1_platform = "Trace 1"
|
|
660
|
+
trace2_platform = "Trace 2"
|
|
661
|
+
|
|
662
|
+
# Filter to "significant" correlation groups
|
|
663
|
+
trace1_large = {
|
|
664
|
+
cid: kernels
|
|
665
|
+
for cid, kernels in trace1_corr_groups.items()
|
|
666
|
+
if len(kernels) >= min_group_size
|
|
667
|
+
}
|
|
668
|
+
trace2_large = {
|
|
669
|
+
cid: kernels
|
|
670
|
+
for cid, kernels in trace2_corr_groups.items()
|
|
671
|
+
if len(kernels) >= min_group_size
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
# Match correlation groups between platforms
|
|
675
|
+
matches = _match_correlation_groups(trace1_large, trace2_large)
|
|
676
|
+
|
|
677
|
+
# Analyze differences in matched groups
|
|
678
|
+
fusion_diffs: dict[str, dict[str, Any]] = defaultdict(
|
|
679
|
+
lambda: {
|
|
680
|
+
"trace1_count": 0,
|
|
681
|
+
"trace2_count": 0,
|
|
682
|
+
"trace1_time_us": 0,
|
|
683
|
+
"trace2_time_us": 0,
|
|
684
|
+
"groups_with_diff": 0,
|
|
685
|
+
"total_groups": 0,
|
|
686
|
+
}
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
# NEW: Collect actual fusion mappings
|
|
690
|
+
all_fusion_mappings = []
|
|
691
|
+
|
|
692
|
+
for trace1_cid, trace2_cid in matches:
|
|
693
|
+
trace1_comp, trace1_times = _analyze_correlation_group(trace1_large[trace1_cid])
|
|
694
|
+
trace2_comp, trace2_times = _analyze_correlation_group(trace2_large[trace2_cid])
|
|
695
|
+
|
|
696
|
+
# Find all kernel types in either platform
|
|
697
|
+
all_types = set(trace1_comp.keys()) | set(trace2_comp.keys())
|
|
698
|
+
|
|
699
|
+
for ktype in all_types:
|
|
700
|
+
trace1_count = trace1_comp.get(ktype, 0)
|
|
701
|
+
trace2_count = trace2_comp.get(ktype, 0)
|
|
702
|
+
trace1_time = trace1_times.get(ktype, 0)
|
|
703
|
+
trace2_time = trace2_times.get(ktype, 0)
|
|
704
|
+
|
|
705
|
+
fusion_diffs[ktype]["trace1_count"] += trace1_count
|
|
706
|
+
fusion_diffs[ktype]["trace2_count"] += trace2_count
|
|
707
|
+
fusion_diffs[ktype]["trace1_time_us"] += trace1_time
|
|
708
|
+
fusion_diffs[ktype]["trace2_time_us"] += trace2_time
|
|
709
|
+
fusion_diffs[ktype]["total_groups"] += 1
|
|
710
|
+
|
|
711
|
+
if trace1_count != trace2_count:
|
|
712
|
+
fusion_diffs[ktype]["groups_with_diff"] += 1
|
|
713
|
+
|
|
714
|
+
# NEW: Find actual kernel mappings in this correlation group
|
|
715
|
+
group_mappings = _find_fusion_mappings(
|
|
716
|
+
trace1_large[trace1_cid],
|
|
717
|
+
trace2_large[trace2_cid],
|
|
718
|
+
trace1_name=trace1_platform,
|
|
719
|
+
trace2_name=trace2_platform
|
|
720
|
+
)
|
|
721
|
+
# Add correlation ID context to each mapping
|
|
722
|
+
for mapping in group_mappings:
|
|
723
|
+
mapping["correlation_group_trace1"] = trace1_cid
|
|
724
|
+
mapping["correlation_group_trace2"] = trace2_cid
|
|
725
|
+
all_fusion_mappings.extend(group_mappings)
|
|
726
|
+
|
|
727
|
+
# Also get global counts for context
|
|
728
|
+
global_trace1_counts: Counter[str] = Counter(
|
|
729
|
+
[_classify_kernel_cached(k.get("name", "")) for k in trace1_kernels]
|
|
730
|
+
)
|
|
731
|
+
global_trace2_counts: Counter[str] = Counter(
|
|
732
|
+
[_classify_kernel_cached(k.get("name", "")) for k in trace2_kernels]
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
# Build results
|
|
736
|
+
results: dict[str, Any] = {
|
|
737
|
+
"metadata": {
|
|
738
|
+
"trace1_gpu": trace1_gpu,
|
|
739
|
+
"trace2_gpu": trace2_gpu,
|
|
740
|
+
"trace1_total_kernels": len(trace1_kernels),
|
|
741
|
+
"trace2_total_kernels": len(trace2_kernels),
|
|
742
|
+
"trace1_correlation_groups": len(trace1_large),
|
|
743
|
+
"trace2_correlation_groups": len(trace2_large),
|
|
744
|
+
"matched_groups": len(matches),
|
|
745
|
+
},
|
|
746
|
+
"global_counts": {},
|
|
747
|
+
"fusion_opportunities": [],
|
|
748
|
+
"fusion_mappings": all_fusion_mappings, # NEW: Include actual mappings
|
|
749
|
+
}
|
|
750
|
+
|
|
751
|
+
# Global counts for all kernel types
|
|
752
|
+
all_ktypes = set(global_trace1_counts.keys()) | set(global_trace2_counts.keys())
|
|
753
|
+
for ktype in all_ktypes:
|
|
754
|
+
trace1_total = global_trace1_counts.get(ktype, 0)
|
|
755
|
+
trace2_total = global_trace2_counts.get(ktype, 0)
|
|
756
|
+
|
|
757
|
+
if trace1_total > 0 or trace2_total > 0:
|
|
758
|
+
results["global_counts"][ktype] = {
|
|
759
|
+
"trace1_count": trace1_total,
|
|
760
|
+
"trace2_count": trace2_total,
|
|
761
|
+
"ratio": trace1_total / trace2_total if trace2_total > 0 else float("inf"),
|
|
762
|
+
}
|
|
763
|
+
|
|
764
|
+
# Fusion opportunities from matched correlation groups
|
|
765
|
+
for ktype, stats in fusion_diffs.items():
|
|
766
|
+
trace1_avg = (
|
|
767
|
+
stats["trace1_count"] / stats["total_groups"]
|
|
768
|
+
if stats["total_groups"] > 0
|
|
769
|
+
else 0
|
|
770
|
+
)
|
|
771
|
+
trace2_avg = (
|
|
772
|
+
stats["trace2_count"] / stats["total_groups"]
|
|
773
|
+
if stats["total_groups"] > 0
|
|
774
|
+
else 0
|
|
775
|
+
)
|
|
776
|
+
trace1_time_ms = stats["trace1_time_us"] / 1000
|
|
777
|
+
trace2_time_ms = stats["trace2_time_us"] / 1000
|
|
778
|
+
|
|
779
|
+
# Calculate significance
|
|
780
|
+
diff_ratio = trace1_avg / trace2_avg if trace2_avg > 0 else float("inf")
|
|
781
|
+
reverse_ratio = trace2_avg / trace1_avg if trace1_avg > 0 else float("inf")
|
|
782
|
+
|
|
783
|
+
# Only report if there's a significant difference
|
|
784
|
+
# Either: one platform has it and the other doesn't (ratio > 10)
|
|
785
|
+
# Or: one platform has significantly more (ratio > 2.0)
|
|
786
|
+
is_significant = (
|
|
787
|
+
(diff_ratio > 10.0 or reverse_ratio > 10.0)
|
|
788
|
+
or ( # One platform doesn't have it
|
|
789
|
+
(diff_ratio > 2.0 or reverse_ratio > 2.0)
|
|
790
|
+
and stats["trace1_count"] + stats["trace2_count"] > 20 # Significant difference # Not trivial counts
|
|
433
791
|
)
|
|
434
792
|
)
|
|
435
793
|
|
|
436
|
-
|
|
437
|
-
|
|
794
|
+
if is_significant:
|
|
795
|
+
# Determine who fuses (who has FEWER calls = more fusion)
|
|
796
|
+
if diff_ratio > 1.5:
|
|
797
|
+
fused_by = "Trace 2" # Trace 2 has fewer calls, so it fuses more
|
|
798
|
+
ratio = diff_ratio
|
|
799
|
+
else:
|
|
800
|
+
fused_by = "Trace 1" # Trace 1 has fewer calls, so it fuses more
|
|
801
|
+
ratio = reverse_ratio
|
|
802
|
+
|
|
803
|
+
# Calculate time ratio (who's faster for this operation)
|
|
804
|
+
time_ratio = trace1_time_ms / trace2_time_ms if trace2_time_ms > 0 else float("inf")
|
|
805
|
+
|
|
806
|
+
results["fusion_opportunities"].append(
|
|
807
|
+
{
|
|
808
|
+
"kernel_type": ktype,
|
|
809
|
+
"trace1_total": stats["trace1_count"],
|
|
810
|
+
"trace2_total": stats["trace2_count"],
|
|
811
|
+
"trace1_avg_per_group": trace1_avg,
|
|
812
|
+
"trace2_avg_per_group": trace2_avg,
|
|
813
|
+
"trace1_time_ms": trace1_time_ms,
|
|
814
|
+
"trace2_time_ms": trace2_time_ms,
|
|
815
|
+
"time_ratio": time_ratio,
|
|
816
|
+
"ratio": ratio,
|
|
817
|
+
"fused_by": fused_by,
|
|
818
|
+
"groups_affected": stats["groups_with_diff"],
|
|
819
|
+
"total_groups": stats["total_groups"],
|
|
820
|
+
}
|
|
821
|
+
)
|
|
438
822
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
},
|
|
446
|
-
)
|
|
823
|
+
# ADDITIONAL: Check global counts for significant differences not captured above
|
|
824
|
+
# This catches patterns like Sort that may be in small groups or distributed differently
|
|
825
|
+
for ktype in all_ktypes:
|
|
826
|
+
# Skip if already added from correlation group analysis
|
|
827
|
+
if any(opp["kernel_type"] == ktype for opp in results["fusion_opportunities"]):
|
|
828
|
+
continue
|
|
447
829
|
|
|
830
|
+
trace1_total = global_trace1_counts.get(ktype, 0)
|
|
831
|
+
trace2_total = global_trace2_counts.get(ktype, 0)
|
|
448
832
|
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
nvidia_kernels: list[dict] | None = None,
|
|
453
|
-
) -> dict[str, Any]:
|
|
454
|
-
"""Analyze fusion from kernel data (for API compatibility).
|
|
833
|
+
# Skip trivial counts
|
|
834
|
+
if trace1_total + trace2_total < 100:
|
|
835
|
+
continue
|
|
455
836
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
nvidia_kernels: Optional list of NVIDIA kernel events for pattern-based analysis
|
|
837
|
+
# Calculate global ratio
|
|
838
|
+
global_ratio = trace1_total / trace2_total if trace2_total > 0 else float("inf")
|
|
839
|
+
global_reverse_ratio = trace2_total / trace1_total if trace1_total > 0 else float("inf")
|
|
460
840
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
if amd_kernels is not None and nvidia_kernels is not None:
|
|
466
|
-
fusion_analysis = detect_fusion_patterns(amd_kernels, nvidia_kernels)
|
|
467
|
-
else:
|
|
468
|
-
# Fallback: empty analysis if no kernel data
|
|
469
|
-
fusion_analysis = FusionAnalysis(
|
|
470
|
-
patterns=[],
|
|
471
|
-
summary={"amd_fuses": 0, "nvidia_fuses": 0, "total_fusion_opportunities": 0},
|
|
841
|
+
# Check if globally significant (more aggressive threshold for comprehensive detection)
|
|
842
|
+
is_globally_significant = (
|
|
843
|
+
(global_ratio > 2.0 or global_reverse_ratio > 2.0)
|
|
844
|
+
and (trace1_total + trace2_total > 100)
|
|
472
845
|
)
|
|
473
846
|
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
"
|
|
494
|
-
|
|
495
|
-
"
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
847
|
+
if is_globally_significant:
|
|
848
|
+
# Get timing info from all kernels (not just matched groups)
|
|
849
|
+
trace1_time = sum(
|
|
850
|
+
k.get("dur", 0) for k in trace1_kernels
|
|
851
|
+
if _classify_kernel_cached(k.get("name", "")) == ktype
|
|
852
|
+
) / 1000 # Convert to ms
|
|
853
|
+
trace2_time = sum(
|
|
854
|
+
k.get("dur", 0) for k in trace2_kernels
|
|
855
|
+
if _classify_kernel_cached(k.get("name", "")) == ktype
|
|
856
|
+
) / 1000
|
|
857
|
+
|
|
858
|
+
# Determine who fuses (who has FEWER calls = more fusion)
|
|
859
|
+
if global_ratio > 1.5:
|
|
860
|
+
fused_by = "Trace 2" # Trace 2 has fewer calls
|
|
861
|
+
ratio = global_ratio
|
|
862
|
+
else:
|
|
863
|
+
fused_by = "Trace 1" # Trace 1 has fewer calls
|
|
864
|
+
ratio = global_reverse_ratio
|
|
865
|
+
|
|
866
|
+
time_ratio = trace1_time / trace2_time if trace2_time > 0 else float("inf")
|
|
867
|
+
|
|
868
|
+
results["fusion_opportunities"].append(
|
|
869
|
+
{
|
|
870
|
+
"kernel_type": ktype,
|
|
871
|
+
"trace1_total": trace1_total,
|
|
872
|
+
"trace2_total": trace2_total,
|
|
873
|
+
"trace1_avg_per_group": trace1_total / len(matches) if matches else 0,
|
|
874
|
+
"trace2_avg_per_group": trace2_total / len(matches) if matches else 0,
|
|
875
|
+
"trace1_time_ms": trace1_time,
|
|
876
|
+
"trace2_time_ms": trace2_time,
|
|
877
|
+
"time_ratio": time_ratio,
|
|
878
|
+
"ratio": ratio,
|
|
879
|
+
"fused_by": fused_by,
|
|
880
|
+
"groups_affected": 0, # Unknown for global analysis
|
|
881
|
+
"total_groups": len(matches),
|
|
882
|
+
}
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
# Sort by impact (ratio * total count)
|
|
886
|
+
results["fusion_opportunities"].sort(
|
|
887
|
+
key=lambda x: x["ratio"] * (x["trace1_total"] + x["trace2_total"]), reverse=True
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
# ADD PARTIAL FUSION MAPPINGS using correlation group differential analysis
|
|
891
|
+
# This catches patterns like Sort that exist on both platforms but with different frequencies
|
|
892
|
+
partial_mappings = _find_partial_fusion_via_groups(
|
|
893
|
+
trace1_large,
|
|
894
|
+
trace2_large,
|
|
895
|
+
matches,
|
|
896
|
+
trace1_name=trace1_platform,
|
|
897
|
+
trace2_name=trace2_platform
|
|
898
|
+
)
|
|
899
|
+
all_fusion_mappings.extend(partial_mappings)
|
|
900
|
+
|
|
901
|
+
# DETECT INTRA-TYPE FUSION (same kernel type fused with itself, like Sort chains)
|
|
902
|
+
# Do this FIRST since it's more accurate than the fallback global analysis
|
|
903
|
+
intra_mappings = _detect_intra_type_fusion(
|
|
904
|
+
trace1_kernels,
|
|
905
|
+
trace2_kernels,
|
|
906
|
+
trace1_name=trace1_platform,
|
|
907
|
+
trace2_name=trace2_platform
|
|
908
|
+
)
|
|
909
|
+
all_fusion_mappings.extend(intra_mappings)
|
|
910
|
+
|
|
911
|
+
# Collect kernel types already handled by intra-type fusion
|
|
912
|
+
intra_handled_types = set(m["fused_kernel_type"] for m in intra_mappings)
|
|
913
|
+
|
|
914
|
+
# ALSO ADD GLOBAL FUSION MAPPINGS for kernels not in large correlation groups
|
|
915
|
+
# Skip types already handled by intra-type fusion (more accurate)
|
|
916
|
+
global_mappings = _find_fusion_mappings(
|
|
917
|
+
trace1_kernels,
|
|
918
|
+
trace2_kernels,
|
|
919
|
+
trace1_name=trace1_platform,
|
|
920
|
+
trace2_name=trace2_platform
|
|
921
|
+
)
|
|
922
|
+
# Filter: skip if already handled or if evidence is duplicate
|
|
923
|
+
existing_evidence = set(m["evidence"] for m in all_fusion_mappings)
|
|
924
|
+
for mapping in global_mappings:
|
|
925
|
+
ktype = mapping["unfused_sequence"][0] if mapping["unfused_sequence"] else None
|
|
926
|
+
if ktype not in intra_handled_types and mapping["evidence"] not in existing_evidence:
|
|
927
|
+
all_fusion_mappings.append(mapping)
|
|
928
|
+
|
|
929
|
+
return results
|