pytscope 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. pytscope/__init__.py +57 -0
  2. pytscope/analyzers/__init__.py +15 -0
  3. pytscope/analyzers/convergence.py +77 -0
  4. pytscope/analyzers/distributed.py +230 -0
  5. pytscope/analyzers/efficiency.py +149 -0
  6. pytscope/analyzers/memory.py +58 -0
  7. pytscope/analyzers/pipeline.py +149 -0
  8. pytscope/analyzers/repro.py +205 -0
  9. pytscope/analyzers/stats.py +110 -0
  10. pytscope/analyzers/timing.py +109 -0
  11. pytscope/analyzers/trace.py +226 -0
  12. pytscope/auto.py +290 -0
  13. pytscope/cli.py +203 -0
  14. pytscope/collectors/__init__.py +3 -0
  15. pytscope/collectors/memory.py +40 -0
  16. pytscope/core/__init__.py +27 -0
  17. pytscope/core/distributed.py +49 -0
  18. pytscope/core/events.py +76 -0
  19. pytscope/core/provenance.py +47 -0
  20. pytscope/core/store.py +101 -0
  21. pytscope/diagnosis/__init__.py +12 -0
  22. pytscope/diagnosis/engine.py +70 -0
  23. pytscope/diagnosis/rules.py +129 -0
  24. pytscope/diagnosis/rules_convergence.py +54 -0
  25. pytscope/diagnosis/rules_cross.py +116 -0
  26. pytscope/diagnosis/rules_distributed.py +184 -0
  27. pytscope/diagnosis/rules_efficiency.py +65 -0
  28. pytscope/diagnosis/rules_memory.py +71 -0
  29. pytscope/hardware.py +93 -0
  30. pytscope/integrations/__init__.py +7 -0
  31. pytscope/integrations/huggingface.py +45 -0
  32. pytscope/integrations/lightning.py +77 -0
  33. pytscope/profiler.py +252 -0
  34. pytscope/py.typed +0 -0
  35. pytscope/report/__init__.py +3 -0
  36. pytscope/report/cli_report.py +394 -0
  37. pytscope-0.2.1.dist-info/METADATA +367 -0
  38. pytscope-0.2.1.dist-info/RECORD +41 -0
  39. pytscope-0.2.1.dist-info/WHEEL +4 -0
  40. pytscope-0.2.1.dist-info/entry_points.txt +2 -0
  41. pytscope-0.2.1.dist-info/licenses/LICENSE +21 -0
