aspire-inference 0.1.0a9__py3-none-any.whl → 0.1.0a10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/samplers/base.py CHANGED
@@ -4,7 +4,7 @@ from typing import Any, Callable
4
4
  from ..flows.base import Flow
5
5
  from ..samples import Samples
6
6
  from ..transforms import IdentityTransform
7
- from ..utils import track_calls
7
+ from ..utils import asarray, track_calls
8
8
 
9
9
  logger = logging.getLogger(__name__)
10
10
 
@@ -56,7 +56,11 @@ class Sampler:
56
56
 
57
57
  def fit_preconditioning_transform(self, x):
58
58
  """Fit the data transform to the data."""
59
- x = self.preconditioning_transform.xp.asarray(x)
59
+ x = asarray(
60
+ x,
61
+ xp=self.preconditioning_transform.xp,
62
+ dtype=self.preconditioning_transform.dtype,
63
+ )
60
64
  return self.preconditioning_transform.fit(x)
61
65
 
62
66
  @track_calls
@@ -3,7 +3,11 @@ from functools import partial
3
3
  import numpy as np
4
4
 
5
5
  from ...samples import SMCSamples
6
- from ...utils import to_numpy, track_calls
6
+ from ...utils import (
7
+ asarray,
8
+ determine_backend_name,
9
+ track_calls,
10
+ )
7
11
  from .base import NumpySMCSampler
8
12
 
9
13
 
@@ -13,7 +17,7 @@ class MiniPCNSMC(NumpySMCSampler):
13
17
  rng = None
14
18
 
15
19
  def log_prob(self, x, beta=None):
16
- return to_numpy(super().log_prob(x, beta))
20
+ return super().log_prob(x, beta)
17
21
 
18
22
  @track_calls
19
23
  def sample(
@@ -29,11 +33,14 @@ class MiniPCNSMC(NumpySMCSampler):
29
33
  sampler_kwargs: dict | None = None,
30
34
  rng: np.random.Generator | None = None,
31
35
  ):
36
+ from orng import ArrayRNG
37
+
32
38
  self.sampler_kwargs = sampler_kwargs or {}
33
39
  self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
34
40
  self.sampler_kwargs.setdefault("target_acceptance_rate", 0.234)
35
41
  self.sampler_kwargs.setdefault("step_fn", "tpcn")
36
- self.rng = rng or np.random.default_rng()
42
+ self.backend_str = determine_backend_name(xp=self.xp)
43
+ self.rng = rng or ArrayRNG(backend=self.backend_str)
37
44
  return super().sample(
38
45
  n_samples,
39
46
  n_steps=n_steps,
@@ -58,9 +65,14 @@ class MiniPCNSMC(NumpySMCSampler):
58
65
  target_acceptance_rate=self.sampler_kwargs[
59
66
  "target_acceptance_rate"
60
67
  ],
68
+ xp=self.xp,
61
69
  )
62
70
  # Map to transformed dimension for sampling
