pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pytensor
|
|
5
|
+
import pytensor.tensor as pt
|
|
6
|
+
|
|
7
|
+
from pymc_extras.statespace.utils.constants import (
|
|
8
|
+
NEVER_TIME_VARYING,
|
|
9
|
+
VECTOR_VALUED,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
floatX = pytensor.config.floatX
|
|
13
|
+
KeyLike = tuple[str | int, ...] | str
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PytensorRepresentation:
|
|
17
|
+
r"""
|
|
18
|
+
Core class to hold all objects required by linear gaussian statespace models
|
|
19
|
+
|
|
20
|
+
Notation for the linear statespace model is taken from [1], while the specific implementation is adapted from
|
|
21
|
+
the statsmodels implementation: https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py
|
|
22
|
+
described in [2].
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
k_endog: int
|
|
27
|
+
Number of observed states (called "endogeous states" in statsmodels)
|
|
28
|
+
k_states: int
|
|
29
|
+
Number of hidden states
|
|
30
|
+
k_posdef: int
|
|
31
|
+
Number of states that have exogenous shocks; also the rank of the selection matrix R.
|
|
32
|
+
design: ArrayLike, optional
|
|
33
|
+
Design matrix, denoted 'Z' in [1].
|
|
34
|
+
obs_intercept: ArrayLike, optional
|
|
35
|
+
Constant vector in the observation equation, denoted 'd' in [1]. Currently
|
|
36
|
+
not used.
|
|
37
|
+
obs_cov: ArrayLike, optional
|
|
38
|
+
Covariance matrix for multivariate-normal errors in the observation equation. Denoted 'H' in
|
|
39
|
+
[1].
|
|
40
|
+
transition: ArrayLike, optional
|
|
41
|
+
Transition equation that updates the hidden state between time-steps. Denoted 'T' in [1].
|
|
42
|
+
state_intercept: ArrayLike, optional
|
|
43
|
+
Constant vector for the observation equation, denoted 'c' in [1]. Currently not used.
|
|
44
|
+
selection: ArrayLike, optional
|
|
45
|
+
Selection matrix that matches shocks to hidden states, denoted 'R' in [1]. This is the identity
|
|
46
|
+
matrix when k_posdef = k_states.
|
|
47
|
+
state_cov: ArrayLike, optional
|
|
48
|
+
Covariance matrix for state equations, denoted 'Q' in [1]. Null matrix when there is no observation
|
|
49
|
+
noise.
|
|
50
|
+
initial_state: ArrayLike, optional
|
|
51
|
+
Experimental setting to allow for Bayesian estimation of the initial state, denoted `alpha_0` in [1]. Default
|
|
52
|
+
It should potentially be removed in favor of the closed-form diffuse initialization.
|
|
53
|
+
initial_state_cov: ArrayLike, optional
|
|
54
|
+
Experimental setting to allow for Bayesian estimation of the initial state, denoted `P_0` in [1]. Default
|
|
55
|
+
It should potentially be removed in favor of the closed-form diffuse initialization.
|
|
56
|
+
|
|
57
|
+
Notes
|
|
58
|
+
-----
|
|
59
|
+
A linear statespace system is defined by two equations:
|
|
60
|
+
|
|
61
|
+
.. math::
|
|
62
|
+
\begin{align}
|
|
63
|
+
x_t &= A_t x_{t-1} + c_t + R_t \varepsilon_t \tag{1} \\
|
|
64
|
+
y_t &= Z_t x_t + d_t + \eta_t \tag{2} \\
|
|
65
|
+
\end{align}
|
|
66
|
+
|
|
67
|
+
Where :math:`\{x_t\}_{t=0}^T` is a trajectory of hidden states, and :math:`\{y_t\}_{t=0}^T` is a trajectory of
|
|
68
|
+
observable states. Equation 1 is known as the "state transition equation", while describes how the system evolves
|
|
69
|
+
over time. Equation 2 is the "observation equation", and maps the latent state processes to observed data.
|
|
70
|
+
The system is Gaussian when the innovations, :math:`\varepsilon_t`, and the measurement errors, :math:`\eta_t`,
|
|
71
|
+
are normally distributed. The definition is completed by specification of these distributions, as
|
|
72
|
+
well as an initial state distribution:
|
|
73
|
+
|
|
74
|
+
.. math::
|
|
75
|
+
\begin{align}
|
|
76
|
+
\varepsilon_t &\sim N(0, Q_t) \tag{3} \\
|
|
77
|
+
\eta_t &\sim N(0, H_t) \tag{4} \\
|
|
78
|
+
x_0 &\sim N(\bar{x}_0, P_0) \tag{5}
|
|
79
|
+
\end{align}
|
|
80
|
+
|
|
81
|
+
The 9 matrices that form equations 1 to 5 are summarized in the table below. We call :math:`N` the number of
|
|
82
|
+
observations, :math:`m` the number of hidden states, :math:`p` the number of observed states, and :math:`r` the
|
|
83
|
+
number of innovations.
|
|
84
|
+
|
|
85
|
+
+-----------------------------------+-------------------+-----------------------+
|
|
86
|
+
| Name | Symbol | Shape |
|
|
87
|
+
+===================================+===================+=======================+
|
|
88
|
+
| Initial hidden state mean | :math:`x_0` | :math:`m \times 1` |
|
|
89
|
+
+-----------------------------------+-------------------+-----------------------+
|
|
90
|
+
| Initial hidden state covariance | :math:`P_0` | :math:`m \times m` |
|
|
91
|
+
+-----------------------------------+-------------------+-----------------------+
|
|
92
|
+
| Hidden state vector intercept | :math:`c_t` | :math:`m \times 1` |
|
|
93
|
+
+-----------------------------------+-------------------+-----------------------+
|
|
94
|
+
| Observed state vector intercept | :math:`d_t` | :math:`p \times 1` |
|
|
95
|
+
+-----------------------------------+-------------------+-----------------------+
|
|
96
|
+
| Transition matrix | :math:`T_t` | :math:`m \times m` |
|
|
97
|
+
+-----------------------------------+-------------------+-----------------------+
|
|
98
|
+
| Design matrix | :math:`Z_t` | :math:`p \times m` |
|
|
99
|
+
+-----------------------------------+-------------------+-----------------------+
|
|
100
|
+
| Selection matrix | :math:`R_t` | :math:`m \times r` |
|
|
101
|
+
+-----------------------------------+-------------------+-----------------------+
|
|
102
|
+
| Observation noise covariance | :math:`H_t` | :math:`p \times p` |
|
|
103
|
+
+-----------------------------------+-------------------+-----------------------+
|
|
104
|
+
| Hidden state innovation covariance| :math:`Q_t` | :math:`r \times r` |
|
|
105
|
+
+-----------------------------------+-------------------+-----------------------+
|
|
106
|
+
|
|
107
|
+
The shapes listed above are the core shapes, but in the general case all of these matrices (except for :math:`x_0`
|
|
108
|
+
and :math:`P_0`) can be time varying. In this case, a time dimension of shape :math:`n`, equal to the number of
|
|
109
|
+
observations, can be added.
|
|
110
|
+
|
|
111
|
+
.. warning:: The time dimension is used as a batch dimension during kalman filtering, and must thus **always**
|
|
112
|
+
be the **leftmost** dimension.
|
|
113
|
+
|
|
114
|
+
The purpose of this class is to store these matrices, as well as to allow users to easily index into them. Matrices
|
|
115
|
+
are stored as pytensor ``TensorVariables`` of known shape. Shapes are always accessible via the ``.type.shape``
|
|
116
|
+
method, which should never return ``None``. Matrices can be accessed via normal numpy array slicing after first
|
|
117
|
+
indexing by the name of the desired array. The time dimension is stored on the far left, and is automatically
|
|
118
|
+
sliced away unless specifically requested by the user. See the examples for details.
|
|
119
|
+
|
|
120
|
+
Examples
|
|
121
|
+
--------
|
|
122
|
+
.. code:: python
|
|
123
|
+
|
|
124
|
+
from pymc_extras.statespace.core.representation import PytensorRepresentation
|
|
125
|
+
ssm = PytensorRepresentation(k_endog=1, k_states=3, k_posdef=1)
|
|
126
|
+
|
|
127
|
+
# Access matrices by their names
|
|
128
|
+
print(ssm['transition'].type.shape)
|
|
129
|
+
>>> (3, 3)
|
|
130
|
+
|
|
131
|
+
# Slice a matrices
|
|
132
|
+
print(ssm['observation_cov', 0, 0].eval())
|
|
133
|
+
>>> 0.0
|
|
134
|
+
|
|
135
|
+
# Set elements in a slice of a matrix
|
|
136
|
+
ssm['design', 0, 0] = 1
|
|
137
|
+
print(ssm['design'].eval())
|
|
138
|
+
>>> np.array([[1, 0, 0]])
|
|
139
|
+
|
|
140
|
+
# Setting an entire matrix is also permitted. If you set a time dimension, it must be the first dimension, and
|
|
141
|
+
# the "core" dimensions must agree with those set when the ssm object was instantiated.
|
|
142
|
+
ssm['obs_intercept'] = np.arange(10).reshape(10, 1) # 10 timesteps
|
|
143
|
+
print(ssm['obs_intercept'].eval())
|
|
144
|
+
>>> np.array([[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.], [9.]])
|
|
145
|
+
|
|
146
|
+
References
|
|
147
|
+
----------
|
|
148
|
+
.. [1] Durbin, James, and Siem Jan Koopman. 2012.
|
|
149
|
+
Time Series Analysis by State Space Methods: Second Edition.
|
|
150
|
+
Oxford University Press.
|
|
151
|
+
.. [2] Fulton, Chad. "Estimating time series models by state space methods in Python: Statsmodels." (2015).
|
|
152
|
+
http://www.chadfulton.com/files/fulton_statsmodels_2017_v1.pdf
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
__slots__ = (
|
|
156
|
+
"k_endog",
|
|
157
|
+
"k_states",
|
|
158
|
+
"k_posdef",
|
|
159
|
+
"shapes",
|
|
160
|
+
"design",
|
|
161
|
+
"obs_intercept",
|
|
162
|
+
"obs_cov",
|
|
163
|
+
"transition",
|
|
164
|
+
"state_intercept",
|
|
165
|
+
"selection",
|
|
166
|
+
"state_cov",
|
|
167
|
+
"initial_state",
|
|
168
|
+
"initial_state_cov",
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
k_endog: int,
|
|
174
|
+
k_states: int,
|
|
175
|
+
k_posdef: int,
|
|
176
|
+
design: np.ndarray | None = None,
|
|
177
|
+
obs_intercept: np.ndarray | None = None,
|
|
178
|
+
obs_cov=None,
|
|
179
|
+
transition=None,
|
|
180
|
+
state_intercept=None,
|
|
181
|
+
selection=None,
|
|
182
|
+
state_cov=None,
|
|
183
|
+
initial_state=None,
|
|
184
|
+
initial_state_cov=None,
|
|
185
|
+
) -> None:
|
|
186
|
+
self.k_states = k_states
|
|
187
|
+
self.k_endog = k_endog
|
|
188
|
+
self.k_posdef = k_posdef if k_posdef is not None else k_states
|
|
189
|
+
|
|
190
|
+
# The first dimension is for time varying matrices; it could be n_obs. Not thinking about that now.
|
|
191
|
+
self.shapes = {
|
|
192
|
+
"design": (1, self.k_endog, self.k_states),
|
|
193
|
+
"obs_intercept": (1, self.k_endog),
|
|
194
|
+
"obs_cov": (1, self.k_endog, self.k_endog),
|
|
195
|
+
"transition": (1, self.k_states, self.k_states),
|
|
196
|
+
"state_intercept": (1, self.k_states),
|
|
197
|
+
"selection": (1, self.k_states, self.k_posdef),
|
|
198
|
+
"state_cov": (1, self.k_posdef, self.k_posdef),
|
|
199
|
+
# These are never time varying, so they don't have a dummy first dimension
|
|
200
|
+
"initial_state": (self.k_states,),
|
|
201
|
+
"initial_state_cov": (self.k_states, self.k_states),
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
# Initialize the representation matrices
|
|
205
|
+
scope = locals()
|
|
206
|
+
for name, shape in self.shapes.items():
|
|
207
|
+
if scope[name] is not None:
|
|
208
|
+
matrix = scope[name]
|
|
209
|
+
if isinstance(matrix, np.ndarray):
|
|
210
|
+
matrix = self._numpy_to_pytensor(name, matrix)
|
|
211
|
+
else:
|
|
212
|
+
matrix = self._check_provided_tensor(name, matrix)
|
|
213
|
+
setattr(self, name, matrix)
|
|
214
|
+
|
|
215
|
+
else:
|
|
216
|
+
matrix = pt.as_tensor_variable(
|
|
217
|
+
np.zeros(shape, dtype=floatX), name=name, ndim=len(shape)
|
|
218
|
+
)
|
|
219
|
+
setattr(self, name, matrix)
|
|
220
|
+
|
|
221
|
+
def _validate_key(self, key: KeyLike) -> None:
|
|
222
|
+
if key not in self.shapes:
|
|
223
|
+
raise IndexError(f"{key} is an invalid state space matrix name")
|
|
224
|
+
|
|
225
|
+
def _update_shape(self, key: KeyLike, value: np.ndarray | pt.Variable) -> None:
|
|
226
|
+
if isinstance(value, pt.TensorConstant | pt.TensorVariable):
|
|
227
|
+
shape = value.type.shape
|
|
228
|
+
else:
|
|
229
|
+
shape = value.shape
|
|
230
|
+
|
|
231
|
+
old_shape = self.shapes[key]
|
|
232
|
+
ndim_core = 1 if key in VECTOR_VALUED else 2
|
|
233
|
+
if not all([a == b for a, b in zip(shape[-ndim_core:], old_shape[-ndim_core:])]):
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"The last two dimensions of {key} must be {old_shape[-ndim_core:]}, found {shape[-ndim_core:]}"
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Add time dimension dummy if none present
|
|
239
|
+
if key not in NEVER_TIME_VARYING:
|
|
240
|
+
if len(shape) == 2 and key not in VECTOR_VALUED:
|
|
241
|
+
shape = (1, *shape)
|
|
242
|
+
elif len(shape) == 1:
|
|
243
|
+
shape = (1, *shape)
|
|
244
|
+
|
|
245
|
+
self.shapes[key] = shape
|
|
246
|
+
|
|
247
|
+
def _add_time_dim_to_slice(
|
|
248
|
+
self, name: str, slice_: list[int] | tuple[int], n_dim: int
|
|
249
|
+
) -> tuple[int | slice, ...]:
|
|
250
|
+
# Case 1: There is never a time dim. No changes needed.
|
|
251
|
+
if name in NEVER_TIME_VARYING:
|
|
252
|
+
return slice_
|
|
253
|
+
|
|
254
|
+
# Case 2: The matrix has a time dim, and it was requested. No changes needed.
|
|
255
|
+
if len(slice_) == n_dim:
|
|
256
|
+
return slice_
|
|
257
|
+
|
|
258
|
+
# Case 3: There's no time dim on the matrix, and none requested. Slice away the dummy dim.
|
|
259
|
+
if len(slice_) < n_dim:
|
|
260
|
+
empty_slice = (slice(None, None, None),)
|
|
261
|
+
n_omitted = n_dim - len(slice_) - 1
|
|
262
|
+
return (0,) + tuple(slice_) + empty_slice * n_omitted
|
|
263
|
+
|
|
264
|
+
@staticmethod
|
|
265
|
+
def _validate_key_and_get_type(key: KeyLike) -> type[str]:
|
|
266
|
+
if isinstance(key, tuple) and not isinstance(key[0], str):
|
|
267
|
+
raise IndexError("First index must the name of a valid state space matrix.")
|
|
268
|
+
|
|
269
|
+
return type(key)
|
|
270
|
+
|
|
271
|
+
def _validate_matrix_shape(self, name: str, X: np.ndarray | pt.TensorVariable) -> None:
|
|
272
|
+
time_dim, *expected_shape = self.shapes[name]
|
|
273
|
+
expected_shape = tuple(expected_shape)
|
|
274
|
+
shape = X.shape if isinstance(X, np.ndarray) else X.type.shape
|
|
275
|
+
|
|
276
|
+
is_vector = name in VECTOR_VALUED
|
|
277
|
+
not_time_varying = name in NEVER_TIME_VARYING
|
|
278
|
+
|
|
279
|
+
if not_time_varying:
|
|
280
|
+
if is_vector:
|
|
281
|
+
if X.ndim != 1:
|
|
282
|
+
raise ValueError(
|
|
283
|
+
f"Array provided for {name} has {X.ndim} dimensions, but it must have exactly 1."
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
else:
|
|
287
|
+
if X.ndim != 2:
|
|
288
|
+
raise ValueError(
|
|
289
|
+
f"Array provided for {name} has {X.ndim} dimensions, but it must have exactly 2."
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
else:
|
|
293
|
+
if is_vector:
|
|
294
|
+
if X.ndim not in [1, 2]:
|
|
295
|
+
raise ValueError(
|
|
296
|
+
f"Array provided for {name} has {X.ndim} dimensions, "
|
|
297
|
+
f"expecting 1 (static) or 2 (time-varying)"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Time varying vector case, check only the static shapes
|
|
301
|
+
if X.ndim == 2 and X.shape[1:] != expected_shape:
|
|
302
|
+
raise ValueError(
|
|
303
|
+
f"Last dimension of array provided for {name} has shape {X.shape[1]}, "
|
|
304
|
+
f"expected {expected_shape}"
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
else:
|
|
308
|
+
if X.ndim not in [2, 3]:
|
|
309
|
+
raise ValueError(
|
|
310
|
+
f"Array provided for {name} has {X.ndim} dimensions, "
|
|
311
|
+
f"expecting 2 (static) or 3 (time-varying)"
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Time varying matrix case, check only the static shapes
|
|
315
|
+
if X.ndim == 3 and shape[1:] != expected_shape:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"Last two dimensions of array provided for {name} have shapes {X.shape[1:]}, "
|
|
318
|
+
f"expected {expected_shape}"
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# TODO: Think of another way to validate shapes of time-varying matrices if we don't know the data
|
|
322
|
+
# when the PytensorRepresentation is recreated
|
|
323
|
+
# if X.shape[-1] != self.data.shape[0]:
|
|
324
|
+
# raise ValueError(
|
|
325
|
+
# f"Last dimension (time dimension) of array provided for {name} has shape "
|
|
326
|
+
# f"{X.shape[-1]}, expected {self.data.shape[0]} (equal to the first dimension of the "
|
|
327
|
+
# f"provided data)"
|
|
328
|
+
# )
|
|
329
|
+
|
|
330
|
+
def _check_provided_tensor(self, name: str, X: pt.TensorVariable) -> pt.TensorVariable:
|
|
331
|
+
self._validate_matrix_shape(name, X)
|
|
332
|
+
if name not in NEVER_TIME_VARYING:
|
|
333
|
+
if X.ndim == 1 and name in VECTOR_VALUED:
|
|
334
|
+
X = pt.expand_dims(X, (0,))
|
|
335
|
+
X = pt.specify_shape(X, self.shapes[name])
|
|
336
|
+
|
|
337
|
+
elif X.ndim == 2:
|
|
338
|
+
X = pt.expand_dims(X, (0,))
|
|
339
|
+
X = pt.specify_shape(X, self.shapes[name])
|
|
340
|
+
|
|
341
|
+
return X
|
|
342
|
+
|
|
343
|
+
def _numpy_to_pytensor(self, name: str, X: np.ndarray) -> pt.TensorVariable:
|
|
344
|
+
X = X.copy()
|
|
345
|
+
self._validate_matrix_shape(name, X)
|
|
346
|
+
|
|
347
|
+
# Add a time dimension if one isn't provided
|
|
348
|
+
if name not in NEVER_TIME_VARYING:
|
|
349
|
+
if X.ndim == 1 and name in VECTOR_VALUED:
|
|
350
|
+
X = X[None, ...]
|
|
351
|
+
elif X.ndim == 2 and name not in VECTOR_VALUED:
|
|
352
|
+
X = X[None, ...]
|
|
353
|
+
|
|
354
|
+
X_pt = pt.as_tensor(X, name=name, dtype=floatX)
|
|
355
|
+
return X_pt
|
|
356
|
+
|
|
357
|
+
def __getitem__(self, key: KeyLike) -> pt.TensorVariable:
|
|
358
|
+
_type = self._validate_key_and_get_type(key)
|
|
359
|
+
|
|
360
|
+
# Case 1: user asked for an entire matrix by name
|
|
361
|
+
if _type is str:
|
|
362
|
+
self._validate_key(key)
|
|
363
|
+
matrix = getattr(self, key)
|
|
364
|
+
|
|
365
|
+
# Slice away the time dimension if it's a dummy
|
|
366
|
+
if (matrix.type.shape[0] == 1) and (key not in NEVER_TIME_VARYING):
|
|
367
|
+
X = matrix[(0,) + (slice(None),) * (matrix.ndim - 1)]
|
|
368
|
+
X = pt.specify_shape(X, self.shapes[key][1:])
|
|
369
|
+
X.name = key
|
|
370
|
+
|
|
371
|
+
return X
|
|
372
|
+
|
|
373
|
+
# If it's never time varying, return everything
|
|
374
|
+
elif key in NEVER_TIME_VARYING:
|
|
375
|
+
return matrix
|
|
376
|
+
|
|
377
|
+
# Last possibility is that it's time varying -- also return everything (for now, might need some processing)
|
|
378
|
+
else:
|
|
379
|
+
return matrix
|
|
380
|
+
|
|
381
|
+
# Case 2: user asked for a particular matrix and some slices of it
|
|
382
|
+
elif _type is tuple:
|
|
383
|
+
name, *slice_ = key
|
|
384
|
+
slice_ = tuple(slice_)
|
|
385
|
+
self._validate_key(name)
|
|
386
|
+
|
|
387
|
+
matrix = getattr(self, name)
|
|
388
|
+
# Case 2a: The user asked for the whole matrix, with time dummies. Return the whole thing
|
|
389
|
+
# without slicing anything away
|
|
390
|
+
if slice_ == (slice(None, None, None),) * matrix.ndim:
|
|
391
|
+
return matrix
|
|
392
|
+
|
|
393
|
+
# Case 2b: The user asked for the whole matrix except time dummies. Ignore the slice and act like we're in
|
|
394
|
+
# case 1.
|
|
395
|
+
elif slice_ == (slice(None, None, None),) * (matrix.ndim - 1):
|
|
396
|
+
X = matrix[(0,) + (slice(None),) * (matrix.ndim - 1)]
|
|
397
|
+
X = pt.specify_shape(X, self.shapes[name][1:])
|
|
398
|
+
X.name = name
|
|
399
|
+
return X
|
|
400
|
+
|
|
401
|
+
# Case 3b: User asked for an arbitrary sub-matrix. Give it back -- nothing else to be done
|
|
402
|
+
slice_ = self._add_time_dim_to_slice(name, slice_, matrix.ndim)
|
|
403
|
+
return matrix[slice_]
|
|
404
|
+
|
|
405
|
+
# Case 3: There is only one slice index, but it's not a string
|
|
406
|
+
else:
|
|
407
|
+
raise IndexError("First index must the name of a valid state space matrix.")
|
|
408
|
+
|
|
409
|
+
def __setitem__(self, key: KeyLike, value: float | int | np.ndarray | pt.Variable) -> None:
|
|
410
|
+
_type = type(key)
|
|
411
|
+
|
|
412
|
+
# Case 1: key is a string: we are setting an entire matrix.
|
|
413
|
+
if _type is str:
|
|
414
|
+
self._validate_key(key)
|
|
415
|
+
if isinstance(value, np.ndarray):
|
|
416
|
+
value = self._numpy_to_pytensor(key, value)
|
|
417
|
+
else:
|
|
418
|
+
value.name = key
|
|
419
|
+
|
|
420
|
+
setattr(self, key, value)
|
|
421
|
+
self._update_shape(key, value)
|
|
422
|
+
|
|
423
|
+
# Case 2: key is a string plus a slice: we are setting a subset of a matrix
|
|
424
|
+
elif _type is tuple:
|
|
425
|
+
name, *slice_ = key
|
|
426
|
+
self._validate_key(name)
|
|
427
|
+
|
|
428
|
+
matrix = getattr(self, name)
|
|
429
|
+
|
|
430
|
+
slice_ = self._add_time_dim_to_slice(name, slice_, matrix.ndim)
|
|
431
|
+
matrix = pt.set_subtensor(matrix[slice_], value)
|
|
432
|
+
matrix = pt.specify_shape(matrix, self.shapes[name])
|
|
433
|
+
matrix.name = name
|
|
434
|
+
|
|
435
|
+
setattr(self, name, matrix)
|
|
436
|
+
|
|
437
|
+
def copy(self):
|
|
438
|
+
return copy.copy(self)
|