pymc-extras 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,9 @@
1
1
  from functools import wraps
2
+ from inspect import signature
2
3
 
3
- from pymc import Model
4
+ import pytensor.tensor as pt
5
+
6
+ from pymc import Data, Model
4
7
 
5
8
 
6
9
  def as_model(*model_args, **model_kwargs):
@@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs):
9
12
  This removes all need to think about context managers and lets you separate creating a generative model from using the model.
10
13
  Additionally, a coords argument is added to the function so coords can be changed during function invocation
11
14
 
15
+ All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.
16
+
12
17
  Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
13
18
 
14
19
  Examples
@@ -47,8 +52,19 @@ def as_model(*model_args, **model_kwargs):
47
52
  @wraps(f)
48
53
  def make_model(*args, **kwargs):
49
54
  coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
55
+ sig = signature(f)
56
+ ba = sig.bind(*args, **kwargs)
57
+ ba.apply_defaults()
58
+
50
59
  with Model(*model_args, coords=coords, **model_kwargs) as m:
51
- f(*args, **kwargs)
60
+ for name, v in ba.arguments.items():
61
+ # Only wrap pm.Data around values pytensor can process
62
+ try:
63
+ _ = pt.as_tensor_variable(v)
64
+ ba.arguments[name] = Data(name, v)
65
+ except (NotImplementedError, TypeError, ValueError):
66
+ pass
67
+ f(*ba.args, **ba.kwargs)
52
68
  return m
53
69
 
54
70
  return make_model
@@ -15,6 +15,9 @@ from pymc.model.transform.optimization import freeze_dims_and_data
15
15
  from pymc.util import RandomState
16
16
  from pytensor import Variable, graph_replace
17
17
  from pytensor.compile import get_mode
18
+ from rich.box import SIMPLE_HEAD
19
+ from rich.console import Console
20
+ from rich.table import Table
18
21
 
19
22
  from pymc_extras.statespace.core.representation import PytensorRepresentation
