pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,438 @@
1
+ import copy
2
+
3
+ import numpy as np
4
+ import pytensor
5
+ import pytensor.tensor as pt
6
+
7
+ from pymc_extras.statespace.utils.constants import (
8
+ NEVER_TIME_VARYING,
9
+ VECTOR_VALUED,
10
+ )
11
+
12
+ floatX = pytensor.config.floatX
13
+ KeyLike = tuple[str | int, ...] | str
14
+
15
+
16
+ class PytensorRepresentation:
17
+ r"""
18
+ Core class to hold all objects required by linear gaussian statespace models
19
+
20
+ Notation for the linear statespace model is taken from [1], while the specific implementation is adapted from
21
+ the statsmodels implementation: https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py
22
+ described in [2].
23
+
24
+ Parameters
25
+ ----------
26
+ k_endog: int
27
+ Number of observed states (called "endogeous states" in statsmodels)
28
+ k_states: int
29
+ Number of hidden states
30
+ k_posdef: int
31
+ Number of states that have exogenous shocks; also the rank of the selection matrix R.
32
+ design: ArrayLike, optional
33
+ Design matrix, denoted 'Z' in [1].
34
+ obs_intercept: ArrayLike, optional
35
+ Constant vector in the observation equation, denoted 'd' in [1]. Currently
36
+ not used.
37
+ obs_cov: ArrayLike, optional
38
+ Covariance matrix for multivariate-normal errors in the observation equation. Denoted 'H' in
39
+ [1].
40
+ transition: ArrayLike, optional
41
+ Transition equation that updates the hidden state between time-steps. Denoted 'T' in [1].
42
+ state_intercept: ArrayLike, optional
43
+ Constant vector for the observation equation, denoted 'c' in [1]. Currently not used.
44
+ selection: ArrayLike, optional
45
+ Selection matrix that matches shocks to hidden states, denoted 'R' in [1]. This is the identity
46
+ matrix when k_posdef = k_states.
47
+ state_cov: ArrayLike, optional
48
+ Covariance matrix for state equations, denoted 'Q' in [1]. Null matrix when there is no observation
49
+ noise.
50
+ initial_state: ArrayLike, optional
51
+ Experimental setting to allow for Bayesian estimation of the initial state, denoted `alpha_0` in [1]. Default
52
+ It should potentially be removed in favor of the closed-form diffuse initialization.
53
+ initial_state_cov: ArrayLike, optional
54
+ Experimental setting to allow for Bayesian estimation of the initial state, denoted `P_0` in [1]. Default
55
+ It should potentially be removed in favor of the closed-form diffuse initialization.
56
+
57
+ Notes
58
+ -----
59
+ A linear statespace system is defined by two equations:
60
+
61
+ .. math::
62
+ \begin{align}
63
+ x_t &= A_t x_{t-1} + c_t + R_t \varepsilon_t \tag{1} \\
64
+ y_t &= Z_t x_t + d_t + \eta_t \tag{2} \\
65
+ \end{align}
66
+
67
+ Where :math:`\{x_t\}_{t=0}^T` is a trajectory of hidden states, and :math:`\{y_t\}_{t=0}^T` is a trajectory of
68
+ observable states. Equation 1 is known as the "state transition equation", while describes how the system evolves
69
+ over time. Equation 2 is the "observation equation", and maps the latent state processes to observed data.
70
+ The system is Gaussian when the innovations, :math:`\varepsilon_t`, and the measurement errors, :math:`\eta_t`,
71
+ are normally distributed. The definition is completed by specification of these distributions, as
72
+ well as an initial state distribution:
73
+
74
+ .. math::
75
+ \begin{align}
76
+ \varepsilon_t &\sim N(0, Q_t) \tag{3} \\
77
+ \eta_t &\sim N(0, H_t) \tag{4} \\
78
+ x_0 &\sim N(\bar{x}_0, P_0) \tag{5}
79
+ \end{align}
80
+
81
+ The 9 matrices that form equations 1 to 5 are summarized in the table below. We call :math:`N` the number of
82
+ observations, :math:`m` the number of hidden states, :math:`p` the number of observed states, and :math:`r` the
83
+ number of innovations.
84
+
85
+ +-----------------------------------+-------------------+-----------------------+
86
+ | Name | Symbol | Shape |
87
+ +===================================+===================+=======================+
88
+ | Initial hidden state mean | :math:`x_0` | :math:`m \times 1` |
89
+ +-----------------------------------+-------------------+-----------------------+
90
+ | Initial hidden state covariance | :math:`P_0` | :math:`m \times m` |
91
+ +-----------------------------------+-------------------+-----------------------+
92
+ | Hidden state vector intercept | :math:`c_t` | :math:`m \times 1` |
93
+ +-----------------------------------+-------------------+-----------------------+
94
+ | Observed state vector intercept | :math:`d_t` | :math:`p \times 1` |
95
+ +-----------------------------------+-------------------+-----------------------+
96
+ | Transition matrix | :math:`T_t` | :math:`m \times m` |
97
+ +-----------------------------------+-------------------+-----------------------+
98
+ | Design matrix | :math:`Z_t` | :math:`p \times m` |
99
+ +-----------------------------------+-------------------+-----------------------+
100
+ | Selection matrix | :math:`R_t` | :math:`m \times r` |
101
+ +-----------------------------------+-------------------+-----------------------+
102
+ | Observation noise covariance | :math:`H_t` | :math:`p \times p` |
103
+ +-----------------------------------+-------------------+-----------------------+
104
+ | Hidden state innovation covariance| :math:`Q_t` | :math:`r \times r` |
105
+ +-----------------------------------+-------------------+-----------------------+
106
+
107
+ The shapes listed above are the core shapes, but in the general case all of these matrices (except for :math:`x_0`
108
+ and :math:`P_0`) can be time varying. In this case, a time dimension of shape :math:`n`, equal to the number of
109
+ observations, can be added.
110
+
111
+ .. warning:: The time dimension is used as a batch dimension during kalman filtering, and must thus **always**
112
+ be the **leftmost** dimension.
113
+
114
+ The purpose of this class is to store these matrices, as well as to allow users to easily index into them. Matrices
115
+ are stored as pytensor ``TensorVariables`` of known shape. Shapes are always accessible via the ``.type.shape``
116
+ method, which should never return ``None``. Matrices can be accessed via normal numpy array slicing after first
117
+ indexing by the name of the desired array. The time dimension is stored on the far left, and is automatically
118
+ sliced away unless specifically requested by the user. See the examples for details.
119
+
120
+ Examples
121
+ --------
122
+ .. code:: python
123
+
124
+ from pymc_extras.statespace.core.representation import PytensorRepresentation
125
+ ssm = PytensorRepresentation(k_endog=1, k_states=3, k_posdef=1)
126
+
127
+ # Access matrices by their names
128
+ print(ssm['transition'].type.shape)
129
+ >>> (3, 3)
130
+
131
+ # Slice a matrices
132
+ print(ssm['observation_cov', 0, 0].eval())
133
+ >>> 0.0
134
+
135
+ # Set elements in a slice of a matrix
136
+ ssm['design', 0, 0] = 1
137
+ print(ssm['design'].eval())
138
+ >>> np.array([[1, 0, 0]])
139
+
140
+ # Setting an entire matrix is also permitted. If you set a time dimension, it must be the first dimension, and
141
+ # the "core" dimensions must agree with those set when the ssm object was instantiated.
142
+ ssm['obs_intercept'] = np.arange(10).reshape(10, 1) # 10 timesteps
143
+ print(ssm['obs_intercept'].eval())
144
+ >>> np.array([[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.], [9.]])
145
+
146
+ References
147
+ ----------
148
+ .. [1] Durbin, James, and Siem Jan Koopman. 2012.
149
+ Time Series Analysis by State Space Methods: Second Edition.
150
+ Oxford University Press.
151
+ .. [2] Fulton, Chad. "Estimating time series models by state space methods in Python: Statsmodels." (2015).
152
+ http://www.chadfulton.com/files/fulton_statsmodels_2017_v1.pdf
153
+ """
154
+
155
+ __slots__ = (
156
+ "k_endog",
157
+ "k_states",
158
+ "k_posdef",
159
+ "shapes",
160
+ "design",
161
+ "obs_intercept",
162
+ "obs_cov",
163
+ "transition",
164
+ "state_intercept",
165
+ "selection",
166
+ "state_cov",
167
+ "initial_state",
168
+ "initial_state_cov",
169
+ )
170
+
171
+ def __init__(
172
+ self,
173
+ k_endog: int,
174
+ k_states: int,
175
+ k_posdef: int,
176
+ design: np.ndarray | None = None,
177
+ obs_intercept: np.ndarray | None = None,
178
+ obs_cov=None,
179
+ transition=None,
180
+ state_intercept=None,
181
+ selection=None,
182
+ state_cov=None,
183
+ initial_state=None,
184
+ initial_state_cov=None,
185
+ ) -> None:
186
+ self.k_states = k_states
187
+ self.k_endog = k_endog
188
+ self.k_posdef = k_posdef if k_posdef is not None else k_states
189
+
190
+ # The first dimension is for time varying matrices; it could be n_obs. Not thinking about that now.
191
+ self.shapes = {
192
+ "design": (1, self.k_endog, self.k_states),
193
+ "obs_intercept": (1, self.k_endog),
194
+ "obs_cov": (1, self.k_endog, self.k_endog),
195
+ "transition": (1, self.k_states, self.k_states),
196
+ "state_intercept": (1, self.k_states),
197
+ "selection": (1, self.k_states, self.k_posdef),
198
+ "state_cov": (1, self.k_posdef, self.k_posdef),
199
+ # These are never time varying, so they don't have a dummy first dimension
200
+ "initial_state": (self.k_states,),
201
+ "initial_state_cov": (self.k_states, self.k_states),
202
+ }
203
+
204
+ # Initialize the representation matrices
205
+ scope = locals()
206
+ for name, shape in self.shapes.items():
207
+ if scope[name] is not None:
208
+ matrix = scope[name]
209
+ if isinstance(matrix, np.ndarray):
210
+ matrix = self._numpy_to_pytensor(name, matrix)
211
+ else:
212
+ matrix = self._check_provided_tensor(name, matrix)
213
+ setattr(self, name, matrix)
214
+
215
+ else:
216
+ matrix = pt.as_tensor_variable(
217
+ np.zeros(shape, dtype=floatX), name=name, ndim=len(shape)
218
+ )
219
+ setattr(self, name, matrix)
220
+
221
+ def _validate_key(self, key: KeyLike) -> None:
222
+ if key not in self.shapes:
223
+ raise IndexError(f"{key} is an invalid state space matrix name")
224
+
225
+ def _update_shape(self, key: KeyLike, value: np.ndarray | pt.Variable) -> None:
226
+ if isinstance(value, pt.TensorConstant | pt.TensorVariable):
227
+ shape = value.type.shape
228
+ else:
229
+ shape = value.shape
230
+
231
+ old_shape = self.shapes[key]
232
+ ndim_core = 1 if key in VECTOR_VALUED else 2
233
+ if not all([a == b for a, b in zip(shape[-ndim_core:], old_shape[-ndim_core:])]):
234
+ raise ValueError(
235
+ f"The last two dimensions of {key} must be {old_shape[-ndim_core:]}, found {shape[-ndim_core:]}"
236
+ )
237
+
238
+ # Add time dimension dummy if none present
239
+ if key not in NEVER_TIME_VARYING:
240
+ if len(shape) == 2 and key not in VECTOR_VALUED:
241
+ shape = (1, *shape)
242
+ elif len(shape) == 1:
243
+ shape = (1, *shape)
244
+
245
+ self.shapes[key] = shape
246
+
247
+ def _add_time_dim_to_slice(
248
+ self, name: str, slice_: list[int] | tuple[int], n_dim: int
249
+ ) -> tuple[int | slice, ...]:
250
+ # Case 1: There is never a time dim. No changes needed.
251
+ if name in NEVER_TIME_VARYING:
252
+ return slice_
253
+
254
+ # Case 2: The matrix has a time dim, and it was requested. No changes needed.
255
+ if len(slice_) == n_dim:
256
+ return slice_
257
+
258
+ # Case 3: There's no time dim on the matrix, and none requested. Slice away the dummy dim.
259
+ if len(slice_) < n_dim:
260
+ empty_slice = (slice(None, None, None),)
261
+ n_omitted = n_dim - len(slice_) - 1
262
+ return (0,) + tuple(slice_) + empty_slice * n_omitted
263
+
264
+ @staticmethod
265
+ def _validate_key_and_get_type(key: KeyLike) -> type[str]:
266
+ if isinstance(key, tuple) and not isinstance(key[0], str):
267
+ raise IndexError("First index must the name of a valid state space matrix.")
268
+
269
+ return type(key)
270
+
271
+ def _validate_matrix_shape(self, name: str, X: np.ndarray | pt.TensorVariable) -> None:
272
+ time_dim, *expected_shape = self.shapes[name]
273
+ expected_shape = tuple(expected_shape)
274
+ shape = X.shape if isinstance(X, np.ndarray) else X.type.shape
275
+
276
+ is_vector = name in VECTOR_VALUED
277
+ not_time_varying = name in NEVER_TIME_VARYING
278
+
279
+ if not_time_varying:
280
+ if is_vector:
281
+ if X.ndim != 1:
282
+ raise ValueError(
283
+ f"Array provided for {name} has {X.ndim} dimensions, but it must have exactly 1."
284
+ )
285
+
286
+ else:
287
+ if X.ndim != 2:
288
+ raise ValueError(
289
+ f"Array provided for {name} has {X.ndim} dimensions, but it must have exactly 2."
290
+ )
291
+
292
+ else:
293
+ if is_vector:
294
+ if X.ndim not in [1, 2]:
295
+ raise ValueError(
296
+ f"Array provided for {name} has {X.ndim} dimensions, "
297
+ f"expecting 1 (static) or 2 (time-varying)"
298
+ )
299
+
300
+ # Time varying vector case, check only the static shapes
301
+ if X.ndim == 2 and X.shape[1:] != expected_shape:
302
+ raise ValueError(
303
+ f"Last dimension of array provided for {name} has shape {X.shape[1]}, "
304
+ f"expected {expected_shape}"
305
+ )
306
+
307
+ else:
308
+ if X.ndim not in [2, 3]:
309
+ raise ValueError(
310
+ f"Array provided for {name} has {X.ndim} dimensions, "
311
+ f"expecting 2 (static) or 3 (time-varying)"
312
+ )
313
+
314
+ # Time varying matrix case, check only the static shapes
315
+ if X.ndim == 3 and shape[1:] != expected_shape:
316
+ raise ValueError(
317
+ f"Last two dimensions of array provided for {name} have shapes {X.shape[1:]}, "
318
+ f"expected {expected_shape}"
319
+ )
320
+
321
+ # TODO: Think of another way to validate shapes of time-varying matrices if we don't know the data
322
+ # when the PytensorRepresentation is recreated
323
+ # if X.shape[-1] != self.data.shape[0]:
324
+ # raise ValueError(
325
+ # f"Last dimension (time dimension) of array provided for {name} has shape "
326
+ # f"{X.shape[-1]}, expected {self.data.shape[0]} (equal to the first dimension of the "
327
+ # f"provided data)"
328
+ # )
329
+
330
+ def _check_provided_tensor(self, name: str, X: pt.TensorVariable) -> pt.TensorVariable:
331
+ self._validate_matrix_shape(name, X)
332
+ if name not in NEVER_TIME_VARYING:
333
+ if X.ndim == 1 and name in VECTOR_VALUED:
334
+ X = pt.expand_dims(X, (0,))
335
+ X = pt.specify_shape(X, self.shapes[name])
336
+
337
+ elif X.ndim == 2:
338
+ X = pt.expand_dims(X, (0,))
339
+ X = pt.specify_shape(X, self.shapes[name])
340
+
341
+ return X
342
+
343
+ def _numpy_to_pytensor(self, name: str, X: np.ndarray) -> pt.TensorVariable:
344
+ X = X.copy()
345
+ self._validate_matrix_shape(name, X)
346
+
347
+ # Add a time dimension if one isn't provided
348
+ if name not in NEVER_TIME_VARYING:
349
+ if X.ndim == 1 and name in VECTOR_VALUED:
350
+ X = X[None, ...]
351
+ elif X.ndim == 2 and name not in VECTOR_VALUED:
352
+ X = X[None, ...]
353
+
354
+ X_pt = pt.as_tensor(X, name=name, dtype=floatX)
355
+ return X_pt
356
+
357
+ def __getitem__(self, key: KeyLike) -> pt.TensorVariable:
358
+ _type = self._validate_key_and_get_type(key)
359
+
360
+ # Case 1: user asked for an entire matrix by name
361
+ if _type is str:
362
+ self._validate_key(key)
363
+ matrix = getattr(self, key)
364
+
365
+ # Slice away the time dimension if it's a dummy
366
+ if (matrix.type.shape[0] == 1) and (key not in NEVER_TIME_VARYING):
367
+ X = matrix[(0,) + (slice(None),) * (matrix.ndim - 1)]
368
+ X = pt.specify_shape(X, self.shapes[key][1:])
369
+ X.name = key
370
+
371
+ return X
372
+
373
+ # If it's never time varying, return everything
374
+ elif key in NEVER_TIME_VARYING:
375
+ return matrix
376
+
377
+ # Last possibility is that it's time varying -- also return everything (for now, might need some processing)
378
+ else:
379
+ return matrix
380
+
381
+ # Case 2: user asked for a particular matrix and some slices of it
382
+ elif _type is tuple:
383
+ name, *slice_ = key
384
+ slice_ = tuple(slice_)
385
+ self._validate_key(name)
386
+
387
+ matrix = getattr(self, name)
388
+ # Case 2a: The user asked for the whole matrix, with time dummies. Return the whole thing
389
+ # without slicing anything away
390
+ if slice_ == (slice(None, None, None),) * matrix.ndim:
391
+ return matrix
392
+
393
+ # Case 2b: The user asked for the whole matrix except time dummies. Ignore the slice and act like we're in
394
+ # case 1.
395
+ elif slice_ == (slice(None, None, None),) * (matrix.ndim - 1):
396
+ X = matrix[(0,) + (slice(None),) * (matrix.ndim - 1)]
397
+ X = pt.specify_shape(X, self.shapes[name][1:])
398
+ X.name = name
399
+ return X
400
+
401
+ # Case 3b: User asked for an arbitrary sub-matrix. Give it back -- nothing else to be done
402
+ slice_ = self._add_time_dim_to_slice(name, slice_, matrix.ndim)
403
+ return matrix[slice_]
404
+
405
+ # Case 3: There is only one slice index, but it's not a string
406
+ else:
407
+ raise IndexError("First index must the name of a valid state space matrix.")
408
+
409
+ def __setitem__(self, key: KeyLike, value: float | int | np.ndarray | pt.Variable) -> None:
410
+ _type = type(key)
411
+
412
+ # Case 1: key is a string: we are setting an entire matrix.
413
+ if _type is str:
414
+ self._validate_key(key)
415
+ if isinstance(value, np.ndarray):
416
+ value = self._numpy_to_pytensor(key, value)
417
+ else:
418
+ value.name = key
419
+
420
+ setattr(self, key, value)
421
+ self._update_shape(key, value)
422
+
423
+ # Case 2: key is a string plus a slice: we are setting a subset of a matrix
424
+ elif _type is tuple:
425
+ name, *slice_ = key
426
+ self._validate_key(name)
427
+
428
+ matrix = getattr(self, name)
429
+
430
+ slice_ = self._add_time_dim_to_slice(name, slice_, matrix.ndim)
431
+ matrix = pt.set_subtensor(matrix[slice_], value)
432
+ matrix = pt.specify_shape(matrix, self.shapes[name])
433
+ matrix.name = name
434
+
435
+ setattr(self, name, matrix)
436
+
437
+ def copy(self):
438
+ return copy.copy(self)