arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,144 +0,0 @@
1
- """Matplotlib loopitplot."""
2
-
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
- from matplotlib.colors import hsv_to_rgb, rgb_to_hsv, to_hex, to_rgb
6
- from xarray import DataArray
7
-
8
- from ....stats.density_utils import kde
9
- from ...plot_utils import _scale_fig_size
10
- from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
11
-
12
-
13
- def plot_loo_pit(
14
- ax,
15
- figsize,
16
- ecdf,
17
- loo_pit,
18
- loo_pit_ecdf,
19
- unif_ecdf,
20
- p975,
21
- p025,
22
- fill_kwargs,
23
- ecdf_fill,
24
- use_hdi,
25
- x_vals,
26
- hdi_kwargs,
27
- hdi_odds,
28
- n_unif,
29
- unif,
30
- plot_unif_kwargs,
31
- loo_pit_kde,
32
- legend,
33
- labeller,
34
- y_hat,
35
- y,
36
- color,
37
- textsize,
38
- hdi_prob,
39
- plot_kwargs,
40
- backend_kwargs,
41
- show,
42
- ):
43
- """Matplotlib loo pit plot."""
44
- if backend_kwargs is None:
45
- backend_kwargs = {}
46
-
47
- backend_kwargs = {
48
- **backend_kwarg_defaults(),
49
- **backend_kwargs,
50
- }
51
-
52
- (figsize, _, _, xt_labelsize, linewidth, _) = _scale_fig_size(figsize, textsize, 1, 1)
53
- backend_kwargs.setdefault("figsize", figsize)
54
- backend_kwargs["squeeze"] = True
55
-
56
- if ax is None:
57
- _, ax = create_axes_grid(1, backend_kwargs=backend_kwargs)
58
-
59
- plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
60
- plot_kwargs["color"] = to_hex(color)
61
- plot_kwargs.setdefault("linewidth", linewidth * 1.4)
62
- if isinstance(y, str):
63
- xlabel = y
64
- elif isinstance(y, DataArray) and y.name is not None:
65
- xlabel = y.name
66
- elif isinstance(y_hat, str):
67
- xlabel = y_hat
68
- elif isinstance(y_hat, DataArray) and y_hat.name is not None:
69
- xlabel = y_hat.name
70
- else:
71
- xlabel = ""
72
- label = "LOO-PIT ECDF" if ecdf else "LOO-PIT"
73
- xlabel = labeller.var_name_to_str(y)
74
-
75
- plot_kwargs.setdefault("label", label)
76
- plot_kwargs.setdefault("zorder", 5)
77
-
78
- plot_unif_kwargs = matplotlib_kwarg_dealiaser(plot_unif_kwargs, "plot")
79
- light_color = rgb_to_hsv(to_rgb(plot_kwargs.get("color")))
80
- light_color[1] /= 2 # pylint: disable=unsupported-assignment-operation
81
- light_color[2] += (1 - light_color[2]) / 2 # pylint: disable=unsupported-assignment-operation
82
- plot_unif_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
83
- plot_unif_kwargs.setdefault("alpha", 0.5)
84
- plot_unif_kwargs.setdefault("linewidth", 0.6 * linewidth)
85
-
86
- if ecdf:
87
- n_data_points = loo_pit.size
88
- plot_kwargs.setdefault("drawstyle", "steps-mid" if n_data_points < 100 else "default")
89
- plot_unif_kwargs.setdefault("drawstyle", "steps-mid" if n_data_points < 100 else "default")
90
-
91
- if ecdf_fill:
92
- if fill_kwargs is None:
93
- fill_kwargs = {}
94
- fill_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
95
- fill_kwargs.setdefault("alpha", 0.5)
96
- fill_kwargs.setdefault(
97
- "step", "mid" if plot_kwargs["drawstyle"] == "steps-mid" else None
98
- )
99
- fill_kwargs.setdefault("label", f"{hdi_prob * 100:.3g}% credible interval")
100
- elif use_hdi:
101
- if hdi_kwargs is None:
102
- hdi_kwargs = {}
103
- hdi_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
104
- hdi_kwargs.setdefault("alpha", 0.35)
105
- hdi_kwargs.setdefault("label", "Uniform HDI")
106
-
107
- if ecdf:
108
- ax.plot(
109
- np.hstack((0, loo_pit, 1)), np.hstack((0, loo_pit - loo_pit_ecdf, 0)), **plot_kwargs
110
- )
111
-
112
- if ecdf_fill:
113
- ax.fill_between(unif_ecdf, p975 - unif_ecdf, p025 - unif_ecdf, **fill_kwargs)
114
- else:
115
- ax.plot(unif_ecdf, p975 - unif_ecdf, unif_ecdf, p025 - unif_ecdf, **plot_unif_kwargs)
116
- else:
117
- x_ss = np.empty((n_unif, len(loo_pit_kde)))
118
- u_dens = np.empty((n_unif, len(loo_pit_kde)))
119
- if use_hdi:
120
- ax.axhspan(*hdi_odds, **hdi_kwargs)
121
-
122
- # Adds horizontal reference line
123
- ax.axhline(1, color="w", zorder=1)
124
- else:
125
- for idx in range(n_unif):
126
- x_s, unif_density = kde(unif[idx, :])
127
- x_ss[idx] = x_s
128
- u_dens[idx] = unif_density
129
- ax.plot(x_ss.T, u_dens.T, **plot_unif_kwargs)
130
- ax.plot(x_vals, loo_pit_kde, **plot_kwargs)
131
- ax.set_xlim(0, 1)
132
- ax.set_ylim(0, None)
133
- ax.set_xlabel(xlabel)
134
- ax.tick_params(labelsize=xt_labelsize)
135
- if legend:
136
- if not (use_hdi or (ecdf and ecdf_fill)):
137
- label = f"{hdi_prob * 100:.3g}% credible interval" if ecdf else "Uniform"
138
- ax.plot([], label=label, **plot_unif_kwargs)
139
- ax.legend()
140
-
141
- if backend_show(show):
142
- plt.show()
143
-
144
- return ax
@@ -1,161 +0,0 @@
1
- """Matplotlib mcseplot."""
2
-
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
- from scipy.stats import rankdata
6
-
7
- from ....stats.stats_utils import quantile as _quantile
8
- from ...plot_utils import _scale_fig_size
9
- from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
10
-
11
-
12
- def plot_mcse(
13
- ax,
14
- plotters,
15
- length_plotters,
16
- rows,
17
- cols,
18
- figsize,
19
- errorbar,
20
- rug,
21
- data,
22
- probs,
23
- kwargs,
24
- extra_methods,
25
- mean_mcse,
26
- sd_mcse,
27
- textsize,
28
- labeller,
29
- text_kwargs,
30
- rug_kwargs,
31
- extra_kwargs,
32
- idata,
33
- rug_kind,
34
- backend_kwargs,
35
- show,
36
- ):
37
- """Matplotlib mcseplot."""
38
- if backend_kwargs is None:
39
- backend_kwargs = {}
40
-
41
- backend_kwargs = {
42
- **backend_kwarg_defaults(),
43
- **backend_kwargs,
44
- }
45
-
46
- (figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _markersize) = _scale_fig_size(
47
- figsize, textsize, rows, cols
48
- )
49
- backend_kwargs.setdefault("figsize", figsize)
50
- backend_kwargs["squeeze"] = True
51
-
52
- kwargs = matplotlib_kwarg_dealiaser(kwargs, "plot")
53
- kwargs.setdefault("linestyle", "none")
54
- kwargs.setdefault("linewidth", _linewidth)
55
- kwargs.setdefault("markersize", _markersize)
56
- kwargs.setdefault("marker", "_" if errorbar else "o")
57
- kwargs.setdefault("zorder", 3)
58
-
59
- extra_kwargs = matplotlib_kwarg_dealiaser(extra_kwargs, "plot")
60
- extra_kwargs.setdefault("linestyle", "-")
61
- extra_kwargs.setdefault("linewidth", _linewidth / 2)
62
- extra_kwargs.setdefault("color", "k")
63
- extra_kwargs.setdefault("alpha", 0.5)
64
- text_x = None
65
- text_va = None
66
- if extra_methods:
67
- text_kwargs = matplotlib_kwarg_dealiaser(text_kwargs, "text")
68
- text_x = text_kwargs.pop("x", 1)
69
- text_kwargs.setdefault("fontsize", xt_labelsize * 0.7)
70
- text_kwargs.setdefault("alpha", extra_kwargs["alpha"])
71
- text_kwargs.setdefault("color", extra_kwargs["color"])
72
- text_kwargs.setdefault("horizontalalignment", "right")
73
- text_va = text_kwargs.pop("verticalalignment", None)
74
-
75
- if ax is None:
76
- _, ax = create_axes_grid(
77
- length_plotters,
78
- rows,
79
- cols,
80
- backend_kwargs=backend_kwargs,
81
- )
82
-
83
- for (var_name, selection, isel, x), ax_ in zip(plotters, np.ravel(ax)):
84
- if errorbar or rug:
85
- values = data[var_name].sel(**selection).values.flatten()
86
- if errorbar:
87
- quantile_values = _quantile(values, probs)
88
- ax_.errorbar(probs, quantile_values, yerr=x, **kwargs)
89
- else:
90
- ax_.plot(probs, x, label="quantile", **kwargs)
91
- if extra_methods:
92
- mean_mcse_i = mean_mcse[var_name].sel(**selection).values.item()
93
- sd_mcse_i = sd_mcse[var_name].sel(**selection).values.item()
94
- ax_.axhline(mean_mcse_i, **extra_kwargs)
95
- ax_.annotate(
96
- "mean",
97
- (text_x, mean_mcse_i),
98
- va=(
99
- text_va
100
- if text_va is not None
101
- else "bottom" if mean_mcse_i > sd_mcse_i else "top"
102
- ),
103
- **text_kwargs,
104
- )
105
- ax_.axhline(sd_mcse_i, **extra_kwargs)
106
- ax_.annotate(
107
- "sd",
108
- (text_x, sd_mcse_i),
109
- va=(
110
- text_va
111
- if text_va is not None
112
- else "bottom" if sd_mcse_i >= mean_mcse_i else "top"
113
- ),
114
- **text_kwargs,
115
- )
116
- if rug:
117
- rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot")
118
- if not hasattr(idata, "sample_stats"):
119
- raise ValueError("InferenceData object must contain sample_stats for rug plot")
120
- if not hasattr(idata.sample_stats, rug_kind):
121
- raise ValueError(f"InferenceData does not contain {rug_kind} data")
122
- rug_kwargs.setdefault("marker", "|")
123
- rug_kwargs.setdefault("linestyle", rug_kwargs.pop("ls", "None"))
124
- rug_kwargs.setdefault("color", rug_kwargs.pop("c", kwargs.get("color", "C0")))
125
- rug_kwargs.setdefault("space", 0.1)
126
- rug_kwargs.setdefault("markersize", rug_kwargs.pop("ms", 2 * _markersize))
127
-
128
- mask = idata.sample_stats[rug_kind].values.flatten()
129
- values = rankdata(values, method="average")[mask]
130
- y_min, y_max = ax_.get_ylim()
131
- y_min = y_min if errorbar else 0
132
- rug_space = (y_max - y_min) * rug_kwargs.pop("space")
133
- rug_x, rug_y = values / (len(mask) - 1), np.full_like(values, y_min) - rug_space
134
- ax_.plot(rug_x, rug_y, **rug_kwargs)
135
- ax_.axhline(y_min, color="k", linewidth=_linewidth, alpha=0.7)
136
-
137
- ax_.set_title(
138
- labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize, wrap=True
139
- )
140
- ax_.tick_params(labelsize=xt_labelsize)
141
- ax_.set_xlabel("Quantile", fontsize=ax_labelsize, wrap=True)
142
- ax_.set_ylabel(
143
- r"Value $\pm$ MCSE for quantiles" if errorbar else "MCSE for quantiles",
144
- fontsize=ax_labelsize,
145
- wrap=True,
146
- )
147
- ax_.set_xlim(0, 1)
148
- if rug:
149
- ax_.yaxis.get_major_locator().set_params(nbins="auto", steps=[1, 2, 5, 10])
150
- y_min, y_max = ax_.get_ylim()
151
- yticks = ax_.get_yticks()
152
- yticks = yticks[(yticks >= y_min) & (yticks < y_max)]
153
- ax_.set_yticks(yticks)
154
- ax_.set_yticklabels([f"{ytick:.3g}" for ytick in yticks])
155
- elif not errorbar:
156
- ax_.set_ylim(bottom=0)
157
-
158
- if backend_show(show):
159
- plt.show()
160
-
161
- return ax
@@ -1,355 +0,0 @@
1
- """Matplotlib pairplot."""
2
-
3
- import warnings
4
- from copy import deepcopy
5
-
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
- from mpl_toolkits.axes_grid1 import make_axes_locatable
9
-
10
- from ....rcparams import rcParams
11
- from ...distplot import plot_dist
12
- from ...kdeplot import plot_kde
13
- from ...plot_utils import _scale_fig_size, calculate_point_estimate, _init_kwargs_dict
14
- from . import backend_kwarg_defaults, backend_show, matplotlib_kwarg_dealiaser
15
-
16
-
17
- def plot_pair(
18
- ax,
19
- plotters,
20
- numvars,
21
- figsize,
22
- textsize,
23
- kind,
24
- scatter_kwargs,
25
- kde_kwargs,
26
- hexbin_kwargs,
27
- gridsize,
28
- colorbar,
29
- divergences,
30
- diverging_mask,
31
- divergences_kwargs,
32
- flat_var_names,
33
- flat_ref_slices,
34
- flat_var_labels,
35
- backend_kwargs,
36
- marginal_kwargs,
37
- show,
38
- marginals,
39
- point_estimate,
40
- point_estimate_kwargs,
41
- point_estimate_marker_kwargs,
42
- reference_values,
43
- reference_values_kwargs,
44
- ):
45
- """Matplotlib pairplot."""
46
- backend_kwargs = _init_kwargs_dict(backend_kwargs)
47
- backend_kwargs = {
48
- **backend_kwarg_defaults(),
49
- **backend_kwargs,
50
- }
51
-
52
- scatter_kwargs = matplotlib_kwarg_dealiaser(scatter_kwargs, "scatter")
53
-
54
- scatter_kwargs.setdefault("marker", ".")
55
- scatter_kwargs.setdefault("lw", 0)
56
- # Sets the default zorder higher than zorder of grid, which is 0.5
57
- scatter_kwargs.setdefault("zorder", 0.6)
58
-
59
- kde_kwargs = _init_kwargs_dict(kde_kwargs)
60
-
61
- hexbin_kwargs = matplotlib_kwarg_dealiaser(hexbin_kwargs, "hexbin")
62
- hexbin_kwargs.setdefault("mincnt", 1)
63
-
64
- divergences_kwargs = matplotlib_kwarg_dealiaser(divergences_kwargs, "plot")
65
- divergences_kwargs.setdefault("marker", "o")
66
- divergences_kwargs.setdefault("markeredgecolor", "k")
67
- divergences_kwargs.setdefault("color", "C1")
68
- divergences_kwargs.setdefault("lw", 0)
69
-
70
- marginal_kwargs = _init_kwargs_dict(marginal_kwargs)
71
-
72
- point_estimate_kwargs = matplotlib_kwarg_dealiaser(point_estimate_kwargs, "fill_between")
73
- point_estimate_kwargs.setdefault("color", "k")
74
-
75
- if kind != "kde":
76
- kde_kwargs.setdefault("contourf_kwargs", {})
77
- kde_kwargs["contourf_kwargs"].setdefault("alpha", 0)
78
- kde_kwargs.setdefault("contour_kwargs", {})
79
- kde_kwargs["contour_kwargs"].setdefault("colors", "k")
80
-
81
- if reference_values:
82
- difference = set(flat_var_names).difference(set(reference_values.keys()))
83
-
84
- if difference:
85
- warnings.warn(
86
- "Argument reference_values does not include reference value for: {}".format(
87
- ", ".join(difference)
88
- ),
89
- UserWarning,
90
- )
91
-
92
- reference_values_kwargs = matplotlib_kwarg_dealiaser(reference_values_kwargs, "plot")
93
-
94
- reference_values_kwargs.setdefault("color", "C2")
95
- reference_values_kwargs.setdefault("markeredgecolor", "k")
96
- reference_values_kwargs.setdefault("marker", "o")
97
-
98
- point_estimate_marker_kwargs = matplotlib_kwarg_dealiaser(
99
- point_estimate_marker_kwargs, "scatter"
100
- )
101
- point_estimate_marker_kwargs.setdefault("marker", "s")
102
- point_estimate_marker_kwargs.setdefault("color", "k")
103
-
104
- # pylint: disable=too-many-nested-blocks
105
- if numvars == 2:
106
- (figsize, ax_labelsize, _, xt_labelsize, linewidth, markersize) = _scale_fig_size(
107
- figsize, textsize, numvars - 1, numvars - 1
108
- )
109
- backend_kwargs.setdefault("figsize", figsize)
110
-
111
- marginal_kwargs.setdefault("plot_kwargs", {})
112
- marginal_kwargs["plot_kwargs"].setdefault("linewidth", linewidth)
113
-
114
- point_estimate_marker_kwargs.setdefault("s", markersize + 50)
115
-
116
- # Flatten data
117
- x = plotters[0][-1].flatten()
118
- y = plotters[1][-1].flatten()
119
- if ax is None:
120
- if marginals:
121
- # Instantiate figure and grid
122
- widths = [2, 2, 2, 1]
123
- heights = [1.4, 2, 2, 2]
124
- fig = plt.figure(**backend_kwargs)
125
- grid = plt.GridSpec(
126
- 4,
127
- 4,
128
- hspace=0.1,
129
- wspace=0.1,
130
- figure=fig,
131
- width_ratios=widths,
132
- height_ratios=heights,
133
- )
134
- # Set up main plot
135
- ax = fig.add_subplot(grid[1:, :-1])
136
- # Set up top KDE
137
- ax_hist_x = fig.add_subplot(grid[0, :-1], sharex=ax)
138
- ax_hist_x.set_yticks([])
139
- # Set up right KDE
140
- ax_hist_y = fig.add_subplot(grid[1:, -1], sharey=ax)
141
- ax_hist_y.set_xticks([])
142
- ax_return = np.array([[ax_hist_x, None], [ax, ax_hist_y]])
143
-
144
- for val, ax_, rotate in ((x, ax_hist_x, False), (y, ax_hist_y, True)):
145
- plot_dist(val, textsize=xt_labelsize, rotated=rotate, ax=ax_, **marginal_kwargs)
146
-
147
- # Personalize axes
148
- ax_hist_x.tick_params(labelleft=False, labelbottom=False)
149
- ax_hist_y.tick_params(labelleft=False, labelbottom=False)
150
- else:
151
- fig, ax = plt.subplots(numvars - 1, numvars - 1, **backend_kwargs)
152
- else:
153
- if marginals:
154
- assert ax.shape == (numvars, numvars)
155
- if ax[0, 1] is not None and ax[0, 1].get_figure() is not None:
156
- ax[0, 1].remove()
157
- ax_return = ax
158
- ax_hist_x = ax[0, 0]
159
- ax_hist_y = ax[1, 1]
160
- ax = ax[1, 0]
161
- for val, ax_, rotate in ((x, ax_hist_x, False), (y, ax_hist_y, True)):
162
- plot_dist(val, textsize=xt_labelsize, rotated=rotate, ax=ax_, **marginal_kwargs)
163
- else:
164
- ax = np.atleast_2d(ax)[0, 0]
165
-
166
- if "scatter" in kind:
167
- ax.scatter(x, y, **scatter_kwargs)
168
- if "kde" in kind:
169
- plot_kde(x, y, ax=ax, **kde_kwargs)
170
- if "hexbin" in kind:
171
- hexbin = ax.hexbin(
172
- x,
173
- y,
174
- gridsize=gridsize,
175
- **hexbin_kwargs,
176
- )
177
- ax.grid(False)
178
-
179
- if kind == "hexbin" and colorbar:
180
- cbar = ax.figure.colorbar(hexbin, ticks=[hexbin.norm.vmin, hexbin.norm.vmax], ax=ax)
181
- cbar.ax.set_yticklabels(["low", "high"], fontsize=ax_labelsize)
182
-
183
- if divergences:
184
- ax.plot(
185
- x[diverging_mask],
186
- y[diverging_mask],
187
- **divergences_kwargs,
188
- )
189
-
190
- if point_estimate:
191
- pe_x = calculate_point_estimate(point_estimate, x)
192
- pe_y = calculate_point_estimate(point_estimate, y)
193
- if marginals:
194
- ax_hist_x.axvline(pe_x, **point_estimate_kwargs)
195
- ax_hist_y.axhline(pe_y, **point_estimate_kwargs)
196
-
197
- ax.axvline(pe_x, **point_estimate_kwargs)
198
- ax.axhline(pe_y, **point_estimate_kwargs)
199
-
200
- ax.scatter(pe_x, pe_y, **point_estimate_marker_kwargs)
201
-
202
- if reference_values:
203
- ax.plot(
204
- np.array(reference_values[flat_var_names[0]])[flat_ref_slices[0]],
205
- np.array(reference_values[flat_var_names[1]])[flat_ref_slices[1]],
206
- **reference_values_kwargs,
207
- )
208
- ax.set_xlabel(f"{flat_var_labels[0]}", fontsize=ax_labelsize, wrap=True)
209
- ax.set_ylabel(f"{flat_var_labels[1]}", fontsize=ax_labelsize, wrap=True)
210
- ax.tick_params(labelsize=xt_labelsize)
211
-
212
- else:
213
- not_marginals = int(not marginals)
214
- num_subplot_cols = numvars - not_marginals
215
- max_plots = (
216
- num_subplot_cols**2
217
- if rcParams["plot.max_subplots"] is None
218
- else rcParams["plot.max_subplots"]
219
- )
220
- cols_to_plot = np.sum(np.arange(1, num_subplot_cols + 1).cumsum() <= max_plots)
221
- if cols_to_plot < num_subplot_cols:
222
- vars_to_plot = cols_to_plot
223
- warnings.warn(
224
- "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
225
- "of resulting pair plots with these variables, generating only a "
226
- "{side}x{side} grid".format(max_plots=max_plots, side=vars_to_plot),
227
- UserWarning,
228
- )
229
- else:
230
- vars_to_plot = numvars - not_marginals
231
-
232
- (figsize, ax_labelsize, _, xt_labelsize, _, markersize) = _scale_fig_size(
233
- figsize, textsize, vars_to_plot, vars_to_plot
234
- )
235
- backend_kwargs.setdefault("figsize", figsize)
236
- point_estimate_marker_kwargs.setdefault("s", markersize + 50)
237
-
238
- if ax is None:
239
- if backend_kwargs.pop("sharex", None) is not None:
240
- warnings.warn(
241
- "'sharex' keyword is ignored. For non-standard sharing, provide 'ax'.",
242
- UserWarning,
243
- )
244
- if backend_kwargs.pop("sharey", None) is not None:
245
- warnings.warn(
246
- "'sharey' keyword is ignored. For non-standard sharing, provide 'ax'.",
247
- UserWarning,
248
- )
249
- backend_kwargs["sharex"] = "col"
250
- if not_marginals:
251
- backend_kwargs["sharey"] = "row"
252
- fig, ax = plt.subplots(
253
- vars_to_plot,
254
- vars_to_plot,
255
- **backend_kwargs,
256
- )
257
- if backend_kwargs.get("sharey") is None:
258
- for j in range(0, vars_to_plot):
259
- for i in range(0, j):
260
- ax[j, i].axes.sharey(ax[j, 0])
261
-
262
- hexbin_values = []
263
- for i in range(0, vars_to_plot):
264
- var1 = plotters[i][-1].flatten()
265
-
266
- for j in range(0, vars_to_plot):
267
- var2 = plotters[j + not_marginals][-1].flatten()
268
- if i > j:
269
- if ax[j, i].get_figure() is not None:
270
- ax[j, i].remove()
271
- continue
272
-
273
- elif i == j and marginals:
274
- loc = "right"
275
- plot_dist(var1, ax=ax[i, j], **marginal_kwargs)
276
-
277
- else:
278
- if i == j:
279
- loc = "left"
280
-
281
- if "scatter" in kind:
282
- ax[j, i].scatter(var1, var2, **scatter_kwargs)
283
-
284
- if "kde" in kind:
285
- plot_kde(
286
- var1,
287
- var2,
288
- ax=ax[j, i],
289
- **deepcopy(kde_kwargs),
290
- )
291
-
292
- if "hexbin" in kind:
293
- ax[j, i].grid(False)
294
- hexbin = ax[j, i].hexbin(var1, var2, gridsize=gridsize, **hexbin_kwargs)
295
-
296
- if divergences:
297
- ax[j, i].plot(
298
- var1[diverging_mask], var2[diverging_mask], **divergences_kwargs
299
- )
300
-
301
- if kind == "hexbin" and colorbar:
302
- hexbin_values.append(hexbin.norm.vmin)
303
- hexbin_values.append(hexbin.norm.vmax)
304
- divider = make_axes_locatable(ax[-1, -1])
305
- cax = divider.append_axes(loc, size="7%", pad="5%")
306
- cbar = fig.colorbar(
307
- hexbin, ticks=[hexbin.norm.vmin, hexbin.norm.vmax], cax=cax
308
- )
309
- cbar.ax.set_yticklabels(["low", "high"], fontsize=ax_labelsize)
310
-
311
- if point_estimate:
312
- pe_x = calculate_point_estimate(point_estimate, var1)
313
- pe_y = calculate_point_estimate(point_estimate, var2)
314
- ax[j, i].axvline(pe_x, **point_estimate_kwargs)
315
- ax[j, i].axhline(pe_y, **point_estimate_kwargs)
316
-
317
- if marginals:
318
- ax[j - 1, i].axvline(pe_x, **point_estimate_kwargs)
319
- pe_last = calculate_point_estimate(point_estimate, plotters[-1][-1])
320
- ax[-1, -1].axvline(pe_last, **point_estimate_kwargs)
321
-
322
- ax[j, i].scatter(pe_x, pe_y, **point_estimate_marker_kwargs)
323
-
324
- if reference_values:
325
- x_name = flat_var_names[i]
326
- y_name = flat_var_names[j + not_marginals]
327
- if (x_name not in difference) and (y_name not in difference):
328
- ax[j, i].plot(
329
- np.array(reference_values[x_name])[flat_ref_slices[i]],
330
- np.array(reference_values[y_name])[
331
- flat_ref_slices[j + not_marginals]
332
- ],
333
- **reference_values_kwargs,
334
- )
335
-
336
- if j != vars_to_plot - 1:
337
- plt.setp(ax[j, i].get_xticklabels(), visible=False)
338
- else:
339
- ax[j, i].set_xlabel(f"{flat_var_labels[i]}", fontsize=ax_labelsize, wrap=True)
340
- if i != 0:
341
- plt.setp(ax[j, i].get_yticklabels(), visible=False)
342
- else:
343
- ax[j, i].set_ylabel(
344
- f"{flat_var_labels[j + not_marginals]}",
345
- fontsize=ax_labelsize,
346
- wrap=True,
347
- )
348
- ax[j, i].tick_params(labelsize=xt_labelsize)
349
-
350
- if backend_show(show):
351
- plt.show()
352
-
353
- if marginals and numvars == 2:
354
- return ax_return
355
- return ax