pymc-extras 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. pymc_extras/deserialize.py +10 -4
  2. pymc_extras/distributions/continuous.py +1 -1
  3. pymc_extras/distributions/histogram_utils.py +6 -4
  4. pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  5. pymc_extras/distributions/timeseries.py +14 -12
  6. pymc_extras/inference/dadvi/dadvi.py +149 -128
  7. pymc_extras/inference/laplace_approx/find_map.py +16 -39
  8. pymc_extras/inference/laplace_approx/idata.py +22 -4
  9. pymc_extras/inference/laplace_approx/laplace.py +196 -151
  10. pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  11. pymc_extras/inference/pathfinder/idata.py +517 -0
  12. pymc_extras/inference/pathfinder/pathfinder.py +71 -12
  13. pymc_extras/inference/smc/sampling.py +2 -2
  14. pymc_extras/model/marginal/distributions.py +4 -2
  15. pymc_extras/model/marginal/graph_analysis.py +2 -2
  16. pymc_extras/model/marginal/marginal_model.py +12 -2
  17. pymc_extras/model_builder.py +9 -4
  18. pymc_extras/prior.py +203 -8
  19. pymc_extras/statespace/core/compile.py +1 -1
  20. pymc_extras/statespace/core/statespace.py +2 -1
  21. pymc_extras/statespace/filters/distributions.py +15 -13
  22. pymc_extras/statespace/filters/kalman_filter.py +24 -22
  23. pymc_extras/statespace/filters/kalman_smoother.py +3 -5
  24. pymc_extras/statespace/filters/utilities.py +2 -5
  25. pymc_extras/statespace/models/DFM.py +12 -27
  26. pymc_extras/statespace/models/ETS.py +190 -198
  27. pymc_extras/statespace/models/SARIMAX.py +5 -17
  28. pymc_extras/statespace/models/VARMAX.py +15 -67
  29. pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  30. pymc_extras/statespace/models/structural/components/regression.py +4 -26
  31. pymc_extras/statespace/models/utilities.py +7 -0
  32. pymc_extras/utils/model_equivalence.py +2 -2
  33. pymc_extras/utils/prior.py +10 -14
  34. pymc_extras/utils/spline.py +4 -10
  35. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/METADATA +4 -4
  36. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/RECORD +38 -37
  37. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/WHEEL +1 -1
  38. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,517 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for converting Pathfinder results to xarray and adding them to InferenceData."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import warnings
