arviz 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. arviz/__init__.py +49 -4
  2. arviz/data/converters.py +11 -0
  3. arviz/data/inference_data.py +46 -24
  4. arviz/data/io_datatree.py +2 -2
  5. arviz/data/io_numpyro.py +116 -5
  6. arviz/data/io_pyjags.py +1 -1
  7. arviz/plots/autocorrplot.py +12 -2
  8. arviz/plots/backends/bokeh/hdiplot.py +7 -6
  9. arviz/plots/backends/bokeh/lmplot.py +19 -3
  10. arviz/plots/backends/bokeh/pairplot.py +18 -48
  11. arviz/plots/backends/matplotlib/khatplot.py +8 -1
  12. arviz/plots/backends/matplotlib/lmplot.py +13 -7
  13. arviz/plots/backends/matplotlib/pairplot.py +14 -22
  14. arviz/plots/bpvplot.py +1 -1
  15. arviz/plots/dotplot.py +2 -0
  16. arviz/plots/forestplot.py +16 -4
  17. arviz/plots/kdeplot.py +4 -4
  18. arviz/plots/lmplot.py +41 -14
  19. arviz/plots/pairplot.py +10 -3
  20. arviz/plots/ppcplot.py +1 -1
  21. arviz/preview.py +31 -21
  22. arviz/rcparams.py +2 -2
  23. arviz/stats/density_utils.py +1 -1
  24. arviz/stats/stats.py +31 -34
  25. arviz/tests/base_tests/test_data.py +25 -4
  26. arviz/tests/base_tests/test_plots_bokeh.py +60 -2
  27. arviz/tests/base_tests/test_plots_matplotlib.py +94 -1
  28. arviz/tests/base_tests/test_stats.py +42 -1
  29. arviz/tests/base_tests/test_stats_ecdf_utils.py +2 -2
  30. arviz/tests/external_tests/test_data_numpyro.py +154 -4
  31. arviz/wrappers/base.py +1 -1
  32. arviz/wrappers/wrap_stan.py +1 -1
  33. {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/METADATA +20 -9
  34. {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/RECORD +37 -37
  35. {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/WHEEL +1 -1
  36. {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info/licenses}/LICENSE +0 -0
  37. {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ from matplotlib import cm
7
7
  import matplotlib.pyplot as plt
8
8
  import numpy as np
9
9
  from matplotlib.colors import to_rgba_array
10
+ from packaging import version
10
11
 
11
12
  from ....stats.density_utils import histogram
12
13
  from ...plot_utils import _scale_fig_size, color_from_dim, set_xticklabels, vectorized_to_hex
@@ -39,7 +40,13 @@ def plot_khat(
39
40
  show,
40
41
  ):
41
42
  """Matplotlib khat plot."""
42
- if hover_label and mpl.get_backend() not in mpl.rcsetup.interactive_bk:
43
+ if version.parse(mpl.__version__) >= version.parse("3.9.0.dev0"):
44
+ interactive_backends = mpl.backends.backend_registry.list_builtin(
45
+ mpl.backends.BackendFilter.INTERACTIVE
46
+ )
47
+ else:
48
+ interactive_backends = mpl.rcsetup.interactive_bk
49
+ if hover_label and mpl.get_backend() not in interactive_backends:
43
50
  hover_label = False
44
51
  warnings.warn(
45
52
  "hover labels are only available with interactive backends. To switch to an "
@@ -115,12 +115,18 @@ def plot_lm(
115
115
 
116
116
  if y_model is not None:
117
117
  _, _, _, y_model_plotters = y_model[i]
118
+
118
119
  if kind_model == "lines":
119
- for j in range(num_samples):
120
- ax_i.plot(x_plotters, y_model_plotters[..., j], **y_model_plot_kwargs)
121
- ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")
120
+ # y_model_plotters should be (points, samples)
121
+ y_points = y_model_plotters.shape[0]
122
+ if x_plotters.shape[0] == y_points:
123
+ for j in range(num_samples):
124
+ ax_i.plot(x_plotters, y_model_plotters[:, j], **y_model_plot_kwargs)
125
+
126
+ ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")
127
+ y_model_mean = np.mean(y_model_plotters, axis=1)
128
+ ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
122
129
 
123
- y_model_mean = np.mean(y_model_plotters, axis=1)
124
130
  else:
125
131
  plot_hdi(
126
132
  x_plotters,
@@ -128,10 +134,10 @@ def plot_lm(
128
134
  fill_kwargs=y_model_fill_kwargs,
129
135
  ax=ax_i,
130
136
  )
131
- ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
132
137
 
133
- y_model_mean = np.mean(y_model_plotters, axis=(0, 1))
134
- ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
138
+ ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
139
+ y_model_mean = np.mean(y_model_plotters, axis=0)
140
+ ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
135
141
 
136
142
  if legend:
137
143
  ax_i.legend(fontsize=xt_labelsize, loc="upper left")
@@ -30,6 +30,8 @@ def plot_pair(
30
30
  diverging_mask,
31
31
  divergences_kwargs,
32
32
  flat_var_names,
33
+ flat_ref_slices,
34
+ flat_var_labels,
33
35
  backend_kwargs,
34
36
  marginal_kwargs,
35
37
  show,
@@ -77,24 +79,12 @@ def plot_pair(
77
79
  kde_kwargs["contour_kwargs"].setdefault("colors", "k")
78
80
 
79
81
  if reference_values:
80
- reference_values_copy = {}
81
- label = []
82
- for variable in list(reference_values.keys()):
83
- if " " in variable:
84
- variable_copy = variable.replace(" ", "\n", 1)
85
- else:
86
- variable_copy = variable
87
-
88
- label.append(variable_copy)
89
- reference_values_copy[variable_copy] = reference_values[variable]
90
-
91
- difference = set(flat_var_names).difference(set(label))
82
+ difference = set(flat_var_names).difference(set(reference_values.keys()))
92
83
 
93
84
  if difference:
94
- warn = [diff.replace("\n", " ", 1) for diff in difference]
95
85
  warnings.warn(
96
86
  "Argument reference_values does not include reference value for: {}".format(
97
- ", ".join(warn)
87
+ ", ".join(difference)
98
88
  ),
99
89
  UserWarning,
100
90
  )
@@ -211,12 +201,12 @@ def plot_pair(
211
201
 
212
202
  if reference_values:
213
203
  ax.plot(
214
- reference_values_copy[flat_var_names[0]],
215
- reference_values_copy[flat_var_names[1]],
204
+ np.array(reference_values[flat_var_names[0]])[flat_ref_slices[0]],
205
+ np.array(reference_values[flat_var_names[1]])[flat_ref_slices[1]],
216
206
  **reference_values_kwargs,
217
207
  )
218
- ax.set_xlabel(f"{flat_var_names[0]}", fontsize=ax_labelsize, wrap=True)
219
- ax.set_ylabel(f"{flat_var_names[1]}", fontsize=ax_labelsize, wrap=True)
208
+ ax.set_xlabel(f"{flat_var_labels[0]}", fontsize=ax_labelsize, wrap=True)
209
+ ax.set_ylabel(f"{flat_var_labels[1]}", fontsize=ax_labelsize, wrap=True)
220
210
  ax.tick_params(labelsize=xt_labelsize)
221
211
 
222
212
  else:
@@ -336,20 +326,22 @@ def plot_pair(
336
326
  y_name = flat_var_names[j + not_marginals]
337
327
  if (x_name not in difference) and (y_name not in difference):
338
328
  ax[j, i].plot(
339
- reference_values_copy[x_name],
340
- reference_values_copy[y_name],
329
+ np.array(reference_values[x_name])[flat_ref_slices[i]],
330
+ np.array(reference_values[y_name])[
331
+ flat_ref_slices[j + not_marginals]
332
+ ],
341
333
  **reference_values_kwargs,
342
334
  )
343
335
 
344
336
  if j != vars_to_plot - 1:
345
337
  plt.setp(ax[j, i].get_xticklabels(), visible=False)
346
338
  else:
347
- ax[j, i].set_xlabel(f"{flat_var_names[i]}", fontsize=ax_labelsize, wrap=True)
339
+ ax[j, i].set_xlabel(f"{flat_var_labels[i]}", fontsize=ax_labelsize, wrap=True)
348
340
  if i != 0:
349
341
  plt.setp(ax[j, i].get_yticklabels(), visible=False)
350
342
  else:
351
343
  ax[j, i].set_ylabel(
352
- f"{flat_var_names[j + not_marginals]}",
344
+ f"{flat_var_labels[j + not_marginals]}",
353
345
  fontsize=ax_labelsize,
354
346
  wrap=True,
355
347
  )
arviz/plots/bpvplot.py CHANGED
@@ -251,7 +251,7 @@ def plot_bpv(
251
251
  total_pp_samples = predictive_dataset.sizes["chain"] * predictive_dataset.sizes["draw"]
252
252
 
253
253
  for key in coords.keys():
254
- coords[key] = np.where(np.in1d(observed[key], coords[key]))[0]
254
+ coords[key] = np.where(np.isin(observed[key], coords[key]))[0]
255
255
 
256
256
  obs_plotters = filter_plotters_list(
257
257
  list(
arviz/plots/dotplot.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  import numpy as np
4
4
 
5
+
5
6
  from ..rcparams import rcParams
6
7
  from .plot_utils import get_plotting_function
7
8
 
@@ -148,6 +149,7 @@ def plot_dot(
148
149
  raise ValueError("marker argument is valid only for matplotlib backend")
149
150
 
150
151
  values = np.ravel(values)
152
+ values = values[np.isfinite(values)]
151
153
  values.sort()
152
154
 
153
155
  if hdi_prob is None:
arviz/plots/forestplot.py CHANGED
@@ -51,7 +51,7 @@ def plot_forest(
51
51
  data : InferenceData
52
52
  Any object that can be converted to an :class:`arviz.InferenceData` object
53
53
  Refer to documentation of :func:`arviz.convert_to_dataset` for details.
54
- kind : {"foresplot", "ridgeplot"}, default "forestplot"
54
+ kind : {"forestplot", "ridgeplot"}, default "forestplot"
55
55
  Specify the kind of plot:
56
56
 
57
57
  * The ``kind="forestplot"`` generates credible intervals, where the central points are the
@@ -75,8 +75,8 @@ def plot_forest(
75
75
  interpret `var_names` as substrings of the real variables names. If "regex",
76
76
  interpret `var_names` as regular expressions on the real variables names. See
77
77
  :ref:`this section <common_filter_vars>` for usage examples.
78
- transform : callable, optional
79
- Function to transform data (defaults to None i.e.the identity function).
78
+ transform : callable or dict, optional
79
+ Function to transform the data. Defaults to None, i.e., the identity function.
80
80
  coords : dict, optional
81
81
  Coordinates of ``var_names`` to be plotted. Passed to :meth:`xarray.Dataset.sel`.
82
82
  See :ref:`this section <common_coords>` for usage examples.
@@ -228,7 +228,19 @@ def plot_forest(
228
228
 
229
229
  datasets = [convert_to_dataset(datum) for datum in reversed(data)]
230
230
  if transform is not None:
231
- datasets = [transform(dataset) for dataset in datasets]
231
+ if callable(transform):
232
+ datasets = [transform(dataset) for dataset in datasets]
233
+ elif isinstance(transform, dict):
234
+ transformed_datasets = []
235
+ for dataset in datasets:
236
+ new_dataset = dataset.copy()
237
+ for var_name, func in transform.items():
238
+ if var_name in new_dataset:
239
+ new_dataset[var_name] = func(new_dataset[var_name])
240
+ transformed_datasets.append(new_dataset)
241
+ datasets = transformed_datasets
242
+ else:
243
+ raise ValueError("transform must be either a callable or a dict {var_name: callable}")
232
244
  datasets = get_coords(
233
245
  datasets, list(reversed(coords)) if isinstance(coords, (list, tuple)) else coords
234
246
  )
arviz/plots/kdeplot.py CHANGED
@@ -255,6 +255,10 @@ def plot_kde(
255
255
  "or plot_pair instead of plot_kde"
256
256
  )
257
257
 
258
+ if backend is None:
259
+ backend = rcParams["plot.backend"]
260
+ backend = backend.lower()
261
+
258
262
  if values2 is None:
259
263
  if bw == "default":
260
264
  bw = "taylor" if is_circular else "experimental"
@@ -346,10 +350,6 @@ def plot_kde(
346
350
  **kwargs,
347
351
  )
348
352
 
349
- if backend is None:
350
- backend = rcParams["plot.backend"]
351
- backend = backend.lower()
352
-
353
353
  # TODO: Add backend kwargs
354
354
  plot = get_plotting_function("plot_kde", "kdeplot", backend)
355
355
  ax = plot(**kde_plot_args)
arviz/plots/lmplot.py CHANGED
@@ -300,20 +300,47 @@ def plot_lm(
300
300
  # Filter out the required values to generate plotters
301
301
  if y_model is not None:
302
302
  if kind_model == "lines":
303
- y_model = y_model.stack(__sample__=("chain", "draw"))[..., pp_sample_ix]
304
-
305
- y_model = [
306
- tup
307
- for _, tup in zip(
308
- range(len_y),
309
- xarray_var_iter(
310
- y_model,
311
- skip_dims=set(y_model.dims),
312
- combined=True,
313
- ),
314
- )
315
- ]
316
- y_model = _repeat_flatten_list(y_model, len_x)
303
+ var_name = y_model.name if y_model.name else "y_model"
304
+ data = y_model.values
305
+
306
+ total_samples = data.shape[0] * data.shape[1]
307
+ data = data.reshape(total_samples, *data.shape[2:])
308
+
309
+ if pp_sample_ix is not None:
310
+ data = data[pp_sample_ix]
311
+
312
+ if plot_dim is not None:
313
+ # For plot_dim case, transpose to get dimension first
314
+ data = data.transpose(1, 0, 2)[..., 0]
315
+
316
+ # Create plotter tuple(s)
317
+ if plot_dim is not None:
318
+ y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
319
+ else:
320
+ y_model = [(var_name, {}, {}, data)]
321
+ y_model = _repeat_flatten_list(y_model, len_x)
322
+
323
+ elif kind_model == "hdi":
324
+ var_name = y_model.name if y_model.name else "y_model"
325
+ data = y_model.values
326
+
327
+ if plot_dim is not None:
328
+ # First transpose to get plot_dim first
329
+ data = data.transpose(2, 0, 1, 3)
330
+ # For plot_dim case, we just want HDI for first dimension
331
+ data = data[..., 0]
332
+
333
+ # Reshape to (samples, points)
334
+ data = data.transpose(1, 2, 0).reshape(-1, data.shape[0])
335
+ y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
336
+
337
+ else:
338
+ data = data.reshape(-1, data.shape[-1])
339
+ y_model = [(var_name, {}, {}, data)]
340
+ y_model = _repeat_flatten_list(y_model, len_x)
341
+
342
+ if len(y_model) == 1:
343
+ y_model = _repeat_flatten_list(y_model, len_x)
317
344
 
318
345
  rows, cols = default_grid(length_plotters)
319
346
 
arviz/plots/pairplot.py CHANGED
@@ -196,9 +196,14 @@ def plot_pair(
196
196
  get_coords(dataset, coords), var_names=var_names, skip_dims=combine_dims, combined=True
197
197
  )
198
198
  )
199
- flat_var_names = [
200
- labeller.make_label_vert(var_name, sel, isel) for var_name, sel, isel, _ in plotters
201
- ]
199
+ flat_var_names = []
200
+ flat_ref_slices = []
201
+ flat_var_labels = []
202
+ for var_name, sel, isel, _ in plotters:
203
+ dims = [dim for dim in dataset[var_name].dims if dim not in ["chain", "draw"]]
204
+ flat_var_names.append(var_name)
205
+ flat_ref_slices.append(tuple(isel[dim] if dim in isel else slice(None) for dim in dims))
206
+ flat_var_labels.append(labeller.make_label_vert(var_name, sel, isel))
202
207
 
203
208
  divergent_data = None
204
209
  diverging_mask = None
@@ -253,6 +258,8 @@ def plot_pair(
253
258
  diverging_mask=diverging_mask,
254
259
  divergences_kwargs=divergences_kwargs,
255
260
  flat_var_names=flat_var_names,
261
+ flat_ref_slices=flat_ref_slices,
262
+ flat_var_labels=flat_var_labels,
256
263
  backend_kwargs=backend_kwargs,
257
264
  marginal_kwargs=marginal_kwargs,
258
265
  show=show,
arviz/plots/ppcplot.py CHANGED
@@ -304,7 +304,7 @@ def plot_ppc(
304
304
  pp_sample_ix = np.random.choice(total_pp_samples, size=num_pp_samples, replace=False)
305
305
 
306
306
  for key in coords.keys():
307
- coords[key] = np.where(np.in1d(observed_data[key], coords[key]))[0]
307
+ coords[key] = np.where(np.isin(observed_data[key], coords[key]))[0]
308
308
 
309
309
  obs_plotters = filter_plotters_list(
310
310
  list(
arviz/preview.py CHANGED
@@ -8,41 +8,51 @@ info = ""
8
8
 
9
9
  try:
10
10
  from arviz_base import *
11
+ import arviz_base as base
11
12
 
12
- status = "arviz_base available, exposing its functions as part of arviz.preview"
13
- _log.info(status)
13
+ _status = "arviz_base available, exposing its functions as part of arviz.preview"
14
+ _log.info(_status)
14
15
  except ModuleNotFoundError:
15
- status = "arviz_base not installed"
16
- _log.info(status)
16
+ _status = "arviz_base not installed"
17
+ _log.info(_status)
17
18
  except ImportError:
18
- status = "Unable to import arviz_base"
19
- _log.info(status, exc_info=True)
19
+ _status = "Unable to import arviz_base"
20
+ _log.info(_status, exc_info=True)
20
21
 
21
- info += status + "\n"
22
+ info += _status + "\n"
22
23
 
23
24
  try:
24
25
  from arviz_stats import *
25
26
 
26
- status = "arviz_stats available, exposing its functions as part of arviz.preview"
27
- _log.info(status)
27
+ # the base computational module fron arviz_stats will override the alias to arviz-base
28
+ # arviz.stats.base will still be available
29
+ import arviz_base as base
30
+ import arviz_stats as stats
31
+
32
+ _status = "arviz_stats available, exposing its functions as part of arviz.preview"
33
+ _log.info(_status)
28
34
  except ModuleNotFoundError:
29
- status = "arviz_stats not installed"
30
- _log.info(status)
35
+ _status = "arviz_stats not installed"
36
+ _log.info(_status)
31
37
  except ImportError:
32
- status = "Unable to import arviz_stats"
33
- _log.info(status, exc_info=True)
34
- info += status + "\n"
38
+ _status = "Unable to import arviz_stats"
39
+ _log.info(_status, exc_info=True)
40
+ info += _status + "\n"
35
41
 
36
42
  try:
37
43
  from arviz_plots import *
44
+ import arviz_plots as plots
38
45
 
39
- status = "arviz_plots available, exposing its functions as part of arviz.preview"
40
- _log.info(status)
46
+ _status = "arviz_plots available, exposing its functions as part of arviz.preview"
47
+ _log.info(_status)
41
48
  except ModuleNotFoundError:
42
- status = "arviz_plots not installed"
43
- _log.info(status)
49
+ _status = "arviz_plots not installed"
50
+ _log.info(_status)
44
51
  except ImportError:
45
- status = "Unable to import arviz_plots"
46
- _log.info(status, exc_info=True)
52
+ _status = "Unable to import arviz_plots"
53
+ _log.info(_status, exc_info=True)
54
+
55
+ info += _status + "\n"
47
56
 
48
- info += status + "\n"
57
+ # clean namespace
58
+ del logging, _status, _log
arviz/rcparams.py CHANGED
@@ -12,11 +12,11 @@ from pathlib import Path
12
12
  from typing import Any, Dict
13
13
  from typing_extensions import Literal
14
14
 
15
- NO_GET_ARGS: bool = False
15
+ NO_GET_ARGS: bool = False # pylint: disable=invalid-name
16
16
  try:
17
17
  from typing_extensions import get_args
18
18
  except ImportError:
19
- NO_GET_ARGS = True
19
+ NO_GET_ARGS = True # pylint: disable=invalid-name
20
20
 
21
21
 
22
22
  import numpy as np
@@ -635,7 +635,7 @@ def _kde_circular(
635
635
  cumulative: bool, optional
636
636
  Whether return the PDF or the cumulative PDF. Defaults to False.
637
637
  grid_len: int, optional
638
- The number of intervals used to bin the data pointa i.e. the length of the grid used in the
638
+ The number of intervals used to bin the data point i.e. the length of the grid used in the
639
639
  estimation. Defaults to 512.
640
640
  """
641
641
  # All values between -pi and pi
arviz/stats/stats.py CHANGED
@@ -1,7 +1,6 @@
1
1
  # pylint: disable=too-many-lines
2
2
  """Statistical functions in ArviZ."""
3
3
 
4
- import itertools
5
4
  import warnings
6
5
  from copy import deepcopy
7
6
  from typing import List, Optional, Tuple, Union, Mapping, cast, Callable
@@ -11,14 +10,14 @@ import pandas as pd
11
10
  import scipy.stats as st
12
11
  from xarray_einstats import stats
13
12
  import xarray as xr
14
- from scipy.optimize import minimize
13
+ from scipy.optimize import minimize, LinearConstraint, Bounds
15
14
  from typing_extensions import Literal
16
15
 
17
- NO_GET_ARGS: bool = False
16
+ NO_GET_ARGS: bool = False # pylint: disable=invalid-name
18
17
  try:
19
18
  from typing_extensions import get_args
20
19
  except ImportError:
21
- NO_GET_ARGS = True
20
+ NO_GET_ARGS = True # pylint: disable=invalid-name
22
21
 
23
22
  from .. import _log
24
23
  from ..data import InferenceData, convert_to_dataset, convert_to_inference_data, extract
@@ -225,37 +224,23 @@ def compare(
225
224
  if method.lower() == "stacking":
226
225
  rows, cols, ic_i_val = _ic_matrix(ics, ic_i)
227
226
  exp_ic_i = np.exp(ic_i_val / scale_value)
228
- km1 = cols - 1
229
-
230
- def w_fuller(weights):
231
- return np.concatenate((weights, [max(1.0 - np.sum(weights), 0.0)]))
232
227
 
233
228
  def log_score(weights):
234
- w_full = w_fuller(weights)
235
- score = 0.0
236
- for i in range(rows):
237
- score += np.log(np.dot(exp_ic_i[i], w_full))
238
- return -score
229
+ return -np.sum(np.log(exp_ic_i @ weights))
239
230
 
240
231
  def gradient(weights):
241
- w_full = w_fuller(weights)
242
- grad = np.zeros(km1)
243
- for k, i in itertools.product(range(km1), range(rows)):
244
- grad[k] += (exp_ic_i[i, k] - exp_ic_i[i, km1]) / np.dot(exp_ic_i[i], w_full)
245
- return -grad
246
-
247
- theta = np.full(km1, 1.0 / cols)
248
- bounds = [(0.0, 1.0) for _ in range(km1)]
249
- constraints = [
250
- {"type": "ineq", "fun": lambda x: -np.sum(x) + 1.0},
251
- {"type": "ineq", "fun": np.sum},
252
- ]
232
+ denominator = exp_ic_i @ weights
233
+ return -np.sum(exp_ic_i / denominator[:, np.newaxis], axis=0)
253
234
 
254
- weights = minimize(
235
+ theta = np.full(cols, 1.0 / cols)
236
+ bounds = Bounds(lb=np.zeros(cols), ub=np.ones(cols))
237
+ constraints = LinearConstraint(np.ones(cols), lb=1.0, ub=1.0)
238
+
239
+ minimize_result = minimize(
255
240
  fun=log_score, x0=theta, jac=gradient, bounds=bounds, constraints=constraints
256
241
  )
257
242
 
258
- weights = w_fuller(weights["x"])
243
+ weights = minimize_result["x"]
259
244
  ses = ics["se"]
260
245
 
261
246
  elif method.lower() == "bb-pseudo-bma":
@@ -869,7 +854,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
869
854
  )
870
855
 
871
856
 
872
- def psislw(log_weights, reff=1.0):
857
+ def psislw(log_weights, reff=1.0, normalize=True):
873
858
  """
874
859
  Pareto smoothed importance sampling (PSIS).
875
860
 
@@ -887,11 +872,13 @@ def psislw(log_weights, reff=1.0):
887
872
  Array of size (n_observations, n_samples)
888
873
  reff : float, default 1
889
874
  relative MCMC efficiency, ``ess / n``
875
+ normalize : bool, default True
876
+ return normalized log weights
890
877
 
891
878
  Returns
892
879
  -------
893
880
  lw_out : DataArray or (..., N) ndarray
894
- Smoothed, truncated and normalized log weights.
881
+ Smoothed, truncated and possibly normalized log weights.
895
882
  kss : DataArray or (...) ndarray
896
883
  Estimates of the shape parameter *k* of the generalized Pareto
897
884
  distribution.
@@ -936,7 +923,12 @@ def psislw(log_weights, reff=1.0):
936
923
  out = np.empty_like(log_weights), np.empty(shape)
937
924
 
938
925
  # define kwargs
939
- func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "out": out}
926
+ func_kwargs = {
927
+ "cutoff_ind": cutoff_ind,
928
+ "cutoffmin": cutoffmin,
929
+ "out": out,
930
+ "normalize": normalize,
931
+ }
940
932
  ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
941
933
  kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
942
934
  log_weights, pareto_shape = _wrap_xarray_ufunc(
@@ -953,7 +945,7 @@ def psislw(log_weights, reff=1.0):
953
945
  return log_weights, pareto_shape
954
946
 
955
947
 
956
- def _psislw(log_weights, cutoff_ind, cutoffmin):
948
+ def _psislw(log_weights, cutoff_ind, cutoffmin, normalize):
957
949
  """
958
950
  Pareto smoothed importance sampling (PSIS) for a 1D vector.
959
951
 
@@ -963,7 +955,7 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
963
955
  Array of length n_observations
964
956
  cutoff_ind: int
965
957
  cutoffmin: float
966
- k_min: float
958
+ normalize: bool
967
959
 
968
960
  Returns
969
961
  -------
@@ -975,7 +967,8 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
975
967
  x = np.asarray(log_weights)
976
968
 
977
969
  # improve numerical accuracy
978
- x -= np.max(x)
970
+ max_x = np.max(x)
971
+ x -= max_x
979
972
  # sort the array
980
973
  x_sort_ind = np.argsort(x)
981
974
  # divide log weights into body and right tail
@@ -1007,8 +1000,12 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
1007
1000
  x[tailinds[x_tail_si]] = smoothed_tail
1008
1001
  # truncate smoothed values to the largest raw weight 0
1009
1002
  x[x > 0] = 0
1003
+
1010
1004
  # renormalize weights
1011
- x -= _logsumexp(x)
1005
+ if normalize:
1006
+ x -= _logsumexp(x)
1007
+ else:
1008
+ x += max_x
1012
1009
 
1013
1010
  return x, k
1014
1011
 
@@ -1501,10 +1501,6 @@ class TestJSON:
1501
1501
  assert not os.path.exists(filepath)
1502
1502
 
1503
1503
 
1504
- @pytest.mark.skipif(
1505
- not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
1506
- reason="test requires xarray-datatree library",
1507
- )
1508
1504
  class TestDataTree:
1509
1505
  def test_datatree(self):
1510
1506
  idata = load_arviz_data("centered_eight")
@@ -1514,6 +1510,15 @@ class TestDataTree:
1514
1510
  assert_identical(ds, idata_back[group])
1515
1511
  assert all(group in dt.children for group in idata.groups())
1516
1512
 
1513
+ def test_datatree_attrs(self):
1514
+ idata = load_arviz_data("centered_eight")
1515
+ idata.attrs = {"not": "empty"}
1516
+ assert idata.attrs
1517
+ dt = idata.to_datatree()
1518
+ idata_back = from_datatree(dt)
1519
+ assert dt.attrs == idata.attrs
1520
+ assert idata_back.attrs == idata.attrs
1521
+
1517
1522
 
1518
1523
  class TestConversions:
1519
1524
  def test_id_conversion_idempotent(self):
@@ -1656,3 +1661,19 @@ class TestExtractDataset:
1656
1661
  post = extract(idata, num_samples=10)
1657
1662
  assert post.sizes["sample"] == 10
1658
1663
  assert post.attrs == idata.posterior.attrs
1664
+
1665
+
1666
+ def test_convert_to_inference_data_with_array_like():
1667
+ class ArrayLike:
1668
+ def __init__(self, data):
1669
+ self._data = np.asarray(data)
1670
+
1671
+ def __array__(self):
1672
+ return self._data
1673
+
1674
+ array_like = ArrayLike(np.random.randn(4, 100))
1675
+ idata = convert_to_inference_data(array_like, group="posterior")
1676
+
1677
+ assert hasattr(idata, "posterior")
1678
+ assert "x" in idata.posterior.data_vars
1679
+ assert idata.posterior["x"].shape == (4, 100)