pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pymc_extras/__init__.py +6 -4
  2. pymc_extras/distributions/__init__.py +2 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/distributions/transforms/__init__.py +3 -0
  6. pymc_extras/distributions/transforms/partial_order.py +227 -0
  7. pymc_extras/inference/__init__.py +4 -2
  8. pymc_extras/inference/find_map.py +62 -17
  9. pymc_extras/inference/fit.py +6 -4
  10. pymc_extras/inference/laplace.py +14 -8
  11. pymc_extras/inference/pathfinder/lbfgs.py +49 -13
  12. pymc_extras/inference/pathfinder/pathfinder.py +89 -103
  13. pymc_extras/statespace/core/statespace.py +191 -52
  14. pymc_extras/statespace/filters/distributions.py +15 -16
  15. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  16. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  17. pymc_extras/statespace/models/ETS.py +10 -0
  18. pymc_extras/statespace/models/SARIMAX.py +26 -5
  19. pymc_extras/statespace/models/VARMAX.py +12 -2
  20. pymc_extras/statespace/models/structural.py +18 -5
  21. pymc_extras/statespace/utils/data_tools.py +24 -9
  22. pymc_extras-0.2.6.dist-info/METADATA +318 -0
  23. pymc_extras-0.2.6.dist-info/RECORD +65 -0
  24. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
  25. pymc_extras/version.py +0 -11
  26. pymc_extras/version.txt +0 -1
  27. pymc_extras-0.2.4.dist-info/METADATA +0 -110
  28. pymc_extras-0.2.4.dist-info/RECORD +0 -105
  29. pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
  30. tests/__init__.py +0 -13
  31. tests/distributions/__init__.py +0 -19
  32. tests/distributions/test_continuous.py +0 -185
  33. tests/distributions/test_discrete.py +0 -210
  34. tests/distributions/test_discrete_markov_chain.py +0 -258
  35. tests/distributions/test_multivariate.py +0 -304
  36. tests/model/__init__.py +0 -0
  37. tests/model/marginal/__init__.py +0 -0
  38. tests/model/marginal/test_distributions.py +0 -132
  39. tests/model/marginal/test_graph_analysis.py +0 -182
  40. tests/model/marginal/test_marginal_model.py +0 -967
  41. tests/model/test_model_api.py +0 -38
  42. tests/statespace/__init__.py +0 -0
  43. tests/statespace/test_ETS.py +0 -411
  44. tests/statespace/test_SARIMAX.py +0 -405
  45. tests/statespace/test_VARMAX.py +0 -184
  46. tests/statespace/test_coord_assignment.py +0 -116
  47. tests/statespace/test_distributions.py +0 -270
  48. tests/statespace/test_kalman_filter.py +0 -326
  49. tests/statespace/test_representation.py +0 -175
  50. tests/statespace/test_statespace.py +0 -872
  51. tests/statespace/test_statespace_JAX.py +0 -156
  52. tests/statespace/test_structural.py +0 -836
  53. tests/statespace/utilities/__init__.py +0 -0
  54. tests/statespace/utilities/shared_fixtures.py +0 -9
  55. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  56. tests/statespace/utilities/test_helpers.py +0 -310
  57. tests/test_blackjax_smc.py +0 -222
  58. tests/test_find_map.py +0 -103
  59. tests/test_histogram_approximation.py +0 -109
  60. tests/test_laplace.py +0 -265
  61. tests/test_linearmodel.py +0 -208
  62. tests/test_model_builder.py +0 -306
  63. tests/test_pathfinder.py +0 -203
  64. tests/test_pivoted_cholesky.py +0 -24
  65. tests/test_printing.py +0 -98
  66. tests/test_prior_from_trace.py +0 -172
  67. tests/test_splines.py +0 -77
  68. tests/utils.py +0 -0
  69. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
@@ -69,9 +69,9 @@ class _LinearGaussianStateSpace(Continuous):
69
69
  H,
70
70
  Q,
71
71
  steps=None,
72
- mode=None,
73
72
  sequence_names=None,
