aspire-inference 0.1.0a10__py3-none-any.whl → 0.1.0a12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/aspire.py CHANGED
@@ -1,5 +1,8 @@
1
+ import copy
1
2
  import logging
2
3
  import multiprocessing as mp
4
+ import pickle
5
+ from contextlib import contextmanager
3
6
  from inspect import signature
4
7
  from typing import Any, Callable
5
8
 
@@ -8,13 +11,20 @@ import h5py
8
11
  from .flows import get_flow_wrapper
9
12
  from .flows.base import Flow
10
13
  from .history import History
14
+ from .samplers.base import Sampler
11
15
  from .samples import Samples
12
16
  from .transforms import (
13
17
  CompositeTransform,
14
18
  FlowPreconditioningTransform,
15
19
  FlowTransform,
16
20
  )
17
- from .utils import recursively_save_to_h5_file
21
+ from .utils import (
22
+ AspireFile,
23
+ function_id,
24
+ load_from_h5_file,
25
+ recursively_save_to_h5_file,
26
+ resolve_xp,
27
+ )
18
28
 
19
29
  logger = logging.getLogger(__name__)
20
30
 
@@ -102,6 +112,7 @@ class Aspire:
102
112
  self.dtype = dtype
103
113
 
104
114
  self._flow = flow
115
+ self._sampler = None
105
116
 
106
117
  @property
107
118
  def flow(self):
@@ -114,7 +125,7 @@ class Aspire:
114
125
  self._flow = flow
115
126
 
116
127
  @property
117
- def sampler(self):
128
+ def sampler(self) -> Sampler | None:
118
129
  """The sampler object."""
119
130
  return self._sampler
120
131
 
@@ -192,7 +203,29 @@ class Aspire:
192
203
  **self.flow_kwargs,
193
204
  )
194
205
 
195
- def fit(self, samples: Samples, **kwargs) -> History:
206
+ def fit(
207
+ self,
208
+ samples: Samples,
209
+ checkpoint_path: str | None = None,
210
+ checkpoint_save_config: bool = True,
211
+ overwrite: bool = False,
212
+ **kwargs,
213
+ ) -> History:
214
+ """Fit the normalizing flow to the provided samples.
215
+
216
+ Parameters
217
+ ----------
218
+ samples : Samples
219
+ The samples to fit the flow to.
220
+ checkpoint_path : str | None
221
+ Path to save the checkpoint. If None, no checkpoint is saved.
222
+ checkpoint_save_config : bool
223
+ Whether to save the Aspire configuration to the checkpoint.
224
+ overwrite : bool
225
+ Whether to overwrite an existing flow in the checkpoint file.
226
+ kwargs : dict
227
+ Keyword arguments to pass to the flow's fit method.
228
+ """
196
229
  if self.xp is None:
197
230
  self.xp = samples.xp
198
231
 
@@ -202,6 +235,28 @@ class Aspire:
202
235
  self.training_samples = samples
203
236
  logger.info(f"Training with {len(samples.x)} samples")
204
237
  history = self.flow.fit(samples.x, **kwargs)
238
+ defaults = getattr(self, "_checkpoint_defaults", None)
239
+ if checkpoint_path is None and defaults:
240
+ checkpoint_path = defaults["path"]
241
+ checkpoint_save_config = defaults["save_config"]
242
+ saved_config = (
243
+ defaults.get("saved_config", False) if defaults else False
244
+ )
245
+ if checkpoint_path is not None:
246
+ with AspireFile(checkpoint_path, "a") as h5_file:
247
+ if checkpoint_save_config and not saved_config:
248
+ if "aspire_config" in h5_file:
249
+ del h5_file["aspire_config"]
250
+ self.save_config(h5_file, include_sampler_config=False)
251
+ if defaults is not None:
252
+ defaults["saved_config"] = True
253
+ # Save flow only if missing or overwrite=True
254
+ if "flow" in h5_file:
255
+ if overwrite:
256
+ del h5_file["flow"]
257
+ self.save_flow(h5_file)
258
+ else:
259
+ self.save_flow(h5_file)
205
260
  return history
206
261
 
207
262
  def get_sampler_class(self, sampler_type: str) -> Callable:
@@ -241,6 +296,13 @@ class Aspire:
241
296
  ----------
242
297
  sampler_type : str
243
298
  The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
299
+ preconditioning: str
300
+ Type of preconditioning to apply in the sampler. Options are
301
+ 'default', 'flow', or 'none'.
302
+ preconditioning_kwargs: dict
303
+ Keyword arguments to pass to the preconditioning transform.
304
+ kwargs : dict
305
+ Keyword arguments to pass to the sampler.
244
306
  """
245
307
  SamplerClass = self.get_sampler_class(sampler_type)
246
308
 
@@ -304,6 +366,9 @@ class Aspire:
304
366
  return_history: bool = False,
305
367
  preconditioning: str | None = None,
306
368
  preconditioning_kwargs: dict | None = None,
369
+ checkpoint_path: str | None = None,
370
+ checkpoint_every: int = 1,
371
+ checkpoint_save_config: bool = True,
307
372
  **kwargs,
308
373
  ) -> Samples:
309
374
  """Draw samples from the posterior distribution.
@@ -342,6 +407,14 @@ class Aspire:
342
407
  will default to 'none' and the other samplers to 'default'
343
408
  preconditioning_kwargs: dict
344
409
  Keyword arguments to pass to the preconditioning transform.
410
+ checkpoint_path : str | None
411
+ Path to save the checkpoint. If None, no checkpoint is saved unless
412
+ within an :py:meth:`auto_checkpoint` context or a custom callback
413
+ is provided.
414
+ checkpoint_every : int
415
+ Frequency (in number of sampler iterations) to save the checkpoint.
416
+ checkpoint_save_config : bool
417
+ Whether to save the Aspire configuration to the checkpoint.
345
418
  kwargs : dict
346
419
  Keyword arguments to pass to the sampler. These are passed
347
420
  automatically to the init method of the sampler or to the sample
@@ -352,6 +425,22 @@ class Aspire:
352
425
  samples : Samples
353
426
  Samples object contain samples and their corresponding weights.
