pymc-extras 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. pymc_extras/__init__.py +5 -1
  2. pymc_extras/distributions/continuous.py +3 -2
  3. pymc_extras/distributions/discrete.py +3 -1
  4. pymc_extras/inference/find_map.py +62 -17
  5. pymc_extras/inference/laplace.py +10 -7
  6. pymc_extras/statespace/core/statespace.py +191 -52
  7. pymc_extras/statespace/filters/distributions.py +15 -16
  8. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  9. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  10. pymc_extras/statespace/models/ETS.py +10 -0
  11. pymc_extras/statespace/models/SARIMAX.py +26 -5
  12. pymc_extras/statespace/models/VARMAX.py +12 -2
  13. pymc_extras/statespace/models/structural.py +18 -5
  14. pymc_extras-0.2.6.dist-info/METADATA +318 -0
  15. pymc_extras-0.2.6.dist-info/RECORD +65 -0
  16. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
  17. pymc_extras/version.py +0 -11
  18. pymc_extras/version.txt +0 -1
  19. pymc_extras-0.2.5.dist-info/METADATA +0 -112
  20. pymc_extras-0.2.5.dist-info/RECORD +0 -108
  21. pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
  22. tests/__init__.py +0 -13
  23. tests/distributions/__init__.py +0 -19
  24. tests/distributions/test_continuous.py +0 -185
  25. tests/distributions/test_discrete.py +0 -210
  26. tests/distributions/test_discrete_markov_chain.py +0 -258
  27. tests/distributions/test_multivariate.py +0 -304
  28. tests/distributions/test_transform.py +0 -77
  29. tests/model/__init__.py +0 -0
  30. tests/model/marginal/__init__.py +0 -0
  31. tests/model/marginal/test_distributions.py +0 -132
  32. tests/model/marginal/test_graph_analysis.py +0 -182
  33. tests/model/marginal/test_marginal_model.py +0 -967
  34. tests/model/test_model_api.py +0 -38
  35. tests/statespace/__init__.py +0 -0
  36. tests/statespace/test_ETS.py +0 -411
  37. tests/statespace/test_SARIMAX.py +0 -405
  38. tests/statespace/test_VARMAX.py +0 -184
  39. tests/statespace/test_coord_assignment.py +0 -181
  40. tests/statespace/test_distributions.py +0 -270
  41. tests/statespace/test_kalman_filter.py +0 -326
  42. tests/statespace/test_representation.py +0 -175
  43. tests/statespace/test_statespace.py +0 -872
  44. tests/statespace/test_statespace_JAX.py +0 -156
  45. tests/statespace/test_structural.py +0 -836
  46. tests/statespace/utilities/__init__.py +0 -0
  47. tests/statespace/utilities/shared_fixtures.py +0 -9
  48. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  49. tests/statespace/utilities/test_helpers.py +0 -310
  50. tests/test_blackjax_smc.py +0 -222
  51. tests/test_find_map.py +0 -103
  52. tests/test_histogram_approximation.py +0 -109
  53. tests/test_laplace.py +0 -281
  54. tests/test_linearmodel.py +0 -208
  55. tests/test_model_builder.py +0 -306
  56. tests/test_pathfinder.py +0 -297
  57. tests/test_pivoted_cholesky.py +0 -24
  58. tests/test_printing.py +0 -98
  59. tests/test_prior_from_trace.py +0 -172
  60. tests/test_splines.py +0 -77
  61. tests/utils.py +0 -0
  62. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.6.dist-info}/licenses/LICENSE +0 -0