74
73
  append_x0=True,
74
+ method="svd",
75
75
  **kwargs,
76
76
  ):
77
77
  # Ignore dims in support shape because they are just passed along to the "observed" and "latent" distributions
@@ -97,9 +97,9 @@ class _LinearGaussianStateSpace(Continuous):
97
97
  H,
98
98
  Q,
99
99
  steps=steps,
100
- mode=mode,
101
100
  sequence_names=sequence_names,
102
101
  append_x0=append_x0,
102
+ method=method,
103
103
  **kwargs,
104
104
  )
105
105
 
@@ -116,9 +116,9 @@ class _LinearGaussianStateSpace(Continuous):
116
116
  H,
117
117
  Q,
118
118
  steps=None,
119
- mode=None,
120
119
  sequence_names=None,
121
120
  append_x0=True,
121
+ method="svd",
122
122
  **kwargs,
123
123
  ):
124
124
  steps = get_support_shape_1d(
@@ -132,9 +132,9 @@ class _LinearGaussianStateSpace(Continuous):
132
132
 
133
133
  return super().dist(
134
134
  [a0, P0, c, d, T, Z, R, H, Q, steps],
135
- mode=mode,
136
135
  sequence_names=sequence_names,
137
136
  append_x0=append_x0,
137
+ method=method,
138
138
  **kwargs,
139
139
  )
140
140
 
@@ -152,9 +152,9 @@ class _LinearGaussianStateSpace(Continuous):
152
152
  Q,
153
153
  steps,
154
154
  size=None,
155
- mode=None,
156
155
  sequence_names=None,
157
156
  append_x0=True,
157
+ method="svd",
158
158
  ):
159
159
  if sequence_names is None:
160
160
  sequence_names = []
@@ -205,10 +205,10 @@ class _LinearGaussianStateSpace(Continuous):
205
205
  a = state[:k]
206
206
 
207
207
  middle_rng, a_innovation = pm.MvNormal.dist(
208
- mu=0, cov=Q, rng=rng, method="svd"
208
+ mu=0, cov=Q, rng=rng, method=method
209
209
  ).owner.outputs
210
210
  next_rng, y_innovation = pm.MvNormal.dist(
211
- mu=0, cov=H, rng=middle_rng, method="svd"
211
+ mu=0, cov=H, rng=middle_rng, method=method
212
212
  ).owner.outputs
213
213
 
214
214
  a_mu = c + T @ a
@@ -224,8 +224,8 @@ class _LinearGaussianStateSpace(Continuous):
224
224
  Z_init = Z_ if Z_ in non_sequences else Z_[0]
225
225
  H_init = H_ if H_ in non_sequences else H_[0]
226
226
 
227
- init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method="svd")
228
- init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method="svd")
227
+ init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method=method)
228
+ init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method=method)
229
229
 
230
230
  init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
231
231
 
@@ -235,7 +235,6 @@ class _LinearGaussianStateSpace(Continuous):
235
235
  sequences=None if len(sequences) == 0 else sequences,
236
236
  non_sequences=[*non_sequences, rng],
237
237
  n_steps=steps,
238
- mode=mode,
239
238
  strict=True,
240
239
  )
241
240
 
@@ -279,8 +278,8 @@ class LinearGaussianStateSpace(Continuous):
279
278
  steps,
280
279
  k_endog=None,
281
280
  sequence_names=None,
282
- mode=None,
283
281
  append_x0=True,
282
+ method="svd",
284
283
  **kwargs,
285
284
  ):
286
285
  dims = kwargs.pop("dims", None)
@@ -307,9 +306,9 @@ class LinearGaussianStateSpace(Continuous):
307
306
  H,
308
307
  Q,
309
308
  steps=steps,
310
- mode=mode,
311
309
  sequence_names=sequence_names,
312
310
  append_x0=append_x0,
311
+ method=method,
313
312
  **kwargs,
314
313
  )
315
314
  latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None))
