skxperiments 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. skxperiments/__init__.py +5 -0
  2. skxperiments/core/__init__.py +42 -0
  3. skxperiments/core/assignment.py +589 -0
  4. skxperiments/core/base.py +512 -0
  5. skxperiments/core/exceptions.py +145 -0
  6. skxperiments/core/potential_outcomes.py +168 -0
  7. skxperiments/core/results.py +624 -0
  8. skxperiments/design/__init__.py +22 -0
  9. skxperiments/design/balance.py +182 -0
  10. skxperiments/design/blocked_crd.py +157 -0
  11. skxperiments/design/crd.py +162 -0
  12. skxperiments/design/factorial.py +174 -0
  13. skxperiments/design/power.py +233 -0
  14. skxperiments/design/rerandomized_crd.py +319 -0
  15. skxperiments/diagnostics/__init__.py +21 -0
  16. skxperiments/diagnostics/aa_test.py +277 -0
  17. skxperiments/diagnostics/balance_report.py +224 -0
  18. skxperiments/diagnostics/srm.py +327 -0
  19. skxperiments/estimators/__init__.py +23 -0
  20. skxperiments/estimators/blocked_difference_in_means.py +197 -0
  21. skxperiments/estimators/cuped.py +280 -0
  22. skxperiments/estimators/difference_in_means.py +161 -0
  23. skxperiments/estimators/factorial_estimator.py +213 -0
  24. skxperiments/estimators/lin_estimator.py +298 -0
  25. skxperiments/inference/__init__.py +17 -0
  26. skxperiments/inference/bootstrap.py +450 -0
  27. skxperiments/inference/multiple.py +365 -0
  28. skxperiments/inference/neyman.py +386 -0
  29. skxperiments/inference/randomization_test.py +319 -0
  30. skxperiments/pipeline.py +366 -0
  31. skxperiments/reporting/__init__.py +30 -0
  32. skxperiments/reporting/plots.py +411 -0
  33. skxperiments/reporting/summary.py +185 -0
  34. skxperiments-0.1.0.dev0.dist-info/METADATA +272 -0
  35. skxperiments-0.1.0.dev0.dist-info/RECORD +36 -0
  36. skxperiments-0.1.0.dev0.dist-info/WHEEL +4 -0
