arviz 0.17.1__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +4 -2
- arviz/data/__init__.py +5 -2
- arviz/data/base.py +102 -11
- arviz/data/converters.py +5 -0
- arviz/data/datasets.py +1 -0
- arviz/data/example_data/data_remote.json +10 -3
- arviz/data/inference_data.py +20 -22
- arviz/data/io_cmdstan.py +5 -3
- arviz/data/io_datatree.py +1 -0
- arviz/data/io_dict.py +5 -3
- arviz/data/io_emcee.py +1 -0
- arviz/data/io_numpyro.py +2 -1
- arviz/data/io_pyjags.py +1 -0
- arviz/data/io_pyro.py +1 -0
- arviz/data/utils.py +1 -0
- arviz/plots/__init__.py +1 -0
- arviz/plots/autocorrplot.py +1 -0
- arviz/plots/backends/bokeh/autocorrplot.py +1 -0
- arviz/plots/backends/bokeh/bpvplot.py +1 -0
- arviz/plots/backends/bokeh/compareplot.py +1 -0
- arviz/plots/backends/bokeh/densityplot.py +1 -0
- arviz/plots/backends/bokeh/distplot.py +1 -0
- arviz/plots/backends/bokeh/dotplot.py +1 -0
- arviz/plots/backends/bokeh/ecdfplot.py +2 -2
- arviz/plots/backends/bokeh/elpdplot.py +1 -0
- arviz/plots/backends/bokeh/energyplot.py +1 -0
- arviz/plots/backends/bokeh/hdiplot.py +1 -0
- arviz/plots/backends/bokeh/kdeplot.py +3 -3
- arviz/plots/backends/bokeh/khatplot.py +9 -3
- arviz/plots/backends/bokeh/lmplot.py +1 -0
- arviz/plots/backends/bokeh/loopitplot.py +1 -0
- arviz/plots/backends/bokeh/mcseplot.py +1 -0
- arviz/plots/backends/bokeh/pairplot.py +3 -6
- arviz/plots/backends/bokeh/parallelplot.py +1 -0
- arviz/plots/backends/bokeh/posteriorplot.py +1 -0
- arviz/plots/backends/bokeh/ppcplot.py +1 -0
- arviz/plots/backends/bokeh/rankplot.py +1 -0
- arviz/plots/backends/bokeh/separationplot.py +1 -0
- arviz/plots/backends/bokeh/traceplot.py +1 -0
- arviz/plots/backends/bokeh/violinplot.py +1 -0
- arviz/plots/backends/matplotlib/autocorrplot.py +1 -0
- arviz/plots/backends/matplotlib/bpvplot.py +1 -0
- arviz/plots/backends/matplotlib/compareplot.py +1 -0
- arviz/plots/backends/matplotlib/densityplot.py +1 -0
- arviz/plots/backends/matplotlib/distcomparisonplot.py +2 -3
- arviz/plots/backends/matplotlib/distplot.py +1 -0
- arviz/plots/backends/matplotlib/dotplot.py +1 -0
- arviz/plots/backends/matplotlib/ecdfplot.py +2 -2
- arviz/plots/backends/matplotlib/elpdplot.py +1 -0
- arviz/plots/backends/matplotlib/energyplot.py +1 -0
- arviz/plots/backends/matplotlib/essplot.py +6 -5
- arviz/plots/backends/matplotlib/forestplot.py +1 -0
- arviz/plots/backends/matplotlib/hdiplot.py +1 -0
- arviz/plots/backends/matplotlib/kdeplot.py +5 -3
- arviz/plots/backends/matplotlib/khatplot.py +8 -3
- arviz/plots/backends/matplotlib/lmplot.py +1 -0
- arviz/plots/backends/matplotlib/loopitplot.py +1 -0
- arviz/plots/backends/matplotlib/mcseplot.py +11 -10
- arviz/plots/backends/matplotlib/pairplot.py +2 -1
- arviz/plots/backends/matplotlib/parallelplot.py +1 -0
- arviz/plots/backends/matplotlib/posteriorplot.py +1 -0
- arviz/plots/backends/matplotlib/ppcplot.py +1 -0
- arviz/plots/backends/matplotlib/rankplot.py +1 -0
- arviz/plots/backends/matplotlib/separationplot.py +1 -0
- arviz/plots/backends/matplotlib/traceplot.py +2 -1
- arviz/plots/backends/matplotlib/tsplot.py +1 -0
- arviz/plots/backends/matplotlib/violinplot.py +2 -1
- arviz/plots/bpvplot.py +3 -2
- arviz/plots/compareplot.py +1 -0
- arviz/plots/densityplot.py +2 -1
- arviz/plots/distcomparisonplot.py +1 -0
- arviz/plots/dotplot.py +3 -2
- arviz/plots/ecdfplot.py +206 -89
- arviz/plots/elpdplot.py +1 -0
- arviz/plots/energyplot.py +1 -0
- arviz/plots/essplot.py +3 -2
- arviz/plots/forestplot.py +2 -1
- arviz/plots/hdiplot.py +3 -2
- arviz/plots/khatplot.py +24 -6
- arviz/plots/lmplot.py +1 -0
- arviz/plots/loopitplot.py +3 -2
- arviz/plots/mcseplot.py +4 -1
- arviz/plots/pairplot.py +1 -0
- arviz/plots/parallelplot.py +1 -0
- arviz/plots/plot_utils.py +3 -4
- arviz/plots/posteriorplot.py +2 -1
- arviz/plots/ppcplot.py +1 -0
- arviz/plots/rankplot.py +3 -2
- arviz/plots/separationplot.py +1 -0
- arviz/plots/traceplot.py +1 -0
- arviz/plots/tsplot.py +1 -0
- arviz/plots/violinplot.py +2 -1
- arviz/preview.py +17 -0
- arviz/rcparams.py +28 -2
- arviz/sel_utils.py +1 -0
- arviz/static/css/style.css +2 -1
- arviz/stats/density_utils.py +2 -1
- arviz/stats/diagnostics.py +15 -11
- arviz/stats/ecdf_utils.py +12 -8
- arviz/stats/stats.py +31 -16
- arviz/stats/stats_refitting.py +1 -0
- arviz/stats/stats_utils.py +13 -7
- arviz/tests/base_tests/test_data.py +15 -2
- arviz/tests/base_tests/test_data_zarr.py +0 -1
- arviz/tests/base_tests/test_diagnostics.py +1 -0
- arviz/tests/base_tests/test_diagnostics_numba.py +2 -6
- arviz/tests/base_tests/test_helpers.py +2 -2
- arviz/tests/base_tests/test_labels.py +1 -0
- arviz/tests/base_tests/test_plot_utils.py +5 -13
- arviz/tests/base_tests/test_plots_matplotlib.py +98 -7
- arviz/tests/base_tests/test_rcparams.py +12 -0
- arviz/tests/base_tests/test_stats.py +5 -5
- arviz/tests/base_tests/test_stats_numba.py +2 -7
- arviz/tests/base_tests/test_stats_utils.py +1 -0
- arviz/tests/base_tests/test_utils.py +3 -2
- arviz/tests/base_tests/test_utils_numba.py +2 -5
- arviz/tests/external_tests/test_data_pystan.py +5 -5
- arviz/tests/helpers.py +18 -10
- arviz/utils.py +4 -0
- arviz/wrappers/__init__.py +1 -0
- {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/METADATA +13 -9
- arviz-0.19.0.dist-info/RECORD +183 -0
- arviz-0.17.1.dist-info/RECORD +0 -182
- {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/LICENSE +0 -0
- {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/WHEEL +0 -0
- {arviz-0.17.1.dist-info → arviz-0.19.0.dist-info}/top_level.txt +0 -0
arviz/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# pylint: disable=wildcard-import,invalid-name,wrong-import-position
|
|
2
2
|
"""ArviZ is a library for exploratory analysis of Bayesian models."""
|
|
3
|
-
__version__ = "0.
|
|
3
|
+
__version__ = "0.19.0"
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
@@ -37,6 +37,7 @@ from .stats import *
|
|
|
37
37
|
from .rcparams import rc_context, rcParams
|
|
38
38
|
from .utils import Numba, Dask, interactive_backend
|
|
39
39
|
from .wrappers import *
|
|
40
|
+
from . import preview
|
|
40
41
|
|
|
41
42
|
# add ArviZ's styles to matplotlib's styles
|
|
42
43
|
_arviz_style_path = os.path.join(os.path.dirname(__file__), "plots", "styles")
|
|
@@ -315,7 +316,8 @@ _linear_grey_10_95_c0 = [
|
|
|
315
316
|
|
|
316
317
|
def _mpl_cm(name, colorlist):
|
|
317
318
|
cmap = LinearSegmentedColormap.from_list(name, colorlist, N=256)
|
|
318
|
-
|
|
319
|
+
if "cet_" + name not in mpl.colormaps():
|
|
320
|
+
mpl.colormaps.register(cmap, name="cet_" + name)
|
|
319
321
|
|
|
320
322
|
|
|
321
323
|
try:
|
arviz/data/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Code for loading and manipulating data structures."""
|
|
2
|
-
|
|
2
|
+
|
|
3
|
+
from .base import CoordSpec, DimSpec, dict_to_dataset, numpy_to_data_array, pytree_to_dataset
|
|
3
4
|
from .converters import convert_to_dataset, convert_to_inference_data
|
|
4
5
|
from .datasets import clear_data_home, list_datasets, load_arviz_data
|
|
5
6
|
from .inference_data import InferenceData, concat
|
|
@@ -7,7 +8,7 @@ from .io_beanmachine import from_beanmachine
|
|
|
7
8
|
from .io_cmdstan import from_cmdstan
|
|
8
9
|
from .io_cmdstanpy import from_cmdstanpy
|
|
9
10
|
from .io_datatree import from_datatree, to_datatree
|
|
10
|
-
from .io_dict import from_dict
|
|
11
|
+
from .io_dict import from_dict, from_pytree
|
|
11
12
|
from .io_emcee import from_emcee
|
|
12
13
|
from .io_json import from_json, to_json
|
|
13
14
|
from .io_netcdf import from_netcdf, to_netcdf
|
|
@@ -38,10 +39,12 @@ __all__ = [
|
|
|
38
39
|
"from_cmdstanpy",
|
|
39
40
|
"from_datatree",
|
|
40
41
|
"from_dict",
|
|
42
|
+
"from_pytree",
|
|
41
43
|
"from_json",
|
|
42
44
|
"from_pyro",
|
|
43
45
|
"from_numpyro",
|
|
44
46
|
"from_netcdf",
|
|
47
|
+
"pytree_to_dataset",
|
|
45
48
|
"to_datatree",
|
|
46
49
|
"to_json",
|
|
47
50
|
"to_netcdf",
|
arviz/data/base.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Low level converters usually used by other functions."""
|
|
2
|
+
|
|
2
3
|
import datetime
|
|
3
4
|
import functools
|
|
4
5
|
import importlib
|
|
@@ -8,6 +9,7 @@ from copy import deepcopy
|
|
|
8
9
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
|
|
9
10
|
|
|
10
11
|
import numpy as np
|
|
12
|
+
import tree
|
|
11
13
|
import xarray as xr
|
|
12
14
|
|
|
13
15
|
try:
|
|
@@ -67,6 +69,48 @@ class requires: # pylint: disable=invalid-name
|
|
|
67
69
|
return wrapped
|
|
68
70
|
|
|
69
71
|
|
|
72
|
+
def _yield_flat_up_to(shallow_tree, input_tree, path=()):
|
|
73
|
+
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.
|
|
74
|
+
|
|
75
|
+
Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow
|
|
76
|
+
lists as leaves.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
|
|
80
|
+
input_tree: Nested structure. Return the paths and values from this tree.
|
|
81
|
+
Must have the same upper structure as shallow_tree.
|
|
82
|
+
path: Tuple. Optional argument, only used when recursing. The path from the
|
|
83
|
+
root of the original shallow_tree, down to the root of the shallow_tree
|
|
84
|
+
arg of this recursive call.
|
|
85
|
+
|
|
86
|
+
Yields:
|
|
87
|
+
Pairs of (path, value), where path the tuple path of a leaf node in
|
|
88
|
+
shallow_tree, and value is the value of the corresponding node in
|
|
89
|
+
input_tree.
|
|
90
|
+
"""
|
|
91
|
+
# pylint: disable=protected-access
|
|
92
|
+
if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
|
|
93
|
+
isinstance(shallow_tree, tree.collections_abc.Mapping)
|
|
94
|
+
or tree._is_namedtuple(shallow_tree)
|
|
95
|
+
or tree._is_attrs(shallow_tree)
|
|
96
|
+
):
|
|
97
|
+
yield (path, input_tree)
|
|
98
|
+
else:
|
|
99
|
+
input_tree = dict(tree._yield_sorted_items(input_tree))
|
|
100
|
+
for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree):
|
|
101
|
+
subpath = path + (shallow_key,)
|
|
102
|
+
input_subtree = input_tree[shallow_key]
|
|
103
|
+
for leaf_path, leaf_value in _yield_flat_up_to(
|
|
104
|
+
shallow_subtree, input_subtree, path=subpath
|
|
105
|
+
):
|
|
106
|
+
yield (leaf_path, leaf_value)
|
|
107
|
+
# pylint: enable=protected-access
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _flatten_with_path(structure):
|
|
111
|
+
return list(_yield_flat_up_to(structure, structure))
|
|
112
|
+
|
|
113
|
+
|
|
70
114
|
def generate_dims_coords(
|
|
71
115
|
shape,
|
|
72
116
|
var_name,
|
|
@@ -255,7 +299,7 @@ def numpy_to_data_array(
|
|
|
255
299
|
return xr.DataArray(ary, coords=coords, dims=dims)
|
|
256
300
|
|
|
257
301
|
|
|
258
|
-
def
|
|
302
|
+
def pytree_to_dataset(
|
|
259
303
|
data,
|
|
260
304
|
*,
|
|
261
305
|
attrs=None,
|
|
@@ -266,26 +310,29 @@ def dict_to_dataset(
|
|
|
266
310
|
index_origin=None,
|
|
267
311
|
skip_event_dims=None,
|
|
268
312
|
):
|
|
269
|
-
"""Convert a dictionary of numpy arrays to an xarray.Dataset.
|
|
313
|
+
"""Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
|
|
314
|
+
|
|
315
|
+
See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
|
|
316
|
+
this inclues at least dictionaries and tuple types.
|
|
270
317
|
|
|
271
318
|
Parameters
|
|
272
319
|
----------
|
|
273
|
-
data : dict
|
|
320
|
+
data : dict of {str : array_like or dict} or pytree
|
|
274
321
|
Data to convert. Keys are variable names.
|
|
275
|
-
attrs : dict
|
|
322
|
+
attrs : dict, optional
|
|
276
323
|
Json serializable metadata to attach to the dataset, in addition to defaults.
|
|
277
|
-
library : module
|
|
324
|
+
library : module, optional
|
|
278
325
|
Library used for performing inference. Will be attached to the attrs metadata.
|
|
279
|
-
coords : dict
|
|
326
|
+
coords : dict of {str : ndarray}, optional
|
|
280
327
|
Coordinates for the dataset
|
|
281
|
-
dims : dict
|
|
328
|
+
dims : dict of {str : list of str}, optional
|
|
282
329
|
Dimensions of each variable. The keys are variable names, values are lists of
|
|
283
330
|
coordinates.
|
|
284
331
|
default_dims : list of str, optional
|
|
285
332
|
Passed to :py:func:`numpy_to_data_array`
|
|
286
333
|
index_origin : int, optional
|
|
287
334
|
Passed to :py:func:`numpy_to_data_array`
|
|
288
|
-
skip_event_dims : bool
|
|
335
|
+
skip_event_dims : bool, optional
|
|
289
336
|
If True, cut extra dims whenever present to match the shape of the data.
|
|
290
337
|
Necessary for PPLs which have the same name in both observed data and log
|
|
291
338
|
likelihood groups, to account for their different shapes when observations are
|
|
@@ -293,15 +340,56 @@ def dict_to_dataset(
|
|
|
293
340
|
|
|
294
341
|
Returns
|
|
295
342
|
-------
|
|
296
|
-
|
|
343
|
+
xarray.Dataset
|
|
344
|
+
In case of nested pytrees, the variable name will be a tuple of individual names.
|
|
345
|
+
|
|
346
|
+
Notes
|
|
347
|
+
-----
|
|
348
|
+
This function is available through two aliases: ``dict_to_dataset`` or ``pytree_to_dataset``.
|
|
297
349
|
|
|
298
350
|
Examples
|
|
299
351
|
--------
|
|
300
|
-
|
|
352
|
+
Convert a dictionary with two 2D variables to a Dataset.
|
|
353
|
+
|
|
354
|
+
.. ipython::
|
|
355
|
+
|
|
356
|
+
In [1]: import arviz as az
|
|
357
|
+
...: import numpy as np
|
|
358
|
+
...: az.dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
|
|
359
|
+
|
|
360
|
+
Note that unlike the :class:`xarray.Dataset` constructor, ArviZ has added extra
|
|
361
|
+
information to the generated Dataset such as default dimension names for sampled
|
|
362
|
+
dimensions and some attributes.
|
|
363
|
+
|
|
364
|
+
The function is also general enough to work on pytrees such as nested dictionaries:
|
|
365
|
+
|
|
366
|
+
.. ipython::
|
|
367
|
+
|
|
368
|
+
In [1]: az.pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})
|
|
369
|
+
|
|
370
|
+
which has two variables (as many as leafs) named ``('top', 'second')`` and ``top2``.
|
|
371
|
+
|
|
372
|
+
Dimensions and co-ordinates can be defined as usual:
|
|
373
|
+
|
|
374
|
+
.. ipython::
|
|
375
|
+
|
|
376
|
+
In [1]: datadict = {
|
|
377
|
+
...: "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
|
|
378
|
+
...: "d": np.random.randn(100),
|
|
379
|
+
...: }
|
|
380
|
+
...: az.dict_to_dataset(
|
|
381
|
+
...: datadict,
|
|
382
|
+
...: coords={"c": np.arange(10)},
|
|
383
|
+
...: dims={("top", "b"): ["c"]}
|
|
384
|
+
...: )
|
|
301
385
|
|
|
302
386
|
"""
|
|
303
387
|
if dims is None:
|
|
304
388
|
dims = {}
|
|
389
|
+
try:
|
|
390
|
+
data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
|
|
391
|
+
except TypeError: # probably unsortable keys -- the function will still work if
|
|
392
|
+
pass # it is an honest dictionary.
|
|
305
393
|
|
|
306
394
|
data_vars = {
|
|
307
395
|
key: numpy_to_data_array(
|
|
@@ -318,6 +406,9 @@ def dict_to_dataset(
|
|
|
318
406
|
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
|
|
319
407
|
|
|
320
408
|
|
|
409
|
+
dict_to_dataset = pytree_to_dataset
|
|
410
|
+
|
|
411
|
+
|
|
321
412
|
def make_attrs(attrs=None, library=None):
|
|
322
413
|
"""Make standard attributes to attach to xarray datasets.
|
|
323
414
|
|
|
@@ -332,7 +423,7 @@ def make_attrs(attrs=None, library=None):
|
|
|
332
423
|
attrs
|
|
333
424
|
"""
|
|
334
425
|
default_attrs = {
|
|
335
|
-
"created_at": datetime.datetime.
|
|
426
|
+
"created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
|
336
427
|
"arviz_version": __version__,
|
|
337
428
|
}
|
|
338
429
|
if library is not None:
|
arviz/data/converters.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""High level conversion functions."""
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
4
|
+
import tree
|
|
3
5
|
import xarray as xr
|
|
4
6
|
|
|
5
7
|
from .base import dict_to_dataset
|
|
@@ -105,6 +107,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
|
|
|
105
107
|
dataset = obj.to_dataset()
|
|
106
108
|
elif isinstance(obj, dict):
|
|
107
109
|
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
|
|
110
|
+
elif tree.is_nested(obj) and not isinstance(obj, (list, tuple)):
|
|
111
|
+
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
|
|
108
112
|
elif isinstance(obj, np.ndarray):
|
|
109
113
|
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
|
|
110
114
|
elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
|
|
@@ -118,6 +122,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
|
|
|
118
122
|
"xarray dataarray",
|
|
119
123
|
"xarray dataset",
|
|
120
124
|
"dict",
|
|
125
|
+
"pytree",
|
|
121
126
|
"netcdf filename",
|
|
122
127
|
"numpy array",
|
|
123
128
|
"pystan fit",
|
arviz/data/datasets.py
CHANGED
|
@@ -9,9 +9,16 @@
|
|
|
9
9
|
{
|
|
10
10
|
"name": "rugby",
|
|
11
11
|
"filename": "rugby.nc",
|
|
12
|
-
"url": "http://
|
|
13
|
-
"checksum": "
|
|
14
|
-
"description": "The Six Nations Championship is a yearly rugby competition between Italy, Ireland, Scotland, England, France and Wales. Fifteen games are played each year, representing all combinations of the six teams.\n\nThis example uses and includes results from 2014 - 2017, comprising 60 total games. It models latent parameters for each team's attack and defense, as well as a parameter for home team advantage.\n\nSee https://
|
|
12
|
+
"url": "http://figshare.com/ndownloader/files/44916469",
|
|
13
|
+
"checksum": "f4a5e699a8a4cc93f722eb97929dd7c4895c59a2183f05309f5082f3f81eb228",
|
|
14
|
+
"description": "The Six Nations Championship is a yearly rugby competition between Italy, Ireland, Scotland, England, France and Wales. Fifteen games are played each year, representing all combinations of the six teams.\n\nThis example uses and includes results from 2014 - 2017, comprising 60 total games. It models latent parameters for each team's attack and defense, as well as a global parameter for home team advantage.\n\nSee https://github.com/arviz-devs/arviz_example_data/blob/main/code/rugby/rugby.ipynb for the whole model specification."
|
|
15
|
+
},
|
|
16
|
+
{
|
|
17
|
+
"name": "rugby_field",
|
|
18
|
+
"filename": "rugby_field.nc",
|
|
19
|
+
"url": "http://figshare.com/ndownloader/files/44667112",
|
|
20
|
+
"checksum": "53a99da7ac40d82cd01bb0b089263b9633ee016f975700e941b4c6ea289a1fb0",
|
|
21
|
+
"description": "A variant of the 'rugby' example dataset. The Six Nations Championship is a yearly rugby competition between Italy, Ireland, Scotland, England, France and Wales. Fifteen games are played each year, representing all combinations of the six teams.\n\nThis example uses and includes results from 2014 - 2017, comprising 60 total games. It models latent parameters for each team's attack and defense, with each team having different values depending on them being home or away team.\n\nSee https://github.com/arviz-devs/arviz_example_data/blob/main/code/rugby_field/rugby_field.ipynb for the whole model specification."
|
|
15
22
|
},
|
|
16
23
|
{
|
|
17
24
|
"name": "regression1d",
|
arviz/data/inference_data.py
CHANGED
|
@@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict
|
|
|
9
9
|
from collections.abc import MutableMapping, Sequence
|
|
10
10
|
from copy import copy as ccopy
|
|
11
11
|
from copy import deepcopy
|
|
12
|
-
|
|
12
|
+
import datetime
|
|
13
13
|
from html import escape
|
|
14
14
|
from typing import (
|
|
15
15
|
TYPE_CHECKING,
|
|
@@ -394,8 +394,10 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
394
394
|
)
|
|
395
395
|
|
|
396
396
|
try:
|
|
397
|
-
with
|
|
398
|
-
filename, mode="r"
|
|
397
|
+
with (
|
|
398
|
+
h5netcdf.File(filename, mode="r")
|
|
399
|
+
if engine == "h5netcdf"
|
|
400
|
+
else nc.Dataset(filename, mode="r")
|
|
399
401
|
) as file_handle:
|
|
400
402
|
if base_group == "/":
|
|
401
403
|
data = file_handle
|
|
@@ -744,11 +746,11 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
744
746
|
if len(dfs) > 1:
|
|
745
747
|
for group, df in dfs.items():
|
|
746
748
|
df.columns = [
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
749
|
+
(
|
|
750
|
+
col
|
|
751
|
+
if col in ("draw", "chain")
|
|
752
|
+
else (group, *col) if isinstance(col, tuple) else (group, col)
|
|
753
|
+
)
|
|
752
754
|
for col in df.columns
|
|
753
755
|
]
|
|
754
756
|
dfs, *dfs_tail = list(dfs.values())
|
|
@@ -1475,12 +1477,12 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1475
1477
|
Examples
|
|
1476
1478
|
--------
|
|
1477
1479
|
Add a ``log_likelihood`` group to the "rugby" example InferenceData after loading.
|
|
1478
|
-
It originally doesn't have the ``log_likelihood`` group:
|
|
1479
1480
|
|
|
1480
1481
|
.. jupyter-execute::
|
|
1481
1482
|
|
|
1482
1483
|
import arviz as az
|
|
1483
1484
|
idata = az.load_arviz_data("rugby")
|
|
1485
|
+
del idata.log_likelihood
|
|
1484
1486
|
idata2 = idata.copy()
|
|
1485
1487
|
post = idata.posterior
|
|
1486
1488
|
obs = idata.observed_data
|
|
@@ -1609,13 +1611,13 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1609
1611
|
.. jupyter-execute::
|
|
1610
1612
|
|
|
1611
1613
|
import arviz as az
|
|
1612
|
-
idata = az.load_arviz_data("
|
|
1614
|
+
idata = az.load_arviz_data("radon")
|
|
1613
1615
|
|
|
1614
1616
|
Second InferenceData:
|
|
1615
1617
|
|
|
1616
1618
|
.. jupyter-execute::
|
|
1617
1619
|
|
|
1618
|
-
other_idata = az.load_arviz_data("
|
|
1620
|
+
other_idata = az.load_arviz_data("rugby")
|
|
1619
1621
|
|
|
1620
1622
|
Call the ``extend`` method:
|
|
1621
1623
|
|
|
@@ -1687,6 +1689,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1687
1689
|
compute = _extend_xr_method(xr.Dataset.compute)
|
|
1688
1690
|
persist = _extend_xr_method(xr.Dataset.persist)
|
|
1689
1691
|
quantile = _extend_xr_method(xr.Dataset.quantile)
|
|
1692
|
+
close = _extend_xr_method(xr.Dataset.close)
|
|
1690
1693
|
|
|
1691
1694
|
# The following lines use methods on xr.Dataset that are dynamically defined and attached.
|
|
1692
1695
|
# As a result mypy cannot see them, so we have to suppress the resulting mypy errors.
|
|
@@ -1918,8 +1921,7 @@ def concat(
|
|
|
1918
1921
|
copy: bool = True,
|
|
1919
1922
|
inplace: "Literal[True]",
|
|
1920
1923
|
reset_dim: bool = True,
|
|
1921
|
-
) -> None:
|
|
1922
|
-
...
|
|
1924
|
+
) -> None: ...
|
|
1923
1925
|
|
|
1924
1926
|
|
|
1925
1927
|
@overload
|
|
@@ -1929,8 +1931,7 @@ def concat(
|
|
|
1929
1931
|
copy: bool = True,
|
|
1930
1932
|
inplace: "Literal[False]",
|
|
1931
1933
|
reset_dim: bool = True,
|
|
1932
|
-
) -> InferenceData:
|
|
1933
|
-
...
|
|
1934
|
+
) -> InferenceData: ...
|
|
1934
1935
|
|
|
1935
1936
|
|
|
1936
1937
|
@overload
|
|
@@ -1941,8 +1942,7 @@ def concat(
|
|
|
1941
1942
|
copy: bool = True,
|
|
1942
1943
|
inplace: "Literal[False]",
|
|
1943
1944
|
reset_dim: bool = True,
|
|
1944
|
-
) -> InferenceData:
|
|
1945
|
-
...
|
|
1945
|
+
) -> InferenceData: ...
|
|
1946
1946
|
|
|
1947
1947
|
|
|
1948
1948
|
@overload
|
|
@@ -1953,8 +1953,7 @@ def concat(
|
|
|
1953
1953
|
copy: bool = True,
|
|
1954
1954
|
inplace: "Literal[True]",
|
|
1955
1955
|
reset_dim: bool = True,
|
|
1956
|
-
) -> None:
|
|
1957
|
-
...
|
|
1956
|
+
) -> None: ...
|
|
1958
1957
|
|
|
1959
1958
|
|
|
1960
1959
|
@overload
|
|
@@ -1965,8 +1964,7 @@ def concat(
|
|
|
1965
1964
|
copy: bool = True,
|
|
1966
1965
|
inplace: bool = False,
|
|
1967
1966
|
reset_dim: bool = True,
|
|
1968
|
-
) -> Optional[InferenceData]:
|
|
1969
|
-
...
|
|
1967
|
+
) -> Optional[InferenceData]: ...
|
|
1970
1968
|
|
|
1971
1969
|
|
|
1972
1970
|
# pylint: disable=protected-access, inconsistent-return-statements
|
|
@@ -2083,7 +2081,7 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
|
|
|
2083
2081
|
else:
|
|
2084
2082
|
return args[0]
|
|
2085
2083
|
|
|
2086
|
-
current_time =
|
|
2084
|
+
current_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
|
|
2087
2085
|
combined_attr = defaultdict(list)
|
|
2088
2086
|
for idata in args:
|
|
2089
2087
|
for key, val in idata.attrs.items():
|
arviz/data/io_cmdstan.py
CHANGED
|
@@ -732,14 +732,13 @@ def _process_configuration(comments):
|
|
|
732
732
|
key = (
|
|
733
733
|
"warmup_time_seconds"
|
|
734
734
|
if "(Warm-up)" in comment
|
|
735
|
-
else "sampling_time_seconds"
|
|
736
|
-
if "(Sampling)" in comment
|
|
737
|
-
else "total_time_seconds"
|
|
735
|
+
else "sampling_time_seconds" if "(Sampling)" in comment else "total_time_seconds"
|
|
738
736
|
)
|
|
739
737
|
results[key] = float(value)
|
|
740
738
|
elif "=" in comment:
|
|
741
739
|
match_int = re.search(r"^(\S+)\s*=\s*([-+]?[0-9]+)$", comment)
|
|
742
740
|
match_float = re.search(r"^(\S+)\s*=\s*([-+]?[0-9]+\.[0-9]+)$", comment)
|
|
741
|
+
match_str_bool = re.search(r"^(\S+)\s*=\s*(true|false)$", comment)
|
|
743
742
|
match_str = re.search(r"^(\S+)\s*=\s*(\S+)$", comment)
|
|
744
743
|
match_empty = re.search(r"^(\S+)\s*=\s*$", comment)
|
|
745
744
|
if match_int:
|
|
@@ -748,6 +747,9 @@ def _process_configuration(comments):
|
|
|
748
747
|
elif match_float:
|
|
749
748
|
key, value = match_float.group(1), match_float.group(2)
|
|
750
749
|
results[key] = float(value)
|
|
750
|
+
elif match_str_bool:
|
|
751
|
+
key, value = match_str_bool.group(1), match_str_bool.group(2)
|
|
752
|
+
results[key] = int(value == "true")
|
|
751
753
|
elif match_str:
|
|
752
754
|
key, value = match_str.group(1), match_str.group(2)
|
|
753
755
|
results[key] = value
|
arviz/data/io_datatree.py
CHANGED
arviz/data/io_dict.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Dictionary specific conversion code."""
|
|
2
|
+
|
|
2
3
|
import warnings
|
|
3
4
|
from typing import Optional
|
|
4
5
|
|
|
@@ -59,9 +60,7 @@ class DictConverter:
|
|
|
59
60
|
self.coords = (
|
|
60
61
|
coords
|
|
61
62
|
if pred_coords is None
|
|
62
|
-
else pred_coords
|
|
63
|
-
if coords is None
|
|
64
|
-
else {**coords, **pred_coords}
|
|
63
|
+
else pred_coords if coords is None else {**coords, **pred_coords}
|
|
65
64
|
)
|
|
66
65
|
self.index_origin = index_origin
|
|
67
66
|
self.coords = coords
|
|
@@ -458,3 +457,6 @@ def from_dict(
|
|
|
458
457
|
attrs=attrs,
|
|
459
458
|
**kwargs,
|
|
460
459
|
).to_inference_data()
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
from_pytree = from_dict
|
arviz/data/io_emcee.py
CHANGED
arviz/data/io_numpyro.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""NumPyro-specific conversion code."""
|
|
2
|
+
|
|
2
3
|
import logging
|
|
3
4
|
from typing import Callable, Optional
|
|
4
5
|
|
|
@@ -193,7 +194,7 @@ class NumPyroConverter:
|
|
|
193
194
|
)
|
|
194
195
|
for obs_name, log_like in log_likelihood_dict.items():
|
|
195
196
|
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
|
|
196
|
-
data[obs_name] = np.reshape(
|
|
197
|
+
data[obs_name] = np.reshape(np.asarray(log_like), shape)
|
|
197
198
|
return dict_to_dataset(
|
|
198
199
|
data,
|
|
199
200
|
library=self.numpyro,
|
arviz/data/io_pyjags.py
CHANGED
arviz/data/io_pyro.py
CHANGED
arviz/data/utils.py
CHANGED
arviz/plots/__init__.py
CHANGED
arviz/plots/autocorrplot.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Bokeh ecdfplot."""
|
|
2
|
+
|
|
2
3
|
from matplotlib.colors import to_hex
|
|
3
4
|
|
|
4
5
|
from ...plot_utils import _scale_fig_size
|
|
@@ -12,7 +13,6 @@ def plot_ecdf(
|
|
|
12
13
|
x_bands,
|
|
13
14
|
lower,
|
|
14
15
|
higher,
|
|
15
|
-
confidence_bands,
|
|
16
16
|
plot_kwargs,
|
|
17
17
|
fill_kwargs,
|
|
18
18
|
plot_outline_kwargs,
|
|
@@ -57,7 +57,7 @@ def plot_ecdf(
|
|
|
57
57
|
plot_outline_kwargs.setdefault("color", to_hex("C0"))
|
|
58
58
|
plot_outline_kwargs.setdefault("alpha", 0.2)
|
|
59
59
|
|
|
60
|
-
if
|
|
60
|
+
if x_bands is not None:
|
|
61
61
|
ax.step(x_coord, y_coord, **plot_kwargs)
|
|
62
62
|
|
|
63
63
|
if fill_band:
|
|
@@ -6,7 +6,7 @@ from numbers import Integral
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
from bokeh.models import ColumnDataSource
|
|
8
8
|
from bokeh.models.glyphs import Scatter
|
|
9
|
-
from matplotlib
|
|
9
|
+
from matplotlib import colormaps
|
|
10
10
|
from matplotlib.colors import rgb2hex
|
|
11
11
|
from matplotlib.pyplot import rcParams as mpl_rcParams
|
|
12
12
|
|
|
@@ -188,7 +188,7 @@ def plot_kde(
|
|
|
188
188
|
|
|
189
189
|
cmap = contourf_kwargs.pop("cmap", "viridis")
|
|
190
190
|
if isinstance(cmap, str):
|
|
191
|
-
cmap =
|
|
191
|
+
cmap = colormaps[cmap]
|
|
192
192
|
if isinstance(cmap, Callable):
|
|
193
193
|
colors = [rgb2hex(item) for item in cmap(np.linspace(0, 1, len(levels_scaled) + 1))]
|
|
194
194
|
else:
|
|
@@ -225,7 +225,7 @@ def plot_kde(
|
|
|
225
225
|
else:
|
|
226
226
|
cmap = pcolormesh_kwargs.pop("cmap", "viridis")
|
|
227
227
|
if isinstance(cmap, str):
|
|
228
|
-
cmap =
|
|
228
|
+
cmap = colormaps[cmap]
|
|
229
229
|
if isinstance(cmap, Callable):
|
|
230
230
|
colors = [rgb2hex(item) for item in cmap(np.linspace(0, 1, 256))]
|
|
231
231
|
else:
|