arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,149 +0,0 @@
1
- """Bokeh rankplot."""
2
-
3
- import numpy as np
4
-
5
- from bokeh.models import Span
6
- from bokeh.models.annotations import Title
7
- from bokeh.models.tickers import FixedTicker
8
-
9
- from ....stats.density_utils import histogram
10
- from ...plot_utils import _scale_fig_size, compute_ranks
11
- from .. import show_layout
12
- from . import backend_kwarg_defaults, create_axes_grid
13
-
14
-
15
- def plot_rank(
16
- axes,
17
- length_plotters,
18
- rows,
19
- cols,
20
- figsize,
21
- plotters,
22
- bins,
23
- kind,
24
- colors,
25
- ref_line,
26
- labels,
27
- labeller,
28
- ref_line_kwargs,
29
- bar_kwargs,
30
- vlines_kwargs,
31
- marker_vlines_kwargs,
32
- backend_kwargs,
33
- show,
34
- ):
35
- """Bokeh rank plot."""
36
- if ref_line_kwargs is None:
37
- ref_line_kwargs = {}
38
- ref_line_kwargs.setdefault("line_dash", "dashed")
39
- ref_line_kwargs.setdefault("line_color", "black")
40
-
41
- if bar_kwargs is None:
42
- bar_kwargs = {}
43
- bar_kwargs.setdefault("line_color", "white")
44
-
45
- if vlines_kwargs is None:
46
- vlines_kwargs = {}
47
- vlines_kwargs.setdefault("line_width", 2)
48
- vlines_kwargs.setdefault("line_dash", "solid")
49
-
50
- if marker_vlines_kwargs is None:
51
- marker_vlines_kwargs = {}
52
- marker_vlines_kwargs.setdefault("marker", "circle")
53
-
54
- if backend_kwargs is None:
55
- backend_kwargs = {}
56
-
57
- backend_kwargs = {
58
- **backend_kwarg_defaults(
59
- ("dpi", "plot.bokeh.figure.dpi"),
60
- ),
61
- **backend_kwargs,
62
- }
63
- figsize, *_ = _scale_fig_size(figsize, None, rows=rows, cols=cols)
64
- if axes is None:
65
- axes = create_axes_grid(
66
- length_plotters,
67
- rows,
68
- cols,
69
- figsize=figsize,
70
- sharex=True,
71
- sharey=True,
72
- backend_kwargs=backend_kwargs,
73
- )
74
- else:
75
- axes = np.atleast_2d(axes)
76
-
77
- for ax, (var_name, selection, isel, var_data) in zip(
78
- (item for item in axes.flatten() if item is not None), plotters
79
- ):
80
- ranks = compute_ranks(var_data)
81
- bin_ary = np.histogram_bin_edges(ranks, bins=bins, range=(0, ranks.size))
82
- all_counts = np.empty((len(ranks), len(bin_ary) - 1))
83
- for idx, row in enumerate(ranks):
84
- _, all_counts[idx], _ = histogram(row, bins=bin_ary)
85
- counts_normalizer = all_counts.max() / 0.95
86
- gap = 1
87
- width = bin_ary[1] - bin_ary[0]
88
-
89
- bar_kwargs.setdefault("width", width)
90
- # Center the bins
91
- bin_ary = (bin_ary[1:] + bin_ary[:-1]) / 2
92
-
93
- y_ticks = []
94
- if kind == "bars":
95
- for idx, counts in enumerate(all_counts):
96
- counts = counts / counts_normalizer
97
- y_ticks.append(idx * gap)
98
- ax.vbar(
99
- x=bin_ary,
100
- top=y_ticks[-1] + counts,
101
- bottom=y_ticks[-1],
102
- fill_color=colors[idx],
103
- **bar_kwargs,
104
- )
105
- if ref_line:
106
- hline = Span(location=y_ticks[-1] + counts.mean(), **ref_line_kwargs)
107
- ax.add_layout(hline)
108
- if labels:
109
- ax.yaxis.axis_label = "Chain"
110
- elif kind == "vlines":
111
- ymin = np.full(len(all_counts), all_counts.mean())
112
- for idx, counts in enumerate(all_counts):
113
- ax.scatter(
114
- bin_ary,
115
- counts,
116
- fill_color=colors[idx],
117
- line_color=colors[idx],
118
- **marker_vlines_kwargs,
119
- )
120
- x_locations = [(bin, bin) for bin in bin_ary]
121
- y_locations = [(ymin[idx], counts_) for counts_ in counts]
122
- ax.multi_line(x_locations, y_locations, line_color=colors[idx], **vlines_kwargs)
123
-
124
- if ref_line:
125
- hline = Span(location=all_counts.mean(), **ref_line_kwargs)
126
- ax.add_layout(hline)
127
-
128
- if labels:
129
- ax.xaxis.axis_label = "Rank (all chains)"
130
-
131
- ax.yaxis.ticker = FixedTicker(ticks=y_ticks)
132
- ax.xaxis.major_label_overrides = dict(
133
- zip(map(str, y_ticks), map(str, range(len(y_ticks))))
134
- )
135
-
136
- else:
137
- ax.yaxis.major_tick_line_color = None
138
- ax.yaxis.minor_tick_line_color = None
139
-
140
- ax.xaxis.major_label_text_font_size = "0pt"
141
- ax.yaxis.major_label_text_font_size = "0pt"
142
-
143
- _title = Title()
144
- _title.text = labeller.make_label_vert(var_name, selection, isel)
145
- ax.title = _title
146
-
147
- show_layout(axes, show)
148
-
149
- return axes
@@ -1,107 +0,0 @@
1
- """Bokeh separation plot."""
2
-
3
- import numpy as np
4
-
5
- from ...plot_utils import _scale_fig_size, vectorized_to_hex
6
- from .. import show_layout
7
- from . import backend_kwarg_defaults, create_axes_grid
8
-
9
-
10
- def plot_separation(
11
- y,
12
- y_hat,
13
- y_hat_line,
14
- label_y_hat,
15
- expected_events,
16
- figsize,
17
- textsize,
18
- color,
19
- legend,
20
- locs,
21
- width,
22
- ax,
23
- plot_kwargs,
24
- y_hat_line_kwargs,
25
- exp_events_kwargs,
26
- backend_kwargs,
27
- show,
28
- ):
29
- """Matplotlib separation plot."""
30
- if backend_kwargs is None:
31
- backend_kwargs = {}
32
-
33
- if plot_kwargs is None:
34
- plot_kwargs = {}
35
-
36
- # plot_kwargs.setdefault("color", "#2a2eec")
37
- # if color:
38
- plot_kwargs["color"] = vectorized_to_hex(color)
39
-
40
- backend_kwargs = {
41
- **backend_kwarg_defaults(),
42
- **backend_kwargs,
43
- }
44
-
45
- if y_hat_line_kwargs is None:
46
- y_hat_line_kwargs = {}
47
-
48
- y_hat_line_kwargs.setdefault("color", "black")
49
- y_hat_line_kwargs.setdefault("line_width", 2)
50
-
51
- if exp_events_kwargs is None:
52
- exp_events_kwargs = {}
53
-
54
- exp_events_kwargs.setdefault("color", "black")
55
- exp_events_kwargs.setdefault("size", 15)
56
-
57
- if legend:
58
- y_hat_line_kwargs.setdefault("legend_label", label_y_hat)
59
- exp_events_kwargs.setdefault(
60
- "legend_label",
61
- "Expected events",
62
- )
63
-
64
- figsize, *_ = _scale_fig_size(figsize, textsize)
65
-
66
- idx = np.argsort(y_hat)
67
-
68
- backend_kwargs["x_range"] = (0, 1)
69
- backend_kwargs["y_range"] = (0, 1)
70
-
71
- if ax is None:
72
- ax = create_axes_grid(1, figsize=figsize, squeeze=True, backend_kwargs=backend_kwargs)
73
-
74
- for i, loc in enumerate(locs):
75
- positive = not y[idx][i] == 0
76
- alpha = 1 if positive else 0.3
77
- ax.vbar(
78
- loc,
79
- top=1,
80
- width=width,
81
- fill_alpha=alpha,
82
- line_alpha=alpha,
83
- **plot_kwargs,
84
- )
85
-
86
- if y_hat_line:
87
- ax.line(
88
- np.linspace(0, 1, len(y_hat)),
89
- y_hat[idx],
90
- **y_hat_line_kwargs,
91
- )
92
-
93
- if expected_events:
94
- expected_events = int(np.round(np.sum(y_hat)))
95
- ax.triangle(
96
- y_hat[idx][len(y_hat) - expected_events - 1],
97
- 0,
98
- **exp_events_kwargs,
99
- )
100
-
101
- ax.axis.visible = False
102
- ax.xgrid.grid_line_color = None
103
- ax.ygrid.grid_line_color = None
104
-
105
- show_layout(ax, show)
106
-
107
- return ax
@@ -1,436 +0,0 @@
1
- """Bokeh Traceplot."""
2
-
3
- import warnings
4
- from collections.abc import Iterable
5
- from itertools import cycle
6
-
7
- import bokeh.plotting as bkp
8
- import matplotlib.pyplot as plt
9
- import numpy as np
10
- from bokeh.models import ColumnDataSource, DataRange1d, Span
11
- from bokeh.models.glyphs import Scatter
12
- from bokeh.models.annotations import Title
13
-
14
- from ...distplot import plot_dist
15
- from ...plot_utils import _scale_fig_size
16
- from ...rankplot import plot_rank
17
- from .. import show_layout
18
- from . import backend_kwarg_defaults, dealiase_sel_kwargs
19
- from ....sel_utils import xarray_var_iter
20
-
21
-
22
- def plot_trace(
23
- data,
24
- var_names,
25
- divergences,
26
- kind,
27
- figsize,
28
- rug,
29
- lines,
30
- circ_var_names, # pylint: disable=unused-argument
31
- circ_var_units, # pylint: disable=unused-argument
32
- compact,
33
- compact_prop,
34
- combined,
35
- chain_prop,
36
- legend,
37
- labeller,
38
- plot_kwargs,
39
- fill_kwargs,
40
- rug_kwargs,
41
- hist_kwargs,
42
- trace_kwargs,
43
- rank_kwargs,
44
- plotters,
45
- divergence_data,
46
- axes,
47
- backend_kwargs,
48
- backend_config,
49
- show,
50
- ):
51
- """Bokeh traceplot."""
52
- # If divergences are plotted they must be provided
53
- if divergences is not False:
54
- assert divergence_data is not None
55
-
56
- if backend_config is None:
57
- backend_config = {}
58
-
59
- backend_config = {
60
- **backend_kwarg_defaults(
61
- ("bounds_y_range", "plot.bokeh.bounds_y_range"),
62
- ),
63
- **backend_config,
64
- }
65
-
66
- # Set plot default backend kwargs
67
- if backend_kwargs is None:
68
- backend_kwargs = {}
69
-
70
- backend_kwargs = {
71
- **backend_kwarg_defaults(
72
- ("dpi", "plot.bokeh.figure.dpi"),
73
- ),
74
- **backend_kwargs,
75
- }
76
- dpi = backend_kwargs.pop("dpi")
77
-
78
- if figsize is None:
79
- figsize = (12, len(plotters) * 2)
80
-
81
- figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2)
82
-
83
- backend_kwargs.setdefault("height", int(figsize[1] * dpi // len(plotters)))
84
- backend_kwargs.setdefault("width", int(figsize[0] * dpi // 2))
85
-
86
- if lines is None:
87
- lines = ()
88
-
89
- num_chain_props = len(data.chain) + 1 if combined else len(data.chain)
90
- if not compact:
91
- chain_prop = (
92
- {"line_color": plt.rcParams["axes.prop_cycle"].by_key()["color"]}
93
- if chain_prop is None
94
- else chain_prop
95
- )
96
- else:
97
- chain_prop = (
98
- {
99
- "line_dash": ("solid", "dotted", "dashed", "dashdot"),
100
- }
101
- if chain_prop is None
102
- else chain_prop
103
- )
104
- compact_prop = (
105
- {"line_color": plt.rcParams["axes.prop_cycle"].by_key()["color"]}
106
- if compact_prop is None
107
- else compact_prop
108
- )
109
-
110
- if isinstance(chain_prop, str):
111
- chain_prop = {chain_prop: plt.rcParams["axes.prop_cycle"].by_key()[chain_prop]}
112
- if isinstance(chain_prop, tuple):
113
- warnings.warn(
114
- "chain_prop as a tuple will be deprecated in a future warning, use a dict instead",
115
- FutureWarning,
116
- )
117
- chain_prop = {chain_prop[0]: chain_prop[1]}
118
- chain_prop = {
119
- prop_name: [prop for _, prop in zip(range(num_chain_props), cycle(props))]
120
- for prop_name, props in chain_prop.items()
121
- }
122
-
123
- if isinstance(compact_prop, str):
124
- compact_prop = {compact_prop: plt.rcParams["axes.prop_cycle"].by_key()[compact_prop]}
125
- if isinstance(compact_prop, tuple):
126
- warnings.warn(
127
- "compact_prop as a tuple will be deprecated in a future warning, use a dict instead",
128
- FutureWarning,
129
- )
130
- compact_prop = {compact_prop[0]: compact_prop[1]}
131
-
132
- trace_kwargs = {} if trace_kwargs is None else trace_kwargs
133
- trace_kwargs.setdefault("alpha", 0.35)
134
-
135
- if hist_kwargs is None:
136
- hist_kwargs = {}
137
- hist_kwargs.setdefault("alpha", 0.35)
138
-
139
- if plot_kwargs is None:
140
- plot_kwargs = {}
141
- if fill_kwargs is None:
142
- fill_kwargs = {}
143
- if rug_kwargs is None:
144
- rug_kwargs = {}
145
- if rank_kwargs is None:
146
- rank_kwargs = {}
147
-
148
- trace_kwargs.setdefault("line_width", linewidth)
149
- plot_kwargs.setdefault("line_width", linewidth)
150
-
151
- if rank_kwargs is None:
152
- rank_kwargs = {}
153
-
154
- if axes is None:
155
- axes = []
156
- backend_kwargs_copy = backend_kwargs.copy()
157
- for i in range(len(plotters)):
158
- if not i:
159
- _axes = [bkp.figure(**backend_kwargs), bkp.figure(**backend_kwargs_copy)]
160
- backend_kwargs_copy.setdefault("x_range", _axes[1].x_range)
161
- else:
162
- _axes = [
163
- bkp.figure(**backend_kwargs),
164
- bkp.figure(**backend_kwargs_copy),
165
- ]
166
- axes.append(_axes)
167
-
168
- axes = np.atleast_2d(axes)
169
-
170
- cds_data = {}
171
- cds_var_groups = {}
172
- draw_name = "draw"
173
-
174
- for var_name, selection, isel, value in list(
175
- xarray_var_iter(data, var_names=var_names, combined=True)
176
- ):
177
- if selection:
178
- cds_name = "{}_ARVIZ_CDS_SELECTION_{}".format(
179
- var_name,
180
- "_".join(
181
- str(item)
182
- for key, value in selection.items()
183
- for item in (
184
- [key, value]
185
- if (isinstance(value, str) or not isinstance(value, Iterable))
186
- else [key, *value]
187
- )
188
- ),
189
- )
190
- else:
191
- cds_name = var_name
192
-
193
- if var_name not in cds_var_groups:
194
- cds_var_groups[var_name] = []
195
- cds_var_groups[var_name].append(cds_name)
196
-
197
- for chain_idx, _ in enumerate(data.chain.values):
198
- if chain_idx not in cds_data:
199
- cds_data[chain_idx] = {}
200
- _data = value[chain_idx]
201
- cds_data[chain_idx][cds_name] = _data
202
-
203
- while any(key == draw_name for key in cds_data[0]):
204
- draw_name += "w"
205
-
206
- for chain in cds_data.values():
207
- chain[draw_name] = data.draw.values
208
-
209
- cds_data = {chain_idx: ColumnDataSource(cds) for chain_idx, cds in cds_data.items()}
210
-
211
- for idx, (var_name, selection, isel, value) in enumerate(plotters):
212
- value = np.atleast_2d(value)
213
-
214
- if len(value.shape) == 2:
215
- y_name = (
216
- var_name
217
- if not selection
218
- else "{}_ARVIZ_CDS_SELECTION_{}".format(
219
- var_name,
220
- "_".join(
221
- str(item)
222
- for key, value in selection.items()
223
- for item in (
224
- (key, value)
225
- if (isinstance(value, str) or not isinstance(value, Iterable))
226
- else (key, *value)
227
- )
228
- ),
229
- )
230
- )
231
- if rug:
232
- rug_kwargs["y"] = y_name
233
- _plot_chains_bokeh(
234
- ax_density=axes[idx, 0],
235
- ax_trace=axes[idx, 1],
236
- data=cds_data,
237
- x_name=draw_name,
238
- y_name=y_name,
239
- chain_prop=chain_prop,
240
- combined=combined,
241
- rug=rug,
242
- kind=kind,
243
- legend=legend,
244
- trace_kwargs=trace_kwargs,
245
- hist_kwargs=hist_kwargs,
246
- plot_kwargs=plot_kwargs,
247
- fill_kwargs=fill_kwargs,
248
- rug_kwargs=rug_kwargs,
249
- rank_kwargs=rank_kwargs,
250
- )
251
- else:
252
- for y_name in cds_var_groups[var_name]:
253
- if rug:
254
- rug_kwargs["y"] = y_name
255
- _plot_chains_bokeh(
256
- ax_density=axes[idx, 0],
257
- ax_trace=axes[idx, 1],
258
- data=cds_data,
259
- x_name=draw_name,
260
- y_name=y_name,
261
- chain_prop=chain_prop,
262
- combined=combined,
263
- rug=rug,
264
- kind=kind,
265
- legend=legend,
266
- trace_kwargs=trace_kwargs,
267
- hist_kwargs=hist_kwargs,
268
- plot_kwargs=plot_kwargs,
269
- fill_kwargs=fill_kwargs,
270
- rug_kwargs=rug_kwargs,
271
- rank_kwargs=rank_kwargs,
272
- )
273
-
274
- for col in (0, 1):
275
- _title = Title()
276
- _title.text = labeller.make_label_vert(var_name, selection, isel)
277
- axes[idx, col].title = _title
278
- axes[idx, col].y_range = DataRange1d(
279
- bounds=backend_config["bounds_y_range"], min_interval=0.1
280
- )
281
-
282
- for _, _, vlines in (j for j in lines if j[0] == var_name and j[1] == selection):
283
- if isinstance(vlines, (float, int)):
284
- line_values = [vlines]
285
- else:
286
- line_values = np.atleast_1d(vlines).ravel()
287
-
288
- for line_value in line_values:
289
- vline = Span(
290
- location=line_value,
291
- dimension="height",
292
- line_color="black",
293
- line_width=1.5,
294
- line_alpha=0.75,
295
- )
296
- hline = Span(
297
- location=line_value,
298
- dimension="width",
299
- line_color="black",
300
- line_width=1.5,
301
- line_alpha=trace_kwargs["alpha"],
302
- )
303
-
304
- axes[idx, 0].renderers.append(vline)
305
- axes[idx, 1].renderers.append(hline)
306
-
307
- if legend:
308
- for col in (0, 1):
309
- axes[idx, col].legend.location = "top_left"
310
- axes[idx, col].legend.click_policy = "hide"
311
- else:
312
- for col in (0, 1):
313
- if axes[idx, col].legend:
314
- axes[idx, col].legend.visible = False
315
-
316
- if divergences:
317
- div_density_kwargs = {}
318
- div_density_kwargs.setdefault("size", 14)
319
- div_density_kwargs.setdefault("line_color", "red")
320
- div_density_kwargs.setdefault("line_width", 2)
321
- div_density_kwargs.setdefault("line_alpha", 0.50)
322
- div_density_kwargs.setdefault("angle", np.pi / 2)
323
-
324
- div_trace_kwargs = {}
325
- div_trace_kwargs.setdefault("size", 14)
326
- div_trace_kwargs.setdefault("line_color", "red")
327
- div_trace_kwargs.setdefault("line_width", 2)
328
- div_trace_kwargs.setdefault("line_alpha", 0.50)
329
- div_trace_kwargs.setdefault("angle", np.pi / 2)
330
-
331
- div_selection = {k: v for k, v in selection.items() if k in divergence_data.dims}
332
- divs = divergence_data.sel(**div_selection).values
333
- divs = np.atleast_2d(divs)
334
-
335
- for chain, chain_divs in enumerate(divs):
336
- div_idxs = np.arange(len(chain_divs))[chain_divs]
337
- if div_idxs.size > 0:
338
- values = value[chain, div_idxs]
339
- tmp_cds = ColumnDataSource({"y": values, "x": div_idxs})
340
- if divergences == "top":
341
- y_div_trace = value.max()
342
- else:
343
- y_div_trace = value.min()
344
- glyph_density = Scatter(x="y", y=0.0, marker="dash", **div_density_kwargs)
345
- if kind == "trace":
346
- glyph_trace = Scatter(
347
- x="x", y=y_div_trace, marker="dash", **div_trace_kwargs
348
- )
349
- axes[idx, 1].add_glyph(tmp_cds, glyph_trace)
350
-
351
- axes[idx, 0].add_glyph(tmp_cds, glyph_density)
352
-
353
- show_layout(axes, show)
354
-
355
- return axes
356
-
357
-
358
- def _plot_chains_bokeh(
359
- ax_density,
360
- ax_trace,
361
- data,
362
- x_name,
363
- y_name,
364
- chain_prop,
365
- combined,
366
- rug,
367
- kind,
368
- legend,
369
- trace_kwargs,
370
- hist_kwargs,
371
- plot_kwargs,
372
- fill_kwargs,
373
- rug_kwargs,
374
- rank_kwargs,
375
- ):
376
- marker = trace_kwargs.pop("marker", True)
377
- for chain_idx, cds in data.items():
378
- if kind == "trace":
379
- if legend:
380
- trace_kwargs["legend_label"] = f"chain {chain_idx}"
381
- ax_trace.line(
382
- x=x_name,
383
- y=y_name,
384
- source=cds,
385
- **dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx),
386
- )
387
- if marker:
388
- ax_trace.scatter(
389
- x=x_name,
390
- y=y_name,
391
- marker="circle",
392
- source=cds,
393
- radius=0.30,
394
- alpha=0.5,
395
- **dealiase_sel_kwargs({}, chain_prop, chain_idx),
396
- )
397
- if not combined:
398
- rug_kwargs["cds"] = cds
399
- if legend:
400
- plot_kwargs["legend_label"] = f"chain {chain_idx}"
401
- plot_dist(
402
- cds.data[y_name],
403
- ax=ax_density,
404
- rug=rug,
405
- hist_kwargs=hist_kwargs,
406
- plot_kwargs=dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx),
407
- fill_kwargs=fill_kwargs,
408
- rug_kwargs=rug_kwargs,
409
- backend="bokeh",
410
- backend_kwargs={},
411
- show=False,
412
- )
413
-
414
- if kind == "rank_bars":
415
- value = np.array([item.data[y_name] for item in data.values()])
416
- plot_rank(value, kind="bars", ax=ax_trace, backend="bokeh", show=False, **rank_kwargs)
417
- elif kind == "rank_vlines":
418
- value = np.array([item.data[y_name] for item in data.values()])
419
- plot_rank(value, kind="vlines", ax=ax_trace, backend="bokeh", show=False, **rank_kwargs)
420
-
421
- if combined:
422
- rug_kwargs["cds"] = data
423
- if legend:
424
- plot_kwargs["legend_label"] = "combined chains"
425
- plot_dist(
426
- np.concatenate([item.data[y_name] for item in data.values()]).flatten(),
427
- ax=ax_density,
428
- rug=rug,
429
- hist_kwargs=hist_kwargs,
430
- plot_kwargs=dealiase_sel_kwargs(plot_kwargs, chain_prop, -1),
431
- fill_kwargs=fill_kwargs,
432
- rug_kwargs=rug_kwargs,
433
- backend="bokeh",
434
- backend_kwargs={},
435
- show=False,
436
- )