pymc-extras 0.2.7__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. pymc_extras/inference/__init__.py +2 -2
  2. pymc_extras/inference/fit.py +1 -1
  3. pymc_extras/inference/laplace_approx/__init__.py +0 -0
  4. pymc_extras/inference/laplace_approx/find_map.py +354 -0
  5. pymc_extras/inference/laplace_approx/idata.py +393 -0
  6. pymc_extras/inference/laplace_approx/laplace.py +453 -0
  7. pymc_extras/inference/laplace_approx/scipy_interface.py +242 -0
  8. pymc_extras/inference/pathfinder/pathfinder.py +3 -4
  9. pymc_extras/linearmodel.py +3 -1
  10. pymc_extras/model/marginal/graph_analysis.py +4 -0
  11. pymc_extras/prior.py +38 -6
  12. pymc_extras/statespace/core/statespace.py +78 -52
  13. pymc_extras/statespace/filters/kalman_smoother.py +1 -1
  14. pymc_extras/statespace/models/structural/__init__.py +21 -0
  15. pymc_extras/statespace/models/structural/components/__init__.py +0 -0
  16. pymc_extras/statespace/models/structural/components/autoregressive.py +188 -0
  17. pymc_extras/statespace/models/structural/components/cycle.py +305 -0
  18. pymc_extras/statespace/models/structural/components/level_trend.py +257 -0
  19. pymc_extras/statespace/models/structural/components/measurement_error.py +137 -0
  20. pymc_extras/statespace/models/structural/components/regression.py +228 -0
  21. pymc_extras/statespace/models/structural/components/seasonality.py +445 -0
  22. pymc_extras/statespace/models/structural/core.py +900 -0
  23. pymc_extras/statespace/models/structural/utils.py +16 -0
  24. pymc_extras/statespace/models/utilities.py +285 -0
  25. pymc_extras/statespace/utils/constants.py +4 -4
  26. pymc_extras/statespace/utils/data_tools.py +3 -2
  27. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/METADATA +6 -6
  28. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/RECORD +30 -18
  29. pymc_extras/inference/find_map.py +0 -496
  30. pymc_extras/inference/laplace.py +0 -583
  31. pymc_extras/statespace/models/structural.py +0 -1679
  32. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/WHEEL +0 -0
  33. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -38,16 +38,15 @@ from pymc.blocking import DictToArrayBijection, RaveledVars
38
38
  from pymc.initial_point import make_initial_point_fn
39
39
  from pymc.model import modelcontext
40
40
  from pymc.model.core import Point
41
+ from pymc.progress_bar import CustomProgress, default_progress_theme
41
42
  from pymc.pytensorf import (
42
43
  compile,
43
44
  find_rng_nodes,
44
45
  reseed_rngs,
45
46
  )
46
47
  from pymc.util import (
47
- CustomProgress,
48
48
  RandomSeed,
49
49
  _get_seeds_per_chain,
50
- default_progress_theme,
51
50
  get_default_varnames,
52
51
  )
53
52
  from pytensor.compile.function.types import Function
@@ -63,7 +62,7 @@ from rich.text import Text
63
62
  # TODO: change to typing.Self after Python versions greater than 3.10
64
63
  from typing_extensions import Self
65
64
 
66
- from pymc_extras.inference.laplace import add_data_to_inferencedata
65
+ from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data
67
66
  from pymc_extras.inference.pathfinder.importance_sampling import (
68
67
  importance_sampling as _importance_sampling,
69
68
  )
@@ -1759,6 +1758,6 @@ def fit_pathfinder(
1759
1758
  importance_sampling=importance_sampling,
1760
1759
  )
1761
1760
 
1762
- idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
1761
+ idata = add_data_to_inference_data(idata, progressbar, model, compile_kwargs)
1763
1762
 
1764
1763
  return idata
@@ -2,10 +2,12 @@ import numpy as np
2
2
  import pandas as pd
3
3
  import pymc as pm
4
4
 
5
+ from sklearn.base import BaseEstimator
6
+
5
7
  from pymc_extras.model_builder import ModelBuilder
