fpwap 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fpwap/__init__.py ADDED
@@ -0,0 +1,43 @@
1
+ from fpwap.callbacks.base import Callback
2
+ from fpwap.engine import (
3
+ PreloopTiming,
4
+ ProfileReport,
5
+ Result,
6
+ SetupTiming,
7
+ Sweep,
8
+ TeardownTiming,
9
+ estimate_max_microbatch,
10
+ )
11
+ from fpwap.extractor import Extractor
12
+ from fpwap.preflight import PreflightReport
13
+ from fpwap.types import (
14
+ Artifact,
15
+ ArtifactKey,
16
+ Context,
17
+ Emit,
18
+ LayerArtifact,
19
+ RaggedTensor,
20
+ ResultArtifact,
21
+ WriteBack,
22
+ )
23
+
24
+ __all__ = [
25
+ "Artifact",
26
+ "ArtifactKey",
27
+ "Callback",
28
+ "Context",
29
+ "Emit",
30
+ "Extractor",
31
+ "LayerArtifact",
32
+ "PreflightReport",
33
+ "PreloopTiming",
34
+ "ProfileReport",
35
+ "RaggedTensor",
36
+ "Result",
37
+ "ResultArtifact",
38
+ "SetupTiming",
39
+ "Sweep",
40
+ "TeardownTiming",
41
+ "WriteBack",
42
+ "estimate_max_microbatch",
43
+ ]
fpwap/buffer.py ADDED
@@ -0,0 +1,164 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ class ResidualBuffer:
12
+ """Inter-layer transport for the fpwap loop.
13
+
14
+ Two modes:
15
+ - In-memory (path=None): pinned torch tensor. Fast async D2H via the CUDA
16
+ copy engine. Default for workloads that fit in host RAM.
17
+ - Disk-backed (path=<file>): numpy memmap. The OS page cache manages
18
+ residency; the full [N, seq, H] corpus never needs to fit in RAM at once.
19
+ bf16 is stored as uint16 bit-patterns (numpy has no bf16 dtype).
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ n_samples: int,
25
+ seq_len: int,
26
+ hidden: int,
27
+ dtype: torch.dtype = torch.bfloat16,
28
+ device: torch.device | str = "cpu",
29
+ path: Path | None = None,
30
+ ) -> None:
31
+ self.n_samples = n_samples
32
+ self.seq_len = seq_len
33
+ self.hidden = hidden
34
+ self.dtype = dtype
35
+ self.device = torch.device(device)
36
+ self.path = path
37
+ self._shape = (n_samples, seq_len, hidden)
38
+
39
+ if path is not None:
40
+ self._bf16_as_u16 = dtype == torch.bfloat16
41
+ np_dtype = _torch_to_numpy(dtype)
42
+ self._np_dtype: type | None = np_dtype
43
+ self._mm: np.memmap | None = np.memmap(
44
+ path, dtype=np_dtype, mode="w+", shape=self._shape
45
+ )
46
+ if hasattr(os, "posix_madvise"):
47
+ try:
48
+ os.posix_madvise(
49
+ self._mm.ctypes.data, self._mm.nbytes,
50
+ os.POSIX_MADV_SEQUENTIAL, # type: ignore[attr-defined]
51
+ )
52
+ except OSError:
53
+ pass
54
+ self._staging: Tensor | None = None
55
+ self._data: Tensor | None = None
56
+ else:
57
+ self._bf16_as_u16 = False
58
+ self._np_dtype = None
59
+ self._mm = None
60
+ self._staging = None
61
+ pin = self.device.type == "cpu" and torch.cuda.is_available()
62
+ self._data = torch.zeros(
63
+ self._shape, dtype=dtype, device=self.device, pin_memory=pin,
64
+ )
65
+
66
+ def _mm_to_tensor(self, arr: np.ndarray) -> Tensor:
67
+ t = torch.from_numpy(arr)
68
+ if self._bf16_as_u16:
69
+ t = t.view(torch.bfloat16)
70
+ return t
71
+
72
+ def _tensor_to_np(self, values: Tensor) -> np.ndarray:
73
+ host = values.detach().to(device="cpu", dtype=self.dtype)
74
+ if self._bf16_as_u16:
75
+ host = host.view(torch.uint16)
76
+ return host.numpy()
77
+
78
+ def __getitem__(self, sample_ids: Tensor) -> Tensor:
79
+ if self._data is not None:
80
+ return self._data[sample_ids]
81
+ assert self._mm is not None
82
+ ids_np = sample_ids.detach().to(device="cpu", dtype=torch.int64).numpy()
83
+ return self._mm_to_tensor(np.asarray(self._mm[ids_np]).copy())
84
+
85
+ def __setitem__(self, sample_ids: Tensor, values: Tensor) -> None:
86
+ if self._data is not None:
87
+ self._data[sample_ids] = values.to(dtype=self.dtype, device=self.device)
88
+ return
89
+ assert self._mm is not None
90
+ ids_np = sample_ids.detach().to(device="cpu", dtype=torch.int64).numpy()
91
+ if values.device.type == "cuda":
92
+ staging = self._ensure_staging(values.shape)
93
+ if values.dtype != self.dtype:
94
+ values = values.to(dtype=self.dtype)
95
+ staging.copy_(values, non_blocking=True)
96
+ torch.cuda.synchronize()
97
+ host = staging
98
+ else:
99
+ host = values.detach().to(device="cpu", dtype=self.dtype)
100
+ if self._bf16_as_u16:
101
+ host = host.view(torch.uint16)
102
+ self._mm[ids_np] = host.numpy()
103
+
104
+ def read_slice(self, start: int, stop: int) -> Tensor:
105
+ if self._data is not None:
106
+ return self._data[start:stop]
107
+ assert self._mm is not None
108
+ return self._mm_to_tensor(np.asarray(self._mm[start:stop]).copy())
109
+
110
+ def _ensure_staging(self, shape: tuple[int, ...]) -> Tensor:
111
+ if self._staging is not None and self._staging.shape == shape:
112
+ return self._staging
113
+ self._staging = torch.zeros(shape, dtype=self.dtype, pin_memory=True)
114
+ return self._staging
115
+
116
+ def write_slice(self, start: int, stop: int, values: Tensor) -> None:
117
+ if self._data is not None:
118
+ if values.dtype != self.dtype:
119
+ values = values.to(dtype=self.dtype)
120
+ self._data[start:stop].copy_(values, non_blocking=True)
121
+ return
122
+ assert self._mm is not None
123
+ if values.device.type == "cuda":
124
+ staging = self._ensure_staging(values.shape)
125
+ if values.dtype != self.dtype:
126
+ values = values.to(dtype=self.dtype)
127
+ staging.copy_(values, non_blocking=True)
128
+ torch.cuda.synchronize()
129
+ host = staging
130
+ else:
131
+ host = values.detach().to(device="cpu", dtype=self.dtype)
132
+ if self._bf16_as_u16:
133
+ host = host.view(torch.uint16)
134
+ self._mm[start:stop] = host.numpy()
135
+
136
+ def flush(self) -> None:
137
+ if self._mm is not None:
138
+ self._mm.flush()
139
+
140
+ def close(self) -> None:
141
+ if self._mm is not None:
142
+ self._mm.flush()
143
+ del self._mm
144
+ self._mm = None
145
+
146
+
147
+ _TORCH_TO_NUMPY = {
148
+ torch.float32: np.float32,
149
+ torch.float64: np.float64,
150
+ torch.float16: np.float16,
151
+ torch.bfloat16: np.uint16,
152
+ torch.int32: np.int32,
153
+ torch.int64: np.int64,
154
+ torch.int16: np.int16,
155
+ torch.int8: np.int8,
156
+ torch.uint8: np.uint8,
157
+ }
158
+
159
+
160
+ def _torch_to_numpy(dtype: torch.dtype) -> type:
161
+ np_dtype = _TORCH_TO_NUMPY.get(dtype)
162
+ if np_dtype is None:
163
+ raise ValueError(f"unsupported dtype for memmap: {dtype}")
164
+ return np_dtype
@@ -0,0 +1,3 @@
1
+ from fpwap.callbacks.base import Callback
2
+
3
+ __all__ = ["Callback"]
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Literal
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+ from fpwap.types import (
10
+ Artifact,
11
+ BatchResult,
12
+ Context,
13
+ HookName,
14
+ LayerArtifact,
15
+ Phase,
16
+ )
17
+
18
+
19
+ class Callback:
20
+ target_layers: Sequence[int] | Literal["all"] = "all"
21
+ target_hooks: Sequence[HookName] = ("residual_post",)
22
+ phase: Phase = "read"
23
+ needs_grad: bool = False
24
+ accum_dtype: torch.dtype = torch.float32
25
+
26
+ def on_sweep_start(self, ctx: Context) -> None:
27
+ return None
28
+
29
+ def on_layer_start(self, layer_idx: int) -> None:
30
+ return None
31
+
32
+ def on_batch(
33
+ self,
34
+ layer_idx: int,
35
+ hook: HookName,
36
+ acts: Tensor,
37
+ sample_ids: Tensor,
38
+ ) -> BatchResult:
39
+ return None
40
+
41
+ def on_layer_end(self, layer_idx: int) -> LayerArtifact | None:
42
+ return None
43
+
44
+ def on_sweep_end(self) -> Artifact | None:
45
+ return None
46
+
47
+ def checkpoint_state(self) -> bytes:
48
+ return b""
49
+
50
+ def restore_state(self, state: bytes) -> None:
51
+ return None
@@ -0,0 +1,278 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Literal
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+ from fpwap.callbacks.base import Callback
10
+ from fpwap.types import (
11
+ Artifact,
12
+ BatchResult,
13
+ HookName,
14
+ LayerArtifact,
15
+ )
16
+
17
+
18
+ class RawActivations(Callback):
19
+ """Persist per-sample activations, pooled (`last_token_only=True`) by default.
20
+
21
+ Returns an `Emit` each microbatch; the engine routes these into the run's
22
+ result so `result.activations(layer, hook)` can return them concatenated
23
+ in sample order. For datasets too large to hold in memory, swap in a
24
+ disk-backed StorageBackend.
25
+ """
26
+
27
+ phase = "read"
28
+
29
+ def __init__(
30
+ self,
31
+ layers: Sequence[int] | Literal["all"] = "all",
32
+ hook: HookName = "residual_post",
33
+ last_token_only: bool = True,
34
+ out_dtype: torch.dtype = torch.bfloat16,
35
+ ) -> None:
36
+ self.target_layers = layers
37
+ self.target_hooks = (hook,)
38
+ self.last_token_only = last_token_only
39
+ self.out_dtype = out_dtype
40
+
41
+ def on_batch(
42
+ self,
43
+ layer_idx: int,
44
+ hook: HookName,
45
+ acts: Tensor,
46
+ sample_ids: Tensor,
47
+ ) -> BatchResult:
48
+ from fpwap.types import Emit
49
+
50
+ pooled = acts[:, -1, :] if self.last_token_only else acts
51
+ return Emit(pooled.to(self.out_dtype), dtype=self.out_dtype)
52
+
53
+
54
+ class IncrementalPCA(Callback):
55
+ """Streaming PCA over the dataset. One pass, O(H²) memory per layer.
56
+
57
+ Accumulates the running mean and `X^T X` in fp32 across microbatches;
58
+ at `on_layer_end` computes `cov = E[XX^T] - mean mean^T` and returns
59
+ the top-k eigenvectors as a `LayerArtifact(kind="pca_basis", ...)`.
60
+ The engine routes it into `result.artifact("pca_basis", layer=i)`.
61
+
62
+ Pooling: `last_token_only=True` by default (matches RawActivations).
63
+ Any 3D `[N, S, H]` activation is pooled to `[N, H]` before accumulation;
64
+ users needing other pooling should pre-pool in an upstream callback or
65
+ subclass. Accumulators run on the execution device to avoid per-batch
66
+ H2D; the final basis is moved to CPU in the artifact payload.
67
+ """
68
+
69
+ phase = "read"
70
+
71
+ def __init__(
72
+ self,
73
+ layers: Sequence[int] | Literal["all"] = "all",
74
+ n_components: int = 64,
75
+ hook: HookName = "residual_post",
76
+ last_token_only: bool = True,
77
+ ) -> None:
78
+ self.target_layers = layers
79
+ self.target_hooks = (hook,)
80
+ self.n_components = n_components
81
+ self.last_token_only = last_token_only
82
+ self._sums: dict[int, Tensor] = {}
83
+ self._sumxx: dict[int, Tensor] = {}
84
+ self._counts: dict[int, int] = {}
85
+
86
+ def on_batch(
87
+ self,
88
+ layer_idx: int,
89
+ hook: HookName,
90
+ acts: Tensor,
91
+ sample_ids: Tensor,
92
+ ) -> BatchResult:
93
+ x = acts[:, -1, :] if (self.last_token_only and acts.dim() == 3) else acts
94
+ if x.dim() != 2:
95
+ raise ValueError(
96
+ f"IncrementalPCA expects 2D activations [N, H] after pooling, "
97
+ f"got {tuple(x.shape)}"
98
+ )
99
+ x = x.to(torch.float32)
100
+ hdim = x.shape[-1]
101
+ if layer_idx not in self._sums:
102
+ self._sums[layer_idx] = torch.zeros(hdim, dtype=torch.float32, device=x.device)
103
+ self._sumxx[layer_idx] = torch.zeros(
104
+ hdim, hdim, dtype=torch.float32, device=x.device
105
+ )
106
+ self._counts[layer_idx] = 0
107
+ self._sums[layer_idx] += x.sum(dim=0)
108
+ self._sumxx[layer_idx] += x.T @ x
109
+ self._counts[layer_idx] += int(x.shape[0])
110
+ return None
111
+
112
+ def on_layer_end(self, layer_idx: int) -> LayerArtifact | None:
113
+ if layer_idx not in self._sums:
114
+ return None
115
+ n = self._counts[layer_idx]
116
+ if n == 0:
117
+ return None
118
+ mean = self._sums[layer_idx] / n
119
+ cov = self._sumxx[layer_idx] / n - torch.outer(mean, mean)
120
+ # eigh returns ascending eigenvalues; reverse for PCA convention.
121
+ eigvals, eigvecs = torch.linalg.eigh(cov)
122
+ order = torch.argsort(eigvals, descending=True)
123
+ eigvals = eigvals[order]
124
+ eigvecs = eigvecs[:, order]
125
+ k = min(self.n_components, eigvecs.shape[1])
126
+ basis = eigvecs[:, :k].contiguous()
127
+ # Drop per-layer accumulators once artifact is materialized.
128
+ del self._sums[layer_idx]
129
+ del self._sumxx[layer_idx]
130
+ del self._counts[layer_idx]
131
+ return LayerArtifact(
132
+ kind="pca_basis",
133
+ payload={
134
+ "basis": basis.cpu(),
135
+ "mean": mean.cpu(),
136
+ "explained_variance": eigvals[:k].cpu(),
137
+ },
138
+ )
139
+
140
+
141
+ class DiffOfMeans(Callback):
142
+ """Per-class activation means for binary-labeled data.
143
+
144
+ User-supplied `labels` is a 1D int tensor indexed by sample_id — this
145
+ sidesteps the dataset-items-in-callback API question. Accumulates
146
+ streaming sums per class (in fp32, on the execution device), and at
147
+ `on_layer_end` returns a `LayerArtifact(kind="diff_of_means", payload)`
148
+ carrying `(mean_1 - mean_0)`, both class means, and their counts. The
149
+ returned direction is a common probing target ("does the model encode
150
+ label L in this layer's residual stream?").
151
+
152
+ Pooling: `last_token_only=True` by default, matching RawActivations.
153
+ """
154
+
155
+ phase = "read"
156
+
157
+ def __init__(
158
+ self,
159
+ labels: Tensor,
160
+ layers: Sequence[int] | Literal["all"] = "all",
161
+ hook: HookName = "residual_post",
162
+ last_token_only: bool = True,
163
+ ) -> None:
164
+ self.target_layers = layers
165
+ self.target_hooks = (hook,)
166
+ self.labels = labels.to(torch.int64)
167
+ self.last_token_only = last_token_only
168
+ self._sums: dict[int, dict[int, Tensor]] = {}
169
+ self._counts: dict[int, dict[int, int]] = {}
170
+
171
+ def on_batch(
172
+ self,
173
+ layer_idx: int,
174
+ hook: HookName,
175
+ acts: Tensor,
176
+ sample_ids: Tensor,
177
+ ) -> BatchResult:
178
+ x = acts[:, -1, :] if (self.last_token_only and acts.dim() == 3) else acts
179
+ if x.dim() != 2:
180
+ raise ValueError(
181
+ f"DiffOfMeans expects 2D activations [N, H] after pooling, "
182
+ f"got {tuple(x.shape)}"
183
+ )
184
+ x = x.to(torch.float32)
185
+ # sample_ids → labels (cpu lookup into user-provided tensor, then push
186
+ # back to acts' device for the masked-sum below).
187
+ labels = self.labels[sample_ids.detach().cpu()].to(x.device)
188
+ sums = self._sums.setdefault(layer_idx, {})
189
+ counts = self._counts.setdefault(layer_idx, {})
190
+ for cls in torch.unique(labels).tolist():
191
+ mask = labels == cls
192
+ cls_x = x[mask]
193
+ if cls not in sums:
194
+ sums[cls] = torch.zeros(x.shape[-1], dtype=torch.float32, device=x.device)
195
+ counts[cls] = 0
196
+ sums[cls] += cls_x.sum(dim=0)
197
+ counts[cls] += int(cls_x.shape[0])
198
+ return None
199
+
200
+ def on_layer_end(self, layer_idx: int) -> LayerArtifact | None:
201
+ if layer_idx not in self._sums:
202
+ return None
203
+ sums = self._sums[layer_idx]
204
+ counts = self._counts[layer_idx]
205
+ means = {cls: sums[cls] / counts[cls] for cls in sums}
206
+ # Drop accumulators.
207
+ del self._sums[layer_idx]
208
+ del self._counts[layer_idx]
209
+ # Binary-labeled convention: require at least labels 0 and 1 present.
210
+ if 0 in means and 1 in means:
211
+ direction = means[1] - means[0]
212
+ else:
213
+ direction = None
214
+ return LayerArtifact(
215
+ kind="diff_of_means",
216
+ payload={
217
+ "direction": direction.cpu() if direction is not None else None,
218
+ "means": {int(c): v.cpu() for c, v in means.items()},
219
+ "counts": {int(c): counts[c] for c in counts},
220
+ },
221
+ )
222
+
223
+
224
+ class SteerInBasis(Callback):
225
+ """Additive intervention in a pre-computed basis.
226
+
227
+ `acts + alpha * basis[:, direction_idx]`, broadcast across the batch and
228
+ sequence dims. `basis_artifact.payload["basis"]` is expected to be
229
+ `[H, n_components]` (the shape produced by `IncrementalPCA`); any
230
+ Artifact with a compatible payload works. Lives in `phase="write"`, so
231
+ the returned `WriteBack` replaces the residual that feeds the next
232
+ layer (or the buffer, if targeted at `residual_post`).
233
+ """
234
+
235
+ phase = "write"
236
+
237
+ def __init__(
238
+ self,
239
+ basis_artifact: Artifact,
240
+ direction_idx: int,
241
+ alpha: float,
242
+ layers: Sequence[int] | Literal["all"] = "all",
243
+ hook: HookName = "residual_post",
244
+ ) -> None:
245
+ self.target_layers = layers
246
+ self.target_hooks = (hook,)
247
+ self.basis = basis_artifact
248
+ self.direction_idx = direction_idx
249
+ self.alpha = alpha
250
+
251
+ def on_batch(
252
+ self,
253
+ layer_idx: int,
254
+ hook: HookName,
255
+ acts: Tensor,
256
+ sample_ids: Tensor,
257
+ ) -> BatchResult:
258
+ from fpwap.types import WriteBack
259
+
260
+ payload = self.basis.payload
261
+ if isinstance(payload, dict):
262
+ # Prefer "basis" [H, K]; fall back to "direction" [H] (what
263
+ # DiffOfMeans returns). direction_idx is ignored for 1D payloads.
264
+ if "basis" in payload:
265
+ basis = payload["basis"]
266
+ direction = basis[:, self.direction_idx]
267
+ elif "direction" in payload and payload["direction"] is not None:
268
+ direction = payload["direction"]
269
+ else:
270
+ raise KeyError(
271
+ "basis_artifact payload has no 'basis' or 'direction' key"
272
+ )
273
+ else:
274
+ direction = (
275
+ payload[:, self.direction_idx] if payload.dim() == 2 else payload
276
+ )
277
+ direction = direction.to(device=acts.device, dtype=acts.dtype)
278
+ return WriteBack(acts + self.alpha * direction)
fpwap/cost_model.py ADDED
@@ -0,0 +1,108 @@
1
+ """Preflight cost model: predict per-layer latency and recommend config.
2
+
3
+ Pure arithmetic — no GPU, no model weights, no torch. CI-safe.
4
+
5
+ The cost model per layer:
6
+ compute = fwd_per_microbatch_s × ceil(n_samples / microbatch_size)
7
+ load = weight_load_s
8
+
9
+ Without prefetch: per_layer = load + compute
10
+ With prefetch: per_layer = max(load, compute)
11
+
12
+ Total wall = embed_s + per_layer × n_layers
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import math
17
+ from collections.abc import Sequence
18
+ from dataclasses import dataclass
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class CostModelInput:
23
+ n_layers: int
24
+ n_samples: int
25
+ seq_len: int
26
+ microbatch_size: int
27
+ weight_load_s: float
28
+ fwd_per_microbatch_s: float
29
+ embed_s: float
30
+ layer_weight_bytes: int
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class CostModelPrediction:
35
+ per_layer_s: float
36
+ total_wall_s: float
37
+ throughput_tok_s: float
38
+ bottleneck: str # "load", "compute", or "balanced"
39
+ load_pct: float
40
+ compute_pct: float
41
+ weight_io_gb: float
42
+ prefetch: bool
43
+
44
+
45
+ def predict(inp: CostModelInput, *, prefetch: bool) -> CostModelPrediction:
46
+ n_microbatches = math.ceil(inp.n_samples / inp.microbatch_size)
47
+ compute_s = inp.fwd_per_microbatch_s * n_microbatches
48
+ load_s = inp.weight_load_s
49
+
50
+ if prefetch:
51
+ per_layer_s = max(load_s, compute_s)
52
+ else:
53
+ per_layer_s = load_s + compute_s
54
+
55
+ total_wall_s = inp.embed_s + per_layer_s * inp.n_layers
56
+ total_tokens = inp.n_samples * inp.seq_len
57
+ throughput = total_tokens / total_wall_s if total_wall_s > 0 else 0.0
58
+
59
+ weight_io_gb = inp.layer_weight_bytes * inp.n_layers / 1e9
60
+
61
+ if per_layer_s > 0:
62
+ load_pct = load_s / per_layer_s
63
+ compute_pct = compute_s / per_layer_s
64
+ else:
65
+ load_pct = 0.0
66
+ compute_pct = 0.0
67
+
68
+ ratio = load_s / compute_s if compute_s > 0 else float("inf")
69
+ if ratio > 1.2:
70
+ bottleneck = "load"
71
+ elif ratio < 1 / 1.2:
72
+ bottleneck = "compute"
73
+ else:
74
+ bottleneck = "balanced"
75
+
76
+ return CostModelPrediction(
77
+ per_layer_s=per_layer_s,
78
+ total_wall_s=total_wall_s,
79
+ throughput_tok_s=throughput,
80
+ bottleneck=bottleneck,
81
+ load_pct=load_pct,
82
+ compute_pct=compute_pct,
83
+ weight_io_gb=weight_io_gb,
84
+ prefetch=prefetch,
85
+ )
86
+
87
+
88
+ @dataclass(frozen=True)
89
+ class Recommendation:
90
+ input: CostModelInput
91
+ prefetch: bool
92
+ prediction: CostModelPrediction
93
+
94
+
95
+ def recommend(
96
+ candidates: Sequence[tuple[CostModelInput, bool]],
97
+ ) -> Recommendation:
98
+ if not candidates:
99
+ raise ValueError("candidates must be non-empty")
100
+
101
+ best: Recommendation | None = None
102
+ for inp, pf in candidates:
103
+ pred = predict(inp, prefetch=pf)
104
+ if best is None or pred.throughput_tok_s > best.prediction.throughput_tok_s:
105
+ best = Recommendation(input=inp, prefetch=pf, prediction=pred)
106
+
107
+ assert best is not None
108
+ return best