diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,899 @@
1
+ """
2
+ Doubly robust math for the Efficient DiD estimator (with covariates).
3
+
4
+ Implements the with-covariates path from Chen, Sant'Anna & Xie (2025):
5
+ OLS outcome regression (linear working model), sieve-based propensity
6
+ score ratios (Eq 4.1-4.2), sieve-based inverse propensities (step 4),
7
+ kernel-smoothed conditional Omega*(X) for per-unit efficient weights,
8
+ doubly robust generated outcomes (Eq 4.4), and the efficient influence
9
+ function for analytical standard errors.
10
+
11
+ The DR property ensures consistency if either the OLS outcome model or
12
+ the sieve propensity ratio is correctly specified. The OLS working model
13
+ does not generically guarantee the semiparametric efficiency bound unless
14
+ the conditional mean is linear in covariates (see REGISTRY.md).
15
+
16
+ All functions are pure (no state), operating on pre-pivoted numpy arrays.
17
+ """
18
+
19
+ import warnings
20
+ from itertools import combinations_with_replacement
21
+ from math import comb
22
+ from typing import Dict, List, Optional, Tuple
23
+
24
+ import numpy as np
25
+ from scipy.spatial.distance import cdist
26
+
27
+ from diff_diff.linalg import solve_ols
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Outcome regression
31
+ # ---------------------------------------------------------------------------
32
+
33
+
34
+ def estimate_outcome_regression(
35
+ outcome_wide: np.ndarray,
36
+ covariate_matrix: np.ndarray,
37
+ group_mask: np.ndarray,
38
+ t_col: int,
39
+ tpre_col: int,
40
+ unit_weights: Optional[np.ndarray] = None,
41
+ ) -> np.ndarray:
42
+ """Estimate conditional mean outcome change m_hat(X) for a comparison group.
43
+
44
+ Regresses ``(Y_t - Y_{tpre})`` on ``X`` within the units identified by
45
+ ``group_mask`` using OLS. Returns predicted values ``m_hat(X_i)`` for
46
+ **all** units (extrapolated from the within-group fit).
47
+
48
+ This implements ``m_hat_{g',t,tpre}(X) = E[Y_t - Y_{tpre} | G=g', X]``.
49
+
50
+ Parameters
51
+ ----------
52
+ outcome_wide : ndarray, shape (n_units, n_periods)
53
+ Pivoted outcome matrix.
54
+ covariate_matrix : ndarray, shape (n_units, n_covariates)
55
+ Unit-level (time-invariant) covariates.
56
+ group_mask : ndarray of bool, shape (n_units,)
57
+ Mask selecting units in the comparison group.
58
+ t_col, tpre_col : int
59
+ Column indices in ``outcome_wide`` for the two time periods.
60
+ unit_weights : ndarray, shape (n_units,), optional
61
+ Survey weights at the unit level. When provided, uses WLS
62
+ instead of OLS for the within-group regression.
63
+
64
+ Returns
65
+ -------
66
+ m_hat : ndarray, shape (n_units,)
67
+ Predicted ``E[Y_t - Y_{tpre} | X]`` for every unit.
68
+ """
69
+ Y_group = outcome_wide[group_mask]
70
+ delta_y = Y_group[:, t_col] - Y_group[:, tpre_col]
71
+
72
+ X_group = covariate_matrix[group_mask]
73
+ X_design = np.column_stack([np.ones(len(X_group)), X_group])
74
+
75
+ w_group = unit_weights[group_mask] if unit_weights is not None else None
76
+
77
+ coef, _, _ = solve_ols(
78
+ X_design,
79
+ delta_y,
80
+ weights=w_group,
81
+ weight_type="pweight" if w_group is not None else None,
82
+ return_vcov=False,
83
+ rank_deficient_action="warn",
84
+ )
85
+
86
+ X_all = np.column_stack([np.ones(len(covariate_matrix)), covariate_matrix])
87
+ coef_safe = np.where(np.isfinite(coef), coef, 0.0)
88
+ m_hat = X_all @ coef_safe
89
+
90
+ non_finite = ~np.isfinite(m_hat)
91
+ if non_finite.any():
92
+ n_bad = int(non_finite.sum())
93
+ warnings.warn(
94
+ f"Outcome regression produced {n_bad} non-finite prediction(s). "
95
+ "Setting to 0.0 (equivalent to no covariate adjustment).",
96
+ UserWarning,
97
+ stacklevel=2,
98
+ )
99
+ m_hat[non_finite] = 0.0
100
+
101
+ return m_hat
102
+
103
+
104
+ # ---------------------------------------------------------------------------
105
+ # Sieve-based propensity ratio estimation (Eq 4.1-4.2)
106
+ # ---------------------------------------------------------------------------
107
+
108
+
109
+ def _polynomial_sieve_basis(X: np.ndarray, degree: int) -> np.ndarray:
110
+ """Build polynomial sieve basis up to total degree K.
111
+
112
+ For d covariates and degree K, includes all monomials
113
+ ``X_1^{a_1} * ... * X_d^{a_d}`` where ``a_1 + ... + a_d <= K``,
114
+ including the intercept term (degree 0).
115
+
116
+ Standardizes X to zero mean, unit variance for numerical stability.
117
+
118
+ Parameters
119
+ ----------
120
+ X : ndarray, shape (n, d)
121
+ Covariate matrix.
122
+ degree : int
123
+ Maximum total polynomial degree.
124
+
125
+ Returns
126
+ -------
127
+ basis : ndarray, shape (n, n_basis)
128
+ Sieve basis matrix. ``n_basis = C(K+d, d)``.
129
+ """
130
+ n, d = X.shape
131
+
132
+ # Standardize for numerical stability (unweighted mean/std intentional —
133
+ # this is only for conditioning, not for the statistical estimand; with
134
+ # survey weights the sieve basis is the same, only the objective changes)
135
+ X_mean = X.mean(axis=0)
136
+ X_std = X.std(axis=0)
137
+ X_std[X_std < 1e-10] = 1.0 # avoid division by zero for constant columns
138
+ X_s = (X - X_mean) / X_std
139
+
140
+ # Build monomials: enumerate all (a_1, ..., a_d) with sum <= degree
141
+ columns = [np.ones(n)] # degree-0 (intercept)
142
+ for total_deg in range(1, degree + 1):
143
+ for exponents in combinations_with_replacement(range(d), total_deg):
144
+ col = np.ones(n)
145
+ for idx in exponents:
146
+ col = col * X_s[:, idx]
147
+ columns.append(col)
148
+
149
+ return np.column_stack(columns)
150
+
151
+
152
+ def estimate_propensity_ratio_sieve(
153
+ covariate_matrix: np.ndarray,
154
+ mask_g: np.ndarray,
155
+ mask_gp: np.ndarray,
156
+ k_max: Optional[int] = None,
157
+ criterion: str = "bic",
158
+ ratio_clip: float = 20.0,
159
+ unit_weights: Optional[np.ndarray] = None,
160
+ ) -> np.ndarray:
161
+ r"""Estimate propensity ratio via sieve convex minimization (Eq 4.1-4.2).
162
+
163
+ Solves for each sieve degree K = 1, ..., k_max:
164
+
165
+ .. math::
166
+ \hat\beta_K = \arg\min_{\beta} \frac{1}{n}
167
+ \sum_i \bigl[ G_{g',i} (\psi^K(X_i)'\beta)^2
168
+ - 2 G_{g,i} (\psi^K(X_i)'\beta) \bigr]
169
+
170
+ The FOC gives a closed-form linear system (no iterative optimization):
171
+ ``(Psi_{g'}' Psi_{g'}) beta = Psi_g.sum(axis=0)``.
172
+
173
+ Selects K via AIC/BIC: ``IC(K) = 2*loss(K) + C_n*K/n``.
174
+
175
+ On singular basis: tries lower K. Short-circuits r_{g,g}(X) = 1.
176
+
177
+ Parameters
178
+ ----------
179
+ covariate_matrix : ndarray, shape (n_units, n_covariates)
180
+ mask_g : ndarray of bool, shape (n_units,)
181
+ Target treatment group mask.
182
+ mask_gp : ndarray of bool, shape (n_units,)
183
+ Comparison group mask.
184
+ k_max : int or None
185
+ Maximum polynomial degree. None = ``min(floor(n_gp^{1/5}), 5)``.
186
+ criterion : str
187
+ ``"aic"`` or ``"bic"``.
188
+ ratio_clip : float
189
+ Clip ratios to ``[1/ratio_clip, ratio_clip]``.
190
+ unit_weights : ndarray, shape (n_units,), optional
191
+ Survey weights at the unit level. When provided, uses weighted
192
+ normal equations for the sieve estimation.
193
+
194
+ Returns
195
+ -------
196
+ ratio : ndarray, shape (n_units,)
197
+ Estimated ``r_{g,g'}(X_i)`` for every unit.
198
+ """
199
+ n_units = len(covariate_matrix)
200
+ n_gp = int(np.sum(mask_gp))
201
+
202
+ # Short-circuit: r_{g,g}(X) = 1 for same-cohort comparisons (PT-All)
203
+ if np.array_equal(mask_g, mask_gp):
204
+ return np.ones(n_units)
205
+
206
+ d = covariate_matrix.shape[1]
207
+
208
+ # Default k_max: use comparison group size, not total n
209
+ if k_max is None:
210
+ k_max = min(int(n_gp**0.2), 5)
211
+ k_max = max(k_max, 1)
212
+
213
+ # Penalty multiplier for IC
214
+ # BIC penalty uses observation count (not weighted) — complexity vs distinct obs
215
+ n_total = int(np.sum(mask_g)) + n_gp
216
+ c_n = 2.0 if criterion == "aic" else np.log(max(n_total, 2))
217
+
218
+ # Weighted totals for loss normalization (raw probability weights)
219
+ if unit_weights is not None:
220
+ w_g = unit_weights[mask_g]
221
+ w_gp = unit_weights[mask_gp]
222
+ n_total_w = float(np.sum(w_g)) + float(np.sum(w_gp))
223
+ else:
224
+ w_g = None
225
+ w_gp = None
226
+ n_total_w = float(n_total)
227
+
228
+ best_ic = np.inf
229
+ best_ratio = np.ones(n_units) # fallback: constant ratio 1
230
+
231
+ for K in range(1, k_max + 1):
232
+ n_basis = comb(K + d, d)
233
+
234
+ # Cap K so basis dimension < n_gp (avoid singular system)
235
+ if n_basis >= n_gp:
236
+ break
237
+
238
+ basis_all = _polynomial_sieve_basis(covariate_matrix, K)
239
+ Psi_gp = basis_all[mask_gp] # (n_gp, n_basis)
240
+ Psi_g = basis_all[mask_g] # (n_g, n_basis)
241
+
242
+ # Normal equations (weighted when survey weights present):
243
+ # Unweighted: (Psi_gp' Psi_gp) beta = Psi_g.sum(axis=0)
244
+ # Weighted: (Psi_gp' W_gp Psi_gp) beta = (w_g * Psi_g).sum(axis=0)
245
+ if w_gp is not None:
246
+ A = Psi_gp.T @ (w_gp[:, None] * Psi_gp)
247
+ b = (w_g[:, None] * Psi_g).sum(axis=0)
248
+ else:
249
+ A = Psi_gp.T @ Psi_gp
250
+ b = Psi_g.sum(axis=0)
251
+
252
+ try:
253
+ beta = np.linalg.solve(A, b)
254
+ except np.linalg.LinAlgError:
255
+ continue # singular — try next K
256
+
257
+ # Check for NaN/Inf in solution
258
+ if not np.all(np.isfinite(beta)):
259
+ continue
260
+
261
+ # Predicted ratio for all units
262
+ r_hat = basis_all @ beta
263
+
264
+ # IC selection: loss at optimum = -(1/n_w) * b'beta
265
+ # Derivation: L(beta) = (1/n_w)(beta'A*beta - 2*b'beta).
266
+ # At optimum A*beta = b, so beta'A*beta = b'beta.
267
+ # Therefore L = (1/n_w)(b'beta - 2*b'beta) = -(1/n_w)*b'beta.
268
+ # Loss uses weighted totals; BIC penalty uses observation count.
269
+ loss = -float(b @ beta) / n_total_w
270
+ ic_val = 2.0 * loss + c_n * n_basis / n_total
271
+
272
+ if ic_val < best_ic:
273
+ best_ic = ic_val
274
+ best_ratio = r_hat.copy()
275
+
276
+ # Warn if no sieve fit succeeded (falling back to constant ratio 1)
277
+ if best_ic == np.inf:
278
+ warnings.warn(
279
+ "Propensity ratio sieve estimation failed for all K values. "
280
+ "Falling back to constant ratio of 1 (no ratio adjustment). "
281
+ "The DR estimator relies on outcome regression only.",
282
+ UserWarning,
283
+ stacklevel=2,
284
+ )
285
+
286
+ # Overlap diagnostics: warn if ratios require significant clipping
287
+ n_extreme = int(np.sum((best_ratio < 1.0 / ratio_clip) | (best_ratio > ratio_clip)))
288
+ if n_extreme > 0:
289
+ pct = 100.0 * n_extreme / n_units
290
+ warnings.warn(
291
+ f"Sieve propensity ratios for {n_extreme} of {n_units} units "
292
+ f"({pct:.1f}%) were outside [{1.0/ratio_clip:.2f}, {ratio_clip:.1f}] "
293
+ f"and will be clipped. This may indicate overlap assumption "
294
+ f"violations (near-zero propensity scores for some covariate values).",
295
+ UserWarning,
296
+ stacklevel=2,
297
+ )
298
+
299
+ # Clip: population ratio p_g(X)/p_{g'}(X) is non-negative
300
+ best_ratio = np.clip(best_ratio, 1.0 / ratio_clip, ratio_clip)
301
+
302
+ return best_ratio
303
+
304
+
305
+ # ---------------------------------------------------------------------------
306
+ # Sieve-based inverse propensity estimation (Algorithm step 4)
307
+ # ---------------------------------------------------------------------------
308
+
309
+
310
+ def estimate_inverse_propensity_sieve(
311
+ covariate_matrix: np.ndarray,
312
+ group_mask: np.ndarray,
313
+ k_max: Optional[int] = None,
314
+ criterion: str = "bic",
315
+ unit_weights: Optional[np.ndarray] = None,
316
+ ) -> np.ndarray:
317
+ r"""Estimate s_{g'}(X) = 1/p_{g'}(X) via sieve convex minimization.
318
+
319
+ Solves for each sieve degree K:
320
+
321
+ .. math::
322
+ \hat\beta_K = \arg\min_\beta \frac{1}{n}
323
+ \sum_i \bigl[ G_{g',i} (\psi^K(X_i)'\beta)^2
324
+ - 2 (\psi^K(X_i)'\beta) \bigr]
325
+
326
+ FOC: ``(Psi_{g'}' Psi_{g'}) beta = Psi_all.sum(axis=0)``
327
+
328
+ This is the same structure as the ratio estimator but with all
329
+ units on the RHS (not just group g), following the paper's
330
+ algorithm step 4.
331
+
332
+ Parameters
333
+ ----------
334
+ covariate_matrix : ndarray, shape (n_units, n_covariates)
335
+ group_mask : ndarray of bool, shape (n_units,)
336
+ Mask for the group whose inverse propensity to estimate.
337
+ k_max : int or None
338
+ Maximum polynomial degree. None = auto.
339
+ criterion : str
340
+ ``"aic"`` or ``"bic"``.
341
+ unit_weights : ndarray, shape (n_units,), optional
342
+ Survey weights at the unit level. When provided, uses weighted
343
+ normal equations for the sieve estimation.
344
+
345
+ Returns
346
+ -------
347
+ s_hat : ndarray, shape (n_units,)
348
+ Estimated ``1/p_{g'}(X_i)`` for every unit. Clipped to [1, n].
349
+ """
350
+ n_units = len(covariate_matrix)
351
+ n_group = int(np.sum(group_mask))
352
+ d = covariate_matrix.shape[1]
353
+
354
+ if n_group == 0:
355
+ return np.ones(n_units)
356
+
357
+ if k_max is None:
358
+ k_max = min(int(n_group**0.2), 5)
359
+ k_max = max(k_max, 1)
360
+
361
+ # BIC penalty uses observation count (not weighted)
362
+ c_n = 2.0 if criterion == "aic" else np.log(max(n_units, 2))
363
+
364
+ # Weighted loss normalization and fallback
365
+ if unit_weights is not None:
366
+ w_group = unit_weights[group_mask]
367
+ sum_w_group = float(np.sum(w_group))
368
+ if sum_w_group <= 0:
369
+ # Zero survey weight for this group — return unconditional fallback
370
+ return np.ones(n_units)
371
+ n_units_w = float(np.sum(unit_weights))
372
+ fallback_ratio = n_units_w / sum_w_group
373
+ else:
374
+ w_group = None
375
+ n_units_w = float(n_units)
376
+ fallback_ratio = n_units / n_group
377
+
378
+ best_ic = np.inf
379
+ best_s = np.full(n_units, fallback_ratio) # fallback: unconditional
380
+
381
+ for K in range(1, k_max + 1):
382
+ n_basis = comb(K + d, d)
383
+ if n_basis >= n_group:
384
+ break
385
+
386
+ basis_all = _polynomial_sieve_basis(covariate_matrix, K)
387
+ Psi_gp = basis_all[group_mask]
388
+
389
+ # Normal equations (weighted when survey weights present):
390
+ # Unweighted: (Psi_gp' Psi_gp) beta = Psi_all.sum(axis=0)
391
+ # Weighted: (Psi_gp' W_group Psi_gp) beta = (w_all * Psi_all).sum(axis=0)
392
+ if w_group is not None:
393
+ A = Psi_gp.T @ (w_group[:, None] * Psi_gp)
394
+ b = (unit_weights[:, None] * basis_all).sum(axis=0)
395
+ else:
396
+ A = Psi_gp.T @ Psi_gp
397
+ # RHS: sum of basis over ALL units (not just one group)
398
+ b = basis_all.sum(axis=0)
399
+
400
+ try:
401
+ beta = np.linalg.solve(A, b)
402
+ except np.linalg.LinAlgError:
403
+ continue
404
+ if not np.all(np.isfinite(beta)):
405
+ continue
406
+
407
+ s_hat = basis_all @ beta
408
+
409
+ # IC: loss = -(1/n_w) * b'beta (same derivation as ratio estimator)
410
+ # Loss uses weighted totals; BIC penalty uses observation count.
411
+ loss = -float(b @ beta) / n_units_w
412
+ ic_val = 2.0 * loss + c_n * n_basis / n_units
413
+
414
+ if ic_val < best_ic:
415
+ best_ic = ic_val
416
+ best_s = s_hat.copy()
417
+
418
+ # Warn if no sieve fit succeeded (falling back to unconditional)
419
+ if best_ic == np.inf:
420
+ warnings.warn(
421
+ "Inverse propensity sieve estimation failed for all K values. "
422
+ "Falling back to unconditional n/n_group scaling.",
423
+ UserWarning,
424
+ stacklevel=2,
425
+ )
426
+
427
+ # Overlap diagnostics: warn if s_hat values require clipping
428
+ n_clipped = int(np.sum((best_s < 1.0) | (best_s > float(n_units))))
429
+ if n_clipped > 0:
430
+ pct = 100.0 * n_clipped / n_units
431
+ warnings.warn(
432
+ f"Inverse propensity estimates for {n_clipped} of {n_units} units "
433
+ f"({pct:.1f}%) were outside [1, {n_units}] and will be clipped. "
434
+ f"This may indicate overlap assumption violations.",
435
+ UserWarning,
436
+ stacklevel=2,
437
+ )
438
+
439
+ # s = 1/p must be >= 1 (since p <= 1) and bounded above
440
+ best_s = np.clip(best_s, 1.0, float(n_units))
441
+ return best_s
442
+
443
+
444
+ # ---------------------------------------------------------------------------
445
+ # Doubly robust generated outcomes (Eq 4.4)
446
+ # ---------------------------------------------------------------------------
447
+
448
+
449
+ def compute_generated_outcomes_cov(
450
+ target_g: float,
451
+ target_t: float,
452
+ valid_pairs: List[Tuple[float, float]],
453
+ outcome_wide: np.ndarray,
454
+ cohort_masks: Dict[float, np.ndarray],
455
+ never_treated_mask: np.ndarray,
456
+ period_to_col: Dict[float, int],
457
+ period_1_col: int,
458
+ cohort_fractions: Dict[float, float],
459
+ m_hat_cache: Dict[Tuple, np.ndarray],
460
+ r_hat_cache: Dict[Tuple[float, float], np.ndarray],
461
+ never_treated_val: float = np.inf,
462
+ ) -> np.ndarray:
463
+ """Compute per-unit doubly robust generated outcomes (Eq 4.4).
464
+
465
+ For each valid pair ``(g', t_pre)`` and each unit ``i``, three terms::
466
+
467
+ Term 1 (treated):
468
+ (G_{g,i} / pi_g) * (Y_{i,t} - Y_{i,1}
469
+ - m_{inf,t,tpre}(X_i) - m_{g',tpre,1}(X_i))
470
+
471
+ Term 2 (never-treated):
472
+ -r_{g,inf}(X_i) * (G_{inf,i} / pi_g)
473
+ * (Y_{i,t} - Y_{i,tpre} - m_{inf,t,tpre}(X_i))
474
+
475
+ Term 3 (comparison cohort):
476
+ -r_{g,g'}(X_i) * (G_{g',i} / pi_g)
477
+ * (Y_{i,tpre} - Y_{i,1} - m_{g',tpre,1}(X_i))
478
+
479
+ Returns
480
+ -------
481
+ gen_out : ndarray, shape (n_units, H)
482
+ Per-unit generated outcome for each valid pair.
483
+ """
484
+ H = len(valid_pairs)
485
+ n_units = outcome_wide.shape[0]
486
+ if H == 0:
487
+ return np.empty((n_units, 0))
488
+
489
+ t_col = period_to_col[target_t]
490
+ y1_col = period_1_col
491
+
492
+ g_mask = cohort_masks[target_g]
493
+ pi_g = cohort_fractions[target_g]
494
+
495
+ # Guard: zero survey weight for the target cohort → no DR estimation possible
496
+ if pi_g <= 0:
497
+ return np.zeros((n_units, H))
498
+
499
+ gen_out = np.zeros((n_units, H))
500
+
501
+ for j, (gp, tpre) in enumerate(valid_pairs):
502
+ tpre_col = period_to_col[tpre]
503
+
504
+ m_inf_t_tpre = m_hat_cache[(never_treated_val, t_col, tpre_col)]
505
+ m_gp_tpre_1 = m_hat_cache[(gp, tpre_col, y1_col)]
506
+ r_g_inf = r_hat_cache[(target_g, never_treated_val)]
507
+ r_g_gp = r_hat_cache[(target_g, gp)]
508
+
509
+ # Term 1: treated units
510
+ if pi_g > 0:
511
+ Y_t_minus_Y1 = outcome_wide[g_mask, t_col] - outcome_wide[g_mask, y1_col]
512
+ residual_treated = Y_t_minus_Y1 - m_inf_t_tpre[g_mask] - m_gp_tpre_1[g_mask]
513
+ gen_out[g_mask, j] += (1.0 / pi_g) * residual_treated
514
+
515
+ # Term 2: never-treated units
516
+ pi_inf = cohort_fractions.get(never_treated_val, 0.0)
517
+ if pi_inf > 0:
518
+ Y_t_minus_Ytpre = (
519
+ outcome_wide[never_treated_mask, t_col] - outcome_wide[never_treated_mask, tpre_col]
520
+ )
521
+ residual_inf = Y_t_minus_Ytpre - m_inf_t_tpre[never_treated_mask]
522
+ gen_out[never_treated_mask, j] -= (
523
+ r_g_inf[never_treated_mask] * (1.0 / pi_g) * residual_inf
524
+ )
525
+
526
+ # Term 3: comparison cohort units
527
+ if np.isinf(gp):
528
+ gp_mask = never_treated_mask
529
+ else:
530
+ gp_mask = cohort_masks[gp]
531
+ pi_gp = cohort_fractions.get(gp, 0.0)
532
+ if pi_gp > 0:
533
+ Y_tpre_minus_Y1 = outcome_wide[gp_mask, tpre_col] - outcome_wide[gp_mask, y1_col]
534
+ residual_gp = Y_tpre_minus_Y1 - m_gp_tpre_1[gp_mask]
535
+ gen_out[gp_mask, j] -= r_g_gp[gp_mask] * (1.0 / pi_g) * residual_gp
536
+
537
+ return gen_out
538
+
539
+
540
+ # ---------------------------------------------------------------------------
541
+ # Kernel-smoothed conditional Omega* (Eq 3.12)
542
+ # ---------------------------------------------------------------------------
543
+
544
+
545
+ def _silverman_bandwidth(X: np.ndarray) -> float:
546
+ """Silverman's rule-of-thumb bandwidth for d-dimensional X.
547
+
548
+ ``h = (4 / (d + 2))^{1/(d+4)} * median_std * n^{-1/(d+4)}``
549
+ """
550
+ n, d = X.shape
551
+ stds = np.std(X, axis=0)
552
+ stds[stds < 1e-10] = 1.0
553
+ median_std = float(np.median(stds))
554
+ h = (4.0 / (d + 2)) ** (1.0 / (d + 4)) * median_std * n ** (-1.0 / (d + 4))
555
+ return max(h, 1e-10)
556
+
557
+
558
+ def _kernel_weights_matrix(
559
+ X_all: np.ndarray,
560
+ X_group: np.ndarray,
561
+ bandwidth: float,
562
+ group_weights: Optional[np.ndarray] = None,
563
+ ) -> np.ndarray:
564
+ """Gaussian kernel weight matrix.
565
+
566
+ Returns shape ``(n_all, n_group)`` where entry ``[i, j]`` is the
567
+ normalized kernel weight ``K_h(X_group[j], X_all[i])``.
568
+
569
+ Each row sums to 1 (Nadaraya-Watson normalization).
570
+
571
+ Parameters
572
+ ----------
573
+ group_weights : ndarray, shape (n_group,), optional
574
+ Survey weights for the group units. When provided, kernel
575
+ weights are multiplied by survey weights before row-normalization,
576
+ making the Nadaraya-Watson estimator survey-weighted.
577
+ """
578
+ # Squared distances: (n_all, n_group)
579
+ dist_sq = cdist(X_all, X_group, metric="sqeuclidean")
580
+ # Gaussian kernel
581
+ raw = np.exp(-dist_sq / (2.0 * bandwidth**2))
582
+ # Survey-weight: each group unit j contributes ∝ w_j * K_h(X_i, X_j)
583
+ if group_weights is not None:
584
+ raw = raw * group_weights[np.newaxis, :]
585
+ # Normalize each row
586
+ row_sums = raw.sum(axis=1, keepdims=True)
587
+ row_sums[row_sums < 1e-15] = 1.0 # avoid division by zero
588
+ return raw / row_sums
589
+
590
+
591
+ def _kernel_weighted_cov(
592
+ A: np.ndarray,
593
+ B: np.ndarray,
594
+ W: np.ndarray,
595
+ ) -> np.ndarray:
596
+ """Kernel-weighted local covariance.
597
+
598
+ Parameters
599
+ ----------
600
+ A : ndarray, shape (n_group,)
601
+ B : ndarray, shape (n_group,)
602
+ W : ndarray, shape (n_all, n_group)
603
+ Normalized kernel weights (rows sum to 1).
604
+
605
+ Returns
606
+ -------
607
+ cov : ndarray, shape (n_all,)
608
+ ``Cov_hat(A, B | X_i)`` for each target unit i.
609
+ """
610
+ # Local means: (n_all,)
611
+ A_local = W @ A
612
+ B_local = W @ B
613
+
614
+ # Centered products: (n_all, n_group)
615
+ A_centered = A[np.newaxis, :] - A_local[:, np.newaxis] # (n_all, n_group)
616
+ B_centered = B[np.newaxis, :] - B_local[:, np.newaxis]
617
+
618
+ # Weighted local covariance: (n_all,)
619
+ cov = np.sum(W * A_centered * B_centered, axis=1)
620
+ return cov
621
+
622
+
623
+ def compute_omega_star_conditional(
624
+ target_g: float,
625
+ target_t: float,
626
+ valid_pairs: List[Tuple[float, float]],
627
+ outcome_wide: np.ndarray,
628
+ cohort_masks: Dict[float, np.ndarray],
629
+ never_treated_mask: np.ndarray,
630
+ period_to_col: Dict[float, int],
631
+ period_1_col: int,
632
+ cohort_fractions: Dict[float, float],
633
+ covariate_matrix: np.ndarray,
634
+ s_hat_cache: Dict[float, np.ndarray],
635
+ bandwidth: Optional[float] = None,
636
+ unit_weights: Optional[np.ndarray] = None,
637
+ never_treated_val: float = np.inf,
638
+ ) -> np.ndarray:
639
+ r"""Kernel-smoothed conditional Omega\*(X_i) for each unit (Eq 3.12).
640
+
641
+ Estimates the five-term conditional covariance matrix using
642
+ Nadaraya-Watson kernel regression with Gaussian kernel and
643
+ local (kernel-weighted) means. Scales each term by per-unit
644
+ conditional inverse propensities ``s_hat_g(X_i) = 1/p_g(X_i)``
645
+ (algorithm step 4), matching the paper's Eq 3.12.
646
+
647
+ Parameters
648
+ ----------
649
+ target_g, target_t : float
650
+ Target group-time.
651
+ valid_pairs : list of (g', t_pre)
652
+ outcome_wide : ndarray, shape (n_units, n_periods)
653
+ cohort_masks, never_treated_mask, period_to_col, period_1_col,
654
+ cohort_fractions : pre-computed data structures
655
+ covariate_matrix : ndarray, shape (n_units, n_covariates)
656
+ s_hat_cache : dict
657
+ Inverse propensity estimates ``{group: s_hat(X_i)}`` where each
658
+ value is shape ``(n_units,)``. Keyed by group identifier.
659
+ bandwidth : float or None
660
+ Kernel bandwidth. None = Silverman's rule.
661
+ unit_weights : ndarray, shape (n_units,), optional
662
+ Survey weights at the unit level. When provided, kernel-smoothed
663
+ covariances use survey-weighted Nadaraya-Watson regression.
664
+ never_treated_val : float
665
+
666
+ Returns
667
+ -------
668
+ omega : ndarray, shape (n_units, H, H)
669
+ Per-unit conditional covariance matrices.
670
+ """
671
+ H = len(valid_pairs)
672
+ n_units = outcome_wide.shape[0]
673
+ if H == 0:
674
+ return np.empty((n_units, 0, 0))
675
+
676
+ if bandwidth is None:
677
+ bandwidth = _silverman_bandwidth(covariate_matrix)
678
+
679
+ t_col = period_to_col[target_t]
680
+ y1_col = period_1_col
681
+
682
+ g_mask = cohort_masks[target_g]
683
+
684
+ Y_inf = outcome_wide[never_treated_mask]
685
+ X_inf = covariate_matrix[never_treated_mask]
686
+
687
+ # Per-unit inverse propensities from sieve estimation (Eq 3.12)
688
+ s_g = s_hat_cache.get(target_g, np.full(n_units, 1.0 / max(cohort_fractions[target_g], 1e-10)))
689
+ s_inf = s_hat_cache.get(
690
+ never_treated_val,
691
+ np.full(n_units, 1.0 / max(cohort_fractions.get(never_treated_val, 1e-10), 1e-10)),
692
+ )
693
+
694
+ # Scalability warning
695
+ if n_units > 5000:
696
+ warnings.warn(
697
+ f"Conditional Omega* estimation with n={n_units} is expensive "
698
+ f"(O(n^2 * H^2)). Consider using fewer units.",
699
+ UserWarning,
700
+ stacklevel=2,
701
+ )
702
+
703
+ # Per-group survey weights for kernel smoothing
704
+ w_g = unit_weights[g_mask] if unit_weights is not None else None
705
+ w_inf = unit_weights[never_treated_mask] if unit_weights is not None else None
706
+
707
+ # Pre-compute kernel weight matrices per group
708
+ Y_g = outcome_wide[g_mask]
709
+ X_g = covariate_matrix[g_mask]
710
+ Yg_t_minus_1 = Y_g[:, t_col] - Y_g[:, y1_col]
711
+
712
+ W_g = _kernel_weights_matrix(covariate_matrix, X_g, bandwidth, group_weights=w_g)
713
+ W_inf = _kernel_weights_matrix(covariate_matrix, X_inf, bandwidth, group_weights=w_inf)
714
+
715
+ inf_t_minus_tpre = {}
716
+ for _, tpre in valid_pairs:
717
+ tpre_col = period_to_col[tpre]
718
+ if tpre_col not in inf_t_minus_tpre:
719
+ inf_t_minus_tpre[tpre_col] = Y_inf[:, t_col] - Y_inf[:, tpre_col]
720
+
721
+ W_gp_cache: Dict[float, np.ndarray] = {}
722
+ gp_outcomes_cache: Dict[float, np.ndarray] = {}
723
+
724
+ omega = np.zeros((n_units, H, H))
725
+
726
+ # Term 1: s_g(X) * Cov(Y_t-Y_1, Y_t-Y_1 | G=g, X) — same for all (j,k)
727
+ term1 = s_g * _kernel_weighted_cov(Yg_t_minus_1, Yg_t_minus_1, W_g)
728
+
729
+ for j in range(H):
730
+ gp_j, tpre_j = valid_pairs[j]
731
+ tpre_j_col = period_to_col[tpre_j]
732
+
733
+ for k in range(j, H):
734
+ gp_k, tpre_k = valid_pairs[k]
735
+ tpre_k_col = period_to_col[tpre_k]
736
+
737
+ val = term1.copy()
738
+
739
+ # Term 2: s_inf(X) * Cov(Y_t-Y_{tpre_j}, Y_t-Y_{tpre_k} | G=inf, X)
740
+ val += s_inf * _kernel_weighted_cov(
741
+ inf_t_minus_tpre[tpre_j_col],
742
+ inf_t_minus_tpre[tpre_k_col],
743
+ W_inf,
744
+ )
745
+
746
+ # Term 3: -1{g==g'_j} * s_g(X) * Cov(Y_t-Y_1, Y_{tpre_j}-Y_1 | G=g, X)
747
+ if gp_j == target_g:
748
+ g_tpre_j = Y_g[:, tpre_j_col] - Y_g[:, y1_col]
749
+ val -= s_g * _kernel_weighted_cov(Yg_t_minus_1, g_tpre_j, W_g)
750
+
751
+ # Term 4: -1{g==g'_k} * s_g(X) * Cov(Y_t-Y_1, Y_{tpre_k}-Y_1 | G=g, X)
752
+ if gp_k == target_g:
753
+ g_tpre_k = Y_g[:, tpre_k_col] - Y_g[:, y1_col]
754
+ val -= s_g * _kernel_weighted_cov(Yg_t_minus_1, g_tpre_k, W_g)
755
+
756
+ # Term 5: 1{g'_j==g'_k} * s_{g'_j}(X) * Cov(...)
757
+ if gp_j == gp_k:
758
+ if np.isinf(gp_j):
759
+ inf_tpre_j = Y_inf[:, tpre_j_col] - Y_inf[:, y1_col]
760
+ inf_tpre_k = Y_inf[:, tpre_k_col] - Y_inf[:, y1_col]
761
+ val += s_inf * _kernel_weighted_cov(inf_tpre_j, inf_tpre_k, W_inf)
762
+ else:
763
+ s_gp_j = s_hat_cache.get(
764
+ gp_j, np.full(n_units, 1.0 / max(cohort_fractions.get(gp_j, 1e-10), 1e-10))
765
+ )
766
+ if gp_j not in W_gp_cache:
767
+ X_gp = covariate_matrix[cohort_masks[gp_j]]
768
+ w_gp_j = unit_weights[cohort_masks[gp_j]] if unit_weights is not None else None
769
+ W_gp_cache[gp_j] = _kernel_weights_matrix(
770
+ covariate_matrix, X_gp, bandwidth, group_weights=w_gp_j
771
+ )
772
+ gp_outcomes_cache[gp_j] = outcome_wide[cohort_masks[gp_j]]
773
+ W_gp = W_gp_cache[gp_j]
774
+ Y_gp = gp_outcomes_cache[gp_j]
775
+ gp_tpre_j = Y_gp[:, tpre_j_col] - Y_gp[:, y1_col]
776
+ gp_tpre_k = Y_gp[:, tpre_k_col] - Y_gp[:, y1_col]
777
+ val += s_gp_j * _kernel_weighted_cov(gp_tpre_j, gp_tpre_k, W_gp)
778
+
779
+ omega[:, j, k] = val
780
+ if j != k:
781
+ omega[:, k, j] = val
782
+
783
+ return omega
784
+
785
+
786
+ # ---------------------------------------------------------------------------
787
+ # Per-unit efficient weights from conditional Omega*
788
+ # ---------------------------------------------------------------------------
789
+
790
+
791
+ def compute_per_unit_weights(
792
+ omega_conditional: np.ndarray,
793
+ cond_threshold: float = 1e12,
794
+ ) -> np.ndarray:
795
+ """Per-unit efficient weights from conditional Omega* inverse.
796
+
797
+ ``w(X_i) = 1' Omega*(X_i)^{-1} / (1' Omega*(X_i)^{-1} 1)``
798
+
799
+ Falls back to pseudoinverse per unit if condition number exceeds threshold.
800
+
801
+ Parameters
802
+ ----------
803
+ omega_conditional : ndarray, shape (n_units, H, H)
804
+ Per-unit conditional covariance matrices.
805
+ cond_threshold : float
806
+ Condition number threshold for pseudoinverse fallback.
807
+
808
+ Returns
809
+ -------
810
+ weights : ndarray, shape (n_units, H)
811
+ Per-unit efficient combination weights (each row sums to 1).
812
+ """
813
+ n_units, H, _ = omega_conditional.shape
814
+ if H == 0:
815
+ return np.empty((n_units, 0))
816
+ if H == 1:
817
+ return np.ones((n_units, 1))
818
+
819
+ ones = np.ones(H)
820
+ weights = np.zeros((n_units, H))
821
+
822
+ for i in range(n_units):
823
+ omega_i = omega_conditional[i]
824
+
825
+ if np.allclose(omega_i, 0.0):
826
+ weights[i] = ones / H
827
+ continue
828
+
829
+ cond = float(np.linalg.cond(omega_i))
830
+ if cond > cond_threshold:
831
+ omega_inv = np.linalg.pinv(omega_i)
832
+ else:
833
+ try:
834
+ omega_inv = np.linalg.inv(omega_i)
835
+ except np.linalg.LinAlgError:
836
+ omega_inv = np.linalg.pinv(omega_i)
837
+
838
+ numerator = ones @ omega_inv
839
+ denominator = numerator @ ones
840
+
841
+ if abs(denominator) < 1e-15:
842
+ weights[i] = ones / H
843
+ else:
844
+ weights[i] = numerator / denominator
845
+
846
+ return weights
847
+
848
+
849
+ # ---------------------------------------------------------------------------
850
+ # EIF computation
851
+ # ---------------------------------------------------------------------------
852
+
853
+
854
+ def compute_eif_cov(
855
+ weights: np.ndarray,
856
+ generated_outcomes: np.ndarray,
857
+ att_gt: float,
858
+ n_units: int,
859
+ ) -> np.ndarray:
860
+ """Per-unit efficient influence function from DR generated outcomes.
861
+
862
+ Supports both global weights ``(H,)`` and per-unit weights ``(n_units, H)``.
863
+
864
+ For global weights: ``EIF_i = w @ (gen_out_i - y_bar) = w @ gen_out_i - ATT``
865
+ For per-unit weights: ``EIF_i = w(X_i) @ gen_out_i - ATT``
866
+
867
+ In both cases the EIF centers on the scalar ATT estimate, ensuring
868
+ ``mean(EIF) ≈ 0``. The plug-in EIF treats estimated per-unit weights
869
+ as fixed, valid under Neyman orthogonality (Remark 4.2).
870
+
871
+ Parameters
872
+ ----------
873
+ weights : ndarray, shape (H,) or (n_units, H)
874
+ Efficient combination weights.
875
+ generated_outcomes : ndarray, shape (n_units, H)
876
+ Per-unit generated outcomes.
877
+ att_gt : float
878
+ Scalar ATT estimate for this (g, t) cell.
879
+ n_units : int
880
+ Total number of units.
881
+
882
+ Returns
883
+ -------
884
+ eif : ndarray, shape (n_units,)
885
+ EIF value for every unit. Sample mean is approximately zero.
886
+ """
887
+ if weights.size == 0:
888
+ return np.zeros(n_units)
889
+
890
+ if weights.ndim == 1:
891
+ # Global weights: w @ gen_out_i for each unit
892
+ weighted_scores = generated_outcomes @ weights # (n_units,)
893
+ else:
894
+ # Per-unit weights: w_i @ gen_out_i for each unit
895
+ weighted_scores = np.sum(weights * generated_outcomes, axis=1)
896
+
897
+ # Center on the scalar ATT estimate (ensures mean(EIF) ≈ 0)
898
+ eif = weighted_scores - att_gt
899
+ return eif