@@ -368,11 +367,11 @@ class SequenceMvNormal(Continuous):
368
367
  return super().__new__(cls, *args, **kwargs)
369
368
 
370
369
  @classmethod
371
- def dist(cls, mus, covs, logp, **kwargs):
372
- return super().dist([mus, covs, logp], **kwargs)
370
+ def dist(cls, mus, covs, logp, method="svd", **kwargs):
371
+ return super().dist([mus, covs, logp], method=method, **kwargs)
373
372
 
374
373
  @classmethod
375
- def rv_op(cls, mus, covs, logp, size=None):
374
+ def rv_op(cls, mus, covs, logp, method="svd", size=None):
376
375
  # Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
377
376
  if mus.ndim > 2:
378
377
  mus = pt.moveaxis(mus, -2, 0)
@@ -385,7 +384,7 @@ class SequenceMvNormal(Continuous):
385
384
  rng = pytensor.shared(np.random.default_rng())
386
385
 
387
386
  def step(mu, cov, rng):
388
- new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs
387
+ new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method=method).owner.outputs
389
388
  return mvn, {rng: new_rng}
390
389
 
391
390
  mvn_seq, updates = pytensor.scan(
@@ -5,7 +5,6 @@ import pytensor
5
5
  import pytensor.tensor as pt
6
6
 
7
7
  from pymc.pytensorf import constant_fold
8
- from pytensor.compile.mode import get_mode
9
8
  from pytensor.graph.basic import Variable
10
9
  from pytensor.raise_op import Assert
11
10
  from pytensor.tensor import TensorVariable
@@ -28,15 +27,10 @@ assert_time_varying_dim_correct = Assert(
28
27
 
29
28
 
30
29
  class BaseFilter(ABC):
31
- def __init__(self, mode=None):
30
+ def __init__(self):
32
31
  """
33
32
  Kalman Filter.
34
33
 
35
- Parameters
36
- ----------
37
- mode : str, optional
38
- The mode used for Pytensor compilation. Defaults to None.
39
-
40
34
  Notes
41
35
  -----
42
36
  The BaseFilter class is an abstract base class (ABC) for implementing kalman filters.
@@ -44,9 +38,6 @@ class BaseFilter(ABC):
44
38
 
45
39
  Attributes
46
40
  ----------
47
- mode : str or None
48
- The mode used for Pytensor compilation.
49
-
50
41
  seq_names : list[str]
51
42
  A list of name representing time-varying statespace matrices. That is, inputs that will need to be
52
43
  provided to the `sequences` argument of `pytensor.scan`
@@ -56,7 +47,6 @@ class BaseFilter(ABC):
56
47
  to the `non_sequences` argument of `pytensor.scan`
57
48
  """
58
49
 
59
- self.mode: str = mode
60
50
  self.seq_names: list[str] = []
61
51
  self.non_seq_names: list[str] = []
62
52
 
@@ -153,7 +143,6 @@ class BaseFilter(ABC):
153
143
  R,
154
144
  H,
155
145
  Q,
156
- mode=None,
157
146
  return_updates=False,
158
147
  missing_fill_value=None,
159
148
  cov_jitter=None,
@@ -166,9 +155,6 @@ class BaseFilter(ABC):
166
155
  data : TensorVariable
167
156
  Data to be filtered
168
157
 
169
- mode : optional, str
170
- Pytensor compile mode, passed to pytensor.scan
171
-
172
158
  return_updates: bool, default False
173
159
  Whether to return updates associated with the pytensor scan. Should only be requried to debug pruposes.
174
160
 
@@ -199,7 +185,6 @@ class BaseFilter(ABC):
199
185
  if cov_jitter is None:
200
186
  cov_jitter = JITTER_DEFAULT
201
187
 
202
- self.mode = mode
203
188
  self.missing_fill_value = missing_fill_value
204
189
  self.cov_jitter = cov_jitter
205
190
 
@@ -227,7 +212,6 @@ class BaseFilter(ABC):
227
212
  outputs_info=[None, a0, None, None, P0, None, None],
228
213
  non_sequences=non_sequences,
229
214
  name="forward_kalman_pass",
230
- mode=get_mode(self.mode),
231
215
  strict=False,
232
216
  )
233
217
 
@@ -800,7 +784,6 @@ class UnivariateFilter(BaseFilter):
800
784
  self._univariate_inner_filter_step,
801
785
  sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
802
786
  outputs_info=[a, P, None, None, None],
803
- mode=get_mode(self.mode),
804
787
  name="univariate_inner_scan",
805
788
  )