File without changes
@@ -1,9 +0,0 @@
1
- import numpy as np
2
- import pytest
3
-
4
- TEST_SEED = sum(map(ord, "statespace"))
5
-
6
-
7
- @pytest.fixture(scope="session")
8
- def rng():
9
- return np.random.default_rng(TEST_SEED)
@@ -1,42 +0,0 @@
1
- import numpy as np
2
- import statsmodels.api as sm
3
-
4
-
5
- class LocalLinearTrend(sm.tsa.statespace.MLEModel):
6
- def __init__(self, endog, **kwargs):
7
- # Model order
8
- k_states = k_posdef = 2
9
-
10
- # Initialize the statespace
11
- super().__init__(endog, k_states=k_states, k_posdef=k_posdef, **kwargs)
12
-
13
- # Initialize the matrices
14
- self.ssm["design"] = np.array([1, 0])
15
- self.ssm["transition"] = np.array([[1, 1], [0, 1]])
16
- self.ssm["selection"] = np.eye(k_states)
17
-
18
- # Cache some indices
19
- self._state_cov_idx = ("state_cov", *np.diag_indices(k_posdef))
20
-
21
- @property
22
- def param_names(self):
23
- return ["sigma2.measurement", "sigma2.level", "sigma2.trend"]
24
-
25
- @property
26
- def start_params(self):
27
- return [np.std(self.endog)] * 3
28
-
29
- def transform_params(self, unconstrained):
30
- return unconstrained**2
31
-
32
- def untransform_params(self, constrained):
33
- return constrained**0.5
34
-
35
- def update(self, params, *args, **kwargs):
36
- params = super().update(params, *args, **kwargs)
37
-
38
- # Observation covariance
39
- self.ssm["obs_cov", 0, 0] = params[0]
40
-
41
- # State covariance
42
- self.ssm[self._state_cov_idx] = params[1:]
@@ -1,310 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import pytensor
4
- import pytensor.tensor as pt
5
- import statsmodels.api as sm
6
-
7
- from numpy.testing import assert_allclose
8
- from pymc import modelcontext
9
-
10
- from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother
11
- from pymc_extras.statespace.utils.constants import (
12
- MATRIX_NAMES,
13
- SHORT_NAME_TO_LONG,
14
- )
15
- from tests.statespace.utilities.statsmodel_local_level import LocalLinearTrend
16
-
17
- floatX = pytensor.config.floatX
18
-
19
-
20
- def load_nile_test_data():
21
- from importlib.metadata import version
22
-
23
- nile = pd.read_csv("tests/statespace/test_data/nile.csv", dtype={"x": floatX})
24
- major, minor, rev = map(int, version("pandas").split("."))
25
- if major >= 2 and minor >= 2 and rev >= 0:
26
- freq_str = "YS-JAN"
27
- else:
28
- freq_str = "AS-JAN"
29
- nile.index = pd.date_range(start="1871-01-01", end="1970-01-01", freq=freq_str)
30
- nile.rename(columns={"x": "height"}, inplace=True)
31
- nile = (nile - nile.mean()) / nile.std()
32
- nile = nile.astype(floatX)
33
-
34
- return nile
35
-
36
-
37
- def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None):
38
- ksmoother = KalmanSmoother()
39
- data = pt.tensor(name="data", dtype=floatX, shape=(n, p))
40
- a0 = pt.tensor(name="x0", dtype=floatX, shape=(m,))
41
- P0 = pt.tensor(name="P0", dtype=floatX, shape=(m, m))
42
- c = pt.tensor(name="c", dtype=floatX, shape=(m,))
43
- d = pt.tensor(name="d", dtype=floatX, shape=(p,))
44
- Q = pt.tensor(name="Q", dtype=floatX, shape=(r, r))
45
- H = pt.tensor(name="H", dtype=floatX, shape=(p, p))
46
- T = pt.tensor(name="T", dtype=floatX, shape=(m, m))
47
- R = pt.tensor(name="R", dtype=floatX, shape=(m, r))
48
- Z = pt.tensor(name="Z", dtype=floatX, shape=(p, m))
49
-
50
- inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
51
-
52
- (
53
- filtered_states,
54
- predicted_states,
55
- observed_states,
56
- filtered_covs,
57
- predicted_covs,
58
- observed_covs,
59
- ll_obs,
60
- ) = kfilter.build_graph(*inputs, mode=mode)
61
-
62
- smoothed_states, smoothed_covs = ksmoother.build_graph(T, R, Q, filtered_states, filtered_covs)
63
-
64
- outputs = [
65
- filtered_states,
66
- predicted_states,
67
- smoothed_states,
68
- filtered_covs,
69
- predicted_covs,
70
- smoothed_covs,
71
- ll_obs.sum(),
72
- ll_obs,
73
- ]
74
-
75
- return inputs, outputs
76
-
77
-
78
- def add_missing_data(data, n_missing, rng):
79
- n = data.shape[0]
80
- missing_idx = rng.choice(n, n_missing, replace=False)
81
- data[missing_idx] = np.nan
82
-
83
- return data
84
-
85
-
86
- def make_test_inputs(p, m, r, n, rng, missing_data=None, H_is_zero=False):
87
- data = np.arange(n * p, dtype=floatX).reshape(-1, p)
88
- if missing_data is not None:
89
- data = add_missing_data(data, missing_data, rng)
90
-
91
- a0 = np.zeros(m, dtype=floatX)
92
- P0 = np.eye(m, dtype=floatX)
93
- c = np.zeros(m, dtype=floatX)
94
- d = np.zeros(p, dtype=floatX)
95
- Q = np.eye(r, dtype=floatX)
96
- H = np.zeros((p, p), dtype=floatX) if H_is_zero else np.eye(p, dtype=floatX)
97
- T = np.eye(m, k=-1, dtype=floatX)
98
- T[0, :] = 1 / m
99
- R = np.eye(m, dtype=floatX)[:, :r]
100
- Z = np.eye(m, dtype=floatX)[:p, :]
101
-
102
- data, a0, P0, c, d, T, Z, R, H, Q = map(
103
- np.ascontiguousarray, [data, a0, P0, c, d, T, Z, R, H, Q]
104
- )
105
-
106
- return data, a0, P0, c, d, T, Z, R, H, Q
107
-
108
-
109
- def get_expected_shape(name, p, m, r, n):
110
- if name == "log_likelihood":
111
- return ()
112
- elif name == "ll_obs":
113
- return (n,)
114
- filter_type, variable = name.split("_")
115
- if variable == "states":
116
- return n, m
117
- if variable == "covs":
118
- return n, m, m
119
-
120
-
121
- def get_sm_state_from_output_name(res, name):
122
- if name == "log_likelihood":
123
- return res.llf
124
- elif name == "ll_obs":
125
- return res.llf_obs
126
-
127
- filter_type, variable = name.split("_")
128
- sm_states = getattr(res, "states")
129
-
130
- if variable == "states":
131
- return getattr(sm_states, filter_type)
132
- if variable == "covs":
133
- m = res.filter_results.k_states
134
- # remove the "s" from "covs"
135
- return getattr(sm_states, name[:-1]).reshape(-1, m, m)
136
-
137
-
138
- def nile_test_test_helper(rng, n_missing=0):
139
- a0 = np.zeros(2, dtype=floatX)
140
- P0 = np.eye(2, dtype=floatX) * 1e6
141
- c = np.zeros(2, dtype=floatX)
142
- d = np.zeros(1, dtype=floatX)
143
- Q = np.eye(2, dtype=floatX) * np.array([0.5, 0.01], dtype=floatX)
144
- H = np.eye(1, dtype=floatX) * 0.8
145
- T = np.array([[1.0, 1.0], [0.0, 1.0]], dtype=floatX)
146
- R = np.eye(2, dtype=floatX)
147
- Z = np.array([[1.0, 0.0]], dtype=floatX)
148
-
149
- data = load_nile_test_data().values
150
- if n_missing > 0:
151
- data = add_missing_data(data, n_missing, rng)
152
-
153
- sm_model = LocalLinearTrend(
154
- endog=data,
155
- initialization="known",
156
- initial_state_cov=P0,
157
- initial_state=a0.ravel(),
158
- )
159
-
160
- res = sm_model.fit_constrained(
161
- constraints={
162
- "sigma2.measurement": 0.8,
163
- "sigma2.level": 0.5,
164
- "sigma2.trend": 0.01,
165
- }
166
- )
167
-
168
- inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
169
-
170
- return res, inputs
171
-
172
-
173
- def fast_eval(var):
174
- return pytensor.function([], var, mode="FAST_COMPILE")()
175
-
176
-
177
- def delete_rvs_from_model(rv_names: list[str]) -> None:
178
- """Remove all model mappings referring to rv
179
-
180
- This can be used to "delete" an RV from a model
181
- """
182
- mod = modelcontext(None)
183
- all_rvs = mod.basic_RVs + mod.deterministics
184
- all_rv_names = [x.name for x in all_rvs]
185
-
186
- for name in rv_names:
187
- assert name in all_rv_names, f"{name} is not part of the Model: {all_rv_names}"
188
-
189
- rv_idx = all_rv_names.index(name)
190
- rv = all_rvs[rv_idx]
191
-
192
- mod.named_vars.pop(name)
193
- if name in mod.named_vars_to_dims:
194
- mod.named_vars_to_dims.pop(name)
195
-
196
- if rv in mod.deterministics:
197
- mod.deterministics.remove(rv)
198
- continue
199
-
200
- value = mod.rvs_to_values.pop(rv)
201
- mod.values_to_rvs.pop(value)
202
- mod.rvs_to_transforms.pop(rv)
203
- if rv in mod.free_RVs:
204
- mod.free_RVs.remove(rv)
205
- mod.rvs_to_initial_values.pop(rv)
206
- else:
207
- mod.observed_RVs.remove(rv)
208
-
209
-
210
- def unpack_statespace(ssm):
211
- return [ssm[SHORT_NAME_TO_LONG[x]] for x in MATRIX_NAMES]
212
-
213
-
214
- def unpack_symbolic_matrices_with_params(mod, param_dict, data_dict=None, mode="FAST_COMPILE"):
215
- inputs = list(mod._name_to_variable.values())
216
- if data_dict is not None:
217
- inputs += list(mod._name_to_data.values())
218
- else:
219
- data_dict = {}
220
-
221
- f_matrices = pytensor.function(
222
- inputs,
223
- unpack_statespace(mod.ssm),
224
- on_unused_input="raise",
225
- mode=mode,
226
- )
227
-
228
- x0, P0, c, d, T, Z, R, H, Q = f_matrices(**param_dict, **data_dict)
229
-
230
- return x0, P0, c, d, T, Z, R, H, Q
231
-
232
-
233
- def simulate_from_numpy_model(mod, rng, param_dict, data_dict=None, steps=100):
234
- """
235
- Helper function to visualize the components outside of a PyMC model context
236
- """
237
- x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, param_dict, data_dict)
238
- k_states = mod.k_states
239
- k_posdef = mod.k_posdef
240
-
241
- x = np.zeros((steps, k_states))
242
- y = np.zeros(steps)
243
-
244
- x[0] = x0
245
- y[0] = (Z @ x0).squeeze() if Z.ndim == 2 else (Z[0] @ x0).squeeze()
246
-
247
- if not np.allclose(H, 0):
248
- y[0] += rng.multivariate_normal(mean=np.zeros(1), cov=H).squeeze()
249
-
250
- for t in range(1, steps):
251
- if k_posdef > 0:
252
- shock = rng.multivariate_normal(mean=np.zeros(k_posdef), cov=Q)
253
- innov = R @ shock
254
- else:
255
- innov = 0
256
-
257
- if not np.allclose(H, 0):
258
- error = rng.multivariate_normal(mean=np.zeros(1), cov=H)
259
- else:
260
- error = 0
261
-
262
- x[t] = c + T @ x[t - 1] + innov
263
- if Z.ndim == 2:
264
- y[t] = (d + Z @ x[t] + error).squeeze()
265
- else:
266
- y[t] = (d + Z[t] @ x[t] + error).squeeze()
267
-
268
- return x, y
269
-
270
-
271
- def assert_pattern_repeats(y, T, atol, rtol):
272
- val = np.diff(y.reshape(-1, T), axis=0)
273
- if floatX.endswith("64"):
274
- # Round this before going into the test, otherwise it behaves poorly (atol = inf)
275
- n_digits = len(str(1 / atol))
276
- val = np.round(val, n_digits)
277
-
278
- assert_allclose(
279
- val,
280
- 0,
281
- err_msg="seasonal pattern does not repeat",
282
- atol=atol,
283
- rtol=rtol,
284
- )
285
-
286
-
287
- def make_stationary_params(data, p, d, q, P, D, Q, S):
288
- sm_sarimax = sm.tsa.SARIMAX(data, order=(p, d, q), seasonal_order=(P, D, Q, S))
289
- res = sm_sarimax.fit(disp=False)
290
-
291
- param_dict = dict(ar_params=[], ma_params=[], seasonal_ar_params=[], seasonal_ma_params=[])
292
-
293
- for name, param in zip(res.param_names, res.params):
294
- if name.startswith("ar.S"):
295
- param_dict["seasonal_ar_params"].append(param)
296
- elif name.startswith("ma.S"):
297
- param_dict["seasonal_ma_params"].append(param)
298
- elif name.startswith("ar."):
299
- param_dict["ar_params"].append(param)
300
- elif name.startswith("ma."):
301
- param_dict["ma_params"].append(param)
302
- else:
303
- param_dict["sigma_state"] = param
304
-
305
- param_dict = {
306
- k: np.array(v, dtype=floatX)
307
- for k, v in param_dict.items()
308
- if isinstance(v, float) or len(v) > 0
309
- }
310
- return param_dict
@@ -1,222 +0,0 @@
1
- # Copyright 2023 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import numpy as np
15
- import pymc as pm
16
- import pytensor.tensor as pt
17
- import pytest
18
- import scipy
19
-
20
- from numpy import dtype
21
- from xarray.core.utils import Frozen
22
-
23
- jax = pytest.importorskip("jax")
24
- pytest.importorskip("blackjax")
25
-
26
- from pymc_extras.inference.smc.sampling import (
27
- arviz_from_particles,
28
- blackjax_particles_from_pymc_population,
29
- get_jaxified_loglikelihood,
30
- get_jaxified_logprior,
31
- sample_smc_blackjax,
32
- )
33
-
34
-
35
- def two_gaussians_model():
36
- n = 4
37
- mu1 = np.ones(n) * 0.5
38
- mu2 = -mu1
39
-
40
- stdev = 0.1
41
- sigma = np.power(stdev, 2) * np.eye(n)
42
- isigma = np.linalg.inv(sigma)
43
- dsigma = np.linalg.det(sigma)
44
-
45
- w1 = stdev
46
- w2 = 1 - stdev
47
-
48
- def two_gaussians(x):
49
- """
50
- Mixture of gaussians likelihood
51
- """
52
- log_like1 = (
53
- -0.5 * n * pt.log(2 * np.pi)
54
- - 0.5 * pt.log(dsigma)
55
- - 0.5 * (x - mu1).T.dot(isigma).dot(x - mu1)
56
- )
57
- log_like2 = (
58
- -0.5 * n * pt.log(2 * np.pi)
59
- - 0.5 * pt.log(dsigma)
60
- - 0.5 * (x - mu2).T.dot(isigma).dot(x - mu2)
61
- )
62
- return pt.log(w1 * pt.exp(log_like1) + w2 * pt.exp(log_like2))
63
-
64
- with pm.Model() as m:
65
- X = pm.Uniform("X", lower=-2, upper=2.0, shape=n)
66
- llk = pm.Potential("muh", two_gaussians(X))
67
-
68
- return m, mu1
69
-
70
-
71
- def fast_model():
72
- with pm.Model() as m:
73
- x = pm.Normal("x", 0, 1)
74
- y = pm.Normal("y", x, 1, observed=0)
75
- return m
76
-
77
-
78
- @pytest.mark.parametrize(
79
- "kernel, check_for_integration_steps, inner_kernel_params",
80
- [
81
- ("HMC", True, {"step_size": 0.1, "integration_steps": 11}),
82
- ("NUTS", False, {"step_size": 0.1}),
83
- ],
84
- )
85
- def test_sample_smc_blackjax(kernel, check_for_integration_steps, inner_kernel_params):
86
- """
87
- When running the two gaussians model
88
- with BlackJax SMC, we sample them correctly,
89
- the shape of a posterior variable is (1, particles, dimension)
90
- and the inference_data has the right attributes.
91
-
92
- """
93
- model, muref = two_gaussians_model()
94
- iterations_to_diagnose = 2
95
- n_particles = 1000
96
- with model:
97
- inference_data = sample_smc_blackjax(
98
- n_particles=n_particles,
99
- kernel=kernel,
100
- inner_kernel_params=inner_kernel_params,
101
- iterations_to_diagnose=iterations_to_diagnose,
102
- )
103
-
104
- x = inference_data.posterior["X"]
105
-
106
- assert x.to_numpy().shape == (1, n_particles, 4)
107
- mu1d = np.abs(x).mean(axis=0).mean(axis=0)
108
- np.testing.assert_allclose(muref, mu1d, rtol=0.0, atol=0.03)
109
-
110
- for attribute, value in [
111
- ("particles", n_particles),
112
- ("step_size", 0.1),
113
- ("num_mcmc_steps", 10),
114
- ("iterations_to_diagnose", iterations_to_diagnose),
115
- ("sampler", f"Blackjax SMC with {kernel} kernel"),
116
- ]:
117
- assert inference_data.posterior.attrs[attribute] == value
118
-
119
- for diagnostic in ["lambda_evolution", "log_likelihood_increments"]:
120
- assert inference_data.posterior.attrs[diagnostic].shape == (iterations_to_diagnose,)
121
-
122
- for diagnostic in ["ancestors_evolution", "weights_evolution"]:
123
- assert inference_data.posterior.attrs[diagnostic].shape == (
124
- iterations_to_diagnose,
125
- n_particles,
126
- )
127
-
128
- for attribute in ["running_time_seconds", "iterations"]:
129
- assert attribute in inference_data.posterior.attrs
130
-
131
- if check_for_integration_steps:
132
- assert inference_data.posterior.attrs["integration_steps"] == 11
133
-
134
-
135
- def test_blackjax_particles_from_pymc_population_univariate():
136
- model = fast_model()
137
- population = {"x": np.array([2, 3, 4])}
138
- blackjax_particles = blackjax_particles_from_pymc_population(model, population)
139
- jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])])
140
-
141
-
142
- def test_blackjax_particles_from_pymc_population_multivariate():
143
- with pm.Model() as model:
144
- x = pm.Normal("x", 0, 1)
145
- z = pm.Normal("z", 0, 1)
146
- y = pm.Normal("y", x + z, 1, observed=0)
147
-
148
- population = {"x": np.array([0.34614613, 1.09163261, -0.44526825]), "z": np.array([1, 2, 3])}
149
- blackjax_particles = blackjax_particles_from_pymc_population(model, population)
150
- jax.tree.map(
151
- np.testing.assert_allclose,
152
- blackjax_particles,
153
- [np.array([[0.34614613], [1.09163261], [-0.44526825]]), np.array([[1], [2], [3]])],
154
- )
155
-
156
-
157
- def simple_multivariable_model():
158
- """
159
- A simple model that has a multivariate variable,
160
- a has more than one variable (multivariable)
161
- """
162
- with pm.Model() as model:
163
- x = pm.Normal("x", 0, 1, shape=2)
164
- z = pm.Normal("z", 0, 1)
165
- y = pm.Normal("y", z, 1, observed=0)
166
- return model
167
-
168
-
169
- def test_blackjax_particles_from_pymc_population_multivariable():
170
- model = simple_multivariable_model()
171
- population = {"x": np.array([[2, 3], [5, 6], [7, 9]]), "z": np.array([11, 12, 13])}
172
- blackjax_particles = blackjax_particles_from_pymc_population(model, population)
173
-
174
- jax.tree.map(
175
- np.testing.assert_allclose,
176
- blackjax_particles,
177
- [np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])],
178
- )
179
-
180
-
181
- def test_arviz_from_particles():
182
- model = simple_multivariable_model()
183
- particles = [np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])]
184
- with model:
185
- inference_data = arviz_from_particles(model, particles)
186
-
187
- assert inference_data.posterior.sizes == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2})
188
- assert inference_data.posterior.data_vars.dtypes == Frozen(
189
- {"x": dtype("float64"), "z": dtype("float64")}
190
- )
191
-
192
-
193
- def test_get_jaxified_logprior():
194
- """
195
- Given a model with a Normal prior
196
- for a RV, the jaxified logprior
197
- indeed calculates that number,
198
- and can be jax.vmap'ed
199
- """
200
- logprior = get_jaxified_logprior(fast_model())
201
- for point in [-0.5, 0.0, 0.5]:
202
- jax.tree.map(
203
- np.testing.assert_allclose,
204
- jax.vmap(logprior)([np.array([point])]),
205
- np.log(scipy.stats.norm(0, 1).pdf(point)),
206
- )
207
-
208
-
209
- def test_get_jaxified_loglikelihood():
210
- """
211
- Given a model with a Normal Likelihood, a single observation
212
- 0 and std=1, the only free parameter of that function is the mean.
213
- When computing the logliklikelihood
214
- Then the function can be jax.vmap'ed, and the calculation matches the likelihood.
215
- """
216
- loglikelihood = get_jaxified_loglikelihood(fast_model())
217
- for point in [-0.5, 0.0, 0.5]:
218
- jax.tree.map(
219
- np.testing.assert_allclose,
220
- jax.vmap(loglikelihood)([np.array([point])]),
221
- np.log(scipy.stats.norm(point, 1).pdf(0)),
222
- )
tests/test_find_map.py DELETED
@@ -1,103 +0,0 @@
1
- import numpy as np
2
- import pymc as pm
3
- import pytensor.tensor as pt
4
- import pytest
5
-
6
- from pymc_extras.inference.find_map import (
7
- GradientBackend,
8
- find_MAP,
9
- scipy_optimize_funcs_from_loss,
10
- )
11
-
12
- pytest.importorskip("jax")
13
-
14
-
15
- @pytest.fixture(scope="session")
16
- def rng():
17
- seed = sum(map(ord, "test_fit_map"))
18
- return np.random.default_rng(seed)
19
-
20
-
21
- @pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
22
- def test_jax_functions_from_graph(gradient_backend: GradientBackend):
23
- x = pt.tensor("x", shape=(2,))
24
-
25
- def compute_z(x):
26
- z1 = x[0] ** 2 + 2
27
- z2 = x[0] * x[1] + 3
28
- return z1, z2
29
-
30
- z = pt.stack(compute_z(x))
31
- f_loss, f_hess, f_hessp = scipy_optimize_funcs_from_loss(
32
- loss=z.sum(),
33
- inputs=[x],
34
- initial_point_dict={"x": np.array([1.0, 2.0])},
35
- use_grad=True,
36
- use_hess=True,
37
- use_hessp=True,
38
- gradient_backend=gradient_backend,
39
- compile_kwargs=dict(mode="JAX"),
40
- )
41
-
42
- x_val = np.array([1.0, 2.0])
43
- expected_z = sum(compute_z(x_val))
44
-
45
- z_jax, grad_val = f_loss(x_val)
46
- np.testing.assert_allclose(z_jax, expected_z)
47
- np.testing.assert_allclose(grad_val.squeeze(), np.array([2 * x_val[0] + x_val[1], x_val[0]]))
48
-
49
- hess_val = np.array(f_hess(x_val))
50
- np.testing.assert_allclose(hess_val.squeeze(), np.array([[2, 1], [1, 0]]))
51
-
52
- hessp_val = np.array(f_hessp(x_val, np.array([1.0, 0.0])))
53
- np.testing.assert_allclose(hessp_val.squeeze(), np.array([2, 1]))
54
-
55
-
56
- @pytest.mark.parametrize(
57
- "method, use_grad, use_hess, use_hessp",
58
- [
59
- ("nelder-mead", False, False, False),
60
- ("powell", False, False, False),
61
- ("CG", True, False, False),
62
- ("BFGS", True, False, False),
63
- ("L-BFGS-B", True, False, False),
64
- ("TNC", True, False, False),
65
- ("SLSQP", True, False, False),
66
- ("dogleg", True, True, False),
67
- ("Newton-CG", True, True, False),
68
- ("Newton-CG", True, False, True),
69
- ("trust-ncg", True, True, False),
70
- ("trust-ncg", True, False, True),
71
- ("trust-exact", True, True, False),
72
- ("trust-krylov", True, True, False),
73
- ("trust-krylov", True, False, True),
74
- ("trust-constr", True, True, False),
75
- ],
76
- )
77
- @pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
78
- def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng):
79
- extra_kwargs = {}
80
- if method == "dogleg":
81
- # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
82
- # where this is true
83
- extra_kwargs = {"initvals": {"mu": 2, "sigma_log__": 1}}
84
-
85
- with pm.Model() as m:
86
- mu = pm.Normal("mu")
87
- sigma = pm.Exponential("sigma", 1)
88
- pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100))
89
-
90
- optimized_point = find_MAP(
91
- method=method,
92
- **extra_kwargs,
93
- use_grad=use_grad,
94
- use_hess=use_hess,
95
- use_hessp=use_hessp,
96
- progressbar=False,
97
- gradient_backend=gradient_backend,
98
- compile_kwargs={"mode": "JAX"},
99
- )
100
- mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"]
101
-
102
- assert np.isclose(mu_hat, 3, atol=0.5)
103
- assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)