aspire-inference 0.1.0a11__tar.gz → 0.1.0a13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/PKG-INFO +4 -2
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/PKG-INFO +4 -2
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/requires.txt +7 -1
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/pyproject.toml +4 -2
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/aspire.py +4 -2
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/mcmc.py +2 -3
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/base.py +6 -2
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/utils.py +21 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_utils.py +17 -1
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/.github/workflows/lint.yml +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/.github/workflows/publish.yml +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/.github/workflows/tests.yml +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/.gitignore +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/.pre-commit-config.yaml +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/LICENSE +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/README.md +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/SOURCES.txt +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/dependency_links.txt +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/top_level.txt +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/Makefile +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/checkpointing.rst +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/conf.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/entry_points.rst +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/examples.rst +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/index.rst +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/installation.rst +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/multiprocessing.rst +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/recipes.rst +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/requirements.txt +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/user_guide.rst +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/examples/basic_example.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/examples/blackjax_smc_example.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/examples/smc_example.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/readthedocs.yml +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/setup.cfg +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/__init__.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/__init__.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/base.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/jax/__init__.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/jax/flows.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/jax/utils.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/torch/__init__.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/torch/flows.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/history.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/plot.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/__init__.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/base.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/importance.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/__init__.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/blackjax.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/emcee.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/minipcn.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samples.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/transforms.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/conftest.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/integration_tests/conftest.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/integration_tests/test_checkpointing.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/integration_tests/test_integration.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_flows/test_flows_core.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_samples.py +0 -0
- {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_transforms.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: aspire-inference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a13
|
|
4
4
|
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
5
|
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
6
|
License: MIT
|
|
@@ -19,7 +19,8 @@ Provides-Extra: scipy
|
|
|
19
19
|
Requires-Dist: scipy; extra == "scipy"
|
|
20
20
|
Provides-Extra: jax
|
|
21
21
|
Requires-Dist: jax; extra == "jax"
|
|
22
|
-
Requires-Dist: jaxlib; extra == "jax"
|
|
22
|
+
Requires-Dist: jaxlib<0.8.2; python_version < "3.12" and extra == "jax"
|
|
23
|
+
Requires-Dist: jaxlib; python_version >= "3.12" and extra == "jax"
|
|
23
24
|
Requires-Dist: flowjax; extra == "jax"
|
|
24
25
|
Provides-Extra: torch
|
|
25
26
|
Requires-Dist: torch; extra == "torch"
|
|
@@ -32,6 +33,7 @@ Provides-Extra: emcee
|
|
|
32
33
|
Requires-Dist: emcee; extra == "emcee"
|
|
33
34
|
Provides-Extra: blackjax
|
|
34
35
|
Requires-Dist: blackjax; extra == "blackjax"
|
|
36
|
+
Requires-Dist: fastprogress<1.1.0; extra == "blackjax"
|
|
35
37
|
Provides-Extra: test
|
|
36
38
|
Requires-Dist: pytest; extra == "test"
|
|
37
39
|
Requires-Dist: pytest-requires; extra == "test"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: aspire-inference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a13
|
|
4
4
|
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
5
|
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
6
|
License: MIT
|
|
@@ -19,7 +19,8 @@ Provides-Extra: scipy
|
|
|
19
19
|
Requires-Dist: scipy; extra == "scipy"
|
|
20
20
|
Provides-Extra: jax
|
|
21
21
|
Requires-Dist: jax; extra == "jax"
|
|
22
|
-
Requires-Dist: jaxlib; extra == "jax"
|
|
22
|
+
Requires-Dist: jaxlib<0.8.2; python_version < "3.12" and extra == "jax"
|
|
23
|
+
Requires-Dist: jaxlib; python_version >= "3.12" and extra == "jax"
|
|
23
24
|
Requires-Dist: flowjax; extra == "jax"
|
|
24
25
|
Provides-Extra: torch
|
|
25
26
|
Requires-Dist: torch; extra == "torch"
|
|
@@ -32,6 +33,7 @@ Provides-Extra: emcee
|
|
|
32
33
|
Requires-Dist: emcee; extra == "emcee"
|
|
33
34
|
Provides-Extra: blackjax
|
|
34
35
|
Requires-Dist: blackjax; extra == "blackjax"
|
|
36
|
+
Requires-Dist: fastprogress<1.1.0; extra == "blackjax"
|
|
35
37
|
Provides-Extra: test
|
|
36
38
|
Requires-Dist: pytest; extra == "test"
|
|
37
39
|
Requires-Dist: pytest-requires; extra == "test"
|
{aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/requires.txt
RENAMED
|
@@ -6,15 +6,21 @@ h5py
|
|
|
6
6
|
|
|
7
7
|
[blackjax]
|
|
8
8
|
blackjax
|
|
9
|
+
fastprogress<1.1.0
|
|
9
10
|
|
|
10
11
|
[emcee]
|
|
11
12
|
emcee
|
|
12
13
|
|
|
13
14
|
[jax]
|
|
14
15
|
jax
|
|
15
|
-
jaxlib
|
|
16
16
|
flowjax
|
|
17
17
|
|
|
18
|
+
[jax:python_version < "3.12"]
|
|
19
|
+
jaxlib<0.8.2
|
|
20
|
+
|
|
21
|
+
[jax:python_version >= "3.12"]
|
|
22
|
+
jaxlib
|
|
23
|
+
|
|
18
24
|
[minipcn]
|
|
19
25
|
minipcn[array-api]>=0.2.0a3
|
|
20
26
|
orng
|
|
@@ -30,8 +30,9 @@ scipy = [
|
|
|
30
30
|
]
|
|
31
31
|
jax = [
|
|
32
32
|
"jax",
|
|
33
|
-
"jaxlib",
|
|
34
|
-
"
|
|
33
|
+
"jaxlib<0.8.2; python_version < '3.12'",
|
|
34
|
+
"jaxlib; python_version >= '3.12'",
|
|
35
|
+
"flowjax",
|
|
35
36
|
]
|
|
36
37
|
torch = [
|
|
37
38
|
"torch",
|
|
@@ -47,6 +48,7 @@ emcee = [
|
|
|
47
48
|
]
|
|
48
49
|
blackjax = [
|
|
49
50
|
"blackjax",
|
|
51
|
+
"fastprogress<1.1.0",
|
|
50
52
|
]
|
|
51
53
|
test = [
|
|
52
54
|
"pytest",
|
|
@@ -20,6 +20,7 @@ from .transforms import (
|
|
|
20
20
|
)
|
|
21
21
|
from .utils import (
|
|
22
22
|
AspireFile,
|
|
23
|
+
function_id,
|
|
23
24
|
load_from_h5_file,
|
|
24
25
|
recursively_save_to_h5_file,
|
|
25
26
|
resolve_xp,
|
|
@@ -353,6 +354,7 @@ class Aspire:
|
|
|
353
354
|
xp=self.xp,
|
|
354
355
|
dtype=self.dtype,
|
|
355
356
|
preconditioning_transform=transform,
|
|
357
|
+
parameters=self.parameters,
|
|
356
358
|
**kwargs,
|
|
357
359
|
)
|
|
358
360
|
return sampler
|
|
@@ -687,8 +689,8 @@ class Aspire:
|
|
|
687
689
|
method of the sampler.
|
|
688
690
|
"""
|
|
689
691
|
config = {
|
|
690
|
-
"log_likelihood": self.log_likelihood
|
|
691
|
-
"log_prior": self.log_prior
|
|
692
|
+
"log_likelihood": function_id(self.log_likelihood),
|
|
693
|
+
"log_prior": function_id(self.log_prior),
|
|
692
694
|
"dims": self.dims,
|
|
693
695
|
"parameters": self.parameters,
|
|
694
696
|
"periodic_parameters": self.periodic_parameters,
|
|
@@ -13,14 +13,13 @@ class MCMCSampler(Sampler):
|
|
|
13
13
|
n_samples_drawn = 0
|
|
14
14
|
samples = None
|
|
15
15
|
while n_samples_drawn < n_samples:
|
|
16
|
-
|
|
17
|
-
x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
|
|
16
|
+
x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
|
|
18
17
|
new_samples = Samples(x, xp=self.xp, log_q=log_q, dtype=self.dtype)
|
|
19
18
|
new_samples.log_prior = new_samples.array_to_namespace(
|
|
20
19
|
self.log_prior(new_samples)
|
|
21
20
|
)
|
|
22
21
|
valid = self.xp.isfinite(new_samples.log_prior)
|
|
23
|
-
n_valid = self.xp.sum(valid)
|
|
22
|
+
n_valid = int(self.xp.sum(valid))
|
|
24
23
|
if n_valid > 0:
|
|
25
24
|
if samples is None:
|
|
26
25
|
samples = new_samples[valid]
|
|
@@ -95,6 +95,7 @@ class SMCSampler(MCMCSampler):
|
|
|
95
95
|
beta: float,
|
|
96
96
|
beta_step: float,
|
|
97
97
|
min_step: float,
|
|
98
|
+
beta_tolerance: float = 1e-6,
|
|
98
99
|
) -> tuple[float, float]:
|
|
99
100
|
"""Determine the next beta value.
|
|
100
101
|
|
|
@@ -108,6 +109,8 @@ class SMCSampler(MCMCSampler):
|
|
|
108
109
|
The fixed beta step size if not adaptive.
|
|
109
110
|
min_step : float
|
|
110
111
|
The minimum beta step size.
|
|
112
|
+
beta_tolerance : float
|
|
113
|
+
Tolerance when checking for beta convergence.
|
|
111
114
|
|
|
112
115
|
Returns
|
|
113
116
|
-------
|
|
@@ -124,14 +127,13 @@ class SMCSampler(MCMCSampler):
|
|
|
124
127
|
beta_prev = beta
|
|
125
128
|
beta_min = beta_prev
|
|
126
129
|
beta_max = 1.0
|
|
127
|
-
tol = 1e-5
|
|
128
130
|
eff_beta_max = effective_sample_size(
|
|
129
131
|
samples.log_weights(beta_max)
|
|
130
132
|
) / len(samples)
|
|
131
133
|
if eff_beta_max >= self.current_target_efficiency(beta_prev):
|
|
132
134
|
beta_min = 1.0
|
|
133
135
|
target_eff = self.current_target_efficiency(beta_prev)
|
|
134
|
-
while beta_max - beta_min >
|
|
136
|
+
while beta_max - beta_min > beta_tolerance:
|
|
135
137
|
beta_try = 0.5 * (beta_max + beta_min)
|
|
136
138
|
eff = effective_sample_size(
|
|
137
139
|
samples.log_weights(beta_try)
|
|
@@ -163,6 +165,7 @@ class SMCSampler(MCMCSampler):
|
|
|
163
165
|
checkpoint_every: int | None = None,
|
|
164
166
|
checkpoint_file_path: str | None = None,
|
|
165
167
|
resume_from: str | bytes | dict | None = None,
|
|
168
|
+
beta_tolerance: float = 1e-6,
|
|
166
169
|
) -> SMCSamples:
|
|
167
170
|
resumed = resume_from is not None
|
|
168
171
|
if resumed:
|
|
@@ -249,6 +252,7 @@ class SMCSampler(MCMCSampler):
|
|
|
249
252
|
beta,
|
|
250
253
|
beta_step,
|
|
251
254
|
min_step,
|
|
255
|
+
beta_tolerance=beta_tolerance,
|
|
252
256
|
)
|
|
253
257
|
self.history.eff_target.append(
|
|
254
258
|
self.current_target_efficiency(beta)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import functools
|
|
3
4
|
import inspect
|
|
4
5
|
import logging
|
|
5
6
|
import pickle
|
|
@@ -911,3 +912,23 @@ def track_calls(wrapped=None):
|
|
|
911
912
|
return wrapped_func(*args, **kwargs)
|
|
912
913
|
|
|
913
914
|
return wrapper(wrapped) if wrapped else wrapper
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
def function_id(fn: Any) -> str:
|
|
918
|
+
"""Get a unique identifier for a function.
|
|
919
|
+
|
|
920
|
+
Parameters
|
|
921
|
+
----------
|
|
922
|
+
fn : Any
|
|
923
|
+
The function to get the identifier for.
|
|
924
|
+
|
|
925
|
+
Returns
|
|
926
|
+
-------
|
|
927
|
+
str
|
|
928
|
+
The unique identifier for the function.
|
|
929
|
+
"""
|
|
930
|
+
if isinstance(fn, functools.partial):
|
|
931
|
+
base = fn.func
|
|
932
|
+
else:
|
|
933
|
+
base = fn
|
|
934
|
+
return f"{base.__module__}:{getattr(base, '__qualname__', type(base).__name__)}"
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import pickle
|
|
2
3
|
|
|
3
4
|
import array_api_compat.numpy as np_xp
|
|
@@ -6,7 +7,7 @@ import h5py
|
|
|
6
7
|
import jax.numpy as jnp
|
|
7
8
|
import pytest
|
|
8
9
|
|
|
9
|
-
from aspire.utils import convert_dtype, dump_state, resolve_dtype
|
|
10
|
+
from aspire.utils import convert_dtype, dump_state, function_id, resolve_dtype
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def _dtype_name(dtype):
|
|
@@ -85,3 +86,18 @@ def test_dump_state_round_trip(tmp_path):
|
|
|
85
86
|
stored = fp["checkpoints"]["state"][...]
|
|
86
87
|
restored = pickle.loads(stored.tobytes())
|
|
87
88
|
assert restored == state
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# Define a simple function for testing
|
|
92
|
+
def _foo(x):
|
|
93
|
+
return x
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@pytest.mark.parametrize(
|
|
97
|
+
"fn", [lambda x: x, functools.partial(lambda x, y: x, y=1), _foo]
|
|
98
|
+
)
|
|
99
|
+
def test_function_id(fn):
|
|
100
|
+
fn_id = function_id(fn)
|
|
101
|
+
assert isinstance(fn_id, str)
|
|
102
|
+
# Calling again should give the same result
|
|
103
|
+
assert fn_id == function_id(fn)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
{aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/integration_tests/test_integration.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|