ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""Custom validators and validation utilities.
|
|
2
|
+
|
|
3
|
+
This module provides reusable validators, custom types, and validation
|
|
4
|
+
helpers used across the configuration system.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Annotated
|
|
11
|
+
|
|
12
|
+
from pydantic import Field
|
|
13
|
+
|
|
14
|
+
# Custom type aliases for common constraints
|
|
15
|
+
PositiveInt = Annotated[int, Field(gt=0)]
|
|
16
|
+
NonNegativeInt = Annotated[int, Field(ge=0)]
|
|
17
|
+
PositiveFloat = Annotated[float, Field(gt=0.0)]
|
|
18
|
+
NonNegativeFloat = Annotated[float, Field(ge=0.0)]
|
|
19
|
+
Probability = Annotated[float, Field(ge=0.0, le=1.0)]
|
|
20
|
+
CorrelationValue = Annotated[float, Field(ge=-1.0, le=1.0)]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SignificanceLevel(float, Enum):
|
|
24
|
+
"""Standard significance levels for hypothesis testing."""
|
|
25
|
+
|
|
26
|
+
LEVEL_01 = 0.01
|
|
27
|
+
LEVEL_05 = 0.05
|
|
28
|
+
LEVEL_10 = 0.10
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CorrelationMethod(str, Enum):
|
|
32
|
+
"""Correlation calculation methods."""
|
|
33
|
+
|
|
34
|
+
PEARSON = "pearson"
|
|
35
|
+
SPEARMAN = "spearman"
|
|
36
|
+
KENDALL = "kendall"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class StationarityTest(str, Enum):
|
|
40
|
+
"""Stationarity test types."""
|
|
41
|
+
|
|
42
|
+
ADF = "adf" # Augmented Dickey-Fuller
|
|
43
|
+
KPSS = "kpss" # Kwiatkowski-Phillips-Schmidt-Shin
|
|
44
|
+
PP = "pp" # Phillips-Perron
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class RegressionType(str, Enum):
|
|
48
|
+
"""Regression types for stationarity tests."""
|
|
49
|
+
|
|
50
|
+
CONSTANT = "c" # Constant only
|
|
51
|
+
CONSTANT_TREND = "ct" # Constant and trend
|
|
52
|
+
CONSTANT_TREND_SQUARED = "ctt" # Constant, trend, and trend squared
|
|
53
|
+
NONE = "n" # No constant or trend
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ClusteringMethod(str, Enum):
|
|
57
|
+
"""Clustering algorithm types."""
|
|
58
|
+
|
|
59
|
+
HIERARCHICAL = "hierarchical"
|
|
60
|
+
KMEANS = "kmeans"
|
|
61
|
+
DBSCAN = "dbscan"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class LinkageMethod(str, Enum):
|
|
65
|
+
"""Linkage methods for hierarchical clustering."""
|
|
66
|
+
|
|
67
|
+
WARD = "ward"
|
|
68
|
+
COMPLETE = "complete"
|
|
69
|
+
AVERAGE = "average"
|
|
70
|
+
SINGLE = "single"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class DistanceMetric(str, Enum):
|
|
74
|
+
"""Distance metrics for clustering."""
|
|
75
|
+
|
|
76
|
+
EUCLIDEAN = "euclidean"
|
|
77
|
+
CORRELATION = "correlation"
|
|
78
|
+
MANHATTAN = "manhattan"
|
|
79
|
+
COSINE = "cosine"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class NormalityTest(str, Enum):
|
|
83
|
+
"""Normality test types."""
|
|
84
|
+
|
|
85
|
+
JARQUE_BERA = "jarque_bera"
|
|
86
|
+
SHAPIRO = "shapiro"
|
|
87
|
+
KOLMOGOROV_SMIRNOV = "ks"
|
|
88
|
+
ANDERSON = "anderson"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class OutlierMethod(str, Enum):
|
|
92
|
+
"""Outlier detection methods."""
|
|
93
|
+
|
|
94
|
+
ZSCORE = "zscore"
|
|
95
|
+
IQR = "iqr"
|
|
96
|
+
ISOLATION_FOREST = "isolation_forest"
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class VolatilityClusterMethod(str, Enum):
|
|
100
|
+
"""Methods for detecting volatility clustering."""
|
|
101
|
+
|
|
102
|
+
LJUNG_BOX = "ljung_box"
|
|
103
|
+
ENGLE_ARCH = "engle_arch"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ThresholdOptimizationTarget(str, Enum):
|
|
107
|
+
"""Optimization targets for threshold analysis."""
|
|
108
|
+
|
|
109
|
+
SHARPE = "sharpe"
|
|
110
|
+
PRECISION = "precision"
|
|
111
|
+
RECALL = "recall"
|
|
112
|
+
F1 = "f1"
|
|
113
|
+
INFORMATION_COEFFICIENT = "ic"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class DriftDetectionMethod(str, Enum):
|
|
117
|
+
"""Feature drift detection methods."""
|
|
118
|
+
|
|
119
|
+
KOLMOGOROV_SMIRNOV = "ks"
|
|
120
|
+
WASSERSTEIN = "wasserstein"
|
|
121
|
+
PSI = "psi" # Population Stability Index
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class PortfolioMetric(str, Enum):
|
|
125
|
+
"""Portfolio performance metrics."""
|
|
126
|
+
|
|
127
|
+
SHARPE = "sharpe"
|
|
128
|
+
SORTINO = "sortino"
|
|
129
|
+
CALMAR = "calmar"
|
|
130
|
+
MAX_DRAWDOWN = "max_dd"
|
|
131
|
+
VAR = "var" # Value at Risk
|
|
132
|
+
CVAR = "cvar" # Conditional Value at Risk
|
|
133
|
+
OMEGA = "omega"
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class TimeFrequency(str, Enum):
|
|
137
|
+
"""Time aggregation frequencies."""
|
|
138
|
+
|
|
139
|
+
DAILY = "daily"
|
|
140
|
+
WEEKLY = "weekly"
|
|
141
|
+
MONTHLY = "monthly"
|
|
142
|
+
QUARTERLY = "quarterly"
|
|
143
|
+
ANNUAL = "annual"
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class FDRMethod(str, Enum):
|
|
147
|
+
"""False Discovery Rate control methods."""
|
|
148
|
+
|
|
149
|
+
BONFERRONI = "bonferroni"
|
|
150
|
+
HOLM = "holm"
|
|
151
|
+
BENJAMINI_HOCHBERG = "bh"
|
|
152
|
+
BENJAMINI_YEKUTIELI = "by"
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class BayesianPriorDistribution(str, Enum):
|
|
156
|
+
"""Prior distributions for Bayesian analysis."""
|
|
157
|
+
|
|
158
|
+
NORMAL = "normal"
|
|
159
|
+
STUDENT_T = "student_t"
|
|
160
|
+
UNIFORM = "uniform"
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class ReportFormat(str, Enum):
|
|
164
|
+
"""Report output formats."""
|
|
165
|
+
|
|
166
|
+
HTML = "html"
|
|
167
|
+
JSON = "json"
|
|
168
|
+
PDF = "pdf"
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class ReportTemplate(str, Enum):
|
|
172
|
+
"""Report templates."""
|
|
173
|
+
|
|
174
|
+
FULL = "full"
|
|
175
|
+
SUMMARY = "summary"
|
|
176
|
+
DIAGNOSTIC = "diagnostic"
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class ReportTheme(str, Enum):
|
|
180
|
+
"""Report visual themes."""
|
|
181
|
+
|
|
182
|
+
LIGHT = "light"
|
|
183
|
+
DARK = "dark"
|
|
184
|
+
PROFESSIONAL = "professional"
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class TableFormat(str, Enum):
|
|
188
|
+
"""Table formatting styles."""
|
|
189
|
+
|
|
190
|
+
STYLED = "styled"
|
|
191
|
+
PLAIN = "plain"
|
|
192
|
+
DATATABLES = "datatables"
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class DataFrameExportFormat(str, Enum):
|
|
196
|
+
"""DataFrame serialization formats for JSON."""
|
|
197
|
+
|
|
198
|
+
RECORDS = "records" # list of dicts
|
|
199
|
+
SPLIT = "split" # {index: [...], columns: [...], data: [...]}
|
|
200
|
+
INDEX = "index" # {index: {column: value}}
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def validate_positive_int(v: int, field_name: str = "value") -> int:
|
|
204
|
+
"""Validate that an integer is positive.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
v: Value to validate
|
|
208
|
+
field_name: Name of field for error messages
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Validated value
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
ValueError: If value is not positive
|
|
215
|
+
"""
|
|
216
|
+
if v <= 0:
|
|
217
|
+
raise ValueError(f"{field_name} must be positive (got {v})")
|
|
218
|
+
return v
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def validate_probability(v: float, field_name: str = "probability") -> float:
|
|
222
|
+
"""Validate that a float is in [0, 1].
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
v: Value to validate
|
|
226
|
+
field_name: Name of field for error messages
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Validated value
|
|
230
|
+
|
|
231
|
+
Raises:
|
|
232
|
+
ValueError: If value is not in [0, 1]
|
|
233
|
+
"""
|
|
234
|
+
if not 0.0 <= v <= 1.0:
|
|
235
|
+
raise ValueError(f"{field_name} must be in [0, 1] (got {v})")
|
|
236
|
+
return v
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def validate_significance_level(v: float) -> float:
|
|
240
|
+
"""Validate significance level is a standard value.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
v: Significance level
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Validated significance level
|
|
247
|
+
|
|
248
|
+
Raises:
|
|
249
|
+
ValueError: If not a standard significance level
|
|
250
|
+
"""
|
|
251
|
+
standard_levels = {0.01, 0.05, 0.10}
|
|
252
|
+
if v not in standard_levels:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
f"Significance level {v} is non-standard. Consider using 0.01, 0.05, or 0.10 for interpretability."
|
|
255
|
+
)
|
|
256
|
+
return v
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def validate_min_max_range(
|
|
260
|
+
min_val: float, max_val: float, field_prefix: str = "range"
|
|
261
|
+
) -> tuple[float, float]:
|
|
262
|
+
"""Validate that min < max.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
min_val: Minimum value
|
|
266
|
+
max_val: Maximum value
|
|
267
|
+
field_prefix: Prefix for error messages
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
Validated (min, max) tuple
|
|
271
|
+
|
|
272
|
+
Raises:
|
|
273
|
+
ValueError: If min >= max
|
|
274
|
+
"""
|
|
275
|
+
if min_val >= max_val:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"{field_prefix}_min must be < {field_prefix}_max (got {min_val} >= {max_val})"
|
|
278
|
+
)
|
|
279
|
+
return min_val, max_val
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""Core functionality for ml4t-diagnostic.
|
|
2
|
+
|
|
3
|
+
This module contains the fundamental logic for purging, embargo, and sampling
|
|
4
|
+
that underlies all cross-validation splitters.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ml4t.diagnostic.core.purging import (
|
|
8
|
+
apply_purging_and_embargo,
|
|
9
|
+
calculate_embargo_indices,
|
|
10
|
+
calculate_purge_indices,
|
|
11
|
+
)
|
|
12
|
+
from ml4t.diagnostic.core.sampling import (
|
|
13
|
+
balanced_subsample,
|
|
14
|
+
block_bootstrap,
|
|
15
|
+
event_based_sample,
|
|
16
|
+
sample_weights_by_importance,
|
|
17
|
+
stratified_sample_time_series,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__: list[str] = [
|
|
21
|
+
"apply_purging_and_embargo",
|
|
22
|
+
"balanced_subsample",
|
|
23
|
+
"calculate_embargo_indices",
|
|
24
|
+
"calculate_purge_indices",
|
|
25
|
+
"event_based_sample",
|
|
26
|
+
"sample_weights_by_importance",
|
|
27
|
+
"block_bootstrap",
|
|
28
|
+
"stratified_sample_time_series",
|
|
29
|
+
]
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""Numba-optimized utility functions for ML4T Diagnostic.
|
|
2
|
+
|
|
3
|
+
This module contains JIT-compiled functions for performance-critical operations.
|
|
4
|
+
Numba is used to optimize computationally intensive loops and array operations.
|
|
5
|
+
|
|
6
|
+
Note: Numba functions work best with NumPy arrays and simple Python types.
|
|
7
|
+
They cannot handle Pandas objects directly.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from numba import jit
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@jit(nopython=True, cache=True)
|
|
15
|
+
def calculate_drawdown_numba(
|
|
16
|
+
cum_returns: np.ndarray,
|
|
17
|
+
) -> tuple[float, int, int, int]:
|
|
18
|
+
"""Numba-optimized maximum drawdown calculation.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
cum_returns : np.ndarray
|
|
23
|
+
Array of cumulative returns
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
Tuple[float, int, int, int]
|
|
28
|
+
(max_drawdown, duration, peak_idx, trough_idx)
|
|
29
|
+
"""
|
|
30
|
+
n = len(cum_returns)
|
|
31
|
+
if n == 0:
|
|
32
|
+
return np.nan, -1, -1, -1
|
|
33
|
+
|
|
34
|
+
max_drawdown = 0.0
|
|
35
|
+
max_duration = 0
|
|
36
|
+
peak_idx = 0
|
|
37
|
+
trough_idx = 0
|
|
38
|
+
current_peak = cum_returns[0]
|
|
39
|
+
current_peak_idx = 0
|
|
40
|
+
|
|
41
|
+
for i in range(1, n):
|
|
42
|
+
# Update peak if necessary
|
|
43
|
+
if cum_returns[i] > current_peak:
|
|
44
|
+
current_peak = cum_returns[i]
|
|
45
|
+
current_peak_idx = i
|
|
46
|
+
|
|
47
|
+
# Calculate current drawdown
|
|
48
|
+
drawdown = cum_returns[i] - current_peak
|
|
49
|
+
|
|
50
|
+
# Update max drawdown if necessary
|
|
51
|
+
if drawdown < max_drawdown:
|
|
52
|
+
max_drawdown = drawdown
|
|
53
|
+
peak_idx = current_peak_idx
|
|
54
|
+
trough_idx = i
|
|
55
|
+
max_duration = i - current_peak_idx
|
|
56
|
+
|
|
57
|
+
return max_drawdown, max_duration, peak_idx, trough_idx
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@jit(nopython=True, cache=True)
|
|
61
|
+
def purge_indices_numba(
|
|
62
|
+
test_start: int,
|
|
63
|
+
_test_end: int,
|
|
64
|
+
label_horizon: int,
|
|
65
|
+
n_samples: int,
|
|
66
|
+
) -> np.ndarray:
|
|
67
|
+
"""Numba-optimized calculation of purge indices.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
test_start : int
|
|
72
|
+
Start index of test period
|
|
73
|
+
test_end : int
|
|
74
|
+
End index of test period
|
|
75
|
+
label_horizon : int
|
|
76
|
+
Forward-looking period of labels
|
|
77
|
+
n_samples : int
|
|
78
|
+
Total number of samples
|
|
79
|
+
|
|
80
|
+
Returns
|
|
81
|
+
-------
|
|
82
|
+
np.ndarray
|
|
83
|
+
Array of indices to purge
|
|
84
|
+
"""
|
|
85
|
+
purge_start = max(0, test_start - label_horizon)
|
|
86
|
+
purge_end = min(test_start, n_samples)
|
|
87
|
+
|
|
88
|
+
if purge_start >= purge_end:
|
|
89
|
+
return np.empty(0, dtype=np.int64)
|
|
90
|
+
|
|
91
|
+
return np.arange(purge_start, purge_end, dtype=np.int64)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@jit(nopython=True, cache=True)
|
|
95
|
+
def embargo_indices_numba(
|
|
96
|
+
test_end: int,
|
|
97
|
+
embargo_size: int,
|
|
98
|
+
n_samples: int,
|
|
99
|
+
) -> np.ndarray:
|
|
100
|
+
"""Numba-optimized calculation of embargo indices.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
test_end : int
|
|
105
|
+
End index of test period
|
|
106
|
+
embargo_size : int
|
|
107
|
+
Number of samples to embargo after test set
|
|
108
|
+
n_samples : int
|
|
109
|
+
Total number of samples
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
np.ndarray
|
|
114
|
+
Array of indices to embargo
|
|
115
|
+
"""
|
|
116
|
+
embargo_start = test_end
|
|
117
|
+
embargo_end = min(test_end + embargo_size, n_samples)
|
|
118
|
+
|
|
119
|
+
if embargo_start >= embargo_end:
|
|
120
|
+
return np.empty(0, dtype=np.int64)
|
|
121
|
+
|
|
122
|
+
return np.arange(embargo_start, embargo_end, dtype=np.int64)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@jit(nopython=True, cache=True, parallel=True)
|
|
126
|
+
def block_bootstrap_numba(
|
|
127
|
+
indices: np.ndarray,
|
|
128
|
+
n_samples: int,
|
|
129
|
+
sample_length: int,
|
|
130
|
+
seed: int,
|
|
131
|
+
) -> np.ndarray:
|
|
132
|
+
"""Numba-optimized block bootstrap sampling.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
indices : np.ndarray
|
|
137
|
+
Array of indices to sample from
|
|
138
|
+
n_samples : int
|
|
139
|
+
Number of bootstrap samples to generate
|
|
140
|
+
sample_length : int
|
|
141
|
+
Length of each sequential sample
|
|
142
|
+
seed : int
|
|
143
|
+
Random seed for reproducibility
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
np.ndarray
|
|
148
|
+
Bootstrap sample indices
|
|
149
|
+
"""
|
|
150
|
+
np.random.seed(seed)
|
|
151
|
+
n_indices = len(indices)
|
|
152
|
+
|
|
153
|
+
# Handle edge cases
|
|
154
|
+
if sample_length >= n_indices:
|
|
155
|
+
if n_samples <= n_indices:
|
|
156
|
+
return indices[:n_samples].copy()
|
|
157
|
+
# Repeat indices to meet n_samples requirement
|
|
158
|
+
repeats = (n_samples // n_indices) + 1
|
|
159
|
+
result = np.empty(repeats * n_indices, dtype=indices.dtype)
|
|
160
|
+
for i in range(repeats):
|
|
161
|
+
result[i * n_indices : (i + 1) * n_indices] = indices
|
|
162
|
+
return result[:n_samples]
|
|
163
|
+
|
|
164
|
+
# Pre-allocate result array
|
|
165
|
+
result = np.empty(n_samples, dtype=indices.dtype)
|
|
166
|
+
filled = 0
|
|
167
|
+
|
|
168
|
+
while filled < n_samples:
|
|
169
|
+
# Sample a random starting point
|
|
170
|
+
start_idx = np.random.randint(0, n_indices - sample_length + 1)
|
|
171
|
+
|
|
172
|
+
# Determine how many samples to take
|
|
173
|
+
samples_to_take = min(sample_length, n_samples - filled)
|
|
174
|
+
|
|
175
|
+
# Copy sequential samples
|
|
176
|
+
for i in range(samples_to_take):
|
|
177
|
+
result[filled + i] = indices[start_idx + i]
|
|
178
|
+
|
|
179
|
+
filled += samples_to_take
|
|
180
|
+
|
|
181
|
+
return result
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@jit(nopython=True, cache=True)
|
|
185
|
+
def rolling_sharpe_numba(
|
|
186
|
+
returns: np.ndarray,
|
|
187
|
+
window: int,
|
|
188
|
+
risk_free_rate: float = 0.0,
|
|
189
|
+
periods_per_year: int = 252,
|
|
190
|
+
) -> np.ndarray:
|
|
191
|
+
"""Numba-optimized rolling Sharpe ratio calculation.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
returns : np.ndarray
|
|
196
|
+
Array of returns
|
|
197
|
+
window : int
|
|
198
|
+
Rolling window size
|
|
199
|
+
risk_free_rate : float
|
|
200
|
+
Risk-free rate (annualized)
|
|
201
|
+
periods_per_year : int
|
|
202
|
+
Number of periods per year for annualization
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
np.ndarray
|
|
207
|
+
Array of rolling Sharpe ratios
|
|
208
|
+
"""
|
|
209
|
+
n = len(returns)
|
|
210
|
+
if n < window:
|
|
211
|
+
return np.full(n, np.nan)
|
|
212
|
+
|
|
213
|
+
result = np.full(n, np.nan)
|
|
214
|
+
daily_rf = risk_free_rate / periods_per_year
|
|
215
|
+
sqrt_periods = np.sqrt(periods_per_year)
|
|
216
|
+
|
|
217
|
+
for i in range(window - 1, n):
|
|
218
|
+
window_returns = returns[i - window + 1 : i + 1]
|
|
219
|
+
excess_returns = window_returns - daily_rf
|
|
220
|
+
|
|
221
|
+
mean_excess = np.mean(excess_returns)
|
|
222
|
+
std_excess = np.std(excess_returns)
|
|
223
|
+
|
|
224
|
+
if std_excess > 0:
|
|
225
|
+
result[i] = mean_excess / std_excess * sqrt_periods
|
|
226
|
+
else:
|
|
227
|
+
# If std is zero, check if mean is also zero
|
|
228
|
+
if abs(mean_excess) < 1e-10:
|
|
229
|
+
result[i] = 0.0
|
|
230
|
+
else:
|
|
231
|
+
result[i] = np.nan
|
|
232
|
+
|
|
233
|
+
return result
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@jit(nopython=True, cache=True, parallel=True)
|
|
237
|
+
def calculate_ic_vectorized(
|
|
238
|
+
predictions: np.ndarray,
|
|
239
|
+
returns: np.ndarray,
|
|
240
|
+
method: int = 0, # 0=pearson, 1=spearman
|
|
241
|
+
) -> float:
|
|
242
|
+
"""Numba-optimized Information Coefficient calculation.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
predictions : np.ndarray
|
|
247
|
+
Array of predictions
|
|
248
|
+
returns : np.ndarray
|
|
249
|
+
Array of returns
|
|
250
|
+
method : int
|
|
251
|
+
0 for Pearson, 1 for Spearman
|
|
252
|
+
|
|
253
|
+
Returns
|
|
254
|
+
-------
|
|
255
|
+
float
|
|
256
|
+
Information coefficient
|
|
257
|
+
"""
|
|
258
|
+
n = len(predictions)
|
|
259
|
+
if n != len(returns) or n < 2:
|
|
260
|
+
return np.nan
|
|
261
|
+
|
|
262
|
+
# Remove NaN values
|
|
263
|
+
valid_mask = ~(np.isnan(predictions) | np.isnan(returns))
|
|
264
|
+
pred_clean = predictions[valid_mask]
|
|
265
|
+
ret_clean = returns[valid_mask]
|
|
266
|
+
|
|
267
|
+
if len(pred_clean) < 2:
|
|
268
|
+
return np.nan
|
|
269
|
+
|
|
270
|
+
if method == 1: # Spearman
|
|
271
|
+
# Rank the data
|
|
272
|
+
pred_clean = _rank_data_numba(pred_clean)
|
|
273
|
+
ret_clean = _rank_data_numba(ret_clean)
|
|
274
|
+
|
|
275
|
+
# Calculate Pearson correlation
|
|
276
|
+
pred_mean = np.mean(pred_clean)
|
|
277
|
+
ret_mean = np.mean(ret_clean)
|
|
278
|
+
|
|
279
|
+
numerator = np.sum((pred_clean - pred_mean) * (ret_clean - ret_mean))
|
|
280
|
+
denominator = np.sqrt(
|
|
281
|
+
np.sum((pred_clean - pred_mean) ** 2) * np.sum((ret_clean - ret_mean) ** 2)
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
if denominator == 0:
|
|
285
|
+
return 0.0
|
|
286
|
+
|
|
287
|
+
return numerator / denominator
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
@jit(nopython=True, cache=True)
|
|
291
|
+
def _rank_data_numba(data: np.ndarray) -> np.ndarray:
|
|
292
|
+
"""Helper function to rank data for Spearman correlation."""
|
|
293
|
+
n = len(data)
|
|
294
|
+
indices = np.argsort(data)
|
|
295
|
+
ranks = np.empty(n)
|
|
296
|
+
|
|
297
|
+
for i in range(n):
|
|
298
|
+
ranks[indices[i]] = i + 1
|
|
299
|
+
|
|
300
|
+
# Handle ties by averaging ranks
|
|
301
|
+
sorted_data = data[indices]
|
|
302
|
+
i = 0
|
|
303
|
+
while i < n:
|
|
304
|
+
j = i
|
|
305
|
+
# Find all equal values
|
|
306
|
+
while j < n - 1 and sorted_data[j] == sorted_data[j + 1]:
|
|
307
|
+
j += 1
|
|
308
|
+
# Average ranks for ties
|
|
309
|
+
if i != j:
|
|
310
|
+
avg_rank = (ranks[indices[i]] + ranks[indices[j]]) / 2
|
|
311
|
+
for k in range(i, j + 1):
|
|
312
|
+
ranks[indices[k]] = avg_rank
|
|
313
|
+
i = j + 1
|
|
314
|
+
|
|
315
|
+
return ranks
|