354
427
  """
428
+ if (
429
+ sampler == "importance"
430
+ and hasattr(self, "_resume_sampler_type")
431
+ and self._resume_sampler_type
432
+ ):
433
+ sampler = self._resume_sampler_type
434
+
435
+ if "resume_from" not in kwargs and hasattr(
436
+ self, "_resume_from_default"
437
+ ):
438
+ kwargs["resume_from"] = self._resume_from_default
439
+ if hasattr(self, "_resume_overrides"):
440
+ kwargs.update(self._resume_overrides)
441
+ if hasattr(self, "_resume_n_samples") and n_samples == 1000:
442
+ n_samples = self._resume_n_samples
443
+
355
444
  SamplerClass = self.get_sampler_class(sampler)
356
445
  # Determine sampler initialization parameters
357
446
  # and remove them from kwargs
@@ -373,7 +462,73 @@ class Aspire:
373
462
  preconditioning_kwargs=preconditioning_kwargs,
374
463
  **sampler_kwargs,
375
464
  )
465
+ self._last_sampler_type = sampler
466
+ # Auto-checkpoint convenience: set defaults for checkpointing to a single file
467
+ defaults = getattr(self, "_checkpoint_defaults", None)
468
+ if checkpoint_path is None and defaults:
469
+ checkpoint_path = defaults["path"]
470
+ checkpoint_every = defaults["every"]
471
+ checkpoint_save_config = defaults["save_config"]
472
+ saved_flow = defaults.get("saved_flow", False) if defaults else False
473
+ saved_config = (
474
+ defaults.get("saved_config", False) if defaults else False
475
+ )
476
+ if checkpoint_path is not None:
477
+ kwargs.setdefault("checkpoint_file_path", checkpoint_path)
478
+ kwargs.setdefault("checkpoint_every", checkpoint_every)
479
+ with AspireFile(checkpoint_path, "a") as h5_file:
480
+ if checkpoint_save_config:
481
+ if "aspire_config" in h5_file:
482
+ del h5_file["aspire_config"]
483
+ self.save_config(
484
+ h5_file,
485
+ include_sampler_config=True,
486
+ include_sample_calls=False,
487
+ )
488
+ saved_config = True
489
+ if defaults is not None:
490
+ defaults["saved_config"] = True
491
+ if (
492
+ self.flow is not None
493
+ and not saved_flow
494
+ and "flow" not in h5_file
495
+ ):
496
+ self.save_flow(h5_file)
497
+ saved_flow = True
498
+ if defaults is not None:
499
+ defaults["saved_flow"] = True
500
+
376
501
  samples = self._sampler.sample(n_samples, **kwargs)
502
+ self._last_sample_posterior_kwargs = {
503
+ "n_samples": n_samples,
504
+ "sampler": sampler,
505
+ "xp": xp,
506
+ "return_history": return_history,
507
+ "preconditioning": preconditioning,
508
+ "preconditioning_kwargs": preconditioning_kwargs,
509
+ "sampler_init_kwargs": sampler_kwargs,
510
+ "sample_kwargs": copy.deepcopy(kwargs),
511
+ }
512
+ if checkpoint_path is not None:
513
+ with AspireFile(checkpoint_path, "a") as h5_file:
514
+ if checkpoint_save_config and not saved_config:
515
+ if "aspire_config" in h5_file:
516
+ del h5_file["aspire_config"]
517
+ self.save_config(
518
+ h5_file,
519
+ include_sampler_config=True,
520
+ include_sample_calls=False,
521
+ )
522
+ if defaults is not None:
523
+ defaults["saved_config"] = True
524
+ if (
525
+ self.flow is not None
526
+ and not saved_flow
527
+ and "flow" not in h5_file
528
+ ):
529
+ self.save_flow(h5_file)
530
+ if defaults is not None:
531
+ defaults["saved_flow"] = True
377
532
  if xp is not None:
378
533
  samples = samples.to_namespace(xp)
379
534
  samples.parameters = self.parameters
@@ -388,6 +543,122 @@ class Aspire:
388
543
  else:
389
544
  return samples
390
545
 
546
+ @classmethod
547
+ def resume_from_file(
548
+ cls,
549
+ file_path: str,
550
+ *,
551
+ log_likelihood: Callable,
552
+ log_prior: Callable,
553
+ sampler: str | None = None,
554
+ checkpoint_path: str = "checkpoint",
555
+ checkpoint_dset: str = "state",
556
+ flow_path: str = "flow",
557
+ config_path: str = "aspire_config",
558
+ resume_kwargs: dict | None = None,
559
+ ):
560
+ """
561
+ Recreate an Aspire object from a single file and prepare to resume sampling.
562
+
563
+ Parameters
564
+ ----------
565
+ file_path : str
566
+ Path to the HDF5 file containing config, flow, and checkpoint.
567
+ log_likelihood : Callable
568
+ Log-likelihood function (required, not pickled).
569
+ log_prior : Callable
570
+ Log-prior function (required, not pickled).
571
+ sampler : str
572
+ Sampler type to use (e.g., 'smc', 'minipcn_smc', 'emcee_smc'). If None,
573
+ will attempt to infer from saved config or checkpoint metadata.
574
+ checkpoint_path : str
575
+ HDF5 group path where the checkpoint is stored.
576
+ checkpoint_dset : str
577
+ Dataset name within the checkpoint group.
578
+ flow_path : str
579
+ HDF5 path to the saved flow.
580
+ config_path : str
581
+ HDF5 path to the saved Aspire config.
582
+ resume_kwargs : dict | None
583
+ Optional overrides to apply when resuming (e.g., checkpoint_every).
584
+ """
585
+ (
586
+ aspire,
587
+ checkpoint_bytes,
588
+ checkpoint_state,
589
+ sampler_config,
590
+ saved_sampler_type,
591
+ n_samples,
592
+ ) = cls._build_aspire_from_file(
593
+ file_path=file_path,
594
+ log_likelihood=log_likelihood,
595
+ log_prior=log_prior,
596
+ checkpoint_path=checkpoint_path,
597
+ checkpoint_dset=checkpoint_dset,
598
+ flow_path=flow_path,
599
+ config_path=config_path,
600
+ )
601
+
602
+ sampler_config = sampler_config or {}
603
+ sampler_config.pop("sampler_class", None)
604
+
605
+ if checkpoint_bytes is not None:
606
+ aspire._resume_from_default = checkpoint_bytes
607
+ aspire._resume_sampler_type = (
608
+ sampler
609
+ or saved_sampler_type
610
+ or (
611
+ checkpoint_state.get("sampler")
612
+ if checkpoint_state
613
+ else None
614
+ )
615
+ )
616
+ aspire._resume_n_samples = n_samples
617
+ aspire._resume_overrides = resume_kwargs or {}
618
+ aspire._resume_sampler_config = sampler_config
619
+ aspire._checkpoint_defaults = {
620
+ "path": file_path,
621
+ "every": 1,
622
+ "save_config": False,
623
+ "save_flow": False,
624
+ "saved_config": False,
625
+ "saved_flow": False,
626
+ }
627
+ return aspire
628
+
629
+ @contextmanager
630
+ def auto_checkpoint(
631
+ self,
632
+ path: str,
633
+ every: int = 1,
634
+ save_config: bool = True,
635
+ save_flow: bool = True,
636
+ ):
637
+ """
638
+ Context manager to auto-save checkpoints, config, and flow to a file.
639
+
640
+ Within the context, sample_posterior will default to writing checkpoints
641
+ to the given path with the specified frequency, and will append config/flow
642
+ after sampling.
643
+ """
644
+ prev = getattr(self, "_checkpoint_defaults", None)
645
+ self._checkpoint_defaults = {
646
+ "path": path,
647
+ "every": every,
648
+ "save_config": save_config,
649
+ "save_flow": save_flow,
650
+ "saved_config": False,
651
+ "saved_flow": False,
652
+ }
653
+ try:
654
+ yield self
655
+ finally:
656
+ if prev is None:
657
+ if hasattr(self, "_checkpoint_defaults"):
658
+ delattr(self, "_checkpoint_defaults")
659
+ else:
660
+ self._checkpoint_defaults = prev
661
+
391
662
  def enable_pool(self, pool: mp.Pool, **kwargs):
392
663
  """Context manager to temporarily replace the log_likelihood method
