aspire-inference 0.1.0a10__tar.gz → 0.1.0a12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/PKG-INFO +1 -1
  2. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/aspire_inference.egg-info/PKG-INFO +1 -1
  3. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/aspire_inference.egg-info/SOURCES.txt +2 -0
  4. aspire_inference-0.1.0a12/docs/checkpointing.rst +88 -0
  5. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/index.rst +1 -0
  6. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/aspire.py +359 -6
  7. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/torch/flows.py +1 -1
  8. aspire_inference-0.1.0a12/src/aspire/samplers/base.py +238 -0
  9. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/smc/base.py +133 -48
  10. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/smc/blackjax.py +8 -0
  11. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/smc/emcee.py +8 -0
  12. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/smc/minipcn.py +10 -2
  13. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samples.py +11 -9
  14. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/utils.py +119 -0
  15. aspire_inference-0.1.0a12/tests/integration_tests/test_checkpointing.py +88 -0
  16. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_samples.py +1 -1
  17. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_utils.py +30 -1
  18. aspire_inference-0.1.0a10/src/aspire/samplers/base.py +0 -98
  19. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/.github/workflows/lint.yml +0 -0
  20. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/.github/workflows/publish.yml +0 -0
  21. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/.github/workflows/tests.yml +0 -0
  22. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/.gitignore +0 -0
  23. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/.pre-commit-config.yaml +0 -0
  24. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/LICENSE +0 -0
  25. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/README.md +0 -0
  26. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/aspire_inference.egg-info/dependency_links.txt +0 -0
  27. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/aspire_inference.egg-info/requires.txt +0 -0
  28. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/aspire_inference.egg-info/top_level.txt +0 -0
  29. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/Makefile +0 -0
  30. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/conf.py +0 -0
  31. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/entry_points.rst +0 -0
  32. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/examples.rst +0 -0
  33. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/installation.rst +0 -0
  34. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/multiprocessing.rst +0 -0
  35. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/recipes.rst +0 -0
  36. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/requirements.txt +0 -0
  37. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/docs/user_guide.rst +0 -0
  38. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/examples/basic_example.py +0 -0
  39. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/examples/blackjax_smc_example.py +0 -0
  40. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/examples/smc_example.py +0 -0
  41. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/pyproject.toml +0 -0
  42. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/readthedocs.yml +0 -0
  43. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/setup.cfg +0 -0
  44. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/__init__.py +0 -0
  45. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/__init__.py +0 -0
  46. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/base.py +0 -0
  47. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/jax/__init__.py +0 -0
  48. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/jax/flows.py +0 -0
  49. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/jax/utils.py +0 -0
  50. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/flows/torch/__init__.py +0 -0
  51. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/history.py +0 -0
  52. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/plot.py +0 -0
  53. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/__init__.py +0 -0
  54. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/importance.py +0 -0
  55. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/mcmc.py +0 -0
  56. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/samplers/smc/__init__.py +0 -0
  57. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/src/aspire/transforms.py +0 -0
  58. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/conftest.py +0 -0
  59. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/integration_tests/conftest.py +0 -0
  60. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/integration_tests/test_integration.py +0 -0
  61. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_flows/test_flows_core.py +0 -0
  62. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -0
  63. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -0
  64. {aspire_inference-0.1.0a10 → aspire_inference-0.1.0a12}/tests/test_transforms.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a10
3
+ Version: 0.1.0a12
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a10
3
+ Version: 0.1.0a12
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -13,6 +13,7 @@ aspire_inference.egg-info/dependency_links.txt
13
13
  aspire_inference.egg-info/requires.txt
14
14
  aspire_inference.egg-info/top_level.txt
15
15
  docs/Makefile
16
+ docs/checkpointing.rst
16
17
  docs/conf.py
17
18
  docs/entry_points.rst
18
19
  docs/examples.rst
@@ -53,6 +54,7 @@ tests/test_samples.py
53
54
  tests/test_transforms.py
54
55
  tests/test_utils.py
55
56
  tests/integration_tests/conftest.py
57
+ tests/integration_tests/test_checkpointing.py
56
58
  tests/integration_tests/test_integration.py
