aspire-inference 0.1.0a9__py3-none-any.whl → 0.1.0a11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/aspire.py CHANGED
@@ -1,5 +1,8 @@
1
+ import copy
1
2
  import logging
2
3
  import multiprocessing as mp
4
+ import pickle
5
+ from contextlib import contextmanager
3
6
  from inspect import signature
4
7
  from typing import Any, Callable
5
8
 
@@ -8,13 +11,19 @@ import h5py
8
11
  from .flows import get_flow_wrapper
9
12
  from .flows.base import Flow
10
13
  from .history import History
14
+ from .samplers.base import Sampler
11
15
  from .samples import Samples
12
16
  from .transforms import (
13
17
  CompositeTransform,
14
18
  FlowPreconditioningTransform,
15
19
  FlowTransform,
16
20
  )
17
- from .utils import recursively_save_to_h5_file
21
+ from .utils import (
22
+ AspireFile,
23
+ load_from_h5_file,
24
+ recursively_save_to_h5_file,
25
+ resolve_xp,
26
+ )
18
27
 
19
28
  logger = logging.getLogger(__name__)
20
29
 
@@ -102,6 +111,7 @@ class Aspire:
102
111
  self.dtype = dtype
103
112
 
104
113
  self._flow = flow
114
+ self._sampler = None
105
115
 
106
116
  @property
107
117
  def flow(self):
@@ -114,7 +124,7 @@ class Aspire:
114
124
  self._flow = flow
115
125
 
116
126
  @property
117
- def sampler(self):
127
+ def sampler(self) -> Sampler | None:
118
128
  """The sampler object."""
119
129
  return self._sampler
120
130
 
@@ -192,7 +202,29 @@ class Aspire:
192
202
  **self.flow_kwargs,
193
203
  )
194
204
 
195
- def fit(self, samples: Samples, **kwargs) -> History:
205
+ def fit(
206
+ self,
207
+ samples: Samples,
208
+ checkpoint_path: str | None = None,
209
+ checkpoint_save_config: bool = True,
210
+ overwrite: bool = False,
211
+ **kwargs,
212
+ ) -> History:
213
+ """Fit the normalizing flow to the provided samples.
214
+
215
+ Parameters
216
+ ----------
217
+ samples : Samples
218
+ The samples to fit the flow to.
219
+ checkpoint_path : str | None
220
+ Path to save the checkpoint. If None, no checkpoint is saved.
221
+ checkpoint_save_config : bool
222
+ Whether to save the Aspire configuration to the checkpoint.
223
+ overwrite : bool
224
+ Whether to overwrite an existing flow in the checkpoint file.
225
+ kwargs : dict
226
+ Keyword arguments to pass to the flow's fit method.
227
+ """
196
228
  if self.xp is None:
197
229
  self.xp = samples.xp
198
230
 
@@ -202,6 +234,28 @@ class Aspire:
202
234
  self.training_samples = samples
203
235
  logger.info(f"Training with {len(samples.x)} samples")
204
236
  history = self.flow.fit(samples.x, **kwargs)
237
+ defaults = getattr(self, "_checkpoint_defaults", None)
238
+ if checkpoint_path is None and defaults:
239
+ checkpoint_path = defaults["path"]
240
+ checkpoint_save_config = defaults["save_config"]
241
+ saved_config = (
242
+ defaults.get("saved_config", False) if defaults else False
243
+ )
244
+ if checkpoint_path is not None:
245
+ with AspireFile(checkpoint_path, "a") as h5_file:
246
+ if checkpoint_save_config and not saved_config:
247
+ if "aspire_config" in h5_file:
248
+ del h5_file["aspire_config"]
249
+ self.save_config(h5_file, include_sampler_config=False)
250
+ if defaults is not None:
251
+ defaults["saved_config"] = True
252
+ # Save flow only if missing or overwrite=True
253
+ if "flow" in h5_file:
254
+ if overwrite:
255
+ del h5_file["flow"]
256
+ self.save_flow(h5_file)
257
+ else:
258
+ self.save_flow(h5_file)
205
259
  return history
