aspire-inference 0.1.0a10__tar.gz → 0.1.0a11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/PKG-INFO +1 -1
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/aspire_inference.egg-info/PKG-INFO +1 -1
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/aspire_inference.egg-info/SOURCES.txt +2 -0
- aspire_inference-0.1.0a11/docs/checkpointing.rst +88 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/index.rst +1 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/aspire.py +356 -4
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/torch/flows.py +1 -1
- aspire_inference-0.1.0a11/src/aspire/samplers/base.py +238 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/smc/base.py +133 -48
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/smc/blackjax.py +8 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/smc/emcee.py +8 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/smc/minipcn.py +10 -2
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samples.py +11 -9
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/utils.py +98 -0
- aspire_inference-0.1.0a11/tests/integration_tests/test_checkpointing.py +88 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_samples.py +1 -1
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_utils.py +14 -1
- aspire_inference-0.1.0a10/src/aspire/samplers/base.py +0 -98
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/.github/workflows/lint.yml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/.github/workflows/publish.yml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/.github/workflows/tests.yml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/.gitignore +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/.pre-commit-config.yaml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/LICENSE +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/README.md +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/aspire_inference.egg-info/dependency_links.txt +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/aspire_inference.egg-info/requires.txt +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/aspire_inference.egg-info/top_level.txt +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/Makefile +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/conf.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/entry_points.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/examples.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/installation.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/multiprocessing.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/recipes.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/requirements.txt +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/user_guide.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/examples/basic_example.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/examples/blackjax_smc_example.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/examples/smc_example.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/pyproject.toml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/readthedocs.yml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/setup.cfg +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/base.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/jax/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/jax/flows.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/jax/utils.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/torch/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/history.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/plot.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/importance.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/mcmc.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/smc/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/transforms.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/conftest.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/integration_tests/conftest.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/integration_tests/test_integration.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_flows/test_flows_core.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_transforms.py +0 -0
{aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/aspire_inference.egg-info/SOURCES.txt
RENAMED
|
@@ -13,6 +13,7 @@ aspire_inference.egg-info/dependency_links.txt
|
|
|
13
13
|
aspire_inference.egg-info/requires.txt
|
|
14
14
|
aspire_inference.egg-info/top_level.txt
|
|
15
15
|
docs/Makefile
|
|
16
|
+
docs/checkpointing.rst
|
|
16
17
|
docs/conf.py
|
|
17
18
|
docs/entry_points.rst
|
|
18
19
|
docs/examples.rst
|
|
@@ -53,6 +54,7 @@ tests/test_samples.py
|
|
|
53
54
|
tests/test_transforms.py
|
|
54
55
|
tests/test_utils.py
|
|
55
56
|
tests/integration_tests/conftest.py
|
|
57
|
+
tests/integration_tests/test_checkpointing.py
|
|
56
58
|
tests/integration_tests/test_integration.py
|
|
57
59
|
tests/test_flows/test_flows_core.py
|
|
58
60
|
tests/test_flows/test_jax_flows/test_flowjax_flows.py
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
Checkpointing and Resuming
|
|
2
|
+
==========================
|
|
3
|
+
|
|
4
|
+
Aspire provides a few simple patterns to resume long runs.
|
|
5
|
+
|
|
6
|
+
Saving checkpoints while sampling
|
|
7
|
+
---------------------------------
|
|
8
|
+
|
|
9
|
+
- Pass ``checkpoint_path`` (an HDF5 file) to :py:meth:`aspire.Aspire.sample_posterior`
|
|
10
|
+
to write checkpoints as the sampler runs. Use ``checkpoint_every`` to control
|
|
11
|
+
frequency and ``checkpoint_save_config/flow`` to control what metadata is saved.
|
|
12
|
+
- For a convenience wrapper, wrap your sampling in
|
|
13
|
+
``with aspire.auto_checkpoint("run.h5", every=1): ...``. Inside the context,
|
|
14
|
+
``sample_posterior`` will default to checkpointing to that file, and the config/flow
|
|
15
|
+
will be updated as needed.
|
|
16
|
+
|
|
17
|
+
What gets saved
|
|
18
|
+
^^^^^^^^^^^^^^^
|
|
19
|
+
|
|
20
|
+
- The sampler stores checkpoints under ``/checkpoint/state`` in the HDF5 file.
|
|
21
|
+
- Aspire writes ``/aspire_config`` (with ``sampler_type`` and ``sampler_config``) and
|
|
22
|
+
``/flow``. If these already exist, they are overwritten when saving.
|
|
23
|
+
|
|
24
|
+
Resuming from a file
|
|
25
|
+
--------------------
|
|
26
|
+
|
|
27
|
+
- Use :py:meth:`aspire.Aspire.resume_from_file` to rebuild an Aspire instance and flow
|
|
28
|
+
from a checkpoint file:
|
|
29
|
+
|
|
30
|
+
.. code-block:: python
|
|
31
|
+
|
|
32
|
+
aspire = Aspire.resume_from_file(
|
|
33
|
+
"run.h5",
|
|
34
|
+
log_likelihood=log_likelihood,
|
|
35
|
+
log_prior=log_prior,
|
|
36
|
+
)
|
|
37
|
+
# Optionally continue checkpointing to the same file
|
|
38
|
+
with aspire.auto_checkpoint("run.h5", every=1):
|
|
39
|
+
samples = aspire.sample_posterior()
|
|
40
|
+
|
|
41
|
+
- ``resume_from_file`` loads config, flow, and the last checkpoint (if present), and
|
|
42
|
+
primes the instance to resume sampling; you can still override sampler kwargs when
|
|
43
|
+
calling ``sample_posterior``.
|
|
44
|
+
|
|
45
|
+
Manual resume via ``sample_posterior`` args
|
|
46
|
+
-------------------------------------------
|
|
47
|
+
|
|
48
|
+
- If you have a checkpoint blob (bytes or dict) already, you can pass it directly:
|
|
49
|
+
|
|
50
|
+
.. code-block:: python
|
|
51
|
+
|
|
52
|
+
samples = aspire.sample_posterior(
|
|
53
|
+
n_samples=...,
|
|
54
|
+
sampler="smc",
|
|
55
|
+
resume_from=checkpoint_bytes_or_dict,
|
|
56
|
+
checkpoint_path="run.h5", # optional: keep writing checkpoints
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
- To resume from a file without using ``resume_from_file``, load the checkpoint bytes
|
|
60
|
+
and flow yourself, then call ``sample_posterior``:
|
|
61
|
+
|
|
62
|
+
.. code-block:: python
|
|
63
|
+
|
|
64
|
+
from aspire.utils import AspireFile
|
|
65
|
+
|
|
66
|
+
aspire = Aspire(..., flow_backend="zuko")
|
|
67
|
+
with AspireFile("run.h5", "r") as f:
|
|
68
|
+
aspire.load_flow(f, path="flow")
|
|
69
|
+
# Standard layout is /checkpoint/state; adjust if you used a different path
|
|
70
|
+
checkpoint_bytes = f["checkpoint"]["state"][...].tobytes()
|
|
71
|
+
samples = aspire.sample_posterior(
|
|
72
|
+
n_samples=..., sampler="smc", resume_from="run.h5"
|
|
73
|
+
# or resume_from=checkpoint_bytes if you prefer to pass bytes directly
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
Notes and tips
|
|
77
|
+
--------------
|
|
78
|
+
|
|
79
|
+
- Checkpoint files must be HDF5 (``.h5``/``.hdf5``).
|
|
80
|
+
- If a checkpoint is missing in the file (e.g., sampling never wrote one), the flow
|
|
81
|
+
and config are still loaded; you can simply start sampling again and checkpointing
|
|
82
|
+
will continue to the same file.
|
|
83
|
+
- For manual control, you can always call ``save_config`` / ``save_flow`` yourself
|
|
84
|
+
on an :class:`aspire.utils.AspireFile`.
|
|
85
|
+
- SMC samplers also accept a custom ``checkpoint_callback`` and ``checkpoint_every`` if
|
|
86
|
+
you want full control over how checkpoints are persisted or inspected. Provide a
|
|
87
|
+
callable that accepts the checkpoint state dict; from there you can, for example,
|
|
88
|
+
serialize to another format or push to remote storage.
|
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import logging
|
|
2
3
|
import multiprocessing as mp
|
|
4
|
+
import pickle
|
|
5
|
+
from contextlib import contextmanager
|
|
3
6
|
from inspect import signature
|
|
4
7
|
from typing import Any, Callable
|
|
5
8
|
|
|
@@ -8,13 +11,19 @@ import h5py
|
|
|
8
11
|
from .flows import get_flow_wrapper
|
|
9
12
|
from .flows.base import Flow
|
|
10
13
|
from .history import History
|
|
14
|
+
from .samplers.base import Sampler
|
|
11
15
|
from .samples import Samples
|
|
12
16
|
from .transforms import (
|
|
13
17
|
CompositeTransform,
|
|
14
18
|
FlowPreconditioningTransform,
|
|
15
19
|
FlowTransform,
|
|
16
20
|
)
|
|
17
|
-
from .utils import
|
|
21
|
+
from .utils import (
|
|
22
|
+
AspireFile,
|
|
23
|
+
load_from_h5_file,
|
|
24
|
+
recursively_save_to_h5_file,
|
|
25
|
+
resolve_xp,
|
|
26
|
+
)
|
|
18
27
|
|
|
19
28
|
logger = logging.getLogger(__name__)
|
|
20
29
|
|
|
@@ -102,6 +111,7 @@ class Aspire:
|
|
|
102
111
|
self.dtype = dtype
|
|
103
112
|
|
|
104
113
|
self._flow = flow
|
|
114
|
+
self._sampler = None
|
|
105
115
|
|
|
106
116
|
@property
|
|
107
117
|
def flow(self):
|
|
@@ -114,7 +124,7 @@ class Aspire:
|
|
|
114
124
|
self._flow = flow
|
|
115
125
|
|
|
116
126
|
@property
|
|
117
|
-
def sampler(self):
|
|
127
|
+
def sampler(self) -> Sampler | None:
|
|
118
128
|
"""The sampler object."""
|
|
119
129
|
return self._sampler
|
|
120
130
|
|
|
@@ -192,7 +202,29 @@ class Aspire:
|
|
|
192
202
|
**self.flow_kwargs,
|
|
193
203
|
)
|
|
194
204
|
|
|
195
|
-
def fit(
|
|
205
|
+
def fit(
|
|
206
|
+
self,
|
|
207
|
+
samples: Samples,
|
|
208
|
+
checkpoint_path: str | None = None,
|
|
209
|
+
checkpoint_save_config: bool = True,
|
|
210
|
+
overwrite: bool = False,
|
|
211
|
+
**kwargs,
|
|
212
|
+
) -> History:
|
|
213
|
+
"""Fit the normalizing flow to the provided samples.
|
|
214
|
+
|
|
215
|
+
Parameters
|
|
216
|
+
----------
|
|
217
|
+
samples : Samples
|
|
218
|
+
The samples to fit the flow to.
|
|
219
|
+
checkpoint_path : str | None
|
|
220
|
+
Path to save the checkpoint. If None, no checkpoint is saved.
|
|
221
|
+
checkpoint_save_config : bool
|
|
222
|
+
Whether to save the Aspire configuration to the checkpoint.
|
|
223
|
+
overwrite : bool
|
|
224
|
+
Whether to overwrite an existing flow in the checkpoint file.
|
|
225
|
+
kwargs : dict
|
|
226
|
+
Keyword arguments to pass to the flow's fit method.
|
|
227
|
+
"""
|
|
196
228
|
if self.xp is None:
|
|
197
229
|
self.xp = samples.xp
|
|
198
230
|
|
|
@@ -202,6 +234,28 @@ class Aspire:
|
|
|
202
234
|
self.training_samples = samples
|
|
203
235
|
logger.info(f"Training with {len(samples.x)} samples")
|
|
204
236
|
history = self.flow.fit(samples.x, **kwargs)
|
|
237
|
+
defaults = getattr(self, "_checkpoint_defaults", None)
|
|
238
|
+
if checkpoint_path is None and defaults:
|
|
239
|
+
checkpoint_path = defaults["path"]
|
|
240
|
+
checkpoint_save_config = defaults["save_config"]
|
|
241
|
+
saved_config = (
|
|
242
|
+
defaults.get("saved_config", False) if defaults else False
|
|
243
|
+
)
|
|
244
|
+
if checkpoint_path is not None:
|
|
245
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
246
|
+
if checkpoint_save_config and not saved_config:
|
|
247
|
+
if "aspire_config" in h5_file:
|
|
248
|
+
del h5_file["aspire_config"]
|
|
249
|
+
self.save_config(h5_file, include_sampler_config=False)
|
|
250
|
+
if defaults is not None:
|
|
251
|
+
defaults["saved_config"] = True
|
|
252
|
+
# Save flow only if missing or overwrite=True
|
|
253
|
+
if "flow" in h5_file:
|
|
254
|
+
if overwrite:
|
|
255
|
+
del h5_file["flow"]
|
|
256
|
+
self.save_flow(h5_file)
|
|
257
|
+
else:
|
|
258
|
+
self.save_flow(h5_file)
|
|
205
259
|
return history
|
|
206
260
|
|
|
207
261
|
def get_sampler_class(self, sampler_type: str) -> Callable:
|
|
@@ -241,6 +295,13 @@ class Aspire:
|
|
|
241
295
|
----------
|
|
242
296
|
sampler_type : str
|
|
243
297
|
The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
|
|
298
|
+
preconditioning: str
|
|
299
|
+
Type of preconditioning to apply in the sampler. Options are
|
|
300
|
+
'default', 'flow', or 'none'.
|
|
301
|
+
preconditioning_kwargs: dict
|
|
302
|
+
Keyword arguments to pass to the preconditioning transform.
|
|
303
|
+
kwargs : dict
|
|
304
|
+
Keyword arguments to pass to the sampler.
|
|
244
305
|
"""
|
|
245
306
|
SamplerClass = self.get_sampler_class(sampler_type)
|
|
246
307
|
|
|
@@ -304,6 +365,9 @@ class Aspire:
|
|
|
304
365
|
return_history: bool = False,
|
|
305
366
|
preconditioning: str | None = None,
|
|
306
367
|
preconditioning_kwargs: dict | None = None,
|
|
368
|
+
checkpoint_path: str | None = None,
|
|
369
|
+
checkpoint_every: int = 1,
|
|
370
|
+
checkpoint_save_config: bool = True,
|
|
307
371
|
**kwargs,
|
|
308
372
|
) -> Samples:
|
|
309
373
|
"""Draw samples from the posterior distribution.
|
|
@@ -342,6 +406,14 @@ class Aspire:
|
|
|
342
406
|
will default to 'none' and the other samplers to 'default'
|
|
343
407
|
preconditioning_kwargs: dict
|
|
344
408
|
Keyword arguments to pass to the preconditioning transform.
|
|
409
|
+
checkpoint_path : str | None
|
|
410
|
+
Path to save the checkpoint. If None, no checkpoint is saved unless
|
|
411
|
+
within an :py:meth:`auto_checkpoint` context or a custom callback
|
|
412
|
+
is provided.
|
|
413
|
+
checkpoint_every : int
|
|
414
|
+
Frequency (in number of sampler iterations) to save the checkpoint.
|
|
415
|
+
checkpoint_save_config : bool
|
|
416
|
+
Whether to save the Aspire configuration to the checkpoint.
|
|
345
417
|
kwargs : dict
|
|
346
418
|
Keyword arguments to pass to the sampler. These are passed
|
|
347
419
|
automatically to the init method of the sampler or to the sample
|
|
@@ -352,6 +424,22 @@ class Aspire:
|
|
|
352
424
|
samples : Samples
|
|
353
425
|
Samples object contain samples and their corresponding weights.
|
|
354
426
|
"""
|
|
427
|
+
if (
|
|
428
|
+
sampler == "importance"
|
|
429
|
+
and hasattr(self, "_resume_sampler_type")
|
|
430
|
+
and self._resume_sampler_type
|
|
431
|
+
):
|
|
432
|
+
sampler = self._resume_sampler_type
|
|
433
|
+
|
|
434
|
+
if "resume_from" not in kwargs and hasattr(
|
|
435
|
+
self, "_resume_from_default"
|
|
436
|
+
):
|
|
437
|
+
kwargs["resume_from"] = self._resume_from_default
|
|
438
|
+
if hasattr(self, "_resume_overrides"):
|
|
439
|
+
kwargs.update(self._resume_overrides)
|
|
440
|
+
if hasattr(self, "_resume_n_samples") and n_samples == 1000:
|
|
441
|
+
n_samples = self._resume_n_samples
|
|
442
|
+
|
|
355
443
|
SamplerClass = self.get_sampler_class(sampler)
|
|
356
444
|
# Determine sampler initialization parameters
|
|
357
445
|
# and remove them from kwargs
|
|
@@ -373,7 +461,73 @@ class Aspire:
|
|
|
373
461
|
preconditioning_kwargs=preconditioning_kwargs,
|
|
374
462
|
**sampler_kwargs,
|
|
375
463
|
)
|
|
464
|
+
self._last_sampler_type = sampler
|
|
465
|
+
# Auto-checkpoint convenience: set defaults for checkpointing to a single file
|
|
466
|
+
defaults = getattr(self, "_checkpoint_defaults", None)
|
|
467
|
+
if checkpoint_path is None and defaults:
|
|
468
|
+
checkpoint_path = defaults["path"]
|
|
469
|
+
checkpoint_every = defaults["every"]
|
|
470
|
+
checkpoint_save_config = defaults["save_config"]
|
|
471
|
+
saved_flow = defaults.get("saved_flow", False) if defaults else False
|
|
472
|
+
saved_config = (
|
|
473
|
+
defaults.get("saved_config", False) if defaults else False
|
|
474
|
+
)
|
|
475
|
+
if checkpoint_path is not None:
|
|
476
|
+
kwargs.setdefault("checkpoint_file_path", checkpoint_path)
|
|
477
|
+
kwargs.setdefault("checkpoint_every", checkpoint_every)
|
|
478
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
479
|
+
if checkpoint_save_config:
|
|
480
|
+
if "aspire_config" in h5_file:
|
|
481
|
+
del h5_file["aspire_config"]
|
|
482
|
+
self.save_config(
|
|
483
|
+
h5_file,
|
|
484
|
+
include_sampler_config=True,
|
|
485
|
+
include_sample_calls=False,
|
|
486
|
+
)
|
|
487
|
+
saved_config = True
|
|
488
|
+
if defaults is not None:
|
|
489
|
+
defaults["saved_config"] = True
|
|
490
|
+
if (
|
|
491
|
+
self.flow is not None
|
|
492
|
+
and not saved_flow
|
|
493
|
+
and "flow" not in h5_file
|
|
494
|
+
):
|
|
495
|
+
self.save_flow(h5_file)
|
|
496
|
+
saved_flow = True
|
|
497
|
+
if defaults is not None:
|
|
498
|
+
defaults["saved_flow"] = True
|
|
499
|
+
|
|
376
500
|
samples = self._sampler.sample(n_samples, **kwargs)
|
|
501
|
+
self._last_sample_posterior_kwargs = {
|
|
502
|
+
"n_samples": n_samples,
|
|
503
|
+
"sampler": sampler,
|
|
504
|
+
"xp": xp,
|
|
505
|
+
"return_history": return_history,
|
|
506
|
+
"preconditioning": preconditioning,
|
|
507
|
+
"preconditioning_kwargs": preconditioning_kwargs,
|
|
508
|
+
"sampler_init_kwargs": sampler_kwargs,
|
|
509
|
+
"sample_kwargs": copy.deepcopy(kwargs),
|
|
510
|
+
}
|
|
511
|
+
if checkpoint_path is not None:
|
|
512
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
513
|
+
if checkpoint_save_config and not saved_config:
|
|
514
|
+
if "aspire_config" in h5_file:
|
|
515
|
+
del h5_file["aspire_config"]
|
|
516
|
+
self.save_config(
|
|
517
|
+
h5_file,
|
|
518
|
+
include_sampler_config=True,
|
|
519
|
+
include_sample_calls=False,
|
|
520
|
+
)
|
|
521
|
+
if defaults is not None:
|
|
522
|
+
defaults["saved_config"] = True
|
|
523
|
+
if (
|
|
524
|
+
self.flow is not None
|
|
525
|
+
and not saved_flow
|
|
526
|
+
and "flow" not in h5_file
|
|
527
|
+
):
|
|
528
|
+
self.save_flow(h5_file)
|
|
529
|
+
if defaults is not None:
|
|
530
|
+
defaults["saved_flow"] = True
|
|
377
531
|
if xp is not None:
|
|
378
532
|
samples = samples.to_namespace(xp)
|
|
379
533
|
samples.parameters = self.parameters
|
|
@@ -388,6 +542,122 @@ class Aspire:
|
|
|
388
542
|
else:
|
|
389
543
|
return samples
|
|
390
544
|
|
|
545
|
+
@classmethod
|
|
546
|
+
def resume_from_file(
|
|
547
|
+
cls,
|
|
548
|
+
file_path: str,
|
|
549
|
+
*,
|
|
550
|
+
log_likelihood: Callable,
|
|
551
|
+
log_prior: Callable,
|
|
552
|
+
sampler: str | None = None,
|
|
553
|
+
checkpoint_path: str = "checkpoint",
|
|
554
|
+
checkpoint_dset: str = "state",
|
|
555
|
+
flow_path: str = "flow",
|
|
556
|
+
config_path: str = "aspire_config",
|
|
557
|
+
resume_kwargs: dict | None = None,
|
|
558
|
+
):
|
|
559
|
+
"""
|
|
560
|
+
Recreate an Aspire object from a single file and prepare to resume sampling.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
----------
|
|
564
|
+
file_path : str
|
|
565
|
+
Path to the HDF5 file containing config, flow, and checkpoint.
|
|
566
|
+
log_likelihood : Callable
|
|
567
|
+
Log-likelihood function (required, not pickled).
|
|
568
|
+
log_prior : Callable
|
|
569
|
+
Log-prior function (required, not pickled).
|
|
570
|
+
sampler : str
|
|
571
|
+
Sampler type to use (e.g., 'smc', 'minipcn_smc', 'emcee_smc'). If None,
|
|
572
|
+
will attempt to infer from saved config or checkpoint metadata.
|
|
573
|
+
checkpoint_path : str
|
|
574
|
+
HDF5 group path where the checkpoint is stored.
|
|
575
|
+
checkpoint_dset : str
|
|
576
|
+
Dataset name within the checkpoint group.
|
|
577
|
+
flow_path : str
|
|
578
|
+
HDF5 path to the saved flow.
|
|
579
|
+
config_path : str
|
|
580
|
+
HDF5 path to the saved Aspire config.
|
|
581
|
+
resume_kwargs : dict | None
|
|
582
|
+
Optional overrides to apply when resuming (e.g., checkpoint_every).
|
|
583
|
+
"""
|
|
584
|
+
(
|
|
585
|
+
aspire,
|
|
586
|
+
checkpoint_bytes,
|
|
587
|
+
checkpoint_state,
|
|
588
|
+
sampler_config,
|
|
589
|
+
saved_sampler_type,
|
|
590
|
+
n_samples,
|
|
591
|
+
) = cls._build_aspire_from_file(
|
|
592
|
+
file_path=file_path,
|
|
593
|
+
log_likelihood=log_likelihood,
|
|
594
|
+
log_prior=log_prior,
|
|
595
|
+
checkpoint_path=checkpoint_path,
|
|
596
|
+
checkpoint_dset=checkpoint_dset,
|
|
597
|
+
flow_path=flow_path,
|
|
598
|
+
config_path=config_path,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
sampler_config = sampler_config or {}
|
|
602
|
+
sampler_config.pop("sampler_class", None)
|
|
603
|
+
|
|
604
|
+
if checkpoint_bytes is not None:
|
|
605
|
+
aspire._resume_from_default = checkpoint_bytes
|
|
606
|
+
aspire._resume_sampler_type = (
|
|
607
|
+
sampler
|
|
608
|
+
or saved_sampler_type
|
|
609
|
+
or (
|
|
610
|
+
checkpoint_state.get("sampler")
|
|
611
|
+
if checkpoint_state
|
|
612
|
+
else None
|
|
613
|
+
)
|
|
614
|
+
)
|
|
615
|
+
aspire._resume_n_samples = n_samples
|
|
616
|
+
aspire._resume_overrides = resume_kwargs or {}
|
|
617
|
+
aspire._resume_sampler_config = sampler_config
|
|
618
|
+
aspire._checkpoint_defaults = {
|
|
619
|
+
"path": file_path,
|
|
620
|
+
"every": 1,
|
|
621
|
+
"save_config": False,
|
|
622
|
+
"save_flow": False,
|
|
623
|
+
"saved_config": False,
|
|
624
|
+
"saved_flow": False,
|
|
625
|
+
}
|
|
626
|
+
return aspire
|
|
627
|
+
|
|
628
|
+
@contextmanager
|
|
629
|
+
def auto_checkpoint(
|
|
630
|
+
self,
|
|
631
|
+
path: str,
|
|
632
|
+
every: int = 1,
|
|
633
|
+
save_config: bool = True,
|
|
634
|
+
save_flow: bool = True,
|
|
635
|
+
):
|
|
636
|
+
"""
|
|
637
|
+
Context manager to auto-save checkpoints, config, and flow to a file.
|
|
638
|
+
|
|
639
|
+
Within the context, sample_posterior will default to writing checkpoints
|
|
640
|
+
to the given path with the specified frequency, and will append config/flow
|
|
641
|
+
after sampling.
|
|
642
|
+
"""
|
|
643
|
+
prev = getattr(self, "_checkpoint_defaults", None)
|
|
644
|
+
self._checkpoint_defaults = {
|
|
645
|
+
"path": path,
|
|
646
|
+
"every": every,
|
|
647
|
+
"save_config": save_config,
|
|
648
|
+
"save_flow": save_flow,
|
|
649
|
+
"saved_config": False,
|
|
650
|
+
"saved_flow": False,
|
|
651
|
+
}
|
|
652
|
+
try:
|
|
653
|
+
yield self
|
|
654
|
+
finally:
|
|
655
|
+
if prev is None:
|
|
656
|
+
if hasattr(self, "_checkpoint_defaults"):
|
|
657
|
+
delattr(self, "_checkpoint_defaults")
|
|
658
|
+
else:
|
|
659
|
+
self._checkpoint_defaults = prev
|
|
660
|
+
|
|
391
661
|
def enable_pool(self, pool: mp.Pool, **kwargs):
|
|
392
662
|
"""Context manager to temporarily replace the log_likelihood method
|
|
393
663
|
with a version that uses a multiprocessing pool to parallelize
|
|
@@ -432,12 +702,16 @@ class Aspire:
|
|
|
432
702
|
"flow_kwargs": self.flow_kwargs,
|
|
433
703
|
"eps": self.eps,
|
|
434
704
|
}
|
|
705
|
+
if hasattr(self, "_last_sampler_type"):
|
|
706
|
+
config["sampler_type"] = self._last_sampler_type
|
|
435
707
|
if include_sampler_config:
|
|
708
|
+
if self.sampler is None:
|
|
709
|
+
raise ValueError("Sampler has not been initialized.")
|
|
436
710
|
config["sampler_config"] = self.sampler.config_dict(**kwargs)
|
|
437
711
|
return config
|
|
438
712
|
|
|
439
713
|
def save_config(
|
|
440
|
-
self, h5_file: h5py.File, path="aspire_config", **kwargs
|
|
714
|
+
self, h5_file: h5py.File | AspireFile, path="aspire_config", **kwargs
|
|
441
715
|
) -> None:
|
|
442
716
|
"""Save the configuration to an HDF5 file.
|
|
443
717
|
|
|
@@ -484,6 +758,7 @@ class Aspire:
|
|
|
484
758
|
FlowClass, xp = get_flow_wrapper(
|
|
485
759
|
backend=self.flow_backend, flow_matching=self.flow_matching
|
|
486
760
|
)
|
|
761
|
+
logger.debug(f"Loading flow of type {FlowClass} from {path}")
|
|
487
762
|
self._flow = FlowClass.load(h5_file, path=path)
|
|
488
763
|
|
|
489
764
|
def save_config_to_json(self, filename: str) -> None:
|
|
@@ -504,3 +779,80 @@ class Aspire:
|
|
|
504
779
|
x, log_q = self.flow.sample_and_log_prob(n_samples)
|
|
505
780
|
samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
|
|
506
781
|
return samples
|
|
782
|
+
|
|
783
|
+
# --- Resume helpers ---
|
|
784
|
+
@staticmethod
|
|
785
|
+
def _build_aspire_from_file(
|
|
786
|
+
file_path: str,
|
|
787
|
+
log_likelihood: Callable,
|
|
788
|
+
log_prior: Callable,
|
|
789
|
+
checkpoint_path: str,
|
|
790
|
+
checkpoint_dset: str,
|
|
791
|
+
flow_path: str,
|
|
792
|
+
config_path: str,
|
|
793
|
+
):
|
|
794
|
+
"""Construct an Aspire instance, load flow, and gather checkpoint metadata from file."""
|
|
795
|
+
with AspireFile(file_path, "r") as h5_file:
|
|
796
|
+
if config_path not in h5_file:
|
|
797
|
+
raise ValueError(
|
|
798
|
+
f"Config path '{config_path}' not found in {file_path}"
|
|
799
|
+
)
|
|
800
|
+
config_dict = load_from_h5_file(h5_file, config_path)
|
|
801
|
+
try:
|
|
802
|
+
checkpoint_bytes = h5_file[checkpoint_path][checkpoint_dset][
|
|
803
|
+
...
|
|
804
|
+
].tobytes()
|
|
805
|
+
except Exception:
|
|
806
|
+
logger.warning(
|
|
807
|
+
"Checkpoint not found at %s/%s in %s; will resume without a checkpoint.",
|
|
808
|
+
checkpoint_path,
|
|
809
|
+
checkpoint_dset,
|
|
810
|
+
file_path,
|
|
811
|
+
)
|
|
812
|
+
checkpoint_bytes = None
|
|
813
|
+
|
|
814
|
+
sampler_config = config_dict.pop("sampler_config", None)
|
|
815
|
+
saved_sampler_type = config_dict.pop("sampler_type", None)
|
|
816
|
+
if isinstance(config_dict.get("xp"), str):
|
|
817
|
+
config_dict["xp"] = resolve_xp(config_dict["xp"])
|
|
818
|
+
config_dict["log_likelihood"] = log_likelihood
|
|
819
|
+
config_dict["log_prior"] = log_prior
|
|
820
|
+
|
|
821
|
+
aspire = Aspire(**config_dict)
|
|
822
|
+
|
|
823
|
+
with AspireFile(file_path, "r") as h5_file:
|
|
824
|
+
if flow_path in h5_file:
|
|
825
|
+
logger.info(f"Loading flow from {flow_path} in {file_path}")
|
|
826
|
+
aspire.load_flow(h5_file, path=flow_path)
|
|
827
|
+
else:
|
|
828
|
+
raise ValueError(
|
|
829
|
+
f"Flow path '{flow_path}' not found in {file_path}"
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
n_samples = None
|
|
833
|
+
checkpoint_state = None
|
|
834
|
+
if checkpoint_bytes is not None:
|
|
835
|
+
try:
|
|
836
|
+
checkpoint_state = pickle.loads(checkpoint_bytes)
|
|
837
|
+
samples_saved = (
|
|
838
|
+
checkpoint_state.get("samples")
|
|
839
|
+
if checkpoint_state
|
|
840
|
+
else None
|
|
841
|
+
)
|
|
842
|
+
if samples_saved is not None:
|
|
843
|
+
n_samples = len(samples_saved)
|
|
844
|
+
if aspire.xp is None and hasattr(samples_saved, "xp"):
|
|
845
|
+
aspire.xp = samples_saved.xp
|
|
846
|
+
except Exception:
|
|
847
|
+
logger.warning(
|
|
848
|
+
"Failed to decode checkpoint; proceeding without resume state."
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
return (
|
|
852
|
+
aspire,
|
|
853
|
+
checkpoint_bytes,
|
|
854
|
+
checkpoint_state,
|
|
855
|
+
sampler_config,
|
|
856
|
+
saved_sampler_type,
|
|
857
|
+
n_samples,
|
|
858
|
+
)
|
|
@@ -92,7 +92,7 @@ class BaseTorchFlow(Flow):
|
|
|
92
92
|
config = load_from_h5_file(flow_grp, "config")
|
|
93
93
|
config["dtype"] = decode_dtype(torch, config.get("dtype"))
|
|
94
94
|
if "data_transform" in flow_grp:
|
|
95
|
-
from
|
|
95
|
+
from ...transforms import BaseTransform
|
|
96
96
|
|
|
97
97
|
data_transform = BaseTransform.load(
|
|
98
98
|
flow_grp,
|