arviz 0.17.1__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. arviz/__init__.py +3 -2
  2. arviz/data/__init__.py +5 -2
  3. arviz/data/base.py +102 -11
  4. arviz/data/converters.py +5 -0
  5. arviz/data/datasets.py +1 -0
  6. arviz/data/example_data/data_remote.json +10 -3
  7. arviz/data/inference_data.py +20 -22
  8. arviz/data/io_cmdstan.py +1 -3
  9. arviz/data/io_datatree.py +1 -0
  10. arviz/data/io_dict.py +5 -3
  11. arviz/data/io_emcee.py +1 -0
  12. arviz/data/io_numpyro.py +1 -0
  13. arviz/data/io_pyjags.py +1 -0
  14. arviz/data/io_pyro.py +1 -0
  15. arviz/data/utils.py +1 -0
  16. arviz/plots/__init__.py +1 -0
  17. arviz/plots/autocorrplot.py +1 -0
  18. arviz/plots/backends/bokeh/autocorrplot.py +1 -0
  19. arviz/plots/backends/bokeh/bpvplot.py +1 -0
  20. arviz/plots/backends/bokeh/compareplot.py +1 -0
  21. arviz/plots/backends/bokeh/densityplot.py +1 -0
  22. arviz/plots/backends/bokeh/distplot.py +1 -0
  23. arviz/plots/backends/bokeh/dotplot.py +1 -0
  24. arviz/plots/backends/bokeh/ecdfplot.py +1 -0
  25. arviz/plots/backends/bokeh/elpdplot.py +1 -0
  26. arviz/plots/backends/bokeh/energyplot.py +1 -0
  27. arviz/plots/backends/bokeh/hdiplot.py +1 -0
  28. arviz/plots/backends/bokeh/kdeplot.py +3 -3
  29. arviz/plots/backends/bokeh/khatplot.py +1 -0
  30. arviz/plots/backends/bokeh/lmplot.py +1 -0
  31. arviz/plots/backends/bokeh/loopitplot.py +1 -0
  32. arviz/plots/backends/bokeh/mcseplot.py +1 -0
  33. arviz/plots/backends/bokeh/pairplot.py +1 -0
  34. arviz/plots/backends/bokeh/parallelplot.py +1 -0
  35. arviz/plots/backends/bokeh/posteriorplot.py +1 -0
  36. arviz/plots/backends/bokeh/ppcplot.py +1 -0
  37. arviz/plots/backends/bokeh/rankplot.py +1 -0
  38. arviz/plots/backends/bokeh/separationplot.py +1 -0
  39. arviz/plots/backends/bokeh/traceplot.py +1 -0
  40. arviz/plots/backends/bokeh/violinplot.py +1 -0
  41. arviz/plots/backends/matplotlib/autocorrplot.py +1 -0
  42. arviz/plots/backends/matplotlib/bpvplot.py +1 -0
  43. arviz/plots/backends/matplotlib/compareplot.py +1 -0
  44. arviz/plots/backends/matplotlib/densityplot.py +1 -0
  45. arviz/plots/backends/matplotlib/distcomparisonplot.py +2 -3
  46. arviz/plots/backends/matplotlib/distplot.py +1 -0
  47. arviz/plots/backends/matplotlib/dotplot.py +1 -0
  48. arviz/plots/backends/matplotlib/ecdfplot.py +1 -0
  49. arviz/plots/backends/matplotlib/elpdplot.py +1 -0
  50. arviz/plots/backends/matplotlib/energyplot.py +1 -0
  51. arviz/plots/backends/matplotlib/essplot.py +6 -5
  52. arviz/plots/backends/matplotlib/forestplot.py +1 -0
  53. arviz/plots/backends/matplotlib/hdiplot.py +1 -0
  54. arviz/plots/backends/matplotlib/kdeplot.py +5 -3
  55. arviz/plots/backends/matplotlib/khatplot.py +1 -0
  56. arviz/plots/backends/matplotlib/lmplot.py +1 -0
  57. arviz/plots/backends/matplotlib/loopitplot.py +1 -0
  58. arviz/plots/backends/matplotlib/mcseplot.py +11 -10
  59. arviz/plots/backends/matplotlib/pairplot.py +2 -1
  60. arviz/plots/backends/matplotlib/parallelplot.py +1 -0
  61. arviz/plots/backends/matplotlib/posteriorplot.py +1 -0
  62. arviz/plots/backends/matplotlib/ppcplot.py +1 -0
  63. arviz/plots/backends/matplotlib/rankplot.py +1 -0
  64. arviz/plots/backends/matplotlib/separationplot.py +1 -0
  65. arviz/plots/backends/matplotlib/traceplot.py +1 -0
  66. arviz/plots/backends/matplotlib/tsplot.py +1 -0
  67. arviz/plots/backends/matplotlib/violinplot.py +2 -1
  68. arviz/plots/bpvplot.py +1 -0
  69. arviz/plots/compareplot.py +1 -0
  70. arviz/plots/densityplot.py +1 -0
  71. arviz/plots/distcomparisonplot.py +1 -0
  72. arviz/plots/dotplot.py +1 -0
  73. arviz/plots/ecdfplot.py +1 -0
  74. arviz/plots/elpdplot.py +1 -0
  75. arviz/plots/energyplot.py +1 -0
  76. arviz/plots/essplot.py +1 -0
  77. arviz/plots/forestplot.py +1 -0
  78. arviz/plots/hdiplot.py +1 -0
  79. arviz/plots/khatplot.py +1 -0
  80. arviz/plots/lmplot.py +1 -0
  81. arviz/plots/loopitplot.py +1 -0
  82. arviz/plots/mcseplot.py +1 -0
  83. arviz/plots/pairplot.py +1 -0
  84. arviz/plots/parallelplot.py +1 -0
  85. arviz/plots/plot_utils.py +1 -0
  86. arviz/plots/posteriorplot.py +1 -0
  87. arviz/plots/ppcplot.py +1 -0
  88. arviz/plots/rankplot.py +1 -0
  89. arviz/plots/separationplot.py +1 -0
  90. arviz/plots/traceplot.py +1 -0
  91. arviz/plots/tsplot.py +1 -0
  92. arviz/plots/violinplot.py +1 -0
  93. arviz/rcparams.py +1 -0
  94. arviz/sel_utils.py +1 -0
  95. arviz/static/css/style.css +2 -1
  96. arviz/stats/density_utils.py +2 -1
  97. arviz/stats/diagnostics.py +2 -2
  98. arviz/stats/ecdf_utils.py +1 -0
  99. arviz/stats/stats_refitting.py +1 -0
  100. arviz/stats/stats_utils.py +5 -1
  101. arviz/tests/base_tests/test_data.py +14 -0
  102. arviz/tests/base_tests/test_diagnostics.py +1 -0
  103. arviz/tests/base_tests/test_diagnostics_numba.py +1 -0
  104. arviz/tests/base_tests/test_labels.py +1 -0
  105. arviz/tests/base_tests/test_plots_matplotlib.py +6 -5
  106. arviz/tests/base_tests/test_stats.py +4 -4
  107. arviz/tests/base_tests/test_stats_utils.py +1 -0
  108. arviz/tests/base_tests/test_utils.py +3 -2
  109. arviz/tests/helpers.py +1 -1
  110. arviz/wrappers/__init__.py +1 -0
  111. {arviz-0.17.1.dist-info → arviz-0.18.0.dist-info}/METADATA +7 -7
  112. arviz-0.18.0.dist-info/RECORD +182 -0
  113. arviz-0.17.1.dist-info/RECORD +0 -182
  114. {arviz-0.17.1.dist-info → arviz-0.18.0.dist-info}/LICENSE +0 -0
  115. {arviz-0.17.1.dist-info → arviz-0.18.0.dist-info}/WHEEL +0 -0
  116. {arviz-0.17.1.dist-info → arviz-0.18.0.dist-info}/top_level.txt +0 -0