393
664
  with a version that uses a multiprocessing pool to parallelize
@@ -417,8 +688,8 @@ class Aspire:
417
688
  method of the sampler.
418
689
  """
419
690
  config = {
420
- "log_likelihood": self.log_likelihood.__name__,
421
- "log_prior": self.log_prior.__name__,
691
+ "log_likelihood": function_id(self.log_likelihood),
692
+ "log_prior": function_id(self.log_prior),
422
693
  "dims": self.dims,
423
694
  "parameters": self.parameters,
424
695
  "periodic_parameters": self.periodic_parameters,
@@ -432,12 +703,16 @@ class Aspire:
432
703
  "flow_kwargs": self.flow_kwargs,
433
704
  "eps": self.eps,
434
705
  }
706
+ if hasattr(self, "_last_sampler_type"):
707
+ config["sampler_type"] = self._last_sampler_type
435
708
  if include_sampler_config:
709
+ if self.sampler is None:
710
+ raise ValueError("Sampler has not been initialized.")
436
711
  config["sampler_config"] = self.sampler.config_dict(**kwargs)
437
712
  return config
438
713
 
439
714
  def save_config(
440
- self, h5_file: h5py.File, path="aspire_config", **kwargs
715
+ self, h5_file: h5py.File | AspireFile, path="aspire_config", **kwargs
441
716
  ) -> None:
442
717
  """Save the configuration to an HDF5 file.
443
718
 
@@ -484,6 +759,7 @@ class Aspire:
484
759
  FlowClass, xp = get_flow_wrapper(
485
760
  backend=self.flow_backend, flow_matching=self.flow_matching
486
761
  )
762
+ logger.debug(f"Loading flow of type {FlowClass} from {path}")
487
763
  self._flow = FlowClass.load(h5_file, path=path)
488
764
 
489
765
  def save_config_to_json(self, filename: str) -> None:
@@ -504,3 +780,80 @@ class Aspire:
504
780
  x, log_q = self.flow.sample_and_log_prob(n_samples)
505
781
  samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
506
782
  return samples
783
+
784
+ # --- Resume helpers ---
785
+ @staticmethod
786
+ def _build_aspire_from_file(
787
+ file_path: str,
788
+ log_likelihood: Callable,
789
+ log_prior: Callable,
790
+ checkpoint_path: str,
791
+ checkpoint_dset: str,
792
+ flow_path: str,
793
+ config_path: str,
794
+ ):
795
+ """Construct an Aspire instance, load flow, and gather checkpoint metadata from file."""
796
+ with AspireFile(file_path, "r") as h5_file:
797
+ if config_path not in h5_file:
798
+ raise ValueError(
799
+ f"Config path '{config_path}' not found in {file_path}"
800
+ )
801
+ config_dict = load_from_h5_file(h5_file, config_path)
802
+ try:
803
+ checkpoint_bytes = h5_file[checkpoint_path][checkpoint_dset][
804
+ ...
805
+ ].tobytes()
806
+ except Exception:
807
+ logger.warning(
808
+ "Checkpoint not found at %s/%s in %s; will resume without a checkpoint.",
809
+ checkpoint_path,
810
+ checkpoint_dset,
811
+ file_path,
812
+ )
813
+ checkpoint_bytes = None
814
+
815
+ sampler_config = config_dict.pop("sampler_config", None)
816
+ saved_sampler_type = config_dict.pop("sampler_type", None)
817
+ if isinstance(config_dict.get("xp"), str):
818
+ config_dict["xp"] = resolve_xp(config_dict["xp"])
819
+ config_dict["log_likelihood"] = log_likelihood
820
+ config_dict["log_prior"] = log_prior
821
+
822
+ aspire = Aspire(**config_dict)
823
+
824
+ with AspireFile(file_path, "r") as h5_file:
825
+ if flow_path in h5_file:
826
+ logger.info(f"Loading flow from {flow_path} in {file_path}")
827
+ aspire.load_flow(h5_file, path=flow_path)
828
+ else:
829
+ raise ValueError(
830
+ f"Flow path '{flow_path}' not found in {file_path}"
831
+ )
832
+
833
+ n_samples = None
834
+ checkpoint_state = None
835
+ if checkpoint_bytes is not None:
836
+ try:
837
+ checkpoint_state = pickle.loads(checkpoint_bytes)
838
+ samples_saved = (
839
+ checkpoint_state.get("samples")
840
+ if checkpoint_state
841
+ else None
842
+ )
843
+ if samples_saved is not None:
844
+ n_samples = len(samples_saved)
845
+ if aspire.xp is None and hasattr(samples_saved, "xp"):
846
+ aspire.xp = samples_saved.xp
847
+ except Exception:
848
+ logger.warning(
849
+ "Failed to decode checkpoint; proceeding without resume state."
850
+ )
851
+
852
+ return (
853
+ aspire,
854
+ checkpoint_bytes,
855
+ checkpoint_state,
856
+ sampler_config,
857
+ saved_sampler_type,
858
+ n_samples,
859
+ )
@@ -92,7 +92,7 @@ class BaseTorchFlow(Flow):
92
92
  config = load_from_h5_file(flow_grp, "config")
93
93
  config["dtype"] = decode_dtype(torch, config.get("dtype"))
94
94
  if "data_transform" in flow_grp:
95
- from ..transforms import BaseTransform
95
+ from ...transforms import BaseTransform
96
96
 
