arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/plots/essplot.py DELETED
@@ -1,319 +0,0 @@
1
- """Plot quantile or local effective sample sizes."""
2
-
3
- import numpy as np
4
- import xarray as xr
5
-
6
- from ..data import convert_to_dataset
7
- from ..labels import BaseLabeller
8
- from ..rcparams import rcParams
9
- from ..sel_utils import xarray_var_iter
10
- from ..stats import ess
11
- from ..utils import _var_names, get_coords
12
- from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
13
-
14
-
15
- def plot_ess(
16
- idata,
17
- var_names=None,
18
- filter_vars=None,
19
- kind="local",
20
- relative=False,
21
- coords=None,
22
- figsize=None,
23
- grid=None,
24
- textsize=None,
25
- rug=False,
26
- rug_kind="diverging",
27
- n_points=20,
28
- extra_methods=False,
29
- min_ess=400,
30
- labeller=None,
31
- ax=None,
32
- extra_kwargs=None,
33
- text_kwargs=None,
34
- hline_kwargs=None,
35
- rug_kwargs=None,
36
- backend=None,
37
- backend_kwargs=None,
38
- show=None,
39
- **kwargs,
40
- ):
41
- r"""Generate quantile, local, or evolution ESS plots.
42
-
43
- The local and the quantile ESS plots are recommended for checking
44
- that there are enough samples for all the explored regions of the
45
- parameter space. Checking local and quantile ESS is particularly
46
- relevant when working with HDI intervals as opposed to ESS bulk,
47
- which is suitable for point estimates.
48
-
49
- Parameters
50
- ----------
51
- idata : InferenceData
52
- Any object that can be converted to an :class:`arviz.InferenceData` object
53
- Refer to documentation of :func:`arviz.convert_to_dataset` for details.
54
- var_names : list of str, optional
55
- Variables to be plotted. Prefix the variables by ``~`` when you want to exclude
56
- them from the plot. See :ref:`this section <common_var_names>` for usage examples.
57
- filter_vars : {None, "like", "regex"}, default None
58
- If `None` (default), interpret `var_names` as the real variables names. If "like",
59
- interpret `var_names` as substrings of the real variables names. If "regex",
60
- interpret `var_names` as regular expressions on the real variables names. See
61
- :ref:`this section <common_filter_vars>` for usage examples.
62
- kind : {"local", "quantile", "evolution"}, default "local"
63
- Specify the kind of plot:
64
-
65
- * The ``kind="local"`` argument generates the ESS' local efficiency for
66
- estimating quantiles of a desired posterior.
67
- * The ``kind="quantile"`` argument generates the ESS' local efficiency
68
- for estimating small-interval probability of a desired posterior.
69
- * The ``kind="evolution"`` argument generates the estimated ESS'
70
- with incrised number of iterations of a desired posterior.
71
-
72
- relative : bool, default False
73
- Show relative ess in plot ``ress = ess / N``.
74
- coords : dict, optional
75
- Coordinates of `var_names` to be plotted. Passed to :meth:`xarray.Dataset.sel`.
76
- See :ref:`this section <common_coords>` for usage examples.
77
- grid : tuple, optional
78
- Number of rows and columns. By default, the rows and columns are
79
- automatically inferred. See :ref:`this section <common_grid>` for usage examples.
80
- figsize : (float, float), optional
81
- Figure size. If ``None`` it will be defined automatically.
82
- textsize : float, optional
83
- Text size scaling factor for labels, titles and lines. If ``None`` it will be autoscaled
84
- based on `figsize`.
85
- rug : bool, default False
86
- Add a `rug plot <https://en.wikipedia.org/wiki/Rug_plot>`_ for a specific subset of values.
87
- rug_kind : str, default "diverging"
88
- Variable in sample stats to use as rug mask. Must be a boolean variable.
89
- n_points : int, default 20
90
- Number of points for which to plot their quantile/local ess or number of subsets
91
- in the evolution plot.
92
- extra_methods : bool, default False
93
- Plot mean and sd ESS as horizontal lines. Not taken into account if ``kind = 'evolution'``.
94
- min_ess : int, default 400
95
- Minimum number of ESS desired. If ``relative=True`` the line is plotted at
96
- ``min_ess / n_samples`` for local and quantile kinds and as a curve following
97
- the ``min_ess / n`` dependency in evolution kind.
98
- labeller : Labeller, optional
99
- Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
100
- Read the :ref:`label_guide` for more details and usage examples.
101
- ax : 2D array-like of matplotlib_axes or bokeh_figure, optional
102
- A 2D array of locations into which to plot the densities. If not supplied, ArviZ will create
103
- its own array of plot areas (and return it).
104
- extra_kwargs : dict, optional
105
- If evolution plot, `extra_kwargs` is used to plot ess tail and differentiate it
106
- from ess bulk. Otherwise, passed to extra methods lines.
107
- text_kwargs : dict, optional
108
- Only taken into account when ``extra_methods=True``. kwargs passed to ax.annotate
109
- for extra methods lines labels. It accepts the additional
110
- key ``x`` to set ``xy=(text_kwargs["x"], mcse)``
111
- hline_kwargs : dict, optional
112
- kwargs passed to :func:`~matplotlib.axes.Axes.axhline` or to :class:`~bokeh.models.Span`
113
- depending on the backend for the horizontal minimum ESS line.
114
- For relative ess evolution plots the kwargs are passed to
115
- :func:`~matplotlib.axes.Axes.plot` or to :class:`~bokeh.plotting.figure.line`
116
- rug_kwargs : dict
117
- kwargs passed to rug plot.
118
- backend : {"matplotlib", "bokeh"}, default "matplotlib"
119
- Select plotting backend.
120
- backend_kwargs : dict, optional
121
- These are kwargs specific to the backend being used, passed to
122
- :func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
123
- For additional documentation check the plotting method of the backend.
124
- show : bool, optional
125
- Call backend show function.
126
- **kwargs
127
- Passed as-is to :meth:`mpl:matplotlib.axes.Axes.hist` or
128
- :meth:`mpl:matplotlib.axes.Axes.plot` function depending on the
129
- value of `kind`.
130
-
131
- Returns
132
- -------
133
- axes : matplotlib_axes or bokeh_figure
134
-
135
- See Also
136
- --------
137
- ess : Calculate estimate of the effective sample size.
138
-
139
- References
140
- ----------
141
- .. [1] Vehtari et al. (2021). Rank-normalization, folding, and
142
- localization: An improved Rhat for assessing convergence of
143
- MCMC. Bayesian analysis, 16(2):667-718.
144
-
145
- Examples
146
- --------
147
- Plot local ESS.
148
-
149
- .. plot::
150
- :context: close-figs
151
-
152
- >>> import arviz as az
153
- >>> idata = az.load_arviz_data("centered_eight")
154
- >>> coords = {"school": ["Choate", "Lawrenceville"]}
155
- >>> az.plot_ess(
156
- ... idata, kind="local", var_names=["mu", "theta"], coords=coords
157
- ... )
158
-
159
- Plot ESS evolution as the number of samples increase. When the model is converging properly,
160
- both lines in this plot should be roughly linear.
161
-
162
- .. plot::
163
- :context: close-figs
164
-
165
- >>> az.plot_ess(
166
- ... idata, kind="evolution", var_names=["mu", "theta"], coords=coords
167
- ... )
168
-
169
- Customize local ESS plot to look like reference paper.
170
-
171
- .. plot::
172
- :context: close-figs
173
-
174
- >>> az.plot_ess(
175
- ... idata, kind="local", var_names=["mu"], drawstyle="steps-mid", color="k",
176
- ... linestyle="-", marker=None, rug=True, rug_kwargs={"color": "r"}
177
- ... )
178
-
179
- Customize ESS evolution plot to look like reference paper.
180
-
181
- .. plot::
182
- :context: close-figs
183
-
184
- >>> extra_kwargs = {"color": "lightsteelblue"}
185
- >>> az.plot_ess(
186
- ... idata, kind="evolution", var_names=["mu"],
187
- ... color="royalblue", extra_kwargs=extra_kwargs
188
- ... )
189
-
190
- """
191
- valid_kinds = ("local", "quantile", "evolution")
192
- kind = kind.lower()
193
- if kind not in valid_kinds:
194
- raise ValueError(f"Invalid kind, kind must be one of {valid_kinds} not {kind}")
195
-
196
- if coords is None:
197
- coords = {}
198
- if "chain" in coords or "draw" in coords:
199
- raise ValueError("chain and draw are invalid coordinates for this kind of plot")
200
- if labeller is None:
201
- labeller = BaseLabeller()
202
- extra_methods = False if kind == "evolution" else extra_methods
203
-
204
- data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
205
- var_names = _var_names(var_names, data, filter_vars)
206
- n_draws = data.sizes["draw"]
207
- n_samples = n_draws * data.sizes["chain"]
208
-
209
- ess_tail_dataset = None
210
- mean_ess = None
211
- sd_ess = None
212
-
213
- if kind == "quantile":
214
- probs = np.linspace(1 / n_points, 1 - 1 / n_points, n_points)
215
- xdata = probs
216
- ylabel = "{} for quantiles"
217
- ess_dataset = xr.concat(
218
- [
219
- ess(data, var_names=var_names, relative=relative, method="quantile", prob=p)
220
- for p in probs
221
- ],
222
- dim="ess_dim",
223
- )
224
- elif kind == "local":
225
- probs = np.linspace(0, 1, n_points, endpoint=False)
226
- xdata = probs
227
- ylabel = "{} for small intervals"
228
- ess_dataset = xr.concat(
229
- [
230
- ess(
231
- data,
232
- var_names=var_names,
233
- relative=relative,
234
- method="local",
235
- prob=[p, p + 1 / n_points],
236
- )
237
- for p in probs
238
- ],
239
- dim="ess_dim",
240
- )
241
- else:
242
- first_draw = data.draw.values[0]
243
- ylabel = "{}"
244
- xdata = np.linspace(n_samples / n_points, n_samples, n_points)
245
- draw_divisions = np.linspace(n_draws // n_points, n_draws, n_points, dtype=int)
246
- ess_dataset = xr.concat(
247
- [
248
- ess(
249
- data.sel(draw=slice(first_draw + draw_div)),
250
- var_names=var_names,
251
- relative=relative,
252
- method="bulk",
253
- )
254
- for draw_div in draw_divisions
255
- ],
256
- dim="ess_dim",
257
- )
258
- ess_tail_dataset = xr.concat(
259
- [
260
- ess(
261
- data.sel(draw=slice(first_draw + draw_div)),
262
- var_names=var_names,
263
- relative=relative,
264
- method="tail",
265
- )
266
- for draw_div in draw_divisions
267
- ],
268
- dim="ess_dim",
269
- )
270
-
271
- plotters = filter_plotters_list(
272
- list(xarray_var_iter(ess_dataset, var_names=var_names, skip_dims={"ess_dim"})), "plot_ess"
273
- )
274
- length_plotters = len(plotters)
275
- rows, cols = default_grid(length_plotters, grid=grid)
276
-
277
- if extra_methods:
278
- mean_ess = ess(data, var_names=var_names, method="mean", relative=relative)
279
- sd_ess = ess(data, var_names=var_names, method="sd", relative=relative)
280
-
281
- essplot_kwargs = dict(
282
- ax=ax,
283
- plotters=plotters,
284
- xdata=xdata,
285
- ess_tail_dataset=ess_tail_dataset,
286
- mean_ess=mean_ess,
287
- sd_ess=sd_ess,
288
- idata=idata,
289
- data=data,
290
- kind=kind,
291
- extra_methods=extra_methods,
292
- textsize=textsize,
293
- rows=rows,
294
- cols=cols,
295
- figsize=figsize,
296
- kwargs=kwargs,
297
- extra_kwargs=extra_kwargs,
298
- text_kwargs=text_kwargs,
299
- n_samples=n_samples,
300
- relative=relative,
301
- min_ess=min_ess,
302
- labeller=labeller,
303
- ylabel=ylabel,
304
- rug=rug,
305
- rug_kind=rug_kind,
306
- rug_kwargs=rug_kwargs,
307
- hline_kwargs=hline_kwargs,
308
- backend_kwargs=backend_kwargs,
309
- show=show,
310
- )
311
-
312
- if backend is None:
313
- backend = rcParams["plot.backend"]
314
- backend = backend.lower()
315
-
316
- # TODO: Add backend kwargs
317
- plot = get_plotting_function("plot_ess", "essplot", backend)
318
- ax = plot(**essplot_kwargs)
319
- return ax
arviz/plots/forestplot.py DELETED
@@ -1,304 +0,0 @@
1
- """Forest plot."""
2
-
3
- from ..data import convert_to_dataset
4
- from ..labels import BaseLabeller, NoModelLabeller
5
- from ..rcparams import rcParams
6
- from ..utils import _var_names, get_coords
7
- from .plot_utils import get_plotting_function
8
-
9
-
10
- def plot_forest(
11
- data,
12
- kind="forestplot",
13
- model_names=None,
14
- var_names=None,
15
- filter_vars=None,
16
- transform=None,
17
- coords=None,
18
- combined=False,
19
- combine_dims=None,
20
- hdi_prob=None,
21
- rope=None,
22
- quartiles=True,
23
- ess=False,
24
- r_hat=False,
25
- colors="cycle",
26
- textsize=None,
27
- linewidth=None,
28
- markersize=None,
29
- legend=True,
30
- labeller=None,
31
- ridgeplot_alpha=None,
32
- ridgeplot_overlap=2,
33
- ridgeplot_kind="auto",
34
- ridgeplot_truncate=True,
35
- ridgeplot_quantiles=None,
36
- figsize=None,
37
- ax=None,
38
- backend=None,
39
- backend_config=None,
40
- backend_kwargs=None,
41
- show=None,
42
- ):
43
- r"""Forest plot to compare HDI intervals from a number of distributions.
44
-
45
- Generate forest or ridge plots to compare distributions from a model or list of models.
46
- Additionally, the function can display effective sample sizes (ess) and Rhats to visualize
47
- convergence diagnostics alongside the distributions.
48
-
49
- Parameters
50
- ----------
51
- data : InferenceData
52
- Any object that can be converted to an :class:`arviz.InferenceData` object
53
- Refer to documentation of :func:`arviz.convert_to_dataset` for details.
54
- kind : {"forestplot", "ridgeplot"}, default "forestplot"
55
- Specify the kind of plot:
56
-
57
- * The ``kind="forestplot"`` generates credible intervals, where the central points are the
58
- estimated posterior median, the thick lines are the central quartiles, and the thin lines
59
- represent the :math:`100\times(hdi\_prob)\%` highest density intervals.
60
- * The ``kind="ridgeplot"`` option generates density plots (kernel density estimate or
61
- histograms) in the same graph. Ridge plots can be configured to have different overlap,
62
- truncation bounds and quantile markers.
63
-
64
- model_names : list of str, optional
65
- List with names for the models in the list of data. Useful when plotting more that one
66
- dataset.
67
- var_names : list of str, optional
68
- Variables to be plotted. Prefix the variables by ``~`` when you want to exclude
69
- them from the plot. See :ref:`this section <common_var_names>` for usage examples.
70
- combine_dims : set_like of str, optional
71
- List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
72
- See :ref:`this section <common_combine_dims>` for usage examples.
73
- filter_vars : {None, "like", "regex"}, default None
74
- If `None` (default), interpret `var_names` as the real variables names. If "like",
75
- interpret `var_names` as substrings of the real variables names. If "regex",
76
- interpret `var_names` as regular expressions on the real variables names. See
77
- :ref:`this section <common_filter_vars>` for usage examples.
78
- transform : callable or dict, optional
79
- Function to transform the data. Defaults to None, i.e., the identity function.
80
- coords : dict, optional
81
- Coordinates of ``var_names`` to be plotted. Passed to :meth:`xarray.Dataset.sel`.
82
- See :ref:`this section <common_coords>` for usage examples.
83
- combined : bool, default False
84
- Flag for combining multiple chains into a single chain. If False, chains will
85
- be plotted separately. See :ref:`this section <common_combine>` for usage examples.
86
- hdi_prob : float, default 0.94
87
- Plots highest posterior density interval for chosen percentage of density.
88
- See :ref:`this section <common_ hdi_prob>` for usage examples.
89
- rope : list, tuple or dictionary of {str : tuples or lists}, optional
90
- A dictionary of tuples with the lower and upper values of the Region Of Practical
91
- Equivalence. See :ref:`this section <common_rope>` for usage examples.
92
- quartiles : bool, default True
93
- Flag for plotting the interquartile range, in addition to the ``hdi_prob`` intervals.
94
- r_hat : bool, default False
95
- Flag for plotting Split R-hat statistics. Requires 2 or more chains.
96
- ess : bool, default False
97
- Flag for plotting the effective sample size.
98
- colors : list or string, optional
99
- list with valid matplotlib colors, one color per model. Alternative a string can be passed.
100
- If the string is `cycle`, it will automatically chose a color per model from the matplotlibs
101
- cycle. If a single color is passed, eg 'k', 'C2', 'red' this color will be used for all
102
- models. Defaults to 'cycle'.
103
- textsize : float, optional
104
- Text size scaling factor for labels, titles and lines. If `None` it will be autoscaled based
105
- on ``figsize``.
106
- linewidth : int, optional
107
- Line width throughout. If `None` it will be autoscaled based on ``figsize``.
108
- markersize : int, optional
109
- Markersize throughout. If `None` it will be autoscaled based on ``figsize``.
110
- legend : bool, optional
111
- Show a legend with the color encoded model information.
112
- Defaults to True, if there are multiple models.
113
- labeller : Labeller, optional
114
- Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
115
- Read the :ref:`label_guide` for more details and usage examples.
116
- ridgeplot_alpha: float, optional
117
- Transparency for ridgeplot fill. If ``ridgeplot_alpha=0``, border is colored by model,
118
- otherwise a `black` outline is used.
119
- ridgeplot_overlap : float, default 2
120
- Overlap height for ridgeplots.
121
- ridgeplot_kind : string, optional
122
- By default ("auto") continuous variables are plotted using KDEs and discrete ones using
123
- histograms. To override this use "hist" to plot histograms and "density" for KDEs.
124
- ridgeplot_truncate : bool, default True
125
- Whether to truncate densities according to the value of ``hdi_prob``.
126
- ridgeplot_quantiles : list, optional
127
- Quantiles in ascending order used to segment the KDE. Use [.25, .5, .75] for quartiles.
128
- figsize : (float, float), optional
129
- Figure size. If `None`, it will be defined automatically.
130
- ax : axes, optional
131
- :class:`matplotlib.axes.Axes` or :class:`bokeh.plotting.Figure`.
132
- backend : {"matplotlib", "bokeh"}, default "matplotlib"
133
- Select plotting backend.
134
- backend_config : dict, optional
135
- Currently specifies the bounds to use for bokeh axes. Defaults to value set in ``rcParams``.
136
- backend_kwargs : dict, optional
137
- These are kwargs specific to the backend being used, passed to
138
- :func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
139
- For additional documentation check the plotting method of the backend.
140
- show : bool, optional
141
- Call backend show function.
142
-
143
- Returns
144
- -------
145
- 1D ndarray of matplotlib_axes or bokeh_figures
146
-
147
- See Also
148
- --------
149
- plot_posterior : Plot Posterior densities in the style of John K. Kruschke's book.
150
- plot_density : Generate KDE plots for continuous variables and histograms for discrete ones.
151
- summary : Create a data frame with summary statistics.
152
-
153
- Examples
154
- --------
155
- Forestplot
156
-
157
- .. plot::
158
- :context: close-figs
159
-
160
- >>> import arviz as az
161
- >>> non_centered_data = az.load_arviz_data('non_centered_eight')
162
- >>> axes = az.plot_forest(non_centered_data,
163
- >>> kind='forestplot',
164
- >>> var_names=["^the"],
165
- >>> filter_vars="regex",
166
- >>> combined=True,
167
- >>> figsize=(9, 7))
168
- >>> axes[0].set_title('Estimated theta for 8 schools model')
169
-
170
- Forestplot with multiple datasets
171
-
172
- .. plot::
173
- :context: close-figs
174
-
175
- >>> centered_data = az.load_arviz_data('centered_eight')
176
- >>> axes = az.plot_forest([non_centered_data, centered_data],
177
- >>> model_names = ["non centered eight", "centered eight"],
178
- >>> kind='forestplot',
179
- >>> var_names=["^the"],
180
- >>> filter_vars="regex",
181
- >>> combined=True,
182
- >>> figsize=(9, 7))
183
- >>> axes[0].set_title('Estimated theta for 8 schools models')
184
-
185
- Ridgeplot
186
-
187
- .. plot::
188
- :context: close-figs
189
-
190
- >>> axes = az.plot_forest(non_centered_data,
191
- >>> kind='ridgeplot',
192
- >>> var_names=['theta'],
193
- >>> combined=True,
194
- >>> ridgeplot_overlap=3,
195
- >>> colors='white',
196
- >>> figsize=(9, 7))
197
- >>> axes[0].set_title('Estimated theta for 8 schools model')
198
-
199
- Ridgeplot non-truncated and with quantiles
200
-
201
- .. plot::
202
- :context: close-figs
203
-
204
- >>> axes = az.plot_forest(non_centered_data,
205
- >>> kind='ridgeplot',
206
- >>> var_names=['theta'],
207
- >>> combined=True,
208
- >>> ridgeplot_truncate=False,
209
- >>> ridgeplot_quantiles=[.25, .5, .75],
210
- >>> ridgeplot_overlap=0.7,
211
- >>> colors='white',
212
- >>> figsize=(9, 7))
213
- >>> axes[0].set_title('Estimated theta for 8 schools model')
214
- """
215
- if not isinstance(data, (list, tuple)):
216
- data = [data]
217
- if len(data) == 1:
218
- legend = False
219
-
220
- if coords is None:
221
- coords = {}
222
-
223
- if combine_dims is None:
224
- combine_dims = set()
225
-
226
- if labeller is None:
227
- labeller = NoModelLabeller() if legend else BaseLabeller()
228
-
229
- datasets = [convert_to_dataset(datum) for datum in reversed(data)]
230
- if transform is not None:
231
- if callable(transform):
232
- datasets = [transform(dataset) for dataset in datasets]
233
- elif isinstance(transform, dict):
234
- transformed_datasets = []
235
- for dataset in datasets:
236
- new_dataset = dataset.copy()
237
- for var_name, func in transform.items():
238
- if var_name in new_dataset:
239
- new_dataset[var_name] = func(new_dataset[var_name])
240
- transformed_datasets.append(new_dataset)
241
- datasets = transformed_datasets
242
- else:
243
- raise ValueError("transform must be either a callable or a dict {var_name: callable}")
244
- datasets = get_coords(
245
- datasets, list(reversed(coords)) if isinstance(coords, (list, tuple)) else coords
246
- )
247
-
248
- var_names = _var_names(var_names, datasets, filter_vars)
249
-
250
- ncols, width_ratios = 1, [3]
251
-
252
- if ess:
253
- ncols += 1
254
- width_ratios.append(1)
255
-
256
- if r_hat:
257
- ncols += 1
258
- width_ratios.append(1)
259
-
260
- if hdi_prob is None:
261
- hdi_prob = rcParams["stats.ci_prob"]
262
- elif not 1 >= hdi_prob > 0:
263
- raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
264
-
265
- plot_forest_kwargs = dict(
266
- ax=ax,
267
- datasets=datasets,
268
- var_names=var_names,
269
- model_names=model_names,
270
- combined=combined,
271
- combine_dims=combine_dims,
272
- colors=colors,
273
- figsize=figsize,
274
- width_ratios=width_ratios,
275
- linewidth=linewidth,
276
- markersize=markersize,
277
- kind=kind,
278
- ncols=ncols,
279
- hdi_prob=hdi_prob,
280
- quartiles=quartiles,
281
- rope=rope,
282
- ridgeplot_overlap=ridgeplot_overlap,
283
- ridgeplot_alpha=ridgeplot_alpha,
284
- ridgeplot_kind=ridgeplot_kind,
285
- ridgeplot_truncate=ridgeplot_truncate,
286
- ridgeplot_quantiles=ridgeplot_quantiles,
287
- textsize=textsize,
288
- legend=legend,
289
- labeller=labeller,
290
- ess=ess,
291
- r_hat=r_hat,
292
- backend_kwargs=backend_kwargs,
293
- backend_config=backend_config,
294
- show=show,
295
- )
296
-
297
- if backend is None:
298
- backend = rcParams["plot.backend"]
299
- backend = backend.lower()
300
-
301
- # TODO: Add backend kwargs
302
- plot = get_plotting_function("plot_forest", "forestplot", backend)
303
- axes = plot(**plot_forest_kwargs)
304
- return axes