pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,306 @@
1
+ # Copyright 2023 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import hashlib
16
+ import json
17
+ import sys
18
+ import tempfile
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+ import pymc as pm
23
+ import pytest
24
+
25
+ from pymc_extras.model_builder import ModelBuilder
26
+
27
+
28
+ @pytest.fixture(scope="module")
29
+ def toy_X():
30
+ x = np.linspace(start=0, stop=1, num=100)
31
+ X = pd.DataFrame({"input": x})
32
+ return X
33
+
34
+
35
+ @pytest.fixture(scope="module")
36
+ def toy_y(toy_X):
37
+ y = 5 * toy_X["input"] + 3
38
+ y = y + np.random.normal(0, 1, size=len(toy_X))
39
+ y = pd.Series(y, name="output")
40
+ return y
41
+
42
+
43
+ def get_unfitted_model_instance(X, y):
44
+ """Creates an unfitted model instance to which idata can be copied in
45
+ and then used as a fitted model instance. That way a fitted model
46
+ can be used multiple times without having to run `fit` multiple times."""
47
+ sampler_config = {
48
+ "draws": 20,
49
+ "tune": 10,
50
+ "chains": 2,
51
+ "target_accept": 0.95,
52
+ }
53
+ model_config = {
54
+ "a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
55
+ "b": {"loc": 0, "scale": 10},
56
+ "obs_error": 2,
57
+ }
58
+ model = test_ModelBuilder(
59
+ model_config=model_config, sampler_config=sampler_config, test_parameter="test_paramter"
60
+ )
61
+ # Do the things that `model.fit` does except sample to create idata.
62
+ model._generate_and_preprocess_model_data(X, y.values.flatten())
63
+ model.build_model(X, y)
64
+ return model
65
+
66
+
67
+ @pytest.fixture(scope="module")
68
+ def fitted_model_instance_base(toy_X, toy_y):
69
+ """Because fitting takes a relatively long time, this is intended to
70
+ be used only once and then have new instances created and fit data patched in
71
+ for tests that use a fitted model instance. Tests should use
72
+ `fitted_model_instance` instead of this."""
73
+ model = get_unfitted_model_instance(toy_X, toy_y)
74
+ model.fit(toy_X, toy_y)
75
+ return model
76
+
77
+
78
+ @pytest.fixture
79
+ def fitted_model_instance(toy_X, toy_y, fitted_model_instance_base):
80
+ """Get a fitted model instance. A new instance is created and fit data is
81
+ patched in, so tests using this fixture can modify the model object without
82
+ affecting other tests."""
83
+ model = get_unfitted_model_instance(toy_X, toy_y)
84
+ model.idata = fitted_model_instance_base.idata.copy()
85
+ return model
86
+
87
+
88
+ class test_ModelBuilder(ModelBuilder):
89
+ def __init__(self, model_config=None, sampler_config=None, test_parameter=None):
90
+ self.test_parameter = test_parameter
91
+ super().__init__(model_config=model_config, sampler_config=sampler_config)
92
+
93
+ _model_type = "test_model"
94
+ version = "0.1"
95
+
96
+ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
97
+ coords = {"numbers": np.arange(len(X))}
98
+ self.generate_and_preprocess_model_data(X, y)
99
+ with pm.Model(coords=coords) as self.model:
100
+ if model_config is None:
101
+ model_config = self.model_config
102
+ x = pm.Data("x", self.X["input"].values)
103
+ y_data = pm.Data("y_data", self.y)
104
+
105
+ # prior parameters
106
+ a_loc = model_config["a"]["loc"]
107
+ a_scale = model_config["a"]["scale"]
108
+ b_loc = model_config["b"]["loc"]
109
+ b_scale = model_config["b"]["scale"]
110
+ obs_error = model_config["obs_error"]
111
+
112
+ # priors
113
+ a = pm.Normal("a", a_loc, sigma=a_scale, dims=model_config["a"]["dims"])
114
+ b = pm.Normal("b", b_loc, sigma=b_scale)
115
+ obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
116
+
117
+ # observed data
118
+ output = pm.Normal("output", a + b * x, obs_error, shape=x.shape, observed=y_data)
119
+
120
+ def _save_input_params(self, idata):
121
+ idata.attrs["test_paramter"] = json.dumps(self.test_parameter)
122
+
123
+ @property
124
+ def output_var(self):
125
+ return "output"
126
+
127
+ def _data_setter(self, x: pd.Series, y: pd.Series = None):
128
+ with self.model:
129
+ pm.set_data({"x": x.values})
130
+ if y is not None:
131
+ pm.set_data({"y_data": y.values})
132
+
133
+ @property
134
+ def _serializable_model_config(self):
135
+ return self.model_config
136
+
137
+ def generate_and_preprocess_model_data(self, X: pd.DataFrame, y: pd.Series):
138
+ self.X = X
139
+ self.y = y
140
+
141
+ @staticmethod
142
+ def get_default_model_config() -> dict:
143
+ return {
144
+ "a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
145
+ "b": {"loc": 0, "scale": 10},
146
+ "obs_error": 2,
147
+ }
148
+
149
+ def _generate_and_preprocess_model_data(
150
+ self, X: pd.DataFrame | pd.Series, y: pd.Series
151
+ ) -> None:
152
+ self.X = X
153
+ self.y = y
154
+
155
+ @staticmethod
156
+ def get_default_sampler_config() -> dict:
157
+ return {
158
+ "draws": 10,
159
+ "tune": 10,
160
+ "chains": 3,
161
+ "target_accept": 0.95,
162
+ }
163
+
164
+
165
+ def test_save_input_params(fitted_model_instance):
166
+ assert fitted_model_instance.idata.attrs["test_paramter"] == '"test_paramter"'
167
+
168
+
169
+ @pytest.mark.skipif(
170
+ sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
171
+ )
172
+ def test_save_load(fitted_model_instance):
173
+ temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
174
+ fitted_model_instance.save(temp.name)
175
+ test_builder2 = test_ModelBuilder.load(temp.name)
176
+ assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
177
+ assert fitted_model_instance.id == test_builder2.id
178
+ x_pred = np.random.uniform(low=0, high=1, size=100)
179
+ prediction_data = pd.DataFrame({"input": x_pred})
180
+ pred1 = fitted_model_instance.predict(prediction_data["input"])
181
+ pred2 = test_builder2.predict(prediction_data["input"])
182
+ assert pred1.shape == pred2.shape
183
+ temp.close()
184
+
185
+
186
+ def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> ModelBuilder:
187
+ if check_idata:
188
+ assert fitted_model_instance.idata is not None
189
+ assert "posterior" in fitted_model_instance.idata.groups()
190
+
191
+
192
+ def test_save_without_fit_raises_runtime_error():
193
+ model_builder = test_ModelBuilder()
194
+ with pytest.raises(RuntimeError):
195
+ model_builder.save("saved_model")
196
+
197
+
198
+ def test_empty_sampler_config_fit(toy_X, toy_y):
199
+ sampler_config = {}
200
+ model_builder = test_ModelBuilder(sampler_config=sampler_config)
201
+ model_builder.idata = model_builder.fit(X=toy_X, y=toy_y)
202
+ assert model_builder.idata is not None
203
+ assert "posterior" in model_builder.idata.groups()
204
+
205
+
206
+ def test_fit(fitted_model_instance):
207
+ prediction_data = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
208
+ pred = fitted_model_instance.predict(prediction_data["input"])
209
+ post_pred = fitted_model_instance.sample_posterior_predictive(
210
+ prediction_data["input"], extend_idata=True, combined=True
211
+ )
212
+ post_pred[fitted_model_instance.output_var].shape[0] == prediction_data.input.shape
213
+
214
+
215
+ def test_fit_no_y(toy_X):
216
+ model_builder = test_ModelBuilder()
217
+ model_builder.idata = model_builder.fit(X=toy_X, chains=1, tune=1, draws=1)
218
+ assert model_builder.model is not None
219
+ assert model_builder.idata is not None
220
+ assert "posterior" in model_builder.idata.groups()
221
+
222
+
223
+ def test_predict(fitted_model_instance):
224
+ x_pred = np.random.uniform(low=0, high=1, size=100)
225
+ prediction_data = pd.DataFrame({"input": x_pred})
226
+ pred = fitted_model_instance.predict(prediction_data["input"])
227
+ # Perform elementwise comparison using numpy
228
+ assert isinstance(pred, np.ndarray)
229
+ assert len(pred) > 0
230
+
231
+
232
+ @pytest.mark.parametrize("combined", [True, False])
233
+ def test_sample_posterior_predictive(fitted_model_instance, combined):
234
+ n_pred = 100
235
+ x_pred = np.random.uniform(low=0, high=1, size=n_pred)
236
+ prediction_data = pd.DataFrame({"input": x_pred})
237
+ pred = fitted_model_instance.sample_posterior_predictive(
238
+ prediction_data["input"], combined=combined, extend_idata=True
239
+ )
240
+ chains = fitted_model_instance.idata.sample_stats.sizes["chain"]
241
+ draws = fitted_model_instance.idata.sample_stats.sizes["draw"]
242
+ expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
243
+ assert pred[fitted_model_instance.output_var].shape == expected_shape
244
+ assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)
245
+
246
+
247
+ @pytest.mark.parametrize("group", ["prior_predictive", "posterior_predictive"])
248
+ @pytest.mark.parametrize("extend_idata", [True, False])
249
+ def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idata):
250
+ output_var = fitted_model_instance.output_var
251
+ idata_prev = fitted_model_instance.idata[group][output_var]
252
+
253
+ # Since coordinates are provided, the dimension must match
254
+ n_pred = 100 # Must match toy_x
255
+ x_pred = np.random.uniform(0, 1, n_pred)
256
+
257
+ prediction_data = pd.DataFrame({"input": x_pred})
258
+ if group == "prior_predictive":
259
+ prediction_method = fitted_model_instance.sample_prior_predictive
260
+ else: # group == "posterior_predictive":
261
+ prediction_method = fitted_model_instance.sample_posterior_predictive
262
+
263
+ pred = prediction_method(prediction_data["input"], combined=False, extend_idata=extend_idata)
264
+
265
+ pred_unstacked = pred[output_var].values
266
+ idata_now = fitted_model_instance.idata[group][output_var].values
267
+
268
+ if extend_idata:
269
+ # After sampling, data in the model should be the same as the predictions
270
+ np.testing.assert_array_equal(idata_now, pred_unstacked)
271
+ # Data in the model should NOT be the same as before
272
+ if idata_now.shape == idata_prev.values.shape:
273
+ assert np.sum(np.abs(idata_now - idata_prev.values) < 1e-5) <= 2
274
+ else:
275
+ # After sampling, data in the model should be the same as it was before
276
+ np.testing.assert_array_equal(idata_now, idata_prev.values)
277
+ # Data in the model should NOT be the same as the predictions
278
+ if idata_now.shape == pred_unstacked.shape:
279
+ assert np.sum(np.abs(idata_now - pred_unstacked) < 1e-5) <= 2
280
+
281
+
282
+ def test_model_config_formatting():
283
+ model_config = {
284
+ "a": {
285
+ "loc": [0, 0],
286
+ "scale": 10,
287
+ "dims": [
288
+ "x",
289
+ ],
290
+ },
291
+ }
292
+ model_builder = test_ModelBuilder()
293
+ converted_model_config = model_builder._model_config_formatting(model_config)
294
+ np.testing.assert_equal(converted_model_config["a"]["dims"], ("x",))
295
+ np.testing.assert_equal(converted_model_config["a"]["loc"], np.array([0, 0]))
296
+
297
+
298
+ def test_id():
299
+ model_builder = test_ModelBuilder()
300
+ expected_id = hashlib.sha256(
301
+ str(model_builder.model_config.values()).encode()
302
+ + model_builder.version.encode()
303
+ + model_builder._model_type.encode()
304
+ ).hexdigest()[:16]
305
+
306
+ assert model_builder.id == expected_id
@@ -0,0 +1,45 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+
17
+ import numpy as np
18
+ import pymc as pm
19
+ import pytest
20
+
21
+ import pymc_extras as pmx
22
+
23
+
24
+ @pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.")
25
+ def test_pathfinder():
26
+ # Data of the Eight Schools Model
27
+ J = 8
28
+ y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
29
+ sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
30
+
31
+ with pm.Model() as model:
32
+ mu = pm.Normal("mu", mu=0.0, sigma=10.0)
33
+ tau = pm.HalfCauchy("tau", 5.0)
34
+
35
+ theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
36
+ obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)
37
+
38
+ idata = pmx.fit(method="pathfinder", random_seed=41)
39
+
40
+ assert idata.posterior["mu"].shape == (1, 1000)
41
+ assert idata.posterior["tau"].shape == (1, 1000)
42
+ assert idata.posterior["theta"].shape == (1, 1000, 8)
43
+ # FIXME: pathfinder doesn't find a reasonable mean! Fix bug or choose model pathfinder can handle
44
+ # np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0)
45
+ np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=0.5)
@@ -0,0 +1,24 @@
1
+ # try:
2
+ # import gpytorch
3
+ # import torch
4
+ # except ImportError as e:
5
+ # # print(
6
+ # # f"Please install Pytorch and GPyTorch to use this pivoted Cholesky implementation. Error {e}"
7
+ # # )
8
+ # pass
9
+ # import numpy as np
10
+ #
11
+ # import pymc_experimental as pmx
12
+ #
13
+ #
14
+ # def test_match_gpytorch_linearcg_output():
15
+ # N = 10
16
+ # rank = 5
17
+ # np.random.seed(1234) # nans with seed 1234
18
+ # K = np.random.randn(N, N)
19
+ # K = K @ K.T + N * np.eye(N)
20
+ # K_torch = torch.from_numpy(K)
21
+ #
22
+ # L_gpt = gpytorch.pivoted_cholesky(K_torch, rank=rank, error_tol=1e-3)
23
+ # L_np, _ = pmx.utils.pivoted_cholesky(K, max_iter=rank, error_tol=1e-3)
24
+ # assert np.allclose(L_gpt, L_np.T)
tests/test_printing.py ADDED
@@ -0,0 +1,98 @@
1
+ import numpy as np
2
+ import pymc as pm
3
+
4
+ from rich.console import Console
5
+
6
+ from pymc_extras.printing import model_table
7
+
8
+
9
+ def get_text(table) -> str:
10
+ console = Console(width=80)
11
+ with console.capture() as capture:
12
+ console.print(table)
13
+ return capture.get()
14
+
15
+
16
+ def test_model_table():
17
+ with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model:
18
+ x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
19
+ y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
20
+
21
+ mu = pm.Normal("mu", mu=0, sigma=1)
22
+ sigma = pm.HalfNormal("sigma", sigma=1)
23
+ global_intercept = pm.Normal("global_intercept", mu=0, sigma=1)
24
+ intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, shape=(20, 1))
25
+ beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject")
26
+
27
+ mu_trial = pm.Deterministic(
28
+ "mu_trial",
29
+ global_intercept.squeeze() + intercept_subject + beta_subject * x_data,
30
+ dims=["trial", "subject"],
31
+ )
32
+ noise = pm.Exponential("noise", lam=1)
33
+ y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject"))
34
+
35
+ pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject")
36
+
37
+ table_txt = get_text(model_table(model))
38
+ expected = """ Variable Expression Dimensions
39
+ ────────────────────────────────────────────────────────────────────────────────
40
+ x_data = Data trial[6] × subject[20]
41
+ y_data = Data trial[6] × subject[20]
42
+
43
+ mu ~ Normal(0, 1)
44
+ sigma ~ HalfNormal(0, 1)
45
+ global_intercept ~ Normal(0, 1)
46
+ intercept_subject ~ Normal(0, 1) [20, 1]
47
+ beta_subject ~ Normal(mu, sigma) subject[20]
48
+ noise ~ Exponential(f())
49
+ Parameter count = 44
50
+
51
+ mu_trial = f(intercept_subject, trial[6] × subject[20]
52
+ beta_subject,
53
+ global_intercept)
54
+
55
+ beta_subject_penalty = Potential(f(beta_subject)) subject[20]
56
+
57
+ y ~ Normal(mu_trial, noise) trial[6] × subject[20]
58
+ """
59
+ assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
60
+
61
+ table_txt = get_text(model_table(model, split_groups=False))
62
+ expected = """ Variable Expression Dimensions
63
+ ────────────────────────────────────────────────────────────────────────────────
64
+ x_data = Data trial[6] × subject[20]
65
+ y_data = Data trial[6] × subject[20]
66
+ mu ~ Normal(0, 1)
67
+ sigma ~ HalfNormal(0, 1)
68
+ global_intercept ~ Normal(0, 1)
69
+ intercept_subject ~ Normal(0, 1) [20, 1]
70
+ beta_subject ~ Normal(mu, sigma) subject[20]
71
+ mu_trial = f(intercept_subject, trial[6] × subject[20]
72
+ beta_subject,
73
+ global_intercept)
74
+ noise ~ Exponential(f())
75
+ y ~ Normal(mu_trial, noise) trial[6] × subject[20]
76
+ beta_subject_penalty = Potential(f(beta_subject)) subject[20]
77
+ Parameter count = 44
78
+ """
79
+ assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
80
+
81
+ table_txt = get_text(
82
+ model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False)
83
+ )
84
+ expected = """ Variable Expression Dimensions
85
+ ────────────────────────────────────────────────────────────────────────────
86
+ x_data = Data trial[6] × subject[20]
87
+ y_data = Data trial[6] × subject[20]
88
+ mu ~ Normal(0, 1)
89
+ sigma ~ HalfNormal(0, 1)
90
+ global_intercept ~ Normal(0, 1)
91
+ intercept_subject ~ Normal(0, 1) [20, 1]
92
+ beta_subject ~ Normal(mu, sigma) subject[20]
93
+ mu_trial = f(intercept_subject, ...) trial[6] × subject[20]
94
+ noise ~ Exponential(f())
95
+ y ~ Normal(mu_trial, noise) trial[6] × subject[20]
96
+ beta_subject_penalty = Potential(f(beta_subject)) subject[20]
97
+ """
98
+ assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
@@ -0,0 +1,172 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import arviz as az
17
+ import numpy as np
18
+ import pymc as pm
19
+ import pytest
20
+
21
+ from pymc.distributions import transforms
22
+
23
+ import pymc_extras as pmx
24
+
25
+
26
+ @pytest.mark.parametrize(
27
+ "case",
28
+ [
29
+ (("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
30
+ (("a", None), dict(name="a", transform=None, dims=None)),
31
+ (("a", transforms.log), dict(name="a", transform=transforms.log, dims=None)),
32
+ (
33
+ ("a", dict(transform=transforms.log)),
34
+ dict(name="a", transform=transforms.log, dims=None),
35
+ ),
36
+ (("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
37
+ (("a", dict(name="b", dims="test")), dict(name="b", transform=None, dims="test")),
38
+ (("a", ("test",)), dict(name="a", transform=None, dims=("test",))),
39
+ ],
40
+ )
41
+ def test_parsing_arguments(case):
42
+ inp, out = case
43
+ test = pmx.utils.prior._arg_to_param_cfg(*inp)
44
+ assert test == out
45
+
46
+
47
+ @pytest.fixture
48
+ def coords():
49
+ return dict(test=range(3), simplex=range(4))
50
+
51
+
52
+ @pytest.fixture(
53
+ params=[
54
+ [
55
+ ("t",),
56
+ dict(
57
+ a="d",
58
+ b=dict(transform=transforms.log, dims=("test",)),
59
+ c=dict(transform=transforms.simplex, dims=("simplex",)),
60
+ ),
61
+ ],
62
+ [("t",), dict()],
63
+ ]
64
+ )
65
+ def user_param_cfg(request):
66
+ return request.param
67
+
68
+
69
+ @pytest.fixture
70
+ def param_cfg(user_param_cfg):
71
+ return pmx.utils.prior._parse_args(user_param_cfg[0], **user_param_cfg[1])
72
+
73
+
74
+ @pytest.fixture
75
+ def transformed_data(param_cfg, coords):
76
+ vars = dict()
77
+ for k, cfg in param_cfg.items():
78
+ if cfg["dims"] is not None:
79
+ extra_dims = [len(coords[d]) for d in cfg["dims"]]
80
+ if cfg["transform"] is not None:
81
+ t = np.random.randn(*extra_dims)
82
+ extra_dims = tuple(cfg["transform"].forward(t).shape.eval())
83
+ else:
84
+ extra_dims = []
85
+ orig = np.random.randn(4, 100, *extra_dims)
86
+ vars[k] = orig
87
+ return vars
88
+
89
+
90
+ @pytest.fixture
91
+ def idata(transformed_data, param_cfg):
92
+ vars = dict()
93
+ for k, orig in transformed_data.items():
94
+ cfg = param_cfg[k]
95
+ if cfg["transform"] is not None:
96
+ var = cfg["transform"].backward(orig).eval()
97
+ else:
98
+ var = orig
99
+ assert not np.isnan(var).any()
100
+ vars[k] = var
101
+ return az.convert_to_inference_data(vars)
102
+
103
+
104
+ def test_idata_for_tests(idata, param_cfg):
105
+ assert set(idata.posterior.keys()) == set(param_cfg)
106
+ assert len(idata.posterior.coords["chain"]) == 4
107
+ assert len(idata.posterior.coords["draw"]) == 100
108
+
109
+
110
+ def test_args_compose():
111
+ cfg = pmx.utils.prior._parse_args(
112
+ var_names=["a"],
113
+ b=("test",),
114
+ c=transforms.log,
115
+ d="e",
116
+ f=dict(dims="test"),
117
+ g=dict(name="h", dims="test", transform=transforms.log),
118
+ )
119
+ assert cfg == dict(
120
+ a=dict(name="a", dims=None, transform=None),
121
+ b=dict(name="b", dims=("test",), transform=None),
122
+ c=dict(name="c", dims=None, transform=transforms.log),
123
+ d=dict(name="e", dims=None, transform=None),
124
+ f=dict(name="f", dims="test", transform=None),
125
+ g=dict(name="h", dims="test", transform=transforms.log),
126
+ )
127
+
128
+
129
+ def test_transform_idata(transformed_data, idata, param_cfg):
130
+ flat_info = pmx.utils.prior._flatten(idata, **param_cfg)
131
+ expected_shape = 0
132
+ for v in transformed_data.values():
133
+ expected_shape += int(np.prod(v.shape[2:]))
134
+ assert flat_info["data"].shape[1] == expected_shape
135
+ assert len(flat_info["info"]) == len(param_cfg)
136
+ assert "sinfo" in flat_info["info"][0]
137
+ assert "vinfo" in flat_info["info"][0]
138
+
139
+
140
+ @pytest.fixture
141
+ def flat_info(idata, param_cfg):
142
+ return pmx.utils.prior._flatten(idata, **param_cfg)
143
+
144
+
145
+ def test_mean_chol(flat_info):
146
+ mean, chol = pmx.utils.prior._mean_chol(flat_info["data"])
147
+ assert mean.shape == (flat_info["data"].shape[1],)
148
+ assert chol.shape == (flat_info["data"].shape[1],) * 2
149
+
150
+
151
+ def test_mvn_prior_from_flat_info(flat_info, coords, param_cfg):
152
+ with pm.Model(coords=coords) as model:
153
+ priors = pmx.utils.prior._mvn_prior_from_flat_info("trace_prior_", flat_info)
154
+ test_prior = pm.sample_prior_predictive(1)
155
+ names = [p["name"] for p in param_cfg.values()]
156
+ assert set(model.named_vars) == {"trace_prior_", *names}
157
+
158
+
159
+ def test_prior_from_idata(idata, user_param_cfg, coords, param_cfg):
160
+ with pm.Model(coords=coords) as model:
161
+ priors = pmx.utils.prior.prior_from_idata(
162
+ idata, var_names=user_param_cfg[0], **user_param_cfg[1]
163
+ )
164
+ test_prior = pm.sample_prior_predictive(1)
165
+ names = [p["name"] for p in param_cfg.values()]
166
+ assert set(model.named_vars) == {"trace_prior_", *names}
167
+
168
+
169
+ def test_empty(idata, coords):
170
+ with pm.Model(coords=coords):
171
+ priors = pmx.utils.prior.prior_from_idata(idata)
172
+ assert not priors