arviz 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. arviz/__init__.py +49 -4
  2. arviz/data/converters.py +11 -0
  3. arviz/data/inference_data.py +46 -24
  4. arviz/data/io_datatree.py +2 -2
  5. arviz/data/io_numpyro.py +116 -5
  6. arviz/data/io_pyjags.py +1 -1
  7. arviz/plots/autocorrplot.py +12 -2
  8. arviz/plots/backends/bokeh/hdiplot.py +7 -6
  9. arviz/plots/backends/bokeh/lmplot.py +19 -3
  10. arviz/plots/backends/bokeh/pairplot.py +18 -48
  11. arviz/plots/backends/matplotlib/khatplot.py +8 -1
  12. arviz/plots/backends/matplotlib/lmplot.py +13 -7
  13. arviz/plots/backends/matplotlib/pairplot.py +14 -22
  14. arviz/plots/bpvplot.py +1 -1
  15. arviz/plots/dotplot.py +2 -0
  16. arviz/plots/forestplot.py +16 -4
  17. arviz/plots/kdeplot.py +4 -4
  18. arviz/plots/lmplot.py +41 -14
  19. arviz/plots/pairplot.py +10 -3
  20. arviz/plots/ppcplot.py +1 -1
  21. arviz/preview.py +31 -21
  22. arviz/rcparams.py +2 -2
  23. arviz/stats/density_utils.py +1 -1
  24. arviz/stats/stats.py +31 -34
  25. arviz/tests/base_tests/test_data.py +25 -4
  26. arviz/tests/base_tests/test_plots_bokeh.py +60 -2
  27. arviz/tests/base_tests/test_plots_matplotlib.py +94 -1
  28. arviz/tests/base_tests/test_stats.py +42 -1
  29. arviz/tests/base_tests/test_stats_ecdf_utils.py +2 -2
  30. arviz/tests/external_tests/test_data_numpyro.py +154 -4
  31. arviz/wrappers/base.py +1 -1
  32. arviz/wrappers/wrap_stan.py +1 -1
  33. {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/METADATA +20 -9
  34. {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/RECORD +37 -37
  35. {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/WHEEL +1 -1
  36. {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info/licenses}/LICENSE +0 -0
  37. {arviz-0.21.0.dist-info → arviz-0.23.0.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ from pandas import DataFrame # pylint: disable=wrong-import-position
8
8
  from scipy.stats import norm # pylint: disable=wrong-import-position
9
9
 
10
10
  from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position
11
+ from ...labels import MapLabeller # pylint: disable=wrong-import-position
11
12
  from ...plots import ( # pylint: disable=wrong-import-position
12
13
  plot_autocorr,
13
14
  plot_bpv,
@@ -773,7 +774,6 @@ def test_plot_mcse_no_divergences(models):
773
774
  {"divergences": True, "var_names": ["theta", "mu"]},
774
775
  {"kind": "kde", "var_names": ["theta"]},
775
776
  {"kind": "hexbin", "var_names": ["theta"]},
776
- {"kind": "hexbin", "var_names": ["theta"]},
777
777
  {
778
778
  "kind": "hexbin",
779
779
  "var_names": ["theta"],
@@ -785,6 +785,21 @@ def test_plot_mcse_no_divergences(models):
785
785
  "reference_values": {"mu": 0, "tau": 0},
786
786
  "reference_values_kwargs": {"line_color": "blue"},
787
787
  },
788
+ {
789
+ "var_names": ["mu", "tau"],
790
+ "reference_values": {"mu": 0, "tau": 0},
791
+ "labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
792
+ },
793
+ {
794
+ "var_names": ["theta"],
795
+ "reference_values": {"theta": [0.0] * 8},
796
+ "labeller": MapLabeller({"theta": r"$\theta$"}),
797
+ },
798
+ {
799
+ "var_names": ["theta"],
800
+ "reference_values": {"theta": np.zeros(8)},
801
+ "labeller": MapLabeller({"theta": r"$\theta$"}),
802
+ },
788
803
  ],
789
804
  )
790
805
  def test_plot_pair(models, kwargs):
@@ -1201,7 +1216,7 @@ def test_plot_dot_rotated(continuous_model, kwargs):
1201
1216
  },
1202
1217
  ],
1203
1218
  )
1204
- def test_plot_lm(models, kwargs):
1219
+ def test_plot_lm_1d(models, kwargs):
1205
1220
  """Test functionality for 1D data."""
1206
1221
  idata = models.model_1
1207
1222
  if "constant_data" not in idata.groups():
@@ -1228,3 +1243,46 @@ def test_plot_lm_list():
1228
1243
  """Test the plots when input data is list or ndarray."""
1229
1244
  y = [1, 2, 3, 4, 5]
1230
1245
  assert plot_lm(y=y, x=np.arange(len(y)), show=False, backend="bokeh")
1246
+
1247
+
1248
+ def generate_lm_1d_data():
1249
+ rng = np.random.default_rng()
1250
+ return from_dict(
1251
+ observed_data={"y": rng.normal(size=7)},
1252
+ posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
1253
+ posterior={"y_model": rng.normal(size=(4, 1000, 7))},
1254
+ dims={"y": ["dim1"]},
1255
+ coords={"dim1": range(7)},
1256
+ )
1257
+
1258
+
1259
+ def generate_lm_2d_data():
1260
+ rng = np.random.default_rng()
1261
+ return from_dict(
1262
+ observed_data={"y": rng.normal(size=(5, 7))},
1263
+ posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
1264
+ posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
1265
+ dims={"y": ["dim1", "dim2"]},
1266
+ coords={"dim1": range(5), "dim2": range(7)},
1267
+ )
1268
+
1269
+
1270
+ @pytest.mark.parametrize("data", ("1d", "2d"))
1271
+ @pytest.mark.parametrize("kind", ("lines", "hdi"))
1272
+ @pytest.mark.parametrize("use_y_model", (True, False))
1273
+ def test_plot_lm(data, kind, use_y_model):
1274
+ if data == "1d":
1275
+ idata = generate_lm_1d_data()
1276
+ else:
1277
+ idata = generate_lm_2d_data()
1278
+
1279
+ kwargs = {"idata": idata, "y": "y", "kind_model": kind, "backend": "bokeh", "show": False}
1280
+ if data == "2d":
1281
+ kwargs["plot_dim"] = "dim1"
1282
+ if use_y_model:
1283
+ kwargs["y_model"] = "y_model"
1284
+ if kind == "lines":
1285
+ kwargs["num_samples"] = 50
1286
+
1287
+ ax = plot_lm(**kwargs)
1288
+ assert ax is not None
@@ -14,6 +14,7 @@ from pandas import DataFrame
14
14
  from scipy.stats import gaussian_kde, norm
15
15
 
16
16
  from ...data import from_dict, load_arviz_data
17
+ from ...labels import MapLabeller
17
18
  from ...plots import (
18
19
  plot_autocorr,
19
20
  plot_bf,
@@ -599,6 +600,21 @@ def test_plot_kde_inference_data(models):
599
600
  "reference_values": {"mu": 0, "tau": 0},
600
601
  "reference_values_kwargs": {"c": "C0", "marker": "*"},
601
602
  },
603
+ {
604
+ "var_names": ["mu", "tau"],
605
+ "reference_values": {"mu": 0, "tau": 0},
606
+ "labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
607
+ },
608
+ {
609
+ "var_names": ["theta"],
610
+ "reference_values": {"theta": [0.0] * 8},
611
+ "labeller": MapLabeller({"theta": r"$\theta$"}),
612
+ },
613
+ {
614
+ "var_names": ["theta"],
615
+ "reference_values": {"theta": np.zeros(8)},
616
+ "labeller": MapLabeller({"theta": r"$\theta$"}),
617
+ },
602
618
  ],
603
619
  )
604
620
  def test_plot_pair(models, kwargs):
@@ -1914,7 +1930,7 @@ def test_wilkinson_algorithm(continuous_model):
1914
1930
  },
1915
1931
  ],
1916
1932
  )
1917
- def test_plot_lm(models, kwargs):
1933
+ def test_plot_lm_1d(models, kwargs):
1918
1934
  """Test functionality for 1D data."""
1919
1935
  idata = models.model_1
1920
1936
  if "constant_data" not in idata.groups():
@@ -2102,3 +2118,80 @@ def test_plot_bf():
2102
2118
  )
2103
2119
  _, bf_plot = plot_bf(idata, var_name="a", ref_val=0)
2104
2120
  assert bf_plot is not None
2121
+
2122
+
2123
+ def generate_lm_1d_data():
2124
+ rng = np.random.default_rng()
2125
+ return from_dict(
2126
+ observed_data={"y": rng.normal(size=7)},
2127
+ posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
2128
+ posterior={"y_model": rng.normal(size=(4, 1000, 7))},
2129
+ dims={"y": ["dim1"]},
2130
+ coords={"dim1": range(7)},
2131
+ )
2132
+
2133
+
2134
+ def generate_lm_2d_data():
2135
+ rng = np.random.default_rng()
2136
+ return from_dict(
2137
+ observed_data={"y": rng.normal(size=(5, 7))},
2138
+ posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
2139
+ posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
2140
+ dims={"y": ["dim1", "dim2"]},
2141
+ coords={"dim1": range(5), "dim2": range(7)},
2142
+ )
2143
+
2144
+
2145
+ @pytest.mark.parametrize("data", ("1d", "2d"))
2146
+ @pytest.mark.parametrize("kind", ("lines", "hdi"))
2147
+ @pytest.mark.parametrize("use_y_model", (True, False))
2148
+ def test_plot_lm(data, kind, use_y_model):
2149
+ if data == "1d":
2150
+ idata = generate_lm_1d_data()
2151
+ else:
2152
+ idata = generate_lm_2d_data()
2153
+
2154
+ kwargs = {"idata": idata, "y": "y", "kind_model": kind}
2155
+ if data == "2d":
2156
+ kwargs["plot_dim"] = "dim1"
2157
+ if use_y_model:
2158
+ kwargs["y_model"] = "y_model"
2159
+ if kind == "lines":
2160
+ kwargs["num_samples"] = 50
2161
+
2162
+ ax = plot_lm(**kwargs)
2163
+ assert ax is not None
2164
+
2165
+
2166
+ @pytest.mark.parametrize(
2167
+ "coords, expected_vars",
2168
+ [
2169
+ ({"school": ["Choate"]}, ["theta"]),
2170
+ ({"school": ["Lawrenceville"]}, ["theta"]),
2171
+ ({}, ["theta"]),
2172
+ ],
2173
+ )
2174
+ def test_plot_autocorr_coords(coords, expected_vars):
2175
+ """Test plot_autocorr with coords kwarg."""
2176
+ idata = load_arviz_data("centered_eight")
2177
+
2178
+ axes = plot_autocorr(idata, var_names=expected_vars, coords=coords, show=False)
2179
+ assert axes is not None
2180
+
2181
+
2182
+ def test_plot_forest_with_transform():
2183
+ """Test if plot_forest runs successfully with a transform dictionary."""
2184
+ data = xr.Dataset(
2185
+ {
2186
+ "var1": (["chain", "draw"], np.array([[1, 2, 3], [4, 5, 6]])),
2187
+ "var2": (["chain", "draw"], np.array([[7, 8, 9], [10, 11, 12]])),
2188
+ },
2189
+ coords={"chain": [0, 1], "draw": [0, 1, 2]},
2190
+ )
2191
+ transform_dict = {
2192
+ "var1": lambda x: x + 1,
2193
+ "var2": lambda x: x * 2,
2194
+ }
2195
+
2196
+ axes = plot_forest(data, transform=transform_dict, show=False)
2197
+ assert axes is not None
@@ -14,7 +14,7 @@ from scipy.stats import linregress, norm, halfcauchy
14
14
  from xarray import DataArray, Dataset