@@ -0,0 +1,277 @@
1
+ """A/A test diagnostic: calibration of a design + inference pipeline.
2
+
3
+ An A/A test repeatedly re-randomizes a design on a *fixed* dataset and
4
+ runs an inference procedure each time. Because the treatment is re-drawn
5
+ while the outcome is held fixed, the true effect is zero by construction,
6
+ so a well-calibrated pipeline should reject the null at rate ``alpha`` and
7
+ produce uniformly distributed p-values. ``AATest`` measures both: the
8
+ false-positive rate (compared to ``alpha`` with an exact binomial test)
9
+ and the uniformity of the p-values (Kolmogorov-Smirnov).
10
+
11
+ Cost note
12
+ ---------
13
+ Each simulation runs the wrapped inference once. With a resampling-based
14
+ inference (``RandomizationTest``, ``BootstrapCI``) this is a nested loop:
15
+ ``O(n_simulations x n_resamples)`` estimator fits. For routine calibration
16
+ prefer the analytic ``NeymanCI``; reserve the resampling inferences for
17
+ small ``n_simulations`` (and expect a slow run).
18
+ """
19
+
20
+ from dataclasses import dataclass
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+ from scipy.stats import binomtest, kstest
25
+
26
+ from skxperiments.core.base import BaseDesign, BaseInference, DiagnosticsReport
27
+ from skxperiments.core.exceptions import InvalidDesignError
28
+
29
+
30
+ @dataclass(frozen=True, eq=False)
31
+ class AAResult:
32
+ """Result of an A/A test.
33
+
34
+ Attributes
35
+ ----------
36
+ n_simulations : int
37
+ Number of re-randomizations performed.
38
+ alpha : float
39
+ Nominal significance level the pipeline was run at.
40
+ meta_threshold : float
41
+ Threshold for the binomial false-positive-rate test below which
42
+ the pipeline is flagged as miscalibrated.
43
+ false_positive_rate : float
44
+ Fraction of simulations with ``p_value < alpha``.
45
+ n_false_positives : int
46
+ Number of simulations with ``p_value < alpha``.
47
+ fp_test_pvalue : float
48
+ Two-sided exact binomial-test p-value for ``n_false_positives``
49
+ out of ``n_simulations`` under the null rate ``alpha``.
50
+ ks_statistic : float
51
+ Kolmogorov-Smirnov statistic of the p-values against Uniform(0, 1).
52
+ ks_pvalue : float
53
+ KS test p-value (secondary, full-distribution calibration signal).
54
+ p_values : np.ndarray
55
+ The ``n_simulations`` p-values, in simulation order.
56
+ flagged : bool
57
+ True if ``fp_test_pvalue < meta_threshold`` — the false-positive
58
+ rate is incompatible with ``alpha``.
59
+ """
60
+
61
+ n_simulations: int
62
+ alpha: float
63
+ meta_threshold: float
64
+ false_positive_rate: float
65
+ n_false_positives: int
66
+ fp_test_pvalue: float
67
+ ks_statistic: float
68
+ ks_pvalue: float
69
+ p_values: np.ndarray
70
+ flagged: bool
71
+
72
+ def summary(self) -> "AAResult":
73
+ """Print a formatted summary table and return self."""
74
+ status = "FLAGGED — miscalibrated" if self.flagged else "OK"
75
+ lines = ["A/A Test", "--------"]
76
+ lines.append(f"simulations {self.n_simulations}")
77
+ lines.append(f"alpha {self.alpha}")
78
+ lines.append(
79
+ f"false-positive {self.false_positive_rate:.4f} "
80
+ f"({self.n_false_positives}/{self.n_simulations})"
81
+ )
82
+ lines.append(f"FP binomial p {self.fp_test_pvalue:.4f}")
83
+ lines.append(f"KS uniform p {self.ks_pvalue:.4f}")
84
+ lines.append(f"status {status}")
85
+ print("\n".join(lines))
86
+ return self
87
+
88
+ def to_dict(self) -> dict:
89
+ """Return the scalar summary fields (excludes the p_values array)."""
90
+ return {
91
+ "n_simulations": self.n_simulations,
92
+ "alpha": self.alpha,
93
+ "meta_threshold": self.meta_threshold,
94
+ "false_positive_rate": self.false_positive_rate,
95
+ "n_false_positives": self.n_false_positives,
96
+ "fp_test_pvalue": self.fp_test_pvalue,
97
+ "ks_statistic": self.ks_statistic,
98
+ "ks_pvalue": self.ks_pvalue,
99
+ "flagged": self.flagged,
100
+ }
101
+
102
+ def to_diagnostics_report(self) -> DiagnosticsReport:
103
+ """Convert to a ``DiagnosticsReport`` for pipeline aggregation.
104
+
105
+ A miscalibrated false-positive rate is a flag; non-uniform
106
+ p-values (KS below the meta-threshold) are a secondary warning.
107
+ """
108
+ report = DiagnosticsReport()
109
+ if self.flagged:
110
+ report.flags.append(
111
+ f"A/A false-positive rate miscalibrated: "
112
+ f"{self.false_positive_rate:.3f} vs alpha={self.alpha} "
113
+ f"(binomial p={self.fp_test_pvalue:.2e} < "
114
+ f"{self.meta_threshold})."
115
+ )
116
+ if self.ks_pvalue < self.meta_threshold:
117
+ report.warnings.append(
118
+ f"A/A p-values deviate from uniform "
119
+ f"(KS p={self.ks_pvalue:.2e})."
120
+ )
121
+ return report
122
+
123
+
124
+ class AATest:
125
+ """A/A test for a design + inference pipeline.
126
+
127
+ Re-randomizes ``design`` on a fixed DataFrame ``n_simulations`` times
128
+ and runs ``inference`` on each draw, collecting the p-values. Because
129
+ the treatment is re-drawn while the outcome is fixed, the true effect
130
+ is zero, so the false-positive rate should equal ``alpha`` and the
131
+ p-values should be uniform.
132
+
133
+ Parameters
134
+ ----------
135
+ design : BaseDesign
136
+ The design whose randomization is being calibrated.
137
+ inference : BaseInference
138
+ A configured inference object (which already wraps an estimator),
139
+ e.g. ``NeymanCI(DifferenceInMeans("y"))`` or
140
+ ``RandomizationTest(...)``. Must produce a scalar ``p_value``.
141
+ n_simulations : int, optional
142
+ Number of re-randomizations, by default 1000.
143
+ alpha : float, optional
144
+ Nominal significance level, by default 0.05.
145
+ meta_threshold : float, optional
146
+ Threshold for the binomial false-positive-rate test, by default
147
+ 0.001. The pipeline is flagged when the binomial-test p-value
148
+ falls below it.
149
+ seed : int or None, optional
150
+ Random seed for reproducibility, by default None.
151
+
152
+ Notes
153
+ -----
154
+ The wrapped ``inference`` is refitted on each draw; its ``seed`` (if
155
+ any) is varied per simulation and restored afterward. The outcome
156
+ column is whatever the wrapped estimator resolves against the data;
157
+ the supplied DataFrame must contain it (and any covariates) and must
158
+ not contain the design's treatment column.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ design: BaseDesign,
164
+ inference: BaseInference,
165
+ n_simulations: int = 1000,
166
+ alpha: float = 0.05,
167
+ meta_threshold: float = 0.001,
168
+ seed: int | None = None,
169
+ ) -> None:
170
+ if not isinstance(design, BaseDesign):
171
+ raise InvalidDesignError(
172
+ f"design must be an instance of BaseDesign, got "
173
+ f"{type(design).__name__}."
174
+ )
175
+ if not isinstance(inference, BaseInference):
176
+ raise InvalidDesignError(
177
+ f"inference must be an instance of BaseInference, got "
178
+ f"{type(inference).__name__}."
179
+ )
180
+
181
+ if not isinstance(n_simulations, int) or isinstance(
182
+ n_simulations, bool
183
+ ):
184
+ raise InvalidDesignError(
185
+ f"n_simulations must be an integer, got "
186
+ f"{type(n_simulations).__name__}."
187
+ )
188
+ if n_simulations <= 0:
189
+ raise InvalidDesignError(
190
+ f"n_simulations must be > 0, got {n_simulations}."
191
+ )
192
+
193
+ for name, value in (("alpha", alpha), ("meta_threshold", meta_threshold)):
194
+ if not isinstance(value, (int, float)) or isinstance(value, bool):
195
+ raise InvalidDesignError(
196
+ f"{name} must be a float in (0, 1), got "
197
+ f"{type(value).__name__}."
198
+ )
199
+ if not (0.0 < value < 1.0):
200
+ raise InvalidDesignError(
201
+ f"{name} must be in (0, 1), got {value}."
202
+ )
203
+
204
+ self.design = design
205
+ self.inference = inference
206
+ self.n_simulations = n_simulations
207
+ self.alpha = alpha
208
+ self.meta_threshold = meta_threshold
209
+ self.seed = seed
210
+
211
+ def run(self, df: pd.DataFrame) -> AAResult:
212
+ """Run the A/A test on a fixed DataFrame.
213
+
214
+ Parameters
215
+ ----------
216
+ df : pd.DataFrame
217
+ Data with the outcome (and any covariates) but without the
218
+ design's treatment column.
219
+
220
+ Returns
221
+ -------
222
+ AAResult
223
+
224
+ Raises
225
+ ------
226
+ InvalidDesignError
227
+ If the wrapped inference does not produce a scalar p-value.
228
+ """
229
+ base = self.design.randomize(df)
230
+
231
+ rng = np.random.default_rng(self.seed)
232
+ sim_seeds = rng.integers(0, 2**32, size=self.n_simulations)
233
+ inf_seeds = rng.integers(0, 2**32, size=self.n_simulations)
234
+
235
+ has_seed = hasattr(self.inference, "seed")
236
+ original_inf_seed = getattr(self.inference, "seed", None)
237
+
238
+ p_values = np.empty(self.n_simulations, dtype=float)
239
+ try:
240
+ for i in range(self.n_simulations):
241
+ assignment = base.draw(seed=int(sim_seeds[i]))
242
+ if has_seed:
243
+ self.inference.seed = int(inf_seeds[i])
244
+ self.inference.fit(assignment)
245
+ result = self.inference.estimate()
246
+ if result.p_value is None:
247
+ raise InvalidDesignError(
248
+ "AATest requires an inference that produces a scalar "
249
+ "p_value (e.g., RandomizationTest, NeymanCI, "
250
+ "BootstrapCI). The supplied "
251
+ f"{type(self.inference).__name__} returned "
252
+ "p_value=None."
253
+ )
254
+ p_values[i] = float(result.p_value)
255
+ finally:
256
+ if has_seed:
257
+ self.inference.seed = original_inf_seed
258
+
259
+ n_fp = int(np.sum(p_values < self.alpha))
260
+ fp_rate = n_fp / self.n_simulations
261
+ fp_test_pvalue = float(
262
+ binomtest(n_fp, self.n_simulations, self.alpha).pvalue
263
+ )
264
+ ks_statistic, ks_pvalue = kstest(p_values, "uniform")
265
+
266
+ return AAResult(
267
+ n_simulations=self.n_simulations,
268
+ alpha=self.alpha,
269
+ meta_threshold=self.meta_threshold,
270
+ false_positive_rate=fp_rate,
271
+ n_false_positives=n_fp,
272
+ fp_test_pvalue=fp_test_pvalue,
273
+ ks_statistic=float(ks_statistic),
274
+ ks_pvalue=float(ks_pvalue),
275
+ p_values=p_values,
276
+ flagged=bool(fp_test_pvalue < self.meta_threshold),
277
+ )
@@ -0,0 +1,224 @@
1
+ """Covariate balance report diagnostic.
2
+
3
+ Wraps ``check_balance`` (Phase 2) to produce a covariate balance report:
4
+ the standardized mean difference (SMD) per covariate, plus a flag when
5
+ any covariate exceeds an absolute-SMD threshold. The conventional cutoff
6
+ for "meaningful" imbalance is ``|SMD| > 0.1`` (Austin 2009).
7
+
8
+ The Love plot is intentionally not produced here: rendering lives in the
9
+ Phase 7 reporting layer (``plot_balance``), which centralizes the
10
+ optional matplotlib dependency. ``BalanceResult.to_dataframe`` exposes
11
+ the table that such a plot would consume.
12
+ """
13
+
14
+ from dataclasses import dataclass
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ from skxperiments.core.assignment import (
20
+ BlockedAssignment,
21
+ CRDAssignment,
22
+ )
23
+ from skxperiments.core.base import DiagnosticsReport
24
+ from skxperiments.core.exceptions import InvalidDesignError
25
+ from skxperiments.design.balance import check_balance
26
+
27
+
28
+ @dataclass(frozen=True, eq=False)
29
+ class BalanceResult:
30
+ """Result of a covariate balance check.
31
+
32
+ Attributes
33
+ ----------
34
+ table : pd.DataFrame
35
+ The full balance table from ``check_balance``: one row per
36
+ covariate with columns ``covariate``, ``mean_treated``,
37
+ ``mean_control``, ``std_pooled``, ``smd``.
38
+ threshold : float
39
+ Absolute-SMD threshold above which a covariate is considered
40
+ imbalanced.
41
+
42
+ Notes
43
+ -----
44
+ A covariate with no within-group variation has ``std_pooled == 0``
45
+ and an undefined (NaN) SMD; such covariates are reported under
46
+ ``constant_covariates`` and never counted as imbalanced.
47
+ """
48
+
49
+ table: pd.DataFrame
50
+ threshold: float
51
+
52
+ @property
53
+ def smd(self) -> dict:
54
+ """Mapping from covariate name to its SMD."""
55
+ return {row.covariate: float(row.smd) for row in self.table.itertuples()}
56
+
57
+ @property
58
+ def imbalanced(self) -> list[str]:
59
+ """Covariates whose absolute SMD exceeds the threshold."""
60
+ return [
61
+ row.covariate
62
+ for row in self.table.itertuples()
63
+ if not np.isnan(row.smd) and abs(row.smd) > self.threshold
64
+ ]
65
+
66
+ @property
67
+ def constant_covariates(self) -> list[str]:
68
+ """Covariates with an undefined (NaN) SMD (no within-group variance)."""
69
+ return [
70
+ row.covariate
71
+ for row in self.table.itertuples()
72
+ if np.isnan(row.smd)
73
+ ]
74
+
75
+ @property
76
+ def flagged(self) -> bool:
77
+ """True if any covariate is imbalanced."""
78
+ return len(self.imbalanced) > 0
79
+
80
+ @property
81
+ def max_abs_smd(self) -> float:
82
+ """Largest absolute SMD across covariates (NaN if all undefined)."""
83
+ values = np.abs(self.table["smd"].to_numpy(dtype=float))
84
+ if np.all(np.isnan(values)):
85
+ return float("nan")
86
+ return float(np.nanmax(values))
87
+
88
+ def to_dataframe(self) -> pd.DataFrame:
89
+ """Return a copy of the balance table."""
90
+ return self.table.copy()
91
+
92
+ def to_dict(self) -> dict:
93
+ """Return the summary fields as a plain dictionary."""
94
+ return {
95
+ "threshold": self.threshold,
96
+ "flagged": self.flagged,
97
+ "imbalanced": self.imbalanced,
98
+ "constant_covariates": self.constant_covariates,
99
+ "max_abs_smd": self.max_abs_smd,
100
+ "smd": self.smd,
101
+ }
102
+
103
+ def summary(self) -> "BalanceResult":
104
+ """Print a formatted summary table and return self."""
105
+ status = "FLAGGED — covariate imbalance" if self.flagged else "OK"
106
+ lines = ["Balance Report", "--------------"]
107
+ lines.append(f"threshold |SMD| > {self.threshold}")
108
+ lines.append(f"max |SMD| {self.max_abs_smd:.4f}")
109
+ lines.append(f"status {status}")
110
+ lines.append("covariate SMD")
111
+ for row in self.table.itertuples():
112
+ if np.isnan(row.smd):
113
+ smd_str = "nan (constant)"
114
+ mark = ""
115
+ else:
116
+ smd_str = f"{row.smd:+.4f}"
117
+ mark = " *" if abs(row.smd) > self.threshold else ""
118
+ lines.append(f" {row.covariate}: {smd_str}{mark}")
119
+ print("\n".join(lines))
120
+ return self
121
+
122
+ def to_diagnostics_report(self) -> DiagnosticsReport:
123
+ """Convert to a ``DiagnosticsReport`` for pipeline aggregation.
124
+
125
+ Imbalanced covariates become flags; constant covariates become
126
+ warnings.
127
+ """
128
+ report = DiagnosticsReport()
129
+ if self.imbalanced:
130
+ report.flags.append(
131
+ f"Covariate imbalance (|SMD| > {self.threshold}): "
132
+ f"{self.imbalanced}."
133
+ )
134
+ if self.constant_covariates:
135
+ report.warnings.append(
136
+ f"Constant covariates with undefined SMD: "
137
+ f"{self.constant_covariates}."
138
+ )
139
+ return report
140
+
141
+
142
+ class BalanceReport:
143
+ """Covariate balance diagnostic for two-arm designs.
144
+
145
+ Computes the standardized mean difference (SMD) per covariate via
146
+ ``check_balance`` and flags covariates whose absolute SMD exceeds
147
+ ``threshold``.
148
+
149
+ Parameters
150
+ ----------
151
+ covariates : list of str or None, optional
152
+ Covariates to check. If None, all numeric columns except the
153
+ treatment column are used (see ``check_balance``). By default None.
154
+ threshold : float, optional
155
+ Absolute-SMD threshold for flagging imbalance, by default 0.1
156
+ (Austin 2009). Must be positive (SMDs are unbounded, so no upper
157
+ limit applies).
158
+
159
+ Notes
160
+ -----
161
+ Supports ``CRDAssignment`` (including rerandomized) and
162
+ ``BlockedAssignment``. ``FactorialAssignment`` is rejected: a single
163
+ treated-vs-control SMD is not defined for multi-cell designs.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ covariates: list[str] | None = None,
169
+ threshold: float = 0.1,
170
+ ) -> None:
171
+ if not isinstance(threshold, (int, float)) or isinstance(
172
+ threshold, bool
173
+ ):
174
+ raise InvalidDesignError(
175
+ f"threshold must be a positive float, got "
176
+ f"{type(threshold).__name__}."
177
+ )
178
+ if threshold <= 0.0:
179
+ raise InvalidDesignError(
180
+ f"threshold must be > 0, got {threshold}."
181
+ )
182
+
183
+ if covariates is not None:
184
+ if not isinstance(covariates, list) or not all(
185
+ isinstance(c, str) for c in covariates
186
+ ):
187
+ raise InvalidDesignError(
188
+ "covariates must be None or a list of column names."
189
+ )
190
+
191
+ self.covariates = covariates
192
+ self.threshold = threshold
193
+
194
+ def run(
195
+ self,
196
+ assignment: CRDAssignment | BlockedAssignment,
197
+ ) -> BalanceResult:
198
+ """Compute the balance report for an assignment.
199
+
200
+ Parameters
201
+ ----------
202
+ assignment : CRDAssignment or BlockedAssignment
203
+ Two-arm assignment to check.
204
+
205
+ Returns
206
+ -------
207
+ BalanceResult
208
+
209
+ Raises
210
+ ------
211
+ InvalidDesignError
212
+ If the assignment is not two-arm, or if a covariate is missing
213
+ or contains NaN (propagated from ``check_balance``).
214
+ """
215
+ if not isinstance(assignment, (CRDAssignment, BlockedAssignment)):
216
+ raise InvalidDesignError(
217
+ f"BalanceReport supports two-arm designs (CRDAssignment, "
218
+ f"BlockedAssignment); received "
219
+ f"{type(assignment).__name__}. Treated-vs-control SMD is "
220
+ f"not defined for multi-cell factorial designs."
221
+ )
222
+
223
+ table = check_balance(assignment, self.covariates)
224
+ return BalanceResult(table=table, threshold=self.threshold)