fpwap 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fpwap/__init__.py +43 -0
- fpwap/buffer.py +164 -0
- fpwap/callbacks/__init__.py +3 -0
- fpwap/callbacks/base.py +51 -0
- fpwap/callbacks/common.py +278 -0
- fpwap/cost_model.py +108 -0
- fpwap/engine.py +1808 -0
- fpwap/extractor.py +113 -0
- fpwap/loader.py +527 -0
- fpwap/models/__init__.py +21 -0
- fpwap/models/base.py +86 -0
- fpwap/models/gpt2.py +105 -0
- fpwap/models/llama.py +138 -0
- fpwap/preflight.py +97 -0
- fpwap/storage/__init__.py +78 -0
- fpwap/storage/memmap.py +658 -0
- fpwap/types.py +114 -0
- fpwap-0.1.0.dist-info/METADATA +293 -0
- fpwap-0.1.0.dist-info/RECORD +20 -0
- fpwap-0.1.0.dist-info/WHEEL +4 -0
fpwap/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from fpwap.callbacks.base import Callback
|
|
2
|
+
from fpwap.engine import (
|
|
3
|
+
PreloopTiming,
|
|
4
|
+
ProfileReport,
|
|
5
|
+
Result,
|
|
6
|
+
SetupTiming,
|
|
7
|
+
Sweep,
|
|
8
|
+
TeardownTiming,
|
|
9
|
+
estimate_max_microbatch,
|
|
10
|
+
)
|
|
11
|
+
from fpwap.extractor import Extractor
|
|
12
|
+
from fpwap.preflight import PreflightReport
|
|
13
|
+
from fpwap.types import (
|
|
14
|
+
Artifact,
|
|
15
|
+
ArtifactKey,
|
|
16
|
+
Context,
|
|
17
|
+
Emit,
|
|
18
|
+
LayerArtifact,
|
|
19
|
+
RaggedTensor,
|
|
20
|
+
ResultArtifact,
|
|
21
|
+
WriteBack,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"Artifact",
|
|
26
|
+
"ArtifactKey",
|
|
27
|
+
"Callback",
|
|
28
|
+
"Context",
|
|
29
|
+
"Emit",
|
|
30
|
+
"Extractor",
|
|
31
|
+
"LayerArtifact",
|
|
32
|
+
"PreflightReport",
|
|
33
|
+
"PreloopTiming",
|
|
34
|
+
"ProfileReport",
|
|
35
|
+
"RaggedTensor",
|
|
36
|
+
"Result",
|
|
37
|
+
"ResultArtifact",
|
|
38
|
+
"SetupTiming",
|
|
39
|
+
"Sweep",
|
|
40
|
+
"TeardownTiming",
|
|
41
|
+
"WriteBack",
|
|
42
|
+
"estimate_max_microbatch",
|
|
43
|
+
]
|
fpwap/buffer.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ResidualBuffer:
|
|
12
|
+
"""Inter-layer transport for the fpwap loop.
|
|
13
|
+
|
|
14
|
+
Two modes:
|
|
15
|
+
- In-memory (path=None): pinned torch tensor. Fast async D2H via the CUDA
|
|
16
|
+
copy engine. Default for workloads that fit in host RAM.
|
|
17
|
+
- Disk-backed (path=<file>): numpy memmap. The OS page cache manages
|
|
18
|
+
residency; the full [N, seq, H] corpus never needs to fit in RAM at once.
|
|
19
|
+
bf16 is stored as uint16 bit-patterns (numpy has no bf16 dtype).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
n_samples: int,
|
|
25
|
+
seq_len: int,
|
|
26
|
+
hidden: int,
|
|
27
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
28
|
+
device: torch.device | str = "cpu",
|
|
29
|
+
path: Path | None = None,
|
|
30
|
+
) -> None:
|
|
31
|
+
self.n_samples = n_samples
|
|
32
|
+
self.seq_len = seq_len
|
|
33
|
+
self.hidden = hidden
|
|
34
|
+
self.dtype = dtype
|
|
35
|
+
self.device = torch.device(device)
|
|
36
|
+
self.path = path
|
|
37
|
+
self._shape = (n_samples, seq_len, hidden)
|
|
38
|
+
|
|
39
|
+
if path is not None:
|
|
40
|
+
self._bf16_as_u16 = dtype == torch.bfloat16
|
|
41
|
+
np_dtype = _torch_to_numpy(dtype)
|
|
42
|
+
self._np_dtype: type | None = np_dtype
|
|
43
|
+
self._mm: np.memmap | None = np.memmap(
|
|
44
|
+
path, dtype=np_dtype, mode="w+", shape=self._shape
|
|
45
|
+
)
|
|
46
|
+
if hasattr(os, "posix_madvise"):
|
|
47
|
+
try:
|
|
48
|
+
os.posix_madvise(
|
|
49
|
+
self._mm.ctypes.data, self._mm.nbytes,
|
|
50
|
+
os.POSIX_MADV_SEQUENTIAL, # type: ignore[attr-defined]
|
|
51
|
+
)
|
|
52
|
+
except OSError:
|
|
53
|
+
pass
|
|
54
|
+
self._staging: Tensor | None = None
|
|
55
|
+
self._data: Tensor | None = None
|
|
56
|
+
else:
|
|
57
|
+
self._bf16_as_u16 = False
|
|
58
|
+
self._np_dtype = None
|
|
59
|
+
self._mm = None
|
|
60
|
+
self._staging = None
|
|
61
|
+
pin = self.device.type == "cpu" and torch.cuda.is_available()
|
|
62
|
+
self._data = torch.zeros(
|
|
63
|
+
self._shape, dtype=dtype, device=self.device, pin_memory=pin,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def _mm_to_tensor(self, arr: np.ndarray) -> Tensor:
|
|
67
|
+
t = torch.from_numpy(arr)
|
|
68
|
+
if self._bf16_as_u16:
|
|
69
|
+
t = t.view(torch.bfloat16)
|
|
70
|
+
return t
|
|
71
|
+
|
|
72
|
+
def _tensor_to_np(self, values: Tensor) -> np.ndarray:
|
|
73
|
+
host = values.detach().to(device="cpu", dtype=self.dtype)
|
|
74
|
+
if self._bf16_as_u16:
|
|
75
|
+
host = host.view(torch.uint16)
|
|
76
|
+
return host.numpy()
|
|
77
|
+
|
|
78
|
+
def __getitem__(self, sample_ids: Tensor) -> Tensor:
|
|
79
|
+
if self._data is not None:
|
|
80
|
+
return self._data[sample_ids]
|
|
81
|
+
assert self._mm is not None
|
|
82
|
+
ids_np = sample_ids.detach().to(device="cpu", dtype=torch.int64).numpy()
|
|
83
|
+
return self._mm_to_tensor(np.asarray(self._mm[ids_np]).copy())
|
|
84
|
+
|
|
85
|
+
def __setitem__(self, sample_ids: Tensor, values: Tensor) -> None:
|
|
86
|
+
if self._data is not None:
|
|
87
|
+
self._data[sample_ids] = values.to(dtype=self.dtype, device=self.device)
|
|
88
|
+
return
|
|
89
|
+
assert self._mm is not None
|
|
90
|
+
ids_np = sample_ids.detach().to(device="cpu", dtype=torch.int64).numpy()
|
|
91
|
+
if values.device.type == "cuda":
|
|
92
|
+
staging = self._ensure_staging(values.shape)
|
|
93
|
+
if values.dtype != self.dtype:
|
|
94
|
+
values = values.to(dtype=self.dtype)
|
|
95
|
+
staging.copy_(values, non_blocking=True)
|
|
96
|
+
torch.cuda.synchronize()
|
|
97
|
+
host = staging
|
|
98
|
+
else:
|
|
99
|
+
host = values.detach().to(device="cpu", dtype=self.dtype)
|
|
100
|
+
if self._bf16_as_u16:
|
|
101
|
+
host = host.view(torch.uint16)
|
|
102
|
+
self._mm[ids_np] = host.numpy()
|
|
103
|
+
|
|
104
|
+
def read_slice(self, start: int, stop: int) -> Tensor:
|
|
105
|
+
if self._data is not None:
|
|
106
|
+
return self._data[start:stop]
|
|
107
|
+
assert self._mm is not None
|
|
108
|
+
return self._mm_to_tensor(np.asarray(self._mm[start:stop]).copy())
|
|
109
|
+
|
|
110
|
+
def _ensure_staging(self, shape: tuple[int, ...]) -> Tensor:
|
|
111
|
+
if self._staging is not None and self._staging.shape == shape:
|
|
112
|
+
return self._staging
|
|
113
|
+
self._staging = torch.zeros(shape, dtype=self.dtype, pin_memory=True)
|
|
114
|
+
return self._staging
|
|
115
|
+
|
|
116
|
+
def write_slice(self, start: int, stop: int, values: Tensor) -> None:
|
|
117
|
+
if self._data is not None:
|
|
118
|
+
if values.dtype != self.dtype:
|
|
119
|
+
values = values.to(dtype=self.dtype)
|
|
120
|
+
self._data[start:stop].copy_(values, non_blocking=True)
|
|
121
|
+
return
|
|
122
|
+
assert self._mm is not None
|
|
123
|
+
if values.device.type == "cuda":
|
|
124
|
+
staging = self._ensure_staging(values.shape)
|
|
125
|
+
if values.dtype != self.dtype:
|
|
126
|
+
values = values.to(dtype=self.dtype)
|
|
127
|
+
staging.copy_(values, non_blocking=True)
|
|
128
|
+
torch.cuda.synchronize()
|
|
129
|
+
host = staging
|
|
130
|
+
else:
|
|
131
|
+
host = values.detach().to(device="cpu", dtype=self.dtype)
|
|
132
|
+
if self._bf16_as_u16:
|
|
133
|
+
host = host.view(torch.uint16)
|
|
134
|
+
self._mm[start:stop] = host.numpy()
|
|
135
|
+
|
|
136
|
+
def flush(self) -> None:
|
|
137
|
+
if self._mm is not None:
|
|
138
|
+
self._mm.flush()
|
|
139
|
+
|
|
140
|
+
def close(self) -> None:
|
|
141
|
+
if self._mm is not None:
|
|
142
|
+
self._mm.flush()
|
|
143
|
+
del self._mm
|
|
144
|
+
self._mm = None
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
_TORCH_TO_NUMPY = {
|
|
148
|
+
torch.float32: np.float32,
|
|
149
|
+
torch.float64: np.float64,
|
|
150
|
+
torch.float16: np.float16,
|
|
151
|
+
torch.bfloat16: np.uint16,
|
|
152
|
+
torch.int32: np.int32,
|
|
153
|
+
torch.int64: np.int64,
|
|
154
|
+
torch.int16: np.int16,
|
|
155
|
+
torch.int8: np.int8,
|
|
156
|
+
torch.uint8: np.uint8,
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _torch_to_numpy(dtype: torch.dtype) -> type:
|
|
161
|
+
np_dtype = _TORCH_TO_NUMPY.get(dtype)
|
|
162
|
+
if np_dtype is None:
|
|
163
|
+
raise ValueError(f"unsupported dtype for memmap: {dtype}")
|
|
164
|
+
return np_dtype
|
fpwap/callbacks/base.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from fpwap.types import (
|
|
10
|
+
Artifact,
|
|
11
|
+
BatchResult,
|
|
12
|
+
Context,
|
|
13
|
+
HookName,
|
|
14
|
+
LayerArtifact,
|
|
15
|
+
Phase,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Callback:
|
|
20
|
+
target_layers: Sequence[int] | Literal["all"] = "all"
|
|
21
|
+
target_hooks: Sequence[HookName] = ("residual_post",)
|
|
22
|
+
phase: Phase = "read"
|
|
23
|
+
needs_grad: bool = False
|
|
24
|
+
accum_dtype: torch.dtype = torch.float32
|
|
25
|
+
|
|
26
|
+
def on_sweep_start(self, ctx: Context) -> None:
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
def on_layer_start(self, layer_idx: int) -> None:
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
def on_batch(
|
|
33
|
+
self,
|
|
34
|
+
layer_idx: int,
|
|
35
|
+
hook: HookName,
|
|
36
|
+
acts: Tensor,
|
|
37
|
+
sample_ids: Tensor,
|
|
38
|
+
) -> BatchResult:
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
def on_layer_end(self, layer_idx: int) -> LayerArtifact | None:
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
def on_sweep_end(self) -> Artifact | None:
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
def checkpoint_state(self) -> bytes:
|
|
48
|
+
return b""
|
|
49
|
+
|
|
50
|
+
def restore_state(self, state: bytes) -> None:
|
|
51
|
+
return None
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from fpwap.callbacks.base import Callback
|
|
10
|
+
from fpwap.types import (
|
|
11
|
+
Artifact,
|
|
12
|
+
BatchResult,
|
|
13
|
+
HookName,
|
|
14
|
+
LayerArtifact,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RawActivations(Callback):
|
|
19
|
+
"""Persist per-sample activations, pooled (`last_token_only=True`) by default.
|
|
20
|
+
|
|
21
|
+
Returns an `Emit` each microbatch; the engine routes these into the run's
|
|
22
|
+
result so `result.activations(layer, hook)` can return them concatenated
|
|
23
|
+
in sample order. For datasets too large to hold in memory, swap in a
|
|
24
|
+
disk-backed StorageBackend.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
phase = "read"
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
layers: Sequence[int] | Literal["all"] = "all",
|
|
32
|
+
hook: HookName = "residual_post",
|
|
33
|
+
last_token_only: bool = True,
|
|
34
|
+
out_dtype: torch.dtype = torch.bfloat16,
|
|
35
|
+
) -> None:
|
|
36
|
+
self.target_layers = layers
|
|
37
|
+
self.target_hooks = (hook,)
|
|
38
|
+
self.last_token_only = last_token_only
|
|
39
|
+
self.out_dtype = out_dtype
|
|
40
|
+
|
|
41
|
+
def on_batch(
|
|
42
|
+
self,
|
|
43
|
+
layer_idx: int,
|
|
44
|
+
hook: HookName,
|
|
45
|
+
acts: Tensor,
|
|
46
|
+
sample_ids: Tensor,
|
|
47
|
+
) -> BatchResult:
|
|
48
|
+
from fpwap.types import Emit
|
|
49
|
+
|
|
50
|
+
pooled = acts[:, -1, :] if self.last_token_only else acts
|
|
51
|
+
return Emit(pooled.to(self.out_dtype), dtype=self.out_dtype)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class IncrementalPCA(Callback):
|
|
55
|
+
"""Streaming PCA over the dataset. One pass, O(H²) memory per layer.
|
|
56
|
+
|
|
57
|
+
Accumulates the running mean and `X^T X` in fp32 across microbatches;
|
|
58
|
+
at `on_layer_end` computes `cov = E[XX^T] - mean mean^T` and returns
|
|
59
|
+
the top-k eigenvectors as a `LayerArtifact(kind="pca_basis", ...)`.
|
|
60
|
+
The engine routes it into `result.artifact("pca_basis", layer=i)`.
|
|
61
|
+
|
|
62
|
+
Pooling: `last_token_only=True` by default (matches RawActivations).
|
|
63
|
+
Any 3D `[N, S, H]` activation is pooled to `[N, H]` before accumulation;
|
|
64
|
+
users needing other pooling should pre-pool in an upstream callback or
|
|
65
|
+
subclass. Accumulators run on the execution device to avoid per-batch
|
|
66
|
+
H2D; the final basis is moved to CPU in the artifact payload.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
phase = "read"
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
layers: Sequence[int] | Literal["all"] = "all",
|
|
74
|
+
n_components: int = 64,
|
|
75
|
+
hook: HookName = "residual_post",
|
|
76
|
+
last_token_only: bool = True,
|
|
77
|
+
) -> None:
|
|
78
|
+
self.target_layers = layers
|
|
79
|
+
self.target_hooks = (hook,)
|
|
80
|
+
self.n_components = n_components
|
|
81
|
+
self.last_token_only = last_token_only
|
|
82
|
+
self._sums: dict[int, Tensor] = {}
|
|
83
|
+
self._sumxx: dict[int, Tensor] = {}
|
|
84
|
+
self._counts: dict[int, int] = {}
|
|
85
|
+
|
|
86
|
+
def on_batch(
|
|
87
|
+
self,
|
|
88
|
+
layer_idx: int,
|
|
89
|
+
hook: HookName,
|
|
90
|
+
acts: Tensor,
|
|
91
|
+
sample_ids: Tensor,
|
|
92
|
+
) -> BatchResult:
|
|
93
|
+
x = acts[:, -1, :] if (self.last_token_only and acts.dim() == 3) else acts
|
|
94
|
+
if x.dim() != 2:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"IncrementalPCA expects 2D activations [N, H] after pooling, "
|
|
97
|
+
f"got {tuple(x.shape)}"
|
|
98
|
+
)
|
|
99
|
+
x = x.to(torch.float32)
|
|
100
|
+
hdim = x.shape[-1]
|
|
101
|
+
if layer_idx not in self._sums:
|
|
102
|
+
self._sums[layer_idx] = torch.zeros(hdim, dtype=torch.float32, device=x.device)
|
|
103
|
+
self._sumxx[layer_idx] = torch.zeros(
|
|
104
|
+
hdim, hdim, dtype=torch.float32, device=x.device
|
|
105
|
+
)
|
|
106
|
+
self._counts[layer_idx] = 0
|
|
107
|
+
self._sums[layer_idx] += x.sum(dim=0)
|
|
108
|
+
self._sumxx[layer_idx] += x.T @ x
|
|
109
|
+
self._counts[layer_idx] += int(x.shape[0])
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
def on_layer_end(self, layer_idx: int) -> LayerArtifact | None:
|
|
113
|
+
if layer_idx not in self._sums:
|
|
114
|
+
return None
|
|
115
|
+
n = self._counts[layer_idx]
|
|
116
|
+
if n == 0:
|
|
117
|
+
return None
|
|
118
|
+
mean = self._sums[layer_idx] / n
|
|
119
|
+
cov = self._sumxx[layer_idx] / n - torch.outer(mean, mean)
|
|
120
|
+
# eigh returns ascending eigenvalues; reverse for PCA convention.
|
|
121
|
+
eigvals, eigvecs = torch.linalg.eigh(cov)
|
|
122
|
+
order = torch.argsort(eigvals, descending=True)
|
|
123
|
+
eigvals = eigvals[order]
|
|
124
|
+
eigvecs = eigvecs[:, order]
|
|
125
|
+
k = min(self.n_components, eigvecs.shape[1])
|
|
126
|
+
basis = eigvecs[:, :k].contiguous()
|
|
127
|
+
# Drop per-layer accumulators once artifact is materialized.
|
|
128
|
+
del self._sums[layer_idx]
|
|
129
|
+
del self._sumxx[layer_idx]
|
|
130
|
+
del self._counts[layer_idx]
|
|
131
|
+
return LayerArtifact(
|
|
132
|
+
kind="pca_basis",
|
|
133
|
+
payload={
|
|
134
|
+
"basis": basis.cpu(),
|
|
135
|
+
"mean": mean.cpu(),
|
|
136
|
+
"explained_variance": eigvals[:k].cpu(),
|
|
137
|
+
},
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class DiffOfMeans(Callback):
|
|
142
|
+
"""Per-class activation means for binary-labeled data.
|
|
143
|
+
|
|
144
|
+
User-supplied `labels` is a 1D int tensor indexed by sample_id — this
|
|
145
|
+
sidesteps the dataset-items-in-callback API question. Accumulates
|
|
146
|
+
streaming sums per class (in fp32, on the execution device), and at
|
|
147
|
+
`on_layer_end` returns a `LayerArtifact(kind="diff_of_means", payload)`
|
|
148
|
+
carrying `(mean_1 - mean_0)`, both class means, and their counts. The
|
|
149
|
+
returned direction is a common probing target ("does the model encode
|
|
150
|
+
label L in this layer's residual stream?").
|
|
151
|
+
|
|
152
|
+
Pooling: `last_token_only=True` by default, matching RawActivations.
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
phase = "read"
|
|
156
|
+
|
|
157
|
+
def __init__(
|
|
158
|
+
self,
|
|
159
|
+
labels: Tensor,
|
|
160
|
+
layers: Sequence[int] | Literal["all"] = "all",
|
|
161
|
+
hook: HookName = "residual_post",
|
|
162
|
+
last_token_only: bool = True,
|
|
163
|
+
) -> None:
|
|
164
|
+
self.target_layers = layers
|
|
165
|
+
self.target_hooks = (hook,)
|
|
166
|
+
self.labels = labels.to(torch.int64)
|
|
167
|
+
self.last_token_only = last_token_only
|
|
168
|
+
self._sums: dict[int, dict[int, Tensor]] = {}
|
|
169
|
+
self._counts: dict[int, dict[int, int]] = {}
|
|
170
|
+
|
|
171
|
+
def on_batch(
|
|
172
|
+
self,
|
|
173
|
+
layer_idx: int,
|
|
174
|
+
hook: HookName,
|
|
175
|
+
acts: Tensor,
|
|
176
|
+
sample_ids: Tensor,
|
|
177
|
+
) -> BatchResult:
|
|
178
|
+
x = acts[:, -1, :] if (self.last_token_only and acts.dim() == 3) else acts
|
|
179
|
+
if x.dim() != 2:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"DiffOfMeans expects 2D activations [N, H] after pooling, "
|
|
182
|
+
f"got {tuple(x.shape)}"
|
|
183
|
+
)
|
|
184
|
+
x = x.to(torch.float32)
|
|
185
|
+
# sample_ids → labels (cpu lookup into user-provided tensor, then push
|
|
186
|
+
# back to acts' device for the masked-sum below).
|
|
187
|
+
labels = self.labels[sample_ids.detach().cpu()].to(x.device)
|
|
188
|
+
sums = self._sums.setdefault(layer_idx, {})
|
|
189
|
+
counts = self._counts.setdefault(layer_idx, {})
|
|
190
|
+
for cls in torch.unique(labels).tolist():
|
|
191
|
+
mask = labels == cls
|
|
192
|
+
cls_x = x[mask]
|
|
193
|
+
if cls not in sums:
|
|
194
|
+
sums[cls] = torch.zeros(x.shape[-1], dtype=torch.float32, device=x.device)
|
|
195
|
+
counts[cls] = 0
|
|
196
|
+
sums[cls] += cls_x.sum(dim=0)
|
|
197
|
+
counts[cls] += int(cls_x.shape[0])
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
def on_layer_end(self, layer_idx: int) -> LayerArtifact | None:
|
|
201
|
+
if layer_idx not in self._sums:
|
|
202
|
+
return None
|
|
203
|
+
sums = self._sums[layer_idx]
|
|
204
|
+
counts = self._counts[layer_idx]
|
|
205
|
+
means = {cls: sums[cls] / counts[cls] for cls in sums}
|
|
206
|
+
# Drop accumulators.
|
|
207
|
+
del self._sums[layer_idx]
|
|
208
|
+
del self._counts[layer_idx]
|
|
209
|
+
# Binary-labeled convention: require at least labels 0 and 1 present.
|
|
210
|
+
if 0 in means and 1 in means:
|
|
211
|
+
direction = means[1] - means[0]
|
|
212
|
+
else:
|
|
213
|
+
direction = None
|
|
214
|
+
return LayerArtifact(
|
|
215
|
+
kind="diff_of_means",
|
|
216
|
+
payload={
|
|
217
|
+
"direction": direction.cpu() if direction is not None else None,
|
|
218
|
+
"means": {int(c): v.cpu() for c, v in means.items()},
|
|
219
|
+
"counts": {int(c): counts[c] for c in counts},
|
|
220
|
+
},
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class SteerInBasis(Callback):
|
|
225
|
+
"""Additive intervention in a pre-computed basis.
|
|
226
|
+
|
|
227
|
+
`acts + alpha * basis[:, direction_idx]`, broadcast across the batch and
|
|
228
|
+
sequence dims. `basis_artifact.payload["basis"]` is expected to be
|
|
229
|
+
`[H, n_components]` (the shape produced by `IncrementalPCA`); any
|
|
230
|
+
Artifact with a compatible payload works. Lives in `phase="write"`, so
|
|
231
|
+
the returned `WriteBack` replaces the residual that feeds the next
|
|
232
|
+
layer (or the buffer, if targeted at `residual_post`).
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
phase = "write"
|
|
236
|
+
|
|
237
|
+
def __init__(
|
|
238
|
+
self,
|
|
239
|
+
basis_artifact: Artifact,
|
|
240
|
+
direction_idx: int,
|
|
241
|
+
alpha: float,
|
|
242
|
+
layers: Sequence[int] | Literal["all"] = "all",
|
|
243
|
+
hook: HookName = "residual_post",
|
|
244
|
+
) -> None:
|
|
245
|
+
self.target_layers = layers
|
|
246
|
+
self.target_hooks = (hook,)
|
|
247
|
+
self.basis = basis_artifact
|
|
248
|
+
self.direction_idx = direction_idx
|
|
249
|
+
self.alpha = alpha
|
|
250
|
+
|
|
251
|
+
def on_batch(
|
|
252
|
+
self,
|
|
253
|
+
layer_idx: int,
|
|
254
|
+
hook: HookName,
|
|
255
|
+
acts: Tensor,
|
|
256
|
+
sample_ids: Tensor,
|
|
257
|
+
) -> BatchResult:
|
|
258
|
+
from fpwap.types import WriteBack
|
|
259
|
+
|
|
260
|
+
payload = self.basis.payload
|
|
261
|
+
if isinstance(payload, dict):
|
|
262
|
+
# Prefer "basis" [H, K]; fall back to "direction" [H] (what
|
|
263
|
+
# DiffOfMeans returns). direction_idx is ignored for 1D payloads.
|
|
264
|
+
if "basis" in payload:
|
|
265
|
+
basis = payload["basis"]
|
|
266
|
+
direction = basis[:, self.direction_idx]
|
|
267
|
+
elif "direction" in payload and payload["direction"] is not None:
|
|
268
|
+
direction = payload["direction"]
|
|
269
|
+
else:
|
|
270
|
+
raise KeyError(
|
|
271
|
+
"basis_artifact payload has no 'basis' or 'direction' key"
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
direction = (
|
|
275
|
+
payload[:, self.direction_idx] if payload.dim() == 2 else payload
|
|
276
|
+
)
|
|
277
|
+
direction = direction.to(device=acts.device, dtype=acts.dtype)
|
|
278
|
+
return WriteBack(acts + self.alpha * direction)
|
fpwap/cost_model.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Preflight cost model: predict per-layer latency and recommend config.
|
|
2
|
+
|
|
3
|
+
Pure arithmetic — no GPU, no model weights, no torch. CI-safe.
|
|
4
|
+
|
|
5
|
+
The cost model per layer:
|
|
6
|
+
compute = fwd_per_microbatch_s × ceil(n_samples / microbatch_size)
|
|
7
|
+
load = weight_load_s
|
|
8
|
+
|
|
9
|
+
Without prefetch: per_layer = load + compute
|
|
10
|
+
With prefetch: per_layer = max(load, compute)
|
|
11
|
+
|
|
12
|
+
Total wall = embed_s + per_layer × n_layers
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
from collections.abc import Sequence
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True)
|
|
22
|
+
class CostModelInput:
|
|
23
|
+
n_layers: int
|
|
24
|
+
n_samples: int
|
|
25
|
+
seq_len: int
|
|
26
|
+
microbatch_size: int
|
|
27
|
+
weight_load_s: float
|
|
28
|
+
fwd_per_microbatch_s: float
|
|
29
|
+
embed_s: float
|
|
30
|
+
layer_weight_bytes: int
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class CostModelPrediction:
|
|
35
|
+
per_layer_s: float
|
|
36
|
+
total_wall_s: float
|
|
37
|
+
throughput_tok_s: float
|
|
38
|
+
bottleneck: str # "load", "compute", or "balanced"
|
|
39
|
+
load_pct: float
|
|
40
|
+
compute_pct: float
|
|
41
|
+
weight_io_gb: float
|
|
42
|
+
prefetch: bool
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def predict(inp: CostModelInput, *, prefetch: bool) -> CostModelPrediction:
|
|
46
|
+
n_microbatches = math.ceil(inp.n_samples / inp.microbatch_size)
|
|
47
|
+
compute_s = inp.fwd_per_microbatch_s * n_microbatches
|
|
48
|
+
load_s = inp.weight_load_s
|
|
49
|
+
|
|
50
|
+
if prefetch:
|
|
51
|
+
per_layer_s = max(load_s, compute_s)
|
|
52
|
+
else:
|
|
53
|
+
per_layer_s = load_s + compute_s
|
|
54
|
+
|
|
55
|
+
total_wall_s = inp.embed_s + per_layer_s * inp.n_layers
|
|
56
|
+
total_tokens = inp.n_samples * inp.seq_len
|
|
57
|
+
throughput = total_tokens / total_wall_s if total_wall_s > 0 else 0.0
|
|
58
|
+
|
|
59
|
+
weight_io_gb = inp.layer_weight_bytes * inp.n_layers / 1e9
|
|
60
|
+
|
|
61
|
+
if per_layer_s > 0:
|
|
62
|
+
load_pct = load_s / per_layer_s
|
|
63
|
+
compute_pct = compute_s / per_layer_s
|
|
64
|
+
else:
|
|
65
|
+
load_pct = 0.0
|
|
66
|
+
compute_pct = 0.0
|
|
67
|
+
|
|
68
|
+
ratio = load_s / compute_s if compute_s > 0 else float("inf")
|
|
69
|
+
if ratio > 1.2:
|
|
70
|
+
bottleneck = "load"
|
|
71
|
+
elif ratio < 1 / 1.2:
|
|
72
|
+
bottleneck = "compute"
|
|
73
|
+
else:
|
|
74
|
+
bottleneck = "balanced"
|
|
75
|
+
|
|
76
|
+
return CostModelPrediction(
|
|
77
|
+
per_layer_s=per_layer_s,
|
|
78
|
+
total_wall_s=total_wall_s,
|
|
79
|
+
throughput_tok_s=throughput,
|
|
80
|
+
bottleneck=bottleneck,
|
|
81
|
+
load_pct=load_pct,
|
|
82
|
+
compute_pct=compute_pct,
|
|
83
|
+
weight_io_gb=weight_io_gb,
|
|
84
|
+
prefetch=prefetch,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass(frozen=True)
|
|
89
|
+
class Recommendation:
|
|
90
|
+
input: CostModelInput
|
|
91
|
+
prefetch: bool
|
|
92
|
+
prediction: CostModelPrediction
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def recommend(
|
|
96
|
+
candidates: Sequence[tuple[CostModelInput, bool]],
|
|
97
|
+
) -> Recommendation:
|
|
98
|
+
if not candidates:
|
|
99
|
+
raise ValueError("candidates must be non-empty")
|
|
100
|
+
|
|
101
|
+
best: Recommendation | None = None
|
|
102
|
+
for inp, pf in candidates:
|
|
103
|
+
pred = predict(inp, prefetch=pf)
|
|
104
|
+
if best is None or pred.throughput_tok_s > best.prediction.throughput_tok_s:
|
|
105
|
+
best = Recommendation(input=inp, prefetch=pf, prediction=pred)
|
|
106
|
+
|
|
107
|
+
assert best is not None
|
|
108
|
+
return best
|