pymc-extras 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. pymc_extras/deserialize.py +10 -4
  2. pymc_extras/distributions/continuous.py +1 -1
  3. pymc_extras/distributions/histogram_utils.py +6 -4
  4. pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  5. pymc_extras/distributions/timeseries.py +4 -2
  6. pymc_extras/inference/__init__.py +8 -1
  7. pymc_extras/inference/dadvi/__init__.py +0 -0
  8. pymc_extras/inference/dadvi/dadvi.py +351 -0
  9. pymc_extras/inference/fit.py +5 -0
  10. pymc_extras/inference/laplace_approx/find_map.py +32 -47
  11. pymc_extras/inference/laplace_approx/idata.py +27 -6
  12. pymc_extras/inference/laplace_approx/laplace.py +24 -6
  13. pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  14. pymc_extras/inference/pathfinder/idata.py +517 -0
  15. pymc_extras/inference/pathfinder/pathfinder.py +61 -7
  16. pymc_extras/model/marginal/graph_analysis.py +2 -2
  17. pymc_extras/model_builder.py +9 -4
  18. pymc_extras/prior.py +203 -8
  19. pymc_extras/statespace/core/compile.py +1 -1
  20. pymc_extras/statespace/filters/kalman_filter.py +12 -11
  21. pymc_extras/statespace/filters/kalman_smoother.py +1 -3
  22. pymc_extras/statespace/filters/utilities.py +2 -5
  23. pymc_extras/statespace/models/DFM.py +834 -0
  24. pymc_extras/statespace/models/ETS.py +190 -198
  25. pymc_extras/statespace/models/SARIMAX.py +9 -21
  26. pymc_extras/statespace/models/VARMAX.py +22 -74
  27. pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  28. pymc_extras/statespace/models/structural/components/regression.py +4 -26
  29. pymc_extras/statespace/models/utilities.py +7 -0
  30. pymc_extras/statespace/utils/constants.py +3 -1
  31. pymc_extras/utils/model_equivalence.py +2 -2
  32. pymc_extras/utils/prior.py +10 -14
  33. pymc_extras/utils/spline.py +4 -10
  34. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/METADATA +3 -3
  35. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/RECORD +37 -33
  36. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/WHEEL +1 -1
  37. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,517 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for converting Pathfinder results to xarray and adding them to InferenceData."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import warnings
