arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
arviz/plots/plot_utils.py DELETED
@@ -1,599 +0,0 @@
1
- """Utilities for plotting."""
2
-
3
- import importlib
4
- import warnings
5
- from typing import Any, Dict
6
-
7
- import matplotlib as mpl
8
- import numpy as np
9
- import packaging
10
- from matplotlib.colors import to_hex
11
- from scipy.stats import mode, rankdata
12
- from scipy.interpolate import CubicSpline
13
-
14
-
15
- from ..rcparams import rcParams
16
- from ..stats.density_utils import kde
17
- from ..stats import hdi
18
-
19
- KwargSpec = Dict[str, Any]
20
-
21
-
22
- def make_2d(ary):
23
- """Convert any array into a 2d numpy array.
24
-
25
- In case the array is already more than 2 dimensional, will ravel the
26
- dimensions after the first.
27
- """
28
- dim_0, *_ = np.atleast_1d(ary).shape
29
- return ary.reshape(dim_0, -1, order="F")
30
-
31
-
32
- def _scale_fig_size(figsize, textsize, rows=1, cols=1):
33
- """Scale figure properties according to rows and cols.
34
-
35
- Parameters
36
- ----------
37
- figsize : float or None
38
- Size of figure in inches
39
- textsize : float or None
40
- fontsize
41
- rows : int
42
- Number of rows
43
- cols : int
44
- Number of columns
45
-
46
- Returns
47
- -------
48
- figsize : float or None
49
- Size of figure in inches
50
- ax_labelsize : int
51
- fontsize for axes label
52
- titlesize : int
53
- fontsize for title
54
- xt_labelsize : int
55
- fontsize for axes ticks
56
- linewidth : int
57
- linewidth
58
- markersize : int
59
- markersize
60
- """
61
- params = mpl.rcParams
62
- rc_width, rc_height = tuple(params["figure.figsize"])
63
- rc_ax_labelsize = params["axes.labelsize"]
64
- rc_titlesize = params["axes.titlesize"]
65
- rc_xt_labelsize = params["xtick.labelsize"]
66
- rc_linewidth = params["lines.linewidth"]
67
- rc_markersize = params["lines.markersize"]
68
- if isinstance(rc_ax_labelsize, str):
69
- rc_ax_labelsize = 15
70
- if isinstance(rc_titlesize, str):
71
- rc_titlesize = 16
72
- if isinstance(rc_xt_labelsize, str):
73
- rc_xt_labelsize = 14
74
-
75
- if figsize is None:
76
- width, height = rc_width, rc_height
77
- sff = 1 if (rows == cols == 1) else 1.15
78
- width = width * cols * sff
79
- height = height * rows * sff
80
- else:
81
- width, height = figsize
82
-
83
- if textsize is not None:
84
- scale_factor = textsize / rc_xt_labelsize
85
- elif rows == cols == 1:
86
- scale_factor = ((width * height) / (rc_width * rc_height)) ** 0.5
87
- else:
88
- scale_factor = 1
89
-
90
- ax_labelsize = rc_ax_labelsize * scale_factor
91
- titlesize = rc_titlesize * scale_factor
92
- xt_labelsize = rc_xt_labelsize * scale_factor
93
- linewidth = rc_linewidth * scale_factor
94
- markersize = rc_markersize * scale_factor
95
-
96
- return (width, height), ax_labelsize, titlesize, xt_labelsize, linewidth, markersize
97
-
98
-
99
- def default_grid(n_items, grid=None, max_cols=4, min_cols=3): # noqa: D202
100
- """Make a grid for subplots.
101
-
102
- Tries to get as close to sqrt(n_items) x sqrt(n_items) as it can,
103
- but allows for custom logic
104
-
105
- Parameters
106
- ----------
107
- n_items : int
108
- Number of panels required
109
- grid : tuple
110
- Number of rows and columns
111
- max_cols : int
112
- Maximum number of columns, inclusive
113
- min_cols : int
114
- Minimum number of columns, inclusive
115
-
116
- Returns
117
- -------
118
- (int, int)
119
- Rows and columns, so that rows * columns >= n_items
120
- """
121
-
122
- if grid is None:
123
-
124
- def in_bounds(val):
125
- return np.clip(val, min_cols, max_cols)
126
-
127
- if n_items <= max_cols:
128
- return 1, n_items
129
- ideal = in_bounds(round(n_items**0.5))
130
-
131
- for offset in (0, 1, -1, 2, -2):
132
- cols = in_bounds(ideal + offset)
133
- rows, extra = divmod(n_items, cols)
134
- if extra == 0:
135
- return rows, cols
136
- return n_items // ideal + 1, ideal
137
- else:
138
- rows, cols = grid
139
- if rows * cols < n_items:
140
- raise ValueError("The number of rows times columns is less than the number of subplots")
141
- if (rows * cols) - n_items >= cols:
142
- warnings.warn("The number of rows times columns is larger than necessary")
143
- return rows, cols
144
-
145
-
146
- def format_sig_figs(value, default=None):
147
- """Get a default number of significant figures.
148
-
149
- Gives the integer part or `default`, whichever is bigger.
150
-
151
- Examples
152
- --------
153
- 0.1234 --> 0.12
154
- 1.234 --> 1.2
155
- 12.34 --> 12
156
- 123.4 --> 123
157
- """
158
- if default is None:
159
- default = 2
160
- if value == 0:
161
- return 1
162
- return max(int(np.log10(np.abs(value))) + 1, default)
163
-
164
-
165
- def round_num(n, round_to):
166
- """
167
- Return a string representing a number with `round_to` significant figures.
168
-
169
- Parameters
170
- ----------
171
- n : float
172
- number to round
173
- round_to : int
174
- number of significant figures
175
- """
176
- sig_figs = format_sig_figs(n, round_to)
177
- return "{n:.{sig_figs}g}".format(n=n, sig_figs=sig_figs)
178
-
179
-
180
- def color_from_dim(dataarray, dim_name):
181
- """Return colors and color mapping of a DataArray using coord values as color code.
182
-
183
- Parameters
184
- ----------
185
- dataarray : xarray.DataArray
186
- dim_name : str
187
- dimension whose coordinates will be used as color code.
188
-
189
- Returns
190
- -------
191
- colors : array of floats
192
- Array of colors (as floats for use with a cmap) for each element in the dataarray.
193
- color_mapping : mapping coord_value -> float
194
- Mapping from coord values to corresponding color
195
- """
196
- present_dims = dataarray.dims
197
- coord_values = dataarray[dim_name].values
198
- unique_coords = set(coord_values)
199
- color_mapping = {coord: num / len(unique_coords) for num, coord in enumerate(unique_coords)}
200
- if len(present_dims) > 1:
201
- multi_coords = dataarray.coords.to_index()
202
- coord_idx = present_dims.index(dim_name)
203
- colors = [color_mapping[coord[coord_idx]] for coord in multi_coords]
204
- else:
205
- colors = [color_mapping[coord] for coord in coord_values]
206
- return colors, color_mapping
207
-
208
-
209
- def vectorized_to_hex(c_values, keep_alpha=False):
210
- """Convert a color (including vector of colors) to hex.
211
-
212
- Parameters
213
- ----------
214
- c: Matplotlib color
215
-
216
- keep_alpha: boolean
217
- to select if alpha values should be kept in the final hex values.
218
-
219
- Returns
220
- -------
221
- rgba_hex : vector of hex values
222
- """
223
- try:
224
- hex_color = to_hex(c_values, keep_alpha)
225
-
226
- except ValueError:
227
- hex_color = [to_hex(color, keep_alpha) for color in c_values]
228
- return hex_color
229
-
230
-
231
- def format_coords_as_labels(dataarray, skip_dims=None):
232
- """Format 1d or multi-d dataarray coords as strings.
233
-
234
- Parameters
235
- ----------
236
- dataarray : xarray.DataArray
237
- DataArray whose coordinates will be converted to labels.
238
- skip_dims : str of list_like, optional
239
- Dimensions whose values should not be included in the labels
240
- """
241
- if skip_dims is None:
242
- coord_labels = dataarray.coords.to_index()
243
- else:
244
- coord_labels = dataarray.coords.to_index().droplevel(skip_dims).drop_duplicates()
245
- coord_labels = coord_labels.values
246
- if isinstance(coord_labels[0], tuple):
247
- fmt = ", ".join(["{}" for _ in coord_labels[0]])
248
- return np.array([fmt.format(*x) for x in coord_labels])
249
- return np.array([f"{s}" for s in coord_labels])
250
-
251
-
252
- def set_xticklabels(ax, coord_labels):
253
- """Set xticklabels to label list using Matplotlib default formatter."""
254
- ax.xaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10])
255
- xticks = ax.get_xticks().astype(np.int64)
256
- xticks = xticks[(xticks >= 0) & (xticks < len(coord_labels))]
257
- if len(xticks) > len(coord_labels):
258
- ax.set_xticks(np.arange(len(coord_labels)))
259
- ax.set_xticklabels(coord_labels)
260
- else:
261
- ax.set_xticks(xticks)
262
- ax.set_xticklabels(coord_labels[xticks])
263
-
264
-
265
- def filter_plotters_list(plotters, plot_kind):
266
- """Cut list of plotters so that it is at most of length "plot.max_subplots"."""
267
- max_plots = rcParams["plot.max_subplots"]
268
- max_plots = len(plotters) if max_plots is None else max_plots
269
- if len(plotters) > max_plots:
270
- warnings.warn(
271
- "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
272
- "of variables to plot ({len_plotters}) in {plot_kind}, generating only "
273
- "{max_plots} plots".format(
274
- max_plots=max_plots, len_plotters=len(plotters), plot_kind=plot_kind
275
- ),
276
- UserWarning,
277
- )
278
- return plotters[:max_plots]
279
- return plotters
280
-
281
-
282
- def get_plotting_function(plot_name, plot_module, backend):
283
- """Return plotting function for correct backend."""
284
- _backend = {
285
- "mpl": "matplotlib",
286
- "bokeh": "bokeh",
287
- "matplotlib": "matplotlib",
288
- }
289
-
290
- if backend is None:
291
- backend = rcParams["plot.backend"]
292
- backend = backend.lower()
293
-
294
- try:
295
- backend = _backend[backend]
296
- except KeyError as err:
297
- raise KeyError(
298
- f"Backend {backend} is not implemented. Try backend in {set(_backend.values())}"
299
- ) from err
300
-
301
- if backend == "bokeh":
302
- try:
303
- import bokeh
304
-
305
- assert packaging.version.parse(bokeh.__version__) >= packaging.version.parse("1.4.0")
306
-
307
- except (ImportError, AssertionError) as err:
308
- raise ImportError(
309
- "'bokeh' backend needs Bokeh (1.4.0+) installed. Please upgrade or install"
310
- ) from err
311
-
312
- # Perform import of plotting method
313
- # TODO: Convert module import to top level for all plots
314
- module = importlib.import_module(f"arviz.plots.backends.{backend}.{plot_module}")
315
-
316
- plotting_method = getattr(module, plot_name)
317
-
318
- return plotting_method
319
-
320
-
321
- def calculate_point_estimate(point_estimate, values, bw="default", circular=False, skipna=False):
322
- """Validate and calculate the point estimate.
323
-
324
- Parameters
325
- ----------
326
- point_estimate : Optional[str]
327
- Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
328
- Defaults to 'auto' i.e. it falls back to default set in rcParams.
329
- values : 1-d array
330
- bw: Optional[float or str]
331
- If numeric, indicates the bandwidth and must be positive.
332
- If str, indicates the method to estimate the bandwidth and must be
333
- one of "scott", "silverman", "isj" or "experimental" when `circular` is False
334
- and "taylor" (for now) when `circular` is True.
335
- Defaults to "default" which means "experimental" when variable is not circular
336
- and "taylor" when it is.
337
- circular: Optional[bool]
338
- If True, it interprets the values passed are from a circular variable measured in radians
339
- and a circular KDE is used. Only valid for 1D KDE. Defaults to False.
340
- skipna=True,
341
- If true ignores nan values when computing the hdi. Defaults to false.
342
-
343
- Returns
344
- -------
345
- point_value : float
346
- best estimate of data distribution
347
- """
348
- point_value = None
349
- if point_estimate == "auto":
350
- point_estimate = rcParams["plot.point_estimate"]
351
- elif point_estimate not in ("mean", "median", "mode", None):
352
- raise ValueError(
353
- f"Point estimate should be 'mean', 'median', 'mode' or None, not {point_estimate}"
354
- )
355
- if point_estimate == "mean":
356
- point_value = np.nanmean(values) if skipna else np.mean(values)
357
- elif point_estimate == "mode":
358
- if values.dtype.kind == "f":
359
- if bw == "default":
360
- bw = "taylor" if circular else "experimental"
361
- x, density = kde(values, circular=circular, bw=bw)
362
- point_value = x[np.argmax(density)]
363
- else:
364
- point_value = int(mode(values).mode)
365
- elif point_estimate == "median":
366
- point_value = np.nanmedian(values) if skipna else np.median(values)
367
- return point_value
368
-
369
-
370
- def plot_point_interval(
371
- ax,
372
- values,
373
- point_estimate,
374
- hdi_prob,
375
- quartiles,
376
- linewidth,
377
- markersize,
378
- markercolor,
379
- marker,
380
- rotated,
381
- intervalcolor,
382
- backend="matplotlib",
383
- ):
384
- """Plot point intervals.
385
-
386
- Translates the data and represents them as point and interval summaries.
387
-
388
- Parameters
389
- ----------
390
- ax : axes
391
- Matplotlib axes
392
- values : array-like
393
- Values to plot
394
- point_estimate : str
395
- Plot point estimate per variable.
396
- linewidth : int
397
- Line width throughout.
398
- quartiles : bool
399
- If True then the quartile interval will be plotted with the HDI.
400
- markersize : int
401
- Markersize throughout.
402
- markercolor: string
403
- Color of the marker.
404
- marker: string
405
- Shape of the marker.
406
- hdi_prob : float
407
- Valid only when point_interval is True. Plots HDI for chosen percentage of density.
408
- rotated : bool
409
- Whether to rotate the dot plot by 90 degrees.
410
- intervalcolor : string
411
- Color of the interval.
412
- backend : string, optional
413
- Matplotlib or Bokeh.
414
- """
415
- endpoint = (1 - hdi_prob) / 2
416
- if quartiles:
417
- qlist_interval = [endpoint, 0.25, 0.75, 1 - endpoint]
418
- else:
419
- qlist_interval = [endpoint, 1 - endpoint]
420
- quantiles_interval = np.quantile(values, qlist_interval)
421
-
422
- quantiles_interval[0], quantiles_interval[-1] = hdi(
423
- values.flatten(), hdi_prob, multimodal=False
424
- )
425
- mid = len(quantiles_interval) // 2
426
- param_iter = zip(np.linspace(2 * linewidth, linewidth, mid, endpoint=True)[-1::-1], range(mid))
427
-
428
- if backend == "matplotlib":
429
- for width, j in param_iter:
430
- if rotated:
431
- ax.vlines(
432
- 0,
433
- quantiles_interval[j],
434
- quantiles_interval[-(j + 1)],
435
- linewidth=width,
436
- color=intervalcolor,
437
- )
438
- else:
439
- ax.hlines(
440
- 0,
441
- quantiles_interval[j],
442
- quantiles_interval[-(j + 1)],
443
- linewidth=width,
444
- color=intervalcolor,
445
- )
446
-
447
- if point_estimate:
448
- point_value = calculate_point_estimate(point_estimate, values)
449
- if rotated:
450
- ax.plot(
451
- 0,
452
- point_value,
453
- marker,
454
- markersize=markersize,
455
- color=markercolor,
456
- )
457
- else:
458
- ax.plot(
459
- point_value,
460
- 0,
461
- marker,
462
- markersize=markersize,
463
- color=markercolor,
464
- )
465
- else:
466
- for width, j in param_iter:
467
- if rotated:
468
- ax.line(
469
- [0, 0],
470
- [quantiles_interval[j], quantiles_interval[-(j + 1)]],
471
- line_width=width,
472
- color=intervalcolor,
473
- )
474
- else:
475
- ax.line(
476
- [quantiles_interval[j], quantiles_interval[-(j + 1)]],
477
- [0, 0],
478
- line_width=width,
479
- color=intervalcolor,
480
- )
481
-
482
- if point_estimate:
483
- point_value = calculate_point_estimate(point_estimate, values)
484
- if rotated:
485
- ax.scatter(
486
- x=0,
487
- y=point_value,
488
- marker="circle",
489
- size=markersize,
490
- fill_color=markercolor,
491
- )
492
- else:
493
- ax.scatter(
494
- x=point_value,
495
- y=0,
496
- marker="circle",
497
- size=markersize,
498
- fill_color=markercolor,
499
- )
500
-
501
- return ax
502
-
503
-
504
- def is_valid_quantile(value):
505
- """Check if value is a number between 0 and 1."""
506
- try:
507
- value = float(value)
508
- return 0 < value < 1
509
- except ValueError:
510
- return False
511
-
512
-
513
- def sample_reference_distribution(dist, shape):
514
- """Generate samples from a scipy distribution with a given shape."""
515
- x_ss = []
516
- densities = []
517
- dist_rvs = dist.rvs(size=shape)
518
- for idx in range(shape[1]):
519
- x_s, density = kde(dist_rvs[:, idx])
520
- x_ss.append(x_s)
521
- densities.append(density)
522
- return np.array(x_ss).T, np.array(densities).T
523
-
524
-
525
- def set_bokeh_circular_ticks_labels(ax, hist, labels):
526
- """Place ticks and ticklabels on Bokeh's circular histogram."""
527
- ticks = np.linspace(-np.pi, np.pi, len(labels), endpoint=False)
528
- ax.annular_wedge(
529
- x=0,
530
- y=0,
531
- inner_radius=0,
532
- outer_radius=np.max(hist) * 1.1,
533
- start_angle=ticks,
534
- end_angle=ticks,
535
- line_color="grey",
536
- )
537
-
538
- radii_circles = np.linspace(0, np.max(hist) * 1.1, 4)
539
- ax.scatter(0, 0, marker="circle", radius=radii_circles, fill_color=None, line_color="grey")
540
-
541
- offset = np.max(hist * 1.05) * 0.15
542
- ticks_labels_pos_1 = np.max(hist * 1.05)
543
- ticks_labels_pos_2 = ticks_labels_pos_1 * np.sqrt(2) / 2
544
-
545
- ax.text(
546
- [
547
- ticks_labels_pos_1 + offset,
548
- ticks_labels_pos_2 + offset,
549
- 0,
550
- -ticks_labels_pos_2 - offset,
551
- -ticks_labels_pos_1 - offset,
552
- -ticks_labels_pos_2 - offset,
553
- 0,
554
- ticks_labels_pos_2 + offset,
555
- ],
556
- [
557
- 0,
558
- ticks_labels_pos_2 + offset / 2,
559
- ticks_labels_pos_1 + offset,
560
- ticks_labels_pos_2 + offset / 2,
561
- 0,
562
- -ticks_labels_pos_2 - offset,
563
- -ticks_labels_pos_1 - offset,
564
- -ticks_labels_pos_2 - offset,
565
- ],
566
- text=labels,
567
- text_align="center",
568
- )
569
-
570
- return ax
571
-
572
-
573
- def compute_ranks(ary):
574
- """Compute ranks for continuous and discrete variables."""
575
- if ary.dtype.kind == "i":
576
- ary_shape = ary.shape
577
- ary = ary.flatten()
578
- min_ary, max_ary = min(ary), max(ary)
579
- x = np.linspace(min_ary, max_ary, len(ary))
580
- csi = CubicSpline(x, ary)
581
- ary = csi(np.linspace(min_ary + 0.001, max_ary - 0.001, len(ary))).reshape(ary_shape)
582
- ranks = rankdata(ary, method="average").reshape(ary.shape)
583
-
584
- return ranks
585
-
586
-
587
- def _init_kwargs_dict(kwargs):
588
- """Initialize kwargs dict.
589
-
590
- If the input is a dictionary, it returns
591
- a copy of the dictionary, otherwise it
592
- returns an empty dictionary.
593
-
594
- Parameters
595
- ----------
596
- kwargs : dict or None
597
- kwargs dict to initialize
598
- """
599
- return {} if kwargs is None else kwargs.copy()