pymc-extras 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,26 +1,28 @@
1
- pymc_extras/__init__.py,sha256=IFIEZdPX_Ugq57Bu7jlyrJLpKng-P0FBAAAzl2pFXLE,1266
1
+ pymc_extras/__init__.py,sha256=lYGf9TcwUHROIElkX7Epnb7-IppcmiSEYuxdtRzqS3s,1195
2
2
  pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
3
3
  pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
4
4
  pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
5
5
  pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
6
- pymc_extras/version.txt,sha256=FyW_plJMUmXnwXHPBlaEF9OblH__ScJC8DhZR5yCM0s,6
7
- pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
6
+ pymc_extras/version.txt,sha256=6Vn3UOktu3YUriislvCjcnLK7YHYu7dYeRr3v7thBqA,6
7
+ pymc_extras/distributions/__init__.py,sha256=fDbrBt9mxEVp2CDPwnyCW3oiutzZ0PduB8EUH3fUrjI,1377
8
8
  pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
9
9
  pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
10
10
  pymc_extras/distributions/histogram_utils.py,sha256=5RTvlGCUrp2qzshrchmPyWxjhs6RIYL62SMikjDM1TU,5814
11
11
  pymc_extras/distributions/timeseries.py,sha256=M5MZ-nik_tgkaoZ1hdUGEZ9g04DQyVLwszVJqSKwNcY,12719
12
12
  pymc_extras/distributions/multivariate/__init__.py,sha256=E8OeLW9tTotCbrUjEo4um76-_WQD56PehsPzkKmhfyA,93
13
13
  pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNPi727uTbvAdxl9fm1zNBqY,16005
14
+ pymc_extras/distributions/transforms/__init__.py,sha256=FUp2vyRE6_2eUcQ_FVt5Dn0-vy5I-puV-Kz13-QtLNc,104
15
+ pymc_extras/distributions/transforms/partial_order.py,sha256=oEZlc9WgnGR46uFEjLzKEUxlhzIo2vrUUbBE3vYrsfQ,8404
14
16
  pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
15
17
  pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
16
- pymc_extras/inference/__init__.py,sha256=5cXpaQQnW0mJJ3x8wSxmYu63l--Xab5D_gMtjA6Q3uU,666
18
+ pymc_extras/inference/__init__.py,sha256=UH6S0bGfQKKyTSuqf7yezdy9PeE2bDU8U1v4eIRv4ZI,887
17
19
  pymc_extras/inference/find_map.py,sha256=vl5l0ei48PnX-uTuHVTr-9QpCEHc8xog-KK6sOnJ8LU,16513
18
- pymc_extras/inference/fit.py,sha256=S9R48dh74s6K0MC9Iys4NAwVjP6rVRfx6SF-kPiR70E,1165
19
- pymc_extras/inference/laplace.py,sha256=uOZGp8ssQuhvCHV_Y_v3icsr4rhcYgr_qlr9dS7pcSM,21761
20
+ pymc_extras/inference/fit.py,sha256=oe20RAajImZ-VD9Ucbzri8Bof4Y2KHNhNRG19v9O3lI,1336
21
+ pymc_extras/inference/laplace.py,sha256=cqarAdbFaOH74AkPUF4c7c_Hswa5mqmhgHpsgrkebHY,21860
20
22
  pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
21
23
  pymc_extras/inference/pathfinder/importance_sampling.py,sha256=NwxepXOFit3cA5zEebniKdlnJ1rZWg56aMlH4MEOcG4,6264
22
- pymc_extras/inference/pathfinder/lbfgs.py,sha256=P0UIOVtspdLzDU6alK-y91qzVAzXjYAXPuGmZ1nRqMo,5715
23
- pymc_extras/inference/pathfinder/pathfinder.py,sha256=baw8NUN4hdylM0o4JpCh32xxig-fNFLjh_W9qsvvmM0,64495
24
+ pymc_extras/inference/pathfinder/lbfgs.py,sha256=GOoJBil5Kft_iFwGNUGKSeqzI5x_shA4KQWDwgGuQtQ,7110
25
+ pymc_extras/inference/pathfinder/pathfinder.py,sha256=GW04HQurj_3Nlo1C6_K2tEIeigo8x0buV3FqDLA88PQ,64439
24
26
  pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
25
27
  pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
26
28
  pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -52,21 +54,22 @@ pymc_extras/statespace/models/utilities.py,sha256=G9GuHKsghmIYOlfkPtvxBWF-FZY5-5
52
54
  pymc_extras/statespace/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
