pymc-extras 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1746 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+ import logging
17
+ import time
18
+ import warnings as _warnings
19
+
20
+ from collections import Counter
21
+ from collections.abc import Callable, Iterator
22
+ from dataclasses import asdict, dataclass, field, replace
23
+ from enum import Enum, auto
24
+ from importlib.util import find_spec
25
+ from typing import Literal, TypeAlias
26
+
27
+ import arviz as az
28
+ import blackjax
29
+ import filelock
30
+ import jax
31
+ import numpy as np
32
+ import pymc as pm
33
+ import pytensor
34
+ import pytensor.tensor as pt
35
+
36
+ from numpy.typing import NDArray
37
+ from packaging import version
38
+ from pymc import Model
39
+ from pymc.backends.arviz import coords_and_dims_for_inferencedata
40
+ from pymc.blocking import DictToArrayBijection, RaveledVars
41
+ from pymc.initial_point import make_initial_point_fn
42
+ from pymc.model import modelcontext
43
+ from pymc.model.core import Point
44
+ from pymc.pytensorf import (
45
+ compile_pymc,
46
+ find_rng_nodes,
47
+ reseed_rngs,
48
+ )
49
+ from pymc.sampling.jax import get_jaxified_graph
50
+ from pymc.util import (
51
+ CustomProgress,
52
+ RandomSeed,
53
+ _get_seeds_per_chain,
54
+ default_progress_theme,
55
+ get_default_varnames,
56
+ )
57
+ from pytensor.compile.function.types import Function
58
+ from pytensor.compile.mode import FAST_COMPILE, Mode
59
+ from pytensor.graph import Apply, Op, vectorize_graph
60
+ from pytensor.tensor import TensorConstant, TensorVariable
61
+ from rich.console import Console, Group
62
+ from rich.padding import Padding
63
+ from rich.table import Table
64
+ from rich.text import Text
65
+
66
+ # TODO: change to typing.Self after Python versions greater than 3.10
67
+ from typing_extensions import Self
68
+
69
+ from pymc_extras.inference.pathfinder.importance_sampling import (
70
+ importance_sampling as _importance_sampling,
71
+ )
72
+ from pymc_extras.inference.pathfinder.lbfgs import (
73
+ LBFGS,
74
+ LBFGSException,
75
+ LBFGSInitFailed,
76
+ LBFGSStatus,
77
+ )
78
+
79
+ logger = logging.getLogger(__name__)
80
+ _warnings.filterwarnings(
81
+ "ignore", category=FutureWarning, message="compile_pymc was renamed to compile"
82
+ )
83
+
84
+ REGULARISATION_TERM = 1e-8
85
+ DEFAULT_LINKER = "cvm_nogc"
86
+
87
+ SinglePathfinderFn: TypeAlias = Callable[[int], "PathfinderResult"]
88
+
89
+
90
+ def get_jaxified_logp_of_ravel_inputs(model: Model, jacobian: bool = True) -> Callable:
91
+ """
92
+ Get a JAX function that computes the log-probability of a PyMC model with ravelled inputs.
93
+
94
+ Parameters
95
+ ----------
96
+ model : Model
97
+ PyMC model to compute log-probability and gradient.
98
+ jacobian : bool, optional
99
+ Whether to include the Jacobian in the log-probability computation, by default True. Setting to False (not recommended) may result in very high values for pareto k.
100
+
101
+ Returns
102
+ -------
103
+ Function
104
+ A JAX function that computes the log-probability of a PyMC model with ravelled inputs.
105
+ """
106
+
107
+ # TODO: JAX: test if we should get jaxified graph of dlogp as well
108
+ new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
109
+ model.initial_point(), (model.logp(jacobian=jacobian),), model.value_vars, ()
110
+ )
111
+
112
+ logp_func_list = get_jaxified_graph([new_input], new_logprob)
113
+
114
+ def logp_func(x):
115
+ return logp_func_list(x)[0]
116
+
117
+ return logp_func
118
+
119
+
120
+ def get_logp_dlogp_of_ravel_inputs(
121
+ model: Model, jacobian: bool = True, **compile_kwargs
122
+ ) -> Function:
123
+ """
124
+ Get the log-probability and its gradient for a PyMC model with ravelled inputs.
125
+
126
+ Parameters
127
+ ----------
128
+ model : Model
129
+ PyMC model to compute log-probability and gradient.
130
+ jacobian : bool, optional
131
+ Whether to include the Jacobian in the log-probability computation, by default True. Setting to False (not recommended) may result in very high values for pareto k.
132
+ **compile_kwargs : dict
133
+ Additional keyword arguments to pass to the compile function.
134
+
135
+ Returns
136
+ -------
137
+ Function
138
+ A compiled PyTensor function that computes the log-probability and its gradient given ravelled inputs.
139
+ """
140
+
141
+ (logP, dlogP), inputs = pm.pytensorf.join_nonshared_inputs(
142
+ model.initial_point(),
143
+ [model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)],
144
+ model.value_vars,
145
+ )
146
+ logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs)
147
+ logp_dlogp_fn.trust_input = True
148
+
149
+ return logp_dlogp_fn
150
+
151
+
152
+ def convert_flat_trace_to_idata(
153
+ samples: NDArray,
154
+ include_transformed: bool = False,
155
+ postprocessing_backend: Literal["cpu", "gpu"] = "cpu",
156
+ inference_backend: Literal["pymc", "blackjax"] = "pymc",
157
+ model: Model | None = None,
158
+ importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
159
+ ) -> az.InferenceData:
160
+ """convert flattened samples to arviz InferenceData format.
161
+
162
+ Parameters
163
+ ----------
164
+ samples : NDArray
165
+ flattened samples
166
+ include_transformed : bool
167
+ whether to include transformed variables
168
+ postprocessing_backend : str
169
+ backend for postprocessing transformations, either "cpu" or "gpu"
170
+ inference_backend : str
171
+ backend for inference, either "pymc" or "blackjax"
172
+ model : Model | None
173
+ pymc model for variable transformations
174
+ importance_sampling : str
175
+ importance sampling method used, affects input samples shape
176
+
177
+ Returns
178
+ -------
179
+ InferenceData
180
+ arviz inference data object
181
+ """
182
+
183
+ if importance_sampling == "none":
184
+ # samples.ndim == 3 in this case, otherwise ndim == 2
185
+ num_paths, num_pdraws, N = samples.shape
186
+ samples = samples.reshape(-1, N)
187
+
188
+ model = modelcontext(model)
189
+ ip = model.initial_point()
190
+ ip_point_map_info = DictToArrayBijection.map(ip).point_map_info
191
+ trace = collections.defaultdict(list)
192
+ for sample in samples:
193
+ raveld_vars = RaveledVars(sample, ip_point_map_info)
194
+ point = DictToArrayBijection.rmap(raveld_vars, ip)
195
+ for p, v in point.items():
196
+ trace[p].append(v.tolist())
197
+
198
+ trace = {k: np.asarray(v)[None, ...] for k, v in trace.items()}
199
+
200
+ var_names = model.unobserved_value_vars
201
+ vars_to_sample = list(get_default_varnames(var_names, include_transformed=include_transformed))
202
+ logger.info("Transforming variables...")
203
+
204
+ if inference_backend == "pymc":
205
+ new_shapes = [v.ndim * (None,) for v in trace.values()]
206
+ replace = {
207
+ var: pt.tensor(dtype="float64", shape=new_shapes[i])
208
+ for i, var in enumerate(model.value_vars)
209
+ }
210
+
211
+ outputs = vectorize_graph(vars_to_sample, replace=replace)
212
+
213
+ fn = pytensor.function(
214
+ inputs=[*list(replace.values())],
215
+ outputs=outputs,
216
+ mode=FAST_COMPILE,
217
+ on_unused_input="ignore",
218
+ )
219
+ fn.trust_input = True
220
+ result = fn(*list(trace.values()))
221
+
222
+ if importance_sampling == "none":
223
+ result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
224
+
225
+ elif inference_backend == "blackjax":
226
+ jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
227
+ result = jax.vmap(jax.vmap(jax_fn))(
228
+ *jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
229
+ )
230
+
231
+ trace = {v.name: r for v, r in zip(vars_to_sample, result)}
232
+ coords, dims = coords_and_dims_for_inferencedata(model)
233
+ idata = az.from_dict(trace, dims=dims, coords=coords)
234
+
235
+ return idata
236
+
237
+
238
+ def alpha_recover(
239
+ x: TensorVariable, g: TensorVariable, epsilon: TensorVariable
240
+ ) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]:
241
+ """compute the diagonal elements of the inverse Hessian at each iterations of L-BFGS and filter updates.
242
+
243
+ Parameters
244
+ ----------
245
+ x : TensorVariable
246
+ position array, shape (L+1, N)
247
+ g : TensorVariable
248
+ gradient array, shape (L+1, N)
249
+ epsilon : float
250
+ threshold for filtering updates based on inner product of position
251
+ and gradient differences
252
+
253
+ Returns
254
+ -------
255
+ alpha : TensorVariable
256
+ diagonal elements of the inverse Hessian at each iteration of L-BFGS, shape (L, N)
257
+ s : TensorVariable
258
+ position differences, shape (L, N)
259
+ z : TensorVariable
260
+ gradient differences, shape (L, N)
261
+ update_mask : TensorVariable
262
+ mask for filtering updates, shape (L,)
263
+
264
+ Notes
265
+ -----
266
+ shapes: L=batch_size, N=num_params
267
+ """
268
+
269
+ def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable:
270
+ # alpha_lm1: (N,)
271
+ # s_l: (N,)
272
+ # z_l: (N,)
273
+ a = z_l.T @ pt.diag(alpha_lm1) @ z_l
274
+ b = z_l.T @ s_l
275
+ c = s_l.T @ pt.diag(1.0 / alpha_lm1) @ s_l
276
+ inv_alpha_l = (
277
+ a / (b * alpha_lm1)
278
+ + z_l ** 2 / b
279
+ - (a * s_l ** 2) / (b * c * alpha_lm1**2)
280
+ ) # fmt:off
281
+ return 1.0 / inv_alpha_l
282
+
283
+ def return_alpha_lm1(alpha_lm1, s_l, z_l) -> TensorVariable:
284
+ return alpha_lm1[-1]
285
+
286
+ def scan_body(update_mask_l, s_l, z_l, alpha_lm1) -> TensorVariable:
287
+ return pt.switch(
288
+ update_mask_l,
289
+ compute_alpha_l(alpha_lm1, s_l, z_l),
290
+ return_alpha_lm1(alpha_lm1, s_l, z_l),
291
+ )
292
+
293
+ Lp1, N = x.shape
294
+ s = pt.diff(x, axis=0)
295
+ z = pt.diff(g, axis=0)
296
+ alpha_l_init = pt.ones(N)
297
+ sz = (s * z).sum(axis=-1)
298
+ # update_mask = sz > epsilon * pt.linalg.norm(z, axis=-1)
299
+ # pt.linalg.norm does not work with JAX!!
300
+ update_mask = sz > epsilon * pt.sqrt(pt.sum(z**2, axis=-1))
301
+
302
+ alpha, _ = pytensor.scan(
303
+ fn=scan_body,
304
+ outputs_info=alpha_l_init,
305
+ sequences=[update_mask, s, z],
306
+ n_steps=Lp1 - 1,
307
+ allow_gc=False,
308
+ )
309
+
310
+ # assert np.all(alpha.eval() > 0), "alpha cannot be negative"
311
+ # alpha: (L, N), update_mask: (L, N)
312
+ return alpha, s, z, update_mask
313
+
314
+
315
+ def inverse_hessian_factors(
316
+ alpha: TensorVariable,
317
+ s: TensorVariable,
318
+ z: TensorVariable,
319
+ update_mask: TensorVariable,
320
+ J: TensorConstant,
321
+ ) -> tuple[TensorVariable, TensorVariable]:
322
+ """compute the inverse hessian factors for the BFGS approximation.
323
+
324
+ Parameters
325
+ ----------
326
+ alpha : TensorVariable
327
+ diagonal scaling matrix, shape (L, N)
328
+ s : TensorVariable
329
+ position differences, shape (L, N)
330
+ z : TensorVariable
331
+ gradient differences, shape (L, N)
332
+ update_mask : TensorVariable
333
+ mask for filtering updates, shape (L,)
334
+ J : TensorConstant
335
+ history size for L-BFGS
336
+
337
+ Returns
338
+ -------
339
+ beta : TensorVariable
340
+ low-rank update matrix, shape (L, N, 2J)
341
+ gamma : TensorVariable
342
+ low-rank update matrix, shape (L, 2J, 2J)
343
+
344
+ Notes
345
+ -----
346
+ shapes: L=batch_size, N=num_params, J=history_size
347
+ """
348
+
349
+ # NOTE: get_chi_matrix_1 is a modified version of get_chi_matrix_2 to closely follow Zhang et al., (2022)
350
+ # NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented
351
+
352
+ def get_chi_matrix_1(
353
+ diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
354
+ ) -> TensorVariable:
355
+ L, N = diff.shape
356
+ j_last = pt.as_tensor(J - 1) # since indexing starts at 0
357
+
358
+ def chi_update(chi_lm1, diff_l) -> TensorVariable:
359
+ chi_l = pt.roll(chi_lm1, -1, axis=0)
360
+ return pt.set_subtensor(chi_l[j_last], diff_l)
361
+
362
+ def no_op(chi_lm1, diff_l) -> TensorVariable:
363
+ return chi_lm1
364
+
365
+ def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable:
366
+ return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l))
367
+
368
+ chi_init = pt.zeros((J, N))
369
+ chi_mat, _ = pytensor.scan(
370
+ fn=scan_body,
371
+ outputs_info=chi_init,
372
+ sequences=[
373
+ update_mask,
374
+ diff,
375
+ ],
376
+ allow_gc=False,
377
+ )
378
+
379
+ chi_mat = pt.matrix_transpose(chi_mat)
380
+
381
+ # (L, N, J)
382
+ return chi_mat
383
+
384
+ def get_chi_matrix_2(
385
+ diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
386
+ ) -> TensorVariable:
387
+ L, N = diff.shape
388
+
389
+ diff_masked = update_mask[:, None] * diff
390
+
391
+ # diff_padded: (L+J, N)
392
+ pad_width = pt.zeros(shape=(2, 2), dtype="int32")
393
+ pad_width = pt.set_subtensor(pad_width[0, 0], J)
394
+ diff_padded = pt.pad(diff_masked, pad_width, mode="constant")
395
+
396
+ index = pt.arange(L)[:, None] + pt.arange(J)[None, :]
397
+ index = index.reshape((L, J))
398
+
399
+ chi_mat = pt.matrix_transpose(diff_padded[index])
400
+
401
+ # (L, N, J)
402
+ return chi_mat
403
+
404
+ L, N = alpha.shape
405
+ S = get_chi_matrix_1(s, update_mask, J)
406
+ Z = get_chi_matrix_1(z, update_mask, J)
407
+
408
+ # E: (L, J, J)
409
+ Ij = pt.eye(J)[None, ...]
410
+ E = pt.triu(pt.matrix_transpose(S) @ Z)
411
+ E += Ij * REGULARISATION_TERM
412
+
413
+ # eta: (L, J)
414
+ eta = pt.diagonal(E, axis1=-2, axis2=-1)
415
+
416
+ # beta: (L, N, 2J)
417
+ alpha_diag, _ = pytensor.scan(lambda a: pt.diag(a), sequences=[alpha])
418
+ beta = pt.concatenate([alpha_diag @ Z, S], axis=-1)
419
+
420
+ # more performant and numerically precise to use solve than inverse: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.inv.html
421
+
422
+ # E_inv: (L, J, J)
423
+ E_inv = pt.slinalg.solve_triangular(E, Ij, check_finite=False)
424
+ eta_diag, _ = pytensor.scan(pt.diag, sequences=[eta])
425
+
426
+ # block_dd: (L, J, J)
427
+ block_dd = (
428
+ pt.matrix_transpose(E_inv) @ (eta_diag + pt.matrix_transpose(Z) @ alpha_diag @ Z) @ E_inv
429
+ )
430
+
431
+ # (L, J, 2J)
432
+ gamma_top = pt.concatenate([pt.zeros((L, J, J)), -E_inv], axis=-1)
433
+
434
+ # (L, J, 2J)
435
+ gamma_bottom = pt.concatenate([-pt.matrix_transpose(E_inv), block_dd], axis=-1)
436
+
437
+ # (L, 2J, 2J)
438
+ gamma = pt.concatenate([gamma_top, gamma_bottom], axis=1)
439
+
440
+ return beta, gamma
441
+
442
+
443
+ def bfgs_sample_dense(
444
+ x: TensorVariable,
445
+ g: TensorVariable,
446
+ alpha: TensorVariable,
447
+ beta: TensorVariable,
448
+ gamma: TensorVariable,
449
+ alpha_diag: TensorVariable,
450
+ inv_sqrt_alpha_diag: TensorVariable,
451
+ sqrt_alpha_diag: TensorVariable,
452
+ u: TensorVariable,
453
+ ) -> tuple[TensorVariable, TensorVariable]:
454
+ """sample from the BFGS approximation using dense matrix operations.
455
+
456
+ Parameters
457
+ ----------
458
+ x : TensorVariable
459
+ position array, shape (L, N)
460
+ g : TensorVariable
461
+ gradient array, shape (L, N)
462
+ alpha : TensorVariable
463
+ diagonal scaling matrix, shape (L, N)
464
+ beta : TensorVariable
465
+ low-rank update matrix, shape (L, N, 2J)
466
+ gamma : TensorVariable
467
+ low-rank update matrix, shape (L, 2J, 2J)
468
+ alpha_diag : TensorVariable
469
+ diagonal matrix of alpha, shape (L, N, N)
470
+ inv_sqrt_alpha_diag : TensorVariable
471
+ inverse sqrt of alpha diagonal, shape (L, N, N)
472
+ sqrt_alpha_diag : TensorVariable
473
+ sqrt of alpha diagonal, shape (L, N, N)
474
+ u : TensorVariable
475
+ random normal samples, shape (L, M, N)
476
+
477
+ Returns
478
+ -------
479
+ phi : TensorVariable
480
+ samples from the approximation, shape (L, M, N)
481
+ logdet : TensorVariable
482
+ log determinant of covariance, shape (L,)
483
+
484
+ Notes
485
+ -----
486
+ shapes: L=batch_size, N=num_params, J=history_size, M=num_samples
487
+ """
488
+
489
+ N = x.shape[-1]
490
+ IdN = pt.eye(N)[None, ...]
491
+
492
+ # inverse Hessian
493
+ H_inv = (
494
+ sqrt_alpha_diag
495
+ @ (
496
+ IdN
497
+ + inv_sqrt_alpha_diag @ beta @ gamma @ pt.matrix_transpose(beta) @ inv_sqrt_alpha_diag
498
+ )
499
+ @ sqrt_alpha_diag
500
+ )
501
+
502
+ Lchol = pt.linalg.cholesky(H_inv, lower=False, check_finite=False, on_error="nan")
503
+
504
+ logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
505
+
506
+ mu = x - pt.batched_dot(H_inv, g)
507
+
508
+ phi = pt.matrix_transpose(
509
+ # (L, N, 1)
510
+ mu[..., None]
511
+ # (L, N, M)
512
+ + Lchol @ pt.matrix_transpose(u)
513
+ ) # fmt: off
514
+
515
+ return phi, logdet
516
+
517
+
518
+ def bfgs_sample_sparse(
519
+ x: TensorVariable,
520
+ g: TensorVariable,
521
+ alpha: TensorVariable,
522
+ beta: TensorVariable,
523
+ gamma: TensorVariable,
524
+ alpha_diag: TensorVariable,
525
+ inv_sqrt_alpha_diag: TensorVariable,
526
+ sqrt_alpha_diag: TensorVariable,
527
+ u: TensorVariable,
528
+ ) -> tuple[TensorVariable, TensorVariable]:
529
+ """sample from the BFGS approximation using sparse matrix operations.
530
+
531
+ Parameters
532
+ ----------
533
+ x : TensorVariable
534
+ position array, shape (L, N)
535
+ g : TensorVariable
536
+ gradient array, shape (L, N)
537
+ alpha : TensorVariable
538
+ diagonal scaling matrix, shape (L, N)
539
+ beta : TensorVariable
540
+ low-rank update matrix, shape (L, N, 2J)
541
+ gamma : TensorVariable
542
+ low-rank update matrix, shape (L, 2J, 2J)
543
+ alpha_diag : TensorVariable
544
+ diagonal matrix of alpha, shape (L, N, N)
545
+ inv_sqrt_alpha_diag : TensorVariable
546
+ inverse sqrt of alpha diagonal, shape (L, N, N)
547
+ sqrt_alpha_diag : TensorVariable
548
+ sqrt of alpha diagonal, shape (L, N, N)
549
+ u : TensorVariable
550
+ random normal samples, shape (L, M, N)
551
+
552
+ Returns
553
+ -------
554
+ phi : TensorVariable
555
+ samples from the approximation, shape (L, M, N)
556
+ logdet : TensorVariable
557
+ log determinant of covariance, shape (L,)
558
+
559
+ Notes
560
+ -----
561
+ shapes: L=batch_size, N=num_params, J=history_size, M=num_samples
562
+ """
563
+
564
+ # qr_input: (L, N, 2J)
565
+ qr_input = inv_sqrt_alpha_diag @ beta
566
+ (Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False)
567
+ IdN = pt.eye(R.shape[1])[None, ...]
568
+ Lchol_input = IdN + R @ gamma @ pt.matrix_transpose(R)
569
+
570
+ Lchol = pt.linalg.cholesky(Lchol_input, lower=False, check_finite=False, on_error="nan")
571
+
572
+ logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
573
+ logdet += pt.sum(pt.log(alpha), axis=-1)
574
+
575
+ # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
576
+ mu = x - (
577
+ # (L, N), (L, N) -> (L, N)
578
+ pt.batched_dot(alpha_diag, g)
579
+ # beta @ gamma @ beta.T
580
+ # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
581
+ # (L, N, N), (L, N) -> (L, N)
582
+ + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
583
+ )
584
+
585
+ phi = pt.matrix_transpose(
586
+ # (L, N, 1)
587
+ mu[..., None]
588
+ # (L, N, N), (L, N, M) -> (L, N, M)
589
+ + sqrt_alpha_diag
590
+ @ (
591
+ # (L, N, 2J), (L, 2J, M) -> (L, N, M)
592
+ # intermediate calcs below
593
+ # (L, N, 2J), (L, 2J, 2J) -> (L, N, 2J)
594
+ (Q @ (Lchol - IdN))
595
+ # (L, 2J, N), (L, N, M) -> (L, 2J, M)
596
+ @ (pt.matrix_transpose(Q) @ pt.matrix_transpose(u))
597
+ # (L, N, M)
598
+ + pt.matrix_transpose(u)
599
+ )
600
+ ) # fmt: off
601
+
602
+ return phi, logdet
603
+
604
+
605
+ def bfgs_sample(
606
+ num_samples: TensorConstant,
607
+ x: TensorVariable, # position
608
+ g: TensorVariable, # grad
609
+ alpha: TensorVariable,
610
+ beta: TensorVariable,
611
+ gamma: TensorVariable,
612
+ index: TensorVariable | None = None,
613
+ ) -> tuple[TensorVariable, TensorVariable]:
614
+ """sample from the BFGS approximation using the inverse hessian factors.
615
+
616
+ Parameters
617
+ ----------
618
+ num_samples : TensorConstant
619
+ number of samples to draw
620
+ x : TensorVariable
621
+ position array, shape (L, N)
622
+ g : TensorVariable
623
+ gradient array, shape (L, N)
624
+ alpha : TensorVariable
625
+ diagonal scaling matrix, shape (L, N)
626
+ beta : TensorVariable
627
+ low-rank update matrix, shape (L, N, 2J)
628
+ gamma : TensorVariable
629
+ low-rank update matrix, shape (L, 2J, 2J)
630
+ index : TensorVariable | None
631
+ optional index for selecting a single path
632
+
633
+ Returns
634
+ -------
635
+ if index is None:
636
+ phi: samples from local approximations over L (L, M, N)
637
+ logQ_phi: log density of samples of phi (L, M)
638
+ else:
639
+ psi: samples from local approximations where ELBO is maximized (1, M, N)
640
+ logQ_psi: log density of samples of psi (1, M)
641
+
642
+ Notes
643
+ -----
644
+ shapes: L=batch_size, N=num_params, J=history_size, M=num_samples
645
+ """
646
+
647
+ if index is not None:
648
+ x = x[index][None, ...]
649
+ g = g[index][None, ...]
650
+ alpha = alpha[index][None, ...]
651
+ beta = beta[index][None, ...]
652
+ gamma = gamma[index][None, ...]
653
+
654
+ L, N, JJ = beta.shape
655
+
656
+ (alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag), _ = pytensor.scan(
657
+ lambda a: [pt.diag(a), pt.diag(pt.sqrt(1.0 / a)), pt.diag(pt.sqrt(a))],
658
+ sequences=[alpha],
659
+ allow_gc=False,
660
+ )
661
+
662
+ u = pt.random.normal(size=(L, num_samples, N))
663
+
664
+ sample_inputs = (
665
+ x,
666
+ g,
667
+ alpha,
668
+ beta,
669
+ gamma,
670
+ alpha_diag,
671
+ inv_sqrt_alpha_diag,
672
+ sqrt_alpha_diag,
673
+ u,
674
+ )
675
+
676
+ phi, logdet = pytensor.ifelse(
677
+ JJ >= N,
678
+ bfgs_sample_dense(*sample_inputs),
679
+ bfgs_sample_sparse(*sample_inputs),
680
+ )
681
+
682
+ logQ_phi = -0.5 * (
683
+ logdet[..., None]
684
+ + pt.sum(u * u, axis=-1)
685
+ + N * pt.log(2.0 * pt.pi)
686
+ ) # fmt: off
687
+
688
+ mask = pt.isnan(logQ_phi) | pt.isinf(logQ_phi)
689
+ logQ_phi = pt.set_subtensor(logQ_phi[mask], pt.inf)
690
+ return phi, logQ_phi
691
+
692
+
693
+ class LogLike(Op):
694
+ """
695
+ Op that computes the densities using vectorised operations.
696
+ """
697
+
698
+ __props__ = ("logp_func",)
699
+
700
+ def __init__(self, logp_func: Callable):
701
+ self.logp_func = logp_func
702
+ super().__init__()
703
+
704
+ def make_node(self, inputs):
705
+ inputs = pt.as_tensor(inputs)
706
+ outputs = pt.tensor(dtype="float64", shape=(None, None))
707
+ return Apply(self, [inputs], [outputs])
708
+
709
+ def perform(self, node: Apply, inputs, outputs) -> None:
710
+ phi = inputs[0]
711
+ logP = np.apply_along_axis(self.logp_func, axis=-1, arr=phi)
712
+ # replace nan with -inf since np.argmax will return the first index at nan
713
+ mask = np.isnan(logP) | np.isinf(logP)
714
+ if np.all(mask):
715
+ raise PathInvalidLogP()
716
+ outputs[0][0] = np.where(mask, -np.inf, logP)
717
+
718
+
719
+ class PathStatus(Enum):
720
+ """
721
+ Statuses of a single-path pathfinder.
722
+ """
723
+
724
+ SUCCESS = auto()
725
+ ELBO_ARGMAX_AT_ZERO = auto()
726
+ # Statuses that lead to Exceptions:
727
+ INVALID_LOGP = auto()
728
+ INVALID_LOGQ = auto()
729
+ LBFGS_FAILED = auto()
730
+ PATH_FAILED = auto()
731
+
732
+
733
+ FAILED_PATH_STATUS = [
734
+ PathStatus.INVALID_LOGP,
735
+ PathStatus.INVALID_LOGQ,
736
+ PathStatus.LBFGS_FAILED,
737
+ PathStatus.PATH_FAILED,
738
+ ]
739
+
740
+
741
+ class PathException(Exception):
742
+ """
743
+ raises a PathException if the path failed.
744
+ """
745
+
746
+ DEFAULT_MESSAGE = "Path failed."
747
+
748
+ def __init__(self, message=None, status: PathStatus = PathStatus.PATH_FAILED) -> None:
749
+ super().__init__(message or self.DEFAULT_MESSAGE)
750
+ self.status = status
751
+
752
+
753
+ class PathInvalidLogP(PathException):
754
+ """
755
+ raises a PathException if all the logP values in a path are not finite.
756
+ """
757
+
758
+ DEFAULT_MESSAGE = "Path failed because all the logP values in a path are not finite."
759
+
760
+ def __init__(self, message=None) -> None:
761
+ super().__init__(message or self.DEFAULT_MESSAGE, PathStatus.INVALID_LOGP)
762
+
763
+
764
+ class PathInvalidLogQ(PathException):
765
+ """
766
+ raises a PathException if all the logQ values in a path are not finite.
767
+ """
768
+
769
+ DEFAULT_MESSAGE = "Path failed because all the logQ values in a path are not finite."
770
+
771
+ def __init__(self, message=None) -> None:
772
+ super().__init__(message or self.DEFAULT_MESSAGE, PathStatus.INVALID_LOGQ)
773
+
774
+
775
+ def make_pathfinder_body(
776
+ logp_func: Callable,
777
+ num_draws: int,
778
+ maxcor: int,
779
+ num_elbo_draws: int,
780
+ epsilon: float,
781
+ **compile_kwargs: dict,
782
+ ) -> Function:
783
+ """
784
+ computes the inner components of the Pathfinder algorithm (post-LBFGS) using PyTensor variables and returns a compiled pytensor.function.
785
+
786
+ Parameters
787
+ ----------
788
+ logp_func : Callable
789
+ The target density function.
790
+ num_draws : int
791
+ Number of samples to draw from the single-path approximation.
792
+ maxcor : int
793
+ The maximum number of iterations for the L-BFGS algorithm.
794
+ num_elbo_draws : int
795
+ The number of draws for the Evidence Lower Bound (ELBO) estimation.
796
+ epsilon : float
797
+ The value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L.
798
+ compile_kwargs : dict
799
+ Additional keyword arguments for the PyTensor compiler.
800
+
801
+ Returns
802
+ -------
803
+ pathfinder_body_fn : Function
804
+ A compiled pytensor.function that performs the inner components of the Pathfinder algorithm (post-LBFGS).
805
+
806
+ pathfinder_body_fn inputs:
807
+ x_full: (L+1, N),
808
+ g_full: (L+1, N)
809
+ pathfinder_body_fn outputs:
810
+ psi: (1, M, N),
811
+ logP_psi: (1, M),
812
+ logQ_psi: (1, M),
813
+ elbo_argmax: (1,)
814
+ """
815
+
816
+ # x_full, g_full: (L+1, N)
817
+ x_full = pt.matrix("x", dtype="float64")
818
+ g_full = pt.matrix("g", dtype="float64")
819
+
820
+ num_draws = pt.constant(num_draws, "num_draws", dtype="int32")
821
+ num_elbo_draws = pt.constant(num_elbo_draws, "num_elbo_draws", dtype="int32")
822
+ epsilon = pt.constant(epsilon, "epsilon", dtype="float64")
823
+ maxcor = pt.constant(maxcor, "maxcor", dtype="int32")
824
+
825
+ alpha, s, z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon)
826
+ beta, gamma = inverse_hessian_factors(alpha, s, z, update_mask, J=maxcor)
827
+
828
+ # ignore initial point - x, g: (L, N)
829
+ x = x_full[1:]
830
+ g = g_full[1:]
831
+
832
+ phi, logQ_phi = bfgs_sample(
833
+ num_samples=num_elbo_draws, x=x, g=g, alpha=alpha, beta=beta, gamma=gamma
834
+ )
835
+
836
+ loglike = LogLike(logp_func)
837
+ logP_phi = loglike(phi)
838
+ elbo = pt.mean(logP_phi - logQ_phi, axis=-1)
839
+ elbo_argmax = pt.argmax(elbo, axis=0)
840
+
841
+ # TODO: move the raise PathInvalidLogQ from single_pathfinder_fn to here to avoid computing logP_psi if logQ_psi is invalid. Possible setup: logQ_phi = PathCheck()(logQ_phi, ~pt.all(mask)), where PathCheck uses pytensor raise.
842
+
843
+ # sample from the single-path approximation
844
+ psi, logQ_psi = bfgs_sample(
845
+ num_samples=num_draws,
846
+ x=x,
847
+ g=g,
848
+ alpha=alpha,
849
+ beta=beta,
850
+ gamma=gamma,
851
+ index=elbo_argmax,
852
+ )
853
+ logP_psi = loglike(psi)
854
+
855
+ # return psi, logP_psi, logQ_psi, elbo_argmax
856
+
857
+ pathfinder_body_fn = compile_pymc(
858
+ [x_full, g_full],
859
+ [psi, logP_psi, logQ_psi, elbo_argmax],
860
+ **compile_kwargs,
861
+ )
862
+ pathfinder_body_fn.trust_input = True
863
+ return pathfinder_body_fn
864
+
865
+
866
+ def make_single_pathfinder_fn(
867
+ model,
868
+ num_draws: int,
869
+ maxcor: int | None,
870
+ maxiter: int,
871
+ ftol: float,
872
+ gtol: float,
873
+ maxls: int,
874
+ num_elbo_draws: int,
875
+ jitter: float,
876
+ epsilon: float,
877
+ pathfinder_kwargs: dict = {},
878
+ compile_kwargs: dict = {},
879
+ ) -> SinglePathfinderFn:
880
+ """
881
+ returns a seedable single-path pathfinder function, where it executes a compiled function that performs the local approximation and sampling part of the Pathfinder algorithm.
882
+
883
+ Parameters
884
+ ----------
885
+ model : pymc.Model
886
+ The PyMC model to fit the Pathfinder algorithm to.
887
+ num_draws : int
888
+ Number of samples to draw from the single-path approximation.
889
+ maxcor : int | None
890
+ Maximum number of iterations for the L-BFGS optimisation.
891
+ maxiter : int
892
+ Maximum number of iterations for the L-BFGS optimisation.
893
+ ftol : float
894
+ Tolerance for the decrease in the objective function.
895
+ gtol : float
896
+ Tolerance for the norm of the gradient.
897
+ maxls : int
898
+ Maximum number of line search steps for the L-BFGS algorithm.
899
+ num_elbo_draws : int
900
+ Number of draws for the Evidence Lower Bound (ELBO) estimation.
901
+ jitter : float
902
+ Amount of jitter to apply to initial points. Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
903
+ epsilon : float
904
+ value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L.
905
+ pathfinder_kwargs : dict
906
+ Additional keyword arguments for the Pathfinder algorithm.
907
+ compile_kwargs : dict
908
+ Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
909
+
910
+ Returns
911
+ -------
912
+ single_pathfinder_fn : Callable
913
+ A seedable single-path pathfinder function.
914
+ """
915
+
916
+ compile_kwargs = {"mode": Mode(linker=DEFAULT_LINKER), **compile_kwargs}
917
+ logp_dlogp_kwargs = {"jacobian": pathfinder_kwargs.get("jacobian", True), **compile_kwargs}
918
+
919
+ logp_dlogp_func = get_logp_dlogp_of_ravel_inputs(model, **logp_dlogp_kwargs)
920
+
921
+ def logp_func(x):
922
+ logp, _ = logp_dlogp_func(x)
923
+ return logp
924
+
925
+ def neg_logp_dlogp_func(x):
926
+ logp, dlogp = logp_dlogp_func(x)
927
+ return -logp, -dlogp
928
+
929
+ # initial point
930
+ # TODO: remove make_initial_points function when feature request is implemented: https://github.com/pymc-devs/pymc/issues/7555
931
+ ipfn = make_initial_point_fn(model=model)
932
+ ip = Point(ipfn(None), model=model)
933
+ x_base = DictToArrayBijection.map(ip).data
934
+
935
+ # lbfgs
936
+ lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls)
937
+
938
+ # pathfinder body
939
+ pathfinder_body_fn = make_pathfinder_body(
940
+ logp_func, num_draws, maxcor, num_elbo_draws, epsilon, **compile_kwargs
941
+ )
942
+ rngs = find_rng_nodes(pathfinder_body_fn.maker.fgraph.outputs)
943
+
944
+ def single_pathfinder_fn(random_seed: int) -> PathfinderResult:
945
+ try:
946
+ init_seed, *bfgs_seeds = _get_seeds_per_chain(random_seed, 3)
947
+ rng = np.random.default_rng(init_seed)
948
+ jitter_value = rng.uniform(-jitter, jitter, size=x_base.shape)
949
+ x0 = x_base + jitter_value
950
+ x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0)
951
+
952
+ if lbfgs_status == LBFGSStatus.INIT_FAILED:
953
+ raise LBFGSInitFailed()
954
+ elif lbfgs_status == LBFGSStatus.LBFGS_FAILED:
955
+ raise LBFGSException()
956
+
957
+ reseed_rngs(rngs, bfgs_seeds)
958
+ psi, logP_psi, logQ_psi, elbo_argmax = pathfinder_body_fn(x, g)
959
+
960
+ if np.all(~np.isfinite(logQ_psi)):
961
+ raise PathInvalidLogQ()
962
+
963
+ if elbo_argmax == 0:
964
+ path_status = PathStatus.ELBO_ARGMAX_AT_ZERO
965
+ else:
966
+ path_status = PathStatus.SUCCESS
967
+
968
+ return PathfinderResult(
969
+ samples=psi,
970
+ logP=logP_psi,
971
+ logQ=logQ_psi,
972
+ lbfgs_niter=lbfgs_niter,
973
+ elbo_argmax=elbo_argmax,
974
+ lbfgs_status=lbfgs_status,
975
+ path_status=path_status,
976
+ )
977
+ except LBFGSException as e:
978
+ return PathfinderResult(
979
+ lbfgs_status=e.status,
980
+ path_status=PathStatus.LBFGS_FAILED,
981
+ )
982
+ except PathException as e:
983
+ return PathfinderResult(
984
+ lbfgs_status=lbfgs_status,
985
+ path_status=e.status,
986
+ )
987
+
988
+ return single_pathfinder_fn
989
+
990
+
991
+ def _calculate_max_workers() -> int:
992
+ """
993
+ calculate the default number of workers to use for concurrent pathfinder runs.
994
+ """
995
+
996
+ # from limited testing, setting values higher than 0.3 makes multiprocessing a lot slower.
997
+ import multiprocessing
998
+
999
+ total_cpus = multiprocessing.cpu_count() or 1
1000
+ processes = max(2, int(total_cpus * 0.3))
1001
+ if processes % 2 != 0:
1002
+ processes += 1
1003
+ return processes
1004
+
1005
+
1006
+ def _thread(fn: SinglePathfinderFn, seed: int) -> "PathfinderResult":
1007
+ """
1008
+ execute pathfinder runs concurrently using threading.
1009
+ """
1010
+
1011
+ # kernel crashes without lock_ctx
1012
+ from pytensor.compile.compilelock import lock_ctx
1013
+
1014
+ with lock_ctx():
1015
+ rng = np.random.default_rng(seed)
1016
+ result = fn(rng)
1017
+ return result
1018
+
1019
+
1020
+ def _process(fn: SinglePathfinderFn, seed: int) -> "PathfinderResult | bytes":
1021
+ """
1022
+ execute pathfinder runs concurrently using multiprocessing.
1023
+ """
1024
+ import cloudpickle
1025
+
1026
+ from pytensor.compile.compilelock import lock_ctx
1027
+
1028
+ with lock_ctx():
1029
+ in_out_pickled = isinstance(fn, bytes)
1030
+ fn = cloudpickle.loads(fn)
1031
+ rng = np.random.default_rng(seed)
1032
+ result = fn(rng) if not in_out_pickled else cloudpickle.dumps(fn(rng))
1033
+ return result
1034
+
1035
+
1036
+ def _get_mp_context(mp_ctx: str | None = None) -> str | None:
1037
+ """code snippet taken from ParallelSampler in pymc/pymc/sampling/parallel.py"""
1038
+ import multiprocessing
1039
+ import platform
1040
+
1041
+ if mp_ctx is None or isinstance(mp_ctx, str):
1042
+ if mp_ctx is None and platform.system() == "Darwin":
1043
+ if platform.processor() == "arm":
1044
+ mp_ctx = "fork"
1045
+ logger.debug(
1046
+ "mp_ctx is set to 'fork' for MacOS with ARM architecture. "
1047
+ + "This might cause unexpected behavior with JAX, which is inherently multithreaded."
1048
+ )
1049
+ else:
1050
+ mp_ctx = "forkserver"
1051
+
1052
+ mp_ctx = multiprocessing.get_context(mp_ctx)
1053
+ return mp_ctx
1054
+
1055
+
1056
+ def _execute_concurrently(
1057
+ fn: SinglePathfinderFn,
1058
+ seeds: list[int],
1059
+ concurrent: Literal["thread", "process"] | None,
1060
+ max_workers: int | None = None,
1061
+ ) -> Iterator["PathfinderResult | bytes"]:
1062
+ """
1063
+ execute pathfinder runs concurrently.
1064
+ """
1065
+ if concurrent == "thread":
1066
+ from concurrent.futures import ThreadPoolExecutor, as_completed
1067
+ elif concurrent == "process":
1068
+ from concurrent.futures import ProcessPoolExecutor, as_completed
1069
+
1070
+ import cloudpickle
1071
+ else:
1072
+ raise ValueError(f"Invalid concurrent value: {concurrent}")
1073
+
1074
+ executor_cls = ThreadPoolExecutor if concurrent == "thread" else ProcessPoolExecutor
1075
+
1076
+ concurrent_fn = _thread if concurrent == "thread" else _process
1077
+
1078
+ executor_kwargs = {} if concurrent == "thread" else {"mp_context": _get_mp_context()}
1079
+
1080
+ max_workers = max_workers or (None if concurrent == "thread" else _calculate_max_workers())
1081
+
1082
+ fn = fn if concurrent == "thread" else cloudpickle.dumps(fn)
1083
+
1084
+ with executor_cls(max_workers=max_workers, **executor_kwargs) as executor:
1085
+ futures = [executor.submit(concurrent_fn, fn, seed) for seed in seeds]
1086
+ for f in as_completed(futures):
1087
+ yield (f.result() if concurrent == "thread" else cloudpickle.loads(f.result()))
1088
+
1089
+
1090
+ def _execute_serially(fn: SinglePathfinderFn, seeds: list[int]) -> Iterator["PathfinderResult"]:
1091
+ """
1092
+ execute pathfinder runs serially.
1093
+ """
1094
+ for seed in seeds:
1095
+ rng = np.random.default_rng(seed)
1096
+ yield fn(rng)
1097
+
1098
+
1099
+ def make_generator(
1100
+ concurrent: Literal["thread", "process"] | None,
1101
+ fn: SinglePathfinderFn,
1102
+ seeds: list[int],
1103
+ max_workers: int | None = None,
1104
+ ) -> Iterator["PathfinderResult | bytes"]:
1105
+ """
1106
+ generator for executing pathfinder runs concurrently or serially.
1107
+ """
1108
+ if concurrent is not None:
1109
+ yield from _execute_concurrently(fn, seeds, concurrent, max_workers)
1110
+ else:
1111
+ yield from _execute_serially(fn, seeds)
1112
+
1113
+
1114
+ @dataclass(slots=True, frozen=True)
1115
+ class PathfinderResult:
1116
+ """
1117
+ container for storing results from a single pathfinder run.
1118
+
1119
+ Attributes
1120
+ ----------
1121
+ samples: posterior samples (1, M, N)
1122
+ logP: log probability of model (1, M)
1123
+ logQ: log probability of approximation (1, M)
1124
+ lbfgs_niter: number of lbfgs iterations (1,)
1125
+ elbo_argmax: elbo values at convergence (1,)
1126
+ lbfgs_status: LBFGS status
1127
+ path_status: path status
1128
+
1129
+ where:
1130
+ M: number of samples
1131
+ N: number of parameters
1132
+ """
1133
+
1134
+ samples: NDArray | None = None
1135
+ logP: NDArray | None = None
1136
+ logQ: NDArray | None = None
1137
+ lbfgs_niter: NDArray | None = None
1138
+ elbo_argmax: NDArray | None = None
1139
+ lbfgs_status: LBFGSStatus = LBFGSStatus.LBFGS_FAILED
1140
+ path_status: PathStatus = PathStatus.PATH_FAILED
1141
+
1142
+
1143
+ @dataclass(frozen=True)
1144
+ class PathfinderConfig:
1145
+ """configuration parameters for a single pathfinder"""
1146
+
1147
+ num_draws: int # same as num_draws_per_path
1148
+ maxcor: int
1149
+ maxiter: int
1150
+ ftol: float
1151
+ gtol: float
1152
+ maxls: int
1153
+ jitter: float
1154
+ epsilon: float
1155
+ num_elbo_draws: int
1156
+
1157
+
1158
+ @dataclass(slots=True, frozen=True)
1159
+ class MultiPathfinderResult:
1160
+ """
1161
+ container for aggregating results from multiple paths.
1162
+
1163
+ Attributes
1164
+ ----------
1165
+ samples: posterior samples (S, M, N)
1166
+ logP: log probability of model (S, M)
1167
+ logQ: log probability of approximation (S, M)
1168
+ lbfgs_niter: number of lbfgs iterations (S,)
1169
+ elbo_argmax: elbo values at convergence (S,)
1170
+ lbfgs_status: counter for LBFGS status occurrences
1171
+ path_status: counter for path status occurrences
1172
+ importance_sampling: importance sampling method used
1173
+ warnings: list of warnings
1174
+ pareto_k
1175
+ pathfinder_config: pathfinder configuration
1176
+ compile_time
1177
+ compute_time
1178
+ where:
1179
+ S: number of successful paths, where S <= num_paths
1180
+ M: number of samples per path
1181
+ N: number of parameters
1182
+ """
1183
+
1184
+ samples: NDArray | None = None
1185
+ logP: NDArray | None = None
1186
+ logQ: NDArray | None = None
1187
+ lbfgs_niter: NDArray | None = None
1188
+ elbo_argmax: NDArray | None = None
1189
+ lbfgs_status: Counter = field(default_factory=Counter)
1190
+ path_status: Counter = field(default_factory=Counter)
1191
+ importance_sampling: str = "none"
1192
+ warnings: list[str] = field(default_factory=list)
1193
+ pareto_k: float | None = None
1194
+
1195
+ # config
1196
+ num_paths: int | None = None
1197
+ num_draws: int | None = None
1198
+ pathfinder_config: PathfinderConfig | None = None
1199
+
1200
+ # timing
1201
+ compile_time: float | None = None
1202
+ compute_time: float | None = None
1203
+
1204
+ all_paths_failed: bool = False # raises ValueError if all paths failed
1205
+
1206
+ @classmethod
1207
+ def from_path_results(cls, path_results: list[PathfinderResult]) -> "MultiPathfinderResult":
1208
+ """aggregate successful pathfinder results and count the occurrences of each status in PathStatus and LBFGSStatus"""
1209
+
1210
+ NUMERIC_ATTRIBUTES = ["samples", "logP", "logQ", "lbfgs_niter", "elbo_argmax"]
1211
+
1212
+ success_results = []
1213
+ mpr = cls()
1214
+
1215
+ for pr in path_results:
1216
+ if pr.path_status not in FAILED_PATH_STATUS:
1217
+ success_results.append(tuple(getattr(pr, attr) for attr in NUMERIC_ATTRIBUTES))
1218
+
1219
+ mpr.lbfgs_status[pr.lbfgs_status] += 1
1220
+ mpr.path_status[pr.path_status] += 1
1221
+
1222
+ # if not success_results:
1223
+ # raise ValueError(
1224
+ # "All paths failed. Consider decreasing the jitter or reparameterizing the model."
1225
+ # )
1226
+
1227
+ warnings = _get_status_warning(mpr)
1228
+
1229
+ if success_results:
1230
+ results_arr = [np.asarray(x) for x in zip(*success_results)]
1231
+ return cls(
1232
+ *[np.concatenate(x) if x.ndim > 1 else x for x in results_arr],
1233
+ lbfgs_status=mpr.lbfgs_status,
1234
+ path_status=mpr.path_status,
1235
+ warnings=warnings,
1236
+ )
1237
+ else:
1238
+ return cls(
1239
+ lbfgs_status=mpr.lbfgs_status,
1240
+ path_status=mpr.path_status,
1241
+ warnings=warnings,
1242
+ all_paths_failed=True, # raises ValueError later
1243
+ )
1244
+
1245
+ def with_timing(self, compile_time: float, compute_time: float) -> Self:
1246
+ """add timing information"""
1247
+ return replace(self, compile_time=compile_time, compute_time=compute_time)
1248
+
1249
+ def with_pathfinder_config(self, config: PathfinderConfig) -> Self:
1250
+ """add pathfinder configuration"""
1251
+ return replace(self, pathfinder_config=config)
1252
+
1253
+ def with_warnings(self, warnings: list[str]) -> Self:
1254
+ """add warnings"""
1255
+ return replace(self, warnings=warnings)
1256
+
1257
+ def with_importance_sampling(
1258
+ self,
1259
+ num_draws: int,
1260
+ method: Literal["psis", "psir", "identity", "none"] | None,
1261
+ random_seed: int | None = None,
1262
+ ) -> Self:
1263
+ """perform importance sampling"""
1264
+ if not self.all_paths_failed:
1265
+ isres = _importance_sampling(
1266
+ samples=self.samples,
1267
+ logP=self.logP,
1268
+ logQ=self.logQ,
1269
+ num_draws=num_draws,
1270
+ method=method,
1271
+ random_seed=random_seed,
1272
+ )
1273
+ return replace(
1274
+ self,
1275
+ samples=isres.samples,
1276
+ importance_sampling=method,
1277
+ warnings=[*self.warnings, *isres.warnings],
1278
+ pareto_k=isres.pareto_k,
1279
+ )
1280
+ else:
1281
+ return self
1282
+
1283
+ def create_summary(self) -> Table:
1284
+ """create rich table summary of pathfinder results"""
1285
+ table = Table(
1286
+ title="Pathfinder Results",
1287
+ title_style="none",
1288
+ title_justify="left",
1289
+ show_header=False,
1290
+ box=None,
1291
+ padding=(0, 2),
1292
+ show_edge=False,
1293
+ )
1294
+ table.add_column("Description")
1295
+ table.add_column("Value")
1296
+
1297
+ # model info
1298
+ if self.samples is not None:
1299
+ table.add_row("")
1300
+ table.add_row("No. model parameters", str(self.samples.shape[-1]))
1301
+
1302
+ # config
1303
+ if self.pathfinder_config is not None:
1304
+ table.add_row("")
1305
+ table.add_row("Configuration:")
1306
+ table.add_row("num_draws_per_path", str(self.pathfinder_config.num_draws))
1307
+ table.add_row("history size (maxcor)", str(self.pathfinder_config.maxcor))
1308
+ table.add_row("max iterations", str(self.pathfinder_config.maxiter))
1309
+ table.add_row("ftol", f"{self.pathfinder_config.ftol:.2e}")
1310
+ table.add_row("gtol", f"{self.pathfinder_config.gtol:.2e}")
1311
+ table.add_row("max line search", str(self.pathfinder_config.maxls))
1312
+ table.add_row("jitter", f"{self.pathfinder_config.jitter}")
1313
+ table.add_row("epsilon", f"{self.pathfinder_config.epsilon:.2e}")
1314
+ table.add_row("ELBO draws", str(self.pathfinder_config.num_elbo_draws))
1315
+
1316
+ # lbfgs
1317
+ table.add_row("")
1318
+ table.add_row("LBFGS Status:")
1319
+ for status, count in self.lbfgs_status.items():
1320
+ table.add_row(str(status.name), str(count))
1321
+
1322
+ if self.lbfgs_niter is not None:
1323
+ table.add_row(
1324
+ "L-BFGS iterations",
1325
+ f"mean {np.mean(self.lbfgs_niter):.0f} ± std {np.std(self.lbfgs_niter):.0f}",
1326
+ )
1327
+
1328
+ # paths
1329
+ table.add_row("")
1330
+ table.add_row("Path Status:")
1331
+ for status, count in self.path_status.items():
1332
+ table.add_row(str(status.name), str(count))
1333
+
1334
+ if self.elbo_argmax is not None:
1335
+ table.add_row(
1336
+ "ELBO argmax",
1337
+ f"mean {np.mean(self.elbo_argmax):.0f} ± std {np.std(self.elbo_argmax):.0f}",
1338
+ )
1339
+
1340
+ # importance sampling section
1341
+ if not self.all_paths_failed:
1342
+ table.add_row("")
1343
+ table.add_row("Importance Sampling:")
1344
+ table.add_row("Method", self.importance_sampling)
1345
+ if self.pareto_k is not None:
1346
+ table.add_row("Pareto k", f"{self.pareto_k:.2f}")
1347
+
1348
+ if self.compile_time is not None:
1349
+ table.add_row("")
1350
+ table.add_row("Timing (seconds):")
1351
+ table.add_row("Compile", f"{self.compile_time:.2f}")
1352
+
1353
+ if self.compute_time is not None:
1354
+ table.add_row("Compute", f"{self.compute_time:.2f}")
1355
+
1356
+ if self.compile_time is not None and self.compute_time is not None:
1357
+ table.add_row("Total", f"{self.compile_time + self.compute_time:.2f}")
1358
+
1359
+ return table
1360
+
1361
+ def display_summary(self) -> None:
1362
+ """display summary including warnings"""
1363
+ console = Console()
1364
+ summary = self.create_summary()
1365
+
1366
+ # warning messages
1367
+ if self.warnings:
1368
+ warning_text = [
1369
+ Text(), # blank line
1370
+ Text("Warnings:"),
1371
+ *(
1372
+ Padding(
1373
+ Text("- " + warning, no_wrap=False).wrap(console, width=console.width - 6),
1374
+ (0, 0, 0, 2), # left padding only
1375
+ )
1376
+ for warning in self.warnings
1377
+ ),
1378
+ ]
1379
+ output = Group(summary, *warning_text)
1380
+ else:
1381
+ output = summary
1382
+
1383
+ console.print(output)
1384
+
1385
+
1386
+ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
1387
+ """get list of relevant LBFGSStatus and PathStatus warnings given a MultiPathfinderResult"""
1388
+ warnings = []
1389
+
1390
+ lbfgs_status_message = {
1391
+ LBFGSStatus.MAX_ITER_REACHED: "LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
1392
+ LBFGSStatus.INIT_FAILED: "LBFGS failed to initialise. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
1393
+ LBFGSStatus.DIVERGED: "LBFGS diverged to infinity. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1394
+ }
1395
+
1396
+ path_status_message = {
1397
+ PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.",
1398
+ PathStatus.INVALID_LOGP: "Invalid logP values occur when a path's logP values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1399
+ PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1400
+ }
1401
+
1402
+ for lbfgs_status in mpr.lbfgs_status:
1403
+ if lbfgs_status in lbfgs_status_message:
1404
+ warnings.append(lbfgs_status_message.get(lbfgs_status))
1405
+
1406
+ for path_status in mpr.path_status:
1407
+ if path_status in path_status_message:
1408
+ warnings.append(path_status_message.get(path_status))
1409
+
1410
+ return warnings
1411
+
1412
+
1413
+ def multipath_pathfinder(
1414
+ model: Model,
1415
+ num_paths: int,
1416
+ num_draws: int,
1417
+ num_draws_per_path: int,
1418
+ maxcor: int,
1419
+ maxiter: int,
1420
+ ftol: float,
1421
+ gtol: float,
1422
+ maxls: int,
1423
+ num_elbo_draws: int,
1424
+ jitter: float,
1425
+ epsilon: float,
1426
+ importance_sampling: Literal["psis", "psir", "identity", "none"] | None,
1427
+ progressbar: bool,
1428
+ concurrent: Literal["thread", "process"] | None,
1429
+ random_seed: RandomSeed,
1430
+ pathfinder_kwargs: dict = {},
1431
+ compile_kwargs: dict = {},
1432
+ ) -> MultiPathfinderResult:
1433
+ """
1434
+ Fit the Pathfinder Variational Inference algorithm using multiple paths with PyMC/PyTensor backend.
1435
+
1436
+ Parameters
1437
+ ----------
1438
+ model : pymc.Model
1439
+ The PyMC model to fit the Pathfinder algorithm to.
1440
+ num_paths : int
1441
+ Number of independent paths to run in the Pathfinder algorithm. (default is 4) It is recommended to increase num_paths when increasing the jitter value.
1442
+ num_draws : int, optional
1443
+ Total number of samples to draw from the fitted approximation (default is 1000).
1444
+ num_draws_per_path : int, optional
1445
+ Number of samples to draw per path (default is 1000).
1446
+ maxcor : int, optional
1447
+ Maximum number of variable metric corrections used to define the limited memory matrix (default is None). If None, maxcor is set to ceil(3 * log(N)) or 5 whichever is greater, where N is the number of model parameters.
1448
+ maxiter : int, optional
1449
+ Maximum number of iterations for the L-BFGS optimisation (default is 1000).
1450
+ ftol : float, optional
1451
+ Tolerance for the decrease in the objective function (default is 1e-5).
1452
+ gtol : float, optional
1453
+ Tolerance for the norm of the gradient (default is 1e-8).
1454
+ maxls : int, optional
1455
+ Maximum number of line search steps for the L-BFGS algorithm (default is 1000).
1456
+ num_elbo_draws : int, optional
1457
+ Number of draws for the Evidence Lower Bound (ELBO) estimation (default is 10).
1458
+ jitter : float, optional
1459
+ Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
1460
+ epsilon: float
1461
+ value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1462
+ importance_sampling : str, optional
1463
+ importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1464
+ progressbar : bool, optional
1465
+ Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time.
1466
+ random_seed : RandomSeed, optional
1467
+ Random seed for reproducibility.
1468
+ postprocessing_backend : str, optional
1469
+ Backend for postprocessing transformations, either "cpu" or "gpu" (default is "cpu"). This is only relevant if inference_backend is "blackjax".
1470
+ inference_backend : str, optional
1471
+ Backend for inference, either "pymc" or "blackjax" (default is "pymc").
1472
+ concurrent : str, optional
1473
+ Whether to run paths concurrently, either "thread" or "process" or None (default is None). Setting concurrent to None runs paths serially and is generally faster with smaller models because of the overhead that comes with concurrency.
1474
+ pathfinder_kwargs
1475
+ Additional keyword arguments for the Pathfinder algorithm.
1476
+ compile_kwargs
1477
+ Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
1478
+
1479
+ Returns
1480
+ -------
1481
+ MultiPathfinderResult
1482
+ The result containing samples and other information from the Multi-Path Pathfinder algorithm.
1483
+ """
1484
+
1485
+ valid_importance_sampling = ["psis", "psir", "identity", "none", None]
1486
+ if importance_sampling is None:
1487
+ importance_sampling = "none"
1488
+ if importance_sampling.lower() not in valid_importance_sampling:
1489
+ raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
1490
+
1491
+ *path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1)
1492
+
1493
+ pathfinder_config = PathfinderConfig(
1494
+ num_draws=num_draws_per_path,
1495
+ maxcor=maxcor,
1496
+ maxiter=maxiter,
1497
+ ftol=ftol,
1498
+ gtol=gtol,
1499
+ maxls=maxls,
1500
+ num_elbo_draws=num_elbo_draws,
1501
+ jitter=jitter,
1502
+ epsilon=epsilon,
1503
+ )
1504
+
1505
+ compile_start = time.time()
1506
+ single_pathfinder_fn = make_single_pathfinder_fn(
1507
+ model,
1508
+ **asdict(pathfinder_config),
1509
+ pathfinder_kwargs=pathfinder_kwargs,
1510
+ compile_kwargs=compile_kwargs,
1511
+ )
1512
+ compile_end = time.time()
1513
+
1514
+ # NOTE: from limited tests, no concurrency is faster than thread, and thread is faster than process. But I suspect this also depends on the model size and maxcor setting.
1515
+ generator = make_generator(
1516
+ concurrent=concurrent,
1517
+ fn=single_pathfinder_fn,
1518
+ seeds=path_seeds,
1519
+ )
1520
+
1521
+ results = []
1522
+ compute_start = time.time()
1523
+ try:
1524
+ with CustomProgress(
1525
+ console=Console(theme=default_progress_theme),
1526
+ disable=not progressbar,
1527
+ ) as progress:
1528
+ task = progress.add_task("Fitting", total=num_paths)
1529
+ for result in generator:
1530
+ try:
1531
+ if isinstance(result, Exception):
1532
+ raise result
1533
+ else:
1534
+ results.append(result)
1535
+ except filelock.Timeout:
1536
+ logger.warning("Lock timeout. Retrying...")
1537
+ num_attempts = 0
1538
+ while num_attempts < 10:
1539
+ try:
1540
+ results.append(result)
1541
+ logger.info("Lock acquired. Continuing...")
1542
+ break
1543
+ except filelock.Timeout:
1544
+ num_attempts += 1
1545
+ time.sleep(0.5)
1546
+ logger.warning(f"Lock timeout. Retrying... ({num_attempts}/10)")
1547
+ except Exception as e:
1548
+ logger.warning("Unexpected error in a path: %s", str(e))
1549
+ results.append(
1550
+ PathfinderResult(
1551
+ path_status=PathStatus.PATH_FAILED,
1552
+ lbfgs_status=LBFGSStatus.LBFGS_FAILED,
1553
+ )
1554
+ )
1555
+ progress.update(task, advance=1)
1556
+ except (KeyboardInterrupt, StopIteration) as e:
1557
+ # if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
1558
+ if isinstance(e, StopIteration):
1559
+ logger.info(str(e))
1560
+ finally:
1561
+ compute_end = time.time()
1562
+ if results:
1563
+ mpr = (
1564
+ MultiPathfinderResult.from_path_results(results)
1565
+ .with_pathfinder_config(config=pathfinder_config)
1566
+ .with_importance_sampling(
1567
+ num_draws=num_draws, method=importance_sampling, random_seed=choice_seed
1568
+ )
1569
+ .with_timing(
1570
+ compile_time=compile_end - compile_start,
1571
+ compute_time=compute_end - compute_start,
1572
+ )
1573
+ )
1574
+ # TODO: option to disable summary, save to file, etc.
1575
+ mpr.display_summary()
1576
+ if mpr.all_paths_failed:
1577
+ raise ValueError(
1578
+ "All paths failed. Consider decreasing the jitter or reparameterizing the model."
1579
+ )
1580
+ else:
1581
+ raise ValueError(
1582
+ "BUG: Failed to iterate!"
1583
+ "Please report this issue at: "
1584
+ "https://github.com/pymc-devs/pymc-extras/issues "
1585
+ "with your code to reproduce the issue and the following details:\n"
1586
+ f"pathfinder_config: \n{pathfinder_config}\n"
1587
+ f"compile_kwargs: {compile_kwargs}\n"
1588
+ f"pathfinder_kwargs: {pathfinder_kwargs}\n"
1589
+ f"num_paths: {num_paths}\n"
1590
+ f"num_draws: {num_draws}\n"
1591
+ )
1592
+
1593
+ return mpr
1594
+
1595
+
1596
+ def fit_pathfinder(
1597
+ model=None,
1598
+ num_paths: int = 4, # I
1599
+ num_draws: int = 1000, # R
1600
+ num_draws_per_path: int = 1000, # M
1601
+ maxcor: int | None = None, # J
1602
+ maxiter: int = 1000, # L^max
1603
+ ftol: float = 1e-5,
1604
+ gtol: float = 1e-8,
1605
+ maxls=1000,
1606
+ num_elbo_draws: int = 10, # K
1607
+ jitter: float = 2.0,
1608
+ epsilon: float = 1e-8,
1609
+ importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
1610
+ progressbar: bool = True,
1611
+ concurrent: Literal["thread", "process"] | None = None,
1612
+ random_seed: RandomSeed | None = None,
1613
+ postprocessing_backend: Literal["cpu", "gpu"] = "cpu",
1614
+ inference_backend: Literal["pymc", "blackjax"] = "pymc",
1615
+ pathfinder_kwargs: dict = {},
1616
+ compile_kwargs: dict = {},
1617
+ ) -> az.InferenceData:
1618
+ """
1619
+ Fit the Pathfinder Variational Inference algorithm.
1620
+
1621
+ This function fits the Pathfinder algorithm to a given PyMC model, allowing for multiple paths and draws. It supports both PyMC and BlackJAX backends.
1622
+
1623
+ Parameters
1624
+ ----------
1625
+ model : pymc.Model
1626
+ The PyMC model to fit the Pathfinder algorithm to.
1627
+ num_paths : int
1628
+ Number of independent paths to run in the Pathfinder algorithm. (default is 4) It is recommended to increase num_paths when increasing the jitter value.
1629
+ num_draws : int, optional
1630
+ Total number of samples to draw from the fitted approximation (default is 1000).
1631
+ num_draws_per_path : int, optional
1632
+ Number of samples to draw per path (default is 1000).
1633
+ maxcor : int, optional
1634
+ Maximum number of variable metric corrections used to define the limited memory matrix (default is None). If None, maxcor is set to ceil(3 * log(N)) or 5 whichever is greater, where N is the number of model parameters.
1635
+ maxiter : int, optional
1636
+ Maximum number of iterations for the L-BFGS optimisation (default is 1000).
1637
+ ftol : float, optional
1638
+ Tolerance for the decrease in the objective function (default is 1e-5).
1639
+ gtol : float, optional
1640
+ Tolerance for the norm of the gradient (default is 1e-8).
1641
+ maxls : int, optional
1642
+ Maximum number of line search steps for the L-BFGS algorithm (default is 1000).
1643
+ num_elbo_draws : int, optional
1644
+ Number of draws for the Evidence Lower Bound (ELBO) estimation (default is 10).
1645
+ jitter : float, optional
1646
+ Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
1647
+ epsilon: float
1648
+ value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1649
+ importance_sampling : str, optional
1650
+ importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1651
+ progressbar : bool, optional
1652
+ Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
1653
+ random_seed : RandomSeed, optional
1654
+ Random seed for reproducibility.
1655
+ postprocessing_backend : str, optional
1656
+ Backend for postprocessing transformations, either "cpu" or "gpu" (default is "cpu"). This is only relevant if inference_backend is "blackjax".
1657
+ inference_backend : str, optional
1658
+ Backend for inference, either "pymc" or "blackjax" (default is "pymc").
1659
+ concurrent : str, optional
1660
+ Whether to run paths concurrently, either "thread" or "process" or None (default is None). Setting concurrent to None runs paths serially and is generally faster with smaller models because of the overhead that comes with concurrency.
1661
+ pathfinder_kwargs
1662
+ Additional keyword arguments for the Pathfinder algorithm.
1663
+ compile_kwargs
1664
+ Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
1665
+
1666
+ Returns
1667
+ -------
1668
+ arviz.InferenceData
1669
+ The inference data containing the results of the Pathfinder algorithm.
1670
+
1671
+ References
1672
+ ----------
1673
+ Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
1674
+ """
1675
+
1676
+ model = modelcontext(model)
1677
+ N = DictToArrayBijection.map(model.initial_point()).data.shape[0]
1678
+
1679
+ if maxcor is None:
1680
+ # Based on tests, this seems to be a good default value. Higher maxcor values do not necessarily lead to better results and can slow down the algorithm. Also, if results do benefit from a higher maxcor value, the improvement may be diminishing w.r.t. the increase in maxcor.
1681
+ maxcor = np.ceil(3 * np.log(N)).astype(np.int32)
1682
+ maxcor = max(maxcor, 5)
1683
+
1684
+ if inference_backend == "pymc":
1685
+ mp_result = multipath_pathfinder(
1686
+ model,
1687
+ num_paths=num_paths,
1688
+ num_draws=num_draws,
1689
+ num_draws_per_path=num_draws_per_path,
1690
+ maxcor=maxcor,
1691
+ maxiter=maxiter,
1692
+ ftol=ftol,
1693
+ gtol=gtol,
1694
+ maxls=maxls,
1695
+ num_elbo_draws=num_elbo_draws,
1696
+ jitter=jitter,
1697
+ epsilon=epsilon,
1698
+ importance_sampling=importance_sampling,
1699
+ progressbar=progressbar,
1700
+ concurrent=concurrent,
1701
+ random_seed=random_seed,
1702
+ pathfinder_kwargs=pathfinder_kwargs,
1703
+ compile_kwargs=compile_kwargs,
1704
+ )
1705
+ pathfinder_samples = mp_result.samples
1706
+ elif inference_backend == "blackjax":
1707
+ if find_spec("blackjax") is None:
1708
+ raise RuntimeError("Need BlackJAX to use `pathfinder`")
1709
+ if version.parse(blackjax.__version__).major < 1:
1710
+ raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
1711
+
1712
+ jitter_seed, pathfinder_seed, sample_seed = _get_seeds_per_chain(random_seed, 3)
1713
+ # TODO: extend initial points with jitter_scale to blackjax
1714
+ # TODO: extend blackjax pathfinder to multiple paths
1715
+ x0, _ = DictToArrayBijection.map(model.initial_point())
1716
+ logp_func = get_jaxified_logp_of_ravel_inputs(model)
1717
+ pathfinder_state, pathfinder_info = blackjax.vi.pathfinder.approximate(
1718
+ rng_key=jax.random.key(pathfinder_seed),
1719
+ logdensity_fn=logp_func,
1720
+ initial_position=x0,
1721
+ num_samples=num_elbo_draws,
1722
+ maxiter=maxiter,
1723
+ maxcor=maxcor,
1724
+ maxls=maxls,
1725
+ ftol=ftol,
1726
+ gtol=gtol,
1727
+ **pathfinder_kwargs,
1728
+ )
1729
+ pathfinder_samples, _ = blackjax.vi.pathfinder.sample(
1730
+ rng_key=jax.random.key(sample_seed),
1731
+ state=pathfinder_state,
1732
+ num_samples=num_draws,
1733
+ )
1734
+ else:
1735
+ raise ValueError(f"Invalid inference_backend: {inference_backend}")
1736
+
1737
+ logger.info("Transforming variables...")
1738
+
1739
+ idata = convert_flat_trace_to_idata(
1740
+ pathfinder_samples,
1741
+ postprocessing_backend=postprocessing_backend,
1742
+ inference_backend=inference_backend,
1743
+ model=model,
1744
+ importance_sampling=importance_sampling,
1745
+ )
1746
+ return idata