pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pymc_extras/__init__.py +6 -4
  2. pymc_extras/distributions/__init__.py +2 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/distributions/transforms/__init__.py +3 -0
  6. pymc_extras/distributions/transforms/partial_order.py +227 -0
  7. pymc_extras/inference/__init__.py +4 -2
  8. pymc_extras/inference/find_map.py +62 -17
  9. pymc_extras/inference/fit.py +6 -4
  10. pymc_extras/inference/laplace.py +14 -8
  11. pymc_extras/inference/pathfinder/lbfgs.py +49 -13
  12. pymc_extras/inference/pathfinder/pathfinder.py +89 -103
  13. pymc_extras/statespace/core/statespace.py +191 -52
  14. pymc_extras/statespace/filters/distributions.py +15 -16
  15. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  16. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  17. pymc_extras/statespace/models/ETS.py +10 -0
  18. pymc_extras/statespace/models/SARIMAX.py +26 -5
  19. pymc_extras/statespace/models/VARMAX.py +12 -2
  20. pymc_extras/statespace/models/structural.py +18 -5
  21. pymc_extras/statespace/utils/data_tools.py +24 -9
  22. pymc_extras-0.2.6.dist-info/METADATA +318 -0
  23. pymc_extras-0.2.6.dist-info/RECORD +65 -0
  24. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
  25. pymc_extras/version.py +0 -11
  26. pymc_extras/version.txt +0 -1
  27. pymc_extras-0.2.4.dist-info/METADATA +0 -110
  28. pymc_extras-0.2.4.dist-info/RECORD +0 -105
  29. pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
  30. tests/__init__.py +0 -13
  31. tests/distributions/__init__.py +0 -19
  32. tests/distributions/test_continuous.py +0 -185
  33. tests/distributions/test_discrete.py +0 -210
  34. tests/distributions/test_discrete_markov_chain.py +0 -258
  35. tests/distributions/test_multivariate.py +0 -304
  36. tests/model/__init__.py +0 -0
  37. tests/model/marginal/__init__.py +0 -0
  38. tests/model/marginal/test_distributions.py +0 -132
  39. tests/model/marginal/test_graph_analysis.py +0 -182
  40. tests/model/marginal/test_marginal_model.py +0 -967
  41. tests/model/test_model_api.py +0 -38
  42. tests/statespace/__init__.py +0 -0
  43. tests/statespace/test_ETS.py +0 -411
  44. tests/statespace/test_SARIMAX.py +0 -405
  45. tests/statespace/test_VARMAX.py +0 -184
  46. tests/statespace/test_coord_assignment.py +0 -116
  47. tests/statespace/test_distributions.py +0 -270
  48. tests/statespace/test_kalman_filter.py +0 -326
  49. tests/statespace/test_representation.py +0 -175
  50. tests/statespace/test_statespace.py +0 -872
  51. tests/statespace/test_statespace_JAX.py +0 -156
  52. tests/statespace/test_structural.py +0 -836
  53. tests/statespace/utilities/__init__.py +0 -0
  54. tests/statespace/utilities/shared_fixtures.py +0 -9
  55. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  56. tests/statespace/utilities/test_helpers.py +0 -310
  57. tests/test_blackjax_smc.py +0 -222
  58. tests/test_find_map.py +0 -103
  59. tests/test_histogram_approximation.py +0 -109
  60. tests/test_laplace.py +0 -265
  61. tests/test_linearmodel.py +0 -208
  62. tests/test_model_builder.py +0 -306
  63. tests/test_pathfinder.py +0 -203
  64. tests/test_pivoted_cholesky.py +0 -24
  65. tests/test_printing.py +0 -98
  66. tests/test_prior_from_trace.py +0 -172
  67. tests/test_splines.py +0 -77
  68. tests/utils.py +0 -0
  69. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
@@ -37,11 +37,14 @@ class LBFGSHistoryManager:
37
37
  initial position
38
38
  maxiter : int
39
39
  maximum number of iterations to store
