pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pymc_extras/__init__.py +6 -4
  2. pymc_extras/distributions/__init__.py +2 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/distributions/transforms/__init__.py +3 -0
  6. pymc_extras/distributions/transforms/partial_order.py +227 -0
  7. pymc_extras/inference/__init__.py +4 -2
  8. pymc_extras/inference/find_map.py +62 -17
  9. pymc_extras/inference/fit.py +6 -4
  10. pymc_extras/inference/laplace.py +14 -8
  11. pymc_extras/inference/pathfinder/lbfgs.py +49 -13
  12. pymc_extras/inference/pathfinder/pathfinder.py +89 -103
  13. pymc_extras/statespace/core/statespace.py +191 -52
  14. pymc_extras/statespace/filters/distributions.py +15 -16
  15. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  16. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  17. pymc_extras/statespace/models/ETS.py +10 -0
  18. pymc_extras/statespace/models/SARIMAX.py +26 -5
  19. pymc_extras/statespace/models/VARMAX.py +12 -2
  20. pymc_extras/statespace/models/structural.py +18 -5
  21. pymc_extras/statespace/utils/data_tools.py +24 -9
  22. pymc_extras-0.2.6.dist-info/METADATA +318 -0
  23. pymc_extras-0.2.6.dist-info/RECORD +65 -0
  24. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
  25. pymc_extras/version.py +0 -11
  26. pymc_extras/version.txt +0 -1
  27. pymc_extras-0.2.4.dist-info/METADATA +0 -110
  28. pymc_extras-0.2.4.dist-info/RECORD +0 -105
  29. pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
  30. tests/__init__.py +0 -13
  31. tests/distributions/__init__.py +0 -19
  32. tests/distributions/test_continuous.py +0 -185
  33. tests/distributions/test_discrete.py +0 -210
  34. tests/distributions/test_discrete_markov_chain.py +0 -258
  35. tests/distributions/test_multivariate.py +0 -304
  36. tests/model/__init__.py +0 -0
  37. tests/model/marginal/__init__.py +0 -0
  38. tests/model/marginal/test_distributions.py +0 -132
  39. tests/model/marginal/test_graph_analysis.py +0 -182
  40. tests/model/marginal/test_marginal_model.py +0 -967
  41. tests/model/test_model_api.py +0 -38
  42. tests/statespace/__init__.py +0 -0
  43. tests/statespace/test_ETS.py +0 -411
  44. tests/statespace/test_SARIMAX.py +0 -405
  45. tests/statespace/test_VARMAX.py +0 -184
  46. tests/statespace/test_coord_assignment.py +0 -116
  47. tests/statespace/test_distributions.py +0 -270
  48. tests/statespace/test_kalman_filter.py +0 -326
  49. tests/statespace/test_representation.py +0 -175
  50. tests/statespace/test_statespace.py +0 -872
  51. tests/statespace/test_statespace_JAX.py +0 -156
  52. tests/statespace/test_structural.py +0 -836
  53. tests/statespace/utilities/__init__.py +0 -0
  54. tests/statespace/utilities/shared_fixtures.py +0 -9
  55. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  56. tests/statespace/utilities/test_helpers.py +0 -310
  57. tests/test_blackjax_smc.py +0 -222
  58. tests/test_find_map.py +0 -103
  59. tests/test_histogram_approximation.py +0 -109
  60. tests/test_laplace.py +0 -265
  61. tests/test_linearmodel.py +0 -208
  62. tests/test_model_builder.py +0 -306
  63. tests/test_pathfinder.py +0 -203
  64. tests/test_pivoted_cholesky.py +0 -24
  65. tests/test_printing.py +0 -98
  66. tests/test_prior_from_trace.py +0 -172
  67. tests/test_splines.py +0 -77
  68. tests/utils.py +0 -0
  69. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
