aspire-inference 0.1.0a10__tar.gz → 0.1.0a11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/PKG-INFO +1 -1
  2. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/aspire_inference.egg-info/PKG-INFO +1 -1
  3. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/aspire_inference.egg-info/SOURCES.txt +2 -0
  4. aspire_inference-0.1.0a11/docs/checkpointing.rst +88 -0
  5. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/index.rst +1 -0
  6. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/aspire.py +356 -4
  7. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/torch/flows.py +1 -1
  8. aspire_inference-0.1.0a11/src/aspire/samplers/base.py +238 -0
  9. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/smc/base.py +133 -48
  10. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/smc/blackjax.py +8 -0
  11. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/smc/emcee.py +8 -0
  12. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/smc/minipcn.py +10 -2
  13. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samples.py +11 -9
  14. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/utils.py +98 -0
  15. aspire_inference-0.1.0a11/tests/integration_tests/test_checkpointing.py +88 -0
  16. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_samples.py +1 -1
  17. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_utils.py +14 -1
  18. aspire_inference-0.1.0a10/src/aspire/samplers/base.py +0 -98
  19. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/.github/workflows/lint.yml +0 -0
  20. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/.github/workflows/publish.yml +0 -0
  21. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/.github/workflows/tests.yml +0 -0
  22. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/.gitignore +0 -0
  23. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/.pre-commit-config.yaml +0 -0
  24. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/LICENSE +0 -0
  25. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/README.md +0 -0
  26. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/aspire_inference.egg-info/dependency_links.txt +0 -0
  27. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/aspire_inference.egg-info/requires.txt +0 -0
  28. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/aspire_inference.egg-info/top_level.txt +0 -0
  29. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/Makefile +0 -0
  30. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/conf.py +0 -0
  31. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/entry_points.rst +0 -0
  32. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/examples.rst +0 -0
  33. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/installation.rst +0 -0
  34. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/multiprocessing.rst +0 -0
  35. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/recipes.rst +0 -0
  36. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/requirements.txt +0 -0
  37. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/docs/user_guide.rst +0 -0
  38. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/examples/basic_example.py +0 -0
  39. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/examples/blackjax_smc_example.py +0 -0
  40. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/examples/smc_example.py +0 -0
  41. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/pyproject.toml +0 -0
  42. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/readthedocs.yml +0 -0
  43. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/setup.cfg +0 -0
  44. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/__init__.py +0 -0
  45. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/__init__.py +0 -0
  46. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/base.py +0 -0
  47. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/jax/__init__.py +0 -0
  48. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/jax/flows.py +0 -0
  49. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/jax/utils.py +0 -0
  50. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/flows/torch/__init__.py +0 -0
  51. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/history.py +0 -0
  52. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/plot.py +0 -0
  53. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/__init__.py +0 -0
  54. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/importance.py +0 -0
  55. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/mcmc.py +0 -0
  56. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/samplers/smc/__init__.py +0 -0
  57. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/src/aspire/transforms.py +0 -0
  58. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/conftest.py +0 -0
  59. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/integration_tests/conftest.py +0 -0
  60. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/integration_tests/test_integration.py +0 -0
  61. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_flows/test_flows_core.py +0 -0
  62. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -0
  63. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -0
  64. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a11}/tests/test_transforms.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a10
3
+ Version: 0.1.0a11
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a10
3
+ Version: 0.1.0a11
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -13,6 +13,7 @@ aspire_inference.egg-info/dependency_links.txt
13
13
  aspire_inference.egg-info/requires.txt
14
14
  aspire_inference.egg-info/top_level.txt
15
15
  docs/Makefile
16
+ docs/checkpointing.rst
16
17
  docs/conf.py
17
18
  docs/entry_points.rst
18
19
  docs/examples.rst
@@ -53,6 +54,7 @@ tests/test_samples.py
53
54
  tests/test_transforms.py
54
55
  tests/test_utils.py
55
56
  tests/integration_tests/conftest.py
57
+ tests/integration_tests/test_checkpointing.py
56
58
  tests/integration_tests/test_integration.py
