pymc-extras 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +1 -3
- pymc_extras/distributions/__init__.py +2 -0
- pymc_extras/distributions/transforms/__init__.py +3 -0
- pymc_extras/distributions/transforms/partial_order.py +227 -0
- pymc_extras/inference/__init__.py +4 -2
- pymc_extras/inference/fit.py +6 -4
- pymc_extras/inference/laplace.py +4 -1
- pymc_extras/inference/pathfinder/lbfgs.py +49 -13
- pymc_extras/inference/pathfinder/pathfinder.py +89 -103
- pymc_extras/statespace/utils/data_tools.py +24 -9
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.5.dist-info}/METADATA +4 -2
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.5.dist-info}/RECORD +20 -17
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.5.dist-info}/WHEEL +1 -1
- tests/distributions/test_transform.py +77 -0
- tests/statespace/test_coord_assignment.py +65 -0
- tests/test_laplace.py +16 -0
- tests/test_pathfinder.py +101 -7
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.5.dist-info/licenses}/LICENSE +0 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.5.dist-info}/top_level.txt +0 -0
|
@@ -1,26 +1,28 @@
|
|
|
1
|
-
pymc_extras/__init__.py,sha256=
|
|
1
|
+
pymc_extras/__init__.py,sha256=lYGf9TcwUHROIElkX7Epnb7-IppcmiSEYuxdtRzqS3s,1195
|
|
2
2
|
pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
|
|
3
3
|
pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
|
|
4
4
|
pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
|
|
5
5
|
pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
|
|
6
|
-
pymc_extras/version.txt,sha256=
|
|
7
|
-
pymc_extras/distributions/__init__.py,sha256=
|
|
6
|
+
pymc_extras/version.txt,sha256=6Vn3UOktu3YUriislvCjcnLK7YHYu7dYeRr3v7thBqA,6
|
|
7
|
+
pymc_extras/distributions/__init__.py,sha256=fDbrBt9mxEVp2CDPwnyCW3oiutzZ0PduB8EUH3fUrjI,1377
|
|
8
8
|
pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
|
|
9
9
|
pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
|
|
10
10
|
pymc_extras/distributions/histogram_utils.py,sha256=5RTvlGCUrp2qzshrchmPyWxjhs6RIYL62SMikjDM1TU,5814
|
|
11
11
|
pymc_extras/distributions/timeseries.py,sha256=M5MZ-nik_tgkaoZ1hdUGEZ9g04DQyVLwszVJqSKwNcY,12719
|
|
12
12
|
pymc_extras/distributions/multivariate/__init__.py,sha256=E8OeLW9tTotCbrUjEo4um76-_WQD56PehsPzkKmhfyA,93
|
|
13
13
|
pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNPi727uTbvAdxl9fm1zNBqY,16005
|
|
14
|
+
pymc_extras/distributions/transforms/__init__.py,sha256=FUp2vyRE6_2eUcQ_FVt5Dn0-vy5I-puV-Kz13-QtLNc,104
|
|
15
|
+
pymc_extras/distributions/transforms/partial_order.py,sha256=oEZlc9WgnGR46uFEjLzKEUxlhzIo2vrUUbBE3vYrsfQ,8404
|
|
14
16
|
pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
|
|
15
17
|
pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
|
|
16
|
-
pymc_extras/inference/__init__.py,sha256=
|
|
18
|
+
pymc_extras/inference/__init__.py,sha256=UH6S0bGfQKKyTSuqf7yezdy9PeE2bDU8U1v4eIRv4ZI,887
|
|
17
19
|
pymc_extras/inference/find_map.py,sha256=vl5l0ei48PnX-uTuHVTr-9QpCEHc8xog-KK6sOnJ8LU,16513
|
|
18
|
-
pymc_extras/inference/fit.py,sha256=
|
|
19
|
-
pymc_extras/inference/laplace.py,sha256=
|
|
20
|
+
pymc_extras/inference/fit.py,sha256=oe20RAajImZ-VD9Ucbzri8Bof4Y2KHNhNRG19v9O3lI,1336
|
|
21
|
+
pymc_extras/inference/laplace.py,sha256=cqarAdbFaOH74AkPUF4c7c_Hswa5mqmhgHpsgrkebHY,21860
|
|
20
22
|
pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
|
|
21
23
|
pymc_extras/inference/pathfinder/importance_sampling.py,sha256=NwxepXOFit3cA5zEebniKdlnJ1rZWg56aMlH4MEOcG4,6264
|
|
22
|
-
pymc_extras/inference/pathfinder/lbfgs.py,sha256=
|
|
23
|
-
pymc_extras/inference/pathfinder/pathfinder.py,sha256=
|
|
24
|
+
pymc_extras/inference/pathfinder/lbfgs.py,sha256=GOoJBil5Kft_iFwGNUGKSeqzI5x_shA4KQWDwgGuQtQ,7110
|
|
25
|
+
pymc_extras/inference/pathfinder/pathfinder.py,sha256=GW04HQurj_3Nlo1C6_K2tEIeigo8x0buV3FqDLA88PQ,64439
|
|
24
26
|
pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
|
|
25
27
|
pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
|
|
26
28
|
pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -52,21 +54,22 @@ pymc_extras/statespace/models/utilities.py,sha256=G9GuHKsghmIYOlfkPtvxBWF-FZY5-5
|
|
|
52
54
|
pymc_extras/statespace/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
53
55
|
pymc_extras/statespace/utils/constants.py,sha256=Kf6j75ABaDQeRODxKQ76wTUQV4F5sTjn1KBcZgCQx20,2403
|
|
54
56
|
pymc_extras/statespace/utils/coord_tools.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
55
|
-
pymc_extras/statespace/utils/data_tools.py,sha256=
|
|
57
|
+
pymc_extras/statespace/utils/data_tools.py,sha256=01sz6XDtLYK9I5xghxYpD-PuDzGXv9D-wFGfTV6FGEw,6566
|
|
56
58
|
pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
|
|
57
59
|
pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
|
|
58
60
|
pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
|
|
59
61
|
pymc_extras/utils/pivoted_cholesky.py,sha256=QtnjP0pAl9b77fLAu-semwT4_9dcoiqx3dz1xKGBjMk,1871
|
|
60
62
|
pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
|
|
61
63
|
pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
|
|
64
|
+
pymc_extras-0.2.5.dist-info/licenses/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
|
|
62
65
|
tests/__init__.py,sha256=-ree9OWVCyTeXLR944OWjrQX2os15HXrRNkhJ7QdRjc,603
|
|
63
66
|
tests/test_blackjax_smc.py,sha256=jcNgcMBxaKyPg9UvHnWQtwoL79LXlSpZfALe3RGEZnQ,7233
|
|
64
67
|
tests/test_find_map.py,sha256=B8ThnXNyfTQeem24QaLoTitFrsxKoq2VQINUdOwzna0,3379
|
|
65
68
|
tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6YxzqA5X4,4674
|
|
66
|
-
tests/test_laplace.py,sha256=
|
|
69
|
+
tests/test_laplace.py,sha256=fArHjwMR7x98K-gZLvrvb3AwNZ7_fo7E0A4SJyt4EGU,9843
|
|
67
70
|
tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
|
|
68
71
|
tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
|
|
69
|
-
tests/test_pathfinder.py,sha256
|
|
72
|
+
tests/test_pathfinder.py,sha256=vlMI1p2Ja5X4QIaSV4h6U41I303rEppfO0JqE3xe1Rs,10023
|
|
70
73
|
tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
|
|
71
74
|
tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
|
|
72
75
|
tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
|
|
@@ -77,6 +80,7 @@ tests/distributions/test_continuous.py,sha256=1-bu-IP6RgLUJnuPYpOD8ZS1ahYbKtsJ9o
|
|
|
77
80
|
tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcleoCXUi0,7776
|
|
78
81
|
tests/distributions/test_discrete_markov_chain.py,sha256=8RCHZXSB8IWjniuKaGGlM_iTWGmdrcOqginxmrAeEJg,9212
|
|
79
82
|
tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIXeRBKiEn6kjs,10452
|
|
83
|
+
tests/distributions/test_transform.py,sha256=QM9sSQ5eSbuT2pM76nUMWqb-tQa7DGZbT9uwFDqIRUk,2672
|
|
80
84
|
tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
81
85
|
tests/model/test_model_api.py,sha256=FJvMTmexovRELZOUcUyk-6Vwk9qSiH7hIFoiArgl5mk,1040
|
|
82
86
|
tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -87,7 +91,7 @@ tests/statespace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
|
87
91
|
tests/statespace/test_ETS.py,sha256=IPg3uQ7xEGqDMEHu993vtUTV7r-uNAxmw23sr5MVGfQ,15582
|
|
88
92
|
tests/statespace/test_SARIMAX.py,sha256=1BYNOm9aSHnpn-qbpe3YsQVH8m-mXcp_gvKgWhWn1W4,12948
|
|
89
93
|
tests/statespace/test_VARMAX.py,sha256=rJnea9_WEGo9I0iv2eaSbwwFQv0tlIjpN7KAE0eQewU,6336
|
|
90
|
-
tests/statespace/test_coord_assignment.py,sha256=
|
|
94
|
+
tests/statespace/test_coord_assignment.py,sha256=2Mo5196ibkBTwscE7kqQoUsgQphdaagVkOccDi7D4RI,5980
|
|
91
95
|
tests/statespace/test_distributions.py,sha256=WQ_ROyd-PL3cimXTyEtyVaMEVtS7Hue2Z0lN7UnGDyo,9122
|
|
92
96
|
tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_YvQ2lM6hiQ,11600
|
|
93
97
|
tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
|
|
@@ -98,8 +102,7 @@ tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJ
|
|
|
98
102
|
tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
|
|
99
103
|
tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
|
|
100
104
|
tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
|
|
101
|
-
pymc_extras-0.2.
|
|
102
|
-
pymc_extras-0.2.
|
|
103
|
-
pymc_extras-0.2.
|
|
104
|
-
pymc_extras-0.2.
|
|
105
|
-
pymc_extras-0.2.4.dist-info/RECORD,,
|
|
105
|
+
pymc_extras-0.2.5.dist-info/METADATA,sha256=02v5liTQQ55sV8xeFl5EjFpwbOKSYKG6g5lE_4htpBo,5227
|
|
106
|
+
pymc_extras-0.2.5.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
|
107
|
+
pymc_extras-0.2.5.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
|
|
108
|
+
pymc_extras-0.2.5.dist-info/RECORD,,
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright 2025 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pymc as pm
|
|
16
|
+
|
|
17
|
+
from pymc_extras.distributions.transforms import PartialOrder
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TestPartialOrder:
|
|
21
|
+
adj_mats = np.array(
|
|
22
|
+
[
|
|
23
|
+
# 0 < {1, 2} < 3
|
|
24
|
+
[[0, 1, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]],
|
|
25
|
+
# 1 < 0 < 3 < 2
|
|
26
|
+
[[0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0]],
|
|
27
|
+
]
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
valid_values = np.array([[0, 2, 1, 3], [1, 0, 3, 2]], dtype=float)
|
|
31
|
+
|
|
32
|
+
# Test that forward and backward are inverses of eachother
|
|
33
|
+
# And that it works when extra dimensions are added in data
|
|
34
|
+
def test_forward_backward_dimensionality(self):
|
|
35
|
+
po = PartialOrder(self.adj_mats)
|
|
36
|
+
po0 = PartialOrder(self.adj_mats[0])
|
|
37
|
+
vv = self.valid_values
|
|
38
|
+
vv0 = self.valid_values[0]
|
|
39
|
+
|
|
40
|
+
testsets = [
|
|
41
|
+
(vv, po),
|
|
42
|
+
(po.initvals(), po),
|
|
43
|
+
(vv0, po0),
|
|
44
|
+
(po0.initvals(), po0),
|
|
45
|
+
(np.tile(vv0, (2, 1)), po0),
|
|
46
|
+
(np.tile(vv0, (2, 3, 2, 1)), po0),
|
|
47
|
+
(np.tile(vv, (2, 3, 2, 1, 1)), po),
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
for vv, po in testsets:
|
|
51
|
+
fw = po.forward(vv)
|
|
52
|
+
bw = po.backward(fw)
|
|
53
|
+
np.testing.assert_allclose(bw.eval(), vv)
|
|
54
|
+
|
|
55
|
+
def test_sample_model(self):
|
|
56
|
+
po = PartialOrder(self.adj_mats)
|
|
57
|
+
with pm.Model() as model:
|
|
58
|
+
x = pm.Normal(
|
|
59
|
+
"x",
|
|
60
|
+
size=(3, 2, 4),
|
|
61
|
+
transform=po,
|
|
62
|
+
initval=po.initvals(shape=(3, 2, 4), lower=-1, upper=1),
|
|
63
|
+
)
|
|
64
|
+
idata = pm.sample()
|
|
65
|
+
|
|
66
|
+
# Check that the order constraints are satisfied
|
|
67
|
+
# Move chain, draw and "3" dimensions to the back
|
|
68
|
+
xvs = idata.posterior.x.values.transpose(3, 4, 0, 1, 2)
|
|
69
|
+
x0 = xvs[0] # 0 < {1, 2} < 3
|
|
70
|
+
assert (
|
|
71
|
+
(x0[0] < x0[1]).all()
|
|
72
|
+
and (x0[0] < x0[2]).all()
|
|
73
|
+
and (x0[1] < x0[3]).all()
|
|
74
|
+
and (x0[2] < x0[3]).all()
|
|
75
|
+
)
|
|
76
|
+
x1 = xvs[1] # 1 < 0 < 3 < 2
|
|
77
|
+
assert (x1[1] < x1[0]).all() and (x1[0] < x1[3]).all() and (x1[3] < x1[2]).all()
|
|
@@ -8,6 +8,7 @@ import pytensor.tensor as pt
|
|
|
8
8
|
import pytest
|
|
9
9
|
|
|
10
10
|
from pymc_extras.statespace.models import structural
|
|
11
|
+
from pymc_extras.statespace.models.structural import LevelTrendComponent
|
|
11
12
|
from pymc_extras.statespace.utils.constants import (
|
|
12
13
|
FILTER_OUTPUT_DIMS,
|
|
13
14
|
FILTER_OUTPUT_NAMES,
|
|
@@ -114,3 +115,67 @@ def test_data_index_is_coord(f, warning, create_model):
|
|
|
114
115
|
with warning:
|
|
115
116
|
pymc_model = create_model(f)
|
|
116
117
|
assert TIME_DIM in pymc_model.coords
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def make_model(index):
|
|
121
|
+
n = len(index)
|
|
122
|
+
a = pd.DataFrame(index=index, columns=["A", "B", "C", "D"], data=np.arange(n * 4).reshape(n, 4))
|
|
123
|
+
|
|
124
|
+
mod = LevelTrendComponent(order=2, innovations_order=[0, 1])
|
|
125
|
+
ss_mod = mod.build(name="a", verbose=False)
|
|
126
|
+
|
|
127
|
+
initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values()
|
|
128
|
+
coords = ss_mod.coords
|
|
129
|
+
|
|
130
|
+
with pm.Model(coords=coords) as model:
|
|
131
|
+
P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5)
|
|
132
|
+
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims)
|
|
133
|
+
|
|
134
|
+
initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
|
|
135
|
+
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=sigma_trend_dims)
|
|
136
|
+
|
|
137
|
+
with pytest.warns(UserWarning, match="No time index found on the supplied data"):
|
|
138
|
+
ss_mod.build_statespace_graph(
|
|
139
|
+
a["A"],
|
|
140
|
+
mode="JAX",
|
|
141
|
+
)
|
|
142
|
+
return model
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_integer_index():
|
|
146
|
+
index = np.arange(8).astype(int)
|
|
147
|
+
model = make_model(index)
|
|
148
|
+
assert TIME_DIM in model.coords
|
|
149
|
+
np.testing.assert_allclose(model.coords[TIME_DIM], index)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def test_float_index_raises():
|
|
153
|
+
index = np.linspace(0, 1, 8)
|
|
154
|
+
|
|
155
|
+
with pytest.raises(IndexError, match="Provided index is not an integer index"):
|
|
156
|
+
make_model(index)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def test_non_strictly_monotone_index_raises():
|
|
160
|
+
# Decreases
|
|
161
|
+
index = [0, 1, 2, 1, 2, 3]
|
|
162
|
+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
|
|
163
|
+
make_model(index)
|
|
164
|
+
|
|
165
|
+
# Has gaps
|
|
166
|
+
index = [0, 1, 2, 3, 5, 6]
|
|
167
|
+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
|
|
168
|
+
make_model(index)
|
|
169
|
+
|
|
170
|
+
# Has duplicates
|
|
171
|
+
index = [0, 1, 1, 2, 3, 4]
|
|
172
|
+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
|
|
173
|
+
make_model(index)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def test_multiindex_raises():
|
|
177
|
+
index = pd.MultiIndex.from_tuples([(0, 0), (1, 1), (2, 2), (3, 3)])
|
|
178
|
+
with pytest.raises(
|
|
179
|
+
NotImplementedError, match="MultiIndex panel data is not currently supported"
|
|
180
|
+
):
|
|
181
|
+
make_model(index)
|
tests/test_laplace.py
CHANGED
|
@@ -263,3 +263,19 @@ def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: Gradien
|
|
|
263
263
|
else:
|
|
264
264
|
assert idata.fit.rows.values.tolist() == ["mu", "sigma"]
|
|
265
265
|
np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 1.5]), atol=0.1)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def test_laplace_scalar():
|
|
269
|
+
# Example model from Statistical Rethinking
|
|
270
|
+
data = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1])
|
|
271
|
+
|
|
272
|
+
with pm.Model():
|
|
273
|
+
p = pm.Uniform("p", 0, 1)
|
|
274
|
+
w = pm.Binomial("w", n=len(data), p=p, observed=data.sum())
|
|
275
|
+
|
|
276
|
+
idata_laplace = pmx.fit_laplace(progressbar=False)
|
|
277
|
+
|
|
278
|
+
assert idata_laplace.fit.mean_vector.shape == (1,)
|
|
279
|
+
assert idata_laplace.fit.covariance_matrix.shape == (1, 1)
|
|
280
|
+
|
|
281
|
+
np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1)
|
tests/test_pathfinder.py
CHANGED
|
@@ -12,14 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
import re
|
|
15
16
|
import sys
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
import pymc as pm
|
|
20
|
+
import pytensor.tensor as pt
|
|
19
21
|
import pytest
|
|
20
22
|
|
|
21
|
-
pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning")
|
|
22
|
-
|
|
23
23
|
import pymc_extras as pmx
|
|
24
24
|
|
|
25
25
|
|
|
@@ -52,7 +52,90 @@ def reference_idata():
|
|
|
52
52
|
return idata
|
|
53
53
|
|
|
54
54
|
|
|
55
|
+
def unstable_lbfgs_update_mask_model() -> pm.Model:
|
|
56
|
+
# data and model from: https://github.com/pymc-devs/pymc-extras/issues/445
|
|
57
|
+
# this scenario made LBFGS struggle leading to a lot of rejected iterations, (result.nit being moderate, but only history.count <= 1).
|
|
58
|
+
# this scenario is used to test that the LBFGS history manager is rejecting iterations as expected and PF can run to completion.
|
|
59
|
+
|
|
60
|
+
# fmt: off
|
|
61
|
+
inp = np.array([0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 1, 1, 2, 0, 1, 2, 1, 0, 1, 0, 1, 0, 1, 0])
|
|
62
|
+
|
|
63
|
+
res = np.array([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[0,1,0,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,1,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[1,0,0,0,0],[0,1,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,0,1,0]])
|
|
64
|
+
# fmt: on
|
|
65
|
+
|
|
66
|
+
n_ordered = res.shape[1]
|
|
67
|
+
coords = {
|
|
68
|
+
"obs": np.arange(len(inp)),
|
|
69
|
+
"inp": np.arange(max(inp) + 1),
|
|
70
|
+
"outp": np.arange(res.shape[1]),
|
|
71
|
+
}
|
|
72
|
+
with pm.Model(coords=coords) as mdl:
|
|
73
|
+
mu = pm.Normal("intercept", sigma=3.5)[None]
|
|
74
|
+
|
|
75
|
+
offset = pm.Normal(
|
|
76
|
+
"offset", dims=("inp"), transform=pm.distributions.transforms.ZeroSumTransform([0])
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
scale = 3.5 * pm.HalfStudentT("scale", nu=5)
|
|
80
|
+
mu += (scale * offset)[inp]
|
|
81
|
+
|
|
82
|
+
phi_delta = pm.Dirichlet("phi_diffs", [1.0] * (n_ordered - 1))
|
|
83
|
+
phi = pt.concatenate([[0], pt.cumsum(phi_delta)])
|
|
84
|
+
s_mu = pm.Normal(
|
|
85
|
+
"stereotype_intercept",
|
|
86
|
+
size=n_ordered,
|
|
87
|
+
transform=pm.distributions.transforms.ZeroSumTransform([-1]),
|
|
88
|
+
)
|
|
89
|
+
fprobs = pm.math.softmax(s_mu[None, :] + phi[None, :] * mu[:, None], axis=-1)
|
|
90
|
+
|
|
91
|
+
pm.Multinomial("y_res", p=fprobs, n=np.ones(len(inp)), observed=res, dims=("obs", "outp"))
|
|
92
|
+
|
|
93
|
+
return mdl
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@pytest.mark.parametrize("jitter", [12.0, 500.0, 1000.0])
|
|
97
|
+
def test_unstable_lbfgs_update_mask(capsys, jitter):
|
|
98
|
+
model = unstable_lbfgs_update_mask_model()
|
|
99
|
+
|
|
100
|
+
if jitter < 1000:
|
|
101
|
+
with model:
|
|
102
|
+
idata = pmx.fit(
|
|
103
|
+
method="pathfinder",
|
|
104
|
+
jitter=jitter,
|
|
105
|
+
random_seed=4,
|
|
106
|
+
)
|
|
107
|
+
out, err = capsys.readouterr()
|
|
108
|
+
status_pattern = [
|
|
109
|
+
r"INIT_FAILED_LOW_UPDATE_PCT\s+\d+",
|
|
110
|
+
r"LOW_UPDATE_PCT\s+\d+",
|
|
111
|
+
r"LBFGS_FAILED\s+\d+",
|
|
112
|
+
r"SUCCESS\s+\d+",
|
|
113
|
+
]
|
|
114
|
+
for pattern in status_pattern:
|
|
115
|
+
assert re.search(pattern, out) is not None
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
with pytest.raises(ValueError, match="All paths failed"):
|
|
119
|
+
with model:
|
|
120
|
+
idata = pmx.fit(
|
|
121
|
+
method="pathfinder",
|
|
122
|
+
jitter=1000,
|
|
123
|
+
random_seed=2,
|
|
124
|
+
num_paths=4,
|
|
125
|
+
)
|
|
126
|
+
out, err = capsys.readouterr()
|
|
127
|
+
|
|
128
|
+
status_pattern = [
|
|
129
|
+
r"INIT_FAILED_LOW_UPDATE_PCT\s+2",
|
|
130
|
+
r"LOW_UPDATE_PCT\s+2",
|
|
131
|
+
r"LBFGS_FAILED\s+4",
|
|
132
|
+
]
|
|
133
|
+
for pattern in status_pattern:
|
|
134
|
+
assert re.search(pattern, out) is not None
|
|
135
|
+
|
|
136
|
+
|
|
55
137
|
@pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
|
|
138
|
+
@pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning")
|
|
56
139
|
def test_pathfinder(inference_backend, reference_idata):
|
|
57
140
|
if inference_backend == "blackjax" and sys.platform == "win32":
|
|
58
141
|
pytest.skip("JAX not supported on windows")
|
|
@@ -149,12 +232,11 @@ def test_bfgs_sample():
|
|
|
149
232
|
# get factors
|
|
150
233
|
x_full = pt.as_tensor(x_data, dtype="float64")
|
|
151
234
|
g_full = pt.as_tensor(g_data, dtype="float64")
|
|
152
|
-
epsilon = 1e-11
|
|
153
235
|
|
|
154
236
|
x = x_full[1:]
|
|
155
237
|
g = g_full[1:]
|
|
156
|
-
alpha,
|
|
157
|
-
beta, gamma = inverse_hessian_factors(alpha,
|
|
238
|
+
alpha, s, z = alpha_recover(x_full, g_full)
|
|
239
|
+
beta, gamma = inverse_hessian_factors(alpha, s, z, J)
|
|
158
240
|
|
|
159
241
|
# sample
|
|
160
242
|
phi, logq = bfgs_sample(
|
|
@@ -169,8 +251,8 @@ def test_bfgs_sample():
|
|
|
169
251
|
# check shapes
|
|
170
252
|
assert beta.eval().shape == (L, N, 2 * J)
|
|
171
253
|
assert gamma.eval().shape == (L, 2 * J, 2 * J)
|
|
172
|
-
assert phi.eval()
|
|
173
|
-
assert logq.eval()
|
|
254
|
+
assert all(phi.shape.eval() == (L, num_samples, N))
|
|
255
|
+
assert all(logq.shape.eval() == (L, num_samples))
|
|
174
256
|
|
|
175
257
|
|
|
176
258
|
@pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])
|
|
@@ -201,3 +283,15 @@ def test_pathfinder_importance_sampling(importance_sampling):
|
|
|
201
283
|
assert idata.posterior["mu"].shape == (1, num_draws)
|
|
202
284
|
assert idata.posterior["tau"].shape == (1, num_draws)
|
|
203
285
|
assert idata.posterior["theta"].shape == (1, num_draws, 8)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def test_pathfinder_initvals():
|
|
289
|
+
# Run a model with an ordered transform that will fail unless initvals are in place
|
|
290
|
+
with pm.Model() as mdl:
|
|
291
|
+
pm.Normal("ordered", size=10, transform=pm.distributions.transforms.ordered)
|
|
292
|
+
idata = pmx.fit_pathfinder(initvals={"ordered": np.linspace(0, 1, 10)})
|
|
293
|
+
|
|
294
|
+
# Check that the samples are ordered to make sure transform was applied
|
|
295
|
+
assert np.all(
|
|
296
|
+
idata.posterior["ordered"][..., 1:].values > idata.posterior["ordered"][..., :-1].values
|
|
297
|
+
)
|
|
File without changes
|
|
File without changes
|