wafer-core 0.1.26__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,233 +5,268 @@ Python call stacks, CPU operator mappings, and layer correlations.
5
5
  """
6
6
 
7
7
  import bisect
8
- import json
8
+ import sys
9
9
  from collections import defaultdict
10
+ from collections.abc import Callable
11
+ from dataclasses import dataclass, field
10
12
  from pathlib import Path
11
13
  from typing import Any
12
14
 
15
+ try:
16
+ import ijson
17
+ except ImportError:
18
+ ijson = None
19
+
20
+ import orjson
13
21
  import pandas as pd
14
22
 
23
+ from .architecture import detect_architecture
15
24
  from .classifier import classify
25
+ from .layer_segmentation import segment_layers_by_architecture
16
26
 
27
+ ProgressCallback = Callable[[str, float], None]
17
28
 
18
- def extract_layer_mapping(events: list[dict[str, Any]], platform: str) -> dict[int, int]:
19
- """Extract correlation ID to layer number mapping.
20
-
21
- vLLM's execution graph creates large correlation groups for full transformer layers.
22
- Each layer's forward pass (norm + attention + FFN) gets grouped under one correlation ID,
23
- containing 200-400 kernels depending on batch size and sequence length.
24
-
25
- We identify layers as correlation groups with many kernels (70+), which filters out
26
- individual operations like sampling, logit processing, etc.
27
29
 
30
+ @dataclass
31
+ class SinglePassResult:
32
+ """Collected data from single-pass event processing."""
33
+ cpu_op_mapping: dict[int, str] = field(default_factory=dict)
34
+ python_intervals: list[tuple[int, int, int, int | None, str]] = field(default_factory=list)
35
+ # Raw python events for lazy python_by_id construction (built on-demand in _get_python_stack_full)
36
+ python_events_raw: list[dict[str, Any]] = field(default_factory=list)
37
+ phases: list[dict[str, Any]] = field(default_factory=list)
38
+ correlation_groups: dict[int, dict[str, Any]] = field(default_factory=lambda: defaultdict(
39
+ lambda: {"count": 0, "has_attention": False, "has_ffn": False}
40
+ ))
41
+ kernel_events: list[dict[str, Any]] = field(default_factory=list)
42
+ # Lazily built when needed for stack resolution
43
+ _python_by_id: dict[int, dict[str, Any]] | None = field(default=None)
44
+
45
+ @property
46
+ def python_by_id(self) -> dict[int, dict[str, Any]]:
47
+ """Lazily build python_by_id from raw events on first access."""
48
+ if self._python_by_id is None:
49
+ self._python_by_id = {}
50
+ for ev in self.python_events_raw:
51
+ args = ev.get("args")
52
+ py_id = args.get("Python id") if args else None
53
+ if py_id is not None:
54
+ self._python_by_id[py_id] = {
55
+ "name": ev["name"],
56
+ "parent_id": args.get("Python parent id") if args else None,
57
+ }
58
+ return self._python_by_id
59
+
60
+
61
+ @dataclass
62
+ class LoadedTrace:
63
+ """Complete trace data loaded once and reused by all analysis functions."""
64
+ platform: str
65
+ gpu_name: str
66
+ device_props: dict[str, Any]
67
+ df: pd.DataFrame
68
+ patterns: dict[tuple[str, str], set[str]]
69
+ layers: dict[int, int]
70
+ # For fusion/warnings (kept from raw JSON)
71
+ kernel_events: list[dict[str, Any]]
72
+ all_events: list[dict[str, Any]]
73
+ correlation_groups: dict[int, list[dict[str, Any]]]
74
+ phases: list[dict[str, Any]] # Phase annotations for alignment
75
+
76
+
77
+ def _process_events_single_pass(
78
+ events: list[dict[str, Any]],
79
+ include_stacks: bool = True,
80
+ ) -> SinglePassResult:
81
+ """Process all events in a single iteration.
82
+
83
+ Optimizations applied:
84
+ - Cache list.append methods for 2-3x speedup on hot paths
85
+ - Store raw python events, build python_by_id lazily (only ~48 lookups due to caching)
86
+ - Local variable caching for frequently accessed attributes
87
+
28
88
  Args:
29
89
  events: List of trace events
30
- platform: 'AMD' or 'NVIDIA'
31
-
32
- Returns:
33
- Dict mapping correlation ID to layer number
90
+ include_stacks: Whether to collect Python stack info (expensive operation).
91
+ When False, skips python_function processing entirely.
34
92
  """