57
59
  tests/test_flows/test_flows_core.py
58
60
  tests/test_flows/test_jax_flows/test_flowjax_flows.py
@@ -0,0 +1,88 @@
1
+ Checkpointing and Resuming
2
+ ==========================
3
+
4
+ Aspire provides a few simple patterns to resume long runs.
5
+
6
+ Saving checkpoints while sampling
7
+ ---------------------------------
8
+
9
+ - Pass ``checkpoint_path`` (an HDF5 file) to :py:meth:`aspire.Aspire.sample_posterior`
10
+ to write checkpoints as the sampler runs. Use ``checkpoint_every`` to control
11
+ frequency and ``checkpoint_save_config/flow`` to control what metadata is saved.
12
+ - For a convenience wrapper, wrap your sampling in
13
+ ``with aspire.auto_checkpoint("run.h5", every=1): ...``. Inside the context,
14
+ ``sample_posterior`` will default to checkpointing to that file, and the config/flow
15
+ will be updated as needed.
16
+
17
+ What gets saved
18
+ ^^^^^^^^^^^^^^^
19
+
20
+ - The sampler stores checkpoints under ``/checkpoint/state`` in the HDF5 file.
21
+ - Aspire writes ``/aspire_config`` (with ``sampler_type`` and ``sampler_config``) and
22
+ ``/flow``. If these already exist, they are overwritten when saving.
23
+
24
+ Resuming from a file
25
+ --------------------
26
+
27
+ - Use :py:meth:`aspire.Aspire.resume_from_file` to rebuild an Aspire instance and flow
28
+ from a checkpoint file:
29
+
30
+ .. code-block:: python
31
+
32
+ aspire = Aspire.resume_from_file(
33
+ "run.h5",
34
+ log_likelihood=log_likelihood,
35
+ log_prior=log_prior,
36
+ )
37
+ # Optionally continue checkpointing to the same file
38
+ with aspire.auto_checkpoint("run.h5", every=1):
39
+ samples = aspire.sample_posterior()
40
+
41
+ - ``resume_from_file`` loads config, flow, and the last checkpoint (if present), and
42
+ primes the instance to resume sampling; you can still override sampler kwargs when
43
+ calling ``sample_posterior``.
44
+
45
+ Manual resume via ``sample_posterior`` args
46
+ -------------------------------------------
47
+
48
+ - If you have a checkpoint blob (bytes or dict) already, you can pass it directly:
49
+
50
+ .. code-block:: python
51
+
52
+ samples = aspire.sample_posterior(
53
+ n_samples=...,
54
+ sampler="smc",
55
+ resume_from=checkpoint_bytes_or_dict,
56
+ checkpoint_path="run.h5", # optional: keep writing checkpoints
57
+ )
58
+
59
+ - To resume from a file without using ``resume_from_file``, load the checkpoint bytes
60
+ and flow yourself, then call ``sample_posterior``:
61
+
62
+ .. code-block:: python
63
+
64
+ from aspire.utils import AspireFile
65
+
66
+ aspire = Aspire(..., flow_backend="zuko")
67
+ with AspireFile("run.h5", "r") as f:
68
+ aspire.load_flow(f, path="flow")
69
+ # Standard layout is /checkpoint/state; adjust if you used a different path
70
+ checkpoint_bytes = f["checkpoint"]["state"][...].tobytes()
71
+ samples = aspire.sample_posterior(
72
+ n_samples=..., sampler="smc", resume_from="run.h5"
73
+ # or resume_from=checkpoint_bytes if you prefer to pass bytes directly
74
+ )
75
+
76
+ Notes and tips
77
+ --------------
78
+
79
+ - Checkpoint files must be HDF5 (``.h5``/``.hdf5``).
80
+ - If a checkpoint is missing in the file (e.g., sampling never wrote one), the flow
81
+ and config are still loaded; you can simply start sampling again and checkpointing
82
+ will continue to the same file.
83
+ - For manual control, you can always call ``save_config`` / ``save_flow`` yourself
84
+ on an :class:`aspire.utils.AspireFile`.
85
+ - SMC samplers also accept a custom ``checkpoint_callback`` and ``checkpoint_every`` if
86
+ you want full control over how checkpoints are persisted or inspected. Provide a
87
+ callable that accepts the checkpoint state dict; from there you can, for example,
88
+ serialize to another format or push to remote storage.
@@ -57,6 +57,7 @@ examples, and the complete API reference.
57
57
 
