ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,935 @@
|
|
|
1
|
+
"""Main Evaluator framework implementing the Three-Tier Validation Framework.
|
|
2
|
+
|
|
3
|
+
This module provides the Evaluator class that orchestrates the complete ml4t-diagnostic
|
|
4
|
+
validation workflow:
|
|
5
|
+
|
|
6
|
+
- Tier 1 (Rigorous Backtesting): Full CPCV validation with statistical tests
|
|
7
|
+
- Tier 2 (Statistical Significance): HAC-adjusted tests and significance testing
|
|
8
|
+
- Tier 3 (Production Monitoring): Fast screening metrics for live systems
|
|
9
|
+
|
|
10
|
+
The Evaluator integrates with all splitters, metrics, and statistical tests to
|
|
11
|
+
provide a unified interface for financial ML validation.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import warnings
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pandas as pd
|
|
21
|
+
import polars as pl
|
|
22
|
+
from joblib import Parallel, delayed
|
|
23
|
+
from sklearn.base import BaseEstimator, clone
|
|
24
|
+
|
|
25
|
+
from ml4t.diagnostic.backends.adapter import DataFrameAdapter
|
|
26
|
+
from ml4t.diagnostic.splitters.base import BaseSplitter
|
|
27
|
+
from ml4t.diagnostic.splitters.combinatorial import CombinatorialPurgedCV
|
|
28
|
+
from ml4t.diagnostic.splitters.walk_forward import PurgedWalkForwardCV
|
|
29
|
+
|
|
30
|
+
from .dashboard import create_evaluation_dashboard
|
|
31
|
+
from .metric_registry import MetricRegistry
|
|
32
|
+
from .stat_registry import StatTestRegistry
|
|
33
|
+
from .visualization import plot_ic_heatmap, plot_quantile_returns
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from numpy.typing import NDArray
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_metric_directionality(metric_name: str) -> bool:
|
|
40
|
+
"""Get whether a metric should be maximized (True) or minimized (False).
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
metric_name : str
|
|
45
|
+
Name of the metric
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
bool
|
|
50
|
+
True if higher values are better, False if lower values are better
|
|
51
|
+
"""
|
|
52
|
+
normalized = metric_name.lower().replace("-", "_").replace(" ", "_")
|
|
53
|
+
return MetricRegistry.default().is_maximize(normalized)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class EvaluationResult:
|
|
57
|
+
"""Container for evaluation results with rich reporting capabilities."""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
tier: int,
|
|
62
|
+
splitter_name: str,
|
|
63
|
+
metrics_results: dict[str, Any],
|
|
64
|
+
statistical_tests: dict[str, Any] | None = None,
|
|
65
|
+
fold_results: list[dict[str, Any]] | None = None,
|
|
66
|
+
metadata: dict[str, Any] | None = None,
|
|
67
|
+
oos_returns: list[np.ndarray] | None = None,
|
|
68
|
+
):
|
|
69
|
+
"""Initialize evaluation result.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
tier : int
|
|
74
|
+
Tier level (1, 2, or 3) of the evaluation
|
|
75
|
+
splitter_name : str
|
|
76
|
+
Name of the cross-validation method used
|
|
77
|
+
metrics_results : Dict[str, Any]
|
|
78
|
+
Aggregated metrics results
|
|
79
|
+
statistical_tests : Optional[Dict[str, Any]]
|
|
80
|
+
Statistical test results (Tier 1 & 2)
|
|
81
|
+
fold_results : Optional[List[Dict[str, Any]]]
|
|
82
|
+
Individual fold results for detailed analysis
|
|
83
|
+
metadata : Optional[Dict[str, Any]]
|
|
84
|
+
Additional metadata about the evaluation
|
|
85
|
+
oos_returns : Optional[List[np.ndarray]]
|
|
86
|
+
Out-of-sample strategy returns from each fold for statistical testing
|
|
87
|
+
"""
|
|
88
|
+
self.tier = tier
|
|
89
|
+
self.splitter_name = splitter_name
|
|
90
|
+
self.metrics_results = metrics_results
|
|
91
|
+
self.statistical_tests = statistical_tests or {}
|
|
92
|
+
self.fold_results = fold_results or []
|
|
93
|
+
self.metadata = metadata or {}
|
|
94
|
+
self.oos_returns = oos_returns or []
|
|
95
|
+
self.timestamp = datetime.now()
|
|
96
|
+
|
|
97
|
+
def summary(self) -> dict[str, Any]:
|
|
98
|
+
"""Generate a summary of the evaluation results."""
|
|
99
|
+
summary: dict[str, Any] = {
|
|
100
|
+
"tier": self.tier,
|
|
101
|
+
"splitter": self.splitter_name,
|
|
102
|
+
"timestamp": self.timestamp.isoformat(),
|
|
103
|
+
"n_folds": len(self.fold_results),
|
|
104
|
+
"metrics": {},
|
|
105
|
+
"statistical_tests": {},
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
# Summarize metrics
|
|
109
|
+
for metric_name, value in self.metrics_results.items():
|
|
110
|
+
if isinstance(value, dict) and "mean" in value:
|
|
111
|
+
summary["metrics"][metric_name] = {
|
|
112
|
+
"mean": value["mean"],
|
|
113
|
+
"std": value.get("std", None),
|
|
114
|
+
"significant": value.get("significant", None),
|
|
115
|
+
}
|
|
116
|
+
else:
|
|
117
|
+
summary["metrics"][metric_name] = value
|
|
118
|
+
|
|
119
|
+
# Summarize statistical tests
|
|
120
|
+
for test_name, result in self.statistical_tests.items():
|
|
121
|
+
if isinstance(result, dict):
|
|
122
|
+
summary["statistical_tests"][test_name] = {
|
|
123
|
+
"test_statistic": result.get(
|
|
124
|
+
"test_statistic",
|
|
125
|
+
result.get("dsr", None),
|
|
126
|
+
),
|
|
127
|
+
"p_value": result.get("p_value", None),
|
|
128
|
+
"significant": result.get("p_value", 1.0) < 0.05
|
|
129
|
+
if "p_value" in result
|
|
130
|
+
else None,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
return summary
|
|
134
|
+
|
|
135
|
+
def get_oos_returns_series(self) -> np.ndarray | None:
|
|
136
|
+
"""Get concatenated out-of-sample returns series for statistical testing.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
np.ndarray or None
|
|
141
|
+
Concatenated strategy returns from all folds, or None if not available
|
|
142
|
+
"""
|
|
143
|
+
if not self.oos_returns or len(self.oos_returns) == 0:
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
# Filter out any NaN arrays from failed folds
|
|
147
|
+
valid_returns = [returns for returns in self.oos_returns if not np.all(np.isnan(returns))]
|
|
148
|
+
|
|
149
|
+
if not valid_returns:
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
return np.concatenate(valid_returns)
|
|
153
|
+
|
|
154
|
+
def plot(
|
|
155
|
+
self,
|
|
156
|
+
predictions: Any | None = None,
|
|
157
|
+
returns: Any | None = None,
|
|
158
|
+
) -> Any:
|
|
159
|
+
"""Generate default visualization for evaluation results.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
predictions : array-like, optional
|
|
164
|
+
Predictions for visualization
|
|
165
|
+
returns : array-like, optional
|
|
166
|
+
Returns for visualization
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
-------
|
|
170
|
+
plotly.graph_objects.Figure
|
|
171
|
+
Interactive visualization
|
|
172
|
+
"""
|
|
173
|
+
# Default plot based on available metrics
|
|
174
|
+
if "ic" in self.metrics_results and predictions is not None and returns is not None:
|
|
175
|
+
return plot_ic_heatmap(predictions, returns)
|
|
176
|
+
if "sharpe" in self.metrics_results and returns is not None and predictions is not None:
|
|
177
|
+
return plot_quantile_returns(predictions, returns)
|
|
178
|
+
# Return a summary plot
|
|
179
|
+
import plotly.graph_objects as go
|
|
180
|
+
|
|
181
|
+
metric_names = list(self.metrics_results.keys())
|
|
182
|
+
metric_values = [
|
|
183
|
+
self.metrics_results[m].get("mean", 0)
|
|
184
|
+
if isinstance(self.metrics_results[m], dict)
|
|
185
|
+
else self.metrics_results[m]
|
|
186
|
+
for m in metric_names
|
|
187
|
+
]
|
|
188
|
+
|
|
189
|
+
fig = go.Figure(data=[go.Bar(x=metric_names, y=metric_values)])
|
|
190
|
+
fig.update_layout(
|
|
191
|
+
title=f"Evaluation Results - Tier {self.tier}",
|
|
192
|
+
xaxis_title="Metric",
|
|
193
|
+
yaxis_title="Value",
|
|
194
|
+
)
|
|
195
|
+
return fig
|
|
196
|
+
|
|
197
|
+
def to_html(
|
|
198
|
+
self,
|
|
199
|
+
filename: str,
|
|
200
|
+
predictions: Any | None = None,
|
|
201
|
+
returns: Any | None = None,
|
|
202
|
+
features: Any | None = None,
|
|
203
|
+
title: str | None = None,
|
|
204
|
+
) -> None:
|
|
205
|
+
"""Generate interactive HTML dashboard.
|
|
206
|
+
|
|
207
|
+
Parameters
|
|
208
|
+
----------
|
|
209
|
+
filename : str
|
|
210
|
+
Output HTML filename
|
|
211
|
+
predictions : array-like, optional
|
|
212
|
+
Model predictions for visualizations
|
|
213
|
+
returns : array-like, optional
|
|
214
|
+
Returns data for visualizations
|
|
215
|
+
features : array-like, optional
|
|
216
|
+
Feature data for distribution analysis
|
|
217
|
+
title : str, optional
|
|
218
|
+
Dashboard title
|
|
219
|
+
|
|
220
|
+
Examples:
|
|
221
|
+
--------
|
|
222
|
+
>>> result.to_html("evaluation_report.html", predictions=pred_df, returns=ret_df)
|
|
223
|
+
"""
|
|
224
|
+
create_evaluation_dashboard(
|
|
225
|
+
self,
|
|
226
|
+
filename,
|
|
227
|
+
predictions=predictions,
|
|
228
|
+
returns=returns,
|
|
229
|
+
features=features,
|
|
230
|
+
title=title,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def __repr__(self) -> str:
|
|
234
|
+
"""String representation of evaluation result."""
|
|
235
|
+
summary = self.summary()
|
|
236
|
+
metrics_str = ", ".join(
|
|
237
|
+
[
|
|
238
|
+
f"{k}: {v['mean']:.3f}"
|
|
239
|
+
if isinstance(v, dict) and "mean" in v
|
|
240
|
+
else f"{k}: {v:.3f}"
|
|
241
|
+
if isinstance(v, int | float)
|
|
242
|
+
else f"{k}: {v}"
|
|
243
|
+
for k, v in summary["metrics"].items()
|
|
244
|
+
],
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
return (
|
|
248
|
+
f"EvaluationResult(tier={self.tier}, splitter={self.splitter_name}, "
|
|
249
|
+
f"n_folds={summary['n_folds']}, metrics=[{metrics_str}])"
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class Evaluator:
|
|
254
|
+
"""Main evaluator implementing the Three-Tier Validation Framework.
|
|
255
|
+
|
|
256
|
+
The Evaluator orchestrates the complete ml4t-diagnostic validation workflow by
|
|
257
|
+
integrating cross-validation splitters, performance metrics, and
|
|
258
|
+
statistical tests into a unified framework.
|
|
259
|
+
|
|
260
|
+
Three-Tier Framework:
|
|
261
|
+
- Tier 3: Fast screening with basic metrics
|
|
262
|
+
- Tier 2: Statistical significance testing with HAC adjustments
|
|
263
|
+
- Tier 1: Rigorous backtesting with multiple testing corrections
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
# Backward-compatible class attributes (delegate to registries)
|
|
267
|
+
@property
|
|
268
|
+
def METRIC_REGISTRY(self) -> dict[str, Callable]: # noqa: N802
|
|
269
|
+
"""Get metric registry (backward compatibility)."""
|
|
270
|
+
registry = MetricRegistry.default()
|
|
271
|
+
return {name: registry.get(name) for name in registry.list_metrics()}
|
|
272
|
+
|
|
273
|
+
@property
|
|
274
|
+
def STAT_TEST_REGISTRY(self) -> dict[str, Callable]: # noqa: N802
|
|
275
|
+
"""Get stat test registry (backward compatibility)."""
|
|
276
|
+
registry = StatTestRegistry.default()
|
|
277
|
+
return {name: registry.get(name) for name in registry.list_tests()}
|
|
278
|
+
|
|
279
|
+
def __init__(
|
|
280
|
+
self,
|
|
281
|
+
splitter: BaseSplitter | None = None,
|
|
282
|
+
metrics: list[str] | None = None,
|
|
283
|
+
statistical_tests: list[str] | None = None,
|
|
284
|
+
tier: int | None = None,
|
|
285
|
+
confidence_level: float = 0.05,
|
|
286
|
+
bootstrap_samples: int = 1000,
|
|
287
|
+
random_state: int | None = None,
|
|
288
|
+
n_jobs: int = 1,
|
|
289
|
+
):
|
|
290
|
+
"""Initialize the Evaluator.
|
|
291
|
+
|
|
292
|
+
Parameters
|
|
293
|
+
----------
|
|
294
|
+
splitter : Optional[BaseSplitter], default None
|
|
295
|
+
Cross-validation splitter. If None, infers from tier
|
|
296
|
+
metrics : Optional[List[str]], default None
|
|
297
|
+
List of metrics to compute. If None, uses tier defaults
|
|
298
|
+
statistical_tests : Optional[List[str]], default None
|
|
299
|
+
List of statistical tests to perform. If None, uses tier defaults
|
|
300
|
+
tier : Optional[int], default None
|
|
301
|
+
Tier level (1, 2, or 3). If None, infers from other parameters
|
|
302
|
+
confidence_level : float, default 0.05
|
|
303
|
+
Significance level for statistical tests
|
|
304
|
+
bootstrap_samples : int, default 1000
|
|
305
|
+
Number of bootstrap samples for confidence intervals
|
|
306
|
+
random_state : Optional[int], default None
|
|
307
|
+
Random seed for reproducible results
|
|
308
|
+
n_jobs : int, default 1
|
|
309
|
+
Number of parallel jobs for cross-validation.
|
|
310
|
+
-1 means using all processors
|
|
311
|
+
|
|
312
|
+
Examples:
|
|
313
|
+
--------
|
|
314
|
+
# Tier 3: Fast screening
|
|
315
|
+
>>> evaluator = Evaluator(tier=3)
|
|
316
|
+
>>> result = evaluator.evaluate(X, y, model)
|
|
317
|
+
|
|
318
|
+
# Tier 1: Full rigorous evaluation
|
|
319
|
+
>>> evaluator = Evaluator(
|
|
320
|
+
... splitter=CombinatorialPurgedCV(n_groups=8),
|
|
321
|
+
... metrics=["sharpe", "sortino", "max_drawdown"],
|
|
322
|
+
... statistical_tests=["dsr", "whites_reality_check"],
|
|
323
|
+
... tier=1
|
|
324
|
+
... )
|
|
325
|
+
>>> result = evaluator.evaluate(X, y, model)
|
|
326
|
+
"""
|
|
327
|
+
self.confidence_level = confidence_level
|
|
328
|
+
self.bootstrap_samples = bootstrap_samples
|
|
329
|
+
self.random_state = random_state
|
|
330
|
+
self.n_jobs = n_jobs
|
|
331
|
+
|
|
332
|
+
# Infer tier if not specified
|
|
333
|
+
if tier is None:
|
|
334
|
+
tier = self._infer_tier(splitter, metrics, statistical_tests)
|
|
335
|
+
|
|
336
|
+
self.tier = tier
|
|
337
|
+
self.splitter = splitter or self._get_default_splitter(tier)
|
|
338
|
+
self.metrics = metrics or self._get_default_metrics(tier)
|
|
339
|
+
self.statistical_tests = statistical_tests or self._get_default_statistical_tests(tier)
|
|
340
|
+
|
|
341
|
+
# Validate configuration
|
|
342
|
+
self._validate_configuration()
|
|
343
|
+
|
|
344
|
+
@classmethod
|
|
345
|
+
def register_metric(
|
|
346
|
+
cls,
|
|
347
|
+
name: str,
|
|
348
|
+
func: Callable[..., float],
|
|
349
|
+
maximize: bool = True,
|
|
350
|
+
) -> None:
|
|
351
|
+
"""Register a custom metric function.
|
|
352
|
+
|
|
353
|
+
Parameters
|
|
354
|
+
----------
|
|
355
|
+
name : str
|
|
356
|
+
Name of the metric
|
|
357
|
+
func : Callable
|
|
358
|
+
Function that takes (predictions, actual, strategy_returns) and returns float
|
|
359
|
+
maximize : bool, default True
|
|
360
|
+
Whether higher values are better
|
|
361
|
+
|
|
362
|
+
Examples
|
|
363
|
+
--------
|
|
364
|
+
>>> def my_metric(predictions, actual, returns):
|
|
365
|
+
... return np.mean(predictions > 0)
|
|
366
|
+
>>> Evaluator.register_metric("my_metric", my_metric)
|
|
367
|
+
"""
|
|
368
|
+
MetricRegistry.default().register(name, func, maximize=maximize)
|
|
369
|
+
|
|
370
|
+
@classmethod
|
|
371
|
+
def register_statistical_test(
|
|
372
|
+
cls,
|
|
373
|
+
name: str,
|
|
374
|
+
func: Callable[..., dict[str, Any]],
|
|
375
|
+
) -> None:
|
|
376
|
+
"""Register a custom statistical test function.
|
|
377
|
+
|
|
378
|
+
Parameters
|
|
379
|
+
----------
|
|
380
|
+
name : str
|
|
381
|
+
Name of the test
|
|
382
|
+
func : Callable
|
|
383
|
+
Function that returns a dict with test results
|
|
384
|
+
"""
|
|
385
|
+
StatTestRegistry.default().register(name, func)
|
|
386
|
+
|
|
387
|
+
def _infer_tier(
|
|
388
|
+
self,
|
|
389
|
+
splitter: BaseSplitter | None,
|
|
390
|
+
_metrics: list[str] | None,
|
|
391
|
+
statistical_tests: list[str] | None,
|
|
392
|
+
) -> int:
|
|
393
|
+
"""Infer tier level from configuration."""
|
|
394
|
+
# Tier 1 indicators: CPCV splitter or advanced statistical tests
|
|
395
|
+
if isinstance(splitter, CombinatorialPurgedCV) or (
|
|
396
|
+
statistical_tests
|
|
397
|
+
and any(test in ["dsr", "whites_reality_check"] for test in statistical_tests)
|
|
398
|
+
):
|
|
399
|
+
return 1
|
|
400
|
+
|
|
401
|
+
# Tier 2 indicators: HAC tests or confidence intervals
|
|
402
|
+
if statistical_tests and any(test in ["hac_ic", "fdr"] for test in statistical_tests):
|
|
403
|
+
return 2
|
|
404
|
+
|
|
405
|
+
# Default to Tier 3 (fast screening)
|
|
406
|
+
return 3
|
|
407
|
+
|
|
408
|
+
def _get_default_splitter(self, tier: int) -> BaseSplitter:
|
|
409
|
+
"""Get default splitter for tier."""
|
|
410
|
+
if tier == 1:
|
|
411
|
+
return CombinatorialPurgedCV(n_groups=8, n_test_groups=2)
|
|
412
|
+
if tier == 2:
|
|
413
|
+
return PurgedWalkForwardCV(n_splits=5)
|
|
414
|
+
# tier == 3
|
|
415
|
+
return PurgedWalkForwardCV(n_splits=3)
|
|
416
|
+
|
|
417
|
+
def _get_default_metrics(self, tier: int) -> list[str]:
|
|
418
|
+
"""Get default metrics for tier."""
|
|
419
|
+
if tier == 1:
|
|
420
|
+
return ["ic", "sharpe", "sortino", "max_drawdown", "hit_rate"]
|
|
421
|
+
if tier == 2:
|
|
422
|
+
return ["ic", "sharpe", "hit_rate"]
|
|
423
|
+
# tier == 3
|
|
424
|
+
return ["ic", "hit_rate"]
|
|
425
|
+
|
|
426
|
+
def _get_default_statistical_tests(self, tier: int) -> list[str]:
|
|
427
|
+
"""Get default statistical tests for tier."""
|
|
428
|
+
if tier == 1:
|
|
429
|
+
return ["dsr", "fdr"]
|
|
430
|
+
if tier == 2:
|
|
431
|
+
return ["hac_ic"]
|
|
432
|
+
# tier == 3
|
|
433
|
+
return []
|
|
434
|
+
|
|
435
|
+
def _validate_configuration(self) -> None:
|
|
436
|
+
"""Validate evaluator configuration using Pydantic schemas."""
|
|
437
|
+
from pydantic import ValidationError
|
|
438
|
+
|
|
439
|
+
from ml4t.diagnostic.utils.config import EvaluatorConfig
|
|
440
|
+
|
|
441
|
+
try:
|
|
442
|
+
# Validate main evaluator parameters
|
|
443
|
+
EvaluatorConfig(
|
|
444
|
+
tier=self.tier,
|
|
445
|
+
confidence_level=self.confidence_level,
|
|
446
|
+
bootstrap_samples=self.bootstrap_samples,
|
|
447
|
+
random_state=self.random_state,
|
|
448
|
+
n_jobs=self.n_jobs,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
except ValidationError as e:
|
|
452
|
+
# Convert Pydantic validation errors to clearer messages
|
|
453
|
+
error_messages = []
|
|
454
|
+
for error in e.errors():
|
|
455
|
+
field = error["loc"][0] if error["loc"] else "unknown"
|
|
456
|
+
message = error["msg"]
|
|
457
|
+
error_messages.append(f"{field}: {message}")
|
|
458
|
+
|
|
459
|
+
raise ValueError( # noqa: B904
|
|
460
|
+
f"Configuration validation failed: {'; '.join(error_messages)}",
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# Validate metrics against registry
|
|
464
|
+
metric_registry = MetricRegistry.default()
|
|
465
|
+
invalid_metrics = [m for m in self.metrics if m not in metric_registry]
|
|
466
|
+
if invalid_metrics:
|
|
467
|
+
raise ValueError(
|
|
468
|
+
f"Unknown metrics: {invalid_metrics}. Available: {metric_registry.list_metrics()}",
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Validate statistical tests against registry
|
|
472
|
+
stat_registry = StatTestRegistry.default()
|
|
473
|
+
invalid_tests = [t for t in self.statistical_tests if t not in stat_registry]
|
|
474
|
+
if invalid_tests:
|
|
475
|
+
raise ValueError(
|
|
476
|
+
f"Unknown statistical tests: {invalid_tests}. Available: {stat_registry.list_tests()}",
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
# Tier-specific validations with Pydantic-style consistency checks
|
|
480
|
+
if self.tier == 1 and not isinstance(self.splitter, CombinatorialPurgedCV):
|
|
481
|
+
warnings.warn(
|
|
482
|
+
"Tier 1 evaluation should use CombinatorialPurgedCV for maximum rigor",
|
|
483
|
+
stacklevel=2,
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
if self.tier == 3 and len(self.statistical_tests) > 2:
|
|
487
|
+
warnings.warn(
|
|
488
|
+
"Tier 3 is designed for fast screening - consider limiting statistical tests",
|
|
489
|
+
stacklevel=2,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
def evaluate(
|
|
493
|
+
self,
|
|
494
|
+
x: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
495
|
+
y: Union[pl.Series, pd.Series, "NDArray[Any]"],
|
|
496
|
+
model: BaseEstimator | Callable[..., Any],
|
|
497
|
+
strategy_func: Callable[..., Any] | None = None,
|
|
498
|
+
**kwargs: Any,
|
|
499
|
+
) -> EvaluationResult:
|
|
500
|
+
"""Evaluate a model using the configured validation framework.
|
|
501
|
+
|
|
502
|
+
Parameters
|
|
503
|
+
----------
|
|
504
|
+
x : Union[pl.DataFrame, pd.DataFrame, NDArray]
|
|
505
|
+
Feature matrix
|
|
506
|
+
y : Union[pl.Series, pd.Series, NDArray]
|
|
507
|
+
Target values (returns)
|
|
508
|
+
model : Union[BaseEstimator, Callable]
|
|
509
|
+
Model to evaluate (scikit-learn compatible or callable)
|
|
510
|
+
strategy_func : Optional[Callable], default None
|
|
511
|
+
Function to convert predictions to returns. If None, assumes
|
|
512
|
+
predictions are directly used for position sizing
|
|
513
|
+
**kwargs : Any
|
|
514
|
+
Additional parameters passed to splitter
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
-------
|
|
518
|
+
EvaluationResult
|
|
519
|
+
Comprehensive evaluation results
|
|
520
|
+
|
|
521
|
+
Examples:
|
|
522
|
+
--------
|
|
523
|
+
>>> from sklearn.ensemble import RandomForestRegressor
|
|
524
|
+
>>> model = RandomForestRegressor(n_estimators=50)
|
|
525
|
+
>>> evaluator = Evaluator(tier=2)
|
|
526
|
+
>>> result = evaluator.evaluate(X, y, model)
|
|
527
|
+
>>> print(result.summary())
|
|
528
|
+
"""
|
|
529
|
+
# Convert inputs to consistent format
|
|
530
|
+
x_array = DataFrameAdapter.to_numpy(x)
|
|
531
|
+
y_array = DataFrameAdapter.to_numpy(y).flatten()
|
|
532
|
+
|
|
533
|
+
if len(x_array) != len(y_array):
|
|
534
|
+
raise ValueError("x and y must have the same number of samples")
|
|
535
|
+
|
|
536
|
+
# Set random seed if specified
|
|
537
|
+
if self.random_state is not None:
|
|
538
|
+
np.random.seed(self.random_state)
|
|
539
|
+
|
|
540
|
+
def process_fold(
|
|
541
|
+
fold_idx,
|
|
542
|
+
train_idx,
|
|
543
|
+
test_idx,
|
|
544
|
+
model,
|
|
545
|
+
x_array,
|
|
546
|
+
y_array,
|
|
547
|
+
strategy_func,
|
|
548
|
+
):
|
|
549
|
+
"""Process a single fold with full process isolation."""
|
|
550
|
+
try:
|
|
551
|
+
x_train, x_test = x_array[train_idx], x_array[test_idx]
|
|
552
|
+
y_train, y_test = y_array[train_idx], y_array[test_idx]
|
|
553
|
+
|
|
554
|
+
if hasattr(model, "fit") and hasattr(model, "predict"):
|
|
555
|
+
# Clone to prevent shared state between parallel processes
|
|
556
|
+
model_clone = clone(model)
|
|
557
|
+
|
|
558
|
+
if hasattr(model_clone, "random_state") and self.random_state is not None:
|
|
559
|
+
# Deterministic but different seed per fold
|
|
560
|
+
model_clone.random_state = self.random_state + fold_idx
|
|
561
|
+
|
|
562
|
+
model_clone.fit(x_train, y_train)
|
|
563
|
+
predictions = model_clone.predict(x_test)
|
|
564
|
+
else:
|
|
565
|
+
# Callable model (must be stateless)
|
|
566
|
+
predictions = model(x_train, y_train, x_test)
|
|
567
|
+
|
|
568
|
+
if strategy_func is not None:
|
|
569
|
+
strategy_returns = strategy_func(predictions, y_test)
|
|
570
|
+
else:
|
|
571
|
+
positions = np.sign(predictions)
|
|
572
|
+
strategy_returns = positions * y_test
|
|
573
|
+
|
|
574
|
+
fold_metrics = {}
|
|
575
|
+
metric_registry = MetricRegistry.default()
|
|
576
|
+
for metric_name in self.metrics:
|
|
577
|
+
try:
|
|
578
|
+
if metric_name in metric_registry:
|
|
579
|
+
metric_func = metric_registry.get(metric_name)
|
|
580
|
+
value = metric_func(predictions, y_test, strategy_returns)
|
|
581
|
+
|
|
582
|
+
if metric_name == "max_drawdown" and isinstance(value, dict):
|
|
583
|
+
value = value["max_drawdown"]
|
|
584
|
+
|
|
585
|
+
fold_metrics[metric_name] = value
|
|
586
|
+
except Exception as e:
|
|
587
|
+
fold_metrics[metric_name] = np.nan
|
|
588
|
+
warnings.warn(
|
|
589
|
+
f"Fold {fold_idx}: Failed to calculate {metric_name}: {e}",
|
|
590
|
+
stacklevel=2,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
fold_metrics["fold"] = fold_idx
|
|
594
|
+
fold_metrics["n_train"] = len(train_idx)
|
|
595
|
+
fold_metrics["n_test"] = len(test_idx)
|
|
596
|
+
|
|
597
|
+
return fold_metrics, predictions, y_test, strategy_returns
|
|
598
|
+
|
|
599
|
+
except Exception as e:
|
|
600
|
+
warnings.warn(
|
|
601
|
+
f"Fold {fold_idx} failed with error: {e}. Returning NaN results.",
|
|
602
|
+
stacklevel=2,
|
|
603
|
+
)
|
|
604
|
+
nan_metrics = dict.fromkeys(self.metrics, np.nan)
|
|
605
|
+
nan_metrics.update(
|
|
606
|
+
{
|
|
607
|
+
"fold": fold_idx,
|
|
608
|
+
"n_train": len(train_idx),
|
|
609
|
+
"n_test": len(test_idx),
|
|
610
|
+
},
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
return nan_metrics, np.array([np.nan]), np.array([np.nan]), np.array([np.nan])
|
|
614
|
+
|
|
615
|
+
splits = list(self.splitter.split(x, y, **kwargs))
|
|
616
|
+
|
|
617
|
+
if self.n_jobs == 1:
|
|
618
|
+
results = []
|
|
619
|
+
for fold_idx, (train_idx, test_idx) in enumerate(splits):
|
|
620
|
+
result = process_fold(
|
|
621
|
+
fold_idx, train_idx, test_idx, model, x_array, y_array, strategy_func
|
|
622
|
+
)
|
|
623
|
+
results.append(result)
|
|
624
|
+
else:
|
|
625
|
+
# Use loky backend for process isolation (prevents race conditions)
|
|
626
|
+
results = Parallel(n_jobs=self.n_jobs, backend="loky")(
|
|
627
|
+
delayed(process_fold)(
|
|
628
|
+
fold_idx, train_idx, test_idx, model, x_array, y_array, strategy_func
|
|
629
|
+
)
|
|
630
|
+
for fold_idx, (train_idx, test_idx) in enumerate(splits)
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
fold_results = [r[0] for r in results]
|
|
634
|
+
all_predictions = [pred for r in results for pred in r[1]]
|
|
635
|
+
all_actual = [actual for r in results for actual in r[2]]
|
|
636
|
+
oos_returns = [r[3] for r in results]
|
|
637
|
+
|
|
638
|
+
metrics_results = self._aggregate_metrics(fold_results)
|
|
639
|
+
statistical_tests = self._perform_statistical_tests(
|
|
640
|
+
fold_results,
|
|
641
|
+
all_predictions,
|
|
642
|
+
all_actual,
|
|
643
|
+
metrics_results,
|
|
644
|
+
oos_returns,
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
metadata = {
|
|
648
|
+
"n_samples": len(x_array),
|
|
649
|
+
"n_features": x_array.shape[1] if x_array.ndim > 1 else 1,
|
|
650
|
+
"splitter_params": self.splitter.__dict__,
|
|
651
|
+
"tier": self.tier,
|
|
652
|
+
"random_state": self.random_state,
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
return EvaluationResult(
|
|
656
|
+
tier=self.tier,
|
|
657
|
+
splitter_name=self.splitter.__class__.__name__,
|
|
658
|
+
metrics_results=metrics_results,
|
|
659
|
+
statistical_tests=statistical_tests,
|
|
660
|
+
fold_results=fold_results,
|
|
661
|
+
metadata=metadata,
|
|
662
|
+
oos_returns=oos_returns,
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
def _aggregate_metrics(self, fold_results: list[dict[str, Any]]) -> dict[str, Any]:
|
|
666
|
+
"""Aggregate metrics across folds."""
|
|
667
|
+
aggregated = {}
|
|
668
|
+
|
|
669
|
+
for metric_name in self.metrics:
|
|
670
|
+
values = [fold.get(metric_name, np.nan) for fold in fold_results]
|
|
671
|
+
valid_values = [v for v in values if not np.isnan(v)]
|
|
672
|
+
|
|
673
|
+
if valid_values:
|
|
674
|
+
aggregated[metric_name] = {
|
|
675
|
+
"mean": np.mean(valid_values),
|
|
676
|
+
"std": np.std(valid_values, ddof=1) if len(valid_values) > 1 else 0.0,
|
|
677
|
+
"min": np.min(valid_values),
|
|
678
|
+
"max": np.max(valid_values),
|
|
679
|
+
"values": valid_values,
|
|
680
|
+
"n_valid": len(valid_values),
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
# Add confidence interval for mean if multiple folds
|
|
684
|
+
if len(valid_values) > 1:
|
|
685
|
+
se = aggregated[metric_name]["std"] / np.sqrt(len(valid_values))
|
|
686
|
+
from scipy.stats import t
|
|
687
|
+
|
|
688
|
+
t_val = t.ppf(1 - self.confidence_level / 2, len(valid_values) - 1)
|
|
689
|
+
margin = t_val * se
|
|
690
|
+
|
|
691
|
+
aggregated[metric_name]["ci_lower"] = aggregated[metric_name]["mean"] - margin
|
|
692
|
+
aggregated[metric_name]["ci_upper"] = aggregated[metric_name]["mean"] + margin
|
|
693
|
+
else:
|
|
694
|
+
aggregated[metric_name] = {
|
|
695
|
+
"mean": np.nan,
|
|
696
|
+
"std": np.nan,
|
|
697
|
+
"min": np.nan,
|
|
698
|
+
"max": np.nan,
|
|
699
|
+
"values": [],
|
|
700
|
+
"n_valid": 0,
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
return aggregated
|
|
704
|
+
|
|
705
|
+
def _perform_statistical_tests(
|
|
706
|
+
self,
|
|
707
|
+
fold_results: list[dict[str, Any]],
|
|
708
|
+
all_predictions: list[float],
|
|
709
|
+
all_actual: list[float],
|
|
710
|
+
metrics_results: dict[str, Any],
|
|
711
|
+
oos_returns: list[np.ndarray],
|
|
712
|
+
) -> dict[str, Any]:
|
|
713
|
+
"""Perform statistical tests based on configuration."""
|
|
714
|
+
statistical_results: dict[str, Any] = {}
|
|
715
|
+
stat_registry = StatTestRegistry.default()
|
|
716
|
+
|
|
717
|
+
for test_name in self.statistical_tests:
|
|
718
|
+
try:
|
|
719
|
+
if test_name in stat_registry:
|
|
720
|
+
test_func = stat_registry.get(test_name)
|
|
721
|
+
|
|
722
|
+
# Prepare test-specific arguments
|
|
723
|
+
if test_name == "dsr" and "sharpe" in metrics_results:
|
|
724
|
+
sharpe_values = metrics_results["sharpe"]["values"]
|
|
725
|
+
if sharpe_values and len(oos_returns) > 0:
|
|
726
|
+
best_sharpe = float(np.max(sharpe_values))
|
|
727
|
+
n_trials = len(fold_results)
|
|
728
|
+
# Calculate variance across trials
|
|
729
|
+
variance_trials = (
|
|
730
|
+
float(np.var(sharpe_values, ddof=1))
|
|
731
|
+
if len(sharpe_values) > 1
|
|
732
|
+
else 0.001
|
|
733
|
+
)
|
|
734
|
+
# Calculate average sample size per fold
|
|
735
|
+
n_samples = int(
|
|
736
|
+
np.mean(
|
|
737
|
+
[len(returns) for returns in oos_returns if len(returns) > 0]
|
|
738
|
+
)
|
|
739
|
+
)
|
|
740
|
+
# Use deflated_sharpe_ratio_from_statistics with new API
|
|
741
|
+
dsr_result = test_func(
|
|
742
|
+
observed_sharpe=best_sharpe,
|
|
743
|
+
n_samples=n_samples,
|
|
744
|
+
n_trials=n_trials,
|
|
745
|
+
variance_trials=variance_trials,
|
|
746
|
+
)
|
|
747
|
+
# Convert DSRResult dataclass to dict for consistency
|
|
748
|
+
result = {
|
|
749
|
+
"dsr": dsr_result.probability,
|
|
750
|
+
"p_value": dsr_result.p_value,
|
|
751
|
+
"expected_max_sharpe": dsr_result.expected_max_sharpe,
|
|
752
|
+
"z_score": dsr_result.z_score,
|
|
753
|
+
"is_significant": dsr_result.is_significant,
|
|
754
|
+
}
|
|
755
|
+
else:
|
|
756
|
+
continue
|
|
757
|
+
|
|
758
|
+
elif test_name == "hac_ic" and "ic" in metrics_results:
|
|
759
|
+
result = test_func(
|
|
760
|
+
predictions=np.array(all_predictions),
|
|
761
|
+
returns=np.array(all_actual),
|
|
762
|
+
return_details=True,
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
elif test_name == "fdr":
|
|
766
|
+
# Collect p-values from other tests
|
|
767
|
+
p_values = []
|
|
768
|
+
for test_result in statistical_results.values():
|
|
769
|
+
if isinstance(test_result, dict) and "p_value" in test_result:
|
|
770
|
+
p_values.append(test_result["p_value"])
|
|
771
|
+
|
|
772
|
+
if p_values:
|
|
773
|
+
result = test_func(
|
|
774
|
+
p_values,
|
|
775
|
+
alpha=self.confidence_level,
|
|
776
|
+
return_details=True,
|
|
777
|
+
)
|
|
778
|
+
else:
|
|
779
|
+
continue
|
|
780
|
+
|
|
781
|
+
elif test_name == "whites_reality_check":
|
|
782
|
+
if len(oos_returns) > 1 and all(
|
|
783
|
+
len(returns) > 0 for returns in oos_returns
|
|
784
|
+
):
|
|
785
|
+
# Concatenate all OOS returns into a single time series
|
|
786
|
+
# This is the correct input for White's Reality Check
|
|
787
|
+
strategy_returns_series = np.concatenate(oos_returns)
|
|
788
|
+
|
|
789
|
+
# Create benchmark (zero returns) of the same length
|
|
790
|
+
benchmark_returns = np.zeros(len(strategy_returns_series))
|
|
791
|
+
|
|
792
|
+
# Reshape for test function (expects 2D array for strategies)
|
|
793
|
+
strategy_returns_matrix = strategy_returns_series.reshape(-1, 1)
|
|
794
|
+
|
|
795
|
+
result = test_func(
|
|
796
|
+
returns_benchmark=benchmark_returns,
|
|
797
|
+
returns_strategies=strategy_returns_matrix,
|
|
798
|
+
bootstrap_samples=min(self.bootstrap_samples, 500),
|
|
799
|
+
random_state=self.random_state,
|
|
800
|
+
)
|
|
801
|
+
else:
|
|
802
|
+
continue
|
|
803
|
+
else:
|
|
804
|
+
# Generic test function call
|
|
805
|
+
result = test_func(
|
|
806
|
+
fold_results=fold_results,
|
|
807
|
+
predictions=all_predictions,
|
|
808
|
+
actual=all_actual,
|
|
809
|
+
metrics_results=metrics_results,
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
statistical_results[test_name] = result
|
|
813
|
+
else:
|
|
814
|
+
warnings.warn(
|
|
815
|
+
f"Unknown statistical test: {test_name}",
|
|
816
|
+
stacklevel=2,
|
|
817
|
+
)
|
|
818
|
+
continue
|
|
819
|
+
|
|
820
|
+
except Exception as e:
|
|
821
|
+
warnings.warn(
|
|
822
|
+
f"Error in statistical test {test_name}: {e}",
|
|
823
|
+
stacklevel=2,
|
|
824
|
+
)
|
|
825
|
+
# Store error in a way that's compatible with the expected type
|
|
826
|
+
error_result: dict[str, Any] = {"error": str(e)}
|
|
827
|
+
statistical_results[test_name] = error_result
|
|
828
|
+
|
|
829
|
+
return statistical_results
|
|
830
|
+
|
|
831
|
+
def batch_evaluate(
|
|
832
|
+
self,
|
|
833
|
+
models: list[BaseEstimator | Callable[..., Any]],
|
|
834
|
+
x: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
835
|
+
y: Union[pl.Series, pd.Series, "NDArray[Any]"],
|
|
836
|
+
model_names: list[str] | None = None,
|
|
837
|
+
**kwargs: Any,
|
|
838
|
+
) -> dict[str, EvaluationResult]:
|
|
839
|
+
"""Evaluate multiple models with the same validation framework.
|
|
840
|
+
|
|
841
|
+
Parameters
|
|
842
|
+
----------
|
|
843
|
+
models : List[Union[BaseEstimator, Callable]]
|
|
844
|
+
List of models to evaluate
|
|
845
|
+
X : Union[pl.DataFrame, pd.DataFrame, NDArray]
|
|
846
|
+
Feature matrix
|
|
847
|
+
y : Union[pl.Series, pd.Series, NDArray]
|
|
848
|
+
Target values
|
|
849
|
+
model_names : Optional[List[str]], default None
|
|
850
|
+
Names for the models. If None, uses model class names
|
|
851
|
+
**kwargs : Any
|
|
852
|
+
Additional parameters passed to evaluate()
|
|
853
|
+
|
|
854
|
+
Returns:
|
|
855
|
+
-------
|
|
856
|
+
dict[str, EvaluationResult]
|
|
857
|
+
Dictionary mapping model names to evaluation results
|
|
858
|
+
"""
|
|
859
|
+
if model_names is None:
|
|
860
|
+
model_names = [
|
|
861
|
+
model.__class__.__name__ if hasattr(model, "__class__") else f"Model_{i}"
|
|
862
|
+
for i, model in enumerate(models)
|
|
863
|
+
]
|
|
864
|
+
|
|
865
|
+
if len(models) != len(model_names):
|
|
866
|
+
raise ValueError("Number of models must match number of model names")
|
|
867
|
+
|
|
868
|
+
results = {}
|
|
869
|
+
for model, name in zip(models, model_names, strict=False):
|
|
870
|
+
print(f"Evaluating {name}...")
|
|
871
|
+
results[name] = self.evaluate(x, y, model, **kwargs)
|
|
872
|
+
|
|
873
|
+
return results
|
|
874
|
+
|
|
875
|
+
def compare_models(
|
|
876
|
+
self,
|
|
877
|
+
batch_results: dict[str, EvaluationResult],
|
|
878
|
+
primary_metric: str = "sharpe",
|
|
879
|
+
) -> dict[str, Any]:
|
|
880
|
+
"""Compare multiple model evaluation results.
|
|
881
|
+
|
|
882
|
+
Parameters
|
|
883
|
+
----------
|
|
884
|
+
batch_results : dict[str, EvaluationResult]
|
|
885
|
+
Results from batch_evaluate()
|
|
886
|
+
primary_metric : str, default "sharpe"
|
|
887
|
+
Primary metric for ranking models
|
|
888
|
+
|
|
889
|
+
Returns:
|
|
890
|
+
-------
|
|
891
|
+
dict[str, Any]
|
|
892
|
+
Comparison summary with rankings and statistical tests
|
|
893
|
+
"""
|
|
894
|
+
if not batch_results:
|
|
895
|
+
return {"error": "No results to compare"}
|
|
896
|
+
|
|
897
|
+
# Extract primary metric values
|
|
898
|
+
model_metrics = {}
|
|
899
|
+
for name, result in batch_results.items():
|
|
900
|
+
metric_value = result.metrics_results.get(primary_metric, {}).get(
|
|
901
|
+
"mean",
|
|
902
|
+
np.nan,
|
|
903
|
+
)
|
|
904
|
+
model_metrics[name] = metric_value
|
|
905
|
+
|
|
906
|
+
# Rank models
|
|
907
|
+
valid_models = {k: v for k, v in model_metrics.items() if not np.isnan(v)}
|
|
908
|
+
if not valid_models:
|
|
909
|
+
return {"error": f"No valid {primary_metric} values found"}
|
|
910
|
+
|
|
911
|
+
# Determine sort order based on metric directionality
|
|
912
|
+
maximize = get_metric_directionality(primary_metric)
|
|
913
|
+
|
|
914
|
+
# Special handling for drawdown metrics (they're negative, closer to 0 is better)
|
|
915
|
+
if "drawdown" in primary_metric.lower():
|
|
916
|
+
# For drawdown, sort by absolute value (smaller absolute value is better)
|
|
917
|
+
ranked_models = sorted(valid_models.items(), key=lambda x: abs(x[1]))
|
|
918
|
+
else:
|
|
919
|
+
# Regular sorting based on directionality
|
|
920
|
+
ranked_models = sorted(valid_models.items(), key=lambda x: x[1], reverse=maximize)
|
|
921
|
+
|
|
922
|
+
# Create comparison summary
|
|
923
|
+
comparison: dict[str, Any] = {
|
|
924
|
+
"primary_metric": primary_metric,
|
|
925
|
+
"n_models": len(batch_results),
|
|
926
|
+
"ranking": [{"model": name, primary_metric: value} for name, value in ranked_models],
|
|
927
|
+
"best_model": ranked_models[0][0] if ranked_models else None,
|
|
928
|
+
"model_details": {},
|
|
929
|
+
}
|
|
930
|
+
|
|
931
|
+
# Add detailed results for each model
|
|
932
|
+
for name, result in batch_results.items():
|
|
933
|
+
comparison["model_details"][name] = result.summary()
|
|
934
|
+
|
|
935
|
+
return comparison
|