57
59
  tests/test_flows/test_flows_core.py
58
60
  tests/test_flows/test_jax_flows/test_flowjax_flows.py
@@ -0,0 +1,88 @@
1
+ Checkpointing and Resuming
2
+ ==========================
3
+
4
+ Aspire provides a few simple patterns to resume long runs.
5
+
6
+ Saving checkpoints while sampling
7
+ ---------------------------------
8
+
9
+ - Pass ``checkpoint_path`` (an HDF5 file) to :py:meth:`aspire.Aspire.sample_posterior`
10
+ to write checkpoints as the sampler runs. Use ``checkpoint_every`` to control
11
+ frequency and ``checkpoint_save_config/flow`` to control what metadata is saved.
12
+ - For a convenience wrapper, wrap your sampling in
13
+ ``with aspire.auto_checkpoint("run.h5", every=1): ...``. Inside the context,
14
+ ``sample_posterior`` will default to checkpointing to that file, and the config/flow
15
+ will be updated as needed.
16
+
17
+ What gets saved
18
+ ^^^^^^^^^^^^^^^
19
+
20
+ - The sampler stores checkpoints under ``/checkpoint/state`` in the HDF5 file.
21
+ - Aspire writes ``/aspire_config`` (with ``sampler_type`` and ``sampler_config``) and
22
+ ``/flow``. If these already exist, they are overwritten when saving.
23
+
24
+ Resuming from a file
25
+ --------------------
26
+
27
+ - Use :py:meth:`aspire.Aspire.resume_from_file` to rebuild an Aspire instance and flow
28
+ from a checkpoint file:
29
+
30
+ .. code-block:: python
31
+
32
+ aspire = Aspire.resume_from_file(
33
+ "run.h5",
34
+ log_likelihood=log_likelihood,
35
+ log_prior=log_prior,
36
+ )
37
+ # Optionally continue checkpointing to the same file
38
+ with aspire.auto_checkpoint("run.h5", every=1):
39
+ samples = aspire.sample_posterior()
40
+
41
+ - ``resume_from_file`` loads config, flow, and the last checkpoint (if present), and
42
+ primes the instance to resume sampling; you can still override sampler kwargs when
43
+ calling ``sample_posterior``.
44
+
45
+ Manual resume via ``sample_posterior`` args
46
+ -------------------------------------------
47
+
48
+ - If you have a checkpoint blob (bytes or dict) already, you can pass it directly:
49
+
50
+ .. code-block:: python
51
+
52
+ samples = aspire.sample_posterior(
53
+ n_samples=...,
54
+ sampler="smc",
55
+ resume_from=checkpoint_bytes_or_dict,
56
+ checkpoint_path="run.h5", # optional: keep writing checkpoints
57
+ )
58
+
59
+ - To resume from a file without using ``resume_from_file``, load the checkpoint bytes
60
+ and flow yourself, then call ``sample_posterior``:
61
+
62
+ .. code-block:: python
63
+
64
+ from aspire.utils import AspireFile
65
+
66
+ aspire = Aspire(..., flow_backend="zuko")
67
+ with AspireFile("run.h5", "r") as f:
68
+ aspire.load_flow(f, path="flow")
69
+ # Standard layout is /checkpoint/state; adjust if you used a different path
70
+ checkpoint_bytes = f["checkpoint"]["state"][...].tobytes()
71
+ samples = aspire.sample_posterior(
72
+ n_samples=..., sampler="smc", resume_from="run.h5"
73
+ # or resume_from=checkpoint_bytes if you prefer to pass bytes directly
74
+ )
75
+
76
+ Notes and tips
77
+ --------------
78
+
79
+ - Checkpoint files must be HDF5 (``.h5``/``.hdf5``).
80
+ - If a checkpoint is missing in the file (e.g., sampling never wrote one), the flow
81
+ and config are still loaded; you can simply start sampling again and checkpointing
82
+ will continue to the same file.
83
+ - For manual control, you can always call ``save_config`` / ``save_flow`` yourself
84
+ on an :class:`aspire.utils.AspireFile`.
85
+ - SMC samplers also accept a custom ``checkpoint_callback`` and ``checkpoint_every`` if
86
+ you want full control over how checkpoints are persisted or inspected. Provide a
87
+ callable that accepts the checkpoint state dict; from there you can, for example,
88
+ serialize to another format or push to remote storage.
@@ -57,6 +57,7 @@ examples, and the complete API reference.
57
57
 