6
8
 
7
9
 
8
- class LinearModel(ModelBuilder):
10
+ class LinearModel(ModelBuilder, BaseEstimator):
9
11
  def __init__(
10
12
  self, model_config: dict | None = None, sampler_config: dict | None = None, nsamples=100
11
13
  ):
@@ -5,6 +5,7 @@ from itertools import zip_longest
5
5
 
6
6
  from pymc import SymbolicRandomVariable
7
7
  from pymc.model.fgraph import ModelVar
8
+ from pymc.variational.minibatch_rv import MinibatchRandomVariable
8
9
  from pytensor.graph import Variable, ancestors
9
10
  from pytensor.graph.basic import io_toposort
10
11
  from pytensor.tensor import TensorType, TensorVariable
@@ -313,6 +314,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
313
314
 
314
315
  var_dims[node.outputs[0]] = output_dims
315
316
 
317
+ elif isinstance(node.op, MinibatchRandomVariable):
318
+ var_dims[node.outputs[0]] = inputs_dims[0]
319
+
316
320
  else:
317
321
  raise NotImplementedError(f"Marginalization through operation {node} not supported.")
318
322
 
pymc_extras/prior.py CHANGED
@@ -84,6 +84,7 @@ from __future__ import annotations
84
84
  import copy
85
85
 
86
86
  from collections.abc import Callable
87
+ from functools import partial
87
88
  from inspect import signature
88
89
  from typing import Any, Protocol, runtime_checkable
89
90
 
@@ -278,7 +279,7 @@ class VariableFactory(Protocol):
278
279
  def sample_prior(
279
280
  factory: VariableFactory,
280
281
  coords=None,
281
- name: str = "var",
282
+ name: str = "variable",
282
283
  wrap: bool = False,
283
284
  **sample_prior_predictive_kwargs,
284
285
  ) -> xr.Dataset:
