pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,387 @@
1
+ import numpy as np
2
+ import pytensor.tensor as pt
3
+
4
+ from pymc_extras.statespace.utils.constants import (
5
+ ALL_STATE_AUX_DIM,
6
+ ALL_STATE_DIM,
7
+ LONG_MATRIX_NAMES,
8
+ MATRIX_NAMES,
9
+ OBS_STATE_AUX_DIM,
10
+ OBS_STATE_DIM,
11
+ SHOCK_AUX_DIM,
12
+ SHOCK_DIM,
13
+ VECTOR_VALUED,
14
+ )
15
+
16
+
17
+ def make_default_coords(ss_mod):
18
+ coords = {
19
+ ALL_STATE_DIM: ss_mod.state_names,
20
+ ALL_STATE_AUX_DIM: ss_mod.state_names,
21
+ OBS_STATE_DIM: ss_mod.observed_states,
22
+ OBS_STATE_AUX_DIM: ss_mod.observed_states,
23
+ SHOCK_DIM: ss_mod.shock_names,
24
+ SHOCK_AUX_DIM: ss_mod.shock_names,
25
+ }
26
+
27
+ return coords
28
+
29
+
30
+ def cleanup_states(states: list[str]) -> list[str]:
31
+ """
32
+ Remove meaningless symbols from state names
33
+
34
+ Parameters
35
+ ----------
36
+ states, list of str
37
+ State names generated by make_harvey_state_names
38
+
39
+ Returns
40
+ -------
41
+ states, list of str
42
+ State names for the Harvey statespace representation, with meaningless terms removed
43
+
44
+ The state names generated by make_harvey_state_names includes some "meaningless" terms. For example, lags are
45
+ indicated with L{i}.state. This includes L0.state, which is correctly just "state".
46
+
47
+ In addition, sequential applications of the difference operator are denoted Dk^i, where k is the length of the
48
+ difference, and i is the number of repeated applications. Dk^1 is thus just Dk.
49
+ """
50
+
51
+ out = []
52
+ for state in states:
53
+ state = state.replace("^1", "")
54
+ state = state.replace("^0", "")
55
+ state = state.replace("L0", "")
56
+ state = state.replace("D0", "")
57
+ out.append(state)
58
+ return out
59
+
60
+
61
+ def make_harvey_state_names(p: int, d: int, q: int, P: int, D: int, Q: int, S: int) -> list[str]:
62
+ """
63
+ Generate informative names for the SARIMA states in the Harvey representation
64
+
65
+ Parameters
66
+ ----------
67
+ p: int
68
+ AR order
69
+ d: int
70
+ Number of ARIMA differences
71
+ q: int
72
+ MA order
73
+ P: int
74
+ Seasonal AR order
75
+ D: int
76
+ Number of seasonal differences
77
+ Q: int
78
+ Seasonal MA order
79
+ S: int
80
+ Seasonal length
81
+
82
+ Returns
83
+ -------
84
+ state_names, list of str
85
+ List of state names
86
+
87
+ The Harvey state is not particularly interpretable, but it's also not totally opaque. This helper function makes
88
+ a list of state names that can help users understand what they are getting back from the statespace. In particular,
89
+ it is helpful to know how differences and seasonal differences are incorporated into the model
90
+ """
91
+ k_lags = max(p + P * S, q + Q * S + 1)
92
+ has_diff = (d + D) > 0
93
+
94
+ # First state is always data
95
+ states = ["data"]
96
+
97
+ # Differencing operations
98
+ # The goal here is to get down to "data_star", the state that actually has the SARIMA dynamics applied to it.
99
+ # To get there, first the data needs to be differenced d-1 times
100
+ d_size = d + int(D > 0)
101
+ states.extend([f"D1^{(i + 1)}.data" for i in range(d_size)[:-1]])
102
+
103
+ # Next, if there are seasonal differences, we need to lag the ARIMA differenced state S times, then seasonal
104
+ # difference it. This procedure is done D-1 times.
105
+
106
+ arma_diff = [int(d_size > 1), d_size - 1]
107
+ season_diff = [S, 0]
108
+ curr_state = f"D{arma_diff[0]}^{arma_diff[1]}"
109
+ for i in range(D):
110
+ states.extend([f"L{j + 1}{curr_state}.data" for j in range(S - 1)])
111
+ season_diff[1] += 1
112
+ curr_state = f"D{arma_diff[0]}^{arma_diff[1]}D{season_diff[0]}^{season_diff[1]}"
113
+ if i != (D - 1):
114
+ states.append(f"{curr_state}.data")
115
+
116
+ # Now we are at data_star. If we did any differencing, add it in.
117
+ if has_diff:
118
+ states.append("data_star")
119
+
120
+ # Next, we add the time series dynamics states. These don't have a immediately obvious interpretation, so just call
121
+ # them "state_1" .., "state_n".
122
+ suffix = "_star" if "star" in states[-1] else ""
123
+ states.extend([f"state{suffix}_{i + 1}" for i in range(k_lags - 1)])
124
+
125
+ states = cleanup_states(states)
126
+
127
+ return states
128
+
129
+
130
+ def make_SARIMA_transition_matrix(
131
+ p: int, d: int, q: int, P: int, D: int, Q: int, S: int
132
+ ) -> np.ndarray:
133
+ r"""
134
+ Make the transition matrix for a SARIMA model
135
+
136
+ Parameters
137
+ ----------
138
+ p: int
139
+ AR order
140
+ d: int
141
+ Number of ARIMA differences
142
+ q: int
143
+ MA order
144
+ P: int
145
+ Seasonal AR order
146
+ D: int
147
+ Number of seasonal differences
148
+ Q: int
149
+ Seasonal MA order
150
+ S: int
151
+ Seasonal length
152
+
153
+ Returns
154
+ -------
155
+ T, ndarray
156
+ The transition matrix associated with a SARIMA model of order (p,d,q)x(P,D,Q,S)
157
+
158
+ Notes
159
+ -----
160
+ The transition matrix for the SARIMA model has a bunch of structure in it, especially when differences are included
161
+ in the statespace model. This function will always assume the state space matrix is in the Harvey representation.
162
+
163
+ Given this representation, the matrix can be divided into a bottom part and a top part. The top part has (S * D) + d
164
+ rows, and is associated with the differencing operations. The bottom part has max(P*S+p, Q*S+q+1) rows, and is
165
+ responsible for the actual time series dynamics.
166
+
167
+ The bottom part of the matrix is quite simple, it is just a shifted identity matrix (called a "companion matrix"),
168
+ responsible for "rolling" the states, so that at each transition, the value for :math:`x_{t-3}` becomes the value
169
+ for :math:`x_{t-2}`, and so on.
170
+
171
+ The top part is quite complex. The goal of this part of the matrix is to transform the raw data state, :math:`x_t`,
172
+ into a stationary state, :math:`x_t^\star`, via the application of differencing operations,
173
+ :math:`\Delta x_t = x_t - x_{t-1}`. For ARIMA differences (the little ``d``), this is quite simple. Sequential
174
+ differences are representated as an upper-triangular matrix of ones. To see this, consider an example where ``d=3``,
175
+ so that:
176
+ .. math::
177
+
178
+ \begin{align}
179
+ x_t^\star &= \Delta^3 x_t \\
180
+ &= \Delta^2 (x_t - x_{t-1})
181
+ &= \Delta (x_t - 2x_{t-1} + x_{t-2})
182
+ &= x_t - x_{t-1} - 2x_{t-1} + 2x_{t-3} + x_{t-2} - x_{t-3}
183
+ &= x_t - 3x_{t-1} + 3x_{t-3} - x_{t-3}
184
+ \end{align}
185
+
186
+ If you choose a state vector :math:`\begin{bmatrix}x_t & \Delta x_t & \Delta^2 x_t & x_t^\star \end{bmatrix}^T`,
187
+ you will find that:
188
+
189
+ .. math::
190
+ \begin{bmatrix}x_t \\ \Delta x_t \\ \Delta^2 x_t \\ x_t^\star \end{bmatrix} =
191
+ \begin{bmatrix} 1 & 1 & 1 & 1 \\
192
+ 0 & 1 & 1 & 1 \\
193
+ 0 & 0 & 1 & 1 \\
194
+ 0 & 0 & 0 & 1
195
+ \end{bmatrix}
196
+ \begin{bmatrix} x_{t-1} \\ \Delta x_{t-1} \\ \Delta^2 x_{t-1} \\ x_{t-1}^\star \end{bmatrix}
197
+
198
+ Next are the seasonal differences. The highest seasonal difference stored in the states is one less than the
199
+ seasonal difference order, ``D``. That is, if ``D = 1, S = 4``, there will be states :math:``x_{t-1}, x_{t-2},
200
+ x_{t-3}, x_{t-4}, x_t^\star`, with :math:`x_t^\star = \Delta_4 x_t = x_t - x_{t-4}`. The level state can be
201
+ recovered by adding :math:`x_t^\star + x_{t-4}`. To accomplish all of this, two things need to be inserted into the
202
+ transition matrix:
203
+
204
+ 1. A shifted identity matrix to "roll" the lagged states forward each transition, and
205
+ 2. A pair of 1's to recover the level state by adding the last 2 states (:math:`x_t^\star + x_{t-4}`)
206
+
207
+ Keeping the example of ``D = 1, S = 4``, the block that handles the seasonal difference will look this this:
208
+ .. math::
209
+ \begin{bmatrix} 0 & 0 & 0 & 1 & 1 \\
210
+ 1 & 0 & 0 & 0 & 0 \\
211
+ 0 & 1 & 0 & 0 & 0 \\
212
+ 0 & 0 & 1 & 0 & 0 \\
213
+ 0 & 0 & 0 & 0 & 0 \end{bmatrix}
214
+
215
+ In the presence of higher order seasonal differences, there needs to be one block per difference. And the level
216
+ state is recovered by adding together the last state from each block. For example, if ``D = 2, S = 4``, the states
217
+ will be :math:`x_{t-1}, x_{t-2}, x_{t-3}, x_{t-4}, \Delta_4 x_{t-1}, \Delta_4 x_{t-2}, \Delta_4 x_{t-3},
218
+ \Delta_4 x_{t-4} x_t^\star`, with :math:`x_t^\star = \Delta_4^2 = \Delta_4(x_t - x_{t-4}) = x_t - 2 x_{t-4} +
219
+ x_{t-8}`. To recover the level state, we need :math:`x_t = x_t^\star + \Delta_4 x_{t-4} + x_{t-4}`. In addition,
220
+ to recover :math:`\Delta_4 x_t`, we have to compute :math:`\Delta_4 x_t = x_t^\star + \Delta_4 x_{t-4} =
221
+ \Delta_4(x_t - x_{t-4}) + \Delta_4 x_{t-4} = \Delta_4 x_t`. The block of the transition matrix associated with all
222
+ this is thus:
223
+
224
+ .. math::
225
+ \begin{bmatrix} 0 & 0 & 0 & 1 & 0 & 0 & 0 & 1 & 1 \\
226
+ 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\
227
+ 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\
228
+ 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \\
229
+ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 1 \\
230
+ 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 \\
231
+ 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 \\
232
+ 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 \\
233
+ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}
234
+
235
+ When ARIMA differences and seasonal differences are mixed, the seasonal differences will be written in terms of the
236
+ highest ARIMA difference order, and recovery of the level state will require the use of all the ARIMA differences,
237
+ as well as the seasonal differences. In addition, the seasonal differences are needed to back out the ARIMA
238
+ differences from :math:`x_t^\star`. Here is the differencing block for a SARIMA(0,2,0)x(0,2,0,4) -- the identites
239
+ of the states is left an exercise for the motivated reader:
240
+
241
+ .. math::
242
+ \begin{bmatrix}
243
+ 1 & 1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 1 & 1 \\
244
+ 0 & 1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 1 & 1 \\
245
+ 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 1 & 1 \\
246
+ 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\
247
+ 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\
248
+ 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \\
249
+ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 1 \\
250
+ 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 \\
251
+ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 \\
252
+ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 \\
253
+ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}
254
+ """
255
+ n_diffs = S * D + d
256
+ k_lags = max(p + P * S, q + Q * S + 1)
257
+ k_states = k_lags + n_diffs
258
+
259
+ # Top Part
260
+ # ARIMA differences
261
+ T = np.zeros((k_states, k_states))
262
+ diff_idx = np.triu_indices(d)
263
+ T[diff_idx] = 1
264
+
265
+ # Adjustment factors for difference states All of the difference states are computed relative to x_t_star using
266
+ # combinations of states, so there's a lot of "backing out" that needs to happen here. The columns are the more
267
+ # straightforward part. After the (d,d) upper triangle of 1s for the ARIMA lags, there will be (S - 1) zeros,
268
+ # and then a 1. In addition, there is an extra column of 1s at position n_diffs + 1, corresponding to x_star itself.
269
+
270
+ # This will slowly taper down, but first we build the "full" set of column indices with values
271
+ base_col_idx = d + S + np.arange(D) * S - 1
272
+ if len(base_col_idx) > 0:
273
+ base_col_idx = np.r_[base_col_idx, base_col_idx[-1] + 1]
274
+
275
+ # The first d rows -- associated with the ARIMA differences -- will have 1s in all columns.
276
+ col_idx = np.tile(base_col_idx, d)
277
+ row_idx = np.arange(d).repeat(D + 1)
278
+
279
+ # Next, if there are seasonal differences, there will be more rows, with the columns slowly dropping off.
280
+ # Starting from the d+1-th row, there will be 1 in the column positions every S rows, for a total of (D-1) rows.
281
+ # Every row will drop 2 columns from the left of base_col_idx.
282
+ for i in range(D):
283
+ n = len(base_col_idx[i:])
284
+ col_idx = np.r_[col_idx, base_col_idx[i:]]
285
+ row_idx = np.r_[row_idx, np.full(n, d + S * i)]
286
+
287
+ if D == 0 and d > 0:
288
+ # Special case: If there are *only* ARIMA lags, there still needs to be a single column of 1s at position
289
+ # [:d, d]
290
+ row_idx = np.arange(d)
291
+ col_idx = np.full(d, d)
292
+ T[row_idx, col_idx] = 1
293
+
294
+ if S > 0:
295
+ # "Rolling" indices for seasonal differences
296
+ (row_roll_idx, col_roll_idx) = np.diag_indices(S * D)
297
+ row_roll_idx = row_roll_idx + d + 1
298
+ col_roll_idx = col_roll_idx + d
299
+
300
+ # Rolling indices have a zero after every diagonal of length S-1
301
+ T[row_roll_idx, col_roll_idx] = 1
302
+ zero_idx = row_roll_idx[S - 1 :: S], col_roll_idx[S - 1 :: S]
303
+ T[zero_idx] = 0
304
+
305
+ # Bottom part
306
+ # Rolling indices for the "compute" states, x_star
307
+ star_roll_row, star_roll_col = np.diag_indices(k_lags - 1)
308
+ star_roll_row = star_roll_row + n_diffs
309
+ star_roll_col = star_roll_col + n_diffs + 1
310
+
311
+ T[star_roll_row, star_roll_col] = 1
312
+
313
+ return T
314
+
315
+
316
+ def conform_time_varying_and_time_invariant_matrices(A, B):
317
+ """
318
+ Adjust either A or B to conform to the other in the time dimension
319
+
320
+ In the context of building a structural model from components, it might be the case that one component has
321
+ time-varying statespace matrices, while the other does not. In this case, it is not possible to concatenate
322
+ or block diagonalize the pair of matrices A and B without first expanding the time-invariant matrix to have a
323
+ time dimension. This function checks if exactly one of the two time varies, and adjusts the other accordingly if
324
+ need be.
325
+
326
+ Parameters
327
+ ----------
328
+ A: pt.TensorVariable
329
+ An anonymous statespace matrix
330
+ B: pt.TensorVariable
331
+ An anonymous statespace matrix
332
+
333
+ Returns
334
+ -------
335
+ (A, B): Tuple of pt.TensorVariable
336
+ A and B, with one or neither expanded to have a time dimension.
337
+ """
338
+
339
+ if A.name == B.name:
340
+ name = A.name
341
+ else:
342
+ if all([X.name not in MATRIX_NAMES + LONG_MATRIX_NAMES for X in [A, B]]):
343
+ raise ValueError(
344
+ "At least one matrix passed to conform_time_varying_and_time_invariant_matrices should be a "
345
+ "statespace matrix"
346
+ )
347
+ name = A.name if A.name in MATRIX_NAMES + LONG_MATRIX_NAMES else B.name
348
+
349
+ time_varying_ndim = 3 - int(name in VECTOR_VALUED)
350
+
351
+ if not all([x.ndim == time_varying_ndim for x in [A, B]]):
352
+ return A, B
353
+
354
+ T_A, *A_dims = A.type.shape
355
+ T_B, *B_dims = B.type.shape
356
+
357
+ if T_A == T_B:
358
+ return A, B
359
+
360
+ if T_A == 1:
361
+ A_out = pt.repeat(A, B.shape[0], axis=0)
362
+ A_out = pt.specify_shape(A_out, (T_B, *tuple(A_dims)))
363
+ A_out.name = A.name
364
+
365
+ return A_out, B
366
+
367
+ if T_B == 1:
368
+ B_out = pt.repeat(B, A.shape[0], axis=0)
369
+ B_out = pt.specify_shape(B_out, (T_A, *tuple(B_dims)))
370
+ B_out.name = B.name
371
+
372
+ return A, B_out
373
+
374
+ return A, B
375
+
376
+
377
+ def get_exog_dims_from_idata(exog_name, idata):
378
+ if exog_name in idata.posterior.data_vars:
379
+ exog_dims = idata.posterior[exog_name].dims[2:]
380
+ elif exog_name in getattr(idata, "constant_data", []):
381
+ exog_dims = idata.constant_data[exog_name].dims
382
+ elif exog_name in getattr(idata, "mutable_data", []):
383
+ exog_dims = idata.mutable_data[exog_name].dims
384
+ else:
385
+ exog_dims = None
386
+
387
+ return exog_dims
File without changes
@@ -0,0 +1,74 @@
1
+ import pytensor
2
+
3
+ ALL_STATE_DIM = "state"
4
+ ALL_STATE_AUX_DIM = "state_aux"
5
+ OBS_STATE_DIM = "observed_state"
6
+ OBS_STATE_AUX_DIM = "observed_state_aux"
7
+ SHOCK_DIM = "shock"
8
+ SHOCK_AUX_DIM = "shock_aux"
9
+ TIME_DIM = "time"
10
+ AR_PARAM_DIM = "ar_lag"
11
+ MA_PARAM_DIM = "ma_lag"
12
+ SEASONAL_AR_PARAM_DIM = "seasonal_ar_lag"
13
+ SEASONAL_MA_PARAM_DIM = "seasonal_ma_lag"
14
+ ETS_SEASONAL_DIM = "seasonal_lag"
15
+
16
+ NEVER_TIME_VARYING = ["initial_state", "initial_state_cov", "a0", "P0"]
17
+ VECTOR_VALUED = ["initial_state", "state_intercept", "obs_intercept", "a0", "c", "d"]
18
+
19
+ MISSING_FILL = -9999.0
20
+ JITTER_DEFAULT = 1e-8 if pytensor.config.floatX.endswith("64") else 1e-6
21
+
22
+ FILTER_OUTPUT_TYPES = ["filtered", "predicted", "smoothed"]
23
+
24
+ MATRIX_NAMES = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
25
+ LONG_MATRIX_NAMES = [
26
+ "initial_state",
27
+ "initial_state_cov",
28
+ "state_intercept",
29
+ "obs_intercept",
30
+ "transition",
31
+ "design",
32
+ "selection",
33
+ "obs_cov",
34
+ "state_cov",
35
+ ]
36
+
37
+ SHORT_NAME_TO_LONG = dict(zip(MATRIX_NAMES, LONG_MATRIX_NAMES))
38
+ LONG_NAME_TO_SHORT = dict(zip(LONG_MATRIX_NAMES, MATRIX_NAMES))
39
+
40
+ FILTER_OUTPUT_NAMES = [
41
+ "filtered_state",
42
+ "predicted_state",
43
+ "filtered_covariance",
44
+ "predicted_covariance",
45
+ ]
46
+
47
+ SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"]
48
+ OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"]
49
+
50
+ MATRIX_DIMS = {
51
+ "x0": (ALL_STATE_DIM,),
52
+ "P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),
53
+ "c": (ALL_STATE_DIM,),
54
+ "d": (OBS_STATE_DIM,),
55
+ "T": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),
56
+ "Z": (OBS_STATE_DIM, ALL_STATE_DIM),
57
+ "R": (ALL_STATE_DIM, SHOCK_DIM),
58
+ "H": (OBS_STATE_DIM, OBS_STATE_AUX_DIM),
59
+ "Q": (SHOCK_DIM, SHOCK_AUX_DIM),
60
+ }
61
+
62
+ FILTER_OUTPUT_DIMS = {
63
+ "filtered_state": (TIME_DIM, ALL_STATE_DIM),
64
+ "smoothed_state": (TIME_DIM, ALL_STATE_DIM),
65
+ "predicted_state": (TIME_DIM, ALL_STATE_DIM),
66
+ "filtered_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
67
+ "smoothed_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
68
+ "predicted_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
69
+ "predicted_observed_state": (TIME_DIM, OBS_STATE_DIM),
70
+ "predicted_observed_covariance": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM),
71
+ }
72
+
73
+ POSITION_DERIVATIVE_NAMES = ["level", "trend", "acceleration", "jerk", "snap", "crackle", "pop"]
74
+ SARIMAX_STATE_STRUCTURES = ["fast", "interpretable"]
File without changes
@@ -0,0 +1,182 @@
1
+ import warnings
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pymc as pm
6
+ import pytensor
7
+ import pytensor.tensor as pt
8
+
9
+ from pymc import ImputationWarning, modelcontext
10
+ from pytensor.tensor.sharedvar import TensorSharedVariable
11
+
12
+ from pymc_extras.statespace.utils.constants import (
13
+ MISSING_FILL,
14
+ OBS_STATE_DIM,
15
+ TIME_DIM,
16
+ )
17
+
18
+ NO_TIME_INDEX_WARNING = (
19
+ "No time index found on the supplied data. A simple range index will be automatically "
20
+ "generated."
21
+ )
22
+ NO_FREQ_INFO_WARNING = "No frequency was specific on the data's DateTimeIndex."
23
+
24
+
25
+ def get_data_dims(data):
26
+ if not isinstance(data, pt.TensorVariable | TensorSharedVariable):
27
+ return
28
+
29
+ data_name = getattr(data, "name", None)
30
+ if not data_name:
31
+ return
32
+
33
+ pm_mod = modelcontext(None)
34
+ data_dims = pm_mod.named_vars_to_dims.get(data_name, None)
35
+ return data_dims
36
+
37
+
38
+ def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=False, col_names=None):
39
+ if col_names is None:
40
+ col_names = []
41
+
42
+ if len(data_shape) != 2:
43
+ raise ValueError("Data must be a 2d matrix")
44
+
45
+ if data_shape[-1] != n_obs:
46
+ raise ValueError(
47
+ f"Shape of data does not match model output. Expected {n_obs} columns, "
48
+ f"found {data_shape[-1]}."
49
+ )
50
+
51
+ if check_col_names:
52
+ missing_cols = set(obs_coords) - set(col_names)
53
+ if len(missing_cols) > 0:
54
+ raise ValueError(
55
+ "Columns of DataFrame provided as data do not match state names. The following states were"
56
+ f'not found: {", ".join(missing_cols)}. This may result in unexpected results in complex'
57
+ f"statespace models"
58
+ )
59
+
60
+
61
+ def preprocess_tensor_data(data, n_obs, obs_coords=None):
62
+ data_shape = data.shape.eval()
63
+ _validate_data_shape(data_shape, n_obs, obs_coords)
64
+ if obs_coords is not None:
65
+ warnings.warn(NO_TIME_INDEX_WARNING)
66
+ index = np.arange(data_shape[0], dtype="int")
67
+
68
+ return data.eval(), index
69
+
70
+
71
+ def preprocess_numpy_data(data, n_obs, obs_coords=None):
72
+ _validate_data_shape(data.shape, n_obs, obs_coords)
73
+ if obs_coords is not None:
74
+ warnings.warn(NO_TIME_INDEX_WARNING)
75
+
76
+ index = np.arange(data.shape[0], dtype="int")
77
+
78
+ return data, index
79
+
80
+
81
+ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=False):
82
+ if isinstance(data, pd.Series):
83
+ if data.name is None:
84
+ data.name = "data"
85
+ data = data.to_frame()
86
+
87
+ col_names = data.columns
88
+ _validate_data_shape(data.shape, n_obs, obs_coords, check_column_names, col_names)
89
+
90
+ if isinstance(data.index, pd.RangeIndex):
91
+ if obs_coords is not None:
92
+ warnings.warn(NO_TIME_INDEX_WARNING)
93
+ return preprocess_numpy_data(data.values, n_obs, obs_coords)
94
+
95
+ elif isinstance(data.index, pd.DatetimeIndex):
96
+ if data.index.freq is None:
97
+ warnings.warn(NO_FREQ_INFO_WARNING)
98
+ data.index.freq = data.index.inferred_freq
99
+
100
+ index = data.index
101
+ return data.values, index
102
+
103
+ else:
104
+ raise IndexError(
105
+ f"Expected pd.DatetimeIndex or pd.RangeIndex on data, found {type(data.index)}"
106
+ )
107
+
108
+
109
+ def add_data_to_active_model(values, index, data_dims=None):
110
+ pymc_mod = modelcontext(None)
111
+ if data_dims is None:
112
+ data_dims = [TIME_DIM, OBS_STATE_DIM]
113
+ time_dim = data_dims[0]
114
+
115
+ if time_dim not in pymc_mod.coords:
116
+ pymc_mod.add_coord(time_dim, index)
117
+ else:
118
+ found_time = pymc_mod.coords[time_dim]
119
+ if found_time is None:
120
+ pymc_mod.coords.update({time_dim: index})
121
+ elif not np.array_equal(found_time, tuple(index)):
122
+ raise ValueError(
123
+ "Provided data has a different time index than the model. Please ensure that the time values "
124
+ "set on coords matches that of the exogenous data."
125
+ )
126
+
127
+ # If the data has just one column, we need to specify the shape as (None, 1), or else the JAX backend will
128
+ # raise a broadcasting error.
129
+ data_shape = None
130
+ if values.shape[-1] == 1:
131
+ data_shape = (None, 1)
132
+
133
+ data = pm.Data("data", values, dims=data_dims, shape=data_shape)
134
+
135
+ return data
136
+
137
+
138
+ def mask_missing_values_in_data(values, missing_fill_value=None):
139
+ if missing_fill_value is None:
140
+ missing_fill_value = MISSING_FILL
141
+
142
+ masked_values = np.ma.masked_invalid(values)
143
+ filled_values = masked_values.filled(missing_fill_value)
144
+ nan_mask = masked_values.mask
145
+
146
+ if np.any(nan_mask):
147
+ if np.any(values == missing_fill_value):
148
+ raise ValueError(
149
+ f"Provided data contains the value {missing_fill_value}, which is used as a missing value marker. "
150
+ f"Please manually change the missing_fill_value to avoid this collision."
151
+ )
152
+
153
+ impute_message = (
154
+ "Provided data contains missing values and"
155
+ " will be automatically imputed as hidden states"
156
+ " during Kalman filtering."
157
+ )
158
+
159
+ warnings.warn(impute_message, ImputationWarning)
160
+
161
+ return filled_values, nan_mask
162
+
163
+
164
+ def register_data_with_pymc(
165
+ data, n_obs, obs_coords, register_data=True, missing_fill_value=None, data_dims=None
166
+ ):
167
+ if isinstance(data, pt.TensorVariable | TensorSharedVariable):
168
+ values, index = preprocess_tensor_data(data, n_obs, obs_coords)
169
+ elif isinstance(data, np.ndarray):
170
+ values, index = preprocess_numpy_data(data, n_obs, obs_coords)
171
+ elif isinstance(data, pd.DataFrame | pd.Series):
172
+ values, index = preprocess_pandas_data(data, n_obs, obs_coords)
173
+ else:
174
+ raise ValueError("Data should be one of pytensor tensor, numpy array, or pandas dataframe")
175
+
176
+ data, nan_mask = mask_missing_values_in_data(values, missing_fill_value)
177
+
178
+ if register_data:
179
+ data = add_data_to_active_model(data, index, data_dims)
180
+ else:
181
+ data = pytensor.shared(data, name="data")
182
+ return data, nan_mask
@@ -0,0 +1,23 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from pymc_extras.utils import prior, spline
17
+ from pymc_extras.utils.linear_cg import linear_cg
18
+
19
+ __all__ = (
20
+ "linear_cg",
21
+ "prior",
22
+ "spline",
23
+ )