ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,910 @@
|
|
|
1
|
+
"""Binary Classification Metrics for Trading Signal Evaluation.
|
|
2
|
+
|
|
3
|
+
This module provides precision, recall, lift, and coverage metrics for evaluating
|
|
4
|
+
binary trading signals against labeled outcomes. Designed to complement the
|
|
5
|
+
existing Signal Analysis and Feature Diagnostics capabilities.
|
|
6
|
+
|
|
7
|
+
Key Features:
|
|
8
|
+
- Polars-native implementation (fast, memory-efficient)
|
|
9
|
+
- Statistical significance testing (binomial test, proportions z-test)
|
|
10
|
+
- Confidence intervals via Wilson score
|
|
11
|
+
- Sparse signal support (handles low coverage gracefully)
|
|
12
|
+
- Comprehensive report generation
|
|
13
|
+
|
|
14
|
+
Usage Example:
|
|
15
|
+
>>> import polars as pl
|
|
16
|
+
>>> from ml4t.diagnostic.evaluation.binary_metrics import (
|
|
17
|
+
... precision, recall, lift, coverage, binary_classification_report
|
|
18
|
+
... )
|
|
19
|
+
>>>
|
|
20
|
+
>>> # Example data
|
|
21
|
+
>>> signals = pl.Series([1, 0, 1, 1, 0, 1, 0, 0, 1, 0])
|
|
22
|
+
>>> labels = pl.Series([1, 0, 1, 0, 0, 1, 0, 1, 1, 0])
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Compute metrics
|
|
25
|
+
>>> prec = precision(signals, labels)
|
|
26
|
+
>>> rec = recall(signals, labels)
|
|
27
|
+
>>> print(f"Precision: {prec:.3f}, Recall: {rec:.3f}")
|
|
28
|
+
|
|
29
|
+
References:
|
|
30
|
+
Wilson, E.B. (1927). "Probable inference, the law of succession,
|
|
31
|
+
and statistical inference". Journal of the American Statistical
|
|
32
|
+
Association.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from __future__ import annotations
|
|
36
|
+
|
|
37
|
+
from dataclasses import dataclass
|
|
38
|
+
from typing import Literal
|
|
39
|
+
|
|
40
|
+
import numpy as np
|
|
41
|
+
import polars as pl
|
|
42
|
+
from scipy import stats
|
|
43
|
+
|
|
44
|
+
# ============================================================================
|
|
45
|
+
# Core Metrics
|
|
46
|
+
# ============================================================================
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def precision(signals: pl.Series, labels: pl.Series) -> float:
|
|
50
|
+
"""Compute precision: P(label=1 | signal=1).
|
|
51
|
+
|
|
52
|
+
Precision measures the accuracy of positive predictions. In trading:
|
|
53
|
+
- High precision = most signals lead to profitable outcomes
|
|
54
|
+
- Low precision = many false positives (unprofitable trades)
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
signals : pl.Series
|
|
59
|
+
Binary series (1=signal, 0=no signal)
|
|
60
|
+
labels : pl.Series
|
|
61
|
+
Binary series (1=positive outcome, 0=negative outcome)
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
float
|
|
66
|
+
Precision value in [0, 1], or NaN if no signals
|
|
67
|
+
|
|
68
|
+
Formula
|
|
69
|
+
-------
|
|
70
|
+
precision = TP / (TP + FP)
|
|
71
|
+
where TP = true positives, FP = false positives
|
|
72
|
+
"""
|
|
73
|
+
n_signals = signals.sum()
|
|
74
|
+
if n_signals == 0:
|
|
75
|
+
return float("nan")
|
|
76
|
+
|
|
77
|
+
tp = ((signals == 1) & (labels == 1)).sum()
|
|
78
|
+
fp = ((signals == 1) & (labels == 0)).sum()
|
|
79
|
+
|
|
80
|
+
return float(tp / (tp + fp))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def recall(signals: pl.Series, labels: pl.Series) -> float:
|
|
84
|
+
"""Compute recall (sensitivity): P(signal=1 | label=1).
|
|
85
|
+
|
|
86
|
+
Recall measures coverage of positive outcomes. In trading:
|
|
87
|
+
- High recall = captures most profitable opportunities
|
|
88
|
+
- Low recall = misses many profitable opportunities
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
signals : pl.Series
|
|
93
|
+
Binary series (1=signal, 0=no signal)
|
|
94
|
+
labels : pl.Series
|
|
95
|
+
Binary series (1=positive outcome, 0=negative outcome)
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
float
|
|
100
|
+
Recall value in [0, 1], or NaN if no positive labels
|
|
101
|
+
|
|
102
|
+
Formula
|
|
103
|
+
-------
|
|
104
|
+
recall = TP / (TP + FN)
|
|
105
|
+
where TP = true positives, FN = false negatives
|
|
106
|
+
"""
|
|
107
|
+
n_positives = labels.sum()
|
|
108
|
+
if n_positives == 0:
|
|
109
|
+
return float("nan")
|
|
110
|
+
|
|
111
|
+
tp = ((signals == 1) & (labels == 1)).sum()
|
|
112
|
+
fn = ((signals == 0) & (labels == 1)).sum()
|
|
113
|
+
|
|
114
|
+
return float(tp / (tp + fn))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def coverage(signals: pl.Series) -> float:
|
|
118
|
+
"""Compute signal coverage: fraction of observations with signals.
|
|
119
|
+
|
|
120
|
+
Coverage measures how frequently the indicator generates signals:
|
|
121
|
+
- High coverage (>20%) = many trading opportunities
|
|
122
|
+
- Low coverage (<5%) = sparse/rare signals
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
signals : pl.Series
|
|
127
|
+
Binary series (1=signal, 0=no signal)
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
float
|
|
132
|
+
Coverage value in [0, 1]
|
|
133
|
+
|
|
134
|
+
Formula
|
|
135
|
+
-------
|
|
136
|
+
coverage = (# signals) / (# total observations)
|
|
137
|
+
"""
|
|
138
|
+
n = len(signals)
|
|
139
|
+
if n == 0:
|
|
140
|
+
return float("nan")
|
|
141
|
+
return float(signals.sum() / n)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def lift(signals: pl.Series, labels: pl.Series) -> float:
|
|
145
|
+
"""Compute lift: precision / base_rate.
|
|
146
|
+
|
|
147
|
+
Lift measures improvement over random selection:
|
|
148
|
+
- Lift > 1.0 = signal better than random
|
|
149
|
+
- Lift < 1.0 = signal worse than random
|
|
150
|
+
- Lift = 1.0 = signal no better than random
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
signals : pl.Series
|
|
155
|
+
Binary series (1=signal, 0=no signal)
|
|
156
|
+
labels : pl.Series
|
|
157
|
+
Binary series (1=positive outcome, 0=negative outcome)
|
|
158
|
+
|
|
159
|
+
Returns
|
|
160
|
+
-------
|
|
161
|
+
float
|
|
162
|
+
Lift value (typically 0.5 - 3.0), or NaN if no signals or labels
|
|
163
|
+
|
|
164
|
+
Formula
|
|
165
|
+
-------
|
|
166
|
+
lift = precision / base_rate
|
|
167
|
+
where base_rate = P(label=1) overall
|
|
168
|
+
"""
|
|
169
|
+
n = len(labels)
|
|
170
|
+
if n == 0:
|
|
171
|
+
return float("nan")
|
|
172
|
+
|
|
173
|
+
base_rate = labels.sum() / n
|
|
174
|
+
if base_rate == 0 or signals.sum() == 0:
|
|
175
|
+
return float("nan")
|
|
176
|
+
|
|
177
|
+
prec = precision(signals, labels)
|
|
178
|
+
return float(prec / base_rate)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def f1_score(signals: pl.Series, labels: pl.Series) -> float:
|
|
182
|
+
"""Compute F1 score: harmonic mean of precision and recall.
|
|
183
|
+
|
|
184
|
+
F1 balances precision and recall:
|
|
185
|
+
- F1 = 1.0 = perfect precision and recall
|
|
186
|
+
- F1 = 0.0 = zero precision or recall
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
signals : pl.Series
|
|
191
|
+
Binary series (1=signal, 0=no signal)
|
|
192
|
+
labels : pl.Series
|
|
193
|
+
Binary series (1=positive outcome, 0=negative outcome)
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
float
|
|
198
|
+
F1 score in [0, 1], or NaN if undefined
|
|
199
|
+
|
|
200
|
+
Formula
|
|
201
|
+
-------
|
|
202
|
+
F1 = 2 * (precision * recall) / (precision + recall)
|
|
203
|
+
"""
|
|
204
|
+
prec = precision(signals, labels)
|
|
205
|
+
rec = recall(signals, labels)
|
|
206
|
+
|
|
207
|
+
if np.isnan(prec) or np.isnan(rec) or (prec + rec) == 0:
|
|
208
|
+
return float("nan")
|
|
209
|
+
|
|
210
|
+
return 2 * (prec * rec) / (prec + rec)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def specificity(signals: pl.Series, labels: pl.Series) -> float:
|
|
214
|
+
"""Compute specificity: P(signal=0 | label=0).
|
|
215
|
+
|
|
216
|
+
Specificity measures the true negative rate:
|
|
217
|
+
- High specificity = correctly avoids bad trades
|
|
218
|
+
- Low specificity = many false positives
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
signals : pl.Series
|
|
223
|
+
Binary series (1=signal, 0=no signal)
|
|
224
|
+
labels : pl.Series
|
|
225
|
+
Binary series (1=positive outcome, 0=negative outcome)
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
float
|
|
230
|
+
Specificity value in [0, 1], or NaN if no negative labels
|
|
231
|
+
|
|
232
|
+
Formula
|
|
233
|
+
-------
|
|
234
|
+
specificity = TN / (TN + FP)
|
|
235
|
+
where TN = true negatives, FP = false positives
|
|
236
|
+
"""
|
|
237
|
+
n_negatives = (labels == 0).sum()
|
|
238
|
+
if n_negatives == 0:
|
|
239
|
+
return float("nan")
|
|
240
|
+
|
|
241
|
+
tn = ((signals == 0) & (labels == 0)).sum()
|
|
242
|
+
fp = ((signals == 1) & (labels == 0)).sum()
|
|
243
|
+
|
|
244
|
+
return float(tn / (tn + fp))
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def balanced_accuracy(signals: pl.Series, labels: pl.Series) -> float:
|
|
248
|
+
"""Compute balanced accuracy: average of recall and specificity.
|
|
249
|
+
|
|
250
|
+
Balanced accuracy is useful when classes are imbalanced:
|
|
251
|
+
- Equal weight to both positive and negative class performance
|
|
252
|
+
- Range [0, 1], where 0.5 = random classifier
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
signals : pl.Series
|
|
257
|
+
Binary series (1=signal, 0=no signal)
|
|
258
|
+
labels : pl.Series
|
|
259
|
+
Binary series (1=positive outcome, 0=negative outcome)
|
|
260
|
+
|
|
261
|
+
Returns
|
|
262
|
+
-------
|
|
263
|
+
float
|
|
264
|
+
Balanced accuracy in [0, 1], or NaN if undefined
|
|
265
|
+
|
|
266
|
+
Formula
|
|
267
|
+
-------
|
|
268
|
+
balanced_accuracy = (recall + specificity) / 2
|
|
269
|
+
"""
|
|
270
|
+
rec = recall(signals, labels)
|
|
271
|
+
spec = specificity(signals, labels)
|
|
272
|
+
|
|
273
|
+
if np.isnan(rec) or np.isnan(spec):
|
|
274
|
+
return float("nan")
|
|
275
|
+
|
|
276
|
+
return (rec + spec) / 2
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
# ============================================================================
|
|
280
|
+
# Confidence Intervals
|
|
281
|
+
# ============================================================================
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def wilson_score_interval(
|
|
285
|
+
n_successes: int,
|
|
286
|
+
n_trials: int,
|
|
287
|
+
confidence: float = 0.95,
|
|
288
|
+
) -> tuple[float, float]:
|
|
289
|
+
"""Compute Wilson score confidence interval for a proportion.
|
|
290
|
+
|
|
291
|
+
More accurate than normal approximation, especially for small samples
|
|
292
|
+
or extreme proportions. Recommended for trading signal evaluation.
|
|
293
|
+
|
|
294
|
+
Parameters
|
|
295
|
+
----------
|
|
296
|
+
n_successes : int
|
|
297
|
+
Number of successes (e.g., true positives)
|
|
298
|
+
n_trials : int
|
|
299
|
+
Total number of trials (e.g., total signals)
|
|
300
|
+
confidence : float, default 0.95
|
|
301
|
+
Confidence level for the interval
|
|
302
|
+
|
|
303
|
+
Returns
|
|
304
|
+
-------
|
|
305
|
+
tuple[float, float]
|
|
306
|
+
(lower_bound, upper_bound) of the confidence interval
|
|
307
|
+
|
|
308
|
+
References
|
|
309
|
+
----------
|
|
310
|
+
Wilson, E.B. (1927). "Probable inference, the law of succession,
|
|
311
|
+
and statistical inference". Journal of the American Statistical
|
|
312
|
+
Association.
|
|
313
|
+
|
|
314
|
+
Examples
|
|
315
|
+
--------
|
|
316
|
+
>>> lower, upper = wilson_score_interval(45, 100, confidence=0.95)
|
|
317
|
+
>>> print(f"95% CI: [{lower:.3f}, {upper:.3f}]")
|
|
318
|
+
"""
|
|
319
|
+
if n_trials == 0:
|
|
320
|
+
return (float("nan"), float("nan"))
|
|
321
|
+
|
|
322
|
+
z = stats.norm.ppf(1 - (1 - confidence) / 2)
|
|
323
|
+
p_hat = n_successes / n_trials
|
|
324
|
+
|
|
325
|
+
denominator = 1 + z**2 / n_trials
|
|
326
|
+
center = (p_hat + z**2 / (2 * n_trials)) / denominator
|
|
327
|
+
margin = z * np.sqrt((p_hat * (1 - p_hat) + z**2 / (4 * n_trials)) / n_trials) / denominator
|
|
328
|
+
|
|
329
|
+
return (float(center - margin), float(center + margin))
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
# ============================================================================
|
|
333
|
+
# Statistical Tests
|
|
334
|
+
# ============================================================================
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def binomial_test_precision(
|
|
338
|
+
tp: int,
|
|
339
|
+
n: int,
|
|
340
|
+
prevalence: float,
|
|
341
|
+
alternative: Literal["greater", "less", "two-sided"] = "greater",
|
|
342
|
+
) -> float:
|
|
343
|
+
"""Test if precision is significantly better than random using binomial test.
|
|
344
|
+
|
|
345
|
+
Null hypothesis: precision = prevalence (signal no better than random)
|
|
346
|
+
Alternative: precision > prevalence (signal better than random)
|
|
347
|
+
|
|
348
|
+
Parameters
|
|
349
|
+
----------
|
|
350
|
+
tp : int
|
|
351
|
+
True positives (# signals with positive outcomes)
|
|
352
|
+
n : int
|
|
353
|
+
Total signals (# times signal=1)
|
|
354
|
+
prevalence : float
|
|
355
|
+
Base rate P(label=1) in population
|
|
356
|
+
alternative : {'greater', 'less', 'two-sided'}, default 'greater'
|
|
357
|
+
Alternative hypothesis direction
|
|
358
|
+
|
|
359
|
+
Returns
|
|
360
|
+
-------
|
|
361
|
+
float
|
|
362
|
+
p-value for the binomial test
|
|
363
|
+
|
|
364
|
+
Notes
|
|
365
|
+
-----
|
|
366
|
+
Interpretation:
|
|
367
|
+
- p < 0.05 => precision significantly > prevalence (good signal!)
|
|
368
|
+
- p >= 0.05 => precision not significantly better than random
|
|
369
|
+
"""
|
|
370
|
+
if n == 0:
|
|
371
|
+
return float("nan")
|
|
372
|
+
|
|
373
|
+
# Handle edge case where prevalence is 0 or 1
|
|
374
|
+
if prevalence <= 0 or prevalence >= 1:
|
|
375
|
+
return float("nan")
|
|
376
|
+
|
|
377
|
+
result = stats.binomtest(tp, n, prevalence, alternative=alternative)
|
|
378
|
+
return float(result.pvalue)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def proportions_z_test(
|
|
382
|
+
signals: pl.Series,
|
|
383
|
+
labels: pl.Series,
|
|
384
|
+
alternative: Literal["greater", "less", "two-sided"] = "greater",
|
|
385
|
+
) -> tuple[float, float]:
|
|
386
|
+
"""Test if precision differs from base rate using z-test.
|
|
387
|
+
|
|
388
|
+
More powerful than binomial test for large samples (n > 30).
|
|
389
|
+
Null hypothesis: precision = base_rate
|
|
390
|
+
|
|
391
|
+
Parameters
|
|
392
|
+
----------
|
|
393
|
+
signals : pl.Series
|
|
394
|
+
Binary series (1=signal, 0=no signal)
|
|
395
|
+
labels : pl.Series
|
|
396
|
+
Binary series (1=positive outcome, 0=negative outcome)
|
|
397
|
+
alternative : {'greater', 'less', 'two-sided'}, default 'greater'
|
|
398
|
+
Alternative hypothesis direction
|
|
399
|
+
|
|
400
|
+
Returns
|
|
401
|
+
-------
|
|
402
|
+
tuple[float, float]
|
|
403
|
+
(z_statistic, p_value)
|
|
404
|
+
|
|
405
|
+
Notes
|
|
406
|
+
-----
|
|
407
|
+
Interpretation:
|
|
408
|
+
- p < 0.05 => precision significantly different from base rate
|
|
409
|
+
- z > 0 => precision > base rate (good)
|
|
410
|
+
- z < 0 => precision < base rate (bad)
|
|
411
|
+
"""
|
|
412
|
+
n_signals = int(signals.sum())
|
|
413
|
+
n_total = len(labels)
|
|
414
|
+
|
|
415
|
+
if n_signals == 0 or n_total == 0:
|
|
416
|
+
return (float("nan"), float("nan"))
|
|
417
|
+
|
|
418
|
+
# Signal group precision
|
|
419
|
+
tp = int(((signals == 1) & (labels == 1)).sum())
|
|
420
|
+
p1 = tp / n_signals
|
|
421
|
+
|
|
422
|
+
# Population base rate
|
|
423
|
+
p2 = float(labels.sum() / n_total)
|
|
424
|
+
n2 = n_total - n_signals
|
|
425
|
+
|
|
426
|
+
if n2 == 0:
|
|
427
|
+
return (float("nan"), float("nan"))
|
|
428
|
+
|
|
429
|
+
# Pooled proportion
|
|
430
|
+
p_pool = float(labels.sum() / n_total)
|
|
431
|
+
|
|
432
|
+
# Standard error
|
|
433
|
+
se = np.sqrt(p_pool * (1 - p_pool) * (1 / n_signals + 1 / n2))
|
|
434
|
+
|
|
435
|
+
if se == 0:
|
|
436
|
+
return (float("nan"), float("nan"))
|
|
437
|
+
|
|
438
|
+
# Z-statistic
|
|
439
|
+
z = (p1 - p2) / se
|
|
440
|
+
|
|
441
|
+
# P-value
|
|
442
|
+
if alternative == "greater":
|
|
443
|
+
p_value = 1 - stats.norm.cdf(z)
|
|
444
|
+
elif alternative == "less":
|
|
445
|
+
p_value = stats.norm.cdf(z)
|
|
446
|
+
else: # two-sided
|
|
447
|
+
p_value = 2 * (1 - stats.norm.cdf(abs(z)))
|
|
448
|
+
|
|
449
|
+
return (float(z), float(p_value))
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def compare_precisions_z_test(
|
|
453
|
+
signals1: pl.Series,
|
|
454
|
+
labels1: pl.Series,
|
|
455
|
+
signals2: pl.Series,
|
|
456
|
+
labels2: pl.Series,
|
|
457
|
+
alternative: Literal["greater", "less", "two-sided"] = "two-sided",
|
|
458
|
+
) -> tuple[float, float]:
|
|
459
|
+
"""Compare precision between two strategies using z-test.
|
|
460
|
+
|
|
461
|
+
Tests whether strategy 1 has significantly different precision than strategy 2.
|
|
462
|
+
|
|
463
|
+
Parameters
|
|
464
|
+
----------
|
|
465
|
+
signals1 : pl.Series
|
|
466
|
+
Binary signals from strategy 1
|
|
467
|
+
labels1 : pl.Series
|
|
468
|
+
Binary labels for strategy 1
|
|
469
|
+
signals2 : pl.Series
|
|
470
|
+
Binary signals from strategy 2
|
|
471
|
+
labels2 : pl.Series
|
|
472
|
+
Binary labels for strategy 2
|
|
473
|
+
alternative : {'greater', 'less', 'two-sided'}, default 'two-sided'
|
|
474
|
+
Alternative hypothesis direction
|
|
475
|
+
|
|
476
|
+
Returns
|
|
477
|
+
-------
|
|
478
|
+
tuple[float, float]
|
|
479
|
+
(z_statistic, p_value)
|
|
480
|
+
"""
|
|
481
|
+
n1 = int(signals1.sum())
|
|
482
|
+
n2 = int(signals2.sum())
|
|
483
|
+
|
|
484
|
+
if n1 == 0 or n2 == 0:
|
|
485
|
+
return (float("nan"), float("nan"))
|
|
486
|
+
|
|
487
|
+
tp1 = int(((signals1 == 1) & (labels1 == 1)).sum())
|
|
488
|
+
tp2 = int(((signals2 == 1) & (labels2 == 1)).sum())
|
|
489
|
+
|
|
490
|
+
p1 = tp1 / n1
|
|
491
|
+
p2 = tp2 / n2
|
|
492
|
+
|
|
493
|
+
# Pooled proportion
|
|
494
|
+
p_pool = (tp1 + tp2) / (n1 + n2)
|
|
495
|
+
|
|
496
|
+
# Standard error
|
|
497
|
+
se = np.sqrt(p_pool * (1 - p_pool) * (1 / n1 + 1 / n2))
|
|
498
|
+
|
|
499
|
+
if se == 0:
|
|
500
|
+
return (float("nan"), float("nan"))
|
|
501
|
+
|
|
502
|
+
z = (p1 - p2) / se
|
|
503
|
+
|
|
504
|
+
if alternative == "greater":
|
|
505
|
+
p_value = 1 - stats.norm.cdf(z)
|
|
506
|
+
elif alternative == "less":
|
|
507
|
+
p_value = stats.norm.cdf(z)
|
|
508
|
+
else:
|
|
509
|
+
p_value = 2 * (1 - stats.norm.cdf(abs(z)))
|
|
510
|
+
|
|
511
|
+
return (float(z), float(p_value))
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
# ============================================================================
|
|
515
|
+
# Confusion Matrix
|
|
516
|
+
# ============================================================================
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
@dataclass
|
|
520
|
+
class ConfusionMatrix:
|
|
521
|
+
"""Confusion matrix for binary classification.
|
|
522
|
+
|
|
523
|
+
Attributes
|
|
524
|
+
----------
|
|
525
|
+
tp : int
|
|
526
|
+
True positives
|
|
527
|
+
fp : int
|
|
528
|
+
False positives
|
|
529
|
+
tn : int
|
|
530
|
+
True negatives
|
|
531
|
+
fn : int
|
|
532
|
+
False negatives
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
tp: int
|
|
536
|
+
fp: int
|
|
537
|
+
tn: int
|
|
538
|
+
fn: int
|
|
539
|
+
|
|
540
|
+
@property
|
|
541
|
+
def n_signals(self) -> int:
|
|
542
|
+
"""Total positive predictions."""
|
|
543
|
+
return self.tp + self.fp
|
|
544
|
+
|
|
545
|
+
@property
|
|
546
|
+
def n_positives(self) -> int:
|
|
547
|
+
"""Total actual positives."""
|
|
548
|
+
return self.tp + self.fn
|
|
549
|
+
|
|
550
|
+
@property
|
|
551
|
+
def n_negatives(self) -> int:
|
|
552
|
+
"""Total actual negatives."""
|
|
553
|
+
return self.tn + self.fp
|
|
554
|
+
|
|
555
|
+
@property
|
|
556
|
+
def n_total(self) -> int:
|
|
557
|
+
"""Total observations."""
|
|
558
|
+
return self.tp + self.fp + self.tn + self.fn
|
|
559
|
+
|
|
560
|
+
def to_dict(self) -> dict[str, int]:
|
|
561
|
+
"""Convert to dictionary."""
|
|
562
|
+
return {
|
|
563
|
+
"tp": self.tp,
|
|
564
|
+
"fp": self.fp,
|
|
565
|
+
"tn": self.tn,
|
|
566
|
+
"fn": self.fn,
|
|
567
|
+
"n_signals": self.n_signals,
|
|
568
|
+
"n_positives": self.n_positives,
|
|
569
|
+
"n_negatives": self.n_negatives,
|
|
570
|
+
"n_total": self.n_total,
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def compute_confusion_matrix(signals: pl.Series, labels: pl.Series) -> ConfusionMatrix:
|
|
575
|
+
"""Compute confusion matrix from signals and labels.
|
|
576
|
+
|
|
577
|
+
Parameters
|
|
578
|
+
----------
|
|
579
|
+
signals : pl.Series
|
|
580
|
+
Binary series (1=signal, 0=no signal)
|
|
581
|
+
labels : pl.Series
|
|
582
|
+
Binary series (1=positive outcome, 0=negative outcome)
|
|
583
|
+
|
|
584
|
+
Returns
|
|
585
|
+
-------
|
|
586
|
+
ConfusionMatrix
|
|
587
|
+
Confusion matrix with tp, fp, tn, fn
|
|
588
|
+
"""
|
|
589
|
+
tp = int(((signals == 1) & (labels == 1)).sum())
|
|
590
|
+
fp = int(((signals == 1) & (labels == 0)).sum())
|
|
591
|
+
tn = int(((signals == 0) & (labels == 0)).sum())
|
|
592
|
+
fn = int(((signals == 0) & (labels == 1)).sum())
|
|
593
|
+
|
|
594
|
+
return ConfusionMatrix(tp=tp, fp=fp, tn=tn, fn=fn)
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
# ============================================================================
|
|
598
|
+
# Comprehensive Report
|
|
599
|
+
# ============================================================================
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
@dataclass
|
|
603
|
+
class BinaryClassificationReport:
|
|
604
|
+
"""Comprehensive binary classification report.
|
|
605
|
+
|
|
606
|
+
Attributes
|
|
607
|
+
----------
|
|
608
|
+
precision : float
|
|
609
|
+
Precision (positive predictive value)
|
|
610
|
+
recall : float
|
|
611
|
+
Recall (sensitivity, true positive rate)
|
|
612
|
+
f1_score : float
|
|
613
|
+
Harmonic mean of precision and recall
|
|
614
|
+
specificity : float
|
|
615
|
+
True negative rate
|
|
616
|
+
balanced_accuracy : float
|
|
617
|
+
Average of recall and specificity
|
|
618
|
+
lift : float
|
|
619
|
+
Improvement over random selection
|
|
620
|
+
coverage : float
|
|
621
|
+
Fraction of observations with signals
|
|
622
|
+
confusion_matrix : ConfusionMatrix
|
|
623
|
+
Confusion matrix details
|
|
624
|
+
base_rate : float
|
|
625
|
+
Population prevalence of positive class
|
|
626
|
+
precision_ci : tuple[float, float]
|
|
627
|
+
Wilson score CI for precision
|
|
628
|
+
recall_ci : tuple[float, float]
|
|
629
|
+
Wilson score CI for recall
|
|
630
|
+
binomial_pvalue : float
|
|
631
|
+
P-value for binomial test of precision > base_rate
|
|
632
|
+
z_test_stat : float
|
|
633
|
+
Z-statistic for precision vs base_rate
|
|
634
|
+
z_test_pvalue : float
|
|
635
|
+
P-value for z-test
|
|
636
|
+
mean_return_on_signal : float | None
|
|
637
|
+
Mean return when signal=1 (if returns provided)
|
|
638
|
+
mean_return_no_signal : float | None
|
|
639
|
+
Mean return when signal=0 (if returns provided)
|
|
640
|
+
return_lift : float | None
|
|
641
|
+
Ratio of signal return to no-signal return (if returns provided)
|
|
642
|
+
"""
|
|
643
|
+
|
|
644
|
+
precision: float
|
|
645
|
+
recall: float
|
|
646
|
+
f1_score: float
|
|
647
|
+
specificity: float
|
|
648
|
+
balanced_accuracy: float
|
|
649
|
+
lift: float
|
|
650
|
+
coverage: float
|
|
651
|
+
confusion_matrix: ConfusionMatrix
|
|
652
|
+
base_rate: float
|
|
653
|
+
precision_ci: tuple[float, float]
|
|
654
|
+
recall_ci: tuple[float, float]
|
|
655
|
+
binomial_pvalue: float
|
|
656
|
+
z_test_stat: float
|
|
657
|
+
z_test_pvalue: float
|
|
658
|
+
mean_return_on_signal: float | None = None
|
|
659
|
+
mean_return_no_signal: float | None = None
|
|
660
|
+
return_lift: float | None = None
|
|
661
|
+
|
|
662
|
+
def to_dict(self) -> dict:
|
|
663
|
+
"""Convert report to dictionary."""
|
|
664
|
+
result = {
|
|
665
|
+
"precision": self.precision,
|
|
666
|
+
"recall": self.recall,
|
|
667
|
+
"f1_score": self.f1_score,
|
|
668
|
+
"specificity": self.specificity,
|
|
669
|
+
"balanced_accuracy": self.balanced_accuracy,
|
|
670
|
+
"lift": self.lift,
|
|
671
|
+
"coverage": self.coverage,
|
|
672
|
+
"base_rate": self.base_rate,
|
|
673
|
+
"precision_ci": self.precision_ci,
|
|
674
|
+
"recall_ci": self.recall_ci,
|
|
675
|
+
"binomial_pvalue": self.binomial_pvalue,
|
|
676
|
+
"z_test_stat": self.z_test_stat,
|
|
677
|
+
"z_test_pvalue": self.z_test_pvalue,
|
|
678
|
+
**self.confusion_matrix.to_dict(),
|
|
679
|
+
}
|
|
680
|
+
if self.mean_return_on_signal is not None:
|
|
681
|
+
result["mean_return_on_signal"] = self.mean_return_on_signal
|
|
682
|
+
result["mean_return_no_signal"] = self.mean_return_no_signal
|
|
683
|
+
result["return_lift"] = self.return_lift
|
|
684
|
+
return result
|
|
685
|
+
|
|
686
|
+
@property
|
|
687
|
+
def is_significant(self) -> bool:
|
|
688
|
+
"""Whether precision is significantly better than base rate at p<0.05."""
|
|
689
|
+
return self.binomial_pvalue < 0.05
|
|
690
|
+
|
|
691
|
+
@property
|
|
692
|
+
def is_sparse(self) -> bool:
|
|
693
|
+
"""Whether signal coverage is below 5%."""
|
|
694
|
+
return self.coverage < 0.05
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def binary_classification_report(
|
|
698
|
+
signals: pl.Series,
|
|
699
|
+
labels: pl.Series,
|
|
700
|
+
returns: pl.Series | None = None,
|
|
701
|
+
confidence: float = 0.95,
|
|
702
|
+
) -> BinaryClassificationReport:
|
|
703
|
+
"""Generate comprehensive binary classification report for trading signal.
|
|
704
|
+
|
|
705
|
+
Computes all key metrics with confidence intervals and statistical tests.
|
|
706
|
+
|
|
707
|
+
Parameters
|
|
708
|
+
----------
|
|
709
|
+
signals : pl.Series
|
|
710
|
+
Binary series (1=signal, 0=no signal)
|
|
711
|
+
labels : pl.Series
|
|
712
|
+
Binary series (1=positive outcome, 0=negative outcome)
|
|
713
|
+
returns : pl.Series, optional
|
|
714
|
+
Series of returns for additional return analysis
|
|
715
|
+
confidence : float, default 0.95
|
|
716
|
+
Confidence level for Wilson score intervals
|
|
717
|
+
|
|
718
|
+
Returns
|
|
719
|
+
-------
|
|
720
|
+
BinaryClassificationReport
|
|
721
|
+
Comprehensive report with all metrics, CIs, and statistical tests
|
|
722
|
+
|
|
723
|
+
Examples
|
|
724
|
+
--------
|
|
725
|
+
>>> report = binary_classification_report(signals, labels)
|
|
726
|
+
>>> print(f"Precision: {report.precision:.3f} "
|
|
727
|
+
... f"[{report.precision_ci[0]:.3f}, {report.precision_ci[1]:.3f}]")
|
|
728
|
+
>>> print(f"Statistical significance: p={report.binomial_pvalue:.4f}")
|
|
729
|
+
>>> if report.is_significant:
|
|
730
|
+
... print("Signal is significantly better than random!")
|
|
731
|
+
"""
|
|
732
|
+
# Compute confusion matrix
|
|
733
|
+
cm = compute_confusion_matrix(signals, labels)
|
|
734
|
+
|
|
735
|
+
# Basic metrics
|
|
736
|
+
prec = precision(signals, labels)
|
|
737
|
+
rec = recall(signals, labels)
|
|
738
|
+
f1 = f1_score(signals, labels)
|
|
739
|
+
spec = specificity(signals, labels)
|
|
740
|
+
bal_acc = balanced_accuracy(signals, labels)
|
|
741
|
+
lift_val = lift(signals, labels)
|
|
742
|
+
cov = coverage(signals)
|
|
743
|
+
|
|
744
|
+
# Base rate
|
|
745
|
+
base_rate = cm.n_positives / cm.n_total if cm.n_total > 0 else float("nan")
|
|
746
|
+
|
|
747
|
+
# Confidence intervals
|
|
748
|
+
prec_ci = wilson_score_interval(cm.tp, cm.n_signals, confidence)
|
|
749
|
+
rec_ci = wilson_score_interval(cm.tp, cm.n_positives, confidence)
|
|
750
|
+
|
|
751
|
+
# Statistical tests
|
|
752
|
+
binom_pvalue = binomial_test_precision(cm.tp, cm.n_signals, base_rate)
|
|
753
|
+
z_stat, z_pvalue = proportions_z_test(signals, labels)
|
|
754
|
+
|
|
755
|
+
# Returns analysis (if provided)
|
|
756
|
+
mean_ret_signal = None
|
|
757
|
+
mean_ret_no_signal = None
|
|
758
|
+
ret_lift = None
|
|
759
|
+
|
|
760
|
+
if returns is not None:
|
|
761
|
+
signal_mask = signals == 1
|
|
762
|
+
no_signal_mask = signals == 0
|
|
763
|
+
|
|
764
|
+
if signal_mask.sum() > 0:
|
|
765
|
+
val = returns.filter(signal_mask).mean()
|
|
766
|
+
if val is not None and isinstance(val, int | float):
|
|
767
|
+
mean_ret_signal = float(val)
|
|
768
|
+
if no_signal_mask.sum() > 0:
|
|
769
|
+
val = returns.filter(no_signal_mask).mean()
|
|
770
|
+
if val is not None and isinstance(val, int | float):
|
|
771
|
+
mean_ret_no_signal = float(val)
|
|
772
|
+
|
|
773
|
+
if (
|
|
774
|
+
mean_ret_signal is not None
|
|
775
|
+
and mean_ret_no_signal is not None
|
|
776
|
+
and mean_ret_no_signal != 0
|
|
777
|
+
):
|
|
778
|
+
ret_lift = mean_ret_signal / mean_ret_no_signal
|
|
779
|
+
|
|
780
|
+
return BinaryClassificationReport(
|
|
781
|
+
precision=prec,
|
|
782
|
+
recall=rec,
|
|
783
|
+
f1_score=f1,
|
|
784
|
+
specificity=spec,
|
|
785
|
+
balanced_accuracy=bal_acc,
|
|
786
|
+
lift=lift_val,
|
|
787
|
+
coverage=cov,
|
|
788
|
+
confusion_matrix=cm,
|
|
789
|
+
base_rate=base_rate,
|
|
790
|
+
precision_ci=prec_ci,
|
|
791
|
+
recall_ci=rec_ci,
|
|
792
|
+
binomial_pvalue=binom_pvalue,
|
|
793
|
+
z_test_stat=z_stat,
|
|
794
|
+
z_test_pvalue=z_pvalue,
|
|
795
|
+
mean_return_on_signal=mean_ret_signal,
|
|
796
|
+
mean_return_no_signal=mean_ret_no_signal,
|
|
797
|
+
return_lift=ret_lift,
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
def format_classification_report(report: BinaryClassificationReport) -> str:
|
|
802
|
+
"""Format binary classification report as human-readable string.
|
|
803
|
+
|
|
804
|
+
Parameters
|
|
805
|
+
----------
|
|
806
|
+
report : BinaryClassificationReport
|
|
807
|
+
Report from binary_classification_report()
|
|
808
|
+
|
|
809
|
+
Returns
|
|
810
|
+
-------
|
|
811
|
+
str
|
|
812
|
+
Formatted string with metrics and interpretation
|
|
813
|
+
"""
|
|
814
|
+
cm = report.confusion_matrix
|
|
815
|
+
|
|
816
|
+
lines = [
|
|
817
|
+
"Binary Classification Report",
|
|
818
|
+
"=" * 50,
|
|
819
|
+
"",
|
|
820
|
+
f"Sample Size: {cm.n_total:,}",
|
|
821
|
+
f"Base Rate: {report.base_rate:.3f} ({cm.n_positives:,} positives)",
|
|
822
|
+
"",
|
|
823
|
+
"Metrics:",
|
|
824
|
+
f" Precision: {report.precision:.3f} "
|
|
825
|
+
f"[{report.precision_ci[0]:.3f}, {report.precision_ci[1]:.3f}]",
|
|
826
|
+
f" Recall: {report.recall:.3f} "
|
|
827
|
+
f"[{report.recall_ci[0]:.3f}, {report.recall_ci[1]:.3f}]",
|
|
828
|
+
f" F1 Score: {report.f1_score:.3f}",
|
|
829
|
+
f" Specificity: {report.specificity:.3f}",
|
|
830
|
+
f" Balanced Acc: {report.balanced_accuracy:.3f}",
|
|
831
|
+
f" Lift: {report.lift:.3f}",
|
|
832
|
+
f" Coverage: {report.coverage:.3f} ({cm.n_signals:,} signals)",
|
|
833
|
+
"",
|
|
834
|
+
"Confusion Matrix:",
|
|
835
|
+
f" TP: {cm.tp:>6,} FP: {cm.fp:>6,}",
|
|
836
|
+
f" FN: {cm.fn:>6,} TN: {cm.tn:>6,}",
|
|
837
|
+
"",
|
|
838
|
+
"Statistical Significance:",
|
|
839
|
+
f" Binomial test p-value: {report.binomial_pvalue:.4f}",
|
|
840
|
+
f" Z-test statistic: {report.z_test_stat:.3f}",
|
|
841
|
+
f" Z-test p-value: {report.z_test_pvalue:.4f}",
|
|
842
|
+
]
|
|
843
|
+
|
|
844
|
+
# Add returns analysis if available
|
|
845
|
+
if report.mean_return_on_signal is not None:
|
|
846
|
+
lines.extend(
|
|
847
|
+
[
|
|
848
|
+
"",
|
|
849
|
+
"Returns Analysis:",
|
|
850
|
+
f" Mean return (signal): {report.mean_return_on_signal:.4f}",
|
|
851
|
+
f" Mean return (no signal): {report.mean_return_no_signal:.4f}",
|
|
852
|
+
f" Return lift: {report.return_lift:.3f}",
|
|
853
|
+
]
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
# Interpretation
|
|
857
|
+
lines.extend(["", "Interpretation:"])
|
|
858
|
+
|
|
859
|
+
if report.is_significant:
|
|
860
|
+
lines.append(" [+] Signal precision significantly > base rate (p < 0.05)")
|
|
861
|
+
else:
|
|
862
|
+
lines.append(" [-] Signal precision NOT significantly > base rate (p >= 0.05)")
|
|
863
|
+
|
|
864
|
+
if report.lift > 1.2:
|
|
865
|
+
lines.append(" [+] Strong lift (>1.2x better than random)")
|
|
866
|
+
elif report.lift > 1.0:
|
|
867
|
+
lines.append(" [~] Moderate lift (>1.0x better than random)")
|
|
868
|
+
else:
|
|
869
|
+
lines.append(" [-] No lift (<= 1.0x, not better than random)")
|
|
870
|
+
|
|
871
|
+
if report.is_sparse:
|
|
872
|
+
lines.append(" [!] Very sparse signals (<5% coverage)")
|
|
873
|
+
elif report.coverage > 0.20:
|
|
874
|
+
lines.append(" [+] High signal frequency (>20% coverage)")
|
|
875
|
+
|
|
876
|
+
return "\n".join(lines)
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
# ============================================================================
|
|
880
|
+
# Convenience Functions
|
|
881
|
+
# ============================================================================
|
|
882
|
+
|
|
883
|
+
|
|
884
|
+
def compute_all_metrics(
|
|
885
|
+
signals: pl.Series,
|
|
886
|
+
labels: pl.Series,
|
|
887
|
+
) -> dict[str, float]:
|
|
888
|
+
"""Compute all binary classification metrics.
|
|
889
|
+
|
|
890
|
+
Parameters
|
|
891
|
+
----------
|
|
892
|
+
signals : pl.Series
|
|
893
|
+
Binary series (1=signal, 0=no signal)
|
|
894
|
+
labels : pl.Series
|
|
895
|
+
Binary series (1=positive outcome, 0=negative outcome)
|
|
896
|
+
|
|
897
|
+
Returns
|
|
898
|
+
-------
|
|
899
|
+
dict[str, float]
|
|
900
|
+
Dictionary with all metric values
|
|
901
|
+
"""
|
|
902
|
+
return {
|
|
903
|
+
"precision": precision(signals, labels),
|
|
904
|
+
"recall": recall(signals, labels),
|
|
905
|
+
"f1_score": f1_score(signals, labels),
|
|
906
|
+
"specificity": specificity(signals, labels),
|
|
907
|
+
"balanced_accuracy": balanced_accuracy(signals, labels),
|
|
908
|
+
"lift": lift(signals, labels),
|
|
909
|
+
"coverage": coverage(signals),
|
|
910
|
+
}
|