arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,177 +0,0 @@
1
- """Summary plot for model comparison."""
2
-
3
- import numpy as np
4
-
5
- from ..labels import BaseLabeller
6
- from ..rcparams import rcParams
7
- from .plot_utils import get_plotting_function
8
-
9
-
10
- def plot_compare(
11
- comp_df,
12
- insample_dev=False,
13
- plot_standard_error=True,
14
- plot_ic_diff=False,
15
- order_by_rank=True,
16
- legend=False,
17
- title=True,
18
- figsize=None,
19
- textsize=None,
20
- labeller=None,
21
- plot_kwargs=None,
22
- ax=None,
23
- backend=None,
24
- backend_kwargs=None,
25
- show=None,
26
- ):
27
- r"""Summary plot for model comparison.
28
-
29
- Models are compared based on their expected log pointwise predictive density (ELPD).
30
- This plot is in the style of the one used in [2]_. Chapter 6 in the first edition
31
- or 7 in the second.
32
-
33
- Notes
34
- -----
35
- The ELPD is estimated either by Pareto smoothed importance sampling leave-one-out
36
- cross-validation (LOO) or using the widely applicable information criterion (WAIC).
37
- We recommend LOO in line with the work presented by [1]_.
38
-
39
- Parameters
40
- ----------
41
- comp_df : pandas.DataFrame
42
- Result of the :func:`arviz.compare` method.
43
- insample_dev : bool, default False
44
- Plot in-sample ELPD, that is the value of the information criteria without the
45
- penalization given by the effective number of parameters (p_loo or p_waic).
46
- plot_standard_error : bool, default True
47
- Plot the standard error of the ELPD.
48
- plot_ic_diff : bool, default False
49
- Plot standard error of the difference in ELPD between each model
50
- and the top-ranked model.
51
- order_by_rank : bool, default True
52
- If True ensure the best model is used as reference.
53
- legend : bool, default False
54
- Add legend to figure.
55
- figsize : (float, float), optional
56
- If `None`, size is (6, num of models) inches.
57
- title : bool, default True
58
- Show a tittle with a description of how to interpret the plot.
59
- textsize : float, optional
60
- Text size scaling factor for labels, titles and lines. If `None` it will be autoscaled based
61
- on `figsize`.
62
- labeller : Labeller, optional
63
- Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
64
- Read the :ref:`label_guide` for more details and usage examples.
65
- plot_kwargs : dict, optional
66
- Optional arguments for plot elements. Currently accepts 'color_ic',
67
- 'marker_ic', 'color_insample_dev', 'marker_insample_dev', 'color_dse',
68
- 'marker_dse', 'ls_min_ic' 'color_ls_min_ic', 'fontsize'
69
- ax : matplotlib_axes or bokeh_figure, optional
70
- Matplotlib axes or bokeh figure.
71
- backend : {"matplotlib", "bokeh"}, default "matplotlib"
72
- Select plotting backend.
73
- backend_kwargs : bool, optional
74
- These are kwargs specific to the backend being used, passed to
75
- :func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
76
- For additional documentation check the plotting method of the backend.
77
- show : bool, optional
78
- Call backend show function.
79
-
80
- Returns
81
- -------
82
- axes : matplotlib_axes or bokeh_figure
83
-
84
- See Also
85
- --------
86
- plot_elpd : Plot pointwise elpd differences between two or more models.
87
- compare : Compare models based on PSIS-LOO loo or WAIC waic cross-validation.
88
- loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
89
- waic : Compute the widely applicable information criterion.
90
-
91
- References
92
- ----------
93
- .. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
94
- cross-validation and WAIC https://arxiv.org/abs/1507.04544
95
-
96
- .. [2] McElreath R. (2022). Statistical Rethinking A Bayesian Course with Examples in
97
- R and Stan, Second edition, CRC Press.
98
-
99
- Examples
100
- --------
101
- Show default compare plot
102
-
103
- .. plot::
104
- :context: close-figs
105
-
106
- >>> import arviz as az
107
- >>> model_compare = az.compare({'Centered 8 schools': az.load_arviz_data('centered_eight'),
108
- >>> 'Non-centered 8 schools': az.load_arviz_data('non_centered_eight')})
109
- >>> az.plot_compare(model_compare)
110
-
111
- Include the in-sample ELDP
112
-
113
- .. plot::
114
- :context: close-figs
115
-
116
- >>> az.plot_compare(model_compare, insample_dev=True)
117
-
118
- """
119
- if plot_kwargs is None:
120
- plot_kwargs = {}
121
-
122
- if labeller is None:
123
- labeller = BaseLabeller()
124
-
125
- yticks_pos, step = np.linspace(0, -1, (comp_df.shape[0] * 2) - 1, retstep=True)
126
- yticks_pos[1::2] = yticks_pos[1::2] + step / 2
127
- labels = [labeller.model_name_to_str(model_name) for model_name in comp_df.index]
128
-
129
- if plot_ic_diff:
130
- yticks_labels = [""] * len(yticks_pos)
131
- yticks_labels[0] = labels[0]
132
- yticks_labels[2::2] = labels[1:]
133
- else:
134
- yticks_labels = labels
135
-
136
- _information_criterion = ["elpd_loo", "elpd_waic"]
137
- column_index = [c.lower() for c in comp_df.columns]
138
- for information_criterion in _information_criterion:
139
- if information_criterion in column_index:
140
- break
141
- else:
142
- raise ValueError(
143
- "comp_df must contain one of the following "
144
- f"information criterion: {_information_criterion}"
145
- )
146
-
147
- if order_by_rank:
148
- comp_df.sort_values(by="rank", inplace=True)
149
-
150
- compareplot_kwargs = dict(
151
- ax=ax,
152
- comp_df=comp_df,
153
- legend=legend,
154
- title=title,
155
- figsize=figsize,
156
- plot_ic_diff=plot_ic_diff,
157
- plot_standard_error=plot_standard_error,
158
- insample_dev=insample_dev,
159
- yticks_pos=yticks_pos,
160
- yticks_labels=yticks_labels,
161
- plot_kwargs=plot_kwargs,
162
- information_criterion=information_criterion,
163
- textsize=textsize,
164
- step=step,
165
- backend_kwargs=backend_kwargs,
166
- show=show,
167
- )
168
-
169
- if backend is None:
170
- backend = rcParams["plot.backend"]
171
- backend = backend.lower()
172
-
173
- # TODO: Add backend kwargs
174
- plot = get_plotting_function("plot_compare", "compareplot", backend)
175
- ax = plot(**compareplot_kwargs)
176
-
177
- return ax
@@ -1,284 +0,0 @@
1
- """KDE and histogram plots for multiple variables."""
2
-
3
- import warnings
4
-
5
- from ..data import convert_to_dataset
6
- from ..labels import BaseLabeller
7
- from ..sel_utils import (
8
- xarray_var_iter,
9
- )
10
- from ..rcparams import rcParams
11
- from ..utils import _var_names
12
- from .plot_utils import default_grid, get_plotting_function
13
-
14
-
15
- # pylint:disable-msg=too-many-function-args
16
- def plot_density(
17
- data,
18
- group="posterior",
19
- data_labels=None,
20
- var_names=None,
21
- filter_vars=None,
22
- combine_dims=None,
23
- transform=None,
24
- hdi_prob=None,
25
- point_estimate="auto",
26
- colors="cycle",
27
- outline=True,
28
- hdi_markers="",
29
- shade=0.0,
30
- bw="default",
31
- circular=False,
32
- grid=None,
33
- figsize=None,
34
- textsize=None,
35
- labeller=None,
36
- ax=None,
37
- backend=None,
38
- backend_kwargs=None,
39
- show=None,
40
- ):
41
- r"""Generate KDE plots for continuous variables and histograms for discrete ones.
42
-
43
- Plots are truncated at their 100*(1-alpha)% highest density intervals. Plots are grouped per
44
- variable and colors assigned to models.
45
-
46
- Parameters
47
- ----------
48
- data : InferenceData or iterable of InferenceData
49
- Any object that can be converted to an :class:`arviz.InferenceData` object, or an Iterator
50
- returning a sequence of such objects.
51
- Refer to documentation of :func:`arviz.convert_to_dataset` for details.
52
- group : {"posterior", "prior"}, default "posterior"
53
- Specifies which InferenceData group should be plotted. If "posterior", then the values
54
- in `posterior_predictive` group are compared to the ones in `observed_data`, if "prior" then
55
- the same comparison happens, but with the values in `prior_predictive` group.
56
- data_labels : list of str, default None
57
- List with names for the datasets passed as "data." Useful when plotting more than one
58
- dataset. Must be the same shape as the data parameter.
59
- var_names : list of str, optional
60
- List of variables to plot. If multiple datasets are supplied and `var_names` is not None,
61
- will print the same set of variables for each dataset. Defaults to None, which results in
62
- all the variables being plotted.
63
- filter_vars : {None, "like", "regex"}, default None
64
- If `None` (default), interpret `var_names` as the real variables names. If "like",
65
- interpret `var_names` as substrings of the real variables names. If "regex",
66
- interpret `var_names` as regular expressions on the real variables names. See
67
- :ref:`this section <common_filter_vars>` for usage examples.
68
- combine_dims : set_like of str, optional
69
- List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
70
- See :ref:`this section <common_combine_dims>` for usage examples.
71
- transform : callable
72
- Function to transform data (defaults to `None` i.e. the identity function).
73
- hdi_prob : float, default 0.94
74
- Probability for the highest density interval. Should be in the interval (0, 1].
75
- See :ref:`this section <common_hdi_prob>` for usage examples.
76
- point_estimate : str, optional
77
- Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
78
- Defaults to 'auto' i.e. it falls back to default set in ``rcParams``.
79
- colors : str or list of str, optional
80
- List with valid matplotlib colors, one color per model. Alternative a string can be passed.
81
- If the string is `cycle`, it will automatically choose a color per model from matplotlib's
82
- cycle. If a single color is passed, e.g. 'k', 'C2' or 'red' this color will be used for all
83
- models. Defaults to `cycle`.
84
- outline : bool, default True
85
- Use a line to draw KDEs and histograms.
86
- hdi_markers : str
87
- A valid `matplotlib.markers` like 'v', used to indicate the limits of the highest density
88
- interval. Defaults to empty string (no marker).
89
- shade : float, default 0
90
- Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1
91
- (opaque).
92
- bw : float or str, optional
93
- If numeric, indicates the bandwidth and must be positive.
94
- If str, indicates the method to estimate the bandwidth and must be
95
- one of "scott", "silverman", "isj" or "experimental" when `circular` is False
96
- and "taylor" (for now) when `circular` is True.
97
- Defaults to "default" which means "experimental" when variable is not circular
98
- and "taylor" when it is.
99
- circular : bool, default False
100
- If True, it interprets the values passed are from a circular variable measured in radians
101
- and a circular KDE is used. Only valid for 1D KDE.
102
- grid : tuple, optional
103
- Number of rows and columns. Defaults to ``None``, the rows and columns are
104
- automatically inferred. See :ref:`this section <common_grid>` for usage examples.
105
- figsize : (float, float), optional
106
- Figure size. If `None` it will be defined automatically.
107
- textsize : float, optional
108
- Text size scaling factor for labels, titles and lines. If `None` it will be autoscaled based
109
- on `figsize`.
110
- labeller : Labeller, optional
111
- Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
112
- Read the :ref:`label_guide` for more details and usage examples.
113
- ax : 2D array-like of matplotlib_axes or bokeh_figure, optional
114
- A 2D array of locations into which to plot the densities. If not supplied, ArviZ will create
115
- its own array of plot areas (and return it).
116
- backend : {"matplotlib", "bokeh"}, default "matplotlib"
117
- Select plotting backend.
118
- backend_kwargs : dict, optional
119
- These are kwargs specific to the backend being used, passed to
120
- :func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
121
- For additional documentation check the plotting method of the backend.
122
- show : bool, optional
123
- Call backend show function.
124
-
125
- Returns
126
- -------
127
- axes : 2D ndarray of matplotlib_axes or bokeh_figure
128
-
129
- See Also
130
- --------
131
- plot_dist : Plot distribution as histogram or kernel density estimates.
132
- plot_posterior : Plot Posterior densities in the style of John K. Kruschke's book.
133
-
134
- Examples
135
- --------
136
- Plot default density plot
137
-
138
- .. plot::
139
- :context: close-figs
140
-
141
- >>> import arviz as az
142
- >>> centered = az.load_arviz_data('centered_eight')
143
- >>> non_centered = az.load_arviz_data('non_centered_eight')
144
- >>> az.plot_density([centered, non_centered])
145
-
146
- Plot variables in a 4x5 grid
147
-
148
- .. plot::
149
- :context: close-figs
150
-
151
- >>> az.plot_density([centered, non_centered], grid=(4, 5))
152
-
153
- Plot subset variables by specifying variable name exactly
154
-
155
- .. plot::
156
- :context: close-figs
157
-
158
- >>> az.plot_density([centered, non_centered], var_names=["mu"])
159
-
160
- Plot a specific `az.InferenceData` group
161
-
162
- .. plot::
163
- :context: close-figs
164
-
165
- >>> az.plot_density([centered, non_centered], var_names=["mu"], group="prior")
166
-
167
- Specify highest density interval
168
-
169
- .. plot::
170
- :context: close-figs
171
-
172
- >>> az.plot_density([centered, non_centered], var_names=["mu"], hdi_prob=.5)
173
-
174
- Shade plots and/or remove outlines
175
-
176
- .. plot::
177
- :context: close-figs
178
-
179
- >>> az.plot_density([centered, non_centered], var_names=["mu"], outline=False, shade=.8)
180
-
181
- Specify binwidth for kernel density estimation
182
-
183
- .. plot::
184
- :context: close-figs
185
-
186
- >>> az.plot_density([centered, non_centered], var_names=["mu"], bw=.9)
187
- """
188
- if isinstance(data, (list, tuple)):
189
- datasets = [convert_to_dataset(datum, group=group) for datum in data]
190
- else:
191
- datasets = [convert_to_dataset(data, group=group)]
192
-
193
- if transform is not None:
194
- datasets = [transform(dataset) for dataset in datasets]
195
-
196
- if labeller is None:
197
- labeller = BaseLabeller()
198
-
199
- var_names = _var_names(var_names, datasets, filter_vars)
200
-
201
- n_data = len(datasets)
202
-
203
- if data_labels is None:
204
- data_labels = [f"{idx}" for idx in range(n_data)] if n_data > 1 else [""]
205
- elif len(data_labels) != n_data:
206
- raise ValueError(
207
- f"The number of names for the models ({len(data_labels)}) "
208
- f"does not match the number of models ({n_data})"
209
- )
210
-
211
- if hdi_prob is None:
212
- hdi_prob = rcParams["stats.ci_prob"]
213
- elif not 1 >= hdi_prob > 0:
214
- raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
215
-
216
- to_plot = [
217
- list(xarray_var_iter(data, var_names, combined=True, skip_dims=combine_dims))
218
- for data in datasets
219
- ]
220
- all_labels = []
221
- length_plotters = []
222
- for plotters in to_plot:
223
- length_plotters.append(len(plotters))
224
- for var_name, selection, isel, _ in plotters:
225
- label = labeller.make_label_vert(var_name, selection, isel)
226
- if label not in all_labels:
227
- all_labels.append(label)
228
- length_plotters = len(all_labels)
229
- max_plots = rcParams["plot.max_subplots"]
230
- max_plots = length_plotters if max_plots is None else max_plots
231
- if length_plotters > max_plots:
232
- warnings.warn(
233
- "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
234
- "of variables to plot ({len_plotters}) in plot_density, generating only "
235
- "{max_plots} plots".format(max_plots=max_plots, len_plotters=length_plotters),
236
- UserWarning,
237
- )
238
- all_labels = all_labels[:max_plots]
239
- to_plot = [
240
- [
241
- (var_name, selection, values)
242
- for var_name, selection, isel, values in plotters
243
- if labeller.make_label_vert(var_name, selection, isel) in all_labels
244
- ]
245
- for plotters in to_plot
246
- ]
247
- length_plotters = max_plots
248
- rows, cols = default_grid(length_plotters, grid=grid, max_cols=3)
249
-
250
- if bw == "default":
251
- bw = "taylor" if circular else "experimental"
252
-
253
- plot_density_kwargs = dict(
254
- ax=ax,
255
- all_labels=all_labels,
256
- to_plot=to_plot,
257
- colors=colors,
258
- bw=bw,
259
- circular=circular,
260
- figsize=figsize,
261
- length_plotters=length_plotters,
262
- rows=rows,
263
- cols=cols,
264
- textsize=textsize,
265
- labeller=labeller,
266
- hdi_prob=hdi_prob,
267
- point_estimate=point_estimate,
268
- hdi_markers=hdi_markers,
269
- outline=outline,
270
- shade=shade,
271
- n_data=n_data,
272
- data_labels=data_labels,
273
- backend_kwargs=backend_kwargs,
274
- show=show,
275
- )
276
-
277
- if backend is None:
278
- backend = rcParams["plot.backend"]
279
- backend = backend.lower()
280
-
281
- # TODO: Add backend kwargs
282
- plot = get_plotting_function("plot_density", "densityplot", backend)
283
- ax = plot(**plot_density_kwargs)
284
- return ax
@@ -1,197 +0,0 @@
1
- """Density Comparison plot."""
2
-
3
- import warnings
4
- from ..labels import BaseLabeller
5
- from ..rcparams import rcParams
6
- from ..utils import _var_names, get_coords
7
- from .plot_utils import get_plotting_function
8
- from ..sel_utils import xarray_var_iter, xarray_sel_iter
9
-
10
-
11
- def plot_dist_comparison(
12
- data,
13
- kind="latent",
14
- figsize=None,
15
- textsize=None,
16
- var_names=None,
17
- coords=None,
18
- combine_dims=None,
19
- transform=None,
20
- legend=True,
21
- labeller=None,
22
- ax=None,
23
- prior_kwargs=None,
24
- posterior_kwargs=None,
25
- observed_kwargs=None,
26
- backend=None,
27
- backend_kwargs=None,
28
- show=None,
29
- ):
30
- r"""Plot to compare fitted and unfitted distributions.
31
-
32
- The resulting plots will show the compared distributions both on
33
- separate axes (particularly useful when one of them is substantially tighter
34
- than another), and plotted together, displaying a grid of three plots per
35
- distribution.
36
-
37
- Parameters
38
- ----------
39
- data : InferenceData
40
- Any object that can be converted to an :class:`arviz.InferenceData` object
41
- containing the posterior/prior data. Refer to documentation of
42
- :func:`arviz.convert_to_dataset` for details.
43
- kind : {"latent", "observed"}, default "latent"
44
- kind of plot to display The "latent" option includes {"prior", "posterior"},
45
- and the "observed" option includes
46
- {"observed_data", "prior_predictive", "posterior_predictive"}.
47
- figsize : (float, float), optional
48
- Figure size. If ``None`` it will be defined automatically.
49
- textsize : float
50
- Text size scaling factor for labels, titles and lines. If ``None`` it will be
51
- autoscaled based on `figsize`.
52
- var_names : str, list, list of lists, optional
53
- if str, plot the variable. if list, plot all the variables in list
54
- of all groups. if list of lists, plot the vars of groups in respective lists.
55
- See :ref:`this section <common_var_names>` for usage examples.
56
- coords : dict
57
- Dictionary mapping dimensions to selected coordinates to be plotted.
58
- Dimensions without a mapping specified will include all coordinates for
59
- that dimension. See :ref:`this section <common_coords>` for usage examples.
60
- combine_dims : set_like of str, optional
61
- List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
62
- See :ref:`this section <common_combine_dims>` for usage examples.
63
- transform : callable
64
- Function to transform data (defaults to `None` i.e. the identity function).
65
- legend : bool
66
- Add legend to figure. By default True.
67
- labeller : Labeller, optional
68
- Class providing the method ``make_pp_label`` to generate the labels in the plot titles.
69
- Read the :ref:`label_guide` for more details and usage examples.
70
- ax : (nvars, 3) array-like of matplotlib_axes, optional
71
- Matplotlib axes: The ax argument should have shape (nvars, 3), where the
72
- last column is for the combined before/after plots and columns 0 and 1 are
73
- for the before and after plots, respectively.
74
- prior_kwargs : dicts, optional
75
- Additional keywords passed to :func:`arviz.plot_dist` for prior/predictive groups.
76
- posterior_kwargs : dicts, optional
77
- Additional keywords passed to :func:`arviz.plot_dist` for posterior/predictive groups.
78
- observed_kwargs : dicts, optional
79
- Additional keywords passed to :func:`arviz.plot_dist` for observed_data group.
80
- backend : {"matplotlib", "bokeh"}, default "matplotlib"
81
- Select plotting backend.
82
- backend_kwargs : dict, optional
83
- These are kwargs specific to the backend being used, passed to
84
- :func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
85
- For additional documentation check the plotting method of the backend.
86
- show : bool, optional
87
- Call backend show function.
88
-
89
- Returns
90
- -------
91
- axes : 2D ndarray of matplotlib_axes
92
- Returned object will have shape (nvars, 3),
93
- where the last column is the combined plot and the first columns are the single plots.
94
-
95
- See Also
96
- --------
97
- plot_bpv : Plot Bayesian p-value for observed data and Posterior/Prior predictive.
98
-
99
- Examples
100
- --------
101
- Plot the prior/posterior plot for specified vars and coords.
102
-
103
- .. plot::
104
- :context: close-figs
105
-
106
- >>> import arviz as az
107
- >>> data = az.load_arviz_data('rugby')
108
- >>> az.plot_dist_comparison(data, var_names=["defs"], coords={"team" : ["Italy"]})
109
-
110
- """
111
- all_groups = ["prior", "posterior"]
112
-
113
- if kind == "observed":
114
- all_groups = ["observed_data", "prior_predictive", "posterior_predictive"]
115
-
116
- if coords is None:
117
- coords = {}
118
-
119
- if labeller is None:
120
- labeller = BaseLabeller()
121
-
122
- datasets = []
123
- groups = []
124
- for group in all_groups:
125
- try:
126
- datasets.append(getattr(data, group))
127
- groups.append(group)
128
- except: # pylint: disable=bare-except
129
- pass
130
-
131
- if var_names is None:
132
- var_names = list(datasets[0].data_vars)
133
-
134
- if isinstance(var_names, str):
135
- var_names = [var_names]
136
-
137
- if isinstance(var_names[0], str):
138
- var_names = [var_names for _ in datasets]
139
-
140
- var_names = [_var_names(vars, dataset) for vars, dataset in zip(var_names, datasets)]
141
-
142
- if transform is not None:
143
- datasets = [transform(dataset) for dataset in datasets]
144
-
145
- datasets = get_coords(datasets, coords)
146
- len_plots = rcParams["plot.max_subplots"] // (len(groups) + 1)
147
- len_plots = len_plots or 1
148
- dc_plotters = [
149
- list(xarray_var_iter(data, var_names=var, combined=True, skip_dims=combine_dims))[
150
- :len_plots
151
- ]
152
- for data, var in zip(datasets, var_names)
153
- ]
154
-
155
- total_plots = sum(
156
- 1 for _ in xarray_sel_iter(datasets[0], var_names=var_names[0], combined=True)
157
- ) * (len(groups) + 1)
158
- maxplots = len(dc_plotters[0]) * (len(groups) + 1)
159
-
160
- if total_plots > rcParams["plot.max_subplots"]:
161
- warnings.warn(
162
- "rcParams['plot.max_subplots'] ({rcParam}) is smaller than the number "
163
- "of subplots to plot ({len_plotters}), generating only {max_plots} "
164
- "plots".format(
165
- rcParam=rcParams["plot.max_subplots"], len_plotters=total_plots, max_plots=maxplots
166
- ),
167
- UserWarning,
168
- )
169
-
170
- nvars = len(dc_plotters[0])
171
- ngroups = len(groups)
172
-
173
- distcomparisonplot_kwargs = dict(
174
- ax=ax,
175
- nvars=nvars,
176
- ngroups=ngroups,
177
- figsize=figsize,
178
- dc_plotters=dc_plotters,
179
- legend=legend,
180
- groups=groups,
181
- textsize=textsize,
182
- labeller=labeller,
183
- prior_kwargs=prior_kwargs,
184
- posterior_kwargs=posterior_kwargs,
185
- observed_kwargs=observed_kwargs,
186
- backend_kwargs=backend_kwargs,
187
- show=show,
188
- )
189
-
190
- if backend is None:
191
- backend = rcParams["plot.backend"]
192
- backend = backend.lower()
193
-
194
- # TODO: Add backend kwargs
195
- plot = get_plotting_function("plot_dist_comparison", "distcomparisonplot", backend)
196
- axes = plot(**distcomparisonplot_kwargs)
197
- return axes