pymc-extras 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. pymc_extras/__init__.py +5 -1
  2. pymc_extras/distributions/continuous.py +3 -2
  3. pymc_extras/distributions/discrete.py +3 -1
  4. pymc_extras/inference/find_map.py +62 -17
  5. pymc_extras/inference/laplace.py +10 -7
  6. pymc_extras/statespace/core/statespace.py +191 -52
  7. pymc_extras/statespace/filters/distributions.py +15 -16
  8. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  9. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  10. pymc_extras/statespace/models/ETS.py +10 -0
  11. pymc_extras/statespace/models/SARIMAX.py +26 -5
  12. pymc_extras/statespace/models/VARMAX.py +12 -2
  13. pymc_extras/statespace/models/structural.py +18 -5
  14. pymc_extras-0.2.6.dist-info/METADATA +318 -0
  15. pymc_extras-0.2.6.dist-info/RECORD +65 -0
  16. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
  17. pymc_extras/version.py +0 -11
  18. pymc_extras/version.txt +0 -1
  19. pymc_extras-0.2.5.dist-info/METADATA +0 -112
  20. pymc_extras-0.2.5.dist-info/RECORD +0 -108
  21. pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
  22. tests/__init__.py +0 -13
  23. tests/distributions/__init__.py +0 -19
  24. tests/distributions/test_continuous.py +0 -185
  25. tests/distributions/test_discrete.py +0 -210
  26. tests/distributions/test_discrete_markov_chain.py +0 -258
  27. tests/distributions/test_multivariate.py +0 -304
  28. tests/distributions/test_transform.py +0 -77
  29. tests/model/__init__.py +0 -0
  30. tests/model/marginal/__init__.py +0 -0
  31. tests/model/marginal/test_distributions.py +0 -132
  32. tests/model/marginal/test_graph_analysis.py +0 -182
  33. tests/model/marginal/test_marginal_model.py +0 -967
  34. tests/model/test_model_api.py +0 -38
  35. tests/statespace/__init__.py +0 -0
  36. tests/statespace/test_ETS.py +0 -411
  37. tests/statespace/test_SARIMAX.py +0 -405
  38. tests/statespace/test_VARMAX.py +0 -184
  39. tests/statespace/test_coord_assignment.py +0 -181
  40. tests/statespace/test_distributions.py +0 -270
  41. tests/statespace/test_kalman_filter.py +0 -326
  42. tests/statespace/test_representation.py +0 -175
  43. tests/statespace/test_statespace.py +0 -872
  44. tests/statespace/test_statespace_JAX.py +0 -156
  45. tests/statespace/test_structural.py +0 -836
  46. tests/statespace/utilities/__init__.py +0 -0
  47. tests/statespace/utilities/shared_fixtures.py +0 -9
  48. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  49. tests/statespace/utilities/test_helpers.py +0 -310
  50. tests/test_blackjax_smc.py +0 -222
  51. tests/test_find_map.py +0 -103
  52. tests/test_histogram_approximation.py +0 -109
  53. tests/test_laplace.py +0 -281
  54. tests/test_linearmodel.py +0 -208
  55. tests/test_model_builder.py +0 -306
  56. tests/test_pathfinder.py +0 -297
  57. tests/test_pivoted_cholesky.py +0 -24
  58. tests/test_printing.py +0 -98
  59. tests/test_prior_from_trace.py +0 -172
  60. tests/test_splines.py +0 -77
  61. tests/utils.py +0 -0
  62. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.6.dist-info}/licenses/LICENSE +0 -0
