pymc-extras 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. pymc_extras/__init__.py +5 -1
  2. pymc_extras/distributions/continuous.py +3 -2
  3. pymc_extras/distributions/discrete.py +3 -1
  4. pymc_extras/inference/find_map.py +62 -17
  5. pymc_extras/inference/laplace.py +10 -7
  6. pymc_extras/statespace/core/statespace.py +191 -52
  7. pymc_extras/statespace/filters/distributions.py +15 -16
  8. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  9. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  10. pymc_extras/statespace/models/ETS.py +10 -0
  11. pymc_extras/statespace/models/SARIMAX.py +26 -5
  12. pymc_extras/statespace/models/VARMAX.py +12 -2
  13. pymc_extras/statespace/models/structural.py +18 -5
  14. pymc_extras-0.2.6.dist-info/METADATA +318 -0
  15. pymc_extras-0.2.6.dist-info/RECORD +65 -0
  16. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
  17. pymc_extras/version.py +0 -11
  18. pymc_extras/version.txt +0 -1
  19. pymc_extras-0.2.5.dist-info/METADATA +0 -112
  20. pymc_extras-0.2.5.dist-info/RECORD +0 -108
  21. pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
  22. tests/__init__.py +0 -13
  23. tests/distributions/__init__.py +0 -19
  24. tests/distributions/test_continuous.py +0 -185
  25. tests/distributions/test_discrete.py +0 -210
  26. tests/distributions/test_discrete_markov_chain.py +0 -258
  27. tests/distributions/test_multivariate.py +0 -304
  28. tests/distributions/test_transform.py +0 -77
  29. tests/model/__init__.py +0 -0
  30. tests/model/marginal/__init__.py +0 -0
  31. tests/model/marginal/test_distributions.py +0 -132
  32. tests/model/marginal/test_graph_analysis.py +0 -182
  33. tests/model/marginal/test_marginal_model.py +0 -967
  34. tests/model/test_model_api.py +0 -38
  35. tests/statespace/__init__.py +0 -0
  36. tests/statespace/test_ETS.py +0 -411
  37. tests/statespace/test_SARIMAX.py +0 -405
  38. tests/statespace/test_VARMAX.py +0 -184
  39. tests/statespace/test_coord_assignment.py +0 -181
  40. tests/statespace/test_distributions.py +0 -270
  41. tests/statespace/test_kalman_filter.py +0 -326
  42. tests/statespace/test_representation.py +0 -175
  43. tests/statespace/test_statespace.py +0 -872
  44. tests/statespace/test_statespace_JAX.py +0 -156
  45. tests/statespace/test_structural.py +0 -836
  46. tests/statespace/utilities/__init__.py +0 -0
  47. tests/statespace/utilities/shared_fixtures.py +0 -9
  48. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  49. tests/statespace/utilities/test_helpers.py +0 -310
  50. tests/test_blackjax_smc.py +0 -222
  51. tests/test_find_map.py +0 -103
  52. tests/test_histogram_approximation.py +0 -109
  53. tests/test_laplace.py +0 -281
  54. tests/test_linearmodel.py +0 -208
  55. tests/test_model_builder.py +0 -306
  56. tests/test_pathfinder.py +0 -297
  57. tests/test_pivoted_cholesky.py +0 -24
  58. tests/test_printing.py +0 -98
  59. tests/test_prior_from_trace.py +0 -172
  60. tests/test_splines.py +0 -77
  61. tests/utils.py +0 -0
  62. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,156 +0,0 @@
