wafer-core 0.1.25__py3-none-any.whl → 0.1.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/__init__.py +32 -0
- wafer_core/lib/trace_compare/analyzer.py +339 -0
- wafer_core/lib/trace_compare/classifier.py +192 -0
- wafer_core/lib/trace_compare/formatter.py +951 -0
- wafer_core/lib/trace_compare/fusion_analyzer.py +890 -0
- wafer_core/lib/trace_compare/loader.py +336 -0
- wafer_core/problem_config.py +3 -3
- wafer_core/rollouts/agent_presets/rlm_01_01.py +2 -2
- wafer_core/rollouts/dtypes.py +18 -3
- wafer_core/rollouts/providers/anthropic.py +35 -3
- wafer_core/utils/kernel_utils/defense.py +10 -0
- wafer_core/utils/kernel_utils/targets/config.py +10 -0
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.26.dist-info}/METADATA +1 -1
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.26.dist-info}/RECORD +15 -9
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.26.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,890 @@
|
|
|
1
|
+
"""Fusion analysis for detecting kernel fusion differences between platforms.
|
|
2
|
+
|
|
3
|
+
Detects fusion differences between AMD and NVIDIA by analyzing how many kernels
|
|
4
|
+
each platform launches for the same logical operations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from collections import Counter, defaultdict
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from .classifier import classify_kernel
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _load_trace_for_fusion(
|
|
16
|
+
file_path: str | Path,
|
|
17
|
+
) -> tuple[str, str, list[dict[str, Any]], dict[int, list[dict[str, Any]]]]:
|
|
18
|
+
"""Load trace and group kernels by correlation ID.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
file_path: Path to trace file
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Tuple of (platform, gpu_name, all_kernels, corr_groups)
|
|
25
|
+
"""
|
|
26
|
+
with open(file_path, "rb") as f:
|
|
27
|
+
trace = json.load(f)
|
|
28
|
+
|
|
29
|
+
# Detect platform
|
|
30
|
+
props = trace.get("deviceProperties", [{}])[0]
|
|
31
|
+
is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
|
|
32
|
+
platform = "AMD" if is_amd else "NVIDIA"
|
|
33
|
+
gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
|
|
34
|
+
|
|
35
|
+
# Get all kernel events
|
|
36
|
+
events = trace.get("traceEvents", [])
|
|
37
|
+
kernels = [e for e in events if e.get("cat") == "kernel"]
|
|
38
|
+
|
|
39
|
+
# Group by correlation ID
|
|
40
|
+
corr_groups: dict[int, list[dict[str, Any]]] = defaultdict(list)
|
|
41
|
+
for k in kernels:
|
|
42
|
+
corr_id = k.get("args", {}).get("correlation")
|
|
43
|
+
if corr_id is not None:
|
|
44
|
+
corr_groups[corr_id].append(k)
|
|
45
|
+
|
|
46
|
+
return platform, gpu_name, kernels, dict(corr_groups)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _analyze_correlation_group(
|
|
50
|
+
kernels: list[dict[str, Any]],
|
|
51
|
+
) -> tuple[dict[str, int], dict[str, float]]:
|
|
52
|
+
"""Analyze kernel composition within a correlation group.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
kernels: List of kernel events in the group
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Tuple of (counts, timings) where counts maps kernel types to counts
|
|
59
|
+
and timings maps kernel types to total duration in microseconds
|
|
60
|
+
"""
|
|
61
|
+
counts: Counter[str] = Counter()
|
|
62
|
+
timings: dict[str, float] = defaultdict(float)
|
|
63
|
+
|
|
64
|
+
for k in kernels:
|
|
65
|
+
kernel_type = classify_kernel(k.get("name", ""))
|
|
66
|
+
counts[kernel_type] += 1
|
|
67
|
+
timings[kernel_type] += k.get("dur", 0)
|
|
68
|
+
|
|
69
|
+
return dict(counts), dict(timings)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _match_correlation_groups(
|
|
73
|
+
amd_groups: dict[int, list[dict[str, Any]]],
|
|
74
|
+
nv_groups: dict[int, list[dict[str, Any]]],
|
|
75
|
+
size_tolerance: float = 0.25,
|
|
76
|
+
) -> list[tuple[int, int]]:
|
|
77
|
+
"""Match AMD and NVIDIA correlation groups by size and composition.
|
|
78
|
+
|
|
79
|
+
Since correlation IDs don't match between platforms, we match groups that
|
|
80
|
+
have similar sizes (within tolerance) and likely represent the same operation.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
amd_groups: AMD correlation groups
|
|
84
|
+
nv_groups: NVIDIA correlation groups
|
|
85
|
+
size_tolerance: Groups match if sizes are within this fraction (25% default, increased for better matching)
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
List of (amd_corr_id, nv_corr_id) pairs
|
|
89
|
+
"""
|
|
90
|
+
matches = []
|
|
91
|
+
used_nv_ids: set[int] = set()
|
|
92
|
+
|
|
93
|
+
# Pre-compute compositions for all groups to avoid redundant analysis
|
|
94
|
+
amd_comps = {id: _analyze_correlation_group(kernels)[0] for id, kernels in amd_groups.items()}
|
|
95
|
+
nv_comps = {id: _analyze_correlation_group(kernels)[0] for id, kernels in nv_groups.items()}
|
|
96
|
+
|
|
97
|
+
# Build size-indexed lookup for NVIDIA groups for faster filtering
|
|
98
|
+
nv_by_size = [(id, len(kernels)) for id, kernels in nv_groups.items()]
|
|
99
|
+
|
|
100
|
+
# Sort AMD groups by size (largest first for better matching)
|
|
101
|
+
amd_sorted = sorted(amd_groups.items(), key=lambda x: len(x[1]), reverse=True)
|
|
102
|
+
|
|
103
|
+
for amd_id, amd_kernels in amd_sorted:
|
|
104
|
+
amd_size = len(amd_kernels)
|
|
105
|
+
amd_comp = amd_comps[amd_id]
|
|
106
|
+
|
|
107
|
+
# Calculate size bounds for matching
|
|
108
|
+
min_size = amd_size / (1 + size_tolerance)
|
|
109
|
+
max_size = amd_size * (1 + size_tolerance)
|
|
110
|
+
|
|
111
|
+
# Find best matching NVIDIA group
|
|
112
|
+
best_match = None
|
|
113
|
+
best_score = float("inf")
|
|
114
|
+
|
|
115
|
+
for nv_id, nv_size in nv_by_size:
|
|
116
|
+
if nv_id in used_nv_ids:
|
|
117
|
+
continue
|
|
118
|
+
|
|
119
|
+
# Quick size filter
|
|
120
|
+
if nv_size < min_size or nv_size > max_size:
|
|
121
|
+
continue
|
|
122
|
+
|
|
123
|
+
# Use pre-computed composition
|
|
124
|
+
nv_comp = nv_comps[nv_id]
|
|
125
|
+
|
|
126
|
+
# Simple similarity: number of shared kernel types
|
|
127
|
+
shared_types = set(amd_comp.keys()) & set(nv_comp.keys())
|
|
128
|
+
similarity = len(shared_types)
|
|
129
|
+
|
|
130
|
+
# Prefer matches with more shared types and closer sizes
|
|
131
|
+
score = abs(amd_size - nv_size) - (similarity * 10)
|
|
132
|
+
|
|
133
|
+
if score < best_score:
|
|
134
|
+
best_score = score
|
|
135
|
+
best_match = nv_id
|
|
136
|
+
|
|
137
|
+
if best_match is not None:
|
|
138
|
+
matches.append((amd_id, best_match))
|
|
139
|
+
used_nv_ids.add(best_match)
|
|
140
|
+
|
|
141
|
+
return matches
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _find_fusion_mappings(
|
|
145
|
+
trace1_kernels: list[dict],
|
|
146
|
+
trace2_kernels: list[dict],
|
|
147
|
+
trace1_name: str = "Trace1",
|
|
148
|
+
trace2_name: str = "Trace2",
|
|
149
|
+
) -> list[dict]:
|
|
150
|
+
"""Find fusion mappings by analyzing kernel execution sequence patterns.
|
|
151
|
+
|
|
152
|
+
This function identifies when one platform runs multiple kernels separately
|
|
153
|
+
while the other platform fuses them into a single kernel.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
trace1_kernels: List of kernel events from first trace
|
|
157
|
+
trace2_kernels: List of kernel events from second trace
|
|
158
|
+
trace1_name: Name of first platform (e.g., "AMD")
|
|
159
|
+
trace2_name: Name of second platform (e.g., "NVIDIA")
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
List of mapping dictionaries, each containing:
|
|
163
|
+
- fused_platform: Which platform fuses the operations
|
|
164
|
+
- fused_kernel_type: The single fused kernel type
|
|
165
|
+
- unfused_platform: Which platform runs them separately
|
|
166
|
+
- unfused_sequence: List of kernel types run separately
|
|
167
|
+
- pattern_count: How many times this pattern appears
|
|
168
|
+
- pattern_confidence: Fraction of occurrences following this pattern
|
|
169
|
+
- evidence: Human-readable description
|
|
170
|
+
"""
|
|
171
|
+
from collections import defaultdict
|
|
172
|
+
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
173
|
+
|
|
174
|
+
mappings = []
|
|
175
|
+
|
|
176
|
+
# Sort kernels by timestamp
|
|
177
|
+
trace1_sorted = sorted(trace1_kernels, key=lambda k: k.get('ts', 0))
|
|
178
|
+
trace2_sorted = sorted(trace2_kernels, key=lambda k: k.get('ts', 0))
|
|
179
|
+
|
|
180
|
+
# Classify all kernels
|
|
181
|
+
trace1_types = [classify_kernel(k.get('name', '')) for k in trace1_sorted]
|
|
182
|
+
trace2_types = [classify_kernel(k.get('name', '')) for k in trace2_sorted]
|
|
183
|
+
|
|
184
|
+
# Find kernel types unique to each trace
|
|
185
|
+
trace1_type_set = set(trace1_types)
|
|
186
|
+
trace2_type_set = set(trace2_types)
|
|
187
|
+
|
|
188
|
+
trace1_only = trace1_type_set - trace2_type_set
|
|
189
|
+
trace2_only = trace2_type_set - trace1_type_set
|
|
190
|
+
|
|
191
|
+
# For each unique type in trace1, find common sequence patterns
|
|
192
|
+
for unique_type in trace1_only:
|
|
193
|
+
# Find all occurrences of this type
|
|
194
|
+
indices = [i for i, t in enumerate(trace1_types) if t == unique_type]
|
|
195
|
+
|
|
196
|
+
if len(indices) < 5: # Need enough samples to be meaningful
|
|
197
|
+
continue
|
|
198
|
+
|
|
199
|
+
# Analyze what comes before/after each occurrence
|
|
200
|
+
before_types = defaultdict(int)
|
|
201
|
+
|
|
202
|
+
for idx in indices:
|
|
203
|
+
if idx > 0:
|
|
204
|
+
before_types[trace1_types[idx - 1]] += 1
|
|
205
|
+
|
|
206
|
+
# Find the most common pattern (e.g., "Attention → Reduce")
|
|
207
|
+
most_common_before = max(before_types.items(), key=lambda x: x[1]) if before_types else (None, 0)
|
|
208
|
+
|
|
209
|
+
# If there's a strong pattern (>80% of occurrences)
|
|
210
|
+
if most_common_before[1] / len(indices) > 0.8:
|
|
211
|
+
# This suggests: Trace2 likely fuses [before_type + unique_type] into [before_type]
|
|
212
|
+
fusion_candidate = most_common_before[0]
|
|
213
|
+
|
|
214
|
+
# Verify trace2 has this type
|
|
215
|
+
if fusion_candidate in trace2_type_set:
|
|
216
|
+
# Count occurrences to compare
|
|
217
|
+
trace1_fusion_count = trace1_types.count(fusion_candidate)
|
|
218
|
+
trace2_fusion_count = trace2_types.count(fusion_candidate)
|
|
219
|
+
|
|
220
|
+
mappings.append({
|
|
221
|
+
"fused_platform": trace2_name,
|
|
222
|
+
"fused_kernel_type": fusion_candidate,
|
|
223
|
+
"fused_count": trace2_fusion_count,
|
|
224
|
+
"unfused_platform": trace1_name,
|
|
225
|
+
"unfused_sequence": [fusion_candidate, unique_type],
|
|
226
|
+
"unfused_count_per_type": {
|
|
227
|
+
fusion_candidate: trace1_fusion_count,
|
|
228
|
+
unique_type: len(indices)
|
|
229
|
+
},
|
|
230
|
+
"pattern_count": len(indices),
|
|
231
|
+
"pattern_confidence": most_common_before[1] / len(indices),
|
|
232
|
+
"evidence": f"{trace1_name} runs {fusion_candidate}+{unique_type} separately, {trace2_name} fuses into {fusion_candidate}"
|
|
233
|
+
})
|
|
234
|
+
|
|
235
|
+
# Also check trace2-only types
|
|
236
|
+
for unique_type in trace2_only:
|
|
237
|
+
indices = [i for i, t in enumerate(trace2_types) if t == unique_type]
|
|
238
|
+
|
|
239
|
+
if len(indices) < 5:
|
|
240
|
+
continue
|
|
241
|
+
|
|
242
|
+
before_types = defaultdict(int)
|
|
243
|
+
|
|
244
|
+
for idx in indices:
|
|
245
|
+
if idx > 0:
|
|
246
|
+
before_types[trace2_types[idx - 1]] += 1
|
|
247
|
+
|
|
248
|
+
most_common_before = max(before_types.items(), key=lambda x: x[1]) if before_types else (None, 0)
|
|
249
|
+
|
|
250
|
+
if most_common_before[1] / len(indices) > 0.8:
|
|
251
|
+
fusion_candidate = most_common_before[0]
|
|
252
|
+
|
|
253
|
+
if fusion_candidate in trace1_type_set:
|
|
254
|
+
trace1_fusion_count = trace1_types.count(fusion_candidate)
|
|
255
|
+
trace2_fusion_count = trace2_types.count(fusion_candidate)
|
|
256
|
+
|
|
257
|
+
mappings.append({
|
|
258
|
+
"fused_platform": trace1_name,
|
|
259
|
+
"fused_kernel_type": fusion_candidate,
|
|
260
|
+
"fused_count": trace1_fusion_count,
|
|
261
|
+
"unfused_platform": trace2_name,
|
|
262
|
+
"unfused_sequence": [fusion_candidate, unique_type],
|
|
263
|
+
"unfused_count_per_type": {
|
|
264
|
+
fusion_candidate: trace2_fusion_count,
|
|
265
|
+
unique_type: len(indices)
|
|
266
|
+
},
|
|
267
|
+
"pattern_count": len(indices),
|
|
268
|
+
"pattern_confidence": most_common_before[1] / len(indices),
|
|
269
|
+
"evidence": f"{trace2_name} runs {fusion_candidate}+{unique_type} separately, {trace1_name} fuses into {fusion_candidate}"
|
|
270
|
+
})
|
|
271
|
+
|
|
272
|
+
# NEW: Check for partial fusion patterns (kernel exists on both platforms but with different counts)
|
|
273
|
+
# If one platform has significantly fewer calls (>1.3x ratio), look for fusion patterns
|
|
274
|
+
common_types = trace1_type_set & trace2_type_set
|
|
275
|
+
|
|
276
|
+
for ktype in common_types:
|
|
277
|
+
trace1_count = trace1_types.count(ktype)
|
|
278
|
+
trace2_count = trace2_types.count(ktype)
|
|
279
|
+
|
|
280
|
+
# Check if there's a significant imbalance (one has >1.3x more)
|
|
281
|
+
if trace1_count == 0 or trace2_count == 0:
|
|
282
|
+
continue
|
|
283
|
+
|
|
284
|
+
ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
|
|
285
|
+
|
|
286
|
+
if ratio < 1.3 or trace1_count + trace2_count < 100:
|
|
287
|
+
continue
|
|
288
|
+
|
|
289
|
+
# Determine which platform has more (unfused) and which has fewer (fused)
|
|
290
|
+
if trace1_count > trace2_count:
|
|
291
|
+
# Trace1 has more separate calls, Trace2 likely fuses
|
|
292
|
+
unfused_platform = trace1_name
|
|
293
|
+
fused_platform = trace2_name
|
|
294
|
+
unfused_count = trace1_count
|
|
295
|
+
fused_count = trace2_count
|
|
296
|
+
|
|
297
|
+
# Find what Trace2 might be fusing this into
|
|
298
|
+
# Use the most common kernel type in Trace2 as a likely fusion target
|
|
299
|
+
trace2_type_counts = defaultdict(int)
|
|
300
|
+
for t in trace2_types:
|
|
301
|
+
if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
|
|
302
|
+
trace2_type_counts[t] += 1
|
|
303
|
+
|
|
304
|
+
if trace2_type_counts:
|
|
305
|
+
# Use the most common type as the fusion target
|
|
306
|
+
fusion_target = max(trace2_type_counts.items(), key=lambda x: x[1])[0]
|
|
307
|
+
|
|
308
|
+
mappings.append({
|
|
309
|
+
"fused_platform": fused_platform,
|
|
310
|
+
"fused_kernel_type": fusion_target,
|
|
311
|
+
"fused_count": fused_count,
|
|
312
|
+
"unfused_platform": unfused_platform,
|
|
313
|
+
"unfused_sequence": [ktype],
|
|
314
|
+
"unfused_count_per_type": {
|
|
315
|
+
ktype: unfused_count
|
|
316
|
+
},
|
|
317
|
+
"pattern_count": unfused_count - fused_count,
|
|
318
|
+
"pattern_confidence": (unfused_count - fused_count) / unfused_count,
|
|
319
|
+
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
|
|
320
|
+
})
|
|
321
|
+
else:
|
|
322
|
+
# Trace2 has more separate calls, Trace1 likely fuses
|
|
323
|
+
unfused_platform = trace2_name
|
|
324
|
+
fused_platform = trace1_name
|
|
325
|
+
unfused_count = trace2_count
|
|
326
|
+
fused_count = trace1_count
|
|
327
|
+
|
|
328
|
+
# Find what Trace1 might be fusing this into
|
|
329
|
+
# Use the most common kernel type in Trace1 as a likely fusion target
|
|
330
|
+
trace1_type_counts = defaultdict(int)
|
|
331
|
+
for t in trace1_types:
|
|
332
|
+
if t != ktype and t != "Other": # Skip the imbalanced type itself and "Other"
|
|
333
|
+
trace1_type_counts[t] += 1
|
|
334
|
+
|
|
335
|
+
if trace1_type_counts:
|
|
336
|
+
fusion_target = max(trace1_type_counts.items(), key=lambda x: x[1])[0]
|
|
337
|
+
|
|
338
|
+
mappings.append({
|
|
339
|
+
"fused_platform": fused_platform,
|
|
340
|
+
"fused_kernel_type": fusion_target,
|
|
341
|
+
"fused_count": fused_count,
|
|
342
|
+
"unfused_platform": unfused_platform,
|
|
343
|
+
"unfused_sequence": [ktype],
|
|
344
|
+
"unfused_count_per_type": {
|
|
345
|
+
ktype: unfused_count
|
|
346
|
+
},
|
|
347
|
+
"pattern_count": unfused_count - fused_count,
|
|
348
|
+
"pattern_confidence": (unfused_count - fused_count) / unfused_count,
|
|
349
|
+
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
|
|
350
|
+
})
|
|
351
|
+
|
|
352
|
+
return mappings
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _detect_intra_type_fusion(
|
|
356
|
+
trace1_kernels: list[dict],
|
|
357
|
+
trace2_kernels: list[dict],
|
|
358
|
+
trace1_name: str,
|
|
359
|
+
trace2_name: str,
|
|
360
|
+
) -> list[dict]:
|
|
361
|
+
"""Detect intra-type fusion where consecutive same-type kernels are fused.
|
|
362
|
+
|
|
363
|
+
Example: AMD runs Sort→Sort→Sort (42 calls) while NVIDIA runs Sort→Sort (10 calls)
|
|
364
|
+
This indicates NVIDIA has a more efficient Sort implementation that fuses operations.
|
|
365
|
+
"""
|
|
366
|
+
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
367
|
+
|
|
368
|
+
def analyze_chains(kernels):
|
|
369
|
+
"""Find chains of consecutive same-type kernels"""
|
|
370
|
+
sorted_kernels = sorted(kernels, key=lambda k: k.get('ts', 0))
|
|
371
|
+
types = [classify_kernel(k['name']) for k in sorted_kernels]
|
|
372
|
+
|
|
373
|
+
chains = defaultdict(list)
|
|
374
|
+
i = 0
|
|
375
|
+
while i < len(types):
|
|
376
|
+
ktype = types[i]
|
|
377
|
+
count = 0
|
|
378
|
+
while i < len(types) and types[i] == ktype:
|
|
379
|
+
count += 1
|
|
380
|
+
i += 1
|
|
381
|
+
chains[ktype].append(count)
|
|
382
|
+
|
|
383
|
+
return chains
|
|
384
|
+
|
|
385
|
+
trace1_chains = analyze_chains(trace1_kernels)
|
|
386
|
+
trace2_chains = analyze_chains(trace2_kernels)
|
|
387
|
+
|
|
388
|
+
mappings = []
|
|
389
|
+
all_types = set(trace1_chains.keys()) | set(trace2_chains.keys())
|
|
390
|
+
|
|
391
|
+
for ktype in all_types:
|
|
392
|
+
t1_lengths = trace1_chains.get(ktype, [])
|
|
393
|
+
t2_lengths = trace2_chains.get(ktype, [])
|
|
394
|
+
|
|
395
|
+
# Skip if not enough data
|
|
396
|
+
if len(t1_lengths) < 5 and len(t2_lengths) < 5:
|
|
397
|
+
continue
|
|
398
|
+
|
|
399
|
+
# Filter to chains with multiple kernels
|
|
400
|
+
t1_multi = [l for l in t1_lengths if l > 1]
|
|
401
|
+
t2_multi = [l for l in t2_lengths if l > 1]
|
|
402
|
+
|
|
403
|
+
if not t1_multi and not t2_multi:
|
|
404
|
+
continue
|
|
405
|
+
|
|
406
|
+
t1_total = sum(t1_lengths)
|
|
407
|
+
t2_total = sum(t2_lengths)
|
|
408
|
+
t1_chains = len(t1_multi) if t1_multi else len(t1_lengths)
|
|
409
|
+
t2_chains = len(t2_multi) if t2_multi else len(t2_lengths)
|
|
410
|
+
|
|
411
|
+
if t1_chains == 0 or t2_chains == 0:
|
|
412
|
+
continue
|
|
413
|
+
|
|
414
|
+
t1_avg_chain = sum(t1_multi) / len(t1_multi) if t1_multi else 1.0
|
|
415
|
+
t2_avg_chain = sum(t2_multi) / len(t2_multi) if t2_multi else 1.0
|
|
416
|
+
|
|
417
|
+
chain_ratio = max(t1_avg_chain, t2_avg_chain) / min(t1_avg_chain, t2_avg_chain)
|
|
418
|
+
|
|
419
|
+
# Significant intra-fusion if chains are 2x+ different
|
|
420
|
+
if chain_ratio > 2.0 and abs(t1_total - t2_total) > 100:
|
|
421
|
+
if t1_avg_chain > t2_avg_chain:
|
|
422
|
+
unfused_platform = trace1_name
|
|
423
|
+
fused_platform = trace2_name
|
|
424
|
+
unfused_chains = t1_chains
|
|
425
|
+
fused_chains = t2_chains
|
|
426
|
+
unfused_avg = t1_avg_chain
|
|
427
|
+
fused_avg = t2_avg_chain
|
|
428
|
+
unfused_total = t1_total
|
|
429
|
+
fused_total = t2_total
|
|
430
|
+
else:
|
|
431
|
+
unfused_platform = trace2_name
|
|
432
|
+
fused_platform = trace1_name
|
|
433
|
+
unfused_chains = t2_chains
|
|
434
|
+
fused_chains = t1_chains
|
|
435
|
+
unfused_avg = t2_avg_chain
|
|
436
|
+
fused_avg = t1_avg_chain
|
|
437
|
+
unfused_total = t2_total
|
|
438
|
+
fused_total = t1_total
|
|
439
|
+
|
|
440
|
+
mappings.append({
|
|
441
|
+
"fused_platform": fused_platform,
|
|
442
|
+
"fused_kernel_type": ktype,
|
|
443
|
+
"fused_count": fused_total,
|
|
444
|
+
"unfused_platform": unfused_platform,
|
|
445
|
+
"unfused_sequence": [ktype, ktype], # Same type repeated
|
|
446
|
+
"unfused_count_per_type": {ktype: unfused_total},
|
|
447
|
+
"pattern_count": unfused_total - fused_total,
|
|
448
|
+
"pattern_confidence": min(unfused_chains, fused_chains) / max(unfused_chains, fused_chains),
|
|
449
|
+
"evidence": f"{unfused_platform} runs {ktype} in chains of {unfused_avg:.0f} calls ({unfused_chains} chains, {unfused_total:,} total), {fused_platform} fuses to {fused_avg:.0f} calls ({fused_chains} chains, {fused_total:,} total) - {chain_ratio:.1f}x more efficient"
|
|
450
|
+
})
|
|
451
|
+
|
|
452
|
+
return mappings
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def _find_partial_fusion_via_groups(
|
|
456
|
+
trace1_large: dict[int, list[dict]],
|
|
457
|
+
trace2_large: dict[int, list[dict]],
|
|
458
|
+
matches: list[tuple[int, int]],
|
|
459
|
+
trace1_name: str,
|
|
460
|
+
trace2_name: str,
|
|
461
|
+
) -> list[dict]:
|
|
462
|
+
"""Find partial fusion patterns by analyzing correlation group differences.
|
|
463
|
+
|
|
464
|
+
When one platform has fewer of a kernel type, check what kernel types the
|
|
465
|
+
other platform has MORE of in those same groups - those are likely fusion targets.
|
|
466
|
+
"""
|
|
467
|
+
from collections import Counter
|
|
468
|
+
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
469
|
+
|
|
470
|
+
mappings = []
|
|
471
|
+
|
|
472
|
+
# For each matched pair, track kernel type counts
|
|
473
|
+
trace1_all_types = []
|
|
474
|
+
trace2_all_types = []
|
|
475
|
+
|
|
476
|
+
for trace1_cid, trace2_cid in matches:
|
|
477
|
+
trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
478
|
+
trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
479
|
+
trace1_all_types.extend(trace1_ktypes)
|
|
480
|
+
trace2_all_types.extend(trace2_ktypes)
|
|
481
|
+
|
|
482
|
+
# Find kernel types with significant imbalances
|
|
483
|
+
trace1_counts = Counter(trace1_all_types)
|
|
484
|
+
trace2_counts = Counter(trace2_all_types)
|
|
485
|
+
all_types = set(trace1_counts.keys()) | set(trace2_counts.keys())
|
|
486
|
+
|
|
487
|
+
for ktype in all_types:
|
|
488
|
+
trace1_count = trace1_counts.get(ktype, 0)
|
|
489
|
+
trace2_count = trace2_counts.get(ktype, 0)
|
|
490
|
+
|
|
491
|
+
if trace1_count == 0 or trace2_count == 0:
|
|
492
|
+
continue # Handled by sequence-based detection
|
|
493
|
+
|
|
494
|
+
ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
|
|
495
|
+
|
|
496
|
+
if ratio < 1.3 or trace1_count + trace2_count < 100:
|
|
497
|
+
continue # Not significant
|
|
498
|
+
|
|
499
|
+
# Determine which platform has fewer (fuses more)
|
|
500
|
+
if trace1_count > trace2_count:
|
|
501
|
+
unfused_platform = trace1_name
|
|
502
|
+
fused_platform = trace2_name
|
|
503
|
+
unfused_count = trace1_count
|
|
504
|
+
fused_count = trace2_count
|
|
505
|
+
|
|
506
|
+
# Find groups where trace1 has this kernel but trace2 doesn't
|
|
507
|
+
fusion_targets = Counter()
|
|
508
|
+
groups_analyzed = 0
|
|
509
|
+
|
|
510
|
+
for trace1_cid, trace2_cid in matches:
|
|
511
|
+
trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
512
|
+
trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
513
|
+
|
|
514
|
+
trace1_has = ktype in trace1_ktypes
|
|
515
|
+
trace2_has = ktype in trace2_ktypes
|
|
516
|
+
|
|
517
|
+
if trace1_has and not trace2_has:
|
|
518
|
+
# What does trace2 have MORE of in this group?
|
|
519
|
+
trace1_kcounts = Counter(trace1_ktypes)
|
|
520
|
+
trace2_kcounts = Counter(trace2_ktypes)
|
|
521
|
+
|
|
522
|
+
for other_type in set(trace2_kcounts.keys()):
|
|
523
|
+
if other_type == ktype or other_type == "Other":
|
|
524
|
+
continue
|
|
525
|
+
diff = trace2_kcounts[other_type] - trace1_kcounts.get(other_type, 0)
|
|
526
|
+
if diff > 0:
|
|
527
|
+
fusion_targets[other_type] += diff
|
|
528
|
+
|
|
529
|
+
groups_analyzed += 1
|
|
530
|
+
|
|
531
|
+
if fusion_targets and groups_analyzed >= 5:
|
|
532
|
+
# Report top fusion targets
|
|
533
|
+
top_targets = fusion_targets.most_common(3)
|
|
534
|
+
target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
|
|
535
|
+
|
|
536
|
+
mappings.append({
|
|
537
|
+
"fused_platform": fused_platform,
|
|
538
|
+
"fused_kernel_type": top_targets[0][0],
|
|
539
|
+
"fused_count": fused_count,
|
|
540
|
+
"unfused_platform": unfused_platform,
|
|
541
|
+
"unfused_sequence": [ktype],
|
|
542
|
+
"unfused_count_per_type": {ktype: unfused_count},
|
|
543
|
+
"pattern_count": unfused_count - fused_count,
|
|
544
|
+
"pattern_confidence": groups_analyzed / len(matches) if matches else 0,
|
|
545
|
+
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
|
|
546
|
+
})
|
|
547
|
+
else:
|
|
548
|
+
# Symmetric case for trace2 > trace1
|
|
549
|
+
unfused_platform = trace2_name
|
|
550
|
+
fused_platform = trace1_name
|
|
551
|
+
unfused_count = trace2_count
|
|
552
|
+
fused_count = trace1_count
|
|
553
|
+
|
|
554
|
+
fusion_targets = Counter()
|
|
555
|
+
groups_analyzed = 0
|
|
556
|
+
|
|
557
|
+
for trace1_cid, trace2_cid in matches:
|
|
558
|
+
trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
559
|
+
trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
560
|
+
|
|
561
|
+
trace1_has = ktype in trace1_ktypes
|
|
562
|
+
trace2_has = ktype in trace2_ktypes
|
|
563
|
+
|
|
564
|
+
if trace2_has and not trace1_has:
|
|
565
|
+
trace1_kcounts = Counter(trace1_ktypes)
|
|
566
|
+
trace2_kcounts = Counter(trace2_ktypes)
|
|
567
|
+
|
|
568
|
+
for other_type in set(trace1_kcounts.keys()):
|
|
569
|
+
if other_type == ktype or other_type == "Other":
|
|
570
|
+
continue
|
|
571
|
+
diff = trace1_kcounts[other_type] - trace2_kcounts.get(other_type, 0)
|
|
572
|
+
if diff > 0:
|
|
573
|
+
fusion_targets[other_type] += diff
|
|
574
|
+
|
|
575
|
+
groups_analyzed += 1
|
|
576
|
+
|
|
577
|
+
if fusion_targets and groups_analyzed >= 5:
|
|
578
|
+
top_targets = fusion_targets.most_common(3)
|
|
579
|
+
target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
|
|
580
|
+
|
|
581
|
+
mappings.append({
|
|
582
|
+
"fused_platform": fused_platform,
|
|
583
|
+
"fused_kernel_type": top_targets[0][0],
|
|
584
|
+
"fused_count": fused_count,
|
|
585
|
+
"unfused_platform": unfused_platform,
|
|
586
|
+
"unfused_sequence": [ktype],
|
|
587
|
+
"unfused_count_per_type": {ktype: unfused_count},
|
|
588
|
+
"pattern_count": unfused_count - fused_count,
|
|
589
|
+
"pattern_confidence": groups_analyzed / len(matches) if matches else 0,
|
|
590
|
+
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
|
|
591
|
+
})
|
|
592
|
+
|
|
593
|
+
return mappings
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def analyze_fusion_differences(
|
|
597
|
+
amd_trace_path: str | Path,
|
|
598
|
+
nv_trace_path: str | Path,
|
|
599
|
+
min_group_size: int = 50,
|
|
600
|
+
) -> dict[str, Any]:
|
|
601
|
+
"""Main fusion analysis function.
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
amd_trace_path: Path to AMD trace
|
|
605
|
+
nv_trace_path: Path to NVIDIA trace
|
|
606
|
+
min_group_size: Only analyze correlation groups with at least this many kernels
|
|
607
|
+
|
|
608
|
+
Returns:
|
|
609
|
+
Dictionary with analysis results containing:
|
|
610
|
+
- metadata: trace info
|
|
611
|
+
- global_counts: kernel type distribution across entire trace
|
|
612
|
+
- fusion_opportunities: significant fusion differences
|
|
613
|
+
- fusion_mappings: actual kernel-to-kernel mappings (NEW)
|
|
614
|
+
"""
|
|
615
|
+
# Load traces (maintain order - don't swap)
|
|
616
|
+
trace1_platform, trace1_gpu, trace1_kernels, trace1_corr_groups = _load_trace_for_fusion(
|
|
617
|
+
amd_trace_path
|
|
618
|
+
)
|
|
619
|
+
trace2_platform, trace2_gpu, trace2_kernels, trace2_corr_groups = _load_trace_for_fusion(
|
|
620
|
+
nv_trace_path
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Filter to "significant" correlation groups
|
|
624
|
+
trace1_large = {
|
|
625
|
+
cid: kernels
|
|
626
|
+
for cid, kernels in trace1_corr_groups.items()
|
|
627
|
+
if len(kernels) >= min_group_size
|
|
628
|
+
}
|
|
629
|
+
trace2_large = {
|
|
630
|
+
cid: kernels
|
|
631
|
+
for cid, kernels in trace2_corr_groups.items()
|
|
632
|
+
if len(kernels) >= min_group_size
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
# Match correlation groups between platforms
|
|
636
|
+
matches = _match_correlation_groups(trace1_large, trace2_large)
|
|
637
|
+
|
|
638
|
+
# Analyze differences in matched groups
|
|
639
|
+
fusion_diffs: dict[str, dict[str, Any]] = defaultdict(
|
|
640
|
+
lambda: {
|
|
641
|
+
"trace1_count": 0,
|
|
642
|
+
"trace2_count": 0,
|
|
643
|
+
"trace1_time_us": 0,
|
|
644
|
+
"trace2_time_us": 0,
|
|
645
|
+
"groups_with_diff": 0,
|
|
646
|
+
"total_groups": 0,
|
|
647
|
+
}
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
# NEW: Collect actual fusion mappings
|
|
651
|
+
all_fusion_mappings = []
|
|
652
|
+
|
|
653
|
+
for trace1_cid, trace2_cid in matches:
|
|
654
|
+
trace1_comp, trace1_times = _analyze_correlation_group(trace1_large[trace1_cid])
|
|
655
|
+
trace2_comp, trace2_times = _analyze_correlation_group(trace2_large[trace2_cid])
|
|
656
|
+
|
|
657
|
+
# Find all kernel types in either platform
|
|
658
|
+
all_types = set(trace1_comp.keys()) | set(trace2_comp.keys())
|
|
659
|
+
|
|
660
|
+
for ktype in all_types:
|
|
661
|
+
trace1_count = trace1_comp.get(ktype, 0)
|
|
662
|
+
trace2_count = trace2_comp.get(ktype, 0)
|
|
663
|
+
trace1_time = trace1_times.get(ktype, 0)
|
|
664
|
+
trace2_time = trace2_times.get(ktype, 0)
|
|
665
|
+
|
|
666
|
+
fusion_diffs[ktype]["trace1_count"] += trace1_count
|
|
667
|
+
fusion_diffs[ktype]["trace2_count"] += trace2_count
|
|
668
|
+
fusion_diffs[ktype]["trace1_time_us"] += trace1_time
|
|
669
|
+
fusion_diffs[ktype]["trace2_time_us"] += trace2_time
|
|
670
|
+
fusion_diffs[ktype]["total_groups"] += 1
|
|
671
|
+
|
|
672
|
+
if trace1_count != trace2_count:
|
|
673
|
+
fusion_diffs[ktype]["groups_with_diff"] += 1
|
|
674
|
+
|
|
675
|
+
# NEW: Find actual kernel mappings in this correlation group
|
|
676
|
+
group_mappings = _find_fusion_mappings(
|
|
677
|
+
trace1_large[trace1_cid],
|
|
678
|
+
trace2_large[trace2_cid],
|
|
679
|
+
trace1_name=trace1_platform,
|
|
680
|
+
trace2_name=trace2_platform
|
|
681
|
+
)
|
|
682
|
+
# Add correlation ID context to each mapping
|
|
683
|
+
for mapping in group_mappings:
|
|
684
|
+
mapping["correlation_group_trace1"] = trace1_cid
|
|
685
|
+
mapping["correlation_group_trace2"] = trace2_cid
|
|
686
|
+
all_fusion_mappings.extend(group_mappings)
|
|
687
|
+
|
|
688
|
+
# Also get global counts for context
|
|
689
|
+
global_trace1_counts: Counter[str] = Counter(
|
|
690
|
+
[classify_kernel(k.get("name", "")) for k in trace1_kernels]
|
|
691
|
+
)
|
|
692
|
+
global_trace2_counts: Counter[str] = Counter(
|
|
693
|
+
[classify_kernel(k.get("name", "")) for k in trace2_kernels]
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
# Build results
|
|
697
|
+
results: dict[str, Any] = {
|
|
698
|
+
"metadata": {
|
|
699
|
+
"trace1_gpu": trace1_gpu,
|
|
700
|
+
"trace2_gpu": trace2_gpu,
|
|
701
|
+
"trace1_total_kernels": len(trace1_kernels),
|
|
702
|
+
"trace2_total_kernels": len(trace2_kernels),
|
|
703
|
+
"trace1_correlation_groups": len(trace1_large),
|
|
704
|
+
"trace2_correlation_groups": len(trace2_large),
|
|
705
|
+
"matched_groups": len(matches),
|
|
706
|
+
},
|
|
707
|
+
"global_counts": {},
|
|
708
|
+
"fusion_opportunities": [],
|
|
709
|
+
"fusion_mappings": all_fusion_mappings, # NEW: Include actual mappings
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
# Global counts for all kernel types
|
|
713
|
+
all_ktypes = set(global_trace1_counts.keys()) | set(global_trace2_counts.keys())
|
|
714
|
+
for ktype in all_ktypes:
|
|
715
|
+
trace1_total = global_trace1_counts.get(ktype, 0)
|
|
716
|
+
trace2_total = global_trace2_counts.get(ktype, 0)
|
|
717
|
+
|
|
718
|
+
if trace1_total > 0 or trace2_total > 0:
|
|
719
|
+
results["global_counts"][ktype] = {
|
|
720
|
+
"trace1_count": trace1_total,
|
|
721
|
+
"trace2_count": trace2_total,
|
|
722
|
+
"ratio": trace1_total / trace2_total if trace2_total > 0 else float("inf"),
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
# Fusion opportunities from matched correlation groups
|
|
726
|
+
for ktype, stats in fusion_diffs.items():
|
|
727
|
+
trace1_avg = (
|
|
728
|
+
stats["trace1_count"] / stats["total_groups"]
|
|
729
|
+
if stats["total_groups"] > 0
|
|
730
|
+
else 0
|
|
731
|
+
)
|
|
732
|
+
trace2_avg = (
|
|
733
|
+
stats["trace2_count"] / stats["total_groups"]
|
|
734
|
+
if stats["total_groups"] > 0
|
|
735
|
+
else 0
|
|
736
|
+
)
|
|
737
|
+
trace1_time_ms = stats["trace1_time_us"] / 1000
|
|
738
|
+
trace2_time_ms = stats["trace2_time_us"] / 1000
|
|
739
|
+
|
|
740
|
+
# Calculate significance
|
|
741
|
+
diff_ratio = trace1_avg / trace2_avg if trace2_avg > 0 else float("inf")
|
|
742
|
+
reverse_ratio = trace2_avg / trace1_avg if trace1_avg > 0 else float("inf")
|
|
743
|
+
|
|
744
|
+
# Only report if there's a significant difference
|
|
745
|
+
# Either: one platform has it and the other doesn't (ratio > 10)
|
|
746
|
+
# Or: one platform has significantly more (ratio > 2.0)
|
|
747
|
+
is_significant = (
|
|
748
|
+
(diff_ratio > 10.0 or reverse_ratio > 10.0)
|
|
749
|
+
or ( # One platform doesn't have it
|
|
750
|
+
(diff_ratio > 2.0 or reverse_ratio > 2.0)
|
|
751
|
+
and stats["trace1_count"] + stats["trace2_count"] > 20 # Significant difference # Not trivial counts
|
|
752
|
+
)
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
if is_significant:
|
|
756
|
+
# Determine who fuses (who has FEWER calls = more fusion)
|
|
757
|
+
if diff_ratio > 1.5:
|
|
758
|
+
fused_by = trace2_platform # Trace 2 has fewer calls, so it fuses more
|
|
759
|
+
ratio = diff_ratio
|
|
760
|
+
else:
|
|
761
|
+
fused_by = trace1_platform # Trace 1 has fewer calls, so it fuses more
|
|
762
|
+
ratio = reverse_ratio
|
|
763
|
+
|
|
764
|
+
# Calculate time ratio (who's faster for this operation)
|
|
765
|
+
time_ratio = trace1_time_ms / trace2_time_ms if trace2_time_ms > 0 else float("inf")
|
|
766
|
+
|
|
767
|
+
results["fusion_opportunities"].append(
|
|
768
|
+
{
|
|
769
|
+
"kernel_type": ktype,
|
|
770
|
+
"trace1_total": stats["trace1_count"],
|
|
771
|
+
"trace2_total": stats["trace2_count"],
|
|
772
|
+
"trace1_avg_per_group": trace1_avg,
|
|
773
|
+
"trace2_avg_per_group": trace2_avg,
|
|
774
|
+
"trace1_time_ms": trace1_time_ms,
|
|
775
|
+
"trace2_time_ms": trace2_time_ms,
|
|
776
|
+
"time_ratio": time_ratio,
|
|
777
|
+
"ratio": ratio,
|
|
778
|
+
"fused_by": fused_by,
|
|
779
|
+
"groups_affected": stats["groups_with_diff"],
|
|
780
|
+
"total_groups": stats["total_groups"],
|
|
781
|
+
}
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
# ADDITIONAL: Check global counts for significant differences not captured above
|
|
785
|
+
# This catches patterns like Sort that may be in small groups or distributed differently
|
|
786
|
+
for ktype in all_ktypes:
|
|
787
|
+
# Skip if already added from correlation group analysis
|
|
788
|
+
if any(opp["kernel_type"] == ktype for opp in results["fusion_opportunities"]):
|
|
789
|
+
continue
|
|
790
|
+
|
|
791
|
+
trace1_total = global_trace1_counts.get(ktype, 0)
|
|
792
|
+
trace2_total = global_trace2_counts.get(ktype, 0)
|
|
793
|
+
|
|
794
|
+
# Skip trivial counts
|
|
795
|
+
if trace1_total + trace2_total < 100:
|
|
796
|
+
continue
|
|
797
|
+
|
|
798
|
+
# Calculate global ratio
|
|
799
|
+
global_ratio = trace1_total / trace2_total if trace2_total > 0 else float("inf")
|
|
800
|
+
global_reverse_ratio = trace2_total / trace1_total if trace1_total > 0 else float("inf")
|
|
801
|
+
|
|
802
|
+
# Check if globally significant (more aggressive threshold for comprehensive detection)
|
|
803
|
+
is_globally_significant = (
|
|
804
|
+
(global_ratio > 2.0 or global_reverse_ratio > 2.0)
|
|
805
|
+
and (trace1_total + trace2_total > 100)
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
if is_globally_significant:
|
|
809
|
+
# Get timing info from all kernels (not just matched groups)
|
|
810
|
+
trace1_time = sum(
|
|
811
|
+
k.get("dur", 0) for k in trace1_kernels
|
|
812
|
+
if classify_kernel(k.get("name", "")) == ktype
|
|
813
|
+
) / 1000 # Convert to ms
|
|
814
|
+
trace2_time = sum(
|
|
815
|
+
k.get("dur", 0) for k in trace2_kernels
|
|
816
|
+
if classify_kernel(k.get("name", "")) == ktype
|
|
817
|
+
) / 1000
|
|
818
|
+
|
|
819
|
+
# Determine who fuses (who has FEWER calls = more fusion)
|
|
820
|
+
if global_ratio > 1.5:
|
|
821
|
+
fused_by = "Trace 2" # Trace 2 has fewer calls
|
|
822
|
+
ratio = global_ratio
|
|
823
|
+
else:
|
|
824
|
+
fused_by = "Trace 1" # Trace 1 has fewer calls
|
|
825
|
+
ratio = global_reverse_ratio
|
|
826
|
+
|
|
827
|
+
time_ratio = trace1_time / trace2_time if trace2_time > 0 else float("inf")
|
|
828
|
+
|
|
829
|
+
results["fusion_opportunities"].append(
|
|
830
|
+
{
|
|
831
|
+
"kernel_type": ktype,
|
|
832
|
+
"trace1_total": trace1_total,
|
|
833
|
+
"trace2_total": trace2_total,
|
|
834
|
+
"trace1_avg_per_group": trace1_total / len(matches) if matches else 0,
|
|
835
|
+
"trace2_avg_per_group": trace2_total / len(matches) if matches else 0,
|
|
836
|
+
"trace1_time_ms": trace1_time,
|
|
837
|
+
"trace2_time_ms": trace2_time,
|
|
838
|
+
"time_ratio": time_ratio,
|
|
839
|
+
"ratio": ratio,
|
|
840
|
+
"fused_by": fused_by,
|
|
841
|
+
"groups_affected": 0, # Unknown for global analysis
|
|
842
|
+
"total_groups": len(matches),
|
|
843
|
+
}
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
# Sort by impact (ratio * total count)
|
|
847
|
+
results["fusion_opportunities"].sort(
|
|
848
|
+
key=lambda x: x["ratio"] * (x["trace1_total"] + x["trace2_total"]), reverse=True
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
# ADD PARTIAL FUSION MAPPINGS using correlation group differential analysis
|
|
852
|
+
# This catches patterns like Sort that exist on both platforms but with different frequencies
|
|
853
|
+
partial_mappings = _find_partial_fusion_via_groups(
|
|
854
|
+
trace1_large,
|
|
855
|
+
trace2_large,
|
|
856
|
+
matches,
|
|
857
|
+
trace1_name=trace1_platform,
|
|
858
|
+
trace2_name=trace2_platform
|
|
859
|
+
)
|
|
860
|
+
all_fusion_mappings.extend(partial_mappings)
|
|
861
|
+
|
|
862
|
+
# DETECT INTRA-TYPE FUSION (same kernel type fused with itself, like Sort chains)
|
|
863
|
+
# Do this FIRST since it's more accurate than the fallback global analysis
|
|
864
|
+
intra_mappings = _detect_intra_type_fusion(
|
|
865
|
+
trace1_kernels,
|
|
866
|
+
trace2_kernels,
|
|
867
|
+
trace1_name=trace1_platform,
|
|
868
|
+
trace2_name=trace2_platform
|
|
869
|
+
)
|
|
870
|
+
all_fusion_mappings.extend(intra_mappings)
|
|
871
|
+
|
|
872
|
+
# Collect kernel types already handled by intra-type fusion
|
|
873
|
+
intra_handled_types = set(m["fused_kernel_type"] for m in intra_mappings)
|
|
874
|
+
|
|
875
|
+
# ALSO ADD GLOBAL FUSION MAPPINGS for kernels not in large correlation groups
|
|
876
|
+
# Skip types already handled by intra-type fusion (more accurate)
|
|
877
|
+
global_mappings = _find_fusion_mappings(
|
|
878
|
+
trace1_kernels,
|
|
879
|
+
trace2_kernels,
|
|
880
|
+
trace1_name=trace1_platform,
|
|
881
|
+
trace2_name=trace2_platform
|
|
882
|
+
)
|
|
883
|
+
# Filter: skip if already handled or if evidence is duplicate
|
|
884
|
+
existing_evidence = set(m["evidence"] for m in all_fusion_mappings)
|
|
885
|
+
for mapping in global_mappings:
|
|
886
|
+
ktype = mapping["unfused_sequence"][0] if mapping["unfused_sequence"] else None
|
|
887
|
+
if ktype not in intra_handled_types and mapping["evidence"] not in existing_evidence:
|
|
888
|
+
all_fusion_mappings.append(mapping)
|
|
889
|
+
|
|
890
|
+
return results
|