arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/data/__init__.py DELETED
@@ -1,55 +0,0 @@
1
- """Code for loading and manipulating data structures."""
2
-
3
- from .base import CoordSpec, DimSpec, dict_to_dataset, numpy_to_data_array, pytree_to_dataset
4
- from .converters import convert_to_dataset, convert_to_inference_data
5
- from .datasets import clear_data_home, list_datasets, load_arviz_data
6
- from .inference_data import InferenceData, concat
7
- from .io_beanmachine import from_beanmachine
8
- from .io_cmdstan import from_cmdstan
9
- from .io_cmdstanpy import from_cmdstanpy
10
- from .io_datatree import from_datatree, to_datatree
11
- from .io_dict import from_dict, from_pytree
12
- from .io_emcee import from_emcee
13
- from .io_json import from_json, to_json
14
- from .io_netcdf import from_netcdf, to_netcdf
15
- from .io_numpyro import from_numpyro
16
- from .io_pyjags import from_pyjags
17
- from .io_pyro import from_pyro
18
- from .io_pystan import from_pystan
19
- from .io_zarr import from_zarr, to_zarr
20
- from .utils import extract, extract_dataset
21
-
22
- __all__ = [
23
- "InferenceData",
24
- "concat",
25
- "load_arviz_data",
26
- "list_datasets",
27
- "clear_data_home",
28
- "numpy_to_data_array",
29
- "extract",
30
- "extract_dataset",
31
- "dict_to_dataset",
32
- "convert_to_dataset",
33
- "convert_to_inference_data",
34
- "from_beanmachine",
35
- "from_pyjags",
36
- "from_pystan",
37
- "from_emcee",
38
- "from_cmdstan",
39
- "from_cmdstanpy",
40
- "from_datatree",
41
- "from_dict",
42
- "from_pytree",
43
- "from_json",
44
- "from_pyro",
45
- "from_numpyro",
46
- "from_netcdf",
47
- "pytree_to_dataset",
48
- "to_datatree",
49
- "to_json",
50
- "to_netcdf",
51
- "from_zarr",
52
- "to_zarr",
53
- "CoordSpec",
54
- "DimSpec",
55
- ]
arviz/data/base.py DELETED
@@ -1,596 +0,0 @@
1
- """Low level converters usually used by other functions."""
2
-
3
- import datetime
4
- import functools
5
- import importlib
6
- import re
7
- import warnings
8
- from copy import deepcopy
9
- from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
10
-
11
- import numpy as np
12
- import xarray as xr
13
-
14
- try:
15
- import tree
16
- except ImportError:
17
- tree = None
18
-
19
- try:
20
- import ujson as json
21
- except ImportError:
22
- # mypy struggles with conditional imports expressed as catching ImportError:
23
- # https://github.com/python/mypy/issues/1153
24
- import json # type: ignore
25
-
26
- from .. import __version__, utils
27
- from ..rcparams import rcParams
28
-
29
- CoordSpec = Dict[str, List[Any]]
30
- DimSpec = Dict[str, List[str]]
31
- RequiresArgTypeT = TypeVar("RequiresArgTypeT")
32
- RequiresReturnTypeT = TypeVar("RequiresReturnTypeT")
33
-
34
-
35
- class requires: # pylint: disable=invalid-name
36
- """Decorator to return None if an object does not have the required attribute.
37
-
38
- If the decorator is called various times on the same function with different
39
- attributes, it will return None if one of them is missing. If instead a list
40
- of attributes is passed, it will return None if all attributes in the list are
41
- missing. Both functionalities can be combined as desired.
42
- """
43
-
44
- def __init__(self, *props: Union[str, List[str]]) -> None:
45
- self.props: Tuple[Union[str, List[str]], ...] = props
46
-
47
- # Until typing.ParamSpec (https://www.python.org/dev/peps/pep-0612/) is available
48
- # in all our supported Python versions, there is no way to simultaneously express
49
- # the following two properties:
50
- # - the input function may take arbitrary args/kwargs, and
51
- # - the output function takes those same arbitrary args/kwargs, but has a different return type.
52
- # We either have to limit the input function to e.g. only allowing a "self" argument,
53
- # or we have to adopt the current approach of annotating the returned function as if
54
- # it was defined as "def f(*args: Any, **kwargs: Any) -> Optional[RequiresReturnTypeT]".
55
- #
56
- # Since all functions decorated with @requires currently only accept a single argument,
57
- # we choose to limit application of @requires to only functions of one argument.
58
- # When typing.ParamSpec is available, this definition can be updated to use it.
59
- # See https://github.com/arviz-devs/arviz/pull/1504 for more discussion.
60
- def __call__(
61
- self, func: Callable[[RequiresArgTypeT], RequiresReturnTypeT]
62
- ) -> Callable[[RequiresArgTypeT], Optional[RequiresReturnTypeT]]: # noqa: D202
63
- """Wrap the decorated function."""
64
-
65
- def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]:
66
- """Return None if not all props are available."""
67
- for prop in self.props:
68
- prop = [prop] if isinstance(prop, str) else prop
69
- if all((getattr(cls, prop_i) is None for prop_i in prop)):
70
- return None
71
- return func(cls)
72
-
73
- return wrapped
74
-
75
-
76
- def _yield_flat_up_to(shallow_tree, input_tree, path=()):
77
- """Yields (path, value) pairs of input_tree flattened up to shallow_tree.
78
-
79
- Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow
80
- lists as leaves.
81
-
82
- Args:
83
- shallow_tree: Nested structure. Traverse no further than its leaf nodes.
84
- input_tree: Nested structure. Return the paths and values from this tree.
85
- Must have the same upper structure as shallow_tree.
86
- path: Tuple. Optional argument, only used when recursing. The path from the
87
- root of the original shallow_tree, down to the root of the shallow_tree
88
- arg of this recursive call.
89
-
90
- Yields:
91
- Pairs of (path, value), where path the tuple path of a leaf node in
92
- shallow_tree, and value is the value of the corresponding node in
93
- input_tree.
94
- """
95
- # pylint: disable=protected-access
96
- if tree is None:
97
- raise ImportError("Missing optional dependency 'dm-tree'. Use pip or conda to install it")
98
-
99
- if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
100
- isinstance(shallow_tree, tree.collections_abc.Mapping)
101
- or tree._is_namedtuple(shallow_tree)
102
- or tree._is_attrs(shallow_tree)
103
- ):
104
- yield (path, input_tree)
105
- else:
106
- input_tree = dict(tree._yield_sorted_items(input_tree))
107
- for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree):
108
- subpath = path + (shallow_key,)
109
- input_subtree = input_tree[shallow_key]
110
- for leaf_path, leaf_value in _yield_flat_up_to(
111
- shallow_subtree, input_subtree, path=subpath
112
- ):
113
- yield (leaf_path, leaf_value)
114
- # pylint: enable=protected-access
115
-
116
-
117
- def _flatten_with_path(structure):
118
- return list(_yield_flat_up_to(structure, structure))
119
-
120
-
121
- def generate_dims_coords(
122
- shape,
123
- var_name,
124
- dims=None,
125
- coords=None,
126
- default_dims=None,
127
- index_origin=None,
128
- skip_event_dims=None,
129
- ):
130
- """Generate default dimensions and coordinates for a variable.
131
-
132
- Parameters
133
- ----------
134
- shape : tuple[int]
135
- Shape of the variable
136
- var_name : str
137
- Name of the variable. If no dimension name(s) is provided, ArviZ
138
- will generate a default dimension name using ``var_name``, e.g.,
139
- ``"foo_dim_0"`` for the first dimension if ``var_name`` is ``"foo"``.
140
- dims : list
141
- List of dimensions for the variable
142
- coords : dict[str] -> list[str]
143
- Map of dimensions to coordinates
144
- default_dims : list[str]
145
- Dimension names that are not part of the variable's shape. For example,
146
- when manipulating Monte Carlo traces, the ``default_dims`` would be
147
- ``["chain" , "draw"]`` which ArviZ uses as its own names for dimensions
148
- of MCMC traces.
149
- index_origin : int, optional
150
- Starting value of integer coordinate values. Defaults to the value in rcParam
151
- ``data.index_origin``.
152
- skip_event_dims : bool, default False
153
-
154
- Returns
155
- -------
156
- list[str]
157
- Default dims
158
- dict[str] -> list[str]
159
- Default coords
160
- """
161
- if index_origin is None:
162
- index_origin = rcParams["data.index_origin"]
163
- if default_dims is None:
164
- default_dims = []
165
- if dims is None:
166
- dims = []
167
- if skip_event_dims is None:
168
- skip_event_dims = False
169
-
170
- if coords is None:
171
- coords = {}
172
-
173
- coords = deepcopy(coords)
174
- dims = deepcopy(dims)
175
-
176
- ndims = len([dim for dim in dims if dim not in default_dims])
177
- if ndims > len(shape):
178
- if skip_event_dims:
179
- dims = dims[: len(shape)]
180
- else:
181
- warnings.warn(
182
- (
183
- "In variable {var_name}, there are "
184
- + "more dims ({dims_len}) given than exist ({shape_len}). "
185
- + "Passed array should have shape ({defaults}*shape)"
186
- ).format(
187
- var_name=var_name,
188
- dims_len=len(dims),
189
- shape_len=len(shape),
190
- defaults=",".join(default_dims) + ", " if default_dims is not None else "",
191
- ),
192
- UserWarning,
193
- )
194
- if skip_event_dims:
195
- # this is needed in case the reduction keeps the dimension with size 1
196
- for i, (dim, dim_size) in enumerate(zip(dims, shape)):
197
- if (dim in coords) and (dim_size != len(coords[dim])):
198
- dims = dims[:i]
199
- break
200
-
201
- for i, dim_len in enumerate(shape):
202
- idx = i + len([dim for dim in default_dims if dim in dims])
203
- if len(dims) < idx + 1:
204
- dim_name = f"{var_name}_dim_{i}"
205
- dims.append(dim_name)
206
- elif dims[idx] is None:
207
- dim_name = f"{var_name}_dim_{i}"
208
- dims[idx] = dim_name
209
- dim_name = dims[idx]
210
- if dim_name not in coords:
211
- coords[dim_name] = np.arange(index_origin, dim_len + index_origin)
212
- coords = {
213
- key: coord
214
- for key, coord in coords.items()
215
- if any(key == dim for dim in dims + default_dims)
216
- }
217
- return dims, coords
218
-
219
-
220
- def numpy_to_data_array(
221
- ary,
222
- *,
223
- var_name="data",
224
- coords=None,
225
- dims=None,
226
- default_dims=None,
227
- index_origin=None,
228
- skip_event_dims=None,
229
- ):
230
- """Convert a numpy array to an xarray.DataArray.
231
-
232
- By default, the first two dimensions will be (chain, draw), and any remaining
233
- dimensions will be "shape".
234
- * If the numpy array is 1d, this dimension is interpreted as draw
235
- * If the numpy array is 2d, it is interpreted as (chain, draw)
236
- * If the numpy array is 3 or more dimensions, the last dimensions are kept as shapes.
237
-
238
- To modify this behaviour, use ``default_dims``.
239
-
240
- Parameters
241
- ----------
242
- ary : np.ndarray
243
- A numpy array. If it has 2 or more dimensions, the first dimension should be
244
- independent chains from a simulation. Use `np.expand_dims(ary, 0)` to add a
245
- single dimension to the front if there is only 1 chain.
246
- var_name : str
247
- If there are no dims passed, this string is used to name dimensions
248
- coords : dict[str, iterable]
249
- A dictionary containing the values that are used as index. The key
250
- is the name of the dimension, the values are the index values.
251
- dims : List(str)
252
- A list of coordinate names for the variable
253
- default_dims : list of str, optional
254
- Passed to :py:func:`generate_dims_coords`. Defaults to ``["chain", "draw"]``, and
255
- an empty list is accepted
256
- index_origin : int, optional
257
- Passed to :py:func:`generate_dims_coords`
258
- skip_event_dims : bool
259
-
260
- Returns
261
- -------
262
- xr.DataArray
263
- Will have the same data as passed, but with coordinates and dimensions
264
- """
265
- # manage and transform copies
266
- if default_dims is None:
267
- default_dims = ["chain", "draw"]
268
- if "chain" in default_dims and "draw" in default_dims:
269
- ary = utils.two_de(ary)
270
- n_chains, n_samples, *_ = ary.shape
271
- if n_chains > n_samples:
272
- warnings.warn(
273
- "More chains ({n_chains}) than draws ({n_samples}). "
274
- "Passed array should have shape (chains, draws, *shape)".format(
275
- n_chains=n_chains, n_samples=n_samples
276
- ),
277
- UserWarning,
278
- )
279
- else:
280
- ary = utils.one_de(ary)
281
-
282
- dims, coords = generate_dims_coords(
283
- ary.shape[len(default_dims) :],
284
- var_name,
285
- dims=dims,
286
- coords=coords,
287
- default_dims=default_dims,
288
- index_origin=index_origin,
289
- skip_event_dims=skip_event_dims,
290
- )
291
-
292
- # reversed order for default dims: 'chain', 'draw'
293
- if "draw" not in dims and "draw" in default_dims:
294
- dims = ["draw"] + dims
295
- if "chain" not in dims and "chain" in default_dims:
296
- dims = ["chain"] + dims
297
-
298
- index_origin = rcParams["data.index_origin"]
299
- if "chain" not in coords and "chain" in default_dims:
300
- coords["chain"] = np.arange(index_origin, n_chains + index_origin)
301
- if "draw" not in coords and "draw" in default_dims:
302
- coords["draw"] = np.arange(index_origin, n_samples + index_origin)
303
-
304
- # filter coords based on the dims
305
- coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
306
- return xr.DataArray(ary, coords=coords, dims=dims)
307
-
308
-
309
- def dict_to_dataset(
310
- data,
311
- *,
312
- attrs=None,
313
- library=None,
314
- coords=None,
315
- dims=None,
316
- default_dims=None,
317
- index_origin=None,
318
- skip_event_dims=None,
319
- ):
320
- """Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
321
-
322
- ArviZ itself supports conversion of flat dictionaries.
323
- Suport for pytrees requires ``dm-tree`` which is an optional dependency.
324
- See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
325
- this inclues at least dictionaries and tuple types.
326
-
327
- Parameters
328
- ----------
329
- data : dict of {str : array_like or dict} or pytree
330
- Data to convert. Keys are variable names.
331
- attrs : dict, optional
332
- Json serializable metadata to attach to the dataset, in addition to defaults.
333
- library : module, optional
334
- Library used for performing inference. Will be attached to the attrs metadata.
335
- coords : dict of {str : ndarray}, optional
336
- Coordinates for the dataset
337
- dims : dict of {str : list of str}, optional
338
- Dimensions of each variable. The keys are variable names, values are lists of
339
- coordinates.
340
- default_dims : list of str, optional
341
- Passed to :py:func:`numpy_to_data_array`
342
- index_origin : int, optional
343
- Passed to :py:func:`numpy_to_data_array`
344
- skip_event_dims : bool, optional
345
- If True, cut extra dims whenever present to match the shape of the data.
346
- Necessary for PPLs which have the same name in both observed data and log
347
- likelihood groups, to account for their different shapes when observations are
348
- multivariate.
349
-
350
- Returns
351
- -------
352
- xarray.Dataset
353
- In case of nested pytrees, the variable name will be a tuple of individual names.
354
-
355
- Notes
356
- -----
357
- This function is available through two aliases: ``dict_to_dataset`` or ``pytree_to_dataset``.
358
-
359
- Examples
360
- --------
361
- Convert a dictionary with two 2D variables to a Dataset.
362
-
363
- .. ipython::
364
-
365
- In [1]: import arviz as az
366
- ...: import numpy as np
367
- ...: az.dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
368
-
369
- Note that unlike the :class:`xarray.Dataset` constructor, ArviZ has added extra
370
- information to the generated Dataset such as default dimension names for sampled
371
- dimensions and some attributes.
372
-
373
- The function is also general enough to work on pytrees such as nested dictionaries:
374
-
375
- .. ipython::
376
-
377
- In [1]: az.pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})
378
-
379
- which has two variables (as many as leafs) named ``('top', 'second')`` and ``top2``.
380
-
381
- Dimensions and co-ordinates can be defined as usual:
382
-
383
- .. ipython::
384
-
385
- In [1]: datadict = {
386
- ...: "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
387
- ...: "d": np.random.randn(100),
388
- ...: }
389
- ...: az.dict_to_dataset(
390
- ...: datadict,
391
- ...: coords={"c": np.arange(10)},
392
- ...: dims={("top", "b"): ["c"]}
393
- ...: )
394
-
395
- """
396
- if dims is None:
397
- dims = {}
398
-
399
- if tree is not None:
400
- try:
401
- data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
402
- except TypeError: # probably unsortable keys -- the function will still work if
403
- pass # it is an honest dictionary.
404
-
405
- data_vars = {
406
- key: numpy_to_data_array(
407
- values,
408
- var_name=key,
409
- coords=coords,
410
- dims=dims.get(key),
411
- default_dims=default_dims,
412
- index_origin=index_origin,
413
- skip_event_dims=skip_event_dims,
414
- )
415
- for key, values in data.items()
416
- }
417
- return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
418
-
419
-
420
- pytree_to_dataset = dict_to_dataset
421
-
422
-
423
- def make_attrs(attrs=None, library=None):
424
- """Make standard attributes to attach to xarray datasets.
425
-
426
- Parameters
427
- ----------
428
- attrs : dict (optional)
429
- Additional attributes to add or overwrite
430
-
431
- Returns
432
- -------
433
- dict
434
- attrs
435
- """
436
- default_attrs = {
437
- "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
438
- "arviz_version": __version__,
439
- }
440
- if library is not None:
441
- library_name = library.__name__
442
- default_attrs["inference_library"] = library_name
443
- try:
444
- version = importlib.metadata.version(library_name)
445
- default_attrs["inference_library_version"] = version
446
- except importlib.metadata.PackageNotFoundError:
447
- if hasattr(library, "__version__"):
448
- version = library.__version__
449
- default_attrs["inference_library_version"] = version
450
- if attrs is not None:
451
- default_attrs.update(attrs)
452
- return default_attrs
453
-
454
-
455
- def _extend_xr_method(func, doc=None, description="", examples="", see_also=""):
456
- """Make wrapper to extend methods from xr.Dataset to InferenceData Class.
457
-
458
- Parameters
459
- ----------
460
- func : callable
461
- An xr.Dataset function
462
- doc : str
463
- docstring for the func
464
- description : str
465
- the description of the func to be added in docstring
466
- examples : str
467
- the examples of the func to be added in docstring
468
- see_also : str, list
469
- the similar methods of func to be included in See Also section of docstring
470
-
471
- """
472
- # pydocstyle requires a non empty line
473
-
474
- @functools.wraps(func)
475
- def wrapped(self, *args, **kwargs):
476
- _filter = kwargs.pop("filter_groups", None)
477
- _groups = kwargs.pop("groups", None)
478
- _inplace = kwargs.pop("inplace", False)
479
-
480
- out = self if _inplace else deepcopy(self)
481
-
482
- groups = self._group_names(_groups, _filter) # pylint: disable=protected-access
483
- for group in groups:
484
- xr_data = getattr(out, group)
485
- xr_data = func(xr_data, *args, **kwargs) # pylint: disable=not-callable
486
- setattr(out, group, xr_data)
487
-
488
- return None if _inplace else out
489
-
490
- description_default = """{method_name} method is extended from xarray.Dataset methods.
491
-
492
- {description}
493
-
494
- For more info see :meth:`xarray:xarray.Dataset.{method_name}`.
495
- In addition to the arguments available in the original method, the following
496
- ones are added by ArviZ to adapt the method to being called on an ``InferenceData`` object.
497
- """.format(
498
- description=description, method_name=func.__name__ # pylint: disable=no-member
499
- )
500
- params = """
501
- Other Parameters
502
- ----------------
503
- groups: str or list of str, optional
504
- Groups where the selection is to be applied. Can either be group names
505
- or metagroup names.
506
- filter_groups: {None, "like", "regex"}, optional, default=None
507
- If `None` (default), interpret groups as the real group or metagroup names.
508
- If "like", interpret groups as substrings of the real group or metagroup names.
509
- If "regex", interpret groups as regular expressions on the real group or
510
- metagroup names. A la `pandas.filter`.
511
- inplace: bool, optional
512
- If ``True``, modify the InferenceData object inplace,
513
- otherwise, return the modified copy.
514
- """
515
-
516
- if not isinstance(see_also, str):
517
- see_also = "\n".join(see_also)
518
- see_also_basic = """
519
- See Also
520
- --------
521
- xarray.Dataset.{method_name}
522
- {custom_see_also}
523
- """.format(
524
- method_name=func.__name__, custom_see_also=see_also # pylint: disable=no-member
525
- )
526
- wrapped.__doc__ = (
527
- description_default + params + examples + see_also_basic if doc is None else doc
528
- )
529
-
530
- return wrapped
531
-
532
-
533
- def _make_json_serializable(data: dict) -> dict:
534
- """Convert `data` with numpy.ndarray-like values to JSON-serializable form."""
535
- ret = {}
536
- for key, value in data.items():
537
- try:
538
- json.dumps(value)
539
- except (TypeError, OverflowError):
540
- pass
541
- else:
542
- ret[key] = value
543
- continue
544
- if isinstance(value, dict):
545
- ret[key] = _make_json_serializable(value)
546
- elif isinstance(value, np.ndarray):
547
- ret[key] = np.asarray(value).tolist()
548
- else:
549
- raise TypeError(
550
- f"Value associated with variable `{type(value)}` is not JSON serializable."
551
- )
552
- return ret
553
-
554
-
555
- def infer_stan_dtypes(stan_code):
556
- """Infer Stan integer variables from generated quantities block."""
557
- # Remove old deprecated comments
558
- stan_code = "\n".join(
559
- line if "#" not in line else line[: line.find("#")] for line in stan_code.splitlines()
560
- )
561
- pattern_remove_comments = re.compile(
562
- r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE
563
- )
564
- stan_code = re.sub(pattern_remove_comments, "", stan_code)
565
-
566
- # Check generated quantities
567
- if "generated quantities" not in stan_code:
568
- return {}
569
-
570
- # Extract generated quantities block
571
- gen_quantities_location = stan_code.index("generated quantities")
572
- block_start = gen_quantities_location + stan_code[gen_quantities_location:].index("{")
573
-
574
- curly_bracket_count = 0
575
- block_end = None
576
- for block_end, char in enumerate(stan_code[block_start:], block_start + 1):
577
- if char == "{":
578
- curly_bracket_count += 1
579
- elif char == "}":
580
- curly_bracket_count -= 1
581
-
582
- if curly_bracket_count == 0:
583
- break
584
-
585
- stan_code = stan_code[block_start:block_end]
586
-
587
- stan_integer = r"int"
588
- stan_limits = r"(?:\<[^\>]+\>)*" # ignore group: 0 or more <....>
589
- stan_param = r"([^;=\s\[]+)" # capture group: ends= ";", "=", "[" or whitespace
590
- stan_ws = r"\s*" # 0 or more whitespace
591
- stan_ws_one = r"\s+" # 1 or more whitespace
592
- pattern_int = re.compile(
593
- "".join((stan_integer, stan_ws_one, stan_limits, stan_ws, stan_param)), re.IGNORECASE
594
- )
595
- dtypes = {key.strip(): "int" for key in re.findall(pattern_int, stan_code)}
596
- return dtypes