58
58
  installation
59
59
  user_guide
60
+ checkpointing
60
61
  recipes
61
62
  multiprocessing
62
63
  examples
@@ -1,5 +1,8 @@
1
+ import copy
1
2
  import logging
2
3
  import multiprocessing as mp
4
+ import pickle
5
+ from contextlib import contextmanager
3
6
  from inspect import signature
4
7
  from typing import Any, Callable
5
8
 
@@ -8,13 +11,20 @@ import h5py
8
11
  from .flows import get_flow_wrapper
9
12
  from .flows.base import Flow
10
13
  from .history import History
14
+ from .samplers.base import Sampler
11
15
  from .samples import Samples
12
16
  from .transforms import (
13
17
  CompositeTransform,
14
18
  FlowPreconditioningTransform,
15
19
  FlowTransform,
16
20
  )
17
- from .utils import recursively_save_to_h5_file
21
+ from .utils import (
22
+ AspireFile,
23
+ function_id,
24
+ load_from_h5_file,
25
+ recursively_save_to_h5_file,
26
+ resolve_xp,
27
+ )
18
28
 
19
29
  logger = logging.getLogger(__name__)
20
30
 
@@ -102,6 +112,7 @@ class Aspire:
102
112
  self.dtype = dtype
103
113
 
104
114
  self._flow = flow
115
+ self._sampler = None
105
116
 
106
117
  @property
107
118
  def flow(self):
@@ -114,7 +125,7 @@ class Aspire:
114
125
  self._flow = flow
115
126
 
116
127
  @property
117
- def sampler(self):
128
+ def sampler(self) -> Sampler | None:
118
129
  """The sampler object."""
119
130
  return self._sampler
120
131
 
@@ -192,7 +203,29 @@ class Aspire:
192
203
  **self.flow_kwargs,
193
204
  )
194
205
 
195
- def fit(self, samples: Samples, **kwargs) -> History:
206
+ def fit(
207
+ self,
208
+ samples: Samples,
209
+ checkpoint_path: str | None = None,
210
+ checkpoint_save_config: bool = True,
211
+ overwrite: bool = False,
212
+ **kwargs,
213
+ ) -> History:
214
+ """Fit the normalizing flow to the provided samples.
215
+
216
+ Parameters
217
+ ----------
218
+ samples : Samples
219
+ The samples to fit the flow to.
220
+ checkpoint_path : str | None
221
+ Path to save the checkpoint. If None, no checkpoint is saved.
222
+ checkpoint_save_config : bool
223
+ Whether to save the Aspire configuration to the checkpoint.
224
+ overwrite : bool
225
+ Whether to overwrite an existing flow in the checkpoint file.
226
+ kwargs : dict
227
+ Keyword arguments to pass to the flow's fit method.
228
+ """
196
229
  if self.xp is None:
197
230
  self.xp = samples.xp
198
231
 
@@ -202,6 +235,28 @@ class Aspire:
202
235
  self.training_samples = samples
203
236
  logger.info(f"Training with {len(samples.x)} samples")
204
237
  history = self.flow.fit(samples.x, **kwargs)
238
+ defaults = getattr(self, "_checkpoint_defaults", None)
239
+ if checkpoint_path is None and defaults:
240
+ checkpoint_path = defaults["path"]
241
+ checkpoint_save_config = defaults["save_config"]
242
+ saved_config = (
243
+ defaults.get("saved_config", False) if defaults else False
244
+ )
245
+ if checkpoint_path is not None:
246
+ with AspireFile(checkpoint_path, "a") as h5_file:
247
+ if checkpoint_save_config and not saved_config:
248
+ if "aspire_config" in h5_file:
249
+ del h5_file["aspire_config"]
250
+ self.save_config(h5_file, include_sampler_config=False)
251
+ if defaults is not None:
252
+ defaults["saved_config"] = True
253
+ # Save flow only if missing or overwrite=True
254
+ if "flow" in h5_file:
255
+ if overwrite:
256
+ del h5_file["flow"]
257
+ self.save_flow(h5_file)
258
+ else:
259
+ self.save_flow(h5_file)
205
260
  return history
