wafer-core 0.1.33__py3-none-any.whl → 0.1.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,24 +7,14 @@ Python call stacks, CPU operator mappings, and layer correlations.
7
7
  import bisect
8
8
  import sys
9
9
  from collections import defaultdict
10
- from collections.abc import Callable
11
10
  from dataclasses import dataclass, field
12
11
  from pathlib import Path
13
12
  from typing import Any
14
13
 
15
- try:
16
- import ijson
17
- except ImportError:
18
- ijson = None
19
-
20
14
  import orjson
21
15
  import pandas as pd
22
16
 
23
- from .architecture import detect_architecture
24
17
  from .classifier import classify
25
- from .layer_segmentation import segment_layers_by_architecture
26
-
27
- ProgressCallback = Callable[[str, float], None]
28
18
 
29
19
 
30
20
  @dataclass
@@ -32,7 +22,7 @@ class SinglePassResult:
32
22
  """Collected data from single-pass event processing."""
33
23
  cpu_op_mapping: dict[int, str] = field(default_factory=dict)
34
24
  python_intervals: list[tuple[int, int, int, int | None, str]] = field(default_factory=list)
35
- # Raw python events for lazy python_by_id construction (built on-demand in _get_python_stack_full)
25
+ # Raw python events for lazy python_by_id construction (built on-demand)
36
26
  python_events_raw: list[dict[str, Any]] = field(default_factory=list)
37
27
  phases: list[dict[str, Any]] = field(default_factory=list)
