aspire-inference 0.1.0a10__py3-none-any.whl → 0.1.0a11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/aspire.py +356 -4
- aspire/flows/torch/flows.py +1 -1
- aspire/samplers/base.py +144 -4
- aspire/samplers/smc/base.py +133 -48
- aspire/samplers/smc/blackjax.py +8 -0
- aspire/samplers/smc/emcee.py +8 -0
- aspire/samplers/smc/minipcn.py +10 -2
- aspire/samples.py +11 -9
- aspire/utils.py +98 -0
- {aspire_inference-0.1.0a10.dist-info → aspire_inference-0.1.0a11.dist-info}/METADATA +1 -1
- {aspire_inference-0.1.0a10.dist-info → aspire_inference-0.1.0a11.dist-info}/RECORD +14 -14
- {aspire_inference-0.1.0a10.dist-info → aspire_inference-0.1.0a11.dist-info}/WHEEL +0 -0
- {aspire_inference-0.1.0a10.dist-info → aspire_inference-0.1.0a11.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a10.dist-info → aspire_inference-0.1.0a11.dist-info}/top_level.txt +0 -0
aspire/aspire.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import logging
|
|
2
3
|
import multiprocessing as mp
|
|
4
|
+
import pickle
|
|
5
|
+
from contextlib import contextmanager
|
|
3
6
|
from inspect import signature
|
|
4
7
|
from typing import Any, Callable
|
|
5
8
|
|
|
@@ -8,13 +11,19 @@ import h5py
|
|
|
8
11
|
from .flows import get_flow_wrapper
|
|
9
12
|
from .flows.base import Flow
|
|
10
13
|
from .history import History
|
|
14
|
+
from .samplers.base import Sampler
|
|
11
15
|
from .samples import Samples
|
|
12
16
|
from .transforms import (
|
|
13
17
|
CompositeTransform,
|
|
14
18
|
FlowPreconditioningTransform,
|
|
15
19
|
FlowTransform,
|
|
16
20
|
)
|
|
17
|
-
from .utils import
|
|
21
|
+
from .utils import (
|
|
22
|
+
AspireFile,
|
|
23
|
+
load_from_h5_file,
|
|
24
|
+
recursively_save_to_h5_file,
|
|
25
|
+
resolve_xp,
|
|
26
|
+
)
|
|
18
27
|
|
|
19
28
|
logger = logging.getLogger(__name__)
|
|
20
29
|
|
|
@@ -102,6 +111,7 @@ class Aspire:
|
|
|
102
111
|
self.dtype = dtype
|
|
103
112
|
|
|
104
113
|
self._flow = flow
|
|
114
|
+
self._sampler = None
|
|
105
115
|
|
|
106
116
|
@property
|
|
107
117
|
def flow(self):
|
|
@@ -114,7 +124,7 @@ class Aspire:
|
|
|
114
124
|
self._flow = flow
|
|
115
125
|
|
|
116
126
|
@property
|
|
117
|
-
def sampler(self):
|
|
127
|
+
def sampler(self) -> Sampler | None:
|
|
118
128
|
"""The sampler object."""
|
|
119
129
|
return self._sampler
|
|
120
130
|
|
|
@@ -192,7 +202,29 @@ class Aspire:
|
|
|
192
202
|
**self.flow_kwargs,
|
|
193
203
|
)
|
|
194
204
|
|
|
195
|
-
def fit(
|
|
205
|
+
def fit(
|
|
206
|
+
self,
|
|
207
|
+
samples: Samples,
|
|
208
|
+
checkpoint_path: str | None = None,
|
|
209
|
+
checkpoint_save_config: bool = True,
|
|
210
|
+
overwrite: bool = False,
|
|
211
|
+
**kwargs,
|
|
212
|
+
) -> History:
|
|
213
|
+
"""Fit the normalizing flow to the provided samples.
|
|
214
|
+
|
|
215
|
+
Parameters
|
|
216
|
+
----------
|
|
217
|
+
samples : Samples
|
|
218
|
+
The samples to fit the flow to.
|
|
219
|
+
checkpoint_path : str | None
|
|
220
|
+
Path to save the checkpoint. If None, no checkpoint is saved.
|
|
221
|
+
checkpoint_save_config : bool
|
|
222
|
+
Whether to save the Aspire configuration to the checkpoint.
|
|
223
|
+
overwrite : bool
|
|
224
|
+
Whether to overwrite an existing flow in the checkpoint file.
|
|
225
|
+
kwargs : dict
|
|
226
|
+
Keyword arguments to pass to the flow's fit method.
|
|
227
|
+
"""
|
|
196
228
|
if self.xp is None:
|
|
197
229
|
self.xp = samples.xp
|
|
198
230
|
|
|
@@ -202,6 +234,28 @@ class Aspire:
|
|
|
202
234
|
self.training_samples = samples
|
|
203
235
|
logger.info(f"Training with {len(samples.x)} samples")
|
|
204
236
|
history = self.flow.fit(samples.x, **kwargs)
|
|
237
|
+
defaults = getattr(self, "_checkpoint_defaults", None)
|
|
238
|
+
if checkpoint_path is None and defaults:
|
|
239
|
+
checkpoint_path = defaults["path"]
|
|
240
|
+
checkpoint_save_config = defaults["save_config"]
|
|
241
|
+
saved_config = (
|
|
242
|
+
defaults.get("saved_config", False) if defaults else False
|
|
243
|
+
)
|
|
244
|
+
if checkpoint_path is not None:
|
|
245
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
246
|
+
if checkpoint_save_config and not saved_config:
|
|
247
|
+
if "aspire_config" in h5_file:
|
|
248
|
+
del h5_file["aspire_config"]
|
|
249
|
+
self.save_config(h5_file, include_sampler_config=False)
|
|
250
|
+
if defaults is not None:
|
|
251
|
+
defaults["saved_config"] = True
|
|
252
|
+
# Save flow only if missing or overwrite=True
|
|
253
|
+
if "flow" in h5_file:
|
|
254
|
+
if overwrite:
|
|
255
|
+
del h5_file["flow"]
|
|
256
|
+
self.save_flow(h5_file)
|
|
257
|
+
else:
|
|
258
|
+
self.save_flow(h5_file)
|
|
205
259
|
return history
|
|
206
260
|
|
|
207
261
|
def get_sampler_class(self, sampler_type: str) -> Callable:
|
|
@@ -241,6 +295,13 @@ class Aspire:
|
|
|
241
295
|
----------
|
|
242
296
|
sampler_type : str
|
|
243
297
|
The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
|
|
298
|
+
preconditioning: str
|
|
299
|
+
Type of preconditioning to apply in the sampler. Options are
|
|
300
|
+
'default', 'flow', or 'none'.
|
|
301
|
+
preconditioning_kwargs: dict
|
|
302
|
+
Keyword arguments to pass to the preconditioning transform.
|
|
303
|
+
kwargs : dict
|
|
304
|
+
Keyword arguments to pass to the sampler.
|
|
244
305
|
"""
|
|
245
306
|
SamplerClass = self.get_sampler_class(sampler_type)
|
|
246
307
|
|
|
@@ -304,6 +365,9 @@ class Aspire:
|
|
|
304
365
|
return_history: bool = False,
|
|
305
366
|
preconditioning: str | None = None,
|
|
306
367
|
preconditioning_kwargs: dict | None = None,
|
|
368
|
+
checkpoint_path: str | None = None,
|
|
369
|
+
checkpoint_every: int = 1,
|
|
370
|
+
checkpoint_save_config: bool = True,
|
|
307
371
|
**kwargs,
|
|
308
372
|
) -> Samples:
|
|
309
373
|
"""Draw samples from the posterior distribution.
|
|
@@ -342,6 +406,14 @@ class Aspire:
|
|
|
342
406
|
will default to 'none' and the other samplers to 'default'
|
|
343
407
|
preconditioning_kwargs: dict
|
|
344
408
|
Keyword arguments to pass to the preconditioning transform.
|
|
409
|
+
checkpoint_path : str | None
|
|
410
|
+
Path to save the checkpoint. If None, no checkpoint is saved unless
|
|
411
|
+
within an :py:meth:`auto_checkpoint` context or a custom callback
|
|
412
|
+
is provided.
|
|
413
|
+
checkpoint_every : int
|
|
414
|
+
Frequency (in number of sampler iterations) to save the checkpoint.
|
|
415
|
+
checkpoint_save_config : bool
|
|
416
|
+
Whether to save the Aspire configuration to the checkpoint.
|
|
345
417
|
kwargs : dict
|
|
346
418
|
Keyword arguments to pass to the sampler. These are passed
|
|
347
419
|
automatically to the init method of the sampler or to the sample
|
|
@@ -352,6 +424,22 @@ class Aspire:
|
|
|
352
424
|
samples : Samples
|
|
353
425
|
Samples object contain samples and their corresponding weights.
|
|
354
426
|
"""
|
|
427
|
+
if (
|
|
428
|
+
sampler == "importance"
|
|
429
|
+
and hasattr(self, "_resume_sampler_type")
|
|
430
|
+
and self._resume_sampler_type
|
|
431
|
+
):
|
|
432
|
+
sampler = self._resume_sampler_type
|
|
433
|
+
|
|
434
|
+
if "resume_from" not in kwargs and hasattr(
|
|
435
|
+
self, "_resume_from_default"
|
|
436
|
+
):
|
|
437
|
+
kwargs["resume_from"] = self._resume_from_default
|
|
438
|
+
if hasattr(self, "_resume_overrides"):
|
|
439
|
+
kwargs.update(self._resume_overrides)
|
|
440
|
+
if hasattr(self, "_resume_n_samples") and n_samples == 1000:
|
|
441
|
+
n_samples = self._resume_n_samples
|
|
442
|
+
|
|
355
443
|
SamplerClass = self.get_sampler_class(sampler)
|
|
356
444
|
# Determine sampler initialization parameters
|
|
357
445
|
# and remove them from kwargs
|
|
@@ -373,7 +461,73 @@ class Aspire:
|
|
|
373
461
|
preconditioning_kwargs=preconditioning_kwargs,
|
|
374
462
|
**sampler_kwargs,
|
|
375
463
|
)
|
|
464
|
+
self._last_sampler_type = sampler
|
|
465
|
+
# Auto-checkpoint convenience: set defaults for checkpointing to a single file
|
|
466
|
+
defaults = getattr(self, "_checkpoint_defaults", None)
|
|
467
|
+
if checkpoint_path is None and defaults:
|
|
468
|
+
checkpoint_path = defaults["path"]
|
|
469
|
+
checkpoint_every = defaults["every"]
|
|
470
|
+
checkpoint_save_config = defaults["save_config"]
|
|
471
|
+
saved_flow = defaults.get("saved_flow", False) if defaults else False
|
|
472
|
+
saved_config = (
|
|
473
|
+
defaults.get("saved_config", False) if defaults else False
|
|
474
|
+
)
|
|
475
|
+
if checkpoint_path is not None:
|
|
476
|
+
kwargs.setdefault("checkpoint_file_path", checkpoint_path)
|
|
477
|
+
kwargs.setdefault("checkpoint_every", checkpoint_every)
|
|
478
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
479
|
+
if checkpoint_save_config:
|
|
480
|
+
if "aspire_config" in h5_file:
|
|
481
|
+
del h5_file["aspire_config"]
|
|
482
|
+
self.save_config(
|
|
483
|
+
h5_file,
|
|
484
|
+
include_sampler_config=True,
|
|
485
|
+
include_sample_calls=False,
|
|
486
|
+
)
|
|
487
|
+
saved_config = True
|
|
488
|
+
if defaults is not None:
|
|
489
|
+
defaults["saved_config"] = True
|
|
490
|
+
if (
|
|
491
|
+
self.flow is not None
|
|
492
|
+
and not saved_flow
|
|
493
|
+
and "flow" not in h5_file
|
|
494
|
+
):
|
|
495
|
+
self.save_flow(h5_file)
|
|
496
|
+
saved_flow = True
|
|
497
|
+
if defaults is not None:
|
|
498
|
+
defaults["saved_flow"] = True
|
|
499
|
+
|
|
376
500
|
samples = self._sampler.sample(n_samples, **kwargs)
|
|
501
|
+
self._last_sample_posterior_kwargs = {
|
|
502
|
+
"n_samples": n_samples,
|
|
503
|
+
"sampler": sampler,
|
|
504
|
+
"xp": xp,
|
|
505
|
+
"return_history": return_history,
|
|
506
|
+
"preconditioning": preconditioning,
|
|
507
|
+
"preconditioning_kwargs": preconditioning_kwargs,
|
|
508
|
+
"sampler_init_kwargs": sampler_kwargs,
|
|
509
|
+
"sample_kwargs": copy.deepcopy(kwargs),
|
|
510
|
+
}
|
|
511
|
+
if checkpoint_path is not None:
|
|
512
|
+
with AspireFile(checkpoint_path, "a") as h5_file:
|
|
513
|
+
if checkpoint_save_config and not saved_config:
|
|
514
|
+
if "aspire_config" in h5_file:
|
|
515
|
+
del h5_file["aspire_config"]
|
|
516
|
+
self.save_config(
|
|
517
|
+
h5_file,
|
|
518
|
+
include_sampler_config=True,
|
|
519
|
+
include_sample_calls=False,
|
|
520
|
+
)
|
|
521
|
+
if defaults is not None:
|
|
522
|
+
defaults["saved_config"] = True
|
|
523
|
+
if (
|
|
524
|
+
self.flow is not None
|
|
525
|
+
and not saved_flow
|
|
526
|
+
and "flow" not in h5_file
|
|
527
|
+
):
|
|
528
|
+
self.save_flow(h5_file)
|
|
529
|
+
if defaults is not None:
|
|
530
|
+
defaults["saved_flow"] = True
|
|
377
531
|
if xp is not None:
|
|
378
532
|
samples = samples.to_namespace(xp)
|
|
379
533
|
samples.parameters = self.parameters
|
|
@@ -388,6 +542,122 @@ class Aspire:
|
|
|
388
542
|
else:
|
|
389
543
|
return samples
|
|
390
544
|
|
|
545
|
+
@classmethod
|
|
546
|
+
def resume_from_file(
|
|
547
|
+
cls,
|
|
548
|
+
file_path: str,
|
|
549
|
+
*,
|
|
550
|
+
log_likelihood: Callable,
|
|
551
|
+
log_prior: Callable,
|
|
552
|
+
sampler: str | None = None,
|
|
553
|
+
checkpoint_path: str = "checkpoint",
|
|
554
|
+
checkpoint_dset: str = "state",
|
|
555
|
+
flow_path: str = "flow",
|
|
556
|
+
config_path: str = "aspire_config",
|
|
557
|
+
resume_kwargs: dict | None = None,
|
|
558
|
+
):
|
|
559
|
+
"""
|
|
560
|
+
Recreate an Aspire object from a single file and prepare to resume sampling.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
----------
|
|
564
|
+
file_path : str
|
|
565
|
+
Path to the HDF5 file containing config, flow, and checkpoint.
|
|
566
|
+
log_likelihood : Callable
|
|
567
|
+
Log-likelihood function (required, not pickled).
|
|
568
|
+
log_prior : Callable
|
|
569
|
+
Log-prior function (required, not pickled).
|
|
570
|
+
sampler : str
|
|
571
|
+
Sampler type to use (e.g., 'smc', 'minipcn_smc', 'emcee_smc'). If None,
|
|
572
|
+
will attempt to infer from saved config or checkpoint metadata.
|
|
573
|
+
checkpoint_path : str
|
|
574
|
+
HDF5 group path where the checkpoint is stored.
|
|
575
|
+
checkpoint_dset : str
|
|
576
|
+
Dataset name within the checkpoint group.
|
|
577
|
+
flow_path : str
|
|
578
|
+
HDF5 path to the saved flow.
|
|
579
|
+
config_path : str
|
|
580
|
+
HDF5 path to the saved Aspire config.
|
|
581
|
+
resume_kwargs : dict | None
|
|
582
|
+
Optional overrides to apply when resuming (e.g., checkpoint_every).
|
|
583
|
+
"""
|
|
584
|
+
(
|
|
585
|
+
aspire,
|
|
586
|
+
checkpoint_bytes,
|
|
587
|
+
checkpoint_state,
|
|
588
|
+
sampler_config,
|
|
589
|
+
saved_sampler_type,
|
|
590
|
+
n_samples,
|
|
591
|
+
) = cls._build_aspire_from_file(
|
|
592
|
+
file_path=file_path,
|
|
593
|
+
log_likelihood=log_likelihood,
|
|
594
|
+
log_prior=log_prior,
|
|
595
|
+
checkpoint_path=checkpoint_path,
|
|
596
|
+
checkpoint_dset=checkpoint_dset,
|
|
597
|
+
flow_path=flow_path,
|
|
598
|
+
config_path=config_path,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
sampler_config = sampler_config or {}
|
|
602
|
+
sampler_config.pop("sampler_class", None)
|
|
603
|
+
|
|
604
|
+
if checkpoint_bytes is not None:
|
|
605
|
+
aspire._resume_from_default = checkpoint_bytes
|
|
606
|
+
aspire._resume_sampler_type = (
|
|
607
|
+
sampler
|
|
608
|
+
or saved_sampler_type
|
|
609
|
+
or (
|
|
610
|
+
checkpoint_state.get("sampler")
|
|
611
|
+
if checkpoint_state
|
|
612
|
+
else None
|
|
613
|
+
)
|
|
614
|
+
)
|
|
615
|
+
aspire._resume_n_samples = n_samples
|
|
616
|
+
aspire._resume_overrides = resume_kwargs or {}
|
|
617
|
+
aspire._resume_sampler_config = sampler_config
|
|
618
|
+
aspire._checkpoint_defaults = {
|
|
619
|
+
"path": file_path,
|
|
620
|
+
"every": 1,
|
|
621
|
+
"save_config": False,
|
|
622
|
+
"save_flow": False,
|
|
623
|
+
"saved_config": False,
|
|
624
|
+
"saved_flow": False,
|
|
625
|
+
}
|
|
626
|
+
return aspire
|
|
627
|
+
|
|
628
|
+
@contextmanager
|
|
629
|
+
def auto_checkpoint(
|
|
630
|
+
self,
|
|
631
|
+
path: str,
|
|
632
|
+
every: int = 1,
|
|
633
|
+
save_config: bool = True,
|
|
634
|
+
save_flow: bool = True,
|
|
635
|
+
):
|
|
636
|
+
"""
|
|
637
|
+
Context manager to auto-save checkpoints, config, and flow to a file.
|
|
638
|
+
|
|
639
|
+
Within the context, sample_posterior will default to writing checkpoints
|
|
640
|
+
to the given path with the specified frequency, and will append config/flow
|
|
641
|
+
after sampling.
|
|
642
|
+
"""
|
|
643
|
+
prev = getattr(self, "_checkpoint_defaults", None)
|
|
644
|
+
self._checkpoint_defaults = {
|
|
645
|
+
"path": path,
|
|
646
|
+
"every": every,
|
|
647
|
+
"save_config": save_config,
|
|
648
|
+
"save_flow": save_flow,
|
|
649
|
+
"saved_config": False,
|
|
650
|
+
"saved_flow": False,
|
|
651
|
+
}
|
|
652
|
+
try:
|
|
653
|
+
yield self
|
|
654
|
+
finally:
|
|
655
|
+
if prev is None:
|
|
656
|
+
if hasattr(self, "_checkpoint_defaults"):
|
|
657
|
+
delattr(self, "_checkpoint_defaults")
|
|
658
|
+
else:
|
|
659
|
+
self._checkpoint_defaults = prev
|
|
660
|
+
|
|
391
661
|
def enable_pool(self, pool: mp.Pool, **kwargs):
|
|
392
662
|
"""Context manager to temporarily replace the log_likelihood method
|
|
393
663
|
with a version that uses a multiprocessing pool to parallelize
|
|
@@ -432,12 +702,16 @@ class Aspire:
|
|
|
432
702
|
"flow_kwargs": self.flow_kwargs,
|
|
433
703
|
"eps": self.eps,
|
|
434
704
|
}
|
|
705
|
+
if hasattr(self, "_last_sampler_type"):
|
|
706
|
+
config["sampler_type"] = self._last_sampler_type
|
|
435
707
|
if include_sampler_config:
|
|
708
|
+
if self.sampler is None:
|
|
709
|
+
raise ValueError("Sampler has not been initialized.")
|
|
436
710
|
config["sampler_config"] = self.sampler.config_dict(**kwargs)
|
|
437
711
|
return config
|
|
438
712
|
|
|
439
713
|
def save_config(
|
|
440
|
-
self, h5_file: h5py.File, path="aspire_config", **kwargs
|
|
714
|
+
self, h5_file: h5py.File | AspireFile, path="aspire_config", **kwargs
|
|
441
715
|
) -> None:
|
|
442
716
|
"""Save the configuration to an HDF5 file.
|
|
443
717
|
|
|
@@ -484,6 +758,7 @@ class Aspire:
|
|
|
484
758
|
FlowClass, xp = get_flow_wrapper(
|
|
485
759
|
backend=self.flow_backend, flow_matching=self.flow_matching
|
|
486
760
|
)
|
|
761
|
+
logger.debug(f"Loading flow of type {FlowClass} from {path}")
|
|
487
762
|
self._flow = FlowClass.load(h5_file, path=path)
|
|
488
763
|
|
|
489
764
|
def save_config_to_json(self, filename: str) -> None:
|
|
@@ -504,3 +779,80 @@ class Aspire:
|
|
|
504
779
|
x, log_q = self.flow.sample_and_log_prob(n_samples)
|
|
505
780
|
samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
|
|
506
781
|
return samples
|
|
782
|
+
|
|
783
|
+
# --- Resume helpers ---
|
|
784
|
+
@staticmethod
|
|
785
|
+
def _build_aspire_from_file(
|
|
786
|
+
file_path: str,
|
|
787
|
+
log_likelihood: Callable,
|
|
788
|
+
log_prior: Callable,
|
|
789
|
+
checkpoint_path: str,
|
|
790
|
+
checkpoint_dset: str,
|
|
791
|
+
flow_path: str,
|
|
792
|
+
config_path: str,
|
|
793
|
+
):
|
|
794
|
+
"""Construct an Aspire instance, load flow, and gather checkpoint metadata from file."""
|
|
795
|
+
with AspireFile(file_path, "r") as h5_file:
|
|
796
|
+
if config_path not in h5_file:
|
|
797
|
+
raise ValueError(
|
|
798
|
+
f"Config path '{config_path}' not found in {file_path}"
|
|
799
|
+
)
|
|
800
|
+
config_dict = load_from_h5_file(h5_file, config_path)
|
|
801
|
+
try:
|
|
802
|
+
checkpoint_bytes = h5_file[checkpoint_path][checkpoint_dset][
|
|
803
|
+
...
|
|
804
|
+
].tobytes()
|
|
805
|
+
except Exception:
|
|
806
|
+
logger.warning(
|
|
807
|
+
"Checkpoint not found at %s/%s in %s; will resume without a checkpoint.",
|
|
808
|
+
checkpoint_path,
|
|
809
|
+
checkpoint_dset,
|
|
810
|
+
file_path,
|
|
811
|
+
)
|
|
812
|
+
checkpoint_bytes = None
|
|
813
|
+
|
|
814
|
+
sampler_config = config_dict.pop("sampler_config", None)
|
|
815
|
+
saved_sampler_type = config_dict.pop("sampler_type", None)
|
|
816
|
+
if isinstance(config_dict.get("xp"), str):
|
|
817
|
+
config_dict["xp"] = resolve_xp(config_dict["xp"])
|
|
818
|
+
config_dict["log_likelihood"] = log_likelihood
|
|
819
|
+
config_dict["log_prior"] = log_prior
|
|
820
|
+
|
|
821
|
+
aspire = Aspire(**config_dict)
|
|
822
|
+
|
|
823
|
+
with AspireFile(file_path, "r") as h5_file:
|
|
824
|
+
if flow_path in h5_file:
|
|
825
|
+
logger.info(f"Loading flow from {flow_path} in {file_path}")
|
|
826
|
+
aspire.load_flow(h5_file, path=flow_path)
|
|
827
|
+
else:
|
|
828
|
+
raise ValueError(
|
|
829
|
+
f"Flow path '{flow_path}' not found in {file_path}"
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
n_samples = None
|
|
833
|
+
checkpoint_state = None
|
|
834
|
+
if checkpoint_bytes is not None:
|
|
835
|
+
try:
|
|
836
|
+
checkpoint_state = pickle.loads(checkpoint_bytes)
|
|
837
|
+
samples_saved = (
|
|
838
|
+
checkpoint_state.get("samples")
|
|
839
|
+
if checkpoint_state
|
|
840
|
+
else None
|
|
841
|
+
)
|
|
842
|
+
if samples_saved is not None:
|
|
843
|
+
n_samples = len(samples_saved)
|
|
844
|
+
if aspire.xp is None and hasattr(samples_saved, "xp"):
|
|
845
|
+
aspire.xp = samples_saved.xp
|
|
846
|
+
except Exception:
|
|
847
|
+
logger.warning(
|
|
848
|
+
"Failed to decode checkpoint; proceeding without resume state."
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
return (
|
|
852
|
+
aspire,
|
|
853
|
+
checkpoint_bytes,
|
|
854
|
+
checkpoint_state,
|
|
855
|
+
sampler_config,
|
|
856
|
+
saved_sampler_type,
|
|
857
|
+
n_samples,
|
|
858
|
+
)
|
aspire/flows/torch/flows.py
CHANGED
|
@@ -92,7 +92,7 @@ class BaseTorchFlow(Flow):
|
|
|
92
92
|
config = load_from_h5_file(flow_grp, "config")
|
|
93
93
|
config["dtype"] = decode_dtype(torch, config.get("dtype"))
|
|
94
94
|
if "data_transform" in flow_grp:
|
|
95
|
-
from
|
|
95
|
+
from ...transforms import BaseTransform
|
|
96
96
|
|
|
97
97
|
data_transform = BaseTransform.load(
|
|
98
98
|
flow_grp,
|
aspire/samplers/base.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import pickle
|
|
3
|
+
from pathlib import Path
|
|
2
4
|
from typing import Any, Callable
|
|
3
5
|
|
|
4
6
|
from ..flows.base import Flow
|
|
5
7
|
from ..samples import Samples
|
|
6
8
|
from ..transforms import IdentityTransform
|
|
7
|
-
from ..utils import asarray, track_calls
|
|
9
|
+
from ..utils import AspireFile, asarray, dump_state, track_calls
|
|
8
10
|
|
|
9
11
|
logger = logging.getLogger(__name__)
|
|
10
12
|
|
|
@@ -49,6 +51,8 @@ class Sampler:
|
|
|
49
51
|
self.parameters = parameters
|
|
50
52
|
self.history = None
|
|
51
53
|
self.n_likelihood_evaluations = 0
|
|
54
|
+
self._last_checkpoint_state: dict | None = None
|
|
55
|
+
self._last_checkpoint_bytes: bytes | None = None
|
|
52
56
|
if preconditioning_transform is None:
|
|
53
57
|
self.preconditioning_transform = IdentityTransform(xp=self.xp)
|
|
54
58
|
else:
|
|
@@ -75,7 +79,7 @@ class Sampler:
|
|
|
75
79
|
self.n_likelihood_evaluations += len(samples)
|
|
76
80
|
return self._log_likelihood(samples)
|
|
77
81
|
|
|
78
|
-
def config_dict(self, include_sample_calls: bool =
|
|
82
|
+
def config_dict(self, include_sample_calls: bool = False) -> dict:
|
|
79
83
|
"""
|
|
80
84
|
Returns a dictionary with the configuration of the sampler.
|
|
81
85
|
|
|
@@ -83,9 +87,9 @@ class Sampler:
|
|
|
83
87
|
----------
|
|
84
88
|
include_sample_calls : bool
|
|
85
89
|
Whether to include the sample calls in the configuration.
|
|
86
|
-
Default is
|
|
90
|
+
Default is False.
|
|
87
91
|
"""
|
|
88
|
-
config = {}
|
|
92
|
+
config = {"sampler_class": self.__class__.__name__}
|
|
89
93
|
if include_sample_calls:
|
|
90
94
|
if hasattr(self, "sample") and hasattr(self.sample, "calls"):
|
|
91
95
|
config["sample_calls"] = self.sample.calls.to_dict(
|
|
@@ -96,3 +100,139 @@ class Sampler:
|
|
|
96
100
|
"Sampler does not have a sample method with calls attribute."
|
|
97
101
|
)
|
|
98
102
|
return config
|
|
103
|
+
|
|
104
|
+
# --- Checkpointing helpers shared across samplers ---
|
|
105
|
+
def _checkpoint_extra_state(self) -> dict:
|
|
106
|
+
"""Sampler-specific extras for checkpointing (override in subclasses)."""
|
|
107
|
+
return {}
|
|
108
|
+
|
|
109
|
+
def _restore_extra_state(self, state: dict) -> None:
|
|
110
|
+
"""Restore sampler-specific extras (override in subclasses)."""
|
|
111
|
+
_ = state # no-op for base
|
|
112
|
+
|
|
113
|
+
def build_checkpoint_state(
|
|
114
|
+
self,
|
|
115
|
+
samples: Samples,
|
|
116
|
+
iteration: int | None = None,
|
|
117
|
+
meta: dict | None = None,
|
|
118
|
+
) -> dict:
|
|
119
|
+
"""Prepare a serializable checkpoint payload for the sampler state."""
|
|
120
|
+
checkpoint_samples = samples
|
|
121
|
+
base_state = {
|
|
122
|
+
"sampler": self.__class__.__name__,
|
|
123
|
+
"iteration": iteration,
|
|
124
|
+
"samples": checkpoint_samples,
|
|
125
|
+
"config": self.config_dict(include_sample_calls=False),
|
|
126
|
+
"parameters": self.parameters,
|
|
127
|
+
"meta": meta or {},
|
|
128
|
+
}
|
|
129
|
+
base_state.update(self._checkpoint_extra_state())
|
|
130
|
+
return base_state
|
|
131
|
+
|
|
132
|
+
def serialize_checkpoint(
|
|
133
|
+
self, state: dict, protocol: int | None = None
|
|
134
|
+
) -> bytes:
|
|
135
|
+
"""Serialize a checkpoint state to bytes with pickle."""
|
|
136
|
+
protocol = (
|
|
137
|
+
pickle.HIGHEST_PROTOCOL if protocol is None else int(protocol)
|
|
138
|
+
)
|
|
139
|
+
return pickle.dumps(state, protocol=protocol)
|
|
140
|
+
|
|
141
|
+
def default_checkpoint_callback(self, state: dict) -> None:
|
|
142
|
+
"""Store the latest checkpoint (state + pickled bytes) on the sampler."""
|
|
143
|
+
self._last_checkpoint_state = state
|
|
144
|
+
self._last_checkpoint_bytes = self.serialize_checkpoint(state)
|
|
145
|
+
|
|
146
|
+
def default_file_checkpoint_callback(
|
|
147
|
+
self, file_path: str | Path | None
|
|
148
|
+
) -> Callable[[dict], None]:
|
|
149
|
+
"""Return a simple default callback that overwrites an HDF5 file."""
|
|
150
|
+
if file_path is None:
|
|
151
|
+
return self.default_checkpoint_callback
|
|
152
|
+
file_path = Path(file_path)
|
|
153
|
+
lower_path = file_path.name.lower()
|
|
154
|
+
if not lower_path.endswith((".h5", ".hdf5")):
|
|
155
|
+
raise ValueError(
|
|
156
|
+
"Checkpoint file must be an HDF5 file (.h5 or .hdf5)."
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def _callback(state: dict) -> None:
|
|
160
|
+
with AspireFile(file_path, "a") as h5_file:
|
|
161
|
+
self.save_checkpoint_to_hdf(
|
|
162
|
+
state, h5_file, path="checkpoint", dsetname="state"
|
|
163
|
+
)
|
|
164
|
+
self.default_checkpoint_callback(state)
|
|
165
|
+
|
|
166
|
+
return _callback
|
|
167
|
+
|
|
168
|
+
def save_checkpoint_to_hdf(
|
|
169
|
+
self,
|
|
170
|
+
state: dict,
|
|
171
|
+
h5_file,
|
|
172
|
+
path: str = "sampler_checkpoints",
|
|
173
|
+
dsetname: str | None = None,
|
|
174
|
+
protocol: int | None = None,
|
|
175
|
+
) -> None:
|
|
176
|
+
"""Save a checkpoint state into an HDF5 file as a pickled blob."""
|
|
177
|
+
if dsetname is None:
|
|
178
|
+
iter_str = state.get("iteration", "unknown")
|
|
179
|
+
dsetname = f"iter_{iter_str}"
|
|
180
|
+
dump_state(
|
|
181
|
+
state,
|
|
182
|
+
h5_file,
|
|
183
|
+
path=path,
|
|
184
|
+
dsetname=dsetname,
|
|
185
|
+
protocol=protocol or pickle.HIGHEST_PROTOCOL,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def load_checkpoint_from_file(
|
|
189
|
+
self,
|
|
190
|
+
file_path: str | Path,
|
|
191
|
+
h5_path: str = "checkpoint",
|
|
192
|
+
dsetname: str = "state",
|
|
193
|
+
) -> dict:
|
|
194
|
+
"""Load a checkpoint dictionary from .pkl or .hdf5 file."""
|
|
195
|
+
file_path = Path(file_path)
|
|
196
|
+
lower_path = file_path.name.lower()
|
|
197
|
+
if lower_path.endswith((".h5", ".hdf5")):
|
|
198
|
+
with AspireFile(file_path, "r") as h5_file:
|
|
199
|
+
data = h5_file[h5_path][dsetname][...]
|
|
200
|
+
checkpoint_bytes = data.tobytes()
|
|
201
|
+
else:
|
|
202
|
+
with open(file_path, "rb") as f:
|
|
203
|
+
checkpoint_bytes = f.read()
|
|
204
|
+
return pickle.loads(checkpoint_bytes)
|
|
205
|
+
|
|
206
|
+
def restore_from_checkpoint(
|
|
207
|
+
self, source: str | bytes | dict
|
|
208
|
+
) -> tuple[Samples, dict]:
|
|
209
|
+
"""Restore sampler state from a checkpoint source."""
|
|
210
|
+
if isinstance(source, str):
|
|
211
|
+
state = self.load_checkpoint_from_file(source)
|
|
212
|
+
elif isinstance(source, bytes):
|
|
213
|
+
state = pickle.loads(source)
|
|
214
|
+
elif isinstance(source, dict):
|
|
215
|
+
state = source
|
|
216
|
+
else:
|
|
217
|
+
raise TypeError("Unsupported checkpoint source type.")
|
|
218
|
+
|
|
219
|
+
samples_saved = state.get("samples")
|
|
220
|
+
if samples_saved is None:
|
|
221
|
+
raise ValueError("Checkpoint missing samples.")
|
|
222
|
+
|
|
223
|
+
samples = Samples.from_samples(
|
|
224
|
+
samples_saved, xp=self.xp, dtype=self.dtype
|
|
225
|
+
)
|
|
226
|
+
# Allow subclasses to restore sampler-specific components
|
|
227
|
+
self._restore_extra_state(state)
|
|
228
|
+
return samples, state
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def last_checkpoint_state(self) -> dict | None:
|
|
232
|
+
"""Return the most recent checkpoint state stored by the default callback."""
|
|
233
|
+
return self._last_checkpoint_state
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def last_checkpoint_bytes(self) -> bytes | None:
|
|
237
|
+
"""Return the most recent pickled checkpoint produced by the default callback."""
|
|
238
|
+
return self._last_checkpoint_bytes
|
aspire/samplers/smc/base.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import logging
|
|
2
3
|
from typing import Any, Callable
|
|
3
4
|
|
|
@@ -158,11 +159,24 @@ class SMCSampler(MCMCSampler):
|
|
|
158
159
|
target_efficiency: float = 0.5,
|
|
159
160
|
target_efficiency_rate: float = 1.0,
|
|
160
161
|
n_final_samples: int | None = None,
|
|
162
|
+
checkpoint_callback: Callable[[dict], None] | None = None,
|
|
163
|
+
checkpoint_every: int | None = None,
|
|
164
|
+
checkpoint_file_path: str | None = None,
|
|
165
|
+
resume_from: str | bytes | dict | None = None,
|
|
161
166
|
) -> SMCSamples:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
samples,
|
|
165
|
-
|
|
167
|
+
resumed = resume_from is not None
|
|
168
|
+
if resumed:
|
|
169
|
+
samples, beta, iterations = self.restore_from_checkpoint(
|
|
170
|
+
resume_from
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
samples = self.draw_initial_samples(n_samples)
|
|
174
|
+
samples = SMCSamples.from_samples(
|
|
175
|
+
samples, xp=self.xp, beta=0.0, dtype=self.dtype
|
|
176
|
+
)
|
|
177
|
+
beta = 0.0
|
|
178
|
+
iterations = 0
|
|
179
|
+
self.history = SMCHistory()
|
|
166
180
|
self.fit_preconditioning_transform(samples.x)
|
|
167
181
|
|
|
168
182
|
if self.xp.isnan(samples.log_q).any():
|
|
@@ -178,8 +192,6 @@ class SMCSampler(MCMCSampler):
|
|
|
178
192
|
self.sampler_kwargs = self.sampler_kwargs or {}
|
|
179
193
|
n_final_steps = self.sampler_kwargs.pop("n_final_steps", None)
|
|
180
194
|
|
|
181
|
-
self.history = SMCHistory()
|
|
182
|
-
|
|
183
195
|
self.target_efficiency = target_efficiency
|
|
184
196
|
self.target_efficiency_rate = target_efficiency_rate
|
|
185
197
|
|
|
@@ -190,7 +202,6 @@ class SMCSampler(MCMCSampler):
|
|
|
190
202
|
else:
|
|
191
203
|
beta_step = np.nan
|
|
192
204
|
self.adaptive = adaptive
|
|
193
|
-
beta = 0.0
|
|
194
205
|
|
|
195
206
|
if min_step is None:
|
|
196
207
|
if max_n_steps is None:
|
|
@@ -202,55 +213,85 @@ class SMCSampler(MCMCSampler):
|
|
|
202
213
|
else:
|
|
203
214
|
self.adaptive_min_step = False
|
|
204
215
|
|
|
205
|
-
iterations = 0
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
beta, min_step = self.determine_beta(
|
|
210
|
-
samples,
|
|
211
|
-
beta,
|
|
212
|
-
beta_step,
|
|
213
|
-
min_step,
|
|
216
|
+
iterations = iterations or 0
|
|
217
|
+
if checkpoint_callback is None and checkpoint_every is not None:
|
|
218
|
+
checkpoint_callback = self.default_file_checkpoint_callback(
|
|
219
|
+
checkpoint_file_path
|
|
214
220
|
)
|
|
215
|
-
|
|
216
|
-
|
|
221
|
+
if checkpoint_callback is not None and checkpoint_every is None:
|
|
222
|
+
checkpoint_every = 1
|
|
223
|
+
|
|
224
|
+
run_smc_loop = True
|
|
225
|
+
if resumed:
|
|
226
|
+
last_beta = self.history.beta[-1] if self.history.beta else beta
|
|
227
|
+
if last_beta >= 1.0:
|
|
228
|
+
run_smc_loop = False
|
|
229
|
+
|
|
230
|
+
def maybe_checkpoint(force: bool = False):
|
|
231
|
+
if checkpoint_callback is None:
|
|
232
|
+
return
|
|
233
|
+
should_checkpoint = force or (
|
|
234
|
+
checkpoint_every is not None
|
|
235
|
+
and checkpoint_every > 0
|
|
236
|
+
and iterations % checkpoint_every == 0
|
|
217
237
|
)
|
|
238
|
+
if not should_checkpoint:
|
|
239
|
+
return
|
|
240
|
+
state = self.build_checkpoint_state(samples, iterations, beta)
|
|
241
|
+
checkpoint_callback(state)
|
|
242
|
+
|
|
243
|
+
if run_smc_loop:
|
|
244
|
+
while True:
|
|
245
|
+
iterations += 1
|
|
246
|
+
|
|
247
|
+
beta, min_step = self.determine_beta(
|
|
248
|
+
samples,
|
|
249
|
+
beta,
|
|
250
|
+
beta_step,
|
|
251
|
+
min_step,
|
|
252
|
+
)
|
|
253
|
+
self.history.eff_target.append(
|
|
254
|
+
self.current_target_efficiency(beta)
|
|
255
|
+
)
|
|
218
256
|
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
257
|
+
logger.info(f"it {iterations} - beta: {beta}")
|
|
258
|
+
self.history.beta.append(beta)
|
|
259
|
+
|
|
260
|
+
ess = effective_sample_size(samples.log_weights(beta))
|
|
261
|
+
eff = ess / len(samples)
|
|
262
|
+
if eff < 0.1:
|
|
263
|
+
logger.warning(
|
|
264
|
+
f"it {iterations} - Low sample efficiency: {eff:.2f}"
|
|
265
|
+
)
|
|
266
|
+
self.history.ess.append(ess)
|
|
267
|
+
logger.info(
|
|
268
|
+
f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
|
|
269
|
+
)
|
|
270
|
+
self.history.ess_target.append(
|
|
271
|
+
effective_sample_size(samples.log_weights(1.0))
|
|
227
272
|
)
|
|
228
|
-
self.history.ess.append(ess)
|
|
229
|
-
logger.info(
|
|
230
|
-
f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
|
|
231
|
-
)
|
|
232
|
-
self.history.ess_target.append(
|
|
233
|
-
effective_sample_size(samples.log_weights(1.0))
|
|
234
|
-
)
|
|
235
273
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
274
|
+
log_evidence_ratio = samples.log_evidence_ratio(beta)
|
|
275
|
+
log_evidence_ratio_var = samples.log_evidence_ratio_variance(
|
|
276
|
+
beta
|
|
277
|
+
)
|
|
278
|
+
self.history.log_norm_ratio.append(log_evidence_ratio)
|
|
279
|
+
self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
|
|
280
|
+
logger.info(
|
|
281
|
+
f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
|
|
282
|
+
)
|
|
243
283
|
|
|
244
|
-
|
|
284
|
+
samples = samples.resample(beta, rng=self.rng)
|
|
245
285
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
286
|
+
samples = self.mutate(samples, beta)
|
|
287
|
+
maybe_checkpoint()
|
|
288
|
+
if beta == 1.0 or (
|
|
289
|
+
max_n_steps is not None and iterations >= max_n_steps
|
|
290
|
+
):
|
|
291
|
+
break
|
|
251
292
|
|
|
252
|
-
# If n_final_samples is
|
|
253
|
-
if n_final_samples is not None:
|
|
293
|
+
# If n_final_samples is specified and differs, perform additional mutation steps
|
|
294
|
+
if n_final_samples is not None and len(samples.x) != n_final_samples:
|
|
254
295
|
logger.info(f"Generating {n_final_samples} final samples")
|
|
255
296
|
final_samples = samples.resample(
|
|
256
297
|
1.0, n_samples=n_final_samples, rng=self.rng
|
|
@@ -263,6 +304,7 @@ class SMCSampler(MCMCSampler):
|
|
|
263
304
|
samples.log_evidence_error = samples.xp.sqrt(
|
|
264
305
|
samples.xp.sum(asarray(self.history.log_norm_ratio_var, self.xp))
|
|
265
306
|
)
|
|
307
|
+
maybe_checkpoint(force=True)
|
|
266
308
|
|
|
267
309
|
final_samples = samples.to_standard_samples()
|
|
268
310
|
logger.info(
|
|
@@ -289,6 +331,49 @@ class SMCSampler(MCMCSampler):
|
|
|
289
331
|
)
|
|
290
332
|
return log_prob
|
|
291
333
|
|
|
334
|
+
def build_checkpoint_state(
|
|
335
|
+
self, samples: SMCSamples, iteration: int, beta: float
|
|
336
|
+
) -> dict:
|
|
337
|
+
"""Prepare a serializable checkpoint payload for the sampler state."""
|
|
338
|
+
return super().build_checkpoint_state(
|
|
339
|
+
samples,
|
|
340
|
+
iteration,
|
|
341
|
+
meta={"beta": beta},
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def _checkpoint_extra_state(self) -> dict:
|
|
345
|
+
history_copy = copy.deepcopy(self.history)
|
|
346
|
+
rng_state = (
|
|
347
|
+
self.rng.bit_generator.state
|
|
348
|
+
if hasattr(self.rng, "bit_generator")
|
|
349
|
+
else None
|
|
350
|
+
)
|
|
351
|
+
return {
|
|
352
|
+
"history": history_copy,
|
|
353
|
+
"rng_state": rng_state,
|
|
354
|
+
"sampler_kwargs": getattr(self, "sampler_kwargs", None),
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
def restore_from_checkpoint(
|
|
358
|
+
self, source: str | bytes | dict
|
|
359
|
+
) -> tuple[SMCSamples, float, int]:
|
|
360
|
+
samples, state = super().restore_from_checkpoint(source)
|
|
361
|
+
meta = state.get("meta", {}) if isinstance(state, dict) else {}
|
|
362
|
+
beta = None
|
|
363
|
+
if isinstance(meta, dict):
|
|
364
|
+
beta = meta.get("beta", None)
|
|
365
|
+
if beta is None:
|
|
366
|
+
beta = state.get("beta", 0.0)
|
|
367
|
+
iteration = state.get("iteration", 0)
|
|
368
|
+
self.history = state.get("history", SMCHistory())
|
|
369
|
+
rng_state = state.get("rng_state")
|
|
370
|
+
if rng_state is not None and hasattr(self.rng, "bit_generator"):
|
|
371
|
+
self.rng.bit_generator.state = rng_state
|
|
372
|
+
samples = SMCSamples.from_samples(
|
|
373
|
+
samples, xp=self.xp, beta=beta, dtype=self.dtype
|
|
374
|
+
)
|
|
375
|
+
return samples, beta, iteration
|
|
376
|
+
|
|
292
377
|
|
|
293
378
|
class NumpySMCSampler(SMCSampler):
|
|
294
379
|
def __init__(
|
aspire/samplers/smc/blackjax.py
CHANGED
|
@@ -84,6 +84,10 @@ class BlackJAXSMC(SMCSampler):
|
|
|
84
84
|
n_final_samples: int | None = None,
|
|
85
85
|
sampler_kwargs: dict | None = None,
|
|
86
86
|
rng_key=None,
|
|
87
|
+
checkpoint_callback=None,
|
|
88
|
+
checkpoint_every: int | None = None,
|
|
89
|
+
checkpoint_file_path: str | None = None,
|
|
90
|
+
resume_from: str | bytes | dict | None = None,
|
|
87
91
|
):
|
|
88
92
|
"""Sample using BlackJAX SMC.
|
|
89
93
|
|
|
@@ -132,6 +136,10 @@ class BlackJAXSMC(SMCSampler):
|
|
|
132
136
|
target_efficiency=target_efficiency,
|
|
133
137
|
target_efficiency_rate=target_efficiency_rate,
|
|
134
138
|
n_final_samples=n_final_samples,
|
|
139
|
+
checkpoint_callback=checkpoint_callback,
|
|
140
|
+
checkpoint_every=checkpoint_every,
|
|
141
|
+
checkpoint_file_path=checkpoint_file_path,
|
|
142
|
+
resume_from=resume_from,
|
|
135
143
|
)
|
|
136
144
|
|
|
137
145
|
def mutate(self, particles, beta, n_steps=None):
|
aspire/samplers/smc/emcee.py
CHANGED
|
@@ -21,6 +21,10 @@ class EmceeSMC(NumpySMCSampler):
|
|
|
21
21
|
target_efficiency_rate: float = 1.0,
|
|
22
22
|
sampler_kwargs: dict | None = None,
|
|
23
23
|
n_final_samples: int | None = None,
|
|
24
|
+
checkpoint_callback=None,
|
|
25
|
+
checkpoint_every: int | None = None,
|
|
26
|
+
checkpoint_file_path: str | None = None,
|
|
27
|
+
resume_from: str | bytes | dict | None = None,
|
|
24
28
|
):
|
|
25
29
|
self.sampler_kwargs = sampler_kwargs or {}
|
|
26
30
|
self.sampler_kwargs.setdefault("nsteps", 5 * self.dims)
|
|
@@ -33,6 +37,10 @@ class EmceeSMC(NumpySMCSampler):
|
|
|
33
37
|
target_efficiency=target_efficiency,
|
|
34
38
|
target_efficiency_rate=target_efficiency_rate,
|
|
35
39
|
n_final_samples=n_final_samples,
|
|
40
|
+
checkpoint_callback=checkpoint_callback,
|
|
41
|
+
checkpoint_every=checkpoint_every,
|
|
42
|
+
checkpoint_file_path=checkpoint_file_path,
|
|
43
|
+
resume_from=resume_from,
|
|
36
44
|
)
|
|
37
45
|
|
|
38
46
|
def mutate(self, particles, beta, n_steps=None):
|
aspire/samplers/smc/minipcn.py
CHANGED
|
@@ -8,10 +8,10 @@ from ...utils import (
|
|
|
8
8
|
determine_backend_name,
|
|
9
9
|
track_calls,
|
|
10
10
|
)
|
|
11
|
-
from .base import
|
|
11
|
+
from .base import SMCSampler
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
class MiniPCNSMC(
|
|
14
|
+
class MiniPCNSMC(SMCSampler):
|
|
15
15
|
"""MiniPCN SMC sampler."""
|
|
16
16
|
|
|
17
17
|
rng = None
|
|
@@ -32,6 +32,10 @@ class MiniPCNSMC(NumpySMCSampler):
|
|
|
32
32
|
n_final_samples: int | None = None,
|
|
33
33
|
sampler_kwargs: dict | None = None,
|
|
34
34
|
rng: np.random.Generator | None = None,
|
|
35
|
+
checkpoint_callback=None,
|
|
36
|
+
checkpoint_every: int | None = None,
|
|
37
|
+
checkpoint_file_path: str | None = None,
|
|
38
|
+
resume_from: str | bytes | dict | None = None,
|
|
35
39
|
):
|
|
36
40
|
from orng import ArrayRNG
|
|
37
41
|
|
|
@@ -50,6 +54,10 @@ class MiniPCNSMC(NumpySMCSampler):
|
|
|
50
54
|
n_final_samples=n_final_samples,
|
|
51
55
|
min_step=min_step,
|
|
52
56
|
max_n_steps=max_n_steps,
|
|
57
|
+
checkpoint_callback=checkpoint_callback,
|
|
58
|
+
checkpoint_every=checkpoint_every,
|
|
59
|
+
checkpoint_file_path=checkpoint_file_path,
|
|
60
|
+
resume_from=resume_from,
|
|
53
61
|
)
|
|
54
62
|
|
|
55
63
|
def mutate(self, particles, beta, n_steps=None):
|
aspire/samples.py
CHANGED
|
@@ -9,19 +9,18 @@ from typing import Any, Callable
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from array_api_compat import (
|
|
11
11
|
array_namespace,
|
|
12
|
-
is_numpy_namespace,
|
|
13
|
-
to_device,
|
|
14
12
|
)
|
|
15
|
-
from array_api_compat import device as api_device
|
|
16
13
|
from array_api_compat.common._typing import Array
|
|
17
14
|
from matplotlib.figure import Figure
|
|
18
15
|
|
|
19
16
|
from .utils import (
|
|
20
17
|
asarray,
|
|
21
18
|
convert_dtype,
|
|
19
|
+
infer_device,
|
|
22
20
|
logsumexp,
|
|
23
21
|
recursively_save_to_h5_file,
|
|
24
22
|
resolve_dtype,
|
|
23
|
+
safe_to_device,
|
|
25
24
|
to_numpy,
|
|
26
25
|
)
|
|
27
26
|
|
|
@@ -67,8 +66,6 @@ class BaseSamples:
|
|
|
67
66
|
if self.xp is None:
|
|
68
67
|
self.xp = array_namespace(self.x)
|
|
69
68
|
# Numpy arrays need to be on the CPU before being converted
|
|
70
|
-
if is_numpy_namespace(self.xp):
|
|
71
|
-
self.device = "cpu"
|
|
72
69
|
if self.dtype is not None:
|
|
73
70
|
self.dtype = resolve_dtype(self.dtype, self.xp)
|
|
74
71
|
else:
|
|
@@ -76,7 +73,7 @@ class BaseSamples:
|
|
|
76
73
|
self.dtype = None
|
|
77
74
|
self.x = self.array_to_namespace(self.x, dtype=self.dtype)
|
|
78
75
|
if self.device is None:
|
|
79
|
-
self.device =
|
|
76
|
+
self.device = infer_device(self.x, self.xp)
|
|
80
77
|
if self.log_likelihood is not None:
|
|
81
78
|
self.log_likelihood = self.array_to_namespace(
|
|
82
79
|
self.log_likelihood, dtype=self.dtype
|
|
@@ -140,8 +137,7 @@ class BaseSamples:
|
|
|
140
137
|
else:
|
|
141
138
|
kwargs["dtype"] = self.dtype
|
|
142
139
|
x = asarray(x, self.xp, **kwargs)
|
|
143
|
-
|
|
144
|
-
x = to_device(x, self.device)
|
|
140
|
+
x = safe_to_device(x, self.device, self.xp)
|
|
145
141
|
return x
|
|
146
142
|
|
|
147
143
|
def to_dict(self, flat: bool = True):
|
|
@@ -174,7 +170,6 @@ class BaseSamples:
|
|
|
174
170
|
----------
|
|
175
171
|
parameters : list[str] | None
|
|
176
172
|
List of parameters to plot. If None, all parameters are plotted.
|
|
177
|
-
fig : matplotlib.figure.Figure | None
|
|
178
173
|
Figure to plot on. If None, a new figure is created.
|
|
179
174
|
**kwargs : dict
|
|
180
175
|
Additional keyword arguments to pass to corner.corner(). Kwargs
|
|
@@ -300,6 +295,13 @@ class BaseSamples:
|
|
|
300
295
|
def __setstate__(self, state):
|
|
301
296
|
# Restore xp by checking the namespace of x
|
|
302
297
|
state["xp"] = array_namespace(state["x"])
|
|
298
|
+
# device may be string; leave as-is or None
|
|
299
|
+
device = state.get("device")
|
|
300
|
+
if device is not None and "jax" in getattr(
|
|
301
|
+
state["xp"], "__name__", ""
|
|
302
|
+
):
|
|
303
|
+
device = None
|
|
304
|
+
state["device"] = device
|
|
303
305
|
self.__dict__.update(state)
|
|
304
306
|
|
|
305
307
|
|
aspire/utils.py
CHANGED
|
@@ -2,9 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import logging
|
|
5
|
+
import pickle
|
|
5
6
|
from contextlib import contextmanager
|
|
6
7
|
from dataclasses import dataclass
|
|
7
8
|
from functools import partial
|
|
9
|
+
from io import BytesIO
|
|
8
10
|
from typing import TYPE_CHECKING, Any
|
|
9
11
|
|
|
10
12
|
import array_api_compat.numpy as np
|
|
@@ -601,6 +603,7 @@ def decode_from_hdf5(value: Any) -> Any:
|
|
|
601
603
|
return None
|
|
602
604
|
if value == "__empty_dict__":
|
|
603
605
|
return {}
|
|
606
|
+
return value
|
|
604
607
|
|
|
605
608
|
if isinstance(value, np.ndarray):
|
|
606
609
|
# Try to collapse 0-D arrays into scalars
|
|
@@ -629,6 +632,101 @@ def decode_from_hdf5(value: Any) -> Any:
|
|
|
629
632
|
return value
|
|
630
633
|
|
|
631
634
|
|
|
635
|
+
def dump_pickle_to_hdf(memfp, fp, path=None, dsetname="state"):
|
|
636
|
+
"""Dump pickled data to an HDF5 file object."""
|
|
637
|
+
memfp.seek(0)
|
|
638
|
+
bdata = np.frombuffer(memfp.read(), dtype="S1")
|
|
639
|
+
target = fp.require_group(path) if path is not None else fp
|
|
640
|
+
if dsetname not in target:
|
|
641
|
+
target.create_dataset(
|
|
642
|
+
dsetname, shape=bdata.shape, maxshape=(None,), dtype=bdata.dtype
|
|
643
|
+
)
|
|
644
|
+
elif bdata.size != target[dsetname].shape[0]:
|
|
645
|
+
target[dsetname].resize((bdata.size,))
|
|
646
|
+
target[dsetname][:] = bdata
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def dump_state(
|
|
650
|
+
state,
|
|
651
|
+
fp,
|
|
652
|
+
path=None,
|
|
653
|
+
dsetname="state",
|
|
654
|
+
protocol=pickle.HIGHEST_PROTOCOL,
|
|
655
|
+
):
|
|
656
|
+
"""Pickle a state object and store it in an HDF5 dataset."""
|
|
657
|
+
memfp = BytesIO()
|
|
658
|
+
pickle.dump(state, memfp, protocol=protocol)
|
|
659
|
+
dump_pickle_to_hdf(memfp, fp, path=path, dsetname=dsetname)
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
def resolve_xp(xp_name: str | None):
|
|
663
|
+
"""
|
|
664
|
+
Resolve a backend name to the corresponding array_api_compat module.
|
|
665
|
+
|
|
666
|
+
Returns None if the name is None or cannot be resolved.
|
|
667
|
+
"""
|
|
668
|
+
if xp_name is None:
|
|
669
|
+
return None
|
|
670
|
+
name = xp_name.lower()
|
|
671
|
+
if name.startswith("array_api_compat."):
|
|
672
|
+
name = name.removeprefix("array_api_compat.")
|
|
673
|
+
try:
|
|
674
|
+
if name in {"numpy", "numpy.ndarray"}:
|
|
675
|
+
import array_api_compat.numpy as np_xp
|
|
676
|
+
|
|
677
|
+
return np_xp
|
|
678
|
+
if name in {"jax", "jax.numpy"}:
|
|
679
|
+
import jax.numpy as jnp
|
|
680
|
+
|
|
681
|
+
return jnp
|
|
682
|
+
if name in {"torch"}:
|
|
683
|
+
import array_api_compat.torch as torch_xp
|
|
684
|
+
|
|
685
|
+
return torch_xp
|
|
686
|
+
except Exception:
|
|
687
|
+
logger.warning(
|
|
688
|
+
"Failed to resolve xp '%s', defaulting to None", xp_name
|
|
689
|
+
)
|
|
690
|
+
return None
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def infer_device(x, xp):
|
|
694
|
+
"""
|
|
695
|
+
Best-effort device inference that avoids non-portable identifiers.
|
|
696
|
+
|
|
697
|
+
Returns None for numpy/jax backends; returns the backend device object
|
|
698
|
+
for torch/cupy if available.
|
|
699
|
+
"""
|
|
700
|
+
if xp is None or is_numpy_namespace(xp) or is_jax_namespace(xp):
|
|
701
|
+
return None
|
|
702
|
+
try:
|
|
703
|
+
from array_api_compat import device
|
|
704
|
+
|
|
705
|
+
return device(x)
|
|
706
|
+
except Exception:
|
|
707
|
+
return None
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def safe_to_device(x, device, xp):
|
|
711
|
+
"""
|
|
712
|
+
Move to device if specified; otherwise return input.
|
|
713
|
+
|
|
714
|
+
Skips moves for numpy/jax/None devices; logs and returns input on failure.
|
|
715
|
+
"""
|
|
716
|
+
if device is None:
|
|
717
|
+
return x
|
|
718
|
+
if xp is None or is_numpy_namespace(xp) or is_jax_namespace(xp):
|
|
719
|
+
return x
|
|
720
|
+
try:
|
|
721
|
+
return to_device(x, device)
|
|
722
|
+
except Exception:
|
|
723
|
+
logger.warning(
|
|
724
|
+
"Failed to move array to device %s; leaving on current device",
|
|
725
|
+
device,
|
|
726
|
+
)
|
|
727
|
+
return x
|
|
728
|
+
|
|
729
|
+
|
|
632
730
|
def recursively_save_to_h5_file(h5_file, path, dictionary):
|
|
633
731
|
"""Save a dictionary to an HDF5 file with flattened keys under a given group path."""
|
|
634
732
|
# Ensure the group exists (or open it if already present)
|
|
@@ -1,28 +1,28 @@
|
|
|
1
1
|
aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
|
|
2
|
-
aspire/aspire.py,sha256=
|
|
2
|
+
aspire/aspire.py,sha256=lr0bD5GDWdlAGfODGzj4BoELKUF5HYAMb8yYGvRR_y0,30860
|
|
3
3
|
aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
|
|
4
4
|
aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
|
|
5
|
-
aspire/samples.py,sha256=
|
|
5
|
+
aspire/samples.py,sha256=v7y8DkirUCHOJbCE-o9y2K7xzU2HicIo_O0CdFhLgXE,19478
|
|
6
6
|
aspire/transforms.py,sha256=CHrfPQHEyHQ9I0WWiAgUplWwxYypZsD7uCYIHUbSFtY,24974
|
|
7
|
-
aspire/utils.py,sha256=
|
|
7
|
+
aspire/utils.py,sha256=87avRkTce9QvELpIcqlKauSoUZSYi1fqe1asC97TzqA,26947
|
|
8
8
|
aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
|
|
9
9
|
aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
|
|
10
10
|
aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
|
|
11
11
|
aspire/flows/jax/flows.py,sha256=1HnVgQ1GUXNcvxiZqEV19H2QI9Th5bWX_QbNfGaUhuA,6625
|
|
12
12
|
aspire/flows/jax/utils.py,sha256=5T6UrgpARG9VywC9qmTl45LjyZWuEdkW3XUladE6xJE,1518
|
|
13
13
|
aspire/flows/torch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
aspire/flows/torch/flows.py,sha256=
|
|
14
|
+
aspire/flows/torch/flows.py,sha256=QcQOcFZEsLWHPwbQUFGOFdfEslyc59Vf_UEsS0xAGPo,11673
|
|
15
15
|
aspire/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
-
aspire/samplers/base.py,sha256=
|
|
16
|
+
aspire/samplers/base.py,sha256=ygqrvqSedWSb0cz8DQ_MHokOOxi6aBRdHxf_qoEPwUE,8243
|
|
17
17
|
aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M,676
|
|
18
18
|
aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
|
|
19
19
|
aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
|
-
aspire/samplers/smc/base.py,sha256=
|
|
21
|
-
aspire/samplers/smc/blackjax.py,sha256=
|
|
22
|
-
aspire/samplers/smc/emcee.py,sha256=
|
|
23
|
-
aspire/samplers/smc/minipcn.py,sha256=
|
|
24
|
-
aspire_inference-0.1.
|
|
25
|
-
aspire_inference-0.1.
|
|
26
|
-
aspire_inference-0.1.
|
|
27
|
-
aspire_inference-0.1.
|
|
28
|
-
aspire_inference-0.1.
|
|
20
|
+
aspire/samplers/smc/base.py,sha256=40A9yVuKS1F8cPzbfVQ9rNk3y07mnkfbuyRDIh_fy5A,14122
|
|
21
|
+
aspire/samplers/smc/blackjax.py,sha256=2riWDSRmpL5lGmnhNtdieiRs0oYC6XZA2X-nVlQaqpE,12490
|
|
22
|
+
aspire/samplers/smc/emcee.py,sha256=4CI9GvH69FCoLiFBbKKYwYocYyiM95IijC5EvrcAmUo,2891
|
|
23
|
+
aspire/samplers/smc/minipcn.py,sha256=IJ5466VvARd4qZCWXXl-l3BPaKW1AgcwmbP3ISL2bto,3368
|
|
24
|
+
aspire_inference-0.1.0a11.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
|
|
25
|
+
aspire_inference-0.1.0a11.dist-info/METADATA,sha256=lLd2d5HR-t942wKLyYbdJ1DL9CKl7tkpBul_vX8DU4M,3869
|
|
26
|
+
aspire_inference-0.1.0a11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
27
|
+
aspire_inference-0.1.0a11.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
|
|
28
|
+
aspire_inference-0.1.0a11.dist-info/RECORD,,
|
|
File without changes
|
{aspire_inference-0.1.0a10.dist-info → aspire_inference-0.1.0a11.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
|
File without changes
|