pymc-extras 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,9 @@
1
1
  from functools import wraps
2
+ from inspect import signature
2
3
 
3
- from pymc import Model
4
+ import pytensor.tensor as pt
5
+
6
+ from pymc import Data, Model
4
7
 
5
8
 
6
9
  def as_model(*model_args, **model_kwargs):
@@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs):
9
12
  This removes all need to think about context managers and lets you separate creating a generative model from using the model.
10
13
  Additionally, a coords argument is added to the function so coords can be changed during function invocation
11
14
 
15
+ All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.
16
+
12
17
  Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
13
18
 
14
19
  Examples
@@ -47,8 +52,19 @@ def as_model(*model_args, **model_kwargs):
47
52
  @wraps(f)
48
53
  def make_model(*args, **kwargs):
49
54
  coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
55
+ sig = signature(f)
56
+ ba = sig.bind(*args, **kwargs)
57
+ ba.apply_defaults()
58
+
50
59
  with Model(*model_args, coords=coords, **model_kwargs) as m:
51
- f(*args, **kwargs)
60
+ for name, v in ba.arguments.items():
61
+ # Only wrap pm.Data around values pytensor can process
62
+ try:
63
+ _ = pt.as_tensor_variable(v)
64
+ ba.arguments[name] = Data(name, v)
65
+ except (NotImplementedError, TypeError, ValueError):
66
+ pass
67
+ f(*ba.args, **ba.kwargs)
52
68
  return m
53
69
 
54
70
  return make_model
@@ -15,6 +15,9 @@ from pymc.model.transform.optimization import freeze_dims_and_data
15
15
  from pymc.util import RandomState
16
16
  from pytensor import Variable, graph_replace
17
17
  from pytensor.compile import get_mode
18
+ from rich.box import SIMPLE_HEAD
19
+ from rich.console import Console
20
+ from rich.table import Table
18
21
 
19
22
  from pymc_extras.statespace.core.representation import PytensorRepresentation
