arviz 0.16.1__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. arviz/__init__.py +1 -1
  2. arviz/data/inference_data.py +34 -7
  3. arviz/data/io_beanmachine.py +6 -1
  4. arviz/data/io_cmdstanpy.py +439 -50
  5. arviz/data/io_pyjags.py +5 -2
  6. arviz/data/io_pystan.py +1 -2
  7. arviz/labels.py +2 -0
  8. arviz/plots/backends/bokeh/bpvplot.py +7 -2
  9. arviz/plots/backends/bokeh/compareplot.py +7 -4
  10. arviz/plots/backends/bokeh/densityplot.py +0 -1
  11. arviz/plots/backends/bokeh/distplot.py +0 -2
  12. arviz/plots/backends/bokeh/forestplot.py +3 -5
  13. arviz/plots/backends/bokeh/kdeplot.py +0 -2
  14. arviz/plots/backends/bokeh/pairplot.py +0 -4
  15. arviz/plots/backends/matplotlib/bfplot.py +0 -1
  16. arviz/plots/backends/matplotlib/bpvplot.py +3 -3
  17. arviz/plots/backends/matplotlib/compareplot.py +1 -1
  18. arviz/plots/backends/matplotlib/dotplot.py +1 -1
  19. arviz/plots/backends/matplotlib/forestplot.py +2 -4
  20. arviz/plots/backends/matplotlib/kdeplot.py +0 -1
  21. arviz/plots/backends/matplotlib/khatplot.py +0 -1
  22. arviz/plots/backends/matplotlib/lmplot.py +4 -5
  23. arviz/plots/backends/matplotlib/pairplot.py +0 -1
  24. arviz/plots/backends/matplotlib/ppcplot.py +8 -5
  25. arviz/plots/backends/matplotlib/traceplot.py +1 -2
  26. arviz/plots/bfplot.py +7 -6
  27. arviz/plots/bpvplot.py +7 -2
  28. arviz/plots/compareplot.py +2 -2
  29. arviz/plots/ecdfplot.py +37 -112
  30. arviz/plots/elpdplot.py +1 -1
  31. arviz/plots/essplot.py +2 -2
  32. arviz/plots/kdeplot.py +0 -1
  33. arviz/plots/pairplot.py +1 -1
  34. arviz/plots/plot_utils.py +0 -1
  35. arviz/plots/ppcplot.py +51 -45
  36. arviz/plots/separationplot.py +0 -1
  37. arviz/stats/__init__.py +2 -0
  38. arviz/stats/density_utils.py +2 -2
  39. arviz/stats/diagnostics.py +2 -3
  40. arviz/stats/ecdf_utils.py +165 -0
  41. arviz/stats/stats.py +241 -38
  42. arviz/stats/stats_utils.py +36 -7
  43. arviz/tests/base_tests/test_data.py +73 -5
  44. arviz/tests/base_tests/test_plots_bokeh.py +0 -1
  45. arviz/tests/base_tests/test_plots_matplotlib.py +24 -1
  46. arviz/tests/base_tests/test_stats.py +43 -1
  47. arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
  48. arviz/tests/base_tests/test_stats_utils.py +3 -3
  49. arviz/tests/external_tests/test_data_beanmachine.py +2 -0
  50. arviz/tests/external_tests/test_data_numpyro.py +3 -3
  51. arviz/tests/external_tests/test_data_pyjags.py +3 -1
  52. arviz/tests/external_tests/test_data_pyro.py +3 -3
  53. arviz/tests/helpers.py +8 -8
  54. arviz/utils.py +15 -7
  55. arviz/wrappers/wrap_pymc.py +1 -1
  56. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/METADATA +16 -15
  57. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/RECORD +60 -58
  58. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/WHEEL +1 -1
  59. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/LICENSE +0 -0
  60. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/top_level.txt +0 -0
arviz/plots/kdeplot.py CHANGED
@@ -256,7 +256,6 @@ def plot_kde(
256
256
  )
257
257
 
258
258
  if values2 is None:
259
-
260
259
  if bw == "default":
261
260
  bw = "taylor" if is_circular else "experimental"
262
261
 
arviz/plots/pairplot.py CHANGED
@@ -229,7 +229,7 @@ def plot_pair(
229
229
  )
230
230
 
231
231
  if gridsize == "auto":
