pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pymc_extras/__init__.py +6 -4
  2. pymc_extras/distributions/__init__.py +2 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/distributions/transforms/__init__.py +3 -0
  6. pymc_extras/distributions/transforms/partial_order.py +227 -0
  7. pymc_extras/inference/__init__.py +4 -2
  8. pymc_extras/inference/find_map.py +62 -17
  9. pymc_extras/inference/fit.py +6 -4
  10. pymc_extras/inference/laplace.py +14 -8
  11. pymc_extras/inference/pathfinder/lbfgs.py +49 -13
  12. pymc_extras/inference/pathfinder/pathfinder.py +89 -103
  13. pymc_extras/statespace/core/statespace.py +191 -52
  14. pymc_extras/statespace/filters/distributions.py +15 -16
  15. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  16. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  17. pymc_extras/statespace/models/ETS.py +10 -0
  18. pymc_extras/statespace/models/SARIMAX.py +26 -5
  19. pymc_extras/statespace/models/VARMAX.py +12 -2
  20. pymc_extras/statespace/models/structural.py +18 -5
  21. pymc_extras/statespace/utils/data_tools.py +24 -9
  22. pymc_extras-0.2.6.dist-info/METADATA +318 -0
  23. pymc_extras-0.2.6.dist-info/RECORD +65 -0
  24. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
  25. pymc_extras/version.py +0 -11
  26. pymc_extras/version.txt +0 -1
  27. pymc_extras-0.2.4.dist-info/METADATA +0 -110
  28. pymc_extras-0.2.4.dist-info/RECORD +0 -105
  29. pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
  30. tests/__init__.py +0 -13
  31. tests/distributions/__init__.py +0 -19
  32. tests/distributions/test_continuous.py +0 -185
  33. tests/distributions/test_discrete.py +0 -210
  34. tests/distributions/test_discrete_markov_chain.py +0 -258
  35. tests/distributions/test_multivariate.py +0 -304
  36. tests/model/__init__.py +0 -0
  37. tests/model/marginal/__init__.py +0 -0
  38. tests/model/marginal/test_distributions.py +0 -132
  39. tests/model/marginal/test_graph_analysis.py +0 -182
  40. tests/model/marginal/test_marginal_model.py +0 -967
  41. tests/model/test_model_api.py +0 -38
  42. tests/statespace/__init__.py +0 -0
  43. tests/statespace/test_ETS.py +0 -411
  44. tests/statespace/test_SARIMAX.py +0 -405
  45. tests/statespace/test_VARMAX.py +0 -184
  46. tests/statespace/test_coord_assignment.py +0 -116
  47. tests/statespace/test_distributions.py +0 -270
  48. tests/statespace/test_kalman_filter.py +0 -326
  49. tests/statespace/test_representation.py +0 -175
  50. tests/statespace/test_statespace.py +0 -872
  51. tests/statespace/test_statespace_JAX.py +0 -156
  52. tests/statespace/test_structural.py +0 -836
  53. tests/statespace/utilities/__init__.py +0 -0
  54. tests/statespace/utilities/shared_fixtures.py +0 -9
  55. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  56. tests/statespace/utilities/test_helpers.py +0 -310
  57. tests/test_blackjax_smc.py +0 -222
  58. tests/test_find_map.py +0 -103
  59. tests/test_histogram_approximation.py +0 -109
  60. tests/test_laplace.py +0 -265
  61. tests/test_linearmodel.py +0 -208
  62. tests/test_model_builder.py +0 -306
  63. tests/test_pathfinder.py +0 -203
  64. tests/test_pivoted_cholesky.py +0 -24
  65. tests/test_printing.py +0 -98
  66. tests/test_prior_from_trace.py +0 -172
  67. tests/test_splines.py +0 -77
  68. tests/utils.py +0 -0
  69. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
