ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,279 @@
1
+ """Custom validators and validation utilities.
2
+
3
+ This module provides reusable validators, custom types, and validation
4
+ helpers used across the configuration system.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from enum import Enum
10
+ from typing import Annotated
11
+
12
+ from pydantic import Field
13
+
14
+ # Custom type aliases for common constraints
15
+ PositiveInt = Annotated[int, Field(gt=0)]
16
+ NonNegativeInt = Annotated[int, Field(ge=0)]
17
+ PositiveFloat = Annotated[float, Field(gt=0.0)]
18
+ NonNegativeFloat = Annotated[float, Field(ge=0.0)]
19
+ Probability = Annotated[float, Field(ge=0.0, le=1.0)]
20
+ CorrelationValue = Annotated[float, Field(ge=-1.0, le=1.0)]
21
+
22
+
23
+ class SignificanceLevel(float, Enum):
24
+ """Standard significance levels for hypothesis testing."""
25
+
26
+ LEVEL_01 = 0.01
27
+ LEVEL_05 = 0.05
28
+ LEVEL_10 = 0.10
29
+
30
+
31
+ class CorrelationMethod(str, Enum):
32
+ """Correlation calculation methods."""
33
+
34
+ PEARSON = "pearson"
35
+ SPEARMAN = "spearman"
36
+ KENDALL = "kendall"
37
+
38
+
39
+ class StationarityTest(str, Enum):
40
+ """Stationarity test types."""
41
+
42
+ ADF = "adf" # Augmented Dickey-Fuller
43
+ KPSS = "kpss" # Kwiatkowski-Phillips-Schmidt-Shin
44
+ PP = "pp" # Phillips-Perron
45
+
46
+
47
+ class RegressionType(str, Enum):
48
+ """Regression types for stationarity tests."""
49
+
50
+ CONSTANT = "c" # Constant only
51
+ CONSTANT_TREND = "ct" # Constant and trend
52
+ CONSTANT_TREND_SQUARED = "ctt" # Constant, trend, and trend squared
53
+ NONE = "n" # No constant or trend
54
+
55
+
56
+ class ClusteringMethod(str, Enum):
57
+ """Clustering algorithm types."""
58
+
59
+ HIERARCHICAL = "hierarchical"
60
+ KMEANS = "kmeans"
61
+ DBSCAN = "dbscan"
62
+
63
+
64
+ class LinkageMethod(str, Enum):
65
+ """Linkage methods for hierarchical clustering."""
66
+
67
+ WARD = "ward"
68
+ COMPLETE = "complete"
69
+ AVERAGE = "average"
70
+ SINGLE = "single"
71
+
72
+
73
+ class DistanceMetric(str, Enum):
74
+ """Distance metrics for clustering."""
75
+
76
+ EUCLIDEAN = "euclidean"
77
+ CORRELATION = "correlation"
78
+ MANHATTAN = "manhattan"
79
+ COSINE = "cosine"
80
+
81
+
82
+ class NormalityTest(str, Enum):
83
+ """Normality test types."""
84
+
85
+ JARQUE_BERA = "jarque_bera"
86
+ SHAPIRO = "shapiro"
87
+ KOLMOGOROV_SMIRNOV = "ks"
88
+ ANDERSON = "anderson"
89
+
90
+
91
+ class OutlierMethod(str, Enum):
92
+ """Outlier detection methods."""
93
+
94
+ ZSCORE = "zscore"
95
+ IQR = "iqr"
96
+ ISOLATION_FOREST = "isolation_forest"
97
+
98
+
99
+ class VolatilityClusterMethod(str, Enum):
100
+ """Methods for detecting volatility clustering."""
101
+
102
+ LJUNG_BOX = "ljung_box"
103
+ ENGLE_ARCH = "engle_arch"
104
+
105
+
106
+ class ThresholdOptimizationTarget(str, Enum):
107
+ """Optimization targets for threshold analysis."""
108
+
109
+ SHARPE = "sharpe"
110
+ PRECISION = "precision"
111
+ RECALL = "recall"
112
+ F1 = "f1"
113
+ INFORMATION_COEFFICIENT = "ic"
114
+
115
+
116
+ class DriftDetectionMethod(str, Enum):
117
+ """Feature drift detection methods."""
118
+
119
+ KOLMOGOROV_SMIRNOV = "ks"
120
+ WASSERSTEIN = "wasserstein"
121
+ PSI = "psi" # Population Stability Index
122
+
123
+
124
+ class PortfolioMetric(str, Enum):
125
+ """Portfolio performance metrics."""
126
+
127
+ SHARPE = "sharpe"
128
+ SORTINO = "sortino"
129
+ CALMAR = "calmar"
130
+ MAX_DRAWDOWN = "max_dd"
131
+ VAR = "var" # Value at Risk
132
+ CVAR = "cvar" # Conditional Value at Risk
133
+ OMEGA = "omega"
134
+
135
+
136
+ class TimeFrequency(str, Enum):
137
+ """Time aggregation frequencies."""
138
+
139
+ DAILY = "daily"
140
+ WEEKLY = "weekly"
141
+ MONTHLY = "monthly"
142
+ QUARTERLY = "quarterly"
143
+ ANNUAL = "annual"
144
+
145
+
146
+ class FDRMethod(str, Enum):
147
+ """False Discovery Rate control methods."""
148
+
149
+ BONFERRONI = "bonferroni"
150
+ HOLM = "holm"
151
+ BENJAMINI_HOCHBERG = "bh"
152
+ BENJAMINI_YEKUTIELI = "by"
153
+
154
+
155
+ class BayesianPriorDistribution(str, Enum):
156
+ """Prior distributions for Bayesian analysis."""
157
+
158
+ NORMAL = "normal"
159
+ STUDENT_T = "student_t"
160
+ UNIFORM = "uniform"
161
+
162
+
163
+ class ReportFormat(str, Enum):
164
+ """Report output formats."""
165
+
166
+ HTML = "html"
167
+ JSON = "json"
168
+ PDF = "pdf"
169
+
170
+
171
+ class ReportTemplate(str, Enum):
172
+ """Report templates."""
173
+
174
+ FULL = "full"
175
+ SUMMARY = "summary"
176
+ DIAGNOSTIC = "diagnostic"
177
+
178
+
179
+ class ReportTheme(str, Enum):
180
+ """Report visual themes."""
181
+
182
+ LIGHT = "light"
183
+ DARK = "dark"
184
+ PROFESSIONAL = "professional"
185
+
186
+
187
+ class TableFormat(str, Enum):
188
+ """Table formatting styles."""
189
+
190
+ STYLED = "styled"
191
+ PLAIN = "plain"
192
+ DATATABLES = "datatables"
193
+
194
+
195
+ class DataFrameExportFormat(str, Enum):
196
+ """DataFrame serialization formats for JSON."""
197
+
198
+ RECORDS = "records" # list of dicts
199
+ SPLIT = "split" # {index: [...], columns: [...], data: [...]}
200
+ INDEX = "index" # {index: {column: value}}
201
+
202
+
203
+ def validate_positive_int(v: int, field_name: str = "value") -> int:
204
+ """Validate that an integer is positive.
205
+
206
+ Args:
207
+ v: Value to validate
208
+ field_name: Name of field for error messages
209
+
210
+ Returns:
211
+ Validated value
212
+
213
+ Raises:
214
+ ValueError: If value is not positive
215
+ """
216
+ if v <= 0:
217
+ raise ValueError(f"{field_name} must be positive (got {v})")
218
+ return v
219
+
220
+
221
+ def validate_probability(v: float, field_name: str = "probability") -> float:
222
+ """Validate that a float is in [0, 1].
223
+
224
+ Args:
225
+ v: Value to validate
226
+ field_name: Name of field for error messages
227
+
228
+ Returns:
229
+ Validated value
230
+
231
+ Raises:
232
+ ValueError: If value is not in [0, 1]
233
+ """
234
+ if not 0.0 <= v <= 1.0:
235
+ raise ValueError(f"{field_name} must be in [0, 1] (got {v})")
236
+ return v
237
+
238
+
239
+ def validate_significance_level(v: float) -> float:
240
+ """Validate significance level is a standard value.
241
+
242
+ Args:
243
+ v: Significance level
244
+
245
+ Returns:
246
+ Validated significance level
247
+
248
+ Raises:
249
+ ValueError: If not a standard significance level
250
+ """
251
+ standard_levels = {0.01, 0.05, 0.10}
252
+ if v not in standard_levels:
253
+ raise ValueError(
254
+ f"Significance level {v} is non-standard. Consider using 0.01, 0.05, or 0.10 for interpretability."
255
+ )
256
+ return v
257
+
258
+
259
+ def validate_min_max_range(
260
+ min_val: float, max_val: float, field_prefix: str = "range"
261
+ ) -> tuple[float, float]:
262
+ """Validate that min < max.
263
+
264
+ Args:
265
+ min_val: Minimum value
266
+ max_val: Maximum value
267
+ field_prefix: Prefix for error messages
268
+
269
+ Returns:
270
+ Validated (min, max) tuple
271
+
272
+ Raises:
273
+ ValueError: If min >= max
274
+ """
275
+ if min_val >= max_val:
276
+ raise ValueError(
277
+ f"{field_prefix}_min must be < {field_prefix}_max (got {min_val} >= {max_val})"
278
+ )
279
+ return min_val, max_val
@@ -0,0 +1,29 @@
1
+ """Core functionality for ml4t-diagnostic.
2
+
3
+ This module contains the fundamental logic for purging, embargo, and sampling
4
+ that underlies all cross-validation splitters.
5
+ """
6
+
7
+ from ml4t.diagnostic.core.purging import (
8
+ apply_purging_and_embargo,
9
+ calculate_embargo_indices,
10
+ calculate_purge_indices,
11
+ )
12
+ from ml4t.diagnostic.core.sampling import (
13
+ balanced_subsample,
14
+ block_bootstrap,
15
+ event_based_sample,
16
+ sample_weights_by_importance,
17
+ stratified_sample_time_series,
18
+ )
19
+
20
+ __all__: list[str] = [
21
+ "apply_purging_and_embargo",
22
+ "balanced_subsample",
23
+ "calculate_embargo_indices",
24
+ "calculate_purge_indices",
25
+ "event_based_sample",
26
+ "sample_weights_by_importance",
27
+ "block_bootstrap",
28
+ "stratified_sample_time_series",
29
+ ]
@@ -0,0 +1,315 @@
1
+ """Numba-optimized utility functions for ML4T Diagnostic.
2
+
3
+ This module contains JIT-compiled functions for performance-critical operations.
4
+ Numba is used to optimize computationally intensive loops and array operations.
5
+
6
+ Note: Numba functions work best with NumPy arrays and simple Python types.
7
+ They cannot handle Pandas objects directly.
8
+ """
9
+
10
+ import numpy as np
11
+ from numba import jit
12
+
13
+
14
+ @jit(nopython=True, cache=True)
15
+ def calculate_drawdown_numba(
16
+ cum_returns: np.ndarray,
17
+ ) -> tuple[float, int, int, int]:
18
+ """Numba-optimized maximum drawdown calculation.
19
+
20
+ Parameters
21
+ ----------
22
+ cum_returns : np.ndarray
23
+ Array of cumulative returns
24
+
25
+ Returns
26
+ -------
27
+ Tuple[float, int, int, int]
28
+ (max_drawdown, duration, peak_idx, trough_idx)
29
+ """
30
+ n = len(cum_returns)
31
+ if n == 0:
32
+ return np.nan, -1, -1, -1
33
+
34
+ max_drawdown = 0.0
35
+ max_duration = 0
36
+ peak_idx = 0
37
+ trough_idx = 0
38
+ current_peak = cum_returns[0]
39
+ current_peak_idx = 0
40
+
41
+ for i in range(1, n):
42
+ # Update peak if necessary
43
+ if cum_returns[i] > current_peak:
44
+ current_peak = cum_returns[i]
45
+ current_peak_idx = i
46
+
47
+ # Calculate current drawdown
48
+ drawdown = cum_returns[i] - current_peak
49
+
50
+ # Update max drawdown if necessary
51
+ if drawdown < max_drawdown:
52
+ max_drawdown = drawdown
53
+ peak_idx = current_peak_idx
54
+ trough_idx = i
55
+ max_duration = i - current_peak_idx
56
+
57
+ return max_drawdown, max_duration, peak_idx, trough_idx
58
+
59
+
60
+ @jit(nopython=True, cache=True)
61
+ def purge_indices_numba(
62
+ test_start: int,
63
+ _test_end: int,
64
+ label_horizon: int,
65
+ n_samples: int,
66
+ ) -> np.ndarray:
67
+ """Numba-optimized calculation of purge indices.
68
+
69
+ Parameters
70
+ ----------
71
+ test_start : int
72
+ Start index of test period
73
+ test_end : int
74
+ End index of test period
75
+ label_horizon : int
76
+ Forward-looking period of labels
77
+ n_samples : int
78
+ Total number of samples
79
+
80
+ Returns
81
+ -------
82
+ np.ndarray
83
+ Array of indices to purge
84
+ """
85
+ purge_start = max(0, test_start - label_horizon)
86
+ purge_end = min(test_start, n_samples)
87
+
88
+ if purge_start >= purge_end:
89
+ return np.empty(0, dtype=np.int64)
90
+
91
+ return np.arange(purge_start, purge_end, dtype=np.int64)
92
+
93
+
94
+ @jit(nopython=True, cache=True)
95
+ def embargo_indices_numba(
96
+ test_end: int,
97
+ embargo_size: int,
98
+ n_samples: int,
99
+ ) -> np.ndarray:
100
+ """Numba-optimized calculation of embargo indices.
101
+
102
+ Parameters
103
+ ----------
104
+ test_end : int
105
+ End index of test period
106
+ embargo_size : int
107
+ Number of samples to embargo after test set
108
+ n_samples : int
109
+ Total number of samples
110
+
111
+ Returns
112
+ -------
113
+ np.ndarray
114
+ Array of indices to embargo
115
+ """
116
+ embargo_start = test_end
117
+ embargo_end = min(test_end + embargo_size, n_samples)
118
+
119
+ if embargo_start >= embargo_end:
120
+ return np.empty(0, dtype=np.int64)
121
+
122
+ return np.arange(embargo_start, embargo_end, dtype=np.int64)
123
+
124
+
125
+ @jit(nopython=True, cache=True, parallel=True)
126
+ def block_bootstrap_numba(
127
+ indices: np.ndarray,
128
+ n_samples: int,
129
+ sample_length: int,
130
+ seed: int,
131
+ ) -> np.ndarray:
132
+ """Numba-optimized block bootstrap sampling.
133
+
134
+ Parameters
135
+ ----------
136
+ indices : np.ndarray
137
+ Array of indices to sample from
138
+ n_samples : int
139
+ Number of bootstrap samples to generate
140
+ sample_length : int
141
+ Length of each sequential sample
142
+ seed : int
143
+ Random seed for reproducibility
144
+
145
+ Returns
146
+ -------
147
+ np.ndarray
148
+ Bootstrap sample indices
149
+ """
150
+ np.random.seed(seed)
151
+ n_indices = len(indices)
152
+
153
+ # Handle edge cases
154
+ if sample_length >= n_indices:
155
+ if n_samples <= n_indices:
156
+ return indices[:n_samples].copy()
157
+ # Repeat indices to meet n_samples requirement
158
+ repeats = (n_samples // n_indices) + 1
159
+ result = np.empty(repeats * n_indices, dtype=indices.dtype)
160
+ for i in range(repeats):
161
+ result[i * n_indices : (i + 1) * n_indices] = indices
162
+ return result[:n_samples]
163
+
164
+ # Pre-allocate result array
165
+ result = np.empty(n_samples, dtype=indices.dtype)
166
+ filled = 0
167
+
168
+ while filled < n_samples:
169
+ # Sample a random starting point
170
+ start_idx = np.random.randint(0, n_indices - sample_length + 1)
171
+
172
+ # Determine how many samples to take
173
+ samples_to_take = min(sample_length, n_samples - filled)
174
+
175
+ # Copy sequential samples
176
+ for i in range(samples_to_take):
177
+ result[filled + i] = indices[start_idx + i]
178
+
179
+ filled += samples_to_take
180
+
181
+ return result
182
+
183
+
184
+ @jit(nopython=True, cache=True)
185
+ def rolling_sharpe_numba(
186
+ returns: np.ndarray,
187
+ window: int,
188
+ risk_free_rate: float = 0.0,
189
+ periods_per_year: int = 252,
190
+ ) -> np.ndarray:
191
+ """Numba-optimized rolling Sharpe ratio calculation.
192
+
193
+ Parameters
194
+ ----------
195
+ returns : np.ndarray
196
+ Array of returns
197
+ window : int
198
+ Rolling window size
199
+ risk_free_rate : float
200
+ Risk-free rate (annualized)
201
+ periods_per_year : int
202
+ Number of periods per year for annualization
203
+
204
+ Returns
205
+ -------
206
+ np.ndarray
207
+ Array of rolling Sharpe ratios
208
+ """
209
+ n = len(returns)
210
+ if n < window:
211
+ return np.full(n, np.nan)
212
+
213
+ result = np.full(n, np.nan)
214
+ daily_rf = risk_free_rate / periods_per_year
215
+ sqrt_periods = np.sqrt(periods_per_year)
216
+
217
+ for i in range(window - 1, n):
218
+ window_returns = returns[i - window + 1 : i + 1]
219
+ excess_returns = window_returns - daily_rf
220
+
221
+ mean_excess = np.mean(excess_returns)
222
+ std_excess = np.std(excess_returns)
223
+
224
+ if std_excess > 0:
225
+ result[i] = mean_excess / std_excess * sqrt_periods
226
+ else:
227
+ # If std is zero, check if mean is also zero
228
+ if abs(mean_excess) < 1e-10:
229
+ result[i] = 0.0
230
+ else:
231
+ result[i] = np.nan
232
+
233
+ return result
234
+
235
+
236
+ @jit(nopython=True, cache=True, parallel=True)
237
+ def calculate_ic_vectorized(
238
+ predictions: np.ndarray,
239
+ returns: np.ndarray,
240
+ method: int = 0, # 0=pearson, 1=spearman
241
+ ) -> float:
242
+ """Numba-optimized Information Coefficient calculation.
243
+
244
+ Parameters
245
+ ----------
246
+ predictions : np.ndarray
247
+ Array of predictions
248
+ returns : np.ndarray
249
+ Array of returns
250
+ method : int
251
+ 0 for Pearson, 1 for Spearman
252
+
253
+ Returns
254
+ -------
255
+ float
256
+ Information coefficient
257
+ """
258
+ n = len(predictions)
259
+ if n != len(returns) or n < 2:
260
+ return np.nan
261
+
262
+ # Remove NaN values
263
+ valid_mask = ~(np.isnan(predictions) | np.isnan(returns))
264
+ pred_clean = predictions[valid_mask]
265
+ ret_clean = returns[valid_mask]
266
+
267
+ if len(pred_clean) < 2:
268
+ return np.nan
269
+
270
+ if method == 1: # Spearman
271
+ # Rank the data
272
+ pred_clean = _rank_data_numba(pred_clean)
273
+ ret_clean = _rank_data_numba(ret_clean)
274
+
275
+ # Calculate Pearson correlation
276
+ pred_mean = np.mean(pred_clean)
277
+ ret_mean = np.mean(ret_clean)
278
+
279
+ numerator = np.sum((pred_clean - pred_mean) * (ret_clean - ret_mean))
280
+ denominator = np.sqrt(
281
+ np.sum((pred_clean - pred_mean) ** 2) * np.sum((ret_clean - ret_mean) ** 2)
282
+ )
283
+
284
+ if denominator == 0:
285
+ return 0.0
286
+
287
+ return numerator / denominator
288
+
289
+
290
+ @jit(nopython=True, cache=True)
291
+ def _rank_data_numba(data: np.ndarray) -> np.ndarray:
292
+ """Helper function to rank data for Spearman correlation."""
293
+ n = len(data)
294
+ indices = np.argsort(data)
295
+ ranks = np.empty(n)
296
+
297
+ for i in range(n):
298
+ ranks[indices[i]] = i + 1
299
+
300
+ # Handle ties by averaging ranks
301
+ sorted_data = data[indices]
302
+ i = 0
303
+ while i < n:
304
+ j = i
305
+ # Find all equal values
306
+ while j < n - 1 and sorted_data[j] == sorted_data[j + 1]:
307
+ j += 1
308
+ # Average ranks for ties
309
+ if i != j:
310
+ avg_rank = (ranks[indices[i]] + ranks[indices[j]]) / 2
311
+ for k in range(i, j + 1):
312
+ ranks[indices[k]] = avg_rank
313
+ i = j + 1
314
+
315
+ return ranks