wafer-core 0.1.26__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/PERFORMANCE.md +148 -0
- wafer_core/lib/trace_compare/__init__.py +22 -9
- wafer_core/lib/trace_compare/aligner.py +369 -0
- wafer_core/lib/trace_compare/analyzer.py +549 -159
- wafer_core/lib/trace_compare/api.py +225 -0
- wafer_core/lib/trace_compare/architecture.py +77 -0
- wafer_core/lib/trace_compare/classifier.py +307 -13
- wafer_core/lib/trace_compare/fusion_analyzer.py +311 -845
- wafer_core/lib/trace_compare/kernel_registry.yaml +349 -0
- wafer_core/lib/trace_compare/layer_segmentation.py +114 -0
- wafer_core/lib/trace_compare/loader.py +526 -227
- wafer_core/lib/trace_compare/same_kernel_analyzer.py +119 -0
- wafer_core/lib/trace_compare/warnings.py +99 -0
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.27.dist-info}/METADATA +3 -1
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.27.dist-info}/RECORD +16 -8
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.27.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""Unified API for trace comparison analysis.
|
|
2
|
+
|
|
3
|
+
Provides a single entry point that combines all analysis types:
|
|
4
|
+
- Operation-level comparison
|
|
5
|
+
- Layer-wise comparison
|
|
6
|
+
- Kernel-to-kernel alignment
|
|
7
|
+
- Fusion analysis
|
|
8
|
+
- Same kernel analysis
|
|
9
|
+
- Warnings
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import sys
|
|
13
|
+
import time
|
|
14
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any, Callable, Literal
|
|
18
|
+
|
|
19
|
+
MetadataCallback = Callable[["StreamingMetadata", "StreamingMetadata"], None]
|
|
20
|
+
|
|
21
|
+
from .analyzer import analyze_traces_from_loaded, analyze_traces_aligned
|
|
22
|
+
from .fusion_analyzer import analyze_fusion_from_alignment
|
|
23
|
+
from .same_kernel_analyzer import analyze_same_kernels_from_alignment
|
|
24
|
+
from .architecture import ArchitectureType, detect_architecture
|
|
25
|
+
from .loader import load_trace_full, ProgressCallback, StreamingMetadata, _extract_metadata_fast
|
|
26
|
+
from .warnings import TraceWarning, detect_warnings
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True)
|
|
30
|
+
class TraceComparisonResult:
|
|
31
|
+
"""Complete trace comparison result with all analysis types."""
|
|
32
|
+
|
|
33
|
+
metadata: dict[str, Any]
|
|
34
|
+
operations: list[dict[str, Any]]
|
|
35
|
+
layers: list[dict[str, Any]]
|
|
36
|
+
fusion_opportunities: list[dict[str, Any]]
|
|
37
|
+
fusion_mappings: list[dict[str, Any]]
|
|
38
|
+
warnings: list[TraceWarning]
|
|
39
|
+
architecture: ArchitectureType
|
|
40
|
+
# New alignment-based fields
|
|
41
|
+
layer_alignments: list[dict[str, Any]] | None = None
|
|
42
|
+
fusion_analysis: dict[str, Any] | None = None
|
|
43
|
+
same_kernel_analysis: dict[str, Any] | None = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def analyze_trace_pair(
|
|
47
|
+
trace1_path: str | Path,
|
|
48
|
+
trace2_path: str | Path,
|
|
49
|
+
phase: Literal["all", "prefill", "decode"] = "all",
|
|
50
|
+
include_stacks: bool = True,
|
|
51
|
+
on_progress: ProgressCallback | None = None,
|
|
52
|
+
on_metadata: MetadataCallback | None = None,
|
|
53
|
+
) -> TraceComparisonResult:
|
|
54
|
+
"""Single entry point combining all analyses.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
trace1_path: Path to first trace file
|
|
58
|
+
trace2_path: Path to second trace file
|
|
59
|
+
phase: Filter by phase ('all', 'prefill', or 'decode')
|
|
60
|
+
include_stacks: Whether to include Python stack traces
|
|
61
|
+
on_progress: Optional callback for progress updates: (stage_name, progress_fraction)
|
|
62
|
+
on_metadata: Optional callback for early metadata (~2ms): (trace1_meta, trace2_meta)
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Complete comparison result with all analysis types
|
|
66
|
+
"""
|
|
67
|
+
trace1_path = Path(trace1_path)
|
|
68
|
+
trace2_path = Path(trace2_path)
|
|
69
|
+
|
|
70
|
+
def _progress(stage: str, fraction: float) -> None:
|
|
71
|
+
if on_progress:
|
|
72
|
+
on_progress(stage, fraction)
|
|
73
|
+
|
|
74
|
+
if on_metadata:
|
|
75
|
+
meta1 = _extract_metadata_fast(trace1_path)
|
|
76
|
+
meta2 = _extract_metadata_fast(trace2_path)
|
|
77
|
+
on_metadata(meta1, meta2)
|
|
78
|
+
|
|
79
|
+
t0 = time.perf_counter()
|
|
80
|
+
_progress("Loading traces", 0.0)
|
|
81
|
+
|
|
82
|
+
with ProcessPoolExecutor(max_workers=2) as executor:
|
|
83
|
+
future1 = executor.submit(load_trace_full, trace1_path, include_stacks, None)
|
|
84
|
+
future2 = executor.submit(load_trace_full, trace2_path, include_stacks, None)
|
|
85
|
+
|
|
86
|
+
if on_progress:
|
|
87
|
+
import threading
|
|
88
|
+
stop_progress = threading.Event()
|
|
89
|
+
|
|
90
|
+
def progress_updater():
|
|
91
|
+
base_progress = 0.0
|
|
92
|
+
max_progress = 0.75
|
|
93
|
+
start_time = time.perf_counter()
|
|
94
|
+
estimated_duration = 30.0
|
|
95
|
+
|
|
96
|
+
stages = [
|
|
97
|
+
("Reading files", 0.05),
|
|
98
|
+
("Parsing JSON", 0.15),
|
|
99
|
+
("Processing events", 0.40),
|
|
100
|
+
("Building DataFrames", 0.60),
|
|
101
|
+
("Finalizing", 0.75),
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
stage_idx = 0
|
|
105
|
+
last_progress = 0.0
|
|
106
|
+
while not stop_progress.is_set():
|
|
107
|
+
elapsed = time.perf_counter() - start_time
|
|
108
|
+
if elapsed > estimated_duration:
|
|
109
|
+
elapsed = estimated_duration
|
|
110
|
+
|
|
111
|
+
if stage_idx < len(stages):
|
|
112
|
+
stage_name, stage_max = stages[stage_idx]
|
|
113
|
+
stage_progress = min(elapsed / estimated_duration, 1.0) * (stage_max - base_progress)
|
|
114
|
+
current_progress = base_progress + stage_progress
|
|
115
|
+
|
|
116
|
+
if current_progress - last_progress >= 0.01:
|
|
117
|
+
_progress(f"Loading: {stage_name}", current_progress)
|
|
118
|
+
last_progress = current_progress
|
|
119
|
+
|
|
120
|
+
if current_progress >= stage_max and stage_idx < len(stages) - 1:
|
|
121
|
+
base_progress = stage_max
|
|
122
|
+
stage_idx += 1
|
|
123
|
+
else:
|
|
124
|
+
if max_progress - last_progress >= 0.01:
|
|
125
|
+
_progress("Loading traces", max_progress)
|
|
126
|
+
last_progress = max_progress
|
|
127
|
+
|
|
128
|
+
if stop_progress.wait(timeout=0.2):
|
|
129
|
+
break
|
|
130
|
+
|
|
131
|
+
progress_thread = threading.Thread(target=progress_updater, daemon=True)
|
|
132
|
+
progress_thread.start()
|
|
133
|
+
|
|
134
|
+
trace1 = future1.result()
|
|
135
|
+
trace2 = future2.result()
|
|
136
|
+
|
|
137
|
+
if on_progress:
|
|
138
|
+
stop_progress.set()
|
|
139
|
+
progress_thread.join(timeout=1.0)
|
|
140
|
+
|
|
141
|
+
# Normalize trace order: trace1 should always be AMD, trace2 should be NVIDIA
|
|
142
|
+
# This ensures consistent output where trace1_* fields always refer to AMD
|
|
143
|
+
if trace1.platform != "AMD" and trace2.platform == "AMD":
|
|
144
|
+
trace1, trace2 = trace2, trace1
|
|
145
|
+
|
|
146
|
+
t1 = time.perf_counter()
|
|
147
|
+
print(f"Trace loading: {t1-t0:.1f}s", file=sys.stderr)
|
|
148
|
+
_progress("Traces loaded", 0.8)
|
|
149
|
+
|
|
150
|
+
t2_start = time.perf_counter()
|
|
151
|
+
_progress("Detecting architecture", 0.8)
|
|
152
|
+
all_kernel_names = list(trace1.df["name"].unique()) + list(trace2.df["name"].unique())
|
|
153
|
+
architecture, _ = detect_architecture(all_kernel_names)
|
|
154
|
+
|
|
155
|
+
_progress("Comparing operations", 0.85)
|
|
156
|
+
comparison_result = analyze_traces_from_loaded(trace1, trace2, phase_filter=phase)
|
|
157
|
+
t2_end = time.perf_counter()
|
|
158
|
+
print(f"Operation analysis: {t2_end-t2_start:.1f}s", file=sys.stderr)
|
|
159
|
+
|
|
160
|
+
_progress("Aligning kernels", 0.9)
|
|
161
|
+
t3_start = time.perf_counter()
|
|
162
|
+
alignment_result = analyze_traces_aligned(trace1, trace2, phase_filter=phase)
|
|
163
|
+
t3_end = time.perf_counter()
|
|
164
|
+
print(f"Alignment analysis: {t3_end-t3_start:.1f}s", file=sys.stderr)
|
|
165
|
+
|
|
166
|
+
t4_start = time.perf_counter()
|
|
167
|
+
kernel_names1 = [ev.get("name", "") for ev in trace1.all_events if ev.get("cat") == "kernel"]
|
|
168
|
+
kernel_names2 = [ev.get("name", "") for ev in trace2.all_events if ev.get("cat") == "kernel"]
|
|
169
|
+
|
|
170
|
+
phases1 = [
|
|
171
|
+
ev for ev in trace1.all_events
|
|
172
|
+
if ev.get("cat") == "user_annotation" and ev.get("name", "").startswith("execute_context")
|
|
173
|
+
]
|
|
174
|
+
phases2 = [
|
|
175
|
+
ev for ev in trace2.all_events
|
|
176
|
+
if ev.get("cat") == "user_annotation" and ev.get("name", "").startswith("execute_context")
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
warnings1 = detect_warnings(
|
|
180
|
+
trace1.all_events,
|
|
181
|
+
kernel_names1,
|
|
182
|
+
phases1,
|
|
183
|
+
comparison_result["metadata"].get("trace1_layers", 0),
|
|
184
|
+
len(kernel_names1),
|
|
185
|
+
)
|
|
186
|
+
warnings2 = detect_warnings(
|
|
187
|
+
trace2.all_events,
|
|
188
|
+
kernel_names2,
|
|
189
|
+
phases2,
|
|
190
|
+
comparison_result["metadata"].get("trace2_layers", 0),
|
|
191
|
+
len(kernel_names2),
|
|
192
|
+
)
|
|
193
|
+
t4_end = time.perf_counter()
|
|
194
|
+
print(f"Warning detection: {t4_end-t4_start:.1f}s", file=sys.stderr)
|
|
195
|
+
|
|
196
|
+
all_warnings: list[TraceWarning] = []
|
|
197
|
+
seen_codes: set[str] = set()
|
|
198
|
+
for warning in warnings1 + warnings2:
|
|
199
|
+
if warning.code not in seen_codes:
|
|
200
|
+
all_warnings.append(warning)
|
|
201
|
+
seen_codes.add(warning.code)
|
|
202
|
+
|
|
203
|
+
print(f"Total analysis time: {t4_end-t0:.1f}s", file=sys.stderr)
|
|
204
|
+
|
|
205
|
+
_progress("Complete", 1.0)
|
|
206
|
+
|
|
207
|
+
fusion_opportunities = []
|
|
208
|
+
fusion_mappings = []
|
|
209
|
+
if alignment_result.get("fusion_analysis"):
|
|
210
|
+
fusion_analysis = alignment_result["fusion_analysis"]
|
|
211
|
+
fusion_opportunities = fusion_analysis.get("fusion_opportunities", [])
|
|
212
|
+
fusion_mappings = fusion_analysis.get("fusion_mappings", [])
|
|
213
|
+
|
|
214
|
+
return TraceComparisonResult(
|
|
215
|
+
metadata=comparison_result["metadata"],
|
|
216
|
+
operations=comparison_result["operations"],
|
|
217
|
+
layers=comparison_result.get("layers", []),
|
|
218
|
+
fusion_opportunities=fusion_opportunities,
|
|
219
|
+
fusion_mappings=fusion_mappings,
|
|
220
|
+
warnings=all_warnings,
|
|
221
|
+
architecture=architecture,
|
|
222
|
+
layer_alignments=alignment_result.get("layer_alignments"),
|
|
223
|
+
fusion_analysis=alignment_result.get("fusion_analysis"),
|
|
224
|
+
same_kernel_analysis=alignment_result.get("same_kernel_analysis"),
|
|
225
|
+
)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Architecture detection from kernel patterns.
|
|
2
|
+
|
|
3
|
+
Detects model architecture type (Transformer, SSM, Hybrid) from kernel names
|
|
4
|
+
to enable architecture-specific layer segmentation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Literal
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ArchitectureType(Enum):
|
|
12
|
+
"""Model architecture types."""
|
|
13
|
+
|
|
14
|
+
TRANSFORMER = "transformer"
|
|
15
|
+
SSM = "ssm"
|
|
16
|
+
HYBRID = "hybrid"
|
|
17
|
+
UNKNOWN = "unknown"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def detect_architecture(kernel_names: list[str]) -> tuple[ArchitectureType, list[str]]:
|
|
21
|
+
"""Detect model architecture from kernel patterns.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
kernel_names: List of all kernel names from trace
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Tuple of (architecture_type, detected_markers)
|
|
28
|
+
detected_markers: List of kernel names that indicate the architecture
|
|
29
|
+
"""
|
|
30
|
+
kernel_names_lower = [name.lower() for name in kernel_names]
|
|
31
|
+
|
|
32
|
+
attention_patterns = [
|
|
33
|
+
"fmha",
|
|
34
|
+
"attention",
|
|
35
|
+
"flash",
|
|
36
|
+
"sdpa",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
ssm_patterns = [
|
|
40
|
+
"selective_scan",
|
|
41
|
+
"mamba",
|
|
42
|
+
"ssd",
|
|
43
|
+
"causal_conv",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
has_attention = any(
|
|
47
|
+
any(pattern in name for pattern in attention_patterns)
|
|
48
|
+
for name in kernel_names_lower
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
has_ssm = any(
|
|
52
|
+
any(pattern in name for pattern in ssm_patterns)
|
|
53
|
+
for name in kernel_names_lower
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
markers: list[str] = []
|
|
57
|
+
if has_attention:
|
|
58
|
+
attention_markers = [
|
|
59
|
+
name for name in kernel_names
|
|
60
|
+
if any(pattern in name.lower() for pattern in attention_patterns)
|
|
61
|
+
]
|
|
62
|
+
markers.extend(attention_markers[:5])
|
|
63
|
+
|
|
64
|
+
if has_ssm:
|
|
65
|
+
ssm_markers = [
|
|
66
|
+
name for name in kernel_names
|
|
67
|
+
if any(pattern in name.lower() for pattern in ssm_patterns)
|
|
68
|
+
]
|
|
69
|
+
markers.extend(ssm_markers[:5])
|
|
70
|
+
if has_attention and not has_ssm:
|
|
71
|
+
return ArchitectureType.TRANSFORMER, markers
|
|
72
|
+
elif has_ssm and not has_attention:
|
|
73
|
+
return ArchitectureType.SSM, markers
|
|
74
|
+
elif has_attention and has_ssm:
|
|
75
|
+
return ArchitectureType.HYBRID, markers
|
|
76
|
+
else:
|
|
77
|
+
return ArchitectureType.UNKNOWN, []
|