pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,304 @@
1
+ import numpy as np
2
+ import pymc as pm
3
+ import pytensor
4
+ import pytest
5
+
6
+ import pymc_extras as pmx
7
+
8
+
9
+ class TestR2D2M2CP:
10
+ @pytest.fixture(autouse=True)
11
+ def fast_compile(self):
12
+ with pytensor.config.change_flags(mode="FAST_COMPILE", exception_verbosity="high"):
13
+ yield
14
+
15
+ @pytest.fixture(autouse=True)
16
+ def model(self):
17
+ # every method is within a model
18
+ with pm.Model() as model:
19
+ yield model
20
+
21
+ @pytest.fixture(params=[True, False], ids=["centered", "non-centered"])
22
+ def centered(self, request):
23
+ return request.param
24
+
25
+ @pytest.fixture(params=[["a"], ["a", "b"], ["one"]])
26
+ def dims(self, model: pm.Model, request):
27
+ for i, c in enumerate(request.param):
28
+ if c == "one":
29
+ model.add_coord(c, range(1))
30
+ else:
31
+ model.add_coord(c, range((i + 2) ** 2))
32
+ return request.param
33
+
34
+ @pytest.fixture
35
+ def input_shape(self, dims, model):
36
+ return [int(model.dim_lengths[d].eval()) for d in dims]
37
+
38
+ @pytest.fixture
39
+ def output_shape(self, dims, model):
40
+ *hierarchy, _ = dims
41
+ return [int(model.dim_lengths[d].eval()) for d in hierarchy]
42
+
43
+ @pytest.fixture
44
+ def input_std(self, input_shape):
45
+ return np.ones(input_shape)
46
+
47
+ @pytest.fixture
48
+ def output_std(self, output_shape):
49
+ return np.ones(output_shape)
50
+
51
+ @pytest.fixture
52
+ def r2(self):
53
+ return 0.8
54
+
55
+ @pytest.fixture(params=[None, 0.1], ids=["r2-std", "no-r2-std"])
56
+ def r2_std(self, request):
57
+ return request.param
58
+
59
+ @pytest.fixture(params=["true", "false", "limit-1", "limit-0", "limit-all"])
60
+ def positive_probs(self, input_std, request):
61
+ if request.param == "true":
62
+ return np.full_like(input_std, 0.5)
63
+ elif request.param == "false":
64
+ return 0.5
65
+ elif request.param == "limit-1":
66
+ ret = np.full_like(input_std, 0.5)
67
+ ret[..., 0] = 1
68
+ return ret
69
+ elif request.param == "limit-0":
70
+ ret = np.full_like(input_std, 0.5)
71
+ ret[..., 0] = 0
72
+ return ret
73
+ elif request.param == "limit-all":
74
+ return np.full_like(input_std, 0)
75
+
76
+ @pytest.fixture(params=[True, False], ids=["probs-std", "no-probs-std"])
77
+ def positive_probs_std(self, positive_probs, request):
78
+ if request.param:
79
+ std = np.full_like(positive_probs, 0.1)
80
+ std[positive_probs == 0] = 0
81
+ std[positive_probs == 1] = 0
82
+ return std
83
+ else:
84
+ return None
85
+
86
+ @pytest.fixture(params=[None, "importance", "explained"])
87
+ def phi_args_base(self, request, input_shape):
88
+ if input_shape[-1] < 2 and request.param is not None:
89
+ pytest.skip("not compatible")
90
+ elif request.param is None:
91
+ return {}
92
+ elif request.param == "importance":
93
+ return {"variables_importance": np.full(input_shape, 2)}
94
+ else:
95
+ val = np.full(input_shape, 2)
96
+ return {"variance_explained": val / val.sum(-1, keepdims=True)}
97
+
98
+ @pytest.fixture(params=["concentration", "no-concentration"])
99
+ def phi_args(self, request, phi_args_base):
100
+ if request.param == "concentration":
101
+ phi_args_base["importance_concentration"] = 10
102
+ return phi_args_base
103
+
104
+ def test_init_r2(
105
+ self,
106
+ dims,
107
+ input_std,
108
+ output_std,
109
+ r2,
110
+ r2_std,
111
+ model: pm.Model,
112
+ ):
113
+ eps, beta = pmx.distributions.R2D2M2CP(
114
+ "beta",
115
+ output_std,
116
+ input_std,
117
+ dims=dims,
118
+ r2=r2,
119
+ r2_std=r2_std,
120
+ )
121
+ assert not np.isnan(beta.eval()).any()
122
+ assert eps.eval().shape == output_std.shape
123
+ assert beta.eval().shape == input_std.shape
124
+ # r2 rv is only created if r2 std is not None
125
+ assert "beta" in model.named_vars
126
+ assert ("beta::r2" in model.named_vars) == (r2_std is not None), set(model.named_vars)
127
+ # phi is only created if variable importance is not None and there is more than one var
128
+ assert np.isfinite(model.compile_logp()(model.initial_point()))
129
+
130
+ def test_init_importance(
131
+ self,
132
+ dims,
133
+ centered,
134
+ input_std,
135
+ output_std,
136
+ phi_args,
137
+ model: pm.Model,
138
+ ):
139
+ eps, beta = pmx.distributions.R2D2M2CP(
140
+ "beta",
141
+ output_std,
142
+ input_std,
143
+ dims=dims,
144
+ r2=1,
145
+ centered=centered,
146
+ **phi_args,
147
+ )
148
+ assert not np.isnan(beta.eval()).any()
149
+ assert eps.eval().shape == output_std.shape
150
+ assert beta.eval().shape == input_std.shape
151
+ # r2 rv is only created if r2 std is not None
152
+ assert "beta" in model.named_vars
153
+ # phi is only created if variable importance is not None and there is more than one var
154
+ assert ("beta::phi" in model.named_vars) == (
155
+ "variables_importance" in phi_args or "importance_concentration" in phi_args
156
+ ), set(model.named_vars)
157
+ assert np.isfinite(model.compile_logp()(model.initial_point()))
158
+
159
+ def test_init_positive_probs(
160
+ self,
161
+ dims,
162
+ centered,
163
+ input_std,
164
+ output_std,
165
+ positive_probs,
166
+ positive_probs_std,
167
+ model: pm.Model,
168
+ ):
169
+ eps, beta = pmx.distributions.R2D2M2CP(
170
+ "beta",
171
+ output_std,
172
+ input_std,
173
+ dims=dims,
174
+ r2=1.0,
175
+ centered=centered,
176
+ positive_probs_std=positive_probs_std,
177
+ positive_probs=positive_probs,
178
+ )
179
+ assert not np.isnan(beta.eval()).any()
180
+ assert eps.eval().shape == output_std.shape
181
+ assert beta.eval().shape == input_std.shape
182
+ # r2 rv is only created if r2 std is not None
183
+ assert "beta" in model.named_vars
184
+ # phi is only created if variable importance is not None and there is more than one var
185
+ assert ("beta::psi" in model.named_vars) == (
186
+ positive_probs_std is not None and positive_probs_std.any()
187
+ ), set(model.named_vars)
188
+ assert np.isfinite(sum(model.point_logps().values()))
189
+
190
+ def test_failing_importance(self, dims, input_shape, output_std, input_std):
191
+ if input_shape[-1] < 2:
192
+ with pytest.raises(TypeError, match="less than two variables"):
193
+ pmx.distributions.R2D2M2CP(
194
+ "beta",
195
+ output_std,
196
+ input_std,
197
+ dims=dims,
198
+ r2=0.8,
199
+ variables_importance=abs(input_std),
200
+ )
201
+ else:
202
+ pmx.distributions.R2D2M2CP(
203
+ "beta",
204
+ output_std,
205
+ input_std,
206
+ dims=dims,
207
+ r2=0.8,
208
+ variables_importance=abs(input_std),
209
+ )
210
+
211
+ def test_failing_variance_explained(self, dims, input_shape, output_std, input_std):
212
+ if input_shape[-1] < 2:
213
+ with pytest.raises(TypeError, match="less than two variables"):
214
+ pmx.distributions.R2D2M2CP(
215
+ "beta",
216
+ output_std,
217
+ input_std,
218
+ dims=dims,
219
+ r2=0.8,
220
+ variance_explained=abs(input_std),
221
+ )
222
+ else:
223
+ pmx.distributions.R2D2M2CP(
224
+ "beta", output_std, input_std, dims=dims, r2=0.8, variance_explained=abs(input_std)
225
+ )
226
+
227
+ def test_failing_mutual_exclusive(self, model: pm.Model):
228
+ with pytest.raises(TypeError, match="variable importance with variance explained"):
229
+ with model:
230
+ model.add_coord("a", range(2))
231
+ pmx.distributions.R2D2M2CP(
232
+ "beta",
233
+ 1,
234
+ [1, 1],
235
+ dims="a",
236
+ r2=0.8,
237
+ variance_explained=[0.5, 0.5],
238
+ variables_importance=[1, 1],
239
+ )
240
+
241
+ def test_limit_case_requires_std_0(self, model: pm.Model):
242
+ model.add_coord("a", range(2))
243
+ with pytest.raises(ValueError, match="Can't have both positive_probs"):
244
+ pmx.distributions.R2D2M2CP(
245
+ "beta",
246
+ 1,
247
+ [1, 1],
248
+ dims="a",
249
+ r2=0.8,
250
+ positive_probs=[0.5, 0],
251
+ positive_probs_std=[0.3, 0.1],
252
+ )
253
+ with pytest.raises(ValueError, match="Can't have both positive_probs"):
254
+ pmx.distributions.R2D2M2CP(
255
+ "beta",
256
+ 1,
257
+ [1, 1],
258
+ dims="a",
259
+ r2=0.8,
260
+ positive_probs=[0.5, 1],
261
+ positive_probs_std=[0.3, 0.1],
262
+ )
263
+
264
+ def test_limit_case_creates_masked_vars(self, model: pm.Model, centered: bool):
265
+ model.add_coord("a", range(2))
266
+ pmx.distributions.R2D2M2CP(
267
+ "beta0",
268
+ 1,
269
+ [1, 1],
270
+ dims="a",
271
+ r2=0.8,
272
+ positive_probs=[0.5, 1],
273
+ positive_probs_std=[0.3, 0],
274
+ centered=centered,
275
+ )
276
+ pmx.distributions.R2D2M2CP(
277
+ "beta1",
278
+ 1,
279
+ [1, 1],
280
+ dims="a",
281
+ r2=0.8,
282
+ positive_probs=[0.5, 0],
283
+ positive_probs_std=[0.3, 0],
284
+ centered=centered,
285
+ )
286
+ if not centered:
287
+ assert "beta0::raw::masked" in model.named_vars, model.named_vars
288
+ assert "beta1::raw::masked" in model.named_vars, model.named_vars
289
+ else:
290
+ assert "beta0::masked" in model.named_vars, model.named_vars
291
+ assert "beta1::masked" in model.named_vars, model.named_vars
292
+ assert "beta1::psi::masked" in model.named_vars
293
+ assert "beta0::psi::masked" in model.named_vars
294
+
295
+ def test_zero_length_rvs_not_created(self, model: pm.Model):
296
+ model.add_coord("a", range(2))
297
+ # deterministic case which should not have any new variables
298
+ b = pmx.distributions.R2D2M2CP("b1", 1, [1, 1], r2=0.5, positive_probs=[1, 1], dims="a")
299
+ assert not model.free_RVs, model.free_RVs
300
+
301
+ b = pmx.distributions.R2D2M2CP(
302
+ "b2", 1, [1, 1], r2=0.5, positive_probs=[1, 1], positive_probs_std=[0, 0], dims="a"
303
+ )
304
+ assert not model.free_RVs, model.free_RVs
File without changes
File without changes
@@ -0,0 +1,131 @@
1
+ import numpy as np
2
+ import pymc as pm
3
+ import pytest
4
+
5
+ from pymc.logprob.abstract import _logprob
6
+ from pytensor import tensor as pt
7
+ from scipy.stats import norm
8
+
9
+ from pymc_extras import MarginalModel
10
+ from pymc_extras.distributions import DiscreteMarkovChain
11
+ from pymc_extras.model.marginal.distributions import MarginalFiniteDiscreteRV
12
+
13
+
14
+ def test_marginalized_bernoulli_logp():
15
+ """Test logp of IR TestFiniteMarginalDiscreteRV directly"""
16
+ mu = pt.vector("mu")
17
+
18
+ idx = pm.Bernoulli.dist(0.7, name="idx")
19
+ y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y")
20
+ marginal_rv_node = MarginalFiniteDiscreteRV(
21
+ [mu],
22
+ [idx, y],
23
+ dims_connections=(((),),),
24
+ )(mu)[0].owner
25
+
26
+ y_vv = y.clone()
27
+ (logp,) = _logprob(
28
+ marginal_rv_node.op,
29
+ (y_vv,),
30
+ *marginal_rv_node.inputs,
31
+ )
32
+
33
+ ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=mu, sigma=1.0), y_vv)
34
+ np.testing.assert_almost_equal(
35
+ logp.eval({mu: [-1, 1], y_vv: 2}),
36
+ ref_logp.eval({mu: [-1, 1], y_vv: 2}),
37
+ )
38
+
39
+
40
+ @pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}")
41
+ @pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}")
42
+ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
43
+ if batch_chain and not batch_emission:
44
+ pytest.skip("Redundant implicit combination")
45
+
46
+ with MarginalModel() as m:
47
+ P = [[0, 1], [1, 0]]
48
+ init_dist = pm.Categorical.dist(p=[1, 0])
49
+ chain = DiscreteMarkovChain(
50
+ "chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None
51
+ )
52
+ emission = pm.Normal(
53
+ "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None
54
+ )
55
+
56
+ m.marginalize([chain])
57
+ logp_fn = m.compile_logp()
58
+
59
+ test_value = np.array([-1, 1, -1, 1])
60
+ expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
61
+ if batch_emission:
62
+ test_value = np.broadcast_to(test_value, (3, 4))
63
+ expected_logp *= 3
64
+ np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
65
+
66
+
67
+ @pytest.mark.parametrize(
68
+ "categorical_emission",
69
+ [False, True],
70
+ )
71
+ def test_marginalized_hmm_categorical_emission(categorical_emission):
72
+ """Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
73
+ with MarginalModel() as m:
74
+ P = np.array([[0.5, 0.5], [0.3, 0.7]])
75
+ init_dist = pm.Categorical.dist(p=[0.375, 0.625])
76
+ chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
77
+ if categorical_emission:
78
+ emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain])
79
+ else:
80
+ emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
81
+ m.marginalize([chain])
82
+
83
+ test_value = np.array([0, 0, 1])
84
+ expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video
85
+ logp_fn = m.compile_logp()
86
+ np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
87
+
88
+
89
+ @pytest.mark.parametrize("batch_chain", (False, True))
90
+ @pytest.mark.parametrize("batch_emission1", (False, True))
91
+ @pytest.mark.parametrize("batch_emission2", (False, True))
92
+ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch_emission2):
93
+ chain_shape = (3, 1, 4) if batch_chain else (4,)
94
+ emission1_shape = (
95
+ (2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape))
96
+ )
97
+ emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape
98
+ with MarginalModel() as m:
99
+ P = [[0, 1], [1, 0]]
100
+ init_dist = pm.Categorical.dist(p=[1, 0])
101
+ chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape)
102
+ emission_1 = pm.Normal(
103
+ "emission_1", mu=(chain * 2 - 1).T, sigma=1e-1, shape=emission1_shape
104
+ )
105
+
106
+ emission2_mu = (1 - chain) * 2 - 1
107
+ if batch_emission2:
108
+ emission2_mu = emission2_mu[..., None]
109
+ emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape)
110
+
111
+ with pytest.warns(UserWarning, match="multiple dependent variables"):
112
+ m.marginalize([chain])
113
+
114
+ logp_fn = m.compile_logp(sum=False)
115
+
116
+ test_value = np.array([-1, 1, -1, 1])
117
+ multiplier = 2 + batch_emission1 + batch_emission2
118
+ if batch_chain:
119
+ multiplier *= 3
120
+ expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier
121
+
122
+ test_value = np.broadcast_to(test_value, chain_shape)
123
+ test_value_emission1 = np.broadcast_to(test_value.T, emission1_shape)
124
+ if batch_emission2:
125
+ test_value_emission2 = np.broadcast_to(-test_value[..., None], emission2_shape)
126
+ else:
127
+ test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
128
+ test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2}
129
+ res_logp, dummy_logp = logp_fn(test_point)
130
+ assert res_logp.shape == ((1, 3) if batch_chain else ())
131
+ np.testing.assert_allclose(res_logp.sum(), expected_logp)
@@ -0,0 +1,182 @@
1
+ import pytensor.tensor as pt
2
+ import pytest
3
+
4
+ from pymc.distributions import CustomDist
5
+ from pytensor.tensor.type_other import NoneTypeT
6
+
7
+ from pymc_extras.model.marginal.graph_analysis import (
8
+ is_conditional_dependent,
9
+ subgraph_batch_dim_connection,
10
+ )
11
+
12
+
13
+ def test_is_conditional_dependent_static_shape():
14
+ """Test that we don't consider dependencies through "constant" shape Ops"""
15
+ x1 = pt.matrix("x1", shape=(None, 5))
16
+ y1 = pt.random.normal(size=pt.shape(x1))
17
+ assert is_conditional_dependent(y1, x1, [x1, y1])
18
+
19
+ x2 = pt.matrix("x2", shape=(9, 5))
20
+ y2 = pt.random.normal(size=pt.shape(x2))
21
+ assert not is_conditional_dependent(y2, x2, [x2, y2])
22
+
23
+
24
+ class TestSubgraphBatchDimConnection:
25
+ def test_dimshuffle(self):
26
+ inp = pt.tensor(shape=(5, 1, 4, 3))
27
+ out1 = pt.matrix_transpose(inp)
28
+ out2 = pt.expand_dims(inp, 1)
29
+ out3 = pt.squeeze(inp)
30
+ [dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3])
31
+ assert dims1 == (0, 1, 3, 2)
32
+ assert dims2 == (0, None, 1, 2, 3)
33
+ assert dims3 == (0, 2, 3)
34
+
35
+ def test_careduce(self):
36
+ inp = pt.tensor(shape=(4, 3, 2))
37
+
38
+ out = pt.sum(inp[:, None], axis=(1,))
39
+ [dims] = subgraph_batch_dim_connection(inp, [out])
40
+ assert dims == (0, 1, 2)
41
+
42
+ invalid_out = pt.sum(inp, axis=(1,))
43
+ with pytest.raises(ValueError, match="Use of known dimensions"):
44
+ subgraph_batch_dim_connection(inp, [invalid_out])
45
+
46
+ def test_subtensor(self):
47
+ inp = pt.tensor(shape=(4, 3, 2))
48
+
49
+ invalid_out = inp[0, :1]
50
+ with pytest.raises(
51
+ ValueError,
52
+ match="Partial slicing or indexing of known dimensions not supported",
53
+ ):
54
+ subgraph_batch_dim_connection(inp, [invalid_out])
55
+
56
+ # If we are selecting dummy / unknown dimensions that's fine
57
+ valid_out = pt.expand_dims(inp, (0, 1))[0, :1]
58
+ [dims] = subgraph_batch_dim_connection(inp, [valid_out])
59
+ assert dims == (None, 0, 1, 2)
60
+
61
+ def test_advanced_subtensor_value(self):
62
+ inp = pt.tensor(shape=(2, 4))
63
+ intermediate_out = inp[:, None, :, None] + pt.zeros((2, 3, 4, 5))
64
+
65
+ # Index on an unlabled dim introduced by broadcasting with zeros
66
+ out = intermediate_out[:, [0, 0, 1, 2]]
67
+ [dims] = subgraph_batch_dim_connection(inp, [out])
68
+ assert dims == (0, None, 1, None)
69
+
70
+ # Indexing that introduces more dimensions
71
+ out = intermediate_out[:, [[0, 0], [1, 2]], :]
72
+ [dims] = subgraph_batch_dim_connection(inp, [out])
73
+ assert dims == (0, None, None, 1, None)
74
+
75
+ # Special case where advanced dims are moved to the front of the output
76
+ out = intermediate_out[:, [0, 0, 1, 2], :, 0]
77
+ [dims] = subgraph_batch_dim_connection(inp, [out])
78
+ assert dims == (None, 0, 1)
79
+
80
+ # Indexing on a labeled dim fails
81
+ out = intermediate_out[:, :, [0, 0, 1, 2]]
82
+ with pytest.raises(ValueError, match="Partial slicing or advanced integer indexing"):
83
+ subgraph_batch_dim_connection(inp, [out])
84
+
85
+ def test_advanced_subtensor_key(self):
86
+ inp = pt.tensor(shape=(5, 5), dtype=int)
87
+ base = pt.zeros((2, 3, 4))
88
+
89
+ out = base[inp]
90
+ [dims] = subgraph_batch_dim_connection(inp, [out])
91
+ assert dims == (0, 1, None, None)
92
+
93
+ out = base[:, :, inp]
94
+ [dims] = subgraph_batch_dim_connection(inp, [out])
95
+ assert dims == (
96
+ None,
97
+ None,
98
+ 0,
99
+ 1,
100
+ )
101
+
102
+ out = base[1:, 0, inp]
103
+ [dims] = subgraph_batch_dim_connection(inp, [out])
104
+ assert dims == (None, 0, 1)
105
+
106
+ # Special case where advanced dims are moved to the front of the output
107
+ out = base[0, :, inp]
108
+ [dims] = subgraph_batch_dim_connection(inp, [out])
109
+ assert dims == (0, 1, None)
110
+
111
+ # Mix keys dimensions
112
+ out = base[:, inp, inp.T]
113
+ with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"):
114
+ subgraph_batch_dim_connection(inp, [out])
115
+
116
+ def test_elemwise(self):
117
+ inp = pt.tensor(shape=(5, 5))
118
+
119
+ out = inp + inp
120
+ [dims] = subgraph_batch_dim_connection(inp, [out])
121
+ assert dims == (0, 1)
122
+
123
+ out = inp + inp.T
124
+ with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"):
125
+ subgraph_batch_dim_connection(inp, [out])
126
+
127
+ out = inp[None, :, None, :] + inp[:, None, :, None]
128
+ with pytest.raises(
129
+ ValueError, match="Same known dimension used in different axis after broadcasting"
130
+ ):
131
+ subgraph_batch_dim_connection(inp, [out])
132
+
133
+ def test_blockwise(self):
134
+ inp = pt.tensor(shape=(5, 4))
135
+
136
+ invalid_out = inp @ pt.ones((4, 3))
137
+ with pytest.raises(ValueError, match="Use of known dimensions"):
138
+ subgraph_batch_dim_connection(inp, [invalid_out])
139
+
140
+ out = (inp[:, :, None, None] + pt.zeros((2, 3))) @ pt.ones((2, 3))
141
+ [dims] = subgraph_batch_dim_connection(inp, [out])
142
+ assert dims == (0, 1, None, None)
143
+
144
+ def test_random_variable(self):
145
+ inp = pt.tensor(shape=(5, 4, 3))
146
+
147
+ out1 = pt.random.normal(loc=inp)
148
+ out2 = pt.random.categorical(p=inp[..., None])
149
+ out3 = pt.random.multivariate_normal(mean=inp[..., None], cov=pt.eye(1))
150
+ [dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3])
151
+ assert dims1 == (0, 1, 2)
152
+ assert dims2 == (0, 1, 2)
153
+ assert dims3 == (0, 1, 2, None)
154
+
155
+ invalid_out = pt.random.categorical(p=inp)
156
+ with pytest.raises(ValueError, match="Use of known dimensions"):
157
+ subgraph_batch_dim_connection(inp, [invalid_out])
158
+
159
+ invalid_out = pt.random.multivariate_normal(mean=inp, cov=pt.eye(3))
160
+ with pytest.raises(ValueError, match="Use of known dimensions"):
161
+ subgraph_batch_dim_connection(inp, [invalid_out])
162
+
163
+ def test_symbolic_random_variable(self):
164
+ inp = pt.tensor(shape=(4, 3, 2))
165
+
166
+ # Test univariate
167
+ out = CustomDist.dist(
168
+ inp,
169
+ dist=lambda mu, size: pt.random.normal(loc=mu, size=size),
170
+ )
171
+ [dims] = subgraph_batch_dim_connection(inp, [out])
172
+ assert dims == (0, 1, 2)
173
+
174
+ # Test multivariate
175
+ def dist(mu, size):
176
+ if isinstance(size.type, NoneTypeT):
177
+ size = mu.shape
178
+ return pt.random.normal(loc=mu[..., None], size=(*size, 2))
179
+
180
+ out = CustomDist.dist(inp, dist=dist, size=(4, 3, 2), signature="()->(2)")
181
+ [dims] = subgraph_batch_dim_connection(inp, [out])
182
+ assert dims == (0, 1, 2, None)