arviz/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # pylint: disable=wildcard-import,invalid-name,wrong-import-position
2
2
  """ArviZ is a library for exploratory analysis of Bayesian models."""
3
- __version__ = "0.17.1"
3
+ __version__ = "0.18.0"
4
4
 
5
5
  import logging
6
6
  import os
@@ -315,7 +315,8 @@ _linear_grey_10_95_c0 = [
315
315
 
316
316
  def _mpl_cm(name, colorlist):
317
317
  cmap = LinearSegmentedColormap.from_list(name, colorlist, N=256)
318
- mpl.colormaps.register(cmap, name="cet_" + name)
318
+ if "cet_" + name not in mpl.colormaps():
319
+ mpl.colormaps.register(cmap, name="cet_" + name)
319
320
 
320
321
 
321
322
  try:
arviz/data/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Code for loading and manipulating data structures."""
2
- from .base import CoordSpec, DimSpec, dict_to_dataset, numpy_to_data_array
2
+
3
+ from .base import CoordSpec, DimSpec, dict_to_dataset, numpy_to_data_array, pytree_to_dataset
3
4
  from .converters import convert_to_dataset, convert_to_inference_data
4
5
  from .datasets import clear_data_home, list_datasets, load_arviz_data
5
6
  from .inference_data import InferenceData, concat
@@ -7,7 +8,7 @@ from .io_beanmachine import from_beanmachine
7
8
  from .io_cmdstan import from_cmdstan
8
9
  from .io_cmdstanpy import from_cmdstanpy
9
10
  from .io_datatree import from_datatree, to_datatree
10
- from .io_dict import from_dict
11
+ from .io_dict import from_dict, from_pytree
11
12
  from .io_emcee import from_emcee
12
13
  from .io_json import from_json, to_json
13
14
  from .io_netcdf import from_netcdf, to_netcdf
@@ -38,10 +39,12 @@ __all__ = [
38
39
  "from_cmdstanpy",
39
40
  "from_datatree",
40
41
  "from_dict",
42
+ "from_pytree",
41
43
  "from_json",
42
44
  "from_pyro",
43
45
  "from_numpyro",
44
46
  "from_netcdf",
47
+ "pytree_to_dataset",
45
48
  "to_datatree",
46
49
  "to_json",
47
50
  "to_netcdf",
arviz/data/base.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Low level converters usually used by other functions."""
2
+
2
3
  import datetime
3
4
  import functools
4
5
  import importlib
@@ -8,6 +9,7 @@ from copy import deepcopy
8
9
  from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
9
10
 
10
11
  import numpy as np
12
+ import tree
11
13
  import xarray as xr
12
14
 
13
15
  try:
@@ -67,6 +69,48 @@ class requires: # pylint: disable=invalid-name
67
69
  return wrapped
68
70
 
69
71
 
72
+ def _yield_flat_up_to(shallow_tree, input_tree, path=()):
73
+ """Yields (path, value) pairs of input_tree flattened up to shallow_tree.
74
+
75
+ Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow
76
+ lists as leaves.
77
+
78
+ Args:
79
+ shallow_tree: Nested structure. Traverse no further than its leaf nodes.
80
+ input_tree: Nested structure. Return the paths and values from this tree.
81
+ Must have the same upper structure as shallow_tree.
82
+ path: Tuple. Optional argument, only used when recursing. The path from the
83
+ root of the original shallow_tree, down to the root of the shallow_tree
84
+ arg of this recursive call.
85
+
86
+ Yields:
87
+ Pairs of (path, value), where path the tuple path of a leaf node in
88
+ shallow_tree, and value is the value of the corresponding node in
89
+ input_tree.
90
+ """
91
+ # pylint: disable=protected-access
92
+ if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
93
+ isinstance(shallow_tree, tree.collections_abc.Mapping)
94
+ or tree._is_namedtuple(shallow_tree)
95
+ or tree._is_attrs(shallow_tree)
96
+ ):
97
+ yield (path, input_tree)
98
+ else:
99
+ input_tree = dict(tree._yield_sorted_items(input_tree))
100
+ for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree):
101
+ subpath = path + (shallow_key,)
102
+ input_subtree = input_tree[shallow_key]
103
+ for leaf_path, leaf_value in _yield_flat_up_to(
104
+ shallow_subtree, input_subtree, path=subpath
105
+ ):
106
+ yield (leaf_path, leaf_value)
107
+ # pylint: enable=protected-access
108
+
109
+
110
+ def _flatten_with_path(structure):
111
+ return list(_yield_flat_up_to(structure, structure))
112
+
113
+
70
114
  def generate_dims_coords(
71
115
  shape,
72
116
  var_name,
@@ -255,7 +299,7 @@ def numpy_to_data_array(
255
299
  return xr.DataArray(ary, coords=coords, dims=dims)
256
300
 
257
301
 
258
- def dict_to_dataset(
302
+ def pytree_to_dataset(
259
303
  data,
260
304
  *,
261
305
  attrs=None,
@@ -266,26 +310,29 @@ def dict_to_dataset(
266
310
  index_origin=None,
267
311
  skip_event_dims=None,
268
312
  ):
269
- """Convert a dictionary of numpy arrays to an xarray.Dataset.
313
+ """Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
314
+
315
+ See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
316
+ this inclues at least dictionaries and tuple types.
270
317
 
271
318
  Parameters
272
319
  ----------
273
- data : dict[str] -> ndarray
320
+ data : dict of {str : array_like or dict} or pytree
274
321
  Data to convert. Keys are variable names.
275
- attrs : dict
322
+ attrs : dict, optional
276
323
  Json serializable metadata to attach to the dataset, in addition to defaults.
277
- library : module
324
+ library : module, optional
278
325
  Library used for performing inference. Will be attached to the attrs metadata.
279
- coords : dict[str] -> ndarray
326
+ coords : dict of {str : ndarray}, optional
280
327
  Coordinates for the dataset
281
- dims : dict[str] -> list[str]
328
+ dims : dict of {str : list of str}, optional
282
329
  Dimensions of each variable. The keys are variable names, values are lists of
283
330
  coordinates.
284
331
  default_dims : list of str, optional
285
332
  Passed to :py:func:`numpy_to_data_array`
286
333
  index_origin : int, optional
287
334
  Passed to :py:func:`numpy_to_data_array`
288
- skip_event_dims : bool
335
+ skip_event_dims : bool, optional
289
336
  If True, cut extra dims whenever present to match the shape of the data.
290
337
  Necessary for PPLs which have the same name in both observed data and log
291
338
  likelihood groups, to account for their different shapes when observations are
@@ -293,15 +340,56 @@ def dict_to_dataset(
293
340
 
294
341
  Returns
295
342
  -------
296
- xr.Dataset
343
+ xarray.Dataset
344
+ In case of nested pytrees, the variable name will be a tuple of individual names.
345
+
346
+ Notes
347
+ -----
348
+ This function is available through two aliases: ``dict_to_dataset`` or ``pytree_to_dataset``.
297
349
 
298
350
  Examples
299
351
  --------
300
- dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
352
+ Convert a dictionary with two 2D variables to a Dataset.
353
+
354
+ .. ipython::
355
+
356
+ In [1]: import arviz as az
357
+ ...: import numpy as np
358
+ ...: az.dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
359
+
360
+ Note that unlike the :class:`xarray.Dataset` constructor, ArviZ has added extra
361
+ information to the generated Dataset such as default dimension names for sampled
362
+ dimensions and some attributes.
363
+
364
+ The function is also general enough to work on pytrees such as nested dictionaries:
365
+
366
+ .. ipython::
367
+
368
+ In [1]: az.pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})
369
+
370
+ which has two variables (as many as leafs) named ``('top', 'second')`` and ``top2``.
371
+
372
+ Dimensions and co-ordinates can be defined as usual:
373
+
374
+ .. ipython::
375
+
376
+ In [1]: datadict = {
377
+ ...: "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
378
+ ...: "d": np.random.randn(100),
379
+ ...: }
380
+ ...: az.dict_to_dataset(
381
+ ...: datadict,
382
+ ...: coords={"c": np.arange(10)},
383
+ ...: dims={("top", "b"): ["c"]}
384
+ ...: )
301
385
 
302
386
  """
303
387
  if dims is None:
304
388
  dims = {}
389
+ try:
390
+ data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
391
+ except TypeError: # probably unsortable keys -- the function will still work if
392
+ pass # it is an honest dictionary.
305
393
 
306
394
  data_vars = {
307
395
  key: numpy_to_data_array(
@@ -318,6 +406,9 @@ def dict_to_dataset(
318
406
  return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
319
407
 
320
408
 
409
+ dict_to_dataset = pytree_to_dataset
410
+
411
+
321
412
  def make_attrs(attrs=None, library=None):
322
413
  """Make standard attributes to attach to xarray datasets.
323
414
 
@@ -332,7 +423,7 @@ def make_attrs(attrs=None, library=None):
332
423
  attrs
333
424
  """
334
425
  default_attrs = {
335
- "created_at": datetime.datetime.utcnow().isoformat(),
426
+ "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
336
427
  "arviz_version": __version__,
337
428
  }
338
429
  if library is not None:
arviz/data/converters.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """High level conversion functions."""
2
+
2
3
  import numpy as np
4
+ import tree
3
5
  import xarray as xr
4
6
 
5
7
  from .base import dict_to_dataset
@@ -105,6 +107,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
105
107
  dataset = obj.to_dataset()
106
108
  elif isinstance(obj, dict):
107
109
  dataset = dict_to_dataset(obj, coords=coords, dims=dims)
110
+ elif tree.is_nested(obj) and not isinstance(obj, (list, tuple)):
111
+ dataset = dict_to_dataset(obj, coords=coords, dims=dims)
108
112
  elif isinstance(obj, np.ndarray):
109
113
  dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
110
114
  elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
@@ -118,6 +122,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
118
122
  "xarray dataarray",
119
123
  "xarray dataset",
120
124
  "dict",
125
+ "pytree",
121
126
  "netcdf filename",
122
127
  "numpy array",
123
128
  "pystan fit",
arviz/data/datasets.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Base IO code for all datasets. Heavily influenced by scikit-learn's implementation."""
2
+
2
3
  import hashlib
3
4
  import itertools
4
5
  import json
@@ -9,9 +9,16 @@
9
9
  {
10
10
  "name": "rugby",
11
11
  "filename": "rugby.nc",
12
- "url": "http://ndownloader.figshare.com/files/16254359",
13
- "checksum": "9eecd2c6317e45b0388dd97ae6326adecf94128b5a7d15a52c9fcfac0937e2a6",
14
- "description": "The Six Nations Championship is a yearly rugby competition between Italy, Ireland, Scotland, England, France and Wales. Fifteen games are played each year, representing all combinations of the six teams.\n\nThis example uses and includes results from 2014 - 2017, comprising 60 total games. It models latent parameters for each team's attack and defense, as well as a parameter for home team advantage.\n\nSee https://docs.pymc.io/notebooks/rugby_analytics.html by Peader Coyle for more details and references."
12
+ "url": "http://figshare.com/ndownloader/files/44916469",
13
+ "checksum": "f4a5e699a8a4cc93f722eb97929dd7c4895c59a2183f05309f5082f3f81eb228",
14
+ "description": "The Six Nations Championship is a yearly rugby competition between Italy, Ireland, Scotland, England, France and Wales. Fifteen games are played each year, representing all combinations of the six teams.\n\nThis example uses and includes results from 2014 - 2017, comprising 60 total games. It models latent parameters for each team's attack and defense, as well as a global parameter for home team advantage.\n\nSee https://github.com/arviz-devs/arviz_example_data/blob/main/code/rugby/rugby.ipynb for the whole model specification."
15
+ },
16
+ {
17
+ "name": "rugby_field",
18
+ "filename": "rugby_field.nc",
19
+ "url": "http://figshare.com/ndownloader/files/44667112",
20
+ "checksum": "53a99da7ac40d82cd01bb0b089263b9633ee016f975700e941b4c6ea289a1fb0",
21
+ "description": "A variant of the 'rugby' example dataset. The Six Nations Championship is a yearly rugby competition between Italy, Ireland, Scotland, England, France and Wales. Fifteen games are played each year, representing all combinations of the six teams.\n\nThis example uses and includes results from 2014 - 2017, comprising 60 total games. It models latent parameters for each team's attack and defense, with each team having different values depending on them being home or away team.\n\nSee https://github.com/arviz-devs/arviz_example_data/blob/main/code/rugby_field/rugby_field.ipynb for the whole model specification."
15
22
  },
16
23
  {
17
24
  "name": "regression1d",
@@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict
9
9
  from collections.abc import MutableMapping, Sequence
10
10
  from copy import copy as ccopy
11
11
  from copy import deepcopy
12
- from datetime import datetime
12
+ import datetime
13
13
  from html import escape
14
14
  from typing import (
15
15
  TYPE_CHECKING,
@@ -394,8 +394,10 @@ class InferenceData(Mapping[str, xr.Dataset]):
394
394
  )
395
395
 
396
396
  try:
397
- with h5netcdf.File(filename, mode="r") if engine == "h5netcdf" else nc.Dataset(
398
- filename, mode="r"
397
+ with (
398
+ h5netcdf.File(filename, mode="r")
399
+ if engine == "h5netcdf"
400
+ else nc.Dataset(filename, mode="r")
399
401
  ) as file_handle:
400
402
  if base_group == "/":
401
403
  data = file_handle
@@ -744,11 +746,11 @@ class InferenceData(Mapping[str, xr.Dataset]):
744
746
  if len(dfs) > 1:
745
747
  for group, df in dfs.items():
746
748
  df.columns = [
747
- col
748
- if col in ("draw", "chain")
749
- else (group, *col)
750
- if isinstance(col, tuple)
751
- else (group, col)
749
+ (
750
+ col
751
+ if col in ("draw", "chain")
752
+ else (group, *col) if isinstance(col, tuple) else (group, col)
753
+ )
752
754
  for col in df.columns
753
755
  ]
754
756
  dfs, *dfs_tail = list(dfs.values())
@@ -1475,12 +1477,12 @@ class InferenceData(Mapping[str, xr.Dataset]):
1475
1477
  Examples
1476
1478
  --------
1477
1479
  Add a ``log_likelihood`` group to the "rugby" example InferenceData after loading.
1478
- It originally doesn't have the ``log_likelihood`` group:
1479
1480
 
1480
1481
  .. jupyter-execute::
1481
1482
 
1482
1483
  import arviz as az
1483
1484
  idata = az.load_arviz_data("rugby")
1485
+ del idata.log_likelihood
1484
1486
  idata2 = idata.copy()
1485
1487
  post = idata.posterior
1486
1488
  obs = idata.observed_data
@@ -1609,13 +1611,13 @@ class InferenceData(Mapping[str, xr.Dataset]):
1609
1611
  .. jupyter-execute::
1610
1612
 
1611
1613
  import arviz as az
1612
- idata = az.load_arviz_data("rugby")
1614
+ idata = az.load_arviz_data("radon")
1613
1615
 
1614
1616
  Second InferenceData:
1615
1617
 
1616
1618
  .. jupyter-execute::
1617
1619
 
1618
- other_idata = az.load_arviz_data("radon")
1620
+ other_idata = az.load_arviz_data("rugby")
1619
1621
 
1620
1622
  Call the ``extend`` method:
1621
1623
 
@@ -1687,6 +1689,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
1687
1689
  compute = _extend_xr_method(xr.Dataset.compute)
1688
1690
  persist = _extend_xr_method(xr.Dataset.persist)
1689
1691
  quantile = _extend_xr_method(xr.Dataset.quantile)
1692
+ close = _extend_xr_method(xr.Dataset.close)
1690
1693
 
1691
1694
  # The following lines use methods on xr.Dataset that are dynamically defined and attached.
1692
1695
  # As a result mypy cannot see them, so we have to suppress the resulting mypy errors.
@@ -1918,8 +1921,7 @@ def concat(
1918
1921
  copy: bool = True,
1919
1922
  inplace: "Literal[True]",
1920
1923
  reset_dim: bool = True,
1921
- ) -> None:
1922
- ...
1924
+ ) -> None: ...
1923
1925
 
1924
1926
 
1925
1927
  @overload
@@ -1929,8 +1931,7 @@ def concat(
1929
1931
  copy: bool = True,
1930
1932
  inplace: "Literal[False]",
1931
1933
  reset_dim: bool = True,
1932
- ) -> InferenceData:
1933
- ...
1934
+ ) -> InferenceData: ...
1934
1935
 
1935
1936
 
1936
1937
  @overload
@@ -1941,8 +1942,7 @@ def concat(
1941
1942
  copy: bool = True,
1942
1943
  inplace: "Literal[False]",
1943
1944
  reset_dim: bool = True,
1944
- ) -> InferenceData:
1945
- ...
1945
+ ) -> InferenceData: ...
1946
1946
 
1947
1947
 
1948
1948
  @overload
@@ -1953,8 +1953,7 @@ def concat(
1953
1953
  copy: bool = True,
1954
1954
  inplace: "Literal[True]",
1955
1955
  reset_dim: bool = True,
1956
- ) -> None:
1957
- ...
1956
+ ) -> None: ...
1958
1957
 
1959
1958
 
1960
1959
  @overload
@@ -1965,8 +1964,7 @@ def concat(
1965
1964
  copy: bool = True,
1966
1965
  inplace: bool = False,
1967
1966
  reset_dim: bool = True,
1968
- ) -> Optional[InferenceData]:
1969
- ...
1967
+ ) -> Optional[InferenceData]: ...
1970
1968
 
1971
1969
 
1972
1970
  # pylint: disable=protected-access, inconsistent-return-statements
@@ -2083,7 +2081,7 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
2083
2081
  else:
2084
2082
  return args[0]
2085
2083
 
2086
- current_time = str(datetime.now())
2084
+ current_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
2087
2085
  combined_attr = defaultdict(list)
2088
2086
  for idata in args:
2089
2087
  for key, val in idata.attrs.items():
arviz/data/io_cmdstan.py CHANGED
@@ -732,9 +732,7 @@ def _process_configuration(comments):
732
732
  key = (
733
733
  "warmup_time_seconds"
734
734
  if "(Warm-up)" in comment
735
- else "sampling_time_seconds"
736
- if "(Sampling)" in comment
737
- else "total_time_seconds"
735
+ else "sampling_time_seconds" if "(Sampling)" in comment else "total_time_seconds"
738
736
  )
739
737
  results[key] = float(value)
740
738
  elif "=" in comment:
arviz/data/io_datatree.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Conversion between InferenceData and DataTree."""
2
+
2
3
  from .inference_data import InferenceData
3
4
 
4
5
 
arviz/data/io_dict.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Dictionary specific conversion code."""
2
+
2
3
  import warnings
3
4
  from typing import Optional
4
5
 
@@ -59,9 +60,7 @@ class DictConverter:
59
60
  self.coords = (
60
61
  coords
61
62
  if pred_coords is None
62
- else pred_coords
63
- if coords is None
64
- else {**coords, **pred_coords}
63
+ else pred_coords if coords is None else {**coords, **pred_coords}
65
64
  )
66
65
  self.index_origin = index_origin
67
66
  self.coords = coords
@@ -458,3 +457,6 @@ def from_dict(
458
457
  attrs=attrs,
459
458
  **kwargs,
460
459
  ).to_inference_data()
460
+
461
+
462
+ from_pytree = from_dict
arviz/data/io_emcee.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """emcee-specific conversion code."""
2
+
2
3
  import warnings
3
4
  from collections import OrderedDict
4
5
 
arviz/data/io_numpyro.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """NumPyro-specific conversion code."""
2
+
2
3
  import logging
3
4
  from typing import Callable, Optional
4
5
 
arviz/data/io_pyjags.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Convert PyJAGS sample dictionaries to ArviZ inference data objects."""
2
+
2
3
  import typing as tp
3
4
  from collections import OrderedDict
4
5
  from collections.abc import Iterable
arviz/data/io_pyro.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Pyro-specific conversion code."""
2
+
2
3
  import logging
3
4
  from typing import Callable, Optional
4
5
  import warnings
arviz/data/utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Data specific utilities."""
2
+
2
3
  import warnings
3
4
  import numpy as np
4
5
 
arviz/plots/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plotting functions."""
2
+
2
3
  from .autocorrplot import plot_autocorr
3
4
  from .bpvplot import plot_bpv
4
5
  from .bfplot import plot_bf
@@ -1,4 +1,5 @@
1
1
  """Autocorrelation plot of data."""
2
+
2
3
  from ..data import convert_to_dataset
3
4
  from ..labels import BaseLabeller
4
5
  from ..sel_utils import xarray_var_iter
@@ -1,4 +1,5 @@
1
1
  """Bokeh Autocorrplot."""
2
+
2
3
  import numpy as np
3
4
  from bokeh.models import DataRange1d, BoxAnnotation
4
5
  from bokeh.models.annotations import Title
@@ -1,4 +1,5 @@
1
1
  """Bokeh Bayesian p-value Posterior predictive plot."""
2
+
2
3
  import numpy as np
3
4
  from bokeh.models import BoxAnnotation
4
5
  from bokeh.models.annotations import Title
@@ -1,4 +1,5 @@
1
1
  """Bokeh Compareplot."""
2
+
2
3
  from bokeh.models import Span
3
4
  from bokeh.models.annotations import Title, Legend
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Bokeh Densityplot."""
2
+
2
3
  from collections import defaultdict
3
4
  from itertools import cycle
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Bokeh Distplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Bokeh dotplot."""
2
+
2
3
  import math
3
4
  import warnings
4
5
  import numpy as np
@@ -1,4 +1,5 @@
1
1
  """Bokeh ecdfplot."""
