pymc-extras 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. pymc_extras/distributions/timeseries.py +10 -10
  2. pymc_extras/inference/dadvi/dadvi.py +14 -83
  3. pymc_extras/inference/laplace_approx/laplace.py +187 -159
  4. pymc_extras/inference/pathfinder/pathfinder.py +12 -7
  5. pymc_extras/inference/smc/sampling.py +2 -2
  6. pymc_extras/model/marginal/distributions.py +4 -2
  7. pymc_extras/model/marginal/marginal_model.py +12 -2
  8. pymc_extras/prior.py +3 -3
  9. pymc_extras/statespace/core/properties.py +276 -0
  10. pymc_extras/statespace/core/statespace.py +182 -45
  11. pymc_extras/statespace/filters/distributions.py +19 -34
  12. pymc_extras/statespace/filters/kalman_filter.py +13 -12
  13. pymc_extras/statespace/filters/kalman_smoother.py +2 -2
  14. pymc_extras/statespace/models/DFM.py +179 -168
  15. pymc_extras/statespace/models/ETS.py +177 -151
  16. pymc_extras/statespace/models/SARIMAX.py +149 -152
  17. pymc_extras/statespace/models/VARMAX.py +134 -145
  18. pymc_extras/statespace/models/__init__.py +8 -1
  19. pymc_extras/statespace/models/structural/__init__.py +30 -8
  20. pymc_extras/statespace/models/structural/components/autoregressive.py +87 -45
  21. pymc_extras/statespace/models/structural/components/cycle.py +119 -80
  22. pymc_extras/statespace/models/structural/components/level_trend.py +95 -42
  23. pymc_extras/statespace/models/structural/components/measurement_error.py +27 -17
  24. pymc_extras/statespace/models/structural/components/regression.py +105 -68
  25. pymc_extras/statespace/models/structural/components/seasonality.py +138 -100
  26. pymc_extras/statespace/models/structural/core.py +397 -286
  27. pymc_extras/statespace/models/utilities.py +5 -20
  28. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/METADATA +4 -4
  29. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/RECORD +31 -30
  30. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/WHEEL +0 -0
  31. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -19,6 +19,22 @@ from rich.box import SIMPLE_HEAD
19
19
  from rich.console import Console
20
20
  from rich.table import Table
21
21
 
22
+ from pymc_extras.statespace.core.properties import (
23
+ Coord,
24
+ CoordInfo,
25
+ Data,
26
+ DataInfo,
27
+ Parameter,
28
+ ParameterInfo,
29
+ Shock,
30
+ ShockInfo,
31
+ State,
32
+ StateInfo,
33
+ SymbolicData,
34
+ SymbolicDataInfo,
35
+ SymbolicVariable,
36
+ SymbolicVariableInfo,
37
+ )
22
38
  from pymc_extras.statespace.core.representation import PytensorRepresentation
