aspire-inference 0.1.0a10__py3-none-any.whl → 0.1.0a11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/aspire.py CHANGED
@@ -1,5 +1,8 @@
1
+ import copy
1
2
  import logging
2
3
  import multiprocessing as mp
4
+ import pickle
5
+ from contextlib import contextmanager
3
6
  from inspect import signature
4
7
  from typing import Any, Callable
5
8
 
@@ -8,13 +11,19 @@ import h5py
8
11
  from .flows import get_flow_wrapper
9
12
  from .flows.base import Flow
10
13
  from .history import History
14
+ from .samplers.base import Sampler
11
15
  from .samples import Samples
12
16
  from .transforms import (
13
17
  CompositeTransform,
14
18
  FlowPreconditioningTransform,
15
19
  FlowTransform,
16
20
  )
17
- from .utils import recursively_save_to_h5_file
21
+ from .utils import (
22
+ AspireFile,
23
+ load_from_h5_file,
24
+ recursively_save_to_h5_file,
25
+ resolve_xp,
26
+ )
18
27
 
19
28
  logger = logging.getLogger(__name__)
20
29
 
@@ -102,6 +111,7 @@ class Aspire:
102
111
  self.dtype = dtype
103
112
 
104
113
  self._flow = flow
114
+ self._sampler = None
105
115
 
106
116
  @property
107
117
  def flow(self):
@@ -114,7 +124,7 @@ class Aspire:
114
124
  self._flow = flow
115
125
 
116
126
  @property
117
- def sampler(self):
127
+ def sampler(self) -> Sampler | None:
118
128
  """The sampler object."""
119
129
  return self._sampler
120
130
 
@@ -192,7 +202,29 @@ class Aspire:
192
202
  **self.flow_kwargs,
193
203
  )
194
204
 
195
- def fit(self, samples: Samples, **kwargs) -> History:
205
+ def fit(
206
+ self,
207
+ samples: Samples,
208
+ checkpoint_path: str | None = None,
209
+ checkpoint_save_config: bool = True,
210
+ overwrite: bool = False,
211
+ **kwargs,
212
+ ) -> History:
213
+ """Fit the normalizing flow to the provided samples.
214
+
215
+ Parameters
216
+ ----------
217
+ samples : Samples
218
+ The samples to fit the flow to.
219
+ checkpoint_path : str | None
220
+ Path to save the checkpoint. If None, no checkpoint is saved.
221
+ checkpoint_save_config : bool
222
+ Whether to save the Aspire configuration to the checkpoint.
223
+ overwrite : bool
224
+ Whether to overwrite an existing flow in the checkpoint file.
225
+ kwargs : dict
226
+ Keyword arguments to pass to the flow's fit method.
227
+ """
196
228
  if self.xp is None:
197
229
  self.xp = samples.xp
198
230
 
@@ -202,6 +234,28 @@ class Aspire:
202
234
  self.training_samples = samples
203
235
  logger.info(f"Training with {len(samples.x)} samples")
204
236
  history = self.flow.fit(samples.x, **kwargs)
237
+ defaults = getattr(self, "_checkpoint_defaults", None)
238
+ if checkpoint_path is None and defaults:
239
+ checkpoint_path = defaults["path"]
240
+ checkpoint_save_config = defaults["save_config"]
241
+ saved_config = (
242
+ defaults.get("saved_config", False) if defaults else False
243
+ )
244
+ if checkpoint_path is not None:
245
+ with AspireFile(checkpoint_path, "a") as h5_file:
246
+ if checkpoint_save_config and not saved_config:
247
+ if "aspire_config" in h5_file:
248
+ del h5_file["aspire_config"]
249
+ self.save_config(h5_file, include_sampler_config=False)
250
+ if defaults is not None:
251
+ defaults["saved_config"] = True
252
+ # Save flow only if missing or overwrite=True
253
+ if "flow" in h5_file:
254
+ if overwrite:
255
+ del h5_file["flow"]
256
+ self.save_flow(h5_file)
257
+ else:
258
+ self.save_flow(h5_file)
205
259
  return history
206
260
 
207
261
  def get_sampler_class(self, sampler_type: str) -> Callable:
@@ -241,6 +295,13 @@ class Aspire:
241
295
  ----------
242
296
  sampler_type : str
243
297
  The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
298
+ preconditioning: str
299
+ Type of preconditioning to apply in the sampler. Options are
300
+ 'default', 'flow', or 'none'.
301
+ preconditioning_kwargs: dict
302
+ Keyword arguments to pass to the preconditioning transform.
303
+ kwargs : dict
304
+ Keyword arguments to pass to the sampler.
244
305
  """
245
306
  SamplerClass = self.get_sampler_class(sampler_type)
246
307
 
@@ -304,6 +365,9 @@ class Aspire:
304
365
  return_history: bool = False,
305
366
  preconditioning: str | None = None,
306
367
  preconditioning_kwargs: dict | None = None,
368
+ checkpoint_path: str | None = None,
369
+ checkpoint_every: int = 1,
370
+ checkpoint_save_config: bool = True,
307
371
  **kwargs,
308
372
  ) -> Samples:
309
373
  """Draw samples from the posterior distribution.
@@ -342,6 +406,14 @@ class Aspire:
342
406
  will default to 'none' and the other samplers to 'default'
343
407
  preconditioning_kwargs: dict
344
408
  Keyword arguments to pass to the preconditioning transform.
409
+ checkpoint_path : str | None
410
+ Path to save the checkpoint. If None, no checkpoint is saved unless
411
+ within an :py:meth:`auto_checkpoint` context or a custom callback
412
+ is provided.
413
+ checkpoint_every : int
414
+ Frequency (in number of sampler iterations) to save the checkpoint.
415
+ checkpoint_save_config : bool
416
+ Whether to save the Aspire configuration to the checkpoint.
345
417
  kwargs : dict
346
418
  Keyword arguments to pass to the sampler. These are passed
347
419
  automatically to the init method of the sampler or to the sample
@@ -352,6 +424,22 @@ class Aspire:
352
424
  samples : Samples
353
425
  Samples object contain samples and their corresponding weights.