@@ -1,306 +0,0 @@
1
- # Copyright 2023 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import hashlib
16
- import json
17
- import sys
18
- import tempfile
19
-
20
- import numpy as np
21
- import pandas as pd
22
- import pymc as pm
23
- import pytest
24
-
25
- from pymc_extras.model_builder import ModelBuilder
26
-
27
-
28
- @pytest.fixture(scope="module")
29
- def toy_X():
30
- x = np.linspace(start=0, stop=1, num=100)
31
- X = pd.DataFrame({"input": x})
32
- return X
33
-
34
-
35
- @pytest.fixture(scope="module")
36
- def toy_y(toy_X):
37
- y = 5 * toy_X["input"] + 3
38
- y = y + np.random.normal(0, 1, size=len(toy_X))
39
- y = pd.Series(y, name="output")
40
- return y
41
-
42
-
43
- def get_unfitted_model_instance(X, y):
44
- """Creates an unfitted model instance to which idata can be copied in
45
- and then used as a fitted model instance. That way a fitted model
46
- can be used multiple times without having to run `fit` multiple times."""
47
- sampler_config = {
48
- "draws": 20,
49
- "tune": 10,
50
- "chains": 2,
51
- "target_accept": 0.95,
52
- }
53
- model_config = {
54
- "a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
55
- "b": {"loc": 0, "scale": 10},
56
- "obs_error": 2,
57
- }
58
- model = test_ModelBuilder(
59
- model_config=model_config, sampler_config=sampler_config, test_parameter="test_paramter"
60
- )
61
- # Do the things that `model.fit` does except sample to create idata.
62
- model._generate_and_preprocess_model_data(X, y.values.flatten())
63
- model.build_model(X, y)
64
- return model
65
-
66
-
67
- @pytest.fixture(scope="module")
68
- def fitted_model_instance_base(toy_X, toy_y):
69
- """Because fitting takes a relatively long time, this is intended to
70
- be used only once and then have new instances created and fit data patched in
71
- for tests that use a fitted model instance. Tests should use
72
- `fitted_model_instance` instead of this."""
73
- model = get_unfitted_model_instance(toy_X, toy_y)
74
- model.fit(toy_X, toy_y)
75
- return model
76
-
77
-
78
- @pytest.fixture
79
- def fitted_model_instance(toy_X, toy_y, fitted_model_instance_base):
80
- """Get a fitted model instance. A new instance is created and fit data is
81
- patched in, so tests using this fixture can modify the model object without
82
- affecting other tests."""
83
- model = get_unfitted_model_instance(toy_X, toy_y)
84
- model.idata = fitted_model_instance_base.idata.copy()
85
- return model
86
-
87
-
88
- class test_ModelBuilder(ModelBuilder):
89
- def __init__(self, model_config=None, sampler_config=None, test_parameter=None):
90
- self.test_parameter = test_parameter
91
- super().__init__(model_config=model_config, sampler_config=sampler_config)
92
-
93
- _model_type = "test_model"
94
- version = "0.1"
95
-
96
- def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
97
- coords = {"numbers": np.arange(len(X))}
98
- self.generate_and_preprocess_model_data(X, y)
99
- with pm.Model(coords=coords) as self.model:
100
- if model_config is None:
101
- model_config = self.model_config
102
- x = pm.Data("x", self.X["input"].values)
103
- y_data = pm.Data("y_data", self.y)
104
-
105
- # prior parameters
106
- a_loc = model_config["a"]["loc"]
107
- a_scale = model_config["a"]["scale"]
108
- b_loc = model_config["b"]["loc"]
109
- b_scale = model_config["b"]["scale"]
110
- obs_error = model_config["obs_error"]
111
-
112
- # priors
113
- a = pm.Normal("a", a_loc, sigma=a_scale, dims=model_config["a"]["dims"])
114
- b = pm.Normal("b", b_loc, sigma=b_scale)
115
- obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
116
-
117
- # observed data
118
- output = pm.Normal("output", a + b * x, obs_error, shape=x.shape, observed=y_data)
119
-
120
- def _save_input_params(self, idata):
121
- idata.attrs["test_paramter"] = json.dumps(self.test_parameter)
122
-
123
- @property
124
- def output_var(self):
125
- return "output"
126
-
127
- def _data_setter(self, x: pd.Series, y: pd.Series = None):
128
- with self.model:
129
- pm.set_data({"x": x.values})
130
- if y is not None:
131
- pm.set_data({"y_data": y.values})
132
-
133
- @property
134
- def _serializable_model_config(self):
135
- return self.model_config
136
-
137
- def generate_and_preprocess_model_data(self, X: pd.DataFrame, y: pd.Series):
138
- self.X = X
139
- self.y = y
140
-
141
- @staticmethod
142
- def get_default_model_config() -> dict:
143
- return {
144
- "a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
145
- "b": {"loc": 0, "scale": 10},
146
- "obs_error": 2,
147
- }
148
-
149
- def _generate_and_preprocess_model_data(
150
- self, X: pd.DataFrame | pd.Series, y: pd.Series
151
- ) -> None:
152
- self.X = X
153
- self.y = y
154
-
155
- @staticmethod
156
- def get_default_sampler_config() -> dict:
157
- return {
158
- "draws": 10,
159
- "tune": 10,
160
- "chains": 3,
161
- "target_accept": 0.95,
162
- }
163
-
164
-
165
- def test_save_input_params(fitted_model_instance):
166
- assert fitted_model_instance.idata.attrs["test_paramter"] == '"test_paramter"'
167
-
168
-
169
- @pytest.mark.skipif(
170
- sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
171
- )
172
- def test_save_load(fitted_model_instance):
173
- temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
174
- fitted_model_instance.save(temp.name)
175
- test_builder2 = test_ModelBuilder.load(temp.name)
176
- assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
177
- assert fitted_model_instance.id == test_builder2.id
178
- x_pred = np.random.uniform(low=0, high=1, size=100)
179
- prediction_data = pd.DataFrame({"input": x_pred})
180
- pred1 = fitted_model_instance.predict(prediction_data["input"])
181
- pred2 = test_builder2.predict(prediction_data["input"])
182
- assert pred1.shape == pred2.shape
183
- temp.close()
184
-
185
-
186
- def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> ModelBuilder:
187
- if check_idata:
188
- assert fitted_model_instance.idata is not None
189
- assert "posterior" in fitted_model_instance.idata.groups()
190
-
191
-
192
- def test_save_without_fit_raises_runtime_error():
193
- model_builder = test_ModelBuilder()
194
- with pytest.raises(RuntimeError):
195
- model_builder.save("saved_model")
196
-
197
-
198
- def test_empty_sampler_config_fit(toy_X, toy_y):
199
- sampler_config = {}
200
- model_builder = test_ModelBuilder(sampler_config=sampler_config)
201
- model_builder.idata = model_builder.fit(X=toy_X, y=toy_y)
202
- assert model_builder.idata is not None
203
- assert "posterior" in model_builder.idata.groups()
204
-
205
-
206
- def test_fit(fitted_model_instance):
207
- prediction_data = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
208
- pred = fitted_model_instance.predict(prediction_data["input"])
209
- post_pred = fitted_model_instance.sample_posterior_predictive(
210
- prediction_data["input"], extend_idata=True, combined=True
211
- )
212
- post_pred[fitted_model_instance.output_var].shape[0] == prediction_data.input.shape
213
-
214
-
215
- def test_fit_no_y(toy_X):
216
- model_builder = test_ModelBuilder()
217
- model_builder.idata = model_builder.fit(X=toy_X, chains=1, tune=1, draws=1)
218
- assert model_builder.model is not None
219
- assert model_builder.idata is not None
220
- assert "posterior" in model_builder.idata.groups()
221
-
222
-
223
- def test_predict(fitted_model_instance):
224
- x_pred = np.random.uniform(low=0, high=1, size=100)
225
- prediction_data = pd.DataFrame({"input": x_pred})
226
- pred = fitted_model_instance.predict(prediction_data["input"])
227
- # Perform elementwise comparison using numpy
228
- assert isinstance(pred, np.ndarray)
229
- assert len(pred) > 0
230
-
231
-
232
- @pytest.mark.parametrize("combined", [True, False])
233
- def test_sample_posterior_predictive(fitted_model_instance, combined):
234
- n_pred = 100
235
- x_pred = np.random.uniform(low=0, high=1, size=n_pred)
236
- prediction_data = pd.DataFrame({"input": x_pred})
237
- pred = fitted_model_instance.sample_posterior_predictive(
238
- prediction_data["input"], combined=combined, extend_idata=True
239
- )
240
- chains = fitted_model_instance.idata.sample_stats.sizes["chain"]
241
- draws = fitted_model_instance.idata.sample_stats.sizes["draw"]
242
- expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
243
- assert pred[fitted_model_instance.output_var].shape == expected_shape
244
- assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)
245
-
246
-
247
- @pytest.mark.parametrize("group", ["prior_predictive", "posterior_predictive"])
248
- @pytest.mark.parametrize("extend_idata", [True, False])
249
- def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idata):
250
- output_var = fitted_model_instance.output_var
251
- idata_prev = fitted_model_instance.idata[group][output_var]
252
-
253
- # Since coordinates are provided, the dimension must match
254
- n_pred = 100 # Must match toy_x
255
- x_pred = np.random.uniform(0, 1, n_pred)
256
-
257
- prediction_data = pd.DataFrame({"input": x_pred})
258
- if group == "prior_predictive":
259
- prediction_method = fitted_model_instance.sample_prior_predictive
260
- else: # group == "posterior_predictive":
261
- prediction_method = fitted_model_instance.sample_posterior_predictive
262
-
263
- pred = prediction_method(prediction_data["input"], combined=False, extend_idata=extend_idata)
264
-
265
- pred_unstacked = pred[output_var].values
266
- idata_now = fitted_model_instance.idata[group][output_var].values
267
-
268
- if extend_idata:
269
- # After sampling, data in the model should be the same as the predictions
270
- np.testing.assert_array_equal(idata_now, pred_unstacked)
271
- # Data in the model should NOT be the same as before
272
- if idata_now.shape == idata_prev.values.shape:
273
- assert np.sum(np.abs(idata_now - idata_prev.values) < 1e-5) <= 2
274
- else:
275
- # After sampling, data in the model should be the same as it was before
276
- np.testing.assert_array_equal(idata_now, idata_prev.values)
277
- # Data in the model should NOT be the same as the predictions
278
- if idata_now.shape == pred_unstacked.shape:
279
- assert np.sum(np.abs(idata_now - pred_unstacked) < 1e-5) <= 2
280
-
281
-
282
- def test_model_config_formatting():
283
- model_config = {
284
- "a": {
285
- "loc": [0, 0],
286
- "scale": 10,
287
- "dims": [
288
- "x",
289
- ],
290
- },
291
- }
292
- model_builder = test_ModelBuilder()
293
- converted_model_config = model_builder._model_config_formatting(model_config)
294
- np.testing.assert_equal(converted_model_config["a"]["dims"], ("x",))
295
- np.testing.assert_equal(converted_model_config["a"]["loc"], np.array([0, 0]))
296
-
297
-
298
- def test_id():
299
- model_builder = test_ModelBuilder()
300
- expected_id = hashlib.sha256(
301
- str(model_builder.model_config.values()).encode()
302
- + model_builder.version.encode()
303
- + model_builder._model_type.encode()
304
- ).hexdigest()[:16]
305
-
306
- assert model_builder.id == expected_id
tests/test_pathfinder.py DELETED
@@ -1,203 +0,0 @@
1
- # Copyright 2022 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import sys
16
-
17
- import numpy as np
18
- import pymc as pm
19
- import pytest
20
-
21
- pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning")
22
-
23
- import pymc_extras as pmx
24
-
25
-
26
- def eight_schools_model() -> pm.Model:
27
- J = 8
28
- y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
29
- sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
30
-
31
- with pm.Model() as model:
32
- mu = pm.Normal("mu", mu=0.0, sigma=10.0)
33
- tau = pm.HalfCauchy("tau", 5.0)
34
-
35
- theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
36
- obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)
37
-
38
- return model
39
-
40
-
41
- @pytest.fixture
42
- def reference_idata():
43
- model = eight_schools_model()
44
- with model:
45
- idata = pmx.fit(
46
- method="pathfinder",
47
- num_paths=10,
48
- jitter=12.0,
49
- random_seed=41,
50
- inference_backend="pymc",
51
- )
52
- return idata
53
-
54
-
55
- @pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
56
- def test_pathfinder(inference_backend, reference_idata):
57
- if inference_backend == "blackjax" and sys.platform == "win32":
58
- pytest.skip("JAX not supported on windows")
59
-
60
- if inference_backend == "blackjax":
61
- model = eight_schools_model()
62
- with model:
63
- idata = pmx.fit(
64
- method="pathfinder",
65
- num_paths=10,
66
- jitter=12.0,
67
- random_seed=41,
68
- inference_backend=inference_backend,
69
- )
70
- else:
71
- idata = reference_idata
72
- np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=0.95)
73
- np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.35)
74
-
75
- assert idata.posterior["mu"].shape == (1, 1000)
76
- assert idata.posterior["tau"].shape == (1, 1000)
77
- assert idata.posterior["theta"].shape == (1, 1000, 8)
78
-
79
-
80
- @pytest.mark.parametrize("concurrent", ["thread", "process"])
81
- def test_concurrent_results(reference_idata, concurrent):
82
- model = eight_schools_model()
83
- with model:
84
- idata_conc = pmx.fit(
85
- method="pathfinder",
86
- num_paths=10,
87
- jitter=12.0,
88
- random_seed=41,
89
- inference_backend="pymc",
90
- concurrent=concurrent,
91
- )
92
-
93
- np.testing.assert_allclose(
94
- reference_idata.posterior.mu.data.mean(),
95
- idata_conc.posterior.mu.data.mean(),
96
- atol=0.4,
97
- )
98
-
99
- np.testing.assert_allclose(
100
- reference_idata.posterior.tau.data.mean(),
101
- idata_conc.posterior.tau.data.mean(),
102
- atol=0.4,
103
- )
104
-
105
-
106
- def test_seed(reference_idata):
107
- model = eight_schools_model()
108
- with model:
109
- idata_41 = pmx.fit(
110
- method="pathfinder",
111
- num_paths=4,
112
- jitter=10.0,
113
- random_seed=41,
114
- inference_backend="pymc",
115
- )
116
-
117
- idata_123 = pmx.fit(
118
- method="pathfinder",
119
- num_paths=4,
120
- jitter=10.0,
121
- random_seed=123,
122
- inference_backend="pymc",
123
- )
124
-
125
- assert not np.allclose(idata_41.posterior.mu.data.mean(), idata_123.posterior.mu.data.mean())
126
-
127
- assert np.allclose(idata_41.posterior.mu.data.mean(), idata_41.posterior.mu.data.mean())
128
-
129
-
130
- def test_bfgs_sample():
131
- import pytensor.tensor as pt
132
-
133
- from pymc_extras.inference.pathfinder.pathfinder import (
134
- alpha_recover,
135
- bfgs_sample,
136
- inverse_hessian_factors,
137
- )
138
-
139
- """test BFGS sampling"""
140
- Lp1, N = 8, 10
141
- L = Lp1 - 1
142
- J = 6
143
- num_samples = 1000
144
-
145
- # mock data
146
- x_data = np.random.randn(Lp1, N)
147
- g_data = np.random.randn(Lp1, N)
148
-
149
- # get factors
150
- x_full = pt.as_tensor(x_data, dtype="float64")
151
- g_full = pt.as_tensor(g_data, dtype="float64")
152
- epsilon = 1e-11
153
-
154
- x = x_full[1:]
155
- g = g_full[1:]
156
- alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon)
157
- beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J)
158
-
159
- # sample
160
- phi, logq = bfgs_sample(
161
- num_samples=num_samples,
162
- x=x,
163
- g=g,
164
- alpha=alpha,
165
- beta=beta,
166
- gamma=gamma,
167
- )
168
-
169
- # check shapes
170
- assert beta.eval().shape == (L, N, 2 * J)
171
- assert gamma.eval().shape == (L, 2 * J, 2 * J)
172
- assert phi.eval().shape == (L, num_samples, N)
173
- assert logq.eval().shape == (L, num_samples)
174
-
175
-
176
- @pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])
177
- def test_pathfinder_importance_sampling(importance_sampling):
178
- model = eight_schools_model()
179
-
180
- num_paths = 4
181
- num_draws_per_path = 300
182
- num_draws = 750
183
-
184
- with model:
185
- idata = pmx.fit(
186
- method="pathfinder",
187
- num_paths=num_paths,
188
- num_draws_per_path=num_draws_per_path,
189
- num_draws=num_draws,
190
- maxiter=5,
191
- random_seed=41,
192
- inference_backend="pymc",
193
- importance_sampling=importance_sampling,
194
- )
195
-
196
- if importance_sampling is None:
197
- assert idata.posterior["mu"].shape == (num_paths, num_draws_per_path)
198
- assert idata.posterior["tau"].shape == (num_paths, num_draws_per_path)
199
- assert idata.posterior["theta"].shape == (num_paths, num_draws_per_path, 8)
200
- else:
201
- assert idata.posterior["mu"].shape == (1, num_draws)
202
- assert idata.posterior["tau"].shape == (1, num_draws)
203
- assert idata.posterior["theta"].shape == (1, num_draws, 8)
@@ -1,24 +0,0 @@
1
- # try:
2
- # import gpytorch
3
- # import torch
4
- # except ImportError as e:
5
- # # print(
6
- # # f"Please install Pytorch and GPyTorch to use this pivoted Cholesky implementation. Error {e}"
7
- # # )
8
- # pass
9
- # import numpy as np
10
- #
11
- # import pymc_extras as pmx
12
- #
13
- #
14
- # def test_match_gpytorch_linearcg_output():
15
- # N = 10
16
- # rank = 5
17
- # np.random.seed(1234) # nans with seed 1234
18
- # K = np.random.randn(N, N)
19
- # K = K @ K.T + N * np.eye(N)
20
- # K_torch = torch.from_numpy(K)
21
- #
22
- # L_gpt = gpytorch.pivoted_cholesky(K_torch, rank=rank, error_tol=1e-3)
23
- # L_np, _ = pmx.utils.pivoted_cholesky(K, max_iter=rank, error_tol=1e-3)
24
- # assert np.allclose(L_gpt, L_np.T)
tests/test_printing.py DELETED
@@ -1,98 +0,0 @@
1
- import numpy as np
2
- import pymc as pm
3
-
4
- from rich.console import Console
5
-
6
- from pymc_extras.printing import model_table
7
-
8
-
9
- def get_text(table) -> str:
10
- console = Console(width=80)
11
- with console.capture() as capture:
12
- console.print(table)
13
- return capture.get()
14
-
15
-
16
- def test_model_table():
17
- with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model:
18
- x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
19
- y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
20
-
21
- mu = pm.Normal("mu", mu=0, sigma=1)
22
- sigma = pm.HalfNormal("sigma", sigma=1)
23
- global_intercept = pm.Normal("global_intercept", mu=0, sigma=1)
24
- intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, shape=(20, 1))
25
- beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject")
26
-
27
- mu_trial = pm.Deterministic(
28
- "mu_trial",
29
- global_intercept.squeeze() + intercept_subject + beta_subject * x_data,
30
- dims=["trial", "subject"],
31
- )
32
- noise = pm.Exponential("noise", lam=1)
33
- y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject"))
34
-
35
- pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject")
36
-
37
- table_txt = get_text(model_table(model))
38
- expected = """ Variable Expression Dimensions
39
- ────────────────────────────────────────────────────────────────────────────────
40
- x_data = Data trial[6] × subject[20]
41
- y_data = Data trial[6] × subject[20]
42
-
43
- mu ~ Normal(0, 1)
44
- sigma ~ HalfNormal(0, 1)
45
- global_intercept ~ Normal(0, 1)
46
- intercept_subject ~ Normal(0, 1) [20, 1]
47
- beta_subject ~ Normal(mu, sigma) subject[20]
48
- noise ~ Exponential(f())
49
- Parameter count = 44
50
-
51
- mu_trial = f(intercept_subject, trial[6] × subject[20]
52
- beta_subject,
53
- global_intercept)
54
-
55
- beta_subject_penalty = Potential(f(beta_subject)) subject[20]
56
-
57
- y ~ Normal(mu_trial, noise) trial[6] × subject[20]
58
- """
59
- assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
60
-
61
- table_txt = get_text(model_table(model, split_groups=False))
62
- expected = """ Variable Expression Dimensions
63
- ────────────────────────────────────────────────────────────────────────────────
64
- x_data = Data trial[6] × subject[20]
65
- y_data = Data trial[6] × subject[20]
66
- mu ~ Normal(0, 1)
67
- sigma ~ HalfNormal(0, 1)
68
- global_intercept ~ Normal(0, 1)
69
- intercept_subject ~ Normal(0, 1) [20, 1]
70
- beta_subject ~ Normal(mu, sigma) subject[20]
71
- mu_trial = f(intercept_subject, trial[6] × subject[20]
72
- beta_subject,
73
- global_intercept)
74
- noise ~ Exponential(f())
75
- y ~ Normal(mu_trial, noise) trial[6] × subject[20]
76
- beta_subject_penalty = Potential(f(beta_subject)) subject[20]
77
- Parameter count = 44
78
- """
79
- assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
80
-
81
- table_txt = get_text(
82
- model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False)
83
- )
84
- expected = """ Variable Expression Dimensions
85
- ────────────────────────────────────────────────────────────────────────────
86
- x_data = Data trial[6] × subject[20]
87
- y_data = Data trial[6] × subject[20]
88
- mu ~ Normal(0, 1)
89
- sigma ~ HalfNormal(0, 1)
90
- global_intercept ~ Normal(0, 1)
91
- intercept_subject ~ Normal(0, 1) [20, 1]
92
- beta_subject ~ Normal(mu, sigma) subject[20]
93
- mu_trial = f(intercept_subject, ...) trial[6] × subject[20]
94
- noise ~ Exponential(f())
95
- y ~ Normal(mu_trial, noise) trial[6] × subject[20]
96
- beta_subject_penalty = Potential(f(beta_subject)) subject[20]
97
- """
98
- assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]