diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,617 @@
1
+ """
2
+ Mathematical core for the Efficient DiD estimator.
3
+
4
+ Implements the no-covariates path from Chen, Sant'Anna & Xie (2025):
5
+ optimal weighting via the inverse of the conditional covariance matrix Omega*,
6
+ generated outcomes from within-group sample means, and the efficient
7
+ influence function for analytical standard errors.
8
+
9
+ All functions are pure (no state), operating on pre-pivoted numpy arrays.
10
+ """
11
+
12
+ import warnings
13
+ from typing import Dict, List, Optional, Tuple
14
+
15
+ import numpy as np
16
+
17
+
18
+ def enumerate_valid_triples(
19
+ target_g: float,
20
+ treatment_groups: List[float],
21
+ time_periods: List[float],
22
+ period_1: float,
23
+ pt_assumption: str,
24
+ anticipation: int = 0,
25
+ never_treated_val: float = np.inf,
26
+ ) -> List[Tuple[float, float]]:
27
+ """Enumerate valid (g', t_pre) pairs for target (g, t).
28
+
29
+ Under PT-All, any not-yet-treated cohort g' (including never-treated and
30
+ g'=g itself) paired with any baseline t_pre that is pre-treatment for the
31
+ *comparison* group g' forms a valid comparison. The target group g appears
32
+ only in the first term (Y_t - Y_1), which is independent of t_pre, so
33
+ t_pre need not be pre-treatment for g. Under PT-Post, only the
34
+ never-treated group with baseline g - 1 - anticipation is valid
35
+ (just-identified).
36
+
37
+ Parameters
38
+ ----------
39
+ target_g : float
40
+ Treatment cohort of the target group.
41
+ treatment_groups : list of float
42
+ All treatment cohort identifiers (finite values only).
43
+ time_periods : list of float
44
+ All observed time periods, sorted.
45
+ period_1 : float
46
+ Earliest observed period (universal baseline).
47
+ pt_assumption : str
48
+ ``"all"`` or ``"post"``.
49
+ anticipation : int
50
+ Number of anticipation periods.
51
+ never_treated_val : float
52
+ Sentinel for the never-treated group (default ``np.inf``).
53
+
54
+ Returns
55
+ -------
56
+ list of (g', t_pre) tuples
57
+ Valid comparison pairs. Empty if none exist.
58
+ """
59
+ if pt_assumption == "post":
60
+ # Just-identified: only (never-treated, g - 1 - anticipation)
61
+ baseline = target_g - 1 - anticipation
62
+ if baseline >= period_1:
63
+ return [(never_treated_val, baseline)]
64
+ return []
65
+
66
+ # PT-All: overidentified
67
+ pairs: List[Tuple[float, float]] = []
68
+
69
+ # Candidate comparison groups: never-treated + all treatment cohorts.
70
+ # Including g'=g (same-cohort) is valid under PT-All (Eq 3.9).
71
+ # Including g'=∞ (never-treated) produces moments where the second
72
+ # and third terms telescope: y_hat = E[Y_t-Y_1|G=g] - E[Y_t-Y_1|G=∞]
73
+ # regardless of t_pre. These redundant moments add no information
74
+ # beyond the basic 2x2 DiD; Omega*'s pseudoinverse assigns them
75
+ # zero effective weight. Retained for implementation simplicity.
76
+ candidate_groups: List[float] = [never_treated_val]
77
+ for gp in treatment_groups:
78
+ candidate_groups.append(gp)
79
+
80
+ for gp in candidate_groups:
81
+ # Determine effective treatment start for comparison group
82
+ if np.isinf(gp):
83
+ effective_gp = np.inf # never treated
84
+ else:
85
+ effective_gp = gp - anticipation
86
+
87
+ for t_pre in time_periods:
88
+ if t_pre == period_1:
89
+ # period_1 is the universal reference — used as Y_1 in the
90
+ # differencing (Eq 3.9 first term). Including t_pre = period_1
91
+ # would make the third term Y_1 - Y_1 = 0 (degenerate), so it
92
+ # adds no information to Omega* regardless of which g' is used.
93
+ continue
94
+ # Only require t_pre < g' (pre-treatment for comparison group).
95
+ # No constraint on t_pre vs g: the target group appears only in
96
+ # the first term (Y_t - Y_1), which is independent of t_pre.
97
+ if not np.isinf(effective_gp) and t_pre >= effective_gp:
98
+ continue
99
+ pairs.append((gp, t_pre))
100
+
101
+ return pairs
102
+
103
+
104
+ def _sample_cov(
105
+ a: np.ndarray,
106
+ b: np.ndarray,
107
+ w: Optional[np.ndarray] = None,
108
+ ) -> float:
109
+ """Sample covariance between two 1-D arrays (ddof=1).
110
+
111
+ Returns 0.0 if fewer than 2 observations.
112
+
113
+ Parameters
114
+ ----------
115
+ a, b : ndarray, shape (n,)
116
+ Data arrays.
117
+ w : ndarray, shape (n,), optional
118
+ Survey weights. When provided, computes the reliability-weighted
119
+ covariance: ``sum(w*(a-a_bar)*(b-b_bar)) / (sum(w) - 1)`` where
120
+ ``a_bar = average(a, weights=w)``.
121
+ """
122
+ n = len(a)
123
+ if n < 2:
124
+ return 0.0
125
+ if w is None:
126
+ return float(((a - a.mean()) * (b - b.mean())).sum() / (n - 1))
127
+ # Weighted covariance with reliability weights (Bessel-style correction)
128
+ a_bar = float(np.average(a, weights=w))
129
+ b_bar = float(np.average(b, weights=w))
130
+ sum_w = float(np.sum(w))
131
+ if sum_w <= 1.0:
132
+ return 0.0
133
+ return float(np.sum(w * (a - a_bar) * (b - b_bar)) / (sum_w - 1.0))
134
+
135
+
136
+ def compute_omega_star_nocov(
137
+ target_g: float,
138
+ target_t: float,
139
+ valid_pairs: List[Tuple[float, float]],
140
+ outcome_wide: np.ndarray,
141
+ cohort_masks: Dict[float, np.ndarray],
142
+ never_treated_mask: np.ndarray,
143
+ period_to_col: Dict[float, int],
144
+ period_1_col: int,
145
+ cohort_fractions: Dict[float, float],
146
+ never_treated_val: float = np.inf,
147
+ unit_weights: Optional[np.ndarray] = None,
148
+ ) -> np.ndarray:
149
+ """Build the |H| x |H| covariance matrix Omega* (Eq 3.12, unconditional).
150
+
151
+ Each element Omega*[j,k] is the sum of up to five covariance terms
152
+ computed from within-group sample covariances scaled by inverse
153
+ cohort fractions.
154
+
155
+ Parameters
156
+ ----------
157
+ target_g : float
158
+ Target treatment cohort.
159
+ target_t : float
160
+ Target time period.
161
+ valid_pairs : list of (g', t_pre) tuples
162
+ Valid comparison pairs from :func:`enumerate_valid_triples`.
163
+ outcome_wide : ndarray, shape (n_units, n_periods)
164
+ Pivoted outcome matrix.
165
+ cohort_masks : dict
166
+ ``{cohort: bool_mask}`` over the unit dimension.
167
+ never_treated_mask : ndarray of bool
168
+ Mask for never-treated units.
169
+ period_to_col : dict
170
+ ``{period: column_index}`` in ``outcome_wide``.
171
+ period_1_col : int
172
+ Column index of the earliest period (universal baseline Y_1).
173
+ cohort_fractions : dict
174
+ ``{cohort: n_cohort / n}`` for each cohort.
175
+ never_treated_val : float
176
+ Sentinel for the never-treated group.
177
+ unit_weights : ndarray, shape (n_units,), optional
178
+ Survey weights at the unit level. When provided, all sample
179
+ means and covariances are weighted.
180
+
181
+ Returns
182
+ -------
183
+ ndarray, shape (|H|, |H|)
184
+ Covariance matrix. Empty (0,0) array if ``valid_pairs`` is empty.
185
+ """
186
+ H = len(valid_pairs)
187
+ if H == 0:
188
+ return np.empty((0, 0))
189
+
190
+ t_col = period_to_col[target_t]
191
+ y1_col = period_1_col
192
+
193
+ # Pre-extract outcome columns for target group g
194
+ g_mask = cohort_masks[target_g]
195
+ Y_g = outcome_wide[g_mask] # (n_g, n_periods)
196
+ pi_g = cohort_fractions[target_g]
197
+
198
+ # Extract per-cohort weights (None propagates = unweighted)
199
+ w_g = unit_weights[g_mask] if unit_weights is not None else None
200
+ w_inf = unit_weights[never_treated_mask] if unit_weights is not None else None
201
+
202
+ # Y_t - Y_1 for the target group
203
+ Yg_t_minus_1 = Y_g[:, t_col] - Y_g[:, y1_col]
204
+
205
+ # Never-treated outcomes
206
+ Y_inf = outcome_wide[never_treated_mask]
207
+ pi_inf = cohort_fractions.get(never_treated_val, 0.0)
208
+
209
+ omega = np.zeros((H, H))
210
+
211
+ # Hoist Term 1: (1/pi_g) * Var(Y_t - Y_1 | G=g) — same for all (j, k)
212
+ term1 = 0.0
213
+ if pi_g > 0:
214
+ term1 = (1.0 / pi_g) * _sample_cov(Yg_t_minus_1, Yg_t_minus_1, w=w_g)
215
+
216
+ # Precompute differenced arrays to avoid redundant slicing in the loop
217
+ # Never-treated: Y_t - Y_{tpre} and Y_{tpre} - Y_1 for each tpre
218
+ inf_t_minus_tpre: Dict[int, np.ndarray] = {}
219
+ inf_tpre_minus_1: Dict[int, np.ndarray] = {}
220
+ if len(Y_inf) >= 2:
221
+ for _, tpre in valid_pairs:
222
+ tpre_col = period_to_col[tpre]
223
+ if tpre_col not in inf_t_minus_tpre:
224
+ inf_t_minus_tpre[tpre_col] = Y_inf[:, t_col] - Y_inf[:, tpre_col]
225
+ inf_tpre_minus_1[tpre_col] = Y_inf[:, tpre_col] - Y_inf[:, y1_col]
226
+
227
+ # Target group: Y_{tpre} - Y_1 for each tpre where g' == target_g
228
+ g_tpre_minus_1: Dict[int, np.ndarray] = {}
229
+ if pi_g > 0:
230
+ for gp, tpre in valid_pairs:
231
+ if gp == target_g:
232
+ tpre_col = period_to_col[tpre]
233
+ if tpre_col not in g_tpre_minus_1:
234
+ g_tpre_minus_1[tpre_col] = Y_g[:, tpre_col] - Y_g[:, y1_col]
235
+
236
+ # Comparison cohort submatrices: cache outcome_wide[cohort_masks[gp]]
237
+ gp_outcomes: Dict[float, np.ndarray] = {}
238
+ gp_weights: Dict[float, Optional[np.ndarray]] = {}
239
+ for gp, _ in valid_pairs:
240
+ if not np.isinf(gp) and gp not in gp_outcomes:
241
+ if gp in cohort_masks:
242
+ gp_outcomes[gp] = outcome_wide[cohort_masks[gp]]
243
+ gp_weights[gp] = (
244
+ unit_weights[cohort_masks[gp]] if unit_weights is not None else None
245
+ )
246
+
247
+ # Comparison cohort: Y_{tpre} - Y_1 for each (gp, tpre) pair in Term 5
248
+ gp_tpre_minus_1: Dict[Tuple[float, int], np.ndarray] = {}
249
+
250
+ for j in range(H):
251
+ gp_j, tpre_j = valid_pairs[j]
252
+ tpre_j_col = period_to_col[tpre_j]
253
+
254
+ for k in range(j, H):
255
+ gp_k, tpre_k = valid_pairs[k]
256
+ tpre_k_col = period_to_col[tpre_k]
257
+
258
+ val = term1
259
+
260
+ # Term 2: (1/pi_inf) * SampleCov(Y_t - Y_{tpre_j}, Y_t - Y_{tpre_k} | G=inf)
261
+ if pi_inf > 0 and tpre_j_col in inf_t_minus_tpre:
262
+ val += (1.0 / pi_inf) * _sample_cov(
263
+ inf_t_minus_tpre[tpre_j_col],
264
+ inf_t_minus_tpre[tpre_k_col],
265
+ w=w_inf,
266
+ )
267
+
268
+ # Term 3: -1{g == g'_j} / pi_g * SampleCov(Y_t-Y_1, Y_{tpre_j}-Y_1 | G=g)
269
+ if gp_j == target_g and tpre_j_col in g_tpre_minus_1:
270
+ val -= (1.0 / pi_g) * _sample_cov(
271
+ Yg_t_minus_1,
272
+ g_tpre_minus_1[tpre_j_col],
273
+ w=w_g,
274
+ )
275
+
276
+ # Term 4: -1{g == g'_k} / pi_g * SampleCov(Y_t-Y_1, Y_{tpre_k}-Y_1 | G=g)
277
+ if gp_k == target_g and tpre_k_col in g_tpre_minus_1:
278
+ val -= (1.0 / pi_g) * _sample_cov(
279
+ Yg_t_minus_1,
280
+ g_tpre_minus_1[tpre_k_col],
281
+ w=w_g,
282
+ )
283
+
284
+ # Term 5: 1{g'_j == g'_k} / pi_{g'_j} * SampleCov(Y_{tpre_j}-Y_1, Y_{tpre_k}-Y_1 | G=g'_j)
285
+ if gp_j == gp_k:
286
+ if np.isinf(gp_j):
287
+ if pi_inf > 0 and tpre_j_col in inf_tpre_minus_1:
288
+ val += (1.0 / pi_inf) * _sample_cov(
289
+ inf_tpre_minus_1[tpre_j_col],
290
+ inf_tpre_minus_1[tpre_k_col],
291
+ w=w_inf,
292
+ )
293
+ else:
294
+ pi_gp = cohort_fractions.get(gp_j, 0.0)
295
+ if pi_gp > 0 and gp_j in cohort_masks:
296
+ Y_gp = gp_outcomes.get(gp_j)
297
+ if Y_gp is None:
298
+ Y_gp = outcome_wide[cohort_masks[gp_j]]
299
+ w_gp = gp_weights.get(gp_j)
300
+ if len(Y_gp) >= 2:
301
+ # Cache tpre diffs for comparison cohorts
302
+ key_j = (gp_j, tpre_j_col)
303
+ if key_j not in gp_tpre_minus_1:
304
+ gp_tpre_minus_1[key_j] = Y_gp[:, tpre_j_col] - Y_gp[:, y1_col]
305
+ key_k = (gp_j, tpre_k_col)
306
+ if key_k not in gp_tpre_minus_1:
307
+ gp_tpre_minus_1[key_k] = Y_gp[:, tpre_k_col] - Y_gp[:, y1_col]
308
+ val += (1.0 / pi_gp) * _sample_cov(
309
+ gp_tpre_minus_1[key_j],
310
+ gp_tpre_minus_1[key_k],
311
+ w=w_gp,
312
+ )
313
+
314
+ omega[j, k] = val
315
+ if j != k:
316
+ omega[k, j] = val
317
+
318
+ return omega
319
+
320
+
321
+ def compute_efficient_weights(
322
+ omega_star: np.ndarray,
323
+ cond_threshold: float = 1e12,
324
+ ) -> Tuple[np.ndarray, bool, float]:
325
+ """Compute efficient weights from Omega* inverse (Eq 3.13 / 4.3).
326
+
327
+ ``w = ones @ inv(Omega*) / (ones @ inv(Omega*) @ ones)``
328
+
329
+ Parameters
330
+ ----------
331
+ omega_star : ndarray, shape (H, H)
332
+ Covariance matrix from :func:`compute_omega_star_nocov`.
333
+ cond_threshold : float
334
+ If condition number exceeds this, use pseudoinverse + warning.
335
+
336
+ Returns
337
+ -------
338
+ weights : ndarray, shape (H,)
339
+ Efficient combination weights (sum to 1).
340
+ used_pinv : bool
341
+ True if pseudoinverse was used.
342
+ cond_number : float
343
+ Condition number of Omega* (avoids recomputation by caller).
344
+ """
345
+ H = omega_star.shape[0]
346
+ if H == 0:
347
+ return np.array([]), False, 0.0
348
+ if H == 1:
349
+ return np.array([1.0]), False, 1.0
350
+
351
+ ones = np.ones(H)
352
+ used_pinv = False
353
+
354
+ # Check for zero matrix
355
+ if np.allclose(omega_star, 0.0):
356
+ warnings.warn(
357
+ "Omega* matrix is all zeros; using uniform weights.",
358
+ UserWarning,
359
+ stacklevel=2,
360
+ )
361
+ return ones / H, False, np.inf
362
+
363
+ cond = float(np.linalg.cond(omega_star))
364
+ if cond > cond_threshold:
365
+ warnings.warn(
366
+ f"Omega* condition number ({cond:.2e}) exceeds threshold "
367
+ f"({cond_threshold:.2e}); using pseudoinverse for weights.",
368
+ UserWarning,
369
+ stacklevel=2,
370
+ )
371
+ omega_inv = np.linalg.pinv(omega_star)
372
+ used_pinv = True
373
+ else:
374
+ try:
375
+ omega_inv = np.linalg.inv(omega_star)
376
+ except np.linalg.LinAlgError:
377
+ omega_inv = np.linalg.pinv(omega_star)
378
+ used_pinv = True
379
+
380
+ numerator = ones @ omega_inv # shape (H,)
381
+ denominator = numerator @ ones # scalar
382
+
383
+ if abs(denominator) < 1e-15:
384
+ warnings.warn(
385
+ "Denominator of efficient weights is near zero; using uniform weights.",
386
+ UserWarning,
387
+ stacklevel=2,
388
+ )
389
+ return ones / H, used_pinv, cond
390
+
391
+ weights = numerator / denominator
392
+ return weights, used_pinv, cond
393
+
394
+
395
+ def compute_generated_outcomes_nocov(
396
+ target_g: float,
397
+ target_t: float,
398
+ valid_pairs: List[Tuple[float, float]],
399
+ outcome_wide: np.ndarray,
400
+ cohort_masks: Dict[float, np.ndarray],
401
+ never_treated_mask: np.ndarray,
402
+ period_to_col: Dict[float, int],
403
+ period_1_col: int,
404
+ never_treated_val: float = np.inf,
405
+ unit_weights: Optional[np.ndarray] = None,
406
+ ) -> np.ndarray:
407
+ """Compute generated outcome vector (one scalar per valid pair).
408
+
409
+ In the no-covariates case each generated outcome is a triple-difference
410
+ of within-group sample means (Eq 3.9 / 4.4 simplified)::
411
+
412
+ Y_hat_j = mean(Y_t - Y_1 | G=g)
413
+ - mean(Y_t - Y_{t_pre} | G=inf)
414
+ - mean(Y_{t_pre} - Y_1 | G=g')
415
+
416
+ where ``inf`` denotes the never-treated group and ``g'`` is the comparison
417
+ cohort for pair *j*.
418
+
419
+ Parameters
420
+ ----------
421
+ target_g, target_t : float
422
+ Target group-time.
423
+ valid_pairs : list of (g', t_pre)
424
+ Valid comparison pairs.
425
+ outcome_wide : ndarray, shape (n_units, n_periods)
426
+ cohort_masks, never_treated_mask, period_to_col, period_1_col :
427
+ Pre-computed data structures.
428
+ never_treated_val : float
429
+ Sentinel for never-treated.
430
+ unit_weights : ndarray, shape (n_units,), optional
431
+ Survey weights at the unit level. When provided, all sample
432
+ means become weighted means.
433
+
434
+ Returns
435
+ -------
436
+ ndarray, shape (|H|,)
437
+ Scalar generated outcome for each pair.
438
+ """
439
+ H = len(valid_pairs)
440
+ if H == 0:
441
+ return np.array([])
442
+
443
+ t_col = period_to_col[target_t]
444
+ y1_col = period_1_col
445
+
446
+ # Helper: weighted or unweighted mean
447
+ def _wmean(vals: np.ndarray, w: Optional[np.ndarray]) -> float:
448
+ if w is not None:
449
+ return float(np.average(vals, weights=w))
450
+ return float(np.mean(vals))
451
+
452
+ # Per-cohort weights
453
+ g_mask = cohort_masks[target_g]
454
+ w_g = unit_weights[g_mask] if unit_weights is not None else None
455
+ w_inf = unit_weights[never_treated_mask] if unit_weights is not None else None
456
+
457
+ # Target group mean: mean(Y_t - Y_1 | G = g)
458
+ Y_g = outcome_wide[g_mask]
459
+ mean_g_t_1 = _wmean(Y_g[:, t_col] - Y_g[:, y1_col], w_g)
460
+
461
+ # Never-treated outcomes
462
+ Y_inf = outcome_wide[never_treated_mask]
463
+
464
+ y_hat = np.empty(H)
465
+
466
+ for j, (gp, tpre) in enumerate(valid_pairs):
467
+ tpre_col = period_to_col[tpre]
468
+
469
+ # mean(Y_t - Y_{tpre} | G = inf)
470
+ mean_inf_t_tpre = _wmean(Y_inf[:, t_col] - Y_inf[:, tpre_col], w_inf)
471
+
472
+ # mean(Y_{tpre} - Y_1 | G = g')
473
+ if np.isinf(gp):
474
+ Y_gp = Y_inf
475
+ w_gp = w_inf
476
+ else:
477
+ Y_gp = outcome_wide[cohort_masks[gp]]
478
+ w_gp = unit_weights[cohort_masks[gp]] if unit_weights is not None else None
479
+ mean_gp_tpre_1 = _wmean(Y_gp[:, tpre_col] - Y_gp[:, y1_col], w_gp)
480
+
481
+ y_hat[j] = mean_g_t_1 - mean_inf_t_tpre - mean_gp_tpre_1
482
+
483
+ return y_hat
484
+
485
+
486
+ def compute_eif_nocov(
487
+ target_g: float,
488
+ target_t: float,
489
+ weights: np.ndarray,
490
+ valid_pairs: List[Tuple[float, float]],
491
+ outcome_wide: np.ndarray,
492
+ cohort_masks: Dict[float, np.ndarray],
493
+ never_treated_mask: np.ndarray,
494
+ period_to_col: Dict[float, int],
495
+ period_1_col: int,
496
+ cohort_fractions: Dict[float, float],
497
+ n_units: int,
498
+ never_treated_val: float = np.inf,
499
+ unit_weights: Optional[np.ndarray] = None,
500
+ ) -> np.ndarray:
501
+ """Compute per-unit efficient influence function values.
502
+
503
+ For each unit *i* and each valid pair *j*, three terms contribute to
504
+ the EIF depending on the unit's cohort membership:
505
+
506
+ * **Treated term** (unit in cohort g):
507
+ ``(1/pi_g) * (Y_{i,t} - Y_{i,1} - Y_hat_j) - ATT(g,t)``
508
+ * **Never-treated term** (unit in never-treated):
509
+ ``-(1/pi_g) * (1/pi_inf) * pi_g * (Y_{i,t} - Y_{i,tpre_j} - mean_inf)``
510
+ (simplified: contributes the comparison group score for the never-treated)
511
+ * **Comparison cohort term** (unit in cohort g'_j):
512
+ ``-(1/pi_g) * (1/pi_{g'_j}) * pi_g * (Y_{i,tpre_j} - Y_{i,1} - mean_gp)``
513
+
514
+ These are combined with efficient weights ``w_j``.
515
+
516
+ The derivation follows Theorem 3.2 and Eq 3.9-3.10, simplified for
517
+ the no-covariates case where propensity score ratios equal cohort
518
+ fraction ratios.
519
+
520
+ Parameters
521
+ ----------
522
+ target_g, target_t : float
523
+ Target group-time.
524
+ weights : ndarray, shape (H,)
525
+ Efficient weights.
526
+ valid_pairs : list of (g', t_pre)
527
+ outcome_wide, cohort_masks, never_treated_mask, period_to_col,
528
+ period_1_col, cohort_fractions, n_units, never_treated_val :
529
+ Pre-computed data structures.
530
+ unit_weights : ndarray, shape (n_units,), optional
531
+ Survey weights at the unit level. When provided, within-group
532
+ means are weighted means.
533
+
534
+ Returns
535
+ -------
536
+ ndarray, shape (n_units,)
537
+ EIF value for every unit.
538
+ """
539
+ H = len(valid_pairs)
540
+ if H == 0:
541
+ return np.zeros(n_units)
542
+
543
+ t_col = period_to_col[target_t]
544
+ y1_col = period_1_col
545
+
546
+ g_mask = cohort_masks[target_g]
547
+ Y_g = outcome_wide[g_mask]
548
+ pi_g = cohort_fractions[target_g]
549
+
550
+ Y_inf = outcome_wide[never_treated_mask]
551
+ pi_inf = cohort_fractions.get(never_treated_val, 0.0)
552
+
553
+ # Per-cohort weights
554
+ w_g = unit_weights[g_mask] if unit_weights is not None else None
555
+ w_inf = unit_weights[never_treated_mask] if unit_weights is not None else None
556
+
557
+ # Helper for weighted/unweighted mean
558
+ def _wmean(vals: np.ndarray, w: Optional[np.ndarray]) -> float:
559
+ if w is not None:
560
+ return float(np.average(vals, weights=w))
561
+ return float(np.mean(vals))
562
+
563
+ eif = np.zeros(n_units)
564
+
565
+ # Hoist treated-group computations out of the per-pair loop (j-invariant)
566
+ Yg_t_minus_1 = Y_g[:, t_col] - Y_g[:, y1_col]
567
+ mean_g_t_1 = _wmean(Yg_t_minus_1, w_g)
568
+ treated_demeaned = None
569
+ if pi_g > 0:
570
+ treated_demeaned = (1.0 / pi_g) * (Yg_t_minus_1 - mean_g_t_1)
571
+
572
+ # Precompute never-treated diffs per tpre to avoid recomputation
573
+ inf_diffs: Dict[int, np.ndarray] = {}
574
+ inf_means: Dict[int, float] = {}
575
+
576
+ for j, (gp, tpre) in enumerate(valid_pairs):
577
+ w_j = weights[j]
578
+ tpre_col = period_to_col[tpre]
579
+
580
+ # --- Treated term (units in cohort g) ---
581
+ # (1/pi_g) * demeaned(Y_t - Y_1 | G=g) — same for all j
582
+ if treated_demeaned is not None:
583
+ eif[g_mask] += w_j * treated_demeaned
584
+
585
+ # --- Never-treated term ---
586
+ if tpre_col not in inf_diffs:
587
+ inf_diffs[tpre_col] = Y_inf[:, t_col] - Y_inf[:, tpre_col]
588
+ inf_means[tpre_col] = _wmean(inf_diffs[tpre_col], w_inf)
589
+ if pi_inf > 0:
590
+ inf_contrib = -(1.0 / pi_inf) * (inf_diffs[tpre_col] - inf_means[tpre_col])
591
+ eif[never_treated_mask] += w_j * inf_contrib
592
+
593
+ # --- Comparison cohort term ---
594
+ # Contribution from units in cohort g'_j for the baseline shift tpre_j - Y_1
595
+ if np.isinf(gp):
596
+ # Comparison group is never-treated; contribution is folded into
597
+ # the never-treated term via Y_{tpre} - Y_1 differencing.
598
+ # Additional term: -(1/pi_inf) * demeaned (Y_{tpre} - Y_1 | G=inf)
599
+ mean_inf_tpre_1 = _wmean(Y_inf[:, tpre_col] - Y_inf[:, y1_col], w_inf)
600
+ if pi_inf > 0:
601
+ gp_contrib = -(1.0 / pi_inf) * (
602
+ (Y_inf[:, tpre_col] - Y_inf[:, y1_col]) - mean_inf_tpre_1
603
+ )
604
+ eif[never_treated_mask] += w_j * gp_contrib
605
+ else:
606
+ gp_mask = cohort_masks[gp]
607
+ Y_gp = outcome_wide[gp_mask]
608
+ pi_gp = cohort_fractions.get(gp, 0.0)
609
+ w_gp = unit_weights[gp_mask] if unit_weights is not None else None
610
+ mean_gp_tpre_1 = _wmean(Y_gp[:, tpre_col] - Y_gp[:, y1_col], w_gp)
611
+ if pi_gp > 0:
612
+ gp_contrib = -(1.0 / pi_gp) * (
613
+ (Y_gp[:, tpre_col] - Y_gp[:, y1_col]) - mean_gp_tpre_1
614
+ )
615
+ eif[gp_mask] += w_j * gp_contrib
616
+
617
+ return eif