58
58
  installation
59
59
  user_guide
60
+ checkpointing
60
61
  recipes
61
62
  multiprocessing
62
63
  examples
@@ -1,5 +1,8 @@
1
+ import copy
1
2
  import logging
2
3
  import multiprocessing as mp
4
+ import pickle
5
+ from contextlib import contextmanager
3
6
  from inspect import signature
4
7
  from typing import Any, Callable
5
8
 
@@ -8,13 +11,19 @@ import h5py
8
11
  from .flows import get_flow_wrapper
9
12
  from .flows.base import Flow
10
13
  from .history import History
14
+ from .samplers.base import Sampler
11
15
  from .samples import Samples
12
16
  from .transforms import (
13
17
  CompositeTransform,
14
18
  FlowPreconditioningTransform,
15
19
  FlowTransform,
16
20
  )
17
- from .utils import recursively_save_to_h5_file
21
+ from .utils import (
22
+ AspireFile,
23
+ load_from_h5_file,
24
+ recursively_save_to_h5_file,
25
+ resolve_xp,
26
+ )
18
27
 
19
28
  logger = logging.getLogger(__name__)
20
29
 
@@ -102,6 +111,7 @@ class Aspire:
102
111
  self.dtype = dtype
103
112
 
104
113
  self._flow = flow
114
+ self._sampler = None
105
115
 
106
116
  @property
107
117
  def flow(self):
@@ -114,7 +124,7 @@ class Aspire:
114
124
  self._flow = flow
115
125
 
116
126
  @property
117
- def sampler(self):
127
+ def sampler(self) -> Sampler | None:
118
128
  """The sampler object."""
119
129
  return self._sampler
120
130
 
@@ -192,7 +202,29 @@ class Aspire:
192
202
  **self.flow_kwargs,
193
203
  )
194
204
 
195
- def fit(self, samples: Samples, **kwargs) -> History:
205
+ def fit(
206
+ self,
207
+ samples: Samples,
208
+ checkpoint_path: str | None = None,
209
+ checkpoint_save_config: bool = True,
210
+ overwrite: bool = False,
211
+ **kwargs,
212
+ ) -> History:
213
+ """Fit the normalizing flow to the provided samples.
214
+
215
+ Parameters
216
+ ----------
217
+ samples : Samples
218
+ The samples to fit the flow to.
219
+ checkpoint_path : str | None
220
+ Path to save the checkpoint. If None, no checkpoint is saved.
221
+ checkpoint_save_config : bool
222
+ Whether to save the Aspire configuration to the checkpoint.
223
+ overwrite : bool
224
+ Whether to overwrite an existing flow in the checkpoint file.
225
+ kwargs : dict
226
+ Keyword arguments to pass to the flow's fit method.
227
+ """
196
228
  if self.xp is None:
197
229
  self.xp = samples.xp
198
230
 
@@ -202,6 +234,28 @@ class Aspire:
202
234
  self.training_samples = samples
203
235
  logger.info(f"Training with {len(samples.x)} samples")
204
236
  history = self.flow.fit(samples.x, **kwargs)
