ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1050 @@
1
+ """Barrier Analysis module for triple barrier outcome evaluation.
2
+
3
+ This module provides analysis of signal quality using triple barrier outcomes
4
+ (take-profit, stop-loss, timeout) instead of simple forward returns.
5
+
6
+ The BarrierAnalysis class computes:
7
+ - Hit rates by signal decile (% TP, % SL, % timeout)
8
+ - Profit factor by decile (sum TP returns / |sum SL returns|)
9
+ - Statistical tests for signal-outcome independence (chi-square)
10
+ - Monotonicity tests for signal strength vs outcome relationship
11
+
12
+ Triple barrier outcomes from ml4t.features:
13
+ - label: int (-1=SL hit, 0=timeout, 1=TP hit)
14
+ - label_return: float (actual return at exit)
15
+ - label_bars: int (bars from entry to exit)
16
+
17
+ References
18
+ ----------
19
+ Lopez de Prado, M. (2018). "Advances in Financial Machine Learning"
20
+ Chapter 3: Labeling (Triple Barrier Method)
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import warnings
26
+ from typing import TYPE_CHECKING
27
+
28
+ import numpy as np
29
+ import polars as pl
30
+ from scipy import stats
31
+
32
+ from ml4t.diagnostic.config.barrier_config import BarrierConfig, BarrierLabel
33
+ from ml4t.diagnostic.results.barrier_results import (
34
+ BarrierTearSheet,
35
+ HitRateResult,
36
+ PrecisionRecallResult,
37
+ ProfitFactorResult,
38
+ TimeToTargetResult,
39
+ )
40
+
41
+ if TYPE_CHECKING:
42
+ pass
43
+
44
+
45
+ class BarrierAnalysis:
46
+ """Analyze signal quality using triple barrier outcomes.
47
+
48
+ This class evaluates how well a signal predicts barrier outcomes
49
+ (take-profit hit, stop-loss hit, or timeout) rather than raw returns.
50
+
51
+ Parameters
52
+ ----------
53
+ signal_data : pl.DataFrame
54
+ DataFrame with columns: [date_col, asset_col, signal_col]
55
+ Contains signal values for each asset-date pair.
56
+
57
+ barrier_labels : pl.DataFrame
58
+ DataFrame with columns: [date_col, asset_col, label_col, label_return_col, label_bars_col]
59
+ Contains triple barrier outcomes from ml4t.features.triple_barrier_labels().
60
+
61
+ config : BarrierConfig | None, optional
62
+ Configuration for analysis. Uses defaults if not provided.
63
+
64
+ Examples
65
+ --------
66
+ >>> from ml4t.diagnostic.evaluation import BarrierAnalysis
67
+ >>> from ml4t.diagnostic.config import BarrierConfig
68
+ >>>
69
+ >>> # Basic usage
70
+ >>> analysis = BarrierAnalysis(signals_df, barriers_df)
71
+ >>> hit_rates = analysis.compute_hit_rates()
72
+ >>> print(hit_rates.summary())
73
+ >>>
74
+ >>> # With custom config
75
+ >>> config = BarrierConfig(n_quantiles=5)
76
+ >>> analysis = BarrierAnalysis(signals_df, barriers_df, config=config)
77
+ >>> profit_factor = analysis.compute_profit_factor()
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ signal_data: pl.DataFrame,
83
+ barrier_labels: pl.DataFrame,
84
+ config: BarrierConfig | None = None,
85
+ ) -> None:
86
+ """Initialize BarrierAnalysis.
87
+
88
+ Parameters
89
+ ----------
90
+ signal_data : pl.DataFrame
91
+ Signal values with date, asset, signal columns.
92
+ barrier_labels : pl.DataFrame
93
+ Barrier outcomes with date, asset, label, label_return, label_bars columns.
94
+ config : BarrierConfig | None
95
+ Configuration object. Uses defaults if None.
96
+
97
+ Raises
98
+ ------
99
+ ValueError
100
+ If required columns are missing or data is invalid.
101
+ """
102
+ self.config = config or BarrierConfig()
103
+ self._validate_inputs(signal_data, barrier_labels)
104
+
105
+ # Store original data
106
+ self._signal_data = signal_data
107
+ self._barrier_labels = barrier_labels
108
+
109
+ # Merge and prepare data
110
+ self._merged_data = self._prepare_data(signal_data, barrier_labels)
111
+
112
+ # Cache for computed results
113
+ self._hit_rate_result: HitRateResult | None = None
114
+ self._profit_factor_result: ProfitFactorResult | None = None
115
+ self._precision_recall_result: PrecisionRecallResult | None = None
116
+ self._time_to_target_result: TimeToTargetResult | None = None
117
+
118
+ def _validate_inputs(
119
+ self,
120
+ signal_data: pl.DataFrame,
121
+ barrier_labels: pl.DataFrame,
122
+ ) -> None:
123
+ """Validate input DataFrames have required columns and valid data.
124
+
125
+ Raises
126
+ ------
127
+ ValueError
128
+ If validation fails.
129
+ """
130
+ cfg = self.config
131
+
132
+ # Check signal_data columns
133
+ signal_required = {cfg.date_col, cfg.asset_col, cfg.signal_col}
134
+ signal_cols = set(signal_data.columns)
135
+ missing_signal = signal_required - signal_cols
136
+ if missing_signal:
137
+ raise ValueError(
138
+ f"signal_data missing required columns: {missing_signal}. "
139
+ f"Available columns: {signal_cols}"
140
+ )
141
+
142
+ # Check barrier_labels columns
143
+ barrier_required = {cfg.date_col, cfg.asset_col, cfg.label_col, cfg.label_return_col}
144
+ barrier_cols = set(barrier_labels.columns)
145
+ missing_barrier = barrier_required - barrier_cols
146
+ if missing_barrier:
147
+ raise ValueError(
148
+ f"barrier_labels missing required columns: {missing_barrier}. "
149
+ f"Available columns: {barrier_cols}"
150
+ )
151
+
152
+ # Check for empty DataFrames
153
+ if signal_data.height == 0:
154
+ raise ValueError("signal_data is empty")
155
+ if barrier_labels.height == 0:
156
+ raise ValueError("barrier_labels is empty")
157
+
158
+ # Validate label values
159
+ valid_labels = {-1, 0, 1}
160
+ unique_labels = set(barrier_labels[cfg.label_col].unique().to_list())
161
+ invalid_labels = unique_labels - valid_labels
162
+ if invalid_labels:
163
+ raise ValueError(
164
+ f"barrier_labels[{cfg.label_col}] contains invalid values: {invalid_labels}. "
165
+ f"Expected values: {valid_labels} (-1=SL, 0=timeout, 1=TP)"
166
+ )
167
+
168
+ def _prepare_data(
169
+ self,
170
+ signal_data: pl.DataFrame,
171
+ barrier_labels: pl.DataFrame,
172
+ ) -> pl.DataFrame:
173
+ """Merge signal data with barrier labels and prepare for analysis.
174
+
175
+ Returns
176
+ -------
177
+ pl.DataFrame
178
+ Merged DataFrame with signal values and barrier outcomes,
179
+ plus computed quantile labels.
180
+ """
181
+ cfg = self.config
182
+
183
+ # Merge on date and asset
184
+ merged = signal_data.join(
185
+ barrier_labels,
186
+ on=[cfg.date_col, cfg.asset_col],
187
+ how="inner",
188
+ )
189
+
190
+ if merged.height == 0:
191
+ raise ValueError(
192
+ "No matching rows after merging signal_data and barrier_labels. "
193
+ "Check that date and asset columns match."
194
+ )
195
+
196
+ # Filter outliers if configured
197
+ if cfg.filter_zscore is not None:
198
+ signal_mean = merged[cfg.signal_col].mean()
199
+ signal_std = merged[cfg.signal_col].std()
200
+ if signal_std is not None and signal_std > 0:
201
+ merged = merged.filter(
202
+ ((pl.col(cfg.signal_col) - signal_mean) / signal_std).abs() <= cfg.filter_zscore
203
+ )
204
+
205
+ # Drop NaN signals
206
+ merged = merged.drop_nulls(subset=[cfg.signal_col])
207
+
208
+ if merged.height == 0:
209
+ raise ValueError("No valid observations after filtering NaN signals and outliers")
210
+
211
+ # Add quantile labels
212
+ merged = self._add_quantile_labels(merged)
213
+
214
+ return merged
215
+
216
+ def _add_quantile_labels(self, df: pl.DataFrame) -> pl.DataFrame:
217
+ """Add quantile labels to DataFrame based on signal values.
218
+
219
+ Parameters
220
+ ----------
221
+ df : pl.DataFrame
222
+ DataFrame with signal column.
223
+
224
+ Returns
225
+ -------
226
+ pl.DataFrame
227
+ DataFrame with added 'quantile' column.
228
+ """
229
+ cfg = self.config
230
+ n_q = cfg.n_quantiles
231
+
232
+ # Generate quantile labels (D1, D2, ..., D10 for deciles)
233
+ quantile_labels = [f"D{i + 1}" for i in range(n_q)]
234
+
235
+ if cfg.decile_method.value == "quantile":
236
+ # Equal frequency bins (like pd.qcut)
237
+ df = df.with_columns(
238
+ pl.col(cfg.signal_col)
239
+ .qcut(n_q, labels=quantile_labels, allow_duplicates=True)
240
+ .alias("quantile")
241
+ )
242
+ else:
243
+ # Equal width bins (like pd.cut)
244
+ df = df.with_columns(
245
+ pl.col(cfg.signal_col).cut(n_q, labels=quantile_labels).alias("quantile")
246
+ )
247
+
248
+ return df
249
+
250
+ @property
251
+ def merged_data(self) -> pl.DataFrame:
252
+ """Get the merged and prepared data."""
253
+ return self._merged_data
254
+
255
+ @property
256
+ def n_observations(self) -> int:
257
+ """Total number of observations after merging."""
258
+ return self._merged_data.height
259
+
260
+ @property
261
+ def n_assets(self) -> int:
262
+ """Number of unique assets."""
263
+ return self._merged_data[self.config.asset_col].n_unique()
264
+
265
+ @property
266
+ def n_dates(self) -> int:
267
+ """Number of unique dates."""
268
+ return self._merged_data[self.config.date_col].n_unique()
269
+
270
+ @property
271
+ def date_range(self) -> tuple[str, str]:
272
+ """Date range (start, end) as ISO strings."""
273
+ dates = self._merged_data[self.config.date_col]
274
+ min_date = dates.min()
275
+ max_date = dates.max()
276
+ return (str(min_date), str(max_date))
277
+
278
+ @property
279
+ def quantile_labels(self) -> list[str]:
280
+ """List of quantile labels used."""
281
+ return [f"D{i + 1}" for i in range(self.config.n_quantiles)]
282
+
283
+ def compute_hit_rates(self) -> HitRateResult:
284
+ """Compute hit rates by signal decile.
285
+
286
+ For each signal quantile, calculates the percentage of observations
287
+ that hit TP, SL, or timeout barriers.
288
+
289
+ Includes chi-square test for independence between signal strength
290
+ and barrier outcome.
291
+
292
+ Returns
293
+ -------
294
+ HitRateResult
295
+ Results containing hit rates per quantile, chi-square test,
296
+ and monotonicity analysis.
297
+
298
+ Examples
299
+ --------
300
+ >>> result = analysis.compute_hit_rates()
301
+ >>> print(result.summary())
302
+ >>> df = result.get_dataframe("hit_rates")
303
+ """
304
+ if self._hit_rate_result is not None:
305
+ return self._hit_rate_result
306
+
307
+ cfg = self.config
308
+ df = self._merged_data
309
+ q_labels = self.quantile_labels
310
+
311
+ # Initialize containers
312
+ hit_rate_tp: dict[str, float] = {}
313
+ hit_rate_sl: dict[str, float] = {}
314
+ hit_rate_timeout: dict[str, float] = {}
315
+ count_tp: dict[str, int] = {}
316
+ count_sl: dict[str, int] = {}
317
+ count_timeout: dict[str, int] = {}
318
+ count_total: dict[str, int] = {}
319
+
320
+ # Build contingency table for chi-square test
321
+ # Rows: quantiles, Columns: outcomes (SL, Timeout, TP)
322
+ contingency = np.zeros((cfg.n_quantiles, 3), dtype=np.int64)
323
+
324
+ for i, q in enumerate(q_labels):
325
+ q_data = df.filter(pl.col("quantile") == q)
326
+ n_total = q_data.height
327
+
328
+ if n_total == 0:
329
+ # Handle empty quantile
330
+ hit_rate_tp[q] = 0.0
331
+ hit_rate_sl[q] = 0.0
332
+ hit_rate_timeout[q] = 0.0
333
+ count_tp[q] = 0
334
+ count_sl[q] = 0
335
+ count_timeout[q] = 0
336
+ count_total[q] = 0
337
+ continue
338
+
339
+ # Count outcomes
340
+ n_tp = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value).height
341
+ n_sl = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value).height
342
+ n_timeout = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TIMEOUT.value).height
343
+
344
+ # Hit rates
345
+ hit_rate_tp[q] = n_tp / n_total
346
+ hit_rate_sl[q] = n_sl / n_total
347
+ hit_rate_timeout[q] = n_timeout / n_total
348
+
349
+ # Counts
350
+ count_tp[q] = n_tp
351
+ count_sl[q] = n_sl
352
+ count_timeout[q] = n_timeout
353
+ count_total[q] = n_total
354
+
355
+ # Contingency table row
356
+ contingency[i, 0] = n_sl
357
+ contingency[i, 1] = n_timeout
358
+ contingency[i, 2] = n_tp
359
+
360
+ # Chi-square test for independence
361
+ # H0: Signal quantile and barrier outcome are independent
362
+ # H1: They are dependent (signal predicts outcome)
363
+
364
+ # Remove rows/cols with all zeros to avoid chi2 issues
365
+ row_sums = contingency.sum(axis=1)
366
+ col_sums = contingency.sum(axis=0)
367
+ valid_rows = row_sums > 0
368
+ valid_cols = col_sums > 0
369
+
370
+ if valid_rows.sum() < 2 or valid_cols.sum() < 2:
371
+ # Not enough data for chi-square test
372
+ chi2_stat = 0.0
373
+ chi2_p = 1.0
374
+ chi2_dof = 0
375
+ warnings.warn(
376
+ "Insufficient variation in data for chi-square test. "
377
+ "Need at least 2 non-empty quantiles and 2 different outcomes.",
378
+ UserWarning,
379
+ stacklevel=2,
380
+ )
381
+ else:
382
+ contingency_valid = contingency[valid_rows][:, valid_cols]
383
+ chi2_stat, chi2_p, chi2_dof, _ = stats.chi2_contingency(contingency_valid)
384
+
385
+ # Overall hit rates
386
+ total_obs = df.height
387
+ overall_tp = (
388
+ df.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value).height / total_obs
389
+ )
390
+ overall_sl = (
391
+ df.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value).height / total_obs
392
+ )
393
+ overall_timeout = (
394
+ df.filter(pl.col(cfg.label_col) == BarrierLabel.TIMEOUT.value).height / total_obs
395
+ )
396
+
397
+ # Monotonicity analysis for TP rate
398
+ tp_rates = [hit_rate_tp[q] for q in q_labels]
399
+ tp_monotonic, tp_direction, tp_spearman = self._analyze_monotonicity(tp_rates)
400
+
401
+ self._hit_rate_result = HitRateResult(
402
+ n_quantiles=cfg.n_quantiles,
403
+ quantile_labels=q_labels,
404
+ hit_rate_tp=hit_rate_tp,
405
+ hit_rate_sl=hit_rate_sl,
406
+ hit_rate_timeout=hit_rate_timeout,
407
+ count_tp=count_tp,
408
+ count_sl=count_sl,
409
+ count_timeout=count_timeout,
410
+ count_total=count_total,
411
+ chi2_statistic=float(chi2_stat),
412
+ chi2_p_value=float(chi2_p),
413
+ chi2_dof=int(chi2_dof),
414
+ is_significant=chi2_p < cfg.significance_level,
415
+ significance_level=cfg.significance_level,
416
+ overall_hit_rate_tp=overall_tp,
417
+ overall_hit_rate_sl=overall_sl,
418
+ overall_hit_rate_timeout=overall_timeout,
419
+ n_observations=total_obs,
420
+ tp_rate_monotonic=tp_monotonic,
421
+ tp_rate_direction=tp_direction,
422
+ tp_rate_spearman=tp_spearman,
423
+ )
424
+
425
+ return self._hit_rate_result
426
+
427
+ def compute_profit_factor(self) -> ProfitFactorResult:
428
+ """Compute profit factor by signal decile.
429
+
430
+ Profit Factor = Sum(TP returns) / |Sum(SL returns)|
431
+
432
+ A profit factor > 1 indicates the quantile is net profitable
433
+ when trading based on the signal.
434
+
435
+ Returns
436
+ -------
437
+ ProfitFactorResult
438
+ Results containing profit factor per quantile and
439
+ return statistics.
440
+
441
+ Examples
442
+ --------
443
+ >>> result = analysis.compute_profit_factor()
444
+ >>> print(result.summary())
445
+ >>> df = result.get_dataframe()
446
+ """
447
+ if self._profit_factor_result is not None:
448
+ return self._profit_factor_result
449
+
450
+ cfg = self.config
451
+ df = self._merged_data
452
+ q_labels = self.quantile_labels
453
+ eps = cfg.profit_factor_epsilon
454
+
455
+ # Initialize containers
456
+ profit_factor: dict[str, float] = {}
457
+ sum_tp_returns: dict[str, float] = {}
458
+ sum_sl_returns: dict[str, float] = {}
459
+ sum_timeout_returns: dict[str, float] = {}
460
+ sum_all_returns: dict[str, float] = {}
461
+ avg_tp_return: dict[str, float] = {}
462
+ avg_sl_return: dict[str, float] = {}
463
+ avg_return: dict[str, float] = {}
464
+ count_tp: dict[str, int] = {}
465
+ count_sl: dict[str, int] = {}
466
+ count_total: dict[str, int] = {}
467
+
468
+ for q in q_labels:
469
+ q_data = df.filter(pl.col("quantile") == q)
470
+ n_total = q_data.height
471
+
472
+ if n_total == 0:
473
+ profit_factor[q] = 0.0
474
+ sum_tp_returns[q] = 0.0
475
+ sum_sl_returns[q] = 0.0
476
+ sum_timeout_returns[q] = 0.0
477
+ sum_all_returns[q] = 0.0
478
+ avg_tp_return[q] = 0.0
479
+ avg_sl_return[q] = 0.0
480
+ avg_return[q] = 0.0
481
+ count_tp[q] = 0
482
+ count_sl[q] = 0
483
+ count_total[q] = 0
484
+ continue
485
+
486
+ # TP returns
487
+ tp_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value)
488
+ n_tp = tp_data.height
489
+ s_tp = tp_data[cfg.label_return_col].sum() if n_tp > 0 else 0.0
490
+
491
+ # SL returns
492
+ sl_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value)
493
+ n_sl = sl_data.height
494
+ s_sl = sl_data[cfg.label_return_col].sum() if n_sl > 0 else 0.0
495
+
496
+ # Timeout returns
497
+ timeout_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TIMEOUT.value)
498
+ s_timeout = timeout_data[cfg.label_return_col].sum() if timeout_data.height > 0 else 0.0
499
+
500
+ # Total returns
501
+ s_all = q_data[cfg.label_return_col].sum()
502
+
503
+ # Profit factor: PF = sum(TP) / |sum(SL)|
504
+ # SL returns are typically negative, so we use abs
505
+ denom = abs(s_sl) + eps if s_sl != 0 else eps
506
+ pf = s_tp / denom if s_tp > 0 else 0.0
507
+
508
+ # Store results
509
+ profit_factor[q] = float(pf)
510
+ sum_tp_returns[q] = float(s_tp) if s_tp is not None else 0.0
511
+ sum_sl_returns[q] = float(s_sl) if s_sl is not None else 0.0
512
+ sum_timeout_returns[q] = float(s_timeout) if s_timeout is not None else 0.0
513
+ sum_all_returns[q] = float(s_all) if s_all is not None else 0.0
514
+ avg_tp_return[q] = float(s_tp / n_tp) if n_tp > 0 and s_tp is not None else 0.0
515
+ avg_sl_return[q] = float(s_sl / n_sl) if n_sl > 0 and s_sl is not None else 0.0
516
+ avg_return[q] = float(s_all / n_total) if s_all is not None else 0.0
517
+ count_tp[q] = n_tp
518
+ count_sl[q] = n_sl
519
+ count_total[q] = n_total
520
+
521
+ # Overall metrics
522
+ total_obs = df.height
523
+ total_tp_returns = df.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value)[
524
+ cfg.label_return_col
525
+ ].sum()
526
+ total_sl_returns = df.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value)[
527
+ cfg.label_return_col
528
+ ].sum()
529
+
530
+ total_tp_returns = float(total_tp_returns) if total_tp_returns is not None else 0.0
531
+ total_sl_returns = float(total_sl_returns) if total_sl_returns is not None else 0.0
532
+
533
+ overall_pf_denom = abs(total_sl_returns) + eps if total_sl_returns != 0 else eps
534
+ overall_pf = total_tp_returns / overall_pf_denom if total_tp_returns > 0 else 0.0
535
+
536
+ overall_sum = df[cfg.label_return_col].sum()
537
+ overall_sum = float(overall_sum) if overall_sum is not None else 0.0
538
+ overall_avg = overall_sum / total_obs
539
+
540
+ # Monotonicity analysis for profit factor
541
+ pf_values = [profit_factor[q] for q in q_labels]
542
+ pf_monotonic, pf_direction, pf_spearman = self._analyze_monotonicity(pf_values)
543
+
544
+ self._profit_factor_result = ProfitFactorResult(
545
+ n_quantiles=cfg.n_quantiles,
546
+ quantile_labels=q_labels,
547
+ profit_factor=profit_factor,
548
+ sum_tp_returns=sum_tp_returns,
549
+ sum_sl_returns=sum_sl_returns,
550
+ sum_timeout_returns=sum_timeout_returns,
551
+ sum_all_returns=sum_all_returns,
552
+ avg_tp_return=avg_tp_return,
553
+ avg_sl_return=avg_sl_return,
554
+ avg_return=avg_return,
555
+ count_tp=count_tp,
556
+ count_sl=count_sl,
557
+ count_total=count_total,
558
+ overall_profit_factor=overall_pf,
559
+ overall_sum_returns=overall_sum,
560
+ overall_avg_return=overall_avg,
561
+ n_observations=total_obs,
562
+ pf_monotonic=pf_monotonic,
563
+ pf_direction=pf_direction,
564
+ pf_spearman=pf_spearman,
565
+ )
566
+
567
+ return self._profit_factor_result
568
+
569
+ def compute_precision_recall(self) -> PrecisionRecallResult:
570
+ """Compute precision and recall metrics for barrier outcomes.
571
+
572
+ For the top signal quantile (highest signals), computes:
573
+ - Precision: P(TP | in quantile) = TP count / total in quantile
574
+ - Recall: P(in quantile | TP) = TP in quantile / all TP
575
+
576
+ Also computes cumulative metrics from the top quantile downward,
577
+ and lift (precision relative to baseline TP rate).
578
+
579
+ Returns
580
+ -------
581
+ PrecisionRecallResult
582
+ Results containing precision, recall, F1, and lift metrics
583
+ per quantile and cumulative from top down.
584
+
585
+ Examples
586
+ --------
587
+ >>> result = analysis.compute_precision_recall()
588
+ >>> print(result.summary())
589
+ >>> df = result.get_dataframe("cumulative")
590
+ """
591
+ if self._precision_recall_result is not None:
592
+ return self._precision_recall_result
593
+
594
+ cfg = self.config
595
+ df = self._merged_data
596
+ q_labels = self.quantile_labels
597
+
598
+ # Total TP count (baseline)
599
+ total_tp = df.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value).height
600
+ total_obs = df.height
601
+ baseline_tp_rate = total_tp / total_obs if total_obs > 0 else 0.0
602
+
603
+ # Per-quantile precision and recall
604
+ precision_tp: dict[str, float] = {}
605
+ recall_tp: dict[str, float] = {}
606
+ lift_tp: dict[str, float] = {}
607
+
608
+ # Count TP per quantile for cumulative calculations
609
+ tp_counts: dict[str, int] = {}
610
+ total_counts: dict[str, int] = {}
611
+
612
+ for q in q_labels:
613
+ q_data = df.filter(pl.col("quantile") == q)
614
+ n_total = q_data.height
615
+ n_tp = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value).height
616
+
617
+ tp_counts[q] = n_tp
618
+ total_counts[q] = n_total
619
+
620
+ # Precision: P(TP | in this quantile)
621
+ prec = n_tp / n_total if n_total > 0 else 0.0
622
+ precision_tp[q] = prec
623
+
624
+ # Recall: P(in this quantile | TP)
625
+ rec = n_tp / total_tp if total_tp > 0 else 0.0
626
+ recall_tp[q] = rec
627
+
628
+ # Lift: precision / baseline
629
+ lift = prec / baseline_tp_rate if baseline_tp_rate > 0 else 0.0
630
+ lift_tp[q] = lift
631
+
632
+ # Cumulative metrics (from top quantile down)
633
+ # Reverse order: D10 is highest signal, then D9, etc.
634
+ reversed_labels = list(reversed(q_labels))
635
+
636
+ cumulative_precision_tp: dict[str, float] = {}
637
+ cumulative_recall_tp: dict[str, float] = {}
638
+ cumulative_f1_tp: dict[str, float] = {}
639
+ cumulative_lift_tp: dict[str, float] = {}
640
+
641
+ cum_tp = 0
642
+ cum_total = 0
643
+
644
+ best_f1 = 0.0
645
+ best_f1_q = q_labels[-1] # Default to top quantile
646
+
647
+ for q in reversed_labels:
648
+ cum_tp += tp_counts[q]
649
+ cum_total += total_counts[q]
650
+
651
+ # Cumulative precision
652
+ cum_prec = cum_tp / cum_total if cum_total > 0 else 0.0
653
+ cumulative_precision_tp[q] = cum_prec
654
+
655
+ # Cumulative recall
656
+ cum_rec = cum_tp / total_tp if total_tp > 0 else 0.0
657
+ cumulative_recall_tp[q] = cum_rec
658
+
659
+ # F1 score
660
+ if cum_prec + cum_rec > 0:
661
+ f1 = 2 * cum_prec * cum_rec / (cum_prec + cum_rec)
662
+ else:
663
+ f1 = 0.0
664
+ cumulative_f1_tp[q] = f1
665
+
666
+ # Track best F1
667
+ if f1 > best_f1:
668
+ best_f1 = f1
669
+ best_f1_q = q
670
+
671
+ # Cumulative lift
672
+ cum_lift = cum_prec / baseline_tp_rate if baseline_tp_rate > 0 else 0.0
673
+ cumulative_lift_tp[q] = cum_lift
674
+
675
+ self._precision_recall_result = PrecisionRecallResult(
676
+ n_quantiles=cfg.n_quantiles,
677
+ quantile_labels=q_labels,
678
+ precision_tp=precision_tp,
679
+ recall_tp=recall_tp,
680
+ cumulative_precision_tp=cumulative_precision_tp,
681
+ cumulative_recall_tp=cumulative_recall_tp,
682
+ cumulative_f1_tp=cumulative_f1_tp,
683
+ lift_tp=lift_tp,
684
+ cumulative_lift_tp=cumulative_lift_tp,
685
+ baseline_tp_rate=baseline_tp_rate,
686
+ total_tp_count=total_tp,
687
+ n_observations=total_obs,
688
+ best_f1_quantile=best_f1_q,
689
+ best_f1_score=best_f1,
690
+ )
691
+
692
+ return self._precision_recall_result
693
+
694
+ def compute_time_to_target(self) -> TimeToTargetResult:
695
+ """Compute time-to-target metrics by signal decile.
696
+
697
+ Analyzes how quickly different signal quantiles reach their barrier
698
+ outcomes (TP, SL, or timeout). Uses the `label_bars` column from
699
+ barrier labels to measure time to exit.
700
+
701
+ Returns
702
+ -------
703
+ TimeToTargetResult
704
+ Results containing mean, median, and std of bars to exit
705
+ per quantile and outcome type.
706
+
707
+ Raises
708
+ ------
709
+ ValueError
710
+ If label_bars column is not available in barrier_labels.
711
+
712
+ Examples
713
+ --------
714
+ >>> result = analysis.compute_time_to_target()
715
+ >>> print(result.summary())
716
+ >>> df = result.get_dataframe("detailed")
717
+ """
718
+ if self._time_to_target_result is not None:
719
+ return self._time_to_target_result
720
+
721
+ cfg = self.config
722
+ df = self._merged_data
723
+ q_labels = self.quantile_labels
724
+
725
+ # Check if label_bars column exists
726
+ if cfg.label_bars_col not in df.columns:
727
+ raise ValueError(
728
+ f"Time-to-target analysis requires '{cfg.label_bars_col}' column in barrier_labels. "
729
+ f"Available columns: {df.columns}"
730
+ )
731
+
732
+ # Initialize containers
733
+ mean_bars_tp: dict[str, float] = {}
734
+ mean_bars_sl: dict[str, float] = {}
735
+ mean_bars_timeout: dict[str, float] = {}
736
+ mean_bars_all: dict[str, float] = {}
737
+ median_bars_tp: dict[str, float] = {}
738
+ median_bars_sl: dict[str, float] = {}
739
+ median_bars_all: dict[str, float] = {}
740
+ std_bars_tp: dict[str, float] = {}
741
+ std_bars_sl: dict[str, float] = {}
742
+ std_bars_all: dict[str, float] = {}
743
+ count_tp: dict[str, int] = {}
744
+ count_sl: dict[str, int] = {}
745
+ count_timeout: dict[str, int] = {}
746
+ tp_faster_than_sl: dict[str, bool] = {}
747
+ speed_advantage_tp: dict[str, float] = {}
748
+
749
+ for q in q_labels:
750
+ q_data = df.filter(pl.col("quantile") == q)
751
+
752
+ # TP outcomes
753
+ tp_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value)
754
+ n_tp = tp_data.height
755
+ count_tp[q] = n_tp
756
+
757
+ if n_tp > 0:
758
+ tp_bars = tp_data[cfg.label_bars_col]
759
+ mean_bars_tp[q] = float(tp_bars.mean() or 0.0)
760
+ median_bars_tp[q] = float(tp_bars.median() or 0.0)
761
+ std_bars_tp[q] = float(tp_bars.std() or 0.0)
762
+ else:
763
+ mean_bars_tp[q] = 0.0
764
+ median_bars_tp[q] = 0.0
765
+ std_bars_tp[q] = 0.0
766
+
767
+ # SL outcomes
768
+ sl_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value)
769
+ n_sl = sl_data.height
770
+ count_sl[q] = n_sl
771
+
772
+ if n_sl > 0:
773
+ sl_bars = sl_data[cfg.label_bars_col]
774
+ mean_bars_sl[q] = float(sl_bars.mean() or 0.0)
775
+ median_bars_sl[q] = float(sl_bars.median() or 0.0)
776
+ std_bars_sl[q] = float(sl_bars.std() or 0.0)
777
+ else:
778
+ mean_bars_sl[q] = 0.0
779
+ median_bars_sl[q] = 0.0
780
+ std_bars_sl[q] = 0.0
781
+
782
+ # Timeout outcomes
783
+ timeout_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TIMEOUT.value)
784
+ n_timeout = timeout_data.height
785
+ count_timeout[q] = n_timeout
786
+
787
+ if n_timeout > 0:
788
+ mean_bars_timeout[q] = float(timeout_data[cfg.label_bars_col].mean() or 0.0)
789
+ else:
790
+ mean_bars_timeout[q] = 0.0
791
+
792
+ # All outcomes
793
+ n_all = q_data.height
794
+ if n_all > 0:
795
+ all_bars = q_data[cfg.label_bars_col]
796
+ mean_bars_all[q] = float(all_bars.mean() or 0.0)
797
+ median_bars_all[q] = float(all_bars.median() or 0.0)
798
+ std_bars_all[q] = float(all_bars.std() or 0.0)
799
+ else:
800
+ mean_bars_all[q] = 0.0
801
+ median_bars_all[q] = 0.0
802
+ std_bars_all[q] = 0.0
803
+
804
+ # Speed analysis: is TP reached faster than SL?
805
+ if n_tp > 0 and n_sl > 0:
806
+ tp_faster = mean_bars_tp[q] < mean_bars_sl[q]
807
+ speed_adv = mean_bars_sl[q] - mean_bars_tp[q]
808
+ elif n_tp > 0:
809
+ tp_faster = True
810
+ speed_adv = 0.0
811
+ elif n_sl > 0:
812
+ tp_faster = False
813
+ speed_adv = 0.0
814
+ else:
815
+ tp_faster = False
816
+ speed_adv = 0.0
817
+
818
+ tp_faster_than_sl[q] = tp_faster
819
+ speed_advantage_tp[q] = speed_adv
820
+
821
+ # Overall statistics
822
+ total_obs = df.height
823
+ all_bars = df[cfg.label_bars_col]
824
+ overall_mean_bars = float(all_bars.mean() or 0.0)
825
+ overall_median_bars = float(all_bars.median() or 0.0)
826
+
827
+ tp_all = df.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value)
828
+ overall_mean_bars_tp = (
829
+ float(tp_all[cfg.label_bars_col].mean() or 0.0) if tp_all.height > 0 else 0.0
830
+ )
831
+
832
+ sl_all = df.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value)
833
+ overall_mean_bars_sl = (
834
+ float(sl_all[cfg.label_bars_col].mean() or 0.0) if sl_all.height > 0 else 0.0
835
+ )
836
+
837
+ self._time_to_target_result = TimeToTargetResult(
838
+ n_quantiles=cfg.n_quantiles,
839
+ quantile_labels=q_labels,
840
+ mean_bars_tp=mean_bars_tp,
841
+ mean_bars_sl=mean_bars_sl,
842
+ mean_bars_timeout=mean_bars_timeout,
843
+ mean_bars_all=mean_bars_all,
844
+ median_bars_tp=median_bars_tp,
845
+ median_bars_sl=median_bars_sl,
846
+ median_bars_all=median_bars_all,
847
+ std_bars_tp=std_bars_tp,
848
+ std_bars_sl=std_bars_sl,
849
+ std_bars_all=std_bars_all,
850
+ count_tp=count_tp,
851
+ count_sl=count_sl,
852
+ count_timeout=count_timeout,
853
+ overall_mean_bars=overall_mean_bars,
854
+ overall_median_bars=overall_median_bars,
855
+ overall_mean_bars_tp=overall_mean_bars_tp,
856
+ overall_mean_bars_sl=overall_mean_bars_sl,
857
+ n_observations=total_obs,
858
+ tp_faster_than_sl=tp_faster_than_sl,
859
+ speed_advantage_tp=speed_advantage_tp,
860
+ )
861
+
862
+ return self._time_to_target_result
863
+
864
+ def _analyze_monotonicity(
865
+ self,
866
+ values: list[float],
867
+ ) -> tuple[bool, str, float]:
868
+ """Analyze monotonicity of values across quantiles.
869
+
870
+ Parameters
871
+ ----------
872
+ values : list[float]
873
+ Values for each quantile (ordered by quantile rank).
874
+
875
+ Returns
876
+ -------
877
+ tuple[bool, str, float]
878
+ (is_monotonic, direction, spearman_correlation)
879
+ direction is 'increasing', 'decreasing', or 'none'
880
+ """
881
+ if len(values) < 2:
882
+ return False, "none", 0.0
883
+
884
+ # Remove any NaN/inf values for correlation
885
+ valid_values = [v for v in values if np.isfinite(v)]
886
+ if len(valid_values) < 2:
887
+ return False, "none", 0.0
888
+
889
+ # Spearman correlation with rank
890
+ ranks = list(range(len(valid_values)))
891
+ try:
892
+ spearman_corr, _ = stats.spearmanr(ranks, valid_values)
893
+ except Exception:
894
+ spearman_corr = 0.0
895
+
896
+ spearman_corr = float(spearman_corr) if np.isfinite(spearman_corr) else 0.0
897
+
898
+ # Check strict monotonicity
899
+ diffs = [values[i + 1] - values[i] for i in range(len(values) - 1)]
900
+ all_increasing = all(d >= 0 for d in diffs) and any(d > 0 for d in diffs)
901
+ all_decreasing = all(d <= 0 for d in diffs) and any(d < 0 for d in diffs)
902
+
903
+ if all_increasing:
904
+ return True, "increasing", spearman_corr
905
+ elif all_decreasing:
906
+ return True, "decreasing", spearman_corr
907
+ else:
908
+ return False, "none", spearman_corr
909
+
910
+ def create_tear_sheet(
911
+ self,
912
+ include_time_to_target: bool = True,
913
+ include_figures: bool = True,
914
+ theme: str | None = None,
915
+ ) -> BarrierTearSheet:
916
+ """Create comprehensive tear sheet with all analysis results.
917
+
918
+ Parameters
919
+ ----------
920
+ include_time_to_target : bool, default=True
921
+ If True, include time-to-target analysis. Requires `label_bars`
922
+ column in barrier_labels. Set to False if column not available.
923
+ include_figures : bool, default=True
924
+ If True, generate Plotly figures for visualization.
925
+ Set to False to skip figure generation (faster).
926
+ theme : str | None
927
+ Plot theme: 'default', 'dark', 'print', 'presentation'.
928
+ If None, uses default theme.
929
+
930
+ Returns
931
+ -------
932
+ BarrierTearSheet
933
+ Complete results including hit rates, profit factor,
934
+ precision/recall, time-to-target, figures, and metadata.
935
+
936
+ Examples
937
+ --------
938
+ >>> tear_sheet = analysis.create_tear_sheet()
939
+ >>> tear_sheet.save_html("barrier_analysis.html")
940
+ >>> print(tear_sheet.summary())
941
+ """
942
+ # Compute all metrics
943
+ hit_rate = self.compute_hit_rates()
944
+ profit_factor = self.compute_profit_factor()
945
+ precision_recall = self.compute_precision_recall()
946
+
947
+ # Time-to-target is optional (requires label_bars column)
948
+ time_to_target = None
949
+ if include_time_to_target:
950
+ try:
951
+ time_to_target = self.compute_time_to_target()
952
+ except ValueError:
953
+ # label_bars column not available, skip
954
+ pass
955
+
956
+ # Generate figures if requested
957
+ figures: dict[str, str] = {}
958
+ if include_figures:
959
+ figures = self._generate_figures(
960
+ hit_rate=hit_rate,
961
+ profit_factor=profit_factor,
962
+ precision_recall=precision_recall,
963
+ time_to_target=time_to_target,
964
+ theme=theme,
965
+ )
966
+
967
+ return BarrierTearSheet(
968
+ hit_rate_result=hit_rate,
969
+ profit_factor_result=profit_factor,
970
+ precision_recall_result=precision_recall,
971
+ time_to_target_result=time_to_target,
972
+ signal_name=self.config.signal_name,
973
+ n_assets=self.n_assets,
974
+ n_dates=self.n_dates,
975
+ n_observations=self.n_observations,
976
+ date_range=self.date_range,
977
+ figures=figures,
978
+ )
979
+
980
+ def _generate_figures(
981
+ self,
982
+ hit_rate: HitRateResult,
983
+ profit_factor: ProfitFactorResult,
984
+ precision_recall: PrecisionRecallResult,
985
+ time_to_target: TimeToTargetResult | None,
986
+ theme: str | None = None,
987
+ ) -> dict[str, str]:
988
+ """Generate Plotly figures for the tear sheet.
989
+
990
+ Parameters
991
+ ----------
992
+ hit_rate : HitRateResult
993
+ Hit rate analysis results.
994
+ profit_factor : ProfitFactorResult
995
+ Profit factor analysis results.
996
+ precision_recall : PrecisionRecallResult
997
+ Precision/recall analysis results.
998
+ time_to_target : TimeToTargetResult | None
999
+ Time-to-target analysis results (optional).
1000
+ theme : str | None
1001
+ Plot theme.
1002
+
1003
+ Returns
1004
+ -------
1005
+ dict[str, str]
1006
+ Dict mapping figure names to JSON-serialized Plotly figures.
1007
+ """
1008
+ import plotly.io as pio
1009
+
1010
+ from ml4t.diagnostic.visualization.barrier_plots import (
1011
+ plot_hit_rate_heatmap,
1012
+ plot_precision_recall_curve,
1013
+ plot_profit_factor_bar,
1014
+ plot_time_to_target_box,
1015
+ )
1016
+
1017
+ figures: dict[str, str] = {}
1018
+
1019
+ # Hit Rate Heatmap
1020
+ try:
1021
+ fig = plot_hit_rate_heatmap(hit_rate, theme=theme)
1022
+ figures["hit_rate_heatmap"] = pio.to_json(fig)
1023
+ except Exception:
1024
+ pass # Skip if visualization fails
1025
+
1026
+ # Profit Factor Bar Chart
1027
+ try:
1028
+ fig = plot_profit_factor_bar(profit_factor, theme=theme)
1029
+ figures["profit_factor_bar"] = pio.to_json(fig)
1030
+ except Exception:
1031
+ pass
1032
+
1033
+ # Precision/Recall Curve
1034
+ try:
1035
+ fig = plot_precision_recall_curve(precision_recall, theme=theme)
1036
+ figures["precision_recall_curve"] = pio.to_json(fig)
1037
+ except Exception:
1038
+ pass
1039
+
1040
+ # Time-to-Target Box Plots (if available)
1041
+ if time_to_target is not None:
1042
+ try:
1043
+ fig = plot_time_to_target_box(
1044
+ time_to_target, outcome_type="comparison", theme=theme
1045
+ )
1046
+ figures["time_to_target_comparison"] = pio.to_json(fig)
1047
+ except Exception:
1048
+ pass
1049
+
1050
+ return figures