arviz 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. arviz/__init__.py +3 -2
  2. arviz/data/__init__.py +5 -2
  3. arviz/data/base.py +102 -11
  4. arviz/data/converters.py +5 -0
  5. arviz/data/datasets.py +1 -0
  6. arviz/data/example_data/data_remote.json +10 -3
  7. arviz/data/inference_data.py +26 -25
  8. arviz/data/io_cmdstan.py +1 -3
  9. arviz/data/io_datatree.py +1 -0
  10. arviz/data/io_dict.py +5 -3
  11. arviz/data/io_emcee.py +1 -0
  12. arviz/data/io_numpyro.py +1 -0
  13. arviz/data/io_pyjags.py +1 -0
  14. arviz/data/io_pyro.py +1 -0
  15. arviz/data/io_pystan.py +1 -2
  16. arviz/data/utils.py +1 -0
  17. arviz/plots/__init__.py +1 -0
  18. arviz/plots/autocorrplot.py +1 -0
  19. arviz/plots/backends/bokeh/autocorrplot.py +1 -0
  20. arviz/plots/backends/bokeh/bpvplot.py +8 -2
  21. arviz/plots/backends/bokeh/compareplot.py +8 -4
  22. arviz/plots/backends/bokeh/densityplot.py +1 -0
  23. arviz/plots/backends/bokeh/distplot.py +1 -0
  24. arviz/plots/backends/bokeh/dotplot.py +1 -0
  25. arviz/plots/backends/bokeh/ecdfplot.py +1 -0
  26. arviz/plots/backends/bokeh/elpdplot.py +1 -0
  27. arviz/plots/backends/bokeh/energyplot.py +1 -0
  28. arviz/plots/backends/bokeh/forestplot.py +2 -4
  29. arviz/plots/backends/bokeh/hdiplot.py +1 -0
  30. arviz/plots/backends/bokeh/kdeplot.py +3 -3
  31. arviz/plots/backends/bokeh/khatplot.py +1 -0
  32. arviz/plots/backends/bokeh/lmplot.py +1 -0
  33. arviz/plots/backends/bokeh/loopitplot.py +1 -0
  34. arviz/plots/backends/bokeh/mcseplot.py +1 -0
  35. arviz/plots/backends/bokeh/pairplot.py +1 -0
  36. arviz/plots/backends/bokeh/parallelplot.py +1 -0
  37. arviz/plots/backends/bokeh/posteriorplot.py +1 -0
  38. arviz/plots/backends/bokeh/ppcplot.py +1 -0
  39. arviz/plots/backends/bokeh/rankplot.py +1 -0
  40. arviz/plots/backends/bokeh/separationplot.py +1 -0
  41. arviz/plots/backends/bokeh/traceplot.py +1 -0
  42. arviz/plots/backends/bokeh/violinplot.py +1 -0
  43. arviz/plots/backends/matplotlib/autocorrplot.py +1 -0
  44. arviz/plots/backends/matplotlib/bpvplot.py +1 -0
  45. arviz/plots/backends/matplotlib/compareplot.py +2 -1
  46. arviz/plots/backends/matplotlib/densityplot.py +1 -0
  47. arviz/plots/backends/matplotlib/distcomparisonplot.py +2 -3
  48. arviz/plots/backends/matplotlib/distplot.py +1 -0
  49. arviz/plots/backends/matplotlib/dotplot.py +1 -0
  50. arviz/plots/backends/matplotlib/ecdfplot.py +1 -0
  51. arviz/plots/backends/matplotlib/elpdplot.py +1 -0
  52. arviz/plots/backends/matplotlib/energyplot.py +1 -0
  53. arviz/plots/backends/matplotlib/essplot.py +6 -5
  54. arviz/plots/backends/matplotlib/forestplot.py +3 -4
  55. arviz/plots/backends/matplotlib/hdiplot.py +1 -0
  56. arviz/plots/backends/matplotlib/kdeplot.py +5 -3
  57. arviz/plots/backends/matplotlib/khatplot.py +1 -0
  58. arviz/plots/backends/matplotlib/lmplot.py +1 -0
  59. arviz/plots/backends/matplotlib/loopitplot.py +1 -0
  60. arviz/plots/backends/matplotlib/mcseplot.py +11 -10
  61. arviz/plots/backends/matplotlib/pairplot.py +2 -1
  62. arviz/plots/backends/matplotlib/parallelplot.py +1 -0
  63. arviz/plots/backends/matplotlib/posteriorplot.py +1 -0
  64. arviz/plots/backends/matplotlib/ppcplot.py +1 -0
  65. arviz/plots/backends/matplotlib/rankplot.py +1 -0
  66. arviz/plots/backends/matplotlib/separationplot.py +1 -0
  67. arviz/plots/backends/matplotlib/traceplot.py +2 -1
  68. arviz/plots/backends/matplotlib/tsplot.py +1 -0
  69. arviz/plots/backends/matplotlib/violinplot.py +2 -1
  70. arviz/plots/bfplot.py +7 -6
  71. arviz/plots/bpvplot.py +3 -2
  72. arviz/plots/compareplot.py +3 -2
  73. arviz/plots/densityplot.py +1 -0
  74. arviz/plots/distcomparisonplot.py +1 -0
  75. arviz/plots/dotplot.py +1 -0
  76. arviz/plots/ecdfplot.py +38 -112
  77. arviz/plots/elpdplot.py +2 -1
  78. arviz/plots/energyplot.py +1 -0
  79. arviz/plots/essplot.py +3 -2
  80. arviz/plots/forestplot.py +1 -0
  81. arviz/plots/hdiplot.py +1 -0
  82. arviz/plots/khatplot.py +1 -0
  83. arviz/plots/lmplot.py +1 -0
  84. arviz/plots/loopitplot.py +1 -0
  85. arviz/plots/mcseplot.py +1 -0
  86. arviz/plots/pairplot.py +2 -1
  87. arviz/plots/parallelplot.py +1 -0
  88. arviz/plots/plot_utils.py +1 -0
  89. arviz/plots/posteriorplot.py +1 -0
  90. arviz/plots/ppcplot.py +11 -5
  91. arviz/plots/rankplot.py +1 -0
  92. arviz/plots/separationplot.py +1 -0
  93. arviz/plots/traceplot.py +1 -0
  94. arviz/plots/tsplot.py +1 -0
  95. arviz/plots/violinplot.py +1 -0
  96. arviz/rcparams.py +1 -0
  97. arviz/sel_utils.py +1 -0
  98. arviz/static/css/style.css +2 -1
  99. arviz/stats/density_utils.py +4 -3
  100. arviz/stats/diagnostics.py +4 -4
  101. arviz/stats/ecdf_utils.py +166 -0
  102. arviz/stats/stats.py +16 -32
  103. arviz/stats/stats_refitting.py +1 -0
  104. arviz/stats/stats_utils.py +6 -2
  105. arviz/tests/base_tests/test_data.py +18 -4
  106. arviz/tests/base_tests/test_diagnostics.py +1 -0
  107. arviz/tests/base_tests/test_diagnostics_numba.py +1 -0
  108. arviz/tests/base_tests/test_labels.py +1 -0
  109. arviz/tests/base_tests/test_plots_matplotlib.py +6 -5
  110. arviz/tests/base_tests/test_stats.py +4 -4
  111. arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
  112. arviz/tests/base_tests/test_stats_utils.py +4 -3
  113. arviz/tests/base_tests/test_utils.py +3 -2
  114. arviz/tests/external_tests/test_data_numpyro.py +3 -3
  115. arviz/tests/external_tests/test_data_pyro.py +3 -3
  116. arviz/tests/helpers.py +1 -1
  117. arviz/wrappers/__init__.py +1 -0
  118. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/METADATA +10 -9
  119. arviz-0.18.0.dist-info/RECORD +182 -0
  120. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/WHEEL +1 -1
  121. arviz-0.17.0.dist-info/RECORD +0 -180
  122. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/LICENSE +0 -0
  123. {arviz-0.17.0.dist-info → arviz-0.18.0.dist-info}/top_level.txt +0 -0