237
+ defaults = getattr(self, "_checkpoint_defaults", None)
238
+ if checkpoint_path is None and defaults:
239
+ checkpoint_path = defaults["path"]
240
+ checkpoint_save_config = defaults["save_config"]
241
+ saved_config = (
242
+ defaults.get("saved_config", False) if defaults else False
243
+ )
244
+ if checkpoint_path is not None:
245
+ with AspireFile(checkpoint_path, "a") as h5_file:
246
+ if checkpoint_save_config and not saved_config:
247
+ if "aspire_config" in h5_file:
248
+ del h5_file["aspire_config"]
249
+ self.save_config(h5_file, include_sampler_config=False)
250
+ if defaults is not None:
251
+ defaults["saved_config"] = True
252
+ # Save flow only if missing or overwrite=True
253
+ if "flow" in h5_file:
254
+ if overwrite:
255
+ del h5_file["flow"]
256
+ self.save_flow(h5_file)
257
+ else:
258
+ self.save_flow(h5_file)
205
259
  return history
206
260
 
207
261
  def get_sampler_class(self, sampler_type: str) -> Callable:
@@ -241,6 +295,13 @@ class Aspire:
241
295
  ----------
242
296
  sampler_type : str
243
297
  The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
298
+ preconditioning: str
299
+ Type of preconditioning to apply in the sampler. Options are
300
+ 'default', 'flow', or 'none'.
301
+ preconditioning_kwargs: dict
302
+ Keyword arguments to pass to the preconditioning transform.
303
+ kwargs : dict
304
+ Keyword arguments to pass to the sampler.
244
305
  """
245
306
  SamplerClass = self.get_sampler_class(sampler_type)
246
307
 
@@ -304,6 +365,9 @@ class Aspire:
304
365
  return_history: bool = False,
305
366
  preconditioning: str | None = None,
306
367
  preconditioning_kwargs: dict | None = None,
368
+ checkpoint_path: str | None = None,
369
+ checkpoint_every: int = 1,
370
+ checkpoint_save_config: bool = True,
307
371
  **kwargs,
308
372
  ) -> Samples:
309
373
  """Draw samples from the posterior distribution.
@@ -342,6 +406,14 @@ class Aspire:
342
406
  will default to 'none' and the other samplers to 'default'
343
407
  preconditioning_kwargs: dict
344
408
  Keyword arguments to pass to the preconditioning transform.
409
+ checkpoint_path : str | None
410
+ Path to save the checkpoint. If None, no checkpoint is saved unless
411
+ within an :py:meth:`auto_checkpoint` context or a custom callback
412
+ is provided.
413
+ checkpoint_every : int
414
+ Frequency (in number of sampler iterations) to save the checkpoint.
415
+ checkpoint_save_config : bool
416
+ Whether to save the Aspire configuration to the checkpoint.
345
417
  kwargs : dict
346
418
  Keyword arguments to pass to the sampler. These are passed
347
419
  automatically to the init method of the sampler or to the sample
@@ -352,6 +424,22 @@ class Aspire:
352
424
  samples : Samples
353
425
  Samples object contain samples and their corresponding weights.
