pymc-extras 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/deserialize.py +10 -4
- pymc_extras/distributions/continuous.py +1 -1
- pymc_extras/distributions/histogram_utils.py +6 -4
- pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
- pymc_extras/distributions/timeseries.py +4 -2
- pymc_extras/inference/dadvi/dadvi.py +162 -72
- pymc_extras/inference/laplace_approx/find_map.py +16 -39
- pymc_extras/inference/laplace_approx/idata.py +22 -4
- pymc_extras/inference/laplace_approx/laplace.py +23 -6
- pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
- pymc_extras/inference/pathfinder/idata.py +517 -0
- pymc_extras/inference/pathfinder/pathfinder.py +61 -7
- pymc_extras/model/marginal/graph_analysis.py +2 -2
- pymc_extras/model_builder.py +9 -4
- pymc_extras/prior.py +203 -8
- pymc_extras/statespace/core/compile.py +1 -1
- pymc_extras/statespace/filters/kalman_filter.py +12 -11
- pymc_extras/statespace/filters/kalman_smoother.py +1 -3
- pymc_extras/statespace/filters/utilities.py +2 -5
- pymc_extras/statespace/models/DFM.py +12 -27
- pymc_extras/statespace/models/ETS.py +190 -198
- pymc_extras/statespace/models/SARIMAX.py +5 -17
- pymc_extras/statespace/models/VARMAX.py +15 -67
- pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
- pymc_extras/statespace/models/structural/components/regression.py +4 -26
- pymc_extras/statespace/models/utilities.py +7 -0
- pymc_extras/utils/model_equivalence.py +2 -2
- pymc_extras/utils/prior.py +10 -14
- pymc_extras/utils/spline.py +4 -10
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/METADATA +3 -3
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/RECORD +33 -32
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/WHEEL +1 -1
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,517 @@
|
|
|
1
|
+
# Copyright 2022 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Utilities for converting Pathfinder results to xarray and adding them to InferenceData."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import warnings
|
|
20
|
+
|
|
21
|
+
from dataclasses import asdict
|
|
22
|
+
|
|
23
|
+
import arviz as az
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pymc as pm
|
|
26
|
+
import xarray as xr
|
|
27
|
+
|
|
28
|
+
from pymc.blocking import DictToArrayBijection
|
|
29
|
+
|
|
30
|
+
from pymc_extras.inference.pathfinder.lbfgs import LBFGSStatus
|
|
31
|
+
from pymc_extras.inference.pathfinder.pathfinder import (
|
|
32
|
+
MultiPathfinderResult,
|
|
33
|
+
PathfinderConfig,
|
|
34
|
+
PathfinderResult,
|
|
35
|
+
PathStatus,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_param_coords(model: pm.Model | None, n_params: int) -> list[str]:
|
|
40
|
+
"""
|
|
41
|
+
Get parameter coordinate labels from PyMC model.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
model : pm.Model | None
|
|
46
|
+
PyMC model to extract variable names from. If None, returns numeric indices.
|
|
47
|
+
n_params : int
|
|
48
|
+
Number of parameters (for fallback indexing when model is None)
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
list[str]
|
|
53
|
+
Parameter coordinate labels
|
|
54
|
+
"""
|
|
55
|
+
if model is None:
|
|
56
|
+
return [str(i) for i in range(n_params)]
|
|
57
|
+
|
|
58
|
+
ip = model.initial_point()
|
|
59
|
+
bij = DictToArrayBijection.map(ip)
|
|
60
|
+
|
|
61
|
+
coords = []
|
|
62
|
+
for var_name, shape, size, _ in bij.point_map_info:
|
|
63
|
+
if size == 1:
|
|
64
|
+
coords.append(var_name)
|
|
65
|
+
else:
|
|
66
|
+
for i in range(size):
|
|
67
|
+
coords.append(f"{var_name}[{i}]")
|
|
68
|
+
return coords
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _status_counter_to_dataarray(counter, status_enum_cls) -> xr.DataArray:
|
|
72
|
+
"""Convert a Counter of status values to a dense xarray DataArray."""
|
|
73
|
+
all_statuses = list(status_enum_cls)
|
|
74
|
+
status_names = [s.name for s in all_statuses]
|
|
75
|
+
|
|
76
|
+
counts = np.array([counter.get(status, 0) for status in all_statuses])
|
|
77
|
+
|
|
78
|
+
return xr.DataArray(
|
|
79
|
+
counts, dims=["status"], coords={"status": status_names}, name="status_counts"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _extract_scalar(value):
|
|
84
|
+
"""Extract scalar from array-like or return as-is."""
|
|
85
|
+
if hasattr(value, "item"):
|
|
86
|
+
return value.item()
|
|
87
|
+
elif hasattr(value, "__len__") and len(value) == 1:
|
|
88
|
+
return value[0]
|
|
89
|
+
return value
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def pathfinder_result_to_xarray(
|
|
93
|
+
result: PathfinderResult,
|
|
94
|
+
model: pm.Model | None = None,
|
|
95
|
+
) -> xr.Dataset:
|
|
96
|
+
"""
|
|
97
|
+
Convert a PathfinderResult to an xarray Dataset.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
result : PathfinderResult
|
|
102
|
+
Single pathfinder run result
|
|
103
|
+
model : pm.Model | None
|
|
104
|
+
PyMC model for parameter name extraction
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
xr.Dataset
|
|
109
|
+
Dataset with pathfinder results
|
|
110
|
+
|
|
111
|
+
Examples
|
|
112
|
+
--------
|
|
113
|
+
>>> import pymc as pm
|
|
114
|
+
>>> import pymc_extras as pmx
|
|
115
|
+
>>>
|
|
116
|
+
>>> with pm.Model() as model:
|
|
117
|
+
... x = pm.Normal("x", 0, 1)
|
|
118
|
+
... y = pm.Normal("y", x, 1, observed=2.0)
|
|
119
|
+
>>> # Assuming we have a PathfinderResult from a pathfinder run
|
|
120
|
+
>>> ds = pathfinder_result_to_xarray(result, model=model)
|
|
121
|
+
>>> print(ds.data_vars) # Shows lbfgs_niter, elbo_argmax, status info, etc.
|
|
122
|
+
>>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
|
|
123
|
+
"""
|
|
124
|
+
data_vars = {}
|
|
125
|
+
coords = {}
|
|
126
|
+
attrs = {}
|
|
127
|
+
|
|
128
|
+
n_params = None
|
|
129
|
+
if result.samples is not None:
|
|
130
|
+
n_params = result.samples.shape[-1]
|
|
131
|
+
elif hasattr(result, "lbfgs_niter") and result.lbfgs_niter is not None:
|
|
132
|
+
if model is not None:
|
|
133
|
+
try:
|
|
134
|
+
ip = model.initial_point()
|
|
135
|
+
n_params = len(DictToArrayBijection.map(ip).data)
|
|
136
|
+
except Exception:
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
if n_params is not None:
|
|
140
|
+
coords["param"] = get_param_coords(model, n_params)
|
|
141
|
+
|
|
142
|
+
if result.lbfgs_niter is not None:
|
|
143
|
+
data_vars["lbfgs_niter"] = xr.DataArray(_extract_scalar(result.lbfgs_niter))
|
|
144
|
+
|
|
145
|
+
if result.elbo_argmax is not None:
|
|
146
|
+
data_vars["elbo_argmax"] = xr.DataArray(_extract_scalar(result.elbo_argmax))
|
|
147
|
+
|
|
148
|
+
data_vars["lbfgs_status_code"] = xr.DataArray(result.lbfgs_status.value)
|
|
149
|
+
data_vars["lbfgs_status_name"] = xr.DataArray(result.lbfgs_status.name)
|
|
150
|
+
data_vars["path_status_code"] = xr.DataArray(result.path_status.value)
|
|
151
|
+
data_vars["path_status_name"] = xr.DataArray(result.path_status.name)
|
|
152
|
+
|
|
153
|
+
if n_params is not None and result.samples is not None:
|
|
154
|
+
if result.samples.ndim >= 2:
|
|
155
|
+
representative_sample = result.samples[0, -1, :]
|
|
156
|
+
data_vars["final_sample"] = xr.DataArray(
|
|
157
|
+
representative_sample, dims=["param"], coords={"param": coords["param"]}
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if result.logP is not None:
|
|
161
|
+
logP = result.logP.flatten() if hasattr(result.logP, "flatten") else result.logP
|
|
162
|
+
if hasattr(logP, "__len__") and len(logP) > 0:
|
|
163
|
+
data_vars["logP_mean"] = xr.DataArray(np.mean(logP))
|
|
164
|
+
data_vars["logP_std"] = xr.DataArray(np.std(logP))
|
|
165
|
+
data_vars["logP_max"] = xr.DataArray(np.max(logP))
|
|
166
|
+
|
|
167
|
+
if result.logQ is not None:
|
|
168
|
+
logQ = result.logQ.flatten() if hasattr(result.logQ, "flatten") else result.logQ
|
|
169
|
+
if hasattr(logQ, "__len__") and len(logQ) > 0:
|
|
170
|
+
data_vars["logQ_mean"] = xr.DataArray(np.mean(logQ))
|
|
171
|
+
data_vars["logQ_std"] = xr.DataArray(np.std(logQ))
|
|
172
|
+
data_vars["logQ_max"] = xr.DataArray(np.max(logQ))
|
|
173
|
+
|
|
174
|
+
attrs["lbfgs_status"] = result.lbfgs_status.name
|
|
175
|
+
attrs["path_status"] = result.path_status.name
|
|
176
|
+
|
|
177
|
+
ds = xr.Dataset(data_vars, coords=coords, attrs=attrs)
|
|
178
|
+
|
|
179
|
+
return ds
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def multipathfinder_result_to_xarray(
|
|
183
|
+
result: MultiPathfinderResult,
|
|
184
|
+
model: pm.Model | None = None,
|
|
185
|
+
*,
|
|
186
|
+
store_diagnostics: bool = False,
|
|
187
|
+
) -> xr.Dataset:
|
|
188
|
+
"""
|
|
189
|
+
Convert a MultiPathfinderResult to a single consolidated xarray Dataset.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
result : MultiPathfinderResult
|
|
194
|
+
Multi-path pathfinder result
|
|
195
|
+
model : pm.Model | None
|
|
196
|
+
PyMC model for parameter name extraction
|
|
197
|
+
store_diagnostics : bool
|
|
198
|
+
Whether to include potentially large diagnostic arrays
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
xr.Dataset
|
|
203
|
+
Single consolidated dataset with all pathfinder results
|
|
204
|
+
|
|
205
|
+
Examples
|
|
206
|
+
--------
|
|
207
|
+
>>> import pymc as pm
|
|
208
|
+
>>> import pymc_extras as pmx
|
|
209
|
+
>>>
|
|
210
|
+
>>> with pm.Model() as model:
|
|
211
|
+
... x = pm.Normal("x", 0, 1)
|
|
212
|
+
...
|
|
213
|
+
>>> # Assuming we have a MultiPathfinderResult from multiple pathfinder runs
|
|
214
|
+
>>> ds = multipathfinder_result_to_xarray(result, model=model)
|
|
215
|
+
>>> print("All data:", ds.data_vars)
|
|
216
|
+
>>> print(
|
|
217
|
+
... "Summary:",
|
|
218
|
+
... [
|
|
219
|
+
... k
|
|
220
|
+
... for k in ds.data_vars.keys()
|
|
221
|
+
... if not k.startswith(("paths/", "config/", "diagnostics/"))
|
|
222
|
+
... ],
|
|
223
|
+
... )
|
|
224
|
+
>>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith("paths/")])
|
|
225
|
+
>>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith("config/")])
|
|
226
|
+
"""
|
|
227
|
+
n_params = result.samples.shape[-1] if result.samples is not None else None
|
|
228
|
+
param_coords = get_param_coords(model, n_params) if n_params is not None else None
|
|
229
|
+
|
|
230
|
+
data_vars = {}
|
|
231
|
+
coords = {}
|
|
232
|
+
attrs = {}
|
|
233
|
+
|
|
234
|
+
# Add parameter coordinates if available
|
|
235
|
+
if param_coords is not None:
|
|
236
|
+
coords["param"] = param_coords
|
|
237
|
+
|
|
238
|
+
# Build summary-level data (top level)
|
|
239
|
+
_add_summary_data(result, data_vars, coords, attrs)
|
|
240
|
+
|
|
241
|
+
# Build per-path data (with paths/ prefix)
|
|
242
|
+
if not result.all_paths_failed and result.samples is not None:
|
|
243
|
+
_add_paths_data(result, data_vars, coords, param_coords, n_params)
|
|
244
|
+
|
|
245
|
+
# Build configuration data (with config/ prefix)
|
|
246
|
+
if result.pathfinder_config is not None:
|
|
247
|
+
_add_config_data(result.pathfinder_config, data_vars)
|
|
248
|
+
|
|
249
|
+
# Build diagnostics data (with diagnostics/ prefix) if requested
|
|
250
|
+
if store_diagnostics:
|
|
251
|
+
_add_diagnostics_data(result, data_vars, coords, param_coords)
|
|
252
|
+
|
|
253
|
+
return xr.Dataset(data_vars, coords=coords, attrs=attrs)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _add_summary_data(
|
|
257
|
+
result: MultiPathfinderResult, data_vars: dict, coords: dict, attrs: dict
|
|
258
|
+
) -> None:
|
|
259
|
+
"""Add summary-level statistics to the pathfinder dataset."""
|
|
260
|
+
if result.num_paths is not None:
|
|
261
|
+
data_vars["num_paths"] = xr.DataArray(result.num_paths)
|
|
262
|
+
if result.num_draws is not None:
|
|
263
|
+
data_vars["num_draws"] = xr.DataArray(result.num_draws)
|
|
264
|
+
|
|
265
|
+
if result.compile_time is not None:
|
|
266
|
+
data_vars["compile_time"] = xr.DataArray(result.compile_time)
|
|
267
|
+
if result.compute_time is not None:
|
|
268
|
+
data_vars["compute_time"] = xr.DataArray(result.compute_time)
|
|
269
|
+
if result.compile_time is not None:
|
|
270
|
+
data_vars["total_time"] = xr.DataArray(result.compile_time + result.compute_time)
|
|
271
|
+
|
|
272
|
+
data_vars["importance_sampling_method"] = xr.DataArray(result.importance_sampling or "none")
|
|
273
|
+
if result.pareto_k is not None:
|
|
274
|
+
data_vars["pareto_k"] = xr.DataArray(result.pareto_k)
|
|
275
|
+
|
|
276
|
+
if result.lbfgs_status:
|
|
277
|
+
data_vars["lbfgs_status_counts"] = _status_counter_to_dataarray(
|
|
278
|
+
result.lbfgs_status, LBFGSStatus
|
|
279
|
+
)
|
|
280
|
+
if result.path_status:
|
|
281
|
+
data_vars["path_status_counts"] = _status_counter_to_dataarray(
|
|
282
|
+
result.path_status, PathStatus
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
data_vars["all_paths_failed"] = xr.DataArray(result.all_paths_failed)
|
|
286
|
+
if not result.all_paths_failed and result.samples is not None:
|
|
287
|
+
data_vars["num_successful_paths"] = xr.DataArray(result.samples.shape[0])
|
|
288
|
+
|
|
289
|
+
if result.lbfgs_niter is not None:
|
|
290
|
+
data_vars["lbfgs_niter_mean"] = xr.DataArray(np.mean(result.lbfgs_niter))
|
|
291
|
+
data_vars["lbfgs_niter_std"] = xr.DataArray(np.std(result.lbfgs_niter))
|
|
292
|
+
|
|
293
|
+
if result.elbo_argmax is not None:
|
|
294
|
+
data_vars["elbo_argmax_mean"] = xr.DataArray(np.mean(result.elbo_argmax))
|
|
295
|
+
data_vars["elbo_argmax_std"] = xr.DataArray(np.std(result.elbo_argmax))
|
|
296
|
+
|
|
297
|
+
if result.logP is not None:
|
|
298
|
+
data_vars["logP_mean"] = xr.DataArray(np.mean(result.logP))
|
|
299
|
+
data_vars["logP_std"] = xr.DataArray(np.std(result.logP))
|
|
300
|
+
data_vars["logP_max"] = xr.DataArray(np.max(result.logP))
|
|
301
|
+
|
|
302
|
+
if result.logQ is not None:
|
|
303
|
+
data_vars["logQ_mean"] = xr.DataArray(np.mean(result.logQ))
|
|
304
|
+
data_vars["logQ_std"] = xr.DataArray(np.std(result.logQ))
|
|
305
|
+
data_vars["logQ_max"] = xr.DataArray(np.max(result.logQ))
|
|
306
|
+
|
|
307
|
+
# Add warnings to attributes
|
|
308
|
+
if result.warnings:
|
|
309
|
+
attrs["warnings"] = list(result.warnings)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _add_paths_data(
|
|
313
|
+
result: MultiPathfinderResult,
|
|
314
|
+
data_vars: dict,
|
|
315
|
+
coords: dict,
|
|
316
|
+
param_coords: list[str] | None,
|
|
317
|
+
n_params: int | None,
|
|
318
|
+
) -> None:
|
|
319
|
+
"""Add per-path diagnostics to the pathfinder dataset with 'paths/' prefix."""
|
|
320
|
+
n_paths = _determine_num_paths(result)
|
|
321
|
+
|
|
322
|
+
# Add path coordinate
|
|
323
|
+
coords["path"] = list(range(n_paths))
|
|
324
|
+
|
|
325
|
+
def _add_path_scalar(name: str, data):
|
|
326
|
+
"""Add a per-path scalar array to data_vars with paths/ prefix."""
|
|
327
|
+
if data is not None:
|
|
328
|
+
data_vars[f"paths/{name}"] = xr.DataArray(
|
|
329
|
+
data, dims=["path"], coords={"path": coords["path"]}
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
_add_path_scalar("lbfgs_niter", result.lbfgs_niter)
|
|
333
|
+
_add_path_scalar("elbo_argmax", result.elbo_argmax)
|
|
334
|
+
|
|
335
|
+
if result.logP is not None:
|
|
336
|
+
_add_path_scalar("logP_mean", np.mean(result.logP, axis=1))
|
|
337
|
+
_add_path_scalar("logP_max", np.max(result.logP, axis=1))
|
|
338
|
+
|
|
339
|
+
if result.logQ is not None:
|
|
340
|
+
_add_path_scalar("logQ_mean", np.mean(result.logQ, axis=1))
|
|
341
|
+
_add_path_scalar("logQ_max", np.max(result.logQ, axis=1))
|
|
342
|
+
|
|
343
|
+
if n_params is not None and result.samples is not None and result.samples.ndim >= 3:
|
|
344
|
+
final_samples = result.samples[:, -1, :] # (S, N)
|
|
345
|
+
data_vars["paths/final_sample"] = xr.DataArray(
|
|
346
|
+
final_samples,
|
|
347
|
+
dims=["path", "param"],
|
|
348
|
+
coords={"path": coords["path"], "param": coords["param"]},
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _add_config_data(config: PathfinderConfig, data_vars: dict) -> None:
|
|
353
|
+
"""Add configuration parameters to the pathfinder dataset with 'config/' prefix."""
|
|
354
|
+
config_dict = asdict(config)
|
|
355
|
+
for key, value in config_dict.items():
|
|
356
|
+
data_vars[f"config/{key}"] = xr.DataArray(value)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _add_diagnostics_data(
|
|
360
|
+
result: MultiPathfinderResult, data_vars: dict, coords: dict, param_coords: list[str] | None
|
|
361
|
+
) -> None:
|
|
362
|
+
"""Add detailed diagnostics to the pathfinder dataset with 'diagnostics/' prefix."""
|
|
363
|
+
if result.logP is not None:
|
|
364
|
+
n_paths, n_draws_per_path = result.logP.shape
|
|
365
|
+
if "path" not in coords:
|
|
366
|
+
coords["path"] = list(range(n_paths))
|
|
367
|
+
coords["draw_per_path"] = list(range(n_draws_per_path))
|
|
368
|
+
|
|
369
|
+
data_vars["diagnostics/logP_full"] = xr.DataArray(
|
|
370
|
+
result.logP,
|
|
371
|
+
dims=["path", "draw_per_path"],
|
|
372
|
+
coords={"path": coords["path"], "draw_per_path": coords["draw_per_path"]},
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
if result.logQ is not None:
|
|
376
|
+
if "draw_per_path" not in coords:
|
|
377
|
+
n_paths, n_draws_per_path = result.logQ.shape
|
|
378
|
+
if "path" not in coords:
|
|
379
|
+
coords["path"] = list(range(n_paths))
|
|
380
|
+
coords["draw_per_path"] = list(range(n_draws_per_path))
|
|
381
|
+
|
|
382
|
+
data_vars["diagnostics/logQ_full"] = xr.DataArray(
|
|
383
|
+
result.logQ,
|
|
384
|
+
dims=["path", "draw_per_path"],
|
|
385
|
+
coords={"path": coords["path"], "draw_per_path": coords["draw_per_path"]},
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
if result.samples is not None and result.samples.ndim == 3 and param_coords is not None:
|
|
389
|
+
n_paths, n_draws_per_path, n_params = result.samples.shape
|
|
390
|
+
|
|
391
|
+
if "path" not in coords:
|
|
392
|
+
coords["path"] = list(range(n_paths))
|
|
393
|
+
if "draw_per_path" not in coords:
|
|
394
|
+
coords["draw_per_path"] = list(range(n_draws_per_path))
|
|
395
|
+
|
|
396
|
+
data_vars["diagnostics/samples_full"] = xr.DataArray(
|
|
397
|
+
result.samples,
|
|
398
|
+
dims=["path", "draw_per_path", "param"],
|
|
399
|
+
coords={
|
|
400
|
+
"path": coords["path"],
|
|
401
|
+
"draw_per_path": coords["draw_per_path"],
|
|
402
|
+
"param": coords["param"],
|
|
403
|
+
},
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _determine_num_paths(result: MultiPathfinderResult) -> int:
|
|
408
|
+
"""
|
|
409
|
+
Determine the number of paths from per-path arrays.
|
|
410
|
+
|
|
411
|
+
When importance sampling is applied, result.samples may be collapsed,
|
|
412
|
+
so we use per-path diagnostic arrays to determine the true path count.
|
|
413
|
+
"""
|
|
414
|
+
if result.lbfgs_niter is not None:
|
|
415
|
+
return len(result.lbfgs_niter)
|
|
416
|
+
elif result.elbo_argmax is not None:
|
|
417
|
+
return len(result.elbo_argmax)
|
|
418
|
+
elif result.logP is not None:
|
|
419
|
+
return result.logP.shape[0]
|
|
420
|
+
elif result.logQ is not None:
|
|
421
|
+
return result.logQ.shape[0]
|
|
422
|
+
|
|
423
|
+
if result.lbfgs_status:
|
|
424
|
+
return sum(result.lbfgs_status.values())
|
|
425
|
+
elif result.path_status:
|
|
426
|
+
return sum(result.path_status.values())
|
|
427
|
+
|
|
428
|
+
if result.samples is not None:
|
|
429
|
+
return result.samples.shape[0]
|
|
430
|
+
|
|
431
|
+
raise ValueError("Cannot determine number of paths from result")
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def add_pathfinder_to_inference_data(
|
|
435
|
+
idata: az.InferenceData,
|
|
436
|
+
result: PathfinderResult | MultiPathfinderResult,
|
|
437
|
+
model: pm.Model | None = None,
|
|
438
|
+
*,
|
|
439
|
+
group: str = "pathfinder",
|
|
440
|
+
paths_group: str = "pathfinder_paths", # Deprecated, kept for API compatibility
|
|
441
|
+
diagnostics_group: str = "pathfinder_diagnostics", # Deprecated, kept for API compatibility
|
|
442
|
+
config_group: str = "pathfinder_config", # Deprecated, kept for API compatibility
|
|
443
|
+
store_diagnostics: bool = False,
|
|
444
|
+
) -> az.InferenceData:
|
|
445
|
+
"""
|
|
446
|
+
Add pathfinder results to an ArviZ InferenceData object as a single consolidated group.
|
|
447
|
+
|
|
448
|
+
All pathfinder output is now consolidated under a single group with nested structure:
|
|
449
|
+
- Summary statistics at the top level
|
|
450
|
+
- Per-path data with 'paths/' prefix
|
|
451
|
+
- Configuration with 'config/' prefix
|
|
452
|
+
- Diagnostics with 'diagnostics/' prefix (if store_diagnostics=True)
|
|
453
|
+
|
|
454
|
+
Parameters
|
|
455
|
+
----------
|
|
456
|
+
idata : az.InferenceData
|
|
457
|
+
InferenceData object to modify
|
|
458
|
+
result : PathfinderResult | MultiPathfinderResult
|
|
459
|
+
Pathfinder results to add
|
|
460
|
+
model : pm.Model | None
|
|
461
|
+
PyMC model for parameter name extraction
|
|
462
|
+
group : str
|
|
463
|
+
Name for the pathfinder group (default: "pathfinder")
|
|
464
|
+
paths_group : str
|
|
465
|
+
Deprecated: no longer used, kept for API compatibility
|
|
466
|
+
diagnostics_group : str
|
|
467
|
+
Deprecated: no longer used, kept for API compatibility
|
|
468
|
+
config_group : str
|
|
469
|
+
Deprecated: no longer used, kept for API compatibility
|
|
470
|
+
store_diagnostics : bool
|
|
471
|
+
Whether to include potentially large diagnostic arrays
|
|
472
|
+
|
|
473
|
+
Returns
|
|
474
|
+
-------
|
|
475
|
+
az.InferenceData
|
|
476
|
+
Modified InferenceData object with consolidated pathfinder group added
|
|
477
|
+
|
|
478
|
+
Examples
|
|
479
|
+
--------
|
|
480
|
+
>>> import pymc as pm
|
|
481
|
+
>>> import pymc_extras as pmx
|
|
482
|
+
>>>
|
|
483
|
+
>>> with pm.Model() as model:
|
|
484
|
+
... x = pm.Normal("x", 0, 1)
|
|
485
|
+
... idata = pmx.fit(method="pathfinder", model=model, add_pathfinder_groups=False)
|
|
486
|
+
>>> # Assuming we have pathfinder results
|
|
487
|
+
>>> idata = add_pathfinder_to_inference_data(idata, results, model=model)
|
|
488
|
+
>>> print(list(idata.groups())) # Will show ['posterior', 'pathfinder']
|
|
489
|
+
>>> # Access nested data:
|
|
490
|
+
>>> print(
|
|
491
|
+
... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("paths/")]
|
|
492
|
+
... ) # Per-path data
|
|
493
|
+
>>> print(
|
|
494
|
+
... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("config/")]
|
|
495
|
+
... ) # Config data
|
|
496
|
+
"""
|
|
497
|
+
# Detect if this is a multi-path result
|
|
498
|
+
# Use isinstance() as primary check, but fall back to duck typing for compatibility
|
|
499
|
+
# with mocks and testing (MultiPathfinderResult has Counter-type status fields)
|
|
500
|
+
is_multipath = isinstance(result, MultiPathfinderResult) or (
|
|
501
|
+
hasattr(result, "lbfgs_status")
|
|
502
|
+
and hasattr(result.lbfgs_status, "values")
|
|
503
|
+
and callable(getattr(result.lbfgs_status, "values"))
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
if is_multipath:
|
|
507
|
+
consolidated_ds = multipathfinder_result_to_xarray(
|
|
508
|
+
result, model=model, store_diagnostics=store_diagnostics
|
|
509
|
+
)
|
|
510
|
+
else:
|
|
511
|
+
consolidated_ds = pathfinder_result_to_xarray(result, model=model)
|
|
512
|
+
|
|
513
|
+
if group in idata.groups():
|
|
514
|
+
warnings.warn(f"Group '{group}' already exists in InferenceData, it will be replaced.")
|
|
515
|
+
|
|
516
|
+
idata.add_groups({group: consolidated_ds})
|
|
517
|
+
return idata
|
|
@@ -16,12 +16,13 @@
|
|
|
16
16
|
import collections
|
|
17
17
|
import logging
|
|
18
18
|
import time
|
|
19
|
+
import warnings
|
|
19
20
|
|
|
20
21
|
from collections import Counter
|
|
21
22
|
from collections.abc import Callable, Iterator
|
|
22
23
|
from dataclasses import asdict, dataclass, field, replace
|
|
23
24
|
from enum import Enum, auto
|
|
24
|
-
from typing import Literal, TypeAlias
|
|
25
|
+
from typing import Literal, Self, TypeAlias
|
|
25
26
|
|
|
26
27
|
import arviz as az
|
|
27
28
|
import filelock
|
|
@@ -59,9 +60,6 @@ from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainin
|
|
|
59
60
|
from rich.table import Table
|
|
60
61
|
from rich.text import Text
|
|
61
62
|
|
|
62
|
-
# TODO: change to typing.Self after Python versions greater than 3.10
|
|
63
|
-
from typing_extensions import Self
|
|
64
|
-
|
|
65
63
|
from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data
|
|
66
64
|
from pymc_extras.inference.pathfinder.importance_sampling import (
|
|
67
65
|
importance_sampling as _importance_sampling,
|
|
@@ -532,7 +530,7 @@ def bfgs_sample_sparse(
|
|
|
532
530
|
|
|
533
531
|
# qr_input: (L, N, 2J)
|
|
534
532
|
qr_input = inv_sqrt_alpha_diag @ beta
|
|
535
|
-
(Q, R), _ = pytensor.scan(fn=pt.
|
|
533
|
+
(Q, R), _ = pytensor.scan(fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False)
|
|
536
534
|
|
|
537
535
|
IdN = pt.eye(R.shape[1])[None, ...]
|
|
538
536
|
IdN += IdN * REGULARISATION_TERM
|
|
@@ -1398,6 +1396,7 @@ def multipath_pathfinder(
|
|
|
1398
1396
|
random_seed: RandomSeed,
|
|
1399
1397
|
pathfinder_kwargs: dict = {},
|
|
1400
1398
|
compile_kwargs: dict = {},
|
|
1399
|
+
display_summary: bool = True,
|
|
1401
1400
|
) -> MultiPathfinderResult:
|
|
1402
1401
|
"""
|
|
1403
1402
|
Fit the Pathfinder Variational Inference algorithm using multiple paths with PyMC/PyTensor backend.
|
|
@@ -1556,8 +1555,9 @@ def multipath_pathfinder(
|
|
|
1556
1555
|
compute_time=compute_end - compute_start,
|
|
1557
1556
|
)
|
|
1558
1557
|
)
|
|
1559
|
-
#
|
|
1560
|
-
|
|
1558
|
+
# Display summary conditionally
|
|
1559
|
+
if display_summary:
|
|
1560
|
+
mpr.display_summary()
|
|
1561
1561
|
if mpr.all_paths_failed:
|
|
1562
1562
|
raise ValueError(
|
|
1563
1563
|
"All paths failed. Consider decreasing the jitter or reparameterizing the model."
|
|
@@ -1600,6 +1600,14 @@ def fit_pathfinder(
|
|
|
1600
1600
|
pathfinder_kwargs: dict = {},
|
|
1601
1601
|
compile_kwargs: dict = {},
|
|
1602
1602
|
initvals: dict | None = None,
|
|
1603
|
+
# New pathfinder result integration options
|
|
1604
|
+
add_pathfinder_groups: bool = True,
|
|
1605
|
+
display_summary: bool | Literal["auto"] = "auto",
|
|
1606
|
+
store_diagnostics: bool = False,
|
|
1607
|
+
pathfinder_group: str = "pathfinder",
|
|
1608
|
+
paths_group: str = "pathfinder_paths",
|
|
1609
|
+
diagnostics_group: str = "pathfinder_diagnostics",
|
|
1610
|
+
config_group: str = "pathfinder_config",
|
|
1603
1611
|
) -> az.InferenceData:
|
|
1604
1612
|
"""
|
|
1605
1613
|
Fit the Pathfinder Variational Inference algorithm.
|
|
@@ -1658,6 +1666,22 @@ def fit_pathfinder(
|
|
|
1658
1666
|
initvals: dict | None = None
|
|
1659
1667
|
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
|
|
1660
1668
|
If None, the model's default initial values are used.
|
|
1669
|
+
add_pathfinder_groups : bool, optional
|
|
1670
|
+
Whether to add pathfinder results as additional groups to the InferenceData (default is True).
|
|
1671
|
+
When True, adds pathfinder and pathfinder_paths groups with optimization diagnostics.
|
|
1672
|
+
display_summary : bool or "auto", optional
|
|
1673
|
+
Whether to display the pathfinder results summary (default is "auto").
|
|
1674
|
+
"auto" preserves current behavior, False suppresses console output.
|
|
1675
|
+
store_diagnostics : bool, optional
|
|
1676
|
+
Whether to include potentially large diagnostic arrays in the pathfinder groups (default is False).
|
|
1677
|
+
pathfinder_group : str, optional
|
|
1678
|
+
Name for the main pathfinder results group (default is "pathfinder").
|
|
1679
|
+
paths_group : str, optional
|
|
1680
|
+
Name for the per-path results group (default is "pathfinder_paths").
|
|
1681
|
+
diagnostics_group : str, optional
|
|
1682
|
+
Name for the diagnostics group (default is "pathfinder_diagnostics").
|
|
1683
|
+
config_group : str, optional
|
|
1684
|
+
Name for the configuration group (default is "pathfinder_config").
|
|
1661
1685
|
|
|
1662
1686
|
Returns
|
|
1663
1687
|
-------
|
|
@@ -1694,6 +1718,9 @@ def fit_pathfinder(
|
|
|
1694
1718
|
maxcor = np.ceil(3 * np.log(N)).astype(np.int32)
|
|
1695
1719
|
maxcor = max(maxcor, 5)
|
|
1696
1720
|
|
|
1721
|
+
# Handle display_summary logic
|
|
1722
|
+
should_display_summary = display_summary == "auto" or display_summary is True
|
|
1723
|
+
|
|
1697
1724
|
if inference_backend == "pymc":
|
|
1698
1725
|
mp_result = multipath_pathfinder(
|
|
1699
1726
|
model,
|
|
@@ -1714,6 +1741,7 @@ def fit_pathfinder(
|
|
|
1714
1741
|
random_seed=random_seed,
|
|
1715
1742
|
pathfinder_kwargs=pathfinder_kwargs,
|
|
1716
1743
|
compile_kwargs=compile_kwargs,
|
|
1744
|
+
display_summary=should_display_summary,
|
|
1717
1745
|
)
|
|
1718
1746
|
pathfinder_samples = mp_result.samples
|
|
1719
1747
|
elif inference_backend == "blackjax":
|
|
@@ -1760,4 +1788,30 @@ def fit_pathfinder(
|
|
|
1760
1788
|
|
|
1761
1789
|
idata = add_data_to_inference_data(idata, progressbar, model, compile_kwargs)
|
|
1762
1790
|
|
|
1791
|
+
# Add pathfinder results to InferenceData if requested
|
|
1792
|
+
if add_pathfinder_groups:
|
|
1793
|
+
if inference_backend == "pymc":
|
|
1794
|
+
from pymc_extras.inference.pathfinder.idata import add_pathfinder_to_inference_data
|
|
1795
|
+
|
|
1796
|
+
idata = add_pathfinder_to_inference_data(
|
|
1797
|
+
idata=idata,
|
|
1798
|
+
result=mp_result,
|
|
1799
|
+
model=model,
|
|
1800
|
+
group=pathfinder_group,
|
|
1801
|
+
paths_group=paths_group,
|
|
1802
|
+
diagnostics_group=diagnostics_group,
|
|
1803
|
+
config_group=config_group,
|
|
1804
|
+
store_diagnostics=store_diagnostics,
|
|
1805
|
+
)
|
|
1806
|
+
else:
|
|
1807
|
+
warnings.warn(
|
|
1808
|
+
f"Pathfinder diagnostic groups are only supported with the PyMC backend. "
|
|
1809
|
+
f"Current backend is '{inference_backend}', which does not support adding "
|
|
1810
|
+
"pathfinder diagnostics to InferenceData. The InferenceData will only contain "
|
|
1811
|
+
"posterior samples. To add diagnostic groups, use inference_backend='pymc', "
|
|
1812
|
+
"or set add_pathfinder_groups=False to suppress this warning.",
|
|
1813
|
+
UserWarning,
|
|
1814
|
+
stacklevel=2,
|
|
1815
|
+
)
|
|
1816
|
+
|
|
1763
1817
|
return idata
|
|
@@ -6,8 +6,8 @@ from itertools import zip_longest
|
|
|
6
6
|
from pymc import SymbolicRandomVariable
|
|
7
7
|
from pymc.model.fgraph import ModelVar
|
|
8
8
|
from pymc.variational.minibatch_rv import MinibatchRandomVariable
|
|
9
|
-
from pytensor.graph import Variable
|
|
10
|
-
from pytensor.graph.
|
|
9
|
+
from pytensor.graph.basic import Variable
|
|
10
|
+
from pytensor.graph.traversal import ancestors, io_toposort
|
|
11
11
|
from pytensor.tensor import TensorType, TensorVariable
|
|
12
12
|
from pytensor.tensor.blockwise import Blockwise
|
|
13
13
|
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|