pymc-extras 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. pymc_extras/deserialize.py +10 -4
  2. pymc_extras/distributions/continuous.py +1 -1
  3. pymc_extras/distributions/histogram_utils.py +6 -4
  4. pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  5. pymc_extras/distributions/timeseries.py +14 -12
  6. pymc_extras/inference/dadvi/dadvi.py +149 -128
  7. pymc_extras/inference/laplace_approx/find_map.py +16 -39
  8. pymc_extras/inference/laplace_approx/idata.py +22 -4
  9. pymc_extras/inference/laplace_approx/laplace.py +196 -151
  10. pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  11. pymc_extras/inference/pathfinder/idata.py +517 -0
  12. pymc_extras/inference/pathfinder/pathfinder.py +71 -12
  13. pymc_extras/inference/smc/sampling.py +2 -2
  14. pymc_extras/model/marginal/distributions.py +4 -2
  15. pymc_extras/model/marginal/graph_analysis.py +2 -2
  16. pymc_extras/model/marginal/marginal_model.py +12 -2
  17. pymc_extras/model_builder.py +9 -4
  18. pymc_extras/prior.py +203 -8
  19. pymc_extras/statespace/core/compile.py +1 -1
  20. pymc_extras/statespace/core/statespace.py +2 -1
  21. pymc_extras/statespace/filters/distributions.py +15 -13
  22. pymc_extras/statespace/filters/kalman_filter.py +24 -22
  23. pymc_extras/statespace/filters/kalman_smoother.py +3 -5
  24. pymc_extras/statespace/filters/utilities.py +2 -5
  25. pymc_extras/statespace/models/DFM.py +12 -27
  26. pymc_extras/statespace/models/ETS.py +190 -198
  27. pymc_extras/statespace/models/SARIMAX.py +5 -17
  28. pymc_extras/statespace/models/VARMAX.py +15 -67
  29. pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  30. pymc_extras/statespace/models/structural/components/regression.py +4 -26
  31. pymc_extras/statespace/models/utilities.py +7 -0
  32. pymc_extras/utils/model_equivalence.py +2 -2
  33. pymc_extras/utils/prior.py +10 -14
  34. pymc_extras/utils/spline.py +4 -10
  35. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/METADATA +4 -4
  36. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/RECORD +38 -37
  37. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/WHEEL +1 -1
  38. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -197,10 +197,9 @@ class _LinearGaussianStateSpace(Continuous):
197
197
  n_seq = len(sequence_names)
198
198
 
199
199
  def step_fn(*args):
200
- seqs, state, non_seqs = args[:n_seq], args[n_seq], args[n_seq + 1 :]
201
- non_seqs, rng = non_seqs[:-1], non_seqs[-1]
200
+ seqs, (rng, state, *non_seqs) = args[:n_seq], args[n_seq:]
202
201
 
203
- c, d, T, Z, R, H, Q = sort_args(seqs + non_seqs)
202
+ c, d, T, Z, R, H, Q = sort_args((*seqs, *non_seqs))
204
203
  k = T.shape[0]
205
204
  a = state[:k]
206
205
 
@@ -219,7 +218,7 @@ class _LinearGaussianStateSpace(Continuous):
219
218
 
220
219
  next_state = pt.concatenate([a_next, y_next], axis=0)
221
220
 
222
- return next_state, {rng: next_rng}
221
+ return next_rng, next_state
223
222
 
224
223
  Z_init = Z_ if Z_ in non_sequences else Z_[0]
225
224
  H_init = H_ if H_ in non_sequences else H_[0]
@@ -229,13 +228,14 @@ class _LinearGaussianStateSpace(Continuous):
229
228
 
230
229
  init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
231
230
 
