aspire-inference 0.1.0a12__tar.gz → 0.1.0a13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/PKG-INFO +4 -2
  2. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/PKG-INFO +4 -2
  3. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/requires.txt +7 -1
  4. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/pyproject.toml +4 -2
  5. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/aspire.py +1 -0
  6. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/samplers/mcmc.py +2 -3
  7. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/base.py +6 -2
  8. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/.github/workflows/lint.yml +0 -0
  9. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/.github/workflows/publish.yml +0 -0
  10. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/.github/workflows/tests.yml +0 -0
  11. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/.gitignore +0 -0
  12. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/.pre-commit-config.yaml +0 -0
  13. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/LICENSE +0 -0
  14. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/README.md +0 -0
  15. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/SOURCES.txt +0 -0
  16. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/dependency_links.txt +0 -0
  17. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/aspire_inference.egg-info/top_level.txt +0 -0
  18. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/docs/Makefile +0 -0
  19. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/docs/checkpointing.rst +0 -0
  20. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/docs/conf.py +0 -0
  21. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/docs/entry_points.rst +0 -0
  22. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/docs/examples.rst +0 -0
  23. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/docs/index.rst +0 -0
  24. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/docs/installation.rst +0 -0
  25. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/docs/multiprocessing.rst +0 -0
  26. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/docs/recipes.rst +0 -0
  27. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/docs/requirements.txt +0 -0
  28. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/docs/user_guide.rst +0 -0
  29. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/examples/basic_example.py +0 -0
  30. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/examples/blackjax_smc_example.py +0 -0
  31. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/examples/smc_example.py +0 -0
  32. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/readthedocs.yml +0 -0
  33. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/setup.cfg +0 -0
  34. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/__init__.py +0 -0
  35. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/flows/__init__.py +0 -0
  36. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/flows/base.py +0 -0
  37. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/flows/jax/__init__.py +0 -0
  38. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/flows/jax/flows.py +0 -0
  39. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/flows/jax/utils.py +0 -0
  40. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/flows/torch/__init__.py +0 -0
  41. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/flows/torch/flows.py +0 -0
  42. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/history.py +0 -0
  43. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/plot.py +0 -0
  44. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/samplers/__init__.py +0 -0
  45. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/samplers/base.py +0 -0
  46. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/samplers/importance.py +0 -0
  47. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/__init__.py +0 -0
  48. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/blackjax.py +0 -0
  49. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/emcee.py +0 -0
  50. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/samplers/smc/minipcn.py +0 -0
  51. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/samples.py +0 -0
  52. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/transforms.py +0 -0
  53. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/src/aspire/utils.py +0 -0
  54. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/tests/conftest.py +0 -0
  55. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/tests/integration_tests/conftest.py +0 -0
  56. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/tests/integration_tests/test_checkpointing.py +0 -0
  57. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/tests/integration_tests/test_integration.py +0 -0
  58. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/tests/test_flows/test_flows_core.py +0 -0
  59. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -0
  60. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -0
  61. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/tests/test_samples.py +0 -0
  62. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/tests/test_transforms.py +0 -0
  63. {aspire_inference-0.1.0a12 → aspire_inference-0.1.0a13}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a12
3
+ Version: 0.1.0a13
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -19,7 +19,8 @@ Provides-Extra: scipy
19
19
  Requires-Dist: scipy; extra == "scipy"
20
20
  Provides-Extra: jax
21
21
  Requires-Dist: jax; extra == "jax"
22
- Requires-Dist: jaxlib; extra == "jax"
22
+ Requires-Dist: jaxlib<0.8.2; python_version < "3.12" and extra == "jax"
23
+ Requires-Dist: jaxlib; python_version >= "3.12" and extra == "jax"
23
24
  Requires-Dist: flowjax; extra == "jax"
24
25
  Provides-Extra: torch
25
26
  Requires-Dist: torch; extra == "torch"
@@ -32,6 +33,7 @@ Provides-Extra: emcee
32
33
  Requires-Dist: emcee; extra == "emcee"
33
34
  Provides-Extra: blackjax
34
35
  Requires-Dist: blackjax; extra == "blackjax"
36
+ Requires-Dist: fastprogress<1.1.0; extra == "blackjax"
35
37
  Provides-Extra: test
36
38
  Requires-Dist: pytest; extra == "test"
37
39
  Requires-Dist: pytest-requires; extra == "test"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a12
3
+ Version: 0.1.0a13
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -19,7 +19,8 @@ Provides-Extra: scipy
19
19
  Requires-Dist: scipy; extra == "scipy"
20
20
  Provides-Extra: jax
21
21
  Requires-Dist: jax; extra == "jax"
22
- Requires-Dist: jaxlib; extra == "jax"
22
+ Requires-Dist: jaxlib<0.8.2; python_version < "3.12" and extra == "jax"
23
+ Requires-Dist: jaxlib; python_version >= "3.12" and extra == "jax"
23
24
  Requires-Dist: flowjax; extra == "jax"
24
25
  Provides-Extra: torch
25
26
  Requires-Dist: torch; extra == "torch"