20
+
21
+ from dataclasses import asdict
22
+
23
+ import arviz as az
24
+ import numpy as np
25
+ import pymc as pm
26
+ import xarray as xr
27
+
28
+ from pymc.blocking import DictToArrayBijection
29
+
30
+ from pymc_extras.inference.pathfinder.lbfgs import LBFGSStatus
31
+ from pymc_extras.inference.pathfinder.pathfinder import (
32
+ MultiPathfinderResult,
33
+ PathfinderConfig,
34
+ PathfinderResult,
35
+ PathStatus,
36
+ )
37
+
38
+
39
+ def get_param_coords(model: pm.Model | None, n_params: int) -> list[str]:
40
+ """
41
+ Get parameter coordinate labels from PyMC model.
42
+
43
+ Parameters
44
+ ----------
45
+ model : pm.Model | None
46
+ PyMC model to extract variable names from. If None, returns numeric indices.
47
+ n_params : int
48
+ Number of parameters (for fallback indexing when model is None)
49
+
50
+ Returns
51
+ -------
52
+ list[str]
53
+ Parameter coordinate labels
54
+ """
55
+ if model is None:
56
+ return [str(i) for i in range(n_params)]
57
+
58
+ ip = model.initial_point()
59
+ bij = DictToArrayBijection.map(ip)
60
+
61
+ coords = []
62
+ for var_name, shape, size, _ in bij.point_map_info:
63
+ if size == 1:
64
+ coords.append(var_name)
65
+ else:
66
+ for i in range(size):
67
+ coords.append(f"{var_name}[{i}]")
68
+ return coords
69
+
70
+
71
+ def _status_counter_to_dataarray(counter, status_enum_cls) -> xr.DataArray:
72
+ """Convert a Counter of status values to a dense xarray DataArray."""
73
+ all_statuses = list(status_enum_cls)
74
+ status_names = [s.name for s in all_statuses]
75
+
76
+ counts = np.array([counter.get(status, 0) for status in all_statuses])
77
+
78
+ return xr.DataArray(
79
+ counts, dims=["status"], coords={"status": status_names}, name="status_counts"
80
+ )
81
+
82
+
83
+ def _extract_scalar(value):
84
+ """Extract scalar from array-like or return as-is."""
85
+ if hasattr(value, "item"):
86
+ return value.item()
87
+ elif hasattr(value, "__len__") and len(value) == 1:
88
+ return value[0]
89
+ return value
90
+
91
+
92
+ def pathfinder_result_to_xarray(
93
+ result: PathfinderResult,
94
+ model: pm.Model | None = None,
95
+ ) -> xr.Dataset:
96
+ """
97
+ Convert a PathfinderResult to an xarray Dataset.
98
+
99
+ Parameters
100
+ ----------
101
+ result : PathfinderResult
102
+ Single pathfinder run result
103
+ model : pm.Model | None
104
+ PyMC model for parameter name extraction
105
+
106
+ Returns
107
+ -------
108
+ xr.Dataset
109
+ Dataset with pathfinder results
110
+
111
+ Examples
112
+ --------
113
+ >>> import pymc as pm
114
+ >>> import pymc_extras as pmx
115
+ >>>
116
+ >>> with pm.Model() as model:
117
+ ... x = pm.Normal("x", 0, 1)
118
+ ... y = pm.Normal("y", x, 1, observed=2.0)
119
+ >>> # Assuming we have a PathfinderResult from a pathfinder run
120
+ >>> ds = pathfinder_result_to_xarray(result, model=model)
121
+ >>> print(ds.data_vars) # Shows lbfgs_niter, elbo_argmax, status info, etc.
122
+ >>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
123
+ """
124
+ data_vars = {}
125
+ coords = {}
126
+ attrs = {}
127
+
128
+ n_params = None
129
+ if result.samples is not None:
130
+ n_params = result.samples.shape[-1]
131
+ elif hasattr(result, "lbfgs_niter") and result.lbfgs_niter is not None:
132
+ if model is not None:
133
+ try:
134
+ ip = model.initial_point()
135
+ n_params = len(DictToArrayBijection.map(ip).data)
136
+ except Exception:
137
+ pass
138
+
139
+ if n_params is not None:
140
+ coords["param"] = get_param_coords(model, n_params)
141
+
142
+ if result.lbfgs_niter is not None:
143
+ data_vars["lbfgs_niter"] = xr.DataArray(_extract_scalar(result.lbfgs_niter))
144
+
145
+ if result.elbo_argmax is not None:
146
+ data_vars["elbo_argmax"] = xr.DataArray(_extract_scalar(result.elbo_argmax))
147
+
148
+ data_vars["lbfgs_status_code"] = xr.DataArray(result.lbfgs_status.value)
149
+ data_vars["lbfgs_status_name"] = xr.DataArray(result.lbfgs_status.name)
150
+ data_vars["path_status_code"] = xr.DataArray(result.path_status.value)
151
+ data_vars["path_status_name"] = xr.DataArray(result.path_status.name)
152
+
153
+ if n_params is not None and result.samples is not None:
154
+ if result.samples.ndim >= 2:
155
+ representative_sample = result.samples[0, -1, :]
156
+ data_vars["final_sample"] = xr.DataArray(
157
+ representative_sample, dims=["param"], coords={"param": coords["param"]}
158
+ )
159
+
160
+ if result.logP is not None:
161
+ logP = result.logP.flatten() if hasattr(result.logP, "flatten") else result.logP
162
+ if hasattr(logP, "__len__") and len(logP) > 0:
163
+ data_vars["logP_mean"] = xr.DataArray(np.mean(logP))
164
+ data_vars["logP_std"] = xr.DataArray(np.std(logP))
165
+ data_vars["logP_max"] = xr.DataArray(np.max(logP))
166
+
167
+ if result.logQ is not None:
168
+ logQ = result.logQ.flatten() if hasattr(result.logQ, "flatten") else result.logQ
169
+ if hasattr(logQ, "__len__") and len(logQ) > 0:
170
+ data_vars["logQ_mean"] = xr.DataArray(np.mean(logQ))
171
+ data_vars["logQ_std"] = xr.DataArray(np.std(logQ))
172
+ data_vars["logQ_max"] = xr.DataArray(np.max(logQ))
173
+
174
+ attrs["lbfgs_status"] = result.lbfgs_status.name
175
+ attrs["path_status"] = result.path_status.name
176
+
177
+ ds = xr.Dataset(data_vars, coords=coords, attrs=attrs)
178
+
179
+ return ds
180
+
181
+
182
+ def multipathfinder_result_to_xarray(
183
+ result: MultiPathfinderResult,
184
+ model: pm.Model | None = None,
185
+ *,
186
+ store_diagnostics: bool = False,
187
+ ) -> xr.Dataset:
188
+ """
189
+ Convert a MultiPathfinderResult to a single consolidated xarray Dataset.
190
+
191
+ Parameters
192
+ ----------
193
+ result : MultiPathfinderResult
194
+ Multi-path pathfinder result
195
+ model : pm.Model | None
196
+ PyMC model for parameter name extraction
197
+ store_diagnostics : bool
198
+ Whether to include potentially large diagnostic arrays
199
+
200
+ Returns
201
+ -------
202
+ xr.Dataset
203
+ Single consolidated dataset with all pathfinder results
204
+
205
+ Examples
206
+ --------
207
+ >>> import pymc as pm
208
+ >>> import pymc_extras as pmx
209
+ >>>
210
+ >>> with pm.Model() as model:
211
+ ... x = pm.Normal("x", 0, 1)
212
+ ...
213
+ >>> # Assuming we have a MultiPathfinderResult from multiple pathfinder runs
214
+ >>> ds = multipathfinder_result_to_xarray(result, model=model)
215
+ >>> print("All data:", ds.data_vars)
216
+ >>> print(
217
+ ... "Summary:",
218
+ ... [
219
+ ... k
220
+ ... for k in ds.data_vars.keys()
221
+ ... if not k.startswith(("paths/", "config/", "diagnostics/"))
222
+ ... ],
223
+ ... )
224
+ >>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith("paths/")])
225
+ >>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith("config/")])
226
+ """
227
+ n_params = result.samples.shape[-1] if result.samples is not None else None
228
+ param_coords = get_param_coords(model, n_params) if n_params is not None else None
229
+
230
+ data_vars = {}
231
+ coords = {}
232
+ attrs = {}
233
+
234
+ # Add parameter coordinates if available
235
+ if param_coords is not None:
236
+ coords["param"] = param_coords
237
+
238
+ # Build summary-level data (top level)
239
+ _add_summary_data(result, data_vars, coords, attrs)
240
+
241
+ # Build per-path data (with paths/ prefix)
242
+ if not result.all_paths_failed and result.samples is not None:
243
+ _add_paths_data(result, data_vars, coords, param_coords, n_params)
244
+
245
+ # Build configuration data (with config/ prefix)
246
+ if result.pathfinder_config is not None:
247
+ _add_config_data(result.pathfinder_config, data_vars)
248
+
249
+ # Build diagnostics data (with diagnostics/ prefix) if requested
250
+ if store_diagnostics:
251
+ _add_diagnostics_data(result, data_vars, coords, param_coords)
252
+
253
+ return xr.Dataset(data_vars, coords=coords, attrs=attrs)
254
+
255
+
256
+ def _add_summary_data(
257
+ result: MultiPathfinderResult, data_vars: dict, coords: dict, attrs: dict
258
+ ) -> None:
259
+ """Add summary-level statistics to the pathfinder dataset."""
260
+ if result.num_paths is not None:
261
+ data_vars["num_paths"] = xr.DataArray(result.num_paths)
262
+ if result.num_draws is not None:
263
+ data_vars["num_draws"] = xr.DataArray(result.num_draws)
264
+
265
+ if result.compile_time is not None:
266
+ data_vars["compile_time"] = xr.DataArray(result.compile_time)
267
+ if result.compute_time is not None:
268
+ data_vars["compute_time"] = xr.DataArray(result.compute_time)
269
+ if result.compile_time is not None:
270
+ data_vars["total_time"] = xr.DataArray(result.compile_time + result.compute_time)
271
+
272
+ data_vars["importance_sampling_method"] = xr.DataArray(result.importance_sampling or "none")
273
+ if result.pareto_k is not None:
274
+ data_vars["pareto_k"] = xr.DataArray(result.pareto_k)
275
+
276
+ if result.lbfgs_status:
277
+ data_vars["lbfgs_status_counts"] = _status_counter_to_dataarray(
278
+ result.lbfgs_status, LBFGSStatus
279
+ )
280
+ if result.path_status:
281
+ data_vars["path_status_counts"] = _status_counter_to_dataarray(
282
+ result.path_status, PathStatus
283
+ )
284
+
285
+ data_vars["all_paths_failed"] = xr.DataArray(result.all_paths_failed)
286
+ if not result.all_paths_failed and result.samples is not None:
287
+ data_vars["num_successful_paths"] = xr.DataArray(result.samples.shape[0])
288
+
289
+ if result.lbfgs_niter is not None:
290
+ data_vars["lbfgs_niter_mean"] = xr.DataArray(np.mean(result.lbfgs_niter))
291
+ data_vars["lbfgs_niter_std"] = xr.DataArray(np.std(result.lbfgs_niter))
292
+
293
+ if result.elbo_argmax is not None:
294
+ data_vars["elbo_argmax_mean"] = xr.DataArray(np.mean(result.elbo_argmax))
295
+ data_vars["elbo_argmax_std"] = xr.DataArray(np.std(result.elbo_argmax))
296
+
297
+ if result.logP is not None:
298
+ data_vars["logP_mean"] = xr.DataArray(np.mean(result.logP))
299
+ data_vars["logP_std"] = xr.DataArray(np.std(result.logP))
300
+ data_vars["logP_max"] = xr.DataArray(np.max(result.logP))
301
+
302
+ if result.logQ is not None:
303
+ data_vars["logQ_mean"] = xr.DataArray(np.mean(result.logQ))
304
+ data_vars["logQ_std"] = xr.DataArray(np.std(result.logQ))
305
+ data_vars["logQ_max"] = xr.DataArray(np.max(result.logQ))
306
+
307
+ # Add warnings to attributes
308
+ if result.warnings:
309
+ attrs["warnings"] = list(result.warnings)
310
+
311
+
312
+ def _add_paths_data(
313
+ result: MultiPathfinderResult,
314
+ data_vars: dict,
315
+ coords: dict,
316
+ param_coords: list[str] | None,
317
+ n_params: int | None,
318
+ ) -> None:
319
+ """Add per-path diagnostics to the pathfinder dataset with 'paths/' prefix."""
320
+ n_paths = _determine_num_paths(result)
321
+
322
+ # Add path coordinate
323
+ coords["path"] = list(range(n_paths))
324
+
325
+ def _add_path_scalar(name: str, data):
326
+ """Add a per-path scalar array to data_vars with paths/ prefix."""
327
+ if data is not None:
328
+ data_vars[f"paths/{name}"] = xr.DataArray(
329
+ data, dims=["path"], coords={"path": coords["path"]}
330
+ )
331
+
332
+ _add_path_scalar("lbfgs_niter", result.lbfgs_niter)
333
+ _add_path_scalar("elbo_argmax", result.elbo_argmax)
334
+
335
+ if result.logP is not None:
336
+ _add_path_scalar("logP_mean", np.mean(result.logP, axis=1))
337
+ _add_path_scalar("logP_max", np.max(result.logP, axis=1))
338
+
339
+ if result.logQ is not None:
340
+ _add_path_scalar("logQ_mean", np.mean(result.logQ, axis=1))
341
+ _add_path_scalar("logQ_max", np.max(result.logQ, axis=1))
342
+
343
+ if n_params is not None and result.samples is not None and result.samples.ndim >= 3:
344
+ final_samples = result.samples[:, -1, :] # (S, N)
345
+ data_vars["paths/final_sample"] = xr.DataArray(
346
+ final_samples,
347
+ dims=["path", "param"],
348
+ coords={"path": coords["path"], "param": coords["param"]},
349
+ )
350
+
351
+
352
+ def _add_config_data(config: PathfinderConfig, data_vars: dict) -> None:
353
+ """Add configuration parameters to the pathfinder dataset with 'config/' prefix."""
354
+ config_dict = asdict(config)
355
+ for key, value in config_dict.items():
356
+ data_vars[f"config/{key}"] = xr.DataArray(value)
357
+
358
+
359
+ def _add_diagnostics_data(
360
+ result: MultiPathfinderResult, data_vars: dict, coords: dict, param_coords: list[str] | None
361
+ ) -> None:
362
+ """Add detailed diagnostics to the pathfinder dataset with 'diagnostics/' prefix."""
363
+ if result.logP is not None:
364
+ n_paths, n_draws_per_path = result.logP.shape
365
+ if "path" not in coords:
366
+ coords["path"] = list(range(n_paths))
367
+ coords["draw_per_path"] = list(range(n_draws_per_path))
368
+
369
+ data_vars["diagnostics/logP_full"] = xr.DataArray(
370
+ result.logP,
371
+ dims=["path", "draw_per_path"],
372
+ coords={"path": coords["path"], "draw_per_path": coords["draw_per_path"]},
373
+ )
374
+
375
+ if result.logQ is not None:
376
+ if "draw_per_path" not in coords:
377
+ n_paths, n_draws_per_path = result.logQ.shape
378
+ if "path" not in coords:
379
+ coords["path"] = list(range(n_paths))
380
+ coords["draw_per_path"] = list(range(n_draws_per_path))
381
+
382
+ data_vars["diagnostics/logQ_full"] = xr.DataArray(
383
+ result.logQ,
384
+ dims=["path", "draw_per_path"],
385
+ coords={"path": coords["path"], "draw_per_path": coords["draw_per_path"]},
386
+ )
387
+
388
+ if result.samples is not None and result.samples.ndim == 3 and param_coords is not None:
389
+ n_paths, n_draws_per_path, n_params = result.samples.shape
390
+
391
+ if "path" not in coords:
392
+ coords["path"] = list(range(n_paths))
393
+ if "draw_per_path" not in coords:
394
+ coords["draw_per_path"] = list(range(n_draws_per_path))
395
+
396
+ data_vars["diagnostics/samples_full"] = xr.DataArray(
397
+ result.samples,
398
+ dims=["path", "draw_per_path", "param"],
399
+ coords={
400
+ "path": coords["path"],
401
+ "draw_per_path": coords["draw_per_path"],
402
+ "param": coords["param"],
403
+ },
404
+ )
405
+
406
+
407
+ def _determine_num_paths(result: MultiPathfinderResult) -> int:
408
+ """
409
+ Determine the number of paths from per-path arrays.
410
+
411
+ When importance sampling is applied, result.samples may be collapsed,
412
+ so we use per-path diagnostic arrays to determine the true path count.
413
+ """
414
+ if result.lbfgs_niter is not None:
415
+ return len(result.lbfgs_niter)
416
+ elif result.elbo_argmax is not None:
417
+ return len(result.elbo_argmax)
418
+ elif result.logP is not None:
419
+ return result.logP.shape[0]
420
+ elif result.logQ is not None:
421
+ return result.logQ.shape[0]
422
+
423
+ if result.lbfgs_status:
424
+ return sum(result.lbfgs_status.values())
425
+ elif result.path_status:
426
+ return sum(result.path_status.values())
427
+
428
+ if result.samples is not None:
429
+ return result.samples.shape[0]
430
+
431
+ raise ValueError("Cannot determine number of paths from result")
432
+
433
+
434
+ def add_pathfinder_to_inference_data(
435
+ idata: az.InferenceData,
436
+ result: PathfinderResult | MultiPathfinderResult,
437
+ model: pm.Model | None = None,
438
+ *,
439
+ group: str = "pathfinder",
440
+ paths_group: str = "pathfinder_paths", # Deprecated, kept for API compatibility
441
+ diagnostics_group: str = "pathfinder_diagnostics", # Deprecated, kept for API compatibility
442
+ config_group: str = "pathfinder_config", # Deprecated, kept for API compatibility
443
+ store_diagnostics: bool = False,
444
+ ) -> az.InferenceData:
445
+ """
446
+ Add pathfinder results to an ArviZ InferenceData object as a single consolidated group.
447
+
448
+ All pathfinder output is now consolidated under a single group with nested structure:
449
+ - Summary statistics at the top level
450
+ - Per-path data with 'paths/' prefix
451
+ - Configuration with 'config/' prefix
452
+ - Diagnostics with 'diagnostics/' prefix (if store_diagnostics=True)
453
+
454
+ Parameters
455
+ ----------
456
+ idata : az.InferenceData
457
+ InferenceData object to modify
458
+ result : PathfinderResult | MultiPathfinderResult
459
+ Pathfinder results to add
460
+ model : pm.Model | None
461
+ PyMC model for parameter name extraction
462
+ group : str
463
+ Name for the pathfinder group (default: "pathfinder")
464
+ paths_group : str
465
+ Deprecated: no longer used, kept for API compatibility
466
+ diagnostics_group : str
467
+ Deprecated: no longer used, kept for API compatibility
468
+ config_group : str
469
+ Deprecated: no longer used, kept for API compatibility
470
+ store_diagnostics : bool
471
+ Whether to include potentially large diagnostic arrays
472
+
473
+ Returns
474
+ -------
475
+ az.InferenceData
476
+ Modified InferenceData object with consolidated pathfinder group added
477
+
478
+ Examples
479
+ --------
480
+ >>> import pymc as pm
481
+ >>> import pymc_extras as pmx
482
+ >>>
483
+ >>> with pm.Model() as model:
484
+ ... x = pm.Normal("x", 0, 1)
485
+ ... idata = pmx.fit(method="pathfinder", model=model, add_pathfinder_groups=False)
486
+ >>> # Assuming we have pathfinder results
487
+ >>> idata = add_pathfinder_to_inference_data(idata, results, model=model)
488
+ >>> print(list(idata.groups())) # Will show ['posterior', 'pathfinder']
489
+ >>> # Access nested data:
490
+ >>> print(
491
+ ... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("paths/")]
492
+ ... ) # Per-path data
493
+ >>> print(
494
+ ... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("config/")]
495
+ ... ) # Config data
496
+ """
497
+ # Detect if this is a multi-path result
498
+ # Use isinstance() as primary check, but fall back to duck typing for compatibility
499
+ # with mocks and testing (MultiPathfinderResult has Counter-type status fields)
500
+ is_multipath = isinstance(result, MultiPathfinderResult) or (
501
+ hasattr(result, "lbfgs_status")
502
+ and hasattr(result.lbfgs_status, "values")
503
+ and callable(getattr(result.lbfgs_status, "values"))
504
+ )
505
+
506
+ if is_multipath:
507
+ consolidated_ds = multipathfinder_result_to_xarray(
508
+ result, model=model, store_diagnostics=store_diagnostics
509
+ )
510
+ else:
511
+ consolidated_ds = pathfinder_result_to_xarray(result, model=model)
512
+
513
+ if group in idata.groups():
514
+ warnings.warn(f"Group '{group}' already exists in InferenceData, it will be replaced.")
515
+
516
+ idata.add_groups({group: consolidated_ds})
517
+ return idata
@@ -16,12 +16,13 @@
16
16
  import collections
