arviz 0.17.0__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arviz/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # pylint: disable=wildcard-import,invalid-name,wrong-import-position
2
2
  """ArviZ is a library for exploratory analysis of Bayesian models."""
3
- __version__ = "0.17.0"
3
+ __version__ = "0.17.1"
4
4
 
5
5
  import logging
6
6
  import os
@@ -56,6 +56,7 @@ SUPPORTED_GROUPS = [
56
56
  "posterior_predictive",
57
57
  "predictions",
58
58
  "log_likelihood",
59
+ "log_prior",
59
60
  "sample_stats",
60
61
  "prior",
61
62
  "prior_predictive",
@@ -63,6 +64,8 @@ SUPPORTED_GROUPS = [
63
64
  "observed_data",
64
65
  "constant_data",
65
66
  "predictions_constant_data",
67
+ "unconstrained_posterior",
68
+ "unconstrained_prior",
66
69
  ]
67
70
 
68
71
  WARMUP_TAG = "warmup_"
@@ -73,6 +76,7 @@ SUPPORTED_GROUPS_WARMUP = [
73
76
  f"{WARMUP_TAG}predictions",
74
77
  f"{WARMUP_TAG}sample_stats",
75
78
  f"{WARMUP_TAG}log_likelihood",
79
+ f"{WARMUP_TAG}log_prior",
76
80
  ]
77
81
 
78
82
  SUPPORTED_GROUPS_ALL = SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP
@@ -250,8 +254,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
250
254
 
251
255
  def __iter__(self) -> Iterator[str]:
252
256
  """Iterate over groups in InferenceData object."""
253
- for group in self._groups_all:
254
- yield group
257
+ yield from self._groups_all
255
258
 
256
259
  def __contains__(self, key: object) -> bool:
257
260
  """Return True if the named item is present, and False otherwise."""
@@ -1490,7 +1493,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
1490
1493
 
1491
1494
  import numpy as np
1492
1495
  rng = np.random.default_rng(73)
1493
- ary = rng.normal(size=(post.dims["chain"], post.dims["draw"], obs.dims["match"]))
1496
+ ary = rng.normal(size=(post.sizes["chain"], post.sizes["draw"], obs.sizes["match"]))
1494
1497
  idata.add_groups(
1495
1498
  log_likelihood={"home_points": ary},
1496
1499
  dims={"home_points": ["match"]},
arviz/data/io_pystan.py CHANGED
@@ -676,8 +676,7 @@ def get_draws(fit, variables=None, ignore=None, warmup=False, dtypes=None):
676
676
  for item in par_keys:
677
677
  _, shape = item.replace("]", "").split("[")
678
678
  shape_idx_min = min(int(shape_value) for shape_value in shape.split(","))
679
- if shape_idx_min < shift:
680
- shift = shape_idx_min
679
+ shift = min(shift, shape_idx_min)
681
680
  # If shift is higher than 1, this will probably mean that Stan
682
681
  # has implemented sparse structure (saves only non-zero parts),
683
682
  # but let's hope that dims are still corresponding to the full shape
@@ -171,8 +171,13 @@ def plot_bpv(
171
171
  ax_i.line(0, 0, legend_label=f"bpv={p_value:.2f}", alpha=0)
172
172
 
173
173
  if plot_mean:
174
- ax_i.circle(
175
- obs_vals.mean(), 0, fill_color=color, line_color="black", size=markersize
174
+ ax_i.scatter(
175
+ obs_vals.mean(),
176
+ 0,
177
+ fill_color=color,
178
+ line_color="black",
179
+ size=markersize,
180
+ marker="circle",
176
181
  )
177
182
 
178
183
  _title = Title()
@@ -69,13 +69,14 @@ def plot_compare(
69
69
  err_ys.append((y, y))
70
70
 
71
71
  # plot them
72
- dif_tri = ax.triangle(
72
+ dif_tri = ax.scatter(
73
73
  comp_df[information_criterion].iloc[1:],
74
74
  yticks_pos[1::2],
75
75
  line_color=plot_kwargs.get("color_dse", "grey"),
76
76
  fill_color=plot_kwargs.get("color_dse", "grey"),
77
77
  line_width=2,
78
78
  size=6,
79
+ marker="triangle",
79
80
  )
80
81
  dif_line = ax.multi_line(err_xs, err_ys, line_color=plot_kwargs.get("color_dse", "grey"))
81
82
 
@@ -85,13 +86,14 @@ def plot_compare(
85
86
  ax.yaxis.ticker = yticks_pos[::2]
86
87
  ax.yaxis.major_label_overrides = dict(zip(yticks_pos[::2], yticks_labels))
87
88
 
88
- elpd_circ = ax.circle(
89
+ elpd_circ = ax.scatter(
89
90
  comp_df[information_criterion],
90
91
  yticks_pos[::2],
91
92
  line_color=plot_kwargs.get("color_ic", "black"),
92
93
  fill_color=None,
93
94
  line_width=2,
94
95
  size=6,
96
+ marker="circle",
95
97
  )
96
98
  elpd_label = [elpd_circ]
97
99
 
@@ -110,7 +112,7 @@ def plot_compare(
110
112
 
111
113
  labels.append(("ELPD", elpd_label))
112
114
 
113
- scale = comp_df["scale"][0]
115
+ scale = comp_df["scale"].iloc[0]
114
116
 
115
117
  if insample_dev:
116
118
  p_ic = comp_df[f"p_{information_criterion.split('_')[1]}"]
@@ -120,13 +122,14 @@ def plot_compare(
120
122
  correction = -p_ic
121
123
  elif scale == "deviance":
122
124
  correction = -(2 * p_ic)
123
- insample_circ = ax.circle(
125
+ insample_circ = ax.scatter(
124
126
  comp_df[information_criterion] + correction,
125
127
  yticks_pos[::2],
126
128
  line_color=plot_kwargs.get("color_insample_dev", "black"),
127
129
  fill_color=plot_kwargs.get("color_insample_dev", "black"),
128
130
  line_width=2,
129
131
  size=6,
132
+ marker="circle",
130
133
  )
131
134
  labels.append(("In-sample ELPD", [insample_circ]))
132
135
 
@@ -15,7 +15,6 @@ from ....rcparams import rcParams
15
15
  from ....stats import hdi
16
16
  from ....stats.density_utils import get_bins, histogram, kde
17
17
  from ....stats.diagnostics import _ess, _rhat
18
- from ....utils import conditional_jit
19
18
  from ...plot_utils import _scale_fig_size
20
19
  from .. import show_layout
21
20
  from . import backend_kwarg_defaults
@@ -277,7 +276,6 @@ class PlotHandler:
277
276
  """Collect labels and ticks from plotters."""
278
277
  val = self.plotters.values()
279
278
 
280
- @conditional_jit(forceobj=True, nopython=False)
281
279
  def label_idxs():
282
280
  labels, idxs = [], []
283
281
  for plotter in val:
@@ -640,7 +638,7 @@ class VarHandler:
640
638
  grouped_data = [[(0, datum)] for datum in self.data]
641
639
  skip_dims = self.combine_dims.union({"chain"})
642
640
  else:
643
- grouped_data = [datum.groupby("chain") for datum in self.data]
641
+ grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
644
642
  skip_dims = self.combine_dims
645
643
 
646
644
  label_dict = OrderedDict()
@@ -648,7 +646,7 @@ class VarHandler:
648
646
  for name, grouped_datum in zip(self.model_names, grouped_data):
649
647
  for _, sub_data in grouped_datum:
650
648
  datum_iter = xarray_var_iter(
651
- sub_data,
649
+ sub_data.squeeze(),
652
650
  var_names=[self.var_name],
653
651
  skip_dims=skip_dims,
654
652
  reverse_selections=True,
@@ -84,7 +84,7 @@ def plot_compare(
84
84
  else:
85
85
  ax.set_yticks(yticks_pos[::2])
86
86
 
87
- scale = comp_df["scale"][0]
87
+ scale = comp_df["scale"].iloc[0]
88
88
 
89
89
  if insample_dev:
90
90
  p_ic = comp_df[f"p_{information_criterion.split('_')[1]}"]
@@ -11,7 +11,6 @@ from ....stats import hdi
11
11
  from ....stats.density_utils import get_bins, histogram, kde
12
12
  from ....stats.diagnostics import _ess, _rhat
13
13
  from ....sel_utils import xarray_var_iter
14
- from ....utils import conditional_jit
15
14
  from ...plot_utils import _scale_fig_size
16
15
  from . import backend_kwarg_defaults, backend_show
17
16
 
@@ -236,7 +235,6 @@ class PlotHandler:
236
235
  """Collect labels and ticks from plotters."""
237
236
  val = self.plotters.values()
238
237
 
239
- @conditional_jit(forceobj=True, nopython=False)
240
238
  def label_idxs():
241
239
  labels, idxs = [], []
242
240
  for plotter in val:
@@ -536,7 +534,7 @@ class VarHandler:
536
534
  grouped_data = [[(0, datum)] for datum in self.data]
537
535
  skip_dims = self.combine_dims.union({"chain"})
538
536
  else:
539
- grouped_data = [datum.groupby("chain") for datum in self.data]
537
+ grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
540
538
  skip_dims = self.combine_dims
541
539
 
542
540
  label_dict = OrderedDict()
@@ -544,7 +542,7 @@ class VarHandler:
544
542
  for name, grouped_datum in zip(self.model_names, grouped_data):
545
543
  for _, sub_data in grouped_datum:
546
544
  datum_iter = xarray_var_iter(
547
- sub_data,
545
+ sub_data.squeeze(),
548
546
  var_names=[self.var_name],
549
547
  skip_dims=skip_dims,
550
548
  reverse_selections=True,
@@ -430,7 +430,7 @@ def plot_trace(
430
430
  Line2D(
431
431
  [], [], label=chain_id, **dealiase_sel_kwargs(legend_kwargs, chain_prop, chain_id)
432
432
  )
433
- for chain_id in range(data.dims["chain"])
433
+ for chain_id in range(data.sizes["chain"])
434
434
  ]
435
435
  if combined:
436
436
  handles.insert(
arviz/plots/bfplot.py CHANGED
@@ -38,7 +38,7 @@ def plot_bf(
38
38
  algorithm presented in [1]_.
39
39
 
40
40
  Parameters
41
- -----------
41
+ ----------
42
42
  idata : InferenceData
43
43
  Any object that can be converted to an :class:`arviz.InferenceData` object
44
44
  Refer to documentation of :func:`arviz.convert_to_dataset` for details.
@@ -52,16 +52,16 @@ def plot_bf(
52
52
  Tuple of valid Matplotlib colors. First element for the prior, second for the posterior.
53
53
  figsize : (float, float), optional
54
54
  Figure size. If `None` it will be defined automatically.
55
- textsize: float, optional
55
+ textsize : float, optional
56
56
  Text size scaling factor for labels, titles and lines. If `None` it will be auto
57
57
  scaled based on `figsize`.
58
- plot_kwargs : dicts, optional
58
+ plot_kwargs : dict, optional
59
59
  Additional keywords passed to :func:`matplotlib.pyplot.plot`.
60
- hist_kwargs : dicts, optional
60
+ hist_kwargs : dict, optional
61
61
  Additional keywords passed to :func:`arviz.plot_dist`. Only works for discrete variables.
62
62
  ax : axes, optional
63
63
  :class:`matplotlib.axes.Axes` or :class:`bokeh.plotting.Figure`.
64
- backend :{"matplotlib", "bokeh"}, default "matplotlib"
64
+ backend : {"matplotlib", "bokeh"}, default "matplotlib"
65
65
  Select plotting backend.
66
66
  backend_kwargs : dict, optional
67
67
  These are kwargs specific to the backend being used, passed to
@@ -78,7 +78,7 @@ def plot_bf(
78
78
  References
79
79
  ----------
80
80
  .. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio:
81
- The case of computing Bayes factors for regression parameters.
81
+ The case of computing Bayes factors for regression parameters.
82
82
 
83
83
  Examples
84
84
  --------
@@ -92,6 +92,7 @@ def plot_bf(
92
92
  >>> idata = az.from_dict(posterior={"a":np.random.normal(1, 0.5, 5000)},
93
93
  ... prior={"a":np.random.normal(0, 1, 5000)})
94
94
  >>> az.plot_bf(idata, var_name="a", ref_val=0)
95
+
95
96
  """
96
97
  posterior = extract(idata, var_names=var_name).values
97
98
 
arviz/plots/bpvplot.py CHANGED
@@ -230,11 +230,11 @@ def plot_bpv(
230
230
 
231
231
  if flatten_pp is None:
232
232
  if flatten is None:
233
- flatten_pp = list(predictive_dataset.dims.keys())
233
+ flatten_pp = list(predictive_dataset.dims)
234
234
  else:
235
235
  flatten_pp = flatten
236
236
  if flatten is None:
237
- flatten = list(observed.dims.keys())
237
+ flatten = list(observed.dims)
238
238
 
239
239
  if coords is None:
240
240
  coords = {}
@@ -90,10 +90,10 @@ def plot_compare(
90
90
  References
91
91
  ----------
92
92
  .. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
93
- cross-validation and WAIC https://arxiv.org/abs/1507.04544
93
+ cross-validation and WAIC https://arxiv.org/abs/1507.04544
94
94
 
95
95
  .. [2] McElreath R. (2022). Statistical Rethinking A Bayesian Course with Examples in
96
- R and Stan, Second edition, CRC Press.
96
+ R and Stan, Second edition, CRC Press.
97
97
 
98
98
  Examples
99
99
  --------
arviz/plots/ecdfplot.py CHANGED
@@ -1,8 +1,9 @@
1
1
  """Plot ecdf or ecdf-difference plot with confidence bands."""
2
2
  import numpy as np
3
- from scipy.stats import uniform, binom
3
+ from scipy.stats import uniform
4
4
 
5
5
  from ..rcparams import rcParams
6
+ from ..stats.ecdf_utils import compute_ecdf, ecdf_confidence_band, _get_ecdf_points
6
7
  from .plot_utils import get_plotting_function
7
8
 
8
9
 
@@ -26,7 +27,7 @@ def plot_ecdf(
26
27
  show=None,
27
28
  backend=None,
28
29
  backend_kwargs=None,
29
- **kwargs
30
+ **kwargs,
30
31
  ):
31
32
  r"""Plot ECDF or ECDF-Difference Plot with Confidence bands.
32
33
 
@@ -48,6 +49,7 @@ def plot_ecdf(
48
49
  Values to compare to the original sample.
49
50
  cdf : callable, optional
50
51
  Cumulative distribution function of the distribution to compare the original sample.
52
+ The function must take as input a numpy array of draws from the distribution.
51
53
  difference : bool, default False
52
54
  If True then plot ECDF-difference plot otherwise ECDF plot.
53
55
  pit : bool, default False
@@ -180,75 +182,47 @@ def plot_ecdf(
180
182
  values = np.ravel(values)
181
183
  values.sort()
182
184
 
183
- ## This block computes gamma and uses it to get the upper and lower confidence bands
184
- ## Here we check if we want confidence bands or not
185
- if confidence_bands:
186
- ## If plotting PIT then we find the PIT values of sample.
187
- ## Basically here we generate the evaluation points(x) and find the PIT values.
188
- ## z is the evaluation point for our uniform distribution in compute_gamma()
189
- if pit:
190
- x = np.linspace(1 / npoints, 1, npoints)
191
- z = x
192
- ## Finding PIT for our sample
193
- probs = cdf(values) if cdf else compute_ecdf(values2, values) / len(values2)
194
- else:
195
- ## If not PIT use sample for plots and for evaluation points(x) use equally spaced
196
- ## points between minimum and maximum of sample
197
- ## For z we have used cdf(x)
198
- x = np.linspace(values[0], values[-1], npoints)
199
- z = cdf(x) if cdf else compute_ecdf(values2, x)
200
- probs = values
201
-
202
- n = len(values) # number of samples
203
- ## Computing gamma
204
- gamma = fpr if pointwise else compute_gamma(n, z, npoints, num_trials, fpr)
205
- ## Using gamma to get the confidence intervals
206
- lower, higher = get_lims(gamma, n, z)
207
-
208
- ## This block is for whether to plot ECDF or ECDF-difference
209
- if not difference:
210
- ## We store the coordinates of our ecdf in x_coord, y_coord
211
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
185
+ if pit:
186
+ eval_points = np.linspace(1 / npoints, 1, npoints)
187
+ if cdf:
188
+ sample = cdf(values)
212
189
  else:
213
- ## Here we subtract the ecdf value as here we are plotting the ECDF-difference
214
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
215
- for i, x_i in enumerate(x):
216
- y_coord[i] = y_coord[i] - (
217
- x_i if pit else cdf(x_i) if cdf else compute_ecdf(values2, x_i)
218
- )
219
-
220
- ## Similarly we subtract from the upper and lower bounds
221
- if pit:
222
- lower = lower - x
223
- higher = higher - x
224
- else:
225
- lower = lower - (cdf(x) if cdf else compute_ecdf(values2, x))
226
- higher = higher - (cdf(x) if cdf else compute_ecdf(values2, x))
227
-
190
+ sample = compute_ecdf(values2, values) / len(values2)
191
+ cdf_at_eval_points = eval_points
192
+ rvs = uniform(0, 1).rvs
228
193
  else:
229
- if pit:
230
- x = np.linspace(1 / npoints, 1, npoints)
231
- probs = cdf(values)
194
+ eval_points = np.linspace(values[0], values[-1], npoints)
195
+ sample = values
196
+ if confidence_bands or difference:
197
+ if cdf:
198
+ cdf_at_eval_points = cdf(eval_points)
199
+ else:
200
+ cdf_at_eval_points = compute_ecdf(values2, eval_points)
232
201
  else:
233
- x = np.linspace(values[0], values[-1], npoints)
234
- probs = values
202
+ cdf_at_eval_points = np.zeros_like(eval_points)
203
+ rvs = None
204
+
205
+ x_coord, y_coord = _get_ecdf_points(sample, eval_points, difference)
235
206
 
207
+ if difference:
208
+ y_coord -= cdf_at_eval_points
209
+
210
+ if confidence_bands:
211
+ ndraws = len(values)
212
+ band_kwargs = {"prob": 1 - fpr, "num_trials": num_trials, "rvs": rvs, "random_state": None}
213
+ band_kwargs["method"] = "pointwise" if pointwise else "simulated"
214
+ lower, higher = ecdf_confidence_band(ndraws, eval_points, cdf_at_eval_points, **band_kwargs)
215
+
216
+ if difference:
217
+ lower -= cdf_at_eval_points
218
+ higher -= cdf_at_eval_points
219
+ else:
236
220
  lower, higher = None, None
237
- ## This block is for whether to plot ECDF or ECDF-difference
238
- if not difference:
239
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
240
- else:
241
- ## Here we subtract the ecdf value as here we are plotting the ECDF-difference
242
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
243
- for i, x_i in enumerate(x):
244
- y_coord[i] = y_coord[i] - (
245
- x_i if pit else cdf(x_i) if cdf else compute_ecdf(values2, x_i)
246
- )
247
221
 
248
222
  ecdf_plot_args = dict(
249
223
  x_coord=x_coord,
250
224
  y_coord=y_coord,
251
- x_bands=x,
225
+ x_bands=eval_points,
252
226
  lower=lower,
253
227
  higher=higher,
254
228
  confidence_bands=confidence_bands,
@@ -260,7 +234,7 @@ def plot_ecdf(
260
234
  ax=ax,
261
235
  show=show,
262
236
  backend_kwargs=backend_kwargs,
263
- **kwargs
237
+ **kwargs,
264
238
  )
265
239
 
266
240
  if backend is None:
@@ -271,52 +245,3 @@ def plot_ecdf(
271
245
  ax = plot(**ecdf_plot_args)
272
246
 
273
247
  return ax
274
-
275
-
276
- def compute_ecdf(sample, z):
277
- """Compute ECDF.
278
-
279
- This function computes the ecdf value at the evaluation point
280
- or a sorted set of evaluation points.
281
- """
282
- return np.searchsorted(sample, z, side="right") / len(sample)
283
-
284
-
285
- def get_ecdf_points(x, probs, difference):
286
- """Compute the coordinates for the ecdf points using compute_ecdf."""
287
- y = compute_ecdf(probs, x)
288
-
289
- if not difference:
290
- x = np.insert(x, 0, x[0])
291
- y = np.insert(y, 0, 0)
292
- return x, y
293
-
294
-
295
- def compute_gamma(n, z, npoints=None, num_trials=1000, fpr=0.05):
296
- """Compute gamma for confidence interval calculation.
297
-
298
- This function simulates an adjusted value of gamma to account for multiplicity
299
- when forming an 1-fpr level confidence envelope for the ECDF of a sample.
300
- """
301
- if npoints is None:
302
- npoints = n
303
- gamma = []
304
- for _ in range(num_trials):
305
- unif_samples = uniform.rvs(0, 1, n)
306
- unif_samples = np.sort(unif_samples)
307
- gamma_m = 1000
308
- ## Can compute ecdf for all the z together or one at a time.
309
- f_z = compute_ecdf(unif_samples, z)
310
- f_z = compute_ecdf(unif_samples, z)
311
- gamma_m = 2 * min(
312
- np.amin(binom.cdf(n * f_z, n, z)), np.amin(1 - binom.cdf(n * f_z - 1, n, z))
313
- )
314
- gamma.append(gamma_m)
315
- return np.quantile(gamma, fpr)
316
-
317
-
318
- def get_lims(gamma, n, z):
319
- """Compute the simultaneous 1 - fpr level confidence bands."""
320
- lower = binom.ppf(gamma / 2, n, z)
321
- upper = binom.ppf(1 - gamma / 2, n, z)
322
- return lower / n, upper / n
arviz/plots/elpdplot.py CHANGED
@@ -98,7 +98,7 @@ def plot_elpd(
98
98
  References
99
99
  ----------
100
100
  .. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
101
- cross-validation and WAIC https://arxiv.org/abs/1507.04544
101
+ cross-validation and WAIC https://arxiv.org/abs/1507.04544
102
102
 
103
103
  Examples
104
104
  --------
arviz/plots/essplot.py CHANGED
@@ -202,8 +202,8 @@ def plot_ess(
202
202
 
203
203
  data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
204
204
  var_names = _var_names(var_names, data, filter_vars)
205
- n_draws = data.dims["draw"]
206
- n_samples = n_draws * data.dims["chain"]
205
+ n_draws = data.sizes["draw"]
206
+ n_samples = n_draws * data.sizes["chain"]
207
207
 
208
208
  ess_tail_dataset = None
209
209
  mean_ess = None
arviz/plots/pairplot.py CHANGED
@@ -229,7 +229,7 @@ def plot_pair(
229
229
  )
230
230
 
231
231
  if gridsize == "auto":
232
- gridsize = int(dataset.dims["draw"] ** 0.35)
232
+ gridsize = int(dataset.sizes["draw"] ** 0.35)
233
233
 
234
234
  numvars = len(flat_var_names)
235
235
 
arviz/plots/ppcplot.py CHANGED
@@ -19,7 +19,7 @@ def plot_ppc(
19
19
  kind="kde",
20
20
  alpha=None,
21
21
  mean=True,
22
- observed=True,
22
+ observed=None,
23
23
  observed_rug=False,
24
24
  color=None,
25
25
  colors=None,
@@ -60,8 +60,9 @@ def plot_ppc(
60
60
  Defaults to 0.2 for ``kind = kde`` and cumulative, for scatter defaults to 0.7.
61
61
  mean : bool, default True
62
62
  Whether or not to plot the mean posterior/prior predictive distribution.
63
- observed : bool, default True
64
- Whether or not to plot the observed data.
63
+ observed : bool, optional
64
+ Whether or not to plot the observed data. Defaults to True for ``group = posterior``
65
+ and False for ``group = prior``.
65
66
  observed_rug : bool, default False
66
67
  Whether or not to plot a rug plot for the observed data. Only valid if `observed` is
67
68
  `True` and for kind `kde` or `cumulative`.
@@ -253,8 +254,12 @@ def plot_ppc(
253
254
 
254
255
  if group == "posterior":
255
256
  predictive_dataset = data.posterior_predictive
257
+ if observed is None:
258
+ observed = True
256
259
  elif group == "prior":
257
260
  predictive_dataset = data.prior_predictive
261
+ if observed is None:
262
+ observed = False
258
263
 
259
264
  if var_names is None:
260
265
  var_names = list(observed_data.data_vars)
@@ -264,11 +269,11 @@ def plot_ppc(
264
269
 
265
270
  if flatten_pp is None:
266
271
  if flatten is None:
267
- flatten_pp = list(predictive_dataset.dims.keys())
272
+ flatten_pp = list(predictive_dataset.dims)
268
273
  else:
269
274
  flatten_pp = flatten
270
275
  if flatten is None:
271
- flatten = list(observed_data.dims.keys())
276
+ flatten = list(observed_data.dims)
272
277
 
273
278
  if coords is None:
274
279
  coords = {}
@@ -231,8 +231,8 @@ def _fixed_point(t, N, k_sq, a_sq):
231
231
  Z. I. Botev, J. F. Grotowski, and D. P. Kroese.
232
232
  Ann. Statist. 38 (2010), no. 5, 2916--2957.
233
233
  """
234
- k_sq = np.asfarray(k_sq, dtype=np.float64)
235
- a_sq = np.asfarray(a_sq, dtype=np.float64)
234
+ k_sq = np.asarray(k_sq, dtype=np.float64)
235
+ a_sq = np.asarray(a_sq, dtype=np.float64)
236
236
 
237
237
  l = 7
238
238
  f = np.sum(np.power(k_sq, l) * a_sq * np.exp(-k_sq * np.pi**2 * t))
@@ -457,10 +457,10 @@ def ks_summary(pareto_tail_indices):
457
457
  """
458
458
  _numba_flag = Numba.numba_flag
459
459
  if _numba_flag:
460
- bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf])
460
+ bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
461
461
  kcounts, *_ = _histogram(pareto_tail_indices, bins)
462
462
  else:
463
- kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
463
+ kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.inf, 0.5, 0.7, 1, np.inf])
464
464
  kprop = kcounts / len(pareto_tail_indices) * 100
465
465
  df_k = pd.DataFrame(
466
466
  dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)