arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/data/converters.py
DELETED
|
@@ -1,203 +0,0 @@
|
|
|
1
|
-
"""High level conversion functions."""
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import xarray as xr
|
|
5
|
-
import pandas as pd
|
|
6
|
-
|
|
7
|
-
try:
|
|
8
|
-
from tree import is_nested
|
|
9
|
-
except ImportError:
|
|
10
|
-
is_nested = lambda obj: False
|
|
11
|
-
|
|
12
|
-
from .base import dict_to_dataset
|
|
13
|
-
from .inference_data import InferenceData
|
|
14
|
-
from .io_beanmachine import from_beanmachine
|
|
15
|
-
from .io_cmdstan import from_cmdstan
|
|
16
|
-
from .io_cmdstanpy import from_cmdstanpy
|
|
17
|
-
from .io_emcee import from_emcee
|
|
18
|
-
from .io_numpyro import from_numpyro
|
|
19
|
-
from .io_pyro import from_pyro
|
|
20
|
-
from .io_pystan import from_pystan
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
# pylint: disable=too-many-return-statements
|
|
24
|
-
def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, **kwargs):
|
|
25
|
-
r"""Convert a supported object to an InferenceData object.
|
|
26
|
-
|
|
27
|
-
This function sends `obj` to the right conversion function. It is idempotent,
|
|
28
|
-
in that it will return arviz.InferenceData objects unchanged.
|
|
29
|
-
|
|
30
|
-
Parameters
|
|
31
|
-
----------
|
|
32
|
-
obj : dict, str, np.ndarray, xr.Dataset, pystan fit
|
|
33
|
-
A supported object to convert to InferenceData:
|
|
34
|
-
| InferenceData: returns unchanged
|
|
35
|
-
| str: Attempts to load the cmdstan csv or netcdf dataset from disk
|
|
36
|
-
| pystan fit: Automatically extracts data
|
|
37
|
-
| cmdstanpy fit: Automatically extracts data
|
|
38
|
-
| cmdstan csv-list: Automatically extracts data
|
|
39
|
-
| emcee sampler: Automatically extracts data
|
|
40
|
-
| pyro MCMC: Automatically extracts data
|
|
41
|
-
| beanmachine MonteCarloSamples: Automatically extracts data
|
|
42
|
-
| xarray.Dataset: adds to InferenceData as only group
|
|
43
|
-
| xarray.DataArray: creates an xarray dataset as the only group, gives the
|
|
44
|
-
array an arbitrary name, if name not set
|
|
45
|
-
| dict: creates an xarray dataset as the only group
|
|
46
|
-
| numpy array: creates an xarray dataset as the only group, gives the
|
|
47
|
-
array an arbitrary name
|
|
48
|
-
| object with __array__: converts to numpy array, then creates an xarray dataset as
|
|
49
|
-
the only group, gives the array an arbitrary name
|
|
50
|
-
group : str
|
|
51
|
-
If `obj` is a dict or numpy array, assigns the resulting xarray
|
|
52
|
-
dataset to this group. Default: "posterior".
|
|
53
|
-
coords : dict[str, iterable]
|
|
54
|
-
A dictionary containing the values that are used as index. The key
|
|
55
|
-
is the name of the dimension, the values are the index values.
|
|
56
|
-
dims : dict[str, List(str)]
|
|
57
|
-
A mapping from variables to a list of coordinate names for the variable
|
|
58
|
-
kwargs
|
|
59
|
-
Rest of the supported keyword arguments transferred to conversion function.
|
|
60
|
-
|
|
61
|
-
Returns
|
|
62
|
-
-------
|
|
63
|
-
InferenceData
|
|
64
|
-
"""
|
|
65
|
-
kwargs[group] = obj
|
|
66
|
-
kwargs["coords"] = coords
|
|
67
|
-
kwargs["dims"] = dims
|
|
68
|
-
|
|
69
|
-
# Cases that convert to InferenceData
|
|
70
|
-
if isinstance(obj, InferenceData):
|
|
71
|
-
if coords is not None or dims is not None:
|
|
72
|
-
raise TypeError("Cannot use coords or dims arguments with InferenceData value.")
|
|
73
|
-
return obj
|
|
74
|
-
elif isinstance(obj, str):
|
|
75
|
-
if obj.endswith(".csv"):
|
|
76
|
-
if group == "sample_stats":
|
|
77
|
-
kwargs["posterior"] = kwargs.pop(group)
|
|
78
|
-
elif group == "sample_stats_prior":
|
|
79
|
-
kwargs["prior"] = kwargs.pop(group)
|
|
80
|
-
return from_cmdstan(**kwargs)
|
|
81
|
-
else:
|
|
82
|
-
if coords is not None or dims is not None:
|
|
83
|
-
raise TypeError(
|
|
84
|
-
"Cannot use coords or dims arguments reading InferenceData from netcdf."
|
|
85
|
-
)
|
|
86
|
-
return InferenceData.from_netcdf(obj)
|
|
87
|
-
elif (
|
|
88
|
-
obj.__class__.__name__ in {"StanFit4Model", "CmdStanMCMC"}
|
|
89
|
-
or obj.__class__.__module__ == "stan.fit"
|
|
90
|
-
):
|
|
91
|
-
if group == "sample_stats":
|
|
92
|
-
kwargs["posterior"] = kwargs.pop(group)
|
|
93
|
-
elif group == "sample_stats_prior":
|
|
94
|
-
kwargs["prior"] = kwargs.pop(group)
|
|
95
|
-
if obj.__class__.__name__ == "CmdStanMCMC":
|
|
96
|
-
return from_cmdstanpy(**kwargs)
|
|
97
|
-
else: # pystan or pystan3
|
|
98
|
-
return from_pystan(**kwargs)
|
|
99
|
-
elif obj.__class__.__name__ == "EnsembleSampler": # ugly, but doesn't make emcee a requirement
|
|
100
|
-
return from_emcee(sampler=kwargs.pop(group), **kwargs)
|
|
101
|
-
elif obj.__class__.__name__ == "MonteCarloSamples":
|
|
102
|
-
return from_beanmachine(sampler=kwargs.pop(group), **kwargs)
|
|
103
|
-
elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("pyro"):
|
|
104
|
-
return from_pyro(posterior=kwargs.pop(group), **kwargs)
|
|
105
|
-
elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("numpyro"):
|
|
106
|
-
return from_numpyro(posterior=kwargs.pop(group), **kwargs)
|
|
107
|
-
|
|
108
|
-
# Cases that convert to xarray
|
|
109
|
-
if isinstance(obj, xr.Dataset):
|
|
110
|
-
dataset = obj
|
|
111
|
-
elif isinstance(obj, xr.DataArray):
|
|
112
|
-
if obj.name is None:
|
|
113
|
-
obj.name = "x"
|
|
114
|
-
dataset = obj.to_dataset()
|
|
115
|
-
elif isinstance(obj, dict):
|
|
116
|
-
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
|
|
117
|
-
elif is_nested(obj) and not isinstance(obj, (list, tuple)):
|
|
118
|
-
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
|
|
119
|
-
elif isinstance(obj, np.ndarray):
|
|
120
|
-
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
|
|
121
|
-
elif (
|
|
122
|
-
hasattr(obj, "__array__")
|
|
123
|
-
and callable(getattr(obj, "__array__"))
|
|
124
|
-
and (not isinstance(obj, pd.DataFrame))
|
|
125
|
-
):
|
|
126
|
-
obj = obj.__array__()
|
|
127
|
-
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
|
|
128
|
-
elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
|
|
129
|
-
if group == "sample_stats":
|
|
130
|
-
kwargs["posterior"] = kwargs.pop(group)
|
|
131
|
-
elif group == "sample_stats_prior":
|
|
132
|
-
kwargs["prior"] = kwargs.pop(group)
|
|
133
|
-
return from_cmdstan(**kwargs)
|
|
134
|
-
else:
|
|
135
|
-
allowable_types = (
|
|
136
|
-
"xarray dataarray",
|
|
137
|
-
"xarray dataset",
|
|
138
|
-
"dict",
|
|
139
|
-
"pytree (if 'dm-tree' is installed)",
|
|
140
|
-
"netcdf filename",
|
|
141
|
-
"numpy array",
|
|
142
|
-
"object with __array__",
|
|
143
|
-
"pystan fit",
|
|
144
|
-
"emcee fit",
|
|
145
|
-
"pyro mcmc fit",
|
|
146
|
-
"numpyro mcmc fit",
|
|
147
|
-
"cmdstan fit csv filename",
|
|
148
|
-
"cmdstanpy fit",
|
|
149
|
-
)
|
|
150
|
-
raise ValueError(
|
|
151
|
-
f'Can only convert {", ".join(allowable_types)} to InferenceData, '
|
|
152
|
-
f"not {obj.__class__.__name__}"
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
return InferenceData(**{group: dataset})
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
def convert_to_dataset(obj, *, group="posterior", coords=None, dims=None):
|
|
159
|
-
"""Convert a supported object to an xarray dataset.
|
|
160
|
-
|
|
161
|
-
This function is idempotent, in that it will return xarray.Dataset functions
|
|
162
|
-
unchanged. Raises `ValueError` if the desired group can not be extracted.
|
|
163
|
-
|
|
164
|
-
Note this goes through a DataInference object. See `convert_to_inference_data`
|
|
165
|
-
for more details. Raises ValueError if it can not work out the desired
|
|
166
|
-
conversion.
|
|
167
|
-
|
|
168
|
-
Parameters
|
|
169
|
-
----------
|
|
170
|
-
obj : dict, str, np.ndarray, xr.Dataset, pystan fit
|
|
171
|
-
A supported object to convert to InferenceData:
|
|
172
|
-
|
|
173
|
-
- InferenceData: returns unchanged
|
|
174
|
-
- str: Attempts to load the netcdf dataset from disk
|
|
175
|
-
- pystan fit: Automatically extracts data
|
|
176
|
-
- xarray.Dataset: adds to InferenceData as only group
|
|
177
|
-
- xarray.DataArray: creates an xarray dataset as the only group, gives the
|
|
178
|
-
array an arbitrary name, if name not set
|
|
179
|
-
- dict: creates an xarray dataset as the only group
|
|
180
|
-
- numpy array: creates an xarray dataset as the only group, gives the
|
|
181
|
-
array an arbitrary name
|
|
182
|
-
|
|
183
|
-
group : str
|
|
184
|
-
If `obj` is a dict or numpy array, assigns the resulting xarray
|
|
185
|
-
dataset to this group.
|
|
186
|
-
coords : dict[str, iterable]
|
|
187
|
-
A dictionary containing the values that are used as index. The key
|
|
188
|
-
is the name of the dimension, the values are the index values.
|
|
189
|
-
dims : dict[str, List(str)]
|
|
190
|
-
A mapping from variables to a list of coordinate names for the variable
|
|
191
|
-
|
|
192
|
-
Returns
|
|
193
|
-
-------
|
|
194
|
-
xarray.Dataset
|
|
195
|
-
"""
|
|
196
|
-
inference_data = convert_to_inference_data(obj, group=group, coords=coords, dims=dims)
|
|
197
|
-
dataset = getattr(inference_data, group, None)
|
|
198
|
-
if dataset is None:
|
|
199
|
-
raise ValueError(
|
|
200
|
-
"Can not extract {group} from {obj}! See {filename} for other "
|
|
201
|
-
"conversion utilities.".format(group=group, obj=obj, filename=__file__)
|
|
202
|
-
)
|
|
203
|
-
return dataset
|
arviz/data/datasets.py
DELETED
|
@@ -1,161 +0,0 @@
|
|
|
1
|
-
"""Base IO code for all datasets. Heavily influenced by scikit-learn's implementation."""
|
|
2
|
-
|
|
3
|
-
import hashlib
|
|
4
|
-
import itertools
|
|
5
|
-
import json
|
|
6
|
-
import os
|
|
7
|
-
import shutil
|
|
8
|
-
from collections import namedtuple
|
|
9
|
-
from urllib.request import urlretrieve
|
|
10
|
-
|
|
11
|
-
from ..rcparams import rcParams
|
|
12
|
-
from .io_netcdf import from_netcdf
|
|
13
|
-
|
|
14
|
-
LocalFileMetadata = namedtuple("LocalFileMetadata", ["name", "filename", "description"])
|
|
15
|
-
|
|
16
|
-
RemoteFileMetadata = namedtuple(
|
|
17
|
-
"RemoteFileMetadata", ["name", "filename", "url", "checksum", "description"]
|
|
18
|
-
)
|
|
19
|
-
_EXAMPLE_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_data")
|
|
20
|
-
_LOCAL_DATA_DIR = os.path.join(_EXAMPLE_DATA_DIR, "data")
|
|
21
|
-
|
|
22
|
-
with open(os.path.join(_EXAMPLE_DATA_DIR, "data_local.json"), "r", encoding="utf-8") as f:
|
|
23
|
-
LOCAL_DATASETS = {
|
|
24
|
-
entry["name"]: LocalFileMetadata(
|
|
25
|
-
name=entry["name"],
|
|
26
|
-
filename=os.path.join(_LOCAL_DATA_DIR, entry["filename"]),
|
|
27
|
-
description=entry["description"],
|
|
28
|
-
)
|
|
29
|
-
for entry in json.load(f)
|
|
30
|
-
}
|
|
31
|
-
|
|
32
|
-
with open(os.path.join(_EXAMPLE_DATA_DIR, "data_remote.json"), "r", encoding="utf-8") as f:
|
|
33
|
-
REMOTE_DATASETS = {entry["name"]: RemoteFileMetadata(**entry) for entry in json.load(f)}
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def get_data_home(data_home=None):
|
|
37
|
-
"""Return the path of the arviz data dir.
|
|
38
|
-
|
|
39
|
-
This folder is used by some dataset loaders to avoid downloading the
|
|
40
|
-
data several times.
|
|
41
|
-
|
|
42
|
-
By default the data dir is set to a folder named 'arviz_data' in the
|
|
43
|
-
user home folder.
|
|
44
|
-
|
|
45
|
-
Alternatively, it can be set by the 'ARVIZ_DATA' environment
|
|
46
|
-
variable or programmatically by giving an explicit folder path. The '~'
|
|
47
|
-
symbol is expanded to the user home folder.
|
|
48
|
-
|
|
49
|
-
If the folder does not already exist, it is automatically created.
|
|
50
|
-
|
|
51
|
-
Parameters
|
|
52
|
-
----------
|
|
53
|
-
data_home : str | None
|
|
54
|
-
The path to arviz data dir.
|
|
55
|
-
"""
|
|
56
|
-
if data_home is None:
|
|
57
|
-
data_home = os.environ.get("ARVIZ_DATA", os.path.join("~", "arviz_data"))
|
|
58
|
-
data_home = os.path.expanduser(data_home)
|
|
59
|
-
if not os.path.exists(data_home):
|
|
60
|
-
os.makedirs(data_home)
|
|
61
|
-
return data_home
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def clear_data_home(data_home=None):
|
|
65
|
-
"""Delete all the content of the data home cache.
|
|
66
|
-
|
|
67
|
-
Parameters
|
|
68
|
-
----------
|
|
69
|
-
data_home : str | None
|
|
70
|
-
The path to arviz data dir.
|
|
71
|
-
"""
|
|
72
|
-
data_home = get_data_home(data_home)
|
|
73
|
-
shutil.rmtree(data_home)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def _sha256(path):
|
|
77
|
-
"""Calculate the sha256 hash of the file at path."""
|
|
78
|
-
sha256hash = hashlib.sha256()
|
|
79
|
-
chunk_size = 8192
|
|
80
|
-
with open(path, "rb") as buff:
|
|
81
|
-
while True:
|
|
82
|
-
buffer = buff.read(chunk_size)
|
|
83
|
-
if not buffer:
|
|
84
|
-
break
|
|
85
|
-
sha256hash.update(buffer)
|
|
86
|
-
return sha256hash.hexdigest()
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
def load_arviz_data(dataset=None, data_home=None, **kwargs):
|
|
90
|
-
"""Load a local or remote pre-made dataset.
|
|
91
|
-
|
|
92
|
-
Run with no parameters to get a list of all available models.
|
|
93
|
-
|
|
94
|
-
The directory to save to can also be set with the environment
|
|
95
|
-
variable `ARVIZ_HOME`. The checksum of the dataset is checked against a
|
|
96
|
-
hardcoded value to watch for data corruption.
|
|
97
|
-
|
|
98
|
-
Run `az.clear_data_home` to clear the data directory.
|
|
99
|
-
|
|
100
|
-
Parameters
|
|
101
|
-
----------
|
|
102
|
-
dataset : str
|
|
103
|
-
Name of dataset to load.
|
|
104
|
-
data_home : str, optional
|
|
105
|
-
Where to save remote datasets
|
|
106
|
-
**kwargs : dict, optional
|
|
107
|
-
Keyword arguments passed to :func:`arviz.from_netcdf`.
|
|
108
|
-
|
|
109
|
-
Returns
|
|
110
|
-
-------
|
|
111
|
-
xarray.Dataset
|
|
112
|
-
|
|
113
|
-
"""
|
|
114
|
-
if dataset in LOCAL_DATASETS:
|
|
115
|
-
resource = LOCAL_DATASETS[dataset]
|
|
116
|
-
return from_netcdf(resource.filename, **kwargs)
|
|
117
|
-
|
|
118
|
-
elif dataset in REMOTE_DATASETS:
|
|
119
|
-
remote = REMOTE_DATASETS[dataset]
|
|
120
|
-
home_dir = get_data_home(data_home=data_home)
|
|
121
|
-
file_path = os.path.join(home_dir, remote.filename)
|
|
122
|
-
|
|
123
|
-
if not os.path.exists(file_path):
|
|
124
|
-
http_type = rcParams["data.http_protocol"]
|
|
125
|
-
|
|
126
|
-
# Replaces http type. Redundant if http_type is http, useful if http_type is https
|
|
127
|
-
url = remote.url.replace("http", http_type)
|
|
128
|
-
urlretrieve(url, file_path)
|
|
129
|
-
|
|
130
|
-
checksum = _sha256(file_path)
|
|
131
|
-
if remote.checksum != checksum:
|
|
132
|
-
raise IOError(
|
|
133
|
-
f"{file_path} has an SHA256 checksum ({checksum}) differing from expected "
|
|
134
|
-
"({remote.checksum}), file may be corrupted. "
|
|
135
|
-
"Run `arviz.clear_data_home()` and try again, or please open an issue."
|
|
136
|
-
)
|
|
137
|
-
return from_netcdf(file_path, **kwargs)
|
|
138
|
-
else:
|
|
139
|
-
if dataset is None:
|
|
140
|
-
return dict(itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items()))
|
|
141
|
-
else:
|
|
142
|
-
raise ValueError(
|
|
143
|
-
"Dataset {} not found! The following are available:\n{}".format(
|
|
144
|
-
dataset, list_datasets()
|
|
145
|
-
)
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
def list_datasets():
|
|
150
|
-
"""Get a string representation of all available datasets with descriptions."""
|
|
151
|
-
lines = []
|
|
152
|
-
for name, resource in itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items()):
|
|
153
|
-
if isinstance(resource, LocalFileMetadata):
|
|
154
|
-
location = f"local: {resource.filename}"
|
|
155
|
-
elif isinstance(resource, RemoteFileMetadata):
|
|
156
|
-
location = f"remote: {resource.url}"
|
|
157
|
-
else:
|
|
158
|
-
location = "unknown"
|
|
159
|
-
lines.append(f"{name}\n{'=' * len(name)}\n{resource.description}\n\n{location}")
|
|
160
|
-
|
|
161
|
-
return f"\n\n{10 * '-'}\n\n".join(lines)
|