232
- statespace, updates = pytensor.scan(
231
+ ss_rng, statespace = pytensor.scan(
233
232
  step_fn,
234
- outputs_info=[init_dist_],
233
+ outputs_info=[rng, init_dist_],
235
234
  sequences=None if len(sequences) == 0 else sequences,
236
- non_sequences=[*non_sequences, rng],
235
+ non_sequences=[*non_sequences],
237
236
  n_steps=steps,
238
237
  strict=True,
238
+ return_updates=False,
239
239
  )
240
240
 
241
241
  if append_x0:
@@ -245,7 +245,6 @@ class _LinearGaussianStateSpace(Continuous):
245
245
  statespace_ = statespace
246
246
  statespace_ = pt.specify_shape(statespace_, (steps, None))
247
247
 
248
- (ss_rng,) = tuple(updates.values())
249
248
  linear_gaussian_ss_op = LinearGaussianStateSpaceRV(
250
249
  inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps, rng],
251
250
  outputs=[ss_rng, statespace_],
@@ -385,10 +384,15 @@ class SequenceMvNormal(Continuous):
385
384
 
386
385
  def step(mu, cov, rng):
387
386
  new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method=method).owner.outputs
388
- return mvn, {rng: new_rng}
387
+ return new_rng, mvn
389
388
 
390
- mvn_seq, updates = pytensor.scan(
391
- step, sequences=[mus_, covs_], non_sequences=[rng], strict=True, n_steps=mus_.shape[0]
389
+ seq_mvn_rng, mvn_seq = pytensor.scan(
390
+ step,
391
+ sequences=[mus_, covs_],
392
+ outputs_info=[rng, None],
393
+ strict=True,
394
+ n_steps=mus_.shape[0],
395
+ return_updates=False,
392
396
  )
393
397
  mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
394
398
 
@@ -396,8 +400,6 @@ class SequenceMvNormal(Continuous):
396
400
  if mvn_seq.ndim > 2:
397
401
  mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
398
402
 
399
- (seq_mvn_rng,) = tuple(updates.values())
400
-
401
403
  mvn_seq_op = KalmanFilterRV(
402
404
  inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
403
405
  )
@@ -148,10 +148,9 @@ class BaseFilter(ABC):
148
148
  R,
149
149
  H,
150
150
  Q,
151
- return_updates=False,
152
151
  missing_fill_value=None,
153
152
  cov_jitter=None,
154
- ) -> list[TensorVariable] | tuple[list[TensorVariable], dict]:
153
+ ) -> list[TensorVariable]:
155
154
  """
156
155
  Construct the computation graph for the Kalman filter. See [1] for details.
157
156
 
@@ -200,7 +199,7 @@ class BaseFilter(ABC):
200
199
  self.n_endog = Z_shape[-2]
201
200
 
202
201
  data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
203
-
202
+ data = pt.specify_shape(data, (data.type.shape[0], self.n_endog))
204
203
  sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
205
204
  params, PARAM_NAMES
206
205
  )
@@ -211,20 +210,17 @@ class BaseFilter(ABC):
211
210
  if len(sequences) > 0:
212
211
  sequences = self.add_check_on_time_varying_shapes(data, sequences)
213
212
 
214
- results, updates = pytensor.scan(
213
+ results = pytensor.scan(
215
214
  self.kalman_step,
216
215
  sequences=[data, *sequences],
217
216
  outputs_info=[None, a0, None, None, P0, None, None],
218
217
  non_sequences=non_sequences,
219
218
  name="forward_kalman_pass",
220
219
  strict=False,
220
+ return_updates=False,
221
221
  )
222
222
 
223
- filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
224
-
225
- if return_updates:
226
- return filter_results, updates
227
- return filter_results
223
+ return self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
228
224
 
229
225
  def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
230
226
  """
@@ -393,7 +389,7 @@ class BaseFilter(ABC):
393
389
  .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
394
390
  2nd ed, Oxford University Press, 2012.
395
391
  """
396
- a_hat = T.dot(a) + c
392
+ a_hat = T @ a + c
397
393
  P_hat = quad_form_sym(T, P) + quad_form_sym(R, Q)
398
394
 
399
395
  return a_hat, P_hat
@@ -580,16 +576,16 @@ class StandardFilter(BaseFilter):
580
576
  .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
581
577
  2nd ed, Oxford University Press, 2012.
582
578
  """