206
261
 
207
262
  def get_sampler_class(self, sampler_type: str) -> Callable:
@@ -241,6 +296,13 @@ class Aspire:
241
296
  ----------
242
297
  sampler_type : str
243
298
  The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
299
+ preconditioning: str
300
+ Type of preconditioning to apply in the sampler. Options are
301
+ 'default', 'flow', or 'none'.
302
+ preconditioning_kwargs: dict
303
+ Keyword arguments to pass to the preconditioning transform.
304
+ kwargs : dict
305
+ Keyword arguments to pass to the sampler.
244
306
  """
245
307
  SamplerClass = self.get_sampler_class(sampler_type)
246
308
 
@@ -304,6 +366,9 @@ class Aspire:
304
366
  return_history: bool = False,
305
367
  preconditioning: str | None = None,
306
368
  preconditioning_kwargs: dict | None = None,
369
+ checkpoint_path: str | None = None,
370
+ checkpoint_every: int = 1,
371
+ checkpoint_save_config: bool = True,
307
372
  **kwargs,
308
373
  ) -> Samples:
309
374
  """Draw samples from the posterior distribution.
@@ -342,6 +407,14 @@ class Aspire:
342
407
  will default to 'none' and the other samplers to 'default'
343
408
  preconditioning_kwargs: dict
344
409
  Keyword arguments to pass to the preconditioning transform.
410
+ checkpoint_path : str | None
411
+ Path to save the checkpoint. If None, no checkpoint is saved unless
412
+ within an :py:meth:`auto_checkpoint` context or a custom callback
413
+ is provided.
414
+ checkpoint_every : int
415
+ Frequency (in number of sampler iterations) to save the checkpoint.
416
+ checkpoint_save_config : bool
417
+ Whether to save the Aspire configuration to the checkpoint.
345
418
  kwargs : dict
346
419
  Keyword arguments to pass to the sampler. These are passed
347
420
  automatically to the init method of the sampler or to the sample
@@ -352,6 +425,22 @@ class Aspire:
352
425
  samples : Samples
353
426
  Samples object contain samples and their corresponding weights.