354
426
  """
427
+ if (
428
+ sampler == "importance"
429
+ and hasattr(self, "_resume_sampler_type")
430
+ and self._resume_sampler_type
431
+ ):
432
+ sampler = self._resume_sampler_type
433
+
434
+ if "resume_from" not in kwargs and hasattr(
435
+ self, "_resume_from_default"
436
+ ):
437
+ kwargs["resume_from"] = self._resume_from_default
438
+ if hasattr(self, "_resume_overrides"):
439
+ kwargs.update(self._resume_overrides)
440
+ if hasattr(self, "_resume_n_samples") and n_samples == 1000:
441
+ n_samples = self._resume_n_samples
442
+
355
443
  SamplerClass = self.get_sampler_class(sampler)
356
444
  # Determine sampler initialization parameters
357
445
  # and remove them from kwargs
@@ -373,7 +461,73 @@ class Aspire:
373
461
  preconditioning_kwargs=preconditioning_kwargs,
374
462
  **sampler_kwargs,
375
463
  )
464
+ self._last_sampler_type = sampler
465
+ # Auto-checkpoint convenience: set defaults for checkpointing to a single file
466
+ defaults = getattr(self, "_checkpoint_defaults", None)
467
+ if checkpoint_path is None and defaults:
468
+ checkpoint_path = defaults["path"]
469
+ checkpoint_every = defaults["every"]
470
+ checkpoint_save_config = defaults["save_config"]
471
+ saved_flow = defaults.get("saved_flow", False) if defaults else False
472
+ saved_config = (
473
+ defaults.get("saved_config", False) if defaults else False
474
+ )
475
+ if checkpoint_path is not None:
476
+ kwargs.setdefault("checkpoint_file_path", checkpoint_path)
477
+ kwargs.setdefault("checkpoint_every", checkpoint_every)
478
+ with AspireFile(checkpoint_path, "a") as h5_file:
479
+ if checkpoint_save_config:
480
+ if "aspire_config" in h5_file:
481
+ del h5_file["aspire_config"]
482
+ self.save_config(
483
+ h5_file,
484
+ include_sampler_config=True,
485
+ include_sample_calls=False,
486
+ )
487
+ saved_config = True
488
+ if defaults is not None:
489
+ defaults["saved_config"] = True
490
+ if (
491
+ self.flow is not None
492
+ and not saved_flow
493
+ and "flow" not in h5_file
494
+ ):
495
+ self.save_flow(h5_file)
496
+ saved_flow = True
497
+ if defaults is not None:
498
+ defaults["saved_flow"] = True
499
+
376
500
  samples = self._sampler.sample(n_samples, **kwargs)
501
+ self._last_sample_posterior_kwargs = {
502
+ "n_samples": n_samples,
503
+ "sampler": sampler,
504
+ "xp": xp,
505
+ "return_history": return_history,
506
+ "preconditioning": preconditioning,
507
+ "preconditioning_kwargs": preconditioning_kwargs,
508
+ "sampler_init_kwargs": sampler_kwargs,
509
+ "sample_kwargs": copy.deepcopy(kwargs),
510
+ }
511
+ if checkpoint_path is not None:
512
+ with AspireFile(checkpoint_path, "a") as h5_file:
513
+ if checkpoint_save_config and not saved_config:
514
+ if "aspire_config" in h5_file:
515
+ del h5_file["aspire_config"]
516
+ self.save_config(
517
+ h5_file,
518
+ include_sampler_config=True,
519
+ include_sample_calls=False,
520
+ )
521
+ if defaults is not None:
522
+ defaults["saved_config"] = True
523
+ if (
524
+ self.flow is not None
525
+ and not saved_flow
526
+ and "flow" not in h5_file
527
+ ):
528
+ self.save_flow(h5_file)
529
+ if defaults is not None:
530
+ defaults["saved_flow"] = True
377
531
  if xp is not None:
378
532
  samples = samples.to_namespace(xp)
379
533
  samples.parameters = self.parameters
@@ -388,6 +542,122 @@ class Aspire:
388
542
  else:
389
543
  return samples
390
544
 
545
+ @classmethod
546
+ def resume_from_file(
547
+ cls,
548
+ file_path: str,
549
+ *,
550
+ log_likelihood: Callable,
551
+ log_prior: Callable,
552
+ sampler: str | None = None,
553
+ checkpoint_path: str = "checkpoint",
554
+ checkpoint_dset: str = "state",
555
+ flow_path: str = "flow",
556
+ config_path: str = "aspire_config",
557
+ resume_kwargs: dict | None = None,
558
+ ):
559
+ """
560
+ Recreate an Aspire object from a single file and prepare to resume sampling.
561
+
562
+ Parameters
563
+ ----------
564
+ file_path : str
565
+ Path to the HDF5 file containing config, flow, and checkpoint.
566
+ log_likelihood : Callable
567
+ Log-likelihood function (required, not pickled).
568
+ log_prior : Callable
569
+ Log-prior function (required, not pickled).
570
+ sampler : str
571
+ Sampler type to use (e.g., 'smc', 'minipcn_smc', 'emcee_smc'). If None,
572
+ will attempt to infer from saved config or checkpoint metadata.
573
+ checkpoint_path : str
574
+ HDF5 group path where the checkpoint is stored.
575
+ checkpoint_dset : str
576
+ Dataset name within the checkpoint group.
577
+ flow_path : str
578
+ HDF5 path to the saved flow.
579
+ config_path : str
580
+ HDF5 path to the saved Aspire config.
581
+ resume_kwargs : dict | None
582
+ Optional overrides to apply when resuming (e.g., checkpoint_every).
583
+ """
584
+ (
585
+ aspire,
586
+ checkpoint_bytes,
587
+ checkpoint_state,
588
+ sampler_config,
589
+ saved_sampler_type,
590
+ n_samples,
591
+ ) = cls._build_aspire_from_file(
592
+ file_path=file_path,
593
+ log_likelihood=log_likelihood,
594
+ log_prior=log_prior,
595
+ checkpoint_path=checkpoint_path,
596
+ checkpoint_dset=checkpoint_dset,
597
+ flow_path=flow_path,
598
+ config_path=config_path,
599
+ )
600
+
601
+ sampler_config = sampler_config or {}
602
+ sampler_config.pop("sampler_class", None)
603
+
604
+ if checkpoint_bytes is not None:
605
+ aspire._resume_from_default = checkpoint_bytes
606
+ aspire._resume_sampler_type = (
607
+ sampler
608
+ or saved_sampler_type
609
+ or (
610
+ checkpoint_state.get("sampler")
611
+ if checkpoint_state
612
+ else None
613
+ )
614
+ )
615
+ aspire._resume_n_samples = n_samples
616
+ aspire._resume_overrides = resume_kwargs or {}
617
+ aspire._resume_sampler_config = sampler_config
618
+ aspire._checkpoint_defaults = {
619
+ "path": file_path,
620
+ "every": 1,
621
+ "save_config": False,
622
+ "save_flow": False,
623
+ "saved_config": False,
624
+ "saved_flow": False,
625
+ }
626
+ return aspire
627
+
628
+ @contextmanager
629
+ def auto_checkpoint(
630
+ self,
631
+ path: str,
632
+ every: int = 1,
633
+ save_config: bool = True,
634
+ save_flow: bool = True,
635
+ ):
636
+ """
637
+ Context manager to auto-save checkpoints, config, and flow to a file.
638
+
639
+ Within the context, sample_posterior will default to writing checkpoints
640
+ to the given path with the specified frequency, and will append config/flow
641
+ after sampling.
642
+ """
643
+ prev = getattr(self, "_checkpoint_defaults", None)
644
+ self._checkpoint_defaults = {
645
+ "path": path,
646
+ "every": every,
647
+ "save_config": save_config,
648
+ "save_flow": save_flow,
649
+ "saved_config": False,
650
+ "saved_flow": False,
651
+ }
652
+ try:
653
+ yield self
654
+ finally:
655
+ if prev is None:
656
+ if hasattr(self, "_checkpoint_defaults"):
657
+ delattr(self, "_checkpoint_defaults")
658
+ else:
659
+ self._checkpoint_defaults = prev
660
+
391
661
  def enable_pool(self, pool: mp.Pool, **kwargs):
392
662
  """Context manager to temporarily replace the log_likelihood method
