pymc-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. pymc_extras/__init__.py +5 -1
  2. pymc_extras/deserialize.py +224 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/inference/find_map.py +62 -17
  6. pymc_extras/inference/laplace.py +10 -7
  7. pymc_extras/prior.py +1356 -0
  8. pymc_extras/statespace/core/statespace.py +191 -52
  9. pymc_extras/statespace/filters/distributions.py +15 -16
  10. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  11. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  12. pymc_extras/statespace/models/ETS.py +10 -0
  13. pymc_extras/statespace/models/SARIMAX.py +26 -5
  14. pymc_extras/statespace/models/VARMAX.py +12 -2
  15. pymc_extras/statespace/models/structural.py +18 -5
  16. pymc_extras-0.2.7.dist-info/METADATA +321 -0
  17. pymc_extras-0.2.7.dist-info/RECORD +66 -0
  18. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/WHEEL +1 -2
  19. pymc_extras/utils/pivoted_cholesky.py +0 -69
  20. pymc_extras/version.py +0 -11
  21. pymc_extras/version.txt +0 -1
  22. pymc_extras-0.2.5.dist-info/METADATA +0 -112
  23. pymc_extras-0.2.5.dist-info/RECORD +0 -108
  24. pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
  25. tests/__init__.py +0 -13
  26. tests/distributions/__init__.py +0 -19
  27. tests/distributions/test_continuous.py +0 -185
  28. tests/distributions/test_discrete.py +0 -210
  29. tests/distributions/test_discrete_markov_chain.py +0 -258
  30. tests/distributions/test_multivariate.py +0 -304
  31. tests/distributions/test_transform.py +0 -77
  32. tests/model/__init__.py +0 -0
  33. tests/model/marginal/__init__.py +0 -0
  34. tests/model/marginal/test_distributions.py +0 -132
  35. tests/model/marginal/test_graph_analysis.py +0 -182
  36. tests/model/marginal/test_marginal_model.py +0 -967
  37. tests/model/test_model_api.py +0 -38
  38. tests/statespace/__init__.py +0 -0
  39. tests/statespace/test_ETS.py +0 -411
  40. tests/statespace/test_SARIMAX.py +0 -405
  41. tests/statespace/test_VARMAX.py +0 -184
  42. tests/statespace/test_coord_assignment.py +0 -181
  43. tests/statespace/test_distributions.py +0 -270
  44. tests/statespace/test_kalman_filter.py +0 -326
  45. tests/statespace/test_representation.py +0 -175
  46. tests/statespace/test_statespace.py +0 -872
  47. tests/statespace/test_statespace_JAX.py +0 -156
  48. tests/statespace/test_structural.py +0 -836
  49. tests/statespace/utilities/__init__.py +0 -0
  50. tests/statespace/utilities/shared_fixtures.py +0 -9
  51. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  52. tests/statespace/utilities/test_helpers.py +0 -310
  53. tests/test_blackjax_smc.py +0 -222
  54. tests/test_find_map.py +0 -103
  55. tests/test_histogram_approximation.py +0 -109
  56. tests/test_laplace.py +0 -281
  57. tests/test_linearmodel.py +0 -208
  58. tests/test_model_builder.py +0 -306
  59. tests/test_pathfinder.py +0 -297
  60. tests/test_pivoted_cholesky.py +0 -24
  61. tests/test_printing.py +0 -98
  62. tests/test_prior_from_trace.py +0 -172
  63. tests/test_splines.py +0 -77
  64. tests/utils.py +0 -0
  65. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,306 +0,0 @@