97
97
  data_transform = BaseTransform.load(
98
98
  flow_grp,
aspire/samplers/base.py CHANGED
@@ -1,10 +1,12 @@
1
1
  import logging
2
+ import pickle
3
+ from pathlib import Path
2
4
  from typing import Any, Callable
3
5
 
4
6
  from ..flows.base import Flow
5
7
  from ..samples import Samples
6
8
  from ..transforms import IdentityTransform
7
- from ..utils import asarray, track_calls
9
+ from ..utils import AspireFile, asarray, dump_state, track_calls
8
10
 
9
11
  logger = logging.getLogger(__name__)
10
12
 
@@ -49,6 +51,8 @@ class Sampler:
49
51
  self.parameters = parameters
50
52
  self.history = None
51
53
  self.n_likelihood_evaluations = 0
54
+ self._last_checkpoint_state: dict | None = None
55
+ self._last_checkpoint_bytes: bytes | None = None
52
56
  if preconditioning_transform is None:
53
57
  self.preconditioning_transform = IdentityTransform(xp=self.xp)
54
58
  else:
@@ -75,7 +79,7 @@ class Sampler:
75
79
  self.n_likelihood_evaluations += len(samples)
76
80
  return self._log_likelihood(samples)
77
81
 
78
- def config_dict(self, include_sample_calls: bool = True) -> dict:
82
+ def config_dict(self, include_sample_calls: bool = False) -> dict:
79
83
  """
80
84
  Returns a dictionary with the configuration of the sampler.
81
85
 
@@ -83,9 +87,9 @@ class Sampler:
83
87
  ----------
84
88
  include_sample_calls : bool
85
89
  Whether to include the sample calls in the configuration.
86
- Default is True.
90
+ Default is False.
87
91
  """
88
- config = {}
92
+ config = {"sampler_class": self.__class__.__name__}
89
93
  if include_sample_calls:
90
94
  if hasattr(self, "sample") and hasattr(self.sample, "calls"):
91
95
  config["sample_calls"] = self.sample.calls.to_dict(
@@ -96,3 +100,139 @@ class Sampler:
96
100
  "Sampler does not have a sample method with calls attribute."
97
101
  )
98
102
  return config
103
+
104
+ # --- Checkpointing helpers shared across samplers ---
105
+ def _checkpoint_extra_state(self) -> dict:
106
+ """Sampler-specific extras for checkpointing (override in subclasses)."""
107
+ return {}
108
+
109
+ def _restore_extra_state(self, state: dict) -> None:
110
+ """Restore sampler-specific extras (override in subclasses)."""
111
+ _ = state # no-op for base
112
+
113
+ def build_checkpoint_state(
114
+ self,
115
+ samples: Samples,
116
+ iteration: int | None = None,
117
+ meta: dict | None = None,
118
+ ) -> dict:
119
+ """Prepare a serializable checkpoint payload for the sampler state."""
120
+ checkpoint_samples = samples
121
+ base_state = {
122
+ "sampler": self.__class__.__name__,
123
+ "iteration": iteration,
124
+ "samples": checkpoint_samples,
125
+ "config": self.config_dict(include_sample_calls=False),
126
+ "parameters": self.parameters,
127
+ "meta": meta or {},
128
+ }
129
+ base_state.update(self._checkpoint_extra_state())
130
+ return base_state
131
+
132
+ def serialize_checkpoint(
133
+ self, state: dict, protocol: int | None = None
134
+ ) -> bytes:
135
+ """Serialize a checkpoint state to bytes with pickle."""
136
+ protocol = (
137
+ pickle.HIGHEST_PROTOCOL if protocol is None else int(protocol)
138
+ )
139
+ return pickle.dumps(state, protocol=protocol)
140
+
141
+ def default_checkpoint_callback(self, state: dict) -> None:
142
+ """Store the latest checkpoint (state + pickled bytes) on the sampler."""
143
+ self._last_checkpoint_state = state
144
+ self._last_checkpoint_bytes = self.serialize_checkpoint(state)
145
+
146
+ def default_file_checkpoint_callback(
147
+ self, file_path: str | Path | None
148
+ ) -> Callable[[dict], None]:
149
+ """Return a simple default callback that overwrites an HDF5 file."""
150
+ if file_path is None:
151
+ return self.default_checkpoint_callback
152
+ file_path = Path(file_path)
153
+ lower_path = file_path.name.lower()
154
+ if not lower_path.endswith((".h5", ".hdf5")):
155
+ raise ValueError(
156
+ "Checkpoint file must be an HDF5 file (.h5 or .hdf5)."
157
+ )
158
+
159
+ def _callback(state: dict) -> None:
160
+ with AspireFile(file_path, "a") as h5_file:
161
+ self.save_checkpoint_to_hdf(
162
+ state, h5_file, path="checkpoint", dsetname="state"
163
+ )
164
+ self.default_checkpoint_callback(state)
165
+
166
+ return _callback
167
+
168
+ def save_checkpoint_to_hdf(
169
+ self,
170
+ state: dict,
171
+ h5_file,
172
+ path: str = "sampler_checkpoints",
173
+ dsetname: str | None = None,
174
+ protocol: int | None = None,
175
+ ) -> None:
176
+ """Save a checkpoint state into an HDF5 file as a pickled blob."""
177
+ if dsetname is None:
178
+ iter_str = state.get("iteration", "unknown")
179
+ dsetname = f"iter_{iter_str}"
180
+ dump_state(
181
+ state,
182
+ h5_file,
183
+ path=path,
184
+ dsetname=dsetname,
185
+ protocol=protocol or pickle.HIGHEST_PROTOCOL,
186
+ )
187
+
188
+ def load_checkpoint_from_file(
189
+ self,
190
+ file_path: str | Path,
191
+ h5_path: str = "checkpoint",
192
+ dsetname: str = "state",
193
+ ) -> dict:
194
+ """Load a checkpoint dictionary from .pkl or .hdf5 file."""
195
+ file_path = Path(file_path)
196
+ lower_path = file_path.name.lower()
197
+ if lower_path.endswith((".h5", ".hdf5")):
198
+ with AspireFile(file_path, "r") as h5_file:
199
+ data = h5_file[h5_path][dsetname][...]
200
+ checkpoint_bytes = data.tobytes()
201
+ else:
202
+ with open(file_path, "rb") as f:
203
+ checkpoint_bytes = f.read()
204
+ return pickle.loads(checkpoint_bytes)
205
+
206
+ def restore_from_checkpoint(
207
+ self, source: str | bytes | dict
208
+ ) -> tuple[Samples, dict]:
209
+ """Restore sampler state from a checkpoint source."""
210
+ if isinstance(source, str):
211
+ state = self.load_checkpoint_from_file(source)
212
+ elif isinstance(source, bytes):
213
+ state = pickle.loads(source)
214
+ elif isinstance(source, dict):
215
+ state = source
216
+ else:
217
+ raise TypeError("Unsupported checkpoint source type.")
218
+
219
+ samples_saved = state.get("samples")
220
+ if samples_saved is None:
221
+ raise ValueError("Checkpoint missing samples.")
222
+
223
+ samples = Samples.from_samples(
224
+ samples_saved, xp=self.xp, dtype=self.dtype
225
+ )
226
+ # Allow subclasses to restore sampler-specific components
227
+ self._restore_extra_state(state)
228
+ return samples, state
229
+
230
+ @property
231
+ def last_checkpoint_state(self) -> dict | None:
232
+ """Return the most recent checkpoint state stored by the default callback."""
233
+ return self._last_checkpoint_state
234
+
235
+ @property
236
+ def last_checkpoint_bytes(self) -> bytes | None:
237
+ """Return the most recent pickled checkpoint produced by the default callback."""
238
+ return self._last_checkpoint_bytes
@@ -1,3 +1,4 @@
1
+ import copy
1
2
  import logging
2
3
  from typing import Any, Callable
3
4
 
@@ -158,11 +159,24 @@ class SMCSampler(MCMCSampler):
158
159
  target_efficiency: float = 0.5,
159
160
  target_efficiency_rate: float = 1.0,
160
161
  n_final_samples: int | None = None,
162
+ checkpoint_callback: Callable[[dict], None] | None = None,
163
+ checkpoint_every: int | None = None,
164
+ checkpoint_file_path: str | None = None,
165
+ resume_from: str | bytes | dict | None = None,
161
166
  ) -> SMCSamples:
