diff-diff 2.3.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,492 @@
1
+ """
2
+ Aggregation methods mixin for Callaway-Sant'Anna estimator.
3
+
4
+ This module provides the mixin class containing methods for aggregating
5
+ group-time average treatment effects into summary measures.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional, Set, Tuple
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from diff_diff.utils import (
14
+ compute_confidence_interval,
15
+ compute_p_value,
16
+ )
17
+
18
+ # Type alias for pre-computed structures (defined at module scope for runtime access)
19
+ PrecomputedData = Dict[str, Any]
20
+
21
+
22
+ class CallawaySantAnnaAggregationMixin:
23
+ """
24
+ Mixin class providing aggregation methods for CallawaySantAnna estimator.
25
+
26
+ This class is not intended to be used standalone. It provides methods
27
+ that are used by the main CallawaySantAnna class to aggregate group-time
28
+ effects into summary measures.
29
+ """
30
+
31
+ # Type hints for attributes accessed from the main class
32
+ alpha: float
33
+
34
+ # Type hint for anticipation attribute accessed from main class
35
+ anticipation: int
36
+
37
+ # Type hint for base_period attribute accessed from main class
38
+ base_period: str
39
+
40
+ def _aggregate_simple(
41
+ self,
42
+ group_time_effects: Dict,
43
+ influence_func_info: Dict,
44
+ df: pd.DataFrame,
45
+ unit: str,
46
+ precomputed: Optional["PrecomputedData"] = None,
47
+ ) -> Tuple[float, float]:
48
+ """
49
+ Compute simple weighted average of ATT(g,t).
50
+
51
+ Weights by group size (number of treated units).
52
+
53
+ Standard errors are computed using influence function aggregation,
54
+ which properly accounts for covariances across (g,t) pairs due to
55
+ shared control units. This includes the wif (weight influence function)
56
+ adjustment from R's `did` package that accounts for uncertainty in
57
+ estimating the group-size weights.
58
+
59
+ Note: Only post-treatment effects (t >= g - anticipation) are included
60
+ in the overall ATT. Pre-treatment effects are computed for parallel
61
+ trends assessment but are not aggregated into the overall ATT.
62
+ """
63
+ effects = []
64
+ weights_list = []
65
+ gt_pairs = []
66
+ groups_for_gt = []
67
+
68
+ for (g, t), data in group_time_effects.items():
69
+ # Only include post-treatment effects (t >= g - anticipation)
70
+ # Pre-treatment effects are for parallel trends, not overall ATT
71
+ if t < g - self.anticipation:
72
+ continue
73
+ effects.append(data['effect'])
74
+ weights_list.append(data['n_treated'])
75
+ gt_pairs.append((g, t))
76
+ groups_for_gt.append(g)
77
+
78
+ # Guard against empty post-treatment set
79
+ if len(effects) == 0:
80
+ import warnings
81
+ warnings.warn(
82
+ "No post-treatment effects available for overall ATT aggregation. "
83
+ "This can occur when cohorts lack post-treatment periods in the data.",
84
+ UserWarning,
85
+ stacklevel=2
86
+ )
87
+ return np.nan, np.nan
88
+
89
+ effects = np.array(effects)
90
+ weights = np.array(weights_list, dtype=float)
91
+ groups_for_gt = np.array(groups_for_gt)
92
+
93
+ # Normalize weights
94
+ total_weight = np.sum(weights)
95
+ weights_norm = weights / total_weight
96
+
97
+ # Weighted average
98
+ overall_att = np.sum(weights_norm * effects)
99
+
100
+ # Compute SE using influence function aggregation with wif adjustment
101
+ overall_se = self._compute_aggregated_se_with_wif(
102
+ gt_pairs, weights_norm, effects, groups_for_gt,
103
+ influence_func_info, df, unit, precomputed
104
+ )
105
+
106
+ return overall_att, overall_se
107
+
108
+ def _compute_aggregated_se(
109
+ self,
110
+ gt_pairs: List[Tuple[Any, Any]],
111
+ weights: np.ndarray,
112
+ influence_func_info: Dict,
113
+ ) -> float:
114
+ """
115
+ Compute standard error using influence function aggregation.
116
+
117
+ This properly accounts for covariances across (g,t) pairs by
118
+ aggregating unit-level influence functions:
119
+
120
+ ψ_i(overall) = Σ_{(g,t)} w_(g,t) × ψ_i(g,t)
121
+ Var(overall) = (1/n) Σ_i [ψ_i]²
122
+
123
+ This matches R's `did` package analytical SE formula.
124
+ """
125
+ if not influence_func_info:
126
+ # Fallback if no influence functions available
127
+ return 0.0
128
+
129
+ # Build unit index mapping from all (g,t) pairs
130
+ all_units = set()
131
+ for (g, t) in gt_pairs:
132
+ if (g, t) in influence_func_info:
133
+ info = influence_func_info[(g, t)]
134
+ all_units.update(info['treated_units'])
135
+ all_units.update(info['control_units'])
136
+
137
+ if not all_units:
138
+ return 0.0
139
+
140
+ all_units = sorted(all_units)
141
+ n_units = len(all_units)
142
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
143
+
144
+ # Aggregate influence functions across (g,t) pairs
145
+ psi_overall = np.zeros(n_units)
146
+
147
+ for j, (g, t) in enumerate(gt_pairs):
148
+ if (g, t) not in influence_func_info:
149
+ continue
150
+
151
+ info = influence_func_info[(g, t)]
152
+ w = weights[j]
153
+
154
+ # Treated unit contributions
155
+ for i, unit_id in enumerate(info['treated_units']):
156
+ idx = unit_to_idx[unit_id]
157
+ psi_overall[idx] += w * info['treated_inf'][i]
158
+
159
+ # Control unit contributions
160
+ for i, unit_id in enumerate(info['control_units']):
161
+ idx = unit_to_idx[unit_id]
162
+ psi_overall[idx] += w * info['control_inf'][i]
163
+
164
+ # Compute variance: Var(θ̄) = (1/n) Σᵢ ψᵢ²
165
+ variance = np.sum(psi_overall ** 2)
166
+ return np.sqrt(variance)
167
+
168
+ def _compute_aggregated_se_with_wif(
169
+ self,
170
+ gt_pairs: List[Tuple[Any, Any]],
171
+ weights: np.ndarray,
172
+ effects: np.ndarray,
173
+ groups_for_gt: np.ndarray,
174
+ influence_func_info: Dict,
175
+ df: pd.DataFrame,
176
+ unit: str,
177
+ precomputed: Optional["PrecomputedData"] = None,
178
+ ) -> float:
179
+ """
180
+ Compute SE with weight influence function (wif) adjustment.
181
+
182
+ This matches R's `did` package approach for "simple" aggregation,
183
+ which accounts for uncertainty in estimating group-size weights.
184
+
185
+ The wif adjustment adds variance due to the fact that aggregation
186
+ weights w_g = n_g / N depend on estimated group sizes.
187
+
188
+ Formula (matching R's did::aggte):
189
+ agg_inf_i = Σ_k w_k × inf_i_k + wif_i × ATT_k
190
+ se = sqrt(mean(agg_inf^2) / n)
191
+
192
+ where:
193
+ - k indexes "keepers" (post-treatment (g,t) pairs)
194
+ - w_k = pg[k] / sum(pg[keepers]) where pg = n_g / n_all
195
+ - wif captures how unit i influences the weight estimation
196
+ """
197
+ if not influence_func_info:
198
+ return 0.0
199
+
200
+ # Build unit index mapping
201
+ all_units_set: Set[Any] = set()
202
+ for (g, t) in gt_pairs:
203
+ if (g, t) in influence_func_info:
204
+ info = influence_func_info[(g, t)]
205
+ all_units_set.update(info['treated_units'])
206
+ all_units_set.update(info['control_units'])
207
+
208
+ if not all_units_set:
209
+ return 0.0
210
+
211
+ all_units = sorted(all_units_set)
212
+ n_units = len(all_units)
213
+ unit_to_idx = {u: i for i, u in enumerate(all_units)}
214
+
215
+ # Get unique groups and their information
216
+ unique_groups = sorted(set(groups_for_gt))
217
+ unique_groups_set = set(unique_groups)
218
+ group_to_idx = {g: i for i, g in enumerate(unique_groups)}
219
+
220
+ # Compute group-level probabilities matching R's formula:
221
+ # pg[g] = n_g / n_all (fraction of ALL units in group g)
222
+ # This differs from our old formula which used n_g / total_treated
223
+ group_sizes = {}
224
+ for g in unique_groups:
225
+ treated_in_g = df[df['first_treat'] == g][unit].nunique()
226
+ group_sizes[g] = treated_in_g
227
+
228
+ # pg indexed by group
229
+ pg_by_group = np.array([group_sizes[g] / n_units for g in unique_groups])
230
+
231
+ # pg indexed by keeper (each (g,t) pair gets its group's pg)
232
+ # This matches R's: pg <- pgg[match(group, originalglist)]
233
+ pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt])
234
+ sum_pg_keepers = np.sum(pg_keepers)
235
+
236
+ # Guard against zero weights (no keepers = no variance)
237
+ if sum_pg_keepers == 0:
238
+ return 0.0
239
+
240
+ # Standard aggregated influence (without wif)
241
+ psi_standard = np.zeros(n_units)
242
+
243
+ for j, (g, t) in enumerate(gt_pairs):
244
+ if (g, t) not in influence_func_info:
245
+ continue
246
+
247
+ info = influence_func_info[(g, t)]
248
+ w = weights[j]
249
+
250
+ # Vectorized influence function aggregation for treated units
251
+ treated_indices = np.array([unit_to_idx[uid] for uid in info['treated_units']])
252
+ if len(treated_indices) > 0:
253
+ np.add.at(psi_standard, treated_indices, w * info['treated_inf'])
254
+
255
+ # Vectorized influence function aggregation for control units
256
+ control_indices = np.array([unit_to_idx[uid] for uid in info['control_units']])
257
+ if len(control_indices) > 0:
258
+ np.add.at(psi_standard, control_indices, w * info['control_inf'])
259
+
260
+ # Build unit-group array using precomputed data if available
261
+ # This is O(n_units) instead of O(n_units × n_obs) DataFrame lookups
262
+ if precomputed is not None:
263
+ # Use precomputed cohort mapping
264
+ precomputed_units = precomputed['all_units']
265
+ precomputed_cohorts = precomputed['unit_cohorts']
266
+ precomputed_unit_to_idx = precomputed['unit_to_idx']
267
+
268
+ # Build unit_groups_array for the units in this SE computation
269
+ # A value of -1 indicates never-treated or other (not in unique_groups)
270
+ unit_groups_array = np.full(n_units, -1, dtype=np.float64)
271
+ for i, uid in enumerate(all_units):
272
+ if uid in precomputed_unit_to_idx:
273
+ cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]]
274
+ if cohort in unique_groups_set:
275
+ unit_groups_array[i] = cohort
276
+ else:
277
+ # Fallback: build from DataFrame (slow path for backward compatibility)
278
+ unit_groups_array = np.full(n_units, -1, dtype=np.float64)
279
+ for i, uid in enumerate(all_units):
280
+ unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0]
281
+ if unit_first_treat in unique_groups_set:
282
+ unit_groups_array[i] = unit_first_treat
283
+
284
+ # Vectorized WIF computation
285
+ # R's wif formula:
286
+ # if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers])
287
+ # if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2
288
+ # wif[i,k] = if1[i,k] - if2[i,k]
289
+ # wif_contrib[i] = sum_k(wif[i,k] * att[k])
290
+
291
+ # Build indicator matrix: (n_units, n_keepers)
292
+ # indicator_matrix[i, k] = 1.0 if unit i belongs to group for keeper k
293
+ groups_for_gt_array = np.array(groups_for_gt)
294
+ indicator_matrix = (unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :]).astype(np.float64)
295
+
296
+ # Vectorized indicator_sum: sum over keepers
297
+ # indicator_sum[i] = sum_k(indicator(G_i == group_k) - pg[k])
298
+ indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1)
299
+
300
+ # Vectorized wif matrix computation
301
+ # Suppress RuntimeWarnings for edge cases (small samples, extreme weights)
302
+ # in division operations and matrix multiplication
303
+ with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
304
+ # if1_matrix[i,k] = (indicator[i,k] - pg[k]) / sum_pg
305
+ if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers
306
+ # if2_matrix[i,k] = indicator_sum[i] * pg[k] / sum_pg^2
307
+ if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers ** 2)
308
+ wif_matrix = if1_matrix - if2_matrix
309
+
310
+ # Single matrix-vector multiply for all contributions
311
+ # wif_contrib[i] = sum_k(wif[i,k] * att[k])
312
+ wif_contrib = wif_matrix @ effects
313
+
314
+ # Check for non-finite values from edge cases
315
+ if not np.all(np.isfinite(wif_contrib)):
316
+ import warnings
317
+ n_nonfinite = np.sum(~np.isfinite(wif_contrib))
318
+ warnings.warn(
319
+ f"Non-finite values ({n_nonfinite}/{len(wif_contrib)}) in weight influence "
320
+ "function computation. This may occur with very small samples or extreme "
321
+ "weights. Returning NaN for SE to signal invalid inference.",
322
+ RuntimeWarning,
323
+ stacklevel=2
324
+ )
325
+ return np.nan # Signal invalid inference instead of biased SE
326
+
327
+ # Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
328
+ psi_wif = wif_contrib / n_units
329
+
330
+ # Combine standard and wif terms
331
+ psi_total = psi_standard + psi_wif
332
+
333
+ # Compute variance and SE
334
+ # R's formula: sqrt(mean(IF^2) / n) = sqrt(sum(IF^2) / n^2)
335
+ variance = np.sum(psi_total ** 2)
336
+ return np.sqrt(variance)
337
+
338
+ def _aggregate_event_study(
339
+ self,
340
+ group_time_effects: Dict,
341
+ influence_func_info: Dict,
342
+ groups: List[Any],
343
+ time_periods: List[Any],
344
+ balance_e: Optional[int] = None,
345
+ ) -> Dict[int, Dict[str, Any]]:
346
+ """
347
+ Aggregate effects by relative time (event study).
348
+
349
+ Computes average effect at each event time e = t - g.
350
+
351
+ Standard errors use influence function aggregation to account for
352
+ covariances across (g,t) pairs.
353
+ """
354
+ # Organize effects by relative time, keeping track of (g,t) pairs
355
+ effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, int]]] = {}
356
+
357
+ for (g, t), data in group_time_effects.items():
358
+ e = t - g # Relative time
359
+ if e not in effects_by_e:
360
+ effects_by_e[e] = []
361
+ effects_by_e[e].append((
362
+ (g, t), # Keep track of the (g,t) pair
363
+ data['effect'],
364
+ data['n_treated']
365
+ ))
366
+
367
+ # Balance the panel if requested
368
+ if balance_e is not None:
369
+ # Keep only groups that have effects at relative time balance_e
370
+ groups_at_e = set()
371
+ for (g, t), data in group_time_effects.items():
372
+ if t - g == balance_e:
373
+ groups_at_e.add(g)
374
+
375
+ # Filter effects to only include balanced groups
376
+ balanced_effects: Dict[int, List[Tuple[Tuple[Any, Any], float, int]]] = {}
377
+ for (g, t), data in group_time_effects.items():
378
+ if g in groups_at_e:
379
+ e = t - g
380
+ if e not in balanced_effects:
381
+ balanced_effects[e] = []
382
+ balanced_effects[e].append((
383
+ (g, t),
384
+ data['effect'],
385
+ data['n_treated']
386
+ ))
387
+ effects_by_e = balanced_effects
388
+
389
+ # Compute aggregated effects
390
+ event_study_effects = {}
391
+
392
+ for e, effect_list in sorted(effects_by_e.items()):
393
+ gt_pairs = [x[0] for x in effect_list]
394
+ effs = np.array([x[1] for x in effect_list])
395
+ ns = np.array([x[2] for x in effect_list], dtype=float)
396
+
397
+ # Weight by group size
398
+ weights = ns / np.sum(ns)
399
+
400
+ agg_effect = np.sum(weights * effs)
401
+
402
+ # Compute SE using influence function aggregation
403
+ agg_se = self._compute_aggregated_se(
404
+ gt_pairs, weights, influence_func_info
405
+ )
406
+
407
+ t_stat = agg_effect / agg_se if np.isfinite(agg_se) and agg_se > 0 else np.nan
408
+ p_val = compute_p_value(t_stat)
409
+ ci = compute_confidence_interval(agg_effect, agg_se, self.alpha)
410
+
411
+ event_study_effects[e] = {
412
+ 'effect': agg_effect,
413
+ 'se': agg_se,
414
+ 't_stat': t_stat,
415
+ 'p_value': p_val,
416
+ 'conf_int': ci,
417
+ 'n_groups': len(effect_list),
418
+ }
419
+
420
+ # Add reference period for universal base period mode (matches R did package)
421
+ # The reference period e = -1 - anticipation has effect = 0 by construction
422
+ # Only add if there are actual computed effects (guard against empty data)
423
+ if getattr(self, 'base_period', 'varying') == "universal":
424
+ ref_period = -1 - self.anticipation
425
+ # Only inject reference if we have at least one real effect
426
+ if event_study_effects and ref_period not in event_study_effects:
427
+ event_study_effects[ref_period] = {
428
+ 'effect': 0.0,
429
+ 'se': np.nan, # Undefined - no data, normalization constraint
430
+ 't_stat': np.nan, # Undefined - normalization constraint
431
+ 'p_value': np.nan,
432
+ 'conf_int': (np.nan, np.nan), # NaN propagation for undefined inference
433
+ 'n_groups': 0, # No groups contribute - fixed by construction
434
+ }
435
+
436
+ return event_study_effects
437
+
438
+ def _aggregate_by_group(
439
+ self,
440
+ group_time_effects: Dict,
441
+ influence_func_info: Dict,
442
+ groups: List[Any],
443
+ ) -> Dict[Any, Dict[str, Any]]:
444
+ """
445
+ Aggregate effects by treatment cohort.
446
+
447
+ Computes average effect for each cohort across all post-treatment periods.
448
+
449
+ Standard errors use influence function aggregation to account for
450
+ covariances across time periods within a cohort.
451
+ """
452
+ group_effects = {}
453
+
454
+ for g in groups:
455
+ # Get all effects for this group (post-treatment only: t >= g - anticipation)
456
+ # Keep track of (g, t) pairs for influence function aggregation
457
+ g_effects = [
458
+ ((g, t), data['effect'])
459
+ for (gg, t), data in group_time_effects.items()
460
+ if gg == g and t >= g - self.anticipation
461
+ ]
462
+
463
+ if not g_effects:
464
+ continue
465
+
466
+ gt_pairs = [x[0] for x in g_effects]
467
+ effs = np.array([x[1] for x in g_effects])
468
+
469
+ # Equal weight across time periods for a group
470
+ weights = np.ones(len(effs)) / len(effs)
471
+
472
+ agg_effect = np.sum(weights * effs)
473
+
474
+ # Compute SE using influence function aggregation
475
+ agg_se = self._compute_aggregated_se(
476
+ gt_pairs, weights, influence_func_info
477
+ )
478
+
479
+ t_stat = agg_effect / agg_se if np.isfinite(agg_se) and agg_se > 0 else np.nan
480
+ p_val = compute_p_value(t_stat)
481
+ ci = compute_confidence_interval(agg_effect, agg_se, self.alpha)
482
+
483
+ group_effects[g] = {
484
+ 'effect': agg_effect,
485
+ 'se': agg_se,
486
+ 't_stat': t_stat,
487
+ 'p_value': p_val,
488
+ 'conf_int': ci,
489
+ 'n_periods': len(g_effects),
490
+ }
491
+
492
+ return group_effects