806
789
 
@@ -1,7 +1,6 @@
1
1
  import pytensor
2
2
  import pytensor.tensor as pt
3
3
 
4
- from pytensor.compile import get_mode
5
4
  from pytensor.tensor.nlinalg import matrix_dot
6
5
 
7
6
  from pymc_extras.statespace.filters.utilities import (
@@ -18,8 +17,7 @@ class KalmanSmoother:
18
17
 
19
18
  """
20
19
 
21
- def __init__(self, mode: str | None = None):
22
- self.mode = mode
20
+ def __init__(self):
23
21
  self.cov_jitter = JITTER_DEFAULT
24
22
  self.seq_names = []
25
23
  self.non_seq_names = []
@@ -64,9 +62,8 @@ class KalmanSmoother:
64
62
  return a, P, a_smooth, P_smooth, T, R, Q
65
63
 
66
64
  def build_graph(
67
- self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT
65
+ self, T, R, Q, filtered_states, filtered_covariances, cov_jitter=JITTER_DEFAULT
68
66
  ):
69
- self.mode = mode
70
67
  self.cov_jitter = cov_jitter
71
68
 
72
69
  n, k = filtered_states.type.shape
@@ -88,7 +85,6 @@ class KalmanSmoother:
88
85
  non_sequences=non_sequences,
89
86
  go_backwards=True,
90
87
  name="kalman_smoother",
91
- mode=get_mode(self.mode),
92
88
  )
93
89
 
94
90
  smoothed_states, smoothed_covariances = smoother_result
@@ -5,6 +5,7 @@ import numpy as np
5
5
  import pytensor.tensor as pt
6
6
 
7
7
  from pytensor import graph_replace
8
+ from pytensor.compile.mode import Mode
8
9
  from pytensor.tensor.slinalg import solve_discrete_lyapunov
9
10
 
10
11
  from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
@@ -35,6 +36,7 @@ class BayesianETS(PyMCStateSpace):
35
36
  initialization_dampening: float = 0.8,
36
37
  filter_type: str = "standard",
37
38
  verbose: bool = True,
39
+ mode: str | Mode | None = None,
38
40
  ):
39
41
  r"""
40
42
  Exponential Smoothing State Space Model
@@ -212,6 +214,13 @@ class BayesianETS(PyMCStateSpace):
212
214
  and "cholesky". See the docs for kalman filters for more details.
213
215
  verbose: bool, default True
214
216
  If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
217
+ mode: str or Mode, optional
218
+ Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
219
+ ``forecast``. The mode does **not** effect calls to ``pm.sample``.
220
+
221
+ Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
222
+ to all sampling methods.
223
+
215
224
 
216
225
  References
217
226
  ----------
@@ -284,6 +293,7 @@ class BayesianETS(PyMCStateSpace):
284
293
  filter_type,
285
294
  verbose=verbose,
286
295
  measurement_error=measurement_error,
296
+ mode=mode,
287
297
  )
288
298
 
289
299
  @property
@@ -4,6 +4,7 @@ from typing import Any
4
4
  import numpy as np
5
5
  import pytensor.tensor as pt
6
6
 
7
+ from pytensor.compile.mode import Mode
7
8
  from pytensor.tensor.slinalg import solve_discrete_lyapunov
8
9
 
9
10
  from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
@@ -91,6 +92,13 @@ class BayesianSARIMA(PyMCStateSpace):
91
92
  verbose: bool, default True
92
93
  If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
93
94
 
95
+ mode: str or Mode, optional
96
+ Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
97
+ ``forecast``. The mode does **not** effect calls to ``pm.sample``.
98
+
99
+ Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
100
+ to all sampling methods.
101
+
94
102
  Notes
95
103
  -----
96
104
  The ARIMAX model is a univariate time series model that posits the future evolution of a stationary time series will
@@ -158,7 +166,7 @@ class BayesianSARIMA(PyMCStateSpace):
158
166
  rho = pm.Beta("ar_params", alpha=5, beta=1, dims=ss_mod.param_dims["ar_params"])
159
167
  theta = pm.Normal("ma_params", mu=0.0, sigma=0.5, dims=ss_mod.param_dims["ma_params"])
160
168
 
161
- ss_mod.build_statespace_graph(df, mode="JAX")
169
+ ss_mod.build_statespace_graph(df)
162
170
  idata = pm.sample(nuts_sampler='numpyro')
163
171
 
164
172
  References
@@ -180,7 +188,21 @@ class BayesianSARIMA(PyMCStateSpace):
180
188
  state_structure: str = "fast",
181
189
  measurement_error: bool = False,
182
190
  verbose=True,
191
+ mode: str | Mode | None = None,
183
192
  ):
193
+ """
194
+
195
+ Parameters
196
+ ----------
197
+ order
198
+ seasonal_order
199
+ stationary_initialization
200
+ filter_type
201
+ state_structure
202
+ measurement_error
203
+ verbose
204
+ mode
205
+ """
184
206
  # Model order
185
207
  self.p, self.d, self.q = order
186
208
  if seasonal_order is None:
@@ -228,6 +250,7 @@ class BayesianSARIMA(PyMCStateSpace):
228
250
  filter_type,
229
251
  verbose=verbose,
230
252
  measurement_error=measurement_error,
253
+ mode=mode,
231
254
  )
232
255
 
233
256
  @property
@@ -366,7 +389,7 @@ class BayesianSARIMA(PyMCStateSpace):
366
389
 
367
390
  return coords
368
391
 
369
- def _stationary_initialization(self, mode=None):
392
+ def _stationary_initialization(self):
370
393
  # Solve for matrix quadratic for P0
371
394
  T = self.ssm["transition"]
372
395
  R = self.ssm["selection"]
@@ -374,9 +397,7 @@ class BayesianSARIMA(PyMCStateSpace):
374
397
  c = self.ssm["state_intercept"]
375
398
 
376
399
  x0 = pt.linalg.solve(pt.identity_like(T) - T, c, assume_a="gen", check_finite=True)
377
-
378
- method = "direct" if (self.k_states < 5) or (mode == "JAX") else "bilinear"
379
- P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method=method)
400
+ P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method="bilinear")
380
401
 
381
402
  return x0, P0
382
403
 
@@ -5,6 +5,7 @@ import numpy as np
5
5
  import pytensor
6
6
  import pytensor.tensor as pt
7
7
 
8
+ from pytensor.compile.mode import Mode
8
9
  from pytensor.tensor.slinalg import solve_discrete_lyapunov
9
10
 
10
11
  from pymc_extras.statespace.core.statespace import PyMCStateSpace
@@ -72,6 +73,13 @@ class BayesianVARMAX(PyMCStateSpace):
72
73
  verbose: bool, default True
73
74
  If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
74
75
 
76
+ mode: str or Mode, optional
77
+ Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
78
+ ``forecast``. The mode does **not** effect calls to ``pm.sample``.
79
+
80
+ Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
81
+ to all sampling methods.
82
+
75
83
  Notes
76
84
  -----
77
85
  The VARMA model is a multivariate extension of the SARIMAX model. Given a set of timeseries :math:`\{x_t\}_{t=0}^T`,
@@ -135,7 +143,7 @@ class BayesianVARMAX(PyMCStateSpace):
135
143
  ar_params = pm.Normal("ar_params", mu=0, sigma=1, dims=ar_dims)
136
144
  state_cov = pm.Deterministic("state_cov", state_chol @ state_chol.T, dims=state_cov_dims)
137
145
 
138
- bvar_mod.build_statespace_graph(data, mode="JAX")
146
+ bvar_mod.build_statespace_graph(data)
139
147
  idata = pm.sample(nuts_sampler="numpyro")
140
148
  """
141
149
 
@@ -147,7 +155,8 @@ class BayesianVARMAX(PyMCStateSpace):
147
155
  stationary_initialization: bool = False,
148
156
  filter_type: str = "standard",
149
157
  measurement_error: bool = False,
150
- verbose=True,
158
+ verbose: bool = True,
159
+ mode: str | Mode | None = None,
151
160
  ):
152
161
  if (endog_names is None) and (k_endog is None):
153
162
  raise ValueError("Must specify either endog_names or k_endog")
@@ -174,6 +183,7 @@ class BayesianVARMAX(PyMCStateSpace):
174
183
  filter_type,
175
184
  verbose=verbose,
176
185
  measurement_error=measurement_error,
186
+ mode=mode,
177
187
  )
178
188
 
179
189
  # Save counts of the number of parameters in each category
@@ -12,6 +12,7 @@ import pytensor.tensor as pt
12
12
  import xarray as xr
13
13
 
14
14
  from pytensor import Variable
15
+ from pytensor.compile.mode import Mode
15
16
 
16
17
  from pymc_extras.statespace.core import PytensorRepresentation
17
18
  from pymc_extras.statespace.core.statespace import PyMCStateSpace
@@ -81,6 +82,7 @@ class StructuralTimeSeries(PyMCStateSpace):
81
82
  name: str | None = None,
82
83
  verbose: bool = True,
83
84
  filter_type: str = "standard",
85
+ mode: str | Mode | None = None,
84
86
  ):
85
87
  # Add the initial state covariance to the parameters
86
88
  if name is None:
@@ -112,6 +114,7 @@ class StructuralTimeSeries(PyMCStateSpace):
112
114
  filter_type=filter_type,
113
115
  verbose=verbose,
114
116
  measurement_error=measurement_error,
117
+ mode=mode,
115
118
  )
116
119
  self.ssm = ssm.copy()
117
120
 
@@ -644,7 +647,9 @@ class Component(ABC):
644
647
 
645
648
  return new_comp
646
649
 
647
- def build(self, name=None, filter_type="standard", verbose=True):
650
+ def build(
651
+ self, name=None, filter_type="standard", verbose=True, mode: str | Mode | None = None
652
+ ):
648
653
  """
649
654
  Build a StructuralTimeSeries statespace model from the current component(s)
650
655
 
@@ -660,6 +665,13 @@ class Component(ABC):
660
665
  verbose : bool, optional
661
666
  If True, displays information about the initialized model. Defaults to True.
662
667
 
668
+ mode: str or Mode, optional
669
+ Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
670
+ ``forecast``. The mode does **not** effect calls to ``pm.sample``.
671
+
672
+ Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
673
+ to all sampling methods.
674
+
663
675
  Returns
664
676
  -------
665
677
  PyMCStateSpace
@@ -685,6 +697,7 @@ class Component(ABC):
685
697
  name_to_data=self._name_to_data,
686
698
  filter_type=filter_type,
687
699
  verbose=verbose,
700
+ mode=mode,
688
701
  )
689
702
 
690
703
 
@@ -908,7 +921,7 @@ class MeasurementError(Component):
908
921
  intitial_trend = pm.Normal('initial_trend', sigma=10, dims=ss_mod.param_dims['initial_trend'])
909
922
  sigma_obs = pm.Exponential('sigma_obs', 1, dims=ss_mod.param_dims['sigma_obs'])
910
923
 
911
- ss_mod.build_statespace_graph(data, mode='JAX')
924
+ ss_mod.build_statespace_graph(data)
912
925
  idata = pm.sample(nuts_sampler='numpyro')
913
926
  """
914
927
 
@@ -991,7 +1004,7 @@ class AutoregressiveComponent(Component):
991
1004
  ar_params = pm.Normal('ar_params', dims=ss_mod.param_dims['ar_params'])
992
1005
  sigma_ar = pm.Exponential('sigma_ar', 1, dims=ss_mod.param_dims['sigma_ar'])
993
1006
 
994
- ss_mod.build_statespace_graph(data, mode='JAX')
1007
+ ss_mod.build_statespace_graph(data)
995
1008
  idata = pm.sample(nuts_sampler='numpyro')
996
1009
 
997
1010
  """
@@ -1153,7 +1166,7 @@ class TimeSeasonality(Component):
1153
1166
  intitial_trend = pm.Deterministic('initial_trend', pt.zeros(1), dims=ss_mod.param_dims['initial_trend'])
1154
1167
  annual_coefs = pm.Normal('annual_coefs', sigma=1e-2, dims=ss_mod.param_dims['annual_coefs'])
1155
1168
  trend_sigmas = pm.HalfNormal('trend_sigmas', sigma=1e-6, dims=ss_mod.param_dims['trend_sigmas'])
1156
- ss_mod.build_statespace_graph(data, mode='JAX')
1169
+ ss_mod.build_statespace_graph(data)
1157
1170
  idata = pm.sample(nuts_sampler='numpyro')
1158
1171
 
1159
1172
  References
@@ -1451,7 +1464,7 @@ class CycleComponent(Component):
1451
1464
  cycle_length = pm.Uniform('business_cycle_length', lower=6, upper=12)
1452
1465
 
1453
1466
  sigma_cycle = pm.HalfNormal('sigma_business_cycle', sigma=1)
1454
- ss_mod.build_statespace_graph(data, mode='JAX')
1467
+ ss_mod.build_statespace_graph(data)
1455
1468
 
1456
1469
  idata = pm.sample(nuts_sampler='numpyro')
1457
1470
 
@@ -87,12 +87,7 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
87
87
  col_names = data.columns
88
88
  _validate_data_shape(data.shape, n_obs, obs_coords, check_column_names, col_names)
89
89
 
90
- if isinstance(data.index, pd.RangeIndex):
91
- if obs_coords is not None:
92
- warnings.warn(NO_TIME_INDEX_WARNING)
93
- return preprocess_numpy_data(data.values, n_obs, obs_coords)
94
-
95
- elif isinstance(data.index, pd.DatetimeIndex):
90
+ if isinstance(data.index, pd.DatetimeIndex):
96
91
  if data.index.freq is None:
97
92
  warnings.warn(NO_FREQ_INFO_WARNING)
98
93
  data.index.freq = data.index.inferred_freq
@@ -100,10 +95,30 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
100
95
  index = data.index
101
96
  return data.values, index
102
97
 
98
+ elif isinstance(data.index, pd.RangeIndex):
99
+ if obs_coords is not None:
100
+ warnings.warn(NO_TIME_INDEX_WARNING)
101
+ return preprocess_numpy_data(data.values, n_obs, obs_coords)
102
+
103
+ elif isinstance(data.index, pd.MultiIndex):
104
+ if obs_coords is not None:
105
+ warnings.warn(NO_TIME_INDEX_WARNING)
106
+
107
+ raise NotImplementedError("MultiIndex panel data is not currently supported.")
108
+
103
109
  else:
104
- raise IndexError(
105
- f"Expected pd.DatetimeIndex or pd.RangeIndex on data, found {type(data.index)}"
106
- )
110
+ if obs_coords is not None:
111
+ warnings.warn(NO_TIME_INDEX_WARNING)
112
+
113
+ index = data.index
114
+ if not np.issubdtype(index.dtype, np.integer):
115
+ raise IndexError("Provided index is not an integer index.")
116
+
117
+ index_diff = index.to_series().diff().dropna().values
118
+ if not (index_diff == 1).all():
119
+ raise IndexError("Provided index is not monotonic increasing.")
120
+
121
+ return preprocess_numpy_data(data.values, n_obs, obs_coords)
107
122
 
108
123
 
109
124
  def add_data_to_active_model(values, index, data_dims=None):