pymc-extras 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,7 +19,8 @@ from pymc.model.fgraph import (
19
19
  model_free_rv,
20
20
  model_from_fgraph,
21
21
  )
22
- from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold, toposort_replace
22
+ from pymc.pytensorf import collect_default_updates, constant_fold, toposort_replace
23
+ from pymc.pytensorf import compile as compile_pymc
23
24
  from pymc.util import RandomState, _get_seeds_per_chain
24
25
  from pytensor import In, Out
25
26
  from pytensor.compile import SharedVariable
@@ -1,6 +1,9 @@
1
1
  from functools import wraps
2
+ from inspect import signature
2
3
 
3
- from pymc import Model
4
+ import pytensor.tensor as pt
5
+
6
+ from pymc import Data, Model
4
7
 
5
8
 
6
9
  def as_model(*model_args, **model_kwargs):
@@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs):
9
12
  This removes all need to think about context managers and lets you separate creating a generative model from using the model.
10
13
  Additionally, a coords argument is added to the function so coords can be changed during function invocation
11
14
 
15
+ All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.
16
+
12
17
  Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
13
18
 
14
19
  Examples
@@ -47,8 +52,19 @@ def as_model(*model_args, **model_kwargs):
47
52
  @wraps(f)
48
53
  def make_model(*args, **kwargs):
49
54
  coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
55
+ sig = signature(f)
56
+ ba = sig.bind(*args, **kwargs)
57
+ ba.apply_defaults()
58
+
50
59
  with Model(*model_args, coords=coords, **model_kwargs) as m:
51
- f(*args, **kwargs)
60
+ for name, v in ba.arguments.items():
61
+ # Only wrap pm.Data around values pytensor can process
62
+ try:
63
+ _ = pt.as_tensor_variable(v)
64
+ ba.arguments[name] = Data(name, v)
65
+ except (NotImplementedError, TypeError, ValueError):
66
+ pass
67
+ f(*ba.args, **ba.kwargs)
52
68
  return m
53
69
 
54
70
  return make_model
