aspire-inference 0.1.0a9__tar.gz → 0.1.0a10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/PKG-INFO +23 -4
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/README.md +20 -2
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/aspire_inference.egg-info/PKG-INFO +23 -4
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/aspire_inference.egg-info/requires.txt +2 -1
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/docs/conf.py +1 -1
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/docs/index.rst +8 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/docs/installation.rst +5 -1
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/pyproject.toml +2 -1
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/samplers/base.py +6 -2
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/samplers/smc/minipcn.py +16 -4
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/samples.py +10 -6
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/utils.py +59 -4
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/.github/workflows/lint.yml +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/.github/workflows/publish.yml +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/.github/workflows/tests.yml +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/.gitignore +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/.pre-commit-config.yaml +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/LICENSE +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/aspire_inference.egg-info/SOURCES.txt +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/aspire_inference.egg-info/dependency_links.txt +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/aspire_inference.egg-info/top_level.txt +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/docs/Makefile +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/docs/entry_points.rst +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/docs/examples.rst +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/docs/multiprocessing.rst +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/docs/recipes.rst +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/docs/requirements.txt +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/docs/user_guide.rst +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/examples/basic_example.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/examples/blackjax_smc_example.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/examples/smc_example.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/readthedocs.yml +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/setup.cfg +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/__init__.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/aspire.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/flows/__init__.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/flows/base.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/flows/jax/__init__.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/flows/jax/flows.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/flows/jax/utils.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/flows/torch/__init__.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/flows/torch/flows.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/history.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/plot.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/samplers/__init__.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/samplers/importance.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/samplers/mcmc.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/samplers/smc/__init__.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/samplers/smc/base.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/samplers/smc/blackjax.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/samplers/smc/emcee.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/src/aspire/transforms.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/tests/conftest.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/tests/integration_tests/conftest.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/tests/integration_tests/test_integration.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/tests/test_flows/test_flows_core.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/tests/test_samples.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/tests/test_transforms.py +0 -0
- {aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/tests/test_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: aspire-inference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a10
|
|
4
4
|
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
5
|
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
6
|
License: MIT
|
|
@@ -26,7 +26,8 @@ Requires-Dist: torch; extra == "torch"
|
|
|
26
26
|
Requires-Dist: zuko; extra == "torch"
|
|
27
27
|
Requires-Dist: tqdm; extra == "torch"
|
|
28
28
|
Provides-Extra: minipcn
|
|
29
|
-
Requires-Dist: minipcn; extra == "minipcn"
|
|
29
|
+
Requires-Dist: minipcn[array-api]>=0.2.0a3; extra == "minipcn"
|
|
30
|
+
Requires-Dist: orng; extra == "minipcn"
|
|
30
31
|
Provides-Extra: emcee
|
|
31
32
|
Requires-Dist: emcee; extra == "emcee"
|
|
32
33
|
Provides-Extra: blackjax
|
|
@@ -49,10 +50,28 @@ aspire is a framework for reusing existing posterior samples to obtain new resul
|
|
|
49
50
|
|
|
50
51
|
## Installation
|
|
51
52
|
|
|
52
|
-
aspire can be installed from PyPI using `pip
|
|
53
|
+
aspire can be installed from PyPI using `pip`. By default, you need to install
|
|
54
|
+
one of the backends for the normalizing flows, either `torch` or `jax`.
|
|
55
|
+
We also recommend installing `minipcn` if using the `smc` sampler:
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
**Torch**
|
|
59
|
+
|
|
60
|
+
We recommend installing `torch` manually to ensure correct CPU/CUDA versions are
|
|
61
|
+
installed. See the [PyTorch installation instructions](https://pytorch.org/)
|
|
62
|
+
for more details.
|
|
63
|
+
|
|
64
|
+
```
|
|
65
|
+
pip install aspire-inference[torch,minipcn]
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
**Jax**:
|
|
69
|
+
|
|
70
|
+
We recommend install `jax` manually to ensure the correct GPU/CUDA versions
|
|
71
|
+
are installed. See the [jax documentation for details](https://docs.jax.dev/en/latest/installation.html)
|
|
53
72
|
|
|
54
73
|
```
|
|
55
|
-
pip install aspire-inference
|
|
74
|
+
pip install aspire-inference[jax,minipcn]
|
|
56
75
|
```
|
|
57
76
|
|
|
58
77
|
**Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
|
|
@@ -10,10 +10,28 @@ aspire is a framework for reusing existing posterior samples to obtain new resul
|
|
|
10
10
|
|
|
11
11
|
## Installation
|
|
12
12
|
|
|
13
|
-
aspire can be installed from PyPI using `pip
|
|
13
|
+
aspire can be installed from PyPI using `pip`. By default, you need to install
|
|
14
|
+
one of the backends for the normalizing flows, either `torch` or `jax`.
|
|
15
|
+
We also recommend installing `minipcn` if using the `smc` sampler:
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
**Torch**
|
|
19
|
+
|
|
20
|
+
We recommend installing `torch` manually to ensure correct CPU/CUDA versions are
|
|
21
|
+
installed. See the [PyTorch installation instructions](https://pytorch.org/)
|
|
22
|
+
for more details.
|
|
23
|
+
|
|
24
|
+
```
|
|
25
|
+
pip install aspire-inference[torch,minipcn]
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
**Jax**:
|
|
29
|
+
|
|
30
|
+
We recommend install `jax` manually to ensure the correct GPU/CUDA versions
|
|
31
|
+
are installed. See the [jax documentation for details](https://docs.jax.dev/en/latest/installation.html)
|
|
14
32
|
|
|
15
33
|
```
|
|
16
|
-
pip install aspire-inference
|
|
34
|
+
pip install aspire-inference[jax,minipcn]
|
|
17
35
|
```
|
|
18
36
|
|
|
19
37
|
**Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: aspire-inference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a10
|
|
4
4
|
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
5
|
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
6
|
License: MIT
|
|
@@ -26,7 +26,8 @@ Requires-Dist: torch; extra == "torch"
|
|
|
26
26
|
Requires-Dist: zuko; extra == "torch"
|
|
27
27
|
Requires-Dist: tqdm; extra == "torch"
|
|
28
28
|
Provides-Extra: minipcn
|
|
29
|
-
Requires-Dist: minipcn; extra == "minipcn"
|
|
29
|
+
Requires-Dist: minipcn[array-api]>=0.2.0a3; extra == "minipcn"
|
|
30
|
+
Requires-Dist: orng; extra == "minipcn"
|
|
30
31
|
Provides-Extra: emcee
|
|
31
32
|
Requires-Dist: emcee; extra == "emcee"
|
|
32
33
|
Provides-Extra: blackjax
|
|
@@ -49,10 +50,28 @@ aspire is a framework for reusing existing posterior samples to obtain new resul
|
|
|
49
50
|
|
|
50
51
|
## Installation
|
|
51
52
|
|
|
52
|
-
aspire can be installed from PyPI using `pip
|
|
53
|
+
aspire can be installed from PyPI using `pip`. By default, you need to install
|
|
54
|
+
one of the backends for the normalizing flows, either `torch` or `jax`.
|
|
55
|
+
We also recommend installing `minipcn` if using the `smc` sampler:
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
**Torch**
|
|
59
|
+
|
|
60
|
+
We recommend installing `torch` manually to ensure correct CPU/CUDA versions are
|
|
61
|
+
installed. See the [PyTorch installation instructions](https://pytorch.org/)
|
|
62
|
+
for more details.
|
|
63
|
+
|
|
64
|
+
```
|
|
65
|
+
pip install aspire-inference[torch,minipcn]
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
**Jax**:
|
|
69
|
+
|
|
70
|
+
We recommend install `jax` manually to ensure the correct GPU/CUDA versions
|
|
71
|
+
are installed. See the [jax documentation for details](https://docs.jax.dev/en/latest/installation.html)
|
|
53
72
|
|
|
54
73
|
```
|
|
55
|
-
pip install aspire-inference
|
|
74
|
+
pip install aspire-inference[jax,minipcn]
|
|
56
75
|
```
|
|
57
76
|
|
|
58
77
|
**Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
|
|
@@ -61,3 +61,11 @@ examples, and the complete API reference.
|
|
|
61
61
|
multiprocessing
|
|
62
62
|
examples
|
|
63
63
|
entry_points
|
|
64
|
+
API Reference </autoapi/aspire/index>
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
.. toctree::
|
|
68
|
+
:maxdepth: 2
|
|
69
|
+
:caption: Related Packages
|
|
70
|
+
|
|
71
|
+
aspire-bilby <https://aspire.readthedocs.io/projects/aspire-bilby/en/latest/>
|
|
@@ -14,7 +14,11 @@ Install the library from PyPI (note the published name):
|
|
|
14
14
|
|
|
15
15
|
$ python -m pip install aspire-inference
|
|
16
16
|
|
|
17
|
-
The installed distribution exposes the ``aspire`` import namespace.
|
|
17
|
+
The installed distribution exposes the ``aspire`` import namespace. By default,
|
|
18
|
+
this doesn't include any optional dependencies beyond the core ones listed above.
|
|
19
|
+
We recommend installing with at least one backend for normalizing flows, e.g.
|
|
20
|
+
``torch`` (PyTorch + ``zuko``) or ``jax`` (JAX + ``flowjax``).
|
|
21
|
+
and optionally the ``minipcn`` SMC kernel.
|
|
18
22
|
|
|
19
23
|
Optional extras
|
|
20
24
|
---------------
|
|
@@ -4,7 +4,7 @@ from typing import Any, Callable
|
|
|
4
4
|
from ..flows.base import Flow
|
|
5
5
|
from ..samples import Samples
|
|
6
6
|
from ..transforms import IdentityTransform
|
|
7
|
-
from ..utils import track_calls
|
|
7
|
+
from ..utils import asarray, track_calls
|
|
8
8
|
|
|
9
9
|
logger = logging.getLogger(__name__)
|
|
10
10
|
|
|
@@ -56,7 +56,11 @@ class Sampler:
|
|
|
56
56
|
|
|
57
57
|
def fit_preconditioning_transform(self, x):
|
|
58
58
|
"""Fit the data transform to the data."""
|
|
59
|
-
x =
|
|
59
|
+
x = asarray(
|
|
60
|
+
x,
|
|
61
|
+
xp=self.preconditioning_transform.xp,
|
|
62
|
+
dtype=self.preconditioning_transform.dtype,
|
|
63
|
+
)
|
|
60
64
|
return self.preconditioning_transform.fit(x)
|
|
61
65
|
|
|
62
66
|
@track_calls
|
|
@@ -3,7 +3,11 @@ from functools import partial
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
5
|
from ...samples import SMCSamples
|
|
6
|
-
from ...utils import
|
|
6
|
+
from ...utils import (
|
|
7
|
+
asarray,
|
|
8
|
+
determine_backend_name,
|
|
9
|
+
track_calls,
|
|
10
|
+
)
|
|
7
11
|
from .base import NumpySMCSampler
|
|
8
12
|
|
|
9
13
|
|
|
@@ -13,7 +17,7 @@ class MiniPCNSMC(NumpySMCSampler):
|
|
|
13
17
|
rng = None
|
|
14
18
|
|
|
15
19
|
def log_prob(self, x, beta=None):
|
|
16
|
-
return
|
|
20
|
+
return super().log_prob(x, beta)
|
|
17
21
|
|
|
18
22
|
@track_calls
|
|
19
23
|
def sample(
|
|
@@ -29,11 +33,14 @@ class MiniPCNSMC(NumpySMCSampler):
|
|
|
29
33
|
sampler_kwargs: dict | None = None,
|
|
30
34
|
rng: np.random.Generator | None = None,
|
|
31
35
|
):
|
|
36
|
+
from orng import ArrayRNG
|
|
37
|
+
|
|
32
38
|
self.sampler_kwargs = sampler_kwargs or {}
|
|
33
39
|
self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
|
|
34
40
|
self.sampler_kwargs.setdefault("target_acceptance_rate", 0.234)
|
|
35
41
|
self.sampler_kwargs.setdefault("step_fn", "tpcn")
|
|
36
|
-
self.
|
|
42
|
+
self.backend_str = determine_backend_name(xp=self.xp)
|
|
43
|
+
self.rng = rng or ArrayRNG(backend=self.backend_str)
|
|
37
44
|
return super().sample(
|
|
38
45
|
n_samples,
|
|
39
46
|
n_steps=n_steps,
|
|
@@ -58,9 +65,14 @@ class MiniPCNSMC(NumpySMCSampler):
|
|
|
58
65
|
target_acceptance_rate=self.sampler_kwargs[
|
|
59
66
|
"target_acceptance_rate"
|
|
60
67
|
],
|
|
68
|
+
xp=self.xp,
|
|
61
69
|
)
|
|
62
70
|
# Map to transformed dimension for sampling
|
|
63
|
-
z =
|
|
71
|
+
z = asarray(
|
|
72
|
+
self.fit_preconditioning_transform(particles.x),
|
|
73
|
+
xp=self.xp,
|
|
74
|
+
dtype=self.dtype,
|
|
75
|
+
)
|
|
64
76
|
chain, history = sampler.sample(
|
|
65
77
|
z,
|
|
66
78
|
n_steps=n_steps or self.sampler_kwargs["n_steps"],
|
|
@@ -425,19 +425,23 @@ class Samples(BaseSamples):
|
|
|
425
425
|
|
|
426
426
|
def to_namespace(self, xp):
|
|
427
427
|
return self.__class__(
|
|
428
|
-
x=asarray(self.x, xp),
|
|
428
|
+
x=asarray(self.x, xp, dtype=self.dtype),
|
|
429
429
|
parameters=self.parameters,
|
|
430
|
-
log_likelihood=asarray(self.log_likelihood, xp)
|
|
430
|
+
log_likelihood=asarray(self.log_likelihood, xp, dtype=self.dtype)
|
|
431
431
|
if self.log_likelihood is not None
|
|
432
432
|
else None,
|
|
433
|
-
log_prior=asarray(self.log_prior, xp)
|
|
433
|
+
log_prior=asarray(self.log_prior, xp, dtype=self.dtype)
|
|
434
434
|
if self.log_prior is not None
|
|
435
435
|
else None,
|
|
436
|
-
log_q=asarray(self.log_q, xp
|
|
437
|
-
|
|
436
|
+
log_q=asarray(self.log_q, xp, dtype=self.dtype)
|
|
437
|
+
if self.log_q is not None
|
|
438
|
+
else None,
|
|
439
|
+
log_evidence=asarray(self.log_evidence, xp, dtype=self.dtype)
|
|
438
440
|
if self.log_evidence is not None
|
|
439
441
|
else None,
|
|
440
|
-
log_evidence_error=asarray(
|
|
442
|
+
log_evidence_error=asarray(
|
|
443
|
+
self.log_evidence_error, xp, dtype=self.dtype
|
|
444
|
+
)
|
|
441
445
|
if self.log_evidence_error is not None
|
|
442
446
|
else None,
|
|
443
447
|
)
|
|
@@ -12,7 +12,13 @@ import h5py
|
|
|
12
12
|
import wrapt
|
|
13
13
|
from array_api_compat import (
|
|
14
14
|
array_namespace,
|
|
15
|
+
is_cupy_namespace,
|
|
16
|
+
is_dask_namespace,
|
|
15
17
|
is_jax_array,
|
|
18
|
+
is_jax_namespace,
|
|
19
|
+
is_ndonnx_namespace,
|
|
20
|
+
is_numpy_namespace,
|
|
21
|
+
is_pydata_sparse_namespace,
|
|
16
22
|
is_torch_array,
|
|
17
23
|
is_torch_namespace,
|
|
18
24
|
to_device,
|
|
@@ -28,6 +34,17 @@ if TYPE_CHECKING:
|
|
|
28
34
|
logger = logging.getLogger(__name__)
|
|
29
35
|
|
|
30
36
|
|
|
37
|
+
IS_NAMESPACE_FUNCTIONS = {
|
|
38
|
+
"numpy": is_numpy_namespace,
|
|
39
|
+
"torch": is_torch_namespace,
|
|
40
|
+
"jax": is_jax_namespace,
|
|
41
|
+
"cupy": is_cupy_namespace,
|
|
42
|
+
"dask": is_dask_namespace,
|
|
43
|
+
"pydata_sparse": is_pydata_sparse_namespace,
|
|
44
|
+
"ndonnx": is_ndonnx_namespace,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
31
48
|
def configure_logger(
|
|
32
49
|
log_level: str | int = "INFO",
|
|
33
50
|
additional_loggers: list[str] = None,
|
|
@@ -234,7 +251,7 @@ def to_numpy(x: Array, **kwargs) -> np.ndarray:
|
|
|
234
251
|
return np.asarray(x, **kwargs)
|
|
235
252
|
|
|
236
253
|
|
|
237
|
-
def asarray(x, xp: Any = None, **kwargs) -> Array:
|
|
254
|
+
def asarray(x, xp: Any = None, dtype: Any | None = None, **kwargs) -> Array:
|
|
238
255
|
"""Convert an array to the specified array API.
|
|
239
256
|
|
|
240
257
|
Parameters
|
|
@@ -244,13 +261,51 @@ def asarray(x, xp: Any = None, **kwargs) -> Array:
|
|
|
244
261
|
xp : Any
|
|
245
262
|
The array API to use for the conversion. If None, the array API
|
|
246
263
|
is inferred from the input array.
|
|
264
|
+
dtype : Any | str | None
|
|
265
|
+
The dtype to use for the conversion. If None, the dtype is not changed.
|
|
247
266
|
kwargs : dict
|
|
248
267
|
Additional keyword arguments to pass to xp.asarray.
|
|
249
268
|
"""
|
|
269
|
+
# Handle DLPack conversion from JAX to PyTorch to avoid shape issues when
|
|
270
|
+
# passing JAX arrays directly to torch.asarray.
|
|
250
271
|
if is_jax_array(x) and is_torch_namespace(xp):
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
272
|
+
tensor = xp.utils.dlpack.from_dlpack(x)
|
|
273
|
+
if dtype is not None:
|
|
274
|
+
tensor = tensor.to(resolve_dtype(dtype, xp=xp))
|
|
275
|
+
return tensor
|
|
276
|
+
|
|
277
|
+
if dtype is not None:
|
|
278
|
+
kwargs["dtype"] = resolve_dtype(dtype, xp=xp)
|
|
279
|
+
return xp.asarray(x, **kwargs)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def determine_backend_name(
|
|
283
|
+
x: Array | None = None, xp: Any | None = None
|
|
284
|
+
) -> str:
|
|
285
|
+
"""Determine the backend name from an array or array API module.
|
|
286
|
+
|
|
287
|
+
Parameters
|
|
288
|
+
----------
|
|
289
|
+
x : Array or None
|
|
290
|
+
The array to infer the backend from. If None, xp must be provided.
|
|
291
|
+
xp : Any or None
|
|
292
|
+
The array API module to infer the backend from. If None, x must be provided.
|
|
293
|
+
|
|
294
|
+
Returns
|
|
295
|
+
-------
|
|
296
|
+
str
|
|
297
|
+
The name of the backend. If the backend cannot be determined, returns "unknown".
|
|
298
|
+
"""
|
|
299
|
+
if x is not None:
|
|
300
|
+
xp = array_namespace(x)
|
|
301
|
+
if xp is None:
|
|
302
|
+
raise ValueError(
|
|
303
|
+
"Either x or xp must be provided to determine backend."
|
|
304
|
+
)
|
|
305
|
+
for name, is_namespace_fn in IS_NAMESPACE_FUNCTIONS.items():
|
|
306
|
+
if is_namespace_fn(xp):
|
|
307
|
+
return name
|
|
308
|
+
return "unknown"
|
|
254
309
|
|
|
255
310
|
|
|
256
311
|
def resolve_dtype(dtype: Any | str | None, xp: Any) -> Any | None:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/aspire_inference.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
{aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/aspire_inference.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{aspire_inference-0.1.0a9 → aspire_inference-0.1.0a10}/tests/integration_tests/test_integration.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|