diff-diff 2.1.5__tar.gz → 2.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {diff_diff-2.1.5 → diff_diff-2.1.7}/PKG-INFO +1 -1
  2. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/__init__.py +1 -1
  3. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/honest_did.py +8 -1
  4. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/pretrends.py +6 -0
  5. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/staggered_aggregation.py +19 -0
  6. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/staggered_bootstrap.py +6 -4
  7. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/utils.py +6 -3
  8. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/visualization.py +76 -22
  9. {diff_diff-2.1.5 → diff_diff-2.1.7}/pyproject.toml +5 -1
  10. {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/Cargo.lock +7 -7
  11. {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/Cargo.toml +1 -1
  12. {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/src/bootstrap.rs +66 -12
  13. {diff_diff-2.1.5 → diff_diff-2.1.7}/README.md +0 -0
  14. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/_backend.py +0 -0
  15. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/bacon.py +0 -0
  16. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/datasets.py +0 -0
  17. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/diagnostics.py +0 -0
  18. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/estimators.py +0 -0
  19. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/linalg.py +0 -0
  20. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/power.py +0 -0
  21. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/prep.py +0 -0
  22. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/prep_dgp.py +0 -0
  23. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/results.py +0 -0
  24. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/staggered.py +0 -0
  25. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/staggered_results.py +0 -0
  26. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/sun_abraham.py +0 -0
  27. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/synthetic_did.py +0 -0
  28. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/triple_diff.py +0 -0
  29. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/trop.py +0 -0
  30. {diff_diff-2.1.5 → diff_diff-2.1.7}/diff_diff/twfe.py +0 -0
  31. {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/src/lib.rs +0 -0
  32. {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/src/linalg.rs +0 -0
  33. {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/src/trop.rs +0 -0
  34. {diff_diff-2.1.5 → diff_diff-2.1.7}/rust/src/weights.rs +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diff-diff
3
- Version: 2.1.5
3
+ Version: 2.1.7
4
4
  Classifier: Development Status :: 5 - Production/Stable
5
5
  Classifier: Intended Audience :: Science/Research
6
6
  Classifier: Operating System :: OS Independent
@@ -136,7 +136,7 @@ from diff_diff.datasets import (
136
136
  load_mpdta,
137
137
  )
138
138
 
139
- __version__ = "2.1.5"
139
+ __version__ = "2.1.7"
140
140
  __all__ = [
141
141
  # Estimators
142
142
  "DifferenceInDifferences",
@@ -584,7 +584,12 @@ def _extract_event_study_params(
584
584
  )
585
585
 
586
586
  # Extract event study effects by relative time
587
- event_effects = results.event_study_effects
587
+ # Filter out normalization constraints (n_groups=0) and non-finite SEs
588
+ event_effects = {
589
+ t: data for t, data in results.event_study_effects.items()
590
+ if data.get('n_groups', 1) > 0
591
+ and np.isfinite(data.get('se', np.nan))
592
+ }
588
593
  rel_times = sorted(event_effects.keys())
589
594
 
590
595
  # Split into pre and post
@@ -1261,10 +1266,12 @@ class HonestDiD:
1261
1266
  from diff_diff.staggered import CallawaySantAnnaResults
1262
1267
  if isinstance(results, CallawaySantAnnaResults):
1263
1268
  if results.event_study_effects:
1269
+ # Filter out normalization constraints (n_groups=0, e.g. reference period)
1264
1270
  pre_effects = [
1265
1271
  abs(results.event_study_effects[t]['effect'])
1266
1272
  for t in results.event_study_effects
1267
1273
  if t < 0
1274
+ and results.event_study_effects[t].get('n_groups', 1) > 0
1268
1275
  ]
1269
1276
  if pre_effects:
1270
1277
  return max(pre_effects)
@@ -656,9 +656,12 @@ class PreTrendsPower:
656
656
  )
657
657
 
658
658
  # Get pre-period effects (negative relative times)
659
+ # Filter out normalization constraints (n_groups=0) and non-finite SEs
659
660
  pre_effects = {
660
661
  t: data for t, data in results.event_study_effects.items()
661
662
  if t < 0
663
+ and data.get('n_groups', 1) > 0
664
+ and np.isfinite(data.get('se', np.nan))
662
665
  }
663
666
 
664
667
  if not pre_effects:
@@ -680,9 +683,12 @@ class PreTrendsPower:
680
683
  from diff_diff.sun_abraham import SunAbrahamResults
681
684
  if isinstance(results, SunAbrahamResults):
682
685
  # Get pre-period effects (negative relative times)
686
+ # Filter out normalization constraints (n_groups=0) and non-finite SEs
683
687
  pre_effects = {
684
688
  t: data for t, data in results.event_study_effects.items()
685
689
  if t < 0
690
+ and data.get('n_groups', 1) > 0
691
+ and np.isfinite(data.get('se', np.nan))
686
692
  }
687
693
 
688
694
  if not pre_effects:
@@ -34,6 +34,9 @@ class CallawaySantAnnaAggregationMixin:
34
34
  # Type hint for anticipation attribute accessed from main class
35
35
  anticipation: int
36
36
 
37
+ # Type hint for base_period attribute accessed from main class
38
+ base_period: str
39
+
37
40
  def _aggregate_simple(
38
41
  self,
39
42
  group_time_effects: Dict,
@@ -414,6 +417,22 @@ class CallawaySantAnnaAggregationMixin:
414
417
  'n_groups': len(effect_list),
415
418
  }
416
419
 
420
+ # Add reference period for universal base period mode (matches R did package)
421
+ # The reference period e = -1 - anticipation has effect = 0 by construction
422
+ # Only add if there are actual computed effects (guard against empty data)
423
+ if getattr(self, 'base_period', 'varying') == "universal":
424
+ ref_period = -1 - self.anticipation
425
+ # Only inject reference if we have at least one real effect
426
+ if event_study_effects and ref_period not in event_study_effects:
427
+ event_study_effects[ref_period] = {
428
+ 'effect': 0.0,
429
+ 'se': np.nan, # Undefined - no data, normalization constraint
430
+ 't_stat': np.nan, # Undefined - normalization constraint
431
+ 'p_value': np.nan,
432
+ 'conf_int': (np.nan, np.nan), # NaN propagation for undefined inference
433
+ 'n_groups': 0, # No groups contribute - fixed by construction
434
+ }
435
+
417
436
  return event_study_effects
418
437
 
419
438
  def _aggregate_by_group(
@@ -60,12 +60,13 @@ def _generate_bootstrap_weights(
60
60
 
61
61
  elif weight_type == "webb":
62
62
  # Webb's 6-point distribution (recommended for few clusters)
63
+ # Values: ±√(3/2), ±1, ±√(1/2) with equal probabilities (1/6 each)
64
+ # This matches R's did package: E[w]=0, Var(w)=1.0
63
65
  values = np.array([
64
66
  -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
65
67
  np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
66
68
  ])
67
- probs = np.array([1, 2, 3, 3, 2, 1]) / 12
68
- return rng.choice(values, size=n_units, p=probs)
69
+ return rng.choice(values, size=n_units) # Equal probs (1/6 each)
69
70
 
70
71
  else:
71
72
  raise ValueError(
@@ -152,12 +153,13 @@ def _generate_bootstrap_weights_batch_numpy(
152
153
 
153
154
  elif weight_type == "webb":
154
155
  # Webb's 6-point distribution
156
+ # Values: ±√(3/2), ±1, ±√(1/2) with equal probabilities (1/6 each)
157
+ # This matches R's did package: E[w]=0, Var(w)=1.0
155
158
  values = np.array([
156
159
  -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
157
160
  np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
158
161
  ])
159
- probs = np.array([1, 2, 3, 3, 2, 1]) / 12
160
- return rng.choice(values, size=(n_bootstrap, n_units), p=probs)
162
+ return rng.choice(values, size=(n_bootstrap, n_units)) # Equal probs (1/6 each)
161
163
 
162
164
  else:
163
165
  raise ValueError(
@@ -238,7 +238,7 @@ def _generate_webb_weights(n_clusters: int, rng: np.random.Generator) -> np.ndar
238
238
  Generate Webb's 6-point distribution weights.
239
239
 
240
240
  Values: {-sqrt(3/2), -sqrt(2/2), -sqrt(1/2), sqrt(1/2), sqrt(2/2), sqrt(3/2)}
241
- with probabilities proportional to {1, 2, 3, 3, 2, 1}.
241
+ with equal probabilities (1/6 each), giving E[w]=0 and Var(w)=1.0.
242
242
 
243
243
  This distribution is recommended for very few clusters (G < 10) as it
244
244
  provides better finite-sample properties than Rademacher weights.
@@ -259,13 +259,16 @@ def _generate_webb_weights(n_clusters: int, rng: np.random.Generator) -> np.ndar
259
259
  ----------
260
260
  Webb, M. D. (2014). Reworking wild bootstrap based inference for
261
261
  clustered errors. Queen's Economics Department Working Paper No. 1315.
262
+
263
+ Note: Uses equal probabilities (1/6 each) matching R's `did` package,
264
+ which gives unit variance for consistency with other weight distributions.
262
265
  """
263
266
  values = np.array([
264
267
  -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
265
268
  np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
266
269
  ])
267
- probs = np.array([1, 2, 3, 3, 2, 1]) / 12
268
- return np.asarray(rng.choice(values, size=n_clusters, p=probs))
270
+ # Equal probabilities (1/6 each) matching R's did package, giving Var(w) = 1.0
271
+ return np.asarray(rng.choice(values, size=n_clusters))
269
272
 
270
273
 
271
274
  def _generate_mammen_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
@@ -73,8 +73,10 @@ def plot_event_study(
73
73
  periods : list, optional
74
74
  List of periods to plot. If None, uses all periods from results.
75
75
  reference_period : any, optional
76
- The reference period (normalized to effect=0). Will be shown as a
77
- hollow marker. If None, tries to infer from results.
76
+ The reference period to highlight. When explicitly provided, effects
77
+ are normalized (ref effect subtracted) and ref SE is set to NaN.
78
+ When None and auto-inferred from results, only hollow marker styling
79
+ is applied (no normalization). If None, tries to infer from results.
78
80
  pre_periods : list, optional
79
81
  List of pre-treatment periods. Used for shading.
80
82
  post_periods : list, optional
@@ -151,8 +153,9 @@ def plot_event_study(
151
153
  trends holds. Large pre-treatment effects suggest the assumption may
152
154
  be violated.
153
155
 
154
- 2. **Reference period**: Usually the last pre-treatment period (t=-1),
155
- normalized to zero. This is the omitted category.
156
+ 2. **Reference period**: Usually the last pre-treatment period (t=-1).
157
+ When explicitly specified via ``reference_period``, effects are normalized
158
+ to zero at this period. When auto-inferred, shown with hollow marker only.
156
159
 
157
160
  3. **Post-treatment periods**: The treatment effects of interest. These
158
161
  show how the outcome evolved after treatment.
@@ -170,10 +173,18 @@ def plot_event_study(
170
173
 
171
174
  from scipy import stats as scipy_stats
172
175
 
176
+ # Track if reference_period was explicitly provided by user
177
+ reference_period_explicit = reference_period is not None
178
+
173
179
  # Extract data from results if provided
174
180
  if results is not None:
175
- effects, se, periods, pre_periods, post_periods, reference_period = \
176
- _extract_plot_data(results, periods, pre_periods, post_periods, reference_period)
181
+ extracted = _extract_plot_data(
182
+ results, periods, pre_periods, post_periods, reference_period
183
+ )
184
+ effects, se, periods, pre_periods, post_periods, reference_period, reference_inferred = extracted
185
+ # If reference was inferred from results, it was NOT explicitly provided
186
+ if reference_inferred:
187
+ reference_period_explicit = False
177
188
  elif effects is None or se is None:
178
189
  raise ValueError(
179
190
  "Must provide either 'results' or both 'effects' and 'se'"
@@ -192,16 +203,35 @@ def plot_event_study(
192
203
  # Compute confidence intervals
193
204
  critical_value = scipy_stats.norm.ppf(1 - alpha / 2)
194
205
 
206
+ # Normalize effects to reference period ONLY if explicitly specified by user
207
+ # Auto-inferred reference periods (from CallawaySantAnna) just get hollow marker styling,
208
+ # NO normalization. This prevents unintended normalization when the reference period
209
+ # isn't a true identifying constraint (e.g., CallawaySantAnna with base_period="varying").
210
+ if (reference_period is not None and reference_period in effects and
211
+ reference_period_explicit):
212
+ ref_effect = effects[reference_period]
213
+ if np.isfinite(ref_effect):
214
+ effects = {p: e - ref_effect for p, e in effects.items()}
215
+ # Set reference SE to NaN (it's now a constraint, not an estimate)
216
+ # This follows fixest convention where the omitted category has no SE/CI
217
+ se = {p: (np.nan if p == reference_period else s) for p, s in se.items()}
218
+
195
219
  plot_data = []
196
220
  for period in periods:
197
221
  effect = effects.get(period, np.nan)
198
222
  std_err = se.get(period, np.nan)
199
223
 
200
- if np.isnan(effect) or np.isnan(std_err):
224
+ # Skip entries with NaN effect, but allow NaN SE (will plot without error bars)
225
+ if np.isnan(effect):
201
226
  continue
202
227
 
203
- ci_lower = effect - critical_value * std_err
204
- ci_upper = effect + critical_value * std_err
228
+ # Compute CI only if SE is finite
229
+ if np.isfinite(std_err):
230
+ ci_lower = effect - critical_value * std_err
231
+ ci_upper = effect + critical_value * std_err
232
+ else:
233
+ ci_lower = np.nan
234
+ ci_upper = np.nan
205
235
 
206
236
  plot_data.append({
207
237
  'period': period,
@@ -244,13 +274,20 @@ def plot_event_study(
244
274
  ref_x = period_to_x[reference_period]
245
275
  ax.axvline(x=ref_x, color='gray', linestyle=':', linewidth=1, zorder=1)
246
276
 
247
- # Plot error bars
248
- yerr = [df['effect'] - df['ci_lower'], df['ci_upper'] - df['effect']]
249
- ax.errorbar(
250
- x_vals, df['effect'], yerr=yerr,
251
- fmt='none', color=color, capsize=capsize, linewidth=linewidth,
252
- capthick=linewidth, zorder=2
253
- )
277
+ # Plot error bars (only for entries with finite CI)
278
+ has_ci = df['ci_lower'].notna() & df['ci_upper'].notna()
279
+ if has_ci.any():
280
+ df_with_ci = df[has_ci]
281
+ x_with_ci = [period_to_x[p] for p in df_with_ci['period']]
282
+ yerr = [
283
+ df_with_ci['effect'] - df_with_ci['ci_lower'],
284
+ df_with_ci['ci_upper'] - df_with_ci['effect']
285
+ ]
286
+ ax.errorbar(
287
+ x_with_ci, df_with_ci['effect'], yerr=yerr,
288
+ fmt='none', color=color, capsize=capsize, linewidth=linewidth,
289
+ capthick=linewidth, zorder=2
290
+ )
254
291
 
255
292
  # Plot point estimates
256
293
  for i, row in df.iterrows():
@@ -291,14 +328,17 @@ def _extract_plot_data(
291
328
  pre_periods: Optional[List[Any]],
292
329
  post_periods: Optional[List[Any]],
293
330
  reference_period: Optional[Any],
294
- ) -> Tuple[Dict, Dict, List, List, List, Any]:
331
+ ) -> Tuple[Dict, Dict, List, List, List, Any, bool]:
295
332
  """
296
333
  Extract plotting data from various result types.
297
334
 
298
335
  Returns
299
336
  -------
300
337
  tuple
301
- (effects, se, periods, pre_periods, post_periods, reference_period)
338
+ (effects, se, periods, pre_periods, post_periods, reference_period, reference_inferred)
339
+
340
+ reference_inferred is True if reference_period was auto-detected from results
341
+ rather than explicitly provided by the user.
302
342
  """
303
343
  # Handle DataFrame input
304
344
  if isinstance(results, pd.DataFrame):
@@ -315,7 +355,8 @@ def _extract_plot_data(
315
355
  if periods is None:
316
356
  periods = list(results['period'])
317
357
 
318
- return effects, se, periods, pre_periods, post_periods, reference_period
358
+ # DataFrame input: reference_period was already set by caller, never inferred here
359
+ return effects, se, periods, pre_periods, post_periods, reference_period, False
319
360
 
320
361
  # Handle MultiPeriodDiDResults
321
362
  if hasattr(results, 'period_effects'):
@@ -335,7 +376,8 @@ def _extract_plot_data(
335
376
  if periods is None:
336
377
  periods = post_periods
337
378
 
338
- return effects, se, periods, pre_periods, post_periods, reference_period
379
+ # MultiPeriodDiDResults: reference_period was already set by caller, never inferred here
380
+ return effects, se, periods, pre_periods, post_periods, reference_period, False
339
381
 
340
382
  # Handle CallawaySantAnnaResults (event study aggregation)
341
383
  if hasattr(results, 'event_study_effects') and results.event_study_effects is not None:
@@ -349,9 +391,21 @@ def _extract_plot_data(
349
391
  if periods is None:
350
392
  periods = sorted(effects.keys())
351
393
 
394
+ # Track if reference_period was explicitly provided vs auto-inferred
395
+ reference_inferred = False
396
+
352
397
  # Reference period is typically -1 for event study
353
398
  if reference_period is None:
354
- reference_period = -1
399
+ reference_inferred = True # We're about to infer it
400
+ # Detect reference period from n_groups=0 marker (normalization constraint)
401
+ # This handles anticipation > 0 where reference is at e = -1 - anticipation
402
+ for period, effect_data in results.event_study_effects.items():
403
+ if effect_data.get('n_groups', 1) == 0:
404
+ reference_period = period
405
+ break
406
+ # Fallback to -1 if no marker found (backward compatibility)
407
+ if reference_period is None:
408
+ reference_period = -1
355
409
 
356
410
  if pre_periods is None:
357
411
  pre_periods = [p for p in periods if p < 0]
@@ -359,7 +413,7 @@ def _extract_plot_data(
359
413
  if post_periods is None:
360
414
  post_periods = [p for p in periods if p >= 0]
361
415
 
362
- return effects, se, periods, pre_periods, post_periods, reference_period
416
+ return effects, se, periods, pre_periods, post_periods, reference_period, reference_inferred
363
417
 
364
418
  raise TypeError(
365
419
  f"Cannot extract plot data from {type(results).__name__}. "
@@ -4,7 +4,7 @@ build-backend = "maturin"
4
4
 
5
5
  [project]
6
6
  name = "diff-diff"
7
- version = "2.1.5"
7
+ version = "2.1.7"
8
8
  description = "A library for Difference-in-Differences causal inference analysis"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -70,7 +70,11 @@ python-packages = ["diff_diff"]
70
70
  [tool.pytest.ini_options]
71
71
  testpaths = ["tests"]
72
72
  python_files = "test_*.py"
73
+ # Run all tests including slow ones by default; use `pytest -m 'not slow'` for faster local runs
73
74
  addopts = "-v --tb=short"
75
+ markers = [
76
+ "slow: marks tests as slow (run `pytest -m 'not slow'` to exclude, or `pytest -m slow` to run only slow tests)",
77
+ ]
74
78
 
75
79
  [tool.black]
76
80
  line-length = 100
@@ -115,9 +115,9 @@ dependencies = [
115
115
 
116
116
  [[package]]
117
117
  name = "cc"
118
- version = "1.2.53"
118
+ version = "1.2.54"
119
119
  source = "registry+https://github.com/rust-lang/crates.io-index"
120
- checksum = "755d2fce177175ffca841e9a06afdb2c4ab0f593d53b4dee48147dfaade85932"
120
+ checksum = "6354c81bbfd62d9cfa9cb3c773c2b7b2a3a482d569de977fd0e961f6e7c00583"
121
121
  dependencies = [
122
122
  "find-msvc-tools",
123
123
  "shlex",
@@ -289,7 +289,7 @@ dependencies = [
289
289
 
290
290
  [[package]]
291
291
  name = "diff_diff_rust"
292
- version = "2.1.5"
292
+ version = "2.1.7"
293
293
  dependencies = [
294
294
  "ndarray",
295
295
  "ndarray-linalg",
@@ -1220,9 +1220,9 @@ dependencies = [
1220
1220
 
1221
1221
  [[package]]
1222
1222
  name = "quote"
1223
- version = "1.0.43"
1223
+ version = "1.0.44"
1224
1224
  source = "registry+https://github.com/rust-lang/crates.io-index"
1225
- checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a"
1225
+ checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
1226
1226
  dependencies = [
1227
1227
  "proc-macro2",
1228
1228
  ]
@@ -1846,9 +1846,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
1846
1846
 
1847
1847
  [[package]]
1848
1848
  name = "uuid"
1849
- version = "1.19.0"
1849
+ version = "1.20.0"
1850
1850
  source = "registry+https://github.com/rust-lang/crates.io-index"
1851
- checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a"
1851
+ checksum = "ee48d38b119b0cd71fe4141b30f5ba9c7c5d9f4e7a3a8b4a674e4b6ef789976f"
1852
1852
  dependencies = [
1853
1853
  "getrandom 0.3.4",
1854
1854
  "js-sys",
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "diff_diff_rust"
3
- version = "2.1.5"
3
+ version = "2.1.7"
4
4
  edition = "2021"
5
5
  description = "Rust backend for diff-diff DiD library"
6
6
  license = "MIT"
@@ -115,24 +115,24 @@ fn generate_mammen_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array
115
115
 
116
116
  /// Generate Webb 6-point distribution weights.
117
117
  ///
118
- /// Six-point distribution that matches additional moments:
119
- /// E[w] = 0, E[w²] = 1, E[w³] = 0, E[w⁴] = 1
118
+ /// Six-point distribution with equal probabilities (1/6 each) matching R's `did` package:
119
+ /// E[w] = 0, Var[w] = 1
120
120
  ///
121
- /// Values: ±√(3/2), ±√(1/2), ±√(1/6) with specific probabilities
121
+ /// Values: ±√(3/2), ±√(2/2)=±1, ±√(1/2)
122
122
  fn generate_webb_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array2<f64> {
123
123
  // Webb 6-point values
124
- let val1 = (3.0_f64 / 2.0).sqrt(); // √(3/2) ≈ 1.225
125
- let val2 = (1.0_f64 / 2.0).sqrt(); // √(1/2) 0.707
126
- let val3 = (1.0_f64 / 6.0).sqrt(); // √(1/6) ≈ 0.408
124
+ let val1 = (3.0_f64 / 2.0).sqrt(); // √(3/2) ≈ 1.2247
125
+ let val2 = 1.0_f64; // √(2/2) = 1.0
126
+ let val3 = (1.0_f64 / 2.0).sqrt(); // √(1/2) ≈ 0.7071
127
127
 
128
- // Lookup table for direct index computation (replaces 6-way if-else)
129
- // Equal probability: u in [0, 1/6) -> -val1, [1/6, 2/6) -> -val2, etc.
128
+ // Values in order: -val1, -val2, -val3, val3, val2, val1
130
129
  let weights_table = [-val1, -val2, -val3, val3, val2, val1];
131
130
 
132
131
  // Pre-allocate output array - eliminates double allocation
133
132
  let mut weights = Array2::<f64>::zeros((n_bootstrap, n_units));
134
133
 
135
134
  // Fill rows in parallel with chunk size tuning
135
+ // Use uniform selection (1/6 probability each) matching R's did package
136
136
  weights
137
137
  .axis_iter_mut(Axis(0))
138
138
  .into_par_iter()
@@ -141,10 +141,8 @@ fn generate_webb_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array2<
141
141
  .for_each(|(i, mut row)| {
142
142
  let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed.wrapping_add(i as u64));
143
143
  for elem in row.iter_mut() {
144
- let u = rng.gen::<f64>();
145
- // Direct bucket computation: multiply by 6 and floor to get index 0-5
146
- // Clamp to 5 to handle edge case where u == 1.0
147
- let bucket = ((u * 6.0).floor() as usize).min(5);
144
+ // Uniform selection: generate integer 0-5, index into weights_table
145
+ let bucket = rng.gen_range(0..6);
148
146
  *elem = weights_table[bucket];
149
147
  }
150
148
  });
@@ -225,4 +223,60 @@ mod tests {
225
223
  // Different seeds should produce different results
226
224
  assert_ne!(weights1, weights2);
227
225
  }
226
+
227
+ #[test]
228
+ fn test_webb_mean_approx_zero() {
229
+ let weights = generate_webb_batch(10000, 1, 42);
230
+ let mean: f64 = weights.iter().sum::<f64>() / weights.len() as f64;
231
+
232
+ // With 10000 samples, mean should be close to 0
233
+ assert!(
234
+ mean.abs() < 0.1,
235
+ "Webb mean should be close to 0, got {}",
236
+ mean
237
+ );
238
+ }
239
+
240
+ #[test]
241
+ fn test_webb_variance_approx_correct() {
242
+ // Webb's 6-point distribution with values ±√(3/2), ±1, ±√(1/2)
243
+ // and equal probabilities (1/6 each) should have variance = 1.0
244
+ // This matches R's did package behavior.
245
+ // Theoretical: Var = (1/6) * (3/2 + 1 + 1/2 + 1/2 + 1 + 3/2) = (1/6) * 6 = 1.0
246
+ let weights = generate_webb_batch(10000, 100, 42);
247
+ let n = weights.len() as f64;
248
+ let mean: f64 = weights.iter().sum::<f64>() / n;
249
+ let variance: f64 = weights.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
250
+
251
+ // Theoretical variance = 1.0 with equal probabilities
252
+ // Allow some statistical variance in the estimate
253
+ assert!(
254
+ (variance - 1.0).abs() < 0.05,
255
+ "Webb variance should be ~1.0 (matching R's did package), got {}",
256
+ variance
257
+ );
258
+ }
259
+
260
+ #[test]
261
+ fn test_webb_values_correct() {
262
+ // Verify that Webb weights only take the expected 6 values
263
+ let weights = generate_webb_batch(100, 1000, 42);
264
+
265
+ let val1 = (3.0_f64 / 2.0).sqrt(); // ≈ 1.2247
266
+ let val2 = 1.0_f64;
267
+ let val3 = (1.0_f64 / 2.0).sqrt(); // ≈ 0.7071
268
+
269
+ let expected_values = [-val1, -val2, -val3, val3, val2, val1];
270
+
271
+ for w in weights.iter() {
272
+ let matches_expected = expected_values
273
+ .iter()
274
+ .any(|&expected| (*w - expected).abs() < 1e-10);
275
+ assert!(
276
+ matches_expected,
277
+ "Webb weight {} is not one of the expected values",
278
+ w
279
+ );
280
+ }
281
+ }
228
282
  }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes