pyrollmatch 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,50 @@
1
+ """
2
+ pyrollmatch — Fast rolling entry matching for staggered adoption studies.
3
+
4
+ A Python reimplementation of the R ``rollmatch`` package (RTI International)
5
+ using polars and numpy for scalable matching on large panel datasets (100K+ units).
6
+
7
+ Rolling entry matching (REM) explicitly handles staggered treatment adoption
8
+ by matching each treated unit to controls at the treated unit's specific entry
9
+ time, using accumulated (rolling-window) covariates.
10
+
11
+ Quick Start
12
+ -----------
13
+ >>> import polars as pl
14
+ >>> from pyrollmatch import rollmatch, alpha_sweep
15
+ >>>
16
+ >>> # data: panel with columns [unit_id, time, treat, entry_time, x1, x2, ...]
17
+ >>> result = rollmatch(
18
+ ... data, treat="treat", tm="time", entry="entry_time", id="unit_id",
19
+ ... covariates=["x1", "x2", "x3"],
20
+ ... alpha=0.1, num_matches=3,
21
+ ... )
22
+ >>> result.balance # SMD table
23
+ >>> result.weights # unit_id -> matching weight
24
+
25
+ References
26
+ ----------
27
+ - Witman et al. (2018). "Comparison Group Selection in the Presence of Rolling Entry."
28
+ Health Services Research, 54(1), 262-270. doi:10.1111/1475-6773.13086
29
+ - RTI International rollmatch R package: https://github.com/RTIInternational/rollmatch
30
+ """
31
+
32
+ from .core import rollmatch, alpha_sweep, RollmatchResult
33
+ from .reduce import reduce_data
34
+ from .score import score_data, ScoredResult
35
+ from .balance import compute_balance, smd_table
36
+ from .diagnostics import balance_test, equivalence_test
37
+
38
+ __version__ = "0.0.3"
39
+ __all__ = [
40
+ "rollmatch",
41
+ "alpha_sweep",
42
+ "RollmatchResult",
43
+ "reduce_data",
44
+ "score_data",
45
+ "ScoredResult",
46
+ "compute_balance",
47
+ "smd_table",
48
+ "balance_test",
49
+ "equivalence_test",
50
+ ]
pyrollmatch/balance.py ADDED
@@ -0,0 +1,121 @@
1
+ """
2
+ balance — Covariate balance computation and SMD table.
3
+ """
4
+
5
+ import polars as pl
6
+ import numpy as np
7
+
8
+
9
+ def compute_balance(
10
+ scored_data: pl.DataFrame,
11
+ matches: pl.DataFrame,
12
+ treat: str,
13
+ id: str,
14
+ tm: str,
15
+ covariates: list[str],
16
+ ) -> pl.DataFrame:
17
+ """Compute covariate balance before and after matching.
18
+
19
+ Returns a table with means, SDs, and SMDs for each covariate,
20
+ both in the full sample and the matched sample.
21
+
22
+ Parameters
23
+ ----------
24
+ scored_data : pl.DataFrame
25
+ Reduced data with treatment indicator and covariates.
26
+ matches : pl.DataFrame
27
+ Match results with treat_id and control_id columns.
28
+ treat : str
29
+ Treatment indicator column.
30
+ id : str
31
+ Unit identifier column.
32
+ tm : str
33
+ Time period column.
34
+ covariates : list[str]
35
+ Covariate column names.
36
+
37
+ Returns
38
+ -------
39
+ pl.DataFrame with columns:
40
+ covariate, full_mean_t, full_mean_c, full_sd_t, full_sd_c,
41
+ full_smd, matched_mean_t, matched_mean_c, matched_sd_t,
42
+ matched_sd_c, matched_smd
43
+ """
44
+ # Pre-compute matched data ONCE (not per covariate)
45
+ treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id})
46
+ control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id})
47
+ matched_ids_df = pl.concat([treat_matches, control_matches])
48
+ matched_data = scored_data.join(matched_ids_df, on=[tm, id], how="semi")
49
+
50
+ # Pre-split by treatment group
51
+ full_t = scored_data.filter(pl.col(treat) == 1)
52
+ full_c = scored_data.filter(pl.col(treat) == 0)
53
+ match_t = matched_data.filter(pl.col(treat) == 1)
54
+ match_c = matched_data.filter(pl.col(treat) == 0)
55
+
56
+ rows = []
57
+
58
+ for cov in covariates:
59
+ vals_t = full_t[cov].drop_nulls().to_numpy()
60
+ vals_c = full_c[cov].drop_nulls().to_numpy()
61
+
62
+ full_mean_t = np.mean(vals_t) if len(vals_t) > 0 else np.nan
63
+ full_mean_c = np.mean(vals_c) if len(vals_c) > 0 else np.nan
64
+ full_sd_t = np.std(vals_t, ddof=1) if len(vals_t) > 1 else np.nan
65
+ full_sd_c = np.std(vals_c, ddof=1) if len(vals_c) > 1 else np.nan
66
+ full_pooled = np.sqrt((full_sd_t**2 + full_sd_c**2) / 2) if not (np.isnan(full_sd_t) or np.isnan(full_sd_c)) else np.nan
67
+ full_smd = (full_mean_t - full_mean_c) / full_pooled if full_pooled and full_pooled > 0 else np.nan
68
+
69
+ mvals_t = match_t[cov].drop_nulls().to_numpy()
70
+ mvals_c = match_c[cov].drop_nulls().to_numpy()
71
+
72
+ m_mean_t = np.mean(mvals_t) if len(mvals_t) > 0 else np.nan
73
+ m_mean_c = np.mean(mvals_c) if len(mvals_c) > 0 else np.nan
74
+ m_sd_t = np.std(mvals_t, ddof=1) if len(mvals_t) > 1 else np.nan
75
+ m_sd_c = np.std(mvals_c, ddof=1) if len(mvals_c) > 1 else np.nan
76
+ m_pooled = np.sqrt((m_sd_t**2 + m_sd_c**2) / 2) if not (np.isnan(m_sd_t) or np.isnan(m_sd_c)) else np.nan
77
+ m_smd = (m_mean_t - m_mean_c) / m_pooled if m_pooled and m_pooled > 0 else np.nan
78
+
79
+ rows.append({
80
+ "covariate": cov,
81
+ "full_mean_t": round(full_mean_t, 4),
82
+ "full_mean_c": round(full_mean_c, 4),
83
+ "full_sd_t": round(full_sd_t, 4),
84
+ "full_sd_c": round(full_sd_c, 4),
85
+ "full_smd": round(full_smd, 4),
86
+ "matched_mean_t": round(m_mean_t, 4),
87
+ "matched_mean_c": round(m_mean_c, 4),
88
+ "matched_sd_t": round(m_sd_t, 4),
89
+ "matched_sd_c": round(m_sd_c, 4),
90
+ "matched_smd": round(m_smd, 4),
91
+ })
92
+
93
+ return pl.DataFrame(rows)
94
+
95
+
96
+ def smd_table(balance: pl.DataFrame, threshold: float = 0.1) -> None:
97
+ """Print a formatted SMD table with pass/fail indicators.
98
+
99
+ Parameters
100
+ ----------
101
+ balance : pl.DataFrame
102
+ Output from compute_balance().
103
+ threshold : float
104
+ |SMD| threshold for pass/fail (default 0.1).
105
+ """
106
+ max_smd = balance["matched_smd"].abs().max()
107
+ all_pass = balance["matched_smd"].abs().max() < threshold
108
+
109
+ print(f"\n{'='*70}")
110
+ print(f" Standardized Mean Differences (threshold: |SMD| < {threshold})")
111
+ print(f" Max |SMD| = {max_smd:.4f} {'✓ ALL PASS' if all_pass else '✗ SOME FAIL'}")
112
+ print(f"{'='*70}\n")
113
+ print(f" {'Covariate':<30} {'Full SMD':>10} {'Matched SMD':>12} {'Pass':>6}")
114
+ print(f" {'-'*30} {'-'*10} {'-'*12} {'-'*6}")
115
+
116
+ for row in balance.iter_rows(named=True):
117
+ smd = row["matched_smd"]
118
+ passed = abs(smd) < threshold if smd is not None else False
119
+ print(f" {row['covariate']:<30} {row['full_smd']:>10.4f} {smd:>12.4f} {'✓' if passed else '✗':>6}")
120
+
121
+ print()
pyrollmatch/core.py ADDED
@@ -0,0 +1,301 @@
1
+ """
2
+ core — Main rollmatch orchestration and alpha sweep.
3
+ """
4
+
5
+ import polars as pl
6
+ import numpy as np
7
+ from dataclasses import dataclass, field
8
+
9
+ from .reduce import reduce_data
10
+ from .score import score_data
11
+ from .match import match_all_periods
12
+ from .balance import compute_balance, smd_table
13
+
14
+
15
+ @dataclass
16
+ class RollmatchResult:
17
+ """Result from rollmatch."""
18
+ matched_data: pl.DataFrame
19
+ balance: pl.DataFrame
20
+ n_treated_total: int
21
+ n_treated_matched: int
22
+ n_controls_matched: int
23
+ alpha: float
24
+ weights: pl.DataFrame # id -> weight
25
+
26
+
27
+ def _compute_weights(matches: pl.DataFrame, id: str, num_matches: int) -> pl.DataFrame:
28
+ """Compute matching weights from matched pairs.
29
+
30
+ Following R rollmatch convention:
31
+ - treatment_weight = 1 / actual_matches_for_this_treated_unit
32
+ - control_weight = sum of treatment_weights across all treatments
33
+ this control is matched to
34
+
35
+ This ensures proper inverse probability weighting when treated units
36
+ have different numbers of matches (e.g., due to tight calipers).
37
+ """
38
+ treat_match_counts = (
39
+ matches.group_by("treat_id").len()
40
+ .rename({"len": "total_matches"})
41
+ )
42
+ matches_with_weights = matches.join(treat_match_counts, on="treat_id")
43
+ matches_with_weights = matches_with_weights.with_columns(
44
+ (1.0 / pl.col("total_matches")).alias("treatment_weight")
45
+ )
46
+
47
+ treat_weights = (
48
+ matches.select("treat_id").unique()
49
+ .rename({"treat_id": id})
50
+ .with_columns(pl.lit(1.0).alias("weight"))
51
+ )
52
+ ctrl_weights = (
53
+ matches_with_weights
54
+ .group_by("control_id")
55
+ .agg(pl.col("treatment_weight").sum().alias("weight"))
56
+ .rename({"control_id": id})
57
+ )
58
+
59
+ weights = pl.concat([treat_weights, ctrl_weights])
60
+ return weights.group_by(id).agg(pl.col("weight").sum())
61
+
62
+
63
+ def rollmatch(
64
+ data: pl.DataFrame,
65
+ treat: str,
66
+ tm: str,
67
+ entry: str,
68
+ id: str,
69
+ covariates: list[str],
70
+ lookback: int = 1,
71
+ alpha: float = 0,
72
+ num_matches: int = 3,
73
+ replacement: bool = True,
74
+ standard_deviation: str = "average",
75
+ model_type: str = "logistic",
76
+ match_on: str = "logit",
77
+ block_size: int = 2000,
78
+ verbose: bool = True,
79
+ ) -> RollmatchResult | None:
80
+ """Run the full rolling entry matching pipeline.
81
+
82
+ Parameters
83
+ ----------
84
+ data : pl.DataFrame
85
+ Panel data with unit × time observations.
86
+ treat : str
87
+ Binary treatment column (1=treated, 0=control).
88
+ tm : str
89
+ Time period column (integer).
90
+ entry : str
91
+ Entry period column. Treatment onset for treated units; null or
92
+ any value > max(tm) for controls.
93
+ id : str
94
+ Unit identifier column.
95
+ covariates : list[str]
96
+ Covariate column names for matching.
97
+ lookback : int
98
+ Periods to look back from entry for baseline covariates.
99
+ alpha : float
100
+ Caliper multiplier (0 = no caliper).
101
+ num_matches : int
102
+ Number of control matches per treated unit.
103
+ replacement : bool
104
+ Allow control reuse within time period.
105
+ standard_deviation : str
106
+ Method for pooled SD in caliper.
107
+ model_type : str
108
+ Propensity model type ("logistic").
109
+ match_on : str
110
+ Score type ("logit" or "pscore").
111
+ block_size : int
112
+ Block size for memory-efficient matching.
113
+ verbose : bool
114
+ Print progress.
115
+
116
+ Returns
117
+ -------
118
+ RollmatchResult or None if matching fails.
119
+ """
120
+ if verbose:
121
+ n_treat = data.filter(pl.col(treat) == 1)[id].n_unique()
122
+ n_ctrl = data.filter(pl.col(treat) == 0)[id].n_unique()
123
+ print(f"rollmatch: {n_treat} treated, {n_ctrl} controls, alpha={alpha}")
124
+
125
+ # Step 1: Reduce data
126
+ if verbose:
127
+ print(" Step 1: reduce_data...")
128
+ reduced = reduce_data(data, treat, tm, entry, id, lookback)
129
+ if verbose:
130
+ print(f" Reduced: {reduced.height} rows")
131
+
132
+ # Drop rows with NaN in covariates
133
+ reduced = reduced.drop_nulls(subset=covariates)
134
+ if verbose:
135
+ print(f" After dropping NaN: {reduced.height} rows")
136
+
137
+ if reduced.height == 0:
138
+ if verbose:
139
+ print(" ERROR: No valid rows after NaN removal")
140
+ return None
141
+
142
+ # Step 2: Score data
143
+ if verbose:
144
+ print(" Step 2: score_data...")
145
+ scored = score_data(reduced, covariates, treat, model_type, match_on)
146
+ if verbose:
147
+ print(f" Scored: {scored.height} rows")
148
+
149
+ # Step 3: Match
150
+ if verbose:
151
+ print(f" Step 3: matching (alpha={alpha}, num_matches={num_matches})...")
152
+ matches = match_all_periods(
153
+ scored, treat, tm, entry, id,
154
+ alpha=alpha, num_matches=num_matches,
155
+ replacement=replacement, standard_deviation=standard_deviation,
156
+ block_size=block_size,
157
+ )
158
+
159
+ if matches is None or matches.height == 0:
160
+ if verbose:
161
+ print(" No matches found!")
162
+ return None
163
+
164
+ n_treated_matched = matches["treat_id"].n_unique()
165
+ n_controls_matched = matches["control_id"].n_unique()
166
+ n_treated_total = scored.filter(pl.col(treat) == 1)[id].n_unique()
167
+
168
+ if verbose:
169
+ print(f" Matched: {matches.height} pairs")
170
+ print(f" Treated matched: {n_treated_matched}/{n_treated_total} "
171
+ f"({100*n_treated_matched/n_treated_total:.1f}%)")
172
+ print(f" Controls used: {n_controls_matched}")
173
+
174
+ # Step 4: Balance
175
+ if verbose:
176
+ print(" Step 4: balance...")
177
+ balance = compute_balance(scored, matches, treat, id, tm, covariates)
178
+
179
+ # Step 5: Compute weights
180
+ weights = _compute_weights(matches, id, num_matches)
181
+
182
+ if verbose:
183
+ smd_table(balance)
184
+
185
+ return RollmatchResult(
186
+ matched_data=matches,
187
+ balance=balance,
188
+ n_treated_total=n_treated_total,
189
+ n_treated_matched=n_treated_matched,
190
+ n_controls_matched=n_controls_matched,
191
+ alpha=alpha,
192
+ weights=weights,
193
+ )
194
+
195
+
196
+ def alpha_sweep(
197
+ data: pl.DataFrame,
198
+ treat: str,
199
+ tm: str,
200
+ entry: str,
201
+ id: str,
202
+ covariates: list[str],
203
+ alphas: list[float] | None = None,
204
+ lookback: int = 1,
205
+ num_matches: int = 3,
206
+ replacement: bool = True,
207
+ standard_deviation: str = "average",
208
+ model_type: str = "logistic",
209
+ match_on: str = "logit",
210
+ block_size: int = 2000,
211
+ smd_threshold: float = 0.1,
212
+ ) -> tuple[pl.DataFrame, RollmatchResult | None]:
213
+ """Run rollmatch across multiple alpha values and select the best.
214
+
215
+ Best = fully balanced (all |SMD| < threshold) with highest match rate.
216
+ If none fully balance, select the one with lowest max|SMD|.
217
+
218
+ Parameters
219
+ ----------
220
+ data : pl.DataFrame
221
+ Panel data.
222
+ alphas : list[float]
223
+ Caliper multipliers to try. Default: [0.01, 0.02, 0.05, 0.1, 0.15, 0.2]
224
+ smd_threshold : float
225
+ |SMD| threshold for "balanced" (default 0.1).
226
+ (other params same as rollmatch)
227
+
228
+ Returns
229
+ -------
230
+ (summary_df, best_result)
231
+ """
232
+ if alphas is None:
233
+ alphas = [0.01, 0.02, 0.05, 0.1, 0.15, 0.2]
234
+
235
+ # Pre-compute reduce + score once (shared across alphas)
236
+ reduced = reduce_data(data, treat, tm, entry, id, lookback)
237
+ reduced = reduced.drop_nulls(subset=covariates)
238
+ scored = score_data(reduced, covariates, treat, model_type, match_on)
239
+
240
+ results = []
241
+ best_result = None
242
+ best_score = (-1, -np.inf) # (all_pass, match_rate)
243
+
244
+ for alpha in alphas:
245
+ print(f" alpha={alpha:.2f} ... ", end="", flush=True)
246
+
247
+ matches = match_all_periods(
248
+ scored, treat, tm, entry, id,
249
+ alpha=alpha, num_matches=num_matches,
250
+ replacement=replacement, standard_deviation=standard_deviation,
251
+ block_size=block_size,
252
+ )
253
+
254
+ if matches is None or matches.height == 0:
255
+ print("no matches")
256
+ continue
257
+
258
+ balance = compute_balance(scored, matches, treat, id, tm, covariates)
259
+ max_smd = balance["matched_smd"].abs().max()
260
+ all_pass = max_smd < smd_threshold
261
+
262
+ n_treat_total = scored.filter(pl.col(treat) == 1)[id].n_unique()
263
+ n_treat_matched = matches["treat_id"].n_unique()
264
+ match_rate = n_treat_matched / n_treat_total
265
+
266
+ results.append({
267
+ "alpha": alpha,
268
+ "n_matched_pairs": matches.height,
269
+ "n_treated_matched": n_treat_matched,
270
+ "pct_treated": round(100 * match_rate, 1),
271
+ "max_abs_smd": round(max_smd, 4),
272
+ "all_pass": all_pass,
273
+ })
274
+
275
+ print(f"matched={n_treat_matched}/{n_treat_total} ({100*match_rate:.0f}%), "
276
+ f"max|SMD|={max_smd:.4f} {'✓' if all_pass else '✗'}")
277
+
278
+ # Track best
279
+ score = (int(all_pass), match_rate)
280
+ if score > best_score:
281
+ best_score = score
282
+ weights = _compute_weights(matches, id, num_matches)
283
+
284
+ best_result = RollmatchResult(
285
+ matched_data=matches,
286
+ balance=balance,
287
+ n_treated_total=n_treat_total,
288
+ n_treated_matched=n_treat_matched,
289
+ n_controls_matched=matches["control_id"].n_unique(),
290
+ alpha=alpha,
291
+ weights=weights,
292
+ )
293
+
294
+ summary = pl.DataFrame(results) if results else pl.DataFrame()
295
+
296
+ if best_result:
297
+ print(f"\n Best: alpha={best_result.alpha} "
298
+ f"(matched={best_result.n_treated_matched}/{best_result.n_treated_total}, "
299
+ f"max|SMD|={best_result.balance['matched_smd'].abs().max():.4f})")
300
+
301
+ return summary, best_result
@@ -0,0 +1,207 @@
1
+ """
2
+ diagnostics — Post-matching diagnostic tests.
3
+
4
+ Includes t-tests, SMD tests, variance ratio tests, and equivalence tests
5
+ for assessing matching quality.
6
+ """
7
+
8
+ import numpy as np
9
+ import polars as pl
10
+ from scipy import stats
11
+
12
+
13
+ def balance_test(
14
+ scored_data: pl.DataFrame,
15
+ matches: pl.DataFrame,
16
+ treat: str,
17
+ id: str,
18
+ tm: str,
19
+ covariates: list[str],
20
+ threshold: float = 0.1,
21
+ ) -> pl.DataFrame:
22
+ """Run comprehensive balance diagnostics on matched sample.
23
+
24
+ For each covariate, computes:
25
+ - Standardized mean difference (SMD)
26
+ - Two-sample t-test (H0: means are equal)
27
+ - Variance ratio (treat/control)
28
+ - Kolmogorov-Smirnov test (H0: distributions are equal)
29
+
30
+ Parameters
31
+ ----------
32
+ scored_data : pl.DataFrame
33
+ Reduced data with treatment indicator and covariates.
34
+ matches : pl.DataFrame
35
+ Match results with treat_id, control_id, tm columns.
36
+ treat : str
37
+ Treatment indicator column.
38
+ id : str
39
+ Unit identifier column.
40
+ tm : str
41
+ Time period column.
42
+ covariates : list[str]
43
+ Covariate column names.
44
+ threshold : float
45
+ SMD threshold for pass/fail (default 0.1).
46
+
47
+ Returns
48
+ -------
49
+ pl.DataFrame with diagnostics per covariate.
50
+ """
51
+ # Get matched units
52
+ treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id})
53
+ control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id})
54
+ matched_ids = pl.concat([treat_matches, control_matches])
55
+ matched_data = scored_data.join(matched_ids, on=[tm, id], how="semi")
56
+
57
+ rows = []
58
+ for cov in covariates:
59
+ vals_t = matched_data.filter(pl.col(treat) == 1)[cov].drop_nulls().to_numpy().astype(float)
60
+ vals_c = matched_data.filter(pl.col(treat) == 0)[cov].drop_nulls().to_numpy().astype(float)
61
+
62
+ if len(vals_t) < 2 or len(vals_c) < 2:
63
+ continue
64
+
65
+ # SMD
66
+ sd_t, sd_c = np.std(vals_t, ddof=1), np.std(vals_c, ddof=1)
67
+ pooled_sd = np.sqrt((sd_t**2 + sd_c**2) / 2)
68
+ smd = (np.mean(vals_t) - np.mean(vals_c)) / pooled_sd if pooled_sd > 0 else np.nan
69
+
70
+ # Two-sample t-test (Welch's)
71
+ t_stat, t_pvalue = stats.ttest_ind(vals_t, vals_c, equal_var=False)
72
+
73
+ # Variance ratio
74
+ var_ratio = np.var(vals_t, ddof=1) / np.var(vals_c, ddof=1) if np.var(vals_c, ddof=1) > 0 else np.nan
75
+
76
+ # KS test
77
+ ks_stat, ks_pvalue = stats.ks_2samp(vals_t, vals_c)
78
+
79
+ rows.append({
80
+ "covariate": cov,
81
+ "mean_treated": round(np.mean(vals_t), 4),
82
+ "mean_control": round(np.mean(vals_c), 4),
83
+ "smd": round(smd, 4),
84
+ "smd_pass": bool(abs(smd) < threshold),
85
+ "t_stat": round(t_stat, 4),
86
+ "t_pvalue": round(t_pvalue, 4),
87
+ "var_ratio": round(var_ratio, 4),
88
+ "var_ratio_pass": bool(0.5 < var_ratio < 2.0) if not np.isnan(var_ratio) else False,
89
+ "ks_stat": round(ks_stat, 4),
90
+ "ks_pvalue": round(ks_pvalue, 4),
91
+ })
92
+
93
+ result = pl.DataFrame(rows)
94
+
95
+ # Print summary
96
+ n_pass_smd = result.filter(pl.col("smd_pass")).height
97
+ n_pass_var = result.filter(pl.col("var_ratio_pass")).height
98
+ n_total = result.height
99
+
100
+ print(f"\n{'='*70}")
101
+ print(f" Post-Matching Balance Diagnostics")
102
+ print(f"{'='*70}")
103
+ print(f" SMD < {threshold}: {n_pass_smd}/{n_total} pass")
104
+ print(f" Variance ratio in (0.5, 2.0): {n_pass_var}/{n_total} pass")
105
+ print(f"{'='*70}\n")
106
+
107
+ print(f" {'Covariate':<25} {'SMD':>8} {'t-test p':>10} {'VR':>8} {'KS p':>8}")
108
+ print(f" {'-'*25} {'-'*8} {'-'*10} {'-'*8} {'-'*8}")
109
+ for row in result.iter_rows(named=True):
110
+ smd_flag = "✓" if row["smd_pass"] else "✗"
111
+ vr_flag = "✓" if row["var_ratio_pass"] else "✗"
112
+ print(f" {row['covariate']:<25} {row['smd']:>7.4f}{smd_flag} {row['t_pvalue']:>10.4f} {row['var_ratio']:>7.3f}{vr_flag} {row['ks_pvalue']:>8.4f}")
113
+
114
+ return result
115
+
116
+
117
+ def equivalence_test(
118
+ scored_data: pl.DataFrame,
119
+ matches: pl.DataFrame,
120
+ treat: str,
121
+ id: str,
122
+ tm: str,
123
+ covariates: list[str],
124
+ multiplier: float = 0.36,
125
+ ) -> pl.DataFrame:
126
+ """TOST equivalence test for covariate balance.
127
+
128
+ Tests H0: |SMD| >= delta (non-equivalence).
129
+ Rejection = GOOD (positive evidence of negligible difference).
130
+ Uses Hartman & Hidalgo (2018) approach: delta = multiplier × pooled_SD.
131
+
132
+ Parameters
133
+ ----------
134
+ scored_data : pl.DataFrame
135
+ Reduced data.
136
+ matches : pl.DataFrame
137
+ Match results.
138
+ treat, id, tm : str
139
+ Column names.
140
+ covariates : list[str]
141
+ Covariate names.
142
+ multiplier : float
143
+ Equivalence bound as fraction of pooled SD (default 0.36).
144
+
145
+ Returns
146
+ -------
147
+ pl.DataFrame with TOST results per covariate.
148
+ """
149
+ treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id})
150
+ control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id})
151
+ matched_ids = pl.concat([treat_matches, control_matches])
152
+ matched_data = scored_data.join(matched_ids, on=[tm, id], how="semi")
153
+
154
+ rows = []
155
+ for cov in covariates:
156
+ vals_t = matched_data.filter(pl.col(treat) == 1)[cov].drop_nulls().to_numpy().astype(float)
157
+ vals_c = matched_data.filter(pl.col(treat) == 0)[cov].drop_nulls().to_numpy().astype(float)
158
+
159
+ if len(vals_t) < 2 or len(vals_c) < 2:
160
+ continue
161
+
162
+ m, n = len(vals_t), len(vals_c)
163
+ diff = np.mean(vals_t) - np.mean(vals_c)
164
+ var_t = np.var(vals_t, ddof=1)
165
+ var_c = np.var(vals_c, ddof=1)
166
+
167
+ # Pooled SD: weighted formula matching Hartman & Hidalgo (2018)
168
+ # equivtest R package: sqrt(((m-1)*var_x + (n-1)*var_y) / (m+n-2))
169
+ pooled_sd = np.sqrt(((m - 1) * var_t + (n - 1) * var_c) / (m + n - 2))
170
+ delta = multiplier * pooled_sd
171
+
172
+ # Two one-sided t-tests following equivtest::tost()
173
+ # Uses Welch's t-test (unequal variances)
174
+ se = np.sqrt(var_t / m + var_c / n)
175
+ df_welch = se**4 / ((var_t/m)**2/(m-1) + (var_c/n)**2/(n-1)) if se > 0 else 1
176
+
177
+ # Upper test: H0: diff >= delta, alt: diff < delta
178
+ t_upper = (diff - delta) / se if se > 0 else np.inf
179
+ p_upper = stats.t.cdf(t_upper, df=df_welch)
180
+
181
+ # Lower test: H0: diff <= -delta, alt: diff > -delta
182
+ t_lower = (diff + delta) / se if se > 0 else -np.inf
183
+ p_lower = 1 - stats.t.cdf(t_lower, df=df_welch)
184
+
185
+ tost_p = max(p_upper, p_lower)
186
+
187
+ rows.append({
188
+ "covariate": cov,
189
+ "diff": round(diff, 6),
190
+ "se": round(se, 6),
191
+ "delta": round(delta, 4),
192
+ "tost_p_upper": round(p_upper, 4),
193
+ "tost_p_lower": round(p_lower, 4),
194
+ "tost_p": round(tost_p, 4),
195
+ "equivalent": bool(tost_p < 0.05),
196
+ })
197
+
198
+ result = pl.DataFrame(rows)
199
+
200
+ n_equiv = result.filter(pl.col("equivalent")).height
201
+ print(f"\n TOST Equivalence Test (bound = {multiplier}σ)")
202
+ print(f" Equivalent: {n_equiv}/{result.height} covariates (p < 0.05 = GOOD)")
203
+ for row in result.iter_rows(named=True):
204
+ flag = "✓ EQUIV" if row["equivalent"] else " not equiv"
205
+ print(f" {row['covariate']:<25} p={row['tost_p']:.4f} {flag}")
206
+
207
+ return result