aspire-inference 0.1.0a11__tar.gz → 0.1.0a13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/PKG-INFO +4 -2
  2. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/PKG-INFO +4 -2
  3. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/requires.txt +7 -1
  4. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/pyproject.toml +4 -2
  5. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/aspire.py +4 -2
  6. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/mcmc.py +2 -3
  7. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/base.py +6 -2
  8. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/utils.py +21 -0
  9. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_utils.py +17 -1
  10. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/.github/workflows/lint.yml +0 -0
  11. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/.github/workflows/publish.yml +0 -0
  12. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/.github/workflows/tests.yml +0 -0
  13. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/.gitignore +0 -0
  14. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/.pre-commit-config.yaml +0 -0
  15. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/LICENSE +0 -0
  16. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/README.md +0 -0
  17. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/SOURCES.txt +0 -0
  18. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/dependency_links.txt +0 -0
  19. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/top_level.txt +0 -0
  20. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/Makefile +0 -0
  21. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/checkpointing.rst +0 -0
  22. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/conf.py +0 -0
  23. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/entry_points.rst +0 -0
  24. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/examples.rst +0 -0
  25. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/index.rst +0 -0
  26. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/installation.rst +0 -0
  27. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/multiprocessing.rst +0 -0
  28. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/recipes.rst +0 -0
  29. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/requirements.txt +0 -0
  30. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/docs/user_guide.rst +0 -0
  31. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/examples/basic_example.py +0 -0
  32. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/examples/blackjax_smc_example.py +0 -0
  33. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/examples/smc_example.py +0 -0
  34. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/readthedocs.yml +0 -0
  35. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/setup.cfg +0 -0
  36. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/__init__.py +0 -0
  37. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/__init__.py +0 -0
  38. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/base.py +0 -0
  39. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/jax/__init__.py +0 -0
  40. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/jax/flows.py +0 -0
  41. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/jax/utils.py +0 -0
  42. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/torch/__init__.py +0 -0
  43. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/flows/torch/flows.py +0 -0
  44. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/history.py +0 -0
  45. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/plot.py +0 -0
  46. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/__init__.py +0 -0
  47. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/base.py +0 -0
  48. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/importance.py +0 -0
  49. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/__init__.py +0 -0
  50. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/blackjax.py +0 -0
  51. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/emcee.py +0 -0
  52. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/minipcn.py +0 -0
  53. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/samples.py +0 -0
  54. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/src/aspire/transforms.py +0 -0
  55. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/conftest.py +0 -0
  56. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/integration_tests/conftest.py +0 -0
  57. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/integration_tests/test_checkpointing.py +0 -0
  58. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/integration_tests/test_integration.py +0 -0
  59. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_flows/test_flows_core.py +0 -0
  60. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -0
  61. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -0
  62. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_samples.py +0 -0
  63. {aspire_inference-0.1.0a11 → aspire_inference-0.1.0a13}/tests/test_transforms.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a11
3
+ Version: 0.1.0a13
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -19,7 +19,8 @@ Provides-Extra: scipy
19
19
  Requires-Dist: scipy; extra == "scipy"
20
20
  Provides-Extra: jax
21
21
  Requires-Dist: jax; extra == "jax"
22
- Requires-Dist: jaxlib; extra == "jax"
22
+ Requires-Dist: jaxlib<0.8.2; python_version < "3.12" and extra == "jax"
23
+ Requires-Dist: jaxlib; python_version >= "3.12" and extra == "jax"
23
24
  Requires-Dist: flowjax; extra == "jax"
24
25
  Provides-Extra: torch
25
26
  Requires-Dist: torch; extra == "torch"
@@ -32,6 +33,7 @@ Provides-Extra: emcee
32
33
  Requires-Dist: emcee; extra == "emcee"
33
34
  Provides-Extra: blackjax
34
35
  Requires-Dist: blackjax; extra == "blackjax"
36
+ Requires-Dist: fastprogress<1.1.0; extra == "blackjax"
35
37
  Provides-Extra: test
36
38
  Requires-Dist: pytest; extra == "test"
37
39
  Requires-Dist: pytest-requires; extra == "test"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a11
3
+ Version: 0.1.0a13
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -19,7 +19,8 @@ Provides-Extra: scipy
19
19
  Requires-Dist: scipy; extra == "scipy"
