pymc-extras 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1071,6 +1071,12 @@ class TimeSeasonality(Component):
1071
1071
 
1072
1072
  If None, states will be numbered ``[State_0, ..., State_s]``
1073
1073
 
1074
+ remove_first_state: bool, default True
1075
+ If True, the first state will be removed from the model. This is done because there are only n-1 degrees of
1076
+ freedom in the seasonal component, and one state is not identified. If False, the first state will be
1077
+ included in the model, but it will not be identified -- you will need to handle this in the priors (e.g. with
1078
+ ZeroSumNormal).
1079
+
1074
1080
  Notes
1075
1081
  -----
1076
1082
  A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
@@ -1163,7 +1169,7 @@ class TimeSeasonality(Component):
1163
1169
  innovations: bool = True,
1164
1170
  name: str | None = None,
1165
1171
  state_names: list | None = None,
1166
- pop_state: bool = True,
1172
+ remove_first_state: bool = True,
1167
1173
  ):
1168
1174
  if name is None:
1169
1175
  name = f"Seasonal[s={season_length}]"
@@ -1176,14 +1182,15 @@ class TimeSeasonality(Component):
1176
1182
  )
1177
1183
  state_names = state_names.copy()
1178
1184
  self.innovations = innovations
1179
- self.pop_state = pop_state
1185
+ self.remove_first_state = remove_first_state
1180
1186
 
1181
- if self.pop_state:
1187
+ if self.remove_first_state:
1182
1188
  # In traditional models, the first state isn't identified, so we can help out the user by automatically
1183
1189
  # discarding it.
1184
1190
  # TODO: Can this be stashed and reconstructed automatically somehow?
1185
1191
  state_names.pop(0)
1186
- k_states = season_length - 1
1192
+
1193
+ k_states = season_length - int(self.remove_first_state)
1187
1194
 
1188
1195
  super().__init__(
1189
1196
  name=name,
@@ -1218,8 +1225,16 @@ class TimeSeasonality(Component):
1218
1225
  self.shock_names = [f"{self.name}"]
1219
1226
 
1220
1227
  def make_symbolic_graph(self) -> None:
1221
- T = np.eye(self.k_states, k=-1)
1222
- T[0, :] = -1
1228
+ if self.remove_first_state:
1229
+ # In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
1230
+ # all previous states.
1231
+ T = np.eye(self.k_states, k=-1)
1232
+ T[0, :] = -1
1233
+ else:
1234
+ # In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a
1235
+ # circulant matrix that cycles between the states.
1236
+ T = np.eye(self.k_states, k=1)
1237
+ T[-1, 0] = 1
1223
1238
 
1224
1239
  self.ssm["transition", :, :] = T
1225
1240
  self.ssm["design", 0, 0] = 1
@@ -0,0 +1,66 @@
1
+ from collections.abc import Sequence
2
+
3
+ from pymc.model.core import Model
4
+ from pymc.model.fgraph import fgraph_from_model
5
+ from pytensor import Variable
6
+ from pytensor.compile import SharedVariable
7
+ from pytensor.graph import Constant, graph_inputs
8
+ from pytensor.graph.basic import equal_computations
9
+ from pytensor.tensor.random.type import RandomType
10
+
11
+
12
+ def equal_computations_up_to_root(
13
+ xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True
14
+ ) -> bool:
15
+ # Check if graphs are equivalent even if root variables have distinct identities
16
+
17
+ x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)]
18
+ y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)]
19
+ if len(x_graph_inputs) != len(y_graph_inputs):
20
+ return False
21
+ for x, y in zip(x_graph_inputs, y_graph_inputs):
22
+ if x.type != y.type:
23
+ return False
24
+ if x.name != y.name:
25
+ return False
26
+ if isinstance(x, SharedVariable):
27
+ # if not isinstance(y, SharedVariable):
28
+ # return False
29
+ if isinstance(x.type, RandomType) and ignore_rng_values:
30
+ continue
31
+ if not x.type.values_eq(x.get_value(), y.get_value()):
32
+ return False
33
+
34
+ return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs)
35
+
36
+
37
+ def equivalent_models(model1: Model, model2: Model) -> bool:
38
+ """Check whether two PyMC models are equivalent.
39
+
40
+ Examples
41
+ --------
42
+
43
+ .. code-block:: python
44
+
45
+ import pymc as pm
46
+ from pymc_extras.utils.model_equivalence import equivalent_models
47
+
48
+ with pm.Model() as m1:
49
+ x = pm.Normal("x")
50
+ y = pm.Normal("y", x)
51
+
52
+ with pm.Model() as m2:
53
+ x = pm.Normal("x")
54
+ y = pm.Normal("y", x + 1)
55
+
56
+ with pm.Model() as m3:
57
+ x = pm.Normal("x")
58
+ y = pm.Normal("y", x)
59
+
60
+ assert not equivalent_models(m1, m2)
61
+ assert equivalent_models(m1, m3)
62
+
63
+ """
64
+ fgraph1, _ = fgraph_from_model(model1)
65
+ fgraph2, _ = fgraph_from_model(model2)
66
+ return equal_computations_up_to_root(fgraph1.outputs, fgraph2.outputs)
pymc_extras/version.txt CHANGED
@@ -1 +1 @@
1
- 0.2.0
1
+ 0.2.1
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pymc-extras
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Home-page: http://github.com/pymc-devs/pymc-extras
6
6
  Maintainer: PyMC Developers