63
- z = to_numpy(self.fit_preconditioning_transform(particles.x))
71
+ z = asarray(
72
+ self.fit_preconditioning_transform(particles.x),
73
+ xp=self.xp,
74
+ dtype=self.dtype,
75
+ )
64
76
  chain, history = sampler.sample(
65
77
  z,
66
78
  n_steps=n_steps or self.sampler_kwargs["n_steps"],
aspire/samples.py CHANGED
@@ -425,19 +425,23 @@ class Samples(BaseSamples):
425
425
 
426
426
  def to_namespace(self, xp):
427
427
  return self.__class__(
428
- x=asarray(self.x, xp),
428
+ x=asarray(self.x, xp, dtype=self.dtype),
429
429
  parameters=self.parameters,
430
- log_likelihood=asarray(self.log_likelihood, xp)
430
+ log_likelihood=asarray(self.log_likelihood, xp, dtype=self.dtype)
431
431
  if self.log_likelihood is not None
432
432
  else None,
433
- log_prior=asarray(self.log_prior, xp)
433
+ log_prior=asarray(self.log_prior, xp, dtype=self.dtype)
434
434
  if self.log_prior is not None
435
435
  else None,
436
- log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
437
- log_evidence=asarray(self.log_evidence, xp)
436
+ log_q=asarray(self.log_q, xp, dtype=self.dtype)
437
+ if self.log_q is not None
438
+ else None,
439
+ log_evidence=asarray(self.log_evidence, xp, dtype=self.dtype)
438
440
  if self.log_evidence is not None
439
441
  else None,
440
- log_evidence_error=asarray(self.log_evidence_error, xp)
442
+ log_evidence_error=asarray(
443
+ self.log_evidence_error, xp, dtype=self.dtype
444
+ )
441
445
  if self.log_evidence_error is not None
442
446
  else None,
443
447
  )
aspire/utils.py CHANGED
@@ -12,7 +12,13 @@ import h5py
12
12
  import wrapt
13
13
  from array_api_compat import (
14
14
  array_namespace,
15
+ is_cupy_namespace,
16
+ is_dask_namespace,
15
17
  is_jax_array,
18
+ is_jax_namespace,
19
+ is_ndonnx_namespace,
20
+ is_numpy_namespace,
21
+ is_pydata_sparse_namespace,
16
22
  is_torch_array,
17
23
  is_torch_namespace,
18
24
  to_device,
@@ -28,6 +34,17 @@ if TYPE_CHECKING:
28
34
  logger = logging.getLogger(__name__)
29
35
 
30
36
 
37
+ IS_NAMESPACE_FUNCTIONS = {
38
+ "numpy": is_numpy_namespace,
39
+ "torch": is_torch_namespace,
40
+ "jax": is_jax_namespace,
41
+ "cupy": is_cupy_namespace,
42
+ "dask": is_dask_namespace,
43
+ "pydata_sparse": is_pydata_sparse_namespace,
44
+ "ndonnx": is_ndonnx_namespace,
45
+ }
46
+
47
+
31
48
  def configure_logger(
32
49
  log_level: str | int = "INFO",
33
50
  additional_loggers: list[str] = None,
@@ -234,7 +251,7 @@ def to_numpy(x: Array, **kwargs) -> np.ndarray:
234
251
  return np.asarray(x, **kwargs)
235
252
 
236
253
 
237
- def asarray(x, xp: Any = None, **kwargs) -> Array:
254
+ def asarray(x, xp: Any = None, dtype: Any | None = None, **kwargs) -> Array:
238
255
  """Convert an array to the specified array API.
239
256
 
240
257
  Parameters
@@ -244,13 +261,51 @@ def asarray(x, xp: Any = None, **kwargs) -> Array:
244
261
  xp : Any
245
262
  The array API to use for the conversion. If None, the array API
246
263
  is inferred from the input array.
264
+ dtype : Any | str | None
265
+ The dtype to use for the conversion. If None, the dtype is not changed.
247
266
  kwargs : dict
248
267
  Additional keyword arguments to pass to xp.asarray.
249
268
  """
269
+ # Handle DLPack conversion from JAX to PyTorch to avoid shape issues when
270
+ # passing JAX arrays directly to torch.asarray.
250
271
  if is_jax_array(x) and is_torch_namespace(xp):
251
- return xp.utils.dlpack.from_dlpack(x)
252
- else:
253
- return xp.asarray(x, **kwargs)
272
+ tensor = xp.utils.dlpack.from_dlpack(x)
273
+ if dtype is not None:
274
+ tensor = tensor.to(resolve_dtype(dtype, xp=xp))
275
+ return tensor
276
+
277
+ if dtype is not None:
278
+ kwargs["dtype"] = resolve_dtype(dtype, xp=xp)
279
+ return xp.asarray(x, **kwargs)
280
+
281
+
282
+ def determine_backend_name(
283
+ x: Array | None = None, xp: Any | None = None
284
+ ) -> str:
285
+ """Determine the backend name from an array or array API module.
286
+
287
+ Parameters
288
+ ----------
289
+ x : Array or None
290
+ The array to infer the backend from. If None, xp must be provided.
291
+ xp : Any or None
292
+ The array API module to infer the backend from. If None, x must be provided.
293
+
294
+ Returns
295
+ -------
296
+ str
297
+ The name of the backend. If the backend cannot be determined, returns "unknown".
298
+ """
299
+ if x is not None:
300
+ xp = array_namespace(x)
301
+ if xp is None:
302
+ raise ValueError(
303
+ "Either x or xp must be provided to determine backend."
304
+ )
305
+ for name, is_namespace_fn in IS_NAMESPACE_FUNCTIONS.items():
306
+ if is_namespace_fn(xp):
307
+ return name
308
+ return "unknown"
254
309
 
255
310
 
256
311
  def resolve_dtype(dtype: Any | str | None, xp: Any) -> Any | None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a9
3
+ Version: 0.1.0a10
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -26,7 +26,8 @@ Requires-Dist: torch; extra == "torch"
26
26
  Requires-Dist: zuko; extra == "torch"
27
27
  Requires-Dist: tqdm; extra == "torch"
28
28
  Provides-Extra: minipcn
29
- Requires-Dist: minipcn; extra == "minipcn"
29
+ Requires-Dist: minipcn[array-api]>=0.2.0a3; extra == "minipcn"
30
+ Requires-Dist: orng; extra == "minipcn"
30
31
  Provides-Extra: emcee
31
32
  Requires-Dist: emcee; extra == "emcee"
32
33
  Provides-Extra: blackjax
@@ -49,10 +50,28 @@ aspire is a framework for reusing existing posterior samples to obtain new resul
49
50
 
50
51
  ## Installation
51
52
 
52
- aspire can be installed from PyPI using `pip`
53
+ aspire can be installed from PyPI using `pip`. By default, you need to install
54
+ one of the backends for the normalizing flows, either `torch` or `jax`.
55
+ We also recommend installing `minipcn` if using the `smc` sampler:
56
+
57
+
58
+ **Torch**
59
+
60
+ We recommend installing `torch` manually to ensure correct CPU/CUDA versions are
61
+ installed. See the [PyTorch installation instructions](https://pytorch.org/)
62
+ for more details.
63
+
64
+ ```
65
+ pip install aspire-inference[torch,minipcn]
66
+ ```
67
+
68
+ **Jax**:
69
+
70
+ We recommend install `jax` manually to ensure the correct GPU/CUDA versions
71
+ are installed. See the [jax documentation for details](https://docs.jax.dev/en/latest/installation.html)
53
72
 
54
73
  ```
55
- pip install aspire-inference
74
+ pip install aspire-inference[jax,minipcn]
56
75
  ```
57
76
 
58
77
  **Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
@@ -2,9 +2,9 @@ aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
2
2
  aspire/aspire.py,sha256=M5o-QxLthE_5daa1trgUfWxPz-g4rmpEUKimKosw4lw,17400
3
3
  aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
4
4
  aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
5
- aspire/samples.py,sha256=nVQULOr19lQVyGhitI8EgDdCDGx9sExPomQBJrV4rTc,19237
5
+ aspire/samples.py,sha256=z5x5hpWuVFH1hYhltmROAe8pbWxGD2UvHi3vcc132dg,19399
6
6
  aspire/transforms.py,sha256=CHrfPQHEyHQ9I0WWiAgUplWwxYypZsD7uCYIHUbSFtY,24974
7
- aspire/utils.py,sha256=pj8O0chqfP6VS8bpW0wCw8W0P5JNQKvWRz1Rg9AYIhg,22525
7
+ aspire/utils.py,sha256=eKGmchpuoNL15Xbu-AGoeZ00PcQEykEQiDZMnnRyV6A,24234
8
8
  aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
9
9
  aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
10
10
  aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
@@ -13,16 +13,16 @@ aspire/flows/jax/utils.py,sha256=5T6UrgpARG9VywC9qmTl45LjyZWuEdkW3XUladE6xJE,151
13
13
  aspire/flows/torch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  aspire/flows/torch/flows.py,sha256=0_YkiMT49QolyQnEFsh28tfKLnURVF0Z6aTnaWLIUDI,11672
15
15
  aspire/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- aspire/samplers/base.py,sha256=8slvgOBnacUrHXCVDAqo-3IZ_LB7-dS8wdMP55MI43Y,2907
16
+ aspire/samplers/base.py,sha256=VEHawyVA33jXHMo63p5hBHkp9k2qxU_bOxh5iaZSXew,3011
17
17
  aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M,676
18
18
  aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
19
19
  aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
20
  aspire/samplers/smc/base.py,sha256=66f_ORUvcKRqMIW35qjhUc-c0PFuY87lJa91MpSaTZI,10729
21
21
  aspire/samplers/smc/blackjax.py,sha256=IcTguAETiPmgFofmVW2GN40P5HBIxkmyd2VR8AU8f4k,12115
22
22
  aspire/samplers/smc/emcee.py,sha256=Wm0vvAlCcRhJMBt7_fU2ZnjDb8SN8jgUOTXLzNstRpA,2516
23
- aspire/samplers/smc/minipcn.py,sha256=ju1gcgyKHjodLEACPdL3eXA9ai8ZJ9_LwitD_Gmf1Rc,2765
24
- aspire_inference-0.1.0a9.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
- aspire_inference-0.1.0a9.dist-info/METADATA,sha256=DCcTGComTNt9UFlbUwGPf58itBX6Y-1-edVys0pb2RQ,3187
26
- aspire_inference-0.1.0a9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- aspire_inference-0.1.0a9.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
- aspire_inference-0.1.0a9.dist-info/RECORD,,
23
+ aspire/samplers/smc/minipcn.py,sha256=iQUBBwHZ_D4CxNjARMngklRvx6yTlEDKdeidyYCgqM4,3003
24
+ aspire_inference-0.1.0a10.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
+ aspire_inference-0.1.0a10.dist-info/METADATA,sha256=SdvlXKjQn0uJQPpWzoZKAH3oMJSKpZnvlUzxPsIwNlY,3869
26
+ aspire_inference-0.1.0a10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ aspire_inference-0.1.0a10.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
+ aspire_inference-0.1.0a10.dist-info/RECORD,,