arviz/stats/stats.py CHANGED
@@ -146,6 +146,7 @@ def compare(
146
146
  Compare the centered and non centered models of the eight school problem:
147
147
 
148
148
  .. ipython::
149
+ :okwarning:
149
150
 
150
151
  In [1]: import arviz as az
151
152
  ...: data1 = az.load_arviz_data("non_centered_eight")
@@ -157,6 +158,7 @@ def compare(
157
158
  weights using the stacking method.
158
159
 
159
160
  .. ipython::
161
+ :okwarning:
160
162
 
161
163
  In [1]: az.compare(compare_dict, ic="loo", method="stacking", scale="log")
162
164
 
@@ -180,37 +182,19 @@ def compare(
180
182
  except Exception as e:
181
183
  raise e.__class__("Encountered error in ELPD computation of compare.") from e
182
184
  names = list(ics_dict.keys())
183
- if ic == "loo":
185
+ if ic in {"loo", "waic"}:
184
186
  df_comp = pd.DataFrame(
185
- index=names,
186
- columns=[
187
- "rank",
188
- "elpd_loo",
189
- "p_loo",
190
- "elpd_diff",
191
- "weight",
192
- "se",
193
- "dse",
194
- "warning",
195
- "scale",
196
- ],
197
- dtype=np.float_,
198
- )
199
- elif ic == "waic":
200
- df_comp = pd.DataFrame(
201
- index=names,
202
- columns=[
203
- "rank",
204
- "elpd_waic",
205
- "p_waic",
206
- "elpd_diff",
207
- "weight",
208
- "se",
209
- "dse",
210
- "warning",
211
- "scale",
212
- ],
213
- dtype=np.float_,
187
+ {
188
+ "rank": pd.Series(index=names, dtype="int"),
189
+ f"elpd_{ic}": pd.Series(index=names, dtype="float"),
190
+ f"p_{ic}": pd.Series(index=names, dtype="float"),
191
+ "elpd_diff": pd.Series(index=names, dtype="float"),
192
+ "weight": pd.Series(index=names, dtype="float"),
193
+ "se": pd.Series(index=names, dtype="float"),
194
+ "dse": pd.Series(index=names, dtype="float"),
195
+ "warning": pd.Series(index=names, dtype="boolean"),
196
+ "scale": pd.Series(index=names, dtype="str"),
197
+ }
214
198
  )
215
199
  else:
216
200
  raise NotImplementedError(f"The information criterion {ic} is not supported.")
@@ -632,7 +616,7 @@ def _hdi(ary, hdi_prob, circular, skipna):
632
616
  ary = np.sort(ary)
633
617
  interval_idx_inc = int(np.floor(hdi_prob * n))
634
618
  n_intervals = n - interval_idx_inc
635
- interval_width = np.subtract(ary[interval_idx_inc:], ary[:n_intervals], dtype=np.float_)
619
+ interval_width = np.subtract(ary[interval_idx_inc:], ary[:n_intervals], dtype=np.float64)
636
620
 
637
621
  if len(interval_width) == 0:
638
622
  raise ValueError("Too few elements for interval calculation. ")
@@ -2096,7 +2080,7 @@ def weight_predictions(idatas, weights=None):
2096
2080
  weights /= weights.sum()
2097
2081
 
2098
2082
  len_idatas = [
2099
- idata.posterior_predictive.dims["chain"] * idata.posterior_predictive.dims["draw"]
2083
+ idata.posterior_predictive.sizes["chain"] * idata.posterior_predictive.sizes["draw"]
2100
2084
  for idata in idatas
2101
2085
  ]
2102
2086
 
@@ -1,4 +1,5 @@
1
1
  """Stats functions that require refitting the model."""
2
+
2
3
  import logging
3
4
  import warnings
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Stats-utility functions for ArviZ."""
2
+
2
3
  import warnings
3
4
  from collections.abc import Sequence
4
5
  from copy import copy as _copy
@@ -134,7 +135,10 @@ def make_ufunc(
134
135
  raise TypeError(msg)
135
136
  for idx in np.ndindex(out.shape[:n_dims_out]):
136
137
  arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
137
- out[idx] = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
138
+ out_idx = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
139
+ if n_dims_out is None:
140
+ out_idx = out_idx.item()
141
+ out[idx] = out_idx
138
142
  return out
139
143
 
140
144
  def _multi_ufunc(*args, out=None, out_shape=None, **kwargs):
@@ -484,7 +488,7 @@ class ELPDData(pd.Series): # pylint: disable=too-many-ancestors
484
488
  base += "\n\nThere has been a warning during the calculation. Please check the results."
485
489
 
486
490
  if kind == "loo" and "pareto_k" in self:
487
- bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf])
491
+ bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
488
492
  counts, *_ = _histogram(self.pareto_k.values, bins)
489
493
  extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
490
494
  extended = extended.format(
@@ -1077,6 +1077,20 @@ def test_dict_to_dataset():
1077
1077
  assert set(dataset.b.coords) == {"chain", "draw", "c"}
1078
1078
 
1079
1079
 
1080
+ def test_nested_dict_to_dataset():
1081
+ datadict = {
1082
+ "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
1083
+ "d": np.random.randn(100),
1084
+ }
1085
+ dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]})
1086
+ assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"}
1087
+ assert set(dataset.coords) == {"chain", "draw", "c"}
1088
+
1089
+ assert set(dataset[("top", "a")].coords) == {"chain", "draw"}
1090
+ assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"}
1091
+ assert set(dataset.d.coords) == {"chain", "draw"}
1092
+
1093
+
1080
1094
  def test_dict_to_dataset_event_dims_error():
1081
1095
  datadict = {"a": np.random.randn(1, 100, 10)}
1082
1096
  coords = {"b": np.arange(10), "c": ["x", "y", "z"]}
@@ -1241,7 +1255,7 @@ class TestDataDict:
1241
1255
  self.check_var_names_coords_dims(inference_data.prior_predictive)
1242
1256
  self.check_var_names_coords_dims(inference_data.sample_stats_prior)
1243
1257
 
1244
- pred_dims = inference_data.predictions.dims["school_pred"]
1258
+ pred_dims = inference_data.predictions.sizes["school_pred"]
1245
1259
  assert pred_dims == 8
1246
1260
 
1247
1261
  def test_inference_data_warmup(self, data, eight_schools_params):
@@ -1586,8 +1600,8 @@ class TestExtractDataset:
1586
1600
  idata = load_arviz_data("centered_eight")
1587
1601
  post = extract(idata, combined=False)
1588
1602
  assert "sample" not in post.dims
1589
- assert post.dims["chain"] == 4
1590
- assert post.dims["draw"] == 500
1603
+ assert post.sizes["chain"] == 4
1604
+ assert post.sizes["draw"] == 500
1591
1605
 
1592
1606
  def test_var_name_group(self):
1593
1607
  idata = load_arviz_data("centered_eight")
@@ -1607,5 +1621,5 @@ class TestExtractDataset:
1607
1621
  def test_subset_samples(self):
1608
1622
  idata = load_arviz_data("centered_eight")
1609
1623
  post = extract(idata, num_samples=10)
1610
- assert post.dims["sample"] == 10
1624
+ assert post.sizes["sample"] == 10
1611
1625
  assert post.attrs == idata.posterior.attrs
@@ -1,4 +1,5 @@
1
1
  """Test Diagnostic methods"""
2
+
2
3
  # pylint: disable=redefined-outer-name, no-member, too-many-public-methods
3
4
  import os
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Test Diagnostic methods"""
2
+
2
3
  import importlib
3
4
 
4
5
  # pylint: disable=redefined-outer-name, no-member, too-many-public-methods
@@ -1,4 +1,5 @@
1
1
  """Tests for labeller classes."""
2
+
2
3
  import pytest
3
4
 
4
5
  from ...labels import (
@@ -1,4 +1,5 @@
1
1
  """Tests use the default backend."""
2
+
2
3
  # pylint: disable=redefined-outer-name,too-many-lines
3
4
  import os
4
5
  from copy import deepcopy
@@ -54,7 +55,7 @@ from ..helpers import ( # pylint: disable=unused-import
54
55
  eight_schools_params,
55
56
  models,
56
57
  multidim_models,
57
- TestRandomVariable,
58
+ RandomVariableTestClass,
58
59
  )
59
60
 
60
61
  rcParams["data.load"] = "eager"
@@ -168,9 +169,9 @@ def test_plot_density_no_subset():
168
169
 
169
170
  def test_plot_density_nonstring_varnames():
170
171
  """Test plot_density works when variables are not strings."""
171
- rv1 = TestRandomVariable("a")
172
- rv2 = TestRandomVariable("b")
173
- rv3 = TestRandomVariable("c")
172
+ rv1 = RandomVariableTestClass("a")
173
+ rv2 = RandomVariableTestClass("b")
174
+ rv3 = RandomVariableTestClass("c")
174
175
  model_ab = from_dict(
175
176
  {
176
177
  rv1: np.random.normal(size=200),
@@ -752,7 +753,7 @@ def test_plot_ppc_transposed():
752
753
  )
753
754
  x, y = ax.get_lines()[2].get_data()
754
755
  assert not np.isclose(y[0], 0)
755
- assert np.all(np.array([40, 43, 10, 9]) == x)
756
+ assert np.all(np.array([47, 44, 15, 11]) == x)
756
757
 
757
758
 
758
759
  @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
@@ -89,25 +89,25 @@ def test_hdi_idata(centered_eight):
89
89
  data = centered_eight.posterior
90
90
  result = hdi(data)
91
91
  assert isinstance(result, Dataset)
92
- assert dict(result.dims) == {"school": 8, "hdi": 2}
92
+ assert dict(result.sizes) == {"school": 8, "hdi": 2}
93
93
 
94
94
  result = hdi(data, input_core_dims=[["chain"]])
95
95
  assert isinstance(result, Dataset)
96
- assert result.dims == {"draw": 500, "hdi": 2, "school": 8}
96
+ assert result.sizes == {"draw": 500, "hdi": 2, "school": 8}
97
97
 
98
98
 
99
99
  def test_hdi_idata_varnames(centered_eight):
100
100
  data = centered_eight.posterior
101
101
  result = hdi(data, var_names=["mu", "theta"])
102
102
  assert isinstance(result, Dataset)
103
- assert result.dims == {"hdi": 2, "school": 8}
103
+ assert result.sizes == {"hdi": 2, "school": 8}
104
104
  assert list(result.data_vars.keys()) == ["mu", "theta"]
105
105
 
106
106
 
107
107
  def test_hdi_idata_group(centered_eight):
108
108
  result_posterior = hdi(centered_eight, group="posterior", var_names="mu")
109
109
  result_prior = hdi(centered_eight, group="prior", var_names="mu")
110
- assert result_prior.dims == {"hdi": 2}
110
+ assert result_prior.sizes == {"hdi": 2}
111
111
  range_posterior = result_posterior.mu.values[1] - result_posterior.mu.values[0]
112
112
  range_prior = result_prior.mu.values[1] - result_prior.mu.values[0]
113
113
  assert range_posterior < range_prior
@@ -0,0 +1,153 @@
1
+ import pytest
2
+
3
+ import numpy as np
4
+ import scipy.stats
5
+ from ...stats.ecdf_utils import (
6
+ compute_ecdf,
7
+ ecdf_confidence_band,
8
+ _get_ecdf_points,
9
+ _simulate_ecdf,
10
+ _get_pointwise_confidence_band,
11
+ )
12
+
13
+
14
+ def test_compute_ecdf():
15
+ """Test compute_ecdf function."""
16
+ sample = np.array([1, 2, 3, 3, 4, 5])
17
+ eval_points = np.arange(0, 7, 0.1)
18
+ ecdf_expected = (sample[:, None] <= eval_points).mean(axis=0)
19
+ assert np.allclose(compute_ecdf(sample, eval_points), ecdf_expected)
20
+ assert np.allclose(compute_ecdf(sample / 2 + 10, eval_points / 2 + 10), ecdf_expected)
21
+
22
+
23
+ @pytest.mark.parametrize("difference", [True, False])
24
+ def test_get_ecdf_points(difference):
25
+ """Test _get_ecdf_points."""
26
+ # if first point already outside support, no need to insert it
27
+ sample = np.array([1, 2, 3, 3, 4, 5, 5])
28
+ eval_points = np.arange(-1, 7, 0.1)
29
+ x, y = _get_ecdf_points(sample, eval_points, difference)
30
+ assert np.array_equal(x, eval_points)
31
+ assert np.array_equal(y, compute_ecdf(sample, eval_points))
32
+
33
+ # if first point is inside support, insert it if not in difference mode
34
+ eval_points = np.arange(1, 6, 0.1)
35
+ x, y = _get_ecdf_points(sample, eval_points, difference)
36
+ assert len(x) == len(eval_points) + 1 - difference
37
+ assert len(y) == len(eval_points) + 1 - difference
38
+
39
+ # if not in difference mode, first point should be (eval_points[0], 0)
40
+ if not difference:
41
+ assert x[0] == eval_points[0]
42
+ assert y[0] == 0
43
+ assert np.allclose(x[1:], eval_points)
44
+ assert np.allclose(y[1:], compute_ecdf(sample, eval_points))
45
+ assert x[-1] == eval_points[-1]
46
+ assert y[-1] == 1
47
+
48
+
49
+ @pytest.mark.parametrize(
50
+ "dist", [scipy.stats.norm(3, 10), scipy.stats.binom(10, 0.5)], ids=["continuous", "discrete"]
51
+ )
52
+ @pytest.mark.parametrize("seed", [32, 87])
53
+ def test_simulate_ecdf(dist, seed):
54
+ """Test _simulate_ecdf."""
55
+ ndraws = 1000
56
+ eval_points = np.arange(0, 1, 0.1)
57
+
58
+ rvs = dist.rvs
59
+
60
+ random_state = np.random.default_rng(seed)
61
+ ecdf = _simulate_ecdf(ndraws, eval_points, rvs, random_state=random_state)
62
+ random_state = np.random.default_rng(seed)
63
+ ecdf_expected = compute_ecdf(np.sort(rvs(ndraws, random_state=random_state)), eval_points)
64
+
65
+ assert np.allclose(ecdf, ecdf_expected)
66
+
67
+
68
+ @pytest.mark.parametrize("prob", [0.8, 0.9])
69
+ @pytest.mark.parametrize(
70
+ "dist", [scipy.stats.norm(3, 10), scipy.stats.poisson(100)], ids=["continuous", "discrete"]
71
+ )
72
+ @pytest.mark.parametrize("ndraws", [10_000])
73
+ def test_get_pointwise_confidence_band(dist, prob, ndraws, num_trials=1_000, seed=57):
74
+ """Test _get_pointwise_confidence_band."""
75
+ eval_points = np.linspace(*dist.interval(0.99), 10)
76
+ cdf_at_eval_points = dist.cdf(eval_points)
77
+
78
+ ecdf_lower, ecdf_upper = _get_pointwise_confidence_band(prob, ndraws, cdf_at_eval_points)
79
+
80
+ # check basic properties
81
+ assert np.all(ecdf_lower >= 0)
82
+ assert np.all(ecdf_upper <= 1)
83
+ assert np.all(ecdf_lower <= ecdf_upper)
84
+
85
+ # use simulation to estimate lower and upper bounds on pointwise probability
86
+ in_interval = []
87
+ random_state = np.random.default_rng(seed)
88
+ for _ in range(num_trials):
89
+ ecdf = _simulate_ecdf(ndraws, eval_points, dist.rvs, random_state=random_state)
90
+ in_interval.append((ecdf_lower <= ecdf) & (ecdf < ecdf_upper))
91
+ asymptotic_dist = scipy.stats.norm(
92
+ np.mean(in_interval, axis=0), scipy.stats.sem(in_interval, axis=0)
93
+ )
94
+ prob_lower, prob_upper = asymptotic_dist.interval(0.999)
95
+
96
+ # check target probability within all bounds
97
+ assert np.all(prob_lower <= prob)
98
+ assert np.all(prob <= prob_upper)
99
+
100
+
101
+ @pytest.mark.parametrize("prob", [0.8, 0.9])
102
+ @pytest.mark.parametrize(
103
+ "dist, rvs",
104
+ [
105
+ (scipy.stats.norm(3, 10), scipy.stats.norm(3, 10).rvs),
106
+ (scipy.stats.norm(3, 10), None),
107
+ (scipy.stats.poisson(100), scipy.stats.poisson(100).rvs),
108
+ ],
109
+ ids=["continuous", "continuous default rvs", "discrete"],
110
+ )
111
+ @pytest.mark.parametrize("ndraws", [10_000])
112
+ @pytest.mark.parametrize("method", ["pointwise", "simulated"])
113
+ def test_ecdf_confidence_band(dist, rvs, prob, ndraws, method, num_trials=1_000, seed=57):
114
+ """Test test_ecdf_confidence_band."""
115
+ eval_points = np.linspace(*dist.interval(0.99), 10)
116
+ cdf_at_eval_points = dist.cdf(eval_points)
117
+ random_state = np.random.default_rng(seed)
118
+
119
+ ecdf_lower, ecdf_upper = ecdf_confidence_band(
120
+ ndraws,
121
+ eval_points,
122
+ cdf_at_eval_points,
123
+ prob=prob,
124
+ rvs=rvs,
125
+ random_state=random_state,
126
+ method=method,
127
+ )
128
+
129
+ if method == "pointwise":
130
+ # these values tested elsewhere, we just make sure they're the same
131
+ ecdf_lower_pointwise, ecdf_upper_pointwise = _get_pointwise_confidence_band(
132
+ prob, ndraws, cdf_at_eval_points
133
+ )
134
+ assert np.array_equal(ecdf_lower, ecdf_lower_pointwise)
135
+ assert np.array_equal(ecdf_upper, ecdf_upper_pointwise)
136
+ return
137
+
138
+ # check basic properties
139
+ assert np.all(ecdf_lower >= 0)
140
+ assert np.all(ecdf_upper <= 1)
141
+ assert np.all(ecdf_lower <= ecdf_upper)
142
+
143
+ # use simulation to estimate lower and upper bounds on simultaneous probability
144
+ in_envelope = []
145
+ random_state = np.random.default_rng(seed)
146
+ for _ in range(num_trials):
147
+ ecdf = _simulate_ecdf(ndraws, eval_points, dist.rvs, random_state=random_state)
148
+ in_envelope.append(np.all(ecdf_lower <= ecdf) & np.all(ecdf < ecdf_upper))
149
+ asymptotic_dist = scipy.stats.norm(np.mean(in_envelope), scipy.stats.sem(in_envelope))
150
+ prob_lower, prob_upper = asymptotic_dist.interval(0.999)
151
+
152
+ # check target probability within bounds
153
+ assert prob_lower <= prob <= prob_upper
@@ -1,4 +1,5 @@
1
1
  """Tests for stats_utils."""
2
+
2
3
  # pylint: disable=no-member
3
4
  import numpy as np
4
5
  import pytest
@@ -344,9 +345,9 @@ def test_variance_bad_data():
344
345
 
345
346
  def test_histogram():
346
347
  school = load_arviz_data("non_centered_eight").posterior["mu"].values
347
- k_count_az, k_dens_az, _ = histogram(school, bins=np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf]))
348
- k_dens_np, *_ = np.histogram(school, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf], density=True)
349
- k_count_np, *_ = np.histogram(school, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf], density=False)
348
+ k_count_az, k_dens_az, _ = histogram(school, bins=np.asarray([-np.inf, 0.5, 0.7, 1, np.inf]))
349
+ k_dens_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=True)
350
+ k_count_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=False)
350
351
  assert np.allclose(k_count_az, k_count_np)