15
15
  from xarray_einstats.stats import XrContinuousRV
16
16
 
17
- from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data
17
+ from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data, InferenceData
18
18
  from ...rcparams import rcParams
19
19
  from ...stats import (
20
20
  apply_test_function,
@@ -882,3 +882,44 @@ def test_bayes_factor():
882
882
  bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
883
883
  assert bf_dict0["BF10"] > bf_dict0["BF01"]
884
884
  assert bf_dict1["BF10"] < bf_dict1["BF01"]
885
+
886
+
887
+ def test_compare_sorting_consistency():
888
+ chains, draws = 4, 1000
889
+
890
+ # Model 1 - good fit
891
+ log_lik1 = np.random.normal(-2, 1, size=(chains, draws))
892
+ posterior1 = Dataset(
893
+ {"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
894
+ coords={"chain": range(chains), "draw": range(draws)},
895
+ )
896
+ log_like1 = Dataset(
897
+ {"y": (("chain", "draw"), log_lik1)},
898
+ coords={"chain": range(chains), "draw": range(draws)},
899
+ )
900
+ data1 = InferenceData(posterior=posterior1, log_likelihood=log_like1)
901
+
902
+ # Model 2 - poor fit (higher variance)
903
+ log_lik2 = np.random.normal(-5, 2, size=(chains, draws))
904
+ posterior2 = Dataset(
905
+ {"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
906
+ coords={"chain": range(chains), "draw": range(draws)},
907
+ )
908
+ log_like2 = Dataset(
909
+ {"y": (("chain", "draw"), log_lik2)},
910
+ coords={"chain": range(chains), "draw": range(draws)},
911
+ )
912
+ data2 = InferenceData(posterior=posterior2, log_likelihood=log_like2)
913
+
914
+ # Compare models in different orders
915
+ comp_dict1 = {"M1": data1, "M2": data2}
916
+ comp_dict2 = {"M2": data2, "M1": data1}
917
+
918
+ comparison1 = compare(comp_dict1, method="bb-pseudo-bma")
919
+ comparison2 = compare(comp_dict2, method="bb-pseudo-bma")
920
+
921
+ assert comparison1.index.tolist() == comparison2.index.tolist()
922
+
923
+ se1 = comparison1["se"].values
924
+ se2 = comparison2["se"].values
925
+ np.testing.assert_array_almost_equal(se1, se2)
@@ -13,9 +13,9 @@ from ...stats.ecdf_utils import (
13
13
  try:
14
14
  import numba # pylint: disable=unused-import
15
15
 
16
- numba_options = [True, False]
16
+ numba_options = [True, False] # pylint: disable=invalid-name
17
17
  except ImportError:
18
- numba_options = [False]
18
+ numba_options = [False] # pylint: disable=invalid-name
19
19
 
20
20
 
21
21
  def test_compute_ecdf():
@@ -1,4 +1,4 @@
1
- # pylint: disable=no-member, invalid-name, redefined-outer-name
1
+ # pylint: disable=no-member, invalid-name, redefined-outer-name, too-many-public-methods
2
2
  from collections import namedtuple
3
3
  import numpy as np
4
4
  import pytest
@@ -46,7 +46,9 @@ class TestDataNumPyro:
46
46
  )
47
47
  return predictions
48
48
 
49
- def get_inference_data(self, data, eight_schools_params, predictions_data, predictions_params):
49
+ def get_inference_data(
50
+ self, data, eight_schools_params, predictions_data, predictions_params, infer_dims=False
51
+ ):
50
52
  posterior_samples = data.obj.get_samples()
51
53
  model = data.obj.sampler.model
52
54
  posterior_predictive = Predictive(model, posterior_samples)(
@@ -55,6 +57,12 @@ class TestDataNumPyro:
55
57
  prior = Predictive(model, num_samples=500)(
56
58
  PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
57
59
  )
60
+ dims = {"theta": ["school"], "eta": ["school"], "obs": ["school"]}
61
+ pred_dims = {"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]}
62
+ if infer_dims:
63
+ dims = None
64
+ pred_dims = None
65
+
58
66
  predictions = predictions_data
59
67
  return from_numpyro(
60
68
  posterior=data.obj,
@@ -65,8 +73,8 @@ class TestDataNumPyro:
65
73
  "school": np.arange(eight_schools_params["J"]),
66
74
  "school_pred": np.arange(predictions_params["J"]),
67
75
  },
68
- dims={"theta": ["school"], "eta": ["school"], "obs": ["school"]},
69
- pred_dims={"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]},
76
+ dims=dims,
77
+ pred_dims=pred_dims,
70
78
  )
71
79
 
72
80
  def test_inference_data_namedtuple(self, data):
@@ -77,6 +85,7 @@ class TestDataNumPyro:
77
85
  data.obj.get_samples = lambda *args, **kwargs: data_namedtuple
78
86
  inference_data = from_numpyro(
79
87
  posterior=data.obj,
88
+ dims={}, # This mock test needs to turn off autodims like so or mock group_by_chain
80
89
  )
81
90
  assert isinstance(data.obj.get_samples(), Samples)
82
91
  data.obj.get_samples = _old_fn
@@ -282,3 +291,144 @@ class TestDataNumPyro:
282
291
  mcmc.run(PRNGKey(0))
283
292
  inference_data = from_numpyro(mcmc)
284
293
  assert inference_data.observed_data
294
+
295
+ def test_mcmc_infer_dims(self):
296
+ import numpyro
297
+ import numpyro.distributions as dist
298
+ from numpyro.infer import MCMC, NUTS
299
+
300
+ def model():
301
+ # note: group2 gets assigned dim=-1 and group1 is assigned dim=-2
302
+ with numpyro.plate("group2", 5), numpyro.plate("group1", 10):
303
+ _ = numpyro.sample("param", dist.Normal(0, 1))
304
+
305
+ mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
306
+ mcmc.run(PRNGKey(0))
307
+ inference_data = from_numpyro(
308
+ mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
309
+ )
310
+ assert inference_data.posterior.param.dims == ("chain", "draw", "group1", "group2")
311
+ assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
312
+
313
+ def test_mcmc_infer_unsorted_dims(self):
314
+ import numpyro
315
+ import numpyro.distributions as dist
316
+ from numpyro.infer import MCMC, NUTS
317
+
318
+ def model():
319
+ group1_plate = numpyro.plate("group1", 10, dim=-1)
320
+ group2_plate = numpyro.plate("group2", 5, dim=-2)
321
+
322
+ # the plate contexts are entered in a different order than the pre-defined dims
323
+ # we should make sure this still works because the trace has all of the info it needs
324
+ with group2_plate, group1_plate:
325
+ _ = numpyro.sample("param", dist.Normal(0, 1))
326
+
327
+ mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
328
+ mcmc.run(PRNGKey(0))
329
+ inference_data = from_numpyro(
330
+ mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
331
+ )
332
+ assert inference_data.posterior.param.dims == ("chain", "draw", "group2", "group1")
333
+ assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
334
+
335
+ def test_mcmc_infer_dims_no_coords(self):
336
+ import numpyro
337
+ import numpyro.distributions as dist
338
+ from numpyro.infer import MCMC, NUTS
339
+
340
+ def model():
341
+ with numpyro.plate("group", 5):
342
+ _ = numpyro.sample("param", dist.Normal(0, 1))
343
+
344
+ mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
345
+ mcmc.run(PRNGKey(0))
346
+ inference_data = from_numpyro(mcmc)
347
+ assert inference_data.posterior.param.dims == ("chain", "draw", "group")
348
+
349
+ def test_mcmc_event_dims(self):
350
+ import numpyro
351
+ import numpyro.distributions as dist
352
+ from numpyro.infer import MCMC, NUTS
353
+
354
+ def model():
355
+ _ = numpyro.sample(
356
+ "gamma", dist.ZeroSumNormal(1, event_shape=(10,)), infer={"event_dims": ["groups"]}
357
+ )
358
+
359
+ mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
360
+ mcmc.run(PRNGKey(0))
361
+ inference_data = from_numpyro(mcmc, coords={"groups": np.arange(10)})
362
+ assert inference_data.posterior.gamma.dims == ("chain", "draw", "groups")
363
+ assert "groups" in inference_data.posterior.gamma.coords
364
+
365
+ @pytest.mark.xfail
366
+ def test_mcmc_inferred_dims_univariate(self):
367
+ import numpyro
368
+ import numpyro.distributions as dist
369
+ from numpyro.infer import MCMC, NUTS
370
+ import jax.numpy as jnp
371
+
372
+ def model():
373
+ alpha = numpyro.sample("alpha", dist.Normal(0, 1))
374
+ sigma = numpyro.sample("sigma", dist.HalfNormal(1))
375
+ with numpyro.plate("obs_idx", 3):
376
+ # mu is plated by obs_idx, but isnt broadcasted to the plate shape
377
+ # the expected behavior is that this should cause a failure
378
+ mu = numpyro.deterministic("mu", alpha)
379
+ return numpyro.sample("y", dist.Normal(mu, sigma), obs=jnp.array([-1, 0, 1]))
380
+
381
+ mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
382
+ mcmc.run(PRNGKey(0))
383
+ inference_data = from_numpyro(mcmc, coords={"obs_idx": np.arange(3)})
384
+ assert inference_data.posterior.mu.dims == ("chain", "draw", "obs_idx")
385
+ assert "obs_idx" in inference_data.posterior.mu.coords
386
+
387
+ def test_mcmc_extra_event_dims(self):
388
+ import numpyro
389
+ import numpyro.distributions as dist
390
+ from numpyro.infer import MCMC, NUTS
391
+
392
+ def model():
393
+ gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(10,)))
394
+ _ = numpyro.deterministic("gamma_plus1", gamma + 1)
395
+
396
+ mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
397
+ mcmc.run(PRNGKey(0))
398
+ inference_data = from_numpyro(
399
+ mcmc, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]}
400
+ )
401
+ assert inference_data.posterior.gamma_plus1.dims == ("chain", "draw", "groups")
402
+ assert "groups" in inference_data.posterior.gamma_plus1.coords
403
+
404
+ def test_mcmc_predictions_infer_dims(
405
+ self, data, eight_schools_params, predictions_data, predictions_params
406
+ ):
407
+ inference_data = self.get_inference_data(
408
+ data, eight_schools_params, predictions_data, predictions_params, infer_dims=True
409
+ )
410
+ assert inference_data.predictions.obs.dims == ("chain", "draw", "J")
411
+ assert "J" in inference_data.predictions.obs.coords
412
+
413
+ def test_potential_energy_sign_conversion(self):
414
+ """Test that potential energy is converted to log probability (lp) with correct sign."""
415
+ import numpyro
416
+ import numpyro.distributions as dist
417
+ from numpyro.infer import MCMC, NUTS
418
+
419
+ num_samples = 10
420
+
421
+ def simple_model():
422
+ numpyro.sample("x", dist.Normal(0, 1))
423
+
424
+ nuts_kernel = NUTS(simple_model)
425
+ mcmc = MCMC(nuts_kernel, num_samples=num_samples, num_warmup=5)
426
+ mcmc.run(PRNGKey(0), extra_fields=["potential_energy"])
427
+
428
+ # Get the raw extra fields from NumPyro
429
+ extra_fields = mcmc.get_extra_fields(group_by_chain=True)
430
+ # Convert to ArviZ InferenceData
431
+ inference_data = from_numpyro(mcmc)
432
+ arviz_lp = inference_data["sample_stats"]["lp"].values
433
+
434
+ np.testing.assert_array_equal(arviz_lp, -extra_fields["potential_energy"])
arviz/wrappers/base.py CHANGED
@@ -197,7 +197,7 @@ class SamplingWrapper:
197
197
  """Check that all methods listed are implemented.
198
198
 
199
199
  Not all functions that require refitting need to have all the methods implemented in
200
- order to work properly. This function shoulg be used before using the SamplingWrapper and
200
+ order to work properly. This function should be used before using the SamplingWrapper and
201
201
  its subclasses to get informative error messages.
202
202
 
203
203
  Parameters
@@ -44,7 +44,7 @@ class StanSamplingWrapper(SamplingWrapper):
44
44
  excluded_observed_data : str
45
45
  Variable name containing the pointwise log likelihood data of the excluded
46
46
  data. As PyStan cannot call C++ functions and log_likelihood__i is already
47
- calculated *during* the simultion, instead of the value on which to evaluate
47
+ calculated *during* the simulation, instead of the value on which to evaluate
48
48
  the likelihood, ``log_likelihood__i`` expects a string so it can extract the
49
49
  corresponding data from the InferenceData object.
50
50
  """
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: arviz
3
- Version: 0.21.0
3
+ Version: 0.23.0
4
4
  Summary: Exploratory analysis of Bayesian models
5
5
  Home-page: http://github.com/arviz-devs/arviz
6
6
  Author: ArviZ Developers
@@ -22,14 +22,14 @@ Requires-Python: >=3.10
22
22
  Description-Content-Type: text/markdown
23
23
  License-File: LICENSE
24
24
  Requires-Dist: setuptools>=60.0.0
25
- Requires-Dist: matplotlib>=3.5
26
- Requires-Dist: numpy>=1.23.0
27
- Requires-Dist: scipy>=1.9.0
25
+ Requires-Dist: matplotlib>=3.8
26
+ Requires-Dist: numpy>=1.26.0
27
+ Requires-Dist: scipy>=1.11.0
28
28
  Requires-Dist: packaging
29
- Requires-Dist: pandas>=1.5.0
30
- Requires-Dist: xarray>=2022.6.0
29
+ Requires-Dist: pandas>=2.1.0
30
+ Requires-Dist: xarray>=2023.7.0
31
31
  Requires-Dist: h5netcdf>=1.0.2
32
- Requires-Dist: typing-extensions>=4.1.0
32
+ Requires-Dist: typing_extensions>=4.1.0
33
33
  Requires-Dist: xarray-einstats>=0.3
34
34
  Provides-Extra: all
35
35
  Requires-Dist: numba; extra == "all"
@@ -39,12 +39,23 @@ Requires-Dist: contourpy; extra == "all"
39
39
  Requires-Dist: ujson; extra == "all"
40
40
  Requires-Dist: dask[distributed]; extra == "all"
41
41
  Requires-Dist: zarr<3,>=2.5.0; extra == "all"
42
- Requires-Dist: xarray-datatree; extra == "all"
42
+ Requires-Dist: xarray>=2024.11.0; extra == "all"
43
43
  Requires-Dist: dm-tree>=0.1.8; extra == "all"
44
44
  Provides-Extra: preview
45
45
  Requires-Dist: arviz-base[h5netcdf]; extra == "preview"
46
46
  Requires-Dist: arviz-stats[xarray]; extra == "preview"
47
47
  Requires-Dist: arviz-plots; extra == "preview"
48
+ Dynamic: author
49
+ Dynamic: classifier
50
+ Dynamic: description
51
+ Dynamic: description-content-type
52
+ Dynamic: home-page
53
+ Dynamic: license
54
+ Dynamic: license-file
55
+ Dynamic: provides-extra
56
+ Dynamic: requires-dist
57
+ Dynamic: requires-python
58
+ Dynamic: summary
48
59
 
49
60
  <img src="https://raw.githubusercontent.com/arviz-devs/arviz-project/main/arviz_logos/ArviZ.png#gh-light-mode-only" width=200></img>
50
61
  <img src="https://raw.githubusercontent.com/arviz-devs/arviz-project/main/arviz_logos/ArviZ_white.png#gh-dark-mode-only" width=200></img>