arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
@@ -1,526 +0,0 @@
1
- """Matplotlib traceplot."""
2
-
3
- import warnings
4
- from itertools import cycle
5
-
6
- from matplotlib import gridspec
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
- from matplotlib.lines import Line2D
10
- import matplotlib.ticker as mticker
11
-
12
- from ....stats.density_utils import get_bins
13
- from ...distplot import plot_dist
14
- from ...plot_utils import _scale_fig_size, format_coords_as_labels
15
- from ...rankplot import plot_rank
16
- from . import backend_kwarg_defaults, backend_show, dealiase_sel_kwargs, matplotlib_kwarg_dealiaser
17
-
18
-
19
- def plot_trace(
20
- data,
21
- var_names, # pylint: disable=unused-argument
22
- divergences,
23
- kind,
24
- figsize,
25
- rug,
26
- lines,
27
- circ_var_names,
28
- circ_var_units,
29
- compact,
30
- compact_prop,
31
- combined,
32
- chain_prop,
33
- legend,
34
- labeller,
35
- plot_kwargs,
36
- fill_kwargs,
37
- rug_kwargs,
38
- hist_kwargs,
39
- trace_kwargs,
40
- rank_kwargs,
41
- plotters,
42
- divergence_data,
43
- axes,
44
- backend_kwargs,
45
- backend_config, # pylint: disable=unused-argument
46
- show,
47
- ):
48
- """Plot distribution (histogram or kernel density estimates) and sampled values.
49
-
50
- If `divergences` data is available in `sample_stats`, will plot the location of divergences as
51
- dashed vertical lines.
52
-
53
- Parameters
54
- ----------
55
- data : obj
56
- Any object that can be converted to an az.InferenceData object
57
- Refer to documentation of az.convert_to_dataset for details
58
- var_names : string, or list of strings
59
- One or more variables to be plotted.
60
- divergences : {"bottom", "top", None, False}
61
- Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y.
62
- kind : {"trace", "rank_bar", "rank_vlines"}, optional
63
- Choose between plotting sampled values per iteration and rank plots.
64
- figsize : figure size tuple
65
- If None, size is (12, variables * 2)
66
- rug : bool
67
- If True adds a rugplot. Defaults to False. Ignored for 2D KDE. Only affects continuous
68
- variables.
69
- lines : tuple or list
70
- List of tuple of (var_name, {'coord': selection}, [line_positions]) to be overplotted as
71
- vertical lines on the density and horizontal lines on the trace.
72
- circ_var_names : string, or list of strings
73
- List of circular variables to account for when plotting KDE.
74
- circ_var_units : str
75
- Whether the variables in `circ_var_names` are in "degrees" or "radians".
76
- combined : bool
77
- Flag for combining multiple chains into a single line. If False (default), chains will be
78
- plotted separately.
79
- legend : bool
80
- Add a legend to the figure with the chain color code.
81
- plot_kwargs : dict
82
- Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
83
- fill_kwargs : dict
84
- Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
85
- rug_kwargs : dict
86
- Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
87
- hist_kwargs : dict
88
- Extra keyword arguments passed to `arviz.plot_dist`. Only affects discrete variables.
89
- trace_kwargs : dict
90
- Extra keyword arguments passed to `plt.plot`
91
- rank_kwargs : dict
92
- Extra keyword arguments passed to `arviz.plot_rank`
93
- Returns
94
- -------
95
- axes : matplotlib axes
96
-
97
-
98
- Examples
99
- --------
100
- Plot a subset variables
101
-
102
- .. plot::
103
- :context: close-figs
104
-
105
- >>> import arviz as az
106
- >>> data = az.load_arviz_data('non_centered_eight')
107
- >>> coords = {'school': ['Choate', 'Lawrenceville']}
108
- >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords)
109
-
110
- Show all dimensions of multidimensional variables in the same plot
111
-
112
- .. plot::
113
- :context: close-figs
114
-
115
- >>> az.plot_trace(data, compact=True)
116
-
117
- Combine all chains into one distribution
118
-
119
- .. plot::
120
- :context: close-figs
121
-
122
- >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, combined=True)
123
-
124
-
125
- Plot reference lines against distribution and trace
126
-
127
- .. plot::
128
- :context: close-figs
129
-
130
- >>> lines = (('theta_t',{'school': "Choate"}, [-1]),)
131
- >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines)
132
-
133
- """
134
- # Set plot default backend kwargs
135
- if backend_kwargs is None:
136
- backend_kwargs = {}
137
-
138
- if circ_var_names is None:
139
- circ_var_names = []
140
-
141
- backend_kwargs = {**backend_kwarg_defaults(), **backend_kwargs}
142
-
143
- if lines is None:
144
- lines = ()
145
-
146
- num_chain_props = len(data.chain) + 1 if combined else len(data.chain)
147
- if not compact:
148
- chain_prop = "color" if chain_prop is None else chain_prop
149
- else:
150
- chain_prop = (
151
- {
152
- "linestyle": ("solid", "dotted", "dashed", "dashdot"),
153
- }
154
- if chain_prop is None
155
- else chain_prop
156
- )
157
- compact_prop = "color" if compact_prop is None else compact_prop
158
-
159
- if isinstance(chain_prop, str):
160
- chain_prop = {chain_prop: plt.rcParams["axes.prop_cycle"].by_key()[chain_prop]}
161
- if isinstance(chain_prop, tuple):
162
- warnings.warn(
163
- "chain_prop as a tuple will be deprecated in a future warning, use a dict instead",
164
- FutureWarning,
165
- )
166
- chain_prop = {chain_prop[0]: chain_prop[1]}
167
- chain_prop = {
168
- prop_name: [prop for _, prop in zip(range(num_chain_props), cycle(props))]
169
- for prop_name, props in chain_prop.items()
170
- }
171
-
172
- if isinstance(compact_prop, str):
173
- compact_prop = {compact_prop: plt.rcParams["axes.prop_cycle"].by_key()[compact_prop]}
174
- if isinstance(compact_prop, tuple):
175
- warnings.warn(
176
- "compact_prop as a tuple will be deprecated in a future warning, use a dict instead",
177
- FutureWarning,
178
- )
179
- compact_prop = {compact_prop[0]: compact_prop[1]}
180
-
181
- if figsize is None:
182
- figsize = (12, len(plotters) * 2)
183
-
184
- backend_kwargs.setdefault("figsize", figsize)
185
-
186
- trace_kwargs = matplotlib_kwarg_dealiaser(trace_kwargs, "plot")
187
- trace_kwargs.setdefault("alpha", 0.35)
188
-
189
- hist_kwargs = matplotlib_kwarg_dealiaser(hist_kwargs, "hist")
190
- hist_kwargs.setdefault("alpha", 0.35)
191
-
192
- plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
193
- fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "fill_between")
194
- rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "scatter")
195
- rank_kwargs = matplotlib_kwarg_dealiaser(rank_kwargs, "bar")
196
- if compact:
197
- rank_kwargs.setdefault("bar_kwargs", {})
198
- rank_kwargs["bar_kwargs"].setdefault("alpha", 0.2)
199
-
200
- textsize = plot_kwargs.pop("textsize", 10)
201
-
202
- figsize, _, titlesize, xt_labelsize, linewidth, _ = _scale_fig_size(
203
- figsize, textsize, rows=len(plotters), cols=2
204
- )
205
-
206
- trace_kwargs.setdefault("linewidth", linewidth)
207
- plot_kwargs.setdefault("linewidth", linewidth)
208
-
209
- # Check the input for lines
210
- if lines is not None:
211
- all_var_names = set(plotter[0] for plotter in plotters)
212
-
213
- invalid_var_names = set()
214
- for line in lines:
215
- if line[0] not in all_var_names:
216
- invalid_var_names.add(line[0])
217
- if invalid_var_names:
218
- warnings.warn(
219
- "A valid var_name should be provided, found {} expected from {}".format(
220
- invalid_var_names, all_var_names
221
- )
222
- )
223
-
224
- if axes is None:
225
- fig = plt.figure(**backend_kwargs)
226
- spec = gridspec.GridSpec(ncols=2, nrows=len(plotters), figure=fig)
227
-
228
- # pylint: disable=too-many-nested-blocks
229
- for idx, (var_name, selection, isel, value) in enumerate(plotters):
230
- for idy in range(2):
231
- value = np.atleast_2d(value)
232
-
233
- circular = var_name in circ_var_names and not idy
234
- if var_name in circ_var_names and idy:
235
- circ_units_trace = circ_var_units
236
- else:
237
- circ_units_trace = False
238
-
239
- if axes is None:
240
- ax = fig.add_subplot(spec[idx, idy], polar=circular)
241
- else:
242
- ax = axes[idx, idy]
243
-
244
- if len(value.shape) == 2:
245
- if compact_prop:
246
- aux_plot_kwargs = dealiase_sel_kwargs(plot_kwargs, compact_prop, 0)
247
- aux_trace_kwargs = dealiase_sel_kwargs(trace_kwargs, compact_prop, 0)
248
- else:
249
- aux_plot_kwargs = plot_kwargs
250
- aux_trace_kwargs = trace_kwargs
251
-
252
- ax = _plot_chains_mpl(
253
- ax,
254
- idy,
255
- value,
256
- data,
257
- chain_prop,
258
- combined,
259
- xt_labelsize,
260
- rug,
261
- kind,
262
- aux_trace_kwargs,
263
- hist_kwargs,
264
- aux_plot_kwargs,
265
- fill_kwargs,
266
- rug_kwargs,
267
- rank_kwargs,
268
- circular,
269
- circ_var_units,
270
- circ_units_trace,
271
- )
272
-
273
- else:
274
- sub_data = data[var_name].sel(**selection)
275
- legend_labels = format_coords_as_labels(sub_data, skip_dims=("chain", "draw"))
276
- legend_title = ", ".join(
277
- [
278
- f"{coord_name}"
279
- for coord_name in sub_data.coords
280
- if coord_name not in {"chain", "draw"}
281
- ]
282
- )
283
- value = value.reshape((value.shape[0], value.shape[1], -1))
284
- compact_prop_iter = {
285
- prop_name: [prop for _, prop in zip(range(value.shape[2]), cycle(props))]
286
- for prop_name, props in compact_prop.items()
287
- }
288
- handles = []
289
- for sub_idx, label in zip(range(value.shape[2]), legend_labels):
290
- aux_plot_kwargs = dealiase_sel_kwargs(plot_kwargs, compact_prop_iter, sub_idx)
291
- aux_trace_kwargs = dealiase_sel_kwargs(trace_kwargs, compact_prop_iter, sub_idx)
292
- ax = _plot_chains_mpl(
293
- ax,
294
- idy,
295
- value[..., sub_idx],
296
- data,
297
- chain_prop,
298
- combined,
299
- xt_labelsize,
300
- rug,
301
- kind,
302
- aux_trace_kwargs,
303
- hist_kwargs,
304
- aux_plot_kwargs,
305
- fill_kwargs,
306
- rug_kwargs,
307
- rank_kwargs,
308
- circular,
309
- circ_var_units,
310
- circ_units_trace,
311
- )
312
- if legend:
313
- handles.append(
314
- Line2D(
315
- [],
316
- [],
317
- label=label,
318
- **dealiase_sel_kwargs(aux_plot_kwargs, chain_prop, 0),
319
- )
320
- )
321
- if legend and idy == 0:
322
- ax.legend(handles=handles, title=legend_title)
323
-
324
- if value[0].dtype.kind == "i" and idy == 0:
325
- xticks = get_bins(value)
326
- ax.set_xticks(xticks[:-1])
327
- y = 1 / textsize
328
- if not idy:
329
- ax.set_yticks([])
330
- if circular:
331
- y = 0.13 if selection else 0.12
332
- ax.set_title(
333
- labeller.make_label_vert(var_name, selection, isel),
334
- fontsize=titlesize,
335
- wrap=True,
336
- y=textsize * y,
337
- )
338
- ax.tick_params(labelsize=xt_labelsize)
339
-
340
- xlims = ax.get_xlim()
341
- ylims = ax.get_ylim()
342
-
343
- if divergences:
344
- div_selection = {k: v for k, v in selection.items() if k in divergence_data.dims}
345
- divs = divergence_data.sel(**div_selection).values
346
- # if combined:
347
- # divs = divs.flatten()
348
- divs = np.atleast_2d(divs)
349
-
350
- for chain, chain_divs in enumerate(divs):
351
- div_draws = data.draw.values[chain_divs]
352
- div_idxs = np.arange(len(chain_divs))[chain_divs]
353
- if div_idxs.size > 0:
354
- if divergences == "top":
355
- ylocs = ylims[1]
356
- else:
357
- ylocs = ylims[0]
358
- values = value[chain, div_idxs]
359
-
360
- if circular:
361
- tick = [ax.get_rmin() + ax.get_rmax() * 0.60, ax.get_rmax()]
362
- for val in values:
363
- ax.plot(
364
- [val, val],
365
- tick,
366
- color="black",
367
- markeredgewidth=1.5,
368
- markersize=30,
369
- alpha=trace_kwargs["alpha"],
370
- zorder=0.6,
371
- )
372
- else:
373
- if kind == "trace" and idy:
374
- ax.plot(
375
- div_draws,
376
- np.zeros_like(div_idxs) + ylocs,
377
- marker="|",
378
- color="black",
379
- markeredgewidth=1.5,
380
- markersize=30,
381
- linestyle="None",
382
- alpha=hist_kwargs["alpha"],
383
- zorder=0.6,
384
- )
385
- elif not idy:
386
- ax.plot(
387
- values,
388
- np.zeros_like(values) + ylocs,
389
- marker="|",
390
- color="black",
391
- markeredgewidth=1.5,
392
- markersize=30,
393
- linestyle="None",
394
- alpha=trace_kwargs["alpha"],
395
- zorder=0.6,
396
- )
397
-
398
- for _, _, vlines in (j for j in lines if j[0] == var_name and j[1] == selection):
399
- if isinstance(vlines, (float, int)):
400
- line_values = [vlines]
401
- else:
402
- line_values = np.atleast_1d(vlines).ravel()
403
- if not np.issubdtype(line_values.dtype, np.number):
404
- raise ValueError(f"line-positions should be numeric, found {line_values}")
405
- if idy:
406
- ax.hlines(
407
- line_values,
408
- xlims[0],
409
- xlims[1],
410
- colors="black",
411
- linewidth=1.5,
412
- alpha=trace_kwargs["alpha"],
413
- )
414
-
415
- else:
416
- ax.vlines(
417
- line_values,
418
- ylims[0],
419
- ylims[1],
420
- colors="black",
421
- linewidth=1.5,
422
- alpha=trace_kwargs["alpha"],
423
- )
424
-
425
- if kind == "trace" and idy:
426
- ax.set_xlim(left=data.draw.min(), right=data.draw.max())
427
-
428
- if legend:
429
- legend_kwargs = trace_kwargs if combined else plot_kwargs
430
- handles = [
431
- Line2D(
432
- [], [], label=chain_id, **dealiase_sel_kwargs(legend_kwargs, chain_prop, chain_id)
433
- )
434
- for chain_id in range(data.sizes["chain"])
435
- ]
436
- if combined:
437
- handles.insert(
438
- 0,
439
- Line2D(
440
- [], [], label="combined", **dealiase_sel_kwargs(plot_kwargs, chain_prop, -1)
441
- ),
442
- )
443
- ax.figure.axes[1].legend(handles=handles, title="chain", loc="upper right")
444
-
445
- if axes is None:
446
- axes = np.array(ax.figure.axes).reshape(-1, 2)
447
-
448
- if backend_show(show):
449
- plt.show()
450
-
451
- return axes
452
-
453
-
454
- def _plot_chains_mpl(
455
- axes,
456
- idy,
457
- value,
458
- data,
459
- chain_prop,
460
- combined,
461
- xt_labelsize,
462
- rug,
463
- kind,
464
- trace_kwargs,
465
- hist_kwargs,
466
- plot_kwargs,
467
- fill_kwargs,
468
- rug_kwargs,
469
- rank_kwargs,
470
- circular,
471
- circ_var_units,
472
- circ_units_trace,
473
- ):
474
- if not circular:
475
- circ_var_units = False
476
-
477
- for chain_idx, row in enumerate(value):
478
- if kind == "trace":
479
- aux_kwargs = dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx)
480
- if idy:
481
- axes.plot(data.draw.values, row, **aux_kwargs)
482
- if circ_units_trace == "degrees":
483
- y_tick_locs = axes.get_yticks()
484
- y_tick_labels = [i + 2 * 180 if i < 0 else i for i in np.rad2deg(y_tick_locs)]
485
- axes.yaxis.set_major_locator(mticker.FixedLocator(y_tick_locs))
486
- axes.set_yticklabels([f"{i:.0f}°" for i in y_tick_labels])
487
-
488
- if not combined:
489
- aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx)
490
- if not idy:
491
- axes = plot_dist(
492
- values=row,
493
- textsize=xt_labelsize,
494
- rug=rug,
495
- ax=axes,
496
- hist_kwargs=hist_kwargs,
497
- plot_kwargs=aux_kwargs,
498
- fill_kwargs=fill_kwargs,
499
- rug_kwargs=rug_kwargs,
500
- backend="matplotlib",
501
- show=False,
502
- is_circular=circ_var_units,
503
- )
504
-
505
- if kind == "rank_bars" and idy:
506
- axes = plot_rank(data=value, kind="bars", ax=axes, **rank_kwargs)
507
- elif kind == "rank_vlines" and idy:
508
- axes = plot_rank(data=value, kind="vlines", ax=axes, **rank_kwargs)
509
-
510
- if combined:
511
- aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, -1)
512
- if not idy:
513
- axes = plot_dist(
514
- values=value.flatten(),
515
- textsize=xt_labelsize,
516
- rug=rug,
517
- ax=axes,
518
- hist_kwargs=hist_kwargs,
519
- plot_kwargs=aux_kwargs,
520
- fill_kwargs=fill_kwargs,
521
- rug_kwargs=rug_kwargs,
522
- backend="matplotlib",
523
- show=False,
524
- is_circular=circ_var_units,
525
- )
526
- return axes
@@ -1,121 +0,0 @@
1
- """Matplotlib plot time series figure."""
2
-
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
-
6
- from ...plot_utils import _scale_fig_size
7
- from . import create_axes_grid, backend_show, matplotlib_kwarg_dealiaser, backend_kwarg_defaults
8
-
9
-
10
- def plot_ts(
11
- x_plotters,
12
- y_plotters,
13
- y_mean_plotters,
14
- y_hat_plotters,
15
- y_holdout_plotters,
16
- x_holdout_plotters,
17
- y_forecasts_plotters,
18
- y_forecasts_mean_plotters,
19
- num_samples,
20
- backend_kwargs,
21
- y_kwargs,
22
- y_hat_plot_kwargs,
23
- y_mean_plot_kwargs,
24
- vline_kwargs,
25
- length_plotters,
26
- rows,
27
- cols,
28
- textsize,
29
- figsize,
30
- legend,
31
- axes,
32
- show,
33
- ):
34
- """Matplotlib time series."""
35
- if backend_kwargs is None:
36
- backend_kwargs = {}
37
-
38
- backend_kwargs = {
39
- **backend_kwarg_defaults(),
40
- **backend_kwargs,
41
- }
42
-
43
- if figsize is None:
44
- figsize = (12, rows * 5)
45
-
46
- backend_kwargs.setdefault("figsize", figsize)
47
- backend_kwargs.setdefault("squeeze", False)
48
-
49
- figsize, _, _, xt_labelsize, _, _ = _scale_fig_size(figsize, textsize, rows, cols)
50
-
51
- if axes is None:
52
- _, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs)
53
-
54
- y_kwargs = matplotlib_kwarg_dealiaser(y_kwargs, "plot")
55
- y_kwargs.setdefault("color", "blue")
56
- y_kwargs.setdefault("zorder", 10)
57
-
58
- y_hat_plot_kwargs = matplotlib_kwarg_dealiaser(y_hat_plot_kwargs, "plot")
59
- y_hat_plot_kwargs.setdefault("color", "grey")
60
- y_hat_plot_kwargs.setdefault("alpha", 0.1)
61
-
62
- y_mean_plot_kwargs = matplotlib_kwarg_dealiaser(y_mean_plot_kwargs, "plot")
63
- y_mean_plot_kwargs.setdefault("color", "red")
64
- y_mean_plot_kwargs.setdefault("linestyle", "dashed")
65
-
66
- vline_kwargs = matplotlib_kwarg_dealiaser(vline_kwargs, "plot")
67
- vline_kwargs.setdefault("color", "black")
68
- vline_kwargs.setdefault("linestyle", "dashed")
69
-
70
- for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]):
71
- y_var_name, _, _, y_plotters_i = y_plotters[i]
72
- x_var_name, _, _, x_plotters_i = x_plotters[i]
73
-
74
- ax_i.plot(x_plotters_i, y_plotters_i, **y_kwargs)
75
- ax_i.plot([], label="Actual", **y_kwargs)
76
- if y_hat_plotters is not None or y_forecasts_plotters is not None:
77
- ax_i.plot([], label="Fitted", **y_mean_plot_kwargs)
78
- ax_i.plot([], label="Uncertainty", **y_hat_plot_kwargs)
79
-
80
- ax_i.set_xlabel(x_var_name)
81
- ax_i.set_ylabel(y_var_name)
82
-
83
- if y_hat_plotters is not None:
84
- *_, y_hat_plotters_i = y_hat_plotters[i]
85
- *_, x_hat_plotters_i = x_plotters[i]
86
- for j in range(num_samples):
87
- ax_i.plot(x_hat_plotters_i, y_hat_plotters_i[..., j], **y_hat_plot_kwargs)
88
-
89
- *_, x_mean_plotters_i = x_plotters[i]
90
- *_, y_mean_plotters_i = y_mean_plotters[i]
91
- ax_i.plot(x_mean_plotters_i, y_mean_plotters_i, **y_mean_plot_kwargs)
92
-
93
- if y_holdout_plotters is not None:
94
- *_, y_holdout_plotters_i = y_holdout_plotters[i]
95
- *_, x_holdout_plotters_i = x_holdout_plotters[i]
96
-
97
- ax_i.plot(x_holdout_plotters_i, y_holdout_plotters_i, **y_kwargs)
98
- ax_i.axvline(x_plotters_i[-1], **vline_kwargs)
99
-
100
- if y_forecasts_plotters is not None:
101
- *_, y_forecasts_plotters_i = y_forecasts_plotters[i]
102
- *_, x_forecasts_plotters_i = x_holdout_plotters[i]
103
- for j in range(num_samples):
104
- ax_i.plot(
105
- x_forecasts_plotters_i, y_forecasts_plotters_i[..., j], **y_hat_plot_kwargs
106
- )
107
-
108
- *_, x_forecasts_mean_plotters_i = x_holdout_plotters[i]
109
- *_, y_forecasts_mean_plotters_i = y_forecasts_mean_plotters[i]
110
- ax_i.plot(
111
- x_forecasts_mean_plotters_i, y_forecasts_mean_plotters_i, **y_mean_plot_kwargs
112
- )
113
- ax_i.axvline(x_plotters_i[-1], **vline_kwargs)
114
-
115
- if legend:
116
- ax_i.legend(fontsize=xt_labelsize, loc="upper left")
117
-
118
- if backend_show(show):
119
- plt.show()
120
-
121
- return axes