393
663
  with a version that uses a multiprocessing pool to parallelize
@@ -432,12 +702,16 @@ class Aspire:
432
702
  "flow_kwargs": self.flow_kwargs,
433
703
  "eps": self.eps,
434
704
  }
705
+ if hasattr(self, "_last_sampler_type"):
706
+ config["sampler_type"] = self._last_sampler_type
435
707
  if include_sampler_config:
708
+ if self.sampler is None:
709
+ raise ValueError("Sampler has not been initialized.")
436
710
  config["sampler_config"] = self.sampler.config_dict(**kwargs)
437
711
  return config
438
712
 
439
713
  def save_config(
440
- self, h5_file: h5py.File, path="aspire_config", **kwargs
714
+ self, h5_file: h5py.File | AspireFile, path="aspire_config", **kwargs
441
715
  ) -> None:
442
716
  """Save the configuration to an HDF5 file.
443
717
 
@@ -484,6 +758,7 @@ class Aspire:
484
758
  FlowClass, xp = get_flow_wrapper(
485
759
  backend=self.flow_backend, flow_matching=self.flow_matching
486
760
  )
761
+ logger.debug(f"Loading flow of type {FlowClass} from {path}")
487
762
  self._flow = FlowClass.load(h5_file, path=path)
488
763
 
489
764
  def save_config_to_json(self, filename: str) -> None:
@@ -504,3 +779,80 @@ class Aspire:
504
779
  x, log_q = self.flow.sample_and_log_prob(n_samples)
505
780
  samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
506
781
  return samples
782
+
783
+ # --- Resume helpers ---
784
+ @staticmethod
785
+ def _build_aspire_from_file(
786
+ file_path: str,
787
+ log_likelihood: Callable,
788
+ log_prior: Callable,
789
+ checkpoint_path: str,
790
+ checkpoint_dset: str,
791
+ flow_path: str,
792
+ config_path: str,
793
+ ):
794
+ """Construct an Aspire instance, load flow, and gather checkpoint metadata from file."""
795
+ with AspireFile(file_path, "r") as h5_file:
796
+ if config_path not in h5_file:
797
+ raise ValueError(
798
+ f"Config path '{config_path}' not found in {file_path}"
799
+ )
800
+ config_dict = load_from_h5_file(h5_file, config_path)
801
+ try:
802
+ checkpoint_bytes = h5_file[checkpoint_path][checkpoint_dset][
803
+ ...
804
+ ].tobytes()
805
+ except Exception:
806
+ logger.warning(
807
+ "Checkpoint not found at %s/%s in %s; will resume without a checkpoint.",
808
+ checkpoint_path,
809
+ checkpoint_dset,
810
+ file_path,
811
+ )
812
+ checkpoint_bytes = None
813
+
814
+ sampler_config = config_dict.pop("sampler_config", None)
815
+ saved_sampler_type = config_dict.pop("sampler_type", None)
816
+ if isinstance(config_dict.get("xp"), str):
817
+ config_dict["xp"] = resolve_xp(config_dict["xp"])
818
+ config_dict["log_likelihood"] = log_likelihood
819
+ config_dict["log_prior"] = log_prior
820
+
821
+ aspire = Aspire(**config_dict)
822
+
823
+ with AspireFile(file_path, "r") as h5_file:
824
+ if flow_path in h5_file:
825
+ logger.info(f"Loading flow from {flow_path} in {file_path}")
826
+ aspire.load_flow(h5_file, path=flow_path)
827
+ else:
828
+ raise ValueError(
829
+ f"Flow path '{flow_path}' not found in {file_path}"
830
+ )
831
+
832
+ n_samples = None
833
+ checkpoint_state = None
834
+ if checkpoint_bytes is not None:
835
+ try:
836
+ checkpoint_state = pickle.loads(checkpoint_bytes)
837
+ samples_saved = (
838
+ checkpoint_state.get("samples")
839
+ if checkpoint_state
840
+ else None
841
+ )
842
+ if samples_saved is not None:
843
+ n_samples = len(samples_saved)
844
+ if aspire.xp is None and hasattr(samples_saved, "xp"):
845
+ aspire.xp = samples_saved.xp
846
+ except Exception:
847
+ logger.warning(
848
+ "Failed to decode checkpoint; proceeding without resume state."
849
+ )
850
+
851
+ return (
852
+ aspire,
853
+ checkpoint_bytes,
854
+ checkpoint_state,
855
+ sampler_config,
856
+ saved_sampler_type,
857
+ n_samples,
858
+ )
@@ -92,7 +92,7 @@ class BaseTorchFlow(Flow):
92
92
  config = load_from_h5_file(flow_grp, "config")
93
93
  config["dtype"] = decode_dtype(torch, config.get("dtype"))
94
94
  if "data_transform" in flow_grp:
95
- from ..transforms import BaseTransform
95
+ from ...transforms import BaseTransform
96
96
 
97
97
  data_transform = BaseTransform.load(
98
98
  flow_grp,