pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,820 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pytensor
|
|
5
|
+
import pytensor.tensor as pt
|
|
6
|
+
|
|
7
|
+
from pymc.pytensorf import constant_fold
|
|
8
|
+
from pytensor.compile.mode import get_mode
|
|
9
|
+
from pytensor.graph.basic import Variable
|
|
10
|
+
from pytensor.raise_op import Assert
|
|
11
|
+
from pytensor.tensor import TensorVariable
|
|
12
|
+
from pytensor.tensor.slinalg import solve_triangular
|
|
13
|
+
|
|
14
|
+
from pymc_extras.statespace.filters.utilities import (
|
|
15
|
+
quad_form_sym,
|
|
16
|
+
split_vars_into_seq_and_nonseq,
|
|
17
|
+
stabilize,
|
|
18
|
+
)
|
|
19
|
+
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL
|
|
20
|
+
|
|
21
|
+
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
|
|
22
|
+
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
|
|
23
|
+
|
|
24
|
+
assert_time_varying_dim_correct = Assert(
|
|
25
|
+
"The first dimension of a time varying matrix (the time dimension) must be "
|
|
26
|
+
"equal to the first dimension of the data (the time dimension)."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BaseFilter(ABC):
|
|
31
|
+
def __init__(self, mode=None):
|
|
32
|
+
"""
|
|
33
|
+
Kalman Filter.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
mode : str, optional
|
|
38
|
+
The mode used for Pytensor compilation. Defaults to None.
|
|
39
|
+
|
|
40
|
+
Notes
|
|
41
|
+
-----
|
|
42
|
+
The BaseFilter class is an abstract base class (ABC) for implementing kalman filters.
|
|
43
|
+
It defines common attributes and methods used by kalman filter implementations.
|
|
44
|
+
|
|
45
|
+
Attributes
|
|
46
|
+
----------
|
|
47
|
+
mode : str or None
|
|
48
|
+
The mode used for Pytensor compilation.
|
|
49
|
+
|
|
50
|
+
seq_names : list[str]
|
|
51
|
+
A list of name representing time-varying statespace matrices. That is, inputs that will need to be
|
|
52
|
+
provided to the `sequences` argument of `pytensor.scan`
|
|
53
|
+
|
|
54
|
+
non_seq_names : list[str]
|
|
55
|
+
A list of names representing static statespace matrices. That is, inputs that will need to be provided
|
|
56
|
+
to the `non_sequences` argument of `pytensor.scan`
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
self.mode: str = mode
|
|
60
|
+
self.seq_names: list[str] = []
|
|
61
|
+
self.non_seq_names: list[str] = []
|
|
62
|
+
|
|
63
|
+
self.n_states = None
|
|
64
|
+
self.n_posdef = None
|
|
65
|
+
self.n_endog = None
|
|
66
|
+
|
|
67
|
+
self.missing_fill_value: float | None = None
|
|
68
|
+
self.cov_jitter = None
|
|
69
|
+
|
|
70
|
+
def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
|
|
71
|
+
"""
|
|
72
|
+
Apply any checks on validity of inputs. For most filters this is just the identity function.
|
|
73
|
+
"""
|
|
74
|
+
return data, a0, P0, c, d, T, Z, R, H, Q
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def add_check_on_time_varying_shapes(
|
|
78
|
+
data: TensorVariable, sequence_params: list[TensorVariable]
|
|
79
|
+
) -> list[Variable]:
|
|
80
|
+
"""
|
|
81
|
+
Insert a check that time-varying matrices match the data shape to the computational graph.
|
|
82
|
+
|
|
83
|
+
If any matrices are time-varying, they need to have the same length as the data. This function wraps each
|
|
84
|
+
element of `sequence_params` in an assert `Op` that makes sure all inputs have the correct shape.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
data : TensorVariable
|
|
89
|
+
The tensor representing the data.
|
|
90
|
+
|
|
91
|
+
sequence_params : list[TensorVariable]
|
|
92
|
+
A list of tensors to be provided to `pytensor.scan` as `sequences`.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
list[TensorVariable]
|
|
97
|
+
A list of tensors wrapped in an `Assert` `Op` that checks the shape of the 0th dimension on each is equal
|
|
98
|
+
to the shape of the 0th dimension on the data.
|
|
99
|
+
"""
|
|
100
|
+
# TODO: The PytensorRepresentation object puts the time dimension last, should the reshaping happen here in
|
|
101
|
+
# the Kalman filter, or in the StateSpaceModel, before passing into the KF?
|
|
102
|
+
|
|
103
|
+
params_with_assert = [
|
|
104
|
+
assert_time_varying_dim_correct(param, pt.eq(param.shape[0], data.shape[0]))
|
|
105
|
+
for param in sequence_params
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
return params_with_assert
|
|
109
|
+
|
|
110
|
+
def unpack_args(self, args) -> tuple:
|
|
111
|
+
"""
|
|
112
|
+
The order of inputs to the inner scan function is not known, since some, all, or none of the input matrices
|
|
113
|
+
can be time varying. The order arguments are fed to the inner function is sequences, outputs_info,
|
|
114
|
+
non-sequences. This function works out which matrices are where, and returns a standardized order expected
|
|
115
|
+
by the kalman_step function.
|
|
116
|
+
|
|
117
|
+
The standard order is: y, a0, P0, c, d, T, Z, R, H, Q
|
|
118
|
+
"""
|
|
119
|
+
# If there are no sequence parameters (all params are static),
|
|
120
|
+
# no changes are needed, params will be in order.
|
|
121
|
+
args = list(args)
|
|
122
|
+
n_seq = len(self.seq_names)
|
|
123
|
+
if n_seq == 0:
|
|
124
|
+
return tuple(args)
|
|
125
|
+
|
|
126
|
+
# The first arg is always y
|
|
127
|
+
y = args.pop(0)
|
|
128
|
+
|
|
129
|
+
# There are always two outputs_info wedged between the seqs and non_seqs
|
|
130
|
+
seqs, (a0, P0), non_seqs = args[:n_seq], args[n_seq : n_seq + 2], args[n_seq + 2 :]
|
|
131
|
+
return_ordered = []
|
|
132
|
+
for name in ["c", "d", "T", "Z", "R", "H", "Q"]:
|
|
133
|
+
if name in self.seq_names:
|
|
134
|
+
idx = self.seq_names.index(name)
|
|
135
|
+
return_ordered.append(seqs[idx])
|
|
136
|
+
else:
|
|
137
|
+
idx = self.non_seq_names.index(name)
|
|
138
|
+
return_ordered.append(non_seqs[idx])
|
|
139
|
+
|
|
140
|
+
c, d, T, Z, R, H, Q = return_ordered
|
|
141
|
+
|
|
142
|
+
return y, a0, P0, c, d, T, Z, R, H, Q
|
|
143
|
+
|
|
144
|
+
def build_graph(
|
|
145
|
+
self,
|
|
146
|
+
data,
|
|
147
|
+
a0,
|
|
148
|
+
P0,
|
|
149
|
+
c,
|
|
150
|
+
d,
|
|
151
|
+
T,
|
|
152
|
+
Z,
|
|
153
|
+
R,
|
|
154
|
+
H,
|
|
155
|
+
Q,
|
|
156
|
+
mode=None,
|
|
157
|
+
return_updates=False,
|
|
158
|
+
missing_fill_value=None,
|
|
159
|
+
cov_jitter=None,
|
|
160
|
+
) -> list[TensorVariable] | tuple[list[TensorVariable], dict]:
|
|
161
|
+
"""
|
|
162
|
+
Construct the computation graph for the Kalman filter. See [1] for details.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
data : TensorVariable
|
|
167
|
+
Data to be filtered
|
|
168
|
+
|
|
169
|
+
mode : optional, str
|
|
170
|
+
Pytensor compile mode, passed to pytensor.scan
|
|
171
|
+
|
|
172
|
+
return_updates: bool, default False
|
|
173
|
+
Whether to return updates associated with the pytensor scan. Should only be requried to debug pruposes.
|
|
174
|
+
|
|
175
|
+
missing_fill_value: float, default -9999
|
|
176
|
+
Fill value used to mark missing values. Used to avoid PyMC's automatic interpolation, which conflict's with
|
|
177
|
+
the Kalman filter's hidden state inference. Change if your data happens to have legitimate values of -9999
|
|
178
|
+
|
|
179
|
+
cov_jitter: float, default 1e-8 or 1e-6 if pytensor.config.floatX is float32
|
|
180
|
+
The Kalman filter is known to be numerically unstable, especially at half precision. This value is added to
|
|
181
|
+
the diagonal of every covariance matrix -- predicted, filtered, and smoothed -- at every step, to ensure
|
|
182
|
+
all matrices are strictly positive semi-definite.
|
|
183
|
+
|
|
184
|
+
Obviously, if this can be zero, that's best. In general:
|
|
185
|
+
- Having measurement error makes Kalman Filters more robust. A large source of numerical errors come
|
|
186
|
+
from the Filtered and Smoothed matrices having a zero in the (0, 0) position, which always occurs
|
|
187
|
+
when there is no measurement error.
|
|
188
|
+
|
|
189
|
+
- The Univariate Filter is more robust than other filters, and can tolerate a lower jitter value
|
|
190
|
+
|
|
191
|
+
References
|
|
192
|
+
----------
|
|
193
|
+
.. [1] Koopman, Siem Jan, Neil Shephard, and Jurgen A. Doornik. 1999.
|
|
194
|
+
Statistical Algorithms for Models in State Space Using SsfPack 2.2.
|
|
195
|
+
Econometrics Journal 2 (1): 107-60. doi:10.1111/1368-423X.00023.
|
|
196
|
+
"""
|
|
197
|
+
if missing_fill_value is None:
|
|
198
|
+
missing_fill_value = MISSING_FILL
|
|
199
|
+
if cov_jitter is None:
|
|
200
|
+
cov_jitter = JITTER_DEFAULT
|
|
201
|
+
|
|
202
|
+
self.mode = mode
|
|
203
|
+
self.missing_fill_value = missing_fill_value
|
|
204
|
+
self.cov_jitter = cov_jitter
|
|
205
|
+
|
|
206
|
+
[R_shape] = constant_fold([R.shape], raise_not_constant=False)
|
|
207
|
+
[Z_shape] = constant_fold([Z.shape], raise_not_constant=False)
|
|
208
|
+
|
|
209
|
+
self.n_states, self.n_shocks = R_shape[-2:]
|
|
210
|
+
self.n_endog = Z_shape[-2]
|
|
211
|
+
|
|
212
|
+
data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
|
|
213
|
+
|
|
214
|
+
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
|
|
215
|
+
params, PARAM_NAMES
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
self.seq_names = seq_names
|
|
219
|
+
self.non_seq_names = non_seq_names
|
|
220
|
+
|
|
221
|
+
if len(sequences) > 0:
|
|
222
|
+
sequences = self.add_check_on_time_varying_shapes(data, sequences)
|
|
223
|
+
|
|
224
|
+
results, updates = pytensor.scan(
|
|
225
|
+
self.kalman_step,
|
|
226
|
+
sequences=[data, *sequences],
|
|
227
|
+
outputs_info=[None, a0, None, None, P0, None, None],
|
|
228
|
+
non_sequences=non_sequences,
|
|
229
|
+
name="forward_kalman_pass",
|
|
230
|
+
mode=get_mode(self.mode),
|
|
231
|
+
strict=False,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
|
|
235
|
+
|
|
236
|
+
if return_updates:
|
|
237
|
+
return filter_results, updates
|
|
238
|
+
return filter_results
|
|
239
|
+
|
|
240
|
+
def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
|
|
241
|
+
"""
|
|
242
|
+
Transform the values returned by the Kalman Filter scan into a form expected by users. In particular:
|
|
243
|
+
1. Append the initial state and covariance matrix to their respective Kalman predictions. This matches the
|
|
244
|
+
output returned by Statsmodels state space models.
|
|
245
|
+
|
|
246
|
+
2. Discard the last state and covariance matrix from the Kalman predictions. This is beacuse the kalman filter
|
|
247
|
+
starts with the (random variable) initial state x0, and treats it as a predicted state. The first step (t=0)
|
|
248
|
+
will filter x0 to make filtered_states[0], then do a predict step to make predicted_states[1]. This means
|
|
249
|
+
the last step (t=T) predicted state will be a *forecast* for T+1. If the user wants this forecast, he should
|
|
250
|
+
use the forecast method.
|
|
251
|
+
|
|
252
|
+
3. Squeeze away extra dimensions from the filtered and predicted states, as well as the likelihoods.
|
|
253
|
+
"""
|
|
254
|
+
(
|
|
255
|
+
filtered_states,
|
|
256
|
+
predicted_states,
|
|
257
|
+
observed_states,
|
|
258
|
+
filtered_covariances,
|
|
259
|
+
predicted_covariances,
|
|
260
|
+
observed_covariances,
|
|
261
|
+
loglike_obs,
|
|
262
|
+
) = results
|
|
263
|
+
|
|
264
|
+
predicted_states = pt.concatenate(
|
|
265
|
+
[pt.expand_dims(a0, axis=(0,)), predicted_states[:-1]], axis=0
|
|
266
|
+
)
|
|
267
|
+
predicted_covariances = pt.concatenate(
|
|
268
|
+
[pt.expand_dims(P0, axis=(0,)), predicted_covariances[:-1]], axis=0
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
filtered_states = pt.specify_shape(filtered_states, (n, self.n_states))
|
|
272
|
+
filtered_states.name = "filtered_states"
|
|
273
|
+
|
|
274
|
+
predicted_states = pt.specify_shape(predicted_states, (n, self.n_states))
|
|
275
|
+
predicted_states.name = "predicted_states"
|
|
276
|
+
|
|
277
|
+
observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
|
|
278
|
+
observed_states.name = "observed_states"
|
|
279
|
+
|
|
280
|
+
filtered_covariances = pt.specify_shape(
|
|
281
|
+
filtered_covariances, (n, self.n_states, self.n_states)
|
|
282
|
+
)
|
|
283
|
+
filtered_covariances.name = "filtered_covariances"
|
|
284
|
+
|
|
285
|
+
predicted_covariances = pt.specify_shape(
|
|
286
|
+
predicted_covariances, (n, self.n_states, self.n_states)
|
|
287
|
+
)
|
|
288
|
+
predicted_covariances.name = "predicted_covariances"
|
|
289
|
+
|
|
290
|
+
observed_covariances = pt.specify_shape(
|
|
291
|
+
observed_covariances, (n, self.n_endog, self.n_endog)
|
|
292
|
+
)
|
|
293
|
+
observed_covariances.name = "observed_covariances"
|
|
294
|
+
|
|
295
|
+
loglike_obs = pt.specify_shape(loglike_obs.squeeze(), (n,))
|
|
296
|
+
loglike_obs.name = "loglike_obs"
|
|
297
|
+
|
|
298
|
+
filter_results = [
|
|
299
|
+
filtered_states,
|
|
300
|
+
predicted_states,
|
|
301
|
+
observed_states,
|
|
302
|
+
filtered_covariances,
|
|
303
|
+
predicted_covariances,
|
|
304
|
+
observed_covariances,
|
|
305
|
+
loglike_obs,
|
|
306
|
+
]
|
|
307
|
+
|
|
308
|
+
return filter_results
|
|
309
|
+
|
|
310
|
+
def handle_missing_values(
|
|
311
|
+
self, y, Z, H
|
|
312
|
+
) -> tuple[TensorVariable, TensorVariable, TensorVariable, float]:
|
|
313
|
+
"""
|
|
314
|
+
Handle missing values in the observation data `y`
|
|
315
|
+
|
|
316
|
+
Adjusts the design matrix `Z` and the observation noise covariance matrix `H` by removing rows and/or columns
|
|
317
|
+
associated with the data that is not observed at this iteration. Missing values are replaced with zeros to prevent
|
|
318
|
+
propagating NaNs through the computation.
|
|
319
|
+
|
|
320
|
+
Return a binary flag tensor `all_nan_flag`,indicating if all values in the observation data are missing. This
|
|
321
|
+
flag is used for numerical adjustments in the update method.
|
|
322
|
+
|
|
323
|
+
Parameters
|
|
324
|
+
----------
|
|
325
|
+
y : TensorVariable
|
|
326
|
+
The observation data at time t.
|
|
327
|
+
Z : TensorVariable
|
|
328
|
+
The design matrix.
|
|
329
|
+
H : TensorVariable
|
|
330
|
+
The observation noise covariance matrix.
|
|
331
|
+
|
|
332
|
+
Returns
|
|
333
|
+
-------
|
|
334
|
+
y_masked : TensorVariable
|
|
335
|
+
Observation vector with missing values replaced by zeros.
|
|
336
|
+
|
|
337
|
+
Z_masked: TensorVariable
|
|
338
|
+
Design matrix adjusted to exclude the missing states from the information set of observed variables in the
|
|
339
|
+
update step
|
|
340
|
+
|
|
341
|
+
H_masked: TensorVariable
|
|
342
|
+
Noise covariance matrix, adjusted to exclude the missing states
|
|
343
|
+
|
|
344
|
+
all_nan_flag: float
|
|
345
|
+
1 if the entire state vector is missing
|
|
346
|
+
|
|
347
|
+
References
|
|
348
|
+
----------
|
|
349
|
+
.. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
|
|
350
|
+
2nd ed, Oxford University Press, 2012.
|
|
351
|
+
"""
|
|
352
|
+
nan_mask = pt.or_(pt.isnan(y), pt.eq(y, self.missing_fill_value))
|
|
353
|
+
all_nan_flag = pt.all(nan_mask).astype(pytensor.config.floatX)
|
|
354
|
+
W = pt.diag(pt.bitwise_not(nan_mask).astype(pytensor.config.floatX))
|
|
355
|
+
|
|
356
|
+
Z_masked = W.dot(Z)
|
|
357
|
+
H_masked = W.dot(H)
|
|
358
|
+
y_masked = pt.set_subtensor(y[nan_mask], 0.0)
|
|
359
|
+
|
|
360
|
+
return y_masked, Z_masked, H_masked, all_nan_flag
|
|
361
|
+
|
|
362
|
+
@staticmethod
|
|
363
|
+
def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]:
|
|
364
|
+
"""
|
|
365
|
+
Perform the prediction step of the Kalman filter.
|
|
366
|
+
|
|
367
|
+
This function computes the one-step forecast of the hidden states and the covariance matrix of the forecasted
|
|
368
|
+
states, based on the current state estimates and model parameters. For computational stability, the estimated
|
|
369
|
+
covariance matrix is forced to by symmetric by averaging it with its own transpose. The prediction equations
|
|
370
|
+
are:
|
|
371
|
+
|
|
372
|
+
.. math::
|
|
373
|
+
|
|
374
|
+
\begin{align}
|
|
375
|
+
a_{t+1 | t} &= T_t a_{t | t} \\
|
|
376
|
+
P_{t+1 | t} &= T_t P_{t | t} T_t^T + R_t Q_t R_t^T
|
|
377
|
+
\\end{align}
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
Parameters
|
|
381
|
+
----------
|
|
382
|
+
a : TensorVariable
|
|
383
|
+
The current state vector estimate computed by the update step, a[t | t].
|
|
384
|
+
P : TensorVariable
|
|
385
|
+
The current covariance matrix estimate computed by the update step, P[t | t].
|
|
386
|
+
c : TensorVariable
|
|
387
|
+
The hidden state intercept/bias vector.
|
|
388
|
+
T : TensorVariable
|
|
389
|
+
The state transition matrix.
|
|
390
|
+
R : TensorVariable
|
|
391
|
+
The selection matrix.
|
|
392
|
+
Q : TensorVariable
|
|
393
|
+
The state innovation covariance matrix.
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
a_hat : TensorVariable
|
|
398
|
+
One-step forecast of the hidden states
|
|
399
|
+
P_hat : TensorVariable
|
|
400
|
+
Covariance matrix of the forecasted hidden states
|
|
401
|
+
|
|
402
|
+
References
|
|
403
|
+
----------
|
|
404
|
+
.. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
|
|
405
|
+
2nd ed, Oxford University Press, 2012.
|
|
406
|
+
"""
|
|
407
|
+
a_hat = T.dot(a) + c
|
|
408
|
+
P_hat = quad_form_sym(T, P) + quad_form_sym(R, Q)
|
|
409
|
+
|
|
410
|
+
return a_hat, P_hat
|
|
411
|
+
|
|
412
|
+
@staticmethod
|
|
413
|
+
def update(
|
|
414
|
+
a, P, y, d, Z, H, all_nan_flag
|
|
415
|
+
) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable]:
|
|
416
|
+
"""
|
|
417
|
+
Perform the update step of the Kalman filter.
|
|
418
|
+
|
|
419
|
+
This function updates the state vector and covariance matrix estimates based on the current observation data,
|
|
420
|
+
previous predictions, and model parameters. The filtering equations are:
|
|
421
|
+
|
|
422
|
+
.. math::
|
|
423
|
+
|
|
424
|
+
\begin{align}
|
|
425
|
+
\\hat{y}_t &= Z_t a_{t | t-1} + d_t \\
|
|
426
|
+
v_t &= y_t - \\hat{y}_t \\
|
|
427
|
+
F_t &= Z_t P_{t | t-1} Z_t^T + H_t \\
|
|
428
|
+
a_{t|t} &= a_{t | t-1} + P_{t | t-1} Z_t^T F_t^{-1} v_t \\
|
|
429
|
+
P_{t|t} &= P_{t | t-1} - P_{t | t-1} Z_t^T F_t^{-1} Z_t P_{t | t-1}
|
|
430
|
+
\\end{align}
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
Parameters
|
|
434
|
+
----------
|
|
435
|
+
a : TensorVariable
|
|
436
|
+
The current state vector estimate, conditioned on information up to time t-1.
|
|
437
|
+
P : TensorVariable
|
|
438
|
+
The current covariance matrix estimate, conditioned on information up to time t-1.
|
|
439
|
+
y : TensorVariable
|
|
440
|
+
The observation data at time t.
|
|
441
|
+
d : TensorVariable
|
|
442
|
+
The matrix d.
|
|
443
|
+
Z : TensorVariable
|
|
444
|
+
The matrix Z.
|
|
445
|
+
H : TensorVariable
|
|
446
|
+
The matrix H.
|
|
447
|
+
all_nan_flag : TensorVariable
|
|
448
|
+
A binary flag tensor indicating whether there are any missing values in the observation data.
|
|
449
|
+
|
|
450
|
+
Returns
|
|
451
|
+
-------
|
|
452
|
+
tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable]
|
|
453
|
+
A tuple containing the updated state vector `a_filtered`, the updated covariance matrix `P_filtered`, the
|
|
454
|
+
predicted observation `obs_mu`, the predicted observation covariance matrix `obs_cov`, and the log-likelihood `ll`.
|
|
455
|
+
"""
|
|
456
|
+
raise NotImplementedError
|
|
457
|
+
|
|
458
|
+
def kalman_step(self, *args) -> tuple:
|
|
459
|
+
"""
|
|
460
|
+
Performs a single iteration of the Kalman filter, which is composed of two steps : an update step and a
|
|
461
|
+
prediction step. The timing convention follows [1], in which initial state and covariance estimates a0 and P0
|
|
462
|
+
are taken to be predictions. As a result, the update step is applied first. The update step computes:
|
|
463
|
+
|
|
464
|
+
.. math::
|
|
465
|
+
|
|
466
|
+
\begin{align}
|
|
467
|
+
\\hat{y}_t &= Z_t a_{t | t-1} \\
|
|
468
|
+
v_t &= y_t - \\hat{y}_t \\
|
|
469
|
+
F_t &= Z_t P_{t | t-1} Z_t^T + H_t \\
|
|
470
|
+
a_{t|t} &= a_{t | t-1} + P_{t | t-1} Z_t^T F_t^{-1} v_t \\
|
|
471
|
+
P_{t|t} &= P_{t | t-1} - P_{t | t-1} Z_t^T F_t^{-1} Z_t P_{t | t-1}
|
|
472
|
+
\\end{align}
|
|
473
|
+
|
|
474
|
+
Where the quantities :math:`a_{t|t}` and :math:`P_{t|t}` are the best linear estimates of the hidden states
|
|
475
|
+
at time t, incorporating all information up to and including the observation :math:`y_t`. After the update step,
|
|
476
|
+
new one-step forecasts of the hidden states can be obtained by applying the model transition dynamics in
|
|
477
|
+
the prediction step:
|
|
478
|
+
|
|
479
|
+
.. math::
|
|
480
|
+
|
|
481
|
+
\begin{align}
|
|
482
|
+
a_{t+1 | t} &= T_t a_{t | t} \\
|
|
483
|
+
P_{t+1 | t} &= T_t P_{t | t} T_t^T + R_t Q_t R_t^T
|
|
484
|
+
\\end{align}
|
|
485
|
+
|
|
486
|
+
Recursive application of these two steps results in the best linear estimate of the hidden states, including
|
|
487
|
+
missing values and observations subject to measurement error.
|
|
488
|
+
|
|
489
|
+
Parameters
|
|
490
|
+
----------
|
|
491
|
+
Kalman filter inputs:
|
|
492
|
+
y, a, P, c, d, T, Z, R, H, Q. See the docstring for the kalman filter class for details.
|
|
493
|
+
|
|
494
|
+
Returns
|
|
495
|
+
-------
|
|
496
|
+
a_filtered : TensorVariable
|
|
497
|
+
Best linear estimate of hidden states given all information up to and including the present
|
|
498
|
+
observation, a[t | t].
|
|
499
|
+
|
|
500
|
+
a_hat: TensorVariable
|
|
501
|
+
One-step forecast of next-period hidden states given all information up to and including the present
|
|
502
|
+
observation, a[t+1 | t]
|
|
503
|
+
|
|
504
|
+
obs_mu: TensorVariable
|
|
505
|
+
Estimates of the current observation given all information available prior to the current state,
|
|
506
|
+
d + Z @ a[t | t-1]
|
|
507
|
+
|
|
508
|
+
P_filtered: TensorVariable
|
|
509
|
+
Best linear estimate of the covariance between hidden states, given all information up to and including
|
|
510
|
+
the present observation, P[t | t]
|
|
511
|
+
|
|
512
|
+
P_hat: TensorVariable
|
|
513
|
+
Covariance between the one-step forecasted hidden states given all information up to and including the
|
|
514
|
+
present observation, P[t+1 | t]
|
|
515
|
+
|
|
516
|
+
obs_cov: TensorVariable
|
|
517
|
+
Covariance between estimated present observations, given all information available prior to the current
|
|
518
|
+
state, Z @ P[t | t-1] @ Z.T + H
|
|
519
|
+
|
|
520
|
+
ll: float
|
|
521
|
+
Likelihood of the time t observation vector under the multivariate normal distribution parameterized by
|
|
522
|
+
`obs_mu` and `obs_cov`
|
|
523
|
+
|
|
524
|
+
References
|
|
525
|
+
----------
|
|
526
|
+
.. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
|
|
527
|
+
2nd ed, Oxford University Press, 2012.
|
|
528
|
+
"""
|
|
529
|
+
y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args)
|
|
530
|
+
y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H)
|
|
531
|
+
|
|
532
|
+
a_filtered, P_filtered, obs_mu, obs_cov, ll = self.update(
|
|
533
|
+
y=y_masked, a=a, d=d, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
P_filtered = stabilize(P_filtered, self.cov_jitter)
|
|
537
|
+
|
|
538
|
+
a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q)
|
|
539
|
+
outputs = (a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll)
|
|
540
|
+
|
|
541
|
+
return outputs
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
class StandardFilter(BaseFilter):
|
|
545
|
+
"""
|
|
546
|
+
Basic Kalman Filter
|
|
547
|
+
"""
|
|
548
|
+
|
|
549
|
+
def update(self, a, P, y, d, Z, H, all_nan_flag):
|
|
550
|
+
"""
|
|
551
|
+
Compute one-step forecasts for observed states conditioned on information up to, but not including, the current
|
|
552
|
+
timestep, `y_hat`, along with the forcast covariance matrix, `F`. Marginalize over observed states to obtain
|
|
553
|
+
the best linear estimate of the unobserved states, `a_filtered`, as well as the associated covariance matrix,
|
|
554
|
+
`P_filtered`, conditioned on all information, up to and including the present.
|
|
555
|
+
|
|
556
|
+
Derivation of the Kalman filter, along with a deeper discussion of the computational elements, can be found in
|
|
557
|
+
[1].
|
|
558
|
+
|
|
559
|
+
Parameters
|
|
560
|
+
----------
|
|
561
|
+
a : TensorVariable
|
|
562
|
+
The current state vector estimate, conditioned on information up to time t-1.
|
|
563
|
+
|
|
564
|
+
P : TensorVariable
|
|
565
|
+
The current covariance matrix estimate, conditioned on information up to time t-1.
|
|
566
|
+
|
|
567
|
+
y : TensorVariable
|
|
568
|
+
Observations at time t.
|
|
569
|
+
|
|
570
|
+
d : TensorVariable
|
|
571
|
+
Observed state bias term.
|
|
572
|
+
|
|
573
|
+
Z : TensorVariable
|
|
574
|
+
Linear map between unobserved and observed states.
|
|
575
|
+
|
|
576
|
+
H : TensorVariable
|
|
577
|
+
Observation noise covariance matrix
|
|
578
|
+
|
|
579
|
+
all_nan_flag : TensorVariable
|
|
580
|
+
A flag indicating whether all elements in the data `y` are NaNs.
|
|
581
|
+
|
|
582
|
+
Returns
|
|
583
|
+
-------
|
|
584
|
+
tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, float]
|
|
585
|
+
A tuple containing the updated state vector `a_filtered`, the updated covariance matrix `P_filtered`,
|
|
586
|
+
the one-step forecast mean `y_hat`, one-step forcast covariance matrix `F`, and the log-likelihood of
|
|
587
|
+
the data, given the one-step forecasts, `ll`.
|
|
588
|
+
|
|
589
|
+
References
|
|
590
|
+
----------
|
|
591
|
+
.. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
|
|
592
|
+
2nd ed, Oxford University Press, 2012.
|
|
593
|
+
"""
|
|
594
|
+
y_hat = d + Z.dot(a)
|
|
595
|
+
v = y - y_hat
|
|
596
|
+
|
|
597
|
+
PZT = P.dot(Z.T)
|
|
598
|
+
F = Z.dot(PZT) + stabilize(H, self.cov_jitter)
|
|
599
|
+
|
|
600
|
+
K = pt.linalg.solve(F.T, PZT.T, assume_a="pos", check_finite=False).T
|
|
601
|
+
I_KZ = pt.eye(self.n_states) - K.dot(Z)
|
|
602
|
+
|
|
603
|
+
a_filtered = a + K.dot(v)
|
|
604
|
+
P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
|
|
605
|
+
|
|
606
|
+
F_inv_v = pt.linalg.solve(F, v, assume_a="pos", check_finite=False)
|
|
607
|
+
inner_term = v.T @ F_inv_v
|
|
608
|
+
|
|
609
|
+
F_logdet = pt.log(pt.linalg.det(F))
|
|
610
|
+
|
|
611
|
+
ll = pt.switch(
|
|
612
|
+
all_nan_flag,
|
|
613
|
+
0.0,
|
|
614
|
+
-0.5 * (MVN_CONST + F_logdet + inner_term).ravel()[0],
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
return a_filtered, P_filtered, y_hat, F, ll
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
class SquareRootFilter(BaseFilter):
|
|
621
|
+
"""
|
|
622
|
+
Kalman filter with Cholesky factorization
|
|
623
|
+
|
|
624
|
+
Kalman filter implementation using a Cholesky factorization plus pt.solve_triangular to (attempt) to speed up
|
|
625
|
+
inversion of the observation covariance matrix `F`.
|
|
626
|
+
|
|
627
|
+
"""
|
|
628
|
+
|
|
629
|
+
def predict(self, a, P, c, T, R, Q):
|
|
630
|
+
"""
|
|
631
|
+
Compute one-step forecasts for the hidden states conditioned on information up to, but not including, the current
|
|
632
|
+
timestep, `a_hat`, along with the forcast covariance matrix, `P_hat`.
|
|
633
|
+
|
|
634
|
+
.. warning::
|
|
635
|
+
Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the
|
|
636
|
+
covariance matrix itself. The name `P` is kept for consistency with the superclass.
|
|
637
|
+
"""
|
|
638
|
+
# Rename P to P_chol for clarity
|
|
639
|
+
P_chol = P
|
|
640
|
+
|
|
641
|
+
a_hat = T.dot(a) + c
|
|
642
|
+
Q_chol = pt.linalg.cholesky(Q, lower=True)
|
|
643
|
+
|
|
644
|
+
M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).T
|
|
645
|
+
R_decomp = pt.linalg.qr(M, mode="r")
|
|
646
|
+
P_chol_hat = R_decomp[: self.n_states, : self.n_states].T
|
|
647
|
+
|
|
648
|
+
return a_hat, P_chol_hat
|
|
649
|
+
|
|
650
|
+
def update(self, a, P, y, d, Z, H, all_nan_flag):
|
|
651
|
+
"""
|
|
652
|
+
Compute posterior estimates of the hidden state distributions conditioned on the observed data, up to and
|
|
653
|
+
including the present timestep. Also compute the log-likelihood of the data given the one-step forecasts.
|
|
654
|
+
|
|
655
|
+
.. warning::
|
|
656
|
+
Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the
|
|
657
|
+
covariance matrix itself. The name `P` is kept for consistency with the superclass.
|
|
658
|
+
"""
|
|
659
|
+
|
|
660
|
+
# Rename P to P_chol for clarity
|
|
661
|
+
P_chol = P
|
|
662
|
+
|
|
663
|
+
y_hat = Z.dot(a) + d
|
|
664
|
+
v = y - y_hat
|
|
665
|
+
|
|
666
|
+
H_chol = pytensor.ifelse(pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True))
|
|
667
|
+
|
|
668
|
+
# The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
|
|
669
|
+
# Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
|
|
670
|
+
# [0, L_pred]]
|
|
671
|
+
# The Schur decomposition of this matrix will be B (upper triangular). We are
|
|
672
|
+
# more insterested in B^T:
|
|
673
|
+
# Structure of B^T = [[chol(F), 0 ],
|
|
674
|
+
# [K @ chol(F), chol(P_filtered)]
|
|
675
|
+
zeros = pt.zeros((self.n_states, self.n_endog))
|
|
676
|
+
upper = pt.horizontal_stack(H_chol, Z @ P_chol)
|
|
677
|
+
lower = pt.horizontal_stack(zeros, P_chol)
|
|
678
|
+
A_T = pt.vertical_stack(upper, lower)
|
|
679
|
+
B = pt.linalg.qr(A_T.T, mode="r").T
|
|
680
|
+
|
|
681
|
+
F_chol = B[: self.n_endog, : self.n_endog]
|
|
682
|
+
K_F_chol = B[self.n_endog :, : self.n_endog]
|
|
683
|
+
P_chol_filtered = B[self.n_endog :, self.n_endog :]
|
|
684
|
+
|
|
685
|
+
def compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v):
|
|
686
|
+
a_filtered = a + K_F_chol @ solve_triangular(F_chol, v, lower=True)
|
|
687
|
+
|
|
688
|
+
inner_term = solve_triangular(
|
|
689
|
+
F_chol, solve_triangular(F_chol, v, lower=True), lower=True
|
|
690
|
+
)
|
|
691
|
+
loss = (v.T @ inner_term).ravel()
|
|
692
|
+
|
|
693
|
+
# abs necessary because we're not guaranteed a positive diagonal from the schur decomposition
|
|
694
|
+
logdet = 2 * pt.log(pt.abs(pt.diag(F_chol))).sum()
|
|
695
|
+
|
|
696
|
+
ll = -0.5 * (self.n_endog * (MVN_CONST + logdet) + loss)[0]
|
|
697
|
+
|
|
698
|
+
return [a_filtered, P_chol_filtered, ll]
|
|
699
|
+
|
|
700
|
+
def compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v):
|
|
701
|
+
"""
|
|
702
|
+
If F is zero (usually because there were no observations this period), then we want:
|
|
703
|
+
K = 0, a = a, P = P, ll = 0
|
|
704
|
+
"""
|
|
705
|
+
return [a, P_chol, pt.zeros(())]
|
|
706
|
+
|
|
707
|
+
[a_filtered, P_chol_filtered, ll] = pytensor.ifelse(
|
|
708
|
+
pt.eq(all_nan_flag, 1.0),
|
|
709
|
+
compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
|
|
710
|
+
compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
a_filtered = pt.specify_shape(a_filtered, (self.n_states,))
|
|
714
|
+
P_chol_filtered = pt.specify_shape(P_chol_filtered, (self.n_states, self.n_states))
|
|
715
|
+
|
|
716
|
+
return a_filtered, P_chol_filtered, y_hat, F_chol, ll
|
|
717
|
+
|
|
718
|
+
def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
|
|
719
|
+
"""
|
|
720
|
+
Convert the Cholesky factor of the covariance matrix back to the covariance matrix itself.
|
|
721
|
+
"""
|
|
722
|
+
results = super()._postprocess_scan_results(results, a0, P0, n)
|
|
723
|
+
(
|
|
724
|
+
filtered_states,
|
|
725
|
+
predicted_states,
|
|
726
|
+
observed_states,
|
|
727
|
+
filtered_covariances_cholesky,
|
|
728
|
+
predicted_covariances_cholesky,
|
|
729
|
+
observed_covariances_cholesky,
|
|
730
|
+
loglike_obs,
|
|
731
|
+
) = results
|
|
732
|
+
|
|
733
|
+
def square_sequnece(L, k):
|
|
734
|
+
X = pt.einsum("...ij,...kj->...ik", L, L.copy())
|
|
735
|
+
X = pt.specify_shape(X, (n, k, k))
|
|
736
|
+
return X
|
|
737
|
+
|
|
738
|
+
filtered_covariances = square_sequnece(filtered_covariances_cholesky, k=self.n_states)
|
|
739
|
+
predicted_covariances = square_sequnece(predicted_covariances_cholesky, k=self.n_states)
|
|
740
|
+
observed_covariances = square_sequnece(observed_covariances_cholesky, k=self.n_endog)
|
|
741
|
+
|
|
742
|
+
return [
|
|
743
|
+
filtered_states,
|
|
744
|
+
predicted_states,
|
|
745
|
+
observed_states,
|
|
746
|
+
filtered_covariances,
|
|
747
|
+
predicted_covariances,
|
|
748
|
+
observed_covariances,
|
|
749
|
+
loglike_obs,
|
|
750
|
+
]
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
class UnivariateFilter(BaseFilter):
|
|
754
|
+
"""
|
|
755
|
+
The univariate kalman filter, described in [1], section 6.4.2, avoids inversion of the F matrix, as well as two
|
|
756
|
+
matrix multiplications, at the cost of an additional loop. Note that the name doesn't mean there's only one
|
|
757
|
+
observed time series. This is called univariate because it updates the state mean and covariance matrices one
|
|
758
|
+
variable at a time, using an inner-inner loop.
|
|
759
|
+
|
|
760
|
+
This is useful when states are perfectly observed, because the F matrix can easily become degenerate in these cases.
|
|
761
|
+
|
|
762
|
+
References
|
|
763
|
+
----------
|
|
764
|
+
.. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
|
|
765
|
+
2nd ed, Oxford University Press, 2012.
|
|
766
|
+
|
|
767
|
+
"""
|
|
768
|
+
|
|
769
|
+
def _univariate_inner_filter_step(self, y, Z_row, d_row, sigma_H, nan_flag, a, P):
|
|
770
|
+
y_hat = Z_row.dot(a) + d_row
|
|
771
|
+
v = y - y_hat
|
|
772
|
+
|
|
773
|
+
PZT = P.dot(Z_row.T)
|
|
774
|
+
F = Z_row.dot(PZT) + sigma_H
|
|
775
|
+
|
|
776
|
+
# Set the zero flag for F first, then jitter it to avoid a divide-by-zero NaN later
|
|
777
|
+
F_zero_flag = pt.or_(pt.eq(F, 0), nan_flag)
|
|
778
|
+
F = F + self.cov_jitter
|
|
779
|
+
|
|
780
|
+
# If F is zero (implies y is NAN or another degenerate case), then we want:
|
|
781
|
+
# K = 0, a = a, P = P, ll = 0
|
|
782
|
+
K = PZT / F * (1 - F_zero_flag)
|
|
783
|
+
|
|
784
|
+
a_filtered = a + K * v
|
|
785
|
+
P_filtered = P - pt.outer(K, K) * F
|
|
786
|
+
|
|
787
|
+
ll_inner = pt.switch(F_zero_flag, 0.0, pt.log(F) + v**2 / F)
|
|
788
|
+
|
|
789
|
+
return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll_inner
|
|
790
|
+
|
|
791
|
+
def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q):
|
|
792
|
+
nan_mask = pt.isnan(y)
|
|
793
|
+
|
|
794
|
+
W = pt.set_subtensor(pt.eye(y.shape[0])[nan_mask, nan_mask], 0.0)
|
|
795
|
+
Z_masked = W.dot(Z)
|
|
796
|
+
H_masked = W.dot(H)
|
|
797
|
+
y_masked = pt.set_subtensor(y[nan_mask], 0.0)
|
|
798
|
+
|
|
799
|
+
result, updates = pytensor.scan(
|
|
800
|
+
self._univariate_inner_filter_step,
|
|
801
|
+
sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
|
|
802
|
+
outputs_info=[a, P, None, None, None],
|
|
803
|
+
mode=get_mode(self.mode),
|
|
804
|
+
name="univariate_inner_scan",
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
a_filtered, P_filtered, obs_mu, obs_cov, ll_inner = result
|
|
808
|
+
a_filtered, P_filtered, obs_mu, obs_cov = (
|
|
809
|
+
a_filtered[-1],
|
|
810
|
+
P_filtered[-1],
|
|
811
|
+
obs_mu[-1],
|
|
812
|
+
obs_cov[-1],
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
P_filtered = stabilize(0.5 * (P_filtered + P_filtered.T), self.cov_jitter)
|
|
816
|
+
a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q)
|
|
817
|
+
|
|
818
|
+
ll = -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum())
|
|
819
|
+
|
|
820
|
+
return a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll
|