arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,135 +0,0 @@
1
- """Matplotlib Compareplot."""
2
-
3
- import matplotlib.pyplot as plt
4
-
5
- from ...plot_utils import _scale_fig_size
6
- from . import backend_kwarg_defaults, backend_show, create_axes_grid
7
-
8
-
9
- def plot_compare(
10
- ax,
11
- comp_df,
12
- legend,
13
- title,
14
- figsize,
15
- plot_ic_diff,
16
- plot_standard_error,
17
- insample_dev,
18
- yticks_pos,
19
- yticks_labels,
20
- plot_kwargs,
21
- information_criterion,
22
- textsize,
23
- step,
24
- backend_kwargs,
25
- show,
26
- ):
27
- """Matplotlib compare plot."""
28
- if backend_kwargs is None:
29
- backend_kwargs = {}
30
-
31
- backend_kwargs = {
32
- **backend_kwarg_defaults(),
33
- **backend_kwargs,
34
- }
35
-
36
- if figsize is None:
37
- figsize = (6, len(comp_df))
38
-
39
- figsize, ax_labelsize, _, xt_labelsize, linewidth, _ = _scale_fig_size(figsize, textsize, 1, 1)
40
-
41
- backend_kwargs.setdefault("figsize", figsize)
42
- backend_kwargs["squeeze"] = True
43
-
44
- if ax is None:
45
- _, ax = create_axes_grid(1, backend_kwargs=backend_kwargs)
46
-
47
- if plot_standard_error:
48
- ax.errorbar(
49
- x=comp_df[information_criterion],
50
- y=yticks_pos[::2],
51
- xerr=comp_df.se,
52
- label="ELPD",
53
- color=plot_kwargs.get("color_ic", "k"),
54
- fmt=plot_kwargs.get("marker_ic", "o"),
55
- mfc=plot_kwargs.get("marker_fc", "white"),
56
- mew=linewidth,
57
- lw=linewidth,
58
- )
59
- else:
60
- ax.plot(
61
- comp_df[information_criterion],
62
- yticks_pos[::2],
63
- label="ELPD",
64
- color=plot_kwargs.get("color_ic", "k"),
65
- marker=plot_kwargs.get("marker_ic", "o"),
66
- mfc=plot_kwargs.get("marker_fc", "white"),
67
- mew=linewidth,
68
- lw=0,
69
- zorder=3,
70
- )
71
-
72
- if plot_ic_diff:
73
- ax.set_yticks(yticks_pos)
74
- ax.errorbar(
75
- x=comp_df[information_criterion].iloc[1:],
76
- y=yticks_pos[1::2],
77
- xerr=comp_df.dse[1:],
78
- label="ELPD difference",
79
- color=plot_kwargs.get("color_dse", "grey"),
80
- fmt=plot_kwargs.get("marker_dse", "^"),
81
- mew=linewidth,
82
- elinewidth=linewidth,
83
- )
84
-
85
- else:
86
- ax.set_yticks(yticks_pos[::2])
87
-
88
- scale = comp_df["scale"].iloc[0]
89
-
90
- if insample_dev:
91
- p_ic = comp_df[f"p_{information_criterion.split('_')[1]}"]
92
- if scale == "log":
93
- correction = p_ic
94
- elif scale == "negative_log":
95
- correction = -p_ic
96
- elif scale == "deviance":
97
- correction = -(2 * p_ic)
98
- ax.plot(
99
- comp_df[information_criterion] + correction,
100
- yticks_pos[::2],
101
- label="In-sample ELPD",
102
- color=plot_kwargs.get("color_insample_dev", "k"),
103
- marker=plot_kwargs.get("marker_insample_dev", "o"),
104
- mew=linewidth,
105
- lw=0,
106
- )
107
-
108
- ax.axvline(
109
- comp_df[information_criterion].iloc[0],
110
- ls=plot_kwargs.get("ls_min_ic", "--"),
111
- color=plot_kwargs.get("color_ls_min_ic", "grey"),
112
- lw=linewidth,
113
- )
114
- if legend:
115
- ax.legend(bbox_to_anchor=(1.01, 1), loc="upper left", ncol=1, fontsize=ax_labelsize)
116
-
117
- if title:
118
- ax.set_title(
119
- f"Model comparison\n{'higher' if scale == 'log' else 'lower'} is better",
120
- fontsize=ax_labelsize,
121
- )
122
-
123
- if scale == "negative_log":
124
- scale = "-log"
125
-
126
- ax.set_xlabel(f"{information_criterion} ({scale})", fontsize=ax_labelsize)
127
- ax.set_ylabel("ranked models", fontsize=ax_labelsize)
128
- ax.set_yticklabels(yticks_labels)
129
- ax.set_ylim(-1 + step, 0 - step)
130
- ax.tick_params(labelsize=xt_labelsize)
131
-
132
- if backend_show(show):
133
- plt.show()
134
-
135
- return ax
@@ -1,194 +0,0 @@
1
- """Matplotlib Densityplot."""
2
-
3
- from itertools import cycle
4
-
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
-
8
- from ....stats import hdi
9
- from ....stats.density_utils import get_bins, kde
10
- from ...plot_utils import _scale_fig_size, calculate_point_estimate
11
- from . import backend_kwarg_defaults, backend_show, create_axes_grid
12
-
13
-
14
- def plot_density(
15
- ax,
16
- all_labels,
17
- to_plot,
18
- colors,
19
- bw,
20
- circular,
21
- figsize,
22
- length_plotters,
23
- rows,
24
- cols,
25
- textsize,
26
- labeller,
27
- hdi_prob,
28
- point_estimate,
29
- hdi_markers,
30
- outline,
31
- shade,
32
- n_data,
33
- data_labels,
34
- backend_kwargs,
35
- show,
36
- ):
37
- """Matplotlib densityplot."""
38
- if backend_kwargs is None:
39
- backend_kwargs = {}
40
-
41
- backend_kwargs = {
42
- **backend_kwarg_defaults(),
43
- **backend_kwargs,
44
- }
45
-
46
- if colors == "cycle":
47
- colors = [
48
- prop
49
- for _, prop in zip(
50
- range(n_data), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
51
- )
52
- ]
53
- elif isinstance(colors, str):
54
- colors = [colors for _ in range(n_data)]
55
-
56
- (figsize, _, titlesize, xt_labelsize, linewidth, markersize) = _scale_fig_size(
57
- figsize, textsize, rows, cols
58
- )
59
-
60
- backend_kwargs.setdefault("figsize", figsize)
61
- backend_kwargs.setdefault("squeeze", False)
62
- if ax is None:
63
- _, ax = create_axes_grid(
64
- length_plotters,
65
- rows,
66
- cols,
67
- backend_kwargs=backend_kwargs,
68
- )
69
-
70
- axis_map = dict(zip(all_labels, np.ravel(ax)))
71
-
72
- for m_idx, plotters in enumerate(to_plot):
73
- for var_name, selection, isel, values in plotters:
74
- label = labeller.make_label_vert(var_name, selection, isel)
75
- _d_helper(
76
- values.flatten(),
77
- label,
78
- colors[m_idx],
79
- bw,
80
- circular,
81
- titlesize,
82
- xt_labelsize,
83
- linewidth,
84
- markersize,
85
- hdi_prob,
86
- point_estimate,
87
- hdi_markers,
88
- outline,
89
- shade,
90
- axis_map[label],
91
- )
92
-
93
- if n_data > 1:
94
- for m_idx, label in enumerate(data_labels):
95
- np.ravel(ax).item(0).plot([], label=label, c=colors[m_idx], markersize=markersize)
96
- np.ravel(ax).item(0).legend(fontsize=xt_labelsize)
97
-
98
- if backend_show(show):
99
- plt.show()
100
-
101
- return ax
102
-
103
-
104
- def _d_helper(
105
- vec,
106
- vname,
107
- color,
108
- bw,
109
- circular,
110
- titlesize,
111
- xt_labelsize,
112
- linewidth,
113
- markersize,
114
- hdi_prob,
115
- point_estimate,
116
- hdi_markers,
117
- outline,
118
- shade,
119
- ax,
120
- ):
121
- """Plot an individual dimension.
122
-
123
- Parameters
124
- ----------
125
- vec : array
126
- 1D array from trace
127
- vname : str
128
- variable name
129
- color : str
130
- matplotlib color
131
- bw: float or str, optional
132
- If numeric, indicates the bandwidth and must be positive.
133
- If str, indicates the method to estimate the bandwidth and must be
134
- one of "scott", "silverman", "isj" or "experimental" when `circular` is False
135
- and "taylor" (for now) when `circular` is True.
136
- titlesize : float
137
- font size for title
138
- xt_labelsize : float
139
- fontsize for xticks
140
- linewidth : float
141
- Thickness of lines
142
- markersize : float
143
- Size of markers
144
- hdi_prob : float
145
- Probability for the highest density interval. Defaults to 0.94
146
- point_estimate : Optional[str]
147
- Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
148
- Defaults to 'auto' i.e. it falls back to default set in rcParams.
149
- shade : float
150
- Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1
151
- (opaque). Defaults to 0.
152
- ax : matplotlib axes
153
- """
154
- if vec.dtype.kind == "f":
155
- if hdi_prob != 1:
156
- hdi_ = hdi(vec, hdi_prob, multimodal=False)
157
- new_vec = vec[(vec >= hdi_[0]) & (vec <= hdi_[1])]
158
- else:
159
- new_vec = vec
160
-
161
- x, density = kde(new_vec, circular=circular, bw=bw)
162
- density *= hdi_prob
163
- xmin, xmax = x[0], x[-1]
164
- ymin, ymax = density[0], density[-1]
165
-
166
- if outline:
167
- ax.plot(x, density, color=color, lw=linewidth)
168
- ax.plot([xmin, xmin], [-ymin / 100, ymin], color=color, ls="-", lw=linewidth)
169
- ax.plot([xmax, xmax], [-ymax / 100, ymax], color=color, ls="-", lw=linewidth)
170
-
171
- if shade:
172
- ax.fill_between(x, density, color=color, alpha=shade)
173
-
174
- else:
175
- xmin, xmax = hdi(vec, hdi_prob, multimodal=False)
176
- bins = get_bins(vec)
177
- if outline:
178
- ax.hist(vec, bins=bins, color=color, histtype="step", align="left")
179
- if shade:
180
- ax.hist(vec, bins=bins, color=color, alpha=shade)
181
-
182
- if hdi_markers:
183
- ax.plot(xmin, 0, hdi_markers, color=color, markeredgecolor="k", markersize=markersize)
184
- ax.plot(xmax, 0, hdi_markers, color=color, markeredgecolor="k", markersize=markersize)
185
-
186
- if point_estimate is not None:
187
- est = calculate_point_estimate(point_estimate, vec, bw)
188
- ax.plot(est, 0, "o", color=color, markeredgecolor="k", markersize=markersize)
189
-
190
- ax.set_yticks([])
191
- ax.set_title(vname, fontsize=titlesize, wrap=True)
192
- for pos in ["left", "right", "top"]:
193
- ax.spines[pos].set_visible(False)
194
- ax.tick_params(labelsize=xt_labelsize)
@@ -1,119 +0,0 @@
1
- """Matplotlib Density Comparison plot."""
2
-
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
-
6
- from ...distplot import plot_dist
7
- from ...plot_utils import _scale_fig_size
8
- from . import backend_kwarg_defaults, backend_show
9
-
10
-
11
- def plot_dist_comparison(
12
- ax,
13
- nvars,
14
- ngroups,
15
- figsize,
16
- dc_plotters,
17
- legend,
18
- groups,
19
- textsize,
20
- labeller,
21
- prior_kwargs,
22
- posterior_kwargs,
23
- observed_kwargs,
24
- backend_kwargs,
25
- show,
26
- ):
27
- """Matplotlib Density Comparison plot."""
28
- if backend_kwargs is None:
29
- backend_kwargs = {}
30
-
31
- backend_kwargs = {
32
- **backend_kwarg_defaults(),
33
- **backend_kwargs,
34
- }
35
-
36
- if prior_kwargs is None:
37
- prior_kwargs = {}
38
-
39
- if posterior_kwargs is None:
40
- posterior_kwargs = {}
41
-
42
- if observed_kwargs is None:
43
- observed_kwargs = {}
44
-
45
- if backend_kwargs is None:
46
- backend_kwargs = {}
47
-
48
- (figsize, _, _, _, linewidth, _) = _scale_fig_size(figsize, textsize, 2 * nvars, ngroups)
49
-
50
- backend_kwargs.setdefault("figsize", figsize)
51
-
52
- posterior_kwargs.setdefault("plot_kwargs", {})
53
- posterior_kwargs["plot_kwargs"]["color"] = posterior_kwargs["plot_kwargs"].get("color", "C0")
54
- posterior_kwargs["plot_kwargs"].setdefault("linewidth", linewidth)
55
- posterior_kwargs.setdefault("hist_kwargs", {})
56
- posterior_kwargs["hist_kwargs"].setdefault("alpha", 0.5)
57
-
58
- prior_kwargs.setdefault("plot_kwargs", {})
59
- prior_kwargs["plot_kwargs"]["color"] = prior_kwargs["plot_kwargs"].get("color", "C1")
60
- prior_kwargs["plot_kwargs"].setdefault("linewidth", linewidth)
61
- prior_kwargs.setdefault("hist_kwargs", {})
62
- prior_kwargs["hist_kwargs"].setdefault("alpha", 0.5)
63
-
64
- observed_kwargs.setdefault("plot_kwargs", {})
65
- observed_kwargs["plot_kwargs"]["color"] = observed_kwargs["plot_kwargs"].get("color", "C2")
66
- observed_kwargs["plot_kwargs"].setdefault("linewidth", linewidth)
67
- observed_kwargs.setdefault("hist_kwargs", {})
68
- observed_kwargs["hist_kwargs"].setdefault("alpha", 0.5)
69
-
70
- if ax is None:
71
- axes = np.empty((nvars, ngroups + 1), dtype=object)
72
- fig = plt.figure(**backend_kwargs)
73
- gs = fig.add_gridspec(ncols=ngroups, nrows=nvars * 2)
74
- for i in range(nvars):
75
- for j in range(ngroups):
76
- axes[i, j] = fig.add_subplot(gs[2 * i, j])
77
- axes[i, -1] = fig.add_subplot(gs[2 * i + 1, :])
78
-
79
- else:
80
- axes = ax
81
- if ax.shape != (nvars, ngroups + 1):
82
- raise ValueError(
83
- f"Found {axes.shape} shape of axes, "
84
- f"which is not equal to data shape {(nvars, ngroups + 1)}."
85
- )
86
-
87
- for idx, plotter in enumerate(dc_plotters):
88
- group = groups[idx]
89
- kwargs = (
90
- prior_kwargs
91
- if group.startswith("prior")
92
- else posterior_kwargs if group.startswith("posterior") else observed_kwargs
93
- )
94
- for idx2, (
95
- var_name,
96
- sel,
97
- isel,
98
- data,
99
- ) in enumerate(plotter):
100
- label = f"{group}"
101
- plot_dist(
102
- data,
103
- label=label if legend else None,
104
- ax=axes[idx2, idx],
105
- **kwargs,
106
- )
107
- plot_dist(
108
- data,
109
- label=label if legend else None,
110
- ax=axes[idx2, -1],
111
- **kwargs,
112
- )
113
- if idx == 0:
114
- axes[idx2, -1].set_xlabel(labeller.make_label_vert(var_name, sel, isel))
115
-
116
- if backend_show(show):
117
- plt.show()
118
-
119
- return axes
@@ -1,178 +0,0 @@
1
- """Matplotlib distplot."""
2
-
3
- import matplotlib.pyplot as plt
4
- from matplotlib import _pylab_helpers
5
- import numpy as np
6
-
7
- from ....stats.density_utils import get_bins
8
- from ...kdeplot import plot_kde
9
- from ...plot_utils import _scale_fig_size, _init_kwargs_dict
10
- from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
11
-
12
-
13
- def plot_dist(
14
- values,
15
- values2,
16
- color,
17
- kind,
18
- cumulative,
19
- label,
20
- rotated,
21
- rug,
22
- bw,
23
- quantiles,
24
- contour,
25
- fill_last,
26
- figsize,
27
- textsize,
28
- plot_kwargs,
29
- fill_kwargs,
30
- rug_kwargs,
31
- contour_kwargs,
32
- contourf_kwargs,
33
- pcolormesh_kwargs,
34
- hist_kwargs,
35
- is_circular,
36
- ax,
37
- backend_kwargs,
38
- show,
39
- ):
40
- """Matplotlib distplot."""
41
- backend_kwargs = _init_kwargs_dict(backend_kwargs)
42
-
43
- backend_kwargs = {
44
- **backend_kwarg_defaults(),
45
- **backend_kwargs,
46
- }
47
-
48
- figsize, *_ = _scale_fig_size(figsize, textsize)
49
-
50
- backend_kwargs.setdefault("figsize", figsize)
51
- backend_kwargs["squeeze"] = True
52
- backend_kwargs.setdefault("subplot_kw", {})
53
- backend_kwargs["subplot_kw"].setdefault("polar", is_circular)
54
-
55
- if ax is None:
56
- fig_manager = _pylab_helpers.Gcf.get_active()
57
- if fig_manager is not None:
58
- ax = fig_manager.canvas.figure.gca()
59
- else:
60
- _, ax = create_axes_grid(
61
- 1,
62
- backend_kwargs=backend_kwargs,
63
- )
64
-
65
- if kind == "hist":
66
- hist_kwargs = matplotlib_kwarg_dealiaser(hist_kwargs, "hist")
67
- hist_kwargs.setdefault("cumulative", cumulative)
68
- hist_kwargs.setdefault("color", color)
69
- hist_kwargs.setdefault("label", label)
70
- hist_kwargs.setdefault("rwidth", 0.9)
71
- hist_kwargs.setdefault("density", True)
72
-
73
- if rotated:
74
- hist_kwargs.setdefault("orientation", "horizontal")
75
- else:
76
- hist_kwargs.setdefault("orientation", "vertical")
77
-
78
- ax = _histplot_mpl_op(
79
- values=values,
80
- values2=values2,
81
- rotated=rotated,
82
- ax=ax,
83
- hist_kwargs=hist_kwargs,
84
- is_circular=is_circular,
85
- )
86
-
87
- elif kind == "kde":
88
- plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
89
- plot_kwargs.setdefault("color", color)
90
- legend = label is not None
91
-
92
- ax = plot_kde(
93
- values,
94
- values2,
95
- cumulative=cumulative,
96
- rug=rug,
97
- label=label,
98
- bw=bw,
99
- quantiles=quantiles,
100
- rotated=rotated,
101
- contour=contour,
102
- legend=legend,
103
- fill_last=fill_last,
104
- textsize=textsize,
105
- plot_kwargs=plot_kwargs,
106
- fill_kwargs=fill_kwargs,
107
- rug_kwargs=rug_kwargs,
108
- contour_kwargs=contour_kwargs,
109
- contourf_kwargs=contourf_kwargs,
110
- pcolormesh_kwargs=pcolormesh_kwargs,
111
- ax=ax,
112
- backend="matplotlib",
113
- backend_kwargs=backend_kwargs,
114
- is_circular=is_circular,
115
- show=show,
116
- )
117
-
118
- if backend_show(show):
119
- plt.show()
120
-
121
- return ax
122
-
123
-
124
- def _histplot_mpl_op(values, values2, rotated, ax, hist_kwargs, is_circular):
125
- """Add a histogram for the data to the axes."""
126
- bins = hist_kwargs.pop("bins", None)
127
-
128
- if is_circular == "degrees":
129
- if bins is None:
130
- bins = get_bins(values)
131
- values = np.deg2rad(values)
132
- bins = np.deg2rad(bins)
133
-
134
- elif is_circular:
135
- labels = [
136
- "0",
137
- f"{np.pi/4:.2f}",
138
- f"{np.pi/2:.2f}",
139
- f"{3*np.pi/4:.2f}",
140
- f"{np.pi:.2f}",
141
- f"{-3*np.pi/4:.2f}",
142
- f"{-np.pi/2:.2f}",
143
- f"{-np.pi/4:.2f}",
144
- ]
145
-
146
- ax.set_xticklabels(labels)
147
-
148
- if values2 is not None:
149
- raise NotImplementedError("Insert hexbin plot here")
150
-
151
- if bins is None:
152
- bins = get_bins(values)
153
-
154
- if values.dtype.kind == "i":
155
- hist_kwargs.setdefault("align", "left")
156
- else:
157
- hist_kwargs.setdefault("align", "mid")
158
-
159
- n, bins, _ = ax.hist(np.asarray(values).flatten(), bins=bins, **hist_kwargs)
160
-
161
- if values.dtype.kind == "i":
162
- ticks = bins[:-1]
163
- else:
164
- ticks = (bins[1:] + bins[:-1]) / 2
165
-
166
- if rotated:
167
- ax.set_yticks(ticks)
168
- elif not is_circular:
169
- ax.set_xticks(ticks)
170
-
171
- if is_circular:
172
- ax.set_ylim(0, 1.5 * n.max())
173
- ax.set_yticklabels([])
174
-
175
- if hist_kwargs.get("label") is not None:
176
- ax.legend()
177
-
178
- return ax