@@ -32,6 +33,7 @@ Provides-Extra: emcee
32
33
  Requires-Dist: emcee; extra == "emcee"
33
34
  Provides-Extra: blackjax
34
35
  Requires-Dist: blackjax; extra == "blackjax"
36
+ Requires-Dist: fastprogress<1.1.0; extra == "blackjax"
35
37
  Provides-Extra: test
36
38
  Requires-Dist: pytest; extra == "test"
37
39
  Requires-Dist: pytest-requires; extra == "test"
@@ -6,15 +6,21 @@ h5py
6
6
 
7
7
  [blackjax]
8
8
  blackjax
9
+ fastprogress<1.1.0
9
10
 
10
11
  [emcee]
11
12
  emcee
12
13
 
13
14
  [jax]
14
15
  jax
15
- jaxlib
16
16
  flowjax
17
17
 
18
+ [jax:python_version < "3.12"]
19
+ jaxlib<0.8.2
20
+
21
+ [jax:python_version >= "3.12"]
22
+ jaxlib
23
+
18
24
  [minipcn]
19
25
  minipcn[array-api]>=0.2.0a3
20
26
  orng
@@ -30,8 +30,9 @@ scipy = [
30
30
  ]
31
31
  jax = [
32
32
  "jax",
33
- "jaxlib",
34
- "flowjax"
33
+ "jaxlib<0.8.2; python_version < '3.12'",
34
+ "jaxlib; python_version >= '3.12'",
35
+ "flowjax",
35
36
  ]
36
37
  torch = [
37
38
  "torch",
@@ -47,6 +48,7 @@ emcee = [
47
48
  ]
48
49
  blackjax = [
49
50
  "blackjax",
51
+ "fastprogress<1.1.0",
50
52
  ]
51
53
  test = [
52
54
  "pytest",
@@ -354,6 +354,7 @@ class Aspire:
354
354
  xp=self.xp,
355
355
  dtype=self.dtype,
356
356
  preconditioning_transform=transform,
357
+ parameters=self.parameters,
357
358
  **kwargs,
358
359
  )
359
360
  return sampler
@@ -13,14 +13,13 @@ class MCMCSampler(Sampler):
13
13
  n_samples_drawn = 0
14
14
  samples = None
15
15
  while n_samples_drawn < n_samples:
16
- n_to_draw = n_samples - n_samples_drawn
17
- x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
16
+ x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
18
17
  new_samples = Samples(x, xp=self.xp, log_q=log_q, dtype=self.dtype)
19
18
  new_samples.log_prior = new_samples.array_to_namespace(
20
19
  self.log_prior(new_samples)
21
20
  )
22
21
  valid = self.xp.isfinite(new_samples.log_prior)
23
- n_valid = self.xp.sum(valid)
22
+ n_valid = int(self.xp.sum(valid))
24
23
  if n_valid > 0:
25
24
  if samples is None:
26
25
  samples = new_samples[valid]
@@ -95,6 +95,7 @@ class SMCSampler(MCMCSampler):
95
95
  beta: float,
96
96
  beta_step: float,
97
97
  min_step: float,
98
+ beta_tolerance: float = 1e-6,
98
99
  ) -> tuple[float, float]:
99
100
  """Determine the next beta value.
100
101
 
@@ -108,6 +109,8 @@ class SMCSampler(MCMCSampler):
108
109
  The fixed beta step size if not adaptive.
109
110
  min_step : float
110
111
  The minimum beta step size.
112
+ beta_tolerance : float
113
+ Tolerance when checking for beta convergence.
111
114
 
112
115
  Returns
113
116
  -------
@@ -124,14 +127,13 @@ class SMCSampler(MCMCSampler):
124
127
  beta_prev = beta
125
128
  beta_min = beta_prev
126
129
  beta_max = 1.0
127
- tol = 1e-5
128
130
  eff_beta_max = effective_sample_size(
129
131
  samples.log_weights(beta_max)
130
132
  ) / len(samples)
131
133
  if eff_beta_max >= self.current_target_efficiency(beta_prev):
132
134
  beta_min = 1.0
133
135
  target_eff = self.current_target_efficiency(beta_prev)
134
- while beta_max - beta_min > tol:
136
+ while beta_max - beta_min > beta_tolerance:
135
137
  beta_try = 0.5 * (beta_max + beta_min)
136
138
  eff = effective_sample_size(
137
139
  samples.log_weights(beta_try)
@@ -163,6 +165,7 @@ class SMCSampler(MCMCSampler):
163
165
  checkpoint_every: int | None = None,
164
166
  checkpoint_file_path: str | None = None,
165
167
  resume_from: str | bytes | dict | None = None,
168
+ beta_tolerance: float = 1e-6,
166
169
  ) -> SMCSamples:
167
170
  resumed = resume_from is not None
168
171
  if resumed:
@@ -249,6 +252,7 @@ class SMCSampler(MCMCSampler):
249
252
  beta,
250
253
  beta_step,
251
254
  min_step,
255
+ beta_tolerance=beta_tolerance,
252
256
  )
253
257
  self.history.eff_target.append(
254
258
  self.current_target_efficiency(beta)