206
260
 
207
261
  def get_sampler_class(self, sampler_type: str) -> Callable:
@@ -241,6 +295,13 @@ class Aspire:
241
295
  ----------
242
296
  sampler_type : str
243
297
  The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
298
+ preconditioning: str
299
+ Type of preconditioning to apply in the sampler. Options are
300
+ 'default', 'flow', or 'none'.
301
+ preconditioning_kwargs: dict
302
+ Keyword arguments to pass to the preconditioning transform.
303
+ kwargs : dict
304
+ Keyword arguments to pass to the sampler.
244
305
  """
245
306
  SamplerClass = self.get_sampler_class(sampler_type)
246
307
 
@@ -304,6 +365,9 @@ class Aspire:
304
365
  return_history: bool = False,
305
366
  preconditioning: str | None = None,
306
367
  preconditioning_kwargs: dict | None = None,
368
+ checkpoint_path: str | None = None,
369
+ checkpoint_every: int = 1,
370
+ checkpoint_save_config: bool = True,
307
371
  **kwargs,
308
372
  ) -> Samples:
309
373
  """Draw samples from the posterior distribution.
@@ -342,6 +406,14 @@ class Aspire:
342
406
  will default to 'none' and the other samplers to 'default'
343
407
  preconditioning_kwargs: dict
344
408
  Keyword arguments to pass to the preconditioning transform.
409
+ checkpoint_path : str | None
410
+ Path to save the checkpoint. If None, no checkpoint is saved unless
411
+ within an :py:meth:`auto_checkpoint` context or a custom callback
412
+ is provided.
413
+ checkpoint_every : int
414
+ Frequency (in number of sampler iterations) to save the checkpoint.
415
+ checkpoint_save_config : bool
416
+ Whether to save the Aspire configuration to the checkpoint.
345
417
  kwargs : dict
346
418
  Keyword arguments to pass to the sampler. These are passed
347
419
  automatically to the init method of the sampler or to the sample
@@ -352,6 +424,22 @@ class Aspire:
352
424
  samples : Samples
353
425
  Samples object contain samples and their corresponding weights.
