pymc-extras 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,11 +6,9 @@ import pytensor.tensor as pt
6
6
  from pymc import intX
7
7
  from pymc.distributions.dist_math import check_parameters
8
8
  from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
9
- from pymc.distributions.multivariate import MvNormal
10
9
  from pymc.distributions.shape_utils import get_support_shape_1d
11
10
  from pymc.logprob.abstract import _logprob
12
11
  from pytensor.graph.basic import Node
13
- from pytensor.tensor.random.basic import MvNormalRV
14
12
 
15
13
  floatX = pytensor.config.floatX
16
14
  COV_ZERO_TOL = 0
@@ -49,44 +47,6 @@ def make_signature(sequence_names):
49
47
  return f"{signature},[rng]->[rng],({time},{state_and_obs})"
50
48
 
51
49
 
52
- class MvNormalSVDRV(MvNormalRV):
53
- name = "multivariate_normal"
54
- signature = "(n),(n,n)->(n)"
55
- dtype = "floatX"
56
- _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
57
-
58
-
59
- class MvNormalSVD(MvNormal):
60
- """Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd".
61
-
62
- A JAX MvNormal robust to low-rank covariance matrices
63
- """
64
-
65
- rv_op = MvNormalSVDRV()
66
-
67
-
68
- try:
69
- import jax.random
70
-
71
- from pytensor.link.jax.dispatch.random import jax_sample_fn
72
-
73
- @jax_sample_fn.register(MvNormalSVDRV)
74
- def jax_sample_fn_mvnormal_svd(op, node):
75
- def sample_fn(rng, size, dtype, *parameters):
76
- rng_key = rng["jax_state"]
77
- rng_key, sampling_key = jax.random.split(rng_key, 2)
78
- sample = jax.random.multivariate_normal(
79
- sampling_key, *parameters, shape=size, dtype=dtype, method="svd"
80
- )
81
- rng["jax_state"] = rng_key
82
- return (rng, sample)
83
-
84
- return sample_fn
85
-
86
- except ImportError:
87
- pass
88
-
89
-
90
50
  class LinearGaussianStateSpaceRV(SymbolicRandomVariable):
91
51
  default_output = 1
92
52
  _print_name = ("LinearGuassianStateSpace", "\\operatorname{LinearGuassianStateSpace}")
@@ -244,8 +204,12 @@ class _LinearGaussianStateSpace(Continuous):
244
204
  k = T.shape[0]
245
205
  a = state[:k]
246
206
 
247
- middle_rng, a_innovation = MvNormalSVD.dist(mu=0, cov=Q, rng=rng).owner.outputs
248
- next_rng, y_innovation = MvNormalSVD.dist(mu=0, cov=H, rng=middle_rng).owner.outputs
207
+ middle_rng, a_innovation = pm.MvNormal.dist(
208
+ mu=0, cov=Q, rng=rng, method="svd"
209
+ ).owner.outputs
210
+ next_rng, y_innovation = pm.MvNormal.dist(
211
+ mu=0, cov=H, rng=middle_rng, method="svd"
212
+ ).owner.outputs
249
213
 
250
214
  a_mu = c + T @ a
251
215
  a_next = a_mu + R @ a_innovation
@@ -260,8 +224,8 @@ class _LinearGaussianStateSpace(Continuous):
260
224
  Z_init = Z_ if Z_ in non_sequences else Z_[0]
261
225
  H_init = H_ if H_ in non_sequences else H_[0]
262
226
 
263
- init_x_ = MvNormalSVD.dist(a0_, P0_, rng=rng)
264
- init_y_ = MvNormalSVD.dist(Z_init @ init_x_, H_init, rng=rng)
227
+ init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method="svd")
228
+ init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method="svd")
265
229
 
266
230
  init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
267
231
 
@@ -421,7 +385,7 @@ class SequenceMvNormal(Continuous):
421
385
  rng = pytensor.shared(np.random.default_rng())
422
386
 
423
387
  def step(mu, cov, rng):
424
- new_rng, mvn = MvNormalSVD.dist(mu=mu, cov=cov, rng=rng).owner.outputs
388
+ new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs
425
389
  return mvn, {rng: new_rng}
426
390
 
