pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +6 -4
- pymc_extras/distributions/__init__.py +2 -0
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/distributions/transforms/__init__.py +3 -0
- pymc_extras/distributions/transforms/partial_order.py +227 -0
- pymc_extras/inference/__init__.py +4 -2
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/fit.py +6 -4
- pymc_extras/inference/laplace.py +14 -8
- pymc_extras/inference/pathfinder/lbfgs.py +49 -13
- pymc_extras/inference/pathfinder/pathfinder.py +89 -103
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras/statespace/utils/data_tools.py +24 -9
- pymc_extras-0.2.6.dist-info/METADATA +318 -0
- pymc_extras-0.2.6.dist-info/RECORD +65 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.4.dist-info/METADATA +0 -110
- pymc_extras-0.2.4.dist-info/RECORD +0 -105
- pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -116
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -265
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -203
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
|
@@ -37,11 +37,14 @@ class LBFGSHistoryManager:
|
|
|
37
37
|
initial position
|
|
38
38
|
maxiter : int
|
|
39
39
|
maximum number of iterations to store
|
|
40
|
+
epsilon : float
|
|
41
|
+
tolerance for lbfgs update
|
|
40
42
|
"""
|
|
41
43
|
|
|
42
44
|
value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]]
|
|
43
45
|
x0: NDArray[np.float64]
|
|
44
46
|
maxiter: int
|
|
47
|
+
epsilon: float
|
|
45
48
|
x_history: NDArray[np.float64] = field(init=False)
|
|
46
49
|
g_history: NDArray[np.float64] = field(init=False)
|
|
47
50
|
count: int = field(init=False)
|
|
@@ -52,7 +55,7 @@ class LBFGSHistoryManager:
|
|
|
52
55
|
self.count = 0
|
|
53
56
|
|
|
54
57
|
value, grad = self.value_grad_fn(self.x0)
|
|
55
|
-
if
|
|
58
|
+
if self.entry_condition_met(self.x0, value, grad):
|
|
56
59
|
self.add_entry(self.x0, grad)
|
|
57
60
|
|
|
58
61
|
def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None:
|
|
@@ -75,18 +78,39 @@ class LBFGSHistoryManager:
|
|
|
75
78
|
x=self.x_history[: self.count], g=self.g_history[: self.count], count=self.count
|
|
76
79
|
)
|
|
77
80
|
|
|
81
|
+
def entry_condition_met(self, x, value, grad) -> bool:
|
|
82
|
+
"""Checks if the LBFGS iteration should continue."""
|
|
83
|
+
|
|
84
|
+
if np.all(np.isfinite(grad)) and np.isfinite(value) and (self.count < self.maxiter + 1):
|
|
85
|
+
if self.count == 0:
|
|
86
|
+
return True
|
|
87
|
+
else:
|
|
88
|
+
s = x - self.x_history[self.count - 1]
|
|
89
|
+
z = grad - self.g_history[self.count - 1]
|
|
90
|
+
sz = (s * z).sum(axis=-1)
|
|
91
|
+
update = sz > self.epsilon * np.sqrt(np.sum(z**2, axis=-1))
|
|
92
|
+
|
|
93
|
+
if update:
|
|
94
|
+
return True
|
|
95
|
+
else:
|
|
96
|
+
return False
|
|
97
|
+
else:
|
|
98
|
+
return False
|
|
99
|
+
|
|
78
100
|
def __call__(self, x: NDArray[np.float64]) -> None:
|
|
79
101
|
value, grad = self.value_grad_fn(x)
|
|
80
|
-
if
|
|
102
|
+
if self.entry_condition_met(x, value, grad):
|
|
81
103
|
self.add_entry(x, grad)
|
|
82
104
|
|
|
83
105
|
|
|
84
106
|
class LBFGSStatus(Enum):
|
|
85
107
|
CONVERGED = auto()
|
|
86
108
|
MAX_ITER_REACHED = auto()
|
|
87
|
-
|
|
109
|
+
NON_FINITE = auto()
|
|
110
|
+
LOW_UPDATE_PCT = auto()
|
|
88
111
|
# Statuses that lead to Exceptions:
|
|
89
112
|
INIT_FAILED = auto()
|
|
113
|
+
INIT_FAILED_LOW_UPDATE_PCT = auto()
|
|
90
114
|
LBFGS_FAILED = auto()
|
|
91
115
|
|
|
92
116
|
|
|
@@ -101,8 +125,8 @@ class LBFGSException(Exception):
|
|
|
101
125
|
class LBFGSInitFailed(LBFGSException):
|
|
102
126
|
DEFAULT_MESSAGE = "LBFGS failed to initialise."
|
|
103
127
|
|
|
104
|
-
def __init__(self, message=None):
|
|
105
|
-
super().__init__(message or self.DEFAULT_MESSAGE,
|
|
128
|
+
def __init__(self, status: LBFGSStatus, message=None):
|
|
129
|
+
super().__init__(message or self.DEFAULT_MESSAGE, status)
|
|
106
130
|
|
|
107
131
|
|
|
108
132
|
class LBFGS:
|
|
@@ -122,10 +146,12 @@ class LBFGS:
|
|
|
122
146
|
gradient tolerance for convergence, defaults to 1e-8
|
|
123
147
|
maxls : int, optional
|
|
124
148
|
maximum number of line search steps, defaults to 1000
|
|
149
|
+
epsilon : float, optional
|
|
150
|
+
tolerance for lbfgs update, defaults to 1e-8
|
|
125
151
|
"""
|
|
126
152
|
|
|
127
153
|
def __init__(
|
|
128
|
-
self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000
|
|
154
|
+
self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000, epsilon=1e-8
|
|
129
155
|
) -> None:
|
|
130
156
|
self.value_grad_fn = value_grad_fn
|
|
131
157
|
self.maxcor = maxcor
|
|
@@ -133,6 +159,7 @@ class LBFGS:
|
|
|
133
159
|
self.ftol = ftol
|
|
134
160
|
self.gtol = gtol
|
|
135
161
|
self.maxls = maxls
|
|
162
|
+
self.epsilon = epsilon
|
|
136
163
|
|
|
137
164
|
def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
|
|
138
165
|
"""minimizes objective function starting from initial position.
|
|
@@ -157,7 +184,7 @@ class LBFGS:
|
|
|
157
184
|
x0 = np.array(x0, dtype=np.float64)
|
|
158
185
|
|
|
159
186
|
history_manager = LBFGSHistoryManager(
|
|
160
|
-
value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter
|
|
187
|
+
value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter, epsilon=self.epsilon
|
|
161
188
|
)
|
|
162
189
|
|
|
163
190
|
result = minimize(
|
|
@@ -177,13 +204,22 @@ class LBFGS:
|
|
|
177
204
|
history = history_manager.get_history()
|
|
178
205
|
|
|
179
206
|
# warnings and suggestions for LBFGSStatus are displayed at the end
|
|
180
|
-
if
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
207
|
+
# threshold determining if the number of lbfgs updates is low compared to iterations
|
|
208
|
+
low_update_threshold = 3
|
|
209
|
+
|
|
210
|
+
if history.count <= 1: # triggers LBFGSInitFailed
|
|
211
|
+
if result.nit < low_update_threshold:
|
|
184
212
|
lbfgs_status = LBFGSStatus.INIT_FAILED
|
|
185
|
-
|
|
186
|
-
lbfgs_status = LBFGSStatus.
|
|
213
|
+
else:
|
|
214
|
+
lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT
|
|
215
|
+
elif result.status == 1:
|
|
216
|
+
# (result.nit > maxiter) or (result.nit > maxls)
|
|
217
|
+
lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
|
|
218
|
+
elif result.status == 2:
|
|
219
|
+
# precision loss resulting to inf or nan
|
|
220
|
+
lbfgs_status = LBFGSStatus.NON_FINITE
|
|
221
|
+
elif history.count * low_update_threshold < result.nit:
|
|
222
|
+
lbfgs_status = LBFGSStatus.LOW_UPDATE_PCT
|
|
187
223
|
else:
|
|
188
224
|
lbfgs_status = LBFGSStatus.CONVERGED
|
|
189
225
|
|
|
@@ -12,22 +12,19 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
|
|
15
16
|
import collections
|
|
16
17
|
import logging
|
|
17
18
|
import time
|
|
18
|
-
import warnings as _warnings
|
|
19
19
|
|
|
20
20
|
from collections import Counter
|
|
21
21
|
from collections.abc import Callable, Iterator
|
|
22
22
|
from dataclasses import asdict, dataclass, field, replace
|
|
23
23
|
from enum import Enum, auto
|
|
24
|
-
from importlib.util import find_spec
|
|
25
24
|
from typing import Literal, TypeAlias
|
|
26
25
|
|
|
27
26
|
import arviz as az
|
|
28
|
-
import blackjax
|
|
29
27
|
import filelock
|
|
30
|
-
import jax
|
|
31
28
|
import numpy as np
|
|
32
29
|
import pymc as pm
|
|
33
30
|
import pytensor
|
|
@@ -42,11 +39,10 @@ from pymc.initial_point import make_initial_point_fn
|
|
|
42
39
|
from pymc.model import modelcontext
|
|
43
40
|
from pymc.model.core import Point
|
|
44
41
|
from pymc.pytensorf import (
|
|
45
|
-
|
|
42
|
+
compile,
|
|
46
43
|
find_rng_nodes,
|
|
47
44
|
reseed_rngs,
|
|
48
45
|
)
|
|
49
|
-
from pymc.sampling.jax import get_jaxified_graph
|
|
50
46
|
from pymc.util import (
|
|
51
47
|
CustomProgress,
|
|
52
48
|
RandomSeed,
|
|
@@ -67,6 +63,7 @@ from rich.text import Text
|
|
|
67
63
|
# TODO: change to typing.Self after Python versions greater than 3.10
|
|
68
64
|
from typing_extensions import Self
|
|
69
65
|
|
|
66
|
+
from pymc_extras.inference.laplace import add_data_to_inferencedata
|
|
70
67
|
from pymc_extras.inference.pathfinder.importance_sampling import (
|
|
71
68
|
importance_sampling as _importance_sampling,
|
|
72
69
|
)
|
|
@@ -78,9 +75,6 @@ from pymc_extras.inference.pathfinder.lbfgs import (
|
|
|
78
75
|
)
|
|
79
76
|
|
|
80
77
|
logger = logging.getLogger(__name__)
|
|
81
|
-
_warnings.filterwarnings(
|
|
82
|
-
"ignore", category=FutureWarning, message="compile_pymc was renamed to compile"
|
|
83
|
-
)
|
|
84
78
|
|
|
85
79
|
REGULARISATION_TERM = 1e-8
|
|
86
80
|
DEFAULT_LINKER = "cvm_nogc"
|
|
@@ -105,6 +99,8 @@ def get_jaxified_logp_of_ravel_inputs(model: Model, jacobian: bool = True) -> Ca
|
|
|
105
99
|
A JAX function that computes the log-probability of a PyMC model with ravelled inputs.
|
|
106
100
|
"""
|
|
107
101
|
|
|
102
|
+
from pymc.sampling.jax import get_jaxified_graph
|
|
103
|
+
|
|
108
104
|
# TODO: JAX: test if we should get jaxified graph of dlogp as well
|
|
109
105
|
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
|
|
110
106
|
model.initial_point(), (model.logp(jacobian=jacobian),), model.value_vars, ()
|
|
@@ -144,7 +140,7 @@ def get_logp_dlogp_of_ravel_inputs(
|
|
|
144
140
|
[model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)],
|
|
145
141
|
model.value_vars,
|
|
146
142
|
)
|
|
147
|
-
logp_dlogp_fn =
|
|
143
|
+
logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs)
|
|
148
144
|
logp_dlogp_fn.trust_input = True
|
|
149
145
|
|
|
150
146
|
return logp_dlogp_fn
|
|
@@ -224,6 +220,10 @@ def convert_flat_trace_to_idata(
|
|
|
224
220
|
result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
|
|
225
221
|
|
|
226
222
|
elif inference_backend == "blackjax":
|
|
223
|
+
import jax
|
|
224
|
+
|
|
225
|
+
from pymc.sampling.jax import get_jaxified_graph
|
|
226
|
+
|
|
227
227
|
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
|
|
228
228
|
result = jax.vmap(jax.vmap(jax_fn))(
|
|
229
229
|
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
|
|
@@ -237,8 +237,8 @@ def convert_flat_trace_to_idata(
|
|
|
237
237
|
|
|
238
238
|
|
|
239
239
|
def alpha_recover(
|
|
240
|
-
x: TensorVariable, g: TensorVariable
|
|
241
|
-
) -> tuple[TensorVariable, TensorVariable, TensorVariable
|
|
240
|
+
x: TensorVariable, g: TensorVariable
|
|
241
|
+
) -> tuple[TensorVariable, TensorVariable, TensorVariable]:
|
|
242
242
|
"""compute the diagonal elements of the inverse Hessian at each iterations of L-BFGS and filter updates.
|
|
243
243
|
|
|
244
244
|
Parameters
|
|
@@ -247,9 +247,6 @@ def alpha_recover(
|
|
|
247
247
|
position array, shape (L+1, N)
|
|
248
248
|
g : TensorVariable
|
|
249
249
|
gradient array, shape (L+1, N)
|
|
250
|
-
epsilon : float
|
|
251
|
-
threshold for filtering updates based on inner product of position
|
|
252
|
-
and gradient differences
|
|
253
250
|
|
|
254
251
|
Returns
|
|
255
252
|
-------
|
|
@@ -259,15 +256,13 @@ def alpha_recover(
|
|
|
259
256
|
position differences, shape (L, N)
|
|
260
257
|
z : TensorVariable
|
|
261
258
|
gradient differences, shape (L, N)
|
|
262
|
-
update_mask : TensorVariable
|
|
263
|
-
mask for filtering updates, shape (L,)
|
|
264
259
|
|
|
265
260
|
Notes
|
|
266
261
|
-----
|
|
267
262
|
shapes: L=batch_size, N=num_params
|
|
268
263
|
"""
|
|
269
264
|
|
|
270
|
-
def compute_alpha_l(
|
|
265
|
+
def compute_alpha_l(s_l, z_l, alpha_lm1) -> TensorVariable:
|
|
271
266
|
# alpha_lm1: (N,)
|
|
272
267
|
# s_l: (N,)
|
|
273
268
|
# z_l: (N,)
|
|
@@ -281,43 +276,28 @@ def alpha_recover(
|
|
|
281
276
|
) # fmt:off
|
|
282
277
|
return 1.0 / inv_alpha_l
|
|
283
278
|
|
|
284
|
-
def return_alpha_lm1(alpha_lm1, s_l, z_l) -> TensorVariable:
|
|
285
|
-
return alpha_lm1[-1]
|
|
286
|
-
|
|
287
|
-
def scan_body(update_mask_l, s_l, z_l, alpha_lm1) -> TensorVariable:
|
|
288
|
-
return pt.switch(
|
|
289
|
-
update_mask_l,
|
|
290
|
-
compute_alpha_l(alpha_lm1, s_l, z_l),
|
|
291
|
-
return_alpha_lm1(alpha_lm1, s_l, z_l),
|
|
292
|
-
)
|
|
293
|
-
|
|
294
279
|
Lp1, N = x.shape
|
|
295
280
|
s = pt.diff(x, axis=0)
|
|
296
281
|
z = pt.diff(g, axis=0)
|
|
297
282
|
alpha_l_init = pt.ones(N)
|
|
298
|
-
sz = (s * z).sum(axis=-1)
|
|
299
|
-
# update_mask = sz > epsilon * pt.linalg.norm(z, axis=-1)
|
|
300
|
-
# pt.linalg.norm does not work with JAX!!
|
|
301
|
-
update_mask = sz > epsilon * pt.sqrt(pt.sum(z**2, axis=-1))
|
|
302
283
|
|
|
303
284
|
alpha, _ = pytensor.scan(
|
|
304
|
-
fn=
|
|
285
|
+
fn=compute_alpha_l,
|
|
305
286
|
outputs_info=alpha_l_init,
|
|
306
|
-
sequences=[
|
|
287
|
+
sequences=[s, z],
|
|
307
288
|
n_steps=Lp1 - 1,
|
|
308
289
|
allow_gc=False,
|
|
309
290
|
)
|
|
310
291
|
|
|
311
292
|
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
|
|
312
|
-
# alpha: (L, N)
|
|
313
|
-
return alpha, s, z
|
|
293
|
+
# alpha: (L, N)
|
|
294
|
+
return alpha, s, z
|
|
314
295
|
|
|
315
296
|
|
|
316
297
|
def inverse_hessian_factors(
|
|
317
298
|
alpha: TensorVariable,
|
|
318
299
|
s: TensorVariable,
|
|
319
300
|
z: TensorVariable,
|
|
320
|
-
update_mask: TensorVariable,
|
|
321
301
|
J: TensorConstant,
|
|
322
302
|
) -> tuple[TensorVariable, TensorVariable]:
|
|
323
303
|
"""compute the inverse hessian factors for the BFGS approximation.
|
|
@@ -330,8 +310,6 @@ def inverse_hessian_factors(
|
|
|
330
310
|
position differences, shape (L, N)
|
|
331
311
|
z : TensorVariable
|
|
332
312
|
gradient differences, shape (L, N)
|
|
333
|
-
update_mask : TensorVariable
|
|
334
|
-
mask for filtering updates, shape (L,)
|
|
335
313
|
J : TensorConstant
|
|
336
314
|
history size for L-BFGS
|
|
337
315
|
|
|
@@ -350,30 +328,19 @@ def inverse_hessian_factors(
|
|
|
350
328
|
# NOTE: get_chi_matrix_1 is a modified version of get_chi_matrix_2 to closely follow Zhang et al., (2022)
|
|
351
329
|
# NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented
|
|
352
330
|
|
|
353
|
-
def get_chi_matrix_1(
|
|
354
|
-
diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
|
|
355
|
-
) -> TensorVariable:
|
|
331
|
+
def get_chi_matrix_1(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
|
|
356
332
|
L, N = diff.shape
|
|
357
333
|
j_last = pt.as_tensor(J - 1) # since indexing starts at 0
|
|
358
334
|
|
|
359
|
-
def chi_update(
|
|
335
|
+
def chi_update(diff_l, chi_lm1) -> TensorVariable:
|
|
360
336
|
chi_l = pt.roll(chi_lm1, -1, axis=0)
|
|
361
337
|
return pt.set_subtensor(chi_l[j_last], diff_l)
|
|
362
338
|
|
|
363
|
-
def no_op(chi_lm1, diff_l) -> TensorVariable:
|
|
364
|
-
return chi_lm1
|
|
365
|
-
|
|
366
|
-
def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable:
|
|
367
|
-
return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l))
|
|
368
|
-
|
|
369
339
|
chi_init = pt.zeros((J, N))
|
|
370
340
|
chi_mat, _ = pytensor.scan(
|
|
371
|
-
fn=
|
|
341
|
+
fn=chi_update,
|
|
372
342
|
outputs_info=chi_init,
|
|
373
|
-
sequences=[
|
|
374
|
-
update_mask,
|
|
375
|
-
diff,
|
|
376
|
-
],
|
|
343
|
+
sequences=[diff],
|
|
377
344
|
allow_gc=False,
|
|
378
345
|
)
|
|
379
346
|
|
|
@@ -382,19 +349,15 @@ def inverse_hessian_factors(
|
|
|
382
349
|
# (L, N, J)
|
|
383
350
|
return chi_mat
|
|
384
351
|
|
|
385
|
-
def get_chi_matrix_2(
|
|
386
|
-
diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
|
|
387
|
-
) -> TensorVariable:
|
|
352
|
+
def get_chi_matrix_2(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
|
|
388
353
|
L, N = diff.shape
|
|
389
354
|
|
|
390
|
-
diff_masked = update_mask[:, None] * diff
|
|
391
|
-
|
|
392
355
|
# diff_padded: (L+J, N)
|
|
393
356
|
pad_width = pt.zeros(shape=(2, 2), dtype="int32")
|
|
394
|
-
pad_width = pt.set_subtensor(pad_width[0, 0], J)
|
|
395
|
-
diff_padded = pt.pad(
|
|
357
|
+
pad_width = pt.set_subtensor(pad_width[0, 0], J - 1)
|
|
358
|
+
diff_padded = pt.pad(diff, pad_width, mode="constant")
|
|
396
359
|
|
|
397
|
-
index = pt.arange(L)[
|
|
360
|
+
index = pt.arange(L)[..., None] + pt.arange(J)[None, ...]
|
|
398
361
|
index = index.reshape((L, J))
|
|
399
362
|
|
|
400
363
|
chi_mat = pt.matrix_transpose(diff_padded[index])
|
|
@@ -403,8 +366,10 @@ def inverse_hessian_factors(
|
|
|
403
366
|
return chi_mat
|
|
404
367
|
|
|
405
368
|
L, N = alpha.shape
|
|
406
|
-
|
|
407
|
-
|
|
369
|
+
|
|
370
|
+
# changed to get_chi_matrix_2 after removing update_mask
|
|
371
|
+
S = get_chi_matrix_2(s, J)
|
|
372
|
+
Z = get_chi_matrix_2(z, J)
|
|
408
373
|
|
|
409
374
|
# E: (L, J, J)
|
|
410
375
|
Ij = pt.eye(J)[None, ...]
|
|
@@ -489,6 +454,7 @@ def bfgs_sample_dense(
|
|
|
489
454
|
|
|
490
455
|
N = x.shape[-1]
|
|
491
456
|
IdN = pt.eye(N)[None, ...]
|
|
457
|
+
IdN += IdN * REGULARISATION_TERM
|
|
492
458
|
|
|
493
459
|
# inverse Hessian
|
|
494
460
|
H_inv = (
|
|
@@ -504,7 +470,10 @@ def bfgs_sample_dense(
|
|
|
504
470
|
|
|
505
471
|
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
|
|
506
472
|
|
|
507
|
-
mu = x - pt.
|
|
473
|
+
# mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
|
|
474
|
+
|
|
475
|
+
batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
|
|
476
|
+
mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
|
|
508
477
|
|
|
509
478
|
phi = pt.matrix_transpose(
|
|
510
479
|
# (L, N, 1)
|
|
@@ -565,23 +534,28 @@ def bfgs_sample_sparse(
|
|
|
565
534
|
# qr_input: (L, N, 2J)
|
|
566
535
|
qr_input = inv_sqrt_alpha_diag @ beta
|
|
567
536
|
(Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False)
|
|
537
|
+
|
|
568
538
|
IdN = pt.eye(R.shape[1])[None, ...]
|
|
539
|
+
IdN += IdN * REGULARISATION_TERM
|
|
540
|
+
|
|
569
541
|
Lchol_input = IdN + R @ gamma @ pt.matrix_transpose(R)
|
|
570
542
|
|
|
543
|
+
# TODO: make robust Lchol calcs more robust, ie. try exceptions, increase REGULARISATION_TERM if non-finite exists
|
|
571
544
|
Lchol = pt.linalg.cholesky(Lchol_input, lower=False, check_finite=False, on_error="nan")
|
|
572
545
|
|
|
573
546
|
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
|
|
574
547
|
logdet += pt.sum(pt.log(alpha), axis=-1)
|
|
575
548
|
|
|
549
|
+
# inverse Hessian
|
|
550
|
+
# (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
|
|
551
|
+
H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta))
|
|
552
|
+
|
|
576
553
|
# NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
# (L, N, N), (L, N) -> (L, N)
|
|
583
|
-
+ pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
|
|
584
|
-
)
|
|
554
|
+
|
|
555
|
+
# mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
|
|
556
|
+
|
|
557
|
+
batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
|
|
558
|
+
mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
|
|
585
559
|
|
|
586
560
|
phi = pt.matrix_transpose(
|
|
587
561
|
# (L, N, 1)
|
|
@@ -589,8 +563,6 @@ def bfgs_sample_sparse(
|
|
|
589
563
|
# (L, N, N), (L, N, M) -> (L, N, M)
|
|
590
564
|
+ sqrt_alpha_diag
|
|
591
565
|
@ (
|
|
592
|
-
# (L, N, 2J), (L, 2J, M) -> (L, N, M)
|
|
593
|
-
# intermediate calcs below
|
|
594
566
|
# (L, N, 2J), (L, 2J, 2J) -> (L, N, 2J)
|
|
595
567
|
(Q @ (Lchol - IdN))
|
|
596
568
|
# (L, 2J, N), (L, N, M) -> (L, 2J, M)
|
|
@@ -778,7 +750,6 @@ def make_pathfinder_body(
|
|
|
778
750
|
num_draws: int,
|
|
779
751
|
maxcor: int,
|
|
780
752
|
num_elbo_draws: int,
|
|
781
|
-
epsilon: float,
|
|
782
753
|
**compile_kwargs: dict,
|
|
783
754
|
) -> Function:
|
|
784
755
|
"""
|
|
@@ -794,8 +765,6 @@ def make_pathfinder_body(
|
|
|
794
765
|
The maximum number of iterations for the L-BFGS algorithm.
|
|
795
766
|
num_elbo_draws : int
|
|
796
767
|
The number of draws for the Evidence Lower Bound (ELBO) estimation.
|
|
797
|
-
epsilon : float
|
|
798
|
-
The value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L.
|
|
799
768
|
compile_kwargs : dict
|
|
800
769
|
Additional keyword arguments for the PyTensor compiler.
|
|
801
770
|
|
|
@@ -820,11 +789,10 @@ def make_pathfinder_body(
|
|
|
820
789
|
|
|
821
790
|
num_draws = pt.constant(num_draws, "num_draws", dtype="int32")
|
|
822
791
|
num_elbo_draws = pt.constant(num_elbo_draws, "num_elbo_draws", dtype="int32")
|
|
823
|
-
epsilon = pt.constant(epsilon, "epsilon", dtype="float64")
|
|
824
792
|
maxcor = pt.constant(maxcor, "maxcor", dtype="int32")
|
|
825
793
|
|
|
826
|
-
alpha, s, z
|
|
827
|
-
beta, gamma = inverse_hessian_factors(alpha, s, z,
|
|
794
|
+
alpha, s, z = alpha_recover(x_full, g_full)
|
|
795
|
+
beta, gamma = inverse_hessian_factors(alpha, s, z, J=maxcor)
|
|
828
796
|
|
|
829
797
|
# ignore initial point - x, g: (L, N)
|
|
830
798
|
x = x_full[1:]
|
|
@@ -855,7 +823,7 @@ def make_pathfinder_body(
|
|
|
855
823
|
|
|
856
824
|
# return psi, logP_psi, logQ_psi, elbo_argmax
|
|
857
825
|
|
|
858
|
-
pathfinder_body_fn =
|
|
826
|
+
pathfinder_body_fn = compile(
|
|
859
827
|
[x_full, g_full],
|
|
860
828
|
[psi, logP_psi, logQ_psi, elbo_argmax],
|
|
861
829
|
**compile_kwargs,
|
|
@@ -934,11 +902,11 @@ def make_single_pathfinder_fn(
|
|
|
934
902
|
x_base = DictToArrayBijection.map(ip).data
|
|
935
903
|
|
|
936
904
|
# lbfgs
|
|
937
|
-
lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls)
|
|
905
|
+
lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls, epsilon)
|
|
938
906
|
|
|
939
907
|
# pathfinder body
|
|
940
908
|
pathfinder_body_fn = make_pathfinder_body(
|
|
941
|
-
logp_func, num_draws, maxcor, num_elbo_draws,
|
|
909
|
+
logp_func, num_draws, maxcor, num_elbo_draws, **compile_kwargs
|
|
942
910
|
)
|
|
943
911
|
rngs = find_rng_nodes(pathfinder_body_fn.maker.fgraph.outputs)
|
|
944
912
|
|
|
@@ -950,8 +918,8 @@ def make_single_pathfinder_fn(
|
|
|
950
918
|
x0 = x_base + jitter_value
|
|
951
919
|
x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0)
|
|
952
920
|
|
|
953
|
-
if lbfgs_status
|
|
954
|
-
raise LBFGSInitFailed()
|
|
921
|
+
if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT}:
|
|
922
|
+
raise LBFGSInitFailed(lbfgs_status)
|
|
955
923
|
elif lbfgs_status == LBFGSStatus.LBFGS_FAILED:
|
|
956
924
|
raise LBFGSException()
|
|
957
925
|
|
|
@@ -1389,15 +1357,16 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
|
|
|
1389
1357
|
warnings = []
|
|
1390
1358
|
|
|
1391
1359
|
lbfgs_status_message = {
|
|
1392
|
-
LBFGSStatus.MAX_ITER_REACHED: "LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
|
|
1393
|
-
LBFGSStatus.INIT_FAILED: "LBFGS failed to
|
|
1394
|
-
LBFGSStatus.
|
|
1360
|
+
LBFGSStatus.MAX_ITER_REACHED: "MAX_ITER_REACHED: LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
|
|
1361
|
+
LBFGSStatus.INIT_FAILED: "INIT_FAILED: LBFGS failed to initialize. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
|
|
1362
|
+
LBFGSStatus.NON_FINITE: "NON_FINITE: LBFGS objective function produced inf or nan at the last iteration. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
|
|
1363
|
+
LBFGSStatus.LOW_UPDATE_PCT: "LOW_UPDATE_PCT: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
|
|
1364
|
+
LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT: "INIT_FAILED_LOW_UPDATE_PCT: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
|
|
1395
1365
|
}
|
|
1396
1366
|
|
|
1397
1367
|
path_status_message = {
|
|
1398
|
-
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter
|
|
1399
|
-
PathStatus.
|
|
1400
|
-
PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
|
|
1368
|
+
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO_ARGMAX_AT_ZERO: ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
|
|
1369
|
+
PathStatus.INVALID_LOGQ: "INVALID_LOGQ: Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
|
|
1401
1370
|
}
|
|
1402
1371
|
|
|
1403
1372
|
for lbfgs_status in mpr.lbfgs_status:
|
|
@@ -1567,8 +1536,9 @@ def multipath_pathfinder(
|
|
|
1567
1536
|
task,
|
|
1568
1537
|
description=desc.format(path_idx=path_idx),
|
|
1569
1538
|
completed=path_idx,
|
|
1570
|
-
refresh=True,
|
|
1571
1539
|
)
|
|
1540
|
+
# Ensure the progress bar visually reaches 100% and shows 'Completed'
|
|
1541
|
+
progress.update(task, completed=num_paths, description="Completed")
|
|
1572
1542
|
except (KeyboardInterrupt, StopIteration) as e:
|
|
1573
1543
|
# if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
|
|
1574
1544
|
if isinstance(e, StopIteration):
|
|
@@ -1618,7 +1588,7 @@ def fit_pathfinder(
|
|
|
1618
1588
|
maxiter: int = 1000, # L^max
|
|
1619
1589
|
ftol: float = 1e-5,
|
|
1620
1590
|
gtol: float = 1e-8,
|
|
1621
|
-
maxls=1000,
|
|
1591
|
+
maxls: int = 1000,
|
|
1622
1592
|
num_elbo_draws: int = 10, # K
|
|
1623
1593
|
jitter: float = 2.0,
|
|
1624
1594
|
epsilon: float = 1e-8,
|
|
@@ -1630,6 +1600,7 @@ def fit_pathfinder(
|
|
|
1630
1600
|
inference_backend: Literal["pymc", "blackjax"] = "pymc",
|
|
1631
1601
|
pathfinder_kwargs: dict = {},
|
|
1632
1602
|
compile_kwargs: dict = {},
|
|
1603
|
+
initvals: dict | None = None,
|
|
1633
1604
|
) -> az.InferenceData:
|
|
1634
1605
|
"""
|
|
1635
1606
|
Fit the Pathfinder Variational Inference algorithm.
|
|
@@ -1665,12 +1636,12 @@ def fit_pathfinder(
|
|
|
1665
1636
|
importance_sampling : str, None, optional
|
|
1666
1637
|
Method to apply sampling based on log importance weights (logP - logQ).
|
|
1667
1638
|
Options are:
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
"psir" : Pareto Smoothed Importance Resampling
|
|
1671
|
-
|
|
1672
|
-
|
|
1673
|
-
|
|
1639
|
+
|
|
1640
|
+
- "psis" : Pareto Smoothed Importance Sampling (default). Usually most stable.
|
|
1641
|
+
- "psir" : Pareto Smoothed Importance Resampling. Less stable than PSIS.
|
|
1642
|
+
- "identity" : Applies log importance weights directly without resampling.
|
|
1643
|
+
- None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
|
|
1644
|
+
|
|
1674
1645
|
progressbar : bool, optional
|
|
1675
1646
|
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
|
|
1676
1647
|
random_seed : RandomSeed, optional
|
|
@@ -1685,10 +1656,13 @@ def fit_pathfinder(
|
|
|
1685
1656
|
Additional keyword arguments for the Pathfinder algorithm.
|
|
1686
1657
|
compile_kwargs
|
|
1687
1658
|
Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
|
|
1659
|
+
initvals: dict | None = None
|
|
1660
|
+
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
|
|
1661
|
+
If None, the model's default initial values are used.
|
|
1688
1662
|
|
|
1689
1663
|
Returns
|
|
1690
1664
|
-------
|
|
1691
|
-
arviz.InferenceData
|
|
1665
|
+
:class:`~arviz.InferenceData`
|
|
1692
1666
|
The inference data containing the results of the Pathfinder algorithm.
|
|
1693
1667
|
|
|
1694
1668
|
References
|
|
@@ -1698,6 +1672,14 @@ def fit_pathfinder(
|
|
|
1698
1672
|
|
|
1699
1673
|
model = modelcontext(model)
|
|
1700
1674
|
|
|
1675
|
+
if initvals is not None:
|
|
1676
|
+
model = pm.model.fgraph.clone_model(model) # Create a clone of the model
|
|
1677
|
+
for (
|
|
1678
|
+
rv_name,
|
|
1679
|
+
ivals,
|
|
1680
|
+
) in initvals.items(): # Set the initial values for the variables in the clone
|
|
1681
|
+
model.set_initval(model.named_vars[rv_name], ivals)
|
|
1682
|
+
|
|
1701
1683
|
valid_importance_sampling = {"psis", "psir", "identity", None}
|
|
1702
1684
|
|
|
1703
1685
|
if importance_sampling is not None:
|
|
@@ -1736,8 +1718,9 @@ def fit_pathfinder(
|
|
|
1736
1718
|
)
|
|
1737
1719
|
pathfinder_samples = mp_result.samples
|
|
1738
1720
|
elif inference_backend == "blackjax":
|
|
1739
|
-
|
|
1740
|
-
|
|
1721
|
+
import blackjax
|
|
1722
|
+
import jax
|
|
1723
|
+
|
|
1741
1724
|
if version.parse(blackjax.__version__).major < 1:
|
|
1742
1725
|
raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
|
|
1743
1726
|
|
|
@@ -1775,4 +1758,7 @@ def fit_pathfinder(
|
|
|
1775
1758
|
model=model,
|
|
1776
1759
|
importance_sampling=importance_sampling,
|
|
1777
1760
|
)
|
|
1761
|
+
|
|
1762
|
+
idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
|
|
1763
|
+
|
|
1778
1764
|
return idata
|