ssbc 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ssbc/hyperparameter.py ADDED
@@ -0,0 +1,258 @@
1
+ """Hyperparameter sweep and optimization for Mondrian conformal prediction."""
2
+
3
+ import itertools
4
+ from collections.abc import Callable
5
+ from typing import Any, Literal
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from .conformal import mondrian_conformal_calibrate
11
+ from .visualization import plot_parallel_coordinates_plotly, report_prediction_stats
12
+
13
+
14
+ def sweep_hyperparams_and_collect(
15
+ class_data: dict[int, dict[str, Any]],
16
+ alpha_0: np.ndarray,
17
+ delta_0: np.ndarray,
18
+ alpha_1: np.ndarray,
19
+ delta_1: np.ndarray,
20
+ mode: Literal["beta", "beta-binomial"] = "beta",
21
+ extra_metrics: dict[str, Callable] | None = None,
22
+ quiet: bool = True,
23
+ ) -> pd.DataFrame:
24
+ """Sweep (a0,d0,a1,d1), run mondrian_conformal_calibrate + report_prediction_stats,
25
+ and return a tidy DataFrame with hyperparams + selected metrics.
26
+
27
+ This function performs a grid search over hyperparameter combinations and
28
+ evaluates the resulting conformal prediction performance.
29
+
30
+ Parameters
31
+ ----------
32
+ class_data : dict
33
+ Output from split_by_class()
34
+ alpha_0 : array-like
35
+ Grid of alpha values for class 0
36
+ delta_0 : array-like
37
+ Grid of delta values for class 0
38
+ alpha_1 : array-like
39
+ Grid of alpha values for class 1
40
+ delta_1 : array-like
41
+ Grid of delta values for class 1
42
+ mode : str, default="beta"
43
+ "beta" or "beta-binomial" mode for SSBC
44
+ extra_metrics : dict of {name: function}, optional
45
+ Additional metrics to compute. Each function takes the summary dict
46
+ and returns a scalar value.
47
+ quiet : bool, default=True
48
+ If True, suppress progress output
49
+
50
+ Returns
51
+ -------
52
+ pd.DataFrame
53
+ Tidy dataframe with one row per hyperparameter combination.
54
+ Columns include:
55
+ - a0, d0, a1, d1: hyperparameters
56
+ - cov: overall coverage rate
57
+ - sing_rate: singleton prediction rate
58
+ - err_all: overall singleton error rate
59
+ - err_pred0, err_pred1: errors by predicted class
60
+ - err_y0, err_y1: errors by true class
61
+ - esc_rate: escalation rate (doublets + abstentions)
62
+ - n_total, sing_count, m_abst, m_doublets: counts
63
+ - Any additional metrics from extra_metrics
64
+
65
+ Examples
66
+ --------
67
+ >>> import numpy as np
68
+ >>> from ssbc import BinaryClassifierSimulator, split_by_class
69
+ >>>
70
+ >>> # Generate data
71
+ >>> sim = BinaryClassifierSimulator(0.1, (2, 8), (8, 2), seed=42)
72
+ >>> labels, probs = sim.generate(1000)
73
+ >>> class_data = split_by_class(labels, probs)
74
+ >>>
75
+ >>> # Define grid
76
+ >>> alpha_grid = np.arange(0.05, 0.20, 0.05)
77
+ >>> delta_grid = np.arange(0.05, 0.20, 0.05)
78
+ >>>
79
+ >>> # Run sweep
80
+ >>> df = sweep_hyperparams_and_collect(
81
+ ... class_data,
82
+ ... alpha_0=alpha_grid, delta_0=delta_grid,
83
+ ... alpha_1=alpha_grid, delta_1=delta_grid,
84
+ ... )
85
+ >>>
86
+ >>> # Analyze results
87
+ >>> print(df[['a0', 'a1', 'cov', 'sing_rate', 'err_all']].head())
88
+
89
+ Notes
90
+ -----
91
+ The function performs a complete grid search, so the total number of
92
+ evaluations is len(alpha_0) × len(delta_0) × len(alpha_1) × len(delta_1).
93
+ For large grids, this can be computationally expensive.
94
+ """
95
+ rows = []
96
+ combos = list(itertools.product(alpha_0, delta_0, alpha_1, delta_1))
97
+
98
+ for a0, d0, a1, d1 in combos:
99
+ if not quiet:
100
+ print(f"a0={a0:.3f}, d0={d0:.3f}, a1={a1:.3f}, d1={d1:.3f}")
101
+
102
+ cal_result, pred_stats = mondrian_conformal_calibrate(
103
+ class_data=class_data,
104
+ alpha_target={0: float(a0), 1: float(a1)},
105
+ delta={0: float(d0), 1: float(d1)},
106
+ mode=mode,
107
+ )
108
+ summary = report_prediction_stats(pred_stats, cal_result, verbose=False)
109
+
110
+ # Robust getter
111
+ def g(d, *keys, default=None):
112
+ """Navigate nested dict safely."""
113
+ cur = d
114
+ for k in keys:
115
+ if not isinstance(cur, dict) or k not in cur:
116
+ return default
117
+ cur = cur[k]
118
+ return cur
119
+
120
+ n_total = int(g(summary, "marginal", "n_total", default=0) or 0)
121
+ cov = float(g(summary, "marginal", "coverage", "rate", default=0.0) or 0.0)
122
+ sing_rate = float(g(summary, "marginal", "singletons", "rate", default=0.0) or 0.0)
123
+ sing_cnt = int(g(summary, "marginal", "singletons", "count", default=0) or 0)
124
+ abst_cnt = int(g(summary, "marginal", "abstentions", "count", default=0) or 0)
125
+ doub_cnt = int(g(summary, "marginal", "doublets", "count", default=0) or 0)
126
+ esc_rate = (abst_cnt + doub_cnt) / float(n_total if n_total else 1)
127
+
128
+ err_all = float(g(summary, "marginal", "singletons", "errors", "rate", default=0.0) or 0.0)
129
+ err_p0 = float(g(summary, "marginal", "singletons", "errors_by_pred", "pred_0", "rate", default=0.0) or 0.0)
130
+ err_p1 = float(g(summary, "marginal", "singletons", "errors_by_pred", "pred_1", "rate", default=0.0) or 0.0)
131
+
132
+ err_y0 = float(g(summary, 0, "singletons", "error_given_singleton", "rate", default=0.0) or 0.0)
133
+ err_y1 = float(g(summary, 1, "singletons", "error_given_singleton", "rate", default=0.0) or 0.0)
134
+
135
+ row = {
136
+ "a0": float(a0),
137
+ "d0": float(d0),
138
+ "a1": float(a1),
139
+ "d1": float(d1),
140
+ "cov": cov,
141
+ "sing_rate": sing_rate,
142
+ "err_all": err_all,
143
+ "err_pred0": err_p0,
144
+ "err_pred1": err_p1,
145
+ "err_y0": err_y0,
146
+ "err_y1": err_y1,
147
+ "esc_rate": esc_rate,
148
+ "n_total": int(n_total),
149
+ "sing_count": int(sing_cnt),
150
+ "m_abst": abst_cnt,
151
+ "m_doublets": doub_cnt,
152
+ }
153
+
154
+ if extra_metrics:
155
+ for name, fn in extra_metrics.items():
156
+ try:
157
+ row[name] = fn(summary)
158
+ except Exception:
159
+ row[name] = np.nan
160
+
161
+ rows.append(row)
162
+
163
+ df = pd.DataFrame(rows)
164
+ return df.sort_values(["a0", "d0", "a1", "d1"], kind="mergesort").reset_index(drop=True)
165
+
166
+
167
+ def sweep_and_plot_parallel_plotly(
168
+ class_data: dict[int, dict[str, Any]],
169
+ delta_0: np.ndarray,
170
+ delta_1: np.ndarray,
171
+ alpha_0: np.ndarray,
172
+ alpha_1: np.ndarray,
173
+ mode: Literal["beta", "beta-binomial"] = "beta",
174
+ extra_metrics: dict[str, Callable] | None = None,
175
+ color: str = "err_all",
176
+ color_continuous_scale=None,
177
+ title: str | None = None,
178
+ height: int = 600,
179
+ ):
180
+ """Convenience wrapper: run sweep + show plotly parallel coordinates figure.
181
+
182
+ This function combines hyperparameter sweep and visualization in one call.
183
+
184
+ Parameters
185
+ ----------
186
+ class_data : dict
187
+ Output from split_by_class()
188
+ delta_0, delta_1 : array-like
189
+ Grid of delta values for classes 0 and 1
190
+ alpha_0, alpha_1 : array-like
191
+ Grid of alpha values for classes 0 and 1
192
+ mode : str, default="beta"
193
+ "beta" or "beta-binomial" mode for SSBC
194
+ extra_metrics : dict of {name: function}, optional
195
+ Additional metrics to compute
196
+ color : str, default='err_all'
197
+ Column to use for coloring the parallel coordinates
198
+ color_continuous_scale : plotly colorscale, optional
199
+ Color scale for the plot
200
+ title : str, optional
201
+ Plot title (defaults to auto-generated title)
202
+ height : int, default=600
203
+ Plot height in pixels
204
+
205
+ Returns
206
+ -------
207
+ df : pd.DataFrame
208
+ Results dataframe
209
+ fig : plotly.graph_objects.Figure
210
+ Interactive parallel coordinates plot
211
+
212
+ Examples
213
+ --------
214
+ >>> import numpy as np
215
+ >>> from ssbc import BinaryClassifierSimulator, split_by_class
216
+ >>>
217
+ >>> # Generate data
218
+ >>> sim = BinaryClassifierSimulator(0.1, (2, 8), (8, 2), seed=42)
219
+ >>> labels, probs = sim.generate(1000)
220
+ >>> class_data = split_by_class(labels, probs)
221
+ >>>
222
+ >>> # Run sweep and plot
223
+ >>> df, fig = sweep_and_plot_parallel_plotly(
224
+ ... class_data,
225
+ ... delta_0=np.arange(0.05, 0.20, 0.05),
226
+ ... delta_1=np.arange(0.05, 0.20, 0.05),
227
+ ... alpha_0=np.arange(0.05, 0.20, 0.05),
228
+ ... alpha_1=np.arange(0.05, 0.20, 0.05),
229
+ ... color='err_all'
230
+ ... )
231
+ >>> fig.show() # Display in notebook
232
+ >>> # Or save: fig.write_html("sweep_results.html")
233
+
234
+ Notes
235
+ -----
236
+ The parallel coordinates plot allows interactive exploration of the
237
+ hyperparameter space. You can brush (select) ranges on any axis to
238
+ filter configurations and see their impact on other metrics.
239
+ """
240
+ df = sweep_hyperparams_and_collect(
241
+ class_data=class_data,
242
+ alpha_0=alpha_0,
243
+ delta_0=delta_0,
244
+ alpha_1=alpha_1,
245
+ delta_1=delta_1,
246
+ mode=mode,
247
+ extra_metrics=extra_metrics,
248
+ quiet=True,
249
+ )
250
+
251
+ if title is None:
252
+ title = f"Mondrian Hyperparameter Sweep (n={len(df)} configs)"
253
+
254
+ fig = plot_parallel_coordinates_plotly(
255
+ df, color=color, color_continuous_scale=color_continuous_scale, title=title, height=height
256
+ )
257
+
258
+ return df, fig
ssbc/simulation.py ADDED
@@ -0,0 +1,148 @@
1
+ """Simulation utilities for testing conformal prediction."""
2
+
3
+ import numpy as np
4
+
5
+
6
+ class BinaryClassifierSimulator:
7
+ """Simulate binary classification data with probabilities from Beta distributions.
8
+
9
+ This simulator generates realistic classification scenarios where the predicted
10
+ probabilities for each class follow Beta distributions. Useful for testing and
11
+ benchmarking conformal prediction methods.
12
+
13
+ Parameters
14
+ ----------
15
+ p_class1 : float
16
+ Probability of drawing class 1 (class imbalance parameter)
17
+ Must be in [0, 1]
18
+ beta_params_class0 : tuple of (a, b)
19
+ Beta distribution parameters for p(class=1) when true label is 0
20
+ Typically use parameters that give low probabilities (e.g., (2, 8))
21
+ beta_params_class1 : tuple of (a, b)
22
+ Beta distribution parameters for p(class=1) when true label is 1
23
+ Typically use parameters that give high probabilities (e.g., (8, 2))
24
+ seed : int, optional
25
+ Random seed for reproducibility
26
+
27
+ Attributes
28
+ ----------
29
+ p_class1 : float
30
+ Probability of class 1
31
+ p_class0 : float
32
+ Probability of class 0 (= 1 - p_class1)
33
+ a0, b0 : float
34
+ Beta parameters for class 0
35
+ a1, b1 : float
36
+ Beta parameters for class 1
37
+ rng : numpy.random.Generator
38
+ Random number generator
39
+
40
+ Examples
41
+ --------
42
+ >>> # Simulate imbalanced data: 10% positive class
43
+ >>> # Class 0: Beta(2, 8) → mean p(class=1) = 0.2 (low scores, correct)
44
+ >>> # Class 1: Beta(8, 2) → mean p(class=1) = 0.8 (high scores, correct)
45
+ >>> sim = BinaryClassifierSimulator(
46
+ ... p_class1=0.10,
47
+ ... beta_params_class0=(2, 8),
48
+ ... beta_params_class1=(8, 2),
49
+ ... seed=42
50
+ ... )
51
+ >>> labels, probs = sim.generate(n_samples=100)
52
+ >>> print(labels.shape)
53
+ (100,)
54
+ >>> print(probs.shape)
55
+ (100, 2)
56
+
57
+ Notes
58
+ -----
59
+ The Beta distribution parameters (a, b) control the shape:
60
+ - Mean = a / (a + b)
61
+ - For a classifier that works well:
62
+ - Class 0 should have low p(class=1): use (a, b) with a < b
63
+ - Class 1 should have high p(class=1): use (a, b) with a > b
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ p_class1: float,
69
+ beta_params_class0: tuple[float, float],
70
+ beta_params_class1: tuple[float, float],
71
+ seed: int | None = None,
72
+ ):
73
+ """Initialize the binary classifier simulator."""
74
+ if not 0 <= p_class1 <= 1:
75
+ raise ValueError("p_class1 must be in [0, 1]")
76
+
77
+ self.p_class1 = p_class1
78
+ self.p_class0 = 1.0 - p_class1
79
+ self.a0, self.b0 = beta_params_class0
80
+ self.a1, self.b1 = beta_params_class1
81
+ self.rng = np.random.default_rng(seed)
82
+
83
+ # Validate beta parameters
84
+ if self.a0 <= 0 or self.b0 <= 0:
85
+ raise ValueError("Beta parameters for class 0 must be positive")
86
+ if self.a1 <= 0 or self.b1 <= 0:
87
+ raise ValueError("Beta parameters for class 1 must be positive")
88
+
89
+ def generate(self, n_samples: int) -> tuple[np.ndarray, np.ndarray]:
90
+ """Generate n_samples of (label, p(class=0), p(class=1)).
91
+
92
+ Parameters
93
+ ----------
94
+ n_samples : int
95
+ Number of samples to generate
96
+
97
+ Returns
98
+ -------
99
+ labels : np.ndarray, shape (n_samples,)
100
+ True binary labels (0 or 1)
101
+ probs : np.ndarray, shape (n_samples, 2)
102
+ Classification probabilities [p(class=0), p(class=1)]
103
+ Each row sums to 1.0
104
+
105
+ Examples
106
+ --------
107
+ >>> sim = BinaryClassifierSimulator(
108
+ ... p_class1=0.5,
109
+ ... beta_params_class0=(2, 8),
110
+ ... beta_params_class1=(8, 2),
111
+ ... seed=42
112
+ ... )
113
+ >>> labels, probs = sim.generate(n_samples=5)
114
+ >>> print(f"Generated {len(labels)} samples")
115
+ Generated 5 samples
116
+ >>> print(f"Class balance: {np.bincount(labels)}")
117
+ Class balance: [2 3]
118
+ """
119
+ if n_samples <= 0:
120
+ raise ValueError("n_samples must be positive")
121
+
122
+ # Draw true labels according to class distribution
123
+ labels = self.rng.choice([0, 1], size=n_samples, p=[self.p_class0, self.p_class1])
124
+
125
+ # Initialize probability array
126
+ probs = np.zeros((n_samples, 2))
127
+
128
+ # For each label, draw classification probability from appropriate Beta
129
+ for i, label in enumerate(labels):
130
+ if label == 0:
131
+ # True label is 0: sample p(class=1) from Beta(a0, b0)
132
+ p_class1 = self.rng.beta(self.a0, self.b0)
133
+ else:
134
+ # True label is 1: sample p(class=1) from Beta(a1, b1)
135
+ p_class1 = self.rng.beta(self.a1, self.b1)
136
+
137
+ probs[i, 1] = p_class1 # p(class=1)
138
+ probs[i, 0] = 1.0 - p_class1 # p(class=0)
139
+
140
+ return labels, probs
141
+
142
+ def __repr__(self) -> str:
143
+ """String representation of the simulator."""
144
+ return (
145
+ f"BinaryClassifierSimulator(p_class1={self.p_class1:.3f}, "
146
+ f"beta_class0=({self.a0}, {self.b0}), "
147
+ f"beta_class1=({self.a1}, {self.b1}))"
148
+ )
ssbc/ssbc.py ADDED
@@ -0,0 +1 @@
1
+ """Main module."""
ssbc/statistics.py ADDED
@@ -0,0 +1,158 @@
1
+ """Statistical utility functions for SSBC."""
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ from scipy import stats
7
+
8
+
9
+ def clopper_pearson_intervals(labels: np.ndarray, confidence: float = 0.95) -> dict[int, dict[str, Any]]:
10
+ """Compute Clopper-Pearson (exact binomial) confidence intervals for class prevalences.
11
+
12
+ Parameters
13
+ ----------
14
+ labels : np.ndarray
15
+ Binary labels (0 or 1)
16
+ confidence : float, default=0.95
17
+ Confidence level (e.g., 0.95 for 95% CI)
18
+
19
+ Returns
20
+ -------
21
+ dict
22
+ Dictionary with keys 0 and 1, each containing:
23
+ - 'count': number of samples in this class
24
+ - 'proportion': observed proportion
25
+ - 'lower': lower bound of CI
26
+ - 'upper': upper bound of CI
27
+
28
+ Examples
29
+ --------
30
+ >>> labels = np.array([0, 0, 1, 1, 0])
31
+ >>> intervals = clopper_pearson_intervals(labels, confidence=0.95)
32
+ >>> print(intervals[0]['proportion'])
33
+ 0.6
34
+
35
+ Notes
36
+ -----
37
+ The Clopper-Pearson interval is an exact binomial confidence interval
38
+ based on Beta distribution quantiles. It provides conservative coverage
39
+ guarantees.
40
+ """
41
+ alpha = 1 - confidence
42
+ n_total = len(labels)
43
+
44
+ intervals = {}
45
+
46
+ for label in [0, 1]:
47
+ count = np.sum(labels == label)
48
+ proportion = count / n_total
49
+
50
+ # Clopper-Pearson uses Beta distribution quantiles
51
+ # Lower bound: Beta(count, n-count+1) at alpha/2
52
+ # Upper bound: Beta(count+1, n-count) at 1-alpha/2
53
+
54
+ if count == 0:
55
+ lower = 0.0
56
+ upper = stats.beta.ppf(1 - alpha / 2, count + 1, n_total - count)
57
+ elif count == n_total:
58
+ lower = stats.beta.ppf(alpha / 2, count, n_total - count + 1)
59
+ upper = 1.0
60
+ else:
61
+ lower = stats.beta.ppf(alpha / 2, count, n_total - count + 1)
62
+ upper = stats.beta.ppf(1 - alpha / 2, count + 1, n_total - count)
63
+
64
+ intervals[label] = {"count": count, "proportion": proportion, "lower": lower, "upper": upper}
65
+
66
+ return intervals
67
+
68
+
69
+ def cp_interval(count: int, total: int, confidence: float = 0.95) -> dict[str, float]:
70
+ """Compute Clopper-Pearson exact confidence interval.
71
+
72
+ Helper function for computing a single CI from count and total.
73
+
74
+ Parameters
75
+ ----------
76
+ count : int
77
+ Number of successes
78
+ total : int
79
+ Total number of trials
80
+ confidence : float, default=0.95
81
+ Confidence level
82
+
83
+ Returns
84
+ -------
85
+ dict
86
+ Dictionary with keys:
87
+ - 'count': original count
88
+ - 'proportion': count/total
89
+ - 'lower': lower CI bound
90
+ - 'upper': upper CI bound
91
+ """
92
+ alpha = 1 - confidence
93
+ count = int(count)
94
+ total = int(total)
95
+
96
+ if total <= 0:
97
+ return {"count": count, "proportion": 0.0, "lower": 0.0, "upper": 0.0}
98
+
99
+ p = count / total
100
+
101
+ if count == 0:
102
+ lower = 0.0
103
+ upper = stats.beta.ppf(1 - alpha / 2, 1, total)
104
+ elif count == total:
105
+ lower = stats.beta.ppf(alpha / 2, total, 1)
106
+ upper = 1.0
107
+ else:
108
+ lower = stats.beta.ppf(alpha / 2, count, total - count + 1)
109
+ upper = stats.beta.ppf(1 - alpha / 2, count + 1, total - count)
110
+
111
+ return {"count": count, "proportion": float(p), "lower": float(lower), "upper": float(upper)}
112
+
113
+
114
+ def ensure_ci(d: dict[str, Any], count: int, total: int, confidence: float = 0.95) -> tuple[float, float, float]:
115
+ """Extract or compute rate and confidence interval from a dictionary.
116
+
117
+ If the dictionary already contains rate/CI information, use it.
118
+ Otherwise, compute Clopper-Pearson CI from count/total.
119
+
120
+ Parameters
121
+ ----------
122
+ d : dict
123
+ Dictionary that may contain 'rate'/'proportion' and 'lower'/'upper'
124
+ count : int
125
+ Count for CI computation (if needed)
126
+ total : int
127
+ Total for CI computation (if needed)
128
+ confidence : float, default=0.95
129
+ Confidence level
130
+
131
+ Returns
132
+ -------
133
+ tuple of (rate, lower, upper)
134
+ Rate and confidence interval bounds
135
+ """
136
+ # Try to get existing rate
137
+ r = None
138
+ if isinstance(d, dict):
139
+ if "rate" in d:
140
+ r = float(d["rate"])
141
+ elif "proportion" in d:
142
+ r = float(d["proportion"])
143
+
144
+ # Try to get existing CI
145
+ lo, hi = 0.0, 0.0
146
+ if isinstance(d, dict):
147
+ if "ci_95" in d and isinstance(d["ci_95"], tuple | list) and len(d["ci_95"]) == 2:
148
+ lo, hi = float(d["ci_95"][0]), float(d["ci_95"][1])
149
+ else:
150
+ lo = float(d.get("lower", 0.0))
151
+ hi = float(d.get("upper", 0.0))
152
+
153
+ # If missing or invalid, compute CP interval
154
+ if r is None or (lo == 0.0 and hi == 0.0 and (count > 0 or total > 0)):
155
+ ci = cp_interval(count, total, confidence)
156
+ return ci["proportion"], ci["lower"], ci["upper"]
157
+
158
+ return float(r), float(lo), float(hi)
ssbc/utils.py ADDED
@@ -0,0 +1,2 @@
1
+ def do_something_useful():
2
+ print("Replace this with a utility function")