354
426
  """
427
+ if (
428
+ sampler == "importance"
429
+ and hasattr(self, "_resume_sampler_type")
430
+ and self._resume_sampler_type
431
+ ):
432
+ sampler = self._resume_sampler_type
433
+
434
+ if "resume_from" not in kwargs and hasattr(
435
+ self, "_resume_from_default"
436
+ ):
437
+ kwargs["resume_from"] = self._resume_from_default
438
+ if hasattr(self, "_resume_overrides"):
439
+ kwargs.update(self._resume_overrides)
440
+ if hasattr(self, "_resume_n_samples") and n_samples == 1000:
441
+ n_samples = self._resume_n_samples
442
+
355
443
  SamplerClass = self.get_sampler_class(sampler)
356
444
  # Determine sampler initialization parameters
357
445
  # and remove them from kwargs
@@ -373,7 +461,73 @@ class Aspire:
373
461
  preconditioning_kwargs=preconditioning_kwargs,
374
462
  **sampler_kwargs,
375
463
  )
464
+ self._last_sampler_type = sampler
465
+ # Auto-checkpoint convenience: set defaults for checkpointing to a single file
466
+ defaults = getattr(self, "_checkpoint_defaults", None)
467
+ if checkpoint_path is None and defaults:
468
+ checkpoint_path = defaults["path"]
469
+ checkpoint_every = defaults["every"]
470
+ checkpoint_save_config = defaults["save_config"]
471
+ saved_flow = defaults.get("saved_flow", False) if defaults else False
472
+ saved_config = (
473
+ defaults.get("saved_config", False) if defaults else False
474
+ )
475
+ if checkpoint_path is not None:
476
+ kwargs.setdefault("checkpoint_file_path", checkpoint_path)
477
+ kwargs.setdefault("checkpoint_every", checkpoint_every)
478
+ with AspireFile(checkpoint_path, "a") as h5_file:
479
+ if checkpoint_save_config:
480
+ if "aspire_config" in h5_file:
481
+ del h5_file["aspire_config"]
482
+ self.save_config(
483
+ h5_file,
484
+ include_sampler_config=True,
485
+ include_sample_calls=False,
486
+ )
487
+ saved_config = True
488
+ if defaults is not None:
489
+ defaults["saved_config"] = True
490
+ if (
491
+ self.flow is not None
492
+ and not saved_flow
493
+ and "flow" not in h5_file
494
+ ):
495
+ self.save_flow(h5_file)
496
+ saved_flow = True
497
+ if defaults is not None:
498
+ defaults["saved_flow"] = True
499
+
376
500
  samples = self._sampler.sample(n_samples, **kwargs)
501
+ self._last_sample_posterior_kwargs = {
502
+ "n_samples": n_samples,
503
+ "sampler": sampler,
504
+ "xp": xp,
505
+ "return_history": return_history,
506
+ "preconditioning": preconditioning,
507
+ "preconditioning_kwargs": preconditioning_kwargs,
508
+ "sampler_init_kwargs": sampler_kwargs,
509
+ "sample_kwargs": copy.deepcopy(kwargs),
510
+ }
511
+ if checkpoint_path is not None:
512
+ with AspireFile(checkpoint_path, "a") as h5_file:
513
+ if checkpoint_save_config and not saved_config:
514
+ if "aspire_config" in h5_file:
515
+ del h5_file["aspire_config"]
516
+ self.save_config(
517
+ h5_file,
518
+ include_sampler_config=True,
519
+ include_sample_calls=False,
520
+ )
521
+ if defaults is not None:
522
+ defaults["saved_config"] = True
523
+ if (
524
+ self.flow is not None
525
+ and not saved_flow
526
+ and "flow" not in h5_file
527
+ ):
528
+ self.save_flow(h5_file)
529
+ if defaults is not None:
530
+ defaults["saved_flow"] = True
377
531
  if xp is not None:
378
532
  samples = samples.to_namespace(xp)
379
533
  samples.parameters = self.parameters
@@ -388,6 +542,122 @@ class Aspire:
388
542
  else:
389
543
  return samples
390
544
 
545
+ @classmethod
546
+ def resume_from_file(
547
+ cls,
548
+ file_path: str,
549
+ *,
550
+ log_likelihood: Callable,
551
+ log_prior: Callable,
552
+ sampler: str | None = None,
553
+ checkpoint_path: str = "checkpoint",
554
+ checkpoint_dset: str = "state",
555
+ flow_path: str = "flow",
556
+ config_path: str = "aspire_config",
557
+ resume_kwargs: dict | None = None,
558
+ ):
559
+ """
560
+ Recreate an Aspire object from a single file and prepare to resume sampling.
561
+
562
+ Parameters
563
+ ----------
564
+ file_path : str
565
+ Path to the HDF5 file containing config, flow, and checkpoint.
566
+ log_likelihood : Callable
567
+ Log-likelihood function (required, not pickled).
568
+ log_prior : Callable
569
+ Log-prior function (required, not pickled).
570
+ sampler : str
571
+ Sampler type to use (e.g., 'smc', 'minipcn_smc', 'emcee_smc'). If None,
572
+ will attempt to infer from saved config or checkpoint metadata.
573
+ checkpoint_path : str
574
+ HDF5 group path where the checkpoint is stored.
575
+ checkpoint_dset : str
576
+ Dataset name within the checkpoint group.
577
+ flow_path : str
578
+ HDF5 path to the saved flow.
579
+ config_path : str
580
+ HDF5 path to the saved Aspire config.
581
+ resume_kwargs : dict | None
582
+ Optional overrides to apply when resuming (e.g., checkpoint_every).
583
+ """
584
+ (
585
+ aspire,
586
+ checkpoint_bytes,
587
+ checkpoint_state,
588
+ sampler_config,
589
+ saved_sampler_type,
590
+ n_samples,
591
+ ) = cls._build_aspire_from_file(
592
+ file_path=file_path,
593
+ log_likelihood=log_likelihood,
594
+ log_prior=log_prior,
595
+ checkpoint_path=checkpoint_path,
596
+ checkpoint_dset=checkpoint_dset,
597
+ flow_path=flow_path,
598
+ config_path=config_path,
599
+ )
600
+
601
+ sampler_config = sampler_config or {}
602
+ sampler_config.pop("sampler_class", None)
603
+
604
+ if checkpoint_bytes is not None:
605
+ aspire._resume_from_default = checkpoint_bytes
606
+ aspire._resume_sampler_type = (
607
+ sampler
608
+ or saved_sampler_type
609
+ or (
610
+ checkpoint_state.get("sampler")
611
+ if checkpoint_state
612
+ else None
613
+ )
614
+ )
615
+ aspire._resume_n_samples = n_samples
616
+ aspire._resume_overrides = resume_kwargs or {}
617
+ aspire._resume_sampler_config = sampler_config
618
+ aspire._checkpoint_defaults = {
619
+ "path": file_path,
620
+ "every": 1,
621
+ "save_config": False,
622
+ "save_flow": False,
623
+ "saved_config": False,
624
+ "saved_flow": False,
625
+ }
626
+ return aspire
627
+
628
+ @contextmanager
629
+ def auto_checkpoint(
630
+ self,
631
+ path: str,
632
+ every: int = 1,
633
+ save_config: bool = True,
634
+ save_flow: bool = True,
635
+ ):
636
+ """
637
+ Context manager to auto-save checkpoints, config, and flow to a file.
638
+
639
+ Within the context, sample_posterior will default to writing checkpoints
640
+ to the given path with the specified frequency, and will append config/flow
641
+ after sampling.
642
+ """
643
+ prev = getattr(self, "_checkpoint_defaults", None)
644
+ self._checkpoint_defaults = {
645
+ "path": path,
646
+ "every": every,
647
+ "save_config": save_config,
648
+ "save_flow": save_flow,
649
+ "saved_config": False,
650
+ "saved_flow": False,
651
+ }
652
+ try:
653
+ yield self
654
+ finally:
655
+ if prev is None:
656
+ if hasattr(self, "_checkpoint_defaults"):
657
+ delattr(self, "_checkpoint_defaults")
658
+ else:
659
+ self._checkpoint_defaults = prev
660
+
391
661
  def enable_pool(self, pool: mp.Pool, **kwargs):
392
662
  """Context manager to temporarily replace the log_likelihood method
