diff-diff 2.1.5__tar.gz → 2.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diff_diff-2.1.5 → diff_diff-2.1.7}/PKG-INFO +1 -1
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/__init__.py +1 -1
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/honest_did.py +8 -1
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/pretrends.py +6 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/staggered_aggregation.py +19 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/staggered_bootstrap.py +6 -4
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/utils.py +6 -3
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/visualization.py +76 -22
- {diff_diff-2.1.5 → diff_diff-2.1.7}/pyproject.toml +5 -1
- {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/Cargo.lock +7 -7
- {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/Cargo.toml +1 -1
- {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/src/bootstrap.rs +66 -12
- {diff_diff-2.1.5 → diff_diff-2.1.7}/README.md +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/_backend.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/bacon.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/datasets.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/diagnostics.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/estimators.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/linalg.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/power.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/prep.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/prep_dgp.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/results.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/staggered.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/staggered_results.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/sun_abraham.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/synthetic_did.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/triple_diff.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/trop.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/twfe.py +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/src/lib.rs +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/src/linalg.rs +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/src/trop.rs +0 -0
- {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/src/weights.rs +0 -0
|
@@ -584,7 +584,12 @@ def _extract_event_study_params(
|
|
|
584
584
|
)
|
|
585
585
|
|
|
586
586
|
# Extract event study effects by relative time
|
|
587
|
-
|
|
587
|
+
# Filter out normalization constraints (n_groups=0) and non-finite SEs
|
|
588
|
+
event_effects = {
|
|
589
|
+
t: data for t, data in results.event_study_effects.items()
|
|
590
|
+
if data.get('n_groups', 1) > 0
|
|
591
|
+
and np.isfinite(data.get('se', np.nan))
|
|
592
|
+
}
|
|
588
593
|
rel_times = sorted(event_effects.keys())
|
|
589
594
|
|
|
590
595
|
# Split into pre and post
|
|
@@ -1261,10 +1266,12 @@ class HonestDiD:
|
|
|
1261
1266
|
from diff_diff.staggered import CallawaySantAnnaResults
|
|
1262
1267
|
if isinstance(results, CallawaySantAnnaResults):
|
|
1263
1268
|
if results.event_study_effects:
|
|
1269
|
+
# Filter out normalization constraints (n_groups=0, e.g. reference period)
|
|
1264
1270
|
pre_effects = [
|
|
1265
1271
|
abs(results.event_study_effects[t]['effect'])
|
|
1266
1272
|
for t in results.event_study_effects
|
|
1267
1273
|
if t < 0
|
|
1274
|
+
and results.event_study_effects[t].get('n_groups', 1) > 0
|
|
1268
1275
|
]
|
|
1269
1276
|
if pre_effects:
|
|
1270
1277
|
return max(pre_effects)
|
|
@@ -656,9 +656,12 @@ class PreTrendsPower:
|
|
|
656
656
|
)
|
|
657
657
|
|
|
658
658
|
# Get pre-period effects (negative relative times)
|
|
659
|
+
# Filter out normalization constraints (n_groups=0) and non-finite SEs
|
|
659
660
|
pre_effects = {
|
|
660
661
|
t: data for t, data in results.event_study_effects.items()
|
|
661
662
|
if t < 0
|
|
663
|
+
and data.get('n_groups', 1) > 0
|
|
664
|
+
and np.isfinite(data.get('se', np.nan))
|
|
662
665
|
}
|
|
663
666
|
|
|
664
667
|
if not pre_effects:
|
|
@@ -680,9 +683,12 @@ class PreTrendsPower:
|
|
|
680
683
|
from diff_diff.sun_abraham import SunAbrahamResults
|
|
681
684
|
if isinstance(results, SunAbrahamResults):
|
|
682
685
|
# Get pre-period effects (negative relative times)
|
|
686
|
+
# Filter out normalization constraints (n_groups=0) and non-finite SEs
|
|
683
687
|
pre_effects = {
|
|
684
688
|
t: data for t, data in results.event_study_effects.items()
|
|
685
689
|
if t < 0
|
|
690
|
+
and data.get('n_groups', 1) > 0
|
|
691
|
+
and np.isfinite(data.get('se', np.nan))
|
|
686
692
|
}
|
|
687
693
|
|
|
688
694
|
if not pre_effects:
|
|
@@ -34,6 +34,9 @@ class CallawaySantAnnaAggregationMixin:
|
|
|
34
34
|
# Type hint for anticipation attribute accessed from main class
|
|
35
35
|
anticipation: int
|
|
36
36
|
|
|
37
|
+
# Type hint for base_period attribute accessed from main class
|
|
38
|
+
base_period: str
|
|
39
|
+
|
|
37
40
|
def _aggregate_simple(
|
|
38
41
|
self,
|
|
39
42
|
group_time_effects: Dict,
|
|
@@ -414,6 +417,22 @@ class CallawaySantAnnaAggregationMixin:
|
|
|
414
417
|
'n_groups': len(effect_list),
|
|
415
418
|
}
|
|
416
419
|
|
|
420
|
+
# Add reference period for universal base period mode (matches R did package)
|
|
421
|
+
# The reference period e = -1 - anticipation has effect = 0 by construction
|
|
422
|
+
# Only add if there are actual computed effects (guard against empty data)
|
|
423
|
+
if getattr(self, 'base_period', 'varying') == "universal":
|
|
424
|
+
ref_period = -1 - self.anticipation
|
|
425
|
+
# Only inject reference if we have at least one real effect
|
|
426
|
+
if event_study_effects and ref_period not in event_study_effects:
|
|
427
|
+
event_study_effects[ref_period] = {
|
|
428
|
+
'effect': 0.0,
|
|
429
|
+
'se': np.nan, # Undefined - no data, normalization constraint
|
|
430
|
+
't_stat': np.nan, # Undefined - normalization constraint
|
|
431
|
+
'p_value': np.nan,
|
|
432
|
+
'conf_int': (np.nan, np.nan), # NaN propagation for undefined inference
|
|
433
|
+
'n_groups': 0, # No groups contribute - fixed by construction
|
|
434
|
+
}
|
|
435
|
+
|
|
417
436
|
return event_study_effects
|
|
418
437
|
|
|
419
438
|
def _aggregate_by_group(
|
|
@@ -60,12 +60,13 @@ def _generate_bootstrap_weights(
|
|
|
60
60
|
|
|
61
61
|
elif weight_type == "webb":
|
|
62
62
|
# Webb's 6-point distribution (recommended for few clusters)
|
|
63
|
+
# Values: ±√(3/2), ±1, ±√(1/2) with equal probabilities (1/6 each)
|
|
64
|
+
# This matches R's did package: E[w]=0, Var(w)=1.0
|
|
63
65
|
values = np.array([
|
|
64
66
|
-np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
|
|
65
67
|
np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
|
|
66
68
|
])
|
|
67
|
-
|
|
68
|
-
return rng.choice(values, size=n_units, p=probs)
|
|
69
|
+
return rng.choice(values, size=n_units) # Equal probs (1/6 each)
|
|
69
70
|
|
|
70
71
|
else:
|
|
71
72
|
raise ValueError(
|
|
@@ -152,12 +153,13 @@ def _generate_bootstrap_weights_batch_numpy(
|
|
|
152
153
|
|
|
153
154
|
elif weight_type == "webb":
|
|
154
155
|
# Webb's 6-point distribution
|
|
156
|
+
# Values: ±√(3/2), ±1, ±√(1/2) with equal probabilities (1/6 each)
|
|
157
|
+
# This matches R's did package: E[w]=0, Var(w)=1.0
|
|
155
158
|
values = np.array([
|
|
156
159
|
-np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
|
|
157
160
|
np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
|
|
158
161
|
])
|
|
159
|
-
|
|
160
|
-
return rng.choice(values, size=(n_bootstrap, n_units), p=probs)
|
|
162
|
+
return rng.choice(values, size=(n_bootstrap, n_units)) # Equal probs (1/6 each)
|
|
161
163
|
|
|
162
164
|
else:
|
|
163
165
|
raise ValueError(
|
|
@@ -238,7 +238,7 @@ def _generate_webb_weights(n_clusters: int, rng: np.random.Generator) -> np.ndar
|
|
|
238
238
|
Generate Webb's 6-point distribution weights.
|
|
239
239
|
|
|
240
240
|
Values: {-sqrt(3/2), -sqrt(2/2), -sqrt(1/2), sqrt(1/2), sqrt(2/2), sqrt(3/2)}
|
|
241
|
-
with probabilities
|
|
241
|
+
with equal probabilities (1/6 each), giving E[w]=0 and Var(w)=1.0.
|
|
242
242
|
|
|
243
243
|
This distribution is recommended for very few clusters (G < 10) as it
|
|
244
244
|
provides better finite-sample properties than Rademacher weights.
|
|
@@ -259,13 +259,16 @@ def _generate_webb_weights(n_clusters: int, rng: np.random.Generator) -> np.ndar
|
|
|
259
259
|
----------
|
|
260
260
|
Webb, M. D. (2014). Reworking wild bootstrap based inference for
|
|
261
261
|
clustered errors. Queen's Economics Department Working Paper No. 1315.
|
|
262
|
+
|
|
263
|
+
Note: Uses equal probabilities (1/6 each) matching R's `did` package,
|
|
264
|
+
which gives unit variance for consistency with other weight distributions.
|
|
262
265
|
"""
|
|
263
266
|
values = np.array([
|
|
264
267
|
-np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
|
|
265
268
|
np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
|
|
266
269
|
])
|
|
267
|
-
|
|
268
|
-
return np.asarray(rng.choice(values, size=n_clusters
|
|
270
|
+
# Equal probabilities (1/6 each) matching R's did package, giving Var(w) = 1.0
|
|
271
|
+
return np.asarray(rng.choice(values, size=n_clusters))
|
|
269
272
|
|
|
270
273
|
|
|
271
274
|
def _generate_mammen_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
|
|
@@ -73,8 +73,10 @@ def plot_event_study(
|
|
|
73
73
|
periods : list, optional
|
|
74
74
|
List of periods to plot. If None, uses all periods from results.
|
|
75
75
|
reference_period : any, optional
|
|
76
|
-
The reference period
|
|
77
|
-
|
|
76
|
+
The reference period to highlight. When explicitly provided, effects
|
|
77
|
+
are normalized (ref effect subtracted) and ref SE is set to NaN.
|
|
78
|
+
When None and auto-inferred from results, only hollow marker styling
|
|
79
|
+
is applied (no normalization). If None, tries to infer from results.
|
|
78
80
|
pre_periods : list, optional
|
|
79
81
|
List of pre-treatment periods. Used for shading.
|
|
80
82
|
post_periods : list, optional
|
|
@@ -151,8 +153,9 @@ def plot_event_study(
|
|
|
151
153
|
trends holds. Large pre-treatment effects suggest the assumption may
|
|
152
154
|
be violated.
|
|
153
155
|
|
|
154
|
-
2. **Reference period**: Usually the last pre-treatment period (t=-1)
|
|
155
|
-
|
|
156
|
+
2. **Reference period**: Usually the last pre-treatment period (t=-1).
|
|
157
|
+
When explicitly specified via ``reference_period``, effects are normalized
|
|
158
|
+
to zero at this period. When auto-inferred, shown with hollow marker only.
|
|
156
159
|
|
|
157
160
|
3. **Post-treatment periods**: The treatment effects of interest. These
|
|
158
161
|
show how the outcome evolved after treatment.
|
|
@@ -170,10 +173,18 @@ def plot_event_study(
|
|
|
170
173
|
|
|
171
174
|
from scipy import stats as scipy_stats
|
|
172
175
|
|
|
176
|
+
# Track if reference_period was explicitly provided by user
|
|
177
|
+
reference_period_explicit = reference_period is not None
|
|
178
|
+
|
|
173
179
|
# Extract data from results if provided
|
|
174
180
|
if results is not None:
|
|
175
|
-
|
|
176
|
-
|
|
181
|
+
extracted = _extract_plot_data(
|
|
182
|
+
results, periods, pre_periods, post_periods, reference_period
|
|
183
|
+
)
|
|
184
|
+
effects, se, periods, pre_periods, post_periods, reference_period, reference_inferred = extracted
|
|
185
|
+
# If reference was inferred from results, it was NOT explicitly provided
|
|
186
|
+
if reference_inferred:
|
|
187
|
+
reference_period_explicit = False
|
|
177
188
|
elif effects is None or se is None:
|
|
178
189
|
raise ValueError(
|
|
179
190
|
"Must provide either 'results' or both 'effects' and 'se'"
|
|
@@ -192,16 +203,35 @@ def plot_event_study(
|
|
|
192
203
|
# Compute confidence intervals
|
|
193
204
|
critical_value = scipy_stats.norm.ppf(1 - alpha / 2)
|
|
194
205
|
|
|
206
|
+
# Normalize effects to reference period ONLY if explicitly specified by user
|
|
207
|
+
# Auto-inferred reference periods (from CallawaySantAnna) just get hollow marker styling,
|
|
208
|
+
# NO normalization. This prevents unintended normalization when the reference period
|
|
209
|
+
# isn't a true identifying constraint (e.g., CallawaySantAnna with base_period="varying").
|
|
210
|
+
if (reference_period is not None and reference_period in effects and
|
|
211
|
+
reference_period_explicit):
|
|
212
|
+
ref_effect = effects[reference_period]
|
|
213
|
+
if np.isfinite(ref_effect):
|
|
214
|
+
effects = {p: e - ref_effect for p, e in effects.items()}
|
|
215
|
+
# Set reference SE to NaN (it's now a constraint, not an estimate)
|
|
216
|
+
# This follows fixest convention where the omitted category has no SE/CI
|
|
217
|
+
se = {p: (np.nan if p == reference_period else s) for p, s in se.items()}
|
|
218
|
+
|
|
195
219
|
plot_data = []
|
|
196
220
|
for period in periods:
|
|
197
221
|
effect = effects.get(period, np.nan)
|
|
198
222
|
std_err = se.get(period, np.nan)
|
|
199
223
|
|
|
200
|
-
|
|
224
|
+
# Skip entries with NaN effect, but allow NaN SE (will plot without error bars)
|
|
225
|
+
if np.isnan(effect):
|
|
201
226
|
continue
|
|
202
227
|
|
|
203
|
-
|
|
204
|
-
|
|
228
|
+
# Compute CI only if SE is finite
|
|
229
|
+
if np.isfinite(std_err):
|
|
230
|
+
ci_lower = effect - critical_value * std_err
|
|
231
|
+
ci_upper = effect + critical_value * std_err
|
|
232
|
+
else:
|
|
233
|
+
ci_lower = np.nan
|
|
234
|
+
ci_upper = np.nan
|
|
205
235
|
|
|
206
236
|
plot_data.append({
|
|
207
237
|
'period': period,
|
|
@@ -244,13 +274,20 @@ def plot_event_study(
|
|
|
244
274
|
ref_x = period_to_x[reference_period]
|
|
245
275
|
ax.axvline(x=ref_x, color='gray', linestyle=':', linewidth=1, zorder=1)
|
|
246
276
|
|
|
247
|
-
# Plot error bars
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
277
|
+
# Plot error bars (only for entries with finite CI)
|
|
278
|
+
has_ci = df['ci_lower'].notna() & df['ci_upper'].notna()
|
|
279
|
+
if has_ci.any():
|
|
280
|
+
df_with_ci = df[has_ci]
|
|
281
|
+
x_with_ci = [period_to_x[p] for p in df_with_ci['period']]
|
|
282
|
+
yerr = [
|
|
283
|
+
df_with_ci['effect'] - df_with_ci['ci_lower'],
|
|
284
|
+
df_with_ci['ci_upper'] - df_with_ci['effect']
|
|
285
|
+
]
|
|
286
|
+
ax.errorbar(
|
|
287
|
+
x_with_ci, df_with_ci['effect'], yerr=yerr,
|
|
288
|
+
fmt='none', color=color, capsize=capsize, linewidth=linewidth,
|
|
289
|
+
capthick=linewidth, zorder=2
|
|
290
|
+
)
|
|
254
291
|
|
|
255
292
|
# Plot point estimates
|
|
256
293
|
for i, row in df.iterrows():
|
|
@@ -291,14 +328,17 @@ def _extract_plot_data(
|
|
|
291
328
|
pre_periods: Optional[List[Any]],
|
|
292
329
|
post_periods: Optional[List[Any]],
|
|
293
330
|
reference_period: Optional[Any],
|
|
294
|
-
) -> Tuple[Dict, Dict, List, List, List, Any]:
|
|
331
|
+
) -> Tuple[Dict, Dict, List, List, List, Any, bool]:
|
|
295
332
|
"""
|
|
296
333
|
Extract plotting data from various result types.
|
|
297
334
|
|
|
298
335
|
Returns
|
|
299
336
|
-------
|
|
300
337
|
tuple
|
|
301
|
-
(effects, se, periods, pre_periods, post_periods, reference_period)
|
|
338
|
+
(effects, se, periods, pre_periods, post_periods, reference_period, reference_inferred)
|
|
339
|
+
|
|
340
|
+
reference_inferred is True if reference_period was auto-detected from results
|
|
341
|
+
rather than explicitly provided by the user.
|
|
302
342
|
"""
|
|
303
343
|
# Handle DataFrame input
|
|
304
344
|
if isinstance(results, pd.DataFrame):
|
|
@@ -315,7 +355,8 @@ def _extract_plot_data(
|
|
|
315
355
|
if periods is None:
|
|
316
356
|
periods = list(results['period'])
|
|
317
357
|
|
|
318
|
-
|
|
358
|
+
# DataFrame input: reference_period was already set by caller, never inferred here
|
|
359
|
+
return effects, se, periods, pre_periods, post_periods, reference_period, False
|
|
319
360
|
|
|
320
361
|
# Handle MultiPeriodDiDResults
|
|
321
362
|
if hasattr(results, 'period_effects'):
|
|
@@ -335,7 +376,8 @@ def _extract_plot_data(
|
|
|
335
376
|
if periods is None:
|
|
336
377
|
periods = post_periods
|
|
337
378
|
|
|
338
|
-
|
|
379
|
+
# MultiPeriodDiDResults: reference_period was already set by caller, never inferred here
|
|
380
|
+
return effects, se, periods, pre_periods, post_periods, reference_period, False
|
|
339
381
|
|
|
340
382
|
# Handle CallawaySantAnnaResults (event study aggregation)
|
|
341
383
|
if hasattr(results, 'event_study_effects') and results.event_study_effects is not None:
|
|
@@ -349,9 +391,21 @@ def _extract_plot_data(
|
|
|
349
391
|
if periods is None:
|
|
350
392
|
periods = sorted(effects.keys())
|
|
351
393
|
|
|
394
|
+
# Track if reference_period was explicitly provided vs auto-inferred
|
|
395
|
+
reference_inferred = False
|
|
396
|
+
|
|
352
397
|
# Reference period is typically -1 for event study
|
|
353
398
|
if reference_period is None:
|
|
354
|
-
|
|
399
|
+
reference_inferred = True # We're about to infer it
|
|
400
|
+
# Detect reference period from n_groups=0 marker (normalization constraint)
|
|
401
|
+
# This handles anticipation > 0 where reference is at e = -1 - anticipation
|
|
402
|
+
for period, effect_data in results.event_study_effects.items():
|
|
403
|
+
if effect_data.get('n_groups', 1) == 0:
|
|
404
|
+
reference_period = period
|
|
405
|
+
break
|
|
406
|
+
# Fallback to -1 if no marker found (backward compatibility)
|
|
407
|
+
if reference_period is None:
|
|
408
|
+
reference_period = -1
|
|
355
409
|
|
|
356
410
|
if pre_periods is None:
|
|
357
411
|
pre_periods = [p for p in periods if p < 0]
|
|
@@ -359,7 +413,7 @@ def _extract_plot_data(
|
|
|
359
413
|
if post_periods is None:
|
|
360
414
|
post_periods = [p for p in periods if p >= 0]
|
|
361
415
|
|
|
362
|
-
return effects, se, periods, pre_periods, post_periods, reference_period
|
|
416
|
+
return effects, se, periods, pre_periods, post_periods, reference_period, reference_inferred
|
|
363
417
|
|
|
364
418
|
raise TypeError(
|
|
365
419
|
f"Cannot extract plot data from {type(results).__name__}. "
|
|
@@ -4,7 +4,7 @@ build-backend = "maturin"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "diff-diff"
|
|
7
|
-
version = "2.1.
|
|
7
|
+
version = "2.1.7"
|
|
8
8
|
description = "A library for Difference-in-Differences causal inference analysis"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "MIT"
|
|
@@ -70,7 +70,11 @@ python-packages = ["diff_diff"]
|
|
|
70
70
|
[tool.pytest.ini_options]
|
|
71
71
|
testpaths = ["tests"]
|
|
72
72
|
python_files = "test_*.py"
|
|
73
|
+
# Run all tests including slow ones by default; use `pytest -m 'not slow'` for faster local runs
|
|
73
74
|
addopts = "-v --tb=short"
|
|
75
|
+
markers = [
|
|
76
|
+
"slow: marks tests as slow (run `pytest -m 'not slow'` to exclude, or `pytest -m slow` to run only slow tests)",
|
|
77
|
+
]
|
|
74
78
|
|
|
75
79
|
[tool.black]
|
|
76
80
|
line-length = 100
|
|
@@ -115,9 +115,9 @@ dependencies = [
|
|
|
115
115
|
|
|
116
116
|
[[package]]
|
|
117
117
|
name = "cc"
|
|
118
|
-
version = "1.2.
|
|
118
|
+
version = "1.2.54"
|
|
119
119
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
120
|
-
checksum = "
|
|
120
|
+
checksum = "6354c81bbfd62d9cfa9cb3c773c2b7b2a3a482d569de977fd0e961f6e7c00583"
|
|
121
121
|
dependencies = [
|
|
122
122
|
"find-msvc-tools",
|
|
123
123
|
"shlex",
|
|
@@ -289,7 +289,7 @@ dependencies = [
|
|
|
289
289
|
|
|
290
290
|
[[package]]
|
|
291
291
|
name = "diff_diff_rust"
|
|
292
|
-
version = "2.1.
|
|
292
|
+
version = "2.1.7"
|
|
293
293
|
dependencies = [
|
|
294
294
|
"ndarray",
|
|
295
295
|
"ndarray-linalg",
|
|
@@ -1220,9 +1220,9 @@ dependencies = [
|
|
|
1220
1220
|
|
|
1221
1221
|
[[package]]
|
|
1222
1222
|
name = "quote"
|
|
1223
|
-
version = "1.0.
|
|
1223
|
+
version = "1.0.44"
|
|
1224
1224
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
1225
|
-
checksum = "
|
|
1225
|
+
checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
|
|
1226
1226
|
dependencies = [
|
|
1227
1227
|
"proc-macro2",
|
|
1228
1228
|
]
|
|
@@ -1846,9 +1846,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
|
|
1846
1846
|
|
|
1847
1847
|
[[package]]
|
|
1848
1848
|
name = "uuid"
|
|
1849
|
-
version = "1.
|
|
1849
|
+
version = "1.20.0"
|
|
1850
1850
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
1851
|
-
checksum = "
|
|
1851
|
+
checksum = "ee48d38b119b0cd71fe4141b30f5ba9c7c5d9f4e7a3a8b4a674e4b6ef789976f"
|
|
1852
1852
|
dependencies = [
|
|
1853
1853
|
"getrandom 0.3.4",
|
|
1854
1854
|
"js-sys",
|
|
@@ -115,24 +115,24 @@ fn generate_mammen_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array
|
|
|
115
115
|
|
|
116
116
|
/// Generate Webb 6-point distribution weights.
|
|
117
117
|
///
|
|
118
|
-
/// Six-point distribution
|
|
119
|
-
/// E[w] = 0,
|
|
118
|
+
/// Six-point distribution with equal probabilities (1/6 each) matching R's `did` package:
|
|
119
|
+
/// E[w] = 0, Var[w] = 1
|
|
120
120
|
///
|
|
121
|
-
/// Values: ±√(3/2), ±√(
|
|
121
|
+
/// Values: ±√(3/2), ±√(2/2)=±1, ±√(1/2)
|
|
122
122
|
fn generate_webb_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array2<f64> {
|
|
123
123
|
// Webb 6-point values
|
|
124
|
-
let val1 = (3.0_f64 / 2.0).sqrt(); // √(3/2) ≈ 1.
|
|
125
|
-
let val2 =
|
|
126
|
-
let val3 = (1.0_f64 /
|
|
124
|
+
let val1 = (3.0_f64 / 2.0).sqrt(); // √(3/2) ≈ 1.2247
|
|
125
|
+
let val2 = 1.0_f64; // √(2/2) = 1.0
|
|
126
|
+
let val3 = (1.0_f64 / 2.0).sqrt(); // √(1/2) ≈ 0.7071
|
|
127
127
|
|
|
128
|
-
//
|
|
129
|
-
// Equal probability: u in [0, 1/6) -> -val1, [1/6, 2/6) -> -val2, etc.
|
|
128
|
+
// Values in order: -val1, -val2, -val3, val3, val2, val1
|
|
130
129
|
let weights_table = [-val1, -val2, -val3, val3, val2, val1];
|
|
131
130
|
|
|
132
131
|
// Pre-allocate output array - eliminates double allocation
|
|
133
132
|
let mut weights = Array2::<f64>::zeros((n_bootstrap, n_units));
|
|
134
133
|
|
|
135
134
|
// Fill rows in parallel with chunk size tuning
|
|
135
|
+
// Use uniform selection (1/6 probability each) matching R's did package
|
|
136
136
|
weights
|
|
137
137
|
.axis_iter_mut(Axis(0))
|
|
138
138
|
.into_par_iter()
|
|
@@ -141,10 +141,8 @@ fn generate_webb_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array2<
|
|
|
141
141
|
.for_each(|(i, mut row)| {
|
|
142
142
|
let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed.wrapping_add(i as u64));
|
|
143
143
|
for elem in row.iter_mut() {
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
// Clamp to 5 to handle edge case where u == 1.0
|
|
147
|
-
let bucket = ((u * 6.0).floor() as usize).min(5);
|
|
144
|
+
// Uniform selection: generate integer 0-5, index into weights_table
|
|
145
|
+
let bucket = rng.gen_range(0..6);
|
|
148
146
|
*elem = weights_table[bucket];
|
|
149
147
|
}
|
|
150
148
|
});
|
|
@@ -225,4 +223,60 @@ mod tests {
|
|
|
225
223
|
// Different seeds should produce different results
|
|
226
224
|
assert_ne!(weights1, weights2);
|
|
227
225
|
}
|
|
226
|
+
|
|
227
|
+
#[test]
|
|
228
|
+
fn test_webb_mean_approx_zero() {
|
|
229
|
+
let weights = generate_webb_batch(10000, 1, 42);
|
|
230
|
+
let mean: f64 = weights.iter().sum::<f64>() / weights.len() as f64;
|
|
231
|
+
|
|
232
|
+
// With 10000 samples, mean should be close to 0
|
|
233
|
+
assert!(
|
|
234
|
+
mean.abs() < 0.1,
|
|
235
|
+
"Webb mean should be close to 0, got {}",
|
|
236
|
+
mean
|
|
237
|
+
);
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
#[test]
|
|
241
|
+
fn test_webb_variance_approx_correct() {
|
|
242
|
+
// Webb's 6-point distribution with values ±√(3/2), ±1, ±√(1/2)
|
|
243
|
+
// and equal probabilities (1/6 each) should have variance = 1.0
|
|
244
|
+
// This matches R's did package behavior.
|
|
245
|
+
// Theoretical: Var = (1/6) * (3/2 + 1 + 1/2 + 1/2 + 1 + 3/2) = (1/6) * 6 = 1.0
|
|
246
|
+
let weights = generate_webb_batch(10000, 100, 42);
|
|
247
|
+
let n = weights.len() as f64;
|
|
248
|
+
let mean: f64 = weights.iter().sum::<f64>() / n;
|
|
249
|
+
let variance: f64 = weights.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
|
|
250
|
+
|
|
251
|
+
// Theoretical variance = 1.0 with equal probabilities
|
|
252
|
+
// Allow some statistical variance in the estimate
|
|
253
|
+
assert!(
|
|
254
|
+
(variance - 1.0).abs() < 0.05,
|
|
255
|
+
"Webb variance should be ~1.0 (matching R's did package), got {}",
|
|
256
|
+
variance
|
|
257
|
+
);
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
#[test]
|
|
261
|
+
fn test_webb_values_correct() {
|
|
262
|
+
// Verify that Webb weights only take the expected 6 values
|
|
263
|
+
let weights = generate_webb_batch(100, 1000, 42);
|
|
264
|
+
|
|
265
|
+
let val1 = (3.0_f64 / 2.0).sqrt(); // ≈ 1.2247
|
|
266
|
+
let val2 = 1.0_f64;
|
|
267
|
+
let val3 = (1.0_f64 / 2.0).sqrt(); // ≈ 0.7071
|
|
268
|
+
|
|
269
|
+
let expected_values = [-val1, -val2, -val3, val3, val2, val1];
|
|
270
|
+
|
|
271
|
+
for w in weights.iter() {
|
|
272
|
+
let matches_expected = expected_values
|
|
273
|
+
.iter()
|
|
274
|
+
.any(|&expected| (*w - expected).abs() < 1e-10);
|
|
275
|
+
assert!(
|
|
276
|
+
matches_expected,
|
|
277
|
+
"Webb weight {} is not one of the expected values",
|
|
278
|
+
w
|
|
279
|
+
);
|
|
280
|
+
}
|
|
281
|
+
}
|
|
228
282
|
}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|