@@ -63,10 +63,9 @@ import pymc as pm
63
63
  import pymc_extras as pmx
64
64
 
65
65
  with pm.Model():
66
+ alpha = pmx.ParabolicFractal('alpha', b=1, c=1)
66
67
 
67
- alpha = pmx.ParabolicFractal('alpha', b=1, c=1)
68
-
69
- ...
68
+ ...
70
69
 
71
70
  ```
72
71
 
@@ -1,14 +1,14 @@
1
- pymc_extras/__init__.py,sha256=WpDTZvLhxFg_t9gOE_wOSsswoYGIZpllsJH-_yOLEYI,1124
1
+ pymc_extras/__init__.py,sha256=URh185f6b1xp2Taj2W2NJuW_hErKufBcLeQ0WDCyaNk,1160
2
2
  pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
3
3
  pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
4
4
  pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
5
5
  pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
6
- pymc_extras/version.txt,sha256=H5MN0fEzwfl6lP46y42zQ3LPTAH_2ys_9Mpy-UlBIek,6
6
+ pymc_extras/version.txt,sha256=cQFcl5zLD8igvnygroMEarBFzcLI-qCfsvD35ED5tKY,6
7
7
  pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
8
8
  pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
9
9
  pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
10
10
  pymc_extras/distributions/histogram_utils.py,sha256=5RTvlGCUrp2qzshrchmPyWxjhs6RIYL62SMikjDM1TU,5814
11
- pymc_extras/distributions/timeseries.py,sha256=EJxWOrfuQlODwPN13Udgy2ras6vQKS0Ebus0pUuduaA,12680
11
+ pymc_extras/distributions/timeseries.py,sha256=M5MZ-nik_tgkaoZ1hdUGEZ9g04DQyVLwszVJqSKwNcY,12719
12
12
  pymc_extras/distributions/multivariate/__init__.py,sha256=E8OeLW9tTotCbrUjEo4um76-_WQD56PehsPzkKmhfyA,93
13
13
  pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNPi727uTbvAdxl9fm1zNBqY,16005
14
14
  pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
@@ -23,9 +23,9 @@ pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmN
23
23
  pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
24
  pymc_extras/model/model_api.py,sha256=_r6rYQG1tt9Z95QU-jVyHqZ1rs-u7sFMO5HJ5unDV5A,1750
25
25
  pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- pymc_extras/model/marginal/distributions.py,sha256=hytO-CJqMXuvIfJhdfKNlf9tkblOwCp8hAWTiIS0cyU,12088
27
- pymc_extras/model/marginal/graph_analysis.py,sha256=LNx5N8ZE7d8Sq3BnlUYHrrwnxTLJTvehf8xYu95yrb8,15699
28
- pymc_extras/model/marginal/marginal_model.py,sha256=5DbhjlOAwY6JSMJUhUAPXPpybXQU0x7MzOGd3eXACYo,24854
26
+ pymc_extras/model/marginal/distributions.py,sha256=iM1yT7_BmivgUSloQPKE2QXGPgjvLqDMY_OTBGsdAWg,15563
27
+ pymc_extras/model/marginal/graph_analysis.py,sha256=0hWUH_PjfpgneQ3NaT__pWHS1fh50zNbI86kH4Nub0E,15693
28
+ pymc_extras/model/marginal/marginal_model.py,sha256=oNsiSWHjOPCTDxNEivEILLP_cOuBarm29Gr2p6hWHIM,23594
29
29
  pymc_extras/model/transforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  pymc_extras/model/transforms/autoreparam.py,sha256=_NltGWmNqi_X9sHCqAvWcBveLTPxVy11-wENFTcN6kk,12377
31
31
  pymc_extras/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -44,7 +44,7 @@ pymc_extras/statespace/models/ETS.py,sha256=o039M-6aCxyMXbbKvUeNVZhheCKvvNIAmuj0
44
44
  pymc_extras/statespace/models/SARIMAX.py,sha256=SX0eiSK1pOt4dHBjWzBqVpRz67pBGLN5pQQgXcOiOgY,21607
45
45
  pymc_extras/statespace/models/VARMAX.py,sha256=xkIuftNc_5NHFpqZalExni99-1kovnzm5OjMIDNgaxY,15989
46
46
  pymc_extras/statespace/models/__init__.py,sha256=U79b8rTHBNijVvvGOd43nLu4PCloPUH1rwlN87-n88c,317
47
- pymc_extras/statespace/models/structural.py,sha256=W5FmImZyvHGxpFCYczfi6IVIjXDQkzaFjVKZSa6CiW8,63017
47
+ pymc_extras/statespace/models/structural.py,sha256=sep9pesJdRN4X8Bea6_RhO3112uWOZRuYRxO6ibl_OA,63943
48
48
  pymc_extras/statespace/models/utilities.py,sha256=G9GuHKsghmIYOlfkPtvxBWF-FZY5-5JI1fJQM8N7EnE,15373
49
49
  pymc_extras/statespace/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
50
  pymc_extras/statespace/utils/constants.py,sha256=Kf6j75ABaDQeRODxKQ76wTUQV4F5sTjn1KBcZgCQx20,2403
@@ -52,6 +52,7 @@ pymc_extras/statespace/utils/coord_tools.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm
52
52
  pymc_extras/statespace/utils/data_tools.py,sha256=caanvrxDu9g-dEKff2bbmaVTs6-71kkSoYIiiSUXhw4,5985
53
53
  pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
54
54
  pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
55
+ pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
55
56
  pymc_extras/utils/pivoted_cholesky.py,sha256=QtnjP0pAl9b77fLAu-semwT4_9dcoiqx3dz1xKGBjMk,1871
56
57
  pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
57
58
  pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
@@ -63,11 +64,11 @@ tests/test_laplace.py,sha256=5ioEyP6AzmMszrtQRz0KWTsCCU35SEhSOdBcYfYzptE,8228
63
64
  tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
64
65
  tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
65
66
  tests/test_pathfinder.py,sha256=FBm0ge6rje5jz9_10h_247E70aKCpkbu1jmzrR7Ar8A,1726
66
- tests/test_pivoted_cholesky.py,sha256=7_thrb90_an_S3boYr0mu4NNOhjiI6AaZ2ADn53sBX8,698
67
+ tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
67
68
  tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
68
69
  tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
69
70
  tests/test_splines.py,sha256=xSZi4hqqReN1H8LHr0xjDmpomhDQm8auIsWQjFOyjbM,2608
70
- tests/utils.py,sha256=cRPe0ovsexxOQ6dK94xao_Kv5qcPrqtFWOFXBqubHqY,1257
71
+ tests/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
71
72
  tests/distributions/__init__.py,sha256=jt-oloszTLNFwi9AgU3M4m6xKQ8xpQE338rmmaMZcMs,795
72
73
  tests/distributions/test_continuous.py,sha256=1-bu-IP6RgLUJnuPYpOD8ZS1ahYbKtsJ9oflBfqCaFo,5477
73
74
  tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcleoCXUi0,7776
@@ -76,9 +77,9 @@ tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIX
76
77
  tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
78
  tests/model/test_model_api.py,sha256=SiOMA1NpyQKJ7stYI1ms8ksDPU81lVo8wS8hbqiik-U,776
78
79
  tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
79
- tests/model/marginal/test_distributions.py,sha256=UWoXi1h-e0QG0fLxPXkUj2LXzNb5uhoHicIBIVDIhL4,5126
80
+ tests/model/marginal/test_distributions.py,sha256=p5f73g4ogxYkdZaBndZV_1ra8TCppXiRlUpaaTwEe-M,5195
80
81
  tests/model/marginal/test_graph_analysis.py,sha256=raoj41NusMOj1zzPCrxrlQODqX6Ey8Ft_o32pNTe5qg,6712
81
- tests/model/marginal/test_marginal_model.py,sha256=IWmF_XkbQeI15H0uepi1zET1Zffj6qVgAAdAiUoxrhA,31009
82
+ tests/model/marginal/test_marginal_model.py,sha256=uOmARalkdWq3sDbnJQ0KjiLwviqauZOAnafmYS_Cnd8,35475
82
83
  tests/statespace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
83
84
  tests/statespace/test_ETS.py,sha256=IPg3uQ7xEGqDMEHu993vtUTV7r-uNAxmw23sr5MVGfQ,15582
84
85
  tests/statespace/test_SARIMAX.py,sha256=1BYNOm9aSHnpn-qbpe3YsQVH8m-mXcp_gvKgWhWn1W4,12948
@@ -89,13 +90,13 @@ tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_Y
89
90
  tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
90
91
  tests/statespace/test_statespace.py,sha256=8ZLLQaxlP5UEJnIMYyIzzAODCxMxs6E5I1hLu2HCdqo,28866
91
92
  tests/statespace/test_statespace_JAX.py,sha256=hZOc6xxYdVeATPCKmcHMLOVcuvdzGRzgQQ4RrDenwk8,5279
92
- tests/statespace/test_structural.py,sha256=IN6OQuq7bZVCDiws3Yhsa4IoPyfLaONzgzdvImp0Zcc,29036
93
+ tests/statespace/test_structural.py,sha256=HD8OaGbjuH4y3xv_uG-R1xLZpPpcb4-3dbcTeb_imLY,29306
93
94
  tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
94
95
  tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
95
96
  tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
96
97
  tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
97
- pymc_extras-0.2.0.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
98
- pymc_extras-0.2.0.dist-info/METADATA,sha256=DpLFVnYDJouXKp6hsIfTtE6rBgJfcLvRdq5CsEKaNVQ,4899
99
- pymc_extras-0.2.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
100
- pymc_extras-0.2.0.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
101
- pymc_extras-0.2.0.dist-info/RECORD,,
98
+ pymc_extras-0.2.1.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
99
+ pymc_extras-0.2.1.dist-info/METADATA,sha256=pT1MOjFxsX6lc0q_D3J-2jNW6UaRkCq_0kJemgG4DGU,4894
100
+ pymc_extras-0.2.1.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
101
+ pymc_extras-0.2.1.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
102
+ pymc_extras-0.2.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.7.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -6,7 +6,7 @@ from pymc.logprob.abstract import _logprob
6
6
  from pytensor import tensor as pt
7
7
  from scipy.stats import norm
8
8
 
9
- from pymc_extras import MarginalModel
9
+ from pymc_extras import marginalize
10
10
  from pymc_extras.distributions import DiscreteMarkovChain
11
11
  from pymc_extras.model.marginal.distributions import MarginalFiniteDiscreteRV
12
12
 
@@ -21,6 +21,7 @@ def test_marginalized_bernoulli_logp():
21
21
  [mu],
22
22
  [idx, y],
23
23
  dims_connections=(((),),),
24
+ dims=(),
24
25
  )(mu)[0].owner
25
26
 
26
27
  y_vv = y.clone()
@@ -43,7 +44,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
43
44
  if batch_chain and not batch_emission:
44
45
  pytest.skip("Redundant implicit combination")
45
46
 
46
- with MarginalModel() as m:
47
+ with pm.Model() as m:
47
48
  P = [[0, 1], [1, 0]]
48
49
  init_dist = pm.Categorical.dist(p=[1, 0])
49
50
  chain = DiscreteMarkovChain(
@@ -53,8 +54,8 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
53
54
  "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None
54
55
  )
55
56
 
56
- m.marginalize([chain])
57
- logp_fn = m.compile_logp()
57
+ marginal_m = marginalize(m, [chain])
58
+ logp_fn = marginal_m.compile_logp()
58
59
 
59
60
  test_value = np.array([-1, 1, -1, 1])
60
61
  expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
@@ -70,7 +71,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
70
71
  )
71
72
  def test_marginalized_hmm_categorical_emission(categorical_emission):
72
73
  """Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
73
- with MarginalModel() as m:
74
+ with pm.Model() as m:
74
75
  P = np.array([[0.5, 0.5], [0.3, 0.7]])
75
76
  init_dist = pm.Categorical.dist(p=[0.375, 0.625])
76
77
  chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
@@ -78,11 +79,11 @@ def test_marginalized_hmm_categorical_emission(categorical_emission):
78
79
  emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain])
79
80
  else:
80
81
  emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
81
- m.marginalize([chain])
82
+ marginal_m = marginalize(m, [chain])
82
83
 
83
84
  test_value = np.array([0, 0, 1])
84
85
  expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video
85
- logp_fn = m.compile_logp()
86
+ logp_fn = marginal_m.compile_logp()
86
87
  np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
87
88
 
88
89
 
@@ -95,7 +96,7 @@ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch
95
96
  (2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape))
96
97
  )
97
98
  emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape
98
- with MarginalModel() as m:
99
+ with pm.Model() as m:
99
100
  P = [[0, 1], [1, 0]]
100
101
  init_dist = pm.Categorical.dist(p=[1, 0])
101
102
  chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape)
@@ -108,10 +109,10 @@ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch
108
109
  emission2_mu = emission2_mu[..., None]
109
110
  emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape)
110
111
 
111
- with pytest.warns(UserWarning, match="multiple dependent variables"):
112
- m.marginalize([chain])
112
+ marginal_m = marginalize(m, [chain])
113
113
 
114
- logp_fn = m.compile_logp(sum=False)
114
+ with pytest.warns(UserWarning, match="multiple dependent variables"):
115
+ logp_fn = marginal_m.compile_logp(sum=False)
115
116
 
116
117
  test_value = np.array([-1, 1, -1, 1])
117
118
  multiplier = 2 + batch_emission1 + batch_emission2