393
663
  with a version that uses a multiprocessing pool to parallelize
@@ -432,12 +702,16 @@ class Aspire:
432
702
  "flow_kwargs": self.flow_kwargs,
433
703
  "eps": self.eps,
434
704
  }
705
+ if hasattr(self, "_last_sampler_type"):
706
+ config["sampler_type"] = self._last_sampler_type
435
707
  if include_sampler_config:
708
+ if self.sampler is None:
709
+ raise ValueError("Sampler has not been initialized.")
436
710
  config["sampler_config"] = self.sampler.config_dict(**kwargs)
437
711
  return config
438
712
 
439
713
  def save_config(
440
- self, h5_file: h5py.File, path="aspire_config", **kwargs
714
+ self, h5_file: h5py.File | AspireFile, path="aspire_config", **kwargs
441
715
  ) -> None:
442
716
  """Save the configuration to an HDF5 file.
443
717
 
@@ -484,6 +758,7 @@ class Aspire:
484
758
  FlowClass, xp = get_flow_wrapper(
485
759
  backend=self.flow_backend, flow_matching=self.flow_matching
486
760
  )
761
+ logger.debug(f"Loading flow of type {FlowClass} from {path}")
487
762
  self._flow = FlowClass.load(h5_file, path=path)
488
763
 
489
764
  def save_config_to_json(self, filename: str) -> None:
@@ -504,3 +779,80 @@ class Aspire:
504
779
  x, log_q = self.flow.sample_and_log_prob(n_samples)
505
780
  samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
506
781
  return samples
782
+
783
+ # --- Resume helpers ---
784
+ @staticmethod
785
+ def _build_aspire_from_file(
786
+ file_path: str,
787
+ log_likelihood: Callable,
788
+ log_prior: Callable,
789
+ checkpoint_path: str,
790
+ checkpoint_dset: str,
791
+ flow_path: str,
792
+ config_path: str,
793
+ ):
794
+ """Construct an Aspire instance, load flow, and gather checkpoint metadata from file."""
795
+ with AspireFile(file_path, "r") as h5_file:
796
+ if config_path not in h5_file:
797
+ raise ValueError(
798
+ f"Config path '{config_path}' not found in {file_path}"
799
+ )
800
+ config_dict = load_from_h5_file(h5_file, config_path)
801
+ try:
802
+ checkpoint_bytes = h5_file[checkpoint_path][checkpoint_dset][
803
+ ...
804
+ ].tobytes()
805
+ except Exception:
806
+ logger.warning(
807
+ "Checkpoint not found at %s/%s in %s; will resume without a checkpoint.",
808
+ checkpoint_path,
809
+ checkpoint_dset,
810
+ file_path,
811
+ )
812
+ checkpoint_bytes = None
813
+
814
+ sampler_config = config_dict.pop("sampler_config", None)
815
+ saved_sampler_type = config_dict.pop("sampler_type", None)
816
+ if isinstance(config_dict.get("xp"), str):
817
+ config_dict["xp"] = resolve_xp(config_dict["xp"])
818
+ config_dict["log_likelihood"] = log_likelihood
819
+ config_dict["log_prior"] = log_prior
820
+
821
+ aspire = Aspire(**config_dict)
822
+
823
+ with AspireFile(file_path, "r") as h5_file:
824
+ if flow_path in h5_file:
825
+ logger.info(f"Loading flow from {flow_path} in {file_path}")
826
+ aspire.load_flow(h5_file, path=flow_path)
827
+ else:
828
+ raise ValueError(
829
+ f"Flow path '{flow_path}' not found in {file_path}"
830
+ )
831
+
832
+ n_samples = None
833
+ checkpoint_state = None
834
+ if checkpoint_bytes is not None:
835
+ try:
836
+ checkpoint_state = pickle.loads(checkpoint_bytes)
837
+ samples_saved = (
838
+ checkpoint_state.get("samples")
839
+ if checkpoint_state
840
+ else None
841
+ )
842
+ if samples_saved is not None:
843
+ n_samples = len(samples_saved)
844
+ if aspire.xp is None and hasattr(samples_saved, "xp"):
845
+ aspire.xp = samples_saved.xp
846
+ except Exception:
847
+ logger.warning(
848
+ "Failed to decode checkpoint; proceeding without resume state."
849
+ )
850
+
851
+ return (
852
+ aspire,
853
+ checkpoint_bytes,
854
+ checkpoint_state,
855
+ sampler_config,
856
+ saved_sampler_type,
857
+ n_samples,
858
+ )
@@ -92,7 +92,7 @@ class BaseTorchFlow(Flow):
92
92
  config = load_from_h5_file(flow_grp, "config")
93
93
  config["dtype"] = decode_dtype(torch, config.get("dtype"))
94
94
  if "data_transform" in flow_grp:
95
- from ..transforms import BaseTransform
95
+ from ...transforms import BaseTransform
96
96
 
97
97
  data_transform = BaseTransform.load(
98
98
  flow_grp,
aspire/samplers/base.py CHANGED
@@ -1,10 +1,12 @@
1
1
  import logging
2
+ import pickle
3
+ from pathlib import Path
2
4
  from typing import Any, Callable
3
5
 
4
6
  from ..flows.base import Flow
5
7
  from ..samples import Samples
6
8
  from ..transforms import IdentityTransform
7
- from ..utils import asarray, track_calls
9
+ from ..utils import AspireFile, asarray, dump_state, track_calls
8
10
 
9
11
  logger = logging.getLogger(__name__)
10
12
 
@@ -49,6 +51,8 @@ class Sampler:
49
51
  self.parameters = parameters
50
52
  self.history = None
51
53
  self.n_likelihood_evaluations = 0
54
+ self._last_checkpoint_state: dict | None = None
55
+ self._last_checkpoint_bytes: bytes | None = None
52
56
  if preconditioning_transform is None:
53
57
  self.preconditioning_transform = IdentityTransform(xp=self.xp)
54
58
  else:
@@ -75,7 +79,7 @@ class Sampler:
75
79
  self.n_likelihood_evaluations += len(samples)
76
80
  return self._log_likelihood(samples)
77
81
 
78
- def config_dict(self, include_sample_calls: bool = True) -> dict:
82
+ def config_dict(self, include_sample_calls: bool = False) -> dict:
79
83
  """