40
+ epsilon : float
41
+ tolerance for lbfgs update
40
42
  """
41
43
 
42
44
  value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]]
43
45
  x0: NDArray[np.float64]
44
46
  maxiter: int
47
+ epsilon: float
45
48
  x_history: NDArray[np.float64] = field(init=False)
46
49
  g_history: NDArray[np.float64] = field(init=False)
47
50
  count: int = field(init=False)
@@ -52,7 +55,7 @@ class LBFGSHistoryManager:
52
55
  self.count = 0
53
56
 
54
57
  value, grad = self.value_grad_fn(self.x0)
55
- if np.all(np.isfinite(grad)) and np.isfinite(value):
58
+ if self.entry_condition_met(self.x0, value, grad):
56
59
  self.add_entry(self.x0, grad)
57
60
 
58
61
  def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None:
@@ -75,18 +78,39 @@ class LBFGSHistoryManager:
75
78
  x=self.x_history[: self.count], g=self.g_history[: self.count], count=self.count
76
79
  )
77
80
 
81
+ def entry_condition_met(self, x, value, grad) -> bool:
82
+ """Checks if the LBFGS iteration should continue."""
83
+
84
+ if np.all(np.isfinite(grad)) and np.isfinite(value) and (self.count < self.maxiter + 1):
85
+ if self.count == 0:
86
+ return True
87
+ else:
88
+ s = x - self.x_history[self.count - 1]
89
+ z = grad - self.g_history[self.count - 1]
90
+ sz = (s * z).sum(axis=-1)
91
+ update = sz > self.epsilon * np.sqrt(np.sum(z**2, axis=-1))
92
+
93
+ if update:
94
+ return True
95
+ else:
96
+ return False
97
+ else:
98
+ return False
99
+
78
100
  def __call__(self, x: NDArray[np.float64]) -> None:
79
101
  value, grad = self.value_grad_fn(x)
80
- if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1:
102
+ if self.entry_condition_met(x, value, grad):
81
103
  self.add_entry(x, grad)
82
104
 
83
105
 
84
106
  class LBFGSStatus(Enum):
85
107
  CONVERGED = auto()
86
108
  MAX_ITER_REACHED = auto()
87
- DIVERGED = auto()
109
+ NON_FINITE = auto()
110
+ LOW_UPDATE_PCT = auto()
88
111
  # Statuses that lead to Exceptions:
89
112
  INIT_FAILED = auto()
113
+ INIT_FAILED_LOW_UPDATE_PCT = auto()
90
114
  LBFGS_FAILED = auto()
91
115
 
92
116
 
@@ -101,8 +125,8 @@ class LBFGSException(Exception):
101
125
  class LBFGSInitFailed(LBFGSException):
102
126
  DEFAULT_MESSAGE = "LBFGS failed to initialise."
103
127
 
104
- def __init__(self, message=None):
105
- super().__init__(message or self.DEFAULT_MESSAGE, LBFGSStatus.INIT_FAILED)
128
+ def __init__(self, status: LBFGSStatus, message=None):
129
+ super().__init__(message or self.DEFAULT_MESSAGE, status)
106
130
 
107
131
 
108
132
  class LBFGS:
@@ -122,10 +146,12 @@ class LBFGS:
122
146
  gradient tolerance for convergence, defaults to 1e-8
123
147
  maxls : int, optional
124
148
  maximum number of line search steps, defaults to 1000
149
+ epsilon : float, optional
150
+ tolerance for lbfgs update, defaults to 1e-8
125
151
  """
126
152
 
127
153
  def __init__(
128
- self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000
154
+ self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000, epsilon=1e-8
129
155
  ) -> None:
130
156
  self.value_grad_fn = value_grad_fn
131
157
  self.maxcor = maxcor
@@ -133,6 +159,7 @@ class LBFGS:
133
159
  self.ftol = ftol
134
160
  self.gtol = gtol
135
161
  self.maxls = maxls
162
+ self.epsilon = epsilon
136
163
 
137
164
  def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
