pyrollmatch 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyrollmatch/__init__.py +50 -0
- pyrollmatch/balance.py +121 -0
- pyrollmatch/core.py +301 -0
- pyrollmatch/diagnostics.py +207 -0
- pyrollmatch/match.py +263 -0
- pyrollmatch/reduce.py +72 -0
- pyrollmatch/score.py +229 -0
- pyrollmatch-0.0.3.dist-info/METADATA +278 -0
- pyrollmatch-0.0.3.dist-info/RECORD +11 -0
- pyrollmatch-0.0.3.dist-info/WHEEL +4 -0
- pyrollmatch-0.0.3.dist-info/licenses/LICENSE +21 -0
pyrollmatch/__init__.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""
|
|
2
|
+
pyrollmatch — Fast rolling entry matching for staggered adoption studies.
|
|
3
|
+
|
|
4
|
+
A Python reimplementation of the R ``rollmatch`` package (RTI International)
|
|
5
|
+
using polars and numpy for scalable matching on large panel datasets (100K+ units).
|
|
6
|
+
|
|
7
|
+
Rolling entry matching (REM) explicitly handles staggered treatment adoption
|
|
8
|
+
by matching each treated unit to controls at the treated unit's specific entry
|
|
9
|
+
time, using accumulated (rolling-window) covariates.
|
|
10
|
+
|
|
11
|
+
Quick Start
|
|
12
|
+
-----------
|
|
13
|
+
>>> import polars as pl
|
|
14
|
+
>>> from pyrollmatch import rollmatch, alpha_sweep
|
|
15
|
+
>>>
|
|
16
|
+
>>> # data: panel with columns [unit_id, time, treat, entry_time, x1, x2, ...]
|
|
17
|
+
>>> result = rollmatch(
|
|
18
|
+
... data, treat="treat", tm="time", entry="entry_time", id="unit_id",
|
|
19
|
+
... covariates=["x1", "x2", "x3"],
|
|
20
|
+
... alpha=0.1, num_matches=3,
|
|
21
|
+
... )
|
|
22
|
+
>>> result.balance # SMD table
|
|
23
|
+
>>> result.weights # unit_id -> matching weight
|
|
24
|
+
|
|
25
|
+
References
|
|
26
|
+
----------
|
|
27
|
+
- Witman et al. (2018). "Comparison Group Selection in the Presence of Rolling Entry."
|
|
28
|
+
Health Services Research, 54(1), 262-270. doi:10.1111/1475-6773.13086
|
|
29
|
+
- RTI International rollmatch R package: https://github.com/RTIInternational/rollmatch
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from .core import rollmatch, alpha_sweep, RollmatchResult
|
|
33
|
+
from .reduce import reduce_data
|
|
34
|
+
from .score import score_data, ScoredResult
|
|
35
|
+
from .balance import compute_balance, smd_table
|
|
36
|
+
from .diagnostics import balance_test, equivalence_test
|
|
37
|
+
|
|
38
|
+
__version__ = "0.0.3"
|
|
39
|
+
__all__ = [
|
|
40
|
+
"rollmatch",
|
|
41
|
+
"alpha_sweep",
|
|
42
|
+
"RollmatchResult",
|
|
43
|
+
"reduce_data",
|
|
44
|
+
"score_data",
|
|
45
|
+
"ScoredResult",
|
|
46
|
+
"compute_balance",
|
|
47
|
+
"smd_table",
|
|
48
|
+
"balance_test",
|
|
49
|
+
"equivalence_test",
|
|
50
|
+
]
|
pyrollmatch/balance.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""
|
|
2
|
+
balance — Covariate balance computation and SMD table.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import polars as pl
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def compute_balance(
|
|
10
|
+
scored_data: pl.DataFrame,
|
|
11
|
+
matches: pl.DataFrame,
|
|
12
|
+
treat: str,
|
|
13
|
+
id: str,
|
|
14
|
+
tm: str,
|
|
15
|
+
covariates: list[str],
|
|
16
|
+
) -> pl.DataFrame:
|
|
17
|
+
"""Compute covariate balance before and after matching.
|
|
18
|
+
|
|
19
|
+
Returns a table with means, SDs, and SMDs for each covariate,
|
|
20
|
+
both in the full sample and the matched sample.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
scored_data : pl.DataFrame
|
|
25
|
+
Reduced data with treatment indicator and covariates.
|
|
26
|
+
matches : pl.DataFrame
|
|
27
|
+
Match results with treat_id and control_id columns.
|
|
28
|
+
treat : str
|
|
29
|
+
Treatment indicator column.
|
|
30
|
+
id : str
|
|
31
|
+
Unit identifier column.
|
|
32
|
+
tm : str
|
|
33
|
+
Time period column.
|
|
34
|
+
covariates : list[str]
|
|
35
|
+
Covariate column names.
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
pl.DataFrame with columns:
|
|
40
|
+
covariate, full_mean_t, full_mean_c, full_sd_t, full_sd_c,
|
|
41
|
+
full_smd, matched_mean_t, matched_mean_c, matched_sd_t,
|
|
42
|
+
matched_sd_c, matched_smd
|
|
43
|
+
"""
|
|
44
|
+
# Pre-compute matched data ONCE (not per covariate)
|
|
45
|
+
treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id})
|
|
46
|
+
control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id})
|
|
47
|
+
matched_ids_df = pl.concat([treat_matches, control_matches])
|
|
48
|
+
matched_data = scored_data.join(matched_ids_df, on=[tm, id], how="semi")
|
|
49
|
+
|
|
50
|
+
# Pre-split by treatment group
|
|
51
|
+
full_t = scored_data.filter(pl.col(treat) == 1)
|
|
52
|
+
full_c = scored_data.filter(pl.col(treat) == 0)
|
|
53
|
+
match_t = matched_data.filter(pl.col(treat) == 1)
|
|
54
|
+
match_c = matched_data.filter(pl.col(treat) == 0)
|
|
55
|
+
|
|
56
|
+
rows = []
|
|
57
|
+
|
|
58
|
+
for cov in covariates:
|
|
59
|
+
vals_t = full_t[cov].drop_nulls().to_numpy()
|
|
60
|
+
vals_c = full_c[cov].drop_nulls().to_numpy()
|
|
61
|
+
|
|
62
|
+
full_mean_t = np.mean(vals_t) if len(vals_t) > 0 else np.nan
|
|
63
|
+
full_mean_c = np.mean(vals_c) if len(vals_c) > 0 else np.nan
|
|
64
|
+
full_sd_t = np.std(vals_t, ddof=1) if len(vals_t) > 1 else np.nan
|
|
65
|
+
full_sd_c = np.std(vals_c, ddof=1) if len(vals_c) > 1 else np.nan
|
|
66
|
+
full_pooled = np.sqrt((full_sd_t**2 + full_sd_c**2) / 2) if not (np.isnan(full_sd_t) or np.isnan(full_sd_c)) else np.nan
|
|
67
|
+
full_smd = (full_mean_t - full_mean_c) / full_pooled if full_pooled and full_pooled > 0 else np.nan
|
|
68
|
+
|
|
69
|
+
mvals_t = match_t[cov].drop_nulls().to_numpy()
|
|
70
|
+
mvals_c = match_c[cov].drop_nulls().to_numpy()
|
|
71
|
+
|
|
72
|
+
m_mean_t = np.mean(mvals_t) if len(mvals_t) > 0 else np.nan
|
|
73
|
+
m_mean_c = np.mean(mvals_c) if len(mvals_c) > 0 else np.nan
|
|
74
|
+
m_sd_t = np.std(mvals_t, ddof=1) if len(mvals_t) > 1 else np.nan
|
|
75
|
+
m_sd_c = np.std(mvals_c, ddof=1) if len(mvals_c) > 1 else np.nan
|
|
76
|
+
m_pooled = np.sqrt((m_sd_t**2 + m_sd_c**2) / 2) if not (np.isnan(m_sd_t) or np.isnan(m_sd_c)) else np.nan
|
|
77
|
+
m_smd = (m_mean_t - m_mean_c) / m_pooled if m_pooled and m_pooled > 0 else np.nan
|
|
78
|
+
|
|
79
|
+
rows.append({
|
|
80
|
+
"covariate": cov,
|
|
81
|
+
"full_mean_t": round(full_mean_t, 4),
|
|
82
|
+
"full_mean_c": round(full_mean_c, 4),
|
|
83
|
+
"full_sd_t": round(full_sd_t, 4),
|
|
84
|
+
"full_sd_c": round(full_sd_c, 4),
|
|
85
|
+
"full_smd": round(full_smd, 4),
|
|
86
|
+
"matched_mean_t": round(m_mean_t, 4),
|
|
87
|
+
"matched_mean_c": round(m_mean_c, 4),
|
|
88
|
+
"matched_sd_t": round(m_sd_t, 4),
|
|
89
|
+
"matched_sd_c": round(m_sd_c, 4),
|
|
90
|
+
"matched_smd": round(m_smd, 4),
|
|
91
|
+
})
|
|
92
|
+
|
|
93
|
+
return pl.DataFrame(rows)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def smd_table(balance: pl.DataFrame, threshold: float = 0.1) -> None:
|
|
97
|
+
"""Print a formatted SMD table with pass/fail indicators.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
balance : pl.DataFrame
|
|
102
|
+
Output from compute_balance().
|
|
103
|
+
threshold : float
|
|
104
|
+
|SMD| threshold for pass/fail (default 0.1).
|
|
105
|
+
"""
|
|
106
|
+
max_smd = balance["matched_smd"].abs().max()
|
|
107
|
+
all_pass = balance["matched_smd"].abs().max() < threshold
|
|
108
|
+
|
|
109
|
+
print(f"\n{'='*70}")
|
|
110
|
+
print(f" Standardized Mean Differences (threshold: |SMD| < {threshold})")
|
|
111
|
+
print(f" Max |SMD| = {max_smd:.4f} {'✓ ALL PASS' if all_pass else '✗ SOME FAIL'}")
|
|
112
|
+
print(f"{'='*70}\n")
|
|
113
|
+
print(f" {'Covariate':<30} {'Full SMD':>10} {'Matched SMD':>12} {'Pass':>6}")
|
|
114
|
+
print(f" {'-'*30} {'-'*10} {'-'*12} {'-'*6}")
|
|
115
|
+
|
|
116
|
+
for row in balance.iter_rows(named=True):
|
|
117
|
+
smd = row["matched_smd"]
|
|
118
|
+
passed = abs(smd) < threshold if smd is not None else False
|
|
119
|
+
print(f" {row['covariate']:<30} {row['full_smd']:>10.4f} {smd:>12.4f} {'✓' if passed else '✗':>6}")
|
|
120
|
+
|
|
121
|
+
print()
|
pyrollmatch/core.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
"""
|
|
2
|
+
core — Main rollmatch orchestration and alpha sweep.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import polars as pl
|
|
6
|
+
import numpy as np
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
|
|
9
|
+
from .reduce import reduce_data
|
|
10
|
+
from .score import score_data
|
|
11
|
+
from .match import match_all_periods
|
|
12
|
+
from .balance import compute_balance, smd_table
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class RollmatchResult:
|
|
17
|
+
"""Result from rollmatch."""
|
|
18
|
+
matched_data: pl.DataFrame
|
|
19
|
+
balance: pl.DataFrame
|
|
20
|
+
n_treated_total: int
|
|
21
|
+
n_treated_matched: int
|
|
22
|
+
n_controls_matched: int
|
|
23
|
+
alpha: float
|
|
24
|
+
weights: pl.DataFrame # id -> weight
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _compute_weights(matches: pl.DataFrame, id: str, num_matches: int) -> pl.DataFrame:
|
|
28
|
+
"""Compute matching weights from matched pairs.
|
|
29
|
+
|
|
30
|
+
Following R rollmatch convention:
|
|
31
|
+
- treatment_weight = 1 / actual_matches_for_this_treated_unit
|
|
32
|
+
- control_weight = sum of treatment_weights across all treatments
|
|
33
|
+
this control is matched to
|
|
34
|
+
|
|
35
|
+
This ensures proper inverse probability weighting when treated units
|
|
36
|
+
have different numbers of matches (e.g., due to tight calipers).
|
|
37
|
+
"""
|
|
38
|
+
treat_match_counts = (
|
|
39
|
+
matches.group_by("treat_id").len()
|
|
40
|
+
.rename({"len": "total_matches"})
|
|
41
|
+
)
|
|
42
|
+
matches_with_weights = matches.join(treat_match_counts, on="treat_id")
|
|
43
|
+
matches_with_weights = matches_with_weights.with_columns(
|
|
44
|
+
(1.0 / pl.col("total_matches")).alias("treatment_weight")
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
treat_weights = (
|
|
48
|
+
matches.select("treat_id").unique()
|
|
49
|
+
.rename({"treat_id": id})
|
|
50
|
+
.with_columns(pl.lit(1.0).alias("weight"))
|
|
51
|
+
)
|
|
52
|
+
ctrl_weights = (
|
|
53
|
+
matches_with_weights
|
|
54
|
+
.group_by("control_id")
|
|
55
|
+
.agg(pl.col("treatment_weight").sum().alias("weight"))
|
|
56
|
+
.rename({"control_id": id})
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
weights = pl.concat([treat_weights, ctrl_weights])
|
|
60
|
+
return weights.group_by(id).agg(pl.col("weight").sum())
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def rollmatch(
|
|
64
|
+
data: pl.DataFrame,
|
|
65
|
+
treat: str,
|
|
66
|
+
tm: str,
|
|
67
|
+
entry: str,
|
|
68
|
+
id: str,
|
|
69
|
+
covariates: list[str],
|
|
70
|
+
lookback: int = 1,
|
|
71
|
+
alpha: float = 0,
|
|
72
|
+
num_matches: int = 3,
|
|
73
|
+
replacement: bool = True,
|
|
74
|
+
standard_deviation: str = "average",
|
|
75
|
+
model_type: str = "logistic",
|
|
76
|
+
match_on: str = "logit",
|
|
77
|
+
block_size: int = 2000,
|
|
78
|
+
verbose: bool = True,
|
|
79
|
+
) -> RollmatchResult | None:
|
|
80
|
+
"""Run the full rolling entry matching pipeline.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
data : pl.DataFrame
|
|
85
|
+
Panel data with unit × time observations.
|
|
86
|
+
treat : str
|
|
87
|
+
Binary treatment column (1=treated, 0=control).
|
|
88
|
+
tm : str
|
|
89
|
+
Time period column (integer).
|
|
90
|
+
entry : str
|
|
91
|
+
Entry period column. Treatment onset for treated units; null or
|
|
92
|
+
any value > max(tm) for controls.
|
|
93
|
+
id : str
|
|
94
|
+
Unit identifier column.
|
|
95
|
+
covariates : list[str]
|
|
96
|
+
Covariate column names for matching.
|
|
97
|
+
lookback : int
|
|
98
|
+
Periods to look back from entry for baseline covariates.
|
|
99
|
+
alpha : float
|
|
100
|
+
Caliper multiplier (0 = no caliper).
|
|
101
|
+
num_matches : int
|
|
102
|
+
Number of control matches per treated unit.
|
|
103
|
+
replacement : bool
|
|
104
|
+
Allow control reuse within time period.
|
|
105
|
+
standard_deviation : str
|
|
106
|
+
Method for pooled SD in caliper.
|
|
107
|
+
model_type : str
|
|
108
|
+
Propensity model type ("logistic").
|
|
109
|
+
match_on : str
|
|
110
|
+
Score type ("logit" or "pscore").
|
|
111
|
+
block_size : int
|
|
112
|
+
Block size for memory-efficient matching.
|
|
113
|
+
verbose : bool
|
|
114
|
+
Print progress.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
RollmatchResult or None if matching fails.
|
|
119
|
+
"""
|
|
120
|
+
if verbose:
|
|
121
|
+
n_treat = data.filter(pl.col(treat) == 1)[id].n_unique()
|
|
122
|
+
n_ctrl = data.filter(pl.col(treat) == 0)[id].n_unique()
|
|
123
|
+
print(f"rollmatch: {n_treat} treated, {n_ctrl} controls, alpha={alpha}")
|
|
124
|
+
|
|
125
|
+
# Step 1: Reduce data
|
|
126
|
+
if verbose:
|
|
127
|
+
print(" Step 1: reduce_data...")
|
|
128
|
+
reduced = reduce_data(data, treat, tm, entry, id, lookback)
|
|
129
|
+
if verbose:
|
|
130
|
+
print(f" Reduced: {reduced.height} rows")
|
|
131
|
+
|
|
132
|
+
# Drop rows with NaN in covariates
|
|
133
|
+
reduced = reduced.drop_nulls(subset=covariates)
|
|
134
|
+
if verbose:
|
|
135
|
+
print(f" After dropping NaN: {reduced.height} rows")
|
|
136
|
+
|
|
137
|
+
if reduced.height == 0:
|
|
138
|
+
if verbose:
|
|
139
|
+
print(" ERROR: No valid rows after NaN removal")
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
# Step 2: Score data
|
|
143
|
+
if verbose:
|
|
144
|
+
print(" Step 2: score_data...")
|
|
145
|
+
scored = score_data(reduced, covariates, treat, model_type, match_on)
|
|
146
|
+
if verbose:
|
|
147
|
+
print(f" Scored: {scored.height} rows")
|
|
148
|
+
|
|
149
|
+
# Step 3: Match
|
|
150
|
+
if verbose:
|
|
151
|
+
print(f" Step 3: matching (alpha={alpha}, num_matches={num_matches})...")
|
|
152
|
+
matches = match_all_periods(
|
|
153
|
+
scored, treat, tm, entry, id,
|
|
154
|
+
alpha=alpha, num_matches=num_matches,
|
|
155
|
+
replacement=replacement, standard_deviation=standard_deviation,
|
|
156
|
+
block_size=block_size,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if matches is None or matches.height == 0:
|
|
160
|
+
if verbose:
|
|
161
|
+
print(" No matches found!")
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
n_treated_matched = matches["treat_id"].n_unique()
|
|
165
|
+
n_controls_matched = matches["control_id"].n_unique()
|
|
166
|
+
n_treated_total = scored.filter(pl.col(treat) == 1)[id].n_unique()
|
|
167
|
+
|
|
168
|
+
if verbose:
|
|
169
|
+
print(f" Matched: {matches.height} pairs")
|
|
170
|
+
print(f" Treated matched: {n_treated_matched}/{n_treated_total} "
|
|
171
|
+
f"({100*n_treated_matched/n_treated_total:.1f}%)")
|
|
172
|
+
print(f" Controls used: {n_controls_matched}")
|
|
173
|
+
|
|
174
|
+
# Step 4: Balance
|
|
175
|
+
if verbose:
|
|
176
|
+
print(" Step 4: balance...")
|
|
177
|
+
balance = compute_balance(scored, matches, treat, id, tm, covariates)
|
|
178
|
+
|
|
179
|
+
# Step 5: Compute weights
|
|
180
|
+
weights = _compute_weights(matches, id, num_matches)
|
|
181
|
+
|
|
182
|
+
if verbose:
|
|
183
|
+
smd_table(balance)
|
|
184
|
+
|
|
185
|
+
return RollmatchResult(
|
|
186
|
+
matched_data=matches,
|
|
187
|
+
balance=balance,
|
|
188
|
+
n_treated_total=n_treated_total,
|
|
189
|
+
n_treated_matched=n_treated_matched,
|
|
190
|
+
n_controls_matched=n_controls_matched,
|
|
191
|
+
alpha=alpha,
|
|
192
|
+
weights=weights,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def alpha_sweep(
|
|
197
|
+
data: pl.DataFrame,
|
|
198
|
+
treat: str,
|
|
199
|
+
tm: str,
|
|
200
|
+
entry: str,
|
|
201
|
+
id: str,
|
|
202
|
+
covariates: list[str],
|
|
203
|
+
alphas: list[float] | None = None,
|
|
204
|
+
lookback: int = 1,
|
|
205
|
+
num_matches: int = 3,
|
|
206
|
+
replacement: bool = True,
|
|
207
|
+
standard_deviation: str = "average",
|
|
208
|
+
model_type: str = "logistic",
|
|
209
|
+
match_on: str = "logit",
|
|
210
|
+
block_size: int = 2000,
|
|
211
|
+
smd_threshold: float = 0.1,
|
|
212
|
+
) -> tuple[pl.DataFrame, RollmatchResult | None]:
|
|
213
|
+
"""Run rollmatch across multiple alpha values and select the best.
|
|
214
|
+
|
|
215
|
+
Best = fully balanced (all |SMD| < threshold) with highest match rate.
|
|
216
|
+
If none fully balance, select the one with lowest max|SMD|.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
data : pl.DataFrame
|
|
221
|
+
Panel data.
|
|
222
|
+
alphas : list[float]
|
|
223
|
+
Caliper multipliers to try. Default: [0.01, 0.02, 0.05, 0.1, 0.15, 0.2]
|
|
224
|
+
smd_threshold : float
|
|
225
|
+
|SMD| threshold for "balanced" (default 0.1).
|
|
226
|
+
(other params same as rollmatch)
|
|
227
|
+
|
|
228
|
+
Returns
|
|
229
|
+
-------
|
|
230
|
+
(summary_df, best_result)
|
|
231
|
+
"""
|
|
232
|
+
if alphas is None:
|
|
233
|
+
alphas = [0.01, 0.02, 0.05, 0.1, 0.15, 0.2]
|
|
234
|
+
|
|
235
|
+
# Pre-compute reduce + score once (shared across alphas)
|
|
236
|
+
reduced = reduce_data(data, treat, tm, entry, id, lookback)
|
|
237
|
+
reduced = reduced.drop_nulls(subset=covariates)
|
|
238
|
+
scored = score_data(reduced, covariates, treat, model_type, match_on)
|
|
239
|
+
|
|
240
|
+
results = []
|
|
241
|
+
best_result = None
|
|
242
|
+
best_score = (-1, -np.inf) # (all_pass, match_rate)
|
|
243
|
+
|
|
244
|
+
for alpha in alphas:
|
|
245
|
+
print(f" alpha={alpha:.2f} ... ", end="", flush=True)
|
|
246
|
+
|
|
247
|
+
matches = match_all_periods(
|
|
248
|
+
scored, treat, tm, entry, id,
|
|
249
|
+
alpha=alpha, num_matches=num_matches,
|
|
250
|
+
replacement=replacement, standard_deviation=standard_deviation,
|
|
251
|
+
block_size=block_size,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if matches is None or matches.height == 0:
|
|
255
|
+
print("no matches")
|
|
256
|
+
continue
|
|
257
|
+
|
|
258
|
+
balance = compute_balance(scored, matches, treat, id, tm, covariates)
|
|
259
|
+
max_smd = balance["matched_smd"].abs().max()
|
|
260
|
+
all_pass = max_smd < smd_threshold
|
|
261
|
+
|
|
262
|
+
n_treat_total = scored.filter(pl.col(treat) == 1)[id].n_unique()
|
|
263
|
+
n_treat_matched = matches["treat_id"].n_unique()
|
|
264
|
+
match_rate = n_treat_matched / n_treat_total
|
|
265
|
+
|
|
266
|
+
results.append({
|
|
267
|
+
"alpha": alpha,
|
|
268
|
+
"n_matched_pairs": matches.height,
|
|
269
|
+
"n_treated_matched": n_treat_matched,
|
|
270
|
+
"pct_treated": round(100 * match_rate, 1),
|
|
271
|
+
"max_abs_smd": round(max_smd, 4),
|
|
272
|
+
"all_pass": all_pass,
|
|
273
|
+
})
|
|
274
|
+
|
|
275
|
+
print(f"matched={n_treat_matched}/{n_treat_total} ({100*match_rate:.0f}%), "
|
|
276
|
+
f"max|SMD|={max_smd:.4f} {'✓' if all_pass else '✗'}")
|
|
277
|
+
|
|
278
|
+
# Track best
|
|
279
|
+
score = (int(all_pass), match_rate)
|
|
280
|
+
if score > best_score:
|
|
281
|
+
best_score = score
|
|
282
|
+
weights = _compute_weights(matches, id, num_matches)
|
|
283
|
+
|
|
284
|
+
best_result = RollmatchResult(
|
|
285
|
+
matched_data=matches,
|
|
286
|
+
balance=balance,
|
|
287
|
+
n_treated_total=n_treat_total,
|
|
288
|
+
n_treated_matched=n_treat_matched,
|
|
289
|
+
n_controls_matched=matches["control_id"].n_unique(),
|
|
290
|
+
alpha=alpha,
|
|
291
|
+
weights=weights,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
summary = pl.DataFrame(results) if results else pl.DataFrame()
|
|
295
|
+
|
|
296
|
+
if best_result:
|
|
297
|
+
print(f"\n Best: alpha={best_result.alpha} "
|
|
298
|
+
f"(matched={best_result.n_treated_matched}/{best_result.n_treated_total}, "
|
|
299
|
+
f"max|SMD|={best_result.balance['matched_smd'].abs().max():.4f})")
|
|
300
|
+
|
|
301
|
+
return summary, best_result
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""
|
|
2
|
+
diagnostics — Post-matching diagnostic tests.
|
|
3
|
+
|
|
4
|
+
Includes t-tests, SMD tests, variance ratio tests, and equivalence tests
|
|
5
|
+
for assessing matching quality.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import polars as pl
|
|
10
|
+
from scipy import stats
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def balance_test(
|
|
14
|
+
scored_data: pl.DataFrame,
|
|
15
|
+
matches: pl.DataFrame,
|
|
16
|
+
treat: str,
|
|
17
|
+
id: str,
|
|
18
|
+
tm: str,
|
|
19
|
+
covariates: list[str],
|
|
20
|
+
threshold: float = 0.1,
|
|
21
|
+
) -> pl.DataFrame:
|
|
22
|
+
"""Run comprehensive balance diagnostics on matched sample.
|
|
23
|
+
|
|
24
|
+
For each covariate, computes:
|
|
25
|
+
- Standardized mean difference (SMD)
|
|
26
|
+
- Two-sample t-test (H0: means are equal)
|
|
27
|
+
- Variance ratio (treat/control)
|
|
28
|
+
- Kolmogorov-Smirnov test (H0: distributions are equal)
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
scored_data : pl.DataFrame
|
|
33
|
+
Reduced data with treatment indicator and covariates.
|
|
34
|
+
matches : pl.DataFrame
|
|
35
|
+
Match results with treat_id, control_id, tm columns.
|
|
36
|
+
treat : str
|
|
37
|
+
Treatment indicator column.
|
|
38
|
+
id : str
|
|
39
|
+
Unit identifier column.
|
|
40
|
+
tm : str
|
|
41
|
+
Time period column.
|
|
42
|
+
covariates : list[str]
|
|
43
|
+
Covariate column names.
|
|
44
|
+
threshold : float
|
|
45
|
+
SMD threshold for pass/fail (default 0.1).
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
pl.DataFrame with diagnostics per covariate.
|
|
50
|
+
"""
|
|
51
|
+
# Get matched units
|
|
52
|
+
treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id})
|
|
53
|
+
control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id})
|
|
54
|
+
matched_ids = pl.concat([treat_matches, control_matches])
|
|
55
|
+
matched_data = scored_data.join(matched_ids, on=[tm, id], how="semi")
|
|
56
|
+
|
|
57
|
+
rows = []
|
|
58
|
+
for cov in covariates:
|
|
59
|
+
vals_t = matched_data.filter(pl.col(treat) == 1)[cov].drop_nulls().to_numpy().astype(float)
|
|
60
|
+
vals_c = matched_data.filter(pl.col(treat) == 0)[cov].drop_nulls().to_numpy().astype(float)
|
|
61
|
+
|
|
62
|
+
if len(vals_t) < 2 or len(vals_c) < 2:
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
# SMD
|
|
66
|
+
sd_t, sd_c = np.std(vals_t, ddof=1), np.std(vals_c, ddof=1)
|
|
67
|
+
pooled_sd = np.sqrt((sd_t**2 + sd_c**2) / 2)
|
|
68
|
+
smd = (np.mean(vals_t) - np.mean(vals_c)) / pooled_sd if pooled_sd > 0 else np.nan
|
|
69
|
+
|
|
70
|
+
# Two-sample t-test (Welch's)
|
|
71
|
+
t_stat, t_pvalue = stats.ttest_ind(vals_t, vals_c, equal_var=False)
|
|
72
|
+
|
|
73
|
+
# Variance ratio
|
|
74
|
+
var_ratio = np.var(vals_t, ddof=1) / np.var(vals_c, ddof=1) if np.var(vals_c, ddof=1) > 0 else np.nan
|
|
75
|
+
|
|
76
|
+
# KS test
|
|
77
|
+
ks_stat, ks_pvalue = stats.ks_2samp(vals_t, vals_c)
|
|
78
|
+
|
|
79
|
+
rows.append({
|
|
80
|
+
"covariate": cov,
|
|
81
|
+
"mean_treated": round(np.mean(vals_t), 4),
|
|
82
|
+
"mean_control": round(np.mean(vals_c), 4),
|
|
83
|
+
"smd": round(smd, 4),
|
|
84
|
+
"smd_pass": bool(abs(smd) < threshold),
|
|
85
|
+
"t_stat": round(t_stat, 4),
|
|
86
|
+
"t_pvalue": round(t_pvalue, 4),
|
|
87
|
+
"var_ratio": round(var_ratio, 4),
|
|
88
|
+
"var_ratio_pass": bool(0.5 < var_ratio < 2.0) if not np.isnan(var_ratio) else False,
|
|
89
|
+
"ks_stat": round(ks_stat, 4),
|
|
90
|
+
"ks_pvalue": round(ks_pvalue, 4),
|
|
91
|
+
})
|
|
92
|
+
|
|
93
|
+
result = pl.DataFrame(rows)
|
|
94
|
+
|
|
95
|
+
# Print summary
|
|
96
|
+
n_pass_smd = result.filter(pl.col("smd_pass")).height
|
|
97
|
+
n_pass_var = result.filter(pl.col("var_ratio_pass")).height
|
|
98
|
+
n_total = result.height
|
|
99
|
+
|
|
100
|
+
print(f"\n{'='*70}")
|
|
101
|
+
print(f" Post-Matching Balance Diagnostics")
|
|
102
|
+
print(f"{'='*70}")
|
|
103
|
+
print(f" SMD < {threshold}: {n_pass_smd}/{n_total} pass")
|
|
104
|
+
print(f" Variance ratio in (0.5, 2.0): {n_pass_var}/{n_total} pass")
|
|
105
|
+
print(f"{'='*70}\n")
|
|
106
|
+
|
|
107
|
+
print(f" {'Covariate':<25} {'SMD':>8} {'t-test p':>10} {'VR':>8} {'KS p':>8}")
|
|
108
|
+
print(f" {'-'*25} {'-'*8} {'-'*10} {'-'*8} {'-'*8}")
|
|
109
|
+
for row in result.iter_rows(named=True):
|
|
110
|
+
smd_flag = "✓" if row["smd_pass"] else "✗"
|
|
111
|
+
vr_flag = "✓" if row["var_ratio_pass"] else "✗"
|
|
112
|
+
print(f" {row['covariate']:<25} {row['smd']:>7.4f}{smd_flag} {row['t_pvalue']:>10.4f} {row['var_ratio']:>7.3f}{vr_flag} {row['ks_pvalue']:>8.4f}")
|
|
113
|
+
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def equivalence_test(
|
|
118
|
+
scored_data: pl.DataFrame,
|
|
119
|
+
matches: pl.DataFrame,
|
|
120
|
+
treat: str,
|
|
121
|
+
id: str,
|
|
122
|
+
tm: str,
|
|
123
|
+
covariates: list[str],
|
|
124
|
+
multiplier: float = 0.36,
|
|
125
|
+
) -> pl.DataFrame:
|
|
126
|
+
"""TOST equivalence test for covariate balance.
|
|
127
|
+
|
|
128
|
+
Tests H0: |SMD| >= delta (non-equivalence).
|
|
129
|
+
Rejection = GOOD (positive evidence of negligible difference).
|
|
130
|
+
Uses Hartman & Hidalgo (2018) approach: delta = multiplier × pooled_SD.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
scored_data : pl.DataFrame
|
|
135
|
+
Reduced data.
|
|
136
|
+
matches : pl.DataFrame
|
|
137
|
+
Match results.
|
|
138
|
+
treat, id, tm : str
|
|
139
|
+
Column names.
|
|
140
|
+
covariates : list[str]
|
|
141
|
+
Covariate names.
|
|
142
|
+
multiplier : float
|
|
143
|
+
Equivalence bound as fraction of pooled SD (default 0.36).
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
pl.DataFrame with TOST results per covariate.
|
|
148
|
+
"""
|
|
149
|
+
treat_matches = matches.select(tm, "treat_id").unique().rename({"treat_id": id})
|
|
150
|
+
control_matches = matches.select(tm, "control_id").unique().rename({"control_id": id})
|
|
151
|
+
matched_ids = pl.concat([treat_matches, control_matches])
|
|
152
|
+
matched_data = scored_data.join(matched_ids, on=[tm, id], how="semi")
|
|
153
|
+
|
|
154
|
+
rows = []
|
|
155
|
+
for cov in covariates:
|
|
156
|
+
vals_t = matched_data.filter(pl.col(treat) == 1)[cov].drop_nulls().to_numpy().astype(float)
|
|
157
|
+
vals_c = matched_data.filter(pl.col(treat) == 0)[cov].drop_nulls().to_numpy().astype(float)
|
|
158
|
+
|
|
159
|
+
if len(vals_t) < 2 or len(vals_c) < 2:
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
m, n = len(vals_t), len(vals_c)
|
|
163
|
+
diff = np.mean(vals_t) - np.mean(vals_c)
|
|
164
|
+
var_t = np.var(vals_t, ddof=1)
|
|
165
|
+
var_c = np.var(vals_c, ddof=1)
|
|
166
|
+
|
|
167
|
+
# Pooled SD: weighted formula matching Hartman & Hidalgo (2018)
|
|
168
|
+
# equivtest R package: sqrt(((m-1)*var_x + (n-1)*var_y) / (m+n-2))
|
|
169
|
+
pooled_sd = np.sqrt(((m - 1) * var_t + (n - 1) * var_c) / (m + n - 2))
|
|
170
|
+
delta = multiplier * pooled_sd
|
|
171
|
+
|
|
172
|
+
# Two one-sided t-tests following equivtest::tost()
|
|
173
|
+
# Uses Welch's t-test (unequal variances)
|
|
174
|
+
se = np.sqrt(var_t / m + var_c / n)
|
|
175
|
+
df_welch = se**4 / ((var_t/m)**2/(m-1) + (var_c/n)**2/(n-1)) if se > 0 else 1
|
|
176
|
+
|
|
177
|
+
# Upper test: H0: diff >= delta, alt: diff < delta
|
|
178
|
+
t_upper = (diff - delta) / se if se > 0 else np.inf
|
|
179
|
+
p_upper = stats.t.cdf(t_upper, df=df_welch)
|
|
180
|
+
|
|
181
|
+
# Lower test: H0: diff <= -delta, alt: diff > -delta
|
|
182
|
+
t_lower = (diff + delta) / se if se > 0 else -np.inf
|
|
183
|
+
p_lower = 1 - stats.t.cdf(t_lower, df=df_welch)
|
|
184
|
+
|
|
185
|
+
tost_p = max(p_upper, p_lower)
|
|
186
|
+
|
|
187
|
+
rows.append({
|
|
188
|
+
"covariate": cov,
|
|
189
|
+
"diff": round(diff, 6),
|
|
190
|
+
"se": round(se, 6),
|
|
191
|
+
"delta": round(delta, 4),
|
|
192
|
+
"tost_p_upper": round(p_upper, 4),
|
|
193
|
+
"tost_p_lower": round(p_lower, 4),
|
|
194
|
+
"tost_p": round(tost_p, 4),
|
|
195
|
+
"equivalent": bool(tost_p < 0.05),
|
|
196
|
+
})
|
|
197
|
+
|
|
198
|
+
result = pl.DataFrame(rows)
|
|
199
|
+
|
|
200
|
+
n_equiv = result.filter(pl.col("equivalent")).height
|
|
201
|
+
print(f"\n TOST Equivalence Test (bound = {multiplier}σ)")
|
|
202
|
+
print(f" Equivalent: {n_equiv}/{result.height} covariates (p < 0.05 = GOOD)")
|
|
203
|
+
for row in result.iter_rows(named=True):
|
|
204
|
+
flag = "✓ EQUIV" if row["equivalent"] else " not equiv"
|
|
205
|
+
print(f" {row['covariate']:<25} p={row['tost_p']:.4f} {flag}")
|
|
206
|
+
|
|
207
|
+
return result
|