ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Result classes for Barrier Analysis module.
|
|
2
|
+
|
|
3
|
+
This package provides Pydantic result classes for storing and serializing
|
|
4
|
+
barrier analysis outputs including hit rates, profit factors, precision/recall,
|
|
5
|
+
and time-to-target metrics.
|
|
6
|
+
|
|
7
|
+
Triple barrier outcomes from ml4t.features:
|
|
8
|
+
- label: int (-1=SL hit, 0=timeout, 1=TP hit)
|
|
9
|
+
- label_return: float (actual return at exit)
|
|
10
|
+
- label_bars: int (bars from entry to exit)
|
|
11
|
+
|
|
12
|
+
References
|
|
13
|
+
----------
|
|
14
|
+
Lopez de Prado, M. (2018). "Advances in Financial Machine Learning"
|
|
15
|
+
Chapter 3: Labeling (Triple Barrier Method)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from ml4t.diagnostic.results.barrier_results.hit_rate import HitRateResult
|
|
21
|
+
from ml4t.diagnostic.results.barrier_results.precision_recall import PrecisionRecallResult
|
|
22
|
+
from ml4t.diagnostic.results.barrier_results.profit_factor import ProfitFactorResult
|
|
23
|
+
from ml4t.diagnostic.results.barrier_results.tearsheet import BarrierTearSheet
|
|
24
|
+
from ml4t.diagnostic.results.barrier_results.time_to_target import TimeToTargetResult
|
|
25
|
+
from ml4t.diagnostic.results.barrier_results.validation import _validate_quantile_dict_keys
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
# Validation helper
|
|
29
|
+
"_validate_quantile_dict_keys",
|
|
30
|
+
# Result classes
|
|
31
|
+
"HitRateResult",
|
|
32
|
+
"ProfitFactorResult",
|
|
33
|
+
"PrecisionRecallResult",
|
|
34
|
+
"TimeToTargetResult",
|
|
35
|
+
"BarrierTearSheet",
|
|
36
|
+
]
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""Hit rate analysis results for barrier outcomes.
|
|
2
|
+
|
|
3
|
+
This module provides the HitRateResult class for storing hit rate metrics
|
|
4
|
+
(TP, SL, timeout) by signal quantile, including chi-square independence tests.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import polars as pl
|
|
10
|
+
from pydantic import Field, model_validator
|
|
11
|
+
|
|
12
|
+
from ml4t.diagnostic.results.barrier_results.validation import _validate_quantile_dict_keys
|
|
13
|
+
from ml4t.diagnostic.results.base import BaseResult
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HitRateResult(BaseResult):
|
|
17
|
+
"""Results from hit rate analysis by signal decile.
|
|
18
|
+
|
|
19
|
+
Contains hit rates (% TP, % SL, % timeout) for each signal quantile,
|
|
20
|
+
along with chi-square test for independence between signal strength
|
|
21
|
+
and barrier outcome.
|
|
22
|
+
|
|
23
|
+
Examples
|
|
24
|
+
--------
|
|
25
|
+
>>> result = hit_rate_result
|
|
26
|
+
>>> print(result.summary())
|
|
27
|
+
>>> df = result.get_dataframe("hit_rates")
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
analysis_type: str = Field(default="barrier_hit_rate", frozen=True)
|
|
31
|
+
|
|
32
|
+
# ==========================================================================
|
|
33
|
+
# Configuration
|
|
34
|
+
# ==========================================================================
|
|
35
|
+
|
|
36
|
+
n_quantiles: int = Field(
|
|
37
|
+
...,
|
|
38
|
+
description="Number of quantiles used",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
quantile_labels: list[str] = Field(
|
|
42
|
+
...,
|
|
43
|
+
description="Labels for each quantile (e.g., ['D1', 'D2', ..., 'D10'])",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# ==========================================================================
|
|
47
|
+
# Hit Rates by Quantile
|
|
48
|
+
# ==========================================================================
|
|
49
|
+
|
|
50
|
+
hit_rate_tp: dict[str, float] = Field(
|
|
51
|
+
...,
|
|
52
|
+
description="Take-profit hit rate per quantile: {quantile: rate}",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
hit_rate_sl: dict[str, float] = Field(
|
|
56
|
+
...,
|
|
57
|
+
description="Stop-loss hit rate per quantile: {quantile: rate}",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
hit_rate_timeout: dict[str, float] = Field(
|
|
61
|
+
...,
|
|
62
|
+
description="Timeout hit rate per quantile: {quantile: rate}",
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# ==========================================================================
|
|
66
|
+
# Counts
|
|
67
|
+
# ==========================================================================
|
|
68
|
+
|
|
69
|
+
count_tp: dict[str, int] = Field(
|
|
70
|
+
...,
|
|
71
|
+
description="Take-profit count per quantile",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
count_sl: dict[str, int] = Field(
|
|
75
|
+
...,
|
|
76
|
+
description="Stop-loss count per quantile",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
count_timeout: dict[str, int] = Field(
|
|
80
|
+
...,
|
|
81
|
+
description="Timeout count per quantile",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
count_total: dict[str, int] = Field(
|
|
85
|
+
...,
|
|
86
|
+
description="Total count per quantile",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# ==========================================================================
|
|
90
|
+
# Statistical Test (Chi-Square Independence)
|
|
91
|
+
# ==========================================================================
|
|
92
|
+
|
|
93
|
+
chi2_statistic: float = Field(
|
|
94
|
+
...,
|
|
95
|
+
description="Chi-square statistic for independence test",
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
chi2_p_value: float = Field(
|
|
99
|
+
...,
|
|
100
|
+
description="P-value for chi-square test",
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
chi2_dof: int = Field(
|
|
104
|
+
...,
|
|
105
|
+
description="Degrees of freedom for chi-square test",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
is_significant: bool = Field(
|
|
109
|
+
...,
|
|
110
|
+
description="Whether signal quantile significantly affects outcome (p < alpha)",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
significance_level: float = Field(
|
|
114
|
+
...,
|
|
115
|
+
description="Significance level used for test",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# ==========================================================================
|
|
119
|
+
# Aggregates
|
|
120
|
+
# ==========================================================================
|
|
121
|
+
|
|
122
|
+
overall_hit_rate_tp: float = Field(
|
|
123
|
+
...,
|
|
124
|
+
description="Overall take-profit hit rate across all observations",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
overall_hit_rate_sl: float = Field(
|
|
128
|
+
...,
|
|
129
|
+
description="Overall stop-loss hit rate across all observations",
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
overall_hit_rate_timeout: float = Field(
|
|
133
|
+
...,
|
|
134
|
+
description="Overall timeout hit rate across all observations",
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
n_observations: int = Field(
|
|
138
|
+
...,
|
|
139
|
+
description="Total number of observations analyzed",
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# ==========================================================================
|
|
143
|
+
# Monotonicity
|
|
144
|
+
# ==========================================================================
|
|
145
|
+
|
|
146
|
+
tp_rate_monotonic: bool = Field(
|
|
147
|
+
...,
|
|
148
|
+
description="Whether TP hit rate is monotonic across quantiles",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
tp_rate_direction: str = Field(
|
|
152
|
+
...,
|
|
153
|
+
description="Direction of TP rate change: 'increasing', 'decreasing', or 'none'",
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
tp_rate_spearman: float = Field(
|
|
157
|
+
...,
|
|
158
|
+
description="Spearman correlation between quantile rank and TP hit rate",
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# ==========================================================================
|
|
162
|
+
# Validation
|
|
163
|
+
# ==========================================================================
|
|
164
|
+
|
|
165
|
+
@model_validator(mode="after")
|
|
166
|
+
def _validate_quantile_keys(self) -> HitRateResult:
|
|
167
|
+
"""Validate that all quantile-keyed dicts have consistent keys."""
|
|
168
|
+
if self.n_quantiles != len(self.quantile_labels):
|
|
169
|
+
raise ValueError(
|
|
170
|
+
f"n_quantiles ({self.n_quantiles}) != len(quantile_labels) ({len(self.quantile_labels)})"
|
|
171
|
+
)
|
|
172
|
+
_validate_quantile_dict_keys(
|
|
173
|
+
self.quantile_labels,
|
|
174
|
+
[
|
|
175
|
+
("hit_rate_tp", self.hit_rate_tp),
|
|
176
|
+
("hit_rate_sl", self.hit_rate_sl),
|
|
177
|
+
("hit_rate_timeout", self.hit_rate_timeout),
|
|
178
|
+
("count_tp", self.count_tp),
|
|
179
|
+
("count_sl", self.count_sl),
|
|
180
|
+
("count_timeout", self.count_timeout),
|
|
181
|
+
("count_total", self.count_total),
|
|
182
|
+
],
|
|
183
|
+
)
|
|
184
|
+
return self
|
|
185
|
+
|
|
186
|
+
# ==========================================================================
|
|
187
|
+
# Methods
|
|
188
|
+
# ==========================================================================
|
|
189
|
+
|
|
190
|
+
def get_dataframe(self, name: str | None = None) -> pl.DataFrame:
|
|
191
|
+
"""Get results as Polars DataFrame.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
name : str | None
|
|
196
|
+
DataFrame to retrieve:
|
|
197
|
+
- None or "hit_rates": Hit rates by quantile
|
|
198
|
+
- "counts": Raw counts by quantile and outcome
|
|
199
|
+
- "summary": Single-row summary statistics
|
|
200
|
+
|
|
201
|
+
Returns
|
|
202
|
+
-------
|
|
203
|
+
pl.DataFrame
|
|
204
|
+
Requested DataFrame
|
|
205
|
+
"""
|
|
206
|
+
if name is None or name == "hit_rates":
|
|
207
|
+
return pl.DataFrame(
|
|
208
|
+
{
|
|
209
|
+
"quantile": self.quantile_labels,
|
|
210
|
+
"hit_rate_tp": [self.hit_rate_tp[q] for q in self.quantile_labels],
|
|
211
|
+
"hit_rate_sl": [self.hit_rate_sl[q] for q in self.quantile_labels],
|
|
212
|
+
"hit_rate_timeout": [self.hit_rate_timeout[q] for q in self.quantile_labels],
|
|
213
|
+
"count_total": [self.count_total[q] for q in self.quantile_labels],
|
|
214
|
+
}
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
if name == "counts":
|
|
218
|
+
return pl.DataFrame(
|
|
219
|
+
{
|
|
220
|
+
"quantile": self.quantile_labels,
|
|
221
|
+
"count_tp": [self.count_tp[q] for q in self.quantile_labels],
|
|
222
|
+
"count_sl": [self.count_sl[q] for q in self.quantile_labels],
|
|
223
|
+
"count_timeout": [self.count_timeout[q] for q in self.quantile_labels],
|
|
224
|
+
"count_total": [self.count_total[q] for q in self.quantile_labels],
|
|
225
|
+
}
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if name == "summary":
|
|
229
|
+
return pl.DataFrame(
|
|
230
|
+
{
|
|
231
|
+
"metric": [
|
|
232
|
+
"n_observations",
|
|
233
|
+
"n_quantiles",
|
|
234
|
+
"overall_hit_rate_tp",
|
|
235
|
+
"overall_hit_rate_sl",
|
|
236
|
+
"overall_hit_rate_timeout",
|
|
237
|
+
"chi2_statistic",
|
|
238
|
+
"chi2_p_value",
|
|
239
|
+
"is_significant",
|
|
240
|
+
"tp_rate_monotonic",
|
|
241
|
+
"tp_rate_spearman",
|
|
242
|
+
],
|
|
243
|
+
"value": [
|
|
244
|
+
float(self.n_observations),
|
|
245
|
+
float(self.n_quantiles),
|
|
246
|
+
self.overall_hit_rate_tp,
|
|
247
|
+
self.overall_hit_rate_sl,
|
|
248
|
+
self.overall_hit_rate_timeout,
|
|
249
|
+
self.chi2_statistic,
|
|
250
|
+
self.chi2_p_value,
|
|
251
|
+
float(self.is_significant),
|
|
252
|
+
float(self.tp_rate_monotonic),
|
|
253
|
+
self.tp_rate_spearman,
|
|
254
|
+
],
|
|
255
|
+
}
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
raise ValueError(
|
|
259
|
+
f"Unknown DataFrame name: {name}. Available: 'hit_rates', 'counts', 'summary'"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
def list_available_dataframes(self) -> list[str]:
|
|
263
|
+
"""List available DataFrame views."""
|
|
264
|
+
return ["hit_rates", "counts", "summary"]
|
|
265
|
+
|
|
266
|
+
def summary(self) -> str:
|
|
267
|
+
"""Get human-readable summary of hit rate results."""
|
|
268
|
+
lines = [
|
|
269
|
+
"=" * 60,
|
|
270
|
+
"Barrier Hit Rate Analysis",
|
|
271
|
+
"=" * 60,
|
|
272
|
+
"",
|
|
273
|
+
f"Observations: {self.n_observations:>10,}",
|
|
274
|
+
f"Quantiles: {self.n_quantiles:>10}",
|
|
275
|
+
"",
|
|
276
|
+
"Overall Hit Rates:",
|
|
277
|
+
f" Take-Profit: {self.overall_hit_rate_tp:>10.1%}",
|
|
278
|
+
f" Stop-Loss: {self.overall_hit_rate_sl:>10.1%}",
|
|
279
|
+
f" Timeout: {self.overall_hit_rate_timeout:>10.1%}",
|
|
280
|
+
"",
|
|
281
|
+
"Chi-Square Test (Signal Decile vs Outcome):",
|
|
282
|
+
f" Chi2 Statistic: {self.chi2_statistic:>10.2f}",
|
|
283
|
+
f" P-value: {self.chi2_p_value:>10.4f}",
|
|
284
|
+
f" DoF: {self.chi2_dof:>10}",
|
|
285
|
+
f" Significant: {'Yes' if self.is_significant else 'No':>10} (alpha={self.significance_level})",
|
|
286
|
+
"",
|
|
287
|
+
"Monotonicity (TP Rate vs Signal Strength):",
|
|
288
|
+
f" Monotonic: {'Yes' if self.tp_rate_monotonic else 'No':>10}",
|
|
289
|
+
f" Direction: {self.tp_rate_direction:>10}",
|
|
290
|
+
f" Spearman rho: {self.tp_rate_spearman:>10.4f}",
|
|
291
|
+
"",
|
|
292
|
+
"-" * 60,
|
|
293
|
+
"Hit Rates by Quantile:",
|
|
294
|
+
"-" * 60,
|
|
295
|
+
f"{'Quantile':<10} {'TP Rate':>10} {'SL Rate':>10} {'Timeout':>10} {'Count':>8}",
|
|
296
|
+
]
|
|
297
|
+
|
|
298
|
+
for q in self.quantile_labels:
|
|
299
|
+
lines.append(
|
|
300
|
+
f"{q:<10} {self.hit_rate_tp[q]:>10.1%} {self.hit_rate_sl[q]:>10.1%} "
|
|
301
|
+
f"{self.hit_rate_timeout[q]:>10.1%} {self.count_total[q]:>8,}"
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
"""Precision/recall analysis results for barrier outcomes.
|
|
2
|
+
|
|
3
|
+
This module provides the PrecisionRecallResult class for storing precision,
|
|
4
|
+
recall, F1 scores, and lift metrics for barrier outcomes by signal quantile.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import polars as pl
|
|
10
|
+
from pydantic import Field, model_validator
|
|
11
|
+
|
|
12
|
+
from ml4t.diagnostic.results.barrier_results.validation import _validate_quantile_dict_keys
|
|
13
|
+
from ml4t.diagnostic.results.base import BaseResult
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PrecisionRecallResult(BaseResult):
|
|
17
|
+
"""Results from precision/recall analysis for barrier outcomes.
|
|
18
|
+
|
|
19
|
+
Precision: Of signals in top quantile, what fraction hit TP?
|
|
20
|
+
Recall: Of all TP outcomes, what fraction came from top quantile?
|
|
21
|
+
|
|
22
|
+
This helps understand signal selectivity vs coverage trade-offs.
|
|
23
|
+
|
|
24
|
+
Examples
|
|
25
|
+
--------
|
|
26
|
+
>>> result = precision_recall_result
|
|
27
|
+
>>> print(result.summary())
|
|
28
|
+
>>> df = result.get_dataframe()
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
analysis_type: str = Field(default="barrier_precision_recall", frozen=True)
|
|
32
|
+
|
|
33
|
+
# ==========================================================================
|
|
34
|
+
# Configuration
|
|
35
|
+
# ==========================================================================
|
|
36
|
+
|
|
37
|
+
n_quantiles: int = Field(
|
|
38
|
+
...,
|
|
39
|
+
description="Number of quantiles used",
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
quantile_labels: list[str] = Field(
|
|
43
|
+
...,
|
|
44
|
+
description="Labels for each quantile (e.g., ['D1', 'D2', ..., 'D10'])",
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# ==========================================================================
|
|
48
|
+
# Precision by Quantile (TP-focused)
|
|
49
|
+
# ==========================================================================
|
|
50
|
+
|
|
51
|
+
precision_tp: dict[str, float] = Field(
|
|
52
|
+
...,
|
|
53
|
+
description="Precision for TP: P(TP | in quantile) = TP count / total in quantile",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# ==========================================================================
|
|
57
|
+
# Recall by Quantile (TP-focused)
|
|
58
|
+
# ==========================================================================
|
|
59
|
+
|
|
60
|
+
recall_tp: dict[str, float] = Field(
|
|
61
|
+
...,
|
|
62
|
+
description="Recall for TP: P(in quantile | TP) = TP in quantile / all TP",
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# ==========================================================================
|
|
66
|
+
# Cumulative Metrics (from top quantile down)
|
|
67
|
+
# ==========================================================================
|
|
68
|
+
|
|
69
|
+
cumulative_precision_tp: dict[str, float] = Field(
|
|
70
|
+
...,
|
|
71
|
+
description="Cumulative precision: P(TP | in top k quantiles)",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
cumulative_recall_tp: dict[str, float] = Field(
|
|
75
|
+
...,
|
|
76
|
+
description="Cumulative recall: P(in top k quantiles | TP)",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
cumulative_f1_tp: dict[str, float] = Field(
|
|
80
|
+
...,
|
|
81
|
+
description="Cumulative F1 score: 2 * (precision * recall) / (precision + recall)",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# ==========================================================================
|
|
85
|
+
# Lift Metrics
|
|
86
|
+
# ==========================================================================
|
|
87
|
+
|
|
88
|
+
lift_tp: dict[str, float] = Field(
|
|
89
|
+
...,
|
|
90
|
+
description="Lift for TP: precision / baseline TP rate",
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
cumulative_lift_tp: dict[str, float] = Field(
|
|
94
|
+
...,
|
|
95
|
+
description="Cumulative lift for TP",
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# ==========================================================================
|
|
99
|
+
# Baseline
|
|
100
|
+
# ==========================================================================
|
|
101
|
+
|
|
102
|
+
baseline_tp_rate: float = Field(
|
|
103
|
+
...,
|
|
104
|
+
description="Baseline TP rate (overall TP count / total)",
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
total_tp_count: int = Field(
|
|
108
|
+
...,
|
|
109
|
+
description="Total number of TP outcomes",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
n_observations: int = Field(
|
|
113
|
+
...,
|
|
114
|
+
description="Total number of observations",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# ==========================================================================
|
|
118
|
+
# Best Operating Point
|
|
119
|
+
# ==========================================================================
|
|
120
|
+
|
|
121
|
+
best_f1_quantile: str = Field(
|
|
122
|
+
...,
|
|
123
|
+
description="Quantile with best cumulative F1 score",
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
best_f1_score: float = Field(
|
|
127
|
+
...,
|
|
128
|
+
description="Best cumulative F1 score achieved",
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# ==========================================================================
|
|
132
|
+
# Validation
|
|
133
|
+
# ==========================================================================
|
|
134
|
+
|
|
135
|
+
@model_validator(mode="after")
|
|
136
|
+
def _validate_quantile_keys(self) -> PrecisionRecallResult:
|
|
137
|
+
"""Validate that all quantile-keyed dicts have consistent keys."""
|
|
138
|
+
if self.n_quantiles != len(self.quantile_labels):
|
|
139
|
+
raise ValueError(
|
|
140
|
+
f"n_quantiles ({self.n_quantiles}) != len(quantile_labels) ({len(self.quantile_labels)})"
|
|
141
|
+
)
|
|
142
|
+
_validate_quantile_dict_keys(
|
|
143
|
+
self.quantile_labels,
|
|
144
|
+
[
|
|
145
|
+
("precision_tp", self.precision_tp),
|
|
146
|
+
("recall_tp", self.recall_tp),
|
|
147
|
+
("cumulative_precision_tp", self.cumulative_precision_tp),
|
|
148
|
+
("cumulative_recall_tp", self.cumulative_recall_tp),
|
|
149
|
+
("cumulative_f1_tp", self.cumulative_f1_tp),
|
|
150
|
+
("lift_tp", self.lift_tp),
|
|
151
|
+
("cumulative_lift_tp", self.cumulative_lift_tp),
|
|
152
|
+
],
|
|
153
|
+
)
|
|
154
|
+
return self
|
|
155
|
+
|
|
156
|
+
def get_dataframe(self, name: str | None = None) -> pl.DataFrame:
|
|
157
|
+
"""Get results as Polars DataFrame.
|
|
158
|
+
|
|
159
|
+
Parameters
|
|
160
|
+
----------
|
|
161
|
+
name : str | None
|
|
162
|
+
DataFrame to retrieve:
|
|
163
|
+
- None or "precision_recall": Per-quantile metrics
|
|
164
|
+
- "cumulative": Cumulative metrics from top down
|
|
165
|
+
- "summary": Summary statistics
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
pl.DataFrame
|
|
170
|
+
Requested DataFrame
|
|
171
|
+
"""
|
|
172
|
+
if name is None or name == "precision_recall":
|
|
173
|
+
return pl.DataFrame(
|
|
174
|
+
{
|
|
175
|
+
"quantile": self.quantile_labels,
|
|
176
|
+
"precision_tp": [self.precision_tp[q] for q in self.quantile_labels],
|
|
177
|
+
"recall_tp": [self.recall_tp[q] for q in self.quantile_labels],
|
|
178
|
+
"lift_tp": [self.lift_tp[q] for q in self.quantile_labels],
|
|
179
|
+
}
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if name == "cumulative":
|
|
183
|
+
return pl.DataFrame(
|
|
184
|
+
{
|
|
185
|
+
"quantile": self.quantile_labels,
|
|
186
|
+
"cumulative_precision_tp": [
|
|
187
|
+
self.cumulative_precision_tp[q] for q in self.quantile_labels
|
|
188
|
+
],
|
|
189
|
+
"cumulative_recall_tp": [
|
|
190
|
+
self.cumulative_recall_tp[q] for q in self.quantile_labels
|
|
191
|
+
],
|
|
192
|
+
"cumulative_f1_tp": [self.cumulative_f1_tp[q] for q in self.quantile_labels],
|
|
193
|
+
"cumulative_lift_tp": [
|
|
194
|
+
self.cumulative_lift_tp[q] for q in self.quantile_labels
|
|
195
|
+
],
|
|
196
|
+
}
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if name == "summary":
|
|
200
|
+
return pl.DataFrame(
|
|
201
|
+
{
|
|
202
|
+
"metric": [
|
|
203
|
+
"n_observations",
|
|
204
|
+
"n_quantiles",
|
|
205
|
+
"total_tp_count",
|
|
206
|
+
"baseline_tp_rate",
|
|
207
|
+
"best_f1_quantile",
|
|
208
|
+
"best_f1_score",
|
|
209
|
+
],
|
|
210
|
+
"value": [
|
|
211
|
+
float(self.n_observations),
|
|
212
|
+
float(self.n_quantiles),
|
|
213
|
+
float(self.total_tp_count),
|
|
214
|
+
self.baseline_tp_rate,
|
|
215
|
+
self.best_f1_quantile,
|
|
216
|
+
self.best_f1_score,
|
|
217
|
+
],
|
|
218
|
+
}
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
raise ValueError(
|
|
222
|
+
f"Unknown DataFrame name: {name}. Available: 'precision_recall', 'cumulative', 'summary'"
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
def list_available_dataframes(self) -> list[str]:
|
|
226
|
+
"""List available DataFrame views."""
|
|
227
|
+
return ["precision_recall", "cumulative", "summary"]
|
|
228
|
+
|
|
229
|
+
def summary(self) -> str:
|
|
230
|
+
"""Get human-readable summary of precision/recall results."""
|
|
231
|
+
lines = [
|
|
232
|
+
"=" * 60,
|
|
233
|
+
"Barrier Precision/Recall Analysis (TP-focused)",
|
|
234
|
+
"=" * 60,
|
|
235
|
+
"",
|
|
236
|
+
f"Observations: {self.n_observations:>10,}",
|
|
237
|
+
f"Total TP Count: {self.total_tp_count:>10,}",
|
|
238
|
+
f"Baseline TP Rate: {self.baseline_tp_rate:>10.1%}",
|
|
239
|
+
"",
|
|
240
|
+
f"Best F1 Score: {self.best_f1_score:>10.4f} (at {self.best_f1_quantile})",
|
|
241
|
+
"",
|
|
242
|
+
"-" * 60,
|
|
243
|
+
"Per-Quantile Metrics:",
|
|
244
|
+
"-" * 60,
|
|
245
|
+
f"{'Quantile':<10} {'Precision':>10} {'Recall':>10} {'Lift':>8}",
|
|
246
|
+
]
|
|
247
|
+
|
|
248
|
+
for q in self.quantile_labels:
|
|
249
|
+
lines.append(
|
|
250
|
+
f"{q:<10} {self.precision_tp[q]:>10.1%} {self.recall_tp[q]:>10.1%} "
|
|
251
|
+
f"{self.lift_tp[q]:>8.2f}x"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
lines.append("")
|
|
255
|
+
lines.append("-" * 60)
|
|
256
|
+
lines.append("Cumulative Metrics (from top quantile):")
|
|
257
|
+
lines.append("-" * 60)
|
|
258
|
+
lines.append(f"{'Quantile':<10} {'Cum Prec':>10} {'Cum Recall':>10} {'Cum F1':>10}")
|
|
259
|
+
|
|
260
|
+
for q in self.quantile_labels:
|
|
261
|
+
lines.append(
|
|
262
|
+
f"{q:<10} {self.cumulative_precision_tp[q]:>10.1%} "
|
|
263
|
+
f"{self.cumulative_recall_tp[q]:>10.1%} {self.cumulative_f1_tp[q]:>10.4f}"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return "\n".join(lines)
|