@@ -292,7 +293,7 @@ def sample_prior(
292
293
  The coordinates for the variable, by default None.
293
294
  Only required if the dims are specified.
294
295
  name : str, optional
295
- The name of the variable, by default "var".
296
+ The name of the variable, by default "variable".
296
297
  wrap : bool, optional
297
298
  Whether to wrap the variable in a `pm.Deterministic` node, by default False.
298
299
  sample_prior_predictive_kwargs : dict
@@ -362,7 +363,7 @@ class Prior:
362
363
 
363
364
  - `preliz` attribute to get the equivalent distribution in `preliz`
364
365
  - `sample_prior` method to sample from the prior
365
- - `graph` get a dummy model graph with the distribution
366
+ - `to_graph` get a dummy model graph with the distribution
366
367
  - `constrain` to shift the distribution to a different range
367
368
 
368
369
  Parameters
@@ -900,7 +901,7 @@ class Prior:
900
901
  def sample_prior(
901
902
  self,
902
903
  coords=None,
903
- name: str = "var",
904
+ name: str = "variable",
904
905
  **sample_prior_predictive_kwargs,
905
906
  ) -> xr.Dataset:
906
907
  """Sample the prior distribution for the variable.
@@ -911,7 +912,7 @@ class Prior:
911
912
  The coordinates for the variable, by default None.
912
913
  Only required if the dims are specified.
913
914
  name : str, optional
914
- The name of the variable, by default "var".
915
+ The name of the variable, by default "variable".
915
916
  sample_prior_predictive_kwargs : dict
916
917
  Additional arguments to pass to `pm.sample_prior_predictive`.
917
918
 
@@ -1175,7 +1176,7 @@ class Censored:
1175
1176
  """Create a censored distribution from a dictionary."""
1176
1177
  data = data["data"]
1177
1178
  return cls( # type: ignore
1178
- distribution=Prior.from_dict(data["dist"]),
1179
+ distribution=deserialize(data["dist"]),
1179
1180
  lower=data["lower"],
1180
1181
  upper=data["upper"],
1181
1182
  )
@@ -1354,3 +1355,34 @@ def _is_censored_type(data: dict) -> bool:
1354
1355
 
1355
1356
  register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict)
1356
1357
  register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict)
1358
+
1359
+
1360
+ def __getattr__(name: str):
1361
+ """Get Prior class through the module.
1362
+
1363
+ Examples
1364
+ --------
1365
+ Create a normal distribution.
1366
+
1367
+ .. code-block:: python
1368
+
1369
+ from pymc_extras.prior import Normal
1370
+
1371
+ dist = Normal(mu=1, sigma=2)
1372
+
1373
+ Create a hierarchical normal distribution.
1374
+
1375
+ .. code-block:: python
1376
+
1377
+ import pymc_extras.prior as pr
1378
+
1379
+ dist = pr.Normal(mu=pr.Normal(), sigma=pr.HalfNormal(), dims="channel")
1380
+ samples = dist.sample_prior(coords={"channel": ["C1", "C2", "C3"]})
1381
+
1382
+ """
1383
+ # Protect against doctest
1384
+ if name == "__wrapped__":
1385
+ return
1386
+
1387
+ _get_pymc_distribution(name)
1388
+ return partial(Prior, distribution=name)
@@ -2047,6 +2047,69 @@ class PyMCStateSpace:
2047
2047
 
2048
2048
  return scenario
2049
2049
 
2050
+ def _build_forecast_model(
2051
+ self, time_index, t0, forecast_index, scenario, filter_output, mvn_method
2052
+ ):
2053
+ filter_time_dim = TIME_DIM
2054
+ temp_coords = self._fit_coords.copy()
2055
+
2056
+ dims = None
2057
+ if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
2058
+ dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
2059
+
2060
+ t0_idx = np.flatnonzero(time_index == t0)[0]
2061
+
2062
+ temp_coords["data_time"] = time_index
2063
+ temp_coords[TIME_DIM] = forecast_index
2064
+
2065
+ mu_dims, cov_dims = None, None
2066
+ if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
2067
+ mu_dims = ["data_time", ALL_STATE_DIM]
2068
+ cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
2069
+
2070
+ with pm.Model(coords=temp_coords) as forecast_model:
2071
+ (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2072
+ data_dims=["data_time", OBS_STATE_DIM],
2073
+ )
2074
+
2075
+ group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
2076
+ mu, cov = grouped_outputs[group_idx]
2077
+
2078
+ sub_dict = {
2079
+ data_var: pt.as_tensor_variable(data_var.get_value(), name="data")
2080
+ for data_var in forecast_model.data_vars
2081
+ }
2082
+
2083
+ missing_data_vars = np.setdiff1d(
2084
+ ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
2085
+ )
2086
+ if missing_data_vars.size > 0:
2087
+ raise ValueError(f"{missing_data_vars} data used for fitting not found!")
2088
+
2089
+ mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True)
2090
+
2091
+ x0 = pm.Deterministic(
2092
+ "x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
2093
+ )
2094
+ P0 = pm.Deterministic(
2095
+ "P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
2096
+ )
2097
+
2098
+ _ = LinearGaussianStateSpace(
2099
+ "forecast",
2100
+ x0,
2101
+ P0,
2102
+ *matrices,
2103
+ steps=len(forecast_index),
2104
+ dims=dims,
2105
+ sequence_names=self.kalman_filter.seq_names,
2106
+ k_endog=self.k_endog,
2107
+ append_x0=False,
2108
+ method=mvn_method,
2109
+ )
2110
+
2111
+ return forecast_model
2112
+
2050
2113
  def forecast(
2051
2114
  self,
2052
2115
  idata: InferenceData,
@@ -2139,8 +2202,6 @@ class PyMCStateSpace:
2139
2202
  the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
2140
2203
 
2141
2204
  """
2142
- filter_time_dim = TIME_DIM
2143
-
2144
2205
  _validate_filter_arg(filter_output)
2145
2206
 
2146
2207
  compile_kwargs = kwargs.pop("compile_kwargs", {})
@@ -2185,58 +2246,23 @@ class PyMCStateSpace:
2185
2246
  use_scenario_index=use_scenario_index,
2186
2247
  )
