pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pytensor.tensor as pt
|
|
3
|
+
|
|
4
|
+
from pymc_extras.statespace.utils.constants import (
|
|
5
|
+
ALL_STATE_AUX_DIM,
|
|
6
|
+
ALL_STATE_DIM,
|
|
7
|
+
LONG_MATRIX_NAMES,
|
|
8
|
+
MATRIX_NAMES,
|
|
9
|
+
OBS_STATE_AUX_DIM,
|
|
10
|
+
OBS_STATE_DIM,
|
|
11
|
+
SHOCK_AUX_DIM,
|
|
12
|
+
SHOCK_DIM,
|
|
13
|
+
VECTOR_VALUED,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def make_default_coords(ss_mod):
|
|
18
|
+
coords = {
|
|
19
|
+
ALL_STATE_DIM: ss_mod.state_names,
|
|
20
|
+
ALL_STATE_AUX_DIM: ss_mod.state_names,
|
|
21
|
+
OBS_STATE_DIM: ss_mod.observed_states,
|
|
22
|
+
OBS_STATE_AUX_DIM: ss_mod.observed_states,
|
|
23
|
+
SHOCK_DIM: ss_mod.shock_names,
|
|
24
|
+
SHOCK_AUX_DIM: ss_mod.shock_names,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
return coords
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def cleanup_states(states: list[str]) -> list[str]:
|
|
31
|
+
"""
|
|
32
|
+
Remove meaningless symbols from state names
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
states, list of str
|
|
37
|
+
State names generated by make_harvey_state_names
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
states, list of str
|
|
42
|
+
State names for the Harvey statespace representation, with meaningless terms removed
|
|
43
|
+
|
|
44
|
+
The state names generated by make_harvey_state_names includes some "meaningless" terms. For example, lags are
|
|
45
|
+
indicated with L{i}.state. This includes L0.state, which is correctly just "state".
|
|
46
|
+
|
|
47
|
+
In addition, sequential applications of the difference operator are denoted Dk^i, where k is the length of the
|
|
48
|
+
difference, and i is the number of repeated applications. Dk^1 is thus just Dk.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
out = []
|
|
52
|
+
for state in states:
|
|
53
|
+
state = state.replace("^1", "")
|
|
54
|
+
state = state.replace("^0", "")
|
|
55
|
+
state = state.replace("L0", "")
|
|
56
|
+
state = state.replace("D0", "")
|
|
57
|
+
out.append(state)
|
|
58
|
+
return out
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def make_harvey_state_names(p: int, d: int, q: int, P: int, D: int, Q: int, S: int) -> list[str]:
|
|
62
|
+
"""
|
|
63
|
+
Generate informative names for the SARIMA states in the Harvey representation
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
p: int
|
|
68
|
+
AR order
|
|
69
|
+
d: int
|
|
70
|
+
Number of ARIMA differences
|
|
71
|
+
q: int
|
|
72
|
+
MA order
|
|
73
|
+
P: int
|
|
74
|
+
Seasonal AR order
|
|
75
|
+
D: int
|
|
76
|
+
Number of seasonal differences
|
|
77
|
+
Q: int
|
|
78
|
+
Seasonal MA order
|
|
79
|
+
S: int
|
|
80
|
+
Seasonal length
|
|
81
|
+
|
|
82
|
+
Returns
|
|
83
|
+
-------
|
|
84
|
+
state_names, list of str
|
|
85
|
+
List of state names
|
|
86
|
+
|
|
87
|
+
The Harvey state is not particularly interpretable, but it's also not totally opaque. This helper function makes
|
|
88
|
+
a list of state names that can help users understand what they are getting back from the statespace. In particular,
|
|
89
|
+
it is helpful to know how differences and seasonal differences are incorporated into the model
|
|
90
|
+
"""
|
|
91
|
+
k_lags = max(p + P * S, q + Q * S + 1)
|
|
92
|
+
has_diff = (d + D) > 0
|
|
93
|
+
|
|
94
|
+
# First state is always data
|
|
95
|
+
states = ["data"]
|
|
96
|
+
|
|
97
|
+
# Differencing operations
|
|
98
|
+
# The goal here is to get down to "data_star", the state that actually has the SARIMA dynamics applied to it.
|
|
99
|
+
# To get there, first the data needs to be differenced d-1 times
|
|
100
|
+
d_size = d + int(D > 0)
|
|
101
|
+
states.extend([f"D1^{(i + 1)}.data" for i in range(d_size)[:-1]])
|
|
102
|
+
|
|
103
|
+
# Next, if there are seasonal differences, we need to lag the ARIMA differenced state S times, then seasonal
|
|
104
|
+
# difference it. This procedure is done D-1 times.
|
|
105
|
+
|
|
106
|
+
arma_diff = [int(d_size > 1), d_size - 1]
|
|
107
|
+
season_diff = [S, 0]
|
|
108
|
+
curr_state = f"D{arma_diff[0]}^{arma_diff[1]}"
|
|
109
|
+
for i in range(D):
|
|
110
|
+
states.extend([f"L{j + 1}{curr_state}.data" for j in range(S - 1)])
|
|
111
|
+
season_diff[1] += 1
|
|
112
|
+
curr_state = f"D{arma_diff[0]}^{arma_diff[1]}D{season_diff[0]}^{season_diff[1]}"
|
|
113
|
+
if i != (D - 1):
|
|
114
|
+
states.append(f"{curr_state}.data")
|
|
115
|
+
|
|
116
|
+
# Now we are at data_star. If we did any differencing, add it in.
|
|
117
|
+
if has_diff:
|
|
118
|
+
states.append("data_star")
|
|
119
|
+
|
|
120
|
+
# Next, we add the time series dynamics states. These don't have a immediately obvious interpretation, so just call
|
|
121
|
+
# them "state_1" .., "state_n".
|
|
122
|
+
suffix = "_star" if "star" in states[-1] else ""
|
|
123
|
+
states.extend([f"state{suffix}_{i + 1}" for i in range(k_lags - 1)])
|
|
124
|
+
|
|
125
|
+
states = cleanup_states(states)
|
|
126
|
+
|
|
127
|
+
return states
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def make_SARIMA_transition_matrix(
|
|
131
|
+
p: int, d: int, q: int, P: int, D: int, Q: int, S: int
|
|
132
|
+
) -> np.ndarray:
|
|
133
|
+
r"""
|
|
134
|
+
Make the transition matrix for a SARIMA model
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
p: int
|
|
139
|
+
AR order
|
|
140
|
+
d: int
|
|
141
|
+
Number of ARIMA differences
|
|
142
|
+
q: int
|
|
143
|
+
MA order
|
|
144
|
+
P: int
|
|
145
|
+
Seasonal AR order
|
|
146
|
+
D: int
|
|
147
|
+
Number of seasonal differences
|
|
148
|
+
Q: int
|
|
149
|
+
Seasonal MA order
|
|
150
|
+
S: int
|
|
151
|
+
Seasonal length
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
T, ndarray
|
|
156
|
+
The transition matrix associated with a SARIMA model of order (p,d,q)x(P,D,Q,S)
|
|
157
|
+
|
|
158
|
+
Notes
|
|
159
|
+
-----
|
|
160
|
+
The transition matrix for the SARIMA model has a bunch of structure in it, especially when differences are included
|
|
161
|
+
in the statespace model. This function will always assume the state space matrix is in the Harvey representation.
|
|
162
|
+
|
|
163
|
+
Given this representation, the matrix can be divided into a bottom part and a top part. The top part has (S * D) + d
|
|
164
|
+
rows, and is associated with the differencing operations. The bottom part has max(P*S+p, Q*S+q+1) rows, and is
|
|
165
|
+
responsible for the actual time series dynamics.
|
|
166
|
+
|
|
167
|
+
The bottom part of the matrix is quite simple, it is just a shifted identity matrix (called a "companion matrix"),
|
|
168
|
+
responsible for "rolling" the states, so that at each transition, the value for :math:`x_{t-3}` becomes the value
|
|
169
|
+
for :math:`x_{t-2}`, and so on.
|
|
170
|
+
|
|
171
|
+
The top part is quite complex. The goal of this part of the matrix is to transform the raw data state, :math:`x_t`,
|
|
172
|
+
into a stationary state, :math:`x_t^\star`, via the application of differencing operations,
|
|
173
|
+
:math:`\Delta x_t = x_t - x_{t-1}`. For ARIMA differences (the little ``d``), this is quite simple. Sequential
|
|
174
|
+
differences are representated as an upper-triangular matrix of ones. To see this, consider an example where ``d=3``,
|
|
175
|
+
so that:
|
|
176
|
+
.. math::
|
|
177
|
+
|
|
178
|
+
\begin{align}
|
|
179
|
+
x_t^\star &= \Delta^3 x_t \\
|
|
180
|
+
&= \Delta^2 (x_t - x_{t-1})
|
|
181
|
+
&= \Delta (x_t - 2x_{t-1} + x_{t-2})
|
|
182
|
+
&= x_t - x_{t-1} - 2x_{t-1} + 2x_{t-3} + x_{t-2} - x_{t-3}
|
|
183
|
+
&= x_t - 3x_{t-1} + 3x_{t-3} - x_{t-3}
|
|
184
|
+
\end{align}
|
|
185
|
+
|
|
186
|
+
If you choose a state vector :math:`\begin{bmatrix}x_t & \Delta x_t & \Delta^2 x_t & x_t^\star \end{bmatrix}^T`,
|
|
187
|
+
you will find that:
|
|
188
|
+
|
|
189
|
+
.. math::
|
|
190
|
+
\begin{bmatrix}x_t \\ \Delta x_t \\ \Delta^2 x_t \\ x_t^\star \end{bmatrix} =
|
|
191
|
+
\begin{bmatrix} 1 & 1 & 1 & 1 \\
|
|
192
|
+
0 & 1 & 1 & 1 \\
|
|
193
|
+
0 & 0 & 1 & 1 \\
|
|
194
|
+
0 & 0 & 0 & 1
|
|
195
|
+
\end{bmatrix}
|
|
196
|
+
\begin{bmatrix} x_{t-1} \\ \Delta x_{t-1} \\ \Delta^2 x_{t-1} \\ x_{t-1}^\star \end{bmatrix}
|
|
197
|
+
|
|
198
|
+
Next are the seasonal differences. The highest seasonal difference stored in the states is one less than the
|
|
199
|
+
seasonal difference order, ``D``. That is, if ``D = 1, S = 4``, there will be states :math:``x_{t-1}, x_{t-2},
|
|
200
|
+
x_{t-3}, x_{t-4}, x_t^\star`, with :math:`x_t^\star = \Delta_4 x_t = x_t - x_{t-4}`. The level state can be
|
|
201
|
+
recovered by adding :math:`x_t^\star + x_{t-4}`. To accomplish all of this, two things need to be inserted into the
|
|
202
|
+
transition matrix:
|
|
203
|
+
|
|
204
|
+
1. A shifted identity matrix to "roll" the lagged states forward each transition, and
|
|
205
|
+
2. A pair of 1's to recover the level state by adding the last 2 states (:math:`x_t^\star + x_{t-4}`)
|
|
206
|
+
|
|
207
|
+
Keeping the example of ``D = 1, S = 4``, the block that handles the seasonal difference will look this this:
|
|
208
|
+
.. math::
|
|
209
|
+
\begin{bmatrix} 0 & 0 & 0 & 1 & 1 \\
|
|
210
|
+
1 & 0 & 0 & 0 & 0 \\
|
|
211
|
+
0 & 1 & 0 & 0 & 0 \\
|
|
212
|
+
0 & 0 & 1 & 0 & 0 \\
|
|
213
|
+
0 & 0 & 0 & 0 & 0 \end{bmatrix}
|
|
214
|
+
|
|
215
|
+
In the presence of higher order seasonal differences, there needs to be one block per difference. And the level
|
|
216
|
+
state is recovered by adding together the last state from each block. For example, if ``D = 2, S = 4``, the states
|
|
217
|
+
will be :math:`x_{t-1}, x_{t-2}, x_{t-3}, x_{t-4}, \Delta_4 x_{t-1}, \Delta_4 x_{t-2}, \Delta_4 x_{t-3},
|
|
218
|
+
\Delta_4 x_{t-4} x_t^\star`, with :math:`x_t^\star = \Delta_4^2 = \Delta_4(x_t - x_{t-4}) = x_t - 2 x_{t-4} +
|
|
219
|
+
x_{t-8}`. To recover the level state, we need :math:`x_t = x_t^\star + \Delta_4 x_{t-4} + x_{t-4}`. In addition,
|
|
220
|
+
to recover :math:`\Delta_4 x_t`, we have to compute :math:`\Delta_4 x_t = x_t^\star + \Delta_4 x_{t-4} =
|
|
221
|
+
\Delta_4(x_t - x_{t-4}) + \Delta_4 x_{t-4} = \Delta_4 x_t`. The block of the transition matrix associated with all
|
|
222
|
+
this is thus:
|
|
223
|
+
|
|
224
|
+
.. math::
|
|
225
|
+
\begin{bmatrix} 0 & 0 & 0 & 1 & 0 & 0 & 0 & 1 & 1 \\
|
|
226
|
+
1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\
|
|
227
|
+
0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\
|
|
228
|
+
0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \\
|
|
229
|
+
0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 1 \\
|
|
230
|
+
0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 \\
|
|
231
|
+
0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 \\
|
|
232
|
+
0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 \\
|
|
233
|
+
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}
|
|
234
|
+
|
|
235
|
+
When ARIMA differences and seasonal differences are mixed, the seasonal differences will be written in terms of the
|
|
236
|
+
highest ARIMA difference order, and recovery of the level state will require the use of all the ARIMA differences,
|
|
237
|
+
as well as the seasonal differences. In addition, the seasonal differences are needed to back out the ARIMA
|
|
238
|
+
differences from :math:`x_t^\star`. Here is the differencing block for a SARIMA(0,2,0)x(0,2,0,4) -- the identites
|
|
239
|
+
of the states is left an exercise for the motivated reader:
|
|
240
|
+
|
|
241
|
+
.. math::
|
|
242
|
+
\begin{bmatrix}
|
|
243
|
+
1 & 1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 1 & 1 \\
|
|
244
|
+
0 & 1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 1 & 1 \\
|
|
245
|
+
0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 1 & 1 \\
|
|
246
|
+
0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\
|
|
247
|
+
0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\
|
|
248
|
+
0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \\
|
|
249
|
+
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 1 \\
|
|
250
|
+
0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 \\
|
|
251
|
+
0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 \\
|
|
252
|
+
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 \\
|
|
253
|
+
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}
|
|
254
|
+
"""
|
|
255
|
+
n_diffs = S * D + d
|
|
256
|
+
k_lags = max(p + P * S, q + Q * S + 1)
|
|
257
|
+
k_states = k_lags + n_diffs
|
|
258
|
+
|
|
259
|
+
# Top Part
|
|
260
|
+
# ARIMA differences
|
|
261
|
+
T = np.zeros((k_states, k_states))
|
|
262
|
+
diff_idx = np.triu_indices(d)
|
|
263
|
+
T[diff_idx] = 1
|
|
264
|
+
|
|
265
|
+
# Adjustment factors for difference states All of the difference states are computed relative to x_t_star using
|
|
266
|
+
# combinations of states, so there's a lot of "backing out" that needs to happen here. The columns are the more
|
|
267
|
+
# straightforward part. After the (d,d) upper triangle of 1s for the ARIMA lags, there will be (S - 1) zeros,
|
|
268
|
+
# and then a 1. In addition, there is an extra column of 1s at position n_diffs + 1, corresponding to x_star itself.
|
|
269
|
+
|
|
270
|
+
# This will slowly taper down, but first we build the "full" set of column indices with values
|
|
271
|
+
base_col_idx = d + S + np.arange(D) * S - 1
|
|
272
|
+
if len(base_col_idx) > 0:
|
|
273
|
+
base_col_idx = np.r_[base_col_idx, base_col_idx[-1] + 1]
|
|
274
|
+
|
|
275
|
+
# The first d rows -- associated with the ARIMA differences -- will have 1s in all columns.
|
|
276
|
+
col_idx = np.tile(base_col_idx, d)
|
|
277
|
+
row_idx = np.arange(d).repeat(D + 1)
|
|
278
|
+
|
|
279
|
+
# Next, if there are seasonal differences, there will be more rows, with the columns slowly dropping off.
|
|
280
|
+
# Starting from the d+1-th row, there will be 1 in the column positions every S rows, for a total of (D-1) rows.
|
|
281
|
+
# Every row will drop 2 columns from the left of base_col_idx.
|
|
282
|
+
for i in range(D):
|
|
283
|
+
n = len(base_col_idx[i:])
|
|
284
|
+
col_idx = np.r_[col_idx, base_col_idx[i:]]
|
|
285
|
+
row_idx = np.r_[row_idx, np.full(n, d + S * i)]
|
|
286
|
+
|
|
287
|
+
if D == 0 and d > 0:
|
|
288
|
+
# Special case: If there are *only* ARIMA lags, there still needs to be a single column of 1s at position
|
|
289
|
+
# [:d, d]
|
|
290
|
+
row_idx = np.arange(d)
|
|
291
|
+
col_idx = np.full(d, d)
|
|
292
|
+
T[row_idx, col_idx] = 1
|
|
293
|
+
|
|
294
|
+
if S > 0:
|
|
295
|
+
# "Rolling" indices for seasonal differences
|
|
296
|
+
(row_roll_idx, col_roll_idx) = np.diag_indices(S * D)
|
|
297
|
+
row_roll_idx = row_roll_idx + d + 1
|
|
298
|
+
col_roll_idx = col_roll_idx + d
|
|
299
|
+
|
|
300
|
+
# Rolling indices have a zero after every diagonal of length S-1
|
|
301
|
+
T[row_roll_idx, col_roll_idx] = 1
|
|
302
|
+
zero_idx = row_roll_idx[S - 1 :: S], col_roll_idx[S - 1 :: S]
|
|
303
|
+
T[zero_idx] = 0
|
|
304
|
+
|
|
305
|
+
# Bottom part
|
|
306
|
+
# Rolling indices for the "compute" states, x_star
|
|
307
|
+
star_roll_row, star_roll_col = np.diag_indices(k_lags - 1)
|
|
308
|
+
star_roll_row = star_roll_row + n_diffs
|
|
309
|
+
star_roll_col = star_roll_col + n_diffs + 1
|
|
310
|
+
|
|
311
|
+
T[star_roll_row, star_roll_col] = 1
|
|
312
|
+
|
|
313
|
+
return T
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def conform_time_varying_and_time_invariant_matrices(A, B):
|
|
317
|
+
"""
|
|
318
|
+
Adjust either A or B to conform to the other in the time dimension
|
|
319
|
+
|
|
320
|
+
In the context of building a structural model from components, it might be the case that one component has
|
|
321
|
+
time-varying statespace matrices, while the other does not. In this case, it is not possible to concatenate
|
|
322
|
+
or block diagonalize the pair of matrices A and B without first expanding the time-invariant matrix to have a
|
|
323
|
+
time dimension. This function checks if exactly one of the two time varies, and adjusts the other accordingly if
|
|
324
|
+
need be.
|
|
325
|
+
|
|
326
|
+
Parameters
|
|
327
|
+
----------
|
|
328
|
+
A: pt.TensorVariable
|
|
329
|
+
An anonymous statespace matrix
|
|
330
|
+
B: pt.TensorVariable
|
|
331
|
+
An anonymous statespace matrix
|
|
332
|
+
|
|
333
|
+
Returns
|
|
334
|
+
-------
|
|
335
|
+
(A, B): Tuple of pt.TensorVariable
|
|
336
|
+
A and B, with one or neither expanded to have a time dimension.
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
if A.name == B.name:
|
|
340
|
+
name = A.name
|
|
341
|
+
else:
|
|
342
|
+
if all([X.name not in MATRIX_NAMES + LONG_MATRIX_NAMES for X in [A, B]]):
|
|
343
|
+
raise ValueError(
|
|
344
|
+
"At least one matrix passed to conform_time_varying_and_time_invariant_matrices should be a "
|
|
345
|
+
"statespace matrix"
|
|
346
|
+
)
|
|
347
|
+
name = A.name if A.name in MATRIX_NAMES + LONG_MATRIX_NAMES else B.name
|
|
348
|
+
|
|
349
|
+
time_varying_ndim = 3 - int(name in VECTOR_VALUED)
|
|
350
|
+
|
|
351
|
+
if not all([x.ndim == time_varying_ndim for x in [A, B]]):
|
|
352
|
+
return A, B
|
|
353
|
+
|
|
354
|
+
T_A, *A_dims = A.type.shape
|
|
355
|
+
T_B, *B_dims = B.type.shape
|
|
356
|
+
|
|
357
|
+
if T_A == T_B:
|
|
358
|
+
return A, B
|
|
359
|
+
|
|
360
|
+
if T_A == 1:
|
|
361
|
+
A_out = pt.repeat(A, B.shape[0], axis=0)
|
|
362
|
+
A_out = pt.specify_shape(A_out, (T_B, *tuple(A_dims)))
|
|
363
|
+
A_out.name = A.name
|
|
364
|
+
|
|
365
|
+
return A_out, B
|
|
366
|
+
|
|
367
|
+
if T_B == 1:
|
|
368
|
+
B_out = pt.repeat(B, A.shape[0], axis=0)
|
|
369
|
+
B_out = pt.specify_shape(B_out, (T_A, *tuple(B_dims)))
|
|
370
|
+
B_out.name = B.name
|
|
371
|
+
|
|
372
|
+
return A, B_out
|
|
373
|
+
|
|
374
|
+
return A, B
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def get_exog_dims_from_idata(exog_name, idata):
|
|
378
|
+
if exog_name in idata.posterior.data_vars:
|
|
379
|
+
exog_dims = idata.posterior[exog_name].dims[2:]
|
|
380
|
+
elif exog_name in getattr(idata, "constant_data", []):
|
|
381
|
+
exog_dims = idata.constant_data[exog_name].dims
|
|
382
|
+
elif exog_name in getattr(idata, "mutable_data", []):
|
|
383
|
+
exog_dims = idata.mutable_data[exog_name].dims
|
|
384
|
+
else:
|
|
385
|
+
exog_dims = None
|
|
386
|
+
|
|
387
|
+
return exog_dims
|
|
File without changes
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import pytensor
|
|
2
|
+
|
|
3
|
+
ALL_STATE_DIM = "state"
|
|
4
|
+
ALL_STATE_AUX_DIM = "state_aux"
|
|
5
|
+
OBS_STATE_DIM = "observed_state"
|
|
6
|
+
OBS_STATE_AUX_DIM = "observed_state_aux"
|
|
7
|
+
SHOCK_DIM = "shock"
|
|
8
|
+
SHOCK_AUX_DIM = "shock_aux"
|
|
9
|
+
TIME_DIM = "time"
|
|
10
|
+
AR_PARAM_DIM = "ar_lag"
|
|
11
|
+
MA_PARAM_DIM = "ma_lag"
|
|
12
|
+
SEASONAL_AR_PARAM_DIM = "seasonal_ar_lag"
|
|
13
|
+
SEASONAL_MA_PARAM_DIM = "seasonal_ma_lag"
|
|
14
|
+
ETS_SEASONAL_DIM = "seasonal_lag"
|
|
15
|
+
|
|
16
|
+
NEVER_TIME_VARYING = ["initial_state", "initial_state_cov", "a0", "P0"]
|
|
17
|
+
VECTOR_VALUED = ["initial_state", "state_intercept", "obs_intercept", "a0", "c", "d"]
|
|
18
|
+
|
|
19
|
+
MISSING_FILL = -9999.0
|
|
20
|
+
JITTER_DEFAULT = 1e-8 if pytensor.config.floatX.endswith("64") else 1e-6
|
|
21
|
+
|
|
22
|
+
FILTER_OUTPUT_TYPES = ["filtered", "predicted", "smoothed"]
|
|
23
|
+
|
|
24
|
+
MATRIX_NAMES = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
|
|
25
|
+
LONG_MATRIX_NAMES = [
|
|
26
|
+
"initial_state",
|
|
27
|
+
"initial_state_cov",
|
|
28
|
+
"state_intercept",
|
|
29
|
+
"obs_intercept",
|
|
30
|
+
"transition",
|
|
31
|
+
"design",
|
|
32
|
+
"selection",
|
|
33
|
+
"obs_cov",
|
|
34
|
+
"state_cov",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
SHORT_NAME_TO_LONG = dict(zip(MATRIX_NAMES, LONG_MATRIX_NAMES))
|
|
38
|
+
LONG_NAME_TO_SHORT = dict(zip(LONG_MATRIX_NAMES, MATRIX_NAMES))
|
|
39
|
+
|
|
40
|
+
FILTER_OUTPUT_NAMES = [
|
|
41
|
+
"filtered_state",
|
|
42
|
+
"predicted_state",
|
|
43
|
+
"filtered_covariance",
|
|
44
|
+
"predicted_covariance",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"]
|
|
48
|
+
OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"]
|
|
49
|
+
|
|
50
|
+
MATRIX_DIMS = {
|
|
51
|
+
"x0": (ALL_STATE_DIM,),
|
|
52
|
+
"P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),
|
|
53
|
+
"c": (ALL_STATE_DIM,),
|
|
54
|
+
"d": (OBS_STATE_DIM,),
|
|
55
|
+
"T": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),
|
|
56
|
+
"Z": (OBS_STATE_DIM, ALL_STATE_DIM),
|
|
57
|
+
"R": (ALL_STATE_DIM, SHOCK_DIM),
|
|
58
|
+
"H": (OBS_STATE_DIM, OBS_STATE_AUX_DIM),
|
|
59
|
+
"Q": (SHOCK_DIM, SHOCK_AUX_DIM),
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
FILTER_OUTPUT_DIMS = {
|
|
63
|
+
"filtered_state": (TIME_DIM, ALL_STATE_DIM),
|
|
64
|
+
"smoothed_state": (TIME_DIM, ALL_STATE_DIM),
|
|
65
|
+
"predicted_state": (TIME_DIM, ALL_STATE_DIM),
|
|
66
|
+
"filtered_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
|
|
67
|
+
"smoothed_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
|
|
68
|
+
"predicted_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
|
|
69
|
+
"predicted_observed_state": (TIME_DIM, OBS_STATE_DIM),
|
|
70
|
+
"predicted_observed_covariance": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM),
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
POSITION_DERIVATIVE_NAMES = ["level", "trend", "acceleration", "jerk", "snap", "crackle", "pop"]
|
|
74
|
+
SARIMAX_STATE_STRUCTURES = ["fast", "interpretable"]
|
|
File without changes
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import pymc as pm
|
|
6
|
+
import pytensor
|
|
7
|
+
import pytensor.tensor as pt
|
|
8
|
+
|
|
9
|
+
from pymc import ImputationWarning, modelcontext
|
|
10
|
+
from pytensor.tensor.sharedvar import TensorSharedVariable
|
|
11
|
+
|
|
12
|
+
from pymc_extras.statespace.utils.constants import (
|
|
13
|
+
MISSING_FILL,
|
|
14
|
+
OBS_STATE_DIM,
|
|
15
|
+
TIME_DIM,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
NO_TIME_INDEX_WARNING = (
|
|
19
|
+
"No time index found on the supplied data. A simple range index will be automatically "
|
|
20
|
+
"generated."
|
|
21
|
+
)
|
|
22
|
+
NO_FREQ_INFO_WARNING = "No frequency was specific on the data's DateTimeIndex."
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_data_dims(data):
|
|
26
|
+
if not isinstance(data, pt.TensorVariable | TensorSharedVariable):
|
|
27
|
+
return
|
|
28
|
+
|
|
29
|
+
data_name = getattr(data, "name", None)
|
|
30
|
+
if not data_name:
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
pm_mod = modelcontext(None)
|
|
34
|
+
data_dims = pm_mod.named_vars_to_dims.get(data_name, None)
|
|
35
|
+
return data_dims
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=False, col_names=None):
|
|
39
|
+
if col_names is None:
|
|
40
|
+
col_names = []
|
|
41
|
+
|
|
42
|
+
if len(data_shape) != 2:
|
|
43
|
+
raise ValueError("Data must be a 2d matrix")
|
|
44
|
+
|
|
45
|
+
if data_shape[-1] != n_obs:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"Shape of data does not match model output. Expected {n_obs} columns, "
|
|
48
|
+
f"found {data_shape[-1]}."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if check_col_names:
|
|
52
|
+
missing_cols = set(obs_coords) - set(col_names)
|
|
53
|
+
if len(missing_cols) > 0:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
"Columns of DataFrame provided as data do not match state names. The following states were"
|
|
56
|
+
f'not found: {", ".join(missing_cols)}. This may result in unexpected results in complex'
|
|
57
|
+
f"statespace models"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def preprocess_tensor_data(data, n_obs, obs_coords=None):
|
|
62
|
+
data_shape = data.shape.eval()
|
|
63
|
+
_validate_data_shape(data_shape, n_obs, obs_coords)
|
|
64
|
+
if obs_coords is not None:
|
|
65
|
+
warnings.warn(NO_TIME_INDEX_WARNING)
|
|
66
|
+
index = np.arange(data_shape[0], dtype="int")
|
|
67
|
+
|
|
68
|
+
return data.eval(), index
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def preprocess_numpy_data(data, n_obs, obs_coords=None):
|
|
72
|
+
_validate_data_shape(data.shape, n_obs, obs_coords)
|
|
73
|
+
if obs_coords is not None:
|
|
74
|
+
warnings.warn(NO_TIME_INDEX_WARNING)
|
|
75
|
+
|
|
76
|
+
index = np.arange(data.shape[0], dtype="int")
|
|
77
|
+
|
|
78
|
+
return data, index
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=False):
|
|
82
|
+
if isinstance(data, pd.Series):
|
|
83
|
+
if data.name is None:
|
|
84
|
+
data.name = "data"
|
|
85
|
+
data = data.to_frame()
|
|
86
|
+
|
|
87
|
+
col_names = data.columns
|
|
88
|
+
_validate_data_shape(data.shape, n_obs, obs_coords, check_column_names, col_names)
|
|
89
|
+
|
|
90
|
+
if isinstance(data.index, pd.RangeIndex):
|
|
91
|
+
if obs_coords is not None:
|
|
92
|
+
warnings.warn(NO_TIME_INDEX_WARNING)
|
|
93
|
+
return preprocess_numpy_data(data.values, n_obs, obs_coords)
|
|
94
|
+
|
|
95
|
+
elif isinstance(data.index, pd.DatetimeIndex):
|
|
96
|
+
if data.index.freq is None:
|
|
97
|
+
warnings.warn(NO_FREQ_INFO_WARNING)
|
|
98
|
+
data.index.freq = data.index.inferred_freq
|
|
99
|
+
|
|
100
|
+
index = data.index
|
|
101
|
+
return data.values, index
|
|
102
|
+
|
|
103
|
+
else:
|
|
104
|
+
raise IndexError(
|
|
105
|
+
f"Expected pd.DatetimeIndex or pd.RangeIndex on data, found {type(data.index)}"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def add_data_to_active_model(values, index, data_dims=None):
|
|
110
|
+
pymc_mod = modelcontext(None)
|
|
111
|
+
if data_dims is None:
|
|
112
|
+
data_dims = [TIME_DIM, OBS_STATE_DIM]
|
|
113
|
+
time_dim = data_dims[0]
|
|
114
|
+
|
|
115
|
+
if time_dim not in pymc_mod.coords:
|
|
116
|
+
pymc_mod.add_coord(time_dim, index)
|
|
117
|
+
else:
|
|
118
|
+
found_time = pymc_mod.coords[time_dim]
|
|
119
|
+
if found_time is None:
|
|
120
|
+
pymc_mod.coords.update({time_dim: index})
|
|
121
|
+
elif not np.array_equal(found_time, tuple(index)):
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"Provided data has a different time index than the model. Please ensure that the time values "
|
|
124
|
+
"set on coords matches that of the exogenous data."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# If the data has just one column, we need to specify the shape as (None, 1), or else the JAX backend will
|
|
128
|
+
# raise a broadcasting error.
|
|
129
|
+
data_shape = None
|
|
130
|
+
if values.shape[-1] == 1:
|
|
131
|
+
data_shape = (None, 1)
|
|
132
|
+
|
|
133
|
+
data = pm.Data("data", values, dims=data_dims, shape=data_shape)
|
|
134
|
+
|
|
135
|
+
return data
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def mask_missing_values_in_data(values, missing_fill_value=None):
|
|
139
|
+
if missing_fill_value is None:
|
|
140
|
+
missing_fill_value = MISSING_FILL
|
|
141
|
+
|
|
142
|
+
masked_values = np.ma.masked_invalid(values)
|
|
143
|
+
filled_values = masked_values.filled(missing_fill_value)
|
|
144
|
+
nan_mask = masked_values.mask
|
|
145
|
+
|
|
146
|
+
if np.any(nan_mask):
|
|
147
|
+
if np.any(values == missing_fill_value):
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"Provided data contains the value {missing_fill_value}, which is used as a missing value marker. "
|
|
150
|
+
f"Please manually change the missing_fill_value to avoid this collision."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
impute_message = (
|
|
154
|
+
"Provided data contains missing values and"
|
|
155
|
+
" will be automatically imputed as hidden states"
|
|
156
|
+
" during Kalman filtering."
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
warnings.warn(impute_message, ImputationWarning)
|
|
160
|
+
|
|
161
|
+
return filled_values, nan_mask
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def register_data_with_pymc(
|
|
165
|
+
data, n_obs, obs_coords, register_data=True, missing_fill_value=None, data_dims=None
|
|
166
|
+
):
|
|
167
|
+
if isinstance(data, pt.TensorVariable | TensorSharedVariable):
|
|
168
|
+
values, index = preprocess_tensor_data(data, n_obs, obs_coords)
|
|
169
|
+
elif isinstance(data, np.ndarray):
|
|
170
|
+
values, index = preprocess_numpy_data(data, n_obs, obs_coords)
|
|
171
|
+
elif isinstance(data, pd.DataFrame | pd.Series):
|
|
172
|
+
values, index = preprocess_pandas_data(data, n_obs, obs_coords)
|
|
173
|
+
else:
|
|
174
|
+
raise ValueError("Data should be one of pytensor tensor, numpy array, or pandas dataframe")
|
|
175
|
+
|
|
176
|
+
data, nan_mask = mask_missing_values_in_data(values, missing_fill_value)
|
|
177
|
+
|
|
178
|
+
if register_data:
|
|
179
|
+
data = add_data_to_active_model(data, index, data_dims)
|
|
180
|
+
else:
|
|
181
|
+
data = pytensor.shared(data, name="data")
|
|
182
|
+
return data, nan_mask
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright 2022 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from pymc_extras.utils import prior, spline
|
|
17
|
+
from pymc_extras.utils.linear_cg import linear_cg
|
|
18
|
+
|
|
19
|
+
__all__ = (
|
|
20
|
+
"linear_cg",
|
|
21
|
+
"prior",
|
|
22
|
+
"spline",
|
|
23
|
+
)
|