fpwap 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. fpwap-0.1.0/.github/workflows/ci.yml +36 -0
  2. fpwap-0.1.0/.github/workflows/release.yml +66 -0
  3. fpwap-0.1.0/.gitignore +19 -0
  4. fpwap-0.1.0/PKG-INFO +293 -0
  5. fpwap-0.1.0/README.md +272 -0
  6. fpwap-0.1.0/ROADMAP.md +108 -0
  7. fpwap-0.1.0/SPEC.md +690 -0
  8. fpwap-0.1.0/bench/results/70b_chunk_size_sweep_1k.csv +4 -0
  9. fpwap-0.1.0/bench/results/70b_streaming_n_scaling.csv +5 -0
  10. fpwap-0.1.0/bench/results/70b_streaming_n_scaling_cold.csv +5 -0
  11. fpwap-0.1.0/bench/results/8b_chunk_size_sweep_1k.csv +7 -0
  12. fpwap-0.1.0/bench/results/8b_chunk_size_sweep_8k.csv +6 -0
  13. fpwap-0.1.0/bench/results/8b_preloaded_n_scaling.csv +5 -0
  14. fpwap-0.1.0/pyproject.toml +62 -0
  15. fpwap-0.1.0/scripts/bench_bucketed_padding.py +157 -0
  16. fpwap-0.1.0/scripts/bench_n_scaling.py +311 -0
  17. fpwap-0.1.0/scripts/bench_warm_start.py +299 -0
  18. fpwap-0.1.0/scripts/benchmark.py +299 -0
  19. fpwap-0.1.0/scripts/harness_adapter.py +248 -0
  20. fpwap-0.1.0/scripts/study_chunk_size.py +187 -0
  21. fpwap-0.1.0/src/fpwap/__init__.py +43 -0
  22. fpwap-0.1.0/src/fpwap/buffer.py +164 -0
  23. fpwap-0.1.0/src/fpwap/callbacks/__init__.py +3 -0
  24. fpwap-0.1.0/src/fpwap/callbacks/base.py +51 -0
  25. fpwap-0.1.0/src/fpwap/callbacks/common.py +278 -0
  26. fpwap-0.1.0/src/fpwap/cost_model.py +108 -0
  27. fpwap-0.1.0/src/fpwap/engine.py +1808 -0
  28. fpwap-0.1.0/src/fpwap/extractor.py +113 -0
  29. fpwap-0.1.0/src/fpwap/loader.py +527 -0
  30. fpwap-0.1.0/src/fpwap/models/__init__.py +21 -0
  31. fpwap-0.1.0/src/fpwap/models/base.py +86 -0
  32. fpwap-0.1.0/src/fpwap/models/gpt2.py +105 -0
  33. fpwap-0.1.0/src/fpwap/models/llama.py +138 -0
  34. fpwap-0.1.0/src/fpwap/preflight.py +97 -0
  35. fpwap-0.1.0/src/fpwap/storage/__init__.py +78 -0
  36. fpwap-0.1.0/src/fpwap/storage/memmap.py +658 -0
  37. fpwap-0.1.0/src/fpwap/types.py +114 -0
  38. fpwap-0.1.0/tests/__init__.py +0 -0
  39. fpwap-0.1.0/tests/conftest.py +36 -0
  40. fpwap-0.1.0/tests/gpu/__init__.py +0 -0
  41. fpwap-0.1.0/tests/gpu/test_buffer_on_cpu.py +106 -0
  42. fpwap-0.1.0/tests/gpu/test_disk_buffer_bit_exact.py +144 -0
  43. fpwap-0.1.0/tests/gpu/test_emit_staging_gpu.py +92 -0
  44. fpwap-0.1.0/tests/gpu/test_final_norm_bit_exact.py +165 -0
  45. fpwap-0.1.0/tests/gpu/test_forward_bit_perfect.py +115 -0
  46. fpwap-0.1.0/tests/gpu/test_moe_streaming_bit_exact.py +171 -0
  47. fpwap-0.1.0/tests/gpu/test_real_llama_bit_exact.py +144 -0
  48. fpwap-0.1.0/tests/gpu/test_real_model_families_bit_exact.py +159 -0
  49. fpwap-0.1.0/tests/gpu/test_streaming_bit_exact.py +136 -0
  50. fpwap-0.1.0/tests/integration/__init__.py +0 -0
  51. fpwap-0.1.0/tests/integration/test_activations_as_path.py +363 -0
  52. fpwap-0.1.0/tests/integration/test_bucketed_padding.py +293 -0
  53. fpwap-0.1.0/tests/integration/test_buffer_aliasing_safety.py +108 -0
  54. fpwap-0.1.0/tests/integration/test_buffer_device.py +108 -0
  55. fpwap-0.1.0/tests/integration/test_chunk_drain_progress.py +170 -0
  56. fpwap-0.1.0/tests/integration/test_chunk_engine.py +313 -0
  57. fpwap-0.1.0/tests/integration/test_early_exit.py +230 -0
  58. fpwap-0.1.0/tests/integration/test_engine_forcing.py +117 -0
  59. fpwap-0.1.0/tests/integration/test_engine_llama.py +137 -0
  60. fpwap-0.1.0/tests/integration/test_engine_ragged_emit.py +187 -0
  61. fpwap-0.1.0/tests/integration/test_engine_streaming_cpu.py +140 -0
  62. fpwap-0.1.0/tests/integration/test_extra_hooks_gpt2.py +129 -0
  63. fpwap-0.1.0/tests/integration/test_extra_hooks_llama.py +127 -0
  64. fpwap-0.1.0/tests/integration/test_extractor.py +261 -0
  65. fpwap-0.1.0/tests/integration/test_harness_adapter.py +170 -0
  66. fpwap-0.1.0/tests/integration/test_incremental_pca.py +144 -0
  67. fpwap-0.1.0/tests/integration/test_layer_artifacts.py +86 -0
  68. fpwap-0.1.0/tests/integration/test_memmap_backend.py +214 -0
  69. fpwap-0.1.0/tests/integration/test_model_families.py +212 -0
  70. fpwap-0.1.0/tests/integration/test_padded_batch.py +146 -0
  71. fpwap-0.1.0/tests/integration/test_preflight_minimum.py +78 -0
  72. fpwap-0.1.0/tests/integration/test_progress_reporter.py +88 -0
  73. fpwap-0.1.0/tests/integration/test_raw_activations.py +86 -0
  74. fpwap-0.1.0/tests/integration/test_readme_workflow.py +100 -0
  75. fpwap-0.1.0/tests/integration/test_steer_and_diff.py +246 -0
  76. fpwap-0.1.0/tests/integration/test_sublayer_writeback.py +121 -0
  77. fpwap-0.1.0/tests/integration/test_teardown_progress.py +96 -0
  78. fpwap-0.1.0/tests/integration/test_verify.py +110 -0
  79. fpwap-0.1.0/tests/integration/test_warm_start.py +253 -0
  80. fpwap-0.1.0/tests/integration/test_writeback_residual_pre.py +110 -0
  81. fpwap-0.1.0/tests/unit/__init__.py +0 -0
  82. fpwap-0.1.0/tests/unit/test_auto_microbatch.py +192 -0
  83. fpwap-0.1.0/tests/unit/test_bucketing.py +123 -0
  84. fpwap-0.1.0/tests/unit/test_buffer_pinned.py +41 -0
  85. fpwap-0.1.0/tests/unit/test_checkpoint_conversion.py +243 -0
  86. fpwap-0.1.0/tests/unit/test_chunk_drain_timing.py +50 -0
  87. fpwap-0.1.0/tests/unit/test_chunk_helpers.py +63 -0
  88. fpwap-0.1.0/tests/unit/test_cost_model.py +173 -0
  89. fpwap-0.1.0/tests/unit/test_dtype_table.py +38 -0
  90. fpwap-0.1.0/tests/unit/test_emit_drain.py +103 -0
  91. fpwap-0.1.0/tests/unit/test_emit_ragged.py +275 -0
  92. fpwap-0.1.0/tests/unit/test_emit_staging.py +78 -0
  93. fpwap-0.1.0/tests/unit/test_emit_timing.py +67 -0
  94. fpwap-0.1.0/tests/unit/test_final_norm.py +137 -0
  95. fpwap-0.1.0/tests/unit/test_harness_fingerprint.py +56 -0
  96. fpwap-0.1.0/tests/unit/test_index_builder.py +34 -0
  97. fpwap-0.1.0/tests/unit/test_load_from_cache.py +99 -0
  98. fpwap-0.1.0/tests/unit/test_loader_layer.py +101 -0
  99. fpwap-0.1.0/tests/unit/test_memmap_shard_fadvise.py +89 -0
  100. fpwap-0.1.0/tests/unit/test_memmap_shard_zero_copy.py +38 -0
  101. fpwap-0.1.0/tests/unit/test_plumbing_moe_mlp_tuple.py +71 -0
  102. fpwap-0.1.0/tests/unit/test_preflight_math.py +69 -0
  103. fpwap-0.1.0/tests/unit/test_preflight_report.py +83 -0
  104. fpwap-0.1.0/tests/unit/test_preloop_timing.py +123 -0
  105. fpwap-0.1.0/tests/unit/test_profile_metrics.py +37 -0
  106. fpwap-0.1.0/tests/unit/test_result_artifact.py +68 -0
  107. fpwap-0.1.0/tests/unit/test_setup_timing.py +104 -0
  108. fpwap-0.1.0/tests/unit/test_shard_page_advisor.py +190 -0
  109. fpwap-0.1.0/tests/unit/test_snapshot_resolution.py +86 -0
  110. fpwap-0.1.0/tests/unit/test_teardown_timing.py +87 -0
  111. fpwap-0.1.0/tests/unit/test_tied_weights.py +72 -0
  112. fpwap-0.1.0/tests/unit/test_warm_start_stats.py +63 -0
  113. fpwap-0.1.0/uv.lock +1495 -0