427
391
  mvn_seq, updates = pytensor.scan(
@@ -87,12 +87,7 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
87
87
  col_names = data.columns
88
88
  _validate_data_shape(data.shape, n_obs, obs_coords, check_column_names, col_names)
89
89
 
90
- if isinstance(data.index, pd.RangeIndex):
91
- if obs_coords is not None:
92
- warnings.warn(NO_TIME_INDEX_WARNING)
93
- return preprocess_numpy_data(data.values, n_obs, obs_coords)
94
-
95
- elif isinstance(data.index, pd.DatetimeIndex):
90
+ if isinstance(data.index, pd.DatetimeIndex):
96
91
  if data.index.freq is None:
97
92
  warnings.warn(NO_FREQ_INFO_WARNING)
98
93
  data.index.freq = data.index.inferred_freq
@@ -100,10 +95,30 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
100
95
  index = data.index
101
96
  return data.values, index
102
97
 
98
+ elif isinstance(data.index, pd.RangeIndex):
99
+ if obs_coords is not None:
100
+ warnings.warn(NO_TIME_INDEX_WARNING)
101
+ return preprocess_numpy_data(data.values, n_obs, obs_coords)
102
+
103
+ elif isinstance(data.index, pd.MultiIndex):
104
+ if obs_coords is not None:
105
+ warnings.warn(NO_TIME_INDEX_WARNING)
106
+
107
+ raise NotImplementedError("MultiIndex panel data is not currently supported.")
108
+
103
109
  else:
104
- raise IndexError(
105
- f"Expected pd.DatetimeIndex or pd.RangeIndex on data, found {type(data.index)}"
106
- )
110
+ if obs_coords is not None:
111
+ warnings.warn(NO_TIME_INDEX_WARNING)
112
+
113
+ index = data.index
114
+ if not np.issubdtype(index.dtype, np.integer):
115
+ raise IndexError("Provided index is not an integer index.")
116
+
117
+ index_diff = index.to_series().diff().dropna().values
118
+ if not (index_diff == 1).all():
119
+ raise IndexError("Provided index is not monotonic increasing.")
120
+
121
+ return preprocess_numpy_data(data.values, n_obs, obs_coords)
107
122
 
108
123
 
109
124
  def add_data_to_active_model(values, index, data_dims=None):
pymc_extras/version.txt CHANGED
@@ -1 +1 @@
1
- 0.2.3
1
+ 0.2.5
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: pymc-extras
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Home-page: http://github.com/pymc-devs/pymc-extras
6
6
  Maintainer: PyMC Developers
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3
12
12
  Classifier: Programming Language :: Python :: 3.10
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
15
16
  Classifier: License :: OSI Approved :: Apache Software License
16
17
  Classifier: Intended Audience :: Science/Research
17
18
  Classifier: Topic :: Scientific/Engineering
@@ -20,7 +21,7 @@ Classifier: Operating System :: OS Independent
20
21
  Requires-Python: >=3.10
21
22
  Description-Content-Type: text/markdown
22
23
  License-File: LICENSE
23
- Requires-Dist: pymc>=5.20
24
+ Requires-Dist: pymc>=5.21.1
24
25
  Requires-Dist: scikit-learn
25
26
  Requires-Dist: better-optimize
26
27
  Provides-Extra: dask-histogram
@@ -40,6 +41,7 @@ Dynamic: description
40
41
  Dynamic: description-content-type
41
42
  Dynamic: home-page
42
43
  Dynamic: license
44
+ Dynamic: license-file
43
45
  Dynamic: maintainer
44
46
  Dynamic: maintainer-email
45
47
  Dynamic: provides-extra
@@ -1,26 +1,28 @@
1
- pymc_extras/__init__.py,sha256=IFIEZdPX_Ugq57Bu7jlyrJLpKng-P0FBAAAzl2pFXLE,1266
1
+ pymc_extras/__init__.py,sha256=lYGf9TcwUHROIElkX7Epnb7-IppcmiSEYuxdtRzqS3s,1195
2
2
  pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
3
3
  pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
4
4
  pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
5
5
  pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
6
- pymc_extras/version.txt,sha256=OrlMBNJJhvOvKIuhzaLAu928Wonf8JcYKAX1RXjh6nU,6
7
- pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
6
+ pymc_extras/version.txt,sha256=6Vn3UOktu3YUriislvCjcnLK7YHYu7dYeRr3v7thBqA,6
7
+ pymc_extras/distributions/__init__.py,sha256=fDbrBt9mxEVp2CDPwnyCW3oiutzZ0PduB8EUH3fUrjI,1377
8
8
  pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
9
9
  pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
10
10
  pymc_extras/distributions/histogram_utils.py,sha256=5RTvlGCUrp2qzshrchmPyWxjhs6RIYL62SMikjDM1TU,5814
11
11
  pymc_extras/distributions/timeseries.py,sha256=M5MZ-nik_tgkaoZ1hdUGEZ9g04DQyVLwszVJqSKwNcY,12719
12
12
  pymc_extras/distributions/multivariate/__init__.py,sha256=E8OeLW9tTotCbrUjEo4um76-_WQD56PehsPzkKmhfyA,93
13
13
  pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNPi727uTbvAdxl9fm1zNBqY,16005
14
+ pymc_extras/distributions/transforms/__init__.py,sha256=FUp2vyRE6_2eUcQ_FVt5Dn0-vy5I-puV-Kz13-QtLNc,104
15
+ pymc_extras/distributions/transforms/partial_order.py,sha256=oEZlc9WgnGR46uFEjLzKEUxlhzIo2vrUUbBE3vYrsfQ,8404
14
16
  pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
15
17
  pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
16
- pymc_extras/inference/__init__.py,sha256=5cXpaQQnW0mJJ3x8wSxmYu63l--Xab5D_gMtjA6Q3uU,666
18
+ pymc_extras/inference/__init__.py,sha256=UH6S0bGfQKKyTSuqf7yezdy9PeE2bDU8U1v4eIRv4ZI,887
17
19
  pymc_extras/inference/find_map.py,sha256=vl5l0ei48PnX-uTuHVTr-9QpCEHc8xog-KK6sOnJ8LU,16513
18
- pymc_extras/inference/fit.py,sha256=S9R48dh74s6K0MC9Iys4NAwVjP6rVRfx6SF-kPiR70E,1165
19
- pymc_extras/inference/laplace.py,sha256=uOZGp8ssQuhvCHV_Y_v3icsr4rhcYgr_qlr9dS7pcSM,21761
20
+ pymc_extras/inference/fit.py,sha256=oe20RAajImZ-VD9Ucbzri8Bof4Y2KHNhNRG19v9O3lI,1336
21
+ pymc_extras/inference/laplace.py,sha256=cqarAdbFaOH74AkPUF4c7c_Hswa5mqmhgHpsgrkebHY,21860
20
22
  pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
21
- pymc_extras/inference/pathfinder/importance_sampling.py,sha256=VvmuaE3aw_Mo3tMwswfF0rqe19mnhOCpzIScaJzjA1Y,6159
22
- pymc_extras/inference/pathfinder/lbfgs.py,sha256=P0UIOVtspdLzDU6alK-y91qzVAzXjYAXPuGmZ1nRqMo,5715
23
- pymc_extras/inference/pathfinder/pathfinder.py,sha256=fomZ5voVcWxvhWpeIZV7IHGIJCasT1g0ivC4dC3-0GM,63694
23
+ pymc_extras/inference/pathfinder/importance_sampling.py,sha256=NwxepXOFit3cA5zEebniKdlnJ1rZWg56aMlH4MEOcG4,6264
24
+ pymc_extras/inference/pathfinder/lbfgs.py,sha256=GOoJBil5Kft_iFwGNUGKSeqzI5x_shA4KQWDwgGuQtQ,7110
25
+ pymc_extras/inference/pathfinder/pathfinder.py,sha256=GW04HQurj_3Nlo1C6_K2tEIeigo8x0buV3FqDLA88PQ,64439
24
26
  pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
25
27
  pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
26
28
  pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -37,9 +39,9 @@ pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53C
37
39
  pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
38
40
  pymc_extras/statespace/core/compile.py,sha256=9FZfE8Bi3VfElxujfOIKRVvmyL9M5R0WfNEqPc5kbVQ,1603
39
41
  pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
40
- pymc_extras/statespace/core/statespace.py,sha256=K_WVnWKlI6sR2kgriq9sctQVvwXCeAirm14TthDpmRM,96860
42
+ pymc_extras/statespace/core/statespace.py,sha256=Tx-821UNNLqsZgHzRmwaQ6s-agp_OthqSsbfwDpA1o0,96927
41
43
  pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
42
- pymc_extras/statespace/filters/distributions.py,sha256=-9j__vRqL5hKyYFnQr5HKHA5kEFzwiuSccH4mslTOuQ,12900
44
+ pymc_extras/statespace/filters/distributions.py,sha256=ejimTFLgBFZMEznxY5zh6u4Vrqij60i0k2_sxdPcZ3A,11878
43
45
  pymc_extras/statespace/filters/kalman_filter.py,sha256=HELC3aK4k8EdWlUAk5_F7y7YkIz-Xi_0j2AwRgAXgcc,31949
44
46
  pymc_extras/statespace/filters/kalman_smoother.py,sha256=ypH9K_88nfJ5K2Cq737aWL3p8v4UfI7MxnYs54WPdDs,4329
45
47
  pymc_extras/statespace/filters/utilities.py,sha256=iwdaYnO1cO06t_XUjLLRmqb8vwzzVH6Nx1iyZcbJL2k,1584
@@ -52,21 +54,22 @@ pymc_extras/statespace/models/utilities.py,sha256=G9GuHKsghmIYOlfkPtvxBWF-FZY5-5
52
54
  pymc_extras/statespace/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
53
55
  pymc_extras/statespace/utils/constants.py,sha256=Kf6j75ABaDQeRODxKQ76wTUQV4F5sTjn1KBcZgCQx20,2403
54
56
  pymc_extras/statespace/utils/coord_tools.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
- pymc_extras/statespace/utils/data_tools.py,sha256=caanvrxDu9g-dEKff2bbmaVTs6-71kkSoYIiiSUXhw4,5985
57
+ pymc_extras/statespace/utils/data_tools.py,sha256=01sz6XDtLYK9I5xghxYpD-PuDzGXv9D-wFGfTV6FGEw,6566
56
58
  pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
57
59
  pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
58
60
  pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
59
61
  pymc_extras/utils/pivoted_cholesky.py,sha256=QtnjP0pAl9b77fLAu-semwT4_9dcoiqx3dz1xKGBjMk,1871
60
62
  pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
61
63
  pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
64
+ pymc_extras-0.2.5.dist-info/licenses/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
62
65
  tests/__init__.py,sha256=-ree9OWVCyTeXLR944OWjrQX2os15HXrRNkhJ7QdRjc,603
63
66
  tests/test_blackjax_smc.py,sha256=jcNgcMBxaKyPg9UvHnWQtwoL79LXlSpZfALe3RGEZnQ,7233
64
67
  tests/test_find_map.py,sha256=B8ThnXNyfTQeem24QaLoTitFrsxKoq2VQINUdOwzna0,3379
65
68
  tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6YxzqA5X4,4674
66
- tests/test_laplace.py,sha256=u4o-0y4v1emaTMYr_rOyL_EKY_bQIz0DUXFuwuDbfNg,9314
69
+ tests/test_laplace.py,sha256=fArHjwMR7x98K-gZLvrvb3AwNZ7_fo7E0A4SJyt4EGU,9843
67
70
  tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
68
71
  tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
69
- tests/test_pathfinder.py,sha256=GnSbZJ9QuFW9UVbkWaVgMVqQZTCttOyz_rSflxhQ-EA,4955
72
+ tests/test_pathfinder.py,sha256=vlMI1p2Ja5X4QIaSV4h6U41I303rEppfO0JqE3xe1Rs,10023
70
73
  tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
71
74
  tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
72
75
  tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
@@ -77,6 +80,7 @@ tests/distributions/test_continuous.py,sha256=1-bu-IP6RgLUJnuPYpOD8ZS1ahYbKtsJ9o
77
80
  tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcleoCXUi0,7776
78
81
  tests/distributions/test_discrete_markov_chain.py,sha256=8RCHZXSB8IWjniuKaGGlM_iTWGmdrcOqginxmrAeEJg,9212
79
82
  tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIXeRBKiEn6kjs,10452
83
+ tests/distributions/test_transform.py,sha256=QM9sSQ5eSbuT2pM76nUMWqb-tQa7DGZbT9uwFDqIRUk,2672
80
84
  tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
81
85
  tests/model/test_model_api.py,sha256=FJvMTmexovRELZOUcUyk-6Vwk9qSiH7hIFoiArgl5mk,1040
82
86
  tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -87,7 +91,7 @@ tests/statespace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
87
91
  tests/statespace/test_ETS.py,sha256=IPg3uQ7xEGqDMEHu993vtUTV7r-uNAxmw23sr5MVGfQ,15582
88
92
  tests/statespace/test_SARIMAX.py,sha256=1BYNOm9aSHnpn-qbpe3YsQVH8m-mXcp_gvKgWhWn1W4,12948
89
93
  tests/statespace/test_VARMAX.py,sha256=rJnea9_WEGo9I0iv2eaSbwwFQv0tlIjpN7KAE0eQewU,6336
90
- tests/statespace/test_coord_assignment.py,sha256=2GBm46-0eI4QNh4bvp3D7az58stcA5Zo6VgOo_JkCig,3821
94
+ tests/statespace/test_coord_assignment.py,sha256=2Mo5196ibkBTwscE7kqQoUsgQphdaagVkOccDi7D4RI,5980
91
95
  tests/statespace/test_distributions.py,sha256=WQ_ROyd-PL3cimXTyEtyVaMEVtS7Hue2Z0lN7UnGDyo,9122
92
96
  tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_YvQ2lM6hiQ,11600
93
97
  tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
@@ -98,8 +102,7 @@ tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJ
98
102
  tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
99
103
  tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
100
104
  tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
101
- pymc_extras-0.2.3.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
102
- pymc_extras-0.2.3.dist-info/METADATA,sha256=ZTiMM7hvVRF3O_liRu4Aea_EuxJc4vHfTD2CbRRQrcU,5152
103
- pymc_extras-0.2.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
104
- pymc_extras-0.2.3.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
105
- pymc_extras-0.2.3.dist-info/RECORD,,
105
+ pymc_extras-0.2.5.dist-info/METADATA,sha256=02v5liTQQ55sV8xeFl5EjFpwbOKSYKG6g5lE_4htpBo,5227
106
+ pymc_extras-0.2.5.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
107
+ pymc_extras-0.2.5.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
108
+ pymc_extras-0.2.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (78.1.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -0,0 +1,77 @@
1
+ # Copyright 2025 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import pymc as pm
16
+
17
+ from pymc_extras.distributions.transforms import PartialOrder
18
+
19
+
20
+ class TestPartialOrder:
21
+ adj_mats = np.array(
22
+ [
23
+ # 0 < {1, 2} < 3
24
+ [[0, 1, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]],
25
+ # 1 < 0 < 3 < 2
26
+ [[0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0]],
27
+ ]
28
+ )
29
+
30
+ valid_values = np.array([[0, 2, 1, 3], [1, 0, 3, 2]], dtype=float)
31
+
32
+ # Test that forward and backward are inverses of eachother
33
+ # And that it works when extra dimensions are added in data
34
+ def test_forward_backward_dimensionality(self):
35
+ po = PartialOrder(self.adj_mats)
36
+ po0 = PartialOrder(self.adj_mats[0])
37
+ vv = self.valid_values
38
+ vv0 = self.valid_values[0]
39
+
40
+ testsets = [
41
+ (vv, po),
42
+ (po.initvals(), po),
43
+ (vv0, po0),
44
+ (po0.initvals(), po0),
45
+ (np.tile(vv0, (2, 1)), po0),
46
+ (np.tile(vv0, (2, 3, 2, 1)), po0),
47
+ (np.tile(vv, (2, 3, 2, 1, 1)), po),
48
+ ]
49
+
50
+ for vv, po in testsets:
51
+ fw = po.forward(vv)
52
+ bw = po.backward(fw)
53
+ np.testing.assert_allclose(bw.eval(), vv)
54
+
55
+ def test_sample_model(self):
56
+ po = PartialOrder(self.adj_mats)
57
+ with pm.Model() as model:
58
+ x = pm.Normal(
59
+ "x",
60
+ size=(3, 2, 4),
61
+ transform=po,
62
+ initval=po.initvals(shape=(3, 2, 4), lower=-1, upper=1),
63
+ )
64
+ idata = pm.sample()
65
+
66
+ # Check that the order constraints are satisfied
67
+ # Move chain, draw and "3" dimensions to the back
68
+ xvs = idata.posterior.x.values.transpose(3, 4, 0, 1, 2)
69
+ x0 = xvs[0] # 0 < {1, 2} < 3
70
+ assert (
71
+ (x0[0] < x0[1]).all()
72
+ and (x0[0] < x0[2]).all()
73
+ and (x0[1] < x0[3]).all()
74
+ and (x0[2] < x0[3]).all()
75
+ )
76
+ x1 = xvs[1] # 1 < 0 < 3 < 2
77
+ assert (x1[1] < x1[0]).all() and (x1[0] < x1[3]).all() and (x1[3] < x1[2]).all()
@@ -8,6 +8,7 @@ import pytensor.tensor as pt
8
8
  import pytest
9
9
 
10
10
  from pymc_extras.statespace.models import structural
11
+ from pymc_extras.statespace.models.structural import LevelTrendComponent
11
12
  from pymc_extras.statespace.utils.constants import (
12
13
  FILTER_OUTPUT_DIMS,
13
14
  FILTER_OUTPUT_NAMES,
@@ -114,3 +115,67 @@ def test_data_index_is_coord(f, warning, create_model):
114
115
  with warning:
115
116
  pymc_model = create_model(f)
116
117
  assert TIME_DIM in pymc_model.coords
118
+
119
+
120
+ def make_model(index):
121
+ n = len(index)
122
+ a = pd.DataFrame(index=index, columns=["A", "B", "C", "D"], data=np.arange(n * 4).reshape(n, 4))
123
+
124
+ mod = LevelTrendComponent(order=2, innovations_order=[0, 1])
125
+ ss_mod = mod.build(name="a", verbose=False)
126
+
127
+ initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values()
128
+ coords = ss_mod.coords
129
+
130
+ with pm.Model(coords=coords) as model:
131
+ P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5)
132
+ P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims)
133
+
134
+ initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
135
+ sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=sigma_trend_dims)
136
+
137
+ with pytest.warns(UserWarning, match="No time index found on the supplied data"):
138
+ ss_mod.build_statespace_graph(
139
+ a["A"],
140
+ mode="JAX",
141
+ )
142
+ return model
143
+
144
+
145
+ def test_integer_index():
146
+ index = np.arange(8).astype(int)
147
+ model = make_model(index)
148
+ assert TIME_DIM in model.coords
149
+ np.testing.assert_allclose(model.coords[TIME_DIM], index)
150
+
151
+
152
+ def test_float_index_raises():
153
+ index = np.linspace(0, 1, 8)
154
+
155
+ with pytest.raises(IndexError, match="Provided index is not an integer index"):
156
+ make_model(index)
157
+
158
+
159
+ def test_non_strictly_monotone_index_raises():
160
+ # Decreases
161
+ index = [0, 1, 2, 1, 2, 3]
162
+ with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
163
+ make_model(index)
164
+
165
+ # Has gaps
166
+ index = [0, 1, 2, 3, 5, 6]
167
+ with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
168
+ make_model(index)
169
+
170
+ # Has duplicates
171
+ index = [0, 1, 1, 2, 3, 4]
172
+ with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
173
+ make_model(index)
174
+
175
+
176
+ def test_multiindex_raises():
177
+ index = pd.MultiIndex.from_tuples([(0, 0), (1, 1), (2, 2), (3, 3)])
178
+ with pytest.raises(
179
+ NotImplementedError, match="MultiIndex panel data is not currently supported"
180
+ ):
181
+ make_model(index)
tests/test_laplace.py CHANGED
@@ -263,3 +263,19 @@ def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: Gradien
263
263
  else:
264
264
  assert idata.fit.rows.values.tolist() == ["mu", "sigma"]
265
265
  np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 1.5]), atol=0.1)
266
+
267
+
268
+ def test_laplace_scalar():
269
+ # Example model from Statistical Rethinking
270
+ data = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1])
271
+
272
+ with pm.Model():
273
+ p = pm.Uniform("p", 0, 1)
274
+ w = pm.Binomial("w", n=len(data), p=p, observed=data.sum())
275
+
276
+ idata_laplace = pmx.fit_laplace(progressbar=False)
277
+
278
+ assert idata_laplace.fit.mean_vector.shape == (1,)
279
+ assert idata_laplace.fit.covariance_matrix.shape == (1, 1)
280
+
281
+ np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1)