temporalcv 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. temporalcv/__init__.py +473 -0
  2. temporalcv/bagging/__init__.py +215 -0
  3. temporalcv/bagging/base.py +397 -0
  4. temporalcv/bagging/strategies/__init__.py +17 -0
  5. temporalcv/bagging/strategies/block_bootstrap.py +150 -0
  6. temporalcv/bagging/strategies/feature_bagging.py +171 -0
  7. temporalcv/bagging/strategies/residual_bootstrap.py +311 -0
  8. temporalcv/bagging/strategies/stationary_bootstrap.py +184 -0
  9. temporalcv/benchmarks/__init__.py +99 -0
  10. temporalcv/benchmarks/base.py +412 -0
  11. temporalcv/benchmarks/fred.py +233 -0
  12. temporalcv/benchmarks/gluonts.py +208 -0
  13. temporalcv/benchmarks/m5.py +165 -0
  14. temporalcv/benchmarks/monash.py +258 -0
  15. temporalcv/changepoint.py +614 -0
  16. temporalcv/compare/__init__.py +117 -0
  17. temporalcv/compare/adapters/__init__.py +49 -0
  18. temporalcv/compare/adapters/multi_series.py +302 -0
  19. temporalcv/compare/adapters/statsforecast_adapter.py +216 -0
  20. temporalcv/compare/base.py +571 -0
  21. temporalcv/compare/docs.py +370 -0
  22. temporalcv/compare/results.py +366 -0
  23. temporalcv/compare/runner.py +499 -0
  24. temporalcv/conformal.py +1561 -0
  25. temporalcv/cv.py +1945 -0
  26. temporalcv/cv_financial.py +703 -0
  27. temporalcv/diagnostics/__init__.py +23 -0
  28. temporalcv/diagnostics/influence.py +240 -0
  29. temporalcv/diagnostics/sensitivity.py +250 -0
  30. temporalcv/gates.py +1790 -0
  31. temporalcv/guardrails.py +637 -0
  32. temporalcv/inference/__init__.py +41 -0
  33. temporalcv/inference/block_bootstrap_ci.py +472 -0
  34. temporalcv/inference/wild_bootstrap.py +334 -0
  35. temporalcv/lag_selection.py +407 -0
  36. temporalcv/metrics/__init__.py +152 -0
  37. temporalcv/metrics/asymmetric.py +537 -0
  38. temporalcv/metrics/core.py +587 -0
  39. temporalcv/metrics/event.py +718 -0
  40. temporalcv/metrics/financial.py +607 -0
  41. temporalcv/metrics/quantile.py +485 -0
  42. temporalcv/metrics/volatility_weighted.py +665 -0
  43. temporalcv/persistence.py +722 -0
  44. temporalcv/py.typed +2 -0
  45. temporalcv/regimes.py +578 -0
  46. temporalcv/stationarity.py +478 -0
  47. temporalcv/statistical_tests.py +3217 -0
  48. temporalcv/validators/__init__.py +55 -0
  49. temporalcv/validators/theoretical.py +503 -0
  50. temporalcv-1.0.0.dist-info/METADATA +408 -0
  51. temporalcv-1.0.0.dist-info/RECORD +53 -0
  52. temporalcv-1.0.0.dist-info/WHEEL +4 -0
  53. temporalcv-1.0.0.dist-info/licenses/LICENSE +21 -0
