arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,2386 +0,0 @@
1
- # pylint: disable=too-many-lines,too-many-public-methods
2
- """Data structure for using netcdf groups with xarray."""
3
- import os
4
- import re
5
- import sys
6
- import uuid
7
- import warnings
8
- from collections import OrderedDict, defaultdict
9
- from collections.abc import MutableMapping, Sequence
10
- from copy import copy as ccopy
11
- from copy import deepcopy
12
- import datetime
13
- from html import escape
14
- from typing import (
15
- TYPE_CHECKING,
16
- Any,
17
- Iterable,
18
- Iterator,
19
- List,
20
- Mapping,
21
- Optional,
22
- Tuple,
23
- TypeVar,
24
- Union,
25
- overload,
26
- )
27
-
28
- import numpy as np
29
- import xarray as xr
30
- from packaging import version
31
-
32
- from ..rcparams import rcParams
33
- from ..utils import HtmlTemplate, _subset_list, _var_names, either_dict_or_kwargs
34
- from .base import _extend_xr_method, _make_json_serializable, dict_to_dataset
35
-
36
- if sys.version_info[:2] >= (3, 9):
37
- # As of 3.9, collections.abc types support generic parameters themselves.
38
- from collections.abc import ItemsView, ValuesView
39
- else:
40
- # These typing imports are deprecated in 3.9, and moved to collections.abc instead.
41
- from typing import ItemsView, ValuesView
42
-
43
- if TYPE_CHECKING:
44
- from typing_extensions import Literal
45
-
46
- try:
47
- import ujson as json
48
- except ImportError:
49
- # mypy struggles with conditional imports expressed as catching ImportError:
50
- # https://github.com/python/mypy/issues/1153
51
- import json # type: ignore
52
-
53
-
54
- SUPPORTED_GROUPS = [
55
- "posterior",
56
- "posterior_predictive",
57
- "predictions",
58
- "log_likelihood",
59
- "log_prior",
60
- "sample_stats",
61
- "prior",
62
- "prior_predictive",
63
- "sample_stats_prior",
64
- "observed_data",
65
- "constant_data",
66
- "predictions_constant_data",
67
- "unconstrained_posterior",
68
- "unconstrained_prior",
69
- ]
70
-
71
- WARMUP_TAG = "warmup_"
72
-
73
- SUPPORTED_GROUPS_WARMUP = [
74
- f"{WARMUP_TAG}posterior",
75
- f"{WARMUP_TAG}posterior_predictive",
76
- f"{WARMUP_TAG}predictions",
77
- f"{WARMUP_TAG}sample_stats",
78
- f"{WARMUP_TAG}log_likelihood",
79
- f"{WARMUP_TAG}log_prior",
80
- ]
81
-
82
- SUPPORTED_GROUPS_ALL = SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP
83
-
84
- InferenceDataT = TypeVar("InferenceDataT", bound="InferenceData")
85
-
86
-
87
- def _compressible_dtype(dtype):
88
- """Check basic dtypes for automatic compression."""
89
- if dtype.kind == "V":
90
- return all(_compressible_dtype(item) for item, _ in dtype.fields.values())
91
- return dtype.kind in {"b", "i", "u", "f", "c", "S"}
92
-
93
-
94
- class InferenceData(Mapping[str, xr.Dataset]):
95
- """Container for inference data storage using xarray.
96
-
97
- For a detailed introduction to ``InferenceData`` objects and their usage, see
98
- :ref:`xarray_for_arviz`. This page provides help and documentation
99
- on ``InferenceData`` methods and their low level implementation.
100
- """
101
-
102
- def __init__(
103
- self,
104
- attrs: Union[None, Mapping[Any, Any]] = None,
105
- warn_on_custom_groups: bool = False,
106
- **kwargs: Union[xr.Dataset, List[xr.Dataset], Tuple[xr.Dataset, xr.Dataset]],
107
- ) -> None:
108
- """Initialize InferenceData object from keyword xarray datasets.
109
-
110
- Parameters
111
- ----------
112
- attrs : dict
113
- sets global attribute for InferenceData object.
114
- warn_on_custom_groups : bool, default False
115
- Emit a warning when custom groups are present in the InferenceData.
116
- "custom group" means any group whose name isn't defined in :ref:`schema`
117
- kwargs :
118
- Keyword arguments of xarray datasets
119
-
120
- Examples
121
- --------
122
- Initiate an InferenceData object from scratch, not recommended. InferenceData
123
- objects should be initialized using ``from_xyz`` methods, see :ref:`data_api` for more
124
- details.
125
-
126
- .. ipython::
127
-
128
- In [1]: import arviz as az
129
- ...: import numpy as np
130
- ...: import xarray as xr
131
- ...: dataset = xr.Dataset(
132
- ...: {
133
- ...: "a": (["chain", "draw", "a_dim"], np.random.normal(size=(4, 100, 3))),
134
- ...: "b": (["chain", "draw"], np.random.normal(size=(4, 100))),
135
- ...: },
136
- ...: coords={
137
- ...: "chain": (["chain"], np.arange(4)),
138
- ...: "draw": (["draw"], np.arange(100)),
139
- ...: "a_dim": (["a_dim"], ["x", "y", "z"]),
140
- ...: }
141
- ...: )
142
- ...: idata = az.InferenceData(posterior=dataset, prior=dataset)
143
- ...: idata
144
-
145
- We have created an ``InferenceData`` object with two groups. Now we can check its
146
- contents:
147
-
148
- .. ipython::
149
-
150
- In [1]: idata.posterior
151
-
152
- """
153
- self._groups: List[str] = []
154
- self._groups_warmup: List[str] = []
155
- self._attrs: Union[None, dict] = dict(attrs) if attrs is not None else None
156
- save_warmup = kwargs.pop("save_warmup", False)
157
- key_list = [key for key in SUPPORTED_GROUPS_ALL if key in kwargs]
158
- for key in kwargs:
159
- if key not in SUPPORTED_GROUPS_ALL:
160
- key_list.append(key)
161
- if warn_on_custom_groups:
162
- warnings.warn(
163
- f"{key} group is not defined in the InferenceData scheme", UserWarning
164
- )
165
- for key in key_list:
166
- dataset = kwargs[key]
167
- dataset_warmup = None
168
- if dataset is None:
169
- continue
170
- elif isinstance(dataset, (list, tuple)):
171
- dataset, dataset_warmup = dataset
172
- elif not isinstance(dataset, xr.Dataset):
173
- raise ValueError(
174
- "Arguments to InferenceData must be xarray Datasets "
175
- f"(argument '{key}' was type '{type(dataset)}')"
176
- )
177
- if not key.startswith(WARMUP_TAG):
178
- if dataset:
179
- setattr(self, key, dataset)
180
- self._groups.append(key)
181
- elif key.startswith(WARMUP_TAG):
182
- if dataset:
183
- setattr(self, key, dataset)
184
- self._groups_warmup.append(key)
185
- if save_warmup and dataset_warmup is not None and dataset_warmup:
186
- key = f"{WARMUP_TAG}{key}"
187
- setattr(self, key, dataset_warmup)
188
- self._groups_warmup.append(key)
189
-
190
- @property
191
- def attrs(self) -> dict:
192
- """Attributes of InferenceData object."""
193
- if self._attrs is None:
194
- self._attrs = {}
195
- return self._attrs
196
-
197
- @attrs.setter
198
- def attrs(self, value) -> None:
199
- self._attrs = dict(value)
200
-
201
- def __repr__(self) -> str:
202
- """Make string representation of InferenceData object."""
203
- msg = "Inference data with groups:\n\t> {options}".format(
204
- options="\n\t> ".join(self._groups)
205
- )
206
- if self._groups_warmup:
207
- msg += f"\n\nWarmup iterations saved ({WARMUP_TAG}*)."
208
- return msg
209
-
210
- def _repr_html_(self) -> str:
211
- """Make html representation of InferenceData object."""
212
- try:
213
- from xarray.core.options import OPTIONS
214
-
215
- display_style = OPTIONS["display_style"]
216
- if display_style == "text":
217
- html_repr = f"<pre>{escape(repr(self))}</pre>"
218
- else:
219
- elements = "".join(
220
- [
221
- HtmlTemplate.element_template.format(
222
- group_id=group + str(uuid.uuid4()),
223
- group=group,
224
- xr_data=getattr( # pylint: disable=protected-access
225
- self, group
226
- )._repr_html_(),
227
- )
228
- for group in self._groups_all
229
- ]
230
- )
231
- formatted_html_template = ( # pylint: disable=possibly-unused-variable
232
- HtmlTemplate.html_template.format(elements)
233
- )
234
- css_template = HtmlTemplate.css_template # pylint: disable=possibly-unused-variable
235
- html_repr = f"{locals()['formatted_html_template']}{locals()['css_template']}"
236
- except: # pylint: disable=bare-except
237
- html_repr = f"<pre>{escape(repr(self))}</pre>"
238
- return html_repr
239
-
240
- def __delattr__(self, group: str) -> None:
241
- """Delete a group from the InferenceData object."""
242
- if group in self._groups:
243
- self._groups.remove(group)
244
- elif group in self._groups_warmup:
245
- self._groups_warmup.remove(group)
246
- object.__delattr__(self, group)
247
-
248
- def __delitem__(self, key: str) -> None:
249
- """Delete an item from the InferenceData object using del idata[key]."""
250
- self.__delattr__(key)
251
-
252
- @property
253
- def _groups_all(self) -> List[str]:
254
- return self._groups + self._groups_warmup
255
-
256
- def __len__(self) -> int:
257
- """Return the number of groups in this InferenceData object."""
258
- return len(self._groups_all)
259
-
260
- def __iter__(self) -> Iterator[str]:
261
- """Iterate over groups in InferenceData object."""
262
- yield from self._groups_all
263
-
264
- def __contains__(self, key: object) -> bool:
265
- """Return True if the named item is present, and False otherwise."""
266
- return key in self._groups_all
267
-
268
- def __getitem__(self, key: str) -> xr.Dataset:
269
- """Get item by key."""
270
- if key not in self._groups_all:
271
- raise KeyError(key)
272
- return getattr(self, key)
273
-
274
- def __setitem__(self, key: str, value: xr.Dataset):
275
- """Set item by key and update group list accordingly."""
276
- if key.startswith(WARMUP_TAG):
277
- self._groups_warmup.append(key)
278
- else:
279
- self._groups.append(key)
280
- setattr(self, key, value)
281
-
282
- def groups(self) -> List[str]:
283
- """Return all groups present in InferenceData object."""
284
- return self._groups_all
285
-
286
- class InferenceDataValuesView(ValuesView[xr.Dataset]):
287
- """ValuesView implementation for InferenceData, to allow it to implement Mapping."""
288
-
289
- def __init__( # pylint: disable=super-init-not-called
290
- self, parent: "InferenceData"
291
- ) -> None:
292
- """Create a new InferenceDataValuesView from an InferenceData object."""
293
- self.parent = parent
294
-
295
- def __len__(self) -> int:
296
- """Return the number of groups in the parent InferenceData."""
297
- return len(self.parent._groups_all)
298
-
299
- def __iter__(self) -> Iterator[xr.Dataset]:
300
- """Iterate through the Xarray datasets present in the InferenceData object."""
301
- parent = self.parent
302
- for group in parent._groups_all:
303
- yield getattr(parent, group)
304
-
305
- def __contains__(self, key: object) -> bool:
306
- """Return True if the given Xarray dataset is one of the values, and False otherwise."""
307
- if not isinstance(key, xr.Dataset):
308
- return False
309
-
310
- for dataset in self:
311
- if dataset.equals(key):
312
- return True
313
-
314
- return False
315
-
316
- def values(self) -> "InferenceData.InferenceDataValuesView":
317
- """Return a view over the Xarray Datasets present in the InferenceData object."""
318
- return InferenceData.InferenceDataValuesView(self)
319
-
320
- class InferenceDataItemsView(ItemsView[str, xr.Dataset]):
321
- """ItemsView implementation for InferenceData, to allow it to implement Mapping."""
322
-
323
- def __init__( # pylint: disable=super-init-not-called
324
- self, parent: "InferenceData"
325
- ) -> None:
326
- """Create a new InferenceDataItemsView from an InferenceData object."""
327
- self.parent = parent
328
-
329
- def __len__(self) -> int:
330
- """Return the number of groups in the parent InferenceData."""
331
- return len(self.parent._groups_all)
332
-
333
- def __iter__(self) -> Iterator[Tuple[str, xr.Dataset]]:
334
- """Iterate through the groups and corresponding Xarray datasets in the InferenceData."""
335
- parent = self.parent
336
- for group in parent._groups_all:
337
- yield group, getattr(parent, group)
338
-
339
- def __contains__(self, key: object) -> bool:
340
- """Return True if the (group, dataset) tuple is present, and False otherwise."""
341
- parent = self.parent
342
- if not isinstance(key, tuple) or len(key) != 2:
343
- return False
344
-
345
- group, dataset = key
346
- if group not in parent._groups_all:
347
- return False
348
-
349
- if not isinstance(dataset, xr.Dataset):
350
- return False
351
-
352
- existing_dataset = getattr(parent, group)
353
- return existing_dataset.equals(dataset)
354
-
355
- def items(self) -> "InferenceData.InferenceDataItemsView":
356
- """Return a view over the groups and datasets present in the InferenceData object."""
357
- return InferenceData.InferenceDataItemsView(self)
358
-
359
- @staticmethod
360
- def from_netcdf(
361
- filename,
362
- *,
363
- engine="h5netcdf",
364
- group_kwargs=None,
365
- regex=False,
366
- base_group: str = "/",
367
- ) -> "InferenceData":
368
- """Initialize object from a netcdf file.
369
-
370
- Expects that the file will have groups, each of which can be loaded by xarray.
371
- By default, the datasets of the InferenceData object will be lazily loaded instead
372
- of being loaded into memory. This
373
- behaviour is regulated by the value of ``az.rcParams["data.load"]``.
374
-
375
- Parameters
376
- ----------
377
- filename : str
378
- location of netcdf file
379
- engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
380
- Library used to read the netcdf file.
381
- group_kwargs : dict of {str: dict}, optional
382
- Keyword arguments to be passed into each call of :func:`xarray.open_dataset`.
383
- The keys of the higher level should be group names or regex matching group
384
- names, the inner dicts re passed to ``open_dataset``
385
- This feature is currently experimental.
386
- regex : bool, default False
387
- Specifies where regex search should be used to extend the keyword arguments.
388
- This feature is currently experimental.
389
- base_group : str, default "/"
390
- The group in the netCDF file where the InferenceData is stored. By default,
391
- assumes that the file only contains an InferenceData object.
392
-
393
- Returns
394
- -------
395
- InferenceData
396
- """
397
- groups = {}
398
- attrs = {}
399
-
400
- if engine == "h5netcdf":
401
- import h5netcdf
402
- elif engine == "netcdf4":
403
- import netCDF4 as nc
404
- else:
405
- raise ValueError(
406
- f"Invalid value for engine: {engine}. Valid options are: h5netcdf or netcdf4"
407
- )
408
-
409
- try:
410
- with (
411
- h5netcdf.File(filename, mode="r")
412
- if engine == "h5netcdf"
413
- else nc.Dataset(filename, mode="r")
414
- ) as file_handle:
415
- if base_group == "/":
416
- data = file_handle
417
- else:
418
- data = file_handle[base_group]
419
-
420
- data_groups = list(data.groups)
421
-
422
- for group in data_groups:
423
- group_kws = {}
424
-
425
- group_kws = {}
426
- if group_kwargs is not None and regex is False:
427
- group_kws = group_kwargs.get(group, {})
428
- if group_kwargs is not None and regex is True:
429
- for key, kws in group_kwargs.items():
430
- if re.search(key, group):
431
- group_kws = kws
432
- group_kws.setdefault("engine", engine)
433
- data = xr.open_dataset(filename, group=f"{base_group}/{group}", **group_kws)
434
- if rcParams["data.load"] == "eager":
435
- with data:
436
- groups[group] = data.load()
437
- else:
438
- groups[group] = data
439
-
440
- with xr.open_dataset(filename, engine=engine, group=base_group) as data:
441
- attrs.update(data.load().attrs)
442
-
443
- return InferenceData(attrs=attrs, **groups)
444
- except OSError as err:
445
- if err.errno == -101:
446
- raise type(err)(
447
- str(err)
448
- + (
449
- " while reading a NetCDF file. This is probably an error in HDF5, "
450
- "which happens because your OS does not support HDF5 file locking. See "
451
- "https://stackoverflow.com/questions/49317927/"
452
- "errno-101-netcdf-hdf-error-when-opening-netcdf-file#49317928"
453
- " for a possible solution."
454
- )
455
- ) from err
456
- raise err
457
-
458
- def to_netcdf(
459
- self,
460
- filename: str,
461
- compress: bool = True,
462
- groups: Optional[List[str]] = None,
463
- engine: str = "h5netcdf",
464
- base_group: str = "/",
465
- overwrite_existing: bool = True,
466
- ) -> str:
467
- """Write InferenceData to netcdf4 file.
468
-
469
- Parameters
470
- ----------
471
- filename : str
472
- Location to write to
473
- compress : bool, optional
474
- Whether to compress result. Note this saves disk space, but may make
475
- saving and loading somewhat slower (default: True).
476
- groups : list, optional
477
- Write only these groups to netcdf file.
478
- engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
479
- Library used to read the netcdf file.
480
- base_group : str, default "/"
481
- The group in the netCDF file where the InferenceData is will be stored.
482
- By default, will write to the root of the netCDF file
483
- overwrite_existing : bool, default True
484
- Whether to overwrite the existing file or append to it.
485
-
486
- Returns
487
- -------
488
- str
489
- Location of netcdf file
490
- """
491
- if base_group is None:
492
- base_group = "/"
493
-
494
- if os.path.exists(filename) and not overwrite_existing:
495
- mode = "a"
496
- else:
497
- mode = "w" # overwrite first, then append
498
-
499
- if self._attrs:
500
- xr.Dataset(attrs=self._attrs).to_netcdf(
501
- filename, mode=mode, engine=engine, group=base_group
502
- )
503
- mode = "a"
504
-
505
- if self._groups_all: # check's whether a group is present or not.
506
- if groups is None:
507
- groups = self._groups_all
508
- else:
509
- groups = [group for group in self._groups_all if group in groups]
510
-
511
- for group in groups:
512
- data = getattr(self, group)
513
- kwargs = {"engine": engine}
514
- if compress:
515
- kwargs["encoding"] = {
516
- var_name: {"zlib": True}
517
- for var_name, values in data.variables.items()
518
- if _compressible_dtype(values.dtype)
519
- }
520
- data.to_netcdf(filename, mode=mode, group=f"{base_group}/{group}", **kwargs)
521
- data.close()
522
- mode = "a"
523
- elif not self._attrs: # creates a netcdf file for an empty InferenceData object.
524
- if engine == "h5netcdf":
525
- import h5netcdf
526
-
527
- empty_netcdf_file = h5netcdf.File(filename, mode="w")
528
- elif engine == "netcdf4":
529
- import netCDF4 as nc
530
-
531
- empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4")
532
- empty_netcdf_file.close()
533
- return filename
534
-
535
- def to_datatree(self):
536
- """Convert InferenceData object to a :class:`~xarray.DataTree`."""
537
- try:
538
- from xarray import DataTree
539
- except ImportError as err:
540
- raise ImportError(
541
- "xarray must be have DataTree in order to use InferenceData.to_datatree. "
542
- "Update to xarray>=2024.11.0"
543
- ) from err
544
- dt = DataTree.from_dict({group: ds for group, ds in self.items()})
545
- dt.attrs = self.attrs
546
- return dt
547
-
548
- @staticmethod
549
- def from_datatree(datatree):
550
- """Create an InferenceData object from a :class:`~xarray.DataTree`.
551
-
552
- Parameters
553
- ----------
554
- datatree : DataTree
555
- """
556
- return InferenceData(
557
- attrs=datatree.attrs,
558
- **{group: child.to_dataset() for group, child in datatree.children.items()},
559
- )
560
-
561
- def to_dict(self, groups=None, filter_groups=None):
562
- """Convert InferenceData to a dictionary following xarray naming conventions.
563
-
564
- Parameters
565
- ----------
566
- groups : list, optional
567
- Groups where the transformation is to be applied. Can either be group names
568
- or metagroup names.
569
- filter_groups: {None, "like", "regex"}, optional, default=None
570
- If `None` (default), interpret groups as the real group or metagroup names.
571
- If "like", interpret groups as substrings of the real group or metagroup names.
572
- If "regex", interpret groups as regular expressions on the real group or
573
- metagroup names. A la `pandas.filter`.
574
-
575
- Returns
576
- -------
577
- dict
578
- A dictionary containing all groups of InferenceData object.
579
- When `data=False` return just the schema.
580
- """
581
- ret = defaultdict(dict)
582
- if self._groups_all: # check's whether a group is present or not.
583
- if groups is None:
584
- groups = self._group_names(groups, filter_groups)
585
- else:
586
- groups = [group for group in self._groups_all if group in groups]
587
-
588
- for group in groups:
589
- dataset = getattr(self, group)
590
- data = {}
591
- for var_name, dataarray in dataset.items():
592
- data[var_name] = dataarray.values
593
- dims = []
594
- for coord_name, coord_values in dataarray.coords.items():
595
- if coord_name not in ("chain", "draw") and not coord_name.startswith(
596
- f"{var_name}_dim_"
597
- ):
598
- dims.append(coord_name)
599
- ret["coords"][coord_name] = coord_values.values
600
-
601
- if group in (
602
- "predictions",
603
- "predictions_constant_data",
604
- ):
605
- dims_key = "pred_dims"
606
- else:
607
- dims_key = "dims"
608
- if len(dims) > 0:
609
- ret[dims_key][var_name] = dims
610
- ret[group] = data
611
- ret[f"{group}_attrs"] = dataset.attrs
612
-
613
- ret["attrs"] = self.attrs
614
- return ret
615
-
616
- def to_json(self, filename, groups=None, filter_groups=None, **kwargs):
617
- """Write InferenceData to a json file.
618
-
619
- Parameters
620
- ----------
621
- filename : str
622
- Location to write to
623
- groups : list, optional
624
- Groups where the transformation is to be applied. Can either be group names
625
- or metagroup names.
626
- filter_groups: {None, "like", "regex"}, optional, default=None
627
- If `None` (default), interpret groups as the real group or metagroup names.
628
- If "like", interpret groups as substrings of the real group or metagroup names.
629
- If "regex", interpret groups as regular expressions on the real group or
630
- metagroup names. A la `pandas.filter`.
631
- kwargs : dict
632
- kwargs passed to json.dump()
633
-
634
- Returns
635
- -------
636
- str
637
- Location of json file
638
- """
639
- idata_dict = _make_json_serializable(
640
- self.to_dict(groups=groups, filter_groups=filter_groups)
641
- )
642
-
643
- with open(filename, "w", encoding="utf8") as file:
644
- json.dump(idata_dict, file, **kwargs)
645
-
646
- return filename
647
-
648
- def to_dataframe(
649
- self,
650
- groups=None,
651
- filter_groups=None,
652
- var_names=None,
653
- filter_vars=None,
654
- include_coords=True,
655
- include_index=True,
656
- index_origin=None,
657
- ):
658
- """Convert InferenceData to a :class:`pandas.DataFrame` following xarray naming conventions.
659
-
660
- This returns dataframe in a "wide" -format, where each item in ndimensional array is
661
- unpacked. To access "tidy" -format, use xarray functionality found for each dataset.
662
-
663
- In case of a multiple groups, function adds a group identification to the var name.
664
-
665
- Data groups ("observed_data", "constant_data", "predictions_constant_data") are
666
- skipped implicitly.
667
-
668
- Raises TypeError if no valid groups are found.
669
- Raises ValueError if no data are selected.
670
-
671
- Parameters
672
- ----------
673
- groups: str or list of str, optional
674
- Groups where the transformation is to be applied. Can either be group names
675
- or metagroup names.
676
- filter_groups: {None, "like", "regex"}, optional, default=None
677
- If `None` (default), interpret groups as the real group or metagroup names.
678
- If "like", interpret groups as substrings of the real group or metagroup names.
679
- If "regex", interpret groups as regular expressions on the real group or
680
- metagroup names. A la `pandas.filter`.
681
- var_names : str or list of str, optional
682
- Variables to be extracted. Prefix the variables by `~` when you want to exclude them.
683
- filter_vars: {None, "like", "regex"}, optional
684
- If `None` (default), interpret var_names as the real variables names. If "like",
685
- interpret var_names as substrings of the real variables names. If "regex",
686
- interpret var_names as regular expressions on the real variables names. A la
687
- `pandas.filter`.
688
- Like with plotting, sometimes it's easier to subset saying what to exclude
689
- instead of what to include
690
- include_coords: bool
691
- Add coordinate values to column name (tuple).
692
- include_index: bool
693
- Add index information for multidimensional arrays.
694
- index_origin: {0, 1}, optional
695
- Starting index for multidimensional objects. 0- or 1-based.
696
- Defaults to rcParams["data.index_origin"].
697
-
698
- Returns
699
- -------
700
- pandas.DataFrame
701
- A pandas DataFrame containing all selected groups of InferenceData object.
702
- """
703
- # pylint: disable=too-many-nested-blocks
704
- if not include_coords and not include_index:
705
- raise TypeError("Both include_coords and include_index can not be False.")
706
- if index_origin is None:
707
- index_origin = rcParams["data.index_origin"]
708
- if index_origin not in [0, 1]:
709
- raise TypeError(f"index_origin must be 0 or 1, saw {index_origin}")
710
-
711
- group_names = list(
712
- filter(lambda x: "data" not in x, self._group_names(groups, filter_groups))
713
- )
714
-
715
- if not group_names:
716
- raise TypeError(f"No valid groups found: {groups}")
717
-
718
- dfs = {}
719
- for group in group_names:
720
- dataset = self[group]
721
- group_var_names = _var_names(var_names, dataset, filter_vars, "ignore")
722
- if (group_var_names is not None) and not group_var_names:
723
- continue
724
- if group_var_names is not None:
725
- dataset = dataset[[var_name for var_name in group_var_names if var_name in dataset]]
726
- df = None
727
- coords_to_idx = {
728
- name: dict(map(reversed, enumerate(dataset.coords[name].values, index_origin)))
729
- for name in list(filter(lambda x: x not in ("chain", "draw"), dataset.coords))
730
- }
731
- for data_array in dataset.values():
732
- dataframe = data_array.to_dataframe()
733
- if list(filter(lambda x: x not in ("chain", "draw"), data_array.dims)):
734
- levels = [
735
- idx
736
- for idx, dim in enumerate(data_array.dims)
737
- if dim not in ("chain", "draw")
738
- ]
739
- dataframe = dataframe.unstack(level=levels)
740
- tuple_columns = []
741
- for name, *coords in dataframe.columns:
742
- if include_index:
743
- idxs = []
744
- for coordname, coorditem in zip(dataframe.columns.names[1:], coords):
745
- idxs.append(coords_to_idx[coordname][coorditem])
746
- if include_coords:
747
- tuple_columns.append(
748
- (f"{name}[{','.join(map(str, idxs))}]", *coords)
749
- )
750
- else:
751
- tuple_columns.append(f"{name}[{','.join(map(str, idxs))}]")
752
- else:
753
- tuple_columns.append((name, *coords))
754
-
755
- dataframe.columns = tuple_columns
756
- dataframe.sort_index(axis=1, inplace=True)
757
- if df is None:
758
- df = dataframe
759
- continue
760
- df = df.join(dataframe, how="outer")
761
- if df is not None:
762
- df = df.reset_index()
763
- dfs[group] = df
764
- if not dfs:
765
- raise ValueError("No data selected for the dataframe.")
766
- if len(dfs) > 1:
767
- for group, df in dfs.items():
768
- df.columns = [
769
- (
770
- col
771
- if col in ("draw", "chain")
772
- else (group, *col) if isinstance(col, tuple) else (group, col)
773
- )
774
- for col in df.columns
775
- ]
776
- dfs, *dfs_tail = list(dfs.values())
777
- for df in dfs_tail:
778
- dfs = dfs.merge(df, how="outer", copy=False)
779
- else:
780
- (dfs,) = dfs.values() # pylint: disable=unbalanced-dict-unpacking
781
- return dfs
782
-
783
- def to_zarr(self, store=None):
784
- """Convert InferenceData to a :class:`zarr.hierarchy.Group`.
785
-
786
- The zarr storage is using the same group names as the InferenceData.
787
-
788
- Raises
789
- ------
790
- TypeError
791
- If no valid store is found.
792
-
793
- Parameters
794
- ----------
795
- store: zarr.storage i.e MutableMapping or str, optional
796
- Zarr storage class or path to desired DirectoryStore.
797
-
798
- Returns
799
- -------
800
- zarr.hierarchy.group
801
- A zarr hierarchy group containing the InferenceData.
802
-
803
- References
804
- ----------
805
- https://zarr.readthedocs.io/
806
- """
807
- try:
808
- import zarr
809
- except ImportError as err:
810
- raise ImportError("'to_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
811
- if version.parse(zarr.__version__) < version.parse("2.5.0"):
812
- raise ImportError(
813
- "Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'to_zarr'"
814
- )
815
- if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
816
- raise ImportError(
817
- "Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
818
- "'dt = InferenceData.to_datatree' followed by 'dt.to_zarr()' "
819
- "(needs xarray>=2024.11.0)"
820
- )
821
-
822
- # Check store type and create store if necessary
823
- if store is None:
824
- store = zarr.storage.TempStore(suffix="arviz")
825
- elif isinstance(store, str):
826
- store = zarr.storage.DirectoryStore(path=store)
827
- elif not isinstance(store, MutableMapping):
828
- raise TypeError(f"No valid store found: {store}")
829
-
830
- groups = self.groups()
831
-
832
- if not groups:
833
- raise TypeError("No valid groups found!")
834
-
835
- # order matters here, saving attrs after the groups will erase the groups.
836
- if self.attrs:
837
- xr.Dataset(attrs=self.attrs).to_zarr(store=store, mode="w")
838
-
839
- for group in groups:
840
- # Create zarr group in store with same group name
841
- getattr(self, group).to_zarr(store=store, group=group, mode="w")
842
-
843
- return zarr.open(store) # Open store to get overarching group
844
-
845
- @staticmethod
846
- def from_zarr(store) -> "InferenceData":
847
- """Initialize object from a zarr store or path.
848
-
849
- Expects that the zarr store will have groups, each of which can be loaded by xarray.
850
- By default, the datasets of the InferenceData object will be lazily loaded instead
851
- of being loaded into memory. This
852
- behaviour is regulated by the value of ``az.rcParams["data.load"]``.
853
-
854
- Parameters
855
- ----------
856
- store: MutableMapping or zarr.hierarchy.Group or str.
857
- Zarr storage class or path to desired Store.
858
-
859
- Returns
860
- -------
861
- InferenceData object
862
-
863
- References
864
- ----------
865
- https://zarr.readthedocs.io/
866
- """
867
- try:
868
- import zarr
869
- except ImportError as err:
870
- raise ImportError("'from_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
871
- if version.parse(zarr.__version__) < version.parse("2.5.0"):
872
- raise ImportError(
873
- "Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'from_zarr'"
874
- )
875
- if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
876
- raise ImportError(
877
- "Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
878
- "'xarray.open_datatree' followed by 'arviz.InferenceData.from_datatree' "
879
- "(needs xarray>=2024.11.0)"
880
- )
881
-
882
- # Check store type and create store if necessary
883
- if isinstance(store, str):
884
- store = zarr.storage.DirectoryStore(path=store)
885
- elif isinstance(store, zarr.hierarchy.Group):
886
- store = store.store
887
- elif not isinstance(store, MutableMapping):
888
- raise TypeError(f"No valid store found: {store}")
889
-
890
- groups = {}
891
- zarr_handle = zarr.open(store, mode="r")
892
-
893
- # Open each group via xarray method
894
- for key_group, _ in zarr_handle.groups():
895
- with xr.open_zarr(store=store, group=key_group) as data:
896
- groups[key_group] = data.load() if rcParams["data.load"] == "eager" else data
897
-
898
- with xr.open_zarr(store=store) as root:
899
- attrs = root.attrs
900
-
901
- return InferenceData(attrs=attrs, **groups)
902
-
903
- def __add__(self, other: "InferenceData") -> "InferenceData":
904
- """Concatenate two InferenceData objects."""
905
- return concat(self, other, copy=True, inplace=False)
906
-
907
- def sel(
908
- self: InferenceDataT,
909
- groups: Optional[Union[str, List[str]]] = None,
910
- filter_groups: Optional["Literal['like', 'regex']"] = None,
911
- inplace: bool = False,
912
- chain_prior: Optional[bool] = None,
913
- **kwargs: Any,
914
- ) -> Optional[InferenceDataT]:
915
- """Perform an xarray selection on all groups.
916
-
917
- Loops groups to perform Dataset.sel(key=item)
918
- for every kwarg if key is a dimension of the dataset.
919
- One example could be performing a burn in cut on the InferenceData object
920
- or discarding a chain. The selection is performed on all relevant groups (like
921
- posterior, prior, sample stats) while non relevant groups like observed data are
922
- omitted. See :meth:`xarray.Dataset.sel <xarray:xarray.Dataset.sel>`
923
-
924
- Parameters
925
- ----------
926
- groups : str or list of str, optional
927
- Groups where the selection is to be applied. Can either be group names
928
- or metagroup names.
929
- filter_groups : {None, "like", "regex"}, optional, default=None
930
- If `None` (default), interpret groups as the real group or metagroup names.
931
- If "like", interpret groups as substrings of the real group or metagroup names.
932
- If "regex", interpret groups as regular expressions on the real group or
933
- metagroup names. A la `pandas.filter`.
934
- inplace : bool, optional
935
- If ``True``, modify the InferenceData object inplace,
936
- otherwise, return the modified copy.
937
- chain_prior : bool, optional, deprecated
938
- If ``False``, do not select prior related groups using ``chain`` dim.
939
- Otherwise, use selection on ``chain`` if present. Default=False
940
- kwargs : dict, optional
941
- It must be accepted by Dataset.sel().
942
-
943
- Returns
944
- -------
945
- InferenceData
946
- A new InferenceData object by default.
947
- When `inplace==True` perform selection in-place and return `None`
948
-
949
- Examples
950
- --------
951
- Use ``sel`` to discard one chain of the InferenceData object. We first check the
952
- dimensions of the original object:
953
-
954
- .. jupyter-execute::
955
-
956
- import arviz as az
957
- idata = az.load_arviz_data("centered_eight")
958
- idata
959
-
960
- In order to remove the third chain:
961
-
962
- .. jupyter-execute::
963
-
964
- idata_subset = idata.sel(chain=[0, 1, 3], groups="posterior_groups")
965
- idata_subset
966
-
967
- See Also
968
- --------
969
- xarray.Dataset.sel :
970
- Returns a new dataset with each array indexed by tick labels along the specified
971
- dimension(s).
972
- isel : Returns a new dataset with each array indexed along the specified dimension(s).
973
- """
974
- if chain_prior is not None:
975
- warnings.warn(
976
- "chain_prior has been deprecated. Use groups argument and "
977
- "rcParams['data.metagroups'] instead.",
978
- DeprecationWarning,
979
- )
980
- else:
981
- chain_prior = False
982
- group_names = self._group_names(groups, filter_groups)
983
-
984
- out = self if inplace else deepcopy(self)
985
- for group in group_names:
986
- dataset = getattr(self, group)
987
- valid_keys = set(kwargs.keys()).intersection(dataset.dims)
988
- if not chain_prior and "prior" in group:
989
- valid_keys -= {"chain"}
990
- dataset = dataset.sel(**{key: kwargs[key] for key in valid_keys})
991
- setattr(out, group, dataset)
992
- if inplace:
993
- return None
994
- else:
995
- return out
996
-
997
- def isel(
998
- self: InferenceDataT,
999
- groups: Optional[Union[str, List[str]]] = None,
1000
- filter_groups: Optional["Literal['like', 'regex']"] = None,
1001
- inplace: bool = False,
1002
- **kwargs: Any,
1003
- ) -> Optional[InferenceDataT]:
1004
- """Perform an xarray selection on all groups.
1005
-
1006
- Loops groups to perform Dataset.isel(key=item)
1007
- for every kwarg if key is a dimension of the dataset.
1008
- One example could be performing a burn in cut on the InferenceData object
1009
- or discarding a chain. The selection is performed on all relevant groups (like
1010
- posterior, prior, sample stats) while non relevant groups like observed data are
1011
- omitted. See :meth:`xarray:xarray.Dataset.isel`
1012
-
1013
- Parameters
1014
- ----------
1015
- groups : str or list of str, optional
1016
- Groups where the selection is to be applied. Can either be group names
1017
- or metagroup names.
1018
- filter_groups : {None, "like", "regex"}, optional
1019
- If `None` (default), interpret groups as the real group or metagroup names.
1020
- If "like", interpret groups as substrings of the real group or metagroup names.
1021
- If "regex", interpret groups as regular expressions on the real group or
1022
- metagroup names. A la `pandas.filter`.
1023
- inplace : bool, optional
1024
- If ``True``, modify the InferenceData object inplace,
1025
- otherwise, return the modified copy.
1026
- kwargs : dict, optional
1027
- It must be accepted by :meth:`xarray:xarray.Dataset.isel`.
1028
-
1029
- Returns
1030
- -------
1031
- InferenceData
1032
- A new InferenceData object by default.
1033
- When `inplace==True` perform selection in-place and return `None`
1034
-
1035
- Examples
1036
- --------
1037
- Use ``isel`` to discard one chain of the InferenceData object. We first check the
1038
- dimensions of the original object:
1039
-
1040
- .. jupyter-execute::
1041
-
1042
- import arviz as az
1043
- idata = az.load_arviz_data("centered_eight")
1044
- idata
1045
-
1046
- In order to remove the third chain:
1047
-
1048
- .. jupyter-execute::
1049
-
1050
- idata_subset = idata.isel(chain=[0, 1, 3], groups="posterior_groups")
1051
- idata_subset
1052
-
1053
- You can expand the groups and coords in each group to see how now only the chains 0, 1 and
1054
- 3 are present.
1055
-
1056
- See Also
1057
- --------
1058
- xarray.Dataset.isel :
1059
- Returns a new dataset with each array indexed along the specified dimension(s).
1060
- sel :
1061
- Returns a new dataset with each array indexed by tick labels along the specified
1062
- dimension(s).
1063
- """
1064
- group_names = self._group_names(groups, filter_groups)
1065
-
1066
- out = self if inplace else deepcopy(self)
1067
- for group in group_names:
1068
- dataset = getattr(self, group)
1069
- valid_keys = set(kwargs.keys()).intersection(dataset.dims)
1070
- dataset = dataset.isel(**{key: kwargs[key] for key in valid_keys})
1071
- setattr(out, group, dataset)
1072
- if inplace:
1073
- return None
1074
- else:
1075
- return out
1076
-
1077
- def stack(
1078
- self,
1079
- dimensions=None,
1080
- groups=None,
1081
- filter_groups=None,
1082
- inplace=False,
1083
- **kwargs,
1084
- ):
1085
- """Perform an xarray stacking on all groups.
1086
-
1087
- Stack any number of existing dimensions into a single new dimension.
1088
- Loops groups to perform Dataset.stack(key=value)
1089
- for every kwarg if value is a dimension of the dataset.
1090
- The selection is performed on all relevant groups (like
1091
- posterior, prior, sample stats) while non relevant groups like observed data are
1092
- omitted. See :meth:`xarray:xarray.Dataset.stack`
1093
-
1094
- Parameters
1095
- ----------
1096
- dimensions : dict, optional
1097
- Names of new dimensions, and the existing dimensions that they replace.
1098
- groups: str or list of str, optional
1099
- Groups where the selection is to be applied. Can either be group names
1100
- or metagroup names.
1101
- filter_groups : {None, "like", "regex"}, optional
1102
- If `None` (default), interpret groups as the real group or metagroup names.
1103
- If "like", interpret groups as substrings of the real group or metagroup names.
1104
- If "regex", interpret groups as regular expressions on the real group or
1105
- metagroup names. A la `pandas.filter`.
1106
- inplace : bool, optional
1107
- If ``True``, modify the InferenceData object inplace,
1108
- otherwise, return the modified copy.
1109
- kwargs : dict, optional
1110
- It must be accepted by :meth:`xarray:xarray.Dataset.stack`.
1111
-
1112
- Returns
1113
- -------
1114
- InferenceData
1115
- A new InferenceData object by default.
1116
- When `inplace==True` perform selection in-place and return `None`
1117
-
1118
- Examples
1119
- --------
1120
- Use ``stack`` to stack any number of existing dimensions into a single new dimension.
1121
- We first check the original object:
1122
-
1123
- .. jupyter-execute::
1124
-
1125
- import arviz as az
1126
- idata = az.load_arviz_data("rugby")
1127
- idata
1128
-
1129
- In order to stack two dimensions ``chain`` and ``draw`` to ``sample``, we can use:
1130
-
1131
- .. jupyter-execute::
1132
-
1133
- idata.stack(sample=["chain", "draw"], inplace=True)
1134
- idata
1135
-
1136
- We can also take the example of custom InferenceData object and perform stacking. We first
1137
- check the original object:
1138
-
1139
- .. jupyter-execute::
1140
-
1141
- import numpy as np
1142
- datadict = {
1143
- "a": np.random.randn(100),
1144
- "b": np.random.randn(1, 100, 10),
1145
- "c": np.random.randn(1, 100, 3, 4),
1146
- }
1147
- coords = {
1148
- "c1": np.arange(3),
1149
- "c99": np.arange(4),
1150
- "b1": np.arange(10),
1151
- }
1152
- dims = {"c": ["c1", "c99"], "b": ["b1"]}
1153
- idata = az.from_dict(
1154
- posterior=datadict, posterior_predictive=datadict, coords=coords, dims=dims
1155
- )
1156
- idata
1157
-
1158
- In order to stack two dimensions ``c1`` and ``c99`` to ``z``, we can use:
1159
-
1160
- .. jupyter-execute::
1161
-
1162
- idata.stack(z=["c1", "c99"], inplace=True)
1163
- idata
1164
-
1165
- See Also
1166
- --------
1167
- xarray.Dataset.stack : Stack any number of existing dimensions into a single new dimension.
1168
- unstack : Perform an xarray unstacking on all groups of InferenceData object.
1169
- """
1170
- groups = self._group_names(groups, filter_groups)
1171
-
1172
- dimensions = {} if dimensions is None else dimensions
1173
- dimensions.update(kwargs)
1174
- out = self if inplace else deepcopy(self)
1175
- for group in groups:
1176
- dataset = getattr(self, group)
1177
- kwarg_dict = {}
1178
- for key, value in dimensions.items():
1179
- try:
1180
- if not set(value).difference(dataset.dims):
1181
- kwarg_dict[key] = value
1182
- except TypeError:
1183
- kwarg_dict[key] = value
1184
- dataset = dataset.stack(**kwarg_dict)
1185
- setattr(out, group, dataset)
1186
- if inplace:
1187
- return None
1188
- else:
1189
- return out
1190
-
1191
- def unstack(self, dim=None, groups=None, filter_groups=None, inplace=False):
1192
- """Perform an xarray unstacking on all groups.
1193
-
1194
- Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions.
1195
- Loops groups to perform Dataset.unstack(key=value).
1196
- The selection is performed on all relevant groups (like posterior, prior,
1197
- sample stats) while non relevant groups like observed data are omitted.
1198
- See :meth:`xarray:xarray.Dataset.unstack`
1199
-
1200
- Parameters
1201
- ----------
1202
- dim : Hashable or iterable of Hashable, optional
1203
- Dimension(s) over which to unstack. By default unstacks all MultiIndexes.
1204
- groups : str or list of str, optional
1205
- Groups where the selection is to be applied. Can either be group names
1206
- or metagroup names.
1207
- filter_groups : {None, "like", "regex"}, optional
1208
- If `None` (default), interpret groups as the real group or metagroup names.
1209
- If "like", interpret groups as substrings of the real group or metagroup names.
1210
- If "regex", interpret groups as regular expressions on the real group or
1211
- metagroup names. A la `pandas.filter`.
1212
- inplace : bool, optional
1213
- If ``True``, modify the InferenceData object inplace,
1214
- otherwise, return the modified copy.
1215
-
1216
- Returns
1217
- -------
1218
- InferenceData
1219
- A new InferenceData object by default.
1220
- When `inplace==True` perform selection in place and return `None`
1221
-
1222
- Examples
1223
- --------
1224
- Use ``unstack`` to unstack existing dimensions corresponding to MultiIndexes into
1225
- multiple new dimensions. We first stack two dimensions ``c1`` and ``c99`` to ``z``:
1226
-
1227
- .. jupyter-execute::
1228
-
1229
- import arviz as az
1230
- import numpy as np
1231
- datadict = {
1232
- "a": np.random.randn(100),
1233
- "b": np.random.randn(1, 100, 10),
1234
- "c": np.random.randn(1, 100, 3, 4),
1235
- }
1236
- coords = {
1237
- "c1": np.arange(3),
1238
- "c99": np.arange(4),
1239
- "b1": np.arange(10),
1240
- }
1241
- dims = {"c": ["c1", "c99"], "b": ["b1"]}
1242
- idata = az.from_dict(
1243
- posterior=datadict, posterior_predictive=datadict, coords=coords, dims=dims
1244
- )
1245
- idata.stack(z=["c1", "c99"], inplace=True)
1246
- idata
1247
-
1248
- In order to unstack the dimension ``z``, we use:
1249
-
1250
- .. jupyter-execute::
1251
-
1252
- idata.unstack(inplace=True)
1253
- idata
1254
-
1255
- See Also
1256
- --------
1257
- xarray.Dataset.unstack :
1258
- Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions.
1259
- stack : Perform an xarray stacking on all groups of InferenceData object.
1260
- """
1261
- groups = self._group_names(groups, filter_groups)
1262
- if isinstance(dim, str):
1263
- dim = [dim]
1264
-
1265
- out = self if inplace else deepcopy(self)
1266
- for group in groups:
1267
- dataset = getattr(self, group)
1268
- valid_dims = set(dim).intersection(dataset.dims) if dim is not None else dim
1269
- dataset = dataset.unstack(dim=valid_dims)
1270
- setattr(out, group, dataset)
1271
- if inplace:
1272
- return None
1273
- else:
1274
- return out
1275
-
1276
- def rename(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
1277
- """Perform xarray renaming of variable and dimensions on all groups.
1278
-
1279
- Loops groups to perform Dataset.rename(name_dict)
1280
- for every key in name_dict if key is a dimension/data_vars of the dataset.
1281
- The renaming is performed on all relevant groups (like
1282
- posterior, prior, sample stats) while non relevant groups like observed data are
1283
- omitted. See :meth:`xarray:xarray.Dataset.rename`
1284
-
1285
- Parameters
1286
- ----------
1287
- name_dict : dict
1288
- Dictionary whose keys are current variable or dimension names
1289
- and whose values are the desired names.
1290
- groups : str or list of str, optional
1291
- Groups where the selection is to be applied. Can either be group names
1292
- or metagroup names.
1293
- filter_groups : {None, "like", "regex"}, optional
1294
- If `None` (default), interpret groups as the real group or metagroup names.
1295
- If "like", interpret groups as substrings of the real group or metagroup names.
1296
- If "regex", interpret groups as regular expressions on the real group or
1297
- metagroup names. A la `pandas.filter`.
1298
- inplace : bool, optional
1299
- If ``True``, modify the InferenceData object inplace,
1300
- otherwise, return the modified copy.
1301
-
1302
- Returns
1303
- -------
1304
- InferenceData
1305
- A new InferenceData object by default.
1306
- When `inplace==True` perform renaming in-place and return `None`
1307
-
1308
- Examples
1309
- --------
1310
- Use ``rename`` to renaming of variable and dimensions on all groups of the InferenceData
1311
- object. We first check the original object:
1312
-
1313
- .. jupyter-execute::
1314
-
1315
- import arviz as az
1316
- idata = az.load_arviz_data("rugby")
1317
- idata
1318
-
1319
- In order to rename the dimensions and variable, we use:
1320
-
1321
- .. jupyter-execute::
1322
-
1323
- idata.rename({"team": "team_new", "match":"match_new"}, inplace=True)
1324
- idata
1325
-
1326
- See Also
1327
- --------
1328
- xarray.Dataset.rename : Returns a new object with renamed variables and dimensions.
1329
- rename_vars :
1330
- Perform xarray renaming of variable or coordinate names on all groups of an
1331
- InferenceData object.
1332
- rename_dims : Perform xarray renaming of dimensions on all groups of InferenceData object.
1333
- """
1334
- groups = self._group_names(groups, filter_groups)
1335
- if "chain" in name_dict.keys() or "draw" in name_dict.keys():
1336
- raise KeyError("'chain' or 'draw' dimensions can't be renamed")
1337
- out = self if inplace else deepcopy(self)
1338
-
1339
- for group in groups:
1340
- dataset = getattr(self, group)
1341
- expected_keys = list(dataset.data_vars) + list(dataset.dims)
1342
- valid_keys = set(name_dict.keys()).intersection(expected_keys)
1343
- dataset = dataset.rename({key: name_dict[key] for key in valid_keys})
1344
- setattr(out, group, dataset)
1345
- if inplace:
1346
- return None
1347
- else:
1348
- return out
1349
-
1350
- def rename_vars(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
1351
- """Perform xarray renaming of variable or coordinate names on all groups.
1352
-
1353
- Loops groups to perform Dataset.rename_vars(name_dict)
1354
- for every key in name_dict if key is a variable or coordinate names of the dataset.
1355
- The renaming is performed on all relevant groups (like
1356
- posterior, prior, sample stats) while non relevant groups like observed data are
1357
- omitted. See :meth:`xarray:xarray.Dataset.rename_vars`
1358
-
1359
- Parameters
1360
- ----------
1361
- name_dict : dict
1362
- Dictionary whose keys are current variable or coordinate names
1363
- and whose values are the desired names.
1364
- groups : str or list of str, optional
1365
- Groups where the selection is to be applied. Can either be group names
1366
- or metagroup names.
1367
- filter_groups : {None, "like", "regex"}, optional
1368
- If `None` (default), interpret groups as the real group or metagroup names.
1369
- If "like", interpret groups as substrings of the real group or metagroup names.
1370
- If "regex", interpret groups as regular expressions on the real group or
1371
- metagroup names. A la `pandas.filter`.
1372
- inplace : bool, optional
1373
- If ``True``, modify the InferenceData object inplace,
1374
- otherwise, return the modified copy.
1375
-
1376
-
1377
- Returns
1378
- -------
1379
- InferenceData
1380
- A new InferenceData object with renamed variables including coordinates by default.
1381
- When `inplace==True` perform renaming in-place and return `None`
1382
-
1383
- Examples
1384
- --------
1385
- Use ``rename_vars`` to renaming of variable and coordinates on all groups of the
1386
- InferenceData object. We first check the data variables of original object:
1387
-
1388
- .. jupyter-execute::
1389
-
1390
- import arviz as az
1391
- idata = az.load_arviz_data("rugby")
1392
- idata
1393
-
1394
- In order to rename the data variables, we use:
1395
-
1396
- .. jupyter-execute::
1397
-
1398
- idata.rename_vars({"home": "home_new"}, inplace=True)
1399
- idata
1400
-
1401
- See Also
1402
- --------
1403
- xarray.Dataset.rename_vars :
1404
- Returns a new object with renamed variables including coordinates.
1405
- rename :
1406
- Perform xarray renaming of variable and dimensions on all groups of an InferenceData
1407
- object.
1408
- rename_dims : Perform xarray renaming of dimensions on all groups of InferenceData object.
1409
- """
1410
- groups = self._group_names(groups, filter_groups)
1411
-
1412
- out = self if inplace else deepcopy(self)
1413
- for group in groups:
1414
- dataset = getattr(self, group)
1415
- valid_keys = set(name_dict.keys()).intersection(dataset.data_vars)
1416
- dataset = dataset.rename_vars({key: name_dict[key] for key in valid_keys})
1417
- setattr(out, group, dataset)
1418
- if inplace:
1419
- return None
1420
- else:
1421
- return out
1422
-
1423
- def rename_dims(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
1424
- """Perform xarray renaming of dimensions on all groups.
1425
-
1426
- Loops groups to perform Dataset.rename_dims(name_dict)
1427
- for every key in name_dict if key is a dimension of the dataset.
1428
- The renaming is performed on all relevant groups (like
1429
- posterior, prior, sample stats) while non relevant groups like observed data are
1430
- omitted. See :meth:`xarray:xarray.Dataset.rename_dims`
1431
-
1432
- Parameters
1433
- ----------
1434
- name_dict : dict
1435
- Dictionary whose keys are current dimension names and whose values are the desired
1436
- names.
1437
- groups : str or list of str, optional
1438
- Groups where the selection is to be applied. Can either be group names
1439
- or metagroup names.
1440
- filter_groups : {None, "like", "regex"}, optional
1441
- If `None` (default), interpret groups as the real group or metagroup names.
1442
- If "like", interpret groups as substrings of the real group or metagroup names.
1443
- If "regex", interpret groups as regular expressions on the real group or
1444
- metagroup names. A la `pandas.filter`.
1445
- inplace : bool, optional
1446
- If ``True``, modify the InferenceData object inplace,
1447
- otherwise, return the modified copy.
1448
-
1449
- Returns
1450
- -------
1451
- InferenceData
1452
- A new InferenceData object with renamed dimension by default.
1453
- When `inplace==True` perform renaming in-place and return `None`
1454
-
1455
- Examples
1456
- --------
1457
- Use ``rename_dims`` to renaming of dimensions on all groups of the InferenceData
1458
- object. We first check the dimensions of original object:
1459
-
1460
- .. jupyter-execute::
1461
-
1462
- import arviz as az
1463
- idata = az.load_arviz_data("rugby")
1464
- idata
1465
-
1466
- In order to rename the dimensions, we use:
1467
-
1468
- .. jupyter-execute::
1469
-
1470
- idata.rename_dims({"team": "team_new"}, inplace=True)
1471
- idata
1472
-
1473
- See Also
1474
- --------
1475
- xarray.Dataset.rename_dims : Returns a new object with renamed dimensions only.
1476
- rename :
1477
- Perform xarray renaming of variable and dimensions on all groups of an InferenceData
1478
- object.
1479
- rename_vars :
1480
- Perform xarray renaming of variable or coordinate names on all groups of an
1481
- InferenceData object.
1482
- """
1483
- groups = self._group_names(groups, filter_groups)
1484
- if "chain" in name_dict.keys() or "draw" in name_dict.keys():
1485
- raise KeyError("'chain' or 'draw' dimensions can't be renamed")
1486
-
1487
- out = self if inplace else deepcopy(self)
1488
- for group in groups:
1489
- dataset = getattr(self, group)
1490
- valid_keys = set(name_dict.keys()).intersection(dataset.dims)
1491
- dataset = dataset.rename_dims({key: name_dict[key] for key in valid_keys})
1492
- setattr(out, group, dataset)
1493
- if inplace:
1494
- return None
1495
- else:
1496
- return out
1497
-
1498
- def add_groups(
1499
- self, group_dict=None, coords=None, dims=None, warn_on_custom_groups=False, **kwargs
1500
- ):
1501
- """Add new groups to InferenceData object.
1502
-
1503
- Parameters
1504
- ----------
1505
- group_dict : dict of {str : dict or xarray.Dataset}, optional
1506
- Groups to be added
1507
- coords : dict of {str : array_like}, optional
1508
- Coordinates for the dataset
1509
- dims : dict of {str : list of str}, optional
1510
- Dimensions of each variable. The keys are variable names, values are lists of
1511
- coordinates.
1512
- warn_on_custom_groups : bool, default False
1513
- Emit a warning when custom groups are present in the InferenceData.
1514
- "custom group" means any group whose name isn't defined in :ref:`schema`
1515
- kwargs : dict, optional
1516
- The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.
1517
-
1518
- Examples
1519
- --------
1520
- Add a ``log_likelihood`` group to the "rugby" example InferenceData after loading.
1521
-
1522
- .. jupyter-execute::
1523
-
1524
- import arviz as az
1525
- idata = az.load_arviz_data("rugby")
1526
- del idata.log_likelihood
1527
- idata2 = idata.copy()
1528
- post = idata.posterior
1529
- obs = idata.observed_data
1530
- idata
1531
-
1532
- Knowing the model, we can compute it manually. In this case however,
1533
- we will generate random samples with the right shape.
1534
-
1535
- .. jupyter-execute::
1536
-
1537
- import numpy as np
1538
- rng = np.random.default_rng(73)
1539
- ary = rng.normal(size=(post.sizes["chain"], post.sizes["draw"], obs.sizes["match"]))
1540
- idata.add_groups(
1541
- log_likelihood={"home_points": ary},
1542
- dims={"home_points": ["match"]},
1543
- )
1544
- idata
1545
-
1546
- This is fine if we have raw data, but a bit inconvenient if we start with labeled
1547
- data already. Why provide dims and coords manually again?
1548
- Let's generate a fake log likelihood (doesn't match the model but it serves just
1549
- the same for illustration purposes here) working from the posterior and
1550
- observed_data groups manually:
1551
-
1552
- .. jupyter-execute::
1553
-
1554
- import xarray as xr
1555
- from xarray_einstats.stats import XrDiscreteRV
1556
- from scipy.stats import poisson
1557
- dist = XrDiscreteRV(poisson, np.exp(post["atts"]))
1558
- log_lik = dist.logpmf(obs["home_points"]).to_dataset(name="home_points")
1559
- idata2.add_groups({"log_likelihood": log_lik})
1560
- idata2
1561
-
1562
- Note that in the first example we have used the ``kwargs`` argument
1563
- and in the second we have used the ``group_dict`` one.
1564
-
1565
- See Also
1566
- --------
1567
- extend : Extend InferenceData with groups from another InferenceData.
1568
- concat : Concatenate InferenceData objects.
1569
- """
1570
- group_dict = either_dict_or_kwargs(group_dict, kwargs, "add_groups")
1571
- if not group_dict:
1572
- raise ValueError("One of group_dict or kwargs must be provided.")
1573
- repeated_groups = [group for group in group_dict.keys() if group in self._groups]
1574
- if repeated_groups:
1575
- raise ValueError(f"{repeated_groups} group(s) already exists.")
1576
- for group, dataset in group_dict.items():
1577
- if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
1578
- warnings.warn(
1579
- f"The group {group} is not defined in the InferenceData scheme",
1580
- UserWarning,
1581
- )
1582
- if dataset is None:
1583
- continue
1584
- elif isinstance(dataset, dict):
1585
- if (
1586
- group in ("observed_data", "constant_data", "predictions_constant_data")
1587
- or group not in SUPPORTED_GROUPS_ALL
1588
- ):
1589
- warnings.warn(
1590
- "the default dims 'chain' and 'draw' will be added automatically",
1591
- UserWarning,
1592
- )
1593
- dataset = dict_to_dataset(dataset, coords=coords, dims=dims)
1594
- elif isinstance(dataset, xr.DataArray):
1595
- if dataset.name is None:
1596
- dataset.name = "x"
1597
- dataset = dataset.to_dataset()
1598
- elif not isinstance(dataset, xr.Dataset):
1599
- raise ValueError(
1600
- "Arguments to add_groups() must be xr.Dataset, xr.Dataarray or dicts\
1601
- (argument '{}' was type '{}')".format(
1602
- group, type(dataset)
1603
- )
1604
- )
1605
- if dataset:
1606
- setattr(self, group, dataset)
1607
- if group.startswith(WARMUP_TAG):
1608
- supported_order = [
1609
- key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup
1610
- ]
1611
- if (supported_order == self._groups_warmup) and (group in SUPPORTED_GROUPS_ALL):
1612
- group_order = [
1613
- key
1614
- for key in SUPPORTED_GROUPS_ALL
1615
- if key in self._groups_warmup + [group]
1616
- ]
1617
- group_idx = group_order.index(group)
1618
- self._groups_warmup.insert(group_idx, group)
1619
- else:
1620
- self._groups_warmup.append(group)
1621
- else:
1622
- supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups]
1623
- if (supported_order == self._groups) and (group in SUPPORTED_GROUPS_ALL):
1624
- group_order = [
1625
- key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group]
1626
- ]
1627
- group_idx = group_order.index(group)
1628
- self._groups.insert(group_idx, group)
1629
- else:
1630
- self._groups.append(group)
1631
-
1632
- def extend(self, other, join="left", warn_on_custom_groups=False):
1633
- """Extend InferenceData with groups from another InferenceData.
1634
-
1635
- Parameters
1636
- ----------
1637
- other : InferenceData
1638
- InferenceData to be added
1639
- join : {'left', 'right'}, default 'left'
1640
- Defines how the two decide which group to keep when the same group is
1641
- present in both objects. 'left' will discard the group in ``other`` whereas 'right'
1642
- will keep the group in ``other`` and discard the one in ``self``.
1643
- warn_on_custom_groups : bool, default False
1644
- Emit a warning when custom groups are present in the InferenceData.
1645
- "custom group" means any group whose name isn't defined in :ref:`schema`
1646
-
1647
- Examples
1648
- --------
1649
- Take two InferenceData objects, and extend the first with the groups it doesn't have
1650
- but are present in the 2nd InferenceData object.
1651
-
1652
- First InferenceData:
1653
-
1654
- .. jupyter-execute::
1655
-
1656
- import arviz as az
1657
- idata = az.load_arviz_data("radon")
1658
-
1659
- Second InferenceData:
1660
-
1661
- .. jupyter-execute::
1662
-
1663
- other_idata = az.load_arviz_data("rugby")
1664
-
1665
- Call the ``extend`` method:
1666
-
1667
- .. jupyter-execute::
1668
-
1669
- idata.extend(other_idata)
1670
- idata
1671
-
1672
- See how now the first InferenceData has more groups, with the data from the
1673
- second one, but the groups it originally had have not been modified,
1674
- even if also present in the second InferenceData.
1675
-
1676
- See Also
1677
- --------
1678
- add_groups : Add new groups to InferenceData object.
1679
- concat : Concatenate InferenceData objects.
1680
-
1681
- """
1682
- if not isinstance(other, InferenceData):
1683
- raise ValueError("Extending is possible between two InferenceData objects only.")
1684
- if join not in ("left", "right"):
1685
- raise ValueError(f"join must be either 'left' or 'right', found {join}")
1686
- for group in other._groups_all: # pylint: disable=protected-access
1687
- if hasattr(self, group) and join == "left":
1688
- continue
1689
- if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
1690
- warnings.warn(
1691
- f"{group} group is not defined in the InferenceData scheme", UserWarning
1692
- )
1693
- dataset = getattr(other, group)
1694
- setattr(self, group, dataset)
1695
- if group.startswith(WARMUP_TAG):
1696
- if group not in self._groups_warmup:
1697
- supported_order = [
1698
- key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup
1699
- ]
1700
- if (supported_order == self._groups_warmup) and (group in SUPPORTED_GROUPS_ALL):
1701
- group_order = [
1702
- key
1703
- for key in SUPPORTED_GROUPS_ALL
1704
- if key in self._groups_warmup + [group]
1705
- ]
1706
- group_idx = group_order.index(group)
1707
- self._groups_warmup.insert(group_idx, group)
1708
- else:
1709
- self._groups_warmup.append(group)
1710
- elif group not in self._groups:
1711
- supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups]
1712
- if (supported_order == self._groups) and (group in SUPPORTED_GROUPS_ALL):
1713
- group_order = [
1714
- key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group]
1715
- ]
1716
- group_idx = group_order.index(group)
1717
- self._groups.insert(group_idx, group)
1718
- else:
1719
- self._groups.append(group)
1720
-
1721
- set_index = _extend_xr_method(xr.Dataset.set_index, see_also="reset_index")
1722
- get_index = _extend_xr_method(xr.Dataset.get_index)
1723
- reset_index = _extend_xr_method(xr.Dataset.reset_index, see_also="set_index")
1724
- set_coords = _extend_xr_method(xr.Dataset.set_coords, see_also="reset_coords")
1725
- reset_coords = _extend_xr_method(xr.Dataset.reset_coords, see_also="set_coords")
1726
- assign = _extend_xr_method(xr.Dataset.assign)
1727
- assign_coords = _extend_xr_method(xr.Dataset.assign_coords)
1728
- sortby = _extend_xr_method(xr.Dataset.sortby)
1729
- chunk = _extend_xr_method(xr.Dataset.chunk)
1730
- unify_chunks = _extend_xr_method(xr.Dataset.unify_chunks)
1731
- load = _extend_xr_method(xr.Dataset.load)
1732
- compute = _extend_xr_method(xr.Dataset.compute)
1733
- persist = _extend_xr_method(xr.Dataset.persist)
1734
- quantile = _extend_xr_method(xr.Dataset.quantile)
1735
- close = _extend_xr_method(xr.Dataset.close)
1736
-
1737
- # The following lines use methods on xr.Dataset that are dynamically defined and attached.
1738
- # As a result mypy cannot see them, so we have to suppress the resulting mypy errors.
1739
- mean = _extend_xr_method(xr.Dataset.mean, see_also="median") # type: ignore[attr-defined]
1740
- median = _extend_xr_method(xr.Dataset.median, see_also="mean") # type: ignore[attr-defined]
1741
- min = _extend_xr_method(xr.Dataset.min, see_also=["max", "sum"]) # type: ignore[attr-defined]
1742
- max = _extend_xr_method(xr.Dataset.max, see_also=["min", "sum"]) # type: ignore[attr-defined]
1743
- cumsum = _extend_xr_method(xr.Dataset.cumsum, see_also="sum") # type: ignore[attr-defined]
1744
- sum = _extend_xr_method(xr.Dataset.sum, see_also="cumsum") # type: ignore[attr-defined]
1745
-
1746
- def _group_names(
1747
- self,
1748
- groups: Optional[Union[str, List[str]]],
1749
- filter_groups: Optional["Literal['like', 'regex']"] = None,
1750
- ) -> List[str]:
1751
- """Handle expansion of group names input across arviz.
1752
-
1753
- Parameters
1754
- ----------
1755
- groups: str, list of str or None
1756
- group or metagroup names.
1757
- idata: xarray.Dataset
1758
- Posterior data in an xarray
1759
- filter_groups: {None, "like", "regex"}, optional, default=None
1760
- If `None` (default), interpret groups as the real group or metagroup names.
1761
- If "like", interpret groups as substrings of the real group or metagroup names.
1762
- If "regex", interpret groups as regular expressions on the real group or
1763
- metagroup names. A la `pandas.filter`.
1764
-
1765
- Returns
1766
- -------
1767
- groups: list
1768
- """
1769
- if filter_groups not in {None, "like", "regex"}:
1770
- raise ValueError(
1771
- f"'filter_groups' can only be None, 'like', or 'regex', got: '{filter_groups}'"
1772
- )
1773
-
1774
- all_groups = self._groups_all
1775
- if groups is None:
1776
- return all_groups
1777
- if isinstance(groups, str):
1778
- groups = [groups]
1779
- sel_groups = []
1780
- metagroups = rcParams["data.metagroups"]
1781
- for group in groups:
1782
- if group[0] == "~":
1783
- sel_groups.extend(
1784
- [f"~{item}" for item in metagroups[group[1:]] if item in all_groups]
1785
- if group[1:] in metagroups
1786
- else [group]
1787
- )
1788
- else:
1789
- sel_groups.extend(
1790
- [item for item in metagroups[group] if item in all_groups]
1791
- if group in metagroups
1792
- else [group]
1793
- )
1794
-
1795
- try:
1796
- group_names = _subset_list(sel_groups, all_groups, filter_items=filter_groups)
1797
- except KeyError as err:
1798
- msg = " ".join(("groups:", f"{err}", "in InferenceData"))
1799
- raise KeyError(msg) from err
1800
- return group_names
1801
-
1802
- def map(self, fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs):
1803
- """Apply a function to multiple groups.
1804
-
1805
- Applies ``fun`` groupwise to the selected ``InferenceData`` groups and overwrites the
1806
- group with the result of the function.
1807
-
1808
- Parameters
1809
- ----------
1810
- fun : callable
1811
- Function to be applied to each group. Assumes the function is called as
1812
- ``fun(dataset, *args, **kwargs)``.
1813
- groups : str or list of str, optional
1814
- Groups where the selection is to be applied. Can either be group names
1815
- or metagroup names.
1816
- filter_groups : {None, "like", "regex"}, optional
1817
- If `None` (default), interpret var_names as the real variables names. If "like",
1818
- interpret var_names as substrings of the real variables names. If "regex",
1819
- interpret var_names as regular expressions on the real variables names. A la
1820
- `pandas.filter`.
1821
- inplace : bool, optional
1822
- If ``True``, modify the InferenceData object inplace,
1823
- otherwise, return the modified copy.
1824
- args : array_like, optional
1825
- Positional arguments passed to ``fun``.
1826
- **kwargs : mapping, optional
1827
- Keyword arguments passed to ``fun``.
1828
-
1829
- Returns
1830
- -------
1831
- InferenceData
1832
- A new InferenceData object by default.
1833
- When `inplace==True` perform selection in place and return `None`
1834
-
1835
- Examples
1836
- --------
1837
- Shift observed_data, prior_predictive and posterior_predictive.
1838
-
1839
- .. jupyter-execute::
1840
-
1841
- import arviz as az
1842
- import numpy as np
1843
- idata = az.load_arviz_data("non_centered_eight")
1844
- idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars")
1845
- idata_shifted_obs
1846
-
1847
- Rename and update the coordinate values in both posterior and prior groups.
1848
-
1849
- .. jupyter-execute::
1850
-
1851
- idata = az.load_arviz_data("radon")
1852
- idata = idata.map(
1853
- lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign(
1854
- uranium_coefs=["intercept", "u_slope"]
1855
- ),
1856
- groups=["posterior", "prior"]
1857
- )
1858
- idata
1859
-
1860
- Add extra coordinates to all groups containing observed variables
1861
-
1862
- .. jupyter-execute::
1863
-
1864
- idata = az.load_arviz_data("rugby")
1865
- home_team, away_team = np.array([
1866
- m.split() for m in idata.observed_data.match.values
1867
- ]).T
1868
- idata = idata.map(
1869
- lambda ds, **kwargs: ds.assign_coords(**kwargs),
1870
- groups="observed_vars",
1871
- home_team=("match", home_team),
1872
- away_team=("match", away_team),
1873
- )
1874
- idata
1875
-
1876
- """
1877
- if args is None:
1878
- args = []
1879
- groups = self._group_names(groups, filter_groups)
1880
-
1881
- out = self if inplace else deepcopy(self)
1882
- for group in groups:
1883
- dataset = getattr(self, group)
1884
- dataset = fun(dataset, *args, **kwargs)
1885
- setattr(out, group, dataset)
1886
- if inplace:
1887
- return None
1888
- else:
1889
- return out
1890
-
1891
- def _wrap_xarray_method(
1892
- self, method, groups=None, filter_groups=None, inplace=False, args=None, **kwargs
1893
- ):
1894
- """Extend and xarray.Dataset method to InferenceData object.
1895
-
1896
- Parameters
1897
- ----------
1898
- method: str
1899
- Method to be extended. Must be a ``xarray.Dataset`` method.
1900
- groups: str or list of str, optional
1901
- Groups where the selection is to be applied. Can either be group names
1902
- or metagroup names.
1903
- inplace: bool, optional
1904
- If ``True``, modify the InferenceData object inplace,
1905
- otherwise, return the modified copy.
1906
- **kwargs: mapping, optional
1907
- Keyword arguments passed to the xarray Dataset method.
1908
-
1909
- Returns
1910
- -------
1911
- InferenceData
1912
- A new InferenceData object by default.
1913
- When `inplace==True` perform selection in place and return `None`
1914
-
1915
- Examples
1916
- --------
1917
- Compute the mean of `posterior_groups`:
1918
-
1919
- .. ipython::
1920
-
1921
- In [1]: import arviz as az
1922
- ...: idata = az.load_arviz_data("non_centered_eight")
1923
- ...: idata_means = idata._wrap_xarray_method("mean", groups="latent_vars")
1924
- ...: print(idata_means.posterior)
1925
- ...: print(idata_means.observed_data)
1926
-
1927
- .. ipython::
1928
-
1929
- In [1]: idata_stack = idata._wrap_xarray_method(
1930
- ...: "stack",
1931
- ...: groups=["posterior_groups", "prior_groups"],
1932
- ...: sample=["chain", "draw"]
1933
- ...: )
1934
- ...: print(idata_stack.posterior)
1935
- ...: print(idata_stack.prior)
1936
- ...: print(idata_stack.observed_data)
1937
-
1938
- """
1939
- if args is None:
1940
- args = []
1941
- groups = self._group_names(groups, filter_groups)
1942
-
1943
- method = getattr(xr.Dataset, method)
1944
-
1945
- out = self if inplace else deepcopy(self)
1946
- for group in groups:
1947
- dataset = getattr(self, group)
1948
- dataset = method(dataset, *args, **kwargs)
1949
- setattr(out, group, dataset)
1950
- if inplace:
1951
- return None
1952
- else:
1953
- return out
1954
-
1955
- def copy(self) -> "InferenceData":
1956
- """Return a fresh copy of the ``InferenceData`` object."""
1957
- return deepcopy(self)
1958
-
1959
-
1960
- @overload
1961
- def concat(
1962
- *args,
1963
- dim: Optional[str] = None,
1964
- copy: bool = True,
1965
- inplace: "Literal[True]",
1966
- reset_dim: bool = True,
1967
- ) -> None: ...
1968
-
1969
-
1970
- @overload
1971
- def concat(
1972
- *args,
1973
- dim: Optional[str] = None,
1974
- copy: bool = True,
1975
- inplace: "Literal[False]",
1976
- reset_dim: bool = True,
1977
- ) -> InferenceData: ...
1978
-
1979
-
1980
- @overload
1981
- def concat(
1982
- ids: Iterable[InferenceData],
1983
- dim: Optional[str] = None,
1984
- *,
1985
- copy: bool = True,
1986
- inplace: "Literal[False]",
1987
- reset_dim: bool = True,
1988
- ) -> InferenceData: ...
1989
-
1990
-
1991
- @overload
1992
- def concat(
1993
- ids: Iterable[InferenceData],
1994
- dim: Optional[str] = None,
1995
- *,
1996
- copy: bool = True,
1997
- inplace: "Literal[True]",
1998
- reset_dim: bool = True,
1999
- ) -> None: ...
2000
-
2001
-
2002
- @overload
2003
- def concat(
2004
- ids: Iterable[InferenceData],
2005
- dim: Optional[str] = None,
2006
- *,
2007
- copy: bool = True,
2008
- inplace: bool = False,
2009
- reset_dim: bool = True,
2010
- ) -> Optional[InferenceData]: ...
2011
-
2012
-
2013
- # pylint: disable=protected-access, inconsistent-return-statements
2014
- def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
2015
- """Concatenate InferenceData objects.
2016
-
2017
- Concatenates over `group`, `chain` or `draw`.
2018
- By default concatenates over unique groups.
2019
- To concatenate over `chain` or `draw` function
2020
- needs identical groups and variables.
2021
-
2022
- The `variables` in the `data` -group are merged if `dim` are not found.
2023
-
2024
-
2025
- Parameters
2026
- ----------
2027
- *args : InferenceData
2028
- Variable length InferenceData list or
2029
- Sequence of InferenceData.
2030
- dim : str, optional
2031
- If defined, concatenated over the defined dimension.
2032
- Dimension which is concatenated. If None, concatenates over
2033
- unique groups.
2034
- copy : bool
2035
- If True, groups are copied to the new InferenceData object.
2036
- Used only if `dim` is None.
2037
- inplace : bool
2038
- If True, merge args to first object.
2039
- reset_dim : bool
2040
- Valid only if dim is not None.
2041
-
2042
- Returns
2043
- -------
2044
- InferenceData
2045
- A new InferenceData object by default.
2046
- When `inplace==True` merge args to first arg and return `None`
2047
-
2048
- See Also
2049
- --------
2050
- add_groups : Add new groups to InferenceData object.
2051
- extend : Extend InferenceData with groups from another InferenceData.
2052
-
2053
- Examples
2054
- --------
2055
- Use ``concat`` method to concatenate InferenceData objects. This will concatenates over
2056
- unique groups by default. We first create an ``InferenceData`` object:
2057
-
2058
- .. ipython::
2059
-
2060
- In [1]: import arviz as az
2061
- ...: import numpy as np
2062
- ...: data = {
2063
- ...: "a": np.random.normal(size=(4, 100, 3)),
2064
- ...: "b": np.random.normal(size=(4, 100)),
2065
- ...: }
2066
- ...: coords = {"a_dim": ["x", "y", "z"]}
2067
- ...: dataA = az.from_dict(data, coords=coords, dims={"a": ["a_dim"]})
2068
- ...: dataA
2069
-
2070
- We have created an ``InferenceData`` object with default group 'posterior'. Now, we will
2071
- create another ``InferenceData`` object:
2072
-
2073
- .. ipython::
2074
-
2075
- In [1]: dataB = az.from_dict(prior=data, coords=coords, dims={"a": ["a_dim"]})
2076
- ...: dataB
2077
-
2078
- We have created another ``InferenceData`` object with group 'prior'. Now, we will concatenate
2079
- these two ``InferenceData`` objects:
2080
-
2081
- .. ipython::
2082
-
2083
- In [1]: az.concat(dataA, dataB)
2084
-
2085
- Now, we will concatenate over chain (or draw). It requires identical groups and variables.
2086
- Here we are concatenating two identical ``InferenceData`` objects over dimension chain:
2087
-
2088
- .. ipython::
2089
-
2090
- In [1]: az.concat(dataA, dataA, dim="chain")
2091
-
2092
- It will create an ``InferenceData`` with the original group 'posterior'. In similar way,
2093
- we can also concatenate over draws.
2094
-
2095
- """
2096
- # pylint: disable=undefined-loop-variable, too-many-nested-blocks
2097
- if len(args) == 0:
2098
- if inplace:
2099
- return
2100
- return InferenceData()
2101
-
2102
- if len(args) == 1 and isinstance(args[0], Sequence):
2103
- args = args[0]
2104
-
2105
- # assert that all args are InferenceData
2106
- for i, arg in enumerate(args):
2107
- if not isinstance(arg, InferenceData):
2108
- raise TypeError(
2109
- "Concatenating is supported only"
2110
- "between InferenceData objects. Input arg {} is {}".format(i, type(arg))
2111
- )
2112
-
2113
- if dim is not None and dim.lower() not in {"group", "chain", "draw"}:
2114
- msg = f'Invalid `dim`: {dim}. Valid `dim` are {{"group", "chain", "draw"}}'
2115
- raise TypeError(msg)
2116
- dim = dim.lower() if dim is not None else dim
2117
-
2118
- if len(args) == 1 and isinstance(args[0], InferenceData):
2119
- if inplace:
2120
- return None
2121
- else:
2122
- if copy:
2123
- return deepcopy(args[0])
2124
- else:
2125
- return args[0]
2126
-
2127
- current_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
2128
- combined_attr = defaultdict(list)
2129
- for idata in args:
2130
- for key, val in idata.attrs.items():
2131
- combined_attr[key].append(val)
2132
-
2133
- for key, val in combined_attr.items():
2134
- all_same = True
2135
- for indx in range(len(val) - 1):
2136
- if val[indx] != val[indx + 1]:
2137
- all_same = False
2138
- break
2139
- if all_same:
2140
- combined_attr[key] = val[0]
2141
- if inplace:
2142
- setattr(args[0], "_attrs", dict(combined_attr))
2143
-
2144
- if not inplace:
2145
- # Keep order for python 3.5
2146
- inference_data_dict = OrderedDict()
2147
-
2148
- if dim is None:
2149
- arg0 = args[0]
2150
- arg0_groups = ccopy(arg0._groups_all)
2151
- args_groups = {}
2152
- # check if groups are independent
2153
- # Concat over unique groups
2154
- for arg in args[1:]:
2155
- for group in arg._groups_all:
2156
- if group in args_groups or group in arg0_groups:
2157
- msg = (
2158
- "Concatenating overlapping groups is not supported unless `dim` is defined."
2159
- " Valid dimensions are `chain` and `draw`. Alternatively, use extend to"
2160
- " combine InferenceData with overlapping groups"
2161
- )
2162
- raise TypeError(msg)
2163
- group_data = getattr(arg, group)
2164
- args_groups[group] = deepcopy(group_data) if copy else group_data
2165
- # add arg0 to args_groups if inplace is False
2166
- # otherwise it will merge args_groups to arg0
2167
- # inference data object
2168
- if not inplace:
2169
- for group in arg0_groups:
2170
- group_data = getattr(arg0, group)
2171
- args_groups[group] = deepcopy(group_data) if copy else group_data
2172
-
2173
- other_groups = [group for group in args_groups if group not in SUPPORTED_GROUPS_ALL]
2174
-
2175
- for group in SUPPORTED_GROUPS_ALL + other_groups:
2176
- if group not in args_groups:
2177
- continue
2178
- if inplace:
2179
- if group.startswith(WARMUP_TAG):
2180
- arg0._groups_warmup.append(group)
2181
- else:
2182
- arg0._groups.append(group)
2183
- setattr(arg0, group, args_groups[group])
2184
- else:
2185
- inference_data_dict[group] = args_groups[group]
2186
- if inplace:
2187
- other_groups = [
2188
- group for group in arg0_groups if group not in SUPPORTED_GROUPS_ALL
2189
- ] + other_groups
2190
- sorted_groups = [
2191
- group for group in SUPPORTED_GROUPS + other_groups if group in arg0._groups
2192
- ]
2193
- setattr(arg0, "_groups", sorted_groups)
2194
- sorted_groups_warmup = [
2195
- group
2196
- for group in SUPPORTED_GROUPS_WARMUP + other_groups
2197
- if group in arg0._groups_warmup
2198
- ]
2199
- setattr(arg0, "_groups_warmup", sorted_groups_warmup)
2200
- else:
2201
- arg0 = args[0]
2202
- arg0_groups = arg0._groups_all
2203
- for arg in args[1:]:
2204
- for group0 in arg0_groups:
2205
- if group0 not in arg._groups_all:
2206
- if group0 == "observed_data":
2207
- continue
2208
- msg = "Mismatch between the groups."
2209
- raise TypeError(msg)
2210
- for group in arg._groups_all:
2211
- # handle data groups separately
2212
- if group not in ["observed_data", "constant_data", "predictions_constant_data"]:
2213
- # assert that groups are equal
2214
- if group not in arg0_groups:
2215
- msg = "Mismatch between the groups."
2216
- raise TypeError(msg)
2217
-
2218
- # assert that variables are equal
2219
- group_data = getattr(arg, group)
2220
- group_vars = group_data.data_vars
2221
-
2222
- if not inplace and group in inference_data_dict:
2223
- group0_data = inference_data_dict[group]
2224
- else:
2225
- group0_data = getattr(arg0, group)
2226
- group0_vars = group0_data.data_vars
2227
-
2228
- for var in group0_vars:
2229
- if var not in group_vars:
2230
- msg = "Mismatch between the variables."
2231
- raise TypeError(msg)
2232
-
2233
- for var in group_vars:
2234
- if var not in group0_vars:
2235
- msg = "Mismatch between the variables."
2236
- raise TypeError(msg)
2237
- var_dims = group_data[var].dims
2238
- var0_dims = group0_data[var].dims
2239
- if var_dims != var0_dims:
2240
- msg = "Mismatch between the dimensions."
2241
- raise TypeError(msg)
2242
-
2243
- if dim not in var_dims or dim not in var0_dims:
2244
- msg = f"Dimension {dim} missing."
2245
- raise TypeError(msg)
2246
-
2247
- # xr.concat
2248
- concatenated_group = xr.concat((group0_data, group_data), dim=dim)
2249
- if reset_dim:
2250
- concatenated_group[dim] = range(concatenated_group[dim].size)
2251
-
2252
- # handle attrs
2253
- if hasattr(group0_data, "attrs"):
2254
- group0_attrs = deepcopy(getattr(group0_data, "attrs"))
2255
- else:
2256
- group0_attrs = OrderedDict()
2257
-
2258
- if hasattr(group_data, "attrs"):
2259
- group_attrs = getattr(group_data, "attrs")
2260
- else:
2261
- group_attrs = {}
2262
-
2263
- # gather attrs results to group0_attrs
2264
- for attr_key, attr_values in group_attrs.items():
2265
- group0_attr_values = group0_attrs.get(attr_key, None)
2266
- equality = attr_values == group0_attr_values
2267
- if hasattr(equality, "__iter__"):
2268
- equality = np.all(equality)
2269
- if equality:
2270
- continue
2271
- # handle special cases:
2272
- if attr_key in ("created_at", "previous_created_at"):
2273
- # check the defaults
2274
- if not hasattr(group0_attrs, "previous_created_at"):
2275
- group0_attrs["previous_created_at"] = []
2276
- if group0_attr_values is not None:
2277
- group0_attrs["previous_created_at"].append(group0_attr_values)
2278
- # check previous values
2279
- if attr_key == "previous_created_at":
2280
- if not isinstance(attr_values, list):
2281
- attr_values = [attr_values]
2282
- group0_attrs["previous_created_at"].extend(attr_values)
2283
- continue
2284
- # update "created_at"
2285
- if group0_attr_values != current_time:
2286
- group0_attrs[attr_key] = current_time
2287
- group0_attrs["previous_created_at"].append(attr_values)
2288
-
2289
- elif attr_key in group0_attrs:
2290
- combined_key = f"combined_{attr_key}"
2291
- if combined_key not in group0_attrs:
2292
- group0_attrs[combined_key] = [group0_attr_values]
2293
- group0_attrs[combined_key].append(attr_values)
2294
- else:
2295
- group0_attrs[attr_key] = attr_values
2296
- # update attrs
2297
- setattr(concatenated_group, "attrs", group0_attrs)
2298
-
2299
- if inplace:
2300
- setattr(arg0, group, concatenated_group)
2301
- else:
2302
- inference_data_dict[group] = concatenated_group
2303
- else:
2304
- # observed_data, "constant_data", "predictions_constant_data",
2305
- if group not in arg0_groups:
2306
- setattr(arg0, group, deepcopy(group_data) if copy else group_data)
2307
- arg0._groups.append(group)
2308
- continue
2309
-
2310
- # assert that variables are equal
2311
- group_data = getattr(arg, group)
2312
- group_vars = group_data.data_vars
2313
-
2314
- group0_data = getattr(arg0, group)
2315
- if not inplace:
2316
- group0_data = deepcopy(group0_data)
2317
- group0_vars = group0_data.data_vars
2318
-
2319
- for var in group_vars:
2320
- if var not in group0_vars:
2321
- var_data = group_data[var]
2322
- getattr(arg0, group)[var] = var_data
2323
- else:
2324
- var_data = group_data[var]
2325
- var0_data = group0_data[var]
2326
- if dim in var_data.dims and dim in var0_data.dims:
2327
- concatenated_var = xr.concat((group_data, group0_data), dim=dim)
2328
- group0_data[var] = concatenated_var
2329
-
2330
- # handle attrs
2331
- if hasattr(group0_data, "attrs"):
2332
- group0_attrs = getattr(group0_data, "attrs")
2333
- else:
2334
- group0_attrs = OrderedDict()
2335
-
2336
- if hasattr(group_data, "attrs"):
2337
- group_attrs = getattr(group_data, "attrs")
2338
- else:
2339
- group_attrs = {}
2340
-
2341
- # gather attrs results to group0_attrs
2342
- for attr_key, attr_values in group_attrs.items():
2343
- group0_attr_values = group0_attrs.get(attr_key, None)
2344
- equality = attr_values == group0_attr_values
2345
- if hasattr(equality, "__iter__"):
2346
- equality = np.all(equality)
2347
- if equality:
2348
- continue
2349
- # handle special cases:
2350
- if attr_key in ("created_at", "previous_created_at"):
2351
- # check the defaults
2352
- if not hasattr(group0_attrs, "previous_created_at"):
2353
- group0_attrs["previous_created_at"] = []
2354
- if group0_attr_values is not None:
2355
- group0_attrs["previous_created_at"].append(group0_attr_values)
2356
- # check previous values
2357
- if attr_key == "previous_created_at":
2358
- if not isinstance(attr_values, list):
2359
- attr_values = [attr_values]
2360
- group0_attrs["previous_created_at"].extend(attr_values)
2361
- continue
2362
- # update "created_at"
2363
- if group0_attr_values != current_time:
2364
- group0_attrs[attr_key] = current_time
2365
- group0_attrs["previous_created_at"].append(attr_values)
2366
-
2367
- elif attr_key in group0_attrs:
2368
- combined_key = f"combined_{attr_key}"
2369
- if combined_key not in group0_attrs:
2370
- group0_attrs[combined_key] = [group0_attr_values]
2371
- group0_attrs[combined_key].append(attr_values)
2372
-
2373
- else:
2374
- group0_attrs[attr_key] = attr_values
2375
- # update attrs
2376
- setattr(group0_data, "attrs", group0_attrs)
2377
-
2378
- if inplace:
2379
- setattr(arg0, group, group0_data)
2380
- else:
2381
- inference_data_dict[group] = group0_data
2382
-
2383
- if not inplace:
2384
- inference_data_dict["attrs"] = combined_attr
2385
-
2386
- return None if inplace else InferenceData(**inference_data_dict)