20
23
  from pymc_extras.statespace.filters import (
@@ -254,53 +257,72 @@ class PyMCStateSpace:
254
257
  self.kalman_smoother = KalmanSmoother()
255
258
  self.make_symbolic_graph()
256
259
 
257
- if verbose:
258
- # These are split into separate try-except blocks, because it will be quite rare of models to implement
259
- # _print_data_requirements, but we still want to print the prior requirements.
260
- try:
261
- self._print_prior_requirements()
262
- except NotImplementedError:
263
- pass
264
- try:
265
- self._print_data_requirements()
266
- except NotImplementedError:
267
- pass
268
-
269
- def _print_prior_requirements(self) -> None:
270
- """
271
- Prints a short report to the terminal about the priors needed for the model, including their names,
260
+ self.requirement_table = None
261
+ self._populate_prior_requirements()
262
+ self._populate_data_requirements()
263
+
264
+ if verbose and self.requirement_table:
265
+ console = Console()
266
+ console.print(self.requirement_table)
267
+
268
+ def _populate_prior_requirements(self) -> None:
269
+ """
270
+ Add requirements about priors needed for the model to a rich table, including their names,
272
271
  shapes, named dimensions, and any parameter constraints.
273
272
  """
274
- out = ""
275
- for param, info in self.param_info.items():
276
- out += f'\t{param} -- shape: {info["shape"]}, constraints: {info["constraints"]}, dims: {info["dims"]}\n'
277
- out = out.rstrip()
273
+ # Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
274
+ # is not true.
275
+ try:
276
+ if not isinstance(self.param_info, dict):
277
+ return
278
+ except NotImplementedError:
279
+ return
278
280
 
279
- _log.info(
280
- "The following parameters should be assigned priors inside a PyMC "
281
- f"model block: \n"
282
- f"{out}"
283
- )
281
+ if self.requirement_table is None:
282
+ self._initialize_requirement_table()
283
+
284
+ for param, info in self.param_info.items():
285
+ self.requirement_table.add_row(
286
+ param, str(info["shape"]), info["constraints"], str(info["dims"])
287
+ )
284
288
 
285
- def _print_data_requirements(self) -> None:
289
+ def _populate_data_requirements(self) -> None:
286
290
  """
287
- Prints a short report to the terminal about the data needed for the model, including their names, shapes,
288
- and named dimensions.
291
+ Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
289
292
  """
290
- if not self.data_info:
293
+ try:
294
+ if not isinstance(self.data_info, dict):
295
+ return
296
+ except NotImplementedError:
291
297
  return
292
298
 
293
- out = ""
299
+ if self.requirement_table is None:
300
+ self._initialize_requirement_table()
301
+ else:
302
+ self.requirement_table.add_section()
303
+
294
304
  for data, info in self.data_info.items():
295
- out += f'\t{data} -- shape: {info["shape"]}, dims: {info["dims"]}\n'
296
- out = out.rstrip()
305
+ self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"]))
306
+
307
+ def _initialize_requirement_table(self) -> None:
308
+ self.requirement_table = Table(
309
+ show_header=True,
310
+ show_edge=True,
311
+ box=SIMPLE_HEAD,
312
+ highlight=True,
313
+ )
297
314
 
298
- _log.info(
299
- "The following Data variables should be assigned to the model inside a PyMC "
300
- f"model block: \n"
301
- f"{out}"
315
+ self.requirement_table.title = "Model Requirements"
316
+ self.requirement_table.caption = (
317
+ "These parameters should be assigned priors inside a PyMC model block before "
318
+ "calling the build_statespace_graph method."
302
319
  )
303
320
 
321
+ self.requirement_table.add_column("Variable", justify="left")
322
+ self.requirement_table.add_column("Shape", justify="left")
323
+ self.requirement_table.add_column("Constraints", justify="left")
324
+ self.requirement_table.add_column("Dimensions", justify="right")
325
+
304
326
  def _unpack_statespace_with_placeholders(
305
327
  self,
306
328
  ) -> tuple[
@@ -961,10 +983,31 @@ class PyMCStateSpace:
961
983
  list[pm.Flat]
962
984
  A list of pm.Flat variables representing all parameters estimated by the model.
963
985
  """
986
+
987
+ def infer_variable_shape(name):
988
+ shape = self._name_to_variable[name].type.shape
989
+ if not any(dim is None for dim in shape):
990
+ return shape
991
+
992
+ dim_names = self._fit_dims.get(name, None)
993
+ if dim_names is None:
994
+ raise ValueError(
995
+ f"Could not infer shape for {name}, because it was not given coords during model"
996
+ f"fitting"
997
+ )
998
+
999
+ shape_from_coords = tuple([len(self._fit_coords[dim]) for dim in dim_names])
1000
+ return tuple(
1001
+ [
1002
+ shape[i] if shape[i] is not None else shape_from_coords[i]
1003
+ for i in range(len(shape))
1004
+ ]
1005
+ )
1006
+
964
1007
  for name in self.param_names:
965
1008
  pm.Flat(
966
1009
  name,
967
- shape=self._name_to_variable[name].type.shape,
1010
+ shape=infer_variable_shape(name),
968
1011
  dims=self._fit_dims.get(name, None),
969
1012
  )
970
1013
 
pymc_extras/version.txt CHANGED
@@ -1 +1 @@
1
- 0.2.1
1
+ 0.2.2
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: pymc-extras
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Home-page: http://github.com/pymc-devs/pymc-extras
6
6
  Maintainer: PyMC Developers
@@ -34,6 +34,17 @@ Provides-Extra: dev
34
34
  Requires-Dist: dask[all]; extra == "dev"
35
35
  Requires-Dist: blackjax; extra == "dev"
36
36
  Requires-Dist: statsmodels; extra == "dev"
37
+ Dynamic: classifier
38
+ Dynamic: description
39
+ Dynamic: description-content-type
40
+ Dynamic: home-page
41
+ Dynamic: license
42
+ Dynamic: maintainer
43
+ Dynamic: maintainer-email
44
+ Dynamic: provides-extra
45
+ Dynamic: requires-dist
46
+ Dynamic: requires-python
47
+ Dynamic: summary
37
48
 
38
49
  # Welcome to `pymc-extras`
39
50
  <a href="https://gitpod.io/#https://github.com/pymc-devs/pymc-extras">
@@ -3,7 +3,7 @@ pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,39
3
3
  pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
4
4
  pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
5
5
  pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
6
- pymc_extras/version.txt,sha256=cQFcl5zLD8igvnygroMEarBFzcLI-qCfsvD35ED5tKY,6
6
+ pymc_extras/version.txt,sha256=mY9riH7Xpu9E6EIQ0CN7cVvtQupyVPSNxBiv7ApIQQM,6
7
7
  pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
8
8
  pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
9
9
  pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
@@ -15,13 +15,16 @@ pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,74
15
15
  pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
16
16
  pymc_extras/inference/__init__.py,sha256=5cXpaQQnW0mJJ3x8wSxmYu63l--Xab5D_gMtjA6Q3uU,666
17
17
  pymc_extras/inference/find_map.py,sha256=T0uO8prUI5aBNuR1AN8fbA4cHmLRQLXznwJrfxfe7CA,15723
18
- pymc_extras/inference/fit.py,sha256=NFEpUaYLJAmDRP1WIPymgnEcXUofkoURYHbEdiTivzQ,1313
18
+ pymc_extras/inference/fit.py,sha256=S9R48dh74s6K0MC9Iys4NAwVjP6rVRfx6SF-kPiR70E,1165
19
19
  pymc_extras/inference/laplace.py,sha256=OglOvnxfHLe0VXxBC1-ddVzADR9zgGxUPScM6P6FYo8,21163
20
- pymc_extras/inference/pathfinder.py,sha256=cmzR2OZCfkdTipT-8pmLuF-MHmLzxotsYlezOWBUM4U,4171
20
+ pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
21
+ pymc_extras/inference/pathfinder/importance_sampling.py,sha256=VvmuaE3aw_Mo3tMwswfF0rqe19mnhOCpzIScaJzjA1Y,6159
22
+ pymc_extras/inference/pathfinder/lbfgs.py,sha256=P0UIOVtspdLzDU6alK-y91qzVAzXjYAXPuGmZ1nRqMo,5715
23
+ pymc_extras/inference/pathfinder/pathfinder.py,sha256=fomZ5voVcWxvhWpeIZV7IHGIJCasT1g0ivC4dC3-0GM,63694
21
24
  pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
22
25
  pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
23
26
  pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
- pymc_extras/model/model_api.py,sha256=_r6rYQG1tt9Z95QU-jVyHqZ1rs-u7sFMO5HJ5unDV5A,1750
27
+ pymc_extras/model/model_api.py,sha256=UHMfQXxWBujeSiUySU0fDUC5Sd_BjT8FoVz3iBxQH_4,2400
25
28
  pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
29
  pymc_extras/model/marginal/distributions.py,sha256=iM1yT7_BmivgUSloQPKE2QXGPgjvLqDMY_OTBGsdAWg,15563
27
30
  pymc_extras/model/marginal/graph_analysis.py,sha256=0hWUH_PjfpgneQ3NaT__pWHS1fh50zNbI86kH4Nub0E,15693
@@ -34,7 +37,7 @@ pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53C
34
37
  pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
35
38
  pymc_extras/statespace/core/compile.py,sha256=1c8Q9D9zeUe7F0z7CH6q1C6ZuLg2_imgk8RoE_KMaFI,1608
36
39
  pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
37
- pymc_extras/statespace/core/statespace.py,sha256=ZElRm9wJvIGG4Pw-3qiQpBkHXRDqS6pfRyuGrBBcZ2Y,95270
40
+ pymc_extras/statespace/core/statespace.py,sha256=K_WVnWKlI6sR2kgriq9sctQVvwXCeAirm14TthDpmRM,96860
38
41
  pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
39
42
  pymc_extras/statespace/filters/distributions.py,sha256=-9j__vRqL5hKyYFnQr5HKHA5kEFzwiuSccH4mslTOuQ,12900
40
43
  pymc_extras/statespace/filters/kalman_filter.py,sha256=HELC3aK4k8EdWlUAk5_F7y7YkIz-Xi_0j2AwRgAXgcc,31949
@@ -63,7 +66,7 @@ tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6
63
66
  tests/test_laplace.py,sha256=5ioEyP6AzmMszrtQRz0KWTsCCU35SEhSOdBcYfYzptE,8228
64
67
  tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
65
68
  tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
66
- tests/test_pathfinder.py,sha256=FBm0ge6rje5jz9_10h_247E70aKCpkbu1jmzrR7Ar8A,1726
69
+ tests/test_pathfinder.py,sha256=GnSbZJ9QuFW9UVbkWaVgMVqQZTCttOyz_rSflxhQ-EA,4955
67
70
  tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
68
71
  tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
69
72
  tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
@@ -75,7 +78,7 @@ tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcl
75
78
  tests/distributions/test_discrete_markov_chain.py,sha256=8RCHZXSB8IWjniuKaGGlM_iTWGmdrcOqginxmrAeEJg,9212
76
79
  tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIXeRBKiEn6kjs,10452
77
80
  tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
- tests/model/test_model_api.py,sha256=SiOMA1NpyQKJ7stYI1ms8ksDPU81lVo8wS8hbqiik-U,776
81
+ tests/model/test_model_api.py,sha256=FJvMTmexovRELZOUcUyk-6Vwk9qSiH7hIFoiArgl5mk,1040
79
82
  tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
80
83
  tests/model/marginal/test_distributions.py,sha256=p5f73g4ogxYkdZaBndZV_1ra8TCppXiRlUpaaTwEe-M,5195
81
84
  tests/model/marginal/test_graph_analysis.py,sha256=raoj41NusMOj1zzPCrxrlQODqX6Ey8Ft_o32pNTe5qg,6712
@@ -88,15 +91,15 @@ tests/statespace/test_coord_assignment.py,sha256=2GBm46-0eI4QNh4bvp3D7az58stcA5Z
88
91
  tests/statespace/test_distributions.py,sha256=WQ_ROyd-PL3cimXTyEtyVaMEVtS7Hue2Z0lN7UnGDyo,9122
89
92
  tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_YvQ2lM6hiQ,11600
90
93
  tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
91
- tests/statespace/test_statespace.py,sha256=8ZLLQaxlP5UEJnIMYyIzzAODCxMxs6E5I1hLu2HCdqo,28866
94
+ tests/statespace/test_statespace.py,sha256=JoupFFpG8PmpB_NFV471IuTmyXhEd6_vOISwVCRrBBM,30570
92
95
  tests/statespace/test_statespace_JAX.py,sha256=hZOc6xxYdVeATPCKmcHMLOVcuvdzGRzgQQ4RrDenwk8,5279
93
96
  tests/statespace/test_structural.py,sha256=HD8OaGbjuH4y3xv_uG-R1xLZpPpcb4-3dbcTeb_imLY,29306
94
97
  tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
95
98
  tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
96
99
  tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
97
100
  tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
98
- pymc_extras-0.2.1.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
99
- pymc_extras-0.2.1.dist-info/METADATA,sha256=pT1MOjFxsX6lc0q_D3J-2jNW6UaRkCq_0kJemgG4DGU,4894
100
- pymc_extras-0.2.1.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
101
- pymc_extras-0.2.1.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
102
- pymc_extras-0.2.1.dist-info/RECORD,,
101
+ pymc_extras-0.2.2.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
102
+ pymc_extras-0.2.2.dist-info/METADATA,sha256=9k60kKKNzr7E24gACpTBNOWj-tRTNOiujSZOFD89G5c,5140
103
+ pymc_extras-0.2.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
104
+ pymc_extras-0.2.2.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
105
+ pymc_extras-0.2.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.7.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -25,5 +25,14 @@ def test_logp():
25
25
 
26
26
  mw2 = model_wrapped2(coords=coords)
27
27
 
28
+ @pmx.as_model()
29
+ def model_wrapped3(mu):
30
+ pm.Normal("x", mu, 1.0, dims="obs")
31
+
32
+ mw3 = model_wrapped3(0.0, coords=coords)
33
+ mw4 = model_wrapped3(np.array([np.nan]), coords=coords)
34
+
28
35
  np.testing.assert_equal(model.point_logps(), mw.point_logps())
29
36
  np.testing.assert_equal(mw.point_logps(), mw2.point_logps())
37
+ assert mw3["mu"] in mw3.data_vars
38
+ assert "mu" not in mw4
@@ -1,3 +1,4 @@
1
+ from collections.abc import Sequence
1
2
  from functools import partial
2
3
 
3
4
  import numpy as np
@@ -349,6 +350,59 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng):
349
350
  assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
350
351
 
351
352
 
353
+ @pytest.mark.filterwarnings("ignore:Provided data contains missing values")
354
+ def test_sample_conditional_with_time_varying():
355
+ class TVCovariance(PyMCStateSpace):
356
+ def __init__(self):
357
+ super().__init__(k_states=1, k_endog=1, k_posdef=1)
358
+
359
+ def make_symbolic_graph(self) -> None:
360
+ self.ssm["transition", 0, 0] = 1.0
361
+
362
+ self.ssm["design", 0, 0] = 1.0
363
+
364
+ sigma_cov = self.make_and_register_variable("sigma_cov", (None,))
365
+ self.ssm["state_cov"] = sigma_cov[:, None, None] ** 2
366
+
367
+ @property
368
+ def param_names(self) -> list[str]:
369
+ return ["sigma_cov"]
370
+
371
+ @property
372
+ def coords(self) -> dict[str, Sequence[str]]:
373
+ return make_default_coords(self)
374
+
375
+ @property
376
+ def state_names(self) -> list[str]:
377
+ return ["level"]
378
+
379
+ @property
380
+ def observed_states(self) -> list[str]:
381
+ return ["level"]
382
+
383
+ @property
384
+ def shock_names(self) -> list[str]:
385
+ return ["level"]
386
+
387
+ ss_mod = TVCovariance()
388
+ empty_data = pd.DataFrame(
389
+ np.nan, index=pd.date_range("2020-01-01", periods=100, freq="D"), columns=["data"]
390
+ )
391
+
392
+ coords = ss_mod.coords
393
+ coords["time"] = empty_data.index
394
+ with pm.Model(coords=coords) as mod:
395
+ log_sigma_cov = pm.Normal("log_sigma_cov", mu=0, sigma=0.1, dims=["time"])
396
+ pm.Deterministic("sigma_cov", pm.math.exp(log_sigma_cov.cumsum()), dims=["time"])
397
+
398
+ ss_mod.build_statespace_graph(data=empty_data)
399
+
400
+ prior = pm.sample_prior_predictive(10)
401
+
402
+ ss_mod.sample_unconditional_prior(prior)
403
+ ss_mod.sample_conditional_prior(prior)
404
+
405
+
352
406
  def _make_time_idx(mod, use_datetime_index=True):
