pymc-extras 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +1 -3
- pymc_extras/distributions/__init__.py +2 -0
- pymc_extras/distributions/transforms/__init__.py +3 -0
- pymc_extras/distributions/transforms/partial_order.py +227 -0
- pymc_extras/inference/__init__.py +4 -2
- pymc_extras/inference/fit.py +6 -4
- pymc_extras/inference/laplace.py +4 -1
- pymc_extras/inference/pathfinder/importance_sampling.py +23 -17
- pymc_extras/inference/pathfinder/lbfgs.py +49 -13
- pymc_extras/inference/pathfinder/pathfinder.py +136 -118
- pymc_extras/statespace/core/statespace.py +5 -4
- pymc_extras/statespace/filters/distributions.py +9 -45
- pymc_extras/statespace/utils/data_tools.py +24 -9
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/METADATA +5 -3
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/RECORD +23 -20
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/WHEEL +1 -1
- tests/distributions/test_transform.py +77 -0
- tests/statespace/test_coord_assignment.py +65 -0
- tests/test_laplace.py +16 -0
- tests/test_pathfinder.py +141 -17
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info/licenses}/LICENSE +0 -0
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/top_level.txt +0 -0
|
@@ -6,11 +6,9 @@ import pytensor.tensor as pt
|
|
|
6
6
|
from pymc import intX
|
|
7
7
|
from pymc.distributions.dist_math import check_parameters
|
|
8
8
|
from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
|
|
9
|
-
from pymc.distributions.multivariate import MvNormal
|
|
10
9
|
from pymc.distributions.shape_utils import get_support_shape_1d
|
|
11
10
|
from pymc.logprob.abstract import _logprob
|
|
12
11
|
from pytensor.graph.basic import Node
|
|
13
|
-
from pytensor.tensor.random.basic import MvNormalRV
|
|
14
12
|
|
|
15
13
|
floatX = pytensor.config.floatX
|
|
16
14
|
COV_ZERO_TOL = 0
|
|
@@ -49,44 +47,6 @@ def make_signature(sequence_names):
|
|
|
49
47
|
return f"{signature},[rng]->[rng],({time},{state_and_obs})"
|
|
50
48
|
|
|
51
49
|
|
|
52
|
-
class MvNormalSVDRV(MvNormalRV):
|
|
53
|
-
name = "multivariate_normal"
|
|
54
|
-
signature = "(n),(n,n)->(n)"
|
|
55
|
-
dtype = "floatX"
|
|
56
|
-
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class MvNormalSVD(MvNormal):
|
|
60
|
-
"""Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd".
|
|
61
|
-
|
|
62
|
-
A JAX MvNormal robust to low-rank covariance matrices
|
|
63
|
-
"""
|
|
64
|
-
|
|
65
|
-
rv_op = MvNormalSVDRV()
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
try:
|
|
69
|
-
import jax.random
|
|
70
|
-
|
|
71
|
-
from pytensor.link.jax.dispatch.random import jax_sample_fn
|
|
72
|
-
|
|
73
|
-
@jax_sample_fn.register(MvNormalSVDRV)
|
|
74
|
-
def jax_sample_fn_mvnormal_svd(op, node):
|
|
75
|
-
def sample_fn(rng, size, dtype, *parameters):
|
|
76
|
-
rng_key = rng["jax_state"]
|
|
77
|
-
rng_key, sampling_key = jax.random.split(rng_key, 2)
|
|
78
|
-
sample = jax.random.multivariate_normal(
|
|
79
|
-
sampling_key, *parameters, shape=size, dtype=dtype, method="svd"
|
|
80
|
-
)
|
|
81
|
-
rng["jax_state"] = rng_key
|
|
82
|
-
return (rng, sample)
|
|
83
|
-
|
|
84
|
-
return sample_fn
|
|
85
|
-
|
|
86
|
-
except ImportError:
|
|
87
|
-
pass
|
|
88
|
-
|
|
89
|
-
|
|
90
50
|
class LinearGaussianStateSpaceRV(SymbolicRandomVariable):
|
|
91
51
|
default_output = 1
|
|
92
52
|
_print_name = ("LinearGuassianStateSpace", "\\operatorname{LinearGuassianStateSpace}")
|
|
@@ -244,8 +204,12 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
244
204
|
k = T.shape[0]
|
|
245
205
|
a = state[:k]
|
|
246
206
|
|
|
247
|
-
middle_rng, a_innovation =
|
|
248
|
-
|
|
207
|
+
middle_rng, a_innovation = pm.MvNormal.dist(
|
|
208
|
+
mu=0, cov=Q, rng=rng, method="svd"
|
|
209
|
+
).owner.outputs
|
|
210
|
+
next_rng, y_innovation = pm.MvNormal.dist(
|
|
211
|
+
mu=0, cov=H, rng=middle_rng, method="svd"
|
|
212
|
+
).owner.outputs
|
|
249
213
|
|
|
250
214
|
a_mu = c + T @ a
|
|
251
215
|
a_next = a_mu + R @ a_innovation
|
|
@@ -260,8 +224,8 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
260
224
|
Z_init = Z_ if Z_ in non_sequences else Z_[0]
|
|
261
225
|
H_init = H_ if H_ in non_sequences else H_[0]
|
|
262
226
|
|
|
263
|
-
init_x_ =
|
|
264
|
-
init_y_ =
|
|
227
|
+
init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method="svd")
|
|
228
|
+
init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method="svd")
|
|
265
229
|
|
|
266
230
|
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
|
|
267
231
|
|
|
@@ -421,7 +385,7 @@ class SequenceMvNormal(Continuous):
|
|
|
421
385
|
rng = pytensor.shared(np.random.default_rng())
|
|
422
386
|
|
|
423
387
|
def step(mu, cov, rng):
|
|
424
|
-
new_rng, mvn =
|
|
388
|
+
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs
|
|
425
389
|
return mvn, {rng: new_rng}
|
|
426
390
|
|
|
427
391
|
mvn_seq, updates = pytensor.scan(
|
|
@@ -87,12 +87,7 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
|
|
|
87
87
|
col_names = data.columns
|
|
88
88
|
_validate_data_shape(data.shape, n_obs, obs_coords, check_column_names, col_names)
|
|
89
89
|
|
|
90
|
-
if isinstance(data.index, pd.
|
|
91
|
-
if obs_coords is not None:
|
|
92
|
-
warnings.warn(NO_TIME_INDEX_WARNING)
|
|
93
|
-
return preprocess_numpy_data(data.values, n_obs, obs_coords)
|
|
94
|
-
|
|
95
|
-
elif isinstance(data.index, pd.DatetimeIndex):
|
|
90
|
+
if isinstance(data.index, pd.DatetimeIndex):
|
|
96
91
|
if data.index.freq is None:
|
|
97
92
|
warnings.warn(NO_FREQ_INFO_WARNING)
|
|
98
93
|
data.index.freq = data.index.inferred_freq
|
|
@@ -100,10 +95,30 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
|
|
|
100
95
|
index = data.index
|
|
101
96
|
return data.values, index
|
|
102
97
|
|
|
98
|
+
elif isinstance(data.index, pd.RangeIndex):
|
|
99
|
+
if obs_coords is not None:
|
|
100
|
+
warnings.warn(NO_TIME_INDEX_WARNING)
|
|
101
|
+
return preprocess_numpy_data(data.values, n_obs, obs_coords)
|
|
102
|
+
|
|
103
|
+
elif isinstance(data.index, pd.MultiIndex):
|
|
104
|
+
if obs_coords is not None:
|
|
105
|
+
warnings.warn(NO_TIME_INDEX_WARNING)
|
|
106
|
+
|
|
107
|
+
raise NotImplementedError("MultiIndex panel data is not currently supported.")
|
|
108
|
+
|
|
103
109
|
else:
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
110
|
+
if obs_coords is not None:
|
|
111
|
+
warnings.warn(NO_TIME_INDEX_WARNING)
|
|
112
|
+
|
|
113
|
+
index = data.index
|
|
114
|
+
if not np.issubdtype(index.dtype, np.integer):
|
|
115
|
+
raise IndexError("Provided index is not an integer index.")
|
|
116
|
+
|
|
117
|
+
index_diff = index.to_series().diff().dropna().values
|
|
118
|
+
if not (index_diff == 1).all():
|
|
119
|
+
raise IndexError("Provided index is not monotonic increasing.")
|
|
120
|
+
|
|
121
|
+
return preprocess_numpy_data(data.values, n_obs, obs_coords)
|
|
107
122
|
|
|
108
123
|
|
|
109
124
|
def add_data_to_active_model(values, index, data_dims=None):
|
pymc_extras/version.txt
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.2.
|
|
1
|
+
0.2.5
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Home-page: http://github.com/pymc-devs/pymc-extras
|
|
6
6
|
Maintainer: PyMC Developers
|
|
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3
|
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
15
16
|
Classifier: License :: OSI Approved :: Apache Software License
|
|
16
17
|
Classifier: Intended Audience :: Science/Research
|
|
17
18
|
Classifier: Topic :: Scientific/Engineering
|
|
@@ -20,7 +21,7 @@ Classifier: Operating System :: OS Independent
|
|
|
20
21
|
Requires-Python: >=3.10
|
|
21
22
|
Description-Content-Type: text/markdown
|
|
22
23
|
License-File: LICENSE
|
|
23
|
-
Requires-Dist: pymc>=5.
|
|
24
|
+
Requires-Dist: pymc>=5.21.1
|
|
24
25
|
Requires-Dist: scikit-learn
|
|
25
26
|
Requires-Dist: better-optimize
|
|
26
27
|
Provides-Extra: dask-histogram
|
|
@@ -40,6 +41,7 @@ Dynamic: description
|
|
|
40
41
|
Dynamic: description-content-type
|
|
41
42
|
Dynamic: home-page
|
|
42
43
|
Dynamic: license
|
|
44
|
+
Dynamic: license-file
|
|
43
45
|
Dynamic: maintainer
|
|
44
46
|
Dynamic: maintainer-email
|
|
45
47
|
Dynamic: provides-extra
|
|
@@ -1,26 +1,28 @@
|
|
|
1
|
-
pymc_extras/__init__.py,sha256=
|
|
1
|
+
pymc_extras/__init__.py,sha256=lYGf9TcwUHROIElkX7Epnb7-IppcmiSEYuxdtRzqS3s,1195
|
|
2
2
|
pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
|
|
3
3
|
pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
|
|
4
4
|
pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
|
|
5
5
|
pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
|
|
6
|
-
pymc_extras/version.txt,sha256=
|
|
7
|
-
pymc_extras/distributions/__init__.py,sha256=
|
|
6
|
+
pymc_extras/version.txt,sha256=6Vn3UOktu3YUriislvCjcnLK7YHYu7dYeRr3v7thBqA,6
|
|
7
|
+
pymc_extras/distributions/__init__.py,sha256=fDbrBt9mxEVp2CDPwnyCW3oiutzZ0PduB8EUH3fUrjI,1377
|
|
8
8
|
pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
|
|
9
9
|
pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
|
|
10
10
|
pymc_extras/distributions/histogram_utils.py,sha256=5RTvlGCUrp2qzshrchmPyWxjhs6RIYL62SMikjDM1TU,5814
|
|
11
11
|
pymc_extras/distributions/timeseries.py,sha256=M5MZ-nik_tgkaoZ1hdUGEZ9g04DQyVLwszVJqSKwNcY,12719
|
|
12
12
|
pymc_extras/distributions/multivariate/__init__.py,sha256=E8OeLW9tTotCbrUjEo4um76-_WQD56PehsPzkKmhfyA,93
|
|
13
13
|
pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNPi727uTbvAdxl9fm1zNBqY,16005
|
|
14
|
+
pymc_extras/distributions/transforms/__init__.py,sha256=FUp2vyRE6_2eUcQ_FVt5Dn0-vy5I-puV-Kz13-QtLNc,104
|
|
15
|
+
pymc_extras/distributions/transforms/partial_order.py,sha256=oEZlc9WgnGR46uFEjLzKEUxlhzIo2vrUUbBE3vYrsfQ,8404
|
|
14
16
|
pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
|
|
15
17
|
pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
|
|
16
|
-
pymc_extras/inference/__init__.py,sha256=
|
|
18
|
+
pymc_extras/inference/__init__.py,sha256=UH6S0bGfQKKyTSuqf7yezdy9PeE2bDU8U1v4eIRv4ZI,887
|
|
17
19
|
pymc_extras/inference/find_map.py,sha256=vl5l0ei48PnX-uTuHVTr-9QpCEHc8xog-KK6sOnJ8LU,16513
|
|
18
|
-
pymc_extras/inference/fit.py,sha256=
|
|
19
|
-
pymc_extras/inference/laplace.py,sha256=
|
|
20
|
+
pymc_extras/inference/fit.py,sha256=oe20RAajImZ-VD9Ucbzri8Bof4Y2KHNhNRG19v9O3lI,1336
|
|
21
|
+
pymc_extras/inference/laplace.py,sha256=cqarAdbFaOH74AkPUF4c7c_Hswa5mqmhgHpsgrkebHY,21860
|
|
20
22
|
pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
|
|
21
|
-
pymc_extras/inference/pathfinder/importance_sampling.py,sha256=
|
|
22
|
-
pymc_extras/inference/pathfinder/lbfgs.py,sha256=
|
|
23
|
-
pymc_extras/inference/pathfinder/pathfinder.py,sha256=
|
|
23
|
+
pymc_extras/inference/pathfinder/importance_sampling.py,sha256=NwxepXOFit3cA5zEebniKdlnJ1rZWg56aMlH4MEOcG4,6264
|
|
24
|
+
pymc_extras/inference/pathfinder/lbfgs.py,sha256=GOoJBil5Kft_iFwGNUGKSeqzI5x_shA4KQWDwgGuQtQ,7110
|
|
25
|
+
pymc_extras/inference/pathfinder/pathfinder.py,sha256=GW04HQurj_3Nlo1C6_K2tEIeigo8x0buV3FqDLA88PQ,64439
|
|
24
26
|
pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
|
|
25
27
|
pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
|
|
26
28
|
pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -37,9 +39,9 @@ pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53C
|
|
|
37
39
|
pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
|
|
38
40
|
pymc_extras/statespace/core/compile.py,sha256=9FZfE8Bi3VfElxujfOIKRVvmyL9M5R0WfNEqPc5kbVQ,1603
|
|
39
41
|
pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
|
|
40
|
-
pymc_extras/statespace/core/statespace.py,sha256=
|
|
42
|
+
pymc_extras/statespace/core/statespace.py,sha256=Tx-821UNNLqsZgHzRmwaQ6s-agp_OthqSsbfwDpA1o0,96927
|
|
41
43
|
pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
|
|
42
|
-
pymc_extras/statespace/filters/distributions.py,sha256
|
|
44
|
+
pymc_extras/statespace/filters/distributions.py,sha256=ejimTFLgBFZMEznxY5zh6u4Vrqij60i0k2_sxdPcZ3A,11878
|
|
43
45
|
pymc_extras/statespace/filters/kalman_filter.py,sha256=HELC3aK4k8EdWlUAk5_F7y7YkIz-Xi_0j2AwRgAXgcc,31949
|
|
44
46
|
pymc_extras/statespace/filters/kalman_smoother.py,sha256=ypH9K_88nfJ5K2Cq737aWL3p8v4UfI7MxnYs54WPdDs,4329
|
|
45
47
|
pymc_extras/statespace/filters/utilities.py,sha256=iwdaYnO1cO06t_XUjLLRmqb8vwzzVH6Nx1iyZcbJL2k,1584
|
|
@@ -52,21 +54,22 @@ pymc_extras/statespace/models/utilities.py,sha256=G9GuHKsghmIYOlfkPtvxBWF-FZY5-5
|
|
|
52
54
|
pymc_extras/statespace/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
53
55
|
pymc_extras/statespace/utils/constants.py,sha256=Kf6j75ABaDQeRODxKQ76wTUQV4F5sTjn1KBcZgCQx20,2403
|
|
54
56
|
pymc_extras/statespace/utils/coord_tools.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
55
|
-
pymc_extras/statespace/utils/data_tools.py,sha256=
|
|
57
|
+
pymc_extras/statespace/utils/data_tools.py,sha256=01sz6XDtLYK9I5xghxYpD-PuDzGXv9D-wFGfTV6FGEw,6566
|
|
56
58
|
pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
|
|
57
59
|
pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
|
|
58
60
|
pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
|
|
59
61
|
pymc_extras/utils/pivoted_cholesky.py,sha256=QtnjP0pAl9b77fLAu-semwT4_9dcoiqx3dz1xKGBjMk,1871
|
|
60
62
|
pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
|
|
61
63
|
pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
|
|
64
|
+
pymc_extras-0.2.5.dist-info/licenses/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
|
|
62
65
|
tests/__init__.py,sha256=-ree9OWVCyTeXLR944OWjrQX2os15HXrRNkhJ7QdRjc,603
|
|
63
66
|
tests/test_blackjax_smc.py,sha256=jcNgcMBxaKyPg9UvHnWQtwoL79LXlSpZfALe3RGEZnQ,7233
|
|
64
67
|
tests/test_find_map.py,sha256=B8ThnXNyfTQeem24QaLoTitFrsxKoq2VQINUdOwzna0,3379
|
|
65
68
|
tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6YxzqA5X4,4674
|
|
66
|
-
tests/test_laplace.py,sha256=
|
|
69
|
+
tests/test_laplace.py,sha256=fArHjwMR7x98K-gZLvrvb3AwNZ7_fo7E0A4SJyt4EGU,9843
|
|
67
70
|
tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
|
|
68
71
|
tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
|
|
69
|
-
tests/test_pathfinder.py,sha256=
|
|
72
|
+
tests/test_pathfinder.py,sha256=vlMI1p2Ja5X4QIaSV4h6U41I303rEppfO0JqE3xe1Rs,10023
|
|
70
73
|
tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
|
|
71
74
|
tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
|
|
72
75
|
tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
|
|
@@ -77,6 +80,7 @@ tests/distributions/test_continuous.py,sha256=1-bu-IP6RgLUJnuPYpOD8ZS1ahYbKtsJ9o
|
|
|
77
80
|
tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcleoCXUi0,7776
|
|
78
81
|
tests/distributions/test_discrete_markov_chain.py,sha256=8RCHZXSB8IWjniuKaGGlM_iTWGmdrcOqginxmrAeEJg,9212
|
|
79
82
|
tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIXeRBKiEn6kjs,10452
|
|
83
|
+
tests/distributions/test_transform.py,sha256=QM9sSQ5eSbuT2pM76nUMWqb-tQa7DGZbT9uwFDqIRUk,2672
|
|
80
84
|
tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
81
85
|
tests/model/test_model_api.py,sha256=FJvMTmexovRELZOUcUyk-6Vwk9qSiH7hIFoiArgl5mk,1040
|
|
82
86
|
tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -87,7 +91,7 @@ tests/statespace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
|
87
91
|
tests/statespace/test_ETS.py,sha256=IPg3uQ7xEGqDMEHu993vtUTV7r-uNAxmw23sr5MVGfQ,15582
|
|
88
92
|
tests/statespace/test_SARIMAX.py,sha256=1BYNOm9aSHnpn-qbpe3YsQVH8m-mXcp_gvKgWhWn1W4,12948
|
|
89
93
|
tests/statespace/test_VARMAX.py,sha256=rJnea9_WEGo9I0iv2eaSbwwFQv0tlIjpN7KAE0eQewU,6336
|
|
90
|
-
tests/statespace/test_coord_assignment.py,sha256=
|
|
94
|
+
tests/statespace/test_coord_assignment.py,sha256=2Mo5196ibkBTwscE7kqQoUsgQphdaagVkOccDi7D4RI,5980
|
|
91
95
|
tests/statespace/test_distributions.py,sha256=WQ_ROyd-PL3cimXTyEtyVaMEVtS7Hue2Z0lN7UnGDyo,9122
|
|
92
96
|
tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_YvQ2lM6hiQ,11600
|
|
93
97
|
tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
|
|
@@ -98,8 +102,7 @@ tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJ
|
|
|
98
102
|
tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
|
|
99
103
|
tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
|
|
100
104
|
tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
|
|
101
|
-
pymc_extras-0.2.
|
|
102
|
-
pymc_extras-0.2.
|
|
103
|
-
pymc_extras-0.2.
|
|
104
|
-
pymc_extras-0.2.
|
|
105
|
-
pymc_extras-0.2.3.dist-info/RECORD,,
|
|
105
|
+
pymc_extras-0.2.5.dist-info/METADATA,sha256=02v5liTQQ55sV8xeFl5EjFpwbOKSYKG6g5lE_4htpBo,5227
|
|
106
|
+
pymc_extras-0.2.5.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
|
107
|
+
pymc_extras-0.2.5.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
|
|
108
|
+
pymc_extras-0.2.5.dist-info/RECORD,,
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright 2025 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pymc as pm
|
|
16
|
+
|
|
17
|
+
from pymc_extras.distributions.transforms import PartialOrder
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TestPartialOrder:
|
|
21
|
+
adj_mats = np.array(
|
|
22
|
+
[
|
|
23
|
+
# 0 < {1, 2} < 3
|
|
24
|
+
[[0, 1, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]],
|
|
25
|
+
# 1 < 0 < 3 < 2
|
|
26
|
+
[[0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0]],
|
|
27
|
+
]
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
valid_values = np.array([[0, 2, 1, 3], [1, 0, 3, 2]], dtype=float)
|
|
31
|
+
|
|
32
|
+
# Test that forward and backward are inverses of eachother
|
|
33
|
+
# And that it works when extra dimensions are added in data
|
|
34
|
+
def test_forward_backward_dimensionality(self):
|
|
35
|
+
po = PartialOrder(self.adj_mats)
|
|
36
|
+
po0 = PartialOrder(self.adj_mats[0])
|
|
37
|
+
vv = self.valid_values
|
|
38
|
+
vv0 = self.valid_values[0]
|
|
39
|
+
|
|
40
|
+
testsets = [
|
|
41
|
+
(vv, po),
|
|
42
|
+
(po.initvals(), po),
|
|
43
|
+
(vv0, po0),
|
|
44
|
+
(po0.initvals(), po0),
|
|
45
|
+
(np.tile(vv0, (2, 1)), po0),
|
|
46
|
+
(np.tile(vv0, (2, 3, 2, 1)), po0),
|
|
47
|
+
(np.tile(vv, (2, 3, 2, 1, 1)), po),
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
for vv, po in testsets:
|
|
51
|
+
fw = po.forward(vv)
|
|
52
|
+
bw = po.backward(fw)
|
|
53
|
+
np.testing.assert_allclose(bw.eval(), vv)
|
|
54
|
+
|
|
55
|
+
def test_sample_model(self):
|
|
56
|
+
po = PartialOrder(self.adj_mats)
|
|
57
|
+
with pm.Model() as model:
|
|
58
|
+
x = pm.Normal(
|
|
59
|
+
"x",
|
|
60
|
+
size=(3, 2, 4),
|
|
61
|
+
transform=po,
|
|
62
|
+
initval=po.initvals(shape=(3, 2, 4), lower=-1, upper=1),
|
|
63
|
+
)
|
|
64
|
+
idata = pm.sample()
|
|
65
|
+
|
|
66
|
+
# Check that the order constraints are satisfied
|
|
67
|
+
# Move chain, draw and "3" dimensions to the back
|
|
68
|
+
xvs = idata.posterior.x.values.transpose(3, 4, 0, 1, 2)
|
|
69
|
+
x0 = xvs[0] # 0 < {1, 2} < 3
|
|
70
|
+
assert (
|
|
71
|
+
(x0[0] < x0[1]).all()
|
|
72
|
+
and (x0[0] < x0[2]).all()
|
|
73
|
+
and (x0[1] < x0[3]).all()
|
|
74
|
+
and (x0[2] < x0[3]).all()
|
|
75
|
+
)
|
|
76
|
+
x1 = xvs[1] # 1 < 0 < 3 < 2
|
|
77
|
+
assert (x1[1] < x1[0]).all() and (x1[0] < x1[3]).all() and (x1[3] < x1[2]).all()
|
|
@@ -8,6 +8,7 @@ import pytensor.tensor as pt
|
|
|
8
8
|
import pytest
|
|
9
9
|
|
|
10
10
|
from pymc_extras.statespace.models import structural
|
|
11
|
+
from pymc_extras.statespace.models.structural import LevelTrendComponent
|
|
11
12
|
from pymc_extras.statespace.utils.constants import (
|
|
12
13
|
FILTER_OUTPUT_DIMS,
|
|
13
14
|
FILTER_OUTPUT_NAMES,
|
|
@@ -114,3 +115,67 @@ def test_data_index_is_coord(f, warning, create_model):
|
|
|
114
115
|
with warning:
|
|
115
116
|
pymc_model = create_model(f)
|
|
116
117
|
assert TIME_DIM in pymc_model.coords
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def make_model(index):
|
|
121
|
+
n = len(index)
|
|
122
|
+
a = pd.DataFrame(index=index, columns=["A", "B", "C", "D"], data=np.arange(n * 4).reshape(n, 4))
|
|
123
|
+
|
|
124
|
+
mod = LevelTrendComponent(order=2, innovations_order=[0, 1])
|
|
125
|
+
ss_mod = mod.build(name="a", verbose=False)
|
|
126
|
+
|
|
127
|
+
initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values()
|
|
128
|
+
coords = ss_mod.coords
|
|
129
|
+
|
|
130
|
+
with pm.Model(coords=coords) as model:
|
|
131
|
+
P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5)
|
|
132
|
+
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims)
|
|
133
|
+
|
|
134
|
+
initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
|
|
135
|
+
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=sigma_trend_dims)
|
|
136
|
+
|
|
137
|
+
with pytest.warns(UserWarning, match="No time index found on the supplied data"):
|
|
138
|
+
ss_mod.build_statespace_graph(
|
|
139
|
+
a["A"],
|
|
140
|
+
mode="JAX",
|
|
141
|
+
)
|
|
142
|
+
return model
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_integer_index():
|
|
146
|
+
index = np.arange(8).astype(int)
|
|
147
|
+
model = make_model(index)
|
|
148
|
+
assert TIME_DIM in model.coords
|
|
149
|
+
np.testing.assert_allclose(model.coords[TIME_DIM], index)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def test_float_index_raises():
|
|
153
|
+
index = np.linspace(0, 1, 8)
|
|
154
|
+
|
|
155
|
+
with pytest.raises(IndexError, match="Provided index is not an integer index"):
|
|
156
|
+
make_model(index)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def test_non_strictly_monotone_index_raises():
|
|
160
|
+
# Decreases
|
|
161
|
+
index = [0, 1, 2, 1, 2, 3]
|
|
162
|
+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
|
|
163
|
+
make_model(index)
|
|
164
|
+
|
|
165
|
+
# Has gaps
|
|
166
|
+
index = [0, 1, 2, 3, 5, 6]
|
|
167
|
+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
|
|
168
|
+
make_model(index)
|
|
169
|
+
|
|
170
|
+
# Has duplicates
|
|
171
|
+
index = [0, 1, 1, 2, 3, 4]
|
|
172
|
+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
|
|
173
|
+
make_model(index)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def test_multiindex_raises():
|
|
177
|
+
index = pd.MultiIndex.from_tuples([(0, 0), (1, 1), (2, 2), (3, 3)])
|
|
178
|
+
with pytest.raises(
|
|
179
|
+
NotImplementedError, match="MultiIndex panel data is not currently supported"
|
|
180
|
+
):
|
|
181
|
+
make_model(index)
|
tests/test_laplace.py
CHANGED
|
@@ -263,3 +263,19 @@ def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: Gradien
|
|
|
263
263
|
else:
|
|
264
264
|
assert idata.fit.rows.values.tolist() == ["mu", "sigma"]
|
|
265
265
|
np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 1.5]), atol=0.1)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def test_laplace_scalar():
|
|
269
|
+
# Example model from Statistical Rethinking
|
|
270
|
+
data = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1])
|
|
271
|
+
|
|
272
|
+
with pm.Model():
|
|
273
|
+
p = pm.Uniform("p", 0, 1)
|
|
274
|
+
w = pm.Binomial("w", n=len(data), p=p, observed=data.sum())
|
|
275
|
+
|
|
276
|
+
idata_laplace = pmx.fit_laplace(progressbar=False)
|
|
277
|
+
|
|
278
|
+
assert idata_laplace.fit.mean_vector.shape == (1,)
|
|
279
|
+
assert idata_laplace.fit.covariance_matrix.shape == (1, 1)
|
|
280
|
+
|
|
281
|
+
np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1)
|