53
55
  pymc_extras/statespace/utils/constants.py,sha256=Kf6j75ABaDQeRODxKQ76wTUQV4F5sTjn1KBcZgCQx20,2403
54
56
  pymc_extras/statespace/utils/coord_tools.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
- pymc_extras/statespace/utils/data_tools.py,sha256=caanvrxDu9g-dEKff2bbmaVTs6-71kkSoYIiiSUXhw4,5985
57
+ pymc_extras/statespace/utils/data_tools.py,sha256=01sz6XDtLYK9I5xghxYpD-PuDzGXv9D-wFGfTV6FGEw,6566
56
58
  pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
57
59
  pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
58
60
  pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
59
61
  pymc_extras/utils/pivoted_cholesky.py,sha256=QtnjP0pAl9b77fLAu-semwT4_9dcoiqx3dz1xKGBjMk,1871
60
62
  pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
61
63
  pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
64
+ pymc_extras-0.2.5.dist-info/licenses/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
62
65
  tests/__init__.py,sha256=-ree9OWVCyTeXLR944OWjrQX2os15HXrRNkhJ7QdRjc,603
63
66
  tests/test_blackjax_smc.py,sha256=jcNgcMBxaKyPg9UvHnWQtwoL79LXlSpZfALe3RGEZnQ,7233
64
67
  tests/test_find_map.py,sha256=B8ThnXNyfTQeem24QaLoTitFrsxKoq2VQINUdOwzna0,3379
65
68
  tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6YxzqA5X4,4674
66
- tests/test_laplace.py,sha256=u4o-0y4v1emaTMYr_rOyL_EKY_bQIz0DUXFuwuDbfNg,9314
69
+ tests/test_laplace.py,sha256=fArHjwMR7x98K-gZLvrvb3AwNZ7_fo7E0A4SJyt4EGU,9843
67
70
  tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
68
71
  tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
69
- tests/test_pathfinder.py,sha256=-ekjetUSWnNRe7YausDvD00Cqh0zpBW3xn5z1hJ37MI,6027
72
+ tests/test_pathfinder.py,sha256=vlMI1p2Ja5X4QIaSV4h6U41I303rEppfO0JqE3xe1Rs,10023
70
73
  tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
71
74
  tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
72
75
  tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
@@ -77,6 +80,7 @@ tests/distributions/test_continuous.py,sha256=1-bu-IP6RgLUJnuPYpOD8ZS1ahYbKtsJ9o
77
80
  tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcleoCXUi0,7776
78
81
  tests/distributions/test_discrete_markov_chain.py,sha256=8RCHZXSB8IWjniuKaGGlM_iTWGmdrcOqginxmrAeEJg,9212
79
82
  tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIXeRBKiEn6kjs,10452
83
+ tests/distributions/test_transform.py,sha256=QM9sSQ5eSbuT2pM76nUMWqb-tQa7DGZbT9uwFDqIRUk,2672
80
84
  tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
81
85
  tests/model/test_model_api.py,sha256=FJvMTmexovRELZOUcUyk-6Vwk9qSiH7hIFoiArgl5mk,1040
82
86
  tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -87,7 +91,7 @@ tests/statespace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
87
91
  tests/statespace/test_ETS.py,sha256=IPg3uQ7xEGqDMEHu993vtUTV7r-uNAxmw23sr5MVGfQ,15582
88
92
  tests/statespace/test_SARIMAX.py,sha256=1BYNOm9aSHnpn-qbpe3YsQVH8m-mXcp_gvKgWhWn1W4,12948
89
93
  tests/statespace/test_VARMAX.py,sha256=rJnea9_WEGo9I0iv2eaSbwwFQv0tlIjpN7KAE0eQewU,6336
90
- tests/statespace/test_coord_assignment.py,sha256=2GBm46-0eI4QNh4bvp3D7az58stcA5Zo6VgOo_JkCig,3821
94
+ tests/statespace/test_coord_assignment.py,sha256=2Mo5196ibkBTwscE7kqQoUsgQphdaagVkOccDi7D4RI,5980
91
95
  tests/statespace/test_distributions.py,sha256=WQ_ROyd-PL3cimXTyEtyVaMEVtS7Hue2Z0lN7UnGDyo,9122
92
96
  tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_YvQ2lM6hiQ,11600
93
97
  tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
@@ -98,8 +102,7 @@ tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJ
98
102
  tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
99
103
  tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
100
104
  tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
