AlphaPFN 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
alphapfn/__init__.py ADDED
@@ -0,0 +1,17 @@
1
+ """alphapfn — fast entropy-search acquisition via in-context learning.
2
+
3
+ Public surface:
4
+ from alphapfn import AlphaPFN
5
+ model = AlphaPFN.from_pretrained(acquisition="JES")
6
+ model.fit(train_X, train_Y)
7
+ acq = model(X_test)
8
+ """
9
+ from alphapfn.api import AlphaPFN, AlphaPFNPosteriorMean, ALLOWED_ACQUISITIONS
10
+
11
+ __version__ = "0.0.1"
12
+ __all__ = [
13
+ "AlphaPFN",
14
+ "AlphaPFNPosteriorMean",
15
+ "ALLOWED_ACQUISITIONS",
16
+ "__version__",
17
+ ]
alphapfn/api.py ADDED
@@ -0,0 +1,437 @@
1
+ """Public AlphaPFN model API.
2
+
3
+ The model itself is the acquisition function: `model(X_test)` returns
4
+ scalar acquisition values that an outer optimizer can maximize. When
5
+ botorch is installed AlphaPFN inherits `AcquisitionFunction` so it
6
+ plugs straight into `botorch.optim.optimize_acqf`. Without botorch,
7
+ `AcquisitionFunction` falls back to `nn.Module` and the decorator
8
+ becomes a no-op — the model is still callable; the user provides
9
+ their own optimizer.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ from typing import Optional
14
+
15
+ import torch
16
+ from torch import Tensor, nn
17
+
18
+ from alphapfn.loader import load_predictor, ALLOWED_VERSIONS
19
+
20
+ try:
21
+ from botorch.acquisition.acquisition import AcquisitionFunction
22
+ from botorch.utils.transforms import t_batch_mode_transform
23
+ _BOTORCH_AVAILABLE = True
24
+ except ImportError:
25
+ AcquisitionFunction = nn.Module # type: ignore[misc,assignment]
26
+ _BOTORCH_AVAILABLE = False
27
+
28
+ # Fallback that mirrors botorch's contract: accept X of shape
29
+ # (b, q, d) or (q, d); assert q == expected_q if given; insert a
30
+ # leading batch dim if missing so the wrapped method always
31
+ # receives a 3-D tensor.
32
+ def t_batch_mode_transform(expected_q=None, assert_output_shape=True): # type: ignore[no-redef]
33
+ def _decorator(fn):
34
+ def _wrapper(self, X, *args, **kwargs):
35
+ if not isinstance(X, torch.Tensor):
36
+ return fn(self, X, *args, **kwargs)
37
+ if X.dim() < 2:
38
+ raise ValueError(
39
+ f"{type(self).__name__} requires X to have at least 2 "
40
+ f"dimensions, but received X with {X.dim()} dimensions."
41
+ )
42
+ if expected_q is not None and X.shape[-2] != expected_q:
43
+ raise AssertionError(
44
+ f"Expected X to be `batch_shape x q={expected_q} x d`, "
45
+ f"but got X with shape {tuple(X.shape)}."
46
+ )
47
+ if X.dim() == 2:
48
+ X = X.unsqueeze(0)
49
+ return fn(self, X, *args, **kwargs)
50
+ return _wrapper
51
+ return _decorator
52
+
53
+
54
+ ALLOWED_ACQUISITIONS = ("EI", "UCB", "PES", "MES", "JES")
55
+ _DIRECT_HEADS = {"PES": "pes", "MES": "mes", "JES": "jes"}
56
+
57
+
58
+ class AlphaPFN(AcquisitionFunction):
59
+ """Acquisition-function-shaped wrapper around the PFN model.
60
+
61
+ Construct via `AlphaPFN.from_pretrained(acquisition=..., version=...)`,
62
+ call `fit(train_X, train_Y)` once, then call as a function on
63
+ candidate `X_test` tensors to get acquisition values.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ ppd_model: Optional[nn.Module],
69
+ head_model: Optional[nn.Module],
70
+ acquisition: Optional[str],
71
+ ucb_beta: float = 2.0,
72
+ is_base_model: bool = False,
73
+ strict: bool = True,
74
+ _registered_model: Optional[nn.Module] = None,
75
+ ) -> None:
76
+ # super().__init__() with model= only works under botorch's
77
+ # AcquisitionFunction; under the nn.Module fallback we have to
78
+ # avoid passing keyword args.
79
+ if _BOTORCH_AVAILABLE:
80
+ super().__init__(model=_registered_model or ppd_model or head_model)
81
+ else:
82
+ super().__init__()
83
+ self._ppd_model = ppd_model
84
+ self._head_model = head_model
85
+ self._acquisition = acquisition
86
+ self._is_base_model = bool(is_base_model)
87
+ self._ucb_beta = float(ucb_beta)
88
+ self._strict = bool(strict)
89
+ # Training cache (set by fit())
90
+ self._train_X: Optional[Tensor] = None
91
+ self._train_Y: Optional[Tensor] = None
92
+ self._fitted: bool = False
93
+
94
+ # ------------------------------------------------------------------
95
+ # Construction
96
+ # ------------------------------------------------------------------
97
+
98
+ @classmethod
99
+ def from_pretrained(
100
+ cls,
101
+ acquisition: Optional[str] = None,
102
+ version: str = "v1",
103
+ *,
104
+ load_base_model: bool = False,
105
+ ucb_beta: float = 2.0,
106
+ strict: bool = True,
107
+ ) -> "AlphaPFN":
108
+ """Load a pretrained AlphaPFN.
109
+
110
+ Pass at least one of:
111
+ - acquisition ∈ {"EI", "UCB", "PES", "MES", "JES"}
112
+ → callable model returning acquisition values via forward().
113
+ - load_base_model=True
114
+ → loads the PPD base model. forward() not implemented;
115
+ access the underlying model via `acqf.base_model`.
116
+
117
+ Loading rules:
118
+ - EI / UCB always implicitly load the base (PPD) model
119
+ — the acquisition is computed in closed form from PPD logits.
120
+ Passing `load_base_model=True` alongside is redundant but
121
+ permitted.
122
+ - PES / MES / JES load a separately-trained head. The base
123
+ model is loaded only if `load_base_model=True`.
124
+ - Without `acquisition`, `load_base_model=True` is required.
125
+
126
+ Input contract:
127
+ - Maximization. AlphaPFN scores points for a *maximization*
128
+ objective (f_best is train_Y.max(); EI/UCB/PES/MES/JES all
129
+ return higher values for "better" inputs). For minimization,
130
+ pass -f(X) to fit and negate the result. The maximization
131
+ assumption is NOT checked at runtime.
132
+ - X must lie in [0, 1]^d. Rescale your search space first.
133
+ - train_Y must be approximately standardized
134
+ (|mean(y)| <= 0.5, |std(y) - 1| <= 0.5).
135
+
136
+ With `strict=True` (default), the cube and standardization
137
+ conditions are checked on every fit/forward; violations raise
138
+ ValueError. Pass `strict=False` if you intentionally use
139
+ out-of-cube inputs or non-standard targets.
140
+ """
141
+ if acquisition is None and not load_base_model:
142
+ raise ValueError(
143
+ "Specify acquisition=... or load_base_model=True (or both)."
144
+ )
145
+ if acquisition is not None and acquisition not in ALLOWED_ACQUISITIONS:
146
+ raise ValueError(
147
+ f"acquisition={acquisition!r} is not supported. "
148
+ f"Allowed: {ALLOWED_ACQUISITIONS}"
149
+ )
150
+ if version not in ALLOWED_VERSIONS:
151
+ raise ValueError(
152
+ f"version={version!r} is not supported. "
153
+ f"Allowed: {sorted(ALLOWED_VERSIONS)}"
154
+ )
155
+
156
+ # PPD/base model is loaded if:
157
+ # - user asked via load_base_model=True, OR
158
+ # - acquisition needs it (EI/UCB compute on top of PPD).
159
+ ppd_needed = load_base_model or acquisition in {"EI", "UCB"}
160
+ ppd_model = load_predictor("ppd", version=version) if ppd_needed else None
161
+
162
+ head_model = None
163
+ if acquisition in _DIRECT_HEADS:
164
+ head_model = load_predictor(_DIRECT_HEADS[acquisition], version=version)
165
+
166
+ # For PES/MES/JES without load_base_model, _ppd_model stays None;
167
+ # forward() uses the head only. The acquisition function still
168
+ # needs an `nn.Module` to register under botorch — fall back to
169
+ # the head model in that case.
170
+ registered_model = ppd_model if ppd_model is not None else head_model
171
+ return cls(
172
+ ppd_model=ppd_model,
173
+ head_model=head_model,
174
+ acquisition=acquisition,
175
+ ucb_beta=ucb_beta,
176
+ is_base_model=load_base_model,
177
+ strict=strict,
178
+ _registered_model=registered_model,
179
+ )
180
+
181
+ @property
182
+ def base_model(self) -> nn.Module:
183
+ """Access the loaded PPD/base model directly.
184
+
185
+ Useful when `load_base_model=True` — the conditioning interface
186
+ (x*, f*) is not yet exposed on AlphaPFN, so callers have to
187
+ invoke the underlying model themselves for now.
188
+
189
+ Raises if no base model was loaded (PES/MES/JES without
190
+ load_base_model=True).
191
+ """
192
+ if self._ppd_model is None:
193
+ raise AttributeError(
194
+ "No base model was loaded. Pass `load_base_model=True` to "
195
+ "from_pretrained()."
196
+ )
197
+ return self._ppd_model
198
+
199
+ # ------------------------------------------------------------------
200
+ # Input-contract checks (strict-mode)
201
+ # ------------------------------------------------------------------
202
+
203
+ _X_BOUND_EPS = 1e-6
204
+ _Y_MEAN_TOL = 0.5
205
+ _Y_STD_TOL = 0.5
206
+
207
+ def _check_X_in_cube(self, X: Tensor, *, where: str) -> None:
208
+ """Assert X ⊂ [0, 1]^d (with a small tolerance). Honors strict."""
209
+ if not self._strict:
210
+ return
211
+ lo = float(X.min().item())
212
+ hi = float(X.max().item())
213
+ if lo < -self._X_BOUND_EPS or hi > 1.0 + self._X_BOUND_EPS:
214
+ raise ValueError(
215
+ f"{where}: X must lie in [0, 1]^d but got min={lo:.4g}, "
216
+ f"max={hi:.4g}. The pretrained model assumes inputs "
217
+ f"are normalized to the unit cube. If you intentionally "
218
+ f"pass out-of-cube inputs, construct AlphaPFN with "
219
+ f"strict=False."
220
+ )
221
+
222
+ def _check_y_standardized(self, y: Tensor) -> None:
223
+ """Assert y is approximately standardized. Honors strict."""
224
+ if not self._strict:
225
+ return
226
+ if y.numel() < 2:
227
+ return # std() undefined / unstable with < 2 points
228
+ mean = float(y.mean().item())
229
+ std = float(y.std().item())
230
+ if abs(mean) > self._Y_MEAN_TOL or abs(std - 1.0) > self._Y_STD_TOL:
231
+ raise ValueError(
232
+ f"fit: train_Y must be approximately standardized "
233
+ f"(|mean| <= {self._Y_MEAN_TOL}, |std-1| <= {self._Y_STD_TOL}); "
234
+ f"got mean={mean:.4g}, std={std:.4g}. The pretrained model "
235
+ f"assumes roughly-standard targets. Standardize before "
236
+ f"calling fit, e.g. `y = (y - y.mean()) / (y.std() + 1e-8)`. "
237
+ f"If you intentionally pass non-standard targets, construct "
238
+ f"AlphaPFN with strict=False."
239
+ )
240
+
241
+ # ------------------------------------------------------------------
242
+ # Fit (one-shot; stores the train data; no real "fitting" happens)
243
+ # ------------------------------------------------------------------
244
+
245
+ def fit(
246
+ self,
247
+ train_X: Tensor,
248
+ train_Y: Tensor,
249
+ *,
250
+ standardize_y: bool = True,
251
+ ) -> "AlphaPFN":
252
+ """Provide training context. Required before forward().
253
+
254
+ Args:
255
+ train_X: (n, d) inputs in [0, 1]^d.
256
+ train_Y: (n,) targets in the original scale.
257
+ standardize_y: When True (default), the targets are standardized
258
+ internally — the pretrained model assumes ~N(0, 1) targets,
259
+ but acquisition argmax is invariant under affine y-rescaling,
260
+ so callers don't need to standardize themselves. Pass
261
+ `standardize_y=False` if you have already standardized
262
+ (the strict-mode contract check then applies).
263
+ """
264
+ if train_X.ndim != 2:
265
+ raise ValueError(
266
+ f"train_X must be (n, d); got shape {tuple(train_X.shape)}"
267
+ )
268
+ if train_Y.ndim == 2 and train_Y.shape[-1] == 1:
269
+ train_Y = train_Y.squeeze(-1)
270
+ if train_Y.ndim != 1:
271
+ raise ValueError(
272
+ f"train_Y must be (n,) or (n, 1); got shape {tuple(train_Y.shape)}"
273
+ )
274
+ if train_X.shape[0] != train_Y.shape[0]:
275
+ raise ValueError(
276
+ f"train_X / train_Y leading-dim mismatch: "
277
+ f"{train_X.shape[0]} vs {train_Y.shape[0]}"
278
+ )
279
+ self._check_X_in_cube(train_X, where="fit")
280
+ if standardize_y:
281
+ if train_Y.numel() >= 2:
282
+ std = train_Y.std()
283
+ train_Y = (train_Y - train_Y.mean()) / (std + 1e-8)
284
+ else:
285
+ self._check_y_standardized(train_Y)
286
+ self._train_X = train_X.detach()
287
+ self._train_Y = train_Y.detach()
288
+ self._fitted = True
289
+ return self
290
+
291
+
292
+ # ------------------------------------------------------------------
293
+ # Forward / __call__: returns acquisition values
294
+ # ------------------------------------------------------------------
295
+
296
+ def _run_ppd(self, X_test: Tensor) -> Tensor:
297
+ """Returns PPD logits at X_test, shape (num_test, num_bars)."""
298
+ assert self._train_X is not None and self._train_Y is not None
299
+ x_train = self._train_X.to(X_test.dtype).unsqueeze(1) # (n, 1, d)
300
+ y_train = self._train_Y.to(X_test.dtype).unsqueeze(1) # (n, 1)
301
+ x_test = X_test.unsqueeze(1) # (m, 1, d)
302
+
303
+ n = x_train.shape[0]
304
+ gp_dim = x_train.shape[-1]
305
+ nan_style = torch.full((1, gp_dim, 1), float("nan"),
306
+ dtype=X_test.dtype, device=X_test.device)
307
+ nan_y_style = torch.full((1, 1), float("nan"),
308
+ dtype=X_test.dtype, device=X_test.device)
309
+
310
+ logits = self._ppd_model(
311
+ x=x_train.float(),
312
+ y=y_train.float(),
313
+ test_x=x_test.float(),
314
+ style=nan_style.float(),
315
+ y_style=nan_y_style.float(),
316
+ )
317
+ return logits.squeeze(1) # (m, num_bars)
318
+
319
+ def _run_head(self, X_test: Tensor) -> Tensor:
320
+ """Returns direct-head scalar acquisition values at X_test, shape (m,)."""
321
+ assert self._head_model is not None
322
+ assert self._train_X is not None and self._train_Y is not None
323
+ x_train = self._train_X.to(X_test.dtype).unsqueeze(1)
324
+ y_train = self._train_Y.to(X_test.dtype).unsqueeze(1)
325
+ x_test = X_test.unsqueeze(1)
326
+
327
+ logits = self._head_model(
328
+ x_train.float(), y_train.float(), x_test.float()
329
+ )
330
+ return self._head_model.criterion.mean(logits).flatten()
331
+
332
+ def posterior_mean(self, X: Tensor) -> Tensor:
333
+ """Returns the PPD posterior predictive mean E[f | D] at X.
334
+
335
+ Requires the base (PPD) model — pass `load_base_model=True` to
336
+ `from_pretrained`. Pair with `AlphaPFNPosteriorMean` and
337
+ `botorch.optim.optimize_acqf` to find x̂ = argmax E[f | D]
338
+ (the predicted optimizer, useful for inference-regret reporting).
339
+ """
340
+ if self._ppd_model is None:
341
+ raise RuntimeError(
342
+ "posterior_mean requires the base model; pass "
343
+ "load_base_model=True to from_pretrained()."
344
+ )
345
+ if not self._fitted:
346
+ raise RuntimeError("call .fit(train_X, train_Y) before posterior_mean()")
347
+ self._check_X_in_cube(X, where="posterior_mean")
348
+ logits = self._run_ppd(X)
349
+ return self._ppd_model.criterion.mean(logits).flatten()
350
+
351
+ def _ei_from_ppd(self, X_test: Tensor) -> Tensor:
352
+ assert self._train_Y is not None
353
+ f_best = float(self._train_Y.max().item())
354
+ logits = self._run_ppd(X_test)
355
+ return self._ppd_model.criterion.ei(logits, f_best).flatten()
356
+
357
+ def _ucb_from_ppd(self, X_test: Tensor) -> Tensor:
358
+ # UCB at quantile alpha = 0.5 * (1 + erf(beta / sqrt(2))) — i.e.
359
+ # the beta-sigma upper bound. For a simple bar distribution
360
+ # approximation we use the inverse-cdf at p = 1 - normal_tail(beta).
361
+ import math
362
+ p = 0.5 * (1.0 + math.erf(self._ucb_beta / math.sqrt(2.0)))
363
+ logits = self._run_ppd(X_test)
364
+ return self._ppd_model.criterion.icdf(logits, p).flatten()
365
+
366
+ @t_batch_mode_transform(expected_q=1)
367
+ def forward(self, X: Tensor) -> Tensor:
368
+ """Returns acquisition values at X.
369
+
370
+ Input contract (botorch-style, applied with or without botorch):
371
+ X has shape (b, q=1, d) — or (q=1, d), in which case the
372
+ decorator prepends a leading batch dim. The decorator
373
+ asserts q == 1.
374
+
375
+ Output: (b,) for scalar acquisitions (EI/UCB/PES/MES/JES).
376
+
377
+ When `load_base_model=True` was passed (and no acquisition),
378
+ forward() is not implemented — the base model has a different
379
+ interface (conditioning on optimizer x* and optimum f*) which is
380
+ not yet designed. Use `model.base_model` directly for now.
381
+ """
382
+ if not self._fitted:
383
+ raise RuntimeError("call .fit(train_X, train_Y) before forward()")
384
+
385
+ # After the decorator, X has shape (b, 1, d). Collapse the q dim.
386
+ X = X.squeeze(-2)
387
+ self._check_X_in_cube(X, where="forward")
388
+
389
+ acq = self._acquisition
390
+ if acq is None:
391
+ raise NotImplementedError(
392
+ "This model was loaded with load_base_model=True. "
393
+ "The base-model forward interface (with x*/f* conditioning) "
394
+ "is not designed yet. Access the underlying model via "
395
+ "`acqf.base_model` for now."
396
+ )
397
+ if acq == "EI":
398
+ return self._ei_from_ppd(X)
399
+ if acq == "UCB":
400
+ return self._ucb_from_ppd(X)
401
+ if acq in {"PES", "MES", "JES"}:
402
+ return self._run_head(X)
403
+ raise ValueError(f"Unknown acquisition: {acq!r}")
404
+
405
+
406
+ class AlphaPFNPosteriorMean(AcquisitionFunction):
407
+ """Scalar acquisition returning `pfn.posterior_mean(X)`.
408
+
409
+ Drop-in for `botorch.optim.optimize_acqf` to find
410
+ x̂ = argmax E[f | D], i.e. the model's predicted optimizer (useful
411
+ for reporting inference regret in BO benchmarks).
412
+
413
+ Requires the wrapped `AlphaPFN` to have been loaded with
414
+ `load_base_model=True` and `.fit(...)` called.
415
+
416
+ Example:
417
+ from alphapfn import AlphaPFN, AlphaPFNPosteriorMean
418
+ from botorch.optim import optimize_acqf
419
+
420
+ pfn = AlphaPFN.from_pretrained(acquisition="JES", load_base_model=True)
421
+ pfn.fit(X, y_std)
422
+ x_hat, _ = optimize_acqf(AlphaPFNPosteriorMean(pfn),
423
+ bounds=bounds, q=1,
424
+ num_restarts=5, raw_samples=64)
425
+ """
426
+ def __init__(self, pfn: "AlphaPFN"):
427
+ if pfn._ppd_model is None:
428
+ raise RuntimeError(
429
+ "AlphaPFNPosteriorMean requires the base PPD model; pass "
430
+ "load_base_model=True to AlphaPFN.from_pretrained()."
431
+ )
432
+ super().__init__(model=pfn._ppd_model)
433
+ self.pfn = pfn
434
+
435
+ @t_batch_mode_transform(expected_q=1)
436
+ def forward(self, X: Tensor) -> Tensor:
437
+ return self.pfn.posterior_mean(X.squeeze(-2))
@@ -0,0 +1,164 @@
1
+ """Checkpoint cache and lazy fetch.
2
+
3
+ Public surface:
4
+ `ensure_checkpoints(version)` -> Path to <cache>/<version>/, downloading
5
+ and extracting the published bundle on first use.
6
+
7
+ Resolution order:
8
+ 1. `path=` argument to `AlphaPFN.from_pretrained` (handled in loader.py)
9
+ 2. `$ALPHAPFN_CACHE_DIR/<version>/`
10
+ 3. Platform-aware user cache: `platformdirs.user_cache_dir("alphapfn")`
11
+ (Linux: ~/.cache/alphapfn, macOS: ~/Library/Caches/alphapfn,
12
+ Windows: %LOCALAPPDATA%\\alphapfn)
13
+
14
+ If the cache is empty, the bundle is fetched from
15
+ `$ALPHAPFN_BASE_URL/alpha_pfn_<version>.zip` (default base URL is the
16
+ ML Freiburg artifact host) and extracted atomically.
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import os
21
+ import shutil
22
+ import sys
23
+ import tempfile
24
+ import urllib.request
25
+ import zipfile
26
+ from pathlib import Path
27
+
28
+ from platformdirs import user_cache_dir
29
+
30
+
31
+ DEFAULT_BASE_URL = (
32
+ "https://ml.informatik.uni-freiburg.de/research-artifacts/rakotoah/alpha_pfn"
33
+ )
34
+
35
+ # All four predictors live under <cache>/<version>/. We probe ppd as the
36
+ # canary; if it's there we trust the bundle was extracted fully.
37
+ _CANARY_PREDICTOR = "ppd"
38
+ _CANARY_FILE = "weights.safetensors"
39
+
40
+
41
+ def get_cache_dir() -> Path:
42
+ """Resolve the alphapfn cache root.
43
+
44
+ Honors $ALPHAPFN_CACHE_DIR; falls back to the platform user cache.
45
+ """
46
+ env = os.environ.get("ALPHAPFN_CACHE_DIR")
47
+ if env:
48
+ return Path(env).expanduser()
49
+ return Path(user_cache_dir("alphapfn"))
50
+
51
+
52
+ def _bundle_url(version: str) -> str:
53
+ base = os.environ.get("ALPHAPFN_BASE_URL", DEFAULT_BASE_URL).rstrip("/")
54
+ return f"{base}/alpha_pfn_{version}.zip"
55
+
56
+
57
+ def _is_populated(version_dir: Path) -> bool:
58
+ return (version_dir / _CANARY_PREDICTOR / _CANARY_FILE).exists()
59
+
60
+
61
+ def _human_size(num_bytes: int) -> str:
62
+ size = float(num_bytes)
63
+ for unit in ("B", "KB", "MB", "GB"):
64
+ if size < 1024 or unit == "GB":
65
+ return f"{size:.1f} {unit}"
66
+ size /= 1024
67
+ return f"{size:.1f} GB"
68
+
69
+
70
+ def _download(url: str, dest: Path) -> None:
71
+ """Stream a URL to `dest` with a single-line progress indicator on stderr."""
72
+ print(f"alphapfn: downloading {url}", file=sys.stderr)
73
+ with urllib.request.urlopen(url) as response:
74
+ total = int(response.headers.get("Content-Length", 0))
75
+ chunk = 1 << 20 # 1 MiB
76
+ written = 0
77
+ with open(dest, "wb") as out:
78
+ while True:
79
+ buf = response.read(chunk)
80
+ if not buf:
81
+ break
82
+ out.write(buf)
83
+ written += len(buf)
84
+ if total > 0 and sys.stderr.isatty():
85
+ pct = 100.0 * written / total
86
+ print(
87
+ f"\ralphapfn: {_human_size(written)}/{_human_size(total)} ({pct:5.1f}%)",
88
+ end="",
89
+ file=sys.stderr,
90
+ )
91
+ if sys.stderr.isatty():
92
+ print("", file=sys.stderr) # newline after progress line
93
+
94
+
95
+ def _download_and_extract(version: str, version_dir: Path) -> None:
96
+ """Download + atomically install the bundle into <cache>/<version>/."""
97
+ version_dir.parent.mkdir(parents=True, exist_ok=True)
98
+ url = _bundle_url(version)
99
+
100
+ with tempfile.TemporaryDirectory(
101
+ prefix=f"alphapfn-{version}-", dir=str(version_dir.parent)
102
+ ) as tmpdir:
103
+ tmp = Path(tmpdir)
104
+ zip_path = tmp / "bundle.zip"
105
+ _download(url, zip_path)
106
+
107
+ extract_dir = tmp / "extracted"
108
+ extract_dir.mkdir()
109
+ with zipfile.ZipFile(zip_path) as zf:
110
+ zf.extractall(extract_dir)
111
+
112
+ # The published bundle's top-level layout is `v1/<predictor>/...`.
113
+ # Find the directory that contains predictors and move it into place.
114
+ roots = [p for p in extract_dir.iterdir() if p.is_dir()]
115
+ if len(roots) == 1 and (roots[0] / _CANARY_PREDICTOR).is_dir():
116
+ src = roots[0]
117
+ elif (extract_dir / _CANARY_PREDICTOR).is_dir():
118
+ src = extract_dir
119
+ else:
120
+ raise RuntimeError(
121
+ f"alphapfn: unexpected bundle layout in {zip_path}. "
122
+ f"Expected a top-level dir containing {_CANARY_PREDICTOR}/. "
123
+ f"Found: {[p.name for p in extract_dir.iterdir()]}"
124
+ )
125
+
126
+ # Atomic publish: rename a sibling temp dir into place, so a
127
+ # concurrent loader either sees the final dir or doesn't.
128
+ staged = version_dir.parent / f".{version_dir.name}.staging-{os.getpid()}"
129
+ if staged.exists():
130
+ shutil.rmtree(staged)
131
+ shutil.move(str(src), str(staged))
132
+ try:
133
+ os.replace(str(staged), str(version_dir))
134
+ except OSError:
135
+ # Another process may have published first; clean up.
136
+ shutil.rmtree(staged, ignore_errors=True)
137
+ if not _is_populated(version_dir):
138
+ raise
139
+
140
+
141
+ def ensure_checkpoints(version: str) -> Path:
142
+ """Return <cache>/<version>/, downloading and extracting if missing."""
143
+ version_dir = get_cache_dir() / version
144
+ if _is_populated(version_dir):
145
+ return version_dir
146
+
147
+ try:
148
+ _download_and_extract(version, version_dir)
149
+ except Exception as e:
150
+ raise RuntimeError(
151
+ f"alphapfn: failed to download checkpoints for version={version!r} "
152
+ f"from {_bundle_url(version)}: {e}\n"
153
+ f"You can pre-populate the cache manually by extracting the bundle "
154
+ f"into {version_dir} (so that {version_dir}/{_CANARY_PREDICTOR}/"
155
+ f"{_CANARY_FILE} exists). The cache root can be overridden via "
156
+ f"$ALPHAPFN_CACHE_DIR."
157
+ ) from e
158
+
159
+ if not _is_populated(version_dir):
160
+ raise RuntimeError(
161
+ f"alphapfn: download appeared to succeed but the cache is still "
162
+ f"missing {_CANARY_PREDICTOR}/{_CANARY_FILE} under {version_dir}."
163
+ )
164
+ return version_dir