80
84
  Returns a dictionary with the configuration of the sampler.
81
85
 
@@ -83,9 +87,9 @@ class Sampler:
83
87
  ----------
84
88
  include_sample_calls : bool
85
89
  Whether to include the sample calls in the configuration.
86
- Default is True.
90
+ Default is False.
87
91
  """
88
- config = {}
92
+ config = {"sampler_class": self.__class__.__name__}
89
93
  if include_sample_calls:
90
94
  if hasattr(self, "sample") and hasattr(self.sample, "calls"):
91
95
  config["sample_calls"] = self.sample.calls.to_dict(
@@ -96,3 +100,139 @@ class Sampler:
96
100
  "Sampler does not have a sample method with calls attribute."
97
101
  )
98
102
  return config
103
+
104
+ # --- Checkpointing helpers shared across samplers ---
105
+ def _checkpoint_extra_state(self) -> dict:
106
+ """Sampler-specific extras for checkpointing (override in subclasses)."""
107
+ return {}
108
+
109
+ def _restore_extra_state(self, state: dict) -> None:
110
+ """Restore sampler-specific extras (override in subclasses)."""
111
+ _ = state # no-op for base
112
+
113
+ def build_checkpoint_state(
114
+ self,
115
+ samples: Samples,
116
+ iteration: int | None = None,
117
+ meta: dict | None = None,
118
+ ) -> dict:
119
+ """Prepare a serializable checkpoint payload for the sampler state."""
120
+ checkpoint_samples = samples
121
+ base_state = {
122
+ "sampler": self.__class__.__name__,
123
+ "iteration": iteration,
124
+ "samples": checkpoint_samples,
125
+ "config": self.config_dict(include_sample_calls=False),
126
+ "parameters": self.parameters,
127
+ "meta": meta or {},
128
+ }
129
+ base_state.update(self._checkpoint_extra_state())
130
+ return base_state
131
+
132
+ def serialize_checkpoint(
133
+ self, state: dict, protocol: int | None = None
134
+ ) -> bytes:
135
+ """Serialize a checkpoint state to bytes with pickle."""
136
+ protocol = (
137
+ pickle.HIGHEST_PROTOCOL if protocol is None else int(protocol)
138
+ )
139
+ return pickle.dumps(state, protocol=protocol)
140
+
141
+ def default_checkpoint_callback(self, state: dict) -> None:
142
+ """Store the latest checkpoint (state + pickled bytes) on the sampler."""
143
+ self._last_checkpoint_state = state
144
+ self._last_checkpoint_bytes = self.serialize_checkpoint(state)
145
+
146
+ def default_file_checkpoint_callback(
147
+ self, file_path: str | Path | None
148
+ ) -> Callable[[dict], None]:
149
+ """Return a simple default callback that overwrites an HDF5 file."""
150
+ if file_path is None:
151
+ return self.default_checkpoint_callback
152
+ file_path = Path(file_path)
153
+ lower_path = file_path.name.lower()
154
+ if not lower_path.endswith((".h5", ".hdf5")):
155
+ raise ValueError(
156
+ "Checkpoint file must be an HDF5 file (.h5 or .hdf5)."
157
+ )
158
+
159
+ def _callback(state: dict) -> None:
160
+ with AspireFile(file_path, "a") as h5_file:
161
+ self.save_checkpoint_to_hdf(
162
+ state, h5_file, path="checkpoint", dsetname="state"
163
+ )
164
+ self.default_checkpoint_callback(state)
165
+
166
+ return _callback
167
+
168
+ def save_checkpoint_to_hdf(
169
+ self,
170
+ state: dict,
171
+ h5_file,
172
+ path: str = "sampler_checkpoints",
173
+ dsetname: str | None = None,
174
+ protocol: int | None = None,
175
+ ) -> None:
176
+ """Save a checkpoint state into an HDF5 file as a pickled blob."""
177
+ if dsetname is None:
178
+ iter_str = state.get("iteration", "unknown")
179
+ dsetname = f"iter_{iter_str}"
180
+ dump_state(
181
+ state,
182
+ h5_file,
183
+ path=path,
184
+ dsetname=dsetname,
185
+ protocol=protocol or pickle.HIGHEST_PROTOCOL,
186
+ )
187
+
188
+ def load_checkpoint_from_file(
189
+ self,
190
+ file_path: str | Path,
191
+ h5_path: str = "checkpoint",
192
+ dsetname: str = "state",
193
+ ) -> dict:
194
+ """Load a checkpoint dictionary from .pkl or .hdf5 file."""
195
+ file_path = Path(file_path)
196
+ lower_path = file_path.name.lower()
197
+ if lower_path.endswith((".h5", ".hdf5")):
198
+ with AspireFile(file_path, "r") as h5_file:
199
+ data = h5_file[h5_path][dsetname][...]
200
+ checkpoint_bytes = data.tobytes()
201
+ else:
202
+ with open(file_path, "rb") as f:
203
+ checkpoint_bytes = f.read()
204
+ return pickle.loads(checkpoint_bytes)
205
+
206
+ def restore_from_checkpoint(
207
+ self, source: str | bytes | dict
208
+ ) -> tuple[Samples, dict]:
209
+ """Restore sampler state from a checkpoint source."""
210
+ if isinstance(source, str):
211
+ state = self.load_checkpoint_from_file(source)
212
+ elif isinstance(source, bytes):
213
+ state = pickle.loads(source)
214
+ elif isinstance(source, dict):
215
+ state = source
216
+ else:
217
+ raise TypeError("Unsupported checkpoint source type.")
218
+
219
+ samples_saved = state.get("samples")
220
+ if samples_saved is None:
221
+ raise ValueError("Checkpoint missing samples.")
222
+
223
+ samples = Samples.from_samples(
224
+ samples_saved, xp=self.xp, dtype=self.dtype
225
+ )
226
+ # Allow subclasses to restore sampler-specific components
227
+ self._restore_extra_state(state)
228
+ return samples, state
229
+
230
+ @property
231
+ def last_checkpoint_state(self) -> dict | None:
232
+ """Return the most recent checkpoint state stored by the default callback."""
233
+ return self._last_checkpoint_state
234
+
235
+ @property
236
+ def last_checkpoint_bytes(self) -> bytes | None:
237
+ """Return the most recent pickled checkpoint produced by the default callback."""
238
+ return self._last_checkpoint_bytes
@@ -1,3 +1,4 @@
1
+ import copy
1
2
  import logging
2
3
  from typing import Any, Callable
3
4
 
@@ -158,11 +159,24 @@ class SMCSampler(MCMCSampler):
158
159
  target_efficiency: float = 0.5,
159
160
  target_efficiency_rate: float = 1.0,
160
161
  n_final_samples: int | None = None,
162
+ checkpoint_callback: Callable[[dict], None] | None = None,
163
+ checkpoint_every: int | None = None,
164
+ checkpoint_file_path: str | None = None,
165
+ resume_from: str | bytes | dict | None = None,
161
166
  ) -> SMCSamples:
162
- samples = self.draw_initial_samples(n_samples)
163
- samples = SMCSamples.from_samples(
164
- samples, xp=self.xp, beta=0.0, dtype=self.dtype
165
- )
167
+ resumed = resume_from is not None
168
+ if resumed:
169
+ samples, beta, iterations = self.restore_from_checkpoint(
170
+ resume_from
171
+ )
172
+ else:
173
+ samples = self.draw_initial_samples(n_samples)
174
+ samples = SMCSamples.from_samples(
175
+ samples, xp=self.xp, beta=0.0, dtype=self.dtype
176
+ )
177
+ beta = 0.0
178
+ iterations = 0
179
+ self.history = SMCHistory()
166
180
  self.fit_preconditioning_transform(samples.x)
167
181
 
168
182
  if self.xp.isnan(samples.log_q).any():
@@ -178,8 +192,6 @@ class SMCSampler(MCMCSampler):
178
192
  self.sampler_kwargs = self.sampler_kwargs or {}
179
193
  n_final_steps = self.sampler_kwargs.pop("n_final_steps", None)
180
194
 
181
- self.history = SMCHistory()
182
-
183
195
  self.target_efficiency = target_efficiency
184
196
  self.target_efficiency_rate = target_efficiency_rate
185
197
 
@@ -190,7 +202,6 @@ class SMCSampler(MCMCSampler):
190
202
  else:
191
203
  beta_step = np.nan
192
204
  self.adaptive = adaptive
193
- beta = 0.0
194
205
 
195
206
  if min_step is None:
196
207
  if max_n_steps is None:
@@ -202,55 +213,85 @@ class SMCSampler(MCMCSampler):
202
213
  else:
203
214
  self.adaptive_min_step = False
204
215
 
205
- iterations = 0
206
- while True:
207
- iterations += 1
208
-
209
- beta, min_step = self.determine_beta(
210
- samples,
211
- beta,
212
- beta_step,
213
- min_step,
216
+ iterations = iterations or 0
217
+ if checkpoint_callback is None and checkpoint_every is not None:
218
+ checkpoint_callback = self.default_file_checkpoint_callback(
219
+ checkpoint_file_path
214
220
  )
215
- self.history.eff_target.append(
216
- self.current_target_efficiency(beta)
221
+ if checkpoint_callback is not None and checkpoint_every is None:
222
+ checkpoint_every = 1
223
+
224
+ run_smc_loop = True
225
+ if resumed:
226
+ last_beta = self.history.beta[-1] if self.history.beta else beta
227
+ if last_beta >= 1.0:
228
+ run_smc_loop = False
229
+
230
+ def maybe_checkpoint(force: bool = False):
231
+ if checkpoint_callback is None:
232
+ return
233
+ should_checkpoint = force or (
234
+ checkpoint_every is not None
235
+ and checkpoint_every > 0
236
+ and iterations % checkpoint_every == 0
217
237
  )
238
+ if not should_checkpoint:
239
+ return
240
+ state = self.build_checkpoint_state(samples, iterations, beta)
241
+ checkpoint_callback(state)
242
+
243
+ if run_smc_loop:
244
+ while True:
245
+ iterations += 1
246
+
247
+ beta, min_step = self.determine_beta(
248
+ samples,
249
+ beta,
250
+ beta_step,
251
+ min_step,
252
+ )
253
+ self.history.eff_target.append(
254
+ self.current_target_efficiency(beta)
255
+ )
218
256
 
219
- logger.info(f"it {iterations} - beta: {beta}")
220
- self.history.beta.append(beta)
221
-
222
- ess = effective_sample_size(samples.log_weights(beta))
223
- eff = ess / len(samples)
224
- if eff < 0.1:
225
- logger.warning(
226
- f"it {iterations} - Low sample efficiency: {eff:.2f}"
257
+ logger.info(f"it {iterations} - beta: {beta}")
258
+ self.history.beta.append(beta)
259
+
260
+ ess = effective_sample_size(samples.log_weights(beta))
261
+ eff = ess / len(samples)
262
+ if eff < 0.1:
263
+ logger.warning(
264
+ f"it {iterations} - Low sample efficiency: {eff:.2f}"
265
+ )
266
+ self.history.ess.append(ess)
267
+ logger.info(
268
+ f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
269
+ )
270
+ self.history.ess_target.append(
271
+ effective_sample_size(samples.log_weights(1.0))
227
272
  )
228
- self.history.ess.append(ess)
229
- logger.info(
230
- f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
231
- )
232
- self.history.ess_target.append(
233
- effective_sample_size(samples.log_weights(1.0))
234
- )
235
273
 
236
- log_evidence_ratio = samples.log_evidence_ratio(beta)
237
- log_evidence_ratio_var = samples.log_evidence_ratio_variance(beta)
238
- self.history.log_norm_ratio.append(log_evidence_ratio)
239
- self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
240
- logger.info(
241
- f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
242
- )
274
+ log_evidence_ratio = samples.log_evidence_ratio(beta)
275
+ log_evidence_ratio_var = samples.log_evidence_ratio_variance(
276
+ beta
277
+ )
278
+ self.history.log_norm_ratio.append(log_evidence_ratio)
279
+ self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
280
+ logger.info(
281
+ f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
282
+ )
243
283
 
244
- samples = samples.resample(beta, rng=self.rng)
284
+ samples = samples.resample(beta, rng=self.rng)
245
285
 
246
- samples = self.mutate(samples, beta)
247
- if beta == 1.0 or (
248
- max_n_steps is not None and iterations >= max_n_steps
249
- ):
250
- break
286
+ samples = self.mutate(samples, beta)
287
+ maybe_checkpoint()
288
+ if beta == 1.0 or (
289
+ max_n_steps is not None and iterations >= max_n_steps
290
+ ):
291
+ break
251
292
 
252
- # If n_final_samples is not None, perform additional mutations steps
253
- if n_final_samples is not None:
293
+ # If n_final_samples is specified and differs, perform additional mutation steps
294
+ if n_final_samples is not None and len(samples.x) != n_final_samples:
254
295
  logger.info(f"Generating {n_final_samples} final samples")
255
296
  final_samples = samples.resample(
256
297
  1.0, n_samples=n_final_samples, rng=self.rng
@@ -263,6 +304,7 @@ class SMCSampler(MCMCSampler):
263
304
  samples.log_evidence_error = samples.xp.sqrt(
264
305
  samples.xp.sum(asarray(self.history.log_norm_ratio_var, self.xp))
265
306
  )
307
+ maybe_checkpoint(force=True)
266
308
 
267
309
  final_samples = samples.to_standard_samples()
268
310
  logger.info(
@@ -289,6 +331,49 @@ class SMCSampler(MCMCSampler):
289
331
  )
290
332
  return log_prob
291
333
 
334
+ def build_checkpoint_state(
335
+ self, samples: SMCSamples, iteration: int, beta: float
336
+ ) -> dict:
337
+ """Prepare a serializable checkpoint payload for the sampler state."""
338
+ return super().build_checkpoint_state(
339
+ samples,
340
+ iteration,
341
+ meta={"beta": beta},
342
+ )
343
+
344
+ def _checkpoint_extra_state(self) -> dict:
345
+ history_copy = copy.deepcopy(self.history)
346
+ rng_state = (
347
+ self.rng.bit_generator.state
348
+ if hasattr(self.rng, "bit_generator")
349
+ else None
350
+ )
351
+ return {
352
+ "history": history_copy,
353
+ "rng_state": rng_state,
354
+ "sampler_kwargs": getattr(self, "sampler_kwargs", None),
355
+ }
356
+
357
+ def restore_from_checkpoint(
358
+ self, source: str | bytes | dict
359
+ ) -> tuple[SMCSamples, float, int]:
360
+ samples, state = super().restore_from_checkpoint(source)
361
+ meta = state.get("meta", {}) if isinstance(state, dict) else {}
362
+ beta = None
363
+ if isinstance(meta, dict):
364
+ beta = meta.get("beta", None)
365
+ if beta is None:
366
+ beta = state.get("beta", 0.0)
367
+ iteration = state.get("iteration", 0)
368
+ self.history = state.get("history", SMCHistory())
369
+ rng_state = state.get("rng_state")
370
+ if rng_state is not None and hasattr(self.rng, "bit_generator"):
371
+ self.rng.bit_generator.state = rng_state
372
+ samples = SMCSamples.from_samples(
373
+ samples, xp=self.xp, beta=beta, dtype=self.dtype
374
+ )
375
+ return samples, beta, iteration
376
+
292
377
 
293
378
  class NumpySMCSampler(SMCSampler):
294
379
  def __init__(
@@ -84,6 +84,10 @@ class BlackJAXSMC(SMCSampler):
84
84
  n_final_samples: int | None = None,
85
85
  sampler_kwargs: dict | None = None,
86
86
  rng_key=None,
87
+ checkpoint_callback=None,
88
+ checkpoint_every: int | None = None,
89
+ checkpoint_file_path: str | None = None,
90
+ resume_from: str | bytes | dict | None = None,
87
91
  ):
88
92
  """Sample using BlackJAX SMC.