2187
2248
  scenario = self._finalize_scenario_initialization(scenario, forecast_index)
2188
- temp_coords = self._fit_coords.copy()
2189
-
2190
- dims = None
2191
- if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
2192
- dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
2193
-
2194
- t0_idx = np.flatnonzero(time_index == t0)[0]
2195
-
2196
- temp_coords["data_time"] = time_index
2197
- temp_coords[TIME_DIM] = forecast_index
2198
-
2199
- mu_dims, cov_dims = None, None
2200
- if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
2201
- mu_dims = ["data_time", ALL_STATE_DIM]
2202
- cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
2203
-
2204
- with pm.Model(coords=temp_coords) as forecast_model:
2205
- (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2206
- scenario=scenario,
2207
- data_dims=["data_time", OBS_STATE_DIM],
2208
- )
2209
-
2210
- for name in self.data_names:
2211
- if name in scenario.keys():
2212
- pm.set_data(
2213
- {"data": np.zeros((len(forecast_index), self.k_endog))},
2214
- coords={"data_time": np.arange(len(forecast_index))},
2215
- )
2216
- break
2217
2249
 
2218
- group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
2219
- mu, cov = grouped_outputs[group_idx]
2220
-
2221
- x0 = pm.Deterministic(
2222
- "x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
2223
- )
2224
- P0 = pm.Deterministic(
2225
- "P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
2226
- )
2250
+ forecast_model = self._build_forecast_model(
2251
+ time_index=time_index,
2252
+ t0=t0,
2253
+ forecast_index=forecast_index,
2254
+ scenario=scenario,
2255
+ filter_output=filter_output,
2256
+ mvn_method=mvn_method,
2257
+ )
2227
2258
 
2228
- _ = LinearGaussianStateSpace(
2229
- "forecast",
2230
- x0,
2231
- P0,
2232
- *matrices,
2233
- steps=len(forecast_index),
2234
- dims=dims,
2235
- sequence_names=self.kalman_filter.seq_names,
2236
- k_endog=self.k_endog,
2237
- append_x0=False,
2238
- method=mvn_method,
2239
- )
2259
+ with forecast_model:
2260
+ if scenario is not None:
2261
+ dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
2262
+ pm.set_data(
2263
+ scenario | {"data": dummy_obs_data},
2264
+ coords={"data_time": np.arange(len(forecast_index))},
2265
+ )
2240
2266
 
