ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,535 @@
|
|
|
1
|
+
"""Validated Cross-Validation combining CPCV with DSR for robust strategy assessment."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
from ml4t.diagnostic.config import StatisticalConfig
|
|
12
|
+
from ml4t.diagnostic.evaluation.stats import deflated_sharpe_ratio_from_statistics
|
|
13
|
+
from ml4t.diagnostic.splitters.combinatorial import CombinatorialPurgedCV
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from collections.abc import Callable
|
|
17
|
+
|
|
18
|
+
import polars as pl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@runtime_checkable
|
|
22
|
+
class ModelProtocol(Protocol):
|
|
23
|
+
"""Protocol for models that can be fit and predict."""
|
|
24
|
+
|
|
25
|
+
def fit(self, X: Any, y: Any) -> Any:
|
|
26
|
+
"""Fit the model."""
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
def predict(self, X: Any) -> Any:
|
|
30
|
+
"""Make predictions."""
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class ValidationFoldResult:
|
|
36
|
+
"""Result from a single cross-validation fold."""
|
|
37
|
+
|
|
38
|
+
fold_idx: int
|
|
39
|
+
train_size: int
|
|
40
|
+
test_size: int
|
|
41
|
+
sharpe_ratio: float
|
|
42
|
+
returns: np.ndarray
|
|
43
|
+
predictions: np.ndarray | None = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ValidationResult:
|
|
48
|
+
"""Complete result from validated cross-validation.
|
|
49
|
+
|
|
50
|
+
Combines cross-validation performance with statistical significance testing.
|
|
51
|
+
|
|
52
|
+
Attributes
|
|
53
|
+
----------
|
|
54
|
+
fold_results : list[ValidationFoldResult]
|
|
55
|
+
Results from each CV fold
|
|
56
|
+
n_folds : int
|
|
57
|
+
Number of folds completed
|
|
58
|
+
mean_sharpe : float
|
|
59
|
+
Mean Sharpe ratio across folds
|
|
60
|
+
std_sharpe : float
|
|
61
|
+
Standard deviation of Sharpe ratios
|
|
62
|
+
dsr : float
|
|
63
|
+
Deflated Sharpe Ratio (probability true SR > 0)
|
|
64
|
+
dsr_zscore : float
|
|
65
|
+
DSR z-score
|
|
66
|
+
expected_max_sharpe : float
|
|
67
|
+
Expected maximum Sharpe under null hypothesis
|
|
68
|
+
is_significant : bool
|
|
69
|
+
Whether DSR > significance threshold
|
|
70
|
+
interpretation : list[str]
|
|
71
|
+
Human-readable interpretation of results
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
fold_results: list[ValidationFoldResult] = field(default_factory=list)
|
|
75
|
+
n_folds: int = 0
|
|
76
|
+
mean_sharpe: float = 0.0
|
|
77
|
+
std_sharpe: float = 0.0
|
|
78
|
+
dsr: float = 0.0
|
|
79
|
+
dsr_zscore: float = 0.0
|
|
80
|
+
expected_max_sharpe: float = 0.0
|
|
81
|
+
is_significant: bool = False
|
|
82
|
+
significance_level: float = 0.95
|
|
83
|
+
interpretation: list[str] = field(default_factory=list)
|
|
84
|
+
|
|
85
|
+
def summary(self) -> str:
|
|
86
|
+
"""Generate human-readable summary.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
str
|
|
91
|
+
Formatted summary string
|
|
92
|
+
"""
|
|
93
|
+
lines = [
|
|
94
|
+
"=" * 50,
|
|
95
|
+
"Validated Cross-Validation Results",
|
|
96
|
+
"=" * 50,
|
|
97
|
+
"",
|
|
98
|
+
f"Folds completed: {self.n_folds}",
|
|
99
|
+
f"Mean Sharpe: {self.mean_sharpe:.4f}",
|
|
100
|
+
f"Std Sharpe: {self.std_sharpe:.4f}",
|
|
101
|
+
"",
|
|
102
|
+
"--- Statistical Significance ---",
|
|
103
|
+
f"DSR (probability true SR > 0): {self.dsr:.4f}",
|
|
104
|
+
f"DSR z-score: {self.dsr_zscore:.4f}",
|
|
105
|
+
f"Expected max SR under null: {self.expected_max_sharpe:.4f}",
|
|
106
|
+
f"Significant at {self.significance_level:.0%}: {'YES' if self.is_significant else 'NO'}",
|
|
107
|
+
"",
|
|
108
|
+
"--- Interpretation ---",
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
for interp in self.interpretation:
|
|
112
|
+
lines.append(f" - {interp}")
|
|
113
|
+
|
|
114
|
+
return "\n".join(lines)
|
|
115
|
+
|
|
116
|
+
def to_dict(self) -> dict[str, Any]:
|
|
117
|
+
"""Export to dictionary.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
dict
|
|
122
|
+
Dictionary representation
|
|
123
|
+
"""
|
|
124
|
+
return {
|
|
125
|
+
"n_folds": self.n_folds,
|
|
126
|
+
"mean_sharpe": self.mean_sharpe,
|
|
127
|
+
"std_sharpe": self.std_sharpe,
|
|
128
|
+
"dsr": self.dsr,
|
|
129
|
+
"dsr_zscore": self.dsr_zscore,
|
|
130
|
+
"expected_max_sharpe": self.expected_max_sharpe,
|
|
131
|
+
"is_significant": self.is_significant,
|
|
132
|
+
"significance_level": self.significance_level,
|
|
133
|
+
"interpretation": self.interpretation,
|
|
134
|
+
"fold_sharpes": [fr.sharpe_ratio for fr in self.fold_results],
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class ValidatedCrossValidationConfig(BaseModel):
|
|
139
|
+
"""Configuration for ValidatedCrossValidation."""
|
|
140
|
+
|
|
141
|
+
# CV parameters
|
|
142
|
+
n_groups: int = Field(default=10, ge=2, description="Number of CV groups")
|
|
143
|
+
n_test_groups: int = Field(default=2, ge=1, description="Groups per test set")
|
|
144
|
+
embargo_pct: float = Field(default=0.01, ge=0, le=0.2, description="Embargo fraction")
|
|
145
|
+
label_horizon: int = Field(default=0, ge=0, description="Label look-ahead samples")
|
|
146
|
+
|
|
147
|
+
# DSR parameters
|
|
148
|
+
sharpe_star: float = Field(default=0.0, description="Benchmark Sharpe ratio")
|
|
149
|
+
significance_level: float = Field(default=0.95, ge=0.5, le=0.999)
|
|
150
|
+
annualization_factor: float = Field(default=252.0, gt=0, description="For Sharpe annualization")
|
|
151
|
+
|
|
152
|
+
# Execution
|
|
153
|
+
random_state: int | None = Field(default=None)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class ValidatedCrossValidation:
|
|
157
|
+
"""Orchestrates CPCV with DSR computation for robust strategy validation.
|
|
158
|
+
|
|
159
|
+
Combines Combinatorial Purged Cross-Validation with Deflated Sharpe Ratio
|
|
160
|
+
to provide statistically rigorous assessment of trading strategies.
|
|
161
|
+
|
|
162
|
+
This addresses the workflow fragmentation where users must manually:
|
|
163
|
+
1. Run CPCV
|
|
164
|
+
2. Collect Sharpe ratios
|
|
165
|
+
3. Compute DSR
|
|
166
|
+
4. Interpret results
|
|
167
|
+
|
|
168
|
+
Examples
|
|
169
|
+
--------
|
|
170
|
+
>>> # Basic usage with model
|
|
171
|
+
>>> vcv = ValidatedCrossValidation(config)
|
|
172
|
+
>>> result = vcv.fit_evaluate(X, y, model, times=dates)
|
|
173
|
+
>>> print(result.summary())
|
|
174
|
+
|
|
175
|
+
>>> # With custom returns computation
|
|
176
|
+
>>> def compute_returns(y_true, y_pred, prices):
|
|
177
|
+
... positions = np.sign(y_pred)
|
|
178
|
+
... returns = positions * y_true # Simple return
|
|
179
|
+
... return returns
|
|
180
|
+
>>> result = vcv.fit_evaluate(X, y, model, times=dates, returns_fn=compute_returns)
|
|
181
|
+
|
|
182
|
+
>>> # Just evaluate pre-computed fold Sharpes
|
|
183
|
+
>>> result = vcv.evaluate_sharpes([0.5, 0.6, 0.4, 0.7, 0.3])
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
config: ValidatedCrossValidationConfig | None = None,
|
|
189
|
+
statistical_config: StatisticalConfig | None = None,
|
|
190
|
+
):
|
|
191
|
+
"""Initialize ValidatedCrossValidation.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
config : ValidatedCrossValidationConfig, optional
|
|
196
|
+
CV and evaluation configuration
|
|
197
|
+
statistical_config : StatisticalConfig, optional
|
|
198
|
+
Statistical testing configuration (for advanced DSR settings)
|
|
199
|
+
"""
|
|
200
|
+
self.config = config or ValidatedCrossValidationConfig()
|
|
201
|
+
self.statistical_config = statistical_config or StatisticalConfig()
|
|
202
|
+
|
|
203
|
+
# Initialize CPCV splitter
|
|
204
|
+
self._cv = CombinatorialPurgedCV(
|
|
205
|
+
n_groups=self.config.n_groups,
|
|
206
|
+
n_test_groups=self.config.n_test_groups,
|
|
207
|
+
embargo_pct=self.config.embargo_pct,
|
|
208
|
+
label_horizon=self.config.label_horizon,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def fit_evaluate(
|
|
212
|
+
self,
|
|
213
|
+
X: np.ndarray | pl.DataFrame,
|
|
214
|
+
y: np.ndarray | pl.Series,
|
|
215
|
+
model: ModelProtocol,
|
|
216
|
+
times: np.ndarray | pl.Series | None = None,
|
|
217
|
+
returns_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] | None = None,
|
|
218
|
+
) -> ValidationResult:
|
|
219
|
+
"""Run cross-validation and compute DSR in one call.
|
|
220
|
+
|
|
221
|
+
Parameters
|
|
222
|
+
----------
|
|
223
|
+
X : array-like
|
|
224
|
+
Features matrix
|
|
225
|
+
y : array-like
|
|
226
|
+
Target variable (or returns if returns_fn not provided)
|
|
227
|
+
model : ModelProtocol
|
|
228
|
+
Model with fit/predict interface
|
|
229
|
+
times : array-like, optional
|
|
230
|
+
Timestamps for purging. Required for temporal purging.
|
|
231
|
+
returns_fn : callable, optional
|
|
232
|
+
Function(y_true, y_pred) -> returns.
|
|
233
|
+
If None, assumes y contains returns and predictions are positions.
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
ValidationResult
|
|
238
|
+
Complete validation results with DSR
|
|
239
|
+
"""
|
|
240
|
+
import polars as pl
|
|
241
|
+
|
|
242
|
+
# Convert to numpy if needed
|
|
243
|
+
if isinstance(X, pl.DataFrame):
|
|
244
|
+
X_np = X.to_numpy()
|
|
245
|
+
else:
|
|
246
|
+
X_np = np.asarray(X)
|
|
247
|
+
|
|
248
|
+
if isinstance(y, pl.Series):
|
|
249
|
+
y_np = y.to_numpy()
|
|
250
|
+
else:
|
|
251
|
+
y_np = np.asarray(y)
|
|
252
|
+
|
|
253
|
+
if times is not None:
|
|
254
|
+
if isinstance(times, pl.Series):
|
|
255
|
+
times_np = times.to_numpy()
|
|
256
|
+
else:
|
|
257
|
+
times_np = np.asarray(times)
|
|
258
|
+
else:
|
|
259
|
+
times_np = None
|
|
260
|
+
|
|
261
|
+
fold_results = []
|
|
262
|
+
|
|
263
|
+
for fold_idx, (train_idx, test_idx) in enumerate(self._cv.split(X_np, y_np, times_np)):
|
|
264
|
+
# Fit model
|
|
265
|
+
model.fit(X_np[train_idx], y_np[train_idx])
|
|
266
|
+
|
|
267
|
+
# Get predictions
|
|
268
|
+
predictions = model.predict(X_np[test_idx])
|
|
269
|
+
|
|
270
|
+
# Compute returns
|
|
271
|
+
if returns_fn is not None:
|
|
272
|
+
fold_returns = returns_fn(y_np[test_idx], predictions)
|
|
273
|
+
else:
|
|
274
|
+
# Default: assume y is returns, predictions are signals
|
|
275
|
+
fold_returns = np.sign(predictions) * y_np[test_idx]
|
|
276
|
+
|
|
277
|
+
# Compute Sharpe
|
|
278
|
+
sharpe = self._compute_sharpe(fold_returns)
|
|
279
|
+
|
|
280
|
+
fold_results.append(
|
|
281
|
+
ValidationFoldResult(
|
|
282
|
+
fold_idx=fold_idx,
|
|
283
|
+
train_size=len(train_idx),
|
|
284
|
+
test_size=len(test_idx),
|
|
285
|
+
sharpe_ratio=sharpe,
|
|
286
|
+
returns=fold_returns,
|
|
287
|
+
predictions=predictions,
|
|
288
|
+
)
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
return self._compute_validation_result(fold_results)
|
|
292
|
+
|
|
293
|
+
def evaluate_sharpes(self, sharpe_ratios: list[float]) -> ValidationResult:
|
|
294
|
+
"""Evaluate pre-computed Sharpe ratios with DSR.
|
|
295
|
+
|
|
296
|
+
Use when you've already computed Sharpe ratios from custom evaluation.
|
|
297
|
+
|
|
298
|
+
Parameters
|
|
299
|
+
----------
|
|
300
|
+
sharpe_ratios : list[float]
|
|
301
|
+
Sharpe ratios from each CV fold or strategy
|
|
302
|
+
|
|
303
|
+
Returns
|
|
304
|
+
-------
|
|
305
|
+
ValidationResult
|
|
306
|
+
Complete validation results with DSR
|
|
307
|
+
|
|
308
|
+
Examples
|
|
309
|
+
--------
|
|
310
|
+
>>> sharpes = [0.5, 0.6, 0.4, 0.7, 0.3, 0.55]
|
|
311
|
+
>>> result = vcv.evaluate_sharpes(sharpes)
|
|
312
|
+
>>> print(f"DSR: {result.dsr:.4f}")
|
|
313
|
+
"""
|
|
314
|
+
fold_results = [
|
|
315
|
+
ValidationFoldResult(
|
|
316
|
+
fold_idx=i,
|
|
317
|
+
train_size=0,
|
|
318
|
+
test_size=0,
|
|
319
|
+
sharpe_ratio=sr,
|
|
320
|
+
returns=np.array([]),
|
|
321
|
+
)
|
|
322
|
+
for i, sr in enumerate(sharpe_ratios)
|
|
323
|
+
]
|
|
324
|
+
return self._compute_validation_result(fold_results)
|
|
325
|
+
|
|
326
|
+
def _compute_sharpe(self, returns: np.ndarray) -> float:
|
|
327
|
+
"""Compute annualized Sharpe ratio.
|
|
328
|
+
|
|
329
|
+
Parameters
|
|
330
|
+
----------
|
|
331
|
+
returns : np.ndarray
|
|
332
|
+
Period returns
|
|
333
|
+
|
|
334
|
+
Returns
|
|
335
|
+
-------
|
|
336
|
+
float
|
|
337
|
+
Annualized Sharpe ratio
|
|
338
|
+
"""
|
|
339
|
+
if len(returns) < 2 or np.std(returns) == 0:
|
|
340
|
+
return 0.0
|
|
341
|
+
|
|
342
|
+
mean_ret = np.mean(returns)
|
|
343
|
+
std_ret = np.std(returns, ddof=1)
|
|
344
|
+
|
|
345
|
+
# Annualize
|
|
346
|
+
sharpe = (mean_ret / std_ret) * np.sqrt(self.config.annualization_factor)
|
|
347
|
+
return float(sharpe)
|
|
348
|
+
|
|
349
|
+
def _compute_validation_result(
|
|
350
|
+
self, fold_results: list[ValidationFoldResult]
|
|
351
|
+
) -> ValidationResult:
|
|
352
|
+
"""Compute final validation result with DSR.
|
|
353
|
+
|
|
354
|
+
Parameters
|
|
355
|
+
----------
|
|
356
|
+
fold_results : list[ValidationFoldResult]
|
|
357
|
+
Results from each fold
|
|
358
|
+
|
|
359
|
+
Returns
|
|
360
|
+
-------
|
|
361
|
+
ValidationResult
|
|
362
|
+
Complete validation result
|
|
363
|
+
"""
|
|
364
|
+
sharpes = [fr.sharpe_ratio for fr in fold_results]
|
|
365
|
+
n_folds = len(sharpes)
|
|
366
|
+
|
|
367
|
+
if n_folds == 0:
|
|
368
|
+
return ValidationResult(interpretation=["No folds completed"])
|
|
369
|
+
|
|
370
|
+
mean_sharpe = float(np.mean(sharpes))
|
|
371
|
+
std_sharpe = float(np.std(sharpes, ddof=1)) if n_folds > 1 else 0.0
|
|
372
|
+
max_sharpe = float(np.max(sharpes))
|
|
373
|
+
|
|
374
|
+
# Compute variance of Sharpes
|
|
375
|
+
var_sharpes = std_sharpe**2 if n_folds > 1 else 0.0
|
|
376
|
+
|
|
377
|
+
# Compute DSR
|
|
378
|
+
# We use max_sharpe as the "observed" Sharpe (the one we'd select)
|
|
379
|
+
dsr_result = deflated_sharpe_ratio_from_statistics(
|
|
380
|
+
observed_sharpe=max_sharpe,
|
|
381
|
+
n_trials=n_folds,
|
|
382
|
+
variance_trials=var_sharpes,
|
|
383
|
+
n_samples=252, # Assume annual Sharpes
|
|
384
|
+
skewness=0.0, # Assume symmetric
|
|
385
|
+
excess_kurtosis=0.0, # Assume normal (Fisher convention: normal=0)
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
dsr = dsr_result.probability
|
|
389
|
+
dsr_zscore = dsr_result.z_score
|
|
390
|
+
expected_max = dsr_result.expected_max_sharpe
|
|
391
|
+
|
|
392
|
+
is_significant = dsr > self.config.significance_level
|
|
393
|
+
|
|
394
|
+
# Generate interpretation
|
|
395
|
+
interpretation = self._generate_interpretation(
|
|
396
|
+
mean_sharpe=mean_sharpe,
|
|
397
|
+
max_sharpe=max_sharpe,
|
|
398
|
+
expected_max=expected_max,
|
|
399
|
+
dsr=dsr,
|
|
400
|
+
is_significant=is_significant,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
return ValidationResult(
|
|
404
|
+
fold_results=fold_results,
|
|
405
|
+
n_folds=n_folds,
|
|
406
|
+
mean_sharpe=mean_sharpe,
|
|
407
|
+
std_sharpe=std_sharpe,
|
|
408
|
+
dsr=dsr,
|
|
409
|
+
dsr_zscore=dsr_zscore,
|
|
410
|
+
expected_max_sharpe=expected_max,
|
|
411
|
+
is_significant=is_significant,
|
|
412
|
+
significance_level=self.config.significance_level,
|
|
413
|
+
interpretation=interpretation,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
def _generate_interpretation(
|
|
417
|
+
self,
|
|
418
|
+
mean_sharpe: float,
|
|
419
|
+
max_sharpe: float,
|
|
420
|
+
expected_max: float,
|
|
421
|
+
dsr: float,
|
|
422
|
+
is_significant: bool,
|
|
423
|
+
) -> list[str]:
|
|
424
|
+
"""Generate human-readable interpretation.
|
|
425
|
+
|
|
426
|
+
Parameters
|
|
427
|
+
----------
|
|
428
|
+
mean_sharpe : float
|
|
429
|
+
Mean Sharpe across folds
|
|
430
|
+
max_sharpe : float
|
|
431
|
+
Maximum observed Sharpe
|
|
432
|
+
expected_max : float
|
|
433
|
+
Expected max under null
|
|
434
|
+
dsr : float
|
|
435
|
+
Deflated Sharpe Ratio
|
|
436
|
+
is_significant : bool
|
|
437
|
+
Whether result is significant
|
|
438
|
+
|
|
439
|
+
Returns
|
|
440
|
+
-------
|
|
441
|
+
list[str]
|
|
442
|
+
Interpretation strings
|
|
443
|
+
"""
|
|
444
|
+
interp = []
|
|
445
|
+
|
|
446
|
+
# Significance assessment
|
|
447
|
+
if is_significant:
|
|
448
|
+
interp.append(
|
|
449
|
+
f"Strategy is statistically significant (DSR={dsr:.2%} > {self.config.significance_level:.0%})"
|
|
450
|
+
)
|
|
451
|
+
else:
|
|
452
|
+
interp.append(
|
|
453
|
+
f"Strategy is NOT significant (DSR={dsr:.2%} < {self.config.significance_level:.0%})"
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# Overfitting assessment
|
|
457
|
+
inflation = max_sharpe - expected_max
|
|
458
|
+
if inflation > 0:
|
|
459
|
+
interp.append(
|
|
460
|
+
f"Potential overfitting: observed SR ({max_sharpe:.3f}) exceeds null expectation ({expected_max:.3f}) by {inflation:.3f}"
|
|
461
|
+
)
|
|
462
|
+
else:
|
|
463
|
+
interp.append("No obvious overfitting: observed SR below null expectation")
|
|
464
|
+
|
|
465
|
+
# Mean vs max
|
|
466
|
+
if max_sharpe > 2 * mean_sharpe and mean_sharpe > 0:
|
|
467
|
+
interp.append("High variance in fold performance suggests unstable strategy")
|
|
468
|
+
elif mean_sharpe > 0.5:
|
|
469
|
+
interp.append("Consistent positive performance across folds")
|
|
470
|
+
|
|
471
|
+
# Recommendation
|
|
472
|
+
if is_significant and mean_sharpe > 0.3:
|
|
473
|
+
interp.append(
|
|
474
|
+
"Recommendation: Strategy shows robust performance, consider paper trading"
|
|
475
|
+
)
|
|
476
|
+
elif is_significant:
|
|
477
|
+
interp.append(
|
|
478
|
+
"Recommendation: Significant but modest returns, investigate improvements"
|
|
479
|
+
)
|
|
480
|
+
else:
|
|
481
|
+
interp.append(
|
|
482
|
+
"Recommendation: Strategy likely overfit, revisit feature selection or model"
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
return interp
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
# Convenience function
|
|
489
|
+
def validated_cross_val_score(
|
|
490
|
+
model: ModelProtocol,
|
|
491
|
+
X: np.ndarray,
|
|
492
|
+
y: np.ndarray,
|
|
493
|
+
times: np.ndarray | None = None,
|
|
494
|
+
n_groups: int = 10,
|
|
495
|
+
embargo_pct: float = 0.01,
|
|
496
|
+
) -> ValidationResult:
|
|
497
|
+
"""Convenience function for validated cross-validation.
|
|
498
|
+
|
|
499
|
+
Parameters
|
|
500
|
+
----------
|
|
501
|
+
model : ModelProtocol
|
|
502
|
+
Model with fit/predict interface
|
|
503
|
+
X : np.ndarray
|
|
504
|
+
Features
|
|
505
|
+
y : np.ndarray
|
|
506
|
+
Target (or returns)
|
|
507
|
+
times : np.ndarray, optional
|
|
508
|
+
Timestamps for purging
|
|
509
|
+
n_groups : int, default 10
|
|
510
|
+
Number of CV groups
|
|
511
|
+
embargo_pct : float, default 0.01
|
|
512
|
+
Embargo fraction
|
|
513
|
+
|
|
514
|
+
Returns
|
|
515
|
+
-------
|
|
516
|
+
ValidationResult
|
|
517
|
+
Validation results with DSR
|
|
518
|
+
|
|
519
|
+
Examples
|
|
520
|
+
--------
|
|
521
|
+
>>> from sklearn.ensemble import RandomForestClassifier
|
|
522
|
+
>>> result = validated_cross_val_score(
|
|
523
|
+
... model=RandomForestClassifier(),
|
|
524
|
+
... X=features,
|
|
525
|
+
... y=returns,
|
|
526
|
+
... times=dates,
|
|
527
|
+
... )
|
|
528
|
+
>>> print(f"DSR: {result.dsr:.4f}")
|
|
529
|
+
"""
|
|
530
|
+
config = ValidatedCrossValidationConfig(
|
|
531
|
+
n_groups=n_groups,
|
|
532
|
+
embargo_pct=embargo_pct,
|
|
533
|
+
)
|
|
534
|
+
vcv = ValidatedCrossValidation(config)
|
|
535
|
+
return vcv.fit_evaluate(X, y, model, times)
|