aspire-inference 0.1.0a11__py3-none-any.whl → 0.1.0a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/aspire.py CHANGED
@@ -20,6 +20,7 @@ from .transforms import (
20
20
  )
21
21
  from .utils import (
22
22
  AspireFile,
23
+ function_id,
23
24
  load_from_h5_file,
24
25
  recursively_save_to_h5_file,
25
26
  resolve_xp,
@@ -353,6 +354,7 @@ class Aspire:
353
354
  xp=self.xp,
354
355
  dtype=self.dtype,
355
356
  preconditioning_transform=transform,
357
+ parameters=self.parameters,
356
358
  **kwargs,
357
359
  )
358
360
  return sampler
@@ -687,8 +689,8 @@ class Aspire:
687
689
  method of the sampler.
688
690
  """
689
691
  config = {
690
- "log_likelihood": self.log_likelihood.__name__,
691
- "log_prior": self.log_prior.__name__,
692
+ "log_likelihood": function_id(self.log_likelihood),
693
+ "log_prior": function_id(self.log_prior),
692
694
  "dims": self.dims,
693
695
  "parameters": self.parameters,
694
696
  "periodic_parameters": self.periodic_parameters,
aspire/samplers/mcmc.py CHANGED
@@ -13,14 +13,13 @@ class MCMCSampler(Sampler):
13
13
  n_samples_drawn = 0
14
14
  samples = None
15
15
  while n_samples_drawn < n_samples:
16
- n_to_draw = n_samples - n_samples_drawn
17
- x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
16
+ x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
18
17
  new_samples = Samples(x, xp=self.xp, log_q=log_q, dtype=self.dtype)
19
18
  new_samples.log_prior = new_samples.array_to_namespace(
20
19
  self.log_prior(new_samples)
21
20
  )
22
21
  valid = self.xp.isfinite(new_samples.log_prior)
23
- n_valid = self.xp.sum(valid)
22
+ n_valid = int(self.xp.sum(valid))
24
23
  if n_valid > 0:
25
24
  if samples is None:
26
25
  samples = new_samples[valid]
@@ -95,6 +95,7 @@ class SMCSampler(MCMCSampler):
95
95
  beta: float,
96
96
  beta_step: float,
97
97
  min_step: float,
98
+ beta_tolerance: float = 1e-6,
98
99
  ) -> tuple[float, float]:
99
100
  """Determine the next beta value.
100
101
 
@@ -108,6 +109,8 @@ class SMCSampler(MCMCSampler):
108
109
  The fixed beta step size if not adaptive.
109
110
  min_step : float
110
111
  The minimum beta step size.
112
+ beta_tolerance : float
113
+ Tolerance when checking for beta convergence.
111
114
 
112
115
  Returns
113
116
  -------
@@ -124,14 +127,13 @@ class SMCSampler(MCMCSampler):
124
127
  beta_prev = beta
125
128
  beta_min = beta_prev
126
129
  beta_max = 1.0
127
- tol = 1e-5
128
130
  eff_beta_max = effective_sample_size(
129
131
  samples.log_weights(beta_max)
130
132
  ) / len(samples)
131
133
  if eff_beta_max >= self.current_target_efficiency(beta_prev):
132
134
  beta_min = 1.0
133
135
  target_eff = self.current_target_efficiency(beta_prev)
134
- while beta_max - beta_min > tol:
136
+ while beta_max - beta_min > beta_tolerance:
135
137
  beta_try = 0.5 * (beta_max + beta_min)
136
138
  eff = effective_sample_size(
137
139
  samples.log_weights(beta_try)
@@ -163,6 +165,7 @@ class SMCSampler(MCMCSampler):
163
165
  checkpoint_every: int | None = None,
164
166
  checkpoint_file_path: str | None = None,
165
167
  resume_from: str | bytes | dict | None = None,
168
+ beta_tolerance: float = 1e-6,
166
169
  ) -> SMCSamples:
167
170
  resumed = resume_from is not None
168
171
  if resumed:
@@ -249,6 +252,7 @@ class SMCSampler(MCMCSampler):
249
252
  beta,
250
253
  beta_step,
251
254
  min_step,
255
+ beta_tolerance=beta_tolerance,
252
256
  )