@@ -30,7 +30,7 @@ def compile_statespace(
30
30
 
31
31
  inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
32
32
 
33
- _f = pm.compile_pymc(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
33
+ _f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
34
34
 
35
35
  def f(*, draws=1, **params):
36
36
  if isinstance(steps, pt.Variable):
@@ -15,6 +15,9 @@ from pymc.model.transform.optimization import freeze_dims_and_data
15
15
  from pymc.util import RandomState
16
16
  from pytensor import Variable, graph_replace
17
17
  from pytensor.compile import get_mode
18
+ from rich.box import SIMPLE_HEAD
19
+ from rich.console import Console
20
+ from rich.table import Table
18
21
 
19
22
  from pymc_extras.statespace.core.representation import PytensorRepresentation
20
23
  from pymc_extras.statespace.filters import (
@@ -254,53 +257,72 @@ class PyMCStateSpace:
254
257
  self.kalman_smoother = KalmanSmoother()
255
258
  self.make_symbolic_graph()
256
259
 
257
- if verbose:
258
- # These are split into separate try-except blocks, because it will be quite rare of models to implement
259
- # _print_data_requirements, but we still want to print the prior requirements.
260
- try:
261
- self._print_prior_requirements()
262
- except NotImplementedError:
263
- pass
264
- try:
265
- self._print_data_requirements()
266
- except NotImplementedError:
267
- pass
268
-
269
- def _print_prior_requirements(self) -> None:
270
- """
271
- Prints a short report to the terminal about the priors needed for the model, including their names,
260
+ self.requirement_table = None
261
+ self._populate_prior_requirements()
262
+ self._populate_data_requirements()
263
+
264
+ if verbose and self.requirement_table:
265
+ console = Console()
266
+ console.print(self.requirement_table)
267
+
268
+ def _populate_prior_requirements(self) -> None:
269
+ """
270
+ Add requirements about priors needed for the model to a rich table, including their names,
272
271
  shapes, named dimensions, and any parameter constraints.
273
272
  """
274
- out = ""
275
- for param, info in self.param_info.items():
276
- out += f'\t{param} -- shape: {info["shape"]}, constraints: {info["constraints"]}, dims: {info["dims"]}\n'
277
- out = out.rstrip()
273
+ # Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
274
+ # is not true.
275
+ try:
276
+ if not isinstance(self.param_info, dict):
277
+ return
278
+ except NotImplementedError:
279
+ return
278
280
 
279
- _log.info(
280
- "The following parameters should be assigned priors inside a PyMC "
281
- f"model block: \n"
282
- f"{out}"
283
- )
281
+ if self.requirement_table is None:
282
+ self._initialize_requirement_table()
283
+
284
+ for param, info in self.param_info.items():
285
+ self.requirement_table.add_row(
286
+ param, str(info["shape"]), info["constraints"], str(info["dims"])
287
+ )
284
288
 
285
- def _print_data_requirements(self) -> None:
289
+ def _populate_data_requirements(self) -> None:
286
290
  """
287
- Prints a short report to the terminal about the data needed for the model, including their names, shapes,
288
- and named dimensions.
291
+ Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
289
292
  """
290
- if not self.data_info:
293
+ try:
294
+ if not isinstance(self.data_info, dict):
295
+ return
296
+ except NotImplementedError:
291
297
  return
292
298
 
293
- out = ""
299
+ if self.requirement_table is None:
300
+ self._initialize_requirement_table()
301
+ else:
302
+ self.requirement_table.add_section()
303
+
294
304
  for data, info in self.data_info.items():
295
- out += f'\t{data} -- shape: {info["shape"]}, dims: {info["dims"]}\n'
296
- out = out.rstrip()
305
+ self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"]))
306
+
307
+ def _initialize_requirement_table(self) -> None:
308
+ self.requirement_table = Table(
309
+ show_header=True,
310
+ show_edge=True,
311
+ box=SIMPLE_HEAD,
312
+ highlight=True,
313
+ )
297
314
 
298
- _log.info(
299
- "The following Data variables should be assigned to the model inside a PyMC "
300
- f"model block: \n"
301
- f"{out}"
315
+ self.requirement_table.title = "Model Requirements"
316
+ self.requirement_table.caption = (
317
+ "These parameters should be assigned priors inside a PyMC model block before "
318
+ "calling the build_statespace_graph method."
302
319
  )
303
320
 
321
+ self.requirement_table.add_column("Variable", justify="left")
322
+ self.requirement_table.add_column("Shape", justify="left")
323
+ self.requirement_table.add_column("Constraints", justify="left")
324
+ self.requirement_table.add_column("Dimensions", justify="right")
325
+
304
326
  def _unpack_statespace_with_placeholders(
305
327
  self,
306
328
  ) -> tuple[
@@ -961,10 +983,31 @@ class PyMCStateSpace:
961
983
  list[pm.Flat]
962
984
  A list of pm.Flat variables representing all parameters estimated by the model.
963
985
  """
986
+
987
+ def infer_variable_shape(name):
988
+ shape = self._name_to_variable[name].type.shape
989
+ if not any(dim is None for dim in shape):
990
+ return shape
991
+
992
+ dim_names = self._fit_dims.get(name, None)
993
+ if dim_names is None:
994
+ raise ValueError(
995
+ f"Could not infer shape for {name}, because it was not given coords during model"
996
+ f"fitting"
997
+ )
998
+
999
+ shape_from_coords = tuple([len(self._fit_coords[dim]) for dim in dim_names])
1000
+ return tuple(
1001
+ [
1002
+ shape[i] if shape[i] is not None else shape_from_coords[i]
1003
+ for i in range(len(shape))
1004
+ ]
1005
+ )
1006
+
964
1007
  for name in self.param_names:
965
1008
  pm.Flat(
966
1009
  name,
967
- shape=self._name_to_variable[name].type.shape,
1010
+ shape=infer_variable_shape(name),
968
1011
  dims=self._fit_dims.get(name, None),
969
1012
  )
970
1013
 
pymc_extras/version.txt CHANGED
@@ -1 +1 @@
1
- 0.2.1
1
+ 0.2.3
@@ -1,11 +1,11 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: pymc-extras
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Home-page: http://github.com/pymc-devs/pymc-extras
6
6
  Maintainer: PyMC Developers
7
7
  Maintainer-email: pymc.devs@gmail.com
8
- License: Apache License, Version 2.0
8
+ License: Apache-2.0
9
9
  Classifier: Development Status :: 5 - Production/Stable
10
10
  Classifier: Programming Language :: Python
11
11
  Classifier: Programming Language :: Python :: 3
@@ -20,8 +20,9 @@ Classifier: Operating System :: OS Independent
20
20
  Requires-Python: >=3.10
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: pymc>=5.19.1
23
+ Requires-Dist: pymc>=5.20
24
24
  Requires-Dist: scikit-learn
25
+ Requires-Dist: better-optimize
25
26
  Provides-Extra: dask-histogram
26
27
  Requires-Dist: dask[complete]; extra == "dask-histogram"
27
28
  Requires-Dist: xhistogram; extra == "dask-histogram"
@@ -34,6 +35,17 @@ Provides-Extra: dev
34
35
  Requires-Dist: dask[all]; extra == "dev"
35
36
  Requires-Dist: blackjax; extra == "dev"
36
37
  Requires-Dist: statsmodels; extra == "dev"
38
+ Dynamic: classifier
39
+ Dynamic: description
40
+ Dynamic: description-content-type
41
+ Dynamic: home-page
42
+ Dynamic: license
43
+ Dynamic: maintainer
44
+ Dynamic: maintainer-email
45
+ Dynamic: provides-extra
46
+ Dynamic: requires-dist
47
+ Dynamic: requires-python
48
+ Dynamic: summary
37
49
 
38
50
  # Welcome to `pymc-extras`
39
51
  <a href="https://gitpod.io/#https://github.com/pymc-devs/pymc-extras">
@@ -1,9 +1,9 @@
1
- pymc_extras/__init__.py,sha256=URh185f6b1xp2Taj2W2NJuW_hErKufBcLeQ0WDCyaNk,1160
1
+ pymc_extras/__init__.py,sha256=IFIEZdPX_Ugq57Bu7jlyrJLpKng-P0FBAAAzl2pFXLE,1266
2
2
  pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
3
3
  pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
4
4
  pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
5
5
  pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
6
- pymc_extras/version.txt,sha256=cQFcl5zLD8igvnygroMEarBFzcLI-qCfsvD35ED5tKY,6
6
+ pymc_extras/version.txt,sha256=OrlMBNJJhvOvKIuhzaLAu928Wonf8JcYKAX1RXjh6nU,6
7
7
  pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
8
8
  pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
9
9
  pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
@@ -14,27 +14,30 @@ pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNP
14
14
  pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
15
15
  pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
16
16
  pymc_extras/inference/__init__.py,sha256=5cXpaQQnW0mJJ3x8wSxmYu63l--Xab5D_gMtjA6Q3uU,666
17
- pymc_extras/inference/find_map.py,sha256=T0uO8prUI5aBNuR1AN8fbA4cHmLRQLXznwJrfxfe7CA,15723
18
- pymc_extras/inference/fit.py,sha256=NFEpUaYLJAmDRP1WIPymgnEcXUofkoURYHbEdiTivzQ,1313
19
- pymc_extras/inference/laplace.py,sha256=OglOvnxfHLe0VXxBC1-ddVzADR9zgGxUPScM6P6FYo8,21163
20
- pymc_extras/inference/pathfinder.py,sha256=cmzR2OZCfkdTipT-8pmLuF-MHmLzxotsYlezOWBUM4U,4171
17
+ pymc_extras/inference/find_map.py,sha256=vl5l0ei48PnX-uTuHVTr-9QpCEHc8xog-KK6sOnJ8LU,16513
18
+ pymc_extras/inference/fit.py,sha256=S9R48dh74s6K0MC9Iys4NAwVjP6rVRfx6SF-kPiR70E,1165
19
+ pymc_extras/inference/laplace.py,sha256=uOZGp8ssQuhvCHV_Y_v3icsr4rhcYgr_qlr9dS7pcSM,21761
20
+ pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
21
+ pymc_extras/inference/pathfinder/importance_sampling.py,sha256=VvmuaE3aw_Mo3tMwswfF0rqe19mnhOCpzIScaJzjA1Y,6159
22
+ pymc_extras/inference/pathfinder/lbfgs.py,sha256=P0UIOVtspdLzDU6alK-y91qzVAzXjYAXPuGmZ1nRqMo,5715
23
+ pymc_extras/inference/pathfinder/pathfinder.py,sha256=fomZ5voVcWxvhWpeIZV7IHGIJCasT1g0ivC4dC3-0GM,63694
21
24
  pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
22
25
  pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
23
26
  pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
- pymc_extras/model/model_api.py,sha256=_r6rYQG1tt9Z95QU-jVyHqZ1rs-u7sFMO5HJ5unDV5A,1750
27
+ pymc_extras/model/model_api.py,sha256=UHMfQXxWBujeSiUySU0fDUC5Sd_BjT8FoVz3iBxQH_4,2400
25
28
  pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
29
  pymc_extras/model/marginal/distributions.py,sha256=iM1yT7_BmivgUSloQPKE2QXGPgjvLqDMY_OTBGsdAWg,15563
27
30
  pymc_extras/model/marginal/graph_analysis.py,sha256=0hWUH_PjfpgneQ3NaT__pWHS1fh50zNbI86kH4Nub0E,15693
28
- pymc_extras/model/marginal/marginal_model.py,sha256=oNsiSWHjOPCTDxNEivEILLP_cOuBarm29Gr2p6hWHIM,23594
31
+ pymc_extras/model/marginal/marginal_model.py,sha256=oIdikaSnefCkyMxmzAe222qGXNucxZpHYk7548fK6iA,23631
29
32
  pymc_extras/model/transforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
33
  pymc_extras/model/transforms/autoreparam.py,sha256=_NltGWmNqi_X9sHCqAvWcBveLTPxVy11-wENFTcN6kk,12377
31
34
  pymc_extras/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
35
  pymc_extras/preprocessing/standard_scaler.py,sha256=Vajp33ma6OkwlU54JYtSS8urHbMJ3CRiRFxZpvFNuus,600
33
36
  pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53CF6ND0,429
34
37
  pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
35
- pymc_extras/statespace/core/compile.py,sha256=1c8Q9D9zeUe7F0z7CH6q1C6ZuLg2_imgk8RoE_KMaFI,1608
38
+ pymc_extras/statespace/core/compile.py,sha256=9FZfE8Bi3VfElxujfOIKRVvmyL9M5R0WfNEqPc5kbVQ,1603
36
39
  pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
37
- pymc_extras/statespace/core/statespace.py,sha256=ZElRm9wJvIGG4Pw-3qiQpBkHXRDqS6pfRyuGrBBcZ2Y,95270
40
+ pymc_extras/statespace/core/statespace.py,sha256=K_WVnWKlI6sR2kgriq9sctQVvwXCeAirm14TthDpmRM,96860
38
41
  pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
39
42
  pymc_extras/statespace/filters/distributions.py,sha256=-9j__vRqL5hKyYFnQr5HKHA5kEFzwiuSccH4mslTOuQ,12900
40
43
  pymc_extras/statespace/filters/kalman_filter.py,sha256=HELC3aK4k8EdWlUAk5_F7y7YkIz-Xi_0j2AwRgAXgcc,31949
@@ -58,12 +61,12 @@ pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,68
58
61
  pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
59
62
  tests/__init__.py,sha256=-ree9OWVCyTeXLR944OWjrQX2os15HXrRNkhJ7QdRjc,603
60
63
  tests/test_blackjax_smc.py,sha256=jcNgcMBxaKyPg9UvHnWQtwoL79LXlSpZfALe3RGEZnQ,7233
61
- tests/test_find_map.py,sha256=iAphukWw7cBiJXX5KI-veATeinqbgFSn2IEYfvYPeYU,3069
64
+ tests/test_find_map.py,sha256=B8ThnXNyfTQeem24QaLoTitFrsxKoq2VQINUdOwzna0,3379
62
65
  tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6YxzqA5X4,4674
63
- tests/test_laplace.py,sha256=5ioEyP6AzmMszrtQRz0KWTsCCU35SEhSOdBcYfYzptE,8228
66
+ tests/test_laplace.py,sha256=u4o-0y4v1emaTMYr_rOyL_EKY_bQIz0DUXFuwuDbfNg,9314
64
67
  tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
65
68
  tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
66
- tests/test_pathfinder.py,sha256=FBm0ge6rje5jz9_10h_247E70aKCpkbu1jmzrR7Ar8A,1726
69
+ tests/test_pathfinder.py,sha256=GnSbZJ9QuFW9UVbkWaVgMVqQZTCttOyz_rSflxhQ-EA,4955
67
70
  tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
68
71
  tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
69
72
  tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
@@ -75,7 +78,7 @@ tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcl
75
78
  tests/distributions/test_discrete_markov_chain.py,sha256=8RCHZXSB8IWjniuKaGGlM_iTWGmdrcOqginxmrAeEJg,9212
76
79
  tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIXeRBKiEn6kjs,10452
77
80
  tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
- tests/model/test_model_api.py,sha256=SiOMA1NpyQKJ7stYI1ms8ksDPU81lVo8wS8hbqiik-U,776
81
+ tests/model/test_model_api.py,sha256=FJvMTmexovRELZOUcUyk-6Vwk9qSiH7hIFoiArgl5mk,1040
79
82
  tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
80
83
  tests/model/marginal/test_distributions.py,sha256=p5f73g4ogxYkdZaBndZV_1ra8TCppXiRlUpaaTwEe-M,5195
81
84
  tests/model/marginal/test_graph_analysis.py,sha256=raoj41NusMOj1zzPCrxrlQODqX6Ey8Ft_o32pNTe5qg,6712
@@ -88,15 +91,15 @@ tests/statespace/test_coord_assignment.py,sha256=2GBm46-0eI4QNh4bvp3D7az58stcA5Z
88
91
  tests/statespace/test_distributions.py,sha256=WQ_ROyd-PL3cimXTyEtyVaMEVtS7Hue2Z0lN7UnGDyo,9122
89
92
  tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_YvQ2lM6hiQ,11600
90
93
  tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
91
- tests/statespace/test_statespace.py,sha256=8ZLLQaxlP5UEJnIMYyIzzAODCxMxs6E5I1hLu2HCdqo,28866
94
+ tests/statespace/test_statespace.py,sha256=JoupFFpG8PmpB_NFV471IuTmyXhEd6_vOISwVCRrBBM,30570
92
95
  tests/statespace/test_statespace_JAX.py,sha256=hZOc6xxYdVeATPCKmcHMLOVcuvdzGRzgQQ4RrDenwk8,5279
93
96
  tests/statespace/test_structural.py,sha256=HD8OaGbjuH4y3xv_uG-R1xLZpPpcb4-3dbcTeb_imLY,29306
94
97
  tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
95
98
  tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
96
99
  tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
97
100
  tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
98
- pymc_extras-0.2.1.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
99
- pymc_extras-0.2.1.dist-info/METADATA,sha256=pT1MOjFxsX6lc0q_D3J-2jNW6UaRkCq_0kJemgG4DGU,4894
100
- pymc_extras-0.2.1.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
101
- pymc_extras-0.2.1.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
102
- pymc_extras-0.2.1.dist-info/RECORD,,
101
+ pymc_extras-0.2.3.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
102
+ pymc_extras-0.2.3.dist-info/METADATA,sha256=ZTiMM7hvVRF3O_liRu4Aea_EuxJc4vHfTD2CbRRQrcU,5152
103
+ pymc_extras-0.2.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
104
+ pymc_extras-0.2.3.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
105
+ pymc_extras-0.2.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.7.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -25,5 +25,14 @@ def test_logp():
25
25
 
26
26
  mw2 = model_wrapped2(coords=coords)
27
27
 
28
+ @pmx.as_model()
29
+ def model_wrapped3(mu):
30
+ pm.Normal("x", mu, 1.0, dims="obs")
31
+
32
+ mw3 = model_wrapped3(0.0, coords=coords)
33
+ mw4 = model_wrapped3(np.array([np.nan]), coords=coords)
34
+
28
35
  np.testing.assert_equal(model.point_logps(), mw.point_logps())
29
36
  np.testing.assert_equal(mw.point_logps(), mw2.point_logps())
37
+ assert mw3["mu"] in mw3.data_vars
38
+ assert "mu" not in mw4
@@ -1,3 +1,4 @@
1
+ from collections.abc import Sequence
1
2
  from functools import partial
2
3
 
3
4
  import numpy as np
@@ -349,6 +350,59 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng):
349
350
  assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
350
351
 
351
352
 
353
+ @pytest.mark.filterwarnings("ignore:Provided data contains missing values")
354
+ def test_sample_conditional_with_time_varying():
355
+ class TVCovariance(PyMCStateSpace):
356
+ def __init__(self):
357
+ super().__init__(k_states=1, k_endog=1, k_posdef=1)
358
+
359
+ def make_symbolic_graph(self) -> None:
360
+ self.ssm["transition", 0, 0] = 1.0
361
+
362
+ self.ssm["design", 0, 0] = 1.0
363
+
364
+ sigma_cov = self.make_and_register_variable("sigma_cov", (None,))
365
+ self.ssm["state_cov"] = sigma_cov[:, None, None] ** 2
366
+
367
+ @property
368
+ def param_names(self) -> list[str]:
369
+ return ["sigma_cov"]
370
+
371
+ @property
372
+ def coords(self) -> dict[str, Sequence[str]]:
373
+ return make_default_coords(self)
374
+
375
+ @property
376
+ def state_names(self) -> list[str]:
377
+ return ["level"]
378
+
379
+ @property
380
+ def observed_states(self) -> list[str]:
381
+ return ["level"]
382
+
383
+ @property
384
+ def shock_names(self) -> list[str]:
385
+ return ["level"]
386
+
387
+ ss_mod = TVCovariance()
388
+ empty_data = pd.DataFrame(
389
+ np.nan, index=pd.date_range("2020-01-01", periods=100, freq="D"), columns=["data"]
390
+ )
391
+
392
+ coords = ss_mod.coords
393
+ coords["time"] = empty_data.index
394
+ with pm.Model(coords=coords) as mod:
395
+ log_sigma_cov = pm.Normal("log_sigma_cov", mu=0, sigma=0.1, dims=["time"])
396
+ pm.Deterministic("sigma_cov", pm.math.exp(log_sigma_cov.cumsum()), dims=["time"])
397
+
398
+ ss_mod.build_statespace_graph(data=empty_data)
399
+
400
+ prior = pm.sample_prior_predictive(10)
401
+
402
+ ss_mod.sample_unconditional_prior(prior)
403
+ ss_mod.sample_conditional_prior(prior)
404
+
405
+
352
406
  def _make_time_idx(mod, use_datetime_index=True):
353
407
  if use_datetime_index:
354
408
  mod._fit_coords["time"] = nile.index
tests/test_find_map.py CHANGED
@@ -54,24 +54,28 @@ def test_jax_functions_from_graph(gradient_backend: GradientBackend):
54
54
 
55
55
 
56
56
  @pytest.mark.parametrize(
57
- "method, use_grad, use_hess",
57
+ "method, use_grad, use_hess, use_hessp",
58
58
  [
59
- ("nelder-mead", False, False),
60
- ("powell", False, False),
61
- ("CG", True, False),
62
- ("BFGS", True, False),
63
- ("L-BFGS-B", True, False),
64
- ("TNC", True, False),
65
- ("SLSQP", True, False),
66
- ("dogleg", True, True),
67
- ("trust-ncg", True, True),
68
- ("trust-exact", True, True),
69
- ("trust-krylov", True, True),
70
- ("trust-constr", True, True),
59
+ ("nelder-mead", False, False, False),
60
+ ("powell", False, False, False),
61
+ ("CG", True, False, False),
62
+ ("BFGS", True, False, False),
63
+ ("L-BFGS-B", True, False, False),
64
+ ("TNC", True, False, False),
65
+ ("SLSQP", True, False, False),
66
+ ("dogleg", True, True, False),
67
+ ("Newton-CG", True, True, False),
68
+ ("Newton-CG", True, False, True),
69
+ ("trust-ncg", True, True, False),
70
+ ("trust-ncg", True, False, True),
71
+ ("trust-exact", True, True, False),
72
+ ("trust-krylov", True, True, False),
73
+ ("trust-krylov", True, False, True),
74
+ ("trust-constr", True, True, False),
71
75
  ],
72
76
  )
73
77
  @pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
74
- def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng):
78
+ def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng):
75
79
  extra_kwargs = {}