1
- # Copyright 2023 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import hashlib
16
- import json
17
- import sys
18
- import tempfile
19
-
20
- import numpy as np
21
- import pandas as pd
22
- import pymc as pm
23
- import pytest
24
-
25
- from pymc_extras.model_builder import ModelBuilder
26
-
27
-
28
- @pytest.fixture(scope="module")
29
- def toy_X():
30
- x = np.linspace(start=0, stop=1, num=100)
31
- X = pd.DataFrame({"input": x})
32
- return X
33
-
34
-
35
- @pytest.fixture(scope="module")
36
- def toy_y(toy_X):
37
- y = 5 * toy_X["input"] + 3
38
- y = y + np.random.normal(0, 1, size=len(toy_X))
39
- y = pd.Series(y, name="output")
40
- return y
41
-
42
-
43
- def get_unfitted_model_instance(X, y):
44
- """Creates an unfitted model instance to which idata can be copied in
45
- and then used as a fitted model instance. That way a fitted model
46
- can be used multiple times without having to run `fit` multiple times."""
47
- sampler_config = {
48
- "draws": 20,
49
- "tune": 10,
50
- "chains": 2,
51
- "target_accept": 0.95,
52
- }
53
- model_config = {
54
- "a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
55
- "b": {"loc": 0, "scale": 10},
56
- "obs_error": 2,
57
- }
58
- model = test_ModelBuilder(
59
- model_config=model_config, sampler_config=sampler_config, test_parameter="test_paramter"
60
- )
61
- # Do the things that `model.fit` does except sample to create idata.
62
- model._generate_and_preprocess_model_data(X, y.values.flatten())
63
- model.build_model(X, y)
64
- return model
65
-
66
-
67
- @pytest.fixture(scope="module")
68
- def fitted_model_instance_base(toy_X, toy_y):
69
- """Because fitting takes a relatively long time, this is intended to
70
- be used only once and then have new instances created and fit data patched in
71
- for tests that use a fitted model instance. Tests should use
72
- `fitted_model_instance` instead of this."""
73
- model = get_unfitted_model_instance(toy_X, toy_y)
74
- model.fit(toy_X, toy_y)
75
- return model
76
-
77
-
78
- @pytest.fixture
79
- def fitted_model_instance(toy_X, toy_y, fitted_model_instance_base):
80
- """Get a fitted model instance. A new instance is created and fit data is
81
- patched in, so tests using this fixture can modify the model object without
82
- affecting other tests."""
83
- model = get_unfitted_model_instance(toy_X, toy_y)
84
- model.idata = fitted_model_instance_base.idata.copy()
85
- return model
86
-
87
-
88
- class test_ModelBuilder(ModelBuilder):
89
- def __init__(self, model_config=None, sampler_config=None, test_parameter=None):
90
- self.test_parameter = test_parameter
91
- super().__init__(model_config=model_config, sampler_config=sampler_config)
92
-
93
- _model_type = "test_model"
94
- version = "0.1"
95
-
96
- def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
97
- coords = {"numbers": np.arange(len(X))}
98
- self.generate_and_preprocess_model_data(X, y)
99
- with pm.Model(coords=coords) as self.model:
100
- if model_config is None:
101
- model_config = self.model_config
102
- x = pm.Data("x", self.X["input"].values)
103
- y_data = pm.Data("y_data", self.y)
104
-
105
- # prior parameters
106
- a_loc = model_config["a"]["loc"]
107
- a_scale = model_config["a"]["scale"]
108
- b_loc = model_config["b"]["loc"]
109
- b_scale = model_config["b"]["scale"]
110
- obs_error = model_config["obs_error"]
111
-
112
- # priors
113
- a = pm.Normal("a", a_loc, sigma=a_scale, dims=model_config["a"]["dims"])
114
- b = pm.Normal("b", b_loc, sigma=b_scale)
115
- obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
116
-
117
- # observed data
118
- output = pm.Normal("output", a + b * x, obs_error, shape=x.shape, observed=y_data)
119
-
120
- def _save_input_params(self, idata):
121
- idata.attrs["test_paramter"] = json.dumps(self.test_parameter)
122
-
123
- @property
124
- def output_var(self):
125
- return "output"
126
-
127
- def _data_setter(self, x: pd.Series, y: pd.Series = None):
128
- with self.model:
129
- pm.set_data({"x": x.values})
130
- if y is not None:
131
- pm.set_data({"y_data": y.values})
132
-
133
- @property
134
- def _serializable_model_config(self):
135
- return self.model_config
136
-
137
- def generate_and_preprocess_model_data(self, X: pd.DataFrame, y: pd.Series):
138
- self.X = X
139
- self.y = y
140
-
141
- @staticmethod
142
- def get_default_model_config() -> dict:
143
- return {
144
- "a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
145
- "b": {"loc": 0, "scale": 10},
146
- "obs_error": 2,
147
- }
148
-
149
- def _generate_and_preprocess_model_data(
150
- self, X: pd.DataFrame | pd.Series, y: pd.Series
151
- ) -> None:
152
- self.X = X
153
- self.y = y
154
-
155
- @staticmethod
156
- def get_default_sampler_config() -> dict:
157
- return {
158
- "draws": 10,
159
- "tune": 10,
160
- "chains": 3,
161
- "target_accept": 0.95,
162
- }
163
-
164
-
165
- def test_save_input_params(fitted_model_instance):
166
- assert fitted_model_instance.idata.attrs["test_paramter"] == '"test_paramter"'
167
-
168
-
169
- @pytest.mark.skipif(
170
- sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
171
- )
172
- def test_save_load(fitted_model_instance):
173
- temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
174
- fitted_model_instance.save(temp.name)
175
- test_builder2 = test_ModelBuilder.load(temp.name)
176
- assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
177
- assert fitted_model_instance.id == test_builder2.id
178
- x_pred = np.random.uniform(low=0, high=1, size=100)
179
- prediction_data = pd.DataFrame({"input": x_pred})
180
- pred1 = fitted_model_instance.predict(prediction_data["input"])
181
- pred2 = test_builder2.predict(prediction_data["input"])
182
- assert pred1.shape == pred2.shape
183
- temp.close()
184
-
185
-
186
- def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> ModelBuilder:
187
- if check_idata:
188
- assert fitted_model_instance.idata is not None
189
- assert "posterior" in fitted_model_instance.idata.groups()
190
-
191
-
192
- def test_save_without_fit_raises_runtime_error():
193
- model_builder = test_ModelBuilder()
194
- with pytest.raises(RuntimeError):
195
- model_builder.save("saved_model")
196
-
197
-
198
- def test_empty_sampler_config_fit(toy_X, toy_y):
199
- sampler_config = {}
200
- model_builder = test_ModelBuilder(sampler_config=sampler_config)
201
- model_builder.idata = model_builder.fit(X=toy_X, y=toy_y)
202
- assert model_builder.idata is not None
203
- assert "posterior" in model_builder.idata.groups()
204
-
205
-
206
- def test_fit(fitted_model_instance):
207
- prediction_data = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
208
- pred = fitted_model_instance.predict(prediction_data["input"])
209
- post_pred = fitted_model_instance.sample_posterior_predictive(
210
- prediction_data["input"], extend_idata=True, combined=True
211
- )
212
- post_pred[fitted_model_instance.output_var].shape[0] == prediction_data.input.shape
213
-
214
-
215
- def test_fit_no_y(toy_X):
216
- model_builder = test_ModelBuilder()
217
- model_builder.idata = model_builder.fit(X=toy_X, chains=1, tune=1, draws=1)
218
- assert model_builder.model is not None
219
- assert model_builder.idata is not None
220
- assert "posterior" in model_builder.idata.groups()
221
-
222
-
223
- def test_predict(fitted_model_instance):
224
- x_pred = np.random.uniform(low=0, high=1, size=100)
225
- prediction_data = pd.DataFrame({"input": x_pred})
226
- pred = fitted_model_instance.predict(prediction_data["input"])
227
- # Perform elementwise comparison using numpy
228
- assert isinstance(pred, np.ndarray)
229
- assert len(pred) > 0
230
-
231
-
232
- @pytest.mark.parametrize("combined", [True, False])
233
- def test_sample_posterior_predictive(fitted_model_instance, combined):
234
- n_pred = 100
235
- x_pred = np.random.uniform(low=0, high=1, size=n_pred)
236
- prediction_data = pd.DataFrame({"input": x_pred})
237
- pred = fitted_model_instance.sample_posterior_predictive(
238
- prediction_data["input"], combined=combined, extend_idata=True
239
- )
240
- chains = fitted_model_instance.idata.sample_stats.sizes["chain"]
241
- draws = fitted_model_instance.idata.sample_stats.sizes["draw"]
242
- expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
243
- assert pred[fitted_model_instance.output_var].shape == expected_shape
244
- assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)
245
-
246
-
247
- @pytest.mark.parametrize("group", ["prior_predictive", "posterior_predictive"])
248
- @pytest.mark.parametrize("extend_idata", [True, False])
249
- def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idata):
250
- output_var = fitted_model_instance.output_var
251
- idata_prev = fitted_model_instance.idata[group][output_var]
252
-
253
- # Since coordinates are provided, the dimension must match
254
- n_pred = 100 # Must match toy_x
255
- x_pred = np.random.uniform(0, 1, n_pred)
256
-
257
- prediction_data = pd.DataFrame({"input": x_pred})
258
- if group == "prior_predictive":
259
- prediction_method = fitted_model_instance.sample_prior_predictive
260
- else: # group == "posterior_predictive":
261
- prediction_method = fitted_model_instance.sample_posterior_predictive
262
-
263
- pred = prediction_method(prediction_data["input"], combined=False, extend_idata=extend_idata)
264
-
265
- pred_unstacked = pred[output_var].values
266
- idata_now = fitted_model_instance.idata[group][output_var].values
267
-
268
- if extend_idata:
269
- # After sampling, data in the model should be the same as the predictions
270
- np.testing.assert_array_equal(idata_now, pred_unstacked)
271
- # Data in the model should NOT be the same as before
272
- if idata_now.shape == idata_prev.values.shape:
273
- assert np.sum(np.abs(idata_now - idata_prev.values) < 1e-5) <= 2
274
- else:
275
- # After sampling, data in the model should be the same as it was before
276
- np.testing.assert_array_equal(idata_now, idata_prev.values)
277
- # Data in the model should NOT be the same as the predictions
278
- if idata_now.shape == pred_unstacked.shape:
279
- assert np.sum(np.abs(idata_now - pred_unstacked) < 1e-5) <= 2
280
-
281
-
282
- def test_model_config_formatting():
283
- model_config = {
284
- "a": {
285
- "loc": [0, 0],
286
- "scale": 10,
287
- "dims": [
288
- "x",
289
- ],
290
- },
291
- }
292
- model_builder = test_ModelBuilder()
293
- converted_model_config = model_builder._model_config_formatting(model_config)
294
- np.testing.assert_equal(converted_model_config["a"]["dims"], ("x",))
295
- np.testing.assert_equal(converted_model_config["a"]["loc"], np.array([0, 0]))
296
-
297
-
298
- def test_id():
299
- model_builder = test_ModelBuilder()
300
- expected_id = hashlib.sha256(
301
- str(model_builder.model_config.values()).encode()
302
- + model_builder.version.encode()
303
- + model_builder._model_type.encode()
304
- ).hexdigest()[:16]
305
-
306
- assert model_builder.id == expected_id
tests/test_pathfinder.py DELETED
@@ -1,297 +0,0 @@
1
- # Copyright 2022 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import re
16
- import sys
17
-
18
- import numpy as np
19
- import pymc as pm
20
- import pytensor.tensor as pt
21
- import pytest
22
-
23
- import pymc_extras as pmx
24
-
25
-
26
- def eight_schools_model() -> pm.Model:
27
- J = 8
28
- y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
29
- sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
30
-
31
- with pm.Model() as model:
32
- mu = pm.Normal("mu", mu=0.0, sigma=10.0)
33
- tau = pm.HalfCauchy("tau", 5.0)
34
-
35
- theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
36
- obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)
37
-
38
- return model
39
-
40
-
41
- @pytest.fixture
42
- def reference_idata():
43
- model = eight_schools_model()
44
- with model:
45
- idata = pmx.fit(
46
- method="pathfinder",
47
- num_paths=10,
48
- jitter=12.0,
49
- random_seed=41,
50
- inference_backend="pymc",
51
- )
52
- return idata
53
-
54
-
55
- def unstable_lbfgs_update_mask_model() -> pm.Model:
56
- # data and model from: https://github.com/pymc-devs/pymc-extras/issues/445
57
- # this scenario made LBFGS struggle leading to a lot of rejected iterations, (result.nit being moderate, but only history.count <= 1).
58
- # this scenario is used to test that the LBFGS history manager is rejecting iterations as expected and PF can run to completion.
59
-
60
- # fmt: off
61
- inp = np.array([0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 1, 1, 2, 0, 1, 2, 1, 0, 1, 0, 1, 0, 1, 0])
62
-
63
- res = np.array([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[0,1,0,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,1,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[1,0,0,0,0],[0,1,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,0,1,0]])
64
- # fmt: on
65
-
66
- n_ordered = res.shape[1]
67
- coords = {
68
- "obs": np.arange(len(inp)),
69
- "inp": np.arange(max(inp) + 1),
70
- "outp": np.arange(res.shape[1]),
71
- }
72
- with pm.Model(coords=coords) as mdl:
73
- mu = pm.Normal("intercept", sigma=3.5)[None]
74
-
75
- offset = pm.Normal(
76
- "offset", dims=("inp"), transform=pm.distributions.transforms.ZeroSumTransform([0])
77
- )
78
-
79
- scale = 3.5 * pm.HalfStudentT("scale", nu=5)
80
- mu += (scale * offset)[inp]
81
-
82
- phi_delta = pm.Dirichlet("phi_diffs", [1.0] * (n_ordered - 1))
83
- phi = pt.concatenate([[0], pt.cumsum(phi_delta)])
84
- s_mu = pm.Normal(
85
- "stereotype_intercept",
86
- size=n_ordered,
87
- transform=pm.distributions.transforms.ZeroSumTransform([-1]),
88
- )
89
- fprobs = pm.math.softmax(s_mu[None, :] + phi[None, :] * mu[:, None], axis=-1)
90
-
91
- pm.Multinomial("y_res", p=fprobs, n=np.ones(len(inp)), observed=res, dims=("obs", "outp"))
92
-
93
- return mdl
94
-
95
-
96
- @pytest.mark.parametrize("jitter", [12.0, 500.0, 1000.0])
97
- def test_unstable_lbfgs_update_mask(capsys, jitter):
98
- model = unstable_lbfgs_update_mask_model()
99
-
100
- if jitter < 1000:
101
- with model:
102
- idata = pmx.fit(
103
- method="pathfinder",
104
- jitter=jitter,
105
- random_seed=4,
106
- )
107
- out, err = capsys.readouterr()
108
- status_pattern = [
109
- r"INIT_FAILED_LOW_UPDATE_PCT\s+\d+",
110
- r"LOW_UPDATE_PCT\s+\d+",
111
- r"LBFGS_FAILED\s+\d+",
112
- r"SUCCESS\s+\d+",
113
- ]
114
- for pattern in status_pattern:
115
- assert re.search(pattern, out) is not None
116
-
117
- else:
118
- with pytest.raises(ValueError, match="All paths failed"):
119
- with model:
120
- idata = pmx.fit(
121
- method="pathfinder",
122
- jitter=1000,
123
- random_seed=2,
124
- num_paths=4,
125
- )
126
- out, err = capsys.readouterr()
127
-
128
- status_pattern = [
129
- r"INIT_FAILED_LOW_UPDATE_PCT\s+2",
130
- r"LOW_UPDATE_PCT\s+2",
131
- r"LBFGS_FAILED\s+4",
132
- ]
133
- for pattern in status_pattern:
134
- assert re.search(pattern, out) is not None
135
-
136
-
137
- @pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
138
- @pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning")
139
- def test_pathfinder(inference_backend, reference_idata):
140
- if inference_backend == "blackjax" and sys.platform == "win32":
141
- pytest.skip("JAX not supported on windows")
142
-
143
- if inference_backend == "blackjax":
144
- model = eight_schools_model()
145
- with model:
146
- idata = pmx.fit(
147
- method="pathfinder",
148
- num_paths=10,
149
- jitter=12.0,
150
- random_seed=41,
151
- inference_backend=inference_backend,
152
- )
153
- else:
154
- idata = reference_idata
155
- np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=0.95)
156
- np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.35)
157
-
158
- assert idata.posterior["mu"].shape == (1, 1000)
159
- assert idata.posterior["tau"].shape == (1, 1000)
160
- assert idata.posterior["theta"].shape == (1, 1000, 8)
161
-
162
-
163
- @pytest.mark.parametrize("concurrent", ["thread", "process"])
164
- def test_concurrent_results(reference_idata, concurrent):
165
- model = eight_schools_model()
166
- with model:
167
- idata_conc = pmx.fit(
168
- method="pathfinder",
169
- num_paths=10,
170
- jitter=12.0,
171
- random_seed=41,
172
- inference_backend="pymc",
173
- concurrent=concurrent,
174
- )
175
-
176
- np.testing.assert_allclose(
177
- reference_idata.posterior.mu.data.mean(),
178
- idata_conc.posterior.mu.data.mean(),
179
- atol=0.4,
180
- )
181
-
182
- np.testing.assert_allclose(
183
- reference_idata.posterior.tau.data.mean(),
184
- idata_conc.posterior.tau.data.mean(),
185
- atol=0.4,
186
- )
187
-
188
-
189
- def test_seed(reference_idata):
190
- model = eight_schools_model()
191
- with model:
192
- idata_41 = pmx.fit(
193
- method="pathfinder",
194
- num_paths=4,
195
- jitter=10.0,
196
- random_seed=41,
197
- inference_backend="pymc",
198
- )
199
-
200
- idata_123 = pmx.fit(
201
- method="pathfinder",
202
- num_paths=4,
203
- jitter=10.0,
204
- random_seed=123,
205
- inference_backend="pymc",
206
- )
207
-
208
- assert not np.allclose(idata_41.posterior.mu.data.mean(), idata_123.posterior.mu.data.mean())
209
-
210
- assert np.allclose(idata_41.posterior.mu.data.mean(), idata_41.posterior.mu.data.mean())
211
-
212
-
213
- def test_bfgs_sample():
214
- import pytensor.tensor as pt
215
-
216
- from pymc_extras.inference.pathfinder.pathfinder import (
217
- alpha_recover,
218
- bfgs_sample,
219
- inverse_hessian_factors,
220
- )
221
-
222
- """test BFGS sampling"""
223
- Lp1, N = 8, 10
224
- L = Lp1 - 1
225
- J = 6
226
- num_samples = 1000
227
-
228
- # mock data
229
- x_data = np.random.randn(Lp1, N)
230
- g_data = np.random.randn(Lp1, N)
231
-
232
- # get factors
233
- x_full = pt.as_tensor(x_data, dtype="float64")
234
- g_full = pt.as_tensor(g_data, dtype="float64")
235
-
236
- x = x_full[1:]
237
- g = g_full[1:]
238
- alpha, s, z = alpha_recover(x_full, g_full)
239
- beta, gamma = inverse_hessian_factors(alpha, s, z, J)
240
-
241
- # sample
242
- phi, logq = bfgs_sample(
243
- num_samples=num_samples,
244
- x=x,
245
- g=g,
246
- alpha=alpha,
247
- beta=beta,
248
- gamma=gamma,
249
- )
250
-
251
- # check shapes
252
- assert beta.eval().shape == (L, N, 2 * J)
253
- assert gamma.eval().shape == (L, 2 * J, 2 * J)
254
- assert all(phi.shape.eval() == (L, num_samples, N))
255
- assert all(logq.shape.eval() == (L, num_samples))
256
-
257
-
258
- @pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])
259
- def test_pathfinder_importance_sampling(importance_sampling):
260
- model = eight_schools_model()
261
-
262
- num_paths = 4
263
- num_draws_per_path = 300
264
- num_draws = 750
265
-
266
- with model:
267
- idata = pmx.fit(
268
- method="pathfinder",
269
- num_paths=num_paths,
270
- num_draws_per_path=num_draws_per_path,
271
- num_draws=num_draws,
272
- maxiter=5,
273
- random_seed=41,
274
- inference_backend="pymc",
275
- importance_sampling=importance_sampling,
276
- )
277
-
278
- if importance_sampling is None:
279
- assert idata.posterior["mu"].shape == (num_paths, num_draws_per_path)
280
- assert idata.posterior["tau"].shape == (num_paths, num_draws_per_path)
281
- assert idata.posterior["theta"].shape == (num_paths, num_draws_per_path, 8)
282
- else:
283
- assert idata.posterior["mu"].shape == (1, num_draws)
284
- assert idata.posterior["tau"].shape == (1, num_draws)
285
- assert idata.posterior["theta"].shape == (1, num_draws, 8)
286
-
287
-
288
- def test_pathfinder_initvals():
289
- # Run a model with an ordered transform that will fail unless initvals are in place
290
- with pm.Model() as mdl:
291
- pm.Normal("ordered", size=10, transform=pm.distributions.transforms.ordered)
292
- idata = pmx.fit_pathfinder(initvals={"ordered": np.linspace(0, 1, 10)})
293
-
294
- # Check that the samples are ordered to make sure transform was applied
295
- assert np.all(
296
- idata.posterior["ordered"][..., 1:].values > idata.posterior["ordered"][..., :-1].values
297
- )
@@ -1,24 +0,0 @@
1
- # try:
2
- # import gpytorch
3
- # import torch
4
- # except ImportError as e:
5
- # # print(
6
- # # f"Please install Pytorch and GPyTorch to use this pivoted Cholesky implementation. Error {e}"
7
- # # )
8
- # pass
9
- # import numpy as np
10
- #
11
- # import pymc_extras as pmx
12
- #
13
- #
14
- # def test_match_gpytorch_linearcg_output():
15
- # N = 10
16
- # rank = 5
17
- # np.random.seed(1234) # nans with seed 1234
18
- # K = np.random.randn(N, N)
19
- # K = K @ K.T + N * np.eye(N)
20
- # K_torch = torch.from_numpy(K)
21
- #
22
- # L_gpt = gpytorch.pivoted_cholesky(K_torch, rank=rank, error_tol=1e-3)
23
- # L_np, _ = pmx.utils.pivoted_cholesky(K, max_iter=rank, error_tol=1e-3)
24
- # assert np.allclose(L_gpt, L_np.T)