pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,109 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy as np
17
+ import pymc as pm
18
+ import pytest
19
+
20
+ import pymc_extras as pmx
21
+
22
+
23
+ @pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format)
24
+ @pytest.mark.parametrize("zero_inflation", [True, False], ids="ZI={}".format)
25
+ @pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
26
+ def test_histogram_init_cont(use_dask, zero_inflation, ndims):
27
+ data = np.random.randn(*(10000, *(2,) * (ndims - 1)))
28
+ if zero_inflation:
29
+ data = abs(data)
30
+ data[:100] = 0
31
+ if use_dask:
32
+ dask = pytest.importorskip("dask")
33
+ dask_df = pytest.importorskip("dask.dataframe")
34
+ data = dask_df.from_array(data)
35
+ histogram = pmx.distributions.histogram_utils.quantile_histogram(
36
+ data, n_quantiles=100, zero_inflation=zero_inflation
37
+ )
38
+ if use_dask:
39
+ (histogram,) = dask.compute(histogram)
40
+ assert isinstance(histogram, dict)
41
+ assert isinstance(histogram["mid"], np.ndarray)
42
+ assert np.issubdtype(histogram["mid"].dtype, np.floating)
43
+ size = 99 + zero_inflation
44
+ assert histogram["mid"].shape == (size,) + (1,) * len(data.shape[1:])
45
+ assert histogram["lower"].shape == (size,) + (1,) * len(data.shape[1:])
46
+ assert histogram["upper"].shape == (size,) + (1,) * len(data.shape[1:])
47
+ assert histogram["count"].shape == (size,) + data.shape[1:]
48
+ assert (histogram["count"].sum(0) == 10000).all()
49
+ if zero_inflation:
50
+ (histogram["count"][0] == 100).all()
51
+
52
+
53
+ @pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format)
54
+ @pytest.mark.parametrize("min_count", [None, 5], ids="min_count={}".format)
55
+ @pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
56
+ def test_histogram_init_discrete(use_dask, min_count, ndims):
57
+ data = np.random.randint(0, 100, size=(10000,) + (2,) * (ndims - 1))
58
+ u, c = np.unique(data, return_counts=True)
59
+ if use_dask:
60
+ dask = pytest.importorskip("dask")
61
+ dask_df = pytest.importorskip("dask.dataframe")
62
+ data = dask_df.from_array(data)
63
+ histogram = pmx.distributions.histogram_utils.discrete_histogram(data, min_count=min_count)
64
+ if use_dask:
65
+ (histogram,) = dask.compute(histogram)
66
+ assert isinstance(histogram, dict)
67
+ assert isinstance(histogram["mid"], np.ndarray)
68
+ assert np.issubdtype(histogram["mid"].dtype, np.integer)
69
+ if min_count is not None:
70
+ size = int((c >= min_count).sum())
71
+ else:
72
+ size = len(u)
73
+ assert histogram["mid"].shape == (size,) + (1,) * len(data.shape[1:])
74
+ assert histogram["count"].shape == (size,) + data.shape[1:]
75
+ if not min_count:
76
+ assert (histogram["count"].sum(0) == 10000).all()
77
+
78
+
79
+ @pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format)
80
+ @pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
81
+ def test_histogram_approx_cont(use_dask, ndims):
82
+ data = np.random.randn(*(10000, *(2,) * (ndims - 1)))
83
+ if use_dask:
84
+ dask = pytest.importorskip("dask")
85
+ dask_df = pytest.importorskip("dask.dataframe")
86
+ data = dask_df.from_array(data)
87
+ with pm.Model():
88
+ m = pm.Normal("m")
89
+ s = pm.HalfNormal("s", size=2 if ndims > 1 else 1)
90
+ pot = pmx.distributions.histogram_utils.histogram_approximation(
91
+ "histogram_potential", pm.Normal.dist(m, s), observed=data, n_quantiles=1000
92
+ )
93
+ trace = pm.sample(10, tune=0) # very fast
94
+
95
+
96
+ @pytest.mark.parametrize("use_dask", [True, False])
97
+ @pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
98
+ def test_histogram_approx_discrete(use_dask, ndims):
99
+ data = np.random.randint(0, 100, size=(10000, *(2,) * (ndims - 1)))
100
+ if use_dask:
101
+ dask = pytest.importorskip("dask")
102
+ dask_df = pytest.importorskip("dask.dataframe")
103
+ data = dask_df.from_array(data)
104
+ with pm.Model():
105
+ s = pm.Exponential("s", 1.0, size=2 if ndims > 1 else 1)
106
+ pot = pmx.distributions.histogram_utils.histogram_approximation(
107
+ "histogram_potential", pm.Poisson.dist(s), observed=data, min_count=10
108
+ )
109
+ trace = pm.sample(10, tune=0) # very fast
tests/test_laplace.py ADDED
@@ -0,0 +1,238 @@
1
+ # Copyright 2024 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy as np
17
+ import pymc as pm
18
+ import pytest
19
+
20
+ import pymc_extras as pmx
21
+
22
+ from pymc_extras.inference.find_map import find_MAP
23
+ from pymc_extras.inference.laplace import (
24
+ fit_laplace,
25
+ fit_mvn_to_MAP,
26
+ sample_laplace_posterior,
27
+ )
28
+
29
+
30
+ @pytest.fixture(scope="session")
31
+ def rng():
32
+ seed = sum(map(ord, "test_laplace"))
33
+ return np.random.default_rng(seed)
34
+
35
+
36
+ @pytest.mark.filterwarnings(
37
+ "ignore:hessian will stop negating the output in a future version of PyMC.\n"
38
+ + "To suppress this warning set `negate_output=False`:FutureWarning",
39
+ )
40
+ def test_laplace():
41
+ # Example originates from Bayesian Data Analyses, 3rd Edition
42
+ # By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
43
+ # Aki Vehtari, and Donald Rubin.
44
+ # See section. 4.1
45
+
46
+ y = np.array([2642, 3503, 4358], dtype=np.float64)
47
+ n = y.size
48
+ draws = 100000
49
+
50
+ with pm.Model() as m:
51
+ mu = pm.Uniform("mu", -10000, 10000)
52
+ logsigma = pm.Uniform("logsigma", 1, 100)
53
+
54
+ yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
55
+ vars = [mu, logsigma]
56
+
57
+ idata = pmx.fit(
58
+ method="laplace", optimize_method="trust-ncg", draws=draws, random_seed=173300, chains=1
59
+ )
60
+
61
+ assert idata.posterior["mu"].shape == (1, draws)
62
+ assert idata.posterior["logsigma"].shape == (1, draws)
63
+ assert idata.observed_data["y"].shape == (n,)
64
+ assert idata.fit["mean_vector"].shape == (len(vars),)
65
+ assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars))
66
+
67
+ bda_map = [y.mean(), np.log(y.std())]
68
+ bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]])
69
+
70
+ np.testing.assert_allclose(idata.fit["mean_vector"].values, bda_map)
71
+ np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
72
+
73
+
74
+ def test_laplace_only_fit():
75
+ # Example originates from Bayesian Data Analyses, 3rd Edition
76
+ # By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
77
+ # Aki Vehtari, and Donald Rubin.
78
+ # See section. 4.1
79
+
80
+ y = np.array([2642, 3503, 4358], dtype=np.float64)
81
+ n = y.size
82
+
83
+ with pm.Model() as m:
84
+ logsigma = pm.Uniform("logsigma", 1, 100)
85
+ mu = pm.Uniform("mu", -10000, 10000)
86
+ yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
87
+ vars = [mu, logsigma]
88
+
89
+ idata = pmx.fit(
90
+ method="laplace",
91
+ optimize_method="BFGS",
92
+ progressbar=True,
93
+ gradient_backend="jax",
94
+ compile_kwargs={"mode": "JAX"},
95
+ optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100),
96
+ random_seed=173300,
97
+ )
98
+
99
+ assert idata.fit["mean_vector"].shape == (len(vars),)
100
+ assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars))
101
+
102
+ bda_map = [np.log(y.std()), y.mean()]
103
+ bda_cov = np.array([[1 / (2 * n), 0], [0, y.var() / n]])
104
+
105
+ np.testing.assert_allclose(idata.fit["mean_vector"].values, bda_map)
106
+ np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
107
+
108
+
109
+ @pytest.mark.parametrize(
110
+ "transform_samples",
111
+ [True, False],
112
+ ids=["transformed", "untransformed"],
113
+ )
114
+ @pytest.mark.parametrize("mode", ["JAX", None], ids=["jax", "pytensor"])
115
+ def test_fit_laplace_coords(rng, transform_samples, mode):
116
+ coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)}
117
+ with pm.Model(coords=coords) as model:
118
+ mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"])
119
+ sigma = pm.Exponential("sigma", 1, dims=["city"])
120
+ obs = pm.Normal(
121
+ "obs",
122
+ mu=mu,
123
+ sigma=sigma,
124
+ observed=rng.normal(loc=3, scale=1.5, size=(100, 3)),
125
+ dims=["obs_idx", "city"],
126
+ )
127
+
128
+ optimized_point = find_MAP(
129
+ method="trust-ncg",
130
+ use_grad=True,
131
+ use_hessp=True,
132
+ progressbar=False,
133
+ compile_kwargs=dict(mode=mode),
134
+ gradient_backend="jax" if mode == "JAX" else "pytensor",
135
+ )
136
+
137
+ for value in optimized_point.values():
138
+ assert value.shape == (3,)
139
+
140
+ mu, H_inv = fit_mvn_to_MAP(
141
+ optimized_point=optimized_point,
142
+ model=model,
143
+ transform_samples=transform_samples,
144
+ )
145
+
146
+ idata = sample_laplace_posterior(
147
+ mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples
148
+ )
149
+
150
+ np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.5)
151
+ np.testing.assert_allclose(
152
+ np.mean(idata.posterior.sigma, axis=1), np.full((2, 3), 1.5), atol=0.3
153
+ )
154
+
155
+ suffix = "_log__" if transform_samples else ""
156
+ assert idata.fit.rows.values.tolist() == [
157
+ "mu[A]",
158
+ "mu[B]",
159
+ "mu[C]",
160
+ f"sigma{suffix}[A]",
161
+ f"sigma{suffix}[B]",
162
+ f"sigma{suffix}[C]",
163
+ ]
164
+
165
+
166
+ def test_fit_laplace_ragged_coords(rng):
167
+ coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
168
+ with pm.Model(coords=coords) as ragged_dim_model:
169
+ X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"])
170
+ beta = pm.Normal(
171
+ "beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"]
172
+ )
173
+ mu = pm.Deterministic(
174
+ "mu", (X[:, None, :] * beta[None]).sum(axis=-1), dims=["obs_idx", "city"]
175
+ )
176
+ sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"])
177
+
178
+ obs = pm.Normal(
179
+ "obs",
180
+ mu=mu,
181
+ sigma=sigma,
182
+ observed=rng.normal(loc=3, scale=1.5, size=(100, 3)),
183
+ dims=["obs_idx", "city"],
184
+ )
185
+
186
+ idata = fit_laplace(
187
+ optimize_method="Newton-CG",
188
+ progressbar=False,
189
+ use_grad=True,
190
+ use_hessp=True,
191
+ gradient_backend="jax",
192
+ compile_kwargs={"mode": "JAX"},
193
+ )
194
+
195
+ assert idata["posterior"].beta.shape[-2:] == (3, 2)
196
+ assert idata["posterior"].sigma.shape[-1:] == (3,)
197
+
198
+ # Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1
199
+ # strictly positive
200
+ assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all()
201
+ assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all()
202
+
203
+
204
+ @pytest.mark.parametrize(
205
+ "fit_in_unconstrained_space",
206
+ [True, False],
207
+ ids=["transformed", "untransformed"],
208
+ )
209
+ def test_fit_laplace(fit_in_unconstrained_space):
210
+ with pm.Model() as simp_model:
211
+ mu = pm.Normal("mu", mu=3, sigma=0.5)
212
+ sigma = pm.Exponential("sigma", 1)
213
+ obs = pm.Normal(
214
+ "obs",
215
+ mu=mu,
216
+ sigma=sigma,
217
+ observed=np.random.default_rng().normal(loc=3, scale=1.5, size=(10000,)),
218
+ )
219
+
220
+ idata = fit_laplace(
221
+ optimize_method="trust-ncg",
222
+ use_grad=True,
223
+ use_hessp=True,
224
+ fit_in_unconstrained_space=fit_in_unconstrained_space,
225
+ optimizer_kwargs=dict(maxiter=100_000, tol=1e-100),
226
+ )
227
+
228
+ np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1)
229
+ np.testing.assert_allclose(
230
+ np.mean(idata.posterior.sigma, axis=1), np.full((2,), 1.5), atol=0.1
231
+ )
232
+
233
+ if fit_in_unconstrained_space:
234
+ assert idata.fit.rows.values.tolist() == ["mu", "sigma_log__"]
235
+ np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 0.4]), atol=0.1)
236
+ else:
237
+ assert idata.fit.rows.values.tolist() == ["mu", "sigma"]
238
+ np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 1.5]), atol=0.1)
@@ -0,0 +1,208 @@
1
+ # Copyright 2023 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import hashlib
17
+ import sys
18
+ import tempfile
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+ import pytest
23
+ import xarray as xr
24
+
25
+ from pymc_extras.linearmodel import LinearModel
26
+
27
+ try:
28
+ from sklearn import set_config
29
+ from sklearn.compose import TransformedTargetRegressor
30
+ from sklearn.pipeline import Pipeline
31
+ from sklearn.preprocessing import StandardScaler
32
+
33
+ set_config(transform_output="pandas")
34
+ sklearn_available = True
35
+ except ImportError:
36
+ sklearn_available = False
37
+
38
+
39
+ @pytest.fixture(scope="module")
40
+ def toy_actual_params():
41
+ return {
42
+ "intercept": 3,
43
+ "slope": 5,
44
+ "obs_error": 0.5,
45
+ }
46
+
47
+
48
+ @pytest.fixture(scope="module")
49
+ def toy_X():
50
+ x = np.linspace(start=0, stop=1, num=100)
51
+ X = pd.DataFrame({"input": x})
52
+ return X
53
+
54
+
55
+ @pytest.fixture(scope="module")
56
+ def toy_y(toy_X, toy_actual_params):
57
+ y = toy_actual_params["slope"] * toy_X["input"] + toy_actual_params["intercept"]
58
+ rng = np.random.default_rng(427)
59
+ y = y + rng.normal(0, toy_actual_params["obs_error"], size=len(toy_X))
60
+ y = pd.Series(y, name="output")
61
+ return y
62
+
63
+
64
+ @pytest.fixture(scope="module")
65
+ def fitted_linear_model_instance(toy_X, toy_y):
66
+ sampler_config = {
67
+ "draws": 50,
68
+ "tune": 30,
69
+ "chains": 2,
70
+ "target_accept": 0.95,
71
+ }
72
+ model = LinearModel(sampler_config=sampler_config)
73
+ model.fit(toy_X, toy_y, random_seed=312)
74
+ return model
75
+
76
+
77
+ @pytest.mark.skipif(
78
+ sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
79
+ )
80
+ def test_save_load(fitted_linear_model_instance):
81
+ model = fitted_linear_model_instance
82
+ temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
83
+ model.save(temp.name)
84
+ model2 = LinearModel.load(temp.name)
85
+ assert model.idata.groups() == model2.idata.groups()
86
+
87
+ X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
88
+ pred1 = model.predict(X_pred, random_seed=423)
89
+ pred2 = model2.predict(X_pred, random_seed=423)
90
+ # Predictions should be identical
91
+ np.testing.assert_array_equal(pred1, pred2)
92
+ temp.close()
93
+
94
+
95
+ def test_save_without_fit_raises_runtime_error(toy_X, toy_y):
96
+ test_model = LinearModel()
97
+ with pytest.raises(RuntimeError):
98
+ test_model.save("saved_model")
99
+
100
+
101
+ def test_fit(fitted_linear_model_instance):
102
+ model = fitted_linear_model_instance
103
+
104
+ new_X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
105
+
106
+ pred = model.predict(new_X_pred)
107
+ assert len(new_X_pred) == len(pred)
108
+ assert isinstance(pred, np.ndarray)
109
+ post_pred = model.predict_posterior(new_X_pred)
110
+ assert len(new_X_pred) == len(post_pred)
111
+ assert isinstance(post_pred, xr.DataArray)
112
+
113
+
114
+ def test_parameter_fit(toy_X, toy_y, toy_actual_params):
115
+ """Check that the fit model recovered the data-generating parameters."""
116
+ # Fit the model with a sufficient number of samples
117
+ sampler_config = {
118
+ "draws": 500,
119
+ "tune": 300,
120
+ "chains": 2,
121
+ "target_accept": 0.95,
122
+ }
123
+ model = LinearModel(sampler_config=sampler_config)
124
+ model.fit(toy_X, toy_y, random_seed=312)
125
+ fit_params = model.idata.posterior.mean()
126
+ np.testing.assert_allclose(fit_params["intercept"], toy_actual_params["intercept"], rtol=0.1)
127
+ np.testing.assert_allclose(fit_params["slope"], toy_actual_params["slope"], rtol=0.1)
128
+ np.testing.assert_allclose(fit_params["σ_model_fmc"], toy_actual_params["obs_error"], rtol=0.1)
129
+
130
+
131
+ def test_predict(fitted_linear_model_instance):
132
+ model = fitted_linear_model_instance
133
+ X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
134
+ pred = model.predict(X_pred)
135
+ assert len(X_pred) == len(pred)
136
+ assert np.issubdtype(pred.dtype, np.floating)
137
+
138
+
139
+ @pytest.mark.parametrize("combined", [True, False])
140
+ def test_predict_posterior(fitted_linear_model_instance, combined):
141
+ model = fitted_linear_model_instance
142
+ n_pred = 150
143
+ X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=n_pred)})
144
+ pred = model.predict_posterior(X_pred, combined=combined)
145
+ chains = model.idata.sample_stats.sizes["chain"]
146
+ draws = model.idata.sample_stats.sizes["draw"]
147
+ expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
148
+ assert pred.shape == expected_shape
149
+ assert np.issubdtype(pred.dtype, np.floating)
150
+ # TODO: check that extend_idata has the expected effect
151
+
152
+
153
+ @pytest.mark.parametrize("samples", [None, 300])
154
+ @pytest.mark.parametrize("combined", [True, False])
155
+ def test_sample_prior_predictive(samples, combined, toy_X, toy_y):
156
+ model = LinearModel()
157
+ prior_pred = model.sample_prior_predictive(toy_X, toy_y, samples, combined=combined)[
158
+ model.output_var
159
+ ]
160
+ draws = model.sampler_config["draws"] if samples is None else samples
161
+ chains = 1
162
+ expected_shape = (len(toy_X), chains * draws) if combined else (chains, draws, len(toy_X))
163
+ assert prior_pred.shape == expected_shape
164
+ # TODO: check that extend_idata has the expected effect
165
+
166
+
167
+ def test_id():
168
+ model_config = {
169
+ "intercept": {"loc": 0, "scale": 10},
170
+ "slope": {"loc": 0, "scale": 10},
171
+ "obs_error": 2,
172
+ }
173
+ sampler_config = {
174
+ "draws": 1_000,
175
+ "tune": 1_000,
176
+ "chains": 3,
177
+ "target_accept": 0.95,
178
+ }
179
+ model = LinearModel(model_config=model_config, sampler_config=sampler_config)
180
+
181
+ expected_id = hashlib.sha256(
182
+ str(model_config.values()).encode() + model.version.encode() + model._model_type.encode()
183
+ ).hexdigest()[:16]
184
+
185
+ assert model.id == expected_id
186
+
187
+
188
+ @pytest.mark.skipif(not sklearn_available, reason="scikit-learn package is not available.")
189
+ def test_pipeline_integration(toy_X, toy_y):
190
+ model_config = {
191
+ "intercept": {"loc": 0, "scale": 2},
192
+ "slope": {"loc": 0, "scale": 2},
193
+ "obs_error": 1,
194
+ "default_output_var": "y_hat",
195
+ }
196
+ model = Pipeline(
197
+ [
198
+ ("input_scaling", StandardScaler()),
199
+ (
200
+ "linear_model",
201
+ TransformedTargetRegressor(LinearModel(model_config), transformer=StandardScaler()),
202
+ ),
203
+ ]
204
+ )
205
+ model.fit(toy_X, toy_y)
206
+
207
+ X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
208
+ model.predict(X_pred)