ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,535 @@
1
+ """Validated Cross-Validation combining CPCV with DSR for robust strategy assessment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
7
+
8
+ import numpy as np
9
+ from pydantic import BaseModel, Field
10
+
11
+ from ml4t.diagnostic.config import StatisticalConfig
12
+ from ml4t.diagnostic.evaluation.stats import deflated_sharpe_ratio_from_statistics
13
+ from ml4t.diagnostic.splitters.combinatorial import CombinatorialPurgedCV
14
+
15
+ if TYPE_CHECKING:
16
+ from collections.abc import Callable
17
+
18
+ import polars as pl
19
+
20
+
21
+ @runtime_checkable
22
+ class ModelProtocol(Protocol):
23
+ """Protocol for models that can be fit and predict."""
24
+
25
+ def fit(self, X: Any, y: Any) -> Any:
26
+ """Fit the model."""
27
+ ...
28
+
29
+ def predict(self, X: Any) -> Any:
30
+ """Make predictions."""
31
+ ...
32
+
33
+
34
+ @dataclass
35
+ class ValidationFoldResult:
36
+ """Result from a single cross-validation fold."""
37
+
38
+ fold_idx: int
39
+ train_size: int
40
+ test_size: int
41
+ sharpe_ratio: float
42
+ returns: np.ndarray
43
+ predictions: np.ndarray | None = None
44
+
45
+
46
+ @dataclass
47
+ class ValidationResult:
48
+ """Complete result from validated cross-validation.
49
+
50
+ Combines cross-validation performance with statistical significance testing.
51
+
52
+ Attributes
53
+ ----------
54
+ fold_results : list[ValidationFoldResult]
55
+ Results from each CV fold
56
+ n_folds : int
57
+ Number of folds completed
58
+ mean_sharpe : float
59
+ Mean Sharpe ratio across folds
60
+ std_sharpe : float
61
+ Standard deviation of Sharpe ratios
62
+ dsr : float
63
+ Deflated Sharpe Ratio (probability true SR > 0)
64
+ dsr_zscore : float
65
+ DSR z-score
66
+ expected_max_sharpe : float
67
+ Expected maximum Sharpe under null hypothesis
68
+ is_significant : bool
69
+ Whether DSR > significance threshold
70
+ interpretation : list[str]
71
+ Human-readable interpretation of results
72
+ """
73
+
74
+ fold_results: list[ValidationFoldResult] = field(default_factory=list)
75
+ n_folds: int = 0
76
+ mean_sharpe: float = 0.0
77
+ std_sharpe: float = 0.0
78
+ dsr: float = 0.0
79
+ dsr_zscore: float = 0.0
80
+ expected_max_sharpe: float = 0.0
81
+ is_significant: bool = False
82
+ significance_level: float = 0.95
83
+ interpretation: list[str] = field(default_factory=list)
84
+
85
+ def summary(self) -> str:
86
+ """Generate human-readable summary.
87
+
88
+ Returns
89
+ -------
90
+ str
91
+ Formatted summary string
92
+ """
93
+ lines = [
94
+ "=" * 50,
95
+ "Validated Cross-Validation Results",
96
+ "=" * 50,
97
+ "",
98
+ f"Folds completed: {self.n_folds}",
99
+ f"Mean Sharpe: {self.mean_sharpe:.4f}",
100
+ f"Std Sharpe: {self.std_sharpe:.4f}",
101
+ "",
102
+ "--- Statistical Significance ---",
103
+ f"DSR (probability true SR > 0): {self.dsr:.4f}",
104
+ f"DSR z-score: {self.dsr_zscore:.4f}",
105
+ f"Expected max SR under null: {self.expected_max_sharpe:.4f}",
106
+ f"Significant at {self.significance_level:.0%}: {'YES' if self.is_significant else 'NO'}",
107
+ "",
108
+ "--- Interpretation ---",
109
+ ]
110
+
111
+ for interp in self.interpretation:
112
+ lines.append(f" - {interp}")
113
+
114
+ return "\n".join(lines)
115
+
116
+ def to_dict(self) -> dict[str, Any]:
117
+ """Export to dictionary.
118
+
119
+ Returns
120
+ -------
121
+ dict
122
+ Dictionary representation
123
+ """
124
+ return {
125
+ "n_folds": self.n_folds,
126
+ "mean_sharpe": self.mean_sharpe,
127
+ "std_sharpe": self.std_sharpe,
128
+ "dsr": self.dsr,
129
+ "dsr_zscore": self.dsr_zscore,
130
+ "expected_max_sharpe": self.expected_max_sharpe,
131
+ "is_significant": self.is_significant,
132
+ "significance_level": self.significance_level,
133
+ "interpretation": self.interpretation,
134
+ "fold_sharpes": [fr.sharpe_ratio for fr in self.fold_results],
135
+ }
136
+
137
+
138
+ class ValidatedCrossValidationConfig(BaseModel):
139
+ """Configuration for ValidatedCrossValidation."""
140
+
141
+ # CV parameters
142
+ n_groups: int = Field(default=10, ge=2, description="Number of CV groups")
143
+ n_test_groups: int = Field(default=2, ge=1, description="Groups per test set")
144
+ embargo_pct: float = Field(default=0.01, ge=0, le=0.2, description="Embargo fraction")
145
+ label_horizon: int = Field(default=0, ge=0, description="Label look-ahead samples")
146
+
147
+ # DSR parameters
148
+ sharpe_star: float = Field(default=0.0, description="Benchmark Sharpe ratio")
149
+ significance_level: float = Field(default=0.95, ge=0.5, le=0.999)
150
+ annualization_factor: float = Field(default=252.0, gt=0, description="For Sharpe annualization")
151
+
152
+ # Execution
153
+ random_state: int | None = Field(default=None)
154
+
155
+
156
+ class ValidatedCrossValidation:
157
+ """Orchestrates CPCV with DSR computation for robust strategy validation.
158
+
159
+ Combines Combinatorial Purged Cross-Validation with Deflated Sharpe Ratio
160
+ to provide statistically rigorous assessment of trading strategies.
161
+
162
+ This addresses the workflow fragmentation where users must manually:
163
+ 1. Run CPCV
164
+ 2. Collect Sharpe ratios
165
+ 3. Compute DSR
166
+ 4. Interpret results
167
+
168
+ Examples
169
+ --------
170
+ >>> # Basic usage with model
171
+ >>> vcv = ValidatedCrossValidation(config)
172
+ >>> result = vcv.fit_evaluate(X, y, model, times=dates)
173
+ >>> print(result.summary())
174
+
175
+ >>> # With custom returns computation
176
+ >>> def compute_returns(y_true, y_pred, prices):
177
+ ... positions = np.sign(y_pred)
178
+ ... returns = positions * y_true # Simple return
179
+ ... return returns
180
+ >>> result = vcv.fit_evaluate(X, y, model, times=dates, returns_fn=compute_returns)
181
+
182
+ >>> # Just evaluate pre-computed fold Sharpes
183
+ >>> result = vcv.evaluate_sharpes([0.5, 0.6, 0.4, 0.7, 0.3])
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ config: ValidatedCrossValidationConfig | None = None,
189
+ statistical_config: StatisticalConfig | None = None,
190
+ ):
191
+ """Initialize ValidatedCrossValidation.
192
+
193
+ Parameters
194
+ ----------
195
+ config : ValidatedCrossValidationConfig, optional
196
+ CV and evaluation configuration
197
+ statistical_config : StatisticalConfig, optional
198
+ Statistical testing configuration (for advanced DSR settings)
199
+ """
200
+ self.config = config or ValidatedCrossValidationConfig()
201
+ self.statistical_config = statistical_config or StatisticalConfig()
202
+
203
+ # Initialize CPCV splitter
204
+ self._cv = CombinatorialPurgedCV(
205
+ n_groups=self.config.n_groups,
206
+ n_test_groups=self.config.n_test_groups,
207
+ embargo_pct=self.config.embargo_pct,
208
+ label_horizon=self.config.label_horizon,
209
+ )
210
+
211
+ def fit_evaluate(
212
+ self,
213
+ X: np.ndarray | pl.DataFrame,
214
+ y: np.ndarray | pl.Series,
215
+ model: ModelProtocol,
216
+ times: np.ndarray | pl.Series | None = None,
217
+ returns_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] | None = None,
218
+ ) -> ValidationResult:
219
+ """Run cross-validation and compute DSR in one call.
220
+
221
+ Parameters
222
+ ----------
223
+ X : array-like
224
+ Features matrix
225
+ y : array-like
226
+ Target variable (or returns if returns_fn not provided)
227
+ model : ModelProtocol
228
+ Model with fit/predict interface
229
+ times : array-like, optional
230
+ Timestamps for purging. Required for temporal purging.
231
+ returns_fn : callable, optional
232
+ Function(y_true, y_pred) -> returns.
233
+ If None, assumes y contains returns and predictions are positions.
234
+
235
+ Returns
236
+ -------
237
+ ValidationResult
238
+ Complete validation results with DSR
239
+ """
240
+ import polars as pl
241
+
242
+ # Convert to numpy if needed
243
+ if isinstance(X, pl.DataFrame):
244
+ X_np = X.to_numpy()
245
+ else:
246
+ X_np = np.asarray(X)
247
+
248
+ if isinstance(y, pl.Series):
249
+ y_np = y.to_numpy()
250
+ else:
251
+ y_np = np.asarray(y)
252
+
253
+ if times is not None:
254
+ if isinstance(times, pl.Series):
255
+ times_np = times.to_numpy()
256
+ else:
257
+ times_np = np.asarray(times)
258
+ else:
259
+ times_np = None
260
+
261
+ fold_results = []
262
+
263
+ for fold_idx, (train_idx, test_idx) in enumerate(self._cv.split(X_np, y_np, times_np)):
264
+ # Fit model
265
+ model.fit(X_np[train_idx], y_np[train_idx])
266
+
267
+ # Get predictions
268
+ predictions = model.predict(X_np[test_idx])
269
+
270
+ # Compute returns
271
+ if returns_fn is not None:
272
+ fold_returns = returns_fn(y_np[test_idx], predictions)
273
+ else:
274
+ # Default: assume y is returns, predictions are signals
275
+ fold_returns = np.sign(predictions) * y_np[test_idx]
276
+
277
+ # Compute Sharpe
278
+ sharpe = self._compute_sharpe(fold_returns)
279
+
280
+ fold_results.append(
281
+ ValidationFoldResult(
282
+ fold_idx=fold_idx,
283
+ train_size=len(train_idx),
284
+ test_size=len(test_idx),
285
+ sharpe_ratio=sharpe,
286
+ returns=fold_returns,
287
+ predictions=predictions,
288
+ )
289
+ )
290
+
291
+ return self._compute_validation_result(fold_results)
292
+
293
+ def evaluate_sharpes(self, sharpe_ratios: list[float]) -> ValidationResult:
294
+ """Evaluate pre-computed Sharpe ratios with DSR.
295
+
296
+ Use when you've already computed Sharpe ratios from custom evaluation.
297
+
298
+ Parameters
299
+ ----------
300
+ sharpe_ratios : list[float]
301
+ Sharpe ratios from each CV fold or strategy
302
+
303
+ Returns
304
+ -------
305
+ ValidationResult
306
+ Complete validation results with DSR
307
+
308
+ Examples
309
+ --------
310
+ >>> sharpes = [0.5, 0.6, 0.4, 0.7, 0.3, 0.55]
311
+ >>> result = vcv.evaluate_sharpes(sharpes)
312
+ >>> print(f"DSR: {result.dsr:.4f}")
313
+ """
314
+ fold_results = [
315
+ ValidationFoldResult(
316
+ fold_idx=i,
317
+ train_size=0,
318
+ test_size=0,
319
+ sharpe_ratio=sr,
320
+ returns=np.array([]),
321
+ )
322
+ for i, sr in enumerate(sharpe_ratios)
323
+ ]
324
+ return self._compute_validation_result(fold_results)
325
+
326
+ def _compute_sharpe(self, returns: np.ndarray) -> float:
327
+ """Compute annualized Sharpe ratio.
328
+
329
+ Parameters
330
+ ----------
331
+ returns : np.ndarray
332
+ Period returns
333
+
334
+ Returns
335
+ -------
336
+ float
337
+ Annualized Sharpe ratio
338
+ """
339
+ if len(returns) < 2 or np.std(returns) == 0:
340
+ return 0.0
341
+
342
+ mean_ret = np.mean(returns)
343
+ std_ret = np.std(returns, ddof=1)
344
+
345
+ # Annualize
346
+ sharpe = (mean_ret / std_ret) * np.sqrt(self.config.annualization_factor)
347
+ return float(sharpe)
348
+
349
+ def _compute_validation_result(
350
+ self, fold_results: list[ValidationFoldResult]
351
+ ) -> ValidationResult:
352
+ """Compute final validation result with DSR.
353
+
354
+ Parameters
355
+ ----------
356
+ fold_results : list[ValidationFoldResult]
357
+ Results from each fold
358
+
359
+ Returns
360
+ -------
361
+ ValidationResult
362
+ Complete validation result
363
+ """
364
+ sharpes = [fr.sharpe_ratio for fr in fold_results]
365
+ n_folds = len(sharpes)
366
+
367
+ if n_folds == 0:
368
+ return ValidationResult(interpretation=["No folds completed"])
369
+
370
+ mean_sharpe = float(np.mean(sharpes))
371
+ std_sharpe = float(np.std(sharpes, ddof=1)) if n_folds > 1 else 0.0
372
+ max_sharpe = float(np.max(sharpes))
373
+
374
+ # Compute variance of Sharpes
375
+ var_sharpes = std_sharpe**2 if n_folds > 1 else 0.0
376
+
377
+ # Compute DSR
378
+ # We use max_sharpe as the "observed" Sharpe (the one we'd select)
379
+ dsr_result = deflated_sharpe_ratio_from_statistics(
380
+ observed_sharpe=max_sharpe,
381
+ n_trials=n_folds,
382
+ variance_trials=var_sharpes,
383
+ n_samples=252, # Assume annual Sharpes
384
+ skewness=0.0, # Assume symmetric
385
+ excess_kurtosis=0.0, # Assume normal (Fisher convention: normal=0)
386
+ )
387
+
388
+ dsr = dsr_result.probability
389
+ dsr_zscore = dsr_result.z_score
390
+ expected_max = dsr_result.expected_max_sharpe
391
+
392
+ is_significant = dsr > self.config.significance_level
393
+
394
+ # Generate interpretation
395
+ interpretation = self._generate_interpretation(
396
+ mean_sharpe=mean_sharpe,
397
+ max_sharpe=max_sharpe,
398
+ expected_max=expected_max,
399
+ dsr=dsr,
400
+ is_significant=is_significant,
401
+ )
402
+
403
+ return ValidationResult(
404
+ fold_results=fold_results,
405
+ n_folds=n_folds,
406
+ mean_sharpe=mean_sharpe,
407
+ std_sharpe=std_sharpe,
408
+ dsr=dsr,
409
+ dsr_zscore=dsr_zscore,
410
+ expected_max_sharpe=expected_max,
411
+ is_significant=is_significant,
412
+ significance_level=self.config.significance_level,
413
+ interpretation=interpretation,
414
+ )
415
+
416
+ def _generate_interpretation(
417
+ self,
418
+ mean_sharpe: float,
419
+ max_sharpe: float,
420
+ expected_max: float,
421
+ dsr: float,
422
+ is_significant: bool,
423
+ ) -> list[str]:
424
+ """Generate human-readable interpretation.
425
+
426
+ Parameters
427
+ ----------
428
+ mean_sharpe : float
429
+ Mean Sharpe across folds
430
+ max_sharpe : float
431
+ Maximum observed Sharpe
432
+ expected_max : float
433
+ Expected max under null
434
+ dsr : float
435
+ Deflated Sharpe Ratio
436
+ is_significant : bool
437
+ Whether result is significant
438
+
439
+ Returns
440
+ -------
441
+ list[str]
442
+ Interpretation strings
443
+ """
444
+ interp = []
445
+
446
+ # Significance assessment
447
+ if is_significant:
448
+ interp.append(
449
+ f"Strategy is statistically significant (DSR={dsr:.2%} > {self.config.significance_level:.0%})"
450
+ )
451
+ else:
452
+ interp.append(
453
+ f"Strategy is NOT significant (DSR={dsr:.2%} < {self.config.significance_level:.0%})"
454
+ )
455
+
456
+ # Overfitting assessment
457
+ inflation = max_sharpe - expected_max
458
+ if inflation > 0:
459
+ interp.append(
460
+ f"Potential overfitting: observed SR ({max_sharpe:.3f}) exceeds null expectation ({expected_max:.3f}) by {inflation:.3f}"
461
+ )
462
+ else:
463
+ interp.append("No obvious overfitting: observed SR below null expectation")
464
+
465
+ # Mean vs max
466
+ if max_sharpe > 2 * mean_sharpe and mean_sharpe > 0:
467
+ interp.append("High variance in fold performance suggests unstable strategy")
468
+ elif mean_sharpe > 0.5:
469
+ interp.append("Consistent positive performance across folds")
470
+
471
+ # Recommendation
472
+ if is_significant and mean_sharpe > 0.3:
473
+ interp.append(
474
+ "Recommendation: Strategy shows robust performance, consider paper trading"
475
+ )
476
+ elif is_significant:
477
+ interp.append(
478
+ "Recommendation: Significant but modest returns, investigate improvements"
479
+ )
480
+ else:
481
+ interp.append(
482
+ "Recommendation: Strategy likely overfit, revisit feature selection or model"
483
+ )
484
+
485
+ return interp
486
+
487
+
488
+ # Convenience function
489
+ def validated_cross_val_score(
490
+ model: ModelProtocol,
491
+ X: np.ndarray,
492
+ y: np.ndarray,
493
+ times: np.ndarray | None = None,
494
+ n_groups: int = 10,
495
+ embargo_pct: float = 0.01,
496
+ ) -> ValidationResult:
497
+ """Convenience function for validated cross-validation.
498
+
499
+ Parameters
500
+ ----------
501
+ model : ModelProtocol
502
+ Model with fit/predict interface
503
+ X : np.ndarray
504
+ Features
505
+ y : np.ndarray
506
+ Target (or returns)
507
+ times : np.ndarray, optional
508
+ Timestamps for purging
509
+ n_groups : int, default 10
510
+ Number of CV groups
511
+ embargo_pct : float, default 0.01
512
+ Embargo fraction
513
+
514
+ Returns
515
+ -------
516
+ ValidationResult
517
+ Validation results with DSR
518
+
519
+ Examples
520
+ --------
521
+ >>> from sklearn.ensemble import RandomForestClassifier
522
+ >>> result = validated_cross_val_score(
523
+ ... model=RandomForestClassifier(),
524
+ ... X=features,
525
+ ... y=returns,
526
+ ... times=dates,
527
+ ... )
528
+ >>> print(f"DSR: {result.dsr:.4f}")
529
+ """
530
+ config = ValidatedCrossValidationConfig(
531
+ n_groups=n_groups,
532
+ embargo_pct=embargo_pct,
533
+ )
534
+ vcv = ValidatedCrossValidation(config)
535
+ return vcv.fit_evaluate(X, y, model, times)