20
20
  Provides-Extra: jax
21
21
  Requires-Dist: jax; extra == "jax"
22
- Requires-Dist: jaxlib; extra == "jax"
22
+ Requires-Dist: jaxlib<0.8.2; python_version < "3.12" and extra == "jax"
23
+ Requires-Dist: jaxlib; python_version >= "3.12" and extra == "jax"
23
24
  Requires-Dist: flowjax; extra == "jax"
24
25
  Provides-Extra: torch
25
26
  Requires-Dist: torch; extra == "torch"
@@ -32,6 +33,7 @@ Provides-Extra: emcee
32
33
  Requires-Dist: emcee; extra == "emcee"
33
34
  Provides-Extra: blackjax
34
35
  Requires-Dist: blackjax; extra == "blackjax"
36
+ Requires-Dist: fastprogress<1.1.0; extra == "blackjax"
35
37
  Provides-Extra: test
36
38
  Requires-Dist: pytest; extra == "test"
37
39
  Requires-Dist: pytest-requires; extra == "test"
@@ -6,15 +6,21 @@ h5py
6
6
 
7
7
  [blackjax]
8
8
  blackjax
9
+ fastprogress<1.1.0
9
10
 
10
11
  [emcee]
11
12
  emcee
12
13
 
13
14
  [jax]
14
15
  jax
15
- jaxlib
16
16
  flowjax
17
17
 
18
+ [jax:python_version < "3.12"]
19
+ jaxlib<0.8.2
20
+
21
+ [jax:python_version >= "3.12"]
22
+ jaxlib
23
+
18
24
  [minipcn]
19
25
  minipcn[array-api]>=0.2.0a3
20
26
  orng
@@ -30,8 +30,9 @@ scipy = [
30
30
  ]
31
31
  jax = [
32
32
  "jax",
33
- "jaxlib",
34
- "flowjax"
33
+ "jaxlib<0.8.2; python_version < '3.12'",
34
+ "jaxlib; python_version >= '3.12'",
35
+ "flowjax",
35
36
  ]
