pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pymc_extras/__init__.py +6 -4
  2. pymc_extras/distributions/__init__.py +2 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/distributions/transforms/__init__.py +3 -0
  6. pymc_extras/distributions/transforms/partial_order.py +227 -0
  7. pymc_extras/inference/__init__.py +4 -2
  8. pymc_extras/inference/find_map.py +62 -17
  9. pymc_extras/inference/fit.py +6 -4
  10. pymc_extras/inference/laplace.py +14 -8
  11. pymc_extras/inference/pathfinder/lbfgs.py +49 -13
  12. pymc_extras/inference/pathfinder/pathfinder.py +89 -103
  13. pymc_extras/statespace/core/statespace.py +191 -52
  14. pymc_extras/statespace/filters/distributions.py +15 -16
  15. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  16. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  17. pymc_extras/statespace/models/ETS.py +10 -0
  18. pymc_extras/statespace/models/SARIMAX.py +26 -5
  19. pymc_extras/statespace/models/VARMAX.py +12 -2
  20. pymc_extras/statespace/models/structural.py +18 -5
  21. pymc_extras/statespace/utils/data_tools.py +24 -9
  22. pymc_extras-0.2.6.dist-info/METADATA +318 -0
  23. pymc_extras-0.2.6.dist-info/RECORD +65 -0
  24. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
  25. pymc_extras/version.py +0 -11
  26. pymc_extras/version.txt +0 -1
  27. pymc_extras-0.2.4.dist-info/METADATA +0 -110
  28. pymc_extras-0.2.4.dist-info/RECORD +0 -105
  29. pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
  30. tests/__init__.py +0 -13
  31. tests/distributions/__init__.py +0 -19
  32. tests/distributions/test_continuous.py +0 -185
  33. tests/distributions/test_discrete.py +0 -210
  34. tests/distributions/test_discrete_markov_chain.py +0 -258
  35. tests/distributions/test_multivariate.py +0 -304
  36. tests/model/__init__.py +0 -0
  37. tests/model/marginal/__init__.py +0 -0
  38. tests/model/marginal/test_distributions.py +0 -132
  39. tests/model/marginal/test_graph_analysis.py +0 -182
  40. tests/model/marginal/test_marginal_model.py +0 -967
  41. tests/model/test_model_api.py +0 -38
  42. tests/statespace/__init__.py +0 -0
  43. tests/statespace/test_ETS.py +0 -411
  44. tests/statespace/test_SARIMAX.py +0 -405
  45. tests/statespace/test_VARMAX.py +0 -184
  46. tests/statespace/test_coord_assignment.py +0 -116
  47. tests/statespace/test_distributions.py +0 -270
  48. tests/statespace/test_kalman_filter.py +0 -326
  49. tests/statespace/test_representation.py +0 -175
  50. tests/statespace/test_statespace.py +0 -872
  51. tests/statespace/test_statespace_JAX.py +0 -156
  52. tests/statespace/test_structural.py +0 -836
  53. tests/statespace/utilities/__init__.py +0 -0
  54. tests/statespace/utilities/shared_fixtures.py +0 -9
  55. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  56. tests/statespace/utilities/test_helpers.py +0 -310
  57. tests/test_blackjax_smc.py +0 -222
  58. tests/test_find_map.py +0 -103
  59. tests/test_histogram_approximation.py +0 -109
  60. tests/test_laplace.py +0 -265
  61. tests/test_linearmodel.py +0 -208
  62. tests/test_model_builder.py +0 -306
  63. tests/test_pathfinder.py +0 -203
  64. tests/test_pivoted_cholesky.py +0 -24
  65. tests/test_printing.py +0 -98
  66. tests/test_prior_from_trace.py +0 -172
  67. tests/test_splines.py +0 -77
  68. tests/utils.py +0 -0
  69. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