253
257
  self.history.eff_target.append(
254
258
  self.current_target_efficiency(beta)
aspire/utils.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import functools
3
4
  import inspect
4
5
  import logging
5
6
  import pickle
@@ -911,3 +912,23 @@ def track_calls(wrapped=None):
911
912
  return wrapped_func(*args, **kwargs)
912
913
 
913
914
  return wrapper(wrapped) if wrapped else wrapper
915
+
916
+
917
+ def function_id(fn: Any) -> str:
918
+ """Get a unique identifier for a function.
919
+
920
+ Parameters
921
+ ----------
922
+ fn : Any
923
+ The function to get the identifier for.
924
+
925
+ Returns
926
+ -------
927
+ str
928
+ The unique identifier for the function.
929
+ """
930
+ if isinstance(fn, functools.partial):
931
+ base = fn.func
932
+ else:
933
+ base = fn
934
+ return f"{base.__module__}:{getattr(base, '__qualname__', type(base).__name__)}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a11
3
+ Version: 0.1.0a13
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -19,7 +19,8 @@ Provides-Extra: scipy
19
19
  Requires-Dist: scipy; extra == "scipy"
20
20
  Provides-Extra: jax
21
21
  Requires-Dist: jax; extra == "jax"
22
- Requires-Dist: jaxlib; extra == "jax"
22
+ Requires-Dist: jaxlib<0.8.2; python_version < "3.12" and extra == "jax"
23
+ Requires-Dist: jaxlib; python_version >= "3.12" and extra == "jax"
23
24
  Requires-Dist: flowjax; extra == "jax"
24
25
  Provides-Extra: torch
25
26
  Requires-Dist: torch; extra == "torch"
@@ -32,6 +33,7 @@ Provides-Extra: emcee
32
33
  Requires-Dist: emcee; extra == "emcee"
33
34
  Provides-Extra: blackjax
34
35
  Requires-Dist: blackjax; extra == "blackjax"
36
+ Requires-Dist: fastprogress<1.1.0; extra == "blackjax"
35
37
  Provides-Extra: test
36
38
  Requires-Dist: pytest; extra == "test"
37
39
  Requires-Dist: pytest-requires; extra == "test"
@@ -1,10 +1,10 @@
1
1
  aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
2
- aspire/aspire.py,sha256=lr0bD5GDWdlAGfODGzj4BoELKUF5HYAMb8yYGvRR_y0,30860
2
+ aspire/aspire.py,sha256=EelzGme70nrR_iEeT_HTB4feloYXzW2WADZUwOcREpo,30925
3
3
  aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
4
4
  aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
5
5
  aspire/samples.py,sha256=v7y8DkirUCHOJbCE-o9y2K7xzU2HicIo_O0CdFhLgXE,19478
6
6
  aspire/transforms.py,sha256=CHrfPQHEyHQ9I0WWiAgUplWwxYypZsD7uCYIHUbSFtY,24974
7
- aspire/utils.py,sha256=87avRkTce9QvELpIcqlKauSoUZSYi1fqe1asC97TzqA,26947
7
+ aspire/utils.py,sha256=sIONKn3gT7i3hVdlK9bRWy_I79rdk0QPkXTA4O1FlCI,27405
8
8
  aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
9
9
  aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
10
10
  aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
@@ -15,14 +15,14 @@ aspire/flows/torch/flows.py,sha256=QcQOcFZEsLWHPwbQUFGOFdfEslyc59Vf_UEsS0xAGPo,1
15
15
  aspire/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  aspire/samplers/base.py,sha256=ygqrvqSedWSb0cz8DQ_MHokOOxi6aBRdHxf_qoEPwUE,8243
17
17
  aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M,676
18
- aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
18
+ aspire/samplers/mcmc.py,sha256=vLLuXUJfNH9QpkRJTLdJB5RylKRy-3cHNPe_eyi1WQE,5104
19
19
  aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- aspire/samplers/smc/base.py,sha256=40A9yVuKS1F8cPzbfVQ9rNk3y07mnkfbuyRDIh_fy5A,14122
20
+ aspire/samplers/smc/base.py,sha256=-WTUWLsODsbOQcXjnVQW1-k75xbQ6H31vpCBP8EwQEE,14326
21
21
  aspire/samplers/smc/blackjax.py,sha256=2riWDSRmpL5lGmnhNtdieiRs0oYC6XZA2X-nVlQaqpE,12490
22
22
  aspire/samplers/smc/emcee.py,sha256=4CI9GvH69FCoLiFBbKKYwYocYyiM95IijC5EvrcAmUo,2891
23
23
  aspire/samplers/smc/minipcn.py,sha256=IJ5466VvARd4qZCWXXl-l3BPaKW1AgcwmbP3ISL2bto,3368
24
- aspire_inference-0.1.0a11.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
- aspire_inference-0.1.0a11.dist-info/METADATA,sha256=lLd2d5HR-t942wKLyYbdJ1DL9CKl7tkpBul_vX8DU4M,3869
26
- aspire_inference-0.1.0a11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- aspire_inference-0.1.0a11.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
- aspire_inference-0.1.0a11.dist-info/RECORD,,
24
+ aspire_inference-0.1.0a13.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
+ aspire_inference-0.1.0a13.dist-info/METADATA,sha256=l0CpbJt7o8zFQGRXM6lKxDQXsndXk2eaAZBPejfwkUQ,4025
26
+ aspire_inference-0.1.0a13.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
27
+ aspire_inference-0.1.0a13.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
+ aspire_inference-0.1.0a13.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5