aspire-inference 0.1.0a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ """
2
+ aspire: Accelerated Sequential Posterior Inference via REuse
3
+ """
4
+
5
+ import logging
6
+ from importlib.metadata import PackageNotFoundError, version
7
+
8
+ from .aspire import Aspire
9
+
10
+ try:
11
+ __version__ = version("aspire")
12
+ except PackageNotFoundError:
13
+ __version__ = "unknown"
14
+
15
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
16
+
17
+ __all__ = [
18
+ "Aspire",
19
+ ]
aspire/aspire.py ADDED
@@ -0,0 +1,506 @@
1
+ import logging
2
+ import multiprocessing as mp
3
+ from inspect import signature
4
+ from typing import Any, Callable
5
+
6
+ import h5py
7
+
8
+ from .flows import get_flow_wrapper
9
+ from .flows.base import Flow
10
+ from .history import History
11
+ from .samples import Samples
12
+ from .transforms import (
13
+ CompositeTransform,
14
+ FlowPreconditioningTransform,
15
+ FlowTransform,
16
+ )
17
+ from .utils import recursively_save_to_h5_file
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class Aspire:
23
+ """Accelerated Sequential Posterior Inference via REuse (aspire).
24
+
25
+ Parameters
26
+ ----------
27
+ log_likelihood : Callable
28
+ The log likelihood function.
29
+ log_prior : Callable
30
+ The log prior function.
31
+ dims : int
32
+ The number of dimensions.
33
+ parameters : list[str] | None
34
+ The list of parameter names. If None, any samples objects will not
35
+ have the parameters names specified.
36
+ periodic_parameters : list[str] | None
37
+ The list of periodic parameters.
38
+ prior_bounds : dict[str, tuple[float, float]] | None
39
+ The bounds for the prior. If None, some parameter transforms cannot
40
+ be applied.
41
+ bounded_to_unbounded : bool
42
+ Whether to transform bounded parameters to unbounded ones.
43
+ bounded_transform : str
44
+ The transformation to use for bounded parameters. Options are
45
+ 'logit', 'exp', or 'tanh'.
46
+ device : str | None
47
+ The device to use for the flow. If None, the default device will be
48
+ used. This is only used when using the PyTorch backend.
49
+ xp : Callable | None
50
+ The array backend to use. If None, the default backend will be
51
+ used.
52
+ flow : Flow | None
53
+ The flow object, if it already exists.
54
+ If None, a new flow will be created.
55
+ flow_backend : str
56
+ The backend to use for the flow. Options are 'zuko' or 'flowjax'.
57
+ flow_matching : bool
58
+ Whether to use flow matching.
59
+ eps : float
60
+ The epsilon value to use for data transforms.
61
+ dtype : Any | str | None
62
+ The data type to use for the samples, flow and transforms.
63
+ **kwargs
64
+ Keyword arguments to pass to the flow.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ *,
70
+ log_likelihood: Callable,
71
+ log_prior: Callable,
72
+ dims: int,
73
+ parameters: list[str] | None = None,
74
+ periodic_parameters: list[str] | None = None,
75
+ prior_bounds: dict[str, tuple[float, float]] | None = None,
76
+ bounded_to_unbounded: bool = True,
77
+ bounded_transform: str = "logit",
78
+ device: str | None = None,
79
+ xp: Callable | None = None,
80
+ flow: Flow | None = None,
81
+ flow_backend: str = "zuko",
82
+ flow_matching: bool = False,
83
+ eps: float = 1e-6,
84
+ dtype: Any | str | None = None,
85
+ **kwargs,
86
+ ) -> None:
87
+ self.log_likelihood = log_likelihood
88
+ self.log_prior = log_prior
89
+ self.dims = dims
90
+ self.parameters = parameters
91
+ self.device = device
92
+ self.eps = eps
93
+
94
+ self.periodic_parameters = periodic_parameters
95
+ self.prior_bounds = prior_bounds
96
+ self.bounded_to_unbounded = bounded_to_unbounded
97
+ self.bounded_transform = bounded_transform
98
+ self.flow_matching = flow_matching
99
+ self.flow_backend = flow_backend
100
+ self.flow_kwargs = kwargs
101
+ self.xp = xp
102
+ self.dtype = dtype
103
+
104
+ self._flow = flow
105
+
106
+ @property
107
+ def flow(self):
108
+ """The normalizing flow object."""
109
+ return self._flow
110
+
111
+ @flow.setter
112
+ def flow(self, flow: Flow):
113
+ """Set the normalizing flow object."""
114
+ self._flow = flow
115
+
116
+ @property
117
+ def sampler(self):
118
+ """The sampler object."""
119
+ return self._sampler
120
+
121
+ @property
122
+ def n_likelihood_evaluations(self):
123
+ """The number of likelihood evaluations."""
124
+ if hasattr(self, "_sampler"):
125
+ return self._sampler.n_likelihood_evaluations
126
+ else:
127
+ return None
128
+
129
+ def convert_to_samples(
130
+ self,
131
+ x,
132
+ log_likelihood=None,
133
+ log_prior=None,
134
+ log_q=None,
135
+ evaluate: bool = True,
136
+ xp=None,
137
+ ) -> Samples:
138
+ if xp is None:
139
+ xp = self.xp
140
+ samples = Samples(
141
+ x=x,
142
+ parameters=self.parameters,
143
+ log_likelihood=log_likelihood,
144
+ log_prior=log_prior,
145
+ log_q=log_q,
146
+ xp=xp,
147
+ dtype=self.dtype,
148
+ )
149
+
150
+ if evaluate:
151
+ if log_prior is None:
152
+ logger.info("Evaluating log prior")
153
+ samples.log_prior = samples.xp.to_device(
154
+ self.log_prior(samples), samples.device
155
+ )
156
+ if log_likelihood is None:
157
+ logger.info("Evaluating log likelihood")
158
+ samples.log_likelihood = samples.xp.to_device(
159
+ self.log_likelihood(samples), samples.device
160
+ )
161
+ samples.compute_weights()
162
+ return samples
163
+
164
+ def init_flow(self):
165
+ FlowClass, xp = get_flow_wrapper(
166
+ backend=self.flow_backend, flow_matching=self.flow_matching
167
+ )
168
+
169
+ data_transform = FlowTransform(
170
+ parameters=self.parameters,
171
+ prior_bounds=self.prior_bounds,
172
+ bounded_to_unbounded=self.bounded_to_unbounded,
173
+ bounded_transform=self.bounded_transform,
174
+ device=self.device,
175
+ xp=xp,
176
+ eps=self.eps,
177
+ dtype=self.dtype,
178
+ )
179
+
180
+ # Check if FlowClass takes `parameters` as an argument
181
+ flow_init_params = signature(FlowClass.__init__).parameters
182
+ if "parameters" in flow_init_params:
183
+ self.flow_kwargs["parameters"] = self.parameters.copy()
184
+
185
+ logger.info(f"Configuring {FlowClass} with kwargs: {self.flow_kwargs}")
186
+
187
+ self._flow = FlowClass(
188
+ dims=self.dims,
189
+ device=self.device,
190
+ data_transform=data_transform,
191
+ dtype=self.dtype,
192
+ **self.flow_kwargs,
193
+ )
194
+
195
+ def fit(self, samples: Samples, **kwargs) -> History:
196
+ if self.xp is None:
197
+ self.xp = samples.xp
198
+
199
+ if self.flow is None:
200
+ self.init_flow()
201
+
202
+ self.training_samples = samples
203
+ logger.info(f"Training with {len(samples.x)} samples")
204
+ history = self.flow.fit(samples.x, **kwargs)
205
+ return history
206
+
207
+ def get_sampler_class(self, sampler_type: str) -> Callable:
208
+ """Get the sampler class based on the sampler type.
209
+
210
+ Parameters
211
+ ----------
212
+ sampler_type : str
213
+ The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
214
+ """
215
+ if sampler_type == "importance":
216
+ from .samplers.importance import ImportanceSampler as SamplerClass
217
+ elif sampler_type == "emcee":
218
+ from .samplers.mcmc import Emcee as SamplerClass
219
+ elif sampler_type == "emcee_smc":
220
+ from .samplers.smc.emcee import EmceeSMC as SamplerClass
221
+ elif sampler_type == "minipcn":
222
+ from .samplers.mcmc import MiniPCN as SamplerClass
223
+ elif sampler_type in ["smc", "minipcn_smc"]:
224
+ from .samplers.smc.minipcn import MiniPCNSMC as SamplerClass
225
+ elif sampler_type == "blackjax_smc":
226
+ from .samplers.smc.blackjax import BlackJAXSMC as SamplerClass
227
+ else:
228
+ raise ValueError(f"Unknown sampler type: {sampler_type}")
229
+ return SamplerClass
230
+
231
+ def init_sampler(
232
+ self,
233
+ sampler_type: str,
234
+ preconditioning: str | None = None,
235
+ preconditioning_kwargs: dict | None = None,
236
+ **kwargs,
237
+ ) -> Callable:
238
+ """Initialize the sampler for posterior sampling.
239
+
240
+ Parameters
241
+ ----------
242
+ sampler_type : str
243
+ The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
244
+ """
245
+ SamplerClass = self.get_sampler_class(sampler_type)
246
+
247
+ if sampler_type != "importance" and preconditioning is None:
248
+ preconditioning = "default"
249
+
250
+ preconditioning = preconditioning.lower() if preconditioning else None
251
+
252
+ if preconditioning is None or preconditioning == "none":
253
+ transform = None
254
+ elif preconditioning in ["standard", "default"]:
255
+ preconditioning_kwargs = preconditioning_kwargs or {}
256
+ preconditioning_kwargs.setdefault("affine_transform", False)
257
+ preconditioning_kwargs.setdefault("bounded_to_unbounded", False)
258
+ preconditioning_kwargs.setdefault("bounded_transform", "logit")
259
+ transform = CompositeTransform(
260
+ parameters=self.parameters,
261
+ prior_bounds=self.prior_bounds,
262
+ periodic_parameters=self.periodic_parameters,
263
+ xp=self.xp,
264
+ device=self.device,
265
+ dtype=self.dtype,
266
+ **preconditioning_kwargs,
267
+ )
268
+ elif preconditioning == "flow":
269
+ preconditioning_kwargs = preconditioning_kwargs or {}
270
+ preconditioning_kwargs.setdefault("affine_transform", False)
271
+ transform = FlowPreconditioningTransform(
272
+ parameters=self.parameters,
273
+ flow_backend=self.flow_backend,
274
+ flow_kwargs=self.flow_kwargs,
275
+ flow_matching=self.flow_matching,
276
+ periodic_parameters=self.periodic_parameters,
277
+ bounded_to_unbounded=self.bounded_to_unbounded,
278
+ prior_bounds=self.prior_bounds,
279
+ xp=self.xp,
280
+ dtype=self.dtype,
281
+ device=self.device,
282
+ **preconditioning_kwargs,
283
+ )
284
+ else:
285
+ raise ValueError(f"Unknown preconditioning: {preconditioning}")
286
+
287
+ sampler = SamplerClass(
288
+ log_likelihood=self.log_likelihood,
289
+ log_prior=self.log_prior,
290
+ dims=self.dims,
291
+ prior_flow=self.flow,
292
+ xp=self.xp,
293
+ dtype=self.dtype,
294
+ preconditioning_transform=transform,
295
+ **kwargs,
296
+ )
297
+ return sampler
298
+
299
+ def sample_posterior(
300
+ self,
301
+ n_samples: int = 1000,
302
+ sampler: str = "importance",
303
+ xp: Any = None,
304
+ return_history: bool = False,
305
+ preconditioning: str | None = None,
306
+ preconditioning_kwargs: dict | None = None,
307
+ **kwargs,
308
+ ) -> Samples:
309
+ """Draw samples from the posterior distribution.
310
+
311
+ If using a sampler that calls an external sampler, e.g.
312
+ :code:`minipcn` then keyword arguments for this sampler should be
313
+ specified in :code:`sampler_kwargs`. For example:
314
+
315
+ .. code-block:: python
316
+
317
+ aspire = aspire(...)
318
+ aspire.sample_posterior(
319
+ n_samples=1000,
320
+ sampler="minipcn_smc",
321
+ adaptive=True,
322
+ sampler_kwargs=dict(
323
+ n_steps=100,
324
+ step_fn="tpcn",
325
+ )
326
+ )
327
+
328
+ Parameters
329
+ ----------
330
+ n_samples : int
331
+ The number of sample to draw.
332
+ sampler: str
333
+ Sampling algorithm to use for drawing the posterior samples.
334
+ xp: Any
335
+ Array API for the final samples.
336
+ return_history : bool
337
+ Whether to return the history of the sampler.
338
+ preconditioning: str
339
+ Type of preconditioning to apply in the sampler. Options are
340
+ 'default', 'flow', or 'none'. If not specified, the default
341
+ will depend on the sampler being used. The importance sampler
342
+ will default to 'none' and the other samplers to 'default'
343
+ preconditioning_kwargs: dict
344
+ Keyword arguments to pass to the preconditioning transform.
345
+ kwargs : dict
346
+ Keyword arguments to pass to the sampler. These are passed
347
+ automatically to the init method of the sampler or to the sample
348
+ method.
349
+
350
+ Returns
351
+ -------
352
+ samples : Samples
353
+ Samples object contain samples and their corresponding weights.
354
+ """
355
+ SamplerClass = self.get_sampler_class(sampler)
356
+ # Determine sampler initialization parameters
357
+ # and remove them from kwargs
358
+ sampler_init_kwargs = signature(SamplerClass.__init__).parameters
359
+ sampler_kwargs = {
360
+ k: v
361
+ for k, v in kwargs.items()
362
+ if k in sampler_init_kwargs and k != "self"
363
+ }
364
+ kwargs = {
365
+ k: v
366
+ for k, v in kwargs.items()
367
+ if k not in sampler_init_kwargs or k == "self"
368
+ }
369
+
370
+ self._sampler = self.init_sampler(
371
+ sampler,
372
+ preconditioning=preconditioning,
373
+ preconditioning_kwargs=preconditioning_kwargs,
374
+ **sampler_kwargs,
375
+ )
376
+ samples = self._sampler.sample(n_samples, **kwargs)
377
+ if xp is not None:
378
+ samples = samples.to_namespace(xp)
379
+ samples.parameters = self.parameters
380
+ logger.info(f"Sampled {len(samples)} samples from the posterior")
381
+ logger.info(
382
+ f"Number of likelihood evaluations: {self.n_likelihood_evaluations}"
383
+ )
384
+ logger.info("Sample summary:")
385
+ logger.info(samples)
386
+ if return_history:
387
+ return samples, self._sampler.history
388
+ else:
389
+ return samples
390
+
391
+ def enable_pool(self, pool: mp.Pool, **kwargs):
392
+ """Context manager to temporarily replace the log_likelihood method
393
+ with a version that uses a multiprocessing pool to parallelize
394
+ computation.
395
+
396
+ Parameters
397
+ ----------
398
+ pool : multiprocessing.Pool
399
+ The pool to use for parallel computation.
400
+ """
401
+ from .utils import PoolHandler
402
+
403
+ return PoolHandler(self, pool, **kwargs)
404
+
405
+ def config_dict(
406
+ self, include_sampler_config: bool = True, **kwargs
407
+ ) -> dict:
408
+ """Return a dictionary with the configuration of the aspire object.
409
+
410
+ Parameters
411
+ ----------
412
+ include_sampler_config : bool
413
+ Whether to include the configuration of the sampler. Default is
414
+ True.
415
+ kwargs : dict
416
+ Additional keyword arguments to pass to the :py:meth:`config_dict`
417
+ method of the sampler.
418
+ """
419
+ config = {
420
+ "log_likelihood": self.log_likelihood.__name__,
421
+ "log_prior": self.log_prior.__name__,
422
+ "dims": self.dims,
423
+ "parameters": self.parameters,
424
+ "periodic_parameters": self.periodic_parameters,
425
+ "prior_bounds": self.prior_bounds,
426
+ "bounded_to_unbounded": self.bounded_to_unbounded,
427
+ "bounded_transform": self.bounded_transform,
428
+ "flow_matching": self.flow_matching,
429
+ "device": self.device,
430
+ "xp": self.xp.__name__ if self.xp else None,
431
+ "flow_backend": self.flow_backend,
432
+ "flow_kwargs": self.flow_kwargs,
433
+ "eps": self.eps,
434
+ }
435
+ if include_sampler_config:
436
+ config["sampler_config"] = self.sampler.config_dict(**kwargs)
437
+ return config
438
+
439
+ def save_config(
440
+ self, h5_file: h5py.File, path="aspire_config", **kwargs
441
+ ) -> None:
442
+ """Save the configuration to an HDF5 file.
443
+
444
+ Parameters
445
+ ----------
446
+ h5_file : h5py.File
447
+ The HDF5 file to save the configuration to.
448
+ path : str
449
+ The path in the HDF5 file to save the configuration to.
450
+ kwargs : dict
451
+ Additional keyword arguments to pass to the :py:meth:`config_dict`
452
+ method.
453
+ """
454
+ recursively_save_to_h5_file(
455
+ h5_file,
456
+ path,
457
+ self.config_dict(**kwargs),
458
+ )
459
+
460
+ def save_flow(self, h5_file: h5py.File, path="flow") -> None:
461
+ """Save the flow to an HDF5 file.
462
+
463
+ Parameters
464
+ ----------
465
+ h5_file : h5py.File
466
+ The HDF5 file to save the flow to.
467
+ path : str
468
+ The path in the HDF5 file to save the flow to.
469
+ """
470
+ if self.flow is None:
471
+ raise ValueError("Flow has not been initialized.")
472
+ self.flow.save(h5_file, path=path)
473
+
474
+ def load_flow(self, h5_file: h5py.File, path="flow") -> None:
475
+ """Load the flow from an HDF5 file.
476
+
477
+ Parameters
478
+ ----------
479
+ h5_file : h5py.File
480
+ The HDF5 file to load the flow from.
481
+ path : str
482
+ The path in the HDF5 file to load the flow from.
483
+ """
484
+ FlowClass, xp = get_flow_wrapper(
485
+ backend=self.flow_backend, flow_matching=self.flow_matching
486
+ )
487
+ self._flow = FlowClass.load(h5_file, path=path)
488
+
489
+ def save_config_to_json(self, filename: str) -> None:
490
+ """Save the configuration to a JSON file."""
491
+ import json
492
+
493
+ with open(filename, "w") as f:
494
+ json.dump(self.config_dict(), f, indent=4)
495
+
496
+ def sample_flow(self, n_samples: int = 1, xp=None) -> Samples:
497
+ """Sample from the flow directly.
498
+
499
+ Includes the data transform, but does not compute
500
+ log likelihood or log prior.
501
+ """
502
+ if self.flow is None:
503
+ self.init_flow()
504
+ x, log_q = self.flow.sample_and_log_prob(n_samples)
505
+ samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
506
+ return samples
@@ -0,0 +1,40 @@
1
+ def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
2
+ """Get the wrapper for the flow implementation."""
3
+ if backend == "zuko":
4
+ import array_api_compat.torch as torch_api
5
+
6
+ from .torch.flows import ZukoFlow, ZukoFlowMatching
7
+
8
+ if flow_matching:
9
+ return ZukoFlowMatching, torch_api
10
+ else:
11
+ return ZukoFlow, torch_api
12
+ elif backend == "flowjax":
13
+ import jax.numpy as jnp
14
+
15
+ from .jax.flows import FlowJax
16
+
17
+ if flow_matching:
18
+ raise NotImplementedError(
19
+ "Flow matching not implemented for JAX backend"
20
+ )
21
+ return FlowJax, jnp
22
+ else:
23
+ from importlib.metadata import entry_points
24
+
25
+ eps = {
26
+ ep.name.lower(): ep
27
+ for ep in entry_points().get("aspire.flows", [])
28
+ }
29
+ if backend in eps:
30
+ FlowClass = eps[backend].load()
31
+ xp = getattr(FlowClass, "xp", None)
32
+ if xp is None:
33
+ raise ValueError(
34
+ f"Flow class {backend} does not define an `xp` attribute"
35
+ )
36
+ return FlowClass, xp
37
+ else:
38
+ raise ValueError(
39
+ f"Unknown flow class: {backend}. Available classes: {list(eps.keys())}"
40
+ )
aspire/flows/base.py ADDED
@@ -0,0 +1,84 @@
1
+ import inspect
2
+ import logging
3
+ from typing import Any
4
+
5
+ from ..history import FlowHistory
6
+ from ..transforms import BaseTransform, IdentityTransform
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class Flow:
12
+ xp = None # type: Any
13
+
14
+ def __init__(
15
+ self,
16
+ dims: int,
17
+ device: Any,
18
+ data_transform: BaseTransform | None = None,
19
+ ):
20
+ self.dims = dims
21
+ self.device = device
22
+
23
+ if data_transform is None:
24
+ data_transform = IdentityTransform(self.xp)
25
+ logger.info("No data_transform provided, using IdentityTransform.")
26
+
27
+ self.data_transform = data_transform
28
+
29
+ def log_prob(self, x):
30
+ raise NotImplementedError
31
+
32
+ def sample(self, x):
33
+ raise NotImplementedError
34
+
35
+ def sample_and_log_prob(self, n_samples):
36
+ raise NotImplementedError
37
+
38
+ def fit(self, samples, **kwargs) -> FlowHistory:
39
+ raise NotImplementedError
40
+
41
+ def fit_data_transform(self, x):
42
+ return self.data_transform.fit(x)
43
+
44
+ def rescale(self, x):
45
+ return self.data_transform.forward(x)
46
+
47
+ def inverse_rescale(self, x):
48
+ return self.data_transform.inverse(x)
49
+
50
+ def config_dict(self):
51
+ """Return a dictionary of the configuration of the flow.
52
+
53
+ This can be used to recreate the flow by passing the dictionary
54
+ as keyword arguments to the constructor.
55
+
56
+ This is automatically populated with the arguments passed to the
57
+ constructor.
58
+
59
+ Returns
60
+ -------
61
+ config : dict
62
+ The configuration dictionary.
63
+ """
64
+ return getattr(self, "_init_args", {})
65
+
66
+ def save(self, h5_file, path="flow"):
67
+ raise NotImplementedError
68
+
69
+ @classmethod
70
+ def load(cls, h5_file, path="flow"):
71
+ raise NotImplementedError
72
+
73
+ def __new__(cls, *args, **kwargs):
74
+ # Create instance
75
+ obj = super().__new__(cls)
76
+ # Inspect the subclass's __init__ signature
77
+ sig = inspect.signature(cls.__init__)
78
+ bound = sig.bind_partial(obj, *args, **kwargs)
79
+ bound.apply_defaults()
80
+ # Save args (excluding self)
81
+ obj._init_args = {
82
+ k: v for k, v in bound.arguments.items() if k != "self"
83
+ }
84
+ return obj
@@ -0,0 +1,3 @@
1
+ from .utils import get_flow
2
+
3
+ __all__ = ["get_flow"]