ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,255 @@
1
+ """Metric registry for evaluation metrics with metadata.
2
+
3
+ This module provides a centralized registry for evaluation metrics,
4
+ including directionality (whether higher is better) and tier defaults.
5
+ """
6
+
7
+ from collections.abc import Callable
8
+ from typing import Any
9
+
10
+
11
+ class MetricRegistry:
12
+ """Registry of evaluation metrics with metadata.
13
+
14
+ The MetricRegistry provides a centralized place to register and query
15
+ metrics, including their computation functions, directionality (whether
16
+ higher values are better), and tier defaults.
17
+
18
+ Attributes
19
+ ----------
20
+ _metrics : dict[str, Callable]
21
+ Mapping of metric names to computation functions
22
+ _directionality : dict[str, bool]
23
+ Mapping of metric names to directionality (True = higher is better)
24
+ _tier_defaults : dict[int, list[str]]
25
+ Default metrics for each evaluation tier
26
+
27
+ Examples
28
+ --------
29
+ >>> registry = MetricRegistry()
30
+ >>> registry.register("sharpe", sharpe_func, maximize=True, tiers=[1, 2, 3])
31
+ >>> func = registry.get("sharpe")
32
+ >>> registry.is_maximize("sharpe")
33
+ True
34
+ """
35
+
36
+ _instance: "MetricRegistry | None" = None
37
+
38
+ def __init__(self) -> None:
39
+ """Initialize empty registry."""
40
+ self._metrics: dict[str, Callable[..., Any]] = {}
41
+ self._directionality: dict[str, bool] = {}
42
+ self._tier_defaults: dict[int, list[str]] = {1: [], 2: [], 3: []}
43
+
44
+ @classmethod
45
+ def default(cls) -> "MetricRegistry":
46
+ """Get or create the default singleton registry instance.
47
+
48
+ Returns
49
+ -------
50
+ MetricRegistry
51
+ The default registry instance with standard metrics registered
52
+ """
53
+ if cls._instance is None:
54
+ cls._instance = cls()
55
+ cls._instance._register_defaults()
56
+ return cls._instance
57
+
58
+ @classmethod
59
+ def reset_default(cls) -> None:
60
+ """Reset the default singleton instance (primarily for testing)."""
61
+ cls._instance = None
62
+
63
+ def register(
64
+ self,
65
+ name: str,
66
+ func: Callable[..., Any],
67
+ maximize: bool = True,
68
+ tiers: list[int] | None = None,
69
+ ) -> None:
70
+ """Register a metric with the registry.
71
+
72
+ Parameters
73
+ ----------
74
+ name : str
75
+ Unique name for the metric
76
+ func : Callable
77
+ Function that computes the metric.
78
+ Signature: (predictions, actual, strategy_returns) -> float
79
+ maximize : bool, default True
80
+ Whether higher values are better (True) or lower (False)
81
+ tiers : list[int], optional
82
+ Evaluation tiers where this metric is a default
83
+ """
84
+ self._metrics[name] = func
85
+ self._directionality[name] = maximize
86
+ if tiers:
87
+ for tier in tiers:
88
+ if tier in self._tier_defaults and name not in self._tier_defaults[tier]:
89
+ self._tier_defaults[tier].append(name)
90
+
91
+ def get(self, name: str) -> Callable[..., Any]:
92
+ """Get a metric function by name.
93
+
94
+ Parameters
95
+ ----------
96
+ name : str
97
+ Name of the metric
98
+
99
+ Returns
100
+ -------
101
+ Callable
102
+ The metric computation function
103
+
104
+ Raises
105
+ ------
106
+ KeyError
107
+ If metric name is not registered
108
+ """
109
+ if name not in self._metrics:
110
+ raise KeyError(f"Unknown metric: {name}. Available: {list(self._metrics.keys())}")
111
+ return self._metrics[name]
112
+
113
+ def is_maximize(self, name: str) -> bool:
114
+ """Get whether higher values are better for a metric.
115
+
116
+ Parameters
117
+ ----------
118
+ name : str
119
+ Name of the metric
120
+
121
+ Returns
122
+ -------
123
+ bool
124
+ True if higher values are better, False otherwise
125
+ """
126
+ if name in self._directionality:
127
+ return self._directionality[name]
128
+ return self._infer_directionality(name)
129
+
130
+ def _infer_directionality(self, name: str) -> bool:
131
+ """Infer directionality for unknown metrics based on naming conventions."""
132
+ normalized = name.lower().replace("-", "_").replace(" ", "_")
133
+
134
+ if any(term in normalized for term in ["drawdown", "risk", "error", "loss", "volatility"]):
135
+ return False
136
+
137
+ if any(
138
+ term in normalized
139
+ for term in ["return", "profit", "gain", "ratio", "score", "coefficient"]
140
+ ):
141
+ return True
142
+
143
+ return True # Default to higher is better
144
+
145
+ def get_by_tier(self, tier: int) -> list[str]:
146
+ """Get default metrics for a specific tier.
147
+
148
+ Parameters
149
+ ----------
150
+ tier : int
151
+ Evaluation tier (1, 2, or 3)
152
+
153
+ Returns
154
+ -------
155
+ list[str]
156
+ List of default metric names for the tier
157
+ """
158
+ return self._tier_defaults.get(tier, []).copy()
159
+
160
+ def list_metrics(self) -> list[str]:
161
+ """List all registered metric names.
162
+
163
+ Returns
164
+ -------
165
+ list[str]
166
+ Sorted list of metric names
167
+ """
168
+ return sorted(self._metrics.keys())
169
+
170
+ def __contains__(self, name: str) -> bool:
171
+ """Check if a metric is registered."""
172
+ return name in self._metrics
173
+
174
+ def _register_defaults(self) -> None:
175
+ """Register default metrics."""
176
+ from . import metrics
177
+
178
+ # Core metrics with tier assignments
179
+ self.register(
180
+ "ic",
181
+ lambda pred, actual, _returns: metrics.information_coefficient(pred, actual),
182
+ maximize=True,
183
+ tiers=[1, 2, 3],
184
+ )
185
+ self.register(
186
+ "sharpe",
187
+ lambda _pred, _actual, returns: metrics.sharpe_ratio(returns),
188
+ maximize=True,
189
+ tiers=[1, 2],
190
+ )
191
+ self.register(
192
+ "sortino",
193
+ lambda _pred, _actual, returns: metrics.sortino_ratio(returns),
194
+ maximize=True,
195
+ tiers=[1],
196
+ )
197
+ self.register(
198
+ "max_drawdown",
199
+ lambda _pred, _actual, returns: metrics.maximum_drawdown(returns),
200
+ maximize=False,
201
+ tiers=[1],
202
+ )
203
+ self.register(
204
+ "hit_rate",
205
+ lambda pred, actual, _returns: metrics.hit_rate(pred, actual),
206
+ maximize=True,
207
+ tiers=[1, 2, 3],
208
+ )
209
+
210
+ # Additional directionality mappings for common metric names
211
+ self._register_directionality_defaults()
212
+
213
+ def _register_directionality_defaults(self) -> None:
214
+ """Register directionality for common metric names (for get_metric_directionality)."""
215
+ # Performance metrics (higher is better)
216
+ for name in [
217
+ "sharpe_ratio",
218
+ "sortino_ratio",
219
+ "calmar",
220
+ "calmar_ratio",
221
+ "information_ratio",
222
+ "omega_ratio",
223
+ "profit_factor",
224
+ "total_return",
225
+ "mean_return",
226
+ "cumulative_return",
227
+ "annualized_return",
228
+ "win_rate",
229
+ "accuracy",
230
+ "information_coefficient",
231
+ "ic_mean",
232
+ "spearman",
233
+ "pearson",
234
+ "r_squared",
235
+ "r2",
236
+ "t_statistic",
237
+ "z_score",
238
+ ]:
239
+ self._directionality[name] = True
240
+
241
+ # Risk metrics (lower is better)
242
+ for name in [
243
+ "maximum_drawdown",
244
+ "drawdown",
245
+ "volatility",
246
+ "downside_deviation",
247
+ "value_at_risk",
248
+ "var",
249
+ "cvar",
250
+ "conditional_value_at_risk",
251
+ "tracking_error",
252
+ "beta",
253
+ "p_value",
254
+ ]:
255
+ self._directionality[name] = False
@@ -0,0 +1,23 @@
1
+ # metrics/ - Feature Metrics
2
+
3
+ IC, importance, and interaction analysis.
4
+
5
+ ## Modules
6
+
7
+ | File | Lines | Purpose |
8
+ |------|-------|---------|
9
+ | ic.py | 530 | Core IC functions |
10
+ | ic_statistics.py | 446 | HAC-adjusted IC, decay |
11
+ | conditional_ic.py | 469 | Conditional IC |
12
+ | importance_classical.py | 375 | PFI, MDI |
13
+ | importance_mda.py | 371 | Mean Decrease Accuracy |
14
+ | importance_shap.py | 715 | SHAP importance |
15
+ | importance_analysis.py | 338 | Multi-method comparison |
16
+ | interactions.py | 772 | H-statistic, SHAP interactions |
17
+ | feature_outcome.py | 475 | Feature-outcome analysis |
18
+ | monotonicity.py | 226 | Monotonicity tests |
19
+ | risk_adjusted.py | 324 | Sharpe, Sortino, drawdown |
20
+
21
+ ## Key Functions
22
+
23
+ `information_coefficient()`, `compute_ic_series()`, `analyze_ml_importance()`, `compute_h_statistic()`, `compute_shap_importance()`
@@ -0,0 +1,133 @@
1
+ """Core performance metrics for financial ML evaluation.
2
+
3
+ This package implements the core metrics used across ml4t-diagnostic's
4
+ Four-Tier Validation Framework:
5
+
6
+ - **Tier 3**: Fast screening metrics (IC, hit rate)
7
+ - **Tier 2**: Statistical significance metrics (HAC-adjusted IC, Sharpe with CI)
8
+ - **Tier 1**: Comprehensive metrics (deflated Sharpe, maximum drawdown)
9
+
10
+ All metrics are implemented with:
11
+ - Mathematical correctness validated by property-based tests
12
+ - Numerical stability for edge cases
13
+ - Polars-native implementation for performance
14
+ - Support for confidence intervals and statistical inference
15
+
16
+ Submodules
17
+ ----------
18
+ ic : Core Information Coefficient calculations
19
+ ic_statistics : HAC-adjusted IC and decay analysis
20
+ conditional_ic : IC conditional on feature quantiles
21
+ monotonicity : Monotonic relationship tests
22
+ risk_adjusted : Sharpe, Sortino, Maximum Drawdown
23
+ basic : Hit rate, forward returns
24
+ feature_outcome : Comprehensive feature-outcome analysis
25
+ importance_classical : Permutation and MDI importance
26
+ importance_shap : SHAP-based importance
27
+ importance_mda : Mean Decrease in Accuracy importance
28
+ importance_analysis : Multi-method importance comparison
29
+ interactions : Feature interaction detection (H-statistic, SHAP)
30
+ """
31
+
32
+ # IC metrics
33
+ # Re-export cov_hac from statsmodels for backward compatibility
34
+ from statsmodels.stats.sandwich_covariance import cov_hac
35
+
36
+ # Basic metrics
37
+ from ml4t.diagnostic.evaluation.metrics.basic import (
38
+ compute_forward_returns,
39
+ hit_rate,
40
+ )
41
+
42
+ # Conditional IC
43
+ from ml4t.diagnostic.evaluation.metrics.conditional_ic import (
44
+ compute_conditional_ic,
45
+ )
46
+
47
+ # Feature outcome analysis
48
+ from ml4t.diagnostic.evaluation.metrics.feature_outcome import (
49
+ analyze_feature_outcome,
50
+ )
51
+
52
+ # IC statistics
53
+ from ml4t.diagnostic.evaluation.metrics.ic_statistics import (
54
+ compute_ic_decay,
55
+ compute_ic_hac_stats,
56
+ )
57
+ from ml4t.diagnostic.evaluation.metrics.importance_analysis import (
58
+ analyze_ml_importance,
59
+ )
60
+
61
+ # Importance methods
62
+ from ml4t.diagnostic.evaluation.metrics.importance_classical import (
63
+ compute_mdi_importance,
64
+ compute_permutation_importance,
65
+ )
66
+ from ml4t.diagnostic.evaluation.metrics.importance_mda import (
67
+ compute_mda_importance,
68
+ )
69
+ from ml4t.diagnostic.evaluation.metrics.importance_shap import (
70
+ compute_shap_importance,
71
+ )
72
+ from ml4t.diagnostic.evaluation.metrics.information_coefficient import (
73
+ compute_ic_by_horizon,
74
+ compute_ic_ir,
75
+ compute_ic_series,
76
+ information_coefficient,
77
+ )
78
+
79
+ # Interaction detection
80
+ from ml4t.diagnostic.evaluation.metrics.interactions import (
81
+ analyze_interactions,
82
+ compute_h_statistic,
83
+ compute_shap_interactions,
84
+ )
85
+
86
+ # Monotonicity
87
+ from ml4t.diagnostic.evaluation.metrics.monotonicity import (
88
+ compute_monotonicity,
89
+ )
90
+
91
+ # Risk-adjusted metrics
92
+ from ml4t.diagnostic.evaluation.metrics.risk_adjusted import (
93
+ maximum_drawdown,
94
+ sharpe_ratio,
95
+ sharpe_ratio_with_ci,
96
+ sortino_ratio,
97
+ )
98
+
99
+ __all__ = [
100
+ # IC metrics
101
+ "information_coefficient",
102
+ "compute_ic_series",
103
+ "compute_ic_by_horizon",
104
+ "compute_ic_ir",
105
+ # IC statistics
106
+ "compute_ic_hac_stats",
107
+ "compute_ic_decay",
108
+ "cov_hac", # Re-exported from statsmodels for backward compatibility
109
+ # Conditional IC
110
+ "compute_conditional_ic",
111
+ # Monotonicity
112
+ "compute_monotonicity",
113
+ # Risk-adjusted
114
+ "sharpe_ratio",
115
+ "sharpe_ratio_with_ci",
116
+ "maximum_drawdown",
117
+ "sortino_ratio",
118
+ # Basic
119
+ "hit_rate",
120
+ "compute_forward_returns",
121
+ # Feature outcome
122
+ "analyze_feature_outcome",
123
+ # Importance
124
+ "compute_permutation_importance",
125
+ "compute_mdi_importance",
126
+ "compute_shap_importance",
127
+ "compute_mda_importance",
128
+ "analyze_ml_importance",
129
+ # Interactions
130
+ "compute_h_statistic",
131
+ "compute_shap_interactions",
132
+ "analyze_interactions",
133
+ ]
@@ -0,0 +1,160 @@
1
+ """Basic metrics: hit rate and forward returns calculation.
2
+
3
+ This module provides fundamental building blocks for feature evaluation.
4
+ """
5
+
6
+ from typing import TYPE_CHECKING, Union, cast
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import polars as pl
11
+
12
+ from ml4t.diagnostic.backends.adapter import DataFrameAdapter
13
+
14
+ if TYPE_CHECKING:
15
+ from numpy.typing import NDArray
16
+
17
+
18
+ def hit_rate(
19
+ predictions: Union[pl.Series, pd.Series, "NDArray"],
20
+ returns: Union[pl.Series, pd.Series, "NDArray"],
21
+ ) -> float:
22
+ """Calculate hit rate (percentage of correct directional predictions).
23
+
24
+ Hit rate measures what percentage of predictions correctly identify the
25
+ direction of subsequent returns (positive/negative).
26
+
27
+ Parameters
28
+ ----------
29
+ predictions : Union[pl.Series, pd.Series, np.ndarray]
30
+ Model predictions or scores
31
+ returns : Union[pl.Series, pd.Series, np.ndarray]
32
+ Forward returns corresponding to predictions
33
+
34
+ Returns
35
+ -------
36
+ float
37
+ Hit rate as a percentage (0-100)
38
+
39
+ Examples
40
+ --------
41
+ >>> predictions = np.array([0.1, -0.2, 0.3, -0.1])
42
+ >>> returns = np.array([0.02, -0.01, 0.05, 0.01]) # Note: last one wrong direction
43
+ >>> hr = hit_rate(predictions, returns)
44
+ >>> print(f"Hit Rate: {hr:.1f}%")
45
+ Hit Rate: 75.0%
46
+ """
47
+ # Convert inputs to numpy
48
+ pred_array = DataFrameAdapter.to_numpy(predictions).flatten()
49
+ ret_array = DataFrameAdapter.to_numpy(returns).flatten()
50
+
51
+ # Validate inputs
52
+ if len(pred_array) != len(ret_array):
53
+ raise ValueError("Predictions and returns must have the same length")
54
+
55
+ # Remove NaN pairs
56
+ valid_mask = ~(np.isnan(pred_array) | np.isnan(ret_array))
57
+ pred_clean = pred_array[valid_mask]
58
+ ret_clean = ret_array[valid_mask]
59
+
60
+ if len(pred_clean) == 0:
61
+ return np.nan
62
+
63
+ # Calculate directional accuracy
64
+ pred_direction = np.sign(pred_clean)
65
+ ret_direction = np.sign(ret_clean)
66
+
67
+ # Count correct predictions (same sign)
68
+ correct_predictions = pred_direction == ret_direction
69
+
70
+ # Handle zero returns/predictions by considering them neutral (correct)
71
+ zero_mask = (pred_clean == 0) | (ret_clean == 0)
72
+ correct_predictions[zero_mask] = True # Conservative approach
73
+
74
+ hit_rate_value = np.mean(correct_predictions) * 100
75
+
76
+ return float(hit_rate_value)
77
+
78
+
79
+ def compute_forward_returns(
80
+ prices: pl.DataFrame | pd.DataFrame,
81
+ periods: int | list[int] = 1,
82
+ price_col: str = "close",
83
+ group_col: str | None = None,
84
+ ) -> pl.DataFrame | pd.DataFrame:
85
+ """Compute forward returns for given periods.
86
+
87
+ This is a helper function for IC analysis, computing the forward-looking
88
+ returns that will be correlated with predictions/features.
89
+
90
+ Parameters
91
+ ----------
92
+ prices : Union[pl.DataFrame, pd.DataFrame]
93
+ Price data with at least price_col and optionally group_col
94
+ periods : Union[int, list[int]], default 1
95
+ Forward periods to compute (e.g., [1, 5, 21] for 1d, 1w, 1m)
96
+ price_col : str, default "close"
97
+ Column name containing prices
98
+ group_col : str | None, default None
99
+ Column for grouping (e.g., 'symbol' for multi-asset)
100
+
101
+ Returns
102
+ -------
103
+ Union[pl.DataFrame, pd.DataFrame]
104
+ DataFrame with forward return columns: fwd_ret_1, fwd_ret_5, etc.
105
+
106
+ Examples
107
+ --------
108
+ >>> prices = pl.DataFrame({
109
+ ... "date": ["2024-01-01", "2024-01-02", "2024-01-03"],
110
+ ... "close": [100.0, 102.0, 101.0]
111
+ ... })
112
+ >>> fwd_returns = compute_forward_returns(prices, periods=[1, 2])
113
+ >>> print(fwd_returns.columns)
114
+ ['date', 'close', 'fwd_ret_1', 'fwd_ret_2']
115
+ """
116
+ is_polars = isinstance(prices, pl.DataFrame)
117
+
118
+ # Ensure periods is a list
119
+ if isinstance(periods, int):
120
+ periods = [periods]
121
+
122
+ if is_polars:
123
+ df = cast(pl.DataFrame, prices).clone()
124
+
125
+ if group_col is not None:
126
+ # Group-wise forward returns (e.g., per symbol)
127
+ for period in periods:
128
+ col_name = f"fwd_ret_{period}"
129
+ df = df.with_columns(
130
+ [
131
+ (
132
+ pl.col(price_col).shift(-period).over(group_col) / pl.col(price_col) - 1
133
+ ).alias(col_name)
134
+ ]
135
+ )
136
+ else:
137
+ # Simple forward returns
138
+ for period in periods:
139
+ col_name = f"fwd_ret_{period}"
140
+ df = df.with_columns(
141
+ [(pl.col(price_col).shift(-period) / pl.col(price_col) - 1).alias(col_name)]
142
+ )
143
+
144
+ return df
145
+
146
+ # pandas - use different variable name to avoid type conflict
147
+ df_pd = cast(pd.DataFrame, prices).copy()
148
+
149
+ if group_col is not None:
150
+ # Group-wise forward returns
151
+ for period in periods:
152
+ col_name = f"fwd_ret_{period}"
153
+ df_pd[col_name] = df_pd.groupby(group_col)[price_col].pct_change(period).shift(-period)
154
+ else:
155
+ # Simple forward returns
156
+ for period in periods:
157
+ col_name = f"fwd_ret_{period}"
158
+ df_pd[col_name] = df_pd[price_col].pct_change(period).shift(-period)
159
+
160
+ return df_pd