arviz 0.17.1__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. arviz/__init__.py +3 -2
  2. arviz/data/__init__.py +5 -2
  3. arviz/data/base.py +102 -11
  4. arviz/data/converters.py +5 -0
  5. arviz/data/datasets.py +1 -0
  6. arviz/data/example_data/data_remote.json +10 -3
  7. arviz/data/inference_data.py +20 -22
  8. arviz/data/io_cmdstan.py +1 -3
  9. arviz/data/io_datatree.py +1 -0
  10. arviz/data/io_dict.py +5 -3
  11. arviz/data/io_emcee.py +1 -0
  12. arviz/data/io_numpyro.py +1 -0
  13. arviz/data/io_pyjags.py +1 -0
  14. arviz/data/io_pyro.py +1 -0
  15. arviz/data/utils.py +1 -0
  16. arviz/plots/__init__.py +1 -0
  17. arviz/plots/autocorrplot.py +1 -0
  18. arviz/plots/backends/bokeh/autocorrplot.py +1 -0
  19. arviz/plots/backends/bokeh/bpvplot.py +1 -0
  20. arviz/plots/backends/bokeh/compareplot.py +1 -0
  21. arviz/plots/backends/bokeh/densityplot.py +1 -0
  22. arviz/plots/backends/bokeh/distplot.py +1 -0
  23. arviz/plots/backends/bokeh/dotplot.py +1 -0
  24. arviz/plots/backends/bokeh/ecdfplot.py +1 -0
  25. arviz/plots/backends/bokeh/elpdplot.py +1 -0
  26. arviz/plots/backends/bokeh/energyplot.py +1 -0
  27. arviz/plots/backends/bokeh/hdiplot.py +1 -0
  28. arviz/plots/backends/bokeh/kdeplot.py +3 -3
  29. arviz/plots/backends/bokeh/khatplot.py +1 -0
  30. arviz/plots/backends/bokeh/lmplot.py +1 -0
  31. arviz/plots/backends/bokeh/loopitplot.py +1 -0
  32. arviz/plots/backends/bokeh/mcseplot.py +1 -0
  33. arviz/plots/backends/bokeh/pairplot.py +1 -0
  34. arviz/plots/backends/bokeh/parallelplot.py +1 -0
  35. arviz/plots/backends/bokeh/posteriorplot.py +1 -0
  36. arviz/plots/backends/bokeh/ppcplot.py +1 -0
  37. arviz/plots/backends/bokeh/rankplot.py +1 -0
  38. arviz/plots/backends/bokeh/separationplot.py +1 -0
  39. arviz/plots/backends/bokeh/traceplot.py +1 -0
  40. arviz/plots/backends/bokeh/violinplot.py +1 -0
  41. arviz/plots/backends/matplotlib/autocorrplot.py +1 -0
  42. arviz/plots/backends/matplotlib/bpvplot.py +1 -0
  43. arviz/plots/backends/matplotlib/compareplot.py +1 -0
  44. arviz/plots/backends/matplotlib/densityplot.py +1 -0
  45. arviz/plots/backends/matplotlib/distcomparisonplot.py +2 -3
  46. arviz/plots/backends/matplotlib/distplot.py +1 -0
  47. arviz/plots/backends/matplotlib/dotplot.py +1 -0
  48. arviz/plots/backends/matplotlib/ecdfplot.py +1 -0
  49. arviz/plots/backends/matplotlib/elpdplot.py +1 -0
  50. arviz/plots/backends/matplotlib/energyplot.py +1 -0
  51. arviz/plots/backends/matplotlib/essplot.py +6 -5
  52. arviz/plots/backends/matplotlib/forestplot.py +1 -0
  53. arviz/plots/backends/matplotlib/hdiplot.py +1 -0
  54. arviz/plots/backends/matplotlib/kdeplot.py +5 -3
  55. arviz/plots/backends/matplotlib/khatplot.py +1 -0
  56. arviz/plots/backends/matplotlib/lmplot.py +1 -0
  57. arviz/plots/backends/matplotlib/loopitplot.py +1 -0
  58. arviz/plots/backends/matplotlib/mcseplot.py +11 -10
  59. arviz/plots/backends/matplotlib/pairplot.py +2 -1
  60. arviz/plots/backends/matplotlib/parallelplot.py +1 -0
  61. arviz/plots/backends/matplotlib/posteriorplot.py +1 -0
  62. arviz/plots/backends/matplotlib/ppcplot.py +1 -0
  63. arviz/plots/backends/matplotlib/rankplot.py +1 -0
  64. arviz/plots/backends/matplotlib/separationplot.py +1 -0
  65. arviz/plots/backends/matplotlib/traceplot.py +1 -0
  66. arviz/plots/backends/matplotlib/tsplot.py +1 -0
  67. arviz/plots/backends/matplotlib/violinplot.py +2 -1
  68. arviz/plots/bpvplot.py +1 -0
  69. arviz/plots/compareplot.py +1 -0
  70. arviz/plots/densityplot.py +1 -0
  71. arviz/plots/distcomparisonplot.py +1 -0
  72. arviz/plots/dotplot.py +1 -0
  73. arviz/plots/ecdfplot.py +1 -0
  74. arviz/plots/elpdplot.py +1 -0
  75. arviz/plots/energyplot.py +1 -0
  76. arviz/plots/essplot.py +1 -0
  77. arviz/plots/forestplot.py +1 -0
  78. arviz/plots/hdiplot.py +1 -0
  79. arviz/plots/khatplot.py +1 -0
  80. arviz/plots/lmplot.py +1 -0
  81. arviz/plots/loopitplot.py +1 -0
  82. arviz/plots/mcseplot.py +1 -0
  83. arviz/plots/pairplot.py +1 -0
  84. arviz/plots/parallelplot.py +1 -0
  85. arviz/plots/plot_utils.py +1 -0
  86. arviz/plots/posteriorplot.py +1 -0
  87. arviz/plots/ppcplot.py +1 -0
  88. arviz/plots/rankplot.py +1 -0
  89. arviz/plots/separationplot.py +1 -0
  90. arviz/plots/traceplot.py +1 -0
  91. arviz/plots/tsplot.py +1 -0
  92. arviz/plots/violinplot.py +1 -0
  93. arviz/rcparams.py +1 -0
  94. arviz/sel_utils.py +1 -0
  95. arviz/static/css/style.css +2 -1
  96. arviz/stats/density_utils.py +2 -1
  97. arviz/stats/diagnostics.py +2 -2
  98. arviz/stats/ecdf_utils.py +1 -0
  99. arviz/stats/stats_refitting.py +1 -0
  100. arviz/stats/stats_utils.py +5 -1
  101. arviz/tests/base_tests/test_data.py +14 -0
  102. arviz/tests/base_tests/test_diagnostics.py +1 -0
  103. arviz/tests/base_tests/test_diagnostics_numba.py +1 -0
  104. arviz/tests/base_tests/test_labels.py +1 -0
  105. arviz/tests/base_tests/test_plots_matplotlib.py +6 -5
  106. arviz/tests/base_tests/test_stats.py +4 -4
  107. arviz/tests/base_tests/test_stats_utils.py +1 -0
  108. arviz/tests/base_tests/test_utils.py +3 -2
  109. arviz/tests/helpers.py +1 -1
  110. arviz/wrappers/__init__.py +1 -0
  111. {arviz-0.17.1.dist-info → arviz-0.18.0.dist-info}/METADATA +7 -7
  112. arviz-0.18.0.dist-info/RECORD +182 -0
  113. arviz-0.17.1.dist-info/RECORD +0 -182
  114. {arviz-0.17.1.dist-info → arviz-0.18.0.dist-info}/LICENSE +0 -0
  115. {arviz-0.17.1.dist-info → arviz-0.18.0.dist-info}/WHEEL +0 -0
  116. {arviz-0.17.1.dist-info → arviz-0.18.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  """Bokeh Posterior predictive plot."""
2
+
2
3
  import numpy as np
3
4
  from bokeh.models.annotations import Legend
4
5
  from bokeh.models.glyphs import Scatter
@@ -1,4 +1,5 @@
1
1
  """Bokeh rankplot."""
2
+
2
3
  import numpy as np
3
4
 
4
5
  from bokeh.models import Span
@@ -1,4 +1,5 @@
1
1
  """Bokeh separation plot."""
2
+
2
3
  import numpy as np
3
4
 
4
5
  from ...plot_utils import _scale_fig_size, vectorized_to_hex
@@ -1,4 +1,5 @@
1
1
  """Bokeh Traceplot."""
2
+
2
3
  import warnings
3
4
  from collections.abc import Iterable
4
5
  from itertools import cycle
@@ -1,4 +1,5 @@
1
1
  """Bokeh Violinplot."""
2
+
2
3
  import numpy as np
3
4
  from bokeh.models.annotations import Title
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib Autocorrplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotib Bayesian p-value Posterior predictive plot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
  from scipy import stats
@@ -1,4 +1,5 @@
1
1
  """Matplotlib Compareplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
 
4
5
  from ...plot_utils import _scale_fig_size
@@ -1,4 +1,5 @@
1
1
  """Matplotlib Densityplot."""
2
+
2
3
  from itertools import cycle
3
4
 
4
5
  import matplotlib.pyplot as plt
@@ -1,4 +1,5 @@
1
1
  """Matplotlib Density Comparison plot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
 
@@ -88,9 +89,7 @@ def plot_dist_comparison(
88
89
  kwargs = (
89
90
  prior_kwargs
90
91
  if group.startswith("prior")
91
- else posterior_kwargs
92
- if group.startswith("posterior")
93
- else observed_kwargs
92
+ else posterior_kwargs if group.startswith("posterior") else observed_kwargs
94
93
  )
95
94
  for idx2, (
96
95
  var_name,
@@ -1,4 +1,5 @@
1
1
  """Matplotlib distplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  from matplotlib import _pylab_helpers
4
5
  import numpy as np
@@ -1,4 +1,5 @@
1
1
  """Matplotlib dotplot."""
2
+
2
3
  import math
3
4
  import warnings
4
5
  import numpy as np
@@ -1,4 +1,5 @@
1
1
  """Matplotlib ecdfplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  from matplotlib.colors import to_hex
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib ELPDPlot."""
2
+
2
3
  import warnings
3
4
 
4
5
  from matplotlib import cm
@@ -1,4 +1,5 @@
1
1
  """Matplotlib energyplot."""
2
+
2
3
  from itertools import cycle
3
4
 
4
5
  import matplotlib.pyplot as plt
@@ -1,4 +1,5 @@
1
1
  """Matplotlib energyplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
  from scipy.stats import rankdata
@@ -127,11 +128,11 @@ def plot_ess(
127
128
  ax_.annotate(
128
129
  "mean",
129
130
  (text_x, mean_ess_i),
130
- va=text_va
131
- if text_va is not None
132
- else "bottom"
133
- if mean_ess_i >= sd_ess_i
134
- else "top",
131
+ va=(
132
+ text_va
133
+ if text_va is not None
134
+ else "bottom" if mean_ess_i >= sd_ess_i else "top"
135
+ ),
135
136
  **text_kwargs,
136
137
  )
137
138
  ax_.axhline(sd_ess_i, **extra_kwargs)
@@ -1,4 +1,5 @@
1
1
  """Matplotlib forestplot."""
2
+
2
3
  from collections import OrderedDict, defaultdict
3
4
  from itertools import tee
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib hdiplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  from matplotlib import _pylab_helpers
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib kdeplot."""
2
+
2
3
  import numpy as np
3
4
  from matplotlib import pyplot as plt
4
5
  from matplotlib import _pylab_helpers
@@ -162,10 +163,11 @@ def plot_kde(
162
163
  ax.grid(False)
163
164
  if contour:
164
165
  qcfs = ax.contourf(x_x, y_y, density, antialiased=True, **contourf_kwargs)
165
- qcs = ax.contour(x_x, y_y, density, **contour_kwargs)
166
+ ax.contour(x_x, y_y, density, **contour_kwargs)
166
167
  if not fill_last:
167
- qcfs.collections[0].set_alpha(0)
168
- qcs.collections[0].set_alpha(0)
168
+ alpha = np.ones(len(qcfs.allsegs), dtype=float)
169
+ alpha[0] = 0
170
+ qcfs.set_alpha(alpha)
169
171
  else:
170
172
  ax.pcolormesh(x_x, y_y, density, **pcolormesh_kwargs)
171
173
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib khatplot."""
2
+
2
3
  import warnings
3
4
 
4
5
  import matplotlib as mpl
@@ -1,4 +1,5 @@
1
1
  """Matplotlib plot linear regression figure."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib loopitplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
  from matplotlib.colors import hsv_to_rgb, rgb_to_hsv, to_hex, to_rgb
@@ -1,4 +1,5 @@
1
1
  """Matplotlib mcseplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
  from scipy.stats import rankdata
@@ -94,22 +95,22 @@ def plot_mcse(
94
95
  ax_.annotate(
95
96
  "mean",
96
97
  (text_x, mean_mcse_i),
97
- va=text_va
98
- if text_va is not None
99
- else "bottom"
100
- if mean_mcse_i > sd_mcse_i
101
- else "top",
98
+ va=(
99
+ text_va
100
+ if text_va is not None
101
+ else "bottom" if mean_mcse_i > sd_mcse_i else "top"
102
+ ),
102
103
  **text_kwargs,
103
104
  )
104
105
  ax_.axhline(sd_mcse_i, **extra_kwargs)
105
106
  ax_.annotate(
106
107
  "sd",
107
108
  (text_x, sd_mcse_i),
108
- va=text_va
109
- if text_va is not None
110
- else "bottom"
111
- if sd_mcse_i >= mean_mcse_i
112
- else "top",
109
+ va=(
110
+ text_va
111
+ if text_va is not None
112
+ else "bottom" if sd_mcse_i >= mean_mcse_i else "top"
113
+ ),
113
114
  **text_kwargs,
114
115
  )
115
116
  if rug:
@@ -1,4 +1,5 @@
1
1
  """Matplotlib pairplot."""
2
+
2
3
  import warnings
3
4
  from copy import deepcopy
4
5
 
@@ -333,7 +334,7 @@ def plot_pair(
333
334
  if reference_values:
334
335
  x_name = flat_var_names[i]
335
336
  y_name = flat_var_names[j + not_marginals]
336
- if x_name and y_name not in difference:
337
+ if (x_name not in difference) and (y_name not in difference):
337
338
  ax[j, i].plot(
338
339
  reference_values_copy[x_name],
339
340
  reference_values_copy[y_name],
@@ -1,4 +1,5 @@
1
1
  """Matplotlib Parallel coordinates plot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib Plot posterior densities."""
2
+
2
3
  from numbers import Number
3
4
 
4
5
  import matplotlib.pyplot as plt
@@ -1,4 +1,5 @@
1
1
  """Matplotlib Posterior predictive plot."""
2
+
2
3
  import logging
3
4
  import platform
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib rankplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib separation plot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib traceplot."""
2
+
2
3
  import warnings
3
4
  from itertools import cycle
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib plot time series figure."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Matplotlib Violinplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
 
@@ -60,7 +61,7 @@ def plot_violin(
60
61
  cols,
61
62
  backend_kwargs=backend_kwargs,
62
63
  )
63
- fig.set_constrained_layout(False)
64
+ fig.set_layout_engine("none")
64
65
  fig.subplots_adjust(wspace=0)
65
66
 
66
67
  ax = np.atleast_1d(ax)
arviz/plots/bpvplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Bayesian p-value Posterior/Prior predictive plot."""
2
+
2
3
  import numpy as np
3
4
 
4
5
  from ..labels import BaseLabeller
@@ -1,4 +1,5 @@
1
1
  """Summary plot for model comparison."""
2
+
2
3
  import numpy as np
3
4
 
4
5
  from ..labels import BaseLabeller
@@ -1,4 +1,5 @@
1
1
  """KDE and histogram plots for multiple variables."""
2
+
2
3
  import warnings
3
4
 
4
5
  from ..data import convert_to_dataset
@@ -1,4 +1,5 @@
1
1
  """Density Comparison plot."""
2
+
2
3
  import warnings
3
4
  from ..labels import BaseLabeller
4
5
  from ..rcparams import rcParams
arviz/plots/dotplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot distribution as dot plot or quantile dot plot."""
2
+
2
3
  import numpy as np
3
4
 
4
5
  from ..rcparams import rcParams
arviz/plots/ecdfplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot ecdf or ecdf-difference plot with confidence bands."""
2
+
2
3
  import numpy as np
3
4
  from scipy.stats import uniform
4
5
 
arviz/plots/elpdplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot pointwise elpd estimations of inference data."""
2
+
2
3
  import numpy as np
3
4
 
4
5
  from ..rcparams import rcParams
arviz/plots/energyplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot energy transition distribution in HMC inference."""
2
+
2
3
  import warnings
3
4
 
4
5
  from ..data import convert_to_dataset
arviz/plots/essplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot quantile or local effective sample sizes."""
2
+
2
3
  import numpy as np
3
4
  import xarray as xr
4
5
 
arviz/plots/forestplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Forest plot."""
2
+
2
3
  from ..data import convert_to_dataset
3
4
  from ..labels import BaseLabeller, NoModelLabeller
4
5
  from ..rcparams import rcParams
arviz/plots/hdiplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot highest density intervals for regression data."""
2
+
2
3
  import warnings
3
4
 
4
5
  import numpy as np
arviz/plots/khatplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Pareto tail indices plot."""
2
+
2
3
  import logging
3
4
 
4
5
  import numpy as np
arviz/plots/lmplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot regression figure."""
2
+
2
3
  import warnings
3
4
  from numbers import Integral
4
5
  from itertools import repeat
arviz/plots/loopitplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot LOO-PIT predictive checks of inference data."""
2
+
2
3
  import numpy as np
3
4
  from scipy import stats
4
5
 
arviz/plots/mcseplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot quantile MC standard error."""
2
+
2
3
  import numpy as np
3
4
  import xarray as xr
4
5
 
arviz/plots/pairplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot a scatter, kde and/or hexbin of sampled parameters."""
2
+
2
3
  import warnings
3
4
  from typing import List, Optional, Union
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Parallel coordinates plot showing posterior points with and without divergences marked."""
2
+
2
3
  import numpy as np
3
4
  from scipy.stats import rankdata
4
5
 
arviz/plots/plot_utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Utilities for plotting."""
2
+
2
3
  import importlib
3
4
  import warnings
4
5
  from typing import Any, Dict
@@ -1,4 +1,5 @@
1
1
  """Plot posterior densities."""
2
+
2
3
  from ..data import convert_to_dataset
3
4
  from ..labels import BaseLabeller
4
5
  from ..sel_utils import xarray_var_iter
arviz/plots/ppcplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Posterior/Prior predictive plot."""
2
+
2
3
  import logging
3
4
  import warnings
4
5
  from numbers import Integral
arviz/plots/rankplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Histograms of ranked posterior draws, plotted for each chain."""
2
+
2
3
  from itertools import cycle
3
4
 
4
5
  import matplotlib.pyplot as plt
@@ -1,4 +1,5 @@
1
1
  """Separation plot for discrete outcome models."""
2
+
2
3
  import warnings
3
4
 
4
5
  import numpy as np
arviz/plots/traceplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot kde or histograms and values from MCMC samples."""
2
+
2
3
  import warnings
3
4
  from typing import Any, Callable, List, Mapping, Optional, Tuple, Union, Sequence
4
5
 
arviz/plots/tsplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot timeseries data."""
2
+
2
3
  import warnings
3
4
  import numpy as np
4
5
 
arviz/plots/violinplot.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plot posterior traces as violin plot."""
2
+
2
3
  from ..data import convert_to_dataset
3
4
  from ..labels import BaseLabeller
4
5
  from ..sel_utils import xarray_var_iter
arviz/rcparams.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """ArviZ rcparams. Based on matplotlib's implementation."""
2
+
2
3
  import locale
3
4
  import logging
4
5
  import os
arviz/sel_utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Utilities for selecting and iterating on xarray objects."""
2
+
2
3
  from itertools import product, tee
3
4
 
4
5
  import numpy as np
@@ -302,7 +302,8 @@ dl.xr-attrs {
302
302
  grid-template-columns: 125px auto;
303
303
  }
304
304
 
305
- .xr-attrs dt, dd {
305
+ .xr-attrs dt,
306
+ .xr-attrs dd {
306
307
  padding: 0;
307
308
  margin: 0;
308
309
  float: left;
@@ -5,7 +5,8 @@ import warnings
5
5
  import numpy as np
6
6
  from scipy.fftpack import fft
7
7
  from scipy.optimize import brentq
8
- from scipy.signal import convolve, convolve2d, gaussian # pylint: disable=no-name-in-module
8
+ from scipy.signal import convolve, convolve2d
9
+ from scipy.signal.windows import gaussian
9
10
  from scipy.sparse import coo_matrix
10
11
  from scipy.special import ive # pylint: disable=no-name-in-module
11
12
 
@@ -836,7 +836,7 @@ def _mcse_sd(ary):
836
836
  return np.nan
837
837
  ess = _ess_sd(ary)
838
838
  if _numba_flag:
839
- sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)))
839
+ sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)).item())
840
840
  else:
841
841
  sd = np.std(ary, ddof=1)
842
842
  fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
@@ -904,7 +904,7 @@ def _mc_error(ary, batches=5, circular=False):
904
904
  else:
905
905
  std = stats.circstd(ary, high=np.pi, low=-np.pi)
906
906
  elif _numba_flag:
907
- std = float(_sqrt(svar(ary), np.zeros(1)))
907
+ std = float(_sqrt(svar(ary), np.zeros(1)).item())
908
908
  else:
909
909
  std = np.std(ary)
910
910
  return std / np.sqrt(len(ary))
arviz/stats/ecdf_utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Functions for evaluating ECDFs and their confidence bands."""
2
+
2
3
  from typing import Any, Callable, Optional, Tuple
3
4
  import warnings
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Stats functions that require refitting the model."""
2
+
2
3
  import logging
3
4
  import warnings
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Stats-utility functions for ArviZ."""
2
+
2
3
  import warnings
3
4
  from collections.abc import Sequence
4
5
  from copy import copy as _copy
@@ -134,7 +135,10 @@ def make_ufunc(
134
135
  raise TypeError(msg)
135
136
  for idx in np.ndindex(out.shape[:n_dims_out]):
136
137
  arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
137
- out[idx] = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
138
+ out_idx = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
139
+ if n_dims_out is None:
140
+ out_idx = out_idx.item()
141
+ out[idx] = out_idx
138
142
  return out
139
143
 
140
144
  def _multi_ufunc(*args, out=None, out_shape=None, **kwargs):
@@ -1077,6 +1077,20 @@ def test_dict_to_dataset():
1077
1077
  assert set(dataset.b.coords) == {"chain", "draw", "c"}
1078
1078
 
1079
1079
 
1080
+ def test_nested_dict_to_dataset():
1081
+ datadict = {
1082
+ "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
1083
+ "d": np.random.randn(100),
1084
+ }
1085
+ dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]})
1086
+ assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"}
1087
+ assert set(dataset.coords) == {"chain", "draw", "c"}
1088
+
1089
+ assert set(dataset[("top", "a")].coords) == {"chain", "draw"}
1090
+ assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"}
1091
+ assert set(dataset.d.coords) == {"chain", "draw"}
1092
+
1093
+
1080
1094
  def test_dict_to_dataset_event_dims_error():
1081
1095
  datadict = {"a": np.random.randn(1, 100, 10)}
1082
1096
  coords = {"b": np.arange(10), "c": ["x", "y", "z"]}
@@ -1,4 +1,5 @@
1
1
  """Test Diagnostic methods"""
2
+
2
3
  # pylint: disable=redefined-outer-name, no-member, too-many-public-methods
3
4
  import os
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Test Diagnostic methods"""
2
+
2
3
  import importlib
3
4
 
4
5
  # pylint: disable=redefined-outer-name, no-member, too-many-public-methods
@@ -1,4 +1,5 @@
1
1
  """Tests for labeller classes."""
2
+
2
3
  import pytest
3
4
 
4
5
  from ...labels import (