20
+
21
+ from dataclasses import asdict
22
+
23
+ import arviz as az
24
+ import numpy as np
25
+ import pymc as pm
26
+ import xarray as xr
27
+
28
+ from pymc.blocking import DictToArrayBijection
29
+
30
+ from pymc_extras.inference.pathfinder.lbfgs import LBFGSStatus
31
+ from pymc_extras.inference.pathfinder.pathfinder import (
32
+ MultiPathfinderResult,
33
+ PathfinderConfig,
34
+ PathfinderResult,
35
+ PathStatus,
36
+ )
37
+
38
+
39
+ def get_param_coords(model: pm.Model | None, n_params: int) -> list[str]:
40
+ """
41
+ Get parameter coordinate labels from PyMC model.
42
+
43
+ Parameters
44
+ ----------
45
+ model : pm.Model | None
46
+ PyMC model to extract variable names from. If None, returns numeric indices.
47
+ n_params : int
48
+ Number of parameters (for fallback indexing when model is None)
49
+
50
+ Returns
51
+ -------
52
+ list[str]
53
+ Parameter coordinate labels
54
+ """
55
+ if model is None:
56
+ return [str(i) for i in range(n_params)]
57
+
58
+ ip = model.initial_point()
59
+ bij = DictToArrayBijection.map(ip)
60
+
61
+ coords = []
62
+ for var_name, shape, size, _ in bij.point_map_info:
63
+ if size == 1:
64
+ coords.append(var_name)
65
+ else:
66
+ for i in range(size):
67
+ coords.append(f"{var_name}[{i}]")
68
+ return coords
69
+
70
+
71
+ def _status_counter_to_dataarray(counter, status_enum_cls) -> xr.DataArray:
72
+ """Convert a Counter of status values to a dense xarray DataArray."""
73
+ all_statuses = list(status_enum_cls)
74
+ status_names = [s.name for s in all_statuses]
75
+
76
+ counts = np.array([counter.get(status, 0) for status in all_statuses])
77
+
78
+ return xr.DataArray(
79
+ counts, dims=["status"], coords={"status": status_names}, name="status_counts"
80
+ )
81
+
82
+
83
+ def _extract_scalar(value):
84
+ """Extract scalar from array-like or return as-is."""
85
+ if hasattr(value, "item"):
86
+ return value.item()
87
+ elif hasattr(value, "__len__") and len(value) == 1:
88
+ return value[0]
89
+ return value
90
+
91
+
92
+ def pathfinder_result_to_xarray(
93
+ result: PathfinderResult,
94
+ model: pm.Model | None = None,
95
+ ) -> xr.Dataset:
96
+ """
97
+ Convert a PathfinderResult to an xarray Dataset.
98
+
99
+ Parameters
100
+ ----------
101
+ result : PathfinderResult
102
+ Single pathfinder run result
103
+ model : pm.Model | None
104
+ PyMC model for parameter name extraction
105
+
106
+ Returns
107
+ -------
108
+ xr.Dataset
109
+ Dataset with pathfinder results
110
+
111
+ Examples
112
+ --------
113
+ >>> import pymc as pm
114
+ >>> import pymc_extras as pmx
115
+ >>>
116
+ >>> with pm.Model() as model:
117
+ ... x = pm.Normal("x", 0, 1)
118
+ ... y = pm.Normal("y", x, 1, observed=2.0)
119
+ >>> # Assuming we have a PathfinderResult from a pathfinder run
120
+ >>> ds = pathfinder_result_to_xarray(result, model=model)
121
+ >>> print(ds.data_vars) # Shows lbfgs_niter, elbo_argmax, status info, etc.
122
+ >>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
123
+ """
124
+ data_vars = {}
125
+ coords = {}
126
+ attrs = {}
127
+
128
+ n_params = None
129
+ if result.samples is not None:
130
+ n_params = result.samples.shape[-1]
131
+ elif hasattr(result, "lbfgs_niter") and result.lbfgs_niter is not None:
132
+ if model is not None:
133
+ try:
134
+ ip = model.initial_point()
135
+ n_params = len(DictToArrayBijection.map(ip).data)
136
+ except Exception:
137
+ pass
138
+
139
+ if n_params is not None:
140
+ coords["param"] = get_param_coords(model, n_params)
141
+
142
+ if result.lbfgs_niter is not None:
143
+ data_vars["lbfgs_niter"] = xr.DataArray(_extract_scalar(result.lbfgs_niter))
144
+
145
+ if result.elbo_argmax is not None:
146
+ data_vars["elbo_argmax"] = xr.DataArray(_extract_scalar(result.elbo_argmax))
147
+
148
+ data_vars["lbfgs_status_code"] = xr.DataArray(result.lbfgs_status.value)
149
+ data_vars["lbfgs_status_name"] = xr.DataArray(result.lbfgs_status.name)
150
+ data_vars["path_status_code"] = xr.DataArray(result.path_status.value)
151
+ data_vars["path_status_name"] = xr.DataArray(result.path_status.name)
152
+
153
+ if n_params is not None and result.samples is not None:
154
+ if result.samples.ndim >= 2:
155
+ representative_sample = result.samples[0, -1, :]
156
+ data_vars["final_sample"] = xr.DataArray(
157
+ representative_sample, dims=["param"], coords={"param": coords["param"]}
158
+ )
159
+
160
+ if result.logP is not None:
161
+ logP = result.logP.flatten() if hasattr(result.logP, "flatten") else result.logP
162
+ if hasattr(logP, "__len__") and len(logP) > 0:
163
+ data_vars["logP_mean"] = xr.DataArray(np.mean(logP))
164
+ data_vars["logP_std"] = xr.DataArray(np.std(logP))
165
+ data_vars["logP_max"] = xr.DataArray(np.max(logP))
166
+
167
+ if result.logQ is not None:
168
+ logQ = result.logQ.flatten() if hasattr(result.logQ, "flatten") else result.logQ
169
+ if hasattr(logQ, "__len__") and len(logQ) > 0:
170
+ data_vars["logQ_mean"] = xr.DataArray(np.mean(logQ))
171
+ data_vars["logQ_std"] = xr.DataArray(np.std(logQ))
172
+ data_vars["logQ_max"] = xr.DataArray(np.max(logQ))
173
+
174
+ attrs["lbfgs_status"] = result.lbfgs_status.name
175
+ attrs["path_status"] = result.path_status.name
176
+
177
+ ds = xr.Dataset(data_vars, coords=coords, attrs=attrs)
178
+
179
+ return ds
180
+
181
+
182
+ def multipathfinder_result_to_xarray(
183
+ result: MultiPathfinderResult,
184
+ model: pm.Model | None = None,
185
+ *,
186
+ store_diagnostics: bool = False,
187
+ ) -> xr.Dataset:
188
+ """
189
+ Convert a MultiPathfinderResult to a single consolidated xarray Dataset.
190
+
191
+ Parameters
192
+ ----------
193
+ result : MultiPathfinderResult
194
+ Multi-path pathfinder result
195
+ model : pm.Model | None
196
+ PyMC model for parameter name extraction
197
+ store_diagnostics : bool
198
+ Whether to include potentially large diagnostic arrays
199
+
200
+ Returns
201
+ -------
202
+ xr.Dataset
203
+ Single consolidated dataset with all pathfinder results
204
+
205
+ Examples
206
+ --------
207
+ >>> import pymc as pm
208
+ >>> import pymc_extras as pmx
209
+ >>>
210
+ >>> with pm.Model() as model:
211
+ ... x = pm.Normal("x", 0, 1)
212
+ ...
213
+ >>> # Assuming we have a MultiPathfinderResult from multiple pathfinder runs
214
+ >>> ds = multipathfinder_result_to_xarray(result, model=model)
215
+ >>> print("All data:", ds.data_vars)
216
+ >>> print(
217
+ ... "Summary:",
218
+ ... [
219
+ ... k
220
+ ... for k in ds.data_vars.keys()
221
+ ... if not k.startswith(("paths/", "config/", "diagnostics/"))
222
+ ... ],
223
+ ... )
224
+ >>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith("paths/")])
225
+ >>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith("config/")])
226
+ """
227
+ n_params = result.samples.shape[-1] if result.samples is not None else None
228
+ param_coords = get_param_coords(model, n_params) if n_params is not None else None
229
+
230
+ data_vars = {}
231
+ coords = {}
232
+ attrs = {}
233
+
234
+ # Add parameter coordinates if available
235
+ if param_coords is not None:
236
+ coords["param"] = param_coords
237
+
238
+ # Build summary-level data (top level)
239
+ _add_summary_data(result, data_vars, coords, attrs)
240
+
241
+ # Build per-path data (with paths/ prefix)
242
+ if not result.all_paths_failed and result.samples is not None:
243
+ _add_paths_data(result, data_vars, coords, param_coords, n_params)
244
+
245
+ # Build configuration data (with config/ prefix)
246
+ if result.pathfinder_config is not None:
247
+ _add_config_data(result.pathfinder_config, data_vars)
248
+
249
+ # Build diagnostics data (with diagnostics/ prefix) if requested
250
+ if store_diagnostics:
251
+ _add_diagnostics_data(result, data_vars, coords, param_coords)
252
+
253
+ return xr.Dataset(data_vars, coords=coords, attrs=attrs)
254
+
255
+
256
+ def _add_summary_data(
257
+ result: MultiPathfinderResult, data_vars: dict, coords: dict, attrs: dict
258
+ ) -> None:
259
+ """Add summary-level statistics to the pathfinder dataset."""
260
+ if result.num_paths is not None:
261
+ data_vars["num_paths"] = xr.DataArray(result.num_paths)
262
+ if result.num_draws is not None:
263
+ data_vars["num_draws"] = xr.DataArray(result.num_draws)
264
+
265
+ if result.compile_time is not None:
266
+ data_vars["compile_time"] = xr.DataArray(result.compile_time)
267
+ if result.compute_time is not None:
268
+ data_vars["compute_time"] = xr.DataArray(result.compute_time)
269
+ if result.compile_time is not None:
270
+ data_vars["total_time"] = xr.DataArray(result.compile_time + result.compute_time)
271
+
272
+ data_vars["importance_sampling_method"] = xr.DataArray(result.importance_sampling or "none")
273
+ if result.pareto_k is not None:
274
+ data_vars["pareto_k"] = xr.DataArray(result.pareto_k)
275
+
276
+ if result.lbfgs_status:
277
+ data_vars["lbfgs_status_counts"] = _status_counter_to_dataarray(
278
+ result.lbfgs_status, LBFGSStatus
279
+ )
280
+ if result.path_status:
281
+ data_vars["path_status_counts"] = _status_counter_to_dataarray(
282
+ result.path_status, PathStatus
283
+ )
284
+
285
+ data_vars["all_paths_failed"] = xr.DataArray(result.all_paths_failed)
286
+ if not result.all_paths_failed and result.samples is not None:
287
+ data_vars["num_successful_paths"] = xr.DataArray(result.samples.shape[0])
288
+
289
+ if result.lbfgs_niter is not None:
290
+ data_vars["lbfgs_niter_mean"] = xr.DataArray(np.mean(result.lbfgs_niter))
291
+ data_vars["lbfgs_niter_std"] = xr.DataArray(np.std(result.lbfgs_niter))
292
+
293
+ if result.elbo_argmax is not None:
294
+ data_vars["elbo_argmax_mean"] = xr.DataArray(np.mean(result.elbo_argmax))
295
+ data_vars["elbo_argmax_std"] = xr.DataArray(np.std(result.elbo_argmax))
296
+
297
+ if result.logP is not None:
298
+ data_vars["logP_mean"] = xr.DataArray(np.mean(result.logP))
299
+ data_vars["logP_std"] = xr.DataArray(np.std(result.logP))
300
+ data_vars["logP_max"] = xr.DataArray(np.max(result.logP))
301
+
302
+ if result.logQ is not None:
303
+ data_vars["logQ_mean"] = xr.DataArray(np.mean(result.logQ))
304
+ data_vars["logQ_std"] = xr.DataArray(np.std(result.logQ))
305
+ data_vars["logQ_max"] = xr.DataArray(np.max(result.logQ))
306
+
307
+ # Add warnings to attributes
308
+ if result.warnings:
309
+ attrs["warnings"] = list(result.warnings)
310
+
311
+
312
+ def _add_paths_data(
313
+ result: MultiPathfinderResult,
314
+ data_vars: dict,
315
+ coords: dict,
316
+ param_coords: list[str] | None,
317
+ n_params: int | None,
318
+ ) -> None:
319
+ """Add per-path diagnostics to the pathfinder dataset with 'paths/' prefix."""
320
+ n_paths = _determine_num_paths(result)
321
+
322
+ # Add path coordinate
323
+ coords["path"] = list(range(n_paths))
324
+
325
+ def _add_path_scalar(name: str, data):
326
+ """Add a per-path scalar array to data_vars with paths/ prefix."""
327
+ if data is not None:
328
+ data_vars[f"paths/{name}"] = xr.DataArray(
329
+ data, dims=["path"], coords={"path": coords["path"]}
330
+ )
331
+
332
+ _add_path_scalar("lbfgs_niter", result.lbfgs_niter)
333
+ _add_path_scalar("elbo_argmax", result.elbo_argmax)
334
+
335
+ if result.logP is not None:
336
+ _add_path_scalar("logP_mean", np.mean(result.logP, axis=1))
337
+ _add_path_scalar("logP_max", np.max(result.logP, axis=1))
338
+
339
+ if result.logQ is not None:
340
+ _add_path_scalar("logQ_mean", np.mean(result.logQ, axis=1))
341
+ _add_path_scalar("logQ_max", np.max(result.logQ, axis=1))
342
+
343
+ if n_params is not None and result.samples is not None and result.samples.ndim >= 3:
344
+ final_samples = result.samples[:, -1, :] # (S, N)
345
+ data_vars["paths/final_sample"] = xr.DataArray(
346
+ final_samples,
347
+ dims=["path", "param"],
348
+ coords={"path": coords["path"], "param": coords["param"]},
349
+ )
350
+
351
+
352
+ def _add_config_data(config: PathfinderConfig, data_vars: dict) -> None:
353
+ """Add configuration parameters to the pathfinder dataset with 'config/' prefix."""
354
+ config_dict = asdict(config)
355
+ for key, value in config_dict.items():
356
+ data_vars[f"config/{key}"] = xr.DataArray(value)
357
+
358
+
359
+ def _add_diagnostics_data(
360
+ result: MultiPathfinderResult, data_vars: dict, coords: dict, param_coords: list[str] | None
361
+ ) -> None:
362
+ """Add detailed diagnostics to the pathfinder dataset with 'diagnostics/' prefix."""
363
+ if result.logP is not None:
364
+ n_paths, n_draws_per_path = result.logP.shape
365
+ if "path" not in coords:
366
+ coords["path"] = list(range(n_paths))
367
+ coords["draw_per_path"] = list(range(n_draws_per_path))
368
+
369
+ data_vars["diagnostics/logP_full"] = xr.DataArray(
370
+ result.logP,
371
+ dims=["path", "draw_per_path"],
372
+ coords={"path": coords["path"], "draw_per_path": coords["draw_per_path"]},
373
+ )
374
+
375
+ if result.logQ is not None:
376
+ if "draw_per_path" not in coords:
377
+ n_paths, n_draws_per_path = result.logQ.shape
378
+ if "path" not in coords:
379
+ coords["path"] = list(range(n_paths))
380
+ coords["draw_per_path"] = list(range(n_draws_per_path))
381
+
382
+ data_vars["diagnostics/logQ_full"] = xr.DataArray(
383
+ result.logQ,
384
+ dims=["path", "draw_per_path"],
385
+ coords={"path": coords["path"], "draw_per_path": coords["draw_per_path"]},
386
+ )
387
+
388
+ if result.samples is not None and result.samples.ndim == 3 and param_coords is not None:
389
+ n_paths, n_draws_per_path, n_params = result.samples.shape
390
+
391
+ if "path" not in coords:
392
+ coords["path"] = list(range(n_paths))
393
+ if "draw_per_path" not in coords:
394
+ coords["draw_per_path"] = list(range(n_draws_per_path))
395
+
396
+ data_vars["diagnostics/samples_full"] = xr.DataArray(
397
+ result.samples,
398
+ dims=["path", "draw_per_path", "param"],
399
+ coords={
400
+ "path": coords["path"],
401
+ "draw_per_path": coords["draw_per_path"],
402
+ "param": coords["param"],
403
+ },
404
+ )
405
+
406
+
407
+ def _determine_num_paths(result: MultiPathfinderResult) -> int:
408
+ """
409
+ Determine the number of paths from per-path arrays.
410
+
411
+ When importance sampling is applied, result.samples may be collapsed,
412
+ so we use per-path diagnostic arrays to determine the true path count.
413
+ """
414
+ if result.lbfgs_niter is not None:
415
+ return len(result.lbfgs_niter)
416
+ elif result.elbo_argmax is not None:
417
+ return len(result.elbo_argmax)
418
+ elif result.logP is not None:
419
+ return result.logP.shape[0]
420
+ elif result.logQ is not None:
421
+ return result.logQ.shape[0]
422
+
423
+ if result.lbfgs_status:
424
+ return sum(result.lbfgs_status.values())
425
+ elif result.path_status:
426
+ return sum(result.path_status.values())
427
+
428
+ if result.samples is not None:
429
+ return result.samples.shape[0]
430
+
431
+ raise ValueError("Cannot determine number of paths from result")
432
+
433
+
434
+ def add_pathfinder_to_inference_data(
435
+ idata: az.InferenceData,
436
+ result: PathfinderResult | MultiPathfinderResult,
437
+ model: pm.Model | None = None,
438
+ *,
439
+ group: str = "pathfinder",
440
+ paths_group: str = "pathfinder_paths", # Deprecated, kept for API compatibility
441
+ diagnostics_group: str = "pathfinder_diagnostics", # Deprecated, kept for API compatibility
442
+ config_group: str = "pathfinder_config", # Deprecated, kept for API compatibility
443
+ store_diagnostics: bool = False,
444
+ ) -> az.InferenceData:
445
+ """
446
+ Add pathfinder results to an ArviZ InferenceData object as a single consolidated group.
447
+
448
+ All pathfinder output is now consolidated under a single group with nested structure:
449
+ - Summary statistics at the top level
450
+ - Per-path data with 'paths/' prefix
451
+ - Configuration with 'config/' prefix
452
+ - Diagnostics with 'diagnostics/' prefix (if store_diagnostics=True)
453
+
454
+ Parameters
455
+ ----------
456
+ idata : az.InferenceData
457
+ InferenceData object to modify
458
+ result : PathfinderResult | MultiPathfinderResult
459
+ Pathfinder results to add
460
+ model : pm.Model | None
461
+ PyMC model for parameter name extraction
462
+ group : str
463
+ Name for the pathfinder group (default: "pathfinder")
464
+ paths_group : str
465
+ Deprecated: no longer used, kept for API compatibility
466
+ diagnostics_group : str
467
+ Deprecated: no longer used, kept for API compatibility
468
+ config_group : str
469
+ Deprecated: no longer used, kept for API compatibility
470
+ store_diagnostics : bool
471
+ Whether to include potentially large diagnostic arrays
472
+
473
+ Returns
474
+ -------
475
+ az.InferenceData
476
+ Modified InferenceData object with consolidated pathfinder group added
477
+
478
+ Examples
479
+ --------
480
+ >>> import pymc as pm
481
+ >>> import pymc_extras as pmx
482
+ >>>
483
+ >>> with pm.Model() as model:
484
+ ... x = pm.Normal("x", 0, 1)
485
+ ... idata = pmx.fit(method="pathfinder", model=model, add_pathfinder_groups=False)
486
+ >>> # Assuming we have pathfinder results
487
+ >>> idata = add_pathfinder_to_inference_data(idata, results, model=model)
488
+ >>> print(list(idata.groups())) # Will show ['posterior', 'pathfinder']
489
+ >>> # Access nested data:
490
+ >>> print(
491
+ ... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("paths/")]
492
+ ... ) # Per-path data
493
+ >>> print(
494
+ ... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("config/")]
495
+ ... ) # Config data
496
+ """
497
+ # Detect if this is a multi-path result
498
+ # Use isinstance() as primary check, but fall back to duck typing for compatibility
499
+ # with mocks and testing (MultiPathfinderResult has Counter-type status fields)
500
+ is_multipath = isinstance(result, MultiPathfinderResult) or (
501
+ hasattr(result, "lbfgs_status")
502
+ and hasattr(result.lbfgs_status, "values")
503
+ and callable(getattr(result.lbfgs_status, "values"))
504
+ )
505
+
506
+ if is_multipath:
507
+ consolidated_ds = multipathfinder_result_to_xarray(
508
+ result, model=model, store_diagnostics=store_diagnostics
509
+ )
510
+ else:
511
+ consolidated_ds = pathfinder_result_to_xarray(result, model=model)
512
+
513
+ if group in idata.groups():
514
+ warnings.warn(f"Group '{group}' already exists in InferenceData, it will be replaced.")
515
+
516
+ idata.add_groups({group: consolidated_ds})
517
+ return idata