583
- y_hat = d + Z.dot(a)
579
+ y_hat = d + Z @ a
584
580
  v = y - y_hat
585
581
 
586
- PZT = P.dot(Z.T)
582
+ PZT = P.dot(Z.mT)
587
583
  F = Z.dot(PZT) + stabilize(H, self.cov_jitter)
588
584
 
589
- K = pt.linalg.solve(F.T, PZT.T, assume_a="pos", check_finite=False).T
585
+ K = pt.linalg.solve(F.mT, PZT.mT, assume_a="pos", check_finite=False).mT
590
586
  I_KZ = pt.eye(self.n_states) - K.dot(Z)
591
587
 
592
- a_filtered = a + K.dot(v)
588
+ a_filtered = a + K @ v
593
589
  P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
594
590
 
595
591
  F_inv_v = pt.linalg.solve(F, v, assume_a="pos", check_finite=False)
@@ -630,9 +626,9 @@ class SquareRootFilter(BaseFilter):
630
626
  a_hat = T.dot(a) + c
631
627
  Q_chol = pt.linalg.cholesky(Q, lower=True)
632
628
 
633
- M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).T
629
+ M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).mT
634
630
  R_decomp = pt.linalg.qr(M, mode="r")
635
- P_chol_hat = R_decomp[: self.n_states, : self.n_states].T
631
+ P_chol_hat = R_decomp[..., : self.n_states, : self.n_states].mT
636
632
 
637
633
  return a_hat, P_chol_hat
638
634
 
@@ -652,20 +648,22 @@ class SquareRootFilter(BaseFilter):
652
648
  y_hat = Z.dot(a) + d
653
649
  v = y - y_hat
654
650
 
655
- H_chol = pytensor.ifelse(pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True))
651
+ H_chol = pytensor.ifelse(
652
+ pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True, on_error="nan")
653
+ )
656
654
 
657
655
  # The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
658
656
  # Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
659
657
  # [0, L_pred]]
660
658
  # The Schur decomposition of this matrix will be B (upper triangular). We are
661
- # more insterested in B^T:
659
+ # more interested in B^T:
662
660
  # Structure of B^T = [[chol(F), 0 ],
663
661
  # [K @ chol(F), chol(P_filtered)]
664
662
  zeros = pt.zeros((self.n_states, self.n_endog))
665
663
  upper = pt.horizontal_stack(H_chol, Z @ P_chol)
666
664
  lower = pt.horizontal_stack(zeros, P_chol)
667
665
  A_T = pt.vertical_stack(upper, lower)
668
- B = pt.linalg.qr(A_T.T, mode="r").T
666
+ B = pt.linalg.qr(A_T.mT, mode="r").mT
669
667
 
670
668
  F_chol = B[: self.n_endog, : self.n_endog]
671
669
  K_F_chol = B[self.n_endog :, : self.n_endog]
@@ -677,6 +675,7 @@ class SquareRootFilter(BaseFilter):
677
675
  inner_term = solve_triangular(
678
676
  F_chol, solve_triangular(F_chol, v, lower=True), lower=True
679
677
  )
678
+
680
679
  loss = (v.T @ inner_term).ravel()
681
680
 
682
681
  # abs necessary because we're not guaranteed a positive diagonal from the schur decomposition
