aspire-inference 0.1.0a12__py3-none-any.whl → 0.1.0a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/aspire.py +1 -0
- aspire/samplers/mcmc.py +2 -3
- aspire/samplers/smc/base.py +6 -2
- {aspire_inference-0.1.0a12.dist-info → aspire_inference-0.1.0a13.dist-info}/METADATA +4 -2
- {aspire_inference-0.1.0a12.dist-info → aspire_inference-0.1.0a13.dist-info}/RECORD +8 -8
- {aspire_inference-0.1.0a12.dist-info → aspire_inference-0.1.0a13.dist-info}/WHEEL +1 -1
- {aspire_inference-0.1.0a12.dist-info → aspire_inference-0.1.0a13.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a12.dist-info → aspire_inference-0.1.0a13.dist-info}/top_level.txt +0 -0
aspire/aspire.py
CHANGED
aspire/samplers/mcmc.py
CHANGED
|
@@ -13,14 +13,13 @@ class MCMCSampler(Sampler):
|
|
|
13
13
|
n_samples_drawn = 0
|
|
14
14
|
samples = None
|
|
15
15
|
while n_samples_drawn < n_samples:
|
|
16
|
-
|
|
17
|
-
x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
|
|
16
|
+
x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
|
|
18
17
|
new_samples = Samples(x, xp=self.xp, log_q=log_q, dtype=self.dtype)
|
|
19
18
|
new_samples.log_prior = new_samples.array_to_namespace(
|
|
20
19
|
self.log_prior(new_samples)
|
|
21
20
|
)
|
|
22
21
|
valid = self.xp.isfinite(new_samples.log_prior)
|
|
23
|
-
n_valid = self.xp.sum(valid)
|
|
22
|
+
n_valid = int(self.xp.sum(valid))
|
|
24
23
|
if n_valid > 0:
|
|
25
24
|
if samples is None:
|
|
26
25
|
samples = new_samples[valid]
|
aspire/samplers/smc/base.py
CHANGED
|
@@ -95,6 +95,7 @@ class SMCSampler(MCMCSampler):
|
|
|
95
95
|
beta: float,
|
|
96
96
|
beta_step: float,
|
|
97
97
|
min_step: float,
|
|
98
|
+
beta_tolerance: float = 1e-6,
|
|
98
99
|
) -> tuple[float, float]:
|
|
99
100
|
"""Determine the next beta value.
|
|
100
101
|
|
|
@@ -108,6 +109,8 @@ class SMCSampler(MCMCSampler):
|
|
|
108
109
|
The fixed beta step size if not adaptive.
|
|
109
110
|
min_step : float
|
|
110
111
|
The minimum beta step size.
|
|
112
|
+
beta_tolerance : float
|
|
113
|
+
Tolerance when checking for beta convergence.
|
|
111
114
|
|
|
112
115
|
Returns
|
|
113
116
|
-------
|
|
@@ -124,14 +127,13 @@ class SMCSampler(MCMCSampler):
|
|
|
124
127
|
beta_prev = beta
|
|
125
128
|
beta_min = beta_prev
|
|
126
129
|
beta_max = 1.0
|
|
127
|
-
tol = 1e-5
|
|
128
130
|
eff_beta_max = effective_sample_size(
|
|
129
131
|
samples.log_weights(beta_max)
|
|
130
132
|
) / len(samples)
|
|
131
133
|
if eff_beta_max >= self.current_target_efficiency(beta_prev):
|
|
132
134
|
beta_min = 1.0
|
|
133
135
|
target_eff = self.current_target_efficiency(beta_prev)
|
|
134
|
-
while beta_max - beta_min >
|
|
136
|
+
while beta_max - beta_min > beta_tolerance:
|
|
135
137
|
beta_try = 0.5 * (beta_max + beta_min)
|
|
136
138
|
eff = effective_sample_size(
|
|
137
139
|
samples.log_weights(beta_try)
|
|
@@ -163,6 +165,7 @@ class SMCSampler(MCMCSampler):
|
|
|
163
165
|
checkpoint_every: int | None = None,
|
|
164
166
|
checkpoint_file_path: str | None = None,
|
|
165
167
|
resume_from: str | bytes | dict | None = None,
|
|
168
|
+
beta_tolerance: float = 1e-6,
|
|
166
169
|
) -> SMCSamples:
|
|
167
170
|
resumed = resume_from is not None
|
|
168
171
|
if resumed:
|
|
@@ -249,6 +252,7 @@ class SMCSampler(MCMCSampler):
|
|
|
249
252
|
beta,
|
|
250
253
|
beta_step,
|
|
251
254
|
min_step,
|
|
255
|
+
beta_tolerance=beta_tolerance,
|
|
252
256
|
)
|
|
253
257
|
self.history.eff_target.append(
|
|
254
258
|
self.current_target_efficiency(beta)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: aspire-inference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a13
|
|
4
4
|
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
5
|
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
6
|
License: MIT
|
|
@@ -19,7 +19,8 @@ Provides-Extra: scipy
|
|
|
19
19
|
Requires-Dist: scipy; extra == "scipy"
|
|
20
20
|
Provides-Extra: jax
|
|
21
21
|
Requires-Dist: jax; extra == "jax"
|
|
22
|
-
Requires-Dist: jaxlib; extra == "jax"
|
|
22
|
+
Requires-Dist: jaxlib<0.8.2; python_version < "3.12" and extra == "jax"
|
|
23
|
+
Requires-Dist: jaxlib; python_version >= "3.12" and extra == "jax"
|
|
23
24
|
Requires-Dist: flowjax; extra == "jax"
|
|
24
25
|
Provides-Extra: torch
|
|
25
26
|
Requires-Dist: torch; extra == "torch"
|
|
@@ -32,6 +33,7 @@ Provides-Extra: emcee
|
|
|
32
33
|
Requires-Dist: emcee; extra == "emcee"
|
|
33
34
|
Provides-Extra: blackjax
|
|
34
35
|
Requires-Dist: blackjax; extra == "blackjax"
|
|
36
|
+
Requires-Dist: fastprogress<1.1.0; extra == "blackjax"
|
|
35
37
|
Provides-Extra: test
|
|
36
38
|
Requires-Dist: pytest; extra == "test"
|
|
37
39
|
Requires-Dist: pytest-requires; extra == "test"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
|
|
2
|
-
aspire/aspire.py,sha256=
|
|
2
|
+
aspire/aspire.py,sha256=EelzGme70nrR_iEeT_HTB4feloYXzW2WADZUwOcREpo,30925
|
|
3
3
|
aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
|
|
4
4
|
aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
|
|
5
5
|
aspire/samples.py,sha256=v7y8DkirUCHOJbCE-o9y2K7xzU2HicIo_O0CdFhLgXE,19478
|
|
@@ -15,14 +15,14 @@ aspire/flows/torch/flows.py,sha256=QcQOcFZEsLWHPwbQUFGOFdfEslyc59Vf_UEsS0xAGPo,1
|
|
|
15
15
|
aspire/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
16
|
aspire/samplers/base.py,sha256=ygqrvqSedWSb0cz8DQ_MHokOOxi6aBRdHxf_qoEPwUE,8243
|
|
17
17
|
aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M,676
|
|
18
|
-
aspire/samplers/mcmc.py,sha256=
|
|
18
|
+
aspire/samplers/mcmc.py,sha256=vLLuXUJfNH9QpkRJTLdJB5RylKRy-3cHNPe_eyi1WQE,5104
|
|
19
19
|
aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
|
-
aspire/samplers/smc/base.py,sha256
|
|
20
|
+
aspire/samplers/smc/base.py,sha256=-WTUWLsODsbOQcXjnVQW1-k75xbQ6H31vpCBP8EwQEE,14326
|
|
21
21
|
aspire/samplers/smc/blackjax.py,sha256=2riWDSRmpL5lGmnhNtdieiRs0oYC6XZA2X-nVlQaqpE,12490
|
|
22
22
|
aspire/samplers/smc/emcee.py,sha256=4CI9GvH69FCoLiFBbKKYwYocYyiM95IijC5EvrcAmUo,2891
|
|
23
23
|
aspire/samplers/smc/minipcn.py,sha256=IJ5466VvARd4qZCWXXl-l3BPaKW1AgcwmbP3ISL2bto,3368
|
|
24
|
-
aspire_inference-0.1.
|
|
25
|
-
aspire_inference-0.1.
|
|
26
|
-
aspire_inference-0.1.
|
|
27
|
-
aspire_inference-0.1.
|
|
28
|
-
aspire_inference-0.1.
|
|
24
|
+
aspire_inference-0.1.0a13.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
|
|
25
|
+
aspire_inference-0.1.0a13.dist-info/METADATA,sha256=l0CpbJt7o8zFQGRXM6lKxDQXsndXk2eaAZBPejfwkUQ,4025
|
|
26
|
+
aspire_inference-0.1.0a13.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
27
|
+
aspire_inference-0.1.0a13.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
|
|
28
|
+
aspire_inference-0.1.0a13.dist-info/RECORD,,
|
{aspire_inference-0.1.0a12.dist-info → aspire_inference-0.1.0a13.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
|
File without changes
|