2241
2267
  forecast_model.rvs_to_initial_values = {
2242
2268
  k: None for k in forecast_model.rvs_to_initial_values.keys()
@@ -105,7 +105,7 @@ class KalmanSmoother:
105
105
  a_hat, P_hat = self.predict(a, P, T, R, Q)
106
106
 
107
107
  # Use pinv, otherwise P_hat is singular when there is missing data
108
- smoother_gain = matrix_dot(pt.linalg.pinv(P_hat), T, P).T
108
+ smoother_gain = matrix_dot(pt.linalg.pinv(P_hat, hermitian=True), T, P).T
109
109
  a_smooth_next = a + smoother_gain @ (a_smooth - a_hat)
110
110
 
111
111
  P_smooth_next = P + quad_form_sym(smoother_gain, P_smooth - P_hat)
@@ -0,0 +1,21 @@
1
+ from pymc_extras.statespace.models.structural.components.autoregressive import (
2
+ AutoregressiveComponent,
3
+ )
4
+ from pymc_extras.statespace.models.structural.components.cycle import CycleComponent
5
+ from pymc_extras.statespace.models.structural.components.level_trend import LevelTrendComponent
6
+ from pymc_extras.statespace.models.structural.components.measurement_error import MeasurementError
7
+ from pymc_extras.statespace.models.structural.components.regression import RegressionComponent
8
+ from pymc_extras.statespace.models.structural.components.seasonality import (
9
+ FrequencySeasonality,
10
+ TimeSeasonality,
11
+ )
12
+
13
+ __all__ = [
14
+ "LevelTrendComponent",
15
+ "MeasurementError",
16
+ "AutoregressiveComponent",
17
+ "TimeSeasonality",
18
+ "FrequencySeasonality",
19
+ "RegressionComponent",
20
+ "CycleComponent",
21
+ ]
@@ -0,0 +1,188 @@
1
+ import numpy as np
2
+ import pytensor.tensor as pt
3
+
4
+ from pymc_extras.statespace.models.structural.core import Component
5
+ from pymc_extras.statespace.models.structural.utils import order_to_mask
6
+ from pymc_extras.statespace.utils.constants import AR_PARAM_DIM
7
+
8
+
9
+ class AutoregressiveComponent(Component):
10
+ r"""
11
+ Autoregressive timeseries component
12
+
13
+ Parameters
14
+ ----------
15
+ order: int or sequence of int
16
+
17
+ If int, the number of lags to include in the model.
18
+ If a sequence, an array-like of zeros and ones indicating which lags to include in the model.
19
+
20
+ name: str, default "auto_regressive"
21
+ A name for this autoregressive component. Used to label dimensions and coordinates.
22
+
23
+ observed_state_names: list[str] | None, default None
24
+ List of strings for observed state labels. If None, defaults to ["data"].
25
+
26
+ Notes
27
+ -----
28
+ An autoregressive component can be thought of as a way o introducing serially correlated errors into the model.
29
+ The process is modeled:
30
+
31
+ .. math::
32
+ x_t = \sum_{i=1}^p \rho_i x_{t-i}
33
+
34
+ Where ``p``, the number of autoregressive terms to model, is the order of the process. By default, all lags up to
35
+ ``p`` are included in the model. To disable lags, pass a list of zeros and ones to the ``order`` argumnet. For
36
+ example, ``order=[1, 1, 0, 1]`` would become:
37
+
38
+ .. math::
39
+ x_t = \rho_1 x_{t-1} + \rho_2 x_{t-1} + \rho_4 x_{t-1}
40
+
41
+ The coefficient :math:`\rho_3` has been constrained to zero.
42
+
43
+ .. warning:: This class is meant to be used as a component in a structural time series model. For modeling of
44
+ stationary processes with ARIMA, use ``statespace.BayesianSARIMA``.
45
+
46
+ Examples
47
+ --------
48
+ Model a timeseries as an AR(2) process with non-zero mean:
49
+
50
+ .. code:: python
51
+
52
+ from pymc_extras.statespace import structural as st
53
+ import pymc as pm
54
+ import pytensor.tensor as pt
55
+
56
+ trend = st.LevelTrendComponent(order=1, innovations_order=0)
57
+ ar = st.AutoregressiveComponent(2)
58
+ ss_mod = (trend + ar).build()
59
+
60
+ with pm.Model(coords=ss_mod.coords) as model:
61
+ P0 = pm.Deterministic('P0', pt.eye(ss_mod.k_states) * 10, dims=ss_mod.param_dims['P0'])
62
+ intitial_trend = pm.Normal('initial_trend', sigma=10, dims=ss_mod.param_dims['initial_trend'])
63
+ ar_params = pm.Normal('ar_params', dims=ss_mod.param_dims['ar_params'])
64
+ sigma_ar = pm.Exponential('sigma_ar', 1, dims=ss_mod.param_dims['sigma_ar'])
65
+
66
+ ss_mod.build_statespace_graph(data)
67
+ idata = pm.sample(nuts_sampler='numpyro')
68
+
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ order: int = 1,
74
+ name: str = "auto_regressive",
75
+ observed_state_names: list[str] | None = None,
76
+ ):
77
+ if observed_state_names is None:
78
+ observed_state_names = ["data"]
79
+
80
+ k_posdef = k_endog = len(observed_state_names)
81
+
82
+ order = order_to_mask(order)
83
+ ar_lags = np.flatnonzero(order).ravel().astype(int) + 1
84
+ k_states = len(order)
85
+
86
+ self.order = order
87
+ self.ar_lags = ar_lags
88
+
89
+ super().__init__(
90
+ name=name,
91
+ k_endog=k_endog,
92
+ k_states=k_states * k_endog,
93
+ k_posdef=k_posdef,
94
+ measurement_error=True,
95
+ combine_hidden_states=True,
96
+ observed_state_names=observed_state_names,
97
+ obs_state_idxs=np.tile(np.r_[[1.0], np.zeros(k_states - 1)], k_endog),
98
+ )
99
+
100
+ def populate_component_properties(self):
101
+ k_states = self.k_states // self.k_endog # this is also the number of AR lags
102
+
103
+ self.state_names = [
104
+ f"L{i + 1}[{state_name}]"
105
+ for state_name in self.observed_state_names
106
+ for i in range(k_states)
107
+ ]
108
+
109
+ self.shock_names = [f"{self.name}[{obs_name}]" for obs_name in self.observed_state_names]
110
+ self.param_names = [f"params_{self.name}", f"sigma_{self.name}"]
111
+ self.param_dims = {f"params_{self.name}": (f"lag_{self.name}",)}
112
+ self.coords = {f"lag_{self.name}": self.ar_lags.tolist()}
113
+
114
+ if self.k_endog > 1:
115
+ self.param_dims[f"params_{self.name}"] = (
116
+ f"endog_{self.name}",
117
+ AR_PARAM_DIM,
118
+ )
119
+ self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
120
+
121
+ self.coords[f"endog_{self.name}"] = self.observed_state_names
122
+
123
+ self.param_info = {
124
+ f"params_{self.name}": {
125
+ "shape": (k_states,) if self.k_endog == 1 else (self.k_endog, k_states),
126
+ "constraints": None,
127
+ "dims": (AR_PARAM_DIM,)
128
+ if self.k_endog == 1
129
+ else (
130
+ f"endog_{self.name}",
131
+ f"lag_{self.name}",
132
+ ),
133
+ },
134
+ f"sigma_{self.name}": {
135
+ "shape": () if self.k_endog == 1 else (self.k_endog,),
136
+ "constraints": "Positive",
137
+ "dims": None if self.k_endog == 1 else (f"endog_{self.name}",),
138
+ },
139
+ }
140
+
141
+ def make_symbolic_graph(self) -> None:
142
+ k_endog = self.k_endog
143
+ k_states = self.k_states // k_endog
144
+ k_posdef = self.k_posdef
145
+
146
+ k_nonzero = int(sum(self.order))
147
+ ar_params = self.make_and_register_variable(
148
+ f"params_{self.name}", shape=(k_nonzero,) if k_endog == 1 else (k_endog, k_nonzero)
149
+ )
150
+ sigma_ar = self.make_and_register_variable(
151
+ f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
152
+ )
153
+
154
+ if k_endog == 1:
155
+ T = pt.eye(k_states, k=-1)
156
+ ar_idx = (np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0])
157
+ T = T[ar_idx].set(ar_params)
158
+
159
+ else:
160
+ transition_matrices = []
161
+
162
+ for i in range(k_endog):
163
+ T = pt.eye(k_states, k=-1)
164
+ ar_idx = (np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0])
165
+ T = T[ar_idx].set(ar_params[i])
166
+ transition_matrices.append(T)
167
+ T = pt.specify_shape(
168
+ pt.linalg.block_diag(*transition_matrices), (self.k_states, self.k_states)
169
+ )
170
+
171
+ self.ssm["transition", :, :] = T
172
+
173
+ R = np.eye(k_states)
174
+ R_mask = np.full((k_states), False)
175
+ R_mask[0] = True
176
+ R = R[:, R_mask]
177
+
178
+ self.ssm["selection", :, :] = pt.specify_shape(
179
+ pt.linalg.block_diag(*[R for _ in range(k_endog)]), (self.k_states, self.k_posdef)
180
+ )
181
+
182
+ Z = pt.zeros((1, k_states))[0, 0].set(1.0)
183
+ self.ssm["design", :, :] = pt.specify_shape(
184
+ pt.linalg.block_diag(*[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
185
+ )
186
+
187
+ cov_idx = ("state_cov", *np.diag_indices(k_posdef))
188
+ self.ssm[cov_idx] = sigma_ar**2