ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1050 @@
|
|
|
1
|
+
"""Barrier Analysis module for triple barrier outcome evaluation.
|
|
2
|
+
|
|
3
|
+
This module provides analysis of signal quality using triple barrier outcomes
|
|
4
|
+
(take-profit, stop-loss, timeout) instead of simple forward returns.
|
|
5
|
+
|
|
6
|
+
The BarrierAnalysis class computes:
|
|
7
|
+
- Hit rates by signal decile (% TP, % SL, % timeout)
|
|
8
|
+
- Profit factor by decile (sum TP returns / |sum SL returns|)
|
|
9
|
+
- Statistical tests for signal-outcome independence (chi-square)
|
|
10
|
+
- Monotonicity tests for signal strength vs outcome relationship
|
|
11
|
+
|
|
12
|
+
Triple barrier outcomes from ml4t.features:
|
|
13
|
+
- label: int (-1=SL hit, 0=timeout, 1=TP hit)
|
|
14
|
+
- label_return: float (actual return at exit)
|
|
15
|
+
- label_bars: int (bars from entry to exit)
|
|
16
|
+
|
|
17
|
+
References
|
|
18
|
+
----------
|
|
19
|
+
Lopez de Prado, M. (2018). "Advances in Financial Machine Learning"
|
|
20
|
+
Chapter 3: Labeling (Triple Barrier Method)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import warnings
|
|
26
|
+
from typing import TYPE_CHECKING
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
import polars as pl
|
|
30
|
+
from scipy import stats
|
|
31
|
+
|
|
32
|
+
from ml4t.diagnostic.config.barrier_config import BarrierConfig, BarrierLabel
|
|
33
|
+
from ml4t.diagnostic.results.barrier_results import (
|
|
34
|
+
BarrierTearSheet,
|
|
35
|
+
HitRateResult,
|
|
36
|
+
PrecisionRecallResult,
|
|
37
|
+
ProfitFactorResult,
|
|
38
|
+
TimeToTargetResult,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if TYPE_CHECKING:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class BarrierAnalysis:
|
|
46
|
+
"""Analyze signal quality using triple barrier outcomes.
|
|
47
|
+
|
|
48
|
+
This class evaluates how well a signal predicts barrier outcomes
|
|
49
|
+
(take-profit hit, stop-loss hit, or timeout) rather than raw returns.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
signal_data : pl.DataFrame
|
|
54
|
+
DataFrame with columns: [date_col, asset_col, signal_col]
|
|
55
|
+
Contains signal values for each asset-date pair.
|
|
56
|
+
|
|
57
|
+
barrier_labels : pl.DataFrame
|
|
58
|
+
DataFrame with columns: [date_col, asset_col, label_col, label_return_col, label_bars_col]
|
|
59
|
+
Contains triple barrier outcomes from ml4t.features.triple_barrier_labels().
|
|
60
|
+
|
|
61
|
+
config : BarrierConfig | None, optional
|
|
62
|
+
Configuration for analysis. Uses defaults if not provided.
|
|
63
|
+
|
|
64
|
+
Examples
|
|
65
|
+
--------
|
|
66
|
+
>>> from ml4t.diagnostic.evaluation import BarrierAnalysis
|
|
67
|
+
>>> from ml4t.diagnostic.config import BarrierConfig
|
|
68
|
+
>>>
|
|
69
|
+
>>> # Basic usage
|
|
70
|
+
>>> analysis = BarrierAnalysis(signals_df, barriers_df)
|
|
71
|
+
>>> hit_rates = analysis.compute_hit_rates()
|
|
72
|
+
>>> print(hit_rates.summary())
|
|
73
|
+
>>>
|
|
74
|
+
>>> # With custom config
|
|
75
|
+
>>> config = BarrierConfig(n_quantiles=5)
|
|
76
|
+
>>> analysis = BarrierAnalysis(signals_df, barriers_df, config=config)
|
|
77
|
+
>>> profit_factor = analysis.compute_profit_factor()
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
signal_data: pl.DataFrame,
|
|
83
|
+
barrier_labels: pl.DataFrame,
|
|
84
|
+
config: BarrierConfig | None = None,
|
|
85
|
+
) -> None:
|
|
86
|
+
"""Initialize BarrierAnalysis.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
signal_data : pl.DataFrame
|
|
91
|
+
Signal values with date, asset, signal columns.
|
|
92
|
+
barrier_labels : pl.DataFrame
|
|
93
|
+
Barrier outcomes with date, asset, label, label_return, label_bars columns.
|
|
94
|
+
config : BarrierConfig | None
|
|
95
|
+
Configuration object. Uses defaults if None.
|
|
96
|
+
|
|
97
|
+
Raises
|
|
98
|
+
------
|
|
99
|
+
ValueError
|
|
100
|
+
If required columns are missing or data is invalid.
|
|
101
|
+
"""
|
|
102
|
+
self.config = config or BarrierConfig()
|
|
103
|
+
self._validate_inputs(signal_data, barrier_labels)
|
|
104
|
+
|
|
105
|
+
# Store original data
|
|
106
|
+
self._signal_data = signal_data
|
|
107
|
+
self._barrier_labels = barrier_labels
|
|
108
|
+
|
|
109
|
+
# Merge and prepare data
|
|
110
|
+
self._merged_data = self._prepare_data(signal_data, barrier_labels)
|
|
111
|
+
|
|
112
|
+
# Cache for computed results
|
|
113
|
+
self._hit_rate_result: HitRateResult | None = None
|
|
114
|
+
self._profit_factor_result: ProfitFactorResult | None = None
|
|
115
|
+
self._precision_recall_result: PrecisionRecallResult | None = None
|
|
116
|
+
self._time_to_target_result: TimeToTargetResult | None = None
|
|
117
|
+
|
|
118
|
+
def _validate_inputs(
|
|
119
|
+
self,
|
|
120
|
+
signal_data: pl.DataFrame,
|
|
121
|
+
barrier_labels: pl.DataFrame,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Validate input DataFrames have required columns and valid data.
|
|
124
|
+
|
|
125
|
+
Raises
|
|
126
|
+
------
|
|
127
|
+
ValueError
|
|
128
|
+
If validation fails.
|
|
129
|
+
"""
|
|
130
|
+
cfg = self.config
|
|
131
|
+
|
|
132
|
+
# Check signal_data columns
|
|
133
|
+
signal_required = {cfg.date_col, cfg.asset_col, cfg.signal_col}
|
|
134
|
+
signal_cols = set(signal_data.columns)
|
|
135
|
+
missing_signal = signal_required - signal_cols
|
|
136
|
+
if missing_signal:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"signal_data missing required columns: {missing_signal}. "
|
|
139
|
+
f"Available columns: {signal_cols}"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Check barrier_labels columns
|
|
143
|
+
barrier_required = {cfg.date_col, cfg.asset_col, cfg.label_col, cfg.label_return_col}
|
|
144
|
+
barrier_cols = set(barrier_labels.columns)
|
|
145
|
+
missing_barrier = barrier_required - barrier_cols
|
|
146
|
+
if missing_barrier:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"barrier_labels missing required columns: {missing_barrier}. "
|
|
149
|
+
f"Available columns: {barrier_cols}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Check for empty DataFrames
|
|
153
|
+
if signal_data.height == 0:
|
|
154
|
+
raise ValueError("signal_data is empty")
|
|
155
|
+
if barrier_labels.height == 0:
|
|
156
|
+
raise ValueError("barrier_labels is empty")
|
|
157
|
+
|
|
158
|
+
# Validate label values
|
|
159
|
+
valid_labels = {-1, 0, 1}
|
|
160
|
+
unique_labels = set(barrier_labels[cfg.label_col].unique().to_list())
|
|
161
|
+
invalid_labels = unique_labels - valid_labels
|
|
162
|
+
if invalid_labels:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"barrier_labels[{cfg.label_col}] contains invalid values: {invalid_labels}. "
|
|
165
|
+
f"Expected values: {valid_labels} (-1=SL, 0=timeout, 1=TP)"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def _prepare_data(
|
|
169
|
+
self,
|
|
170
|
+
signal_data: pl.DataFrame,
|
|
171
|
+
barrier_labels: pl.DataFrame,
|
|
172
|
+
) -> pl.DataFrame:
|
|
173
|
+
"""Merge signal data with barrier labels and prepare for analysis.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
pl.DataFrame
|
|
178
|
+
Merged DataFrame with signal values and barrier outcomes,
|
|
179
|
+
plus computed quantile labels.
|
|
180
|
+
"""
|
|
181
|
+
cfg = self.config
|
|
182
|
+
|
|
183
|
+
# Merge on date and asset
|
|
184
|
+
merged = signal_data.join(
|
|
185
|
+
barrier_labels,
|
|
186
|
+
on=[cfg.date_col, cfg.asset_col],
|
|
187
|
+
how="inner",
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if merged.height == 0:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
"No matching rows after merging signal_data and barrier_labels. "
|
|
193
|
+
"Check that date and asset columns match."
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Filter outliers if configured
|
|
197
|
+
if cfg.filter_zscore is not None:
|
|
198
|
+
signal_mean = merged[cfg.signal_col].mean()
|
|
199
|
+
signal_std = merged[cfg.signal_col].std()
|
|
200
|
+
if signal_std is not None and signal_std > 0:
|
|
201
|
+
merged = merged.filter(
|
|
202
|
+
((pl.col(cfg.signal_col) - signal_mean) / signal_std).abs() <= cfg.filter_zscore
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Drop NaN signals
|
|
206
|
+
merged = merged.drop_nulls(subset=[cfg.signal_col])
|
|
207
|
+
|
|
208
|
+
if merged.height == 0:
|
|
209
|
+
raise ValueError("No valid observations after filtering NaN signals and outliers")
|
|
210
|
+
|
|
211
|
+
# Add quantile labels
|
|
212
|
+
merged = self._add_quantile_labels(merged)
|
|
213
|
+
|
|
214
|
+
return merged
|
|
215
|
+
|
|
216
|
+
def _add_quantile_labels(self, df: pl.DataFrame) -> pl.DataFrame:
|
|
217
|
+
"""Add quantile labels to DataFrame based on signal values.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
df : pl.DataFrame
|
|
222
|
+
DataFrame with signal column.
|
|
223
|
+
|
|
224
|
+
Returns
|
|
225
|
+
-------
|
|
226
|
+
pl.DataFrame
|
|
227
|
+
DataFrame with added 'quantile' column.
|
|
228
|
+
"""
|
|
229
|
+
cfg = self.config
|
|
230
|
+
n_q = cfg.n_quantiles
|
|
231
|
+
|
|
232
|
+
# Generate quantile labels (D1, D2, ..., D10 for deciles)
|
|
233
|
+
quantile_labels = [f"D{i + 1}" for i in range(n_q)]
|
|
234
|
+
|
|
235
|
+
if cfg.decile_method.value == "quantile":
|
|
236
|
+
# Equal frequency bins (like pd.qcut)
|
|
237
|
+
df = df.with_columns(
|
|
238
|
+
pl.col(cfg.signal_col)
|
|
239
|
+
.qcut(n_q, labels=quantile_labels, allow_duplicates=True)
|
|
240
|
+
.alias("quantile")
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
# Equal width bins (like pd.cut)
|
|
244
|
+
df = df.with_columns(
|
|
245
|
+
pl.col(cfg.signal_col).cut(n_q, labels=quantile_labels).alias("quantile")
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
return df
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def merged_data(self) -> pl.DataFrame:
|
|
252
|
+
"""Get the merged and prepared data."""
|
|
253
|
+
return self._merged_data
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def n_observations(self) -> int:
|
|
257
|
+
"""Total number of observations after merging."""
|
|
258
|
+
return self._merged_data.height
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def n_assets(self) -> int:
|
|
262
|
+
"""Number of unique assets."""
|
|
263
|
+
return self._merged_data[self.config.asset_col].n_unique()
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def n_dates(self) -> int:
|
|
267
|
+
"""Number of unique dates."""
|
|
268
|
+
return self._merged_data[self.config.date_col].n_unique()
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def date_range(self) -> tuple[str, str]:
|
|
272
|
+
"""Date range (start, end) as ISO strings."""
|
|
273
|
+
dates = self._merged_data[self.config.date_col]
|
|
274
|
+
min_date = dates.min()
|
|
275
|
+
max_date = dates.max()
|
|
276
|
+
return (str(min_date), str(max_date))
|
|
277
|
+
|
|
278
|
+
@property
|
|
279
|
+
def quantile_labels(self) -> list[str]:
|
|
280
|
+
"""List of quantile labels used."""
|
|
281
|
+
return [f"D{i + 1}" for i in range(self.config.n_quantiles)]
|
|
282
|
+
|
|
283
|
+
def compute_hit_rates(self) -> HitRateResult:
|
|
284
|
+
"""Compute hit rates by signal decile.
|
|
285
|
+
|
|
286
|
+
For each signal quantile, calculates the percentage of observations
|
|
287
|
+
that hit TP, SL, or timeout barriers.
|
|
288
|
+
|
|
289
|
+
Includes chi-square test for independence between signal strength
|
|
290
|
+
and barrier outcome.
|
|
291
|
+
|
|
292
|
+
Returns
|
|
293
|
+
-------
|
|
294
|
+
HitRateResult
|
|
295
|
+
Results containing hit rates per quantile, chi-square test,
|
|
296
|
+
and monotonicity analysis.
|
|
297
|
+
|
|
298
|
+
Examples
|
|
299
|
+
--------
|
|
300
|
+
>>> result = analysis.compute_hit_rates()
|
|
301
|
+
>>> print(result.summary())
|
|
302
|
+
>>> df = result.get_dataframe("hit_rates")
|
|
303
|
+
"""
|
|
304
|
+
if self._hit_rate_result is not None:
|
|
305
|
+
return self._hit_rate_result
|
|
306
|
+
|
|
307
|
+
cfg = self.config
|
|
308
|
+
df = self._merged_data
|
|
309
|
+
q_labels = self.quantile_labels
|
|
310
|
+
|
|
311
|
+
# Initialize containers
|
|
312
|
+
hit_rate_tp: dict[str, float] = {}
|
|
313
|
+
hit_rate_sl: dict[str, float] = {}
|
|
314
|
+
hit_rate_timeout: dict[str, float] = {}
|
|
315
|
+
count_tp: dict[str, int] = {}
|
|
316
|
+
count_sl: dict[str, int] = {}
|
|
317
|
+
count_timeout: dict[str, int] = {}
|
|
318
|
+
count_total: dict[str, int] = {}
|
|
319
|
+
|
|
320
|
+
# Build contingency table for chi-square test
|
|
321
|
+
# Rows: quantiles, Columns: outcomes (SL, Timeout, TP)
|
|
322
|
+
contingency = np.zeros((cfg.n_quantiles, 3), dtype=np.int64)
|
|
323
|
+
|
|
324
|
+
for i, q in enumerate(q_labels):
|
|
325
|
+
q_data = df.filter(pl.col("quantile") == q)
|
|
326
|
+
n_total = q_data.height
|
|
327
|
+
|
|
328
|
+
if n_total == 0:
|
|
329
|
+
# Handle empty quantile
|
|
330
|
+
hit_rate_tp[q] = 0.0
|
|
331
|
+
hit_rate_sl[q] = 0.0
|
|
332
|
+
hit_rate_timeout[q] = 0.0
|
|
333
|
+
count_tp[q] = 0
|
|
334
|
+
count_sl[q] = 0
|
|
335
|
+
count_timeout[q] = 0
|
|
336
|
+
count_total[q] = 0
|
|
337
|
+
continue
|
|
338
|
+
|
|
339
|
+
# Count outcomes
|
|
340
|
+
n_tp = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value).height
|
|
341
|
+
n_sl = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value).height
|
|
342
|
+
n_timeout = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TIMEOUT.value).height
|
|
343
|
+
|
|
344
|
+
# Hit rates
|
|
345
|
+
hit_rate_tp[q] = n_tp / n_total
|
|
346
|
+
hit_rate_sl[q] = n_sl / n_total
|
|
347
|
+
hit_rate_timeout[q] = n_timeout / n_total
|
|
348
|
+
|
|
349
|
+
# Counts
|
|
350
|
+
count_tp[q] = n_tp
|
|
351
|
+
count_sl[q] = n_sl
|
|
352
|
+
count_timeout[q] = n_timeout
|
|
353
|
+
count_total[q] = n_total
|
|
354
|
+
|
|
355
|
+
# Contingency table row
|
|
356
|
+
contingency[i, 0] = n_sl
|
|
357
|
+
contingency[i, 1] = n_timeout
|
|
358
|
+
contingency[i, 2] = n_tp
|
|
359
|
+
|
|
360
|
+
# Chi-square test for independence
|
|
361
|
+
# H0: Signal quantile and barrier outcome are independent
|
|
362
|
+
# H1: They are dependent (signal predicts outcome)
|
|
363
|
+
|
|
364
|
+
# Remove rows/cols with all zeros to avoid chi2 issues
|
|
365
|
+
row_sums = contingency.sum(axis=1)
|
|
366
|
+
col_sums = contingency.sum(axis=0)
|
|
367
|
+
valid_rows = row_sums > 0
|
|
368
|
+
valid_cols = col_sums > 0
|
|
369
|
+
|
|
370
|
+
if valid_rows.sum() < 2 or valid_cols.sum() < 2:
|
|
371
|
+
# Not enough data for chi-square test
|
|
372
|
+
chi2_stat = 0.0
|
|
373
|
+
chi2_p = 1.0
|
|
374
|
+
chi2_dof = 0
|
|
375
|
+
warnings.warn(
|
|
376
|
+
"Insufficient variation in data for chi-square test. "
|
|
377
|
+
"Need at least 2 non-empty quantiles and 2 different outcomes.",
|
|
378
|
+
UserWarning,
|
|
379
|
+
stacklevel=2,
|
|
380
|
+
)
|
|
381
|
+
else:
|
|
382
|
+
contingency_valid = contingency[valid_rows][:, valid_cols]
|
|
383
|
+
chi2_stat, chi2_p, chi2_dof, _ = stats.chi2_contingency(contingency_valid)
|
|
384
|
+
|
|
385
|
+
# Overall hit rates
|
|
386
|
+
total_obs = df.height
|
|
387
|
+
overall_tp = (
|
|
388
|
+
df.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value).height / total_obs
|
|
389
|
+
)
|
|
390
|
+
overall_sl = (
|
|
391
|
+
df.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value).height / total_obs
|
|
392
|
+
)
|
|
393
|
+
overall_timeout = (
|
|
394
|
+
df.filter(pl.col(cfg.label_col) == BarrierLabel.TIMEOUT.value).height / total_obs
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# Monotonicity analysis for TP rate
|
|
398
|
+
tp_rates = [hit_rate_tp[q] for q in q_labels]
|
|
399
|
+
tp_monotonic, tp_direction, tp_spearman = self._analyze_monotonicity(tp_rates)
|
|
400
|
+
|
|
401
|
+
self._hit_rate_result = HitRateResult(
|
|
402
|
+
n_quantiles=cfg.n_quantiles,
|
|
403
|
+
quantile_labels=q_labels,
|
|
404
|
+
hit_rate_tp=hit_rate_tp,
|
|
405
|
+
hit_rate_sl=hit_rate_sl,
|
|
406
|
+
hit_rate_timeout=hit_rate_timeout,
|
|
407
|
+
count_tp=count_tp,
|
|
408
|
+
count_sl=count_sl,
|
|
409
|
+
count_timeout=count_timeout,
|
|
410
|
+
count_total=count_total,
|
|
411
|
+
chi2_statistic=float(chi2_stat),
|
|
412
|
+
chi2_p_value=float(chi2_p),
|
|
413
|
+
chi2_dof=int(chi2_dof),
|
|
414
|
+
is_significant=chi2_p < cfg.significance_level,
|
|
415
|
+
significance_level=cfg.significance_level,
|
|
416
|
+
overall_hit_rate_tp=overall_tp,
|
|
417
|
+
overall_hit_rate_sl=overall_sl,
|
|
418
|
+
overall_hit_rate_timeout=overall_timeout,
|
|
419
|
+
n_observations=total_obs,
|
|
420
|
+
tp_rate_monotonic=tp_monotonic,
|
|
421
|
+
tp_rate_direction=tp_direction,
|
|
422
|
+
tp_rate_spearman=tp_spearman,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
return self._hit_rate_result
|
|
426
|
+
|
|
427
|
+
def compute_profit_factor(self) -> ProfitFactorResult:
|
|
428
|
+
"""Compute profit factor by signal decile.
|
|
429
|
+
|
|
430
|
+
Profit Factor = Sum(TP returns) / |Sum(SL returns)|
|
|
431
|
+
|
|
432
|
+
A profit factor > 1 indicates the quantile is net profitable
|
|
433
|
+
when trading based on the signal.
|
|
434
|
+
|
|
435
|
+
Returns
|
|
436
|
+
-------
|
|
437
|
+
ProfitFactorResult
|
|
438
|
+
Results containing profit factor per quantile and
|
|
439
|
+
return statistics.
|
|
440
|
+
|
|
441
|
+
Examples
|
|
442
|
+
--------
|
|
443
|
+
>>> result = analysis.compute_profit_factor()
|
|
444
|
+
>>> print(result.summary())
|
|
445
|
+
>>> df = result.get_dataframe()
|
|
446
|
+
"""
|
|
447
|
+
if self._profit_factor_result is not None:
|
|
448
|
+
return self._profit_factor_result
|
|
449
|
+
|
|
450
|
+
cfg = self.config
|
|
451
|
+
df = self._merged_data
|
|
452
|
+
q_labels = self.quantile_labels
|
|
453
|
+
eps = cfg.profit_factor_epsilon
|
|
454
|
+
|
|
455
|
+
# Initialize containers
|
|
456
|
+
profit_factor: dict[str, float] = {}
|
|
457
|
+
sum_tp_returns: dict[str, float] = {}
|
|
458
|
+
sum_sl_returns: dict[str, float] = {}
|
|
459
|
+
sum_timeout_returns: dict[str, float] = {}
|
|
460
|
+
sum_all_returns: dict[str, float] = {}
|
|
461
|
+
avg_tp_return: dict[str, float] = {}
|
|
462
|
+
avg_sl_return: dict[str, float] = {}
|
|
463
|
+
avg_return: dict[str, float] = {}
|
|
464
|
+
count_tp: dict[str, int] = {}
|
|
465
|
+
count_sl: dict[str, int] = {}
|
|
466
|
+
count_total: dict[str, int] = {}
|
|
467
|
+
|
|
468
|
+
for q in q_labels:
|
|
469
|
+
q_data = df.filter(pl.col("quantile") == q)
|
|
470
|
+
n_total = q_data.height
|
|
471
|
+
|
|
472
|
+
if n_total == 0:
|
|
473
|
+
profit_factor[q] = 0.0
|
|
474
|
+
sum_tp_returns[q] = 0.0
|
|
475
|
+
sum_sl_returns[q] = 0.0
|
|
476
|
+
sum_timeout_returns[q] = 0.0
|
|
477
|
+
sum_all_returns[q] = 0.0
|
|
478
|
+
avg_tp_return[q] = 0.0
|
|
479
|
+
avg_sl_return[q] = 0.0
|
|
480
|
+
avg_return[q] = 0.0
|
|
481
|
+
count_tp[q] = 0
|
|
482
|
+
count_sl[q] = 0
|
|
483
|
+
count_total[q] = 0
|
|
484
|
+
continue
|
|
485
|
+
|
|
486
|
+
# TP returns
|
|
487
|
+
tp_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value)
|
|
488
|
+
n_tp = tp_data.height
|
|
489
|
+
s_tp = tp_data[cfg.label_return_col].sum() if n_tp > 0 else 0.0
|
|
490
|
+
|
|
491
|
+
# SL returns
|
|
492
|
+
sl_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value)
|
|
493
|
+
n_sl = sl_data.height
|
|
494
|
+
s_sl = sl_data[cfg.label_return_col].sum() if n_sl > 0 else 0.0
|
|
495
|
+
|
|
496
|
+
# Timeout returns
|
|
497
|
+
timeout_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TIMEOUT.value)
|
|
498
|
+
s_timeout = timeout_data[cfg.label_return_col].sum() if timeout_data.height > 0 else 0.0
|
|
499
|
+
|
|
500
|
+
# Total returns
|
|
501
|
+
s_all = q_data[cfg.label_return_col].sum()
|
|
502
|
+
|
|
503
|
+
# Profit factor: PF = sum(TP) / |sum(SL)|
|
|
504
|
+
# SL returns are typically negative, so we use abs
|
|
505
|
+
denom = abs(s_sl) + eps if s_sl != 0 else eps
|
|
506
|
+
pf = s_tp / denom if s_tp > 0 else 0.0
|
|
507
|
+
|
|
508
|
+
# Store results
|
|
509
|
+
profit_factor[q] = float(pf)
|
|
510
|
+
sum_tp_returns[q] = float(s_tp) if s_tp is not None else 0.0
|
|
511
|
+
sum_sl_returns[q] = float(s_sl) if s_sl is not None else 0.0
|
|
512
|
+
sum_timeout_returns[q] = float(s_timeout) if s_timeout is not None else 0.0
|
|
513
|
+
sum_all_returns[q] = float(s_all) if s_all is not None else 0.0
|
|
514
|
+
avg_tp_return[q] = float(s_tp / n_tp) if n_tp > 0 and s_tp is not None else 0.0
|
|
515
|
+
avg_sl_return[q] = float(s_sl / n_sl) if n_sl > 0 and s_sl is not None else 0.0
|
|
516
|
+
avg_return[q] = float(s_all / n_total) if s_all is not None else 0.0
|
|
517
|
+
count_tp[q] = n_tp
|
|
518
|
+
count_sl[q] = n_sl
|
|
519
|
+
count_total[q] = n_total
|
|
520
|
+
|
|
521
|
+
# Overall metrics
|
|
522
|
+
total_obs = df.height
|
|
523
|
+
total_tp_returns = df.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value)[
|
|
524
|
+
cfg.label_return_col
|
|
525
|
+
].sum()
|
|
526
|
+
total_sl_returns = df.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value)[
|
|
527
|
+
cfg.label_return_col
|
|
528
|
+
].sum()
|
|
529
|
+
|
|
530
|
+
total_tp_returns = float(total_tp_returns) if total_tp_returns is not None else 0.0
|
|
531
|
+
total_sl_returns = float(total_sl_returns) if total_sl_returns is not None else 0.0
|
|
532
|
+
|
|
533
|
+
overall_pf_denom = abs(total_sl_returns) + eps if total_sl_returns != 0 else eps
|
|
534
|
+
overall_pf = total_tp_returns / overall_pf_denom if total_tp_returns > 0 else 0.0
|
|
535
|
+
|
|
536
|
+
overall_sum = df[cfg.label_return_col].sum()
|
|
537
|
+
overall_sum = float(overall_sum) if overall_sum is not None else 0.0
|
|
538
|
+
overall_avg = overall_sum / total_obs
|
|
539
|
+
|
|
540
|
+
# Monotonicity analysis for profit factor
|
|
541
|
+
pf_values = [profit_factor[q] for q in q_labels]
|
|
542
|
+
pf_monotonic, pf_direction, pf_spearman = self._analyze_monotonicity(pf_values)
|
|
543
|
+
|
|
544
|
+
self._profit_factor_result = ProfitFactorResult(
|
|
545
|
+
n_quantiles=cfg.n_quantiles,
|
|
546
|
+
quantile_labels=q_labels,
|
|
547
|
+
profit_factor=profit_factor,
|
|
548
|
+
sum_tp_returns=sum_tp_returns,
|
|
549
|
+
sum_sl_returns=sum_sl_returns,
|
|
550
|
+
sum_timeout_returns=sum_timeout_returns,
|
|
551
|
+
sum_all_returns=sum_all_returns,
|
|
552
|
+
avg_tp_return=avg_tp_return,
|
|
553
|
+
avg_sl_return=avg_sl_return,
|
|
554
|
+
avg_return=avg_return,
|
|
555
|
+
count_tp=count_tp,
|
|
556
|
+
count_sl=count_sl,
|
|
557
|
+
count_total=count_total,
|
|
558
|
+
overall_profit_factor=overall_pf,
|
|
559
|
+
overall_sum_returns=overall_sum,
|
|
560
|
+
overall_avg_return=overall_avg,
|
|
561
|
+
n_observations=total_obs,
|
|
562
|
+
pf_monotonic=pf_monotonic,
|
|
563
|
+
pf_direction=pf_direction,
|
|
564
|
+
pf_spearman=pf_spearman,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
return self._profit_factor_result
|
|
568
|
+
|
|
569
|
+
def compute_precision_recall(self) -> PrecisionRecallResult:
|
|
570
|
+
"""Compute precision and recall metrics for barrier outcomes.
|
|
571
|
+
|
|
572
|
+
For the top signal quantile (highest signals), computes:
|
|
573
|
+
- Precision: P(TP | in quantile) = TP count / total in quantile
|
|
574
|
+
- Recall: P(in quantile | TP) = TP in quantile / all TP
|
|
575
|
+
|
|
576
|
+
Also computes cumulative metrics from the top quantile downward,
|
|
577
|
+
and lift (precision relative to baseline TP rate).
|
|
578
|
+
|
|
579
|
+
Returns
|
|
580
|
+
-------
|
|
581
|
+
PrecisionRecallResult
|
|
582
|
+
Results containing precision, recall, F1, and lift metrics
|
|
583
|
+
per quantile and cumulative from top down.
|
|
584
|
+
|
|
585
|
+
Examples
|
|
586
|
+
--------
|
|
587
|
+
>>> result = analysis.compute_precision_recall()
|
|
588
|
+
>>> print(result.summary())
|
|
589
|
+
>>> df = result.get_dataframe("cumulative")
|
|
590
|
+
"""
|
|
591
|
+
if self._precision_recall_result is not None:
|
|
592
|
+
return self._precision_recall_result
|
|
593
|
+
|
|
594
|
+
cfg = self.config
|
|
595
|
+
df = self._merged_data
|
|
596
|
+
q_labels = self.quantile_labels
|
|
597
|
+
|
|
598
|
+
# Total TP count (baseline)
|
|
599
|
+
total_tp = df.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value).height
|
|
600
|
+
total_obs = df.height
|
|
601
|
+
baseline_tp_rate = total_tp / total_obs if total_obs > 0 else 0.0
|
|
602
|
+
|
|
603
|
+
# Per-quantile precision and recall
|
|
604
|
+
precision_tp: dict[str, float] = {}
|
|
605
|
+
recall_tp: dict[str, float] = {}
|
|
606
|
+
lift_tp: dict[str, float] = {}
|
|
607
|
+
|
|
608
|
+
# Count TP per quantile for cumulative calculations
|
|
609
|
+
tp_counts: dict[str, int] = {}
|
|
610
|
+
total_counts: dict[str, int] = {}
|
|
611
|
+
|
|
612
|
+
for q in q_labels:
|
|
613
|
+
q_data = df.filter(pl.col("quantile") == q)
|
|
614
|
+
n_total = q_data.height
|
|
615
|
+
n_tp = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value).height
|
|
616
|
+
|
|
617
|
+
tp_counts[q] = n_tp
|
|
618
|
+
total_counts[q] = n_total
|
|
619
|
+
|
|
620
|
+
# Precision: P(TP | in this quantile)
|
|
621
|
+
prec = n_tp / n_total if n_total > 0 else 0.0
|
|
622
|
+
precision_tp[q] = prec
|
|
623
|
+
|
|
624
|
+
# Recall: P(in this quantile | TP)
|
|
625
|
+
rec = n_tp / total_tp if total_tp > 0 else 0.0
|
|
626
|
+
recall_tp[q] = rec
|
|
627
|
+
|
|
628
|
+
# Lift: precision / baseline
|
|
629
|
+
lift = prec / baseline_tp_rate if baseline_tp_rate > 0 else 0.0
|
|
630
|
+
lift_tp[q] = lift
|
|
631
|
+
|
|
632
|
+
# Cumulative metrics (from top quantile down)
|
|
633
|
+
# Reverse order: D10 is highest signal, then D9, etc.
|
|
634
|
+
reversed_labels = list(reversed(q_labels))
|
|
635
|
+
|
|
636
|
+
cumulative_precision_tp: dict[str, float] = {}
|
|
637
|
+
cumulative_recall_tp: dict[str, float] = {}
|
|
638
|
+
cumulative_f1_tp: dict[str, float] = {}
|
|
639
|
+
cumulative_lift_tp: dict[str, float] = {}
|
|
640
|
+
|
|
641
|
+
cum_tp = 0
|
|
642
|
+
cum_total = 0
|
|
643
|
+
|
|
644
|
+
best_f1 = 0.0
|
|
645
|
+
best_f1_q = q_labels[-1] # Default to top quantile
|
|
646
|
+
|
|
647
|
+
for q in reversed_labels:
|
|
648
|
+
cum_tp += tp_counts[q]
|
|
649
|
+
cum_total += total_counts[q]
|
|
650
|
+
|
|
651
|
+
# Cumulative precision
|
|
652
|
+
cum_prec = cum_tp / cum_total if cum_total > 0 else 0.0
|
|
653
|
+
cumulative_precision_tp[q] = cum_prec
|
|
654
|
+
|
|
655
|
+
# Cumulative recall
|
|
656
|
+
cum_rec = cum_tp / total_tp if total_tp > 0 else 0.0
|
|
657
|
+
cumulative_recall_tp[q] = cum_rec
|
|
658
|
+
|
|
659
|
+
# F1 score
|
|
660
|
+
if cum_prec + cum_rec > 0:
|
|
661
|
+
f1 = 2 * cum_prec * cum_rec / (cum_prec + cum_rec)
|
|
662
|
+
else:
|
|
663
|
+
f1 = 0.0
|
|
664
|
+
cumulative_f1_tp[q] = f1
|
|
665
|
+
|
|
666
|
+
# Track best F1
|
|
667
|
+
if f1 > best_f1:
|
|
668
|
+
best_f1 = f1
|
|
669
|
+
best_f1_q = q
|
|
670
|
+
|
|
671
|
+
# Cumulative lift
|
|
672
|
+
cum_lift = cum_prec / baseline_tp_rate if baseline_tp_rate > 0 else 0.0
|
|
673
|
+
cumulative_lift_tp[q] = cum_lift
|
|
674
|
+
|
|
675
|
+
self._precision_recall_result = PrecisionRecallResult(
|
|
676
|
+
n_quantiles=cfg.n_quantiles,
|
|
677
|
+
quantile_labels=q_labels,
|
|
678
|
+
precision_tp=precision_tp,
|
|
679
|
+
recall_tp=recall_tp,
|
|
680
|
+
cumulative_precision_tp=cumulative_precision_tp,
|
|
681
|
+
cumulative_recall_tp=cumulative_recall_tp,
|
|
682
|
+
cumulative_f1_tp=cumulative_f1_tp,
|
|
683
|
+
lift_tp=lift_tp,
|
|
684
|
+
cumulative_lift_tp=cumulative_lift_tp,
|
|
685
|
+
baseline_tp_rate=baseline_tp_rate,
|
|
686
|
+
total_tp_count=total_tp,
|
|
687
|
+
n_observations=total_obs,
|
|
688
|
+
best_f1_quantile=best_f1_q,
|
|
689
|
+
best_f1_score=best_f1,
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
return self._precision_recall_result
|
|
693
|
+
|
|
694
|
+
def compute_time_to_target(self) -> TimeToTargetResult:
|
|
695
|
+
"""Compute time-to-target metrics by signal decile.
|
|
696
|
+
|
|
697
|
+
Analyzes how quickly different signal quantiles reach their barrier
|
|
698
|
+
outcomes (TP, SL, or timeout). Uses the `label_bars` column from
|
|
699
|
+
barrier labels to measure time to exit.
|
|
700
|
+
|
|
701
|
+
Returns
|
|
702
|
+
-------
|
|
703
|
+
TimeToTargetResult
|
|
704
|
+
Results containing mean, median, and std of bars to exit
|
|
705
|
+
per quantile and outcome type.
|
|
706
|
+
|
|
707
|
+
Raises
|
|
708
|
+
------
|
|
709
|
+
ValueError
|
|
710
|
+
If label_bars column is not available in barrier_labels.
|
|
711
|
+
|
|
712
|
+
Examples
|
|
713
|
+
--------
|
|
714
|
+
>>> result = analysis.compute_time_to_target()
|
|
715
|
+
>>> print(result.summary())
|
|
716
|
+
>>> df = result.get_dataframe("detailed")
|
|
717
|
+
"""
|
|
718
|
+
if self._time_to_target_result is not None:
|
|
719
|
+
return self._time_to_target_result
|
|
720
|
+
|
|
721
|
+
cfg = self.config
|
|
722
|
+
df = self._merged_data
|
|
723
|
+
q_labels = self.quantile_labels
|
|
724
|
+
|
|
725
|
+
# Check if label_bars column exists
|
|
726
|
+
if cfg.label_bars_col not in df.columns:
|
|
727
|
+
raise ValueError(
|
|
728
|
+
f"Time-to-target analysis requires '{cfg.label_bars_col}' column in barrier_labels. "
|
|
729
|
+
f"Available columns: {df.columns}"
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
# Initialize containers
|
|
733
|
+
mean_bars_tp: dict[str, float] = {}
|
|
734
|
+
mean_bars_sl: dict[str, float] = {}
|
|
735
|
+
mean_bars_timeout: dict[str, float] = {}
|
|
736
|
+
mean_bars_all: dict[str, float] = {}
|
|
737
|
+
median_bars_tp: dict[str, float] = {}
|
|
738
|
+
median_bars_sl: dict[str, float] = {}
|
|
739
|
+
median_bars_all: dict[str, float] = {}
|
|
740
|
+
std_bars_tp: dict[str, float] = {}
|
|
741
|
+
std_bars_sl: dict[str, float] = {}
|
|
742
|
+
std_bars_all: dict[str, float] = {}
|
|
743
|
+
count_tp: dict[str, int] = {}
|
|
744
|
+
count_sl: dict[str, int] = {}
|
|
745
|
+
count_timeout: dict[str, int] = {}
|
|
746
|
+
tp_faster_than_sl: dict[str, bool] = {}
|
|
747
|
+
speed_advantage_tp: dict[str, float] = {}
|
|
748
|
+
|
|
749
|
+
for q in q_labels:
|
|
750
|
+
q_data = df.filter(pl.col("quantile") == q)
|
|
751
|
+
|
|
752
|
+
# TP outcomes
|
|
753
|
+
tp_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value)
|
|
754
|
+
n_tp = tp_data.height
|
|
755
|
+
count_tp[q] = n_tp
|
|
756
|
+
|
|
757
|
+
if n_tp > 0:
|
|
758
|
+
tp_bars = tp_data[cfg.label_bars_col]
|
|
759
|
+
mean_bars_tp[q] = float(tp_bars.mean() or 0.0)
|
|
760
|
+
median_bars_tp[q] = float(tp_bars.median() or 0.0)
|
|
761
|
+
std_bars_tp[q] = float(tp_bars.std() or 0.0)
|
|
762
|
+
else:
|
|
763
|
+
mean_bars_tp[q] = 0.0
|
|
764
|
+
median_bars_tp[q] = 0.0
|
|
765
|
+
std_bars_tp[q] = 0.0
|
|
766
|
+
|
|
767
|
+
# SL outcomes
|
|
768
|
+
sl_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value)
|
|
769
|
+
n_sl = sl_data.height
|
|
770
|
+
count_sl[q] = n_sl
|
|
771
|
+
|
|
772
|
+
if n_sl > 0:
|
|
773
|
+
sl_bars = sl_data[cfg.label_bars_col]
|
|
774
|
+
mean_bars_sl[q] = float(sl_bars.mean() or 0.0)
|
|
775
|
+
median_bars_sl[q] = float(sl_bars.median() or 0.0)
|
|
776
|
+
std_bars_sl[q] = float(sl_bars.std() or 0.0)
|
|
777
|
+
else:
|
|
778
|
+
mean_bars_sl[q] = 0.0
|
|
779
|
+
median_bars_sl[q] = 0.0
|
|
780
|
+
std_bars_sl[q] = 0.0
|
|
781
|
+
|
|
782
|
+
# Timeout outcomes
|
|
783
|
+
timeout_data = q_data.filter(pl.col(cfg.label_col) == BarrierLabel.TIMEOUT.value)
|
|
784
|
+
n_timeout = timeout_data.height
|
|
785
|
+
count_timeout[q] = n_timeout
|
|
786
|
+
|
|
787
|
+
if n_timeout > 0:
|
|
788
|
+
mean_bars_timeout[q] = float(timeout_data[cfg.label_bars_col].mean() or 0.0)
|
|
789
|
+
else:
|
|
790
|
+
mean_bars_timeout[q] = 0.0
|
|
791
|
+
|
|
792
|
+
# All outcomes
|
|
793
|
+
n_all = q_data.height
|
|
794
|
+
if n_all > 0:
|
|
795
|
+
all_bars = q_data[cfg.label_bars_col]
|
|
796
|
+
mean_bars_all[q] = float(all_bars.mean() or 0.0)
|
|
797
|
+
median_bars_all[q] = float(all_bars.median() or 0.0)
|
|
798
|
+
std_bars_all[q] = float(all_bars.std() or 0.0)
|
|
799
|
+
else:
|
|
800
|
+
mean_bars_all[q] = 0.0
|
|
801
|
+
median_bars_all[q] = 0.0
|
|
802
|
+
std_bars_all[q] = 0.0
|
|
803
|
+
|
|
804
|
+
# Speed analysis: is TP reached faster than SL?
|
|
805
|
+
if n_tp > 0 and n_sl > 0:
|
|
806
|
+
tp_faster = mean_bars_tp[q] < mean_bars_sl[q]
|
|
807
|
+
speed_adv = mean_bars_sl[q] - mean_bars_tp[q]
|
|
808
|
+
elif n_tp > 0:
|
|
809
|
+
tp_faster = True
|
|
810
|
+
speed_adv = 0.0
|
|
811
|
+
elif n_sl > 0:
|
|
812
|
+
tp_faster = False
|
|
813
|
+
speed_adv = 0.0
|
|
814
|
+
else:
|
|
815
|
+
tp_faster = False
|
|
816
|
+
speed_adv = 0.0
|
|
817
|
+
|
|
818
|
+
tp_faster_than_sl[q] = tp_faster
|
|
819
|
+
speed_advantage_tp[q] = speed_adv
|
|
820
|
+
|
|
821
|
+
# Overall statistics
|
|
822
|
+
total_obs = df.height
|
|
823
|
+
all_bars = df[cfg.label_bars_col]
|
|
824
|
+
overall_mean_bars = float(all_bars.mean() or 0.0)
|
|
825
|
+
overall_median_bars = float(all_bars.median() or 0.0)
|
|
826
|
+
|
|
827
|
+
tp_all = df.filter(pl.col(cfg.label_col) == BarrierLabel.TAKE_PROFIT.value)
|
|
828
|
+
overall_mean_bars_tp = (
|
|
829
|
+
float(tp_all[cfg.label_bars_col].mean() or 0.0) if tp_all.height > 0 else 0.0
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
sl_all = df.filter(pl.col(cfg.label_col) == BarrierLabel.STOP_LOSS.value)
|
|
833
|
+
overall_mean_bars_sl = (
|
|
834
|
+
float(sl_all[cfg.label_bars_col].mean() or 0.0) if sl_all.height > 0 else 0.0
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
self._time_to_target_result = TimeToTargetResult(
|
|
838
|
+
n_quantiles=cfg.n_quantiles,
|
|
839
|
+
quantile_labels=q_labels,
|
|
840
|
+
mean_bars_tp=mean_bars_tp,
|
|
841
|
+
mean_bars_sl=mean_bars_sl,
|
|
842
|
+
mean_bars_timeout=mean_bars_timeout,
|
|
843
|
+
mean_bars_all=mean_bars_all,
|
|
844
|
+
median_bars_tp=median_bars_tp,
|
|
845
|
+
median_bars_sl=median_bars_sl,
|
|
846
|
+
median_bars_all=median_bars_all,
|
|
847
|
+
std_bars_tp=std_bars_tp,
|
|
848
|
+
std_bars_sl=std_bars_sl,
|
|
849
|
+
std_bars_all=std_bars_all,
|
|
850
|
+
count_tp=count_tp,
|
|
851
|
+
count_sl=count_sl,
|
|
852
|
+
count_timeout=count_timeout,
|
|
853
|
+
overall_mean_bars=overall_mean_bars,
|
|
854
|
+
overall_median_bars=overall_median_bars,
|
|
855
|
+
overall_mean_bars_tp=overall_mean_bars_tp,
|
|
856
|
+
overall_mean_bars_sl=overall_mean_bars_sl,
|
|
857
|
+
n_observations=total_obs,
|
|
858
|
+
tp_faster_than_sl=tp_faster_than_sl,
|
|
859
|
+
speed_advantage_tp=speed_advantage_tp,
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
return self._time_to_target_result
|
|
863
|
+
|
|
864
|
+
def _analyze_monotonicity(
|
|
865
|
+
self,
|
|
866
|
+
values: list[float],
|
|
867
|
+
) -> tuple[bool, str, float]:
|
|
868
|
+
"""Analyze monotonicity of values across quantiles.
|
|
869
|
+
|
|
870
|
+
Parameters
|
|
871
|
+
----------
|
|
872
|
+
values : list[float]
|
|
873
|
+
Values for each quantile (ordered by quantile rank).
|
|
874
|
+
|
|
875
|
+
Returns
|
|
876
|
+
-------
|
|
877
|
+
tuple[bool, str, float]
|
|
878
|
+
(is_monotonic, direction, spearman_correlation)
|
|
879
|
+
direction is 'increasing', 'decreasing', or 'none'
|
|
880
|
+
"""
|
|
881
|
+
if len(values) < 2:
|
|
882
|
+
return False, "none", 0.0
|
|
883
|
+
|
|
884
|
+
# Remove any NaN/inf values for correlation
|
|
885
|
+
valid_values = [v for v in values if np.isfinite(v)]
|
|
886
|
+
if len(valid_values) < 2:
|
|
887
|
+
return False, "none", 0.0
|
|
888
|
+
|
|
889
|
+
# Spearman correlation with rank
|
|
890
|
+
ranks = list(range(len(valid_values)))
|
|
891
|
+
try:
|
|
892
|
+
spearman_corr, _ = stats.spearmanr(ranks, valid_values)
|
|
893
|
+
except Exception:
|
|
894
|
+
spearman_corr = 0.0
|
|
895
|
+
|
|
896
|
+
spearman_corr = float(spearman_corr) if np.isfinite(spearman_corr) else 0.0
|
|
897
|
+
|
|
898
|
+
# Check strict monotonicity
|
|
899
|
+
diffs = [values[i + 1] - values[i] for i in range(len(values) - 1)]
|
|
900
|
+
all_increasing = all(d >= 0 for d in diffs) and any(d > 0 for d in diffs)
|
|
901
|
+
all_decreasing = all(d <= 0 for d in diffs) and any(d < 0 for d in diffs)
|
|
902
|
+
|
|
903
|
+
if all_increasing:
|
|
904
|
+
return True, "increasing", spearman_corr
|
|
905
|
+
elif all_decreasing:
|
|
906
|
+
return True, "decreasing", spearman_corr
|
|
907
|
+
else:
|
|
908
|
+
return False, "none", spearman_corr
|
|
909
|
+
|
|
910
|
+
def create_tear_sheet(
|
|
911
|
+
self,
|
|
912
|
+
include_time_to_target: bool = True,
|
|
913
|
+
include_figures: bool = True,
|
|
914
|
+
theme: str | None = None,
|
|
915
|
+
) -> BarrierTearSheet:
|
|
916
|
+
"""Create comprehensive tear sheet with all analysis results.
|
|
917
|
+
|
|
918
|
+
Parameters
|
|
919
|
+
----------
|
|
920
|
+
include_time_to_target : bool, default=True
|
|
921
|
+
If True, include time-to-target analysis. Requires `label_bars`
|
|
922
|
+
column in barrier_labels. Set to False if column not available.
|
|
923
|
+
include_figures : bool, default=True
|
|
924
|
+
If True, generate Plotly figures for visualization.
|
|
925
|
+
Set to False to skip figure generation (faster).
|
|
926
|
+
theme : str | None
|
|
927
|
+
Plot theme: 'default', 'dark', 'print', 'presentation'.
|
|
928
|
+
If None, uses default theme.
|
|
929
|
+
|
|
930
|
+
Returns
|
|
931
|
+
-------
|
|
932
|
+
BarrierTearSheet
|
|
933
|
+
Complete results including hit rates, profit factor,
|
|
934
|
+
precision/recall, time-to-target, figures, and metadata.
|
|
935
|
+
|
|
936
|
+
Examples
|
|
937
|
+
--------
|
|
938
|
+
>>> tear_sheet = analysis.create_tear_sheet()
|
|
939
|
+
>>> tear_sheet.save_html("barrier_analysis.html")
|
|
940
|
+
>>> print(tear_sheet.summary())
|
|
941
|
+
"""
|
|
942
|
+
# Compute all metrics
|
|
943
|
+
hit_rate = self.compute_hit_rates()
|
|
944
|
+
profit_factor = self.compute_profit_factor()
|
|
945
|
+
precision_recall = self.compute_precision_recall()
|
|
946
|
+
|
|
947
|
+
# Time-to-target is optional (requires label_bars column)
|
|
948
|
+
time_to_target = None
|
|
949
|
+
if include_time_to_target:
|
|
950
|
+
try:
|
|
951
|
+
time_to_target = self.compute_time_to_target()
|
|
952
|
+
except ValueError:
|
|
953
|
+
# label_bars column not available, skip
|
|
954
|
+
pass
|
|
955
|
+
|
|
956
|
+
# Generate figures if requested
|
|
957
|
+
figures: dict[str, str] = {}
|
|
958
|
+
if include_figures:
|
|
959
|
+
figures = self._generate_figures(
|
|
960
|
+
hit_rate=hit_rate,
|
|
961
|
+
profit_factor=profit_factor,
|
|
962
|
+
precision_recall=precision_recall,
|
|
963
|
+
time_to_target=time_to_target,
|
|
964
|
+
theme=theme,
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
return BarrierTearSheet(
|
|
968
|
+
hit_rate_result=hit_rate,
|
|
969
|
+
profit_factor_result=profit_factor,
|
|
970
|
+
precision_recall_result=precision_recall,
|
|
971
|
+
time_to_target_result=time_to_target,
|
|
972
|
+
signal_name=self.config.signal_name,
|
|
973
|
+
n_assets=self.n_assets,
|
|
974
|
+
n_dates=self.n_dates,
|
|
975
|
+
n_observations=self.n_observations,
|
|
976
|
+
date_range=self.date_range,
|
|
977
|
+
figures=figures,
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
def _generate_figures(
|
|
981
|
+
self,
|
|
982
|
+
hit_rate: HitRateResult,
|
|
983
|
+
profit_factor: ProfitFactorResult,
|
|
984
|
+
precision_recall: PrecisionRecallResult,
|
|
985
|
+
time_to_target: TimeToTargetResult | None,
|
|
986
|
+
theme: str | None = None,
|
|
987
|
+
) -> dict[str, str]:
|
|
988
|
+
"""Generate Plotly figures for the tear sheet.
|
|
989
|
+
|
|
990
|
+
Parameters
|
|
991
|
+
----------
|
|
992
|
+
hit_rate : HitRateResult
|
|
993
|
+
Hit rate analysis results.
|
|
994
|
+
profit_factor : ProfitFactorResult
|
|
995
|
+
Profit factor analysis results.
|
|
996
|
+
precision_recall : PrecisionRecallResult
|
|
997
|
+
Precision/recall analysis results.
|
|
998
|
+
time_to_target : TimeToTargetResult | None
|
|
999
|
+
Time-to-target analysis results (optional).
|
|
1000
|
+
theme : str | None
|
|
1001
|
+
Plot theme.
|
|
1002
|
+
|
|
1003
|
+
Returns
|
|
1004
|
+
-------
|
|
1005
|
+
dict[str, str]
|
|
1006
|
+
Dict mapping figure names to JSON-serialized Plotly figures.
|
|
1007
|
+
"""
|
|
1008
|
+
import plotly.io as pio
|
|
1009
|
+
|
|
1010
|
+
from ml4t.diagnostic.visualization.barrier_plots import (
|
|
1011
|
+
plot_hit_rate_heatmap,
|
|
1012
|
+
plot_precision_recall_curve,
|
|
1013
|
+
plot_profit_factor_bar,
|
|
1014
|
+
plot_time_to_target_box,
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
figures: dict[str, str] = {}
|
|
1018
|
+
|
|
1019
|
+
# Hit Rate Heatmap
|
|
1020
|
+
try:
|
|
1021
|
+
fig = plot_hit_rate_heatmap(hit_rate, theme=theme)
|
|
1022
|
+
figures["hit_rate_heatmap"] = pio.to_json(fig)
|
|
1023
|
+
except Exception:
|
|
1024
|
+
pass # Skip if visualization fails
|
|
1025
|
+
|
|
1026
|
+
# Profit Factor Bar Chart
|
|
1027
|
+
try:
|
|
1028
|
+
fig = plot_profit_factor_bar(profit_factor, theme=theme)
|
|
1029
|
+
figures["profit_factor_bar"] = pio.to_json(fig)
|
|
1030
|
+
except Exception:
|
|
1031
|
+
pass
|
|
1032
|
+
|
|
1033
|
+
# Precision/Recall Curve
|
|
1034
|
+
try:
|
|
1035
|
+
fig = plot_precision_recall_curve(precision_recall, theme=theme)
|
|
1036
|
+
figures["precision_recall_curve"] = pio.to_json(fig)
|
|
1037
|
+
except Exception:
|
|
1038
|
+
pass
|
|
1039
|
+
|
|
1040
|
+
# Time-to-Target Box Plots (if available)
|
|
1041
|
+
if time_to_target is not None:
|
|
1042
|
+
try:
|
|
1043
|
+
fig = plot_time_to_target_box(
|
|
1044
|
+
time_to_target, outcome_type="comparison", theme=theme
|
|
1045
|
+
)
|
|
1046
|
+
figures["time_to_target_comparison"] = pio.to_json(fig)
|
|
1047
|
+
except Exception:
|
|
1048
|
+
pass
|
|
1049
|
+
|
|
1050
|
+
return figures
|