23
39
  from pymc_extras.statespace.filters import (
24
40
  KalmanSmoother,
@@ -69,6 +85,22 @@ def _verify_group(group):
69
85
  raise ValueError(f'Argument "group" must be one of "prior" or "posterior", found {group}')
70
86
 
71
87
 
88
+ def _validate_property(props, property_name, expected_type):
89
+ if isinstance(props, expected_type) or props is None:
90
+ return
91
+ elif not isinstance(props, tuple | list):
92
+ raise TypeError(
93
+ f"The {property_name} property must be a {expected_type.__name__} or a "
94
+ f"list/tuple of {expected_type.__name__} instances."
95
+ )
96
+
97
+ if not all(isinstance(prop, expected_type) for prop in props):
98
+ raise TypeError(
99
+ f"All elements of the {property_name} property must be instances of "
100
+ f"{expected_type.__name__}."
101
+ )
102
+
103
+
72
104
  class PyMCStateSpace:
73
105
  r"""
74
106
  Base class for Linear Gaussian Statespace models in PyMC.
@@ -236,8 +268,8 @@ class PyMCStateSpace:
236
268
  self._fit_exog_data: dict[str, dict] = {}
237
269
 
238
270
  self._needs_exog_data = None
239
- self._name_to_variable = {}
240
- self._name_to_data = {}
271
+ self._tensor_variable_info = SymbolicVariableInfo()
272
+ self._tensor_data_info = SymbolicDataInfo()
241
273
 
242
274
  self.k_endog = k_endog
243
275
  self.k_states = k_states
@@ -245,6 +277,8 @@ class PyMCStateSpace:
245
277
  self.measurement_error = measurement_error
246
278
  self.mode = mode
247
279
 
280
+ self._populate_properties()
281
+
248
282
  # All models contain a state space representation and a Kalman filter
249
283
  self.ssm = PytensorRepresentation(k_endog, k_states, k_posdef)
250
284
 
@@ -271,17 +305,107 @@ class PyMCStateSpace:
271
305
  console = Console()
272
306
  console.print(self.requirement_table)
273
307
 
308
+ def _populate_properties(self) -> None:
309
+ self._set_parameters()
310
+ self._set_states()
311
+ self._set_shocks()
312
+ self._set_coords()
313
+ self._set_data_info()
314
+
315
+ def set_parameters(self) -> Parameter | tuple[Parameter, ...] | None:
316
+ """
317
+ Provides parameter metadata to the model.
318
+
319
+ Optional. Default implementation sets an empty ParameterInfo.
320
+ Child classes can override to define model parameters.
321
+ """
322
+ return
323
+
324
+ def _set_parameters(self) -> None:
325
+ param_list = self.set_parameters()
326
+ self._param_info = ParameterInfo(parameters=param_list)
327
+
328
+ def set_states(self) -> State | tuple[State, ...] | None:
329
+ """
330
+ Provide state metadata to the model.
331
+
332
+ Optional. Default implementation creates generic state names (state_0, state_1, ...).
333
+ Child classes can override to provide meaningful state names.
334
+ """
335
+ hidden_states = [
336
+ State(name=f"hidden_state_{name}", observed=False) for name in range(self.k_states or 0)
337
+ ]
338
+ observed_states = [
339
+ State(name=f"observed_state_{name}", observed=True) for name in range(self.k_endog or 0)
340
+ ]
341
+ return *hidden_states, *observed_states
342
+
343
+ def _set_states(self) -> None:
344
+ states = self.set_states()
345
+ _validate_property(states, "states", State)
346
+
347
+ if isinstance(states, State):
348
+ states = (states,)
349
+
350
+ self._state_info = StateInfo(states=states)
351
+
352
+ def set_shocks(self) -> Shock | tuple[Shock, ...] | None:
353
+ """
354
+ Provide shock metadata to the model.
355
+
356
+ Optional. Default implementation creates generic shock names (shock_0, shock_1, ...).
357
+ Child classes can override to provide meaningful shock names.
358
+ """
359
+ return tuple(Shock(name=f"shock_{i}") for i in range(self.k_posdef))
360
+
361
+ def _set_shocks(self) -> None:
362
+ shocks = self.set_shocks()
363
+ _validate_property(shocks, "shocks", Shock)
364
+ if isinstance(shocks, Shock):
365
+ shocks = (shocks,)
366
+ self._shock_info = ShockInfo(shocks=shocks)
367
+
368
+ def default_coords(self) -> tuple[Coord, ...]:
369
+ return CoordInfo.default_coords_from_model(self).items
370
+
371
+ def set_coords(self) -> Coord | tuple[Coord, ...] | None:
372
+ """
373
+ Provide coordinates to the model.
374
+
375
+ Optional. Default implementation sets an empty CoordInfo.
376
+ Child classes can override to provide model-specific coordinates.
377
+ """
378
+ return self.default_coords()
379
+
380
+ def _set_coords(self) -> None:
381
+ coords = self.set_coords()
382
+ _validate_property(coords, "coords", Coord)
383
+ if isinstance(coords, Coord):
384
+ coords = (coords,)
385
+ self._coords_info = CoordInfo(coords=coords)
386
+
387
+ def set_data_info(self) -> Data | tuple[Data, ...] | None:
388
+ """
389
+ Provide data_info metadata to the model.
390
+
391
+ Optional. Default implementation sets an empty DataInfo.
392
+ Child classes can override if the model requires exogenous data.
393
+ """
394
+ return
395
+
396
+ def _set_data_info(self) -> None:
397
+ data_info = self.set_data_info()
398
+ _validate_property(data_info, "data_info", Data)
399
+ if isinstance(data_info, Data):
400
+ data_info = (data_info,)
401
+ self._data_info = DataInfo(data=data_info)
402
+
274
403
  def _populate_prior_requirements(self) -> None:
275
404
  """
276
405
  Add requirements about priors needed for the model to a rich table, including their names,
277
406
  shapes, named dimensions, and any parameter constraints.
278
407
  """
279
- # Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
280
- # is not true.
281
- try:
282
- if not isinstance(self.param_info, dict):
283
- return
284
- except NotImplementedError:
408
+ if not self.param_info:
285
409
  return
286
410
 
287
411
  if self.requirement_table is None:
@@ -296,10 +420,7 @@ class PyMCStateSpace:
296
420
  """
297
421
  Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
298
422
  """
299
- try:
300
- if not isinstance(self.data_info, dict):
301
- return
302
- except NotImplementedError:
423
+ if not self.data_info:
303
424
  return
304
425
 
305
426
  if self.requirement_table is None:
@@ -374,24 +495,24 @@ class PyMCStateSpace:
374
495
  return self.subbed_ssm
375
496
 
376
497
  @property
377
- def param_names(self) -> list[str]:
498
+ def param_names(self) -> tuple[str, ...]:
378
499
  """
379
500
  Names of model parameters
380
501
 
381
502
  A list of all parameters expected by the model. Each parameter will be sought inside the active PyMC model
382
503
  context when ``build_statespace_graph`` is invoked.
383
504
  """
384
- raise NotImplementedError("The param_names property has not been implemented!")
505
+ return self._param_info.names
385
506
 
386
507
  @property
387
- def data_names(self) -> list[str]:
508
+ def data_names(self) -> tuple[str, ...]:
388
509
  """
389
510
  Names of data variables expected by the model.
390
511
 
391
512
  This does not include the observed data series, which is automatically handled by PyMC. This property only
392
513
  needs to be implemented for models that expect exogenous data.
393
514
  """
394
- return []
515
+ return self._data_info.exogenous_names
395
516
 
396
517
  @property
397
518
  def param_info(self) -> dict[str, dict[str, Any]]:
@@ -406,7 +527,7 @@ class PyMCStateSpace:
406
527
  positive semi-definite, etc)
407
528
  * key: "dims", value: tuple of strings
408
529
  """
409
- raise NotImplementedError("The params_info property has not been implemented!")
530
+ return self._param_info.to_dict()
410
531
 
411
532
  @property
412
533
  def data_info(self) -> dict[str, dict[str, Any]]:
@@ -419,31 +540,30 @@ class PyMCStateSpace:
419
540
  * key: "shape", value: a tuple of integers
420
541
  * key: "dims", value: tuple of strings
421
542
  """
422
- raise NotImplementedError("The data_info property has not been implemented!")
543
+ return self._data_info.to_dict()
423
544
 
424
545
  @property
425
- def state_names(self) -> list[str]:
546
+ def state_names(self) -> tuple[str, ...]:
426
547
  """
427
548
  A k_states length list of strings, associated with the model's hidden states
428
549
 
429
550
  """
430
-
431
- raise NotImplementedError("The state_names property has not been implemented!")
551
+ return self._state_info.unobserved_state_names
432
552
 
433
553
  @property
434
- def observed_states(self) -> list[str]:
554
+ def observed_states(self) -> tuple[str, ...]:
435
555
  """
436
556
  A k_endog length list of strings, associated with the model's observed states
437
557
  """
438
- raise NotImplementedError("The observed_states property has not been implemented!")
558
+ return self._state_info.observed_state_names
439
559
 
440
560
  @property
441
- def shock_names(self) -> list[str]:
561
+ def shock_names(self) -> tuple[str, ...]:
442
562
  """
443
563
  A k_posdef length list of strings, associated with the model's shock processes
444
564
 
445
565
  """
446
- raise NotImplementedError("The shock_names property has not been implemented!")
566
+ return self._shock_info.names
447
567
 
448
568
  @property
449
569
  def default_priors(self) -> dict[str, Callable]:
@@ -464,10 +584,10 @@ class PyMCStateSpace:
464
584
  should come from the default names defined in ``statespace.utils.constants`` for them to be detected by
465
585
  sampling methods.
466
586
  """
467
- raise NotImplementedError("The coords property has not been implemented!")
587
+ return self._coords_info.to_dict()
468
588
 
469
589
  @property
470
- def param_dims(self) -> dict[str, Sequence[str]]:
590
+ def param_dims(self) -> dict[str, tuple[str, ...]]:
471
591
  """
472
592
  Dictionary of named dimensions for each model parameter
473
593
 
@@ -475,8 +595,17 @@ class PyMCStateSpace:
475
595
  PyMC random variable. Dimensions should come from the default names defined in ``statespace.utils.constants``
476
596
  for them to be detected by sampling methods.
477
597
 
598
+ Note: Scalar parameters (with dims=None) are not included in this dictionary.
478
599
  """
479
- raise NotImplementedError("The param_dims property has not been implemented!")
600
+ return {param.name: param.dims for param in self._param_info if param.dims is not None}
601
+
602
+ @property
603
+ def _name_to_variable(self):
604
+ return self._tensor_variable_info.to_dict()
605
+
606
+ @property
607
+ def _name_to_data(self):
608
+ return self._tensor_data_info.to_dict()
480
609
 
481
610
  def add_default_priors(self) -> None:
482
611
  """
@@ -520,14 +649,15 @@ class PyMCStateSpace:
520
649
  f"parameters."
521
650
  )
522
651
 
523
- if name in self._name_to_variable.keys():
652
+ if name in self._tensor_variable_info:
524
653
  raise ValueError(
525
654
  f"{name} is already a registered placeholder variable with shape "
526
- f"{self._name_to_variable[name].type.shape}"
655
+ f"{self._tensor_variable_info[name].type.shape}"
527
656
  )
528
657
 
529
658
  placeholder = pt.tensor(name, shape=shape, dtype=dtype)
530
- self._name_to_variable[name] = placeholder
659
+ tensor_var = SymbolicVariable(name=name, symbolic_variable=placeholder)
660
+ self._tensor_variable_info = self._tensor_variable_info.add(tensor_var)
531
661
  return placeholder
532
662
 
533
663
  def make_and_register_data(
@@ -559,14 +689,15 @@ class PyMCStateSpace:
559
689
  f"parameters."
560
690
  )
561
691
 
562
- if name in self._name_to_data.keys():
692
+ if name in self._tensor_data_info:
563
693
  raise ValueError(
564
694
  f"{name} is already a registered placeholder variable with shape "
565
- f"{self._name_to_data[name].type.shape}"
695
+ f"{self._tensor_data_info[name].type.shape}"
566
696
  )
567
697
 
568
698
  placeholder = pt.tensor(name, shape=shape, dtype=dtype)
569
- self._name_to_data[name] = placeholder
699
+ tensor_data = SymbolicData(name=name, symbolic_data=placeholder)
700
+ self._tensor_data_info = self._tensor_data_info.add(tensor_data)
570
701
  return placeholder
571
702
 
572
703
  def make_symbolic_graph(self) -> None:
@@ -741,10 +872,8 @@ class PyMCStateSpace:
741
872
 
742
873
  Only used when models require exogenous data. The observed data is not added to the model using this method!
743
874
  """
744
-
745
- try:
746
- data_names = self.data_names
747
- except NotImplementedError:
875
+ data_names = self.data_names
876
+ if not data_names:
748
877
  return
749
878
 
750
879
  pymc_model = modelcontext(None)
@@ -892,7 +1021,6 @@ class PyMCStateSpace:
892
1021
  .. deprecated:: 0.2.5
893
1022
  The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
894
1023
  model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
895
-
896
1024
  """
897
1025
  if mode is not None:
898
1026
  warnings.warn(
@@ -1430,7 +1558,11 @@ class PyMCStateSpace:
1430
1558
  """
1431
1559
 
1432
1560
  return self._sample_conditional(
1433
- idata=idata, group="prior", random_seed=random_seed, mvn_method=mvn_method, **kwargs
1561
+ idata=idata,
1562
+ group="prior",
1563
+ random_seed=random_seed,
1564
+ mvn_method=mvn_method,
1565
+ **kwargs,
1434
1566
  )
1435
1567
 
1436
1568
  def sample_conditional_posterior(
@@ -1473,7 +1605,11 @@ class PyMCStateSpace:
1473
1605
  """
1474
1606
 
1475
1607
  return self._sample_conditional(
1476
- idata=idata, group="posterior", random_seed=random_seed, mvn_method=mvn_method, **kwargs
1608
+ idata=idata,
1609
+ group="posterior",
1610
+ random_seed=random_seed,
1611
+ mvn_method=mvn_method,
1612
+ **kwargs,
1477
1613
  )
1478
1614
 
1479
1615
  def sample_unconditional_prior(
@@ -2080,11 +2216,11 @@ class PyMCStateSpace:
2080
2216
  forecast_index: pd.RangeIndex | pd.DatetimeIndex,
2081
2217
  name=None,
2082
2218
  ):
2083
- try:
2084
- var_to_dims = {key: info["dims"][1:] for key, info in self.data_info.items()}
2085
- except NotImplementedError:
2219
+ if not self.data_info:
2086
2220
  return scenario
2087
2221
 
2222
+ var_to_dims = {key: info["dims"][1:] for key, info in self.data_info.items()}
2223
+
2088
2224
  if any(len(dims) > 1 for dims in var_to_dims.values()):
2089
2225
  raise NotImplementedError(">2d exogenous data is not yet supported.")
2090
2226
  coords = {
@@ -2500,13 +2636,14 @@ class PyMCStateSpace:
2500
2636
  next_x = c + T @ x + R @ shock
2501
2637
  return next_x
2502
2638
 
2503
- irf, updates = pytensor.scan(
2639
+ irf = pytensor.scan(
2504
2640
  irf_step,
2505
2641
  sequences=[shock_trajectory],
2506
2642
  outputs_info=[x0],
2507
2643
  non_sequences=[c, T, R],
2508
2644
  n_steps=n_steps,
2509
2645
  strict=True,
2646
+ return_updates=False,
2510
2647
  )
2511
2648
 
2512
2649
  pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM])
@@ -1,4 +1,3 @@
1
- import numpy as np
2
1
  import pymc as pm
3
2
  import pytensor
4
3
  import pytensor.tensor as pt
@@ -8,7 +7,9 @@ from pymc.distributions.dist_math import check_parameters
8
7
  from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
9
8
  from pymc.distributions.shape_utils import get_support_shape_1d
10
9
  from pymc.logprob.abstract import _logprob
10
+ from pymc.pytensorf import normalize_rng_param
11
11
  from pytensor.graph.basic import Node
12
+ from pytensor.tensor.random import multivariate_normal
12
13
 
13
14
  floatX = pytensor.config.floatX
14
15
  COV_ZERO_TOL = 0
@@ -152,6 +153,7 @@ class _LinearGaussianStateSpace(Continuous):
152
153
  Q,
153
154
  steps,
154
155
  size=None,
156
+ rng=None,
155
157
  sequence_names=None,
156
158
  append_x0=True,
157
159
  method="svd",
@@ -178,7 +180,7 @@ class _LinearGaussianStateSpace(Continuous):
178
180
  ]
179
181
  non_sequences = [x for x in [c_, d_, T_, Z_, R_, H_, Q_] if x not in sequences]
180
182
 
181
- rng = pytensor.shared(np.random.default_rng())
183
+ rng = normalize_rng_param(rng)
182
184
 
183
185
  def sort_args(args):
184
186
  sorted_args = []
@@ -197,10 +199,9 @@ class _LinearGaussianStateSpace(Continuous):
197
199
  n_seq = len(sequence_names)
198
200
 
199
201
  def step_fn(*args):
200
- seqs, state, non_seqs = args[:n_seq], args[n_seq], args[n_seq + 1 :]
201
- non_seqs, rng = non_seqs[:-1], non_seqs[-1]
202
+ seqs, (rng, state, *non_seqs) = args[:n_seq], args[n_seq:]
202
203
 
203
- c, d, T, Z, R, H, Q = sort_args(seqs + non_seqs)
204
+ c, d, T, Z, R, H, Q = sort_args((*seqs, *non_seqs))
204
205
  k = T.shape[0]
205
206
  a = state[:k]
206
207
 
@@ -219,7 +220,7 @@ class _LinearGaussianStateSpace(Continuous):
219
220
 
220
221
  next_state = pt.concatenate([a_next, y_next], axis=0)
221
222
 
222
- return next_state, {rng: next_rng}
223
+ return next_rng, next_state
223
224
 
224
225
  Z_init = Z_ if Z_ in non_sequences else Z_[0]
225
226
  H_init = H_ if H_ in non_sequences else H_[0]
@@ -229,13 +230,14 @@ class _LinearGaussianStateSpace(Continuous):
229
230
 
230
231
  init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
231
232
 
232
- statespace, updates = pytensor.scan(
233
+ ss_rng, statespace = pytensor.scan(
233
234
  step_fn,
234
- outputs_info=[init_dist_],
235
+ outputs_info=[rng, init_dist_],
235
236
  sequences=None if len(sequences) == 0 else sequences,
236
- non_sequences=[*non_sequences, rng],
237
+ non_sequences=[*non_sequences],
237
238
  n_steps=steps,
238
239
  strict=True,
240
+ return_updates=False,
239
241
  )
240
242
 
241
243
  if append_x0:
@@ -245,7 +247,6 @@ class _LinearGaussianStateSpace(Continuous):
245
247
  statespace_ = statespace
246
248
  statespace_ = pt.specify_shape(statespace_, (steps, None))
247
249
 
248
- (ss_rng,) = tuple(updates.values())
249
250
  linear_gaussian_ss_op = LinearGaussianStateSpaceRV(
250
251
  inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps, rng],
251
252
  outputs=[ss_rng, statespace_],
@@ -368,41 +369,25 @@ class SequenceMvNormal(Continuous):
368
369
 
369
370
  @classmethod
370
371
  def dist(cls, mus, covs, logp, method="svd", **kwargs):
372
+ mus, covs, logp = map(pt.as_tensor_variable, (mus, covs, logp))
371
373
  return super().dist([mus, covs, logp], method=method, **kwargs)
372
374
 
373
375
  @classmethod
374
- def rv_op(cls, mus, covs, logp, method="svd", size=None):
375
- # Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
376
- if mus.ndim > 2:
377
- mus = pt.moveaxis(mus, -2, 0)
378
- if covs.ndim > 3:
379
- covs = pt.moveaxis(covs, -3, 0)
380
-
381
- mus_, covs_ = mus.type(), covs.type()
382
-
376
+ def rv_op(cls, mus, covs, logp, method="svd", size=None, rng=None):
377
+ rng = normalize_rng_param(rng)
383
378
  logp_ = logp.type()
384
- rng = pytensor.shared(np.random.default_rng())
385
379
 
386
- def step(mu, cov, rng):
387
- new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method=method).owner.outputs
388
- return mvn, {rng: new_rng}
389
-
390
- mvn_seq, updates = pytensor.scan(
391
- step, sequences=[mus_, covs_], non_sequences=[rng], strict=True, n_steps=mus_.shape[0]
392
- )
393
- mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
394
-
395
- # Move time axis back to position -2 so batches are on the left
396
- if mvn_seq.ndim > 2:
397
- mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
398
-
399
- (seq_mvn_rng,) = tuple(updates.values())
380
+ mus_, covs_ = mus.type(), covs.type()
381
+ seq_mvn_rng, mvn_seq = multivariate_normal(
382
+ mean=mus_, cov=covs_, rng=rng, method=method
383
+ ).owner.outputs
400
384
 
401
385
  mvn_seq_op = KalmanFilterRV(
402
386
  inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
403
387
  )
404
388
 
405
389
  mvn_seq = mvn_seq_op(mus, covs, logp, rng)
390
+
406
391
  return mvn_seq
407
392
 
408
393
 
@@ -8,7 +8,7 @@ from pymc.pytensorf import constant_fold
8
8
  from pytensor.graph.basic import Variable
9
9
  from pytensor.raise_op import Assert
10
10
  from pytensor.tensor import TensorVariable
11
- from pytensor.tensor.slinalg import solve_triangular
11
+ from pytensor.tensor.linalg import solve_triangular
12
12
 
13
13
  from pymc_extras.statespace.filters.utilities import (
14
14
  quad_form_sym,
@@ -148,10 +148,9 @@ class BaseFilter(ABC):
148
148
  R,
149
149
  H,
150
150
  Q,
151
- return_updates=False,
152
151
  missing_fill_value=None,
153
152
  cov_jitter=None,
154
- ) -> list[TensorVariable] | tuple[list[TensorVariable], dict]:
153
+ ) -> list[TensorVariable]:
155
154
  """
156
155
  Construct the computation graph for the Kalman filter. See [1] for details.
157
156
 
@@ -211,20 +210,17 @@ class BaseFilter(ABC):
211
210
  if len(sequences) > 0:
212
211
  sequences = self.add_check_on_time_varying_shapes(data, sequences)
213
212
 
214
- results, updates = pytensor.scan(
213
+ results = pytensor.scan(
215
214
  self.kalman_step,
216
215
  sequences=[data, *sequences],
217
216
  outputs_info=[None, a0, None, None, P0, None, None],
218
217
  non_sequences=non_sequences,
219
218
  name="forward_kalman_pass",
220
219
  strict=False,
220
+ return_updates=False,
221
221
  )
222
222
 
223
- filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
224
-
225
- if return_updates:
226
- return filter_results, updates
227
- return filter_results
223
+ return self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
228
224
 
229
225
  def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
230
226
  """
@@ -652,7 +648,9 @@ class SquareRootFilter(BaseFilter):
652
648
  y_hat = Z.dot(a) + d
653
649
  v = y - y_hat
654
650
 
655
- H_chol = pytensor.ifelse(pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True))
651
+ H_chol = pytensor.ifelse(
652
+ pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True, on_error="nan")
653
+ )
656
654
 
657
655
  # The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
658
656
  # Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
@@ -694,8 +692,10 @@ class SquareRootFilter(BaseFilter):
694
692
  """
695
693
  return [a, P_chol, pt.zeros(())]
696
694
 
695
+ degenerate = pt.eq(all_nan_flag, 1.0)
696
+ F_chol = pytensor.ifelse(degenerate, pt.eye(*F_chol.shape), F_chol)
697
697
  [a_filtered, P_chol_filtered, ll] = pytensor.ifelse(
698
- pt.eq(all_nan_flag, 1.0),
698
+ degenerate,
699
699
  compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
700
700
  compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
701
701
  )
@@ -786,11 +786,12 @@ class UnivariateFilter(BaseFilter):
786
786
  H_masked = W.dot(H)
787
787
  y_masked = pt.set_subtensor(y[nan_mask], 0.0)
788
788
 
789
- result, updates = pytensor.scan(
789
+ result = pytensor.scan(
790
790
  self._univariate_inner_filter_step,
791
791
  sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
792
792
  outputs_info=[a, P, None, None, None],
793
793
  name="univariate_inner_scan",
794
+ return_updates=False,
794
795
  )
795
796
 
796
797
  a_filtered, P_filtered, obs_mu, obs_cov, ll_inner = result
@@ -76,16 +76,16 @@ class KalmanSmoother:
76
76
  self.seq_names = seq_names
77
77
  self.non_seq_names = non_seq_names
78
78
 
79
- smoother_result, updates = pytensor.scan(
79
+ smoothed_states, smoothed_covariances = pytensor.scan(
80
80
  self.smoother_step,
81
81
  sequences=[filtered_states[:-1], filtered_covariances[:-1], *sequences],
82
82
  outputs_info=[a_last, P_last],
83
83
  non_sequences=non_sequences,
84
84
  go_backwards=True,
85
85
  name="kalman_smoother",
86
+ return_updates=False,
86
87
  )
87
88
 
88
- smoothed_states, smoothed_covariances = smoother_result
89
89
  smoothed_states = pt.concatenate(
90
90
  [smoothed_states[::-1], pt.expand_dims(a_last, axis=(0,))], axis=0
91
91
  )