arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/utils.py DELETED
@@ -1,773 +0,0 @@
1
- # pylint: disable=too-many-nested-blocks
2
- """General utilities."""
3
- import functools
4
- import importlib
5
- import importlib.resources
6
- import re
7
- import warnings
8
- from functools import lru_cache
9
-
10
- import matplotlib.pyplot as plt
11
- import numpy as np
12
- from numpy import newaxis
13
-
14
- from .rcparams import rcParams
15
-
16
-
17
- STATIC_FILES = ("static/html/icons-svg-inline.html", "static/css/style.css")
18
-
19
-
20
- class BehaviourChangeWarning(Warning):
21
- """Custom warning to ease filtering it."""
22
-
23
-
24
- def _check_tilde_start(x):
25
- return bool(isinstance(x, str) and x.startswith("~"))
26
-
27
-
28
- def _var_names(var_names, data, filter_vars=None, errors="raise"):
29
- """Handle var_names input across arviz.
30
-
31
- Parameters
32
- ----------
33
- var_names: str, list, or None
34
- data : xarray.Dataset
35
- Posterior data in an xarray
36
- filter_vars: {None, "like", "regex"}, optional, default=None
37
- If `None` (default), interpret var_names as the real variables names. If "like",
38
- interpret var_names as substrings of the real variables names. If "regex",
39
- interpret var_names as regular expressions on the real variables names. A la
40
- `pandas.filter`.
41
- errors: {"raise", "ignore"}, optional, default="raise"
42
- Select either to raise or ignore the invalid names.
43
-
44
- Returns
45
- -------
46
- var_name: list or None
47
- """
48
- if filter_vars not in {None, "like", "regex"}:
49
- raise ValueError(
50
- f"'filter_vars' can only be None, 'like', or 'regex', got: '{filter_vars}'"
51
- )
52
-
53
- if errors not in {"raise", "ignore"}:
54
- raise ValueError(f"'errors' can only be 'raise', or 'ignore', got: '{errors}'")
55
-
56
- if var_names is not None:
57
- if isinstance(data, (list, tuple)):
58
- all_vars = []
59
- for dataset in data:
60
- dataset_vars = list(dataset.data_vars)
61
- for var in dataset_vars:
62
- if var not in all_vars:
63
- all_vars.append(var)
64
- else:
65
- all_vars = list(data.data_vars)
66
-
67
- all_vars_tilde = [var for var in all_vars if _check_tilde_start(var)]
68
- if all_vars_tilde:
69
- warnings.warn(
70
- """ArviZ treats '~' as a negation character for variable selection.
71
- Your model has variables names starting with '~', {0}. Please double check
72
- your results to ensure all variables are included""".format(
73
- ", ".join(all_vars_tilde)
74
- )
75
- )
76
-
77
- try:
78
- var_names = _subset_list(
79
- var_names, all_vars, filter_items=filter_vars, warn=False, errors=errors
80
- )
81
- except KeyError as err:
82
- msg = " ".join(("var names:", f"{err}", "in dataset"))
83
- raise KeyError(msg) from err
84
- return var_names
85
-
86
-
87
- def _subset_list(subset, whole_list, filter_items=None, warn=True, errors="raise"):
88
- """Handle list subsetting (var_names, groups...) across arviz.
89
-
90
- Parameters
91
- ----------
92
- subset : str, list, or None
93
- whole_list : list
94
- List from which to select a subset according to subset elements and
95
- filter_items value.
96
- filter_items : {None, "like", "regex"}, optional
97
- If `None` (default), interpret `subset` as the exact elements in `whole_list`
98
- names. If "like", interpret `subset` as substrings of the elements in
99
- `whole_list`. If "regex", interpret `subset` as regular expressions to match
100
- elements in `whole_list`. A la `pandas.filter`.
101
- errors: {"raise", "ignore"}, optional, default="raise"
102
- Select either to raise or ignore the invalid names.
103
-
104
- Returns
105
- -------
106
- list or None
107
- A subset of ``whole_list`` fulfilling the requests imposed by ``subset``
108
- and ``filter_items``.
109
- """
110
- if subset is not None:
111
- if isinstance(subset, str):
112
- subset = [subset]
113
-
114
- whole_list_tilde = [item for item in whole_list if _check_tilde_start(item)]
115
- if whole_list_tilde and warn:
116
- warnings.warn(
117
- "ArviZ treats '~' as a negation character for selection. There are "
118
- "elements in `whole_list` starting with '~', {0}. Please double check"
119
- "your results to ensure all elements are included".format(
120
- ", ".join(whole_list_tilde)
121
- )
122
- )
123
-
124
- excluded_items = [
125
- item[1:] for item in subset if _check_tilde_start(item) and item not in whole_list
126
- ]
127
- filter_items = str(filter_items).lower()
128
- if excluded_items:
129
- not_found = []
130
-
131
- if filter_items in {"like", "regex"}:
132
- for pattern in excluded_items[:]:
133
- excluded_items.remove(pattern)
134
- if filter_items == "like":
135
- real_items = [real_item for real_item in whole_list if pattern in real_item]
136
- else:
137
- # i.e filter_items == "regex"
138
- real_items = [
139
- real_item for real_item in whole_list if re.search(pattern, real_item)
140
- ]
141
- if not real_items:
142
- not_found.append(pattern)
143
- excluded_items.extend(real_items)
144
- not_found.extend([item for item in excluded_items if item not in whole_list])
145
- if not_found:
146
- warnings.warn(
147
- f"Items starting with ~: {not_found} have not been found and will be ignored"
148
- )
149
- subset = [item for item in whole_list if item not in excluded_items]
150
-
151
- elif filter_items == "like":
152
- subset = [item for item in whole_list for name in subset if name in item]
153
- elif filter_items == "regex":
154
- subset = [item for item in whole_list for name in subset if re.search(name, item)]
155
-
156
- existing_items = np.isin(subset, whole_list)
157
- if not np.all(existing_items) and (errors == "raise"):
158
- raise KeyError(f"{np.array(subset)[~existing_items]} are not present")
159
-
160
- return subset
161
-
162
-
163
- class lazy_property: # pylint: disable=invalid-name
164
- """Used to load numba first time it is needed."""
165
-
166
- def __init__(self, fget):
167
- """Lazy load a property with `fget`."""
168
- self.fget = fget
169
-
170
- # copy the getter function's docstring and other attributes
171
- functools.update_wrapper(self, fget)
172
-
173
- def __get__(self, obj, cls):
174
- """Call the function, set the attribute."""
175
- if obj is None:
176
- return self
177
-
178
- value = self.fget(obj)
179
- setattr(obj, self.fget.__name__, value)
180
- return value
181
-
182
-
183
- class maybe_numba_fn: # pylint: disable=invalid-name
184
- """Wrap a function to (maybe) use a (lazy) jit-compiled version."""
185
-
186
- def __init__(self, function, **kwargs):
187
- """Wrap a function and save compilation keywords."""
188
- self.function = function
189
- kwargs.setdefault("nopython", True)
190
- self.kwargs = kwargs
191
-
192
- @lazy_property
193
- def numba_fn(self):
194
- """Memoized compiled function."""
195
- try:
196
- numba = importlib.import_module("numba")
197
- numba_fn = numba.jit(**self.kwargs)(self.function)
198
- except ImportError:
199
- numba_fn = self.function
200
- return numba_fn
201
-
202
- def __call__(self, *args, **kwargs):
203
- """Call the jitted function or normal, depending on flag."""
204
- if Numba.numba_flag:
205
- return self.numba_fn(*args, **kwargs)
206
- else:
207
- return self.function(*args, **kwargs)
208
-
209
-
210
- class interactive_backend: # pylint: disable=invalid-name
211
- """Context manager to change backend temporarily in ipython sesson.
212
-
213
- It uses ipython magic to change temporarily from the ipython inline backend to
214
- an interactive backend of choice. It cannot be used outside ipython sessions nor
215
- to change backends different than inline -> interactive.
216
-
217
- Notes
218
- -----
219
- The first time ``interactive_backend`` context manager is called, any of the available
220
- interactive backends can be chosen. The following times, this same backend must be used
221
- unless the kernel is restarted.
222
-
223
- Parameters
224
- ----------
225
- backend : str, optional
226
- Interactive backend to use. It will be passed to ``%matplotlib`` magic, refer to
227
- its docs to see available options.
228
-
229
- Examples
230
- --------
231
- Inside an ipython session (i.e. a jupyter notebook) with the inline backend set:
232
-
233
- .. code::
234
-
235
- >>> import arviz as az
236
- >>> idata = az.load_arviz_data("centered_eight")
237
- >>> az.plot_posterior(idata) # inline
238
- >>> with az.interactive_backend():
239
- ... az.plot_density(idata) # interactive
240
- >>> az.plot_trace(idata) # inline
241
-
242
- """
243
-
244
- # based on matplotlib.rc_context
245
- def __init__(self, backend=""):
246
- """Initialize context manager."""
247
- try:
248
- from IPython import get_ipython
249
- except ImportError as err:
250
- raise ImportError(
251
- "The exception below was risen while importing Ipython, this "
252
- f"context manager can only be used inside ipython sessions:\n{err}"
253
- ) from err
254
- self.ipython = get_ipython()
255
- if self.ipython is None:
256
- raise EnvironmentError("This context manager can only be used inside ipython sessions")
257
- self.ipython.magic(f"matplotlib {backend}")
258
-
259
- def __enter__(self):
260
- """Enter context manager."""
261
- return self
262
-
263
- def __exit__(self, exc_type, exc_value, exc_tb):
264
- """Exit context manager."""
265
- plt.show(block=True)
266
- self.ipython.magic("matplotlib inline")
267
-
268
-
269
- def conditional_jit(_func=None, **kwargs):
270
- """Use numba's jit decorator if numba is installed.
271
-
272
- Notes
273
- -----
274
- If called without arguments then return wrapped function.
275
-
276
- @conditional_jit
277
- def my_func():
278
- return
279
-
280
- else called with arguments
281
-
282
- @conditional_jit(nopython=True)
283
- def my_func():
284
- return
285
-
286
- """
287
- if _func is None:
288
- return lambda fn: functools.wraps(fn)(maybe_numba_fn(fn, **kwargs))
289
- lazy_numba = maybe_numba_fn(_func, **kwargs)
290
- return functools.wraps(_func)(lazy_numba)
291
-
292
-
293
- def conditional_vect(function=None, **kwargs): # noqa: D202
294
- """Use numba's vectorize decorator if numba is installed.
295
-
296
- Notes
297
- -----
298
- If called without arguments then return wrapped function.
299
- @conditional_vect
300
- def my_func():
301
- return
302
- else called with arguments
303
- @conditional_vect(nopython=True)
304
- def my_func():
305
- return
306
-
307
- """
308
-
309
- def wrapper(function):
310
- try:
311
- numba = importlib.import_module("numba")
312
- return numba.vectorize(**kwargs)(function)
313
-
314
- except ImportError:
315
- return function
316
-
317
- if function:
318
- return wrapper(function)
319
- else:
320
- return wrapper
321
-
322
-
323
- def numba_check():
324
- """Check if numba is installed."""
325
- numba = importlib.util.find_spec("numba")
326
- return numba is not None
327
-
328
-
329
- class Numba:
330
- """A class to toggle numba states."""
331
-
332
- numba_flag = numba_check()
333
- """bool: Indicates whether Numba optimizations are enabled. Defaults to False."""
334
-
335
- @classmethod
336
- def disable_numba(cls):
337
- """To disable numba."""
338
- cls.numba_flag = False
339
-
340
- @classmethod
341
- def enable_numba(cls):
342
- """To enable numba."""
343
- if numba_check():
344
- cls.numba_flag = True
345
- else:
346
- raise ValueError("Numba is not installed")
347
-
348
-
349
- def _numba_var(numba_function, standard_numpy_func, data, axis=None, ddof=0):
350
- """Replace the numpy methods used to calculate variance.
351
-
352
- Parameters
353
- ----------
354
- numba_function : function()
355
- Custom numba function included in stats/stats_utils.py.
356
-
357
- standard_numpy_func: function()
358
- Standard function included in the numpy library.
359
-
360
- data : array.
361
- axis : axis along which the variance is calculated.
362
- ddof : degrees of freedom allowed while calculating variance.
363
-
364
- Returns
365
- -------
366
- array:
367
- variance values calculate by appropriate function for numba speedup
368
- if Numba is installed or enabled.
369
-
370
- """
371
- if Numba.numba_flag:
372
- return numba_function(data, axis=axis, ddof=ddof)
373
- else:
374
- return standard_numpy_func(data, axis=axis, ddof=ddof)
375
-
376
-
377
- def _stack(x, y):
378
- assert x.shape[1:] == y.shape[1:]
379
- return np.vstack((x, y))
380
-
381
-
382
- def arange(x):
383
- """Jitting numpy arange."""
384
- return np.arange(x)
385
-
386
-
387
- def one_de(x):
388
- """Jitting numpy atleast_1d."""
389
- if not isinstance(x, np.ndarray):
390
- return np.atleast_1d(x)
391
- result = x.reshape(1) if x.ndim == 0 else x
392
- return result
393
-
394
-
395
- def two_de(x):
396
- """Jitting numpy at_least_2d."""
397
- if not isinstance(x, np.ndarray):
398
- return np.atleast_2d(x)
399
- if x.ndim == 0:
400
- result = x.reshape(1, 1)
401
- elif x.ndim == 1:
402
- result = x[newaxis, :]
403
- else:
404
- result = x
405
- return result
406
-
407
-
408
- def expand_dims(x):
409
- """Jitting numpy expand_dims."""
410
- if not isinstance(x, np.ndarray):
411
- return np.expand_dims(x, 0)
412
- shape = x.shape
413
- return x.reshape(shape[:0] + (1,) + shape[0:])
414
-
415
-
416
- @conditional_jit(cache=True, nopython=True)
417
- def _dot(x, y):
418
- return np.dot(x, y)
419
-
420
-
421
- @conditional_jit(cache=True, nopython=True)
422
- def _cov_1d(x):
423
- x = x - x.mean()
424
- ddof = x.shape[0] - 1
425
- return np.dot(x.T, x.conj()) / ddof
426
-
427
-
428
- # @conditional_jit(cache=True)
429
- def _cov(data):
430
- if data.ndim == 1:
431
- return _cov_1d(data)
432
- elif data.ndim == 2:
433
- x = data.astype(float)
434
- avg, _ = np.average(x, axis=1, weights=None, returned=True)
435
- ddof = x.shape[1] - 1
436
- if ddof <= 0:
437
- warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
438
- ddof = 0.0
439
- x -= avg[:, None]
440
- prod = _dot(x, x.T.conj())
441
- prod *= np.true_divide(1, ddof)
442
- prod = prod.squeeze()
443
- prod += 1e-6 * np.eye(prod.shape[0])
444
- return prod
445
- else:
446
- raise ValueError(f"{data.ndim} dimension arrays are not supported")
447
-
448
-
449
- def flatten_inference_data_to_dict(
450
- data,
451
- var_names=None,
452
- groups=None,
453
- dimensions=None,
454
- group_info=False,
455
- var_name_format=None,
456
- index_origin=None,
457
- ):
458
- """Transform data to dictionary.
459
-
460
- Parameters
461
- ----------
462
- data : obj
463
- Any object that can be converted to an az.InferenceData object
464
- Refer to documentation of az.convert_to_inference_data for details
465
- var_names : str or list of str, optional
466
- Variables to be processed, if None all variables are processed.
467
- groups : str or list of str, optional
468
- Select groups for CDS. Default groups are
469
- {"posterior_groups", "prior_groups", "posterior_groups_warmup"}
470
- - posterior_groups: posterior, posterior_predictive, sample_stats
471
- - prior_groups: prior, prior_predictive, sample_stats_prior
472
- - posterior_groups_warmup: warmup_posterior, warmup_posterior_predictive,
473
- warmup_sample_stats
474
- ignore_groups : str or list of str, optional
475
- Ignore specific groups from CDS.
476
- dimension : str, or list of str, optional
477
- Select dimensions along to slice the data. By default uses ("chain", "draw").
478
- group_info : bool
479
- Add group info for `var_name_format`
480
- var_name_format : str or tuple of tuple of string, optional
481
- Select column name format for non-scalar input.
482
- Predefined options are {"brackets", "underscore", "cds"}
483
- "brackets":
484
- - add_group_info == False: theta[0,0]
485
- - add_group_info == True: theta_posterior[0,0]
486
- "underscore":
487
- - add_group_info == False: theta_0_0
488
- - add_group_info == True: theta_posterior_0_0_
489
- "cds":
490
- - add_group_info == False: theta_ARVIZ_CDS_SELECTION_0_0
491
- - add_group_info == True: theta_ARVIZ_GROUP_posterior__ARVIZ_CDS_SELECTION_0_0
492
- tuple:
493
- Structure:
494
- tuple: (dim_info, group_info)
495
- dim_info: (str: `.join` separator,
496
- str: dim_separator_start,
497
- str: dim_separator_end)
498
- group_info: (str: group separator start, str: group separator end)
499
- Example: ((",", "[", "]"), ("_", ""))
500
- - add_group_info == False: theta[0,0]
501
- - add_group_info == True: theta_posterior[0,0]
502
- index_origin : int, optional
503
- Start parameter indices from `index_origin`. Either 0 or 1.
504
-
505
- Returns
506
- -------
507
- dict
508
- """
509
- from .data import convert_to_inference_data
510
-
511
- data = convert_to_inference_data(data)
512
-
513
- if groups is None:
514
- groups = ["posterior", "posterior_predictive", "sample_stats"]
515
- elif isinstance(groups, str):
516
- if groups.lower() == "posterior_groups":
517
- groups = ["posterior", "posterior_predictive", "sample_stats"]
518
- elif groups.lower() == "prior_groups":
519
- groups = ["prior", "prior_predictive", "sample_stats_prior"]
520
- elif groups.lower() == "posterior_groups_warmup":
521
- groups = ["warmup_posterior", "warmup_posterior_predictive", "warmup_sample_stats"]
522
- else:
523
- raise TypeError(
524
- (
525
- "Valid predefined groups are "
526
- "{posterior_groups, prior_groups, posterior_groups_warmup}"
527
- )
528
- )
529
-
530
- if dimensions is None:
531
- dimensions = "chain", "draw"
532
- elif isinstance(dimensions, str):
533
- dimensions = (dimensions,)
534
-
535
- if var_name_format is None:
536
- var_name_format = "brackets"
537
-
538
- if isinstance(var_name_format, str):
539
- var_name_format = var_name_format.lower()
540
-
541
- if var_name_format == "brackets":
542
- dim_join_separator, dim_separator_start, dim_separator_end = ",", "[", "]"
543
- group_separator_start, group_separator_end = "_", ""
544
- elif var_name_format == "underscore":
545
- dim_join_separator, dim_separator_start, dim_separator_end = "_", "_", ""
546
- group_separator_start, group_separator_end = "_", ""
547
- elif var_name_format == "cds":
548
- dim_join_separator, dim_separator_start, dim_separator_end = (
549
- "_",
550
- "_ARVIZ_CDS_SELECTION_",
551
- "",
552
- )
553
- group_separator_start, group_separator_end = "_ARVIZ_GROUP_", ""
554
- elif isinstance(var_name_format, str):
555
- msg = 'Invalid predefined format. Select one {"brackets", "underscore", "cds"}'
556
- raise TypeError(msg)
557
- else:
558
- (
559
- (dim_join_separator, dim_separator_start, dim_separator_end),
560
- (group_separator_start, group_separator_end),
561
- ) = var_name_format
562
-
563
- if index_origin is None:
564
- index_origin = rcParams["data.index_origin"]
565
-
566
- data_dict = {}
567
- for group in groups:
568
- if hasattr(data, group):
569
- group_data = getattr(data, group).stack(stack_dimension=dimensions)
570
- for var_name, var in group_data.data_vars.items():
571
- var_values = var.values
572
- if var_names is not None and var_name not in var_names:
573
- continue
574
- for dim_name in dimensions:
575
- if dim_name not in data_dict:
576
- data_dict[dim_name] = var.coords.get(dim_name).values
577
- if len(var.shape) == 1:
578
- if group_info:
579
- var_name_dim = (
580
- "{var_name}" "{group_separator_start}{group}{group_separator_end}"
581
- ).format(
582
- var_name=var_name,
583
- group_separator_start=group_separator_start,
584
- group=group,
585
- group_separator_end=group_separator_end,
586
- )
587
- else:
588
- var_name_dim = f"{var_name}"
589
- data_dict[var_name_dim] = var.values
590
- else:
591
- for loc in np.ndindex(var.shape[:-1]):
592
- if group_info:
593
- var_name_dim = (
594
- "{var_name}"
595
- "{group_separator_start}{group}{group_separator_end}"
596
- "{dim_separator_start}{dim_join}{dim_separator_end}"
597
- ).format(
598
- var_name=var_name,
599
- group_separator_start=group_separator_start,
600
- group=group,
601
- group_separator_end=group_separator_end,
602
- dim_separator_start=dim_separator_start,
603
- dim_join=dim_join_separator.join(
604
- (str(item + index_origin) for item in loc)
605
- ),
606
- dim_separator_end=dim_separator_end,
607
- )
608
- else:
609
- var_name_dim = (
610
- "{var_name}" "{dim_separator_start}{dim_join}{dim_separator_end}"
611
- ).format(
612
- var_name=var_name,
613
- dim_separator_start=dim_separator_start,
614
- dim_join=dim_join_separator.join(
615
- (str(item + index_origin) for item in loc)
616
- ),
617
- dim_separator_end=dim_separator_end,
618
- )
619
-
620
- data_dict[var_name_dim] = var_values[loc]
621
- return data_dict
622
-
623
-
624
- def get_coords(data, coords):
625
- """Subselects xarray DataSet or DataArray object to provided coords. Raises exception if fails.
626
-
627
- Raises
628
- ------
629
- ValueError
630
- If coords name are not available in data
631
-
632
- KeyError
633
- If coords dims are not available in data
634
-
635
- Returns
636
- -------
637
- data: xarray
638
- xarray.DataSet or xarray.DataArray object, same type as input
639
- """
640
- if not isinstance(data, (list, tuple)):
641
- try:
642
- return data.sel(**coords)
643
-
644
- except ValueError as err:
645
- invalid_coords = set(coords.keys()) - set(data.coords.keys())
646
- raise ValueError(f"Coords {invalid_coords} are invalid coordinate keys") from err
647
-
648
- except KeyError as err:
649
- raise KeyError(
650
- (
651
- "Coords should follow mapping format {{coord_name:[dim1, dim2]}}. "
652
- "Check that coords structure is correct and"
653
- " dimensions are valid. {}"
654
- ).format(err)
655
- ) from err
656
- if not isinstance(coords, (list, tuple)):
657
- coords = [coords] * len(data)
658
- data_subset = []
659
- for idx, (datum, coords_dict) in enumerate(zip(data, coords)):
660
- try:
661
- data_subset.append(get_coords(datum, coords_dict))
662
- except ValueError as err:
663
- raise ValueError(f"Error in data[{idx}]: {err}") from err
664
- except KeyError as err:
665
- raise KeyError(f"Error in data[{idx}]: {err}") from err
666
- return data_subset
667
-
668
-
669
- @lru_cache(None)
670
- def _load_static_files():
671
- """Lazily load the resource files into memory the first time they are needed.
672
-
673
- Clone from xarray.core.formatted_html_template.
674
- """
675
- return [
676
- importlib.resources.files("arviz").joinpath(fname).read_text(encoding="utf-8")
677
- for fname in STATIC_FILES
678
- ]
679
-
680
-
681
- class HtmlTemplate:
682
- """Contain html templates for InferenceData repr."""
683
-
684
- html_template = """
685
- <div>
686
- <div class='xr-header'>
687
- <div class="xr-obj-type">arviz.InferenceData</div>
688
- </div>
689
- <ul class="xr-sections group-sections">
690
- {}
691
- </ul>
692
- </div>
693
- """
694
- element_template = """
695
- <li class = "xr-section-item">
696
- <input id="idata_{group_id}" class="xr-section-summary-in" type="checkbox">
697
- <label for="idata_{group_id}" class = "xr-section-summary">{group}</label>
698
- <div class="xr-section-inline-details"></div>
699
- <div class="xr-section-details">
700
- <ul id="xr-dataset-coord-list" class="xr-var-list">
701
- <div style="padding-left:2rem;">{xr_data}<br></div>
702
- </ul>
703
- </div>
704
- </li>
705
- """
706
- _, css_style = _load_static_files() # pylint: disable=protected-access
707
- specific_style = ".xr-wrap{width:700px!important;}"
708
- css_template = f"<style> {css_style}{specific_style} </style>"
709
-
710
-
711
- def either_dict_or_kwargs(
712
- pos_kwargs,
713
- kw_kwargs,
714
- func_name,
715
- ):
716
- """Clone from xarray.core.utils."""
717
- if pos_kwargs is None:
718
- return kw_kwargs
719
- if not hasattr(pos_kwargs, "keys") and hasattr(pos_kwargs, "__getitem__"):
720
- raise ValueError(f"the first argument to .{func_name} must be a dictionary")
721
- if kw_kwargs:
722
- raise ValueError(f"cannot specify both keyword and positional arguments to .{func_name}")
723
- return pos_kwargs
724
-
725
-
726
- class Dask:
727
- """Class to toggle Dask states.
728
-
729
- Warnings
730
- --------
731
- Dask integration is an experimental feature still in progress. It can already be used
732
- but it doesn't work with all stats nor diagnostics yet.
733
- """
734
-
735
- dask_flag = False
736
- """bool: Enables Dask parallelization when set to True. Defaults to False."""
737
- dask_kwargs = None
738
- """dict: Additional keyword arguments for Dask configuration.
739
- Defaults to an empty dictionary."""
740
-
741
- @classmethod
742
- def enable_dask(cls, dask_kwargs=None):
743
- """To enable Dask.
744
-
745
- Parameters
746
- ----------
747
- dask_kwargs : dict
748
- Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
749
- """
750
- cls.dask_flag = True
751
- cls.dask_kwargs = dask_kwargs
752
-
753
- @classmethod
754
- def disable_dask(cls):
755
- """To disable Dask."""
756
- cls.dask_flag = False
757
- cls.dask_kwargs = None
758
-
759
-
760
- def conditional_dask(func):
761
- """Conditionally pass dask kwargs to `wrap_xarray_ufunc`."""
762
-
763
- @functools.wraps(func)
764
- def wrapper(*args, **kwargs):
765
- if not Dask.dask_flag:
766
- return func(*args, **kwargs)
767
- user_kwargs = kwargs.pop("dask_kwargs", None)
768
- if user_kwargs is None:
769
- user_kwargs = {}
770
- default_kwargs = Dask.dask_kwargs
771
- return func(dask_kwargs={**default_kwargs, **user_kwargs}, *args, **kwargs)
772
-
773
- return wrapper