arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
arviz/plots/tsplot.py DELETED
@@ -1,440 +0,0 @@
1
- """Plot timeseries data."""
2
-
3
- import warnings
4
- import numpy as np
5
-
6
- from ..sel_utils import xarray_var_iter
7
- from ..rcparams import rcParams
8
- from .plot_utils import default_grid, get_plotting_function
9
-
10
-
11
- def plot_ts(
12
- idata,
13
- y,
14
- x=None,
15
- y_hat=None,
16
- y_holdout=None,
17
- y_forecasts=None,
18
- x_holdout=None,
19
- plot_dim=None,
20
- holdout_dim=None,
21
- num_samples=100,
22
- backend=None,
23
- backend_kwargs=None,
24
- y_kwargs=None,
25
- y_hat_plot_kwargs=None,
26
- y_mean_plot_kwargs=None,
27
- vline_kwargs=None,
28
- textsize=None,
29
- figsize=None,
30
- legend=True,
31
- axes=None,
32
- show=None,
33
- ):
34
- """Plot timeseries data.
35
-
36
- Parameters
37
- ----------
38
- idata : InferenceData
39
- :class:`arviz.InferenceData` object.
40
- y : str
41
- Variable name from ``observed_data``.
42
- Values to be plotted on y-axis before holdout.
43
- x : str, Optional
44
- Values to be plotted on x-axis before holdout.
45
- If None, coords of ``y`` dims is chosen.
46
- y_hat : str, optional
47
- Variable name from ``posterior_predictive``.
48
- Assumed to be of shape ``(chain, draw, *y_dims)``.
49
- y_holdout : str, optional
50
- Variable name from ``observed_data``.
51
- It represents the observed data after the holdout period.
52
- Useful while testing the model, when you want to compare
53
- observed test data with predictions/forecasts.
54
- y_forecasts : str, optional
55
- Variable name from ``posterior_predictive``.
56
- It represents forecasts (posterior predictive) values after holdout period.
57
- Useful to compare observed vs predictions/forecasts.
58
- Assumed shape ``(chain, draw, *shape)``.
59
- x_holdout : str, Defaults to coords of y.
60
- Variable name from ``constant_data``.
61
- If None, coords of ``y_holdout`` or
62
- coords of ``y_forecast`` (either of the two available) is chosen.
63
- plot_dim: str, Optional
64
- Should be present in ``y.dims``.
65
- Necessary for selection of ``x`` if ``x`` is None and ``y`` is multidimensional.
66
- holdout_dim: str, Optional
67
- Should be present in ``y_holdout.dims`` or ``y_forecats.dims``.
68
- Necessary to choose ``x_holdout`` if ``x`` is None and
69
- if ``y_holdout`` or ``y_forecasts`` is multidimensional.
70
- num_samples : int, default 100
71
- Number of posterior predictive samples drawn from ``y_hat`` and ``y_forecasts``.
72
- backend : {"matplotlib", "bokeh"}, default "matplotlib"
73
- Select plotting backend.
74
- y_kwargs : dict, optional
75
- Passed to :meth:`matplotlib.axes.Axes.plot` in matplotlib.
76
- y_hat_plot_kwargs : dict, optional
77
- Passed to :meth:`matplotlib.axes.Axes.plot` in matplotlib.
78
- y_mean_plot_kwargs : dict, optional
79
- Passed to :meth:`matplotlib.axes.Axes.plot` in matplotlib.
80
- vline_kwargs : dict, optional
81
- Passed to :meth:`matplotlib.axes.Axes.axvline` in matplotlib.
82
- backend_kwargs : dict, optional
83
- These are kwargs specific to the backend being used. Passed to
84
- :func:`matplotlib.pyplot.subplots`.
85
- figsize : tuple, optional
86
- Figure size. If None, it will be defined automatically.
87
- textsize : float, optional
88
- Text size scaling factor for labels, titles and lines. If None, it will be
89
- autoscaled based on ``figsize``.
90
-
91
-
92
- Returns
93
- -------
94
- axes: matplotlib axes or bokeh figures.
95
-
96
- See Also
97
- --------
98
- plot_lm : Posterior predictive and mean plots for regression-like data.
99
- plot_ppc : Plot for posterior/prior predictive checks.
100
-
101
- Examples
102
- --------
103
- Plot timeseries default plot
104
-
105
- .. plot::
106
- :context: close-figs
107
-
108
- >>> import arviz as az
109
- >>> nchains, ndraws = (4, 500)
110
- >>> obs_data = {
111
- ... "y": 2 * np.arange(1, 9) + 3,
112
- ... "z": 2 * np.arange(8, 12) + 3,
113
- ... }
114
- >>> posterior_predictive = {
115
- ... "y": np.random.normal(
116
- ... (obs_data["y"] * 1.2) - 3, size=(nchains, ndraws, len(obs_data["y"]))
117
- ... ),
118
- ... "z": np.random.normal(
119
- ... (obs_data["z"] * 1.2) - 3, size=(nchains, ndraws, len(obs_data["z"]))
120
- ... ),
121
- ... }
122
- >>> idata = az.from_dict(
123
- ... observed_data=obs_data,
124
- ... posterior_predictive=posterior_predictive,
125
- ... coords={"obs_dim": np.arange(1, 9), "pred_dim": np.arange(8, 12)},
126
- ... dims={"y": ["obs_dim"], "z": ["pred_dim"]},
127
- ... )
128
- >>> ax = az.plot_ts(idata=idata, y="y", y_holdout="z")
129
-
130
- Plot timeseries multidim plot
131
-
132
- .. plot::
133
- :context: close-figs
134
-
135
- >>> ndim1, ndim2 = (5, 7)
136
- >>> data = {
137
- ... "y": np.random.normal(size=(ndim1, ndim2)),
138
- ... "z": np.random.normal(size=(ndim1, ndim2)),
139
- ... }
140
- >>> posterior_predictive = {
141
- ... "y": np.random.randn(nchains, ndraws, ndim1, ndim2),
142
- ... "z": np.random.randn(nchains, ndraws, ndim1, ndim2),
143
- ... }
144
- >>> const_data = {"x": np.arange(1, 6), "x_pred": np.arange(5, 10)}
145
- >>> idata = az.from_dict(
146
- ... observed_data=data,
147
- ... posterior_predictive=posterior_predictive,
148
- ... constant_data=const_data,
149
- ... dims={
150
- ... "y": ["dim1", "dim2"],
151
- ... "z": ["holdout_dim1", "holdout_dim2"],
152
- ... },
153
- ... coords={
154
- ... "dim1": range(ndim1),
155
- ... "dim2": range(ndim2),
156
- ... "holdout_dim1": range(ndim1 - 1, ndim1 + 4),
157
- ... "holdout_dim2": range(ndim2 - 1, ndim2 + 6),
158
- ... },
159
- ... )
160
- >>> az.plot_ts(
161
- ... idata=idata,
162
- ... y="y",
163
- ... plot_dim="dim1",
164
- ... y_holdout="z",
165
- ... holdout_dim="holdout_dim1",
166
- ... )
167
-
168
- """
169
- # Assign default values if none is provided
170
- y_hat = y if y_hat is None and isinstance(y, str) else y_hat
171
- y_forecasts = y_holdout if y_forecasts is None and isinstance(y_holdout, str) else y_forecasts
172
- # holdout_dim = plot_dim if holdout_dim is None and plot_dim is not None else holdout_dim
173
-
174
- if isinstance(y, str):
175
- y = idata.observed_data[y]
176
-
177
- if isinstance(y_holdout, str):
178
- y_holdout = idata.observed_data[y_holdout]
179
-
180
- if len(y.dims) > 1 and plot_dim is None:
181
- raise ValueError("Argument plot_dim is needed in case of multidimensional data")
182
-
183
- if y_holdout is not None and len(y_holdout.dims) > 1 and holdout_dim is None:
184
- raise ValueError("Argument holdout_dim is needed in case of multidimensional data")
185
-
186
- # Assigning values to x
187
- x_var_names = None
188
- if isinstance(x, str):
189
- x = idata.constant_data[x]
190
- elif isinstance(x, tuple):
191
- x_var_names = x
192
- x = idata.constant_data
193
- elif x is None:
194
- if plot_dim is None:
195
- x = y.coords[y.dims[0]]
196
- else:
197
- x = y.coords[plot_dim]
198
-
199
- # If posterior_predictive is present in idata and y_hat is there, get its values
200
- if isinstance(y_hat, str):
201
- if "posterior_predictive" not in idata.groups():
202
- warnings.warn("posterior_predictive not found in idata", UserWarning)
203
- y_hat = None
204
- elif hasattr(idata.posterior_predictive, y_hat):
205
- y_hat = idata.posterior_predictive[y_hat]
206
- else:
207
- warnings.warn("y_hat not found in posterior_predictive", UserWarning)
208
- y_hat = None
209
-
210
- # If posterior_predictive is present in idata and y_forecasts is there, get its values
211
- x_holdout_var_names = None
212
- if isinstance(y_forecasts, str):
213
- if "posterior_predictive" not in idata.groups():
214
- warnings.warn("posterior_predictive not found in idata", UserWarning)
215
- y_forecasts = None
216
- elif hasattr(idata.posterior_predictive, y_forecasts):
217
- y_forecasts = idata.posterior_predictive[y_forecasts]
218
- else:
219
- warnings.warn("y_hat not found in posterior_predictive", UserWarning)
220
- y_forecasts = None
221
-
222
- # Assign values to y_holdout
223
- if isinstance(y_holdout, str):
224
- y_holdout = idata.observed_data[y_holdout]
225
-
226
- # Assign values to x_holdout.
227
- if y_holdout is not None or y_forecasts is not None:
228
- if x_holdout is None:
229
- if holdout_dim is None:
230
- if y_holdout is None:
231
- x_holdout = y_forecasts.coords[y_forecasts.dims[-1]]
232
- else:
233
- x_holdout = y_holdout.coords[y_holdout.dims[-1]]
234
- elif y_holdout is None:
235
- x_holdout = y_forecasts.coords[holdout_dim]
236
- else:
237
- x_holdout = y_holdout.coords[holdout_dim]
238
- elif isinstance(x_holdout, str):
239
- x_holdout = idata.constant_data[x_holdout]
240
- elif isinstance(x_holdout, tuple):
241
- x_holdout_var_names = x_holdout
242
- x_holdout = idata.constant_data
243
-
244
- # Choose dims to generate y plotters
245
- if plot_dim is None:
246
- skip_dims = list(y.dims)
247
- elif isinstance(plot_dim, str):
248
- skip_dims = [plot_dim]
249
- elif isinstance(plot_dim, tuple):
250
- skip_dims = list(plot_dim)
251
-
252
- # Choose dims to generate y_holdout plotters
253
- if holdout_dim is None:
254
- if y_holdout is not None:
255
- skip_holdout_dims = list(y_holdout.dims)
256
- elif y_forecasts is not None:
257
- skip_holdout_dims = list(y_forecasts.dims)
258
- elif isinstance(holdout_dim, str):
259
- skip_holdout_dims = [holdout_dim]
260
- elif isinstance(holdout_dim, tuple):
261
- skip_holdout_dims = list(holdout_dim)
262
-
263
- # Compulsory plotters
264
- y_plotters = list(
265
- xarray_var_iter(
266
- y,
267
- skip_dims=set(skip_dims),
268
- combined=True,
269
- )
270
- )
271
-
272
- # Compulsory plotters
273
- x_plotters = list(
274
- xarray_var_iter(
275
- x,
276
- var_names=x_var_names,
277
- skip_dims=set(x.dims),
278
- combined=True,
279
- )
280
- )
281
- # Necessary when multidim y
282
- # If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs
283
- len_y = len(y_plotters)
284
- len_x = len(x_plotters)
285
- length_plotters = len_x * len_y
286
- # TODO: Incompatible types in assignment (expression has type "ndarray[Any, dtype[Any]]",
287
- # TODO: variable has type "List[Any]") [assignment]
288
- y_plotters = np.tile(np.array(y_plotters, dtype=object), (len_x, 1)) # type: ignore[assignment]
289
- x_plotters = np.tile(np.array(x_plotters, dtype=object), (len_y, 1)) # type: ignore[assignment]
290
-
291
- # Generate plotters for all the available data
292
- y_mean_plotters = None
293
- y_hat_plotters = None
294
- if y_hat is not None:
295
- total_samples = y_hat.sizes["chain"] * y_hat.sizes["draw"]
296
- pp_sample_ix = np.random.choice(total_samples, size=num_samples, replace=False)
297
-
298
- y_hat_satcked = y_hat.stack(__sample__=("chain", "draw"))[..., pp_sample_ix]
299
-
300
- y_hat_plotters = list(
301
- xarray_var_iter(
302
- y_hat_satcked,
303
- skip_dims=set(skip_dims + ["__sample__"]),
304
- combined=True,
305
- )
306
- )
307
-
308
- y_mean = y_hat.mean(("chain", "draw"))
309
- y_mean_plotters = list(
310
- xarray_var_iter(
311
- y_mean,
312
- skip_dims=set(skip_dims),
313
- combined=True,
314
- )
315
- )
316
-
317
- # Necessary when multidim y
318
- # If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs
319
- # TODO: Incompatible types in assignment (expression has type "ndarray[Any, dtype[Any]]",
320
- # TODO: variable has type "List[Any]") [assignment]
321
- y_hat_plotters = np.tile(
322
- np.array(y_hat_plotters, dtype=object), (len_x, 1)
323
- ) # type: ignore[assignment]
324
- y_mean_plotters = np.tile(
325
- np.array(y_mean_plotters, dtype=object), (len_x, 1)
326
- ) # type: ignore[assignment]
327
-
328
- y_holdout_plotters = None
329
- x_holdout_plotters = None
330
- if y_holdout is not None:
331
- y_holdout_plotters = list(
332
- xarray_var_iter(
333
- y_holdout,
334
- skip_dims=set(skip_holdout_dims),
335
- combined=True,
336
- )
337
- )
338
-
339
- x_holdout_plotters = list(
340
- xarray_var_iter(
341
- x_holdout,
342
- var_names=x_holdout_var_names,
343
- skip_dims=set(x_holdout.dims),
344
- combined=True,
345
- )
346
- )
347
-
348
- # Necessary when multidim y
349
- # If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs
350
- # TODO: Incompatible types in assignment (expression has type "ndarray[Any, dtype[Any]]",
351
- # TODO: variable has type "List[Any]") [assignment]
352
- y_holdout_plotters = np.tile(
353
- np.array(y_holdout_plotters, dtype=object), (len_x, 1)
354
- ) # type: ignore[assignment]
355
- x_holdout_plotters = np.tile(
356
- np.array(x_holdout_plotters, dtype=object), (len_y, 1)
357
- ) # type: ignore[assignment]
358
-
359
- y_forecasts_plotters = None
360
- y_forecasts_mean_plotters = None
361
- if y_forecasts is not None:
362
- total_samples = y_forecasts.sizes["chain"] * y_forecasts.sizes["draw"]
363
- pp_sample_ix = np.random.choice(total_samples, size=num_samples, replace=False)
364
-
365
- y_forecasts_satcked = y_forecasts.stack(__sample__=("chain", "draw"))[..., pp_sample_ix]
366
-
367
- y_forecasts_plotters = list(
368
- xarray_var_iter(
369
- y_forecasts_satcked,
370
- skip_dims=set(skip_holdout_dims + ["__sample__"]),
371
- combined=True,
372
- )
373
- )
374
-
375
- y_forecasts_mean = y_forecasts.mean(("chain", "draw"))
376
- y_forecasts_mean_plotters = list(
377
- xarray_var_iter(
378
- y_forecasts_mean,
379
- skip_dims=set(skip_holdout_dims),
380
- combined=True,
381
- )
382
- )
383
-
384
- x_holdout_plotters = list(
385
- xarray_var_iter(
386
- x_holdout,
387
- var_names=x_holdout_var_names,
388
- skip_dims=set(x_holdout.dims),
389
- combined=True,
390
- )
391
- )
392
-
393
- # Necessary when multidim y
394
- # If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs
395
- # TODO: Incompatible types in assignment (expression has type "ndarray[Any, dtype[Any]]",
396
- # TODO: variable has type "List[Any]") [assignment]
397
- y_forecasts_mean_plotters = np.tile(
398
- np.array(y_forecasts_mean_plotters, dtype=object), (len_x, 1)
399
- ) # type: ignore[assignment]
400
- y_forecasts_plotters = np.tile(
401
- np.array(y_forecasts_plotters, dtype=object), (len_x, 1)
402
- ) # type: ignore[assignment]
403
- x_holdout_plotters = np.tile(
404
- np.array(x_holdout_plotters, dtype=object), (len_y, 1)
405
- ) # type: ignore[assignment]
406
-
407
- rows, cols = default_grid(length_plotters)
408
-
409
- tsplot_kwargs = dict(
410
- x_plotters=x_plotters,
411
- y_plotters=y_plotters,
412
- y_mean_plotters=y_mean_plotters,
413
- y_hat_plotters=y_hat_plotters,
414
- y_holdout_plotters=y_holdout_plotters,
415
- x_holdout_plotters=x_holdout_plotters,
416
- y_forecasts_plotters=y_forecasts_plotters,
417
- y_forecasts_mean_plotters=y_forecasts_mean_plotters,
418
- num_samples=num_samples,
419
- length_plotters=length_plotters,
420
- rows=rows,
421
- cols=cols,
422
- backend_kwargs=backend_kwargs,
423
- y_kwargs=y_kwargs,
424
- y_hat_plot_kwargs=y_hat_plot_kwargs,
425
- y_mean_plot_kwargs=y_mean_plot_kwargs,
426
- vline_kwargs=vline_kwargs,
427
- textsize=textsize,
428
- figsize=figsize,
429
- legend=legend,
430
- axes=axes,
431
- show=show,
432
- )
433
-
434
- if backend is None:
435
- backend = rcParams["plot.backend"]
436
- backend = backend.lower()
437
-
438
- plot = get_plotting_function("plot_ts", "tsplot", backend)
439
- ax = plot(**tsplot_kwargs)
440
- return ax
arviz/plots/violinplot.py DELETED
@@ -1,192 +0,0 @@
1
- """Plot posterior traces as violin plot."""
2
-
3
- from ..data import convert_to_dataset
4
- from ..labels import BaseLabeller
5
- from ..sel_utils import xarray_var_iter
6
- from ..utils import _var_names
7
- from ..rcparams import rcParams
8
- from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
9
-
10
-
11
- def plot_violin(
12
- data,
13
- var_names=None,
14
- combine_dims=None,
15
- filter_vars=None,
16
- transform=None,
17
- quartiles=True,
18
- rug=False,
19
- side="both",
20
- hdi_prob=None,
21
- shade=0.35,
22
- bw="default",
23
- circular=False,
24
- sharex=True,
25
- sharey=True,
26
- grid=None,
27
- figsize=None,
28
- textsize=None,
29
- labeller=None,
30
- ax=None,
31
- shade_kwargs=None,
32
- rug_kwargs=None,
33
- backend=None,
34
- backend_kwargs=None,
35
- show=None,
36
- ):
37
- """Plot posterior of traces as violin plot.
38
-
39
- Notes
40
- -----
41
- If multiple chains are provided for a variable they will be combined
42
-
43
- Parameters
44
- ----------
45
- data: obj
46
- Any object that can be converted to an :class:`arviz.InferenceData` object
47
- Refer to documentation of :func:`arviz.convert_to_dataset` for details
48
- var_names: list of variable names, optional
49
- Variables to be plotted, if None all variable are plotted. Prefix the
50
- variables by ``~`` when you want to exclude them from the plot.
51
- combine_dims : set_like of str, optional
52
- List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
53
- See the :ref:`this section <common_combine_dims>` for usage examples.
54
- filter_vars: {None, "like", "regex"}, optional, default=None
55
- If `None` (default), interpret var_names as the real variables names. If "like",
56
- interpret var_names as substrings of the real variables names. If "regex",
57
- interpret var_names as regular expressions on the real variables names. A la
58
- ``pandas.filter``.
59
- transform: callable
60
- Function to transform data (defaults to None i.e. the identity function).
61
- quartiles: bool, optional
62
- Flag for plotting the interquartile range, in addition to the ``hdi_prob`` * 100%
63
- intervals. Defaults to ``True``.
64
- rug: bool
65
- If ``True`` adds a jittered rugplot. Defaults to ``False``.
66
- side : {"both", "left", "right"}, default "both"
67
- If ``both``, both sides of the violin plot are rendered. If ``left`` or ``right``, only
68
- the respective side is rendered. By separately plotting left and right halfs with
69
- different data, split violin plots can be achieved.
70
- hdi_prob: float, optional
71
- Plots highest posterior density interval for chosen percentage of density.
72
- Defaults to 0.94.
73
- shade: float
74
- Alpha blending value for the shaded area under the curve, between 0
75
- (no shade) and 1 (opaque). Defaults to 0.
76
- bw: float or str, optional
77
- If numeric, indicates the bandwidth and must be positive.
78
- If str, indicates the method to estimate the bandwidth and must be
79
- one of "scott", "silverman", "isj" or "experimental" when ``circular`` is ``False``
80
- and "taylor" (for now) when ``circular`` is ``True``.
81
- Defaults to "default" which means "experimental" when variable is not circular
82
- and "taylor" when it is.
83
- circular: bool, optional.
84
- If ``True``, it interprets `values` is a circular variable measured in radians
85
- and a circular KDE is used. Defaults to ``False``.
86
- grid : tuple
87
- Number of rows and columns. Defaults to None, the rows and columns are
88
- automatically inferred.
89
- figsize: tuple
90
- Figure size. If None it will be defined automatically.
91
- textsize: int
92
- Text size of the point_estimates, axis ticks, and highest density interval. If None it will
93
- be autoscaled based on ``figsize``.
94
- labeller : labeller instance, optional
95
- Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
96
- Read the :ref:`label_guide` for more details and usage examples.
97
- sharex: bool
98
- Defaults to ``True``, violinplots share a common x-axis scale.
99
- sharey: bool
100
- Defaults to ``True``, violinplots share a common y-axis scale.
101
- ax: numpy array-like of matplotlib axes or bokeh figures, optional
102
- A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
103
- its own array of plot areas (and return it).
104
- shade_kwargs: dicts, optional
105
- Additional keywords passed to :meth:`matplotlib.axes.Axes.fill_between`, or
106
- :meth:`matplotlib.axes.Axes.barh` to control the shade.
107
- rug_kwargs: dict
108
- Keywords passed to the rug plot. If true only the right half side of the violin will be
109
- plotted.
110
- backend: str, optional
111
- Select plotting backend {"matplotlib","bokeh"}. Default to "matplotlib".
112
- backend_kwargs: bool, optional
113
- These are kwargs specific to the backend being used, passed to
114
- :func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`.
115
- For additional documentation check the plotting method of the backend.
116
- show: bool, optional
117
- Call backend show function.
118
-
119
- Returns
120
- -------
121
- axes: matplotlib axes or bokeh figures
122
-
123
- See Also
124
- --------
125
- plot_forest: Forest plot to compare HDI intervals from a number of distributions.
126
-
127
- Examples
128
- --------
129
- Show a default violin plot
130
-
131
- .. plot::
132
- :context: close-figs
133
-
134
- >>> import arviz as az
135
- >>> data = az.load_arviz_data('centered_eight')
136
- >>> az.plot_violin(data)
137
-
138
- """
139
- if labeller is None:
140
- labeller = BaseLabeller()
141
-
142
- data = convert_to_dataset(data, group="posterior")
143
- if transform is not None:
144
- data = transform(data)
145
- var_names = _var_names(var_names, data, filter_vars)
146
-
147
- plotters = filter_plotters_list(
148
- list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=combine_dims)),
149
- "plot_violin",
150
- )
151
-
152
- rows, cols = default_grid(len(plotters), grid=grid)
153
-
154
- if hdi_prob is None:
155
- hdi_prob = rcParams["stats.ci_prob"]
156
- elif not 1 >= hdi_prob > 0:
157
- raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
158
-
159
- violinplot_kwargs = dict(
160
- ax=ax,
161
- plotters=plotters,
162
- figsize=figsize,
163
- rows=rows,
164
- cols=cols,
165
- sharex=sharex,
166
- sharey=sharey,
167
- shade_kwargs=shade_kwargs,
168
- shade=shade,
169
- rug=rug,
170
- rug_kwargs=rug_kwargs,
171
- side=side,
172
- bw=bw,
173
- textsize=textsize,
174
- labeller=labeller,
175
- circular=circular,
176
- hdi_prob=hdi_prob,
177
- quartiles=quartiles,
178
- backend_kwargs=backend_kwargs,
179
- show=show,
180
- )
181
-
182
- if backend is None:
183
- backend = rcParams["plot.backend"]
184
- backend = backend.lower()
185
-
186
- if side not in ("both", "left", "right"):
187
- raise ValueError(f"'side' can only be 'both', 'left', or 'right', got: '{side}'")
188
-
189
- # TODO: Add backend kwargs
190
- plot = get_plotting_function("plot_violin", "violinplot", backend)
191
- ax = plot(**violinplot_kwargs)
192
- return ax
arviz/preview.py DELETED
@@ -1,58 +0,0 @@
1
- # pylint: disable=unused-import,unused-wildcard-import,wildcard-import,invalid-name
2
- """Expose features from arviz-xyz refactored packages inside ``arviz.preview`` namespace."""
3
- import logging
4
-
5
- _log = logging.getLogger(__name__)
6
-
7
- info = ""
8
-
9
- try:
10
- from arviz_base import *
11
- import arviz_base as base
12
-
13
- _status = "arviz_base available, exposing its functions as part of arviz.preview"
14
- _log.info(_status)
15
- except ModuleNotFoundError:
16
- _status = "arviz_base not installed"
17
- _log.info(_status)
18
- except ImportError:
19
- _status = "Unable to import arviz_base"
20
- _log.info(_status, exc_info=True)
21
-
22
- info += _status + "\n"
23
-
24
- try:
25
- from arviz_stats import *
26
-
27
- # the base computational module fron arviz_stats will override the alias to arviz-base
28
- # arviz.stats.base will still be available
29
- import arviz_base as base
30
- import arviz_stats as stats
31
-
32
- _status = "arviz_stats available, exposing its functions as part of arviz.preview"
33
- _log.info(_status)
34
- except ModuleNotFoundError:
35
- _status = "arviz_stats not installed"
36
- _log.info(_status)
37
- except ImportError:
38
- _status = "Unable to import arviz_stats"
39
- _log.info(_status, exc_info=True)
40
- info += _status + "\n"
41
-
42
- try:
43
- from arviz_plots import *
44
- import arviz_plots as plots
45
-
46
- _status = "arviz_plots available, exposing its functions as part of arviz.preview"
47
- _log.info(_status)
48
- except ModuleNotFoundError:
49
- _status = "arviz_plots not installed"
50
- _log.info(_status)
51
- except ImportError:
52
- _status = "Unable to import arviz_plots"
53
- _log.info(_status, exc_info=True)
54
-
55
- info += _status + "\n"
56
-
57
- # clean namespace
58
- del logging, _status, _log
arviz/py.typed DELETED
File without changes