101
- pymc_extras-0.2.4.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
102
- pymc_extras-0.2.4.dist-info/METADATA,sha256=ozmK251JzJsJLI9yx8NFhhVCgOy5nfcfSfE5IfTP3ok,5154
103
- pymc_extras-0.2.4.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
104
- pymc_extras-0.2.4.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
105
- pymc_extras-0.2.4.dist-info/RECORD,,
105
+ pymc_extras-0.2.5.dist-info/METADATA,sha256=02v5liTQQ55sV8xeFl5EjFpwbOKSYKG6g5lE_4htpBo,5227
106
+ pymc_extras-0.2.5.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
107
+ pymc_extras-0.2.5.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
108
+ pymc_extras-0.2.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (78.1.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -0,0 +1,77 @@
1
+ # Copyright 2025 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import pymc as pm
16
+
17
+ from pymc_extras.distributions.transforms import PartialOrder
18
+
19
+
20
+ class TestPartialOrder:
21
+ adj_mats = np.array(
22
+ [
23
+ # 0 < {1, 2} < 3
24
+ [[0, 1, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]],
25
+ # 1 < 0 < 3 < 2
26
+ [[0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0]],
27
+ ]
28
+ )
29
+
30
+ valid_values = np.array([[0, 2, 1, 3], [1, 0, 3, 2]], dtype=float)
31
+
32
+ # Test that forward and backward are inverses of eachother
33
+ # And that it works when extra dimensions are added in data
34
+ def test_forward_backward_dimensionality(self):
35
+ po = PartialOrder(self.adj_mats)
36
+ po0 = PartialOrder(self.adj_mats[0])
37
+ vv = self.valid_values
38
+ vv0 = self.valid_values[0]
39
+
40
+ testsets = [
41
+ (vv, po),
42
+ (po.initvals(), po),
43
+ (vv0, po0),
44
+ (po0.initvals(), po0),
45
+ (np.tile(vv0, (2, 1)), po0),
46
+ (np.tile(vv0, (2, 3, 2, 1)), po0),
47
+ (np.tile(vv, (2, 3, 2, 1, 1)), po),
48
+ ]
49
+
50
+ for vv, po in testsets:
51
+ fw = po.forward(vv)
52
+ bw = po.backward(fw)
53
+ np.testing.assert_allclose(bw.eval(), vv)
54
+
55
+ def test_sample_model(self):
56
+ po = PartialOrder(self.adj_mats)
57
+ with pm.Model() as model:
58
+ x = pm.Normal(
59
+ "x",
60
+ size=(3, 2, 4),
61
+ transform=po,
62
+ initval=po.initvals(shape=(3, 2, 4), lower=-1, upper=1),
63
+ )
64
+ idata = pm.sample()
65
+
66
+ # Check that the order constraints are satisfied
67
+ # Move chain, draw and "3" dimensions to the back
68
+ xvs = idata.posterior.x.values.transpose(3, 4, 0, 1, 2)
69
+ x0 = xvs[0] # 0 < {1, 2} < 3
70
+ assert (
71
+ (x0[0] < x0[1]).all()
72
+ and (x0[0] < x0[2]).all()
73
+ and (x0[1] < x0[3]).all()
74
+ and (x0[2] < x0[3]).all()
75
+ )
76
+ x1 = xvs[1] # 1 < 0 < 3 < 2
77
+ assert (x1[1] < x1[0]).all() and (x1[0] < x1[3]).all() and (x1[3] < x1[2]).all()
@@ -8,6 +8,7 @@ import pytensor.tensor as pt
8
8
  import pytest
9
9
 
10
10
  from pymc_extras.statespace.models import structural
11
+ from pymc_extras.statespace.models.structural import LevelTrendComponent
11
12
  from pymc_extras.statespace.utils.constants import (
12
13
  FILTER_OUTPUT_DIMS,
13
14
  FILTER_OUTPUT_NAMES,
@@ -114,3 +115,67 @@ def test_data_index_is_coord(f, warning, create_model):
114
115
  with warning:
115
116
  pymc_model = create_model(f)
116
117
  assert TIME_DIM in pymc_model.coords
118
+
119
+
120
+ def make_model(index):
121
+ n = len(index)
122
+ a = pd.DataFrame(index=index, columns=["A", "B", "C", "D"], data=np.arange(n * 4).reshape(n, 4))
123
+
124
+ mod = LevelTrendComponent(order=2, innovations_order=[0, 1])
125
+ ss_mod = mod.build(name="a", verbose=False)
126
+
127
+ initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values()
128
+ coords = ss_mod.coords
129
+
130
+ with pm.Model(coords=coords) as model:
131
+ P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5)
132
+ P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims)
133
+
134
+ initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
135
+ sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=sigma_trend_dims)
136
+
137
+ with pytest.warns(UserWarning, match="No time index found on the supplied data"):
138
+ ss_mod.build_statespace_graph(
139
+ a["A"],
140
+ mode="JAX",
141
+ )
142
+ return model
143
+
144
+
145
+ def test_integer_index():
146
+ index = np.arange(8).astype(int)
147
+ model = make_model(index)
148
+ assert TIME_DIM in model.coords
149
+ np.testing.assert_allclose(model.coords[TIME_DIM], index)
150
+
151
+
152
+ def test_float_index_raises():
153
+ index = np.linspace(0, 1, 8)
154
+
155
+ with pytest.raises(IndexError, match="Provided index is not an integer index"):
156
+ make_model(index)
157
+
158
+
159
+ def test_non_strictly_monotone_index_raises():
160
+ # Decreases
161
+ index = [0, 1, 2, 1, 2, 3]
162
+ with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
163
+ make_model(index)
164
+
165
+ # Has gaps
166
+ index = [0, 1, 2, 3, 5, 6]
167
+ with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
168
+ make_model(index)
169
+
170
+ # Has duplicates
171
+ index = [0, 1, 1, 2, 3, 4]
172
+ with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
173
+ make_model(index)
174
+
175
+
176
+ def test_multiindex_raises():
177
+ index = pd.MultiIndex.from_tuples([(0, 0), (1, 1), (2, 2), (3, 3)])
178
+ with pytest.raises(
179
+ NotImplementedError, match="MultiIndex panel data is not currently supported"
180
+ ):
181
+ make_model(index)
tests/test_laplace.py CHANGED
@@ -263,3 +263,19 @@ def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: Gradien
263
263
  else:
264
264
  assert idata.fit.rows.values.tolist() == ["mu", "sigma"]
265
265
  np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 1.5]), atol=0.1)
266
+
267
+
268
+ def test_laplace_scalar():
269
+ # Example model from Statistical Rethinking
270
+ data = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1])
271
+
272
+ with pm.Model():
273
+ p = pm.Uniform("p", 0, 1)
274
+ w = pm.Binomial("w", n=len(data), p=p, observed=data.sum())
275
+
276
+ idata_laplace = pmx.fit_laplace(progressbar=False)
277
+
278
+ assert idata_laplace.fit.mean_vector.shape == (1,)
279
+ assert idata_laplace.fit.covariance_matrix.shape == (1, 1)
280
+
281
+ np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1)
tests/test_pathfinder.py CHANGED
@@ -12,14 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import re
15
16
  import sys
16
17
 
17
18
  import numpy as np
18
19
  import pymc as pm
20
+ import pytensor.tensor as pt
19
21
  import pytest
20
22
 
21
- pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning")
22
-
23
23
  import pymc_extras as pmx
24
24
 
25
25
 
@@ -52,7 +52,90 @@ def reference_idata():
52
52
  return idata
53
53
 
54
54
 
55
+ def unstable_lbfgs_update_mask_model() -> pm.Model:
56
+ # data and model from: https://github.com/pymc-devs/pymc-extras/issues/445
57
+ # this scenario made LBFGS struggle leading to a lot of rejected iterations, (result.nit being moderate, but only history.count <= 1).
58
+ # this scenario is used to test that the LBFGS history manager is rejecting iterations as expected and PF can run to completion.
59
+
60
+ # fmt: off
61
+ inp = np.array([0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 1, 1, 2, 0, 1, 2, 1, 0, 1, 0, 1, 0, 1, 0])
62
+
63
+ res = np.array([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[0,1,0,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,1,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[1,0,0,0,0],[0,1,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,0,1,0]])
64
+ # fmt: on
65
+
66
+ n_ordered = res.shape[1]
67
+ coords = {
68
+ "obs": np.arange(len(inp)),
69
+ "inp": np.arange(max(inp) + 1),
70
+ "outp": np.arange(res.shape[1]),
71
+ }
72
+ with pm.Model(coords=coords) as mdl:
73
+ mu = pm.Normal("intercept", sigma=3.5)[None]
74
+
75
+ offset = pm.Normal(
76
+ "offset", dims=("inp"), transform=pm.distributions.transforms.ZeroSumTransform([0])
77
+ )
78
+
79
+ scale = 3.5 * pm.HalfStudentT("scale", nu=5)
80
+ mu += (scale * offset)[inp]
81
+
82
+ phi_delta = pm.Dirichlet("phi_diffs", [1.0] * (n_ordered - 1))
83
+ phi = pt.concatenate([[0], pt.cumsum(phi_delta)])
84
+ s_mu = pm.Normal(
85
+ "stereotype_intercept",
86
+ size=n_ordered,
87
+ transform=pm.distributions.transforms.ZeroSumTransform([-1]),
88
+ )
89
+ fprobs = pm.math.softmax(s_mu[None, :] + phi[None, :] * mu[:, None], axis=-1)
90
+
91
+ pm.Multinomial("y_res", p=fprobs, n=np.ones(len(inp)), observed=res, dims=("obs", "outp"))
92
+
93
+ return mdl
94
+
95
+
96
+ @pytest.mark.parametrize("jitter", [12.0, 500.0, 1000.0])
97
+ def test_unstable_lbfgs_update_mask(capsys, jitter):
98
+ model = unstable_lbfgs_update_mask_model()
99
+
100
+ if jitter < 1000:
101
+ with model:
102
+ idata = pmx.fit(
103
+ method="pathfinder",
104
+ jitter=jitter,
105
+ random_seed=4,
106
+ )
107
+ out, err = capsys.readouterr()
108
+ status_pattern = [
109
+ r"INIT_FAILED_LOW_UPDATE_PCT\s+\d+",
110
+ r"LOW_UPDATE_PCT\s+\d+",
111
+ r"LBFGS_FAILED\s+\d+",
112
+ r"SUCCESS\s+\d+",
113
+ ]
114
+ for pattern in status_pattern:
115
+ assert re.search(pattern, out) is not None
116
+
117
+ else:
118
+ with pytest.raises(ValueError, match="All paths failed"):
119
+ with model:
120
+ idata = pmx.fit(
121
+ method="pathfinder",
122
+ jitter=1000,
123
+ random_seed=2,
124
+ num_paths=4,
125
+ )
126
+ out, err = capsys.readouterr()
127
+
128
+ status_pattern = [
129
+ r"INIT_FAILED_LOW_UPDATE_PCT\s+2",
130
+ r"LOW_UPDATE_PCT\s+2",
131
+ r"LBFGS_FAILED\s+4",
132
+ ]
133
+ for pattern in status_pattern:
134
+ assert re.search(pattern, out) is not None
135
+
136
+
55
137
  @pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
138
+ @pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning")
56
139
  def test_pathfinder(inference_backend, reference_idata):
57
140
  if inference_backend == "blackjax" and sys.platform == "win32":
58
141
  pytest.skip("JAX not supported on windows")
@@ -149,12 +232,11 @@ def test_bfgs_sample():
149
232
  # get factors
150
233
  x_full = pt.as_tensor(x_data, dtype="float64")
151
234
  g_full = pt.as_tensor(g_data, dtype="float64")
152
- epsilon = 1e-11
153
235
 
154
236
  x = x_full[1:]
155
237
  g = g_full[1:]
156
- alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon)
157
- beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J)
238
+ alpha, s, z = alpha_recover(x_full, g_full)
239
+ beta, gamma = inverse_hessian_factors(alpha, s, z, J)
158
240
 
