pymc-extras 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +1 -3
- pymc_extras/distributions/__init__.py +2 -0
- pymc_extras/distributions/transforms/__init__.py +3 -0
- pymc_extras/distributions/transforms/partial_order.py +227 -0
- pymc_extras/inference/__init__.py +4 -2
- pymc_extras/inference/fit.py +6 -4
- pymc_extras/inference/laplace.py +4 -1
- pymc_extras/inference/pathfinder/importance_sampling.py +23 -17
- pymc_extras/inference/pathfinder/lbfgs.py +49 -13
- pymc_extras/inference/pathfinder/pathfinder.py +136 -118
- pymc_extras/statespace/core/statespace.py +5 -4
- pymc_extras/statespace/filters/distributions.py +9 -45
- pymc_extras/statespace/utils/data_tools.py +24 -9
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/METADATA +5 -3
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/RECORD +23 -20
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/WHEEL +1 -1
- tests/distributions/test_transform.py +77 -0
- tests/statespace/test_coord_assignment.py +65 -0
- tests/test_laplace.py +16 -0
- tests/test_pathfinder.py +141 -17
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info/licenses}/LICENSE +0 -0
- {pymc_extras-0.2.3.dist-info → pymc_extras-0.2.5.dist-info}/top_level.txt +0 -0
|
@@ -12,22 +12,19 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
|
|
15
16
|
import collections
|
|
16
17
|
import logging
|
|
17
18
|
import time
|
|
18
|
-
import warnings as _warnings
|
|
19
19
|
|
|
20
20
|
from collections import Counter
|
|
21
21
|
from collections.abc import Callable, Iterator
|
|
22
22
|
from dataclasses import asdict, dataclass, field, replace
|
|
23
23
|
from enum import Enum, auto
|
|
24
|
-
from importlib.util import find_spec
|
|
25
24
|
from typing import Literal, TypeAlias
|
|
26
25
|
|
|
27
26
|
import arviz as az
|
|
28
|
-
import blackjax
|
|
29
27
|
import filelock
|
|
30
|
-
import jax
|
|
31
28
|
import numpy as np
|
|
32
29
|
import pymc as pm
|
|
33
30
|
import pytensor
|
|
@@ -42,11 +39,10 @@ from pymc.initial_point import make_initial_point_fn
|
|
|
42
39
|
from pymc.model import modelcontext
|
|
43
40
|
from pymc.model.core import Point
|
|
44
41
|
from pymc.pytensorf import (
|
|
45
|
-
|
|
42
|
+
compile,
|
|
46
43
|
find_rng_nodes,
|
|
47
44
|
reseed_rngs,
|
|
48
45
|
)
|
|
49
|
-
from pymc.sampling.jax import get_jaxified_graph
|
|
50
46
|
from pymc.util import (
|
|
51
47
|
CustomProgress,
|
|
52
48
|
RandomSeed,
|
|
@@ -60,12 +56,14 @@ from pytensor.graph import Apply, Op, vectorize_graph
|
|
|
60
56
|
from pytensor.tensor import TensorConstant, TensorVariable
|
|
61
57
|
from rich.console import Console, Group
|
|
62
58
|
from rich.padding import Padding
|
|
59
|
+
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
|
|
63
60
|
from rich.table import Table
|
|
64
61
|
from rich.text import Text
|
|
65
62
|
|
|
66
63
|
# TODO: change to typing.Self after Python versions greater than 3.10
|
|
67
64
|
from typing_extensions import Self
|
|
68
65
|
|
|
66
|
+
from pymc_extras.inference.laplace import add_data_to_inferencedata
|
|
69
67
|
from pymc_extras.inference.pathfinder.importance_sampling import (
|
|
70
68
|
importance_sampling as _importance_sampling,
|
|
71
69
|
)
|
|
@@ -77,9 +75,6 @@ from pymc_extras.inference.pathfinder.lbfgs import (
|
|
|
77
75
|
)
|
|
78
76
|
|
|
79
77
|
logger = logging.getLogger(__name__)
|
|
80
|
-
_warnings.filterwarnings(
|
|
81
|
-
"ignore", category=FutureWarning, message="compile_pymc was renamed to compile"
|
|
82
|
-
)
|
|
83
78
|
|
|
84
79
|
REGULARISATION_TERM = 1e-8
|
|
85
80
|
DEFAULT_LINKER = "cvm_nogc"
|
|
@@ -104,6 +99,8 @@ def get_jaxified_logp_of_ravel_inputs(model: Model, jacobian: bool = True) -> Ca
|
|
|
104
99
|
A JAX function that computes the log-probability of a PyMC model with ravelled inputs.
|
|
105
100
|
"""
|
|
106
101
|
|
|
102
|
+
from pymc.sampling.jax import get_jaxified_graph
|
|
103
|
+
|
|
107
104
|
# TODO: JAX: test if we should get jaxified graph of dlogp as well
|
|
108
105
|
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
|
|
109
106
|
model.initial_point(), (model.logp(jacobian=jacobian),), model.value_vars, ()
|
|
@@ -143,7 +140,7 @@ def get_logp_dlogp_of_ravel_inputs(
|
|
|
143
140
|
[model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)],
|
|
144
141
|
model.value_vars,
|
|
145
142
|
)
|
|
146
|
-
logp_dlogp_fn =
|
|
143
|
+
logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs)
|
|
147
144
|
logp_dlogp_fn.trust_input = True
|
|
148
145
|
|
|
149
146
|
return logp_dlogp_fn
|
|
@@ -155,7 +152,7 @@ def convert_flat_trace_to_idata(
|
|
|
155
152
|
postprocessing_backend: Literal["cpu", "gpu"] = "cpu",
|
|
156
153
|
inference_backend: Literal["pymc", "blackjax"] = "pymc",
|
|
157
154
|
model: Model | None = None,
|
|
158
|
-
importance_sampling: Literal["psis", "psir", "identity"
|
|
155
|
+
importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
|
|
159
156
|
) -> az.InferenceData:
|
|
160
157
|
"""convert flattened samples to arviz InferenceData format.
|
|
161
158
|
|
|
@@ -180,7 +177,7 @@ def convert_flat_trace_to_idata(
|
|
|
180
177
|
arviz inference data object
|
|
181
178
|
"""
|
|
182
179
|
|
|
183
|
-
if importance_sampling
|
|
180
|
+
if importance_sampling is None:
|
|
184
181
|
# samples.ndim == 3 in this case, otherwise ndim == 2
|
|
185
182
|
num_paths, num_pdraws, N = samples.shape
|
|
186
183
|
samples = samples.reshape(-1, N)
|
|
@@ -219,10 +216,14 @@ def convert_flat_trace_to_idata(
|
|
|
219
216
|
fn.trust_input = True
|
|
220
217
|
result = fn(*list(trace.values()))
|
|
221
218
|
|
|
222
|
-
if importance_sampling
|
|
219
|
+
if importance_sampling is None:
|
|
223
220
|
result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
|
|
224
221
|
|
|
225
222
|
elif inference_backend == "blackjax":
|
|
223
|
+
import jax
|
|
224
|
+
|
|
225
|
+
from pymc.sampling.jax import get_jaxified_graph
|
|
226
|
+
|
|
226
227
|
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
|
|
227
228
|
result = jax.vmap(jax.vmap(jax_fn))(
|
|
228
229
|
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
|
|
@@ -236,8 +237,8 @@ def convert_flat_trace_to_idata(
|
|
|
236
237
|
|
|
237
238
|
|
|
238
239
|
def alpha_recover(
|
|
239
|
-
x: TensorVariable, g: TensorVariable
|
|
240
|
-
) -> tuple[TensorVariable, TensorVariable, TensorVariable
|
|
240
|
+
x: TensorVariable, g: TensorVariable
|
|
241
|
+
) -> tuple[TensorVariable, TensorVariable, TensorVariable]:
|
|
241
242
|
"""compute the diagonal elements of the inverse Hessian at each iterations of L-BFGS and filter updates.
|
|
242
243
|
|
|
243
244
|
Parameters
|
|
@@ -246,9 +247,6 @@ def alpha_recover(
|
|
|
246
247
|
position array, shape (L+1, N)
|
|
247
248
|
g : TensorVariable
|
|
248
249
|
gradient array, shape (L+1, N)
|
|
249
|
-
epsilon : float
|
|
250
|
-
threshold for filtering updates based on inner product of position
|
|
251
|
-
and gradient differences
|
|
252
250
|
|
|
253
251
|
Returns
|
|
254
252
|
-------
|
|
@@ -258,15 +256,13 @@ def alpha_recover(
|
|
|
258
256
|
position differences, shape (L, N)
|
|
259
257
|
z : TensorVariable
|
|
260
258
|
gradient differences, shape (L, N)
|
|
261
|
-
update_mask : TensorVariable
|
|
262
|
-
mask for filtering updates, shape (L,)
|
|
263
259
|
|
|
264
260
|
Notes
|
|
265
261
|
-----
|
|
266
262
|
shapes: L=batch_size, N=num_params
|
|
267
263
|
"""
|
|
268
264
|
|
|
269
|
-
def compute_alpha_l(
|
|
265
|
+
def compute_alpha_l(s_l, z_l, alpha_lm1) -> TensorVariable:
|
|
270
266
|
# alpha_lm1: (N,)
|
|
271
267
|
# s_l: (N,)
|
|
272
268
|
# z_l: (N,)
|
|
@@ -280,43 +276,28 @@ def alpha_recover(
|
|
|
280
276
|
) # fmt:off
|
|
281
277
|
return 1.0 / inv_alpha_l
|
|
282
278
|
|
|
283
|
-
def return_alpha_lm1(alpha_lm1, s_l, z_l) -> TensorVariable:
|
|
284
|
-
return alpha_lm1[-1]
|
|
285
|
-
|
|
286
|
-
def scan_body(update_mask_l, s_l, z_l, alpha_lm1) -> TensorVariable:
|
|
287
|
-
return pt.switch(
|
|
288
|
-
update_mask_l,
|
|
289
|
-
compute_alpha_l(alpha_lm1, s_l, z_l),
|
|
290
|
-
return_alpha_lm1(alpha_lm1, s_l, z_l),
|
|
291
|
-
)
|
|
292
|
-
|
|
293
279
|
Lp1, N = x.shape
|
|
294
280
|
s = pt.diff(x, axis=0)
|
|
295
281
|
z = pt.diff(g, axis=0)
|
|
296
282
|
alpha_l_init = pt.ones(N)
|
|
297
|
-
sz = (s * z).sum(axis=-1)
|
|
298
|
-
# update_mask = sz > epsilon * pt.linalg.norm(z, axis=-1)
|
|
299
|
-
# pt.linalg.norm does not work with JAX!!
|
|
300
|
-
update_mask = sz > epsilon * pt.sqrt(pt.sum(z**2, axis=-1))
|
|
301
283
|
|
|
302
284
|
alpha, _ = pytensor.scan(
|
|
303
|
-
fn=
|
|
285
|
+
fn=compute_alpha_l,
|
|
304
286
|
outputs_info=alpha_l_init,
|
|
305
|
-
sequences=[
|
|
287
|
+
sequences=[s, z],
|
|
306
288
|
n_steps=Lp1 - 1,
|
|
307
289
|
allow_gc=False,
|
|
308
290
|
)
|
|
309
291
|
|
|
310
292
|
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
|
|
311
|
-
# alpha: (L, N)
|
|
312
|
-
return alpha, s, z
|
|
293
|
+
# alpha: (L, N)
|
|
294
|
+
return alpha, s, z
|
|
313
295
|
|
|
314
296
|
|
|
315
297
|
def inverse_hessian_factors(
|
|
316
298
|
alpha: TensorVariable,
|
|
317
299
|
s: TensorVariable,
|
|
318
300
|
z: TensorVariable,
|
|
319
|
-
update_mask: TensorVariable,
|
|
320
301
|
J: TensorConstant,
|
|
321
302
|
) -> tuple[TensorVariable, TensorVariable]:
|
|
322
303
|
"""compute the inverse hessian factors for the BFGS approximation.
|
|
@@ -329,8 +310,6 @@ def inverse_hessian_factors(
|
|
|
329
310
|
position differences, shape (L, N)
|
|
330
311
|
z : TensorVariable
|
|
331
312
|
gradient differences, shape (L, N)
|
|
332
|
-
update_mask : TensorVariable
|
|
333
|
-
mask for filtering updates, shape (L,)
|
|
334
313
|
J : TensorConstant
|
|
335
314
|
history size for L-BFGS
|
|
336
315
|
|
|
@@ -349,30 +328,19 @@ def inverse_hessian_factors(
|
|
|
349
328
|
# NOTE: get_chi_matrix_1 is a modified version of get_chi_matrix_2 to closely follow Zhang et al., (2022)
|
|
350
329
|
# NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented
|
|
351
330
|
|
|
352
|
-
def get_chi_matrix_1(
|
|
353
|
-
diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
|
|
354
|
-
) -> TensorVariable:
|
|
331
|
+
def get_chi_matrix_1(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
|
|
355
332
|
L, N = diff.shape
|
|
356
333
|
j_last = pt.as_tensor(J - 1) # since indexing starts at 0
|
|
357
334
|
|
|
358
|
-
def chi_update(
|
|
335
|
+
def chi_update(diff_l, chi_lm1) -> TensorVariable:
|
|
359
336
|
chi_l = pt.roll(chi_lm1, -1, axis=0)
|
|
360
337
|
return pt.set_subtensor(chi_l[j_last], diff_l)
|
|
361
338
|
|
|
362
|
-
def no_op(chi_lm1, diff_l) -> TensorVariable:
|
|
363
|
-
return chi_lm1
|
|
364
|
-
|
|
365
|
-
def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable:
|
|
366
|
-
return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l))
|
|
367
|
-
|
|
368
339
|
chi_init = pt.zeros((J, N))
|
|
369
340
|
chi_mat, _ = pytensor.scan(
|
|
370
|
-
fn=
|
|
341
|
+
fn=chi_update,
|
|
371
342
|
outputs_info=chi_init,
|
|
372
|
-
sequences=[
|
|
373
|
-
update_mask,
|
|
374
|
-
diff,
|
|
375
|
-
],
|
|
343
|
+
sequences=[diff],
|
|
376
344
|
allow_gc=False,
|
|
377
345
|
)
|
|
378
346
|
|
|
@@ -381,19 +349,15 @@ def inverse_hessian_factors(
|
|
|
381
349
|
# (L, N, J)
|
|
382
350
|
return chi_mat
|
|
383
351
|
|
|
384
|
-
def get_chi_matrix_2(
|
|
385
|
-
diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
|
|
386
|
-
) -> TensorVariable:
|
|
352
|
+
def get_chi_matrix_2(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
|
|
387
353
|
L, N = diff.shape
|
|
388
354
|
|
|
389
|
-
diff_masked = update_mask[:, None] * diff
|
|
390
|
-
|
|
391
355
|
# diff_padded: (L+J, N)
|
|
392
356
|
pad_width = pt.zeros(shape=(2, 2), dtype="int32")
|
|
393
|
-
pad_width = pt.set_subtensor(pad_width[0, 0], J)
|
|
394
|
-
diff_padded = pt.pad(
|
|
357
|
+
pad_width = pt.set_subtensor(pad_width[0, 0], J - 1)
|
|
358
|
+
diff_padded = pt.pad(diff, pad_width, mode="constant")
|
|
395
359
|
|
|
396
|
-
index = pt.arange(L)[
|
|
360
|
+
index = pt.arange(L)[..., None] + pt.arange(J)[None, ...]
|
|
397
361
|
index = index.reshape((L, J))
|
|
398
362
|
|
|
399
363
|
chi_mat = pt.matrix_transpose(diff_padded[index])
|
|
@@ -402,8 +366,10 @@ def inverse_hessian_factors(
|
|
|
402
366
|
return chi_mat
|
|
403
367
|
|
|
404
368
|
L, N = alpha.shape
|
|
405
|
-
|
|
406
|
-
|
|
369
|
+
|
|
370
|
+
# changed to get_chi_matrix_2 after removing update_mask
|
|
371
|
+
S = get_chi_matrix_2(s, J)
|
|
372
|
+
Z = get_chi_matrix_2(z, J)
|
|
407
373
|
|
|
408
374
|
# E: (L, J, J)
|
|
409
375
|
Ij = pt.eye(J)[None, ...]
|
|
@@ -488,6 +454,7 @@ def bfgs_sample_dense(
|
|
|
488
454
|
|
|
489
455
|
N = x.shape[-1]
|
|
490
456
|
IdN = pt.eye(N)[None, ...]
|
|
457
|
+
IdN += IdN * REGULARISATION_TERM
|
|
491
458
|
|
|
492
459
|
# inverse Hessian
|
|
493
460
|
H_inv = (
|
|
@@ -503,7 +470,10 @@ def bfgs_sample_dense(
|
|
|
503
470
|
|
|
504
471
|
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
|
|
505
472
|
|
|
506
|
-
mu = x - pt.
|
|
473
|
+
# mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
|
|
474
|
+
|
|
475
|
+
batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
|
|
476
|
+
mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
|
|
507
477
|
|
|
508
478
|
phi = pt.matrix_transpose(
|
|
509
479
|
# (L, N, 1)
|
|
@@ -564,23 +534,28 @@ def bfgs_sample_sparse(
|
|
|
564
534
|
# qr_input: (L, N, 2J)
|
|
565
535
|
qr_input = inv_sqrt_alpha_diag @ beta
|
|
566
536
|
(Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False)
|
|
537
|
+
|
|
567
538
|
IdN = pt.eye(R.shape[1])[None, ...]
|
|
539
|
+
IdN += IdN * REGULARISATION_TERM
|
|
540
|
+
|
|
568
541
|
Lchol_input = IdN + R @ gamma @ pt.matrix_transpose(R)
|
|
569
542
|
|
|
543
|
+
# TODO: make robust Lchol calcs more robust, ie. try exceptions, increase REGULARISATION_TERM if non-finite exists
|
|
570
544
|
Lchol = pt.linalg.cholesky(Lchol_input, lower=False, check_finite=False, on_error="nan")
|
|
571
545
|
|
|
572
546
|
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
|
|
573
547
|
logdet += pt.sum(pt.log(alpha), axis=-1)
|
|
574
548
|
|
|
549
|
+
# inverse Hessian
|
|
550
|
+
# (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
|
|
551
|
+
H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta))
|
|
552
|
+
|
|
575
553
|
# NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
# (L, N, N), (L, N) -> (L, N)
|
|
582
|
-
+ pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
|
|
583
|
-
)
|
|
554
|
+
|
|
555
|
+
# mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
|
|
556
|
+
|
|
557
|
+
batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
|
|
558
|
+
mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
|
|
584
559
|
|
|
585
560
|
phi = pt.matrix_transpose(
|
|
586
561
|
# (L, N, 1)
|
|
@@ -588,8 +563,6 @@ def bfgs_sample_sparse(
|
|
|
588
563
|
# (L, N, N), (L, N, M) -> (L, N, M)
|
|
589
564
|
+ sqrt_alpha_diag
|
|
590
565
|
@ (
|
|
591
|
-
# (L, N, 2J), (L, 2J, M) -> (L, N, M)
|
|
592
|
-
# intermediate calcs below
|
|
593
566
|
# (L, N, 2J), (L, 2J, 2J) -> (L, N, 2J)
|
|
594
567
|
(Q @ (Lchol - IdN))
|
|
595
568
|
# (L, 2J, N), (L, N, M) -> (L, 2J, M)
|
|
@@ -777,7 +750,6 @@ def make_pathfinder_body(
|
|
|
777
750
|
num_draws: int,
|
|
778
751
|
maxcor: int,
|
|
779
752
|
num_elbo_draws: int,
|
|
780
|
-
epsilon: float,
|
|
781
753
|
**compile_kwargs: dict,
|
|
782
754
|
) -> Function:
|
|
783
755
|
"""
|
|
@@ -793,8 +765,6 @@ def make_pathfinder_body(
|
|
|
793
765
|
The maximum number of iterations for the L-BFGS algorithm.
|
|
794
766
|
num_elbo_draws : int
|
|
795
767
|
The number of draws for the Evidence Lower Bound (ELBO) estimation.
|
|
796
|
-
epsilon : float
|
|
797
|
-
The value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L.
|
|
798
768
|
compile_kwargs : dict
|
|
799
769
|
Additional keyword arguments for the PyTensor compiler.
|
|
800
770
|
|
|
@@ -819,11 +789,10 @@ def make_pathfinder_body(
|
|
|
819
789
|
|
|
820
790
|
num_draws = pt.constant(num_draws, "num_draws", dtype="int32")
|
|
821
791
|
num_elbo_draws = pt.constant(num_elbo_draws, "num_elbo_draws", dtype="int32")
|
|
822
|
-
epsilon = pt.constant(epsilon, "epsilon", dtype="float64")
|
|
823
792
|
maxcor = pt.constant(maxcor, "maxcor", dtype="int32")
|
|
824
793
|
|
|
825
|
-
alpha, s, z
|
|
826
|
-
beta, gamma = inverse_hessian_factors(alpha, s, z,
|
|
794
|
+
alpha, s, z = alpha_recover(x_full, g_full)
|
|
795
|
+
beta, gamma = inverse_hessian_factors(alpha, s, z, J=maxcor)
|
|
827
796
|
|
|
828
797
|
# ignore initial point - x, g: (L, N)
|
|
829
798
|
x = x_full[1:]
|
|
@@ -854,7 +823,7 @@ def make_pathfinder_body(
|
|
|
854
823
|
|
|
855
824
|
# return psi, logP_psi, logQ_psi, elbo_argmax
|
|
856
825
|
|
|
857
|
-
pathfinder_body_fn =
|
|
826
|
+
pathfinder_body_fn = compile(
|
|
858
827
|
[x_full, g_full],
|
|
859
828
|
[psi, logP_psi, logQ_psi, elbo_argmax],
|
|
860
829
|
**compile_kwargs,
|
|
@@ -933,11 +902,11 @@ def make_single_pathfinder_fn(
|
|
|
933
902
|
x_base = DictToArrayBijection.map(ip).data
|
|
934
903
|
|
|
935
904
|
# lbfgs
|
|
936
|
-
lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls)
|
|
905
|
+
lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls, epsilon)
|
|
937
906
|
|
|
938
907
|
# pathfinder body
|
|
939
908
|
pathfinder_body_fn = make_pathfinder_body(
|
|
940
|
-
logp_func, num_draws, maxcor, num_elbo_draws,
|
|
909
|
+
logp_func, num_draws, maxcor, num_elbo_draws, **compile_kwargs
|
|
941
910
|
)
|
|
942
911
|
rngs = find_rng_nodes(pathfinder_body_fn.maker.fgraph.outputs)
|
|
943
912
|
|
|
@@ -949,8 +918,8 @@ def make_single_pathfinder_fn(
|
|
|
949
918
|
x0 = x_base + jitter_value
|
|
950
919
|
x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0)
|
|
951
920
|
|
|
952
|
-
if lbfgs_status
|
|
953
|
-
raise LBFGSInitFailed()
|
|
921
|
+
if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT}:
|
|
922
|
+
raise LBFGSInitFailed(lbfgs_status)
|
|
954
923
|
elif lbfgs_status == LBFGSStatus.LBFGS_FAILED:
|
|
955
924
|
raise LBFGSException()
|
|
956
925
|
|
|
@@ -1188,7 +1157,7 @@ class MultiPathfinderResult:
|
|
|
1188
1157
|
elbo_argmax: NDArray | None = None
|
|
1189
1158
|
lbfgs_status: Counter = field(default_factory=Counter)
|
|
1190
1159
|
path_status: Counter = field(default_factory=Counter)
|
|
1191
|
-
importance_sampling: str = "
|
|
1160
|
+
importance_sampling: str | None = "psis"
|
|
1192
1161
|
warnings: list[str] = field(default_factory=list)
|
|
1193
1162
|
pareto_k: float | None = None
|
|
1194
1163
|
|
|
@@ -1257,7 +1226,7 @@ class MultiPathfinderResult:
|
|
|
1257
1226
|
def with_importance_sampling(
|
|
1258
1227
|
self,
|
|
1259
1228
|
num_draws: int,
|
|
1260
|
-
method: Literal["psis", "psir", "identity"
|
|
1229
|
+
method: Literal["psis", "psir", "identity"] | None,
|
|
1261
1230
|
random_seed: int | None = None,
|
|
1262
1231
|
) -> Self:
|
|
1263
1232
|
"""perform importance sampling"""
|
|
@@ -1388,15 +1357,16 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
|
|
|
1388
1357
|
warnings = []
|
|
1389
1358
|
|
|
1390
1359
|
lbfgs_status_message = {
|
|
1391
|
-
LBFGSStatus.MAX_ITER_REACHED: "LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
|
|
1392
|
-
LBFGSStatus.INIT_FAILED: "LBFGS failed to
|
|
1393
|
-
LBFGSStatus.
|
|
1360
|
+
LBFGSStatus.MAX_ITER_REACHED: "MAX_ITER_REACHED: LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
|
|
1361
|
+
LBFGSStatus.INIT_FAILED: "INIT_FAILED: LBFGS failed to initialize. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
|
|
1362
|
+
LBFGSStatus.NON_FINITE: "NON_FINITE: LBFGS objective function produced inf or nan at the last iteration. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
|
|
1363
|
+
LBFGSStatus.LOW_UPDATE_PCT: "LOW_UPDATE_PCT: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
|
|
1364
|
+
LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT: "INIT_FAILED_LOW_UPDATE_PCT: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
|
|
1394
1365
|
}
|
|
1395
1366
|
|
|
1396
1367
|
path_status_message = {
|
|
1397
|
-
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter
|
|
1398
|
-
PathStatus.
|
|
1399
|
-
PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
|
|
1368
|
+
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO_ARGMAX_AT_ZERO: ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
|
|
1369
|
+
PathStatus.INVALID_LOGQ: "INVALID_LOGQ: Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
|
|
1400
1370
|
}
|
|
1401
1371
|
|
|
1402
1372
|
for lbfgs_status in mpr.lbfgs_status:
|
|
@@ -1423,7 +1393,7 @@ def multipath_pathfinder(
|
|
|
1423
1393
|
num_elbo_draws: int,
|
|
1424
1394
|
jitter: float,
|
|
1425
1395
|
epsilon: float,
|
|
1426
|
-
importance_sampling: Literal["psis", "psir", "identity"
|
|
1396
|
+
importance_sampling: Literal["psis", "psir", "identity"] | None,
|
|
1427
1397
|
progressbar: bool,
|
|
1428
1398
|
concurrent: Literal["thread", "process"] | None,
|
|
1429
1399
|
random_seed: RandomSeed,
|
|
@@ -1459,8 +1429,14 @@ def multipath_pathfinder(
|
|
|
1459
1429
|
Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
|
|
1460
1430
|
epsilon: float
|
|
1461
1431
|
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
|
|
1462
|
-
importance_sampling : str, optional
|
|
1463
|
-
|
|
1432
|
+
importance_sampling : str, None, optional
|
|
1433
|
+
Method to apply sampling based on log importance weights (logP - logQ).
|
|
1434
|
+
"psis" : Pareto Smoothed Importance Sampling (default)
|
|
1435
|
+
Recommended for more stable results.
|
|
1436
|
+
"psir" : Pareto Smoothed Importance Resampling
|
|
1437
|
+
Less stable than PSIS.
|
|
1438
|
+
"identity" : Applies log importance weights directly without resampling.
|
|
1439
|
+
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
|
|
1464
1440
|
progressbar : bool, optional
|
|
1465
1441
|
Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time.
|
|
1466
1442
|
random_seed : RandomSeed, optional
|
|
@@ -1482,12 +1458,6 @@ def multipath_pathfinder(
|
|
|
1482
1458
|
The result containing samples and other information from the Multi-Path Pathfinder algorithm.
|
|
1483
1459
|
"""
|
|
1484
1460
|
|
|
1485
|
-
valid_importance_sampling = ["psis", "psir", "identity", "none", None]
|
|
1486
|
-
if importance_sampling is None:
|
|
1487
|
-
importance_sampling = "none"
|
|
1488
|
-
if importance_sampling.lower() not in valid_importance_sampling:
|
|
1489
|
-
raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
|
|
1490
|
-
|
|
1491
1461
|
*path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1)
|
|
1492
1462
|
|
|
1493
1463
|
pathfinder_config = PathfinderConfig(
|
|
@@ -1521,12 +1491,20 @@ def multipath_pathfinder(
|
|
|
1521
1491
|
results = []
|
|
1522
1492
|
compute_start = time.time()
|
|
1523
1493
|
try:
|
|
1524
|
-
|
|
1494
|
+
desc = f"Paths Complete: {{path_idx}}/{num_paths}"
|
|
1495
|
+
progress = CustomProgress(
|
|
1496
|
+
"[progress.description]{task.description}",
|
|
1497
|
+
BarColumn(),
|
|
1498
|
+
"[progress.percentage]{task.percentage:>3.0f}%",
|
|
1499
|
+
TimeRemainingColumn(),
|
|
1500
|
+
TextColumn("/"),
|
|
1501
|
+
TimeElapsedColumn(),
|
|
1525
1502
|
console=Console(theme=default_progress_theme),
|
|
1526
1503
|
disable=not progressbar,
|
|
1527
|
-
)
|
|
1528
|
-
|
|
1529
|
-
|
|
1504
|
+
)
|
|
1505
|
+
with progress:
|
|
1506
|
+
task = progress.add_task(desc.format(path_idx=0), completed=0, total=num_paths)
|
|
1507
|
+
for path_idx, result in enumerate(generator, start=1):
|
|
1530
1508
|
try:
|
|
1531
1509
|
if isinstance(result, Exception):
|
|
1532
1510
|
raise result
|
|
@@ -1552,7 +1530,15 @@ def multipath_pathfinder(
|
|
|
1552
1530
|
lbfgs_status=LBFGSStatus.LBFGS_FAILED,
|
|
1553
1531
|
)
|
|
1554
1532
|
)
|
|
1555
|
-
|
|
1533
|
+
finally:
|
|
1534
|
+
# TODO: display LBFGS and Path Status in real time
|
|
1535
|
+
progress.update(
|
|
1536
|
+
task,
|
|
1537
|
+
description=desc.format(path_idx=path_idx),
|
|
1538
|
+
completed=path_idx,
|
|
1539
|
+
)
|
|
1540
|
+
# Ensure the progress bar visually reaches 100% and shows 'Completed'
|
|
1541
|
+
progress.update(task, completed=num_paths, description="Completed")
|
|
1556
1542
|
except (KeyboardInterrupt, StopIteration) as e:
|
|
1557
1543
|
# if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
|
|
1558
1544
|
if isinstance(e, StopIteration):
|
|
@@ -1602,11 +1588,11 @@ def fit_pathfinder(
|
|
|
1602
1588
|
maxiter: int = 1000, # L^max
|
|
1603
1589
|
ftol: float = 1e-5,
|
|
1604
1590
|
gtol: float = 1e-8,
|
|
1605
|
-
maxls=1000,
|
|
1591
|
+
maxls: int = 1000,
|
|
1606
1592
|
num_elbo_draws: int = 10, # K
|
|
1607
1593
|
jitter: float = 2.0,
|
|
1608
1594
|
epsilon: float = 1e-8,
|
|
1609
|
-
importance_sampling: Literal["psis", "psir", "identity"
|
|
1595
|
+
importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
|
|
1610
1596
|
progressbar: bool = True,
|
|
1611
1597
|
concurrent: Literal["thread", "process"] | None = None,
|
|
1612
1598
|
random_seed: RandomSeed | None = None,
|
|
@@ -1614,6 +1600,7 @@ def fit_pathfinder(
|
|
|
1614
1600
|
inference_backend: Literal["pymc", "blackjax"] = "pymc",
|
|
1615
1601
|
pathfinder_kwargs: dict = {},
|
|
1616
1602
|
compile_kwargs: dict = {},
|
|
1603
|
+
initvals: dict | None = None,
|
|
1617
1604
|
) -> az.InferenceData:
|
|
1618
1605
|
"""
|
|
1619
1606
|
Fit the Pathfinder Variational Inference algorithm.
|
|
@@ -1646,8 +1633,15 @@ def fit_pathfinder(
|
|
|
1646
1633
|
Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
|
|
1647
1634
|
epsilon: float
|
|
1648
1635
|
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
|
|
1649
|
-
importance_sampling : str, optional
|
|
1650
|
-
|
|
1636
|
+
importance_sampling : str, None, optional
|
|
1637
|
+
Method to apply sampling based on log importance weights (logP - logQ).
|
|
1638
|
+
Options are:
|
|
1639
|
+
|
|
1640
|
+
- "psis" : Pareto Smoothed Importance Sampling (default). Usually most stable.
|
|
1641
|
+
- "psir" : Pareto Smoothed Importance Resampling. Less stable than PSIS.
|
|
1642
|
+
- "identity" : Applies log importance weights directly without resampling.
|
|
1643
|
+
- None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
|
|
1644
|
+
|
|
1651
1645
|
progressbar : bool, optional
|
|
1652
1646
|
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
|
|
1653
1647
|
random_seed : RandomSeed, optional
|
|
@@ -1662,10 +1656,13 @@ def fit_pathfinder(
|
|
|
1662
1656
|
Additional keyword arguments for the Pathfinder algorithm.
|
|
1663
1657
|
compile_kwargs
|
|
1664
1658
|
Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
|
|
1659
|
+
initvals: dict | None = None
|
|
1660
|
+
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
|
|
1661
|
+
If None, the model's default initial values are used.
|
|
1665
1662
|
|
|
1666
1663
|
Returns
|
|
1667
1664
|
-------
|
|
1668
|
-
arviz.InferenceData
|
|
1665
|
+
:class:`~arviz.InferenceData`
|
|
1669
1666
|
The inference data containing the results of the Pathfinder algorithm.
|
|
1670
1667
|
|
|
1671
1668
|
References
|
|
@@ -1674,6 +1671,23 @@ def fit_pathfinder(
|
|
|
1674
1671
|
"""
|
|
1675
1672
|
|
|
1676
1673
|
model = modelcontext(model)
|
|
1674
|
+
|
|
1675
|
+
if initvals is not None:
|
|
1676
|
+
model = pm.model.fgraph.clone_model(model) # Create a clone of the model
|
|
1677
|
+
for (
|
|
1678
|
+
rv_name,
|
|
1679
|
+
ivals,
|
|
1680
|
+
) in initvals.items(): # Set the initial values for the variables in the clone
|
|
1681
|
+
model.set_initval(model.named_vars[rv_name], ivals)
|
|
1682
|
+
|
|
1683
|
+
valid_importance_sampling = {"psis", "psir", "identity", None}
|
|
1684
|
+
|
|
1685
|
+
if importance_sampling is not None:
|
|
1686
|
+
importance_sampling = importance_sampling.lower()
|
|
1687
|
+
|
|
1688
|
+
if importance_sampling not in valid_importance_sampling:
|
|
1689
|
+
raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
|
|
1690
|
+
|
|
1677
1691
|
N = DictToArrayBijection.map(model.initial_point()).data.shape[0]
|
|
1678
1692
|
|
|
1679
1693
|
if maxcor is None:
|
|
@@ -1704,8 +1718,9 @@ def fit_pathfinder(
|
|
|
1704
1718
|
)
|
|
1705
1719
|
pathfinder_samples = mp_result.samples
|
|
1706
1720
|
elif inference_backend == "blackjax":
|
|
1707
|
-
|
|
1708
|
-
|
|
1721
|
+
import blackjax
|
|
1722
|
+
import jax
|
|
1723
|
+
|
|
1709
1724
|
if version.parse(blackjax.__version__).major < 1:
|
|
1710
1725
|
raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
|
|
1711
1726
|
|
|
@@ -1743,4 +1758,7 @@ def fit_pathfinder(
|
|
|
1743
1758
|
model=model,
|
|
1744
1759
|
importance_sampling=importance_sampling,
|
|
1745
1760
|
)
|
|
1761
|
+
|
|
1762
|
+
idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
|
|
1763
|
+
|
|
1746
1764
|
return idata
|
|
@@ -28,7 +28,6 @@ from pymc_extras.statespace.filters import (
|
|
|
28
28
|
)
|
|
29
29
|
from pymc_extras.statespace.filters.distributions import (
|
|
30
30
|
LinearGaussianStateSpace,
|
|
31
|
-
MvNormalSVD,
|
|
32
31
|
SequenceMvNormal,
|
|
33
32
|
)
|
|
34
33
|
from pymc_extras.statespace.filters.utilities import stabilize
|
|
@@ -707,7 +706,7 @@ class PyMCStateSpace:
|
|
|
707
706
|
with pymc_model:
|
|
708
707
|
for param_name in self.param_names:
|
|
709
708
|
param = getattr(pymc_model, param_name, None)
|
|
710
|
-
if param:
|
|
709
|
+
if param is not None:
|
|
711
710
|
found_params.append(param.name)
|
|
712
711
|
|
|
713
712
|
missing_params = list(set(self.param_names) - set(found_params))
|
|
@@ -746,7 +745,7 @@ class PyMCStateSpace:
|
|
|
746
745
|
with pymc_model:
|
|
747
746
|
for data_name in data_names:
|
|
748
747
|
data = getattr(pymc_model, data_name, None)
|
|
749
|
-
if data:
|
|
748
|
+
if data is not None:
|
|
750
749
|
found_data.append(data.name)
|
|
751
750
|
|
|
752
751
|
missing_data = list(set(data_names) - set(found_data))
|
|
@@ -2233,7 +2232,9 @@ class PyMCStateSpace:
|
|
|
2233
2232
|
if shock_trajectory is None:
|
|
2234
2233
|
shock_trajectory = pt.zeros((n_steps, self.k_posdef))
|
|
2235
2234
|
if Q is not None:
|
|
2236
|
-
init_shock =
|
|
2235
|
+
init_shock = pm.MvNormal(
|
|
2236
|
+
"initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method="svd"
|
|
2237
|
+
)
|
|
2237
2238
|
else:
|
|
2238
2239
|
init_shock = pm.Deterministic(
|
|
2239
2240
|
"initial_shock",
|