89
93
 
@@ -132,6 +136,10 @@ class BlackJAXSMC(SMCSampler):
132
136
  target_efficiency=target_efficiency,
133
137
  target_efficiency_rate=target_efficiency_rate,
134
138
  n_final_samples=n_final_samples,
139
+ checkpoint_callback=checkpoint_callback,
140
+ checkpoint_every=checkpoint_every,
141
+ checkpoint_file_path=checkpoint_file_path,
142
+ resume_from=resume_from,
135
143
  )
136
144
 
137
145
  def mutate(self, particles, beta, n_steps=None):
@@ -21,6 +21,10 @@ class EmceeSMC(NumpySMCSampler):
21
21
  target_efficiency_rate: float = 1.0,
22
22
  sampler_kwargs: dict | None = None,
23
23
  n_final_samples: int | None = None,
24
+ checkpoint_callback=None,
25
+ checkpoint_every: int | None = None,
26
+ checkpoint_file_path: str | None = None,
27
+ resume_from: str | bytes | dict | None = None,
24
28
  ):
25
29
  self.sampler_kwargs = sampler_kwargs or {}
26
30
  self.sampler_kwargs.setdefault("nsteps", 5 * self.dims)
@@ -33,6 +37,10 @@ class EmceeSMC(NumpySMCSampler):
33
37
  target_efficiency=target_efficiency,
