pymc-extras 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pymc_extras/distributions/__init__.py +5 -5
  2. pymc_extras/distributions/histogram_utils.py +1 -1
  3. pymc_extras/inference/__init__.py +8 -1
  4. pymc_extras/inference/dadvi/__init__.py +0 -0
  5. pymc_extras/inference/dadvi/dadvi.py +261 -0
  6. pymc_extras/inference/fit.py +5 -0
  7. pymc_extras/inference/laplace_approx/find_map.py +16 -8
  8. pymc_extras/inference/laplace_approx/idata.py +5 -2
  9. pymc_extras/inference/laplace_approx/laplace.py +1 -0
  10. pymc_extras/printing.py +1 -1
  11. pymc_extras/statespace/__init__.py +4 -4
  12. pymc_extras/statespace/core/__init__.py +1 -1
  13. pymc_extras/statespace/core/representation.py +8 -8
  14. pymc_extras/statespace/core/statespace.py +94 -23
  15. pymc_extras/statespace/filters/__init__.py +3 -3
  16. pymc_extras/statespace/filters/kalman_filter.py +16 -11
  17. pymc_extras/statespace/models/DFM.py +849 -0
  18. pymc_extras/statespace/models/SARIMAX.py +138 -74
  19. pymc_extras/statespace/models/VARMAX.py +248 -57
  20. pymc_extras/statespace/models/__init__.py +2 -2
  21. pymc_extras/statespace/models/structural/__init__.py +4 -4
  22. pymc_extras/statespace/models/structural/components/autoregressive.py +49 -24
  23. pymc_extras/statespace/models/structural/components/cycle.py +48 -28
  24. pymc_extras/statespace/models/structural/components/level_trend.py +61 -29
  25. pymc_extras/statespace/models/structural/components/measurement_error.py +22 -5
  26. pymc_extras/statespace/models/structural/components/regression.py +47 -18
  27. pymc_extras/statespace/models/structural/components/seasonality.py +278 -95
  28. pymc_extras/statespace/models/structural/core.py +27 -8
  29. pymc_extras/statespace/utils/constants.py +19 -14
  30. pymc_extras/statespace/utils/data_tools.py +1 -1
  31. {pymc_extras-0.4.0.dist-info → pymc_extras-0.5.0.dist-info}/METADATA +1 -1
  32. {pymc_extras-0.4.0.dist-info → pymc_extras-0.5.0.dist-info}/RECORD +34 -31
  33. {pymc_extras-0.4.0.dist-info → pymc_extras-0.5.0.dist-info}/WHEEL +0 -0
  34. {pymc_extras-0.4.0.dist-info → pymc_extras-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -7,9 +7,9 @@ from pymc_extras.statespace.filters.kalman_filter import (
7
7
  from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother
8
8
 
9
9
  __all__ = [
10
- "StandardFilter",
11
- "UnivariateFilter",
12
10
  "KalmanSmoother",
13
- "SquareRootFilter",
14
11
  "LinearGaussianStateSpace",
12
+ "SquareRootFilter",
13
+ "StandardFilter",
14
+ "UnivariateFilter",
15
15
  ]
@@ -15,10 +15,15 @@ from pymc_extras.statespace.filters.utilities import (
15
15
  split_vars_into_seq_and_nonseq,
16
16
  stabilize,
17
17
  )
18
- from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL
18
+ from pymc_extras.statespace.utils.constants import (
19
+ FILTER_OUTPUT_NAMES,
20
+ JITTER_DEFAULT,
21
+ MATRIX_NAMES,
22
+ MISSING_FILL,
23
+ )
19
24
 
20
25
  MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
21
- PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
26
+ PARAM_NAMES = MATRIX_NAMES[2:]
22
27
 
23
28
  assert_time_varying_dim_correct = Assert(
24
29
  "The first dimension of a time varying matrix (the time dimension) must be "
@@ -119,7 +124,7 @@ class BaseFilter(ABC):
119
124
  # There are always two outputs_info wedged between the seqs and non_seqs
120
125
  seqs, (a0, P0), non_seqs = args[:n_seq], args[n_seq : n_seq + 2], args[n_seq + 2 :]
121
126
  return_ordered = []
122
- for name in ["c", "d", "T", "Z", "R", "H", "Q"]:
127
+ for name in PARAM_NAMES:
123
128
  if name in self.seq_names:
124
129
  idx = self.seq_names.index(name)
125
130
  return_ordered.append(seqs[idx])
@@ -253,28 +258,28 @@ class BaseFilter(ABC):
253
258
  )
254
259
 
255
260
  filtered_states = pt.specify_shape(filtered_states, (n, self.n_states))
256
- filtered_states.name = "filtered_states"
261
+ filtered_states.name = FILTER_OUTPUT_NAMES[0]
257
262
 
258
263
  predicted_states = pt.specify_shape(predicted_states, (n, self.n_states))
259
- predicted_states.name = "predicted_states"
260
-
261
- observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
262
- observed_states.name = "observed_states"
264
+ predicted_states.name = FILTER_OUTPUT_NAMES[1]
263
265
 
264
266
  filtered_covariances = pt.specify_shape(
265
267
  filtered_covariances, (n, self.n_states, self.n_states)
266
268
  )
267
- filtered_covariances.name = "filtered_covariances"
269
+ filtered_covariances.name = FILTER_OUTPUT_NAMES[2]
268
270
 
269
271
  predicted_covariances = pt.specify_shape(
270
272
  predicted_covariances, (n, self.n_states, self.n_states)
271
273
  )
272
- predicted_covariances.name = "predicted_covariances"
274
+ predicted_covariances.name = FILTER_OUTPUT_NAMES[3]
275
+
276
+ observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
277
+ observed_states.name = FILTER_OUTPUT_NAMES[4]
273
278
 
274
279
  observed_covariances = pt.specify_shape(
275
280
  observed_covariances, (n, self.n_endog, self.n_endog)
276
281
  )
277
- observed_covariances.name = "observed_covariances"
282
+ observed_covariances.name = FILTER_OUTPUT_NAMES[5]
278
283
 
279
284
  loglike_obs = pt.specify_shape(loglike_obs.squeeze(), (n,))
280
285
  loglike_obs.name = "loglike_obs"