38
28
  correlation_groups: dict[int, dict[str, Any]] = field(default_factory=lambda: defaultdict(
@@ -41,7 +31,7 @@ class SinglePassResult:
41
31
  kernel_events: list[dict[str, Any]] = field(default_factory=list)
42
32
  # Lazily built when needed for stack resolution
43
33
  _python_by_id: dict[int, dict[str, Any]] | None = field(default=None)
44
-
34
+
45
35
  @property
46
36
  def python_by_id(self) -> dict[int, dict[str, Any]]:
47
37
  """Lazily build python_by_id from raw events on first access."""
@@ -58,53 +48,101 @@ class SinglePassResult:
58
48
  return self._python_by_id
59
49
 
60
50
 
61
- @dataclass
62
- class LoadedTrace:
63
- """Complete trace data loaded once and reused by all analysis functions."""
64
- platform: str
65
- gpu_name: str
66
- device_props: dict[str, Any]
67
- df: pd.DataFrame
68
- patterns: dict[tuple[str, str], set[str]]
69
- layers: dict[int, int]
70
- # For fusion/warnings (kept from raw JSON)
71
- kernel_events: list[dict[str, Any]]
72
- all_events: list[dict[str, Any]]
73
- correlation_groups: dict[int, list[dict[str, Any]]]
74
- phases: list[dict[str, Any]] # Phase annotations for alignment
51
+ def extract_layer_mapping(events: list[dict[str, Any]], platform: str) -> dict[int, int]:
52
+ """Extract correlation ID to layer number mapping.
53
+
54
+ vLLM's execution graph creates large correlation groups for full transformer layers.
55
+ Each layer's forward pass (norm + attention + FFN) gets grouped under one correlation ID,
56
+ containing 200-400 kernels depending on batch size and sequence length.
57
+
58
+ We identify layers as correlation groups with many kernels (70+), which filters out
59
+ individual operations like sampling, logit processing, etc.
60
+
61
+ Args:
62
+ events: List of trace events
63
+ platform: 'AMD' or 'NVIDIA'
64
+
65
+ Returns:
66
+ Dict mapping correlation ID to layer number
67
+ """
68
+ # Group kernels by correlation ID
69
+ correlation_groups = defaultdict(
70
+ lambda: {"count": 0, "has_attention": False, "has_ffn": False}
71
+ )
72
+
73
+ for ev in events:
74
+ if ev.get("cat") != "kernel":
75
+ continue
76
+
77
+ corr_id = ev.get("args", {}).get("correlation")
78
+ if corr_id is None:
79
+ continue
80
+
81
+ kernel_name = ev.get("name", "").lower()
82
+
83
+ # Track what operations this correlation contains
84
+ correlation_groups[corr_id]["count"] += 1
85
+ if "attention" in kernel_name or "fmha" in kernel_name:
86
+ correlation_groups[corr_id]["has_attention"] = True
87
+ if any(x in kernel_name for x in ["cijk_", "nvjet", "wvsplitk", "gemm"]):
88
+ correlation_groups[corr_id]["has_ffn"] = True
89
+
90
+ # Map correlation IDs to layer numbers
91
+ # Transformer layers have many kernels AND contain both attention and FFN ops
92
+ correlation_to_layer = {}
93
+ layer_num = 0
94
+
95
+ for corr_id in sorted(correlation_groups.keys()):
96
+ group = correlation_groups[corr_id]
97
+
98
+ # Identify complete transformer layers by their characteristics:
99
+ # - Has attention operations (self-attention or cross-attention)
100
+ # - Has FFN operations (feed-forward network)
101
+ # - Has sufficient kernel count (70+): typical transformer block has ~80-100 kernels
102
+ # including attention QKV projections, softmax, output projection, FFN layers,
103
+ # normalization, and elementwise ops. This threshold filters out:
104
+ # - Individual operations (1-10 kernels)
105
+ # - Sampling/generation steps (20-40 kernels)
106
+ # - Partial layer executions
107
+ is_layer = (
108
+ group["count"] >= 70 and group["has_attention"] and group["has_ffn"]
109
+ )
110
+
111
+ if is_layer:
112
+ correlation_to_layer[corr_id] = layer_num
113
+ layer_num += 1
114
+
115
+ return correlation_to_layer
75
116
 
76
117
 
77
118
  def _process_events_single_pass(
78
119
  events: list[dict[str, Any]],
79
- include_stacks: bool = True,
80
120
  ) -> SinglePassResult:
81
121
  """Process all events in a single iteration.
82
-
122
+
83
123
  Optimizations applied:
84
124
  - Cache list.append methods for 2-3x speedup on hot paths
85
125
  - Store raw python events, build python_by_id lazily (only ~48 lookups due to caching)
86
126
  - Local variable caching for frequently accessed attributes
87
-
127
+
88
128
  Args:
89
129
  events: List of trace events
90
- include_stacks: Whether to collect Python stack info (expensive operation).
91
- When False, skips python_function processing entirely.
92
130
  """
93
131
  result = SinglePassResult()
94
132
  correlation_groups: dict[int, dict[str, Any]] = defaultdict(
95
133
  lambda: {"count": 0, "has_attention": False, "has_ffn": False}
96
134
  )
97
-
135
+
98
136
  # Cache list.append methods for faster appending (measured 2-3x speedup)
99
137
  kernel_append = result.kernel_events.append
100
138
  python_interval_append = result.python_intervals.append
101
139
  python_raw_append = result.python_events_raw.append
102
140
  phase_append = result.phases.append
103
141
  cpu_op_mapping = result.cpu_op_mapping
104
-
142
+
105
143
  for ev in events:
106
144
  cat = ev.get("cat")
107
-
145
+
108
146
  if cat == "kernel":
109
147
  args = ev.get("args")
110
148
  corr_id = args.get("correlation") if args else None
@@ -117,14 +155,14 @@ def _process_events_single_pass(
117
155
  if "cijk_" in kernel_name or "nvjet" in kernel_name or "wvsplitk" in kernel_name or "gemm" in kernel_name:
118
156
  grp["has_ffn"] = True
119
157
  kernel_append(ev)
120
-
158
+
121
159
  elif cat == "cpu_op":
122
160
  args = ev.get("args")
123
161
  ext_id = args.get("External id") if args else None
124
162
  if ext_id is not None:
125
163
  cpu_op_mapping[ext_id] = ev.get("name", "")
126
-
127
- elif cat == "python_function" and include_stacks:
164
+
165
+ elif cat == "python_function":
128
166
  # Store raw event for lazy python_by_id construction
129
167
  python_raw_append(ev)
130
168
  # Build interval tuple for binary search
@@ -133,7 +171,7 @@ def _process_events_single_pass(
133
171
  ts = ev["ts"]
134
172
  dur = ev.get("dur", 0)
135
173
  python_interval_append((ts, ts + dur, dur, py_id, ev["name"]))
136
-
174
+
137
175
  elif cat == "user_annotation":
138
176
  name = ev.get("name", "")
139
177
  if name.startswith("execute_context"):
@@ -146,38 +184,59 @@ def _process_events_single_pass(
146
184
  break
147
185
  except Exception:
148
186
  pass
149
- is_prefill = "generation_0" in name and tokens > 0
187
+ is_prefill = tokens >= 1024 and "generation_0" in name
150
188
  phase_append({
151
189
  "type": "prefill" if is_prefill else "decode",
152
190
  "ts_start": ev["ts"],
153
191
  "ts_end": ev["ts"] + ev["dur"],
154
192
  })
155
-
156
- if include_stacks and result.python_intervals:
193
+
194
+ if result.python_intervals:
157
195
  result.python_intervals.sort()
158
-
196
+
159
197
  result.correlation_groups = dict(correlation_groups)
160
-
198
+
161
199
  return result
162
200
 
163
201
 
164
- def _build_layer_mapping(correlation_groups: dict[int, dict[str, Any]]) -> dict[int, int]:
165
- """Build layer mapping from correlation groups."""
166
- correlation_to_layer = {}
167
- layer_num = 0
168
-
169
- for corr_id in sorted(correlation_groups.keys()):
170
- group = correlation_groups[corr_id]
171
-
172
- is_layer = (
173
- group["count"] >= 70 and group["has_attention"] and group["has_ffn"]
174
- )
175
-
176
- if is_layer:
177
- correlation_to_layer[corr_id] = layer_num
178
- layer_num += 1
179
-
180
- return correlation_to_layer
202
+ def _build_python_stack_index(
203
+ events: list[dict[str, Any]],
204
+ ) -> tuple[list[tuple[int, int, int, int | None, str]], dict[int, dict[str, Any]]]:
205
+ """Build Python call stack index for kernels.
206
+
207
+ Args:
208
+ events: List of trace events
209
+
210
+ Returns:
211
+ Tuple of (python_intervals, python_by_id)
212
+ """
213
+ python_by_id: dict[int, dict[str, Any]] = {}
214
+ python_intervals: list[tuple[int, int, int, int | None, str]] = []
215
+
216
+ for ev in events:
217
+ if ev.get("cat") == "python_function":
218
+ py_id = ev.get("args", {}).get("Python id")
219
+ name = ev["name"]
220
+ ts_start = ev["ts"]
221
+ ts_end = ts_start + ev.get("dur", 0)
222
+ duration = ev.get("dur", 0)
223
+ parent_id = ev.get("args", {}).get("Python parent id")
224
+
225
+ python_intervals.append((ts_start, ts_end, duration, py_id, name))
226
+
227
+ if py_id is not None:
228
+ python_by_id[py_id] = {
229
+ "name": name,
230
+ "parent_id": parent_id,
231
+ "ts_start": ts_start,
232
+ "ts_end": ts_end,
233
+ "duration": duration,
234
+ }
235
+
236
+ # Sort by start time for efficient binary search
237
+ python_intervals.sort()
238
+
239
+ return python_intervals, python_by_id
181
240
 
182
241
 
183
242
  def _get_python_stack_full(
@@ -186,87 +245,112 @@ def _get_python_stack_full(
186
245
  python_by_id: dict[int, dict[str, Any]],
187
246
  ) -> tuple[str | None, list[str]]:
188
247
  """Get full Python call stack for a kernel launch.
189
-
190
- Computes the chain on-demand by walking parent pointers.
191
- This is fast because we only call this ~48 times due to cpu_op caching.
248
+
249
+ Args:
250
+ timestamp: Kernel launch timestamp
251
+ python_intervals: Sorted list of Python function intervals
252
+ python_by_id: Mapping of Python ID to function info
253
+
254
+ Returns:
255
+ Tuple of (summary_string, full_stack_list)
192
256
  """
257
+ # Binary search for Python functions active at this timestamp
193
258
  idx = bisect.bisect_right(
194
259
  python_intervals, (timestamp, float("inf"), float("inf"), None, "")
195
260
  )
196
-
261
+
262
+ # Find active functions
197
263
  active_funcs = []
198
264
  for i in range(idx - 1, max(0, idx - 1000), -1):
199
265
  ts_start, ts_end, duration, py_id, name = python_intervals[i]
200
266
  if ts_start <= timestamp <= ts_end:
201
267
  active_funcs.append((duration, py_id, name))
202
- if ts_end < timestamp - 1000000:
268
+ if ts_end < timestamp - 1000000: # 1 second before
203
269
  break
204
-
270
+
205
271
  if not active_funcs:
206
272
  return None, []
207
-
273
+
274
+ # Get the innermost (most specific) function
208
275
  active_funcs.sort()
209
276
  leaf_duration, leaf_id, leaf_name = active_funcs[0]
210
-
277
+
278
+ # Walk up parent chain to get FULL stack
211
279
  full_stack = []
212
280
  current_id = leaf_id
213
- visited: set[int] = set()
214
-
215
- while current_id is not None and current_id not in visited and current_id in python_by_id:
281
+ visited = set()
282
+
283
+ while (
284
+ current_id is not None
285
+ and current_id not in visited
286
+ and current_id in python_by_id
287
+ ):
216
288
  func = python_by_id[current_id]
217
- full_stack.append(func["name"])
289
+ name = func["name"]
290
+ full_stack.append(name)
291
+
218
292
  visited.add(current_id)
219
293
  current_id = func["parent_id"]
294
+
295
+ # Safety limit: prevent infinite loops from circular parent references
296
+ # and bound memory usage. 50 frames is deeper than typical Python stacks.
220
297
  if len(full_stack) >= 50:
221
298
  break
222
-
299
+
300
+ # Reverse so it's outermost -> innermost
223
301
  full_stack.reverse()
224
-
302
+
303
+ # Create summary for text output: show the most informative vLLM/model function
225
304
  summary = None
226
- vllm_funcs = [f for f in full_stack if any(x in f.lower() for x in ["vllm/", "model", "<eval_with_key>"])]
227
-
305
+ vllm_funcs = [
306
+ f
307
+ for f in full_stack
308
+ if any(x in f.lower() for x in ["vllm/", "model", "<eval_with_key>"])
309
+ ]
310
+
228
311
  if vllm_funcs:
312
+ # Get innermost vLLM function (most specific)
229
313
  summary = vllm_funcs[-1]
314
+
315
+ # Check if it's a CUDA graph - add annotation
230
316
  if any("torch/cuda/graphs" in f for f in full_stack):
317
+ # Shorten if too long
231
318
  if len(summary) > 45:
232
- summary = "vllm/..." + summary.split("/")[-1]
319
+ parts = summary.split("/")[-1]
320
+ summary = "vllm/..." + parts
233
321
  summary = f"{summary} [CUDA graph]"
234
322
  elif len(summary) > 53:
235
- summary = "vllm/..." + summary.split("/")[-1]
323
+ parts = summary.split("/")[-1]
324
+ summary = "vllm/..." + parts
236
325
  else:
326
+ # Fallback to innermost function
237
327
  summary = leaf_name
238
-
328
+
239
329
  return summary, full_stack
240
330
 
241
331
 
242
332
  def load_trace(
243
333
  file_path: str | Path,
244
- include_stacks: bool = True,
245
- on_progress: ProgressCallback | None = None,
246
334
  ) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
247
335
  """Load trace and return platform info, device properties, kernels, patterns, and layer mapping.
248
-
336
+
249
337
  Args:
250
- file_path: Path to the trace JSON file
251
- include_stacks: Whether to resolve Python call stacks (slower but more info)
252
- on_progress: Optional callback for progress updates: (stage_name, progress_fraction)
338
+ file_path: Path to JSON trace file
339
+
340
+ Returns:
341
+ Tuple of (platform, gpu_name, device_props, kernel_df, kernel_patterns, layer_mapping)
253
342
  """
254
- def _progress(stage: str, pct: float) -> None:
255
- if on_progress:
256
- on_progress(stage, pct)
257
-
258
- _progress("Reading file", 0.0)
259
343
  with open(file_path, "rb") as f:
260
344
  raw = f.read()
261
-
262
- _progress("Parsing JSON", 0.1)
345
+
263
346
  trace = orjson.loads(raw)
264
-
347
+
265
348
  props = trace.get("deviceProperties", [{}])[0]
266
349
  is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
267
350
  platform = "AMD" if is_amd else "NVIDIA"
268
351
  gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
269
-
352
+
353
+ # Extract relevant device properties
270
354
  device_props = {
271
355
  "name": gpu_name,
272
356
  "compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
@@ -276,170 +360,30 @@ def load_trace(
276
360
  "max_threads_per_block": props.get("maxThreadsPerBlock", 0),
277
361
  "shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
278
362
  }
279
-
280
- _progress("Processing events", 0.4)
363
+
281
364
  events = trace.get("traceEvents", [])
282
- pass_result = _process_events_single_pass(events, include_stacks=include_stacks)
283
-
284
- _progress("Detecting architecture", 0.6)
285
- kernel_names = [ev.get("name", "") for ev in pass_result.kernel_events]
286
- architecture, _ = detect_architecture(kernel_names)
287
-
288
- layer_kernels_dict, layer_warnings = segment_layers_by_architecture(
289
- pass_result.kernel_events,
290
- architecture,
291
- )
292
-
293
- # Convert layer_kernels_dict (layer_num -> kernels) to correlation_id -> layer_num mapping
294
- # Fall back to old method if architecture-based segmentation fails
295
- if layer_kernels_dict:
296
- layer_mapping: dict[int, int] = {}
297
- for layer_num, kernels in layer_kernels_dict.items():
298
- for kernel in kernels:
299
- corr_id = kernel.get("args", {}).get("correlation")
300
- if corr_id is not None:
301
- layer_mapping[corr_id] = layer_num
302
- else:
303
- # Fallback to correlation-based method if architecture segmentation failed
304
- layer_mapping = _build_layer_mapping(pass_result.correlation_groups)
305
-
306
- kernel_data = []
307
- kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
308
- sorted_phases = sorted(pass_result.phases, key=lambda p: p["ts_start"])
309
-
310
- phase_starts = [p["ts_start"] for p in sorted_phases]
311
- phase_types = [p["type"] for p in sorted_phases]
312
- phase_ends = [p["ts_end"] for p in sorted_phases]
313
-
314
- def _get_phase_for_timestamp(ts: int) -> str:
315
- """Get phase for a timestamp using binary search. O(log n)."""
316
- if not phase_starts:
317
- return "decode"
318
- idx = bisect.bisect_right(phase_starts, ts) - 1
319
- if idx >= 0 and phase_starts[idx] <= ts <= phase_ends[idx]:
320
- return phase_types[idx]
321
- return "decode"
322
-
323
- cpu_op_cache: dict[str, str | None] = {}
324
-
325
- _progress("Classifying kernels", 0.7)
326
- for ev in pass_result.kernel_events:
327
- name_raw = ev["name"]
328
- name = sys.intern(name_raw)
329
- dur, ts = ev.get("dur", 0), ev["ts"]
330
- corr_id = ev.get("args", {}).get("correlation")
331
- ext_id = ev.get("args", {}).get("External id")
332
-
333
- phase = _get_phase_for_timestamp(ts)
334
-
335
- op, pattern = classify(name, platform)
336
- kernel_patterns[(op.value, phase)].add(pattern)
337
-
338
- layer = layer_mapping.get(corr_id) if corr_id is not None else None
339
- cpu_op = pass_result.cpu_op_mapping.get(ext_id) if ext_id is not None else None
340
- python_stack: list[str] = []
341
-
342
- if cpu_op is None and include_stacks:
343
- if name in cpu_op_cache:
344
- cpu_op = cpu_op_cache[name]
345
- else:
346
- cpu_op, python_stack = _get_python_stack_full(
347
- ts, pass_result.python_intervals, pass_result.python_by_id
348
- )
349
- cpu_op_cache[name] = cpu_op
350
-
351
- kernel_data.append({
352
- "name": name,
353
- "dur_us": dur,
354
- "phase": phase,
355
- "op": op.value,
356
- "pattern": pattern,
357
- "layer": layer,
358
- "correlation": corr_id,
359
- "cpu_op": cpu_op,
360
- "python_stack": python_stack,
361
- })
362
-
363
- _progress("Building DataFrame", 0.95)
364
- df = pd.DataFrame(kernel_data)
365
- _progress("Complete", 1.0)
366
-
367
- return platform, gpu_name, device_props, df, dict(kernel_patterns), layer_mapping
368
-
369
-
370
- def load_trace_full(
371
- file_path: str | Path,
372
- include_stacks: bool = True,
373
- on_progress: ProgressCallback | None = None,
374
- ) -> LoadedTrace:
375
- """Load trace once with all data needed by downstream analysis functions.
376
-
377
- Args:
378
- file_path: Path to the trace JSON file
379
- include_stacks: Whether to resolve Python call stacks
380
- on_progress: Optional callback for progress updates
381
-
382
- Returns:
383
- LoadedTrace with all trace data
384
- """
385
- def _progress(stage: str, pct: float) -> None:
386
- if on_progress:
387
- on_progress(stage, pct)
388
-
389
- _progress("Reading file", 0.0)
390
- with open(file_path, "rb") as f:
391
- raw = f.read()
392
-
393
- _progress("Parsing JSON", 0.1)
394
- trace = orjson.loads(raw)
395
- all_events = trace.get("traceEvents", [])
396
-
397
- props = trace.get("deviceProperties", [{}])[0]
398
- is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
399
- platform = "AMD" if is_amd else "NVIDIA"
400
- gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
401
-
402
- device_props = {
403
- "name": gpu_name,
404
- "compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
405
- "total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
406
- "sm_count": props.get("numSms", 0),
407
- "warp_size": props.get("warpSize", 32),
408
- "max_threads_per_block": props.get("maxThreadsPerBlock", 0),
409
- "shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
365
+
366
+ # Single-pass event processing for all metadata
367
+ pass_result = _process_events_single_pass(events)
368
+
369
+ # Extract layer mapping from correlation groups
370
+ layer_mapping = {
371
+ corr_id: layer_num
372
+ for layer_num, (corr_id, grp) in enumerate(
373
+ (cid, g) for cid, g in sorted(pass_result.correlation_groups.items())
374
+ if g["count"] >= 70 and g["has_attention"] and g["has_ffn"]
375
+ )
410
376
  }
411
-
412
- _progress("Processing events", 0.4)
413
- pass_result = _process_events_single_pass(all_events, include_stacks=include_stacks)
414
-
415
- _progress("Detecting architecture", 0.6)
416
- kernel_names = [ev.get("name", "") for ev in pass_result.kernel_events]
417
- architecture, _ = detect_architecture(kernel_names)
418
-
419
- layer_kernels_dict, layer_warnings = segment_layers_by_architecture(
420
- pass_result.kernel_events,
421
- architecture,
422
- )
423
-
424
- # Convert layer_kernels_dict (layer_num -> kernels) to correlation_id -> layer_num mapping
425
- if layer_kernels_dict:
426
- layer_mapping: dict[int, int] = {}
427
- for layer_num, kernels in layer_kernels_dict.items():
428
- for kernel in kernels:
429
- corr_id = kernel.get("args", {}).get("correlation")
430
- if corr_id is not None:
431
- layer_mapping[corr_id] = layer_num
432
- else:
433
- layer_mapping = _build_layer_mapping(pass_result.correlation_groups)
434
-
377
+
435
378
  kernel_data = []
436
379
  kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
380
+
381
+ # Pre-sort phases for binary search
437
382
  sorted_phases = sorted(pass_result.phases, key=lambda p: p["ts_start"])
438
-
439
383
  phase_starts = [p["ts_start"] for p in sorted_phases]
440
384
  phase_types = [p["type"] for p in sorted_phases]
441
385
  phase_ends = [p["ts_end"] for p in sorted_phases]
442
-
386
+
443
387
  def _get_phase_for_timestamp(ts: int) -> str:
444
388
  """Get phase for a timestamp using binary search. O(log n)."""
445
389
  if not phase_starts:
@@ -448,27 +392,31 @@ def load_trace_full(
448
392
  if idx >= 0 and phase_starts[idx] <= ts <= phase_ends[idx]:
449
393
  return phase_types[idx]
450
394
  return "decode"
451
-
395
+
396
+ # Cache CPU op lookups by kernel name (reduces 779k lookups to ~48)
452
397
  cpu_op_cache: dict[str, str | None] = {}
453
-
454
- _progress("Classifying kernels", 0.7)
398
+
455
399
  for ev in pass_result.kernel_events:
456
400
  name_raw = ev["name"]
457
- name = sys.intern(name_raw)
401
+ name = sys.intern(name_raw) # String interning for memory efficiency
458
402
  dur, ts = ev.get("dur", 0), ev["ts"]
459
403
  corr_id = ev.get("args", {}).get("correlation")
460
404
  ext_id = ev.get("args", {}).get("External id")
461
-
405
+
462
406
  phase = _get_phase_for_timestamp(ts)
463
-
407
+
464
408
  op, pattern = classify(name, platform)
465
409
  kernel_patterns[(op.value, phase)].add(pattern)
466
-
410
+
411
+ # Assign layer number from correlation ID
467
412
  layer = layer_mapping.get(corr_id) if corr_id is not None else None
413
+
414
+ # Get CPU operator name from external ID, or fallback to Python stack
468
415
  cpu_op = pass_result.cpu_op_mapping.get(ext_id) if ext_id is not None else None
469
416
  python_stack: list[str] = []
470
-
471
- if cpu_op is None and include_stacks:
417
+
418
+ # If no CPU op via External ID, try Python stack trace with caching
419
+ if cpu_op is None:
472
420
  if name in cpu_op_cache:
473
421
  cpu_op = cpu_op_cache[name]
474
422
  else:
@@ -476,160 +424,19 @@ def load_trace_full(
476
424
  ts, pass_result.python_intervals, pass_result.python_by_id
477
425
  )
478
426
  cpu_op_cache[name] = cpu_op
479
-
480
- kernel_data.append({
481
- "name": name,
482
- "dur_us": dur,
483
- "phase": phase,
484
- "op": op.value,
485
- "pattern": pattern,
486
- "layer": layer,
487
- "correlation": corr_id,
488
- "cpu_op": cpu_op,
489
- "python_stack": python_stack,
490
- })
491
-
492
- _progress("Building DataFrame", 0.95)
493
- df = pd.DataFrame(kernel_data)
494
-
495
- kernel_events = pass_result.kernel_events
496
- correlation_groups: dict[int, list[dict[str, Any]]] = defaultdict(list)
497
- for ev in kernel_events:
498
- corr_id = ev.get("args", {}).get("correlation")
499
- if corr_id is not None:
500
- correlation_groups[corr_id].append(ev)
501
-
502
- _progress("Complete", 1.0)
503
-
504
- return LoadedTrace(
505
- platform=platform,
506
- gpu_name=gpu_name,
507
- device_props=device_props,
508
- df=df,
509
- patterns=dict(kernel_patterns),
510
- layers=layer_mapping,
511
- kernel_events=kernel_events,
512
- all_events=all_events,
513
- correlation_groups=dict(correlation_groups),
514
- phases=pass_result.phases,
515
- )
516
427
 
517
-
518
- @dataclass
519
- class StreamingMetadata:
520
- """Early metadata available before full trace processing."""
521
- platform: str
522
- gpu_name: str
523
- device_props: dict[str, Any]
524
- file_size_mb: float
525
-
526
-
527
- def _extract_metadata_fast(file_path: Path) -> StreamingMetadata:
528
- """Extract trace metadata instantly using streaming parser (~2ms).
529
-
530
- Uses ijson to read only the deviceProperties section without
531
- parsing the entire file. Falls back to full parse if ijson unavailable.
532
- """
533
- file_size_mb = file_path.stat().st_size / (1024 * 1024)
534
- platform = "Unknown"
535
- gpu_name = "Unknown GPU"
536
- device_props: dict[str, Any] = {}
537
-
538
- if ijson is None:
539
- # Fallback: parse full file (slower but works)
540
- with open(file_path, "rb") as f:
541
- trace = orjson.loads(f.read())
542
- props = trace.get("deviceProperties", [{}])[0]
543
- is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
544
- platform = "AMD" if is_amd else "NVIDIA"
545
- gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
546
- device_props = {
547
- "name": gpu_name,
548
- "compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
549
- "total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
550
- "sm_count": props.get("numSms", 0),
551
- "warp_size": props.get("warpSize", 32),
552
- "max_threads_per_block": props.get("maxThreadsPerBlock", 0),
553
- "shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
554
- }
555
- return StreamingMetadata(
556
- platform=platform,
557
- gpu_name=gpu_name,
558
- device_props=device_props,
559
- file_size_mb=file_size_mb,
428
+ kernel_data.append(
429
+ {
430
+ "name": name,
431
+ "dur_us": dur,
432
+ "phase": phase,
433
+ "op": op.value,
434
+ "pattern": pattern,
435
+ "layer": layer,
436
+ "correlation": corr_id,
437
+ "cpu_op": cpu_op,
438
+ "python_stack": python_stack, # Full stack for JSON output
439
+ }
560
440
  )
561
-
562
- with open(file_path, "rb") as f:
563
- parser = ijson.parse(f)
564
- for prefix, event, value in parser:
565
- if prefix == "deviceProperties.item.name":
566
- gpu_name = value
567
- elif prefix == "deviceProperties.item.warpSize":
568
- platform = "AMD" if value == 64 else "NVIDIA"
569
- elif prefix == "deviceProperties.item.computeMajor":
570
- device_props["compute_major"] = value
571
- elif prefix == "deviceProperties.item.computeMinor":
572
- device_props["compute_minor"] = value
573
- elif prefix == "deviceProperties.item.totalGlobalMem":
574
- device_props["total_memory_gb"] = value / (1024**3)
575
- elif prefix == "deviceProperties.item.numSms":
576
- device_props["sm_count"] = value
577
- elif prefix == "deviceProperties.item.maxThreadsPerBlock":
578
- device_props["max_threads_per_block"] = value
579
- elif prefix == "deviceProperties.item.sharedMemPerBlock":
580
- device_props["shared_mem_per_block_kb"] = value / 1024
581
- elif prefix == "traceEvents.item":
582
- # Hit first event, stop - we only need metadata
583
- break
584
-
585
- # Fallback platform detection
586
- if platform == "Unknown":
587
- platform = "AMD" if "MI" in gpu_name or "Instinct" in gpu_name else "NVIDIA"
588
-
589
- device_props["name"] = gpu_name
590
- device_props["compute_capability"] = f"{device_props.get('compute_major', 0)}.{device_props.get('compute_minor', 0)}"
591
- device_props["warp_size"] = 64 if platform == "AMD" else 32
592
-
593
- return StreamingMetadata(
594
- platform=platform,
595
- gpu_name=gpu_name,
596
- device_props=device_props,
597
- file_size_mb=file_size_mb,
598
- )
599
-
600
441
 
601
- def load_trace_streaming(
602
- file_path: str | Path,
603
- include_stacks: bool = True,
604
- on_metadata: Callable[[StreamingMetadata], None] | None = None,
605
- on_progress: ProgressCallback | None = None,
606
- ) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
607
- """Load trace with instant metadata feedback.
608
-
609
- Hybrid approach:
610
- 1. Phase 1 (~2ms): Extract metadata with ijson streaming
611
- 2. Phase 2 (full): Parse with orjson and process
612
-
613
- The on_metadata callback fires immediately with GPU/platform info,
614
- allowing the UI to show feedback while the full load continues.
615
-
616
- Args:
617
- file_path: Path to the trace JSON file
618
- include_stacks: Whether to resolve Python call stacks
619
- on_metadata: Callback for instant metadata (fires in ~2ms)
620
- on_progress: Callback for progress updates during processing
621
- """
622
- file_path = Path(file_path)
623
-
624
- # Phase 1: Instant metadata extraction (~2ms)
625
- metadata = _extract_metadata_fast(file_path)
626
-
627
- if on_metadata:
628
- on_metadata(metadata)
629
-
630
- # Phase 2: Full load with orjson (fast) + progress updates
631
- return load_trace(
632
- file_path,
633
- include_stacks=include_stacks,
634
- on_progress=on_progress,
635
- )
442
+ return platform, gpu_name, device_props, pd.DataFrame(kernel_data), dict(kernel_patterns), layer_mapping