ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
"""Configuration classes for cross-validation splitters.
|
|
2
|
+
|
|
3
|
+
This module provides Pydantic-based configuration for all CV splitters,
|
|
4
|
+
enabling reproducible, serializable, and validated split strategies.
|
|
5
|
+
|
|
6
|
+
Integration with qdata
|
|
7
|
+
----------------------
|
|
8
|
+
Session-aware splitting consumes `session_date` column from qdata library:
|
|
9
|
+
|
|
10
|
+
from ml4t.data.sessions import SessionAssigner
|
|
11
|
+
assigner = SessionAssigner.from_exchange('CME')
|
|
12
|
+
df_with_sessions = assigner.assign_sessions(df)
|
|
13
|
+
|
|
14
|
+
config = PurgedWalkForwardConfig(
|
|
15
|
+
n_splits=5,
|
|
16
|
+
test_size=4, # 4 sessions
|
|
17
|
+
align_to_sessions=True
|
|
18
|
+
)
|
|
19
|
+
cv = PurgedWalkForwardCV.from_config(config)
|
|
20
|
+
for train, test in cv.split(df_with_sessions):
|
|
21
|
+
# Fold boundaries aligned to session boundaries
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
Examples
|
|
25
|
+
--------
|
|
26
|
+
>>> # Parameter-based initialization (backward compatible)
|
|
27
|
+
>>> cv = PurgedWalkForwardCV(n_splits=5, test_size=100)
|
|
28
|
+
>>>
|
|
29
|
+
>>> # Config-based initialization
|
|
30
|
+
>>> config = PurgedWalkForwardConfig(n_splits=5, test_size=100)
|
|
31
|
+
>>> cv = PurgedWalkForwardCV.from_config(config)
|
|
32
|
+
>>>
|
|
33
|
+
>>> # Serialize config for reproducibility
|
|
34
|
+
>>> config.to_json("cv_config.json")
|
|
35
|
+
>>> loaded_config = PurgedWalkForwardConfig.from_json("cv_config.json")
|
|
36
|
+
>>> cv = PurgedWalkForwardCV.from_config(loaded_config)
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
from __future__ import annotations
|
|
40
|
+
|
|
41
|
+
from typing import Any
|
|
42
|
+
|
|
43
|
+
from pydantic import Field, field_validator, model_validator
|
|
44
|
+
|
|
45
|
+
from ml4t.diagnostic.config.base import BaseConfig
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SplitterConfig(BaseConfig):
|
|
49
|
+
"""Base configuration for all cross-validation splitters.
|
|
50
|
+
|
|
51
|
+
All splitter configs inherit from this class to ensure consistent
|
|
52
|
+
serialization, validation, and reproducibility.
|
|
53
|
+
|
|
54
|
+
Attributes
|
|
55
|
+
----------
|
|
56
|
+
n_splits : int
|
|
57
|
+
Number of cross-validation folds.
|
|
58
|
+
label_horizon : int
|
|
59
|
+
Number of periods ahead that labels look.
|
|
60
|
+
Used for purging and embargo calculations.
|
|
61
|
+
embargo_td : int | None
|
|
62
|
+
Embargo buffer to prevent serial correlation leakage.
|
|
63
|
+
If None, no embargo is applied.
|
|
64
|
+
align_to_sessions : bool
|
|
65
|
+
If True, fold boundaries are aligned to trading session boundaries.
|
|
66
|
+
Requires 'session_date' column in data (from ml4t.data.sessions.SessionAssigner).
|
|
67
|
+
session_col : str
|
|
68
|
+
Column name containing session identifiers.
|
|
69
|
+
Default: 'session_date' (standard qdata column name).
|
|
70
|
+
isolate_groups : bool
|
|
71
|
+
If True, ensures no overlap between train/test group identifiers.
|
|
72
|
+
Useful for multi-asset validation to prevent data leakage.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
n_splits: int = Field(5, gt=0, description="Number of cross-validation folds")
|
|
76
|
+
label_horizon: Any = Field(
|
|
77
|
+
0,
|
|
78
|
+
description="Number of periods ahead that labels look (for purging/embargo). Can be int or pd.Timedelta.",
|
|
79
|
+
)
|
|
80
|
+
embargo_td: Any = Field(
|
|
81
|
+
None,
|
|
82
|
+
description="Embargo buffer in periods (prevents serial correlation leakage). Can be int, pd.Timedelta, or None.",
|
|
83
|
+
)
|
|
84
|
+
align_to_sessions: bool = Field(
|
|
85
|
+
False,
|
|
86
|
+
description=(
|
|
87
|
+
"Align fold boundaries to session boundaries. "
|
|
88
|
+
"Requires 'session_date' column from ml4t.data.sessions.SessionAssigner."
|
|
89
|
+
),
|
|
90
|
+
)
|
|
91
|
+
session_col: str = Field(
|
|
92
|
+
"session_date",
|
|
93
|
+
description="Column name containing session identifiers (default: qdata standard)",
|
|
94
|
+
)
|
|
95
|
+
timestamp_col: str | None = Field(
|
|
96
|
+
None,
|
|
97
|
+
description=(
|
|
98
|
+
"Column name containing timestamps for time-based sizes. "
|
|
99
|
+
"Required for Polars DataFrames with time-based test_size/train_size. "
|
|
100
|
+
"If None, falls back to pandas DatetimeIndex (backward compatible)."
|
|
101
|
+
),
|
|
102
|
+
)
|
|
103
|
+
isolate_groups: bool = Field(
|
|
104
|
+
False,
|
|
105
|
+
description=(
|
|
106
|
+
"Prevent same group (symbol/contract) from appearing in both train and test sets"
|
|
107
|
+
),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
@field_validator("label_horizon")
|
|
111
|
+
@classmethod
|
|
112
|
+
def validate_label_horizon(cls, v: Any) -> Any:
|
|
113
|
+
"""Validate label_horizon is either int >= 0 or a timedelta-like object."""
|
|
114
|
+
if isinstance(v, int):
|
|
115
|
+
if v < 0:
|
|
116
|
+
raise ValueError("label_horizon must be greater than or equal to 0")
|
|
117
|
+
return v
|
|
118
|
+
# Allow timedelta-like objects (pd.Timedelta, datetime.timedelta)
|
|
119
|
+
if hasattr(v, "total_seconds"):
|
|
120
|
+
return v
|
|
121
|
+
# Handle ISO 8601 duration strings from JSON serialization
|
|
122
|
+
if isinstance(v, str):
|
|
123
|
+
import pandas as pd
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
return pd.Timedelta(v)
|
|
127
|
+
except Exception as e:
|
|
128
|
+
raise ValueError(f"Could not parse label_horizon string '{v}' as Timedelta: {e}") # noqa: B904
|
|
129
|
+
raise ValueError(f"label_horizon must be int >= 0 or timedelta-like object, got {type(v)}")
|
|
130
|
+
|
|
131
|
+
@field_validator("embargo_td")
|
|
132
|
+
@classmethod
|
|
133
|
+
def validate_embargo_td(cls, v: Any) -> Any:
|
|
134
|
+
"""Validate embargo_td is either None, int >= 0, or a timedelta-like object."""
|
|
135
|
+
if v is None:
|
|
136
|
+
return v
|
|
137
|
+
if isinstance(v, int):
|
|
138
|
+
if v < 0:
|
|
139
|
+
raise ValueError("embargo_td must be greater than or equal to 0")
|
|
140
|
+
return v
|
|
141
|
+
# Allow timedelta-like objects (pd.Timedelta, datetime.timedelta)
|
|
142
|
+
if hasattr(v, "total_seconds"):
|
|
143
|
+
return v
|
|
144
|
+
# Handle ISO 8601 duration strings from JSON serialization
|
|
145
|
+
if isinstance(v, str):
|
|
146
|
+
import pandas as pd
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
return pd.Timedelta(v)
|
|
150
|
+
except Exception as e:
|
|
151
|
+
raise ValueError(f"Could not parse embargo_td string '{v}' as Timedelta: {e}") # noqa: B904
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"embargo_td must be None, int >= 0, or timedelta-like object, got {type(v)}"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class PurgedWalkForwardConfig(SplitterConfig):
|
|
158
|
+
"""Configuration for Purged Walk-Forward Cross-Validation.
|
|
159
|
+
|
|
160
|
+
Walk-forward validation is the standard approach for time-series backtesting,
|
|
161
|
+
where the model is trained on historical data and tested on future periods.
|
|
162
|
+
|
|
163
|
+
Attributes
|
|
164
|
+
----------
|
|
165
|
+
test_size : int | float | str | None
|
|
166
|
+
Test set size specification:
|
|
167
|
+
- int: Number of samples (or sessions if align_to_sessions=True)
|
|
168
|
+
- float: Proportion of dataset (0.0 to 1.0)
|
|
169
|
+
- str: Time-based ('4W', '3M') - NOT supported with align_to_sessions=True
|
|
170
|
+
- None: Auto-calculated to maintain equal test set sizes
|
|
171
|
+
train_size : int | float | str | None
|
|
172
|
+
Training set size specification (same format as test_size).
|
|
173
|
+
If None, uses expanding window (all data before test set).
|
|
174
|
+
step_size : int | None
|
|
175
|
+
Step size between consecutive splits:
|
|
176
|
+
- int: Number of samples (or sessions if align_to_sessions=True)
|
|
177
|
+
- None: Defaults to test_size (non-overlapping test sets)
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
test_size: int | float | str | None = Field(
|
|
181
|
+
None,
|
|
182
|
+
description=(
|
|
183
|
+
"Test set size: int (samples/sessions), float (proportion), "
|
|
184
|
+
"str (time-based, e.g., '4W'). "
|
|
185
|
+
"Time-based NOT supported with align_to_sessions=True."
|
|
186
|
+
),
|
|
187
|
+
)
|
|
188
|
+
train_size: int | float | str | None = Field(
|
|
189
|
+
None,
|
|
190
|
+
description=(
|
|
191
|
+
"Train set size: int (samples/sessions), float (proportion), "
|
|
192
|
+
"str (time-based, e.g., '12W'). "
|
|
193
|
+
"None uses expanding window (all data before test)."
|
|
194
|
+
),
|
|
195
|
+
)
|
|
196
|
+
step_size: int | None = Field(
|
|
197
|
+
None,
|
|
198
|
+
ge=1,
|
|
199
|
+
description=(
|
|
200
|
+
"Step size between splits (int: samples/sessions). None defaults to test_size (non-overlapping)."
|
|
201
|
+
),
|
|
202
|
+
)
|
|
203
|
+
isolate_groups: bool = Field(
|
|
204
|
+
False,
|
|
205
|
+
description=(
|
|
206
|
+
"Default False for walk-forward (opt-in). Set True for multi-asset validation to prevent group leakage."
|
|
207
|
+
),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
@field_validator("test_size", "train_size")
|
|
211
|
+
@classmethod
|
|
212
|
+
def validate_size_with_sessions(
|
|
213
|
+
cls, v: int | float | str | None, info
|
|
214
|
+
) -> int | float | str | None:
|
|
215
|
+
"""Validate that time-based sizes are not used with session alignment."""
|
|
216
|
+
if v is None:
|
|
217
|
+
return v
|
|
218
|
+
|
|
219
|
+
align_to_sessions = info.data.get("align_to_sessions", False)
|
|
220
|
+
if align_to_sessions and isinstance(v, str):
|
|
221
|
+
raise ValueError(
|
|
222
|
+
f"align_to_sessions=True does not support time-based size specifications. "
|
|
223
|
+
f"Use integer (number of sessions) or float (proportion). Got: {v!r}"
|
|
224
|
+
)
|
|
225
|
+
return v
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class CombinatorialPurgedConfig(SplitterConfig):
|
|
229
|
+
"""Configuration for Combinatorial Purged Cross-Validation (CPCV).
|
|
230
|
+
|
|
231
|
+
CPCV is designed for multi-asset strategies and combating overfitting by
|
|
232
|
+
creating multiple test sets from combinatorial group selections.
|
|
233
|
+
|
|
234
|
+
Reference: Bailey & López de Prado (2014)
|
|
235
|
+
"The Deflated Sharpe Ratio: Correcting for Selection Bias, Backtest Overfitting and Non-Normality"
|
|
236
|
+
|
|
237
|
+
Attributes
|
|
238
|
+
----------
|
|
239
|
+
n_groups : int
|
|
240
|
+
Number of groups to partition the timeline into (typically 8-12).
|
|
241
|
+
n_test_groups : int
|
|
242
|
+
Number of groups used for each test set (typically 2-3).
|
|
243
|
+
Total folds = C(n_groups, n_test_groups).
|
|
244
|
+
max_combinations : int | None
|
|
245
|
+
Maximum number of folds to generate.
|
|
246
|
+
If C(n_groups, n_test_groups) > max_combinations, randomly sample.
|
|
247
|
+
contiguous_test_blocks : bool
|
|
248
|
+
If True, only use contiguous test groups (reduces overfitting).
|
|
249
|
+
If False, allow any combination (more folds).
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
n_groups: int = Field(
|
|
253
|
+
8, gt=1, description="Number of groups to partition timeline into (typically 8-12)"
|
|
254
|
+
)
|
|
255
|
+
n_test_groups: int = Field(2, gt=0, description="Number of groups per test set (typically 2-3)")
|
|
256
|
+
max_combinations: int | None = Field(
|
|
257
|
+
None,
|
|
258
|
+
gt=0,
|
|
259
|
+
description=(
|
|
260
|
+
"Maximum folds to generate. If C(n_groups, n_test_groups) exceeds this, randomly sample."
|
|
261
|
+
),
|
|
262
|
+
)
|
|
263
|
+
contiguous_test_blocks: bool = Field(
|
|
264
|
+
False,
|
|
265
|
+
description=(
|
|
266
|
+
"If True, only use contiguous test groups (less overfitting). If False, allow any combination."
|
|
267
|
+
),
|
|
268
|
+
)
|
|
269
|
+
embargo_pct: float | None = Field(
|
|
270
|
+
None,
|
|
271
|
+
ge=0.0,
|
|
272
|
+
lt=1.0,
|
|
273
|
+
description=(
|
|
274
|
+
"Embargo size as percentage of total samples. "
|
|
275
|
+
"Alternative to embargo_td. Cannot specify both."
|
|
276
|
+
),
|
|
277
|
+
)
|
|
278
|
+
isolate_groups: bool = Field(
|
|
279
|
+
True,
|
|
280
|
+
description=(
|
|
281
|
+
"Default True for CPCV (opt-out). "
|
|
282
|
+
"CPCV is designed for multi-asset validation, so group isolation is aggressive by default."
|
|
283
|
+
),
|
|
284
|
+
)
|
|
285
|
+
random_state: int | None = Field(
|
|
286
|
+
None,
|
|
287
|
+
description=(
|
|
288
|
+
"Random seed for sampling when max_combinations limits the number of folds. "
|
|
289
|
+
"Use for reproducible subset selection from C(n_groups, n_test_groups) combinations."
|
|
290
|
+
),
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
@field_validator("n_test_groups")
|
|
294
|
+
@classmethod
|
|
295
|
+
def validate_n_test_groups(cls, v: int, info) -> int:
|
|
296
|
+
"""Validate that n_test_groups < n_groups (must leave groups for training)."""
|
|
297
|
+
n_groups = info.data.get("n_groups")
|
|
298
|
+
if n_groups is not None and v >= n_groups:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"n_test_groups ({v}) cannot exceed n_groups ({n_groups}). "
|
|
301
|
+
f"Must leave at least one group for training. "
|
|
302
|
+
f"Typically n_test_groups is 2-3 for CPCV."
|
|
303
|
+
)
|
|
304
|
+
return v
|
|
305
|
+
|
|
306
|
+
@model_validator(mode="after")
|
|
307
|
+
def validate_embargo_mutual_exclusivity(self) -> CombinatorialPurgedConfig:
|
|
308
|
+
"""Validate that embargo_td and embargo_pct are mutually exclusive."""
|
|
309
|
+
if self.embargo_td is not None and self.embargo_pct is not None:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
"Cannot specify both 'embargo_td' and 'embargo_pct'. "
|
|
312
|
+
"Choose one method for setting the embargo period."
|
|
313
|
+
)
|
|
314
|
+
return self
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
# Export all config classes
|
|
318
|
+
__all__ = [
|
|
319
|
+
"SplitterConfig",
|
|
320
|
+
"PurgedWalkForwardConfig",
|
|
321
|
+
"CombinatorialPurgedConfig",
|
|
322
|
+
]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Combinatorial Purged Cross-Validation (CPCV) submodules.
|
|
2
|
+
|
|
3
|
+
This package provides modular components for CPCV:
|
|
4
|
+
|
|
5
|
+
- combinations: Combination generation and reservoir sampling
|
|
6
|
+
- partitioning: Group partitioning strategies
|
|
7
|
+
- windows: Time window computation for purging
|
|
8
|
+
- purge_engine: Core purging and embargo logic
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from ml4t.diagnostic.splitters.cpcv.combinations import (
|
|
12
|
+
iter_combinations,
|
|
13
|
+
reservoir_sample_combinations,
|
|
14
|
+
)
|
|
15
|
+
from ml4t.diagnostic.splitters.cpcv.partitioning import (
|
|
16
|
+
boundaries_to_indices,
|
|
17
|
+
create_contiguous_partitions,
|
|
18
|
+
create_session_partitions,
|
|
19
|
+
exact_indices_to_array,
|
|
20
|
+
validate_contiguous_partitions,
|
|
21
|
+
)
|
|
22
|
+
from ml4t.diagnostic.splitters.cpcv.purge_engine import (
|
|
23
|
+
apply_multi_asset_purging,
|
|
24
|
+
apply_segment_purging,
|
|
25
|
+
apply_single_asset_purging,
|
|
26
|
+
prepare_test_groups_data,
|
|
27
|
+
process_asset_purging,
|
|
28
|
+
)
|
|
29
|
+
from ml4t.diagnostic.splitters.cpcv.windows import (
|
|
30
|
+
TimeWindow,
|
|
31
|
+
find_contiguous_segments,
|
|
32
|
+
merge_windows,
|
|
33
|
+
timestamp_window_from_indices,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
__all__ = [
|
|
37
|
+
# combinations
|
|
38
|
+
"iter_combinations",
|
|
39
|
+
"reservoir_sample_combinations",
|
|
40
|
+
# partitioning
|
|
41
|
+
"create_contiguous_partitions",
|
|
42
|
+
"validate_contiguous_partitions",
|
|
43
|
+
"create_session_partitions",
|
|
44
|
+
"boundaries_to_indices",
|
|
45
|
+
"exact_indices_to_array",
|
|
46
|
+
# windows
|
|
47
|
+
"TimeWindow",
|
|
48
|
+
"timestamp_window_from_indices",
|
|
49
|
+
"find_contiguous_segments",
|
|
50
|
+
"merge_windows",
|
|
51
|
+
# purge_engine
|
|
52
|
+
"apply_single_asset_purging",
|
|
53
|
+
"apply_multi_asset_purging",
|
|
54
|
+
"prepare_test_groups_data",
|
|
55
|
+
"process_asset_purging",
|
|
56
|
+
"apply_segment_purging",
|
|
57
|
+
]
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Combination generation and sampling for CPCV.
|
|
2
|
+
|
|
3
|
+
This module handles the combinatorial aspects of CPCV:
|
|
4
|
+
- Generating C(n,k) test group combinations
|
|
5
|
+
- Reservoir sampling for large combination spaces
|
|
6
|
+
- Lazy iteration over combinations
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import itertools
|
|
12
|
+
import math
|
|
13
|
+
from collections.abc import Iterator
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def iter_combinations(
|
|
23
|
+
n_groups: int,
|
|
24
|
+
n_test_groups: int,
|
|
25
|
+
max_combinations: int | None = None,
|
|
26
|
+
random_state: int | None = None,
|
|
27
|
+
) -> Iterator[tuple[int, ...]]:
|
|
28
|
+
"""Iterate over test group combinations, optionally sampling.
|
|
29
|
+
|
|
30
|
+
When max_combinations is None or larger than total combinations,
|
|
31
|
+
yields all C(n_groups, n_test_groups) combinations.
|
|
32
|
+
|
|
33
|
+
When max_combinations is smaller, uses reservoir sampling to
|
|
34
|
+
select a random subset without materializing all combinations.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
n_groups : int
|
|
39
|
+
Total number of groups to choose from.
|
|
40
|
+
n_test_groups : int
|
|
41
|
+
Number of groups to choose for each combination.
|
|
42
|
+
max_combinations : int, optional
|
|
43
|
+
Maximum number of combinations to yield.
|
|
44
|
+
If None, yields all combinations.
|
|
45
|
+
random_state : int, optional
|
|
46
|
+
Random seed for reproducible sampling.
|
|
47
|
+
|
|
48
|
+
Yields
|
|
49
|
+
------
|
|
50
|
+
tuple[int, ...]
|
|
51
|
+
Test group indices for each combination.
|
|
52
|
+
|
|
53
|
+
Examples
|
|
54
|
+
--------
|
|
55
|
+
>>> list(iter_combinations(4, 2))
|
|
56
|
+
[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
|
|
57
|
+
|
|
58
|
+
>>> list(iter_combinations(4, 2, max_combinations=3, random_state=42))
|
|
59
|
+
[(0, 2), (1, 2), (2, 3)]
|
|
60
|
+
"""
|
|
61
|
+
total_combinations = math.comb(n_groups, n_test_groups)
|
|
62
|
+
|
|
63
|
+
if max_combinations is None or total_combinations <= max_combinations:
|
|
64
|
+
# Yield all combinations directly from generator
|
|
65
|
+
yield from itertools.combinations(range(n_groups), n_test_groups)
|
|
66
|
+
else:
|
|
67
|
+
# Use reservoir sampling for subset selection
|
|
68
|
+
rng = np.random.default_rng(random_state)
|
|
69
|
+
sampled = reservoir_sample_combinations(n_groups, n_test_groups, max_combinations, rng)
|
|
70
|
+
yield from sampled
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def reservoir_sample_combinations(
|
|
74
|
+
n_groups: int,
|
|
75
|
+
n_test_groups: int,
|
|
76
|
+
max_combinations: int,
|
|
77
|
+
rng: np.random.Generator,
|
|
78
|
+
) -> list[tuple[int, ...]]:
|
|
79
|
+
"""Sample combinations using reservoir sampling.
|
|
80
|
+
|
|
81
|
+
Samples directly from the combinations iterator without materializing
|
|
82
|
+
all C(n,k) combinations in memory. Time complexity O(C(n,k)) but
|
|
83
|
+
space complexity O(max_combinations).
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
n_groups : int
|
|
88
|
+
Total number of groups to choose from.
|
|
89
|
+
n_test_groups : int
|
|
90
|
+
Number of groups to choose for each combination.
|
|
91
|
+
max_combinations : int
|
|
92
|
+
Number of combinations to sample.
|
|
93
|
+
rng : np.random.Generator
|
|
94
|
+
Random number generator for reproducible sampling.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
list[tuple[int, ...]]
|
|
99
|
+
Sampled combinations, randomly selected with uniform probability.
|
|
100
|
+
|
|
101
|
+
Notes
|
|
102
|
+
-----
|
|
103
|
+
Uses Algorithm R (Vitter, 1985) for reservoir sampling:
|
|
104
|
+
- First k items fill the reservoir
|
|
105
|
+
- Each subsequent item i replaces a random reservoir item with probability k/i
|
|
106
|
+
- Result is a uniform random sample of size k
|
|
107
|
+
"""
|
|
108
|
+
reservoir: list[tuple[int, ...]] = []
|
|
109
|
+
|
|
110
|
+
for i, combo in enumerate(itertools.combinations(range(n_groups), n_test_groups)):
|
|
111
|
+
if i < max_combinations:
|
|
112
|
+
reservoir.append(combo)
|
|
113
|
+
else:
|
|
114
|
+
# Reservoir sampling: replace with probability max_combinations/(i+1)
|
|
115
|
+
j = rng.integers(0, i + 1)
|
|
116
|
+
if j < max_combinations:
|
|
117
|
+
reservoir[j] = combo
|
|
118
|
+
|
|
119
|
+
return reservoir
|