354
427
  """
428
+ if (
429
+ sampler == "importance"
430
+ and hasattr(self, "_resume_sampler_type")
431
+ and self._resume_sampler_type
432
+ ):
433
+ sampler = self._resume_sampler_type
434
+
435
+ if "resume_from" not in kwargs and hasattr(
436
+ self, "_resume_from_default"
437
+ ):
438
+ kwargs["resume_from"] = self._resume_from_default
439
+ if hasattr(self, "_resume_overrides"):
440
+ kwargs.update(self._resume_overrides)
441
+ if hasattr(self, "_resume_n_samples") and n_samples == 1000:
442
+ n_samples = self._resume_n_samples
443
+
355
444
  SamplerClass = self.get_sampler_class(sampler)
356
445
  # Determine sampler initialization parameters
357
446
  # and remove them from kwargs
@@ -373,7 +462,73 @@ class Aspire:
373
462
  preconditioning_kwargs=preconditioning_kwargs,
374
463
  **sampler_kwargs,
375
464
  )
465
+ self._last_sampler_type = sampler
466
+ # Auto-checkpoint convenience: set defaults for checkpointing to a single file
467
+ defaults = getattr(self, "_checkpoint_defaults", None)
468
+ if checkpoint_path is None and defaults:
469
+ checkpoint_path = defaults["path"]
470
+ checkpoint_every = defaults["every"]
471
+ checkpoint_save_config = defaults["save_config"]
472
+ saved_flow = defaults.get("saved_flow", False) if defaults else False
473
+ saved_config = (
474
+ defaults.get("saved_config", False) if defaults else False
475
+ )
476
+ if checkpoint_path is not None:
477
+ kwargs.setdefault("checkpoint_file_path", checkpoint_path)
478
+ kwargs.setdefault("checkpoint_every", checkpoint_every)
479
+ with AspireFile(checkpoint_path, "a") as h5_file:
480
+ if checkpoint_save_config:
481
+ if "aspire_config" in h5_file:
482
+ del h5_file["aspire_config"]
483
+ self.save_config(
484
+ h5_file,
485
+ include_sampler_config=True,
486
+ include_sample_calls=False,
487
+ )
488
+ saved_config = True
489
+ if defaults is not None:
490
+ defaults["saved_config"] = True
491
+ if (
492
+ self.flow is not None
493
+ and not saved_flow
494
+ and "flow" not in h5_file
495
+ ):
496
+ self.save_flow(h5_file)
497
+ saved_flow = True
498
+ if defaults is not None:
499
+ defaults["saved_flow"] = True
500
+
376
501
  samples = self._sampler.sample(n_samples, **kwargs)
502
+ self._last_sample_posterior_kwargs = {
503
+ "n_samples": n_samples,
504
+ "sampler": sampler,
505
+ "xp": xp,
506
+ "return_history": return_history,
507
+ "preconditioning": preconditioning,
508
+ "preconditioning_kwargs": preconditioning_kwargs,
509
+ "sampler_init_kwargs": sampler_kwargs,
510
+ "sample_kwargs": copy.deepcopy(kwargs),
511
+ }
512
+ if checkpoint_path is not None:
513
+ with AspireFile(checkpoint_path, "a") as h5_file:
514
+ if checkpoint_save_config and not saved_config:
515
+ if "aspire_config" in h5_file:
516
+ del h5_file["aspire_config"]
517
+ self.save_config(
518
+ h5_file,
519
+ include_sampler_config=True,
520
+ include_sample_calls=False,
521
+ )
522
+ if defaults is not None:
523
+ defaults["saved_config"] = True
524
+ if (
525
+ self.flow is not None
526
+ and not saved_flow
527
+ and "flow" not in h5_file
528
+ ):
529
+ self.save_flow(h5_file)
530
+ if defaults is not None:
531
+ defaults["saved_flow"] = True
377
532
  if xp is not None:
378
533
  samples = samples.to_namespace(xp)
379
534
  samples.parameters = self.parameters
@@ -388,6 +543,122 @@ class Aspire:
388
543
  else:
389
544
  return samples
390
545
 
546
+ @classmethod
547
+ def resume_from_file(
548
+ cls,
549
+ file_path: str,
550
+ *,
551
+ log_likelihood: Callable,
552
+ log_prior: Callable,
553
+ sampler: str | None = None,
554
+ checkpoint_path: str = "checkpoint",
555
+ checkpoint_dset: str = "state",
556
+ flow_path: str = "flow",
557
+ config_path: str = "aspire_config",
558
+ resume_kwargs: dict | None = None,
559
+ ):
560
+ """
561
+ Recreate an Aspire object from a single file and prepare to resume sampling.
562
+
563
+ Parameters
564
+ ----------
565
+ file_path : str
566
+ Path to the HDF5 file containing config, flow, and checkpoint.
567
+ log_likelihood : Callable
568
+ Log-likelihood function (required, not pickled).
569
+ log_prior : Callable
570
+ Log-prior function (required, not pickled).
571
+ sampler : str
572
+ Sampler type to use (e.g., 'smc', 'minipcn_smc', 'emcee_smc'). If None,
573
+ will attempt to infer from saved config or checkpoint metadata.
574
+ checkpoint_path : str
575
+ HDF5 group path where the checkpoint is stored.
576
+ checkpoint_dset : str
577
+ Dataset name within the checkpoint group.
578
+ flow_path : str
579
+ HDF5 path to the saved flow.
580
+ config_path : str
581
+ HDF5 path to the saved Aspire config.
582
+ resume_kwargs : dict | None
583
+ Optional overrides to apply when resuming (e.g., checkpoint_every).
584
+ """
585
+ (
586
+ aspire,
587
+ checkpoint_bytes,
588
+ checkpoint_state,
589
+ sampler_config,
590
+ saved_sampler_type,
591
+ n_samples,
592
+ ) = cls._build_aspire_from_file(
593
+ file_path=file_path,
594
+ log_likelihood=log_likelihood,
595
+ log_prior=log_prior,
596
+ checkpoint_path=checkpoint_path,
597
+ checkpoint_dset=checkpoint_dset,
598
+ flow_path=flow_path,
599
+ config_path=config_path,
600
+ )
601
+
602
+ sampler_config = sampler_config or {}
603
+ sampler_config.pop("sampler_class", None)
604
+
605
+ if checkpoint_bytes is not None:
606
+ aspire._resume_from_default = checkpoint_bytes
607
+ aspire._resume_sampler_type = (
608
+ sampler
609
+ or saved_sampler_type
610
+ or (
611
+ checkpoint_state.get("sampler")
612
+ if checkpoint_state
613
+ else None
614
+ )
615
+ )
616
+ aspire._resume_n_samples = n_samples
617
+ aspire._resume_overrides = resume_kwargs or {}
618
+ aspire._resume_sampler_config = sampler_config
619
+ aspire._checkpoint_defaults = {
620
+ "path": file_path,
621
+ "every": 1,
622
+ "save_config": False,
623
+ "save_flow": False,
624
+ "saved_config": False,
625
+ "saved_flow": False,
626
+ }
627
+ return aspire
628
+
629
+ @contextmanager
630
+ def auto_checkpoint(
631
+ self,
632
+ path: str,
633
+ every: int = 1,
634
+ save_config: bool = True,
635
+ save_flow: bool = True,
636
+ ):
637
+ """
638
+ Context manager to auto-save checkpoints, config, and flow to a file.
639
+
640
+ Within the context, sample_posterior will default to writing checkpoints
641
+ to the given path with the specified frequency, and will append config/flow
642
+ after sampling.
643
+ """
644
+ prev = getattr(self, "_checkpoint_defaults", None)
645
+ self._checkpoint_defaults = {
646
+ "path": path,
647
+ "every": every,
648
+ "save_config": save_config,
649
+ "save_flow": save_flow,
650
+ "saved_config": False,
651
+ "saved_flow": False,
652
+ }
653
+ try:
654
+ yield self
655
+ finally:
656
+ if prev is None:
657
+ if hasattr(self, "_checkpoint_defaults"):
658
+ delattr(self, "_checkpoint_defaults")
659
+ else:
660
+ self._checkpoint_defaults = prev
661
+
391
662
  def enable_pool(self, pool: mp.Pool, **kwargs):
392
663
  """Context manager to temporarily replace the log_likelihood method
393
664
  with a version that uses a multiprocessing pool to parallelize
@@ -417,8 +688,8 @@ class Aspire:
417
688
  method of the sampler.
418
689
  """
419
690
  config = {
420
- "log_likelihood": self.log_likelihood.__name__,
421
- "log_prior": self.log_prior.__name__,
691
+ "log_likelihood": function_id(self.log_likelihood),
692
+ "log_prior": function_id(self.log_prior),
422
693
  "dims": self.dims,
423
694
  "parameters": self.parameters,
424
695
  "periodic_parameters": self.periodic_parameters,
@@ -432,12 +703,16 @@ class Aspire:
432
703
  "flow_kwargs": self.flow_kwargs,
433
704
  "eps": self.eps,
434
705
  }
706
+ if hasattr(self, "_last_sampler_type"):
707
+ config["sampler_type"] = self._last_sampler_type
435
708
  if include_sampler_config:
709
+ if self.sampler is None:
710
+ raise ValueError("Sampler has not been initialized.")
436
711
  config["sampler_config"] = self.sampler.config_dict(**kwargs)
437
712
  return config
438
713
 
439
714
  def save_config(
440
- self, h5_file: h5py.File, path="aspire_config", **kwargs
715
+ self, h5_file: h5py.File | AspireFile, path="aspire_config", **kwargs
441
716
  ) -> None:
442
717
  """Save the configuration to an HDF5 file.