34
38
  target_efficiency_rate=target_efficiency_rate,
35
39
  n_final_samples=n_final_samples,
40
+ checkpoint_callback=checkpoint_callback,
41
+ checkpoint_every=checkpoint_every,
42
+ checkpoint_file_path=checkpoint_file_path,
43
+ resume_from=resume_from,
36
44
  )
37
45
 
38
46
  def mutate(self, particles, beta, n_steps=None):
@@ -8,10 +8,10 @@ from ...utils import (
8
8
  determine_backend_name,
9
9
  track_calls,
10
10
  )
11
- from .base import NumpySMCSampler
11
+ from .base import SMCSampler
12
12
 
13
13
 
14
- class MiniPCNSMC(NumpySMCSampler):
14
+ class MiniPCNSMC(SMCSampler):
15
15
  """MiniPCN SMC sampler."""
16
16
 
17
17
  rng = None
@@ -32,6 +32,10 @@ class MiniPCNSMC(NumpySMCSampler):
32
32
  n_final_samples: int | None = None,
33
33
  sampler_kwargs: dict | None = None,
34
34
  rng: np.random.Generator | None = None,
35
+ checkpoint_callback=None,
36
+ checkpoint_every: int | None = None,
37
+ checkpoint_file_path: str | None = None,
38
+ resume_from: str | bytes | dict | None = None,
35
39
  ):