159
241
  # sample
160
242
  phi, logq = bfgs_sample(
@@ -169,8 +251,8 @@ def test_bfgs_sample():
169
251
  # check shapes
170
252
  assert beta.eval().shape == (L, N, 2 * J)
171
253
  assert gamma.eval().shape == (L, 2 * J, 2 * J)
172
- assert phi.eval().shape == (L, num_samples, N)
173
- assert logq.eval().shape == (L, num_samples)
254
+ assert all(phi.shape.eval() == (L, num_samples, N))
255
+ assert all(logq.shape.eval() == (L, num_samples))
174
256
 
175
257
 
176
258
  @pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])
@@ -201,3 +283,15 @@ def test_pathfinder_importance_sampling(importance_sampling):
201
283
  assert idata.posterior["mu"].shape == (1, num_draws)
202
284
  assert idata.posterior["tau"].shape == (1, num_draws)
203
285
  assert idata.posterior["theta"].shape == (1, num_draws, 8)
286
+
287
+
288
+ def test_pathfinder_initvals():
289
+ # Run a model with an ordered transform that will fail unless initvals are in place
290
+ with pm.Model() as mdl:
291
+ pm.Normal("ordered", size=10, transform=pm.distributions.transforms.ordered)
292
+ idata = pmx.fit_pathfinder(initvals={"ordered": np.linspace(0, 1, 10)})
293
+
294
+ # Check that the samples are ordered to make sure transform was applied
295
+ assert np.all(
296
+ idata.posterior["ordered"][..., 1:].values > idata.posterior["ordered"][..., :-1].values
297
+ )