@@ -1,270 +0,0 @@
1
- import numpy as np
2
- import pymc as pm
3
- import pytensor
4
- import pytensor.tensor as pt
5
- import pytest
6
-
7
- from numpy.testing import assert_allclose
8
- from scipy.stats import multivariate_normal
9
-
10
- from pymc_extras.statespace import structural
11
- from pymc_extras.statespace.filters.distributions import (
12
- LinearGaussianStateSpace,
13
- SequenceMvNormal,
14
- _LinearGaussianStateSpace,
15
- )
16
- from pymc_extras.statespace.utils.constants import (
17
- ALL_STATE_DIM,
18
- OBS_STATE_DIM,
19
- TIME_DIM,
20
- )
21
- from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
22
- rng,
23
- )
24
- from tests.statespace.utilities.test_helpers import (
25
- delete_rvs_from_model,
26
- fast_eval,
27
- load_nile_test_data,
28
- )
29
-
30
- floatX = pytensor.config.floatX
31
-
32
- # TODO: These are pretty loose because of all the stabilizing of covariance matrices that is done inside the kalman
33
- # filters. When that is improved, this should be tightened.
34
- ATOL = 1e-5 if floatX.endswith("64") else 1e-4
35
- RTOL = 1e-5 if floatX.endswith("64") else 1e-4
36
-
37
- filter_names = [
38
- "standard",
39
- "cholesky",
40
- "univariate",
41
- ]
42
-
43
-
44
- @pytest.fixture(scope="session")
45
- def data():
46
- return load_nile_test_data()
47
-
48
-
49
- @pytest.fixture(scope="session")
50
- def pymc_model(data):
51
- with pm.Model() as mod:
52
- data = pm.Data("data", data.values)
53
- P0_diag = pm.Exponential("P0_diag", 1, shape=(2,))
54
- P0 = pm.Deterministic("P0", pt.diag(P0_diag))
55
- initial_trend = pm.Normal("initial_trend", shape=(2,))
56
- sigma_trend = pm.Exponential("sigma_trend", 1, shape=(2,))
57
-
58
- return mod
59
-
60
-
61
- @pytest.fixture(scope="session")
62
- def pymc_model_2(data):
63
- coords = {
64
- ALL_STATE_DIM: ["level", "trend"],
65
- OBS_STATE_DIM: ["level"],
66
- TIME_DIM: np.arange(101, dtype="int"),
67
- }
68
-
69
- with pm.Model(coords=coords) as mod:
70
- P0_diag = pm.Exponential("P0_diag", 1, shape=(2,))
71
- P0 = pm.Deterministic("P0", pt.diag(P0_diag))
72
- initial_trend = pm.Normal("initial_trend", shape=(2,))
73
- sigma_trend = pm.Exponential("sigma_trend", 1, shape=(2,))
74
- sigma_me = pm.Exponential("sigma_error", 1)
75
-
76
- return mod
77
-
78
-
79
- @pytest.fixture(scope="session")
80
- def ss_mod_me():
81
- ss_mod = structural.LevelTrendComponent(order=2)
82
- ss_mod += structural.MeasurementError(name="error")
83
- ss_mod = ss_mod.build("data", verbose=False)
84
-
85
- return ss_mod
86
-
87
-
88
- @pytest.fixture(scope="session")
89
- def ss_mod_no_me():
90
- ss_mod = structural.LevelTrendComponent(order=2)
91
- ss_mod = ss_mod.build("data", verbose=False)
92
-
93
- return ss_mod
94
-
95
-
96
- @pytest.mark.parametrize("kfilter", filter_names, ids=filter_names)
97
- def test_loglike_vectors_agree(kfilter, pymc_model):
98
- # TODO: This test might be flakey, I've gotten random failures
99
- ss_mod = structural.LevelTrendComponent(order=2).build(
100
- "data", verbose=False, filter_type=kfilter
101
- )
102
- with pymc_model:
103
- ss_mod._insert_random_variables()
104
- matrices = ss_mod.unpack_statespace()
105
-
106
- filter_outputs = ss_mod.kalman_filter.build_graph(pymc_model["data"], *matrices)
107
- filter_mus, pred_mus, obs_mu, filter_covs, pred_covs, obs_cov, ll = filter_outputs
108
-
109
- test_ll = fast_eval(ll)
110
-
111
- # TODO: BUG: Why does fast eval end up with a 2d output when filter is "single"?
112
- obs_mu_np = obs_mu.eval()
113
- obs_cov_np = fast_eval(obs_cov)
114
- data_np = fast_eval(pymc_model["data"])
115
-
116
- scipy_lls = []
117
- for y, mu, cov in zip(data_np, obs_mu_np, obs_cov_np):
118
- scipy_lls.append(multivariate_normal.logpdf(y, mean=mu, cov=cov))
119
- assert_allclose(test_ll, np.array(scipy_lls).ravel(), atol=ATOL, rtol=RTOL)
120
-
121
-
122
- def test_sequence_mvn_distribution():
123
- # Base Case
124
- mu_sequence = pt.tensor("mu_sequence", shape=(100, 3))
125
- cov_sequence = pt.tensor("cov_sequence", shape=(100, 3, 3))
126
- logp = pt.tensor("logp", shape=(100,))
127
-
128
- dist = SequenceMvNormal.dist(mu_sequence, cov_sequence, logp)
129
- assert dist.type.shape == (100, 3)
130
-
131
- # With batch dimension
132
- mu_sequence = pt.tensor("mu_sequence", shape=(10, 100, 3))
133
- cov_sequence = pt.tensor("cov_sequence", shape=(10, 100, 3, 3))
134
- logp = pt.tensor(
135
- "logp",
136
- shape=(
137
- 10,
138
- 100,
139
- ),
140
- )
141
-
142
- dist = SequenceMvNormal.dist(mu_sequence, cov_sequence, logp)
143
- assert dist.type.shape == (10, 100, 3)
144
-
145
-
146
- @pytest.mark.parametrize("output_name", ["states_latent", "states_observed"])
147
- def test_lgss_distribution_from_steps(output_name, ss_mod_me, pymc_model_2):
148
- with pymc_model_2:
149
- ss_mod_me._insert_random_variables()
150
- matrices = ss_mod_me.unpack_statespace()
151
-
152
- # pylint: disable=unpacking-non-sequence
153
- latent_states, obs_states = LinearGaussianStateSpace("states", *matrices, steps=100)
154
- # pylint: enable=unpacking-non-sequence
155
-
156
- idata = pm.sample_prior_predictive(draws=10)
157
- delete_rvs_from_model(["states_latent", "states_observed", "states_combined"])
158
-
159
- assert idata.prior.coords["states_latent_dim_0"].shape == (101,)
160
- assert not np.any(np.isnan(idata.prior[output_name].values))
161
-
162
-
163
- @pytest.mark.parametrize("output_name", ["states_latent", "states_observed"])
164
- def test_lgss_distribution_with_dims(output_name, ss_mod_me, pymc_model_2):
165
- with pymc_model_2:
166
- ss_mod_me._insert_random_variables()
167
- matrices = ss_mod_me.unpack_statespace()
168
-
169
- # pylint: disable=unpacking-non-sequence
170
- latent_states, obs_states = LinearGaussianStateSpace(
171
- "states",
172
- *matrices,
173
- steps=100,
174
- dims=[TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM],
175
- sequence_names=[],
176
- k_endog=ss_mod_me.k_endog,
177
- )
178
- # pylint: enable=unpacking-non-sequence
179
- idata = pm.sample_prior_predictive(draws=10)
180
- delete_rvs_from_model(["states_latent", "states_observed", "states_combined"])
181
-
182
- assert idata.prior.coords["time"].shape == (101,)
183
- assert all(
184
- [dim in idata.prior.states_latent.coords.keys() for dim in [TIME_DIM, ALL_STATE_DIM]]
185
- )
186
- assert all(
187
- [dim in idata.prior.states_observed.coords.keys() for dim in [TIME_DIM, OBS_STATE_DIM]]
188
- )
189
- assert not np.any(np.isnan(idata.prior[output_name].values))
190
-
191
-
192
- @pytest.mark.parametrize("output_name", ["states_latent", "states_observed"])
193
- def test_lgss_with_time_varying_inputs(output_name, rng):
194
- X = rng.random(size=(10, 3), dtype=floatX)
195
- ss_mod = structural.LevelTrendComponent() + structural.RegressionComponent(
196
- name="exog", k_exog=3
197
- )
198
- mod = ss_mod.build("data", verbose=False)
199
-
200
- coords = {
201
- ALL_STATE_DIM: ["level", "trend", "beta_1", "beta_2", "beta_3"],
202
- OBS_STATE_DIM: ["level"],
203
- TIME_DIM: np.arange(10, dtype="int"),
204
- }
205
-
206
- with pm.Model(coords=coords):
207
- exog_data = pm.Data("data_exog", X)
208
- P0_diag = pm.Exponential("P0_diag", 1, shape=(mod.k_states,))
209
- P0 = pm.Deterministic("P0", pt.diag(P0_diag))
210
- initial_trend = pm.Normal("initial_trend", shape=(2,))
211
- sigma_trend = pm.Exponential("sigma_trend", 1, shape=(2,))
212
- beta_exog = pm.Normal("beta_exog", shape=(3,))
213
-
214
- mod._insert_random_variables()
215
- mod._insert_data_variables()
216
- matrices = mod.unpack_statespace()
217
-
218
- # pylint: disable=unpacking-non-sequence
219
- latent_states, obs_states = LinearGaussianStateSpace(
220
- "states",
221
- *matrices,
222
- steps=9,
223
- sequence_names=["d", "Z"],
224
- dims=[TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM],
225
- )
226
- # pylint: enable=unpacking-non-sequence
227
- idata = pm.sample_prior_predictive(draws=10)
228
-
229
- assert idata.prior.coords["time"].shape == (10,)
230
- assert all(
231
- [dim in idata.prior.states_latent.coords.keys() for dim in [TIME_DIM, ALL_STATE_DIM]]
232
- )
233
- assert all(
234
- [dim in idata.prior.states_observed.coords.keys() for dim in [TIME_DIM, OBS_STATE_DIM]]
235
- )
236
- assert not np.any(np.isnan(idata.prior[output_name].values))
237
-
238
-
239
- def test_lgss_signature():
240
- # Base case
241
- x0 = pt.tensor("x0", shape=(None,))
242
- P0 = pt.tensor("P0", shape=(None, None))
243
- c = pt.tensor("c", shape=(None,))
244
- d = pt.tensor("d", shape=(None,))
245
- T = pt.tensor("T", shape=(None, None))
246
- Z = pt.tensor("Z", shape=(None, None))
247
- R = pt.tensor("R", shape=(None, None))
248
- H = pt.tensor("H", shape=(None, None))
249
- Q = pt.tensor("Q", shape=(None, None))
250
-
251
- lgss = _LinearGaussianStateSpace.dist(x0, P0, c, d, T, Z, R, H, Q, steps=100)
252
- assert (
253
- lgss.owner.op.extended_signature
254
- == "(s),(s,s),(s),(p),(s,s),(p,s),(s,r),(p,p),(r,r),[rng]->[rng],(t,n)"
255
- )
256
- assert lgss.owner.op.ndim_supp == 2
257
- assert lgss.owner.op.ndims_params == [1, 2, 1, 1, 2, 2, 2, 2, 2]
258
-
259
- # Case with time-varying matrices
260
- T = pt.tensor("T", shape=(None, None, None))
261
- lgss = _LinearGaussianStateSpace.dist(
262
- x0, P0, c, d, T, Z, R, H, Q, steps=100, sequence_names=["T"]
263
- )
264
-
265
- assert (
266
- lgss.owner.op.extended_signature
267
- == "(s),(s,s),(s),(p),(t,s,s),(p,s),(s,r),(p,p),(r,r),[rng]->[rng],(t,n)"
268
- )
269
- assert lgss.owner.op.ndim_supp == 2
270
- assert lgss.owner.op.ndims_params == [1, 2, 1, 1, 3, 2, 2, 2, 2]
@@ -1,326 +0,0 @@
1
- import numpy as np
2
- import pytensor
3
- import pytensor.tensor as pt
4
- import pytest
5
-
6
- from numpy.testing import assert_allclose, assert_array_less
7
-
8
- from pymc_extras.statespace.filters import (
9
- KalmanSmoother,
10
- SquareRootFilter,
11
- StandardFilter,
12
- UnivariateFilter,
13
- )
14
- from pymc_extras.statespace.filters.kalman_filter import BaseFilter
15
- from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
16
- rng,
17
- )
18
- from tests.statespace.utilities.test_helpers import (
19
- get_expected_shape,
20
- get_sm_state_from_output_name,
21
- initialize_filter,
22
- make_test_inputs,
23
- nile_test_test_helper,
24
- )
25
-
26
- floatX = pytensor.config.floatX
27
-
28
- # TODO: These are pretty loose because of all the stabilizing of covariance matrices that is done inside the kalman
29
- # filters. When that is improved, this should be tightened.
30
- ATOL = 1e-6 if floatX.endswith("64") else 1e-3
31
- RTOL = 1e-6 if floatX.endswith("64") else 1e-3
32
-
33
- standard_inout = initialize_filter(StandardFilter())
34
- cholesky_inout = initialize_filter(SquareRootFilter())
35
- univariate_inout = initialize_filter(UnivariateFilter())
36
-
37
- f_standard = pytensor.function(*standard_inout, on_unused_input="ignore")
38
- f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
39
- f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore")
40
-
41
- filter_funcs = [f_standard, f_cholesky, f_univariate]
42
-
43
- filter_names = [
44
- "StandardFilter",
45
- "CholeskyFilter",
46
- "UnivariateFilter",
47
- ]
48
-
49
- output_names = [
50
- "filtered_states",
51
- "predicted_states",
52
- "smoothed_states",
53
- "filtered_covs",
54
- "predicted_covs",
55
- "smoothed_covs",
56
- "log_likelihood",
57
- "ll_obs",
58
- ]
59
-
60
-
61
- def test_base_class_update_raises():
62
- filter = BaseFilter()
63
- inputs = [None] * 7
64
- with pytest.raises(NotImplementedError):
65
- filter.update(*inputs)
66
-
67
-
68
- @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
69
- def test_output_shapes_one_state_one_observed(filter_func, rng):
70
- p, m, r, n = 1, 1, 1, 10
71
- inputs = make_test_inputs(p, m, r, n, rng)
72
- outputs = filter_func(*inputs)
73
-
74
- for output_idx, name in enumerate(output_names):
75
- expected_output = get_expected_shape(name, p, m, r, n)
76
- assert (
77
- outputs[output_idx].shape == expected_output
78
- ), f"Shape of {name} does not match expected"
79
-
80
-
81
- @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
82
- def test_output_shapes_when_all_states_are_stochastic(filter_func, rng):
83
- p, m, r, n = 1, 2, 2, 10
84
- inputs = make_test_inputs(p, m, r, n, rng)
85
-
86
- outputs = filter_func(*inputs)
87
- for output_idx, name in enumerate(output_names):
88
- expected_output = get_expected_shape(name, p, m, r, n)
89
- assert (
90
- outputs[output_idx].shape == expected_output
91
- ), f"Shape of {name} does not match expected"
92
-
93
-
94
- @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
95
- def test_output_shapes_when_some_states_are_deterministic(filter_func, rng):
96
- p, m, r, n = 1, 5, 2, 10
97
- inputs = make_test_inputs(p, m, r, n, rng)
98
-
99
- outputs = filter_func(*inputs)
100
- for output_idx, name in enumerate(output_names):
101
- expected_output = get_expected_shape(name, p, m, r, n)
102
- assert (
103
- outputs[output_idx].shape == expected_output
104
- ), f"Shape of {name} does not match expected"
105
-
106
-
107
- @pytest.fixture
108
- def f_standard_nd():
109
- ksmoother = KalmanSmoother()
110
- data = pt.tensor(name="data", dtype=floatX, shape=(None, None))
111
- a0 = pt.vector(name="a0", dtype=floatX)
112
- P0 = pt.matrix(name="P0", dtype=floatX)
113
- c = pt.vector(name="c", dtype=floatX)
114
- d = pt.vector(name="d", dtype=floatX)
115
- Q = pt.tensor(name="Q", dtype=floatX, shape=(None, None, None))
116
- H = pt.tensor(name="H", dtype=floatX, shape=(None, None, None))
117
- T = pt.tensor(name="T", dtype=floatX, shape=(None, None, None))
118
- R = pt.tensor(name="R", dtype=floatX, shape=(None, None, None))
119
- Z = pt.tensor(name="Z", dtype=floatX, shape=(None, None, None))
120
-
121
- inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
122
-
123
- (
124
- filtered_states,
125
- predicted_states,
126
- observed_states,
127
- filtered_covs,
128
- predicted_covs,
129
- observed_covs,
130
- ll_obs,
131
- ) = StandardFilter().build_graph(*inputs)
132
-
133
- smoothed_states, smoothed_covs = ksmoother.build_graph(T, R, Q, filtered_states, filtered_covs)
134
-
135
- outputs = [
136
- filtered_states,
137
- predicted_states,
138
- smoothed_states,
139
- filtered_covs,
140
- predicted_covs,
141
- smoothed_covs,
142
- ll_obs.sum(),
143
- ll_obs,
144
- ]
145
-
146
- f_standard = pytensor.function(inputs, outputs)
147
-
148
- return f_standard
149
-
150
-
151
- def test_output_shapes_with_time_varying_matrices(f_standard_nd, rng):
152
- p, m, r, n = 1, 5, 2, 10
153
- data, a0, P0, c, d, T, Z, R, H, Q = make_test_inputs(p, m, r, n, rng)
154
- T = np.concatenate([np.expand_dims(T, 0)] * n, axis=0)
155
- Z = np.concatenate([np.expand_dims(Z, 0)] * n, axis=0)
156
- R = np.concatenate([np.expand_dims(R, 0)] * n, axis=0)
157
- H = np.concatenate([np.expand_dims(H, 0)] * n, axis=0)
158
- Q = np.concatenate([np.expand_dims(Q, 0)] * n, axis=0)
159
-
160
- outputs = f_standard_nd(data, a0, P0, c, d, T, Z, R, H, Q)
161
-
162
- for output_idx, name in enumerate(output_names):
163
- expected_output = get_expected_shape(name, p, m, r, n)
164
- assert (
165
- outputs[output_idx].shape == expected_output
166
- ), f"Shape of {name} does not match expected"
167
-
168
-
169
- @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
170
- def test_output_with_deterministic_observation_equation(filter_func, rng):
171
- p, m, r, n = 1, 5, 1, 10
172
- inputs = make_test_inputs(p, m, r, n, rng)
173
-
174
- outputs = filter_func(*inputs)
175
-
176
- for output_idx, name in enumerate(output_names):
177
- expected_output = get_expected_shape(name, p, m, r, n)
178
- assert (
179
- outputs[output_idx].shape == expected_output
180
- ), f"Shape of {name} does not match expected"
181
-
182
-
183
- @pytest.mark.parametrize(
184
- ("filter_func", "filter_name"), zip(filter_funcs, filter_names), ids=filter_names
185
- )
186
- def test_output_with_multiple_observed(filter_func, filter_name, rng):
187
- p, m, r, n = 5, 5, 1, 10
188
- inputs = make_test_inputs(p, m, r, n, rng)
189
-
190
- outputs = filter_func(*inputs)
191
- for output_idx, name in enumerate(output_names):
192
- expected_output = get_expected_shape(name, p, m, r, n)
193
- assert (
194
- outputs[output_idx].shape == expected_output
195
- ), f"Shape of {name} does not match expected"
196
-
197
-
198
- @pytest.mark.parametrize(
199
- ("filter_func", "filter_name"), zip(filter_funcs, filter_names), ids=filter_names
200
- )
201
- @pytest.mark.parametrize("p", [1, 5], ids=["univariate (p=1)", "multivariate (p=5)"])
202
- def test_missing_data(filter_func, filter_name, p, rng):
203
- m, r, n = 5, 1, 10
204
- inputs = make_test_inputs(p, m, r, n, rng, missing_data=1)
205
-
206
- outputs = filter_func(*inputs)
207
- for output_idx, name in enumerate(output_names):
208
- expected_output = get_expected_shape(name, p, m, r, n)
209
- assert (
210
- outputs[output_idx].shape == expected_output
211
- ), f"Shape of {name} does not match expected"
212
-
213
-
214
- @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
215
- @pytest.mark.parametrize("output_idx", [(0, 2), (3, 5)], ids=["smoothed_states", "smoothed_covs"])
216
- def test_last_smoother_is_last_filtered(filter_func, output_idx, rng):
217
- p, m, r, n = 1, 5, 1, 10
218
- inputs = make_test_inputs(p, m, r, n, rng)
219
- outputs = filter_func(*inputs)
220
-
221
- filtered = outputs[output_idx[0]]
222
- smoothed = outputs[output_idx[1]]
223
-
224
- assert_allclose(filtered[-1], smoothed[-1])
225
-
226
-
227
- @pytest.mark.parametrize(
228
- "filter_func, filter_name", zip(filter_funcs, filter_names), ids=filter_names
229
- )
230
- @pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"])
231
- @pytest.mark.skipif(floatX == "float32", reason="Tests are too sensitive for float32")
232
- def test_filters_match_statsmodel_output(filter_func, filter_name, n_missing, rng):
233
- fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
234
- if filter_name == "CholeskyFilter":
235
- P0 = np.linalg.cholesky(P0)
236
- inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
237
- outputs = filter_func(*inputs)
238
-
239
- for output_idx, name in enumerate(output_names):
240
- ref_val = get_sm_state_from_output_name(fit_sm_mod, name)
241
- val_to_test = outputs[output_idx].squeeze()
242
-
243
- if name == "smoothed_covs":
244
- # TODO: The smoothed covariance matrices have large errors (1e-2) ONLY in the first few states -- no idea why.
245
- assert_allclose(
246
- val_to_test[5:],
247
- ref_val[5:],
248
- atol=ATOL,
249
- rtol=RTOL,
250
- err_msg=f"{name} does not match statsmodels",
251
- )
252
- elif name.startswith("predicted"):
253
- # statsmodels doesn't throw away the T+1 forecast in the predicted states like we do
254
- assert_allclose(
255
- val_to_test,
256
- ref_val[:-1],
257
- atol=ATOL,
258
- rtol=RTOL,
259
- err_msg=f"{name} does not match statsmodels",
260
- )
261
- else:
262
- # Need atol = 1e-7 for smoother tests to pass
263
- assert_allclose(
264
- val_to_test,
265
- ref_val,
266
- atol=ATOL,
267
- rtol=RTOL,
268
- err_msg=f"{name} does not match statsmodels",
269
- )
270
-
271
-
272
- @pytest.mark.parametrize(
273
- "filter_func, filter_name", zip(filter_funcs[:-1], filter_names[:-1]), ids=filter_names[:-1]
274
- )
275
- @pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"])
276
- @pytest.mark.parametrize("obs_noise", [True, False])
277
- def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, obs_noise, rng):
278
- if (floatX == "float32") & (filter_name == "UnivariateFilter"):
279
- # TODO: These tests all pass locally for me with float32 but they fail on the CI, so i'm just disabling them.
280
- pytest.skip("Univariate filter not stable at half precision without measurement error")
281
-
282
- fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
283
- if filter_name == "CholeskyFilter":
284
- P0 = np.linalg.cholesky(P0)
285
-
286
- H *= int(obs_noise)
287
- inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
288
- outputs = filter_func(*inputs)
289
-
290
- for output_idx, name in zip([3, 4, 5], output_names[3:-2]):
291
- cov_stack = outputs[output_idx]
292
- w, v = np.linalg.eig(cov_stack)
293
-
294
- assert_array_less(0, w, err_msg=f"Smallest eigenvalue of {name}: {min(w.ravel())}")
295
- assert_allclose(
296
- cov_stack,
297
- np.swapaxes(cov_stack, -2, -1),
298
- rtol=RTOL,
299
- atol=ATOL,
300
- err_msg=f"{name} is not symmetrical",
301
- )
302
-
303
-
304
- @pytest.mark.parametrize(
305
- "filter",
306
- [StandardFilter, SquareRootFilter],
307
- ids=["standard", "cholesky"],
308
- )
309
- def test_kalman_filter_jax(filter):
310
- pytest.importorskip("jax")
311
- from pymc.sampling.jax import get_jaxified_graph
312
-
313
- # TODO: Add UnivariateFilter to test; need to figure out the broadcasting issue when 2nd data dim is defined
314
-
315
- p, m, r, n = 1, 5, 1, 10
316
- inputs, outputs = initialize_filter(filter(), mode="JAX", p=p, m=m, r=r, n=n)
317
- inputs_np = make_test_inputs(p, m, r, n, rng)
318
-
319
- f_jax = get_jaxified_graph(inputs, outputs)
320
- f_pt = pytensor.function(inputs, outputs, mode="FAST_COMPILE")
321
-
322
- jax_outputs = f_jax(*inputs_np)
323
- pt_outputs = f_pt(*inputs_np)
324
-
325
- for name, jax_res, pt_res in zip(output_names, jax_outputs, pt_outputs):
326
- assert_allclose(jax_res, pt_res, atol=ATOL, rtol=RTOL, err_msg=f"{name} failed!")