arviz 0.16.1__py3-none-any.whl → 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +1 -1
- arviz/data/inference_data.py +34 -7
- arviz/data/io_beanmachine.py +6 -1
- arviz/data/io_cmdstanpy.py +439 -50
- arviz/data/io_pyjags.py +5 -2
- arviz/data/io_pystan.py +1 -2
- arviz/labels.py +2 -0
- arviz/plots/backends/bokeh/bpvplot.py +7 -2
- arviz/plots/backends/bokeh/compareplot.py +7 -4
- arviz/plots/backends/bokeh/densityplot.py +0 -1
- arviz/plots/backends/bokeh/distplot.py +0 -2
- arviz/plots/backends/bokeh/forestplot.py +3 -5
- arviz/plots/backends/bokeh/kdeplot.py +0 -2
- arviz/plots/backends/bokeh/pairplot.py +0 -4
- arviz/plots/backends/matplotlib/bfplot.py +0 -1
- arviz/plots/backends/matplotlib/bpvplot.py +3 -3
- arviz/plots/backends/matplotlib/compareplot.py +1 -1
- arviz/plots/backends/matplotlib/dotplot.py +1 -1
- arviz/plots/backends/matplotlib/forestplot.py +2 -4
- arviz/plots/backends/matplotlib/kdeplot.py +0 -1
- arviz/plots/backends/matplotlib/khatplot.py +0 -1
- arviz/plots/backends/matplotlib/lmplot.py +4 -5
- arviz/plots/backends/matplotlib/pairplot.py +0 -1
- arviz/plots/backends/matplotlib/ppcplot.py +8 -5
- arviz/plots/backends/matplotlib/traceplot.py +1 -2
- arviz/plots/bfplot.py +7 -6
- arviz/plots/bpvplot.py +7 -2
- arviz/plots/compareplot.py +2 -2
- arviz/plots/ecdfplot.py +37 -112
- arviz/plots/elpdplot.py +1 -1
- arviz/plots/essplot.py +2 -2
- arviz/plots/kdeplot.py +0 -1
- arviz/plots/pairplot.py +1 -1
- arviz/plots/plot_utils.py +0 -1
- arviz/plots/ppcplot.py +51 -45
- arviz/plots/separationplot.py +0 -1
- arviz/stats/__init__.py +2 -0
- arviz/stats/density_utils.py +2 -2
- arviz/stats/diagnostics.py +2 -3
- arviz/stats/ecdf_utils.py +165 -0
- arviz/stats/stats.py +241 -38
- arviz/stats/stats_utils.py +36 -7
- arviz/tests/base_tests/test_data.py +73 -5
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1
- arviz/tests/base_tests/test_plots_matplotlib.py +24 -1
- arviz/tests/base_tests/test_stats.py +43 -1
- arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
- arviz/tests/base_tests/test_stats_utils.py +3 -3
- arviz/tests/external_tests/test_data_beanmachine.py +2 -0
- arviz/tests/external_tests/test_data_numpyro.py +3 -3
- arviz/tests/external_tests/test_data_pyjags.py +3 -1
- arviz/tests/external_tests/test_data_pyro.py +3 -3
- arviz/tests/helpers.py +8 -8
- arviz/utils.py +15 -7
- arviz/wrappers/wrap_pymc.py +1 -1
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/METADATA +16 -15
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/RECORD +60 -58
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/WHEEL +1 -1
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/LICENSE +0 -0
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/top_level.txt +0 -0
arviz/data/io_cmdstanpy.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# pylint: disable=too-many-lines
|
|
1
2
|
"""CmdStanPy-specific conversion code."""
|
|
2
3
|
import logging
|
|
3
4
|
import re
|
|
@@ -72,15 +73,26 @@ class CmdStanPyConverter:
|
|
|
72
73
|
|
|
73
74
|
self.dtypes = dtypes
|
|
74
75
|
|
|
75
|
-
if hasattr(self.posterior, "metadata")
|
|
76
|
+
if hasattr(self.posterior, "metadata") and hasattr(
|
|
77
|
+
self.posterior.metadata, "stan_vars_cols"
|
|
78
|
+
):
|
|
79
|
+
if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars_cols:
|
|
80
|
+
self.log_likelihood = ["log_lik"]
|
|
81
|
+
elif hasattr(self.posterior, "metadata") and hasattr(
|
|
82
|
+
self.posterior.metadata, "stan_vars_cols"
|
|
83
|
+
):
|
|
76
84
|
if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars_cols:
|
|
77
85
|
self.log_likelihood = ["log_lik"]
|
|
78
86
|
elif hasattr(self.posterior, "stan_vars_cols"):
|
|
79
87
|
if self.log_likelihood is True and "log_lik" in self.posterior.stan_vars_cols:
|
|
80
88
|
self.log_likelihood = ["log_lik"]
|
|
89
|
+
elif hasattr(self.posterior, "metadata") and hasattr(self.posterior.metadata, "stan_vars"):
|
|
90
|
+
if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars:
|
|
91
|
+
self.log_likelihood = ["log_lik"]
|
|
81
92
|
elif (
|
|
82
93
|
self.log_likelihood is True
|
|
83
94
|
and self.posterior is not None
|
|
95
|
+
and hasattr(self.posterior, "column_names")
|
|
84
96
|
and any(name.split("[")[0] == "log_lik" for name in self.posterior.column_names)
|
|
85
97
|
):
|
|
86
98
|
self.log_likelihood = ["log_lik"]
|
|
@@ -95,11 +107,17 @@ class CmdStanPyConverter:
|
|
|
95
107
|
"""Extract posterior samples from output csv."""
|
|
96
108
|
if not (hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")):
|
|
97
109
|
return self.posterior_to_xarray_pre_v_0_9_68()
|
|
110
|
+
if (
|
|
111
|
+
hasattr(self.posterior, "metadata")
|
|
112
|
+
and hasattr(self.posterior.metadata, "stan_vars_cols")
|
|
113
|
+
) or hasattr(self.posterior, "stan_vars_cols"):
|
|
114
|
+
return self.posterior_to_xarray_pre_v_1_0_0()
|
|
115
|
+
if hasattr(self.posterior, "metadata") and hasattr(
|
|
116
|
+
self.posterior.metadata, "stan_vars_cols"
|
|
117
|
+
):
|
|
118
|
+
return self.posterior_to_xarray_pre_v_1_2_0()
|
|
98
119
|
|
|
99
|
-
|
|
100
|
-
items = list(self.posterior.metadata.stan_vars_cols.keys())
|
|
101
|
-
else:
|
|
102
|
-
items = list(self.posterior.stan_vars_cols.keys())
|
|
120
|
+
items = list(self.posterior.metadata.stan_vars)
|
|
103
121
|
if self.posterior_predictive is not None:
|
|
104
122
|
try:
|
|
105
123
|
items = _filter(items, self.posterior_predictive)
|
|
@@ -119,9 +137,8 @@ class CmdStanPyConverter:
|
|
|
119
137
|
valid_cols = []
|
|
120
138
|
for item in items:
|
|
121
139
|
if hasattr(self.posterior, "metadata"):
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
valid_cols.extend(self.posterior.stan_vars_cols[item])
|
|
140
|
+
if item in self.posterior.metadata.stan_vars:
|
|
141
|
+
valid_cols.append(item)
|
|
125
142
|
|
|
126
143
|
data, data_warmup = _unpack_fit(
|
|
127
144
|
self.posterior,
|
|
@@ -130,7 +147,6 @@ class CmdStanPyConverter:
|
|
|
130
147
|
self.dtypes,
|
|
131
148
|
)
|
|
132
149
|
|
|
133
|
-
# copy dims and coords - Mitzi question: why???
|
|
134
150
|
dims = deepcopy(self.dims) if self.dims is not None else {}
|
|
135
151
|
coords = deepcopy(self.coords) if self.coords is not None else {}
|
|
136
152
|
|
|
@@ -165,6 +181,12 @@ class CmdStanPyConverter:
|
|
|
165
181
|
"""Extract sample_stats from fit."""
|
|
166
182
|
if not (hasattr(fit, "metadata") or hasattr(fit, "sampler_vars_cols")):
|
|
167
183
|
return self.sample_stats_to_xarray_pre_v_0_9_68(fit)
|
|
184
|
+
if (hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols")) or hasattr(
|
|
185
|
+
fit, "stan_vars_cols"
|
|
186
|
+
):
|
|
187
|
+
return self.sample_stats_to_xarray_pre_v_1_0_0(fit)
|
|
188
|
+
if hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols"):
|
|
189
|
+
return self.sample_stats_to_xarray_pre_v_1_2_0(fit)
|
|
168
190
|
|
|
169
191
|
dtypes = {
|
|
170
192
|
"divergent__": bool,
|
|
@@ -172,10 +194,9 @@ class CmdStanPyConverter:
|
|
|
172
194
|
"treedepth__": np.int64,
|
|
173
195
|
**self.dtypes,
|
|
174
196
|
}
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
items = list(fit.sampler_vars_cols.keys())
|
|
197
|
+
|
|
198
|
+
items = list(fit.method_variables()) # pylint: disable=protected-access
|
|
199
|
+
|
|
179
200
|
rename_dict = {
|
|
180
201
|
"divergent": "diverging",
|
|
181
202
|
"n_leapfrog": "n_steps",
|
|
@@ -196,6 +217,7 @@ class CmdStanPyConverter:
|
|
|
196
217
|
data[name] = data.pop(item).astype(dtypes.get(item, float))
|
|
197
218
|
if data_warmup:
|
|
198
219
|
data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
|
|
220
|
+
|
|
199
221
|
return (
|
|
200
222
|
dict_to_dataset(
|
|
201
223
|
data,
|
|
@@ -229,19 +251,35 @@ class CmdStanPyConverter:
|
|
|
229
251
|
"""Convert predictive samples to xarray."""
|
|
230
252
|
predictive = _as_set(names)
|
|
231
253
|
|
|
232
|
-
if hasattr(fit, "metadata") or hasattr(fit, "stan_vars_cols"):
|
|
233
|
-
|
|
254
|
+
if not (hasattr(fit, "metadata") or hasattr(fit, "stan_vars_cols")): # pre_v_0_9_68
|
|
255
|
+
valid_cols = _filter_columns(fit.column_names, predictive)
|
|
256
|
+
data, data_warmup = _unpack_frame(
|
|
257
|
+
fit,
|
|
258
|
+
fit.column_names,
|
|
259
|
+
valid_cols,
|
|
260
|
+
self.save_warmup,
|
|
261
|
+
self.dtypes,
|
|
262
|
+
)
|
|
263
|
+
elif (hasattr(fit, "metadata") and hasattr(fit.metadata, "sample_vars_cols")) or hasattr(
|
|
264
|
+
fit, "stan_vars_cols"
|
|
265
|
+
): # pre_v_1_0_0
|
|
266
|
+
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
234
267
|
fit,
|
|
235
268
|
predictive,
|
|
236
269
|
self.save_warmup,
|
|
237
270
|
self.dtypes,
|
|
238
271
|
)
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
data, data_warmup = _unpack_frame(
|
|
272
|
+
elif hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols"): # pre_v_1_2_0
|
|
273
|
+
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
242
274
|
fit,
|
|
243
|
-
|
|
244
|
-
|
|
275
|
+
predictive,
|
|
276
|
+
self.save_warmup,
|
|
277
|
+
self.dtypes,
|
|
278
|
+
)
|
|
279
|
+
else:
|
|
280
|
+
data, data_warmup = _unpack_fit(
|
|
281
|
+
fit,
|
|
282
|
+
predictive,
|
|
245
283
|
self.save_warmup,
|
|
246
284
|
self.dtypes,
|
|
247
285
|
)
|
|
@@ -269,14 +307,9 @@ class CmdStanPyConverter:
|
|
|
269
307
|
"""Convert out of sample predictions samples to xarray."""
|
|
270
308
|
predictions = _as_set(self.predictions)
|
|
271
309
|
|
|
272
|
-
if
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
predictions,
|
|
276
|
-
self.save_warmup,
|
|
277
|
-
self.dtypes,
|
|
278
|
-
)
|
|
279
|
-
else: # pre_v_0_9_68
|
|
310
|
+
if not (
|
|
311
|
+
hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")
|
|
312
|
+
): # pre_v_0_9_68
|
|
280
313
|
columns = self.posterior.column_names
|
|
281
314
|
valid_cols = _filter_columns(columns, predictions)
|
|
282
315
|
data, data_warmup = _unpack_frame(
|
|
@@ -286,6 +319,34 @@ class CmdStanPyConverter:
|
|
|
286
319
|
self.save_warmup,
|
|
287
320
|
self.dtypes,
|
|
288
321
|
)
|
|
322
|
+
elif (
|
|
323
|
+
hasattr(self.posterior, "metadata")
|
|
324
|
+
and hasattr(self.posterior.metadata, "sample_vars_cols")
|
|
325
|
+
) or hasattr(
|
|
326
|
+
self.posterior, "stan_vars_cols"
|
|
327
|
+
): # pre_v_1_0_0
|
|
328
|
+
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
329
|
+
self.posterior,
|
|
330
|
+
predictions,
|
|
331
|
+
self.save_warmup,
|
|
332
|
+
self.dtypes,
|
|
333
|
+
)
|
|
334
|
+
elif hasattr(self.posterior, "metadata") and hasattr(
|
|
335
|
+
self.posterior.metadata, "stan_vars_cols"
|
|
336
|
+
): # pre_v_1_2_0
|
|
337
|
+
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
338
|
+
self.posterior,
|
|
339
|
+
predictions,
|
|
340
|
+
self.save_warmup,
|
|
341
|
+
self.dtypes,
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
data, data_warmup = _unpack_fit(
|
|
345
|
+
self.posterior,
|
|
346
|
+
predictions,
|
|
347
|
+
self.save_warmup,
|
|
348
|
+
self.dtypes,
|
|
349
|
+
)
|
|
289
350
|
|
|
290
351
|
return (
|
|
291
352
|
dict_to_dataset(
|
|
@@ -310,14 +371,9 @@ class CmdStanPyConverter:
|
|
|
310
371
|
"""Convert elementwise log likelihood samples to xarray."""
|
|
311
372
|
log_likelihood = _as_set(self.log_likelihood)
|
|
312
373
|
|
|
313
|
-
if
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
log_likelihood,
|
|
317
|
-
self.save_warmup,
|
|
318
|
-
self.dtypes,
|
|
319
|
-
)
|
|
320
|
-
else: # pre_v_0_9_68
|
|
374
|
+
if not (
|
|
375
|
+
hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")
|
|
376
|
+
): # pre_v_0_9_68
|
|
321
377
|
columns = self.posterior.column_names
|
|
322
378
|
valid_cols = _filter_columns(columns, log_likelihood)
|
|
323
379
|
data, data_warmup = _unpack_frame(
|
|
@@ -327,6 +383,35 @@ class CmdStanPyConverter:
|
|
|
327
383
|
self.save_warmup,
|
|
328
384
|
self.dtypes,
|
|
329
385
|
)
|
|
386
|
+
elif (
|
|
387
|
+
hasattr(self.posterior, "metadata")
|
|
388
|
+
and hasattr(self.posterior.metadata, "sample_vars_cols")
|
|
389
|
+
) or hasattr(
|
|
390
|
+
self.posterior, "stan_vars_cols"
|
|
391
|
+
): # pre_v_1_0_0
|
|
392
|
+
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
393
|
+
self.posterior,
|
|
394
|
+
log_likelihood,
|
|
395
|
+
self.save_warmup,
|
|
396
|
+
self.dtypes,
|
|
397
|
+
)
|
|
398
|
+
elif hasattr(self.posterior, "metadata") and hasattr(
|
|
399
|
+
self.posterior.metadata, "stan_vars_cols"
|
|
400
|
+
): # pre_v_1_2_0
|
|
401
|
+
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
402
|
+
self.posterior,
|
|
403
|
+
log_likelihood,
|
|
404
|
+
self.save_warmup,
|
|
405
|
+
self.dtypes,
|
|
406
|
+
)
|
|
407
|
+
else:
|
|
408
|
+
data, data_warmup = _unpack_fit(
|
|
409
|
+
self.posterior,
|
|
410
|
+
log_likelihood,
|
|
411
|
+
self.save_warmup,
|
|
412
|
+
self.dtypes,
|
|
413
|
+
)
|
|
414
|
+
|
|
330
415
|
if isinstance(self.log_likelihood, dict):
|
|
331
416
|
data = {obs_name: data[lik_name] for obs_name, lik_name in self.log_likelihood.items()}
|
|
332
417
|
if data_warmup:
|
|
@@ -356,8 +441,29 @@ class CmdStanPyConverter:
|
|
|
356
441
|
@requires("prior")
|
|
357
442
|
def prior_to_xarray(self):
|
|
358
443
|
"""Convert prior samples to xarray."""
|
|
359
|
-
if
|
|
360
|
-
|
|
444
|
+
if not (
|
|
445
|
+
hasattr(self.prior, "metadata") or hasattr(self.prior, "stan_vars_cols")
|
|
446
|
+
): # pre_v_0_9_68
|
|
447
|
+
columns = self.prior.column_names
|
|
448
|
+
prior_predictive = _as_set(self.prior_predictive)
|
|
449
|
+
prior_predictive = _filter_columns(columns, prior_predictive)
|
|
450
|
+
|
|
451
|
+
invalid_cols = set(prior_predictive + [col for col in columns if col.endswith("__")])
|
|
452
|
+
valid_cols = [col for col in columns if col not in invalid_cols]
|
|
453
|
+
|
|
454
|
+
data, data_warmup = _unpack_frame(
|
|
455
|
+
self.prior,
|
|
456
|
+
columns,
|
|
457
|
+
valid_cols,
|
|
458
|
+
self.save_warmup,
|
|
459
|
+
self.dtypes,
|
|
460
|
+
)
|
|
461
|
+
elif (
|
|
462
|
+
hasattr(self.prior, "metadata") and hasattr(self.prior.metadata, "sample_vars_cols")
|
|
463
|
+
) or hasattr(
|
|
464
|
+
self.prior, "stan_vars_cols"
|
|
465
|
+
): # pre_v_1_0_0
|
|
466
|
+
if hasattr(self.prior, "metadata"):
|
|
361
467
|
items = list(self.prior.metadata.stan_vars_cols.keys())
|
|
362
468
|
else:
|
|
363
469
|
items = list(self.prior.stan_vars_cols.keys())
|
|
@@ -366,24 +472,37 @@ class CmdStanPyConverter:
|
|
|
366
472
|
items = _filter(items, self.prior_predictive)
|
|
367
473
|
except ValueError:
|
|
368
474
|
pass
|
|
369
|
-
data, data_warmup =
|
|
475
|
+
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
370
476
|
self.prior,
|
|
371
477
|
items,
|
|
372
478
|
self.save_warmup,
|
|
373
479
|
self.dtypes,
|
|
374
480
|
)
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
481
|
+
elif hasattr(self.prior, "metadata") and hasattr(
|
|
482
|
+
self.prior.metadata, "stan_vars_cols"
|
|
483
|
+
): # pre_v_1_2_0
|
|
484
|
+
items = list(self.prior.metadata.stan_vars_cols.keys())
|
|
485
|
+
if self.prior_predictive is not None:
|
|
486
|
+
try:
|
|
487
|
+
items = _filter(items, self.prior_predictive)
|
|
488
|
+
except ValueError:
|
|
489
|
+
pass
|
|
490
|
+
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
384
491
|
self.prior,
|
|
385
|
-
|
|
386
|
-
|
|
492
|
+
items,
|
|
493
|
+
self.save_warmup,
|
|
494
|
+
self.dtypes,
|
|
495
|
+
)
|
|
496
|
+
else:
|
|
497
|
+
items = list(self.prior.metadata.stan_vars.keys())
|
|
498
|
+
if self.prior_predictive is not None:
|
|
499
|
+
try:
|
|
500
|
+
items = _filter(items, self.prior_predictive)
|
|
501
|
+
except ValueError:
|
|
502
|
+
pass
|
|
503
|
+
data, data_warmup = _unpack_fit(
|
|
504
|
+
self.prior,
|
|
505
|
+
items,
|
|
387
506
|
self.save_warmup,
|
|
388
507
|
self.dtypes,
|
|
389
508
|
)
|
|
@@ -466,6 +585,113 @@ class CmdStanPyConverter:
|
|
|
466
585
|
},
|
|
467
586
|
)
|
|
468
587
|
|
|
588
|
+
def posterior_to_xarray_pre_v_1_2_0(self):
|
|
589
|
+
items = list(self.posterior.metadata.stan_vars_cols)
|
|
590
|
+
if self.posterior_predictive is not None:
|
|
591
|
+
try:
|
|
592
|
+
items = _filter(items, self.posterior_predictive)
|
|
593
|
+
except ValueError:
|
|
594
|
+
pass
|
|
595
|
+
if self.predictions is not None:
|
|
596
|
+
try:
|
|
597
|
+
items = _filter(items, self.predictions)
|
|
598
|
+
except ValueError:
|
|
599
|
+
pass
|
|
600
|
+
if self.log_likelihood is not None:
|
|
601
|
+
try:
|
|
602
|
+
items = _filter(items, self.log_likelihood)
|
|
603
|
+
except ValueError:
|
|
604
|
+
pass
|
|
605
|
+
|
|
606
|
+
valid_cols = []
|
|
607
|
+
for item in items:
|
|
608
|
+
if hasattr(self.posterior, "metadata"):
|
|
609
|
+
if item in self.posterior.metadata.stan_vars_cols:
|
|
610
|
+
valid_cols.append(item)
|
|
611
|
+
|
|
612
|
+
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
613
|
+
self.posterior,
|
|
614
|
+
items,
|
|
615
|
+
self.save_warmup,
|
|
616
|
+
self.dtypes,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
dims = deepcopy(self.dims) if self.dims is not None else {}
|
|
620
|
+
coords = deepcopy(self.coords) if self.coords is not None else {}
|
|
621
|
+
|
|
622
|
+
return (
|
|
623
|
+
dict_to_dataset(
|
|
624
|
+
data,
|
|
625
|
+
library=self.cmdstanpy,
|
|
626
|
+
coords=coords,
|
|
627
|
+
dims=dims,
|
|
628
|
+
index_origin=self.index_origin,
|
|
629
|
+
),
|
|
630
|
+
dict_to_dataset(
|
|
631
|
+
data_warmup,
|
|
632
|
+
library=self.cmdstanpy,
|
|
633
|
+
coords=coords,
|
|
634
|
+
dims=dims,
|
|
635
|
+
index_origin=self.index_origin,
|
|
636
|
+
),
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
@requires("posterior")
|
|
640
|
+
def posterior_to_xarray_pre_v_1_0_0(self):
|
|
641
|
+
if hasattr(self.posterior, "metadata"):
|
|
642
|
+
items = list(self.posterior.metadata.stan_vars_cols.keys())
|
|
643
|
+
else:
|
|
644
|
+
items = list(self.posterior.stan_vars_cols.keys())
|
|
645
|
+
if self.posterior_predictive is not None:
|
|
646
|
+
try:
|
|
647
|
+
items = _filter(items, self.posterior_predictive)
|
|
648
|
+
except ValueError:
|
|
649
|
+
pass
|
|
650
|
+
if self.predictions is not None:
|
|
651
|
+
try:
|
|
652
|
+
items = _filter(items, self.predictions)
|
|
653
|
+
except ValueError:
|
|
654
|
+
pass
|
|
655
|
+
if self.log_likelihood is not None:
|
|
656
|
+
try:
|
|
657
|
+
items = _filter(items, self.log_likelihood)
|
|
658
|
+
except ValueError:
|
|
659
|
+
pass
|
|
660
|
+
|
|
661
|
+
valid_cols = []
|
|
662
|
+
for item in items:
|
|
663
|
+
if hasattr(self.posterior, "metadata"):
|
|
664
|
+
valid_cols.extend(self.posterior.metadata.stan_vars_cols[item])
|
|
665
|
+
else:
|
|
666
|
+
valid_cols.extend(self.posterior.stan_vars_cols[item])
|
|
667
|
+
|
|
668
|
+
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
669
|
+
self.posterior,
|
|
670
|
+
items,
|
|
671
|
+
self.save_warmup,
|
|
672
|
+
self.dtypes,
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
dims = deepcopy(self.dims) if self.dims is not None else {}
|
|
676
|
+
coords = deepcopy(self.coords) if self.coords is not None else {}
|
|
677
|
+
|
|
678
|
+
return (
|
|
679
|
+
dict_to_dataset(
|
|
680
|
+
data,
|
|
681
|
+
library=self.cmdstanpy,
|
|
682
|
+
coords=coords,
|
|
683
|
+
dims=dims,
|
|
684
|
+
index_origin=self.index_origin,
|
|
685
|
+
),
|
|
686
|
+
dict_to_dataset(
|
|
687
|
+
data_warmup,
|
|
688
|
+
library=self.cmdstanpy,
|
|
689
|
+
coords=coords,
|
|
690
|
+
dims=dims,
|
|
691
|
+
index_origin=self.index_origin,
|
|
692
|
+
),
|
|
693
|
+
)
|
|
694
|
+
|
|
469
695
|
@requires("posterior")
|
|
470
696
|
def posterior_to_xarray_pre_v_0_9_68(self):
|
|
471
697
|
"""Extract posterior samples from output csv."""
|
|
@@ -544,6 +770,103 @@ class CmdStanPyConverter:
|
|
|
544
770
|
),
|
|
545
771
|
)
|
|
546
772
|
|
|
773
|
+
def sample_stats_to_xarray_pre_v_1_2_0(self, fit):
|
|
774
|
+
dtypes = {
|
|
775
|
+
"divergent__": bool,
|
|
776
|
+
"n_leapfrog__": np.int64,
|
|
777
|
+
"treedepth__": np.int64,
|
|
778
|
+
**self.dtypes,
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
items = list(fit.metadata.method_vars_cols.keys()) # pylint: disable=protected-access
|
|
782
|
+
|
|
783
|
+
rename_dict = {
|
|
784
|
+
"divergent": "diverging",
|
|
785
|
+
"n_leapfrog": "n_steps",
|
|
786
|
+
"treedepth": "tree_depth",
|
|
787
|
+
"stepsize": "step_size",
|
|
788
|
+
"accept_stat": "acceptance_rate",
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
data, data_warmup = _unpack_fit_pre_v_1_2_0(
|
|
792
|
+
fit,
|
|
793
|
+
items,
|
|
794
|
+
self.save_warmup,
|
|
795
|
+
self.dtypes,
|
|
796
|
+
)
|
|
797
|
+
for item in items:
|
|
798
|
+
name = re.sub("__$", "", item)
|
|
799
|
+
name = rename_dict.get(name, name)
|
|
800
|
+
data[name] = data.pop(item).astype(dtypes.get(item, float))
|
|
801
|
+
if data_warmup:
|
|
802
|
+
data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
|
|
803
|
+
|
|
804
|
+
return (
|
|
805
|
+
dict_to_dataset(
|
|
806
|
+
data,
|
|
807
|
+
library=self.cmdstanpy,
|
|
808
|
+
coords=self.coords,
|
|
809
|
+
dims=self.dims,
|
|
810
|
+
index_origin=self.index_origin,
|
|
811
|
+
),
|
|
812
|
+
dict_to_dataset(
|
|
813
|
+
data_warmup,
|
|
814
|
+
library=self.cmdstanpy,
|
|
815
|
+
coords=self.coords,
|
|
816
|
+
dims=self.dims,
|
|
817
|
+
index_origin=self.index_origin,
|
|
818
|
+
),
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
def sample_stats_to_xarray_pre_v_1_0_0(self, fit):
|
|
822
|
+
"""Extract sample_stats from fit."""
|
|
823
|
+
dtypes = {
|
|
824
|
+
"divergent__": bool,
|
|
825
|
+
"n_leapfrog__": np.int64,
|
|
826
|
+
"treedepth__": np.int64,
|
|
827
|
+
**self.dtypes,
|
|
828
|
+
}
|
|
829
|
+
if hasattr(fit, "metadata"):
|
|
830
|
+
items = list(fit.metadata._method_vars_cols.keys()) # pylint: disable=protected-access
|
|
831
|
+
else:
|
|
832
|
+
items = list(fit.sampler_vars_cols.keys())
|
|
833
|
+
rename_dict = {
|
|
834
|
+
"divergent": "diverging",
|
|
835
|
+
"n_leapfrog": "n_steps",
|
|
836
|
+
"treedepth": "tree_depth",
|
|
837
|
+
"stepsize": "step_size",
|
|
838
|
+
"accept_stat": "acceptance_rate",
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
data, data_warmup = _unpack_fit_pre_v_1_0_0(
|
|
842
|
+
fit,
|
|
843
|
+
items,
|
|
844
|
+
self.save_warmup,
|
|
845
|
+
self.dtypes,
|
|
846
|
+
)
|
|
847
|
+
for item in items:
|
|
848
|
+
name = re.sub("__$", "", item)
|
|
849
|
+
name = rename_dict.get(name, name)
|
|
850
|
+
data[name] = data.pop(item).astype(dtypes.get(item, float))
|
|
851
|
+
if data_warmup:
|
|
852
|
+
data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
|
|
853
|
+
return (
|
|
854
|
+
dict_to_dataset(
|
|
855
|
+
data,
|
|
856
|
+
library=self.cmdstanpy,
|
|
857
|
+
coords=self.coords,
|
|
858
|
+
dims=self.dims,
|
|
859
|
+
index_origin=self.index_origin,
|
|
860
|
+
),
|
|
861
|
+
dict_to_dataset(
|
|
862
|
+
data_warmup,
|
|
863
|
+
library=self.cmdstanpy,
|
|
864
|
+
coords=self.coords,
|
|
865
|
+
dims=self.dims,
|
|
866
|
+
index_origin=self.index_origin,
|
|
867
|
+
),
|
|
868
|
+
)
|
|
869
|
+
|
|
547
870
|
def sample_stats_to_xarray_pre_v_0_9_68(self, fit):
|
|
548
871
|
"""Extract sample_stats from fit."""
|
|
549
872
|
dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}
|
|
@@ -612,6 +935,72 @@ def _filter_columns(columns, spec):
|
|
|
612
935
|
|
|
613
936
|
|
|
614
937
|
def _unpack_fit(fit, items, save_warmup, dtypes):
|
|
938
|
+
num_warmup = 0
|
|
939
|
+
if save_warmup:
|
|
940
|
+
if not fit._save_warmup: # pylint: disable=protected-access
|
|
941
|
+
save_warmup = False
|
|
942
|
+
else:
|
|
943
|
+
num_warmup = fit.num_draws_warmup
|
|
944
|
+
|
|
945
|
+
nchains = fit.chains
|
|
946
|
+
sample = {}
|
|
947
|
+
sample_warmup = {}
|
|
948
|
+
stan_vars_cols = list(fit.metadata.stan_vars)
|
|
949
|
+
sampler_vars = fit.method_variables()
|
|
950
|
+
for item in items:
|
|
951
|
+
if item in stan_vars_cols:
|
|
952
|
+
raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
|
|
953
|
+
raw_draws = np.swapaxes(
|
|
954
|
+
raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
|
|
955
|
+
)
|
|
956
|
+
elif item in sampler_vars:
|
|
957
|
+
raw_draws = np.swapaxes(sampler_vars[item], 0, 1)
|
|
958
|
+
else:
|
|
959
|
+
raise ValueError(f"fit data, unknown variable: {item}")
|
|
960
|
+
raw_draws = raw_draws.astype(dtypes.get(item))
|
|
961
|
+
if save_warmup:
|
|
962
|
+
sample_warmup[item] = raw_draws[:, :num_warmup, ...]
|
|
963
|
+
sample[item] = raw_draws[:, num_warmup:, ...]
|
|
964
|
+
else:
|
|
965
|
+
sample[item] = raw_draws
|
|
966
|
+
|
|
967
|
+
return sample, sample_warmup
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
def _unpack_fit_pre_v_1_2_0(fit, items, save_warmup, dtypes):
|
|
971
|
+
num_warmup = 0
|
|
972
|
+
if save_warmup:
|
|
973
|
+
if not fit._save_warmup: # pylint: disable=protected-access
|
|
974
|
+
save_warmup = False
|
|
975
|
+
else:
|
|
976
|
+
num_warmup = fit.num_draws_warmup
|
|
977
|
+
|
|
978
|
+
nchains = fit.chains
|
|
979
|
+
sample = {}
|
|
980
|
+
sample_warmup = {}
|
|
981
|
+
stan_vars_cols = list(fit.metadata.stan_vars_cols)
|
|
982
|
+
sampler_vars = fit.method_variables()
|
|
983
|
+
for item in items:
|
|
984
|
+
if item in stan_vars_cols:
|
|
985
|
+
raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
|
|
986
|
+
raw_draws = np.swapaxes(
|
|
987
|
+
raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
|
|
988
|
+
)
|
|
989
|
+
elif item in sampler_vars:
|
|
990
|
+
raw_draws = np.swapaxes(sampler_vars[item], 0, 1)
|
|
991
|
+
else:
|
|
992
|
+
raise ValueError(f"fit data, unknown variable: {item}")
|
|
993
|
+
raw_draws = raw_draws.astype(dtypes.get(item))
|
|
994
|
+
if save_warmup:
|
|
995
|
+
sample_warmup[item] = raw_draws[:, :num_warmup, ...]
|
|
996
|
+
sample[item] = raw_draws[:, num_warmup:, ...]
|
|
997
|
+
else:
|
|
998
|
+
sample[item] = raw_draws
|
|
999
|
+
|
|
1000
|
+
return sample, sample_warmup
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
def _unpack_fit_pre_v_1_0_0(fit, items, save_warmup, dtypes):
|
|
615
1004
|
"""Transform fit to dictionary containing ndarrays.
|
|
616
1005
|
|
|
617
1006
|
Parameters
|
arviz/data/io_pyjags.py
CHANGED
|
@@ -52,7 +52,7 @@ class PyJAGSConverter:
|
|
|
52
52
|
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
|
|
53
53
|
self.warmup_iterations = warmup_iterations
|
|
54
54
|
|
|
55
|
-
import pyjags
|
|
55
|
+
import pyjags # pylint: disable=import-error
|
|
56
56
|
|
|
57
57
|
self.pyjags = pyjags
|
|
58
58
|
|
|
@@ -149,7 +149,10 @@ def get_draws(
|
|
|
149
149
|
variables = tuple(variables)
|
|
150
150
|
|
|
151
151
|
if warmup_iterations > 0:
|
|
152
|
-
(
|
|
152
|
+
(
|
|
153
|
+
warmup_samples,
|
|
154
|
+
actual_samples,
|
|
155
|
+
) = _split_pyjags_dict_in_warmup_and_actual_samples(
|
|
153
156
|
pyjags_samples=pyjags_samples,
|
|
154
157
|
warmup_iterations=warmup_iterations,
|
|
155
158
|
variable_names=variables,
|
arviz/data/io_pystan.py
CHANGED
|
@@ -676,8 +676,7 @@ def get_draws(fit, variables=None, ignore=None, warmup=False, dtypes=None):
|
|
|
676
676
|
for item in par_keys:
|
|
677
677
|
_, shape = item.replace("]", "").split("[")
|
|
678
678
|
shape_idx_min = min(int(shape_value) for shape_value in shape.split(","))
|
|
679
|
-
|
|
680
|
-
shift = shape_idx_min
|
|
679
|
+
shift = min(shift, shape_idx_min)
|
|
681
680
|
# If shift is higher than 1, this will probably mean that Stan
|
|
682
681
|
# has implemented sparse structure (saves only non-zero parts),
|
|
683
682
|
# but let's hope that dims are still corresponding to the full shape
|
arviz/labels.py
CHANGED
|
@@ -100,6 +100,8 @@ class BaseLabeller:
|
|
|
100
100
|
"""WIP."""
|
|
101
101
|
var_name_str = self.var_name_to_str(var_name)
|
|
102
102
|
pp_var_name_str = self.var_name_to_str(pp_var_name)
|
|
103
|
+
if var_name_str == pp_var_name_str:
|
|
104
|
+
return f"{var_name_str}"
|
|
103
105
|
return f"{var_name_str} / {pp_var_name_str}"
|
|
104
106
|
|
|
105
107
|
def model_name_to_str(self, model_name):
|
|
@@ -171,8 +171,13 @@ def plot_bpv(
|
|
|
171
171
|
ax_i.line(0, 0, legend_label=f"bpv={p_value:.2f}", alpha=0)
|
|
172
172
|
|
|
173
173
|
if plot_mean:
|
|
174
|
-
ax_i.
|
|
175
|
-
obs_vals.mean(),
|
|
174
|
+
ax_i.scatter(
|
|
175
|
+
obs_vals.mean(),
|
|
176
|
+
0,
|
|
177
|
+
fill_color=color,
|
|
178
|
+
line_color="black",
|
|
179
|
+
size=markersize,
|
|
180
|
+
marker="circle",
|
|
176
181
|
)
|
|
177
182
|
|
|
178
183
|
_title = Title()
|