temporalcv/__init__.py ADDED
@@ -0,0 +1,473 @@
1
+ """
2
+ temporalcv: Temporal cross-validation with leakage protection for time-series ML.
3
+
4
+ This package provides rigorous validation tools for time-series forecasting,
5
+ including:
6
+
7
+ - Validation gates for detecting data leakage
8
+ - Walk-forward cross-validation with gap enforcement
9
+ - Statistical tests (Diebold-Mariano, Pesaran-Timmermann)
10
+ - High-persistence series handling (MC-SS, move thresholds)
11
+ - Regime classification (volatility, direction)
12
+ - Conformal prediction intervals with coverage guarantees
13
+ - Time-series-aware bagging with bootstrap strategies
14
+
15
+ Example
16
+ -------
17
+ >>> from temporalcv import run_gates, WalkForwardCV
18
+ >>> from temporalcv.gates import gate_signal_verification
19
+ >>>
20
+ >>> # Signal verification: does model have predictive power?
21
+ >>> result = gate_signal_verification(model=my_model, X=X, y=y, random_state=42)
22
+ >>> if result.status.name == "HALT":
23
+ ... # Model has signal - could be legitimate or leakage
24
+ ... print("Model has signal - investigate source")
25
+ >>>
26
+ >>> # Aggregate multiple gates
27
+ >>> report = run_gates([result])
28
+ >>> if report.status == "HALT":
29
+ ... print(f"Investigation needed: {report.failures}")
30
+
31
+ >>> # Move-conditional metrics for high-persistence series
32
+ >>> from temporalcv import compute_move_threshold, compute_move_conditional_metrics
33
+ >>> threshold = compute_move_threshold(train_actuals) # From training only!
34
+ >>> mc = compute_move_conditional_metrics(predictions, actuals, threshold=threshold)
35
+ >>> print(f"MC-SS: {mc.skill_score:.3f}")
36
+
37
+ >>> # Conformal prediction intervals
38
+ >>> from temporalcv import SplitConformalPredictor
39
+ >>> conformal = SplitConformalPredictor(alpha=0.05)
40
+ >>> conformal.calibrate(cal_preds, cal_actuals)
41
+ >>> intervals = conformal.predict_interval(test_preds)
42
+ >>> print(f"Coverage: {intervals.coverage(test_actuals):.1%}")
43
+ """
44
+
45
+ from __future__ import annotations
46
+
47
+ __version__ = "1.0.0"
48
+
49
+ # Gates module exports
50
+ from temporalcv.gates import (
51
+ GateStatus,
52
+ GateResult,
53
+ ValidationReport,
54
+ StratifiedValidationReport,
55
+ gate_signal_verification,
56
+ gate_synthetic_ar1,
57
+ gate_suspicious_improvement,
58
+ gate_temporal_boundary,
59
+ gate_residual_diagnostics,
60
+ gate_theoretical_bounds,
61
+ run_gates,
62
+ run_gates_stratified,
63
+ )
64
+
65
+ # Statistical tests exports
66
+ from temporalcv.statistical_tests import (
67
+ DMTestResult,
68
+ PTTestResult,
69
+ GWTestResult,
70
+ CWTestResult,
71
+ MultiModelComparisonResult,
72
+ MultiHorizonResult,
73
+ MultiModelHorizonResult,
74
+ EncompassingTestResult,
75
+ BidirectionalEncompassingResult,
76
+ RealityCheckResult,
77
+ SPATestResult,
78
+ dm_test,
79
+ pt_test,
80
+ gw_test,
81
+ cw_test,
82
+ compare_multiple_models,
83
+ compare_horizons,
84
+ compare_models_horizons,
85
+ forecast_encompassing_test,
86
+ forecast_encompassing_bidirectional,
87
+ reality_check_test,
88
+ spa_test,
89
+ compute_hac_variance,
90
+ )
91
+
92
+ # Cross-validation exports
93
+ from temporalcv.cv import (
94
+ SplitInfo,
95
+ SplitResult,
96
+ WalkForwardResults,
97
+ NestedCVResult,
98
+ WalkForwardCV,
99
+ CrossFitCV,
100
+ NestedWalkForwardCV,
101
+ walk_forward_evaluate,
102
+ )
103
+
104
+ # Regime classification exports
105
+ from temporalcv.regimes import (
106
+ classify_volatility_regime,
107
+ classify_direction_regime,
108
+ get_combined_regimes,
109
+ get_regime_counts,
110
+ mask_low_n_regimes,
111
+ StratifiedMetricsResult,
112
+ compute_stratified_metrics,
113
+ )
114
+
115
+ # High-persistence metrics exports
116
+ from temporalcv.persistence import (
117
+ MoveDirection,
118
+ MoveConditionalResult,
119
+ compute_move_threshold,
120
+ classify_moves,
121
+ compute_move_conditional_metrics,
122
+ compute_direction_accuracy,
123
+ compute_move_only_mae,
124
+ compute_persistence_mae,
125
+ )
126
+
127
+ # Core metrics exports
128
+ from temporalcv.metrics import (
129
+ compute_mae,
130
+ compute_mse,
131
+ compute_rmse,
132
+ compute_mape,
133
+ compute_smape,
134
+ compute_bias,
135
+ compute_naive_error,
136
+ compute_mase,
137
+ compute_mrae,
138
+ compute_theils_u,
139
+ compute_forecast_correlation,
140
+ compute_r_squared,
141
+ )
142
+
143
+ # Quantile/interval metrics exports
144
+ from temporalcv.metrics import (
145
+ compute_pinball_loss,
146
+ compute_crps,
147
+ compute_interval_score,
148
+ compute_quantile_coverage,
149
+ compute_winkler_score,
150
+ )
151
+
152
+ # Financial/trading metrics exports
153
+ from temporalcv.metrics import (
154
+ compute_sharpe_ratio,
155
+ compute_max_drawdown,
156
+ compute_cumulative_return,
157
+ compute_information_ratio,
158
+ compute_hit_rate,
159
+ compute_profit_factor,
160
+ compute_calmar_ratio,
161
+ )
162
+
163
+ # Asymmetric loss exports
164
+ from temporalcv.metrics import (
165
+ compute_linex_loss,
166
+ compute_asymmetric_mape,
167
+ compute_directional_loss,
168
+ compute_squared_log_error,
169
+ compute_huber_loss,
170
+ )
171
+
172
+ # Volatility-weighted metrics exports
173
+ from temporalcv.metrics import (
174
+ VolatilityEstimator,
175
+ RollingVolatility,
176
+ EWMAVolatility,
177
+ compute_local_volatility,
178
+ compute_volatility_normalized_mae,
179
+ compute_volatility_weighted_mae,
180
+ VolatilityStratifiedResult,
181
+ compute_volatility_stratified_metrics,
182
+ )
183
+
184
+ # Conformal prediction exports
185
+ from temporalcv.conformal import (
186
+ PredictionInterval,
187
+ SplitConformalPredictor,
188
+ AdaptiveConformalPredictor,
189
+ BellmanConformalPredictor,
190
+ BootstrapUncertainty,
191
+ evaluate_interval_quality,
192
+ walk_forward_conformal,
193
+ CoverageDiagnostics,
194
+ compute_coverage_diagnostics,
195
+ )
196
+
197
+ # Bagging exports
198
+ from temporalcv.bagging import (
199
+ BootstrapStrategy,
200
+ TimeSeriesBagger,
201
+ MovingBlockBootstrap,
202
+ StationaryBootstrap,
203
+ FeatureBagging,
204
+ ResidualBootstrap,
205
+ create_block_bagger,
206
+ create_stationary_bagger,
207
+ create_feature_bagger,
208
+ create_residual_bagger,
209
+ )
210
+
211
+ # Diagnostics exports
212
+ from temporalcv.diagnostics import (
213
+ InfluenceDiagnostic,
214
+ compute_dm_influence,
215
+ GapSensitivityResult,
216
+ gap_sensitivity_analysis,
217
+ )
218
+
219
+ # Inference exports
220
+ from temporalcv.inference import (
221
+ WildBootstrapResult,
222
+ wild_cluster_bootstrap,
223
+ )
224
+
225
+ # Validators exports (theoretical bounds)
226
+ from temporalcv.validators import (
227
+ theoretical_ar1_mse_bound,
228
+ theoretical_ar1_mae_bound,
229
+ theoretical_ar2_mse_bound,
230
+ check_against_ar1_bounds,
231
+ generate_ar1_series,
232
+ generate_ar2_series,
233
+ )
234
+
235
+ # Guardrails exports (unified validation)
236
+ from temporalcv.guardrails import (
237
+ GuardrailResult,
238
+ check_suspicious_improvement,
239
+ check_minimum_sample_size,
240
+ check_stratified_sample_size,
241
+ check_forecast_horizon_consistency,
242
+ check_residual_autocorrelation,
243
+ run_all_guardrails,
244
+ )
245
+
246
+ # Stationarity tests exports
247
+ from temporalcv.stationarity import (
248
+ StationarityTestResult,
249
+ StationarityConclusion,
250
+ JointStationarityResult,
251
+ adf_test,
252
+ kpss_test,
253
+ pp_test,
254
+ check_stationarity,
255
+ difference_until_stationary,
256
+ )
257
+
258
+ # Lag selection exports
259
+ from temporalcv.lag_selection import (
260
+ LagSelectionResult,
261
+ select_lag_pacf,
262
+ select_lag_aic,
263
+ select_lag_bic,
264
+ auto_select_lag,
265
+ suggest_cv_gap,
266
+ )
267
+
268
+ # Changepoint detection exports
269
+ from temporalcv.changepoint import (
270
+ Changepoint,
271
+ ChangepointResult,
272
+ detect_changepoints,
273
+ detect_changepoints_variance,
274
+ detect_changepoints_pelt,
275
+ classify_regimes_from_changepoints,
276
+ create_regime_indicators,
277
+ get_segment_boundaries,
278
+ )
279
+
280
+ # Financial CV exports
281
+ from temporalcv.cv_financial import (
282
+ PurgedSplit,
283
+ PurgedKFold,
284
+ CombinatorialPurgedCV,
285
+ PurgedWalkForward,
286
+ compute_label_overlap,
287
+ estimate_purge_gap,
288
+ )
289
+
290
+ __all__ = [
291
+ "__version__",
292
+ # Gates
293
+ "GateStatus",
294
+ "GateResult",
295
+ "ValidationReport",
296
+ "StratifiedValidationReport",
297
+ "gate_signal_verification",
298
+ "gate_synthetic_ar1",
299
+ "gate_suspicious_improvement",
300
+ "gate_temporal_boundary",
301
+ "gate_residual_diagnostics",
302
+ "gate_theoretical_bounds",
303
+ "run_gates",
304
+ "run_gates_stratified",
305
+ # Statistical tests
306
+ "DMTestResult",
307
+ "PTTestResult",
308
+ "GWTestResult",
309
+ "CWTestResult",
310
+ "EncompassingTestResult",
311
+ "BidirectionalEncompassingResult",
312
+ "RealityCheckResult",
313
+ "SPATestResult",
314
+ "MultiHorizonResult",
315
+ "MultiModelHorizonResult",
316
+ "MultiModelComparisonResult",
317
+ "dm_test",
318
+ "pt_test",
319
+ "gw_test",
320
+ "cw_test",
321
+ "compare_multiple_models",
322
+ "compare_horizons",
323
+ "compare_models_horizons",
324
+ "forecast_encompassing_test",
325
+ "forecast_encompassing_bidirectional",
326
+ "reality_check_test",
327
+ "spa_test",
328
+ "compute_hac_variance",
329
+ # Cross-validation
330
+ "SplitInfo",
331
+ "SplitResult",
332
+ "WalkForwardResults",
333
+ "NestedCVResult",
334
+ "WalkForwardCV",
335
+ "CrossFitCV",
336
+ "NestedWalkForwardCV",
337
+ "walk_forward_evaluate",
338
+ # Regime classification
339
+ "classify_volatility_regime",
340
+ "classify_direction_regime",
341
+ "get_combined_regimes",
342
+ "get_regime_counts",
343
+ "mask_low_n_regimes",
344
+ "StratifiedMetricsResult",
345
+ "compute_stratified_metrics",
346
+ # High-persistence metrics
347
+ "MoveDirection",
348
+ "MoveConditionalResult",
349
+ "compute_move_threshold",
350
+ "classify_moves",
351
+ "compute_move_conditional_metrics",
352
+ "compute_direction_accuracy",
353
+ "compute_move_only_mae",
354
+ "compute_persistence_mae",
355
+ # Core metrics
356
+ "compute_mae",
357
+ "compute_mse",
358
+ "compute_rmse",
359
+ "compute_mape",
360
+ "compute_smape",
361
+ "compute_bias",
362
+ "compute_naive_error",
363
+ "compute_mase",
364
+ "compute_mrae",
365
+ "compute_theils_u",
366
+ "compute_forecast_correlation",
367
+ "compute_r_squared",
368
+ # Quantile/interval metrics
369
+ "compute_pinball_loss",
370
+ "compute_crps",
371
+ "compute_interval_score",
372
+ "compute_quantile_coverage",
373
+ "compute_winkler_score",
374
+ # Financial/trading metrics
375
+ "compute_sharpe_ratio",
376
+ "compute_max_drawdown",
377
+ "compute_cumulative_return",
378
+ "compute_information_ratio",
379
+ "compute_hit_rate",
380
+ "compute_profit_factor",
381
+ "compute_calmar_ratio",
382
+ # Asymmetric loss functions
383
+ "compute_linex_loss",
384
+ "compute_asymmetric_mape",
385
+ "compute_directional_loss",
386
+ "compute_squared_log_error",
387
+ "compute_huber_loss",
388
+ # Volatility-weighted metrics
389
+ "VolatilityEstimator",
390
+ "RollingVolatility",
391
+ "EWMAVolatility",
392
+ "compute_local_volatility",
393
+ "compute_volatility_normalized_mae",
394
+ "compute_volatility_weighted_mae",
395
+ "VolatilityStratifiedResult",
396
+ "compute_volatility_stratified_metrics",
397
+ # Conformal prediction
398
+ "PredictionInterval",
399
+ "SplitConformalPredictor",
400
+ "AdaptiveConformalPredictor",
401
+ "BellmanConformalPredictor",
402
+ "BootstrapUncertainty",
403
+ "evaluate_interval_quality",
404
+ "walk_forward_conformal",
405
+ "CoverageDiagnostics",
406
+ "compute_coverage_diagnostics",
407
+ # Bagging
408
+ "BootstrapStrategy",
409
+ "TimeSeriesBagger",
410
+ "MovingBlockBootstrap",
411
+ "StationaryBootstrap",
412
+ "FeatureBagging",
413
+ "ResidualBootstrap",
414
+ "create_block_bagger",
415
+ "create_stationary_bagger",
416
+ "create_feature_bagger",
417
+ "create_residual_bagger",
418
+ # Diagnostics
419
+ "InfluenceDiagnostic",
420
+ "compute_dm_influence",
421
+ "GapSensitivityResult",
422
+ "gap_sensitivity_analysis",
423
+ # Inference
424
+ "WildBootstrapResult",
425
+ "wild_cluster_bootstrap",
426
+ # Validators (theoretical bounds)
427
+ "theoretical_ar1_mse_bound",
428
+ "theoretical_ar1_mae_bound",
429
+ "theoretical_ar2_mse_bound",
430
+ "check_against_ar1_bounds",
431
+ "generate_ar1_series",
432
+ "generate_ar2_series",
433
+ # Guardrails (unified validation)
434
+ "GuardrailResult",
435
+ "check_suspicious_improvement",
436
+ "check_minimum_sample_size",
437
+ "check_stratified_sample_size",
438
+ "check_forecast_horizon_consistency",
439
+ "check_residual_autocorrelation",
440
+ "run_all_guardrails",
441
+ # Stationarity tests
442
+ "StationarityTestResult",
443
+ "StationarityConclusion",
444
+ "JointStationarityResult",
445
+ "adf_test",
446
+ "kpss_test",
447
+ "pp_test",
448
+ "check_stationarity",
449
+ "difference_until_stationary",
450
+ # Lag selection
451
+ "LagSelectionResult",
452
+ "select_lag_pacf",
453
+ "select_lag_aic",
454
+ "select_lag_bic",
455
+ "auto_select_lag",
456
+ "suggest_cv_gap",
457
+ # Changepoint detection
458
+ "Changepoint",
459
+ "ChangepointResult",
460
+ "detect_changepoints",
461
+ "detect_changepoints_variance",
462
+ "detect_changepoints_pelt",
463
+ "classify_regimes_from_changepoints",
464
+ "create_regime_indicators",
465
+ "get_segment_boundaries",
466
+ # Financial CV
467
+ "PurgedSplit",
468
+ "PurgedKFold",
469
+ "CombinatorialPurgedCV",
470
+ "PurgedWalkForward",
471
+ "compute_label_overlap",
472
+ "estimate_purge_gap",
473
+ ]
@@ -0,0 +1,215 @@
1
+ """
2
+ Time Series Bagging Framework.
3
+
4
+ Generic, model-agnostic bagging for time series with methodologically
5
+ correct bootstrap strategies from the literature.
6
+
7
+ Available Strategies
8
+ --------------------
9
+ - MovingBlockBootstrap: Block bootstrap preserving local autocorrelation
10
+ - StationaryBootstrap: Geometric block lengths for stationarity
11
+ - FeatureBagging: Random subspace method (feature bootstrap)
12
+
13
+ Factory Functions
14
+ -----------------
15
+ - create_block_bagger: Create bagger with Moving Block Bootstrap
16
+ - create_stationary_bagger: Create bagger with Stationary Bootstrap
17
+ - create_feature_bagger: Create bagger with Feature Bagging
18
+
19
+ Example
20
+ -------
21
+ >>> from sklearn.linear_model import Ridge
22
+ >>> from temporalcv.bagging import create_block_bagger
23
+ >>>
24
+ >>> # Create bagged Ridge model
25
+ >>> bagged_ridge = create_block_bagger(Ridge(), n_estimators=20)
26
+ >>> bagged_ridge.fit(X_train, y_train)
27
+ >>> predictions = bagged_ridge.predict(X_test)
28
+ >>> mean, lower, upper = bagged_ridge.predict_interval(X_test)
29
+
30
+ References
31
+ ----------
32
+ - Kunsch (1989). "The Jackknife and Bootstrap for General Stationary"
33
+ - Politis & Romano (1994). "The Stationary Bootstrap"
34
+ - Ho (1998). "The Random Subspace Method"
35
+ - Bergmeir, Hyndman & Benitez (2016). "Bagging Exponential Smoothing"
36
+ """
37
+
38
+ from typing import Optional
39
+
40
+ from temporalcv.bagging.base import (
41
+ SupportsPredict,
42
+ BootstrapStrategy,
43
+ TimeSeriesBagger,
44
+ )
45
+ from temporalcv.bagging.strategies import (
46
+ MovingBlockBootstrap,
47
+ StationaryBootstrap,
48
+ FeatureBagging,
49
+ ResidualBootstrap,
50
+ create_residual_bagger,
51
+ )
52
+
53
+
54
+ def create_block_bagger(
55
+ base_model: SupportsPredict,
56
+ n_estimators: int = 20,
57
+ block_length: Optional[int] = None,
58
+ aggregation: str = "mean",
59
+ random_state: Optional[int] = None,
60
+ ) -> TimeSeriesBagger:
61
+ """
62
+ Create bagged model with Moving Block Bootstrap.
63
+
64
+ Parameters
65
+ ----------
66
+ base_model : SupportsPredict
67
+ Model to bag (e.g., Ridge, ElasticNet)
68
+ n_estimators : int, default=20
69
+ Number of bootstrap estimators
70
+ block_length : int or None, default=None
71
+ Block length. If None, auto-compute as n^(1/3)
72
+ aggregation : {"mean", "median"}, default="mean"
73
+ How to combine predictions
74
+ random_state : int or None, default=None
75
+ Random seed for reproducibility. None for non-deterministic.
76
+
77
+ Returns
78
+ -------
79
+ TimeSeriesBagger
80
+ Bagged model ready for fit/predict
81
+
82
+ Examples
83
+ --------
84
+ >>> from sklearn.linear_model import Ridge
85
+ >>> bagger = create_block_bagger(Ridge(), n_estimators=50)
86
+ >>> bagger.fit(X_train, y_train)
87
+
88
+ See Also
89
+ --------
90
+ create_stationary_bagger : Alternative with geometric block lengths.
91
+ create_feature_bagger : Feature subspace method instead of row bootstrap.
92
+ MovingBlockBootstrap : The underlying bootstrap strategy.
93
+ """
94
+ strategy = MovingBlockBootstrap(block_length=block_length)
95
+ return TimeSeriesBagger(
96
+ base_model,
97
+ strategy,
98
+ n_estimators=n_estimators,
99
+ aggregation=aggregation, # type: ignore[arg-type]
100
+ random_state=random_state,
101
+ )
102
+
103
+
104
+ def create_stationary_bagger(
105
+ base_model: SupportsPredict,
106
+ n_estimators: int = 20,
107
+ expected_block_length: Optional[float] = None,
108
+ aggregation: str = "mean",
109
+ random_state: Optional[int] = None,
110
+ ) -> TimeSeriesBagger:
111
+ """
112
+ Create bagged model with Stationary Bootstrap.
113
+
114
+ Parameters
115
+ ----------
116
+ base_model : SupportsPredict
117
+ Model to bag (e.g., Ridge, ElasticNet)
118
+ n_estimators : int, default=20
119
+ Number of bootstrap estimators
120
+ expected_block_length : float or None, default=None
121
+ Expected block length. If None, auto-compute as n^(1/3)
122
+ aggregation : {"mean", "median"}, default="mean"
123
+ How to combine predictions
124
+ random_state : int or None, default=None
125
+ Random seed for reproducibility. None for non-deterministic.
126
+
127
+ Returns
128
+ -------
129
+ TimeSeriesBagger
130
+ Bagged model ready for fit/predict
131
+
132
+ Examples
133
+ --------
134
+ >>> from sklearn.linear_model import ElasticNet
135
+ >>> bagger = create_stationary_bagger(ElasticNet(), n_estimators=50)
136
+
137
+ See Also
138
+ --------
139
+ create_block_bagger : Fixed block lengths (simpler).
140
+ StationaryBootstrap : The underlying bootstrap strategy.
141
+ """
142
+ strategy = StationaryBootstrap(expected_block_length=expected_block_length)
143
+ return TimeSeriesBagger(
144
+ base_model,
145
+ strategy,
146
+ n_estimators=n_estimators,
147
+ aggregation=aggregation, # type: ignore[arg-type]
148
+ random_state=random_state,
149
+ )
150
+
151
+
152
+ def create_feature_bagger(
153
+ base_model: SupportsPredict,
154
+ n_estimators: int = 20,
155
+ max_features: float = 0.7,
156
+ aggregation: str = "mean",
157
+ random_state: Optional[int] = None,
158
+ ) -> TimeSeriesBagger:
159
+ """
160
+ Create bagged model with Feature Bagging (Random Subspace).
161
+
162
+ Parameters
163
+ ----------
164
+ base_model : SupportsPredict
165
+ Model to bag (e.g., Ridge, ElasticNet)
166
+ n_estimators : int, default=20
167
+ Number of bootstrap estimators
168
+ max_features : float, default=0.7
169
+ Fraction of features per estimator (0.0-1.0)
170
+ aggregation : {"mean", "median"}, default="mean"
171
+ How to combine predictions
172
+ random_state : int or None, default=None
173
+ Random seed for reproducibility. None for non-deterministic.
174
+
175
+ Returns
176
+ -------
177
+ TimeSeriesBagger
178
+ Bagged model ready for fit/predict
179
+
180
+ Examples
181
+ --------
182
+ >>> from sklearn.linear_model import Ridge
183
+ >>> bagger = create_feature_bagger(Ridge(), max_features=0.6)
184
+
185
+ See Also
186
+ --------
187
+ create_block_bagger : Bootstrap rows instead of features.
188
+ FeatureBagging : The underlying bootstrap strategy.
189
+ """
190
+ strategy = FeatureBagging(max_features=max_features)
191
+ return TimeSeriesBagger(
192
+ base_model,
193
+ strategy,
194
+ n_estimators=n_estimators,
195
+ aggregation=aggregation, # type: ignore[arg-type]
196
+ random_state=random_state,
197
+ )
198
+
199
+
200
+ __all__ = [
201
+ # Core classes
202
+ "SupportsPredict",
203
+ "BootstrapStrategy",
204
+ "TimeSeriesBagger",
205
+ # Strategies
206
+ "MovingBlockBootstrap",
207
+ "StationaryBootstrap",
208
+ "FeatureBagging",
209
+ "ResidualBootstrap",
210
+ # Factory functions
211
+ "create_block_bagger",
212
+ "create_stationary_bagger",
213
+ "create_feature_bagger",
214
+ "create_residual_bagger",
215
+ ]