wafer-core 0.1.26__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/PERFORMANCE.md +148 -0
- wafer_core/lib/trace_compare/__init__.py +22 -9
- wafer_core/lib/trace_compare/aligner.py +369 -0
- wafer_core/lib/trace_compare/analyzer.py +549 -159
- wafer_core/lib/trace_compare/api.py +225 -0
- wafer_core/lib/trace_compare/architecture.py +77 -0
- wafer_core/lib/trace_compare/classifier.py +307 -13
- wafer_core/lib/trace_compare/fusion_analyzer.py +311 -845
- wafer_core/lib/trace_compare/kernel_registry.yaml +349 -0
- wafer_core/lib/trace_compare/layer_segmentation.py +114 -0
- wafer_core/lib/trace_compare/loader.py +526 -227
- wafer_core/lib/trace_compare/same_kernel_analyzer.py +119 -0
- wafer_core/lib/trace_compare/warnings.py +99 -0
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.27.dist-info}/METADATA +3 -1
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.27.dist-info}/RECORD +16 -8
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.27.dist-info}/WHEEL +0 -0
|
@@ -1,890 +1,356 @@
|
|
|
1
|
-
"""Fusion analysis
|
|
1
|
+
"""Fusion analysis using aligned kernel pairs.
|
|
2
2
|
|
|
3
|
-
Detects fusion
|
|
4
|
-
|
|
3
|
+
Detects kernel fusion patterns at the kernel-pair level within aligned layers.
|
|
4
|
+
|
|
5
|
+
Key insight: When one platform has an operation type like "RMSNorm+GEMM" (containing '+'),
|
|
6
|
+
that platform IS fusing. The OTHER platform runs the components separately.
|
|
5
7
|
"""
|
|
6
8
|
|
|
7
|
-
import json
|
|
8
9
|
from collections import Counter, defaultdict
|
|
9
|
-
from
|
|
10
|
+
from dataclasses import dataclass, field
|
|
10
11
|
from typing import Any
|
|
11
12
|
|
|
12
|
-
from .
|
|
13
|
-
|
|
13
|
+
from .aligner import KernelPair, LayerAlignment
|
|
14
14
|
|
|
15
|
-
def _load_trace_for_fusion(
|
|
16
|
-
file_path: str | Path,
|
|
17
|
-
) -> tuple[str, str, list[dict[str, Any]], dict[int, list[dict[str, Any]]]]:
|
|
18
|
-
"""Load trace and group kernels by correlation ID.
|
|
19
15
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
trace = json.load(f)
|
|
16
|
+
# Maps fused operation names to their component operations
|
|
17
|
+
FUSED_OPERATION_COMPONENTS: dict[str, list[str]] = {
|
|
18
|
+
"RMSNorm+GEMM": ["RMSNorm", "Dense GEMM"],
|
|
19
|
+
"SwiGLU+GEMM": ["SwiGLU", "Dense GEMM"],
|
|
20
|
+
"MoE GEMM+SwiGLU": ["MoE GEMM", "SwiGLU"],
|
|
21
|
+
"Embedding+RMSNorm+GEMM": ["Elementwise", "RMSNorm", "Dense GEMM"], # Embedding is often Elementwise
|
|
22
|
+
}
|
|
28
23
|
|
|
29
|
-
# Detect platform
|
|
30
|
-
props = trace.get("deviceProperties", [{}])[0]
|
|
31
|
-
is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
|
|
32
|
-
platform = "AMD" if is_amd else "NVIDIA"
|
|
33
|
-
gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
|
|
34
24
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
25
|
+
@dataclass
|
|
26
|
+
class FusionPattern:
|
|
27
|
+
"""A detected fusion pattern."""
|
|
38
28
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
29
|
+
layer: int
|
|
30
|
+
operation: str # The fused operation name (e.g., "RMSNorm+GEMM")
|
|
31
|
+
fused_platform: str # Platform that fuses (has fewer kernels)
|
|
32
|
+
fused_kernel: str # The actual fused kernel name
|
|
33
|
+
unfused_kernels: list[str] # List of separate kernels on the other platform
|
|
34
|
+
count: int
|
|
35
|
+
evidence: str
|
|
45
36
|
|
|
46
|
-
return platform, gpu_name, kernels, dict(corr_groups)
|
|
47
37
|
|
|
38
|
+
@dataclass
|
|
39
|
+
class FusionAnalysis:
|
|
40
|
+
"""Complete fusion analysis result."""
|
|
48
41
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
) -> tuple[dict[str, int], dict[str, float]]:
|
|
52
|
-
"""Analyze kernel composition within a correlation group.
|
|
42
|
+
patterns: list[FusionPattern] = field(default_factory=list)
|
|
43
|
+
summary: dict[str, Any] = field(default_factory=dict)
|
|
53
44
|
|
|
54
|
-
Args:
|
|
55
|
-
kernels: List of kernel events in the group
|
|
56
45
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
46
|
+
def _is_fused_operation(operation: str) -> bool:
|
|
47
|
+
"""Check if an operation type represents a fused operation.
|
|
48
|
+
|
|
49
|
+
Operations like 'RMSNorm+GEMM', 'MoE GEMM+SwiGLU' indicate fusion.
|
|
50
|
+
Also handles 'Fused (Unknown)' from heuristic detection.
|
|
60
51
|
"""
|
|
61
|
-
|
|
62
|
-
timings: dict[str, float] = defaultdict(float)
|
|
63
|
-
|
|
64
|
-
for k in kernels:
|
|
65
|
-
kernel_type = classify_kernel(k.get("name", ""))
|
|
66
|
-
counts[kernel_type] += 1
|
|
67
|
-
timings[kernel_type] += k.get("dur", 0)
|
|
52
|
+
return "+" in operation or operation == "Fused (Unknown)"
|
|
68
53
|
|
|
69
|
-
return dict(counts), dict(timings)
|
|
70
54
|
|
|
55
|
+
def _get_component_operations(fused_op: str) -> list[str]:
|
|
56
|
+
"""Get the component operations for a fused operation type."""
|
|
57
|
+
if fused_op in FUSED_OPERATION_COMPONENTS:
|
|
58
|
+
return FUSED_OPERATION_COMPONENTS[fused_op]
|
|
59
|
+
|
|
60
|
+
# Parse from the operation name itself (e.g., "RMSNorm+GEMM" -> ["RMSNorm", "GEMM"])
|
|
61
|
+
if "+" in fused_op:
|
|
62
|
+
parts = [p.strip() for p in fused_op.split("+")]
|
|
63
|
+
# Map shorthand to full names
|
|
64
|
+
mapping = {
|
|
65
|
+
"GEMM": "Dense GEMM",
|
|
66
|
+
"SwiGLU": "Triton Fused",
|
|
67
|
+
}
|
|
68
|
+
return [mapping.get(p, p) for p in parts]
|
|
69
|
+
|
|
70
|
+
return []
|
|
71
71
|
|
|
72
|
-
def _match_correlation_groups(
|
|
73
|
-
amd_groups: dict[int, list[dict[str, Any]]],
|
|
74
|
-
nv_groups: dict[int, list[dict[str, Any]]],
|
|
75
|
-
size_tolerance: float = 0.25,
|
|
76
|
-
) -> list[tuple[int, int]]:
|
|
77
|
-
"""Match AMD and NVIDIA correlation groups by size and composition.
|
|
78
|
-
|
|
79
|
-
Since correlation IDs don't match between platforms, we match groups that
|
|
80
|
-
have similar sizes (within tolerance) and likely represent the same operation.
|
|
81
72
|
|
|
73
|
+
def _find_component_kernels(
|
|
74
|
+
layer_alignment: LayerAlignment,
|
|
75
|
+
component_ops: list[str],
|
|
76
|
+
platform: str,
|
|
77
|
+
) -> list[str]:
|
|
78
|
+
"""Find kernels for component operations on the specified platform.
|
|
79
|
+
|
|
82
80
|
Args:
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
81
|
+
layer_alignment: Layer alignment data
|
|
82
|
+
component_ops: List of operation types to look for (e.g., ["RMSNorm", "Dense GEMM"])
|
|
83
|
+
platform: "AMD" or "NVIDIA"
|
|
84
|
+
|
|
87
85
|
Returns:
|
|
88
|
-
List of
|
|
86
|
+
List of kernel names found for these operations
|
|
89
87
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
# Sort AMD groups by size (largest first for better matching)
|
|
101
|
-
amd_sorted = sorted(amd_groups.items(), key=lambda x: len(x[1]), reverse=True)
|
|
102
|
-
|
|
103
|
-
for amd_id, amd_kernels in amd_sorted:
|
|
104
|
-
amd_size = len(amd_kernels)
|
|
105
|
-
amd_comp = amd_comps[amd_id]
|
|
106
|
-
|
|
107
|
-
# Calculate size bounds for matching
|
|
108
|
-
min_size = amd_size / (1 + size_tolerance)
|
|
109
|
-
max_size = amd_size * (1 + size_tolerance)
|
|
110
|
-
|
|
111
|
-
# Find best matching NVIDIA group
|
|
112
|
-
best_match = None
|
|
113
|
-
best_score = float("inf")
|
|
114
|
-
|
|
115
|
-
for nv_id, nv_size in nv_by_size:
|
|
116
|
-
if nv_id in used_nv_ids:
|
|
117
|
-
continue
|
|
118
|
-
|
|
119
|
-
# Quick size filter
|
|
120
|
-
if nv_size < min_size or nv_size > max_size:
|
|
121
|
-
continue
|
|
122
|
-
|
|
123
|
-
# Use pre-computed composition
|
|
124
|
-
nv_comp = nv_comps[nv_id]
|
|
125
|
-
|
|
126
|
-
# Simple similarity: number of shared kernel types
|
|
127
|
-
shared_types = set(amd_comp.keys()) & set(nv_comp.keys())
|
|
128
|
-
similarity = len(shared_types)
|
|
129
|
-
|
|
130
|
-
# Prefer matches with more shared types and closer sizes
|
|
131
|
-
score = abs(amd_size - nv_size) - (similarity * 10)
|
|
132
|
-
|
|
133
|
-
if score < best_score:
|
|
134
|
-
best_score = score
|
|
135
|
-
best_match = nv_id
|
|
136
|
-
|
|
137
|
-
if best_match is not None:
|
|
138
|
-
matches.append((amd_id, best_match))
|
|
139
|
-
used_nv_ids.add(best_match)
|
|
140
|
-
|
|
141
|
-
return matches
|
|
142
|
-
|
|
88
|
+
found_kernels: list[str] = []
|
|
89
|
+
|
|
90
|
+
for pair in layer_alignment.kernel_pairs:
|
|
91
|
+
if pair.operation in component_ops:
|
|
92
|
+
if platform == "AMD" and pair.amd_kernel and pair.amd_count > 0:
|
|
93
|
+
found_kernels.append(pair.amd_kernel)
|
|
94
|
+
elif platform == "NVIDIA" and pair.nvidia_kernel and pair.nvidia_count > 0:
|
|
95
|
+
found_kernels.append(pair.nvidia_kernel)
|
|
96
|
+
|
|
97
|
+
return found_kernels
|
|
143
98
|
|
|
144
|
-
def _find_fusion_mappings(
|
|
145
|
-
trace1_kernels: list[dict],
|
|
146
|
-
trace2_kernels: list[dict],
|
|
147
|
-
trace1_name: str = "Trace1",
|
|
148
|
-
trace2_name: str = "Trace2",
|
|
149
|
-
) -> list[dict]:
|
|
150
|
-
"""Find fusion mappings by analyzing kernel execution sequence patterns.
|
|
151
99
|
|
|
152
|
-
|
|
153
|
-
|
|
100
|
+
def detect_fusion_in_aligned_layers(
|
|
101
|
+
layer_alignments: list[LayerAlignment],
|
|
102
|
+
) -> FusionAnalysis:
|
|
103
|
+
"""Detect fusion patterns from aligned layer data.
|
|
154
104
|
|
|
155
105
|
Args:
|
|
156
|
-
|
|
157
|
-
trace2_kernels: List of kernel events from second trace
|
|
158
|
-
trace1_name: Name of first platform (e.g., "AMD")
|
|
159
|
-
trace2_name: Name of second platform (e.g., "NVIDIA")
|
|
106
|
+
layer_alignments: List of aligned layers
|
|
160
107
|
|
|
161
108
|
Returns:
|
|
162
|
-
|
|
163
|
-
- fused_platform: Which platform fuses the operations
|
|
164
|
-
- fused_kernel_type: The single fused kernel type
|
|
165
|
-
- unfused_platform: Which platform runs them separately
|
|
166
|
-
- unfused_sequence: List of kernel types run separately
|
|
167
|
-
- pattern_count: How many times this pattern appears
|
|
168
|
-
- pattern_confidence: Fraction of occurrences following this pattern
|
|
169
|
-
- evidence: Human-readable description
|
|
109
|
+
FusionAnalysis with detected patterns
|
|
170
110
|
"""
|
|
171
|
-
|
|
172
|
-
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
111
|
+
fusion_patterns: list[FusionPattern] = []
|
|
173
112
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
"
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
"
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
ktype: unfused_count
|
|
346
|
-
},
|
|
347
|
-
"pattern_count": unfused_count - fused_count,
|
|
348
|
-
"pattern_confidence": (unfused_count - fused_count) / unfused_count,
|
|
349
|
-
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}), {fused_platform} likely fuses into {fusion_target}"
|
|
350
|
-
})
|
|
351
|
-
|
|
352
|
-
return mappings
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
def _detect_intra_type_fusion(
|
|
356
|
-
trace1_kernels: list[dict],
|
|
357
|
-
trace2_kernels: list[dict],
|
|
358
|
-
trace1_name: str,
|
|
359
|
-
trace2_name: str,
|
|
360
|
-
) -> list[dict]:
|
|
361
|
-
"""Detect intra-type fusion where consecutive same-type kernels are fused.
|
|
362
|
-
|
|
363
|
-
Example: AMD runs Sort→Sort→Sort (42 calls) while NVIDIA runs Sort→Sort (10 calls)
|
|
364
|
-
This indicates NVIDIA has a more efficient Sort implementation that fuses operations.
|
|
365
|
-
"""
|
|
366
|
-
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
367
|
-
|
|
368
|
-
def analyze_chains(kernels):
|
|
369
|
-
"""Find chains of consecutive same-type kernels"""
|
|
370
|
-
sorted_kernels = sorted(kernels, key=lambda k: k.get('ts', 0))
|
|
371
|
-
types = [classify_kernel(k['name']) for k in sorted_kernels]
|
|
372
|
-
|
|
373
|
-
chains = defaultdict(list)
|
|
374
|
-
i = 0
|
|
375
|
-
while i < len(types):
|
|
376
|
-
ktype = types[i]
|
|
377
|
-
count = 0
|
|
378
|
-
while i < len(types) and types[i] == ktype:
|
|
379
|
-
count += 1
|
|
380
|
-
i += 1
|
|
381
|
-
chains[ktype].append(count)
|
|
382
|
-
|
|
383
|
-
return chains
|
|
384
|
-
|
|
385
|
-
trace1_chains = analyze_chains(trace1_kernels)
|
|
386
|
-
trace2_chains = analyze_chains(trace2_kernels)
|
|
387
|
-
|
|
388
|
-
mappings = []
|
|
389
|
-
all_types = set(trace1_chains.keys()) | set(trace2_chains.keys())
|
|
390
|
-
|
|
391
|
-
for ktype in all_types:
|
|
392
|
-
t1_lengths = trace1_chains.get(ktype, [])
|
|
393
|
-
t2_lengths = trace2_chains.get(ktype, [])
|
|
394
|
-
|
|
395
|
-
# Skip if not enough data
|
|
396
|
-
if len(t1_lengths) < 5 and len(t2_lengths) < 5:
|
|
397
|
-
continue
|
|
398
|
-
|
|
399
|
-
# Filter to chains with multiple kernels
|
|
400
|
-
t1_multi = [l for l in t1_lengths if l > 1]
|
|
401
|
-
t2_multi = [l for l in t2_lengths if l > 1]
|
|
402
|
-
|
|
403
|
-
if not t1_multi and not t2_multi:
|
|
404
|
-
continue
|
|
405
|
-
|
|
406
|
-
t1_total = sum(t1_lengths)
|
|
407
|
-
t2_total = sum(t2_lengths)
|
|
408
|
-
t1_chains = len(t1_multi) if t1_multi else len(t1_lengths)
|
|
409
|
-
t2_chains = len(t2_multi) if t2_multi else len(t2_lengths)
|
|
410
|
-
|
|
411
|
-
if t1_chains == 0 or t2_chains == 0:
|
|
412
|
-
continue
|
|
413
|
-
|
|
414
|
-
t1_avg_chain = sum(t1_multi) / len(t1_multi) if t1_multi else 1.0
|
|
415
|
-
t2_avg_chain = sum(t2_multi) / len(t2_multi) if t2_multi else 1.0
|
|
416
|
-
|
|
417
|
-
chain_ratio = max(t1_avg_chain, t2_avg_chain) / min(t1_avg_chain, t2_avg_chain)
|
|
418
|
-
|
|
419
|
-
# Significant intra-fusion if chains are 2x+ different
|
|
420
|
-
if chain_ratio > 2.0 and abs(t1_total - t2_total) > 100:
|
|
421
|
-
if t1_avg_chain > t2_avg_chain:
|
|
422
|
-
unfused_platform = trace1_name
|
|
423
|
-
fused_platform = trace2_name
|
|
424
|
-
unfused_chains = t1_chains
|
|
425
|
-
fused_chains = t2_chains
|
|
426
|
-
unfused_avg = t1_avg_chain
|
|
427
|
-
fused_avg = t2_avg_chain
|
|
428
|
-
unfused_total = t1_total
|
|
429
|
-
fused_total = t2_total
|
|
430
|
-
else:
|
|
431
|
-
unfused_platform = trace2_name
|
|
432
|
-
fused_platform = trace1_name
|
|
433
|
-
unfused_chains = t2_chains
|
|
434
|
-
fused_chains = t1_chains
|
|
435
|
-
unfused_avg = t2_avg_chain
|
|
436
|
-
fused_avg = t1_avg_chain
|
|
437
|
-
unfused_total = t2_total
|
|
438
|
-
fused_total = t1_total
|
|
439
|
-
|
|
440
|
-
mappings.append({
|
|
441
|
-
"fused_platform": fused_platform,
|
|
442
|
-
"fused_kernel_type": ktype,
|
|
443
|
-
"fused_count": fused_total,
|
|
444
|
-
"unfused_platform": unfused_platform,
|
|
445
|
-
"unfused_sequence": [ktype, ktype], # Same type repeated
|
|
446
|
-
"unfused_count_per_type": {ktype: unfused_total},
|
|
447
|
-
"pattern_count": unfused_total - fused_total,
|
|
448
|
-
"pattern_confidence": min(unfused_chains, fused_chains) / max(unfused_chains, fused_chains),
|
|
449
|
-
"evidence": f"{unfused_platform} runs {ktype} in chains of {unfused_avg:.0f} calls ({unfused_chains} chains, {unfused_total:,} total), {fused_platform} fuses to {fused_avg:.0f} calls ({fused_chains} chains, {fused_total:,} total) - {chain_ratio:.1f}x more efficient"
|
|
450
|
-
})
|
|
451
|
-
|
|
452
|
-
return mappings
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
def _find_partial_fusion_via_groups(
|
|
456
|
-
trace1_large: dict[int, list[dict]],
|
|
457
|
-
trace2_large: dict[int, list[dict]],
|
|
458
|
-
matches: list[tuple[int, int]],
|
|
459
|
-
trace1_name: str,
|
|
460
|
-
trace2_name: str,
|
|
461
|
-
) -> list[dict]:
|
|
462
|
-
"""Find partial fusion patterns by analyzing correlation group differences.
|
|
463
|
-
|
|
464
|
-
When one platform has fewer of a kernel type, check what kernel types the
|
|
465
|
-
other platform has MORE of in those same groups - those are likely fusion targets.
|
|
466
|
-
"""
|
|
467
|
-
from collections import Counter
|
|
468
|
-
from wafer_core.lib.trace_compare.classifier import classify_kernel
|
|
469
|
-
|
|
470
|
-
mappings = []
|
|
471
|
-
|
|
472
|
-
# For each matched pair, track kernel type counts
|
|
473
|
-
trace1_all_types = []
|
|
474
|
-
trace2_all_types = []
|
|
475
|
-
|
|
476
|
-
for trace1_cid, trace2_cid in matches:
|
|
477
|
-
trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
478
|
-
trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
479
|
-
trace1_all_types.extend(trace1_ktypes)
|
|
480
|
-
trace2_all_types.extend(trace2_ktypes)
|
|
481
|
-
|
|
482
|
-
# Find kernel types with significant imbalances
|
|
483
|
-
trace1_counts = Counter(trace1_all_types)
|
|
484
|
-
trace2_counts = Counter(trace2_all_types)
|
|
485
|
-
all_types = set(trace1_counts.keys()) | set(trace2_counts.keys())
|
|
486
|
-
|
|
487
|
-
for ktype in all_types:
|
|
488
|
-
trace1_count = trace1_counts.get(ktype, 0)
|
|
489
|
-
trace2_count = trace2_counts.get(ktype, 0)
|
|
490
|
-
|
|
491
|
-
if trace1_count == 0 or trace2_count == 0:
|
|
492
|
-
continue # Handled by sequence-based detection
|
|
493
|
-
|
|
494
|
-
ratio = max(trace1_count, trace2_count) / min(trace1_count, trace2_count)
|
|
495
|
-
|
|
496
|
-
if ratio < 1.3 or trace1_count + trace2_count < 100:
|
|
497
|
-
continue # Not significant
|
|
498
|
-
|
|
499
|
-
# Determine which platform has fewer (fuses more)
|
|
500
|
-
if trace1_count > trace2_count:
|
|
501
|
-
unfused_platform = trace1_name
|
|
502
|
-
fused_platform = trace2_name
|
|
503
|
-
unfused_count = trace1_count
|
|
504
|
-
fused_count = trace2_count
|
|
505
|
-
|
|
506
|
-
# Find groups where trace1 has this kernel but trace2 doesn't
|
|
507
|
-
fusion_targets = Counter()
|
|
508
|
-
groups_analyzed = 0
|
|
509
|
-
|
|
510
|
-
for trace1_cid, trace2_cid in matches:
|
|
511
|
-
trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
512
|
-
trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
513
|
-
|
|
514
|
-
trace1_has = ktype in trace1_ktypes
|
|
515
|
-
trace2_has = ktype in trace2_ktypes
|
|
516
|
-
|
|
517
|
-
if trace1_has and not trace2_has:
|
|
518
|
-
# What does trace2 have MORE of in this group?
|
|
519
|
-
trace1_kcounts = Counter(trace1_ktypes)
|
|
520
|
-
trace2_kcounts = Counter(trace2_ktypes)
|
|
521
|
-
|
|
522
|
-
for other_type in set(trace2_kcounts.keys()):
|
|
523
|
-
if other_type == ktype or other_type == "Other":
|
|
524
|
-
continue
|
|
525
|
-
diff = trace2_kcounts[other_type] - trace1_kcounts.get(other_type, 0)
|
|
526
|
-
if diff > 0:
|
|
527
|
-
fusion_targets[other_type] += diff
|
|
528
|
-
|
|
529
|
-
groups_analyzed += 1
|
|
530
|
-
|
|
531
|
-
if fusion_targets and groups_analyzed >= 5:
|
|
532
|
-
# Report top fusion targets
|
|
533
|
-
top_targets = fusion_targets.most_common(3)
|
|
534
|
-
target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
|
|
535
|
-
|
|
536
|
-
mappings.append({
|
|
537
|
-
"fused_platform": fused_platform,
|
|
538
|
-
"fused_kernel_type": top_targets[0][0],
|
|
539
|
-
"fused_count": fused_count,
|
|
540
|
-
"unfused_platform": unfused_platform,
|
|
541
|
-
"unfused_sequence": [ktype],
|
|
542
|
-
"unfused_count_per_type": {ktype: unfused_count},
|
|
543
|
-
"pattern_count": unfused_count - fused_count,
|
|
544
|
-
"pattern_confidence": groups_analyzed / len(matches) if matches else 0,
|
|
545
|
-
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
|
|
546
|
-
})
|
|
113
|
+
# Track patterns we've already seen to avoid duplicates
|
|
114
|
+
seen_patterns: set[tuple[int, str, str]] = set() # (layer, operation, fused_platform)
|
|
115
|
+
|
|
116
|
+
for layer_alignment in layer_alignments:
|
|
117
|
+
for pair in layer_alignment.kernel_pairs:
|
|
118
|
+
is_fused_op = _is_fused_operation(pair.operation)
|
|
119
|
+
|
|
120
|
+
# Case 1: AMD has a kernel for this operation, NVIDIA doesn't
|
|
121
|
+
if pair.amd_kernel and pair.amd_count > 0 and (pair.nvidia_kernel is None or pair.nvidia_count == 0):
|
|
122
|
+
if is_fused_op:
|
|
123
|
+
# AMD HAS the fused kernel → AMD is fusing
|
|
124
|
+
# Find what NVIDIA runs separately
|
|
125
|
+
component_ops = _get_component_operations(pair.operation)
|
|
126
|
+
unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "NVIDIA")
|
|
127
|
+
|
|
128
|
+
pattern_key = (layer_alignment.layer, pair.operation, "AMD")
|
|
129
|
+
if pattern_key not in seen_patterns:
|
|
130
|
+
seen_patterns.add(pattern_key)
|
|
131
|
+
|
|
132
|
+
if unfused_kernels:
|
|
133
|
+
evidence = f"AMD fuses {' + '.join(component_ops)} into {pair.amd_kernel}, NVIDIA runs {len(unfused_kernels)} separate kernels"
|
|
134
|
+
else:
|
|
135
|
+
evidence = f"AMD fuses into {pair.amd_kernel}, NVIDIA runs components separately"
|
|
136
|
+
|
|
137
|
+
fusion_patterns.append(
|
|
138
|
+
FusionPattern(
|
|
139
|
+
layer=layer_alignment.layer,
|
|
140
|
+
operation=pair.operation,
|
|
141
|
+
fused_platform="AMD",
|
|
142
|
+
fused_kernel=pair.amd_kernel,
|
|
143
|
+
unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
|
|
144
|
+
count=pair.amd_count,
|
|
145
|
+
evidence=evidence,
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
# Regular case: AMD has a kernel that NVIDIA fuses into something else
|
|
150
|
+
# This means NVIDIA is fusing (it doesn't need this separate kernel)
|
|
151
|
+
pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
|
|
152
|
+
if pattern_key not in seen_patterns:
|
|
153
|
+
seen_patterns.add(pattern_key)
|
|
154
|
+
|
|
155
|
+
evidence = f"AMD runs {pair.amd_kernel} separately ({pair.amd_count}x), NVIDIA fuses this operation"
|
|
156
|
+
|
|
157
|
+
fusion_patterns.append(
|
|
158
|
+
FusionPattern(
|
|
159
|
+
layer=layer_alignment.layer,
|
|
160
|
+
operation=pair.operation,
|
|
161
|
+
fused_platform="NVIDIA",
|
|
162
|
+
fused_kernel="(fused into nearby kernel)",
|
|
163
|
+
unfused_kernels=[pair.amd_kernel],
|
|
164
|
+
count=pair.amd_count,
|
|
165
|
+
evidence=evidence,
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Case 2: NVIDIA has a kernel for this operation, AMD doesn't
|
|
170
|
+
elif pair.nvidia_kernel and pair.nvidia_count > 0 and (pair.amd_kernel is None or pair.amd_count == 0):
|
|
171
|
+
if is_fused_op:
|
|
172
|
+
# NVIDIA HAS the fused kernel → NVIDIA is fusing
|
|
173
|
+
# Find what AMD runs separately
|
|
174
|
+
component_ops = _get_component_operations(pair.operation)
|
|
175
|
+
unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "AMD")
|
|
176
|
+
|
|
177
|
+
pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
|
|
178
|
+
if pattern_key not in seen_patterns:
|
|
179
|
+
seen_patterns.add(pattern_key)
|
|
180
|
+
|
|
181
|
+
if unfused_kernels:
|
|
182
|
+
evidence = f"NVIDIA fuses {' + '.join(component_ops)} into {pair.nvidia_kernel}, AMD runs {len(unfused_kernels)} separate kernels"
|
|
183
|
+
else:
|
|
184
|
+
evidence = f"NVIDIA fuses into {pair.nvidia_kernel}, AMD runs components separately"
|
|
185
|
+
|
|
186
|
+
fusion_patterns.append(
|
|
187
|
+
FusionPattern(
|
|
188
|
+
layer=layer_alignment.layer,
|
|
189
|
+
operation=pair.operation,
|
|
190
|
+
fused_platform="NVIDIA",
|
|
191
|
+
fused_kernel=pair.nvidia_kernel,
|
|
192
|
+
unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
|
|
193
|
+
count=pair.nvidia_count,
|
|
194
|
+
evidence=evidence,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
else:
|
|
198
|
+
# Regular case: NVIDIA has a kernel that AMD fuses into something else
|
|
199
|
+
pattern_key = (layer_alignment.layer, pair.operation, "AMD")
|
|
200
|
+
if pattern_key not in seen_patterns:
|
|
201
|
+
seen_patterns.add(pattern_key)
|
|
202
|
+
|
|
203
|
+
evidence = f"NVIDIA runs {pair.nvidia_kernel} separately ({pair.nvidia_count}x), AMD fuses this operation"
|
|
204
|
+
|
|
205
|
+
fusion_patterns.append(
|
|
206
|
+
FusionPattern(
|
|
207
|
+
layer=layer_alignment.layer,
|
|
208
|
+
operation=pair.operation,
|
|
209
|
+
fused_platform="AMD",
|
|
210
|
+
fused_kernel="(fused into nearby kernel)",
|
|
211
|
+
unfused_kernels=[pair.nvidia_kernel],
|
|
212
|
+
count=pair.nvidia_count,
|
|
213
|
+
evidence=evidence,
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Case 3: Both have kernels but with very different counts (partial fusion)
|
|
218
|
+
elif (
|
|
219
|
+
pair.amd_kernel
|
|
220
|
+
and pair.nvidia_kernel
|
|
221
|
+
and pair.amd_count > 0
|
|
222
|
+
and pair.nvidia_count > 0
|
|
223
|
+
):
|
|
224
|
+
count_ratio = pair.amd_count / pair.nvidia_count if pair.nvidia_count > 0 else float('inf')
|
|
225
|
+
|
|
226
|
+
if count_ratio > 1.5:
|
|
227
|
+
# AMD runs more → NVIDIA fuses some instances
|
|
228
|
+
pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
|
|
229
|
+
if pattern_key not in seen_patterns:
|
|
230
|
+
seen_patterns.add(pattern_key)
|
|
231
|
+
|
|
232
|
+
evidence = (
|
|
233
|
+
f"AMD runs {pair.amd_kernel} {count_ratio:.1f}x more "
|
|
234
|
+
f"({pair.amd_count} vs {pair.nvidia_count}), NVIDIA partially fuses"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
fusion_patterns.append(
|
|
238
|
+
FusionPattern(
|
|
239
|
+
layer=layer_alignment.layer,
|
|
240
|
+
operation=pair.operation,
|
|
241
|
+
fused_platform="NVIDIA",
|
|
242
|
+
fused_kernel=pair.nvidia_kernel,
|
|
243
|
+
unfused_kernels=[pair.amd_kernel],
|
|
244
|
+
count=pair.amd_count - pair.nvidia_count,
|
|
245
|
+
evidence=evidence,
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
elif count_ratio < 0.67:
|
|
250
|
+
# NVIDIA runs more → AMD fuses some instances
|
|
251
|
+
inverse_ratio = pair.nvidia_count / pair.amd_count if pair.amd_count > 0 else float('inf')
|
|
252
|
+
pattern_key = (layer_alignment.layer, pair.operation, "AMD")
|
|
253
|
+
if pattern_key not in seen_patterns:
|
|
254
|
+
seen_patterns.add(pattern_key)
|
|
255
|
+
|
|
256
|
+
evidence = (
|
|
257
|
+
f"NVIDIA runs {pair.nvidia_kernel} {inverse_ratio:.1f}x more "
|
|
258
|
+
f"({pair.nvidia_count} vs {pair.amd_count}), AMD partially fuses"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
fusion_patterns.append(
|
|
262
|
+
FusionPattern(
|
|
263
|
+
layer=layer_alignment.layer,
|
|
264
|
+
operation=pair.operation,
|
|
265
|
+
fused_platform="AMD",
|
|
266
|
+
fused_kernel=pair.amd_kernel,
|
|
267
|
+
unfused_kernels=[pair.nvidia_kernel],
|
|
268
|
+
count=pair.nvidia_count - pair.amd_count,
|
|
269
|
+
evidence=evidence,
|
|
270
|
+
)
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Aggregate patterns by operation type and fused platform
|
|
274
|
+
aggregated: dict[tuple[str, str], FusionPattern] = {}
|
|
275
|
+
for pattern in fusion_patterns:
|
|
276
|
+
key = (pattern.operation, pattern.fused_platform)
|
|
277
|
+
if key in aggregated:
|
|
278
|
+
existing = aggregated[key]
|
|
279
|
+
existing.count += pattern.count
|
|
280
|
+
# Merge unfused kernels (deduplicate)
|
|
281
|
+
for k in pattern.unfused_kernels:
|
|
282
|
+
if k not in existing.unfused_kernels:
|
|
283
|
+
existing.unfused_kernels.append(k)
|
|
547
284
|
else:
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
for trace1_cid, trace2_cid in matches:
|
|
558
|
-
trace1_ktypes = [classify_kernel(k.get("name", "")) for k in trace1_large[trace1_cid]]
|
|
559
|
-
trace2_ktypes = [classify_kernel(k.get("name", "")) for k in trace2_large[trace2_cid]]
|
|
560
|
-
|
|
561
|
-
trace1_has = ktype in trace1_ktypes
|
|
562
|
-
trace2_has = ktype in trace2_ktypes
|
|
563
|
-
|
|
564
|
-
if trace2_has and not trace1_has:
|
|
565
|
-
trace1_kcounts = Counter(trace1_ktypes)
|
|
566
|
-
trace2_kcounts = Counter(trace2_ktypes)
|
|
567
|
-
|
|
568
|
-
for other_type in set(trace1_kcounts.keys()):
|
|
569
|
-
if other_type == ktype or other_type == "Other":
|
|
570
|
-
continue
|
|
571
|
-
diff = trace1_kcounts[other_type] - trace2_kcounts.get(other_type, 0)
|
|
572
|
-
if diff > 0:
|
|
573
|
-
fusion_targets[other_type] += diff
|
|
574
|
-
|
|
575
|
-
groups_analyzed += 1
|
|
576
|
-
|
|
577
|
-
if fusion_targets and groups_analyzed >= 5:
|
|
578
|
-
top_targets = fusion_targets.most_common(3)
|
|
579
|
-
target_str = ", ".join(f"{t[0]} (+{t[1]})" for t in top_targets)
|
|
580
|
-
|
|
581
|
-
mappings.append({
|
|
582
|
-
"fused_platform": fused_platform,
|
|
583
|
-
"fused_kernel_type": top_targets[0][0],
|
|
584
|
-
"fused_count": fused_count,
|
|
585
|
-
"unfused_platform": unfused_platform,
|
|
586
|
-
"unfused_sequence": [ktype],
|
|
587
|
-
"unfused_count_per_type": {ktype: unfused_count},
|
|
588
|
-
"pattern_count": unfused_count - fused_count,
|
|
589
|
-
"pattern_confidence": groups_analyzed / len(matches) if matches else 0,
|
|
590
|
-
"evidence": f"{unfused_platform} calls {ktype} {ratio:.1f}x more ({unfused_count} vs {fused_count}). In {groups_analyzed} groups where {unfused_platform} has {ktype}, {fused_platform} has more: {target_str}"
|
|
591
|
-
})
|
|
592
|
-
|
|
593
|
-
return mappings
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
def analyze_fusion_differences(
|
|
597
|
-
amd_trace_path: str | Path,
|
|
598
|
-
nv_trace_path: str | Path,
|
|
599
|
-
min_group_size: int = 50,
|
|
600
|
-
) -> dict[str, Any]:
|
|
601
|
-
"""Main fusion analysis function.
|
|
285
|
+
aggregated[key] = FusionPattern(
|
|
286
|
+
layer=pattern.layer,
|
|
287
|
+
operation=pattern.operation,
|
|
288
|
+
fused_platform=pattern.fused_platform,
|
|
289
|
+
fused_kernel=pattern.fused_kernel,
|
|
290
|
+
unfused_kernels=list(pattern.unfused_kernels),
|
|
291
|
+
count=pattern.count,
|
|
292
|
+
evidence=pattern.evidence,
|
|
293
|
+
)
|
|
602
294
|
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
nv_trace_path: Path to NVIDIA trace
|
|
606
|
-
min_group_size: Only analyze correlation groups with at least this many kernels
|
|
295
|
+
amd_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "AMD")
|
|
296
|
+
nvidia_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "NVIDIA")
|
|
607
297
|
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
# Load traces (maintain order - don't swap)
|
|
616
|
-
trace1_platform, trace1_gpu, trace1_kernels, trace1_corr_groups = _load_trace_for_fusion(
|
|
617
|
-
amd_trace_path
|
|
618
|
-
)
|
|
619
|
-
trace2_platform, trace2_gpu, trace2_kernels, trace2_corr_groups = _load_trace_for_fusion(
|
|
620
|
-
nv_trace_path
|
|
298
|
+
return FusionAnalysis(
|
|
299
|
+
patterns=list(aggregated.values()),
|
|
300
|
+
summary={
|
|
301
|
+
"amd_fuses": amd_fuses,
|
|
302
|
+
"nvidia_fuses": nvidia_fuses,
|
|
303
|
+
"total_fusion_opportunities": len(aggregated),
|
|
304
|
+
},
|
|
621
305
|
)
|
|
622
306
|
|
|
623
|
-
# Filter to "significant" correlation groups
|
|
624
|
-
trace1_large = {
|
|
625
|
-
cid: kernels
|
|
626
|
-
for cid, kernels in trace1_corr_groups.items()
|
|
627
|
-
if len(kernels) >= min_group_size
|
|
628
|
-
}
|
|
629
|
-
trace2_large = {
|
|
630
|
-
cid: kernels
|
|
631
|
-
for cid, kernels in trace2_corr_groups.items()
|
|
632
|
-
if len(kernels) >= min_group_size
|
|
633
|
-
}
|
|
634
|
-
|
|
635
|
-
# Match correlation groups between platforms
|
|
636
|
-
matches = _match_correlation_groups(trace1_large, trace2_large)
|
|
637
|
-
|
|
638
|
-
# Analyze differences in matched groups
|
|
639
|
-
fusion_diffs: dict[str, dict[str, Any]] = defaultdict(
|
|
640
|
-
lambda: {
|
|
641
|
-
"trace1_count": 0,
|
|
642
|
-
"trace2_count": 0,
|
|
643
|
-
"trace1_time_us": 0,
|
|
644
|
-
"trace2_time_us": 0,
|
|
645
|
-
"groups_with_diff": 0,
|
|
646
|
-
"total_groups": 0,
|
|
647
|
-
}
|
|
648
|
-
)
|
|
649
307
|
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
trace1_comp, trace1_times = _analyze_correlation_group(trace1_large[trace1_cid])
|
|
655
|
-
trace2_comp, trace2_times = _analyze_correlation_group(trace2_large[trace2_cid])
|
|
656
|
-
|
|
657
|
-
# Find all kernel types in either platform
|
|
658
|
-
all_types = set(trace1_comp.keys()) | set(trace2_comp.keys())
|
|
659
|
-
|
|
660
|
-
for ktype in all_types:
|
|
661
|
-
trace1_count = trace1_comp.get(ktype, 0)
|
|
662
|
-
trace2_count = trace2_comp.get(ktype, 0)
|
|
663
|
-
trace1_time = trace1_times.get(ktype, 0)
|
|
664
|
-
trace2_time = trace2_times.get(ktype, 0)
|
|
665
|
-
|
|
666
|
-
fusion_diffs[ktype]["trace1_count"] += trace1_count
|
|
667
|
-
fusion_diffs[ktype]["trace2_count"] += trace2_count
|
|
668
|
-
fusion_diffs[ktype]["trace1_time_us"] += trace1_time
|
|
669
|
-
fusion_diffs[ktype]["trace2_time_us"] += trace2_time
|
|
670
|
-
fusion_diffs[ktype]["total_groups"] += 1
|
|
671
|
-
|
|
672
|
-
if trace1_count != trace2_count:
|
|
673
|
-
fusion_diffs[ktype]["groups_with_diff"] += 1
|
|
674
|
-
|
|
675
|
-
# NEW: Find actual kernel mappings in this correlation group
|
|
676
|
-
group_mappings = _find_fusion_mappings(
|
|
677
|
-
trace1_large[trace1_cid],
|
|
678
|
-
trace2_large[trace2_cid],
|
|
679
|
-
trace1_name=trace1_platform,
|
|
680
|
-
trace2_name=trace2_platform
|
|
681
|
-
)
|
|
682
|
-
# Add correlation ID context to each mapping
|
|
683
|
-
for mapping in group_mappings:
|
|
684
|
-
mapping["correlation_group_trace1"] = trace1_cid
|
|
685
|
-
mapping["correlation_group_trace2"] = trace2_cid
|
|
686
|
-
all_fusion_mappings.extend(group_mappings)
|
|
687
|
-
|
|
688
|
-
# Also get global counts for context
|
|
689
|
-
global_trace1_counts: Counter[str] = Counter(
|
|
690
|
-
[classify_kernel(k.get("name", "")) for k in trace1_kernels]
|
|
691
|
-
)
|
|
692
|
-
global_trace2_counts: Counter[str] = Counter(
|
|
693
|
-
[classify_kernel(k.get("name", "")) for k in trace2_kernels]
|
|
694
|
-
)
|
|
308
|
+
def analyze_fusion_from_alignment(
|
|
309
|
+
layer_alignments: list[LayerAlignment],
|
|
310
|
+
) -> dict[str, Any]:
|
|
311
|
+
"""Analyze fusion from alignment data (for API compatibility).
|
|
695
312
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
"metadata": {
|
|
699
|
-
"trace1_gpu": trace1_gpu,
|
|
700
|
-
"trace2_gpu": trace2_gpu,
|
|
701
|
-
"trace1_total_kernels": len(trace1_kernels),
|
|
702
|
-
"trace2_total_kernels": len(trace2_kernels),
|
|
703
|
-
"trace1_correlation_groups": len(trace1_large),
|
|
704
|
-
"trace2_correlation_groups": len(trace2_large),
|
|
705
|
-
"matched_groups": len(matches),
|
|
706
|
-
},
|
|
707
|
-
"global_counts": {},
|
|
708
|
-
"fusion_opportunities": [],
|
|
709
|
-
"fusion_mappings": all_fusion_mappings, # NEW: Include actual mappings
|
|
710
|
-
}
|
|
313
|
+
Args:
|
|
314
|
+
layer_alignments: List of aligned layers
|
|
711
315
|
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
316
|
+
Returns:
|
|
317
|
+
Dictionary with fusion analysis results (compatible with old API)
|
|
318
|
+
"""
|
|
319
|
+
fusion_analysis = detect_fusion_in_aligned_layers(layer_alignments)
|
|
320
|
+
|
|
321
|
+
fusion_opportunities = []
|
|
322
|
+
fusion_mappings = []
|
|
323
|
+
|
|
324
|
+
for pattern in fusion_analysis.patterns:
|
|
325
|
+
unfused_platform = "NVIDIA" if pattern.fused_platform == "AMD" else "AMD"
|
|
326
|
+
|
|
327
|
+
fusion_opportunities.append(
|
|
328
|
+
{
|
|
329
|
+
"kernel_type": pattern.operation,
|
|
330
|
+
"layer": pattern.layer,
|
|
331
|
+
"fused_by": pattern.fused_platform,
|
|
332
|
+
"fused_kernel": pattern.fused_kernel,
|
|
333
|
+
"unfused_kernels": pattern.unfused_kernels,
|
|
334
|
+
"count": pattern.count,
|
|
335
|
+
"evidence": pattern.evidence,
|
|
723
336
|
}
|
|
724
|
-
|
|
725
|
-
# Fusion opportunities from matched correlation groups
|
|
726
|
-
for ktype, stats in fusion_diffs.items():
|
|
727
|
-
trace1_avg = (
|
|
728
|
-
stats["trace1_count"] / stats["total_groups"]
|
|
729
|
-
if stats["total_groups"] > 0
|
|
730
|
-
else 0
|
|
731
337
|
)
|
|
732
|
-
trace2_avg = (
|
|
733
|
-
stats["trace2_count"] / stats["total_groups"]
|
|
734
|
-
if stats["total_groups"] > 0
|
|
735
|
-
else 0
|
|
736
|
-
)
|
|
737
|
-
trace1_time_ms = stats["trace1_time_us"] / 1000
|
|
738
|
-
trace2_time_ms = stats["trace2_time_us"] / 1000
|
|
739
|
-
|
|
740
|
-
# Calculate significance
|
|
741
|
-
diff_ratio = trace1_avg / trace2_avg if trace2_avg > 0 else float("inf")
|
|
742
|
-
reverse_ratio = trace2_avg / trace1_avg if trace1_avg > 0 else float("inf")
|
|
743
|
-
|
|
744
|
-
# Only report if there's a significant difference
|
|
745
|
-
# Either: one platform has it and the other doesn't (ratio > 10)
|
|
746
|
-
# Or: one platform has significantly more (ratio > 2.0)
|
|
747
|
-
is_significant = (
|
|
748
|
-
(diff_ratio > 10.0 or reverse_ratio > 10.0)
|
|
749
|
-
or ( # One platform doesn't have it
|
|
750
|
-
(diff_ratio > 2.0 or reverse_ratio > 2.0)
|
|
751
|
-
and stats["trace1_count"] + stats["trace2_count"] > 20 # Significant difference # Not trivial counts
|
|
752
|
-
)
|
|
753
|
-
)
|
|
754
|
-
|
|
755
|
-
if is_significant:
|
|
756
|
-
# Determine who fuses (who has FEWER calls = more fusion)
|
|
757
|
-
if diff_ratio > 1.5:
|
|
758
|
-
fused_by = trace2_platform # Trace 2 has fewer calls, so it fuses more
|
|
759
|
-
ratio = diff_ratio
|
|
760
|
-
else:
|
|
761
|
-
fused_by = trace1_platform # Trace 1 has fewer calls, so it fuses more
|
|
762
|
-
ratio = reverse_ratio
|
|
763
|
-
|
|
764
|
-
# Calculate time ratio (who's faster for this operation)
|
|
765
|
-
time_ratio = trace1_time_ms / trace2_time_ms if trace2_time_ms > 0 else float("inf")
|
|
766
|
-
|
|
767
|
-
results["fusion_opportunities"].append(
|
|
768
|
-
{
|
|
769
|
-
"kernel_type": ktype,
|
|
770
|
-
"trace1_total": stats["trace1_count"],
|
|
771
|
-
"trace2_total": stats["trace2_count"],
|
|
772
|
-
"trace1_avg_per_group": trace1_avg,
|
|
773
|
-
"trace2_avg_per_group": trace2_avg,
|
|
774
|
-
"trace1_time_ms": trace1_time_ms,
|
|
775
|
-
"trace2_time_ms": trace2_time_ms,
|
|
776
|
-
"time_ratio": time_ratio,
|
|
777
|
-
"ratio": ratio,
|
|
778
|
-
"fused_by": fused_by,
|
|
779
|
-
"groups_affected": stats["groups_with_diff"],
|
|
780
|
-
"total_groups": stats["total_groups"],
|
|
781
|
-
}
|
|
782
|
-
)
|
|
783
338
|
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
if trace1_total + trace2_total < 100:
|
|
796
|
-
continue
|
|
797
|
-
|
|
798
|
-
# Calculate global ratio
|
|
799
|
-
global_ratio = trace1_total / trace2_total if trace2_total > 0 else float("inf")
|
|
800
|
-
global_reverse_ratio = trace2_total / trace1_total if trace1_total > 0 else float("inf")
|
|
801
|
-
|
|
802
|
-
# Check if globally significant (more aggressive threshold for comprehensive detection)
|
|
803
|
-
is_globally_significant = (
|
|
804
|
-
(global_ratio > 2.0 or global_reverse_ratio > 2.0)
|
|
805
|
-
and (trace1_total + trace2_total > 100)
|
|
339
|
+
fusion_mappings.append(
|
|
340
|
+
{
|
|
341
|
+
"fused_platform": pattern.fused_platform,
|
|
342
|
+
"fused_kernel_type": pattern.operation,
|
|
343
|
+
"fused_kernel_name": pattern.fused_kernel,
|
|
344
|
+
"unfused_platform": unfused_platform,
|
|
345
|
+
"unfused_sequence": pattern.unfused_kernels,
|
|
346
|
+
"pattern_count": pattern.count,
|
|
347
|
+
"evidence": pattern.evidence,
|
|
348
|
+
"layer": pattern.layer,
|
|
349
|
+
}
|
|
806
350
|
)
|
|
807
351
|
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
) / 1000 # Convert to ms
|
|
814
|
-
trace2_time = sum(
|
|
815
|
-
k.get("dur", 0) for k in trace2_kernels
|
|
816
|
-
if classify_kernel(k.get("name", "")) == ktype
|
|
817
|
-
) / 1000
|
|
818
|
-
|
|
819
|
-
# Determine who fuses (who has FEWER calls = more fusion)
|
|
820
|
-
if global_ratio > 1.5:
|
|
821
|
-
fused_by = "Trace 2" # Trace 2 has fewer calls
|
|
822
|
-
ratio = global_ratio
|
|
823
|
-
else:
|
|
824
|
-
fused_by = "Trace 1" # Trace 1 has fewer calls
|
|
825
|
-
ratio = global_reverse_ratio
|
|
826
|
-
|
|
827
|
-
time_ratio = trace1_time / trace2_time if trace2_time > 0 else float("inf")
|
|
828
|
-
|
|
829
|
-
results["fusion_opportunities"].append(
|
|
830
|
-
{
|
|
831
|
-
"kernel_type": ktype,
|
|
832
|
-
"trace1_total": trace1_total,
|
|
833
|
-
"trace2_total": trace2_total,
|
|
834
|
-
"trace1_avg_per_group": trace1_total / len(matches) if matches else 0,
|
|
835
|
-
"trace2_avg_per_group": trace2_total / len(matches) if matches else 0,
|
|
836
|
-
"trace1_time_ms": trace1_time,
|
|
837
|
-
"trace2_time_ms": trace2_time,
|
|
838
|
-
"time_ratio": time_ratio,
|
|
839
|
-
"ratio": ratio,
|
|
840
|
-
"fused_by": fused_by,
|
|
841
|
-
"groups_affected": 0, # Unknown for global analysis
|
|
842
|
-
"total_groups": len(matches),
|
|
843
|
-
}
|
|
844
|
-
)
|
|
845
|
-
|
|
846
|
-
# Sort by impact (ratio * total count)
|
|
847
|
-
results["fusion_opportunities"].sort(
|
|
848
|
-
key=lambda x: x["ratio"] * (x["trace1_total"] + x["trace2_total"]), reverse=True
|
|
849
|
-
)
|
|
850
|
-
|
|
851
|
-
# ADD PARTIAL FUSION MAPPINGS using correlation group differential analysis
|
|
852
|
-
# This catches patterns like Sort that exist on both platforms but with different frequencies
|
|
853
|
-
partial_mappings = _find_partial_fusion_via_groups(
|
|
854
|
-
trace1_large,
|
|
855
|
-
trace2_large,
|
|
856
|
-
matches,
|
|
857
|
-
trace1_name=trace1_platform,
|
|
858
|
-
trace2_name=trace2_platform
|
|
859
|
-
)
|
|
860
|
-
all_fusion_mappings.extend(partial_mappings)
|
|
861
|
-
|
|
862
|
-
# DETECT INTRA-TYPE FUSION (same kernel type fused with itself, like Sort chains)
|
|
863
|
-
# Do this FIRST since it's more accurate than the fallback global analysis
|
|
864
|
-
intra_mappings = _detect_intra_type_fusion(
|
|
865
|
-
trace1_kernels,
|
|
866
|
-
trace2_kernels,
|
|
867
|
-
trace1_name=trace1_platform,
|
|
868
|
-
trace2_name=trace2_platform
|
|
869
|
-
)
|
|
870
|
-
all_fusion_mappings.extend(intra_mappings)
|
|
871
|
-
|
|
872
|
-
# Collect kernel types already handled by intra-type fusion
|
|
873
|
-
intra_handled_types = set(m["fused_kernel_type"] for m in intra_mappings)
|
|
874
|
-
|
|
875
|
-
# ALSO ADD GLOBAL FUSION MAPPINGS for kernels not in large correlation groups
|
|
876
|
-
# Skip types already handled by intra-type fusion (more accurate)
|
|
877
|
-
global_mappings = _find_fusion_mappings(
|
|
878
|
-
trace1_kernels,
|
|
879
|
-
trace2_kernels,
|
|
880
|
-
trace1_name=trace1_platform,
|
|
881
|
-
trace2_name=trace2_platform
|
|
882
|
-
)
|
|
883
|
-
# Filter: skip if already handled or if evidence is duplicate
|
|
884
|
-
existing_evidence = set(m["evidence"] for m in all_fusion_mappings)
|
|
885
|
-
for mapping in global_mappings:
|
|
886
|
-
ktype = mapping["unfused_sequence"][0] if mapping["unfused_sequence"] else None
|
|
887
|
-
if ktype not in intra_handled_types and mapping["evidence"] not in existing_evidence:
|
|
888
|
-
all_fusion_mappings.append(mapping)
|
|
889
|
-
|
|
890
|
-
return results
|
|
352
|
+
return {
|
|
353
|
+
"fusion_opportunities": fusion_opportunities,
|
|
354
|
+
"fusion_mappings": fusion_mappings,
|
|
355
|
+
"summary": fusion_analysis.summary,
|
|
356
|
+
}
|