35
- # Group kernels by correlation ID
36
- correlation_groups = defaultdict(
93
+ result = SinglePassResult()
94
+ correlation_groups: dict[int, dict[str, Any]] = defaultdict(
37
95
  lambda: {"count": 0, "has_attention": False, "has_ffn": False}
38
96
  )
39
-
97
+
98
+ # Cache list.append methods for faster appending (measured 2-3x speedup)
99
+ kernel_append = result.kernel_events.append
100
+ python_interval_append = result.python_intervals.append
101
+ python_raw_append = result.python_events_raw.append
102
+ phase_append = result.phases.append
103
+ cpu_op_mapping = result.cpu_op_mapping
104
+
40
105
  for ev in events:
41
- if ev.get("cat") != "kernel":
42
- continue
43
-
44
- corr_id = ev.get("args", {}).get("correlation")
45
- if corr_id is None:
46
- continue
47
-
48
- kernel_name = ev.get("name", "").lower()
49
-
50
- # Track what operations this correlation contains
51
- correlation_groups[corr_id]["count"] += 1
52
- if "attention" in kernel_name or "fmha" in kernel_name:
53
- correlation_groups[corr_id]["has_attention"] = True
54
- if any(x in kernel_name for x in ["cijk_", "nvjet", "wvsplitk", "gemm"]):
55
- correlation_groups[corr_id]["has_ffn"] = True
56
-
57
- # Map correlation IDs to layer numbers
58
- # Transformer layers have many kernels AND contain both attention and FFN ops
106
+ cat = ev.get("cat")
107
+
108
+ if cat == "kernel":
109
+ args = ev.get("args")
110
+ corr_id = args.get("correlation") if args else None
111
+ if corr_id is not None:
112
+ kernel_name = ev.get("name", "").lower()
113
+ grp = correlation_groups[corr_id]
114
+ grp["count"] += 1
115
+ if "attention" in kernel_name or "fmha" in kernel_name:
116
+ grp["has_attention"] = True
117
+ if "cijk_" in kernel_name or "nvjet" in kernel_name or "wvsplitk" in kernel_name or "gemm" in kernel_name:
118
+ grp["has_ffn"] = True
119
+ kernel_append(ev)
120
+
121
+ elif cat == "cpu_op":
122
+ args = ev.get("args")
123
+ ext_id = args.get("External id") if args else None
124
+ if ext_id is not None:
125
+ cpu_op_mapping[ext_id] = ev.get("name", "")
126
+
127
+ elif cat == "python_function" and include_stacks:
128
+ # Store raw event for lazy python_by_id construction
129
+ python_raw_append(ev)
130
+ # Build interval tuple for binary search
131
+ args = ev.get("args")
132
+ py_id = args.get("Python id") if args else None
133
+ ts = ev["ts"]
134
+ dur = ev.get("dur", 0)
135
+ python_interval_append((ts, ts + dur, dur, py_id, ev["name"]))
136
+
137
+ elif cat == "user_annotation":
138
+ name = ev.get("name", "")
139
+ if name.startswith("execute_context"):
140
+ tokens = 0
141
+ parts = name.split("_")
142
+ for i, p in enumerate(parts):
143
+ if i > 0 and parts[i-1] == "context" and "(" in p and ")" in p:
144
+ try:
145
+ tokens = int(p.split("(")[1].split(")")[0])
146
+ break
147
+ except Exception:
148
+ pass
149
+ is_prefill = "generation_0" in name and tokens > 0
150
+ phase_append({
151
+ "type": "prefill" if is_prefill else "decode",
152
+ "ts_start": ev["ts"],
153
+ "ts_end": ev["ts"] + ev["dur"],
154
+ })
155
+
156
+ if include_stacks and result.python_intervals:
157
+ result.python_intervals.sort()
158
+
159
+ result.correlation_groups = dict(correlation_groups)
160
+
161
+ return result
162
+
163
+
164
+ def _build_layer_mapping(correlation_groups: dict[int, dict[str, Any]]) -> dict[int, int]:
165
+ """Build layer mapping from correlation groups."""
59
166
  correlation_to_layer = {}
60
167
  layer_num = 0
61
-
168
+
62
169
  for corr_id in sorted(correlation_groups.keys()):
63
170
  group = correlation_groups[corr_id]
64
-
65
- # Identify complete transformer layers by their characteristics:
66
- # - Has attention operations (self-attention or cross-attention)
67
- # - Has FFN operations (feed-forward network)
68
- # - Has sufficient kernel count (70+): typical transformer block has ~80-100 kernels
69
- # including attention QKV projections, softmax, output projection, FFN layers,
70
- # normalization, and elementwise ops. This threshold filters out:
71
- # - Individual operations (1-10 kernels)
72
- # - Sampling/generation steps (20-40 kernels)
73
- # - Partial layer executions
171
+
74
172
  is_layer = (
75
173
  group["count"] >= 70 and group["has_attention"] and group["has_ffn"]
76
174
  )
77
-
175
+
78
176
  if is_layer:
79
177
  correlation_to_layer[corr_id] = layer_num
80
178
  layer_num += 1
81
-
179
+
82
180
  return correlation_to_layer
83
181
 
84
182
 
85
- def _build_python_stack_index(
86
- events: list[dict[str, Any]],
87
- ) -> tuple[list[tuple[int, int, int, int | None, str]], dict[int, dict[str, Any]]]:
88
- """Build Python call stack index for kernels.
89
-
90
- Args:
91
- events: List of trace events
92
-
93
- Returns:
94
- Tuple of (python_intervals, python_by_id)
95
- """
96
- python_by_id: dict[int, dict[str, Any]] = {}
97
- python_intervals: list[tuple[int, int, int, int | None, str]] = []
98
-
99
- for ev in events:
100
- if ev.get("cat") == "python_function":
101
- py_id = ev.get("args", {}).get("Python id")
102
- name = ev["name"]
103
- ts_start = ev["ts"]
104
- ts_end = ts_start + ev.get("dur", 0)
105
- duration = ev.get("dur", 0)
106
- parent_id = ev.get("args", {}).get("Python parent id")
107
-
108
- python_intervals.append((ts_start, ts_end, duration, py_id, name))
109
-
110
- if py_id is not None:
111
- python_by_id[py_id] = {
112
- "name": name,
113
- "parent_id": parent_id,
114
- "ts_start": ts_start,
115
- "ts_end": ts_end,
116
- "duration": duration,
117
- }
118
-
119
- # Sort by start time for efficient binary search
120
- python_intervals.sort()
121
-
122
- return python_intervals, python_by_id
123
-
124
-
125
183
  def _get_python_stack_full(
126
184
  timestamp: int,
127
185
  python_intervals: list[tuple[int, int, int, int | None, str]],
128
186
  python_by_id: dict[int, dict[str, Any]],
129
187
  ) -> tuple[str | None, list[str]]:
130
188
  """Get full Python call stack for a kernel launch.
131
-
132
- Args:
133
- timestamp: Kernel launch timestamp
134
- python_intervals: Sorted list of Python function intervals
135
- python_by_id: Mapping of Python ID to function info
136
-
137
- Returns:
138
- Tuple of (summary_string, full_stack_list)
189
+
190
+ Computes the chain on-demand by walking parent pointers.
191
+ This is fast because we only call this ~48 times due to cpu_op caching.
139
192
  """
140
- # Binary search for Python functions active at this timestamp
141
193
  idx = bisect.bisect_right(
142
194
  python_intervals, (timestamp, float("inf"), float("inf"), None, "")
143
195
  )
144
-
145
- # Find active functions
196
+
146
197
  active_funcs = []
147
198
  for i in range(idx - 1, max(0, idx - 1000), -1):
148
199
  ts_start, ts_end, duration, py_id, name = python_intervals[i]
149
200
  if ts_start <= timestamp <= ts_end:
150
201
  active_funcs.append((duration, py_id, name))
151
- if ts_end < timestamp - 1000000: # 1 second before
202
+ if ts_end < timestamp - 1000000:
152
203
  break
153
-
204
+
154
205
  if not active_funcs:
155
206
  return None, []
156
-
157
- # Get the innermost (most specific) function
207
+
158
208
  active_funcs.sort()
159
209
  leaf_duration, leaf_id, leaf_name = active_funcs[0]
160
-
161
- # Walk up parent chain to get FULL stack
210
+
162
211
  full_stack = []
163
212
  current_id = leaf_id
164
- visited = set()
165
-
166
- while (
167
- current_id is not None
168
- and current_id not in visited
169
- and current_id in python_by_id
170
- ):
213
+ visited: set[int] = set()
214
+
215
+ while current_id is not None and current_id not in visited and current_id in python_by_id:
171
216
  func = python_by_id[current_id]
172
- name = func["name"]
173
- full_stack.append(name)
174
-
217
+ full_stack.append(func["name"])
175
218
  visited.add(current_id)
176
219
  current_id = func["parent_id"]
177
-
178
- # Safety limit: prevent infinite loops from circular parent references
179
- # and bound memory usage. 50 frames is deeper than typical Python stacks.
180
220
  if len(full_stack) >= 50:
181
221
  break
182
-
183
- # Reverse so it's outermost -> innermost
222
+
184
223
  full_stack.reverse()
185
-
186
- # Create summary for text output: show the most informative vLLM/model function
224
+
187
225
  summary = None
188
- vllm_funcs = [
189
- f
190
- for f in full_stack
191
- if any(x in f.lower() for x in ["vllm/", "model", "<eval_with_key>"])
192
- ]
193
-
226
+ vllm_funcs = [f for f in full_stack if any(x in f.lower() for x in ["vllm/", "model", "<eval_with_key>"])]
227
+
194
228
  if vllm_funcs:
195
- # Get innermost vLLM function (most specific)
196
229
  summary = vllm_funcs[-1]
197
-
198
- # Check if it's a CUDA graph - add annotation
199
230
  if any("torch/cuda/graphs" in f for f in full_stack):
200
- # Shorten if too long
201
231
  if len(summary) > 45:
202
- parts = summary.split("/")[-1]
203
- summary = "vllm/..." + parts
232
+ summary = "vllm/..." + summary.split("/")[-1]
204
233
  summary = f"{summary} [CUDA graph]"
205
234
  elif len(summary) > 53:
206
- parts = summary.split("/")[-1]
207
- summary = "vllm/..." + parts
235
+ summary = "vllm/..." + summary.split("/")[-1]
208
236
  else:
209
- # Fallback to innermost function
210
237
  summary = leaf_name
211
-
238
+
212
239
  return summary, full_stack
213
240
 
214
241
 
215
242
  def load_trace(
216
243
  file_path: str | Path,
244
+ include_stacks: bool = True,
245
+ on_progress: ProgressCallback | None = None,
217
246
  ) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
218
247
  """Load trace and return platform info, device properties, kernels, patterns, and layer mapping.
219
-
248
+
220
249
  Args:
221
- file_path: Path to JSON trace file
222
-
223
- Returns:
224
- Tuple of (platform, gpu_name, device_props, kernel_df, kernel_patterns, layer_mapping)
250
+ file_path: Path to the trace JSON file
251
+ include_stacks: Whether to resolve Python call stacks (slower but more info)
252
+ on_progress: Optional callback for progress updates: (stage_name, progress_fraction)
225
253
  """
254
+ def _progress(stage: str, pct: float) -> None:
255
+ if on_progress:
256
+ on_progress(stage, pct)
257
+
258
+ _progress("Reading file", 0.0)
226
259
  with open(file_path, "rb") as f:
227
- trace = json.load(f)
228
-
260
+ raw = f.read()
261
+
262
+ _progress("Parsing JSON", 0.1)
263
+ trace = orjson.loads(raw)
264
+
229
265
  props = trace.get("deviceProperties", [{}])[0]
230
266
  is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
231
267
  platform = "AMD" if is_amd else "NVIDIA"
232
268
  gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
233
-
234
- # Extract relevant device properties
269
+
235
270
  device_props = {
236
271
  "name": gpu_name,
237
272
  "compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
@@ -241,96 +276,360 @@ def load_trace(
241
276
  "max_threads_per_block": props.get("maxThreadsPerBlock", 0),
242
277
  "shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
243
278
  }
244
-
279
+
280
+ _progress("Processing events", 0.4)
245
281
  events = trace.get("traceEvents", [])
246
-
247
- # Build mapping: external_id -> CPU operator name
248
- external_to_cpu = {}
249
- for ev in events:
250
- if ev.get("cat") == "cpu_op":
251
- ext_id = ev.get("args", {}).get("External id")
252
- cpu_op_name = ev.get("name", "")
253
- if ext_id is not None:
254
- external_to_cpu[ext_id] = cpu_op_name
255
-
256
- # Build Python call stack index for kernels without External IDs
257
- python_intervals, python_by_id = _build_python_stack_index(events)
258
-
259
- # Extract phases
260
- phases = []
261
- for ev in events:
262
- if ev.get("cat") == "user_annotation" and ev.get("name", "").startswith(
263
- "execute_context"
264
- ):
265
- name = ev["name"]
266
- # Parse execute_context_X(TOKENS)_generation_Y(Y)
267
- # We want the TOKENS from execute_context, not the generation number
268
- tokens = 0
269
- parts = name.split("_")
270
- for i, p in enumerate(parts):
271
- # Look for execute_context_X(TOKENS) specifically
272
- if i > 0 and parts[i-1] == "context" and "(" in p and ")" in p:
273
- try:
274
- tokens = int(p.split("(")[1].split(")")[0])
275
- break # Stop after finding context tokens
276
- except Exception:
277
- pass
278
- is_prefill = tokens >= 1024 and "generation_0" in name
279
- phases.append(
280
- {
281
- "type": "prefill" if is_prefill else "decode",
282
- "ts_start": ev["ts"],
283
- "ts_end": ev["ts"] + ev["dur"],
284
- }
285
- )
286
-
287
- # Extract layer mapping from correlation IDs
288
- layer_mapping = extract_layer_mapping(events, platform)
289
-
282
+ pass_result = _process_events_single_pass(events, include_stacks=include_stacks)
283
+
284
+ _progress("Detecting architecture", 0.6)
285
+ kernel_names = [ev.get("name", "") for ev in pass_result.kernel_events]
286
+ architecture, _ = detect_architecture(kernel_names)
287
+
288
+ layer_kernels_dict, layer_warnings = segment_layers_by_architecture(
289
+ pass_result.kernel_events,
290
+ architecture,
291
+ )
292
+
293
+ # Convert layer_kernels_dict (layer_num -> kernels) to correlation_id -> layer_num mapping
294
+ # Fall back to old method if architecture-based segmentation fails
295
+ if layer_kernels_dict:
296
+ layer_mapping: dict[int, int] = {}
297
+ for layer_num, kernels in layer_kernels_dict.items():
298
+ for kernel in kernels:
299
+ corr_id = kernel.get("args", {}).get("correlation")
300
+ if corr_id is not None:
301
+ layer_mapping[corr_id] = layer_num
302
+ else:
303
+ # Fallback to correlation-based method if architecture segmentation failed
304
+ layer_mapping = _build_layer_mapping(pass_result.correlation_groups)
305
+
290
306
  kernel_data = []
291
307
  kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
292
-
293
- for ev in events:
294
- if ev.get("cat") != "kernel":
295
- continue
296
- name, dur, ts = ev["name"], ev.get("dur", 0), ev["ts"]
308
+ sorted_phases = sorted(pass_result.phases, key=lambda p: p["ts_start"])
309
+
310
+ phase_starts = [p["ts_start"] for p in sorted_phases]
311
+ phase_types = [p["type"] for p in sorted_phases]
312
+ phase_ends = [p["ts_end"] for p in sorted_phases]
313
+
314
+ def _get_phase_for_timestamp(ts: int) -> str:
315
+ """Get phase for a timestamp using binary search. O(log n)."""
316
+ if not phase_starts:
317
+ return "decode"
318
+ idx = bisect.bisect_right(phase_starts, ts) - 1
319
+ if idx >= 0 and phase_starts[idx] <= ts <= phase_ends[idx]:
320
+ return phase_types[idx]
321
+ return "decode"
322
+
323
+ cpu_op_cache: dict[str, str | None] = {}
324
+
325
+ _progress("Classifying kernels", 0.7)
326
+ for ev in pass_result.kernel_events:
327
+ name_raw = ev["name"]
328
+ name = sys.intern(name_raw)
329
+ dur, ts = ev.get("dur", 0), ev["ts"]
297
330
  corr_id = ev.get("args", {}).get("correlation")
298
331
  ext_id = ev.get("args", {}).get("External id")
299
-
300
- phase = "decode"
301
- for p in phases:
302
- if p["ts_start"] <= ts <= p["ts_end"]:
303
- phase = p["type"]
304
- break
305
-
332
+
333
+ phase = _get_phase_for_timestamp(ts)
334
+
306
335
  op, pattern = classify(name, platform)
307
336
  kernel_patterns[(op.value, phase)].add(pattern)
308
-
309
- # Assign layer number from correlation ID
337
+
310
338
  layer = layer_mapping.get(corr_id) if corr_id is not None else None
311
-
312
- # Get CPU operator name from external ID, or fallback to Python stack
313
- cpu_op = external_to_cpu.get(ext_id) if ext_id is not None else None
339
+ cpu_op = pass_result.cpu_op_mapping.get(ext_id) if ext_id is not None else None
340
+ python_stack: list[str] = []
341
+
342
+ if cpu_op is None and include_stacks:
343
+ if name in cpu_op_cache:
344
+ cpu_op = cpu_op_cache[name]
345
+ else:
346
+ cpu_op, python_stack = _get_python_stack_full(
347
+ ts, pass_result.python_intervals, pass_result.python_by_id
348
+ )
349
+ cpu_op_cache[name] = cpu_op
350
+
351
+ kernel_data.append({
352
+ "name": name,
353
+ "dur_us": dur,
354
+ "phase": phase,
355
+ "op": op.value,
356
+ "pattern": pattern,
357
+ "layer": layer,
358
+ "correlation": corr_id,
359
+ "cpu_op": cpu_op,
360
+ "python_stack": python_stack,
361
+ })
362
+
363
+ _progress("Building DataFrame", 0.95)
364
+ df = pd.DataFrame(kernel_data)
365
+ _progress("Complete", 1.0)
366
+
367
+ return platform, gpu_name, device_props, df, dict(kernel_patterns), layer_mapping
368
+
369
+
370
+ def load_trace_full(
371
+ file_path: str | Path,
372
+ include_stacks: bool = True,
373
+ on_progress: ProgressCallback | None = None,
374
+ ) -> LoadedTrace:
375
+ """Load trace once with all data needed by downstream analysis functions.
376
+
377
+ Args:
378
+ file_path: Path to the trace JSON file
379
+ include_stacks: Whether to resolve Python call stacks
380
+ on_progress: Optional callback for progress updates
381
+
382
+ Returns:
383
+ LoadedTrace with all trace data
384
+ """
385
+ def _progress(stage: str, pct: float) -> None:
386
+ if on_progress:
387
+ on_progress(stage, pct)
388
+
389
+ _progress("Reading file", 0.0)
390
+ with open(file_path, "rb") as f:
391
+ raw = f.read()
392
+
393
+ _progress("Parsing JSON", 0.1)
394
+ trace = orjson.loads(raw)
395
+ all_events = trace.get("traceEvents", [])
396
+
397
+ props = trace.get("deviceProperties", [{}])[0]
398
+ is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
399
+ platform = "AMD" if is_amd else "NVIDIA"
400
+ gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
401
+
402
+ device_props = {
403
+ "name": gpu_name,
404
+ "compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
405
+ "total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
406
+ "sm_count": props.get("numSms", 0),
407
+ "warp_size": props.get("warpSize", 32),
408
+ "max_threads_per_block": props.get("maxThreadsPerBlock", 0),
409
+ "shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
410
+ }
411
+
412
+ _progress("Processing events", 0.4)
413
+ pass_result = _process_events_single_pass(all_events, include_stacks=include_stacks)
414
+
415
+ _progress("Detecting architecture", 0.6)
416
+ kernel_names = [ev.get("name", "") for ev in pass_result.kernel_events]
417
+ architecture, _ = detect_architecture(kernel_names)
418
+
419
+ layer_kernels_dict, layer_warnings = segment_layers_by_architecture(
420
+ pass_result.kernel_events,
421
+ architecture,
422
+ )
423
+
424
+ # Convert layer_kernels_dict (layer_num -> kernels) to correlation_id -> layer_num mapping
425
+ if layer_kernels_dict:
426
+ layer_mapping: dict[int, int] = {}
427
+ for layer_num, kernels in layer_kernels_dict.items():
428
+ for kernel in kernels:
429
+ corr_id = kernel.get("args", {}).get("correlation")
430
+ if corr_id is not None:
431
+ layer_mapping[corr_id] = layer_num
432
+ else:
433
+ layer_mapping = _build_layer_mapping(pass_result.correlation_groups)
434
+
435
+ kernel_data = []
436
+ kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
437
+ sorted_phases = sorted(pass_result.phases, key=lambda p: p["ts_start"])
438
+
439
+ phase_starts = [p["ts_start"] for p in sorted_phases]
440
+ phase_types = [p["type"] for p in sorted_phases]
441
+ phase_ends = [p["ts_end"] for p in sorted_phases]
442
+
443
+ def _get_phase_for_timestamp(ts: int) -> str:
444
+ """Get phase for a timestamp using binary search. O(log n)."""
445
+ if not phase_starts:
446
+ return "decode"
447
+ idx = bisect.bisect_right(phase_starts, ts) - 1
448
+ if idx >= 0 and phase_starts[idx] <= ts <= phase_ends[idx]:
449
+ return phase_types[idx]
450
+ return "decode"
451
+
452
+ cpu_op_cache: dict[str, str | None] = {}
453
+
454
+ _progress("Classifying kernels", 0.7)
455
+ for ev in pass_result.kernel_events:
456
+ name_raw = ev["name"]
457
+ name = sys.intern(name_raw)
458
+ dur, ts = ev.get("dur", 0), ev["ts"]
459
+ corr_id = ev.get("args", {}).get("correlation")
460
+ ext_id = ev.get("args", {}).get("External id")
461
+
462
+ phase = _get_phase_for_timestamp(ts)
463
+
464
+ op, pattern = classify(name, platform)
465
+ kernel_patterns[(op.value, phase)].add(pattern)
466
+
467
+ layer = layer_mapping.get(corr_id) if corr_id is not None else None
468
+ cpu_op = pass_result.cpu_op_mapping.get(ext_id) if ext_id is not None else None
314
469
  python_stack: list[str] = []
470
+
471
+ if cpu_op is None and include_stacks:
472
+ if name in cpu_op_cache:
473
+ cpu_op = cpu_op_cache[name]
474
+ else:
475
+ cpu_op, python_stack = _get_python_stack_full(
476
+ ts, pass_result.python_intervals, pass_result.python_by_id
477
+ )
478
+ cpu_op_cache[name] = cpu_op
479
+
480
+ kernel_data.append({
481
+ "name": name,
482
+ "dur_us": dur,
483
+ "phase": phase,
484
+ "op": op.value,
485
+ "pattern": pattern,
486
+ "layer": layer,
487
+ "correlation": corr_id,
488
+ "cpu_op": cpu_op,
489
+ "python_stack": python_stack,
490
+ })
491
+
492
+ _progress("Building DataFrame", 0.95)
493
+ df = pd.DataFrame(kernel_data)
494
+
495
+ kernel_events = pass_result.kernel_events
496
+ correlation_groups: dict[int, list[dict[str, Any]]] = defaultdict(list)
497
+ for ev in kernel_events:
498
+ corr_id = ev.get("args", {}).get("correlation")
499
+ if corr_id is not None:
500
+ correlation_groups[corr_id].append(ev)
501
+
502
+ _progress("Complete", 1.0)
503
+
504
+ return LoadedTrace(
505
+ platform=platform,
506
+ gpu_name=gpu_name,
507
+ device_props=device_props,
508
+ df=df,
509
+ patterns=dict(kernel_patterns),
510
+ layers=layer_mapping,
511
+ kernel_events=kernel_events,
512
+ all_events=all_events,
513
+ correlation_groups=dict(correlation_groups),
514
+ phases=pass_result.phases,
515
+ )
516
+
315
517
 
316
- # If no CPU op via External ID, try Python stack trace
317
- if cpu_op is None:
318
- cpu_op, python_stack = _get_python_stack_full(
319
- ts, python_intervals, python_by_id
320
- )
518
+ @dataclass
519
+ class StreamingMetadata:
520
+ """Early metadata available before full trace processing."""
521
+ platform: str
522
+ gpu_name: str
523
+ device_props: dict[str, Any]
524
+ file_size_mb: float
321
525
 
322
- kernel_data.append(
323
- {
324
- "name": name,
325
- "dur_us": dur,
326
- "phase": phase,
327
- "op": op.value,
328
- "pattern": pattern,
329
- "layer": layer,
330
- "correlation": corr_id,
331
- "cpu_op": cpu_op,
332
- "python_stack": python_stack, # Full stack for JSON output
333
- }
526
+
527
+ def _extract_metadata_fast(file_path: Path) -> StreamingMetadata:
528
+ """Extract trace metadata instantly using streaming parser (~2ms).
529
+
530
+ Uses ijson to read only the deviceProperties section without
531
+ parsing the entire file. Falls back to full parse if ijson unavailable.
532
+ """
533
+ file_size_mb = file_path.stat().st_size / (1024 * 1024)
534
+ platform = "Unknown"
535
+ gpu_name = "Unknown GPU"
536
+ device_props: dict[str, Any] = {}
537
+
538
+ if ijson is None:
539
+ # Fallback: parse full file (slower but works)
540
+ with open(file_path, "rb") as f:
541
+ trace = orjson.loads(f.read())
542
+ props = trace.get("deviceProperties", [{}])[0]
543
+ is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
544
+ platform = "AMD" if is_amd else "NVIDIA"
545
+ gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
546
+ device_props = {
547
+ "name": gpu_name,
548
+ "compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
549
+ "total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
550
+ "sm_count": props.get("numSms", 0),
551
+ "warp_size": props.get("warpSize", 32),
552
+ "max_threads_per_block": props.get("maxThreadsPerBlock", 0),
553
+ "shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
554
+ }
555
+ return StreamingMetadata(
556
+ platform=platform,
557
+ gpu_name=gpu_name,
558
+ device_props=device_props,
559
+ file_size_mb=file_size_mb,
334
560
  )
561
+
562
+ with open(file_path, "rb") as f:
563
+ parser = ijson.parse(f)
564
+ for prefix, event, value in parser:
565
+ if prefix == "deviceProperties.item.name":
566
+ gpu_name = value
567
+ elif prefix == "deviceProperties.item.warpSize":
568
+ platform = "AMD" if value == 64 else "NVIDIA"
569
+ elif prefix == "deviceProperties.item.computeMajor":
570
+ device_props["compute_major"] = value
571
+ elif prefix == "deviceProperties.item.computeMinor":
572
+ device_props["compute_minor"] = value
573
+ elif prefix == "deviceProperties.item.totalGlobalMem":
574
+ device_props["total_memory_gb"] = value / (1024**3)
575
+ elif prefix == "deviceProperties.item.numSms":
576
+ device_props["sm_count"] = value
577
+ elif prefix == "deviceProperties.item.maxThreadsPerBlock":
578
+ device_props["max_threads_per_block"] = value
579
+ elif prefix == "deviceProperties.item.sharedMemPerBlock":
580
+ device_props["shared_mem_per_block_kb"] = value / 1024
581
+ elif prefix == "traceEvents.item":
582
+ # Hit first event, stop - we only need metadata
583
+ break
584
+
585
+ # Fallback platform detection
586
+ if platform == "Unknown":
587
+ platform = "AMD" if "MI" in gpu_name or "Instinct" in gpu_name else "NVIDIA"
588
+
589
+ device_props["name"] = gpu_name
590
+ device_props["compute_capability"] = f"{device_props.get('compute_major', 0)}.{device_props.get('compute_minor', 0)}"
591
+ device_props["warp_size"] = 64 if platform == "AMD" else 32
592
+
593
+ return StreamingMetadata(
594
+ platform=platform,
595
+ gpu_name=gpu_name,
596
+ device_props=device_props,
597
+ file_size_mb=file_size_mb,
598
+ )
599
+
335
600
 
336
- return platform, gpu_name, device_props, pd.DataFrame(kernel_data), dict(kernel_patterns), layer_mapping
601
+ def load_trace_streaming(
602
+ file_path: str | Path,
603
+ include_stacks: bool = True,
604
+ on_metadata: Callable[[StreamingMetadata], None] | None = None,
605
+ on_progress: ProgressCallback | None = None,
606
+ ) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
607
+ """Load trace with instant metadata feedback.
608
+
609
+ Hybrid approach:
610
+ 1. Phase 1 (~2ms): Extract metadata with ijson streaming
611
+ 2. Phase 2 (full): Parse with orjson and process
612
+
613
+ The on_metadata callback fires immediately with GPU/platform info,
614
+ allowing the UI to show feedback while the full load continues.
615
+
616
+ Args:
617
+ file_path: Path to the trace JSON file
618
+ include_stacks: Whether to resolve Python call stacks
619
+ on_metadata: Callback for instant metadata (fires in ~2ms)
620
+ on_progress: Callback for progress updates during processing
621
+ """
622
+ file_path = Path(file_path)
623
+
624
+ # Phase 1: Instant metadata extraction (~2ms)
625
+ metadata = _extract_metadata_fast(file_path)
626
+
627
+ if on_metadata:
628
+ on_metadata(metadata)
629
+
630
+ # Phase 2: Full load with orjson (fast) + progress updates
631
+ return load_trace(
632
+ file_path,
633
+ include_stacks=include_stacks,
634
+ on_progress=on_progress,
635
+ )