skxperiments 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- skxperiments/__init__.py +5 -0
- skxperiments/core/__init__.py +42 -0
- skxperiments/core/assignment.py +589 -0
- skxperiments/core/base.py +512 -0
- skxperiments/core/exceptions.py +145 -0
- skxperiments/core/potential_outcomes.py +168 -0
- skxperiments/core/results.py +624 -0
- skxperiments/design/__init__.py +22 -0
- skxperiments/design/balance.py +182 -0
- skxperiments/design/blocked_crd.py +157 -0
- skxperiments/design/crd.py +162 -0
- skxperiments/design/factorial.py +174 -0
- skxperiments/design/power.py +233 -0
- skxperiments/design/rerandomized_crd.py +319 -0
- skxperiments/diagnostics/__init__.py +21 -0
- skxperiments/diagnostics/aa_test.py +277 -0
- skxperiments/diagnostics/balance_report.py +224 -0
- skxperiments/diagnostics/srm.py +327 -0
- skxperiments/estimators/__init__.py +23 -0
- skxperiments/estimators/blocked_difference_in_means.py +197 -0
- skxperiments/estimators/cuped.py +280 -0
- skxperiments/estimators/difference_in_means.py +161 -0
- skxperiments/estimators/factorial_estimator.py +213 -0
- skxperiments/estimators/lin_estimator.py +298 -0
- skxperiments/inference/__init__.py +17 -0
- skxperiments/inference/bootstrap.py +450 -0
- skxperiments/inference/multiple.py +365 -0
- skxperiments/inference/neyman.py +386 -0
- skxperiments/inference/randomization_test.py +319 -0
- skxperiments/pipeline.py +366 -0
- skxperiments/reporting/__init__.py +30 -0
- skxperiments/reporting/plots.py +411 -0
- skxperiments/reporting/summary.py +185 -0
- skxperiments-0.1.0.dev0.dist-info/METADATA +272 -0
- skxperiments-0.1.0.dev0.dist-info/RECORD +36 -0
- skxperiments-0.1.0.dev0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""A/A test diagnostic: calibration of a design + inference pipeline.
|
|
2
|
+
|
|
3
|
+
An A/A test repeatedly re-randomizes a design on a *fixed* dataset and
|
|
4
|
+
runs an inference procedure each time. Because the treatment is re-drawn
|
|
5
|
+
while the outcome is held fixed, the true effect is zero by construction,
|
|
6
|
+
so a well-calibrated pipeline should reject the null at rate ``alpha`` and
|
|
7
|
+
produce uniformly distributed p-values. ``AATest`` measures both: the
|
|
8
|
+
false-positive rate (compared to ``alpha`` with an exact binomial test)
|
|
9
|
+
and the uniformity of the p-values (Kolmogorov-Smirnov).
|
|
10
|
+
|
|
11
|
+
Cost note
|
|
12
|
+
---------
|
|
13
|
+
Each simulation runs the wrapped inference once. With a resampling-based
|
|
14
|
+
inference (``RandomizationTest``, ``BootstrapCI``) this is a nested loop:
|
|
15
|
+
``O(n_simulations x n_resamples)`` estimator fits. For routine calibration
|
|
16
|
+
prefer the analytic ``NeymanCI``; reserve the resampling inferences for
|
|
17
|
+
small ``n_simulations`` (and expect a slow run).
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pandas as pd
|
|
24
|
+
from scipy.stats import binomtest, kstest
|
|
25
|
+
|
|
26
|
+
from skxperiments.core.base import BaseDesign, BaseInference, DiagnosticsReport
|
|
27
|
+
from skxperiments.core.exceptions import InvalidDesignError
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True, eq=False)
|
|
31
|
+
class AAResult:
|
|
32
|
+
"""Result of an A/A test.
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
n_simulations : int
|
|
37
|
+
Number of re-randomizations performed.
|
|
38
|
+
alpha : float
|
|
39
|
+
Nominal significance level the pipeline was run at.
|
|
40
|
+
meta_threshold : float
|
|
41
|
+
Threshold for the binomial false-positive-rate test below which
|
|
42
|
+
the pipeline is flagged as miscalibrated.
|
|
43
|
+
false_positive_rate : float
|
|
44
|
+
Fraction of simulations with ``p_value < alpha``.
|
|
45
|
+
n_false_positives : int
|
|
46
|
+
Number of simulations with ``p_value < alpha``.
|
|
47
|
+
fp_test_pvalue : float
|
|
48
|
+
Two-sided exact binomial-test p-value for ``n_false_positives``
|
|
49
|
+
out of ``n_simulations`` under the null rate ``alpha``.
|
|
50
|
+
ks_statistic : float
|
|
51
|
+
Kolmogorov-Smirnov statistic of the p-values against Uniform(0, 1).
|
|
52
|
+
ks_pvalue : float
|
|
53
|
+
KS test p-value (secondary, full-distribution calibration signal).
|
|
54
|
+
p_values : np.ndarray
|
|
55
|
+
The ``n_simulations`` p-values, in simulation order.
|
|
56
|
+
flagged : bool
|
|
57
|
+
True if ``fp_test_pvalue < meta_threshold`` — the false-positive
|
|
58
|
+
rate is incompatible with ``alpha``.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
n_simulations: int
|
|
62
|
+
alpha: float
|
|
63
|
+
meta_threshold: float
|
|
64
|
+
false_positive_rate: float
|
|
65
|
+
n_false_positives: int
|
|
66
|
+
fp_test_pvalue: float
|
|
67
|
+
ks_statistic: float
|
|
68
|
+
ks_pvalue: float
|
|
69
|
+
p_values: np.ndarray
|
|
70
|
+
flagged: bool
|
|
71
|
+
|
|
72
|
+
def summary(self) -> "AAResult":
|
|
73
|
+
"""Print a formatted summary table and return self."""
|
|
74
|
+
status = "FLAGGED — miscalibrated" if self.flagged else "OK"
|
|
75
|
+
lines = ["A/A Test", "--------"]
|
|
76
|
+
lines.append(f"simulations {self.n_simulations}")
|
|
77
|
+
lines.append(f"alpha {self.alpha}")
|
|
78
|
+
lines.append(
|
|
79
|
+
f"false-positive {self.false_positive_rate:.4f} "
|
|
80
|
+
f"({self.n_false_positives}/{self.n_simulations})"
|
|
81
|
+
)
|
|
82
|
+
lines.append(f"FP binomial p {self.fp_test_pvalue:.4f}")
|
|
83
|
+
lines.append(f"KS uniform p {self.ks_pvalue:.4f}")
|
|
84
|
+
lines.append(f"status {status}")
|
|
85
|
+
print("\n".join(lines))
|
|
86
|
+
return self
|
|
87
|
+
|
|
88
|
+
def to_dict(self) -> dict:
|
|
89
|
+
"""Return the scalar summary fields (excludes the p_values array)."""
|
|
90
|
+
return {
|
|
91
|
+
"n_simulations": self.n_simulations,
|
|
92
|
+
"alpha": self.alpha,
|
|
93
|
+
"meta_threshold": self.meta_threshold,
|
|
94
|
+
"false_positive_rate": self.false_positive_rate,
|
|
95
|
+
"n_false_positives": self.n_false_positives,
|
|
96
|
+
"fp_test_pvalue": self.fp_test_pvalue,
|
|
97
|
+
"ks_statistic": self.ks_statistic,
|
|
98
|
+
"ks_pvalue": self.ks_pvalue,
|
|
99
|
+
"flagged": self.flagged,
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
def to_diagnostics_report(self) -> DiagnosticsReport:
|
|
103
|
+
"""Convert to a ``DiagnosticsReport`` for pipeline aggregation.
|
|
104
|
+
|
|
105
|
+
A miscalibrated false-positive rate is a flag; non-uniform
|
|
106
|
+
p-values (KS below the meta-threshold) are a secondary warning.
|
|
107
|
+
"""
|
|
108
|
+
report = DiagnosticsReport()
|
|
109
|
+
if self.flagged:
|
|
110
|
+
report.flags.append(
|
|
111
|
+
f"A/A false-positive rate miscalibrated: "
|
|
112
|
+
f"{self.false_positive_rate:.3f} vs alpha={self.alpha} "
|
|
113
|
+
f"(binomial p={self.fp_test_pvalue:.2e} < "
|
|
114
|
+
f"{self.meta_threshold})."
|
|
115
|
+
)
|
|
116
|
+
if self.ks_pvalue < self.meta_threshold:
|
|
117
|
+
report.warnings.append(
|
|
118
|
+
f"A/A p-values deviate from uniform "
|
|
119
|
+
f"(KS p={self.ks_pvalue:.2e})."
|
|
120
|
+
)
|
|
121
|
+
return report
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class AATest:
|
|
125
|
+
"""A/A test for a design + inference pipeline.
|
|
126
|
+
|
|
127
|
+
Re-randomizes ``design`` on a fixed DataFrame ``n_simulations`` times
|
|
128
|
+
and runs ``inference`` on each draw, collecting the p-values. Because
|
|
129
|
+
the treatment is re-drawn while the outcome is fixed, the true effect
|
|
130
|
+
is zero, so the false-positive rate should equal ``alpha`` and the
|
|
131
|
+
p-values should be uniform.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
design : BaseDesign
|
|
136
|
+
The design whose randomization is being calibrated.
|
|
137
|
+
inference : BaseInference
|
|
138
|
+
A configured inference object (which already wraps an estimator),
|
|
139
|
+
e.g. ``NeymanCI(DifferenceInMeans("y"))`` or
|
|
140
|
+
``RandomizationTest(...)``. Must produce a scalar ``p_value``.
|
|
141
|
+
n_simulations : int, optional
|
|
142
|
+
Number of re-randomizations, by default 1000.
|
|
143
|
+
alpha : float, optional
|
|
144
|
+
Nominal significance level, by default 0.05.
|
|
145
|
+
meta_threshold : float, optional
|
|
146
|
+
Threshold for the binomial false-positive-rate test, by default
|
|
147
|
+
0.001. The pipeline is flagged when the binomial-test p-value
|
|
148
|
+
falls below it.
|
|
149
|
+
seed : int or None, optional
|
|
150
|
+
Random seed for reproducibility, by default None.
|
|
151
|
+
|
|
152
|
+
Notes
|
|
153
|
+
-----
|
|
154
|
+
The wrapped ``inference`` is refitted on each draw; its ``seed`` (if
|
|
155
|
+
any) is varied per simulation and restored afterward. The outcome
|
|
156
|
+
column is whatever the wrapped estimator resolves against the data;
|
|
157
|
+
the supplied DataFrame must contain it (and any covariates) and must
|
|
158
|
+
not contain the design's treatment column.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
design: BaseDesign,
|
|
164
|
+
inference: BaseInference,
|
|
165
|
+
n_simulations: int = 1000,
|
|
166
|
+
alpha: float = 0.05,
|
|
167
|
+
meta_threshold: float = 0.001,
|
|
168
|
+
seed: int | None = None,
|
|
169
|
+
) -> None:
|
|
170
|
+
if not isinstance(design, BaseDesign):
|
|
171
|
+
raise InvalidDesignError(
|
|
172
|
+
f"design must be an instance of BaseDesign, got "
|
|
173
|
+
f"{type(design).__name__}."
|
|
174
|
+
)
|
|
175
|
+
if not isinstance(inference, BaseInference):
|
|
176
|
+
raise InvalidDesignError(
|
|
177
|
+
f"inference must be an instance of BaseInference, got "
|
|
178
|
+
f"{type(inference).__name__}."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if not isinstance(n_simulations, int) or isinstance(
|
|
182
|
+
n_simulations, bool
|
|
183
|
+
):
|
|
184
|
+
raise InvalidDesignError(
|
|
185
|
+
f"n_simulations must be an integer, got "
|
|
186
|
+
f"{type(n_simulations).__name__}."
|
|
187
|
+
)
|
|
188
|
+
if n_simulations <= 0:
|
|
189
|
+
raise InvalidDesignError(
|
|
190
|
+
f"n_simulations must be > 0, got {n_simulations}."
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
for name, value in (("alpha", alpha), ("meta_threshold", meta_threshold)):
|
|
194
|
+
if not isinstance(value, (int, float)) or isinstance(value, bool):
|
|
195
|
+
raise InvalidDesignError(
|
|
196
|
+
f"{name} must be a float in (0, 1), got "
|
|
197
|
+
f"{type(value).__name__}."
|
|
198
|
+
)
|
|
199
|
+
if not (0.0 < value < 1.0):
|
|
200
|
+
raise InvalidDesignError(
|
|
201
|
+
f"{name} must be in (0, 1), got {value}."
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
self.design = design
|
|
205
|
+
self.inference = inference
|
|
206
|
+
self.n_simulations = n_simulations
|
|
207
|
+
self.alpha = alpha
|
|
208
|
+
self.meta_threshold = meta_threshold
|
|
209
|
+
self.seed = seed
|
|
210
|
+
|
|
211
|
+
def run(self, df: pd.DataFrame) -> AAResult:
|
|
212
|
+
"""Run the A/A test on a fixed DataFrame.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
df : pd.DataFrame
|
|
217
|
+
Data with the outcome (and any covariates) but without the
|
|
218
|
+
design's treatment column.
|
|
219
|
+
|
|
220
|
+
Returns
|
|
221
|
+
-------
|
|
222
|
+
AAResult
|
|
223
|
+
|
|
224
|
+
Raises
|
|
225
|
+
------
|
|
226
|
+
InvalidDesignError
|
|
227
|
+
If the wrapped inference does not produce a scalar p-value.
|
|
228
|
+
"""
|
|
229
|
+
base = self.design.randomize(df)
|
|
230
|
+
|
|
231
|
+
rng = np.random.default_rng(self.seed)
|
|
232
|
+
sim_seeds = rng.integers(0, 2**32, size=self.n_simulations)
|
|
233
|
+
inf_seeds = rng.integers(0, 2**32, size=self.n_simulations)
|
|
234
|
+
|
|
235
|
+
has_seed = hasattr(self.inference, "seed")
|
|
236
|
+
original_inf_seed = getattr(self.inference, "seed", None)
|
|
237
|
+
|
|
238
|
+
p_values = np.empty(self.n_simulations, dtype=float)
|
|
239
|
+
try:
|
|
240
|
+
for i in range(self.n_simulations):
|
|
241
|
+
assignment = base.draw(seed=int(sim_seeds[i]))
|
|
242
|
+
if has_seed:
|
|
243
|
+
self.inference.seed = int(inf_seeds[i])
|
|
244
|
+
self.inference.fit(assignment)
|
|
245
|
+
result = self.inference.estimate()
|
|
246
|
+
if result.p_value is None:
|
|
247
|
+
raise InvalidDesignError(
|
|
248
|
+
"AATest requires an inference that produces a scalar "
|
|
249
|
+
"p_value (e.g., RandomizationTest, NeymanCI, "
|
|
250
|
+
"BootstrapCI). The supplied "
|
|
251
|
+
f"{type(self.inference).__name__} returned "
|
|
252
|
+
"p_value=None."
|
|
253
|
+
)
|
|
254
|
+
p_values[i] = float(result.p_value)
|
|
255
|
+
finally:
|
|
256
|
+
if has_seed:
|
|
257
|
+
self.inference.seed = original_inf_seed
|
|
258
|
+
|
|
259
|
+
n_fp = int(np.sum(p_values < self.alpha))
|
|
260
|
+
fp_rate = n_fp / self.n_simulations
|
|
261
|
+
fp_test_pvalue = float(
|
|
262
|
+
binomtest(n_fp, self.n_simulations, self.alpha).pvalue
|
|
263
|
+
)
|
|
264
|
+
ks_statistic, ks_pvalue = kstest(p_values, "uniform")
|
|
265
|
+
|
|
266
|
+
return AAResult(
|
|
267
|
+
n_simulations=self.n_simulations,
|
|
268
|
+
alpha=self.alpha,
|
|
269
|
+
meta_threshold=self.meta_threshold,
|
|
270
|
+
false_positive_rate=fp_rate,
|
|
271
|
+
n_false_positives=n_fp,
|
|
272
|
+
fp_test_pvalue=fp_test_pvalue,
|
|
273
|
+
ks_statistic=float(ks_statistic),
|
|
274
|
+
ks_pvalue=float(ks_pvalue),
|
|
275
|
+
p_values=p_values,
|
|
276
|
+
flagged=bool(fp_test_pvalue < self.meta_threshold),
|
|
277
|
+
)
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""Covariate balance report diagnostic.
|
|
2
|
+
|
|
3
|
+
Wraps ``check_balance`` (Phase 2) to produce a covariate balance report:
|
|
4
|
+
the standardized mean difference (SMD) per covariate, plus a flag when
|
|
5
|
+
any covariate exceeds an absolute-SMD threshold. The conventional cutoff
|
|
6
|
+
for "meaningful" imbalance is ``|SMD| > 0.1`` (Austin 2009).
|
|
7
|
+
|
|
8
|
+
The Love plot is intentionally not produced here: rendering lives in the
|
|
9
|
+
Phase 7 reporting layer (``plot_balance``), which centralizes the
|
|
10
|
+
optional matplotlib dependency. ``BalanceResult.to_dataframe`` exposes
|
|
11
|
+
the table that such a plot would consume.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
|
|
19
|
+
from skxperiments.core.assignment import (
|
|
20
|
+
BlockedAssignment,
|
|
21
|
+
CRDAssignment,
|
|
22
|
+
)
|
|
23
|
+
from skxperiments.core.base import DiagnosticsReport
|
|
24
|
+
from skxperiments.core.exceptions import InvalidDesignError
|
|
25
|
+
from skxperiments.design.balance import check_balance
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True, eq=False)
|
|
29
|
+
class BalanceResult:
|
|
30
|
+
"""Result of a covariate balance check.
|
|
31
|
+
|
|
32
|
+
Attributes
|
|
33
|
+
----------
|
|
34
|
+
table : pd.DataFrame
|
|
35
|
+
The full balance table from ``check_balance``: one row per
|
|
36
|
+
covariate with columns ``covariate``, ``mean_treated``,
|
|
37
|
+
``mean_control``, ``std_pooled``, ``smd``.
|
|
38
|
+
threshold : float
|
|
39
|
+
Absolute-SMD threshold above which a covariate is considered
|
|
40
|
+
imbalanced.
|
|
41
|
+
|
|
42
|
+
Notes
|
|
43
|
+
-----
|
|
44
|
+
A covariate with no within-group variation has ``std_pooled == 0``
|
|
45
|
+
and an undefined (NaN) SMD; such covariates are reported under
|
|
46
|
+
``constant_covariates`` and never counted as imbalanced.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
table: pd.DataFrame
|
|
50
|
+
threshold: float
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def smd(self) -> dict:
|
|
54
|
+
"""Mapping from covariate name to its SMD."""
|
|
55
|
+
return {row.covariate: float(row.smd) for row in self.table.itertuples()}
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def imbalanced(self) -> list[str]:
|
|
59
|
+
"""Covariates whose absolute SMD exceeds the threshold."""
|
|
60
|
+
return [
|
|
61
|
+
row.covariate
|
|
62
|
+
for row in self.table.itertuples()
|
|
63
|
+
if not np.isnan(row.smd) and abs(row.smd) > self.threshold
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def constant_covariates(self) -> list[str]:
|
|
68
|
+
"""Covariates with an undefined (NaN) SMD (no within-group variance)."""
|
|
69
|
+
return [
|
|
70
|
+
row.covariate
|
|
71
|
+
for row in self.table.itertuples()
|
|
72
|
+
if np.isnan(row.smd)
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def flagged(self) -> bool:
|
|
77
|
+
"""True if any covariate is imbalanced."""
|
|
78
|
+
return len(self.imbalanced) > 0
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def max_abs_smd(self) -> float:
|
|
82
|
+
"""Largest absolute SMD across covariates (NaN if all undefined)."""
|
|
83
|
+
values = np.abs(self.table["smd"].to_numpy(dtype=float))
|
|
84
|
+
if np.all(np.isnan(values)):
|
|
85
|
+
return float("nan")
|
|
86
|
+
return float(np.nanmax(values))
|
|
87
|
+
|
|
88
|
+
def to_dataframe(self) -> pd.DataFrame:
|
|
89
|
+
"""Return a copy of the balance table."""
|
|
90
|
+
return self.table.copy()
|
|
91
|
+
|
|
92
|
+
def to_dict(self) -> dict:
|
|
93
|
+
"""Return the summary fields as a plain dictionary."""
|
|
94
|
+
return {
|
|
95
|
+
"threshold": self.threshold,
|
|
96
|
+
"flagged": self.flagged,
|
|
97
|
+
"imbalanced": self.imbalanced,
|
|
98
|
+
"constant_covariates": self.constant_covariates,
|
|
99
|
+
"max_abs_smd": self.max_abs_smd,
|
|
100
|
+
"smd": self.smd,
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
def summary(self) -> "BalanceResult":
|
|
104
|
+
"""Print a formatted summary table and return self."""
|
|
105
|
+
status = "FLAGGED — covariate imbalance" if self.flagged else "OK"
|
|
106
|
+
lines = ["Balance Report", "--------------"]
|
|
107
|
+
lines.append(f"threshold |SMD| > {self.threshold}")
|
|
108
|
+
lines.append(f"max |SMD| {self.max_abs_smd:.4f}")
|
|
109
|
+
lines.append(f"status {status}")
|
|
110
|
+
lines.append("covariate SMD")
|
|
111
|
+
for row in self.table.itertuples():
|
|
112
|
+
if np.isnan(row.smd):
|
|
113
|
+
smd_str = "nan (constant)"
|
|
114
|
+
mark = ""
|
|
115
|
+
else:
|
|
116
|
+
smd_str = f"{row.smd:+.4f}"
|
|
117
|
+
mark = " *" if abs(row.smd) > self.threshold else ""
|
|
118
|
+
lines.append(f" {row.covariate}: {smd_str}{mark}")
|
|
119
|
+
print("\n".join(lines))
|
|
120
|
+
return self
|
|
121
|
+
|
|
122
|
+
def to_diagnostics_report(self) -> DiagnosticsReport:
|
|
123
|
+
"""Convert to a ``DiagnosticsReport`` for pipeline aggregation.
|
|
124
|
+
|
|
125
|
+
Imbalanced covariates become flags; constant covariates become
|
|
126
|
+
warnings.
|
|
127
|
+
"""
|
|
128
|
+
report = DiagnosticsReport()
|
|
129
|
+
if self.imbalanced:
|
|
130
|
+
report.flags.append(
|
|
131
|
+
f"Covariate imbalance (|SMD| > {self.threshold}): "
|
|
132
|
+
f"{self.imbalanced}."
|
|
133
|
+
)
|
|
134
|
+
if self.constant_covariates:
|
|
135
|
+
report.warnings.append(
|
|
136
|
+
f"Constant covariates with undefined SMD: "
|
|
137
|
+
f"{self.constant_covariates}."
|
|
138
|
+
)
|
|
139
|
+
return report
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class BalanceReport:
|
|
143
|
+
"""Covariate balance diagnostic for two-arm designs.
|
|
144
|
+
|
|
145
|
+
Computes the standardized mean difference (SMD) per covariate via
|
|
146
|
+
``check_balance`` and flags covariates whose absolute SMD exceeds
|
|
147
|
+
``threshold``.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
covariates : list of str or None, optional
|
|
152
|
+
Covariates to check. If None, all numeric columns except the
|
|
153
|
+
treatment column are used (see ``check_balance``). By default None.
|
|
154
|
+
threshold : float, optional
|
|
155
|
+
Absolute-SMD threshold for flagging imbalance, by default 0.1
|
|
156
|
+
(Austin 2009). Must be positive (SMDs are unbounded, so no upper
|
|
157
|
+
limit applies).
|
|
158
|
+
|
|
159
|
+
Notes
|
|
160
|
+
-----
|
|
161
|
+
Supports ``CRDAssignment`` (including rerandomized) and
|
|
162
|
+
``BlockedAssignment``. ``FactorialAssignment`` is rejected: a single
|
|
163
|
+
treated-vs-control SMD is not defined for multi-cell designs.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
covariates: list[str] | None = None,
|
|
169
|
+
threshold: float = 0.1,
|
|
170
|
+
) -> None:
|
|
171
|
+
if not isinstance(threshold, (int, float)) or isinstance(
|
|
172
|
+
threshold, bool
|
|
173
|
+
):
|
|
174
|
+
raise InvalidDesignError(
|
|
175
|
+
f"threshold must be a positive float, got "
|
|
176
|
+
f"{type(threshold).__name__}."
|
|
177
|
+
)
|
|
178
|
+
if threshold <= 0.0:
|
|
179
|
+
raise InvalidDesignError(
|
|
180
|
+
f"threshold must be > 0, got {threshold}."
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if covariates is not None:
|
|
184
|
+
if not isinstance(covariates, list) or not all(
|
|
185
|
+
isinstance(c, str) for c in covariates
|
|
186
|
+
):
|
|
187
|
+
raise InvalidDesignError(
|
|
188
|
+
"covariates must be None or a list of column names."
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self.covariates = covariates
|
|
192
|
+
self.threshold = threshold
|
|
193
|
+
|
|
194
|
+
def run(
|
|
195
|
+
self,
|
|
196
|
+
assignment: CRDAssignment | BlockedAssignment,
|
|
197
|
+
) -> BalanceResult:
|
|
198
|
+
"""Compute the balance report for an assignment.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
assignment : CRDAssignment or BlockedAssignment
|
|
203
|
+
Two-arm assignment to check.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
BalanceResult
|
|
208
|
+
|
|
209
|
+
Raises
|
|
210
|
+
------
|
|
211
|
+
InvalidDesignError
|
|
212
|
+
If the assignment is not two-arm, or if a covariate is missing
|
|
213
|
+
or contains NaN (propagated from ``check_balance``).
|
|
214
|
+
"""
|
|
215
|
+
if not isinstance(assignment, (CRDAssignment, BlockedAssignment)):
|
|
216
|
+
raise InvalidDesignError(
|
|
217
|
+
f"BalanceReport supports two-arm designs (CRDAssignment, "
|
|
218
|
+
f"BlockedAssignment); received "
|
|
219
|
+
f"{type(assignment).__name__}. Treated-vs-control SMD is "
|
|
220
|
+
f"not defined for multi-cell factorial designs."
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
table = check_balance(assignment, self.covariates)
|
|
224
|
+
return BalanceResult(table=table, threshold=self.threshold)
|