351
352
  assert np.allclose(k_dens_az, k_dens_np)
352
353
 
@@ -1,4 +1,5 @@
1
1
  """Tests for arviz.utils."""
2
+
2
3
  # pylint: disable=redefined-outer-name, no-member
3
4
  from unittest.mock import Mock
4
5
 
@@ -17,7 +18,7 @@ from ...utils import (
17
18
  one_de,
18
19
  two_de,
19
20
  )
20
- from ..helpers import TestRandomVariable
21
+ from ..helpers import RandomVariableTestClass
21
22
 
22
23
 
23
24
  @pytest.fixture(scope="session")
@@ -123,7 +124,7 @@ def test_var_names_filter(var_args):
123
124
 
124
125
  def test_nonstring_var_names():
125
126
  """Check that non-string variables are preserved"""
126
- mu = TestRandomVariable("mu")
127
+ mu = RandomVariableTestClass("mu")
127
128
  samples = np.random.randn(10)
128
129
  data = dict_to_dataset({mu: samples})
129
130
  assert _var_names([mu], data) == [mu]
@@ -101,8 +101,8 @@ class TestDataNumPyro:
101
101
  assert not fails
102
102
 
103
103
  # test dims
104
- dims = inference_data.posterior_predictive.dims["school"]
105
- pred_dims = inference_data.predictions.dims["school_pred"]
104
+ dims = inference_data.posterior_predictive.sizes["school"]
105
+ pred_dims = inference_data.predictions.sizes["school_pred"]
106
106
  assert dims == 8