20
23
  from pymc_extras.statespace.filters import (
@@ -254,53 +257,72 @@ class PyMCStateSpace:
254
257
  self.kalman_smoother = KalmanSmoother()
255
258
  self.make_symbolic_graph()
256
259
 
257
- if verbose:
258
- # These are split into separate try-except blocks, because it will be quite rare of models to implement
259
- # _print_data_requirements, but we still want to print the prior requirements.
260
- try:
261
- self._print_prior_requirements()
262
- except NotImplementedError:
263
- pass
264
- try:
265
- self._print_data_requirements()
266
- except NotImplementedError:
267
- pass
268
-
269
- def _print_prior_requirements(self) -> None:
270
- """
271
- Prints a short report to the terminal about the priors needed for the model, including their names,
260
+ self.requirement_table = None
261
+ self._populate_prior_requirements()
262
+ self._populate_data_requirements()
263
+
264
+ if verbose and self.requirement_table:
265
+ console = Console()
266
+ console.print(self.requirement_table)
267
+
268
+ def _populate_prior_requirements(self) -> None:
269
+ """
270
+ Add requirements about priors needed for the model to a rich table, including their names,
272
271
  shapes, named dimensions, and any parameter constraints.
273
272
  """
274
- out = ""
275
- for param, info in self.param_info.items():
276
- out += f'\t{param} -- shape: {info["shape"]}, constraints: {info["constraints"]}, dims: {info["dims"]}\n'
277
- out = out.rstrip()
273
+ # Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
274
+ # is not true.
275
+ try:
276
+ if not isinstance(self.param_info, dict):
277
+ return
278
+ except NotImplementedError:
279
+ return
278
280
 
279
- _log.info(
280
- "The following parameters should be assigned priors inside a PyMC "
281
- f"model block: \n"
282
- f"{out}"
283
- )
281
+ if self.requirement_table is None:
282
+ self._initialize_requirement_table()
283
+
284
+ for param, info in self.param_info.items():
285
+ self.requirement_table.add_row(
286
+ param, str(info["shape"]), info["constraints"], str(info["dims"])
287
+ )
284
288
 
285
- def _print_data_requirements(self) -> None:
289
+ def _populate_data_requirements(self) -> None:
286
290
  """
287
- Prints a short report to the terminal about the data needed for the model, including their names, shapes,
288
- and named dimensions.
291
+ Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
289
292
  """
290
- if not self.data_info:
293
+ try:
294
+ if not isinstance(self.data_info, dict):
295
+ return
296
+ except NotImplementedError:
291
297
  return
292
298
 
293
- out = ""
299
+ if self.requirement_table is None:
300
+ self._initialize_requirement_table()
301
+ else:
302
+ self.requirement_table.add_section()
303
+
294
304
  for data, info in self.data_info.items():
295
- out += f'\t{data} -- shape: {info["shape"]}, dims: {info["dims"]}\n'
296
- out = out.rstrip()
305
+ self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"]))
306
+
307
+ def _initialize_requirement_table(self) -> None:
308
+ self.requirement_table = Table(
309
+ show_header=True,
310
+ show_edge=True,
311
+ box=SIMPLE_HEAD,
312
+ highlight=True,
313
+ )
297
314
 
298
- _log.info(
299
- "The following Data variables should be assigned to the model inside a PyMC "
300
- f"model block: \n"
301
- f"{out}"
315
+ self.requirement_table.title = "Model Requirements"
316
+ self.requirement_table.caption = (
317
+ "These parameters should be assigned priors inside a PyMC model block before "
318
+ "calling the build_statespace_graph method."
302
319
  )
303
320
 
321
+ self.requirement_table.add_column("Variable", justify="left")
322
+ self.requirement_table.add_column("Shape", justify="left")
323
+ self.requirement_table.add_column("Constraints", justify="left")
324
+ self.requirement_table.add_column("Dimensions", justify="right")
325
+
304
326
  def _unpack_statespace_with_placeholders(
305
327
  self,
306
328
  ) -> tuple[
@@ -961,10 +983,31 @@ class PyMCStateSpace:
961
983
  list[pm.Flat]
962
984
  A list of pm.Flat variables representing all parameters estimated by the model.
963
985
  """
986
+
987
+ def infer_variable_shape(name):
988
+ shape = self._name_to_variable[name].type.shape
989
+ if not any(dim is None for dim in shape):
990
+ return shape
991
+
992
+ dim_names = self._fit_dims.get(name, None)
993
+ if dim_names is None:
994
+ raise ValueError(
995
+ f"Could not infer shape for {name}, because it was not given coords during model"
996
+ f"fitting"
997
+ )
998
+
999
+ shape_from_coords = tuple([len(self._fit_coords[dim]) for dim in dim_names])
1000
+ return tuple(
1001
+ [
1002
+ shape[i] if shape[i] is not None else shape_from_coords[i]
1003
+ for i in range(len(shape))
1004
+ ]
1005
+ )
1006
+
964
1007
  for name in self.param_names:
965
1008
  pm.Flat(
966
1009
  name,
967
- shape=self._name_to_variable[name].type.shape,
1010
+ shape=infer_variable_shape(name),
968
1011
  dims=self._fit_dims.get(name, None),
969
1012
  )
970
1013
 
@@ -1071,6 +1071,12 @@ class TimeSeasonality(Component):
1071
1071
 
1072
1072
  If None, states will be numbered ``[State_0, ..., State_s]``
1073
1073
 
1074
+ remove_first_state: bool, default True
1075
+ If True, the first state will be removed from the model. This is done because there are only n-1 degrees of
1076
+ freedom in the seasonal component, and one state is not identified. If False, the first state will be
1077
+ included in the model, but it will not be identified -- you will need to handle this in the priors (e.g. with
1078
+ ZeroSumNormal).
1079
+
1074
1080
  Notes
1075
1081
  -----
1076
1082
  A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
@@ -1163,7 +1169,7 @@ class TimeSeasonality(Component):
1163
1169
  innovations: bool = True,
1164
1170
  name: str | None = None,
1165
1171
  state_names: list | None = None,
1166
- pop_state: bool = True,
1172
+ remove_first_state: bool = True,
1167
1173
  ):
