panelkit 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. {panelkit-0.2.0 → panelkit-0.2.1}/Cargo.lock +5 -5
  2. {panelkit-0.2.0 → panelkit-0.2.1}/Cargo.toml +1 -1
  3. {panelkit-0.2.0 → panelkit-0.2.1}/GUIDE.md +46 -4
  4. {panelkit-0.2.0 → panelkit-0.2.1}/PKG-INFO +15 -3
  5. {panelkit-0.2.0 → panelkit-0.2.1}/README.md +14 -2
  6. {panelkit-0.2.0 → panelkit-0.2.1}/crates/geo/src/power.rs +13 -1
  7. {panelkit-0.2.0 → panelkit-0.2.1}/crates/geo/src/selection.rs +36 -5
  8. {panelkit-0.2.0 → panelkit-0.2.1}/crates/geo/tests/geo.rs +25 -3
  9. {panelkit-0.2.0 → panelkit-0.2.1}/crates/pypanelkit/src/api_geo.rs +8 -2
  10. {panelkit-0.2.0 → panelkit-0.2.1}/pyproject.toml +1 -1
  11. {panelkit-0.2.0 → panelkit-0.2.1}/python/panelkit/_panelkit.pyi +3 -0
  12. {panelkit-0.2.0 → panelkit-0.2.1}/python/panelkit/design.py +259 -45
  13. {panelkit-0.2.0 → panelkit-0.2.1}/BENCHMARKS.md +0 -0
  14. {panelkit-0.2.0 → panelkit-0.2.1}/LICENSE-APACHE +0 -0
  15. {panelkit-0.2.0 → panelkit-0.2.1}/LICENSE-MIT +0 -0
  16. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/Cargo.toml +0 -0
  17. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/benches/estimators.rs +0 -0
  18. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/did/bacon.rs +0 -0
  19. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/did/callaway.rs +0 -0
  20. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/did/mod.rs +0 -0
  21. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/did/sunab.rs +0 -0
  22. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/did/twfe.rs +0 -0
  23. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/fe/mod.rs +0 -0
  24. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/fe/within.rs +0 -0
  25. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/lib.rs +0 -0
  26. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/mcnnm/mod.rs +0 -0
  27. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/mcnnm/softimpute.rs +0 -0
  28. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/panel.rs +0 -0
  29. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/result.rs +0 -0
  30. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/sc/augmented.rs +0 -0
  31. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/sc/cpasc.rs +0 -0
  32. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/sc/mod.rs +0 -0
  33. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/sc/sdid.rs +0 -0
  34. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/src/sc/synthetic.rs +0 -0
  35. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/tests/cpasc.rs +0 -0
  36. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/tests/did.rs +0 -0
  37. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/tests/sc.rs +0 -0
  38. {panelkit-0.2.0 → panelkit-0.2.1}/crates/estimators/tests/sc_family.rs +0 -0
  39. {panelkit-0.2.0 → panelkit-0.2.1}/crates/geo/Cargo.toml +0 -0
  40. {panelkit-0.2.0 → panelkit-0.2.1}/crates/geo/src/diagnostics.rs +0 -0
  41. {panelkit-0.2.0 → panelkit-0.2.1}/crates/geo/src/lib.rs +0 -0
  42. {panelkit-0.2.0 → panelkit-0.2.1}/crates/geo/src/types.rs +0 -0
  43. {panelkit-0.2.0 → panelkit-0.2.1}/crates/inference/Cargo.toml +0 -0
  44. {panelkit-0.2.0 → panelkit-0.2.1}/crates/inference/src/batch.rs +0 -0
  45. {panelkit-0.2.0 → panelkit-0.2.1}/crates/inference/src/bootstrap.rs +0 -0
  46. {panelkit-0.2.0 → panelkit-0.2.1}/crates/inference/src/ci.rs +0 -0
  47. {panelkit-0.2.0 → panelkit-0.2.1}/crates/inference/src/lib.rs +0 -0
  48. {panelkit-0.2.0 → panelkit-0.2.1}/crates/inference/src/parallel.rs +0 -0
  49. {panelkit-0.2.0 → panelkit-0.2.1}/crates/inference/src/placebo.rs +0 -0
  50. {panelkit-0.2.0 → panelkit-0.2.1}/crates/inference/tests/inference.rs +0 -0
  51. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/Cargo.toml +0 -0
  52. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/error.rs +0 -0
  53. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/factor/cholesky.rs +0 -0
  54. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/factor/eig_sym.rs +0 -0
  55. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/factor/mod.rs +0 -0
  56. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/factor/qr.rs +0 -0
  57. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/factor/randomized.rs +0 -0
  58. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/factor/svd.rs +0 -0
  59. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/factor/svd_gram.rs +0 -0
  60. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/lib.rs +0 -0
  61. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/matrix.rs +0 -0
  62. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/ops/matmul.rs +0 -0
  63. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/ops/mod.rs +0 -0
  64. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/ops/norms.rs +0 -0
  65. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/ops/transform.rs +0 -0
  66. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/opt/mod.rs +0 -0
  67. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/opt/simplex.rs +0 -0
  68. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/opt/softthresh.rs +0 -0
  69. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/rng.rs +0 -0
  70. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/solve/lstsq.rs +0 -0
  71. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/solve/mod.rs +0 -0
  72. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/src/solve/spd.rs +0 -0
  73. {panelkit-0.2.0 → panelkit-0.2.1}/crates/linalg/tests/numerics.rs +0 -0
  74. {panelkit-0.2.0 → panelkit-0.2.1}/crates/pypanelkit/Cargo.toml +0 -0
  75. {panelkit-0.2.0 → panelkit-0.2.1}/crates/pypanelkit/src/api_did.rs +0 -0
  76. {panelkit-0.2.0 → panelkit-0.2.1}/crates/pypanelkit/src/api_sc.rs +0 -0
  77. {panelkit-0.2.0 → panelkit-0.2.1}/crates/pypanelkit/src/convert.rs +0 -0
  78. {panelkit-0.2.0 → panelkit-0.2.1}/crates/pypanelkit/src/lib.rs +0 -0
  79. {panelkit-0.2.0 → panelkit-0.2.1}/crates/pypanelkit/src/results.rs +0 -0
  80. {panelkit-0.2.0 → panelkit-0.2.1}/python/panelkit/__init__.py +0 -0
  81. {panelkit-0.2.0 → panelkit-0.2.1}/python/panelkit/estimators.py +0 -0
  82. {panelkit-0.2.0 → panelkit-0.2.1}/python/panelkit/py.typed +0 -0
@@ -462,7 +462,7 @@ checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
462
462
 
463
463
  [[package]]
464
464
  name = "panelkit-estimators"