17
17
  import logging
18
18
  import time
19
+ import warnings
19
20
 
20
21
  from collections import Counter
21
22
  from collections.abc import Callable, Iterator
22
23
  from dataclasses import asdict, dataclass, field, replace
23
24
  from enum import Enum, auto
24
- from typing import Literal, TypeAlias
25
+ from typing import Literal, Self, TypeAlias
25
26
 
26
27
  import arviz as az
27
28
  import filelock
@@ -59,9 +60,6 @@ from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainin
59
60
  from rich.table import Table
60
61
  from rich.text import Text
61
62
 
62
- # TODO: change to typing.Self after Python versions greater than 3.10
63
- from typing_extensions import Self
64
-
65
63
  from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data
66
64
  from pymc_extras.inference.pathfinder.importance_sampling import (
67
65
  importance_sampling as _importance_sampling,
@@ -532,7 +530,7 @@ def bfgs_sample_sparse(
532
530
 
533
531
  # qr_input: (L, N, 2J)
534
532
  qr_input = inv_sqrt_alpha_diag @ beta
535
- (Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False)
533
+ (Q, R), _ = pytensor.scan(fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False)
536
534
 
537
535
  IdN = pt.eye(R.shape[1])[None, ...]
538
536
  IdN += IdN * REGULARISATION_TERM
@@ -1398,6 +1396,7 @@ def multipath_pathfinder(
1398
1396
  random_seed: RandomSeed,
1399
1397
  pathfinder_kwargs: dict = {},
1400
1398
  compile_kwargs: dict = {},
1399
+ display_summary: bool = True,
1401
1400
  ) -> MultiPathfinderResult:
1402
1401
  """
1403
1402
  Fit the Pathfinder Variational Inference algorithm using multiple paths with PyMC/PyTensor backend.
@@ -1556,8 +1555,9 @@ def multipath_pathfinder(
1556
1555
  compute_time=compute_end - compute_start,
1557
1556
  )
1558
1557
  )
1559
- # TODO: option to disable summary, save to file, etc.
1560
- mpr.display_summary()
1558
+ # Display summary conditionally
1559
+ if display_summary:
1560
+ mpr.display_summary()
1561
1561
  if mpr.all_paths_failed:
1562
1562
  raise ValueError(
1563
1563
  "All paths failed. Consider decreasing the jitter or reparameterizing the model."
@@ -1600,6 +1600,14 @@ def fit_pathfinder(
1600
1600
  pathfinder_kwargs: dict = {},
1601
1601
  compile_kwargs: dict = {},
1602
1602
  initvals: dict | None = None,
1603
+ # New pathfinder result integration options
1604
+ add_pathfinder_groups: bool = True,
1605
+ display_summary: bool | Literal["auto"] = "auto",
1606
+ store_diagnostics: bool = False,
1607
+ pathfinder_group: str = "pathfinder",
1608
+ paths_group: str = "pathfinder_paths",
1609
+ diagnostics_group: str = "pathfinder_diagnostics",
1610
+ config_group: str = "pathfinder_config",
1603
1611
  ) -> az.InferenceData:
1604
1612
  """
1605
1613
  Fit the Pathfinder Variational Inference algorithm.
@@ -1658,6 +1666,22 @@ def fit_pathfinder(
1658
1666
  initvals: dict | None = None
1659
1667
  Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
1660
1668
  If None, the model's default initial values are used.
1669
+ add_pathfinder_groups : bool, optional
1670
+ Whether to add pathfinder results as additional groups to the InferenceData (default is True).
1671
+ When True, adds pathfinder and pathfinder_paths groups with optimization diagnostics.
1672
+ display_summary : bool or "auto", optional
1673
+ Whether to display the pathfinder results summary (default is "auto").
1674
+ "auto" preserves current behavior, False suppresses console output.
1675
+ store_diagnostics : bool, optional
1676
+ Whether to include potentially large diagnostic arrays in the pathfinder groups (default is False).
1677
+ pathfinder_group : str, optional
1678
+ Name for the main pathfinder results group (default is "pathfinder").
1679
+ paths_group : str, optional
1680
+ Name for the per-path results group (default is "pathfinder_paths").
1681
+ diagnostics_group : str, optional
1682
+ Name for the diagnostics group (default is "pathfinder_diagnostics").
1683
+ config_group : str, optional
1684
+ Name for the configuration group (default is "pathfinder_config").
1661
1685
 
1662
1686
  Returns
1663
1687
  -------
@@ -1694,6 +1718,9 @@ def fit_pathfinder(
1694
1718
  maxcor = np.ceil(3 * np.log(N)).astype(np.int32)
1695
1719
  maxcor = max(maxcor, 5)
1696
1720
 
1721
+ # Handle display_summary logic
1722
+ should_display_summary = display_summary == "auto" or display_summary is True
1723
+
1697
1724
  if inference_backend == "pymc":
1698
1725
  mp_result = multipath_pathfinder(
1699
1726
  model,
@@ -1714,6 +1741,7 @@ def fit_pathfinder(
1714
1741
  random_seed=random_seed,
1715
1742
  pathfinder_kwargs=pathfinder_kwargs,
1716
1743
  compile_kwargs=compile_kwargs,
1744
+ display_summary=should_display_summary,
1717
1745
  )
1718
1746
  pathfinder_samples = mp_result.samples
1719
1747
  elif inference_backend == "blackjax":
@@ -1760,4 +1788,30 @@ def fit_pathfinder(
1760
1788
 
1761
1789
  idata = add_data_to_inference_data(idata, progressbar, model, compile_kwargs)
1762
1790
 
1791
+ # Add pathfinder results to InferenceData if requested
1792
+ if add_pathfinder_groups:
1793
+ if inference_backend == "pymc":
1794
+ from pymc_extras.inference.pathfinder.idata import add_pathfinder_to_inference_data
1795
+
1796
+ idata = add_pathfinder_to_inference_data(
1797
+ idata=idata,
1798
+ result=mp_result,
1799
+ model=model,
1800
+ group=pathfinder_group,
1801
+ paths_group=paths_group,
1802
+ diagnostics_group=diagnostics_group,
1803
+ config_group=config_group,
1804
+ store_diagnostics=store_diagnostics,
1805
+ )
1806
+ else:
1807
+ warnings.warn(
1808
+ f"Pathfinder diagnostic groups are only supported with the PyMC backend. "
1809
+ f"Current backend is '{inference_backend}', which does not support adding "
1810
+ "pathfinder diagnostics to InferenceData. The InferenceData will only contain "
1811
+ "posterior samples. To add diagnostic groups, use inference_backend='pymc', "
1812
+ "or set add_pathfinder_groups=False to suppress this warning.",
1813
+ UserWarning,
1814
+ stacklevel=2,
1815
+ )
1816
+
1763
1817
  return idata
@@ -6,8 +6,8 @@ from itertools import zip_longest
6
6
  from pymc import SymbolicRandomVariable
7
7
  from pymc.model.fgraph import ModelVar
8
8
  from pymc.variational.minibatch_rv import MinibatchRandomVariable
9
- from pytensor.graph import Variable, ancestors
10
- from pytensor.graph.basic import io_toposort
9
+ from pytensor.graph.basic import Variable
10
+ from pytensor.graph.traversal import ancestors, io_toposort
11
11
  from pytensor.tensor import TensorType, TensorVariable
12
12
  from pytensor.tensor.blockwise import Blockwise
13
13
  from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise