arviz 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. arviz/__init__.py +3 -2
  2. arviz/data/__init__.py +5 -2
  3. arviz/data/base.py +102 -11
  4. arviz/data/converters.py +5 -0
  5. arviz/data/datasets.py +1 -0
  6. arviz/data/example_data/data_remote.json +10 -3
  7. arviz/data/inference_data.py +26 -25
  8. arviz/data/io_cmdstan.py +1 -3
  9. arviz/data/io_datatree.py +1 -0
  10. arviz/data/io_dict.py +5 -3
  11. arviz/data/io_emcee.py +1 -0
  12. arviz/data/io_numpyro.py +1 -0
  13. arviz/data/io_pyjags.py +1 -0
  14. arviz/data/io_pyro.py +1 -0
  15. arviz/data/io_pystan.py +1 -2
  16. arviz/data/utils.py +1 -0
  17. arviz/plots/__init__.py +1 -0
  18. arviz/plots/autocorrplot.py +1 -0
  19. arviz/plots/backends/bokeh/autocorrplot.py +1 -0
  20. arviz/plots/backends/bokeh/bpvplot.py +8 -2
  21. arviz/plots/backends/bokeh/compareplot.py +8 -4
  22. arviz/plots/backends/bokeh/densityplot.py +1 -0
  23. arviz/plots/backends/bokeh/distplot.py +1 -0
  24. arviz/plots/backends/bokeh/dotplot.py +1 -0
  25. arviz/plots/backends/bokeh/ecdfplot.py +1 -0
  26. arviz/plots/backends/bokeh/elpdplot.py +1 -0
  27. arviz/plots/backends/bokeh/energyplot.py +1 -0
  28. arviz/plots/backends/bokeh/forestplot.py +2 -4
  29. arviz/plots/backends/bokeh/hdiplot.py +1 -0
  30. arviz/plots/backends/bokeh/kdeplot.py +3 -3
  31. arviz/plots/backends/bokeh/khatplot.py +1 -0
  32. arviz/plots/backends/bokeh/lmplot.py +1 -0
  33. arviz/plots/backends/bokeh/loopitplot.py +1 -0
  34. arviz/plots/backends/bokeh/mcseplot.py +1 -0
  35. arviz/plots/backends/bokeh/pairplot.py +1 -0
  36. arviz/plots/backends/bokeh/parallelplot.py +1 -0
  37. arviz/plots/backends/bokeh/posteriorplot.py +1 -0
  38. arviz/plots/backends/bokeh/ppcplot.py +1 -0
  39. arviz/plots/backends/bokeh/rankplot.py +1 -0
  40. arviz/plots/backends/bokeh/separationplot.py +1 -0
  41. arviz/plots/backends/bokeh/traceplot.py +1 -0
  42. arviz/plots/backends/bokeh/violinplot.py +1 -0
  43. arviz/plots/backends/matplotlib/autocorrplot.py +1 -0
  44. arviz/plots/backends/matplotlib/bpvplot.py +1 -0
  45. arviz/plots/backends/matplotlib/compareplot.py +2 -1
  46. arviz/plots/backends/matplotlib/densityplot.py +1 -0
  47. arviz/plots/backends/matplotlib/distcomparisonplot.py +2 -3
  48. arviz/plots/backends/matplotlib/distplot.py +1 -0
  49. arviz/plots/backends/matplotlib/dotplot.py +1 -0
  50. arviz/plots/backends/matplotlib/ecdfplot.py +1 -0
  51. arviz/plots/backends/matplotlib/elpdplot.py +1 -0
  52. arviz/plots/backends/matplotlib/energyplot.py +1 -0
  53. arviz/plots/backends/matplotlib/essplot.py +6 -5
  54. arviz/plots/backends/matplotlib/forestplot.py +3 -4
  55. arviz/plots/backends/matplotlib/hdiplot.py +1 -0
  56. arviz/plots/backends/matplotlib/kdeplot.py +5 -3
  57. arviz/plots/backends/matplotlib/khatplot.py +1 -0
  58. arviz/plots/backends/matplotlib/lmplot.py +1 -0
  59. arviz/plots/backends/matplotlib/loopitplot.py +1 -0
  60. arviz/plots/backends/matplotlib/mcseplot.py +11 -10
  61. arviz/plots/backends/matplotlib/pairplot.py +2 -1
  62. arviz/plots/backends/matplotlib/parallelplot.py +1 -0
  63. arviz/plots/backends/matplotlib/posteriorplot.py +1 -0
  64. arviz/plots/backends/matplotlib/ppcplot.py +1 -0
  65. arviz/plots/backends/matplotlib/rankplot.py +1 -0
  66. arviz/plots/backends/matplotlib/separationplot.py +1 -0
  67. arviz/plots/backends/matplotlib/traceplot.py +2 -1
  68. arviz/plots/backends/matplotlib/tsplot.py +1 -0
  69. arviz/plots/backends/matplotlib/violinplot.py +2 -1
  70. arviz/plots/bfplot.py +7 -6
  71. arviz/plots/bpvplot.py +3 -2
  72. arviz/plots/compareplot.py +3 -2
  73. arviz/plots/densityplot.py +1 -0
  74. arviz/plots/distcomparisonplot.py +1 -0
  75. arviz/plots/dotplot.py +1 -0
  76. arviz/plots/ecdfplot.py +38 -112
  77. arviz/plots/elpdplot.py +2 -1
  78. arviz/plots/energyplot.py +1 -0
  79. arviz/plots/essplot.py +3 -2
  80. arviz/plots/forestplot.py +1 -0
  81. arviz/plots/hdiplot.py +1 -0
  82. arviz/plots/khatplot.py +1 -0
  83. arviz/plots/lmplot.py +1 -0
  84. arviz/plots/loopitplot.py +1 -0
  85. arviz/plots/mcseplot.py +1 -0
  86. arviz/plots/pairplot.py +2 -1
  87. arviz/plots/parallelplot.py +1 -0
  88. arviz/plots/plot_utils.py +1 -0
  89. arviz/plots/posteriorplot.py +1 -0
  90. arviz/plots/ppcplot.py +11 -5
  91. arviz/plots/rankplot.py +1 -0
  92. arviz/plots/separationplot.py +1 -0
  93. arviz/plots/traceplot.py +1 -0
  94. arviz/plots/tsplot.py +1 -0
  95. arviz/plots/violinplot.py +1 -0
  96. arviz/rcparams.py +1 -0
  97. arviz/sel_utils.py +1 -0
  98. arviz/static/css/style.css +2 -1
  99. arviz/stats/density_utils.py +4 -3
  100. arviz/stats/diagnostics.py +4 -4
  101. arviz/stats/ecdf_utils.py +166 -0
  102. arviz/stats/stats.py +16 -32
  103. arviz/stats/stats_refitting.py +1 -0
  104. arviz/stats/stats_utils.py +6 -2
  105. arviz/tests/base_tests/test_data.py +18 -4
  106. arviz/tests/base_tests/test_diagnostics.py +1 -0
  107. arviz/tests/base_tests/test_diagnostics_numba.py +1 -0
  108. arviz/tests/base_tests/test_labels.py +1 -0
  109. arviz/tests/base_tests/test_plots_matplotlib.py +6 -5
  110. arviz/tests/base_tests/test_stats.py +4 -4
  111. arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
  112. arviz/tests/base_tests/test_stats_utils.py +4 -3
  113. arviz/tests/base_tests/test_utils.py +3 -2
  114. arviz/tests/external_tests/test_data_numpyro.py +3 -3
  115. arviz/tests/external_tests/test_data_pyro.py +3 -3
  116. arviz/tests/helpers.py +1 -1
  117. arviz/wrappers/__init__.py +1 -0
  118. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/METADATA +10 -9
  119. arviz-0.18.0.dist-info/RECORD +182 -0
  120. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/WHEEL +1 -1
  121. arviz-0.17.0.dist-info/RECORD +0 -180
  122. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/LICENSE +0 -0
  123. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/top_level.txt +0 -0
arviz/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # pylint: disable=wildcard-import,invalid-name,wrong-import-position
2
2
  """ArviZ is a library for exploratory analysis of Bayesian models."""
3
- __version__ = "0.17.0"
3
+ __version__ = "0.18.0"
4
4
 
5
5
  import logging
6
6
  import os
@@ -315,7 +315,8 @@ _linear_grey_10_95_c0 = [
315
315
 
316
316
  def _mpl_cm(name, colorlist):
317
317
  cmap = LinearSegmentedColormap.from_list(name, colorlist, N=256)
318
- mpl.colormaps.register(cmap, name="cet_" + name)
318
+ if "cet_" + name not in mpl.colormaps():
319
+ mpl.colormaps.register(cmap, name="cet_" + name)
319
320
 
320
321
 
321
322
  try:
arviz/data/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Code for loading and manipulating data structures."""
2
- from .base import CoordSpec, DimSpec, dict_to_dataset, numpy_to_data_array
2
+
3
+ from .base import CoordSpec, DimSpec, dict_to_dataset, numpy_to_data_array, pytree_to_dataset
3
4
  from .converters import convert_to_dataset, convert_to_inference_data
4
5
  from .datasets import clear_data_home, list_datasets, load_arviz_data
5
6
  from .inference_data import InferenceData, concat
@@ -7,7 +8,7 @@ from .io_beanmachine import from_beanmachine
7
8
  from .io_cmdstan import from_cmdstan
8
9
  from .io_cmdstanpy import from_cmdstanpy
9
10
  from .io_datatree import from_datatree, to_datatree
10
- from .io_dict import from_dict
11
+ from .io_dict import from_dict, from_pytree
11
12
  from .io_emcee import from_emcee
12
13
  from .io_json import from_json, to_json
13
14
  from .io_netcdf import from_netcdf, to_netcdf
@@ -38,10 +39,12 @@ __all__ = [
38
39
  "from_cmdstanpy",
39
40
  "from_datatree",
40
41
  "from_dict",
42
+ "from_pytree",
41
43
  "from_json",
42
44
  "from_pyro",
43
45
  "from_numpyro",
44
46
  "from_netcdf",
47
+ "pytree_to_dataset",
45
48
  "to_datatree",
46
49
  "to_json",
47
50
  "to_netcdf",
arviz/data/base.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Low level converters usually used by other functions."""
2
+
2
3
  import datetime
3
4
  import functools
4
5
  import importlib
@@ -8,6 +9,7 @@ from copy import deepcopy
8
9
  from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
9
10
 
10
11
  import numpy as np
12
+ import tree
11
13
  import xarray as xr
12
14
 
13
15
  try:
@@ -67,6 +69,48 @@ class requires: # pylint: disable=invalid-name
67
69
  return wrapped
68
70
 
69
71
 
72
+ def _yield_flat_up_to(shallow_tree, input_tree, path=()):
73
+ """Yields (path, value) pairs of input_tree flattened up to shallow_tree.
74
+
75
+ Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow
76
+ lists as leaves.
77
+
78
+ Args:
79
+ shallow_tree: Nested structure. Traverse no further than its leaf nodes.
80
+ input_tree: Nested structure. Return the paths and values from this tree.
81
+ Must have the same upper structure as shallow_tree.
82
+ path: Tuple. Optional argument, only used when recursing. The path from the
83
+ root of the original shallow_tree, down to the root of the shallow_tree
84
+ arg of this recursive call.
85
+
86
+ Yields:
87
+ Pairs of (path, value), where path the tuple path of a leaf node in
88
+ shallow_tree, and value is the value of the corresponding node in
89
+ input_tree.
90
+ """
91
+ # pylint: disable=protected-access
92
+ if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
93
+ isinstance(shallow_tree, tree.collections_abc.Mapping)
94
+ or tree._is_namedtuple(shallow_tree)
95
+ or tree._is_attrs(shallow_tree)
96
+ ):
97
+ yield (path, input_tree)
98
+ else:
99
+ input_tree = dict(tree._yield_sorted_items(input_tree))
100
+ for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree):
101
+ subpath = path + (shallow_key,)
102
+ input_subtree = input_tree[shallow_key]
103
+ for leaf_path, leaf_value in _yield_flat_up_to(
104
+ shallow_subtree, input_subtree, path=subpath
105
+ ):
106
+ yield (leaf_path, leaf_value)
107
+ # pylint: enable=protected-access
108
+
109
+
110
+ def _flatten_with_path(structure):
111
+ return list(_yield_flat_up_to(structure, structure))
112
+
113
+
70
114
  def generate_dims_coords(
71
115
  shape,
72
116
  var_name,
@@ -255,7 +299,7 @@ def numpy_to_data_array(
255
299
  return xr.DataArray(ary, coords=coords, dims=dims)
256
300
 
257
301
 
258
- def dict_to_dataset(
302
+ def pytree_to_dataset(
259
303
  data,
260
304
  *,
261
305
  attrs=None,
@@ -266,26 +310,29 @@ def dict_to_dataset(
266
310
  index_origin=None,
267
311
  skip_event_dims=None,
268
312
  ):
269
- """Convert a dictionary of numpy arrays to an xarray.Dataset.
313
+ """Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
314
+
315
+ See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
316
+ this inclues at least dictionaries and tuple types.
270
317
 
271
318
  Parameters
272
319
  ----------
273
- data : dict[str] -> ndarray
320
+ data : dict of {str : array_like or dict} or pytree
274
321
  Data to convert. Keys are variable names.
275
- attrs : dict
322
+ attrs : dict, optional
276
323
  Json serializable metadata to attach to the dataset, in addition to defaults.
277
- library : module
324
+ library : module, optional
278
325
  Library used for performing inference. Will be attached to the attrs metadata.
279
- coords : dict[str] -> ndarray
326
+ coords : dict of {str : ndarray}, optional
280
327
  Coordinates for the dataset
281
- dims : dict[str] -> list[str]
328
+ dims : dict of {str : list of str}, optional
282
329
  Dimensions of each variable. The keys are variable names, values are lists of
283
330
  coordinates.
284
331
  default_dims : list of str, optional
285
332
  Passed to :py:func:`numpy_to_data_array`
286
333
  index_origin : int, optional
287
334
  Passed to :py:func:`numpy_to_data_array`
288
- skip_event_dims : bool
335
+ skip_event_dims : bool, optional
289
336
  If True, cut extra dims whenever present to match the shape of the data.
290
337
  Necessary for PPLs which have the same name in both observed data and log
291
338
  likelihood groups, to account for their different shapes when observations are
@@ -293,15 +340,56 @@ def dict_to_dataset(
293
340
 
294
341
  Returns
295
342
  -------
296
- xr.Dataset
343
+ xarray.Dataset
344
+ In case of nested pytrees, the variable name will be a tuple of individual names.
345
+
346
+ Notes
347
+ -----
348
+ This function is available through two aliases: ``dict_to_dataset`` or ``pytree_to_dataset``.
297
349
 
298
350
  Examples
299
351
  --------
300
- dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
352
+ Convert a dictionary with two 2D variables to a Dataset.
353
+
354
+ .. ipython::
355
+
356
+ In [1]: import arviz as az
357
+ ...: import numpy as np
358
+ ...: az.dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
359
+
360
+ Note that unlike the :class:`xarray.Dataset` constructor, ArviZ has added extra
361
+ information to the generated Dataset such as default dimension names for sampled
362
+ dimensions and some attributes.
363
+
364
+ The function is also general enough to work on pytrees such as nested dictionaries:
365
+
366
+ .. ipython::
367
+
368
+ In [1]: az.pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})
369
+
370
+ which has two variables (as many as leafs) named ``('top', 'second')`` and ``top2``.
371
+
372
+ Dimensions and co-ordinates can be defined as usual:
373
+
374
+ .. ipython::
375
+
376
+ In [1]: datadict = {
377
+ ...: "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
378
+ ...: "d": np.random.randn(100),
379
+ ...: }
380
+ ...: az.dict_to_dataset(
381
+ ...: datadict,
382
+ ...: coords={"c": np.arange(10)},
383
+ ...: dims={("top", "b"): ["c"]}
384
+ ...: )
301
385
 
302
386
  """
303
387
  if dims is None:
304
388
  dims = {}
389
+ try:
390
+ data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
391
+ except TypeError: # probably unsortable keys -- the function will still work if
392
+ pass # it is an honest dictionary.
305
393
 
306
394
  data_vars = {
307
395
  key: numpy_to_data_array(
@@ -318,6 +406,9 @@ def dict_to_dataset(
318
406
  return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
319
407
 
320
408
 
409
+ dict_to_dataset = pytree_to_dataset
410
+
411
+
321
412
  def make_attrs(attrs=None, library=None):
322
413
  """Make standard attributes to attach to xarray datasets.
323
414
 
@@ -332,7 +423,7 @@ def make_attrs(attrs=None, library=None):
332
423
  attrs
333
424
  """
334
425
  default_attrs = {
335
- "created_at": datetime.datetime.utcnow().isoformat(),
426
+ "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
336
427
  "arviz_version": __version__,
337
428
  }
338
429
  if library is not None:
arviz/data/converters.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """High level conversion functions."""
2
+
2
3
  import numpy as np
4
+ import tree
3
5
  import xarray as xr
4
6
 
5
7
  from .base import dict_to_dataset
@@ -105,6 +107,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
105
107
  dataset = obj.to_dataset()
106
108
  elif isinstance(obj, dict):
107
109
  dataset = dict_to_dataset(obj, coords=coords, dims=dims)
110
+ elif tree.is_nested(obj) and not isinstance(obj, (list, tuple)):
111
+ dataset = dict_to_dataset(obj, coords=coords, dims=dims)
108
112
  elif isinstance(obj, np.ndarray):
109
113
  dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
110
114
  elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
@@ -118,6 +122,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
118
122
  "xarray dataarray",
119
123
  "xarray dataset",
120
124
  "dict",
125
+ "pytree",
121
126
  "netcdf filename",
122
127
  "numpy array",
123
128
  "pystan fit",
arviz/data/datasets.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Base IO code for all datasets. Heavily influenced by scikit-learn's implementation."""
2
+
2
3
  import hashlib
3
4
  import itertools
4
5
  import json
@@ -9,9 +9,16 @@
9
9
  {
10
10
  "name": "rugby",
11
11
  "filename": "rugby.nc",
12
- "url": "http://ndownloader.figshare.com/files/16254359",
13
- "checksum": "9eecd2c6317e45b0388dd97ae6326adecf94128b5a7d15a52c9fcfac0937e2a6",
14
- "description": "The Six Nations Championship is a yearly rugby competition between Italy, Ireland, Scotland, England, France and Wales. Fifteen games are played each year, representing all combinations of the six teams.\n\nThis example uses and includes results from 2014 - 2017, comprising 60 total games. It models latent parameters for each team's attack and defense, as well as a parameter for home team advantage.\n\nSee https://docs.pymc.io/notebooks/rugby_analytics.html by Peader Coyle for more details and references."
12
+ "url": "http://figshare.com/ndownloader/files/44916469",
13
+ "checksum": "f4a5e699a8a4cc93f722eb97929dd7c4895c59a2183f05309f5082f3f81eb228",
14
+ "description": "The Six Nations Championship is a yearly rugby competition between Italy, Ireland, Scotland, England, France and Wales. Fifteen games are played each year, representing all combinations of the six teams.\n\nThis example uses and includes results from 2014 - 2017, comprising 60 total games. It models latent parameters for each team's attack and defense, as well as a global parameter for home team advantage.\n\nSee https://github.com/arviz-devs/arviz_example_data/blob/main/code/rugby/rugby.ipynb for the whole model specification."
15
+ },
16
+ {
17
+ "name": "rugby_field",
18
+ "filename": "rugby_field.nc",
19
+ "url": "http://figshare.com/ndownloader/files/44667112",
20
+ "checksum": "53a99da7ac40d82cd01bb0b089263b9633ee016f975700e941b4c6ea289a1fb0",
21
+ "description": "A variant of the 'rugby' example dataset. The Six Nations Championship is a yearly rugby competition between Italy, Ireland, Scotland, England, France and Wales. Fifteen games are played each year, representing all combinations of the six teams.\n\nThis example uses and includes results from 2014 - 2017, comprising 60 total games. It models latent parameters for each team's attack and defense, with each team having different values depending on them being home or away team.\n\nSee https://github.com/arviz-devs/arviz_example_data/blob/main/code/rugby_field/rugby_field.ipynb for the whole model specification."
15
22
  },
16
23
  {
17
24
  "name": "regression1d",
@@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict
9
9
  from collections.abc import MutableMapping, Sequence
10
10
  from copy import copy as ccopy
11
11
  from copy import deepcopy
12
- from datetime import datetime
12
+ import datetime
13
13
  from html import escape
14
14
  from typing import (
15
15
  TYPE_CHECKING,
@@ -56,6 +56,7 @@ SUPPORTED_GROUPS = [
56
56
  "posterior_predictive",
57
57
  "predictions",
58
58
  "log_likelihood",
59
+ "log_prior",
59
60
  "sample_stats",
60
61
  "prior",
61
62
  "prior_predictive",
@@ -63,6 +64,8 @@ SUPPORTED_GROUPS = [
63
64
  "observed_data",
64
65
  "constant_data",
65
66
  "predictions_constant_data",
67
+ "unconstrained_posterior",
68
+ "unconstrained_prior",
66
69
  ]
67
70
 
68
71
  WARMUP_TAG = "warmup_"
@@ -73,6 +76,7 @@ SUPPORTED_GROUPS_WARMUP = [
73
76
  f"{WARMUP_TAG}predictions",
74
77
  f"{WARMUP_TAG}sample_stats",
75
78
  f"{WARMUP_TAG}log_likelihood",
79
+ f"{WARMUP_TAG}log_prior",
76
80
  ]
77
81
 
78
82
  SUPPORTED_GROUPS_ALL = SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP
@@ -250,8 +254,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
250
254
 
251
255
  def __iter__(self) -> Iterator[str]:
252
256
  """Iterate over groups in InferenceData object."""
253
- for group in self._groups_all:
254
- yield group
257
+ yield from self._groups_all
255
258
 
256
259
  def __contains__(self, key: object) -> bool:
257
260
  """Return True if the named item is present, and False otherwise."""
@@ -391,8 +394,10 @@ class InferenceData(Mapping[str, xr.Dataset]):
391
394
  )
392
395
 
393
396
  try:
394
- with h5netcdf.File(filename, mode="r") if engine == "h5netcdf" else nc.Dataset(
395
- filename, mode="r"
397
+ with (
398
+ h5netcdf.File(filename, mode="r")
399
+ if engine == "h5netcdf"
400
+ else nc.Dataset(filename, mode="r")
396
401
  ) as file_handle:
397
402
  if base_group == "/":
398
403
  data = file_handle
@@ -741,11 +746,11 @@ class InferenceData(Mapping[str, xr.Dataset]):
741
746
  if len(dfs) > 1:
742
747
  for group, df in dfs.items():
743
748
  df.columns = [
744
- col
745
- if col in ("draw", "chain")
746
- else (group, *col)
747
- if isinstance(col, tuple)
748
- else (group, col)
749
+ (
750
+ col
751
+ if col in ("draw", "chain")
752
+ else (group, *col) if isinstance(col, tuple) else (group, col)
753
+ )
749
754
  for col in df.columns
750
755
  ]
751
756
  dfs, *dfs_tail = list(dfs.values())
@@ -1472,12 +1477,12 @@ class InferenceData(Mapping[str, xr.Dataset]):
1472
1477
  Examples
1473
1478
  --------
1474
1479
  Add a ``log_likelihood`` group to the "rugby" example InferenceData after loading.
1475
- It originally doesn't have the ``log_likelihood`` group:
1476
1480
 
1477
1481
  .. jupyter-execute::
1478
1482
 
1479
1483
  import arviz as az
1480
1484
  idata = az.load_arviz_data("rugby")
1485
+ del idata.log_likelihood
1481
1486
  idata2 = idata.copy()
1482
1487
  post = idata.posterior
1483
1488
  obs = idata.observed_data
@@ -1490,7 +1495,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
1490
1495
 
1491
1496
  import numpy as np
1492
1497
  rng = np.random.default_rng(73)
1493
- ary = rng.normal(size=(post.dims["chain"], post.dims["draw"], obs.dims["match"]))
1498
+ ary = rng.normal(size=(post.sizes["chain"], post.sizes["draw"], obs.sizes["match"]))
1494
1499
  idata.add_groups(
1495
1500
  log_likelihood={"home_points": ary},
1496
1501
  dims={"home_points": ["match"]},
@@ -1606,13 +1611,13 @@ class InferenceData(Mapping[str, xr.Dataset]):
1606
1611
  .. jupyter-execute::
1607
1612
 
1608
1613
  import arviz as az
1609
- idata = az.load_arviz_data("rugby")
1614
+ idata = az.load_arviz_data("radon")
1610
1615
 
1611
1616
  Second InferenceData:
1612
1617
 
1613
1618
  .. jupyter-execute::
1614
1619
 
1615
- other_idata = az.load_arviz_data("radon")
1620
+ other_idata = az.load_arviz_data("rugby")
1616
1621
 
1617
1622
  Call the ``extend`` method:
1618
1623
 
@@ -1684,6 +1689,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
1684
1689
  compute = _extend_xr_method(xr.Dataset.compute)
1685
1690
  persist = _extend_xr_method(xr.Dataset.persist)
1686
1691
  quantile = _extend_xr_method(xr.Dataset.quantile)
1692
+ close = _extend_xr_method(xr.Dataset.close)
1687
1693
 
1688
1694
  # The following lines use methods on xr.Dataset that are dynamically defined and attached.
1689
1695
  # As a result mypy cannot see them, so we have to suppress the resulting mypy errors.
@@ -1915,8 +1921,7 @@ def concat(
1915
1921
  copy: bool = True,
1916
1922
  inplace: "Literal[True]",
1917
1923
  reset_dim: bool = True,
1918
- ) -> None:
1919
- ...
1924
+ ) -> None: ...
1920
1925
 
1921
1926
 
1922
1927
  @overload
@@ -1926,8 +1931,7 @@ def concat(
1926
1931
  copy: bool = True,
1927
1932
  inplace: "Literal[False]",
1928
1933
  reset_dim: bool = True,
1929
- ) -> InferenceData:
1930
- ...
1934
+ ) -> InferenceData: ...
1931
1935
 
1932
1936
 
1933
1937
  @overload
@@ -1938,8 +1942,7 @@ def concat(
1938
1942
  copy: bool = True,
1939
1943
  inplace: "Literal[False]",
1940
1944
  reset_dim: bool = True,
1941
- ) -> InferenceData:
1942
- ...
1945
+ ) -> InferenceData: ...
1943
1946
 
1944
1947
 
1945
1948
  @overload
@@ -1950,8 +1953,7 @@ def concat(
1950
1953
  copy: bool = True,
1951
1954
  inplace: "Literal[True]",
1952
1955
  reset_dim: bool = True,
1953
- ) -> None:
1954
- ...
1956
+ ) -> None: ...
1955
1957
 
1956
1958
 
1957
1959
  @overload
@@ -1962,8 +1964,7 @@ def concat(
1962
1964
  copy: bool = True,
1963
1965
  inplace: bool = False,
1964
1966
  reset_dim: bool = True,
1965
- ) -> Optional[InferenceData]:
1966
- ...
1967
+ ) -> Optional[InferenceData]: ...
1967
1968
 
1968
1969
 
1969
1970
  # pylint: disable=protected-access, inconsistent-return-statements
@@ -2080,7 +2081,7 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
2080
2081
  else:
2081
2082
  return args[0]
2082
2083
 
2083
- current_time = str(datetime.now())
2084
+ current_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
2084
2085
  combined_attr = defaultdict(list)
2085
2086
  for idata in args:
2086
2087
  for key, val in idata.attrs.items():
arviz/data/io_cmdstan.py CHANGED
@@ -732,9 +732,7 @@ def _process_configuration(comments):
732
732
  key = (
733
733
  "warmup_time_seconds"
734
734
  if "(Warm-up)" in comment
735
- else "sampling_time_seconds"
736
- if "(Sampling)" in comment
737
- else "total_time_seconds"
735
+ else "sampling_time_seconds" if "(Sampling)" in comment else "total_time_seconds"
738
736
  )
739
737
  results[key] = float(value)
740
738
  elif "=" in comment:
arviz/data/io_datatree.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Conversion between InferenceData and DataTree."""
2
+
2
3
  from .inference_data import InferenceData
3
4
 
4
5
 
arviz/data/io_dict.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Dictionary specific conversion code."""
2
+
2
3
  import warnings
3
4
  from typing import Optional
4
5
 
@@ -59,9 +60,7 @@ class DictConverter:
59
60
  self.coords = (
60
61
  coords
61
62
  if pred_coords is None
62
- else pred_coords
63
- if coords is None
64
- else {**coords, **pred_coords}
63
+ else pred_coords if coords is None else {**coords, **pred_coords}
65
64
  )
66
65
  self.index_origin = index_origin
67
66
  self.coords = coords
@@ -458,3 +457,6 @@ def from_dict(
458
457
  attrs=attrs,
459
458
  **kwargs,
460
459
  ).to_inference_data()
460
+
461
+
462
+ from_pytree = from_dict
arviz/data/io_emcee.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """emcee-specific conversion code."""
2
+
2
3
  import warnings
3
4
  from collections import OrderedDict
4
5
 
arviz/data/io_numpyro.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """NumPyro-specific conversion code."""
2
+
2
3
  import logging
3
4
  from typing import Callable, Optional
4
5
 
arviz/data/io_pyjags.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Convert PyJAGS sample dictionaries to ArviZ inference data objects."""
2
+
2
3
  import typing as tp
3
4
  from collections import OrderedDict
4
5
  from collections.abc import Iterable
arviz/data/io_pyro.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Pyro-specific conversion code."""
2
+
2
3
  import logging
3
4
  from typing import Callable, Optional
4
5
  import warnings
arviz/data/io_pystan.py CHANGED
@@ -676,8 +676,7 @@ def get_draws(fit, variables=None, ignore=None, warmup=False, dtypes=None):
676
676
  for item in par_keys:
677
677
  _, shape = item.replace("]", "").split("[")
678
678
  shape_idx_min = min(int(shape_value) for shape_value in shape.split(","))
679
- if shape_idx_min < shift:
680
- shift = shape_idx_min
679
+ shift = min(shift, shape_idx_min)
681
680
  # If shift is higher than 1, this will probably mean that Stan
682
681
  # has implemented sparse structure (saves only non-zero parts),
683
682
  # but let's hope that dims are still corresponding to the full shape
arviz/data/utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Data specific utilities."""
2
+
2
3
  import warnings
3
4
  import numpy as np
4
5
 
arviz/plots/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Plotting functions."""
2
+
2
3
  from .autocorrplot import plot_autocorr
3
4
  from .bpvplot import plot_bpv
4
5
  from .bfplot import plot_bf
@@ -1,4 +1,5 @@
1
1
  """Autocorrelation plot of data."""
2
+
2
3
  from ..data import convert_to_dataset
3
4
  from ..labels import BaseLabeller
4
5
  from ..sel_utils import xarray_var_iter
@@ -1,4 +1,5 @@
1
1
  """Bokeh Autocorrplot."""
2
+
2
3
  import numpy as np
3
4
  from bokeh.models import DataRange1d, BoxAnnotation
4
5
  from bokeh.models.annotations import Title
@@ -1,4 +1,5 @@
1
1
  """Bokeh Bayesian p-value Posterior predictive plot."""
2
+
2
3
  import numpy as np
3
4
  from bokeh.models import BoxAnnotation
4
5
  from bokeh.models.annotations import Title
@@ -171,8 +172,13 @@ def plot_bpv(
171
172
  ax_i.line(0, 0, legend_label=f"bpv={p_value:.2f}", alpha=0)
172
173
 
173
174
  if plot_mean:
174
- ax_i.circle(
175
- obs_vals.mean(), 0, fill_color=color, line_color="black", size=markersize
175
+ ax_i.scatter(
176
+ obs_vals.mean(),
177
+ 0,
178
+ fill_color=color,
179
+ line_color="black",
180
+ size=markersize,
181
+ marker="circle",
176
182
  )
177
183
 
178
184
  _title = Title()
@@ -1,4 +1,5 @@
1
1
  """Bokeh Compareplot."""
2
+
2
3
  from bokeh.models import Span
3
4
  from bokeh.models.annotations import Title, Legend
4
5
 
@@ -69,13 +70,14 @@ def plot_compare(
69
70
  err_ys.append((y, y))
70
71
 
71
72
  # plot them
72
- dif_tri = ax.triangle(
73
+ dif_tri = ax.scatter(
73
74
  comp_df[information_criterion].iloc[1:],
74
75
  yticks_pos[1::2],
75
76
  line_color=plot_kwargs.get("color_dse", "grey"),
76
77
  fill_color=plot_kwargs.get("color_dse", "grey"),
77
78
  line_width=2,
78
79
  size=6,
80
+ marker="triangle",
79
81
  )
80
82
  dif_line = ax.multi_line(err_xs, err_ys, line_color=plot_kwargs.get("color_dse", "grey"))
81
83
 
@@ -85,13 +87,14 @@ def plot_compare(
85
87
  ax.yaxis.ticker = yticks_pos[::2]
86
88
  ax.yaxis.major_label_overrides = dict(zip(yticks_pos[::2], yticks_labels))
87
89
 
88
- elpd_circ = ax.circle(
90
+ elpd_circ = ax.scatter(
89
91
  comp_df[information_criterion],
90
92
  yticks_pos[::2],
91
93
  line_color=plot_kwargs.get("color_ic", "black"),
92
94
  fill_color=None,
93
95
  line_width=2,
94
96
  size=6,
97
+ marker="circle",
95
98
  )
96
99
  elpd_label = [elpd_circ]
97
100
 
@@ -110,7 +113,7 @@ def plot_compare(
110
113
 
111
114
  labels.append(("ELPD", elpd_label))
112
115
 
113
- scale = comp_df["scale"][0]
116
+ scale = comp_df["scale"].iloc[0]
114
117
 
115
118
  if insample_dev:
116
119
  p_ic = comp_df[f"p_{information_criterion.split('_')[1]}"]
@@ -120,13 +123,14 @@ def plot_compare(
120
123
  correction = -p_ic
121
124
  elif scale == "deviance":
122
125
  correction = -(2 * p_ic)
123
- insample_circ = ax.circle(
126
+ insample_circ = ax.scatter(
124
127
  comp_df[information_criterion] + correction,
125
128
  yticks_pos[::2],
126
129
  line_color=plot_kwargs.get("color_insample_dev", "black"),
127
130
  fill_color=plot_kwargs.get("color_insample_dev", "black"),
128
131
  line_width=2,
129
132
  size=6,
133
+ marker="circle",
130
134
  )
131
135
  labels.append(("In-sample ELPD", [insample_circ]))
132
136
 
@@ -1,4 +1,5 @@
1
1
  """Bokeh Densityplot."""
2
+
2
3
  from collections import defaultdict
3
4
  from itertools import cycle
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Bokeh Distplot."""
2
+
2
3
  import matplotlib.pyplot as plt
3
4
  import numpy as np
4
5