aspire-inference 0.1.0a9__py3-none-any.whl → 0.1.0a11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/aspire.py +356 -4
- aspire/flows/torch/flows.py +1 -1
- aspire/samplers/base.py +149 -5
- aspire/samplers/smc/base.py +133 -48
- aspire/samplers/smc/blackjax.py +8 -0
- aspire/samplers/smc/emcee.py +8 -0
- aspire/samplers/smc/minipcn.py +26 -6
- aspire/samples.py +21 -15
- aspire/utils.py +157 -4
- {aspire_inference-0.1.0a9.dist-info → aspire_inference-0.1.0a11.dist-info}/METADATA +23 -4
- aspire_inference-0.1.0a11.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a9.dist-info/RECORD +0 -28
- {aspire_inference-0.1.0a9.dist-info → aspire_inference-0.1.0a11.dist-info}/WHEEL +0 -0
- {aspire_inference-0.1.0a9.dist-info → aspire_inference-0.1.0a11.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a9.dist-info → aspire_inference-0.1.0a11.dist-info}/top_level.txt +0 -0
aspire/aspire.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import logging
|
|
2
3
|
import multiprocessing as mp
|
|
4
|
+
import pickle
|
|
5
|
+
from contextlib import contextmanager
|
|
3
6
|
from inspect import signature
|
|
4
7
|
from typing import Any, Callable
|
|
5
8
|
|
|
@@ -8,13 +11,19 @@ import h5py
|
|
|
8
11
|
from .flows import get_flow_wrapper
|
|
9
12
|
from .flows.base import Flow
|
|
10
13
|
from .history import History
|
|
14
|
+
from .samplers.base import Sampler
|
|
11
15
|
from .samples import Samples
|
|
12
16
|
from .transforms import (
|
|
13
17
|
CompositeTransform,
|
|
14
18
|
FlowPreconditioningTransform,
|
|
15
19
|
FlowTransform,
|
|
16
20
|
)
|
|
17
|
-
from .utils import
|
|
21
|
+
from .utils import (
|
|
22
|
+
AspireFile,
|
|
23
|
+
load_from_h5_file,
|
|
24
|
+
recursively_save_to_h5_file,
|
|
25
|
+
resolve_xp,
|
|
26
|
+
)
|
|
18
27
|
|
|
19
28
|
logger = logging.getLogger(__name__)
|
|
20
29
|
|
|
@@ -102,6 +111,7 @@ class Aspire:
|
|
|
102
111
|
self.dtype = dtype
|
|
103
112
|
|
|
104
113
|
self._flow = flow
|
|
114
|
+
self._sampler = None
|
|
105
115
|
|
|
106
116
|
@property
|
|
107
117
|
def flow(self):
|
|
@@ -114,7 +124,7 @@ class Aspire:
|
|
|
114
124
|
self._flow = flow
|
|
115
125
|
|
|
116
126
|
@property
|
|
117
|
-
def sampler(self):
|
|
127
|
+
def sampler(self) -> Sampler | None:
|
|
118
128
|
"""The sampler object."""
|
|
119
129
|
return self._sampler
|
|
120
130
|
|
|
@@ -192,7 +202,29 @@ class Aspire:
|
|
|
192
202
|
**self.flow_kwargs,
|
|
193
203
|
)
|
|
194
204
|
|
|
195
|
-
def fit(
|
|
205
|
+
def fit(
|
|
206
|
+
self,
|
|
207
|
+
samples: Samples,
|
|
208
|
+
checkpoint_path: str | None = None,
|
|
209
|
+
checkpoint_save_config: bool = True,
|
|
210
|
+
overwrite: bool = False,
|
|
211
|
+
**kwargs,
|
|
212
|
+
) -> History:
|
|
213
|
+
"""Fit the normalizing flow to the provided samples.
|
|
214
|
+
|
|
215
|
+
Parameters
|
|
216
|
+
----------
|
|
217
|
+
samples : Samples
|
|
218
|
+
The samples to fit the flow to.
|
|
219
|
+
checkpoint_path : str | None
|
|
220
|
+
Path to save the checkpoint. If None, no checkpoint is saved.
|
|
221
|
+
checkpoint_save_config : bool
|
|
222
|
+
Whether to save the Aspire configuration to the checkpoint.
|
|
223
|
+
overwrite : bool
|
|
224
|
+
Whether to overwrite an existing flow in the checkpoint file.
|
|
225
|
+
kwargs : dict
|
|
226
|
+
Keyword arguments to pass to the flow's fit method.
|
|
227
|
+
"""
|
|
196
228
|
if self.xp is None:
|
|
197
229
|
self.xp = samples.xp
|
|
198
230
|
|
|
@@ -202,6 +234,28 @@ class Aspire:
|
|
|
202
234
|
self.training_samples = samples
|
|
203
235
|
logger.info(f"Training with {len(samples.x)} samples")
|
|
204
236
|
history = self.flow.fit(samples.x, **kwargs)
|
|
237
|
+
defaults = getattr(self, "_checkpoint_defaults", None)
|
|
238
|
+
if checkpoint_path is None and defaults:
|
|
239
|
+
checkpoint_path = defaults["path"]
|
|
240
|
+
checkpoint_save_config = defaults["save_config"]
|
|
241
|
+
saved_config = (
|
|
242
|
+
defaults.get("saved_config", False) if defaults else False
|
|
243
|
+
)
|
|
244
|
+
if checkpoint_path is not None:
|
|
245
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
246
|
+
if checkpoint_save_config and not saved_config:
|
|
247
|
+
if "aspire_config" in h5_file:
|
|
248
|
+
del h5_file["aspire_config"]
|
|
249
|
+
self.save_config(h5_file, include_sampler_config=False)
|
|
250
|
+
if defaults is not None:
|
|
251
|
+
defaults["saved_config"] = True
|
|
252
|
+
# Save flow only if missing or overwrite=True
|
|
253
|
+
if "flow" in h5_file:
|
|
254
|
+
if overwrite:
|
|
255
|
+
del h5_file["flow"]
|
|
256
|
+
self.save_flow(h5_file)
|
|
257
|
+
else:
|
|
258
|
+
self.save_flow(h5_file)
|
|
205
259
|
return history
|
|
206
260
|
|
|
207
261
|
def get_sampler_class(self, sampler_type: str) -> Callable:
|
|
@@ -241,6 +295,13 @@ class Aspire:
|
|
|
241
295
|
----------
|
|
242
296
|
sampler_type : str
|
|
243
297
|
The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
|
|
298
|
+
preconditioning: str
|
|
299
|
+
Type of preconditioning to apply in the sampler. Options are
|
|
300
|
+
'default', 'flow', or 'none'.
|
|
301
|
+
preconditioning_kwargs: dict
|
|
302
|
+
Keyword arguments to pass to the preconditioning transform.
|
|
303
|
+
kwargs : dict
|
|
304
|
+
Keyword arguments to pass to the sampler.
|
|
244
305
|
"""
|
|
245
306
|
SamplerClass = self.get_sampler_class(sampler_type)
|
|
246
307
|
|
|
@@ -304,6 +365,9 @@ class Aspire:
|
|
|
304
365
|
return_history: bool = False,
|
|
305
366
|
preconditioning: str | None = None,
|
|
306
367
|
preconditioning_kwargs: dict | None = None,
|
|
368
|
+
checkpoint_path: str | None = None,
|
|
369
|
+
checkpoint_every: int = 1,
|
|
370
|
+
checkpoint_save_config: bool = True,
|
|
307
371
|
**kwargs,
|
|
308
372
|
) -> Samples:
|
|
309
373
|
"""Draw samples from the posterior distribution.
|
|
@@ -342,6 +406,14 @@ class Aspire:
|
|
|
342
406
|
will default to 'none' and the other samplers to 'default'
|
|
343
407
|
preconditioning_kwargs: dict
|
|
344
408
|
Keyword arguments to pass to the preconditioning transform.
|
|
409
|
+
checkpoint_path : str | None
|
|
410
|
+
Path to save the checkpoint. If None, no checkpoint is saved unless
|
|
411
|
+
within an :py:meth:`auto_checkpoint` context or a custom callback
|
|
412
|
+
is provided.
|
|
413
|
+
checkpoint_every : int
|
|
414
|
+
Frequency (in number of sampler iterations) to save the checkpoint.
|
|
415
|
+
checkpoint_save_config : bool
|
|
416
|
+
Whether to save the Aspire configuration to the checkpoint.
|
|
345
417
|
kwargs : dict
|
|
346
418
|
Keyword arguments to pass to the sampler. These are passed
|
|
347
419
|
automatically to the init method of the sampler or to the sample
|
|
@@ -352,6 +424,22 @@ class Aspire:
|
|
|
352
424
|
samples : Samples
|
|
353
425
|
Samples object contain samples and their corresponding weights.
|
|
354
426
|
"""
|
|
427
|
+
if (
|
|
428
|
+
sampler == "importance"
|
|
429
|
+
and hasattr(self, "_resume_sampler_type")
|
|
430
|
+
and self._resume_sampler_type
|
|
431
|
+
):
|
|
432
|
+
sampler = self._resume_sampler_type
|
|
433
|
+
|
|
434
|
+
if "resume_from" not in kwargs and hasattr(
|
|
435
|
+
self, "_resume_from_default"
|
|
436
|
+
):
|
|
437
|
+
kwargs["resume_from"] = self._resume_from_default
|
|
438
|
+
if hasattr(self, "_resume_overrides"):
|
|
439
|
+
kwargs.update(self._resume_overrides)
|
|
440
|
+
if hasattr(self, "_resume_n_samples") and n_samples == 1000:
|
|
441
|
+
n_samples = self._resume_n_samples
|
|
442
|
+
|
|
355
443
|
SamplerClass = self.get_sampler_class(sampler)
|
|
356
444
|
# Determine sampler initialization parameters
|
|
357
445
|
# and remove them from kwargs
|
|
@@ -373,7 +461,73 @@ class Aspire:
|
|
|
373
461
|
preconditioning_kwargs=preconditioning_kwargs,
|
|
374
462
|
**sampler_kwargs,
|
|
375
463
|
)
|
|
464
|
+
self._last_sampler_type = sampler
|
|
465
|
+
# Auto-checkpoint convenience: set defaults for checkpointing to a single file
|
|
466
|
+
defaults = getattr(self, "_checkpoint_defaults", None)
|
|
467
|
+
if checkpoint_path is None and defaults:
|
|
468
|
+
checkpoint_path = defaults["path"]
|
|
469
|
+
checkpoint_every = defaults["every"]
|
|
470
|
+
checkpoint_save_config = defaults["save_config"]
|
|
471
|
+
saved_flow = defaults.get("saved_flow", False) if defaults else False
|
|
472
|
+
saved_config = (
|
|
473
|
+
defaults.get("saved_config", False) if defaults else False
|
|
474
|
+
)
|
|
475
|
+
if checkpoint_path is not None:
|
|
476
|
+
kwargs.setdefault("checkpoint_file_path", checkpoint_path)
|
|
477
|
+
kwargs.setdefault("checkpoint_every", checkpoint_every)
|
|
478
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
479
|
+
if checkpoint_save_config:
|
|
480
|
+
if "aspire_config" in h5_file:
|
|
481
|
+
del h5_file["aspire_config"]
|
|
482
|
+
self.save_config(
|
|
483
|
+
h5_file,
|
|
484
|
+
include_sampler_config=True,
|
|
485
|
+
include_sample_calls=False,
|
|
486
|
+
)
|
|
487
|
+
saved_config = True
|
|
488
|
+
if defaults is not None:
|
|
489
|
+
defaults["saved_config"] = True
|
|
490
|
+
if (
|
|
491
|
+
self.flow is not None
|
|
492
|
+
and not saved_flow
|
|
493
|
+
and "flow" not in h5_file
|
|
494
|
+
):
|
|
495
|
+
self.save_flow(h5_file)
|
|
496
|
+
saved_flow = True
|
|
497
|
+
if defaults is not None:
|
|
498
|
+
defaults["saved_flow"] = True
|
|
499
|
+
|
|
376
500
|
samples = self._sampler.sample(n_samples, **kwargs)
|
|
501
|
+
self._last_sample_posterior_kwargs = {
|
|
502
|
+
"n_samples": n_samples,
|
|
503
|
+
"sampler": sampler,
|
|
504
|
+
"xp": xp,
|
|
505
|
+
"return_history": return_history,
|
|
506
|
+
"preconditioning": preconditioning,
|
|
507
|
+
"preconditioning_kwargs": preconditioning_kwargs,
|
|
508
|
+
"sampler_init_kwargs": sampler_kwargs,
|
|
509
|
+
"sample_kwargs": copy.deepcopy(kwargs),
|
|
510
|
+
}
|
|
511
|
+
if checkpoint_path is not None:
|
|
512
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
513
|
+
if checkpoint_save_config and not saved_config:
|
|
514
|
+
if "aspire_config" in h5_file:
|
|
515
|
+
del h5_file["aspire_config"]
|
|
516
|
+
self.save_config(
|
|
517
|
+
h5_file,
|
|
518
|
+
include_sampler_config=True,
|
|
519
|
+
include_sample_calls=False,
|
|
520
|
+
)
|
|
521
|
+
if defaults is not None:
|
|
522
|
+
defaults["saved_config"] = True
|
|
523
|
+
if (
|
|
524
|
+
self.flow is not None
|
|
525
|
+
and not saved_flow
|
|
526
|
+
and "flow" not in h5_file
|
|
527
|
+
):
|
|
528
|
+
self.save_flow(h5_file)
|
|
529
|
+
if defaults is not None:
|
|
530
|
+
defaults["saved_flow"] = True
|
|
377
531
|
if xp is not None:
|
|
378
532
|
samples = samples.to_namespace(xp)
|
|
379
533
|
samples.parameters = self.parameters
|
|
@@ -388,6 +542,122 @@ class Aspire:
|
|
|
388
542
|
else:
|
|
389
543
|
return samples
|
|
390
544
|
|
|
545
|
+
@classmethod
|
|
546
|
+
def resume_from_file(
|
|
547
|
+
cls,
|
|
548
|
+
file_path: str,
|
|
549
|
+
*,
|
|
550
|
+
log_likelihood: Callable,
|
|
551
|
+
log_prior: Callable,
|
|
552
|
+
sampler: str | None = None,
|
|
553
|
+
checkpoint_path: str = "checkpoint",
|
|
554
|
+
checkpoint_dset: str = "state",
|
|
555
|
+
flow_path: str = "flow",
|
|
556
|
+
config_path: str = "aspire_config",
|
|
557
|
+
resume_kwargs: dict | None = None,
|
|
558
|
+
):
|
|
559
|
+
"""
|
|
560
|
+
Recreate an Aspire object from a single file and prepare to resume sampling.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
----------
|
|
564
|
+
file_path : str
|
|
565
|
+
Path to the HDF5 file containing config, flow, and checkpoint.
|
|
566
|
+
log_likelihood : Callable
|
|
567
|
+
Log-likelihood function (required, not pickled).
|
|
568
|
+
log_prior : Callable
|
|
569
|
+
Log-prior function (required, not pickled).
|
|
570
|
+
sampler : str
|
|
571
|
+
Sampler type to use (e.g., 'smc', 'minipcn_smc', 'emcee_smc'). If None,
|
|
572
|
+
will attempt to infer from saved config or checkpoint metadata.
|
|
573
|
+
checkpoint_path : str
|
|
574
|
+
HDF5 group path where the checkpoint is stored.
|
|
575
|
+
checkpoint_dset : str
|
|
576
|
+
Dataset name within the checkpoint group.
|
|
577
|
+
flow_path : str
|
|
578
|
+
HDF5 path to the saved flow.
|
|
579
|
+
config_path : str
|
|
580
|
+
HDF5 path to the saved Aspire config.
|
|
581
|
+
resume_kwargs : dict | None
|
|
582
|
+
Optional overrides to apply when resuming (e.g., checkpoint_every).
|
|
583
|
+
"""
|
|
584
|
+
(
|
|
585
|
+
aspire,
|
|
586
|
+
checkpoint_bytes,
|
|
587
|
+
checkpoint_state,
|
|
588
|
+
sampler_config,
|
|
589
|
+
saved_sampler_type,
|
|
590
|
+
n_samples,
|
|
591
|
+
) = cls._build_aspire_from_file(
|
|
592
|
+
file_path=file_path,
|
|
593
|
+
log_likelihood=log_likelihood,
|
|
594
|
+
log_prior=log_prior,
|
|
595
|
+
checkpoint_path=checkpoint_path,
|
|
596
|
+
checkpoint_dset=checkpoint_dset,
|
|
597
|
+
flow_path=flow_path,
|
|
598
|
+
config_path=config_path,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
sampler_config = sampler_config or {}
|
|
602
|
+
sampler_config.pop("sampler_class", None)
|
|
603
|
+
|
|
604
|
+
if checkpoint_bytes is not None:
|
|
605
|
+
aspire._resume_from_default = checkpoint_bytes
|
|
606
|
+
aspire._resume_sampler_type = (
|
|
607
|
+
sampler
|
|
608
|
+
or saved_sampler_type
|
|
609
|
+
or (
|
|
610
|
+
checkpoint_state.get("sampler")
|
|
611
|
+
if checkpoint_state
|
|
612
|
+
else None
|
|
613
|
+
)
|
|
614
|
+
)
|
|
615
|
+
aspire._resume_n_samples = n_samples
|
|
616
|
+
aspire._resume_overrides = resume_kwargs or {}
|
|
617
|
+
aspire._resume_sampler_config = sampler_config
|
|
618
|
+
aspire._checkpoint_defaults = {
|
|
619
|
+
"path": file_path,
|
|
620
|
+
"every": 1,
|
|
621
|
+
"save_config": False,
|
|
622
|
+
"save_flow": False,
|
|
623
|
+
"saved_config": False,
|
|
624
|
+
"saved_flow": False,
|
|
625
|
+
}
|
|
626
|
+
return aspire
|
|
627
|
+
|
|
628
|
+
@contextmanager
|
|
629
|
+
def auto_checkpoint(
|
|
630
|
+
self,
|
|
631
|
+
path: str,
|
|
632
|
+
every: int = 1,
|
|
633
|
+
save_config: bool = True,
|
|
634
|
+
save_flow: bool = True,
|
|
635
|
+
):
|
|
636
|
+
"""
|
|
637
|
+
Context manager to auto-save checkpoints, config, and flow to a file.
|
|
638
|
+
|
|
639
|
+
Within the context, sample_posterior will default to writing checkpoints
|
|
640
|
+
to the given path with the specified frequency, and will append config/flow
|
|
641
|
+
after sampling.
|
|
642
|
+
"""
|
|
643
|
+
prev = getattr(self, "_checkpoint_defaults", None)
|
|
644
|
+
self._checkpoint_defaults = {
|
|
645
|
+
"path": path,
|
|
646
|
+
"every": every,
|
|
647
|
+
"save_config": save_config,
|
|
648
|
+
"save_flow": save_flow,
|
|
649
|
+
"saved_config": False,
|
|
650
|
+
"saved_flow": False,
|
|
651
|
+
}
|
|
652
|
+
try:
|
|
653
|
+
yield self
|
|
654
|
+
finally:
|
|
655
|
+
if prev is None:
|
|
656
|
+
if hasattr(self, "_checkpoint_defaults"):
|
|
657
|
+
delattr(self, "_checkpoint_defaults")
|
|
658
|
+
else:
|
|
659
|
+
self._checkpoint_defaults = prev
|
|
660
|
+
|
|
391
661
|
def enable_pool(self, pool: mp.Pool, **kwargs):
|
|
392
662
|
"""Context manager to temporarily replace the log_likelihood method
|
|
393
663
|
with a version that uses a multiprocessing pool to parallelize
|
|
@@ -432,12 +702,16 @@ class Aspire:
|
|
|
432
702
|
"flow_kwargs": self.flow_kwargs,
|
|
433
703
|
"eps": self.eps,
|
|
434
704
|
}
|
|
705
|
+
if hasattr(self, "_last_sampler_type"):
|
|
706
|
+
config["sampler_type"] = self._last_sampler_type
|
|
435
707
|
if include_sampler_config:
|
|
708
|
+
if self.sampler is None:
|
|
709
|
+
raise ValueError("Sampler has not been initialized.")
|
|
436
710
|
config["sampler_config"] = self.sampler.config_dict(**kwargs)
|
|
437
711
|
return config
|
|
438
712
|
|
|
439
713
|
def save_config(
|
|
440
|
-
self, h5_file: h5py.File, path="aspire_config", **kwargs
|
|
714
|
+
self, h5_file: h5py.File | AspireFile, path="aspire_config", **kwargs
|
|
441
715
|
) -> None:
|
|
442
716
|
"""Save the configuration to an HDF5 file.
|
|
443
717
|
|
|
@@ -484,6 +758,7 @@ class Aspire:
|
|
|
484
758
|
FlowClass, xp = get_flow_wrapper(
|
|
485
759
|
backend=self.flow_backend, flow_matching=self.flow_matching
|
|
486
760
|
)
|
|
761
|
+
logger.debug(f"Loading flow of type {FlowClass} from {path}")
|
|
487
762
|
self._flow = FlowClass.load(h5_file, path=path)
|
|
488
763
|
|
|
489
764
|
def save_config_to_json(self, filename: str) -> None:
|
|
@@ -504,3 +779,80 @@ class Aspire:
|
|
|
504
779
|
x, log_q = self.flow.sample_and_log_prob(n_samples)
|
|
505
780
|
samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
|
|
506
781
|
return samples
|
|
782
|
+
|
|
783
|
+
# --- Resume helpers ---
|
|
784
|
+
@staticmethod
|
|
785
|
+
def _build_aspire_from_file(
|
|
786
|
+
file_path: str,
|
|
787
|
+
log_likelihood: Callable,
|
|
788
|
+
log_prior: Callable,
|
|
789
|
+
checkpoint_path: str,
|
|
790
|
+
checkpoint_dset: str,
|
|
791
|
+
flow_path: str,
|
|
792
|
+
config_path: str,
|
|
793
|
+
):
|
|
794
|
+
"""Construct an Aspire instance, load flow, and gather checkpoint metadata from file."""
|
|
795
|
+
with AspireFile(file_path, "r") as h5_file:
|
|
796
|
+
if config_path not in h5_file:
|
|
797
|
+
raise ValueError(
|
|
798
|
+
f"Config path '{config_path}' not found in {file_path}"
|
|
799
|
+
)
|
|
800
|
+
config_dict = load_from_h5_file(h5_file, config_path)
|
|
801
|
+
try:
|
|
802
|
+
checkpoint_bytes = h5_file[checkpoint_path][checkpoint_dset][
|
|
803
|
+
...
|
|
804
|
+
].tobytes()
|
|
805
|
+
except Exception:
|
|
806
|
+
logger.warning(
|
|
807
|
+
"Checkpoint not found at %s/%s in %s; will resume without a checkpoint.",
|
|
808
|
+
checkpoint_path,
|
|
809
|
+
checkpoint_dset,
|
|
810
|
+
file_path,
|
|
811
|
+
)
|
|
812
|
+
checkpoint_bytes = None
|
|
813
|
+
|
|
814
|
+
sampler_config = config_dict.pop("sampler_config", None)
|
|
815
|
+
saved_sampler_type = config_dict.pop("sampler_type", None)
|
|
816
|
+
if isinstance(config_dict.get("xp"), str):
|
|
817
|
+
config_dict["xp"] = resolve_xp(config_dict["xp"])
|
|
818
|
+
config_dict["log_likelihood"] = log_likelihood
|
|
819
|
+
config_dict["log_prior"] = log_prior
|
|
820
|
+
|
|
821
|
+
aspire = Aspire(**config_dict)
|
|
822
|
+
|
|
823
|
+
with AspireFile(file_path, "r") as h5_file:
|
|
824
|
+
if flow_path in h5_file:
|
|
825
|
+
logger.info(f"Loading flow from {flow_path} in {file_path}")
|
|
826
|
+
aspire.load_flow(h5_file, path=flow_path)
|
|
827
|
+
else:
|
|
828
|
+
raise ValueError(
|
|
829
|
+
f"Flow path '{flow_path}' not found in {file_path}"
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
n_samples = None
|
|
833
|
+
checkpoint_state = None
|
|
834
|
+
if checkpoint_bytes is not None:
|
|
835
|
+
try:
|
|
836
|
+
checkpoint_state = pickle.loads(checkpoint_bytes)
|
|
837
|
+
samples_saved = (
|
|
838
|
+
checkpoint_state.get("samples")
|
|
839
|
+
if checkpoint_state
|
|
840
|
+
else None
|
|
841
|
+
)
|
|
842
|
+
if samples_saved is not None:
|
|
843
|
+
n_samples = len(samples_saved)
|
|
844
|
+
if aspire.xp is None and hasattr(samples_saved, "xp"):
|
|
845
|
+
aspire.xp = samples_saved.xp
|
|
846
|
+
except Exception:
|
|
847
|
+
logger.warning(
|
|
848
|
+
"Failed to decode checkpoint; proceeding without resume state."
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
return (
|
|
852
|
+
aspire,
|
|
853
|
+
checkpoint_bytes,
|
|
854
|
+
checkpoint_state,
|
|
855
|
+
sampler_config,
|
|
856
|
+
saved_sampler_type,
|
|
857
|
+
n_samples,
|
|
858
|
+
)
|
aspire/flows/torch/flows.py
CHANGED
|
@@ -92,7 +92,7 @@ class BaseTorchFlow(Flow):
|
|
|
92
92
|
config = load_from_h5_file(flow_grp, "config")
|
|
93
93
|
config["dtype"] = decode_dtype(torch, config.get("dtype"))
|
|
94
94
|
if "data_transform" in flow_grp:
|
|
95
|
-
from
|
|
95
|
+
from ...transforms import BaseTransform
|
|
96
96
|
|
|
97
97
|
data_transform = BaseTransform.load(
|
|
98
98
|
flow_grp,
|
aspire/samplers/base.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import pickle
|
|
3
|
+
from pathlib import Path
|
|
2
4
|
from typing import Any, Callable
|
|
3
5
|
|
|
4
6
|
from ..flows.base import Flow
|
|
5
7
|
from ..samples import Samples
|
|
6
8
|
from ..transforms import IdentityTransform
|
|
7
|
-
from ..utils import track_calls
|
|
9
|
+
from ..utils import AspireFile, asarray, dump_state, track_calls
|
|
8
10
|
|
|
9
11
|
logger = logging.getLogger(__name__)
|
|
10
12
|
|
|
@@ -49,6 +51,8 @@ class Sampler:
|
|
|
49
51
|
self.parameters = parameters
|
|
50
52
|
self.history = None
|
|
51
53
|
self.n_likelihood_evaluations = 0
|
|
54
|
+
self._last_checkpoint_state: dict | None = None
|
|
55
|
+
self._last_checkpoint_bytes: bytes | None = None
|
|
52
56
|
if preconditioning_transform is None:
|
|
53
57
|
self.preconditioning_transform = IdentityTransform(xp=self.xp)
|
|
54
58
|
else:
|
|
@@ -56,7 +60,11 @@ class Sampler:
|
|
|
56
60
|
|
|
57
61
|
def fit_preconditioning_transform(self, x):
|
|
58
62
|
"""Fit the data transform to the data."""
|
|
59
|
-
x =
|
|
63
|
+
x = asarray(
|
|
64
|
+
x,
|
|
65
|
+
xp=self.preconditioning_transform.xp,
|
|
66
|
+
dtype=self.preconditioning_transform.dtype,
|
|
67
|
+
)
|
|
60
68
|
return self.preconditioning_transform.fit(x)
|
|
61
69
|
|
|
62
70
|
@track_calls
|
|
@@ -71,7 +79,7 @@ class Sampler:
|
|
|
71
79
|
self.n_likelihood_evaluations += len(samples)
|
|
72
80
|
return self._log_likelihood(samples)
|
|
73
81
|
|
|
74
|
-
def config_dict(self, include_sample_calls: bool =
|
|
82
|
+
def config_dict(self, include_sample_calls: bool = False) -> dict:
|
|
75
83
|
"""
|
|
76
84
|
Returns a dictionary with the configuration of the sampler.
|
|
77
85
|
|
|
@@ -79,9 +87,9 @@ class Sampler:
|
|
|
79
87
|
----------
|
|
80
88
|
include_sample_calls : bool
|
|
81
89
|
Whether to include the sample calls in the configuration.
|
|
82
|
-
Default is
|
|
90
|
+
Default is False.
|
|
83
91
|
"""
|
|
84
|
-
config = {}
|
|
92
|
+
config = {"sampler_class": self.__class__.__name__}
|
|
85
93
|
if include_sample_calls:
|
|
86
94
|
if hasattr(self, "sample") and hasattr(self.sample, "calls"):
|
|
87
95
|
config["sample_calls"] = self.sample.calls.to_dict(
|
|
@@ -92,3 +100,139 @@ class Sampler:
|
|
|
92
100
|
"Sampler does not have a sample method with calls attribute."
|
|
93
101
|
)
|
|
94
102
|
return config
|
|
103
|
+
|
|
104
|
+
# --- Checkpointing helpers shared across samplers ---
|
|
105
|
+
def _checkpoint_extra_state(self) -> dict:
|
|
106
|
+
"""Sampler-specific extras for checkpointing (override in subclasses)."""
|
|
107
|
+
return {}
|
|
108
|
+
|
|
109
|
+
def _restore_extra_state(self, state: dict) -> None:
|
|
110
|
+
"""Restore sampler-specific extras (override in subclasses)."""
|
|
111
|
+
_ = state # no-op for base
|
|
112
|
+
|
|
113
|
+
def build_checkpoint_state(
|
|
114
|
+
self,
|
|
115
|
+
samples: Samples,
|
|
116
|
+
iteration: int | None = None,
|
|
117
|
+
meta: dict | None = None,
|
|
118
|
+
) -> dict:
|
|
119
|
+
"""Prepare a serializable checkpoint payload for the sampler state."""
|
|
120
|
+
checkpoint_samples = samples
|
|
121
|
+
base_state = {
|
|
122
|
+
"sampler": self.__class__.__name__,
|
|
123
|
+
"iteration": iteration,
|
|
124
|
+
"samples": checkpoint_samples,
|
|
125
|
+
"config": self.config_dict(include_sample_calls=False),
|
|
126
|
+
"parameters": self.parameters,
|
|
127
|
+
"meta": meta or {},
|
|
128
|
+
}
|
|
129
|
+
base_state.update(self._checkpoint_extra_state())
|
|
130
|
+
return base_state
|
|
131
|
+
|
|
132
|
+
def serialize_checkpoint(
|
|
133
|
+
self, state: dict, protocol: int | None = None
|
|
134
|
+
) -> bytes:
|
|
135
|
+
"""Serialize a checkpoint state to bytes with pickle."""
|
|
136
|
+
protocol = (
|
|
137
|
+
pickle.HIGHEST_PROTOCOL if protocol is None else int(protocol)
|
|
138
|
+
)
|
|
139
|
+
return pickle.dumps(state, protocol=protocol)
|
|
140
|
+
|
|
141
|
+
def default_checkpoint_callback(self, state: dict) -> None:
|
|
142
|
+
"""Store the latest checkpoint (state + pickled bytes) on the sampler."""
|
|
143
|
+
self._last_checkpoint_state = state
|
|
144
|
+
self._last_checkpoint_bytes = self.serialize_checkpoint(state)
|
|
145
|
+
|
|
146
|
+
def default_file_checkpoint_callback(
|
|
147
|
+
self, file_path: str | Path | None
|
|
148
|
+
) -> Callable[[dict], None]:
|
|
149
|
+
"""Return a simple default callback that overwrites an HDF5 file."""
|
|
150
|
+
if file_path is None:
|
|
151
|
+
return self.default_checkpoint_callback
|
|
152
|
+
file_path = Path(file_path)
|
|
153
|
+
lower_path = file_path.name.lower()
|
|
154
|
+
if not lower_path.endswith((".h5", ".hdf5")):
|
|
155
|
+
raise ValueError(
|
|
156
|
+
"Checkpoint file must be an HDF5 file (.h5 or .hdf5)."
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def _callback(state: dict) -> None:
|
|
160
|
+
with AspireFile(file_path, "a") as h5_file:
|
|
161
|
+
self.save_checkpoint_to_hdf(
|
|
162
|
+
state, h5_file, path="checkpoint", dsetname="state"
|
|
163
|
+
)
|
|
164
|
+
self.default_checkpoint_callback(state)
|
|
165
|
+
|
|
166
|
+
return _callback
|
|
167
|
+
|
|
168
|
+
def save_checkpoint_to_hdf(
|
|
169
|
+
self,
|
|
170
|
+
state: dict,
|
|
171
|
+
h5_file,
|
|
172
|
+
path: str = "sampler_checkpoints",
|
|
173
|
+
dsetname: str | None = None,
|
|
174
|
+
protocol: int | None = None,
|
|
175
|
+
) -> None:
|
|
176
|
+
"""Save a checkpoint state into an HDF5 file as a pickled blob."""
|
|
177
|
+
if dsetname is None:
|
|
178
|
+
iter_str = state.get("iteration", "unknown")
|
|
179
|
+
dsetname = f"iter_{iter_str}"
|
|
180
|
+
dump_state(
|
|
181
|
+
state,
|
|
182
|
+
h5_file,
|
|
183
|
+
path=path,
|
|
184
|
+
dsetname=dsetname,
|
|
185
|
+
protocol=protocol or pickle.HIGHEST_PROTOCOL,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def load_checkpoint_from_file(
|
|
189
|
+
self,
|
|
190
|
+
file_path: str | Path,
|
|
191
|
+
h5_path: str = "checkpoint",
|
|
192
|
+
dsetname: str = "state",
|
|
193
|
+
) -> dict:
|
|
194
|
+
"""Load a checkpoint dictionary from .pkl or .hdf5 file."""
|
|
195
|
+
file_path = Path(file_path)
|
|
196
|
+
lower_path = file_path.name.lower()
|
|
197
|
+
if lower_path.endswith((".h5", ".hdf5")):
|
|
198
|
+
with AspireFile(file_path, "r") as h5_file:
|
|
199
|
+
data = h5_file[h5_path][dsetname][...]
|
|
200
|
+
checkpoint_bytes = data.tobytes()
|
|
201
|
+
else:
|
|
202
|
+
with open(file_path, "rb") as f:
|
|
203
|
+
checkpoint_bytes = f.read()
|
|
204
|
+
return pickle.loads(checkpoint_bytes)
|
|
205
|
+
|
|
206
|
+
def restore_from_checkpoint(
|
|
207
|
+
self, source: str | bytes | dict
|
|
208
|
+
) -> tuple[Samples, dict]:
|
|
209
|
+
"""Restore sampler state from a checkpoint source."""
|
|
210
|
+
if isinstance(source, str):
|
|
211
|
+
state = self.load_checkpoint_from_file(source)
|
|
212
|
+
elif isinstance(source, bytes):
|
|
213
|
+
state = pickle.loads(source)
|
|
214
|
+
elif isinstance(source, dict):
|
|
215
|
+
state = source
|
|
216
|
+
else:
|
|
217
|
+
raise TypeError("Unsupported checkpoint source type.")
|
|
218
|
+
|
|
219
|
+
samples_saved = state.get("samples")
|
|
220
|
+
if samples_saved is None:
|
|
221
|
+
raise ValueError("Checkpoint missing samples.")
|
|
222
|
+
|
|
223
|
+
samples = Samples.from_samples(
|
|
224
|
+
samples_saved, xp=self.xp, dtype=self.dtype
|
|
225
|
+
)
|
|
226
|
+
# Allow subclasses to restore sampler-specific components
|
|
227
|
+
self._restore_extra_state(state)
|
|
228
|
+
return samples, state
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def last_checkpoint_state(self) -> dict | None:
|
|
232
|
+
"""Return the most recent checkpoint state stored by the default callback."""
|
|
233
|
+
return self._last_checkpoint_state
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def last_checkpoint_bytes(self) -> bytes | None:
|
|
237
|
+
"""Return the most recent pickled checkpoint produced by the default callback."""
|
|
238
|
+
return self._last_checkpoint_bytes
|