36
37
  torch = [
37
38
  "torch",
@@ -47,6 +48,7 @@ emcee = [
47
48
  ]
48
49
  blackjax = [
49
50
  "blackjax",
51
+ "fastprogress<1.1.0",
50
52
  ]
51
53
  test = [
52
54
  "pytest",
@@ -20,6 +20,7 @@ from .transforms import (
20
20
  )
21
21
  from .utils import (
22
22
  AspireFile,
23
+ function_id,
23
24
  load_from_h5_file,
24
25
  recursively_save_to_h5_file,
25
26
  resolve_xp,
@@ -353,6 +354,7 @@ class Aspire:
353
354
  xp=self.xp,
354
355
  dtype=self.dtype,
355
356
  preconditioning_transform=transform,
357
+ parameters=self.parameters,
356
358
  **kwargs,
357
359
  )
358
360
  return sampler
@@ -687,8 +689,8 @@ class Aspire:
687
689
  method of the sampler.
688
690
  """
689
691
  config = {
690
- "log_likelihood": self.log_likelihood.__name__,
691
- "log_prior": self.log_prior.__name__,
692
+ "log_likelihood": function_id(self.log_likelihood),
693
+ "log_prior": function_id(self.log_prior),
692
694
  "dims": self.dims,
693
695
  "parameters": self.parameters,
694
696
  "periodic_parameters": self.periodic_parameters,
@@ -13,14 +13,13 @@ class MCMCSampler(Sampler):
13
13
  n_samples_drawn = 0
14
14
  samples = None
15
15
  while n_samples_drawn < n_samples:
16
- n_to_draw = n_samples - n_samples_drawn
17
- x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
16
+ x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
18
17
  new_samples = Samples(x, xp=self.xp, log_q=log_q, dtype=self.dtype)
19
18
  new_samples.log_prior = new_samples.array_to_namespace(
20
19
  self.log_prior(new_samples)
21
20
  )
22
21
  valid = self.xp.isfinite(new_samples.log_prior)
23
- n_valid = self.xp.sum(valid)
22
+ n_valid = int(self.xp.sum(valid))
24
23
  if n_valid > 0:
25
24
  if samples is None:
26
25
  samples = new_samples[valid]
@@ -95,6 +95,7 @@ class SMCSampler(MCMCSampler):
95
95
  beta: float,
96
96
  beta_step: float,
97
97
  min_step: float,
98
+ beta_tolerance: float = 1e-6,
98
99
  ) -> tuple[float, float]:
99
100
  """Determine the next beta value.
100
101
 
@@ -108,6 +109,8 @@ class SMCSampler(MCMCSampler):
108
109
  The fixed beta step size if not adaptive.
109
110
  min_step : float
110
111
  The minimum beta step size.
112
+ beta_tolerance : float
113
+ Tolerance when checking for beta convergence.
111
114
 
112
115
  Returns
113
116
  -------
@@ -124,14 +127,13 @@ class SMCSampler(MCMCSampler):
124
127
  beta_prev = beta
125
128
  beta_min = beta_prev
126
129
  beta_max = 1.0
127
- tol = 1e-5
128
130
  eff_beta_max = effective_sample_size(
129
131
  samples.log_weights(beta_max)
130
132
  ) / len(samples)
131
133
  if eff_beta_max >= self.current_target_efficiency(beta_prev):
132
134
  beta_min = 1.0
133
135
  target_eff = self.current_target_efficiency(beta_prev)
134
- while beta_max - beta_min > tol:
136
+ while beta_max - beta_min > beta_tolerance:
135
137
  beta_try = 0.5 * (beta_max + beta_min)
136
138
  eff = effective_sample_size(
137
139
  samples.log_weights(beta_try)
@@ -163,6 +165,7 @@ class SMCSampler(MCMCSampler):
163
165
  checkpoint_every: int | None = None,
164
166
  checkpoint_file_path: str | None = None,
165
167
  resume_from: str | bytes | dict | None = None,
168
+ beta_tolerance: float = 1e-6,
166
169
  ) -> SMCSamples:
167
170
  resumed = resume_from is not None
168
171
  if resumed:
@@ -249,6 +252,7 @@ class SMCSampler(MCMCSampler):
249
252
  beta,
250
253
  beta_step,
251
254
  min_step,
255
+ beta_tolerance=beta_tolerance,
252
256
  )
253
257
  self.history.eff_target.append(
254
258
  self.current_target_efficiency(beta)
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import functools
3
4
  import inspect
4
5
  import logging
5
6
  import pickle
@@ -911,3 +912,23 @@ def track_calls(wrapped=None):
911
912
  return wrapped_func(*args, **kwargs)
912
913
 
913
914
  return wrapper(wrapped) if wrapped else wrapper
915
+
916
+
917
+ def function_id(fn: Any) -> str:
918
+ """Get a unique identifier for a function.
919
+
920
+ Parameters
921
+ ----------
922
+ fn : Any
923
+ The function to get the identifier for.
924
+
925
+ Returns
926
+ -------
927
+ str
928
+ The unique identifier for the function.
929
+ """
930
+ if isinstance(fn, functools.partial):
931
+ base = fn.func
932
+ else:
933
+ base = fn
934
+ return f"{base.__module__}:{getattr(base, '__qualname__', type(base).__name__)}"
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  import pickle
2
3
 
3
4
  import array_api_compat.numpy as np_xp
@@ -6,7 +7,7 @@ import h5py
6
7
  import jax.numpy as jnp
7
8
  import pytest
8
9
 
9
- from aspire.utils import convert_dtype, dump_state, resolve_dtype
10
+ from aspire.utils import convert_dtype, dump_state, function_id, resolve_dtype
10
11
 
11
12
 
12
13
  def _dtype_name(dtype):
@@ -85,3 +86,18 @@ def test_dump_state_round_trip(tmp_path):
85
86
  stored = fp["checkpoints"]["state"][...]
86
87
  restored = pickle.loads(stored.tobytes())
87
88
  assert restored == state
89
+
90
+
91
+ # Define a simple function for testing
92
+ def _foo(x):
93
+ return x
94
+
95
+
96
+ @pytest.mark.parametrize(
97
+ "fn", [lambda x: x, functools.partial(lambda x, y: x, y=1), _foo]
98
+ )
99
+ def test_function_id(fn):
100
+ fn_id = function_id(fn)
101
+ assert isinstance(fn_id, str)
102
+ # Calling again should give the same result
103
+ assert fn_id == function_id(fn)