@@ -0,0 +1,36 @@
1
+ name: ci
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+
8
+ jobs:
9
+ lint-typecheck-test:
10
+ runs-on: ubuntu-latest
11
+ strategy:
12
+ fail-fast: false
13
+ matrix:
14
+ python-version: ["3.11", "3.12"]
15
+ steps:
16
+ - uses: actions/checkout@v4
17
+
18
+ - name: Install uv
19
+ uses: astral-sh/setup-uv@v3
20
+ with:
21
+ enable-cache: true
22
+
23
+ - name: Set up Python ${{ matrix.python-version }}
24
+ run: uv python install ${{ matrix.python-version }}
25
+
26
+ - name: Sync dev dependencies
27
+ run: uv sync --extra dev
28
+
29
+ - name: Lint (ruff)
30
+ run: uv run ruff check .
31
+
32
+ - name: Typecheck (mypy)
33
+ run: uv run mypy src
34
+
35
+ - name: Unit tests (no GPU)
36
+ run: uv run pytest -m 'not gpu'
@@ -0,0 +1,66 @@
1
+ name: release
2
+
3
+ # Tag, create a GitHub release, and publish to PyPI whenever main lands
4
+ # with a pyproject version that has no tag yet. Releasing = bumping
5
+ # `version` in pyproject.toml; merges that don't bump are no-ops here.
6
+ #
7
+ # PyPI auth is trusted publishing (OIDC) — registered on pypi.org for
8
+ # this repo + workflow + the `pypi` environment. No tokens.
9
+
10
+ on:
11
+ push:
12
+ branches: [main]
13
+
14
+ permissions:
15
+ contents: write
16
+
17
+ jobs:
18
+ tag-release:
19
+ runs-on: ubuntu-latest
20
+ outputs:
21
+ released: ${{ steps.release.outputs.released }}
22
+ version: ${{ steps.version.outputs.version }}
23
+ steps:
24
+ - uses: actions/checkout@v4
25
+
26
+ - name: Read version
27
+ id: version
28
+ run: |
29
+ echo "version=$(grep -Po '(?<=^version = ")[^"]*' pyproject.toml)" >> "$GITHUB_OUTPUT"
30
+
31
+ - name: Create release if tag missing
32
+ id: release
33
+ env:
34
+ GH_TOKEN: ${{ github.token }}
35
+ run: |
36
+ tag="v${{ steps.version.outputs.version }}"
37
+ if gh release view "$tag" --repo "$GITHUB_REPOSITORY" >/dev/null 2>&1; then
38
+ echo "release $tag already exists — skipping"
39
+ echo "released=false" >> "$GITHUB_OUTPUT"
40
+ else
41
+ gh release create "$tag" \
42
+ --repo "$GITHUB_REPOSITORY" \
43
+ --target "$GITHUB_SHA" \
44
+ --title "fpwap $tag" \
45
+ --generate-notes
46
+ echo "released=true" >> "$GITHUB_OUTPUT"
47
+ fi
48
+
49
+ publish-pypi:
50
+ needs: tag-release
51
+ if: needs.tag-release.outputs.released == 'true'
52
+ runs-on: ubuntu-latest
53
+ environment: pypi
54
+ permissions:
55
+ id-token: write # OIDC for PyPI trusted publishing
56
+ steps:
57
+ - uses: actions/checkout@v4
58
+
59
+ - name: Install uv
60
+ uses: astral-sh/setup-uv@v3
61
+
62
+ - name: Build sdist and wheel
63
+ run: uv build
64
+
65
+ - name: Publish to PyPI
66
+ uses: pypa/gh-action-pypi-publish@release/v1
fpwap-0.1.0/.gitignore ADDED
@@ -0,0 +1,19 @@
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.egg-info/
5
+ build/
6
+ dist/
7
+ .pytest_cache/
8
+ .ruff_cache/
9
+ .mypy_cache/
10
+ .coverage
11
+ htmlcov/
12
+
13
+ .venv/
14
+ .env
15
+
16
+ .vscode/
17
+ .idea/
18
+ CLAUDE.md
19
+ .claude/
fpwap-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,293 @@
1
+ Metadata-Version: 2.4
2
+ Name: fpwap
3
+ Version: 0.1.0
4
+ Summary: Forward Pass Weight Amortization Protocol — invert the inference loop for large transformer models.
5
+ Author: Michael Klear
6
+ License: MIT
7
+ Requires-Python: >=3.11
8
+ Requires-Dist: accelerate>=1.0
9
+ Requires-Dist: numpy>=1.26
10
+ Requires-Dist: pyarrow>=15.0
11
+ Requires-Dist: safetensors>=0.4
12
+ Requires-Dist: torch>=2.1
13
+ Requires-Dist: tqdm>=4.66
14
+ Requires-Dist: transformers>=4.40
15
+ Provides-Extra: dev
16
+ Requires-Dist: mypy>=1.10; extra == 'dev'
17
+ Requires-Dist: pytest-cov>=5.0; extra == 'dev'
18
+ Requires-Dist: pytest>=8.0; extra == 'dev'
19
+ Requires-Dist: ruff>=0.6; extra == 'dev'
20
+ Description-Content-Type: text/markdown
21
+
22
+ # fpwap — Forward Pass Weight Amortization Protocol
23
+
24
+ A single-purpose library for running activation extraction over large transformer models **whose weights don't fit in your GPU**, across datasets of **thousands of prompts**, on **consumer hardware**, at **full precision**.
25
+
26
+ ## The regime
27
+
28
+ You're a mech-interp researcher. Your model is bigger than your VRAM. Your dataset is thousands of prompts. Adjacent tools each fail in a way that changes what you're studying:
29
+
30
+ - **Quantization** (bitsandbytes, GPTQ) changes the activations you're trying to read.
31
+ - **Inference servers** (vLLM, TGI) optimize next-token throughput, not residual-stream extraction.
32
+ - **`accelerate.cpu_offload`** streams weights once per batch — 10k prompts × 80 layers on a 70B model is hundreds of TB of weight I/O, hours of wall-clock per dataset pass.
33
+ - **Cloud GPUs** break your interactive iteration loop and cost hundreds per experiment.
34
+
35
+ fpwap inverts the inference loop: **load each layer once, stream the whole dataset through it**, spill intermediates to disk, move on. Total weight I/O drops from `O(N_batches × N_layers)` to `O(N_layers)`. A 10k-sample Llama-3.1-70B extraction on a 32 GB consumer GPU runs in roughly the wall-clock of a single batch under the naive approach — with the same weights, no quantization, no cloud.
36
+
37
+ ## Aspirational performance
38
+
39
+ Targets, not measurements. These are the numbers fpwap is being built to; each row unlocks only after its milestone lands (70B gates on the bit-perfect test; 405B gates on the mmap-from-HF-cache path). Replaced by measured benchmarks as they come in.
40
+
41
+ ### Reference machine
42
+
43
+ | Component | Spec |
44
+ | --------- | ---- |
45
+ | GPU | NVIDIA RTX 5090, 32 GB VRAM |
46
+ | CPU | Modern desktop-class, 16+ cores |
47
+ | RAM | 128 GB DDR5 |
48
+ | Storage | NVMe SSD (Gen 4+), ≥ 1 TB free |
49
+ | Interconnect | PCIe 5.0 x16 |
50
+ | Network | None — fully local, no cloud |
51
+
52
+ ### Dataset-scale activation extraction (10,000 prompts × 256 tokens = 2.56M tokens)
53
+
54
+ Residual stream (`residual_post`) captured at every layer, pooled to last token, persisted to disk. `RawActivations(layers="all")`.
55
+
56
+ | Model | Weights (bf16) | Loading strategy | Wall-clock target | Throughput target | vs. naive `accelerate.cpu_offload` |
57
+ | ----- | -------------- | ---------------- | ----------------- | ----------------- | ----------------------------------- |
58
+ | Llama-3.1-8B | 16 GB | `cpu_offload` | ≤ 8 min | ≥ 5,000 tok/s | ≥ 4× faster |
59
+ | Llama-3.1-70B | 140 GB | `disk_offload` | ≤ 45 min | ≥ 950 tok/s | ≥ 4× faster (naive ≈ 3 h) |
60
+ | Llama-3.1-405B | 810 GB | `mmap_from_cache` | ≤ 4 h | ≥ 180 tok/s | naive infeasible (OOM in RAM) |
61
+
62
+ Throughput is end-to-end tokens per second — total tokens processed (samples × seq_len) divided by wall-clock from `fpwap(...).run()` entry to return, including weight I/O, forward, callbacks, and buffer write.
63
+
64
+ ### Single-pass cost per layer (Llama-3.1-70B, 1.75 GB weights per layer)
65
+
66
+ The inner loop that fpwap is optimizing. On the reference machine, per layer, per full sweep of 10k × 256-token samples:
67
+
68
+ | Phase | Budget | Notes |
69
+ | ----- | ------ | ----- |
70
+ | Weight load | ≤ 1.0 s | NVMe → CPU → GPU, `disk_offload` path; once per layer, not once per batch |
71
+ | Forward | ≤ 15 s | 10k samples, bf16, batched at engine's discretion |
72
+ | Callback | ≤ 1.0 s | Aggregate across all registered callbacks for this layer |
73
+ | Buffer write | ≤ 1.0 s | Pooled activations to memmap; raw `[N, S, H]` budget is higher |
74
+ | **Per-layer total** | **≤ 18 s** | × 80 layers ≈ 24 min (leaves headroom vs. 45 min end-to-end target) |
75
+
76
+ ### Overhead budgets
77
+
78
+ | Surface | Budget | Why |
79
+ | ------- | ------ | --- |
80
+ | Profile + progress, combined | < 1% wall-clock | Has to stay on by default — see the [Observability](#observability) section |
81
+ | `verify=True` (vs. naive `cpu_offload` at every layer) | 2–3× slower | Correctness debugging only; not for production runs |
82
+ | Preflight | < 5 s | Rejects infeasible configurations before GPU contact |
83
+
84
+ ## The API
85
+
86
+ One verb. One callback class. One result.
87
+
88
+ ```python
89
+ from fpwap import Sweep
90
+ from fpwap.callbacks.common import RawActivations, IncrementalPCA, DiffOfMeans
91
+
92
+ run = Sweep(
93
+ model="meta-llama/Llama-3.1-70B",
94
+ dataset=my_dataset, # iterable of {"input_ids": ..., "label": ...}
95
+ seq_len=256,
96
+ callbacks=[
97
+ RawActivations(layers=[40, 45, 50]), # pooled by default
98
+ IncrementalPCA(layers="all", n_components=64),
99
+ DiffOfMeans(layers="all", label_fn=lambda s: s["label"]),
100
+ ],
101
+ )
102
+
103
+ plan = run.preflight()
104
+ print(plan.summary()) # check feasibility before GPU contact
105
+
106
+ result = run.run()
107
+ acts = result.activations(layer=45, hook="residual_post") # [N, H]
108
+ basis = result.artifact("pca_basis", layer=45)
109
+ ```
110
+
111
+ That is the entire user-facing surface for read-only workflows. No backend objects to construct. No `batch_size` knob to foot-gun. No `loader` / `accumulator` triple to wire up. Construction is cheap; `.preflight()` inspects the plan and rejects infeasible configurations with actionable messages; `.run()` executes.
112
+
113
+ ### Layer indexing
114
+
115
+ Hook names follow the HF `hidden_states` convention:
116
+
117
+ | Hook | Equals |
118
+ | ---- | ------ |
119
+ | `residual_pre` at layer `L` | `hidden_states[L]` (input to block `L`) |
120
+ | `residual_post` at layer `L` | `hidden_states[L+1]` (output of block `L`) |
121
+ | `attn_out` at layer `L` | attention sub-layer output at block `L` |
122
+ | `mlp_out` at layer `L` | MLP sub-layer output at block `L` |
123
+
124
+ No off-by-one translation at the call site.
125
+
126
+ ### Writing your own callback
127
+
128
+ Subclass `Callback`. Declare which layers and hooks you want; implement `on_batch`. Return an `Emit` to persist a tensor, a `WriteBack` to modify the residual before the next layer, or `None` to no-op.
129
+
130
+ ```python
131
+ from fpwap import Callback, Emit
132
+
133
+ class LastTokenLogNorm(Callback):
134
+ target_layers = [32]
135
+ target_hooks = ("residual_post",)
136
+ phase = "read"
137
+
138
+ def on_batch(self, layer_idx, hook, acts, sample_ids):
139
+ return Emit(acts[:, -1, :].norm(dim=-1).log())
140
+ ```
141
+
142
+ ### Write-backs and multi-pass workflows
143
+
144
+ The same entry point handles steering. A callback with `phase = "write"` modifies the residual stream between layers; artifacts from one run feed the next.
145
+
146
+ ```python
147
+ from fpwap.callbacks.common import SteerInBasis
148
+
149
+ # Pass 2: steer in the basis fit during pass 1
150
+ steer = Sweep(
151
+ model="meta-llama/Llama-3.1-70B",
152
+ dataset=my_dataset,
153
+ seq_len=256,
154
+ callbacks=[
155
+ SteerInBasis(
156
+ basis_artifact=result.artifact("pca_basis", layer=45),
157
+ direction_idx=0,
158
+ alpha=2.0,
159
+ layers=[45],
160
+ ),
161
+ ],
162
+ )
163
+ steered = steer.run()
164
+ ```
165
+
166
+ ## Observability
167
+
168
+ Performance is the product, so every run is profiled by default with a measurement overhead small enough (target: under 1% wall-clock) that you never have to opt in. When a run is slower than you want, the answer is already in `result.profile` — no re-running with `profile=True`.
169
+
170
+ ```python
171
+ result = run.run()
172
+
173
+ result.profile.summary() # human-readable breakdown per layer
174
+ result.profile.by_phase() # load / forward / callback / write
175
+ result.profile.slowest_layer() # where the time went
176
+ result.profile.bytes_moved() # weight I/O, buffer I/O
177
+ ```
178
+
179
+ Interactive progress is on by default — a tqdm-style bar across layers × batches, because a run on the workstation under your desk should not sit silent for 40 minutes. Disable with `progress=False`; pass a callable (`progress=my_reporter`) to stream events into wandb, rich, or any other backend.
180
+
181
+ ### Known cliff: CUDA allocator fragmentation on K-sweep configs
182
+
183
+ If a K-sweep run on tight VRAM (K-packed sweeps, `chunk_size=1`, large per-K residual buffer) shows episodic multi-minute pauses every few sweeps with the process stuck in D-state but NVMe mostly idle in `iostat`, the cause is almost certainly CUDA caching-allocator fragmentation, not host I/O. Set `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` before launch — segments grow contiguously on demand and the cliff disappears (verified on a 70B / K=30 / 14k repro: 2:20 → 0:55 wall, no other change). See [#72](https://github.com/AlliedToasters/fpwap/issues/72) for the diagnostic walk.
184
+
185
+ ## Reference callbacks
186
+
187
+ Four callbacks ship with the library as examples and integration tests:
188
+
189
+ - **`RawActivations`** — persist per-sample activations, pooled (`last_token_only=True`) by default to avoid an `[N, S, H]` memory landmine.
190
+ - **`IncrementalPCA`** — fit a PCA basis over the entire dataset in a single pass.
191
+ - **`DiffOfMeans`** — compute per-class activation means for binary-labeled data.
192
+ - **`SteerInBasis`** — additive intervention in a pre-computed basis; `phase = "write"`.
193
+
194
+ Anything beyond these four is a consumer's problem.
195
+
196
+ ## Integrating fpwap into a research codebase
197
+
198
+ The recommended shape is a single classmethod on your codebase's activation-source type, inserted **above** any per-batch sharding your framework does:
199
+
200
+ ```python
201
+ class Activations:
202
+ @classmethod
203
+ def from_fpwap(cls, model_id, prompts, layers, pool="last_token"):
204
+ run = Sweep(
205
+ model=model_id,
206
+ dataset=_as_dataset(prompts),
207
+ seq_len=...,
208
+ callbacks=[
209
+ RawActivations(
210
+ layers=layers,
211
+ last_token_only=(pool == "last_token"),
212
+ ),
213
+ ],
214
+ )
215
+ return cls.from_result(run.run())
216
+ ```
217
+
218
+ Branch `use_fpwap` at your dispatch layer — the same place you'd branch between `from_model`, `from_goodfire`, etc. — not inside a per-batch loop. fpwap's value (amortizing layer loads across the whole dataset) only materializes if it sees the dataset; if your framework shards externally and calls an extractor per shard, lift the dispatch up one level before integrating.
219
+
220
+ ## Scope
221
+
222
+ fpwap is a plumbing layer. It produces activations and accepts transforms. It does not know what a probe is. Linear probe fitting, SAE training, attribution analysis, and any other statistical modeling of activations belong in consumer libraries. If it requires knowing what a probe is, it's out of scope.
223
+
224
+ ## Related work
225
+
226
+ The loop inversion at the heart of fpwap — load each layer once, stream the dataset through it — was explored independently by [FlexGen](https://arxiv.org/abs/2303.06865) (Sheng et al., ICML 2023) for high-throughput generative inference on a single GPU. FlexGen calls this a "zig-zag block schedule" and proves it is within 2× of I/O-optimal (Theorem 4.1) — a result that applies directly to fpwap's loop, since our schedule is the same modulo KV cache. FlexGen solves a harder scheduling problem (KV cache placement across GPU/CPU/disk, multi-step autoregressive decoding, CPU compute delegation) and applies 4-bit group-wise quantization to further compress weights. fpwap targets a narrower regime — forward-pass activation extraction for mechanistic interpretability — where full precision is non-negotiable and generation is not needed, so the implementation is much simpler. The absence of KV cache and autoregressive decoding also means fpwap's cost model has fewer free variables, making strategy selection tractable without an LP solver.
227
+
228
+ ## Status
229
+
230
+ **Llama-3.1-405B on a single RTX 5090 (32 GB VRAM), streaming 803 GB of
231
+ bf16 weights from NVMe — 45.7 tok/s in under 12 minutes.** 70B at
232
+ 10,000 prompts × 128 tokens hits 1,221 tok/s. That's the regime fpwap
233
+ exists for: the model doesn't fit in VRAM (or even RAM), the dataset is
234
+ thousands of prompts, and no quantization is involved. Measured on the
235
+ reference machine (RTX 5090, 128 GB DDR5, PCIe 5.0 NVMe):
236
+
237
+ | Model | Path | Samples × seq_len | Throughput (bf16) | SPEC target |
238
+ | ----- | ---- | ------------------ | ----------------- | ----------- |
239
+ | Llama-3.1-405B-Instruct | streaming, prefetch | 256 × 128 | **45.7 tok/s** | ≥ 180 |
240
+ | Llama-3.3-70B-Instruct | streaming, prefetch | 10,000 × 128 | **1,221 tok/s** | ≥ 950 |
241
+ | Llama-3.3-70B-Instruct | streaming | 1,024 × 128 | 1,026 tok/s | ≥ 950 |
242
+ | Llama-3.1-8B-Instruct | streaming | 1,024 × 128 | 10,442 tok/s | ≥ 5,000 |
243
+ | Llama-3.1-8B-Instruct | preloaded | 256 × 128 | 11,894 tok/s | ≥ 5,000 |
244
+
245
+ The 405B number is end-to-end across 126 layers streaming 803 GB of
246
+ weights from NVMe SSD at 1.12 GB/s sustained, with prefetch fully hiding
247
+ disk reads behind compute (0.000s load per layer at steady state). The
248
+ 70B hero number is end-to-end across 80 layers with a pinned-CPU
249
+ residual buffer (21 GB), async D2H, and a worker-thread weight prefetch
250
+ that overlaps layer L+1's safetensors read with layer L's compute.
251
+
252
+ Baseline sanity: an 8B streaming-vs-naive head-to-head shows a **7.25×
253
+ speedup** at 1024 × 128 (SPEC §17 ratio target ≥ 4×). The naive baseline
254
+ is `accelerate.cpu_offload` at 1,440 tok/s, reproducible via
255
+ `scripts/benchmark.py --mode naive`. 70B can't ratio-test on this machine
256
+ (141 GB bf16 > 128 GB RAM for `cpu_offload`); the 70B claim is absolute
257
+ throughput.
258
+
259
+ Correctness: `tests/gpu/test_real_llama_bit_exact.py` runs Llama-3.2-1B in bf16
260
+ on CUDA and compares every layer's `residual_post` against a naive HF forward —
261
+ bit-exact (`torch.equal`) at every real token position. When microbatch_size
262
+ equals the naive batch size, bf16 is deterministic; at different microbatch
263
+ sizes, outputs diverge by LSB accumulation noise (see the memory note on
264
+ `bf16_microbatch_determinism`).
265
+
266
+ What's wired: pre-loaded and streaming model paths, `Sweep` + `Callback` +
267
+ `Result` API, padded-batch + attention-mask propagation, RoPE-aware Llama
268
+ plumbing, GPT-2 plumbing, all four hooks (`residual_pre`, `attn_out`,
269
+ `mlp_out`, `residual_post`) with fast-path block forward when no sub-layer
270
+ hook is wanted and WriteBack at every hook (sub-layer WriteBack is
271
+ threaded through the block mid-forward so the modified tensor actually
272
+ affects downstream compute), all four reference callbacks shipped
273
+ (`RawActivations`, `IncrementalPCA`, `DiffOfMeans`, `SteerInBasis`),
274
+ `result.activations(...)`, tqdm progress plus callable `progress=reporter`
275
+ emitting `ProgressEvent`s for wandb/rich sinks, pinned-CPU
276
+ `buffer_device="cpu"` with async D2H copy (so oversized residual buffers
277
+ don't block compute), worker-thread **concurrent weight prefetch** on the
278
+ streaming path (layer L+1's safetensors read + H2D overlap with layer L's
279
+ compute), `MemmapBackend` for disk-backed emits,
280
+ `ProfileReport.throughput_tok_per_s()` / `weight_bandwidth_gb_per_s()`,
281
+ `verify=True` fail-fast against a naive-forward baseline (pre-loaded
282
+ models), per-layer `on_layer_end` artifacts collected into
283
+ `result.artifacts`.
284
+
285
+ Model families covered by the structural matcher: Llama, Mistral, Qwen2,
286
+ Gemma, DeepSeek-V2, and any future HF causal LM exposing the same
287
+ `model.{embed_tokens, layers, rotary_emb}` layout. GPT-2 covered by its
288
+ own plumbing.
289
+
290
+ What's not yet: checkpoint/resume, NVMe-backed ResidualBuffer,
291
+ `verify=True` on the streaming path (pre-loaded only).
292
+
293
+ See `SPEC.md` for the full design.