arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,609 +0,0 @@
1
- """Stats-utility functions for ArviZ."""
2
-
3
- import warnings
4
- from collections.abc import Sequence
5
- from copy import copy as _copy
6
- from copy import deepcopy as _deepcopy
7
-
8
- import numpy as np
9
- import pandas as pd
10
- from scipy.fftpack import next_fast_len
11
- from scipy.interpolate import CubicSpline
12
- from scipy.stats.mstats import mquantiles
13
- from xarray import apply_ufunc
14
-
15
- from .. import _log
16
- from ..utils import conditional_jit, conditional_vect, conditional_dask
17
- from .density_utils import histogram as _histogram
18
-
19
-
20
- __all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "smooth_data", "wrap_xarray_ufunc"]
21
-
22
-
23
- def autocov(ary, axis=-1):
24
- """Compute autocovariance estimates for every lag for the input array.
25
-
26
- Parameters
27
- ----------
28
- ary : Numpy array
29
- An array containing MCMC samples
30
-
31
- Returns
32
- -------
33
- acov: Numpy array same size as the input array
34
- """
35
- axis = axis if axis > 0 else len(ary.shape) + axis
36
- n = ary.shape[axis]
37
- m = next_fast_len(2 * n)
38
-
39
- ary = ary - ary.mean(axis, keepdims=True)
40
-
41
- # added to silence tuple warning for a submodule
42
- with warnings.catch_warnings():
43
- warnings.simplefilter("ignore")
44
-
45
- ifft_ary = np.fft.rfft(ary, n=m, axis=axis)
46
- ifft_ary *= np.conjugate(ifft_ary)
47
-
48
- shape = tuple(
49
- slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape)
50
- )
51
- cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape]
52
- cov /= n
53
-
54
- return cov
55
-
56
-
57
- def autocorr(ary, axis=-1):
58
- """Compute autocorrelation using FFT for every lag for the input array.
59
-
60
- See https://en.wikipedia.org/wiki/autocorrelation#Efficient_computation
61
-
62
- Parameters
63
- ----------
64
- ary : Numpy array
65
- An array containing MCMC samples
66
-
67
- Returns
68
- -------
69
- acorr: Numpy array same size as the input array
70
- """
71
- corr = autocov(ary, axis=axis)
72
- axis = axis = axis if axis > 0 else len(corr.shape) + axis
73
- norm = tuple(
74
- slice(None, None) if dim != axis else slice(None, 1) for dim, _ in enumerate(corr.shape)
75
- )
76
- with np.errstate(invalid="ignore"):
77
- corr /= corr[norm]
78
- return corr
79
-
80
-
81
- def make_ufunc(
82
- func, n_dims=2, n_output=1, n_input=1, index=Ellipsis, ravel=True, check_shape=None
83
- ): # noqa: D202
84
- """Make ufunc from a function taking 1D array input.
85
-
86
- Parameters
87
- ----------
88
- func : callable
89
- n_dims : int, optional
90
- Number of core dimensions not broadcasted. Dimensions are skipped from the end.
91
- At minimum n_dims > 0.
92
- n_output : int, optional
93
- Select number of results returned by `func`.
94
- If n_output > 1, ufunc returns a tuple of objects else returns an object.
95
- n_input : int, optional
96
- Number of **array** inputs to func, i.e. ``n_input=2`` means that func is called
97
- with ``func(ary1, ary2, *args, **kwargs)``
98
- index : int, optional
99
- Slice ndarray with `index`. Defaults to `Ellipsis`.
100
- ravel : bool, optional
101
- If true, ravel the ndarray before calling `func`.
102
- check_shape: bool, optional
103
- If false, do not check if the shape of the output is compatible with n_dims and
104
- n_output. By default, True only for n_input=1. If n_input is larger than 1, the last
105
- input array is used to check the shape, however, shape checking with multiple inputs
106
- may not be correct.
107
-
108
- Returns
109
- -------
110
- callable
111
- ufunc wrapper for `func`.
112
- """
113
- if n_dims < 1:
114
- raise TypeError("n_dims must be one or higher.")
115
-
116
- if n_input == 1 and check_shape is None:
117
- check_shape = True
118
- elif check_shape is None:
119
- check_shape = False
120
-
121
- def _ufunc(*args, out=None, out_shape=None, **kwargs):
122
- """General ufunc for single-output function."""
123
- arys = args[:n_input]
124
- n_dims_out = None
125
- if out is None:
126
- if out_shape is None:
127
- out = np.empty(arys[-1].shape[:-n_dims])
128
- else:
129
- out = np.empty((*arys[-1].shape[:-n_dims], *out_shape))
130
- n_dims_out = -len(out_shape)
131
- elif check_shape:
132
- if out.shape != arys[-1].shape[:-n_dims]:
133
- msg = f"Shape incorrect for `out`: {out.shape}."
134
- msg += f" Correct shape is {arys[-1].shape[:-n_dims]}"
135
- raise TypeError(msg)
136
- for idx in np.ndindex(out.shape[:n_dims_out]):
137
- arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
138
- out_idx = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
139
- if n_dims_out is None:
140
- out_idx = out_idx.item()
141
- out[idx] = out_idx
142
- return out
143
-
144
- def _multi_ufunc(*args, out=None, out_shape=None, **kwargs):
145
- """General ufunc for multi-output function."""
146
- arys = args[:n_input]
147
- element_shape = arys[-1].shape[:-n_dims]
148
- if out is None:
149
- if out_shape is None:
150
- out = tuple(np.empty(element_shape) for _ in range(n_output))
151
- else:
152
- out = tuple(np.empty((*element_shape, *out_shape[i])) for i in range(n_output))
153
-
154
- elif check_shape:
155
- raise_error = False
156
- correct_shape = tuple(element_shape for _ in range(n_output))
157
- if isinstance(out, tuple):
158
- out_shape = tuple(item.shape for item in out)
159
- if out_shape != correct_shape:
160
- raise_error = True
161
- else:
162
- raise_error = True
163
- out_shape = "not tuple, type={type(out)}"
164
- if raise_error:
165
- msg = f"Shapes incorrect for `out`: {out_shape}."
166
- msg += f" Correct shapes are {correct_shape}"
167
- raise TypeError(msg)
168
- for idx in np.ndindex(element_shape):
169
- arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
170
- results = func(*arys_idx, *args[n_input:], **kwargs)
171
- for i, res in enumerate(results):
172
- out[i][idx] = np.asarray(res)[index]
173
- return out
174
-
175
- if n_output > 1:
176
- ufunc = _multi_ufunc
177
- else:
178
- ufunc = _ufunc
179
-
180
- update_docstring(ufunc, func, n_output)
181
- return ufunc
182
-
183
-
184
- @conditional_dask
185
- def wrap_xarray_ufunc(
186
- ufunc,
187
- *datasets,
188
- ufunc_kwargs=None,
189
- func_args=None,
190
- func_kwargs=None,
191
- dask_kwargs=None,
192
- **kwargs,
193
- ):
194
- """Wrap make_ufunc with xarray.apply_ufunc.
195
-
196
- Parameters
197
- ----------
198
- ufunc : callable
199
- *datasets : xarray.Dataset
200
- ufunc_kwargs : dict
201
- Keyword arguments passed to `make_ufunc`.
202
- - 'n_dims', int, by default 2
203
- - 'n_output', int, by default 1
204
- - 'n_input', int, by default len(datasets)
205
- - 'index', slice, by default Ellipsis
206
- - 'ravel', bool, by default True
207
- func_args : tuple
208
- Arguments passed to 'ufunc'.
209
- func_kwargs : dict
210
- Keyword arguments passed to 'ufunc'.
211
- - 'out_shape', int, by default None
212
- dask_kwargs : dict
213
- Dask related kwargs passed to :func:`xarray:xarray.apply_ufunc`.
214
- Use ``enable_dask`` method of :class:`arviz.Dask` to set default kwargs.
215
- **kwargs
216
- Passed to :func:`xarray.apply_ufunc`.
217
-
218
- Returns
219
- -------
220
- xarray.Dataset
221
- """
222
- if ufunc_kwargs is None:
223
- ufunc_kwargs = {}
224
- ufunc_kwargs.setdefault("n_input", len(datasets))
225
- if func_args is None:
226
- func_args = tuple()
227
- if func_kwargs is None:
228
- func_kwargs = {}
229
- if dask_kwargs is None:
230
- dask_kwargs = {}
231
-
232
- kwargs.setdefault(
233
- "input_core_dims", tuple(("chain", "draw") for _ in range(len(func_args) + len(datasets)))
234
- )
235
- ufunc_kwargs.setdefault("n_dims", len(kwargs["input_core_dims"][-1]))
236
- kwargs.setdefault("output_core_dims", tuple([] for _ in range(ufunc_kwargs.get("n_output", 1))))
237
-
238
- callable_ufunc = make_ufunc(ufunc, **ufunc_kwargs)
239
-
240
- return apply_ufunc(
241
- callable_ufunc, *datasets, *func_args, kwargs=func_kwargs, **dask_kwargs, **kwargs
242
- )
243
-
244
-
245
- def update_docstring(ufunc, func, n_output=1):
246
- """Update ArviZ generated ufunc docstring."""
247
- module = ""
248
- name = ""
249
- docstring = ""
250
- if hasattr(func, "__module__") and isinstance(func.__module__, str):
251
- module += func.__module__
252
- if hasattr(func, "__name__"):
253
- name += func.__name__
254
- if hasattr(func, "__doc__") and isinstance(func.__doc__, str):
255
- docstring += func.__doc__
256
- ufunc.__doc__ += "\n\n"
257
- if module or name:
258
- ufunc.__doc__ += "This function is a ufunc wrapper for "
259
- ufunc.__doc__ += module + "." + name
260
- ufunc.__doc__ += "\n"
261
- ufunc.__doc__ += 'Call ufunc with n_args from xarray against "chain" and "draw" dimensions:'
262
- ufunc.__doc__ += "\n\n"
263
- input_core_dims = 'tuple(("chain", "draw") for _ in range(n_args))'
264
- if n_output > 1:
265
- output_core_dims = f" tuple([] for _ in range({n_output}))"
266
- msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims}, "
267
- msg += f"output_core_dims={ output_core_dims})"
268
- else:
269
- output_core_dims = ""
270
- msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims})"
271
- ufunc.__doc__ += msg
272
- ufunc.__doc__ += "\n\n"
273
- ufunc.__doc__ += "For example: np.std(data, ddof=1) --> n_args=2"
274
- if docstring:
275
- ufunc.__doc__ += "\n\n"
276
- ufunc.__doc__ += module
277
- ufunc.__doc__ += name
278
- ufunc.__doc__ += " docstring:"
279
- ufunc.__doc__ += "\n\n"
280
- ufunc.__doc__ += docstring
281
-
282
-
283
- def logsumexp(ary, *, b=None, b_inv=None, axis=None, keepdims=False, out=None, copy=True):
284
- """Stable logsumexp when b >= 0 and b is scalar.
285
-
286
- b_inv overwrites b unless b_inv is None.
287
- """
288
- # check dimensions for result arrays
289
- ary = np.asarray(ary)
290
- if ary.dtype.kind == "i":
291
- ary = ary.astype(np.float64)
292
- dtype = ary.dtype.type
293
- shape = ary.shape
294
- shape_len = len(shape)
295
- if isinstance(axis, Sequence):
296
- axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis)
297
- agroup = axis
298
- else:
299
- axis = axis if (axis is None) or (axis >= 0) else shape_len + axis
300
- agroup = (axis,)
301
- shape_max = (
302
- tuple(1 for _ in shape)
303
- if axis is None
304
- else tuple(1 if i in agroup else d for i, d in enumerate(shape))
305
- )
306
- # create result arrays
307
- if out is None:
308
- if not keepdims:
309
- out_shape = (
310
- tuple()
311
- if axis is None
312
- else tuple(d for i, d in enumerate(shape) if i not in agroup)
313
- )
314
- else:
315
- out_shape = shape_max
316
- out = np.empty(out_shape, dtype=dtype)
317
- if b_inv == 0:
318
- return np.full_like(out, np.inf, dtype=dtype) if out.shape else np.inf
319
- if b_inv is None and b == 0:
320
- return np.full_like(out, -np.inf) if out.shape else -np.inf
321
- ary_max = np.empty(shape_max, dtype=dtype)
322
- # calculations
323
- ary.max(axis=axis, keepdims=True, out=ary_max)
324
- if copy:
325
- ary = ary.copy()
326
- ary -= ary_max
327
- np.exp(ary, out=ary)
328
- ary.sum(axis=axis, keepdims=keepdims, out=out)
329
- np.log(out, out=out)
330
- if b_inv is not None:
331
- ary_max -= np.log(b_inv)
332
- elif b:
333
- ary_max += np.log(b)
334
- out += ary_max if keepdims else ary_max.squeeze()
335
- # transform to scalar if possible
336
- return out if out.shape else dtype(out)
337
-
338
-
339
- def quantile(ary, q, axis=None, limit=None):
340
- """Use same quantile function as R (Type 7)."""
341
- if limit is None:
342
- limit = tuple()
343
- return mquantiles(ary, q, alphap=1, betap=1, axis=axis, limit=limit)
344
-
345
-
346
- def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwargs=None):
347
- """Validate ndarray.
348
-
349
- Parameters
350
- ----------
351
- ary : numpy.ndarray
352
- check_nan : bool
353
- Check if any value contains NaN.
354
- check_shape : bool
355
- Check if array has correct shape. Assumes dimensions in order (chain, draw, *shape).
356
- For 1D arrays (shape = (n,)) assumes chain equals 1.
357
- nan_kwargs : dict
358
- Valid kwargs are:
359
- axis : int,
360
- Defaults to None.
361
- how : str, {"all", "any"}
362
- Default to "any".
363
- shape_kwargs : dict
364
- Valid kwargs are:
365
- min_chains : int
366
- Defaults to 1.
367
- min_draws : int
368
- Defaults to 4.
369
-
370
- Returns
371
- -------
372
- bool
373
- """
374
- ary = np.asarray(ary)
375
-
376
- nan_error = False
377
- draw_error = False
378
- chain_error = False
379
-
380
- if check_nan:
381
- if nan_kwargs is None:
382
- nan_kwargs = {}
383
-
384
- isnan = np.isnan(ary)
385
- axis = nan_kwargs.get("axis", None)
386
- if nan_kwargs.get("how", "any").lower() == "all":
387
- nan_error = isnan.all(axis)
388
- else:
389
- nan_error = isnan.any(axis)
390
-
391
- if (isinstance(nan_error, bool) and nan_error) or nan_error.any():
392
- _log.warning("Array contains NaN-value.")
393
-
394
- if check_shape:
395
- shape = ary.shape
396
-
397
- if shape_kwargs is None:
398
- shape_kwargs = {}
399
-
400
- min_chains = shape_kwargs.get("min_chains", 2)
401
- min_draws = shape_kwargs.get("min_draws", 4)
402
- error_msg = f"Shape validation failed: input_shape: {shape}, "
403
- error_msg += f"minimum_shape: (chains={min_chains}, draws={min_draws})"
404
-
405
- chain_error = ((min_chains > 1) and (len(shape) < 2)) or (shape[0] < min_chains)
406
- draw_error = ((len(shape) < 2) and (shape[0] < min_draws)) or (
407
- (len(shape) > 1) and (shape[1] < min_draws)
408
- )
409
-
410
- if chain_error or draw_error:
411
- _log.warning(error_msg)
412
-
413
- return nan_error | chain_error | draw_error
414
-
415
-
416
- def get_log_likelihood(idata, var_name=None, single_var=True):
417
- """Retrieve the log likelihood dataarray of a given variable."""
418
- if (
419
- not hasattr(idata, "log_likelihood")
420
- and hasattr(idata, "sample_stats")
421
- and hasattr(idata.sample_stats, "log_likelihood")
422
- ):
423
- warnings.warn(
424
- "Storing the log_likelihood in sample_stats groups has been deprecated",
425
- DeprecationWarning,
426
- )
427
- return idata.sample_stats.log_likelihood
428
- if not hasattr(idata, "log_likelihood"):
429
- raise TypeError("log likelihood not found in inference data object")
430
- if var_name is None:
431
- var_names = list(idata.log_likelihood.data_vars)
432
- if len(var_names) > 1:
433
- if single_var:
434
- raise TypeError(
435
- f"Found several log likelihood arrays {var_names}, var_name cannot be None"
436
- )
437
- return idata.log_likelihood[var_names]
438
- return idata.log_likelihood[var_names[0]]
439
- else:
440
- try:
441
- log_likelihood = idata.log_likelihood[var_name]
442
- except KeyError as err:
443
- raise TypeError(f"No log likelihood data named {var_name} found") from err
444
- return log_likelihood
445
-
446
-
447
- BASE_FMT = """Computed from {{n_samples}} posterior samples and \
448
- {{n_points}} observations log-likelihood matrix.
449
-
450
- {{0:{0}}} Estimate SE
451
- {{scale}}_{{kind}} {{1:8.2f}} {{2:7.2f}}
452
- p_{{kind:{1}}} {{3:8.2f}} -"""
453
- POINTWISE_LOO_FMT = """------
454
-
455
- Pareto k diagnostic values:
456
- {{0:>{0}}} {{1:>6}}
457
- (-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}%
458
- ({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}%
459
- (1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}%
460
- """
461
- SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}
462
-
463
-
464
- class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
465
- """Class to contain the data from elpd information criterion like waic or loo."""
466
-
467
- def __str__(self):
468
- """Print elpd data in a user friendly way."""
469
- kind = self.index[0].split("_")[1]
470
-
471
- if kind not in ("loo", "waic"):
472
- raise ValueError("Invalid ELPDData object")
473
-
474
- scale_str = SCALE_DICT[self["scale"]]
475
- padding = len(scale_str) + len(kind) + 1
476
- base = BASE_FMT.format(padding, padding - 2)
477
- base = base.format(
478
- "",
479
- kind=kind,
480
- scale=scale_str,
481
- n_samples=self.n_samples,
482
- n_points=self.n_data_points,
483
- *self.values,
484
- )
485
-
486
- if self.warning:
487
- base += "\n\nThere has been a warning during the calculation. Please check the results."
488
-
489
- if kind == "loo" and "pareto_k" in self:
490
- bins = np.asarray([-np.inf, self.good_k, 1, np.inf])
491
- counts, *_ = _histogram(self.pareto_k.values, bins)
492
- extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
493
- extended = extended.format(
494
- "Count",
495
- "Pct.",
496
- *[*counts, *(counts / np.sum(counts) * 100)],
497
- self.good_k,
498
- )
499
- base = "\n".join([base, extended])
500
- return base
501
-
502
- def __repr__(self):
503
- """Alias to ``__str__``."""
504
- return self.__str__()
505
-
506
- def copy(self, deep=True): # pylint:disable=overridden-final-method
507
- """Perform a pandas deep copy of the ELPDData plus a copy of the stored data."""
508
- copied_obj = pd.Series.copy(self)
509
- for key in copied_obj.keys():
510
- if deep:
511
- copied_obj[key] = _deepcopy(copied_obj[key])
512
- else:
513
- copied_obj[key] = _copy(copied_obj[key])
514
- return ELPDData(copied_obj)
515
-
516
-
517
- @conditional_jit(nopython=True)
518
- def stats_variance_1d(data, ddof=0):
519
- a_a, b_b = 0, 0
520
- for i in data:
521
- a_a = a_a + i
522
- b_b = b_b + i * i
523
- var = b_b / (len(data)) - ((a_a / (len(data))) ** 2)
524
- var = var * (len(data) / (len(data) - ddof))
525
- return var
526
-
527
-
528
- def stats_variance_2d(data, ddof=0, axis=1):
529
- if data.ndim == 1:
530
- return stats_variance_1d(data, ddof=ddof)
531
- a_a, b_b = data.shape
532
- if axis == 1:
533
- var = np.zeros(a_a)
534
- for i in range(a_a):
535
- var[i] = stats_variance_1d(data[i], ddof=ddof)
536
- else:
537
- var = np.zeros(b_b)
538
- for i in range(b_b):
539
- var[i] = stats_variance_1d(data[:, i], ddof=ddof)
540
-
541
- return var
542
-
543
-
544
- @conditional_vect
545
- def _sqrt(a_a, b_b):
546
- return (a_a + b_b) ** 0.5
547
-
548
-
549
- def _circfunc(samples, high, low, skipna):
550
- samples = np.asarray(samples)
551
- if skipna:
552
- samples = samples[~np.isnan(samples)]
553
- if samples.size == 0:
554
- return np.nan
555
- return _angle(samples, low, high, np.pi)
556
-
557
-
558
- @conditional_vect
559
- def _angle(samples, low, high, p_i=np.pi):
560
- ang = (samples - low) * 2.0 * p_i / (high - low)
561
- return ang
562
-
563
-
564
- def _circular_standard_deviation(samples, high=2 * np.pi, low=0, skipna=False, axis=None):
565
- ang = _circfunc(samples, high, low, skipna)
566
- s_s = np.sin(ang).mean(axis=axis)
567
- c_c = np.cos(ang).mean(axis=axis)
568
- r_r = np.hypot(s_s, c_c)
569
- return ((high - low) / 2.0 / np.pi) * np.sqrt(-2 * np.log(r_r))
570
-
571
-
572
- def smooth_data(obs_vals, pp_vals):
573
- """Smooth data using a cubic spline.
574
-
575
- Helper function for discrete data in plot_pbv, loo_pit and plot_loo_pit.
576
-
577
- Parameters
578
- ----------
579
- obs_vals : (N) array-like
580
- Observed data
581
- pp_vals : (S, N) array-like
582
- Posterior predictive samples. ``N`` is the number of observations,
583
- and ``S`` is the number of samples (generally n_chains*n_draws).
584
-
585
- Returns
586
- -------
587
- obs_vals : (N) ndarray
588
- Smoothed observed data
589
- pp_vals : (S, N) ndarray
590
- Smoothed posterior predictive samples
591
- """
592
- x = np.linspace(0, 1, len(obs_vals))
593
- csi = CubicSpline(x, obs_vals)
594
- obs_vals = csi(np.linspace(0.01, 0.99, len(obs_vals)))
595
-
596
- x = np.linspace(0, 1, pp_vals.shape[1])
597
- csi = CubicSpline(x, pp_vals, axis=1)
598
- pp_vals = csi(np.linspace(0.01, 0.99, pp_vals.shape[1]))
599
-
600
- return obs_vals, pp_vals
601
-
602
-
603
- def get_log_prior(idata, var_names=None):
604
- """Retrieve the log prior dataarray of a given variable."""
605
- if not hasattr(idata, "log_prior"):
606
- raise TypeError("log prior not found in inference data object")
607
- if var_names is None:
608
- var_names = list(idata.log_prior.data_vars)
609
- return idata.log_prior[var_names]
arviz/tests/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Test suite."""
@@ -1 +0,0 @@
1
- """Base test suite."""