353
407
  if use_datetime_index:
354
408
  mod._fit_coords["time"] = nile.index
tests/test_pathfinder.py CHANGED
@@ -18,12 +18,12 @@ import numpy as np
18
18
  import pymc as pm
19
19
  import pytest
20
20
 
21
+ pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning")
22
+
21
23
  import pymc_extras as pmx
22
24
 
23
25
 
24
- @pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.")
25
- def test_pathfinder():
26
- # Data of the Eight Schools Model
26
+ def eight_schools_model() -> pm.Model:
27
27
  J = 8
28
28
  y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
29
29
  sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
@@ -35,11 +35,139 @@ def test_pathfinder():
35
35
  theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
36
36
  obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)
37
37
 
38
- idata = pmx.fit(method="pathfinder", random_seed=41)
38
+ return model
39
+
40
+
41
+ @pytest.fixture
42
+ def reference_idata():
43
+ model = eight_schools_model()
44
+ with model:
45
+ idata = pmx.fit(
46
+ method="pathfinder",
47
+ num_paths=50,
48
+ jitter=10.0,
49
+ random_seed=41,
50
+ inference_backend="pymc",
51
+ )
52
+ return idata
53
+
54
+
55
+ @pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
56
+ def test_pathfinder(inference_backend, reference_idata):
57
+ if inference_backend == "blackjax" and sys.platform == "win32":
58
+ pytest.skip("JAX not supported on windows")
59
+
60
+ if inference_backend == "blackjax":
61
+ model = eight_schools_model()
62
+ with model:
63
+ idata = pmx.fit(
64
+ method="pathfinder",
65
+ num_paths=50,
66
+ jitter=10.0,
67
+ random_seed=41,
68
+ inference_backend=inference_backend,
69
+ )
70
+ else:
71
+ idata = reference_idata
72
+ np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6)
73
+ np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5)
39
74
 