232
- gridsize = int(dataset.dims["draw"] ** 0.35)
232
+ gridsize = int(dataset.sizes["draw"] ** 0.35)
233
233
 
234
234
  numvars = len(flat_var_names)
235
235
 
arviz/plots/plot_utils.py CHANGED
@@ -364,7 +364,6 @@ def calculate_point_estimate(point_estimate, values, bw="default", circular=Fals
364
364
  else:
365
365
  point_value = int(mode(values).mode)
366
366
  elif point_estimate == "median":
367
-
368
367
  point_value = np.nanmedian(values) if skipna else np.median(values)
369
368
  return point_value
370
369
 
arviz/plots/ppcplot.py CHANGED
@@ -19,7 +19,7 @@ def plot_ppc(
19
19
  kind="kde",
20
20
  alpha=None,
21
21
  mean=True,
22
- observed=True,
22
+ observed=None,
23
23
  observed_rug=False,
24
24
  color=None,
25
25
  colors=None,
@@ -50,37 +50,35 @@ def plot_ppc(
50
50
 
51
51
  Parameters
52
52
  ----------
53
- data: az.InferenceData object
53
+ data : InferenceData
54
54
  :class:`arviz.InferenceData` object containing the observed and posterior/prior
55
55
  predictive data.
56
- kind: str
57
- Type of plot to display ("kde", "cumulative", or "scatter"). Defaults to `kde`.
58
- alpha: float
56
+ kind : str, default "kde"
57
+ Type of plot to display ("kde", "cumulative", or "scatter").
58
+ alpha : float, optional
59
59
  Opacity of posterior/prior predictive density curves.
60
60
  Defaults to 0.2 for ``kind = kde`` and cumulative, for scatter defaults to 0.7.
61
- mean: bool
61
+ mean : bool, default True
62
62
  Whether or not to plot the mean posterior/prior predictive distribution.
63
- Defaults to ``True``.
64
- observed: bool, default True
65
- Whether or not to plot the observed data.
66
- observed_rug: bool, default False
63
+ observed : bool, optional
64
+ Whether or not to plot the observed data. Defaults to True for ``group = posterior``
65
+ and False for ``group = prior``.
66
+ observed_rug : bool, default False
67
67
  Whether or not to plot a rug plot for the observed data. Only valid if `observed` is
68
68
  `True` and for kind `kde` or `cumulative`.
69
- color: str
70
- Valid matplotlib ``color``. Defaults to ``C0``.
71
- color: list
69
+ color : list, optional
72
70
  List with valid matplotlib colors corresponding to the posterior/prior predictive
73
71
  distribution, observed data and mean of the posterior/prior predictive distribution.
74
72
  Defaults to ["C0", "k", "C1"].
75
- grid : tuple
73
+ grid : tuple, optional
76
74
  Number of rows and columns. Defaults to None, the rows and columns are
77
75
  automatically inferred.
78
- figsize: tuple
76
+ figsize : tuple, optional
79
77
  Figure size. If None, it will be defined automatically.
80
- textsize: float
78
+ textsize : float, optional
81
79
  Text size scaling factor for labels, titles and lines. If None, it will be
82
80
  autoscaled based on ``figsize``.
83
- data_pairs: dict
81
+ data_pairs : dict, optional
84
82
  Dictionary containing relations between observed data and posterior/prior predictive data.
85
83
  Dictionary structure:
86
84
 
@@ -90,84 +88,86 @@ def plot_ppc(
90
88
  For example, ``data_pairs = {'y' : 'y_hat'}``
91
89
  If None, it will assume that the observed data and the posterior/prior
92
90
  predictive data have the same variable name.
93
- var_names: list of variable names
91
+ var_names : list of str, optional
94
92
  Variables to be plotted, if `None` all variable are plotted. Prefix the
95
93
  variables by ``~`` when you want to exclude them from the plot.
96
- filter_vars: {None, "like", "regex"}, optional, default=None
94
+ filter_vars : {None, "like", "regex"}, default None
97
95
  If `None` (default), interpret var_names as the real variables names. If "like",
98
96
  interpret var_names as substrings of the real variables names. If "regex",
99
97
  interpret var_names as regular expressions on the real variables names. A la
100
98
  ``pandas.filter``.
101
- coords: dict
99
+ coords : dict, optional
102
100
  Dictionary mapping dimensions to selected coordinates to be plotted.
103
101
  Dimensions without a mapping specified will include all coordinates for
104
102
  that dimension. Defaults to including all coordinates for all
105
103
  dimensions if None.
106
- flatten: list
104
+ flatten : list
107
105
  List of dimensions to flatten in ``observed_data``. Only flattens across the coordinates
108
106
  specified in the ``coords`` argument. Defaults to flattening all of the dimensions.
109
- flatten_pp: list
107
+ flatten_pp : list
110
108
  List of dimensions to flatten in posterior_predictive/prior_predictive. Only flattens
111
109
  across the coordinates specified in the ``coords`` argument. Defaults to flattening all
112
110
  of the dimensions. Dimensions should match flatten excluding dimensions for ``data_pairs``
113
111
  parameters. If ``flatten`` is defined and ``flatten_pp`` is None, then
114
112
  ``flatten_pp = flatten``.
115
- num_pp_samples: int
113
+ num_pp_samples : int
116
114
  The number of posterior/prior predictive samples to plot. For ``kind`` = 'scatter' and
117
115
  ``animation = False`` if defaults to a maximum of 5 samples and will set jitter to 0.7.
118
116
  unless defined. Otherwise it defaults to all provided samples.
119
- random_seed: int
117
+ random_seed : int
120
118
  Random number generator seed passed to ``numpy.random.seed`` to allow
121
119
  reproducibility of the plot. By default, no seed will be provided
122
120
  and the plot will change each call if a random sample is specified
123
121
  by ``num_pp_samples``.
124
- jitter: float
122
+ jitter : float, default 0
125
123
  If ``kind`` is "scatter", jitter will add random uniform noise to the height
126
- of the ppc samples and observed data. By default 0.
127
- animated: bool
124
+ of the ppc samples and observed data.
125
+ animated : bool, default False
128
126
  Create an animation of one posterior/prior predictive sample per frame.
129
- Defaults to ``False``. Only works with matploblib backend.
127
+ Only works with matploblib backend.
130
128
  To run animations inside a notebook you have to use the `nbAgg` matplotlib's backend.
131
129
  Try with `%matplotlib notebook` or `%matplotlib nbAgg`. You can switch back to the
132
130
  default matplotlib's backend with `%matplotlib inline` or `%matplotlib auto`.
133
131
  If switching back and forth between matplotlib's backend, you may need to run twice the cell
134
132
  with the animation.
135
133
  If you experience problems rendering the animation try setting
136
- `animation_kwargs({'blit':False}`) or changing the matplotlib's backend (e.g. to TkAgg)
137
- If you run the animation from a script write `ax, ani = az.plot_ppc(.)`
134
+ ``animation_kwargs({'blit':False})`` or changing the matplotlib's backend (e.g. to TkAgg)
135
+ If you run the animation from a script write ``ax, ani = az.plot_ppc(.)``
138
136
  animation_kwargs : dict
139
137
  Keywords passed to :class:`matplotlib.animation.FuncAnimation`. Ignored with
140
138
  matplotlib backend.
141
- legend : bool
142
- Add legend to figure. By default ``True``.
143
- labeller : labeller instance, optional
139
+ legend : bool, default True
140
+ Add legend to figure.
141
+ labeller : labeller, optional
144
142
  Class providing the method ``make_pp_label`` to generate the labels in the plot titles.
145
143
  Read the :ref:`label_guide` for more details and usage examples.
146
- ax: numpy array-like of matplotlib axes or bokeh figures, optional
144
+ ax : numpy array-like of matplotlib_axes or bokeh figures, optional
147
145
  A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
148
146
  its own array of plot areas (and return it).
149
- backend: str, optional
147
+ backend : str, optional
150
148
  Select plotting backend {"matplotlib","bokeh"}. Default to "matplotlib".
151
- backend_kwargs: bool, optional
149
+ backend_kwargs : dict, optional
152
150
  These are kwargs specific to the backend being used, passed to
153
151
  :func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`.
154
152
  For additional documentation check the plotting method of the backend.
155
- group: {"prior", "posterior"}, optional
153
+ group : {"prior", "posterior"}, optional
156
154
  Specifies which InferenceData group should be plotted. Defaults to 'posterior'.
157
155
  Other value can be 'prior'.
158
- show: bool, optional
156
+ show : bool, optional
159
157
  Call backend show function.
160
158
 
161
159
  Returns
162
160
  -------
163
- axes: matplotlib axes or bokeh figures
161
+ axes : matplotlib_axes or bokeh_figures
162
+ ani : matplotlib.animation.FuncAnimation, optional
163
+ Only provided if `animated` is ``True``.
164
164
 
165
165
  See Also
166
166
  --------
167
- plot_bpv: Plot Bayesian p-value for observed data and Posterior/Prior predictive.
168
- plot_lm: Posterior predictive and mean plots for regression-like data.
169
- plot_ppc: plot for posterior/prior predictive checks.
170
- plot_ts: Plot timeseries data.
167
+ plot_bpv : Plot Bayesian p-value for observed data and Posterior/Prior predictive.
168
+ plot_loo_pit : Plot for posterior predictive checks using cross validation.
169
+ plot_lm : Posterior predictive and mean plots for regression-like data.
170
+ plot_ts : Plot timeseries data.
171
171
 
172
172
  Examples
173
173
  --------
@@ -254,8 +254,12 @@ def plot_ppc(
254
254
 
255
255
  if group == "posterior":
256
256
  predictive_dataset = data.posterior_predictive
257
+ if observed is None:
258
+ observed = True
257
259
  elif group == "prior":
258
260
  predictive_dataset = data.prior_predictive
261
+ if observed is None:
262
+ observed = False
259
263
 
260
264
  if var_names is None:
261
265
  var_names = list(observed_data.data_vars)
@@ -265,11 +269,11 @@ def plot_ppc(
265
269
 
266
270
  if flatten_pp is None:
267
271
  if flatten is None:
268
- flatten_pp = list(predictive_dataset.dims.keys())
272
+ flatten_pp = list(predictive_dataset.dims)
269
273
  else:
270
274
  flatten_pp = flatten
271
275
  if flatten is None:
272
- flatten = list(observed_data.dims.keys())
276
+ flatten = list(observed_data.dims)
273
277
 
274
278
  if coords is None:
275
279
  coords = {}
@@ -308,6 +312,7 @@ def plot_ppc(
308
312
  skip_dims=set(flatten),
309
313
  var_names=var_names,
310
314
  combined=True,
315
+ dim_order=["chain", "draw"],
311
316
  )
312
317
  ),
313
318
  "plot_ppc",
@@ -322,6 +327,7 @@ def plot_ppc(
322
327
  var_names=pp_var_names,
323
328
  skip_dims=set(flatten_pp),
324
329
  combined=True,
330
+ dim_order=["chain", "draw"],
325
331
  ),
326
332
  )
327
333
  ]
@@ -110,7 +110,6 @@ def plot_separation(
110
110
  )
111
111
 
112
112
  else:
113
-
114
113
  if y_hat is None and isinstance(y, str):
115
114
  label_y_hat = y
116
115
  y_hat = y
arviz/stats/__init__.py CHANGED
@@ -28,7 +28,9 @@ __all__ = [
28
28
  "autocorr",
29
29
  "autocov",
30
30
  "make_ufunc",
31
+ "smooth_data",
31
32
  "wrap_xarray_ufunc",
32
33
  "reloo",
33
34
  "_calculate_ics",
35
+ "psens",
34
36
  ]
@@ -231,8 +231,8 @@ def _fixed_point(t, N, k_sq, a_sq):
231
231
  Z. I. Botev, J. F. Grotowski, and D. P. Kroese.
232
232
  Ann. Statist. 38 (2010), no. 5, 2916--2957.
233
233
  """
234
- k_sq = np.asfarray(k_sq, dtype=np.float64)
235
- a_sq = np.asfarray(a_sq, dtype=np.float64)
234
+ k_sq = np.asarray(k_sq, dtype=np.float64)
235
+ a_sq = np.asarray(a_sq, dtype=np.float64)
236
236
 
237
237
  l = 7
238
238
  f = np.sum(np.power(k_sq, l) * a_sq * np.exp(-k_sq * np.pi**2 * t))
@@ -457,10 +457,10 @@ def ks_summary(pareto_tail_indices):
457
457
  """
458
458
  _numba_flag = Numba.numba_flag
459
459
  if _numba_flag:
460
- bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf])
460
+ bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
461
461
  kcounts, *_ = _histogram(pareto_tail_indices, bins)
462
462
  else:
463
- kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
463
+ kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.inf, 0.5, 0.7, 1, np.inf])
464
464
  kprop = kcounts / len(pareto_tail_indices) * 100
465
465
  df_k = pd.DataFrame(
466
466
  dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
@@ -889,7 +889,6 @@ def _mc_error(ary, batches=5, circular=False):
889
889
  """
890
890
  _numba_flag = Numba.numba_flag
891
891
  if ary.ndim > 1:
892
-
893
892
  dims = np.shape(ary)
894
893
  trace = np.transpose([t.ravel() for t in ary])
895
894
 
@@ -0,0 +1,165 @@
1
+ """Functions for evaluating ECDFs and their confidence bands."""
2
+ from typing import Any, Callable, Optional, Tuple
3
+ import warnings
4
+
5
+ import numpy as np
6
+ from scipy.stats import uniform, binom
7
+
8
+
9
+ def compute_ecdf(sample: np.ndarray, eval_points: np.ndarray) -> np.ndarray:
10
+ """Compute ECDF of the sorted `sample` at the evaluation points."""
11
+ return np.searchsorted(sample, eval_points, side="right") / len(sample)
12
+
13
+
14
+ def _get_ecdf_points(
15
+ sample: np.ndarray, eval_points: np.ndarray, difference: bool
16
+ ) -> Tuple[np.ndarray, np.ndarray]:
17
+ """Compute the coordinates for the ecdf points using compute_ecdf."""
18
+ x = eval_points
19
+ y = compute_ecdf(sample, eval_points)
20
+
21
+ if not difference and y[0] > 0:
22
+ x = np.insert(x, 0, x[0])
23
+ y = np.insert(y, 0, 0)
24
+ return x, y
25
+
26
+
27
+ def _simulate_ecdf(
28
+ ndraws: int,
29
+ eval_points: np.ndarray,
30
+ rvs: Callable[[int, Optional[Any]], np.ndarray],
31
+ random_state: Optional[Any] = None,
32
+ ) -> np.ndarray:
33
+ """Simulate ECDF at the `eval_points` using the given random variable sampler"""
34
+ sample = rvs(ndraws, random_state=random_state)
35
+ sample.sort()
36
+ return compute_ecdf(sample, eval_points)
37
+
38
+
39
+ def _fit_pointwise_band_probability(
40
+ ndraws: int,
41
+ ecdf_at_eval_points: np.ndarray,
42
+ cdf_at_eval_points: np.ndarray,
43
+ ) -> float:
44
+ """Compute the smallest marginal probability of a pointwise confidence band that
45
+ contains the ECDF."""
46
+ ecdf_scaled = (ndraws * ecdf_at_eval_points).astype(int)
47
+ prob_lower_tail = np.amin(binom.cdf(ecdf_scaled, ndraws, cdf_at_eval_points))
48
+ prob_upper_tail = np.amin(binom.sf(ecdf_scaled - 1, ndraws, cdf_at_eval_points))
49
+ prob_pointwise = 1 - 2 * min(prob_lower_tail, prob_upper_tail)
50
+ return prob_pointwise
51
+
52
+
53
+ def _get_pointwise_confidence_band(
54
+ prob: float, ndraws: int, cdf_at_eval_points: np.ndarray
55
+ ) -> Tuple[np.ndarray, np.ndarray]:
56
+ """Compute the `prob`-level pointwise confidence band."""
57
+ count_lower, count_upper = binom.interval(prob, ndraws, cdf_at_eval_points)
58
+ prob_lower = count_lower / ndraws
59
+ prob_upper = count_upper / ndraws
60
+ return prob_lower, prob_upper
61
+
62
+
63
+ def ecdf_confidence_band(
64
+ ndraws: int,
65
+ eval_points: np.ndarray,
66
+ cdf_at_eval_points: np.ndarray,
67
+ prob: float = 0.95,
68
+ method="simulated",
69
+ **kwargs,
70
+ ) -> Tuple[np.ndarray, np.ndarray]:
71
+ """Compute the `prob`-level confidence band for the ECDF.
72
+
73
+ Arguments
74
+ ---------
75
+ ndraws : int
76
+ Number of samples in the original dataset.
77
+ eval_points : np.ndarray
78
+ Points at which the ECDF is evaluated. If these are dependent on the sample
79
+ values, simultaneous confidence bands may not be correctly calibrated.
80
+ cdf_at_eval_points : np.ndarray
81
+ CDF values at the evaluation points.
82
+ prob : float, default 0.95
83
+ The target probability that a true ECDF lies within the confidence band.
84
+ method : string, default "simulated"
85
+ The method used to compute the confidence band. Valid options are:
86
+ - "pointwise": Compute the pointwise (i.e. marginal) confidence band.
87
+ - "simulated": Use Monte Carlo simulation to estimate a simultaneous confidence band.
88
+ `rvs` must be provided.
89
+ rvs: callable, optional
90
+ A function that takes an integer `ndraws` and optionally the object passed to
91
+ `random_state` and returns an array of `ndraws` samples from the same distribution
92
+ as the original dataset. Required if `method` is "simulated" and variable is discrete.
93
+ num_trials : int, default 1000
94
+ The number of random ECDFs to generate for constructing simultaneous confidence bands
95
+ (if `method` is "simulated").
96
+ random_state : {None, int, `numpy.random.Generator`,
97
+ `numpy.random.RandomState`}, optional
98
+ If `None`, the `numpy.random.RandomState` singleton is used. If an `int`, a new
99
+ ``numpy.random.RandomState`` instance is used, seeded with seed. If a `RandomState` or
100
+ `Generator` instance, the instance is used.
101
+
102
+ Returns
103
+ -------
104
+ prob_lower : np.ndarray
105
+ Lower confidence band for the ECDF at the evaluation points.
106
+ prob_upper : np.ndarray
107
+ Upper confidence band for the ECDF at the evaluation points.
108
+ """
109
+ if not 0 < prob < 1:
110
+ raise ValueError(f"Invalid value for `prob`. Expected 0 < prob < 1, but got {prob}.")
111
+
112
+ if method == "pointwise":
113
+ prob_pointwise = prob
114
+ elif method == "simulated":
115
+ prob_pointwise = _simulate_simultaneous_ecdf_band_probability(
116
+ ndraws, eval_points, cdf_at_eval_points, prob=prob, **kwargs
117
+ )
118
+ else:
119
+ raise ValueError(f"Unknown method {method}. Valid options are 'pointwise' or 'simulated'.")
120
+
121
+ prob_lower, prob_upper = _get_pointwise_confidence_band(
122
+ prob_pointwise, ndraws, cdf_at_eval_points
123
+ )
124
+
125
+ return prob_lower, prob_upper
126
+
127
+
128
+ def _simulate_simultaneous_ecdf_band_probability(
129
+ ndraws: int,
130
+ eval_points: np.ndarray,
131
+ cdf_at_eval_points: np.ndarray,
132
+ prob: float = 0.95,
133
+ rvs: Optional[Callable[[int, Optional[Any]], np.ndarray]] = None,
134
+ num_trials: int = 1000,
135
+ random_state: Optional[Any] = None,
136
+ ) -> float:
137
+ """Estimate probability for simultaneous confidence band using simulation.
138
+
139
+ This function simulates the pointwise probability needed to construct pointwise
140
+ confidence bands that form a `prob`-level confidence envelope for the ECDF
141
+ of a sample.
142
+ """
143
+ if rvs is None:
144
+ warnings.warn(
145
+ "Assuming variable is continuous for calibration of pointwise bands. "
146
+ "If the variable is discrete, specify random variable sampler `rvs`.",
147
+ UserWarning,
148
+ )
149
+ # if variable continuous, we can calibrate the confidence band using a uniform
150
+ # distribution
151
+ rvs = uniform(0, 1).rvs
152
+ eval_points_sim = cdf_at_eval_points
153
+ else:
154
+ eval_points_sim = eval_points
155
+
156
+ probs_pointwise = np.empty(num_trials)
157
+ for i in range(num_trials):
158
+ ecdf_at_eval_points = _simulate_ecdf(
159
+ ndraws, eval_points_sim, rvs, random_state=random_state
160
+ )
161
+ prob_pointwise = _fit_pointwise_band_probability(
162
+ ndraws, ecdf_at_eval_points, cdf_at_eval_points
163
+ )
164
+ probs_pointwise[i] = prob_pointwise
165
+ return np.quantile(probs_pointwise, prob)