arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/utils.py
DELETED
|
@@ -1,773 +0,0 @@
|
|
|
1
|
-
# pylint: disable=too-many-nested-blocks
|
|
2
|
-
"""General utilities."""
|
|
3
|
-
import functools
|
|
4
|
-
import importlib
|
|
5
|
-
import importlib.resources
|
|
6
|
-
import re
|
|
7
|
-
import warnings
|
|
8
|
-
from functools import lru_cache
|
|
9
|
-
|
|
10
|
-
import matplotlib.pyplot as plt
|
|
11
|
-
import numpy as np
|
|
12
|
-
from numpy import newaxis
|
|
13
|
-
|
|
14
|
-
from .rcparams import rcParams
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
STATIC_FILES = ("static/html/icons-svg-inline.html", "static/css/style.css")
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class BehaviourChangeWarning(Warning):
|
|
21
|
-
"""Custom warning to ease filtering it."""
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def _check_tilde_start(x):
|
|
25
|
-
return bool(isinstance(x, str) and x.startswith("~"))
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def _var_names(var_names, data, filter_vars=None, errors="raise"):
|
|
29
|
-
"""Handle var_names input across arviz.
|
|
30
|
-
|
|
31
|
-
Parameters
|
|
32
|
-
----------
|
|
33
|
-
var_names: str, list, or None
|
|
34
|
-
data : xarray.Dataset
|
|
35
|
-
Posterior data in an xarray
|
|
36
|
-
filter_vars: {None, "like", "regex"}, optional, default=None
|
|
37
|
-
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
38
|
-
interpret var_names as substrings of the real variables names. If "regex",
|
|
39
|
-
interpret var_names as regular expressions on the real variables names. A la
|
|
40
|
-
`pandas.filter`.
|
|
41
|
-
errors: {"raise", "ignore"}, optional, default="raise"
|
|
42
|
-
Select either to raise or ignore the invalid names.
|
|
43
|
-
|
|
44
|
-
Returns
|
|
45
|
-
-------
|
|
46
|
-
var_name: list or None
|
|
47
|
-
"""
|
|
48
|
-
if filter_vars not in {None, "like", "regex"}:
|
|
49
|
-
raise ValueError(
|
|
50
|
-
f"'filter_vars' can only be None, 'like', or 'regex', got: '{filter_vars}'"
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
if errors not in {"raise", "ignore"}:
|
|
54
|
-
raise ValueError(f"'errors' can only be 'raise', or 'ignore', got: '{errors}'")
|
|
55
|
-
|
|
56
|
-
if var_names is not None:
|
|
57
|
-
if isinstance(data, (list, tuple)):
|
|
58
|
-
all_vars = []
|
|
59
|
-
for dataset in data:
|
|
60
|
-
dataset_vars = list(dataset.data_vars)
|
|
61
|
-
for var in dataset_vars:
|
|
62
|
-
if var not in all_vars:
|
|
63
|
-
all_vars.append(var)
|
|
64
|
-
else:
|
|
65
|
-
all_vars = list(data.data_vars)
|
|
66
|
-
|
|
67
|
-
all_vars_tilde = [var for var in all_vars if _check_tilde_start(var)]
|
|
68
|
-
if all_vars_tilde:
|
|
69
|
-
warnings.warn(
|
|
70
|
-
"""ArviZ treats '~' as a negation character for variable selection.
|
|
71
|
-
Your model has variables names starting with '~', {0}. Please double check
|
|
72
|
-
your results to ensure all variables are included""".format(
|
|
73
|
-
", ".join(all_vars_tilde)
|
|
74
|
-
)
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
try:
|
|
78
|
-
var_names = _subset_list(
|
|
79
|
-
var_names, all_vars, filter_items=filter_vars, warn=False, errors=errors
|
|
80
|
-
)
|
|
81
|
-
except KeyError as err:
|
|
82
|
-
msg = " ".join(("var names:", f"{err}", "in dataset"))
|
|
83
|
-
raise KeyError(msg) from err
|
|
84
|
-
return var_names
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def _subset_list(subset, whole_list, filter_items=None, warn=True, errors="raise"):
|
|
88
|
-
"""Handle list subsetting (var_names, groups...) across arviz.
|
|
89
|
-
|
|
90
|
-
Parameters
|
|
91
|
-
----------
|
|
92
|
-
subset : str, list, or None
|
|
93
|
-
whole_list : list
|
|
94
|
-
List from which to select a subset according to subset elements and
|
|
95
|
-
filter_items value.
|
|
96
|
-
filter_items : {None, "like", "regex"}, optional
|
|
97
|
-
If `None` (default), interpret `subset` as the exact elements in `whole_list`
|
|
98
|
-
names. If "like", interpret `subset` as substrings of the elements in
|
|
99
|
-
`whole_list`. If "regex", interpret `subset` as regular expressions to match
|
|
100
|
-
elements in `whole_list`. A la `pandas.filter`.
|
|
101
|
-
errors: {"raise", "ignore"}, optional, default="raise"
|
|
102
|
-
Select either to raise or ignore the invalid names.
|
|
103
|
-
|
|
104
|
-
Returns
|
|
105
|
-
-------
|
|
106
|
-
list or None
|
|
107
|
-
A subset of ``whole_list`` fulfilling the requests imposed by ``subset``
|
|
108
|
-
and ``filter_items``.
|
|
109
|
-
"""
|
|
110
|
-
if subset is not None:
|
|
111
|
-
if isinstance(subset, str):
|
|
112
|
-
subset = [subset]
|
|
113
|
-
|
|
114
|
-
whole_list_tilde = [item for item in whole_list if _check_tilde_start(item)]
|
|
115
|
-
if whole_list_tilde and warn:
|
|
116
|
-
warnings.warn(
|
|
117
|
-
"ArviZ treats '~' as a negation character for selection. There are "
|
|
118
|
-
"elements in `whole_list` starting with '~', {0}. Please double check"
|
|
119
|
-
"your results to ensure all elements are included".format(
|
|
120
|
-
", ".join(whole_list_tilde)
|
|
121
|
-
)
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
excluded_items = [
|
|
125
|
-
item[1:] for item in subset if _check_tilde_start(item) and item not in whole_list
|
|
126
|
-
]
|
|
127
|
-
filter_items = str(filter_items).lower()
|
|
128
|
-
if excluded_items:
|
|
129
|
-
not_found = []
|
|
130
|
-
|
|
131
|
-
if filter_items in {"like", "regex"}:
|
|
132
|
-
for pattern in excluded_items[:]:
|
|
133
|
-
excluded_items.remove(pattern)
|
|
134
|
-
if filter_items == "like":
|
|
135
|
-
real_items = [real_item for real_item in whole_list if pattern in real_item]
|
|
136
|
-
else:
|
|
137
|
-
# i.e filter_items == "regex"
|
|
138
|
-
real_items = [
|
|
139
|
-
real_item for real_item in whole_list if re.search(pattern, real_item)
|
|
140
|
-
]
|
|
141
|
-
if not real_items:
|
|
142
|
-
not_found.append(pattern)
|
|
143
|
-
excluded_items.extend(real_items)
|
|
144
|
-
not_found.extend([item for item in excluded_items if item not in whole_list])
|
|
145
|
-
if not_found:
|
|
146
|
-
warnings.warn(
|
|
147
|
-
f"Items starting with ~: {not_found} have not been found and will be ignored"
|
|
148
|
-
)
|
|
149
|
-
subset = [item for item in whole_list if item not in excluded_items]
|
|
150
|
-
|
|
151
|
-
elif filter_items == "like":
|
|
152
|
-
subset = [item for item in whole_list for name in subset if name in item]
|
|
153
|
-
elif filter_items == "regex":
|
|
154
|
-
subset = [item for item in whole_list for name in subset if re.search(name, item)]
|
|
155
|
-
|
|
156
|
-
existing_items = np.isin(subset, whole_list)
|
|
157
|
-
if not np.all(existing_items) and (errors == "raise"):
|
|
158
|
-
raise KeyError(f"{np.array(subset)[~existing_items]} are not present")
|
|
159
|
-
|
|
160
|
-
return subset
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
class lazy_property: # pylint: disable=invalid-name
|
|
164
|
-
"""Used to load numba first time it is needed."""
|
|
165
|
-
|
|
166
|
-
def __init__(self, fget):
|
|
167
|
-
"""Lazy load a property with `fget`."""
|
|
168
|
-
self.fget = fget
|
|
169
|
-
|
|
170
|
-
# copy the getter function's docstring and other attributes
|
|
171
|
-
functools.update_wrapper(self, fget)
|
|
172
|
-
|
|
173
|
-
def __get__(self, obj, cls):
|
|
174
|
-
"""Call the function, set the attribute."""
|
|
175
|
-
if obj is None:
|
|
176
|
-
return self
|
|
177
|
-
|
|
178
|
-
value = self.fget(obj)
|
|
179
|
-
setattr(obj, self.fget.__name__, value)
|
|
180
|
-
return value
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
class maybe_numba_fn: # pylint: disable=invalid-name
|
|
184
|
-
"""Wrap a function to (maybe) use a (lazy) jit-compiled version."""
|
|
185
|
-
|
|
186
|
-
def __init__(self, function, **kwargs):
|
|
187
|
-
"""Wrap a function and save compilation keywords."""
|
|
188
|
-
self.function = function
|
|
189
|
-
kwargs.setdefault("nopython", True)
|
|
190
|
-
self.kwargs = kwargs
|
|
191
|
-
|
|
192
|
-
@lazy_property
|
|
193
|
-
def numba_fn(self):
|
|
194
|
-
"""Memoized compiled function."""
|
|
195
|
-
try:
|
|
196
|
-
numba = importlib.import_module("numba")
|
|
197
|
-
numba_fn = numba.jit(**self.kwargs)(self.function)
|
|
198
|
-
except ImportError:
|
|
199
|
-
numba_fn = self.function
|
|
200
|
-
return numba_fn
|
|
201
|
-
|
|
202
|
-
def __call__(self, *args, **kwargs):
|
|
203
|
-
"""Call the jitted function or normal, depending on flag."""
|
|
204
|
-
if Numba.numba_flag:
|
|
205
|
-
return self.numba_fn(*args, **kwargs)
|
|
206
|
-
else:
|
|
207
|
-
return self.function(*args, **kwargs)
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
class interactive_backend: # pylint: disable=invalid-name
|
|
211
|
-
"""Context manager to change backend temporarily in ipython sesson.
|
|
212
|
-
|
|
213
|
-
It uses ipython magic to change temporarily from the ipython inline backend to
|
|
214
|
-
an interactive backend of choice. It cannot be used outside ipython sessions nor
|
|
215
|
-
to change backends different than inline -> interactive.
|
|
216
|
-
|
|
217
|
-
Notes
|
|
218
|
-
-----
|
|
219
|
-
The first time ``interactive_backend`` context manager is called, any of the available
|
|
220
|
-
interactive backends can be chosen. The following times, this same backend must be used
|
|
221
|
-
unless the kernel is restarted.
|
|
222
|
-
|
|
223
|
-
Parameters
|
|
224
|
-
----------
|
|
225
|
-
backend : str, optional
|
|
226
|
-
Interactive backend to use. It will be passed to ``%matplotlib`` magic, refer to
|
|
227
|
-
its docs to see available options.
|
|
228
|
-
|
|
229
|
-
Examples
|
|
230
|
-
--------
|
|
231
|
-
Inside an ipython session (i.e. a jupyter notebook) with the inline backend set:
|
|
232
|
-
|
|
233
|
-
.. code::
|
|
234
|
-
|
|
235
|
-
>>> import arviz as az
|
|
236
|
-
>>> idata = az.load_arviz_data("centered_eight")
|
|
237
|
-
>>> az.plot_posterior(idata) # inline
|
|
238
|
-
>>> with az.interactive_backend():
|
|
239
|
-
... az.plot_density(idata) # interactive
|
|
240
|
-
>>> az.plot_trace(idata) # inline
|
|
241
|
-
|
|
242
|
-
"""
|
|
243
|
-
|
|
244
|
-
# based on matplotlib.rc_context
|
|
245
|
-
def __init__(self, backend=""):
|
|
246
|
-
"""Initialize context manager."""
|
|
247
|
-
try:
|
|
248
|
-
from IPython import get_ipython
|
|
249
|
-
except ImportError as err:
|
|
250
|
-
raise ImportError(
|
|
251
|
-
"The exception below was risen while importing Ipython, this "
|
|
252
|
-
f"context manager can only be used inside ipython sessions:\n{err}"
|
|
253
|
-
) from err
|
|
254
|
-
self.ipython = get_ipython()
|
|
255
|
-
if self.ipython is None:
|
|
256
|
-
raise EnvironmentError("This context manager can only be used inside ipython sessions")
|
|
257
|
-
self.ipython.magic(f"matplotlib {backend}")
|
|
258
|
-
|
|
259
|
-
def __enter__(self):
|
|
260
|
-
"""Enter context manager."""
|
|
261
|
-
return self
|
|
262
|
-
|
|
263
|
-
def __exit__(self, exc_type, exc_value, exc_tb):
|
|
264
|
-
"""Exit context manager."""
|
|
265
|
-
plt.show(block=True)
|
|
266
|
-
self.ipython.magic("matplotlib inline")
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
def conditional_jit(_func=None, **kwargs):
|
|
270
|
-
"""Use numba's jit decorator if numba is installed.
|
|
271
|
-
|
|
272
|
-
Notes
|
|
273
|
-
-----
|
|
274
|
-
If called without arguments then return wrapped function.
|
|
275
|
-
|
|
276
|
-
@conditional_jit
|
|
277
|
-
def my_func():
|
|
278
|
-
return
|
|
279
|
-
|
|
280
|
-
else called with arguments
|
|
281
|
-
|
|
282
|
-
@conditional_jit(nopython=True)
|
|
283
|
-
def my_func():
|
|
284
|
-
return
|
|
285
|
-
|
|
286
|
-
"""
|
|
287
|
-
if _func is None:
|
|
288
|
-
return lambda fn: functools.wraps(fn)(maybe_numba_fn(fn, **kwargs))
|
|
289
|
-
lazy_numba = maybe_numba_fn(_func, **kwargs)
|
|
290
|
-
return functools.wraps(_func)(lazy_numba)
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
def conditional_vect(function=None, **kwargs): # noqa: D202
|
|
294
|
-
"""Use numba's vectorize decorator if numba is installed.
|
|
295
|
-
|
|
296
|
-
Notes
|
|
297
|
-
-----
|
|
298
|
-
If called without arguments then return wrapped function.
|
|
299
|
-
@conditional_vect
|
|
300
|
-
def my_func():
|
|
301
|
-
return
|
|
302
|
-
else called with arguments
|
|
303
|
-
@conditional_vect(nopython=True)
|
|
304
|
-
def my_func():
|
|
305
|
-
return
|
|
306
|
-
|
|
307
|
-
"""
|
|
308
|
-
|
|
309
|
-
def wrapper(function):
|
|
310
|
-
try:
|
|
311
|
-
numba = importlib.import_module("numba")
|
|
312
|
-
return numba.vectorize(**kwargs)(function)
|
|
313
|
-
|
|
314
|
-
except ImportError:
|
|
315
|
-
return function
|
|
316
|
-
|
|
317
|
-
if function:
|
|
318
|
-
return wrapper(function)
|
|
319
|
-
else:
|
|
320
|
-
return wrapper
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
def numba_check():
|
|
324
|
-
"""Check if numba is installed."""
|
|
325
|
-
numba = importlib.util.find_spec("numba")
|
|
326
|
-
return numba is not None
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
class Numba:
|
|
330
|
-
"""A class to toggle numba states."""
|
|
331
|
-
|
|
332
|
-
numba_flag = numba_check()
|
|
333
|
-
"""bool: Indicates whether Numba optimizations are enabled. Defaults to False."""
|
|
334
|
-
|
|
335
|
-
@classmethod
|
|
336
|
-
def disable_numba(cls):
|
|
337
|
-
"""To disable numba."""
|
|
338
|
-
cls.numba_flag = False
|
|
339
|
-
|
|
340
|
-
@classmethod
|
|
341
|
-
def enable_numba(cls):
|
|
342
|
-
"""To enable numba."""
|
|
343
|
-
if numba_check():
|
|
344
|
-
cls.numba_flag = True
|
|
345
|
-
else:
|
|
346
|
-
raise ValueError("Numba is not installed")
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
def _numba_var(numba_function, standard_numpy_func, data, axis=None, ddof=0):
|
|
350
|
-
"""Replace the numpy methods used to calculate variance.
|
|
351
|
-
|
|
352
|
-
Parameters
|
|
353
|
-
----------
|
|
354
|
-
numba_function : function()
|
|
355
|
-
Custom numba function included in stats/stats_utils.py.
|
|
356
|
-
|
|
357
|
-
standard_numpy_func: function()
|
|
358
|
-
Standard function included in the numpy library.
|
|
359
|
-
|
|
360
|
-
data : array.
|
|
361
|
-
axis : axis along which the variance is calculated.
|
|
362
|
-
ddof : degrees of freedom allowed while calculating variance.
|
|
363
|
-
|
|
364
|
-
Returns
|
|
365
|
-
-------
|
|
366
|
-
array:
|
|
367
|
-
variance values calculate by appropriate function for numba speedup
|
|
368
|
-
if Numba is installed or enabled.
|
|
369
|
-
|
|
370
|
-
"""
|
|
371
|
-
if Numba.numba_flag:
|
|
372
|
-
return numba_function(data, axis=axis, ddof=ddof)
|
|
373
|
-
else:
|
|
374
|
-
return standard_numpy_func(data, axis=axis, ddof=ddof)
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
def _stack(x, y):
|
|
378
|
-
assert x.shape[1:] == y.shape[1:]
|
|
379
|
-
return np.vstack((x, y))
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
def arange(x):
|
|
383
|
-
"""Jitting numpy arange."""
|
|
384
|
-
return np.arange(x)
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
def one_de(x):
|
|
388
|
-
"""Jitting numpy atleast_1d."""
|
|
389
|
-
if not isinstance(x, np.ndarray):
|
|
390
|
-
return np.atleast_1d(x)
|
|
391
|
-
result = x.reshape(1) if x.ndim == 0 else x
|
|
392
|
-
return result
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
def two_de(x):
|
|
396
|
-
"""Jitting numpy at_least_2d."""
|
|
397
|
-
if not isinstance(x, np.ndarray):
|
|
398
|
-
return np.atleast_2d(x)
|
|
399
|
-
if x.ndim == 0:
|
|
400
|
-
result = x.reshape(1, 1)
|
|
401
|
-
elif x.ndim == 1:
|
|
402
|
-
result = x[newaxis, :]
|
|
403
|
-
else:
|
|
404
|
-
result = x
|
|
405
|
-
return result
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
def expand_dims(x):
|
|
409
|
-
"""Jitting numpy expand_dims."""
|
|
410
|
-
if not isinstance(x, np.ndarray):
|
|
411
|
-
return np.expand_dims(x, 0)
|
|
412
|
-
shape = x.shape
|
|
413
|
-
return x.reshape(shape[:0] + (1,) + shape[0:])
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
@conditional_jit(cache=True, nopython=True)
|
|
417
|
-
def _dot(x, y):
|
|
418
|
-
return np.dot(x, y)
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
@conditional_jit(cache=True, nopython=True)
|
|
422
|
-
def _cov_1d(x):
|
|
423
|
-
x = x - x.mean()
|
|
424
|
-
ddof = x.shape[0] - 1
|
|
425
|
-
return np.dot(x.T, x.conj()) / ddof
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
# @conditional_jit(cache=True)
|
|
429
|
-
def _cov(data):
|
|
430
|
-
if data.ndim == 1:
|
|
431
|
-
return _cov_1d(data)
|
|
432
|
-
elif data.ndim == 2:
|
|
433
|
-
x = data.astype(float)
|
|
434
|
-
avg, _ = np.average(x, axis=1, weights=None, returned=True)
|
|
435
|
-
ddof = x.shape[1] - 1
|
|
436
|
-
if ddof <= 0:
|
|
437
|
-
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
|
|
438
|
-
ddof = 0.0
|
|
439
|
-
x -= avg[:, None]
|
|
440
|
-
prod = _dot(x, x.T.conj())
|
|
441
|
-
prod *= np.true_divide(1, ddof)
|
|
442
|
-
prod = prod.squeeze()
|
|
443
|
-
prod += 1e-6 * np.eye(prod.shape[0])
|
|
444
|
-
return prod
|
|
445
|
-
else:
|
|
446
|
-
raise ValueError(f"{data.ndim} dimension arrays are not supported")
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
def flatten_inference_data_to_dict(
|
|
450
|
-
data,
|
|
451
|
-
var_names=None,
|
|
452
|
-
groups=None,
|
|
453
|
-
dimensions=None,
|
|
454
|
-
group_info=False,
|
|
455
|
-
var_name_format=None,
|
|
456
|
-
index_origin=None,
|
|
457
|
-
):
|
|
458
|
-
"""Transform data to dictionary.
|
|
459
|
-
|
|
460
|
-
Parameters
|
|
461
|
-
----------
|
|
462
|
-
data : obj
|
|
463
|
-
Any object that can be converted to an az.InferenceData object
|
|
464
|
-
Refer to documentation of az.convert_to_inference_data for details
|
|
465
|
-
var_names : str or list of str, optional
|
|
466
|
-
Variables to be processed, if None all variables are processed.
|
|
467
|
-
groups : str or list of str, optional
|
|
468
|
-
Select groups for CDS. Default groups are
|
|
469
|
-
{"posterior_groups", "prior_groups", "posterior_groups_warmup"}
|
|
470
|
-
- posterior_groups: posterior, posterior_predictive, sample_stats
|
|
471
|
-
- prior_groups: prior, prior_predictive, sample_stats_prior
|
|
472
|
-
- posterior_groups_warmup: warmup_posterior, warmup_posterior_predictive,
|
|
473
|
-
warmup_sample_stats
|
|
474
|
-
ignore_groups : str or list of str, optional
|
|
475
|
-
Ignore specific groups from CDS.
|
|
476
|
-
dimension : str, or list of str, optional
|
|
477
|
-
Select dimensions along to slice the data. By default uses ("chain", "draw").
|
|
478
|
-
group_info : bool
|
|
479
|
-
Add group info for `var_name_format`
|
|
480
|
-
var_name_format : str or tuple of tuple of string, optional
|
|
481
|
-
Select column name format for non-scalar input.
|
|
482
|
-
Predefined options are {"brackets", "underscore", "cds"}
|
|
483
|
-
"brackets":
|
|
484
|
-
- add_group_info == False: theta[0,0]
|
|
485
|
-
- add_group_info == True: theta_posterior[0,0]
|
|
486
|
-
"underscore":
|
|
487
|
-
- add_group_info == False: theta_0_0
|
|
488
|
-
- add_group_info == True: theta_posterior_0_0_
|
|
489
|
-
"cds":
|
|
490
|
-
- add_group_info == False: theta_ARVIZ_CDS_SELECTION_0_0
|
|
491
|
-
- add_group_info == True: theta_ARVIZ_GROUP_posterior__ARVIZ_CDS_SELECTION_0_0
|
|
492
|
-
tuple:
|
|
493
|
-
Structure:
|
|
494
|
-
tuple: (dim_info, group_info)
|
|
495
|
-
dim_info: (str: `.join` separator,
|
|
496
|
-
str: dim_separator_start,
|
|
497
|
-
str: dim_separator_end)
|
|
498
|
-
group_info: (str: group separator start, str: group separator end)
|
|
499
|
-
Example: ((",", "[", "]"), ("_", ""))
|
|
500
|
-
- add_group_info == False: theta[0,0]
|
|
501
|
-
- add_group_info == True: theta_posterior[0,0]
|
|
502
|
-
index_origin : int, optional
|
|
503
|
-
Start parameter indices from `index_origin`. Either 0 or 1.
|
|
504
|
-
|
|
505
|
-
Returns
|
|
506
|
-
-------
|
|
507
|
-
dict
|
|
508
|
-
"""
|
|
509
|
-
from .data import convert_to_inference_data
|
|
510
|
-
|
|
511
|
-
data = convert_to_inference_data(data)
|
|
512
|
-
|
|
513
|
-
if groups is None:
|
|
514
|
-
groups = ["posterior", "posterior_predictive", "sample_stats"]
|
|
515
|
-
elif isinstance(groups, str):
|
|
516
|
-
if groups.lower() == "posterior_groups":
|
|
517
|
-
groups = ["posterior", "posterior_predictive", "sample_stats"]
|
|
518
|
-
elif groups.lower() == "prior_groups":
|
|
519
|
-
groups = ["prior", "prior_predictive", "sample_stats_prior"]
|
|
520
|
-
elif groups.lower() == "posterior_groups_warmup":
|
|
521
|
-
groups = ["warmup_posterior", "warmup_posterior_predictive", "warmup_sample_stats"]
|
|
522
|
-
else:
|
|
523
|
-
raise TypeError(
|
|
524
|
-
(
|
|
525
|
-
"Valid predefined groups are "
|
|
526
|
-
"{posterior_groups, prior_groups, posterior_groups_warmup}"
|
|
527
|
-
)
|
|
528
|
-
)
|
|
529
|
-
|
|
530
|
-
if dimensions is None:
|
|
531
|
-
dimensions = "chain", "draw"
|
|
532
|
-
elif isinstance(dimensions, str):
|
|
533
|
-
dimensions = (dimensions,)
|
|
534
|
-
|
|
535
|
-
if var_name_format is None:
|
|
536
|
-
var_name_format = "brackets"
|
|
537
|
-
|
|
538
|
-
if isinstance(var_name_format, str):
|
|
539
|
-
var_name_format = var_name_format.lower()
|
|
540
|
-
|
|
541
|
-
if var_name_format == "brackets":
|
|
542
|
-
dim_join_separator, dim_separator_start, dim_separator_end = ",", "[", "]"
|
|
543
|
-
group_separator_start, group_separator_end = "_", ""
|
|
544
|
-
elif var_name_format == "underscore":
|
|
545
|
-
dim_join_separator, dim_separator_start, dim_separator_end = "_", "_", ""
|
|
546
|
-
group_separator_start, group_separator_end = "_", ""
|
|
547
|
-
elif var_name_format == "cds":
|
|
548
|
-
dim_join_separator, dim_separator_start, dim_separator_end = (
|
|
549
|
-
"_",
|
|
550
|
-
"_ARVIZ_CDS_SELECTION_",
|
|
551
|
-
"",
|
|
552
|
-
)
|
|
553
|
-
group_separator_start, group_separator_end = "_ARVIZ_GROUP_", ""
|
|
554
|
-
elif isinstance(var_name_format, str):
|
|
555
|
-
msg = 'Invalid predefined format. Select one {"brackets", "underscore", "cds"}'
|
|
556
|
-
raise TypeError(msg)
|
|
557
|
-
else:
|
|
558
|
-
(
|
|
559
|
-
(dim_join_separator, dim_separator_start, dim_separator_end),
|
|
560
|
-
(group_separator_start, group_separator_end),
|
|
561
|
-
) = var_name_format
|
|
562
|
-
|
|
563
|
-
if index_origin is None:
|
|
564
|
-
index_origin = rcParams["data.index_origin"]
|
|
565
|
-
|
|
566
|
-
data_dict = {}
|
|
567
|
-
for group in groups:
|
|
568
|
-
if hasattr(data, group):
|
|
569
|
-
group_data = getattr(data, group).stack(stack_dimension=dimensions)
|
|
570
|
-
for var_name, var in group_data.data_vars.items():
|
|
571
|
-
var_values = var.values
|
|
572
|
-
if var_names is not None and var_name not in var_names:
|
|
573
|
-
continue
|
|
574
|
-
for dim_name in dimensions:
|
|
575
|
-
if dim_name not in data_dict:
|
|
576
|
-
data_dict[dim_name] = var.coords.get(dim_name).values
|
|
577
|
-
if len(var.shape) == 1:
|
|
578
|
-
if group_info:
|
|
579
|
-
var_name_dim = (
|
|
580
|
-
"{var_name}" "{group_separator_start}{group}{group_separator_end}"
|
|
581
|
-
).format(
|
|
582
|
-
var_name=var_name,
|
|
583
|
-
group_separator_start=group_separator_start,
|
|
584
|
-
group=group,
|
|
585
|
-
group_separator_end=group_separator_end,
|
|
586
|
-
)
|
|
587
|
-
else:
|
|
588
|
-
var_name_dim = f"{var_name}"
|
|
589
|
-
data_dict[var_name_dim] = var.values
|
|
590
|
-
else:
|
|
591
|
-
for loc in np.ndindex(var.shape[:-1]):
|
|
592
|
-
if group_info:
|
|
593
|
-
var_name_dim = (
|
|
594
|
-
"{var_name}"
|
|
595
|
-
"{group_separator_start}{group}{group_separator_end}"
|
|
596
|
-
"{dim_separator_start}{dim_join}{dim_separator_end}"
|
|
597
|
-
).format(
|
|
598
|
-
var_name=var_name,
|
|
599
|
-
group_separator_start=group_separator_start,
|
|
600
|
-
group=group,
|
|
601
|
-
group_separator_end=group_separator_end,
|
|
602
|
-
dim_separator_start=dim_separator_start,
|
|
603
|
-
dim_join=dim_join_separator.join(
|
|
604
|
-
(str(item + index_origin) for item in loc)
|
|
605
|
-
),
|
|
606
|
-
dim_separator_end=dim_separator_end,
|
|
607
|
-
)
|
|
608
|
-
else:
|
|
609
|
-
var_name_dim = (
|
|
610
|
-
"{var_name}" "{dim_separator_start}{dim_join}{dim_separator_end}"
|
|
611
|
-
).format(
|
|
612
|
-
var_name=var_name,
|
|
613
|
-
dim_separator_start=dim_separator_start,
|
|
614
|
-
dim_join=dim_join_separator.join(
|
|
615
|
-
(str(item + index_origin) for item in loc)
|
|
616
|
-
),
|
|
617
|
-
dim_separator_end=dim_separator_end,
|
|
618
|
-
)
|
|
619
|
-
|
|
620
|
-
data_dict[var_name_dim] = var_values[loc]
|
|
621
|
-
return data_dict
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
def get_coords(data, coords):
|
|
625
|
-
"""Subselects xarray DataSet or DataArray object to provided coords. Raises exception if fails.
|
|
626
|
-
|
|
627
|
-
Raises
|
|
628
|
-
------
|
|
629
|
-
ValueError
|
|
630
|
-
If coords name are not available in data
|
|
631
|
-
|
|
632
|
-
KeyError
|
|
633
|
-
If coords dims are not available in data
|
|
634
|
-
|
|
635
|
-
Returns
|
|
636
|
-
-------
|
|
637
|
-
data: xarray
|
|
638
|
-
xarray.DataSet or xarray.DataArray object, same type as input
|
|
639
|
-
"""
|
|
640
|
-
if not isinstance(data, (list, tuple)):
|
|
641
|
-
try:
|
|
642
|
-
return data.sel(**coords)
|
|
643
|
-
|
|
644
|
-
except ValueError as err:
|
|
645
|
-
invalid_coords = set(coords.keys()) - set(data.coords.keys())
|
|
646
|
-
raise ValueError(f"Coords {invalid_coords} are invalid coordinate keys") from err
|
|
647
|
-
|
|
648
|
-
except KeyError as err:
|
|
649
|
-
raise KeyError(
|
|
650
|
-
(
|
|
651
|
-
"Coords should follow mapping format {{coord_name:[dim1, dim2]}}. "
|
|
652
|
-
"Check that coords structure is correct and"
|
|
653
|
-
" dimensions are valid. {}"
|
|
654
|
-
).format(err)
|
|
655
|
-
) from err
|
|
656
|
-
if not isinstance(coords, (list, tuple)):
|
|
657
|
-
coords = [coords] * len(data)
|
|
658
|
-
data_subset = []
|
|
659
|
-
for idx, (datum, coords_dict) in enumerate(zip(data, coords)):
|
|
660
|
-
try:
|
|
661
|
-
data_subset.append(get_coords(datum, coords_dict))
|
|
662
|
-
except ValueError as err:
|
|
663
|
-
raise ValueError(f"Error in data[{idx}]: {err}") from err
|
|
664
|
-
except KeyError as err:
|
|
665
|
-
raise KeyError(f"Error in data[{idx}]: {err}") from err
|
|
666
|
-
return data_subset
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
@lru_cache(None)
|
|
670
|
-
def _load_static_files():
|
|
671
|
-
"""Lazily load the resource files into memory the first time they are needed.
|
|
672
|
-
|
|
673
|
-
Clone from xarray.core.formatted_html_template.
|
|
674
|
-
"""
|
|
675
|
-
return [
|
|
676
|
-
importlib.resources.files("arviz").joinpath(fname).read_text(encoding="utf-8")
|
|
677
|
-
for fname in STATIC_FILES
|
|
678
|
-
]
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
class HtmlTemplate:
|
|
682
|
-
"""Contain html templates for InferenceData repr."""
|
|
683
|
-
|
|
684
|
-
html_template = """
|
|
685
|
-
<div>
|
|
686
|
-
<div class='xr-header'>
|
|
687
|
-
<div class="xr-obj-type">arviz.InferenceData</div>
|
|
688
|
-
</div>
|
|
689
|
-
<ul class="xr-sections group-sections">
|
|
690
|
-
{}
|
|
691
|
-
</ul>
|
|
692
|
-
</div>
|
|
693
|
-
"""
|
|
694
|
-
element_template = """
|
|
695
|
-
<li class = "xr-section-item">
|
|
696
|
-
<input id="idata_{group_id}" class="xr-section-summary-in" type="checkbox">
|
|
697
|
-
<label for="idata_{group_id}" class = "xr-section-summary">{group}</label>
|
|
698
|
-
<div class="xr-section-inline-details"></div>
|
|
699
|
-
<div class="xr-section-details">
|
|
700
|
-
<ul id="xr-dataset-coord-list" class="xr-var-list">
|
|
701
|
-
<div style="padding-left:2rem;">{xr_data}<br></div>
|
|
702
|
-
</ul>
|
|
703
|
-
</div>
|
|
704
|
-
</li>
|
|
705
|
-
"""
|
|
706
|
-
_, css_style = _load_static_files() # pylint: disable=protected-access
|
|
707
|
-
specific_style = ".xr-wrap{width:700px!important;}"
|
|
708
|
-
css_template = f"<style> {css_style}{specific_style} </style>"
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
def either_dict_or_kwargs(
|
|
712
|
-
pos_kwargs,
|
|
713
|
-
kw_kwargs,
|
|
714
|
-
func_name,
|
|
715
|
-
):
|
|
716
|
-
"""Clone from xarray.core.utils."""
|
|
717
|
-
if pos_kwargs is None:
|
|
718
|
-
return kw_kwargs
|
|
719
|
-
if not hasattr(pos_kwargs, "keys") and hasattr(pos_kwargs, "__getitem__"):
|
|
720
|
-
raise ValueError(f"the first argument to .{func_name} must be a dictionary")
|
|
721
|
-
if kw_kwargs:
|
|
722
|
-
raise ValueError(f"cannot specify both keyword and positional arguments to .{func_name}")
|
|
723
|
-
return pos_kwargs
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
class Dask:
|
|
727
|
-
"""Class to toggle Dask states.
|
|
728
|
-
|
|
729
|
-
Warnings
|
|
730
|
-
--------
|
|
731
|
-
Dask integration is an experimental feature still in progress. It can already be used
|
|
732
|
-
but it doesn't work with all stats nor diagnostics yet.
|
|
733
|
-
"""
|
|
734
|
-
|
|
735
|
-
dask_flag = False
|
|
736
|
-
"""bool: Enables Dask parallelization when set to True. Defaults to False."""
|
|
737
|
-
dask_kwargs = None
|
|
738
|
-
"""dict: Additional keyword arguments for Dask configuration.
|
|
739
|
-
Defaults to an empty dictionary."""
|
|
740
|
-
|
|
741
|
-
@classmethod
|
|
742
|
-
def enable_dask(cls, dask_kwargs=None):
|
|
743
|
-
"""To enable Dask.
|
|
744
|
-
|
|
745
|
-
Parameters
|
|
746
|
-
----------
|
|
747
|
-
dask_kwargs : dict
|
|
748
|
-
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
|
|
749
|
-
"""
|
|
750
|
-
cls.dask_flag = True
|
|
751
|
-
cls.dask_kwargs = dask_kwargs
|
|
752
|
-
|
|
753
|
-
@classmethod
|
|
754
|
-
def disable_dask(cls):
|
|
755
|
-
"""To disable Dask."""
|
|
756
|
-
cls.dask_flag = False
|
|
757
|
-
cls.dask_kwargs = None
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
def conditional_dask(func):
|
|
761
|
-
"""Conditionally pass dask kwargs to `wrap_xarray_ufunc`."""
|
|
762
|
-
|
|
763
|
-
@functools.wraps(func)
|
|
764
|
-
def wrapper(*args, **kwargs):
|
|
765
|
-
if not Dask.dask_flag:
|
|
766
|
-
return func(*args, **kwargs)
|
|
767
|
-
user_kwargs = kwargs.pop("dask_kwargs", None)
|
|
768
|
-
if user_kwargs is None:
|
|
769
|
-
user_kwargs = {}
|
|
770
|
-
default_kwargs = Dask.dask_kwargs
|
|
771
|
-
return func(dask_kwargs={**default_kwargs, **user_kwargs}, *args, **kwargs)
|
|
772
|
-
|
|
773
|
-
return wrapper
|