138
165
  """minimizes objective function starting from initial position.
@@ -157,7 +184,7 @@ class LBFGS:
157
184
  x0 = np.array(x0, dtype=np.float64)
158
185
 
159
186
  history_manager = LBFGSHistoryManager(
160
- value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter
187
+ value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter, epsilon=self.epsilon
161
188
  )
162
189
 
163
190
  result = minimize(
@@ -177,13 +204,22 @@ class LBFGS:
177
204
  history = history_manager.get_history()
178
205
 
179
206
  # warnings and suggestions for LBFGSStatus are displayed at the end
180
- if result.status == 1:
181
- lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
182
- elif (result.status == 2) or (history.count <= 1):
183
- if result.nit <= 1:
207
+ # threshold determining if the number of lbfgs updates is low compared to iterations
208
+ low_update_threshold = 3
209
+
210
+ if history.count <= 1: # triggers LBFGSInitFailed
211
+ if result.nit < low_update_threshold:
184
212
  lbfgs_status = LBFGSStatus.INIT_FAILED
185
- elif result.fun == np.inf:
186
- lbfgs_status = LBFGSStatus.DIVERGED
213
+ else:
214
+ lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT
215
+ elif result.status == 1:
216
+ # (result.nit > maxiter) or (result.nit > maxls)
217
+ lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
218
+ elif result.status == 2:
219
+ # precision loss resulting to inf or nan
220
+ lbfgs_status = LBFGSStatus.NON_FINITE
221
+ elif history.count * low_update_threshold < result.nit:
222
+ lbfgs_status = LBFGSStatus.LOW_UPDATE_PCT
187
223
  else:
188
224
  lbfgs_status = LBFGSStatus.CONVERGED
189
225
 
@@ -12,22 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
15
16
  import collections
16
17
  import logging
17
18
  import time
18
- import warnings as _warnings
19
19
 
20
20
  from collections import Counter
21
21
  from collections.abc import Callable, Iterator
22
22
  from dataclasses import asdict, dataclass, field, replace
23
23
  from enum import Enum, auto
24
- from importlib.util import find_spec
25
24
  from typing import Literal, TypeAlias
26
25
 
27
26
  import arviz as az
28
- import blackjax
29
27
  import filelock
30
- import jax
31
28
  import numpy as np
32
29
  import pymc as pm
33
30
  import pytensor
@@ -42,11 +39,10 @@ from pymc.initial_point import make_initial_point_fn
42
39
  from pymc.model import modelcontext
43
40
  from pymc.model.core import Point
44
41
  from pymc.pytensorf import (
45
- compile_pymc,
42
+ compile,
46
43
  find_rng_nodes,
47
44
  reseed_rngs,
48
45
  )
49
- from pymc.sampling.jax import get_jaxified_graph
50
46
  from pymc.util import (
51
47
  CustomProgress,
52
48
  RandomSeed,
@@ -67,6 +63,7 @@ from rich.text import Text
67
63
  # TODO: change to typing.Self after Python versions greater than 3.10
68
64
  from typing_extensions import Self
69
65
 
66
+ from pymc_extras.inference.laplace import add_data_to_inferencedata
70
67
  from pymc_extras.inference.pathfinder.importance_sampling import (
71
68
  importance_sampling as _importance_sampling,
72
69
  )
@@ -78,9 +75,6 @@ from pymc_extras.inference.pathfinder.lbfgs import (
78
75
  )
79
76
 
80
77
  logger = logging.getLogger(__name__)
81
- _warnings.filterwarnings(
82
- "ignore", category=FutureWarning, message="compile_pymc was renamed to compile"
83
- )
84
78
 
85
79
  REGULARISATION_TERM = 1e-8
86
80
  DEFAULT_LINKER = "cvm_nogc"
@@ -105,6 +99,8 @@ def get_jaxified_logp_of_ravel_inputs(model: Model, jacobian: bool = True) -> Ca
105
99
  A JAX function that computes the log-probability of a PyMC model with ravelled inputs.
106
100
  """
107
101
 
102
+ from pymc.sampling.jax import get_jaxified_graph
103
+
108
104
  # TODO: JAX: test if we should get jaxified graph of dlogp as well
