arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -357
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.1.dist-info/METADATA +0 -263
- arviz-0.23.1.dist-info/RECORD +0 -183
- arviz-0.23.1.dist-info/top_level.txt +0 -1
|
@@ -1,526 +0,0 @@
|
|
|
1
|
-
"""Matplotlib traceplot."""
|
|
2
|
-
|
|
3
|
-
import warnings
|
|
4
|
-
from itertools import cycle
|
|
5
|
-
|
|
6
|
-
from matplotlib import gridspec
|
|
7
|
-
import matplotlib.pyplot as plt
|
|
8
|
-
import numpy as np
|
|
9
|
-
from matplotlib.lines import Line2D
|
|
10
|
-
import matplotlib.ticker as mticker
|
|
11
|
-
|
|
12
|
-
from ....stats.density_utils import get_bins
|
|
13
|
-
from ...distplot import plot_dist
|
|
14
|
-
from ...plot_utils import _scale_fig_size, format_coords_as_labels
|
|
15
|
-
from ...rankplot import plot_rank
|
|
16
|
-
from . import backend_kwarg_defaults, backend_show, dealiase_sel_kwargs, matplotlib_kwarg_dealiaser
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def plot_trace(
|
|
20
|
-
data,
|
|
21
|
-
var_names, # pylint: disable=unused-argument
|
|
22
|
-
divergences,
|
|
23
|
-
kind,
|
|
24
|
-
figsize,
|
|
25
|
-
rug,
|
|
26
|
-
lines,
|
|
27
|
-
circ_var_names,
|
|
28
|
-
circ_var_units,
|
|
29
|
-
compact,
|
|
30
|
-
compact_prop,
|
|
31
|
-
combined,
|
|
32
|
-
chain_prop,
|
|
33
|
-
legend,
|
|
34
|
-
labeller,
|
|
35
|
-
plot_kwargs,
|
|
36
|
-
fill_kwargs,
|
|
37
|
-
rug_kwargs,
|
|
38
|
-
hist_kwargs,
|
|
39
|
-
trace_kwargs,
|
|
40
|
-
rank_kwargs,
|
|
41
|
-
plotters,
|
|
42
|
-
divergence_data,
|
|
43
|
-
axes,
|
|
44
|
-
backend_kwargs,
|
|
45
|
-
backend_config, # pylint: disable=unused-argument
|
|
46
|
-
show,
|
|
47
|
-
):
|
|
48
|
-
"""Plot distribution (histogram or kernel density estimates) and sampled values.
|
|
49
|
-
|
|
50
|
-
If `divergences` data is available in `sample_stats`, will plot the location of divergences as
|
|
51
|
-
dashed vertical lines.
|
|
52
|
-
|
|
53
|
-
Parameters
|
|
54
|
-
----------
|
|
55
|
-
data : obj
|
|
56
|
-
Any object that can be converted to an az.InferenceData object
|
|
57
|
-
Refer to documentation of az.convert_to_dataset for details
|
|
58
|
-
var_names : string, or list of strings
|
|
59
|
-
One or more variables to be plotted.
|
|
60
|
-
divergences : {"bottom", "top", None, False}
|
|
61
|
-
Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y.
|
|
62
|
-
kind : {"trace", "rank_bar", "rank_vlines"}, optional
|
|
63
|
-
Choose between plotting sampled values per iteration and rank plots.
|
|
64
|
-
figsize : figure size tuple
|
|
65
|
-
If None, size is (12, variables * 2)
|
|
66
|
-
rug : bool
|
|
67
|
-
If True adds a rugplot. Defaults to False. Ignored for 2D KDE. Only affects continuous
|
|
68
|
-
variables.
|
|
69
|
-
lines : tuple or list
|
|
70
|
-
List of tuple of (var_name, {'coord': selection}, [line_positions]) to be overplotted as
|
|
71
|
-
vertical lines on the density and horizontal lines on the trace.
|
|
72
|
-
circ_var_names : string, or list of strings
|
|
73
|
-
List of circular variables to account for when plotting KDE.
|
|
74
|
-
circ_var_units : str
|
|
75
|
-
Whether the variables in `circ_var_names` are in "degrees" or "radians".
|
|
76
|
-
combined : bool
|
|
77
|
-
Flag for combining multiple chains into a single line. If False (default), chains will be
|
|
78
|
-
plotted separately.
|
|
79
|
-
legend : bool
|
|
80
|
-
Add a legend to the figure with the chain color code.
|
|
81
|
-
plot_kwargs : dict
|
|
82
|
-
Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
|
|
83
|
-
fill_kwargs : dict
|
|
84
|
-
Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
|
|
85
|
-
rug_kwargs : dict
|
|
86
|
-
Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
|
|
87
|
-
hist_kwargs : dict
|
|
88
|
-
Extra keyword arguments passed to `arviz.plot_dist`. Only affects discrete variables.
|
|
89
|
-
trace_kwargs : dict
|
|
90
|
-
Extra keyword arguments passed to `plt.plot`
|
|
91
|
-
rank_kwargs : dict
|
|
92
|
-
Extra keyword arguments passed to `arviz.plot_rank`
|
|
93
|
-
Returns
|
|
94
|
-
-------
|
|
95
|
-
axes : matplotlib axes
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
Examples
|
|
99
|
-
--------
|
|
100
|
-
Plot a subset variables
|
|
101
|
-
|
|
102
|
-
.. plot::
|
|
103
|
-
:context: close-figs
|
|
104
|
-
|
|
105
|
-
>>> import arviz as az
|
|
106
|
-
>>> data = az.load_arviz_data('non_centered_eight')
|
|
107
|
-
>>> coords = {'school': ['Choate', 'Lawrenceville']}
|
|
108
|
-
>>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords)
|
|
109
|
-
|
|
110
|
-
Show all dimensions of multidimensional variables in the same plot
|
|
111
|
-
|
|
112
|
-
.. plot::
|
|
113
|
-
:context: close-figs
|
|
114
|
-
|
|
115
|
-
>>> az.plot_trace(data, compact=True)
|
|
116
|
-
|
|
117
|
-
Combine all chains into one distribution
|
|
118
|
-
|
|
119
|
-
.. plot::
|
|
120
|
-
:context: close-figs
|
|
121
|
-
|
|
122
|
-
>>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, combined=True)
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
Plot reference lines against distribution and trace
|
|
126
|
-
|
|
127
|
-
.. plot::
|
|
128
|
-
:context: close-figs
|
|
129
|
-
|
|
130
|
-
>>> lines = (('theta_t',{'school': "Choate"}, [-1]),)
|
|
131
|
-
>>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines)
|
|
132
|
-
|
|
133
|
-
"""
|
|
134
|
-
# Set plot default backend kwargs
|
|
135
|
-
if backend_kwargs is None:
|
|
136
|
-
backend_kwargs = {}
|
|
137
|
-
|
|
138
|
-
if circ_var_names is None:
|
|
139
|
-
circ_var_names = []
|
|
140
|
-
|
|
141
|
-
backend_kwargs = {**backend_kwarg_defaults(), **backend_kwargs}
|
|
142
|
-
|
|
143
|
-
if lines is None:
|
|
144
|
-
lines = ()
|
|
145
|
-
|
|
146
|
-
num_chain_props = len(data.chain) + 1 if combined else len(data.chain)
|
|
147
|
-
if not compact:
|
|
148
|
-
chain_prop = "color" if chain_prop is None else chain_prop
|
|
149
|
-
else:
|
|
150
|
-
chain_prop = (
|
|
151
|
-
{
|
|
152
|
-
"linestyle": ("solid", "dotted", "dashed", "dashdot"),
|
|
153
|
-
}
|
|
154
|
-
if chain_prop is None
|
|
155
|
-
else chain_prop
|
|
156
|
-
)
|
|
157
|
-
compact_prop = "color" if compact_prop is None else compact_prop
|
|
158
|
-
|
|
159
|
-
if isinstance(chain_prop, str):
|
|
160
|
-
chain_prop = {chain_prop: plt.rcParams["axes.prop_cycle"].by_key()[chain_prop]}
|
|
161
|
-
if isinstance(chain_prop, tuple):
|
|
162
|
-
warnings.warn(
|
|
163
|
-
"chain_prop as a tuple will be deprecated in a future warning, use a dict instead",
|
|
164
|
-
FutureWarning,
|
|
165
|
-
)
|
|
166
|
-
chain_prop = {chain_prop[0]: chain_prop[1]}
|
|
167
|
-
chain_prop = {
|
|
168
|
-
prop_name: [prop for _, prop in zip(range(num_chain_props), cycle(props))]
|
|
169
|
-
for prop_name, props in chain_prop.items()
|
|
170
|
-
}
|
|
171
|
-
|
|
172
|
-
if isinstance(compact_prop, str):
|
|
173
|
-
compact_prop = {compact_prop: plt.rcParams["axes.prop_cycle"].by_key()[compact_prop]}
|
|
174
|
-
if isinstance(compact_prop, tuple):
|
|
175
|
-
warnings.warn(
|
|
176
|
-
"compact_prop as a tuple will be deprecated in a future warning, use a dict instead",
|
|
177
|
-
FutureWarning,
|
|
178
|
-
)
|
|
179
|
-
compact_prop = {compact_prop[0]: compact_prop[1]}
|
|
180
|
-
|
|
181
|
-
if figsize is None:
|
|
182
|
-
figsize = (12, len(plotters) * 2)
|
|
183
|
-
|
|
184
|
-
backend_kwargs.setdefault("figsize", figsize)
|
|
185
|
-
|
|
186
|
-
trace_kwargs = matplotlib_kwarg_dealiaser(trace_kwargs, "plot")
|
|
187
|
-
trace_kwargs.setdefault("alpha", 0.35)
|
|
188
|
-
|
|
189
|
-
hist_kwargs = matplotlib_kwarg_dealiaser(hist_kwargs, "hist")
|
|
190
|
-
hist_kwargs.setdefault("alpha", 0.35)
|
|
191
|
-
|
|
192
|
-
plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
|
|
193
|
-
fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "fill_between")
|
|
194
|
-
rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "scatter")
|
|
195
|
-
rank_kwargs = matplotlib_kwarg_dealiaser(rank_kwargs, "bar")
|
|
196
|
-
if compact:
|
|
197
|
-
rank_kwargs.setdefault("bar_kwargs", {})
|
|
198
|
-
rank_kwargs["bar_kwargs"].setdefault("alpha", 0.2)
|
|
199
|
-
|
|
200
|
-
textsize = plot_kwargs.pop("textsize", 10)
|
|
201
|
-
|
|
202
|
-
figsize, _, titlesize, xt_labelsize, linewidth, _ = _scale_fig_size(
|
|
203
|
-
figsize, textsize, rows=len(plotters), cols=2
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
trace_kwargs.setdefault("linewidth", linewidth)
|
|
207
|
-
plot_kwargs.setdefault("linewidth", linewidth)
|
|
208
|
-
|
|
209
|
-
# Check the input for lines
|
|
210
|
-
if lines is not None:
|
|
211
|
-
all_var_names = set(plotter[0] for plotter in plotters)
|
|
212
|
-
|
|
213
|
-
invalid_var_names = set()
|
|
214
|
-
for line in lines:
|
|
215
|
-
if line[0] not in all_var_names:
|
|
216
|
-
invalid_var_names.add(line[0])
|
|
217
|
-
if invalid_var_names:
|
|
218
|
-
warnings.warn(
|
|
219
|
-
"A valid var_name should be provided, found {} expected from {}".format(
|
|
220
|
-
invalid_var_names, all_var_names
|
|
221
|
-
)
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
if axes is None:
|
|
225
|
-
fig = plt.figure(**backend_kwargs)
|
|
226
|
-
spec = gridspec.GridSpec(ncols=2, nrows=len(plotters), figure=fig)
|
|
227
|
-
|
|
228
|
-
# pylint: disable=too-many-nested-blocks
|
|
229
|
-
for idx, (var_name, selection, isel, value) in enumerate(plotters):
|
|
230
|
-
for idy in range(2):
|
|
231
|
-
value = np.atleast_2d(value)
|
|
232
|
-
|
|
233
|
-
circular = var_name in circ_var_names and not idy
|
|
234
|
-
if var_name in circ_var_names and idy:
|
|
235
|
-
circ_units_trace = circ_var_units
|
|
236
|
-
else:
|
|
237
|
-
circ_units_trace = False
|
|
238
|
-
|
|
239
|
-
if axes is None:
|
|
240
|
-
ax = fig.add_subplot(spec[idx, idy], polar=circular)
|
|
241
|
-
else:
|
|
242
|
-
ax = axes[idx, idy]
|
|
243
|
-
|
|
244
|
-
if len(value.shape) == 2:
|
|
245
|
-
if compact_prop:
|
|
246
|
-
aux_plot_kwargs = dealiase_sel_kwargs(plot_kwargs, compact_prop, 0)
|
|
247
|
-
aux_trace_kwargs = dealiase_sel_kwargs(trace_kwargs, compact_prop, 0)
|
|
248
|
-
else:
|
|
249
|
-
aux_plot_kwargs = plot_kwargs
|
|
250
|
-
aux_trace_kwargs = trace_kwargs
|
|
251
|
-
|
|
252
|
-
ax = _plot_chains_mpl(
|
|
253
|
-
ax,
|
|
254
|
-
idy,
|
|
255
|
-
value,
|
|
256
|
-
data,
|
|
257
|
-
chain_prop,
|
|
258
|
-
combined,
|
|
259
|
-
xt_labelsize,
|
|
260
|
-
rug,
|
|
261
|
-
kind,
|
|
262
|
-
aux_trace_kwargs,
|
|
263
|
-
hist_kwargs,
|
|
264
|
-
aux_plot_kwargs,
|
|
265
|
-
fill_kwargs,
|
|
266
|
-
rug_kwargs,
|
|
267
|
-
rank_kwargs,
|
|
268
|
-
circular,
|
|
269
|
-
circ_var_units,
|
|
270
|
-
circ_units_trace,
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
else:
|
|
274
|
-
sub_data = data[var_name].sel(**selection)
|
|
275
|
-
legend_labels = format_coords_as_labels(sub_data, skip_dims=("chain", "draw"))
|
|
276
|
-
legend_title = ", ".join(
|
|
277
|
-
[
|
|
278
|
-
f"{coord_name}"
|
|
279
|
-
for coord_name in sub_data.coords
|
|
280
|
-
if coord_name not in {"chain", "draw"}
|
|
281
|
-
]
|
|
282
|
-
)
|
|
283
|
-
value = value.reshape((value.shape[0], value.shape[1], -1))
|
|
284
|
-
compact_prop_iter = {
|
|
285
|
-
prop_name: [prop for _, prop in zip(range(value.shape[2]), cycle(props))]
|
|
286
|
-
for prop_name, props in compact_prop.items()
|
|
287
|
-
}
|
|
288
|
-
handles = []
|
|
289
|
-
for sub_idx, label in zip(range(value.shape[2]), legend_labels):
|
|
290
|
-
aux_plot_kwargs = dealiase_sel_kwargs(plot_kwargs, compact_prop_iter, sub_idx)
|
|
291
|
-
aux_trace_kwargs = dealiase_sel_kwargs(trace_kwargs, compact_prop_iter, sub_idx)
|
|
292
|
-
ax = _plot_chains_mpl(
|
|
293
|
-
ax,
|
|
294
|
-
idy,
|
|
295
|
-
value[..., sub_idx],
|
|
296
|
-
data,
|
|
297
|
-
chain_prop,
|
|
298
|
-
combined,
|
|
299
|
-
xt_labelsize,
|
|
300
|
-
rug,
|
|
301
|
-
kind,
|
|
302
|
-
aux_trace_kwargs,
|
|
303
|
-
hist_kwargs,
|
|
304
|
-
aux_plot_kwargs,
|
|
305
|
-
fill_kwargs,
|
|
306
|
-
rug_kwargs,
|
|
307
|
-
rank_kwargs,
|
|
308
|
-
circular,
|
|
309
|
-
circ_var_units,
|
|
310
|
-
circ_units_trace,
|
|
311
|
-
)
|
|
312
|
-
if legend:
|
|
313
|
-
handles.append(
|
|
314
|
-
Line2D(
|
|
315
|
-
[],
|
|
316
|
-
[],
|
|
317
|
-
label=label,
|
|
318
|
-
**dealiase_sel_kwargs(aux_plot_kwargs, chain_prop, 0),
|
|
319
|
-
)
|
|
320
|
-
)
|
|
321
|
-
if legend and idy == 0:
|
|
322
|
-
ax.legend(handles=handles, title=legend_title)
|
|
323
|
-
|
|
324
|
-
if value[0].dtype.kind == "i" and idy == 0:
|
|
325
|
-
xticks = get_bins(value)
|
|
326
|
-
ax.set_xticks(xticks[:-1])
|
|
327
|
-
y = 1 / textsize
|
|
328
|
-
if not idy:
|
|
329
|
-
ax.set_yticks([])
|
|
330
|
-
if circular:
|
|
331
|
-
y = 0.13 if selection else 0.12
|
|
332
|
-
ax.set_title(
|
|
333
|
-
labeller.make_label_vert(var_name, selection, isel),
|
|
334
|
-
fontsize=titlesize,
|
|
335
|
-
wrap=True,
|
|
336
|
-
y=textsize * y,
|
|
337
|
-
)
|
|
338
|
-
ax.tick_params(labelsize=xt_labelsize)
|
|
339
|
-
|
|
340
|
-
xlims = ax.get_xlim()
|
|
341
|
-
ylims = ax.get_ylim()
|
|
342
|
-
|
|
343
|
-
if divergences:
|
|
344
|
-
div_selection = {k: v for k, v in selection.items() if k in divergence_data.dims}
|
|
345
|
-
divs = divergence_data.sel(**div_selection).values
|
|
346
|
-
# if combined:
|
|
347
|
-
# divs = divs.flatten()
|
|
348
|
-
divs = np.atleast_2d(divs)
|
|
349
|
-
|
|
350
|
-
for chain, chain_divs in enumerate(divs):
|
|
351
|
-
div_draws = data.draw.values[chain_divs]
|
|
352
|
-
div_idxs = np.arange(len(chain_divs))[chain_divs]
|
|
353
|
-
if div_idxs.size > 0:
|
|
354
|
-
if divergences == "top":
|
|
355
|
-
ylocs = ylims[1]
|
|
356
|
-
else:
|
|
357
|
-
ylocs = ylims[0]
|
|
358
|
-
values = value[chain, div_idxs]
|
|
359
|
-
|
|
360
|
-
if circular:
|
|
361
|
-
tick = [ax.get_rmin() + ax.get_rmax() * 0.60, ax.get_rmax()]
|
|
362
|
-
for val in values:
|
|
363
|
-
ax.plot(
|
|
364
|
-
[val, val],
|
|
365
|
-
tick,
|
|
366
|
-
color="black",
|
|
367
|
-
markeredgewidth=1.5,
|
|
368
|
-
markersize=30,
|
|
369
|
-
alpha=trace_kwargs["alpha"],
|
|
370
|
-
zorder=0.6,
|
|
371
|
-
)
|
|
372
|
-
else:
|
|
373
|
-
if kind == "trace" and idy:
|
|
374
|
-
ax.plot(
|
|
375
|
-
div_draws,
|
|
376
|
-
np.zeros_like(div_idxs) + ylocs,
|
|
377
|
-
marker="|",
|
|
378
|
-
color="black",
|
|
379
|
-
markeredgewidth=1.5,
|
|
380
|
-
markersize=30,
|
|
381
|
-
linestyle="None",
|
|
382
|
-
alpha=hist_kwargs["alpha"],
|
|
383
|
-
zorder=0.6,
|
|
384
|
-
)
|
|
385
|
-
elif not idy:
|
|
386
|
-
ax.plot(
|
|
387
|
-
values,
|
|
388
|
-
np.zeros_like(values) + ylocs,
|
|
389
|
-
marker="|",
|
|
390
|
-
color="black",
|
|
391
|
-
markeredgewidth=1.5,
|
|
392
|
-
markersize=30,
|
|
393
|
-
linestyle="None",
|
|
394
|
-
alpha=trace_kwargs["alpha"],
|
|
395
|
-
zorder=0.6,
|
|
396
|
-
)
|
|
397
|
-
|
|
398
|
-
for _, _, vlines in (j for j in lines if j[0] == var_name and j[1] == selection):
|
|
399
|
-
if isinstance(vlines, (float, int)):
|
|
400
|
-
line_values = [vlines]
|
|
401
|
-
else:
|
|
402
|
-
line_values = np.atleast_1d(vlines).ravel()
|
|
403
|
-
if not np.issubdtype(line_values.dtype, np.number):
|
|
404
|
-
raise ValueError(f"line-positions should be numeric, found {line_values}")
|
|
405
|
-
if idy:
|
|
406
|
-
ax.hlines(
|
|
407
|
-
line_values,
|
|
408
|
-
xlims[0],
|
|
409
|
-
xlims[1],
|
|
410
|
-
colors="black",
|
|
411
|
-
linewidth=1.5,
|
|
412
|
-
alpha=trace_kwargs["alpha"],
|
|
413
|
-
)
|
|
414
|
-
|
|
415
|
-
else:
|
|
416
|
-
ax.vlines(
|
|
417
|
-
line_values,
|
|
418
|
-
ylims[0],
|
|
419
|
-
ylims[1],
|
|
420
|
-
colors="black",
|
|
421
|
-
linewidth=1.5,
|
|
422
|
-
alpha=trace_kwargs["alpha"],
|
|
423
|
-
)
|
|
424
|
-
|
|
425
|
-
if kind == "trace" and idy:
|
|
426
|
-
ax.set_xlim(left=data.draw.min(), right=data.draw.max())
|
|
427
|
-
|
|
428
|
-
if legend:
|
|
429
|
-
legend_kwargs = trace_kwargs if combined else plot_kwargs
|
|
430
|
-
handles = [
|
|
431
|
-
Line2D(
|
|
432
|
-
[], [], label=chain_id, **dealiase_sel_kwargs(legend_kwargs, chain_prop, chain_id)
|
|
433
|
-
)
|
|
434
|
-
for chain_id in range(data.sizes["chain"])
|
|
435
|
-
]
|
|
436
|
-
if combined:
|
|
437
|
-
handles.insert(
|
|
438
|
-
0,
|
|
439
|
-
Line2D(
|
|
440
|
-
[], [], label="combined", **dealiase_sel_kwargs(plot_kwargs, chain_prop, -1)
|
|
441
|
-
),
|
|
442
|
-
)
|
|
443
|
-
ax.figure.axes[1].legend(handles=handles, title="chain", loc="upper right")
|
|
444
|
-
|
|
445
|
-
if axes is None:
|
|
446
|
-
axes = np.array(ax.figure.axes).reshape(-1, 2)
|
|
447
|
-
|
|
448
|
-
if backend_show(show):
|
|
449
|
-
plt.show()
|
|
450
|
-
|
|
451
|
-
return axes
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
def _plot_chains_mpl(
|
|
455
|
-
axes,
|
|
456
|
-
idy,
|
|
457
|
-
value,
|
|
458
|
-
data,
|
|
459
|
-
chain_prop,
|
|
460
|
-
combined,
|
|
461
|
-
xt_labelsize,
|
|
462
|
-
rug,
|
|
463
|
-
kind,
|
|
464
|
-
trace_kwargs,
|
|
465
|
-
hist_kwargs,
|
|
466
|
-
plot_kwargs,
|
|
467
|
-
fill_kwargs,
|
|
468
|
-
rug_kwargs,
|
|
469
|
-
rank_kwargs,
|
|
470
|
-
circular,
|
|
471
|
-
circ_var_units,
|
|
472
|
-
circ_units_trace,
|
|
473
|
-
):
|
|
474
|
-
if not circular:
|
|
475
|
-
circ_var_units = False
|
|
476
|
-
|
|
477
|
-
for chain_idx, row in enumerate(value):
|
|
478
|
-
if kind == "trace":
|
|
479
|
-
aux_kwargs = dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx)
|
|
480
|
-
if idy:
|
|
481
|
-
axes.plot(data.draw.values, row, **aux_kwargs)
|
|
482
|
-
if circ_units_trace == "degrees":
|
|
483
|
-
y_tick_locs = axes.get_yticks()
|
|
484
|
-
y_tick_labels = [i + 2 * 180 if i < 0 else i for i in np.rad2deg(y_tick_locs)]
|
|
485
|
-
axes.yaxis.set_major_locator(mticker.FixedLocator(y_tick_locs))
|
|
486
|
-
axes.set_yticklabels([f"{i:.0f}°" for i in y_tick_labels])
|
|
487
|
-
|
|
488
|
-
if not combined:
|
|
489
|
-
aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx)
|
|
490
|
-
if not idy:
|
|
491
|
-
axes = plot_dist(
|
|
492
|
-
values=row,
|
|
493
|
-
textsize=xt_labelsize,
|
|
494
|
-
rug=rug,
|
|
495
|
-
ax=axes,
|
|
496
|
-
hist_kwargs=hist_kwargs,
|
|
497
|
-
plot_kwargs=aux_kwargs,
|
|
498
|
-
fill_kwargs=fill_kwargs,
|
|
499
|
-
rug_kwargs=rug_kwargs,
|
|
500
|
-
backend="matplotlib",
|
|
501
|
-
show=False,
|
|
502
|
-
is_circular=circ_var_units,
|
|
503
|
-
)
|
|
504
|
-
|
|
505
|
-
if kind == "rank_bars" and idy:
|
|
506
|
-
axes = plot_rank(data=value, kind="bars", ax=axes, **rank_kwargs)
|
|
507
|
-
elif kind == "rank_vlines" and idy:
|
|
508
|
-
axes = plot_rank(data=value, kind="vlines", ax=axes, **rank_kwargs)
|
|
509
|
-
|
|
510
|
-
if combined:
|
|
511
|
-
aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, -1)
|
|
512
|
-
if not idy:
|
|
513
|
-
axes = plot_dist(
|
|
514
|
-
values=value.flatten(),
|
|
515
|
-
textsize=xt_labelsize,
|
|
516
|
-
rug=rug,
|
|
517
|
-
ax=axes,
|
|
518
|
-
hist_kwargs=hist_kwargs,
|
|
519
|
-
plot_kwargs=aux_kwargs,
|
|
520
|
-
fill_kwargs=fill_kwargs,
|
|
521
|
-
rug_kwargs=rug_kwargs,
|
|
522
|
-
backend="matplotlib",
|
|
523
|
-
show=False,
|
|
524
|
-
is_circular=circ_var_units,
|
|
525
|
-
)
|
|
526
|
-
return axes
|
|
@@ -1,121 +0,0 @@
|
|
|
1
|
-
"""Matplotlib plot time series figure."""
|
|
2
|
-
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
|
-
import numpy as np
|
|
5
|
-
|
|
6
|
-
from ...plot_utils import _scale_fig_size
|
|
7
|
-
from . import create_axes_grid, backend_show, matplotlib_kwarg_dealiaser, backend_kwarg_defaults
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def plot_ts(
|
|
11
|
-
x_plotters,
|
|
12
|
-
y_plotters,
|
|
13
|
-
y_mean_plotters,
|
|
14
|
-
y_hat_plotters,
|
|
15
|
-
y_holdout_plotters,
|
|
16
|
-
x_holdout_plotters,
|
|
17
|
-
y_forecasts_plotters,
|
|
18
|
-
y_forecasts_mean_plotters,
|
|
19
|
-
num_samples,
|
|
20
|
-
backend_kwargs,
|
|
21
|
-
y_kwargs,
|
|
22
|
-
y_hat_plot_kwargs,
|
|
23
|
-
y_mean_plot_kwargs,
|
|
24
|
-
vline_kwargs,
|
|
25
|
-
length_plotters,
|
|
26
|
-
rows,
|
|
27
|
-
cols,
|
|
28
|
-
textsize,
|
|
29
|
-
figsize,
|
|
30
|
-
legend,
|
|
31
|
-
axes,
|
|
32
|
-
show,
|
|
33
|
-
):
|
|
34
|
-
"""Matplotlib time series."""
|
|
35
|
-
if backend_kwargs is None:
|
|
36
|
-
backend_kwargs = {}
|
|
37
|
-
|
|
38
|
-
backend_kwargs = {
|
|
39
|
-
**backend_kwarg_defaults(),
|
|
40
|
-
**backend_kwargs,
|
|
41
|
-
}
|
|
42
|
-
|
|
43
|
-
if figsize is None:
|
|
44
|
-
figsize = (12, rows * 5)
|
|
45
|
-
|
|
46
|
-
backend_kwargs.setdefault("figsize", figsize)
|
|
47
|
-
backend_kwargs.setdefault("squeeze", False)
|
|
48
|
-
|
|
49
|
-
figsize, _, _, xt_labelsize, _, _ = _scale_fig_size(figsize, textsize, rows, cols)
|
|
50
|
-
|
|
51
|
-
if axes is None:
|
|
52
|
-
_, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs)
|
|
53
|
-
|
|
54
|
-
y_kwargs = matplotlib_kwarg_dealiaser(y_kwargs, "plot")
|
|
55
|
-
y_kwargs.setdefault("color", "blue")
|
|
56
|
-
y_kwargs.setdefault("zorder", 10)
|
|
57
|
-
|
|
58
|
-
y_hat_plot_kwargs = matplotlib_kwarg_dealiaser(y_hat_plot_kwargs, "plot")
|
|
59
|
-
y_hat_plot_kwargs.setdefault("color", "grey")
|
|
60
|
-
y_hat_plot_kwargs.setdefault("alpha", 0.1)
|
|
61
|
-
|
|
62
|
-
y_mean_plot_kwargs = matplotlib_kwarg_dealiaser(y_mean_plot_kwargs, "plot")
|
|
63
|
-
y_mean_plot_kwargs.setdefault("color", "red")
|
|
64
|
-
y_mean_plot_kwargs.setdefault("linestyle", "dashed")
|
|
65
|
-
|
|
66
|
-
vline_kwargs = matplotlib_kwarg_dealiaser(vline_kwargs, "plot")
|
|
67
|
-
vline_kwargs.setdefault("color", "black")
|
|
68
|
-
vline_kwargs.setdefault("linestyle", "dashed")
|
|
69
|
-
|
|
70
|
-
for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]):
|
|
71
|
-
y_var_name, _, _, y_plotters_i = y_plotters[i]
|
|
72
|
-
x_var_name, _, _, x_plotters_i = x_plotters[i]
|
|
73
|
-
|
|
74
|
-
ax_i.plot(x_plotters_i, y_plotters_i, **y_kwargs)
|
|
75
|
-
ax_i.plot([], label="Actual", **y_kwargs)
|
|
76
|
-
if y_hat_plotters is not None or y_forecasts_plotters is not None:
|
|
77
|
-
ax_i.plot([], label="Fitted", **y_mean_plot_kwargs)
|
|
78
|
-
ax_i.plot([], label="Uncertainty", **y_hat_plot_kwargs)
|
|
79
|
-
|
|
80
|
-
ax_i.set_xlabel(x_var_name)
|
|
81
|
-
ax_i.set_ylabel(y_var_name)
|
|
82
|
-
|
|
83
|
-
if y_hat_plotters is not None:
|
|
84
|
-
*_, y_hat_plotters_i = y_hat_plotters[i]
|
|
85
|
-
*_, x_hat_plotters_i = x_plotters[i]
|
|
86
|
-
for j in range(num_samples):
|
|
87
|
-
ax_i.plot(x_hat_plotters_i, y_hat_plotters_i[..., j], **y_hat_plot_kwargs)
|
|
88
|
-
|
|
89
|
-
*_, x_mean_plotters_i = x_plotters[i]
|
|
90
|
-
*_, y_mean_plotters_i = y_mean_plotters[i]
|
|
91
|
-
ax_i.plot(x_mean_plotters_i, y_mean_plotters_i, **y_mean_plot_kwargs)
|
|
92
|
-
|
|
93
|
-
if y_holdout_plotters is not None:
|
|
94
|
-
*_, y_holdout_plotters_i = y_holdout_plotters[i]
|
|
95
|
-
*_, x_holdout_plotters_i = x_holdout_plotters[i]
|
|
96
|
-
|
|
97
|
-
ax_i.plot(x_holdout_plotters_i, y_holdout_plotters_i, **y_kwargs)
|
|
98
|
-
ax_i.axvline(x_plotters_i[-1], **vline_kwargs)
|
|
99
|
-
|
|
100
|
-
if y_forecasts_plotters is not None:
|
|
101
|
-
*_, y_forecasts_plotters_i = y_forecasts_plotters[i]
|
|
102
|
-
*_, x_forecasts_plotters_i = x_holdout_plotters[i]
|
|
103
|
-
for j in range(num_samples):
|
|
104
|
-
ax_i.plot(
|
|
105
|
-
x_forecasts_plotters_i, y_forecasts_plotters_i[..., j], **y_hat_plot_kwargs
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
*_, x_forecasts_mean_plotters_i = x_holdout_plotters[i]
|
|
109
|
-
*_, y_forecasts_mean_plotters_i = y_forecasts_mean_plotters[i]
|
|
110
|
-
ax_i.plot(
|
|
111
|
-
x_forecasts_mean_plotters_i, y_forecasts_mean_plotters_i, **y_mean_plot_kwargs
|
|
112
|
-
)
|
|
113
|
-
ax_i.axvline(x_plotters_i[-1], **vline_kwargs)
|
|
114
|
-
|
|
115
|
-
if legend:
|
|
116
|
-
ax_i.legend(fontsize=xt_labelsize, loc="upper left")
|
|
117
|
-
|
|
118
|
-
if backend_show(show):
|
|
119
|
-
plt.show()
|
|
120
|
-
|
|
121
|
-
return axes
|