pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import pytensor
|
|
2
|
+
import pytensor.tensor as pt
|
|
3
|
+
|
|
4
|
+
from pytensor.compile import get_mode
|
|
5
|
+
from pytensor.tensor.nlinalg import matrix_dot
|
|
6
|
+
|
|
7
|
+
from pymc_extras.statespace.filters.utilities import (
|
|
8
|
+
quad_form_sym,
|
|
9
|
+
split_vars_into_seq_and_nonseq,
|
|
10
|
+
stabilize,
|
|
11
|
+
)
|
|
12
|
+
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class KalmanSmoother:
|
|
16
|
+
"""
|
|
17
|
+
Kalman Smoother
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, mode: str | None = None):
|
|
22
|
+
self.mode = mode
|
|
23
|
+
self.cov_jitter = JITTER_DEFAULT
|
|
24
|
+
self.seq_names = []
|
|
25
|
+
self.non_seq_names = []
|
|
26
|
+
|
|
27
|
+
def unpack_args(self, args):
|
|
28
|
+
"""
|
|
29
|
+
The order of inputs to the inner scan function is not known, since some, all, or none of the input matrices
|
|
30
|
+
can be time varying. The order arguments are fed to the inner function is sequences, outputs_info,
|
|
31
|
+
non-sequences. This function works out which matrices are where, and returns a standardized order expected
|
|
32
|
+
by the kalman_step function.
|
|
33
|
+
|
|
34
|
+
The standard order is: a, P, a_smooth, P_smooth, T, R, Q
|
|
35
|
+
"""
|
|
36
|
+
# If there are no sequence parameters (all params are static),
|
|
37
|
+
# no changes are needed, params will be in order.
|
|
38
|
+
args = list(args)
|
|
39
|
+
n_seq = len(self.seq_names)
|
|
40
|
+
if n_seq == 0:
|
|
41
|
+
return args
|
|
42
|
+
|
|
43
|
+
# The first two args are always a and P
|
|
44
|
+
a = args.pop(0)
|
|
45
|
+
P = args.pop(0)
|
|
46
|
+
|
|
47
|
+
# There are always two outputs_info wedged between the seqs and non_seqs
|
|
48
|
+
seqs, (a_smooth, P_smooth), non_seqs = (
|
|
49
|
+
args[:n_seq],
|
|
50
|
+
args[n_seq : n_seq + 2],
|
|
51
|
+
args[n_seq + 2 :],
|
|
52
|
+
)
|
|
53
|
+
return_ordered = []
|
|
54
|
+
for name in ["T", "R", "Q"]:
|
|
55
|
+
if name in self.seq_names:
|
|
56
|
+
idx = self.seq_names.index(name)
|
|
57
|
+
return_ordered.append(seqs[idx])
|
|
58
|
+
else:
|
|
59
|
+
idx = self.non_seq_names.index(name)
|
|
60
|
+
return_ordered.append(non_seqs[idx])
|
|
61
|
+
|
|
62
|
+
T, R, Q = return_ordered
|
|
63
|
+
|
|
64
|
+
return a, P, a_smooth, P_smooth, T, R, Q
|
|
65
|
+
|
|
66
|
+
def build_graph(
|
|
67
|
+
self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT
|
|
68
|
+
):
|
|
69
|
+
self.mode = mode
|
|
70
|
+
self.cov_jitter = cov_jitter
|
|
71
|
+
|
|
72
|
+
n, k = filtered_states.type.shape
|
|
73
|
+
|
|
74
|
+
a_last = pt.specify_shape(filtered_states[-1], (k,))
|
|
75
|
+
P_last = pt.specify_shape(filtered_covariances[-1], (k, k))
|
|
76
|
+
|
|
77
|
+
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
|
|
78
|
+
[T, R, Q], ["T", "R", "Q"]
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
self.seq_names = seq_names
|
|
82
|
+
self.non_seq_names = non_seq_names
|
|
83
|
+
|
|
84
|
+
smoother_result, updates = pytensor.scan(
|
|
85
|
+
self.smoother_step,
|
|
86
|
+
sequences=[filtered_states[:-1], filtered_covariances[:-1], *sequences],
|
|
87
|
+
outputs_info=[a_last, P_last],
|
|
88
|
+
non_sequences=non_sequences,
|
|
89
|
+
go_backwards=True,
|
|
90
|
+
name="kalman_smoother",
|
|
91
|
+
mode=get_mode(self.mode),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
smoothed_states, smoothed_covariances = smoother_result
|
|
95
|
+
smoothed_states = pt.concatenate(
|
|
96
|
+
[smoothed_states[::-1], pt.expand_dims(a_last, axis=(0,))], axis=0
|
|
97
|
+
)
|
|
98
|
+
smoothed_covariances = pt.concatenate(
|
|
99
|
+
[smoothed_covariances[::-1], pt.expand_dims(P_last, axis=(0,))], axis=0
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
smoothed_states.name = "smoothed_states"
|
|
103
|
+
smoothed_covariances.name = "smoothed_covariances"
|
|
104
|
+
|
|
105
|
+
return smoothed_states, smoothed_covariances
|
|
106
|
+
|
|
107
|
+
def smoother_step(self, *args):
|
|
108
|
+
a, P, a_smooth, P_smooth, T, R, Q = self.unpack_args(args)
|
|
109
|
+
a_hat, P_hat = self.predict(a, P, T, R, Q)
|
|
110
|
+
|
|
111
|
+
# Use pinv, otherwise P_hat is singular when there is missing data
|
|
112
|
+
smoother_gain = matrix_dot(pt.linalg.pinv(P_hat), T, P).T
|
|
113
|
+
a_smooth_next = a + smoother_gain @ (a_smooth - a_hat)
|
|
114
|
+
|
|
115
|
+
P_smooth_next = P + quad_form_sym(smoother_gain, P_smooth - P_hat)
|
|
116
|
+
P_smooth_next = stabilize(P_smooth_next, self.cov_jitter)
|
|
117
|
+
P_smooth_next = pt.specify_shape(stabilize(P_smooth_next), P_smooth.type.shape)
|
|
118
|
+
|
|
119
|
+
return a_smooth_next, P_smooth_next
|
|
120
|
+
|
|
121
|
+
def predict(self, a, P, T, R, Q):
|
|
122
|
+
a_hat = T.dot(a)
|
|
123
|
+
P_hat = quad_form_sym(T, P) + quad_form_sym(R, Q)
|
|
124
|
+
P_hat = stabilize(P_hat, self.cov_jitter)
|
|
125
|
+
|
|
126
|
+
return a_hat, P_hat
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import pytensor.tensor as pt
|
|
2
|
+
|
|
3
|
+
from pytensor.tensor.nlinalg import matrix_dot
|
|
4
|
+
|
|
5
|
+
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, NEVER_TIME_VARYING, VECTOR_VALUED
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def decide_if_x_time_varies(x, name):
|
|
9
|
+
if name in NEVER_TIME_VARYING:
|
|
10
|
+
return False
|
|
11
|
+
|
|
12
|
+
ndim = x.ndim
|
|
13
|
+
|
|
14
|
+
if name in VECTOR_VALUED:
|
|
15
|
+
if ndim not in [1, 2]:
|
|
16
|
+
raise ValueError(
|
|
17
|
+
f"Vector {name} has {ndim} dimensions; it should have either 1 (static),"
|
|
18
|
+
f" or 2 (time varying )"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
return ndim == 2
|
|
22
|
+
|
|
23
|
+
if ndim not in [2, 3]:
|
|
24
|
+
raise ValueError(
|
|
25
|
+
f"Matrix {name} has {ndim} dimensions; it should have either"
|
|
26
|
+
f" 2 (static), or 3 (time varying)."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
return ndim == 3
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def split_vars_into_seq_and_nonseq(params, param_names):
|
|
33
|
+
"""
|
|
34
|
+
Split inputs into those that are time varying and those that are not. This division is required by scan.
|
|
35
|
+
"""
|
|
36
|
+
sequences, non_sequences = [], []
|
|
37
|
+
seq_names, non_seq_names = [], []
|
|
38
|
+
|
|
39
|
+
for param, name in zip(params, param_names):
|
|
40
|
+
if decide_if_x_time_varies(param, name):
|
|
41
|
+
sequences.append(param)
|
|
42
|
+
seq_names.append(name)
|
|
43
|
+
else:
|
|
44
|
+
non_sequences.append(param)
|
|
45
|
+
non_seq_names.append(name)
|
|
46
|
+
|
|
47
|
+
return sequences, non_sequences, seq_names, non_seq_names
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def stabilize(cov, jitter=JITTER_DEFAULT):
|
|
51
|
+
# Ensure diagonal is non-zero
|
|
52
|
+
cov = cov + pt.identity_like(cov) * jitter
|
|
53
|
+
|
|
54
|
+
return cov
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def quad_form_sym(A, B):
|
|
58
|
+
out = matrix_dot(A, B, A.T)
|
|
59
|
+
return 0.5 * (out + out.T)
|