pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +6 -4
- pymc_extras/distributions/__init__.py +2 -0
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/distributions/transforms/__init__.py +3 -0
- pymc_extras/distributions/transforms/partial_order.py +227 -0
- pymc_extras/inference/__init__.py +4 -2
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/fit.py +6 -4
- pymc_extras/inference/laplace.py +14 -8
- pymc_extras/inference/pathfinder/lbfgs.py +49 -13
- pymc_extras/inference/pathfinder/pathfinder.py +89 -103
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras/statespace/utils/data_tools.py +24 -9
- pymc_extras-0.2.6.dist-info/METADATA +318 -0
- pymc_extras-0.2.6.dist-info/RECORD +65 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.4.dist-info/METADATA +0 -110
- pymc_extras-0.2.4.dist-info/RECORD +0 -105
- pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -116
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -265
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -203
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
|
@@ -1,270 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pymc as pm
|
|
3
|
-
import pytensor
|
|
4
|
-
import pytensor.tensor as pt
|
|
5
|
-
import pytest
|
|
6
|
-
|
|
7
|
-
from numpy.testing import assert_allclose
|
|
8
|
-
from scipy.stats import multivariate_normal
|
|
9
|
-
|
|
10
|
-
from pymc_extras.statespace import structural
|
|
11
|
-
from pymc_extras.statespace.filters.distributions import (
|
|
12
|
-
LinearGaussianStateSpace,
|
|
13
|
-
SequenceMvNormal,
|
|
14
|
-
_LinearGaussianStateSpace,
|
|
15
|
-
)
|
|
16
|
-
from pymc_extras.statespace.utils.constants import (
|
|
17
|
-
ALL_STATE_DIM,
|
|
18
|
-
OBS_STATE_DIM,
|
|
19
|
-
TIME_DIM,
|
|
20
|
-
)
|
|
21
|
-
from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
|
|
22
|
-
rng,
|
|
23
|
-
)
|
|
24
|
-
from tests.statespace.utilities.test_helpers import (
|
|
25
|
-
delete_rvs_from_model,
|
|
26
|
-
fast_eval,
|
|
27
|
-
load_nile_test_data,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
floatX = pytensor.config.floatX
|
|
31
|
-
|
|
32
|
-
# TODO: These are pretty loose because of all the stabilizing of covariance matrices that is done inside the kalman
|
|
33
|
-
# filters. When that is improved, this should be tightened.
|
|
34
|
-
ATOL = 1e-5 if floatX.endswith("64") else 1e-4
|
|
35
|
-
RTOL = 1e-5 if floatX.endswith("64") else 1e-4
|
|
36
|
-
|
|
37
|
-
filter_names = [
|
|
38
|
-
"standard",
|
|
39
|
-
"cholesky",
|
|
40
|
-
"univariate",
|
|
41
|
-
]
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
@pytest.fixture(scope="session")
|
|
45
|
-
def data():
|
|
46
|
-
return load_nile_test_data()
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@pytest.fixture(scope="session")
|
|
50
|
-
def pymc_model(data):
|
|
51
|
-
with pm.Model() as mod:
|
|
52
|
-
data = pm.Data("data", data.values)
|
|
53
|
-
P0_diag = pm.Exponential("P0_diag", 1, shape=(2,))
|
|
54
|
-
P0 = pm.Deterministic("P0", pt.diag(P0_diag))
|
|
55
|
-
initial_trend = pm.Normal("initial_trend", shape=(2,))
|
|
56
|
-
sigma_trend = pm.Exponential("sigma_trend", 1, shape=(2,))
|
|
57
|
-
|
|
58
|
-
return mod
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
@pytest.fixture(scope="session")
|
|
62
|
-
def pymc_model_2(data):
|
|
63
|
-
coords = {
|
|
64
|
-
ALL_STATE_DIM: ["level", "trend"],
|
|
65
|
-
OBS_STATE_DIM: ["level"],
|
|
66
|
-
TIME_DIM: np.arange(101, dtype="int"),
|
|
67
|
-
}
|
|
68
|
-
|
|
69
|
-
with pm.Model(coords=coords) as mod:
|
|
70
|
-
P0_diag = pm.Exponential("P0_diag", 1, shape=(2,))
|
|
71
|
-
P0 = pm.Deterministic("P0", pt.diag(P0_diag))
|
|
72
|
-
initial_trend = pm.Normal("initial_trend", shape=(2,))
|
|
73
|
-
sigma_trend = pm.Exponential("sigma_trend", 1, shape=(2,))
|
|
74
|
-
sigma_me = pm.Exponential("sigma_error", 1)
|
|
75
|
-
|
|
76
|
-
return mod
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
@pytest.fixture(scope="session")
|
|
80
|
-
def ss_mod_me():
|
|
81
|
-
ss_mod = structural.LevelTrendComponent(order=2)
|
|
82
|
-
ss_mod += structural.MeasurementError(name="error")
|
|
83
|
-
ss_mod = ss_mod.build("data", verbose=False)
|
|
84
|
-
|
|
85
|
-
return ss_mod
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
@pytest.fixture(scope="session")
|
|
89
|
-
def ss_mod_no_me():
|
|
90
|
-
ss_mod = structural.LevelTrendComponent(order=2)
|
|
91
|
-
ss_mod = ss_mod.build("data", verbose=False)
|
|
92
|
-
|
|
93
|
-
return ss_mod
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
@pytest.mark.parametrize("kfilter", filter_names, ids=filter_names)
|
|
97
|
-
def test_loglike_vectors_agree(kfilter, pymc_model):
|
|
98
|
-
# TODO: This test might be flakey, I've gotten random failures
|
|
99
|
-
ss_mod = structural.LevelTrendComponent(order=2).build(
|
|
100
|
-
"data", verbose=False, filter_type=kfilter
|
|
101
|
-
)
|
|
102
|
-
with pymc_model:
|
|
103
|
-
ss_mod._insert_random_variables()
|
|
104
|
-
matrices = ss_mod.unpack_statespace()
|
|
105
|
-
|
|
106
|
-
filter_outputs = ss_mod.kalman_filter.build_graph(pymc_model["data"], *matrices)
|
|
107
|
-
filter_mus, pred_mus, obs_mu, filter_covs, pred_covs, obs_cov, ll = filter_outputs
|
|
108
|
-
|
|
109
|
-
test_ll = fast_eval(ll)
|
|
110
|
-
|
|
111
|
-
# TODO: BUG: Why does fast eval end up with a 2d output when filter is "single"?
|
|
112
|
-
obs_mu_np = obs_mu.eval()
|
|
113
|
-
obs_cov_np = fast_eval(obs_cov)
|
|
114
|
-
data_np = fast_eval(pymc_model["data"])
|
|
115
|
-
|
|
116
|
-
scipy_lls = []
|
|
117
|
-
for y, mu, cov in zip(data_np, obs_mu_np, obs_cov_np):
|
|
118
|
-
scipy_lls.append(multivariate_normal.logpdf(y, mean=mu, cov=cov))
|
|
119
|
-
assert_allclose(test_ll, np.array(scipy_lls).ravel(), atol=ATOL, rtol=RTOL)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def test_sequence_mvn_distribution():
|
|
123
|
-
# Base Case
|
|
124
|
-
mu_sequence = pt.tensor("mu_sequence", shape=(100, 3))
|
|
125
|
-
cov_sequence = pt.tensor("cov_sequence", shape=(100, 3, 3))
|
|
126
|
-
logp = pt.tensor("logp", shape=(100,))
|
|
127
|
-
|
|
128
|
-
dist = SequenceMvNormal.dist(mu_sequence, cov_sequence, logp)
|
|
129
|
-
assert dist.type.shape == (100, 3)
|
|
130
|
-
|
|
131
|
-
# With batch dimension
|
|
132
|
-
mu_sequence = pt.tensor("mu_sequence", shape=(10, 100, 3))
|
|
133
|
-
cov_sequence = pt.tensor("cov_sequence", shape=(10, 100, 3, 3))
|
|
134
|
-
logp = pt.tensor(
|
|
135
|
-
"logp",
|
|
136
|
-
shape=(
|
|
137
|
-
10,
|
|
138
|
-
100,
|
|
139
|
-
),
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
dist = SequenceMvNormal.dist(mu_sequence, cov_sequence, logp)
|
|
143
|
-
assert dist.type.shape == (10, 100, 3)
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
@pytest.mark.parametrize("output_name", ["states_latent", "states_observed"])
|
|
147
|
-
def test_lgss_distribution_from_steps(output_name, ss_mod_me, pymc_model_2):
|
|
148
|
-
with pymc_model_2:
|
|
149
|
-
ss_mod_me._insert_random_variables()
|
|
150
|
-
matrices = ss_mod_me.unpack_statespace()
|
|
151
|
-
|
|
152
|
-
# pylint: disable=unpacking-non-sequence
|
|
153
|
-
latent_states, obs_states = LinearGaussianStateSpace("states", *matrices, steps=100)
|
|
154
|
-
# pylint: enable=unpacking-non-sequence
|
|
155
|
-
|
|
156
|
-
idata = pm.sample_prior_predictive(draws=10)
|
|
157
|
-
delete_rvs_from_model(["states_latent", "states_observed", "states_combined"])
|
|
158
|
-
|
|
159
|
-
assert idata.prior.coords["states_latent_dim_0"].shape == (101,)
|
|
160
|
-
assert not np.any(np.isnan(idata.prior[output_name].values))
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
@pytest.mark.parametrize("output_name", ["states_latent", "states_observed"])
|
|
164
|
-
def test_lgss_distribution_with_dims(output_name, ss_mod_me, pymc_model_2):
|
|
165
|
-
with pymc_model_2:
|
|
166
|
-
ss_mod_me._insert_random_variables()
|
|
167
|
-
matrices = ss_mod_me.unpack_statespace()
|
|
168
|
-
|
|
169
|
-
# pylint: disable=unpacking-non-sequence
|
|
170
|
-
latent_states, obs_states = LinearGaussianStateSpace(
|
|
171
|
-
"states",
|
|
172
|
-
*matrices,
|
|
173
|
-
steps=100,
|
|
174
|
-
dims=[TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM],
|
|
175
|
-
sequence_names=[],
|
|
176
|
-
k_endog=ss_mod_me.k_endog,
|
|
177
|
-
)
|
|
178
|
-
# pylint: enable=unpacking-non-sequence
|
|
179
|
-
idata = pm.sample_prior_predictive(draws=10)
|
|
180
|
-
delete_rvs_from_model(["states_latent", "states_observed", "states_combined"])
|
|
181
|
-
|
|
182
|
-
assert idata.prior.coords["time"].shape == (101,)
|
|
183
|
-
assert all(
|
|
184
|
-
[dim in idata.prior.states_latent.coords.keys() for dim in [TIME_DIM, ALL_STATE_DIM]]
|
|
185
|
-
)
|
|
186
|
-
assert all(
|
|
187
|
-
[dim in idata.prior.states_observed.coords.keys() for dim in [TIME_DIM, OBS_STATE_DIM]]
|
|
188
|
-
)
|
|
189
|
-
assert not np.any(np.isnan(idata.prior[output_name].values))
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
@pytest.mark.parametrize("output_name", ["states_latent", "states_observed"])
|
|
193
|
-
def test_lgss_with_time_varying_inputs(output_name, rng):
|
|
194
|
-
X = rng.random(size=(10, 3), dtype=floatX)
|
|
195
|
-
ss_mod = structural.LevelTrendComponent() + structural.RegressionComponent(
|
|
196
|
-
name="exog", k_exog=3
|
|
197
|
-
)
|
|
198
|
-
mod = ss_mod.build("data", verbose=False)
|
|
199
|
-
|
|
200
|
-
coords = {
|
|
201
|
-
ALL_STATE_DIM: ["level", "trend", "beta_1", "beta_2", "beta_3"],
|
|
202
|
-
OBS_STATE_DIM: ["level"],
|
|
203
|
-
TIME_DIM: np.arange(10, dtype="int"),
|
|
204
|
-
}
|
|
205
|
-
|
|
206
|
-
with pm.Model(coords=coords):
|
|
207
|
-
exog_data = pm.Data("data_exog", X)
|
|
208
|
-
P0_diag = pm.Exponential("P0_diag", 1, shape=(mod.k_states,))
|
|
209
|
-
P0 = pm.Deterministic("P0", pt.diag(P0_diag))
|
|
210
|
-
initial_trend = pm.Normal("initial_trend", shape=(2,))
|
|
211
|
-
sigma_trend = pm.Exponential("sigma_trend", 1, shape=(2,))
|
|
212
|
-
beta_exog = pm.Normal("beta_exog", shape=(3,))
|
|
213
|
-
|
|
214
|
-
mod._insert_random_variables()
|
|
215
|
-
mod._insert_data_variables()
|
|
216
|
-
matrices = mod.unpack_statespace()
|
|
217
|
-
|
|
218
|
-
# pylint: disable=unpacking-non-sequence
|
|
219
|
-
latent_states, obs_states = LinearGaussianStateSpace(
|
|
220
|
-
"states",
|
|
221
|
-
*matrices,
|
|
222
|
-
steps=9,
|
|
223
|
-
sequence_names=["d", "Z"],
|
|
224
|
-
dims=[TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM],
|
|
225
|
-
)
|
|
226
|
-
# pylint: enable=unpacking-non-sequence
|
|
227
|
-
idata = pm.sample_prior_predictive(draws=10)
|
|
228
|
-
|
|
229
|
-
assert idata.prior.coords["time"].shape == (10,)
|
|
230
|
-
assert all(
|
|
231
|
-
[dim in idata.prior.states_latent.coords.keys() for dim in [TIME_DIM, ALL_STATE_DIM]]
|
|
232
|
-
)
|
|
233
|
-
assert all(
|
|
234
|
-
[dim in idata.prior.states_observed.coords.keys() for dim in [TIME_DIM, OBS_STATE_DIM]]
|
|
235
|
-
)
|
|
236
|
-
assert not np.any(np.isnan(idata.prior[output_name].values))
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
def test_lgss_signature():
|
|
240
|
-
# Base case
|
|
241
|
-
x0 = pt.tensor("x0", shape=(None,))
|
|
242
|
-
P0 = pt.tensor("P0", shape=(None, None))
|
|
243
|
-
c = pt.tensor("c", shape=(None,))
|
|
244
|
-
d = pt.tensor("d", shape=(None,))
|
|
245
|
-
T = pt.tensor("T", shape=(None, None))
|
|
246
|
-
Z = pt.tensor("Z", shape=(None, None))
|
|
247
|
-
R = pt.tensor("R", shape=(None, None))
|
|
248
|
-
H = pt.tensor("H", shape=(None, None))
|
|
249
|
-
Q = pt.tensor("Q", shape=(None, None))
|
|
250
|
-
|
|
251
|
-
lgss = _LinearGaussianStateSpace.dist(x0, P0, c, d, T, Z, R, H, Q, steps=100)
|
|
252
|
-
assert (
|
|
253
|
-
lgss.owner.op.extended_signature
|
|
254
|
-
== "(s),(s,s),(s),(p),(s,s),(p,s),(s,r),(p,p),(r,r),[rng]->[rng],(t,n)"
|
|
255
|
-
)
|
|
256
|
-
assert lgss.owner.op.ndim_supp == 2
|
|
257
|
-
assert lgss.owner.op.ndims_params == [1, 2, 1, 1, 2, 2, 2, 2, 2]
|
|
258
|
-
|
|
259
|
-
# Case with time-varying matrices
|
|
260
|
-
T = pt.tensor("T", shape=(None, None, None))
|
|
261
|
-
lgss = _LinearGaussianStateSpace.dist(
|
|
262
|
-
x0, P0, c, d, T, Z, R, H, Q, steps=100, sequence_names=["T"]
|
|
263
|
-
)
|
|
264
|
-
|
|
265
|
-
assert (
|
|
266
|
-
lgss.owner.op.extended_signature
|
|
267
|
-
== "(s),(s,s),(s),(p),(t,s,s),(p,s),(s,r),(p,p),(r,r),[rng]->[rng],(t,n)"
|
|
268
|
-
)
|
|
269
|
-
assert lgss.owner.op.ndim_supp == 2
|
|
270
|
-
assert lgss.owner.op.ndims_params == [1, 2, 1, 1, 3, 2, 2, 2, 2]
|
|
@@ -1,326 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pytensor
|
|
3
|
-
import pytensor.tensor as pt
|
|
4
|
-
import pytest
|
|
5
|
-
|
|
6
|
-
from numpy.testing import assert_allclose, assert_array_less
|
|
7
|
-
|
|
8
|
-
from pymc_extras.statespace.filters import (
|
|
9
|
-
KalmanSmoother,
|
|
10
|
-
SquareRootFilter,
|
|
11
|
-
StandardFilter,
|
|
12
|
-
UnivariateFilter,
|
|
13
|
-
)
|
|
14
|
-
from pymc_extras.statespace.filters.kalman_filter import BaseFilter
|
|
15
|
-
from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
|
|
16
|
-
rng,
|
|
17
|
-
)
|
|
18
|
-
from tests.statespace.utilities.test_helpers import (
|
|
19
|
-
get_expected_shape,
|
|
20
|
-
get_sm_state_from_output_name,
|
|
21
|
-
initialize_filter,
|
|
22
|
-
make_test_inputs,
|
|
23
|
-
nile_test_test_helper,
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
floatX = pytensor.config.floatX
|
|
27
|
-
|
|
28
|
-
# TODO: These are pretty loose because of all the stabilizing of covariance matrices that is done inside the kalman
|
|
29
|
-
# filters. When that is improved, this should be tightened.
|
|
30
|
-
ATOL = 1e-6 if floatX.endswith("64") else 1e-3
|
|
31
|
-
RTOL = 1e-6 if floatX.endswith("64") else 1e-3
|
|
32
|
-
|
|
33
|
-
standard_inout = initialize_filter(StandardFilter())
|
|
34
|
-
cholesky_inout = initialize_filter(SquareRootFilter())
|
|
35
|
-
univariate_inout = initialize_filter(UnivariateFilter())
|
|
36
|
-
|
|
37
|
-
f_standard = pytensor.function(*standard_inout, on_unused_input="ignore")
|
|
38
|
-
f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
|
|
39
|
-
f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore")
|
|
40
|
-
|
|
41
|
-
filter_funcs = [f_standard, f_cholesky, f_univariate]
|
|
42
|
-
|
|
43
|
-
filter_names = [
|
|
44
|
-
"StandardFilter",
|
|
45
|
-
"CholeskyFilter",
|
|
46
|
-
"UnivariateFilter",
|
|
47
|
-
]
|
|
48
|
-
|
|
49
|
-
output_names = [
|
|
50
|
-
"filtered_states",
|
|
51
|
-
"predicted_states",
|
|
52
|
-
"smoothed_states",
|
|
53
|
-
"filtered_covs",
|
|
54
|
-
"predicted_covs",
|
|
55
|
-
"smoothed_covs",
|
|
56
|
-
"log_likelihood",
|
|
57
|
-
"ll_obs",
|
|
58
|
-
]
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def test_base_class_update_raises():
|
|
62
|
-
filter = BaseFilter()
|
|
63
|
-
inputs = [None] * 7
|
|
64
|
-
with pytest.raises(NotImplementedError):
|
|
65
|
-
filter.update(*inputs)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
|
|
69
|
-
def test_output_shapes_one_state_one_observed(filter_func, rng):
|
|
70
|
-
p, m, r, n = 1, 1, 1, 10
|
|
71
|
-
inputs = make_test_inputs(p, m, r, n, rng)
|
|
72
|
-
outputs = filter_func(*inputs)
|
|
73
|
-
|
|
74
|
-
for output_idx, name in enumerate(output_names):
|
|
75
|
-
expected_output = get_expected_shape(name, p, m, r, n)
|
|
76
|
-
assert (
|
|
77
|
-
outputs[output_idx].shape == expected_output
|
|
78
|
-
), f"Shape of {name} does not match expected"
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
|
|
82
|
-
def test_output_shapes_when_all_states_are_stochastic(filter_func, rng):
|
|
83
|
-
p, m, r, n = 1, 2, 2, 10
|
|
84
|
-
inputs = make_test_inputs(p, m, r, n, rng)
|
|
85
|
-
|
|
86
|
-
outputs = filter_func(*inputs)
|
|
87
|
-
for output_idx, name in enumerate(output_names):
|
|
88
|
-
expected_output = get_expected_shape(name, p, m, r, n)
|
|
89
|
-
assert (
|
|
90
|
-
outputs[output_idx].shape == expected_output
|
|
91
|
-
), f"Shape of {name} does not match expected"
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
|
|
95
|
-
def test_output_shapes_when_some_states_are_deterministic(filter_func, rng):
|
|
96
|
-
p, m, r, n = 1, 5, 2, 10
|
|
97
|
-
inputs = make_test_inputs(p, m, r, n, rng)
|
|
98
|
-
|
|
99
|
-
outputs = filter_func(*inputs)
|
|
100
|
-
for output_idx, name in enumerate(output_names):
|
|
101
|
-
expected_output = get_expected_shape(name, p, m, r, n)
|
|
102
|
-
assert (
|
|
103
|
-
outputs[output_idx].shape == expected_output
|
|
104
|
-
), f"Shape of {name} does not match expected"
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
@pytest.fixture
|
|
108
|
-
def f_standard_nd():
|
|
109
|
-
ksmoother = KalmanSmoother()
|
|
110
|
-
data = pt.tensor(name="data", dtype=floatX, shape=(None, None))
|
|
111
|
-
a0 = pt.vector(name="a0", dtype=floatX)
|
|
112
|
-
P0 = pt.matrix(name="P0", dtype=floatX)
|
|
113
|
-
c = pt.vector(name="c", dtype=floatX)
|
|
114
|
-
d = pt.vector(name="d", dtype=floatX)
|
|
115
|
-
Q = pt.tensor(name="Q", dtype=floatX, shape=(None, None, None))
|
|
116
|
-
H = pt.tensor(name="H", dtype=floatX, shape=(None, None, None))
|
|
117
|
-
T = pt.tensor(name="T", dtype=floatX, shape=(None, None, None))
|
|
118
|
-
R = pt.tensor(name="R", dtype=floatX, shape=(None, None, None))
|
|
119
|
-
Z = pt.tensor(name="Z", dtype=floatX, shape=(None, None, None))
|
|
120
|
-
|
|
121
|
-
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
|
|
122
|
-
|
|
123
|
-
(
|
|
124
|
-
filtered_states,
|
|
125
|
-
predicted_states,
|
|
126
|
-
observed_states,
|
|
127
|
-
filtered_covs,
|
|
128
|
-
predicted_covs,
|
|
129
|
-
observed_covs,
|
|
130
|
-
ll_obs,
|
|
131
|
-
) = StandardFilter().build_graph(*inputs)
|
|
132
|
-
|
|
133
|
-
smoothed_states, smoothed_covs = ksmoother.build_graph(T, R, Q, filtered_states, filtered_covs)
|
|
134
|
-
|
|
135
|
-
outputs = [
|
|
136
|
-
filtered_states,
|
|
137
|
-
predicted_states,
|
|
138
|
-
smoothed_states,
|
|
139
|
-
filtered_covs,
|
|
140
|
-
predicted_covs,
|
|
141
|
-
smoothed_covs,
|
|
142
|
-
ll_obs.sum(),
|
|
143
|
-
ll_obs,
|
|
144
|
-
]
|
|
145
|
-
|
|
146
|
-
f_standard = pytensor.function(inputs, outputs)
|
|
147
|
-
|
|
148
|
-
return f_standard
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
def test_output_shapes_with_time_varying_matrices(f_standard_nd, rng):
|
|
152
|
-
p, m, r, n = 1, 5, 2, 10
|
|
153
|
-
data, a0, P0, c, d, T, Z, R, H, Q = make_test_inputs(p, m, r, n, rng)
|
|
154
|
-
T = np.concatenate([np.expand_dims(T, 0)] * n, axis=0)
|
|
155
|
-
Z = np.concatenate([np.expand_dims(Z, 0)] * n, axis=0)
|
|
156
|
-
R = np.concatenate([np.expand_dims(R, 0)] * n, axis=0)
|
|
157
|
-
H = np.concatenate([np.expand_dims(H, 0)] * n, axis=0)
|
|
158
|
-
Q = np.concatenate([np.expand_dims(Q, 0)] * n, axis=0)
|
|
159
|
-
|
|
160
|
-
outputs = f_standard_nd(data, a0, P0, c, d, T, Z, R, H, Q)
|
|
161
|
-
|
|
162
|
-
for output_idx, name in enumerate(output_names):
|
|
163
|
-
expected_output = get_expected_shape(name, p, m, r, n)
|
|
164
|
-
assert (
|
|
165
|
-
outputs[output_idx].shape == expected_output
|
|
166
|
-
), f"Shape of {name} does not match expected"
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
|
|
170
|
-
def test_output_with_deterministic_observation_equation(filter_func, rng):
|
|
171
|
-
p, m, r, n = 1, 5, 1, 10
|
|
172
|
-
inputs = make_test_inputs(p, m, r, n, rng)
|
|
173
|
-
|
|
174
|
-
outputs = filter_func(*inputs)
|
|
175
|
-
|
|
176
|
-
for output_idx, name in enumerate(output_names):
|
|
177
|
-
expected_output = get_expected_shape(name, p, m, r, n)
|
|
178
|
-
assert (
|
|
179
|
-
outputs[output_idx].shape == expected_output
|
|
180
|
-
), f"Shape of {name} does not match expected"
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
@pytest.mark.parametrize(
|
|
184
|
-
("filter_func", "filter_name"), zip(filter_funcs, filter_names), ids=filter_names
|
|
185
|
-
)
|
|
186
|
-
def test_output_with_multiple_observed(filter_func, filter_name, rng):
|
|
187
|
-
p, m, r, n = 5, 5, 1, 10
|
|
188
|
-
inputs = make_test_inputs(p, m, r, n, rng)
|
|
189
|
-
|
|
190
|
-
outputs = filter_func(*inputs)
|
|
191
|
-
for output_idx, name in enumerate(output_names):
|
|
192
|
-
expected_output = get_expected_shape(name, p, m, r, n)
|
|
193
|
-
assert (
|
|
194
|
-
outputs[output_idx].shape == expected_output
|
|
195
|
-
), f"Shape of {name} does not match expected"
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
@pytest.mark.parametrize(
|
|
199
|
-
("filter_func", "filter_name"), zip(filter_funcs, filter_names), ids=filter_names
|
|
200
|
-
)
|
|
201
|
-
@pytest.mark.parametrize("p", [1, 5], ids=["univariate (p=1)", "multivariate (p=5)"])
|
|
202
|
-
def test_missing_data(filter_func, filter_name, p, rng):
|
|
203
|
-
m, r, n = 5, 1, 10
|
|
204
|
-
inputs = make_test_inputs(p, m, r, n, rng, missing_data=1)
|
|
205
|
-
|
|
206
|
-
outputs = filter_func(*inputs)
|
|
207
|
-
for output_idx, name in enumerate(output_names):
|
|
208
|
-
expected_output = get_expected_shape(name, p, m, r, n)
|
|
209
|
-
assert (
|
|
210
|
-
outputs[output_idx].shape == expected_output
|
|
211
|
-
), f"Shape of {name} does not match expected"
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
|
|
215
|
-
@pytest.mark.parametrize("output_idx", [(0, 2), (3, 5)], ids=["smoothed_states", "smoothed_covs"])
|
|
216
|
-
def test_last_smoother_is_last_filtered(filter_func, output_idx, rng):
|
|
217
|
-
p, m, r, n = 1, 5, 1, 10
|
|
218
|
-
inputs = make_test_inputs(p, m, r, n, rng)
|
|
219
|
-
outputs = filter_func(*inputs)
|
|
220
|
-
|
|
221
|
-
filtered = outputs[output_idx[0]]
|
|
222
|
-
smoothed = outputs[output_idx[1]]
|
|
223
|
-
|
|
224
|
-
assert_allclose(filtered[-1], smoothed[-1])
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
@pytest.mark.parametrize(
|
|
228
|
-
"filter_func, filter_name", zip(filter_funcs, filter_names), ids=filter_names
|
|
229
|
-
)
|
|
230
|
-
@pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"])
|
|
231
|
-
@pytest.mark.skipif(floatX == "float32", reason="Tests are too sensitive for float32")
|
|
232
|
-
def test_filters_match_statsmodel_output(filter_func, filter_name, n_missing, rng):
|
|
233
|
-
fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
|
|
234
|
-
if filter_name == "CholeskyFilter":
|
|
235
|
-
P0 = np.linalg.cholesky(P0)
|
|
236
|
-
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
|
|
237
|
-
outputs = filter_func(*inputs)
|
|
238
|
-
|
|
239
|
-
for output_idx, name in enumerate(output_names):
|
|
240
|
-
ref_val = get_sm_state_from_output_name(fit_sm_mod, name)
|
|
241
|
-
val_to_test = outputs[output_idx].squeeze()
|
|
242
|
-
|
|
243
|
-
if name == "smoothed_covs":
|
|
244
|
-
# TODO: The smoothed covariance matrices have large errors (1e-2) ONLY in the first few states -- no idea why.
|
|
245
|
-
assert_allclose(
|
|
246
|
-
val_to_test[5:],
|
|
247
|
-
ref_val[5:],
|
|
248
|
-
atol=ATOL,
|
|
249
|
-
rtol=RTOL,
|
|
250
|
-
err_msg=f"{name} does not match statsmodels",
|
|
251
|
-
)
|
|
252
|
-
elif name.startswith("predicted"):
|
|
253
|
-
# statsmodels doesn't throw away the T+1 forecast in the predicted states like we do
|
|
254
|
-
assert_allclose(
|
|
255
|
-
val_to_test,
|
|
256
|
-
ref_val[:-1],
|
|
257
|
-
atol=ATOL,
|
|
258
|
-
rtol=RTOL,
|
|
259
|
-
err_msg=f"{name} does not match statsmodels",
|
|
260
|
-
)
|
|
261
|
-
else:
|
|
262
|
-
# Need atol = 1e-7 for smoother tests to pass
|
|
263
|
-
assert_allclose(
|
|
264
|
-
val_to_test,
|
|
265
|
-
ref_val,
|
|
266
|
-
atol=ATOL,
|
|
267
|
-
rtol=RTOL,
|
|
268
|
-
err_msg=f"{name} does not match statsmodels",
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
@pytest.mark.parametrize(
|
|
273
|
-
"filter_func, filter_name", zip(filter_funcs[:-1], filter_names[:-1]), ids=filter_names[:-1]
|
|
274
|
-
)
|
|
275
|
-
@pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"])
|
|
276
|
-
@pytest.mark.parametrize("obs_noise", [True, False])
|
|
277
|
-
def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, obs_noise, rng):
|
|
278
|
-
if (floatX == "float32") & (filter_name == "UnivariateFilter"):
|
|
279
|
-
# TODO: These tests all pass locally for me with float32 but they fail on the CI, so i'm just disabling them.
|
|
280
|
-
pytest.skip("Univariate filter not stable at half precision without measurement error")
|
|
281
|
-
|
|
282
|
-
fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
|
|
283
|
-
if filter_name == "CholeskyFilter":
|
|
284
|
-
P0 = np.linalg.cholesky(P0)
|
|
285
|
-
|
|
286
|
-
H *= int(obs_noise)
|
|
287
|
-
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
|
|
288
|
-
outputs = filter_func(*inputs)
|
|
289
|
-
|
|
290
|
-
for output_idx, name in zip([3, 4, 5], output_names[3:-2]):
|
|
291
|
-
cov_stack = outputs[output_idx]
|
|
292
|
-
w, v = np.linalg.eig(cov_stack)
|
|
293
|
-
|
|
294
|
-
assert_array_less(0, w, err_msg=f"Smallest eigenvalue of {name}: {min(w.ravel())}")
|
|
295
|
-
assert_allclose(
|
|
296
|
-
cov_stack,
|
|
297
|
-
np.swapaxes(cov_stack, -2, -1),
|
|
298
|
-
rtol=RTOL,
|
|
299
|
-
atol=ATOL,
|
|
300
|
-
err_msg=f"{name} is not symmetrical",
|
|
301
|
-
)
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
@pytest.mark.parametrize(
|
|
305
|
-
"filter",
|
|
306
|
-
[StandardFilter, SquareRootFilter],
|
|
307
|
-
ids=["standard", "cholesky"],
|
|
308
|
-
)
|
|
309
|
-
def test_kalman_filter_jax(filter):
|
|
310
|
-
pytest.importorskip("jax")
|
|
311
|
-
from pymc.sampling.jax import get_jaxified_graph
|
|
312
|
-
|
|
313
|
-
# TODO: Add UnivariateFilter to test; need to figure out the broadcasting issue when 2nd data dim is defined
|
|
314
|
-
|
|
315
|
-
p, m, r, n = 1, 5, 1, 10
|
|
316
|
-
inputs, outputs = initialize_filter(filter(), mode="JAX", p=p, m=m, r=r, n=n)
|
|
317
|
-
inputs_np = make_test_inputs(p, m, r, n, rng)
|
|
318
|
-
|
|
319
|
-
f_jax = get_jaxified_graph(inputs, outputs)
|
|
320
|
-
f_pt = pytensor.function(inputs, outputs, mode="FAST_COMPILE")
|
|
321
|
-
|
|
322
|
-
jax_outputs = f_jax(*inputs_np)
|
|
323
|
-
pt_outputs = f_pt(*inputs_np)
|
|
324
|
-
|
|
325
|
-
for name, jax_res, pt_res in zip(output_names, jax_outputs, pt_outputs):
|
|
326
|
-
assert_allclose(jax_res, pt_res, atol=ATOL, rtol=RTOL, err_msg=f"{name} failed!")
|