pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pymc_extras/__init__.py +6 -4
  2. pymc_extras/distributions/__init__.py +2 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/distributions/transforms/__init__.py +3 -0
  6. pymc_extras/distributions/transforms/partial_order.py +227 -0
  7. pymc_extras/inference/__init__.py +4 -2
  8. pymc_extras/inference/find_map.py +62 -17
  9. pymc_extras/inference/fit.py +6 -4
  10. pymc_extras/inference/laplace.py +14 -8
  11. pymc_extras/inference/pathfinder/lbfgs.py +49 -13
  12. pymc_extras/inference/pathfinder/pathfinder.py +89 -103
  13. pymc_extras/statespace/core/statespace.py +191 -52
  14. pymc_extras/statespace/filters/distributions.py +15 -16
  15. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  16. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  17. pymc_extras/statespace/models/ETS.py +10 -0
  18. pymc_extras/statespace/models/SARIMAX.py +26 -5
  19. pymc_extras/statespace/models/VARMAX.py +12 -2
  20. pymc_extras/statespace/models/structural.py +18 -5
  21. pymc_extras/statespace/utils/data_tools.py +24 -9
  22. pymc_extras-0.2.6.dist-info/METADATA +318 -0
  23. pymc_extras-0.2.6.dist-info/RECORD +65 -0
  24. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
  25. pymc_extras/version.py +0 -11
  26. pymc_extras/version.txt +0 -1
  27. pymc_extras-0.2.4.dist-info/METADATA +0 -110
  28. pymc_extras-0.2.4.dist-info/RECORD +0 -105
  29. pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
  30. tests/__init__.py +0 -13
  31. tests/distributions/__init__.py +0 -19
  32. tests/distributions/test_continuous.py +0 -185
  33. tests/distributions/test_discrete.py +0 -210
  34. tests/distributions/test_discrete_markov_chain.py +0 -258
  35. tests/distributions/test_multivariate.py +0 -304
  36. tests/model/__init__.py +0 -0
  37. tests/model/marginal/__init__.py +0 -0
  38. tests/model/marginal/test_distributions.py +0 -132
  39. tests/model/marginal/test_graph_analysis.py +0 -182
  40. tests/model/marginal/test_marginal_model.py +0 -967
  41. tests/model/test_model_api.py +0 -38
  42. tests/statespace/__init__.py +0 -0
  43. tests/statespace/test_ETS.py +0 -411
  44. tests/statespace/test_SARIMAX.py +0 -405
  45. tests/statespace/test_VARMAX.py +0 -184
  46. tests/statespace/test_coord_assignment.py +0 -116
  47. tests/statespace/test_distributions.py +0 -270
  48. tests/statespace/test_kalman_filter.py +0 -326
  49. tests/statespace/test_representation.py +0 -175
  50. tests/statespace/test_statespace.py +0 -872
  51. tests/statespace/test_statespace_JAX.py +0 -156
  52. tests/statespace/test_structural.py +0 -836
  53. tests/statespace/utilities/__init__.py +0 -0
  54. tests/statespace/utilities/shared_fixtures.py +0 -9
  55. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  56. tests/statespace/utilities/test_helpers.py +0 -310
  57. tests/test_blackjax_smc.py +0 -222
  58. tests/test_find_map.py +0 -103
  59. tests/test_histogram_approximation.py +0 -109
  60. tests/test_laplace.py +0 -265
  61. tests/test_linearmodel.py +0 -208
  62. tests/test_model_builder.py +0 -306
  63. tests/test_pathfinder.py +0 -203
  64. tests/test_pivoted_cholesky.py +0 -24
  65. tests/test_printing.py +0 -98
  66. tests/test_prior_from_trace.py +0 -172
  67. tests/test_splines.py +0 -77
  68. tests/utils.py +0 -0
  69. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