1
- import warnings
2
-
3
- import numpy as np
4
- import pymc as pm
5
- import pytensor
6
- import pytensor.tensor as pt
7
- import pytest
8
-
9
- from pymc.model.transform.optimization import freeze_dims_and_data
10
-
11
- from pymc_extras.statespace.utils.constants import (
12
- FILTER_OUTPUT_NAMES,
13
- MATRIX_NAMES,
14
- SMOOTHER_OUTPUT_NAMES,
15
- )
16
- from tests.statespace.test_statespace import ( # pylint: disable=unused-import
17
- exog_ss_mod,
18
- ss_mod,
19
- )
20
- from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
21
- rng,
22
- )
23
- from tests.statespace.utilities.test_helpers import load_nile_test_data
24
-
25
- pytest.importorskip("jax")
26
- pytest.importorskip("numpyro")
27
-
28
-
29
- floatX = pytensor.config.floatX
30
- nile = load_nile_test_data()
31
- ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
32
-
33
-
34
- @pytest.fixture(scope="session")
35
- def pymc_mod(ss_mod):
36
- with pm.Model(coords=ss_mod.coords) as pymc_mod:
37
- rho = pm.Beta("rho", 1, 1)
38
- zeta = pm.Deterministic("zeta", 1 - rho)
39
-
40
- ss_mod.build_statespace_graph(
41
- data=nile, mode="JAX", save_kalman_filter_outputs_in_idata=True
42
- )
43
- names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
44
- for name, matrix in zip(names, ss_mod.unpack_statespace()):
45
- pm.Deterministic(name, matrix)
46
-
47
- return pymc_mod
48
-
49
-
50
- @pytest.fixture(scope="session")
51
- def exog_pymc_mod(exog_ss_mod, rng):
52
- y = rng.normal(size=(100, 1)).astype(floatX)
53
- X = rng.normal(size=(100, 3)).astype(floatX)
54
-
55
- with pm.Model(coords=exog_ss_mod.coords) as m:
56
- exog_data = pm.Data("data_exog", X)
57
- initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
58
- P0_sigma = pm.Exponential("P0_sigma", 1)
59
- P0 = pm.Deterministic(
60
- "P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"]
61
- )
62
- beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
63
-
64
- sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
65
- exog_ss_mod.build_statespace_graph(y, mode="JAX")
66
-
67
- return m
68
-
69
-
70
- @pytest.fixture(scope="session")
71
- def idata(pymc_mod, rng):
72
- with warnings.catch_warnings():
73
- warnings.simplefilter("ignore")
74
- with pymc_mod:
75
- idata = pm.sample(
76
- draws=10,
77
- tune=1,
78
- chains=1,
79
- random_seed=rng,
80
- nuts_sampler="numpyro",
81
- progressbar=False,
82
- )
83
- with freeze_dims_and_data(pymc_mod):
84
- idata_prior = pm.sample_prior_predictive(
85
- samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
86
- )
87
-
88
- idata.extend(idata_prior)
89
- return idata
90
-
91
-
92
- @pytest.fixture(scope="session")
93
- def idata_exog(exog_pymc_mod, rng):
94
- with warnings.catch_warnings():
95
- warnings.simplefilter("ignore")
96
-
97
- with exog_pymc_mod:
98
- idata = pm.sample(
99
- draws=10,
100
- tune=1,
101
- chains=1,
102
- random_seed=rng,
103
- nuts_sampler="numpyro",
104
- progressbar=False,
105
- )
106
- with freeze_dims_and_data(pymc_mod):
107
- idata_prior = pm.sample_prior_predictive(
108
- samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
109
- )
110
-
111
- idata.extend(idata_prior)
112
- return idata
113
-
114
-
115
- @pytest.mark.parametrize("group", ["posterior", "prior"])
116
- @pytest.mark.parametrize("matrix", ALL_SAMPLE_OUTPUTS)
117
- def test_no_nans_in_sampling_output(ss_mod, group, matrix, idata):
118
- assert not np.any(np.isnan(idata[group][matrix].values))
119
-
120
-
121
- @pytest.mark.parametrize("group", ["prior", "posterior"])
122
- @pytest.mark.parametrize("kind", ["conditional", "unconditional"])
123
- def test_sampling_methods(group, kind, ss_mod, idata, rng):
124
- assert ss_mod._fit_mode == "JAX"
125
-
126
- f = getattr(ss_mod, f"sample_{kind}_{group}")
127
- with pytest.warns(UserWarning, match="The RandomType SharedVariables"):
128
- test_idata = f(idata, random_seed=rng)
129
-
130
- if kind == "conditional":
131
- for output in ["filtered", "predicted", "smoothed"]:
132
- assert f"{output}_{group}" in test_idata
133
- assert not np.any(np.isnan(test_idata[f"{output}_{group}"].values))
134
- assert not np.any(np.isnan(test_idata[f"{output}_{group}_observed"].values))
135
-
136
- if kind == "unconditional":
137
- for output in ["latent", "observed"]:
138
- assert f"{group}_{output}" in test_idata
139
- assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
140
-
141
-
142
- @pytest.mark.parametrize("filter_output", ["predicted", "filtered", "smoothed"])
143
- def test_forecast(filter_output, ss_mod, idata, rng):
144
- time_idx = idata.posterior.coords["time"].values
145
-
146
- with pytest.warns(UserWarning, match="The RandomType SharedVariables"):
147
- forecast_idata = ss_mod.forecast(
148
- idata, start=time_idx[-1], periods=10, filter_output=filter_output, random_seed=rng
149
- )
150
-
151
- assert forecast_idata.coords["time"].values.shape == (10,)
152
- assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state")
153
- assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state")
154
-
155
- assert not np.any(np.isnan(forecast_idata.forecast_latent.values))
156
- assert not np.any(np.isnan(forecast_idata.forecast_observed.values))