@@ -1,156 +0,0 @@
1
- import warnings
2
-
3
- import numpy as np
4
- import pymc as pm
5
- import pytensor
6
- import pytensor.tensor as pt
7
- import pytest
8
-
9
- from pymc.model.transform.optimization import freeze_dims_and_data
10
-
11
- from pymc_extras.statespace.utils.constants import (
12
- FILTER_OUTPUT_NAMES,
13
- MATRIX_NAMES,
14
- SMOOTHER_OUTPUT_NAMES,
15
- )
16
- from tests.statespace.test_statespace import ( # pylint: disable=unused-import
17
- exog_ss_mod,
18
- ss_mod,
19
- )
20
- from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
21
- rng,
22
- )
23
- from tests.statespace.utilities.test_helpers import load_nile_test_data
24
-
25
- pytest.importorskip("jax")
26
- pytest.importorskip("numpyro")
27
-
28
-
29
- floatX = pytensor.config.floatX
30
- nile = load_nile_test_data()
31
- ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
32
-
33
-
34
- @pytest.fixture(scope="session")
35
- def pymc_mod(ss_mod):
36
- with pm.Model(coords=ss_mod.coords) as pymc_mod:
37
- rho = pm.Beta("rho", 1, 1)
38
- zeta = pm.Deterministic("zeta", 1 - rho)
39
-
40
- ss_mod.build_statespace_graph(
41
- data=nile, mode="JAX", save_kalman_filter_outputs_in_idata=True
42
- )
43
- names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
44
- for name, matrix in zip(names, ss_mod.unpack_statespace()):
45
- pm.Deterministic(name, matrix)
46
-
47
- return pymc_mod
48
-
49
-
50
- @pytest.fixture(scope="session")
51
- def exog_pymc_mod(exog_ss_mod, rng):
52
- y = rng.normal(size=(100, 1)).astype(floatX)
53
- X = rng.normal(size=(100, 3)).astype(floatX)
54
-
55
- with pm.Model(coords=exog_ss_mod.coords) as m:
56
- exog_data = pm.Data("data_exog", X)
57
- initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
58
- P0_sigma = pm.Exponential("P0_sigma", 1)
59
- P0 = pm.Deterministic(
60
- "P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"]
61
- )
62
- beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
63
-
64
- sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
65
- exog_ss_mod.build_statespace_graph(y, mode="JAX")
66
-
67
- return m
68
-
69
-
70
- @pytest.fixture(scope="session")
71
- def idata(pymc_mod, rng):
72
- with warnings.catch_warnings():
73
- warnings.simplefilter("ignore")
74
- with pymc_mod:
75
- idata = pm.sample(
76
- draws=10,
77
- tune=1,
78
- chains=1,
79
- random_seed=rng,
80
- nuts_sampler="numpyro",
81
- progressbar=False,
82
- )
83
- with freeze_dims_and_data(pymc_mod):
84
- idata_prior = pm.sample_prior_predictive(
85
- samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
86
- )
87
-
88
- idata.extend(idata_prior)
89
- return idata
90
-
91
-
92
- @pytest.fixture(scope="session")
93
- def idata_exog(exog_pymc_mod, rng):
94
- with warnings.catch_warnings():
95
- warnings.simplefilter("ignore")
96
-
97
- with exog_pymc_mod:
98
- idata = pm.sample(
99
- draws=10,
100
- tune=1,
101
- chains=1,
102
- random_seed=rng,
103
- nuts_sampler="numpyro",
104
- progressbar=False,
105
- )
106
- with freeze_dims_and_data(pymc_mod):
107
- idata_prior = pm.sample_prior_predictive(
108
- samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
109
- )
110
-
111
- idata.extend(idata_prior)
112
- return idata
113
-
114
-
115
- @pytest.mark.parametrize("group", ["posterior", "prior"])
116
- @pytest.mark.parametrize("matrix", ALL_SAMPLE_OUTPUTS)
117
- def test_no_nans_in_sampling_output(ss_mod, group, matrix, idata):
118
- assert not np.any(np.isnan(idata[group][matrix].values))
119
-
120
-
121
- @pytest.mark.parametrize("group", ["prior", "posterior"])
122
- @pytest.mark.parametrize("kind", ["conditional", "unconditional"])
123
- def test_sampling_methods(group, kind, ss_mod, idata, rng):
124
- assert ss_mod._fit_mode == "JAX"
125
-
126
- f = getattr(ss_mod, f"sample_{kind}_{group}")
127
- with pytest.warns(UserWarning, match="The RandomType SharedVariables"):
128
- test_idata = f(idata, random_seed=rng)
129
-
130
- if kind == "conditional":
131
- for output in ["filtered", "predicted", "smoothed"]:
132
- assert f"{output}_{group}" in test_idata
133
- assert not np.any(np.isnan(test_idata[f"{output}_{group}"].values))
134
- assert not np.any(np.isnan(test_idata[f"{output}_{group}_observed"].values))
135
-
136
- if kind == "unconditional":
137
- for output in ["latent", "observed"]:
138
- assert f"{group}_{output}" in test_idata
139
- assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
140
-
141
-
142
- @pytest.mark.parametrize("filter_output", ["predicted", "filtered", "smoothed"])
143
- def test_forecast(filter_output, ss_mod, idata, rng):
144
- time_idx = idata.posterior.coords["time"].values
145
-
146
- with pytest.warns(UserWarning, match="The RandomType SharedVariables"):
147
- forecast_idata = ss_mod.forecast(
148
- idata, start=time_idx[-1], periods=10, filter_output=filter_output, random_seed=rng
149
- )
150
-
151
- assert forecast_idata.coords["time"].values.shape == (10,)
152
- assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state")
153
- assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state")
154
-
155
- assert not np.any(np.isnan(forecast_idata.forecast_latent.values))
156
- assert not np.any(np.isnan(forecast_idata.forecast_observed.values))