diff-diff 3.0.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_diff/__init__.py +382 -0
- diff_diff/_backend.py +134 -0
- diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
- diff_diff/bacon.py +1140 -0
- diff_diff/bootstrap_utils.py +730 -0
- diff_diff/continuous_did.py +1626 -0
- diff_diff/continuous_did_bspline.py +190 -0
- diff_diff/continuous_did_results.py +374 -0
- diff_diff/datasets.py +815 -0
- diff_diff/diagnostics.py +882 -0
- diff_diff/efficient_did.py +1770 -0
- diff_diff/efficient_did_bootstrap.py +359 -0
- diff_diff/efficient_did_covariates.py +899 -0
- diff_diff/efficient_did_results.py +368 -0
- diff_diff/efficient_did_weights.py +617 -0
- diff_diff/estimators.py +1501 -0
- diff_diff/honest_did.py +2585 -0
- diff_diff/imputation.py +2458 -0
- diff_diff/imputation_bootstrap.py +418 -0
- diff_diff/imputation_results.py +448 -0
- diff_diff/linalg.py +2538 -0
- diff_diff/power.py +2588 -0
- diff_diff/practitioner.py +869 -0
- diff_diff/prep.py +1738 -0
- diff_diff/prep_dgp.py +1718 -0
- diff_diff/pretrends.py +1105 -0
- diff_diff/results.py +918 -0
- diff_diff/stacked_did.py +1049 -0
- diff_diff/stacked_did_results.py +339 -0
- diff_diff/staggered.py +3895 -0
- diff_diff/staggered_aggregation.py +864 -0
- diff_diff/staggered_bootstrap.py +752 -0
- diff_diff/staggered_results.py +416 -0
- diff_diff/staggered_triple_diff.py +1545 -0
- diff_diff/staggered_triple_diff_results.py +416 -0
- diff_diff/sun_abraham.py +1685 -0
- diff_diff/survey.py +1981 -0
- diff_diff/synthetic_did.py +1136 -0
- diff_diff/triple_diff.py +2047 -0
- diff_diff/trop.py +952 -0
- diff_diff/trop_global.py +1270 -0
- diff_diff/trop_local.py +1307 -0
- diff_diff/trop_results.py +356 -0
- diff_diff/twfe.py +542 -0
- diff_diff/two_stage.py +1952 -0
- diff_diff/two_stage_bootstrap.py +520 -0
- diff_diff/two_stage_results.py +400 -0
- diff_diff/utils.py +1902 -0
- diff_diff/visualization/__init__.py +61 -0
- diff_diff/visualization/_common.py +328 -0
- diff_diff/visualization/_continuous.py +274 -0
- diff_diff/visualization/_diagnostic.py +817 -0
- diff_diff/visualization/_event_study.py +1086 -0
- diff_diff/visualization/_power.py +661 -0
- diff_diff/visualization/_staggered.py +833 -0
- diff_diff/visualization/_synthetic.py +197 -0
- diff_diff/wooldridge.py +1285 -0
- diff_diff/wooldridge_results.py +349 -0
- diff_diff-3.0.1.dist-info/METADATA +2997 -0
- diff_diff-3.0.1.dist-info/RECORD +62 -0
- diff_diff-3.0.1.dist-info/WHEEL +4 -0
- diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
|
@@ -0,0 +1,864 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Aggregation methods mixin for Callaway-Sant'Anna estimator.
|
|
3
|
+
|
|
4
|
+
This module provides the mixin class containing methods for aggregating
|
|
5
|
+
group-time average treatment effects into summary measures.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
from diff_diff.utils import safe_inference_batch
|
|
14
|
+
|
|
15
|
+
# Type alias for pre-computed structures (defined at module scope for runtime access)
|
|
16
|
+
PrecomputedData = Dict[str, Any]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CallawaySantAnnaAggregationMixin:
|
|
20
|
+
"""
|
|
21
|
+
Mixin class providing aggregation methods for CallawaySantAnna estimator.
|
|
22
|
+
|
|
23
|
+
This class is not intended to be used standalone. It provides methods
|
|
24
|
+
that are used by the main CallawaySantAnna class to aggregate group-time
|
|
25
|
+
effects into summary measures.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# Type hints for attributes accessed from the main class
|
|
29
|
+
alpha: float
|
|
30
|
+
|
|
31
|
+
# Type hint for anticipation attribute accessed from main class
|
|
32
|
+
anticipation: int
|
|
33
|
+
|
|
34
|
+
# Type hint for base_period attribute accessed from main class
|
|
35
|
+
base_period: str
|
|
36
|
+
|
|
37
|
+
def _aggregate_simple(
|
|
38
|
+
self,
|
|
39
|
+
group_time_effects: Dict,
|
|
40
|
+
influence_func_info: Dict,
|
|
41
|
+
df: pd.DataFrame,
|
|
42
|
+
unit: str,
|
|
43
|
+
precomputed: Optional["PrecomputedData"] = None,
|
|
44
|
+
) -> Tuple[float, float]:
|
|
45
|
+
"""
|
|
46
|
+
Compute simple weighted average of ATT(g,t).
|
|
47
|
+
|
|
48
|
+
Weights by group size (number of treated units).
|
|
49
|
+
|
|
50
|
+
Standard errors are computed using influence function aggregation,
|
|
51
|
+
which properly accounts for covariances across (g,t) pairs due to
|
|
52
|
+
shared control units. This includes the wif (weight influence function)
|
|
53
|
+
adjustment from R's `did` package that accounts for uncertainty in
|
|
54
|
+
estimating the group-size weights.
|
|
55
|
+
|
|
56
|
+
Note: Only post-treatment effects (t >= g - anticipation) are included
|
|
57
|
+
in the overall ATT. Pre-treatment effects are computed for parallel
|
|
58
|
+
trends assessment but are not aggregated into the overall ATT.
|
|
59
|
+
"""
|
|
60
|
+
effects = []
|
|
61
|
+
weights_list = []
|
|
62
|
+
gt_pairs = []
|
|
63
|
+
groups_for_gt = []
|
|
64
|
+
|
|
65
|
+
# For survey: compute fixed per-cohort weight sums from the full
|
|
66
|
+
# unit-level sample (matching R's did::aggte pg = n_g / N).
|
|
67
|
+
survey_cohort_weights = None
|
|
68
|
+
if precomputed is not None and precomputed.get("survey_weights") is not None:
|
|
69
|
+
sw = precomputed["survey_weights"]
|
|
70
|
+
unit_cohorts = precomputed["unit_cohorts"]
|
|
71
|
+
survey_cohort_weights = {}
|
|
72
|
+
for g in np.unique(unit_cohorts):
|
|
73
|
+
if g > 0: # exclude never-treated (0)
|
|
74
|
+
survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g]))
|
|
75
|
+
|
|
76
|
+
for (g, t), data in group_time_effects.items():
|
|
77
|
+
# Only include post-treatment effects (t >= g - anticipation)
|
|
78
|
+
# Pre-treatment effects are for parallel trends, not overall ATT
|
|
79
|
+
if t < g - self.anticipation:
|
|
80
|
+
continue
|
|
81
|
+
effects.append(data["effect"])
|
|
82
|
+
# Use fixed cohort-level survey weight sum for aggregation.
|
|
83
|
+
# For RCS, data["agg_weight"] holds the fixed cohort mass;
|
|
84
|
+
# for panel, fallback to data["n_treated"].
|
|
85
|
+
if survey_cohort_weights is not None and g in survey_cohort_weights:
|
|
86
|
+
weights_list.append(survey_cohort_weights[g])
|
|
87
|
+
else:
|
|
88
|
+
weights_list.append(data.get("agg_weight", data["n_treated"]))
|
|
89
|
+
gt_pairs.append((g, t))
|
|
90
|
+
groups_for_gt.append(g)
|
|
91
|
+
|
|
92
|
+
# Guard against empty post-treatment set
|
|
93
|
+
if len(effects) == 0:
|
|
94
|
+
import warnings
|
|
95
|
+
|
|
96
|
+
warnings.warn(
|
|
97
|
+
"No post-treatment effects available for overall ATT aggregation. "
|
|
98
|
+
"This can occur when cohorts lack post-treatment periods in the data.",
|
|
99
|
+
UserWarning,
|
|
100
|
+
stacklevel=2,
|
|
101
|
+
)
|
|
102
|
+
return np.nan, np.nan, None
|
|
103
|
+
|
|
104
|
+
effects = np.array(effects)
|
|
105
|
+
weights = np.array(weights_list, dtype=float)
|
|
106
|
+
groups_for_gt = np.array(groups_for_gt)
|
|
107
|
+
|
|
108
|
+
# Exclude NaN effects from aggregation (R's aggte() convention).
|
|
109
|
+
# No warning here — fit() emits a consolidated skip warning covering
|
|
110
|
+
# all estimation paths (vectorized, covariate, general, RC).
|
|
111
|
+
finite_mask = np.isfinite(effects)
|
|
112
|
+
if not np.all(finite_mask):
|
|
113
|
+
effects = effects[finite_mask]
|
|
114
|
+
weights = weights[finite_mask]
|
|
115
|
+
gt_pairs = [gt for gt, m in zip(gt_pairs, finite_mask) if m]
|
|
116
|
+
groups_for_gt = groups_for_gt[finite_mask]
|
|
117
|
+
|
|
118
|
+
if len(effects) == 0:
|
|
119
|
+
import warnings
|
|
120
|
+
|
|
121
|
+
warnings.warn(
|
|
122
|
+
"All post-treatment effects are NaN. Cannot compute overall ATT.",
|
|
123
|
+
UserWarning,
|
|
124
|
+
stacklevel=2,
|
|
125
|
+
)
|
|
126
|
+
return np.nan, np.nan, None
|
|
127
|
+
|
|
128
|
+
# Normalize weights
|
|
129
|
+
total_weight = np.sum(weights)
|
|
130
|
+
weights_norm = weights / total_weight
|
|
131
|
+
|
|
132
|
+
# Weighted average
|
|
133
|
+
overall_att = np.sum(weights_norm * effects)
|
|
134
|
+
|
|
135
|
+
# Compute SE using influence function aggregation with wif adjustment
|
|
136
|
+
overall_se, effective_df = self._compute_aggregated_se_with_wif(
|
|
137
|
+
gt_pairs,
|
|
138
|
+
weights_norm,
|
|
139
|
+
effects,
|
|
140
|
+
groups_for_gt,
|
|
141
|
+
influence_func_info,
|
|
142
|
+
df,
|
|
143
|
+
unit,
|
|
144
|
+
precomputed,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
return overall_att, overall_se, effective_df
|
|
148
|
+
|
|
149
|
+
def _compute_aggregated_se(
|
|
150
|
+
self,
|
|
151
|
+
gt_pairs: List[Tuple[Any, Any]],
|
|
152
|
+
weights: np.ndarray,
|
|
153
|
+
influence_func_info: Dict,
|
|
154
|
+
n_units: Optional[int] = None,
|
|
155
|
+
) -> float:
|
|
156
|
+
"""
|
|
157
|
+
Compute standard error using influence function aggregation.
|
|
158
|
+
|
|
159
|
+
This properly accounts for covariances across (g,t) pairs by
|
|
160
|
+
aggregating unit-level influence functions:
|
|
161
|
+
|
|
162
|
+
ψ_i(overall) = Σ_{(g,t)} w_(g,t) × ψ_i(g,t)
|
|
163
|
+
Var(overall) = (1/n) Σ_i [ψ_i]²
|
|
164
|
+
|
|
165
|
+
This matches R's `did` package analytical SE formula.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
n_units : int, optional
|
|
170
|
+
Size of the canonical index space (len(precomputed['all_units'])).
|
|
171
|
+
When provided, influence function indices (treated_idx, control_idx)
|
|
172
|
+
index directly into this space, eliminating dict lookups.
|
|
173
|
+
"""
|
|
174
|
+
if not influence_func_info:
|
|
175
|
+
return 0.0
|
|
176
|
+
|
|
177
|
+
if n_units is None:
|
|
178
|
+
# Fallback: infer size from influence function info
|
|
179
|
+
max_idx = 0
|
|
180
|
+
for g, t in gt_pairs:
|
|
181
|
+
if (g, t) in influence_func_info:
|
|
182
|
+
info = influence_func_info[(g, t)]
|
|
183
|
+
if len(info["treated_idx"]) > 0:
|
|
184
|
+
max_idx = max(max_idx, info["treated_idx"].max())
|
|
185
|
+
if len(info["control_idx"]) > 0:
|
|
186
|
+
max_idx = max(max_idx, info["control_idx"].max())
|
|
187
|
+
n_units = max_idx + 1
|
|
188
|
+
|
|
189
|
+
if n_units == 0:
|
|
190
|
+
return 0.0
|
|
191
|
+
|
|
192
|
+
# Aggregate influence functions across (g,t) pairs
|
|
193
|
+
psi_overall = np.zeros(n_units)
|
|
194
|
+
|
|
195
|
+
for j, (g, t) in enumerate(gt_pairs):
|
|
196
|
+
if (g, t) not in influence_func_info:
|
|
197
|
+
continue
|
|
198
|
+
|
|
199
|
+
info = influence_func_info[(g, t)]
|
|
200
|
+
w = weights[j]
|
|
201
|
+
|
|
202
|
+
# Vectorized influence function aggregation using index arrays
|
|
203
|
+
treated_idx = info["treated_idx"]
|
|
204
|
+
if len(treated_idx) > 0:
|
|
205
|
+
np.add.at(psi_overall, treated_idx, w * info["treated_inf"])
|
|
206
|
+
|
|
207
|
+
control_idx = info["control_idx"]
|
|
208
|
+
if len(control_idx) > 0:
|
|
209
|
+
np.add.at(psi_overall, control_idx, w * info["control_inf"])
|
|
210
|
+
|
|
211
|
+
# Compute variance: Var(θ̄) = (1/n) Σᵢ ψᵢ²
|
|
212
|
+
variance = np.sum(psi_overall**2)
|
|
213
|
+
return np.sqrt(variance)
|
|
214
|
+
|
|
215
|
+
def _compute_combined_influence_function(
|
|
216
|
+
self,
|
|
217
|
+
gt_pairs: List[Tuple[Any, Any]],
|
|
218
|
+
weights: np.ndarray,
|
|
219
|
+
effects: np.ndarray,
|
|
220
|
+
groups_for_gt: np.ndarray,
|
|
221
|
+
influence_func_info: Dict,
|
|
222
|
+
df: pd.DataFrame,
|
|
223
|
+
unit: str,
|
|
224
|
+
precomputed: Optional["PrecomputedData"] = None,
|
|
225
|
+
global_unit_to_idx: Optional[Dict[Any, int]] = None,
|
|
226
|
+
n_global_units: Optional[int] = None,
|
|
227
|
+
) -> Tuple[np.ndarray, Optional[List]]:
|
|
228
|
+
"""
|
|
229
|
+
Compute the combined (standard IF + WIF) influence function vector.
|
|
230
|
+
|
|
231
|
+
If global_unit_to_idx / n_global_units are provided, the returned vector
|
|
232
|
+
is zero-padded to the global unit set for bootstrap alignment.
|
|
233
|
+
Otherwise, the returned vector is indexed by the local unit set
|
|
234
|
+
(all units appearing in the (g,t) pairs).
|
|
235
|
+
|
|
236
|
+
Returns
|
|
237
|
+
-------
|
|
238
|
+
combined_if : np.ndarray
|
|
239
|
+
Per-unit combined influence function (standard IF + WIF).
|
|
240
|
+
all_units : list or None
|
|
241
|
+
Ordered list of units (only when using local indexing).
|
|
242
|
+
"""
|
|
243
|
+
if not influence_func_info:
|
|
244
|
+
if n_global_units is not None:
|
|
245
|
+
return np.zeros(n_global_units), None
|
|
246
|
+
return np.zeros(0), None
|
|
247
|
+
|
|
248
|
+
# Detect RCS mode via explicit flag. In RCS, obs indices ARE array positions.
|
|
249
|
+
_is_rcs = precomputed is not None and not precomputed.get("is_panel", True)
|
|
250
|
+
|
|
251
|
+
# Build unit index mapping (local or global)
|
|
252
|
+
if _is_rcs and n_global_units is not None:
|
|
253
|
+
# RCS: direct indexing — obs indices are the array positions
|
|
254
|
+
n_units = n_global_units
|
|
255
|
+
all_units = None
|
|
256
|
+
elif global_unit_to_idx is not None and n_global_units is not None:
|
|
257
|
+
n_units = n_global_units
|
|
258
|
+
all_units = None # caller already has the unit list
|
|
259
|
+
else:
|
|
260
|
+
all_units_set: Set[Any] = set()
|
|
261
|
+
for g, t in gt_pairs:
|
|
262
|
+
if (g, t) in influence_func_info:
|
|
263
|
+
info = influence_func_info[(g, t)]
|
|
264
|
+
all_units_set.update(info["treated_units"])
|
|
265
|
+
all_units_set.update(info["control_units"])
|
|
266
|
+
|
|
267
|
+
if not all_units_set:
|
|
268
|
+
return np.zeros(0), []
|
|
269
|
+
|
|
270
|
+
all_units = sorted(all_units_set)
|
|
271
|
+
n_units = len(all_units)
|
|
272
|
+
|
|
273
|
+
# Get unique groups and their information
|
|
274
|
+
unique_groups = sorted(set(groups_for_gt))
|
|
275
|
+
unique_groups_set = set(unique_groups)
|
|
276
|
+
group_to_idx = {g: i for i, g in enumerate(unique_groups)}
|
|
277
|
+
|
|
278
|
+
# Check for survey weights in precomputed data
|
|
279
|
+
survey_w = precomputed.get("survey_weights") if precomputed is not None else None
|
|
280
|
+
|
|
281
|
+
# Compute group-level probabilities matching R's formula:
|
|
282
|
+
# pg[g] = n_g / n_all (fraction of ALL units in group g)
|
|
283
|
+
# With survey weights: pg[g] = sum(sw_g) / sum(sw_all)
|
|
284
|
+
group_sizes = {}
|
|
285
|
+
if survey_w is not None:
|
|
286
|
+
# Survey-weighted group sizes
|
|
287
|
+
precomputed_cohorts = precomputed["unit_cohorts"]
|
|
288
|
+
for g in unique_groups:
|
|
289
|
+
mask_g = precomputed_cohorts == g
|
|
290
|
+
group_sizes[g] = float(np.sum(survey_w[mask_g]))
|
|
291
|
+
total_weight = float(np.sum(survey_w))
|
|
292
|
+
elif _is_rcs:
|
|
293
|
+
# RCS without survey: count observations per cohort
|
|
294
|
+
precomputed_cohorts = precomputed["unit_cohorts"]
|
|
295
|
+
for g in unique_groups:
|
|
296
|
+
group_sizes[g] = int(np.sum(precomputed_cohorts == g))
|
|
297
|
+
total_weight = float(n_units)
|
|
298
|
+
else:
|
|
299
|
+
for g in unique_groups:
|
|
300
|
+
treated_in_g = df[df["first_treat"] == g][unit].nunique()
|
|
301
|
+
group_sizes[g] = treated_in_g
|
|
302
|
+
total_weight = float(n_units)
|
|
303
|
+
|
|
304
|
+
# pg indexed by group
|
|
305
|
+
pg_by_group = np.array([group_sizes[g] / total_weight for g in unique_groups])
|
|
306
|
+
|
|
307
|
+
# pg indexed by keeper (each (g,t) pair gets its group's pg)
|
|
308
|
+
pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt])
|
|
309
|
+
sum_pg_keepers = np.sum(pg_keepers)
|
|
310
|
+
|
|
311
|
+
# Guard against zero weights (no keepers = no variance)
|
|
312
|
+
if sum_pg_keepers == 0:
|
|
313
|
+
return np.zeros(n_units), all_units
|
|
314
|
+
|
|
315
|
+
# Standard aggregated influence (without wif)
|
|
316
|
+
psi_standard = np.zeros(n_units)
|
|
317
|
+
|
|
318
|
+
for j, (g, t) in enumerate(gt_pairs):
|
|
319
|
+
if (g, t) not in influence_func_info:
|
|
320
|
+
continue
|
|
321
|
+
|
|
322
|
+
info = influence_func_info[(g, t)]
|
|
323
|
+
w = weights[j]
|
|
324
|
+
|
|
325
|
+
# Vectorized influence function aggregation using precomputed index arrays
|
|
326
|
+
treated_idx = info["treated_idx"]
|
|
327
|
+
if len(treated_idx) > 0:
|
|
328
|
+
np.add.at(psi_standard, treated_idx, w * info["treated_inf"])
|
|
329
|
+
|
|
330
|
+
control_idx = info["control_idx"]
|
|
331
|
+
if len(control_idx) > 0:
|
|
332
|
+
np.add.at(psi_standard, control_idx, w * info["control_inf"])
|
|
333
|
+
|
|
334
|
+
# Build unit-group array: normalize iterator to (idx, uid) pairs
|
|
335
|
+
unit_groups_array = np.full(n_units, -1, dtype=np.float64)
|
|
336
|
+
|
|
337
|
+
if _is_rcs:
|
|
338
|
+
# RCS: direct vectorized assignment — obs indices are positions
|
|
339
|
+
precomputed_cohorts = precomputed["unit_cohorts"]
|
|
340
|
+
for g in unique_groups:
|
|
341
|
+
mask_g = precomputed_cohorts == g
|
|
342
|
+
unit_groups_array[mask_g] = g
|
|
343
|
+
elif global_unit_to_idx is not None:
|
|
344
|
+
idx_uid_pairs = [(idx, uid) for uid, idx in global_unit_to_idx.items()]
|
|
345
|
+
|
|
346
|
+
if precomputed is not None:
|
|
347
|
+
precomputed_cohorts = precomputed["unit_cohorts"]
|
|
348
|
+
precomputed_unit_to_idx = precomputed["unit_to_idx"]
|
|
349
|
+
for idx, uid in idx_uid_pairs:
|
|
350
|
+
if uid in precomputed_unit_to_idx:
|
|
351
|
+
cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]]
|
|
352
|
+
if cohort in unique_groups_set:
|
|
353
|
+
unit_groups_array[idx] = cohort
|
|
354
|
+
else:
|
|
355
|
+
for idx, uid in idx_uid_pairs:
|
|
356
|
+
unit_first_treat = df[df[unit] == uid]["first_treat"].iloc[0]
|
|
357
|
+
if unit_first_treat in unique_groups_set:
|
|
358
|
+
unit_groups_array[idx] = unit_first_treat
|
|
359
|
+
else:
|
|
360
|
+
idx_uid_pairs = list(enumerate(all_units))
|
|
361
|
+
for idx, uid in idx_uid_pairs:
|
|
362
|
+
unit_first_treat = df[df[unit] == uid]["first_treat"].iloc[0]
|
|
363
|
+
if unit_first_treat in unique_groups_set:
|
|
364
|
+
unit_groups_array[idx] = unit_first_treat
|
|
365
|
+
|
|
366
|
+
# Vectorized WIF computation
|
|
367
|
+
groups_for_gt_array = np.array(groups_for_gt)
|
|
368
|
+
indicator_matrix = (
|
|
369
|
+
unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :]
|
|
370
|
+
).astype(np.float64)
|
|
371
|
+
|
|
372
|
+
if survey_w is not None:
|
|
373
|
+
# Survey-weighted WIF matching R's did::wif() / compute.aggte.R.
|
|
374
|
+
# pg_k = E[w_i * 1{G_i=g}] is the weighted group share.
|
|
375
|
+
# IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT s_i * (1{G_i=g} - pg_k).
|
|
376
|
+
# The pg subtraction is NOT weighted by s_i because pg is already
|
|
377
|
+
# the population-level expected value of w_i * 1{G_i=g}.
|
|
378
|
+
if _is_rcs and precomputed is not None:
|
|
379
|
+
# RCS: survey weights are already per-observation, direct indexing
|
|
380
|
+
unit_sw = survey_w
|
|
381
|
+
elif global_unit_to_idx is not None and precomputed is not None:
|
|
382
|
+
unit_sw = np.zeros(n_units)
|
|
383
|
+
precomputed_unit_to_idx_local = precomputed["unit_to_idx"]
|
|
384
|
+
idx_uid_pairs_sw = [(idx, uid) for uid, idx in global_unit_to_idx.items()]
|
|
385
|
+
for idx, uid in idx_uid_pairs_sw:
|
|
386
|
+
if uid in precomputed_unit_to_idx_local:
|
|
387
|
+
pc_idx = precomputed_unit_to_idx_local[uid]
|
|
388
|
+
unit_sw[idx] = survey_w[pc_idx]
|
|
389
|
+
else:
|
|
390
|
+
unit_sw = np.ones(n_units)
|
|
391
|
+
|
|
392
|
+
# w_i * 1{G_i == g_k} - pg_k (matches R's did::wif)
|
|
393
|
+
weighted_indicator = indicator_matrix * unit_sw[:, np.newaxis]
|
|
394
|
+
indicator_diff = weighted_indicator - pg_keepers
|
|
395
|
+
indicator_sum_w = np.sum(indicator_diff, axis=1)
|
|
396
|
+
|
|
397
|
+
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
|
|
398
|
+
if1_matrix = indicator_diff / sum_pg_keepers
|
|
399
|
+
if2_matrix = np.outer(indicator_sum_w, pg_keepers) / (sum_pg_keepers**2)
|
|
400
|
+
wif_matrix = if1_matrix - if2_matrix
|
|
401
|
+
wif_contrib = wif_matrix @ effects
|
|
402
|
+
else:
|
|
403
|
+
indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1)
|
|
404
|
+
|
|
405
|
+
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
|
|
406
|
+
if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers
|
|
407
|
+
if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers**2)
|
|
408
|
+
wif_matrix = if1_matrix - if2_matrix
|
|
409
|
+
wif_contrib = wif_matrix @ effects
|
|
410
|
+
|
|
411
|
+
# Check for non-finite values from edge cases
|
|
412
|
+
if not np.all(np.isfinite(wif_contrib)):
|
|
413
|
+
import warnings
|
|
414
|
+
|
|
415
|
+
n_nonfinite = np.sum(~np.isfinite(wif_contrib))
|
|
416
|
+
warnings.warn(
|
|
417
|
+
f"Non-finite values ({n_nonfinite}/{len(wif_contrib)}) in weight influence "
|
|
418
|
+
"function computation. This may occur with very small samples or extreme "
|
|
419
|
+
"weights. Returning NaN for SE to signal invalid inference.",
|
|
420
|
+
RuntimeWarning,
|
|
421
|
+
stacklevel=2,
|
|
422
|
+
)
|
|
423
|
+
nan_result = np.full(n_units, np.nan)
|
|
424
|
+
return nan_result, all_units
|
|
425
|
+
|
|
426
|
+
# Scale by 1/total_weight to match R's getSE formula
|
|
427
|
+
# (for non-survey, total_weight == n_units; for survey, total_weight == sum(sw))
|
|
428
|
+
psi_wif = wif_contrib / total_weight
|
|
429
|
+
|
|
430
|
+
# Combine standard and wif terms
|
|
431
|
+
psi_total = psi_standard + psi_wif
|
|
432
|
+
|
|
433
|
+
return psi_total, all_units
|
|
434
|
+
|
|
435
|
+
def _compute_aggregated_se_with_wif(
|
|
436
|
+
self,
|
|
437
|
+
gt_pairs: List[Tuple[Any, Any]],
|
|
438
|
+
weights: np.ndarray,
|
|
439
|
+
effects: np.ndarray,
|
|
440
|
+
groups_for_gt: np.ndarray,
|
|
441
|
+
influence_func_info: Dict,
|
|
442
|
+
df: pd.DataFrame,
|
|
443
|
+
unit: str,
|
|
444
|
+
precomputed: Optional["PrecomputedData"] = None,
|
|
445
|
+
return_psi: bool = False,
|
|
446
|
+
) -> "Union[float, Tuple[float, np.ndarray]]":
|
|
447
|
+
"""
|
|
448
|
+
Compute SE with weight influence function (wif) adjustment.
|
|
449
|
+
|
|
450
|
+
This matches R's `did` package approach for aggregation,
|
|
451
|
+
which accounts for uncertainty in estimating group-size weights.
|
|
452
|
+
|
|
453
|
+
When a full survey design (strata/PSU/FPC) is available in
|
|
454
|
+
``precomputed['resolved_survey']``, the design-based variance
|
|
455
|
+
:func:`compute_survey_if_variance` is used instead of the simple
|
|
456
|
+
``sum(psi^2)`` formula.
|
|
457
|
+
|
|
458
|
+
Formula (matching R's did::aggte):
|
|
459
|
+
agg_inf_i = Σ_k w_k × inf_i_k + wif_i × ATT_k
|
|
460
|
+
se = sqrt(mean(agg_inf^2) / n)
|
|
461
|
+
"""
|
|
462
|
+
# Extract global unit info for correct pg = n_g / N_total scaling.
|
|
463
|
+
# Without this, the local path builds the unit set from only units in
|
|
464
|
+
# the selected (g,t) pairs, causing pg overestimation at extreme event
|
|
465
|
+
# times where only early-adopter groups have data.
|
|
466
|
+
global_unit_to_idx = None
|
|
467
|
+
n_global_units = None
|
|
468
|
+
if precomputed is not None:
|
|
469
|
+
global_unit_to_idx = precomputed["unit_to_idx"] # None for RCS
|
|
470
|
+
n_global_units = precomputed.get(
|
|
471
|
+
"canonical_size", len(precomputed.get("all_units", []))
|
|
472
|
+
)
|
|
473
|
+
elif df is not None and unit is not None:
|
|
474
|
+
n_global_units = df[unit].nunique()
|
|
475
|
+
|
|
476
|
+
psi_total, _ = self._compute_combined_influence_function(
|
|
477
|
+
gt_pairs,
|
|
478
|
+
weights,
|
|
479
|
+
effects,
|
|
480
|
+
groups_for_gt,
|
|
481
|
+
influence_func_info,
|
|
482
|
+
df,
|
|
483
|
+
unit,
|
|
484
|
+
precomputed,
|
|
485
|
+
global_unit_to_idx=global_unit_to_idx,
|
|
486
|
+
n_global_units=n_global_units,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
if len(psi_total) == 0:
|
|
490
|
+
return (0.0, psi_total) if return_psi else 0.0
|
|
491
|
+
|
|
492
|
+
# Check for NaN propagation from non-finite WIF
|
|
493
|
+
if not np.all(np.isfinite(psi_total)):
|
|
494
|
+
return (np.nan, psi_total) if return_psi else np.nan
|
|
495
|
+
|
|
496
|
+
# Use design-based variance when full survey design is available
|
|
497
|
+
# Use unit-level resolved survey (panel IF is indexed by unit, not obs)
|
|
498
|
+
resolved_survey = (
|
|
499
|
+
precomputed.get("resolved_survey_unit") if precomputed is not None else None
|
|
500
|
+
)
|
|
501
|
+
if (
|
|
502
|
+
resolved_survey is not None
|
|
503
|
+
and hasattr(resolved_survey, "uses_replicate_variance")
|
|
504
|
+
and resolved_survey.uses_replicate_variance
|
|
505
|
+
):
|
|
506
|
+
from diff_diff.survey import compute_replicate_if_variance
|
|
507
|
+
|
|
508
|
+
variance, n_valid_rep = compute_replicate_if_variance(psi_total, resolved_survey)
|
|
509
|
+
# Compute effective df for this statistic (don't mutate shared state)
|
|
510
|
+
effective_df = None
|
|
511
|
+
if n_valid_rep < resolved_survey.n_replicates:
|
|
512
|
+
effective_df = n_valid_rep - 1 if n_valid_rep > 1 else 0
|
|
513
|
+
if np.isnan(variance):
|
|
514
|
+
se = np.nan
|
|
515
|
+
else:
|
|
516
|
+
se = np.sqrt(max(variance, 0.0))
|
|
517
|
+
if return_psi:
|
|
518
|
+
return (se, psi_total, effective_df)
|
|
519
|
+
return (se, effective_df)
|
|
520
|
+
|
|
521
|
+
if resolved_survey is not None and (
|
|
522
|
+
resolved_survey.strata is not None
|
|
523
|
+
or resolved_survey.psu is not None
|
|
524
|
+
or resolved_survey.fpc is not None
|
|
525
|
+
):
|
|
526
|
+
from diff_diff.survey import compute_survey_if_variance
|
|
527
|
+
|
|
528
|
+
variance = compute_survey_if_variance(psi_total, resolved_survey)
|
|
529
|
+
if np.isnan(variance):
|
|
530
|
+
se = np.nan
|
|
531
|
+
else:
|
|
532
|
+
se = np.sqrt(max(variance, 0.0))
|
|
533
|
+
if return_psi:
|
|
534
|
+
return (se, psi_total, None)
|
|
535
|
+
return (se, None)
|
|
536
|
+
|
|
537
|
+
variance = np.sum(psi_total**2)
|
|
538
|
+
se = np.sqrt(variance)
|
|
539
|
+
if return_psi:
|
|
540
|
+
return (se, psi_total, None)
|
|
541
|
+
return (se, None)
|
|
542
|
+
|
|
543
|
+
def _aggregate_event_study(
|
|
544
|
+
self,
|
|
545
|
+
group_time_effects: Dict,
|
|
546
|
+
influence_func_info: Dict,
|
|
547
|
+
groups: List[Any],
|
|
548
|
+
time_periods: List[Any],
|
|
549
|
+
balance_e: Optional[int] = None,
|
|
550
|
+
df: Optional[pd.DataFrame] = None,
|
|
551
|
+
unit: Optional[str] = None,
|
|
552
|
+
precomputed: Optional["PrecomputedData"] = None,
|
|
553
|
+
) -> Dict[int, Dict[str, Any]]:
|
|
554
|
+
"""
|
|
555
|
+
Aggregate effects by relative time (event study).
|
|
556
|
+
|
|
557
|
+
Computes average effect at each event time e = t - g.
|
|
558
|
+
|
|
559
|
+
Standard errors include the weight influence function (WIF)
|
|
560
|
+
adjustment that accounts for uncertainty in group-size weights,
|
|
561
|
+
matching R's did::aggte(..., type="dynamic").
|
|
562
|
+
"""
|
|
563
|
+
# Organize effects by relative time, keeping track of (g,t) pairs
|
|
564
|
+
effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
|
|
565
|
+
|
|
566
|
+
# Fixed per-cohort survey weights for aggregation
|
|
567
|
+
survey_cohort_weights = None
|
|
568
|
+
if precomputed is not None and precomputed.get("survey_weights") is not None:
|
|
569
|
+
sw = precomputed["survey_weights"]
|
|
570
|
+
unit_cohorts = precomputed["unit_cohorts"]
|
|
571
|
+
survey_cohort_weights = {}
|
|
572
|
+
for g in np.unique(unit_cohorts):
|
|
573
|
+
if g > 0:
|
|
574
|
+
survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g]))
|
|
575
|
+
|
|
576
|
+
for (g, t), data in group_time_effects.items():
|
|
577
|
+
e = t - g # Relative time
|
|
578
|
+
if e not in effects_by_e:
|
|
579
|
+
effects_by_e[e] = []
|
|
580
|
+
# For RCS, data["agg_weight"] holds the fixed cohort mass;
|
|
581
|
+
# for panel, fallback to data["n_treated"].
|
|
582
|
+
w = (
|
|
583
|
+
survey_cohort_weights[g]
|
|
584
|
+
if survey_cohort_weights is not None and g in survey_cohort_weights
|
|
585
|
+
else data.get("agg_weight", data["n_treated"])
|
|
586
|
+
)
|
|
587
|
+
effects_by_e[e].append(
|
|
588
|
+
(
|
|
589
|
+
(g, t), # Keep track of the (g,t) pair
|
|
590
|
+
data["effect"],
|
|
591
|
+
w,
|
|
592
|
+
)
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
# Balance the panel if requested
|
|
596
|
+
if balance_e is not None:
|
|
597
|
+
# Keep only groups that have effects at relative time balance_e
|
|
598
|
+
groups_at_e = set()
|
|
599
|
+
for (g, t), data in group_time_effects.items():
|
|
600
|
+
if t - g == balance_e and np.isfinite(data["effect"]):
|
|
601
|
+
groups_at_e.add(g)
|
|
602
|
+
|
|
603
|
+
# Filter effects to only include balanced groups
|
|
604
|
+
balanced_effects: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
|
|
605
|
+
for (g, t), data in group_time_effects.items():
|
|
606
|
+
if g in groups_at_e:
|
|
607
|
+
e = t - g
|
|
608
|
+
if e not in balanced_effects:
|
|
609
|
+
balanced_effects[e] = []
|
|
610
|
+
w = (
|
|
611
|
+
survey_cohort_weights[g]
|
|
612
|
+
if survey_cohort_weights is not None and g in survey_cohort_weights
|
|
613
|
+
else data.get("agg_weight", data["n_treated"])
|
|
614
|
+
)
|
|
615
|
+
balanced_effects[e].append(
|
|
616
|
+
(
|
|
617
|
+
(g, t),
|
|
618
|
+
data["effect"],
|
|
619
|
+
w,
|
|
620
|
+
)
|
|
621
|
+
)
|
|
622
|
+
effects_by_e = balanced_effects
|
|
623
|
+
|
|
624
|
+
# Compute aggregated effects and SEs for all relative periods
|
|
625
|
+
sorted_periods = sorted(effects_by_e.items())
|
|
626
|
+
agg_effects_list = []
|
|
627
|
+
agg_ses_list = []
|
|
628
|
+
agg_n_groups = []
|
|
629
|
+
agg_effective_dfs = [] # Per-horizon effective df (replicate designs)
|
|
630
|
+
_psi_vectors = [] # Per-event-time combined IF vectors for VCV
|
|
631
|
+
_psi_event_times = [] # Event times that contributed a psi column
|
|
632
|
+
for e, effect_list in sorted_periods:
|
|
633
|
+
gt_pairs = [x[0] for x in effect_list]
|
|
634
|
+
effs = np.array([x[1] for x in effect_list])
|
|
635
|
+
ns = np.array([x[2] for x in effect_list], dtype=float)
|
|
636
|
+
|
|
637
|
+
# Exclude NaN effects from this period's aggregation
|
|
638
|
+
finite_mask = np.isfinite(effs)
|
|
639
|
+
if not np.all(finite_mask):
|
|
640
|
+
effs = effs[finite_mask]
|
|
641
|
+
ns = ns[finite_mask]
|
|
642
|
+
gt_pairs = [gt for gt, m in zip(gt_pairs, finite_mask) if m]
|
|
643
|
+
if len(effs) == 0:
|
|
644
|
+
agg_effects_list.append(np.nan)
|
|
645
|
+
agg_ses_list.append(np.nan)
|
|
646
|
+
agg_n_groups.append(0)
|
|
647
|
+
agg_effective_dfs.append(None)
|
|
648
|
+
continue
|
|
649
|
+
|
|
650
|
+
weights = ns / np.sum(ns)
|
|
651
|
+
agg_effect = np.sum(weights * effs)
|
|
652
|
+
|
|
653
|
+
# Compute SE with WIF adjustment (matching R's did::aggte)
|
|
654
|
+
groups_for_gt = np.array([g for (g, t) in gt_pairs])
|
|
655
|
+
agg_se, psi_e, eff_df = self._compute_aggregated_se_with_wif(
|
|
656
|
+
gt_pairs,
|
|
657
|
+
weights,
|
|
658
|
+
effs,
|
|
659
|
+
groups_for_gt,
|
|
660
|
+
influence_func_info,
|
|
661
|
+
df,
|
|
662
|
+
unit,
|
|
663
|
+
precomputed,
|
|
664
|
+
return_psi=True,
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
agg_effects_list.append(agg_effect)
|
|
668
|
+
agg_ses_list.append(agg_se)
|
|
669
|
+
agg_n_groups.append(len(effect_list))
|
|
670
|
+
agg_effective_dfs.append(eff_df)
|
|
671
|
+
_psi_vectors.append(psi_e)
|
|
672
|
+
_psi_event_times.append(e)
|
|
673
|
+
|
|
674
|
+
# Batch inference for all relative periods
|
|
675
|
+
if not agg_effects_list:
|
|
676
|
+
return {}
|
|
677
|
+
# Use per-horizon effective df if any replicate aggregation overrode it;
|
|
678
|
+
# otherwise fall back to the original df from the survey design.
|
|
679
|
+
df_survey_val = precomputed.get("df_survey") if precomputed is not None else None
|
|
680
|
+
# Guard: replicate design with undefined df → NaN inference
|
|
681
|
+
if (
|
|
682
|
+
df_survey_val is None
|
|
683
|
+
and precomputed is not None
|
|
684
|
+
and precomputed.get("resolved_survey_unit") is not None
|
|
685
|
+
and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance")
|
|
686
|
+
and precomputed["resolved_survey_unit"].uses_replicate_variance
|
|
687
|
+
):
|
|
688
|
+
df_survey_val = 0
|
|
689
|
+
# If any horizon has a per-statistic effective df (dropped replicates),
|
|
690
|
+
# use the minimum across horizons for conservative batch inference.
|
|
691
|
+
non_none_dfs = [d for d in agg_effective_dfs if d is not None]
|
|
692
|
+
if non_none_dfs:
|
|
693
|
+
df_survey_val = min(non_none_dfs)
|
|
694
|
+
t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch(
|
|
695
|
+
np.array(agg_effects_list),
|
|
696
|
+
np.array(agg_ses_list),
|
|
697
|
+
alpha=self.alpha,
|
|
698
|
+
df=df_survey_val,
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
event_study_effects = {}
|
|
702
|
+
for idx, (e, _) in enumerate(sorted_periods):
|
|
703
|
+
event_study_effects[e] = {
|
|
704
|
+
"effect": agg_effects_list[idx],
|
|
705
|
+
"se": agg_ses_list[idx],
|
|
706
|
+
"t_stat": float(t_stats[idx]),
|
|
707
|
+
"p_value": float(p_values[idx]),
|
|
708
|
+
"conf_int": (float(ci_lowers[idx]), float(ci_uppers[idx])),
|
|
709
|
+
"n_groups": agg_n_groups[idx],
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
# Add reference period for universal base period mode (matches R did package)
|
|
713
|
+
if getattr(self, "base_period", "varying") == "universal":
|
|
714
|
+
ref_period = -1 - self.anticipation
|
|
715
|
+
if event_study_effects and ref_period not in event_study_effects:
|
|
716
|
+
event_study_effects[ref_period] = {
|
|
717
|
+
"effect": 0.0,
|
|
718
|
+
"se": np.nan,
|
|
719
|
+
"t_stat": np.nan,
|
|
720
|
+
"p_value": np.nan,
|
|
721
|
+
"conf_int": (np.nan, np.nan),
|
|
722
|
+
"n_groups": 0,
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
# Compute full event-study VCV from per-event-time IF vectors (Phase 7d)
|
|
726
|
+
# This enables HonestDiD to use the full covariance structure
|
|
727
|
+
event_study_vcov = None
|
|
728
|
+
valid_psi = [p for p in _psi_vectors if len(p) > 0]
|
|
729
|
+
if valid_psi:
|
|
730
|
+
try:
|
|
731
|
+
Psi = np.column_stack(valid_psi) # (n_units, n_event_times)
|
|
732
|
+
resolved_survey = (
|
|
733
|
+
precomputed.get("resolved_survey_unit") if precomputed is not None else None
|
|
734
|
+
)
|
|
735
|
+
if (
|
|
736
|
+
resolved_survey is not None
|
|
737
|
+
and not (
|
|
738
|
+
hasattr(resolved_survey, "uses_replicate_variance")
|
|
739
|
+
and resolved_survey.uses_replicate_variance
|
|
740
|
+
)
|
|
741
|
+
and (
|
|
742
|
+
resolved_survey.strata is not None
|
|
743
|
+
or resolved_survey.psu is not None
|
|
744
|
+
or resolved_survey.fpc is not None
|
|
745
|
+
)
|
|
746
|
+
):
|
|
747
|
+
from diff_diff.survey import _compute_stratified_psu_meat
|
|
748
|
+
|
|
749
|
+
meat, _, _ = _compute_stratified_psu_meat(Psi, resolved_survey)
|
|
750
|
+
event_study_vcov = meat
|
|
751
|
+
elif (
|
|
752
|
+
resolved_survey is not None
|
|
753
|
+
and hasattr(resolved_survey, "uses_replicate_variance")
|
|
754
|
+
and resolved_survey.uses_replicate_variance
|
|
755
|
+
):
|
|
756
|
+
# Replicate-weight: fall back to None (diagonal in HonestDiD)
|
|
757
|
+
# until multivariate replicate VCV is implemented
|
|
758
|
+
event_study_vcov = None
|
|
759
|
+
else:
|
|
760
|
+
# No survey: simple sum-of-outer-products
|
|
761
|
+
event_study_vcov = Psi.T @ Psi
|
|
762
|
+
except (ValueError, np.linalg.LinAlgError):
|
|
763
|
+
pass # Fall back to diagonal (None)
|
|
764
|
+
|
|
765
|
+
# Store the event-time index that matches VCV columns (for subsetting
|
|
766
|
+
# in HonestDiD when some event times are filtered out)
|
|
767
|
+
self._event_study_vcov_index = _psi_event_times if event_study_vcov is not None else None
|
|
768
|
+
|
|
769
|
+
# Attach VCV to self for CallawaySantAnna to pick up
|
|
770
|
+
self._event_study_vcov = event_study_vcov
|
|
771
|
+
|
|
772
|
+
return event_study_effects
|
|
773
|
+
|
|
774
|
+
def _aggregate_by_group(
|
|
775
|
+
self,
|
|
776
|
+
group_time_effects: Dict,
|
|
777
|
+
influence_func_info: Dict,
|
|
778
|
+
groups: List[Any],
|
|
779
|
+
precomputed: Optional["PrecomputedData"] = None,
|
|
780
|
+
df: Optional[pd.DataFrame] = None,
|
|
781
|
+
unit: Optional[str] = None,
|
|
782
|
+
) -> Dict[Any, Dict[str, Any]]:
|
|
783
|
+
"""
|
|
784
|
+
Aggregate effects by treatment cohort.
|
|
785
|
+
|
|
786
|
+
Computes average effect for each cohort across all post-treatment periods.
|
|
787
|
+
|
|
788
|
+
Standard errors use influence function aggregation with WIF adjustment
|
|
789
|
+
to account for covariances across time periods within a cohort.
|
|
790
|
+
When a full survey design is present in precomputed, uses design-based
|
|
791
|
+
variance via compute_survey_if_variance().
|
|
792
|
+
"""
|
|
793
|
+
# Collect all group aggregation data first
|
|
794
|
+
group_data_list = []
|
|
795
|
+
for g in groups:
|
|
796
|
+
g_effects = [
|
|
797
|
+
((g, t), data["effect"])
|
|
798
|
+
for (gg, t), data in group_time_effects.items()
|
|
799
|
+
if gg == g and t >= g - self.anticipation
|
|
800
|
+
]
|
|
801
|
+
|
|
802
|
+
if not g_effects:
|
|
803
|
+
continue
|
|
804
|
+
|
|
805
|
+
gt_pairs = [x[0] for x in g_effects]
|
|
806
|
+
effs = np.array([x[1] for x in g_effects])
|
|
807
|
+
|
|
808
|
+
# Exclude NaN effects from this group's aggregation
|
|
809
|
+
finite_mask = np.isfinite(effs)
|
|
810
|
+
if not np.all(finite_mask):
|
|
811
|
+
effs = effs[finite_mask]
|
|
812
|
+
gt_pairs = [gt for gt, m in zip(gt_pairs, finite_mask) if m]
|
|
813
|
+
if len(effs) == 0:
|
|
814
|
+
continue
|
|
815
|
+
|
|
816
|
+
weights = np.ones(len(effs)) / len(effs)
|
|
817
|
+
agg_effect = np.sum(weights * effs)
|
|
818
|
+
|
|
819
|
+
# Use WIF-adjusted SE (with survey design support)
|
|
820
|
+
groups_for_gt = np.array([gg for (gg, t) in gt_pairs])
|
|
821
|
+
agg_se, eff_df = self._compute_aggregated_se_with_wif(
|
|
822
|
+
gt_pairs, weights, effs, groups_for_gt, influence_func_info, df, unit, precomputed
|
|
823
|
+
)
|
|
824
|
+
group_data_list.append((g, agg_effect, agg_se, len(g_effects), eff_df))
|
|
825
|
+
|
|
826
|
+
if not group_data_list:
|
|
827
|
+
return {}
|
|
828
|
+
|
|
829
|
+
# Batch inference
|
|
830
|
+
agg_effects = np.array([x[1] for x in group_data_list])
|
|
831
|
+
agg_ses = np.array([x[2] for x in group_data_list])
|
|
832
|
+
df_survey_val = precomputed.get("df_survey") if precomputed is not None else None
|
|
833
|
+
# Guard: replicate design with undefined df → NaN inference
|
|
834
|
+
if (
|
|
835
|
+
df_survey_val is None
|
|
836
|
+
and precomputed is not None
|
|
837
|
+
and precomputed.get("resolved_survey_unit") is not None
|
|
838
|
+
and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance")
|
|
839
|
+
and precomputed["resolved_survey_unit"].uses_replicate_variance
|
|
840
|
+
):
|
|
841
|
+
df_survey_val = 0
|
|
842
|
+
# Use minimum per-group effective df if any dropped replicates
|
|
843
|
+
non_none_dfs = [x[4] for x in group_data_list if x[4] is not None]
|
|
844
|
+
if non_none_dfs:
|
|
845
|
+
df_survey_val = min(non_none_dfs)
|
|
846
|
+
t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch(
|
|
847
|
+
agg_effects,
|
|
848
|
+
agg_ses,
|
|
849
|
+
alpha=self.alpha,
|
|
850
|
+
df=df_survey_val,
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
group_effects = {}
|
|
854
|
+
for idx, (g, agg_effect, agg_se, n_periods, _eff_df) in enumerate(group_data_list):
|
|
855
|
+
group_effects[g] = {
|
|
856
|
+
"effect": agg_effect,
|
|
857
|
+
"se": agg_se,
|
|
858
|
+
"t_stat": float(t_stats[idx]),
|
|
859
|
+
"p_value": float(p_values[idx]),
|
|
860
|
+
"conf_int": (float(ci_lowers[idx]), float(ci_uppers[idx])),
|
|
861
|
+
"n_periods": n_periods,
|
|
862
|
+
}
|
|
863
|
+
|
|
864
|
+
return group_effects
|