109
105
  new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
110
106
  model.initial_point(), (model.logp(jacobian=jacobian),), model.value_vars, ()
@@ -144,7 +140,7 @@ def get_logp_dlogp_of_ravel_inputs(
144
140
  [model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)],
145
141
  model.value_vars,
146
142
  )
147
- logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs)
143
+ logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs)
148
144
  logp_dlogp_fn.trust_input = True
149
145
 
150
146
  return logp_dlogp_fn
@@ -224,6 +220,10 @@ def convert_flat_trace_to_idata(
224
220
  result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
225
221
 
226
222
  elif inference_backend == "blackjax":
223
+ import jax
224
+
225
+ from pymc.sampling.jax import get_jaxified_graph
226
+
227
227
  jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
228
228
  result = jax.vmap(jax.vmap(jax_fn))(
229
229
  *jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
@@ -237,8 +237,8 @@ def convert_flat_trace_to_idata(
237
237
 
238
238
 
239
239
  def alpha_recover(
240
- x: TensorVariable, g: TensorVariable, epsilon: TensorVariable
241
- ) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]:
240
+ x: TensorVariable, g: TensorVariable
241
+ ) -> tuple[TensorVariable, TensorVariable, TensorVariable]:
242
242
  """compute the diagonal elements of the inverse Hessian at each iterations of L-BFGS and filter updates.
243
243
 
244
244
  Parameters
@@ -247,9 +247,6 @@ def alpha_recover(
247
247
  position array, shape (L+1, N)
248
248
  g : TensorVariable
249
249
  gradient array, shape (L+1, N)
250
- epsilon : float
251
- threshold for filtering updates based on inner product of position
252
- and gradient differences
253
250
 
254
251
  Returns
255
252
  -------
@@ -259,15 +256,13 @@ def alpha_recover(
259
256
  position differences, shape (L, N)
260
257
  z : TensorVariable
261
258
  gradient differences, shape (L, N)
262
- update_mask : TensorVariable
263
- mask for filtering updates, shape (L,)
264
259
 
265
260
  Notes
266
261
  -----
267
262
  shapes: L=batch_size, N=num_params
268
263
  """
269
264
 
270
- def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable:
265
+ def compute_alpha_l(s_l, z_l, alpha_lm1) -> TensorVariable:
271
266
  # alpha_lm1: (N,)
272
267
  # s_l: (N,)
273
268
  # z_l: (N,)
@@ -281,43 +276,28 @@ def alpha_recover(
281
276
  ) # fmt:off
282
277
  return 1.0 / inv_alpha_l
283
278
 
284
- def return_alpha_lm1(alpha_lm1, s_l, z_l) -> TensorVariable:
285
- return alpha_lm1[-1]
286
-
287
- def scan_body(update_mask_l, s_l, z_l, alpha_lm1) -> TensorVariable:
288
- return pt.switch(
289
- update_mask_l,
290
- compute_alpha_l(alpha_lm1, s_l, z_l),
291
- return_alpha_lm1(alpha_lm1, s_l, z_l),
292
- )
293
-
294
279
  Lp1, N = x.shape
295
280
  s = pt.diff(x, axis=0)
296
281
  z = pt.diff(g, axis=0)
297
282
  alpha_l_init = pt.ones(N)
298
- sz = (s * z).sum(axis=-1)
299
- # update_mask = sz > epsilon * pt.linalg.norm(z, axis=-1)
300
- # pt.linalg.norm does not work with JAX!!
301
- update_mask = sz > epsilon * pt.sqrt(pt.sum(z**2, axis=-1))
302
283
 
303
284
  alpha, _ = pytensor.scan(
304
- fn=scan_body,
285
+ fn=compute_alpha_l,
305
286
  outputs_info=alpha_l_init,
306
- sequences=[update_mask, s, z],
287
+ sequences=[s, z],
307
288
  n_steps=Lp1 - 1,
308
289
  allow_gc=False,
309
290
  )
310
291
 
311
292
  # assert np.all(alpha.eval() > 0), "alpha cannot be negative"
312
- # alpha: (L, N), update_mask: (L, N)
313
- return alpha, s, z, update_mask
293
+ # alpha: (L, N)
294
+ return alpha, s, z
314
295
 
315
296
 
316
297
  def inverse_hessian_factors(
317
298
  alpha: TensorVariable,
318
299
  s: TensorVariable,
319
300
  z: TensorVariable,
320
- update_mask: TensorVariable,
321
301
  J: TensorConstant,
322
302
  ) -> tuple[TensorVariable, TensorVariable]:
323
303
  """compute the inverse hessian factors for the BFGS approximation.
@@ -330,8 +310,6 @@ def inverse_hessian_factors(
330
310
  position differences, shape (L, N)
331
311
  z : TensorVariable
332
312
  gradient differences, shape (L, N)
333
- update_mask : TensorVariable
334
- mask for filtering updates, shape (L,)
335
313
  J : TensorConstant
336
314
  history size for L-BFGS
337
315
 
@@ -350,30 +328,19 @@ def inverse_hessian_factors(
350
328
  # NOTE: get_chi_matrix_1 is a modified version of get_chi_matrix_2 to closely follow Zhang et al., (2022)
351
329
  # NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented
352
330
 
353
- def get_chi_matrix_1(
354
- diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
355
- ) -> TensorVariable:
331
+ def get_chi_matrix_1(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
356
332
  L, N = diff.shape
357
333
  j_last = pt.as_tensor(J - 1) # since indexing starts at 0
358
334
 
359
- def chi_update(chi_lm1, diff_l) -> TensorVariable:
335
+ def chi_update(diff_l, chi_lm1) -> TensorVariable:
360
336
  chi_l = pt.roll(chi_lm1, -1, axis=0)
361
337
  return pt.set_subtensor(chi_l[j_last], diff_l)
362
338
 
363
- def no_op(chi_lm1, diff_l) -> TensorVariable:
364
- return chi_lm1
365
-
366
- def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable:
367
- return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l))
368
-
369
339
  chi_init = pt.zeros((J, N))
370
340
  chi_mat, _ = pytensor.scan(
371
- fn=scan_body,
341
+ fn=chi_update,
372
342
  outputs_info=chi_init,
373
- sequences=[
374
- update_mask,
375
- diff,
376
- ],
343
+ sequences=[diff],
377
344
  allow_gc=False,
378
345
  )
379
346
 
@@ -382,19 +349,15 @@ def inverse_hessian_factors(
382
349
  # (L, N, J)
383
350
  return chi_mat
384
351
 
385
- def get_chi_matrix_2(
386
- diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
387
- ) -> TensorVariable:
352
+ def get_chi_matrix_2(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
388
353
  L, N = diff.shape
389
354
 
390
- diff_masked = update_mask[:, None] * diff
391
-
392
355
  # diff_padded: (L+J, N)
393
356
  pad_width = pt.zeros(shape=(2, 2), dtype="int32")
394
- pad_width = pt.set_subtensor(pad_width[0, 0], J)
395
- diff_padded = pt.pad(diff_masked, pad_width, mode="constant")
357
+ pad_width = pt.set_subtensor(pad_width[0, 0], J - 1)
358
+ diff_padded = pt.pad(diff, pad_width, mode="constant")
396
359
 
397
- index = pt.arange(L)[:, None] + pt.arange(J)[None, :]
360
+ index = pt.arange(L)[..., None] + pt.arange(J)[None, ...]
398
361
  index = index.reshape((L, J))
399
362
 
400
363
  chi_mat = pt.matrix_transpose(diff_padded[index])
@@ -403,8 +366,10 @@ def inverse_hessian_factors(
403
366
  return chi_mat
404
367
 
405
368
  L, N = alpha.shape
406
- S = get_chi_matrix_1(s, update_mask, J)
407
- Z = get_chi_matrix_1(z, update_mask, J)
369
+
370
+ # changed to get_chi_matrix_2 after removing update_mask
371
+ S = get_chi_matrix_2(s, J)
372
+ Z = get_chi_matrix_2(z, J)
408
373
 
409
374
  # E: (L, J, J)
410
375
  Ij = pt.eye(J)[None, ...]
@@ -489,6 +454,7 @@ def bfgs_sample_dense(
489
454
 
490
455
  N = x.shape[-1]
491
456
  IdN = pt.eye(N)[None, ...]
457
+ IdN += IdN * REGULARISATION_TERM
492
458
 
493
459
  # inverse Hessian
494
460
  H_inv = (
@@ -504,7 +470,10 @@ def bfgs_sample_dense(
504
470
 
505
471
  logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
506
472
 
507
- mu = x - pt.batched_dot(H_inv, g)
473
+ # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
474
+
475
+ batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
476
+ mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
508
477
 
509
478
  phi = pt.matrix_transpose(
510
479
  # (L, N, 1)
@@ -565,23 +534,28 @@ def bfgs_sample_sparse(
565
534
  # qr_input: (L, N, 2J)
566
535
  qr_input = inv_sqrt_alpha_diag @ beta
567
536
  (Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False)
537
+
568
538
  IdN = pt.eye(R.shape[1])[None, ...]
539
+ IdN += IdN * REGULARISATION_TERM
540
+
569
541
  Lchol_input = IdN + R @ gamma @ pt.matrix_transpose(R)
570
542
 
543
+ # TODO: make robust Lchol calcs more robust, ie. try exceptions, increase REGULARISATION_TERM if non-finite exists
571
544
  Lchol = pt.linalg.cholesky(Lchol_input, lower=False, check_finite=False, on_error="nan")
572
545
 
573
546
  logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
574
547
  logdet += pt.sum(pt.log(alpha), axis=-1)
575
548
 
549
+ # inverse Hessian
550
+ # (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
551
+ H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta))
552
+
576
553
  # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
577
- mu = x - (
578
- # (L, N), (L, N) -> (L, N)
579
- pt.batched_dot(alpha_diag, g)
580
- # beta @ gamma @ beta.T
581
- # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
582
- # (L, N, N), (L, N) -> (L, N)
583
- + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
584
- )
554
+
555
+ # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
556
+
557
+ batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
558
+ mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
585
559
 
586
560
  phi = pt.matrix_transpose(
587
561
  # (L, N, 1)
@@ -589,8 +563,6 @@ def bfgs_sample_sparse(
589
563
  # (L, N, N), (L, N, M) -> (L, N, M)
590
564
  + sqrt_alpha_diag
591
565
  @ (
592
- # (L, N, 2J), (L, 2J, M) -> (L, N, M)
593
- # intermediate calcs below
594
566
  # (L, N, 2J), (L, 2J, 2J) -> (L, N, 2J)
595
567
  (Q @ (Lchol - IdN))
596
568
  # (L, 2J, N), (L, N, M) -> (L, 2J, M)
@@ -778,7 +750,6 @@ def make_pathfinder_body(
778
750
  num_draws: int,
779
751
  maxcor: int,
780
752
  num_elbo_draws: int,
781
- epsilon: float,
782
753
  **compile_kwargs: dict,
783
754
  ) -> Function:
784
755
  """
@@ -794,8 +765,6 @@ def make_pathfinder_body(
794
765
  The maximum number of iterations for the L-BFGS algorithm.
795
766
  num_elbo_draws : int
796
767
  The number of draws for the Evidence Lower Bound (ELBO) estimation.
797
- epsilon : float
798
- The value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L.
799
768
  compile_kwargs : dict
800
769
  Additional keyword arguments for the PyTensor compiler.
801
770
 
@@ -820,11 +789,10 @@ def make_pathfinder_body(
820
789
 
821
790
  num_draws = pt.constant(num_draws, "num_draws", dtype="int32")
822
791
  num_elbo_draws = pt.constant(num_elbo_draws, "num_elbo_draws", dtype="int32")
823
- epsilon = pt.constant(epsilon, "epsilon", dtype="float64")
824
792
  maxcor = pt.constant(maxcor, "maxcor", dtype="int32")
825
793
 
826
- alpha, s, z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon)
827
- beta, gamma = inverse_hessian_factors(alpha, s, z, update_mask, J=maxcor)
794
+ alpha, s, z = alpha_recover(x_full, g_full)
795
+ beta, gamma = inverse_hessian_factors(alpha, s, z, J=maxcor)
828
796
 
829
797
  # ignore initial point - x, g: (L, N)
830
798
  x = x_full[1:]
@@ -855,7 +823,7 @@ def make_pathfinder_body(
855
823
 
856
824
  # return psi, logP_psi, logQ_psi, elbo_argmax
857
825
 
858
- pathfinder_body_fn = compile_pymc(
826
+ pathfinder_body_fn = compile(
859
827
  [x_full, g_full],
860
828
  [psi, logP_psi, logQ_psi, elbo_argmax],
861
829
  **compile_kwargs,
@@ -934,11 +902,11 @@ def make_single_pathfinder_fn(
934
902
  x_base = DictToArrayBijection.map(ip).data
935
903
 
936
904
  # lbfgs
937
- lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls)
905
+ lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls, epsilon)
938
906
 
939
907
  # pathfinder body
940
908
  pathfinder_body_fn = make_pathfinder_body(
941
- logp_func, num_draws, maxcor, num_elbo_draws, epsilon, **compile_kwargs
909
+ logp_func, num_draws, maxcor, num_elbo_draws, **compile_kwargs
942
910
  )
943
911
  rngs = find_rng_nodes(pathfinder_body_fn.maker.fgraph.outputs)
944
912
 
@@ -950,8 +918,8 @@ def make_single_pathfinder_fn(
950
918
  x0 = x_base + jitter_value
951
919
  x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0)
952
920
 
953
- if lbfgs_status == LBFGSStatus.INIT_FAILED:
954
- raise LBFGSInitFailed()
921
+ if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT}:
922
+ raise LBFGSInitFailed(lbfgs_status)
955
923
  elif lbfgs_status == LBFGSStatus.LBFGS_FAILED:
956
924
  raise LBFGSException()
957
925
 
@@ -1389,15 +1357,16 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
1389
1357
  warnings = []
1390
1358
 
1391
1359
  lbfgs_status_message = {
1392
- LBFGSStatus.MAX_ITER_REACHED: "LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
1393
- LBFGSStatus.INIT_FAILED: "LBFGS failed to initialise. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
1394
- LBFGSStatus.DIVERGED: "LBFGS diverged to infinity. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1360
+ LBFGSStatus.MAX_ITER_REACHED: "MAX_ITER_REACHED: LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
1361
+ LBFGSStatus.INIT_FAILED: "INIT_FAILED: LBFGS failed to initialize. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
1362
+ LBFGSStatus.NON_FINITE: "NON_FINITE: LBFGS objective function produced inf or nan at the last iteration. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1363
+ LBFGSStatus.LOW_UPDATE_PCT: "LOW_UPDATE_PCT: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
1364
+ LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT: "INIT_FAILED_LOW_UPDATE_PCT: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
1395
1365
  }
1396
1366
 
1397
1367
  path_status_message = {
1398
- PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.",
1399
- PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
1400
- PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1368
+ PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO_ARGMAX_AT_ZERO: ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
1369
+ PathStatus.INVALID_LOGQ: "INVALID_LOGQ: Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1401
1370
  }
1402
1371
 
1403
1372
  for lbfgs_status in mpr.lbfgs_status:
@@ -1567,8 +1536,9 @@ def multipath_pathfinder(
1567
1536
  task,
1568
1537
  description=desc.format(path_idx=path_idx),
1569
1538
  completed=path_idx,
1570
- refresh=True,
1571
1539
  )
1540
+ # Ensure the progress bar visually reaches 100% and shows 'Completed'
1541
+ progress.update(task, completed=num_paths, description="Completed")
1572
1542
  except (KeyboardInterrupt, StopIteration) as e:
1573
1543
  # if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
1574
1544
  if isinstance(e, StopIteration):
@@ -1618,7 +1588,7 @@ def fit_pathfinder(
1618
1588
  maxiter: int = 1000, # L^max
1619
1589
  ftol: float = 1e-5,
1620
1590
  gtol: float = 1e-8,
1621
- maxls=1000,
1591
+ maxls: int = 1000,
1622
1592
  num_elbo_draws: int = 10, # K
1623
1593
  jitter: float = 2.0,
1624
1594
  epsilon: float = 1e-8,
@@ -1630,6 +1600,7 @@ def fit_pathfinder(
1630
1600
  inference_backend: Literal["pymc", "blackjax"] = "pymc",
1631
1601
  pathfinder_kwargs: dict = {},
1632
1602
  compile_kwargs: dict = {},
1603
+ initvals: dict | None = None,
1633
1604
  ) -> az.InferenceData:
1634
1605
  """
1635
1606
  Fit the Pathfinder Variational Inference algorithm.
@@ -1665,12 +1636,12 @@ def fit_pathfinder(
1665
1636
  importance_sampling : str, None, optional
1666
1637
  Method to apply sampling based on log importance weights (logP - logQ).
1667
1638
  Options are:
1668
- "psis" : Pareto Smoothed Importance Sampling (default)
1669
- Recommended for more stable results.
1670
- "psir" : Pareto Smoothed Importance Resampling
1671
- Less stable than PSIS.
1672
- "identity" : Applies log importance weights directly without resampling.
1673
- None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1639
+
1640
+ - "psis" : Pareto Smoothed Importance Sampling (default). Usually most stable.
1641
+ - "psir" : Pareto Smoothed Importance Resampling. Less stable than PSIS.
1642
+ - "identity" : Applies log importance weights directly without resampling.
1643
+ - None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1644
+
1674
1645
  progressbar : bool, optional
1675
1646
  Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
1676
1647
  random_seed : RandomSeed, optional
@@ -1685,10 +1656,13 @@ def fit_pathfinder(
1685
1656
  Additional keyword arguments for the Pathfinder algorithm.
1686
1657
  compile_kwargs
1687
1658
  Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
1659
+ initvals: dict | None = None
1660
+ Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
1661
+ If None, the model's default initial values are used.
1688
1662
 
1689
1663
  Returns
1690
1664
  -------
1691
- arviz.InferenceData
1665
+ :class:`~arviz.InferenceData`
1692
1666
  The inference data containing the results of the Pathfinder algorithm.
1693
1667
 
1694
1668
  References
@@ -1698,6 +1672,14 @@ def fit_pathfinder(
1698
1672
 
1699
1673
  model = modelcontext(model)
1700
1674
 
1675
+ if initvals is not None:
1676
+ model = pm.model.fgraph.clone_model(model) # Create a clone of the model
1677
+ for (
1678
+ rv_name,
1679
+ ivals,
1680
+ ) in initvals.items(): # Set the initial values for the variables in the clone
1681
+ model.set_initval(model.named_vars[rv_name], ivals)
1682
+
1701
1683
  valid_importance_sampling = {"psis", "psir", "identity", None}
1702
1684
 
1703
1685
  if importance_sampling is not None:
@@ -1736,8 +1718,9 @@ def fit_pathfinder(
1736
1718
  )
1737
1719
  pathfinder_samples = mp_result.samples
1738
1720
  elif inference_backend == "blackjax":
1739
- if find_spec("blackjax") is None:
1740
- raise RuntimeError("Need BlackJAX to use `pathfinder`")
1721
+ import blackjax
1722
+ import jax
1723
+
1741
1724
  if version.parse(blackjax.__version__).major < 1:
1742
1725
  raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
1743
1726
 
@@ -1775,4 +1758,7 @@ def fit_pathfinder(
1775
1758
  model=model,
1776
1759
  importance_sampling=importance_sampling,
1777
1760
  )
1761
+
1762
+ idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
1763
+
1778
1764
  return idata