arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
arviz/data/io_emcee.py DELETED
@@ -1,317 +0,0 @@
1
- """emcee-specific conversion code."""
2
-
3
- import warnings
4
- from collections import OrderedDict
5
-
6
- import numpy as np
7
- import xarray as xr
8
-
9
- from .. import utils
10
- from .base import dict_to_dataset, generate_dims_coords, make_attrs
11
- from .inference_data import InferenceData
12
-
13
-
14
- def _verify_names(sampler, var_names, arg_names, slices):
15
- """Make sure var_names and arg_names are assigned reasonably.
16
-
17
- This is meant to run before loading emcee objects into InferenceData.
18
- In case var_names or arg_names is None, will provide defaults. If they are
19
- not None, it verifies there are the right number of them.
20
-
21
- Throws a ValueError in case validation fails.
22
-
23
- Parameters
24
- ----------
25
- sampler : emcee.EnsembleSampler
26
- Fitted emcee sampler
27
- var_names : list[str] or None
28
- Names for the emcee parameters
29
- arg_names : list[str] or None
30
- Names for the args/observations provided to emcee
31
- slices : list[seq] or None
32
- slices to select the variables (used for multidimensional variables)
33
-
34
- Returns
35
- -------
36
- list[str], list[str], list[seq]
37
- Defaults for var_names, arg_names and slices
38
- """
39
- # There are 3 possible cases: emcee2, emcee3 and sampler read from h5 file (emcee3 only)
40
- if hasattr(sampler, "args"):
41
- ndim = sampler.chain.shape[-1]
42
- num_args = len(sampler.args)
43
- elif hasattr(sampler, "log_prob_fn"):
44
- ndim = sampler.get_chain().shape[-1]
45
- num_args = len(sampler.log_prob_fn.args)
46
- else:
47
- ndim = sampler.get_chain().shape[-1]
48
- num_args = 0 # emcee only stores the posterior samples
49
-
50
- if slices is None:
51
- slices = utils.arange(ndim)
52
- num_vars = ndim
53
- else:
54
- num_vars = len(slices)
55
- indices = utils.arange(ndim)
56
- slicing_try = np.concatenate([utils.one_de(indices[idx]) for idx in slices])
57
- if len(set(slicing_try)) != ndim:
58
- warnings.warn(
59
- "Check slices: Not all parameters in chain captured. "
60
- f"{ndim} are present, and {len(slicing_try)} have been captured.",
61
- UserWarning,
62
- )
63
- if len(slicing_try) != len(set(slicing_try)):
64
- warnings.warn(f"Overlapping slices. Check the index present: {slicing_try}", UserWarning)
65
-
66
- if var_names is None:
67
- var_names = [f"var_{idx}" for idx in range(num_vars)]
68
- if arg_names is None:
69
- arg_names = [f"arg_{idx}" for idx in range(num_args)]
70
-
71
- if len(var_names) != num_vars:
72
- raise ValueError(
73
- f"The sampler has {num_vars} variables, "
74
- f"but only {len(var_names)} var_names were provided!"
75
- )
76
-
77
- if len(arg_names) != num_args:
78
- raise ValueError(
79
- f"The sampler has {num_args} args, "
80
- f"but only {len(arg_names)} arg_names were provided!"
81
- )
82
- return var_names, arg_names, slices
83
-
84
-
85
- # pylint: disable=too-many-instance-attributes
86
- class EmceeConverter:
87
- """Encapsulate emcee specific logic."""
88
-
89
- def __init__(
90
- self,
91
- sampler,
92
- var_names=None,
93
- slices=None,
94
- arg_names=None,
95
- arg_groups=None,
96
- blob_names=None,
97
- blob_groups=None,
98
- index_origin=None,
99
- coords=None,
100
- dims=None,
101
- ):
102
- var_names, arg_names, slices = _verify_names(sampler, var_names, arg_names, slices)
103
- self.sampler = sampler
104
- self.var_names = var_names
105
- self.slices = slices
106
- self.arg_names = arg_names
107
- self.arg_groups = arg_groups
108
- self.blob_names = blob_names
109
- self.blob_groups = blob_groups
110
- self.index_origin = index_origin
111
- self.coords = coords
112
- self.dims = dims
113
- import emcee
114
-
115
- self.emcee = emcee
116
-
117
- def posterior_to_xarray(self):
118
- """Convert the posterior to an xarray dataset."""
119
- # Use emcee3 syntax, else use emcee2
120
- if hasattr(self.sampler, "get_chain"):
121
- samples_ary = self.sampler.get_chain().swapaxes(0, 1)
122
- else:
123
- samples_ary = self.sampler.chain
124
-
125
- data = {
126
- var_name: (samples_ary[(..., idx)])
127
- for idx, var_name in zip(self.slices, self.var_names)
128
- }
129
- return dict_to_dataset(
130
- data,
131
- library=self.emcee,
132
- coords=self.coords,
133
- dims=self.dims,
134
- index_origin=self.index_origin,
135
- )
136
-
137
- def args_to_xarray(self):
138
- """Convert emcee args to observed and constant_data xarray Datasets."""
139
- dims = {} if self.dims is None else self.dims
140
- if self.arg_groups is None:
141
- self.arg_groups = ["observed_data" for _ in self.arg_names]
142
- if len(self.arg_names) != len(self.arg_groups):
143
- raise ValueError(
144
- "arg_names and arg_groups must have the same length, or arg_groups be None"
145
- )
146
- arg_groups_set = set(self.arg_groups)
147
- bad_groups = [
148
- group for group in arg_groups_set if group not in ("observed_data", "constant_data")
149
- ]
150
- if bad_groups:
151
- raise SyntaxError(
152
- "all arg_groups values should be either 'observed_data' or 'constant_data' , "
153
- f"not {bad_groups}"
154
- )
155
- obs_const_dict = {group: OrderedDict() for group in arg_groups_set}
156
- for idx, (arg_name, group) in enumerate(zip(self.arg_names, self.arg_groups)):
157
- # Use emcee3 syntax, else use emcee2
158
- arg_array = np.atleast_1d(
159
- self.sampler.log_prob_fn.args[idx]
160
- if hasattr(self.sampler, "log_prob_fn")
161
- else self.sampler.args[idx]
162
- )
163
- arg_dims = dims.get(arg_name)
164
- arg_dims, coords = generate_dims_coords(
165
- arg_array.shape,
166
- arg_name,
167
- dims=arg_dims,
168
- coords=self.coords,
169
- index_origin=self.index_origin,
170
- )
171
- # filter coords based on the dims
172
- coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in arg_dims}
173
- obs_const_dict[group][arg_name] = xr.DataArray(arg_array, dims=arg_dims, coords=coords)
174
- for key, values in obs_const_dict.items():
175
- obs_const_dict[key] = xr.Dataset(data_vars=values, attrs=make_attrs(library=self.emcee))
176
- return obs_const_dict
177
-
178
- def blobs_to_dict(self):
179
- """Convert blobs to dictionary {groupname: xr.Dataset}.
180
-
181
- It also stores lp values in sample_stats group.
182
- """
183
- store_blobs = self.blob_names is not None
184
- self.blob_names = [] if self.blob_names is None else self.blob_names
185
- if self.blob_groups is None:
186
- self.blob_groups = ["log_likelihood" for _ in self.blob_names]
187
- if len(self.blob_names) != len(self.blob_groups):
188
- raise ValueError(
189
- "blob_names and blob_groups must have the same length, or blob_groups be None"
190
- )
191
- if store_blobs:
192
- if int(self.emcee.__version__[0]) >= 3:
193
- blobs = self.sampler.get_blobs()
194
- else:
195
- blobs = np.array(self.sampler.blobs, dtype=object)
196
- if (blobs is None or blobs.size == 0) and self.blob_names:
197
- raise ValueError("No blobs in sampler, blob_names must be None")
198
- if len(blobs.shape) == 2:
199
- blobs = np.expand_dims(blobs, axis=-1)
200
- blobs = blobs.swapaxes(0, 2)
201
- nblobs, nwalkers, ndraws, *_ = blobs.shape
202
- if len(self.blob_names) != nblobs and len(self.blob_names) > 1:
203
- raise ValueError(
204
- "Incorrect number of blob names. "
205
- f"Expected {nblobs}, found {len(self.blob_names)}"
206
- )
207
- blob_groups_set = set(self.blob_groups)
208
- blob_groups_set.add("sample_stats")
209
- idata_groups = ("posterior", "observed_data", "constant_data")
210
- if np.any(np.isin(list(blob_groups_set), idata_groups)):
211
- raise SyntaxError(
212
- f"{idata_groups} groups should not come from blobs. "
213
- "Using them here would overwrite their actual values"
214
- )
215
- blob_dict = {group: OrderedDict() for group in blob_groups_set}
216
- if len(self.blob_names) == 1:
217
- blob_dict[self.blob_groups[0]][self.blob_names[0]] = blobs.swapaxes(0, 2).swapaxes(0, 1)
218
- else:
219
- for i_blob, (name, group) in enumerate(zip(self.blob_names, self.blob_groups)):
220
- # for coherent blobs (all having the same dimensions) one line is enough
221
- blob = blobs[i_blob]
222
- # for blobs of different size, we get an array of arrays, which we convert
223
- # to an ndarray per blob_name
224
- if blob.dtype == object:
225
- blob = blob.reshape(-1)
226
- blob = np.stack(blob)
227
- blob = blob.reshape((nwalkers, ndraws, -1))
228
- blob_dict[group][name] = np.squeeze(blob)
229
-
230
- # store lp in sample_stats group
231
- blob_dict["sample_stats"]["lp"] = (
232
- self.sampler.get_log_prob().swapaxes(0, 1)
233
- if hasattr(self.sampler, "get_log_prob")
234
- else self.sampler.lnprobability
235
- )
236
- for key, values in blob_dict.items():
237
- blob_dict[key] = dict_to_dataset(
238
- values,
239
- library=self.emcee,
240
- coords=self.coords,
241
- dims=self.dims,
242
- index_origin=self.index_origin,
243
- )
244
- return blob_dict
245
-
246
- def to_inference_data(self):
247
- """Convert all available data to an InferenceData object."""
248
- blobs_dict = self.blobs_to_dict()
249
- obs_const_dict = self.args_to_xarray()
250
- return InferenceData(
251
- **{"posterior": self.posterior_to_xarray(), **obs_const_dict, **blobs_dict}
252
- )
253
-
254
-
255
- def from_emcee(
256
- sampler=None,
257
- var_names=None,
258
- slices=None,
259
- arg_names=None,
260
- arg_groups=None,
261
- blob_names=None,
262
- blob_groups=None,
263
- index_origin=None,
264
- coords=None,
265
- dims=None,
266
- ):
267
- """Convert emcee data into an InferenceData object.
268
-
269
- For a usage example read :ref:`emcee_conversion`
270
-
271
-
272
- Parameters
273
- ----------
274
- sampler : emcee.EnsembleSampler
275
- Fitted sampler from emcee.
276
- var_names : list of str, optional
277
- A list of names for variables in the sampler
278
- slices : list of array-like or slice, optional
279
- A list containing the indexes of each variable. Should only be used
280
- for multidimensional variables.
281
- arg_names : list of str, optional
282
- A list of names for args in the sampler
283
- arg_groups : list of str, optional
284
- A list of the group names (either ``observed_data`` or ``constant_data``) where
285
- args in the sampler are stored. If None, all args will be stored in observed
286
- data group.
287
- blob_names : list of str, optional
288
- A list of names for blobs in the sampler. When None,
289
- blobs are omitted, independently of them being present
290
- in the sampler or not.
291
- blob_groups : list of str, optional
292
- A list of the groups where blob_names variables
293
- should be assigned respectively. If blob_names!=None
294
- and blob_groups is None, all variables are assigned
295
- to log_likelihood group
296
- coords : dict of {str : array_like}, optional
297
- Map of dimensions to coordinates
298
- dims : dict of {str : list of str}, optional
299
- Map variable names to their coordinates
300
-
301
- Returns
302
- -------
303
- arviz.InferenceData
304
-
305
- """
306
- return EmceeConverter(
307
- sampler=sampler,
308
- var_names=var_names,
309
- slices=slices,
310
- arg_names=arg_names,
311
- arg_groups=arg_groups,
312
- blob_names=blob_names,
313
- blob_groups=blob_groups,
314
- index_origin=index_origin,
315
- coords=coords,
316
- dims=dims,
317
- ).to_inference_data()
arviz/data/io_json.py DELETED
@@ -1,54 +0,0 @@
1
- """Input and output support for data."""
2
-
3
- from .io_dict import from_dict
4
-
5
- try:
6
- import ujson as json
7
- except ImportError:
8
- # Can't find ujson using json
9
- # mypy struggles with conditional imports expressed as catching ImportError:
10
- # https://github.com/python/mypy/issues/1153
11
- import json # type: ignore
12
-
13
-
14
- def from_json(filename):
15
- """Initialize object from a json file.
16
-
17
- Will use the faster `ujson` (https://github.com/ultrajson/ultrajson) if it is available.
18
-
19
- Parameters
20
- ----------
21
- filename : str
22
- location of json file
23
-
24
- Returns
25
- -------
26
- InferenceData object
27
- """
28
- with open(filename, "rb") as file:
29
- idata_dict = json.load(file)
30
-
31
- return from_dict(**idata_dict, save_warmup=True)
32
-
33
-
34
- def to_json(idata, filename):
35
- """Save dataset as a json file.
36
-
37
- Will use the faster `ujson` (https://github.com/ultrajson/ultrajson) if it is available.
38
-
39
- WARNING: Only idempotent in case `idata` is InferenceData.
40
-
41
- Parameters
42
- ----------
43
- idata : InferenceData
44
- Object to be saved
45
- filename : str
46
- name or path of the file to load trace
47
-
48
- Returns
49
- -------
50
- str
51
- filename saved to
52
- """
53
- file_name = idata.to_json(filename)
54
- return file_name
arviz/data/io_netcdf.py DELETED
@@ -1,68 +0,0 @@
1
- """Input and output support for data."""
2
-
3
- from .converters import convert_to_inference_data
4
- from .inference_data import InferenceData
5
-
6
-
7
- def from_netcdf(filename, *, engine="h5netcdf", group_kwargs=None, regex=False):
8
- """Load netcdf file back into an arviz.InferenceData.
9
-
10
- Parameters
11
- ----------
12
- filename : str
13
- name or path of the file to load trace
14
- engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
15
- Library used to read the netcdf file.
16
- group_kwargs : dict of {str: dict}
17
- Keyword arguments to be passed into each call of :func:`xarray.open_dataset`.
18
- The keys of the higher level should be group names or regex matching group
19
- names, the inner dicts re passed to ``open_dataset``.
20
- This feature is currently experimental
21
- regex : str
22
- Specifies where regex search should be used to extend the keyword arguments.
23
-
24
- Returns
25
- -------
26
- InferenceData object
27
-
28
- Notes
29
- -----
30
- By default, the datasets of the InferenceData object will be lazily loaded instead
31
- of loaded into memory. This behaviour is regulated by the value of
32
- ``az.rcParams["data.load"]``.
33
- """
34
- if group_kwargs is None:
35
- group_kwargs = {}
36
- return InferenceData.from_netcdf(
37
- filename, engine=engine, group_kwargs=group_kwargs, regex=regex
38
- )
39
-
40
-
41
- def to_netcdf(data, filename, *, group="posterior", engine="h5netcdf", coords=None, dims=None):
42
- """Save dataset as a netcdf file.
43
-
44
- WARNING: Only idempotent in case `data` is InferenceData
45
-
46
- Parameters
47
- ----------
48
- data : InferenceData, or any object accepted by `convert_to_inference_data`
49
- Object to be saved
50
- filename : str
51
- name or path of the file to load trace
52
- group : str (optional)
53
- In case `data` is not InferenceData, this is the group it will be saved to
54
- engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
55
- Library used to read the netcdf file.
56
- coords : dict (optional)
57
- See `convert_to_inference_data`
58
- dims : dict (optional)
59
- See `convert_to_inference_data`
60
-
61
- Returns
62
- -------
63
- str
64
- filename saved to
65
- """
66
- inference_data = convert_to_inference_data(data, group=group, coords=coords, dims=dims)
67
- file_name = inference_data.to_netcdf(filename, engine=engine)
68
- return file_name