443
718
 
@@ -484,6 +759,7 @@ class Aspire:
484
759
  FlowClass, xp = get_flow_wrapper(
485
760
  backend=self.flow_backend, flow_matching=self.flow_matching
486
761
  )
762
+ logger.debug(f"Loading flow of type {FlowClass} from {path}")
487
763
  self._flow = FlowClass.load(h5_file, path=path)
488
764
 
489
765
  def save_config_to_json(self, filename: str) -> None:
@@ -504,3 +780,80 @@ class Aspire:
504
780
  x, log_q = self.flow.sample_and_log_prob(n_samples)
505
781
  samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
506
782
  return samples
783
+
784
+ # --- Resume helpers ---
785
+ @staticmethod
786
+ def _build_aspire_from_file(
787
+ file_path: str,
788
+ log_likelihood: Callable,
789
+ log_prior: Callable,
790
+ checkpoint_path: str,
791
+ checkpoint_dset: str,
792
+ flow_path: str,
793
+ config_path: str,
794
+ ):
795
+ """Construct an Aspire instance, load flow, and gather checkpoint metadata from file."""
796
+ with AspireFile(file_path, "r") as h5_file:
797
+ if config_path not in h5_file:
798
+ raise ValueError(
799
+ f"Config path '{config_path}' not found in {file_path}"
800
+ )
801
+ config_dict = load_from_h5_file(h5_file, config_path)
802
+ try:
803
+ checkpoint_bytes = h5_file[checkpoint_path][checkpoint_dset][
804
+ ...
805
+ ].tobytes()
806
+ except Exception:
807
+ logger.warning(
808
+ "Checkpoint not found at %s/%s in %s; will resume without a checkpoint.",
809
+ checkpoint_path,
810
+ checkpoint_dset,
811
+ file_path,
812
+ )
813
+ checkpoint_bytes = None
814
+
815
+ sampler_config = config_dict.pop("sampler_config", None)
816
+ saved_sampler_type = config_dict.pop("sampler_type", None)
817
+ if isinstance(config_dict.get("xp"), str):
818
+ config_dict["xp"] = resolve_xp(config_dict["xp"])
819
+ config_dict["log_likelihood"] = log_likelihood
820
+ config_dict["log_prior"] = log_prior
821
+
822
+ aspire = Aspire(**config_dict)
823
+
824
+ with AspireFile(file_path, "r") as h5_file:
825
+ if flow_path in h5_file:
826
+ logger.info(f"Loading flow from {flow_path} in {file_path}")
827
+ aspire.load_flow(h5_file, path=flow_path)
828
+ else:
829
+ raise ValueError(
830
+ f"Flow path '{flow_path}' not found in {file_path}"
831
+ )
832
+
833
+ n_samples = None
834
+ checkpoint_state = None
835
+ if checkpoint_bytes is not None:
836
+ try:
837
+ checkpoint_state = pickle.loads(checkpoint_bytes)
838
+ samples_saved = (
839
+ checkpoint_state.get("samples")
840
+ if checkpoint_state
841
+ else None
842
+ )
843
+ if samples_saved is not None:
844
+ n_samples = len(samples_saved)
845
+ if aspire.xp is None and hasattr(samples_saved, "xp"):
846
+ aspire.xp = samples_saved.xp
847
+ except Exception:
848
+ logger.warning(
849
+ "Failed to decode checkpoint; proceeding without resume state."
850
+ )
851
+
852
+ return (
853
+ aspire,
854
+ checkpoint_bytes,
855
+ checkpoint_state,
856
+ sampler_config,
857
+ saved_sampler_type,
858
+ n_samples,
859
+ )
@@ -92,7 +92,7 @@ class BaseTorchFlow(Flow):
92
92
  config = load_from_h5_file(flow_grp, "config")
93
93
  config["dtype"] = decode_dtype(torch, config.get("dtype"))
94
94
  if "data_transform" in flow_grp:
95
- from ..transforms import BaseTransform
95
+ from ...transforms import BaseTransform
96
96
 
97
97
  data_transform = BaseTransform.load(
98
98
  flow_grp,