arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
@@ -1,656 +0,0 @@
1
- """Matplotlib forestplot."""
2
-
3
- from collections import OrderedDict, defaultdict
4
- from itertools import tee
5
-
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
- from matplotlib.colors import to_rgba
9
- from matplotlib.lines import Line2D
10
-
11
- from ....stats import hdi
12
- from ....stats.density_utils import get_bins, histogram, kde
13
- from ....stats.diagnostics import _ess, _rhat
14
- from ....sel_utils import xarray_var_iter
15
- from ...plot_utils import _scale_fig_size
16
- from . import backend_kwarg_defaults, backend_show
17
-
18
-
19
- def pairwise(iterable):
20
- """From itertools cookbook. [a, b, c, ...] -> (a, b), (b, c), ..."""
21
- first, second = tee(iterable)
22
- next(second, None)
23
- return zip(first, second)
24
-
25
-
26
- def plot_forest(
27
- ax,
28
- datasets,
29
- var_names,
30
- model_names,
31
- combined,
32
- combine_dims,
33
- colors,
34
- figsize,
35
- width_ratios,
36
- linewidth,
37
- markersize,
38
- kind,
39
- ncols,
40
- hdi_prob,
41
- quartiles,
42
- rope,
43
- ridgeplot_overlap,
44
- ridgeplot_alpha,
45
- ridgeplot_kind,
46
- ridgeplot_truncate,
47
- ridgeplot_quantiles,
48
- textsize,
49
- legend,
50
- labeller,
51
- ess,
52
- r_hat,
53
- backend_kwargs,
54
- backend_config, # pylint: disable=unused-argument
55
- show,
56
- ):
57
- """Matplotlib forest plot."""
58
- plot_handler = PlotHandler(
59
- datasets,
60
- var_names=var_names,
61
- model_names=model_names,
62
- combined=combined,
63
- combine_dims=combine_dims,
64
- colors=colors,
65
- labeller=labeller,
66
- )
67
-
68
- if figsize is None:
69
- if kind == "ridgeplot":
70
- figsize = (min(14, sum(width_ratios) * 4), plot_handler.fig_height() * 1.2)
71
- else:
72
- figsize = (min(12, sum(width_ratios) * 2), plot_handler.fig_height())
73
-
74
- (figsize, _, titlesize, xt_labelsize, auto_linewidth, auto_markersize) = _scale_fig_size(
75
- figsize, textsize, 1.1, 1
76
- )
77
-
78
- if linewidth is None:
79
- linewidth = auto_linewidth
80
-
81
- if markersize is None:
82
- markersize = auto_markersize
83
-
84
- if backend_kwargs is None:
85
- backend_kwargs = {}
86
-
87
- backend_kwargs = {
88
- **backend_kwarg_defaults(),
89
- **backend_kwargs,
90
- }
91
-
92
- if ax is None:
93
- _, axes = plt.subplots(
94
- nrows=1,
95
- ncols=ncols,
96
- figsize=figsize,
97
- gridspec_kw={"width_ratios": width_ratios},
98
- sharey=True,
99
- **backend_kwargs,
100
- )
101
- else:
102
- axes = ax
103
-
104
- axes = np.atleast_1d(axes)
105
- if kind == "forestplot":
106
- plot_handler.forestplot(
107
- hdi_prob,
108
- quartiles,
109
- xt_labelsize,
110
- titlesize,
111
- linewidth,
112
- markersize,
113
- axes[0],
114
- rope,
115
- )
116
- elif kind == "ridgeplot":
117
- plot_handler.ridgeplot(
118
- hdi_prob,
119
- ridgeplot_overlap,
120
- linewidth,
121
- markersize,
122
- ridgeplot_alpha,
123
- ridgeplot_kind,
124
- ridgeplot_truncate,
125
- ridgeplot_quantiles,
126
- axes[0],
127
- )
128
- else:
129
- raise TypeError(
130
- f"Argument 'kind' must be one of 'forestplot' " f"or 'ridgeplot' (you provided {kind})"
131
- )
132
-
133
- idx = 1
134
- if ess:
135
- plot_handler.plot_neff(axes[idx], xt_labelsize, titlesize, markersize)
136
- idx += 1
137
-
138
- if r_hat:
139
- plot_handler.plot_rhat(axes[idx], xt_labelsize, titlesize, markersize)
140
- idx += 1
141
-
142
- for ax_ in axes:
143
- if kind == "ridgeplot":
144
- ax_.grid(False)
145
- else:
146
- ax_.grid(False, axis="y")
147
- # Remove ticklines on y-axes
148
- ax_.tick_params(axis="y", left=False, right=False)
149
-
150
- for loc, spine in ax_.spines.items():
151
- if loc in ["left", "right"]:
152
- spine.set_visible(False)
153
-
154
- if len(plot_handler.data) > 1:
155
- plot_handler.make_bands(ax_)
156
-
157
- labels, ticks = plot_handler.labels_and_ticks()
158
- axes[0].set_yticks(ticks)
159
- axes[0].set_yticklabels(labels)
160
- all_plotters = list(plot_handler.plotters.values())
161
- y_max = plot_handler.y_max() - all_plotters[-1].group_offset
162
- if kind == "ridgeplot": # space at the top
163
- y_max += ridgeplot_overlap
164
- axes[0].set_ylim(-all_plotters[0].group_offset, y_max)
165
- if legend:
166
- plot_handler.legend(ax=axes[0])
167
-
168
- if backend_show(show):
169
- plt.show()
170
-
171
- return axes
172
-
173
-
174
- class PlotHandler:
175
- """Class to handle logic from ForestPlot."""
176
-
177
- # pylint: disable=inconsistent-return-statements
178
-
179
- def __init__(self, datasets, var_names, model_names, combined, combine_dims, colors, labeller):
180
- self.data = datasets
181
-
182
- if model_names is None:
183
- if len(self.data) > 1:
184
- model_names = [f"Model {idx}" for idx, _ in enumerate(self.data)]
185
- else:
186
- model_names = [None]
187
- elif len(model_names) != len(self.data):
188
- raise ValueError("The number of model names does not match the number of models")
189
-
190
- self.model_names = list(reversed(model_names)) # y-values are upside down
191
-
192
- if var_names is None:
193
- if len(self.data) > 1:
194
- self.var_names = list(
195
- set().union(*[OrderedDict(datum.data_vars) for datum in self.data])
196
- )
197
- else:
198
- self.var_names = list(
199
- reversed(*[OrderedDict(datum.data_vars) for datum in self.data])
200
- )
201
- else:
202
- self.var_names = list(reversed(var_names)) # y-values are upside down
203
-
204
- self.combined = combined
205
- self.combine_dims = combine_dims
206
-
207
- if colors == "cycle":
208
- # TODO: Use matplotlib prop cycle instead
209
- colors = [f"C{idx}" for idx, _ in enumerate(self.data)]
210
- elif isinstance(colors, str):
211
- colors = [colors for _ in self.data]
212
-
213
- self.colors = list(reversed(colors)) # y-values are upside down
214
- self.labeller = labeller
215
-
216
- self.plotters = self.make_plotters()
217
-
218
- def make_plotters(self):
219
- """Initialize an object for each variable to be plotted."""
220
- plotters, y = {}, 0
221
- for var_name in self.var_names:
222
- plotters[var_name] = VarHandler(
223
- var_name,
224
- self.data,
225
- y,
226
- model_names=self.model_names,
227
- combined=self.combined,
228
- combine_dims=self.combine_dims,
229
- colors=self.colors,
230
- labeller=self.labeller,
231
- )
232
- y = plotters[var_name].y_max()
233
- return plotters
234
-
235
- def labels_and_ticks(self):
236
- """Collect labels and ticks from plotters."""
237
- val = self.plotters.values()
238
-
239
- def label_idxs():
240
- labels, idxs = [], []
241
- for plotter in val:
242
- sub_labels, sub_idxs, _, _ = plotter.labels_ticks_and_vals()
243
- labels_to_idxs = defaultdict(list)
244
- for label, idx in zip(sub_labels, sub_idxs):
245
- labels_to_idxs[label].append(idx)
246
- sub_idxs = []
247
- sub_labels = []
248
- for label, all_idx in labels_to_idxs.items():
249
- sub_labels.append(label)
250
- sub_idxs.append(np.mean([j for j in all_idx]))
251
- labels.append(sub_labels)
252
- idxs.append(sub_idxs)
253
- return np.concatenate(labels), np.concatenate(idxs)
254
-
255
- return label_idxs()
256
-
257
- def legend(self, ax):
258
- """Add legend with colorcoded model info."""
259
- handles = [Line2D([], [], color=c) for c in self.colors]
260
- ax.legend(handles=handles, labels=self.model_names)
261
-
262
- def display_multiple_ropes(self, rope, ax, y, linewidth, var_name, selection):
263
- """Display ROPE when more than one interval is provided."""
264
- for sel in rope.get(var_name, []):
265
- # pylint: disable=line-too-long
266
- if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"):
267
- vals = sel["rope"]
268
- ax.plot(
269
- vals,
270
- (y + 0.05, y + 0.05),
271
- lw=linewidth * 2,
272
- color="C2",
273
- solid_capstyle="round",
274
- zorder=0,
275
- alpha=0.7,
276
- )
277
- return ax
278
-
279
- def ridgeplot(
280
- self,
281
- hdi_prob,
282
- mult,
283
- linewidth,
284
- markersize,
285
- alpha,
286
- ridgeplot_kind,
287
- ridgeplot_truncate,
288
- ridgeplot_quantiles,
289
- ax,
290
- ):
291
- """Draw ridgeplot for each plotter.
292
-
293
- Parameters
294
- ----------
295
- hdi_prob : float
296
- Probability for the highest density interval.
297
- mult : float
298
- How much to multiply height by. Set this to greater than 1 to have some overlap.
299
- linewidth : float
300
- Width of line on border of ridges
301
- markersize : float
302
- Size of marker in center of forestplot line
303
- alpha : float
304
- Transparency of ridges
305
- ridgeplot_kind : string
306
- By default ("auto") continuous variables are plotted using KDEs and discrete ones using
307
- histograms. To override this use "hist" to plot histograms and "density" for KDEs
308
- ridgeplot_truncate: bool
309
- Whether to truncate densities according to the value of hdi_prop. Defaults to True
310
- ridgeplot_quantiles: list
311
- Quantiles in ascending order used to segment the KDE. Use [.25, .5, .75] for quartiles.
312
- Defaults to None.
313
- ax : Axes
314
- Axes to draw on
315
- """
316
- if alpha is None:
317
- alpha = 1.0
318
- zorder = 0
319
- for plotter in self.plotters.values():
320
- for x, y_min, y_max, hdi_, y_q, color in plotter.ridgeplot(
321
- hdi_prob, mult, ridgeplot_kind
322
- ):
323
- if alpha == 0:
324
- border = color
325
- facecolor = "None"
326
- else:
327
- border = "k"
328
- if x.dtype.kind == "i":
329
- if ridgeplot_truncate:
330
- facecolor = to_rgba(color, alpha)
331
- y_max = y_max[(x >= hdi_[0]) & (x <= hdi_[1])]
332
- x = x[(x >= hdi_[0]) & (x <= hdi_[1])]
333
- else:
334
- facecolor = [
335
- to_rgba(color, alpha) if ci else "None"
336
- for ci in ((x >= hdi_[0]) & (x <= hdi_[1]))
337
- ]
338
- y_min = np.ones_like(x) * y_min
339
- ax.bar(
340
- x,
341
- y_max - y_min,
342
- bottom=y_min,
343
- linewidth=linewidth,
344
- ec=border,
345
- color=facecolor,
346
- alpha=None,
347
- zorder=zorder,
348
- )
349
- else:
350
- tr_x = x[(x >= hdi_[0]) & (x <= hdi_[1])]
351
- tr_y_min = np.ones_like(tr_x) * y_min
352
- tr_y_max = y_max[(x >= hdi_[0]) & (x <= hdi_[1])]
353
- y_min = np.ones_like(x) * y_min
354
- if ridgeplot_truncate:
355
- ax.plot(
356
- tr_x, tr_y_max, "-", linewidth=linewidth, color=border, zorder=zorder
357
- )
358
- ax.plot(
359
- tr_x, tr_y_min, "-", linewidth=linewidth, color=border, zorder=zorder
360
- )
361
- else:
362
- ax.plot(x, y_max, "-", linewidth=linewidth, color=border, zorder=zorder)
363
- ax.plot(x, y_min, "-", linewidth=linewidth, color=border, zorder=zorder)
364
- ax.fill_between(
365
- tr_x, tr_y_max, tr_y_min, alpha=alpha, color=color, zorder=zorder
366
- )
367
-
368
- if ridgeplot_quantiles is not None:
369
- quantiles = [x[np.sum(y_q < quant)] for quant in ridgeplot_quantiles]
370
- ax.plot(
371
- quantiles,
372
- np.ones_like(quantiles) * y_min[0],
373
- "d",
374
- mfc=border,
375
- mec=border,
376
- ms=markersize,
377
- )
378
- zorder -= 1
379
- return ax
380
-
381
- def forestplot(
382
- self, hdi_prob, quartiles, xt_labelsize, titlesize, linewidth, markersize, ax, rope
383
- ):
384
- """Draw forestplot for each plotter.
385
-
386
- Parameters
387
- ----------
388
- hdi_prob : float
389
- Probability for the highest density interval. Width of each line.
390
- quartiles : bool
391
- Whether to mark quartiles
392
- xt_textsize : float
393
- Size of tick text
394
- titlesize : float
395
- Size of title text
396
- linewidth : float
397
- Width of forestplot line
398
- markersize : float
399
- Size of marker in center of forestplot line
400
- ax : Axes
401
- Axes to draw on
402
- """
403
- # Quantiles to be calculated
404
- endpoint = 100 * (1 - hdi_prob) / 2
405
- if quartiles:
406
- qlist = [endpoint, 25, 50, 75, 100 - endpoint]
407
- else:
408
- qlist = [endpoint, 50, 100 - endpoint]
409
-
410
- for plotter in self.plotters.values():
411
- for y, selection, values, color in plotter.treeplot(qlist, hdi_prob):
412
- if isinstance(rope, dict):
413
- self.display_multiple_ropes(rope, ax, y, linewidth, plotter.var_name, selection)
414
-
415
- mid = len(values) // 2
416
- param_iter = zip(
417
- np.linspace(2 * linewidth, linewidth, mid, endpoint=True)[-1::-1], range(mid)
418
- )
419
- for width, j in param_iter:
420
- ax.hlines(y, values[j], values[-(j + 1)], linewidth=width, color=color)
421
- ax.plot(
422
- values[mid],
423
- y,
424
- "o",
425
- mfc=ax.get_facecolor(),
426
- markersize=markersize * 0.75,
427
- color=color,
428
- )
429
- ax.tick_params(labelsize=xt_labelsize)
430
- ax.set_title(f"{hdi_prob:.1%} HDI", fontsize=titlesize, wrap=True)
431
- if rope is None or isinstance(rope, dict):
432
- return
433
- elif len(rope) == 2:
434
- ax.axvspan(rope[0], rope[1], 0, self.y_max(), color="C2", alpha=0.5)
435
- else:
436
- raise ValueError(
437
- "Argument `rope` must be None, a dictionary like"
438
- '{"var_name": {"rope": (lo, hi)}}, or an '
439
- "iterable of length 2"
440
- )
441
- return ax
442
-
443
- def plot_neff(self, ax, xt_labelsize, titlesize, markersize):
444
- """Draw effective n for each plotter."""
445
- for plotter in self.plotters.values():
446
- for y, ess, color in plotter.ess():
447
- if ess is not None:
448
- ax.plot(
449
- ess,
450
- y,
451
- "o",
452
- color=color,
453
- clip_on=False,
454
- markersize=markersize,
455
- markeredgecolor="k",
456
- )
457
- ax.set_xlim(left=0)
458
- ax.set_title("ess", fontsize=titlesize, wrap=True)
459
- ax.tick_params(labelsize=xt_labelsize)
460
- return ax
461
-
462
- def plot_rhat(self, ax, xt_labelsize, titlesize, markersize):
463
- """Draw r-hat for each plotter."""
464
- for plotter in self.plotters.values():
465
- for y, r_hat, color in plotter.r_hat():
466
- if r_hat is not None:
467
- ax.plot(r_hat, y, "o", color=color, markersize=markersize, markeredgecolor="k")
468
- ax.set_xlim(left=0.9, right=2.1)
469
- ax.set_xticks([1, 2])
470
- ax.tick_params(labelsize=xt_labelsize)
471
- ax.set_title("r_hat", fontsize=titlesize, wrap=True)
472
- return ax
473
-
474
- def make_bands(self, ax):
475
- """Draw shaded horizontal bands for each plotter."""
476
- y_vals, y_prev, is_zero = [0], None, False
477
- prev_color_index = 0
478
- for plotter in self.plotters.values():
479
- for y, *_, color in plotter.iterator():
480
- if self.colors.index(color) < prev_color_index:
481
- if not is_zero and y_prev is not None:
482
- y_vals.append((y + y_prev) * 0.5)
483
- is_zero = True
484
- else:
485
- is_zero = False
486
- prev_color_index = self.colors.index(color)
487
- y_prev = y
488
-
489
- offset = plotter.group_offset # pylint: disable=undefined-loop-variable
490
-
491
- y_vals.append(y_prev + offset)
492
- for idx, (y_start, y_stop) in enumerate(pairwise(y_vals)):
493
- ax.axhspan(y_start, y_stop, color="k", alpha=0.1 * (idx % 2))
494
- return ax
495
-
496
- def fig_height(self):
497
- """Figure out the height of this plot."""
498
- # hand-tuned
499
- return (
500
- 4
501
- + len(self.data) * len(self.var_names)
502
- - 1
503
- + 0.1 * sum(1 for j in self.plotters.values() for _ in j.iterator())
504
- )
505
-
506
- def y_max(self):
507
- """Get maximum y value for the plot."""
508
- return max(p.y_max() for p in self.plotters.values())
509
-
510
-
511
- # pylint: disable=too-many-instance-attributes
512
- class VarHandler:
513
- """Handle individual variable logic."""
514
-
515
- def __init__(
516
- self, var_name, data, y_start, model_names, combined, combine_dims, colors, labeller
517
- ):
518
- self.var_name = var_name
519
- self.data = data
520
- self.y_start = y_start
521
- self.model_names = model_names
522
- self.combined = combined
523
- self.combine_dims = combine_dims
524
- self.colors = colors
525
- self.labeller = labeller
526
- self.model_color = dict(zip(self.model_names, self.colors))
527
- max_chains = max(datum.chain.max().values for datum in data)
528
- self.chain_offset = len(data) * 0.45 / max(1, max_chains)
529
- self.var_offset = 1.5 * self.chain_offset
530
- self.group_offset = 2 * self.var_offset
531
-
532
- def iterator(self):
533
- """Iterate over models and chains for each variable."""
534
- if self.combined:
535
- grouped_data = [[(0, datum)] for datum in self.data]
536
- skip_dims = self.combine_dims.union({"chain"})
537
- else:
538
- grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
539
- skip_dims = self.combine_dims
540
-
541
- label_dict = OrderedDict()
542
- selection_list = []
543
- for name, grouped_datum in zip(self.model_names, grouped_data):
544
- for _, sub_data in grouped_datum:
545
- datum_iter = xarray_var_iter(
546
- sub_data.squeeze(),
547
- var_names=[self.var_name],
548
- skip_dims=skip_dims,
549
- reverse_selections=True,
550
- )
551
- datum_list = list(datum_iter)
552
- for _, selection, isel, values in datum_list:
553
- selection_list.append(selection)
554
- if not selection or not len(selection_list) % len(datum_list):
555
- var_name = self.var_name
556
- else:
557
- var_name = ""
558
- label = self.labeller.make_label_flat(var_name, selection, isel)
559
- if label not in label_dict:
560
- label_dict[label] = OrderedDict()
561
- if name not in label_dict[label]:
562
- label_dict[label][name] = []
563
- label_dict[label][name].append(values)
564
-
565
- y = self.y_start
566
- for idx, (label, model_data) in enumerate(label_dict.items()):
567
- for model_name, value_list in model_data.items():
568
- row_label = self.labeller.make_model_label(model_name, label)
569
- for values in value_list:
570
- yield y, row_label, label, selection_list[idx], values, self.model_color[
571
- model_name
572
- ]
573
- y += self.chain_offset
574
- y += self.var_offset
575
- y += self.group_offset
576
-
577
- def labels_ticks_and_vals(self):
578
- """Get labels, ticks, values, and colors for the variable."""
579
- y_ticks = defaultdict(list)
580
- for y, label, _, _, vals, color in self.iterator():
581
- y_ticks[label].append((y, vals, color))
582
- labels, ticks, vals, colors = [], [], [], []
583
- for label, all_data in y_ticks.items():
584
- for data in all_data:
585
- labels.append(label)
586
- ticks.append(data[0])
587
- vals.append(np.array(data[1]))
588
- colors.append(data[2]) # the colors are all the same
589
- return labels, ticks, vals, colors
590
-
591
- def treeplot(self, qlist, hdi_prob):
592
- """Get data for each treeplot for the variable."""
593
- for y, _, _, selection, values, color in self.iterator():
594
- ntiles = np.percentile(values.flatten(), qlist)
595
- ntiles[0], ntiles[-1] = hdi(values.flatten(), hdi_prob, multimodal=False)
596
- yield y, selection, ntiles, color
597
-
598
- def ridgeplot(self, hdi_prob, mult, ridgeplot_kind):
599
- """Get data for each ridgeplot for the variable."""
600
- xvals, hdi_vals, yvals, pdfs, pdfs_q, colors = [], [], [], [], [], []
601
- for y, *_, values, color in self.iterator():
602
- yvals.append(y)
603
- colors.append(color)
604
- values = values.flatten()
605
- values = values[np.isfinite(values)]
606
-
607
- if hdi_prob != 1:
608
- hdi_ = hdi(values, hdi_prob, multimodal=False)
609
- else:
610
- hdi_ = min(values), max(values)
611
-
612
- if ridgeplot_kind == "auto":
613
- kind = "hist" if np.all(np.mod(values, 1) == 0) else "density"
614
- else:
615
- kind = ridgeplot_kind
616
-
617
- if kind == "hist":
618
- _, density, x = histogram(values, bins=get_bins(values))
619
- x = x[:-1]
620
- elif kind == "density":
621
- x, density = kde(values)
622
-
623
- density_q = density.cumsum() / density.sum()
624
-
625
- xvals.append(x)
626
- pdfs.append(density)
627
- pdfs_q.append(density_q)
628
- hdi_vals.append(hdi_)
629
-
630
- scaling = max(np.max(j) for j in pdfs)
631
- for y, x, hdi_val, pdf, pdf_q, color in zip(yvals, xvals, hdi_vals, pdfs, pdfs_q, colors):
632
- yield x, y, mult * pdf / scaling + y, hdi_val, pdf_q, color
633
-
634
- def ess(self):
635
- """Get effective n data for the variable."""
636
- _, y_vals, values, colors = self.labels_ticks_and_vals()
637
- for y, value, color in zip(y_vals, values, colors):
638
- yield y, _ess(value), color
639
-
640
- def r_hat(self):
641
- """Get rhat data for the variable."""
642
- _, y_vals, values, colors = self.labels_ticks_and_vals()
643
- for y, value, color in zip(y_vals, values, colors):
644
- if value.ndim != 2 or value.shape[0] < 2:
645
- yield y, None, color
646
- else:
647
- yield y, _rhat(value), color
648
-
649
- def y_max(self):
650
- """Get max y value for the variable."""
651
- end_y = max(y for y, *_ in self.iterator())
652
-
653
- if self.combined:
654
- end_y += self.group_offset
655
-
656
- return end_y + 2 * self.group_offset
@@ -1,48 +0,0 @@
1
- """Matplotlib hdiplot."""
2
-
3
- import matplotlib.pyplot as plt
4
- from matplotlib import _pylab_helpers
5
-
6
- from ...plot_utils import _scale_fig_size, vectorized_to_hex
7
- from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
8
-
9
-
10
- def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backend_kwargs, show):
11
- """Matplotlib HDI plot."""
12
- if backend_kwargs is None:
13
- backend_kwargs = {}
14
-
15
- backend_kwargs = {
16
- **backend_kwarg_defaults(),
17
- **backend_kwargs,
18
- }
19
-
20
- plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
21
- plot_kwargs["color"] = vectorized_to_hex(plot_kwargs.get("color", color))
22
- plot_kwargs.setdefault("alpha", 0)
23
-
24
- fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "fill_between")
25
- fill_kwargs["color"] = vectorized_to_hex(fill_kwargs.get("color", color))
26
- fill_kwargs.setdefault("alpha", 0.5)
27
-
28
- figsize, *_ = _scale_fig_size(figsize, None)
29
- backend_kwargs.setdefault("figsize", figsize)
30
- backend_kwargs["squeeze"] = True
31
-
32
- if ax is None:
33
- fig_manager = _pylab_helpers.Gcf.get_active()
34
- if fig_manager is not None:
35
- ax = fig_manager.canvas.figure.gca()
36
- else:
37
- _, ax = create_axes_grid(
38
- 1,
39
- backend_kwargs=backend_kwargs,
40
- )
41
-
42
- ax.plot(x_data, y_data, **plot_kwargs)
43
- ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], **fill_kwargs)
44
-
45
- if backend_show(show):
46
- plt.show()
47
-
48
- return ax