@@ -693,8 +692,10 @@ class SquareRootFilter(BaseFilter):
693
692
  """
694
693
  return [a, P_chol, pt.zeros(())]
695
694
 
695
+ degenerate = pt.eq(all_nan_flag, 1.0)
696
+ F_chol = pytensor.ifelse(degenerate, pt.eye(*F_chol.shape), F_chol)
696
697
  [a_filtered, P_chol_filtered, ll] = pytensor.ifelse(
697
- pt.eq(all_nan_flag, 1.0),
698
+ degenerate,
698
699
  compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
699
700
  compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
700
701
  )
@@ -785,11 +786,12 @@ class UnivariateFilter(BaseFilter):
785
786
  H_masked = W.dot(H)
786
787
  y_masked = pt.set_subtensor(y[nan_mask], 0.0)
787
788
 
788
- result, updates = pytensor.scan(
789
+ result = pytensor.scan(
789
790
  self._univariate_inner_filter_step,
790
791
  sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
791
792
  outputs_info=[a, P, None, None, None],
792
793
  name="univariate_inner_scan",
794
+ return_updates=False,
793
795
  )
794
796
 
795
797
  a_filtered, P_filtered, obs_mu, obs_cov, ll_inner = result
@@ -800,7 +802,7 @@ class UnivariateFilter(BaseFilter):
800
802
  obs_cov[-1],
801
803
  )
802
804
 
803
- P_filtered = stabilize(0.5 * (P_filtered + P_filtered.T), self.cov_jitter)
805
+ P_filtered = stabilize(0.5 * (P_filtered + P_filtered.mT), self.cov_jitter)
804
806
  a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q)
805
807
 
806
808
  ll = -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum())
@@ -1,8 +1,6 @@
1
1
  import pytensor
2
2
  import pytensor.tensor as pt
3
3
 
4
- from pytensor.tensor.nlinalg import matrix_dot
5
-
6
4
  from pymc_extras.statespace.filters.utilities import (
7
5
  quad_form_sym,
8
6
  split_vars_into_seq_and_nonseq,
@@ -78,16 +76,16 @@ class KalmanSmoother:
78
76
  self.seq_names = seq_names
79
77
  self.non_seq_names = non_seq_names
80
78
 
81
- smoother_result, updates = pytensor.scan(
79
+ smoothed_states, smoothed_covariances = pytensor.scan(
82
80
  self.smoother_step,
83
81
  sequences=[filtered_states[:-1], filtered_covariances[:-1], *sequences],
84
82
  outputs_info=[a_last, P_last],
85
83
  non_sequences=non_sequences,
86
84
  go_backwards=True,
87
85
  name="kalman_smoother",
86
+ return_updates=False,
88
87
  )
89
88
 
90
- smoothed_states, smoothed_covariances = smoother_result
91
89
  smoothed_states = pt.concatenate(
92
90
  [smoothed_states[::-1], pt.expand_dims(a_last, axis=(0,))], axis=0
93
91
  )
@@ -105,7 +103,7 @@ class KalmanSmoother:
105
103
  a_hat, P_hat = self.predict(a, P, T, R, Q)
106
104
 
107
105
  # Use pinv, otherwise P_hat is singular when there is missing data
108
- smoother_gain = matrix_dot(pt.linalg.pinv(P_hat, hermitian=True), T, P).T
106
+ smoother_gain = (pt.linalg.pinv(P_hat, hermitian=True) @ T @ P).mT
109
107
  a_smooth_next = a + smoother_gain @ (a_smooth - a_hat)
110
108
 
111
109
  P_smooth_next = P + quad_form_sym(smoother_gain, P_smooth - P_hat)
@@ -1,7 +1,5 @@
1
1
  import pytensor.tensor as pt
2
2
 
3
- from pytensor.tensor.nlinalg import matrix_dot
4
-
5
3
  from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, NEVER_TIME_VARYING, VECTOR_VALUED
6
4
 
7
5
 
@@ -48,12 +46,11 @@ def split_vars_into_seq_and_nonseq(params, param_names):
48
46
 
49
47
 
50
48
  def stabilize(cov, jitter=JITTER_DEFAULT):
51
- # Ensure diagonal is non-zero
52
49
  cov = cov + pt.identity_like(cov) * jitter
53
50
 
54
51
  return cov
55
52
 
56
53
 
57
54
  def quad_form_sym(A, B):
58
- out = matrix_dot(A, B, A.T)
59
- return 0.5 * (out + out.T)
55
+ out = A @ B @ A.mT
56
+ return 0.5 * (out + out.mT)
@@ -5,7 +5,7 @@ import pytensor
5
5
  import pytensor.tensor as pt
6
6
 
7
7
  from pymc_extras.statespace.core.statespace import PyMCStateSpace
8
- from pymc_extras.statespace.models.utilities import make_default_coords
8
+ from pymc_extras.statespace.models.utilities import make_default_coords, validate_names
9
9
  from pymc_extras.statespace.utils.constants import (
10
10
  ALL_STATE_AUX_DIM,
11
11
  ALL_STATE_DIM,
@@ -224,9 +224,7 @@ class BayesianDynamicFactor(PyMCStateSpace):
224
224
  self,
225
225
  k_factors: int,
226
226
  factor_order: int,
227
- k_endog: int | None = None,
228
227
  endog_names: Sequence[str] | None = None,
229
- k_exog: int | None = None,
230
228
  exog_names: Sequence[str] | None = None,
231
229
  shared_exog_states: bool = False,
232
230
  exog_innovations: bool = False,
@@ -249,19 +247,11 @@ class BayesianDynamicFactor(PyMCStateSpace):
249
247
  and are modeled as a white noise process, i.e., :math:`f_t = \varepsilon_{f,t}`.
250
248
  Therefore, the state vector will include one state per factor and "factor_ar" will not exist.
251
249
 
252
- k_endog : int, optional
253
- Number of observed time series. If not provided, the number of observed series will be inferred from `endog_names`.
254
- At least one of `k_endog` or `endog_names` must be provided.
255
-
256
250
  endog_names : list of str, optional
257
- Names of the observed time series. If not provided, default names will be generated as `endog_1`, `endog_2`, ..., `endog_k` based on `k_endog`.
258
- At least one of `k_endog` or `endog_names` must be provided.
259
-
260
- k_exog : int, optional
261
- Number of exogenous variables. If not provided, the model will not have exogenous variables.
251
+ Names of the observed time series.
262
252
 
263
253
  exog_names : Sequence[str], optional
264
- Names of the exogenous variables. If not provided, but `k_exog` is specified, default names will be generated as `exog_1`, `exog_2`, ..., `exog_k`.
254
+ Names of the exogenous variables.
265
255
 
266
256
  shared_exog_states: bool, optional
267
257
  Whether exogenous latent states are shared across the observed states. If True, there will be only one set of exogenous latent
@@ -289,13 +279,8 @@ class BayesianDynamicFactor(PyMCStateSpace):
289
279
 
290
280
  """
