aspire-inference 0.1.0a12__py3-none-any.whl → 0.1.0a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/aspire.py CHANGED
@@ -354,6 +354,7 @@ class Aspire:
354
354
  xp=self.xp,
355
355
  dtype=self.dtype,
356
356
  preconditioning_transform=transform,
357
+ parameters=self.parameters,
357
358
  **kwargs,
358
359
  )
359
360
  return sampler
aspire/samplers/mcmc.py CHANGED
@@ -13,14 +13,13 @@ class MCMCSampler(Sampler):
13
13
  n_samples_drawn = 0
14
14
  samples = None
15
15
  while n_samples_drawn < n_samples:
16
- n_to_draw = n_samples - n_samples_drawn
17
- x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
16
+ x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
18
17
  new_samples = Samples(x, xp=self.xp, log_q=log_q, dtype=self.dtype)
19
18
  new_samples.log_prior = new_samples.array_to_namespace(
20
19
  self.log_prior(new_samples)
21
20
  )
22
21
  valid = self.xp.isfinite(new_samples.log_prior)
23
- n_valid = self.xp.sum(valid)
22
+ n_valid = int(self.xp.sum(valid))
24
23
  if n_valid > 0:
25
24
  if samples is None:
26
25
  samples = new_samples[valid]
@@ -95,6 +95,7 @@ class SMCSampler(MCMCSampler):
95
95
  beta: float,
96
96
  beta_step: float,
97
97
  min_step: float,
98
+ beta_tolerance: float = 1e-6,
98
99
  ) -> tuple[float, float]:
99
100
  """Determine the next beta value.
100
101
 
@@ -108,6 +109,8 @@ class SMCSampler(MCMCSampler):
108
109
  The fixed beta step size if not adaptive.
109
110
  min_step : float
110
111
  The minimum beta step size.
112
+ beta_tolerance : float
113
+ Tolerance when checking for beta convergence.
111
114
 
112
115
  Returns
113
116
  -------
@@ -124,14 +127,13 @@ class SMCSampler(MCMCSampler):
124
127
  beta_prev = beta
125
128
  beta_min = beta_prev
126
129
  beta_max = 1.0
127
- tol = 1e-5
128
130
  eff_beta_max = effective_sample_size(
129
131
  samples.log_weights(beta_max)
130
132
  ) / len(samples)
131
133
  if eff_beta_max >= self.current_target_efficiency(beta_prev):
132
134
  beta_min = 1.0
133
135
  target_eff = self.current_target_efficiency(beta_prev)
134
- while beta_max - beta_min > tol:
136
+ while beta_max - beta_min > beta_tolerance:
135
137
  beta_try = 0.5 * (beta_max + beta_min)
136
138
  eff = effective_sample_size(
137
139
  samples.log_weights(beta_try)
@@ -163,6 +165,7 @@ class SMCSampler(MCMCSampler):
163
165
  checkpoint_every: int | None = None,
164
166
  checkpoint_file_path: str | None = None,
165
167
  resume_from: str | bytes | dict | None = None,
168
+ beta_tolerance: float = 1e-6,
166
169
  ) -> SMCSamples:
167
170
  resumed = resume_from is not None
168
171
  if resumed:
@@ -249,6 +252,7 @@ class SMCSampler(MCMCSampler):
249
252
  beta,
250
253
  beta_step,
251
254
  min_step,
255
+ beta_tolerance=beta_tolerance,
252
256
  )
253
257
  self.history.eff_target.append(
254
258
  self.current_target_efficiency(beta)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a12
3
+ Version: 0.1.0a13
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -19,7 +19,8 @@ Provides-Extra: scipy
19
19
  Requires-Dist: scipy; extra == "scipy"
20
20
  Provides-Extra: jax
21
21
  Requires-Dist: jax; extra == "jax"
22
- Requires-Dist: jaxlib; extra == "jax"
22
+ Requires-Dist: jaxlib<0.8.2; python_version < "3.12" and extra == "jax"
23
+ Requires-Dist: jaxlib; python_version >= "3.12" and extra == "jax"
23
24
  Requires-Dist: flowjax; extra == "jax"
24
25
  Provides-Extra: torch
25
26
  Requires-Dist: torch; extra == "torch"
@@ -32,6 +33,7 @@ Provides-Extra: emcee
32
33
  Requires-Dist: emcee; extra == "emcee"
33
34
  Provides-Extra: blackjax
34
35
  Requires-Dist: blackjax; extra == "blackjax"
36
+ Requires-Dist: fastprogress<1.1.0; extra == "blackjax"
35
37
  Provides-Extra: test
36
38
  Requires-Dist: pytest; extra == "test"
37
39
  Requires-Dist: pytest-requires; extra == "test"
@@ -1,5 +1,5 @@
1
1
  aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
2
- aspire/aspire.py,sha256=7DDRpwMezJABzX3AyHamRf8hjLAEeqCtg-_s5qSRjg0,30885
2
+ aspire/aspire.py,sha256=EelzGme70nrR_iEeT_HTB4feloYXzW2WADZUwOcREpo,30925
3
3
  aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
4
4
  aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
5
5
  aspire/samples.py,sha256=v7y8DkirUCHOJbCE-o9y2K7xzU2HicIo_O0CdFhLgXE,19478
@@ -15,14 +15,14 @@ aspire/flows/torch/flows.py,sha256=QcQOcFZEsLWHPwbQUFGOFdfEslyc59Vf_UEsS0xAGPo,1
15
15
  aspire/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  aspire/samplers/base.py,sha256=ygqrvqSedWSb0cz8DQ_MHokOOxi6aBRdHxf_qoEPwUE,8243
17
17
  aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M,676
18
- aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
18
+ aspire/samplers/mcmc.py,sha256=vLLuXUJfNH9QpkRJTLdJB5RylKRy-3cHNPe_eyi1WQE,5104
19
19
  aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- aspire/samplers/smc/base.py,sha256=40A9yVuKS1F8cPzbfVQ9rNk3y07mnkfbuyRDIh_fy5A,14122
20
+ aspire/samplers/smc/base.py,sha256=-WTUWLsODsbOQcXjnVQW1-k75xbQ6H31vpCBP8EwQEE,14326
21
21
  aspire/samplers/smc/blackjax.py,sha256=2riWDSRmpL5lGmnhNtdieiRs0oYC6XZA2X-nVlQaqpE,12490
22
22
  aspire/samplers/smc/emcee.py,sha256=4CI9GvH69FCoLiFBbKKYwYocYyiM95IijC5EvrcAmUo,2891
23
23
  aspire/samplers/smc/minipcn.py,sha256=IJ5466VvARd4qZCWXXl-l3BPaKW1AgcwmbP3ISL2bto,3368
24
- aspire_inference-0.1.0a12.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
- aspire_inference-0.1.0a12.dist-info/METADATA,sha256=snzuBueTUZazIqKt6yEYfI4JdY3QVhY0C-vxl7Urauw,3869
26
- aspire_inference-0.1.0a12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- aspire_inference-0.1.0a12.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
- aspire_inference-0.1.0a12.dist-info/RECORD,,
24
+ aspire_inference-0.1.0a13.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
+ aspire_inference-0.1.0a13.dist-info/METADATA,sha256=l0CpbJt7o8zFQGRXM6lKxDQXsndXk2eaAZBPejfwkUQ,4025
26
+ aspire_inference-0.1.0a13.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
27
+ aspire_inference-0.1.0a13.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
+ aspire_inference-0.1.0a13.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5