pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,820 @@
1
+ from abc import ABC
2
+
3
+ import numpy as np
4
+ import pytensor
5
+ import pytensor.tensor as pt
6
+
7
+ from pymc.pytensorf import constant_fold
8
+ from pytensor.compile.mode import get_mode
9
+ from pytensor.graph.basic import Variable
10
+ from pytensor.raise_op import Assert
11
+ from pytensor.tensor import TensorVariable
12
+ from pytensor.tensor.slinalg import solve_triangular
13
+
14
+ from pymc_extras.statespace.filters.utilities import (
15
+ quad_form_sym,
16
+ split_vars_into_seq_and_nonseq,
17
+ stabilize,
18
+ )
19
+ from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL
20
+
21
+ MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
22
+ PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
23
+
24
+ assert_time_varying_dim_correct = Assert(
25
+ "The first dimension of a time varying matrix (the time dimension) must be "
26
+ "equal to the first dimension of the data (the time dimension)."
27
+ )
28
+
29
+
30
+ class BaseFilter(ABC):
31
+ def __init__(self, mode=None):
32
+ """
33
+ Kalman Filter.
34
+
35
+ Parameters
36
+ ----------
37
+ mode : str, optional
38
+ The mode used for Pytensor compilation. Defaults to None.
39
+
40
+ Notes
41
+ -----
42
+ The BaseFilter class is an abstract base class (ABC) for implementing kalman filters.
43
+ It defines common attributes and methods used by kalman filter implementations.
44
+
45
+ Attributes
46
+ ----------
47
+ mode : str or None
48
+ The mode used for Pytensor compilation.
49
+
50
+ seq_names : list[str]
51
+ A list of name representing time-varying statespace matrices. That is, inputs that will need to be
52
+ provided to the `sequences` argument of `pytensor.scan`
53
+
54
+ non_seq_names : list[str]
55
+ A list of names representing static statespace matrices. That is, inputs that will need to be provided
56
+ to the `non_sequences` argument of `pytensor.scan`
57
+ """
58
+
59
+ self.mode: str = mode
60
+ self.seq_names: list[str] = []
61
+ self.non_seq_names: list[str] = []
62
+
63
+ self.n_states = None
64
+ self.n_posdef = None
65
+ self.n_endog = None
66
+
67
+ self.missing_fill_value: float | None = None
68
+ self.cov_jitter = None
69
+
70
+ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
71
+ """
72
+ Apply any checks on validity of inputs. For most filters this is just the identity function.
73
+ """
74
+ return data, a0, P0, c, d, T, Z, R, H, Q
75
+
76
+ @staticmethod
77
+ def add_check_on_time_varying_shapes(
78
+ data: TensorVariable, sequence_params: list[TensorVariable]
79
+ ) -> list[Variable]:
80
+ """
81
+ Insert a check that time-varying matrices match the data shape to the computational graph.
82
+
83
+ If any matrices are time-varying, they need to have the same length as the data. This function wraps each
84
+ element of `sequence_params` in an assert `Op` that makes sure all inputs have the correct shape.
85
+
86
+ Parameters
87
+ ----------
88
+ data : TensorVariable
89
+ The tensor representing the data.
90
+
91
+ sequence_params : list[TensorVariable]
92
+ A list of tensors to be provided to `pytensor.scan` as `sequences`.
93
+
94
+ Returns
95
+ -------
96
+ list[TensorVariable]
97
+ A list of tensors wrapped in an `Assert` `Op` that checks the shape of the 0th dimension on each is equal
98
+ to the shape of the 0th dimension on the data.
99
+ """
100
+ # TODO: The PytensorRepresentation object puts the time dimension last, should the reshaping happen here in
101
+ # the Kalman filter, or in the StateSpaceModel, before passing into the KF?
102
+
103
+ params_with_assert = [
104
+ assert_time_varying_dim_correct(param, pt.eq(param.shape[0], data.shape[0]))
105
+ for param in sequence_params
106
+ ]
107
+
108
+ return params_with_assert
109
+
110
+ def unpack_args(self, args) -> tuple:
111
+ """
112
+ The order of inputs to the inner scan function is not known, since some, all, or none of the input matrices
113
+ can be time varying. The order arguments are fed to the inner function is sequences, outputs_info,
114
+ non-sequences. This function works out which matrices are where, and returns a standardized order expected
115
+ by the kalman_step function.
116
+
117
+ The standard order is: y, a0, P0, c, d, T, Z, R, H, Q
118
+ """
119
+ # If there are no sequence parameters (all params are static),
120
+ # no changes are needed, params will be in order.
121
+ args = list(args)
122
+ n_seq = len(self.seq_names)
123
+ if n_seq == 0:
124
+ return tuple(args)
125
+
126
+ # The first arg is always y
127
+ y = args.pop(0)
128
+
129
+ # There are always two outputs_info wedged between the seqs and non_seqs
130
+ seqs, (a0, P0), non_seqs = args[:n_seq], args[n_seq : n_seq + 2], args[n_seq + 2 :]
131
+ return_ordered = []
132
+ for name in ["c", "d", "T", "Z", "R", "H", "Q"]:
133
+ if name in self.seq_names:
134
+ idx = self.seq_names.index(name)
135
+ return_ordered.append(seqs[idx])
136
+ else:
137
+ idx = self.non_seq_names.index(name)
138
+ return_ordered.append(non_seqs[idx])
139
+
140
+ c, d, T, Z, R, H, Q = return_ordered
141
+
142
+ return y, a0, P0, c, d, T, Z, R, H, Q
143
+
144
+ def build_graph(
145
+ self,
146
+ data,
147
+ a0,
148
+ P0,
149
+ c,
150
+ d,
151
+ T,
152
+ Z,
153
+ R,
154
+ H,
155
+ Q,
156
+ mode=None,
157
+ return_updates=False,
158
+ missing_fill_value=None,
159
+ cov_jitter=None,
160
+ ) -> list[TensorVariable] | tuple[list[TensorVariable], dict]:
161
+ """
162
+ Construct the computation graph for the Kalman filter. See [1] for details.
163
+
164
+ Parameters
165
+ ----------
166
+ data : TensorVariable
167
+ Data to be filtered
168
+
169
+ mode : optional, str
170
+ Pytensor compile mode, passed to pytensor.scan
171
+
172
+ return_updates: bool, default False
173
+ Whether to return updates associated with the pytensor scan. Should only be requried to debug pruposes.
174
+
175
+ missing_fill_value: float, default -9999
176
+ Fill value used to mark missing values. Used to avoid PyMC's automatic interpolation, which conflict's with
177
+ the Kalman filter's hidden state inference. Change if your data happens to have legitimate values of -9999
178
+
179
+ cov_jitter: float, default 1e-8 or 1e-6 if pytensor.config.floatX is float32
180
+ The Kalman filter is known to be numerically unstable, especially at half precision. This value is added to
181
+ the diagonal of every covariance matrix -- predicted, filtered, and smoothed -- at every step, to ensure
182
+ all matrices are strictly positive semi-definite.
183
+
184
+ Obviously, if this can be zero, that's best. In general:
185
+ - Having measurement error makes Kalman Filters more robust. A large source of numerical errors come
186
+ from the Filtered and Smoothed matrices having a zero in the (0, 0) position, which always occurs
187
+ when there is no measurement error.
188
+
189
+ - The Univariate Filter is more robust than other filters, and can tolerate a lower jitter value
190
+
191
+ References
192
+ ----------
193
+ .. [1] Koopman, Siem Jan, Neil Shephard, and Jurgen A. Doornik. 1999.
194
+ Statistical Algorithms for Models in State Space Using SsfPack 2.2.
195
+ Econometrics Journal 2 (1): 107-60. doi:10.1111/1368-423X.00023.
196
+ """
197
+ if missing_fill_value is None:
198
+ missing_fill_value = MISSING_FILL
199
+ if cov_jitter is None:
200
+ cov_jitter = JITTER_DEFAULT
201
+
202
+ self.mode = mode
203
+ self.missing_fill_value = missing_fill_value
204
+ self.cov_jitter = cov_jitter
205
+
206
+ [R_shape] = constant_fold([R.shape], raise_not_constant=False)
207
+ [Z_shape] = constant_fold([Z.shape], raise_not_constant=False)
208
+
209
+ self.n_states, self.n_shocks = R_shape[-2:]
210
+ self.n_endog = Z_shape[-2]
211
+
212
+ data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
213
+
214
+ sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
215
+ params, PARAM_NAMES
216
+ )
217
+
218
+ self.seq_names = seq_names
219
+ self.non_seq_names = non_seq_names
220
+
221
+ if len(sequences) > 0:
222
+ sequences = self.add_check_on_time_varying_shapes(data, sequences)
223
+
224
+ results, updates = pytensor.scan(
225
+ self.kalman_step,
226
+ sequences=[data, *sequences],
227
+ outputs_info=[None, a0, None, None, P0, None, None],
228
+ non_sequences=non_sequences,
229
+ name="forward_kalman_pass",
230
+ mode=get_mode(self.mode),
231
+ strict=False,
232
+ )
233
+
234
+ filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
235
+
236
+ if return_updates:
237
+ return filter_results, updates
238
+ return filter_results
239
+
240
+ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
241
+ """
242
+ Transform the values returned by the Kalman Filter scan into a form expected by users. In particular:
243
+ 1. Append the initial state and covariance matrix to their respective Kalman predictions. This matches the
244
+ output returned by Statsmodels state space models.
245
+
246
+ 2. Discard the last state and covariance matrix from the Kalman predictions. This is beacuse the kalman filter
247
+ starts with the (random variable) initial state x0, and treats it as a predicted state. The first step (t=0)
248
+ will filter x0 to make filtered_states[0], then do a predict step to make predicted_states[1]. This means
249
+ the last step (t=T) predicted state will be a *forecast* for T+1. If the user wants this forecast, he should
250
+ use the forecast method.
251
+
252
+ 3. Squeeze away extra dimensions from the filtered and predicted states, as well as the likelihoods.
253
+ """
254
+ (
255
+ filtered_states,
256
+ predicted_states,
257
+ observed_states,
258
+ filtered_covariances,
259
+ predicted_covariances,
260
+ observed_covariances,
261
+ loglike_obs,
262
+ ) = results
263
+
264
+ predicted_states = pt.concatenate(
265
+ [pt.expand_dims(a0, axis=(0,)), predicted_states[:-1]], axis=0
266
+ )
267
+ predicted_covariances = pt.concatenate(
268
+ [pt.expand_dims(P0, axis=(0,)), predicted_covariances[:-1]], axis=0
269
+ )
270
+
271
+ filtered_states = pt.specify_shape(filtered_states, (n, self.n_states))
272
+ filtered_states.name = "filtered_states"
273
+
274
+ predicted_states = pt.specify_shape(predicted_states, (n, self.n_states))
275
+ predicted_states.name = "predicted_states"
276
+
277
+ observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
278
+ observed_states.name = "observed_states"
279
+
280
+ filtered_covariances = pt.specify_shape(
281
+ filtered_covariances, (n, self.n_states, self.n_states)
282
+ )
283
+ filtered_covariances.name = "filtered_covariances"
284
+
285
+ predicted_covariances = pt.specify_shape(
286
+ predicted_covariances, (n, self.n_states, self.n_states)
287
+ )
288
+ predicted_covariances.name = "predicted_covariances"
289
+
290
+ observed_covariances = pt.specify_shape(
291
+ observed_covariances, (n, self.n_endog, self.n_endog)
292
+ )
293
+ observed_covariances.name = "observed_covariances"
294
+
295
+ loglike_obs = pt.specify_shape(loglike_obs.squeeze(), (n,))
296
+ loglike_obs.name = "loglike_obs"
297
+
298
+ filter_results = [
299
+ filtered_states,
300
+ predicted_states,
301
+ observed_states,
302
+ filtered_covariances,
303
+ predicted_covariances,
304
+ observed_covariances,
305
+ loglike_obs,
306
+ ]
307
+
308
+ return filter_results
309
+
310
+ def handle_missing_values(
311
+ self, y, Z, H
312
+ ) -> tuple[TensorVariable, TensorVariable, TensorVariable, float]:
313
+ """
314
+ Handle missing values in the observation data `y`
315
+
316
+ Adjusts the design matrix `Z` and the observation noise covariance matrix `H` by removing rows and/or columns
317
+ associated with the data that is not observed at this iteration. Missing values are replaced with zeros to prevent
318
+ propagating NaNs through the computation.
319
+
320
+ Return a binary flag tensor `all_nan_flag`,indicating if all values in the observation data are missing. This
321
+ flag is used for numerical adjustments in the update method.
322
+
323
+ Parameters
324
+ ----------
325
+ y : TensorVariable
326
+ The observation data at time t.
327
+ Z : TensorVariable
328
+ The design matrix.
329
+ H : TensorVariable
330
+ The observation noise covariance matrix.
331
+
332
+ Returns
333
+ -------
334
+ y_masked : TensorVariable
335
+ Observation vector with missing values replaced by zeros.
336
+
337
+ Z_masked: TensorVariable
338
+ Design matrix adjusted to exclude the missing states from the information set of observed variables in the
339
+ update step
340
+
341
+ H_masked: TensorVariable
342
+ Noise covariance matrix, adjusted to exclude the missing states
343
+
344
+ all_nan_flag: float
345
+ 1 if the entire state vector is missing
346
+
347
+ References
348
+ ----------
349
+ .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
350
+ 2nd ed, Oxford University Press, 2012.
351
+ """
352
+ nan_mask = pt.or_(pt.isnan(y), pt.eq(y, self.missing_fill_value))
353
+ all_nan_flag = pt.all(nan_mask).astype(pytensor.config.floatX)
354
+ W = pt.diag(pt.bitwise_not(nan_mask).astype(pytensor.config.floatX))
355
+
356
+ Z_masked = W.dot(Z)
357
+ H_masked = W.dot(H)
358
+ y_masked = pt.set_subtensor(y[nan_mask], 0.0)
359
+
360
+ return y_masked, Z_masked, H_masked, all_nan_flag
361
+
362
+ @staticmethod
363
+ def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]:
364
+ """
365
+ Perform the prediction step of the Kalman filter.
366
+
367
+ This function computes the one-step forecast of the hidden states and the covariance matrix of the forecasted
368
+ states, based on the current state estimates and model parameters. For computational stability, the estimated
369
+ covariance matrix is forced to by symmetric by averaging it with its own transpose. The prediction equations
370
+ are:
371
+
372
+ .. math::
373
+
374
+ \begin{align}
375
+ a_{t+1 | t} &= T_t a_{t | t} \\
376
+ P_{t+1 | t} &= T_t P_{t | t} T_t^T + R_t Q_t R_t^T
377
+ \\end{align}
378
+
379
+
380
+ Parameters
381
+ ----------
382
+ a : TensorVariable
383
+ The current state vector estimate computed by the update step, a[t | t].
384
+ P : TensorVariable
385
+ The current covariance matrix estimate computed by the update step, P[t | t].
386
+ c : TensorVariable
387
+ The hidden state intercept/bias vector.
388
+ T : TensorVariable
389
+ The state transition matrix.
390
+ R : TensorVariable
391
+ The selection matrix.
392
+ Q : TensorVariable
393
+ The state innovation covariance matrix.
394
+
395
+ Returns
396
+ -------
397
+ a_hat : TensorVariable
398
+ One-step forecast of the hidden states
399
+ P_hat : TensorVariable
400
+ Covariance matrix of the forecasted hidden states
401
+
402
+ References
403
+ ----------
404
+ .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
405
+ 2nd ed, Oxford University Press, 2012.
406
+ """
407
+ a_hat = T.dot(a) + c
408
+ P_hat = quad_form_sym(T, P) + quad_form_sym(R, Q)
409
+
410
+ return a_hat, P_hat
411
+
412
+ @staticmethod
413
+ def update(
414
+ a, P, y, d, Z, H, all_nan_flag
415
+ ) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable]:
416
+ """
417
+ Perform the update step of the Kalman filter.
418
+
419
+ This function updates the state vector and covariance matrix estimates based on the current observation data,
420
+ previous predictions, and model parameters. The filtering equations are:
421
+
422
+ .. math::
423
+
424
+ \begin{align}
425
+ \\hat{y}_t &= Z_t a_{t | t-1} + d_t \\
426
+ v_t &= y_t - \\hat{y}_t \\
427
+ F_t &= Z_t P_{t | t-1} Z_t^T + H_t \\
428
+ a_{t|t} &= a_{t | t-1} + P_{t | t-1} Z_t^T F_t^{-1} v_t \\
429
+ P_{t|t} &= P_{t | t-1} - P_{t | t-1} Z_t^T F_t^{-1} Z_t P_{t | t-1}
430
+ \\end{align}
431
+
432
+
433
+ Parameters
434
+ ----------
435
+ a : TensorVariable
436
+ The current state vector estimate, conditioned on information up to time t-1.
437
+ P : TensorVariable
438
+ The current covariance matrix estimate, conditioned on information up to time t-1.
439
+ y : TensorVariable
440
+ The observation data at time t.
441
+ d : TensorVariable
442
+ The matrix d.
443
+ Z : TensorVariable
444
+ The matrix Z.
445
+ H : TensorVariable
446
+ The matrix H.
447
+ all_nan_flag : TensorVariable
448
+ A binary flag tensor indicating whether there are any missing values in the observation data.
449
+
450
+ Returns
451
+ -------
452
+ tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable]
453
+ A tuple containing the updated state vector `a_filtered`, the updated covariance matrix `P_filtered`, the
454
+ predicted observation `obs_mu`, the predicted observation covariance matrix `obs_cov`, and the log-likelihood `ll`.
455
+ """
456
+ raise NotImplementedError
457
+
458
+ def kalman_step(self, *args) -> tuple:
459
+ """
460
+ Performs a single iteration of the Kalman filter, which is composed of two steps : an update step and a
461
+ prediction step. The timing convention follows [1], in which initial state and covariance estimates a0 and P0
462
+ are taken to be predictions. As a result, the update step is applied first. The update step computes:
463
+
464
+ .. math::
465
+
466
+ \begin{align}
467
+ \\hat{y}_t &= Z_t a_{t | t-1} \\
468
+ v_t &= y_t - \\hat{y}_t \\
469
+ F_t &= Z_t P_{t | t-1} Z_t^T + H_t \\
470
+ a_{t|t} &= a_{t | t-1} + P_{t | t-1} Z_t^T F_t^{-1} v_t \\
471
+ P_{t|t} &= P_{t | t-1} - P_{t | t-1} Z_t^T F_t^{-1} Z_t P_{t | t-1}
472
+ \\end{align}
473
+
474
+ Where the quantities :math:`a_{t|t}` and :math:`P_{t|t}` are the best linear estimates of the hidden states
475
+ at time t, incorporating all information up to and including the observation :math:`y_t`. After the update step,
476
+ new one-step forecasts of the hidden states can be obtained by applying the model transition dynamics in
477
+ the prediction step:
478
+
479
+ .. math::
480
+
481
+ \begin{align}
482
+ a_{t+1 | t} &= T_t a_{t | t} \\
483
+ P_{t+1 | t} &= T_t P_{t | t} T_t^T + R_t Q_t R_t^T
484
+ \\end{align}
485
+
486
+ Recursive application of these two steps results in the best linear estimate of the hidden states, including
487
+ missing values and observations subject to measurement error.
488
+
489
+ Parameters
490
+ ----------
491
+ Kalman filter inputs:
492
+ y, a, P, c, d, T, Z, R, H, Q. See the docstring for the kalman filter class for details.
493
+
494
+ Returns
495
+ -------
496
+ a_filtered : TensorVariable
497
+ Best linear estimate of hidden states given all information up to and including the present
498
+ observation, a[t | t].
499
+
500
+ a_hat: TensorVariable
501
+ One-step forecast of next-period hidden states given all information up to and including the present
502
+ observation, a[t+1 | t]
503
+
504
+ obs_mu: TensorVariable
505
+ Estimates of the current observation given all information available prior to the current state,
506
+ d + Z @ a[t | t-1]
507
+
508
+ P_filtered: TensorVariable
509
+ Best linear estimate of the covariance between hidden states, given all information up to and including
510
+ the present observation, P[t | t]
511
+
512
+ P_hat: TensorVariable
513
+ Covariance between the one-step forecasted hidden states given all information up to and including the
514
+ present observation, P[t+1 | t]
515
+
516
+ obs_cov: TensorVariable
517
+ Covariance between estimated present observations, given all information available prior to the current
518
+ state, Z @ P[t | t-1] @ Z.T + H
519
+
520
+ ll: float
521
+ Likelihood of the time t observation vector under the multivariate normal distribution parameterized by
522
+ `obs_mu` and `obs_cov`
523
+
524
+ References
525
+ ----------
526
+ .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
527
+ 2nd ed, Oxford University Press, 2012.
528
+ """
529
+ y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args)
530
+ y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H)
531
+
532
+ a_filtered, P_filtered, obs_mu, obs_cov, ll = self.update(
533
+ y=y_masked, a=a, d=d, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag
534
+ )
535
+
536
+ P_filtered = stabilize(P_filtered, self.cov_jitter)
537
+
538
+ a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q)
539
+ outputs = (a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll)
540
+
541
+ return outputs
542
+
543
+
544
+ class StandardFilter(BaseFilter):
545
+ """
546
+ Basic Kalman Filter
547
+ """
548
+
549
+ def update(self, a, P, y, d, Z, H, all_nan_flag):
550
+ """
551
+ Compute one-step forecasts for observed states conditioned on information up to, but not including, the current
552
+ timestep, `y_hat`, along with the forcast covariance matrix, `F`. Marginalize over observed states to obtain
553
+ the best linear estimate of the unobserved states, `a_filtered`, as well as the associated covariance matrix,
554
+ `P_filtered`, conditioned on all information, up to and including the present.
555
+
556
+ Derivation of the Kalman filter, along with a deeper discussion of the computational elements, can be found in
557
+ [1].
558
+
559
+ Parameters
560
+ ----------
561
+ a : TensorVariable
562
+ The current state vector estimate, conditioned on information up to time t-1.
563
+
564
+ P : TensorVariable
565
+ The current covariance matrix estimate, conditioned on information up to time t-1.
566
+
567
+ y : TensorVariable
568
+ Observations at time t.
569
+
570
+ d : TensorVariable
571
+ Observed state bias term.
572
+
573
+ Z : TensorVariable
574
+ Linear map between unobserved and observed states.
575
+
576
+ H : TensorVariable
577
+ Observation noise covariance matrix
578
+
579
+ all_nan_flag : TensorVariable
580
+ A flag indicating whether all elements in the data `y` are NaNs.
581
+
582
+ Returns
583
+ -------
584
+ tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, float]
585
+ A tuple containing the updated state vector `a_filtered`, the updated covariance matrix `P_filtered`,
586
+ the one-step forecast mean `y_hat`, one-step forcast covariance matrix `F`, and the log-likelihood of
587
+ the data, given the one-step forecasts, `ll`.
588
+
589
+ References
590
+ ----------
591
+ .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
592
+ 2nd ed, Oxford University Press, 2012.
593
+ """
594
+ y_hat = d + Z.dot(a)
595
+ v = y - y_hat
596
+
597
+ PZT = P.dot(Z.T)
598
+ F = Z.dot(PZT) + stabilize(H, self.cov_jitter)
599
+
600
+ K = pt.linalg.solve(F.T, PZT.T, assume_a="pos", check_finite=False).T
601
+ I_KZ = pt.eye(self.n_states) - K.dot(Z)
602
+
603
+ a_filtered = a + K.dot(v)
604
+ P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
605
+
606
+ F_inv_v = pt.linalg.solve(F, v, assume_a="pos", check_finite=False)
607
+ inner_term = v.T @ F_inv_v
608
+
609
+ F_logdet = pt.log(pt.linalg.det(F))
610
+
611
+ ll = pt.switch(
612
+ all_nan_flag,
613
+ 0.0,
614
+ -0.5 * (MVN_CONST + F_logdet + inner_term).ravel()[0],
615
+ )
616
+
617
+ return a_filtered, P_filtered, y_hat, F, ll
618
+
619
+
620
+ class SquareRootFilter(BaseFilter):
621
+ """
622
+ Kalman filter with Cholesky factorization
623
+
624
+ Kalman filter implementation using a Cholesky factorization plus pt.solve_triangular to (attempt) to speed up
625
+ inversion of the observation covariance matrix `F`.
626
+
627
+ """
628
+
629
+ def predict(self, a, P, c, T, R, Q):
630
+ """
631
+ Compute one-step forecasts for the hidden states conditioned on information up to, but not including, the current
632
+ timestep, `a_hat`, along with the forcast covariance matrix, `P_hat`.
633
+
634
+ .. warning::
635
+ Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the
636
+ covariance matrix itself. The name `P` is kept for consistency with the superclass.
637
+ """
638
+ # Rename P to P_chol for clarity
639
+ P_chol = P
640
+
641
+ a_hat = T.dot(a) + c
642
+ Q_chol = pt.linalg.cholesky(Q, lower=True)
643
+
644
+ M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).T
645
+ R_decomp = pt.linalg.qr(M, mode="r")
646
+ P_chol_hat = R_decomp[: self.n_states, : self.n_states].T
647
+
648
+ return a_hat, P_chol_hat
649
+
650
+ def update(self, a, P, y, d, Z, H, all_nan_flag):
651
+ """
652
+ Compute posterior estimates of the hidden state distributions conditioned on the observed data, up to and
653
+ including the present timestep. Also compute the log-likelihood of the data given the one-step forecasts.
654
+
655
+ .. warning::
656
+ Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the
657
+ covariance matrix itself. The name `P` is kept for consistency with the superclass.
658
+ """
659
+
660
+ # Rename P to P_chol for clarity
661
+ P_chol = P
662
+
663
+ y_hat = Z.dot(a) + d
664
+ v = y - y_hat
665
+
666
+ H_chol = pytensor.ifelse(pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True))
667
+
668
+ # The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
669
+ # Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
670
+ # [0, L_pred]]
671
+ # The Schur decomposition of this matrix will be B (upper triangular). We are
672
+ # more insterested in B^T:
673
+ # Structure of B^T = [[chol(F), 0 ],
674
+ # [K @ chol(F), chol(P_filtered)]
675
+ zeros = pt.zeros((self.n_states, self.n_endog))
676
+ upper = pt.horizontal_stack(H_chol, Z @ P_chol)
677
+ lower = pt.horizontal_stack(zeros, P_chol)
678
+ A_T = pt.vertical_stack(upper, lower)
679
+ B = pt.linalg.qr(A_T.T, mode="r").T
680
+
681
+ F_chol = B[: self.n_endog, : self.n_endog]
682
+ K_F_chol = B[self.n_endog :, : self.n_endog]
683
+ P_chol_filtered = B[self.n_endog :, self.n_endog :]
684
+
685
+ def compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v):
686
+ a_filtered = a + K_F_chol @ solve_triangular(F_chol, v, lower=True)
687
+
688
+ inner_term = solve_triangular(
689
+ F_chol, solve_triangular(F_chol, v, lower=True), lower=True
690
+ )
691
+ loss = (v.T @ inner_term).ravel()
692
+
693
+ # abs necessary because we're not guaranteed a positive diagonal from the schur decomposition
694
+ logdet = 2 * pt.log(pt.abs(pt.diag(F_chol))).sum()
695
+
696
+ ll = -0.5 * (self.n_endog * (MVN_CONST + logdet) + loss)[0]
697
+
698
+ return [a_filtered, P_chol_filtered, ll]
699
+
700
+ def compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v):
701
+ """
702
+ If F is zero (usually because there were no observations this period), then we want:
703
+ K = 0, a = a, P = P, ll = 0
704
+ """
705
+ return [a, P_chol, pt.zeros(())]
706
+
707
+ [a_filtered, P_chol_filtered, ll] = pytensor.ifelse(
708
+ pt.eq(all_nan_flag, 1.0),
709
+ compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
710
+ compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
711
+ )
712
+
713
+ a_filtered = pt.specify_shape(a_filtered, (self.n_states,))
714
+ P_chol_filtered = pt.specify_shape(P_chol_filtered, (self.n_states, self.n_states))
715
+
716
+ return a_filtered, P_chol_filtered, y_hat, F_chol, ll
717
+
718
+ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
719
+ """
720
+ Convert the Cholesky factor of the covariance matrix back to the covariance matrix itself.
721
+ """
722
+ results = super()._postprocess_scan_results(results, a0, P0, n)
723
+ (
724
+ filtered_states,
725
+ predicted_states,
726
+ observed_states,
727
+ filtered_covariances_cholesky,
728
+ predicted_covariances_cholesky,
729
+ observed_covariances_cholesky,
730
+ loglike_obs,
731
+ ) = results
732
+
733
+ def square_sequnece(L, k):
734
+ X = pt.einsum("...ij,...kj->...ik", L, L.copy())
735
+ X = pt.specify_shape(X, (n, k, k))
736
+ return X
737
+
738
+ filtered_covariances = square_sequnece(filtered_covariances_cholesky, k=self.n_states)
739
+ predicted_covariances = square_sequnece(predicted_covariances_cholesky, k=self.n_states)
740
+ observed_covariances = square_sequnece(observed_covariances_cholesky, k=self.n_endog)
741
+
742
+ return [
743
+ filtered_states,
744
+ predicted_states,
745
+ observed_states,
746
+ filtered_covariances,
747
+ predicted_covariances,
748
+ observed_covariances,
749
+ loglike_obs,
750
+ ]
751
+
752
+
753
+ class UnivariateFilter(BaseFilter):
754
+ """
755
+ The univariate kalman filter, described in [1], section 6.4.2, avoids inversion of the F matrix, as well as two
756
+ matrix multiplications, at the cost of an additional loop. Note that the name doesn't mean there's only one
757
+ observed time series. This is called univariate because it updates the state mean and covariance matrices one
758
+ variable at a time, using an inner-inner loop.
759
+
760
+ This is useful when states are perfectly observed, because the F matrix can easily become degenerate in these cases.
761
+
762
+ References
763
+ ----------
764
+ .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
765
+ 2nd ed, Oxford University Press, 2012.
766
+
767
+ """
768
+
769
+ def _univariate_inner_filter_step(self, y, Z_row, d_row, sigma_H, nan_flag, a, P):
770
+ y_hat = Z_row.dot(a) + d_row
771
+ v = y - y_hat
772
+
773
+ PZT = P.dot(Z_row.T)
774
+ F = Z_row.dot(PZT) + sigma_H
775
+
776
+ # Set the zero flag for F first, then jitter it to avoid a divide-by-zero NaN later
777
+ F_zero_flag = pt.or_(pt.eq(F, 0), nan_flag)
778
+ F = F + self.cov_jitter
779
+
780
+ # If F is zero (implies y is NAN or another degenerate case), then we want:
781
+ # K = 0, a = a, P = P, ll = 0
782
+ K = PZT / F * (1 - F_zero_flag)
783
+
784
+ a_filtered = a + K * v
785
+ P_filtered = P - pt.outer(K, K) * F
786
+
787
+ ll_inner = pt.switch(F_zero_flag, 0.0, pt.log(F) + v**2 / F)
788
+
789
+ return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll_inner
790
+
791
+ def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q):
792
+ nan_mask = pt.isnan(y)
793
+
794
+ W = pt.set_subtensor(pt.eye(y.shape[0])[nan_mask, nan_mask], 0.0)
795
+ Z_masked = W.dot(Z)
796
+ H_masked = W.dot(H)
797
+ y_masked = pt.set_subtensor(y[nan_mask], 0.0)
798
+
799
+ result, updates = pytensor.scan(
800
+ self._univariate_inner_filter_step,
801
+ sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
802
+ outputs_info=[a, P, None, None, None],
803
+ mode=get_mode(self.mode),
804
+ name="univariate_inner_scan",
805
+ )
806
+
807
+ a_filtered, P_filtered, obs_mu, obs_cov, ll_inner = result
808
+ a_filtered, P_filtered, obs_mu, obs_cov = (
809
+ a_filtered[-1],
810
+ P_filtered[-1],
811
+ obs_mu[-1],
812
+ obs_cov[-1],
813
+ )
814
+
815
+ P_filtered = stabilize(0.5 * (P_filtered + P_filtered.T), self.cov_jitter)
816
+ a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q)
817
+
818
+ ll = -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum())
819
+
820
+ return a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll