wafer-core 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/PERFORMANCE.md +148 -0
- wafer_core/lib/trace_compare/__init__.py +45 -0
- wafer_core/lib/trace_compare/aligner.py +369 -0
- wafer_core/lib/trace_compare/analyzer.py +729 -0
- wafer_core/lib/trace_compare/api.py +225 -0
- wafer_core/lib/trace_compare/architecture.py +77 -0
- wafer_core/lib/trace_compare/classifier.py +486 -0
- wafer_core/lib/trace_compare/formatter.py +951 -0
- wafer_core/lib/trace_compare/fusion_analyzer.py +356 -0
- wafer_core/lib/trace_compare/kernel_registry.yaml +349 -0
- wafer_core/lib/trace_compare/layer_segmentation.py +114 -0
- wafer_core/lib/trace_compare/loader.py +635 -0
- wafer_core/lib/trace_compare/same_kernel_analyzer.py +119 -0
- wafer_core/lib/trace_compare/warnings.py +99 -0
- wafer_core/problem_config.py +3 -3
- wafer_core/rollouts/agent_presets/rlm_01_01.py +2 -2
- wafer_core/rollouts/dtypes.py +18 -3
- wafer_core/rollouts/providers/anthropic.py +35 -3
- wafer_core/utils/kernel_utils/defense.py +10 -0
- wafer_core/utils/kernel_utils/targets/config.py +10 -0
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.27.dist-info}/METADATA +3 -1
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.27.dist-info}/RECORD +23 -9
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.27.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,635 @@
|
|
|
1
|
+
"""Trace loading and parsing logic.
|
|
2
|
+
|
|
3
|
+
Loads JSON trace files from AMD/NVIDIA profilers and extracts kernel execution data,
|
|
4
|
+
Python call stacks, CPU operator mappings, and layer correlations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import bisect
|
|
8
|
+
import sys
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import ijson
|
|
17
|
+
except ImportError:
|
|
18
|
+
ijson = None
|
|
19
|
+
|
|
20
|
+
import orjson
|
|
21
|
+
import pandas as pd
|
|
22
|
+
|
|
23
|
+
from .architecture import detect_architecture
|
|
24
|
+
from .classifier import classify
|
|
25
|
+
from .layer_segmentation import segment_layers_by_architecture
|
|
26
|
+
|
|
27
|
+
ProgressCallback = Callable[[str, float], None]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class SinglePassResult:
|
|
32
|
+
"""Collected data from single-pass event processing."""
|
|
33
|
+
cpu_op_mapping: dict[int, str] = field(default_factory=dict)
|
|
34
|
+
python_intervals: list[tuple[int, int, int, int | None, str]] = field(default_factory=list)
|
|
35
|
+
# Raw python events for lazy python_by_id construction (built on-demand in _get_python_stack_full)
|
|
36
|
+
python_events_raw: list[dict[str, Any]] = field(default_factory=list)
|
|
37
|
+
phases: list[dict[str, Any]] = field(default_factory=list)
|
|
38
|
+
correlation_groups: dict[int, dict[str, Any]] = field(default_factory=lambda: defaultdict(
|
|
39
|
+
lambda: {"count": 0, "has_attention": False, "has_ffn": False}
|
|
40
|
+
))
|
|
41
|
+
kernel_events: list[dict[str, Any]] = field(default_factory=list)
|
|
42
|
+
# Lazily built when needed for stack resolution
|
|
43
|
+
_python_by_id: dict[int, dict[str, Any]] | None = field(default=None)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def python_by_id(self) -> dict[int, dict[str, Any]]:
|
|
47
|
+
"""Lazily build python_by_id from raw events on first access."""
|
|
48
|
+
if self._python_by_id is None:
|
|
49
|
+
self._python_by_id = {}
|
|
50
|
+
for ev in self.python_events_raw:
|
|
51
|
+
args = ev.get("args")
|
|
52
|
+
py_id = args.get("Python id") if args else None
|
|
53
|
+
if py_id is not None:
|
|
54
|
+
self._python_by_id[py_id] = {
|
|
55
|
+
"name": ev["name"],
|
|
56
|
+
"parent_id": args.get("Python parent id") if args else None,
|
|
57
|
+
}
|
|
58
|
+
return self._python_by_id
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class LoadedTrace:
|
|
63
|
+
"""Complete trace data loaded once and reused by all analysis functions."""
|
|
64
|
+
platform: str
|
|
65
|
+
gpu_name: str
|
|
66
|
+
device_props: dict[str, Any]
|
|
67
|
+
df: pd.DataFrame
|
|
68
|
+
patterns: dict[tuple[str, str], set[str]]
|
|
69
|
+
layers: dict[int, int]
|
|
70
|
+
# For fusion/warnings (kept from raw JSON)
|
|
71
|
+
kernel_events: list[dict[str, Any]]
|
|
72
|
+
all_events: list[dict[str, Any]]
|
|
73
|
+
correlation_groups: dict[int, list[dict[str, Any]]]
|
|
74
|
+
phases: list[dict[str, Any]] # Phase annotations for alignment
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _process_events_single_pass(
|
|
78
|
+
events: list[dict[str, Any]],
|
|
79
|
+
include_stacks: bool = True,
|
|
80
|
+
) -> SinglePassResult:
|
|
81
|
+
"""Process all events in a single iteration.
|
|
82
|
+
|
|
83
|
+
Optimizations applied:
|
|
84
|
+
- Cache list.append methods for 2-3x speedup on hot paths
|
|
85
|
+
- Store raw python events, build python_by_id lazily (only ~48 lookups due to caching)
|
|
86
|
+
- Local variable caching for frequently accessed attributes
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
events: List of trace events
|
|
90
|
+
include_stacks: Whether to collect Python stack info (expensive operation).
|
|
91
|
+
When False, skips python_function processing entirely.
|
|
92
|
+
"""
|
|
93
|
+
result = SinglePassResult()
|
|
94
|
+
correlation_groups: dict[int, dict[str, Any]] = defaultdict(
|
|
95
|
+
lambda: {"count": 0, "has_attention": False, "has_ffn": False}
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Cache list.append methods for faster appending (measured 2-3x speedup)
|
|
99
|
+
kernel_append = result.kernel_events.append
|
|
100
|
+
python_interval_append = result.python_intervals.append
|
|
101
|
+
python_raw_append = result.python_events_raw.append
|
|
102
|
+
phase_append = result.phases.append
|
|
103
|
+
cpu_op_mapping = result.cpu_op_mapping
|
|
104
|
+
|
|
105
|
+
for ev in events:
|
|
106
|
+
cat = ev.get("cat")
|
|
107
|
+
|
|
108
|
+
if cat == "kernel":
|
|
109
|
+
args = ev.get("args")
|
|
110
|
+
corr_id = args.get("correlation") if args else None
|
|
111
|
+
if corr_id is not None:
|
|
112
|
+
kernel_name = ev.get("name", "").lower()
|
|
113
|
+
grp = correlation_groups[corr_id]
|
|
114
|
+
grp["count"] += 1
|
|
115
|
+
if "attention" in kernel_name or "fmha" in kernel_name:
|
|
116
|
+
grp["has_attention"] = True
|
|
117
|
+
if "cijk_" in kernel_name or "nvjet" in kernel_name or "wvsplitk" in kernel_name or "gemm" in kernel_name:
|
|
118
|
+
grp["has_ffn"] = True
|
|
119
|
+
kernel_append(ev)
|
|
120
|
+
|
|
121
|
+
elif cat == "cpu_op":
|
|
122
|
+
args = ev.get("args")
|
|
123
|
+
ext_id = args.get("External id") if args else None
|
|
124
|
+
if ext_id is not None:
|
|
125
|
+
cpu_op_mapping[ext_id] = ev.get("name", "")
|
|
126
|
+
|
|
127
|
+
elif cat == "python_function" and include_stacks:
|
|
128
|
+
# Store raw event for lazy python_by_id construction
|
|
129
|
+
python_raw_append(ev)
|
|
130
|
+
# Build interval tuple for binary search
|
|
131
|
+
args = ev.get("args")
|
|
132
|
+
py_id = args.get("Python id") if args else None
|
|
133
|
+
ts = ev["ts"]
|
|
134
|
+
dur = ev.get("dur", 0)
|
|
135
|
+
python_interval_append((ts, ts + dur, dur, py_id, ev["name"]))
|
|
136
|
+
|
|
137
|
+
elif cat == "user_annotation":
|
|
138
|
+
name = ev.get("name", "")
|
|
139
|
+
if name.startswith("execute_context"):
|
|
140
|
+
tokens = 0
|
|
141
|
+
parts = name.split("_")
|
|
142
|
+
for i, p in enumerate(parts):
|
|
143
|
+
if i > 0 and parts[i-1] == "context" and "(" in p and ")" in p:
|
|
144
|
+
try:
|
|
145
|
+
tokens = int(p.split("(")[1].split(")")[0])
|
|
146
|
+
break
|
|
147
|
+
except Exception:
|
|
148
|
+
pass
|
|
149
|
+
is_prefill = "generation_0" in name and tokens > 0
|
|
150
|
+
phase_append({
|
|
151
|
+
"type": "prefill" if is_prefill else "decode",
|
|
152
|
+
"ts_start": ev["ts"],
|
|
153
|
+
"ts_end": ev["ts"] + ev["dur"],
|
|
154
|
+
})
|
|
155
|
+
|
|
156
|
+
if include_stacks and result.python_intervals:
|
|
157
|
+
result.python_intervals.sort()
|
|
158
|
+
|
|
159
|
+
result.correlation_groups = dict(correlation_groups)
|
|
160
|
+
|
|
161
|
+
return result
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _build_layer_mapping(correlation_groups: dict[int, dict[str, Any]]) -> dict[int, int]:
|
|
165
|
+
"""Build layer mapping from correlation groups."""
|
|
166
|
+
correlation_to_layer = {}
|
|
167
|
+
layer_num = 0
|
|
168
|
+
|
|
169
|
+
for corr_id in sorted(correlation_groups.keys()):
|
|
170
|
+
group = correlation_groups[corr_id]
|
|
171
|
+
|
|
172
|
+
is_layer = (
|
|
173
|
+
group["count"] >= 70 and group["has_attention"] and group["has_ffn"]
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if is_layer:
|
|
177
|
+
correlation_to_layer[corr_id] = layer_num
|
|
178
|
+
layer_num += 1
|
|
179
|
+
|
|
180
|
+
return correlation_to_layer
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _get_python_stack_full(
|
|
184
|
+
timestamp: int,
|
|
185
|
+
python_intervals: list[tuple[int, int, int, int | None, str]],
|
|
186
|
+
python_by_id: dict[int, dict[str, Any]],
|
|
187
|
+
) -> tuple[str | None, list[str]]:
|
|
188
|
+
"""Get full Python call stack for a kernel launch.
|
|
189
|
+
|
|
190
|
+
Computes the chain on-demand by walking parent pointers.
|
|
191
|
+
This is fast because we only call this ~48 times due to cpu_op caching.
|
|
192
|
+
"""
|
|
193
|
+
idx = bisect.bisect_right(
|
|
194
|
+
python_intervals, (timestamp, float("inf"), float("inf"), None, "")
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
active_funcs = []
|
|
198
|
+
for i in range(idx - 1, max(0, idx - 1000), -1):
|
|
199
|
+
ts_start, ts_end, duration, py_id, name = python_intervals[i]
|
|
200
|
+
if ts_start <= timestamp <= ts_end:
|
|
201
|
+
active_funcs.append((duration, py_id, name))
|
|
202
|
+
if ts_end < timestamp - 1000000:
|
|
203
|
+
break
|
|
204
|
+
|
|
205
|
+
if not active_funcs:
|
|
206
|
+
return None, []
|
|
207
|
+
|
|
208
|
+
active_funcs.sort()
|
|
209
|
+
leaf_duration, leaf_id, leaf_name = active_funcs[0]
|
|
210
|
+
|
|
211
|
+
full_stack = []
|
|
212
|
+
current_id = leaf_id
|
|
213
|
+
visited: set[int] = set()
|
|
214
|
+
|
|
215
|
+
while current_id is not None and current_id not in visited and current_id in python_by_id:
|
|
216
|
+
func = python_by_id[current_id]
|
|
217
|
+
full_stack.append(func["name"])
|
|
218
|
+
visited.add(current_id)
|
|
219
|
+
current_id = func["parent_id"]
|
|
220
|
+
if len(full_stack) >= 50:
|
|
221
|
+
break
|
|
222
|
+
|
|
223
|
+
full_stack.reverse()
|
|
224
|
+
|
|
225
|
+
summary = None
|
|
226
|
+
vllm_funcs = [f for f in full_stack if any(x in f.lower() for x in ["vllm/", "model", "<eval_with_key>"])]
|
|
227
|
+
|
|
228
|
+
if vllm_funcs:
|
|
229
|
+
summary = vllm_funcs[-1]
|
|
230
|
+
if any("torch/cuda/graphs" in f for f in full_stack):
|
|
231
|
+
if len(summary) > 45:
|
|
232
|
+
summary = "vllm/..." + summary.split("/")[-1]
|
|
233
|
+
summary = f"{summary} [CUDA graph]"
|
|
234
|
+
elif len(summary) > 53:
|
|
235
|
+
summary = "vllm/..." + summary.split("/")[-1]
|
|
236
|
+
else:
|
|
237
|
+
summary = leaf_name
|
|
238
|
+
|
|
239
|
+
return summary, full_stack
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def load_trace(
|
|
243
|
+
file_path: str | Path,
|
|
244
|
+
include_stacks: bool = True,
|
|
245
|
+
on_progress: ProgressCallback | None = None,
|
|
246
|
+
) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
|
|
247
|
+
"""Load trace and return platform info, device properties, kernels, patterns, and layer mapping.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
file_path: Path to the trace JSON file
|
|
251
|
+
include_stacks: Whether to resolve Python call stacks (slower but more info)
|
|
252
|
+
on_progress: Optional callback for progress updates: (stage_name, progress_fraction)
|
|
253
|
+
"""
|
|
254
|
+
def _progress(stage: str, pct: float) -> None:
|
|
255
|
+
if on_progress:
|
|
256
|
+
on_progress(stage, pct)
|
|
257
|
+
|
|
258
|
+
_progress("Reading file", 0.0)
|
|
259
|
+
with open(file_path, "rb") as f:
|
|
260
|
+
raw = f.read()
|
|
261
|
+
|
|
262
|
+
_progress("Parsing JSON", 0.1)
|
|
263
|
+
trace = orjson.loads(raw)
|
|
264
|
+
|
|
265
|
+
props = trace.get("deviceProperties", [{}])[0]
|
|
266
|
+
is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
|
|
267
|
+
platform = "AMD" if is_amd else "NVIDIA"
|
|
268
|
+
gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
|
|
269
|
+
|
|
270
|
+
device_props = {
|
|
271
|
+
"name": gpu_name,
|
|
272
|
+
"compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
|
|
273
|
+
"total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
|
|
274
|
+
"sm_count": props.get("numSms", 0),
|
|
275
|
+
"warp_size": props.get("warpSize", 32),
|
|
276
|
+
"max_threads_per_block": props.get("maxThreadsPerBlock", 0),
|
|
277
|
+
"shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
_progress("Processing events", 0.4)
|
|
281
|
+
events = trace.get("traceEvents", [])
|
|
282
|
+
pass_result = _process_events_single_pass(events, include_stacks=include_stacks)
|
|
283
|
+
|
|
284
|
+
_progress("Detecting architecture", 0.6)
|
|
285
|
+
kernel_names = [ev.get("name", "") for ev in pass_result.kernel_events]
|
|
286
|
+
architecture, _ = detect_architecture(kernel_names)
|
|
287
|
+
|
|
288
|
+
layer_kernels_dict, layer_warnings = segment_layers_by_architecture(
|
|
289
|
+
pass_result.kernel_events,
|
|
290
|
+
architecture,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Convert layer_kernels_dict (layer_num -> kernels) to correlation_id -> layer_num mapping
|
|
294
|
+
# Fall back to old method if architecture-based segmentation fails
|
|
295
|
+
if layer_kernels_dict:
|
|
296
|
+
layer_mapping: dict[int, int] = {}
|
|
297
|
+
for layer_num, kernels in layer_kernels_dict.items():
|
|
298
|
+
for kernel in kernels:
|
|
299
|
+
corr_id = kernel.get("args", {}).get("correlation")
|
|
300
|
+
if corr_id is not None:
|
|
301
|
+
layer_mapping[corr_id] = layer_num
|
|
302
|
+
else:
|
|
303
|
+
# Fallback to correlation-based method if architecture segmentation failed
|
|
304
|
+
layer_mapping = _build_layer_mapping(pass_result.correlation_groups)
|
|
305
|
+
|
|
306
|
+
kernel_data = []
|
|
307
|
+
kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
|
|
308
|
+
sorted_phases = sorted(pass_result.phases, key=lambda p: p["ts_start"])
|
|
309
|
+
|
|
310
|
+
phase_starts = [p["ts_start"] for p in sorted_phases]
|
|
311
|
+
phase_types = [p["type"] for p in sorted_phases]
|
|
312
|
+
phase_ends = [p["ts_end"] for p in sorted_phases]
|
|
313
|
+
|
|
314
|
+
def _get_phase_for_timestamp(ts: int) -> str:
|
|
315
|
+
"""Get phase for a timestamp using binary search. O(log n)."""
|
|
316
|
+
if not phase_starts:
|
|
317
|
+
return "decode"
|
|
318
|
+
idx = bisect.bisect_right(phase_starts, ts) - 1
|
|
319
|
+
if idx >= 0 and phase_starts[idx] <= ts <= phase_ends[idx]:
|
|
320
|
+
return phase_types[idx]
|
|
321
|
+
return "decode"
|
|
322
|
+
|
|
323
|
+
cpu_op_cache: dict[str, str | None] = {}
|
|
324
|
+
|
|
325
|
+
_progress("Classifying kernels", 0.7)
|
|
326
|
+
for ev in pass_result.kernel_events:
|
|
327
|
+
name_raw = ev["name"]
|
|
328
|
+
name = sys.intern(name_raw)
|
|
329
|
+
dur, ts = ev.get("dur", 0), ev["ts"]
|
|
330
|
+
corr_id = ev.get("args", {}).get("correlation")
|
|
331
|
+
ext_id = ev.get("args", {}).get("External id")
|
|
332
|
+
|
|
333
|
+
phase = _get_phase_for_timestamp(ts)
|
|
334
|
+
|
|
335
|
+
op, pattern = classify(name, platform)
|
|
336
|
+
kernel_patterns[(op.value, phase)].add(pattern)
|
|
337
|
+
|
|
338
|
+
layer = layer_mapping.get(corr_id) if corr_id is not None else None
|
|
339
|
+
cpu_op = pass_result.cpu_op_mapping.get(ext_id) if ext_id is not None else None
|
|
340
|
+
python_stack: list[str] = []
|
|
341
|
+
|
|
342
|
+
if cpu_op is None and include_stacks:
|
|
343
|
+
if name in cpu_op_cache:
|
|
344
|
+
cpu_op = cpu_op_cache[name]
|
|
345
|
+
else:
|
|
346
|
+
cpu_op, python_stack = _get_python_stack_full(
|
|
347
|
+
ts, pass_result.python_intervals, pass_result.python_by_id
|
|
348
|
+
)
|
|
349
|
+
cpu_op_cache[name] = cpu_op
|
|
350
|
+
|
|
351
|
+
kernel_data.append({
|
|
352
|
+
"name": name,
|
|
353
|
+
"dur_us": dur,
|
|
354
|
+
"phase": phase,
|
|
355
|
+
"op": op.value,
|
|
356
|
+
"pattern": pattern,
|
|
357
|
+
"layer": layer,
|
|
358
|
+
"correlation": corr_id,
|
|
359
|
+
"cpu_op": cpu_op,
|
|
360
|
+
"python_stack": python_stack,
|
|
361
|
+
})
|
|
362
|
+
|
|
363
|
+
_progress("Building DataFrame", 0.95)
|
|
364
|
+
df = pd.DataFrame(kernel_data)
|
|
365
|
+
_progress("Complete", 1.0)
|
|
366
|
+
|
|
367
|
+
return platform, gpu_name, device_props, df, dict(kernel_patterns), layer_mapping
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def load_trace_full(
|
|
371
|
+
file_path: str | Path,
|
|
372
|
+
include_stacks: bool = True,
|
|
373
|
+
on_progress: ProgressCallback | None = None,
|
|
374
|
+
) -> LoadedTrace:
|
|
375
|
+
"""Load trace once with all data needed by downstream analysis functions.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
file_path: Path to the trace JSON file
|
|
379
|
+
include_stacks: Whether to resolve Python call stacks
|
|
380
|
+
on_progress: Optional callback for progress updates
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
LoadedTrace with all trace data
|
|
384
|
+
"""
|
|
385
|
+
def _progress(stage: str, pct: float) -> None:
|
|
386
|
+
if on_progress:
|
|
387
|
+
on_progress(stage, pct)
|
|
388
|
+
|
|
389
|
+
_progress("Reading file", 0.0)
|
|
390
|
+
with open(file_path, "rb") as f:
|
|
391
|
+
raw = f.read()
|
|
392
|
+
|
|
393
|
+
_progress("Parsing JSON", 0.1)
|
|
394
|
+
trace = orjson.loads(raw)
|
|
395
|
+
all_events = trace.get("traceEvents", [])
|
|
396
|
+
|
|
397
|
+
props = trace.get("deviceProperties", [{}])[0]
|
|
398
|
+
is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
|
|
399
|
+
platform = "AMD" if is_amd else "NVIDIA"
|
|
400
|
+
gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
|
|
401
|
+
|
|
402
|
+
device_props = {
|
|
403
|
+
"name": gpu_name,
|
|
404
|
+
"compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
|
|
405
|
+
"total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
|
|
406
|
+
"sm_count": props.get("numSms", 0),
|
|
407
|
+
"warp_size": props.get("warpSize", 32),
|
|
408
|
+
"max_threads_per_block": props.get("maxThreadsPerBlock", 0),
|
|
409
|
+
"shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
_progress("Processing events", 0.4)
|
|
413
|
+
pass_result = _process_events_single_pass(all_events, include_stacks=include_stacks)
|
|
414
|
+
|
|
415
|
+
_progress("Detecting architecture", 0.6)
|
|
416
|
+
kernel_names = [ev.get("name", "") for ev in pass_result.kernel_events]
|
|
417
|
+
architecture, _ = detect_architecture(kernel_names)
|
|
418
|
+
|
|
419
|
+
layer_kernels_dict, layer_warnings = segment_layers_by_architecture(
|
|
420
|
+
pass_result.kernel_events,
|
|
421
|
+
architecture,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Convert layer_kernels_dict (layer_num -> kernels) to correlation_id -> layer_num mapping
|
|
425
|
+
if layer_kernels_dict:
|
|
426
|
+
layer_mapping: dict[int, int] = {}
|
|
427
|
+
for layer_num, kernels in layer_kernels_dict.items():
|
|
428
|
+
for kernel in kernels:
|
|
429
|
+
corr_id = kernel.get("args", {}).get("correlation")
|
|
430
|
+
if corr_id is not None:
|
|
431
|
+
layer_mapping[corr_id] = layer_num
|
|
432
|
+
else:
|
|
433
|
+
layer_mapping = _build_layer_mapping(pass_result.correlation_groups)
|
|
434
|
+
|
|
435
|
+
kernel_data = []
|
|
436
|
+
kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
|
|
437
|
+
sorted_phases = sorted(pass_result.phases, key=lambda p: p["ts_start"])
|
|
438
|
+
|
|
439
|
+
phase_starts = [p["ts_start"] for p in sorted_phases]
|
|
440
|
+
phase_types = [p["type"] for p in sorted_phases]
|
|
441
|
+
phase_ends = [p["ts_end"] for p in sorted_phases]
|
|
442
|
+
|
|
443
|
+
def _get_phase_for_timestamp(ts: int) -> str:
|
|
444
|
+
"""Get phase for a timestamp using binary search. O(log n)."""
|
|
445
|
+
if not phase_starts:
|
|
446
|
+
return "decode"
|
|
447
|
+
idx = bisect.bisect_right(phase_starts, ts) - 1
|
|
448
|
+
if idx >= 0 and phase_starts[idx] <= ts <= phase_ends[idx]:
|
|
449
|
+
return phase_types[idx]
|
|
450
|
+
return "decode"
|
|
451
|
+
|
|
452
|
+
cpu_op_cache: dict[str, str | None] = {}
|
|
453
|
+
|
|
454
|
+
_progress("Classifying kernels", 0.7)
|
|
455
|
+
for ev in pass_result.kernel_events:
|
|
456
|
+
name_raw = ev["name"]
|
|
457
|
+
name = sys.intern(name_raw)
|
|
458
|
+
dur, ts = ev.get("dur", 0), ev["ts"]
|
|
459
|
+
corr_id = ev.get("args", {}).get("correlation")
|
|
460
|
+
ext_id = ev.get("args", {}).get("External id")
|
|
461
|
+
|
|
462
|
+
phase = _get_phase_for_timestamp(ts)
|
|
463
|
+
|
|
464
|
+
op, pattern = classify(name, platform)
|
|
465
|
+
kernel_patterns[(op.value, phase)].add(pattern)
|
|
466
|
+
|
|
467
|
+
layer = layer_mapping.get(corr_id) if corr_id is not None else None
|
|
468
|
+
cpu_op = pass_result.cpu_op_mapping.get(ext_id) if ext_id is not None else None
|
|
469
|
+
python_stack: list[str] = []
|
|
470
|
+
|
|
471
|
+
if cpu_op is None and include_stacks:
|
|
472
|
+
if name in cpu_op_cache:
|
|
473
|
+
cpu_op = cpu_op_cache[name]
|
|
474
|
+
else:
|
|
475
|
+
cpu_op, python_stack = _get_python_stack_full(
|
|
476
|
+
ts, pass_result.python_intervals, pass_result.python_by_id
|
|
477
|
+
)
|
|
478
|
+
cpu_op_cache[name] = cpu_op
|
|
479
|
+
|
|
480
|
+
kernel_data.append({
|
|
481
|
+
"name": name,
|
|
482
|
+
"dur_us": dur,
|
|
483
|
+
"phase": phase,
|
|
484
|
+
"op": op.value,
|
|
485
|
+
"pattern": pattern,
|
|
486
|
+
"layer": layer,
|
|
487
|
+
"correlation": corr_id,
|
|
488
|
+
"cpu_op": cpu_op,
|
|
489
|
+
"python_stack": python_stack,
|
|
490
|
+
})
|
|
491
|
+
|
|
492
|
+
_progress("Building DataFrame", 0.95)
|
|
493
|
+
df = pd.DataFrame(kernel_data)
|
|
494
|
+
|
|
495
|
+
kernel_events = pass_result.kernel_events
|
|
496
|
+
correlation_groups: dict[int, list[dict[str, Any]]] = defaultdict(list)
|
|
497
|
+
for ev in kernel_events:
|
|
498
|
+
corr_id = ev.get("args", {}).get("correlation")
|
|
499
|
+
if corr_id is not None:
|
|
500
|
+
correlation_groups[corr_id].append(ev)
|
|
501
|
+
|
|
502
|
+
_progress("Complete", 1.0)
|
|
503
|
+
|
|
504
|
+
return LoadedTrace(
|
|
505
|
+
platform=platform,
|
|
506
|
+
gpu_name=gpu_name,
|
|
507
|
+
device_props=device_props,
|
|
508
|
+
df=df,
|
|
509
|
+
patterns=dict(kernel_patterns),
|
|
510
|
+
layers=layer_mapping,
|
|
511
|
+
kernel_events=kernel_events,
|
|
512
|
+
all_events=all_events,
|
|
513
|
+
correlation_groups=dict(correlation_groups),
|
|
514
|
+
phases=pass_result.phases,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
@dataclass
|
|
519
|
+
class StreamingMetadata:
|
|
520
|
+
"""Early metadata available before full trace processing."""
|
|
521
|
+
platform: str
|
|
522
|
+
gpu_name: str
|
|
523
|
+
device_props: dict[str, Any]
|
|
524
|
+
file_size_mb: float
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def _extract_metadata_fast(file_path: Path) -> StreamingMetadata:
|
|
528
|
+
"""Extract trace metadata instantly using streaming parser (~2ms).
|
|
529
|
+
|
|
530
|
+
Uses ijson to read only the deviceProperties section without
|
|
531
|
+
parsing the entire file. Falls back to full parse if ijson unavailable.
|
|
532
|
+
"""
|
|
533
|
+
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
|
534
|
+
platform = "Unknown"
|
|
535
|
+
gpu_name = "Unknown GPU"
|
|
536
|
+
device_props: dict[str, Any] = {}
|
|
537
|
+
|
|
538
|
+
if ijson is None:
|
|
539
|
+
# Fallback: parse full file (slower but works)
|
|
540
|
+
with open(file_path, "rb") as f:
|
|
541
|
+
trace = orjson.loads(f.read())
|
|
542
|
+
props = trace.get("deviceProperties", [{}])[0]
|
|
543
|
+
is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
|
|
544
|
+
platform = "AMD" if is_amd else "NVIDIA"
|
|
545
|
+
gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
|
|
546
|
+
device_props = {
|
|
547
|
+
"name": gpu_name,
|
|
548
|
+
"compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
|
|
549
|
+
"total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
|
|
550
|
+
"sm_count": props.get("numSms", 0),
|
|
551
|
+
"warp_size": props.get("warpSize", 32),
|
|
552
|
+
"max_threads_per_block": props.get("maxThreadsPerBlock", 0),
|
|
553
|
+
"shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
|
|
554
|
+
}
|
|
555
|
+
return StreamingMetadata(
|
|
556
|
+
platform=platform,
|
|
557
|
+
gpu_name=gpu_name,
|
|
558
|
+
device_props=device_props,
|
|
559
|
+
file_size_mb=file_size_mb,
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
with open(file_path, "rb") as f:
|
|
563
|
+
parser = ijson.parse(f)
|
|
564
|
+
for prefix, event, value in parser:
|
|
565
|
+
if prefix == "deviceProperties.item.name":
|
|
566
|
+
gpu_name = value
|
|
567
|
+
elif prefix == "deviceProperties.item.warpSize":
|
|
568
|
+
platform = "AMD" if value == 64 else "NVIDIA"
|
|
569
|
+
elif prefix == "deviceProperties.item.computeMajor":
|
|
570
|
+
device_props["compute_major"] = value
|
|
571
|
+
elif prefix == "deviceProperties.item.computeMinor":
|
|
572
|
+
device_props["compute_minor"] = value
|
|
573
|
+
elif prefix == "deviceProperties.item.totalGlobalMem":
|
|
574
|
+
device_props["total_memory_gb"] = value / (1024**3)
|
|
575
|
+
elif prefix == "deviceProperties.item.numSms":
|
|
576
|
+
device_props["sm_count"] = value
|
|
577
|
+
elif prefix == "deviceProperties.item.maxThreadsPerBlock":
|
|
578
|
+
device_props["max_threads_per_block"] = value
|
|
579
|
+
elif prefix == "deviceProperties.item.sharedMemPerBlock":
|
|
580
|
+
device_props["shared_mem_per_block_kb"] = value / 1024
|
|
581
|
+
elif prefix == "traceEvents.item":
|
|
582
|
+
# Hit first event, stop - we only need metadata
|
|
583
|
+
break
|
|
584
|
+
|
|
585
|
+
# Fallback platform detection
|
|
586
|
+
if platform == "Unknown":
|
|
587
|
+
platform = "AMD" if "MI" in gpu_name or "Instinct" in gpu_name else "NVIDIA"
|
|
588
|
+
|
|
589
|
+
device_props["name"] = gpu_name
|
|
590
|
+
device_props["compute_capability"] = f"{device_props.get('compute_major', 0)}.{device_props.get('compute_minor', 0)}"
|
|
591
|
+
device_props["warp_size"] = 64 if platform == "AMD" else 32
|
|
592
|
+
|
|
593
|
+
return StreamingMetadata(
|
|
594
|
+
platform=platform,
|
|
595
|
+
gpu_name=gpu_name,
|
|
596
|
+
device_props=device_props,
|
|
597
|
+
file_size_mb=file_size_mb,
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
def load_trace_streaming(
|
|
602
|
+
file_path: str | Path,
|
|
603
|
+
include_stacks: bool = True,
|
|
604
|
+
on_metadata: Callable[[StreamingMetadata], None] | None = None,
|
|
605
|
+
on_progress: ProgressCallback | None = None,
|
|
606
|
+
) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
|
|
607
|
+
"""Load trace with instant metadata feedback.
|
|
608
|
+
|
|
609
|
+
Hybrid approach:
|
|
610
|
+
1. Phase 1 (~2ms): Extract metadata with ijson streaming
|
|
611
|
+
2. Phase 2 (full): Parse with orjson and process
|
|
612
|
+
|
|
613
|
+
The on_metadata callback fires immediately with GPU/platform info,
|
|
614
|
+
allowing the UI to show feedback while the full load continues.
|
|
615
|
+
|
|
616
|
+
Args:
|
|
617
|
+
file_path: Path to the trace JSON file
|
|
618
|
+
include_stacks: Whether to resolve Python call stacks
|
|
619
|
+
on_metadata: Callback for instant metadata (fires in ~2ms)
|
|
620
|
+
on_progress: Callback for progress updates during processing
|
|
621
|
+
"""
|
|
622
|
+
file_path = Path(file_path)
|
|
623
|
+
|
|
624
|
+
# Phase 1: Instant metadata extraction (~2ms)
|
|
625
|
+
metadata = _extract_metadata_fast(file_path)
|
|
626
|
+
|
|
627
|
+
if on_metadata:
|
|
628
|
+
on_metadata(metadata)
|
|
629
|
+
|
|
630
|
+
# Phase 2: Full load with orjson (fast) + progress updates
|
|
631
|
+
return load_trace(
|
|
632
|
+
file_path,
|
|
633
|
+
include_stacks=include_stacks,
|
|
634
|
+
on_progress=on_progress,
|
|
635
|
+
)
|