aspire-inference 0.1.0a10__tar.gz → 0.1.0a12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/PKG-INFO +1 -1
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/aspire_inference.egg-info/PKG-INFO +1 -1
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/aspire_inference.egg-info/SOURCES.txt +2 -0
- aspire_inference-0.1.0a12/docs/checkpointing.rst +88 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/index.rst +1 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/aspire.py +359 -6
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/torch/flows.py +1 -1
- aspire_inference-0.1.0a12/src/aspire/samplers/base.py +238 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/smc/base.py +133 -48
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/smc/blackjax.py +8 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/smc/emcee.py +8 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/smc/minipcn.py +10 -2
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samples.py +11 -9
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/utils.py +119 -0
- aspire_inference-0.1.0a12/tests/integration_tests/test_checkpointing.py +88 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_samples.py +1 -1
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_utils.py +30 -1
- aspire_inference-0.1.0a10/src/aspire/samplers/base.py +0 -98
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/.github/workflows/lint.yml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/.github/workflows/publish.yml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/.github/workflows/tests.yml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/.gitignore +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/.pre-commit-config.yaml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/LICENSE +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/README.md +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/aspire_inference.egg-info/dependency_links.txt +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/aspire_inference.egg-info/requires.txt +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/aspire_inference.egg-info/top_level.txt +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/Makefile +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/conf.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/entry_points.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/examples.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/installation.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/multiprocessing.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/recipes.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/requirements.txt +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/user_guide.rst +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/examples/basic_example.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/examples/blackjax_smc_example.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/examples/smc_example.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/pyproject.toml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/readthedocs.yml +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/setup.cfg +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/base.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/jax/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/jax/flows.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/jax/utils.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/torch/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/history.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/plot.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/importance.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/mcmc.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/smc/__init__.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/transforms.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/conftest.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/integration_tests/conftest.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/integration_tests/test_integration.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_flows/test_flows_core.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -0
- {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_transforms.py +0 -0
{aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/aspire_inference.egg-info/SOURCES.txt
RENAMED
|
@@ -13,6 +13,7 @@ aspire_inference.egg-info/dependency_links.txt
|
|
|
13
13
|
aspire_inference.egg-info/requires.txt
|
|
14
14
|
aspire_inference.egg-info/top_level.txt
|
|
15
15
|
docs/Makefile
|
|
16
|
+
docs/checkpointing.rst
|
|
16
17
|
docs/conf.py
|
|
17
18
|
docs/entry_points.rst
|
|
18
19
|
docs/examples.rst
|
|
@@ -53,6 +54,7 @@ tests/test_samples.py
|
|
|
53
54
|
tests/test_transforms.py
|
|
54
55
|
tests/test_utils.py
|
|
55
56
|
tests/integration_tests/conftest.py
|
|
57
|
+
tests/integration_tests/test_checkpointing.py
|
|
56
58
|
tests/integration_tests/test_integration.py
|
|
57
59
|
tests/test_flows/test_flows_core.py
|
|
58
60
|
tests/test_flows/test_jax_flows/test_flowjax_flows.py
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
Checkpointing and Resuming
|
|
2
|
+
==========================
|
|
3
|
+
|
|
4
|
+
Aspire provides a few simple patterns to resume long runs.
|
|
5
|
+
|
|
6
|
+
Saving checkpoints while sampling
|
|
7
|
+
---------------------------------
|
|
8
|
+
|
|
9
|
+
- Pass ``checkpoint_path`` (an HDF5 file) to :py:meth:`aspire.Aspire.sample_posterior`
|
|
10
|
+
to write checkpoints as the sampler runs. Use ``checkpoint_every`` to control
|
|
11
|
+
frequency and ``checkpoint_save_config/flow`` to control what metadata is saved.
|
|
12
|
+
- For a convenience wrapper, wrap your sampling in
|
|
13
|
+
``with aspire.auto_checkpoint("run.h5", every=1): ...``. Inside the context,
|
|
14
|
+
``sample_posterior`` will default to checkpointing to that file, and the config/flow
|
|
15
|
+
will be updated as needed.
|
|
16
|
+
|
|
17
|
+
What gets saved
|
|
18
|
+
^^^^^^^^^^^^^^^
|
|
19
|
+
|
|
20
|
+
- The sampler stores checkpoints under ``/checkpoint/state`` in the HDF5 file.
|
|
21
|
+
- Aspire writes ``/aspire_config`` (with ``sampler_type`` and ``sampler_config``) and
|
|
22
|
+
``/flow``. If these already exist, they are overwritten when saving.
|
|
23
|
+
|
|
24
|
+
Resuming from a file
|
|
25
|
+
--------------------
|
|
26
|
+
|
|
27
|
+
- Use :py:meth:`aspire.Aspire.resume_from_file` to rebuild an Aspire instance and flow
|
|
28
|
+
from a checkpoint file:
|
|
29
|
+
|
|
30
|
+
.. code-block:: python
|
|
31
|
+
|
|
32
|
+
aspire = Aspire.resume_from_file(
|
|
33
|
+
"run.h5",
|
|
34
|
+
log_likelihood=log_likelihood,
|
|
35
|
+
log_prior=log_prior,
|
|
36
|
+
)
|
|
37
|
+
# Optionally continue checkpointing to the same file
|
|
38
|
+
with aspire.auto_checkpoint("run.h5", every=1):
|
|
39
|
+
samples = aspire.sample_posterior()
|
|
40
|
+
|
|
41
|
+
- ``resume_from_file`` loads config, flow, and the last checkpoint (if present), and
|
|
42
|
+
primes the instance to resume sampling; you can still override sampler kwargs when
|
|
43
|
+
calling ``sample_posterior``.
|
|
44
|
+
|
|
45
|
+
Manual resume via ``sample_posterior`` args
|
|
46
|
+
-------------------------------------------
|
|
47
|
+
|
|
48
|
+
- If you have a checkpoint blob (bytes or dict) already, you can pass it directly:
|
|
49
|
+
|
|
50
|
+
.. code-block:: python
|
|
51
|
+
|
|
52
|
+
samples = aspire.sample_posterior(
|
|
53
|
+
n_samples=...,
|
|
54
|
+
sampler="smc",
|
|
55
|
+
resume_from=checkpoint_bytes_or_dict,
|
|
56
|
+
checkpoint_path="run.h5", # optional: keep writing checkpoints
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
- To resume from a file without using ``resume_from_file``, load the checkpoint bytes
|
|
60
|
+
and flow yourself, then call ``sample_posterior``:
|
|
61
|
+
|
|
62
|
+
.. code-block:: python
|
|
63
|
+
|
|
64
|
+
from aspire.utils import AspireFile
|
|
65
|
+
|
|
66
|
+
aspire = Aspire(..., flow_backend="zuko")
|
|
67
|
+
with AspireFile("run.h5", "r") as f:
|
|
68
|
+
aspire.load_flow(f, path="flow")
|
|
69
|
+
# Standard layout is /checkpoint/state; adjust if you used a different path
|
|
70
|
+
checkpoint_bytes = f["checkpoint"]["state"][...].tobytes()
|
|
71
|
+
samples = aspire.sample_posterior(
|
|
72
|
+
n_samples=..., sampler="smc", resume_from="run.h5"
|
|
73
|
+
# or resume_from=checkpoint_bytes if you prefer to pass bytes directly
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
Notes and tips
|
|
77
|
+
--------------
|
|
78
|
+
|
|
79
|
+
- Checkpoint files must be HDF5 (``.h5``/``.hdf5``).
|
|
80
|
+
- If a checkpoint is missing in the file (e.g., sampling never wrote one), the flow
|
|
81
|
+
and config are still loaded; you can simply start sampling again and checkpointing
|
|
82
|
+
will continue to the same file.
|
|
83
|
+
- For manual control, you can always call ``save_config`` / ``save_flow`` yourself
|
|
84
|
+
on an :class:`aspire.utils.AspireFile`.
|
|
85
|
+
- SMC samplers also accept a custom ``checkpoint_callback`` and ``checkpoint_every`` if
|
|
86
|
+
you want full control over how checkpoints are persisted or inspected. Provide a
|
|
87
|
+
callable that accepts the checkpoint state dict; from there you can, for example,
|
|
88
|
+
serialize to another format or push to remote storage.
|
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import logging
|
|
2
3
|
import multiprocessing as mp
|
|
4
|
+
import pickle
|
|
5
|
+
from contextlib import contextmanager
|
|
3
6
|
from inspect import signature
|
|
4
7
|
from typing import Any, Callable
|
|
5
8
|
|
|
@@ -8,13 +11,20 @@ import h5py
|
|
|
8
11
|
from .flows import get_flow_wrapper
|
|
9
12
|
from .flows.base import Flow
|
|
10
13
|
from .history import History
|
|
14
|
+
from .samplers.base import Sampler
|
|
11
15
|
from .samples import Samples
|
|
12
16
|
from .transforms import (
|
|
13
17
|
CompositeTransform,
|
|
14
18
|
FlowPreconditioningTransform,
|
|
15
19
|
FlowTransform,
|
|
16
20
|
)
|
|
17
|
-
from .utils import
|
|
21
|
+
from .utils import (
|
|
22
|
+
AspireFile,
|
|
23
|
+
function_id,
|
|
24
|
+
load_from_h5_file,
|
|
25
|
+
recursively_save_to_h5_file,
|
|
26
|
+
resolve_xp,
|
|
27
|
+
)
|
|
18
28
|
|
|
19
29
|
logger = logging.getLogger(__name__)
|
|
20
30
|
|
|
@@ -102,6 +112,7 @@ class Aspire:
|
|
|
102
112
|
self.dtype = dtype
|
|
103
113
|
|
|
104
114
|
self._flow = flow
|
|
115
|
+
self._sampler = None
|
|
105
116
|
|
|
106
117
|
@property
|
|
107
118
|
def flow(self):
|
|
@@ -114,7 +125,7 @@ class Aspire:
|
|
|
114
125
|
self._flow = flow
|
|
115
126
|
|
|
116
127
|
@property
|
|
117
|
-
def sampler(self):
|
|
128
|
+
def sampler(self) -> Sampler | None:
|
|
118
129
|
"""The sampler object."""
|
|
119
130
|
return self._sampler
|
|
120
131
|
|
|
@@ -192,7 +203,29 @@ class Aspire:
|
|
|
192
203
|
**self.flow_kwargs,
|
|
193
204
|
)
|
|
194
205
|
|
|
195
|
-
def fit(
|
|
206
|
+
def fit(
|
|
207
|
+
self,
|
|
208
|
+
samples: Samples,
|
|
209
|
+
checkpoint_path: str | None = None,
|
|
210
|
+
checkpoint_save_config: bool = True,
|
|
211
|
+
overwrite: bool = False,
|
|
212
|
+
**kwargs,
|
|
213
|
+
) -> History:
|
|
214
|
+
"""Fit the normalizing flow to the provided samples.
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
samples : Samples
|
|
219
|
+
The samples to fit the flow to.
|
|
220
|
+
checkpoint_path : str | None
|
|
221
|
+
Path to save the checkpoint. If None, no checkpoint is saved.
|
|
222
|
+
checkpoint_save_config : bool
|
|
223
|
+
Whether to save the Aspire configuration to the checkpoint.
|
|
224
|
+
overwrite : bool
|
|
225
|
+
Whether to overwrite an existing flow in the checkpoint file.
|
|
226
|
+
kwargs : dict
|
|
227
|
+
Keyword arguments to pass to the flow's fit method.
|
|
228
|
+
"""
|
|
196
229
|
if self.xp is None:
|
|
197
230
|
self.xp = samples.xp
|
|
198
231
|
|
|
@@ -202,6 +235,28 @@ class Aspire:
|
|
|
202
235
|
self.training_samples = samples
|
|
203
236
|
logger.info(f"Training with {len(samples.x)} samples")
|
|
204
237
|
history = self.flow.fit(samples.x, **kwargs)
|
|
238
|
+
defaults = getattr(self, "_checkpoint_defaults", None)
|
|
239
|
+
if checkpoint_path is None and defaults:
|
|
240
|
+
checkpoint_path = defaults["path"]
|
|
241
|
+
checkpoint_save_config = defaults["save_config"]
|
|
242
|
+
saved_config = (
|
|
243
|
+
defaults.get("saved_config", False) if defaults else False
|
|
244
|
+
)
|
|
245
|
+
if checkpoint_path is not None:
|
|
246
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
247
|
+
if checkpoint_save_config and not saved_config:
|
|
248
|
+
if "aspire_config" in h5_file:
|
|
249
|
+
del h5_file["aspire_config"]
|
|
250
|
+
self.save_config(h5_file, include_sampler_config=False)
|
|
251
|
+
if defaults is not None:
|
|
252
|
+
defaults["saved_config"] = True
|
|
253
|
+
# Save flow only if missing or overwrite=True
|
|
254
|
+
if "flow" in h5_file:
|
|
255
|
+
if overwrite:
|
|
256
|
+
del h5_file["flow"]
|
|
257
|
+
self.save_flow(h5_file)
|
|
258
|
+
else:
|
|
259
|
+
self.save_flow(h5_file)
|
|
205
260
|
return history
|
|
206
261
|
|
|
207
262
|
def get_sampler_class(self, sampler_type: str) -> Callable:
|
|
@@ -241,6 +296,13 @@ class Aspire:
|
|
|
241
296
|
----------
|
|
242
297
|
sampler_type : str
|
|
243
298
|
The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
|
|
299
|
+
preconditioning: str
|
|
300
|
+
Type of preconditioning to apply in the sampler. Options are
|
|
301
|
+
'default', 'flow', or 'none'.
|
|
302
|
+
preconditioning_kwargs: dict
|
|
303
|
+
Keyword arguments to pass to the preconditioning transform.
|
|
304
|
+
kwargs : dict
|
|
305
|
+
Keyword arguments to pass to the sampler.
|
|
244
306
|
"""
|
|
245
307
|
SamplerClass = self.get_sampler_class(sampler_type)
|
|
246
308
|
|
|
@@ -304,6 +366,9 @@ class Aspire:
|
|
|
304
366
|
return_history: bool = False,
|
|
305
367
|
preconditioning: str | None = None,
|
|
306
368
|
preconditioning_kwargs: dict | None = None,
|
|
369
|
+
checkpoint_path: str | None = None,
|
|
370
|
+
checkpoint_every: int = 1,
|
|
371
|
+
checkpoint_save_config: bool = True,
|
|
307
372
|
**kwargs,
|
|
308
373
|
) -> Samples:
|
|
309
374
|
"""Draw samples from the posterior distribution.
|
|
@@ -342,6 +407,14 @@ class Aspire:
|
|
|
342
407
|
will default to 'none' and the other samplers to 'default'
|
|
343
408
|
preconditioning_kwargs: dict
|
|
344
409
|
Keyword arguments to pass to the preconditioning transform.
|
|
410
|
+
checkpoint_path : str | None
|
|
411
|
+
Path to save the checkpoint. If None, no checkpoint is saved unless
|
|
412
|
+
within an :py:meth:`auto_checkpoint` context or a custom callback
|
|
413
|
+
is provided.
|
|
414
|
+
checkpoint_every : int
|
|
415
|
+
Frequency (in number of sampler iterations) to save the checkpoint.
|
|
416
|
+
checkpoint_save_config : bool
|
|
417
|
+
Whether to save the Aspire configuration to the checkpoint.
|
|
345
418
|
kwargs : dict
|
|
346
419
|
Keyword arguments to pass to the sampler. These are passed
|
|
347
420
|
automatically to the init method of the sampler or to the sample
|
|
@@ -352,6 +425,22 @@ class Aspire:
|
|
|
352
425
|
samples : Samples
|
|
353
426
|
Samples object contain samples and their corresponding weights.
|
|
354
427
|
"""
|
|
428
|
+
if (
|
|
429
|
+
sampler == "importance"
|
|
430
|
+
and hasattr(self, "_resume_sampler_type")
|
|
431
|
+
and self._resume_sampler_type
|
|
432
|
+
):
|
|
433
|
+
sampler = self._resume_sampler_type
|
|
434
|
+
|
|
435
|
+
if "resume_from" not in kwargs and hasattr(
|
|
436
|
+
self, "_resume_from_default"
|
|
437
|
+
):
|
|
438
|
+
kwargs["resume_from"] = self._resume_from_default
|
|
439
|
+
if hasattr(self, "_resume_overrides"):
|
|
440
|
+
kwargs.update(self._resume_overrides)
|
|
441
|
+
if hasattr(self, "_resume_n_samples") and n_samples == 1000:
|
|
442
|
+
n_samples = self._resume_n_samples
|
|
443
|
+
|
|
355
444
|
SamplerClass = self.get_sampler_class(sampler)
|
|
356
445
|
# Determine sampler initialization parameters
|
|
357
446
|
# and remove them from kwargs
|
|
@@ -373,7 +462,73 @@ class Aspire:
|
|
|
373
462
|
preconditioning_kwargs=preconditioning_kwargs,
|
|
374
463
|
**sampler_kwargs,
|
|
375
464
|
)
|
|
465
|
+
self._last_sampler_type = sampler
|
|
466
|
+
# Auto-checkpoint convenience: set defaults for checkpointing to a single file
|
|
467
|
+
defaults = getattr(self, "_checkpoint_defaults", None)
|
|
468
|
+
if checkpoint_path is None and defaults:
|
|
469
|
+
checkpoint_path = defaults["path"]
|
|
470
|
+
checkpoint_every = defaults["every"]
|
|
471
|
+
checkpoint_save_config = defaults["save_config"]
|
|
472
|
+
saved_flow = defaults.get("saved_flow", False) if defaults else False
|
|
473
|
+
saved_config = (
|
|
474
|
+
defaults.get("saved_config", False) if defaults else False
|
|
475
|
+
)
|
|
476
|
+
if checkpoint_path is not None:
|
|
477
|
+
kwargs.setdefault("checkpoint_file_path", checkpoint_path)
|
|
478
|
+
kwargs.setdefault("checkpoint_every", checkpoint_every)
|
|
479
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
480
|
+
if checkpoint_save_config:
|
|
481
|
+
if "aspire_config" in h5_file:
|
|
482
|
+
del h5_file["aspire_config"]
|
|
483
|
+
self.save_config(
|
|
484
|
+
h5_file,
|
|
485
|
+
include_sampler_config=True,
|
|
486
|
+
include_sample_calls=False,
|
|
487
|
+
)
|
|
488
|
+
saved_config = True
|
|
489
|
+
if defaults is not None:
|
|
490
|
+
defaults["saved_config"] = True
|
|
491
|
+
if (
|
|
492
|
+
self.flow is not None
|
|
493
|
+
and not saved_flow
|
|
494
|
+
and "flow" not in h5_file
|
|
495
|
+
):
|
|
496
|
+
self.save_flow(h5_file)
|
|
497
|
+
saved_flow = True
|
|
498
|
+
if defaults is not None:
|
|
499
|
+
defaults["saved_flow"] = True
|
|
500
|
+
|
|
376
501
|
samples = self._sampler.sample(n_samples, **kwargs)
|
|
502
|
+
self._last_sample_posterior_kwargs = {
|
|
503
|
+
"n_samples": n_samples,
|
|
504
|
+
"sampler": sampler,
|
|
505
|
+
"xp": xp,
|
|
506
|
+
"return_history": return_history,
|
|
507
|
+
"preconditioning": preconditioning,
|
|
508
|
+
"preconditioning_kwargs": preconditioning_kwargs,
|
|
509
|
+
"sampler_init_kwargs": sampler_kwargs,
|
|
510
|
+
"sample_kwargs": copy.deepcopy(kwargs),
|
|
511
|
+
}
|
|
512
|
+
if checkpoint_path is not None:
|
|
513
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
514
|
+
if checkpoint_save_config and not saved_config:
|
|
515
|
+
if "aspire_config" in h5_file:
|
|
516
|
+
del h5_file["aspire_config"]
|
|
517
|
+
self.save_config(
|
|
518
|
+
h5_file,
|
|
519
|
+
include_sampler_config=True,
|
|
520
|
+
include_sample_calls=False,
|
|
521
|
+
)
|
|
522
|
+
if defaults is not None:
|
|
523
|
+
defaults["saved_config"] = True
|
|
524
|
+
if (
|
|
525
|
+
self.flow is not None
|
|
526
|
+
and not saved_flow
|
|
527
|
+
and "flow" not in h5_file
|
|
528
|
+
):
|
|
529
|
+
self.save_flow(h5_file)
|
|
530
|
+
if defaults is not None:
|
|
531
|
+
defaults["saved_flow"] = True
|
|
377
532
|
if xp is not None:
|
|
378
533
|
samples = samples.to_namespace(xp)
|
|
379
534
|
samples.parameters = self.parameters
|
|
@@ -388,6 +543,122 @@ class Aspire:
|
|
|
388
543
|
else:
|
|
389
544
|
return samples
|
|
390
545
|
|
|
546
|
+
@classmethod
|
|
547
|
+
def resume_from_file(
|
|
548
|
+
cls,
|
|
549
|
+
file_path: str,
|
|
550
|
+
*,
|
|
551
|
+
log_likelihood: Callable,
|
|
552
|
+
log_prior: Callable,
|
|
553
|
+
sampler: str | None = None,
|
|
554
|
+
checkpoint_path: str = "checkpoint",
|
|
555
|
+
checkpoint_dset: str = "state",
|
|
556
|
+
flow_path: str = "flow",
|
|
557
|
+
config_path: str = "aspire_config",
|
|
558
|
+
resume_kwargs: dict | None = None,
|
|
559
|
+
):
|
|
560
|
+
"""
|
|
561
|
+
Recreate an Aspire object from a single file and prepare to resume sampling.
|
|
562
|
+
|
|
563
|
+
Parameters
|
|
564
|
+
----------
|
|
565
|
+
file_path : str
|
|
566
|
+
Path to the HDF5 file containing config, flow, and checkpoint.
|
|
567
|
+
log_likelihood : Callable
|
|
568
|
+
Log-likelihood function (required, not pickled).
|
|
569
|
+
log_prior : Callable
|
|
570
|
+
Log-prior function (required, not pickled).
|
|
571
|
+
sampler : str
|
|
572
|
+
Sampler type to use (e.g., 'smc', 'minipcn_smc', 'emcee_smc'). If None,
|
|
573
|
+
will attempt to infer from saved config or checkpoint metadata.
|
|
574
|
+
checkpoint_path : str
|
|
575
|
+
HDF5 group path where the checkpoint is stored.
|
|
576
|
+
checkpoint_dset : str
|
|
577
|
+
Dataset name within the checkpoint group.
|
|
578
|
+
flow_path : str
|
|
579
|
+
HDF5 path to the saved flow.
|
|
580
|
+
config_path : str
|
|
581
|
+
HDF5 path to the saved Aspire config.
|
|
582
|
+
resume_kwargs : dict | None
|
|
583
|
+
Optional overrides to apply when resuming (e.g., checkpoint_every).
|
|
584
|
+
"""
|
|
585
|
+
(
|
|
586
|
+
aspire,
|
|
587
|
+
checkpoint_bytes,
|
|
588
|
+
checkpoint_state,
|
|
589
|
+
sampler_config,
|
|
590
|
+
saved_sampler_type,
|
|
591
|
+
n_samples,
|
|
592
|
+
) = cls._build_aspire_from_file(
|
|
593
|
+
file_path=file_path,
|
|
594
|
+
log_likelihood=log_likelihood,
|
|
595
|
+
log_prior=log_prior,
|
|
596
|
+
checkpoint_path=checkpoint_path,
|
|
597
|
+
checkpoint_dset=checkpoint_dset,
|
|
598
|
+
flow_path=flow_path,
|
|
599
|
+
config_path=config_path,
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
sampler_config = sampler_config or {}
|
|
603
|
+
sampler_config.pop("sampler_class", None)
|
|
604
|
+
|
|
605
|
+
if checkpoint_bytes is not None:
|
|
606
|
+
aspire._resume_from_default = checkpoint_bytes
|
|
607
|
+
aspire._resume_sampler_type = (
|
|
608
|
+
sampler
|
|
609
|
+
or saved_sampler_type
|
|
610
|
+
or (
|
|
611
|
+
checkpoint_state.get("sampler")
|
|
612
|
+
if checkpoint_state
|
|
613
|
+
else None
|
|
614
|
+
)
|
|
615
|
+
)
|
|
616
|
+
aspire._resume_n_samples = n_samples
|
|
617
|
+
aspire._resume_overrides = resume_kwargs or {}
|
|
618
|
+
aspire._resume_sampler_config = sampler_config
|
|
619
|
+
aspire._checkpoint_defaults = {
|
|
620
|
+
"path": file_path,
|
|
621
|
+
"every": 1,
|
|
622
|
+
"save_config": False,
|
|
623
|
+
"save_flow": False,
|
|
624
|
+
"saved_config": False,
|
|
625
|
+
"saved_flow": False,
|
|
626
|
+
}
|
|
627
|
+
return aspire
|
|
628
|
+
|
|
629
|
+
@contextmanager
|
|
630
|
+
def auto_checkpoint(
|
|
631
|
+
self,
|
|
632
|
+
path: str,
|
|
633
|
+
every: int = 1,
|
|
634
|
+
save_config: bool = True,
|
|
635
|
+
save_flow: bool = True,
|
|
636
|
+
):
|
|
637
|
+
"""
|
|
638
|
+
Context manager to auto-save checkpoints, config, and flow to a file.
|
|
639
|
+
|
|
640
|
+
Within the context, sample_posterior will default to writing checkpoints
|
|
641
|
+
to the given path with the specified frequency, and will append config/flow
|
|
642
|
+
after sampling.
|
|
643
|
+
"""
|
|
644
|
+
prev = getattr(self, "_checkpoint_defaults", None)
|
|
645
|
+
self._checkpoint_defaults = {
|
|
646
|
+
"path": path,
|
|
647
|
+
"every": every,
|
|
648
|
+
"save_config": save_config,
|
|
649
|
+
"save_flow": save_flow,
|
|
650
|
+
"saved_config": False,
|
|
651
|
+
"saved_flow": False,
|
|
652
|
+
}
|
|
653
|
+
try:
|
|
654
|
+
yield self
|
|
655
|
+
finally:
|
|
656
|
+
if prev is None:
|
|
657
|
+
if hasattr(self, "_checkpoint_defaults"):
|
|
658
|
+
delattr(self, "_checkpoint_defaults")
|
|
659
|
+
else:
|
|
660
|
+
self._checkpoint_defaults = prev
|
|
661
|
+
|
|
391
662
|
def enable_pool(self, pool: mp.Pool, **kwargs):
|
|
392
663
|
"""Context manager to temporarily replace the log_likelihood method
|
|
393
664
|
with a version that uses a multiprocessing pool to parallelize
|
|
@@ -417,8 +688,8 @@ class Aspire:
|
|
|
417
688
|
method of the sampler.
|
|
418
689
|
"""
|
|
419
690
|
config = {
|
|
420
|
-
"log_likelihood": self.log_likelihood
|
|
421
|
-
"log_prior": self.log_prior
|
|
691
|
+
"log_likelihood": function_id(self.log_likelihood),
|
|
692
|
+
"log_prior": function_id(self.log_prior),
|
|
422
693
|
"dims": self.dims,
|
|
423
694
|
"parameters": self.parameters,
|
|
424
695
|
"periodic_parameters": self.periodic_parameters,
|
|
@@ -432,12 +703,16 @@ class Aspire:
|
|
|
432
703
|
"flow_kwargs": self.flow_kwargs,
|
|
433
704
|
"eps": self.eps,
|
|
434
705
|
}
|
|
706
|
+
if hasattr(self, "_last_sampler_type"):
|
|
707
|
+
config["sampler_type"] = self._last_sampler_type
|
|
435
708
|
if include_sampler_config:
|
|
709
|
+
if self.sampler is None:
|
|
710
|
+
raise ValueError("Sampler has not been initialized.")
|
|
436
711
|
config["sampler_config"] = self.sampler.config_dict(**kwargs)
|
|
437
712
|
return config
|
|
438
713
|
|
|
439
714
|
def save_config(
|
|
440
|
-
self, h5_file: h5py.File, path="aspire_config", **kwargs
|
|
715
|
+
self, h5_file: h5py.File | AspireFile, path="aspire_config", **kwargs
|
|
441
716
|
) -> None:
|
|
442
717
|
"""Save the configuration to an HDF5 file.
|
|
443
718
|
|
|
@@ -484,6 +759,7 @@ class Aspire:
|
|
|
484
759
|
FlowClass, xp = get_flow_wrapper(
|
|
485
760
|
backend=self.flow_backend, flow_matching=self.flow_matching
|
|
486
761
|
)
|
|
762
|
+
logger.debug(f"Loading flow of type {FlowClass} from {path}")
|
|
487
763
|
self._flow = FlowClass.load(h5_file, path=path)
|
|
488
764
|
|
|
489
765
|
def save_config_to_json(self, filename: str) -> None:
|
|
@@ -504,3 +780,80 @@ class Aspire:
|
|
|
504
780
|
x, log_q = self.flow.sample_and_log_prob(n_samples)
|
|
505
781
|
samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
|
|
506
782
|
return samples
|
|
783
|
+
|
|
784
|
+
# --- Resume helpers ---
|
|
785
|
+
@staticmethod
|
|
786
|
+
def _build_aspire_from_file(
|
|
787
|
+
file_path: str,
|
|
788
|
+
log_likelihood: Callable,
|
|
789
|
+
log_prior: Callable,
|
|
790
|
+
checkpoint_path: str,
|
|
791
|
+
checkpoint_dset: str,
|
|
792
|
+
flow_path: str,
|
|
793
|
+
config_path: str,
|
|
794
|
+
):
|
|
795
|
+
"""Construct an Aspire instance, load flow, and gather checkpoint metadata from file."""
|
|
796
|
+
with AspireFile(file_path, "r") as h5_file:
|
|
797
|
+
if config_path not in h5_file:
|
|
798
|
+
raise ValueError(
|
|
799
|
+
f"Config path '{config_path}' not found in {file_path}"
|
|
800
|
+
)
|
|
801
|
+
config_dict = load_from_h5_file(h5_file, config_path)
|
|
802
|
+
try:
|
|
803
|
+
checkpoint_bytes = h5_file[checkpoint_path][checkpoint_dset][
|
|
804
|
+
...
|
|
805
|
+
].tobytes()
|
|
806
|
+
except Exception:
|
|
807
|
+
logger.warning(
|
|
808
|
+
"Checkpoint not found at %s/%s in %s; will resume without a checkpoint.",
|
|
809
|
+
checkpoint_path,
|
|
810
|
+
checkpoint_dset,
|
|
811
|
+
file_path,
|
|
812
|
+
)
|
|
813
|
+
checkpoint_bytes = None
|
|
814
|
+
|
|
815
|
+
sampler_config = config_dict.pop("sampler_config", None)
|
|
816
|
+
saved_sampler_type = config_dict.pop("sampler_type", None)
|
|
817
|
+
if isinstance(config_dict.get("xp"), str):
|
|
818
|
+
config_dict["xp"] = resolve_xp(config_dict["xp"])
|
|
819
|
+
config_dict["log_likelihood"] = log_likelihood
|
|
820
|
+
config_dict["log_prior"] = log_prior
|
|
821
|
+
|
|
822
|
+
aspire = Aspire(**config_dict)
|
|
823
|
+
|
|
824
|
+
with AspireFile(file_path, "r") as h5_file:
|
|
825
|
+
if flow_path in h5_file:
|
|
826
|
+
logger.info(f"Loading flow from {flow_path} in {file_path}")
|
|
827
|
+
aspire.load_flow(h5_file, path=flow_path)
|
|
828
|
+
else:
|
|
829
|
+
raise ValueError(
|
|
830
|
+
f"Flow path '{flow_path}' not found in {file_path}"
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
n_samples = None
|
|
834
|
+
checkpoint_state = None
|
|
835
|
+
if checkpoint_bytes is not None:
|
|
836
|
+
try:
|
|
837
|
+
checkpoint_state = pickle.loads(checkpoint_bytes)
|
|
838
|
+
samples_saved = (
|
|
839
|
+
checkpoint_state.get("samples")
|
|
840
|
+
if checkpoint_state
|
|
841
|
+
else None
|
|
842
|
+
)
|
|
843
|
+
if samples_saved is not None:
|
|
844
|
+
n_samples = len(samples_saved)
|
|
845
|
+
if aspire.xp is None and hasattr(samples_saved, "xp"):
|
|
846
|
+
aspire.xp = samples_saved.xp
|
|
847
|
+
except Exception:
|
|
848
|
+
logger.warning(
|
|
849
|
+
"Failed to decode checkpoint; proceeding without resume state."
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
return (
|
|
853
|
+
aspire,
|
|
854
|
+
checkpoint_bytes,
|
|
855
|
+
checkpoint_state,
|
|
856
|
+
sampler_config,
|
|
857
|
+
saved_sampler_type,
|
|
858
|
+
n_samples,
|
|
859
|
+
)
|
|
@@ -92,7 +92,7 @@ class BaseTorchFlow(Flow):
|
|
|
92
92
|
config = load_from_h5_file(flow_grp, "config")
|
|
93
93
|
config["dtype"] = decode_dtype(torch, config.get("dtype"))
|
|
94
94
|
if "data_transform" in flow_grp:
|
|
95
|
-
from
|
|
95
|
+
from ...transforms import BaseTransform
|
|
96
96
|
|
|
97
97
|
data_transform = BaseTransform.load(
|
|
98
98
|
flow_grp,
|