76
80
  if method == "dogleg":
77
81
  # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
@@ -88,6 +92,7 @@ def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend,
88
92
  **extra_kwargs,
89
93
  use_grad=use_grad,
90
94
  use_hess=use_hess,
95
+ use_hessp=use_hessp,
91
96
  progressbar=False,
92
97
  gradient_backend=gradient_backend,
93
98
  compile_kwargs={"mode": "JAX"},
tests/test_laplace.py CHANGED
@@ -19,10 +19,10 @@ import pytest
19
19
 
20
20
  import pymc_extras as pmx
21
21
 
22
- from pymc_extras.inference.find_map import find_MAP
22
+ from pymc_extras.inference.find_map import GradientBackend, find_MAP
23
23
  from pymc_extras.inference.laplace import (
24
24
  fit_laplace,
25
- fit_mvn_to_MAP,
25
+ fit_mvn_at_MAP,
26
26
  sample_laplace_posterior,
27
27
  )
28
28
 
@@ -37,7 +37,11 @@ def rng():
37
37
  "ignore:hessian will stop negating the output in a future version of PyMC.\n"
38
38
  + "To suppress this warning set `negate_output=False`:FutureWarning",
39
39
  )
40
- def test_laplace():
40
+ @pytest.mark.parametrize(
41
+ "mode, gradient_backend",
42
+ [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
43
+ )
44
+ def test_laplace(mode, gradient_backend: GradientBackend):
41
45
  # Example originates from Bayesian Data Analyses, 3rd Edition
42
46
  # By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
43
47
  # Aki Vehtari, and Donald Rubin.
@@ -55,7 +59,13 @@ def test_laplace():
55
59
  vars = [mu, logsigma]
56
60
 
57
61
  idata = pmx.fit(
58
- method="laplace", optimize_method="trust-ncg", draws=draws, random_seed=173300, chains=1
62
+ method="laplace",
63
+ optimize_method="trust-ncg",
64
+ draws=draws,
65
+ random_seed=173300,
66
+ chains=1,
67
+ compile_kwargs={"mode": mode},
68
+ gradient_backend=gradient_backend,
59
69
  )
60
70
 
61
71
  assert idata.posterior["mu"].shape == (1, draws)
@@ -71,7 +81,11 @@ def test_laplace():
71
81
  np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
72
82
 
73
83
 
74
- def test_laplace_only_fit():
84
+ @pytest.mark.parametrize(
85
+ "mode, gradient_backend",
86
+ [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
87
+ )
88
+ def test_laplace_only_fit(mode, gradient_backend: GradientBackend):
75
89
  # Example originates from Bayesian Data Analyses, 3rd Edition
76
90
  # By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
77
91
  # Aki Vehtari, and Donald Rubin.
@@ -90,8 +104,8 @@ def test_laplace_only_fit():
90
104
  method="laplace",
91
105
  optimize_method="BFGS",
92
106
  progressbar=True,
93
- gradient_backend="jax",
94
- compile_kwargs={"mode": "JAX"},
107
+ gradient_backend=gradient_backend,
108
+ compile_kwargs={"mode": mode},
95
109
  optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100),
96
110
  random_seed=173300,
97
111
  )
