wafer-core 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/PERFORMANCE.md +148 -0
- wafer_core/lib/trace_compare/__init__.py +22 -9
- wafer_core/lib/trace_compare/aligner.py +376 -0
- wafer_core/lib/trace_compare/analyzer.py +558 -159
- wafer_core/lib/trace_compare/api.py +225 -0
- wafer_core/lib/trace_compare/architecture.py +77 -0
- wafer_core/lib/trace_compare/classifier.py +307 -13
- wafer_core/lib/trace_compare/fusion_analyzer.py +280 -706
- wafer_core/lib/trace_compare/kernel_registry.yaml +349 -0
- wafer_core/lib/trace_compare/layer_segmentation.py +114 -0
- wafer_core/lib/trace_compare/loader.py +526 -227
- wafer_core/lib/trace_compare/same_kernel_analyzer.py +119 -0
- wafer_core/lib/trace_compare/warnings.py +99 -0
- wafer_core/targets/__init__.py +47 -21
- wafer_core/targets/pool.py +181 -0
- wafer_core/targets/probe.py +113 -0
- wafer_core/targets/providers/__init__.py +46 -0
- wafer_core/targets/providers/baremetal.py +72 -0
- wafer_core/targets/providers/digitalocean.py +164 -0
- wafer_core/targets/providers/runpod.py +250 -0
- wafer_core/targets/reconcile.py +90 -0
- wafer_core/targets/spec_store.py +200 -0
- wafer_core/targets/state_cache.py +150 -0
- wafer_core/targets/types.py +141 -0
- wafer_core/utils/kernel_utils/targets/config.py +8 -24
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/METADATA +3 -1
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/RECORD +28 -10
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/WHEEL +0 -0
|
@@ -5,233 +5,268 @@ Python call stacks, CPU operator mappings, and layer correlations.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import bisect
|
|
8
|
-
import
|
|
8
|
+
import sys
|
|
9
9
|
from collections import defaultdict
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from dataclasses import dataclass, field
|
|
10
12
|
from pathlib import Path
|
|
11
13
|
from typing import Any
|
|
12
14
|
|
|
15
|
+
try:
|
|
16
|
+
import ijson
|
|
17
|
+
except ImportError:
|
|
18
|
+
ijson = None
|
|
19
|
+
|
|
20
|
+
import orjson
|
|
13
21
|
import pandas as pd
|
|
14
22
|
|
|
23
|
+
from .architecture import detect_architecture
|
|
15
24
|
from .classifier import classify
|
|
25
|
+
from .layer_segmentation import segment_layers_by_architecture
|
|
16
26
|
|
|
27
|
+
ProgressCallback = Callable[[str, float], None]
|
|
17
28
|
|
|
18
|
-
def extract_layer_mapping(events: list[dict[str, Any]], platform: str) -> dict[int, int]:
|
|
19
|
-
"""Extract correlation ID to layer number mapping.
|
|
20
|
-
|
|
21
|
-
vLLM's execution graph creates large correlation groups for full transformer layers.
|
|
22
|
-
Each layer's forward pass (norm + attention + FFN) gets grouped under one correlation ID,
|
|
23
|
-
containing 200-400 kernels depending on batch size and sequence length.
|
|
24
|
-
|
|
25
|
-
We identify layers as correlation groups with many kernels (70+), which filters out
|
|
26
|
-
individual operations like sampling, logit processing, etc.
|
|
27
29
|
|
|
30
|
+
@dataclass
|
|
31
|
+
class SinglePassResult:
|
|
32
|
+
"""Collected data from single-pass event processing."""
|
|
33
|
+
cpu_op_mapping: dict[int, str] = field(default_factory=dict)
|
|
34
|
+
python_intervals: list[tuple[int, int, int, int | None, str]] = field(default_factory=list)
|
|
35
|
+
# Raw python events for lazy python_by_id construction (built on-demand in _get_python_stack_full)
|
|
36
|
+
python_events_raw: list[dict[str, Any]] = field(default_factory=list)
|
|
37
|
+
phases: list[dict[str, Any]] = field(default_factory=list)
|
|
38
|
+
correlation_groups: dict[int, dict[str, Any]] = field(default_factory=lambda: defaultdict(
|
|
39
|
+
lambda: {"count": 0, "has_attention": False, "has_ffn": False}
|
|
40
|
+
))
|
|
41
|
+
kernel_events: list[dict[str, Any]] = field(default_factory=list)
|
|
42
|
+
# Lazily built when needed for stack resolution
|
|
43
|
+
_python_by_id: dict[int, dict[str, Any]] | None = field(default=None)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def python_by_id(self) -> dict[int, dict[str, Any]]:
|
|
47
|
+
"""Lazily build python_by_id from raw events on first access."""
|
|
48
|
+
if self._python_by_id is None:
|
|
49
|
+
self._python_by_id = {}
|
|
50
|
+
for ev in self.python_events_raw:
|
|
51
|
+
args = ev.get("args")
|
|
52
|
+
py_id = args.get("Python id") if args else None
|
|
53
|
+
if py_id is not None:
|
|
54
|
+
self._python_by_id[py_id] = {
|
|
55
|
+
"name": ev["name"],
|
|
56
|
+
"parent_id": args.get("Python parent id") if args else None,
|
|
57
|
+
}
|
|
58
|
+
return self._python_by_id
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class LoadedTrace:
|
|
63
|
+
"""Complete trace data loaded once and reused by all analysis functions."""
|
|
64
|
+
platform: str
|
|
65
|
+
gpu_name: str
|
|
66
|
+
device_props: dict[str, Any]
|
|
67
|
+
df: pd.DataFrame
|
|
68
|
+
patterns: dict[tuple[str, str], set[str]]
|
|
69
|
+
layers: dict[int, int]
|
|
70
|
+
# For fusion/warnings (kept from raw JSON)
|
|
71
|
+
kernel_events: list[dict[str, Any]]
|
|
72
|
+
all_events: list[dict[str, Any]]
|
|
73
|
+
correlation_groups: dict[int, list[dict[str, Any]]]
|
|
74
|
+
phases: list[dict[str, Any]] # Phase annotations for alignment
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _process_events_single_pass(
|
|
78
|
+
events: list[dict[str, Any]],
|
|
79
|
+
include_stacks: bool = True,
|
|
80
|
+
) -> SinglePassResult:
|
|
81
|
+
"""Process all events in a single iteration.
|
|
82
|
+
|
|
83
|
+
Optimizations applied:
|
|
84
|
+
- Cache list.append methods for 2-3x speedup on hot paths
|
|
85
|
+
- Store raw python events, build python_by_id lazily (only ~48 lookups due to caching)
|
|
86
|
+
- Local variable caching for frequently accessed attributes
|
|
87
|
+
|
|
28
88
|
Args:
|
|
29
89
|
events: List of trace events
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
Returns:
|
|
33
|
-
Dict mapping correlation ID to layer number
|
|
90
|
+
include_stacks: Whether to collect Python stack info (expensive operation).
|
|
91
|
+
When False, skips python_function processing entirely.
|
|
34
92
|
"""
|
|
35
|
-
|
|
36
|
-
correlation_groups = defaultdict(
|
|
93
|
+
result = SinglePassResult()
|
|
94
|
+
correlation_groups: dict[int, dict[str, Any]] = defaultdict(
|
|
37
95
|
lambda: {"count": 0, "has_attention": False, "has_ffn": False}
|
|
38
96
|
)
|
|
39
|
-
|
|
97
|
+
|
|
98
|
+
# Cache list.append methods for faster appending (measured 2-3x speedup)
|
|
99
|
+
kernel_append = result.kernel_events.append
|
|
100
|
+
python_interval_append = result.python_intervals.append
|
|
101
|
+
python_raw_append = result.python_events_raw.append
|
|
102
|
+
phase_append = result.phases.append
|
|
103
|
+
cpu_op_mapping = result.cpu_op_mapping
|
|
104
|
+
|
|
40
105
|
for ev in events:
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
106
|
+
cat = ev.get("cat")
|
|
107
|
+
|
|
108
|
+
if cat == "kernel":
|
|
109
|
+
args = ev.get("args")
|
|
110
|
+
corr_id = args.get("correlation") if args else None
|
|
111
|
+
if corr_id is not None:
|
|
112
|
+
kernel_name = ev.get("name", "").lower()
|
|
113
|
+
grp = correlation_groups[corr_id]
|
|
114
|
+
grp["count"] += 1
|
|
115
|
+
if "attention" in kernel_name or "fmha" in kernel_name:
|
|
116
|
+
grp["has_attention"] = True
|
|
117
|
+
if "cijk_" in kernel_name or "nvjet" in kernel_name or "wvsplitk" in kernel_name or "gemm" in kernel_name:
|
|
118
|
+
grp["has_ffn"] = True
|
|
119
|
+
kernel_append(ev)
|
|
120
|
+
|
|
121
|
+
elif cat == "cpu_op":
|
|
122
|
+
args = ev.get("args")
|
|
123
|
+
ext_id = args.get("External id") if args else None
|
|
124
|
+
if ext_id is not None:
|
|
125
|
+
cpu_op_mapping[ext_id] = ev.get("name", "")
|
|
126
|
+
|
|
127
|
+
elif cat == "python_function" and include_stacks:
|
|
128
|
+
# Store raw event for lazy python_by_id construction
|
|
129
|
+
python_raw_append(ev)
|
|
130
|
+
# Build interval tuple for binary search
|
|
131
|
+
args = ev.get("args")
|
|
132
|
+
py_id = args.get("Python id") if args else None
|
|
133
|
+
ts = ev["ts"]
|
|
134
|
+
dur = ev.get("dur", 0)
|
|
135
|
+
python_interval_append((ts, ts + dur, dur, py_id, ev["name"]))
|
|
136
|
+
|
|
137
|
+
elif cat == "user_annotation":
|
|
138
|
+
name = ev.get("name", "")
|
|
139
|
+
if name.startswith("execute_context"):
|
|
140
|
+
tokens = 0
|
|
141
|
+
parts = name.split("_")
|
|
142
|
+
for i, p in enumerate(parts):
|
|
143
|
+
if i > 0 and parts[i-1] == "context" and "(" in p and ")" in p:
|
|
144
|
+
try:
|
|
145
|
+
tokens = int(p.split("(")[1].split(")")[0])
|
|
146
|
+
break
|
|
147
|
+
except Exception:
|
|
148
|
+
pass
|
|
149
|
+
is_prefill = "generation_0" in name and tokens > 0
|
|
150
|
+
phase_append({
|
|
151
|
+
"type": "prefill" if is_prefill else "decode",
|
|
152
|
+
"ts_start": ev["ts"],
|
|
153
|
+
"ts_end": ev["ts"] + ev["dur"],
|
|
154
|
+
})
|
|
155
|
+
|
|
156
|
+
if include_stacks and result.python_intervals:
|
|
157
|
+
result.python_intervals.sort()
|
|
158
|
+
|
|
159
|
+
result.correlation_groups = dict(correlation_groups)
|
|
160
|
+
|
|
161
|
+
return result
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _build_layer_mapping(correlation_groups: dict[int, dict[str, Any]]) -> dict[int, int]:
|
|
165
|
+
"""Build layer mapping from correlation groups."""
|
|
59
166
|
correlation_to_layer = {}
|
|
60
167
|
layer_num = 0
|
|
61
|
-
|
|
168
|
+
|
|
62
169
|
for corr_id in sorted(correlation_groups.keys()):
|
|
63
170
|
group = correlation_groups[corr_id]
|
|
64
|
-
|
|
65
|
-
# Identify complete transformer layers by their characteristics:
|
|
66
|
-
# - Has attention operations (self-attention or cross-attention)
|
|
67
|
-
# - Has FFN operations (feed-forward network)
|
|
68
|
-
# - Has sufficient kernel count (70+): typical transformer block has ~80-100 kernels
|
|
69
|
-
# including attention QKV projections, softmax, output projection, FFN layers,
|
|
70
|
-
# normalization, and elementwise ops. This threshold filters out:
|
|
71
|
-
# - Individual operations (1-10 kernels)
|
|
72
|
-
# - Sampling/generation steps (20-40 kernels)
|
|
73
|
-
# - Partial layer executions
|
|
171
|
+
|
|
74
172
|
is_layer = (
|
|
75
173
|
group["count"] >= 70 and group["has_attention"] and group["has_ffn"]
|
|
76
174
|
)
|
|
77
|
-
|
|
175
|
+
|
|
78
176
|
if is_layer:
|
|
79
177
|
correlation_to_layer[corr_id] = layer_num
|
|
80
178
|
layer_num += 1
|
|
81
|
-
|
|
179
|
+
|
|
82
180
|
return correlation_to_layer
|
|
83
181
|
|
|
84
182
|
|
|
85
|
-
def _build_python_stack_index(
|
|
86
|
-
events: list[dict[str, Any]],
|
|
87
|
-
) -> tuple[list[tuple[int, int, int, int | None, str]], dict[int, dict[str, Any]]]:
|
|
88
|
-
"""Build Python call stack index for kernels.
|
|
89
|
-
|
|
90
|
-
Args:
|
|
91
|
-
events: List of trace events
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
Tuple of (python_intervals, python_by_id)
|
|
95
|
-
"""
|
|
96
|
-
python_by_id: dict[int, dict[str, Any]] = {}
|
|
97
|
-
python_intervals: list[tuple[int, int, int, int | None, str]] = []
|
|
98
|
-
|
|
99
|
-
for ev in events:
|
|
100
|
-
if ev.get("cat") == "python_function":
|
|
101
|
-
py_id = ev.get("args", {}).get("Python id")
|
|
102
|
-
name = ev["name"]
|
|
103
|
-
ts_start = ev["ts"]
|
|
104
|
-
ts_end = ts_start + ev.get("dur", 0)
|
|
105
|
-
duration = ev.get("dur", 0)
|
|
106
|
-
parent_id = ev.get("args", {}).get("Python parent id")
|
|
107
|
-
|
|
108
|
-
python_intervals.append((ts_start, ts_end, duration, py_id, name))
|
|
109
|
-
|
|
110
|
-
if py_id is not None:
|
|
111
|
-
python_by_id[py_id] = {
|
|
112
|
-
"name": name,
|
|
113
|
-
"parent_id": parent_id,
|
|
114
|
-
"ts_start": ts_start,
|
|
115
|
-
"ts_end": ts_end,
|
|
116
|
-
"duration": duration,
|
|
117
|
-
}
|
|
118
|
-
|
|
119
|
-
# Sort by start time for efficient binary search
|
|
120
|
-
python_intervals.sort()
|
|
121
|
-
|
|
122
|
-
return python_intervals, python_by_id
|
|
123
|
-
|
|
124
|
-
|
|
125
183
|
def _get_python_stack_full(
|
|
126
184
|
timestamp: int,
|
|
127
185
|
python_intervals: list[tuple[int, int, int, int | None, str]],
|
|
128
186
|
python_by_id: dict[int, dict[str, Any]],
|
|
129
187
|
) -> tuple[str | None, list[str]]:
|
|
130
188
|
"""Get full Python call stack for a kernel launch.
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
python_intervals: Sorted list of Python function intervals
|
|
135
|
-
python_by_id: Mapping of Python ID to function info
|
|
136
|
-
|
|
137
|
-
Returns:
|
|
138
|
-
Tuple of (summary_string, full_stack_list)
|
|
189
|
+
|
|
190
|
+
Computes the chain on-demand by walking parent pointers.
|
|
191
|
+
This is fast because we only call this ~48 times due to cpu_op caching.
|
|
139
192
|
"""
|
|
140
|
-
# Binary search for Python functions active at this timestamp
|
|
141
193
|
idx = bisect.bisect_right(
|
|
142
194
|
python_intervals, (timestamp, float("inf"), float("inf"), None, "")
|
|
143
195
|
)
|
|
144
|
-
|
|
145
|
-
# Find active functions
|
|
196
|
+
|
|
146
197
|
active_funcs = []
|
|
147
198
|
for i in range(idx - 1, max(0, idx - 1000), -1):
|
|
148
199
|
ts_start, ts_end, duration, py_id, name = python_intervals[i]
|
|
149
200
|
if ts_start <= timestamp <= ts_end:
|
|
150
201
|
active_funcs.append((duration, py_id, name))
|
|
151
|
-
if ts_end < timestamp - 1000000:
|
|
202
|
+
if ts_end < timestamp - 1000000:
|
|
152
203
|
break
|
|
153
|
-
|
|
204
|
+
|
|
154
205
|
if not active_funcs:
|
|
155
206
|
return None, []
|
|
156
|
-
|
|
157
|
-
# Get the innermost (most specific) function
|
|
207
|
+
|
|
158
208
|
active_funcs.sort()
|
|
159
209
|
leaf_duration, leaf_id, leaf_name = active_funcs[0]
|
|
160
|
-
|
|
161
|
-
# Walk up parent chain to get FULL stack
|
|
210
|
+
|
|
162
211
|
full_stack = []
|
|
163
212
|
current_id = leaf_id
|
|
164
|
-
visited = set()
|
|
165
|
-
|
|
166
|
-
while
|
|
167
|
-
current_id is not None
|
|
168
|
-
and current_id not in visited
|
|
169
|
-
and current_id in python_by_id
|
|
170
|
-
):
|
|
213
|
+
visited: set[int] = set()
|
|
214
|
+
|
|
215
|
+
while current_id is not None and current_id not in visited and current_id in python_by_id:
|
|
171
216
|
func = python_by_id[current_id]
|
|
172
|
-
|
|
173
|
-
full_stack.append(name)
|
|
174
|
-
|
|
217
|
+
full_stack.append(func["name"])
|
|
175
218
|
visited.add(current_id)
|
|
176
219
|
current_id = func["parent_id"]
|
|
177
|
-
|
|
178
|
-
# Safety limit: prevent infinite loops from circular parent references
|
|
179
|
-
# and bound memory usage. 50 frames is deeper than typical Python stacks.
|
|
180
220
|
if len(full_stack) >= 50:
|
|
181
221
|
break
|
|
182
|
-
|
|
183
|
-
# Reverse so it's outermost -> innermost
|
|
222
|
+
|
|
184
223
|
full_stack.reverse()
|
|
185
|
-
|
|
186
|
-
# Create summary for text output: show the most informative vLLM/model function
|
|
224
|
+
|
|
187
225
|
summary = None
|
|
188
|
-
vllm_funcs = [
|
|
189
|
-
|
|
190
|
-
for f in full_stack
|
|
191
|
-
if any(x in f.lower() for x in ["vllm/", "model", "<eval_with_key>"])
|
|
192
|
-
]
|
|
193
|
-
|
|
226
|
+
vllm_funcs = [f for f in full_stack if any(x in f.lower() for x in ["vllm/", "model", "<eval_with_key>"])]
|
|
227
|
+
|
|
194
228
|
if vllm_funcs:
|
|
195
|
-
# Get innermost vLLM function (most specific)
|
|
196
229
|
summary = vllm_funcs[-1]
|
|
197
|
-
|
|
198
|
-
# Check if it's a CUDA graph - add annotation
|
|
199
230
|
if any("torch/cuda/graphs" in f for f in full_stack):
|
|
200
|
-
# Shorten if too long
|
|
201
231
|
if len(summary) > 45:
|
|
202
|
-
|
|
203
|
-
summary = "vllm/..." + parts
|
|
232
|
+
summary = "vllm/..." + summary.split("/")[-1]
|
|
204
233
|
summary = f"{summary} [CUDA graph]"
|
|
205
234
|
elif len(summary) > 53:
|
|
206
|
-
|
|
207
|
-
summary = "vllm/..." + parts
|
|
235
|
+
summary = "vllm/..." + summary.split("/")[-1]
|
|
208
236
|
else:
|
|
209
|
-
# Fallback to innermost function
|
|
210
237
|
summary = leaf_name
|
|
211
|
-
|
|
238
|
+
|
|
212
239
|
return summary, full_stack
|
|
213
240
|
|
|
214
241
|
|
|
215
242
|
def load_trace(
|
|
216
243
|
file_path: str | Path,
|
|
244
|
+
include_stacks: bool = True,
|
|
245
|
+
on_progress: ProgressCallback | None = None,
|
|
217
246
|
) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
|
|
218
247
|
"""Load trace and return platform info, device properties, kernels, patterns, and layer mapping.
|
|
219
|
-
|
|
248
|
+
|
|
220
249
|
Args:
|
|
221
|
-
file_path: Path to
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
Tuple of (platform, gpu_name, device_props, kernel_df, kernel_patterns, layer_mapping)
|
|
250
|
+
file_path: Path to the trace JSON file
|
|
251
|
+
include_stacks: Whether to resolve Python call stacks (slower but more info)
|
|
252
|
+
on_progress: Optional callback for progress updates: (stage_name, progress_fraction)
|
|
225
253
|
"""
|
|
254
|
+
def _progress(stage: str, pct: float) -> None:
|
|
255
|
+
if on_progress:
|
|
256
|
+
on_progress(stage, pct)
|
|
257
|
+
|
|
258
|
+
_progress("Reading file", 0.0)
|
|
226
259
|
with open(file_path, "rb") as f:
|
|
227
|
-
|
|
228
|
-
|
|
260
|
+
raw = f.read()
|
|
261
|
+
|
|
262
|
+
_progress("Parsing JSON", 0.1)
|
|
263
|
+
trace = orjson.loads(raw)
|
|
264
|
+
|
|
229
265
|
props = trace.get("deviceProperties", [{}])[0]
|
|
230
266
|
is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
|
|
231
267
|
platform = "AMD" if is_amd else "NVIDIA"
|
|
232
268
|
gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
|
|
233
|
-
|
|
234
|
-
# Extract relevant device properties
|
|
269
|
+
|
|
235
270
|
device_props = {
|
|
236
271
|
"name": gpu_name,
|
|
237
272
|
"compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
|
|
@@ -241,96 +276,360 @@ def load_trace(
|
|
|
241
276
|
"max_threads_per_block": props.get("maxThreadsPerBlock", 0),
|
|
242
277
|
"shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
|
|
243
278
|
}
|
|
244
|
-
|
|
279
|
+
|
|
280
|
+
_progress("Processing events", 0.4)
|
|
245
281
|
events = trace.get("traceEvents", [])
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
for ev in
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
for i, p in enumerate(parts):
|
|
271
|
-
# Look for execute_context_X(TOKENS) specifically
|
|
272
|
-
if i > 0 and parts[i-1] == "context" and "(" in p and ")" in p:
|
|
273
|
-
try:
|
|
274
|
-
tokens = int(p.split("(")[1].split(")")[0])
|
|
275
|
-
break # Stop after finding context tokens
|
|
276
|
-
except Exception:
|
|
277
|
-
pass
|
|
278
|
-
is_prefill = tokens >= 1024 and "generation_0" in name
|
|
279
|
-
phases.append(
|
|
280
|
-
{
|
|
281
|
-
"type": "prefill" if is_prefill else "decode",
|
|
282
|
-
"ts_start": ev["ts"],
|
|
283
|
-
"ts_end": ev["ts"] + ev["dur"],
|
|
284
|
-
}
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
# Extract layer mapping from correlation IDs
|
|
288
|
-
layer_mapping = extract_layer_mapping(events, platform)
|
|
289
|
-
|
|
282
|
+
pass_result = _process_events_single_pass(events, include_stacks=include_stacks)
|
|
283
|
+
|
|
284
|
+
_progress("Detecting architecture", 0.6)
|
|
285
|
+
kernel_names = [ev.get("name", "") for ev in pass_result.kernel_events]
|
|
286
|
+
architecture, _ = detect_architecture(kernel_names)
|
|
287
|
+
|
|
288
|
+
layer_kernels_dict, layer_warnings = segment_layers_by_architecture(
|
|
289
|
+
pass_result.kernel_events,
|
|
290
|
+
architecture,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Convert layer_kernels_dict (layer_num -> kernels) to correlation_id -> layer_num mapping
|
|
294
|
+
# Fall back to old method if architecture-based segmentation fails
|
|
295
|
+
if layer_kernels_dict:
|
|
296
|
+
layer_mapping: dict[int, int] = {}
|
|
297
|
+
for layer_num, kernels in layer_kernels_dict.items():
|
|
298
|
+
for kernel in kernels:
|
|
299
|
+
corr_id = kernel.get("args", {}).get("correlation")
|
|
300
|
+
if corr_id is not None:
|
|
301
|
+
layer_mapping[corr_id] = layer_num
|
|
302
|
+
else:
|
|
303
|
+
# Fallback to correlation-based method if architecture segmentation failed
|
|
304
|
+
layer_mapping = _build_layer_mapping(pass_result.correlation_groups)
|
|
305
|
+
|
|
290
306
|
kernel_data = []
|
|
291
307
|
kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
308
|
+
sorted_phases = sorted(pass_result.phases, key=lambda p: p["ts_start"])
|
|
309
|
+
|
|
310
|
+
phase_starts = [p["ts_start"] for p in sorted_phases]
|
|
311
|
+
phase_types = [p["type"] for p in sorted_phases]
|
|
312
|
+
phase_ends = [p["ts_end"] for p in sorted_phases]
|
|
313
|
+
|
|
314
|
+
def _get_phase_for_timestamp(ts: int) -> str:
|
|
315
|
+
"""Get phase for a timestamp using binary search. O(log n)."""
|
|
316
|
+
if not phase_starts:
|
|
317
|
+
return "decode"
|
|
318
|
+
idx = bisect.bisect_right(phase_starts, ts) - 1
|
|
319
|
+
if idx >= 0 and phase_starts[idx] <= ts <= phase_ends[idx]:
|
|
320
|
+
return phase_types[idx]
|
|
321
|
+
return "decode"
|
|
322
|
+
|
|
323
|
+
cpu_op_cache: dict[str, str | None] = {}
|
|
324
|
+
|
|
325
|
+
_progress("Classifying kernels", 0.7)
|
|
326
|
+
for ev in pass_result.kernel_events:
|
|
327
|
+
name_raw = ev["name"]
|
|
328
|
+
name = sys.intern(name_raw)
|
|
329
|
+
dur, ts = ev.get("dur", 0), ev["ts"]
|
|
297
330
|
corr_id = ev.get("args", {}).get("correlation")
|
|
298
331
|
ext_id = ev.get("args", {}).get("External id")
|
|
299
|
-
|
|
300
|
-
phase =
|
|
301
|
-
|
|
302
|
-
if p["ts_start"] <= ts <= p["ts_end"]:
|
|
303
|
-
phase = p["type"]
|
|
304
|
-
break
|
|
305
|
-
|
|
332
|
+
|
|
333
|
+
phase = _get_phase_for_timestamp(ts)
|
|
334
|
+
|
|
306
335
|
op, pattern = classify(name, platform)
|
|
307
336
|
kernel_patterns[(op.value, phase)].add(pattern)
|
|
308
|
-
|
|
309
|
-
# Assign layer number from correlation ID
|
|
337
|
+
|
|
310
338
|
layer = layer_mapping.get(corr_id) if corr_id is not None else None
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
339
|
+
cpu_op = pass_result.cpu_op_mapping.get(ext_id) if ext_id is not None else None
|
|
340
|
+
python_stack: list[str] = []
|
|
341
|
+
|
|
342
|
+
if cpu_op is None and include_stacks:
|
|
343
|
+
if name in cpu_op_cache:
|
|
344
|
+
cpu_op = cpu_op_cache[name]
|
|
345
|
+
else:
|
|
346
|
+
cpu_op, python_stack = _get_python_stack_full(
|
|
347
|
+
ts, pass_result.python_intervals, pass_result.python_by_id
|
|
348
|
+
)
|
|
349
|
+
cpu_op_cache[name] = cpu_op
|
|
350
|
+
|
|
351
|
+
kernel_data.append({
|
|
352
|
+
"name": name,
|
|
353
|
+
"dur_us": dur,
|
|
354
|
+
"phase": phase,
|
|
355
|
+
"op": op.value,
|
|
356
|
+
"pattern": pattern,
|
|
357
|
+
"layer": layer,
|
|
358
|
+
"correlation": corr_id,
|
|
359
|
+
"cpu_op": cpu_op,
|
|
360
|
+
"python_stack": python_stack,
|
|
361
|
+
})
|
|
362
|
+
|
|
363
|
+
_progress("Building DataFrame", 0.95)
|
|
364
|
+
df = pd.DataFrame(kernel_data)
|
|
365
|
+
_progress("Complete", 1.0)
|
|
366
|
+
|
|
367
|
+
return platform, gpu_name, device_props, df, dict(kernel_patterns), layer_mapping
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def load_trace_full(
|
|
371
|
+
file_path: str | Path,
|
|
372
|
+
include_stacks: bool = True,
|
|
373
|
+
on_progress: ProgressCallback | None = None,
|
|
374
|
+
) -> LoadedTrace:
|
|
375
|
+
"""Load trace once with all data needed by downstream analysis functions.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
file_path: Path to the trace JSON file
|
|
379
|
+
include_stacks: Whether to resolve Python call stacks
|
|
380
|
+
on_progress: Optional callback for progress updates
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
LoadedTrace with all trace data
|
|
384
|
+
"""
|
|
385
|
+
def _progress(stage: str, pct: float) -> None:
|
|
386
|
+
if on_progress:
|
|
387
|
+
on_progress(stage, pct)
|
|
388
|
+
|
|
389
|
+
_progress("Reading file", 0.0)
|
|
390
|
+
with open(file_path, "rb") as f:
|
|
391
|
+
raw = f.read()
|
|
392
|
+
|
|
393
|
+
_progress("Parsing JSON", 0.1)
|
|
394
|
+
trace = orjson.loads(raw)
|
|
395
|
+
all_events = trace.get("traceEvents", [])
|
|
396
|
+
|
|
397
|
+
props = trace.get("deviceProperties", [{}])[0]
|
|
398
|
+
is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
|
|
399
|
+
platform = "AMD" if is_amd else "NVIDIA"
|
|
400
|
+
gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
|
|
401
|
+
|
|
402
|
+
device_props = {
|
|
403
|
+
"name": gpu_name,
|
|
404
|
+
"compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
|
|
405
|
+
"total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
|
|
406
|
+
"sm_count": props.get("numSms", 0),
|
|
407
|
+
"warp_size": props.get("warpSize", 32),
|
|
408
|
+
"max_threads_per_block": props.get("maxThreadsPerBlock", 0),
|
|
409
|
+
"shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
_progress("Processing events", 0.4)
|
|
413
|
+
pass_result = _process_events_single_pass(all_events, include_stacks=include_stacks)
|
|
414
|
+
|
|
415
|
+
_progress("Detecting architecture", 0.6)
|
|
416
|
+
kernel_names = [ev.get("name", "") for ev in pass_result.kernel_events]
|
|
417
|
+
architecture, _ = detect_architecture(kernel_names)
|
|
418
|
+
|
|
419
|
+
layer_kernels_dict, layer_warnings = segment_layers_by_architecture(
|
|
420
|
+
pass_result.kernel_events,
|
|
421
|
+
architecture,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Convert layer_kernels_dict (layer_num -> kernels) to correlation_id -> layer_num mapping
|
|
425
|
+
if layer_kernels_dict:
|
|
426
|
+
layer_mapping: dict[int, int] = {}
|
|
427
|
+
for layer_num, kernels in layer_kernels_dict.items():
|
|
428
|
+
for kernel in kernels:
|
|
429
|
+
corr_id = kernel.get("args", {}).get("correlation")
|
|
430
|
+
if corr_id is not None:
|
|
431
|
+
layer_mapping[corr_id] = layer_num
|
|
432
|
+
else:
|
|
433
|
+
layer_mapping = _build_layer_mapping(pass_result.correlation_groups)
|
|
434
|
+
|
|
435
|
+
kernel_data = []
|
|
436
|
+
kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
|
|
437
|
+
sorted_phases = sorted(pass_result.phases, key=lambda p: p["ts_start"])
|
|
438
|
+
|
|
439
|
+
phase_starts = [p["ts_start"] for p in sorted_phases]
|
|
440
|
+
phase_types = [p["type"] for p in sorted_phases]
|
|
441
|
+
phase_ends = [p["ts_end"] for p in sorted_phases]
|
|
442
|
+
|
|
443
|
+
def _get_phase_for_timestamp(ts: int) -> str:
|
|
444
|
+
"""Get phase for a timestamp using binary search. O(log n)."""
|
|
445
|
+
if not phase_starts:
|
|
446
|
+
return "decode"
|
|
447
|
+
idx = bisect.bisect_right(phase_starts, ts) - 1
|
|
448
|
+
if idx >= 0 and phase_starts[idx] <= ts <= phase_ends[idx]:
|
|
449
|
+
return phase_types[idx]
|
|
450
|
+
return "decode"
|
|
451
|
+
|
|
452
|
+
cpu_op_cache: dict[str, str | None] = {}
|
|
453
|
+
|
|
454
|
+
_progress("Classifying kernels", 0.7)
|
|
455
|
+
for ev in pass_result.kernel_events:
|
|
456
|
+
name_raw = ev["name"]
|
|
457
|
+
name = sys.intern(name_raw)
|
|
458
|
+
dur, ts = ev.get("dur", 0), ev["ts"]
|
|
459
|
+
corr_id = ev.get("args", {}).get("correlation")
|
|
460
|
+
ext_id = ev.get("args", {}).get("External id")
|
|
461
|
+
|
|
462
|
+
phase = _get_phase_for_timestamp(ts)
|
|
463
|
+
|
|
464
|
+
op, pattern = classify(name, platform)
|
|
465
|
+
kernel_patterns[(op.value, phase)].add(pattern)
|
|
466
|
+
|
|
467
|
+
layer = layer_mapping.get(corr_id) if corr_id is not None else None
|
|
468
|
+
cpu_op = pass_result.cpu_op_mapping.get(ext_id) if ext_id is not None else None
|
|
314
469
|
python_stack: list[str] = []
|
|
470
|
+
|
|
471
|
+
if cpu_op is None and include_stacks:
|
|
472
|
+
if name in cpu_op_cache:
|
|
473
|
+
cpu_op = cpu_op_cache[name]
|
|
474
|
+
else:
|
|
475
|
+
cpu_op, python_stack = _get_python_stack_full(
|
|
476
|
+
ts, pass_result.python_intervals, pass_result.python_by_id
|
|
477
|
+
)
|
|
478
|
+
cpu_op_cache[name] = cpu_op
|
|
479
|
+
|
|
480
|
+
kernel_data.append({
|
|
481
|
+
"name": name,
|
|
482
|
+
"dur_us": dur,
|
|
483
|
+
"phase": phase,
|
|
484
|
+
"op": op.value,
|
|
485
|
+
"pattern": pattern,
|
|
486
|
+
"layer": layer,
|
|
487
|
+
"correlation": corr_id,
|
|
488
|
+
"cpu_op": cpu_op,
|
|
489
|
+
"python_stack": python_stack,
|
|
490
|
+
})
|
|
491
|
+
|
|
492
|
+
_progress("Building DataFrame", 0.95)
|
|
493
|
+
df = pd.DataFrame(kernel_data)
|
|
494
|
+
|
|
495
|
+
kernel_events = pass_result.kernel_events
|
|
496
|
+
correlation_groups: dict[int, list[dict[str, Any]]] = defaultdict(list)
|
|
497
|
+
for ev in kernel_events:
|
|
498
|
+
corr_id = ev.get("args", {}).get("correlation")
|
|
499
|
+
if corr_id is not None:
|
|
500
|
+
correlation_groups[corr_id].append(ev)
|
|
501
|
+
|
|
502
|
+
_progress("Complete", 1.0)
|
|
503
|
+
|
|
504
|
+
return LoadedTrace(
|
|
505
|
+
platform=platform,
|
|
506
|
+
gpu_name=gpu_name,
|
|
507
|
+
device_props=device_props,
|
|
508
|
+
df=df,
|
|
509
|
+
patterns=dict(kernel_patterns),
|
|
510
|
+
layers=layer_mapping,
|
|
511
|
+
kernel_events=kernel_events,
|
|
512
|
+
all_events=all_events,
|
|
513
|
+
correlation_groups=dict(correlation_groups),
|
|
514
|
+
phases=pass_result.phases,
|
|
515
|
+
)
|
|
516
|
+
|
|
315
517
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
518
|
+
@dataclass
|
|
519
|
+
class StreamingMetadata:
|
|
520
|
+
"""Early metadata available before full trace processing."""
|
|
521
|
+
platform: str
|
|
522
|
+
gpu_name: str
|
|
523
|
+
device_props: dict[str, Any]
|
|
524
|
+
file_size_mb: float
|
|
321
525
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
526
|
+
|
|
527
|
+
def _extract_metadata_fast(file_path: Path) -> StreamingMetadata:
|
|
528
|
+
"""Extract trace metadata instantly using streaming parser (~2ms).
|
|
529
|
+
|
|
530
|
+
Uses ijson to read only the deviceProperties section without
|
|
531
|
+
parsing the entire file. Falls back to full parse if ijson unavailable.
|
|
532
|
+
"""
|
|
533
|
+
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
|
534
|
+
platform = "Unknown"
|
|
535
|
+
gpu_name = "Unknown GPU"
|
|
536
|
+
device_props: dict[str, Any] = {}
|
|
537
|
+
|
|
538
|
+
if ijson is None:
|
|
539
|
+
# Fallback: parse full file (slower but works)
|
|
540
|
+
with open(file_path, "rb") as f:
|
|
541
|
+
trace = orjson.loads(f.read())
|
|
542
|
+
props = trace.get("deviceProperties", [{}])[0]
|
|
543
|
+
is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
|
|
544
|
+
platform = "AMD" if is_amd else "NVIDIA"
|
|
545
|
+
gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
|
|
546
|
+
device_props = {
|
|
547
|
+
"name": gpu_name,
|
|
548
|
+
"compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
|
|
549
|
+
"total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
|
|
550
|
+
"sm_count": props.get("numSms", 0),
|
|
551
|
+
"warp_size": props.get("warpSize", 32),
|
|
552
|
+
"max_threads_per_block": props.get("maxThreadsPerBlock", 0),
|
|
553
|
+
"shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
|
|
554
|
+
}
|
|
555
|
+
return StreamingMetadata(
|
|
556
|
+
platform=platform,
|
|
557
|
+
gpu_name=gpu_name,
|
|
558
|
+
device_props=device_props,
|
|
559
|
+
file_size_mb=file_size_mb,
|
|
334
560
|
)
|
|
561
|
+
|
|
562
|
+
with open(file_path, "rb") as f:
|
|
563
|
+
parser = ijson.parse(f)
|
|
564
|
+
for prefix, event, value in parser:
|
|
565
|
+
if prefix == "deviceProperties.item.name":
|
|
566
|
+
gpu_name = value
|
|
567
|
+
elif prefix == "deviceProperties.item.warpSize":
|
|
568
|
+
platform = "AMD" if value == 64 else "NVIDIA"
|
|
569
|
+
elif prefix == "deviceProperties.item.computeMajor":
|
|
570
|
+
device_props["compute_major"] = value
|
|
571
|
+
elif prefix == "deviceProperties.item.computeMinor":
|
|
572
|
+
device_props["compute_minor"] = value
|
|
573
|
+
elif prefix == "deviceProperties.item.totalGlobalMem":
|
|
574
|
+
device_props["total_memory_gb"] = value / (1024**3)
|
|
575
|
+
elif prefix == "deviceProperties.item.numSms":
|
|
576
|
+
device_props["sm_count"] = value
|
|
577
|
+
elif prefix == "deviceProperties.item.maxThreadsPerBlock":
|
|
578
|
+
device_props["max_threads_per_block"] = value
|
|
579
|
+
elif prefix == "deviceProperties.item.sharedMemPerBlock":
|
|
580
|
+
device_props["shared_mem_per_block_kb"] = value / 1024
|
|
581
|
+
elif prefix == "traceEvents.item":
|
|
582
|
+
# Hit first event, stop - we only need metadata
|
|
583
|
+
break
|
|
584
|
+
|
|
585
|
+
# Fallback platform detection
|
|
586
|
+
if platform == "Unknown":
|
|
587
|
+
platform = "AMD" if "MI" in gpu_name or "Instinct" in gpu_name else "NVIDIA"
|
|
588
|
+
|
|
589
|
+
device_props["name"] = gpu_name
|
|
590
|
+
device_props["compute_capability"] = f"{device_props.get('compute_major', 0)}.{device_props.get('compute_minor', 0)}"
|
|
591
|
+
device_props["warp_size"] = 64 if platform == "AMD" else 32
|
|
592
|
+
|
|
593
|
+
return StreamingMetadata(
|
|
594
|
+
platform=platform,
|
|
595
|
+
gpu_name=gpu_name,
|
|
596
|
+
device_props=device_props,
|
|
597
|
+
file_size_mb=file_size_mb,
|
|
598
|
+
)
|
|
599
|
+
|
|
335
600
|
|
|
336
|
-
|
|
601
|
+
def load_trace_streaming(
|
|
602
|
+
file_path: str | Path,
|
|
603
|
+
include_stacks: bool = True,
|
|
604
|
+
on_metadata: Callable[[StreamingMetadata], None] | None = None,
|
|
605
|
+
on_progress: ProgressCallback | None = None,
|
|
606
|
+
) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
|
|
607
|
+
"""Load trace with instant metadata feedback.
|
|
608
|
+
|
|
609
|
+
Hybrid approach:
|
|
610
|
+
1. Phase 1 (~2ms): Extract metadata with ijson streaming
|
|
611
|
+
2. Phase 2 (full): Parse with orjson and process
|
|
612
|
+
|
|
613
|
+
The on_metadata callback fires immediately with GPU/platform info,
|
|
614
|
+
allowing the UI to show feedback while the full load continues.
|
|
615
|
+
|
|
616
|
+
Args:
|
|
617
|
+
file_path: Path to the trace JSON file
|
|
618
|
+
include_stacks: Whether to resolve Python call stacks
|
|
619
|
+
on_metadata: Callback for instant metadata (fires in ~2ms)
|
|
620
|
+
on_progress: Callback for progress updates during processing
|
|
621
|
+
"""
|
|
622
|
+
file_path = Path(file_path)
|
|
623
|
+
|
|
624
|
+
# Phase 1: Instant metadata extraction (~2ms)
|
|
625
|
+
metadata = _extract_metadata_fast(file_path)
|
|
626
|
+
|
|
627
|
+
if on_metadata:
|
|
628
|
+
on_metadata(metadata)
|
|
629
|
+
|
|
630
|
+
# Phase 2: Full load with orjson (fast) + progress updates
|
|
631
|
+
return load_trace(
|
|
632
|
+
file_path,
|
|
633
|
+
include_stacks=include_stacks,
|
|
634
|
+
on_progress=on_progress,
|
|
635
|
+
)
|