162
- samples = self.draw_initial_samples(n_samples)
163
- samples = SMCSamples.from_samples(
164
- samples, xp=self.xp, beta=0.0, dtype=self.dtype
165
- )
167
+ resumed = resume_from is not None
168
+ if resumed:
169
+ samples, beta, iterations = self.restore_from_checkpoint(
170
+ resume_from
171
+ )
172
+ else:
173
+ samples = self.draw_initial_samples(n_samples)
174
+ samples = SMCSamples.from_samples(
175
+ samples, xp=self.xp, beta=0.0, dtype=self.dtype
176
+ )
177
+ beta = 0.0
178
+ iterations = 0
179
+ self.history = SMCHistory()
166
180
  self.fit_preconditioning_transform(samples.x)
167
181
 
168
182
  if self.xp.isnan(samples.log_q).any():
@@ -178,8 +192,6 @@ class SMCSampler(MCMCSampler):
178
192
  self.sampler_kwargs = self.sampler_kwargs or {}
179
193
  n_final_steps = self.sampler_kwargs.pop("n_final_steps", None)
180
194
 
181
- self.history = SMCHistory()
182
-
183
195
  self.target_efficiency = target_efficiency
184
196
  self.target_efficiency_rate = target_efficiency_rate
185
197
 
@@ -190,7 +202,6 @@ class SMCSampler(MCMCSampler):
190
202
  else:
191
203
  beta_step = np.nan
192
204
  self.adaptive = adaptive
193
- beta = 0.0
194
205
 
195
206
  if min_step is None:
196
207
  if max_n_steps is None:
@@ -202,55 +213,85 @@ class SMCSampler(MCMCSampler):
202
213
  else:
203
214
  self.adaptive_min_step = False
204
215
 