465
- version = "0.2.0"
465
+ version = "0.2.1"
466
466
  dependencies = [
467
467
  "criterion",
468
468
  "panelkit-linalg",
@@ -471,7 +471,7 @@ dependencies = [
471
471
 
472
472
  [[package]]
473
473
  name = "panelkit-geo"
474
- version = "0.2.0"
474
+ version = "0.2.1"
475
475
  dependencies = [
476
476
  "panelkit-estimators",
477
477
  "panelkit-inference",
@@ -482,7 +482,7 @@ dependencies = [
482
482
 
483
483
  [[package]]
484
484
  name = "panelkit-inference"
485
- version = "0.2.0"
485
+ version = "0.2.1"
486
486
  dependencies = [
487
487
  "panelkit-estimators",
488
488
  "panelkit-linalg",
@@ -491,7 +491,7 @@ dependencies = [
491
491
 
492
492
  [[package]]
493
493
  name = "panelkit-linalg"
494
- version = "0.2.0"
494
+ version = "0.2.1"
495
495
  dependencies = [
496
496
  "proptest",
497
497
  "rayon",
@@ -623,7 +623,7 @@ dependencies = [
623
623
 
624
624
  [[package]]
625
625
  name = "pypanelkit"
626
- version = "0.2.0"
626
+ version = "0.2.1"
627
627
  dependencies = [
628
628
  "numpy",
629
629
  "panelkit-estimators",
@@ -3,7 +3,7 @@ resolver = "2"
3
3
  members = ["crates/linalg", "crates/estimators", "crates/inference", "crates/geo", "crates/pypanelkit"]
4
4
 
5
5
  [workspace.package]
6
- version = "0.2.0"
6
+ version = "0.2.1"
7
7
  edition = "2021"
8
8
  rust-version = "1.74"
9
9
  license = "MIT OR Apache-2.0"
@@ -251,16 +251,58 @@ SC / ASC / SDID. Returns a report with:
251
251
  estimate-accuracy CI, design-quality bars).
252
252
 
253
253
  Key options: `alpha` (significance level, default 0.10), `target_power`
254
- (default 0.80), `lifts` (the % grid), `methods`, `recommended` (default SDID).
254
+ (default 0.80), `lifts` (the % grid), `methods`, `recommended` (default SDID),
255
+ `lookback`.
256
+
257
+ **How power is simulated (many placebos, not one).** For a treated set, the test
258
+ window of length `test_len` is *slid across the whole history*: every valid start
259
+ position is one placebo experiment. The detection threshold (critical |ATT|)
260
+ comes from those same windows with **no** injected lift (the historical null), and
261
+ power at lift τ is the share of windows whose injected effect clears that
262
+ threshold. So the estimate is averaged over **many** placebos — `result.n_windows`
263
+ reports how many.
264
+
265
+ **Relationship to GeoLift's `lookback_window`.** GeoLift's lookback is exactly
266
+ this idea — how many recent test-start points to simulate over. By default
267
+ panelkit uses *all* available windows (more placebo samples → a more stable power
268
+ estimate). Pass `lookback=k` to use only the **most-recent k** windows: those have
269
+ the longest pre-periods and reflect current dynamics, so they're the most
270
+ representative of the test you're about to run — at the cost of fewer samples (a
271
+ noisier estimate). It matters when older history is unrepresentative (regime
272
+ change, growth, format changes) or when early windows have very short pre-periods;
273
+ use a `lookback` covering your relevant recent history (e.g. the last ~6–12
274
+ months of windows).
255
275
 
256
276
  ### Choosing a specification — `design.recommend(test_lengths, n_geos_options, target_lift, alphas=…)`
257
277
 
258
278
  Sweeps designs across **test length × number of geos × alpha** and recommends the
259
279
  best (smallest MDE among trustworthy designs, ties broken toward shorter/cheaper).
260
280
  `grid.summary()` prints the recommendation + alternatives; `grid.plot(path)`
261
- renders the **tradeoffs figure** (MDE vs length per #geos, an MDE heatmap over
262
- length × #geos, and alpha sensitivity). Use it to find the "knee" — the cheapest
263
- design that still detects your target lift.
281
+ renders the **tradeoffs figure**. Use it to find the "knee" the cheapest design
282
+ that still detects your target lift.
283
+
284
+ **Reading the tradeoffs figure:**
285
+ - **Top panel** — minimum detectable lift (%) vs test length, one line per number
286
+ of treated geos. *Lower is better.* The red band marks lifts you *can't*
287
+ detect; lines below your target lift are viable designs. More geos and longer
288
+ tests pull the line down (more signal), but cost more holdout/time — pick the
289
+ knee where the curve flattens.
290
+ - **Bottom-left heatmap** — the same MDE across every (test length × #geos) cell,
291
+ green = small detectable lift (good), red = large (bad), grey = underpowered.
292
+ - **Bottom-right** — with multiple alphas, how the MDE of the recommended design
293
+ moves with the significance level (looser α → smaller MDE, more false
294
+ positives); with one alpha, design confidence by spec.
295
+ - The black ★ marks the recommended design.
296
+
297
+ ### Guardrails — `design.diagnose(treated, test_len)`
298
+
299
+ Before trusting a design, check it. `diagnose` returns a report with
300
+ `.summary()` and `.plot(path)` (the **guardrails figure**): the pre-period fit
301
+ (treated vs synthetic control, so you can *see* whether the counterfactual
302
+ tracks), a seasonality ACF, the holdout share against a healthy band, and a
303
+ banner listing any plain-language warnings (weak fit, volatile markets, strong
304
+ seasonality vs short history, tiny/huge holdout, too few donors). It also exposes
305
+ `.confidence`, `.holdout_pct`, and `.warnings`.
264
306
 
265
307
  ### Picking markets — `design.select_markets(test_len, target_lift, max_treated, …)`
266
308
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: panelkit
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Topic :: Scientific/Engineering
@@ -217,11 +217,16 @@ rep = design.power(treated=["chicago", "denver"], test_len=8, alpha=0.10)
217
217
  print(rep.summary()) # plain-English report: MDE, confidence, warnings
218
218
  rep.plot("design.png") # the figure below
219
219
 
220
+ # guardrails: is this design trustworthy? (pre-fit, seasonality, holdout, warnings)
221
+ guard = design.diagnose(treated=["chicago", "denver"], test_len=8)
222
+ print(guard.summary())
223
+ guard.plot("guardrails.png") # the guardrails figure below
224
+
220
225
  # let it pick the markets for you:
221
226
  ranked = design.select_markets(test_len=8, target_lift=0.05, max_treated=3)
222
227
 
223
228
  # or sweep specifications (length × #geos × significance) and recommend one:
224
- grid = design.recommend(test_lengths=[4, 6, 8, 12], n_geos_options=[1, 2, 3, 4],
229
+ grid = design.recommend(test_lengths=[4, 6, 8, 12], n_geos_options=[3, 5, 10, 20],
225
230
  target_lift=0.05, alphas=[0.05, 0.10])
226
231
  print(grid.summary())
227
232
  grid.plot("tradeoffs.png") # the tradeoffs figure below
@@ -229,9 +234,16 @@ grid.plot("tradeoffs.png") # the tradeoffs figure below
229
234
 
230
235
  ![geo design report](assets/geo_design.png)
231
236
 
237
+ **Guardrails — can you trust the design?** `diagnose(...)` visualizes the
238
+ pre-period fit (treated vs synthetic control), seasonality, holdout share, and
239
+ surfaces plain-language warnings when the design is risky:
240
+
241
+ ![guardrails](assets/geo_guardrails.png)
242
+
232
243
  **Recommendations across specifications.** `recommend(...)` sweeps test length ×
233
244
  number of geos × significance level (`alpha`) and points you at the cheapest
234
- design that still detects your target lift — with a figure of the tradeoffs:
245
+ design that still detects your target lift — with a readable figure of the
246
+ tradeoffs (MDE vs length per #geos, an intuitive heatmap, and alpha sensitivity):
235
247
 
236
248
  ![specification tradeoffs](assets/geo_scenarios.png)
237
249
 
@@ -187,11 +187,16 @@ rep = design.power(treated=["chicago", "denver"], test_len=8, alpha=0.10)
187
187
  print(rep.summary()) # plain-English report: MDE, confidence, warnings
188
188
  rep.plot("design.png") # the figure below
189
189
 
190
+ # guardrails: is this design trustworthy? (pre-fit, seasonality, holdout, warnings)
191
+ guard = design.diagnose(treated=["chicago", "denver"], test_len=8)
192
+ print(guard.summary())
193
+ guard.plot("guardrails.png") # the guardrails figure below
194
+
190
195
  # let it pick the markets for you:
191
196
  ranked = design.select_markets(test_len=8, target_lift=0.05, max_treated=3)
192
197
 
193
198
  # or sweep specifications (length × #geos × significance) and recommend one:
194
- grid = design.recommend(test_lengths=[4, 6, 8, 12], n_geos_options=[1, 2, 3, 4],
199
+ grid = design.recommend(test_lengths=[4, 6, 8, 12], n_geos_options=[3, 5, 10, 20],
195
200
  target_lift=0.05, alphas=[0.05, 0.10])
196
201
  print(grid.summary())
197
202
  grid.plot("tradeoffs.png") # the tradeoffs figure below
@@ -199,9 +204,16 @@ grid.plot("tradeoffs.png") # the tradeoffs figure below
199
204
 
200
205
  ![geo design report](assets/geo_design.png)
201
206
 
207
+ **Guardrails — can you trust the design?** `diagnose(...)` visualizes the
208
+ pre-period fit (treated vs synthetic control), seasonality, holdout share, and
209
+ surfaces plain-language warnings when the design is risky:
210
+
211
+ ![guardrails](assets/geo_guardrails.png)
212
+
202
213
  **Recommendations across specifications.** `recommend(...)` sweeps test length ×
203
214
  number of geos × significance level (`alpha`) and points you at the cheapest
204
- design that still detects your target lift — with a figure of the tradeoffs:
215
+ design that still detects your target lift — with a readable figure of the
216
+ tradeoffs (MDE vs length per #geos, an intuitive heatmap, and alpha sensitivity):
205
217
 
206
218
  ![specification tradeoffs](assets/geo_scenarios.png)
207
219
 
@@ -103,6 +103,7 @@ pub fn power_curve(
103
103
  alpha: f64,
104
104
  target_power: f64,
105
105
  min_pre: usize,
106
+ lookback: Option<usize>,
106
107
  ) -> PowerResult {
107
108
  let t = y.cols();
108
109
  assert!(test_len >= 1 && test_len < t, "test_len out of range");
@@ -111,7 +112,18 @@ pub fn power_curve(
111
112
  first <= t - test_len,
112
113
  "not enough periods for the requested pre-window + test_len"
113
114
  );
114
- let starts: Vec<usize> = (first..=(t - test_len)).collect();
115
+ // Every valid sliding test-window start position is one historical placebo.
116
+ // We power over MANY of them (the count is `n_windows`). `lookback`, when set,
117
+ // keeps only the most-recent K windows — GeoLift's "lookback_window": those
118
+ // are the most representative of the upcoming test (recent dynamics, longest
119
+ // pre-periods), at the cost of fewer placebo samples.
120
+ let mut starts: Vec<usize> = (first..=(t - test_len)).collect();
121
+ if let Some(k) = lookback {
122
+ let k = k.max(1);
123
+ if starts.len() > k {
124
+ starts = starts.split_off(starts.len() - k);
125
+ }
126
+ }
115
127
  let n_windows = starts.len();
116
128
  let (base_mean, base_sum) = treated_baseline(y, treated);
117
129
 
@@ -46,6 +46,13 @@ pub struct SelectConfig {
46
46
  /// How many candidate sets to sample/evaluate.
47
47
  pub n_candidates: usize,
48
48
  pub seed: u64,
49
+ /// If `Some(k)`, only consider candidate sets of **exactly** `k` markets
50
+ /// (used by the spec sweep so each "#geos" row reflects that size). If
51
+ /// `None`, considers all sizes from 1 to `max_treated`.
52
+ pub exact_size: Option<usize>,
53
+ /// Number of most-recent historical placebo windows to power over
54
+ /// (GeoLift's lookback). `None` = all available windows.
55
+ pub lookback: Option<usize>,
49
56
  }
50
57
 
51
58
  /// Evaluate a single candidate set: quick power probe + diagnostics → score.
@@ -61,6 +68,7 @@ pub fn evaluate(y: &Mat, treated: &[usize], cfg: &SelectConfig) -> MarketCandida
61
68
  cfg.alpha,
62
69
  cfg.target_power,
63
70
  cfg.min_pre,
71
+ cfg.lookback,
64
72
  );
65
73
  let power_at_target = pr
66
74
  .points
@@ -87,13 +95,36 @@ pub fn evaluate(y: &Mat, treated: &[usize], cfg: &SelectConfig) -> MarketCandida
87
95
  }
88
96
  }
89
97
 
90
- /// Build the candidate list: every single eligible market, plus sampled subsets
91
- /// of size 2..=max_treated.
98
+ /// Build the candidate list. With `exact_size = Some(k)`, every candidate has
99
+ /// exactly `k` markets; otherwise it's every singleton plus sampled subsets of
100
+ /// size 2..=max_treated.
92
101
  fn candidate_sets(cfg: &SelectConfig) -> Vec<Vec<usize>> {
93
- let mut sets: Vec<Vec<usize>> = cfg.eligible.iter().map(|&u| vec![u]).collect();
102
+ let mut rng = Xoshiro256pp::seed_from_u64(cfg.seed);
103
+ let mut seen = std::collections::HashSet::new();
104
+ let mut sets: Vec<Vec<usize>> = Vec::new();
105
+
106
+ if let Some(k) = cfg.exact_size {
107
+ let k = k.min(cfg.eligible.len()).max(1);
108
+ if k == 1 {
109
+ return cfg.eligible.iter().map(|&u| vec![u]).collect();
110
+ }
111
+ let mut attempts = 0;
112
+ while sets.len() < cfg.n_candidates && attempts < cfg.n_candidates * 40 {
113
+ attempts += 1;
114
+ let mut pool = cfg.eligible.clone();
115
+ rng.shuffle(&mut pool);
116
+ let mut pick: Vec<usize> = pool.into_iter().take(k).collect();
117
+ pick.sort_unstable();
118
+ if seen.insert(pick.clone()) {
119
+ sets.push(pick);
120
+ }
121
+ }
122
+ return sets;
123
+ }
124
+
125
+ // Mixed-size search: all singletons + sampled subsets of size 2..=max_treated.
126
+ sets = cfg.eligible.iter().map(|&u| vec![u]).collect();
94
127
  if cfg.max_treated >= 2 && cfg.eligible.len() >= 2 {
95
- let mut rng = Xoshiro256pp::seed_from_u64(cfg.seed);
96
- let mut seen = std::collections::HashSet::new();
97
128
  for s in &sets {
98
129
  seen.insert(s.clone());
99
130
  }
@@ -31,7 +31,7 @@ fn geo_panel(n: usize, t: usize, seed: u64) -> Mat {
31
31
  fn power_increases_with_lift_and_mde_is_sane() {
32
32
  let y = geo_panel(15, 60, 1);
33
33
  let lifts = vec![0.0, 0.02, 0.05, 0.10, 0.20];
34
- let pr = power_curve(&y, &[0], 10, &lifts, Method::Sc, 0.10, 0.8, 20);
34
+ let pr = power_curve(&y, &[0], 10, &lifts, Method::Sc, 0.10, 0.8, 20, None);
35
35
 
36
36
  // Power is (weakly) increasing in lift.
37
37
  for w in pr.points.windows(2) {
@@ -60,11 +60,24 @@ fn power_increases_with_lift_and_mde_is_sane() {
60
60
  }
61
61
  }
62
62
 
63
+ #[test]
64
+ fn lookback_limits_to_recent_windows() {
65
+ let y = geo_panel(15, 60, 1);
66
+ let lifts = vec![0.0, 0.05];
67
+ let all = power_curve(&y, &[0], 10, &lifts, Method::Sc, 0.10, 0.8, 20, None);
68
+ let recent = power_curve(&y, &[0], 10, &lifts, Method::Sc, 0.10, 0.8, 20, Some(8));
69
+ assert_eq!(recent.n_windows, 8, "lookback should cap to 8 windows");
70
+ assert!(
71
+ all.n_windows > recent.n_windows,
72
+ "all-windows count should exceed lookback"
73
+ );
74
+ }
75
+
63
76
  #[test]
64
77
  fn estimated_lift_tracks_true_lift() {
65
78
  let y = geo_panel(15, 60, 2);
66
79
  let lifts = vec![0.0, 0.10];
67
- let pr = power_curve(&y, &[0], 10, &lifts, Method::Sc, 0.10, 0.8, 20);
80
+ let pr = power_curve(&y, &[0], 10, &lifts, Method::Sc, 0.10, 0.8, 20, None);
68
81
  // At a 10% injected lift, the mean estimated lift should be in the ballpark.
69
82
  let p10 = pr.points.last().unwrap();
70
83
  assert!(
@@ -90,7 +103,7 @@ fn all_three_methods_run() {
90
103
  let y = geo_panel(15, 60, 4);
91
104
  let lifts = vec![0.0, 0.10];
92
105
  for m in [Method::Sc, Method::Asc, Method::Sdid] {
93
- let pr = power_curve(&y, &[0], 10, &lifts, m, 0.10, 0.8, 20);
106
+ let pr = power_curve(&y, &[0], 10, &lifts, m, 0.10, 0.8, 20, None);
94
107
  assert_eq!(pr.method, m);
95
108
  assert_eq!(pr.points.len(), 2);
96
109
  }
@@ -110,6 +123,8 @@ fn market_selection_ranks_candidates() {
110
123
  min_pre: 20,
111
124
  n_candidates: 20,
112
125
  seed: 7,
126
+ exact_size: None,
127
+ lookback: None,
113
128
  };
114
129
  let ranked = select_markets(&y, &cfg);
115
130
  assert!(!ranked.is_empty());
@@ -117,6 +132,13 @@ fn market_selection_ranks_candidates() {
117
132
  for w in ranked.windows(2) {
118
133
  assert!(w[0].score >= w[1].score - 1e-12);
119
134
  }
135
+ // exact_size: every candidate has exactly that many markets.
136
+ let cfg2 = SelectConfig {
137
+ exact_size: Some(2),
138
+ ..cfg.clone()
139
+ };
140
+ let ranked2 = select_markets(&y, &cfg2);
141
+ assert!(ranked2.iter().all(|c| c.treated.len() == 2));
120
142
  // Every candidate has a valid holdout and confidence.
121
143
  for c in &ranked {
122
144
  assert!(c.holdout_pct > 0.0 && c.holdout_pct < 1.0);
@@ -23,7 +23,7 @@ fn parse_method(s: &str) -> PyResult<Method> {
23
23
 
24
24
  /// Power analysis for one method via historical placebo with injected lift.
25
25
  #[pyfunction]
26
- #[pyo3(signature = (y, treated, test_len, lifts, method="sdid", alpha=0.1, target_power=0.8, min_pre=0))]
26
+ #[pyo3(signature = (y, treated, test_len, lifts, method="sdid", alpha=0.1, target_power=0.8, min_pre=0, lookback=None))]
27
27
  #[allow(clippy::too_many_arguments)]
28
28
  pub fn geo_power(
29
29
  py: Python<'_>,
@@ -35,6 +35,7 @@ pub fn geo_power(
35
35
  alpha: f64,
36
36
  target_power: f64,
37
37
  min_pre: usize,
38
+ lookback: Option<usize>,
38
39
  ) -> PyResult<PyPowerResult> {
39
40
  let m = parse_method(method)?;
40
41
  let mat = mat_from_numpy(&y);
@@ -53,6 +54,7 @@ pub fn geo_power(
53
54
  alpha,
54
55
  target_power,
55
56
  min_pre,
57
+ lookback,
56
58
  )
57
59
  });
58
60
  Ok(PyPowerResult {
@@ -94,7 +96,7 @@ pub fn geo_diagnostics(
94
96
 
95
97
  /// Search and rank candidate treatment-market sets.
96
98
  #[pyfunction]
97
- #[pyo3(signature = (y, eligible, max_treated, test_len, target_lift, method="sdid", alpha=0.1, target_power=0.8, min_pre=0, n_candidates=200, seed=0))]
99
+ #[pyo3(signature = (y, eligible, max_treated, test_len, target_lift, method="sdid", alpha=0.1, target_power=0.8, min_pre=0, n_candidates=200, seed=0, exact_size=None, lookback=None))]
98
100
  #[allow(clippy::too_many_arguments)]
99
101
  pub fn geo_select(
100
102
  py: Python<'_>,
@@ -109,6 +111,8 @@ pub fn geo_select(
109
111
  min_pre: usize,
110
112
  n_candidates: usize,
111
113
  seed: u64,
114
+ exact_size: Option<usize>,
115
+ lookback: Option<usize>,
112
116
  ) -> PyResult<Vec<PyMarketCandidate>> {
113
117
  let m = parse_method(method)?;
114
118
  let mat = mat_from_numpy(&y);
@@ -128,6 +132,8 @@ pub fn geo_select(
128
132
  min_pre,
129
133
  n_candidates,
130
134
  seed,
135
+ exact_size,
136
+ lookback,
131
137
  };
132
138
  let ranked = py.allow_threads(move || select_markets(&mat, &cfg));
133
139
  Ok(ranked
@@ -4,7 +4,7 @@ build-backend = "maturin"
4
4
 
5
5
  [project]
6
6
  name = "panelkit"
7
- version = "0.2.0"
7
+ version = "0.2.1"
8
8
  description = "Fast, from-scratch causal-inference estimators for panel/geo experiments (SC, ASC, SDID, DiD, MC-NNM)."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.9"
@@ -160,6 +160,7 @@ def geo_power(
160
160
  alpha: float = ...,
161
161
  target_power: float = ...,
162
162
  min_pre: int = ...,
163
+ lookback: Optional[int] = ...,
163
164
  ) -> PowerResult: ...
164
165
  def geo_diagnostics(
165
166
  y: npt.NDArray[np.float64], treated: Sequence[int], test_len: int
@@ -176,6 +177,8 @@ def geo_select(
176
177
  min_pre: int = ...,
177
178
  n_candidates: int = ...,
178
179
  seed: int = ...,
180
+ exact_size: Optional[int] = ...,
181
+ lookback: Optional[int] = ...,
179
182
  ) -> list[MarketCandidate]: ...
180
183
  def fit_callaway_py(
181
184
  y: npt.NDArray[np.float64],
@@ -134,6 +134,56 @@ def _verdict(confidence, mde_pct):
134
134
  "history length, or holdout size before spending.")
135
135
 
136
136
 
137
+ class _DiagnosticsReport:
138
+ """Real-world guardrails for a design, with a summary and a visual."""
139
+
140
+ def __init__(self, treated_names, t0, test_len, diag, treated_series, synthetic):
141
+ self.treated_names = treated_names
142
+ self.t0 = t0
143
+ self.test_len = test_len
144
+ self._raw = diag
145
+ self.treated_series = np.asarray(treated_series, dtype=float)
146
+ self.synthetic = np.asarray(synthetic, dtype=float)
147
+
148
+ @property
149
+ def holdout_pct(self):
150
+ return self._raw.holdout_pct
151
+
152
+ @property
153
+ def confidence(self):
154
+ return self._raw.confidence
155
+
156
+ @property
157
+ def warnings(self):
158
+ return list(self._raw.warnings)
159
+
160
+ def summary(self) -> str:
161
+ d = self._raw
162
+ lines = ["GUARDRAILS — " + ", ".join(map(str, self.treated_names))]
163
+ lines.append(f" holdout : {100*d.holdout_pct:.1f}% of volume")
164
+ lines.append(f" pre-period fit : rel. RMSPE {d.pre_fit_rel:.2f} "
165
+ f"({'good' if d.pre_fit_rel < 0.25 else 'fair' if d.pre_fit_rel < 0.5 else 'weak'})")
166
+ lines.append(f" improvement v naive: {100*d.improvement_vs_naive:.0f}%")
167
+ lines.append(f" seasonality : {d.seasonality_strength:.2f}")
168
+ lines.append(f" stability : {d.stability_score:.2f}")
169
+ lines.append(f" confidence : {d.confidence:.0f}/100")
170
+ if d.warnings:
171
+ lines.append(" warnings:")
172
+ for w in d.warnings:
173
+ lines.append(f" ⚠ {w}")
174
+ else:
175
+ lines.append(" ✓ no warnings")
176
+ return "\n".join(lines)
177
+
178
+ def plot(self, path: str | None = None):
179
+ """Render the guardrails figure. Returns the matplotlib Figure."""
180
+ return _plot_guardrails(self, path)
181
+
182
+ def __repr__(self):
183
+ return (f"GuardrailsReport(confidence={self.confidence:.0f}, "
184
+ f"holdout={100*self.holdout_pct:.1f}%, warnings={len(self.warnings)})")
185
+
186
+
137
187
  class GeoDesign:
138
188
  """A geo panel ready for power analysis and market selection.
139
189
 
@@ -281,23 +331,48 @@ class GeoDesign:
281
331
  alpha: float = 0.10,
282
332
  target_power: float = 0.80,
283
333
  recommended: str = "SDID",
334
+ lookback: int | None = None,
284
335
  ) -> _PowerReport:
285
- """Power analysis for a specified treated-market set across methods."""
336
+ """Power analysis for a specified treated-market set across methods.
337
+
338
+ Powers over many historical placebo windows (sliding the test window
339
+ across history); ``lookback=k`` restricts to the most-recent ``k`` windows
340
+ (GeoLift-style), which are most representative of the upcoming test."""
286
341
  idx = self._resolve(treated)
287
342
  names = [self.names[i] for i in idx]
288
343
  lifts = list(_DEFAULT_LIFTS if lifts is None else lifts)
289
344
  if 0.0 not in lifts:
290
345
  lifts = [0.0] + list(lifts)
291
346
  lifts = sorted(set(float(x) for x in lifts))
347
+ lb = None if lookback is None else int(lookback)
292
348
  results = {}
293
349
  for m in methods:
294
350
  results[m] = _panelkit.geo_power(
295
- self.Y, idx, int(test_len), lifts, m.lower(), alpha, target_power, 0
351
+ self.Y, idx, int(test_len), lifts, m.lower(), alpha, target_power, 0, lb
296
352
  )
297
353
  diag = _panelkit.geo_diagnostics(self.Y, idx, int(test_len))
298
354
  rec = recommended if recommended in results else list(results)[0]
299
355
  return _PowerReport(self, idx, names, test_len, results, diag, rec, alpha, target_power)
300
356
 
357
+ def diagnose(self, treated, test_len: int) -> "_DiagnosticsReport":
358
+ """Real-world guardrails for a treated-market set: pre-period fit,
359
+ seasonality, holdout, stability, and warnings — with a visual.
360
+
361
+ Returns a report with ``.summary()`` and ``.plot(path)`` (the guardrails
362
+ figure: treated-vs-synthetic pre-fit, seasonality ACF, holdout share, and
363
+ a scorecard listing any warnings)."""
364
+ idx = self._resolve(treated)
365
+ names = [self.names[i] for i in idx]
366
+ t0 = self.t - int(test_len)
367
+ diag = _panelkit.geo_diagnostics(self.Y, idx, int(test_len))
368
+ # Treated-average series and the SC counterfactual (from the SC weights).
369
+ treated_series = self.Y[idx].mean(axis=0)
370
+ scres = _panelkit.fit_sc(self.Y, idx, int(t0), 0.0, False, 0.95)
371
+ w = np.asarray(scres.weights, dtype=float)
372
+ donors = np.asarray(scres.donor_ids, dtype=int)
373
+ synthetic = self.Y[donors].T @ w if len(donors) else np.full(self.t, np.nan)
374
+ return _DiagnosticsReport(names, t0, test_len, diag, treated_series, synthetic)
375
+
301
376
  def select_markets(
302
377
  self,
303
378
  test_len: int,
@@ -310,12 +385,20 @@ class GeoDesign:
310
385
  n_candidates: int = 200,
311
386
  seed: int = 0,
312
387
  top: int = 10,
388
+ exact_size: int | None = None,
389
+ lookback: int | None = None,
313
390
  ) -> list:
314
- """Search candidate treatment-market sets and return the top ranked."""
391
+ """Search candidate treatment-market sets and return the top ranked.
392
+
393
+ ``exact_size=k`` restricts the search to sets of exactly ``k`` markets
394
+ (otherwise sizes 1..``max_treated`` are considered). ``lookback=k`` powers
395
+ over the most-recent ``k`` historical windows (GeoLift-style)."""
315
396
  elig = self._resolve(eligible) if eligible is not None else list(range(self.n))
316
397
  ranked = _panelkit.geo_select(
317
398
  self.Y, elig, int(max_treated), int(test_len), float(target_lift),
318
399
  method.lower(), alpha, target_power, 0, int(n_candidates), int(seed),
400
+ None if exact_size is None else int(exact_size),
401
+ None if lookback is None else int(lookback),
319
402
  )
320
403
  out = []
321
404
  for c in ranked[:top]:
@@ -342,6 +425,7 @@ class GeoDesign:
342
425
  n_candidates: int = 80,
343
426
  seed: int = 0,
344
427
  min_confidence: float = 60.0,
428
+ lookback: int | None = None,
345
429
  ) -> "_ScenarioGrid":
346
430
  """Sweep designs across **specifications** — test length × number of geos
347
431
  × significance level (alpha) — and recommend the best.
@@ -359,10 +443,9 @@ class GeoDesign:
359
443
  test_len=tl, target_lift=target_lift, max_treated=ng,
360
444
  eligible=eligible, method=method, alpha=alpha,
361
445
  target_power=target_power, n_candidates=n_candidates,
362
- seed=seed, top=n_candidates,
446
+ seed=seed, top=1, exact_size=ng, lookback=lookback,
363
447
  )
364
- exact = [c for c in ranked if len(c["markets"]) == ng]
365
- best = exact[0] if exact else (ranked[0] if ranked else None)
448
+ best = ranked[0] if ranked else None
366
449
  if best is None:
367
450
  continue
368
451
  rows.append({
@@ -549,6 +632,10 @@ def _plot_power(rep: _PowerReport, path):
549
632
  return fig
550
633
 
551
634
 
635
+ # Distinct, colorblind-friendly line colors (one per #geos), not a gradient.
636
+ _GEO_PALETTE = ["#2563eb", "#059669", "#d97706", "#dc2626", "#7c3aed", "#0891b2"]
637
+
638
+
552
639
  def _plot_scenarios(grid: "_ScenarioGrid", path):
553
640
  _, plt = _require_mpl()
554
641
  import numpy as _np
@@ -557,37 +644,56 @@ def _plot_scenarios(grid: "_ScenarioGrid", path):
557
644
  a0 = grid.alphas[0] # primary alpha for the main panels
558
645
  rec = grid.recommended
559
646
  by = {(r["alpha"], r["test_len"], r["n_geos"]): r for r in grid.rows}
647
+ color_for = {ng: _GEO_PALETTE[i % len(_GEO_PALETTE)]
648
+ for i, ng in enumerate(grid.n_geos_options)}
560
649
 
561
- fig = plt.figure(figsize=(11, 7.4))
650
+ plt.rcParams.update({"font.size": 11, "axes.titlesize": 12})
651
+ fig = plt.figure(figsize=(12, 7.6))
562
652
  fig.patch.set_facecolor("white")
563
- gs = GridSpec(2, 2, figure=fig, height_ratios=[1.1, 1.0], hspace=0.36, wspace=0.28)
653
+ gs = GridSpec(2, 2, figure=fig, height_ratios=[1.15, 1.0], hspace=0.42, wspace=0.30)
564
654
 
565
- # Panel 1: MDE vs test length, one line per #geos (at primary alpha).
655
+ # ---- Panel 1: MDE vs test length, one labelled line per #geos. ----
566
656
  ax = fig.add_subplot(gs[0, :])
567
- cmap = plt.get_cmap("viridis")
568
- for j, ng in enumerate(grid.n_geos_options):
657
+ ymax = 0.0
658
+ for ng in grid.n_geos_options:
569
659
  xs, ys = [], []
570
660
  for tl in grid.test_lengths:
571
661
  r = by.get((a0, tl, ng))
572
662
  if r and r["mde_pct"] is not None:
573
663
  xs.append(tl)
574
664
  ys.append(100 * r["mde_pct"])
575
- if xs:
576
- color = cmap(j / max(1, len(grid.n_geos_options) - 1))
577
- ax.plot(xs, ys, "o-", color=color, lw=2.0, markersize=5, label=f"{ng} geos")
578
- ax.axhline(100 * grid.target_lift, ls="--", color="#374151", lw=1.0,
579
- label=f"target lift {100*grid.target_lift:.0f}%")
665
+ if not xs:
666
+ continue
667
+ ymax = max(ymax, max(ys))
668
+ c = color_for[ng]
669
+ ax.plot(xs, ys, "o-", color=c, lw=2.6, markersize=7, label=f"{ng} geos", zorder=3)
670
+ # label each line at its right end so you don't need to trace the legend
671
+ ax.annotate(f"{ng} geos", (xs[-1], ys[-1]), textcoords="offset points",
672
+ xytext=(8, 0), va="center", color=c, fontweight="bold", fontsize=10)
673
+ tgt = 100 * grid.target_lift
674
+ ax.axhline(tgt, ls="--", color="#374151", lw=1.2)
675
+ ax.axhspan(tgt, max(ymax, tgt) * 1.08 + 0.5, color="#fca5a5", alpha=0.12)
676
+ ax.annotate("can't detect below this lift", (grid.test_lengths[0], tgt),
677
+ textcoords="offset points", xytext=(4, 6), color="#b91c1c", fontsize=9)
580
678
  if rec is not None and rec["alpha"] == a0 and rec["mde_pct"] is not None:
581
- ax.plot(rec["test_len"], 100 * rec["mde_pct"], "*", color="#dc2626",
582
- markersize=20, zorder=5, label="recommended")
679
+ ax.plot(rec["test_len"], 100 * rec["mde_pct"], "*", color="#111827",
680
+ markersize=22, zorder=6)
681
+ ax.annotate("recommended", (rec["test_len"], 100 * rec["mde_pct"]),
682
+ textcoords="offset points", xytext=(6, -16), fontweight="bold")
583
683
  ax.set_xlabel("test length (periods)")
584
- ax.set_ylabel("minimum detectable lift (%)")
585
- ax.set_title(f"Detectable lift vs test length & number of geos (α={a0:.2f})",
586
- fontweight="bold")
684
+ ax.set_ylabel("min. detectable lift (%) · lower = better")
685
+ ax.set_title(f"How small a lift can you detect? (α = {a0:.2f})", fontweight="bold")
686
+ ax.set_xticks(grid.test_lengths)
687
+ ax.set_ylim(0, max(ymax, tgt) * 1.12 + 0.5)
688
+ ax.margins(x=0.08)
587
689
  ax.grid(True, alpha=0.25)
588
- ax.legend(loc="upper right", framealpha=0.9, fontsize=8, ncol=2)
690
+ # endpoint labels already identify lines; keep the legend out of the way
691
+ # (lower-left, where the curves don't go).
692
+ ax.legend(title="treatment markets", loc="lower left", framealpha=0.95, ncol=2,
693
+ fontsize=9)
694
+ ax.margins(x=0.12) # room for the right-edge endpoint labels
589
695
 
590
- # Panel 2: heatmap of MDE over (n_geos × test_len) at primary alpha.
696
+ # ---- Panel 2: MDE heatmap (red = worse, green = better), readable text. ----
591
697
  ax2 = fig.add_subplot(gs[1, 0])
592
698
  grid_mde = _np.full((len(grid.n_geos_options), len(grid.test_lengths)), _np.nan)
593
699
  for i, ng in enumerate(grid.n_geos_options):
@@ -595,23 +701,34 @@ def _plot_scenarios(grid: "_ScenarioGrid", path):
595
701
  r = by.get((a0, tl, ng))
596
702
  if r and r["mde_pct"] is not None:
597
703
  grid_mde[i, k] = 100 * r["mde_pct"]
598
- im = ax2.imshow(grid_mde, aspect="auto", cmap="viridis_r", origin="lower")
704
+ cmap = plt.get_cmap("RdYlGn_r").copy()
705
+ cmap.set_bad("#e5e7eb") # grey for un-powered cells
706
+ finite = grid_mde[_np.isfinite(grid_mde)]
707
+ vmin = float(finite.min()) if finite.size else 0.0
708
+ vmax = float(finite.max()) if finite.size else 1.0
709
+ im = ax2.imshow(grid_mde, aspect="auto", cmap=cmap, origin="lower",
710
+ vmin=vmin, vmax=vmax)
599
711
  ax2.set_xticks(range(len(grid.test_lengths)))
600
712
  ax2.set_xticklabels(grid.test_lengths)
601
713
  ax2.set_yticks(range(len(grid.n_geos_options)))
602
714
  ax2.set_yticklabels(grid.n_geos_options)
603
715
  ax2.set_xlabel("test length")
604
716
  ax2.set_ylabel("number of geos")
605
- ax2.set_title("MDE (%) heatmap", fontweight="bold")
717
+ ax2.set_title("Detectable lift (%) by design", fontweight="bold")
718
+ span = (vmax - vmin) or 1.0
606
719
  for i in range(grid_mde.shape[0]):
607
720
  for k in range(grid_mde.shape[1]):
608
- if not _np.isnan(grid_mde[i, k]):
609
- ax2.text(k, i, f"{grid_mde[i, k]:.1f}", ha="center", va="center",
610
- color="white", fontsize=8)
611
- fig.colorbar(im, ax=ax2, fraction=0.046, pad=0.04, label="MDE %")
612
-
613
- # Panel 3: alpha sensitivity for the recommended (test_len, n_geos), or — if
614
- # only one alpha — confidence vs test length by #geos.
721
+ v = grid_mde[i, k]
722
+ if not _np.isnan(v):
723
+ r_, g_, b_, _ = cmap((v - vmin) / span)
724
+ lum = 0.299 * r_ + 0.587 * g_ + 0.114 * b_
725
+ ax2.text(k, i, f"{v:.1f}", ha="center", va="center",
726
+ color="black" if lum > 0.55 else "white",
727
+ fontsize=10, fontweight="bold")
728
+ fig.colorbar(im, ax=ax2, fraction=0.046, pad=0.04).set_label(
729
+ "MDE (%) — greener is better", fontsize=9)
730
+
731
+ # ---- Panel 3: alpha sensitivity (recommended spec), else confidence. ----
615
732
  ax3 = fig.add_subplot(gs[1, 1])
616
733
  if len(grid.alphas) > 1 and rec is not None:
617
734
  xs, ys = [], []
@@ -620,14 +737,17 @@ def _plot_scenarios(grid: "_ScenarioGrid", path):
620
737
  if r and r["mde_pct"] is not None:
621
738
  xs.append(a)
622
739
  ys.append(100 * r["mde_pct"])
623
- ax3.plot(xs, ys, "o-", color=_PK_BLUE, lw=2.0)
740
+ ax3.plot(xs, ys, "o-", color=_PK_BLUE, lw=2.6, markersize=7)
741
+ for xa, ya in zip(xs, ys):
742
+ ax3.annotate(f"{ya:.1f}%", (xa, ya), textcoords="offset points",
743
+ xytext=(0, 8), ha="center", fontsize=9)
624
744
  ax3.set_xlabel("significance level α")
625
- ax3.set_ylabel("MDE (%)")
626
- ax3.set_title(f"Alpha sensitivity ({rec['n_geos']}g × {rec['test_len']}p)",
745
+ ax3.set_ylabel("min. detectable lift (%)")
746
+ ax3.set_title(f"Looser α → smaller MDE ({rec['n_geos']}g × {rec['test_len']}p)",
627
747
  fontweight="bold")
748
+ ax3.margins(x=0.15, y=0.2)
628
749
  else:
629
- cmap = plt.get_cmap("viridis")
630
- for j, ng in enumerate(grid.n_geos_options):
750
+ for ng in grid.n_geos_options:
631
751
  xs, ys = [], []
632
752
  for tl in grid.test_lengths:
633
753
  r = by.get((a0, tl, ng))
@@ -635,19 +755,113 @@ def _plot_scenarios(grid: "_ScenarioGrid", path):
635
755
  xs.append(tl)
636
756
  ys.append(r["confidence"])
637
757
  if xs:
638
- ax3.plot(xs, ys, "o-", lw=1.8, markersize=4,
639
- color=cmap(j / max(1, len(grid.n_geos_options) - 1)),
640
- label=f"{ng} geos")
641
- ax3.axhline(grid.min_confidence, ls=":", color="#dc2626", lw=1.0,
758
+ ax3.plot(xs, ys, "o-", lw=2.2, markersize=6,
759
+ color=color_for[ng], label=f"{ng} geos")
760
+ ax3.axhline(grid.min_confidence, ls=":", color="#dc2626", lw=1.2,
642
761
  label="min confidence")
643
762
  ax3.set_xlabel("test length")
644
- ax3.set_ylabel("design confidence")
645
- ax3.legend(fontsize=7, framealpha=0.9)
763
+ ax3.set_ylabel("design confidence (0–100)")
764
+ ax3.legend(fontsize=8, framealpha=0.95)
646
765
  ax3.set_title("Design confidence by spec", fontweight="bold")
647
766
  ax3.grid(True, alpha=0.25)
648
767
 
649
- fig.suptitle("panelkit · specification tradeoffs", fontsize=13, fontweight="bold",
650
- x=0.01, ha="left")
768
+ fig.suptitle("panelkit · specification tradeoffs", fontsize=14, fontweight="bold",
769
+ x=0.012, ha="left")
770
+ if path:
771
+ fig.savefig(path, dpi=150, bbox_inches="tight")
772
+ return fig
773
+
774
+
775
+ def _plot_guardrails(rep: "_DiagnosticsReport", path):
776
+ _, plt = _require_mpl()
777
+ import numpy as _np
778
+ from matplotlib.gridspec import GridSpec
779
+
780
+ d = rep._raw
781
+ t0 = rep.t0
782
+ T = len(rep.treated_series)
783
+ x = _np.arange(T)
784
+
785
+ plt.rcParams.update({"font.size": 11, "axes.titlesize": 12})
786
+ fig = plt.figure(figsize=(12, 7.8))
787
+ fig.patch.set_facecolor("white")
788
+ gs = GridSpec(2, 2, figure=fig, height_ratios=[1.0, 1.0], hspace=0.40, wspace=0.26)
789
+
790
+ # ---- A: pre-period fit — treated vs synthetic control. ----
791
+ ax = fig.add_subplot(gs[0, :])
792
+ ax.axvspan(t0 - 0.5, T - 0.5, color="#dbeafe", alpha=0.5, label="test window")
793
+ ax.plot(x, rep.treated_series, color="#111827", lw=2.2, label="treated (actual)")
794
+ if _np.isfinite(rep.synthetic).all():
795
+ ax.plot(x, rep.synthetic, color="#2563eb", lw=2.0, ls="--",
796
+ label="synthetic control")
797
+ ax.axvline(t0 - 0.5, color="#374151", lw=1.0, ls=":")
798
+ fit_word = "good" if d.pre_fit_rel < 0.25 else "fair" if d.pre_fit_rel < 0.5 else "weak"
799
+ fit_color = "#059669" if d.pre_fit_rel < 0.25 else "#d97706" if d.pre_fit_rel < 0.5 else "#dc2626"
800
+ ax.set_title("Pre-period fit: does the synthetic control track the treated markets?",
801
+ fontweight="bold")
802
+ ax.set_xlabel("period")
803
+ ax.set_ylabel("outcome")
804
+ ax.grid(True, alpha=0.25)
805
+ ax.legend(loc="upper left", framealpha=0.95, fontsize=9)
806
+ ax.annotate(f"pre-fit: {fit_word} (rel. RMSPE {d.pre_fit_rel:.2f})",
807
+ xy=(0.99, 0.04), xycoords="axes fraction", ha="right",
808
+ color=fit_color, fontweight="bold", fontsize=10)
809
+
810
+ # ---- B: seasonality — ACF of pre-period first differences. ----
811
+ axb = fig.add_subplot(gs[1, 0])
812
+ pre = rep.treated_series[:t0]
813
+ dd = _np.diff(pre)
814
+ dd = dd - dd.mean()
815
+ denom = (dd ** 2).sum()
816
+ max_lag = int(min(len(dd) // 2, 26))
817
+ lags = list(range(1, max(max_lag, 2)))
818
+ acf = [float((dd[lag:] * dd[:-lag]).sum() / denom) if denom > 0 else 0.0 for lag in lags]
819
+ best_lag = lags[int(_np.argmax(acf))] if acf else 0
820
+ colors = ["#dc2626" if (lg == best_lag and d.seasonality_strength > 0.3) else "#93c5fd"
821
+ for lg in lags]
822
+ axb.bar(lags, acf, color=colors)
823
+ axb.axhline(0, color="#374151", lw=0.8)
824
+ axb.set_xlabel("lag (periods)")
825
+ axb.set_ylabel("autocorrelation")
826
+ seas_word = ("strong" if d.seasonality_strength > 0.5 else
827
+ "some" if d.seasonality_strength > 0.3 else "weak")
828
+ title = f"Seasonality: {seas_word} (strength {d.seasonality_strength:.2f})"
829
+ if d.seasonality_strength > 0.3 and best_lag:
830
+ title += f", ≈{best_lag}-period cycle"
831
+ axb.set_title(title, fontweight="bold")
832
+ axb.grid(True, axis="y", alpha=0.25)
833
+
834
+ # ---- C: holdout share. ----
835
+ axc = fig.add_subplot(gs[1, 1])
836
+ h = d.holdout_pct
837
+ in_band = 0.03 <= h <= 0.35
838
+ bar_color = "#059669" if in_band else "#d97706"
839
+ axc.barh([0], [100 * h], color=bar_color, height=0.5, label="treated")
840
+ axc.barh([0], [100 * (1 - h)], left=[100 * h], color="#e5e7eb", height=0.5,
841
+ label="control / donors")
842
+ axc.axvspan(3, 35, color="#bbf7d0", alpha=0.35) # healthy band
843
+ axc.set_xlim(0, 100)
844
+ axc.set_yticks([])
845
+ axc.set_xlabel("% of total volume")
846
+ axc.set_title(f"Holdout: treated = {100*h:.1f}% of volume "
847
+ f"({'healthy' if in_band else 'check'})", fontweight="bold")
848
+ axc.annotate(f"{100*h:.1f}%", (100 * h / 2, 0), ha="center", va="center",
849
+ color="white", fontweight="bold")
850
+ axc.legend(loc="lower right", fontsize=8, framealpha=0.95)
851
+ axc.annotate("healthy 3–35%", (19, 0.32), ha="center", color="#15803d", fontsize=8)
852
+
853
+ # ---- Warnings / verdict banner across the bottom. ----
854
+ warns = list(d.warnings)
855
+ if warns:
856
+ txt = "⚠ Guardrail warnings:\n" + "\n".join(f" • {w}" for w in warns)
857
+ box = dict(boxstyle="round,pad=0.5", fc="#fef3c7", ec="#d97706")
858
+ else:
859
+ txt = "✓ No guardrail warnings — design looks clean."
860
+ box = dict(boxstyle="round,pad=0.5", fc="#dcfce7", ec="#059669")
861
+ fig.text(0.012, -0.02, txt, ha="left", va="top", fontsize=9, bbox=box, wrap=True)
862
+
863
+ fig.suptitle(f"panelkit · guardrails — confidence {d.confidence:.0f}/100",
864
+ fontsize=14, fontweight="bold", x=0.012, ha="left")
651
865
  if path:
652
866
  fig.savefig(path, dpi=150, bbox_inches="tight")
653
867
  return fig
File without changes
File without changes
File without changes
File without changes
File without changes