arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,772 +0,0 @@
1
- # pylint: disable=all
2
- """Bokeh forestplot."""
3
- from collections import OrderedDict, defaultdict
4
- from itertools import cycle, tee
5
-
6
- import bokeh.plotting as bkp
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
- from bokeh.models import Band, ColumnDataSource, DataRange1d
10
- from bokeh.models.annotations import Title, Legend
11
- from bokeh.models.tickers import FixedTicker
12
-
13
- from ....sel_utils import xarray_var_iter
14
- from ....rcparams import rcParams
15
- from ....stats import hdi
16
- from ....stats.density_utils import get_bins, histogram, kde
17
- from ....stats.diagnostics import _ess, _rhat
18
- from ...plot_utils import _scale_fig_size
19
- from .. import show_layout
20
- from . import backend_kwarg_defaults
21
-
22
-
23
- def pairwise(iterable):
24
- """From itertools cookbook. [a, b, c, ...] -> (a, b), (b, c), ..."""
25
- first, second = tee(iterable)
26
- next(second, None)
27
- return zip(first, second)
28
-
29
-
30
- def plot_forest(
31
- ax,
32
- datasets,
33
- var_names,
34
- model_names,
35
- combined,
36
- combine_dims,
37
- colors,
38
- figsize,
39
- width_ratios,
40
- linewidth,
41
- markersize,
42
- kind,
43
- ncols,
44
- hdi_prob,
45
- quartiles,
46
- rope,
47
- ridgeplot_overlap,
48
- ridgeplot_alpha,
49
- ridgeplot_kind,
50
- ridgeplot_truncate,
51
- ridgeplot_quantiles,
52
- textsize,
53
- legend,
54
- labeller,
55
- ess,
56
- r_hat,
57
- backend_config,
58
- backend_kwargs,
59
- show,
60
- ):
61
- """Bokeh forest plot."""
62
- plot_handler = PlotHandler(
63
- datasets,
64
- var_names=var_names,
65
- model_names=model_names,
66
- combined=combined,
67
- combine_dims=combine_dims,
68
- colors=colors,
69
- labeller=labeller,
70
- )
71
-
72
- if figsize is None:
73
- if kind == "ridgeplot":
74
- figsize = (min(14, sum(width_ratios) * 3), plot_handler.fig_height() * 3)
75
- else:
76
- figsize = (min(12, sum(width_ratios) * 2), plot_handler.fig_height())
77
-
78
- (figsize, _, _, _, auto_linewidth, auto_markersize) = _scale_fig_size(figsize, textsize, 1.1, 1)
79
-
80
- if linewidth is None:
81
- linewidth = auto_linewidth
82
-
83
- if markersize is None:
84
- markersize = auto_markersize
85
-
86
- if backend_config is None:
87
- backend_config = {}
88
-
89
- backend_config = {
90
- **backend_kwarg_defaults(
91
- ("bounds_x_range", "plot.bokeh.bounds_x_range"),
92
- ("bounds_y_range", "plot.bokeh.bounds_y_range"),
93
- ),
94
- **backend_config,
95
- }
96
-
97
- if backend_kwargs is None:
98
- backend_kwargs = {}
99
-
100
- backend_kwargs = {
101
- **backend_kwarg_defaults(
102
- ("dpi", "plot.bokeh.figure.dpi"),
103
- ),
104
- **backend_kwargs,
105
- }
106
- dpi = backend_kwargs.pop("dpi")
107
-
108
- if ax is None:
109
- axes = []
110
-
111
- for i, width_r in zip(range(ncols), width_ratios):
112
- backend_kwargs_i = backend_kwargs.copy()
113
- backend_kwargs_i.setdefault("height", int(figsize[1] * dpi))
114
- backend_kwargs_i.setdefault(
115
- "width", int(figsize[0] * (width_r / sum(width_ratios)) * dpi * 1.25)
116
- )
117
- ax = bkp.figure(
118
- **backend_kwargs_i,
119
- )
120
- if i == 0:
121
- backend_kwargs.setdefault("y_range", ax.y_range)
122
- axes.append(ax)
123
- else:
124
- axes = ax
125
-
126
- axes = np.atleast_2d(axes)
127
-
128
- plotted = defaultdict(list)
129
-
130
- if kind == "forestplot":
131
- plot_handler.forestplot(
132
- hdi_prob,
133
- quartiles,
134
- linewidth,
135
- markersize,
136
- axes[0, 0],
137
- rope,
138
- plotted,
139
- )
140
- elif kind == "ridgeplot":
141
- plot_handler.ridgeplot(
142
- hdi_prob,
143
- ridgeplot_overlap,
144
- linewidth,
145
- markersize,
146
- ridgeplot_alpha,
147
- ridgeplot_kind,
148
- ridgeplot_truncate,
149
- ridgeplot_quantiles,
150
- axes[0, 0],
151
- plotted,
152
- )
153
- else:
154
- raise TypeError(
155
- f"Argument 'kind' must be one of 'forestplot' or 'ridgeplot' (you provided {kind})"
156
- )
157
-
158
- idx = 1
159
- if ess:
160
- plotted_ess = defaultdict(list)
161
- plot_handler.plot_neff(axes[0, idx], markersize, plotted_ess)
162
- if legend:
163
- plot_handler.legend(axes[0, idx], plotted_ess)
164
- idx += 1
165
-
166
- if r_hat:
167
- plotted_r_hat = defaultdict(list)
168
- plot_handler.plot_rhat(axes[0, idx], markersize, plotted_r_hat)
169
- if legend:
170
- plot_handler.legend(axes[0, idx], plotted_r_hat)
171
- idx += 1
172
-
173
- all_plotters = list(plot_handler.plotters.values())
174
- y_max = plot_handler.y_max() - all_plotters[-1].group_offset
175
- if kind == "ridgeplot": # space at the top
176
- y_max += ridgeplot_overlap
177
-
178
- for i, ax_ in enumerate(axes.ravel()):
179
- if kind == "ridgeplot":
180
- ax_.xgrid.grid_line_color = None
181
- ax_.ygrid.grid_line_color = None
182
- else:
183
- ax_.ygrid.grid_line_color = None
184
-
185
- if i != 0:
186
- ax_.yaxis.visible = False
187
-
188
- ax_.outline_line_color = None
189
- ax_.x_range = DataRange1d(bounds=backend_config["bounds_x_range"], min_interval=1)
190
- ax_.y_range = DataRange1d(bounds=backend_config["bounds_y_range"], min_interval=2)
191
-
192
- ax_.y_range._property_values["start"] = -all_plotters[ # pylint: disable=protected-access
193
- 0
194
- ].group_offset
195
- ax_.y_range._property_values["end"] = y_max # pylint: disable=protected-access
196
-
197
- labels, ticks = plot_handler.labels_and_ticks()
198
- ticks = [int(tick) if (tick).is_integer() else tick for tick in ticks]
199
-
200
- axes[0, 0].yaxis.ticker = FixedTicker(ticks=ticks)
201
- axes[0, 0].yaxis.major_label_overrides = dict(zip(map(str, ticks), map(str, labels)))
202
-
203
- if legend:
204
- plot_handler.legend(axes[0, 0], plotted)
205
- show_layout(axes, show)
206
-
207
- return axes
208
-
209
-
210
- class PlotHandler:
211
- """Class to handle logic from ForestPlot."""
212
-
213
- # pylint: disable=inconsistent-return-statements
214
-
215
- def __init__(self, datasets, var_names, model_names, combined, combine_dims, colors, labeller):
216
- self.data = datasets
217
-
218
- if model_names is None:
219
- if len(self.data) > 1:
220
- model_names = [f"Model {idx}" for idx, _ in enumerate(self.data)]
221
- else:
222
- model_names = [""]
223
- elif len(model_names) != len(self.data):
224
- raise ValueError("The number of model names does not match the number of models")
225
-
226
- self.model_names = list(reversed(model_names)) # y-values are upside down
227
-
228
- if var_names is None:
229
- if len(self.data) > 1:
230
- self.var_names = list(
231
- set().union(*[OrderedDict(datum.data_vars) for datum in self.data])
232
- )
233
- else:
234
- self.var_names = list(
235
- reversed(*[OrderedDict(datum.data_vars) for datum in self.data])
236
- )
237
- else:
238
- self.var_names = list(reversed(var_names)) # y-values are upside down
239
-
240
- self.combined = combined
241
- self.combine_dims = combine_dims
242
-
243
- if colors == "cycle":
244
- colors = [
245
- prop
246
- for _, prop in zip(
247
- range(len(self.data)), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
248
- )
249
- ]
250
- elif isinstance(colors, str):
251
- colors = [colors for _ in self.data]
252
-
253
- self.colors = list(reversed(colors)) # y-values are upside down
254
- self.labeller = labeller
255
-
256
- self.plotters = self.make_plotters()
257
-
258
- def make_plotters(self):
259
- """Initialize an object for each variable to be plotted."""
260
- plotters, y = {}, 0
261
- for var_name in self.var_names:
262
- plotters[var_name] = VarHandler(
263
- var_name,
264
- self.data,
265
- y,
266
- model_names=self.model_names,
267
- combined=self.combined,
268
- combine_dims=self.combine_dims,
269
- colors=self.colors,
270
- labeller=self.labeller,
271
- )
272
- y = plotters[var_name].y_max()
273
- return plotters
274
-
275
- def labels_and_ticks(self):
276
- """Collect labels and ticks from plotters."""
277
- val = self.plotters.values()
278
-
279
- def label_idxs():
280
- labels, idxs = [], []
281
- for plotter in val:
282
- sub_labels, sub_idxs, _, _, _ = plotter.labels_ticks_and_vals()
283
- labels_to_idxs = defaultdict(list)
284
- for label, idx in zip(sub_labels, sub_idxs):
285
- labels_to_idxs[label].append(idx)
286
- sub_idxs = []
287
- sub_labels = []
288
- for label, all_idx in labels_to_idxs.items():
289
- sub_labels.append(label)
290
- sub_idxs.append(np.mean([j for j in all_idx]))
291
- labels.append(sub_labels)
292
- idxs.append(sub_idxs)
293
- return np.concatenate(labels), np.concatenate(idxs)
294
-
295
- return label_idxs()
296
-
297
- def legend(self, ax, plotted):
298
- """Add interactive legend with colorcoded model info."""
299
- legend_it = []
300
- for model_name, glyphs in plotted.items():
301
- legend_it.append((model_name, glyphs))
302
-
303
- legend = Legend(items=legend_it, orientation="vertical", location="top_left")
304
- ax.add_layout(legend, "above")
305
- ax.legend.click_policy = "hide"
306
-
307
- def display_multiple_ropes(
308
- self, rope, ax, y, linewidth, var_name, selection, plotted, model_name
309
- ):
310
- """Display ROPE when more than one interval is provided."""
311
- for sel in rope.get(var_name, []):
312
- # pylint: disable=line-too-long
313
- if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"):
314
- vals = sel["rope"]
315
- plotted[model_name].append(
316
- ax.line(
317
- vals,
318
- (y + 0.05, y + 0.05),
319
- line_width=linewidth * 2,
320
- color=[
321
- color
322
- for _, color in zip(
323
- range(3), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
324
- )
325
- ][2],
326
- line_alpha=0.7,
327
- )
328
- )
329
- return ax
330
-
331
- def ridgeplot(
332
- self,
333
- hdi_prob,
334
- mult,
335
- linewidth,
336
- markersize,
337
- alpha,
338
- ridgeplot_kind,
339
- ridgeplot_truncate,
340
- ridgeplot_quantiles,
341
- ax,
342
- plotted,
343
- ):
344
- """Draw ridgeplot for each plotter.
345
-
346
- Parameters
347
- ----------
348
- hdi_prob : float
349
- Probability for the highest density interval.
350
- mult : float
351
- How much to multiply height by. Set this to greater than 1 to have some overlap.
352
- linewidth : float
353
- Width of line on border of ridges
354
- markersize : float
355
- Size of marker in center of forestplot line
356
- alpha : float
357
- Transparency of ridges
358
- ridgeplot_kind : string
359
- By default ("auto") continuous variables are plotted using KDEs and discrete ones using
360
- histograms. To override this use "hist" to plot histograms and "density" for KDEs
361
- ridgeplot_truncate: bool
362
- Whether to truncate densities according to the value of hdi_prop. Defaults to True
363
- ridgeplot_quantiles: list
364
- Quantiles in ascending order used to segment the KDE. Use [.25, .5, .75] for quartiles.
365
- Defaults to None.
366
- ax : Axes
367
- Axes to draw on
368
- plotted : dict
369
- Contains glyphs for each model
370
- """
371
- if alpha is None:
372
- alpha = 1.0
373
- for plotter in list(self.plotters.values())[::-1]:
374
- for x, y_min, y_max, hdi_, y_q, color, model_name in plotter.ridgeplot(
375
- hdi_prob, mult, ridgeplot_kind
376
- ):
377
- if alpha == 0:
378
- border = color
379
- facecolor = None
380
- else:
381
- border = "black"
382
- facecolor = color
383
- if x.dtype.kind == "i":
384
- if ridgeplot_truncate:
385
- y_max = y_max[(x >= hdi_[0]) & (x <= hdi_[1])]
386
- x = x[(x >= hdi_[0]) & (x <= hdi_[1])]
387
- else:
388
- facecolor = color
389
- alpha = [alpha if ci else 0 for ci in ((x >= hdi_[0]) & (x <= hdi_[1]))]
390
- y_min = np.ones_like(x) * y_min
391
- plotted[model_name].append(
392
- ax.vbar(
393
- x=x,
394
- top=y_max - y_min,
395
- bottom=y_min,
396
- width=0.9,
397
- line_color=border,
398
- color=facecolor,
399
- fill_alpha=alpha,
400
- )
401
- )
402
- else:
403
- tr_x = x[(x >= hdi_[0]) & (x <= hdi_[1])]
404
- tr_y_min = np.ones_like(tr_x) * y_min
405
- tr_y_max = y_max[(x >= hdi_[0]) & (x <= hdi_[1])]
406
- y_min = np.ones_like(x) * y_min
407
- patch = ax.patch(
408
- np.concatenate([tr_x, tr_x[::-1]]),
409
- np.concatenate([tr_y_min, tr_y_max[::-1]]),
410
- fill_color=color,
411
- fill_alpha=alpha,
412
- line_width=0,
413
- )
414
- patch.level = "overlay"
415
- plotted[model_name].append(patch)
416
- if ridgeplot_truncate:
417
- plotted[model_name].append(
418
- ax.line(
419
- x, y_max, line_dash="solid", line_width=linewidth, line_color=border
420
- )
421
- )
422
- plotted[model_name].append(
423
- ax.line(
424
- x, y_min, line_dash="solid", line_width=linewidth, line_color=border
425
- )
426
- )
427
- else:
428
- plotted[model_name].append(
429
- ax.line(
430
- tr_x,
431
- tr_y_max,
432
- line_dash="solid",
433
- line_width=linewidth,
434
- line_color=border,
435
- )
436
- )
437
- plotted[model_name].append(
438
- ax.line(
439
- tr_x,
440
- tr_y_min,
441
- line_dash="solid",
442
- line_width=linewidth,
443
- line_color=border,
444
- )
445
- )
446
- if ridgeplot_quantiles is not None:
447
- quantiles = [x[np.sum(y_q < quant)] for quant in ridgeplot_quantiles]
448
- plotted[model_name].append(
449
- ax.diamond(
450
- quantiles,
451
- np.ones_like(quantiles) * y_min[0],
452
- line_color="black",
453
- fill_color="black",
454
- size=markersize,
455
- )
456
- )
457
-
458
- return ax
459
-
460
- def forestplot(self, hdi_prob, quartiles, linewidth, markersize, ax, rope, plotted):
461
- """Draw forestplot for each plotter.
462
-
463
- Parameters
464
- ----------
465
- hdi_prob : float
466
- Probability for the highest density interval. Width of each line.
467
- quartiles : bool
468
- Whether to mark quartiles
469
- linewidth : float
470
- Width of forestplot line
471
- markersize : float
472
- Size of marker in center of forestplot line
473
- ax : Axes
474
- Axes to draw on
475
- plotted : dict
476
- Contains glyphs for each model
477
- """
478
- if rope is None or isinstance(rope, dict):
479
- pass
480
- elif len(rope) == 2:
481
- cds = ColumnDataSource(
482
- {
483
- "x": rope,
484
- "lower": [-2 * self.y_max(), -2 * self.y_max()],
485
- "upper": [self.y_max() * 2, self.y_max() * 2],
486
- }
487
- )
488
-
489
- band = Band(
490
- base="x",
491
- lower="lower",
492
- upper="upper",
493
- fill_color=[
494
- color
495
- for _, color in zip(
496
- range(4), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
497
- )
498
- ][2],
499
- line_alpha=0.5,
500
- source=cds,
501
- )
502
-
503
- ax.renderers.append(band)
504
- else:
505
- raise ValueError(
506
- "Argument `rope` must be None, a dictionary like"
507
- '{"var_name": {"rope": (lo, hi)}}, or an '
508
- "iterable of length 2"
509
- )
510
- # Quantiles to be calculated
511
- endpoint = 100 * (1 - hdi_prob) / 2
512
- if quartiles:
513
- qlist = [endpoint, 25, 50, 75, 100 - endpoint]
514
- else:
515
- qlist = [endpoint, 50, 100 - endpoint]
516
-
517
- for plotter in self.plotters.values():
518
- for y, model_name, selection, values, color in plotter.treeplot(qlist, hdi_prob):
519
- if isinstance(rope, dict):
520
- self.display_multiple_ropes(
521
- rope, ax, y, linewidth, plotter.var_name, selection, plotted, model_name
522
- )
523
-
524
- mid = len(values) // 2
525
- param_iter = zip(
526
- np.linspace(2 * linewidth, linewidth, mid, endpoint=True)[-1::-1], range(mid)
527
- )
528
- for width, j in param_iter:
529
- plotted[model_name].append(
530
- ax.line(
531
- [values[j], values[-(j + 1)]],
532
- [y, y],
533
- line_width=width,
534
- line_color=color,
535
- )
536
- )
537
- plotted[model_name].append(
538
- ax.scatter(
539
- x=values[mid],
540
- y=y,
541
- marker="circle",
542
- size=markersize * 0.75,
543
- fill_color=color,
544
- )
545
- )
546
- _title = Title()
547
- _title.text = f"{hdi_prob:.1%} HDI"
548
- ax.title = _title
549
-
550
- return ax
551
-
552
- def plot_neff(self, ax, markersize, plotted):
553
- """Draw effective n for each plotter."""
554
- max_ess = 0
555
- for plotter in self.plotters.values():
556
- for y, ess, color, model_name in plotter.ess():
557
- if ess is not None:
558
- plotted[model_name].append(
559
- ax.scatter(
560
- x=ess,
561
- y=y,
562
- marker="circle",
563
- fill_color=color,
564
- size=markersize,
565
- line_color="black",
566
- )
567
- )
568
- if ess > max_ess:
569
- max_ess = ess
570
- ax.x_range._property_values["start"] = 0 # pylint: disable=protected-access
571
- ax.x_range._property_values["end"] = 1.07 * max_ess # pylint: disable=protected-access
572
-
573
- _title = Title()
574
- _title.text = "ess"
575
- ax.title = _title
576
-
577
- ax.xaxis[0].ticker.desired_num_ticks = 3
578
-
579
- return ax
580
-
581
- def plot_rhat(self, ax, markersize, plotted):
582
- """Draw r-hat for each plotter."""
583
- for plotter in self.plotters.values():
584
- for y, r_hat, color, model_name in plotter.r_hat():
585
- if r_hat is not None:
586
- plotted[model_name].append(
587
- ax.scatter(
588
- x=r_hat,
589
- y=y,
590
- marker="circle",
591
- fill_color=color,
592
- size=markersize,
593
- line_color="black",
594
- )
595
- )
596
- ax.x_range._property_values["start"] = 0.9 # pylint: disable=protected-access
597
- ax.x_range._property_values["end"] = 2.1 # pylint: disable=protected-access
598
-
599
- _title = Title()
600
- _title.text = "r_hat"
601
- ax.title = _title
602
-
603
- ax.xaxis[0].ticker.desired_num_ticks = 3
604
-
605
- return ax
606
-
607
- def fig_height(self):
608
- """Figure out the height of this plot."""
609
- # hand-tuned
610
- return (
611
- 4
612
- + len(self.data) * len(self.var_names)
613
- - 1
614
- + 0.1 * sum(1 for j in self.plotters.values() for _ in j.iterator())
615
- )
616
-
617
- def y_max(self):
618
- """Get maximum y value for the plot."""
619
- return max(p.y_max() for p in self.plotters.values())
620
-
621
-
622
- class VarHandler:
623
- """Handle individual variable logic."""
624
-
625
- def __init__(
626
- self, var_name, data, y_start, model_names, combined, combine_dims, colors, labeller
627
- ):
628
- self.var_name = var_name
629
- self.data = data
630
- self.y_start = y_start
631
- self.model_names = model_names
632
- self.combined = combined
633
- self.combine_dims = combine_dims
634
- self.colors = colors
635
- self.labeller = labeller
636
- self.model_color = dict(zip(self.model_names, self.colors))
637
- max_chains = max(datum.chain.max().values for datum in data)
638
- self.chain_offset = len(data) * 0.45 / max(1, max_chains)
639
- self.var_offset = 1.5 * self.chain_offset
640
- self.group_offset = 2 * self.var_offset
641
-
642
- def iterator(self):
643
- """Iterate over models and chains for each variable."""
644
- if self.combined:
645
- grouped_data = [[(0, datum)] for datum in self.data]
646
- skip_dims = self.combine_dims.union({"chain"})
647
- else:
648
- grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
649
- skip_dims = self.combine_dims
650
-
651
- label_dict = OrderedDict()
652
- selection_list = []
653
- for name, grouped_datum in zip(self.model_names, grouped_data):
654
- for _, sub_data in grouped_datum:
655
- datum_iter = xarray_var_iter(
656
- sub_data.squeeze(),
657
- var_names=[self.var_name],
658
- skip_dims=skip_dims,
659
- reverse_selections=True,
660
- )
661
- datum_list = list(datum_iter)
662
- for _, selection, isel, values in datum_list:
663
- selection_list.append(selection)
664
- if not selection or not len(selection_list) % len(datum_list):
665
- var_name = self.var_name
666
- else:
667
- var_name = ""
668
- label = self.labeller.make_label_flat(var_name, selection, isel)
669
- if label not in label_dict:
670
- label_dict[label] = OrderedDict()
671
- if name not in label_dict[label]:
672
- label_dict[label][name] = []
673
- label_dict[label][name].append(values)
674
-
675
- y = self.y_start
676
- for idx, (label, model_data) in enumerate(label_dict.items()):
677
- for model_name, value_list in model_data.items():
678
- row_label = self.labeller.make_model_label(model_name, label)
679
- for values in value_list:
680
- yield y, row_label, model_name, label, selection_list[
681
- idx
682
- ], values, self.model_color[model_name]
683
- y += self.chain_offset
684
- y += self.var_offset
685
- y += self.group_offset
686
-
687
- def labels_ticks_and_vals(self):
688
- """Get labels, ticks, values, and colors for the variable."""
689
- y_ticks = defaultdict(list)
690
- for y, label, model_name, _, _, vals, color in self.iterator():
691
- y_ticks[label].append((y, vals, color, model_name))
692
- labels, ticks, vals, colors, model_names = [], [], [], [], []
693
- for label, all_data in y_ticks.items():
694
- for data in all_data:
695
- labels.append(label)
696
- ticks.append(data[0])
697
- vals.append(np.array(data[1]))
698
- model_names.append(data[3])
699
- colors.append(data[2]) # the colors are all the same
700
- return labels, ticks, vals, colors, model_names
701
-
702
- def treeplot(self, qlist, hdi_prob):
703
- """Get data for each treeplot for the variable."""
704
- for y, _, model_name, _, selection, values, color in self.iterator():
705
- ntiles = np.percentile(values.flatten(), qlist)
706
- ntiles[0], ntiles[-1] = hdi(values.flatten(), hdi_prob, multimodal=False)
707
- yield y, model_name, selection, ntiles, color
708
-
709
- def ridgeplot(self, hdi_prob, mult, ridgeplot_kind):
710
- """Get data for each ridgeplot for the variable."""
711
- xvals, hdi_vals, yvals, pdfs, pdfs_q, colors, model_names = [], [], [], [], [], [], []
712
-
713
- for y, _, model_name, *_, values, color in self.iterator():
714
- yvals.append(y)
715
- colors.append(color)
716
- model_names.append(model_name)
717
- values = values.flatten()
718
- values = values[np.isfinite(values)]
719
-
720
- if hdi_prob != 1:
721
- hdi_ = hdi(values, hdi_prob, multimodal=False)
722
- else:
723
- hdi_ = min(values), max(values)
724
-
725
- if ridgeplot_kind == "auto":
726
- kind = "hist" if np.all(np.mod(values, 1) == 0) else "density"
727
- else:
728
- kind = ridgeplot_kind
729
-
730
- if kind == "hist":
731
- bins = get_bins(values)
732
- _, density, x = histogram(values, bins=bins)
733
- x = x[:-1]
734
- elif kind == "density":
735
- x, density = kde(values)
736
-
737
- density_q = density.cumsum() / density.sum()
738
-
739
- xvals.append(x)
740
- pdfs.append(density)
741
- pdfs_q.append(density_q)
742
- hdi_vals.append(hdi_)
743
-
744
- scaling = max(np.max(j) for j in pdfs)
745
- for y, x, hdi_val, pdf, pdf_q, color, model_name in zip(
746
- yvals, xvals, hdi_vals, pdfs, pdfs_q, colors, model_names
747
- ):
748
- yield x, y, mult * pdf / scaling + y, hdi_val, pdf_q, color, model_name
749
-
750
- def ess(self):
751
- """Get effective n data for the variable."""
752
- _, y_vals, values, colors, model_names = self.labels_ticks_and_vals()
753
- for y, value, color, model_name in zip(y_vals, values, colors, model_names):
754
- yield y, _ess(value), color, model_name
755
-
756
- def r_hat(self):
757
- """Get rhat data for the variable."""
758
- _, y_vals, values, colors, model_names = self.labels_ticks_and_vals()
759
- for y, value, color, model_name in zip(y_vals, values, colors, model_names):
760
- if value.ndim != 2 or value.shape[0] < 2:
761
- yield y, None, color, model_name
762
- else:
763
- yield y, _rhat(value), color, model_name
764
-
765
- def y_max(self):
766
- """Get max y value for the variable."""
767
- end_y = max(y for y, *_ in self.iterator())
768
-
769
- if self.combined:
770
- end_y += self.group_offset
771
-
772
- return end_y + 2 * self.group_offset