tests/test_printing.py DELETED
@@ -1,98 +0,0 @@
1
- import numpy as np
2
- import pymc as pm
3
-
4
- from rich.console import Console
5
-
6
- from pymc_extras.printing import model_table
7
-
8
-
9
- def get_text(table) -> str:
10
- console = Console(width=80)
11
- with console.capture() as capture:
12
- console.print(table)
13
- return capture.get()
14
-
15
-
16
- def test_model_table():
17
- with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model:
18
- x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
19
- y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
20
-
21
- mu = pm.Normal("mu", mu=0, sigma=1)
22
- sigma = pm.HalfNormal("sigma", sigma=1)
23
- global_intercept = pm.Normal("global_intercept", mu=0, sigma=1)
24
- intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, shape=(20, 1))
25
- beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject")
26
-
27
- mu_trial = pm.Deterministic(
28
- "mu_trial",
29
- global_intercept.squeeze() + intercept_subject + beta_subject * x_data,
30
- dims=["trial", "subject"],
31
- )
32
- noise = pm.Exponential("noise", lam=1)
33
- y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject"))
34
-
35
- pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject")
36
-
37
- table_txt = get_text(model_table(model))
38
- expected = """ Variable Expression Dimensions
39
- ────────────────────────────────────────────────────────────────────────────────
40
- x_data = Data trial[6] × subject[20]
41
- y_data = Data trial[6] × subject[20]
42
-
43
- mu ~ Normal(0, 1)
44
- sigma ~ HalfNormal(0, 1)
45
- global_intercept ~ Normal(0, 1)
46
- intercept_subject ~ Normal(0, 1) [20, 1]
47
- beta_subject ~ Normal(mu, sigma) subject[20]
48
- noise ~ Exponential(f())
49
- Parameter count = 44
50
-
51
- mu_trial = f(intercept_subject, trial[6] × subject[20]
52
- beta_subject,
53
- global_intercept)
54
-
55
- beta_subject_penalty = Potential(f(beta_subject)) subject[20]
56
-
57
- y ~ Normal(mu_trial, noise) trial[6] × subject[20]
58
- """
59
- assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
60
-
61
- table_txt = get_text(model_table(model, split_groups=False))
62
- expected = """ Variable Expression Dimensions
63
- ────────────────────────────────────────────────────────────────────────────────
64
- x_data = Data trial[6] × subject[20]
65
- y_data = Data trial[6] × subject[20]
66
- mu ~ Normal(0, 1)
67
- sigma ~ HalfNormal(0, 1)
68
- global_intercept ~ Normal(0, 1)
69
- intercept_subject ~ Normal(0, 1) [20, 1]
70
- beta_subject ~ Normal(mu, sigma) subject[20]
71
- mu_trial = f(intercept_subject, trial[6] × subject[20]
72
- beta_subject,
73
- global_intercept)
74
- noise ~ Exponential(f())
75
- y ~ Normal(mu_trial, noise) trial[6] × subject[20]
76
- beta_subject_penalty = Potential(f(beta_subject)) subject[20]
77
- Parameter count = 44
78
- """
79
- assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
80
-
81
- table_txt = get_text(
82
- model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False)
83
- )
84
- expected = """ Variable Expression Dimensions
85
- ────────────────────────────────────────────────────────────────────────────
86
- x_data = Data trial[6] × subject[20]
87
- y_data = Data trial[6] × subject[20]
88
- mu ~ Normal(0, 1)
89
- sigma ~ HalfNormal(0, 1)
90
- global_intercept ~ Normal(0, 1)
91
- intercept_subject ~ Normal(0, 1) [20, 1]
92
- beta_subject ~ Normal(mu, sigma) subject[20]
93
- mu_trial = f(intercept_subject, ...) trial[6] × subject[20]
94
- noise ~ Exponential(f())
95
- y ~ Normal(mu_trial, noise) trial[6] × subject[20]
96
- beta_subject_penalty = Potential(f(beta_subject)) subject[20]
97
- """
98
- assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
@@ -1,172 +0,0 @@
1
- # Copyright 2022 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import arviz as az
17
- import numpy as np
18
- import pymc as pm
19
- import pytest
20
-
21
- from pymc.distributions import transforms
22
-
23
- import pymc_extras as pmx
24
-
25
-
26
- @pytest.mark.parametrize(
27
- "case",
28
- [
29
- (("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
30
- (("a", None), dict(name="a", transform=None, dims=None)),
31
- (("a", transforms.log), dict(name="a", transform=transforms.log, dims=None)),
32
- (
33
- ("a", dict(transform=transforms.log)),
34
- dict(name="a", transform=transforms.log, dims=None),
35
- ),
36
- (("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
37
- (("a", dict(name="b", dims="test")), dict(name="b", transform=None, dims="test")),
38
- (("a", ("test",)), dict(name="a", transform=None, dims=("test",))),
39
- ],
40
- )
41
- def test_parsing_arguments(case):
42
- inp, out = case
43
- test = pmx.utils.prior._arg_to_param_cfg(*inp)
44
- assert test == out
45
-
46
-
47
- @pytest.fixture
48
- def coords():
49
- return dict(test=range(3), simplex=range(4))
50
-
51
-
52
- @pytest.fixture(
53
- params=[
54
- [
55
- ("t",),
56
- dict(
57
- a="d",
58
- b=dict(transform=transforms.log, dims=("test",)),
59
- c=dict(transform=transforms.simplex, dims=("simplex",)),
60
- ),
61
- ],
62
- [("t",), dict()],
63
- ]
64
- )
65
- def user_param_cfg(request):
66
- return request.param
67
-
68
-
69
- @pytest.fixture
70
- def param_cfg(user_param_cfg):
71
- return pmx.utils.prior._parse_args(user_param_cfg[0], **user_param_cfg[1])
72
-
73
-
74
- @pytest.fixture
75
- def transformed_data(param_cfg, coords):
76
- vars = dict()
77
- for k, cfg in param_cfg.items():
78
- if cfg["dims"] is not None:
79
- extra_dims = [len(coords[d]) for d in cfg["dims"]]
80
- if cfg["transform"] is not None:
81
- t = np.random.randn(*extra_dims)
82
- extra_dims = tuple(cfg["transform"].forward(t).shape.eval())
83
- else:
84
- extra_dims = []
85
- orig = np.random.randn(4, 100, *extra_dims)
86
- vars[k] = orig
87
- return vars
88
-
89
-
90
- @pytest.fixture
91
- def idata(transformed_data, param_cfg):
92
- vars = dict()
93
- for k, orig in transformed_data.items():
94
- cfg = param_cfg[k]
95
- if cfg["transform"] is not None:
96
- var = cfg["transform"].backward(orig).eval()
97
- else:
98
- var = orig
99
- assert not np.isnan(var).any()
100
- vars[k] = var
101
- return az.convert_to_inference_data(vars)
102
-
103
-
104
- def test_idata_for_tests(idata, param_cfg):
105
- assert set(idata.posterior.keys()) == set(param_cfg)
106
- assert len(idata.posterior.coords["chain"]) == 4
107
- assert len(idata.posterior.coords["draw"]) == 100
108
-
109
-
110
- def test_args_compose():
111
- cfg = pmx.utils.prior._parse_args(
112
- var_names=["a"],
113
- b=("test",),
114
- c=transforms.log,
115
- d="e",
116
- f=dict(dims="test"),
117
- g=dict(name="h", dims="test", transform=transforms.log),
118
- )
119
- assert cfg == dict(
120
- a=dict(name="a", dims=None, transform=None),
121
- b=dict(name="b", dims=("test",), transform=None),
122
- c=dict(name="c", dims=None, transform=transforms.log),
123
- d=dict(name="e", dims=None, transform=None),
124
- f=dict(name="f", dims="test", transform=None),
125
- g=dict(name="h", dims="test", transform=transforms.log),
126
- )
127
-
128
-
129
- def test_transform_idata(transformed_data, idata, param_cfg):
130
- flat_info = pmx.utils.prior._flatten(idata, **param_cfg)
131
- expected_shape = 0
132
- for v in transformed_data.values():
133
- expected_shape += int(np.prod(v.shape[2:]))
134
- assert flat_info["data"].shape[1] == expected_shape
135
- assert len(flat_info["info"]) == len(param_cfg)
136
- assert "sinfo" in flat_info["info"][0]
137
- assert "vinfo" in flat_info["info"][0]
138
-
139
-
140
- @pytest.fixture
141
- def flat_info(idata, param_cfg):
142
- return pmx.utils.prior._flatten(idata, **param_cfg)
143
-
144
-
145
- def test_mean_chol(flat_info):
146
- mean, chol = pmx.utils.prior._mean_chol(flat_info["data"])
147
- assert mean.shape == (flat_info["data"].shape[1],)
148
- assert chol.shape == (flat_info["data"].shape[1],) * 2
149
-
150
-
151
- def test_mvn_prior_from_flat_info(flat_info, coords, param_cfg):
152
- with pm.Model(coords=coords) as model:
153
- priors = pmx.utils.prior._mvn_prior_from_flat_info("trace_prior_", flat_info)
154
- test_prior = pm.sample_prior_predictive(1)
155
- names = [p["name"] for p in param_cfg.values()]
156
- assert set(model.named_vars) == {"trace_prior_", *names}
157
-
158
-
159
- def test_prior_from_idata(idata, user_param_cfg, coords, param_cfg):
160
- with pm.Model(coords=coords) as model:
161
- priors = pmx.utils.prior.prior_from_idata(
162
- idata, var_names=user_param_cfg[0], **user_param_cfg[1]
163
- )
164
- test_prior = pm.sample_prior_predictive(1)
165
- names = [p["name"] for p in param_cfg.values()]
166
- assert set(model.named_vars) == {"trace_prior_", *names}
167
-
168
-
169
- def test_empty(idata, coords):
170
- with pm.Model(coords=coords):
171
- priors = pmx.utils.prior.prior_from_idata(idata)
172
- assert not priors
tests/test_splines.py DELETED
@@ -1,77 +0,0 @@
1
- # Copyright 2022 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import numpy as np
17
- import pytensor.tensor as pt
18
- import pytest
19
-
20
- from pytensor.sparse import SparseTensorType
21
-
22
- import pymc_extras as pmx
23
-
24
-
25
- @pytest.mark.parametrize("dtype", [np.float32, np.float64])
26
- @pytest.mark.parametrize("sparse", [True, False])
27
- def test_spline_construction(dtype, sparse):
28
- x = np.linspace(0, 1, 20, dtype=dtype)
29
- np_out = pmx.utils.spline.numpy_bspline_basis(x, 10, 3)
30
- assert np_out.shape == (20, 10)
31
- assert np_out.dtype == dtype
32
- spline_op = pmx.utils.spline.BSplineBasis(sparse=sparse)
33
- out = spline_op(x, pt.constant(10), pt.constant(3))
34
- if not sparse:
35
- assert isinstance(out.type, pt.TensorType)
36
- else:
37
- assert isinstance(out.type, SparseTensorType)
38
- B = out.eval()
39
- if not sparse:
40
- np.testing.assert_allclose(B, np_out)
41
- else:
42
- np.testing.assert_allclose(B.todense(), np_out)
43
- assert B.shape == (20, 10)
44
-
45
-
46
- @pytest.mark.parametrize("shape", [(100,), (100, 5)])
47
- @pytest.mark.parametrize("sparse", [True, False])
48
- @pytest.mark.parametrize("points", [dict(n=1001), dict(eval_points=np.linspace(0, 1, 1001))])
49
- def test_interpolation_api(shape, sparse, points):
50
- x = np.random.randn(*shape)
51
- yt = pmx.utils.spline.bspline_interpolation(x, **points, sparse=sparse)
52
- y = yt.eval()
53
- assert y.shape == (1001, *shape[1:])
54
-
55
-
56
- @pytest.mark.parametrize(
57
- "params",
58
- [
59
- (dict(sparse="foo", n=100, degree=1), TypeError, "sparse should be True or False"),
60
- (dict(n=100, degree=0.5), TypeError, "degree should be integer"),
61
- (
62
- dict(n=100, eval_points=np.linspace(0, 1), degree=1),
63
- ValueError,
64
- "Please provide one of n or eval_points",
65
- ),
66
- (
67
- dict(degree=1),
68
- ValueError,
69
- "Please provide one of n or eval_points",
70
- ),
71
- ],
72
- )
73
- def test_bad_calls(params):
74
- kw, E, err = params
75
- x = np.random.randn(10)
76
- with pytest.raises(E, match=err):
77
- pmx.utils.spline.bspline_interpolation(x, **kw)
tests/utils.py DELETED
File without changes