arviz 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,6 +30,8 @@ def plot_pair(
30
30
  diverging_mask,
31
31
  divergences_kwargs,
32
32
  flat_var_names,
33
+ flat_ref_slices,
34
+ flat_var_labels,
33
35
  backend_kwargs,
34
36
  marginal_kwargs,
35
37
  show,
@@ -77,24 +79,12 @@ def plot_pair(
77
79
  kde_kwargs["contour_kwargs"].setdefault("colors", "k")
78
80
 
79
81
  if reference_values:
80
- reference_values_copy = {}
81
- label = []
82
- for variable in list(reference_values.keys()):
83
- if " " in variable:
84
- variable_copy = variable.replace(" ", "\n", 1)
85
- else:
86
- variable_copy = variable
87
-
88
- label.append(variable_copy)
89
- reference_values_copy[variable_copy] = reference_values[variable]
90
-
91
- difference = set(flat_var_names).difference(set(label))
82
+ difference = set(flat_var_names).difference(set(reference_values.keys()))
92
83
 
93
84
  if difference:
94
- warn = [diff.replace("\n", " ", 1) for diff in difference]
95
85
  warnings.warn(
96
86
  "Argument reference_values does not include reference value for: {}".format(
97
- ", ".join(warn)
87
+ ", ".join(difference)
98
88
  ),
99
89
  UserWarning,
100
90
  )
@@ -211,12 +201,12 @@ def plot_pair(
211
201
 
212
202
  if reference_values:
213
203
  ax.plot(
214
- reference_values_copy[flat_var_names[0]],
215
- reference_values_copy[flat_var_names[1]],
204
+ np.array(reference_values[flat_var_names[0]])[flat_ref_slices[0]],
205
+ np.array(reference_values[flat_var_names[1]])[flat_ref_slices[1]],
216
206
  **reference_values_kwargs,
217
207
  )
218
- ax.set_xlabel(f"{flat_var_names[0]}", fontsize=ax_labelsize, wrap=True)
219
- ax.set_ylabel(f"{flat_var_names[1]}", fontsize=ax_labelsize, wrap=True)
208
+ ax.set_xlabel(f"{flat_var_labels[0]}", fontsize=ax_labelsize, wrap=True)
209
+ ax.set_ylabel(f"{flat_var_labels[1]}", fontsize=ax_labelsize, wrap=True)
220
210
  ax.tick_params(labelsize=xt_labelsize)
221
211
 
222
212
  else:
@@ -336,20 +326,22 @@ def plot_pair(
336
326
  y_name = flat_var_names[j + not_marginals]
337
327
  if (x_name not in difference) and (y_name not in difference):
338
328
  ax[j, i].plot(
339
- reference_values_copy[x_name],
340
- reference_values_copy[y_name],
329
+ np.array(reference_values[x_name])[flat_ref_slices[i]],
330
+ np.array(reference_values[y_name])[
331
+ flat_ref_slices[j + not_marginals]
332
+ ],
341
333
  **reference_values_kwargs,
342
334
  )
343
335
 
344
336
  if j != vars_to_plot - 1:
345
337
  plt.setp(ax[j, i].get_xticklabels(), visible=False)
346
338
  else:
347
- ax[j, i].set_xlabel(f"{flat_var_names[i]}", fontsize=ax_labelsize, wrap=True)
339
+ ax[j, i].set_xlabel(f"{flat_var_labels[i]}", fontsize=ax_labelsize, wrap=True)
348
340
  if i != 0:
349
341
  plt.setp(ax[j, i].get_yticklabels(), visible=False)
350
342
  else:
351
343
  ax[j, i].set_ylabel(
352
- f"{flat_var_names[j + not_marginals]}",
344
+ f"{flat_var_labels[j + not_marginals]}",
353
345
  fontsize=ax_labelsize,
354
346
  wrap=True,
355
347
  )
arviz/plots/kdeplot.py CHANGED
@@ -255,6 +255,10 @@ def plot_kde(
255
255
  "or plot_pair instead of plot_kde"
256
256
  )
257
257
 
258
+ if backend is None:
259
+ backend = rcParams["plot.backend"]
260
+ backend = backend.lower()
261
+
258
262
  if values2 is None:
259
263
  if bw == "default":
260
264
  bw = "taylor" if is_circular else "experimental"
@@ -346,10 +350,6 @@ def plot_kde(
346
350
  **kwargs,
347
351
  )
348
352
 
349
- if backend is None:
350
- backend = rcParams["plot.backend"]
351
- backend = backend.lower()
352
-
353
353
  # TODO: Add backend kwargs
354
354
  plot = get_plotting_function("plot_kde", "kdeplot", backend)
355
355
  ax = plot(**kde_plot_args)
arviz/plots/lmplot.py CHANGED
@@ -300,20 +300,47 @@ def plot_lm(
300
300
  # Filter out the required values to generate plotters
301
301
  if y_model is not None:
302
302
  if kind_model == "lines":
303
- y_model = y_model.stack(__sample__=("chain", "draw"))[..., pp_sample_ix]
304
-
305
- y_model = [
306
- tup
307
- for _, tup in zip(
308
- range(len_y),
309
- xarray_var_iter(
310
- y_model,
311
- skip_dims=set(y_model.dims),
312
- combined=True,
313
- ),
314
- )
315
- ]
316
- y_model = _repeat_flatten_list(y_model, len_x)
303
+ var_name = y_model.name if y_model.name else "y_model"
304
+ data = y_model.values
305
+
306
+ total_samples = data.shape[0] * data.shape[1]
307
+ data = data.reshape(total_samples, *data.shape[2:])
308
+
309
+ if pp_sample_ix is not None:
310
+ data = data[pp_sample_ix]
311
+
312
+ if plot_dim is not None:
313
+ # For plot_dim case, transpose to get dimension first
314
+ data = data.transpose(1, 0, 2)[..., 0]
315
+
316
+ # Create plotter tuple(s)
317
+ if plot_dim is not None:
318
+ y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
319
+ else:
320
+ y_model = [(var_name, {}, {}, data)]
321
+ y_model = _repeat_flatten_list(y_model, len_x)
322
+
323
+ elif kind_model == "hdi":
324
+ var_name = y_model.name if y_model.name else "y_model"
325
+ data = y_model.values
326
+
327
+ if plot_dim is not None:
328
+ # First transpose to get plot_dim first
329
+ data = data.transpose(2, 0, 1, 3)
330
+ # For plot_dim case, we just want HDI for first dimension
331
+ data = data[..., 0]
332
+
333
+ # Reshape to (samples, points)
334
+ data = data.transpose(1, 2, 0).reshape(-1, data.shape[0])
335
+ y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
336
+
337
+ else:
338
+ data = data.reshape(-1, data.shape[-1])
339
+ y_model = [(var_name, {}, {}, data)]
340
+ y_model = _repeat_flatten_list(y_model, len_x)
341
+
342
+ if len(y_model) == 1:
343
+ y_model = _repeat_flatten_list(y_model, len_x)
317
344
 
318
345
  rows, cols = default_grid(length_plotters)
319
346
 
arviz/plots/pairplot.py CHANGED
@@ -196,9 +196,14 @@ def plot_pair(
196
196
  get_coords(dataset, coords), var_names=var_names, skip_dims=combine_dims, combined=True
197
197
  )
198
198
  )
199
- flat_var_names = [
200
- labeller.make_label_vert(var_name, sel, isel) for var_name, sel, isel, _ in plotters
201
- ]
199
+ flat_var_names = []
200
+ flat_ref_slices = []
201
+ flat_var_labels = []
202
+ for var_name, sel, isel, _ in plotters:
203
+ dims = [dim for dim in dataset[var_name].dims if dim not in ["chain", "draw"]]
204
+ flat_var_names.append(var_name)
205
+ flat_ref_slices.append(tuple(isel[dim] if dim in isel else slice(None) for dim in dims))
206
+ flat_var_labels.append(labeller.make_label_vert(var_name, sel, isel))
202
207
 
203
208
  divergent_data = None
204
209
  diverging_mask = None
@@ -253,6 +258,8 @@ def plot_pair(
253
258
  diverging_mask=diverging_mask,
254
259
  divergences_kwargs=divergences_kwargs,
255
260
  flat_var_names=flat_var_names,
261
+ flat_ref_slices=flat_ref_slices,
262
+ flat_var_labels=flat_var_labels,
256
263
  backend_kwargs=backend_kwargs,
257
264
  marginal_kwargs=marginal_kwargs,
258
265
  show=show,
@@ -635,7 +635,7 @@ def _kde_circular(
635
635
  cumulative: bool, optional
636
636
  Whether return the PDF or the cumulative PDF. Defaults to False.
637
637
  grid_len: int, optional
638
- The number of intervals used to bin the data pointa i.e. the length of the grid used in the
638
+ The number of intervals used to bin the data point i.e. the length of the grid used in the
639
639
  estimation. Defaults to 512.
640
640
  """
641
641
  # All values between -pi and pi
arviz/stats/stats.py CHANGED
@@ -869,7 +869,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
869
869
  )
870
870
 
871
871
 
872
- def psislw(log_weights, reff=1.0):
872
+ def psislw(log_weights, reff=1.0, normalize=True):
873
873
  """
874
874
  Pareto smoothed importance sampling (PSIS).
875
875
 
@@ -887,11 +887,13 @@ def psislw(log_weights, reff=1.0):
887
887
  Array of size (n_observations, n_samples)
888
888
  reff : float, default 1
889
889
  relative MCMC efficiency, ``ess / n``
890
+ normalize : bool, default True
891
+ return normalized log weights
890
892
 
891
893
  Returns
892
894
  -------
893
895
  lw_out : DataArray or (..., N) ndarray
894
- Smoothed, truncated and normalized log weights.
896
+ Smoothed, truncated and possibly normalized log weights.
895
897
  kss : DataArray or (...) ndarray
896
898
  Estimates of the shape parameter *k* of the generalized Pareto
897
899
  distribution.
@@ -936,7 +938,12 @@ def psislw(log_weights, reff=1.0):
936
938
  out = np.empty_like(log_weights), np.empty(shape)
937
939
 
938
940
  # define kwargs
939
- func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "out": out}
941
+ func_kwargs = {
942
+ "cutoff_ind": cutoff_ind,
943
+ "cutoffmin": cutoffmin,
944
+ "out": out,
945
+ "normalize": normalize,
946
+ }
940
947
  ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
941
948
  kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
942
949
  log_weights, pareto_shape = _wrap_xarray_ufunc(
@@ -953,7 +960,7 @@ def psislw(log_weights, reff=1.0):
953
960
  return log_weights, pareto_shape
954
961
 
955
962
 
956
- def _psislw(log_weights, cutoff_ind, cutoffmin):
963
+ def _psislw(log_weights, cutoff_ind, cutoffmin, normalize):
957
964
  """
958
965
  Pareto smoothed importance sampling (PSIS) for a 1D vector.
959
966
 
@@ -963,7 +970,7 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
963
970
  Array of length n_observations
964
971
  cutoff_ind: int
965
972
  cutoffmin: float
966
- k_min: float
973
+ normalize: bool
967
974
 
968
975
  Returns
969
976
  -------
@@ -975,7 +982,8 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
975
982
  x = np.asarray(log_weights)
976
983
 
977
984
  # improve numerical accuracy
978
- x -= np.max(x)
985
+ max_x = np.max(x)
986
+ x -= max_x
979
987
  # sort the array
980
988
  x_sort_ind = np.argsort(x)
981
989
  # divide log weights into body and right tail
@@ -1007,8 +1015,12 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
1007
1015
  x[tailinds[x_tail_si]] = smoothed_tail
1008
1016
  # truncate smoothed values to the largest raw weight 0
1009
1017
  x[x > 0] = 0
1018
+
1010
1019
  # renormalize weights
1011
- x -= _logsumexp(x)
1020
+ if normalize:
1021
+ x -= _logsumexp(x)
1022
+ else:
1023
+ x += max_x
1012
1024
 
1013
1025
  return x, k
1014
1026
 
@@ -1501,10 +1501,6 @@ class TestJSON:
1501
1501
  assert not os.path.exists(filepath)
1502
1502
 
1503
1503
 
1504
- @pytest.mark.skipif(
1505
- not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
1506
- reason="test requires xarray-datatree library",
1507
- )
1508
1504
  class TestDataTree:
1509
1505
  def test_datatree(self):
1510
1506
  idata = load_arviz_data("centered_eight")
@@ -8,6 +8,7 @@ from pandas import DataFrame # pylint: disable=wrong-import-position
8
8
  from scipy.stats import norm # pylint: disable=wrong-import-position
9
9
 
10
10
  from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position
11
+ from ...labels import MapLabeller # pylint: disable=wrong-import-position
11
12
  from ...plots import ( # pylint: disable=wrong-import-position
12
13
  plot_autocorr,
13
14
  plot_bpv,
@@ -773,7 +774,6 @@ def test_plot_mcse_no_divergences(models):
773
774
  {"divergences": True, "var_names": ["theta", "mu"]},
774
775
  {"kind": "kde", "var_names": ["theta"]},
775
776
  {"kind": "hexbin", "var_names": ["theta"]},
776
- {"kind": "hexbin", "var_names": ["theta"]},
777
777
  {
778
778
  "kind": "hexbin",
779
779
  "var_names": ["theta"],
@@ -785,6 +785,21 @@ def test_plot_mcse_no_divergences(models):
785
785
  "reference_values": {"mu": 0, "tau": 0},
786
786
  "reference_values_kwargs": {"line_color": "blue"},
787
787
  },
788
+ {
789
+ "var_names": ["mu", "tau"],
790
+ "reference_values": {"mu": 0, "tau": 0},
791
+ "labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
792
+ },
793
+ {
794
+ "var_names": ["theta"],
795
+ "reference_values": {"theta": [0.0] * 8},
796
+ "labeller": MapLabeller({"theta": r"$\theta$"}),
797
+ },
798
+ {
799
+ "var_names": ["theta"],
800
+ "reference_values": {"theta": np.zeros(8)},
801
+ "labeller": MapLabeller({"theta": r"$\theta$"}),
802
+ },
788
803
  ],
789
804
  )
790
805
  def test_plot_pair(models, kwargs):
@@ -1201,7 +1216,7 @@ def test_plot_dot_rotated(continuous_model, kwargs):
1201
1216
  },
1202
1217
  ],
1203
1218
  )
1204
- def test_plot_lm(models, kwargs):
1219
+ def test_plot_lm_1d(models, kwargs):
1205
1220
  """Test functionality for 1D data."""
1206
1221
  idata = models.model_1
1207
1222
  if "constant_data" not in idata.groups():
@@ -1228,3 +1243,46 @@ def test_plot_lm_list():
1228
1243
  """Test the plots when input data is list or ndarray."""
1229
1244
  y = [1, 2, 3, 4, 5]
1230
1245
  assert plot_lm(y=y, x=np.arange(len(y)), show=False, backend="bokeh")
1246
+
1247
+
1248
+ def generate_lm_1d_data():
1249
+ rng = np.random.default_rng()
1250
+ return from_dict(
1251
+ observed_data={"y": rng.normal(size=7)},
1252
+ posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
1253
+ posterior={"y_model": rng.normal(size=(4, 1000, 7))},
1254
+ dims={"y": ["dim1"]},
1255
+ coords={"dim1": range(7)},
1256
+ )
1257
+
1258
+
1259
+ def generate_lm_2d_data():
1260
+ rng = np.random.default_rng()
1261
+ return from_dict(
1262
+ observed_data={"y": rng.normal(size=(5, 7))},
1263
+ posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
1264
+ posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
1265
+ dims={"y": ["dim1", "dim2"]},
1266
+ coords={"dim1": range(5), "dim2": range(7)},
1267
+ )
1268
+
1269
+
1270
+ @pytest.mark.parametrize("data", ("1d", "2d"))
1271
+ @pytest.mark.parametrize("kind", ("lines", "hdi"))
1272
+ @pytest.mark.parametrize("use_y_model", (True, False))
1273
+ def test_plot_lm(data, kind, use_y_model):
1274
+ if data == "1d":
1275
+ idata = generate_lm_1d_data()
1276
+ else:
1277
+ idata = generate_lm_2d_data()
1278
+
1279
+ kwargs = {"idata": idata, "y": "y", "kind_model": kind, "backend": "bokeh", "show": False}
1280
+ if data == "2d":
1281
+ kwargs["plot_dim"] = "dim1"
1282
+ if use_y_model:
1283
+ kwargs["y_model"] = "y_model"
1284
+ if kind == "lines":
1285
+ kwargs["num_samples"] = 50
1286
+
1287
+ ax = plot_lm(**kwargs)
1288
+ assert ax is not None
@@ -14,6 +14,7 @@ from pandas import DataFrame
14
14
  from scipy.stats import gaussian_kde, norm
15
15
 
16
16
  from ...data import from_dict, load_arviz_data
17
+ from ...labels import MapLabeller
17
18
  from ...plots import (
18
19
  plot_autocorr,
19
20
  plot_bf,
@@ -599,6 +600,21 @@ def test_plot_kde_inference_data(models):
599
600
  "reference_values": {"mu": 0, "tau": 0},
600
601
  "reference_values_kwargs": {"c": "C0", "marker": "*"},
601
602
  },
603
+ {
604
+ "var_names": ["mu", "tau"],
605
+ "reference_values": {"mu": 0, "tau": 0},
606
+ "labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
607
+ },
608
+ {
609
+ "var_names": ["theta"],
610
+ "reference_values": {"theta": [0.0] * 8},
611
+ "labeller": MapLabeller({"theta": r"$\theta$"}),
612
+ },
613
+ {
614
+ "var_names": ["theta"],
615
+ "reference_values": {"theta": np.zeros(8)},
616
+ "labeller": MapLabeller({"theta": r"$\theta$"}),
617
+ },
602
618
  ],
603
619
  )
604
620
  def test_plot_pair(models, kwargs):
@@ -1914,7 +1930,7 @@ def test_wilkinson_algorithm(continuous_model):
1914
1930
  },
1915
1931
  ],
1916
1932
  )
1917
- def test_plot_lm(models, kwargs):
1933
+ def test_plot_lm_1d(models, kwargs):
1918
1934
  """Test functionality for 1D data."""
1919
1935
  idata = models.model_1
1920
1936
  if "constant_data" not in idata.groups():
@@ -2102,3 +2118,63 @@ def test_plot_bf():
2102
2118
  )
2103
2119
  _, bf_plot = plot_bf(idata, var_name="a", ref_val=0)
2104
2120
  assert bf_plot is not None
2121
+
2122
+
2123
+ def generate_lm_1d_data():
2124
+ rng = np.random.default_rng()
2125
+ return from_dict(
2126
+ observed_data={"y": rng.normal(size=7)},
2127
+ posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
2128
+ posterior={"y_model": rng.normal(size=(4, 1000, 7))},
2129
+ dims={"y": ["dim1"]},
2130
+ coords={"dim1": range(7)},
2131
+ )
2132
+
2133
+
2134
+ def generate_lm_2d_data():
2135
+ rng = np.random.default_rng()
2136
+ return from_dict(
2137
+ observed_data={"y": rng.normal(size=(5, 7))},
2138
+ posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
2139
+ posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
2140
+ dims={"y": ["dim1", "dim2"]},
2141
+ coords={"dim1": range(5), "dim2": range(7)},
2142
+ )
2143
+
2144
+
2145
+ @pytest.mark.parametrize("data", ("1d", "2d"))
2146
+ @pytest.mark.parametrize("kind", ("lines", "hdi"))
2147
+ @pytest.mark.parametrize("use_y_model", (True, False))
2148
+ def test_plot_lm(data, kind, use_y_model):
2149
+ if data == "1d":
2150
+ idata = generate_lm_1d_data()
2151
+ else:
2152
+ idata = generate_lm_2d_data()
2153
+
2154
+ kwargs = {"idata": idata, "y": "y", "kind_model": kind}
2155
+ if data == "2d":
2156
+ kwargs["plot_dim"] = "dim1"
2157
+ if use_y_model:
2158
+ kwargs["y_model"] = "y_model"
2159
+ if kind == "lines":
2160
+ kwargs["num_samples"] = 50
2161
+
2162
+ ax = plot_lm(**kwargs)
2163
+ assert ax is not None
2164
+
2165
+
2166
+ @pytest.mark.parametrize(
2167
+ "coords, expected_vars",
2168
+ [
2169
+ ({"school": ["Choate"]}, ["theta"]),
2170
+ ({"school": ["Lawrenceville"]}, ["theta"]),
2171
+ ({}, ["theta"]),
2172
+ ],
2173
+ )
2174
+ def test_plot_autocorr_coords(coords, expected_vars):
2175
+ """Test plot_autocorr with coords kwarg."""
2176
+ idata = load_arviz_data("centered_eight")
2177
+
2178
+ axes = plot_autocorr(idata, var_names=expected_vars, coords=coords, show=False)
2179
+
2180
+ assert axes is not None
@@ -14,7 +14,7 @@ from scipy.stats import linregress, norm, halfcauchy
14
14
  from xarray import DataArray, Dataset
15
15
  from xarray_einstats.stats import XrContinuousRV
16
16
 
17
- from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data
17
+ from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data, InferenceData
18
18
  from ...rcparams import rcParams
19
19
  from ...stats import (
20
20
  apply_test_function,
@@ -882,3 +882,44 @@ def test_bayes_factor():
882
882
  bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
883
883
  assert bf_dict0["BF10"] > bf_dict0["BF01"]
884
884
  assert bf_dict1["BF10"] < bf_dict1["BF01"]
885
+
886
+
887
+ def test_compare_sorting_consistency():
888
+ chains, draws = 4, 1000
889
+
890
+ # Model 1 - good fit
891
+ log_lik1 = np.random.normal(-2, 1, size=(chains, draws))
892
+ posterior1 = Dataset(
893
+ {"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
894
+ coords={"chain": range(chains), "draw": range(draws)},
895
+ )
896
+ log_like1 = Dataset(
897
+ {"y": (("chain", "draw"), log_lik1)},
898
+ coords={"chain": range(chains), "draw": range(draws)},
899
+ )
900
+ data1 = InferenceData(posterior=posterior1, log_likelihood=log_like1)
901
+
902
+ # Model 2 - poor fit (higher variance)
903
+ log_lik2 = np.random.normal(-5, 2, size=(chains, draws))
904
+ posterior2 = Dataset(
905
+ {"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
906
+ coords={"chain": range(chains), "draw": range(draws)},
907
+ )
908
+ log_like2 = Dataset(
909
+ {"y": (("chain", "draw"), log_lik2)},
910
+ coords={"chain": range(chains), "draw": range(draws)},
911
+ )
912
+ data2 = InferenceData(posterior=posterior2, log_likelihood=log_like2)
913
+
914
+ # Compare models in different orders
915
+ comp_dict1 = {"M1": data1, "M2": data2}
916
+ comp_dict2 = {"M2": data2, "M1": data1}
917
+
918
+ comparison1 = compare(comp_dict1, method="bb-pseudo-bma")
919
+ comparison2 = compare(comp_dict2, method="bb-pseudo-bma")
920
+
921
+ assert comparison1.index.tolist() == comparison2.index.tolist()
922
+
923
+ se1 = comparison1["se"].values
924
+ se2 = comparison2["se"].values
925
+ np.testing.assert_array_almost_equal(se1, se2)