2
+
2
3
  from matplotlib.colors import to_hex
3
4
 
4
5
  from ...plot_utils import _scale_fig_size
@@ -1,4 +1,5 @@
1
1
  """Bokeh ELPDPlot."""
2
+
2
3
  import warnings
3
4
 
4
5
  import bokeh.plotting as bkp
@@ -1,4 +1,5 @@
1
1
  """Bokeh energyplot."""
2
+
2
3
  from itertools import cycle
3
4
 
4
5
  import numpy as np
@@ -1,4 +1,5 @@
1
1
  """Bokeh hdiplot."""
2
+
2
3
  import numpy as np
3
4
 
4
5
  from ...plot_utils import _scale_fig_size, vectorized_to_hex
@@ -6,7 +6,7 @@ from numbers import Integral
6
6
  import numpy as np
7
7
  from bokeh.models import ColumnDataSource
8
8
  from bokeh.models.glyphs import Scatter
9
- from matplotlib.cm import get_cmap
9
+ from matplotlib import colormaps
10
10
  from matplotlib.colors import rgb2hex
11
11
  from matplotlib.pyplot import rcParams as mpl_rcParams
12
12
 
@@ -188,7 +188,7 @@ def plot_kde(
188
188
 
189
189
  cmap = contourf_kwargs.pop("cmap", "viridis")
190
190
  if isinstance(cmap, str):
191
- cmap = get_cmap(cmap)
191
+ cmap = colormaps[cmap]
192
192
  if isinstance(cmap, Callable):
193
193
  colors = [rgb2hex(item) for item in cmap(np.linspace(0, 1, len(levels_scaled) + 1))]
194
194
  else:
@@ -225,7 +225,7 @@ def plot_kde(
225
225
  else:
226
226
  cmap = pcolormesh_kwargs.pop("cmap", "viridis")
227
227
  if isinstance(cmap, str):
228
- cmap = get_cmap(cmap)
228
+ cmap = colormaps[cmap]
229
229
  if isinstance(cmap, Callable):
230
230
  colors = [rgb2hex(item) for item in cmap(np.linspace(0, 1, 256))]
231
231
  else:
@@ -1,4 +1,5 @@
1
1
  """Bokeh pareto shape plot."""
2
+
2
3
  from collections.abc import Iterable
3
4
 
4
5
  from matplotlib import cm
@@ -1,4 +1,5 @@
1
1
  """Bokeh linear regression plot."""
2
+
2
3
  import numpy as np
3
4
  from bokeh.models.annotations import Legend
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Bokeh loopitplot."""
2
+
2
3
  import numpy as np
3
4
  from bokeh.models import BoxAnnotation
4
5
  from matplotlib.colors import hsv_to_rgb, rgb_to_hsv, to_hex, to_rgb
@@ -1,4 +1,5 @@
1
1
  """Bokeh mcseplot."""
2
+
2
3
  import numpy as np
3
4
  from bokeh.models import ColumnDataSource, Span
4
5
  from bokeh.models.glyphs import Scatter
@@ -1,4 +1,5 @@
1
1
  """Bokeh pairplot."""
2
+
2
3
  import warnings
3
4
  from copy import deepcopy
4
5
  from uuid import uuid4
@@ -1,4 +1,5 @@
1
1
  """Bokeh Parallel coordinates plot."""
2
+
2
3
  import numpy as np
3
4
  from bokeh.models import DataRange1d
4
5
  from bokeh.models.tickers import FixedTicker
@@ -1,4 +1,5 @@
1
1
  """Bokeh Plot posterior densities."""
2
+
2
3
  from numbers import Number
3
4
  from typing import Optional
4
5