pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +6 -4
- pymc_extras/distributions/__init__.py +2 -0
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/distributions/transforms/__init__.py +3 -0
- pymc_extras/distributions/transforms/partial_order.py +227 -0
- pymc_extras/inference/__init__.py +4 -2
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/fit.py +6 -4
- pymc_extras/inference/laplace.py +14 -8
- pymc_extras/inference/pathfinder/lbfgs.py +49 -13
- pymc_extras/inference/pathfinder/pathfinder.py +89 -103
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras/statespace/utils/data_tools.py +24 -9
- pymc_extras-0.2.6.dist-info/METADATA +318 -0
- pymc_extras-0.2.6.dist-info/RECORD +65 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.4.dist-info/METADATA +0 -110
- pymc_extras-0.2.4.dist-info/RECORD +0 -105
- pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -116
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -265
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -203
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
|
File without changes
|
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import statsmodels.api as sm
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class LocalLinearTrend(sm.tsa.statespace.MLEModel):
|
|
6
|
-
def __init__(self, endog, **kwargs):
|
|
7
|
-
# Model order
|
|
8
|
-
k_states = k_posdef = 2
|
|
9
|
-
|
|
10
|
-
# Initialize the statespace
|
|
11
|
-
super().__init__(endog, k_states=k_states, k_posdef=k_posdef, **kwargs)
|
|
12
|
-
|
|
13
|
-
# Initialize the matrices
|
|
14
|
-
self.ssm["design"] = np.array([1, 0])
|
|
15
|
-
self.ssm["transition"] = np.array([[1, 1], [0, 1]])
|
|
16
|
-
self.ssm["selection"] = np.eye(k_states)
|
|
17
|
-
|
|
18
|
-
# Cache some indices
|
|
19
|
-
self._state_cov_idx = ("state_cov", *np.diag_indices(k_posdef))
|
|
20
|
-
|
|
21
|
-
@property
|
|
22
|
-
def param_names(self):
|
|
23
|
-
return ["sigma2.measurement", "sigma2.level", "sigma2.trend"]
|
|
24
|
-
|
|
25
|
-
@property
|
|
26
|
-
def start_params(self):
|
|
27
|
-
return [np.std(self.endog)] * 3
|
|
28
|
-
|
|
29
|
-
def transform_params(self, unconstrained):
|
|
30
|
-
return unconstrained**2
|
|
31
|
-
|
|
32
|
-
def untransform_params(self, constrained):
|
|
33
|
-
return constrained**0.5
|
|
34
|
-
|
|
35
|
-
def update(self, params, *args, **kwargs):
|
|
36
|
-
params = super().update(params, *args, **kwargs)
|
|
37
|
-
|
|
38
|
-
# Observation covariance
|
|
39
|
-
self.ssm["obs_cov", 0, 0] = params[0]
|
|
40
|
-
|
|
41
|
-
# State covariance
|
|
42
|
-
self.ssm[self._state_cov_idx] = params[1:]
|
|
@@ -1,310 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import pytensor
|
|
4
|
-
import pytensor.tensor as pt
|
|
5
|
-
import statsmodels.api as sm
|
|
6
|
-
|
|
7
|
-
from numpy.testing import assert_allclose
|
|
8
|
-
from pymc import modelcontext
|
|
9
|
-
|
|
10
|
-
from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother
|
|
11
|
-
from pymc_extras.statespace.utils.constants import (
|
|
12
|
-
MATRIX_NAMES,
|
|
13
|
-
SHORT_NAME_TO_LONG,
|
|
14
|
-
)
|
|
15
|
-
from tests.statespace.utilities.statsmodel_local_level import LocalLinearTrend
|
|
16
|
-
|
|
17
|
-
floatX = pytensor.config.floatX
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def load_nile_test_data():
|
|
21
|
-
from importlib.metadata import version
|
|
22
|
-
|
|
23
|
-
nile = pd.read_csv("tests/statespace/test_data/nile.csv", dtype={"x": floatX})
|
|
24
|
-
major, minor, rev = map(int, version("pandas").split("."))
|
|
25
|
-
if major >= 2 and minor >= 2 and rev >= 0:
|
|
26
|
-
freq_str = "YS-JAN"
|
|
27
|
-
else:
|
|
28
|
-
freq_str = "AS-JAN"
|
|
29
|
-
nile.index = pd.date_range(start="1871-01-01", end="1970-01-01", freq=freq_str)
|
|
30
|
-
nile.rename(columns={"x": "height"}, inplace=True)
|
|
31
|
-
nile = (nile - nile.mean()) / nile.std()
|
|
32
|
-
nile = nile.astype(floatX)
|
|
33
|
-
|
|
34
|
-
return nile
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None):
|
|
38
|
-
ksmoother = KalmanSmoother()
|
|
39
|
-
data = pt.tensor(name="data", dtype=floatX, shape=(n, p))
|
|
40
|
-
a0 = pt.tensor(name="x0", dtype=floatX, shape=(m,))
|
|
41
|
-
P0 = pt.tensor(name="P0", dtype=floatX, shape=(m, m))
|
|
42
|
-
c = pt.tensor(name="c", dtype=floatX, shape=(m,))
|
|
43
|
-
d = pt.tensor(name="d", dtype=floatX, shape=(p,))
|
|
44
|
-
Q = pt.tensor(name="Q", dtype=floatX, shape=(r, r))
|
|
45
|
-
H = pt.tensor(name="H", dtype=floatX, shape=(p, p))
|
|
46
|
-
T = pt.tensor(name="T", dtype=floatX, shape=(m, m))
|
|
47
|
-
R = pt.tensor(name="R", dtype=floatX, shape=(m, r))
|
|
48
|
-
Z = pt.tensor(name="Z", dtype=floatX, shape=(p, m))
|
|
49
|
-
|
|
50
|
-
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
|
|
51
|
-
|
|
52
|
-
(
|
|
53
|
-
filtered_states,
|
|
54
|
-
predicted_states,
|
|
55
|
-
observed_states,
|
|
56
|
-
filtered_covs,
|
|
57
|
-
predicted_covs,
|
|
58
|
-
observed_covs,
|
|
59
|
-
ll_obs,
|
|
60
|
-
) = kfilter.build_graph(*inputs, mode=mode)
|
|
61
|
-
|
|
62
|
-
smoothed_states, smoothed_covs = ksmoother.build_graph(T, R, Q, filtered_states, filtered_covs)
|
|
63
|
-
|
|
64
|
-
outputs = [
|
|
65
|
-
filtered_states,
|
|
66
|
-
predicted_states,
|
|
67
|
-
smoothed_states,
|
|
68
|
-
filtered_covs,
|
|
69
|
-
predicted_covs,
|
|
70
|
-
smoothed_covs,
|
|
71
|
-
ll_obs.sum(),
|
|
72
|
-
ll_obs,
|
|
73
|
-
]
|
|
74
|
-
|
|
75
|
-
return inputs, outputs
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def add_missing_data(data, n_missing, rng):
|
|
79
|
-
n = data.shape[0]
|
|
80
|
-
missing_idx = rng.choice(n, n_missing, replace=False)
|
|
81
|
-
data[missing_idx] = np.nan
|
|
82
|
-
|
|
83
|
-
return data
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def make_test_inputs(p, m, r, n, rng, missing_data=None, H_is_zero=False):
|
|
87
|
-
data = np.arange(n * p, dtype=floatX).reshape(-1, p)
|
|
88
|
-
if missing_data is not None:
|
|
89
|
-
data = add_missing_data(data, missing_data, rng)
|
|
90
|
-
|
|
91
|
-
a0 = np.zeros(m, dtype=floatX)
|
|
92
|
-
P0 = np.eye(m, dtype=floatX)
|
|
93
|
-
c = np.zeros(m, dtype=floatX)
|
|
94
|
-
d = np.zeros(p, dtype=floatX)
|
|
95
|
-
Q = np.eye(r, dtype=floatX)
|
|
96
|
-
H = np.zeros((p, p), dtype=floatX) if H_is_zero else np.eye(p, dtype=floatX)
|
|
97
|
-
T = np.eye(m, k=-1, dtype=floatX)
|
|
98
|
-
T[0, :] = 1 / m
|
|
99
|
-
R = np.eye(m, dtype=floatX)[:, :r]
|
|
100
|
-
Z = np.eye(m, dtype=floatX)[:p, :]
|
|
101
|
-
|
|
102
|
-
data, a0, P0, c, d, T, Z, R, H, Q = map(
|
|
103
|
-
np.ascontiguousarray, [data, a0, P0, c, d, T, Z, R, H, Q]
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
return data, a0, P0, c, d, T, Z, R, H, Q
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
def get_expected_shape(name, p, m, r, n):
|
|
110
|
-
if name == "log_likelihood":
|
|
111
|
-
return ()
|
|
112
|
-
elif name == "ll_obs":
|
|
113
|
-
return (n,)
|
|
114
|
-
filter_type, variable = name.split("_")
|
|
115
|
-
if variable == "states":
|
|
116
|
-
return n, m
|
|
117
|
-
if variable == "covs":
|
|
118
|
-
return n, m, m
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def get_sm_state_from_output_name(res, name):
|
|
122
|
-
if name == "log_likelihood":
|
|
123
|
-
return res.llf
|
|
124
|
-
elif name == "ll_obs":
|
|
125
|
-
return res.llf_obs
|
|
126
|
-
|
|
127
|
-
filter_type, variable = name.split("_")
|
|
128
|
-
sm_states = getattr(res, "states")
|
|
129
|
-
|
|
130
|
-
if variable == "states":
|
|
131
|
-
return getattr(sm_states, filter_type)
|
|
132
|
-
if variable == "covs":
|
|
133
|
-
m = res.filter_results.k_states
|
|
134
|
-
# remove the "s" from "covs"
|
|
135
|
-
return getattr(sm_states, name[:-1]).reshape(-1, m, m)
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def nile_test_test_helper(rng, n_missing=0):
|
|
139
|
-
a0 = np.zeros(2, dtype=floatX)
|
|
140
|
-
P0 = np.eye(2, dtype=floatX) * 1e6
|
|
141
|
-
c = np.zeros(2, dtype=floatX)
|
|
142
|
-
d = np.zeros(1, dtype=floatX)
|
|
143
|
-
Q = np.eye(2, dtype=floatX) * np.array([0.5, 0.01], dtype=floatX)
|
|
144
|
-
H = np.eye(1, dtype=floatX) * 0.8
|
|
145
|
-
T = np.array([[1.0, 1.0], [0.0, 1.0]], dtype=floatX)
|
|
146
|
-
R = np.eye(2, dtype=floatX)
|
|
147
|
-
Z = np.array([[1.0, 0.0]], dtype=floatX)
|
|
148
|
-
|
|
149
|
-
data = load_nile_test_data().values
|
|
150
|
-
if n_missing > 0:
|
|
151
|
-
data = add_missing_data(data, n_missing, rng)
|
|
152
|
-
|
|
153
|
-
sm_model = LocalLinearTrend(
|
|
154
|
-
endog=data,
|
|
155
|
-
initialization="known",
|
|
156
|
-
initial_state_cov=P0,
|
|
157
|
-
initial_state=a0.ravel(),
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
res = sm_model.fit_constrained(
|
|
161
|
-
constraints={
|
|
162
|
-
"sigma2.measurement": 0.8,
|
|
163
|
-
"sigma2.level": 0.5,
|
|
164
|
-
"sigma2.trend": 0.01,
|
|
165
|
-
}
|
|
166
|
-
)
|
|
167
|
-
|
|
168
|
-
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
|
|
169
|
-
|
|
170
|
-
return res, inputs
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
def fast_eval(var):
|
|
174
|
-
return pytensor.function([], var, mode="FAST_COMPILE")()
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
def delete_rvs_from_model(rv_names: list[str]) -> None:
|
|
178
|
-
"""Remove all model mappings referring to rv
|
|
179
|
-
|
|
180
|
-
This can be used to "delete" an RV from a model
|
|
181
|
-
"""
|
|
182
|
-
mod = modelcontext(None)
|
|
183
|
-
all_rvs = mod.basic_RVs + mod.deterministics
|
|
184
|
-
all_rv_names = [x.name for x in all_rvs]
|
|
185
|
-
|
|
186
|
-
for name in rv_names:
|
|
187
|
-
assert name in all_rv_names, f"{name} is not part of the Model: {all_rv_names}"
|
|
188
|
-
|
|
189
|
-
rv_idx = all_rv_names.index(name)
|
|
190
|
-
rv = all_rvs[rv_idx]
|
|
191
|
-
|
|
192
|
-
mod.named_vars.pop(name)
|
|
193
|
-
if name in mod.named_vars_to_dims:
|
|
194
|
-
mod.named_vars_to_dims.pop(name)
|
|
195
|
-
|
|
196
|
-
if rv in mod.deterministics:
|
|
197
|
-
mod.deterministics.remove(rv)
|
|
198
|
-
continue
|
|
199
|
-
|
|
200
|
-
value = mod.rvs_to_values.pop(rv)
|
|
201
|
-
mod.values_to_rvs.pop(value)
|
|
202
|
-
mod.rvs_to_transforms.pop(rv)
|
|
203
|
-
if rv in mod.free_RVs:
|
|
204
|
-
mod.free_RVs.remove(rv)
|
|
205
|
-
mod.rvs_to_initial_values.pop(rv)
|
|
206
|
-
else:
|
|
207
|
-
mod.observed_RVs.remove(rv)
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
def unpack_statespace(ssm):
|
|
211
|
-
return [ssm[SHORT_NAME_TO_LONG[x]] for x in MATRIX_NAMES]
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def unpack_symbolic_matrices_with_params(mod, param_dict, data_dict=None, mode="FAST_COMPILE"):
|
|
215
|
-
inputs = list(mod._name_to_variable.values())
|
|
216
|
-
if data_dict is not None:
|
|
217
|
-
inputs += list(mod._name_to_data.values())
|
|
218
|
-
else:
|
|
219
|
-
data_dict = {}
|
|
220
|
-
|
|
221
|
-
f_matrices = pytensor.function(
|
|
222
|
-
inputs,
|
|
223
|
-
unpack_statespace(mod.ssm),
|
|
224
|
-
on_unused_input="raise",
|
|
225
|
-
mode=mode,
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
x0, P0, c, d, T, Z, R, H, Q = f_matrices(**param_dict, **data_dict)
|
|
229
|
-
|
|
230
|
-
return x0, P0, c, d, T, Z, R, H, Q
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
def simulate_from_numpy_model(mod, rng, param_dict, data_dict=None, steps=100):
|
|
234
|
-
"""
|
|
235
|
-
Helper function to visualize the components outside of a PyMC model context
|
|
236
|
-
"""
|
|
237
|
-
x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, param_dict, data_dict)
|
|
238
|
-
k_states = mod.k_states
|
|
239
|
-
k_posdef = mod.k_posdef
|
|
240
|
-
|
|
241
|
-
x = np.zeros((steps, k_states))
|
|
242
|
-
y = np.zeros(steps)
|
|
243
|
-
|
|
244
|
-
x[0] = x0
|
|
245
|
-
y[0] = (Z @ x0).squeeze() if Z.ndim == 2 else (Z[0] @ x0).squeeze()
|
|
246
|
-
|
|
247
|
-
if not np.allclose(H, 0):
|
|
248
|
-
y[0] += rng.multivariate_normal(mean=np.zeros(1), cov=H).squeeze()
|
|
249
|
-
|
|
250
|
-
for t in range(1, steps):
|
|
251
|
-
if k_posdef > 0:
|
|
252
|
-
shock = rng.multivariate_normal(mean=np.zeros(k_posdef), cov=Q)
|
|
253
|
-
innov = R @ shock
|
|
254
|
-
else:
|
|
255
|
-
innov = 0
|
|
256
|
-
|
|
257
|
-
if not np.allclose(H, 0):
|
|
258
|
-
error = rng.multivariate_normal(mean=np.zeros(1), cov=H)
|
|
259
|
-
else:
|
|
260
|
-
error = 0
|
|
261
|
-
|
|
262
|
-
x[t] = c + T @ x[t - 1] + innov
|
|
263
|
-
if Z.ndim == 2:
|
|
264
|
-
y[t] = (d + Z @ x[t] + error).squeeze()
|
|
265
|
-
else:
|
|
266
|
-
y[t] = (d + Z[t] @ x[t] + error).squeeze()
|
|
267
|
-
|
|
268
|
-
return x, y
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
def assert_pattern_repeats(y, T, atol, rtol):
|
|
272
|
-
val = np.diff(y.reshape(-1, T), axis=0)
|
|
273
|
-
if floatX.endswith("64"):
|
|
274
|
-
# Round this before going into the test, otherwise it behaves poorly (atol = inf)
|
|
275
|
-
n_digits = len(str(1 / atol))
|
|
276
|
-
val = np.round(val, n_digits)
|
|
277
|
-
|
|
278
|
-
assert_allclose(
|
|
279
|
-
val,
|
|
280
|
-
0,
|
|
281
|
-
err_msg="seasonal pattern does not repeat",
|
|
282
|
-
atol=atol,
|
|
283
|
-
rtol=rtol,
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
def make_stationary_params(data, p, d, q, P, D, Q, S):
|
|
288
|
-
sm_sarimax = sm.tsa.SARIMAX(data, order=(p, d, q), seasonal_order=(P, D, Q, S))
|
|
289
|
-
res = sm_sarimax.fit(disp=False)
|
|
290
|
-
|
|
291
|
-
param_dict = dict(ar_params=[], ma_params=[], seasonal_ar_params=[], seasonal_ma_params=[])
|
|
292
|
-
|
|
293
|
-
for name, param in zip(res.param_names, res.params):
|
|
294
|
-
if name.startswith("ar.S"):
|
|
295
|
-
param_dict["seasonal_ar_params"].append(param)
|
|
296
|
-
elif name.startswith("ma.S"):
|
|
297
|
-
param_dict["seasonal_ma_params"].append(param)
|
|
298
|
-
elif name.startswith("ar."):
|
|
299
|
-
param_dict["ar_params"].append(param)
|
|
300
|
-
elif name.startswith("ma."):
|
|
301
|
-
param_dict["ma_params"].append(param)
|
|
302
|
-
else:
|
|
303
|
-
param_dict["sigma_state"] = param
|
|
304
|
-
|
|
305
|
-
param_dict = {
|
|
306
|
-
k: np.array(v, dtype=floatX)
|
|
307
|
-
for k, v in param_dict.items()
|
|
308
|
-
if isinstance(v, float) or len(v) > 0
|
|
309
|
-
}
|
|
310
|
-
return param_dict
|
tests/test_blackjax_smc.py
DELETED
|
@@ -1,222 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The PyMC Developers
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
import numpy as np
|
|
15
|
-
import pymc as pm
|
|
16
|
-
import pytensor.tensor as pt
|
|
17
|
-
import pytest
|
|
18
|
-
import scipy
|
|
19
|
-
|
|
20
|
-
from numpy import dtype
|
|
21
|
-
from xarray.core.utils import Frozen
|
|
22
|
-
|
|
23
|
-
jax = pytest.importorskip("jax")
|
|
24
|
-
pytest.importorskip("blackjax")
|
|
25
|
-
|
|
26
|
-
from pymc_extras.inference.smc.sampling import (
|
|
27
|
-
arviz_from_particles,
|
|
28
|
-
blackjax_particles_from_pymc_population,
|
|
29
|
-
get_jaxified_loglikelihood,
|
|
30
|
-
get_jaxified_logprior,
|
|
31
|
-
sample_smc_blackjax,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def two_gaussians_model():
|
|
36
|
-
n = 4
|
|
37
|
-
mu1 = np.ones(n) * 0.5
|
|
38
|
-
mu2 = -mu1
|
|
39
|
-
|
|
40
|
-
stdev = 0.1
|
|
41
|
-
sigma = np.power(stdev, 2) * np.eye(n)
|
|
42
|
-
isigma = np.linalg.inv(sigma)
|
|
43
|
-
dsigma = np.linalg.det(sigma)
|
|
44
|
-
|
|
45
|
-
w1 = stdev
|
|
46
|
-
w2 = 1 - stdev
|
|
47
|
-
|
|
48
|
-
def two_gaussians(x):
|
|
49
|
-
"""
|
|
50
|
-
Mixture of gaussians likelihood
|
|
51
|
-
"""
|
|
52
|
-
log_like1 = (
|
|
53
|
-
-0.5 * n * pt.log(2 * np.pi)
|
|
54
|
-
- 0.5 * pt.log(dsigma)
|
|
55
|
-
- 0.5 * (x - mu1).T.dot(isigma).dot(x - mu1)
|
|
56
|
-
)
|
|
57
|
-
log_like2 = (
|
|
58
|
-
-0.5 * n * pt.log(2 * np.pi)
|
|
59
|
-
- 0.5 * pt.log(dsigma)
|
|
60
|
-
- 0.5 * (x - mu2).T.dot(isigma).dot(x - mu2)
|
|
61
|
-
)
|
|
62
|
-
return pt.log(w1 * pt.exp(log_like1) + w2 * pt.exp(log_like2))
|
|
63
|
-
|
|
64
|
-
with pm.Model() as m:
|
|
65
|
-
X = pm.Uniform("X", lower=-2, upper=2.0, shape=n)
|
|
66
|
-
llk = pm.Potential("muh", two_gaussians(X))
|
|
67
|
-
|
|
68
|
-
return m, mu1
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def fast_model():
|
|
72
|
-
with pm.Model() as m:
|
|
73
|
-
x = pm.Normal("x", 0, 1)
|
|
74
|
-
y = pm.Normal("y", x, 1, observed=0)
|
|
75
|
-
return m
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
@pytest.mark.parametrize(
|
|
79
|
-
"kernel, check_for_integration_steps, inner_kernel_params",
|
|
80
|
-
[
|
|
81
|
-
("HMC", True, {"step_size": 0.1, "integration_steps": 11}),
|
|
82
|
-
("NUTS", False, {"step_size": 0.1}),
|
|
83
|
-
],
|
|
84
|
-
)
|
|
85
|
-
def test_sample_smc_blackjax(kernel, check_for_integration_steps, inner_kernel_params):
|
|
86
|
-
"""
|
|
87
|
-
When running the two gaussians model
|
|
88
|
-
with BlackJax SMC, we sample them correctly,
|
|
89
|
-
the shape of a posterior variable is (1, particles, dimension)
|
|
90
|
-
and the inference_data has the right attributes.
|
|
91
|
-
|
|
92
|
-
"""
|
|
93
|
-
model, muref = two_gaussians_model()
|
|
94
|
-
iterations_to_diagnose = 2
|
|
95
|
-
n_particles = 1000
|
|
96
|
-
with model:
|
|
97
|
-
inference_data = sample_smc_blackjax(
|
|
98
|
-
n_particles=n_particles,
|
|
99
|
-
kernel=kernel,
|
|
100
|
-
inner_kernel_params=inner_kernel_params,
|
|
101
|
-
iterations_to_diagnose=iterations_to_diagnose,
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
x = inference_data.posterior["X"]
|
|
105
|
-
|
|
106
|
-
assert x.to_numpy().shape == (1, n_particles, 4)
|
|
107
|
-
mu1d = np.abs(x).mean(axis=0).mean(axis=0)
|
|
108
|
-
np.testing.assert_allclose(muref, mu1d, rtol=0.0, atol=0.03)
|
|
109
|
-
|
|
110
|
-
for attribute, value in [
|
|
111
|
-
("particles", n_particles),
|
|
112
|
-
("step_size", 0.1),
|
|
113
|
-
("num_mcmc_steps", 10),
|
|
114
|
-
("iterations_to_diagnose", iterations_to_diagnose),
|
|
115
|
-
("sampler", f"Blackjax SMC with {kernel} kernel"),
|
|
116
|
-
]:
|
|
117
|
-
assert inference_data.posterior.attrs[attribute] == value
|
|
118
|
-
|
|
119
|
-
for diagnostic in ["lambda_evolution", "log_likelihood_increments"]:
|
|
120
|
-
assert inference_data.posterior.attrs[diagnostic].shape == (iterations_to_diagnose,)
|
|
121
|
-
|
|
122
|
-
for diagnostic in ["ancestors_evolution", "weights_evolution"]:
|
|
123
|
-
assert inference_data.posterior.attrs[diagnostic].shape == (
|
|
124
|
-
iterations_to_diagnose,
|
|
125
|
-
n_particles,
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
for attribute in ["running_time_seconds", "iterations"]:
|
|
129
|
-
assert attribute in inference_data.posterior.attrs
|
|
130
|
-
|
|
131
|
-
if check_for_integration_steps:
|
|
132
|
-
assert inference_data.posterior.attrs["integration_steps"] == 11
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
def test_blackjax_particles_from_pymc_population_univariate():
|
|
136
|
-
model = fast_model()
|
|
137
|
-
population = {"x": np.array([2, 3, 4])}
|
|
138
|
-
blackjax_particles = blackjax_particles_from_pymc_population(model, population)
|
|
139
|
-
jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])])
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def test_blackjax_particles_from_pymc_population_multivariate():
|
|
143
|
-
with pm.Model() as model:
|
|
144
|
-
x = pm.Normal("x", 0, 1)
|
|
145
|
-
z = pm.Normal("z", 0, 1)
|
|
146
|
-
y = pm.Normal("y", x + z, 1, observed=0)
|
|
147
|
-
|
|
148
|
-
population = {"x": np.array([0.34614613, 1.09163261, -0.44526825]), "z": np.array([1, 2, 3])}
|
|
149
|
-
blackjax_particles = blackjax_particles_from_pymc_population(model, population)
|
|
150
|
-
jax.tree.map(
|
|
151
|
-
np.testing.assert_allclose,
|
|
152
|
-
blackjax_particles,
|
|
153
|
-
[np.array([[0.34614613], [1.09163261], [-0.44526825]]), np.array([[1], [2], [3]])],
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
def simple_multivariable_model():
|
|
158
|
-
"""
|
|
159
|
-
A simple model that has a multivariate variable,
|
|
160
|
-
a has more than one variable (multivariable)
|
|
161
|
-
"""
|
|
162
|
-
with pm.Model() as model:
|
|
163
|
-
x = pm.Normal("x", 0, 1, shape=2)
|
|
164
|
-
z = pm.Normal("z", 0, 1)
|
|
165
|
-
y = pm.Normal("y", z, 1, observed=0)
|
|
166
|
-
return model
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def test_blackjax_particles_from_pymc_population_multivariable():
|
|
170
|
-
model = simple_multivariable_model()
|
|
171
|
-
population = {"x": np.array([[2, 3], [5, 6], [7, 9]]), "z": np.array([11, 12, 13])}
|
|
172
|
-
blackjax_particles = blackjax_particles_from_pymc_population(model, population)
|
|
173
|
-
|
|
174
|
-
jax.tree.map(
|
|
175
|
-
np.testing.assert_allclose,
|
|
176
|
-
blackjax_particles,
|
|
177
|
-
[np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])],
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
def test_arviz_from_particles():
|
|
182
|
-
model = simple_multivariable_model()
|
|
183
|
-
particles = [np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])]
|
|
184
|
-
with model:
|
|
185
|
-
inference_data = arviz_from_particles(model, particles)
|
|
186
|
-
|
|
187
|
-
assert inference_data.posterior.sizes == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2})
|
|
188
|
-
assert inference_data.posterior.data_vars.dtypes == Frozen(
|
|
189
|
-
{"x": dtype("float64"), "z": dtype("float64")}
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
def test_get_jaxified_logprior():
|
|
194
|
-
"""
|
|
195
|
-
Given a model with a Normal prior
|
|
196
|
-
for a RV, the jaxified logprior
|
|
197
|
-
indeed calculates that number,
|
|
198
|
-
and can be jax.vmap'ed
|
|
199
|
-
"""
|
|
200
|
-
logprior = get_jaxified_logprior(fast_model())
|
|
201
|
-
for point in [-0.5, 0.0, 0.5]:
|
|
202
|
-
jax.tree.map(
|
|
203
|
-
np.testing.assert_allclose,
|
|
204
|
-
jax.vmap(logprior)([np.array([point])]),
|
|
205
|
-
np.log(scipy.stats.norm(0, 1).pdf(point)),
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def test_get_jaxified_loglikelihood():
|
|
210
|
-
"""
|
|
211
|
-
Given a model with a Normal Likelihood, a single observation
|
|
212
|
-
0 and std=1, the only free parameter of that function is the mean.
|
|
213
|
-
When computing the logliklikelihood
|
|
214
|
-
Then the function can be jax.vmap'ed, and the calculation matches the likelihood.
|
|
215
|
-
"""
|
|
216
|
-
loglikelihood = get_jaxified_loglikelihood(fast_model())
|
|
217
|
-
for point in [-0.5, 0.0, 0.5]:
|
|
218
|
-
jax.tree.map(
|
|
219
|
-
np.testing.assert_allclose,
|
|
220
|
-
jax.vmap(loglikelihood)([np.array([point])]),
|
|
221
|
-
np.log(scipy.stats.norm(point, 1).pdf(0)),
|
|
222
|
-
)
|
tests/test_find_map.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pymc as pm
|
|
3
|
-
import pytensor.tensor as pt
|
|
4
|
-
import pytest
|
|
5
|
-
|
|
6
|
-
from pymc_extras.inference.find_map import (
|
|
7
|
-
GradientBackend,
|
|
8
|
-
find_MAP,
|
|
9
|
-
scipy_optimize_funcs_from_loss,
|
|
10
|
-
)
|
|
11
|
-
|
|
12
|
-
pytest.importorskip("jax")
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@pytest.fixture(scope="session")
|
|
16
|
-
def rng():
|
|
17
|
-
seed = sum(map(ord, "test_fit_map"))
|
|
18
|
-
return np.random.default_rng(seed)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
|
|
22
|
-
def test_jax_functions_from_graph(gradient_backend: GradientBackend):
|
|
23
|
-
x = pt.tensor("x", shape=(2,))
|
|
24
|
-
|
|
25
|
-
def compute_z(x):
|
|
26
|
-
z1 = x[0] ** 2 + 2
|
|
27
|
-
z2 = x[0] * x[1] + 3
|
|
28
|
-
return z1, z2
|
|
29
|
-
|
|
30
|
-
z = pt.stack(compute_z(x))
|
|
31
|
-
f_loss, f_hess, f_hessp = scipy_optimize_funcs_from_loss(
|
|
32
|
-
loss=z.sum(),
|
|
33
|
-
inputs=[x],
|
|
34
|
-
initial_point_dict={"x": np.array([1.0, 2.0])},
|
|
35
|
-
use_grad=True,
|
|
36
|
-
use_hess=True,
|
|
37
|
-
use_hessp=True,
|
|
38
|
-
gradient_backend=gradient_backend,
|
|
39
|
-
compile_kwargs=dict(mode="JAX"),
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
x_val = np.array([1.0, 2.0])
|
|
43
|
-
expected_z = sum(compute_z(x_val))
|
|
44
|
-
|
|
45
|
-
z_jax, grad_val = f_loss(x_val)
|
|
46
|
-
np.testing.assert_allclose(z_jax, expected_z)
|
|
47
|
-
np.testing.assert_allclose(grad_val.squeeze(), np.array([2 * x_val[0] + x_val[1], x_val[0]]))
|
|
48
|
-
|
|
49
|
-
hess_val = np.array(f_hess(x_val))
|
|
50
|
-
np.testing.assert_allclose(hess_val.squeeze(), np.array([[2, 1], [1, 0]]))
|
|
51
|
-
|
|
52
|
-
hessp_val = np.array(f_hessp(x_val, np.array([1.0, 0.0])))
|
|
53
|
-
np.testing.assert_allclose(hessp_val.squeeze(), np.array([2, 1]))
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@pytest.mark.parametrize(
|
|
57
|
-
"method, use_grad, use_hess, use_hessp",
|
|
58
|
-
[
|
|
59
|
-
("nelder-mead", False, False, False),
|
|
60
|
-
("powell", False, False, False),
|
|
61
|
-
("CG", True, False, False),
|
|
62
|
-
("BFGS", True, False, False),
|
|
63
|
-
("L-BFGS-B", True, False, False),
|
|
64
|
-
("TNC", True, False, False),
|
|
65
|
-
("SLSQP", True, False, False),
|
|
66
|
-
("dogleg", True, True, False),
|
|
67
|
-
("Newton-CG", True, True, False),
|
|
68
|
-
("Newton-CG", True, False, True),
|
|
69
|
-
("trust-ncg", True, True, False),
|
|
70
|
-
("trust-ncg", True, False, True),
|
|
71
|
-
("trust-exact", True, True, False),
|
|
72
|
-
("trust-krylov", True, True, False),
|
|
73
|
-
("trust-krylov", True, False, True),
|
|
74
|
-
("trust-constr", True, True, False),
|
|
75
|
-
],
|
|
76
|
-
)
|
|
77
|
-
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
|
|
78
|
-
def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng):
|
|
79
|
-
extra_kwargs = {}
|
|
80
|
-
if method == "dogleg":
|
|
81
|
-
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
|
|
82
|
-
# where this is true
|
|
83
|
-
extra_kwargs = {"initvals": {"mu": 2, "sigma_log__": 1}}
|
|
84
|
-
|
|
85
|
-
with pm.Model() as m:
|
|
86
|
-
mu = pm.Normal("mu")
|
|
87
|
-
sigma = pm.Exponential("sigma", 1)
|
|
88
|
-
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100))
|
|
89
|
-
|
|
90
|
-
optimized_point = find_MAP(
|
|
91
|
-
method=method,
|
|
92
|
-
**extra_kwargs,
|
|
93
|
-
use_grad=use_grad,
|
|
94
|
-
use_hess=use_hess,
|
|
95
|
-
use_hessp=use_hessp,
|
|
96
|
-
progressbar=False,
|
|
97
|
-
gradient_backend=gradient_backend,
|
|
98
|
-
compile_kwargs={"mode": "JAX"},
|
|
99
|
-
)
|
|
100
|
-
mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"]
|
|
101
|
-
|
|
102
|
-
assert np.isclose(mu_hat, 3, atol=0.5)
|
|
103
|
-
assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)
|