354
426
  """
427
+ if (
428
+ sampler == "importance"
429
+ and hasattr(self, "_resume_sampler_type")
430
+ and self._resume_sampler_type
431
+ ):
432
+ sampler = self._resume_sampler_type
433
+
434
+ if "resume_from" not in kwargs and hasattr(
435
+ self, "_resume_from_default"
436
+ ):
437
+ kwargs["resume_from"] = self._resume_from_default
438
+ if hasattr(self, "_resume_overrides"):
439
+ kwargs.update(self._resume_overrides)
440
+ if hasattr(self, "_resume_n_samples") and n_samples == 1000:
441
+ n_samples = self._resume_n_samples
442
+
355
443
  SamplerClass = self.get_sampler_class(sampler)
356
444
  # Determine sampler initialization parameters
357
445
  # and remove them from kwargs
@@ -373,7 +461,73 @@ class Aspire:
373
461
  preconditioning_kwargs=preconditioning_kwargs,
374
462
  **sampler_kwargs,
375
463
  )
464
+ self._last_sampler_type = sampler
465
+ # Auto-checkpoint convenience: set defaults for checkpointing to a single file
466
+ defaults = getattr(self, "_checkpoint_defaults", None)
467
+ if checkpoint_path is None and defaults:
468
+ checkpoint_path = defaults["path"]
469
+ checkpoint_every = defaults["every"]
470
+ checkpoint_save_config = defaults["save_config"]
471
+ saved_flow = defaults.get("saved_flow", False) if defaults else False
472
+ saved_config = (
473
+ defaults.get("saved_config", False) if defaults else False
474
+ )
475
+ if checkpoint_path is not None:
476
+ kwargs.setdefault("checkpoint_file_path", checkpoint_path)
477
+ kwargs.setdefault("checkpoint_every", checkpoint_every)
478
+ with AspireFile(checkpoint_path, "a") as h5_file:
479
+ if checkpoint_save_config:
480
+ if "aspire_config" in h5_file:
481
+ del h5_file["aspire_config"]
482
+ self.save_config(
483
+ h5_file,
484
+ include_sampler_config=True,
485
+ include_sample_calls=False,
486
+ )
487
+ saved_config = True
488
+ if defaults is not None:
489
+ defaults["saved_config"] = True
490
+ if (
491
+ self.flow is not None
492
+ and not saved_flow
493
+ and "flow" not in h5_file
494
+ ):
495
+ self.save_flow(h5_file)
496
+ saved_flow = True
497
+ if defaults is not None:
498
+ defaults["saved_flow"] = True
499
+
376
500
  samples = self._sampler.sample(n_samples, **kwargs)
501
+ self._last_sample_posterior_kwargs = {
502
+ "n_samples": n_samples,
503
+ "sampler": sampler,
504
+ "xp": xp,
505
+ "return_history": return_history,
506
+ "preconditioning": preconditioning,
507
+ "preconditioning_kwargs": preconditioning_kwargs,
508
+ "sampler_init_kwargs": sampler_kwargs,
509
+ "sample_kwargs": copy.deepcopy(kwargs),
510
+ }
511
+ if checkpoint_path is not None:
512
+ with AspireFile(checkpoint_path, "a") as h5_file:
513
+ if checkpoint_save_config and not saved_config:
514
+ if "aspire_config" in h5_file:
515
+ del h5_file["aspire_config"]
516
+ self.save_config(
517
+ h5_file,
518
+ include_sampler_config=True,
519
+ include_sample_calls=False,
520
+ )
521
+ if defaults is not None:
522
+ defaults["saved_config"] = True
523
+ if (
524
+ self.flow is not None
525
+ and not saved_flow
526
+ and "flow" not in h5_file
527
+ ):
528
+ self.save_flow(h5_file)
529
+ if defaults is not None:
530
+ defaults["saved_flow"] = True
377
531
  if xp is not None:
378
532
  samples = samples.to_namespace(xp)
379
533
  samples.parameters = self.parameters
@@ -388,6 +542,122 @@ class Aspire:
388
542
  else:
389
543
  return samples
390
544
 
545
+ @classmethod
546
+ def resume_from_file(
547
+ cls,
548
+ file_path: str,
549
+ *,
550
+ log_likelihood: Callable,
551
+ log_prior: Callable,
552
+ sampler: str | None = None,
553
+ checkpoint_path: str = "checkpoint",
554
+ checkpoint_dset: str = "state",
555
+ flow_path: str = "flow",
556
+ config_path: str = "aspire_config",
557
+ resume_kwargs: dict | None = None,
558
+ ):
559
+ """
560
+ Recreate an Aspire object from a single file and prepare to resume sampling.
561
+
562
+ Parameters
563
+ ----------
564
+ file_path : str
565
+ Path to the HDF5 file containing config, flow, and checkpoint.
566
+ log_likelihood : Callable
567
+ Log-likelihood function (required, not pickled).
568
+ log_prior : Callable
569
+ Log-prior function (required, not pickled).
570
+ sampler : str
571
+ Sampler type to use (e.g., 'smc', 'minipcn_smc', 'emcee_smc'). If None,
572
+ will attempt to infer from saved config or checkpoint metadata.
573
+ checkpoint_path : str
574
+ HDF5 group path where the checkpoint is stored.
575
+ checkpoint_dset : str
576
+ Dataset name within the checkpoint group.
577
+ flow_path : str
578
+ HDF5 path to the saved flow.
579
+ config_path : str
580
+ HDF5 path to the saved Aspire config.
581
+ resume_kwargs : dict | None
582
+ Optional overrides to apply when resuming (e.g., checkpoint_every).
583
+ """
584
+ (
585
+ aspire,
586
+ checkpoint_bytes,
587
+ checkpoint_state,
588
+ sampler_config,
589
+ saved_sampler_type,
590
+ n_samples,
591
+ ) = cls._build_aspire_from_file(
592
+ file_path=file_path,
593
+ log_likelihood=log_likelihood,
594
+ log_prior=log_prior,
595
+ checkpoint_path=checkpoint_path,
596
+ checkpoint_dset=checkpoint_dset,
597
+ flow_path=flow_path,
598
+ config_path=config_path,
599
+ )
600
+
601
+ sampler_config = sampler_config or {}
602
+ sampler_config.pop("sampler_class", None)
603
+
604
+ if checkpoint_bytes is not None:
605
+ aspire._resume_from_default = checkpoint_bytes
606
+ aspire._resume_sampler_type = (
607
+ sampler
608
+ or saved_sampler_type
609
+ or (
610
+ checkpoint_state.get("sampler")
611
+ if checkpoint_state
612
+ else None
613
+ )
614
+ )
615
+ aspire._resume_n_samples = n_samples
616
+ aspire._resume_overrides = resume_kwargs or {}
617
+ aspire._resume_sampler_config = sampler_config
618
+ aspire._checkpoint_defaults = {
619
+ "path": file_path,
620
+ "every": 1,
621
+ "save_config": False,
622
+ "save_flow": False,
623
+ "saved_config": False,
624
+ "saved_flow": False,
625
+ }
626
+ return aspire
627
+
628
+ @contextmanager
629
+ def auto_checkpoint(
630
+ self,
631
+ path: str,
632
+ every: int = 1,
633
+ save_config: bool = True,
634
+ save_flow: bool = True,
635
+ ):
636
+ """
637
+ Context manager to auto-save checkpoints, config, and flow to a file.
638
+
639
+ Within the context, sample_posterior will default to writing checkpoints
640
+ to the given path with the specified frequency, and will append config/flow
641
+ after sampling.
642
+ """
643
+ prev = getattr(self, "_checkpoint_defaults", None)
644
+ self._checkpoint_defaults = {
645
+ "path": path,
646
+ "every": every,
647
+ "save_config": save_config,
648
+ "save_flow": save_flow,
649
+ "saved_config": False,
650
+ "saved_flow": False,
651
+ }
652
+ try:
653
+ yield self
654
+ finally:
655
+ if prev is None:
656
+ if hasattr(self, "_checkpoint_defaults"):
657
+ delattr(self, "_checkpoint_defaults")
658
+ else:
659
+ self._checkpoint_defaults = prev
660
+
391
661
  def enable_pool(self, pool: mp.Pool, **kwargs):
392
662
  """Context manager to temporarily replace the log_likelihood method