1168
1174
  if name is None:
1169
1175
  name = f"Seasonal[s={season_length}]"
@@ -1176,14 +1182,15 @@ class TimeSeasonality(Component):
1176
1182
  )
1177
1183
  state_names = state_names.copy()
1178
1184
  self.innovations = innovations
1179
- self.pop_state = pop_state
1185
+ self.remove_first_state = remove_first_state
1180
1186
 
1181
- if self.pop_state:
1187
+ if self.remove_first_state:
1182
1188
  # In traditional models, the first state isn't identified, so we can help out the user by automatically
1183
1189
  # discarding it.
1184
1190
  # TODO: Can this be stashed and reconstructed automatically somehow?
1185
1191
  state_names.pop(0)
1186
- k_states = season_length - 1
1192
+
1193
+ k_states = season_length - int(self.remove_first_state)
1187
1194
 
1188
1195
  super().__init__(
1189
1196
  name=name,
@@ -1218,8 +1225,16 @@ class TimeSeasonality(Component):
1218
1225
  self.shock_names = [f"{self.name}"]
1219
1226
 
1220
1227
  def make_symbolic_graph(self) -> None:
1221
- T = np.eye(self.k_states, k=-1)
1222
- T[0, :] = -1
1228
+ if self.remove_first_state:
1229
+ # In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
1230
+ # all previous states.
1231
+ T = np.eye(self.k_states, k=-1)
1232
+ T[0, :] = -1
1233
+ else:
1234
+ # In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a
1235
+ # circulant matrix that cycles between the states.
1236
+ T = np.eye(self.k_states, k=1)
1237
+ T[-1, 0] = 1
1223
1238
 
1224
1239
  self.ssm["transition", :, :] = T
1225
1240
  self.ssm["design", 0, 0] = 1
@@ -0,0 +1,66 @@
1
+ from collections.abc import Sequence
2
+
3
+ from pymc.model.core import Model
4
+ from pymc.model.fgraph import fgraph_from_model
5
+ from pytensor import Variable
6
+ from pytensor.compile import SharedVariable
7
+ from pytensor.graph import Constant, graph_inputs
8
+ from pytensor.graph.basic import equal_computations
9
+ from pytensor.tensor.random.type import RandomType
10
+
11
+
12
+ def equal_computations_up_to_root(
13
+ xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True
14
+ ) -> bool:
15
+ # Check if graphs are equivalent even if root variables have distinct identities
16
+
17
+ x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)]
18
+ y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)]
19
+ if len(x_graph_inputs) != len(y_graph_inputs):
20
+ return False
21
+ for x, y in zip(x_graph_inputs, y_graph_inputs):
22
+ if x.type != y.type:
23
+ return False
24
+ if x.name != y.name:
25
+ return False
26
+ if isinstance(x, SharedVariable):
27
+ # if not isinstance(y, SharedVariable):
28
+ # return False
29
+ if isinstance(x.type, RandomType) and ignore_rng_values:
30
+ continue
31
+ if not x.type.values_eq(x.get_value(), y.get_value()):
32
+ return False
33
+
34
+ return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs)
35
+
36
+
37
+ def equivalent_models(model1: Model, model2: Model) -> bool:
38
+ """Check whether two PyMC models are equivalent.
39
+
40
+ Examples
41
+ --------
42
+
43
+ .. code-block:: python
44
+
45
+ import pymc as pm
46
+ from pymc_extras.utils.model_equivalence import equivalent_models
47
+
48
+ with pm.Model() as m1:
49
+ x = pm.Normal("x")
50
+ y = pm.Normal("y", x)
51
+
52
+ with pm.Model() as m2:
53
+ x = pm.Normal("x")
54
+ y = pm.Normal("y", x + 1)
55
+
56
+ with pm.Model() as m3:
57
+ x = pm.Normal("x")
58
+ y = pm.Normal("y", x)
59
+
60
+ assert not equivalent_models(m1, m2)
61
+ assert equivalent_models(m1, m3)
62
+
63
+ """
64
+ fgraph1, _ = fgraph_from_model(model1)
65
+ fgraph2, _ = fgraph_from_model(model2)
66
+ return equal_computations_up_to_root(fgraph1.outputs, fgraph2.outputs)
pymc_extras/version.txt CHANGED
@@ -1 +1 @@
1
- 0.2.0
1
+ 0.2.2
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: pymc-extras
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Home-page: http://github.com/pymc-devs/pymc-extras
6
6
  Maintainer: PyMC Developers
