pymc-extras 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +5 -1
- pymc_extras/distributions/timeseries.py +1 -1
- pymc_extras/model/marginal/distributions.py +100 -3
- pymc_extras/model/marginal/graph_analysis.py +8 -9
- pymc_extras/model/marginal/marginal_model.py +437 -424
- pymc_extras/statespace/models/structural.py +21 -6
- pymc_extras/utils/model_equivalence.py +66 -0
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/METADATA +3 -4
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/RECORD +18 -17
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/WHEEL +1 -1
- tests/model/marginal/test_distributions.py +12 -11
- tests/model/marginal/test_marginal_model.py +301 -201
- tests/statespace/test_structural.py +10 -3
- tests/test_pivoted_cholesky.py +1 -1
- tests/utils.py +0 -31
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/LICENSE +0 -0
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/top_level.txt +0 -0
|
@@ -1071,6 +1071,12 @@ class TimeSeasonality(Component):
|
|
|
1071
1071
|
|
|
1072
1072
|
If None, states will be numbered ``[State_0, ..., State_s]``
|
|
1073
1073
|
|
|
1074
|
+
remove_first_state: bool, default True
|
|
1075
|
+
If True, the first state will be removed from the model. This is done because there are only n-1 degrees of
|
|
1076
|
+
freedom in the seasonal component, and one state is not identified. If False, the first state will be
|
|
1077
|
+
included in the model, but it will not be identified -- you will need to handle this in the priors (e.g. with
|
|
1078
|
+
ZeroSumNormal).
|
|
1079
|
+
|
|
1074
1080
|
Notes
|
|
1075
1081
|
-----
|
|
1076
1082
|
A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
|
|
@@ -1163,7 +1169,7 @@ class TimeSeasonality(Component):
|
|
|
1163
1169
|
innovations: bool = True,
|
|
1164
1170
|
name: str | None = None,
|
|
1165
1171
|
state_names: list | None = None,
|
|
1166
|
-
|
|
1172
|
+
remove_first_state: bool = True,
|
|
1167
1173
|
):
|
|
1168
1174
|
if name is None:
|
|
1169
1175
|
name = f"Seasonal[s={season_length}]"
|
|
@@ -1176,14 +1182,15 @@ class TimeSeasonality(Component):
|
|
|
1176
1182
|
)
|
|
1177
1183
|
state_names = state_names.copy()
|
|
1178
1184
|
self.innovations = innovations
|
|
1179
|
-
self.
|
|
1185
|
+
self.remove_first_state = remove_first_state
|
|
1180
1186
|
|
|
1181
|
-
if self.
|
|
1187
|
+
if self.remove_first_state:
|
|
1182
1188
|
# In traditional models, the first state isn't identified, so we can help out the user by automatically
|
|
1183
1189
|
# discarding it.
|
|
1184
1190
|
# TODO: Can this be stashed and reconstructed automatically somehow?
|
|
1185
1191
|
state_names.pop(0)
|
|
1186
|
-
|
|
1192
|
+
|
|
1193
|
+
k_states = season_length - int(self.remove_first_state)
|
|
1187
1194
|
|
|
1188
1195
|
super().__init__(
|
|
1189
1196
|
name=name,
|
|
@@ -1218,8 +1225,16 @@ class TimeSeasonality(Component):
|
|
|
1218
1225
|
self.shock_names = [f"{self.name}"]
|
|
1219
1226
|
|
|
1220
1227
|
def make_symbolic_graph(self) -> None:
|
|
1221
|
-
|
|
1222
|
-
|
|
1228
|
+
if self.remove_first_state:
|
|
1229
|
+
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
|
|
1230
|
+
# all previous states.
|
|
1231
|
+
T = np.eye(self.k_states, k=-1)
|
|
1232
|
+
T[0, :] = -1
|
|
1233
|
+
else:
|
|
1234
|
+
# In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a
|
|
1235
|
+
# circulant matrix that cycles between the states.
|
|
1236
|
+
T = np.eye(self.k_states, k=1)
|
|
1237
|
+
T[-1, 0] = 1
|
|
1223
1238
|
|
|
1224
1239
|
self.ssm["transition", :, :] = T
|
|
1225
1240
|
self.ssm["design", 0, 0] = 1
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
from pymc.model.core import Model
|
|
4
|
+
from pymc.model.fgraph import fgraph_from_model
|
|
5
|
+
from pytensor import Variable
|
|
6
|
+
from pytensor.compile import SharedVariable
|
|
7
|
+
from pytensor.graph import Constant, graph_inputs
|
|
8
|
+
from pytensor.graph.basic import equal_computations
|
|
9
|
+
from pytensor.tensor.random.type import RandomType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def equal_computations_up_to_root(
|
|
13
|
+
xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True
|
|
14
|
+
) -> bool:
|
|
15
|
+
# Check if graphs are equivalent even if root variables have distinct identities
|
|
16
|
+
|
|
17
|
+
x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)]
|
|
18
|
+
y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)]
|
|
19
|
+
if len(x_graph_inputs) != len(y_graph_inputs):
|
|
20
|
+
return False
|
|
21
|
+
for x, y in zip(x_graph_inputs, y_graph_inputs):
|
|
22
|
+
if x.type != y.type:
|
|
23
|
+
return False
|
|
24
|
+
if x.name != y.name:
|
|
25
|
+
return False
|
|
26
|
+
if isinstance(x, SharedVariable):
|
|
27
|
+
# if not isinstance(y, SharedVariable):
|
|
28
|
+
# return False
|
|
29
|
+
if isinstance(x.type, RandomType) and ignore_rng_values:
|
|
30
|
+
continue
|
|
31
|
+
if not x.type.values_eq(x.get_value(), y.get_value()):
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def equivalent_models(model1: Model, model2: Model) -> bool:
|
|
38
|
+
"""Check whether two PyMC models are equivalent.
|
|
39
|
+
|
|
40
|
+
Examples
|
|
41
|
+
--------
|
|
42
|
+
|
|
43
|
+
.. code-block:: python
|
|
44
|
+
|
|
45
|
+
import pymc as pm
|
|
46
|
+
from pymc_extras.utils.model_equivalence import equivalent_models
|
|
47
|
+
|
|
48
|
+
with pm.Model() as m1:
|
|
49
|
+
x = pm.Normal("x")
|
|
50
|
+
y = pm.Normal("y", x)
|
|
51
|
+
|
|
52
|
+
with pm.Model() as m2:
|
|
53
|
+
x = pm.Normal("x")
|
|
54
|
+
y = pm.Normal("y", x + 1)
|
|
55
|
+
|
|
56
|
+
with pm.Model() as m3:
|
|
57
|
+
x = pm.Normal("x")
|
|
58
|
+
y = pm.Normal("y", x)
|
|
59
|
+
|
|
60
|
+
assert not equivalent_models(m1, m2)
|
|
61
|
+
assert equivalent_models(m1, m3)
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
fgraph1, _ = fgraph_from_model(model1)
|
|
65
|
+
fgraph2, _ = fgraph_from_model(model2)
|
|
66
|
+
return equal_computations_up_to_root(fgraph1.outputs, fgraph2.outputs)
|
pymc_extras/version.txt
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.2.
|
|
1
|
+
0.2.1
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Home-page: http://github.com/pymc-devs/pymc-extras
|
|
6
6
|
Maintainer: PyMC Developers
|
|
@@ -63,10 +63,9 @@ import pymc as pm
|
|
|
63
63
|
import pymc_extras as pmx
|
|
64
64
|
|
|
65
65
|
with pm.Model():
|
|
66
|
+
alpha = pmx.ParabolicFractal('alpha', b=1, c=1)
|
|
66
67
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
...
|
|
68
|
+
...
|
|
70
69
|
|
|
71
70
|
```
|
|
72
71
|
|
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
pymc_extras/__init__.py,sha256=
|
|
1
|
+
pymc_extras/__init__.py,sha256=URh185f6b1xp2Taj2W2NJuW_hErKufBcLeQ0WDCyaNk,1160
|
|
2
2
|
pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
|
|
3
3
|
pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
|
|
4
4
|
pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
|
|
5
5
|
pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
|
|
6
|
-
pymc_extras/version.txt,sha256=
|
|
6
|
+
pymc_extras/version.txt,sha256=cQFcl5zLD8igvnygroMEarBFzcLI-qCfsvD35ED5tKY,6
|
|
7
7
|
pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
|
|
8
8
|
pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
|
|
9
9
|
pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
|
|
10
10
|
pymc_extras/distributions/histogram_utils.py,sha256=5RTvlGCUrp2qzshrchmPyWxjhs6RIYL62SMikjDM1TU,5814
|
|
11
|
-
pymc_extras/distributions/timeseries.py,sha256=
|
|
11
|
+
pymc_extras/distributions/timeseries.py,sha256=M5MZ-nik_tgkaoZ1hdUGEZ9g04DQyVLwszVJqSKwNcY,12719
|
|
12
12
|
pymc_extras/distributions/multivariate/__init__.py,sha256=E8OeLW9tTotCbrUjEo4um76-_WQD56PehsPzkKmhfyA,93
|
|
13
13
|
pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNPi727uTbvAdxl9fm1zNBqY,16005
|
|
14
14
|
pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
|
|
@@ -23,9 +23,9 @@ pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmN
|
|
|
23
23
|
pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
24
|
pymc_extras/model/model_api.py,sha256=_r6rYQG1tt9Z95QU-jVyHqZ1rs-u7sFMO5HJ5unDV5A,1750
|
|
25
25
|
pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
-
pymc_extras/model/marginal/distributions.py,sha256=
|
|
27
|
-
pymc_extras/model/marginal/graph_analysis.py,sha256=
|
|
28
|
-
pymc_extras/model/marginal/marginal_model.py,sha256=
|
|
26
|
+
pymc_extras/model/marginal/distributions.py,sha256=iM1yT7_BmivgUSloQPKE2QXGPgjvLqDMY_OTBGsdAWg,15563
|
|
27
|
+
pymc_extras/model/marginal/graph_analysis.py,sha256=0hWUH_PjfpgneQ3NaT__pWHS1fh50zNbI86kH4Nub0E,15693
|
|
28
|
+
pymc_extras/model/marginal/marginal_model.py,sha256=oNsiSWHjOPCTDxNEivEILLP_cOuBarm29Gr2p6hWHIM,23594
|
|
29
29
|
pymc_extras/model/transforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
30
|
pymc_extras/model/transforms/autoreparam.py,sha256=_NltGWmNqi_X9sHCqAvWcBveLTPxVy11-wENFTcN6kk,12377
|
|
31
31
|
pymc_extras/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -44,7 +44,7 @@ pymc_extras/statespace/models/ETS.py,sha256=o039M-6aCxyMXbbKvUeNVZhheCKvvNIAmuj0
|
|
|
44
44
|
pymc_extras/statespace/models/SARIMAX.py,sha256=SX0eiSK1pOt4dHBjWzBqVpRz67pBGLN5pQQgXcOiOgY,21607
|
|
45
45
|
pymc_extras/statespace/models/VARMAX.py,sha256=xkIuftNc_5NHFpqZalExni99-1kovnzm5OjMIDNgaxY,15989
|
|
46
46
|
pymc_extras/statespace/models/__init__.py,sha256=U79b8rTHBNijVvvGOd43nLu4PCloPUH1rwlN87-n88c,317
|
|
47
|
-
pymc_extras/statespace/models/structural.py,sha256=
|
|
47
|
+
pymc_extras/statespace/models/structural.py,sha256=sep9pesJdRN4X8Bea6_RhO3112uWOZRuYRxO6ibl_OA,63943
|
|
48
48
|
pymc_extras/statespace/models/utilities.py,sha256=G9GuHKsghmIYOlfkPtvxBWF-FZY5-5JI1fJQM8N7EnE,15373
|
|
49
49
|
pymc_extras/statespace/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
50
50
|
pymc_extras/statespace/utils/constants.py,sha256=Kf6j75ABaDQeRODxKQ76wTUQV4F5sTjn1KBcZgCQx20,2403
|
|
@@ -52,6 +52,7 @@ pymc_extras/statespace/utils/coord_tools.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm
|
|
|
52
52
|
pymc_extras/statespace/utils/data_tools.py,sha256=caanvrxDu9g-dEKff2bbmaVTs6-71kkSoYIiiSUXhw4,5985
|
|
53
53
|
pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
|
|
54
54
|
pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
|
|
55
|
+
pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
|
|
55
56
|
pymc_extras/utils/pivoted_cholesky.py,sha256=QtnjP0pAl9b77fLAu-semwT4_9dcoiqx3dz1xKGBjMk,1871
|
|
56
57
|
pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
|
|
57
58
|
pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
|
|
@@ -63,11 +64,11 @@ tests/test_laplace.py,sha256=5ioEyP6AzmMszrtQRz0KWTsCCU35SEhSOdBcYfYzptE,8228
|
|
|
63
64
|
tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
|
|
64
65
|
tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
|
|
65
66
|
tests/test_pathfinder.py,sha256=FBm0ge6rje5jz9_10h_247E70aKCpkbu1jmzrR7Ar8A,1726
|
|
66
|
-
tests/test_pivoted_cholesky.py,sha256=
|
|
67
|
+
tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
|
|
67
68
|
tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
|
|
68
69
|
tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
|
|
69
70
|
tests/test_splines.py,sha256=xSZi4hqqReN1H8LHr0xjDmpomhDQm8auIsWQjFOyjbM,2608
|
|
70
|
-
tests/utils.py,sha256=
|
|
71
|
+
tests/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
71
72
|
tests/distributions/__init__.py,sha256=jt-oloszTLNFwi9AgU3M4m6xKQ8xpQE338rmmaMZcMs,795
|
|
72
73
|
tests/distributions/test_continuous.py,sha256=1-bu-IP6RgLUJnuPYpOD8ZS1ahYbKtsJ9oflBfqCaFo,5477
|
|
73
74
|
tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcleoCXUi0,7776
|
|
@@ -76,9 +77,9 @@ tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIX
|
|
|
76
77
|
tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
78
|
tests/model/test_model_api.py,sha256=SiOMA1NpyQKJ7stYI1ms8ksDPU81lVo8wS8hbqiik-U,776
|
|
78
79
|
tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
79
|
-
tests/model/marginal/test_distributions.py,sha256=
|
|
80
|
+
tests/model/marginal/test_distributions.py,sha256=p5f73g4ogxYkdZaBndZV_1ra8TCppXiRlUpaaTwEe-M,5195
|
|
80
81
|
tests/model/marginal/test_graph_analysis.py,sha256=raoj41NusMOj1zzPCrxrlQODqX6Ey8Ft_o32pNTe5qg,6712
|
|
81
|
-
tests/model/marginal/test_marginal_model.py,sha256=
|
|
82
|
+
tests/model/marginal/test_marginal_model.py,sha256=uOmARalkdWq3sDbnJQ0KjiLwviqauZOAnafmYS_Cnd8,35475
|
|
82
83
|
tests/statespace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
83
84
|
tests/statespace/test_ETS.py,sha256=IPg3uQ7xEGqDMEHu993vtUTV7r-uNAxmw23sr5MVGfQ,15582
|
|
84
85
|
tests/statespace/test_SARIMAX.py,sha256=1BYNOm9aSHnpn-qbpe3YsQVH8m-mXcp_gvKgWhWn1W4,12948
|
|
@@ -89,13 +90,13 @@ tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_Y
|
|
|
89
90
|
tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
|
|
90
91
|
tests/statespace/test_statespace.py,sha256=8ZLLQaxlP5UEJnIMYyIzzAODCxMxs6E5I1hLu2HCdqo,28866
|
|
91
92
|
tests/statespace/test_statespace_JAX.py,sha256=hZOc6xxYdVeATPCKmcHMLOVcuvdzGRzgQQ4RrDenwk8,5279
|
|
92
|
-
tests/statespace/test_structural.py,sha256=
|
|
93
|
+
tests/statespace/test_structural.py,sha256=HD8OaGbjuH4y3xv_uG-R1xLZpPpcb4-3dbcTeb_imLY,29306
|
|
93
94
|
tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
94
95
|
tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
|
|
95
96
|
tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
|
|
96
97
|
tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
|
|
97
|
-
pymc_extras-0.2.
|
|
98
|
-
pymc_extras-0.2.
|
|
99
|
-
pymc_extras-0.2.
|
|
100
|
-
pymc_extras-0.2.
|
|
101
|
-
pymc_extras-0.2.
|
|
98
|
+
pymc_extras-0.2.1.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
|
|
99
|
+
pymc_extras-0.2.1.dist-info/METADATA,sha256=pT1MOjFxsX6lc0q_D3J-2jNW6UaRkCq_0kJemgG4DGU,4894
|
|
100
|
+
pymc_extras-0.2.1.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
|
|
101
|
+
pymc_extras-0.2.1.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
|
|
102
|
+
pymc_extras-0.2.1.dist-info/RECORD,,
|
|
@@ -6,7 +6,7 @@ from pymc.logprob.abstract import _logprob
|
|
|
6
6
|
from pytensor import tensor as pt
|
|
7
7
|
from scipy.stats import norm
|
|
8
8
|
|
|
9
|
-
from pymc_extras import
|
|
9
|
+
from pymc_extras import marginalize
|
|
10
10
|
from pymc_extras.distributions import DiscreteMarkovChain
|
|
11
11
|
from pymc_extras.model.marginal.distributions import MarginalFiniteDiscreteRV
|
|
12
12
|
|
|
@@ -21,6 +21,7 @@ def test_marginalized_bernoulli_logp():
|
|
|
21
21
|
[mu],
|
|
22
22
|
[idx, y],
|
|
23
23
|
dims_connections=(((),),),
|
|
24
|
+
dims=(),
|
|
24
25
|
)(mu)[0].owner
|
|
25
26
|
|
|
26
27
|
y_vv = y.clone()
|
|
@@ -43,7 +44,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
|
|
|
43
44
|
if batch_chain and not batch_emission:
|
|
44
45
|
pytest.skip("Redundant implicit combination")
|
|
45
46
|
|
|
46
|
-
with
|
|
47
|
+
with pm.Model() as m:
|
|
47
48
|
P = [[0, 1], [1, 0]]
|
|
48
49
|
init_dist = pm.Categorical.dist(p=[1, 0])
|
|
49
50
|
chain = DiscreteMarkovChain(
|
|
@@ -53,8 +54,8 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
|
|
|
53
54
|
"emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None
|
|
54
55
|
)
|
|
55
56
|
|
|
56
|
-
|
|
57
|
-
logp_fn =
|
|
57
|
+
marginal_m = marginalize(m, [chain])
|
|
58
|
+
logp_fn = marginal_m.compile_logp()
|
|
58
59
|
|
|
59
60
|
test_value = np.array([-1, 1, -1, 1])
|
|
60
61
|
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
|
|
@@ -70,7 +71,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
|
|
|
70
71
|
)
|
|
71
72
|
def test_marginalized_hmm_categorical_emission(categorical_emission):
|
|
72
73
|
"""Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
|
|
73
|
-
with
|
|
74
|
+
with pm.Model() as m:
|
|
74
75
|
P = np.array([[0.5, 0.5], [0.3, 0.7]])
|
|
75
76
|
init_dist = pm.Categorical.dist(p=[0.375, 0.625])
|
|
76
77
|
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
|
|
@@ -78,11 +79,11 @@ def test_marginalized_hmm_categorical_emission(categorical_emission):
|
|
|
78
79
|
emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain])
|
|
79
80
|
else:
|
|
80
81
|
emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
|
|
81
|
-
|
|
82
|
+
marginal_m = marginalize(m, [chain])
|
|
82
83
|
|
|
83
84
|
test_value = np.array([0, 0, 1])
|
|
84
85
|
expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video
|
|
85
|
-
logp_fn =
|
|
86
|
+
logp_fn = marginal_m.compile_logp()
|
|
86
87
|
np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
|
|
87
88
|
|
|
88
89
|
|
|
@@ -95,7 +96,7 @@ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch
|
|
|
95
96
|
(2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape))
|
|
96
97
|
)
|
|
97
98
|
emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape
|
|
98
|
-
with
|
|
99
|
+
with pm.Model() as m:
|
|
99
100
|
P = [[0, 1], [1, 0]]
|
|
100
101
|
init_dist = pm.Categorical.dist(p=[1, 0])
|
|
101
102
|
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape)
|
|
@@ -108,10 +109,10 @@ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch
|
|
|
108
109
|
emission2_mu = emission2_mu[..., None]
|
|
109
110
|
emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape)
|
|
110
111
|
|
|
111
|
-
|
|
112
|
-
m.marginalize([chain])
|
|
112
|
+
marginal_m = marginalize(m, [chain])
|
|
113
113
|
|
|
114
|
-
|
|
114
|
+
with pytest.warns(UserWarning, match="multiple dependent variables"):
|
|
115
|
+
logp_fn = marginal_m.compile_logp(sum=False)
|
|
115
116
|
|
|
116
117
|
test_value = np.array([-1, 1, -1, 1])
|
|
117
118
|
multiplier = 2 + batch_emission1 + batch_emission2
|