393
663
  with a version that uses a multiprocessing pool to parallelize
@@ -432,12 +702,16 @@ class Aspire:
432
702
  "flow_kwargs": self.flow_kwargs,
433
703
  "eps": self.eps,
434
704
  }
705
+ if hasattr(self, "_last_sampler_type"):
706
+ config["sampler_type"] = self._last_sampler_type
435
707
  if include_sampler_config:
708
+ if self.sampler is None:
709
+ raise ValueError("Sampler has not been initialized.")
436
710
  config["sampler_config"] = self.sampler.config_dict(**kwargs)
437
711
  return config
438
712
 
439
713
  def save_config(
440
- self, h5_file: h5py.File, path="aspire_config", **kwargs
714
+ self, h5_file: h5py.File | AspireFile, path="aspire_config", **kwargs
441
715
  ) -> None:
442
716
  """Save the configuration to an HDF5 file.
443
717
 
@@ -484,6 +758,7 @@ class Aspire:
484
758
  FlowClass, xp = get_flow_wrapper(
485
759
  backend=self.flow_backend, flow_matching=self.flow_matching
486
760
  )
761
+ logger.debug(f"Loading flow of type {FlowClass} from {path}")
487
762
  self._flow = FlowClass.load(h5_file, path=path)
488
763
 
489
764
  def save_config_to_json(self, filename: str) -> None:
@@ -504,3 +779,80 @@ class Aspire:
504
779
  x, log_q = self.flow.sample_and_log_prob(n_samples)
505
780
  samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
506
781
  return samples
782
+
783
+ # --- Resume helpers ---
784
+ @staticmethod
785
+ def _build_aspire_from_file(
786
+ file_path: str,
787
+ log_likelihood: Callable,
788
+ log_prior: Callable,
789
+ checkpoint_path: str,
790
+ checkpoint_dset: str,
791
+ flow_path: str,
792
+ config_path: str,
793
+ ):
794
+ """Construct an Aspire instance, load flow, and gather checkpoint metadata from file."""
795
+ with AspireFile(file_path, "r") as h5_file:
796
+ if config_path not in h5_file:
797
+ raise ValueError(
798
+ f"Config path '{config_path}' not found in {file_path}"
799
+ )
800
+ config_dict = load_from_h5_file(h5_file, config_path)
801
+ try:
802
+ checkpoint_bytes = h5_file[checkpoint_path][checkpoint_dset][
803
+ ...
804
+ ].tobytes()
805
+ except Exception:
806
+ logger.warning(
807
+ "Checkpoint not found at %s/%s in %s; will resume without a checkpoint.",
808
+ checkpoint_path,
809
+ checkpoint_dset,
810
+ file_path,
811
+ )
812
+ checkpoint_bytes = None
813
+
814
+ sampler_config = config_dict.pop("sampler_config", None)
815
+ saved_sampler_type = config_dict.pop("sampler_type", None)
816
+ if isinstance(config_dict.get("xp"), str):
817
+ config_dict["xp"] = resolve_xp(config_dict["xp"])
818
+ config_dict["log_likelihood"] = log_likelihood
819
+ config_dict["log_prior"] = log_prior
820
+
821
+ aspire = Aspire(**config_dict)
822
+
823
+ with AspireFile(file_path, "r") as h5_file:
824
+ if flow_path in h5_file:
825
+ logger.info(f"Loading flow from {flow_path} in {file_path}")
826
+ aspire.load_flow(h5_file, path=flow_path)
827
+ else:
828
+ raise ValueError(
829
+ f"Flow path '{flow_path}' not found in {file_path}"
830
+ )
831
+
832
+ n_samples = None
833
+ checkpoint_state = None
834
+ if checkpoint_bytes is not None:
835
+ try:
836
+ checkpoint_state = pickle.loads(checkpoint_bytes)
837
+ samples_saved = (
838
+ checkpoint_state.get("samples")
839
+ if checkpoint_state
840
+ else None
841
+ )
842
+ if samples_saved is not None:
843
+ n_samples = len(samples_saved)
844
+ if aspire.xp is None and hasattr(samples_saved, "xp"):
845
+ aspire.xp = samples_saved.xp
846
+ except Exception:
847
+ logger.warning(
848
+ "Failed to decode checkpoint; proceeding without resume state."
849
+ )
850
+
851
+ return (
852
+ aspire,
853
+ checkpoint_bytes,
854
+ checkpoint_state,
855
+ sampler_config,
856
+ saved_sampler_type,
857
+ n_samples,
858
+ )
@@ -92,7 +92,7 @@ class BaseTorchFlow(Flow):
92
92
  config = load_from_h5_file(flow_grp, "config")
93
93
  config["dtype"] = decode_dtype(torch, config.get("dtype"))
94
94
  if "data_transform" in flow_grp:
95
- from ..transforms import BaseTransform
95
+ from ...transforms import BaseTransform
96
96
 
97
97
  data_transform = BaseTransform.load(
98
98
  flow_grp,
aspire/samplers/base.py CHANGED
@@ -1,10 +1,12 @@
1
1
  import logging
2
+ import pickle
3
+ from pathlib import Path
2
4
  from typing import Any, Callable
3
5
 
4
6
  from ..flows.base import Flow
5
7
  from ..samples import Samples
6
8
  from ..transforms import IdentityTransform
7
- from ..utils import track_calls
9
+ from ..utils import AspireFile, asarray, dump_state, track_calls
8
10
 
9
11
  logger = logging.getLogger(__name__)
10
12
 
@@ -49,6 +51,8 @@ class Sampler:
49
51
  self.parameters = parameters
50
52
  self.history = None
51
53
  self.n_likelihood_evaluations = 0
54
+ self._last_checkpoint_state: dict | None = None
55
+ self._last_checkpoint_bytes: bytes | None = None
52
56
  if preconditioning_transform is None:
53
57
  self.preconditioning_transform = IdentityTransform(xp=self.xp)
54
58
  else:
@@ -56,7 +60,11 @@ class Sampler:
56
60
 
57
61
  def fit_preconditioning_transform(self, x):
58
62
  """Fit the data transform to the data."""
59
- x = self.preconditioning_transform.xp.asarray(x)
63
+ x = asarray(
64
+ x,
65
+ xp=self.preconditioning_transform.xp,
66
+ dtype=self.preconditioning_transform.dtype,
67
+ )
60
68
  return self.preconditioning_transform.fit(x)
61
69
 
62
70
  @track_calls
@@ -71,7 +79,7 @@ class Sampler:
71
79
  self.n_likelihood_evaluations += len(samples)
72
80
  return self._log_likelihood(samples)
73
81
 
74
- def config_dict(self, include_sample_calls: bool = True) -> dict:
82
+ def config_dict(self, include_sample_calls: bool = False) -> dict:
75
83
  """