pytscope/__init__.py ADDED
@@ -0,0 +1,57 @@
1
+ """pytscope — an intelligence layer for ML training.
2
+
3
+ One telemetry backbone (aligned per-step records) feeds pluggable analyzers
4
+ (timing today; memory, convergence, reproducibility next) and a diagnosis engine
5
+ that turns raw numbers into ranked, actionable findings.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from .analyzers.convergence import ConvergenceSummary, analyze_convergence
11
+ from .analyzers.distributed import (
12
+ DistributedSummary,
13
+ analyze_distributed,
14
+ load_multirank,
15
+ )
16
+ from .analyzers.efficiency import EfficiencyBudget, analyze_efficiency
17
+ from .analyzers.memory import MemorySummary, analyze_memory
18
+ from .analyzers.pipeline import PipelineSummary, analyze_pipeline
19
+ from .analyzers.timing import TimingSummary, analyze_timing
20
+ from .analyzers.trace import TraceSummary, analyze_trace, analyze_trace_file
21
+ from .auto import AutoProfiler
22
+ from .core.events import StepRecord
23
+ from .core.store import RunStore
24
+ from .diagnosis.engine import DiagnosisContext, Finding, run_diagnosis
25
+ from .hardware import measure_flops, peak_flops_for
26
+ from .profiler import Profiler
27
+
28
+ __version__ = "0.2.1"
29
+
30
+ __all__ = [
31
+ "Profiler",
32
+ "AutoProfiler",
33
+ "RunStore",
34
+ "StepRecord",
35
+ "analyze_timing",
36
+ "TimingSummary",
37
+ "analyze_memory",
38
+ "MemorySummary",
39
+ "analyze_convergence",
40
+ "ConvergenceSummary",
41
+ "analyze_distributed",
42
+ "DistributedSummary",
43
+ "load_multirank",
44
+ "analyze_pipeline",
45
+ "PipelineSummary",
46
+ "analyze_trace",
47
+ "analyze_trace_file",
48
+ "TraceSummary",
49
+ "analyze_efficiency",
50
+ "EfficiencyBudget",
51
+ "measure_flops",
52
+ "peak_flops_for",
53
+ "run_diagnosis",
54
+ "DiagnosisContext",
55
+ "Finding",
56
+ "__version__",
57
+ ]
@@ -0,0 +1,15 @@
1
+ from .convergence import ConvergenceSummary, analyze_convergence
2
+ from .memory import MemorySummary, analyze_memory
3
+ from .repro import RunDiff, diff_runs
4
+ from .timing import TimingSummary, analyze_timing
5
+
6
+ __all__ = [
7
+ "analyze_timing",
8
+ "TimingSummary",
9
+ "analyze_memory",
10
+ "MemorySummary",
11
+ "analyze_convergence",
12
+ "ConvergenceSummary",
13
+ "diff_runs",
14
+ "RunDiff",
15
+ ]
@@ -0,0 +1,77 @@
1
+ """Convergence analyzer (vertical #3) — reads per-step ``scalars``.
2
+
3
+ Consumes loss / grad_norm / lr that the user logs via ``prof.log(...)`` (or the
4
+ Lightning callback's loss). Pure functions over the timeline.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+ from dataclasses import dataclass, field
11
+
12
+ from ..core.events import StepRecord
13
+ from .stats import local_spikes, median
14
+
15
+
16
+ @dataclass
17
+ class ConvergenceSummary:
18
+ has_loss: bool = False
19
+ has_grad_norm: bool = False
20
+ n_steps: int = 0
21
+ final_loss: float | None = None
22
+ best_loss: float | None = None
23
+ loss_trend: str = "unknown" # improving | plateau | worsening | diverged | unknown
24
+ diverged_at: int | None = None # step index of first non-finite loss
25
+ loss_spikes: list[int] = field(default_factory=list) # step indices
26
+ grad_norm_spikes: list[int] = field(default_factory=list)
27
+
28
+
29
+ def _trend(losses: list[float], tol: float = 0.02) -> str:
30
+ n = len(losses)
31
+ if n < 4:
32
+ return "unknown"
33
+ w = max(1, n // 10)
34
+ first = median(losses[:w])
35
+ last = median(losses[-w:])
36
+ denom = abs(first) if first != 0 else 1.0
37
+ rel = (last - first) / denom
38
+ if rel < -tol:
39
+ return "improving"
40
+ if rel > tol:
41
+ return "worsening"
42
+ return "plateau"
43
+
44
+
45
+ def analyze_convergence(steps: list[StepRecord]) -> ConvergenceSummary:
46
+ loss_pairs = [(s.step, s.scalars["loss"]) for s in steps if "loss" in s.scalars]
47
+ grad_series = [s.scalars.get("grad_norm") for s in steps]
48
+ has_grad = any(g is not None for g in grad_series)
49
+
50
+ if not loss_pairs:
51
+ return ConvergenceSummary(
52
+ has_loss=False, has_grad_norm=has_grad, n_steps=len(steps)
53
+ )
54
+
55
+ steps_idx = [st for st, _ in loss_pairs]
56
+ losses = [lv for _, lv in loss_pairs]
57
+
58
+ diverged_at = next((st for st, lv in loss_pairs if not math.isfinite(lv)), None)
59
+ finite = [lv for lv in losses if math.isfinite(lv)]
60
+
61
+ # Spike indices map back to the original step numbers.
62
+ loss_spike_pos = local_spikes(losses)
63
+ loss_spikes = [steps_idx[i] for i in sorted(loss_spike_pos)]
64
+ grad_spike_pos = local_spikes(grad_series) if has_grad else set()
65
+ grad_spikes = [steps[i].step for i in sorted(grad_spike_pos)]
66
+
67
+ return ConvergenceSummary(
68
+ has_loss=True,
69
+ has_grad_norm=has_grad,
70
+ n_steps=len(steps),
71
+ final_loss=finite[-1] if finite else None,
72
+ best_loss=min(finite) if finite else None,
73
+ loss_trend="diverged" if diverged_at is not None else _trend(finite),
74
+ diverged_at=diverged_at,
75
+ loss_spikes=loss_spikes,
76
+ grad_norm_spikes=grad_spikes,
77
+ )
@@ -0,0 +1,230 @@
1
+ """Distributed (data-parallel) analyzer — the headline vertical.
2
+
3
+ In synchronous data-parallel training every rank must reach the gradient
4
+ all-reduce before *any* rank can proceed. The step is therefore gated by the
5
+ **slowest** rank's compute (the critical path); every faster rank sits idle at
6
+ the barrier. That idle time is pure waste, and it is invisible to any
7
+ single-rank profiler — you only see it by putting all ranks on one timeline.
8
+
9
+ This module aligns the per-rank step timelines and computes:
10
+
11
+ - **Critical-path wall loss** — wall time lost because the step waits for the
12
+ slowest rank instead of the average rank.
13
+ - **Straggler attribution with a persistence test** — *which* rank is slow and,
14
+ crucially, whether it is a *consistent* straggler (a bad GPU/node) or just
15
+ noise. We test the count of steps a rank is the critical path against the
16
+ null hypothesis that the slowest rank is uniformly random (Binomial(S, 1/N)),
17
+ via a one-sided normal approximation z-score.
18
+ - **Load imbalance** — robust coefficient of variation of per-rank compute.
19
+ - **Communication fraction** — share of step time spent in collectives.
20
+ - **Sync skew** — how far ahead the fastest ranks arrive at the barrier.
21
+
22
+ Everything is pure-stdlib and numerically careful (``math.fsum``, robust
23
+ medians from :mod:`pytscope.analyzers.stats`).
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import math
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+
32
+ from ..core.events import COMM, COMPUTE_PHASES, StepRecord
33
+ from ..core.store import RunStore
34
+ from .stats import median
35
+
36
+
37
+ @dataclass
38
+ class RankStat:
39
+ rank: int
40
+ mean_compute: float # mean per-step local compute (non-comm), seconds
41
+ median_compute: float
42
+ slowest_count: int # steps where this rank was the critical path
43
+ slowest_fraction: float
44
+ straggler_z: float # z-score of slowest_count vs Binomial(S, 1/N)
45
+ rel_slowdown: float # median(this_rank_compute / step_median_compute) - 1
46
+
47
+
48
+ @dataclass
49
+ class DistributedSummary:
50
+ world_size: int
51
+ n_steps: int # number of aligned steps analyzed
52
+ mean_step_wall: float # mean critical-path step time (max compute + comm)
53
+ mean_comm_fraction: float # comm / step wall
54
+ imbalance_cv: float # median over steps of CV(per-rank compute)
55
+ sync_skew: float # median over steps of (max - median) compute, seconds
56
+ wall_frac_lost_to_imbalance: float # (max-mean)/max summed over steps
57
+ ranks: list[RankStat] = field(default_factory=list)
58
+ straggler: RankStat | None = None # worst *persistent* straggler, if any
59
+
60
+ @property
61
+ def has_straggler(self) -> bool:
62
+ return self.straggler is not None
63
+
64
+
65
+ # --- loading -------------------------------------------------------------
66
+
67
+
68
+ def load_multirank(run_dir) -> dict[int, RunStore]:
69
+ """Load a distributed run written by ``Profiler(distributed=True)``.
70
+
71
+ Layout is ``run_dir/rank{k}/{steps.jsonl,run.json}``. Returns a mapping of
72
+ rank -> RunStore. Returns ``{}`` if the directory has no per-rank subdirs.
73
+ """
74
+ run_dir = Path(run_dir)
75
+ ranks: dict[int, RunStore] = {}
76
+ if not run_dir.is_dir():
77
+ return ranks
78
+ for child in sorted(run_dir.iterdir()):
79
+ if child.is_dir() and child.name.startswith("rank"):
80
+ suffix = child.name[len("rank") :]
81
+ if suffix.isdigit():
82
+ ranks[int(suffix)] = RunStore.load(child)
83
+ return ranks
84
+
85
+
86
+ def is_multirank(run_dir) -> bool:
87
+ run_dir = Path(run_dir)
88
+ if not run_dir.is_dir():
89
+ return False
90
+ return any(
91
+ c.is_dir() and c.name.startswith("rank") and c.name[len("rank") :].isdigit()
92
+ for c in run_dir.iterdir()
93
+ )
94
+
95
+
96
+ # --- helpers -------------------------------------------------------------
97
+
98
+
99
+ def _compute_seconds(rec: StepRecord) -> float:
100
+ """Local compute time for a step: everything that is not communication."""
101
+ return math.fsum(v for p, v in rec.phases.items() if p in COMPUTE_PHASES)
102
+
103
+
104
+ def _comm_seconds(rec: StepRecord) -> float:
105
+ return float(rec.phases.get(COMM, 0.0))
106
+
107
+
108
+ def _mean(xs: list[float]) -> float:
109
+ return math.fsum(xs) / len(xs) if xs else 0.0
110
+
111
+
112
+ def _cv(xs: list[float]) -> float:
113
+ """Coefficient of variation (std / mean), population std. 0 if mean<=0."""
114
+ n = len(xs)
115
+ if n < 2:
116
+ return 0.0
117
+ m = _mean(xs)
118
+ if m <= 0:
119
+ return 0.0
120
+ var = math.fsum((x - m) ** 2 for x in xs) / n
121
+ return math.sqrt(var) / m
122
+
123
+
124
+ # --- the analyzer --------------------------------------------------------
125
+
126
+
127
+ def analyze_distributed(
128
+ ranks: dict[int, RunStore],
129
+ straggler_z: float = 3.0,
130
+ min_rel_slowdown: float = 0.05,
131
+ ) -> DistributedSummary | None:
132
+ """Analyze aligned per-rank timelines.
133
+
134
+ ``straggler_z``: z-score threshold for calling a rank a *persistent*
135
+ straggler (3.0 ≈ p<0.002, one-sided). ``min_rel_slowdown``: also require the
136
+ flagged rank to be at least this much slower than the per-step median, so we
137
+ don't flag a statistically-consistent-but-negligible straggler.
138
+ """
139
+ if len(ranks) < 2:
140
+ return None
141
+
142
+ rank_ids = sorted(ranks)
143
+ n = len(rank_ids)
144
+ # Per-rank step index -> record, so we can align by the shared step number.
145
+ by_rank: dict[int, dict[int, StepRecord]] = {
146
+ r: {rec.step: rec for rec in ranks[r].steps} for r in rank_ids
147
+ }
148
+ common_steps = sorted(set.intersection(*[set(by_rank[r]) for r in rank_ids]))
149
+ if not common_steps:
150
+ return None
151
+
152
+ # Accumulators.
153
+ step_walls: list[float] = []
154
+ comm_fracs: list[float] = []
155
+ imbalance_cvs: list[float] = []
156
+ sync_skews: list[float] = []
157
+ slowest_count = {r: 0 for r in rank_ids}
158
+ rel_slowdowns: dict[int, list[float]] = {r: [] for r in rank_ids}
159
+ compute_samples: dict[int, list[float]] = {r: [] for r in rank_ids}
160
+ sum_max = 0.0
161
+ sum_lost = 0.0 # sum of (max - mean) compute
162
+
163
+ for s in common_steps:
164
+ comps = {r: _compute_seconds(by_rank[r][s]) for r in rank_ids}
165
+ comms = {r: _comm_seconds(by_rank[r][s]) for r in rank_ids}
166
+ cvals = [comps[r] for r in rank_ids]
167
+ cmax = max(cvals)
168
+ cmean = _mean(cvals)
169
+ cmed = median(cvals)
170
+ # Critical path: slowest compute, plus the max comm observed that step.
171
+ wall = cmax + max(comms.values())
172
+ step_walls.append(wall)
173
+ comm_fracs.append((max(comms.values()) / wall) if wall > 0 else 0.0)
174
+ imbalance_cvs.append(_cv(cvals))
175
+ sync_skews.append(cmax - cmed)
176
+ sum_max += cmax
177
+ sum_lost += cmax - cmean
178
+
179
+ # Critical rank (argmax compute); ties -> lowest rank id (deterministic).
180
+ crit = min(rank_ids, key=lambda r: (-comps[r], r))
181
+ slowest_count[crit] += 1
182
+ for r in rank_ids:
183
+ compute_samples[r].append(comps[r])
184
+ if cmed > 0:
185
+ rel_slowdowns[r].append(comps[r] / cmed - 1.0)
186
+
187
+ s_total = len(common_steps)
188
+ # Binomial(S, 1/N) null for "is rank r the slowest".
189
+ p0 = 1.0 / n
190
+ sd = math.sqrt(s_total * p0 * (1.0 - p0)) or 1.0
191
+ expected = s_total * p0
192
+
193
+ rank_stats: list[RankStat] = []
194
+ for r in rank_ids:
195
+ z = (slowest_count[r] - expected) / sd
196
+ rank_stats.append(
197
+ RankStat(
198
+ rank=r,
199
+ mean_compute=_mean(compute_samples[r]),
200
+ median_compute=median(compute_samples[r]),
201
+ slowest_count=slowest_count[r],
202
+ slowest_fraction=slowest_count[r] / s_total,
203
+ straggler_z=z,
204
+ rel_slowdown=median(rel_slowdowns[r]) if rel_slowdowns[r] else 0.0,
205
+ )
206
+ )
207
+
208
+ # A *single* persistent straggler exists only if exactly one rank clears both
209
+ # the statistical-persistence bar (z) and the practical-magnitude bar
210
+ # (rel_slowdown). If several ranks do (e.g. two slow nodes alternating as the
211
+ # critical path), that's load imbalance, not one straggler — leave it to the
212
+ # imbalance rule rather than fingering one rank misleadingly.
213
+ significant = [
214
+ rs
215
+ for rs in rank_stats
216
+ if rs.straggler_z >= straggler_z and rs.rel_slowdown >= min_rel_slowdown
217
+ ]
218
+ straggler = significant[0] if len(significant) == 1 else None
219
+
220
+ return DistributedSummary(
221
+ world_size=n,
222
+ n_steps=s_total,
223
+ mean_step_wall=_mean(step_walls),
224
+ mean_comm_fraction=_mean(comm_fracs),
225
+ imbalance_cv=median(imbalance_cvs),
226
+ sync_skew=median(sync_skews),
227
+ wall_frac_lost_to_imbalance=(sum_lost / sum_max) if sum_max > 0 else 0.0,
228
+ ranks=rank_stats,
229
+ straggler=straggler,
230
+ )
@@ -0,0 +1,149 @@
1
+ """The Training Efficiency Budget — one accounting identity for the whole run.
2
+
3
+ Every other analyzer produces a *finding*. This one produces a **budget**: it
4
+ decomposes the wall time of training into named line items that, by construction,
5
+ sum back to the measured wall time. Anchored at the top by ``useful_compute`` —
6
+ the time the math would take at the hardware's peak throughput — the rest of the
7
+ budget is, line by line, *recoverable* time:
8
+
9
+ wall = useful_compute (irreducible: the FLOPs, at peak)
10
+ + compute_overhead (your kernels don't hit peak)
11
+ + data_stall (waiting on the dataloader)
12
+ + communication (collective time on the timeline)
13
+ + other (everything else attributed)
14
+
15
+ ``MFU`` (Model FLOPs Utilization) falls straight out: ``useful_compute / wall``.
16
+ Because the phase timeline already partitions each step, the decomposition is
17
+ **exact** — the line items sum to the attributed wall with no fudge factor, which
18
+ makes the model falsifiable (a wrong term shows up as a non-zero residual).
19
+
20
+ This turns a profiler into an advisor: each recoverable line is a number of
21
+ seconds you could win back, so fixes rank themselves by payoff.
22
+
23
+ FLOPs and peak are optional. With them you get a true MFU anchor; without them
24
+ the budget still decomposes wall time, with ``useful_compute`` falling back to
25
+ measured compute (MFU unknown).
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import math
31
+ from collections.abc import Sequence
32
+ from dataclasses import dataclass, field
33
+
34
+ from ..core.events import (
35
+ COMM,
36
+ COMPUTE,
37
+ DATA,
38
+ OTHER,
39
+ StepRecord,
40
+ )
41
+
42
+ # Phases that count as local compute (where FLOPs are spent).
43
+ _COMPUTE_PHASES = ("forward", "backward", "optimizer", COMPUTE)
44
+
45
+
46
+ @dataclass
47
+ class BudgetLine:
48
+ name: str
49
+ seconds: float
50
+ fraction: float # of attributed wall
51
+ recoverable: bool # is this time you could win back?
52
+
53
+
54
+ @dataclass
55
+ class EfficiencyBudget:
56
+ wall: float # attributed wall (sum of all phases over the run)
57
+ n_steps: int
58
+ compute_measured: float # forward+backward+optimizer+compute, seconds
59
+ ideal_compute: float | None # FLOPs-anchored minimum, seconds (if known)
60
+ mfu: float | None # useful_compute / wall, in [0, 1] (if FLOPs+peak known)
61
+ flops_per_step: float | None
62
+ peak_flops: float | None
63
+ lines: list[BudgetLine] = field(default_factory=list)
64
+
65
+ @property
66
+ def efficiency(self) -> float:
67
+ """Useful fraction of wall (== MFU when FLOPs+peak are known)."""
68
+ useful = next(
69
+ (ln.seconds for ln in self.lines if ln.name == "useful_compute"), 0.0
70
+ )
71
+ return useful / self.wall if self.wall > 0 else 0.0
72
+
73
+ @property
74
+ def recoverable_lines(self) -> list[BudgetLine]:
75
+ """Recoverable line items, largest first — the ranked fix list."""
76
+ return sorted(
77
+ (ln for ln in self.lines if ln.recoverable and ln.seconds > 0),
78
+ key=lambda ln: ln.seconds,
79
+ reverse=True,
80
+ )
81
+
82
+ @property
83
+ def top_recoverable(self) -> BudgetLine | None:
84
+ rec = self.recoverable_lines
85
+ return rec[0] if rec else None
86
+
87
+
88
+ def _phase_total(steps: Sequence[StepRecord], phase: str) -> float:
89
+ return math.fsum(s.phases.get(phase, 0.0) for s in steps)
90
+
91
+
92
+ def analyze_efficiency(
93
+ steps: Sequence[StepRecord],
94
+ *,
95
+ flops_per_step: float | None = None,
96
+ peak_flops: float | None = None,
97
+ ) -> EfficiencyBudget | None:
98
+ """Build the efficiency budget from the per-step phase timeline.
99
+
100
+ ``flops_per_step``: training FLOPs per step (forward+backward; see
101
+ ``hardware.measure_flops``). ``peak_flops``: device peak FLOP/s. Both
102
+ optional — supply them for a true MFU anchor.
103
+ """
104
+ steps = list(steps)
105
+ if not steps:
106
+ return None
107
+
108
+ data = _phase_total(steps, DATA)
109
+ comm = _phase_total(steps, COMM)
110
+ other = _phase_total(steps, OTHER)
111
+ compute_measured = math.fsum(_phase_total(steps, p) for p in _COMPUTE_PHASES)
112
+ wall = compute_measured + data + comm + other
113
+ if wall <= 0:
114
+ return None
115
+
116
+ n = len(steps)
117
+ ideal_compute: float | None = None
118
+ mfu: float | None = None
119
+ if flops_per_step and peak_flops and peak_flops > 0:
120
+ ideal_compute = flops_per_step * n / peak_flops
121
+
122
+ # Useful compute is the irreducible work. If we know the FLOPs-anchored ideal,
123
+ # it caps useful at the measured compute (you can't be "more than 100%
124
+ # efficient"; if the estimate exceeds measured, treat compute as all-useful
125
+ # and surface no negative overhead).
126
+ if ideal_compute is not None:
127
+ useful = min(ideal_compute, compute_measured)
128
+ mfu = useful / wall
129
+ else:
130
+ useful = compute_measured
131
+ overhead = compute_measured - useful
132
+
133
+ lines = [
134
+ BudgetLine("useful_compute", useful, useful / wall, recoverable=False),
135
+ BudgetLine("compute_overhead", overhead, overhead / wall, recoverable=True),
136
+ BudgetLine("data_stall", data, data / wall, recoverable=True),
137
+ BudgetLine("communication", comm, comm / wall, recoverable=True),
138
+ BudgetLine("other", other, other / wall, recoverable=True),
139
+ ]
140
+ return EfficiencyBudget(
141
+ wall=wall,
142
+ n_steps=n,
143
+ compute_measured=compute_measured,
144
+ ideal_compute=ideal_compute,
145
+ mfu=mfu,
146
+ flops_per_step=flops_per_step,
147
+ peak_flops=peak_flops,
148
+ lines=lines,
149
+ )
@@ -0,0 +1,58 @@
1
+ """Memory analyzer (vertical #2) — reads the per-step ``memory`` block.
2
+
3
+ Operates on whatever the memory collector stored (bytes): ``alloc``,
4
+ ``reserved``, ``peak_alloc``, ``peak_reserved``. Pure functions over the
5
+ timeline, like the timing analyzer.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import math
11
+ from dataclasses import dataclass, field
12
+
13
+ from ..core.events import StepRecord
14
+ from .stats import robust_slope
15
+
16
+
17
+ @dataclass
18
+ class MemorySummary:
19
+ has_memory: bool = False
20
+ n_steps: int = 0
21
+ peak_alloc_bytes: float = 0.0
22
+ peak_reserved_bytes: float = 0.0
23
+ mean_alloc_bytes: float = 0.0
24
+ # Fraction of reserved memory not actually allocated — fragmentation/slack.
25
+ fragmentation: float = 0.0
26
+ # Robust per-step growth of allocated memory — a leak signal.
27
+ growth_bytes_per_step: float = 0.0
28
+ alloc_series: list[float] = field(default_factory=list)
29
+
30
+
31
+ def analyze_memory(steps: list[StepRecord]) -> MemorySummary:
32
+ allocs = [s.memory["alloc"] for s in steps if "alloc" in s.memory]
33
+ if not allocs:
34
+ return MemorySummary(has_memory=False, n_steps=len(steps))
35
+
36
+ reserved = [s.memory["reserved"] for s in steps if "reserved" in s.memory]
37
+ peak_alloc = [
38
+ s.memory.get("peak_alloc", s.memory.get("alloc", 0.0)) for s in steps if s.memory
39
+ ]
40
+ peak_reserved = [
41
+ s.memory.get("peak_reserved", s.memory.get("reserved", 0.0))
42
+ for s in steps
43
+ if s.memory
44
+ ]
45
+
46
+ frag_samples = [(r - a) / r for a, r in zip(allocs, reserved) if r > 0]
47
+ fragmentation = math.fsum(frag_samples) / len(frag_samples) if frag_samples else 0.0
48
+
49
+ return MemorySummary(
50
+ has_memory=True,
51
+ n_steps=len(steps),
52
+ peak_alloc_bytes=max(peak_alloc) if peak_alloc else max(allocs),
53
+ peak_reserved_bytes=max(peak_reserved) if peak_reserved else 0.0,
54
+ mean_alloc_bytes=math.fsum(allocs) / len(allocs),
55
+ fragmentation=fragmentation,
56
+ growth_bytes_per_step=robust_slope(allocs),
57
+ alloc_series=allocs,
58
+ )