arviz 0.16.1__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. arviz/__init__.py +1 -1
  2. arviz/data/inference_data.py +34 -7
  3. arviz/data/io_beanmachine.py +6 -1
  4. arviz/data/io_cmdstanpy.py +439 -50
  5. arviz/data/io_pyjags.py +5 -2
  6. arviz/data/io_pystan.py +1 -2
  7. arviz/labels.py +2 -0
  8. arviz/plots/backends/bokeh/bpvplot.py +7 -2
  9. arviz/plots/backends/bokeh/compareplot.py +7 -4
  10. arviz/plots/backends/bokeh/densityplot.py +0 -1
  11. arviz/plots/backends/bokeh/distplot.py +0 -2
  12. arviz/plots/backends/bokeh/forestplot.py +3 -5
  13. arviz/plots/backends/bokeh/kdeplot.py +0 -2
  14. arviz/plots/backends/bokeh/pairplot.py +0 -4
  15. arviz/plots/backends/matplotlib/bfplot.py +0 -1
  16. arviz/plots/backends/matplotlib/bpvplot.py +3 -3
  17. arviz/plots/backends/matplotlib/compareplot.py +1 -1
  18. arviz/plots/backends/matplotlib/dotplot.py +1 -1
  19. arviz/plots/backends/matplotlib/forestplot.py +2 -4
  20. arviz/plots/backends/matplotlib/kdeplot.py +0 -1
  21. arviz/plots/backends/matplotlib/khatplot.py +0 -1
  22. arviz/plots/backends/matplotlib/lmplot.py +4 -5
  23. arviz/plots/backends/matplotlib/pairplot.py +0 -1
  24. arviz/plots/backends/matplotlib/ppcplot.py +8 -5
  25. arviz/plots/backends/matplotlib/traceplot.py +1 -2
  26. arviz/plots/bfplot.py +7 -6
  27. arviz/plots/bpvplot.py +7 -2
  28. arviz/plots/compareplot.py +2 -2
  29. arviz/plots/ecdfplot.py +37 -112
  30. arviz/plots/elpdplot.py +1 -1
  31. arviz/plots/essplot.py +2 -2
  32. arviz/plots/kdeplot.py +0 -1
  33. arviz/plots/pairplot.py +1 -1
  34. arviz/plots/plot_utils.py +0 -1
  35. arviz/plots/ppcplot.py +51 -45
  36. arviz/plots/separationplot.py +0 -1
  37. arviz/stats/__init__.py +2 -0
  38. arviz/stats/density_utils.py +2 -2
  39. arviz/stats/diagnostics.py +2 -3
  40. arviz/stats/ecdf_utils.py +165 -0
  41. arviz/stats/stats.py +241 -38
  42. arviz/stats/stats_utils.py +36 -7
  43. arviz/tests/base_tests/test_data.py +73 -5
  44. arviz/tests/base_tests/test_plots_bokeh.py +0 -1
  45. arviz/tests/base_tests/test_plots_matplotlib.py +24 -1
  46. arviz/tests/base_tests/test_stats.py +43 -1
  47. arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
  48. arviz/tests/base_tests/test_stats_utils.py +3 -3
  49. arviz/tests/external_tests/test_data_beanmachine.py +2 -0
  50. arviz/tests/external_tests/test_data_numpyro.py +3 -3
  51. arviz/tests/external_tests/test_data_pyjags.py +3 -1
  52. arviz/tests/external_tests/test_data_pyro.py +3 -3
  53. arviz/tests/helpers.py +8 -8
  54. arviz/utils.py +15 -7
  55. arviz/wrappers/wrap_pymc.py +1 -1
  56. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/METADATA +16 -15
  57. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/RECORD +60 -58
  58. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/WHEEL +1 -1
  59. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/LICENSE +0 -0
  60. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ # pylint: disable=too-many-lines
1
2
  """CmdStanPy-specific conversion code."""
2
3
  import logging
3
4
  import re
@@ -72,15 +73,26 @@ class CmdStanPyConverter:
72
73
 
73
74
  self.dtypes = dtypes
74
75
 
75
- if hasattr(self.posterior, "metadata"):
76
+ if hasattr(self.posterior, "metadata") and hasattr(
77
+ self.posterior.metadata, "stan_vars_cols"
78
+ ):
79
+ if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars_cols:
80
+ self.log_likelihood = ["log_lik"]
81
+ elif hasattr(self.posterior, "metadata") and hasattr(
82
+ self.posterior.metadata, "stan_vars_cols"
83
+ ):
76
84
  if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars_cols:
77
85
  self.log_likelihood = ["log_lik"]
78
86
  elif hasattr(self.posterior, "stan_vars_cols"):
79
87
  if self.log_likelihood is True and "log_lik" in self.posterior.stan_vars_cols:
80
88
  self.log_likelihood = ["log_lik"]
89
+ elif hasattr(self.posterior, "metadata") and hasattr(self.posterior.metadata, "stan_vars"):
90
+ if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars:
91
+ self.log_likelihood = ["log_lik"]
81
92
  elif (
82
93
  self.log_likelihood is True
83
94
  and self.posterior is not None
95
+ and hasattr(self.posterior, "column_names")
84
96
  and any(name.split("[")[0] == "log_lik" for name in self.posterior.column_names)
85
97
  ):
86
98
  self.log_likelihood = ["log_lik"]
@@ -95,11 +107,17 @@ class CmdStanPyConverter:
95
107
  """Extract posterior samples from output csv."""
96
108
  if not (hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")):
97
109
  return self.posterior_to_xarray_pre_v_0_9_68()
110
+ if (
111
+ hasattr(self.posterior, "metadata")
112
+ and hasattr(self.posterior.metadata, "stan_vars_cols")
113
+ ) or hasattr(self.posterior, "stan_vars_cols"):
114
+ return self.posterior_to_xarray_pre_v_1_0_0()
115
+ if hasattr(self.posterior, "metadata") and hasattr(
116
+ self.posterior.metadata, "stan_vars_cols"
117
+ ):
118
+ return self.posterior_to_xarray_pre_v_1_2_0()
98
119
 
99
- if hasattr(self.posterior, "metadata"):
100
- items = list(self.posterior.metadata.stan_vars_cols.keys())
101
- else:
102
- items = list(self.posterior.stan_vars_cols.keys())
120
+ items = list(self.posterior.metadata.stan_vars)
103
121
  if self.posterior_predictive is not None:
104
122
  try:
105
123
  items = _filter(items, self.posterior_predictive)
@@ -119,9 +137,8 @@ class CmdStanPyConverter:
119
137
  valid_cols = []
120
138
  for item in items:
121
139
  if hasattr(self.posterior, "metadata"):
122
- valid_cols.extend(self.posterior.metadata.stan_vars_cols[item])
123
- else:
124
- valid_cols.extend(self.posterior.stan_vars_cols[item])
140
+ if item in self.posterior.metadata.stan_vars:
141
+ valid_cols.append(item)
125
142
 
126
143
  data, data_warmup = _unpack_fit(
127
144
  self.posterior,
@@ -130,7 +147,6 @@ class CmdStanPyConverter:
130
147
  self.dtypes,
131
148
  )
132
149
 
133
- # copy dims and coords - Mitzi question: why???
134
150
  dims = deepcopy(self.dims) if self.dims is not None else {}
135
151
  coords = deepcopy(self.coords) if self.coords is not None else {}
136
152
 
@@ -165,6 +181,12 @@ class CmdStanPyConverter:
165
181
  """Extract sample_stats from fit."""
166
182
  if not (hasattr(fit, "metadata") or hasattr(fit, "sampler_vars_cols")):
167
183
  return self.sample_stats_to_xarray_pre_v_0_9_68(fit)
184
+ if (hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols")) or hasattr(
185
+ fit, "stan_vars_cols"
186
+ ):
187
+ return self.sample_stats_to_xarray_pre_v_1_0_0(fit)
188
+ if hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols"):
189
+ return self.sample_stats_to_xarray_pre_v_1_2_0(fit)
168
190
 
169
191
  dtypes = {
170
192
  "divergent__": bool,
@@ -172,10 +194,9 @@ class CmdStanPyConverter:
172
194
  "treedepth__": np.int64,
173
195
  **self.dtypes,
174
196
  }
175
- if hasattr(fit, "metadata"):
176
- items = list(fit.metadata._method_vars_cols.keys()) # pylint: disable=protected-access
177
- else:
178
- items = list(fit.sampler_vars_cols.keys())
197
+
198
+ items = list(fit.method_variables()) # pylint: disable=protected-access
199
+
179
200
  rename_dict = {
180
201
  "divergent": "diverging",
181
202
  "n_leapfrog": "n_steps",
@@ -196,6 +217,7 @@ class CmdStanPyConverter:
196
217
  data[name] = data.pop(item).astype(dtypes.get(item, float))
197
218
  if data_warmup:
198
219
  data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
220
+
199
221
  return (
200
222
  dict_to_dataset(
201
223
  data,
@@ -229,19 +251,35 @@ class CmdStanPyConverter:
229
251
  """Convert predictive samples to xarray."""
230
252
  predictive = _as_set(names)
231
253
 
232
- if hasattr(fit, "metadata") or hasattr(fit, "stan_vars_cols"):
233
- data, data_warmup = _unpack_fit(
254
+ if not (hasattr(fit, "metadata") or hasattr(fit, "stan_vars_cols")): # pre_v_0_9_68
255
+ valid_cols = _filter_columns(fit.column_names, predictive)
256
+ data, data_warmup = _unpack_frame(
257
+ fit,
258
+ fit.column_names,
259
+ valid_cols,
260
+ self.save_warmup,
261
+ self.dtypes,
262
+ )
263
+ elif (hasattr(fit, "metadata") and hasattr(fit.metadata, "sample_vars_cols")) or hasattr(
264
+ fit, "stan_vars_cols"
265
+ ): # pre_v_1_0_0
266
+ data, data_warmup = _unpack_fit_pre_v_1_0_0(
234
267
  fit,
235
268
  predictive,
236
269
  self.save_warmup,
237
270
  self.dtypes,
238
271
  )
239
- else: # pre_v_0_9_68
240
- valid_cols = _filter_columns(fit.column_names, predictive)
241
- data, data_warmup = _unpack_frame(
272
+ elif hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols"): # pre_v_1_2_0
273
+ data, data_warmup = _unpack_fit_pre_v_1_2_0(
242
274
  fit,
243
- fit.column_names,
244
- valid_cols,
275
+ predictive,
276
+ self.save_warmup,
277
+ self.dtypes,
278
+ )
279
+ else:
280
+ data, data_warmup = _unpack_fit(
281
+ fit,
282
+ predictive,
245
283
  self.save_warmup,
246
284
  self.dtypes,
247
285
  )
@@ -269,14 +307,9 @@ class CmdStanPyConverter:
269
307
  """Convert out of sample predictions samples to xarray."""
270
308
  predictions = _as_set(self.predictions)
271
309
 
272
- if hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols"):
273
- data, data_warmup = _unpack_fit(
274
- self.posterior,
275
- predictions,
276
- self.save_warmup,
277
- self.dtypes,
278
- )
279
- else: # pre_v_0_9_68
310
+ if not (
311
+ hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")
312
+ ): # pre_v_0_9_68
280
313
  columns = self.posterior.column_names
281
314
  valid_cols = _filter_columns(columns, predictions)
282
315
  data, data_warmup = _unpack_frame(
@@ -286,6 +319,34 @@ class CmdStanPyConverter:
286
319
  self.save_warmup,
287
320
  self.dtypes,
288
321
  )
322
+ elif (
323
+ hasattr(self.posterior, "metadata")
324
+ and hasattr(self.posterior.metadata, "sample_vars_cols")
325
+ ) or hasattr(
326
+ self.posterior, "stan_vars_cols"
327
+ ): # pre_v_1_0_0
328
+ data, data_warmup = _unpack_fit_pre_v_1_0_0(
329
+ self.posterior,
330
+ predictions,
331
+ self.save_warmup,
332
+ self.dtypes,
333
+ )
334
+ elif hasattr(self.posterior, "metadata") and hasattr(
335
+ self.posterior.metadata, "stan_vars_cols"
336
+ ): # pre_v_1_2_0
337
+ data, data_warmup = _unpack_fit_pre_v_1_2_0(
338
+ self.posterior,
339
+ predictions,
340
+ self.save_warmup,
341
+ self.dtypes,
342
+ )
343
+ else:
344
+ data, data_warmup = _unpack_fit(
345
+ self.posterior,
346
+ predictions,
347
+ self.save_warmup,
348
+ self.dtypes,
349
+ )
289
350
 
290
351
  return (
291
352
  dict_to_dataset(
@@ -310,14 +371,9 @@ class CmdStanPyConverter:
310
371
  """Convert elementwise log likelihood samples to xarray."""
311
372
  log_likelihood = _as_set(self.log_likelihood)
312
373
 
313
- if hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols"):
314
- data, data_warmup = _unpack_fit(
315
- self.posterior,
316
- log_likelihood,
317
- self.save_warmup,
318
- self.dtypes,
319
- )
320
- else: # pre_v_0_9_68
374
+ if not (
375
+ hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")
376
+ ): # pre_v_0_9_68
321
377
  columns = self.posterior.column_names
322
378
  valid_cols = _filter_columns(columns, log_likelihood)
323
379
  data, data_warmup = _unpack_frame(
@@ -327,6 +383,35 @@ class CmdStanPyConverter:
327
383
  self.save_warmup,
328
384
  self.dtypes,
329
385
  )
386
+ elif (
387
+ hasattr(self.posterior, "metadata")
388
+ and hasattr(self.posterior.metadata, "sample_vars_cols")
389
+ ) or hasattr(
390
+ self.posterior, "stan_vars_cols"
391
+ ): # pre_v_1_0_0
392
+ data, data_warmup = _unpack_fit_pre_v_1_0_0(
393
+ self.posterior,
394
+ log_likelihood,
395
+ self.save_warmup,
396
+ self.dtypes,
397
+ )
398
+ elif hasattr(self.posterior, "metadata") and hasattr(
399
+ self.posterior.metadata, "stan_vars_cols"
400
+ ): # pre_v_1_2_0
401
+ data, data_warmup = _unpack_fit_pre_v_1_2_0(
402
+ self.posterior,
403
+ log_likelihood,
404
+ self.save_warmup,
405
+ self.dtypes,
406
+ )
407
+ else:
408
+ data, data_warmup = _unpack_fit(
409
+ self.posterior,
410
+ log_likelihood,
411
+ self.save_warmup,
412
+ self.dtypes,
413
+ )
414
+
330
415
  if isinstance(self.log_likelihood, dict):
331
416
  data = {obs_name: data[lik_name] for obs_name, lik_name in self.log_likelihood.items()}
332
417
  if data_warmup:
@@ -356,8 +441,29 @@ class CmdStanPyConverter:
356
441
  @requires("prior")
357
442
  def prior_to_xarray(self):
358
443
  """Convert prior samples to xarray."""
359
- if hasattr(self.posterior, "metadata") or hasattr(self.prior, "stan_vars_cols"):
360
- if hasattr(self.posterior, "metadata"):
444
+ if not (
445
+ hasattr(self.prior, "metadata") or hasattr(self.prior, "stan_vars_cols")
446
+ ): # pre_v_0_9_68
447
+ columns = self.prior.column_names
448
+ prior_predictive = _as_set(self.prior_predictive)
449
+ prior_predictive = _filter_columns(columns, prior_predictive)
450
+
451
+ invalid_cols = set(prior_predictive + [col for col in columns if col.endswith("__")])
452
+ valid_cols = [col for col in columns if col not in invalid_cols]
453
+
454
+ data, data_warmup = _unpack_frame(
455
+ self.prior,
456
+ columns,
457
+ valid_cols,
458
+ self.save_warmup,
459
+ self.dtypes,
460
+ )
461
+ elif (
462
+ hasattr(self.prior, "metadata") and hasattr(self.prior.metadata, "sample_vars_cols")
463
+ ) or hasattr(
464
+ self.prior, "stan_vars_cols"
465
+ ): # pre_v_1_0_0
466
+ if hasattr(self.prior, "metadata"):
361
467
  items = list(self.prior.metadata.stan_vars_cols.keys())
362
468
  else:
363
469
  items = list(self.prior.stan_vars_cols.keys())
@@ -366,24 +472,37 @@ class CmdStanPyConverter:
366
472
  items = _filter(items, self.prior_predictive)
367
473
  except ValueError:
368
474
  pass
369
- data, data_warmup = _unpack_fit(
475
+ data, data_warmup = _unpack_fit_pre_v_1_0_0(
370
476
  self.prior,
371
477
  items,
372
478
  self.save_warmup,
373
479
  self.dtypes,
374
480
  )
375
- else: # pre_v_0_9_68
376
- columns = self.prior.column_names
377
- prior_predictive = _as_set(self.prior_predictive)
378
- prior_predictive = _filter_columns(columns, prior_predictive)
379
-
380
- invalid_cols = set(prior_predictive + [col for col in columns if col.endswith("__")])
381
- valid_cols = [col for col in columns if col not in invalid_cols]
382
-
383
- data, data_warmup = _unpack_frame(
481
+ elif hasattr(self.prior, "metadata") and hasattr(
482
+ self.prior.metadata, "stan_vars_cols"
483
+ ): # pre_v_1_2_0
484
+ items = list(self.prior.metadata.stan_vars_cols.keys())
485
+ if self.prior_predictive is not None:
486
+ try:
487
+ items = _filter(items, self.prior_predictive)
488
+ except ValueError:
489
+ pass
490
+ data, data_warmup = _unpack_fit_pre_v_1_2_0(
384
491
  self.prior,
385
- columns,
386
- valid_cols,
492
+ items,
493
+ self.save_warmup,
494
+ self.dtypes,
495
+ )
496
+ else:
497
+ items = list(self.prior.metadata.stan_vars.keys())
498
+ if self.prior_predictive is not None:
499
+ try:
500
+ items = _filter(items, self.prior_predictive)
501
+ except ValueError:
502
+ pass
503
+ data, data_warmup = _unpack_fit(
504
+ self.prior,
505
+ items,
387
506
  self.save_warmup,
388
507
  self.dtypes,
389
508
  )
@@ -466,6 +585,113 @@ class CmdStanPyConverter:
466
585
  },
467
586
  )
468
587
 
588
+ def posterior_to_xarray_pre_v_1_2_0(self):
589
+ items = list(self.posterior.metadata.stan_vars_cols)
590
+ if self.posterior_predictive is not None:
591
+ try:
592
+ items = _filter(items, self.posterior_predictive)
593
+ except ValueError:
594
+ pass
595
+ if self.predictions is not None:
596
+ try:
597
+ items = _filter(items, self.predictions)
598
+ except ValueError:
599
+ pass
600
+ if self.log_likelihood is not None:
601
+ try:
602
+ items = _filter(items, self.log_likelihood)
603
+ except ValueError:
604
+ pass
605
+
606
+ valid_cols = []
607
+ for item in items:
608
+ if hasattr(self.posterior, "metadata"):
609
+ if item in self.posterior.metadata.stan_vars_cols:
610
+ valid_cols.append(item)
611
+
612
+ data, data_warmup = _unpack_fit_pre_v_1_2_0(
613
+ self.posterior,
614
+ items,
615
+ self.save_warmup,
616
+ self.dtypes,
617
+ )
618
+
619
+ dims = deepcopy(self.dims) if self.dims is not None else {}
620
+ coords = deepcopy(self.coords) if self.coords is not None else {}
621
+
622
+ return (
623
+ dict_to_dataset(
624
+ data,
625
+ library=self.cmdstanpy,
626
+ coords=coords,
627
+ dims=dims,
628
+ index_origin=self.index_origin,
629
+ ),
630
+ dict_to_dataset(
631
+ data_warmup,
632
+ library=self.cmdstanpy,
633
+ coords=coords,
634
+ dims=dims,
635
+ index_origin=self.index_origin,
636
+ ),
637
+ )
638
+
639
+ @requires("posterior")
640
+ def posterior_to_xarray_pre_v_1_0_0(self):
641
+ if hasattr(self.posterior, "metadata"):
642
+ items = list(self.posterior.metadata.stan_vars_cols.keys())
643
+ else:
644
+ items = list(self.posterior.stan_vars_cols.keys())
645
+ if self.posterior_predictive is not None:
646
+ try:
647
+ items = _filter(items, self.posterior_predictive)
648
+ except ValueError:
649
+ pass
650
+ if self.predictions is not None:
651
+ try:
652
+ items = _filter(items, self.predictions)
653
+ except ValueError:
654
+ pass
655
+ if self.log_likelihood is not None:
656
+ try:
657
+ items = _filter(items, self.log_likelihood)
658
+ except ValueError:
659
+ pass
660
+
661
+ valid_cols = []
662
+ for item in items:
663
+ if hasattr(self.posterior, "metadata"):
664
+ valid_cols.extend(self.posterior.metadata.stan_vars_cols[item])
665
+ else:
666
+ valid_cols.extend(self.posterior.stan_vars_cols[item])
667
+
668
+ data, data_warmup = _unpack_fit_pre_v_1_0_0(
669
+ self.posterior,
670
+ items,
671
+ self.save_warmup,
672
+ self.dtypes,
673
+ )
674
+
675
+ dims = deepcopy(self.dims) if self.dims is not None else {}
676
+ coords = deepcopy(self.coords) if self.coords is not None else {}
677
+
678
+ return (
679
+ dict_to_dataset(
680
+ data,
681
+ library=self.cmdstanpy,
682
+ coords=coords,
683
+ dims=dims,
684
+ index_origin=self.index_origin,
685
+ ),
686
+ dict_to_dataset(
687
+ data_warmup,
688
+ library=self.cmdstanpy,
689
+ coords=coords,
690
+ dims=dims,
691
+ index_origin=self.index_origin,
692
+ ),
693
+ )
694
+
469
695
  @requires("posterior")
470
696
  def posterior_to_xarray_pre_v_0_9_68(self):
471
697
  """Extract posterior samples from output csv."""
@@ -544,6 +770,103 @@ class CmdStanPyConverter:
544
770
  ),
545
771
  )
546
772
 
773
+ def sample_stats_to_xarray_pre_v_1_2_0(self, fit):
774
+ dtypes = {
775
+ "divergent__": bool,
776
+ "n_leapfrog__": np.int64,
777
+ "treedepth__": np.int64,
778
+ **self.dtypes,
779
+ }
780
+
781
+ items = list(fit.metadata.method_vars_cols.keys()) # pylint: disable=protected-access
782
+
783
+ rename_dict = {
784
+ "divergent": "diverging",
785
+ "n_leapfrog": "n_steps",
786
+ "treedepth": "tree_depth",
787
+ "stepsize": "step_size",
788
+ "accept_stat": "acceptance_rate",
789
+ }
790
+
791
+ data, data_warmup = _unpack_fit_pre_v_1_2_0(
792
+ fit,
793
+ items,
794
+ self.save_warmup,
795
+ self.dtypes,
796
+ )
797
+ for item in items:
798
+ name = re.sub("__$", "", item)
799
+ name = rename_dict.get(name, name)
800
+ data[name] = data.pop(item).astype(dtypes.get(item, float))
801
+ if data_warmup:
802
+ data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
803
+
804
+ return (
805
+ dict_to_dataset(
806
+ data,
807
+ library=self.cmdstanpy,
808
+ coords=self.coords,
809
+ dims=self.dims,
810
+ index_origin=self.index_origin,
811
+ ),
812
+ dict_to_dataset(
813
+ data_warmup,
814
+ library=self.cmdstanpy,
815
+ coords=self.coords,
816
+ dims=self.dims,
817
+ index_origin=self.index_origin,
818
+ ),
819
+ )
820
+
821
+ def sample_stats_to_xarray_pre_v_1_0_0(self, fit):
822
+ """Extract sample_stats from fit."""
823
+ dtypes = {
824
+ "divergent__": bool,
825
+ "n_leapfrog__": np.int64,
826
+ "treedepth__": np.int64,
827
+ **self.dtypes,
828
+ }
829
+ if hasattr(fit, "metadata"):
830
+ items = list(fit.metadata._method_vars_cols.keys()) # pylint: disable=protected-access
831
+ else:
832
+ items = list(fit.sampler_vars_cols.keys())
833
+ rename_dict = {
834
+ "divergent": "diverging",
835
+ "n_leapfrog": "n_steps",
836
+ "treedepth": "tree_depth",
837
+ "stepsize": "step_size",
838
+ "accept_stat": "acceptance_rate",
839
+ }
840
+
841
+ data, data_warmup = _unpack_fit_pre_v_1_0_0(
842
+ fit,
843
+ items,
844
+ self.save_warmup,
845
+ self.dtypes,
846
+ )
847
+ for item in items:
848
+ name = re.sub("__$", "", item)
849
+ name = rename_dict.get(name, name)
850
+ data[name] = data.pop(item).astype(dtypes.get(item, float))
851
+ if data_warmup:
852
+ data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
853
+ return (
854
+ dict_to_dataset(
855
+ data,
856
+ library=self.cmdstanpy,
857
+ coords=self.coords,
858
+ dims=self.dims,
859
+ index_origin=self.index_origin,
860
+ ),
861
+ dict_to_dataset(
862
+ data_warmup,
863
+ library=self.cmdstanpy,
864
+ coords=self.coords,
865
+ dims=self.dims,
866
+ index_origin=self.index_origin,
867
+ ),
868
+ )
869
+
547
870
  def sample_stats_to_xarray_pre_v_0_9_68(self, fit):
548
871
  """Extract sample_stats from fit."""
549
872
  dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}
@@ -612,6 +935,72 @@ def _filter_columns(columns, spec):
612
935
 
613
936
 
614
937
  def _unpack_fit(fit, items, save_warmup, dtypes):
938
+ num_warmup = 0
939
+ if save_warmup:
940
+ if not fit._save_warmup: # pylint: disable=protected-access
941
+ save_warmup = False
942
+ else:
943
+ num_warmup = fit.num_draws_warmup
944
+
945
+ nchains = fit.chains
946
+ sample = {}
947
+ sample_warmup = {}
948
+ stan_vars_cols = list(fit.metadata.stan_vars)
949
+ sampler_vars = fit.method_variables()
950
+ for item in items:
951
+ if item in stan_vars_cols:
952
+ raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
953
+ raw_draws = np.swapaxes(
954
+ raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
955
+ )
956
+ elif item in sampler_vars:
957
+ raw_draws = np.swapaxes(sampler_vars[item], 0, 1)
958
+ else:
959
+ raise ValueError(f"fit data, unknown variable: {item}")
960
+ raw_draws = raw_draws.astype(dtypes.get(item))
961
+ if save_warmup:
962
+ sample_warmup[item] = raw_draws[:, :num_warmup, ...]
963
+ sample[item] = raw_draws[:, num_warmup:, ...]
964
+ else:
965
+ sample[item] = raw_draws
966
+
967
+ return sample, sample_warmup
968
+
969
+
970
+ def _unpack_fit_pre_v_1_2_0(fit, items, save_warmup, dtypes):
971
+ num_warmup = 0
972
+ if save_warmup:
973
+ if not fit._save_warmup: # pylint: disable=protected-access
974
+ save_warmup = False
975
+ else:
976
+ num_warmup = fit.num_draws_warmup
977
+
978
+ nchains = fit.chains
979
+ sample = {}
980
+ sample_warmup = {}
981
+ stan_vars_cols = list(fit.metadata.stan_vars_cols)
982
+ sampler_vars = fit.method_variables()
983
+ for item in items:
984
+ if item in stan_vars_cols:
985
+ raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
986
+ raw_draws = np.swapaxes(
987
+ raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
988
+ )
989
+ elif item in sampler_vars:
990
+ raw_draws = np.swapaxes(sampler_vars[item], 0, 1)
991
+ else:
992
+ raise ValueError(f"fit data, unknown variable: {item}")
993
+ raw_draws = raw_draws.astype(dtypes.get(item))
994
+ if save_warmup:
995
+ sample_warmup[item] = raw_draws[:, :num_warmup, ...]
996
+ sample[item] = raw_draws[:, num_warmup:, ...]
997
+ else:
998
+ sample[item] = raw_draws
999
+
1000
+ return sample, sample_warmup
1001
+
1002
+
1003
+ def _unpack_fit_pre_v_1_0_0(fit, items, save_warmup, dtypes):
615
1004
  """Transform fit to dictionary containing ndarrays.
616
1005
 
617
1006
  Parameters
arviz/data/io_pyjags.py CHANGED
@@ -52,7 +52,7 @@ class PyJAGSConverter:
52
52
  self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
53
53
  self.warmup_iterations = warmup_iterations
54
54
 
55
- import pyjags
55
+ import pyjags # pylint: disable=import-error
56
56
 
57
57
  self.pyjags = pyjags
58
58
 
@@ -149,7 +149,10 @@ def get_draws(
149
149
  variables = tuple(variables)
150
150
 
151
151
  if warmup_iterations > 0:
152
- (warmup_samples, actual_samples,) = _split_pyjags_dict_in_warmup_and_actual_samples(
152
+ (
153
+ warmup_samples,
154
+ actual_samples,
155
+ ) = _split_pyjags_dict_in_warmup_and_actual_samples(
153
156
  pyjags_samples=pyjags_samples,
154
157
  warmup_iterations=warmup_iterations,
155
158
  variable_names=variables,
arviz/data/io_pystan.py CHANGED
@@ -676,8 +676,7 @@ def get_draws(fit, variables=None, ignore=None, warmup=False, dtypes=None):
676
676
  for item in par_keys:
677
677
  _, shape = item.replace("]", "").split("[")
678
678
  shape_idx_min = min(int(shape_value) for shape_value in shape.split(","))
679
- if shape_idx_min < shift:
680
- shift = shape_idx_min
679
+ shift = min(shift, shape_idx_min)
681
680
  # If shift is higher than 1, this will probably mean that Stan
682
681
  # has implemented sparse structure (saves only non-zero parts),
683
682
  # but let's hope that dims are still corresponding to the full shape
arviz/labels.py CHANGED
@@ -100,6 +100,8 @@ class BaseLabeller:
100
100
  """WIP."""
101
101
  var_name_str = self.var_name_to_str(var_name)
102
102
  pp_var_name_str = self.var_name_to_str(pp_var_name)
103
+ if var_name_str == pp_var_name_str:
104
+ return f"{var_name_str}"
103
105
  return f"{var_name_str} / {pp_var_name_str}"
104
106
 
105
107
  def model_name_to_str(self, model_name):
@@ -171,8 +171,13 @@ def plot_bpv(
171
171
  ax_i.line(0, 0, legend_label=f"bpv={p_value:.2f}", alpha=0)
172
172
 
173
173
  if plot_mean:
174
- ax_i.circle(
175
- obs_vals.mean(), 0, fill_color=color, line_color="black", size=markersize
174
+ ax_i.scatter(
175
+ obs_vals.mean(),
176
+ 0,
177
+ fill_color=color,
178
+ line_color="black",
179
+ size=markersize,
180
+ marker="circle",
176
181
  )
177
182
 
178
183
  _title = Title()