arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/plots/khatplot.py DELETED
@@ -1,236 +0,0 @@
1
- """Pareto tail indices plot."""
2
-
3
- import logging
4
- import warnings
5
-
6
- import numpy as np
7
- from xarray import DataArray
8
-
9
- from ..rcparams import rcParams
10
- from ..stats import ELPDData
11
- from ..utils import get_coords
12
- from .plot_utils import format_coords_as_labels, get_plotting_function
13
-
14
- _log = logging.getLogger(__name__)
15
-
16
-
17
- def plot_khat(
18
- khats,
19
- color="C0",
20
- xlabels=False,
21
- show_hlines=False,
22
- show_bins=False,
23
- bin_format="{1:.1f}%",
24
- annotate=False,
25
- threshold=None,
26
- hover_label=False,
27
- hover_format="{1}",
28
- figsize=None,
29
- textsize=None,
30
- coords=None,
31
- legend=False,
32
- markersize=None,
33
- ax=None,
34
- hlines_kwargs=None,
35
- backend=None,
36
- backend_kwargs=None,
37
- show=None,
38
- **kwargs
39
- ):
40
- r"""Plot Pareto tail indices :math:`\hat{k}` for diagnosing convergence in PSIS-LOO.
41
-
42
- Parameters
43
- ----------
44
- khats : ELPDData
45
- The input Pareto tail indices to be plotted.
46
- color : str or array_like, default "C0"
47
- Colors of the scatter plot, if color is a str all dots will have the same color,
48
- if it is the size of the observations, each dot will have the specified color,
49
- otherwise, it will be interpreted as a list of the dims to be used for the color
50
- code. If Matplotlib c argument is passed, it will override the color argument.
51
- xlabels : bool, default False
52
- Use coords as xticklabels.
53
- show_hlines : bool, default False
54
- Show the horizontal lines, by default at the values [0, 0.5, 0.7, 1].
55
- show_bins : bool, default False
56
- Show the percentage of khats falling in each bin, as delimited by hlines.
57
- bin_format : str, optional
58
- The string is used as formatting guide calling ``bin_format.format(count, pct)``.
59
- threshold : float, optional
60
- Show the labels of k values larger than `threshold`. If ``None`` (default), no
61
- observations will be highlighted.
62
- hover_label : bool, default False
63
- Show the datapoint label when hovering over it with the mouse. Requires an interactive
64
- backend.
65
- hover_format : str, default "{1}"
66
- String used to format the hover label via ``hover_format.format(idx, coord_label)``
67
- figsize : (float, float), optional
68
- Figure size. If ``None`` it will be defined automatically.
69
- textsize : float, optional
70
- Text size scaling factor for labels, titles and lines. If ``None`` it will be autoscaled
71
- based on `figsize`.
72
- coords : mapping, optional
73
- Coordinates of points to plot. **All** values are used for computation, but only a
74
- a subset can be plotted for convenience. See :ref:`this section <common_coords>` for
75
- usage examples.
76
- legend : bool, default False
77
- Include a legend to the plot. Only taken into account when color argument is a dim name.
78
- markersize : int, optional
79
- markersize for scatter plot. Defaults to ``None`` in which case it will
80
- be chosen based on autoscaling for figsize.
81
- ax : axes, optional
82
- Matplotlib axes or bokeh figures.
83
- hlines_kwargs : dict, optional
84
- Additional keywords passed to
85
- :meth:`matplotlib.axes.Axes.hlines`.
86
- backend : {"matplotlib", "bokeh"}, default "matplotlib"
87
- Select plotting backend.
88
- backend_kwargs : dict, optional
89
- These are kwargs specific to the backend being used, passed to
90
- :func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
91
- For additional documentation check the plotting method of the backend.
92
- show : bool, optional
93
- Call backend show function.
94
- kwargs :
95
- Additional keywords passed to
96
- :meth:`matplotlib.axes.Axes.scatter`.
97
-
98
- Returns
99
- -------
100
- axes : matplotlib_axes or bokeh_figures
101
-
102
- See Also
103
- --------
104
- psislw : Pareto smoothed importance sampling (PSIS).
105
-
106
- Examples
107
- --------
108
- Plot estimated pareto shape parameters showing how many fall in each category.
109
-
110
- .. plot::
111
- :context: close-figs
112
-
113
- >>> import arviz as az
114
- >>> radon = az.load_arviz_data("radon")
115
- >>> loo_radon = az.loo(radon, pointwise=True)
116
- >>> az.plot_khat(loo_radon, show_bins=True)
117
-
118
- Show xlabels
119
-
120
- .. plot::
121
- :context: close-figs
122
-
123
- >>> centered_eight = az.load_arviz_data("centered_eight")
124
- >>> khats = az.loo(centered_eight, pointwise=True).pareto_k
125
- >>> az.plot_khat(khats, xlabels=True, threshold=1)
126
-
127
- Use custom color scheme
128
-
129
- .. plot::
130
- :context: close-figs
131
-
132
- >>> counties = radon.posterior.County[radon.constant_data.county_idx].values
133
- >>> colors = [
134
- ... "blue" if county[-1] in ("A", "N") else "green" for county in counties
135
- ... ]
136
- >>> az.plot_khat(loo_radon, color=colors)
137
-
138
- Notes
139
- -----
140
- The Generalized Pareto distribution (GPD) diagnoses convergence rates for importance
141
- sampling. GPD has parameters offset, scale, and shape. The shape parameter (:math:`k`)
142
- tells the distribution's number of finite moments. The pre-asymptotic convergence rate
143
- of importance sampling can be estimated based on the fractional number of finite moments
144
- of the importance ratio distribution. GPD is fitted to the largest importance ratios and
145
- interprets the estimated shape parameter :math:`k`, i.e., :math:`\hat{k}` can then be
146
- used as a diagnostic (most importantly if :math:`\hat{k} > 0.7`, then the convergence
147
- rate is impractically low). See [1]_.
148
-
149
- References
150
- ----------
151
- .. [1] Vehtari, A., Simpson, D., Gelman, A., Yao, Y., Gabry, J. (2024).
152
- Pareto Smoothed Importance Sampling. Journal of Machine Learning
153
- Research, 25(72):1-58.
154
-
155
- """
156
- if annotate:
157
- _log.warning("annotate will be deprecated, please use threshold instead")
158
- threshold = annotate
159
-
160
- if coords is None:
161
- coords = {}
162
-
163
- if color is None:
164
- color = "C0"
165
-
166
- if isinstance(khats, np.ndarray):
167
- warnings.warn(
168
- "support for arrays will be deprecated, please use ELPDData."
169
- "The reason for this, is that we need to know the numbers of draws"
170
- "sampled from the posterior",
171
- FutureWarning,
172
- )
173
- khats = khats.flatten()
174
- xlabels = False
175
- legend = False
176
- dims = []
177
- good_k = None
178
- else:
179
- if isinstance(khats, ELPDData):
180
- good_k = khats.good_k
181
- khats = khats.pareto_k
182
- else:
183
- good_k = None
184
- warnings.warn(
185
- "support for DataArrays will be deprecated, please use ELPDData."
186
- "The reason for this, is that we need to know the numbers of draws"
187
- "sampled from the posterior",
188
- FutureWarning,
189
- )
190
- if not isinstance(khats, DataArray):
191
- raise ValueError("Incorrect khat data input. Check the documentation")
192
-
193
- khats = get_coords(khats, coords)
194
- dims = khats.dims
195
-
196
- n_data_points = khats.size
197
- xdata = np.arange(n_data_points)
198
- if isinstance(khats, DataArray):
199
- coord_labels = format_coords_as_labels(khats)
200
- else:
201
- coord_labels = xdata.astype(str)
202
-
203
- plot_khat_kwargs = dict(
204
- hover_label=hover_label,
205
- hover_format=hover_format,
206
- ax=ax,
207
- figsize=figsize,
208
- xdata=xdata,
209
- khats=khats,
210
- good_k=good_k,
211
- kwargs=kwargs,
212
- threshold=threshold,
213
- coord_labels=coord_labels,
214
- show_hlines=show_hlines,
215
- show_bins=show_bins,
216
- hlines_kwargs=hlines_kwargs,
217
- xlabels=xlabels,
218
- legend=legend,
219
- color=color,
220
- dims=dims,
221
- textsize=textsize,
222
- markersize=markersize,
223
- n_data_points=n_data_points,
224
- bin_format=bin_format,
225
- backend_kwargs=backend_kwargs,
226
- show=show,
227
- )
228
-
229
- if backend is None:
230
- backend = rcParams["plot.backend"]
231
- backend = backend.lower()
232
-
233
- # TODO: Add backend kwargs
234
- plot = get_plotting_function("plot_khat", "khatplot", backend)
235
- axes = plot(**plot_khat_kwargs)
236
- return axes
arviz/plots/lmplot.py DELETED
@@ -1,380 +0,0 @@
1
- """Plot regression figure."""
2
-
3
- import warnings
4
- from numbers import Integral
5
- from itertools import repeat
6
-
7
- import xarray as xr
8
- import numpy as np
9
- from xarray.core.dataarray import DataArray
10
-
11
- from ..sel_utils import xarray_var_iter
12
- from ..rcparams import rcParams
13
- from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
14
-
15
-
16
- def _repeat_flatten_list(lst, n):
17
- return [item for sublist in repeat(lst, n) for item in sublist]
18
-
19
-
20
- def plot_lm(
21
- y,
22
- idata=None,
23
- x=None,
24
- y_model=None,
25
- y_hat=None,
26
- num_samples=50,
27
- kind_pp="samples",
28
- kind_model="lines",
29
- xjitter=False,
30
- plot_dim=None,
31
- backend=None,
32
- y_kwargs=None,
33
- y_hat_plot_kwargs=None,
34
- y_hat_fill_kwargs=None,
35
- y_model_plot_kwargs=None,
36
- y_model_fill_kwargs=None,
37
- y_model_mean_kwargs=None,
38
- backend_kwargs=None,
39
- show=None,
40
- figsize=None,
41
- textsize=None,
42
- axes=None,
43
- legend=True,
44
- grid=True,
45
- ):
46
- """Posterior predictive and mean plots for regression-like data.
47
-
48
- Parameters
49
- ----------
50
- y : str or DataArray or ndarray
51
- If str, variable name from ``observed_data``.
52
- idata : InferenceData, Optional
53
- Optional only if ``y`` is not str.
54
- x : str, tuple of strings, DataArray or array-like, optional
55
- If str or tuple, variable name from ``constant_data``.
56
- If ndarray, could be 1D, or 2D for multiple plots.
57
- If None, coords name of ``y`` (``y`` should be DataArray).
58
- y_model : str or Sequence, Optional
59
- If str, variable name from ``posterior``.
60
- Its dimensions should be same as ``y`` plus added chains and draws.
61
- y_hat : str, Optional
62
- If str, variable name from ``posterior_predictive``.
63
- Its dimensions should be same as ``y`` plus added chains and draws.
64
- num_samples : int, Optional, Default 50
65
- Significant if ``kind_pp`` is "samples" or ``kind_model`` is "lines".
66
- Number of samples to be drawn from posterior predictive or
67
- kind_pp : {"samples", "hdi"}, Default "samples"
68
- Options to visualize uncertainty in data.
69
- kind_model : {"lines", "hdi"}, Default "lines"
70
- Options to visualize uncertainty in mean of the data.
71
- plot_dim : str, Optional
72
- Necessary if ``y`` is multidimensional.
73
- backend : str, Optional
74
- Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
75
- y_kwargs : dict, optional
76
- Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib
77
- and :meth:`bokeh:bokeh.plotting.Figure.circle` in bokeh
78
- y_hat_plot_kwargs : dict, optional
79
- Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib
80
- and :meth:`bokeh:bokeh.plotting.Figure.circle` in bokeh
81
- y_hat_fill_kwargs : dict, optional
82
- Passed to :func:`arviz.plot_hdi`
83
- y_model_plot_kwargs : dict, optional
84
- Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib
85
- and :meth:`bokeh:bokeh.plotting.Figure.line` in bokeh
86
- y_model_fill_kwargs : dict, optional
87
- Significant if ``kind_model`` is "hdi". Passed to :func:`arviz.plot_hdi`
88
- y_model_mean_kwargs : dict, optional
89
- Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib
90
- and :meth:`bokeh:bokeh.plotting.Figure.line` in bokeh
91
- backend_kwargs : dict, optional
92
- These are kwargs specific to the backend being used. Passed to
93
- :func:`matplotlib.pyplot.subplots` or
94
- :func:`bokeh.plotting.figure`.
95
- figsize : (float, float), optional
96
- Figure size. If None it will be defined automatically.
97
- textsize : float, optional
98
- Text size scaling factor for labels, titles and lines. If None it will be
99
- autoscaled based on ``figsize``.
100
- axes : 2D numpy array-like of matplotlib_axes or bokeh_figures, optional
101
- A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
102
- its own array of plot areas (and return it).
103
- show : bool, optional
104
- Call backend show function.
105
- legend : bool, optional
106
- Add legend to figure. By default True.
107
- grid : bool, optional
108
- Add grid to figure. By default True.
109
-
110
-
111
- Returns
112
- -------
113
- axes: matplotlib axes or bokeh figures
114
-
115
- See Also
116
- --------
117
- plot_ts : Plot timeseries data
118
- plot_ppc : Plot for posterior/prior predictive checks
119
-
120
- Examples
121
- --------
122
- Plot regression default plot
123
-
124
- .. plot::
125
- :context: close-figs
126
-
127
- >>> import arviz as az
128
- >>> import numpy as np
129
- >>> import xarray as xr
130
- >>> idata = az.load_arviz_data('regression1d')
131
- >>> x = xr.DataArray(np.linspace(0, 1, 100))
132
- >>> idata.posterior["y_model"] = idata.posterior["intercept"] + idata.posterior["slope"]*x
133
- >>> az.plot_lm(idata=idata, y="y", x=x)
134
-
135
- Plot regression data and mean uncertainty
136
-
137
- .. plot::
138
- :context: close-figs
139
-
140
- >>> az.plot_lm(idata=idata, y="y", x=x, y_model="y_model")
141
-
142
- Plot regression data and mean uncertainty in hdi form
143
-
144
- .. plot::
145
- :context: close-figs
146
-
147
- >>> az.plot_lm(
148
- ... idata=idata, y="y", x=x, y_model="y_model", kind_pp="hdi", kind_model="hdi"
149
- ... )
150
-
151
- Plot regression data for multi-dimensional y using plot_dim
152
-
153
- .. plot::
154
- :context: close-figs
155
-
156
- >>> data = az.from_dict(
157
- ... observed_data = { "y": np.random.normal(size=(5, 7)) },
158
- ... posterior_predictive = {"y": np.random.randn(4, 1000, 5, 7) / 2},
159
- ... dims={"y": ["dim1", "dim2"]},
160
- ... coords={"dim1": range(5), "dim2": range(7)}
161
- ... )
162
- >>> az.plot_lm(idata=data, y="y", plot_dim="dim1")
163
- """
164
- if kind_pp not in ("samples", "hdi"):
165
- raise ValueError("kind_ppc should be either samples or hdi")
166
-
167
- if kind_model not in ("lines", "hdi"):
168
- raise ValueError("kind_model should be either lines or hdi")
169
-
170
- if y_hat is None and isinstance(y, str):
171
- y_hat = y
172
-
173
- if isinstance(y, str):
174
- y = idata.observed_data[y]
175
- elif not isinstance(y, DataArray):
176
- y = xr.DataArray(y)
177
-
178
- if len(y.dims) > 1 and plot_dim is None:
179
- raise ValueError("Argument plot_dim is needed in case of multidimensional data")
180
-
181
- x_var_names = None
182
- if isinstance(x, str):
183
- x = idata.constant_data[x]
184
- x_skip_dims = x.dims
185
- elif isinstance(x, tuple):
186
- x_var_names = x
187
- x = idata.constant_data
188
- x_skip_dims = x.dims
189
- elif isinstance(x, DataArray):
190
- x_skip_dims = x.dims
191
- elif x is None:
192
- x = y.coords[y.dims[0]] if plot_dim is None else y.coords[plot_dim]
193
- x_skip_dims = x.dims
194
- else:
195
- x = xr.DataArray(x)
196
- x_skip_dims = [x.dims[-1]]
197
-
198
- # If posterior is present in idata and y_hat is there, get its values
199
- if isinstance(y_model, str):
200
- if "posterior" not in idata.groups():
201
- warnings.warn("Posterior not found in idata", UserWarning)
202
- y_model = None
203
- elif hasattr(idata.posterior, y_model):
204
- y_model = idata.posterior[y_model]
205
- else:
206
- warnings.warn("y_model not found in posterior", UserWarning)
207
- y_model = None
208
-
209
- # If posterior_predictive is present in idata and y_hat is there, get its values
210
- if isinstance(y_hat, str):
211
- if "posterior_predictive" not in idata.groups():
212
- warnings.warn("posterior_predictive not found in idata", UserWarning)
213
- y_hat = None
214
- elif hasattr(idata.posterior_predictive, y_hat):
215
- y_hat = idata.posterior_predictive[y_hat]
216
- else:
217
- warnings.warn("y_hat not found in posterior_predictive", UserWarning)
218
- y_hat = None
219
-
220
- # Check if num_pp_smaples is valid and generate num_pp_smaples number of random indexes.
221
- # Only needed if kind_pp="samples" or kind_model="lines". Not req for plotting hdi
222
- pp_sample_ix = None
223
- if (y_hat is not None and kind_pp == "samples") or (
224
- y_model is not None and kind_model == "lines"
225
- ):
226
- if y_hat is not None:
227
- total_pp_samples = y_hat.sizes["chain"] * y_hat.sizes["draw"]
228
- else:
229
- total_pp_samples = y_model.sizes["chain"] * y_model.sizes["draw"]
230
-
231
- if (
232
- not isinstance(num_samples, Integral)
233
- or num_samples < 1
234
- or num_samples > total_pp_samples
235
- ):
236
- raise TypeError(f"`num_samples` must be an integer between 1 and {total_pp_samples}.")
237
-
238
- pp_sample_ix = np.random.choice(total_pp_samples, size=num_samples, replace=False)
239
-
240
- # crucial step in case of multidim y
241
- if plot_dim is None:
242
- skip_dims = list(y.dims)
243
- elif isinstance(plot_dim, str):
244
- skip_dims = [plot_dim]
245
- elif isinstance(plot_dim, tuple):
246
- skip_dims = list(plot_dim)
247
-
248
- # Generate x axis plotters.
249
- x = filter_plotters_list(
250
- plotters=list(
251
- xarray_var_iter(
252
- x,
253
- var_names=x_var_names,
254
- skip_dims=set(x_skip_dims),
255
- combined=True,
256
- )
257
- ),
258
- plot_kind="plot_lm",
259
- )
260
-
261
- # Generate y axis plotters
262
- y = filter_plotters_list(
263
- plotters=list(
264
- xarray_var_iter(
265
- y,
266
- skip_dims=set(skip_dims),
267
- combined=True,
268
- )
269
- ),
270
- plot_kind="plot_lm",
271
- )
272
-
273
- # If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs
274
- len_y = len(y)
275
- len_x = len(x)
276
- length_plotters = len_x * len_y
277
- y = _repeat_flatten_list(y, len_x)
278
- x = _repeat_flatten_list(x, len_y)
279
-
280
- # Filter out the required values to generate plotters
281
- if y_hat is not None:
282
- if kind_pp == "samples":
283
- y_hat = y_hat.stack(__sample__=("chain", "draw"))[..., pp_sample_ix]
284
- skip_dims += ["__sample__"]
285
-
286
- y_hat = [
287
- tup
288
- for _, tup in zip(
289
- range(len_y),
290
- xarray_var_iter(
291
- y_hat,
292
- skip_dims=set(skip_dims),
293
- combined=True,
294
- ),
295
- )
296
- ]
297
-
298
- y_hat = _repeat_flatten_list(y_hat, len_x)
299
-
300
- # Filter out the required values to generate plotters
301
- if y_model is not None:
302
- if kind_model == "lines":
303
- var_name = y_model.name if y_model.name else "y_model"
304
- data = y_model.values
305
-
306
- total_samples = data.shape[0] * data.shape[1]
307
- data = data.reshape(total_samples, *data.shape[2:])
308
-
309
- if pp_sample_ix is not None:
310
- data = data[pp_sample_ix]
311
-
312
- if plot_dim is not None:
313
- # For plot_dim case, transpose to get dimension first
314
- data = data.transpose(1, 0, 2)[..., 0]
315
-
316
- # Create plotter tuple(s)
317
- if plot_dim is not None:
318
- y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
319
- else:
320
- y_model = [(var_name, {}, {}, data)]
321
- y_model = _repeat_flatten_list(y_model, len_x)
322
-
323
- elif kind_model == "hdi":
324
- var_name = y_model.name if y_model.name else "y_model"
325
- data = y_model.values
326
-
327
- if plot_dim is not None:
328
- # First transpose to get plot_dim first
329
- data = data.transpose(2, 0, 1, 3)
330
- # For plot_dim case, we just want HDI for first dimension
331
- data = data[..., 0]
332
-
333
- # Reshape to (samples, points)
334
- data = data.transpose(1, 2, 0).reshape(-1, data.shape[0])
335
- y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
336
-
337
- else:
338
- data = data.reshape(-1, data.shape[-1])
339
- y_model = [(var_name, {}, {}, data)]
340
- y_model = _repeat_flatten_list(y_model, len_x)
341
-
342
- if len(y_model) == 1:
343
- y_model = _repeat_flatten_list(y_model, len_x)
344
-
345
- rows, cols = default_grid(length_plotters)
346
-
347
- lmplot_kwargs = dict(
348
- x=x,
349
- y=y,
350
- y_model=y_model,
351
- y_hat=y_hat,
352
- num_samples=num_samples,
353
- kind_pp=kind_pp,
354
- kind_model=kind_model,
355
- length_plotters=length_plotters,
356
- xjitter=xjitter,
357
- rows=rows,
358
- cols=cols,
359
- y_kwargs=y_kwargs,
360
- y_hat_plot_kwargs=y_hat_plot_kwargs,
361
- y_hat_fill_kwargs=y_hat_fill_kwargs,
362
- y_model_plot_kwargs=y_model_plot_kwargs,
363
- y_model_fill_kwargs=y_model_fill_kwargs,
364
- y_model_mean_kwargs=y_model_mean_kwargs,
365
- backend_kwargs=backend_kwargs,
366
- show=show,
367
- figsize=figsize,
368
- textsize=textsize,
369
- axes=axes,
370
- legend=legend,
371
- grid=grid,
372
- )
373
-
374
- if backend is None:
375
- backend = rcParams["plot.backend"]
376
- backend = backend.lower()
377
-
378
- plot = get_plotting_function("plot_lm", "lmplot", backend)
379
- ax = plot(**lmplot_kwargs)
380
- return ax