@@ -34,6 +34,17 @@ Provides-Extra: dev
34
34
  Requires-Dist: dask[all]; extra == "dev"
35
35
  Requires-Dist: blackjax; extra == "dev"
36
36
  Requires-Dist: statsmodels; extra == "dev"
37
+ Dynamic: classifier
38
+ Dynamic: description
39
+ Dynamic: description-content-type
40
+ Dynamic: home-page
41
+ Dynamic: license
42
+ Dynamic: maintainer
43
+ Dynamic: maintainer-email
44
+ Dynamic: provides-extra
45
+ Dynamic: requires-dist
46
+ Dynamic: requires-python
47
+ Dynamic: summary
37
48
 
38
49
  # Welcome to `pymc-extras`
39
50
  <a href="https://gitpod.io/#https://github.com/pymc-devs/pymc-extras">
@@ -63,10 +74,9 @@ import pymc as pm
63
74
  import pymc_extras as pmx
64
75
 
65
76
  with pm.Model():
77
+ alpha = pmx.ParabolicFractal('alpha', b=1, c=1)
66
78
 
67
- alpha = pmx.ParabolicFractal('alpha', b=1, c=1)
68
-
69
- ...
79
+ ...
70
80
 
71
81
  ```
72
82
 
@@ -1,31 +1,34 @@
1
- pymc_extras/__init__.py,sha256=WpDTZvLhxFg_t9gOE_wOSsswoYGIZpllsJH-_yOLEYI,1124
1
+ pymc_extras/__init__.py,sha256=URh185f6b1xp2Taj2W2NJuW_hErKufBcLeQ0WDCyaNk,1160
2
2
  pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
3
3
  pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
4
4
  pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
5
5
  pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
6
- pymc_extras/version.txt,sha256=H5MN0fEzwfl6lP46y42zQ3LPTAH_2ys_9Mpy-UlBIek,6
6
+ pymc_extras/version.txt,sha256=mY9riH7Xpu9E6EIQ0CN7cVvtQupyVPSNxBiv7ApIQQM,6
7
7
  pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
8
8
  pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
9
9
  pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
10
10
  pymc_extras/distributions/histogram_utils.py,sha256=5RTvlGCUrp2qzshrchmPyWxjhs6RIYL62SMikjDM1TU,5814
11
- pymc_extras/distributions/timeseries.py,sha256=EJxWOrfuQlODwPN13Udgy2ras6vQKS0Ebus0pUuduaA,12680
11
+ pymc_extras/distributions/timeseries.py,sha256=M5MZ-nik_tgkaoZ1hdUGEZ9g04DQyVLwszVJqSKwNcY,12719
12
12
  pymc_extras/distributions/multivariate/__init__.py,sha256=E8OeLW9tTotCbrUjEo4um76-_WQD56PehsPzkKmhfyA,93
13
13
  pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNPi727uTbvAdxl9fm1zNBqY,16005
14
14
  pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
15
15
  pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
16
16
  pymc_extras/inference/__init__.py,sha256=5cXpaQQnW0mJJ3x8wSxmYu63l--Xab5D_gMtjA6Q3uU,666
17
17
  pymc_extras/inference/find_map.py,sha256=T0uO8prUI5aBNuR1AN8fbA4cHmLRQLXznwJrfxfe7CA,15723
18
- pymc_extras/inference/fit.py,sha256=NFEpUaYLJAmDRP1WIPymgnEcXUofkoURYHbEdiTivzQ,1313
18
+ pymc_extras/inference/fit.py,sha256=S9R48dh74s6K0MC9Iys4NAwVjP6rVRfx6SF-kPiR70E,1165
19
19
  pymc_extras/inference/laplace.py,sha256=OglOvnxfHLe0VXxBC1-ddVzADR9zgGxUPScM6P6FYo8,21163
20
- pymc_extras/inference/pathfinder.py,sha256=cmzR2OZCfkdTipT-8pmLuF-MHmLzxotsYlezOWBUM4U,4171
20
+ pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
21
+ pymc_extras/inference/pathfinder/importance_sampling.py,sha256=VvmuaE3aw_Mo3tMwswfF0rqe19mnhOCpzIScaJzjA1Y,6159
22
+ pymc_extras/inference/pathfinder/lbfgs.py,sha256=P0UIOVtspdLzDU6alK-y91qzVAzXjYAXPuGmZ1nRqMo,5715
23
+ pymc_extras/inference/pathfinder/pathfinder.py,sha256=fomZ5voVcWxvhWpeIZV7IHGIJCasT1g0ivC4dC3-0GM,63694
21
24
  pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
22
25
  pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
23
26
  pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
- pymc_extras/model/model_api.py,sha256=_r6rYQG1tt9Z95QU-jVyHqZ1rs-u7sFMO5HJ5unDV5A,1750
27
+ pymc_extras/model/model_api.py,sha256=UHMfQXxWBujeSiUySU0fDUC5Sd_BjT8FoVz3iBxQH_4,2400
25
28
  pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- pymc_extras/model/marginal/distributions.py,sha256=hytO-CJqMXuvIfJhdfKNlf9tkblOwCp8hAWTiIS0cyU,12088
27
- pymc_extras/model/marginal/graph_analysis.py,sha256=LNx5N8ZE7d8Sq3BnlUYHrrwnxTLJTvehf8xYu95yrb8,15699
28
- pymc_extras/model/marginal/marginal_model.py,sha256=5DbhjlOAwY6JSMJUhUAPXPpybXQU0x7MzOGd3eXACYo,24854
29
+ pymc_extras/model/marginal/distributions.py,sha256=iM1yT7_BmivgUSloQPKE2QXGPgjvLqDMY_OTBGsdAWg,15563
30
+ pymc_extras/model/marginal/graph_analysis.py,sha256=0hWUH_PjfpgneQ3NaT__pWHS1fh50zNbI86kH4Nub0E,15693
31
+ pymc_extras/model/marginal/marginal_model.py,sha256=oNsiSWHjOPCTDxNEivEILLP_cOuBarm29Gr2p6hWHIM,23594
29
32
  pymc_extras/model/transforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
33
  pymc_extras/model/transforms/autoreparam.py,sha256=_NltGWmNqi_X9sHCqAvWcBveLTPxVy11-wENFTcN6kk,12377
31
34
  pymc_extras/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -34,7 +37,7 @@ pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53C
34
37
  pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
35
38
  pymc_extras/statespace/core/compile.py,sha256=1c8Q9D9zeUe7F0z7CH6q1C6ZuLg2_imgk8RoE_KMaFI,1608
36
39
  pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
37
- pymc_extras/statespace/core/statespace.py,sha256=ZElRm9wJvIGG4Pw-3qiQpBkHXRDqS6pfRyuGrBBcZ2Y,95270
40
+ pymc_extras/statespace/core/statespace.py,sha256=K_WVnWKlI6sR2kgriq9sctQVvwXCeAirm14TthDpmRM,96860
38
41
  pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
39
42
  pymc_extras/statespace/filters/distributions.py,sha256=-9j__vRqL5hKyYFnQr5HKHA5kEFzwiuSccH4mslTOuQ,12900
40
43
  pymc_extras/statespace/filters/kalman_filter.py,sha256=HELC3aK4k8EdWlUAk5_F7y7YkIz-Xi_0j2AwRgAXgcc,31949
@@ -44,7 +47,7 @@ pymc_extras/statespace/models/ETS.py,sha256=o039M-6aCxyMXbbKvUeNVZhheCKvvNIAmuj0
44
47
  pymc_extras/statespace/models/SARIMAX.py,sha256=SX0eiSK1pOt4dHBjWzBqVpRz67pBGLN5pQQgXcOiOgY,21607
45
48
  pymc_extras/statespace/models/VARMAX.py,sha256=xkIuftNc_5NHFpqZalExni99-1kovnzm5OjMIDNgaxY,15989
46
49
  pymc_extras/statespace/models/__init__.py,sha256=U79b8rTHBNijVvvGOd43nLu4PCloPUH1rwlN87-n88c,317
47
- pymc_extras/statespace/models/structural.py,sha256=W5FmImZyvHGxpFCYczfi6IVIjXDQkzaFjVKZSa6CiW8,63017
50
+ pymc_extras/statespace/models/structural.py,sha256=sep9pesJdRN4X8Bea6_RhO3112uWOZRuYRxO6ibl_OA,63943
48
51
  pymc_extras/statespace/models/utilities.py,sha256=G9GuHKsghmIYOlfkPtvxBWF-FZY5-5JI1fJQM8N7EnE,15373
49
52
  pymc_extras/statespace/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
53
  pymc_extras/statespace/utils/constants.py,sha256=Kf6j75ABaDQeRODxKQ76wTUQV4F5sTjn1KBcZgCQx20,2403
@@ -52,6 +55,7 @@ pymc_extras/statespace/utils/coord_tools.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm
52
55
  pymc_extras/statespace/utils/data_tools.py,sha256=caanvrxDu9g-dEKff2bbmaVTs6-71kkSoYIiiSUXhw4,5985
53
56
  pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
54
57
  pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
58
+ pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
55
59
  pymc_extras/utils/pivoted_cholesky.py,sha256=QtnjP0pAl9b77fLAu-semwT4_9dcoiqx3dz1xKGBjMk,1871
56
60
  pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
57
61
  pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
@@ -62,23 +66,23 @@ tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6
62
66
  tests/test_laplace.py,sha256=5ioEyP6AzmMszrtQRz0KWTsCCU35SEhSOdBcYfYzptE,8228
63
67
  tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
64
68
  tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
65
- tests/test_pathfinder.py,sha256=FBm0ge6rje5jz9_10h_247E70aKCpkbu1jmzrR7Ar8A,1726
66
- tests/test_pivoted_cholesky.py,sha256=7_thrb90_an_S3boYr0mu4NNOhjiI6AaZ2ADn53sBX8,698
69
+ tests/test_pathfinder.py,sha256=GnSbZJ9QuFW9UVbkWaVgMVqQZTCttOyz_rSflxhQ-EA,4955
70
+ tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
67
71
  tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
68
72
  tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
69
73
  tests/test_splines.py,sha256=xSZi4hqqReN1H8LHr0xjDmpomhDQm8auIsWQjFOyjbM,2608
70
- tests/utils.py,sha256=cRPe0ovsexxOQ6dK94xao_Kv5qcPrqtFWOFXBqubHqY,1257
74
+ tests/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
71
75
  tests/distributions/__init__.py,sha256=jt-oloszTLNFwi9AgU3M4m6xKQ8xpQE338rmmaMZcMs,795
72
76
  tests/distributions/test_continuous.py,sha256=1-bu-IP6RgLUJnuPYpOD8ZS1ahYbKtsJ9oflBfqCaFo,5477
73
77
  tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcleoCXUi0,7776
74
78
  tests/distributions/test_discrete_markov_chain.py,sha256=8RCHZXSB8IWjniuKaGGlM_iTWGmdrcOqginxmrAeEJg,9212
75
79
  tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIXeRBKiEn6kjs,10452
76
80
  tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
- tests/model/test_model_api.py,sha256=SiOMA1NpyQKJ7stYI1ms8ksDPU81lVo8wS8hbqiik-U,776
81
+ tests/model/test_model_api.py,sha256=FJvMTmexovRELZOUcUyk-6Vwk9qSiH7hIFoiArgl5mk,1040
78
82
  tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
79
- tests/model/marginal/test_distributions.py,sha256=UWoXi1h-e0QG0fLxPXkUj2LXzNb5uhoHicIBIVDIhL4,5126
83
+ tests/model/marginal/test_distributions.py,sha256=p5f73g4ogxYkdZaBndZV_1ra8TCppXiRlUpaaTwEe-M,5195
80
84
  tests/model/marginal/test_graph_analysis.py,sha256=raoj41NusMOj1zzPCrxrlQODqX6Ey8Ft_o32pNTe5qg,6712
81
- tests/model/marginal/test_marginal_model.py,sha256=IWmF_XkbQeI15H0uepi1zET1Zffj6qVgAAdAiUoxrhA,31009
85
+ tests/model/marginal/test_marginal_model.py,sha256=uOmARalkdWq3sDbnJQ0KjiLwviqauZOAnafmYS_Cnd8,35475
82
86
  tests/statespace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
83
87
  tests/statespace/test_ETS.py,sha256=IPg3uQ7xEGqDMEHu993vtUTV7r-uNAxmw23sr5MVGfQ,15582
84
88
  tests/statespace/test_SARIMAX.py,sha256=1BYNOm9aSHnpn-qbpe3YsQVH8m-mXcp_gvKgWhWn1W4,12948
@@ -87,15 +91,15 @@ tests/statespace/test_coord_assignment.py,sha256=2GBm46-0eI4QNh4bvp3D7az58stcA5Z
87
91
  tests/statespace/test_distributions.py,sha256=WQ_ROyd-PL3cimXTyEtyVaMEVtS7Hue2Z0lN7UnGDyo,9122
88
92
  tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_YvQ2lM6hiQ,11600
89
93
  tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
90
- tests/statespace/test_statespace.py,sha256=8ZLLQaxlP5UEJnIMYyIzzAODCxMxs6E5I1hLu2HCdqo,28866
94
+ tests/statespace/test_statespace.py,sha256=JoupFFpG8PmpB_NFV471IuTmyXhEd6_vOISwVCRrBBM,30570
91
95
  tests/statespace/test_statespace_JAX.py,sha256=hZOc6xxYdVeATPCKmcHMLOVcuvdzGRzgQQ4RrDenwk8,5279
92
- tests/statespace/test_structural.py,sha256=IN6OQuq7bZVCDiws3Yhsa4IoPyfLaONzgzdvImp0Zcc,29036
96
+ tests/statespace/test_structural.py,sha256=HD8OaGbjuH4y3xv_uG-R1xLZpPpcb4-3dbcTeb_imLY,29306
93
97
  tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
94
98
  tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
95
99
  tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
96
100
  tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
97
- pymc_extras-0.2.0.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
98
- pymc_extras-0.2.0.dist-info/METADATA,sha256=DpLFVnYDJouXKp6hsIfTtE6rBgJfcLvRdq5CsEKaNVQ,4899
99
- pymc_extras-0.2.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
100
- pymc_extras-0.2.0.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
101
- pymc_extras-0.2.0.dist-info/RECORD,,
101
+ pymc_extras-0.2.2.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
102
+ pymc_extras-0.2.2.dist-info/METADATA,sha256=9k60kKKNzr7E24gACpTBNOWj-tRTNOiujSZOFD89G5c,5140
103
+ pymc_extras-0.2.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
104
+ pymc_extras-0.2.2.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
105
+ pymc_extras-0.2.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -6,7 +6,7 @@ from pymc.logprob.abstract import _logprob
6
6
  from pytensor import tensor as pt
7
7
  from scipy.stats import norm
8
8
 
9
- from pymc_extras import MarginalModel
9
+ from pymc_extras import marginalize
10
10
  from pymc_extras.distributions import DiscreteMarkovChain
11
11
  from pymc_extras.model.marginal.distributions import MarginalFiniteDiscreteRV
12
12
 
@@ -21,6 +21,7 @@ def test_marginalized_bernoulli_logp():
21
21
  [mu],
22
22
  [idx, y],
23
23
  dims_connections=(((),),),
24
+ dims=(),
24
25
  )(mu)[0].owner
25
26
 
26
27
  y_vv = y.clone()
@@ -43,7 +44,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
43
44
  if batch_chain and not batch_emission:
44
45
  pytest.skip("Redundant implicit combination")
45
46
 
46
- with MarginalModel() as m:
47
+ with pm.Model() as m:
47
48
  P = [[0, 1], [1, 0]]
48
49
  init_dist = pm.Categorical.dist(p=[1, 0])
49
50
  chain = DiscreteMarkovChain(
@@ -53,8 +54,8 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
53
54
  "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None
54
55
  )
55
56
 
56
- m.marginalize([chain])
57
- logp_fn = m.compile_logp()
57
+ marginal_m = marginalize(m, [chain])
58
+ logp_fn = marginal_m.compile_logp()
58
59
 
59
60
  test_value = np.array([-1, 1, -1, 1])
60
61
  expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
@@ -70,7 +71,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
70
71
  )
71
72
  def test_marginalized_hmm_categorical_emission(categorical_emission):
72
73
  """Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
73
- with MarginalModel() as m:
74
+ with pm.Model() as m:
74
75
  P = np.array([[0.5, 0.5], [0.3, 0.7]])
75
76
  init_dist = pm.Categorical.dist(p=[0.375, 0.625])
76
77
  chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
@@ -78,11 +79,11 @@ def test_marginalized_hmm_categorical_emission(categorical_emission):
78
79
  emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain])
79
80
  else:
80
81
  emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
81
- m.marginalize([chain])
82
+ marginal_m = marginalize(m, [chain])
82
83
 
83
84
  test_value = np.array([0, 0, 1])
84
85
  expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video
85
- logp_fn = m.compile_logp()
86
+ logp_fn = marginal_m.compile_logp()
86
87
  np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
87
88
 
88
89
 
@@ -95,7 +96,7 @@ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch
95
96
  (2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape))
96
97
  )
97
98
  emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape
98
- with MarginalModel() as m:
99
+ with pm.Model() as m:
99
100
  P = [[0, 1], [1, 0]]
100
101
  init_dist = pm.Categorical.dist(p=[1, 0])
101
102
  chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape)
@@ -108,10 +109,10 @@ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch
108
109
  emission2_mu = emission2_mu[..., None]
109
110
  emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape)
110
111
 
111
- with pytest.warns(UserWarning, match="multiple dependent variables"):
112
- m.marginalize([chain])
112
+ marginal_m = marginalize(m, [chain])
113
113
 
114
- logp_fn = m.compile_logp(sum=False)
114
+ with pytest.warns(UserWarning, match="multiple dependent variables"):
115
+ logp_fn = marginal_m.compile_logp(sum=False)
115
116
 
116
117
  test_value = np.array([-1, 1, -1, 1])
117
118
  multiplier = 2 + batch_emission1 + batch_emission2