40
75
  assert idata.posterior["mu"].shape == (1, 1000)
41
76
  assert idata.posterior["tau"].shape == (1, 1000)
42
77
  assert idata.posterior["theta"].shape == (1, 1000, 8)
43
- # FIXME: pathfinder doesn't find a reasonable mean! Fix bug or choose model pathfinder can handle
44
- # np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0)
45
- np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=0.5)
78
+
79
+
80
+ @pytest.mark.parametrize("concurrent", ["thread", "process"])
81
+ def test_concurrent_results(reference_idata, concurrent):
82
+ model = eight_schools_model()
83
+ with model:
84
+ idata_conc = pmx.fit(
85
+ method="pathfinder",
86
+ num_paths=50,
87
+ jitter=10.0,
88
+ random_seed=41,
89
+ inference_backend="pymc",
90
+ concurrent=concurrent,
91
+ )
92
+
93
+ np.testing.assert_allclose(
94
+ reference_idata.posterior.mu.data.mean(),
95
+ idata_conc.posterior.mu.data.mean(),
96
+ atol=0.4,
97
+ )
98
+
99
+ np.testing.assert_allclose(
100
+ reference_idata.posterior.tau.data.mean(),
101
+ idata_conc.posterior.tau.data.mean(),
102
+ atol=0.4,
103
+ )
104
+
105
+
106
+ def test_seed(reference_idata):
107
+ model = eight_schools_model()
108
+ with model:
109
+ idata_41 = pmx.fit(
110
+ method="pathfinder",
111
+ num_paths=50,
112
+ jitter=10.0,
113
+ random_seed=41,
114
+ inference_backend="pymc",
115
+ )
116
+
117
+ idata_123 = pmx.fit(
118
+ method="pathfinder",
119
+ num_paths=50,
120
+ jitter=10.0,
121
+ random_seed=123,
122
+ inference_backend="pymc",
123
+ )
124
+
125
+ assert not np.allclose(idata_41.posterior.mu.data.mean(), idata_123.posterior.mu.data.mean())
126
+
127
+ assert np.allclose(idata_41.posterior.mu.data.mean(), idata_41.posterior.mu.data.mean())
128
+
129
+
130
+ def test_bfgs_sample():
131
+ import pytensor.tensor as pt
132
+
133
+ from pymc_extras.inference.pathfinder.pathfinder import (
134
+ alpha_recover,
135
+ bfgs_sample,
136
+ inverse_hessian_factors,
137
+ )
138
+
139
+ """test BFGS sampling"""
140
+ Lp1, N = 8, 10
141
+ L = Lp1 - 1
142
+ J = 6
143
+ num_samples = 1000
144
+
145
+ # mock data
146
+ x_data = np.random.randn(Lp1, N)
147
+ g_data = np.random.randn(Lp1, N)
148
+
149
+ # get factors
150
+ x_full = pt.as_tensor(x_data, dtype="float64")
151
+ g_full = pt.as_tensor(g_data, dtype="float64")
152
+ epsilon = 1e-11
153
+
154
+ x = x_full[1:]
155
+ g = g_full[1:]
156
+ alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon)
157
+ beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J)
158
+
159
+ # sample
160
+ phi, logq = bfgs_sample(
161
+ num_samples=num_samples,
162
+ x=x,
163
+ g=g,
164
+ alpha=alpha,
165
+ beta=beta,
166
+ gamma=gamma,
167
+ )
168
+
169
+ # check shapes
170
+ assert beta.eval().shape == (L, N, 2 * J)
171
+ assert gamma.eval().shape == (L, 2 * J, 2 * J)
172
+ assert phi.eval().shape == (L, num_samples, N)
173
+ assert logq.eval().shape == (L, num_samples)