pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +6 -4
- pymc_extras/distributions/__init__.py +2 -0
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/distributions/transforms/__init__.py +3 -0
- pymc_extras/distributions/transforms/partial_order.py +227 -0
- pymc_extras/inference/__init__.py +4 -2
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/fit.py +6 -4
- pymc_extras/inference/laplace.py +14 -8
- pymc_extras/inference/pathfinder/lbfgs.py +49 -13
- pymc_extras/inference/pathfinder/pathfinder.py +89 -103
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras/statespace/utils/data_tools.py +24 -9
- pymc_extras-0.2.6.dist-info/METADATA +318 -0
- pymc_extras-0.2.6.dist-info/RECORD +65 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.4.dist-info/METADATA +0 -110
- pymc_extras-0.2.4.dist-info/RECORD +0 -105
- pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -116
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -265
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -203
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
|
@@ -1,304 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pymc as pm
|
|
3
|
-
import pytensor
|
|
4
|
-
import pytest
|
|
5
|
-
|
|
6
|
-
import pymc_extras as pmx
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class TestR2D2M2CP:
|
|
10
|
-
@pytest.fixture(autouse=True)
|
|
11
|
-
def fast_compile(self):
|
|
12
|
-
with pytensor.config.change_flags(mode="FAST_COMPILE", exception_verbosity="high"):
|
|
13
|
-
yield
|
|
14
|
-
|
|
15
|
-
@pytest.fixture(autouse=True)
|
|
16
|
-
def model(self):
|
|
17
|
-
# every method is within a model
|
|
18
|
-
with pm.Model() as model:
|
|
19
|
-
yield model
|
|
20
|
-
|
|
21
|
-
@pytest.fixture(params=[True, False], ids=["centered", "non-centered"])
|
|
22
|
-
def centered(self, request):
|
|
23
|
-
return request.param
|
|
24
|
-
|
|
25
|
-
@pytest.fixture(params=[["a"], ["a", "b"], ["one"]])
|
|
26
|
-
def dims(self, model: pm.Model, request):
|
|
27
|
-
for i, c in enumerate(request.param):
|
|
28
|
-
if c == "one":
|
|
29
|
-
model.add_coord(c, range(1))
|
|
30
|
-
else:
|
|
31
|
-
model.add_coord(c, range((i + 2) ** 2))
|
|
32
|
-
return request.param
|
|
33
|
-
|
|
34
|
-
@pytest.fixture
|
|
35
|
-
def input_shape(self, dims, model):
|
|
36
|
-
return [int(model.dim_lengths[d].eval()) for d in dims]
|
|
37
|
-
|
|
38
|
-
@pytest.fixture
|
|
39
|
-
def output_shape(self, dims, model):
|
|
40
|
-
*hierarchy, _ = dims
|
|
41
|
-
return [int(model.dim_lengths[d].eval()) for d in hierarchy]
|
|
42
|
-
|
|
43
|
-
@pytest.fixture
|
|
44
|
-
def input_std(self, input_shape):
|
|
45
|
-
return np.ones(input_shape)
|
|
46
|
-
|
|
47
|
-
@pytest.fixture
|
|
48
|
-
def output_std(self, output_shape):
|
|
49
|
-
return np.ones(output_shape)
|
|
50
|
-
|
|
51
|
-
@pytest.fixture
|
|
52
|
-
def r2(self):
|
|
53
|
-
return 0.8
|
|
54
|
-
|
|
55
|
-
@pytest.fixture(params=[None, 0.1], ids=["r2-std", "no-r2-std"])
|
|
56
|
-
def r2_std(self, request):
|
|
57
|
-
return request.param
|
|
58
|
-
|
|
59
|
-
@pytest.fixture(params=["true", "false", "limit-1", "limit-0", "limit-all"])
|
|
60
|
-
def positive_probs(self, input_std, request):
|
|
61
|
-
if request.param == "true":
|
|
62
|
-
return np.full_like(input_std, 0.5)
|
|
63
|
-
elif request.param == "false":
|
|
64
|
-
return 0.5
|
|
65
|
-
elif request.param == "limit-1":
|
|
66
|
-
ret = np.full_like(input_std, 0.5)
|
|
67
|
-
ret[..., 0] = 1
|
|
68
|
-
return ret
|
|
69
|
-
elif request.param == "limit-0":
|
|
70
|
-
ret = np.full_like(input_std, 0.5)
|
|
71
|
-
ret[..., 0] = 0
|
|
72
|
-
return ret
|
|
73
|
-
elif request.param == "limit-all":
|
|
74
|
-
return np.full_like(input_std, 0)
|
|
75
|
-
|
|
76
|
-
@pytest.fixture(params=[True, False], ids=["probs-std", "no-probs-std"])
|
|
77
|
-
def positive_probs_std(self, positive_probs, request):
|
|
78
|
-
if request.param:
|
|
79
|
-
std = np.full_like(positive_probs, 0.1)
|
|
80
|
-
std[positive_probs == 0] = 0
|
|
81
|
-
std[positive_probs == 1] = 0
|
|
82
|
-
return std
|
|
83
|
-
else:
|
|
84
|
-
return None
|
|
85
|
-
|
|
86
|
-
@pytest.fixture(params=[None, "importance", "explained"])
|
|
87
|
-
def phi_args_base(self, request, input_shape):
|
|
88
|
-
if input_shape[-1] < 2 and request.param is not None:
|
|
89
|
-
pytest.skip("not compatible")
|
|
90
|
-
elif request.param is None:
|
|
91
|
-
return {}
|
|
92
|
-
elif request.param == "importance":
|
|
93
|
-
return {"variables_importance": np.full(input_shape, 2)}
|
|
94
|
-
else:
|
|
95
|
-
val = np.full(input_shape, 2)
|
|
96
|
-
return {"variance_explained": val / val.sum(-1, keepdims=True)}
|
|
97
|
-
|
|
98
|
-
@pytest.fixture(params=["concentration", "no-concentration"])
|
|
99
|
-
def phi_args(self, request, phi_args_base):
|
|
100
|
-
if request.param == "concentration":
|
|
101
|
-
phi_args_base["importance_concentration"] = 10
|
|
102
|
-
return phi_args_base
|
|
103
|
-
|
|
104
|
-
def test_init_r2(
|
|
105
|
-
self,
|
|
106
|
-
dims,
|
|
107
|
-
input_std,
|
|
108
|
-
output_std,
|
|
109
|
-
r2,
|
|
110
|
-
r2_std,
|
|
111
|
-
model: pm.Model,
|
|
112
|
-
):
|
|
113
|
-
eps, beta = pmx.distributions.R2D2M2CP(
|
|
114
|
-
"beta",
|
|
115
|
-
output_std,
|
|
116
|
-
input_std,
|
|
117
|
-
dims=dims,
|
|
118
|
-
r2=r2,
|
|
119
|
-
r2_std=r2_std,
|
|
120
|
-
)
|
|
121
|
-
assert not np.isnan(beta.eval()).any()
|
|
122
|
-
assert eps.eval().shape == output_std.shape
|
|
123
|
-
assert beta.eval().shape == input_std.shape
|
|
124
|
-
# r2 rv is only created if r2 std is not None
|
|
125
|
-
assert "beta" in model.named_vars
|
|
126
|
-
assert ("beta::r2" in model.named_vars) == (r2_std is not None), set(model.named_vars)
|
|
127
|
-
# phi is only created if variable importance is not None and there is more than one var
|
|
128
|
-
assert np.isfinite(model.compile_logp()(model.initial_point()))
|
|
129
|
-
|
|
130
|
-
def test_init_importance(
|
|
131
|
-
self,
|
|
132
|
-
dims,
|
|
133
|
-
centered,
|
|
134
|
-
input_std,
|
|
135
|
-
output_std,
|
|
136
|
-
phi_args,
|
|
137
|
-
model: pm.Model,
|
|
138
|
-
):
|
|
139
|
-
eps, beta = pmx.distributions.R2D2M2CP(
|
|
140
|
-
"beta",
|
|
141
|
-
output_std,
|
|
142
|
-
input_std,
|
|
143
|
-
dims=dims,
|
|
144
|
-
r2=1,
|
|
145
|
-
centered=centered,
|
|
146
|
-
**phi_args,
|
|
147
|
-
)
|
|
148
|
-
assert not np.isnan(beta.eval()).any()
|
|
149
|
-
assert eps.eval().shape == output_std.shape
|
|
150
|
-
assert beta.eval().shape == input_std.shape
|
|
151
|
-
# r2 rv is only created if r2 std is not None
|
|
152
|
-
assert "beta" in model.named_vars
|
|
153
|
-
# phi is only created if variable importance is not None and there is more than one var
|
|
154
|
-
assert ("beta::phi" in model.named_vars) == (
|
|
155
|
-
"variables_importance" in phi_args or "importance_concentration" in phi_args
|
|
156
|
-
), set(model.named_vars)
|
|
157
|
-
assert np.isfinite(model.compile_logp()(model.initial_point()))
|
|
158
|
-
|
|
159
|
-
def test_init_positive_probs(
|
|
160
|
-
self,
|
|
161
|
-
dims,
|
|
162
|
-
centered,
|
|
163
|
-
input_std,
|
|
164
|
-
output_std,
|
|
165
|
-
positive_probs,
|
|
166
|
-
positive_probs_std,
|
|
167
|
-
model: pm.Model,
|
|
168
|
-
):
|
|
169
|
-
eps, beta = pmx.distributions.R2D2M2CP(
|
|
170
|
-
"beta",
|
|
171
|
-
output_std,
|
|
172
|
-
input_std,
|
|
173
|
-
dims=dims,
|
|
174
|
-
r2=1.0,
|
|
175
|
-
centered=centered,
|
|
176
|
-
positive_probs_std=positive_probs_std,
|
|
177
|
-
positive_probs=positive_probs,
|
|
178
|
-
)
|
|
179
|
-
assert not np.isnan(beta.eval()).any()
|
|
180
|
-
assert eps.eval().shape == output_std.shape
|
|
181
|
-
assert beta.eval().shape == input_std.shape
|
|
182
|
-
# r2 rv is only created if r2 std is not None
|
|
183
|
-
assert "beta" in model.named_vars
|
|
184
|
-
# phi is only created if variable importance is not None and there is more than one var
|
|
185
|
-
assert ("beta::psi" in model.named_vars) == (
|
|
186
|
-
positive_probs_std is not None and positive_probs_std.any()
|
|
187
|
-
), set(model.named_vars)
|
|
188
|
-
assert np.isfinite(sum(model.point_logps().values()))
|
|
189
|
-
|
|
190
|
-
def test_failing_importance(self, dims, input_shape, output_std, input_std):
|
|
191
|
-
if input_shape[-1] < 2:
|
|
192
|
-
with pytest.raises(TypeError, match="less than two variables"):
|
|
193
|
-
pmx.distributions.R2D2M2CP(
|
|
194
|
-
"beta",
|
|
195
|
-
output_std,
|
|
196
|
-
input_std,
|
|
197
|
-
dims=dims,
|
|
198
|
-
r2=0.8,
|
|
199
|
-
variables_importance=abs(input_std),
|
|
200
|
-
)
|
|
201
|
-
else:
|
|
202
|
-
pmx.distributions.R2D2M2CP(
|
|
203
|
-
"beta",
|
|
204
|
-
output_std,
|
|
205
|
-
input_std,
|
|
206
|
-
dims=dims,
|
|
207
|
-
r2=0.8,
|
|
208
|
-
variables_importance=abs(input_std),
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
def test_failing_variance_explained(self, dims, input_shape, output_std, input_std):
|
|
212
|
-
if input_shape[-1] < 2:
|
|
213
|
-
with pytest.raises(TypeError, match="less than two variables"):
|
|
214
|
-
pmx.distributions.R2D2M2CP(
|
|
215
|
-
"beta",
|
|
216
|
-
output_std,
|
|
217
|
-
input_std,
|
|
218
|
-
dims=dims,
|
|
219
|
-
r2=0.8,
|
|
220
|
-
variance_explained=abs(input_std),
|
|
221
|
-
)
|
|
222
|
-
else:
|
|
223
|
-
pmx.distributions.R2D2M2CP(
|
|
224
|
-
"beta", output_std, input_std, dims=dims, r2=0.8, variance_explained=abs(input_std)
|
|
225
|
-
)
|
|
226
|
-
|
|
227
|
-
def test_failing_mutual_exclusive(self, model: pm.Model):
|
|
228
|
-
with pytest.raises(TypeError, match="variable importance with variance explained"):
|
|
229
|
-
with model:
|
|
230
|
-
model.add_coord("a", range(2))
|
|
231
|
-
pmx.distributions.R2D2M2CP(
|
|
232
|
-
"beta",
|
|
233
|
-
1,
|
|
234
|
-
[1, 1],
|
|
235
|
-
dims="a",
|
|
236
|
-
r2=0.8,
|
|
237
|
-
variance_explained=[0.5, 0.5],
|
|
238
|
-
variables_importance=[1, 1],
|
|
239
|
-
)
|
|
240
|
-
|
|
241
|
-
def test_limit_case_requires_std_0(self, model: pm.Model):
|
|
242
|
-
model.add_coord("a", range(2))
|
|
243
|
-
with pytest.raises(ValueError, match="Can't have both positive_probs"):
|
|
244
|
-
pmx.distributions.R2D2M2CP(
|
|
245
|
-
"beta",
|
|
246
|
-
1,
|
|
247
|
-
[1, 1],
|
|
248
|
-
dims="a",
|
|
249
|
-
r2=0.8,
|
|
250
|
-
positive_probs=[0.5, 0],
|
|
251
|
-
positive_probs_std=[0.3, 0.1],
|
|
252
|
-
)
|
|
253
|
-
with pytest.raises(ValueError, match="Can't have both positive_probs"):
|
|
254
|
-
pmx.distributions.R2D2M2CP(
|
|
255
|
-
"beta",
|
|
256
|
-
1,
|
|
257
|
-
[1, 1],
|
|
258
|
-
dims="a",
|
|
259
|
-
r2=0.8,
|
|
260
|
-
positive_probs=[0.5, 1],
|
|
261
|
-
positive_probs_std=[0.3, 0.1],
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
def test_limit_case_creates_masked_vars(self, model: pm.Model, centered: bool):
|
|
265
|
-
model.add_coord("a", range(2))
|
|
266
|
-
pmx.distributions.R2D2M2CP(
|
|
267
|
-
"beta0",
|
|
268
|
-
1,
|
|
269
|
-
[1, 1],
|
|
270
|
-
dims="a",
|
|
271
|
-
r2=0.8,
|
|
272
|
-
positive_probs=[0.5, 1],
|
|
273
|
-
positive_probs_std=[0.3, 0],
|
|
274
|
-
centered=centered,
|
|
275
|
-
)
|
|
276
|
-
pmx.distributions.R2D2M2CP(
|
|
277
|
-
"beta1",
|
|
278
|
-
1,
|
|
279
|
-
[1, 1],
|
|
280
|
-
dims="a",
|
|
281
|
-
r2=0.8,
|
|
282
|
-
positive_probs=[0.5, 0],
|
|
283
|
-
positive_probs_std=[0.3, 0],
|
|
284
|
-
centered=centered,
|
|
285
|
-
)
|
|
286
|
-
if not centered:
|
|
287
|
-
assert "beta0::raw::masked" in model.named_vars, model.named_vars
|
|
288
|
-
assert "beta1::raw::masked" in model.named_vars, model.named_vars
|
|
289
|
-
else:
|
|
290
|
-
assert "beta0::masked" in model.named_vars, model.named_vars
|
|
291
|
-
assert "beta1::masked" in model.named_vars, model.named_vars
|
|
292
|
-
assert "beta1::psi::masked" in model.named_vars
|
|
293
|
-
assert "beta0::psi::masked" in model.named_vars
|
|
294
|
-
|
|
295
|
-
def test_zero_length_rvs_not_created(self, model: pm.Model):
|
|
296
|
-
model.add_coord("a", range(2))
|
|
297
|
-
# deterministic case which should not have any new variables
|
|
298
|
-
b = pmx.distributions.R2D2M2CP("b1", 1, [1, 1], r2=0.5, positive_probs=[1, 1], dims="a")
|
|
299
|
-
assert not model.free_RVs, model.free_RVs
|
|
300
|
-
|
|
301
|
-
b = pmx.distributions.R2D2M2CP(
|
|
302
|
-
"b2", 1, [1, 1], r2=0.5, positive_probs=[1, 1], positive_probs_std=[0, 0], dims="a"
|
|
303
|
-
)
|
|
304
|
-
assert not model.free_RVs, model.free_RVs
|
tests/model/__init__.py
DELETED
|
File without changes
|
tests/model/marginal/__init__.py
DELETED
|
File without changes
|
|
@@ -1,132 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pymc as pm
|
|
3
|
-
import pytest
|
|
4
|
-
|
|
5
|
-
from pymc.logprob.abstract import _logprob
|
|
6
|
-
from pytensor import tensor as pt
|
|
7
|
-
from scipy.stats import norm
|
|
8
|
-
|
|
9
|
-
from pymc_extras import marginalize
|
|
10
|
-
from pymc_extras.distributions import DiscreteMarkovChain
|
|
11
|
-
from pymc_extras.model.marginal.distributions import MarginalFiniteDiscreteRV
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def test_marginalized_bernoulli_logp():
|
|
15
|
-
"""Test logp of IR TestFiniteMarginalDiscreteRV directly"""
|
|
16
|
-
mu = pt.vector("mu")
|
|
17
|
-
|
|
18
|
-
idx = pm.Bernoulli.dist(0.7, name="idx")
|
|
19
|
-
y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y")
|
|
20
|
-
marginal_rv_node = MarginalFiniteDiscreteRV(
|
|
21
|
-
[mu],
|
|
22
|
-
[idx, y],
|
|
23
|
-
dims_connections=(((),),),
|
|
24
|
-
dims=(),
|
|
25
|
-
)(mu)[0].owner
|
|
26
|
-
|
|
27
|
-
y_vv = y.clone()
|
|
28
|
-
(logp,) = _logprob(
|
|
29
|
-
marginal_rv_node.op,
|
|
30
|
-
(y_vv,),
|
|
31
|
-
*marginal_rv_node.inputs,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=mu, sigma=1.0), y_vv)
|
|
35
|
-
np.testing.assert_almost_equal(
|
|
36
|
-
logp.eval({mu: [-1, 1], y_vv: 2}),
|
|
37
|
-
ref_logp.eval({mu: [-1, 1], y_vv: 2}),
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
@pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}")
|
|
42
|
-
@pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}")
|
|
43
|
-
def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
|
|
44
|
-
if batch_chain and not batch_emission:
|
|
45
|
-
pytest.skip("Redundant implicit combination")
|
|
46
|
-
|
|
47
|
-
with pm.Model() as m:
|
|
48
|
-
P = [[0, 1], [1, 0]]
|
|
49
|
-
init_dist = pm.Categorical.dist(p=[1, 0])
|
|
50
|
-
chain = DiscreteMarkovChain(
|
|
51
|
-
"chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None
|
|
52
|
-
)
|
|
53
|
-
emission = pm.Normal(
|
|
54
|
-
"emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
marginal_m = marginalize(m, [chain])
|
|
58
|
-
logp_fn = marginal_m.compile_logp()
|
|
59
|
-
|
|
60
|
-
test_value = np.array([-1, 1, -1, 1])
|
|
61
|
-
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
|
|
62
|
-
if batch_emission:
|
|
63
|
-
test_value = np.broadcast_to(test_value, (3, 4))
|
|
64
|
-
expected_logp *= 3
|
|
65
|
-
np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
@pytest.mark.parametrize(
|
|
69
|
-
"categorical_emission",
|
|
70
|
-
[False, True],
|
|
71
|
-
)
|
|
72
|
-
def test_marginalized_hmm_categorical_emission(categorical_emission):
|
|
73
|
-
"""Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
|
|
74
|
-
with pm.Model() as m:
|
|
75
|
-
P = np.array([[0.5, 0.5], [0.3, 0.7]])
|
|
76
|
-
init_dist = pm.Categorical.dist(p=[0.375, 0.625])
|
|
77
|
-
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
|
|
78
|
-
if categorical_emission:
|
|
79
|
-
emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain])
|
|
80
|
-
else:
|
|
81
|
-
emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
|
|
82
|
-
marginal_m = marginalize(m, [chain])
|
|
83
|
-
|
|
84
|
-
test_value = np.array([0, 0, 1])
|
|
85
|
-
expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video
|
|
86
|
-
logp_fn = marginal_m.compile_logp()
|
|
87
|
-
np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
@pytest.mark.parametrize("batch_chain", (False, True))
|
|
91
|
-
@pytest.mark.parametrize("batch_emission1", (False, True))
|
|
92
|
-
@pytest.mark.parametrize("batch_emission2", (False, True))
|
|
93
|
-
def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch_emission2):
|
|
94
|
-
chain_shape = (3, 1, 4) if batch_chain else (4,)
|
|
95
|
-
emission1_shape = (
|
|
96
|
-
(2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape))
|
|
97
|
-
)
|
|
98
|
-
emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape
|
|
99
|
-
with pm.Model() as m:
|
|
100
|
-
P = [[0, 1], [1, 0]]
|
|
101
|
-
init_dist = pm.Categorical.dist(p=[1, 0])
|
|
102
|
-
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape)
|
|
103
|
-
emission_1 = pm.Normal(
|
|
104
|
-
"emission_1", mu=(chain * 2 - 1).T, sigma=1e-1, shape=emission1_shape
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
emission2_mu = (1 - chain) * 2 - 1
|
|
108
|
-
if batch_emission2:
|
|
109
|
-
emission2_mu = emission2_mu[..., None]
|
|
110
|
-
emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape)
|
|
111
|
-
|
|
112
|
-
marginal_m = marginalize(m, [chain])
|
|
113
|
-
|
|
114
|
-
with pytest.warns(UserWarning, match="multiple dependent variables"):
|
|
115
|
-
logp_fn = marginal_m.compile_logp(sum=False)
|
|
116
|
-
|
|
117
|
-
test_value = np.array([-1, 1, -1, 1])
|
|
118
|
-
multiplier = 2 + batch_emission1 + batch_emission2
|
|
119
|
-
if batch_chain:
|
|
120
|
-
multiplier *= 3
|
|
121
|
-
expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier
|
|
122
|
-
|
|
123
|
-
test_value = np.broadcast_to(test_value, chain_shape)
|
|
124
|
-
test_value_emission1 = np.broadcast_to(test_value.T, emission1_shape)
|
|
125
|
-
if batch_emission2:
|
|
126
|
-
test_value_emission2 = np.broadcast_to(-test_value[..., None], emission2_shape)
|
|
127
|
-
else:
|
|
128
|
-
test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
|
|
129
|
-
test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2}
|
|
130
|
-
res_logp, dummy_logp = logp_fn(test_point)
|
|
131
|
-
assert res_logp.shape == ((1, 3) if batch_chain else ())
|
|
132
|
-
np.testing.assert_allclose(res_logp.sum(), expected_logp)
|
|
@@ -1,182 +0,0 @@
|
|
|
1
|
-
import pytensor.tensor as pt
|
|
2
|
-
import pytest
|
|
3
|
-
|
|
4
|
-
from pymc.distributions import CustomDist
|
|
5
|
-
from pytensor.tensor.type_other import NoneTypeT
|
|
6
|
-
|
|
7
|
-
from pymc_extras.model.marginal.graph_analysis import (
|
|
8
|
-
is_conditional_dependent,
|
|
9
|
-
subgraph_batch_dim_connection,
|
|
10
|
-
)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def test_is_conditional_dependent_static_shape():
|
|
14
|
-
"""Test that we don't consider dependencies through "constant" shape Ops"""
|
|
15
|
-
x1 = pt.matrix("x1", shape=(None, 5))
|
|
16
|
-
y1 = pt.random.normal(size=pt.shape(x1))
|
|
17
|
-
assert is_conditional_dependent(y1, x1, [x1, y1])
|
|
18
|
-
|
|
19
|
-
x2 = pt.matrix("x2", shape=(9, 5))
|
|
20
|
-
y2 = pt.random.normal(size=pt.shape(x2))
|
|
21
|
-
assert not is_conditional_dependent(y2, x2, [x2, y2])
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class TestSubgraphBatchDimConnection:
|
|
25
|
-
def test_dimshuffle(self):
|
|
26
|
-
inp = pt.tensor(shape=(5, 1, 4, 3))
|
|
27
|
-
out1 = pt.matrix_transpose(inp)
|
|
28
|
-
out2 = pt.expand_dims(inp, 1)
|
|
29
|
-
out3 = pt.squeeze(inp)
|
|
30
|
-
[dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3])
|
|
31
|
-
assert dims1 == (0, 1, 3, 2)
|
|
32
|
-
assert dims2 == (0, None, 1, 2, 3)
|
|
33
|
-
assert dims3 == (0, 2, 3)
|
|
34
|
-
|
|
35
|
-
def test_careduce(self):
|
|
36
|
-
inp = pt.tensor(shape=(4, 3, 2))
|
|
37
|
-
|
|
38
|
-
out = pt.sum(inp[:, None], axis=(1,))
|
|
39
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
40
|
-
assert dims == (0, 1, 2)
|
|
41
|
-
|
|
42
|
-
invalid_out = pt.sum(inp, axis=(1,))
|
|
43
|
-
with pytest.raises(ValueError, match="Use of known dimensions"):
|
|
44
|
-
subgraph_batch_dim_connection(inp, [invalid_out])
|
|
45
|
-
|
|
46
|
-
def test_subtensor(self):
|
|
47
|
-
inp = pt.tensor(shape=(4, 3, 2))
|
|
48
|
-
|
|
49
|
-
invalid_out = inp[0, :1]
|
|
50
|
-
with pytest.raises(
|
|
51
|
-
ValueError,
|
|
52
|
-
match="Partial slicing or indexing of known dimensions not supported",
|
|
53
|
-
):
|
|
54
|
-
subgraph_batch_dim_connection(inp, [invalid_out])
|
|
55
|
-
|
|
56
|
-
# If we are selecting dummy / unknown dimensions that's fine
|
|
57
|
-
valid_out = pt.expand_dims(inp, (0, 1))[0, :1]
|
|
58
|
-
[dims] = subgraph_batch_dim_connection(inp, [valid_out])
|
|
59
|
-
assert dims == (None, 0, 1, 2)
|
|
60
|
-
|
|
61
|
-
def test_advanced_subtensor_value(self):
|
|
62
|
-
inp = pt.tensor(shape=(2, 4))
|
|
63
|
-
intermediate_out = inp[:, None, :, None] + pt.zeros((2, 3, 4, 5))
|
|
64
|
-
|
|
65
|
-
# Index on an unlabled dim introduced by broadcasting with zeros
|
|
66
|
-
out = intermediate_out[:, [0, 0, 1, 2]]
|
|
67
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
68
|
-
assert dims == (0, None, 1, None)
|
|
69
|
-
|
|
70
|
-
# Indexing that introduces more dimensions
|
|
71
|
-
out = intermediate_out[:, [[0, 0], [1, 2]], :]
|
|
72
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
73
|
-
assert dims == (0, None, None, 1, None)
|
|
74
|
-
|
|
75
|
-
# Special case where advanced dims are moved to the front of the output
|
|
76
|
-
out = intermediate_out[:, [0, 0, 1, 2], :, 0]
|
|
77
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
78
|
-
assert dims == (None, 0, 1)
|
|
79
|
-
|
|
80
|
-
# Indexing on a labeled dim fails
|
|
81
|
-
out = intermediate_out[:, :, [0, 0, 1, 2]]
|
|
82
|
-
with pytest.raises(ValueError, match="Partial slicing or advanced integer indexing"):
|
|
83
|
-
subgraph_batch_dim_connection(inp, [out])
|
|
84
|
-
|
|
85
|
-
def test_advanced_subtensor_key(self):
|
|
86
|
-
inp = pt.tensor(shape=(5, 5), dtype=int)
|
|
87
|
-
base = pt.zeros((2, 3, 4))
|
|
88
|
-
|
|
89
|
-
out = base[inp]
|
|
90
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
91
|
-
assert dims == (0, 1, None, None)
|
|
92
|
-
|
|
93
|
-
out = base[:, :, inp]
|
|
94
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
95
|
-
assert dims == (
|
|
96
|
-
None,
|
|
97
|
-
None,
|
|
98
|
-
0,
|
|
99
|
-
1,
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
out = base[1:, 0, inp]
|
|
103
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
104
|
-
assert dims == (None, 0, 1)
|
|
105
|
-
|
|
106
|
-
# Special case where advanced dims are moved to the front of the output
|
|
107
|
-
out = base[0, :, inp]
|
|
108
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
109
|
-
assert dims == (0, 1, None)
|
|
110
|
-
|
|
111
|
-
# Mix keys dimensions
|
|
112
|
-
out = base[:, inp, inp.T]
|
|
113
|
-
with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"):
|
|
114
|
-
subgraph_batch_dim_connection(inp, [out])
|
|
115
|
-
|
|
116
|
-
def test_elemwise(self):
|
|
117
|
-
inp = pt.tensor(shape=(5, 5))
|
|
118
|
-
|
|
119
|
-
out = inp + inp
|
|
120
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
121
|
-
assert dims == (0, 1)
|
|
122
|
-
|
|
123
|
-
out = inp + inp.T
|
|
124
|
-
with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"):
|
|
125
|
-
subgraph_batch_dim_connection(inp, [out])
|
|
126
|
-
|
|
127
|
-
out = inp[None, :, None, :] + inp[:, None, :, None]
|
|
128
|
-
with pytest.raises(
|
|
129
|
-
ValueError, match="Same known dimension used in different axis after broadcasting"
|
|
130
|
-
):
|
|
131
|
-
subgraph_batch_dim_connection(inp, [out])
|
|
132
|
-
|
|
133
|
-
def test_blockwise(self):
|
|
134
|
-
inp = pt.tensor(shape=(5, 4))
|
|
135
|
-
|
|
136
|
-
invalid_out = inp @ pt.ones((4, 3))
|
|
137
|
-
with pytest.raises(ValueError, match="Use of known dimensions"):
|
|
138
|
-
subgraph_batch_dim_connection(inp, [invalid_out])
|
|
139
|
-
|
|
140
|
-
out = (inp[:, :, None, None] + pt.zeros((2, 3))) @ pt.ones((2, 3))
|
|
141
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
142
|
-
assert dims == (0, 1, None, None)
|
|
143
|
-
|
|
144
|
-
def test_random_variable(self):
|
|
145
|
-
inp = pt.tensor(shape=(5, 4, 3))
|
|
146
|
-
|
|
147
|
-
out1 = pt.random.normal(loc=inp)
|
|
148
|
-
out2 = pt.random.categorical(p=inp[..., None])
|
|
149
|
-
out3 = pt.random.multivariate_normal(mean=inp[..., None], cov=pt.eye(1))
|
|
150
|
-
[dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3])
|
|
151
|
-
assert dims1 == (0, 1, 2)
|
|
152
|
-
assert dims2 == (0, 1, 2)
|
|
153
|
-
assert dims3 == (0, 1, 2, None)
|
|
154
|
-
|
|
155
|
-
invalid_out = pt.random.categorical(p=inp)
|
|
156
|
-
with pytest.raises(ValueError, match="Use of known dimensions"):
|
|
157
|
-
subgraph_batch_dim_connection(inp, [invalid_out])
|
|
158
|
-
|
|
159
|
-
invalid_out = pt.random.multivariate_normal(mean=inp, cov=pt.eye(3))
|
|
160
|
-
with pytest.raises(ValueError, match="Use of known dimensions"):
|
|
161
|
-
subgraph_batch_dim_connection(inp, [invalid_out])
|
|
162
|
-
|
|
163
|
-
def test_symbolic_random_variable(self):
|
|
164
|
-
inp = pt.tensor(shape=(4, 3, 2))
|
|
165
|
-
|
|
166
|
-
# Test univariate
|
|
167
|
-
out = CustomDist.dist(
|
|
168
|
-
inp,
|
|
169
|
-
dist=lambda mu, size: pt.random.normal(loc=mu, size=size),
|
|
170
|
-
)
|
|
171
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
172
|
-
assert dims == (0, 1, 2)
|
|
173
|
-
|
|
174
|
-
# Test multivariate
|
|
175
|
-
def dist(mu, size):
|
|
176
|
-
if isinstance(size.type, NoneTypeT):
|
|
177
|
-
size = mu.shape
|
|
178
|
-
return pt.random.normal(loc=mu[..., None], size=(*size, 2))
|
|
179
|
-
|
|
180
|
-
out = CustomDist.dist(inp, dist=dist, size=(4, 3, 2), signature="()->(2)")
|
|
181
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
182
|
-
assert dims == (0, 1, 2, None)
|