76
84
  Returns a dictionary with the configuration of the sampler.
77
85
 
@@ -79,9 +87,9 @@ class Sampler:
79
87
  ----------
80
88
  include_sample_calls : bool
81
89
  Whether to include the sample calls in the configuration.
82
- Default is True.
90
+ Default is False.
83
91
  """
84
- config = {}
92
+ config = {"sampler_class": self.__class__.__name__}
85
93
  if include_sample_calls:
86
94
  if hasattr(self, "sample") and hasattr(self.sample, "calls"):
87
95
  config["sample_calls"] = self.sample.calls.to_dict(
@@ -92,3 +100,139 @@ class Sampler:
92
100
  "Sampler does not have a sample method with calls attribute."
93
101
  )
94
102
  return config
103
+
104
+ # --- Checkpointing helpers shared across samplers ---
105
+ def _checkpoint_extra_state(self) -> dict:
106
+ """Sampler-specific extras for checkpointing (override in subclasses)."""
107
+ return {}
108
+
109
+ def _restore_extra_state(self, state: dict) -> None:
110
+ """Restore sampler-specific extras (override in subclasses)."""
111
+ _ = state # no-op for base
112
+
113
+ def build_checkpoint_state(
114
+ self,
115
+ samples: Samples,
116
+ iteration: int | None = None,
117
+ meta: dict | None = None,
118
+ ) -> dict:
119
+ """Prepare a serializable checkpoint payload for the sampler state."""
120
+ checkpoint_samples = samples
121
+ base_state = {
122
+ "sampler": self.__class__.__name__,
123
+ "iteration": iteration,
124
+ "samples": checkpoint_samples,
125
+ "config": self.config_dict(include_sample_calls=False),
126
+ "parameters": self.parameters,
127
+ "meta": meta or {},
128
+ }
129
+ base_state.update(self._checkpoint_extra_state())
130
+ return base_state
131
+
132
+ def serialize_checkpoint(
133
+ self, state: dict, protocol: int | None = None
134
+ ) -> bytes:
135
+ """Serialize a checkpoint state to bytes with pickle."""
136
+ protocol = (
137
+ pickle.HIGHEST_PROTOCOL if protocol is None else int(protocol)
138
+ )
139
+ return pickle.dumps(state, protocol=protocol)
140
+
141
+ def default_checkpoint_callback(self, state: dict) -> None:
142
+ """Store the latest checkpoint (state + pickled bytes) on the sampler."""
143
+ self._last_checkpoint_state = state
144
+ self._last_checkpoint_bytes = self.serialize_checkpoint(state)
145
+
146
+ def default_file_checkpoint_callback(
147
+ self, file_path: str | Path | None
148
+ ) -> Callable[[dict], None]:
149
+ """Return a simple default callback that overwrites an HDF5 file."""
150
+ if file_path is None:
151
+ return self.default_checkpoint_callback
152
+ file_path = Path(file_path)
153
+ lower_path = file_path.name.lower()
154
+ if not lower_path.endswith((".h5", ".hdf5")):
155
+ raise ValueError(
156
+ "Checkpoint file must be an HDF5 file (.h5 or .hdf5)."
157
+ )
158
+
159
+ def _callback(state: dict) -> None:
160
+ with AspireFile(file_path, "a") as h5_file:
161
+ self.save_checkpoint_to_hdf(
162
+ state, h5_file, path="checkpoint", dsetname="state"
163
+ )
164
+ self.default_checkpoint_callback(state)
165
+
166
+ return _callback
167
+
168
+ def save_checkpoint_to_hdf(
169
+ self,
170
+ state: dict,
171
+ h5_file,
172
+ path: str = "sampler_checkpoints",
173
+ dsetname: str | None = None,
174
+ protocol: int | None = None,
175
+ ) -> None:
176
+ """Save a checkpoint state into an HDF5 file as a pickled blob."""
177
+ if dsetname is None:
178
+ iter_str = state.get("iteration", "unknown")
179
+ dsetname = f"iter_{iter_str}"
180
+ dump_state(
181
+ state,
182
+ h5_file,
183
+ path=path,
184
+ dsetname=dsetname,
185
+ protocol=protocol or pickle.HIGHEST_PROTOCOL,
186
+ )
187
+
188
+ def load_checkpoint_from_file(
189
+ self,
190
+ file_path: str | Path,
191
+ h5_path: str = "checkpoint",
192
+ dsetname: str = "state",
193
+ ) -> dict:
194
+ """Load a checkpoint dictionary from .pkl or .hdf5 file."""
195
+ file_path = Path(file_path)
196
+ lower_path = file_path.name.lower()
197
+ if lower_path.endswith((".h5", ".hdf5")):
198
+ with AspireFile(file_path, "r") as h5_file:
199
+ data = h5_file[h5_path][dsetname][...]
200
+ checkpoint_bytes = data.tobytes()
201
+ else:
202
+ with open(file_path, "rb") as f:
203
+ checkpoint_bytes = f.read()
204
+ return pickle.loads(checkpoint_bytes)
205
+
206
+ def restore_from_checkpoint(
207
+ self, source: str | bytes | dict
208
+ ) -> tuple[Samples, dict]:
209
+ """Restore sampler state from a checkpoint source."""
210
+ if isinstance(source, str):
211
+ state = self.load_checkpoint_from_file(source)
212
+ elif isinstance(source, bytes):
213
+ state = pickle.loads(source)
214
+ elif isinstance(source, dict):
215
+ state = source
216
+ else:
217
+ raise TypeError("Unsupported checkpoint source type.")
218
+
219
+ samples_saved = state.get("samples")
220
+ if samples_saved is None:
221
+ raise ValueError("Checkpoint missing samples.")
222
+
223
+ samples = Samples.from_samples(
224
+ samples_saved, xp=self.xp, dtype=self.dtype
225
+ )
226
+ # Allow subclasses to restore sampler-specific components
227
+ self._restore_extra_state(state)
228
+ return samples, state
229
+
230
+ @property
231
+ def last_checkpoint_state(self) -> dict | None:
232
+ """Return the most recent checkpoint state stored by the default callback."""
233
+ return self._last_checkpoint_state
234
+
235
+ @property
236
+ def last_checkpoint_bytes(self) -> bytes | None:
237
+ """Return the most recent pickled checkpoint produced by the default callback."""
238
+ return self._last_checkpoint_bytes