arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,298 +0,0 @@
1
- """Plot posterior densities."""
2
-
3
- from ..data import convert_to_dataset
4
- from ..labels import BaseLabeller
5
- from ..sel_utils import xarray_var_iter
6
- from ..utils import _var_names, get_coords
7
- from ..rcparams import rcParams
8
- from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
9
-
10
-
11
- def plot_posterior(
12
- data,
13
- var_names=None,
14
- filter_vars=None,
15
- combine_dims=None,
16
- transform=None,
17
- coords=None,
18
- grid=None,
19
- figsize=None,
20
- textsize=None,
21
- hdi_prob=None,
22
- multimodal=False,
23
- skipna=False,
24
- round_to=None,
25
- point_estimate="auto",
26
- group="posterior",
27
- rope=None,
28
- ref_val=None,
29
- rope_color="C2",
30
- ref_val_color="C1",
31
- kind=None,
32
- bw="default",
33
- circular=False,
34
- bins=None,
35
- labeller=None,
36
- ax=None,
37
- backend=None,
38
- backend_kwargs=None,
39
- show=None,
40
- **kwargs
41
- ):
42
- r"""Plot Posterior densities in the style of John K. Kruschke's book.
43
-
44
- Parameters
45
- ----------
46
- data: obj
47
- Any object that can be converted to an :class:`arviz.InferenceData` object.
48
- Refer to the documentation of :func:`arviz.convert_to_dataset` for details
49
- var_names: list of variable names
50
- Variables to be plotted, two variables are required. Prefix the variables with ``~``
51
- when you want to exclude them from the plot.
52
- filter_vars: {None, "like", "regex"}, optional, default=None
53
- If `None` (default), interpret var_names as the real variables names. If "like",
54
- interpret var_names as substrings of the real variables names. If "regex",
55
- interpret var_names as regular expressions on the real variables names. A la
56
- ``pandas.filter``.
57
- combine_dims : set_like of str, optional
58
- List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
59
- See the :ref:`this section <common_combine_dims>` for usage examples.
60
- transform: callable
61
- Function to transform data (defaults to None i.e.the identity function)
62
- coords: mapping, optional
63
- Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
64
- grid : tuple
65
- Number of rows and columns. Defaults to None, the rows and columns are
66
- automatically inferred.
67
- figsize: tuple
68
- Figure size. If None it will be defined automatically.
69
- textsize: float
70
- Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
71
- on ``figsize``.
72
- hdi_prob: float, optional
73
- Plots highest density interval for chosen percentage of density.
74
- Use 'hide' to hide the highest density interval. Defaults to 0.94.
75
- multimodal: bool
76
- If true (default) it may compute more than one credible interval if the distribution is
77
- multimodal and the modes are well separated.
78
- skipna : bool
79
- If true ignores nan values when computing the hdi and point estimates. Defaults to false.
80
- round_to: int, optional
81
- Controls formatting of floats. Defaults to 2 or the integer part, whichever is bigger.
82
- point_estimate: Optional[str]
83
- Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
84
- Defaults to 'auto' i.e. it falls back to default set in rcParams.
85
- group: str, optional
86
- Specifies which InferenceData group should be plotted. Defaults to 'posterior'.
87
- rope : list, tuple or dictionary of {str: tuples or lists}, optional
88
- A dictionary of tuples with the lower and upper values of the Region Of Practical
89
- Equivalence. See :ref:`this section <common_rope>` for usage examples.
90
- ref_val: float or dictionary of floats
91
- display the percentage below and above the values in ref_val. Must be None (default),
92
- a constant, a list or a dictionary like see an example below. If a list is provided, its
93
- length should match the number of variables.
94
- rope_color: str, optional
95
- Specifies the color of ROPE and displayed percentage within ROPE
96
- ref_val_color: str, optional
97
- Specifies the color of the displayed percentage
98
- kind: str
99
- Type of plot to display (kde or hist) For discrete variables this argument is ignored and
100
- a histogram is always used. Defaults to rcParam ``plot.density_kind``
101
- bw: float or str, optional
102
- If numeric, indicates the bandwidth and must be positive.
103
- If str, indicates the method to estimate the bandwidth and must be
104
- one of "scott", "silverman", "isj" or "experimental" when `circular` is False
105
- and "taylor" (for now) when `circular` is True.
106
- Defaults to "default" which means "experimental" when variable is not circular
107
- and "taylor" when it is. Only works if `kind == kde`.
108
- circular: bool, optional
109
- If True, it interprets the values passed are from a circular variable measured in radians
110
- and a circular KDE is used. Only valid for 1D KDE. Defaults to False.
111
- Only works if `kind == kde`.
112
- bins: integer or sequence or 'auto', optional
113
- Controls the number of bins,accepts the same keywords :func:`matplotlib.pyplot.hist` does.
114
- Only works if `kind == hist`. If None (default) it will use `auto` for continuous variables
115
- and `range(xmin, xmax + 1)` for discrete variables.
116
- labeller : labeller instance, optional
117
- Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
118
- Read the :ref:`label_guide` for more details and usage examples.
119
- ax: numpy array-like of matplotlib axes or bokeh figures, optional
120
- A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
121
- its own array of plot areas (and return it).
122
- backend: str, optional
123
- Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
124
- backend_kwargs: bool, optional
125
- These are kwargs specific to the backend being used, passed to
126
- :func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`
127
- show: bool, optional
128
- Call backend show function.
129
- **kwargs
130
- Passed as-is to :func:`matplotlib.pyplot.hist` or :func:`matplotlib.pyplot.plot` function
131
- depending on the value of `kind`.
132
-
133
- Returns
134
- -------
135
- axes: matplotlib axes or bokeh figures
136
-
137
- See Also
138
- --------
139
- plot_dist : Plot distribution as histogram or kernel density estimates.
140
- plot_density : Generate KDE plots for continuous variables and histograms for discrete ones.
141
- plot_forest : Forest plot to compare HDI intervals from a number of distributions.
142
-
143
- Examples
144
- --------
145
- Show a default kernel density plot following style of John Kruschke
146
-
147
- .. plot::
148
- :context: close-figs
149
-
150
- >>> import arviz as az
151
- >>> data = az.load_arviz_data('centered_eight')
152
- >>> az.plot_posterior(data)
153
-
154
- Plot subset variables by specifying variable name exactly
155
-
156
- .. plot::
157
- :context: close-figs
158
-
159
- >>> az.plot_posterior(data, var_names=['mu'])
160
-
161
- Plot Region of Practical Equivalence (rope) and select variables with regular expressions
162
-
163
- .. plot::
164
- :context: close-figs
165
-
166
- >>> az.plot_posterior(data, var_names=['mu', '^the'], filter_vars="regex", rope=(-1, 1))
167
-
168
- Plot Region of Practical Equivalence for selected distributions
169
-
170
- .. plot::
171
- :context: close-figs
172
-
173
- >>> rope = {'mu': [{'rope': (-2, 2)}], 'theta': [{'school': 'Choate', 'rope': (2, 4)}]}
174
- >>> az.plot_posterior(data, var_names=['mu', 'theta'], rope=rope)
175
-
176
- Using `coords` argument to plot only a subset of data
177
-
178
- .. plot::
179
- :context: close-figs
180
-
181
- >>> coords = {"school": ["Choate","Phillips Exeter"]}
182
- >>> az.plot_posterior(data, var_names=["mu", "theta"], coords=coords)
183
-
184
- Add reference lines
185
-
186
- .. plot::
187
- :context: close-figs
188
-
189
- >>> az.plot_posterior(data, var_names=['mu', 'theta'], ref_val=0)
190
-
191
- Show point estimate of distribution
192
-
193
- .. plot::
194
- :context: close-figs
195
-
196
- >>> az.plot_posterior(data, var_names=['mu', 'theta'], point_estimate='mode')
197
-
198
- Show reference values using variable names and coordinates
199
-
200
- .. plot::
201
- :context: close-figs
202
-
203
- >>> az.plot_posterior(data, ref_val= {"theta": [{"school": "Deerfield", "ref_val": 4},
204
- ... {"school": "Choate", "ref_val": 3}]})
205
-
206
- Show reference values using a list
207
-
208
- .. plot::
209
- :context: close-figs
210
-
211
- >>> az.plot_posterior(data, ref_val=[1] + [5] * 8 + [1])
212
-
213
-
214
- Plot posterior as a histogram
215
-
216
- .. plot::
217
- :context: close-figs
218
-
219
- >>> az.plot_posterior(data, var_names=['mu'], kind='hist')
220
-
221
- Change size of highest density interval
222
-
223
- .. plot::
224
- :context: close-figs
225
-
226
- >>> az.plot_posterior(data, var_names=['mu'], hdi_prob=.75)
227
- """
228
- data = convert_to_dataset(data, group=group)
229
- if transform is not None:
230
- data = transform(data)
231
- var_names = _var_names(var_names, data, filter_vars)
232
-
233
- if coords is None:
234
- coords = {}
235
-
236
- if labeller is None:
237
- labeller = BaseLabeller()
238
-
239
- if hdi_prob is None:
240
- hdi_prob = rcParams["stats.ci_prob"]
241
- elif hdi_prob not in (None, "hide"):
242
- if not 1 >= hdi_prob > 0:
243
- raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
244
-
245
- if point_estimate == "auto":
246
- point_estimate = rcParams["plot.point_estimate"]
247
- elif point_estimate not in {"mean", "median", "mode", None}:
248
- raise ValueError("The value of point_estimate must be either mean, median, mode or None.")
249
-
250
- if kind is None:
251
- kind = rcParams["plot.density_kind"]
252
-
253
- plotters = filter_plotters_list(
254
- list(
255
- xarray_var_iter(
256
- get_coords(data, coords), var_names=var_names, combined=True, skip_dims=combine_dims
257
- )
258
- ),
259
- "plot_posterior",
260
- )
261
- length_plotters = len(plotters)
262
- rows, cols = default_grid(length_plotters, grid=grid)
263
-
264
- posteriorplot_kwargs = dict(
265
- ax=ax,
266
- length_plotters=length_plotters,
267
- rows=rows,
268
- cols=cols,
269
- figsize=figsize,
270
- plotters=plotters,
271
- bw=bw,
272
- circular=circular,
273
- bins=bins,
274
- kind=kind,
275
- point_estimate=point_estimate,
276
- round_to=round_to,
277
- hdi_prob=hdi_prob,
278
- multimodal=multimodal,
279
- skipna=skipna,
280
- textsize=textsize,
281
- ref_val=ref_val,
282
- rope=rope,
283
- ref_val_color=ref_val_color,
284
- rope_color=rope_color,
285
- labeller=labeller,
286
- kwargs=kwargs,
287
- backend_kwargs=backend_kwargs,
288
- show=show,
289
- )
290
-
291
- if backend is None:
292
- backend = rcParams["plot.backend"]
293
- backend = backend.lower()
294
-
295
- # TODO: Add backend kwargs
296
- plot = get_plotting_function("plot_posterior", "posteriorplot", backend)
297
- ax = plot(**posteriorplot_kwargs)
298
- return ax
arviz/plots/ppcplot.py DELETED
@@ -1,369 +0,0 @@
1
- """Posterior/Prior predictive plot."""
2
-
3
- import logging
4
- import warnings
5
- from numbers import Integral
6
-
7
- import numpy as np
8
-
9
- from ..labels import BaseLabeller
10
- from ..sel_utils import xarray_var_iter
11
- from ..rcparams import rcParams
12
- from ..utils import _var_names
13
- from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
14
-
15
- _log = logging.getLogger(__name__)
16
-
17
-
18
- def plot_ppc(
19
- data,
20
- kind="kde",
21
- alpha=None,
22
- mean=True,
23
- observed=None,
24
- observed_rug=False,
25
- color=None,
26
- colors=None,
27
- grid=None,
28
- figsize=None,
29
- textsize=None,
30
- data_pairs=None,
31
- var_names=None,
32
- filter_vars=None,
33
- coords=None,
34
- flatten=None,
35
- flatten_pp=None,
36
- num_pp_samples=None,
37
- random_seed=None,
38
- jitter=None,
39
- animated=False,
40
- animation_kwargs=None,
41
- legend=True,
42
- labeller=None,
43
- ax=None,
44
- backend=None,
45
- backend_kwargs=None,
46
- group="posterior",
47
- show=None,
48
- ):
49
- """
50
- Plot for posterior/prior predictive checks.
51
-
52
- Parameters
53
- ----------
54
- data : InferenceData
55
- :class:`arviz.InferenceData` object containing the observed and posterior/prior
56
- predictive data.
57
- kind : str, default "kde"
58
- Type of plot to display ("kde", "cumulative", or "scatter").
59
- alpha : float, optional
60
- Opacity of posterior/prior predictive density curves.
61
- Defaults to 0.2 for ``kind = kde`` and cumulative, for scatter defaults to 0.7.
62
- mean : bool, default True
63
- Whether or not to plot the mean posterior/prior predictive distribution.
64
- observed : bool, optional
65
- Whether or not to plot the observed data. Defaults to True for ``group = posterior``
66
- and False for ``group = prior``.
67
- observed_rug : bool, default False
68
- Whether or not to plot a rug plot for the observed data. Only valid if `observed` is
69
- `True` and for kind `kde` or `cumulative`.
70
- color : list, optional
71
- List with valid matplotlib colors corresponding to the posterior/prior predictive
72
- distribution, observed data and mean of the posterior/prior predictive distribution.
73
- Defaults to ["C0", "k", "C1"].
74
- grid : tuple, optional
75
- Number of rows and columns. Defaults to None, the rows and columns are
76
- automatically inferred.
77
- figsize : tuple, optional
78
- Figure size. If None, it will be defined automatically.
79
- textsize : float, optional
80
- Text size scaling factor for labels, titles and lines. If None, it will be
81
- autoscaled based on ``figsize``.
82
- data_pairs : dict, optional
83
- Dictionary containing relations between observed data and posterior/prior predictive data.
84
- Dictionary structure:
85
-
86
- - key = data var_name
87
- - value = posterior/prior predictive var_name
88
-
89
- For example, ``data_pairs = {'y' : 'y_hat'}``
90
- If None, it will assume that the observed data and the posterior/prior
91
- predictive data have the same variable name.
92
- var_names : list of str, optional
93
- Variables to be plotted, if `None` all variable are plotted. Prefix the
94
- variables by ``~`` when you want to exclude them from the plot.
95
- filter_vars : {None, "like", "regex"}, default None
96
- If `None` (default), interpret var_names as the real variables names. If "like",
97
- interpret var_names as substrings of the real variables names. If "regex",
98
- interpret var_names as regular expressions on the real variables names. A la
99
- ``pandas.filter``.
100
- coords : dict, optional
101
- Dictionary mapping dimensions to selected coordinates to be plotted.
102
- Dimensions without a mapping specified will include all coordinates for
103
- that dimension. Defaults to including all coordinates for all
104
- dimensions if None.
105
- flatten : list
106
- List of dimensions to flatten in ``observed_data``. Only flattens across the coordinates
107
- specified in the ``coords`` argument. Defaults to flattening all of the dimensions.
108
- flatten_pp : list
109
- List of dimensions to flatten in posterior_predictive/prior_predictive. Only flattens
110
- across the coordinates specified in the ``coords`` argument. Defaults to flattening all
111
- of the dimensions. Dimensions should match flatten excluding dimensions for ``data_pairs``
112
- parameters. If ``flatten`` is defined and ``flatten_pp`` is None, then
113
- ``flatten_pp = flatten``.
114
- num_pp_samples : int
115
- The number of posterior/prior predictive samples to plot. For ``kind`` = 'scatter' and
116
- ``animation = False`` if defaults to a maximum of 5 samples and will set jitter to 0.7.
117
- unless defined. Otherwise it defaults to all provided samples.
118
- random_seed : int
119
- Random number generator seed passed to ``numpy.random.seed`` to allow
120
- reproducibility of the plot. By default, no seed will be provided
121
- and the plot will change each call if a random sample is specified
122
- by ``num_pp_samples``.
123
- jitter : float, default 0
124
- If ``kind`` is "scatter", jitter will add random uniform noise to the height
125
- of the ppc samples and observed data.
126
- animated : bool, default False
127
- Create an animation of one posterior/prior predictive sample per frame.
128
- Only works with matploblib backend.
129
- To run animations inside a notebook you have to use the `nbAgg` matplotlib's backend.
130
- Try with `%matplotlib notebook` or `%matplotlib nbAgg`. You can switch back to the
131
- default matplotlib's backend with `%matplotlib inline` or `%matplotlib auto`.
132
- If switching back and forth between matplotlib's backend, you may need to run twice the cell
133
- with the animation.
134
- If you experience problems rendering the animation try setting
135
- ``animation_kwargs({'blit':False})`` or changing the matplotlib's backend (e.g. to TkAgg)
136
- If you run the animation from a script write ``ax, ani = az.plot_ppc(.)``
137
- animation_kwargs : dict
138
- Keywords passed to :class:`matplotlib.animation.FuncAnimation`. Ignored with
139
- matplotlib backend.
140
- legend : bool, default True
141
- Add legend to figure.
142
- labeller : labeller, optional
143
- Class providing the method ``make_pp_label`` to generate the labels in the plot titles.
144
- Read the :ref:`label_guide` for more details and usage examples.
145
- ax : numpy array-like of matplotlib_axes or bokeh figures, optional
146
- A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
147
- its own array of plot areas (and return it).
148
- backend : str, optional
149
- Select plotting backend {"matplotlib","bokeh"}. Default to "matplotlib".
150
- backend_kwargs : dict, optional
151
- These are kwargs specific to the backend being used, passed to
152
- :func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`.
153
- For additional documentation check the plotting method of the backend.
154
- group : {"prior", "posterior"}, optional
155
- Specifies which InferenceData group should be plotted. Defaults to 'posterior'.
156
- Other value can be 'prior'.
157
- show : bool, optional
158
- Call backend show function.
159
-
160
- Returns
161
- -------
162
- axes : matplotlib_axes or bokeh_figures
163
- ani : matplotlib.animation.FuncAnimation, optional
164
- Only provided if `animated` is ``True``.
165
-
166
- See Also
167
- --------
168
- plot_bpv : Plot Bayesian p-value for observed data and Posterior/Prior predictive.
169
- plot_loo_pit : Plot for posterior predictive checks using cross validation.
170
- plot_lm : Posterior predictive and mean plots for regression-like data.
171
- plot_ts : Plot timeseries data.
172
-
173
- Examples
174
- --------
175
- Plot the observed data KDE overlaid on posterior predictive KDEs.
176
-
177
- .. plot::
178
- :context: close-figs
179
-
180
- >>> import arviz as az
181
- >>> data = az.load_arviz_data('radon')
182
- >>> az.plot_ppc(data, data_pairs={"y":"y"})
183
-
184
- Plot the overlay with empirical CDFs.
185
-
186
- .. plot::
187
- :context: close-figs
188
-
189
- >>> az.plot_ppc(data, kind='cumulative')
190
-
191
- Use the ``coords`` and ``flatten`` parameters to plot selected variable dimensions
192
- across multiple plots. We will now modify the dimension ``obs_id`` to contain
193
- indicate the name of the county where the measure was taken. The change has to
194
- be done on both ``posterior_predictive`` and ``observed_data`` groups, which is
195
- why we will use :meth:`~arviz.InferenceData.map` to apply the same function to
196
- both groups. Afterwards, we will select the counties to be plotted with the
197
- ``coords`` arg.
198
-
199
- .. plot::
200
- :context: close-figs
201
-
202
- >>> obs_county = data.posterior["County"][data.constant_data["county_idx"]]
203
- >>> data = data.assign_coords(obs_id=obs_county, groups="observed_vars")
204
- >>> az.plot_ppc(data, coords={'obs_id': ['ANOKA', 'BELTRAMI']}, flatten=[])
205
-
206
- Plot the overlay using a stacked scatter plot that is particularly useful
207
- when the sample sizes are small.
208
-
209
- .. plot::
210
- :context: close-figs
211
-
212
- >>> az.plot_ppc(data, kind='scatter', flatten=[],
213
- >>> coords={'obs_id': ['AITKIN', 'BELTRAMI']})
214
-
215
- Plot random posterior predictive sub-samples.
216
-
217
- .. plot::
218
- :context: close-figs
219
-
220
- >>> az.plot_ppc(data, num_pp_samples=30, random_seed=7)
221
- """
222
- if group not in ("posterior", "prior"):
223
- raise TypeError("`group` argument must be either `posterior` or `prior`")
224
-
225
- for groups in (f"{group}_predictive", "observed_data"):
226
- if not hasattr(data, groups):
227
- raise TypeError(f'`data` argument must have the group "{groups}" for ppcplot')
228
-
229
- if kind.lower() not in ("kde", "cumulative", "scatter"):
230
- raise TypeError("`kind` argument must be either `kde`, `cumulative`, or `scatter`")
231
-
232
- if colors is None:
233
- colors = ["C0", "k", "C1"]
234
-
235
- if isinstance(colors, str):
236
- raise TypeError("colors should be a list with 3 items.")
237
-
238
- if len(colors) != 3:
239
- raise ValueError("colors should be a list with 3 items.")
240
-
241
- if color is not None:
242
- warnings.warn("color has been deprecated in favor of colors", FutureWarning)
243
- colors[0] = color
244
-
245
- if data_pairs is None:
246
- data_pairs = {}
247
-
248
- if backend is None:
249
- backend = rcParams["plot.backend"]
250
- backend = backend.lower()
251
- if backend == "bokeh" and animated:
252
- raise TypeError("Animation option is only supported with matplotlib backend.")
253
-
254
- observed_data = data.observed_data
255
-
256
- if group == "posterior":
257
- predictive_dataset = data.posterior_predictive
258
- if observed is None:
259
- observed = True
260
- elif group == "prior":
261
- predictive_dataset = data.prior_predictive
262
- if observed is None:
263
- observed = False
264
-
265
- if var_names is None:
266
- var_names = list(observed_data.data_vars)
267
- var_names = _var_names(var_names, observed_data, filter_vars)
268
- pp_var_names = [data_pairs.get(var, var) for var in var_names]
269
- pp_var_names = _var_names(pp_var_names, predictive_dataset, filter_vars)
270
-
271
- if flatten_pp is None:
272
- if flatten is None:
273
- flatten_pp = list(predictive_dataset.dims)
274
- else:
275
- flatten_pp = flatten
276
- if flatten is None:
277
- flatten = list(observed_data.dims)
278
-
279
- if coords is None:
280
- coords = {}
281
- else:
282
- coords = coords.copy()
283
-
284
- if labeller is None:
285
- labeller = BaseLabeller()
286
-
287
- if random_seed is not None:
288
- np.random.seed(random_seed)
289
-
290
- total_pp_samples = predictive_dataset.sizes["chain"] * predictive_dataset.sizes["draw"]
291
- if num_pp_samples is None:
292
- if kind == "scatter" and not animated:
293
- num_pp_samples = min(5, total_pp_samples)
294
- else:
295
- num_pp_samples = total_pp_samples
296
-
297
- if (
298
- not isinstance(num_pp_samples, Integral)
299
- or num_pp_samples < 1
300
- or num_pp_samples > total_pp_samples
301
- ):
302
- raise TypeError(f"`num_pp_samples` must be an integer between 1 and {total_pp_samples}.")
303
-
304
- pp_sample_ix = np.random.choice(total_pp_samples, size=num_pp_samples, replace=False)
305
-
306
- for key in coords.keys():
307
- coords[key] = np.where(np.isin(observed_data[key], coords[key]))[0]
308
-
309
- obs_plotters = filter_plotters_list(
310
- list(
311
- xarray_var_iter(
312
- observed_data.isel(coords),
313
- skip_dims=set(flatten),
314
- var_names=var_names,
315
- combined=True,
316
- dim_order=["chain", "draw"],
317
- )
318
- ),
319
- "plot_ppc",
320
- )
321
- length_plotters = len(obs_plotters)
322
- pp_plotters = [
323
- tup
324
- for _, tup in zip(
325
- range(length_plotters),
326
- xarray_var_iter(
327
- predictive_dataset.isel(coords),
328
- var_names=pp_var_names,
329
- skip_dims=set(flatten_pp),
330
- combined=True,
331
- dim_order=["chain", "draw"],
332
- ),
333
- )
334
- ]
335
- rows, cols = default_grid(length_plotters, grid=grid)
336
-
337
- ppcplot_kwargs = dict(
338
- ax=ax,
339
- length_plotters=length_plotters,
340
- rows=rows,
341
- cols=cols,
342
- figsize=figsize,
343
- animated=animated,
344
- obs_plotters=obs_plotters,
345
- pp_plotters=pp_plotters,
346
- predictive_dataset=predictive_dataset,
347
- pp_sample_ix=pp_sample_ix,
348
- kind=kind,
349
- alpha=alpha,
350
- colors=colors,
351
- jitter=jitter,
352
- textsize=textsize,
353
- mean=mean,
354
- observed=observed,
355
- observed_rug=observed_rug,
356
- total_pp_samples=total_pp_samples,
357
- legend=legend,
358
- labeller=labeller,
359
- group=group,
360
- animation_kwargs=animation_kwargs,
361
- num_pp_samples=num_pp_samples,
362
- backend_kwargs=backend_kwargs,
363
- show=show,
364
- )
365
-
366
- # TODO: Add backend kwargs
367
- plot = get_plotting_function("plot_ppc", "ppcplot", backend)
368
- axes = plot(**ppcplot_kwargs)
369
- return axes