pymc-extras 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/deserialize.py +10 -4
- pymc_extras/distributions/continuous.py +1 -1
- pymc_extras/distributions/histogram_utils.py +6 -4
- pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
- pymc_extras/distributions/timeseries.py +14 -12
- pymc_extras/inference/dadvi/dadvi.py +149 -128
- pymc_extras/inference/laplace_approx/find_map.py +16 -39
- pymc_extras/inference/laplace_approx/idata.py +22 -4
- pymc_extras/inference/laplace_approx/laplace.py +196 -151
- pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
- pymc_extras/inference/pathfinder/idata.py +517 -0
- pymc_extras/inference/pathfinder/pathfinder.py +71 -12
- pymc_extras/inference/smc/sampling.py +2 -2
- pymc_extras/model/marginal/distributions.py +4 -2
- pymc_extras/model/marginal/graph_analysis.py +2 -2
- pymc_extras/model/marginal/marginal_model.py +12 -2
- pymc_extras/model_builder.py +9 -4
- pymc_extras/prior.py +203 -8
- pymc_extras/statespace/core/compile.py +1 -1
- pymc_extras/statespace/core/statespace.py +2 -1
- pymc_extras/statespace/filters/distributions.py +15 -13
- pymc_extras/statespace/filters/kalman_filter.py +24 -22
- pymc_extras/statespace/filters/kalman_smoother.py +3 -5
- pymc_extras/statespace/filters/utilities.py +2 -5
- pymc_extras/statespace/models/DFM.py +12 -27
- pymc_extras/statespace/models/ETS.py +190 -198
- pymc_extras/statespace/models/SARIMAX.py +5 -17
- pymc_extras/statespace/models/VARMAX.py +15 -67
- pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
- pymc_extras/statespace/models/structural/components/regression.py +4 -26
- pymc_extras/statespace/models/utilities.py +7 -0
- pymc_extras/utils/model_equivalence.py +2 -2
- pymc_extras/utils/prior.py +10 -14
- pymc_extras/utils/spline.py +4 -10
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/METADATA +4 -4
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/RECORD +38 -37
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/WHEEL +1 -1
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,517 @@
|
|
|
1
|
+
# Copyright 2022 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Utilities for converting Pathfinder results to xarray and adding them to InferenceData."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import warnings
|
|
20
|
+
|
|
21
|
+
from dataclasses import asdict
|
|
22
|
+
|
|
23
|
+
import arviz as az
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pymc as pm
|
|
26
|
+
import xarray as xr
|
|
27
|
+
|
|
28
|
+
from pymc.blocking import DictToArrayBijection
|
|
29
|
+
|
|
30
|
+
from pymc_extras.inference.pathfinder.lbfgs import LBFGSStatus
|
|
31
|
+
from pymc_extras.inference.pathfinder.pathfinder import (
|
|
32
|
+
MultiPathfinderResult,
|
|
33
|
+
PathfinderConfig,
|
|
34
|
+
PathfinderResult,
|
|
35
|
+
PathStatus,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_param_coords(model: pm.Model | None, n_params: int) -> list[str]:
|
|
40
|
+
"""
|
|
41
|
+
Get parameter coordinate labels from PyMC model.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
model : pm.Model | None
|
|
46
|
+
PyMC model to extract variable names from. If None, returns numeric indices.
|
|
47
|
+
n_params : int
|
|
48
|
+
Number of parameters (for fallback indexing when model is None)
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
list[str]
|
|
53
|
+
Parameter coordinate labels
|
|
54
|
+
"""
|
|
55
|
+
if model is None:
|
|
56
|
+
return [str(i) for i in range(n_params)]
|
|
57
|
+
|
|
58
|
+
ip = model.initial_point()
|
|
59
|
+
bij = DictToArrayBijection.map(ip)
|
|
60
|
+
|
|
61
|
+
coords = []
|
|
62
|
+
for var_name, shape, size, _ in bij.point_map_info:
|
|
63
|
+
if size == 1:
|
|
64
|
+
coords.append(var_name)
|
|
65
|
+
else:
|
|
66
|
+
for i in range(size):
|
|
67
|
+
coords.append(f"{var_name}[{i}]")
|
|
68
|
+
return coords
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _status_counter_to_dataarray(counter, status_enum_cls) -> xr.DataArray:
|
|
72
|
+
"""Convert a Counter of status values to a dense xarray DataArray."""
|
|
73
|
+
all_statuses = list(status_enum_cls)
|
|
74
|
+
status_names = [s.name for s in all_statuses]
|
|
75
|
+
|
|
76
|
+
counts = np.array([counter.get(status, 0) for status in all_statuses])
|
|
77
|
+
|
|
78
|
+
return xr.DataArray(
|
|
79
|
+
counts, dims=["status"], coords={"status": status_names}, name="status_counts"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _extract_scalar(value):
|
|
84
|
+
"""Extract scalar from array-like or return as-is."""
|
|
85
|
+
if hasattr(value, "item"):
|
|
86
|
+
return value.item()
|
|
87
|
+
elif hasattr(value, "__len__") and len(value) == 1:
|
|
88
|
+
return value[0]
|
|
89
|
+
return value
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def pathfinder_result_to_xarray(
|
|
93
|
+
result: PathfinderResult,
|
|
94
|
+
model: pm.Model | None = None,
|
|
95
|
+
) -> xr.Dataset:
|
|
96
|
+
"""
|
|
97
|
+
Convert a PathfinderResult to an xarray Dataset.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
result : PathfinderResult
|
|
102
|
+
Single pathfinder run result
|
|
103
|
+
model : pm.Model | None
|
|
104
|
+
PyMC model for parameter name extraction
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
xr.Dataset
|
|
109
|
+
Dataset with pathfinder results
|
|
110
|
+
|
|
111
|
+
Examples
|
|
112
|
+
--------
|
|
113
|
+
>>> import pymc as pm
|
|
114
|
+
>>> import pymc_extras as pmx
|
|
115
|
+
>>>
|
|
116
|
+
>>> with pm.Model() as model:
|
|
117
|
+
... x = pm.Normal("x", 0, 1)
|
|
118
|
+
... y = pm.Normal("y", x, 1, observed=2.0)
|
|
119
|
+
>>> # Assuming we have a PathfinderResult from a pathfinder run
|
|
120
|
+
>>> ds = pathfinder_result_to_xarray(result, model=model)
|
|
121
|
+
>>> print(ds.data_vars) # Shows lbfgs_niter, elbo_argmax, status info, etc.
|
|
122
|
+
>>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
|
|
123
|
+
"""
|
|
124
|
+
data_vars = {}
|
|
125
|
+
coords = {}
|
|
126
|
+
attrs = {}
|
|
127
|
+
|
|
128
|
+
n_params = None
|
|
129
|
+
if result.samples is not None:
|
|
130
|
+
n_params = result.samples.shape[-1]
|
|
131
|
+
elif hasattr(result, "lbfgs_niter") and result.lbfgs_niter is not None:
|
|
132
|
+
if model is not None:
|
|
133
|
+
try:
|
|
134
|
+
ip = model.initial_point()
|
|
135
|
+
n_params = len(DictToArrayBijection.map(ip).data)
|
|
136
|
+
except Exception:
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
if n_params is not None:
|
|
140
|
+
coords["param"] = get_param_coords(model, n_params)
|
|
141
|
+
|
|
142
|
+
if result.lbfgs_niter is not None:
|
|
143
|
+
data_vars["lbfgs_niter"] = xr.DataArray(_extract_scalar(result.lbfgs_niter))
|
|
144
|
+
|
|
145
|
+
if result.elbo_argmax is not None:
|
|
146
|
+
data_vars["elbo_argmax"] = xr.DataArray(_extract_scalar(result.elbo_argmax))
|
|
147
|
+
|
|
148
|
+
data_vars["lbfgs_status_code"] = xr.DataArray(result.lbfgs_status.value)
|
|
149
|
+
data_vars["lbfgs_status_name"] = xr.DataArray(result.lbfgs_status.name)
|
|
150
|
+
data_vars["path_status_code"] = xr.DataArray(result.path_status.value)
|
|
151
|
+
data_vars["path_status_name"] = xr.DataArray(result.path_status.name)
|
|
152
|
+
|
|
153
|
+
if n_params is not None and result.samples is not None:
|
|
154
|
+
if result.samples.ndim >= 2:
|
|
155
|
+
representative_sample = result.samples[0, -1, :]
|
|
156
|
+
data_vars["final_sample"] = xr.DataArray(
|
|
157
|
+
representative_sample, dims=["param"], coords={"param": coords["param"]}
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if result.logP is not None:
|
|
161
|
+
logP = result.logP.flatten() if hasattr(result.logP, "flatten") else result.logP
|
|
162
|
+
if hasattr(logP, "__len__") and len(logP) > 0:
|
|
163
|
+
data_vars["logP_mean"] = xr.DataArray(np.mean(logP))
|
|
164
|
+
data_vars["logP_std"] = xr.DataArray(np.std(logP))
|
|
165
|
+
data_vars["logP_max"] = xr.DataArray(np.max(logP))
|
|
166
|
+
|
|
167
|
+
if result.logQ is not None:
|
|
168
|
+
logQ = result.logQ.flatten() if hasattr(result.logQ, "flatten") else result.logQ
|
|
169
|
+
if hasattr(logQ, "__len__") and len(logQ) > 0:
|
|
170
|
+
data_vars["logQ_mean"] = xr.DataArray(np.mean(logQ))
|
|
171
|
+
data_vars["logQ_std"] = xr.DataArray(np.std(logQ))
|
|
172
|
+
data_vars["logQ_max"] = xr.DataArray(np.max(logQ))
|
|
173
|
+
|
|
174
|
+
attrs["lbfgs_status"] = result.lbfgs_status.name
|
|
175
|
+
attrs["path_status"] = result.path_status.name
|
|
176
|
+
|
|
177
|
+
ds = xr.Dataset(data_vars, coords=coords, attrs=attrs)
|
|
178
|
+
|
|
179
|
+
return ds
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def multipathfinder_result_to_xarray(
|
|
183
|
+
result: MultiPathfinderResult,
|
|
184
|
+
model: pm.Model | None = None,
|
|
185
|
+
*,
|
|
186
|
+
store_diagnostics: bool = False,
|
|
187
|
+
) -> xr.Dataset:
|
|
188
|
+
"""
|
|
189
|
+
Convert a MultiPathfinderResult to a single consolidated xarray Dataset.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
result : MultiPathfinderResult
|
|
194
|
+
Multi-path pathfinder result
|
|
195
|
+
model : pm.Model | None
|
|
196
|
+
PyMC model for parameter name extraction
|
|
197
|
+
store_diagnostics : bool
|
|
198
|
+
Whether to include potentially large diagnostic arrays
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
xr.Dataset
|
|
203
|
+
Single consolidated dataset with all pathfinder results
|
|
204
|
+
|
|
205
|
+
Examples
|
|
206
|
+
--------
|
|
207
|
+
>>> import pymc as pm
|
|
208
|
+
>>> import pymc_extras as pmx
|
|
209
|
+
>>>
|
|
210
|
+
>>> with pm.Model() as model:
|
|
211
|
+
... x = pm.Normal("x", 0, 1)
|
|
212
|
+
...
|
|
213
|
+
>>> # Assuming we have a MultiPathfinderResult from multiple pathfinder runs
|
|
214
|
+
>>> ds = multipathfinder_result_to_xarray(result, model=model)
|
|
215
|
+
>>> print("All data:", ds.data_vars)
|
|
216
|
+
>>> print(
|
|
217
|
+
... "Summary:",
|
|
218
|
+
... [
|
|
219
|
+
... k
|
|
220
|
+
... for k in ds.data_vars.keys()
|
|
221
|
+
... if not k.startswith(("paths/", "config/", "diagnostics/"))
|
|
222
|
+
... ],
|
|
223
|
+
... )
|
|
224
|
+
>>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith("paths/")])
|
|
225
|
+
>>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith("config/")])
|
|
226
|
+
"""
|
|
227
|
+
n_params = result.samples.shape[-1] if result.samples is not None else None
|
|
228
|
+
param_coords = get_param_coords(model, n_params) if n_params is not None else None
|
|
229
|
+
|
|
230
|
+
data_vars = {}
|
|
231
|
+
coords = {}
|
|
232
|
+
attrs = {}
|
|
233
|
+
|
|
234
|
+
# Add parameter coordinates if available
|
|
235
|
+
if param_coords is not None:
|
|
236
|
+
coords["param"] = param_coords
|
|
237
|
+
|
|
238
|
+
# Build summary-level data (top level)
|
|
239
|
+
_add_summary_data(result, data_vars, coords, attrs)
|
|
240
|
+
|
|
241
|
+
# Build per-path data (with paths/ prefix)
|
|
242
|
+
if not result.all_paths_failed and result.samples is not None:
|
|
243
|
+
_add_paths_data(result, data_vars, coords, param_coords, n_params)
|
|
244
|
+
|
|
245
|
+
# Build configuration data (with config/ prefix)
|
|
246
|
+
if result.pathfinder_config is not None:
|
|
247
|
+
_add_config_data(result.pathfinder_config, data_vars)
|
|
248
|
+
|
|
249
|
+
# Build diagnostics data (with diagnostics/ prefix) if requested
|
|
250
|
+
if store_diagnostics:
|
|
251
|
+
_add_diagnostics_data(result, data_vars, coords, param_coords)
|
|
252
|
+
|
|
253
|
+
return xr.Dataset(data_vars, coords=coords, attrs=attrs)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _add_summary_data(
|
|
257
|
+
result: MultiPathfinderResult, data_vars: dict, coords: dict, attrs: dict
|
|
258
|
+
) -> None:
|
|
259
|
+
"""Add summary-level statistics to the pathfinder dataset."""
|
|
260
|
+
if result.num_paths is not None:
|
|
261
|
+
data_vars["num_paths"] = xr.DataArray(result.num_paths)
|
|
262
|
+
if result.num_draws is not None:
|
|
263
|
+
data_vars["num_draws"] = xr.DataArray(result.num_draws)
|
|
264
|
+
|
|
265
|
+
if result.compile_time is not None:
|
|
266
|
+
data_vars["compile_time"] = xr.DataArray(result.compile_time)
|
|
267
|
+
if result.compute_time is not None:
|
|
268
|
+
data_vars["compute_time"] = xr.DataArray(result.compute_time)
|
|
269
|
+
if result.compile_time is not None:
|
|
270
|
+
data_vars["total_time"] = xr.DataArray(result.compile_time + result.compute_time)
|
|
271
|
+
|
|
272
|
+
data_vars["importance_sampling_method"] = xr.DataArray(result.importance_sampling or "none")
|
|
273
|
+
if result.pareto_k is not None:
|
|
274
|
+
data_vars["pareto_k"] = xr.DataArray(result.pareto_k)
|
|
275
|
+
|
|
276
|
+
if result.lbfgs_status:
|
|
277
|
+
data_vars["lbfgs_status_counts"] = _status_counter_to_dataarray(
|
|
278
|
+
result.lbfgs_status, LBFGSStatus
|
|
279
|
+
)
|
|
280
|
+
if result.path_status:
|
|
281
|
+
data_vars["path_status_counts"] = _status_counter_to_dataarray(
|
|
282
|
+
result.path_status, PathStatus
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
data_vars["all_paths_failed"] = xr.DataArray(result.all_paths_failed)
|
|
286
|
+
if not result.all_paths_failed and result.samples is not None:
|
|
287
|
+
data_vars["num_successful_paths"] = xr.DataArray(result.samples.shape[0])
|
|
288
|
+
|
|
289
|
+
if result.lbfgs_niter is not None:
|
|
290
|
+
data_vars["lbfgs_niter_mean"] = xr.DataArray(np.mean(result.lbfgs_niter))
|
|
291
|
+
data_vars["lbfgs_niter_std"] = xr.DataArray(np.std(result.lbfgs_niter))
|
|
292
|
+
|
|
293
|
+
if result.elbo_argmax is not None:
|
|
294
|
+
data_vars["elbo_argmax_mean"] = xr.DataArray(np.mean(result.elbo_argmax))
|
|
295
|
+
data_vars["elbo_argmax_std"] = xr.DataArray(np.std(result.elbo_argmax))
|
|
296
|
+
|
|
297
|
+
if result.logP is not None:
|
|
298
|
+
data_vars["logP_mean"] = xr.DataArray(np.mean(result.logP))
|
|
299
|
+
data_vars["logP_std"] = xr.DataArray(np.std(result.logP))
|
|
300
|
+
data_vars["logP_max"] = xr.DataArray(np.max(result.logP))
|
|
301
|
+
|
|
302
|
+
if result.logQ is not None:
|
|
303
|
+
data_vars["logQ_mean"] = xr.DataArray(np.mean(result.logQ))
|
|
304
|
+
data_vars["logQ_std"] = xr.DataArray(np.std(result.logQ))
|
|
305
|
+
data_vars["logQ_max"] = xr.DataArray(np.max(result.logQ))
|
|
306
|
+
|
|
307
|
+
# Add warnings to attributes
|
|
308
|
+
if result.warnings:
|
|
309
|
+
attrs["warnings"] = list(result.warnings)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _add_paths_data(
|
|
313
|
+
result: MultiPathfinderResult,
|
|
314
|
+
data_vars: dict,
|
|
315
|
+
coords: dict,
|
|
316
|
+
param_coords: list[str] | None,
|
|
317
|
+
n_params: int | None,
|
|
318
|
+
) -> None:
|
|
319
|
+
"""Add per-path diagnostics to the pathfinder dataset with 'paths/' prefix."""
|
|
320
|
+
n_paths = _determine_num_paths(result)
|
|
321
|
+
|
|
322
|
+
# Add path coordinate
|
|
323
|
+
coords["path"] = list(range(n_paths))
|
|
324
|
+
|
|
325
|
+
def _add_path_scalar(name: str, data):
|
|
326
|
+
"""Add a per-path scalar array to data_vars with paths/ prefix."""
|
|
327
|
+
if data is not None:
|
|
328
|
+
data_vars[f"paths/{name}"] = xr.DataArray(
|
|
329
|
+
data, dims=["path"], coords={"path": coords["path"]}
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
_add_path_scalar("lbfgs_niter", result.lbfgs_niter)
|
|
333
|
+
_add_path_scalar("elbo_argmax", result.elbo_argmax)
|
|
334
|
+
|
|
335
|
+
if result.logP is not None:
|
|
336
|
+
_add_path_scalar("logP_mean", np.mean(result.logP, axis=1))
|
|
337
|
+
_add_path_scalar("logP_max", np.max(result.logP, axis=1))
|
|
338
|
+
|
|
339
|
+
if result.logQ is not None:
|
|
340
|
+
_add_path_scalar("logQ_mean", np.mean(result.logQ, axis=1))
|
|
341
|
+
_add_path_scalar("logQ_max", np.max(result.logQ, axis=1))
|
|
342
|
+
|
|
343
|
+
if n_params is not None and result.samples is not None and result.samples.ndim >= 3:
|
|
344
|
+
final_samples = result.samples[:, -1, :] # (S, N)
|
|
345
|
+
data_vars["paths/final_sample"] = xr.DataArray(
|
|
346
|
+
final_samples,
|
|
347
|
+
dims=["path", "param"],
|
|
348
|
+
coords={"path": coords["path"], "param": coords["param"]},
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _add_config_data(config: PathfinderConfig, data_vars: dict) -> None:
|
|
353
|
+
"""Add configuration parameters to the pathfinder dataset with 'config/' prefix."""
|
|
354
|
+
config_dict = asdict(config)
|
|
355
|
+
for key, value in config_dict.items():
|
|
356
|
+
data_vars[f"config/{key}"] = xr.DataArray(value)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _add_diagnostics_data(
|
|
360
|
+
result: MultiPathfinderResult, data_vars: dict, coords: dict, param_coords: list[str] | None
|
|
361
|
+
) -> None:
|
|
362
|
+
"""Add detailed diagnostics to the pathfinder dataset with 'diagnostics/' prefix."""
|
|
363
|
+
if result.logP is not None:
|
|
364
|
+
n_paths, n_draws_per_path = result.logP.shape
|
|
365
|
+
if "path" not in coords:
|
|
366
|
+
coords["path"] = list(range(n_paths))
|
|
367
|
+
coords["draw_per_path"] = list(range(n_draws_per_path))
|
|
368
|
+
|
|
369
|
+
data_vars["diagnostics/logP_full"] = xr.DataArray(
|
|
370
|
+
result.logP,
|
|
371
|
+
dims=["path", "draw_per_path"],
|
|
372
|
+
coords={"path": coords["path"], "draw_per_path": coords["draw_per_path"]},
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
if result.logQ is not None:
|
|
376
|
+
if "draw_per_path" not in coords:
|
|
377
|
+
n_paths, n_draws_per_path = result.logQ.shape
|
|
378
|
+
if "path" not in coords:
|
|
379
|
+
coords["path"] = list(range(n_paths))
|
|
380
|
+
coords["draw_per_path"] = list(range(n_draws_per_path))
|
|
381
|
+
|
|
382
|
+
data_vars["diagnostics/logQ_full"] = xr.DataArray(
|
|
383
|
+
result.logQ,
|
|
384
|
+
dims=["path", "draw_per_path"],
|
|
385
|
+
coords={"path": coords["path"], "draw_per_path": coords["draw_per_path"]},
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
if result.samples is not None and result.samples.ndim == 3 and param_coords is not None:
|
|
389
|
+
n_paths, n_draws_per_path, n_params = result.samples.shape
|
|
390
|
+
|
|
391
|
+
if "path" not in coords:
|
|
392
|
+
coords["path"] = list(range(n_paths))
|
|
393
|
+
if "draw_per_path" not in coords:
|
|
394
|
+
coords["draw_per_path"] = list(range(n_draws_per_path))
|
|
395
|
+
|
|
396
|
+
data_vars["diagnostics/samples_full"] = xr.DataArray(
|
|
397
|
+
result.samples,
|
|
398
|
+
dims=["path", "draw_per_path", "param"],
|
|
399
|
+
coords={
|
|
400
|
+
"path": coords["path"],
|
|
401
|
+
"draw_per_path": coords["draw_per_path"],
|
|
402
|
+
"param": coords["param"],
|
|
403
|
+
},
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _determine_num_paths(result: MultiPathfinderResult) -> int:
|
|
408
|
+
"""
|
|
409
|
+
Determine the number of paths from per-path arrays.
|
|
410
|
+
|
|
411
|
+
When importance sampling is applied, result.samples may be collapsed,
|
|
412
|
+
so we use per-path diagnostic arrays to determine the true path count.
|
|
413
|
+
"""
|
|
414
|
+
if result.lbfgs_niter is not None:
|
|
415
|
+
return len(result.lbfgs_niter)
|
|
416
|
+
elif result.elbo_argmax is not None:
|
|
417
|
+
return len(result.elbo_argmax)
|
|
418
|
+
elif result.logP is not None:
|
|
419
|
+
return result.logP.shape[0]
|
|
420
|
+
elif result.logQ is not None:
|
|
421
|
+
return result.logQ.shape[0]
|
|
422
|
+
|
|
423
|
+
if result.lbfgs_status:
|
|
424
|
+
return sum(result.lbfgs_status.values())
|
|
425
|
+
elif result.path_status:
|
|
426
|
+
return sum(result.path_status.values())
|
|
427
|
+
|
|
428
|
+
if result.samples is not None:
|
|
429
|
+
return result.samples.shape[0]
|
|
430
|
+
|
|
431
|
+
raise ValueError("Cannot determine number of paths from result")
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def add_pathfinder_to_inference_data(
|
|
435
|
+
idata: az.InferenceData,
|
|
436
|
+
result: PathfinderResult | MultiPathfinderResult,
|
|
437
|
+
model: pm.Model | None = None,
|
|
438
|
+
*,
|
|
439
|
+
group: str = "pathfinder",
|
|
440
|
+
paths_group: str = "pathfinder_paths", # Deprecated, kept for API compatibility
|
|
441
|
+
diagnostics_group: str = "pathfinder_diagnostics", # Deprecated, kept for API compatibility
|
|
442
|
+
config_group: str = "pathfinder_config", # Deprecated, kept for API compatibility
|
|
443
|
+
store_diagnostics: bool = False,
|
|
444
|
+
) -> az.InferenceData:
|
|
445
|
+
"""
|
|
446
|
+
Add pathfinder results to an ArviZ InferenceData object as a single consolidated group.
|
|
447
|
+
|
|
448
|
+
All pathfinder output is now consolidated under a single group with nested structure:
|
|
449
|
+
- Summary statistics at the top level
|
|
450
|
+
- Per-path data with 'paths/' prefix
|
|
451
|
+
- Configuration with 'config/' prefix
|
|
452
|
+
- Diagnostics with 'diagnostics/' prefix (if store_diagnostics=True)
|
|
453
|
+
|
|
454
|
+
Parameters
|
|
455
|
+
----------
|
|
456
|
+
idata : az.InferenceData
|
|
457
|
+
InferenceData object to modify
|
|
458
|
+
result : PathfinderResult | MultiPathfinderResult
|
|
459
|
+
Pathfinder results to add
|
|
460
|
+
model : pm.Model | None
|
|
461
|
+
PyMC model for parameter name extraction
|
|
462
|
+
group : str
|
|
463
|
+
Name for the pathfinder group (default: "pathfinder")
|
|
464
|
+
paths_group : str
|
|
465
|
+
Deprecated: no longer used, kept for API compatibility
|
|
466
|
+
diagnostics_group : str
|
|
467
|
+
Deprecated: no longer used, kept for API compatibility
|
|
468
|
+
config_group : str
|
|
469
|
+
Deprecated: no longer used, kept for API compatibility
|
|
470
|
+
store_diagnostics : bool
|
|
471
|
+
Whether to include potentially large diagnostic arrays
|
|
472
|
+
|
|
473
|
+
Returns
|
|
474
|
+
-------
|
|
475
|
+
az.InferenceData
|
|
476
|
+
Modified InferenceData object with consolidated pathfinder group added
|
|
477
|
+
|
|
478
|
+
Examples
|
|
479
|
+
--------
|
|
480
|
+
>>> import pymc as pm
|
|
481
|
+
>>> import pymc_extras as pmx
|
|
482
|
+
>>>
|
|
483
|
+
>>> with pm.Model() as model:
|
|
484
|
+
... x = pm.Normal("x", 0, 1)
|
|
485
|
+
... idata = pmx.fit(method="pathfinder", model=model, add_pathfinder_groups=False)
|
|
486
|
+
>>> # Assuming we have pathfinder results
|
|
487
|
+
>>> idata = add_pathfinder_to_inference_data(idata, results, model=model)
|
|
488
|
+
>>> print(list(idata.groups())) # Will show ['posterior', 'pathfinder']
|
|
489
|
+
>>> # Access nested data:
|
|
490
|
+
>>> print(
|
|
491
|
+
... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("paths/")]
|
|
492
|
+
... ) # Per-path data
|
|
493
|
+
>>> print(
|
|
494
|
+
... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("config/")]
|
|
495
|
+
... ) # Config data
|
|
496
|
+
"""
|
|
497
|
+
# Detect if this is a multi-path result
|
|
498
|
+
# Use isinstance() as primary check, but fall back to duck typing for compatibility
|
|
499
|
+
# with mocks and testing (MultiPathfinderResult has Counter-type status fields)
|
|
500
|
+
is_multipath = isinstance(result, MultiPathfinderResult) or (
|
|
501
|
+
hasattr(result, "lbfgs_status")
|
|
502
|
+
and hasattr(result.lbfgs_status, "values")
|
|
503
|
+
and callable(getattr(result.lbfgs_status, "values"))
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
if is_multipath:
|
|
507
|
+
consolidated_ds = multipathfinder_result_to_xarray(
|
|
508
|
+
result, model=model, store_diagnostics=store_diagnostics
|
|
509
|
+
)
|
|
510
|
+
else:
|
|
511
|
+
consolidated_ds = pathfinder_result_to_xarray(result, model=model)
|
|
512
|
+
|
|
513
|
+
if group in idata.groups():
|
|
514
|
+
warnings.warn(f"Group '{group}' already exists in InferenceData, it will be replaced.")
|
|
515
|
+
|
|
516
|
+
idata.add_groups({group: consolidated_ds})
|
|
517
|
+
return idata
|