291
281
 
292
- if k_endog is None and endog_names is None:
293
- raise ValueError("Either k_endog or endog_names must be provided.")
294
- if k_endog is None:
295
- k_endog = len(endog_names)
296
- if endog_names is None:
297
- endog_names = [f"endog_{i}" for i in range(k_endog)]
298
-
282
+ validate_names(endog_names, var_name="endog_names", optional=False)
283
+ k_endog = len(endog_names)
299
284
  self.endog_names = endog_names
300
285
  self.k_endog = k_endog
301
286
  self.k_factors = k_factors
@@ -304,17 +289,17 @@ class BayesianDynamicFactor(PyMCStateSpace):
304
289
  self.error_var = error_var
305
290
  self.error_cov_type = error_cov_type
306
291
 
307
- if k_exog is None and exog_names is None:
308
- self.k_exog = 0
309
- else:
292
+ if exog_names is not None:
310
293
  self.shared_exog_states = shared_exog_states
311
294
  self.exog_innovations = exog_innovations
312
- if k_exog is None:
313
- k_exog = len(exog_names) if exog_names is not None else 0
314
- elif exog_names is None:
315
- exog_names = [f"exog_{i}" for i in range(k_exog)] if k_exog > 0 else None
295
+ validate_names(
296
+ exog_names, var_name="exog_names", optional=True
297
+ ) # Not sure if this adds anything
298
+ k_exog = len(exog_names)
316
299
  self.k_exog = k_exog
317
300
  self.exog_names = exog_names
301
+ else:
302
+ self.k_exog = 0
318
303
 
319
304
  self.k_exog_states = self.k_exog * self.k_endog if not shared_exog_states else self.k_exog
320
305
  self.exog_flag = self.k_exog > 0