205
- iterations = 0
206
- while True:
207
- iterations += 1
208
-
209
- beta, min_step = self.determine_beta(
210
- samples,
211
- beta,
212
- beta_step,
213
- min_step,
216
+ iterations = iterations or 0
217
+ if checkpoint_callback is None and checkpoint_every is not None:
218
+ checkpoint_callback = self.default_file_checkpoint_callback(
219
+ checkpoint_file_path
214
220
  )
215
- self.history.eff_target.append(
216
- self.current_target_efficiency(beta)
221
+ if checkpoint_callback is not None and checkpoint_every is None:
222
+ checkpoint_every = 1
223
+
224
+ run_smc_loop = True
225
+ if resumed:
226
+ last_beta = self.history.beta[-1] if self.history.beta else beta
227
+ if last_beta >= 1.0:
228
+ run_smc_loop = False
229
+
230
+ def maybe_checkpoint(force: bool = False):
231
+ if checkpoint_callback is None:
232
+ return
233
+ should_checkpoint = force or (
234
+ checkpoint_every is not None
235
+ and checkpoint_every > 0
236
+ and iterations % checkpoint_every == 0
217
237
  )
238
+ if not should_checkpoint:
239
+ return
240
+ state = self.build_checkpoint_state(samples, iterations, beta)
241
+ checkpoint_callback(state)
242
+
243
+ if run_smc_loop:
244
+ while True:
245
+ iterations += 1
246
+
247
+ beta, min_step = self.determine_beta(
248
+ samples,
249
+ beta,
250
+ beta_step,
251
+ min_step,
252
+ )
253
+ self.history.eff_target.append(
254
+ self.current_target_efficiency(beta)
255
+ )
218
256
 
219
- logger.info(f"it {iterations} - beta: {beta}")
220
- self.history.beta.append(beta)
221
-
222
- ess = effective_sample_size(samples.log_weights(beta))
223
- eff = ess / len(samples)
224
- if eff < 0.1:
225
- logger.warning(
226
- f"it {iterations} - Low sample efficiency: {eff:.2f}"
257
+ logger.info(f"it {iterations} - beta: {beta}")
258
+ self.history.beta.append(beta)
259
+
260
+ ess = effective_sample_size(samples.log_weights(beta))
261
+ eff = ess / len(samples)
262
+ if eff < 0.1:
263
+ logger.warning(
264
+ f"it {iterations} - Low sample efficiency: {eff:.2f}"
265
+ )
266
+ self.history.ess.append(ess)
267
+ logger.info(
268
+ f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
269
+ )
270
+ self.history.ess_target.append(
271
+ effective_sample_size(samples.log_weights(1.0))
227
272
  )
228
- self.history.ess.append(ess)
229
- logger.info(
230
- f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
231
- )
232
- self.history.ess_target.append(
233
- effective_sample_size(samples.log_weights(1.0))
234
- )
235
273
 
236
- log_evidence_ratio = samples.log_evidence_ratio(beta)
237
- log_evidence_ratio_var = samples.log_evidence_ratio_variance(beta)
238
- self.history.log_norm_ratio.append(log_evidence_ratio)
239
- self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
240
- logger.info(
241
- f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
242
- )
274
+ log_evidence_ratio = samples.log_evidence_ratio(beta)
275
+ log_evidence_ratio_var = samples.log_evidence_ratio_variance(
276
+ beta
277
+ )
278
+ self.history.log_norm_ratio.append(log_evidence_ratio)
279
+ self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
280
+ logger.info(
281
+ f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
282
+ )
243
283
 
244
- samples = samples.resample(beta, rng=self.rng)
284
+ samples = samples.resample(beta, rng=self.rng)
245
285
 
246
- samples = self.mutate(samples, beta)
247
- if beta == 1.0 or (
248
- max_n_steps is not None and iterations >= max_n_steps
249
- ):
250
- break
286
+ samples = self.mutate(samples, beta)
287
+ maybe_checkpoint()
288
+ if beta == 1.0 or (
289
+ max_n_steps is not None and iterations >= max_n_steps
290
+ ):
291
+ break
251
292
 
252
- # If n_final_samples is not None, perform additional mutations steps
253
- if n_final_samples is not None:
293
+ # If n_final_samples is specified and differs, perform additional mutation steps
294
+ if n_final_samples is not None and len(samples.x) != n_final_samples:
254
295
  logger.info(f"Generating {n_final_samples} final samples")
255
296
  final_samples = samples.resample(
256
297
  1.0, n_samples=n_final_samples, rng=self.rng
@@ -263,6 +304,7 @@ class SMCSampler(MCMCSampler):
263
304
  samples.log_evidence_error = samples.xp.sqrt(
264
305
  samples.xp.sum(asarray(self.history.log_norm_ratio_var, self.xp))
265
306
  )
307
+ maybe_checkpoint(force=True)
266
308
 
267
309
  final_samples = samples.to_standard_samples()
268
310
  logger.info(
@@ -289,6 +331,49 @@ class SMCSampler(MCMCSampler):
289
331
  )
290
332
  return log_prob
291
333
 
334
+ def build_checkpoint_state(
335
+ self, samples: SMCSamples, iteration: int, beta: float
336
+ ) -> dict:
337
+ """Prepare a serializable checkpoint payload for the sampler state."""
338
+ return super().build_checkpoint_state(
339
+ samples,
340
+ iteration,
341
+ meta={"beta": beta},
342
+ )
343
+
344
+ def _checkpoint_extra_state(self) -> dict:
345
+ history_copy = copy.deepcopy(self.history)
346
+ rng_state = (
347
+ self.rng.bit_generator.state
348
+ if hasattr(self.rng, "bit_generator")
349
+ else None
350
+ )
351
+ return {
352
+ "history": history_copy,
353
+ "rng_state": rng_state,
354
+ "sampler_kwargs": getattr(self, "sampler_kwargs", None),
355
+ }
356
+
357
+ def restore_from_checkpoint(
358
+ self, source: str | bytes | dict
359
+ ) -> tuple[SMCSamples, float, int]:
360
+ samples, state = super().restore_from_checkpoint(source)
361
+ meta = state.get("meta", {}) if isinstance(state, dict) else {}
362
+ beta = None
363
+ if isinstance(meta, dict):
364
+ beta = meta.get("beta", None)
365
+ if beta is None:
366
+ beta = state.get("beta", 0.0)
367
+ iteration = state.get("iteration", 0)
368
+ self.history = state.get("history", SMCHistory())
369
+ rng_state = state.get("rng_state")
370
+ if rng_state is not None and hasattr(self.rng, "bit_generator"):
371
+ self.rng.bit_generator.state = rng_state
372
+ samples = SMCSamples.from_samples(
373
+ samples, xp=self.xp, beta=beta, dtype=self.dtype
374
+ )
375
+ return samples, beta, iteration
376
+
292
377
 
293
378
  class NumpySMCSampler(SMCSampler):
294
379
  def __init__(
@@ -84,6 +84,10 @@ class BlackJAXSMC(SMCSampler):
84
84
  n_final_samples: int | None = None,
85
85
  sampler_kwargs: dict | None = None,
86
86
  rng_key=None,
87
+ checkpoint_callback=None,
88
+ checkpoint_every: int | None = None,
89
+ checkpoint_file_path: str | None = None,
90
+ resume_from: str | bytes | dict | None = None,
87
91
  ):
88
92
  """Sample using BlackJAX SMC.
89
93
 
@@ -132,6 +136,10 @@ class BlackJAXSMC(SMCSampler):
132
136
  target_efficiency=target_efficiency,
133
137
  target_efficiency_rate=target_efficiency_rate,
134
138
  n_final_samples=n_final_samples,
139
+ checkpoint_callback=checkpoint_callback,
140
+ checkpoint_every=checkpoint_every,
141
+ checkpoint_file_path=checkpoint_file_path,
142
+ resume_from=resume_from,
135
143
  )
136
144
 
137
145
  def mutate(self, particles, beta, n_steps=None):
@@ -21,6 +21,10 @@ class EmceeSMC(NumpySMCSampler):
21
21
  target_efficiency_rate: float = 1.0,
22
22
  sampler_kwargs: dict | None = None,
23
23
  n_final_samples: int | None = None,
24
+ checkpoint_callback=None,
25
+ checkpoint_every: int | None = None,
26
+ checkpoint_file_path: str | None = None,
27
+ resume_from: str | bytes | dict | None = None,
24
28
  ):
25
29
  self.sampler_kwargs = sampler_kwargs or {}
26
30
  self.sampler_kwargs.setdefault("nsteps", 5 * self.dims)
@@ -33,6 +37,10 @@ class EmceeSMC(NumpySMCSampler):
33
37
  target_efficiency=target_efficiency,
34
38
  target_efficiency_rate=target_efficiency_rate,
35
39
  n_final_samples=n_final_samples,
40
+ checkpoint_callback=checkpoint_callback,
41
+ checkpoint_every=checkpoint_every,
42
+ checkpoint_file_path=checkpoint_file_path,
43
+ resume_from=resume_from,
36
44
  )
37
45
 
38
46
  def mutate(self, particles, beta, n_steps=None):
@@ -8,10 +8,10 @@ from ...utils import (
8
8
  determine_backend_name,
9
9
  track_calls,
10
10
  )
11
- from .base import NumpySMCSampler
11
+ from .base import SMCSampler
12
12
 
13
13
 
14
- class MiniPCNSMC(NumpySMCSampler):
14
+ class MiniPCNSMC(SMCSampler):
15
15
  """MiniPCN SMC sampler."""
16
16
 
17
17
  rng = None
@@ -32,6 +32,10 @@ class MiniPCNSMC(NumpySMCSampler):
32
32
  n_final_samples: int | None = None,
33
33
  sampler_kwargs: dict | None = None,
34
34
  rng: np.random.Generator | None = None,
35
+ checkpoint_callback=None,
36
+ checkpoint_every: int | None = None,
37
+ checkpoint_file_path: str | None = None,
38
+ resume_from: str | bytes | dict | None = None,
35
39
  ):
36
40
  from orng import ArrayRNG
37
41
 
@@ -50,6 +54,10 @@ class MiniPCNSMC(NumpySMCSampler):
50
54
  n_final_samples=n_final_samples,
51
55
  min_step=min_step,
52
56
  max_n_steps=max_n_steps,
57
+ checkpoint_callback=checkpoint_callback,
58
+ checkpoint_every=checkpoint_every,
59
+ checkpoint_file_path=checkpoint_file_path,
60
+ resume_from=resume_from,
53
61
  )
54
62
 
55
63
  def mutate(self, particles, beta, n_steps=None):
aspire/samples.py CHANGED
@@ -9,19 +9,18 @@ from typing import Any, Callable
9
9
  import numpy as np
10
10
  from array_api_compat import (
11
11
  array_namespace,
12
- is_numpy_namespace,
13
- to_device,
14
12
  )
15
- from array_api_compat import device as api_device
16
13
  from array_api_compat.common._typing import Array
17
14
  from matplotlib.figure import Figure
18
15
 
19
16
  from .utils import (
20
17
  asarray,
21
18
  convert_dtype,
19
+ infer_device,
22
20
  logsumexp,
23
21
  recursively_save_to_h5_file,
24
22
  resolve_dtype,
23
+ safe_to_device,
25
24
  to_numpy,
26
25
  )
27
26
 
@@ -67,8 +66,6 @@ class BaseSamples:
67
66
  if self.xp is None:
68
67
  self.xp = array_namespace(self.x)
69
68
  # Numpy arrays need to be on the CPU before being converted
70
- if is_numpy_namespace(self.xp):
71
- self.device = "cpu"
72
69
  if self.dtype is not None:
73
70
  self.dtype = resolve_dtype(self.dtype, self.xp)
74
71
  else:
@@ -76,7 +73,7 @@ class BaseSamples:
76
73
  self.dtype = None
77
74
  self.x = self.array_to_namespace(self.x, dtype=self.dtype)
78
75
  if self.device is None:
79
- self.device = api_device(self.x)
76
+ self.device = infer_device(self.x, self.xp)
80
77
  if self.log_likelihood is not None:
81
78
  self.log_likelihood = self.array_to_namespace(
82
79
  self.log_likelihood, dtype=self.dtype
@@ -140,8 +137,7 @@ class BaseSamples:
140
137
  else:
141
138
  kwargs["dtype"] = self.dtype
142
139
  x = asarray(x, self.xp, **kwargs)
143
- if self.device:
144
- x = to_device(x, self.device)
140
+ x = safe_to_device(x, self.device, self.xp)
145
141
  return x
146
142
 
147
143
  def to_dict(self, flat: bool = True):
@@ -174,7 +170,6 @@ class BaseSamples:
174
170
  ----------
175
171
  parameters : list[str] | None
176
172
  List of parameters to plot. If None, all parameters are plotted.
177
- fig : matplotlib.figure.Figure | None
178
173
  Figure to plot on. If None, a new figure is created.
179
174
  **kwargs : dict
180
175
  Additional keyword arguments to pass to corner.corner(). Kwargs
@@ -300,6 +295,13 @@ class BaseSamples:
300
295
  def __setstate__(self, state):
301
296
  # Restore xp by checking the namespace of x
302
297
  state["xp"] = array_namespace(state["x"])
298
+ # device may be string; leave as-is or None
299
+ device = state.get("device")
300
+ if device is not None and "jax" in getattr(
301
+ state["xp"], "__name__", ""
302
+ ):
303
+ device = None
304
+ state["device"] = device
303
305
  self.__dict__.update(state)
304
306
 
305
307
 
aspire/utils.py CHANGED
@@ -1,10 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import functools
3
4
  import inspect
4
5
  import logging
6
+ import pickle
5
7
  from contextlib import contextmanager
6
8
  from dataclasses import dataclass
7
9
  from functools import partial
10
+ from io import BytesIO
8
11
  from typing import TYPE_CHECKING, Any
9
12
 
10
13
  import array_api_compat.numpy as np
@@ -601,6 +604,7 @@ def decode_from_hdf5(value: Any) -> Any:
601
604
  return None
602
605
  if value == "__empty_dict__":
603
606
  return {}
607
+ return value
604
608
 
605
609
  if isinstance(value, np.ndarray):
606
610
  # Try to collapse 0-D arrays into scalars
@@ -629,6 +633,101 @@ def decode_from_hdf5(value: Any) -> Any:
629
633
  return value
630
634
 
631
635
 
636
+ def dump_pickle_to_hdf(memfp, fp, path=None, dsetname="state"):
637
+ """Dump pickled data to an HDF5 file object."""
638
+ memfp.seek(0)
639
+ bdata = np.frombuffer(memfp.read(), dtype="S1")
640
+ target = fp.require_group(path) if path is not None else fp
641
+ if dsetname not in target:
642
+ target.create_dataset(
643
+ dsetname, shape=bdata.shape, maxshape=(None,), dtype=bdata.dtype
644
+ )
645
+ elif bdata.size != target[dsetname].shape[0]:
646
+ target[dsetname].resize((bdata.size,))
647
+ target[dsetname][:] = bdata
648
+
649
+
650
+ def dump_state(
651
+ state,
652
+ fp,
653
+ path=None,
654
+ dsetname="state",
655
+ protocol=pickle.HIGHEST_PROTOCOL,
656
+ ):
657
+ """Pickle a state object and store it in an HDF5 dataset."""
658
+ memfp = BytesIO()
659
+ pickle.dump(state, memfp, protocol=protocol)
660
+ dump_pickle_to_hdf(memfp, fp, path=path, dsetname=dsetname)
661
+
662
+
663
+ def resolve_xp(xp_name: str | None):
664
+ """
665
+ Resolve a backend name to the corresponding array_api_compat module.
666
+
667
+ Returns None if the name is None or cannot be resolved.
668
+ """
669
+ if xp_name is None:
670
+ return None
671
+ name = xp_name.lower()
672
+ if name.startswith("array_api_compat."):
673
+ name = name.removeprefix("array_api_compat.")
674
+ try:
675
+ if name in {"numpy", "numpy.ndarray"}:
676
+ import array_api_compat.numpy as np_xp
677
+
678
+ return np_xp
679
+ if name in {"jax", "jax.numpy"}:
680
+ import jax.numpy as jnp
681
+
682
+ return jnp
683
+ if name in {"torch"}:
684
+ import array_api_compat.torch as torch_xp
685
+
686
+ return torch_xp
687
+ except Exception:
688
+ logger.warning(
689
+ "Failed to resolve xp '%s', defaulting to None", xp_name
690
+ )
691
+ return None
692
+
693
+
694
+ def infer_device(x, xp):
695
+ """
696
+ Best-effort device inference that avoids non-portable identifiers.
697
+
698
+ Returns None for numpy/jax backends; returns the backend device object
699
+ for torch/cupy if available.
700
+ """
701
+ if xp is None or is_numpy_namespace(xp) or is_jax_namespace(xp):
702
+ return None
703
+ try:
704
+ from array_api_compat import device
705
+
706
+ return device(x)
707
+ except Exception:
708
+ return None
709
+
710
+
711
+ def safe_to_device(x, device, xp):
712
+ """
713
+ Move to device if specified; otherwise return input.
714
+
715
+ Skips moves for numpy/jax/None devices; logs and returns input on failure.
716
+ """
717
+ if device is None:
718
+ return x
719
+ if xp is None or is_numpy_namespace(xp) or is_jax_namespace(xp):
720
+ return x
721
+ try:
722
+ return to_device(x, device)
723
+ except Exception:
724
+ logger.warning(
725
+ "Failed to move array to device %s; leaving on current device",
726
+ device,
727
+ )
728
+ return x
729
+
730
+
632
731
  def recursively_save_to_h5_file(h5_file, path, dictionary):
633
732
  """Save a dictionary to an HDF5 file with flattened keys under a given group path."""
634
733
  # Ensure the group exists (or open it if already present)
@@ -813,3 +912,23 @@ def track_calls(wrapped=None):
813
912
  return wrapped_func(*args, **kwargs)
814
913
 
815
914
  return wrapper(wrapped) if wrapped else wrapper
915
+
916
+
917
+ def function_id(fn: Any) -> str:
918
+ """Get a unique identifier for a function.
919
+
920
+ Parameters
921
+ ----------
922
+ fn : Any
923
+ The function to get the identifier for.
924
+
925
+ Returns
926
+ -------
927
+ str
928
+ The unique identifier for the function.
929
+ """
930
+ if isinstance(fn, functools.partial):
931
+ base = fn.func
932
+ else:
933
+ base = fn
934
+ return f"{base.__module__}:{getattr(base, '__qualname__', type(base).__name__)}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a10
3
+ Version: 0.1.0a12
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -1,28 +1,28 @@
1
1
  aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
2
- aspire/aspire.py,sha256=M5o-QxLthE_5daa1trgUfWxPz-g4rmpEUKimKosw4lw,17400
2
+ aspire/aspire.py,sha256=7DDRpwMezJABzX3AyHamRf8hjLAEeqCtg-_s5qSRjg0,30885
3
3
  aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
4
4
  aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
5
- aspire/samples.py,sha256=z5x5hpWuVFH1hYhltmROAe8pbWxGD2UvHi3vcc132dg,19399
5
+ aspire/samples.py,sha256=v7y8DkirUCHOJbCE-o9y2K7xzU2HicIo_O0CdFhLgXE,19478
6
6
  aspire/transforms.py,sha256=CHrfPQHEyHQ9I0WWiAgUplWwxYypZsD7uCYIHUbSFtY,24974
7
- aspire/utils.py,sha256=eKGmchpuoNL15Xbu-AGoeZ00PcQEykEQiDZMnnRyV6A,24234
7
+ aspire/utils.py,sha256=sIONKn3gT7i3hVdlK9bRWy_I79rdk0QPkXTA4O1FlCI,27405
8
8
  aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
9
9
  aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
10
10
  aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
11
11
  aspire/flows/jax/flows.py,sha256=1HnVgQ1GUXNcvxiZqEV19H2QI9Th5bWX_QbNfGaUhuA,6625
12
12
  aspire/flows/jax/utils.py,sha256=5T6UrgpARG9VywC9qmTl45LjyZWuEdkW3XUladE6xJE,1518
13
13
  aspire/flows/torch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- aspire/flows/torch/flows.py,sha256=0_YkiMT49QolyQnEFsh28tfKLnURVF0Z6aTnaWLIUDI,11672
14
+ aspire/flows/torch/flows.py,sha256=QcQOcFZEsLWHPwbQUFGOFdfEslyc59Vf_UEsS0xAGPo,11673
15
15
  aspire/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- aspire/samplers/base.py,sha256=VEHawyVA33jXHMo63p5hBHkp9k2qxU_bOxh5iaZSXew,3011
16
+ aspire/samplers/base.py,sha256=ygqrvqSedWSb0cz8DQ_MHokOOxi6aBRdHxf_qoEPwUE,8243
17
17
  aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M,676
18
18
  aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
19
19
  aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- aspire/samplers/smc/base.py,sha256=66f_ORUvcKRqMIW35qjhUc-c0PFuY87lJa91MpSaTZI,10729
21
- aspire/samplers/smc/blackjax.py,sha256=IcTguAETiPmgFofmVW2GN40P5HBIxkmyd2VR8AU8f4k,12115
22
- aspire/samplers/smc/emcee.py,sha256=Wm0vvAlCcRhJMBt7_fU2ZnjDb8SN8jgUOTXLzNstRpA,2516
23
- aspire/samplers/smc/minipcn.py,sha256=iQUBBwHZ_D4CxNjARMngklRvx6yTlEDKdeidyYCgqM4,3003
24
- aspire_inference-0.1.0a10.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
- aspire_inference-0.1.0a10.dist-info/METADATA,sha256=SdvlXKjQn0uJQPpWzoZKAH3oMJSKpZnvlUzxPsIwNlY,3869
26
- aspire_inference-0.1.0a10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- aspire_inference-0.1.0a10.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
- aspire_inference-0.1.0a10.dist-info/RECORD,,
20
+ aspire/samplers/smc/base.py,sha256=40A9yVuKS1F8cPzbfVQ9rNk3y07mnkfbuyRDIh_fy5A,14122
21
+ aspire/samplers/smc/blackjax.py,sha256=2riWDSRmpL5lGmnhNtdieiRs0oYC6XZA2X-nVlQaqpE,12490
22
+ aspire/samplers/smc/emcee.py,sha256=4CI9GvH69FCoLiFBbKKYwYocYyiM95IijC5EvrcAmUo,2891
23
+ aspire/samplers/smc/minipcn.py,sha256=IJ5466VvARd4qZCWXXl-l3BPaKW1AgcwmbP3ISL2bto,3368
24
+ aspire_inference-0.1.0a12.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
+ aspire_inference-0.1.0a12.dist-info/METADATA,sha256=snzuBueTUZazIqKt6yEYfI4JdY3QVhY0C-vxl7Urauw,3869
26
+ aspire_inference-0.1.0a12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ aspire_inference-0.1.0a12.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
+ aspire_inference-0.1.0a12.dist-info/RECORD,,