arviz 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. arviz/__init__.py +1 -1
  2. arviz/data/base.py +20 -9
  3. arviz/data/converters.py +7 -3
  4. arviz/data/inference_data.py +28 -7
  5. arviz/plots/backends/__init__.py +8 -7
  6. arviz/plots/backends/bokeh/bpvplot.py +4 -3
  7. arviz/plots/backends/bokeh/densityplot.py +5 -1
  8. arviz/plots/backends/bokeh/dotplot.py +5 -2
  9. arviz/plots/backends/bokeh/essplot.py +4 -2
  10. arviz/plots/backends/bokeh/forestplot.py +11 -4
  11. arviz/plots/backends/bokeh/khatplot.py +4 -2
  12. arviz/plots/backends/bokeh/lmplot.py +9 -3
  13. arviz/plots/backends/bokeh/mcseplot.py +2 -2
  14. arviz/plots/backends/bokeh/pairplot.py +10 -5
  15. arviz/plots/backends/bokeh/ppcplot.py +2 -1
  16. arviz/plots/backends/bokeh/rankplot.py +2 -1
  17. arviz/plots/backends/bokeh/traceplot.py +2 -1
  18. arviz/plots/backends/bokeh/violinplot.py +2 -1
  19. arviz/plots/backends/matplotlib/bpvplot.py +2 -1
  20. arviz/plots/bfplot.py +9 -26
  21. arviz/plots/bpvplot.py +10 -1
  22. arviz/plots/compareplot.py +4 -4
  23. arviz/plots/ecdfplot.py +16 -8
  24. arviz/plots/forestplot.py +2 -2
  25. arviz/plots/hdiplot.py +5 -0
  26. arviz/plots/kdeplot.py +9 -2
  27. arviz/plots/plot_utils.py +5 -3
  28. arviz/preview.py +36 -5
  29. arviz/stats/__init__.py +1 -0
  30. arviz/stats/diagnostics.py +18 -14
  31. arviz/stats/ecdf_utils.py +157 -2
  32. arviz/stats/stats.py +99 -7
  33. arviz/tests/base_tests/test_data.py +41 -7
  34. arviz/tests/base_tests/test_diagnostics.py +5 -4
  35. arviz/tests/base_tests/test_plots_matplotlib.py +32 -13
  36. arviz/tests/base_tests/test_stats.py +11 -0
  37. arviz/tests/base_tests/test_stats_ecdf_utils.py +15 -2
  38. arviz/utils.py +4 -0
  39. {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/METADATA +22 -22
  40. {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/RECORD +43 -43
  41. {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/WHEEL +1 -1
  42. {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/LICENSE +0 -0
  43. {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/top_level.txt +0 -0
arviz/plots/bpvplot.py CHANGED
@@ -16,6 +16,7 @@ def plot_bpv(
16
16
  bpv=True,
17
17
  plot_mean=True,
18
18
  reference="analytical",
19
+ smoothing=None,
19
20
  mse=False,
20
21
  n_ref=100,
21
22
  hdi_prob=0.94,
@@ -72,6 +73,9 @@ def plot_bpv(
72
73
  reference : {"analytical", "samples", None}, default "analytical"
73
74
  How to compute the distributions used as reference for ``kind=u_values``
74
75
  or ``kind=p_values``. Use `None` to not plot any reference.
76
+ smoothing : bool, optional
77
+ If True and the data has integer dtype, smooth the data before computing the p-values,
78
+ u-values or tstat. By default, True when `kind` is "u_value" and False otherwise.
75
79
  mse : bool, default False
76
80
  Show scaled mean square error between uniform distribution and marginal p_value
77
81
  distribution.
@@ -166,7 +170,8 @@ def plot_bpv(
166
170
  Notes
167
171
  -----
168
172
  Discrete data is smoothed before computing either p-values or u-values using the
169
- function :func:`~arviz.smooth_data`
173
+ function :func:`~arviz.smooth_data` if the data is integer type
174
+ and the smoothing parameter is True.
170
175
 
171
176
  Examples
172
177
  --------
@@ -206,6 +211,9 @@ def plot_bpv(
206
211
  elif not 1 >= hdi_prob > 0:
207
212
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
208
213
 
214
+ if smoothing is None:
215
+ smoothing = kind.lower() == "u_value"
216
+
209
217
  if data_pairs is None:
210
218
  data_pairs = {}
211
219
 
@@ -291,6 +299,7 @@ def plot_bpv(
291
299
  plot_ref_kwargs=plot_ref_kwargs,
292
300
  backend_kwargs=backend_kwargs,
293
301
  show=show,
302
+ smoothing=smoothing,
294
303
  )
295
304
 
296
305
  # TODO: Add backend kwargs
@@ -11,9 +11,9 @@ def plot_compare(
11
11
  comp_df,
12
12
  insample_dev=False,
13
13
  plot_standard_error=True,
14
- plot_ic_diff=True,
14
+ plot_ic_diff=False,
15
15
  order_by_rank=True,
16
- legend=True,
16
+ legend=False,
17
17
  title=True,
18
18
  figsize=None,
19
19
  textsize=None,
@@ -45,12 +45,12 @@ def plot_compare(
45
45
  penalization given by the effective number of parameters (p_loo or p_waic).
46
46
  plot_standard_error : bool, default True
47
47
  Plot the standard error of the ELPD.
48
- plot_ic_diff : bool, default True
48
+ plot_ic_diff : bool, default False
49
49
  Plot standard error of the difference in ELPD between each model
50
50
  and the top-ranked model.
51
51
  order_by_rank : bool, default True
52
52
  If True ensure the best model is used as reference.
53
- legend : bool, default True
53
+ legend : bool, default False
54
54
  Add legend to figure.
55
55
  figsize : (float, float), optional
56
56
  If `None`, size is (6, num of models) inches.
arviz/plots/ecdfplot.py CHANGED
@@ -73,6 +73,7 @@ def plot_ecdf(
73
73
  - False: No confidence bands are plotted (default).
74
74
  - True: Plot bands computed with the default algorithm (subject to change)
75
75
  - "pointwise": Compute the pointwise (i.e. marginal) confidence band.
76
+ - "optimized": Use optimization to estimate a simultaneous confidence band.
76
77
  - "simulated": Use Monte Carlo simulation to estimate a simultaneous confidence
77
78
  band.
78
79
 
@@ -216,8 +217,7 @@ def plot_ecdf(
216
217
  >>> pit_vals = distribution.cdf(sample)
217
218
  >>> uniform_dist = uniform(0, 1)
218
219
  >>> az.plot_ecdf(
219
- >>> pit_vals, cdf=uniform_dist.cdf,
220
- >>> rvs=uniform_dist.rvs, confidence_bands=True
220
+ >>> pit_vals, cdf=uniform_dist.cdf, confidence_bands=True,
221
221
  >>> )
222
222
 
223
223
  Plot an ECDF-difference plot of PIT values.
@@ -226,8 +226,8 @@ def plot_ecdf(
226
226
  :context: close-figs
227
227
 
228
228
  >>> az.plot_ecdf(
229
- >>> pit_vals, cdf = uniform_dist.cdf, rvs = uniform_dist.rvs,
230
- >>> confidence_bands = True, difference = True
229
+ >>> pit_vals, cdf = uniform_dist.cdf, confidence_bands = True,
230
+ >>> difference = True
231
231
  >>> )
232
232
  """
233
233
  if confidence_bands is True:
@@ -238,9 +238,12 @@ def plot_ecdf(
238
238
  )
239
239
  confidence_bands = "pointwise"
240
240
  else:
241
- confidence_bands = "simulated"
242
- elif confidence_bands == "simulated" and pointwise:
243
- raise ValueError("Cannot specify both `confidence_bands='simulated'` and `pointwise=True`")
241
+ confidence_bands = "auto"
242
+ # if pointwise specified, confidence_bands must be a bool or 'pointwise'
243
+ elif confidence_bands not in [False, "pointwise"] and pointwise:
244
+ raise ValueError(
245
+ f"Cannot specify both `confidence_bands='{confidence_bands}'` and `pointwise=True`"
246
+ )
244
247
 
245
248
  if fpr is not None:
246
249
  warnings.warn(
@@ -298,7 +301,7 @@ def plot_ecdf(
298
301
  "`eval_points` explicitly.",
299
302
  BehaviourChangeWarning,
300
303
  )
301
- if confidence_bands == "simulated":
304
+ if confidence_bands in ["optimized", "simulated"]:
302
305
  warnings.warn(
303
306
  "For simultaneous bands to be correctly calibrated, specify `eval_points` "
304
307
  "independent of the `values`"
@@ -319,6 +322,11 @@ def plot_ecdf(
319
322
 
320
323
  if confidence_bands:
321
324
  ndraws = len(values)
325
+ if confidence_bands == "auto":
326
+ if ndraws < 200 or num_trials >= 250 * np.sqrt(ndraws):
327
+ confidence_bands = "optimized"
328
+ else:
329
+ confidence_bands = "simulated"
322
330
  x_bands = eval_points
323
331
  lower, higher = ecdf_confidence_band(
324
332
  ndraws,
arviz/plots/forestplot.py CHANGED
@@ -55,8 +55,8 @@ def plot_forest(
55
55
  Specify the kind of plot:
56
56
 
57
57
  * The ``kind="forestplot"`` generates credible intervals, where the central points are the
58
- estimated posterior means, the thick lines are the central quartiles, and the thin lines
59
- represent the :math:`100\times`(`hdi_prob`)% highest density intervals.
58
+ estimated posterior median, the thick lines are the central quartiles, and the thin lines
59
+ represent the :math:`100\times(hdi\_prob)\%` highest density intervals.
60
60
  * The ``kind="ridgeplot"`` option generates density plots (kernel density estimate or
61
61
  histograms) in the same graph. Ridge plots can be configured to have different overlap,
62
62
  truncation bounds and quantile markers.
arviz/plots/hdiplot.py CHANGED
@@ -136,6 +136,11 @@ def plot_hdi(
136
136
  x = np.asarray(x)
137
137
  x_shape = x.shape
138
138
 
139
+ if isinstance(x[0], str):
140
+ raise NotImplementedError(
141
+ "The `arviz.plot_hdi()` function does not support categorical data. "
142
+ "Consider using `arviz.plot_forest()`."
143
+ )
139
144
  if y is None and hdi_data is None:
140
145
  raise ValueError("One of {y, hdi_data} is required")
141
146
  if hdi_data is not None and y is not None:
arviz/plots/kdeplot.py CHANGED
@@ -72,7 +72,7 @@ def plot_kde(
72
72
  If True plot the 2D KDE using contours, otherwise plot a smooth 2D KDE.
73
73
  hdi_probs : list, optional
74
74
  Plots highest density credibility regions for the provided probabilities for a 2D KDE.
75
- Defaults to matplotlib chosen levels with no fixed probability associated.
75
+ Defaults to [0.5, 0.8, 0.94].
76
76
  fill_last : bool, default False
77
77
  If True fill the last contour of the 2D KDE plot.
78
78
  figsize : (float, float), optional
@@ -270,6 +270,9 @@ def plot_kde(
270
270
  gridsize = (128, 128) if contour else (256, 256)
271
271
  density, xmin, xmax, ymin, ymax = _fast_kde_2d(values, values2, gridsize=gridsize)
272
272
 
273
+ if hdi_probs is None:
274
+ hdi_probs = [0.5, 0.8, 0.94]
275
+
273
276
  if hdi_probs is not None:
274
277
  # Check hdi probs are within bounds (0, 1)
275
278
  if min(hdi_probs) <= 0 or max(hdi_probs) >= 1:
@@ -289,7 +292,11 @@ def plot_kde(
289
292
  "Using 'hdi_probs' in favor of 'levels'.",
290
293
  UserWarning,
291
294
  )
292
- contour_kwargs["levels"] = contour_level_list
295
+
296
+ if backend == "bokeh":
297
+ contour_kwargs["levels"] = contour_level_list
298
+ elif backend == "matplotlib":
299
+ contour_kwargs["levels"] = contour_level_list[1:]
293
300
 
294
301
  contourf_kwargs = _init_kwargs_dict(contourf_kwargs)
295
302
  if "levels" in contourf_kwargs:
arviz/plots/plot_utils.py CHANGED
@@ -482,16 +482,18 @@ def plot_point_interval(
482
482
  if point_estimate:
483
483
  point_value = calculate_point_estimate(point_estimate, values)
484
484
  if rotated:
485
- ax.circle(
485
+ ax.scatter(
486
486
  x=0,
487
487
  y=point_value,
488
+ marker="circle",
488
489
  size=markersize,
489
490
  fill_color=markercolor,
490
491
  )
491
492
  else:
492
- ax.circle(
493
+ ax.scatter(
493
494
  x=point_value,
494
495
  y=0,
496
+ marker="circle",
495
497
  size=markersize,
496
498
  fill_color=markercolor,
497
499
  )
@@ -534,7 +536,7 @@ def set_bokeh_circular_ticks_labels(ax, hist, labels):
534
536
  )
535
537
 
536
538
  radii_circles = np.linspace(0, np.max(hist) * 1.1, 4)
537
- ax.circle(0, 0, radius=radii_circles, fill_color=None, line_color="grey")
539
+ ax.scatter(0, 0, marker="circle", radius=radii_circles, fill_color=None, line_color="grey")
538
540
 
539
541
  offset = np.max(hist * 1.05) * 0.15
540
542
  ticks_labels_pos_1 = np.max(hist * 1.05)
arviz/preview.py CHANGED
@@ -1,17 +1,48 @@
1
- # pylint: disable=unused-import,unused-wildcard-import,wildcard-import
1
+ # pylint: disable=unused-import,unused-wildcard-import,wildcard-import,invalid-name
2
2
  """Expose features from arviz-xyz refactored packages inside ``arviz.preview`` namespace."""
3
+ import logging
4
+
5
+ _log = logging.getLogger(__name__)
6
+
7
+ info = ""
3
8
 
4
9
  try:
5
10
  from arviz_base import *
11
+
12
+ status = "arviz_base available, exposing its functions as part of arviz.preview"
13
+ _log.info(status)
6
14
  except ModuleNotFoundError:
7
- pass
15
+ status = "arviz_base not installed"
16
+ _log.info(status)
17
+ except ImportError:
18
+ status = "Unable to import arviz_base"
19
+ _log.info(status, exc_info=True)
20
+
21
+ info += status + "\n"
8
22
 
9
23
  try:
10
- import arviz_stats
24
+ from arviz_stats import *
25
+
26
+ status = "arviz_stats available, exposing its functions as part of arviz.preview"
27
+ _log.info(status)
11
28
  except ModuleNotFoundError:
12
- pass
29
+ status = "arviz_stats not installed"
30
+ _log.info(status)
31
+ except ImportError:
32
+ status = "Unable to import arviz_stats"
33
+ _log.info(status, exc_info=True)
34
+ info += status + "\n"
13
35
 
14
36
  try:
15
37
  from arviz_plots import *
38
+
39
+ status = "arviz_plots available, exposing its functions as part of arviz.preview"
40
+ _log.info(status)
16
41
  except ModuleNotFoundError:
17
- pass
42
+ status = "arviz_plots not installed"
43
+ _log.info(status)
44
+ except ImportError:
45
+ status = "Unable to import arviz_plots"
46
+ _log.info(status, exc_info=True)
47
+
48
+ info += status + "\n"
arviz/stats/__init__.py CHANGED
@@ -9,6 +9,7 @@ from .stats_utils import *
9
9
 
10
10
  __all__ = [
11
11
  "apply_test_function",
12
+ "bayes_factor",
12
13
  "bfmi",
13
14
  "compare",
14
15
  "hdi",
@@ -744,8 +744,8 @@ def _ess_sd(ary, relative=False):
744
744
  ary = np.asarray(ary)
745
745
  if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
746
746
  return np.nan
747
- ary = _split_chains(ary)
748
- return min(_ess(ary, relative=relative), _ess(ary**2, relative=relative))
747
+ ary = (ary - ary.mean()) ** 2
748
+ return _ess(_split_chains(ary), relative=relative)
749
749
 
750
750
 
751
751
  def _ess_quantile(ary, prob, relative=False):
@@ -838,13 +838,15 @@ def _mcse_sd(ary):
838
838
  ary = np.asarray(ary)
839
839
  if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
840
840
  return np.nan
841
- ess = _ess_sd(ary)
841
+ sims_c2 = (ary - ary.mean()) ** 2
842
+ ess = _ess_mean(sims_c2)
843
+ evar = (sims_c2).mean()
844
+ varvar = ((sims_c2**2).mean() - evar**2) / ess
845
+ varsd = varvar / evar / 4
842
846
  if _numba_flag:
843
- sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)).item())
847
+ mcse_sd_value = float(_sqrt(np.ravel(varsd), np.zeros(1)))
844
848
  else:
845
- sd = np.std(ary, ddof=1)
846
- fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
847
- mcse_sd_value = sd * fac_mcse_sd
849
+ mcse_sd_value = np.sqrt(varsd)
848
850
  return mcse_sd_value
849
851
 
850
852
 
@@ -973,19 +975,21 @@ def _multichain_statistics(ary, focus="mean"):
973
975
  # ess mean
974
976
  ess_mean_value = _ess_mean(ary)
975
977
 
976
- # ess sd
977
- ess_sd_value = _ess_sd(ary)
978
-
979
978
  # mcse_mean
980
- sd = np.std(ary, ddof=1)
981
- mcse_mean_value = sd / np.sqrt(ess_mean_value)
979
+ sims_c2 = (ary - ary.mean()) ** 2
980
+ sims_c2_sum = sims_c2.sum()
981
+ var = sims_c2_sum / (sims_c2.size - 1)
982
+ mcse_mean_value = np.sqrt(var / ess_mean_value)
982
983
 
983
984
  # ess bulk
984
985
  ess_bulk_value = _ess(z_split)
985
986
 
986
987
  # mcse_sd
987
- fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess_sd_value) ** (ess_sd_value - 1) - 1)
988
- mcse_sd_value = sd * fac_mcse_sd
988
+ evar = sims_c2_sum / sims_c2.size
989
+ ess_mean_sims = _ess_mean(sims_c2)
990
+ varvar = ((sims_c2**2).mean() - evar**2) / ess_mean_sims
991
+ varsd = varvar / evar / 4
992
+ mcse_sd_value = np.sqrt(varsd)
989
993
 
990
994
  return (
991
995
  mcse_mean_value,
arviz/stats/ecdf_utils.py CHANGED
@@ -1,10 +1,25 @@
1
1
  """Functions for evaluating ECDFs and their confidence bands."""
2
2
 
3
+ import math
3
4
  from typing import Any, Callable, Optional, Tuple
4
5
  import warnings
5
6
 
6
7
  import numpy as np
7
8
  from scipy.stats import uniform, binom
9
+ from scipy.optimize import minimize_scalar
10
+
11
+ try:
12
+ from numba import jit, vectorize
13
+ except ImportError:
14
+
15
+ def jit(*args, **kwargs): # pylint: disable=unused-argument
16
+ return lambda f: f
17
+
18
+ def vectorize(*args, **kwargs): # pylint: disable=unused-argument
19
+ return lambda f: f
20
+
21
+
22
+ from ..utils import Numba
8
23
 
9
24
 
10
25
  def compute_ecdf(sample: np.ndarray, eval_points: np.ndarray) -> np.ndarray:
@@ -73,7 +88,7 @@ def ecdf_confidence_band(
73
88
  eval_points: np.ndarray,
74
89
  cdf_at_eval_points: np.ndarray,
75
90
  prob: float = 0.95,
76
- method="simulated",
91
+ method="optimized",
77
92
  **kwargs,
78
93
  ) -> Tuple[np.ndarray, np.ndarray]:
79
94
  """Compute the `prob`-level confidence band for the ECDF.
@@ -92,6 +107,7 @@ def ecdf_confidence_band(
92
107
  method : string, default "simulated"
93
108
  The method used to compute the confidence band. Valid options are:
94
109
  - "pointwise": Compute the pointwise (i.e. marginal) confidence band.
110
+ - "optimized": Use optimization to estimate a simultaneous confidence band.
95
111
  - "simulated": Use Monte Carlo simulation to estimate a simultaneous confidence band.
96
112
  `rvs` must be provided.
97
113
  rvs: callable, optional
@@ -115,12 +131,18 @@ def ecdf_confidence_band(
115
131
 
116
132
  if method == "pointwise":
117
133
  prob_pointwise = prob
134
+ elif method == "optimized":
135
+ prob_pointwise = _optimize_simultaneous_ecdf_band_probability(
136
+ ndraws, eval_points, cdf_at_eval_points, prob=prob, **kwargs
137
+ )
118
138
  elif method == "simulated":
119
139
  prob_pointwise = _simulate_simultaneous_ecdf_band_probability(
120
140
  ndraws, eval_points, cdf_at_eval_points, prob=prob, **kwargs
121
141
  )
122
142
  else:
123
- raise ValueError(f"Unknown method {method}. Valid options are 'pointwise' or 'simulated'.")
143
+ raise ValueError(
144
+ f"Unknown method {method}. Valid options are 'pointwise', 'optimized', or 'simulated'."
145
+ )
124
146
 
125
147
  prob_lower, prob_upper = _get_pointwise_confidence_band(
126
148
  prob_pointwise, ndraws, cdf_at_eval_points
@@ -129,6 +151,139 @@ def ecdf_confidence_band(
129
151
  return prob_lower, prob_upper
130
152
 
131
153
 
154
+ def _update_ecdf_band_interior_probabilities(
155
+ prob_left: np.ndarray,
156
+ interval_left: np.ndarray,
157
+ interval_right: np.ndarray,
158
+ p: float,
159
+ ndraws: int,
160
+ ) -> np.ndarray:
161
+ """Update the probability that an ECDF has been within the envelope including at the current
162
+ point.
163
+
164
+ Arguments
165
+ ---------
166
+ prob_left : np.ndarray
167
+ For each point in the interior at the previous point, the joint probability that it and all
168
+ points before are in the interior.
169
+ interval_left : np.ndarray
170
+ The set of points in the interior at the previous point.
171
+ interval_right : np.ndarray
172
+ The set of points in the interior at the current point.
173
+ p : float
174
+ The probability of any given point found between the previous point and the current one.
175
+ ndraws : int
176
+ Number of draws in the original dataset.
177
+
178
+ Returns
179
+ -------
180
+ prob_right : np.ndarray
181
+ For each point in the interior at the current point, the joint probability that it and all
182
+ previous points are in the interior.
183
+ """
184
+ interval_left = interval_left[:, np.newaxis]
185
+ prob_conditional = binom.pmf(interval_right, ndraws - interval_left, p, loc=interval_left)
186
+ prob_right = prob_left.dot(prob_conditional)
187
+ return prob_right
188
+
189
+
190
+ @vectorize(["float64(int64, int64, float64, int64)"])
191
+ def _binom_pmf(k, n, p, loc):
192
+ k -= loc
193
+ if k < 0 or k > n:
194
+ return 0.0
195
+ if p == 0:
196
+ return 1.0 if k == 0 else 0.0
197
+ if p == 1:
198
+ return 1.0 if k == n else 0.0
199
+ if k == 0:
200
+ return (1 - p) ** n
201
+ if k == n:
202
+ return p**n
203
+ lbinom = math.lgamma(n + 1) - math.lgamma(k + 1) - math.lgamma(n - k + 1)
204
+ return np.exp(lbinom + k * np.log(p) + (n - k) * np.log1p(-p))
205
+
206
+
207
+ @jit(nopython=True)
208
+ def _update_ecdf_band_interior_probabilities_numba(
209
+ prob_left: np.ndarray,
210
+ interval_left: np.ndarray,
211
+ interval_right: np.ndarray,
212
+ p: float,
213
+ ndraws: int,
214
+ ) -> np.ndarray:
215
+ interval_left = interval_left[:, np.newaxis]
216
+ prob_conditional = _binom_pmf(interval_right, ndraws - interval_left, p, interval_left)
217
+ prob_right = prob_left.dot(prob_conditional)
218
+ return prob_right
219
+
220
+
221
+ def _ecdf_band_interior_probability(prob_between_points, ndraws, lower_count, upper_count):
222
+ interval_left = np.arange(1)
223
+ prob_interior = np.ones(1)
224
+ for i in range(prob_between_points.shape[0]):
225
+ interval_right = np.arange(lower_count[i], upper_count[i])
226
+ prob_interior = _update_ecdf_band_interior_probabilities(
227
+ prob_interior, interval_left, interval_right, prob_between_points[i], ndraws
228
+ )
229
+ interval_left = interval_right
230
+ return prob_interior.sum()
231
+
232
+
233
+ @jit(nopython=True)
234
+ def _ecdf_band_interior_probability_numba(prob_between_points, ndraws, lower_count, upper_count):
235
+ interval_left = np.arange(1)
236
+ prob_interior = np.ones(1)
237
+ for i in range(prob_between_points.shape[0]):
238
+ interval_right = np.arange(lower_count[i], upper_count[i])
239
+ prob_interior = _update_ecdf_band_interior_probabilities_numba(
240
+ prob_interior, interval_left, interval_right, prob_between_points[i], ndraws
241
+ )
242
+ interval_left = interval_right
243
+ return prob_interior.sum()
244
+
245
+
246
+ def _ecdf_band_optimization_objective(
247
+ prob_pointwise: float,
248
+ cdf_at_eval_points: np.ndarray,
249
+ ndraws: int,
250
+ prob_target: float,
251
+ ) -> float:
252
+ """Objective function for optimizing the simultaneous confidence band probability."""
253
+ lower, upper = _get_pointwise_confidence_band(prob_pointwise, ndraws, cdf_at_eval_points)
254
+ lower_count = (lower * ndraws).astype(int)
255
+ upper_count = (upper * ndraws).astype(int) + 1
256
+ cdf_with_zero = np.insert(cdf_at_eval_points[:-1], 0, 0)
257
+ prob_between_points = (cdf_at_eval_points - cdf_with_zero) / (1 - cdf_with_zero)
258
+ if Numba.numba_flag:
259
+ prob_interior = _ecdf_band_interior_probability_numba(
260
+ prob_between_points, ndraws, lower_count, upper_count
261
+ )
262
+ else:
263
+ prob_interior = _ecdf_band_interior_probability(
264
+ prob_between_points, ndraws, lower_count, upper_count
265
+ )
266
+ return abs(prob_interior - prob_target)
267
+
268
+
269
+ def _optimize_simultaneous_ecdf_band_probability(
270
+ ndraws: int,
271
+ eval_points: np.ndarray, # pylint: disable=unused-argument
272
+ cdf_at_eval_points: np.ndarray,
273
+ prob: float = 0.95,
274
+ **kwargs, # pylint: disable=unused-argument
275
+ ):
276
+ """Estimate probability for simultaneous confidence band using optimization.
277
+
278
+ This function simulates the pointwise probability needed to construct pointwise confidence bands
279
+ that form a `prob`-level confidence envelope for the ECDF of a sample.
280
+ """
281
+ cdf_at_eval_points = np.unique(cdf_at_eval_points)
282
+ objective = lambda p: _ecdf_band_optimization_objective(p, cdf_at_eval_points, ndraws, prob)
283
+ prob_pointwise = minimize_scalar(objective, bounds=(prob, 1), method="bounded").x
284
+ return prob_pointwise
285
+
286
+
132
287
  def _simulate_simultaneous_ecdf_band_probability(
133
288
  ndraws: int,
134
289
  eval_points: np.ndarray,