pymc-extras 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_pathfinder.py CHANGED
@@ -12,14 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import re
15
16
  import sys
16
17
 
17
18
  import numpy as np
18
19
  import pymc as pm
20
+ import pytensor.tensor as pt
19
21
  import pytest
20
22
 
21
- pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning")
22
-
23
23
  import pymc_extras as pmx
24
24
 
25
25
 
@@ -44,15 +44,98 @@ def reference_idata():
44
44
  with model:
45
45
  idata = pmx.fit(
46
46
  method="pathfinder",
47
- num_paths=50,
48
- jitter=10.0,
47
+ num_paths=10,
48
+ jitter=12.0,
49
49
  random_seed=41,
50
50
  inference_backend="pymc",
51
51
  )
52
52
  return idata
53
53
 
54
54
 
55
+ def unstable_lbfgs_update_mask_model() -> pm.Model:
56
+ # data and model from: https://github.com/pymc-devs/pymc-extras/issues/445
57
+ # this scenario made LBFGS struggle leading to a lot of rejected iterations, (result.nit being moderate, but only history.count <= 1).
58
+ # this scenario is used to test that the LBFGS history manager is rejecting iterations as expected and PF can run to completion.
59
+
60
+ # fmt: off
61
+ inp = np.array([0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 1, 1, 2, 0, 1, 2, 1, 0, 1, 0, 1, 0, 1, 0])
62
+
63
+ res = np.array([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[0,1,0,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,1,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[1,0,0,0,0],[0,1,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,0,1,0]])
64
+ # fmt: on
65
+
66
+ n_ordered = res.shape[1]
67
+ coords = {
68
+ "obs": np.arange(len(inp)),
69
+ "inp": np.arange(max(inp) + 1),
70
+ "outp": np.arange(res.shape[1]),
71
+ }
72
+ with pm.Model(coords=coords) as mdl:
73
+ mu = pm.Normal("intercept", sigma=3.5)[None]
74
+
75
+ offset = pm.Normal(
76
+ "offset", dims=("inp"), transform=pm.distributions.transforms.ZeroSumTransform([0])
77
+ )
78
+
79
+ scale = 3.5 * pm.HalfStudentT("scale", nu=5)
80
+ mu += (scale * offset)[inp]
81
+
82
+ phi_delta = pm.Dirichlet("phi_diffs", [1.0] * (n_ordered - 1))
83
+ phi = pt.concatenate([[0], pt.cumsum(phi_delta)])
84
+ s_mu = pm.Normal(
85
+ "stereotype_intercept",
86
+ size=n_ordered,
87
+ transform=pm.distributions.transforms.ZeroSumTransform([-1]),
88
+ )
89
+ fprobs = pm.math.softmax(s_mu[None, :] + phi[None, :] * mu[:, None], axis=-1)
90
+
91
+ pm.Multinomial("y_res", p=fprobs, n=np.ones(len(inp)), observed=res, dims=("obs", "outp"))
92
+
93
+ return mdl
94
+
95
+
96
+ @pytest.mark.parametrize("jitter", [12.0, 500.0, 1000.0])
97
+ def test_unstable_lbfgs_update_mask(capsys, jitter):
98
+ model = unstable_lbfgs_update_mask_model()
99
+
100
+ if jitter < 1000:
101
+ with model:
102
+ idata = pmx.fit(
103
+ method="pathfinder",
104
+ jitter=jitter,
105
+ random_seed=4,
106
+ )
107
+ out, err = capsys.readouterr()
108
+ status_pattern = [
109
+ r"INIT_FAILED_LOW_UPDATE_PCT\s+\d+",
110
+ r"LOW_UPDATE_PCT\s+\d+",
111
+ r"LBFGS_FAILED\s+\d+",
112
+ r"SUCCESS\s+\d+",
113
+ ]
114
+ for pattern in status_pattern:
115
+ assert re.search(pattern, out) is not None
116
+
117
+ else:
118
+ with pytest.raises(ValueError, match="All paths failed"):
119
+ with model:
120
+ idata = pmx.fit(
121
+ method="pathfinder",
122
+ jitter=1000,
123
+ random_seed=2,
124
+ num_paths=4,
125
+ )
126
+ out, err = capsys.readouterr()
127
+
128
+ status_pattern = [
129
+ r"INIT_FAILED_LOW_UPDATE_PCT\s+2",
130
+ r"LOW_UPDATE_PCT\s+2",
131
+ r"LBFGS_FAILED\s+4",
132
+ ]
133
+ for pattern in status_pattern:
134
+ assert re.search(pattern, out) is not None
135
+
136
+
55
137
  @pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
138
+ @pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning")
56
139
  def test_pathfinder(inference_backend, reference_idata):
57
140
  if inference_backend == "blackjax" and sys.platform == "win32":
58
141
  pytest.skip("JAX not supported on windows")
@@ -62,15 +145,15 @@ def test_pathfinder(inference_backend, reference_idata):
62
145
  with model:
63
146
  idata = pmx.fit(
64
147
  method="pathfinder",
65
- num_paths=50,
66
- jitter=10.0,
148
+ num_paths=10,
149
+ jitter=12.0,
67
150
  random_seed=41,
68
151
  inference_backend=inference_backend,
69
152
  )
70
153
  else:
71
154
  idata = reference_idata
72
- np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6)
73
- np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5)
155
+ np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=0.95)
156
+ np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.35)
74
157
 
