arviz 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +49 -4
- arviz/data/converters.py +11 -0
- arviz/data/inference_data.py +46 -24
- arviz/data/io_datatree.py +2 -2
- arviz/data/io_numpyro.py +116 -5
- arviz/data/io_pyjags.py +1 -1
- arviz/plots/autocorrplot.py +12 -2
- arviz/plots/backends/bokeh/hdiplot.py +7 -6
- arviz/plots/backends/bokeh/lmplot.py +19 -3
- arviz/plots/backends/bokeh/pairplot.py +18 -48
- arviz/plots/backends/matplotlib/khatplot.py +8 -1
- arviz/plots/backends/matplotlib/lmplot.py +13 -7
- arviz/plots/backends/matplotlib/pairplot.py +14 -22
- arviz/plots/bpvplot.py +1 -1
- arviz/plots/dotplot.py +2 -0
- arviz/plots/forestplot.py +16 -4
- arviz/plots/kdeplot.py +4 -4
- arviz/plots/lmplot.py +41 -14
- arviz/plots/pairplot.py +10 -3
- arviz/plots/ppcplot.py +1 -1
- arviz/preview.py +31 -21
- arviz/rcparams.py +2 -2
- arviz/stats/density_utils.py +1 -1
- arviz/stats/stats.py +31 -34
- arviz/tests/base_tests/test_data.py +25 -4
- arviz/tests/base_tests/test_plots_bokeh.py +60 -2
- arviz/tests/base_tests/test_plots_matplotlib.py +94 -1
- arviz/tests/base_tests/test_stats.py +42 -1
- arviz/tests/base_tests/test_stats_ecdf_utils.py +2 -2
- arviz/tests/external_tests/test_data_numpyro.py +154 -4
- arviz/wrappers/base.py +1 -1
- arviz/wrappers/wrap_stan.py +1 -1
- {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/METADATA +20 -9
- {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/RECORD +37 -37
- {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/WHEEL +1 -1
- {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info/licenses}/LICENSE +0 -0
- {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/top_level.txt +0 -0
|
@@ -7,6 +7,7 @@ from matplotlib import cm
|
|
|
7
7
|
import matplotlib.pyplot as plt
|
|
8
8
|
import numpy as np
|
|
9
9
|
from matplotlib.colors import to_rgba_array
|
|
10
|
+
from packaging import version
|
|
10
11
|
|
|
11
12
|
from ....stats.density_utils import histogram
|
|
12
13
|
from ...plot_utils import _scale_fig_size, color_from_dim, set_xticklabels, vectorized_to_hex
|
|
@@ -39,7 +40,13 @@ def plot_khat(
|
|
|
39
40
|
show,
|
|
40
41
|
):
|
|
41
42
|
"""Matplotlib khat plot."""
|
|
42
|
-
if
|
|
43
|
+
if version.parse(mpl.__version__) >= version.parse("3.9.0.dev0"):
|
|
44
|
+
interactive_backends = mpl.backends.backend_registry.list_builtin(
|
|
45
|
+
mpl.backends.BackendFilter.INTERACTIVE
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
interactive_backends = mpl.rcsetup.interactive_bk
|
|
49
|
+
if hover_label and mpl.get_backend() not in interactive_backends:
|
|
43
50
|
hover_label = False
|
|
44
51
|
warnings.warn(
|
|
45
52
|
"hover labels are only available with interactive backends. To switch to an "
|
|
@@ -115,12 +115,18 @@ def plot_lm(
|
|
|
115
115
|
|
|
116
116
|
if y_model is not None:
|
|
117
117
|
_, _, _, y_model_plotters = y_model[i]
|
|
118
|
+
|
|
118
119
|
if kind_model == "lines":
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
120
|
+
# y_model_plotters should be (points, samples)
|
|
121
|
+
y_points = y_model_plotters.shape[0]
|
|
122
|
+
if x_plotters.shape[0] == y_points:
|
|
123
|
+
for j in range(num_samples):
|
|
124
|
+
ax_i.plot(x_plotters, y_model_plotters[:, j], **y_model_plot_kwargs)
|
|
125
|
+
|
|
126
|
+
ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")
|
|
127
|
+
y_model_mean = np.mean(y_model_plotters, axis=1)
|
|
128
|
+
ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
|
|
122
129
|
|
|
123
|
-
y_model_mean = np.mean(y_model_plotters, axis=1)
|
|
124
130
|
else:
|
|
125
131
|
plot_hdi(
|
|
126
132
|
x_plotters,
|
|
@@ -128,10 +134,10 @@ def plot_lm(
|
|
|
128
134
|
fill_kwargs=y_model_fill_kwargs,
|
|
129
135
|
ax=ax_i,
|
|
130
136
|
)
|
|
131
|
-
ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
|
|
132
137
|
|
|
133
|
-
|
|
134
|
-
|
|
138
|
+
ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
|
|
139
|
+
y_model_mean = np.mean(y_model_plotters, axis=0)
|
|
140
|
+
ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
|
|
135
141
|
|
|
136
142
|
if legend:
|
|
137
143
|
ax_i.legend(fontsize=xt_labelsize, loc="upper left")
|
|
@@ -30,6 +30,8 @@ def plot_pair(
|
|
|
30
30
|
diverging_mask,
|
|
31
31
|
divergences_kwargs,
|
|
32
32
|
flat_var_names,
|
|
33
|
+
flat_ref_slices,
|
|
34
|
+
flat_var_labels,
|
|
33
35
|
backend_kwargs,
|
|
34
36
|
marginal_kwargs,
|
|
35
37
|
show,
|
|
@@ -77,24 +79,12 @@ def plot_pair(
|
|
|
77
79
|
kde_kwargs["contour_kwargs"].setdefault("colors", "k")
|
|
78
80
|
|
|
79
81
|
if reference_values:
|
|
80
|
-
|
|
81
|
-
label = []
|
|
82
|
-
for variable in list(reference_values.keys()):
|
|
83
|
-
if " " in variable:
|
|
84
|
-
variable_copy = variable.replace(" ", "\n", 1)
|
|
85
|
-
else:
|
|
86
|
-
variable_copy = variable
|
|
87
|
-
|
|
88
|
-
label.append(variable_copy)
|
|
89
|
-
reference_values_copy[variable_copy] = reference_values[variable]
|
|
90
|
-
|
|
91
|
-
difference = set(flat_var_names).difference(set(label))
|
|
82
|
+
difference = set(flat_var_names).difference(set(reference_values.keys()))
|
|
92
83
|
|
|
93
84
|
if difference:
|
|
94
|
-
warn = [diff.replace("\n", " ", 1) for diff in difference]
|
|
95
85
|
warnings.warn(
|
|
96
86
|
"Argument reference_values does not include reference value for: {}".format(
|
|
97
|
-
", ".join(
|
|
87
|
+
", ".join(difference)
|
|
98
88
|
),
|
|
99
89
|
UserWarning,
|
|
100
90
|
)
|
|
@@ -211,12 +201,12 @@ def plot_pair(
|
|
|
211
201
|
|
|
212
202
|
if reference_values:
|
|
213
203
|
ax.plot(
|
|
214
|
-
|
|
215
|
-
|
|
204
|
+
np.array(reference_values[flat_var_names[0]])[flat_ref_slices[0]],
|
|
205
|
+
np.array(reference_values[flat_var_names[1]])[flat_ref_slices[1]],
|
|
216
206
|
**reference_values_kwargs,
|
|
217
207
|
)
|
|
218
|
-
ax.set_xlabel(f"{
|
|
219
|
-
ax.set_ylabel(f"{
|
|
208
|
+
ax.set_xlabel(f"{flat_var_labels[0]}", fontsize=ax_labelsize, wrap=True)
|
|
209
|
+
ax.set_ylabel(f"{flat_var_labels[1]}", fontsize=ax_labelsize, wrap=True)
|
|
220
210
|
ax.tick_params(labelsize=xt_labelsize)
|
|
221
211
|
|
|
222
212
|
else:
|
|
@@ -336,20 +326,22 @@ def plot_pair(
|
|
|
336
326
|
y_name = flat_var_names[j + not_marginals]
|
|
337
327
|
if (x_name not in difference) and (y_name not in difference):
|
|
338
328
|
ax[j, i].plot(
|
|
339
|
-
|
|
340
|
-
|
|
329
|
+
np.array(reference_values[x_name])[flat_ref_slices[i]],
|
|
330
|
+
np.array(reference_values[y_name])[
|
|
331
|
+
flat_ref_slices[j + not_marginals]
|
|
332
|
+
],
|
|
341
333
|
**reference_values_kwargs,
|
|
342
334
|
)
|
|
343
335
|
|
|
344
336
|
if j != vars_to_plot - 1:
|
|
345
337
|
plt.setp(ax[j, i].get_xticklabels(), visible=False)
|
|
346
338
|
else:
|
|
347
|
-
ax[j, i].set_xlabel(f"{
|
|
339
|
+
ax[j, i].set_xlabel(f"{flat_var_labels[i]}", fontsize=ax_labelsize, wrap=True)
|
|
348
340
|
if i != 0:
|
|
349
341
|
plt.setp(ax[j, i].get_yticklabels(), visible=False)
|
|
350
342
|
else:
|
|
351
343
|
ax[j, i].set_ylabel(
|
|
352
|
-
f"{
|
|
344
|
+
f"{flat_var_labels[j + not_marginals]}",
|
|
353
345
|
fontsize=ax_labelsize,
|
|
354
346
|
wrap=True,
|
|
355
347
|
)
|
arviz/plots/bpvplot.py
CHANGED
|
@@ -251,7 +251,7 @@ def plot_bpv(
|
|
|
251
251
|
total_pp_samples = predictive_dataset.sizes["chain"] * predictive_dataset.sizes["draw"]
|
|
252
252
|
|
|
253
253
|
for key in coords.keys():
|
|
254
|
-
coords[key] = np.where(np.
|
|
254
|
+
coords[key] = np.where(np.isin(observed[key], coords[key]))[0]
|
|
255
255
|
|
|
256
256
|
obs_plotters = filter_plotters_list(
|
|
257
257
|
list(
|
arviz/plots/dotplot.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
|
+
|
|
5
6
|
from ..rcparams import rcParams
|
|
6
7
|
from .plot_utils import get_plotting_function
|
|
7
8
|
|
|
@@ -148,6 +149,7 @@ def plot_dot(
|
|
|
148
149
|
raise ValueError("marker argument is valid only for matplotlib backend")
|
|
149
150
|
|
|
150
151
|
values = np.ravel(values)
|
|
152
|
+
values = values[np.isfinite(values)]
|
|
151
153
|
values.sort()
|
|
152
154
|
|
|
153
155
|
if hdi_prob is None:
|
arviz/plots/forestplot.py
CHANGED
|
@@ -51,7 +51,7 @@ def plot_forest(
|
|
|
51
51
|
data : InferenceData
|
|
52
52
|
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
53
53
|
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
54
|
-
kind : {"
|
|
54
|
+
kind : {"forestplot", "ridgeplot"}, default "forestplot"
|
|
55
55
|
Specify the kind of plot:
|
|
56
56
|
|
|
57
57
|
* The ``kind="forestplot"`` generates credible intervals, where the central points are the
|
|
@@ -75,8 +75,8 @@ def plot_forest(
|
|
|
75
75
|
interpret `var_names` as substrings of the real variables names. If "regex",
|
|
76
76
|
interpret `var_names` as regular expressions on the real variables names. See
|
|
77
77
|
:ref:`this section <common_filter_vars>` for usage examples.
|
|
78
|
-
transform : callable, optional
|
|
79
|
-
Function to transform data
|
|
78
|
+
transform : callable or dict, optional
|
|
79
|
+
Function to transform the data. Defaults to None, i.e., the identity function.
|
|
80
80
|
coords : dict, optional
|
|
81
81
|
Coordinates of ``var_names`` to be plotted. Passed to :meth:`xarray.Dataset.sel`.
|
|
82
82
|
See :ref:`this section <common_coords>` for usage examples.
|
|
@@ -228,7 +228,19 @@ def plot_forest(
|
|
|
228
228
|
|
|
229
229
|
datasets = [convert_to_dataset(datum) for datum in reversed(data)]
|
|
230
230
|
if transform is not None:
|
|
231
|
-
|
|
231
|
+
if callable(transform):
|
|
232
|
+
datasets = [transform(dataset) for dataset in datasets]
|
|
233
|
+
elif isinstance(transform, dict):
|
|
234
|
+
transformed_datasets = []
|
|
235
|
+
for dataset in datasets:
|
|
236
|
+
new_dataset = dataset.copy()
|
|
237
|
+
for var_name, func in transform.items():
|
|
238
|
+
if var_name in new_dataset:
|
|
239
|
+
new_dataset[var_name] = func(new_dataset[var_name])
|
|
240
|
+
transformed_datasets.append(new_dataset)
|
|
241
|
+
datasets = transformed_datasets
|
|
242
|
+
else:
|
|
243
|
+
raise ValueError("transform must be either a callable or a dict {var_name: callable}")
|
|
232
244
|
datasets = get_coords(
|
|
233
245
|
datasets, list(reversed(coords)) if isinstance(coords, (list, tuple)) else coords
|
|
234
246
|
)
|
arviz/plots/kdeplot.py
CHANGED
|
@@ -255,6 +255,10 @@ def plot_kde(
|
|
|
255
255
|
"or plot_pair instead of plot_kde"
|
|
256
256
|
)
|
|
257
257
|
|
|
258
|
+
if backend is None:
|
|
259
|
+
backend = rcParams["plot.backend"]
|
|
260
|
+
backend = backend.lower()
|
|
261
|
+
|
|
258
262
|
if values2 is None:
|
|
259
263
|
if bw == "default":
|
|
260
264
|
bw = "taylor" if is_circular else "experimental"
|
|
@@ -346,10 +350,6 @@ def plot_kde(
|
|
|
346
350
|
**kwargs,
|
|
347
351
|
)
|
|
348
352
|
|
|
349
|
-
if backend is None:
|
|
350
|
-
backend = rcParams["plot.backend"]
|
|
351
|
-
backend = backend.lower()
|
|
352
|
-
|
|
353
353
|
# TODO: Add backend kwargs
|
|
354
354
|
plot = get_plotting_function("plot_kde", "kdeplot", backend)
|
|
355
355
|
ax = plot(**kde_plot_args)
|
arviz/plots/lmplot.py
CHANGED
|
@@ -300,20 +300,47 @@ def plot_lm(
|
|
|
300
300
|
# Filter out the required values to generate plotters
|
|
301
301
|
if y_model is not None:
|
|
302
302
|
if kind_model == "lines":
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
303
|
+
var_name = y_model.name if y_model.name else "y_model"
|
|
304
|
+
data = y_model.values
|
|
305
|
+
|
|
306
|
+
total_samples = data.shape[0] * data.shape[1]
|
|
307
|
+
data = data.reshape(total_samples, *data.shape[2:])
|
|
308
|
+
|
|
309
|
+
if pp_sample_ix is not None:
|
|
310
|
+
data = data[pp_sample_ix]
|
|
311
|
+
|
|
312
|
+
if plot_dim is not None:
|
|
313
|
+
# For plot_dim case, transpose to get dimension first
|
|
314
|
+
data = data.transpose(1, 0, 2)[..., 0]
|
|
315
|
+
|
|
316
|
+
# Create plotter tuple(s)
|
|
317
|
+
if plot_dim is not None:
|
|
318
|
+
y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
|
|
319
|
+
else:
|
|
320
|
+
y_model = [(var_name, {}, {}, data)]
|
|
321
|
+
y_model = _repeat_flatten_list(y_model, len_x)
|
|
322
|
+
|
|
323
|
+
elif kind_model == "hdi":
|
|
324
|
+
var_name = y_model.name if y_model.name else "y_model"
|
|
325
|
+
data = y_model.values
|
|
326
|
+
|
|
327
|
+
if plot_dim is not None:
|
|
328
|
+
# First transpose to get plot_dim first
|
|
329
|
+
data = data.transpose(2, 0, 1, 3)
|
|
330
|
+
# For plot_dim case, we just want HDI for first dimension
|
|
331
|
+
data = data[..., 0]
|
|
332
|
+
|
|
333
|
+
# Reshape to (samples, points)
|
|
334
|
+
data = data.transpose(1, 2, 0).reshape(-1, data.shape[0])
|
|
335
|
+
y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
|
|
336
|
+
|
|
337
|
+
else:
|
|
338
|
+
data = data.reshape(-1, data.shape[-1])
|
|
339
|
+
y_model = [(var_name, {}, {}, data)]
|
|
340
|
+
y_model = _repeat_flatten_list(y_model, len_x)
|
|
341
|
+
|
|
342
|
+
if len(y_model) == 1:
|
|
343
|
+
y_model = _repeat_flatten_list(y_model, len_x)
|
|
317
344
|
|
|
318
345
|
rows, cols = default_grid(length_plotters)
|
|
319
346
|
|
arviz/plots/pairplot.py
CHANGED
|
@@ -196,9 +196,14 @@ def plot_pair(
|
|
|
196
196
|
get_coords(dataset, coords), var_names=var_names, skip_dims=combine_dims, combined=True
|
|
197
197
|
)
|
|
198
198
|
)
|
|
199
|
-
flat_var_names = [
|
|
200
|
-
|
|
201
|
-
]
|
|
199
|
+
flat_var_names = []
|
|
200
|
+
flat_ref_slices = []
|
|
201
|
+
flat_var_labels = []
|
|
202
|
+
for var_name, sel, isel, _ in plotters:
|
|
203
|
+
dims = [dim for dim in dataset[var_name].dims if dim not in ["chain", "draw"]]
|
|
204
|
+
flat_var_names.append(var_name)
|
|
205
|
+
flat_ref_slices.append(tuple(isel[dim] if dim in isel else slice(None) for dim in dims))
|
|
206
|
+
flat_var_labels.append(labeller.make_label_vert(var_name, sel, isel))
|
|
202
207
|
|
|
203
208
|
divergent_data = None
|
|
204
209
|
diverging_mask = None
|
|
@@ -253,6 +258,8 @@ def plot_pair(
|
|
|
253
258
|
diverging_mask=diverging_mask,
|
|
254
259
|
divergences_kwargs=divergences_kwargs,
|
|
255
260
|
flat_var_names=flat_var_names,
|
|
261
|
+
flat_ref_slices=flat_ref_slices,
|
|
262
|
+
flat_var_labels=flat_var_labels,
|
|
256
263
|
backend_kwargs=backend_kwargs,
|
|
257
264
|
marginal_kwargs=marginal_kwargs,
|
|
258
265
|
show=show,
|
arviz/plots/ppcplot.py
CHANGED
|
@@ -304,7 +304,7 @@ def plot_ppc(
|
|
|
304
304
|
pp_sample_ix = np.random.choice(total_pp_samples, size=num_pp_samples, replace=False)
|
|
305
305
|
|
|
306
306
|
for key in coords.keys():
|
|
307
|
-
coords[key] = np.where(np.
|
|
307
|
+
coords[key] = np.where(np.isin(observed_data[key], coords[key]))[0]
|
|
308
308
|
|
|
309
309
|
obs_plotters = filter_plotters_list(
|
|
310
310
|
list(
|
arviz/preview.py
CHANGED
|
@@ -8,41 +8,51 @@ info = ""
|
|
|
8
8
|
|
|
9
9
|
try:
|
|
10
10
|
from arviz_base import *
|
|
11
|
+
import arviz_base as base
|
|
11
12
|
|
|
12
|
-
|
|
13
|
-
_log.info(
|
|
13
|
+
_status = "arviz_base available, exposing its functions as part of arviz.preview"
|
|
14
|
+
_log.info(_status)
|
|
14
15
|
except ModuleNotFoundError:
|
|
15
|
-
|
|
16
|
-
_log.info(
|
|
16
|
+
_status = "arviz_base not installed"
|
|
17
|
+
_log.info(_status)
|
|
17
18
|
except ImportError:
|
|
18
|
-
|
|
19
|
-
_log.info(
|
|
19
|
+
_status = "Unable to import arviz_base"
|
|
20
|
+
_log.info(_status, exc_info=True)
|
|
20
21
|
|
|
21
|
-
info +=
|
|
22
|
+
info += _status + "\n"
|
|
22
23
|
|
|
23
24
|
try:
|
|
24
25
|
from arviz_stats import *
|
|
25
26
|
|
|
26
|
-
|
|
27
|
-
|
|
27
|
+
# the base computational module fron arviz_stats will override the alias to arviz-base
|
|
28
|
+
# arviz.stats.base will still be available
|
|
29
|
+
import arviz_base as base
|
|
30
|
+
import arviz_stats as stats
|
|
31
|
+
|
|
32
|
+
_status = "arviz_stats available, exposing its functions as part of arviz.preview"
|
|
33
|
+
_log.info(_status)
|
|
28
34
|
except ModuleNotFoundError:
|
|
29
|
-
|
|
30
|
-
_log.info(
|
|
35
|
+
_status = "arviz_stats not installed"
|
|
36
|
+
_log.info(_status)
|
|
31
37
|
except ImportError:
|
|
32
|
-
|
|
33
|
-
_log.info(
|
|
34
|
-
info +=
|
|
38
|
+
_status = "Unable to import arviz_stats"
|
|
39
|
+
_log.info(_status, exc_info=True)
|
|
40
|
+
info += _status + "\n"
|
|
35
41
|
|
|
36
42
|
try:
|
|
37
43
|
from arviz_plots import *
|
|
44
|
+
import arviz_plots as plots
|
|
38
45
|
|
|
39
|
-
|
|
40
|
-
_log.info(
|
|
46
|
+
_status = "arviz_plots available, exposing its functions as part of arviz.preview"
|
|
47
|
+
_log.info(_status)
|
|
41
48
|
except ModuleNotFoundError:
|
|
42
|
-
|
|
43
|
-
_log.info(
|
|
49
|
+
_status = "arviz_plots not installed"
|
|
50
|
+
_log.info(_status)
|
|
44
51
|
except ImportError:
|
|
45
|
-
|
|
46
|
-
_log.info(
|
|
52
|
+
_status = "Unable to import arviz_plots"
|
|
53
|
+
_log.info(_status, exc_info=True)
|
|
54
|
+
|
|
55
|
+
info += _status + "\n"
|
|
47
56
|
|
|
48
|
-
|
|
57
|
+
# clean namespace
|
|
58
|
+
del logging, _status, _log
|
arviz/rcparams.py
CHANGED
|
@@ -12,11 +12,11 @@ from pathlib import Path
|
|
|
12
12
|
from typing import Any, Dict
|
|
13
13
|
from typing_extensions import Literal
|
|
14
14
|
|
|
15
|
-
NO_GET_ARGS: bool = False
|
|
15
|
+
NO_GET_ARGS: bool = False # pylint: disable=invalid-name
|
|
16
16
|
try:
|
|
17
17
|
from typing_extensions import get_args
|
|
18
18
|
except ImportError:
|
|
19
|
-
NO_GET_ARGS = True
|
|
19
|
+
NO_GET_ARGS = True # pylint: disable=invalid-name
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
arviz/stats/density_utils.py
CHANGED
|
@@ -635,7 +635,7 @@ def _kde_circular(
|
|
|
635
635
|
cumulative: bool, optional
|
|
636
636
|
Whether return the PDF or the cumulative PDF. Defaults to False.
|
|
637
637
|
grid_len: int, optional
|
|
638
|
-
The number of intervals used to bin the data
|
|
638
|
+
The number of intervals used to bin the data point i.e. the length of the grid used in the
|
|
639
639
|
estimation. Defaults to 512.
|
|
640
640
|
"""
|
|
641
641
|
# All values between -pi and pi
|
arviz/stats/stats.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# pylint: disable=too-many-lines
|
|
2
2
|
"""Statistical functions in ArviZ."""
|
|
3
3
|
|
|
4
|
-
import itertools
|
|
5
4
|
import warnings
|
|
6
5
|
from copy import deepcopy
|
|
7
6
|
from typing import List, Optional, Tuple, Union, Mapping, cast, Callable
|
|
@@ -11,14 +10,14 @@ import pandas as pd
|
|
|
11
10
|
import scipy.stats as st
|
|
12
11
|
from xarray_einstats import stats
|
|
13
12
|
import xarray as xr
|
|
14
|
-
from scipy.optimize import minimize
|
|
13
|
+
from scipy.optimize import minimize, LinearConstraint, Bounds
|
|
15
14
|
from typing_extensions import Literal
|
|
16
15
|
|
|
17
|
-
NO_GET_ARGS: bool = False
|
|
16
|
+
NO_GET_ARGS: bool = False # pylint: disable=invalid-name
|
|
18
17
|
try:
|
|
19
18
|
from typing_extensions import get_args
|
|
20
19
|
except ImportError:
|
|
21
|
-
NO_GET_ARGS = True
|
|
20
|
+
NO_GET_ARGS = True # pylint: disable=invalid-name
|
|
22
21
|
|
|
23
22
|
from .. import _log
|
|
24
23
|
from ..data import InferenceData, convert_to_dataset, convert_to_inference_data, extract
|
|
@@ -225,37 +224,23 @@ def compare(
|
|
|
225
224
|
if method.lower() == "stacking":
|
|
226
225
|
rows, cols, ic_i_val = _ic_matrix(ics, ic_i)
|
|
227
226
|
exp_ic_i = np.exp(ic_i_val / scale_value)
|
|
228
|
-
km1 = cols - 1
|
|
229
|
-
|
|
230
|
-
def w_fuller(weights):
|
|
231
|
-
return np.concatenate((weights, [max(1.0 - np.sum(weights), 0.0)]))
|
|
232
227
|
|
|
233
228
|
def log_score(weights):
|
|
234
|
-
|
|
235
|
-
score = 0.0
|
|
236
|
-
for i in range(rows):
|
|
237
|
-
score += np.log(np.dot(exp_ic_i[i], w_full))
|
|
238
|
-
return -score
|
|
229
|
+
return -np.sum(np.log(exp_ic_i @ weights))
|
|
239
230
|
|
|
240
231
|
def gradient(weights):
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
for k, i in itertools.product(range(km1), range(rows)):
|
|
244
|
-
grad[k] += (exp_ic_i[i, k] - exp_ic_i[i, km1]) / np.dot(exp_ic_i[i], w_full)
|
|
245
|
-
return -grad
|
|
246
|
-
|
|
247
|
-
theta = np.full(km1, 1.0 / cols)
|
|
248
|
-
bounds = [(0.0, 1.0) for _ in range(km1)]
|
|
249
|
-
constraints = [
|
|
250
|
-
{"type": "ineq", "fun": lambda x: -np.sum(x) + 1.0},
|
|
251
|
-
{"type": "ineq", "fun": np.sum},
|
|
252
|
-
]
|
|
232
|
+
denominator = exp_ic_i @ weights
|
|
233
|
+
return -np.sum(exp_ic_i / denominator[:, np.newaxis], axis=0)
|
|
253
234
|
|
|
254
|
-
|
|
235
|
+
theta = np.full(cols, 1.0 / cols)
|
|
236
|
+
bounds = Bounds(lb=np.zeros(cols), ub=np.ones(cols))
|
|
237
|
+
constraints = LinearConstraint(np.ones(cols), lb=1.0, ub=1.0)
|
|
238
|
+
|
|
239
|
+
minimize_result = minimize(
|
|
255
240
|
fun=log_score, x0=theta, jac=gradient, bounds=bounds, constraints=constraints
|
|
256
241
|
)
|
|
257
242
|
|
|
258
|
-
weights =
|
|
243
|
+
weights = minimize_result["x"]
|
|
259
244
|
ses = ics["se"]
|
|
260
245
|
|
|
261
246
|
elif method.lower() == "bb-pseudo-bma":
|
|
@@ -869,7 +854,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
|
|
869
854
|
)
|
|
870
855
|
|
|
871
856
|
|
|
872
|
-
def psislw(log_weights, reff=1.0):
|
|
857
|
+
def psislw(log_weights, reff=1.0, normalize=True):
|
|
873
858
|
"""
|
|
874
859
|
Pareto smoothed importance sampling (PSIS).
|
|
875
860
|
|
|
@@ -887,11 +872,13 @@ def psislw(log_weights, reff=1.0):
|
|
|
887
872
|
Array of size (n_observations, n_samples)
|
|
888
873
|
reff : float, default 1
|
|
889
874
|
relative MCMC efficiency, ``ess / n``
|
|
875
|
+
normalize : bool, default True
|
|
876
|
+
return normalized log weights
|
|
890
877
|
|
|
891
878
|
Returns
|
|
892
879
|
-------
|
|
893
880
|
lw_out : DataArray or (..., N) ndarray
|
|
894
|
-
Smoothed, truncated and normalized log weights.
|
|
881
|
+
Smoothed, truncated and possibly normalized log weights.
|
|
895
882
|
kss : DataArray or (...) ndarray
|
|
896
883
|
Estimates of the shape parameter *k* of the generalized Pareto
|
|
897
884
|
distribution.
|
|
@@ -936,7 +923,12 @@ def psislw(log_weights, reff=1.0):
|
|
|
936
923
|
out = np.empty_like(log_weights), np.empty(shape)
|
|
937
924
|
|
|
938
925
|
# define kwargs
|
|
939
|
-
func_kwargs = {
|
|
926
|
+
func_kwargs = {
|
|
927
|
+
"cutoff_ind": cutoff_ind,
|
|
928
|
+
"cutoffmin": cutoffmin,
|
|
929
|
+
"out": out,
|
|
930
|
+
"normalize": normalize,
|
|
931
|
+
}
|
|
940
932
|
ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
|
|
941
933
|
kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
|
|
942
934
|
log_weights, pareto_shape = _wrap_xarray_ufunc(
|
|
@@ -953,7 +945,7 @@ def psislw(log_weights, reff=1.0):
|
|
|
953
945
|
return log_weights, pareto_shape
|
|
954
946
|
|
|
955
947
|
|
|
956
|
-
def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
948
|
+
def _psislw(log_weights, cutoff_ind, cutoffmin, normalize):
|
|
957
949
|
"""
|
|
958
950
|
Pareto smoothed importance sampling (PSIS) for a 1D vector.
|
|
959
951
|
|
|
@@ -963,7 +955,7 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
|
963
955
|
Array of length n_observations
|
|
964
956
|
cutoff_ind: int
|
|
965
957
|
cutoffmin: float
|
|
966
|
-
|
|
958
|
+
normalize: bool
|
|
967
959
|
|
|
968
960
|
Returns
|
|
969
961
|
-------
|
|
@@ -975,7 +967,8 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
|
975
967
|
x = np.asarray(log_weights)
|
|
976
968
|
|
|
977
969
|
# improve numerical accuracy
|
|
978
|
-
|
|
970
|
+
max_x = np.max(x)
|
|
971
|
+
x -= max_x
|
|
979
972
|
# sort the array
|
|
980
973
|
x_sort_ind = np.argsort(x)
|
|
981
974
|
# divide log weights into body and right tail
|
|
@@ -1007,8 +1000,12 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
|
|
|
1007
1000
|
x[tailinds[x_tail_si]] = smoothed_tail
|
|
1008
1001
|
# truncate smoothed values to the largest raw weight 0
|
|
1009
1002
|
x[x > 0] = 0
|
|
1003
|
+
|
|
1010
1004
|
# renormalize weights
|
|
1011
|
-
|
|
1005
|
+
if normalize:
|
|
1006
|
+
x -= _logsumexp(x)
|
|
1007
|
+
else:
|
|
1008
|
+
x += max_x
|
|
1012
1009
|
|
|
1013
1010
|
return x, k
|
|
1014
1011
|
|
|
@@ -1501,10 +1501,6 @@ class TestJSON:
|
|
|
1501
1501
|
assert not os.path.exists(filepath)
|
|
1502
1502
|
|
|
1503
1503
|
|
|
1504
|
-
@pytest.mark.skipif(
|
|
1505
|
-
not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
|
|
1506
|
-
reason="test requires xarray-datatree library",
|
|
1507
|
-
)
|
|
1508
1504
|
class TestDataTree:
|
|
1509
1505
|
def test_datatree(self):
|
|
1510
1506
|
idata = load_arviz_data("centered_eight")
|
|
@@ -1514,6 +1510,15 @@ class TestDataTree:
|
|
|
1514
1510
|
assert_identical(ds, idata_back[group])
|
|
1515
1511
|
assert all(group in dt.children for group in idata.groups())
|
|
1516
1512
|
|
|
1513
|
+
def test_datatree_attrs(self):
|
|
1514
|
+
idata = load_arviz_data("centered_eight")
|
|
1515
|
+
idata.attrs = {"not": "empty"}
|
|
1516
|
+
assert idata.attrs
|
|
1517
|
+
dt = idata.to_datatree()
|
|
1518
|
+
idata_back = from_datatree(dt)
|
|
1519
|
+
assert dt.attrs == idata.attrs
|
|
1520
|
+
assert idata_back.attrs == idata.attrs
|
|
1521
|
+
|
|
1517
1522
|
|
|
1518
1523
|
class TestConversions:
|
|
1519
1524
|
def test_id_conversion_idempotent(self):
|
|
@@ -1656,3 +1661,19 @@ class TestExtractDataset:
|
|
|
1656
1661
|
post = extract(idata, num_samples=10)
|
|
1657
1662
|
assert post.sizes["sample"] == 10
|
|
1658
1663
|
assert post.attrs == idata.posterior.attrs
|
|
1664
|
+
|
|
1665
|
+
|
|
1666
|
+
def test_convert_to_inference_data_with_array_like():
|
|
1667
|
+
class ArrayLike:
|
|
1668
|
+
def __init__(self, data):
|
|
1669
|
+
self._data = np.asarray(data)
|
|
1670
|
+
|
|
1671
|
+
def __array__(self):
|
|
1672
|
+
return self._data
|
|
1673
|
+
|
|
1674
|
+
array_like = ArrayLike(np.random.randn(4, 100))
|
|
1675
|
+
idata = convert_to_inference_data(array_like, group="posterior")
|
|
1676
|
+
|
|
1677
|
+
assert hasattr(idata, "posterior")
|
|
1678
|
+
assert "x" in idata.posterior.data_vars
|
|
1679
|
+
assert idata.posterior["x"].shape == (4, 100)
|