arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
@@ -1,324 +0,0 @@
1
- """Bokeh Plot posterior densities."""
2
-
3
- from numbers import Number
4
- from typing import Optional
5
-
6
- import numpy as np
7
- from bokeh.models.annotations import Title
8
-
9
- from ....stats import hdi
10
- from ....stats.density_utils import get_bins, histogram
11
- from ...kdeplot import plot_kde
12
- from ...plot_utils import (
13
- _scale_fig_size,
14
- calculate_point_estimate,
15
- format_sig_figs,
16
- round_num,
17
- vectorized_to_hex,
18
- )
19
- from .. import show_layout
20
- from . import backend_kwarg_defaults, create_axes_grid
21
-
22
-
23
- def plot_posterior(
24
- ax,
25
- length_plotters,
26
- rows,
27
- cols,
28
- figsize,
29
- plotters,
30
- bw,
31
- circular,
32
- bins,
33
- kind,
34
- point_estimate,
35
- round_to,
36
- hdi_prob,
37
- multimodal,
38
- skipna,
39
- textsize,
40
- ref_val,
41
- rope,
42
- ref_val_color,
43
- rope_color,
44
- labeller,
45
- kwargs,
46
- backend_kwargs,
47
- show,
48
- ):
49
- """Bokeh posterior plot."""
50
- if backend_kwargs is None:
51
- backend_kwargs = {}
52
-
53
- backend_kwargs = {
54
- **backend_kwarg_defaults(
55
- ("dpi", "plot.bokeh.figure.dpi"),
56
- ),
57
- **backend_kwargs,
58
- }
59
-
60
- (figsize, ax_labelsize, *_, linewidth, _) = _scale_fig_size(figsize, textsize, rows, cols)
61
-
62
- if ax is None:
63
- ax = create_axes_grid(
64
- length_plotters,
65
- rows,
66
- cols,
67
- figsize=figsize,
68
- backend_kwargs=backend_kwargs,
69
- )
70
- else:
71
- ax = np.atleast_2d(ax)
72
- idx = 0
73
- for (var_name, selection, isel, x), ax_ in zip(
74
- plotters, (item for item in ax.flatten() if item is not None)
75
- ):
76
- _plot_posterior_op(
77
- idx,
78
- x.flatten(),
79
- var_name,
80
- selection,
81
- ax=ax_,
82
- bw=bw,
83
- circular=circular,
84
- bins=bins,
85
- kind=kind,
86
- point_estimate=point_estimate,
87
- round_to=round_to,
88
- hdi_prob=hdi_prob,
89
- multimodal=multimodal,
90
- skipna=skipna,
91
- linewidth=linewidth,
92
- ref_val=ref_val,
93
- rope=rope,
94
- ref_val_color=ref_val_color,
95
- rope_color=rope_color,
96
- ax_labelsize=ax_labelsize,
97
- **kwargs,
98
- )
99
- idx += 1
100
- _title = Title()
101
- _title.text = labeller.make_label_vert(var_name, selection, isel)
102
- ax_.title = _title
103
-
104
- show_layout(ax, show)
105
-
106
- return ax
107
-
108
-
109
- def _plot_posterior_op(
110
- idx,
111
- values,
112
- var_name,
113
- selection,
114
- ax,
115
- bw,
116
- circular,
117
- linewidth,
118
- bins,
119
- kind,
120
- point_estimate,
121
- hdi_prob,
122
- multimodal,
123
- skipna,
124
- ref_val,
125
- rope,
126
- ref_val_color,
127
- rope_color,
128
- ax_labelsize,
129
- round_to: Optional[int] = None,
130
- **kwargs,
131
- ): # noqa: D202
132
- """Artist to draw posterior."""
133
-
134
- def format_as_percent(x, round_to=0):
135
- return "{0:.{1:d}f}%".format(100 * x, round_to)
136
-
137
- def display_ref_val(max_data):
138
- if ref_val is None:
139
- return
140
- elif isinstance(ref_val, dict):
141
- val = None
142
- for sel in ref_val.get(var_name, []):
143
- if all(
144
- k in selection and selection[k] == v for k, v in sel.items() if k != "ref_val"
145
- ):
146
- val = sel["ref_val"]
147
- break
148
- if val is None:
149
- return
150
- elif isinstance(ref_val, list):
151
- val = ref_val[idx]
152
- elif isinstance(ref_val, Number):
153
- val = ref_val
154
- else:
155
- raise ValueError(
156
- "Argument `ref_val` must be None, a constant, a list or a "
157
- 'dictionary like {"var_name": [{"ref_val": ref_val}]}'
158
- )
159
- less_than_ref_probability = (values < val).mean()
160
- greater_than_ref_probability = (values >= val).mean()
161
- ref_in_posterior = "{} <{:g}< {}".format(
162
- format_as_percent(less_than_ref_probability, 1),
163
- val,
164
- format_as_percent(greater_than_ref_probability, 1),
165
- )
166
- ax.line(
167
- [val, val],
168
- [0, 0.8 * max_data],
169
- line_color=vectorized_to_hex(ref_val_color),
170
- line_alpha=0.65,
171
- )
172
-
173
- ax.text(
174
- x=[values.mean()],
175
- y=[max_data * 0.6],
176
- text=[ref_in_posterior],
177
- text_color=vectorized_to_hex(ref_val_color),
178
- text_align="center",
179
- )
180
-
181
- def display_rope(max_data):
182
- if rope is None:
183
- return
184
- elif isinstance(rope, dict):
185
- vals = None
186
- for sel in rope.get(var_name, []):
187
- # pylint: disable=line-too-long
188
- if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"):
189
- vals = sel["rope"]
190
- break
191
- if vals is None:
192
- return
193
- elif len(rope) == 2:
194
- vals = rope
195
- else:
196
- raise ValueError(
197
- "Argument `rope` must be None, a dictionary like"
198
- '{"var_name": {"rope": (lo, hi)}}, or an'
199
- "iterable of length 2"
200
- )
201
- rope_text = [f"{val:.{format_sig_figs(val, round_to)}g}" for val in vals]
202
-
203
- ax.line(
204
- vals,
205
- (max_data * 0.02, max_data * 0.02),
206
- line_width=linewidth * 5,
207
- line_color=vectorized_to_hex(rope_color),
208
- line_alpha=0.7,
209
- )
210
- probability_within_rope = ((values > vals[0]) & (values <= vals[1])).mean()
211
- text_props = dict(
212
- text_color=vectorized_to_hex(rope_color),
213
- text_align="center",
214
- )
215
- ax.text(
216
- x=values.mean(),
217
- y=[max_data * 0.45],
218
- text=[f"{format_as_percent(probability_within_rope, 1)} in ROPE"],
219
- **text_props,
220
- )
221
-
222
- ax.text(
223
- x=vals,
224
- y=[max_data * 0.2, max_data * 0.2],
225
- text_font_size=f"{ax_labelsize}pt",
226
- text=rope_text,
227
- **text_props,
228
- )
229
-
230
- def display_point_estimate(max_data):
231
- if not point_estimate:
232
- return
233
- point_value = calculate_point_estimate(point_estimate, values, bw, circular)
234
- sig_figs = format_sig_figs(point_value, round_to)
235
- point_text = "{point_estimate}={point_value:.{sig_figs}g}".format(
236
- point_estimate=point_estimate, point_value=point_value, sig_figs=sig_figs
237
- )
238
-
239
- ax.text(x=[point_value], y=[max_data * 0.8], text=[point_text], text_align="center")
240
-
241
- def display_hdi(max_data):
242
- # np.ndarray with 2 entries, min and max
243
- # pylint: disable=line-too-long
244
- hdi_probs = hdi(
245
- values, hdi_prob=hdi_prob, circular=circular, multimodal=multimodal, skipna=skipna
246
- ) # type: np.ndarray
247
-
248
- for hdi_i in np.atleast_2d(hdi_probs):
249
- ax.line(
250
- hdi_i,
251
- (max_data * 0.02, max_data * 0.02),
252
- line_width=linewidth * 2,
253
- line_color="black",
254
- )
255
-
256
- ax.text(
257
- x=list(hdi_i) + [(hdi_i[0] + hdi_i[1]) / 2],
258
- y=[max_data * 0.07, max_data * 0.07, max_data * 0.3],
259
- text=(
260
- list(map(str, map(lambda x: round_num(x, round_to), hdi_i)))
261
- + [f"{format_as_percent(hdi_prob)} HDI"]
262
- ),
263
- text_align="center",
264
- )
265
-
266
- def format_axes():
267
- ax.yaxis.visible = False
268
- ax.yaxis.major_tick_line_color = None
269
- ax.yaxis.minor_tick_line_color = None
270
- ax.yaxis.major_label_text_font_size = "0pt"
271
- ax.xgrid.grid_line_color = None
272
- ax.ygrid.grid_line_color = None
273
-
274
- if skipna:
275
- values = values[~np.isnan(values)]
276
-
277
- if kind == "kde" and values.dtype.kind == "f":
278
- kwargs.setdefault("line_width", linewidth)
279
- plot_kde(
280
- values,
281
- bw=bw,
282
- is_circular=circular,
283
- fill_kwargs={"fill_alpha": kwargs.pop("fill_alpha", 0)},
284
- plot_kwargs=kwargs,
285
- ax=ax,
286
- rug=False,
287
- backend="bokeh",
288
- backend_kwargs={},
289
- show=False,
290
- )
291
- max_data = values.max()
292
- elif values.dtype.kind == "i" or (values.dtype.kind == "f" and kind == "hist"):
293
- if bins is None:
294
- bins = get_bins(values)
295
- kwargs.setdefault("align", "left")
296
- kwargs.setdefault("color", "blue")
297
- _, hist, edges = histogram(values, bins=bins)
298
- max_data = hist.max()
299
- ax.quad(
300
- top=hist, bottom=0, left=edges[:-1], right=edges[1:], fill_alpha=0.35, line_alpha=0.35
301
- )
302
- elif values.dtype.kind == "b":
303
- if bins is None:
304
- bins = "auto"
305
- kwargs.setdefault("color", "blue")
306
-
307
- hist = np.array([(~values).sum(), values.sum()])
308
- max_data = hist.max()
309
- edges = np.array([-0.5, 0.5, 1.5])
310
- ax.quad(
311
- top=hist, bottom=0, left=edges[:-1], right=edges[1:], fill_alpha=0.35, line_alpha=0.35
312
- )
313
- hdi_prob = "hide"
314
- ax.xaxis.ticker = [0, 1]
315
- ax.xaxis.major_label_overrides = {0: "False", 1: "True"}
316
- else:
317
- raise TypeError("Values must be float, integer or boolean")
318
-
319
- format_axes()
320
- if hdi_prob != "hide":
321
- display_hdi(max_data)
322
- display_point_estimate(max_data)
323
- display_ref_val(max_data)
324
- display_rope(max_data)
@@ -1,379 +0,0 @@
1
- """Bokeh Posterior predictive plot."""
2
-
3
- import numpy as np
4
- from bokeh.models.annotations import Legend
5
- from bokeh.models.glyphs import Scatter
6
- from bokeh.models import ColumnDataSource
7
-
8
-
9
- from ....stats.density_utils import get_bins, histogram, kde
10
- from ...kdeplot import plot_kde
11
- from ...plot_utils import _scale_fig_size, vectorized_to_hex
12
-
13
-
14
- from .. import show_layout
15
- from . import backend_kwarg_defaults, create_axes_grid
16
-
17
-
18
- def plot_ppc(
19
- ax,
20
- length_plotters,
21
- rows,
22
- cols,
23
- figsize,
24
- animated,
25
- obs_plotters,
26
- pp_plotters,
27
- predictive_dataset,
28
- pp_sample_ix,
29
- kind,
30
- alpha,
31
- colors,
32
- textsize,
33
- mean,
34
- observed,
35
- observed_rug,
36
- jitter,
37
- total_pp_samples,
38
- legend, # pylint: disable=unused-argument
39
- labeller,
40
- group, # pylint: disable=unused-argument
41
- animation_kwargs, # pylint: disable=unused-argument
42
- num_pp_samples,
43
- backend_kwargs,
44
- show,
45
- ):
46
- """Bokeh ppc plot."""
47
- if backend_kwargs is None:
48
- backend_kwargs = {}
49
-
50
- backend_kwargs = {
51
- **backend_kwarg_defaults(
52
- ("dpi", "plot.bokeh.figure.dpi"),
53
- ),
54
- **backend_kwargs,
55
- }
56
-
57
- colors = vectorized_to_hex(colors)
58
-
59
- (figsize, *_, linewidth, markersize) = _scale_fig_size(figsize, textsize, rows, cols)
60
- if ax is None:
61
- axes = create_axes_grid(
62
- length_plotters,
63
- rows,
64
- cols,
65
- figsize=figsize,
66
- backend_kwargs=backend_kwargs,
67
- )
68
- else:
69
- axes = np.atleast_2d(ax)
70
-
71
- if len([item for item in axes.ravel() if not None]) != length_plotters:
72
- raise ValueError(
73
- f"Found {length_plotters} variables to plot but {len(axes)} axes instances. "
74
- "They must be equal."
75
- )
76
-
77
- if alpha is None:
78
- if animated:
79
- alpha = 1
80
- else:
81
- if kind.lower() == "scatter":
82
- alpha = 0.7
83
- else:
84
- alpha = 0.2
85
-
86
- if jitter is None:
87
- jitter = 0.0
88
- if jitter < 0.0:
89
- raise ValueError("jitter must be >=0.")
90
-
91
- for i, ax_i in enumerate((item for item in axes.flatten() if item is not None)):
92
- var_name, sel, isel, obs_vals = obs_plotters[i]
93
- pp_var_name, _, _, pp_vals = pp_plotters[i]
94
- dtype = predictive_dataset[pp_var_name].dtype.kind
95
- legend_it = []
96
-
97
- if dtype not in ["i", "f"]:
98
- raise ValueError(
99
- f"The data type of the predictive data must be one of 'i' or 'f', but is '{dtype}'"
100
- )
101
-
102
- # flatten non-specified dimensions
103
- obs_vals = obs_vals.flatten()
104
- pp_vals = pp_vals.reshape(total_pp_samples, -1)
105
- pp_sampled_vals = pp_vals[pp_sample_ix]
106
- cds_rug = ColumnDataSource({"_": np.array(obs_vals)})
107
-
108
- if kind == "kde":
109
- plot_kwargs = {
110
- "line_color": colors[0],
111
- "line_alpha": alpha,
112
- "line_width": 0.5 * linewidth,
113
- }
114
-
115
- pp_densities = []
116
- pp_xs = []
117
- for vals in pp_sampled_vals:
118
- vals = np.array([vals]).flatten()
119
- if dtype == "f":
120
- pp_x, pp_density = kde(vals)
121
- pp_densities.append(pp_density)
122
- pp_xs.append(pp_x)
123
- else:
124
- bins = get_bins(vals)
125
- _, hist, bin_edges = histogram(vals, bins=bins)
126
- hist = np.concatenate((hist[:1], hist))
127
- pp_densities.append(hist)
128
- pp_xs.append(bin_edges)
129
-
130
- if dtype == "f":
131
- multi_line = ax_i.multi_line(pp_xs, pp_densities, **plot_kwargs)
132
- legend_it.append((f"{group.capitalize()} predictive", [multi_line]))
133
- else:
134
- all_steps = []
135
- for x_s, y_s in zip(pp_xs, pp_densities):
136
- step = ax_i.step(x_s, y_s, **plot_kwargs)
137
- all_steps.append(step)
138
- legend_it.append((f"{group.capitalize()} predictive", all_steps))
139
-
140
- if observed:
141
- label = "Observed"
142
- if dtype == "f":
143
- _, glyph = plot_kde(
144
- obs_vals,
145
- plot_kwargs={"line_color": colors[1], "line_width": linewidth},
146
- fill_kwargs={"alpha": 0},
147
- ax=ax_i,
148
- backend="bokeh",
149
- backend_kwargs={},
150
- show=False,
151
- return_glyph=True,
152
- )
153
- legend_it.append((label, glyph))
154
- if observed_rug:
155
- glyph = Scatter(
156
- x="_",
157
- y=0.0,
158
- marker="dash",
159
- angle=np.pi / 2,
160
- line_color=colors[1],
161
- line_width=linewidth,
162
- )
163
- ax_i.add_glyph(cds_rug, glyph)
164
- else:
165
- bins = get_bins(obs_vals)
166
- _, hist, bin_edges = histogram(obs_vals, bins=bins)
167
- hist = np.concatenate((hist[:1], hist))
168
- step = ax_i.step(
169
- bin_edges,
170
- hist,
171
- line_color=colors[1],
172
- line_width=linewidth,
173
- mode="center",
174
- )
175
- legend_it.append((label, [step]))
176
-
177
- if mean:
178
- label = f"{group.capitalize()} predictive mean"
179
- if dtype == "f":
180
- rep = len(pp_densities)
181
- len_density = len(pp_densities[0])
182
-
183
- new_x = np.linspace(np.min(pp_xs), np.max(pp_xs), len_density)
184
- new_d = np.zeros((rep, len_density))
185
- bins = np.digitize(pp_xs, new_x, right=True)
186
- new_x -= (new_x[1] - new_x[0]) / 2
187
- for irep in range(rep):
188
- new_d[irep][bins[irep]] = pp_densities[irep]
189
- line = ax_i.line(
190
- new_x,
191
- new_d.mean(0),
192
- color=colors[2],
193
- line_dash="dashed",
194
- line_width=linewidth,
195
- )
196
- legend_it.append((label, [line]))
197
- else:
198
- vals = pp_vals.flatten()
199
- bins = get_bins(vals)
200
- _, hist, bin_edges = histogram(vals, bins=bins)
201
- hist = np.concatenate((hist[:1], hist))
202
- step = ax_i.step(
203
- bin_edges,
204
- hist,
205
- line_color=colors[2],
206
- line_width=linewidth,
207
- line_dash="dashed",
208
- mode="center",
209
- )
210
- legend_it.append((label, [step]))
211
- ax_i.yaxis.major_tick_line_color = None
212
- ax_i.yaxis.minor_tick_line_color = None
213
- ax_i.yaxis.major_label_text_font_size = "0pt"
214
-
215
- elif kind == "cumulative":
216
- if observed:
217
- label = "Observed"
218
- if dtype == "f":
219
- glyph = ax_i.line(
220
- *_empirical_cdf(obs_vals),
221
- line_color=colors[1],
222
- line_width=linewidth,
223
- )
224
- glyph.level = "overlay"
225
- legend_it.append((label, [glyph]))
226
-
227
- else:
228
- step = ax_i.step(
229
- *_empirical_cdf(obs_vals),
230
- line_color=colors[1],
231
- line_width=linewidth,
232
- mode="center",
233
- )
234
- legend_it.append((label, [step]))
235
-
236
- if observed_rug:
237
- glyph = Scatter(
238
- x="_",
239
- y=0.0,
240
- marker="dash",
241
- angle=np.pi / 2,
242
- line_color=colors[1],
243
- line_width=linewidth,
244
- )
245
- ax_i.add_glyph(cds_rug, glyph)
246
-
247
- pp_densities = np.empty((2 * len(pp_sampled_vals), pp_sampled_vals[0].size))
248
- for idx, vals in enumerate(pp_sampled_vals):
249
- vals = np.array([vals]).flatten()
250
- pp_x, pp_density = _empirical_cdf(vals)
251
- pp_densities[2 * idx] = pp_x
252
- pp_densities[2 * idx + 1] = pp_density
253
- multi_line = ax_i.multi_line(
254
- list(pp_densities[::2]),
255
- list(pp_densities[1::2]),
256
- line_alpha=alpha,
257
- line_color=colors[0],
258
- line_width=linewidth,
259
- )
260
- legend_it.append((f"{group.capitalize()} predictive", [multi_line]))
261
- if mean:
262
- label = f"{group.capitalize()} predictive mean"
263
- line = ax_i.line(
264
- *_empirical_cdf(pp_vals.flatten()),
265
- color=colors[2],
266
- line_dash="dashed",
267
- line_width=linewidth,
268
- )
269
- legend_it.append((label, [line]))
270
-
271
- elif kind == "scatter":
272
- if mean:
273
- label = f"{group.capitalize()} predictive mean"
274
- if dtype == "f":
275
- _, glyph = plot_kde(
276
- pp_vals.flatten(),
277
- plot_kwargs={
278
- "line_color": colors[2],
279
- "line_dash": "dashed",
280
- "line_width": linewidth,
281
- },
282
- ax=ax_i,
283
- backend="bokeh",
284
- backend_kwargs={},
285
- show=False,
286
- return_glyph=True,
287
- )
288
- legend_it.append((label, glyph))
289
- else:
290
- vals = pp_vals.flatten()
291
- bins = get_bins(vals)
292
- _, hist, bin_edges = histogram(vals, bins=bins)
293
- hist = np.concatenate((hist[:1], hist))
294
- step = ax_i.step(
295
- bin_edges,
296
- hist,
297
- color=colors[2],
298
- line_width=linewidth,
299
- line_dash="dashed",
300
- mode="center",
301
- )
302
- legend_it.append((label, [step]))
303
-
304
- jitter_scale = 0.1
305
- y_rows = np.linspace(0, 0.1, num_pp_samples + 1)
306
- scale_low = 0
307
- scale_high = jitter_scale * jitter
308
-
309
- if observed:
310
- label = "Observed"
311
- obs_yvals = np.zeros_like(obs_vals, dtype=np.float64)
312
- if jitter:
313
- obs_yvals += np.random.uniform(
314
- low=scale_low, high=scale_high, size=len(obs_vals)
315
- )
316
- glyph = ax_i.scatter(
317
- obs_vals,
318
- obs_yvals,
319
- marker="circle",
320
- line_color=colors[1],
321
- fill_color=colors[1],
322
- size=markersize,
323
- line_alpha=alpha,
324
- )
325
- glyph.level = "overlay"
326
- legend_it.append((label, [glyph]))
327
-
328
- all_scatter = []
329
- for vals, y in zip(pp_sampled_vals, y_rows[1:]):
330
- vals = np.ravel(vals)
331
- yvals = np.full_like(vals, y, dtype=np.float64)
332
- if jitter:
333
- yvals += np.random.uniform(low=scale_low, high=scale_high, size=len(vals))
334
- scatter = ax_i.scatter(
335
- vals,
336
- yvals,
337
- line_color=colors[0],
338
- fill_color=colors[0],
339
- size=markersize,
340
- fill_alpha=alpha,
341
- )
342
- all_scatter.append(scatter)
343
-
344
- legend_it.append((f"{group.capitalize()} predictive", all_scatter))
345
- ax_i.yaxis.major_tick_line_color = None
346
- ax_i.yaxis.minor_tick_line_color = None
347
- ax_i.yaxis.major_label_text_font_size = "0pt"
348
-
349
- if legend:
350
- legend = Legend(
351
- items=legend_it,
352
- location="top_left",
353
- orientation="vertical",
354
- )
355
- ax_i.add_layout(legend)
356
- if textsize is not None:
357
- ax_i.legend.label_text_font_size = f"{textsize}pt"
358
- ax_i.legend.click_policy = "hide"
359
- ax_i.xaxis.axis_label = labeller.make_pp_label(var_name, pp_var_name, sel, isel)
360
-
361
- show_layout(axes, show)
362
-
363
- return axes
364
-
365
-
366
- def _empirical_cdf(data):
367
- """Compute empirical cdf of a numpy array.
368
-
369
- Parameters
370
- ----------
371
- data : np.array
372
- 1d array
373
-
374
- Returns
375
- -------
376
- np.array, np.array
377
- x and y coordinates for the empirical cdf of the data
378
- """
379
- return np.sort(data), np.linspace(0, 1, len(data))