ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,36 @@
1
+ """Result classes for Barrier Analysis module.
2
+
3
+ This package provides Pydantic result classes for storing and serializing
4
+ barrier analysis outputs including hit rates, profit factors, precision/recall,
5
+ and time-to-target metrics.
6
+
7
+ Triple barrier outcomes from ml4t.features:
8
+ - label: int (-1=SL hit, 0=timeout, 1=TP hit)
9
+ - label_return: float (actual return at exit)
10
+ - label_bars: int (bars from entry to exit)
11
+
12
+ References
13
+ ----------
14
+ Lopez de Prado, M. (2018). "Advances in Financial Machine Learning"
15
+ Chapter 3: Labeling (Triple Barrier Method)
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from ml4t.diagnostic.results.barrier_results.hit_rate import HitRateResult
21
+ from ml4t.diagnostic.results.barrier_results.precision_recall import PrecisionRecallResult
22
+ from ml4t.diagnostic.results.barrier_results.profit_factor import ProfitFactorResult
23
+ from ml4t.diagnostic.results.barrier_results.tearsheet import BarrierTearSheet
24
+ from ml4t.diagnostic.results.barrier_results.time_to_target import TimeToTargetResult
25
+ from ml4t.diagnostic.results.barrier_results.validation import _validate_quantile_dict_keys
26
+
27
+ __all__ = [
28
+ # Validation helper
29
+ "_validate_quantile_dict_keys",
30
+ # Result classes
31
+ "HitRateResult",
32
+ "ProfitFactorResult",
33
+ "PrecisionRecallResult",
34
+ "TimeToTargetResult",
35
+ "BarrierTearSheet",
36
+ ]
@@ -0,0 +1,304 @@
1
+ """Hit rate analysis results for barrier outcomes.
2
+
3
+ This module provides the HitRateResult class for storing hit rate metrics
4
+ (TP, SL, timeout) by signal quantile, including chi-square independence tests.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import polars as pl
10
+ from pydantic import Field, model_validator
11
+
12
+ from ml4t.diagnostic.results.barrier_results.validation import _validate_quantile_dict_keys
13
+ from ml4t.diagnostic.results.base import BaseResult
14
+
15
+
16
+ class HitRateResult(BaseResult):
17
+ """Results from hit rate analysis by signal decile.
18
+
19
+ Contains hit rates (% TP, % SL, % timeout) for each signal quantile,
20
+ along with chi-square test for independence between signal strength
21
+ and barrier outcome.
22
+
23
+ Examples
24
+ --------
25
+ >>> result = hit_rate_result
26
+ >>> print(result.summary())
27
+ >>> df = result.get_dataframe("hit_rates")
28
+ """
29
+
30
+ analysis_type: str = Field(default="barrier_hit_rate", frozen=True)
31
+
32
+ # ==========================================================================
33
+ # Configuration
34
+ # ==========================================================================
35
+
36
+ n_quantiles: int = Field(
37
+ ...,
38
+ description="Number of quantiles used",
39
+ )
40
+
41
+ quantile_labels: list[str] = Field(
42
+ ...,
43
+ description="Labels for each quantile (e.g., ['D1', 'D2', ..., 'D10'])",
44
+ )
45
+
46
+ # ==========================================================================
47
+ # Hit Rates by Quantile
48
+ # ==========================================================================
49
+
50
+ hit_rate_tp: dict[str, float] = Field(
51
+ ...,
52
+ description="Take-profit hit rate per quantile: {quantile: rate}",
53
+ )
54
+
55
+ hit_rate_sl: dict[str, float] = Field(
56
+ ...,
57
+ description="Stop-loss hit rate per quantile: {quantile: rate}",
58
+ )
59
+
60
+ hit_rate_timeout: dict[str, float] = Field(
61
+ ...,
62
+ description="Timeout hit rate per quantile: {quantile: rate}",
63
+ )
64
+
65
+ # ==========================================================================
66
+ # Counts
67
+ # ==========================================================================
68
+
69
+ count_tp: dict[str, int] = Field(
70
+ ...,
71
+ description="Take-profit count per quantile",
72
+ )
73
+
74
+ count_sl: dict[str, int] = Field(
75
+ ...,
76
+ description="Stop-loss count per quantile",
77
+ )
78
+
79
+ count_timeout: dict[str, int] = Field(
80
+ ...,
81
+ description="Timeout count per quantile",
82
+ )
83
+
84
+ count_total: dict[str, int] = Field(
85
+ ...,
86
+ description="Total count per quantile",
87
+ )
88
+
89
+ # ==========================================================================
90
+ # Statistical Test (Chi-Square Independence)
91
+ # ==========================================================================
92
+
93
+ chi2_statistic: float = Field(
94
+ ...,
95
+ description="Chi-square statistic for independence test",
96
+ )
97
+
98
+ chi2_p_value: float = Field(
99
+ ...,
100
+ description="P-value for chi-square test",
101
+ )
102
+
103
+ chi2_dof: int = Field(
104
+ ...,
105
+ description="Degrees of freedom for chi-square test",
106
+ )
107
+
108
+ is_significant: bool = Field(
109
+ ...,
110
+ description="Whether signal quantile significantly affects outcome (p < alpha)",
111
+ )
112
+
113
+ significance_level: float = Field(
114
+ ...,
115
+ description="Significance level used for test",
116
+ )
117
+
118
+ # ==========================================================================
119
+ # Aggregates
120
+ # ==========================================================================
121
+
122
+ overall_hit_rate_tp: float = Field(
123
+ ...,
124
+ description="Overall take-profit hit rate across all observations",
125
+ )
126
+
127
+ overall_hit_rate_sl: float = Field(
128
+ ...,
129
+ description="Overall stop-loss hit rate across all observations",
130
+ )
131
+
132
+ overall_hit_rate_timeout: float = Field(
133
+ ...,
134
+ description="Overall timeout hit rate across all observations",
135
+ )
136
+
137
+ n_observations: int = Field(
138
+ ...,
139
+ description="Total number of observations analyzed",
140
+ )
141
+
142
+ # ==========================================================================
143
+ # Monotonicity
144
+ # ==========================================================================
145
+
146
+ tp_rate_monotonic: bool = Field(
147
+ ...,
148
+ description="Whether TP hit rate is monotonic across quantiles",
149
+ )
150
+
151
+ tp_rate_direction: str = Field(
152
+ ...,
153
+ description="Direction of TP rate change: 'increasing', 'decreasing', or 'none'",
154
+ )
155
+
156
+ tp_rate_spearman: float = Field(
157
+ ...,
158
+ description="Spearman correlation between quantile rank and TP hit rate",
159
+ )
160
+
161
+ # ==========================================================================
162
+ # Validation
163
+ # ==========================================================================
164
+
165
+ @model_validator(mode="after")
166
+ def _validate_quantile_keys(self) -> HitRateResult:
167
+ """Validate that all quantile-keyed dicts have consistent keys."""
168
+ if self.n_quantiles != len(self.quantile_labels):
169
+ raise ValueError(
170
+ f"n_quantiles ({self.n_quantiles}) != len(quantile_labels) ({len(self.quantile_labels)})"
171
+ )
172
+ _validate_quantile_dict_keys(
173
+ self.quantile_labels,
174
+ [
175
+ ("hit_rate_tp", self.hit_rate_tp),
176
+ ("hit_rate_sl", self.hit_rate_sl),
177
+ ("hit_rate_timeout", self.hit_rate_timeout),
178
+ ("count_tp", self.count_tp),
179
+ ("count_sl", self.count_sl),
180
+ ("count_timeout", self.count_timeout),
181
+ ("count_total", self.count_total),
182
+ ],
183
+ )
184
+ return self
185
+
186
+ # ==========================================================================
187
+ # Methods
188
+ # ==========================================================================
189
+
190
+ def get_dataframe(self, name: str | None = None) -> pl.DataFrame:
191
+ """Get results as Polars DataFrame.
192
+
193
+ Parameters
194
+ ----------
195
+ name : str | None
196
+ DataFrame to retrieve:
197
+ - None or "hit_rates": Hit rates by quantile
198
+ - "counts": Raw counts by quantile and outcome
199
+ - "summary": Single-row summary statistics
200
+
201
+ Returns
202
+ -------
203
+ pl.DataFrame
204
+ Requested DataFrame
205
+ """
206
+ if name is None or name == "hit_rates":
207
+ return pl.DataFrame(
208
+ {
209
+ "quantile": self.quantile_labels,
210
+ "hit_rate_tp": [self.hit_rate_tp[q] for q in self.quantile_labels],
211
+ "hit_rate_sl": [self.hit_rate_sl[q] for q in self.quantile_labels],
212
+ "hit_rate_timeout": [self.hit_rate_timeout[q] for q in self.quantile_labels],
213
+ "count_total": [self.count_total[q] for q in self.quantile_labels],
214
+ }
215
+ )
216
+
217
+ if name == "counts":
218
+ return pl.DataFrame(
219
+ {
220
+ "quantile": self.quantile_labels,
221
+ "count_tp": [self.count_tp[q] for q in self.quantile_labels],
222
+ "count_sl": [self.count_sl[q] for q in self.quantile_labels],
223
+ "count_timeout": [self.count_timeout[q] for q in self.quantile_labels],
224
+ "count_total": [self.count_total[q] for q in self.quantile_labels],
225
+ }
226
+ )
227
+
228
+ if name == "summary":
229
+ return pl.DataFrame(
230
+ {
231
+ "metric": [
232
+ "n_observations",
233
+ "n_quantiles",
234
+ "overall_hit_rate_tp",
235
+ "overall_hit_rate_sl",
236
+ "overall_hit_rate_timeout",
237
+ "chi2_statistic",
238
+ "chi2_p_value",
239
+ "is_significant",
240
+ "tp_rate_monotonic",
241
+ "tp_rate_spearman",
242
+ ],
243
+ "value": [
244
+ float(self.n_observations),
245
+ float(self.n_quantiles),
246
+ self.overall_hit_rate_tp,
247
+ self.overall_hit_rate_sl,
248
+ self.overall_hit_rate_timeout,
249
+ self.chi2_statistic,
250
+ self.chi2_p_value,
251
+ float(self.is_significant),
252
+ float(self.tp_rate_monotonic),
253
+ self.tp_rate_spearman,
254
+ ],
255
+ }
256
+ )
257
+
258
+ raise ValueError(
259
+ f"Unknown DataFrame name: {name}. Available: 'hit_rates', 'counts', 'summary'"
260
+ )
261
+
262
+ def list_available_dataframes(self) -> list[str]:
263
+ """List available DataFrame views."""
264
+ return ["hit_rates", "counts", "summary"]
265
+
266
+ def summary(self) -> str:
267
+ """Get human-readable summary of hit rate results."""
268
+ lines = [
269
+ "=" * 60,
270
+ "Barrier Hit Rate Analysis",
271
+ "=" * 60,
272
+ "",
273
+ f"Observations: {self.n_observations:>10,}",
274
+ f"Quantiles: {self.n_quantiles:>10}",
275
+ "",
276
+ "Overall Hit Rates:",
277
+ f" Take-Profit: {self.overall_hit_rate_tp:>10.1%}",
278
+ f" Stop-Loss: {self.overall_hit_rate_sl:>10.1%}",
279
+ f" Timeout: {self.overall_hit_rate_timeout:>10.1%}",
280
+ "",
281
+ "Chi-Square Test (Signal Decile vs Outcome):",
282
+ f" Chi2 Statistic: {self.chi2_statistic:>10.2f}",
283
+ f" P-value: {self.chi2_p_value:>10.4f}",
284
+ f" DoF: {self.chi2_dof:>10}",
285
+ f" Significant: {'Yes' if self.is_significant else 'No':>10} (alpha={self.significance_level})",
286
+ "",
287
+ "Monotonicity (TP Rate vs Signal Strength):",
288
+ f" Monotonic: {'Yes' if self.tp_rate_monotonic else 'No':>10}",
289
+ f" Direction: {self.tp_rate_direction:>10}",
290
+ f" Spearman rho: {self.tp_rate_spearman:>10.4f}",
291
+ "",
292
+ "-" * 60,
293
+ "Hit Rates by Quantile:",
294
+ "-" * 60,
295
+ f"{'Quantile':<10} {'TP Rate':>10} {'SL Rate':>10} {'Timeout':>10} {'Count':>8}",
296
+ ]
297
+
298
+ for q in self.quantile_labels:
299
+ lines.append(
300
+ f"{q:<10} {self.hit_rate_tp[q]:>10.1%} {self.hit_rate_sl[q]:>10.1%} "
301
+ f"{self.hit_rate_timeout[q]:>10.1%} {self.count_total[q]:>8,}"
302
+ )
303
+
304
+ return "\n".join(lines)
@@ -0,0 +1,266 @@
1
+ """Precision/recall analysis results for barrier outcomes.
2
+
3
+ This module provides the PrecisionRecallResult class for storing precision,
4
+ recall, F1 scores, and lift metrics for barrier outcomes by signal quantile.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import polars as pl
10
+ from pydantic import Field, model_validator
11
+
12
+ from ml4t.diagnostic.results.barrier_results.validation import _validate_quantile_dict_keys
13
+ from ml4t.diagnostic.results.base import BaseResult
14
+
15
+
16
+ class PrecisionRecallResult(BaseResult):
17
+ """Results from precision/recall analysis for barrier outcomes.
18
+
19
+ Precision: Of signals in top quantile, what fraction hit TP?
20
+ Recall: Of all TP outcomes, what fraction came from top quantile?
21
+
22
+ This helps understand signal selectivity vs coverage trade-offs.
23
+
24
+ Examples
25
+ --------
26
+ >>> result = precision_recall_result
27
+ >>> print(result.summary())
28
+ >>> df = result.get_dataframe()
29
+ """
30
+
31
+ analysis_type: str = Field(default="barrier_precision_recall", frozen=True)
32
+
33
+ # ==========================================================================
34
+ # Configuration
35
+ # ==========================================================================
36
+
37
+ n_quantiles: int = Field(
38
+ ...,
39
+ description="Number of quantiles used",
40
+ )
41
+
42
+ quantile_labels: list[str] = Field(
43
+ ...,
44
+ description="Labels for each quantile (e.g., ['D1', 'D2', ..., 'D10'])",
45
+ )
46
+
47
+ # ==========================================================================
48
+ # Precision by Quantile (TP-focused)
49
+ # ==========================================================================
50
+
51
+ precision_tp: dict[str, float] = Field(
52
+ ...,
53
+ description="Precision for TP: P(TP | in quantile) = TP count / total in quantile",
54
+ )
55
+
56
+ # ==========================================================================
57
+ # Recall by Quantile (TP-focused)
58
+ # ==========================================================================
59
+
60
+ recall_tp: dict[str, float] = Field(
61
+ ...,
62
+ description="Recall for TP: P(in quantile | TP) = TP in quantile / all TP",
63
+ )
64
+
65
+ # ==========================================================================
66
+ # Cumulative Metrics (from top quantile down)
67
+ # ==========================================================================
68
+
69
+ cumulative_precision_tp: dict[str, float] = Field(
70
+ ...,
71
+ description="Cumulative precision: P(TP | in top k quantiles)",
72
+ )
73
+
74
+ cumulative_recall_tp: dict[str, float] = Field(
75
+ ...,
76
+ description="Cumulative recall: P(in top k quantiles | TP)",
77
+ )
78
+
79
+ cumulative_f1_tp: dict[str, float] = Field(
80
+ ...,
81
+ description="Cumulative F1 score: 2 * (precision * recall) / (precision + recall)",
82
+ )
83
+
84
+ # ==========================================================================
85
+ # Lift Metrics
86
+ # ==========================================================================
87
+
88
+ lift_tp: dict[str, float] = Field(
89
+ ...,
90
+ description="Lift for TP: precision / baseline TP rate",
91
+ )
92
+
93
+ cumulative_lift_tp: dict[str, float] = Field(
94
+ ...,
95
+ description="Cumulative lift for TP",
96
+ )
97
+
98
+ # ==========================================================================
99
+ # Baseline
100
+ # ==========================================================================
101
+
102
+ baseline_tp_rate: float = Field(
103
+ ...,
104
+ description="Baseline TP rate (overall TP count / total)",
105
+ )
106
+
107
+ total_tp_count: int = Field(
108
+ ...,
109
+ description="Total number of TP outcomes",
110
+ )
111
+
112
+ n_observations: int = Field(
113
+ ...,
114
+ description="Total number of observations",
115
+ )
116
+
117
+ # ==========================================================================
118
+ # Best Operating Point
119
+ # ==========================================================================
120
+
121
+ best_f1_quantile: str = Field(
122
+ ...,
123
+ description="Quantile with best cumulative F1 score",
124
+ )
125
+
126
+ best_f1_score: float = Field(
127
+ ...,
128
+ description="Best cumulative F1 score achieved",
129
+ )
130
+
131
+ # ==========================================================================
132
+ # Validation
133
+ # ==========================================================================
134
+
135
+ @model_validator(mode="after")
136
+ def _validate_quantile_keys(self) -> PrecisionRecallResult:
137
+ """Validate that all quantile-keyed dicts have consistent keys."""
138
+ if self.n_quantiles != len(self.quantile_labels):
139
+ raise ValueError(
140
+ f"n_quantiles ({self.n_quantiles}) != len(quantile_labels) ({len(self.quantile_labels)})"
141
+ )
142
+ _validate_quantile_dict_keys(
143
+ self.quantile_labels,
144
+ [
145
+ ("precision_tp", self.precision_tp),
146
+ ("recall_tp", self.recall_tp),
147
+ ("cumulative_precision_tp", self.cumulative_precision_tp),
148
+ ("cumulative_recall_tp", self.cumulative_recall_tp),
149
+ ("cumulative_f1_tp", self.cumulative_f1_tp),
150
+ ("lift_tp", self.lift_tp),
151
+ ("cumulative_lift_tp", self.cumulative_lift_tp),
152
+ ],
153
+ )
154
+ return self
155
+
156
+ def get_dataframe(self, name: str | None = None) -> pl.DataFrame:
157
+ """Get results as Polars DataFrame.
158
+
159
+ Parameters
160
+ ----------
161
+ name : str | None
162
+ DataFrame to retrieve:
163
+ - None or "precision_recall": Per-quantile metrics
164
+ - "cumulative": Cumulative metrics from top down
165
+ - "summary": Summary statistics
166
+
167
+ Returns
168
+ -------
169
+ pl.DataFrame
170
+ Requested DataFrame
171
+ """
172
+ if name is None or name == "precision_recall":
173
+ return pl.DataFrame(
174
+ {
175
+ "quantile": self.quantile_labels,
176
+ "precision_tp": [self.precision_tp[q] for q in self.quantile_labels],
177
+ "recall_tp": [self.recall_tp[q] for q in self.quantile_labels],
178
+ "lift_tp": [self.lift_tp[q] for q in self.quantile_labels],
179
+ }
180
+ )
181
+
182
+ if name == "cumulative":
183
+ return pl.DataFrame(
184
+ {
185
+ "quantile": self.quantile_labels,
186
+ "cumulative_precision_tp": [
187
+ self.cumulative_precision_tp[q] for q in self.quantile_labels
188
+ ],
189
+ "cumulative_recall_tp": [
190
+ self.cumulative_recall_tp[q] for q in self.quantile_labels
191
+ ],
192
+ "cumulative_f1_tp": [self.cumulative_f1_tp[q] for q in self.quantile_labels],
193
+ "cumulative_lift_tp": [
194
+ self.cumulative_lift_tp[q] for q in self.quantile_labels
195
+ ],
196
+ }
197
+ )
198
+
199
+ if name == "summary":
200
+ return pl.DataFrame(
201
+ {
202
+ "metric": [
203
+ "n_observations",
204
+ "n_quantiles",
205
+ "total_tp_count",
206
+ "baseline_tp_rate",
207
+ "best_f1_quantile",
208
+ "best_f1_score",
209
+ ],
210
+ "value": [
211
+ float(self.n_observations),
212
+ float(self.n_quantiles),
213
+ float(self.total_tp_count),
214
+ self.baseline_tp_rate,
215
+ self.best_f1_quantile,
216
+ self.best_f1_score,
217
+ ],
218
+ }
219
+ )
220
+
221
+ raise ValueError(
222
+ f"Unknown DataFrame name: {name}. Available: 'precision_recall', 'cumulative', 'summary'"
223
+ )
224
+
225
+ def list_available_dataframes(self) -> list[str]:
226
+ """List available DataFrame views."""
227
+ return ["precision_recall", "cumulative", "summary"]
228
+
229
+ def summary(self) -> str:
230
+ """Get human-readable summary of precision/recall results."""
231
+ lines = [
232
+ "=" * 60,
233
+ "Barrier Precision/Recall Analysis (TP-focused)",
234
+ "=" * 60,
235
+ "",
236
+ f"Observations: {self.n_observations:>10,}",
237
+ f"Total TP Count: {self.total_tp_count:>10,}",
238
+ f"Baseline TP Rate: {self.baseline_tp_rate:>10.1%}",
239
+ "",
240
+ f"Best F1 Score: {self.best_f1_score:>10.4f} (at {self.best_f1_quantile})",
241
+ "",
242
+ "-" * 60,
243
+ "Per-Quantile Metrics:",
244
+ "-" * 60,
245
+ f"{'Quantile':<10} {'Precision':>10} {'Recall':>10} {'Lift':>8}",
246
+ ]
247
+
248
+ for q in self.quantile_labels:
249
+ lines.append(
250
+ f"{q:<10} {self.precision_tp[q]:>10.1%} {self.recall_tp[q]:>10.1%} "
251
+ f"{self.lift_tp[q]:>8.2f}x"
252
+ )
253
+
254
+ lines.append("")
255
+ lines.append("-" * 60)
256
+ lines.append("Cumulative Metrics (from top quantile):")
257
+ lines.append("-" * 60)
258
+ lines.append(f"{'Quantile':<10} {'Cum Prec':>10} {'Cum Recall':>10} {'Cum F1':>10}")
259
+
260
+ for q in self.quantile_labels:
261
+ lines.append(
262
+ f"{q:<10} {self.cumulative_precision_tp[q]:>10.1%} "
263
+ f"{self.cumulative_recall_tp[q]:>10.1%} {self.cumulative_f1_tp[q]:>10.4f}"
264
+ )
265
+
266
+ return "\n".join(lines)