ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,935 @@
1
+ """Main Evaluator framework implementing the Three-Tier Validation Framework.
2
+
3
+ This module provides the Evaluator class that orchestrates the complete ml4t-diagnostic
4
+ validation workflow:
5
+
6
+ - Tier 1 (Rigorous Backtesting): Full CPCV validation with statistical tests
7
+ - Tier 2 (Statistical Significance): HAC-adjusted tests and significance testing
8
+ - Tier 3 (Production Monitoring): Fast screening metrics for live systems
9
+
10
+ The Evaluator integrates with all splitters, metrics, and statistical tests to
11
+ provide a unified interface for financial ML validation.
12
+ """
13
+
14
+ import warnings
15
+ from collections.abc import Callable
16
+ from datetime import datetime
17
+ from typing import TYPE_CHECKING, Any, Union
18
+
19
+ import numpy as np
20
+ import pandas as pd
21
+ import polars as pl
22
+ from joblib import Parallel, delayed
23
+ from sklearn.base import BaseEstimator, clone
24
+
25
+ from ml4t.diagnostic.backends.adapter import DataFrameAdapter
26
+ from ml4t.diagnostic.splitters.base import BaseSplitter
27
+ from ml4t.diagnostic.splitters.combinatorial import CombinatorialPurgedCV
28
+ from ml4t.diagnostic.splitters.walk_forward import PurgedWalkForwardCV
29
+
30
+ from .dashboard import create_evaluation_dashboard
31
+ from .metric_registry import MetricRegistry
32
+ from .stat_registry import StatTestRegistry
33
+ from .visualization import plot_ic_heatmap, plot_quantile_returns
34
+
35
+ if TYPE_CHECKING:
36
+ from numpy.typing import NDArray
37
+
38
+
39
+ def get_metric_directionality(metric_name: str) -> bool:
40
+ """Get whether a metric should be maximized (True) or minimized (False).
41
+
42
+ Parameters
43
+ ----------
44
+ metric_name : str
45
+ Name of the metric
46
+
47
+ Returns
48
+ -------
49
+ bool
50
+ True if higher values are better, False if lower values are better
51
+ """
52
+ normalized = metric_name.lower().replace("-", "_").replace(" ", "_")
53
+ return MetricRegistry.default().is_maximize(normalized)
54
+
55
+
56
+ class EvaluationResult:
57
+ """Container for evaluation results with rich reporting capabilities."""
58
+
59
+ def __init__(
60
+ self,
61
+ tier: int,
62
+ splitter_name: str,
63
+ metrics_results: dict[str, Any],
64
+ statistical_tests: dict[str, Any] | None = None,
65
+ fold_results: list[dict[str, Any]] | None = None,
66
+ metadata: dict[str, Any] | None = None,
67
+ oos_returns: list[np.ndarray] | None = None,
68
+ ):
69
+ """Initialize evaluation result.
70
+
71
+ Parameters
72
+ ----------
73
+ tier : int
74
+ Tier level (1, 2, or 3) of the evaluation
75
+ splitter_name : str
76
+ Name of the cross-validation method used
77
+ metrics_results : Dict[str, Any]
78
+ Aggregated metrics results
79
+ statistical_tests : Optional[Dict[str, Any]]
80
+ Statistical test results (Tier 1 & 2)
81
+ fold_results : Optional[List[Dict[str, Any]]]
82
+ Individual fold results for detailed analysis
83
+ metadata : Optional[Dict[str, Any]]
84
+ Additional metadata about the evaluation
85
+ oos_returns : Optional[List[np.ndarray]]
86
+ Out-of-sample strategy returns from each fold for statistical testing
87
+ """
88
+ self.tier = tier
89
+ self.splitter_name = splitter_name
90
+ self.metrics_results = metrics_results
91
+ self.statistical_tests = statistical_tests or {}
92
+ self.fold_results = fold_results or []
93
+ self.metadata = metadata or {}
94
+ self.oos_returns = oos_returns or []
95
+ self.timestamp = datetime.now()
96
+
97
+ def summary(self) -> dict[str, Any]:
98
+ """Generate a summary of the evaluation results."""
99
+ summary: dict[str, Any] = {
100
+ "tier": self.tier,
101
+ "splitter": self.splitter_name,
102
+ "timestamp": self.timestamp.isoformat(),
103
+ "n_folds": len(self.fold_results),
104
+ "metrics": {},
105
+ "statistical_tests": {},
106
+ }
107
+
108
+ # Summarize metrics
109
+ for metric_name, value in self.metrics_results.items():
110
+ if isinstance(value, dict) and "mean" in value:
111
+ summary["metrics"][metric_name] = {
112
+ "mean": value["mean"],
113
+ "std": value.get("std", None),
114
+ "significant": value.get("significant", None),
115
+ }
116
+ else:
117
+ summary["metrics"][metric_name] = value
118
+
119
+ # Summarize statistical tests
120
+ for test_name, result in self.statistical_tests.items():
121
+ if isinstance(result, dict):
122
+ summary["statistical_tests"][test_name] = {
123
+ "test_statistic": result.get(
124
+ "test_statistic",
125
+ result.get("dsr", None),
126
+ ),
127
+ "p_value": result.get("p_value", None),
128
+ "significant": result.get("p_value", 1.0) < 0.05
129
+ if "p_value" in result
130
+ else None,
131
+ }
132
+
133
+ return summary
134
+
135
+ def get_oos_returns_series(self) -> np.ndarray | None:
136
+ """Get concatenated out-of-sample returns series for statistical testing.
137
+
138
+ Returns
139
+ -------
140
+ np.ndarray or None
141
+ Concatenated strategy returns from all folds, or None if not available
142
+ """
143
+ if not self.oos_returns or len(self.oos_returns) == 0:
144
+ return None
145
+
146
+ # Filter out any NaN arrays from failed folds
147
+ valid_returns = [returns for returns in self.oos_returns if not np.all(np.isnan(returns))]
148
+
149
+ if not valid_returns:
150
+ return None
151
+
152
+ return np.concatenate(valid_returns)
153
+
154
+ def plot(
155
+ self,
156
+ predictions: Any | None = None,
157
+ returns: Any | None = None,
158
+ ) -> Any:
159
+ """Generate default visualization for evaluation results.
160
+
161
+ Parameters
162
+ ----------
163
+ predictions : array-like, optional
164
+ Predictions for visualization
165
+ returns : array-like, optional
166
+ Returns for visualization
167
+
168
+ Returns:
169
+ -------
170
+ plotly.graph_objects.Figure
171
+ Interactive visualization
172
+ """
173
+ # Default plot based on available metrics
174
+ if "ic" in self.metrics_results and predictions is not None and returns is not None:
175
+ return plot_ic_heatmap(predictions, returns)
176
+ if "sharpe" in self.metrics_results and returns is not None and predictions is not None:
177
+ return plot_quantile_returns(predictions, returns)
178
+ # Return a summary plot
179
+ import plotly.graph_objects as go
180
+
181
+ metric_names = list(self.metrics_results.keys())
182
+ metric_values = [
183
+ self.metrics_results[m].get("mean", 0)
184
+ if isinstance(self.metrics_results[m], dict)
185
+ else self.metrics_results[m]
186
+ for m in metric_names
187
+ ]
188
+
189
+ fig = go.Figure(data=[go.Bar(x=metric_names, y=metric_values)])
190
+ fig.update_layout(
191
+ title=f"Evaluation Results - Tier {self.tier}",
192
+ xaxis_title="Metric",
193
+ yaxis_title="Value",
194
+ )
195
+ return fig
196
+
197
+ def to_html(
198
+ self,
199
+ filename: str,
200
+ predictions: Any | None = None,
201
+ returns: Any | None = None,
202
+ features: Any | None = None,
203
+ title: str | None = None,
204
+ ) -> None:
205
+ """Generate interactive HTML dashboard.
206
+
207
+ Parameters
208
+ ----------
209
+ filename : str
210
+ Output HTML filename
211
+ predictions : array-like, optional
212
+ Model predictions for visualizations
213
+ returns : array-like, optional
214
+ Returns data for visualizations
215
+ features : array-like, optional
216
+ Feature data for distribution analysis
217
+ title : str, optional
218
+ Dashboard title
219
+
220
+ Examples:
221
+ --------
222
+ >>> result.to_html("evaluation_report.html", predictions=pred_df, returns=ret_df)
223
+ """
224
+ create_evaluation_dashboard(
225
+ self,
226
+ filename,
227
+ predictions=predictions,
228
+ returns=returns,
229
+ features=features,
230
+ title=title,
231
+ )
232
+
233
+ def __repr__(self) -> str:
234
+ """String representation of evaluation result."""
235
+ summary = self.summary()
236
+ metrics_str = ", ".join(
237
+ [
238
+ f"{k}: {v['mean']:.3f}"
239
+ if isinstance(v, dict) and "mean" in v
240
+ else f"{k}: {v:.3f}"
241
+ if isinstance(v, int | float)
242
+ else f"{k}: {v}"
243
+ for k, v in summary["metrics"].items()
244
+ ],
245
+ )
246
+
247
+ return (
248
+ f"EvaluationResult(tier={self.tier}, splitter={self.splitter_name}, "
249
+ f"n_folds={summary['n_folds']}, metrics=[{metrics_str}])"
250
+ )
251
+
252
+
253
+ class Evaluator:
254
+ """Main evaluator implementing the Three-Tier Validation Framework.
255
+
256
+ The Evaluator orchestrates the complete ml4t-diagnostic validation workflow by
257
+ integrating cross-validation splitters, performance metrics, and
258
+ statistical tests into a unified framework.
259
+
260
+ Three-Tier Framework:
261
+ - Tier 3: Fast screening with basic metrics
262
+ - Tier 2: Statistical significance testing with HAC adjustments
263
+ - Tier 1: Rigorous backtesting with multiple testing corrections
264
+ """
265
+
266
+ # Backward-compatible class attributes (delegate to registries)
267
+ @property
268
+ def METRIC_REGISTRY(self) -> dict[str, Callable]: # noqa: N802
269
+ """Get metric registry (backward compatibility)."""
270
+ registry = MetricRegistry.default()
271
+ return {name: registry.get(name) for name in registry.list_metrics()}
272
+
273
+ @property
274
+ def STAT_TEST_REGISTRY(self) -> dict[str, Callable]: # noqa: N802
275
+ """Get stat test registry (backward compatibility)."""
276
+ registry = StatTestRegistry.default()
277
+ return {name: registry.get(name) for name in registry.list_tests()}
278
+
279
+ def __init__(
280
+ self,
281
+ splitter: BaseSplitter | None = None,
282
+ metrics: list[str] | None = None,
283
+ statistical_tests: list[str] | None = None,
284
+ tier: int | None = None,
285
+ confidence_level: float = 0.05,
286
+ bootstrap_samples: int = 1000,
287
+ random_state: int | None = None,
288
+ n_jobs: int = 1,
289
+ ):
290
+ """Initialize the Evaluator.
291
+
292
+ Parameters
293
+ ----------
294
+ splitter : Optional[BaseSplitter], default None
295
+ Cross-validation splitter. If None, infers from tier
296
+ metrics : Optional[List[str]], default None
297
+ List of metrics to compute. If None, uses tier defaults
298
+ statistical_tests : Optional[List[str]], default None
299
+ List of statistical tests to perform. If None, uses tier defaults
300
+ tier : Optional[int], default None
301
+ Tier level (1, 2, or 3). If None, infers from other parameters
302
+ confidence_level : float, default 0.05
303
+ Significance level for statistical tests
304
+ bootstrap_samples : int, default 1000
305
+ Number of bootstrap samples for confidence intervals
306
+ random_state : Optional[int], default None
307
+ Random seed for reproducible results
308
+ n_jobs : int, default 1
309
+ Number of parallel jobs for cross-validation.
310
+ -1 means using all processors
311
+
312
+ Examples:
313
+ --------
314
+ # Tier 3: Fast screening
315
+ >>> evaluator = Evaluator(tier=3)
316
+ >>> result = evaluator.evaluate(X, y, model)
317
+
318
+ # Tier 1: Full rigorous evaluation
319
+ >>> evaluator = Evaluator(
320
+ ... splitter=CombinatorialPurgedCV(n_groups=8),
321
+ ... metrics=["sharpe", "sortino", "max_drawdown"],
322
+ ... statistical_tests=["dsr", "whites_reality_check"],
323
+ ... tier=1
324
+ ... )
325
+ >>> result = evaluator.evaluate(X, y, model)
326
+ """
327
+ self.confidence_level = confidence_level
328
+ self.bootstrap_samples = bootstrap_samples
329
+ self.random_state = random_state
330
+ self.n_jobs = n_jobs
331
+
332
+ # Infer tier if not specified
333
+ if tier is None:
334
+ tier = self._infer_tier(splitter, metrics, statistical_tests)
335
+
336
+ self.tier = tier
337
+ self.splitter = splitter or self._get_default_splitter(tier)
338
+ self.metrics = metrics or self._get_default_metrics(tier)
339
+ self.statistical_tests = statistical_tests or self._get_default_statistical_tests(tier)
340
+
341
+ # Validate configuration
342
+ self._validate_configuration()
343
+
344
+ @classmethod
345
+ def register_metric(
346
+ cls,
347
+ name: str,
348
+ func: Callable[..., float],
349
+ maximize: bool = True,
350
+ ) -> None:
351
+ """Register a custom metric function.
352
+
353
+ Parameters
354
+ ----------
355
+ name : str
356
+ Name of the metric
357
+ func : Callable
358
+ Function that takes (predictions, actual, strategy_returns) and returns float
359
+ maximize : bool, default True
360
+ Whether higher values are better
361
+
362
+ Examples
363
+ --------
364
+ >>> def my_metric(predictions, actual, returns):
365
+ ... return np.mean(predictions > 0)
366
+ >>> Evaluator.register_metric("my_metric", my_metric)
367
+ """
368
+ MetricRegistry.default().register(name, func, maximize=maximize)
369
+
370
+ @classmethod
371
+ def register_statistical_test(
372
+ cls,
373
+ name: str,
374
+ func: Callable[..., dict[str, Any]],
375
+ ) -> None:
376
+ """Register a custom statistical test function.
377
+
378
+ Parameters
379
+ ----------
380
+ name : str
381
+ Name of the test
382
+ func : Callable
383
+ Function that returns a dict with test results
384
+ """
385
+ StatTestRegistry.default().register(name, func)
386
+
387
+ def _infer_tier(
388
+ self,
389
+ splitter: BaseSplitter | None,
390
+ _metrics: list[str] | None,
391
+ statistical_tests: list[str] | None,
392
+ ) -> int:
393
+ """Infer tier level from configuration."""
394
+ # Tier 1 indicators: CPCV splitter or advanced statistical tests
395
+ if isinstance(splitter, CombinatorialPurgedCV) or (
396
+ statistical_tests
397
+ and any(test in ["dsr", "whites_reality_check"] for test in statistical_tests)
398
+ ):
399
+ return 1
400
+
401
+ # Tier 2 indicators: HAC tests or confidence intervals
402
+ if statistical_tests and any(test in ["hac_ic", "fdr"] for test in statistical_tests):
403
+ return 2
404
+
405
+ # Default to Tier 3 (fast screening)
406
+ return 3
407
+
408
+ def _get_default_splitter(self, tier: int) -> BaseSplitter:
409
+ """Get default splitter for tier."""
410
+ if tier == 1:
411
+ return CombinatorialPurgedCV(n_groups=8, n_test_groups=2)
412
+ if tier == 2:
413
+ return PurgedWalkForwardCV(n_splits=5)
414
+ # tier == 3
415
+ return PurgedWalkForwardCV(n_splits=3)
416
+
417
+ def _get_default_metrics(self, tier: int) -> list[str]:
418
+ """Get default metrics for tier."""
419
+ if tier == 1:
420
+ return ["ic", "sharpe", "sortino", "max_drawdown", "hit_rate"]
421
+ if tier == 2:
422
+ return ["ic", "sharpe", "hit_rate"]
423
+ # tier == 3
424
+ return ["ic", "hit_rate"]
425
+
426
+ def _get_default_statistical_tests(self, tier: int) -> list[str]:
427
+ """Get default statistical tests for tier."""
428
+ if tier == 1:
429
+ return ["dsr", "fdr"]
430
+ if tier == 2:
431
+ return ["hac_ic"]
432
+ # tier == 3
433
+ return []
434
+
435
+ def _validate_configuration(self) -> None:
436
+ """Validate evaluator configuration using Pydantic schemas."""
437
+ from pydantic import ValidationError
438
+
439
+ from ml4t.diagnostic.utils.config import EvaluatorConfig
440
+
441
+ try:
442
+ # Validate main evaluator parameters
443
+ EvaluatorConfig(
444
+ tier=self.tier,
445
+ confidence_level=self.confidence_level,
446
+ bootstrap_samples=self.bootstrap_samples,
447
+ random_state=self.random_state,
448
+ n_jobs=self.n_jobs,
449
+ )
450
+
451
+ except ValidationError as e:
452
+ # Convert Pydantic validation errors to clearer messages
453
+ error_messages = []
454
+ for error in e.errors():
455
+ field = error["loc"][0] if error["loc"] else "unknown"
456
+ message = error["msg"]
457
+ error_messages.append(f"{field}: {message}")
458
+
459
+ raise ValueError( # noqa: B904
460
+ f"Configuration validation failed: {'; '.join(error_messages)}",
461
+ )
462
+
463
+ # Validate metrics against registry
464
+ metric_registry = MetricRegistry.default()
465
+ invalid_metrics = [m for m in self.metrics if m not in metric_registry]
466
+ if invalid_metrics:
467
+ raise ValueError(
468
+ f"Unknown metrics: {invalid_metrics}. Available: {metric_registry.list_metrics()}",
469
+ )
470
+
471
+ # Validate statistical tests against registry
472
+ stat_registry = StatTestRegistry.default()
473
+ invalid_tests = [t for t in self.statistical_tests if t not in stat_registry]
474
+ if invalid_tests:
475
+ raise ValueError(
476
+ f"Unknown statistical tests: {invalid_tests}. Available: {stat_registry.list_tests()}",
477
+ )
478
+
479
+ # Tier-specific validations with Pydantic-style consistency checks
480
+ if self.tier == 1 and not isinstance(self.splitter, CombinatorialPurgedCV):
481
+ warnings.warn(
482
+ "Tier 1 evaluation should use CombinatorialPurgedCV for maximum rigor",
483
+ stacklevel=2,
484
+ )
485
+
486
+ if self.tier == 3 and len(self.statistical_tests) > 2:
487
+ warnings.warn(
488
+ "Tier 3 is designed for fast screening - consider limiting statistical tests",
489
+ stacklevel=2,
490
+ )
491
+
492
+ def evaluate(
493
+ self,
494
+ x: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
495
+ y: Union[pl.Series, pd.Series, "NDArray[Any]"],
496
+ model: BaseEstimator | Callable[..., Any],
497
+ strategy_func: Callable[..., Any] | None = None,
498
+ **kwargs: Any,
499
+ ) -> EvaluationResult:
500
+ """Evaluate a model using the configured validation framework.
501
+
502
+ Parameters
503
+ ----------
504
+ x : Union[pl.DataFrame, pd.DataFrame, NDArray]
505
+ Feature matrix
506
+ y : Union[pl.Series, pd.Series, NDArray]
507
+ Target values (returns)
508
+ model : Union[BaseEstimator, Callable]
509
+ Model to evaluate (scikit-learn compatible or callable)
510
+ strategy_func : Optional[Callable], default None
511
+ Function to convert predictions to returns. If None, assumes
512
+ predictions are directly used for position sizing
513
+ **kwargs : Any
514
+ Additional parameters passed to splitter
515
+
516
+ Returns:
517
+ -------
518
+ EvaluationResult
519
+ Comprehensive evaluation results
520
+
521
+ Examples:
522
+ --------
523
+ >>> from sklearn.ensemble import RandomForestRegressor
524
+ >>> model = RandomForestRegressor(n_estimators=50)
525
+ >>> evaluator = Evaluator(tier=2)
526
+ >>> result = evaluator.evaluate(X, y, model)
527
+ >>> print(result.summary())
528
+ """
529
+ # Convert inputs to consistent format
530
+ x_array = DataFrameAdapter.to_numpy(x)
531
+ y_array = DataFrameAdapter.to_numpy(y).flatten()
532
+
533
+ if len(x_array) != len(y_array):
534
+ raise ValueError("x and y must have the same number of samples")
535
+
536
+ # Set random seed if specified
537
+ if self.random_state is not None:
538
+ np.random.seed(self.random_state)
539
+
540
+ def process_fold(
541
+ fold_idx,
542
+ train_idx,
543
+ test_idx,
544
+ model,
545
+ x_array,
546
+ y_array,
547
+ strategy_func,
548
+ ):
549
+ """Process a single fold with full process isolation."""
550
+ try:
551
+ x_train, x_test = x_array[train_idx], x_array[test_idx]
552
+ y_train, y_test = y_array[train_idx], y_array[test_idx]
553
+
554
+ if hasattr(model, "fit") and hasattr(model, "predict"):
555
+ # Clone to prevent shared state between parallel processes
556
+ model_clone = clone(model)
557
+
558
+ if hasattr(model_clone, "random_state") and self.random_state is not None:
559
+ # Deterministic but different seed per fold
560
+ model_clone.random_state = self.random_state + fold_idx
561
+
562
+ model_clone.fit(x_train, y_train)
563
+ predictions = model_clone.predict(x_test)
564
+ else:
565
+ # Callable model (must be stateless)
566
+ predictions = model(x_train, y_train, x_test)
567
+
568
+ if strategy_func is not None:
569
+ strategy_returns = strategy_func(predictions, y_test)
570
+ else:
571
+ positions = np.sign(predictions)
572
+ strategy_returns = positions * y_test
573
+
574
+ fold_metrics = {}
575
+ metric_registry = MetricRegistry.default()
576
+ for metric_name in self.metrics:
577
+ try:
578
+ if metric_name in metric_registry:
579
+ metric_func = metric_registry.get(metric_name)
580
+ value = metric_func(predictions, y_test, strategy_returns)
581
+
582
+ if metric_name == "max_drawdown" and isinstance(value, dict):
583
+ value = value["max_drawdown"]
584
+
585
+ fold_metrics[metric_name] = value
586
+ except Exception as e:
587
+ fold_metrics[metric_name] = np.nan
588
+ warnings.warn(
589
+ f"Fold {fold_idx}: Failed to calculate {metric_name}: {e}",
590
+ stacklevel=2,
591
+ )
592
+
593
+ fold_metrics["fold"] = fold_idx
594
+ fold_metrics["n_train"] = len(train_idx)
595
+ fold_metrics["n_test"] = len(test_idx)
596
+
597
+ return fold_metrics, predictions, y_test, strategy_returns
598
+
599
+ except Exception as e:
600
+ warnings.warn(
601
+ f"Fold {fold_idx} failed with error: {e}. Returning NaN results.",
602
+ stacklevel=2,
603
+ )
604
+ nan_metrics = dict.fromkeys(self.metrics, np.nan)
605
+ nan_metrics.update(
606
+ {
607
+ "fold": fold_idx,
608
+ "n_train": len(train_idx),
609
+ "n_test": len(test_idx),
610
+ },
611
+ )
612
+
613
+ return nan_metrics, np.array([np.nan]), np.array([np.nan]), np.array([np.nan])
614
+
615
+ splits = list(self.splitter.split(x, y, **kwargs))
616
+
617
+ if self.n_jobs == 1:
618
+ results = []
619
+ for fold_idx, (train_idx, test_idx) in enumerate(splits):
620
+ result = process_fold(
621
+ fold_idx, train_idx, test_idx, model, x_array, y_array, strategy_func
622
+ )
623
+ results.append(result)
624
+ else:
625
+ # Use loky backend for process isolation (prevents race conditions)
626
+ results = Parallel(n_jobs=self.n_jobs, backend="loky")(
627
+ delayed(process_fold)(
628
+ fold_idx, train_idx, test_idx, model, x_array, y_array, strategy_func
629
+ )
630
+ for fold_idx, (train_idx, test_idx) in enumerate(splits)
631
+ )
632
+
633
+ fold_results = [r[0] for r in results]
634
+ all_predictions = [pred for r in results for pred in r[1]]
635
+ all_actual = [actual for r in results for actual in r[2]]
636
+ oos_returns = [r[3] for r in results]
637
+
638
+ metrics_results = self._aggregate_metrics(fold_results)
639
+ statistical_tests = self._perform_statistical_tests(
640
+ fold_results,
641
+ all_predictions,
642
+ all_actual,
643
+ metrics_results,
644
+ oos_returns,
645
+ )
646
+
647
+ metadata = {
648
+ "n_samples": len(x_array),
649
+ "n_features": x_array.shape[1] if x_array.ndim > 1 else 1,
650
+ "splitter_params": self.splitter.__dict__,
651
+ "tier": self.tier,
652
+ "random_state": self.random_state,
653
+ }
654
+
655
+ return EvaluationResult(
656
+ tier=self.tier,
657
+ splitter_name=self.splitter.__class__.__name__,
658
+ metrics_results=metrics_results,
659
+ statistical_tests=statistical_tests,
660
+ fold_results=fold_results,
661
+ metadata=metadata,
662
+ oos_returns=oos_returns,
663
+ )
664
+
665
+ def _aggregate_metrics(self, fold_results: list[dict[str, Any]]) -> dict[str, Any]:
666
+ """Aggregate metrics across folds."""
667
+ aggregated = {}
668
+
669
+ for metric_name in self.metrics:
670
+ values = [fold.get(metric_name, np.nan) for fold in fold_results]
671
+ valid_values = [v for v in values if not np.isnan(v)]
672
+
673
+ if valid_values:
674
+ aggregated[metric_name] = {
675
+ "mean": np.mean(valid_values),
676
+ "std": np.std(valid_values, ddof=1) if len(valid_values) > 1 else 0.0,
677
+ "min": np.min(valid_values),
678
+ "max": np.max(valid_values),
679
+ "values": valid_values,
680
+ "n_valid": len(valid_values),
681
+ }
682
+
683
+ # Add confidence interval for mean if multiple folds
684
+ if len(valid_values) > 1:
685
+ se = aggregated[metric_name]["std"] / np.sqrt(len(valid_values))
686
+ from scipy.stats import t
687
+
688
+ t_val = t.ppf(1 - self.confidence_level / 2, len(valid_values) - 1)
689
+ margin = t_val * se
690
+
691
+ aggregated[metric_name]["ci_lower"] = aggregated[metric_name]["mean"] - margin
692
+ aggregated[metric_name]["ci_upper"] = aggregated[metric_name]["mean"] + margin
693
+ else:
694
+ aggregated[metric_name] = {
695
+ "mean": np.nan,
696
+ "std": np.nan,
697
+ "min": np.nan,
698
+ "max": np.nan,
699
+ "values": [],
700
+ "n_valid": 0,
701
+ }
702
+
703
+ return aggregated
704
+
705
+ def _perform_statistical_tests(
706
+ self,
707
+ fold_results: list[dict[str, Any]],
708
+ all_predictions: list[float],
709
+ all_actual: list[float],
710
+ metrics_results: dict[str, Any],
711
+ oos_returns: list[np.ndarray],
712
+ ) -> dict[str, Any]:
713
+ """Perform statistical tests based on configuration."""
714
+ statistical_results: dict[str, Any] = {}
715
+ stat_registry = StatTestRegistry.default()
716
+
717
+ for test_name in self.statistical_tests:
718
+ try:
719
+ if test_name in stat_registry:
720
+ test_func = stat_registry.get(test_name)
721
+
722
+ # Prepare test-specific arguments
723
+ if test_name == "dsr" and "sharpe" in metrics_results:
724
+ sharpe_values = metrics_results["sharpe"]["values"]
725
+ if sharpe_values and len(oos_returns) > 0:
726
+ best_sharpe = float(np.max(sharpe_values))
727
+ n_trials = len(fold_results)
728
+ # Calculate variance across trials
729
+ variance_trials = (
730
+ float(np.var(sharpe_values, ddof=1))
731
+ if len(sharpe_values) > 1
732
+ else 0.001
733
+ )
734
+ # Calculate average sample size per fold
735
+ n_samples = int(
736
+ np.mean(
737
+ [len(returns) for returns in oos_returns if len(returns) > 0]
738
+ )
739
+ )
740
+ # Use deflated_sharpe_ratio_from_statistics with new API
741
+ dsr_result = test_func(
742
+ observed_sharpe=best_sharpe,
743
+ n_samples=n_samples,
744
+ n_trials=n_trials,
745
+ variance_trials=variance_trials,
746
+ )
747
+ # Convert DSRResult dataclass to dict for consistency
748
+ result = {
749
+ "dsr": dsr_result.probability,
750
+ "p_value": dsr_result.p_value,
751
+ "expected_max_sharpe": dsr_result.expected_max_sharpe,
752
+ "z_score": dsr_result.z_score,
753
+ "is_significant": dsr_result.is_significant,
754
+ }
755
+ else:
756
+ continue
757
+
758
+ elif test_name == "hac_ic" and "ic" in metrics_results:
759
+ result = test_func(
760
+ predictions=np.array(all_predictions),
761
+ returns=np.array(all_actual),
762
+ return_details=True,
763
+ )
764
+
765
+ elif test_name == "fdr":
766
+ # Collect p-values from other tests
767
+ p_values = []
768
+ for test_result in statistical_results.values():
769
+ if isinstance(test_result, dict) and "p_value" in test_result:
770
+ p_values.append(test_result["p_value"])
771
+
772
+ if p_values:
773
+ result = test_func(
774
+ p_values,
775
+ alpha=self.confidence_level,
776
+ return_details=True,
777
+ )
778
+ else:
779
+ continue
780
+
781
+ elif test_name == "whites_reality_check":
782
+ if len(oos_returns) > 1 and all(
783
+ len(returns) > 0 for returns in oos_returns
784
+ ):
785
+ # Concatenate all OOS returns into a single time series
786
+ # This is the correct input for White's Reality Check
787
+ strategy_returns_series = np.concatenate(oos_returns)
788
+
789
+ # Create benchmark (zero returns) of the same length
790
+ benchmark_returns = np.zeros(len(strategy_returns_series))
791
+
792
+ # Reshape for test function (expects 2D array for strategies)
793
+ strategy_returns_matrix = strategy_returns_series.reshape(-1, 1)
794
+
795
+ result = test_func(
796
+ returns_benchmark=benchmark_returns,
797
+ returns_strategies=strategy_returns_matrix,
798
+ bootstrap_samples=min(self.bootstrap_samples, 500),
799
+ random_state=self.random_state,
800
+ )
801
+ else:
802
+ continue
803
+ else:
804
+ # Generic test function call
805
+ result = test_func(
806
+ fold_results=fold_results,
807
+ predictions=all_predictions,
808
+ actual=all_actual,
809
+ metrics_results=metrics_results,
810
+ )
811
+
812
+ statistical_results[test_name] = result
813
+ else:
814
+ warnings.warn(
815
+ f"Unknown statistical test: {test_name}",
816
+ stacklevel=2,
817
+ )
818
+ continue
819
+
820
+ except Exception as e:
821
+ warnings.warn(
822
+ f"Error in statistical test {test_name}: {e}",
823
+ stacklevel=2,
824
+ )
825
+ # Store error in a way that's compatible with the expected type
826
+ error_result: dict[str, Any] = {"error": str(e)}
827
+ statistical_results[test_name] = error_result
828
+
829
+ return statistical_results
830
+
831
+ def batch_evaluate(
832
+ self,
833
+ models: list[BaseEstimator | Callable[..., Any]],
834
+ x: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
835
+ y: Union[pl.Series, pd.Series, "NDArray[Any]"],
836
+ model_names: list[str] | None = None,
837
+ **kwargs: Any,
838
+ ) -> dict[str, EvaluationResult]:
839
+ """Evaluate multiple models with the same validation framework.
840
+
841
+ Parameters
842
+ ----------
843
+ models : List[Union[BaseEstimator, Callable]]
844
+ List of models to evaluate
845
+ X : Union[pl.DataFrame, pd.DataFrame, NDArray]
846
+ Feature matrix
847
+ y : Union[pl.Series, pd.Series, NDArray]
848
+ Target values
849
+ model_names : Optional[List[str]], default None
850
+ Names for the models. If None, uses model class names
851
+ **kwargs : Any
852
+ Additional parameters passed to evaluate()
853
+
854
+ Returns:
855
+ -------
856
+ dict[str, EvaluationResult]
857
+ Dictionary mapping model names to evaluation results
858
+ """
859
+ if model_names is None:
860
+ model_names = [
861
+ model.__class__.__name__ if hasattr(model, "__class__") else f"Model_{i}"
862
+ for i, model in enumerate(models)
863
+ ]
864
+
865
+ if len(models) != len(model_names):
866
+ raise ValueError("Number of models must match number of model names")
867
+
868
+ results = {}
869
+ for model, name in zip(models, model_names, strict=False):
870
+ print(f"Evaluating {name}...")
871
+ results[name] = self.evaluate(x, y, model, **kwargs)
872
+
873
+ return results
874
+
875
+ def compare_models(
876
+ self,
877
+ batch_results: dict[str, EvaluationResult],
878
+ primary_metric: str = "sharpe",
879
+ ) -> dict[str, Any]:
880
+ """Compare multiple model evaluation results.
881
+
882
+ Parameters
883
+ ----------
884
+ batch_results : dict[str, EvaluationResult]
885
+ Results from batch_evaluate()
886
+ primary_metric : str, default "sharpe"
887
+ Primary metric for ranking models
888
+
889
+ Returns:
890
+ -------
891
+ dict[str, Any]
892
+ Comparison summary with rankings and statistical tests
893
+ """
894
+ if not batch_results:
895
+ return {"error": "No results to compare"}
896
+
897
+ # Extract primary metric values
898
+ model_metrics = {}
899
+ for name, result in batch_results.items():
900
+ metric_value = result.metrics_results.get(primary_metric, {}).get(
901
+ "mean",
902
+ np.nan,
903
+ )
904
+ model_metrics[name] = metric_value
905
+
906
+ # Rank models
907
+ valid_models = {k: v for k, v in model_metrics.items() if not np.isnan(v)}
908
+ if not valid_models:
909
+ return {"error": f"No valid {primary_metric} values found"}
910
+
911
+ # Determine sort order based on metric directionality
912
+ maximize = get_metric_directionality(primary_metric)
913
+
914
+ # Special handling for drawdown metrics (they're negative, closer to 0 is better)
915
+ if "drawdown" in primary_metric.lower():
916
+ # For drawdown, sort by absolute value (smaller absolute value is better)
917
+ ranked_models = sorted(valid_models.items(), key=lambda x: abs(x[1]))
918
+ else:
919
+ # Regular sorting based on directionality
920
+ ranked_models = sorted(valid_models.items(), key=lambda x: x[1], reverse=maximize)
921
+
922
+ # Create comparison summary
923
+ comparison: dict[str, Any] = {
924
+ "primary_metric": primary_metric,
925
+ "n_models": len(batch_results),
926
+ "ranking": [{"model": name, primary_metric: value} for name, value in ranked_models],
927
+ "best_model": ranked_models[0][0] if ranked_models else None,
928
+ "model_details": {},
929
+ }
930
+
931
+ # Add detailed results for each model
932
+ for name, result in batch_results.items():
933
+ comparison["model_details"][name] = result.summary()
934
+
935
+ return comparison