107
107
  assert pred_dims == 8
108
108
 
@@ -240,7 +240,7 @@ class TestDataNumPyro:
240
240
  def test_inference_data_num_chains(self, predictions_data, chains):
241
241
  predictions = predictions_data
242
242
  inference_data = from_numpyro(predictions=predictions, num_chains=chains)
243
- nchains = inference_data.predictions.dims["chain"]
243
+ nchains = inference_data.predictions.sizes["chain"]
244
244
  assert nchains == chains
245
245
 
246
246
  @pytest.mark.parametrize("nchains", [1, 2])
@@ -83,8 +83,8 @@ class TestDataPyro:
83
83
  assert not fails
84
84
 
85
85
  # test dims
86
- dims = inference_data.posterior_predictive.dims["school"]
87
- pred_dims = inference_data.predictions.dims["school_pred"]
86
+ dims = inference_data.posterior_predictive.sizes["school"]
87
+ pred_dims = inference_data.predictions.sizes["school_pred"]
88
88
  assert dims == 8
89
89
  assert pred_dims == 8
90
90
 
@@ -225,7 +225,7 @@ class TestDataPyro:
225
225
  def test_inference_data_num_chains(self, predictions_data, chains):
226
226
  predictions = predictions_data
227
227
  inference_data = from_pyro(predictions=predictions, num_chains=chains)
228
- nchains = inference_data.predictions.dims["chain"]
228
+ nchains = inference_data.predictions.sizes["chain"]
229
229
  assert nchains == chains
230
230
 
231
231
  @pytest.mark.parametrize("log_likelihood", [True, False])
arviz/tests/helpers.py CHANGED
@@ -18,7 +18,7 @@ from ..data import InferenceData, from_dict
18
18
  _log = logging.getLogger(__name__)
19
19
 
20
20
 
21
- class TestRandomVariable:
21
+ class RandomVariableTestClass:
22
22
  """Example class for random variables."""
23
23
 
24
24
  def __init__(self, name):
@@ -1,4 +1,5 @@
1
1
  """Sampling wrappers."""
2
+
2
3
  from .base import SamplingWrapper
3
4
  from .wrap_stan import PyStan2SamplingWrapper, PyStanSamplingWrapper, CmdStanPySamplingWrapper
4
5
  from .wrap_pymc import PyMCSamplingWrapper
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: arviz
3
- Version: 0.17.0
3
+ Version: 0.18.0
4
4
  Summary: Exploratory analysis of Bayesian models
5
5
  Home-page: http://github.com/arviz-devs/arviz
6
6
  Author: ArviZ Developers
@@ -12,22 +12,23 @@ Classifier: Intended Audience :: Education
12
12
  Classifier: License :: OSI Approved :: Apache Software License
