wafer-core 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,635 @@
1
+ """Trace loading and parsing logic.
2
+
3
+ Loads JSON trace files from AMD/NVIDIA profilers and extracts kernel execution data,
4
+ Python call stacks, CPU operator mappings, and layer correlations.
5
+ """
6
+
7
+ import bisect
8
+ import sys
9
+ from collections import defaultdict
10
+ from collections.abc import Callable
11
+ from dataclasses import dataclass, field
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ try:
16
+ import ijson
17
+ except ImportError:
18
+ ijson = None
19
+
20
+ import orjson
21
+ import pandas as pd
22
+
23
+ from .architecture import detect_architecture
24
+ from .classifier import classify
25
+ from .layer_segmentation import segment_layers_by_architecture
26
+
27
+ ProgressCallback = Callable[[str, float], None]
28
+
29
+
30
+ @dataclass
31
+ class SinglePassResult:
32
+ """Collected data from single-pass event processing."""
33
+ cpu_op_mapping: dict[int, str] = field(default_factory=dict)
34
+ python_intervals: list[tuple[int, int, int, int | None, str]] = field(default_factory=list)
35
+ # Raw python events for lazy python_by_id construction (built on-demand in _get_python_stack_full)
36
+ python_events_raw: list[dict[str, Any]] = field(default_factory=list)
37
+ phases: list[dict[str, Any]] = field(default_factory=list)
38
+ correlation_groups: dict[int, dict[str, Any]] = field(default_factory=lambda: defaultdict(
39
+ lambda: {"count": 0, "has_attention": False, "has_ffn": False}
40
+ ))
41
+ kernel_events: list[dict[str, Any]] = field(default_factory=list)
42
+ # Lazily built when needed for stack resolution
43
+ _python_by_id: dict[int, dict[str, Any]] | None = field(default=None)
44
+
45
+ @property
46
+ def python_by_id(self) -> dict[int, dict[str, Any]]:
47
+ """Lazily build python_by_id from raw events on first access."""
48
+ if self._python_by_id is None:
49
+ self._python_by_id = {}
50
+ for ev in self.python_events_raw:
51
+ args = ev.get("args")
52
+ py_id = args.get("Python id") if args else None
53
+ if py_id is not None:
54
+ self._python_by_id[py_id] = {
55
+ "name": ev["name"],
56
+ "parent_id": args.get("Python parent id") if args else None,
57
+ }
58
+ return self._python_by_id
59
+
60
+
61
+ @dataclass
62
+ class LoadedTrace:
63
+ """Complete trace data loaded once and reused by all analysis functions."""
64
+ platform: str
65
+ gpu_name: str
66
+ device_props: dict[str, Any]
67
+ df: pd.DataFrame
68
+ patterns: dict[tuple[str, str], set[str]]
69
+ layers: dict[int, int]
70
+ # For fusion/warnings (kept from raw JSON)
71
+ kernel_events: list[dict[str, Any]]
72
+ all_events: list[dict[str, Any]]
73
+ correlation_groups: dict[int, list[dict[str, Any]]]
74
+ phases: list[dict[str, Any]] # Phase annotations for alignment
75
+
76
+
77
+ def _process_events_single_pass(
78
+ events: list[dict[str, Any]],
79
+ include_stacks: bool = True,
80
+ ) -> SinglePassResult:
81
+ """Process all events in a single iteration.
82
+
83
+ Optimizations applied:
84
+ - Cache list.append methods for 2-3x speedup on hot paths
85
+ - Store raw python events, build python_by_id lazily (only ~48 lookups due to caching)
86
+ - Local variable caching for frequently accessed attributes
87
+
88
+ Args:
89
+ events: List of trace events
90
+ include_stacks: Whether to collect Python stack info (expensive operation).
91
+ When False, skips python_function processing entirely.
92
+ """
93
+ result = SinglePassResult()
94
+ correlation_groups: dict[int, dict[str, Any]] = defaultdict(
95
+ lambda: {"count": 0, "has_attention": False, "has_ffn": False}
96
+ )
97
+
98
+ # Cache list.append methods for faster appending (measured 2-3x speedup)
99
+ kernel_append = result.kernel_events.append
100
+ python_interval_append = result.python_intervals.append
101
+ python_raw_append = result.python_events_raw.append
102
+ phase_append = result.phases.append
103
+ cpu_op_mapping = result.cpu_op_mapping
104
+
105
+ for ev in events:
106
+ cat = ev.get("cat")
107
+
108
+ if cat == "kernel":
109
+ args = ev.get("args")
110
+ corr_id = args.get("correlation") if args else None
111
+ if corr_id is not None:
112
+ kernel_name = ev.get("name", "").lower()
113
+ grp = correlation_groups[corr_id]
114
+ grp["count"] += 1
115
+ if "attention" in kernel_name or "fmha" in kernel_name:
116
+ grp["has_attention"] = True
117
+ if "cijk_" in kernel_name or "nvjet" in kernel_name or "wvsplitk" in kernel_name or "gemm" in kernel_name:
118
+ grp["has_ffn"] = True
119
+ kernel_append(ev)
120
+
121
+ elif cat == "cpu_op":
122
+ args = ev.get("args")
123
+ ext_id = args.get("External id") if args else None
124
+ if ext_id is not None:
125
+ cpu_op_mapping[ext_id] = ev.get("name", "")
126
+
127
+ elif cat == "python_function" and include_stacks:
128
+ # Store raw event for lazy python_by_id construction
129
+ python_raw_append(ev)
130
+ # Build interval tuple for binary search
131
+ args = ev.get("args")
132
+ py_id = args.get("Python id") if args else None
133
+ ts = ev["ts"]
134
+ dur = ev.get("dur", 0)
135
+ python_interval_append((ts, ts + dur, dur, py_id, ev["name"]))
136
+
137
+ elif cat == "user_annotation":
138
+ name = ev.get("name", "")
139
+ if name.startswith("execute_context"):
140
+ tokens = 0
141
+ parts = name.split("_")
142
+ for i, p in enumerate(parts):
143
+ if i > 0 and parts[i-1] == "context" and "(" in p and ")" in p:
144
+ try:
145
+ tokens = int(p.split("(")[1].split(")")[0])
146
+ break
147
+ except Exception:
148
+ pass
149
+ is_prefill = "generation_0" in name and tokens > 0
150
+ phase_append({
151
+ "type": "prefill" if is_prefill else "decode",
152
+ "ts_start": ev["ts"],
153
+ "ts_end": ev["ts"] + ev["dur"],
154
+ })
155
+
156
+ if include_stacks and result.python_intervals:
157
+ result.python_intervals.sort()
158
+
159
+ result.correlation_groups = dict(correlation_groups)
160
+
161
+ return result
162
+
163
+
164
+ def _build_layer_mapping(correlation_groups: dict[int, dict[str, Any]]) -> dict[int, int]:
165
+ """Build layer mapping from correlation groups."""
166
+ correlation_to_layer = {}
167
+ layer_num = 0
168
+
169
+ for corr_id in sorted(correlation_groups.keys()):
170
+ group = correlation_groups[corr_id]
171
+
172
+ is_layer = (
173
+ group["count"] >= 70 and group["has_attention"] and group["has_ffn"]
174
+ )
175
+
176
+ if is_layer:
177
+ correlation_to_layer[corr_id] = layer_num
178
+ layer_num += 1
179
+
180
+ return correlation_to_layer
181
+
182
+
183
+ def _get_python_stack_full(
184
+ timestamp: int,
185
+ python_intervals: list[tuple[int, int, int, int | None, str]],
186
+ python_by_id: dict[int, dict[str, Any]],
187
+ ) -> tuple[str | None, list[str]]:
188
+ """Get full Python call stack for a kernel launch.
189
+
190
+ Computes the chain on-demand by walking parent pointers.
191
+ This is fast because we only call this ~48 times due to cpu_op caching.
192
+ """
193
+ idx = bisect.bisect_right(
194
+ python_intervals, (timestamp, float("inf"), float("inf"), None, "")
195
+ )
196
+
197
+ active_funcs = []
198
+ for i in range(idx - 1, max(0, idx - 1000), -1):
199
+ ts_start, ts_end, duration, py_id, name = python_intervals[i]
200
+ if ts_start <= timestamp <= ts_end:
201
+ active_funcs.append((duration, py_id, name))
202
+ if ts_end < timestamp - 1000000:
203
+ break
204
+
205
+ if not active_funcs:
206
+ return None, []
207
+
208
+ active_funcs.sort()
209
+ leaf_duration, leaf_id, leaf_name = active_funcs[0]
210
+
211
+ full_stack = []
212
+ current_id = leaf_id
213
+ visited: set[int] = set()
214
+
215
+ while current_id is not None and current_id not in visited and current_id in python_by_id:
216
+ func = python_by_id[current_id]
217
+ full_stack.append(func["name"])
218
+ visited.add(current_id)
219
+ current_id = func["parent_id"]
220
+ if len(full_stack) >= 50:
221
+ break
222
+
223
+ full_stack.reverse()
224
+
225
+ summary = None
226
+ vllm_funcs = [f for f in full_stack if any(x in f.lower() for x in ["vllm/", "model", "<eval_with_key>"])]
227
+
228
+ if vllm_funcs:
229
+ summary = vllm_funcs[-1]
230
+ if any("torch/cuda/graphs" in f for f in full_stack):
231
+ if len(summary) > 45:
232
+ summary = "vllm/..." + summary.split("/")[-1]
233
+ summary = f"{summary} [CUDA graph]"
234
+ elif len(summary) > 53:
235
+ summary = "vllm/..." + summary.split("/")[-1]
236
+ else:
237
+ summary = leaf_name
238
+
239
+ return summary, full_stack
240
+
241
+
242
+ def load_trace(
243
+ file_path: str | Path,
244
+ include_stacks: bool = True,
245
+ on_progress: ProgressCallback | None = None,
246
+ ) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
247
+ """Load trace and return platform info, device properties, kernels, patterns, and layer mapping.
248
+
249
+ Args:
250
+ file_path: Path to the trace JSON file
251
+ include_stacks: Whether to resolve Python call stacks (slower but more info)
252
+ on_progress: Optional callback for progress updates: (stage_name, progress_fraction)
253
+ """
254
+ def _progress(stage: str, pct: float) -> None:
255
+ if on_progress:
256
+ on_progress(stage, pct)
257
+
258
+ _progress("Reading file", 0.0)
259
+ with open(file_path, "rb") as f:
260
+ raw = f.read()
261
+
262
+ _progress("Parsing JSON", 0.1)
263
+ trace = orjson.loads(raw)
264
+
265
+ props = trace.get("deviceProperties", [{}])[0]
266
+ is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
267
+ platform = "AMD" if is_amd else "NVIDIA"
268
+ gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
269
+
270
+ device_props = {
271
+ "name": gpu_name,
272
+ "compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
273
+ "total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
274
+ "sm_count": props.get("numSms", 0),
275
+ "warp_size": props.get("warpSize", 32),
276
+ "max_threads_per_block": props.get("maxThreadsPerBlock", 0),
277
+ "shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
278
+ }
279
+
280
+ _progress("Processing events", 0.4)
281
+ events = trace.get("traceEvents", [])
282
+ pass_result = _process_events_single_pass(events, include_stacks=include_stacks)
283
+
284
+ _progress("Detecting architecture", 0.6)
285
+ kernel_names = [ev.get("name", "") for ev in pass_result.kernel_events]
286
+ architecture, _ = detect_architecture(kernel_names)
287
+
288
+ layer_kernels_dict, layer_warnings = segment_layers_by_architecture(
289
+ pass_result.kernel_events,
290
+ architecture,
291
+ )
292
+
293
+ # Convert layer_kernels_dict (layer_num -> kernels) to correlation_id -> layer_num mapping
294
+ # Fall back to old method if architecture-based segmentation fails
295
+ if layer_kernels_dict:
296
+ layer_mapping: dict[int, int] = {}
297
+ for layer_num, kernels in layer_kernels_dict.items():
298
+ for kernel in kernels:
299
+ corr_id = kernel.get("args", {}).get("correlation")
300
+ if corr_id is not None:
301
+ layer_mapping[corr_id] = layer_num
302
+ else:
303
+ # Fallback to correlation-based method if architecture segmentation failed
304
+ layer_mapping = _build_layer_mapping(pass_result.correlation_groups)
305
+
306
+ kernel_data = []
307
+ kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
308
+ sorted_phases = sorted(pass_result.phases, key=lambda p: p["ts_start"])
309
+
310
+ phase_starts = [p["ts_start"] for p in sorted_phases]
311
+ phase_types = [p["type"] for p in sorted_phases]
312
+ phase_ends = [p["ts_end"] for p in sorted_phases]
313
+
314
+ def _get_phase_for_timestamp(ts: int) -> str:
315
+ """Get phase for a timestamp using binary search. O(log n)."""
316
+ if not phase_starts:
317
+ return "decode"
318
+ idx = bisect.bisect_right(phase_starts, ts) - 1
319
+ if idx >= 0 and phase_starts[idx] <= ts <= phase_ends[idx]:
320
+ return phase_types[idx]
321
+ return "decode"
322
+
323
+ cpu_op_cache: dict[str, str | None] = {}
324
+
325
+ _progress("Classifying kernels", 0.7)
326
+ for ev in pass_result.kernel_events:
327
+ name_raw = ev["name"]
328
+ name = sys.intern(name_raw)
329
+ dur, ts = ev.get("dur", 0), ev["ts"]
330
+ corr_id = ev.get("args", {}).get("correlation")
331
+ ext_id = ev.get("args", {}).get("External id")
332
+
333
+ phase = _get_phase_for_timestamp(ts)
334
+
335
+ op, pattern = classify(name, platform)
336
+ kernel_patterns[(op.value, phase)].add(pattern)
337
+
338
+ layer = layer_mapping.get(corr_id) if corr_id is not None else None
339
+ cpu_op = pass_result.cpu_op_mapping.get(ext_id) if ext_id is not None else None
340
+ python_stack: list[str] = []
341
+
342
+ if cpu_op is None and include_stacks:
343
+ if name in cpu_op_cache:
344
+ cpu_op = cpu_op_cache[name]
345
+ else:
346
+ cpu_op, python_stack = _get_python_stack_full(
347
+ ts, pass_result.python_intervals, pass_result.python_by_id
348
+ )
349
+ cpu_op_cache[name] = cpu_op
350
+
351
+ kernel_data.append({
352
+ "name": name,
353
+ "dur_us": dur,
354
+ "phase": phase,
355
+ "op": op.value,
356
+ "pattern": pattern,
357
+ "layer": layer,
358
+ "correlation": corr_id,
359
+ "cpu_op": cpu_op,
360
+ "python_stack": python_stack,
361
+ })
362
+
363
+ _progress("Building DataFrame", 0.95)
364
+ df = pd.DataFrame(kernel_data)
365
+ _progress("Complete", 1.0)
366
+
367
+ return platform, gpu_name, device_props, df, dict(kernel_patterns), layer_mapping
368
+
369
+
370
+ def load_trace_full(
371
+ file_path: str | Path,
372
+ include_stacks: bool = True,
373
+ on_progress: ProgressCallback | None = None,
374
+ ) -> LoadedTrace:
375
+ """Load trace once with all data needed by downstream analysis functions.
376
+
377
+ Args:
378
+ file_path: Path to the trace JSON file
379
+ include_stacks: Whether to resolve Python call stacks
380
+ on_progress: Optional callback for progress updates
381
+
382
+ Returns:
383
+ LoadedTrace with all trace data
384
+ """
385
+ def _progress(stage: str, pct: float) -> None:
386
+ if on_progress:
387
+ on_progress(stage, pct)
388
+
389
+ _progress("Reading file", 0.0)
390
+ with open(file_path, "rb") as f:
391
+ raw = f.read()
392
+
393
+ _progress("Parsing JSON", 0.1)
394
+ trace = orjson.loads(raw)
395
+ all_events = trace.get("traceEvents", [])
396
+
397
+ props = trace.get("deviceProperties", [{}])[0]
398
+ is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
399
+ platform = "AMD" if is_amd else "NVIDIA"
400
+ gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
401
+
402
+ device_props = {
403
+ "name": gpu_name,
404
+ "compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
405
+ "total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
406
+ "sm_count": props.get("numSms", 0),
407
+ "warp_size": props.get("warpSize", 32),
408
+ "max_threads_per_block": props.get("maxThreadsPerBlock", 0),
409
+ "shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
410
+ }
411
+
412
+ _progress("Processing events", 0.4)
413
+ pass_result = _process_events_single_pass(all_events, include_stacks=include_stacks)
414
+
415
+ _progress("Detecting architecture", 0.6)
416
+ kernel_names = [ev.get("name", "") for ev in pass_result.kernel_events]
417
+ architecture, _ = detect_architecture(kernel_names)
418
+
419
+ layer_kernels_dict, layer_warnings = segment_layers_by_architecture(
420
+ pass_result.kernel_events,
421
+ architecture,
422
+ )
423
+
424
+ # Convert layer_kernels_dict (layer_num -> kernels) to correlation_id -> layer_num mapping
425
+ if layer_kernels_dict:
426
+ layer_mapping: dict[int, int] = {}
427
+ for layer_num, kernels in layer_kernels_dict.items():
428
+ for kernel in kernels:
429
+ corr_id = kernel.get("args", {}).get("correlation")
430
+ if corr_id is not None:
431
+ layer_mapping[corr_id] = layer_num
432
+ else:
433
+ layer_mapping = _build_layer_mapping(pass_result.correlation_groups)
434
+
435
+ kernel_data = []
436
+ kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
437
+ sorted_phases = sorted(pass_result.phases, key=lambda p: p["ts_start"])
438
+
439
+ phase_starts = [p["ts_start"] for p in sorted_phases]
440
+ phase_types = [p["type"] for p in sorted_phases]
441
+ phase_ends = [p["ts_end"] for p in sorted_phases]
442
+
443
+ def _get_phase_for_timestamp(ts: int) -> str:
444
+ """Get phase for a timestamp using binary search. O(log n)."""
445
+ if not phase_starts:
446
+ return "decode"
447
+ idx = bisect.bisect_right(phase_starts, ts) - 1
448
+ if idx >= 0 and phase_starts[idx] <= ts <= phase_ends[idx]:
449
+ return phase_types[idx]
450
+ return "decode"
451
+
452
+ cpu_op_cache: dict[str, str | None] = {}
453
+
454
+ _progress("Classifying kernels", 0.7)
455
+ for ev in pass_result.kernel_events:
456
+ name_raw = ev["name"]
457
+ name = sys.intern(name_raw)
458
+ dur, ts = ev.get("dur", 0), ev["ts"]
459
+ corr_id = ev.get("args", {}).get("correlation")
460
+ ext_id = ev.get("args", {}).get("External id")
461
+
462
+ phase = _get_phase_for_timestamp(ts)
463
+
464
+ op, pattern = classify(name, platform)
465
+ kernel_patterns[(op.value, phase)].add(pattern)
466
+
467
+ layer = layer_mapping.get(corr_id) if corr_id is not None else None
468
+ cpu_op = pass_result.cpu_op_mapping.get(ext_id) if ext_id is not None else None
469
+ python_stack: list[str] = []
470
+
471
+ if cpu_op is None and include_stacks:
472
+ if name in cpu_op_cache:
473
+ cpu_op = cpu_op_cache[name]
474
+ else:
475
+ cpu_op, python_stack = _get_python_stack_full(
476
+ ts, pass_result.python_intervals, pass_result.python_by_id
477
+ )
478
+ cpu_op_cache[name] = cpu_op
479
+
480
+ kernel_data.append({
481
+ "name": name,
482
+ "dur_us": dur,
483
+ "phase": phase,
484
+ "op": op.value,
485
+ "pattern": pattern,
486
+ "layer": layer,
487
+ "correlation": corr_id,
488
+ "cpu_op": cpu_op,
489
+ "python_stack": python_stack,
490
+ })
491
+
492
+ _progress("Building DataFrame", 0.95)
493
+ df = pd.DataFrame(kernel_data)
494
+
495
+ kernel_events = pass_result.kernel_events
496
+ correlation_groups: dict[int, list[dict[str, Any]]] = defaultdict(list)
497
+ for ev in kernel_events:
498
+ corr_id = ev.get("args", {}).get("correlation")
499
+ if corr_id is not None:
500
+ correlation_groups[corr_id].append(ev)
501
+
502
+ _progress("Complete", 1.0)
503
+
504
+ return LoadedTrace(
505
+ platform=platform,
506
+ gpu_name=gpu_name,
507
+ device_props=device_props,
508
+ df=df,
509
+ patterns=dict(kernel_patterns),
510
+ layers=layer_mapping,
511
+ kernel_events=kernel_events,
512
+ all_events=all_events,
513
+ correlation_groups=dict(correlation_groups),
514
+ phases=pass_result.phases,
515
+ )
516
+
517
+
518
+ @dataclass
519
+ class StreamingMetadata:
520
+ """Early metadata available before full trace processing."""
521
+ platform: str
522
+ gpu_name: str
523
+ device_props: dict[str, Any]
524
+ file_size_mb: float
525
+
526
+
527
+ def _extract_metadata_fast(file_path: Path) -> StreamingMetadata:
528
+ """Extract trace metadata instantly using streaming parser (~2ms).
529
+
530
+ Uses ijson to read only the deviceProperties section without
531
+ parsing the entire file. Falls back to full parse if ijson unavailable.
532
+ """
533
+ file_size_mb = file_path.stat().st_size / (1024 * 1024)
534
+ platform = "Unknown"
535
+ gpu_name = "Unknown GPU"
536
+ device_props: dict[str, Any] = {}
537
+
538
+ if ijson is None:
539
+ # Fallback: parse full file (slower but works)
540
+ with open(file_path, "rb") as f:
541
+ trace = orjson.loads(f.read())
542
+ props = trace.get("deviceProperties", [{}])[0]
543
+ is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
544
+ platform = "AMD" if is_amd else "NVIDIA"
545
+ gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
546
+ device_props = {
547
+ "name": gpu_name,
548
+ "compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
549
+ "total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
550
+ "sm_count": props.get("numSms", 0),
551
+ "warp_size": props.get("warpSize", 32),
552
+ "max_threads_per_block": props.get("maxThreadsPerBlock", 0),
553
+ "shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
554
+ }
555
+ return StreamingMetadata(
556
+ platform=platform,
557
+ gpu_name=gpu_name,
558
+ device_props=device_props,
559
+ file_size_mb=file_size_mb,
560
+ )
561
+
562
+ with open(file_path, "rb") as f:
563
+ parser = ijson.parse(f)
564
+ for prefix, event, value in parser:
565
+ if prefix == "deviceProperties.item.name":
566
+ gpu_name = value
567
+ elif prefix == "deviceProperties.item.warpSize":
568
+ platform = "AMD" if value == 64 else "NVIDIA"
569
+ elif prefix == "deviceProperties.item.computeMajor":
570
+ device_props["compute_major"] = value
571
+ elif prefix == "deviceProperties.item.computeMinor":
572
+ device_props["compute_minor"] = value
573
+ elif prefix == "deviceProperties.item.totalGlobalMem":
574
+ device_props["total_memory_gb"] = value / (1024**3)
575
+ elif prefix == "deviceProperties.item.numSms":
576
+ device_props["sm_count"] = value
577
+ elif prefix == "deviceProperties.item.maxThreadsPerBlock":
578
+ device_props["max_threads_per_block"] = value
579
+ elif prefix == "deviceProperties.item.sharedMemPerBlock":
580
+ device_props["shared_mem_per_block_kb"] = value / 1024
581
+ elif prefix == "traceEvents.item":
582
+ # Hit first event, stop - we only need metadata
583
+ break
584
+
585
+ # Fallback platform detection
586
+ if platform == "Unknown":
587
+ platform = "AMD" if "MI" in gpu_name or "Instinct" in gpu_name else "NVIDIA"
588
+
589
+ device_props["name"] = gpu_name
590
+ device_props["compute_capability"] = f"{device_props.get('compute_major', 0)}.{device_props.get('compute_minor', 0)}"
591
+ device_props["warp_size"] = 64 if platform == "AMD" else 32
592
+
593
+ return StreamingMetadata(
594
+ platform=platform,
595
+ gpu_name=gpu_name,
596
+ device_props=device_props,
597
+ file_size_mb=file_size_mb,
598
+ )
599
+
600
+
601
+ def load_trace_streaming(
602
+ file_path: str | Path,
603
+ include_stacks: bool = True,
604
+ on_metadata: Callable[[StreamingMetadata], None] | None = None,
605
+ on_progress: ProgressCallback | None = None,
606
+ ) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
607
+ """Load trace with instant metadata feedback.
608
+
609
+ Hybrid approach:
610
+ 1. Phase 1 (~2ms): Extract metadata with ijson streaming
611
+ 2. Phase 2 (full): Parse with orjson and process
612
+
613
+ The on_metadata callback fires immediately with GPU/platform info,
614
+ allowing the UI to show feedback while the full load continues.
615
+
616
+ Args:
617
+ file_path: Path to the trace JSON file
618
+ include_stacks: Whether to resolve Python call stacks
619
+ on_metadata: Callback for instant metadata (fires in ~2ms)
620
+ on_progress: Callback for progress updates during processing
621
+ """
622
+ file_path = Path(file_path)
623
+
624
+ # Phase 1: Instant metadata extraction (~2ms)
625
+ metadata = _extract_metadata_fast(file_path)
626
+
627
+ if on_metadata:
628
+ on_metadata(metadata)
629
+
630
+ # Phase 2: Full load with orjson (fast) + progress updates
631
+ return load_trace(
632
+ file_path,
633
+ include_stacks=include_stacks,
634
+ on_progress=on_progress,
635
+ )