@@ -1,172 +0,0 @@
1
- # Copyright 2022 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import arviz as az
17
- import numpy as np
18
- import pymc as pm
19
- import pytest
20
-
21
- from pymc.distributions import transforms
22
-
23
- import pymc_extras as pmx
24
-
25
-
26
- @pytest.mark.parametrize(
27
- "case",
28
- [
29
- (("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
30
- (("a", None), dict(name="a", transform=None, dims=None)),
31
- (("a", transforms.log), dict(name="a", transform=transforms.log, dims=None)),
32
- (
33
- ("a", dict(transform=transforms.log)),
34
- dict(name="a", transform=transforms.log, dims=None),
35
- ),
36
- (("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
37
- (("a", dict(name="b", dims="test")), dict(name="b", transform=None, dims="test")),
38
- (("a", ("test",)), dict(name="a", transform=None, dims=("test",))),
39
- ],
40
- )
41
- def test_parsing_arguments(case):
42
- inp, out = case
43
- test = pmx.utils.prior._arg_to_param_cfg(*inp)
44
- assert test == out
45
-
46
-
47
- @pytest.fixture
48
- def coords():
49
- return dict(test=range(3), simplex=range(4))
50
-
51
-
52
- @pytest.fixture(
53
- params=[
54
- [
55
- ("t",),
56
- dict(
57
- a="d",
58
- b=dict(transform=transforms.log, dims=("test",)),
59
- c=dict(transform=transforms.simplex, dims=("simplex",)),
60
- ),
61
- ],
62
- [("t",), dict()],
63
- ]
64
- )
65
- def user_param_cfg(request):
66
- return request.param
67
-
68
-
69
- @pytest.fixture
70
- def param_cfg(user_param_cfg):
71
- return pmx.utils.prior._parse_args(user_param_cfg[0], **user_param_cfg[1])
72
-
73
-
74
- @pytest.fixture
75
- def transformed_data(param_cfg, coords):
76
- vars = dict()
77
- for k, cfg in param_cfg.items():
78
- if cfg["dims"] is not None:
79
- extra_dims = [len(coords[d]) for d in cfg["dims"]]
80
- if cfg["transform"] is not None:
81
- t = np.random.randn(*extra_dims)
82
- extra_dims = tuple(cfg["transform"].forward(t).shape.eval())
83
- else:
84
- extra_dims = []
85
- orig = np.random.randn(4, 100, *extra_dims)
86
- vars[k] = orig
87
- return vars
88
-
89
-
90
- @pytest.fixture
91
- def idata(transformed_data, param_cfg):
92
- vars = dict()
93
- for k, orig in transformed_data.items():
94
- cfg = param_cfg[k]
95
- if cfg["transform"] is not None:
96
- var = cfg["transform"].backward(orig).eval()
97
- else:
98
- var = orig
99
- assert not np.isnan(var).any()
100
- vars[k] = var
101
- return az.convert_to_inference_data(vars)
102
-
103
-
104
- def test_idata_for_tests(idata, param_cfg):
105
- assert set(idata.posterior.keys()) == set(param_cfg)
106
- assert len(idata.posterior.coords["chain"]) == 4
107
- assert len(idata.posterior.coords["draw"]) == 100
108
-
109
-
110
- def test_args_compose():
111
- cfg = pmx.utils.prior._parse_args(
112
- var_names=["a"],
113
- b=("test",),
114
- c=transforms.log,
115
- d="e",
116
- f=dict(dims="test"),
117
- g=dict(name="h", dims="test", transform=transforms.log),
118
- )
119
- assert cfg == dict(
120
- a=dict(name="a", dims=None, transform=None),
121
- b=dict(name="b", dims=("test",), transform=None),
122
- c=dict(name="c", dims=None, transform=transforms.log),
123
- d=dict(name="e", dims=None, transform=None),
124
- f=dict(name="f", dims="test", transform=None),
125
- g=dict(name="h", dims="test", transform=transforms.log),
126
- )
127
-
128
-
129
- def test_transform_idata(transformed_data, idata, param_cfg):
130
- flat_info = pmx.utils.prior._flatten(idata, **param_cfg)
131
- expected_shape = 0
132
- for v in transformed_data.values():
133
- expected_shape += int(np.prod(v.shape[2:]))
134
- assert flat_info["data"].shape[1] == expected_shape
135
- assert len(flat_info["info"]) == len(param_cfg)
136
- assert "sinfo" in flat_info["info"][0]
137
- assert "vinfo" in flat_info["info"][0]
138
-
139
-
140
- @pytest.fixture
141
- def flat_info(idata, param_cfg):
142
- return pmx.utils.prior._flatten(idata, **param_cfg)
143
-
144
-
145
- def test_mean_chol(flat_info):
146
- mean, chol = pmx.utils.prior._mean_chol(flat_info["data"])
147
- assert mean.shape == (flat_info["data"].shape[1],)
148
- assert chol.shape == (flat_info["data"].shape[1],) * 2
149
-
150
-
151
- def test_mvn_prior_from_flat_info(flat_info, coords, param_cfg):
152
- with pm.Model(coords=coords) as model:
153
- priors = pmx.utils.prior._mvn_prior_from_flat_info("trace_prior_", flat_info)
154
- test_prior = pm.sample_prior_predictive(1)
155
- names = [p["name"] for p in param_cfg.values()]
156
- assert set(model.named_vars) == {"trace_prior_", *names}
157
-
158
-
159
- def test_prior_from_idata(idata, user_param_cfg, coords, param_cfg):
160
- with pm.Model(coords=coords) as model:
161
- priors = pmx.utils.prior.prior_from_idata(
162
- idata, var_names=user_param_cfg[0], **user_param_cfg[1]
163
- )
164
- test_prior = pm.sample_prior_predictive(1)
165
- names = [p["name"] for p in param_cfg.values()]
166
- assert set(model.named_vars) == {"trace_prior_", *names}
167
-
168
-
169
- def test_empty(idata, coords):
170
- with pm.Model(coords=coords):
171
- priors = pmx.utils.prior.prior_from_idata(idata)
172
- assert not priors
tests/test_splines.py DELETED
@@ -1,77 +0,0 @@
1
- # Copyright 2022 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import numpy as np
17
- import pytensor.tensor as pt
18
- import pytest
19
-
20
- from pytensor.sparse import SparseTensorType
21
-
22
- import pymc_extras as pmx
23
-
24
-
25
- @pytest.mark.parametrize("dtype", [np.float32, np.float64])
26
- @pytest.mark.parametrize("sparse", [True, False])
27
- def test_spline_construction(dtype, sparse):
28
- x = np.linspace(0, 1, 20, dtype=dtype)
29
- np_out = pmx.utils.spline.numpy_bspline_basis(x, 10, 3)
30
- assert np_out.shape == (20, 10)
31
- assert np_out.dtype == dtype
32
- spline_op = pmx.utils.spline.BSplineBasis(sparse=sparse)
33
- out = spline_op(x, pt.constant(10), pt.constant(3))
34
- if not sparse:
35
- assert isinstance(out.type, pt.TensorType)
36
- else:
37
- assert isinstance(out.type, SparseTensorType)
38
- B = out.eval()
39
- if not sparse:
40
- np.testing.assert_allclose(B, np_out)
41
- else:
42
- np.testing.assert_allclose(B.todense(), np_out)
43
- assert B.shape == (20, 10)
44
-
45
-
46
- @pytest.mark.parametrize("shape", [(100,), (100, 5)])
47
- @pytest.mark.parametrize("sparse", [True, False])
48
- @pytest.mark.parametrize("points", [dict(n=1001), dict(eval_points=np.linspace(0, 1, 1001))])
49
- def test_interpolation_api(shape, sparse, points):
50
- x = np.random.randn(*shape)
51
- yt = pmx.utils.spline.bspline_interpolation(x, **points, sparse=sparse)
52
- y = yt.eval()
53
- assert y.shape == (1001, *shape[1:])
54
-
55
-
56
- @pytest.mark.parametrize(
57
- "params",
58
- [
59
- (dict(sparse="foo", n=100, degree=1), TypeError, "sparse should be True or False"),
60
- (dict(n=100, degree=0.5), TypeError, "degree should be integer"),
61
- (
62
- dict(n=100, eval_points=np.linspace(0, 1), degree=1),
63
- ValueError,
64
- "Please provide one of n or eval_points",
65
- ),
66
- (
67
- dict(degree=1),
68
- ValueError,
69
- "Please provide one of n or eval_points",
70
- ),
71
- ],
72
- )
73
- def test_bad_calls(params):
74
- kw, E, err = params
75
- x = np.random.randn(10)
76
- with pytest.raises(E, match=err):
77
- pmx.utils.spline.bspline_interpolation(x, **kw)
tests/utils.py DELETED
File without changes