13
13
  Classifier: Programming Language :: Python
14
14
  Classifier: Programming Language :: Python :: 3
15
- Classifier: Programming Language :: Python :: 3.9
16
15
  Classifier: Programming Language :: Python :: 3.10
17
16
  Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
18
  Classifier: Topic :: Scientific/Engineering
19
19
  Classifier: Topic :: Scientific/Engineering :: Visualization
20
20
  Classifier: Topic :: Scientific/Engineering :: Mathematics
21
- Requires-Python: >=3.9
21
+ Requires-Python: >=3.10
22
22
  Description-Content-Type: text/markdown
23
23
  License-File: LICENSE
24
24
  Requires-Dist: setuptools >=60.0.0
25
25
  Requires-Dist: matplotlib >=3.5
26
- Requires-Dist: numpy <2.0,>=1.22.0
27
- Requires-Dist: scipy >=1.8.0
26
+ Requires-Dist: numpy <2.0,>=1.23.0
27
+ Requires-Dist: scipy >=1.9.0
28
28
  Requires-Dist: packaging
29
- Requires-Dist: pandas >=1.4.0
30
- Requires-Dist: xarray >=0.21.0
29
+ Requires-Dist: pandas >=1.5.0
30
+ Requires-Dist: dm-tree >=0.1.8
31
+ Requires-Dist: xarray >=2022.6.0
31
32
  Requires-Dist: h5netcdf >=1.0.2
