bayesian-pricing 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bayesian_pricing/__init__.py +47 -0
- bayesian_pricing/_utils.py +89 -0
- bayesian_pricing/diagnostics.py +299 -0
- bayesian_pricing/frequency.py +534 -0
- bayesian_pricing/relativities.py +284 -0
- bayesian_pricing/severity.py +395 -0
- bayesian_pricing-0.1.0.dist-info/METADATA +221 -0
- bayesian_pricing-0.1.0.dist-info/RECORD +10 -0
- bayesian_pricing-0.1.0.dist-info/WHEEL +4 -0
- bayesian_pricing-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""
|
|
2
|
+
bayesian-pricing: Hierarchical Bayesian models for insurance pricing thin-data segments.
|
|
3
|
+
|
|
4
|
+
The central problem this library solves is the sparse-cell problem in personal lines rating.
|
|
5
|
+
A motor book with 1M policies might have 4.5M theoretical rating cells. Most are empty or
|
|
6
|
+
contain fewer than 30 observations. GBMs overfit or refuse to split. Saturated GLMs overfit.
|
|
7
|
+
Ridge GLMs shrink uniformly regardless of exposure. None of these are right.
|
|
8
|
+
|
|
9
|
+
The correct answer is partial pooling: thin segments borrow strength from related segments
|
|
10
|
+
via a shared population distribution. The degree of borrowing is data-driven -- determined
|
|
11
|
+
by the ratio of within-segment sampling noise to between-segment signal variance. This is
|
|
12
|
+
the Bayesian posterior.
|
|
13
|
+
|
|
14
|
+
Under Normal-Normal conjugacy, this is exactly Bühlmann-Straub credibility. This library
|
|
15
|
+
generalises that to Poisson (frequency) and Gamma (severity) likelihoods, with multiple
|
|
16
|
+
crossed random effects, using PyMC 5.x under the hood.
|
|
17
|
+
|
|
18
|
+
Primary classes:
|
|
19
|
+
HierarchicalFrequency: Poisson hierarchical model for claim frequency
|
|
20
|
+
HierarchicalSeverity: Gamma hierarchical model for claim severity
|
|
21
|
+
BayesianRelativities: Extract multiplicative relativities from the posterior
|
|
22
|
+
|
|
23
|
+
Usage::
|
|
24
|
+
|
|
25
|
+
from bayesian_pricing import HierarchicalFrequency, BayesianRelativities
|
|
26
|
+
|
|
27
|
+
freq_model = HierarchicalFrequency(group_cols=["veh_group", "age_band"])
|
|
28
|
+
freq_model.fit(df, claim_count_col="claims", exposure_col="earned_exposure")
|
|
29
|
+
|
|
30
|
+
rel = BayesianRelativities(freq_model)
|
|
31
|
+
print(rel.relativities()) # DataFrame with posterior median + credible interval
|
|
32
|
+
print(rel.credibility_factors()) # How much weight each segment puts on own data
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from bayesian_pricing.frequency import HierarchicalFrequency
|
|
36
|
+
from bayesian_pricing.severity import HierarchicalSeverity
|
|
37
|
+
from bayesian_pricing.relativities import BayesianRelativities
|
|
38
|
+
from bayesian_pricing.diagnostics import convergence_summary, posterior_predictive_check
|
|
39
|
+
|
|
40
|
+
__version__ = "0.1.0"
|
|
41
|
+
__all__ = [
|
|
42
|
+
"HierarchicalFrequency",
|
|
43
|
+
"HierarchicalSeverity",
|
|
44
|
+
"BayesianRelativities",
|
|
45
|
+
"convergence_summary",
|
|
46
|
+
"posterior_predictive_check",
|
|
47
|
+
]
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Internal utilities for bayesian-pricing."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _check_pymc() -> None:
|
|
10
|
+
"""Raise a helpful ImportError if PyMC is not installed."""
|
|
11
|
+
try:
|
|
12
|
+
import pymc # noqa: F401
|
|
13
|
+
except ImportError:
|
|
14
|
+
raise ImportError(
|
|
15
|
+
"PyMC is required for fitting Bayesian models. Install it with:\n\n"
|
|
16
|
+
" uv add pymc\n\n"
|
|
17
|
+
"Or install this package with the pymc extras:\n\n"
|
|
18
|
+
" uv add bayesian-pricing[pymc]\n\n"
|
|
19
|
+
"PyMC requires C++ compiler tools on some platforms. See:\n"
|
|
20
|
+
" https://www.pymc.io/projects/docs/en/stable/installation.html\n\n"
|
|
21
|
+
"For GPU acceleration (large portfolios), install with NumPyro backend:\n"
|
|
22
|
+
" uv add bayesian-pricing[numpyro]"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _validate_columns_present(df: pd.DataFrame, cols: list[str]) -> None:
|
|
27
|
+
"""Raise ValueError if any column is missing from the DataFrame."""
|
|
28
|
+
missing = [c for c in cols if c not in df.columns]
|
|
29
|
+
if missing:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"Columns not found in data: {missing}. "
|
|
32
|
+
f"Available columns: {df.columns.tolist()}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _validate_positive(series: pd.Series, name: str) -> None:
|
|
37
|
+
"""Raise ValueError if any value is non-positive."""
|
|
38
|
+
if (series <= 0).any():
|
|
39
|
+
n_bad = (series <= 0).sum()
|
|
40
|
+
raise ValueError(
|
|
41
|
+
f"Column '{name}' must be strictly positive. "
|
|
42
|
+
f"Found {n_bad} non-positive values. "
|
|
43
|
+
f"Min value: {series.min()}"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _portfolio_mean_rate(
|
|
48
|
+
claims: pd.Series, exposure: pd.Series
|
|
49
|
+
) -> float:
|
|
50
|
+
"""Compute exposure-weighted portfolio mean claim rate.
|
|
51
|
+
|
|
52
|
+
This is the maximum likelihood estimate of the overall Poisson rate --
|
|
53
|
+
total claims divided by total exposure. Used as the prior mean for the
|
|
54
|
+
intercept when the user does not provide one.
|
|
55
|
+
|
|
56
|
+
The prior should ideally come from a long-run average, not the training data,
|
|
57
|
+
to avoid the prior adapting to the same data the likelihood uses. But in
|
|
58
|
+
practice, for a weakly informative prior (sigma=0.5 on log scale), this
|
|
59
|
+
makes little difference.
|
|
60
|
+
"""
|
|
61
|
+
total_claims = float(claims.sum())
|
|
62
|
+
total_exposure = float(exposure.sum())
|
|
63
|
+
if total_exposure == 0:
|
|
64
|
+
raise ValueError("Total exposure is zero. Cannot compute portfolio mean rate.")
|
|
65
|
+
if total_claims == 0:
|
|
66
|
+
# Avoid log(0) in prior; use a small positive value
|
|
67
|
+
return 0.001
|
|
68
|
+
return total_claims / total_exposure
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _segment_index(series: pd.Series) -> tuple[np.ndarray, np.ndarray]:
|
|
72
|
+
"""Convert a categorical series to integer indices and unique levels.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
(indices, levels): indices maps each row to a position in levels.
|
|
76
|
+
"""
|
|
77
|
+
# Use pandas Categorical for consistent ordering
|
|
78
|
+
cat = pd.Categorical(series)
|
|
79
|
+
indices = cat.codes.copy()
|
|
80
|
+
levels = np.array(cat.categories)
|
|
81
|
+
|
|
82
|
+
if (indices < 0).any():
|
|
83
|
+
n_null = (indices < 0).sum()
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"Column '{series.name}' contains {n_null} null/NaN values. "
|
|
86
|
+
"Fill or drop these before fitting."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return indices, levels
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Diagnostics for hierarchical Bayesian insurance models.
|
|
3
|
+
|
|
4
|
+
Before you trust a model's output, you need to know whether the MCMC sampler
|
|
5
|
+
actually explored the posterior correctly. Two types of failures are common in
|
|
6
|
+
hierarchical models:
|
|
7
|
+
|
|
8
|
+
1. Non-convergence: chains did not mix. R-hat > 1.01 is the standard flag.
|
|
9
|
+
Cause: usually centered parameterization creating funnel geometry. Fix:
|
|
10
|
+
ensure non-centered parameterization is used (it is, by default in this library).
|
|
11
|
+
|
|
12
|
+
2. Divergences: the HMC trajectory hit a region where the step size is too large.
|
|
13
|
+
A small number (<0.1% of samples) is acceptable. More than 1% indicates a
|
|
14
|
+
poorly specified model. Increase target_accept in SamplerConfig or check
|
|
15
|
+
your priors.
|
|
16
|
+
|
|
17
|
+
After convergence, validate that the model actually describes your data:
|
|
18
|
+
|
|
19
|
+
3. Posterior predictive check: simulate new datasets from the posterior and
|
|
20
|
+
compare to observed data. If the model is correct, the observed statistics
|
|
21
|
+
(mean, variance, 95th percentile) should fall within the simulated range.
|
|
22
|
+
|
|
23
|
+
4. Calibration on holdout: the 90% credible interval should contain the true
|
|
24
|
+
value 90% of the time. If it contains it 99% of the time, your priors are
|
|
25
|
+
too tight and the model is overconfident. If 70%, it's under-dispersed.
|
|
26
|
+
|
|
27
|
+
These are the checks a Lloyd's of London actuary would want to see in a model
|
|
28
|
+
validation report. The functions here support all of them.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
from typing import Optional
|
|
34
|
+
|
|
35
|
+
import numpy as np
|
|
36
|
+
import pandas as pd
|
|
37
|
+
|
|
38
|
+
from bayesian_pricing._utils import _check_pymc
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def convergence_summary(model, return_warnings: bool = True) -> pd.DataFrame:
|
|
42
|
+
"""Summarise MCMC convergence diagnostics.
|
|
43
|
+
|
|
44
|
+
Returns a DataFrame of diagnostics for every parameter in the model.
|
|
45
|
+
The key columns are:
|
|
46
|
+
|
|
47
|
+
- r_hat: Gelman-Rubin statistic. Should be < 1.01. Values > 1.05 indicate
|
|
48
|
+
serious non-convergence and the results should not be used.
|
|
49
|
+
- ess_bulk: Effective sample size for bulk of posterior. Target > 400.
|
|
50
|
+
Low ESS means the chain mixed slowly -- your estimates are less precise
|
|
51
|
+
than the nominal sample count suggests.
|
|
52
|
+
- ess_tail: ESS for the tails of the distribution. More relevant for
|
|
53
|
+
credible intervals. Target > 400.
|
|
54
|
+
- divergences: Count of divergent transitions. Should be 0 ideally; < 10
|
|
55
|
+
(out of 4,000 samples) is often acceptable. Flag any divergences.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
model: Fitted HierarchicalFrequency or HierarchicalSeverity.
|
|
59
|
+
return_warnings: If True, print actionable warnings when diagnostics
|
|
60
|
+
are outside acceptable ranges.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
DataFrame with one row per parameter (or parameter level for vectors).
|
|
64
|
+
"""
|
|
65
|
+
_check_pymc()
|
|
66
|
+
import arviz as az
|
|
67
|
+
|
|
68
|
+
if not getattr(model, "_fitted", False):
|
|
69
|
+
raise RuntimeError("Model not fitted. Call .fit() first.")
|
|
70
|
+
|
|
71
|
+
idata = model.idata
|
|
72
|
+
|
|
73
|
+
# Check if NUTS was used (pathfinder has no r_hat)
|
|
74
|
+
has_sample_stats = hasattr(idata, "sample_stats")
|
|
75
|
+
is_nuts = has_sample_stats and "diverging" in idata.sample_stats
|
|
76
|
+
|
|
77
|
+
if not is_nuts:
|
|
78
|
+
# Pathfinder: only basic summary, no convergence diagnostics
|
|
79
|
+
summary = az.summary(idata, kind="stats")
|
|
80
|
+
if return_warnings:
|
|
81
|
+
print(
|
|
82
|
+
"WARNING: Model was fitted with Pathfinder (variational inference). "
|
|
83
|
+
"R-hat and ESS diagnostics are not available. "
|
|
84
|
+
"Re-fit with SamplerConfig(method='nuts') for production use."
|
|
85
|
+
)
|
|
86
|
+
return summary
|
|
87
|
+
|
|
88
|
+
summary = az.summary(idata, round_to=4)
|
|
89
|
+
|
|
90
|
+
# Count divergences
|
|
91
|
+
n_divergent = int(idata.sample_stats["diverging"].sum().item())
|
|
92
|
+
total_samples = (
|
|
93
|
+
idata.sample_stats["diverging"].sizes["chain"]
|
|
94
|
+
* idata.sample_stats["diverging"].sizes["draw"]
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if return_warnings:
|
|
98
|
+
# R-hat check
|
|
99
|
+
if "r_hat" in summary.columns:
|
|
100
|
+
bad_rhat = summary[summary["r_hat"] > 1.01]
|
|
101
|
+
if len(bad_rhat) > 0:
|
|
102
|
+
print(
|
|
103
|
+
f"WARNING: {len(bad_rhat)} parameter(s) have R-hat > 1.01. "
|
|
104
|
+
f"Non-convergence detected. Do not use these results. "
|
|
105
|
+
f"Check: {bad_rhat.index.tolist()[:5]}"
|
|
106
|
+
)
|
|
107
|
+
elif summary["r_hat"].max() > 1.005:
|
|
108
|
+
print(
|
|
109
|
+
f"NOTE: Maximum R-hat is {summary['r_hat'].max():.4f}. "
|
|
110
|
+
"Marginally acceptable. Consider longer chains."
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
print(f"R-hat: OK (max = {summary['r_hat'].max():.4f})")
|
|
114
|
+
|
|
115
|
+
# ESS check
|
|
116
|
+
ess_col = "ess_bulk" if "ess_bulk" in summary.columns else "ess_mean"
|
|
117
|
+
if ess_col in summary.columns:
|
|
118
|
+
low_ess = summary[summary[ess_col] < 400]
|
|
119
|
+
if len(low_ess) > 0:
|
|
120
|
+
print(
|
|
121
|
+
f"WARNING: {len(low_ess)} parameter(s) have ESS < 400. "
|
|
122
|
+
f"Increase draws or tune in SamplerConfig."
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
print(f"ESS: OK (min bulk = {summary[ess_col].min():.0f})")
|
|
126
|
+
|
|
127
|
+
# Divergence check
|
|
128
|
+
pct_divergent = n_divergent / total_samples * 100
|
|
129
|
+
if n_divergent == 0:
|
|
130
|
+
print("Divergences: none")
|
|
131
|
+
elif pct_divergent < 0.1:
|
|
132
|
+
print(
|
|
133
|
+
f"NOTE: {n_divergent} divergences ({pct_divergent:.3f}%). "
|
|
134
|
+
"Small number, probably fine. Check with az.plot_trace()."
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
print(
|
|
138
|
+
f"WARNING: {n_divergent} divergences ({pct_divergent:.2f}%). "
|
|
139
|
+
"This is too many. Try SamplerConfig(target_accept=0.95). "
|
|
140
|
+
"If problem persists, check model specification."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
summary.attrs["n_divergences"] = n_divergent
|
|
144
|
+
summary.attrs["pct_divergences"] = n_divergent / total_samples * 100 if total_samples else 0
|
|
145
|
+
return summary
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def posterior_predictive_check(
|
|
149
|
+
model,
|
|
150
|
+
claim_count_col: Optional[str] = None,
|
|
151
|
+
severity_col: Optional[str] = None,
|
|
152
|
+
n_stats: int = 200,
|
|
153
|
+
) -> dict:
|
|
154
|
+
"""Compare observed statistics to posterior predictive distribution.
|
|
155
|
+
|
|
156
|
+
This is the fundamental model validation: simulate datasets from the
|
|
157
|
+
fitted model and check whether the observed data looks plausible given
|
|
158
|
+
those simulations. If the observed mean claim rate falls in the 94th
|
|
159
|
+
percentile of simulated means, the model is over-predicting -- this
|
|
160
|
+
is a problem.
|
|
161
|
+
|
|
162
|
+
The function returns a dict of check statistics. Each key maps to a
|
|
163
|
+
sub-dict with:
|
|
164
|
+
- observed: the statistic computed on actual data
|
|
165
|
+
- simulated_mean: mean of the statistic across posterior predictive draws
|
|
166
|
+
- simulated_p5, simulated_p95: credible range of the statistic
|
|
167
|
+
- posterior_predictive_p: what percentile the observed value is at
|
|
168
|
+
(should be between 0.05 and 0.95 for a well-calibrated model)
|
|
169
|
+
|
|
170
|
+
Statistics checked:
|
|
171
|
+
- mean: overall mean prediction
|
|
172
|
+
- variance: prediction variance (tests dispersion)
|
|
173
|
+
- p90, p95: upper tail (important for large claim detection)
|
|
174
|
+
- gini: discriminatory power across segments
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
model: A fitted model (HierarchicalFrequency or HierarchicalSeverity).
|
|
178
|
+
claim_count_col: Required for frequency models.
|
|
179
|
+
severity_col: Required for severity models.
|
|
180
|
+
n_stats: Number of posterior predictive samples to use. More gives
|
|
181
|
+
tighter bounds on the check statistics but takes longer.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Dict of check statistics. Examine pp_p values: should all be in [0.05, 0.95].
|
|
185
|
+
"""
|
|
186
|
+
_check_pymc()
|
|
187
|
+
import arviz as az
|
|
188
|
+
|
|
189
|
+
if not getattr(model, "_fitted", False):
|
|
190
|
+
raise RuntimeError("Model not fitted. Call .fit() first.")
|
|
191
|
+
|
|
192
|
+
idata = model.idata
|
|
193
|
+
|
|
194
|
+
if not hasattr(idata, "posterior_predictive"):
|
|
195
|
+
raise RuntimeError(
|
|
196
|
+
"No posterior predictive samples found. "
|
|
197
|
+
"This should be computed automatically during fit(). "
|
|
198
|
+
"If using a custom workflow, call pm.sample_posterior_predictive() "
|
|
199
|
+
"and pass the result to extend_inferencedata=True."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
pp = idata.posterior_predictive
|
|
203
|
+
|
|
204
|
+
# Determine which predictive variable to check
|
|
205
|
+
pp_vars = list(pp.data_vars)
|
|
206
|
+
if not pp_vars:
|
|
207
|
+
raise RuntimeError("posterior_predictive contains no variables.")
|
|
208
|
+
|
|
209
|
+
pp_var = pp_vars[0] # "claims" or "severity"
|
|
210
|
+
pp_data = pp[pp_var].values # (chains, draws, n_segments)
|
|
211
|
+
pp_flat = pp_data.reshape(-1, pp_data.shape[-1]) # (n_samples, n_segments)
|
|
212
|
+
|
|
213
|
+
# Sample a subset if large
|
|
214
|
+
if pp_flat.shape[0] > n_stats:
|
|
215
|
+
idx = np.random.choice(pp_flat.shape[0], n_stats, replace=False)
|
|
216
|
+
pp_flat = pp_flat[idx]
|
|
217
|
+
|
|
218
|
+
# Get observed values
|
|
219
|
+
df = model._segment_data
|
|
220
|
+
if claim_count_col and claim_count_col in df.columns:
|
|
221
|
+
observed_vals = df[claim_count_col].values.astype(float)
|
|
222
|
+
elif severity_col and severity_col in df.columns:
|
|
223
|
+
observed_vals = df[severity_col].values.astype(float)
|
|
224
|
+
else:
|
|
225
|
+
# Try to infer from model
|
|
226
|
+
non_group_cols = [c for c in df.columns if c not in model.group_cols]
|
|
227
|
+
if not non_group_cols:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
"Cannot determine observed values column. "
|
|
230
|
+
"Pass claim_count_col or severity_col explicitly."
|
|
231
|
+
)
|
|
232
|
+
observed_vals = df[non_group_cols[0]].values.astype(float)
|
|
233
|
+
|
|
234
|
+
def _gini(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
235
|
+
"""Normalised Gini coefficient -- standard insurance model discrimination metric."""
|
|
236
|
+
if len(y_true) < 2:
|
|
237
|
+
return 0.0
|
|
238
|
+
order = np.argsort(y_pred)
|
|
239
|
+
y_sorted = y_true[order]
|
|
240
|
+
n = len(y_sorted)
|
|
241
|
+
cum_y = np.cumsum(y_sorted)
|
|
242
|
+
gini = (2 * np.sum((np.arange(1, n + 1)) * y_sorted) - (n + 1) * cum_y[-1]) / (
|
|
243
|
+
n * cum_y[-1]
|
|
244
|
+
)
|
|
245
|
+
return float(np.abs(gini))
|
|
246
|
+
|
|
247
|
+
results = {}
|
|
248
|
+
|
|
249
|
+
# Mean check
|
|
250
|
+
obs_mean = float(np.mean(observed_vals))
|
|
251
|
+
sim_means = pp_flat.mean(axis=1)
|
|
252
|
+
results["mean"] = _stat_check(obs_mean, sim_means)
|
|
253
|
+
|
|
254
|
+
# Variance check
|
|
255
|
+
obs_var = float(np.var(observed_vals))
|
|
256
|
+
sim_vars = pp_flat.var(axis=1)
|
|
257
|
+
results["variance"] = _stat_check(obs_var, sim_vars)
|
|
258
|
+
|
|
259
|
+
# 90th percentile check
|
|
260
|
+
obs_p90 = float(np.percentile(observed_vals, 90))
|
|
261
|
+
sim_p90s = np.percentile(pp_flat, 90, axis=1)
|
|
262
|
+
results["p90"] = _stat_check(obs_p90, sim_p90s)
|
|
263
|
+
|
|
264
|
+
# 95th percentile
|
|
265
|
+
obs_p95 = float(np.percentile(observed_vals, 95))
|
|
266
|
+
sim_p95s = np.percentile(pp_flat, 95, axis=1)
|
|
267
|
+
results["p95"] = _stat_check(obs_p95, sim_p95s)
|
|
268
|
+
|
|
269
|
+
# Summary: did all checks pass?
|
|
270
|
+
failed = [k for k, v in results.items() if not v["pass"]]
|
|
271
|
+
results["_summary"] = {
|
|
272
|
+
"passed": len(results) - len(failed) - 1, # -1 for _summary itself
|
|
273
|
+
"total": len(results) - 1,
|
|
274
|
+
"failed_checks": failed,
|
|
275
|
+
"interpretation": (
|
|
276
|
+
"All checks passed. Model appears well-calibrated."
|
|
277
|
+
if not failed
|
|
278
|
+
else f"Failed checks: {failed}. "
|
|
279
|
+
"The model may be mis-specified for these statistics. "
|
|
280
|
+
"Consider alternative likelihood distributions."
|
|
281
|
+
),
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
return results
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def _stat_check(
|
|
288
|
+
observed: float, simulated: np.ndarray, alpha: float = 0.05
|
|
289
|
+
) -> dict:
|
|
290
|
+
"""Check whether observed statistic is within the simulated range."""
|
|
291
|
+
p_value = float(np.mean(simulated <= observed))
|
|
292
|
+
return {
|
|
293
|
+
"observed": observed,
|
|
294
|
+
"simulated_mean": float(simulated.mean()),
|
|
295
|
+
"simulated_p5": float(np.percentile(simulated, 5)),
|
|
296
|
+
"simulated_p95": float(np.percentile(simulated, 95)),
|
|
297
|
+
"posterior_predictive_p": p_value,
|
|
298
|
+
"pass": alpha <= p_value <= (1 - alpha),
|
|
299
|
+
}
|