pymc-extras 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/distributions/timeseries.py +10 -10
- pymc_extras/inference/dadvi/dadvi.py +14 -83
- pymc_extras/inference/laplace_approx/laplace.py +187 -159
- pymc_extras/inference/pathfinder/pathfinder.py +12 -7
- pymc_extras/inference/smc/sampling.py +2 -2
- pymc_extras/model/marginal/distributions.py +4 -2
- pymc_extras/model/marginal/marginal_model.py +12 -2
- pymc_extras/prior.py +3 -3
- pymc_extras/statespace/core/properties.py +276 -0
- pymc_extras/statespace/core/statespace.py +182 -45
- pymc_extras/statespace/filters/distributions.py +19 -34
- pymc_extras/statespace/filters/kalman_filter.py +13 -12
- pymc_extras/statespace/filters/kalman_smoother.py +2 -2
- pymc_extras/statespace/models/DFM.py +179 -168
- pymc_extras/statespace/models/ETS.py +177 -151
- pymc_extras/statespace/models/SARIMAX.py +149 -152
- pymc_extras/statespace/models/VARMAX.py +134 -145
- pymc_extras/statespace/models/__init__.py +8 -1
- pymc_extras/statespace/models/structural/__init__.py +30 -8
- pymc_extras/statespace/models/structural/components/autoregressive.py +87 -45
- pymc_extras/statespace/models/structural/components/cycle.py +119 -80
- pymc_extras/statespace/models/structural/components/level_trend.py +95 -42
- pymc_extras/statespace/models/structural/components/measurement_error.py +27 -17
- pymc_extras/statespace/models/structural/components/regression.py +105 -68
- pymc_extras/statespace/models/structural/components/seasonality.py +138 -100
- pymc_extras/statespace/models/structural/core.py +397 -286
- pymc_extras/statespace/models/utilities.py +5 -20
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/METADATA +4 -4
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/RECORD +31 -30
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/WHEEL +0 -0
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -19,6 +19,22 @@ from rich.box import SIMPLE_HEAD
|
|
|
19
19
|
from rich.console import Console
|
|
20
20
|
from rich.table import Table
|
|
21
21
|
|
|
22
|
+
from pymc_extras.statespace.core.properties import (
|
|
23
|
+
Coord,
|
|
24
|
+
CoordInfo,
|
|
25
|
+
Data,
|
|
26
|
+
DataInfo,
|
|
27
|
+
Parameter,
|
|
28
|
+
ParameterInfo,
|
|
29
|
+
Shock,
|
|
30
|
+
ShockInfo,
|
|
31
|
+
State,
|
|
32
|
+
StateInfo,
|
|
33
|
+
SymbolicData,
|
|
34
|
+
SymbolicDataInfo,
|
|
35
|
+
SymbolicVariable,
|
|
36
|
+
SymbolicVariableInfo,
|
|
37
|
+
)
|
|
22
38
|
from pymc_extras.statespace.core.representation import PytensorRepresentation
|
|
23
39
|
from pymc_extras.statespace.filters import (
|
|
24
40
|
KalmanSmoother,
|
|
@@ -69,6 +85,22 @@ def _verify_group(group):
|
|
|
69
85
|
raise ValueError(f'Argument "group" must be one of "prior" or "posterior", found {group}')
|
|
70
86
|
|
|
71
87
|
|
|
88
|
+
def _validate_property(props, property_name, expected_type):
|
|
89
|
+
if isinstance(props, expected_type) or props is None:
|
|
90
|
+
return
|
|
91
|
+
elif not isinstance(props, tuple | list):
|
|
92
|
+
raise TypeError(
|
|
93
|
+
f"The {property_name} property must be a {expected_type.__name__} or a "
|
|
94
|
+
f"list/tuple of {expected_type.__name__} instances."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if not all(isinstance(prop, expected_type) for prop in props):
|
|
98
|
+
raise TypeError(
|
|
99
|
+
f"All elements of the {property_name} property must be instances of "
|
|
100
|
+
f"{expected_type.__name__}."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
72
104
|
class PyMCStateSpace:
|
|
73
105
|
r"""
|
|
74
106
|
Base class for Linear Gaussian Statespace models in PyMC.
|
|
@@ -236,8 +268,8 @@ class PyMCStateSpace:
|
|
|
236
268
|
self._fit_exog_data: dict[str, dict] = {}
|
|
237
269
|
|
|
238
270
|
self._needs_exog_data = None
|
|
239
|
-
self.
|
|
240
|
-
self.
|
|
271
|
+
self._tensor_variable_info = SymbolicVariableInfo()
|
|
272
|
+
self._tensor_data_info = SymbolicDataInfo()
|
|
241
273
|
|
|
242
274
|
self.k_endog = k_endog
|
|
243
275
|
self.k_states = k_states
|
|
@@ -245,6 +277,8 @@ class PyMCStateSpace:
|
|
|
245
277
|
self.measurement_error = measurement_error
|
|
246
278
|
self.mode = mode
|
|
247
279
|
|
|
280
|
+
self._populate_properties()
|
|
281
|
+
|
|
248
282
|
# All models contain a state space representation and a Kalman filter
|
|
249
283
|
self.ssm = PytensorRepresentation(k_endog, k_states, k_posdef)
|
|
250
284
|
|
|
@@ -271,17 +305,107 @@ class PyMCStateSpace:
|
|
|
271
305
|
console = Console()
|
|
272
306
|
console.print(self.requirement_table)
|
|
273
307
|
|
|
308
|
+
def _populate_properties(self) -> None:
|
|
309
|
+
self._set_parameters()
|
|
310
|
+
self._set_states()
|
|
311
|
+
self._set_shocks()
|
|
312
|
+
self._set_coords()
|
|
313
|
+
self._set_data_info()
|
|
314
|
+
|
|
315
|
+
def set_parameters(self) -> Parameter | tuple[Parameter, ...] | None:
|
|
316
|
+
"""
|
|
317
|
+
Provides parameter metadata to the model.
|
|
318
|
+
|
|
319
|
+
Optional. Default implementation sets an empty ParameterInfo.
|
|
320
|
+
Child classes can override to define model parameters.
|
|
321
|
+
"""
|
|
322
|
+
return
|
|
323
|
+
|
|
324
|
+
def _set_parameters(self) -> None:
|
|
325
|
+
param_list = self.set_parameters()
|
|
326
|
+
self._param_info = ParameterInfo(parameters=param_list)
|
|
327
|
+
|
|
328
|
+
def set_states(self) -> State | tuple[State, ...] | None:
|
|
329
|
+
"""
|
|
330
|
+
Provide state metadata to the model.
|
|
331
|
+
|
|
332
|
+
Optional. Default implementation creates generic state names (state_0, state_1, ...).
|
|
333
|
+
Child classes can override to provide meaningful state names.
|
|
334
|
+
"""
|
|
335
|
+
hidden_states = [
|
|
336
|
+
State(name=f"hidden_state_{name}", observed=False) for name in range(self.k_states or 0)
|
|
337
|
+
]
|
|
338
|
+
observed_states = [
|
|
339
|
+
State(name=f"observed_state_{name}", observed=True) for name in range(self.k_endog or 0)
|
|
340
|
+
]
|
|
341
|
+
return *hidden_states, *observed_states
|
|
342
|
+
|
|
343
|
+
def _set_states(self) -> None:
|
|
344
|
+
states = self.set_states()
|
|
345
|
+
_validate_property(states, "states", State)
|
|
346
|
+
|
|
347
|
+
if isinstance(states, State):
|
|
348
|
+
states = (states,)
|
|
349
|
+
|
|
350
|
+
self._state_info = StateInfo(states=states)
|
|
351
|
+
|
|
352
|
+
def set_shocks(self) -> Shock | tuple[Shock, ...] | None:
|
|
353
|
+
"""
|
|
354
|
+
Provide shock metadata to the model.
|
|
355
|
+
|
|
356
|
+
Optional. Default implementation creates generic shock names (shock_0, shock_1, ...).
|
|
357
|
+
Child classes can override to provide meaningful shock names.
|
|
358
|
+
"""
|
|
359
|
+
return tuple(Shock(name=f"shock_{i}") for i in range(self.k_posdef))
|
|
360
|
+
|
|
361
|
+
def _set_shocks(self) -> None:
|
|
362
|
+
shocks = self.set_shocks()
|
|
363
|
+
_validate_property(shocks, "shocks", Shock)
|
|
364
|
+
if isinstance(shocks, Shock):
|
|
365
|
+
shocks = (shocks,)
|
|
366
|
+
self._shock_info = ShockInfo(shocks=shocks)
|
|
367
|
+
|
|
368
|
+
def default_coords(self) -> tuple[Coord, ...]:
|
|
369
|
+
return CoordInfo.default_coords_from_model(self).items
|
|
370
|
+
|
|
371
|
+
def set_coords(self) -> Coord | tuple[Coord, ...] | None:
|
|
372
|
+
"""
|
|
373
|
+
Provide coordinates to the model.
|
|
374
|
+
|
|
375
|
+
Optional. Default implementation sets an empty CoordInfo.
|
|
376
|
+
Child classes can override to provide model-specific coordinates.
|
|
377
|
+
"""
|
|
378
|
+
return self.default_coords()
|
|
379
|
+
|
|
380
|
+
def _set_coords(self) -> None:
|
|
381
|
+
coords = self.set_coords()
|
|
382
|
+
_validate_property(coords, "coords", Coord)
|
|
383
|
+
if isinstance(coords, Coord):
|
|
384
|
+
coords = (coords,)
|
|
385
|
+
self._coords_info = CoordInfo(coords=coords)
|
|
386
|
+
|
|
387
|
+
def set_data_info(self) -> Data | tuple[Data, ...] | None:
|
|
388
|
+
"""
|
|
389
|
+
Provide data_info metadata to the model.
|
|
390
|
+
|
|
391
|
+
Optional. Default implementation sets an empty DataInfo.
|
|
392
|
+
Child classes can override if the model requires exogenous data.
|
|
393
|
+
"""
|
|
394
|
+
return
|
|
395
|
+
|
|
396
|
+
def _set_data_info(self) -> None:
|
|
397
|
+
data_info = self.set_data_info()
|
|
398
|
+
_validate_property(data_info, "data_info", Data)
|
|
399
|
+
if isinstance(data_info, Data):
|
|
400
|
+
data_info = (data_info,)
|
|
401
|
+
self._data_info = DataInfo(data=data_info)
|
|
402
|
+
|
|
274
403
|
def _populate_prior_requirements(self) -> None:
|
|
275
404
|
"""
|
|
276
405
|
Add requirements about priors needed for the model to a rich table, including their names,
|
|
277
406
|
shapes, named dimensions, and any parameter constraints.
|
|
278
407
|
"""
|
|
279
|
-
|
|
280
|
-
# is not true.
|
|
281
|
-
try:
|
|
282
|
-
if not isinstance(self.param_info, dict):
|
|
283
|
-
return
|
|
284
|
-
except NotImplementedError:
|
|
408
|
+
if not self.param_info:
|
|
285
409
|
return
|
|
286
410
|
|
|
287
411
|
if self.requirement_table is None:
|
|
@@ -296,10 +420,7 @@ class PyMCStateSpace:
|
|
|
296
420
|
"""
|
|
297
421
|
Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
|
|
298
422
|
"""
|
|
299
|
-
|
|
300
|
-
if not isinstance(self.data_info, dict):
|
|
301
|
-
return
|
|
302
|
-
except NotImplementedError:
|
|
423
|
+
if not self.data_info:
|
|
303
424
|
return
|
|
304
425
|
|
|
305
426
|
if self.requirement_table is None:
|
|
@@ -374,24 +495,24 @@ class PyMCStateSpace:
|
|
|
374
495
|
return self.subbed_ssm
|
|
375
496
|
|
|
376
497
|
@property
|
|
377
|
-
def param_names(self) ->
|
|
498
|
+
def param_names(self) -> tuple[str, ...]:
|
|
378
499
|
"""
|
|
379
500
|
Names of model parameters
|
|
380
501
|
|
|
381
502
|
A list of all parameters expected by the model. Each parameter will be sought inside the active PyMC model
|
|
382
503
|
context when ``build_statespace_graph`` is invoked.
|
|
383
504
|
"""
|
|
384
|
-
|
|
505
|
+
return self._param_info.names
|
|
385
506
|
|
|
386
507
|
@property
|
|
387
|
-
def data_names(self) ->
|
|
508
|
+
def data_names(self) -> tuple[str, ...]:
|
|
388
509
|
"""
|
|
389
510
|
Names of data variables expected by the model.
|
|
390
511
|
|
|
391
512
|
This does not include the observed data series, which is automatically handled by PyMC. This property only
|
|
392
513
|
needs to be implemented for models that expect exogenous data.
|
|
393
514
|
"""
|
|
394
|
-
return
|
|
515
|
+
return self._data_info.exogenous_names
|
|
395
516
|
|
|
396
517
|
@property
|
|
397
518
|
def param_info(self) -> dict[str, dict[str, Any]]:
|
|
@@ -406,7 +527,7 @@ class PyMCStateSpace:
|
|
|
406
527
|
positive semi-definite, etc)
|
|
407
528
|
* key: "dims", value: tuple of strings
|
|
408
529
|
"""
|
|
409
|
-
|
|
530
|
+
return self._param_info.to_dict()
|
|
410
531
|
|
|
411
532
|
@property
|
|
412
533
|
def data_info(self) -> dict[str, dict[str, Any]]:
|
|
@@ -419,31 +540,30 @@ class PyMCStateSpace:
|
|
|
419
540
|
* key: "shape", value: a tuple of integers
|
|
420
541
|
* key: "dims", value: tuple of strings
|
|
421
542
|
"""
|
|
422
|
-
|
|
543
|
+
return self._data_info.to_dict()
|
|
423
544
|
|
|
424
545
|
@property
|
|
425
|
-
def state_names(self) ->
|
|
546
|
+
def state_names(self) -> tuple[str, ...]:
|
|
426
547
|
"""
|
|
427
548
|
A k_states length list of strings, associated with the model's hidden states
|
|
428
549
|
|
|
429
550
|
"""
|
|
430
|
-
|
|
431
|
-
raise NotImplementedError("The state_names property has not been implemented!")
|
|
551
|
+
return self._state_info.unobserved_state_names
|
|
432
552
|
|
|
433
553
|
@property
|
|
434
|
-
def observed_states(self) ->
|
|
554
|
+
def observed_states(self) -> tuple[str, ...]:
|
|
435
555
|
"""
|
|
436
556
|
A k_endog length list of strings, associated with the model's observed states
|
|
437
557
|
"""
|
|
438
|
-
|
|
558
|
+
return self._state_info.observed_state_names
|
|
439
559
|
|
|
440
560
|
@property
|
|
441
|
-
def shock_names(self) ->
|
|
561
|
+
def shock_names(self) -> tuple[str, ...]:
|
|
442
562
|
"""
|
|
443
563
|
A k_posdef length list of strings, associated with the model's shock processes
|
|
444
564
|
|
|
445
565
|
"""
|
|
446
|
-
|
|
566
|
+
return self._shock_info.names
|
|
447
567
|
|
|
448
568
|
@property
|
|
449
569
|
def default_priors(self) -> dict[str, Callable]:
|
|
@@ -464,10 +584,10 @@ class PyMCStateSpace:
|
|
|
464
584
|
should come from the default names defined in ``statespace.utils.constants`` for them to be detected by
|
|
465
585
|
sampling methods.
|
|
466
586
|
"""
|
|
467
|
-
|
|
587
|
+
return self._coords_info.to_dict()
|
|
468
588
|
|
|
469
589
|
@property
|
|
470
|
-
def param_dims(self) -> dict[str,
|
|
590
|
+
def param_dims(self) -> dict[str, tuple[str, ...]]:
|
|
471
591
|
"""
|
|
472
592
|
Dictionary of named dimensions for each model parameter
|
|
473
593
|
|
|
@@ -475,8 +595,17 @@ class PyMCStateSpace:
|
|
|
475
595
|
PyMC random variable. Dimensions should come from the default names defined in ``statespace.utils.constants``
|
|
476
596
|
for them to be detected by sampling methods.
|
|
477
597
|
|
|
598
|
+
Note: Scalar parameters (with dims=None) are not included in this dictionary.
|
|
478
599
|
"""
|
|
479
|
-
|
|
600
|
+
return {param.name: param.dims for param in self._param_info if param.dims is not None}
|
|
601
|
+
|
|
602
|
+
@property
|
|
603
|
+
def _name_to_variable(self):
|
|
604
|
+
return self._tensor_variable_info.to_dict()
|
|
605
|
+
|
|
606
|
+
@property
|
|
607
|
+
def _name_to_data(self):
|
|
608
|
+
return self._tensor_data_info.to_dict()
|
|
480
609
|
|
|
481
610
|
def add_default_priors(self) -> None:
|
|
482
611
|
"""
|
|
@@ -520,14 +649,15 @@ class PyMCStateSpace:
|
|
|
520
649
|
f"parameters."
|
|
521
650
|
)
|
|
522
651
|
|
|
523
|
-
if name in self.
|
|
652
|
+
if name in self._tensor_variable_info:
|
|
524
653
|
raise ValueError(
|
|
525
654
|
f"{name} is already a registered placeholder variable with shape "
|
|
526
|
-
f"{self.
|
|
655
|
+
f"{self._tensor_variable_info[name].type.shape}"
|
|
527
656
|
)
|
|
528
657
|
|
|
529
658
|
placeholder = pt.tensor(name, shape=shape, dtype=dtype)
|
|
530
|
-
|
|
659
|
+
tensor_var = SymbolicVariable(name=name, symbolic_variable=placeholder)
|
|
660
|
+
self._tensor_variable_info = self._tensor_variable_info.add(tensor_var)
|
|
531
661
|
return placeholder
|
|
532
662
|
|
|
533
663
|
def make_and_register_data(
|
|
@@ -559,14 +689,15 @@ class PyMCStateSpace:
|
|
|
559
689
|
f"parameters."
|
|
560
690
|
)
|
|
561
691
|
|
|
562
|
-
if name in self.
|
|
692
|
+
if name in self._tensor_data_info:
|
|
563
693
|
raise ValueError(
|
|
564
694
|
f"{name} is already a registered placeholder variable with shape "
|
|
565
|
-
f"{self.
|
|
695
|
+
f"{self._tensor_data_info[name].type.shape}"
|
|
566
696
|
)
|
|
567
697
|
|
|
568
698
|
placeholder = pt.tensor(name, shape=shape, dtype=dtype)
|
|
569
|
-
|
|
699
|
+
tensor_data = SymbolicData(name=name, symbolic_data=placeholder)
|
|
700
|
+
self._tensor_data_info = self._tensor_data_info.add(tensor_data)
|
|
570
701
|
return placeholder
|
|
571
702
|
|
|
572
703
|
def make_symbolic_graph(self) -> None:
|
|
@@ -741,10 +872,8 @@ class PyMCStateSpace:
|
|
|
741
872
|
|
|
742
873
|
Only used when models require exogenous data. The observed data is not added to the model using this method!
|
|
743
874
|
"""
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
data_names = self.data_names
|
|
747
|
-
except NotImplementedError:
|
|
875
|
+
data_names = self.data_names
|
|
876
|
+
if not data_names:
|
|
748
877
|
return
|
|
749
878
|
|
|
750
879
|
pymc_model = modelcontext(None)
|
|
@@ -892,7 +1021,6 @@ class PyMCStateSpace:
|
|
|
892
1021
|
.. deprecated:: 0.2.5
|
|
893
1022
|
The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
|
|
894
1023
|
model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
|
|
895
|
-
|
|
896
1024
|
"""
|
|
897
1025
|
if mode is not None:
|
|
898
1026
|
warnings.warn(
|
|
@@ -1430,7 +1558,11 @@ class PyMCStateSpace:
|
|
|
1430
1558
|
"""
|
|
1431
1559
|
|
|
1432
1560
|
return self._sample_conditional(
|
|
1433
|
-
idata=idata,
|
|
1561
|
+
idata=idata,
|
|
1562
|
+
group="prior",
|
|
1563
|
+
random_seed=random_seed,
|
|
1564
|
+
mvn_method=mvn_method,
|
|
1565
|
+
**kwargs,
|
|
1434
1566
|
)
|
|
1435
1567
|
|
|
1436
1568
|
def sample_conditional_posterior(
|
|
@@ -1473,7 +1605,11 @@ class PyMCStateSpace:
|
|
|
1473
1605
|
"""
|
|
1474
1606
|
|
|
1475
1607
|
return self._sample_conditional(
|
|
1476
|
-
idata=idata,
|
|
1608
|
+
idata=idata,
|
|
1609
|
+
group="posterior",
|
|
1610
|
+
random_seed=random_seed,
|
|
1611
|
+
mvn_method=mvn_method,
|
|
1612
|
+
**kwargs,
|
|
1477
1613
|
)
|
|
1478
1614
|
|
|
1479
1615
|
def sample_unconditional_prior(
|
|
@@ -2080,11 +2216,11 @@ class PyMCStateSpace:
|
|
|
2080
2216
|
forecast_index: pd.RangeIndex | pd.DatetimeIndex,
|
|
2081
2217
|
name=None,
|
|
2082
2218
|
):
|
|
2083
|
-
|
|
2084
|
-
var_to_dims = {key: info["dims"][1:] for key, info in self.data_info.items()}
|
|
2085
|
-
except NotImplementedError:
|
|
2219
|
+
if not self.data_info:
|
|
2086
2220
|
return scenario
|
|
2087
2221
|
|
|
2222
|
+
var_to_dims = {key: info["dims"][1:] for key, info in self.data_info.items()}
|
|
2223
|
+
|
|
2088
2224
|
if any(len(dims) > 1 for dims in var_to_dims.values()):
|
|
2089
2225
|
raise NotImplementedError(">2d exogenous data is not yet supported.")
|
|
2090
2226
|
coords = {
|
|
@@ -2500,13 +2636,14 @@ class PyMCStateSpace:
|
|
|
2500
2636
|
next_x = c + T @ x + R @ shock
|
|
2501
2637
|
return next_x
|
|
2502
2638
|
|
|
2503
|
-
irf
|
|
2639
|
+
irf = pytensor.scan(
|
|
2504
2640
|
irf_step,
|
|
2505
2641
|
sequences=[shock_trajectory],
|
|
2506
2642
|
outputs_info=[x0],
|
|
2507
2643
|
non_sequences=[c, T, R],
|
|
2508
2644
|
n_steps=n_steps,
|
|
2509
2645
|
strict=True,
|
|
2646
|
+
return_updates=False,
|
|
2510
2647
|
)
|
|
2511
2648
|
|
|
2512
2649
|
pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM])
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import numpy as np
|
|
2
1
|
import pymc as pm
|
|
3
2
|
import pytensor
|
|
4
3
|
import pytensor.tensor as pt
|
|
@@ -8,7 +7,9 @@ from pymc.distributions.dist_math import check_parameters
|
|
|
8
7
|
from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
|
|
9
8
|
from pymc.distributions.shape_utils import get_support_shape_1d
|
|
10
9
|
from pymc.logprob.abstract import _logprob
|
|
10
|
+
from pymc.pytensorf import normalize_rng_param
|
|
11
11
|
from pytensor.graph.basic import Node
|
|
12
|
+
from pytensor.tensor.random import multivariate_normal
|
|
12
13
|
|
|
13
14
|
floatX = pytensor.config.floatX
|
|
14
15
|
COV_ZERO_TOL = 0
|
|
@@ -152,6 +153,7 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
152
153
|
Q,
|
|
153
154
|
steps,
|
|
154
155
|
size=None,
|
|
156
|
+
rng=None,
|
|
155
157
|
sequence_names=None,
|
|
156
158
|
append_x0=True,
|
|
157
159
|
method="svd",
|
|
@@ -178,7 +180,7 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
178
180
|
]
|
|
179
181
|
non_sequences = [x for x in [c_, d_, T_, Z_, R_, H_, Q_] if x not in sequences]
|
|
180
182
|
|
|
181
|
-
rng =
|
|
183
|
+
rng = normalize_rng_param(rng)
|
|
182
184
|
|
|
183
185
|
def sort_args(args):
|
|
184
186
|
sorted_args = []
|
|
@@ -197,10 +199,9 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
197
199
|
n_seq = len(sequence_names)
|
|
198
200
|
|
|
199
201
|
def step_fn(*args):
|
|
200
|
-
seqs, state, non_seqs = args[:n_seq], args[n_seq
|
|
201
|
-
non_seqs, rng = non_seqs[:-1], non_seqs[-1]
|
|
202
|
+
seqs, (rng, state, *non_seqs) = args[:n_seq], args[n_seq:]
|
|
202
203
|
|
|
203
|
-
c, d, T, Z, R, H, Q = sort_args(seqs
|
|
204
|
+
c, d, T, Z, R, H, Q = sort_args((*seqs, *non_seqs))
|
|
204
205
|
k = T.shape[0]
|
|
205
206
|
a = state[:k]
|
|
206
207
|
|
|
@@ -219,7 +220,7 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
219
220
|
|
|
220
221
|
next_state = pt.concatenate([a_next, y_next], axis=0)
|
|
221
222
|
|
|
222
|
-
return
|
|
223
|
+
return next_rng, next_state
|
|
223
224
|
|
|
224
225
|
Z_init = Z_ if Z_ in non_sequences else Z_[0]
|
|
225
226
|
H_init = H_ if H_ in non_sequences else H_[0]
|
|
@@ -229,13 +230,14 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
229
230
|
|
|
230
231
|
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
|
|
231
232
|
|
|
232
|
-
|
|
233
|
+
ss_rng, statespace = pytensor.scan(
|
|
233
234
|
step_fn,
|
|
234
|
-
outputs_info=[init_dist_],
|
|
235
|
+
outputs_info=[rng, init_dist_],
|
|
235
236
|
sequences=None if len(sequences) == 0 else sequences,
|
|
236
|
-
non_sequences=[*non_sequences
|
|
237
|
+
non_sequences=[*non_sequences],
|
|
237
238
|
n_steps=steps,
|
|
238
239
|
strict=True,
|
|
240
|
+
return_updates=False,
|
|
239
241
|
)
|
|
240
242
|
|
|
241
243
|
if append_x0:
|
|
@@ -245,7 +247,6 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
245
247
|
statespace_ = statespace
|
|
246
248
|
statespace_ = pt.specify_shape(statespace_, (steps, None))
|
|
247
249
|
|
|
248
|
-
(ss_rng,) = tuple(updates.values())
|
|
249
250
|
linear_gaussian_ss_op = LinearGaussianStateSpaceRV(
|
|
250
251
|
inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps, rng],
|
|
251
252
|
outputs=[ss_rng, statespace_],
|
|
@@ -368,41 +369,25 @@ class SequenceMvNormal(Continuous):
|
|
|
368
369
|
|
|
369
370
|
@classmethod
|
|
370
371
|
def dist(cls, mus, covs, logp, method="svd", **kwargs):
|
|
372
|
+
mus, covs, logp = map(pt.as_tensor_variable, (mus, covs, logp))
|
|
371
373
|
return super().dist([mus, covs, logp], method=method, **kwargs)
|
|
372
374
|
|
|
373
375
|
@classmethod
|
|
374
|
-
def rv_op(cls, mus, covs, logp, method="svd", size=None):
|
|
375
|
-
|
|
376
|
-
if mus.ndim > 2:
|
|
377
|
-
mus = pt.moveaxis(mus, -2, 0)
|
|
378
|
-
if covs.ndim > 3:
|
|
379
|
-
covs = pt.moveaxis(covs, -3, 0)
|
|
380
|
-
|
|
381
|
-
mus_, covs_ = mus.type(), covs.type()
|
|
382
|
-
|
|
376
|
+
def rv_op(cls, mus, covs, logp, method="svd", size=None, rng=None):
|
|
377
|
+
rng = normalize_rng_param(rng)
|
|
383
378
|
logp_ = logp.type()
|
|
384
|
-
rng = pytensor.shared(np.random.default_rng())
|
|
385
379
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
mvn_seq, updates = pytensor.scan(
|
|
391
|
-
step, sequences=[mus_, covs_], non_sequences=[rng], strict=True, n_steps=mus_.shape[0]
|
|
392
|
-
)
|
|
393
|
-
mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
|
|
394
|
-
|
|
395
|
-
# Move time axis back to position -2 so batches are on the left
|
|
396
|
-
if mvn_seq.ndim > 2:
|
|
397
|
-
mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
|
|
398
|
-
|
|
399
|
-
(seq_mvn_rng,) = tuple(updates.values())
|
|
380
|
+
mus_, covs_ = mus.type(), covs.type()
|
|
381
|
+
seq_mvn_rng, mvn_seq = multivariate_normal(
|
|
382
|
+
mean=mus_, cov=covs_, rng=rng, method=method
|
|
383
|
+
).owner.outputs
|
|
400
384
|
|
|
401
385
|
mvn_seq_op = KalmanFilterRV(
|
|
402
386
|
inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
|
|
403
387
|
)
|
|
404
388
|
|
|
405
389
|
mvn_seq = mvn_seq_op(mus, covs, logp, rng)
|
|
390
|
+
|
|
406
391
|
return mvn_seq
|
|
407
392
|
|
|
408
393
|
|
|
@@ -8,7 +8,7 @@ from pymc.pytensorf import constant_fold
|
|
|
8
8
|
from pytensor.graph.basic import Variable
|
|
9
9
|
from pytensor.raise_op import Assert
|
|
10
10
|
from pytensor.tensor import TensorVariable
|
|
11
|
-
from pytensor.tensor.
|
|
11
|
+
from pytensor.tensor.linalg import solve_triangular
|
|
12
12
|
|
|
13
13
|
from pymc_extras.statespace.filters.utilities import (
|
|
14
14
|
quad_form_sym,
|
|
@@ -148,10 +148,9 @@ class BaseFilter(ABC):
|
|
|
148
148
|
R,
|
|
149
149
|
H,
|
|
150
150
|
Q,
|
|
151
|
-
return_updates=False,
|
|
152
151
|
missing_fill_value=None,
|
|
153
152
|
cov_jitter=None,
|
|
154
|
-
) -> list[TensorVariable]
|
|
153
|
+
) -> list[TensorVariable]:
|
|
155
154
|
"""
|
|
156
155
|
Construct the computation graph for the Kalman filter. See [1] for details.
|
|
157
156
|
|
|
@@ -211,20 +210,17 @@ class BaseFilter(ABC):
|
|
|
211
210
|
if len(sequences) > 0:
|
|
212
211
|
sequences = self.add_check_on_time_varying_shapes(data, sequences)
|
|
213
212
|
|
|
214
|
-
results
|
|
213
|
+
results = pytensor.scan(
|
|
215
214
|
self.kalman_step,
|
|
216
215
|
sequences=[data, *sequences],
|
|
217
216
|
outputs_info=[None, a0, None, None, P0, None, None],
|
|
218
217
|
non_sequences=non_sequences,
|
|
219
218
|
name="forward_kalman_pass",
|
|
220
219
|
strict=False,
|
|
220
|
+
return_updates=False,
|
|
221
221
|
)
|
|
222
222
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
if return_updates:
|
|
226
|
-
return filter_results, updates
|
|
227
|
-
return filter_results
|
|
223
|
+
return self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
|
|
228
224
|
|
|
229
225
|
def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
|
|
230
226
|
"""
|
|
@@ -652,7 +648,9 @@ class SquareRootFilter(BaseFilter):
|
|
|
652
648
|
y_hat = Z.dot(a) + d
|
|
653
649
|
v = y - y_hat
|
|
654
650
|
|
|
655
|
-
H_chol = pytensor.ifelse(
|
|
651
|
+
H_chol = pytensor.ifelse(
|
|
652
|
+
pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True, on_error="nan")
|
|
653
|
+
)
|
|
656
654
|
|
|
657
655
|
# The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
|
|
658
656
|
# Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
|
|
@@ -694,8 +692,10 @@ class SquareRootFilter(BaseFilter):
|
|
|
694
692
|
"""
|
|
695
693
|
return [a, P_chol, pt.zeros(())]
|
|
696
694
|
|
|
695
|
+
degenerate = pt.eq(all_nan_flag, 1.0)
|
|
696
|
+
F_chol = pytensor.ifelse(degenerate, pt.eye(*F_chol.shape), F_chol)
|
|
697
697
|
[a_filtered, P_chol_filtered, ll] = pytensor.ifelse(
|
|
698
|
-
|
|
698
|
+
degenerate,
|
|
699
699
|
compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
|
|
700
700
|
compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
|
|
701
701
|
)
|
|
@@ -786,11 +786,12 @@ class UnivariateFilter(BaseFilter):
|
|
|
786
786
|
H_masked = W.dot(H)
|
|
787
787
|
y_masked = pt.set_subtensor(y[nan_mask], 0.0)
|
|
788
788
|
|
|
789
|
-
result
|
|
789
|
+
result = pytensor.scan(
|
|
790
790
|
self._univariate_inner_filter_step,
|
|
791
791
|
sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
|
|
792
792
|
outputs_info=[a, P, None, None, None],
|
|
793
793
|
name="univariate_inner_scan",
|
|
794
|
+
return_updates=False,
|
|
794
795
|
)
|
|
795
796
|
|
|
796
797
|
a_filtered, P_filtered, obs_mu, obs_cov, ll_inner = result
|
|
@@ -76,16 +76,16 @@ class KalmanSmoother:
|
|
|
76
76
|
self.seq_names = seq_names
|
|
77
77
|
self.non_seq_names = non_seq_names
|
|
78
78
|
|
|
79
|
-
|
|
79
|
+
smoothed_states, smoothed_covariances = pytensor.scan(
|
|
80
80
|
self.smoother_step,
|
|
81
81
|
sequences=[filtered_states[:-1], filtered_covariances[:-1], *sequences],
|
|
82
82
|
outputs_info=[a_last, P_last],
|
|
83
83
|
non_sequences=non_sequences,
|
|
84
84
|
go_backwards=True,
|
|
85
85
|
name="kalman_smoother",
|
|
86
|
+
return_updates=False,
|
|
86
87
|
)
|
|
87
88
|
|
|
88
|
-
smoothed_states, smoothed_covariances = smoother_result
|
|
89
89
|
smoothed_states = pt.concatenate(
|
|
90
90
|
[smoothed_states[::-1], pt.expand_dims(a_last, axis=(0,))], axis=0
|
|
91
91
|
)
|