32
33
  Requires-Dist: typing-extensions >=4.1.0
33
34
  Requires-Dist: xarray-einstats >=0.3
@@ -52,8 +53,7 @@ Requires-Dist: xarray-datatree ; extra == 'all'
52
53
  [![DOI](http://joss.theoj.org/papers/10.21105/joss.01143/status.svg)](https://doi.org/10.21105/joss.01143) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.2540945.svg)](https://doi.org/10.5281/zenodo.2540945)
53
54
  [![Powered by NumFOCUS](https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)](https://numfocus.org)
54
55
 
55
- ArviZ (pronounced "AR-_vees_") is a Python package for exploratory analysis of Bayesian models.
56
- Includes functions for posterior analysis, data storage, model checking, comparison and diagnostics.
56
+ ArviZ (pronounced "AR-_vees_") is a Python package for exploratory analysis of Bayesian models. It includes functions for posterior analysis, data storage, model checking, comparison and diagnostics.
57
57
 
58
58
  ### ArviZ in other languages
59
59
  ArviZ also has a Julia wrapper available [ArviZ.jl](https://julia.arviz.org/).
@@ -202,6 +202,7 @@ python setup.py install
202
202
 
203
203
  <a href="https://python.arviz.org/en/latest/examples/index.html">And more...</a>
204
204
  </div>
205
+
205
206
  ## Dependencies
206
207
 
207
208
  ArviZ is tested on Python 3.10, 3.11 and 3.12, and depends on NumPy, SciPy, xarray, and Matplotlib.