@@ -111,8 +125,11 @@ def test_laplace_only_fit():
111
125
  [True, False],
112
126
  ids=["transformed", "untransformed"],
113
127
  )
114
- @pytest.mark.parametrize("mode", ["JAX", None], ids=["jax", "pytensor"])
115
- def test_fit_laplace_coords(rng, transform_samples, mode):
128
+ @pytest.mark.parametrize(
129
+ "mode, gradient_backend",
130
+ [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
131
+ )
132
+ def test_fit_laplace_coords(rng, transform_samples, mode, gradient_backend: GradientBackend):
116
133
  coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)}
117
134
  with pm.Model(coords=coords) as model:
118
135
  mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"])
@@ -131,13 +148,13 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
131
148
  use_hessp=True,
132
149
  progressbar=False,
133
150
  compile_kwargs=dict(mode=mode),
134
- gradient_backend="jax" if mode == "JAX" else "pytensor",
151
+ gradient_backend=gradient_backend,
135
152
  )
136
153
 
137
154
  for value in optimized_point.values():
138
155
  assert value.shape == (3,)
139
156
 
140
- mu, H_inv = fit_mvn_to_MAP(
157
+ mu, H_inv = fit_mvn_at_MAP(
141
158
  optimized_point=optimized_point,
142
159
  model=model,
143
160
  transform_samples=transform_samples,
@@ -163,7 +180,11 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
163
180
  ]
164
181
 
165
182
 
166
- def test_fit_laplace_ragged_coords(rng):
183
+ @pytest.mark.parametrize(
184
+ "mode, gradient_backend",
185
+ [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
186
+ )
187
+ def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng):
167
188
  coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
168
189
  with pm.Model(coords=coords) as ragged_dim_model:
169
190
  X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"])
@@ -188,8 +209,8 @@ def test_fit_laplace_ragged_coords(rng):
188
209
  progressbar=False,
189
210
  use_grad=True,
190
211
  use_hessp=True,
191
- gradient_backend="jax",
192
- compile_kwargs={"mode": "JAX"},
212
+ gradient_backend=gradient_backend,
213
+ compile_kwargs={"mode": mode},
193
214
  )
194
215
 
195
216
  assert idata["posterior"].beta.shape[-2:] == (3, 2)
@@ -206,7 +227,11 @@ def test_fit_laplace_ragged_coords(rng):
206
227
  [True, False],
207
228
  ids=["transformed", "untransformed"],
208
229
  )
209
- def test_fit_laplace(fit_in_unconstrained_space):
230
+ @pytest.mark.parametrize(
231
+ "mode, gradient_backend",
232
+ [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
233
+ )
234
+ def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: GradientBackend):
210
235
  with pm.Model() as simp_model:
211
236
  mu = pm.Normal("mu", mu=3, sigma=0.5)
212
237
  sigma = pm.Exponential("sigma", 1)
@@ -223,6 +248,8 @@ def test_fit_laplace(fit_in_unconstrained_space):
223
248
  use_hessp=True,
224
249
  fit_in_unconstrained_space=fit_in_unconstrained_space,
225
250
  optimizer_kwargs=dict(maxiter=100_000, tol=1e-100),
251
+ compile_kwargs={"mode": mode},
252
+ gradient_backend=gradient_backend,
226
253
  )
227
254
 
228
255
  np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1)