36
40
  from orng import ArrayRNG
37
41
 
@@ -50,6 +54,10 @@ class MiniPCNSMC(NumpySMCSampler):
50
54
  n_final_samples=n_final_samples,
51
55
  min_step=min_step,
52
56
  max_n_steps=max_n_steps,
57
+ checkpoint_callback=checkpoint_callback,
58
+ checkpoint_every=checkpoint_every,
59
+ checkpoint_file_path=checkpoint_file_path,
60
+ resume_from=resume_from,
53
61
  )
54
62
 
55
63
  def mutate(self, particles, beta, n_steps=None):
aspire/samples.py CHANGED
@@ -9,19 +9,18 @@ from typing import Any, Callable
9
9
  import numpy as np
10
10
  from array_api_compat import (
11
11
  array_namespace,
12
- is_numpy_namespace,
13
- to_device,
14
12
  )
15
- from array_api_compat import device as api_device
16
13
  from array_api_compat.common._typing import Array
17
14
  from matplotlib.figure import Figure
18
15
 
19
16
  from .utils import (
20
17
  asarray,
21
18
  convert_dtype,
19
+ infer_device,
22
20
  logsumexp,
23
21
  recursively_save_to_h5_file,
24
22
  resolve_dtype,
23
+ safe_to_device,
25
24
  to_numpy,
26
25
  )
27
26
 
@@ -67,8 +66,6 @@ class BaseSamples:
67
66
  if self.xp is None:
68
67
  self.xp = array_namespace(self.x)
69
68
  # Numpy arrays need to be on the CPU before being converted
70
- if is_numpy_namespace(self.xp):
71
- self.device = "cpu"
72
69
  if self.dtype is not None:
73
70
  self.dtype = resolve_dtype(self.dtype, self.xp)
74
71
  else:
@@ -76,7 +73,7 @@ class BaseSamples:
76
73
  self.dtype = None
77
74
  self.x = self.array_to_namespace(self.x, dtype=self.dtype)
78
75
  if self.device is None:
79
- self.device = api_device(self.x)
76
+ self.device = infer_device(self.x, self.xp)
80
77
  if self.log_likelihood is not None:
81
78
  self.log_likelihood = self.array_to_namespace(
82
79
  self.log_likelihood, dtype=self.dtype
@@ -140,8 +137,7 @@ class BaseSamples:
140
137
  else:
141
138
  kwargs["dtype"] = self.dtype
142
139
  x = asarray(x, self.xp, **kwargs)
143
- if self.device:
144
- x = to_device(x, self.device)
140
+ x = safe_to_device(x, self.device, self.xp)
145
141
  return x
146
142
 
147
143
  def to_dict(self, flat: bool = True):
@@ -174,7 +170,6 @@ class BaseSamples:
174
170
  ----------
175
171
  parameters : list[str] | None
176
172
  List of parameters to plot. If None, all parameters are plotted.
177
- fig : matplotlib.figure.Figure | None
178
173
  Figure to plot on. If None, a new figure is created.
179
174
  **kwargs : dict
180
175
  Additional keyword arguments to pass to corner.corner(). Kwargs
@@ -300,6 +295,13 @@ class BaseSamples:
300
295
  def __setstate__(self, state):
301
296
  # Restore xp by checking the namespace of x
302
297
  state["xp"] = array_namespace(state["x"])
298
+ # device may be string; leave as-is or None
299
+ device = state.get("device")
300
+ if device is not None and "jax" in getattr(
301
+ state["xp"], "__name__", ""
302
+ ):
303
+ device = None
304
+ state["device"] = device
303
305
  self.__dict__.update(state)
304
306
 
305
307
 
aspire/utils.py CHANGED
@@ -2,9 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  import inspect
4
4
  import logging
5
+ import pickle
5
6
  from contextlib import contextmanager
6
7
  from dataclasses import dataclass
7
8
  from functools import partial
9
+ from io import BytesIO
8
10
  from typing import TYPE_CHECKING, Any
9
11
 
10
12
  import array_api_compat.numpy as np
@@ -601,6 +603,7 @@ def decode_from_hdf5(value: Any) -> Any:
601
603
  return None
602
604
  if value == "__empty_dict__":
603
605
  return {}
606
+ return value
604
607
 
605
608
  if isinstance(value, np.ndarray):
606
609
  # Try to collapse 0-D arrays into scalars
@@ -629,6 +632,101 @@ def decode_from_hdf5(value: Any) -> Any:
629
632
  return value
630
633
 
631
634
 
635
+ def dump_pickle_to_hdf(memfp, fp, path=None, dsetname="state"):
636
+ """Dump pickled data to an HDF5 file object."""
637
+ memfp.seek(0)
638
+ bdata = np.frombuffer(memfp.read(), dtype="S1")
639
+ target = fp.require_group(path) if path is not None else fp
640
+ if dsetname not in target:
641
+ target.create_dataset(
642
+ dsetname, shape=bdata.shape, maxshape=(None,), dtype=bdata.dtype
643
+ )
644
+ elif bdata.size != target[dsetname].shape[0]:
645
+ target[dsetname].resize((bdata.size,))
646
+ target[dsetname][:] = bdata
647
+
648
+
649
+ def dump_state(
650
+ state,
651
+ fp,
652
+ path=None,
653
+ dsetname="state",
654
+ protocol=pickle.HIGHEST_PROTOCOL,
655
+ ):
656
+ """Pickle a state object and store it in an HDF5 dataset."""
657
+ memfp = BytesIO()
658
+ pickle.dump(state, memfp, protocol=protocol)
659
+ dump_pickle_to_hdf(memfp, fp, path=path, dsetname=dsetname)
660
+
661
+
662
+ def resolve_xp(xp_name: str | None):
663
+ """
664
+ Resolve a backend name to the corresponding array_api_compat module.
665
+
666
+ Returns None if the name is None or cannot be resolved.
667
+ """
668
+ if xp_name is None:
669
+ return None
670
+ name = xp_name.lower()
671
+ if name.startswith("array_api_compat."):
672
+ name = name.removeprefix("array_api_compat.")
673
+ try:
674
+ if name in {"numpy", "numpy.ndarray"}:
675
+ import array_api_compat.numpy as np_xp
676
+
677
+ return np_xp
678
+ if name in {"jax", "jax.numpy"}:
679
+ import jax.numpy as jnp
680
+
681
+ return jnp
682
+ if name in {"torch"}:
683
+ import array_api_compat.torch as torch_xp
684
+
685
+ return torch_xp
686
+ except Exception:
687
+ logger.warning(
688
+ "Failed to resolve xp '%s', defaulting to None", xp_name
689
+ )
690
+ return None
691
+
692
+
693
+ def infer_device(x, xp):
694
+ """
695
+ Best-effort device inference that avoids non-portable identifiers.
696
+
697
+ Returns None for numpy/jax backends; returns the backend device object
698
+ for torch/cupy if available.
699
+ """
700
+ if xp is None or is_numpy_namespace(xp) or is_jax_namespace(xp):
701
+ return None
702
+ try:
703
+ from array_api_compat import device
704
+
705
+ return device(x)
706
+ except Exception:
707
+ return None
708
+
709
+
710
+ def safe_to_device(x, device, xp):
711
+ """
712
+ Move to device if specified; otherwise return input.
713
+
714
+ Skips moves for numpy/jax/None devices; logs and returns input on failure.
715
+ """
716
+ if device is None:
717
+ return x
718
+ if xp is None or is_numpy_namespace(xp) or is_jax_namespace(xp):
719
+ return x
720
+ try:
721
+ return to_device(x, device)
722
+ except Exception:
723
+ logger.warning(
724
+ "Failed to move array to device %s; leaving on current device",
725
+ device,
726
+ )
727
+ return x
728
+
729
+
632
730
  def recursively_save_to_h5_file(h5_file, path, dictionary):
633
731
  """Save a dictionary to an HDF5 file with flattened keys under a given group path."""
634
732
  # Ensure the group exists (or open it if already present)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a10
3
+ Version: 0.1.0a11
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -1,28 +1,28 @@
1
1
  aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
2
- aspire/aspire.py,sha256=M5o-QxLthE_5daa1trgUfWxPz-g4rmpEUKimKosw4lw,17400
2
+ aspire/aspire.py,sha256=lr0bD5GDWdlAGfODGzj4BoELKUF5HYAMb8yYGvRR_y0,30860
3
3
  aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
4
4
  aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
5
- aspire/samples.py,sha256=z5x5hpWuVFH1hYhltmROAe8pbWxGD2UvHi3vcc132dg,19399
5
+ aspire/samples.py,sha256=v7y8DkirUCHOJbCE-o9y2K7xzU2HicIo_O0CdFhLgXE,19478
6
6
  aspire/transforms.py,sha256=CHrfPQHEyHQ9I0WWiAgUplWwxYypZsD7uCYIHUbSFtY,24974
7
- aspire/utils.py,sha256=eKGmchpuoNL15Xbu-AGoeZ00PcQEykEQiDZMnnRyV6A,24234
7
+ aspire/utils.py,sha256=87avRkTce9QvELpIcqlKauSoUZSYi1fqe1asC97TzqA,26947
8
8
  aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
9
9
  aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
10
10
  aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
11
11
  aspire/flows/jax/flows.py,sha256=1HnVgQ1GUXNcvxiZqEV19H2QI9Th5bWX_QbNfGaUhuA,6625
12
12
  aspire/flows/jax/utils.py,sha256=5T6UrgpARG9VywC9qmTl45LjyZWuEdkW3XUladE6xJE,1518
13
13
  aspire/flows/torch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- aspire/flows/torch/flows.py,sha256=0_YkiMT49QolyQnEFsh28tfKLnURVF0Z6aTnaWLIUDI,11672
14
+ aspire/flows/torch/flows.py,sha256=QcQOcFZEsLWHPwbQUFGOFdfEslyc59Vf_UEsS0xAGPo,11673
15
15
  aspire/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- aspire/samplers/base.py,sha256=VEHawyVA33jXHMo63p5hBHkp9k2qxU_bOxh5iaZSXew,3011
16
+ aspire/samplers/base.py,sha256=ygqrvqSedWSb0cz8DQ_MHokOOxi6aBRdHxf_qoEPwUE,8243
17
17
  aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M,676
18
18
  aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
19
19
  aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- aspire/samplers/smc/base.py,sha256=66f_ORUvcKRqMIW35qjhUc-c0PFuY87lJa91MpSaTZI,10729
21
- aspire/samplers/smc/blackjax.py,sha256=IcTguAETiPmgFofmVW2GN40P5HBIxkmyd2VR8AU8f4k,12115
22
- aspire/samplers/smc/emcee.py,sha256=Wm0vvAlCcRhJMBt7_fU2ZnjDb8SN8jgUOTXLzNstRpA,2516
23
- aspire/samplers/smc/minipcn.py,sha256=iQUBBwHZ_D4CxNjARMngklRvx6yTlEDKdeidyYCgqM4,3003
24
- aspire_inference-0.1.0a10.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
- aspire_inference-0.1.0a10.dist-info/METADATA,sha256=SdvlXKjQn0uJQPpWzoZKAH3oMJSKpZnvlUzxPsIwNlY,3869
26
- aspire_inference-0.1.0a10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- aspire_inference-0.1.0a10.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
- aspire_inference-0.1.0a10.dist-info/RECORD,,
20
+ aspire/samplers/smc/base.py,sha256=40A9yVuKS1F8cPzbfVQ9rNk3y07mnkfbuyRDIh_fy5A,14122
21
+ aspire/samplers/smc/blackjax.py,sha256=2riWDSRmpL5lGmnhNtdieiRs0oYC6XZA2X-nVlQaqpE,12490
22
+ aspire/samplers/smc/emcee.py,sha256=4CI9GvH69FCoLiFBbKKYwYocYyiM95IijC5EvrcAmUo,2891
23
+ aspire/samplers/smc/minipcn.py,sha256=IJ5466VvARd4qZCWXXl-l3BPaKW1AgcwmbP3ISL2bto,3368
24
+ aspire_inference-0.1.0a11.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
+ aspire_inference-0.1.0a11.dist-info/METADATA,sha256=lLd2d5HR-t942wKLyYbdJ1DL9CKl7tkpBul_vX8DU4M,3869
26
+ aspire_inference-0.1.0a11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ aspire_inference-0.1.0a11.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
+ aspire_inference-0.1.0a11.dist-info/RECORD,,