75
158
  assert idata.posterior["mu"].shape == (1, 1000)
76
159
  assert idata.posterior["tau"].shape == (1, 1000)
@@ -83,8 +166,8 @@ def test_concurrent_results(reference_idata, concurrent):
83
166
  with model:
84
167
  idata_conc = pmx.fit(
85
168
  method="pathfinder",
86
- num_paths=50,
87
- jitter=10.0,
169
+ num_paths=10,
170
+ jitter=12.0,
88
171
  random_seed=41,
89
172
  inference_backend="pymc",
90
173
  concurrent=concurrent,
@@ -108,7 +191,7 @@ def test_seed(reference_idata):
108
191
  with model:
109
192
  idata_41 = pmx.fit(
110
193
  method="pathfinder",
111
- num_paths=50,
194
+ num_paths=4,
112
195
  jitter=10.0,
113
196
  random_seed=41,
114
197
  inference_backend="pymc",
@@ -116,7 +199,7 @@ def test_seed(reference_idata):
116
199
 
117
200
  idata_123 = pmx.fit(
118
201
  method="pathfinder",
119
- num_paths=50,
202
+ num_paths=4,
120
203
  jitter=10.0,
121
204
  random_seed=123,
122
205
  inference_backend="pymc",
@@ -149,12 +232,11 @@ def test_bfgs_sample():
149
232
  # get factors
150
233
  x_full = pt.as_tensor(x_data, dtype="float64")
151
234
  g_full = pt.as_tensor(g_data, dtype="float64")
152
- epsilon = 1e-11
153
235
 
154
236
  x = x_full[1:]
155
237
  g = g_full[1:]
156
- alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon)
157
- beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J)
238
+ alpha, s, z = alpha_recover(x_full, g_full)
239
+ beta, gamma = inverse_hessian_factors(alpha, s, z, J)
158
240
 
159
241
  # sample
160
242
  phi, logq = bfgs_sample(
@@ -169,5 +251,47 @@ def test_bfgs_sample():
169
251
  # check shapes
170
252
  assert beta.eval().shape == (L, N, 2 * J)
171
253
  assert gamma.eval().shape == (L, 2 * J, 2 * J)
172
- assert phi.eval().shape == (L, num_samples, N)
173
- assert logq.eval().shape == (L, num_samples)
254
+ assert all(phi.shape.eval() == (L, num_samples, N))
255
+ assert all(logq.shape.eval() == (L, num_samples))
256
+
257
+
258
+ @pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])
259
+ def test_pathfinder_importance_sampling(importance_sampling):
260
+ model = eight_schools_model()
261
+
262
+ num_paths = 4
263
+ num_draws_per_path = 300
264
+ num_draws = 750
265
+
266
+ with model:
267
+ idata = pmx.fit(
268
+ method="pathfinder",
269
+ num_paths=num_paths,
270
+ num_draws_per_path=num_draws_per_path,
271
+ num_draws=num_draws,
272
+ maxiter=5,
273
+ random_seed=41,
274
+ inference_backend="pymc",
275
+ importance_sampling=importance_sampling,
276
+ )
277
+
278
+ if importance_sampling is None:
279
+ assert idata.posterior["mu"].shape == (num_paths, num_draws_per_path)
280
+ assert idata.posterior["tau"].shape == (num_paths, num_draws_per_path)
281
+ assert idata.posterior["theta"].shape == (num_paths, num_draws_per_path, 8)
282
+ else:
283
+ assert idata.posterior["mu"].shape == (1, num_draws)
284
+ assert idata.posterior["tau"].shape == (1, num_draws)
285
+ assert idata.posterior["theta"].shape == (1, num_draws, 8)
286
+
287
+
288
+ def test_pathfinder_initvals():
289
+ # Run a model with an ordered transform that will fail unless initvals are in place
290
+ with pm.Model() as mdl:
291
+ pm.Normal("ordered", size=10, transform=pm.distributions.transforms.ordered)
292
+ idata = pmx.fit_pathfinder(initvals={"ordered": np.linspace(0, 1, 10)})
293
+
294
+ # Check that the samples are ordered to make sure transform was applied
295
+ assert np.all(
296
+ idata.posterior["ordered"][..., 1:].values > idata.posterior["ordered"][..., :-1].values
297
+ )