diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
@@ -0,0 +1,864 @@
1
+ """
2
+ Aggregation methods mixin for Callaway-Sant'Anna estimator.
3
+
4
+ This module provides the mixin class containing methods for aggregating
5
+ group-time average treatment effects into summary measures.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from diff_diff.utils import safe_inference_batch
14
+
15
+ # Type alias for pre-computed structures (defined at module scope for runtime access)
16
+ PrecomputedData = Dict[str, Any]
17
+
18
+
19
+ class CallawaySantAnnaAggregationMixin:
20
+ """
21
+ Mixin class providing aggregation methods for CallawaySantAnna estimator.
22
+
23
+ This class is not intended to be used standalone. It provides methods
24
+ that are used by the main CallawaySantAnna class to aggregate group-time
25
+ effects into summary measures.
26
+ """
27
+
28
+ # Type hints for attributes accessed from the main class
29
+ alpha: float
30
+
31
+ # Type hint for anticipation attribute accessed from main class
32
+ anticipation: int
33
+
34
+ # Type hint for base_period attribute accessed from main class
35
+ base_period: str
36
+
37
+ def _aggregate_simple(
38
+ self,
39
+ group_time_effects: Dict,
40
+ influence_func_info: Dict,
41
+ df: pd.DataFrame,
42
+ unit: str,
43
+ precomputed: Optional["PrecomputedData"] = None,
44
+ ) -> Tuple[float, float]:
45
+ """
46
+ Compute simple weighted average of ATT(g,t).
47
+
48
+ Weights by group size (number of treated units).
49
+
50
+ Standard errors are computed using influence function aggregation,
51
+ which properly accounts for covariances across (g,t) pairs due to
52
+ shared control units. This includes the wif (weight influence function)
53
+ adjustment from R's `did` package that accounts for uncertainty in
54
+ estimating the group-size weights.
55
+
56
+ Note: Only post-treatment effects (t >= g - anticipation) are included
57
+ in the overall ATT. Pre-treatment effects are computed for parallel
58
+ trends assessment but are not aggregated into the overall ATT.
59
+ """
60
+ effects = []
61
+ weights_list = []
62
+ gt_pairs = []
63
+ groups_for_gt = []
64
+
65
+ # For survey: compute fixed per-cohort weight sums from the full
66
+ # unit-level sample (matching R's did::aggte pg = n_g / N).
67
+ survey_cohort_weights = None
68
+ if precomputed is not None and precomputed.get("survey_weights") is not None:
69
+ sw = precomputed["survey_weights"]
70
+ unit_cohorts = precomputed["unit_cohorts"]
71
+ survey_cohort_weights = {}
72
+ for g in np.unique(unit_cohorts):
73
+ if g > 0: # exclude never-treated (0)
74
+ survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g]))
75
+
76
+ for (g, t), data in group_time_effects.items():
77
+ # Only include post-treatment effects (t >= g - anticipation)
78
+ # Pre-treatment effects are for parallel trends, not overall ATT
79
+ if t < g - self.anticipation:
80
+ continue
81
+ effects.append(data["effect"])
82
+ # Use fixed cohort-level survey weight sum for aggregation.
83
+ # For RCS, data["agg_weight"] holds the fixed cohort mass;
84
+ # for panel, fallback to data["n_treated"].
85
+ if survey_cohort_weights is not None and g in survey_cohort_weights:
86
+ weights_list.append(survey_cohort_weights[g])
87
+ else:
88
+ weights_list.append(data.get("agg_weight", data["n_treated"]))
89
+ gt_pairs.append((g, t))
90
+ groups_for_gt.append(g)
91
+
92
+ # Guard against empty post-treatment set
93
+ if len(effects) == 0:
94
+ import warnings
95
+
96
+ warnings.warn(
97
+ "No post-treatment effects available for overall ATT aggregation. "
98
+ "This can occur when cohorts lack post-treatment periods in the data.",
99
+ UserWarning,
100
+ stacklevel=2,
101
+ )
102
+ return np.nan, np.nan, None
103
+
104
+ effects = np.array(effects)
105
+ weights = np.array(weights_list, dtype=float)
106
+ groups_for_gt = np.array(groups_for_gt)
107
+
108
+ # Exclude NaN effects from aggregation (R's aggte() convention).
109
+ # No warning here — fit() emits a consolidated skip warning covering
110
+ # all estimation paths (vectorized, covariate, general, RC).
111
+ finite_mask = np.isfinite(effects)
112
+ if not np.all(finite_mask):
113
+ effects = effects[finite_mask]
114
+ weights = weights[finite_mask]
115
+ gt_pairs = [gt for gt, m in zip(gt_pairs, finite_mask) if m]
116
+ groups_for_gt = groups_for_gt[finite_mask]
117
+
118
+ if len(effects) == 0:
119
+ import warnings
120
+
121
+ warnings.warn(
122
+ "All post-treatment effects are NaN. Cannot compute overall ATT.",
123
+ UserWarning,
124
+ stacklevel=2,
125
+ )
126
+ return np.nan, np.nan, None
127
+
128
+ # Normalize weights
129
+ total_weight = np.sum(weights)
130
+ weights_norm = weights / total_weight
131
+
132
+ # Weighted average
133
+ overall_att = np.sum(weights_norm * effects)
134
+
135
+ # Compute SE using influence function aggregation with wif adjustment
136
+ overall_se, effective_df = self._compute_aggregated_se_with_wif(
137
+ gt_pairs,
138
+ weights_norm,
139
+ effects,
140
+ groups_for_gt,
141
+ influence_func_info,
142
+ df,
143
+ unit,
144
+ precomputed,
145
+ )
146
+
147
+ return overall_att, overall_se, effective_df
148
+
149
+ def _compute_aggregated_se(
150
+ self,
151
+ gt_pairs: List[Tuple[Any, Any]],
152
+ weights: np.ndarray,
153
+ influence_func_info: Dict,
154
+ n_units: Optional[int] = None,
155
+ ) -> float:
156
+ """
157
+ Compute standard error using influence function aggregation.
158
+
159
+ This properly accounts for covariances across (g,t) pairs by
160
+ aggregating unit-level influence functions:
161
+
162
+ ψ_i(overall) = Σ_{(g,t)} w_(g,t) × ψ_i(g,t)
163
+ Var(overall) = (1/n) Σ_i [ψ_i]²
164
+
165
+ This matches R's `did` package analytical SE formula.
166
+
167
+ Parameters
168
+ ----------
169
+ n_units : int, optional
170
+ Size of the canonical index space (len(precomputed['all_units'])).
171
+ When provided, influence function indices (treated_idx, control_idx)
172
+ index directly into this space, eliminating dict lookups.
173
+ """
174
+ if not influence_func_info:
175
+ return 0.0
176
+
177
+ if n_units is None:
178
+ # Fallback: infer size from influence function info
179
+ max_idx = 0
180
+ for g, t in gt_pairs:
181
+ if (g, t) in influence_func_info:
182
+ info = influence_func_info[(g, t)]
183
+ if len(info["treated_idx"]) > 0:
184
+ max_idx = max(max_idx, info["treated_idx"].max())
185
+ if len(info["control_idx"]) > 0:
186
+ max_idx = max(max_idx, info["control_idx"].max())
187
+ n_units = max_idx + 1
188
+
189
+ if n_units == 0:
190
+ return 0.0
191
+
192
+ # Aggregate influence functions across (g,t) pairs
193
+ psi_overall = np.zeros(n_units)
194
+
195
+ for j, (g, t) in enumerate(gt_pairs):
196
+ if (g, t) not in influence_func_info:
197
+ continue
198
+
199
+ info = influence_func_info[(g, t)]
200
+ w = weights[j]
201
+
202
+ # Vectorized influence function aggregation using index arrays
203
+ treated_idx = info["treated_idx"]
204
+ if len(treated_idx) > 0:
205
+ np.add.at(psi_overall, treated_idx, w * info["treated_inf"])
206
+
207
+ control_idx = info["control_idx"]
208
+ if len(control_idx) > 0:
209
+ np.add.at(psi_overall, control_idx, w * info["control_inf"])
210
+
211
+ # Compute variance: Var(θ̄) = (1/n) Σᵢ ψᵢ²
212
+ variance = np.sum(psi_overall**2)
213
+ return np.sqrt(variance)
214
+
215
+ def _compute_combined_influence_function(
216
+ self,
217
+ gt_pairs: List[Tuple[Any, Any]],
218
+ weights: np.ndarray,
219
+ effects: np.ndarray,
220
+ groups_for_gt: np.ndarray,
221
+ influence_func_info: Dict,
222
+ df: pd.DataFrame,
223
+ unit: str,
224
+ precomputed: Optional["PrecomputedData"] = None,
225
+ global_unit_to_idx: Optional[Dict[Any, int]] = None,
226
+ n_global_units: Optional[int] = None,
227
+ ) -> Tuple[np.ndarray, Optional[List]]:
228
+ """
229
+ Compute the combined (standard IF + WIF) influence function vector.
230
+
231
+ If global_unit_to_idx / n_global_units are provided, the returned vector
232
+ is zero-padded to the global unit set for bootstrap alignment.
233
+ Otherwise, the returned vector is indexed by the local unit set
234
+ (all units appearing in the (g,t) pairs).
235
+
236
+ Returns
237
+ -------
238
+ combined_if : np.ndarray
239
+ Per-unit combined influence function (standard IF + WIF).
240
+ all_units : list or None
241
+ Ordered list of units (only when using local indexing).
242
+ """
243
+ if not influence_func_info:
244
+ if n_global_units is not None:
245
+ return np.zeros(n_global_units), None
246
+ return np.zeros(0), None
247
+
248
+ # Detect RCS mode via explicit flag. In RCS, obs indices ARE array positions.
249
+ _is_rcs = precomputed is not None and not precomputed.get("is_panel", True)
250
+
251
+ # Build unit index mapping (local or global)
252
+ if _is_rcs and n_global_units is not None:
253
+ # RCS: direct indexing — obs indices are the array positions
254
+ n_units = n_global_units
255
+ all_units = None
256
+ elif global_unit_to_idx is not None and n_global_units is not None:
257
+ n_units = n_global_units
258
+ all_units = None # caller already has the unit list
259
+ else:
260
+ all_units_set: Set[Any] = set()
261
+ for g, t in gt_pairs:
262
+ if (g, t) in influence_func_info:
263
+ info = influence_func_info[(g, t)]
264
+ all_units_set.update(info["treated_units"])
265
+ all_units_set.update(info["control_units"])
266
+
267
+ if not all_units_set:
268
+ return np.zeros(0), []
269
+
270
+ all_units = sorted(all_units_set)
271
+ n_units = len(all_units)
272
+
273
+ # Get unique groups and their information
274
+ unique_groups = sorted(set(groups_for_gt))
275
+ unique_groups_set = set(unique_groups)
276
+ group_to_idx = {g: i for i, g in enumerate(unique_groups)}
277
+
278
+ # Check for survey weights in precomputed data
279
+ survey_w = precomputed.get("survey_weights") if precomputed is not None else None
280
+
281
+ # Compute group-level probabilities matching R's formula:
282
+ # pg[g] = n_g / n_all (fraction of ALL units in group g)
283
+ # With survey weights: pg[g] = sum(sw_g) / sum(sw_all)
284
+ group_sizes = {}
285
+ if survey_w is not None:
286
+ # Survey-weighted group sizes
287
+ precomputed_cohorts = precomputed["unit_cohorts"]
288
+ for g in unique_groups:
289
+ mask_g = precomputed_cohorts == g
290
+ group_sizes[g] = float(np.sum(survey_w[mask_g]))
291
+ total_weight = float(np.sum(survey_w))
292
+ elif _is_rcs:
293
+ # RCS without survey: count observations per cohort
294
+ precomputed_cohorts = precomputed["unit_cohorts"]
295
+ for g in unique_groups:
296
+ group_sizes[g] = int(np.sum(precomputed_cohorts == g))
297
+ total_weight = float(n_units)
298
+ else:
299
+ for g in unique_groups:
300
+ treated_in_g = df[df["first_treat"] == g][unit].nunique()
301
+ group_sizes[g] = treated_in_g
302
+ total_weight = float(n_units)
303
+
304
+ # pg indexed by group
305
+ pg_by_group = np.array([group_sizes[g] / total_weight for g in unique_groups])
306
+
307
+ # pg indexed by keeper (each (g,t) pair gets its group's pg)
308
+ pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt])
309
+ sum_pg_keepers = np.sum(pg_keepers)
310
+
311
+ # Guard against zero weights (no keepers = no variance)
312
+ if sum_pg_keepers == 0:
313
+ return np.zeros(n_units), all_units
314
+
315
+ # Standard aggregated influence (without wif)
316
+ psi_standard = np.zeros(n_units)
317
+
318
+ for j, (g, t) in enumerate(gt_pairs):
319
+ if (g, t) not in influence_func_info:
320
+ continue
321
+
322
+ info = influence_func_info[(g, t)]
323
+ w = weights[j]
324
+
325
+ # Vectorized influence function aggregation using precomputed index arrays
326
+ treated_idx = info["treated_idx"]
327
+ if len(treated_idx) > 0:
328
+ np.add.at(psi_standard, treated_idx, w * info["treated_inf"])
329
+
330
+ control_idx = info["control_idx"]
331
+ if len(control_idx) > 0:
332
+ np.add.at(psi_standard, control_idx, w * info["control_inf"])
333
+
334
+ # Build unit-group array: normalize iterator to (idx, uid) pairs
335
+ unit_groups_array = np.full(n_units, -1, dtype=np.float64)
336
+
337
+ if _is_rcs:
338
+ # RCS: direct vectorized assignment — obs indices are positions
339
+ precomputed_cohorts = precomputed["unit_cohorts"]
340
+ for g in unique_groups:
341
+ mask_g = precomputed_cohorts == g
342
+ unit_groups_array[mask_g] = g
343
+ elif global_unit_to_idx is not None:
344
+ idx_uid_pairs = [(idx, uid) for uid, idx in global_unit_to_idx.items()]
345
+
346
+ if precomputed is not None:
347
+ precomputed_cohorts = precomputed["unit_cohorts"]
348
+ precomputed_unit_to_idx = precomputed["unit_to_idx"]
349
+ for idx, uid in idx_uid_pairs:
350
+ if uid in precomputed_unit_to_idx:
351
+ cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]]
352
+ if cohort in unique_groups_set:
353
+ unit_groups_array[idx] = cohort
354
+ else:
355
+ for idx, uid in idx_uid_pairs:
356
+ unit_first_treat = df[df[unit] == uid]["first_treat"].iloc[0]
357
+ if unit_first_treat in unique_groups_set:
358
+ unit_groups_array[idx] = unit_first_treat
359
+ else:
360
+ idx_uid_pairs = list(enumerate(all_units))
361
+ for idx, uid in idx_uid_pairs:
362
+ unit_first_treat = df[df[unit] == uid]["first_treat"].iloc[0]
363
+ if unit_first_treat in unique_groups_set:
364
+ unit_groups_array[idx] = unit_first_treat
365
+
366
+ # Vectorized WIF computation
367
+ groups_for_gt_array = np.array(groups_for_gt)
368
+ indicator_matrix = (
369
+ unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :]
370
+ ).astype(np.float64)
371
+
372
+ if survey_w is not None:
373
+ # Survey-weighted WIF matching R's did::wif() / compute.aggte.R.
374
+ # pg_k = E[w_i * 1{G_i=g}] is the weighted group share.
375
+ # IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT s_i * (1{G_i=g} - pg_k).
376
+ # The pg subtraction is NOT weighted by s_i because pg is already
377
+ # the population-level expected value of w_i * 1{G_i=g}.
378
+ if _is_rcs and precomputed is not None:
379
+ # RCS: survey weights are already per-observation, direct indexing
380
+ unit_sw = survey_w
381
+ elif global_unit_to_idx is not None and precomputed is not None:
382
+ unit_sw = np.zeros(n_units)
383
+ precomputed_unit_to_idx_local = precomputed["unit_to_idx"]
384
+ idx_uid_pairs_sw = [(idx, uid) for uid, idx in global_unit_to_idx.items()]
385
+ for idx, uid in idx_uid_pairs_sw:
386
+ if uid in precomputed_unit_to_idx_local:
387
+ pc_idx = precomputed_unit_to_idx_local[uid]
388
+ unit_sw[idx] = survey_w[pc_idx]
389
+ else:
390
+ unit_sw = np.ones(n_units)
391
+
392
+ # w_i * 1{G_i == g_k} - pg_k (matches R's did::wif)
393
+ weighted_indicator = indicator_matrix * unit_sw[:, np.newaxis]
394
+ indicator_diff = weighted_indicator - pg_keepers
395
+ indicator_sum_w = np.sum(indicator_diff, axis=1)
396
+
397
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
398
+ if1_matrix = indicator_diff / sum_pg_keepers
399
+ if2_matrix = np.outer(indicator_sum_w, pg_keepers) / (sum_pg_keepers**2)
400
+ wif_matrix = if1_matrix - if2_matrix
401
+ wif_contrib = wif_matrix @ effects
402
+ else:
403
+ indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1)
404
+
405
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
406
+ if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers
407
+ if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers**2)
408
+ wif_matrix = if1_matrix - if2_matrix
409
+ wif_contrib = wif_matrix @ effects
410
+
411
+ # Check for non-finite values from edge cases
412
+ if not np.all(np.isfinite(wif_contrib)):
413
+ import warnings
414
+
415
+ n_nonfinite = np.sum(~np.isfinite(wif_contrib))
416
+ warnings.warn(
417
+ f"Non-finite values ({n_nonfinite}/{len(wif_contrib)}) in weight influence "
418
+ "function computation. This may occur with very small samples or extreme "
419
+ "weights. Returning NaN for SE to signal invalid inference.",
420
+ RuntimeWarning,
421
+ stacklevel=2,
422
+ )
423
+ nan_result = np.full(n_units, np.nan)
424
+ return nan_result, all_units
425
+
426
+ # Scale by 1/total_weight to match R's getSE formula
427
+ # (for non-survey, total_weight == n_units; for survey, total_weight == sum(sw))
428
+ psi_wif = wif_contrib / total_weight
429
+
430
+ # Combine standard and wif terms
431
+ psi_total = psi_standard + psi_wif
432
+
433
+ return psi_total, all_units
434
+
435
+ def _compute_aggregated_se_with_wif(
436
+ self,
437
+ gt_pairs: List[Tuple[Any, Any]],
438
+ weights: np.ndarray,
439
+ effects: np.ndarray,
440
+ groups_for_gt: np.ndarray,
441
+ influence_func_info: Dict,
442
+ df: pd.DataFrame,
443
+ unit: str,
444
+ precomputed: Optional["PrecomputedData"] = None,
445
+ return_psi: bool = False,
446
+ ) -> "Union[float, Tuple[float, np.ndarray]]":
447
+ """
448
+ Compute SE with weight influence function (wif) adjustment.
449
+
450
+ This matches R's `did` package approach for aggregation,
451
+ which accounts for uncertainty in estimating group-size weights.
452
+
453
+ When a full survey design (strata/PSU/FPC) is available in
454
+ ``precomputed['resolved_survey']``, the design-based variance
455
+ :func:`compute_survey_if_variance` is used instead of the simple
456
+ ``sum(psi^2)`` formula.
457
+
458
+ Formula (matching R's did::aggte):
459
+ agg_inf_i = Σ_k w_k × inf_i_k + wif_i × ATT_k
460
+ se = sqrt(mean(agg_inf^2) / n)
461
+ """
462
+ # Extract global unit info for correct pg = n_g / N_total scaling.
463
+ # Without this, the local path builds the unit set from only units in
464
+ # the selected (g,t) pairs, causing pg overestimation at extreme event
465
+ # times where only early-adopter groups have data.
466
+ global_unit_to_idx = None
467
+ n_global_units = None
468
+ if precomputed is not None:
469
+ global_unit_to_idx = precomputed["unit_to_idx"] # None for RCS
470
+ n_global_units = precomputed.get(
471
+ "canonical_size", len(precomputed.get("all_units", []))
472
+ )
473
+ elif df is not None and unit is not None:
474
+ n_global_units = df[unit].nunique()
475
+
476
+ psi_total, _ = self._compute_combined_influence_function(
477
+ gt_pairs,
478
+ weights,
479
+ effects,
480
+ groups_for_gt,
481
+ influence_func_info,
482
+ df,
483
+ unit,
484
+ precomputed,
485
+ global_unit_to_idx=global_unit_to_idx,
486
+ n_global_units=n_global_units,
487
+ )
488
+
489
+ if len(psi_total) == 0:
490
+ return (0.0, psi_total) if return_psi else 0.0
491
+
492
+ # Check for NaN propagation from non-finite WIF
493
+ if not np.all(np.isfinite(psi_total)):
494
+ return (np.nan, psi_total) if return_psi else np.nan
495
+
496
+ # Use design-based variance when full survey design is available
497
+ # Use unit-level resolved survey (panel IF is indexed by unit, not obs)
498
+ resolved_survey = (
499
+ precomputed.get("resolved_survey_unit") if precomputed is not None else None
500
+ )
501
+ if (
502
+ resolved_survey is not None
503
+ and hasattr(resolved_survey, "uses_replicate_variance")
504
+ and resolved_survey.uses_replicate_variance
505
+ ):
506
+ from diff_diff.survey import compute_replicate_if_variance
507
+
508
+ variance, n_valid_rep = compute_replicate_if_variance(psi_total, resolved_survey)
509
+ # Compute effective df for this statistic (don't mutate shared state)
510
+ effective_df = None
511
+ if n_valid_rep < resolved_survey.n_replicates:
512
+ effective_df = n_valid_rep - 1 if n_valid_rep > 1 else 0
513
+ if np.isnan(variance):
514
+ se = np.nan
515
+ else:
516
+ se = np.sqrt(max(variance, 0.0))
517
+ if return_psi:
518
+ return (se, psi_total, effective_df)
519
+ return (se, effective_df)
520
+
521
+ if resolved_survey is not None and (
522
+ resolved_survey.strata is not None
523
+ or resolved_survey.psu is not None
524
+ or resolved_survey.fpc is not None
525
+ ):
526
+ from diff_diff.survey import compute_survey_if_variance
527
+
528
+ variance = compute_survey_if_variance(psi_total, resolved_survey)
529
+ if np.isnan(variance):
530
+ se = np.nan
531
+ else:
532
+ se = np.sqrt(max(variance, 0.0))
533
+ if return_psi:
534
+ return (se, psi_total, None)
535
+ return (se, None)
536
+
537
+ variance = np.sum(psi_total**2)
538
+ se = np.sqrt(variance)
539
+ if return_psi:
540
+ return (se, psi_total, None)
541
+ return (se, None)
542
+
543
+ def _aggregate_event_study(
544
+ self,
545
+ group_time_effects: Dict,
546
+ influence_func_info: Dict,
547
+ groups: List[Any],
548
+ time_periods: List[Any],
549
+ balance_e: Optional[int] = None,
550
+ df: Optional[pd.DataFrame] = None,
551
+ unit: Optional[str] = None,
552
+ precomputed: Optional["PrecomputedData"] = None,
553
+ ) -> Dict[int, Dict[str, Any]]:
554
+ """
555
+ Aggregate effects by relative time (event study).
556
+
557
+ Computes average effect at each event time e = t - g.
558
+
559
+ Standard errors include the weight influence function (WIF)
560
+ adjustment that accounts for uncertainty in group-size weights,
561
+ matching R's did::aggte(..., type="dynamic").
562
+ """
563
+ # Organize effects by relative time, keeping track of (g,t) pairs
564
+ effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
565
+
566
+ # Fixed per-cohort survey weights for aggregation
567
+ survey_cohort_weights = None
568
+ if precomputed is not None and precomputed.get("survey_weights") is not None:
569
+ sw = precomputed["survey_weights"]
570
+ unit_cohorts = precomputed["unit_cohorts"]
571
+ survey_cohort_weights = {}
572
+ for g in np.unique(unit_cohorts):
573
+ if g > 0:
574
+ survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g]))
575
+
576
+ for (g, t), data in group_time_effects.items():
577
+ e = t - g # Relative time
578
+ if e not in effects_by_e:
579
+ effects_by_e[e] = []
580
+ # For RCS, data["agg_weight"] holds the fixed cohort mass;
581
+ # for panel, fallback to data["n_treated"].
582
+ w = (
583
+ survey_cohort_weights[g]
584
+ if survey_cohort_weights is not None and g in survey_cohort_weights
585
+ else data.get("agg_weight", data["n_treated"])
586
+ )
587
+ effects_by_e[e].append(
588
+ (
589
+ (g, t), # Keep track of the (g,t) pair
590
+ data["effect"],
591
+ w,
592
+ )
593
+ )
594
+
595
+ # Balance the panel if requested
596
+ if balance_e is not None:
597
+ # Keep only groups that have effects at relative time balance_e
598
+ groups_at_e = set()
599
+ for (g, t), data in group_time_effects.items():
600
+ if t - g == balance_e and np.isfinite(data["effect"]):
601
+ groups_at_e.add(g)
602
+
603
+ # Filter effects to only include balanced groups
604
+ balanced_effects: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
605
+ for (g, t), data in group_time_effects.items():
606
+ if g in groups_at_e:
607
+ e = t - g
608
+ if e not in balanced_effects:
609
+ balanced_effects[e] = []
610
+ w = (
611
+ survey_cohort_weights[g]
612
+ if survey_cohort_weights is not None and g in survey_cohort_weights
613
+ else data.get("agg_weight", data["n_treated"])
614
+ )
615
+ balanced_effects[e].append(
616
+ (
617
+ (g, t),
618
+ data["effect"],
619
+ w,
620
+ )
621
+ )
622
+ effects_by_e = balanced_effects
623
+
624
+ # Compute aggregated effects and SEs for all relative periods
625
+ sorted_periods = sorted(effects_by_e.items())
626
+ agg_effects_list = []
627
+ agg_ses_list = []
628
+ agg_n_groups = []
629
+ agg_effective_dfs = [] # Per-horizon effective df (replicate designs)
630
+ _psi_vectors = [] # Per-event-time combined IF vectors for VCV
631
+ _psi_event_times = [] # Event times that contributed a psi column
632
+ for e, effect_list in sorted_periods:
633
+ gt_pairs = [x[0] for x in effect_list]
634
+ effs = np.array([x[1] for x in effect_list])
635
+ ns = np.array([x[2] for x in effect_list], dtype=float)
636
+
637
+ # Exclude NaN effects from this period's aggregation
638
+ finite_mask = np.isfinite(effs)
639
+ if not np.all(finite_mask):
640
+ effs = effs[finite_mask]
641
+ ns = ns[finite_mask]
642
+ gt_pairs = [gt for gt, m in zip(gt_pairs, finite_mask) if m]
643
+ if len(effs) == 0:
644
+ agg_effects_list.append(np.nan)
645
+ agg_ses_list.append(np.nan)
646
+ agg_n_groups.append(0)
647
+ agg_effective_dfs.append(None)
648
+ continue
649
+
650
+ weights = ns / np.sum(ns)
651
+ agg_effect = np.sum(weights * effs)
652
+
653
+ # Compute SE with WIF adjustment (matching R's did::aggte)
654
+ groups_for_gt = np.array([g for (g, t) in gt_pairs])
655
+ agg_se, psi_e, eff_df = self._compute_aggregated_se_with_wif(
656
+ gt_pairs,
657
+ weights,
658
+ effs,
659
+ groups_for_gt,
660
+ influence_func_info,
661
+ df,
662
+ unit,
663
+ precomputed,
664
+ return_psi=True,
665
+ )
666
+
667
+ agg_effects_list.append(agg_effect)
668
+ agg_ses_list.append(agg_se)
669
+ agg_n_groups.append(len(effect_list))
670
+ agg_effective_dfs.append(eff_df)
671
+ _psi_vectors.append(psi_e)
672
+ _psi_event_times.append(e)
673
+
674
+ # Batch inference for all relative periods
675
+ if not agg_effects_list:
676
+ return {}
677
+ # Use per-horizon effective df if any replicate aggregation overrode it;
678
+ # otherwise fall back to the original df from the survey design.
679
+ df_survey_val = precomputed.get("df_survey") if precomputed is not None else None
680
+ # Guard: replicate design with undefined df → NaN inference
681
+ if (
682
+ df_survey_val is None
683
+ and precomputed is not None
684
+ and precomputed.get("resolved_survey_unit") is not None
685
+ and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance")
686
+ and precomputed["resolved_survey_unit"].uses_replicate_variance
687
+ ):
688
+ df_survey_val = 0
689
+ # If any horizon has a per-statistic effective df (dropped replicates),
690
+ # use the minimum across horizons for conservative batch inference.
691
+ non_none_dfs = [d for d in agg_effective_dfs if d is not None]
692
+ if non_none_dfs:
693
+ df_survey_val = min(non_none_dfs)
694
+ t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch(
695
+ np.array(agg_effects_list),
696
+ np.array(agg_ses_list),
697
+ alpha=self.alpha,
698
+ df=df_survey_val,
699
+ )
700
+
701
+ event_study_effects = {}
702
+ for idx, (e, _) in enumerate(sorted_periods):
703
+ event_study_effects[e] = {
704
+ "effect": agg_effects_list[idx],
705
+ "se": agg_ses_list[idx],
706
+ "t_stat": float(t_stats[idx]),
707
+ "p_value": float(p_values[idx]),
708
+ "conf_int": (float(ci_lowers[idx]), float(ci_uppers[idx])),
709
+ "n_groups": agg_n_groups[idx],
710
+ }
711
+
712
+ # Add reference period for universal base period mode (matches R did package)
713
+ if getattr(self, "base_period", "varying") == "universal":
714
+ ref_period = -1 - self.anticipation
715
+ if event_study_effects and ref_period not in event_study_effects:
716
+ event_study_effects[ref_period] = {
717
+ "effect": 0.0,
718
+ "se": np.nan,
719
+ "t_stat": np.nan,
720
+ "p_value": np.nan,
721
+ "conf_int": (np.nan, np.nan),
722
+ "n_groups": 0,
723
+ }
724
+
725
+ # Compute full event-study VCV from per-event-time IF vectors (Phase 7d)
726
+ # This enables HonestDiD to use the full covariance structure
727
+ event_study_vcov = None
728
+ valid_psi = [p for p in _psi_vectors if len(p) > 0]
729
+ if valid_psi:
730
+ try:
731
+ Psi = np.column_stack(valid_psi) # (n_units, n_event_times)
732
+ resolved_survey = (
733
+ precomputed.get("resolved_survey_unit") if precomputed is not None else None
734
+ )
735
+ if (
736
+ resolved_survey is not None
737
+ and not (
738
+ hasattr(resolved_survey, "uses_replicate_variance")
739
+ and resolved_survey.uses_replicate_variance
740
+ )
741
+ and (
742
+ resolved_survey.strata is not None
743
+ or resolved_survey.psu is not None
744
+ or resolved_survey.fpc is not None
745
+ )
746
+ ):
747
+ from diff_diff.survey import _compute_stratified_psu_meat
748
+
749
+ meat, _, _ = _compute_stratified_psu_meat(Psi, resolved_survey)
750
+ event_study_vcov = meat
751
+ elif (
752
+ resolved_survey is not None
753
+ and hasattr(resolved_survey, "uses_replicate_variance")
754
+ and resolved_survey.uses_replicate_variance
755
+ ):
756
+ # Replicate-weight: fall back to None (diagonal in HonestDiD)
757
+ # until multivariate replicate VCV is implemented
758
+ event_study_vcov = None
759
+ else:
760
+ # No survey: simple sum-of-outer-products
761
+ event_study_vcov = Psi.T @ Psi
762
+ except (ValueError, np.linalg.LinAlgError):
763
+ pass # Fall back to diagonal (None)
764
+
765
+ # Store the event-time index that matches VCV columns (for subsetting
766
+ # in HonestDiD when some event times are filtered out)
767
+ self._event_study_vcov_index = _psi_event_times if event_study_vcov is not None else None
768
+
769
+ # Attach VCV to self for CallawaySantAnna to pick up
770
+ self._event_study_vcov = event_study_vcov
771
+
772
+ return event_study_effects
773
+
774
+ def _aggregate_by_group(
775
+ self,
776
+ group_time_effects: Dict,
777
+ influence_func_info: Dict,
778
+ groups: List[Any],
779
+ precomputed: Optional["PrecomputedData"] = None,
780
+ df: Optional[pd.DataFrame] = None,
781
+ unit: Optional[str] = None,
782
+ ) -> Dict[Any, Dict[str, Any]]:
783
+ """
784
+ Aggregate effects by treatment cohort.
785
+
786
+ Computes average effect for each cohort across all post-treatment periods.
787
+
788
+ Standard errors use influence function aggregation with WIF adjustment
789
+ to account for covariances across time periods within a cohort.
790
+ When a full survey design is present in precomputed, uses design-based
791
+ variance via compute_survey_if_variance().
792
+ """
793
+ # Collect all group aggregation data first
794
+ group_data_list = []
795
+ for g in groups:
796
+ g_effects = [
797
+ ((g, t), data["effect"])
798
+ for (gg, t), data in group_time_effects.items()
799
+ if gg == g and t >= g - self.anticipation
800
+ ]
801
+
802
+ if not g_effects:
803
+ continue
804
+
805
+ gt_pairs = [x[0] for x in g_effects]
806
+ effs = np.array([x[1] for x in g_effects])
807
+
808
+ # Exclude NaN effects from this group's aggregation
809
+ finite_mask = np.isfinite(effs)
810
+ if not np.all(finite_mask):
811
+ effs = effs[finite_mask]
812
+ gt_pairs = [gt for gt, m in zip(gt_pairs, finite_mask) if m]
813
+ if len(effs) == 0:
814
+ continue
815
+
816
+ weights = np.ones(len(effs)) / len(effs)
817
+ agg_effect = np.sum(weights * effs)
818
+
819
+ # Use WIF-adjusted SE (with survey design support)
820
+ groups_for_gt = np.array([gg for (gg, t) in gt_pairs])
821
+ agg_se, eff_df = self._compute_aggregated_se_with_wif(
822
+ gt_pairs, weights, effs, groups_for_gt, influence_func_info, df, unit, precomputed
823
+ )
824
+ group_data_list.append((g, agg_effect, agg_se, len(g_effects), eff_df))
825
+
826
+ if not group_data_list:
827
+ return {}
828
+
829
+ # Batch inference
830
+ agg_effects = np.array([x[1] for x in group_data_list])
831
+ agg_ses = np.array([x[2] for x in group_data_list])
832
+ df_survey_val = precomputed.get("df_survey") if precomputed is not None else None
833
+ # Guard: replicate design with undefined df → NaN inference
834
+ if (
835
+ df_survey_val is None
836
+ and precomputed is not None
837
+ and precomputed.get("resolved_survey_unit") is not None
838
+ and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance")
839
+ and precomputed["resolved_survey_unit"].uses_replicate_variance
840
+ ):
841
+ df_survey_val = 0
842
+ # Use minimum per-group effective df if any dropped replicates
843
+ non_none_dfs = [x[4] for x in group_data_list if x[4] is not None]
844
+ if non_none_dfs:
845
+ df_survey_val = min(non_none_dfs)
846
+ t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch(
847
+ agg_effects,
848
+ agg_ses,
849
+ alpha=self.alpha,
850
+ df=df_survey_val,
851
+ )
852
+
853
+ group_effects = {}
854
+ for idx, (g, agg_effect, agg_se, n_periods, _eff_df) in enumerate(group_data_list):
855
+ group_effects[g] = {
856
+ "effect": agg_effect,
857
+ "se": agg_se,
858
+ "t_stat": float(t_stats[idx]),
859
+ "p_value": float(p_values[idx]),
860
+ "conf_int": (float(ci_lowers[idx]), float(ci_uppers[idx])),
861
+ "n_periods": n_periods,
862
+ }
863
+
864
+ return group_effects