wafer-core 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,225 @@
1
+ """Unified API for trace comparison analysis.
2
+
3
+ Provides a single entry point that combines all analysis types:
4
+ - Operation-level comparison
5
+ - Layer-wise comparison
6
+ - Kernel-to-kernel alignment
7
+ - Fusion analysis
8
+ - Same kernel analysis
9
+ - Warnings
10
+ """
11
+
12
+ import sys
13
+ import time
14
+ from concurrent.futures import ProcessPoolExecutor
15
+ from dataclasses import dataclass
16
+ from pathlib import Path
17
+ from typing import Any, Callable, Literal
18
+
19
+ MetadataCallback = Callable[["StreamingMetadata", "StreamingMetadata"], None]
20
+
21
+ from .analyzer import analyze_traces_from_loaded, analyze_traces_aligned
22
+ from .fusion_analyzer import analyze_fusion_from_alignment
23
+ from .same_kernel_analyzer import analyze_same_kernels_from_alignment
24
+ from .architecture import ArchitectureType, detect_architecture
25
+ from .loader import load_trace_full, ProgressCallback, StreamingMetadata, _extract_metadata_fast
26
+ from .warnings import TraceWarning, detect_warnings
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class TraceComparisonResult:
31
+ """Complete trace comparison result with all analysis types."""
32
+
33
+ metadata: dict[str, Any]
34
+ operations: list[dict[str, Any]]
35
+ layers: list[dict[str, Any]]
36
+ fusion_opportunities: list[dict[str, Any]]
37
+ fusion_mappings: list[dict[str, Any]]
38
+ warnings: list[TraceWarning]
39
+ architecture: ArchitectureType
40
+ # New alignment-based fields
41
+ layer_alignments: list[dict[str, Any]] | None = None
42
+ fusion_analysis: dict[str, Any] | None = None
43
+ same_kernel_analysis: dict[str, Any] | None = None
44
+
45
+
46
+ def analyze_trace_pair(
47
+ trace1_path: str | Path,
48
+ trace2_path: str | Path,
49
+ phase: Literal["all", "prefill", "decode"] = "all",
50
+ include_stacks: bool = True,
51
+ on_progress: ProgressCallback | None = None,
52
+ on_metadata: MetadataCallback | None = None,
53
+ ) -> TraceComparisonResult:
54
+ """Single entry point combining all analyses.
55
+
56
+ Args:
57
+ trace1_path: Path to first trace file
58
+ trace2_path: Path to second trace file
59
+ phase: Filter by phase ('all', 'prefill', or 'decode')
60
+ include_stacks: Whether to include Python stack traces
61
+ on_progress: Optional callback for progress updates: (stage_name, progress_fraction)
62
+ on_metadata: Optional callback for early metadata (~2ms): (trace1_meta, trace2_meta)
63
+
64
+ Returns:
65
+ Complete comparison result with all analysis types
66
+ """
67
+ trace1_path = Path(trace1_path)
68
+ trace2_path = Path(trace2_path)
69
+
70
+ def _progress(stage: str, fraction: float) -> None:
71
+ if on_progress:
72
+ on_progress(stage, fraction)
73
+
74
+ if on_metadata:
75
+ meta1 = _extract_metadata_fast(trace1_path)
76
+ meta2 = _extract_metadata_fast(trace2_path)
77
+ on_metadata(meta1, meta2)
78
+
79
+ t0 = time.perf_counter()
80
+ _progress("Loading traces", 0.0)
81
+
82
+ with ProcessPoolExecutor(max_workers=2) as executor:
83
+ future1 = executor.submit(load_trace_full, trace1_path, include_stacks, None)
84
+ future2 = executor.submit(load_trace_full, trace2_path, include_stacks, None)
85
+
86
+ if on_progress:
87
+ import threading
88
+ stop_progress = threading.Event()
89
+
90
+ def progress_updater():
91
+ base_progress = 0.0
92
+ max_progress = 0.75
93
+ start_time = time.perf_counter()
94
+ estimated_duration = 30.0
95
+
96
+ stages = [
97
+ ("Reading files", 0.05),
98
+ ("Parsing JSON", 0.15),
99
+ ("Processing events", 0.40),
100
+ ("Building DataFrames", 0.60),
101
+ ("Finalizing", 0.75),
102
+ ]
103
+
104
+ stage_idx = 0
105
+ last_progress = 0.0
106
+ while not stop_progress.is_set():
107
+ elapsed = time.perf_counter() - start_time
108
+ if elapsed > estimated_duration:
109
+ elapsed = estimated_duration
110
+
111
+ if stage_idx < len(stages):
112
+ stage_name, stage_max = stages[stage_idx]
113
+ stage_progress = min(elapsed / estimated_duration, 1.0) * (stage_max - base_progress)
114
+ current_progress = base_progress + stage_progress
115
+
116
+ if current_progress - last_progress >= 0.01:
117
+ _progress(f"Loading: {stage_name}", current_progress)
118
+ last_progress = current_progress
119
+
120
+ if current_progress >= stage_max and stage_idx < len(stages) - 1:
121
+ base_progress = stage_max
122
+ stage_idx += 1
123
+ else:
124
+ if max_progress - last_progress >= 0.01:
125
+ _progress("Loading traces", max_progress)
126
+ last_progress = max_progress
127
+
128
+ if stop_progress.wait(timeout=0.2):
129
+ break
130
+
131
+ progress_thread = threading.Thread(target=progress_updater, daemon=True)
132
+ progress_thread.start()
133
+
134
+ trace1 = future1.result()
135
+ trace2 = future2.result()
136
+
137
+ if on_progress:
138
+ stop_progress.set()
139
+ progress_thread.join(timeout=1.0)
140
+
141
+ # Normalize trace order: trace1 should always be AMD, trace2 should be NVIDIA
142
+ # This ensures consistent output where trace1_* fields always refer to AMD
143
+ if trace1.platform != "AMD" and trace2.platform == "AMD":
144
+ trace1, trace2 = trace2, trace1
145
+
146
+ t1 = time.perf_counter()
147
+ print(f"Trace loading: {t1-t0:.1f}s", file=sys.stderr)
148
+ _progress("Traces loaded", 0.8)
149
+
150
+ t2_start = time.perf_counter()
151
+ _progress("Detecting architecture", 0.8)
152
+ all_kernel_names = list(trace1.df["name"].unique()) + list(trace2.df["name"].unique())
153
+ architecture, _ = detect_architecture(all_kernel_names)
154
+
155
+ _progress("Comparing operations", 0.85)
156
+ comparison_result = analyze_traces_from_loaded(trace1, trace2, phase_filter=phase)
157
+ t2_end = time.perf_counter()
158
+ print(f"Operation analysis: {t2_end-t2_start:.1f}s", file=sys.stderr)
159
+
160
+ _progress("Aligning kernels", 0.9)
161
+ t3_start = time.perf_counter()
162
+ alignment_result = analyze_traces_aligned(trace1, trace2, phase_filter=phase)
163
+ t3_end = time.perf_counter()
164
+ print(f"Alignment analysis: {t3_end-t3_start:.1f}s", file=sys.stderr)
165
+
166
+ t4_start = time.perf_counter()
167
+ kernel_names1 = [ev.get("name", "") for ev in trace1.all_events if ev.get("cat") == "kernel"]
168
+ kernel_names2 = [ev.get("name", "") for ev in trace2.all_events if ev.get("cat") == "kernel"]
169
+
170
+ phases1 = [
171
+ ev for ev in trace1.all_events
172
+ if ev.get("cat") == "user_annotation" and ev.get("name", "").startswith("execute_context")
173
+ ]
174
+ phases2 = [
175
+ ev for ev in trace2.all_events
176
+ if ev.get("cat") == "user_annotation" and ev.get("name", "").startswith("execute_context")
177
+ ]
178
+
179
+ warnings1 = detect_warnings(
180
+ trace1.all_events,
181
+ kernel_names1,
182
+ phases1,
183
+ comparison_result["metadata"].get("trace1_layers", 0),
184
+ len(kernel_names1),
185
+ )
186
+ warnings2 = detect_warnings(
187
+ trace2.all_events,
188
+ kernel_names2,
189
+ phases2,
190
+ comparison_result["metadata"].get("trace2_layers", 0),
191
+ len(kernel_names2),
192
+ )
193
+ t4_end = time.perf_counter()
194
+ print(f"Warning detection: {t4_end-t4_start:.1f}s", file=sys.stderr)
195
+
196
+ all_warnings: list[TraceWarning] = []
197
+ seen_codes: set[str] = set()
198
+ for warning in warnings1 + warnings2:
199
+ if warning.code not in seen_codes:
200
+ all_warnings.append(warning)
201
+ seen_codes.add(warning.code)
202
+
203
+ print(f"Total analysis time: {t4_end-t0:.1f}s", file=sys.stderr)
204
+
205
+ _progress("Complete", 1.0)
206
+
207
+ fusion_opportunities = []
208
+ fusion_mappings = []
209
+ if alignment_result.get("fusion_analysis"):
210
+ fusion_analysis = alignment_result["fusion_analysis"]
211
+ fusion_opportunities = fusion_analysis.get("fusion_opportunities", [])
212
+ fusion_mappings = fusion_analysis.get("fusion_mappings", [])
213
+
214
+ return TraceComparisonResult(
215
+ metadata=comparison_result["metadata"],
216
+ operations=comparison_result["operations"],
217
+ layers=comparison_result.get("layers", []),
218
+ fusion_opportunities=fusion_opportunities,
219
+ fusion_mappings=fusion_mappings,
220
+ warnings=all_warnings,
221
+ architecture=architecture,
222
+ layer_alignments=alignment_result.get("layer_alignments"),
223
+ fusion_analysis=alignment_result.get("fusion_analysis"),
224
+ same_kernel_analysis=alignment_result.get("same_kernel_analysis"),
225
+ )
@@ -0,0 +1,77 @@
1
+ """Architecture detection from kernel patterns.
2
+
3
+ Detects model architecture type (Transformer, SSM, Hybrid) from kernel names
4
+ to enable architecture-specific layer segmentation.
5
+ """
6
+
7
+ from enum import Enum
8
+ from typing import Literal
9
+
10
+
11
+ class ArchitectureType(Enum):
12
+ """Model architecture types."""
13
+
14
+ TRANSFORMER = "transformer"
15
+ SSM = "ssm"
16
+ HYBRID = "hybrid"
17
+ UNKNOWN = "unknown"
18
+
19
+
20
+ def detect_architecture(kernel_names: list[str]) -> tuple[ArchitectureType, list[str]]:
21
+ """Detect model architecture from kernel patterns.
22
+
23
+ Args:
24
+ kernel_names: List of all kernel names from trace
25
+
26
+ Returns:
27
+ Tuple of (architecture_type, detected_markers)
28
+ detected_markers: List of kernel names that indicate the architecture
29
+ """
30
+ kernel_names_lower = [name.lower() for name in kernel_names]
31
+
32
+ attention_patterns = [
33
+ "fmha",
34
+ "attention",
35
+ "flash",
36
+ "sdpa",
37
+ ]
38
+
39
+ ssm_patterns = [
40
+ "selective_scan",
41
+ "mamba",
42
+ "ssd",
43
+ "causal_conv",
44
+ ]
45
+
46
+ has_attention = any(
47
+ any(pattern in name for pattern in attention_patterns)
48
+ for name in kernel_names_lower
49
+ )
50
+
51
+ has_ssm = any(
52
+ any(pattern in name for pattern in ssm_patterns)
53
+ for name in kernel_names_lower
54
+ )
55
+
56
+ markers: list[str] = []
57
+ if has_attention:
58
+ attention_markers = [
59
+ name for name in kernel_names
60
+ if any(pattern in name.lower() for pattern in attention_patterns)
61
+ ]
62
+ markers.extend(attention_markers[:5])
63
+
64
+ if has_ssm:
65
+ ssm_markers = [
66
+ name for name in kernel_names
67
+ if any(pattern in name.lower() for pattern in ssm_patterns)
68
+ ]
69
+ markers.extend(ssm_markers[:5])
70
+ if has_attention and not has_ssm:
71
+ return ArchitectureType.TRANSFORMER, markers
72
+ elif has_ssm and not has_attention:
73
+ return ArchitectureType.SSM, markers
74
+ elif has_attention and has_ssm:
75
+ return ArchitectureType.HYBRID, markers
76
+ else:
77
+ return ArchitectureType.UNKNOWN, []