pymc-extras 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +1 -3
- pymc_extras/distributions/__init__.py +2 -0
- pymc_extras/distributions/transforms/__init__.py +3 -0
- pymc_extras/distributions/transforms/partial_order.py +227 -0
- pymc_extras/inference/__init__.py +4 -2
- pymc_extras/inference/fit.py +6 -4
- pymc_extras/inference/laplace.py +4 -1
- pymc_extras/inference/pathfinder/importance_sampling.py +23 -17
- pymc_extras/inference/pathfinder/lbfgs.py +49 -13
- pymc_extras/inference/pathfinder/pathfinder.py +136 -118
- pymc_extras/statespace/core/statespace.py +5 -4
- pymc_extras/statespace/filters/distributions.py +9 -45
- pymc_extras/statespace/utils/data_tools.py +24 -9
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/METADATA +5 -3
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/RECORD +23 -20
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/WHEEL +1 -1
- tests/distributions/test_transform.py +77 -0
- tests/statespace/test_coord_assignment.py +65 -0
- tests/test_laplace.py +16 -0
- tests/test_pathfinder.py +141 -17
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info/licenses}/LICENSE +0 -0
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/top_level.txt +0 -0
tests/test_pathfinder.py
CHANGED
|
@@ -12,14 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
import re
|
|
15
16
|
import sys
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
import pymc as pm
|
|
20
|
+
import pytensor.tensor as pt
|
|
19
21
|
import pytest
|
|
20
22
|
|
|
21
|
-
pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning")
|
|
22
|
-
|
|
23
23
|
import pymc_extras as pmx
|
|
24
24
|
|
|
25
25
|
|
|
@@ -44,15 +44,98 @@ def reference_idata():
|
|
|
44
44
|
with model:
|
|
45
45
|
idata = pmx.fit(
|
|
46
46
|
method="pathfinder",
|
|
47
|
-
num_paths=
|
|
48
|
-
jitter=
|
|
47
|
+
num_paths=10,
|
|
48
|
+
jitter=12.0,
|
|
49
49
|
random_seed=41,
|
|
50
50
|
inference_backend="pymc",
|
|
51
51
|
)
|
|
52
52
|
return idata
|
|
53
53
|
|
|
54
54
|
|
|
55
|
+
def unstable_lbfgs_update_mask_model() -> pm.Model:
|
|
56
|
+
# data and model from: https://github.com/pymc-devs/pymc-extras/issues/445
|
|
57
|
+
# this scenario made LBFGS struggle leading to a lot of rejected iterations, (result.nit being moderate, but only history.count <= 1).
|
|
58
|
+
# this scenario is used to test that the LBFGS history manager is rejecting iterations as expected and PF can run to completion.
|
|
59
|
+
|
|
60
|
+
# fmt: off
|
|
61
|
+
inp = np.array([0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 1, 1, 2, 0, 1, 2, 1, 0, 1, 0, 1, 0, 1, 0])
|
|
62
|
+
|
|
63
|
+
res = np.array([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[0,1,0,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,1,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[1,0,0,0,0],[0,1,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,0,1,0]])
|
|
64
|
+
# fmt: on
|
|
65
|
+
|
|
66
|
+
n_ordered = res.shape[1]
|
|
67
|
+
coords = {
|
|
68
|
+
"obs": np.arange(len(inp)),
|
|
69
|
+
"inp": np.arange(max(inp) + 1),
|
|
70
|
+
"outp": np.arange(res.shape[1]),
|
|
71
|
+
}
|
|
72
|
+
with pm.Model(coords=coords) as mdl:
|
|
73
|
+
mu = pm.Normal("intercept", sigma=3.5)[None]
|
|
74
|
+
|
|
75
|
+
offset = pm.Normal(
|
|
76
|
+
"offset", dims=("inp"), transform=pm.distributions.transforms.ZeroSumTransform([0])
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
scale = 3.5 * pm.HalfStudentT("scale", nu=5)
|
|
80
|
+
mu += (scale * offset)[inp]
|
|
81
|
+
|
|
82
|
+
phi_delta = pm.Dirichlet("phi_diffs", [1.0] * (n_ordered - 1))
|
|
83
|
+
phi = pt.concatenate([[0], pt.cumsum(phi_delta)])
|
|
84
|
+
s_mu = pm.Normal(
|
|
85
|
+
"stereotype_intercept",
|
|
86
|
+
size=n_ordered,
|
|
87
|
+
transform=pm.distributions.transforms.ZeroSumTransform([-1]),
|
|
88
|
+
)
|
|
89
|
+
fprobs = pm.math.softmax(s_mu[None, :] + phi[None, :] * mu[:, None], axis=-1)
|
|
90
|
+
|
|
91
|
+
pm.Multinomial("y_res", p=fprobs, n=np.ones(len(inp)), observed=res, dims=("obs", "outp"))
|
|
92
|
+
|
|
93
|
+
return mdl
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@pytest.mark.parametrize("jitter", [12.0, 500.0, 1000.0])
|
|
97
|
+
def test_unstable_lbfgs_update_mask(capsys, jitter):
|
|
98
|
+
model = unstable_lbfgs_update_mask_model()
|
|
99
|
+
|
|
100
|
+
if jitter < 1000:
|
|
101
|
+
with model:
|
|
102
|
+
idata = pmx.fit(
|
|
103
|
+
method="pathfinder",
|
|
104
|
+
jitter=jitter,
|
|
105
|
+
random_seed=4,
|
|
106
|
+
)
|
|
107
|
+
out, err = capsys.readouterr()
|
|
108
|
+
status_pattern = [
|
|
109
|
+
r"INIT_FAILED_LOW_UPDATE_PCT\s+\d+",
|
|
110
|
+
r"LOW_UPDATE_PCT\s+\d+",
|
|
111
|
+
r"LBFGS_FAILED\s+\d+",
|
|
112
|
+
r"SUCCESS\s+\d+",
|
|
113
|
+
]
|
|
114
|
+
for pattern in status_pattern:
|
|
115
|
+
assert re.search(pattern, out) is not None
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
with pytest.raises(ValueError, match="All paths failed"):
|
|
119
|
+
with model:
|
|
120
|
+
idata = pmx.fit(
|
|
121
|
+
method="pathfinder",
|
|
122
|
+
jitter=1000,
|
|
123
|
+
random_seed=2,
|
|
124
|
+
num_paths=4,
|
|
125
|
+
)
|
|
126
|
+
out, err = capsys.readouterr()
|
|
127
|
+
|
|
128
|
+
status_pattern = [
|
|
129
|
+
r"INIT_FAILED_LOW_UPDATE_PCT\s+2",
|
|
130
|
+
r"LOW_UPDATE_PCT\s+2",
|
|
131
|
+
r"LBFGS_FAILED\s+4",
|
|
132
|
+
]
|
|
133
|
+
for pattern in status_pattern:
|
|
134
|
+
assert re.search(pattern, out) is not None
|
|
135
|
+
|
|
136
|
+
|
|
55
137
|
@pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
|
|
138
|
+
@pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning")
|
|
56
139
|
def test_pathfinder(inference_backend, reference_idata):
|
|
57
140
|
if inference_backend == "blackjax" and sys.platform == "win32":
|
|
58
141
|
pytest.skip("JAX not supported on windows")
|
|
@@ -62,15 +145,15 @@ def test_pathfinder(inference_backend, reference_idata):
|
|
|
62
145
|
with model:
|
|
63
146
|
idata = pmx.fit(
|
|
64
147
|
method="pathfinder",
|
|
65
|
-
num_paths=
|
|
66
|
-
jitter=
|
|
148
|
+
num_paths=10,
|
|
149
|
+
jitter=12.0,
|
|
67
150
|
random_seed=41,
|
|
68
151
|
inference_backend=inference_backend,
|
|
69
152
|
)
|
|
70
153
|
else:
|
|
71
154
|
idata = reference_idata
|
|
72
|
-
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=
|
|
73
|
-
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.
|
|
155
|
+
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=0.95)
|
|
156
|
+
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.35)
|
|
74
157
|
|
|
75
158
|
assert idata.posterior["mu"].shape == (1, 1000)
|
|
76
159
|
assert idata.posterior["tau"].shape == (1, 1000)
|
|
@@ -83,8 +166,8 @@ def test_concurrent_results(reference_idata, concurrent):
|
|
|
83
166
|
with model:
|
|
84
167
|
idata_conc = pmx.fit(
|
|
85
168
|
method="pathfinder",
|
|
86
|
-
num_paths=
|
|
87
|
-
jitter=
|
|
169
|
+
num_paths=10,
|
|
170
|
+
jitter=12.0,
|
|
88
171
|
random_seed=41,
|
|
89
172
|
inference_backend="pymc",
|
|
90
173
|
concurrent=concurrent,
|
|
@@ -108,7 +191,7 @@ def test_seed(reference_idata):
|
|
|
108
191
|
with model:
|
|
109
192
|
idata_41 = pmx.fit(
|
|
110
193
|
method="pathfinder",
|
|
111
|
-
num_paths=
|
|
194
|
+
num_paths=4,
|
|
112
195
|
jitter=10.0,
|
|
113
196
|
random_seed=41,
|
|
114
197
|
inference_backend="pymc",
|
|
@@ -116,7 +199,7 @@ def test_seed(reference_idata):
|
|
|
116
199
|
|
|
117
200
|
idata_123 = pmx.fit(
|
|
118
201
|
method="pathfinder",
|
|
119
|
-
num_paths=
|
|
202
|
+
num_paths=4,
|
|
120
203
|
jitter=10.0,
|
|
121
204
|
random_seed=123,
|
|
122
205
|
inference_backend="pymc",
|
|
@@ -149,12 +232,11 @@ def test_bfgs_sample():
|
|
|
149
232
|
# get factors
|
|
150
233
|
x_full = pt.as_tensor(x_data, dtype="float64")
|
|
151
234
|
g_full = pt.as_tensor(g_data, dtype="float64")
|
|
152
|
-
epsilon = 1e-11
|
|
153
235
|
|
|
154
236
|
x = x_full[1:]
|
|
155
237
|
g = g_full[1:]
|
|
156
|
-
alpha,
|
|
157
|
-
beta, gamma = inverse_hessian_factors(alpha,
|
|
238
|
+
alpha, s, z = alpha_recover(x_full, g_full)
|
|
239
|
+
beta, gamma = inverse_hessian_factors(alpha, s, z, J)
|
|
158
240
|
|
|
159
241
|
# sample
|
|
160
242
|
phi, logq = bfgs_sample(
|
|
@@ -169,5 +251,47 @@ def test_bfgs_sample():
|
|
|
169
251
|
# check shapes
|
|
170
252
|
assert beta.eval().shape == (L, N, 2 * J)
|
|
171
253
|
assert gamma.eval().shape == (L, 2 * J, 2 * J)
|
|
172
|
-
assert phi.eval()
|
|
173
|
-
assert logq.eval()
|
|
254
|
+
assert all(phi.shape.eval() == (L, num_samples, N))
|
|
255
|
+
assert all(logq.shape.eval() == (L, num_samples))
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])
|
|
259
|
+
def test_pathfinder_importance_sampling(importance_sampling):
|
|
260
|
+
model = eight_schools_model()
|
|
261
|
+
|
|
262
|
+
num_paths = 4
|
|
263
|
+
num_draws_per_path = 300
|
|
264
|
+
num_draws = 750
|
|
265
|
+
|
|
266
|
+
with model:
|
|
267
|
+
idata = pmx.fit(
|
|
268
|
+
method="pathfinder",
|
|
269
|
+
num_paths=num_paths,
|
|
270
|
+
num_draws_per_path=num_draws_per_path,
|
|
271
|
+
num_draws=num_draws,
|
|
272
|
+
maxiter=5,
|
|
273
|
+
random_seed=41,
|
|
274
|
+
inference_backend="pymc",
|
|
275
|
+
importance_sampling=importance_sampling,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
if importance_sampling is None:
|
|
279
|
+
assert idata.posterior["mu"].shape == (num_paths, num_draws_per_path)
|
|
280
|
+
assert idata.posterior["tau"].shape == (num_paths, num_draws_per_path)
|
|
281
|
+
assert idata.posterior["theta"].shape == (num_paths, num_draws_per_path, 8)
|
|
282
|
+
else:
|
|
283
|
+
assert idata.posterior["mu"].shape == (1, num_draws)
|
|
284
|
+
assert idata.posterior["tau"].shape == (1, num_draws)
|
|
285
|
+
assert idata.posterior["theta"].shape == (1, num_draws, 8)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def test_pathfinder_initvals():
|
|
289
|
+
# Run a model with an ordered transform that will fail unless initvals are in place
|
|
290
|
+
with pm.Model() as mdl:
|
|
291
|
+
pm.Normal("ordered", size=10, transform=pm.distributions.transforms.ordered)
|
|
292
|
+
idata = pmx.fit_pathfinder(initvals={"ordered": np.linspace(0, 1, 10)})
|
|
293
|
+
|
|
294
|
+
# Check that the samples are ordered to make sure transform was applied
|
|
295
|
+
assert np.all(
|
|
296
|
+
idata.posterior["ordered"][..., 1:].values > idata.posterior["ordered"][..., :-1].values
|
|
297
|
+
)
|
|
File without changes
|
|
File without changes
|