ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,957 @@
|
|
|
1
|
+
"""Threshold Analysis for Trading Signal Optimization.
|
|
2
|
+
|
|
3
|
+
This module provides tools for evaluating trading signals across multiple thresholds
|
|
4
|
+
and finding optimal threshold values. Essential for:
|
|
5
|
+
- Identifying optimal thresholds for indicator-based signals
|
|
6
|
+
- Understanding metric behavior and trade-offs
|
|
7
|
+
- Detecting monotonicity violations
|
|
8
|
+
- Assessing threshold sensitivity
|
|
9
|
+
|
|
10
|
+
The module integrates with binary_metrics to compute metrics at each threshold.
|
|
11
|
+
|
|
12
|
+
Usage Example:
|
|
13
|
+
>>> import polars as pl
|
|
14
|
+
>>> from ml4t.diagnostic.evaluation.threshold_analysis import (
|
|
15
|
+
... evaluate_threshold_sweep,
|
|
16
|
+
... find_optimal_threshold,
|
|
17
|
+
... check_monotonicity,
|
|
18
|
+
... )
|
|
19
|
+
>>>
|
|
20
|
+
>>> # Indicator values and labels
|
|
21
|
+
>>> indicator = pl.Series([45, 55, 65, 75, 85, 35, 72, 68, 52, 88])
|
|
22
|
+
>>> labels = pl.Series([0, 0, 1, 1, 1, 0, 1, 1, 0, 1])
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Sweep across thresholds
|
|
25
|
+
>>> thresholds = [50, 60, 70, 80]
|
|
26
|
+
>>> results = evaluate_threshold_sweep(indicator, labels, thresholds)
|
|
27
|
+
>>> print(results)
|
|
28
|
+
>>>
|
|
29
|
+
>>> # Find optimal threshold
|
|
30
|
+
>>> optimal = find_optimal_threshold(results, metric="lift", min_coverage=0.1)
|
|
31
|
+
>>> print(f"Optimal threshold: {optimal['threshold']}")
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from __future__ import annotations
|
|
35
|
+
|
|
36
|
+
from dataclasses import dataclass
|
|
37
|
+
from typing import Literal
|
|
38
|
+
|
|
39
|
+
import numpy as np
|
|
40
|
+
import polars as pl
|
|
41
|
+
|
|
42
|
+
from .binary_metrics import (
|
|
43
|
+
binary_classification_report,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# ============================================================================
|
|
47
|
+
# Core Threshold Sweep
|
|
48
|
+
# ============================================================================
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def evaluate_threshold_sweep(
|
|
52
|
+
indicator: pl.Series,
|
|
53
|
+
labels: pl.Series,
|
|
54
|
+
thresholds: list[float],
|
|
55
|
+
direction: Literal["above", "below"] = "above",
|
|
56
|
+
returns: pl.Series | None = None,
|
|
57
|
+
) -> pl.DataFrame:
|
|
58
|
+
"""Evaluate binary classification metrics across multiple thresholds.
|
|
59
|
+
|
|
60
|
+
For each threshold, generates binary signals and computes precision, recall,
|
|
61
|
+
F1, lift, coverage, and statistical significance.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
indicator : pl.Series
|
|
66
|
+
Continuous indicator values (e.g., RSI, momentum score)
|
|
67
|
+
labels : pl.Series
|
|
68
|
+
Binary labels (1=positive outcome, 0=negative outcome)
|
|
69
|
+
thresholds : list[float]
|
|
70
|
+
List of threshold values to evaluate
|
|
71
|
+
direction : {'above', 'below'}, default 'above'
|
|
72
|
+
Signal direction:
|
|
73
|
+
- 'above': signal=1 when indicator > threshold
|
|
74
|
+
- 'below': signal=1 when indicator < threshold
|
|
75
|
+
returns : pl.Series, optional
|
|
76
|
+
Returns for additional return analysis
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
pl.DataFrame
|
|
81
|
+
DataFrame with columns:
|
|
82
|
+
- threshold: threshold value
|
|
83
|
+
- precision, recall, f1_score, specificity, lift, coverage
|
|
84
|
+
- n_signals, n_positives, n_total
|
|
85
|
+
- binomial_pvalue, z_test_pvalue
|
|
86
|
+
- is_significant: whether precision > base_rate at p<0.05
|
|
87
|
+
- mean_return_on_signal (if returns provided)
|
|
88
|
+
|
|
89
|
+
Examples
|
|
90
|
+
--------
|
|
91
|
+
>>> indicator = pl.Series([45, 55, 65, 75, 85, 35, 72, 68, 52, 88])
|
|
92
|
+
>>> labels = pl.Series([0, 0, 1, 1, 1, 0, 1, 1, 0, 1])
|
|
93
|
+
>>> results = evaluate_threshold_sweep(indicator, labels, [50, 60, 70, 80])
|
|
94
|
+
>>> print(results.select(["threshold", "precision", "lift", "coverage"]))
|
|
95
|
+
"""
|
|
96
|
+
if len(indicator) != len(labels):
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"indicator and labels must have same length, got {len(indicator)} and {len(labels)}"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if len(thresholds) == 0:
|
|
102
|
+
raise ValueError("thresholds must not be empty")
|
|
103
|
+
|
|
104
|
+
results = []
|
|
105
|
+
|
|
106
|
+
for threshold in sorted(thresholds):
|
|
107
|
+
# Generate signals based on direction
|
|
108
|
+
if direction == "above":
|
|
109
|
+
signals = (indicator > threshold).cast(pl.Int8)
|
|
110
|
+
else:
|
|
111
|
+
signals = (indicator < threshold).cast(pl.Int8)
|
|
112
|
+
|
|
113
|
+
# Get comprehensive report
|
|
114
|
+
report = binary_classification_report(signals, labels, returns=returns)
|
|
115
|
+
|
|
116
|
+
row = {
|
|
117
|
+
"threshold": threshold,
|
|
118
|
+
"precision": report.precision,
|
|
119
|
+
"recall": report.recall,
|
|
120
|
+
"f1_score": report.f1_score,
|
|
121
|
+
"specificity": report.specificity,
|
|
122
|
+
"lift": report.lift,
|
|
123
|
+
"coverage": report.coverage,
|
|
124
|
+
"n_signals": report.confusion_matrix.n_signals,
|
|
125
|
+
"n_positives": report.confusion_matrix.n_positives,
|
|
126
|
+
"n_total": report.confusion_matrix.n_total,
|
|
127
|
+
"base_rate": report.base_rate,
|
|
128
|
+
"binomial_pvalue": report.binomial_pvalue,
|
|
129
|
+
"z_test_pvalue": report.z_test_pvalue,
|
|
130
|
+
"is_significant": 1.0 if report.is_significant else 0.0,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
# Add return metrics if available
|
|
134
|
+
if returns is not None and report.mean_return_on_signal is not None:
|
|
135
|
+
row["mean_return_on_signal"] = report.mean_return_on_signal or 0.0
|
|
136
|
+
row["mean_return_no_signal"] = report.mean_return_no_signal or 0.0
|
|
137
|
+
row["return_lift"] = report.return_lift or 0.0
|
|
138
|
+
|
|
139
|
+
results.append(row)
|
|
140
|
+
|
|
141
|
+
return pl.DataFrame(results)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def evaluate_percentile_thresholds(
|
|
145
|
+
indicator: pl.Series,
|
|
146
|
+
labels: pl.Series,
|
|
147
|
+
percentiles: list[float] | None = None,
|
|
148
|
+
direction: Literal["above", "below"] = "above",
|
|
149
|
+
returns: pl.Series | None = None,
|
|
150
|
+
) -> pl.DataFrame:
|
|
151
|
+
"""Evaluate thresholds at indicator percentiles.
|
|
152
|
+
|
|
153
|
+
Instead of specifying absolute threshold values, this function computes
|
|
154
|
+
thresholds based on the indicator's distribution. Useful when the indicator
|
|
155
|
+
scale varies across assets or time periods.
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
indicator : pl.Series
|
|
160
|
+
Continuous indicator values
|
|
161
|
+
labels : pl.Series
|
|
162
|
+
Binary labels
|
|
163
|
+
percentiles : list[float], optional
|
|
164
|
+
Percentiles to evaluate (default: [10, 25, 50, 75, 90])
|
|
165
|
+
direction : {'above', 'below'}, default 'above'
|
|
166
|
+
Signal direction
|
|
167
|
+
returns : pl.Series, optional
|
|
168
|
+
Returns for additional analysis
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
pl.DataFrame
|
|
173
|
+
Same as evaluate_threshold_sweep, with additional 'percentile' column
|
|
174
|
+
|
|
175
|
+
Examples
|
|
176
|
+
--------
|
|
177
|
+
>>> results = evaluate_percentile_thresholds(indicator, labels)
|
|
178
|
+
>>> print(results.select(["percentile", "threshold", "precision"]))
|
|
179
|
+
"""
|
|
180
|
+
if percentiles is None:
|
|
181
|
+
percentiles = [10.0, 25.0, 50.0, 75.0, 90.0]
|
|
182
|
+
|
|
183
|
+
# Compute threshold values at each percentile
|
|
184
|
+
thresholds = []
|
|
185
|
+
for p in percentiles:
|
|
186
|
+
q_val = indicator.quantile(p / 100.0)
|
|
187
|
+
threshold = float(q_val) if q_val is not None else 0.0
|
|
188
|
+
thresholds.append(threshold)
|
|
189
|
+
|
|
190
|
+
# Evaluate
|
|
191
|
+
results = evaluate_threshold_sweep(indicator, labels, thresholds, direction, returns)
|
|
192
|
+
|
|
193
|
+
# Add percentile column
|
|
194
|
+
results = results.with_columns(pl.Series("percentile", percentiles))
|
|
195
|
+
|
|
196
|
+
# Reorder columns
|
|
197
|
+
cols = ["percentile", "threshold"] + [
|
|
198
|
+
c for c in results.columns if c not in ["percentile", "threshold"]
|
|
199
|
+
]
|
|
200
|
+
return results.select(cols)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# ============================================================================
|
|
204
|
+
# Optimal Threshold Finding
|
|
205
|
+
# ============================================================================
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@dataclass
|
|
209
|
+
class OptimalThresholdResult:
|
|
210
|
+
"""Result of optimal threshold search.
|
|
211
|
+
|
|
212
|
+
Attributes
|
|
213
|
+
----------
|
|
214
|
+
threshold : float | None
|
|
215
|
+
Optimal threshold value, or None if no valid threshold found
|
|
216
|
+
found : bool
|
|
217
|
+
Whether a valid threshold was found
|
|
218
|
+
metric : str
|
|
219
|
+
Metric that was optimized
|
|
220
|
+
metric_value : float | None
|
|
221
|
+
Value of the optimized metric at optimal threshold
|
|
222
|
+
precision : float | None
|
|
223
|
+
Precision at optimal threshold
|
|
224
|
+
recall : float | None
|
|
225
|
+
Recall at optimal threshold
|
|
226
|
+
f1_score : float | None
|
|
227
|
+
F1 score at optimal threshold
|
|
228
|
+
lift : float | None
|
|
229
|
+
Lift at optimal threshold
|
|
230
|
+
coverage : float | None
|
|
231
|
+
Coverage at optimal threshold
|
|
232
|
+
n_signals : int | None
|
|
233
|
+
Number of signals at optimal threshold
|
|
234
|
+
is_significant : bool
|
|
235
|
+
Whether optimal threshold is statistically significant
|
|
236
|
+
reason : str | None
|
|
237
|
+
Reason if no threshold found
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
threshold: float | None
|
|
241
|
+
found: bool
|
|
242
|
+
metric: str
|
|
243
|
+
metric_value: float | None = None
|
|
244
|
+
precision: float | None = None
|
|
245
|
+
recall: float | None = None
|
|
246
|
+
f1_score: float | None = None
|
|
247
|
+
lift: float | None = None
|
|
248
|
+
coverage: float | None = None
|
|
249
|
+
n_signals: int | None = None
|
|
250
|
+
is_significant: bool = False
|
|
251
|
+
reason: str | None = None
|
|
252
|
+
|
|
253
|
+
def to_dict(self) -> dict:
|
|
254
|
+
"""Convert to dictionary."""
|
|
255
|
+
return {
|
|
256
|
+
"threshold": self.threshold,
|
|
257
|
+
"found": self.found,
|
|
258
|
+
"metric": self.metric,
|
|
259
|
+
"metric_value": self.metric_value,
|
|
260
|
+
"precision": self.precision,
|
|
261
|
+
"recall": self.recall,
|
|
262
|
+
"f1_score": self.f1_score,
|
|
263
|
+
"lift": self.lift,
|
|
264
|
+
"coverage": self.coverage,
|
|
265
|
+
"n_signals": self.n_signals,
|
|
266
|
+
"is_significant": self.is_significant,
|
|
267
|
+
"reason": self.reason,
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def find_optimal_threshold(
|
|
272
|
+
results_df: pl.DataFrame,
|
|
273
|
+
metric: str = "lift",
|
|
274
|
+
min_coverage: float = 0.01,
|
|
275
|
+
max_coverage: float = 1.0,
|
|
276
|
+
require_significant: bool = False,
|
|
277
|
+
min_signals: int = 1,
|
|
278
|
+
) -> OptimalThresholdResult:
|
|
279
|
+
"""Find optimal threshold based on specified metric and constraints.
|
|
280
|
+
|
|
281
|
+
Parameters
|
|
282
|
+
----------
|
|
283
|
+
results_df : pl.DataFrame
|
|
284
|
+
DataFrame from evaluate_threshold_sweep()
|
|
285
|
+
metric : str, default 'lift'
|
|
286
|
+
Metric to optimize ('lift', 'precision', 'f1_score', 'recall')
|
|
287
|
+
min_coverage : float, default 0.01
|
|
288
|
+
Minimum required signal coverage (1%)
|
|
289
|
+
max_coverage : float, default 1.0
|
|
290
|
+
Maximum allowed signal coverage
|
|
291
|
+
require_significant : bool, default False
|
|
292
|
+
Only consider statistically significant thresholds
|
|
293
|
+
min_signals : int, default 1
|
|
294
|
+
Minimum number of signals required
|
|
295
|
+
|
|
296
|
+
Returns
|
|
297
|
+
-------
|
|
298
|
+
OptimalThresholdResult
|
|
299
|
+
Result containing optimal threshold and associated metrics
|
|
300
|
+
|
|
301
|
+
Examples
|
|
302
|
+
--------
|
|
303
|
+
>>> results = evaluate_threshold_sweep(indicator, labels, thresholds)
|
|
304
|
+
>>> optimal = find_optimal_threshold(results, metric="lift", min_coverage=0.05)
|
|
305
|
+
>>> if optimal.found:
|
|
306
|
+
... print(f"Optimal threshold: {optimal.threshold}")
|
|
307
|
+
... print(f"Lift: {optimal.lift:.2f}")
|
|
308
|
+
"""
|
|
309
|
+
if metric not in results_df.columns:
|
|
310
|
+
return OptimalThresholdResult(
|
|
311
|
+
threshold=None,
|
|
312
|
+
found=False,
|
|
313
|
+
metric=metric,
|
|
314
|
+
reason=f"Metric '{metric}' not in results",
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Apply filters
|
|
318
|
+
filtered = results_df
|
|
319
|
+
|
|
320
|
+
# Coverage constraints
|
|
321
|
+
if "coverage" in filtered.columns:
|
|
322
|
+
filtered = filtered.filter(
|
|
323
|
+
(pl.col("coverage") >= min_coverage) & (pl.col("coverage") <= max_coverage)
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Minimum signals
|
|
327
|
+
if "n_signals" in filtered.columns:
|
|
328
|
+
filtered = filtered.filter(pl.col("n_signals") >= min_signals)
|
|
329
|
+
|
|
330
|
+
# Statistical significance
|
|
331
|
+
if require_significant and "is_significant" in filtered.columns:
|
|
332
|
+
filtered = filtered.filter(pl.col("is_significant") == 1.0)
|
|
333
|
+
|
|
334
|
+
# Filter out NaN values in target metric
|
|
335
|
+
filtered = filtered.filter(pl.col(metric).is_not_nan())
|
|
336
|
+
|
|
337
|
+
if len(filtered) == 0:
|
|
338
|
+
return OptimalThresholdResult(
|
|
339
|
+
threshold=None,
|
|
340
|
+
found=False,
|
|
341
|
+
metric=metric,
|
|
342
|
+
reason="No thresholds meet constraints",
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Find maximum
|
|
346
|
+
optimal_idx_result = filtered[metric].arg_max()
|
|
347
|
+
if optimal_idx_result is None:
|
|
348
|
+
return OptimalThresholdResult(
|
|
349
|
+
threshold=0.0,
|
|
350
|
+
found=False,
|
|
351
|
+
metric=metric,
|
|
352
|
+
)
|
|
353
|
+
optimal_idx: int = optimal_idx_result
|
|
354
|
+
|
|
355
|
+
# Extract values
|
|
356
|
+
def safe_float(col: str) -> float | None:
|
|
357
|
+
if col in filtered.columns:
|
|
358
|
+
val = filtered[col][optimal_idx]
|
|
359
|
+
return float(val) if val is not None else None
|
|
360
|
+
return None
|
|
361
|
+
|
|
362
|
+
def safe_int(col: str) -> int | None:
|
|
363
|
+
if col in filtered.columns:
|
|
364
|
+
val = filtered[col][optimal_idx]
|
|
365
|
+
return int(val) if val is not None else None
|
|
366
|
+
return None
|
|
367
|
+
|
|
368
|
+
is_sig = False
|
|
369
|
+
if "is_significant" in filtered.columns:
|
|
370
|
+
is_sig = bool(filtered["is_significant"][optimal_idx])
|
|
371
|
+
|
|
372
|
+
threshold_val = filtered["threshold"][optimal_idx]
|
|
373
|
+
return OptimalThresholdResult(
|
|
374
|
+
threshold=float(threshold_val) if threshold_val is not None else 0.0,
|
|
375
|
+
found=True,
|
|
376
|
+
metric=metric,
|
|
377
|
+
metric_value=safe_float(metric),
|
|
378
|
+
precision=safe_float("precision"),
|
|
379
|
+
recall=safe_float("recall"),
|
|
380
|
+
f1_score=safe_float("f1_score"),
|
|
381
|
+
lift=safe_float("lift"),
|
|
382
|
+
coverage=safe_float("coverage"),
|
|
383
|
+
n_signals=safe_int("n_signals"),
|
|
384
|
+
is_significant=is_sig,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def find_threshold_for_target_coverage(
|
|
389
|
+
results_df: pl.DataFrame,
|
|
390
|
+
target_coverage: float,
|
|
391
|
+
tolerance: float = 0.05,
|
|
392
|
+
) -> OptimalThresholdResult:
|
|
393
|
+
"""Find threshold that achieves target coverage.
|
|
394
|
+
|
|
395
|
+
Useful when you want a specific signal frequency regardless of metric value.
|
|
396
|
+
|
|
397
|
+
Parameters
|
|
398
|
+
----------
|
|
399
|
+
results_df : pl.DataFrame
|
|
400
|
+
DataFrame from evaluate_threshold_sweep()
|
|
401
|
+
target_coverage : float
|
|
402
|
+
Target signal coverage (e.g., 0.10 for 10%)
|
|
403
|
+
tolerance : float, default 0.05
|
|
404
|
+
Acceptable deviation from target (e.g., 0.05 means 5-15% for target=10%)
|
|
405
|
+
|
|
406
|
+
Returns
|
|
407
|
+
-------
|
|
408
|
+
OptimalThresholdResult
|
|
409
|
+
Result with threshold closest to target coverage
|
|
410
|
+
"""
|
|
411
|
+
if "coverage" not in results_df.columns:
|
|
412
|
+
return OptimalThresholdResult(
|
|
413
|
+
threshold=None,
|
|
414
|
+
found=False,
|
|
415
|
+
metric="coverage",
|
|
416
|
+
reason="No coverage column in results",
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# Find threshold closest to target coverage
|
|
420
|
+
results_df = results_df.with_columns(
|
|
421
|
+
(pl.col("coverage") - target_coverage).abs().alias("_coverage_diff")
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Filter by tolerance
|
|
425
|
+
filtered = results_df.filter(pl.col("_coverage_diff") <= tolerance)
|
|
426
|
+
|
|
427
|
+
if len(filtered) == 0:
|
|
428
|
+
return OptimalThresholdResult(
|
|
429
|
+
threshold=None,
|
|
430
|
+
found=False,
|
|
431
|
+
metric="coverage",
|
|
432
|
+
reason=f"No threshold within {tolerance:.1%} of target {target_coverage:.1%}",
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# Find closest
|
|
436
|
+
closest_idx_result = filtered["_coverage_diff"].arg_min()
|
|
437
|
+
if closest_idx_result is None:
|
|
438
|
+
return OptimalThresholdResult(
|
|
439
|
+
threshold=None,
|
|
440
|
+
found=False,
|
|
441
|
+
metric="coverage",
|
|
442
|
+
reason="No threshold found",
|
|
443
|
+
)
|
|
444
|
+
closest_idx: int = closest_idx_result
|
|
445
|
+
|
|
446
|
+
def safe_float(col: str) -> float | None:
|
|
447
|
+
if col in filtered.columns:
|
|
448
|
+
val = filtered[col][closest_idx]
|
|
449
|
+
return float(val) if val is not None else None
|
|
450
|
+
return None
|
|
451
|
+
|
|
452
|
+
threshold_val = filtered["threshold"][closest_idx]
|
|
453
|
+
n_signals_val = filtered["n_signals"][closest_idx] if "n_signals" in filtered.columns else None
|
|
454
|
+
is_sig_val = (
|
|
455
|
+
filtered["is_significant"][closest_idx] if "is_significant" in filtered.columns else False
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
return OptimalThresholdResult(
|
|
459
|
+
threshold=float(threshold_val) if threshold_val is not None else 0.0,
|
|
460
|
+
found=True,
|
|
461
|
+
metric="coverage",
|
|
462
|
+
metric_value=safe_float("coverage"),
|
|
463
|
+
precision=safe_float("precision"),
|
|
464
|
+
recall=safe_float("recall"),
|
|
465
|
+
f1_score=safe_float("f1_score"),
|
|
466
|
+
lift=safe_float("lift"),
|
|
467
|
+
coverage=safe_float("coverage"),
|
|
468
|
+
n_signals=int(n_signals_val) if n_signals_val is not None else None,
|
|
469
|
+
is_significant=bool(is_sig_val) if is_sig_val is not None else False,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
# ============================================================================
|
|
474
|
+
# Monotonicity Analysis
|
|
475
|
+
# ============================================================================
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@dataclass
|
|
479
|
+
class MonotonicityResult:
|
|
480
|
+
"""Result of monotonicity analysis.
|
|
481
|
+
|
|
482
|
+
Attributes
|
|
483
|
+
----------
|
|
484
|
+
metric : str
|
|
485
|
+
Metric that was analyzed
|
|
486
|
+
is_monotonic : bool
|
|
487
|
+
Whether metric is monotonic (either direction)
|
|
488
|
+
is_monotonic_increasing : bool
|
|
489
|
+
Whether metric is monotonically increasing
|
|
490
|
+
is_monotonic_decreasing : bool
|
|
491
|
+
Whether metric is monotonically decreasing
|
|
492
|
+
direction_changes : int
|
|
493
|
+
Number of direction reversals
|
|
494
|
+
violations : list[tuple[int, float, float]]
|
|
495
|
+
List of (index, previous_value, current_value) for violations
|
|
496
|
+
max_violation : float | None
|
|
497
|
+
Largest decrease (for increasing expectation)
|
|
498
|
+
"""
|
|
499
|
+
|
|
500
|
+
metric: str
|
|
501
|
+
is_monotonic: bool
|
|
502
|
+
is_monotonic_increasing: bool
|
|
503
|
+
is_monotonic_decreasing: bool
|
|
504
|
+
direction_changes: int
|
|
505
|
+
violations: list[tuple[int, float, float]]
|
|
506
|
+
max_violation: float | None = None
|
|
507
|
+
|
|
508
|
+
def to_dict(self) -> dict:
|
|
509
|
+
"""Convert to dictionary."""
|
|
510
|
+
return {
|
|
511
|
+
"metric": self.metric,
|
|
512
|
+
"is_monotonic": self.is_monotonic,
|
|
513
|
+
"is_monotonic_increasing": self.is_monotonic_increasing,
|
|
514
|
+
"is_monotonic_decreasing": self.is_monotonic_decreasing,
|
|
515
|
+
"direction_changes": self.direction_changes,
|
|
516
|
+
"n_violations": len(self.violations),
|
|
517
|
+
"max_violation": self.max_violation,
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def check_monotonicity(
|
|
522
|
+
results_df: pl.DataFrame,
|
|
523
|
+
metric: str,
|
|
524
|
+
) -> MonotonicityResult:
|
|
525
|
+
"""Check if metric exhibits monotonic behavior across thresholds.
|
|
526
|
+
|
|
527
|
+
Non-monotonic behavior can indicate:
|
|
528
|
+
- Regime changes in the data
|
|
529
|
+
- Data quality issues
|
|
530
|
+
- Complex indicator dynamics
|
|
531
|
+
- Overfitting at certain thresholds
|
|
532
|
+
|
|
533
|
+
Parameters
|
|
534
|
+
----------
|
|
535
|
+
results_df : pl.DataFrame
|
|
536
|
+
DataFrame from evaluate_threshold_sweep(), sorted by threshold
|
|
537
|
+
metric : str
|
|
538
|
+
Metric to analyze ('precision', 'recall', 'lift', 'f1_score', etc.)
|
|
539
|
+
|
|
540
|
+
Returns
|
|
541
|
+
-------
|
|
542
|
+
MonotonicityResult
|
|
543
|
+
Analysis result with monotonicity status and violations
|
|
544
|
+
|
|
545
|
+
Examples
|
|
546
|
+
--------
|
|
547
|
+
>>> results = evaluate_threshold_sweep(indicator, labels, thresholds)
|
|
548
|
+
>>> mono = check_monotonicity(results, "lift")
|
|
549
|
+
>>> if not mono.is_monotonic:
|
|
550
|
+
... print(f"Warning: {mono.direction_changes} direction changes")
|
|
551
|
+
"""
|
|
552
|
+
if metric not in results_df.columns:
|
|
553
|
+
raise ValueError(f"Metric '{metric}' not in results")
|
|
554
|
+
|
|
555
|
+
# Sort by threshold to ensure proper ordering
|
|
556
|
+
sorted_df = results_df.sort("threshold")
|
|
557
|
+
values = sorted_df[metric].to_numpy()
|
|
558
|
+
|
|
559
|
+
# Handle NaN values
|
|
560
|
+
valid_mask = ~np.isnan(values)
|
|
561
|
+
if not np.any(valid_mask):
|
|
562
|
+
return MonotonicityResult(
|
|
563
|
+
metric=metric,
|
|
564
|
+
is_monotonic=False,
|
|
565
|
+
is_monotonic_increasing=False,
|
|
566
|
+
is_monotonic_decreasing=False,
|
|
567
|
+
direction_changes=0,
|
|
568
|
+
violations=[],
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
# Use only valid values for analysis
|
|
572
|
+
valid_values = values[valid_mask]
|
|
573
|
+
|
|
574
|
+
if len(valid_values) < 2:
|
|
575
|
+
return MonotonicityResult(
|
|
576
|
+
metric=metric,
|
|
577
|
+
is_monotonic=True,
|
|
578
|
+
is_monotonic_increasing=True,
|
|
579
|
+
is_monotonic_decreasing=True,
|
|
580
|
+
direction_changes=0,
|
|
581
|
+
violations=[],
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# Compute differences
|
|
585
|
+
diffs = np.diff(valid_values)
|
|
586
|
+
|
|
587
|
+
# Check increasing/decreasing
|
|
588
|
+
is_increasing = bool(np.all(diffs >= -1e-10)) # Small tolerance for floating point
|
|
589
|
+
is_decreasing = bool(np.all(diffs <= 1e-10))
|
|
590
|
+
|
|
591
|
+
# Count direction changes (ignoring zero differences)
|
|
592
|
+
nonzero_diffs = diffs[np.abs(diffs) > 1e-10]
|
|
593
|
+
if len(nonzero_diffs) > 0:
|
|
594
|
+
signs = np.sign(nonzero_diffs)
|
|
595
|
+
direction_changes = int(np.sum(np.abs(np.diff(signs)) > 0))
|
|
596
|
+
else:
|
|
597
|
+
direction_changes = 0
|
|
598
|
+
|
|
599
|
+
# Find violations (decreases when we'd expect increase)
|
|
600
|
+
violations = []
|
|
601
|
+
for i in range(1, len(valid_values)):
|
|
602
|
+
if valid_values[i] < valid_values[i - 1] - 1e-10:
|
|
603
|
+
violations.append((i, float(valid_values[i - 1]), float(valid_values[i])))
|
|
604
|
+
|
|
605
|
+
# Max violation
|
|
606
|
+
max_violation = None
|
|
607
|
+
if violations:
|
|
608
|
+
max_violation = max(v[1] - v[2] for v in violations)
|
|
609
|
+
|
|
610
|
+
return MonotonicityResult(
|
|
611
|
+
metric=metric,
|
|
612
|
+
is_monotonic=is_increasing or is_decreasing,
|
|
613
|
+
is_monotonic_increasing=is_increasing,
|
|
614
|
+
is_monotonic_decreasing=is_decreasing,
|
|
615
|
+
direction_changes=direction_changes,
|
|
616
|
+
violations=violations,
|
|
617
|
+
max_violation=max_violation,
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def analyze_all_metrics_monotonicity(
|
|
622
|
+
results_df: pl.DataFrame,
|
|
623
|
+
metrics: list[str] | None = None,
|
|
624
|
+
) -> dict[str, MonotonicityResult]:
|
|
625
|
+
"""Analyze monotonicity for multiple metrics.
|
|
626
|
+
|
|
627
|
+
Parameters
|
|
628
|
+
----------
|
|
629
|
+
results_df : pl.DataFrame
|
|
630
|
+
DataFrame from evaluate_threshold_sweep()
|
|
631
|
+
metrics : list[str], optional
|
|
632
|
+
Metrics to analyze (default: precision, recall, lift, f1_score, coverage)
|
|
633
|
+
|
|
634
|
+
Returns
|
|
635
|
+
-------
|
|
636
|
+
dict[str, MonotonicityResult]
|
|
637
|
+
Dictionary mapping metric name to monotonicity result
|
|
638
|
+
"""
|
|
639
|
+
if metrics is None:
|
|
640
|
+
metrics = ["precision", "recall", "lift", "f1_score", "coverage"]
|
|
641
|
+
|
|
642
|
+
results = {}
|
|
643
|
+
for metric in metrics:
|
|
644
|
+
if metric in results_df.columns:
|
|
645
|
+
results[metric] = check_monotonicity(results_df, metric)
|
|
646
|
+
|
|
647
|
+
return results
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
# ============================================================================
|
|
651
|
+
# Threshold Sensitivity Analysis
|
|
652
|
+
# ============================================================================
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
@dataclass
|
|
656
|
+
class SensitivityResult:
|
|
657
|
+
"""Result of threshold sensitivity analysis.
|
|
658
|
+
|
|
659
|
+
Attributes
|
|
660
|
+
----------
|
|
661
|
+
metric : str
|
|
662
|
+
Metric analyzed
|
|
663
|
+
mean_value : float
|
|
664
|
+
Mean metric value across thresholds
|
|
665
|
+
std_value : float
|
|
666
|
+
Standard deviation of metric
|
|
667
|
+
min_value : float
|
|
668
|
+
Minimum metric value
|
|
669
|
+
max_value : float
|
|
670
|
+
Maximum metric value
|
|
671
|
+
range_value : float
|
|
672
|
+
Range (max - min)
|
|
673
|
+
coefficient_of_variation : float
|
|
674
|
+
CV = std / mean (relative variability)
|
|
675
|
+
is_stable : bool
|
|
676
|
+
Whether metric is relatively stable (CV < 0.2)
|
|
677
|
+
"""
|
|
678
|
+
|
|
679
|
+
metric: str
|
|
680
|
+
mean_value: float
|
|
681
|
+
std_value: float
|
|
682
|
+
min_value: float
|
|
683
|
+
max_value: float
|
|
684
|
+
range_value: float
|
|
685
|
+
coefficient_of_variation: float
|
|
686
|
+
is_stable: bool
|
|
687
|
+
|
|
688
|
+
def to_dict(self) -> dict:
|
|
689
|
+
"""Convert to dictionary."""
|
|
690
|
+
return {
|
|
691
|
+
"metric": self.metric,
|
|
692
|
+
"mean": self.mean_value,
|
|
693
|
+
"std": self.std_value,
|
|
694
|
+
"min": self.min_value,
|
|
695
|
+
"max": self.max_value,
|
|
696
|
+
"range": self.range_value,
|
|
697
|
+
"cv": self.coefficient_of_variation,
|
|
698
|
+
"is_stable": self.is_stable,
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def analyze_threshold_sensitivity(
|
|
703
|
+
results_df: pl.DataFrame,
|
|
704
|
+
metric: str,
|
|
705
|
+
stability_threshold: float = 0.2,
|
|
706
|
+
) -> SensitivityResult:
|
|
707
|
+
"""Analyze how sensitive a metric is to threshold changes.
|
|
708
|
+
|
|
709
|
+
A highly sensitive metric changes dramatically across thresholds,
|
|
710
|
+
suggesting careful threshold selection is important.
|
|
711
|
+
|
|
712
|
+
Parameters
|
|
713
|
+
----------
|
|
714
|
+
results_df : pl.DataFrame
|
|
715
|
+
DataFrame from evaluate_threshold_sweep()
|
|
716
|
+
metric : str
|
|
717
|
+
Metric to analyze
|
|
718
|
+
stability_threshold : float, default 0.2
|
|
719
|
+
CV threshold for considering metric stable
|
|
720
|
+
|
|
721
|
+
Returns
|
|
722
|
+
-------
|
|
723
|
+
SensitivityResult
|
|
724
|
+
Sensitivity analysis result
|
|
725
|
+
|
|
726
|
+
Examples
|
|
727
|
+
--------
|
|
728
|
+
>>> sensitivity = analyze_threshold_sensitivity(results, "lift")
|
|
729
|
+
>>> if not sensitivity.is_stable:
|
|
730
|
+
... print(f"Warning: {metric} varies significantly (CV={sensitivity.cv:.2f})")
|
|
731
|
+
"""
|
|
732
|
+
if metric not in results_df.columns:
|
|
733
|
+
raise ValueError(f"Metric '{metric}' not in results")
|
|
734
|
+
|
|
735
|
+
values = results_df[metric].drop_nulls().drop_nans()
|
|
736
|
+
|
|
737
|
+
if len(values) == 0:
|
|
738
|
+
return SensitivityResult(
|
|
739
|
+
metric=metric,
|
|
740
|
+
mean_value=float("nan"),
|
|
741
|
+
std_value=float("nan"),
|
|
742
|
+
min_value=float("nan"),
|
|
743
|
+
max_value=float("nan"),
|
|
744
|
+
range_value=float("nan"),
|
|
745
|
+
coefficient_of_variation=float("nan"),
|
|
746
|
+
is_stable=False,
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
mean_result = values.mean()
|
|
750
|
+
std_result = values.std()
|
|
751
|
+
min_result = values.min()
|
|
752
|
+
max_result = values.max()
|
|
753
|
+
|
|
754
|
+
mean_val = (
|
|
755
|
+
float(mean_result)
|
|
756
|
+
if mean_result is not None and isinstance(mean_result, int | float)
|
|
757
|
+
else 0.0
|
|
758
|
+
)
|
|
759
|
+
std_val = (
|
|
760
|
+
float(std_result) if std_result is not None and isinstance(std_result, int | float) else 0.0
|
|
761
|
+
)
|
|
762
|
+
min_val = (
|
|
763
|
+
float(min_result) if min_result is not None and isinstance(min_result, int | float) else 0.0
|
|
764
|
+
)
|
|
765
|
+
max_val = (
|
|
766
|
+
float(max_result) if max_result is not None and isinstance(max_result, int | float) else 0.0
|
|
767
|
+
)
|
|
768
|
+
range_val = max_val - min_val
|
|
769
|
+
|
|
770
|
+
# Coefficient of variation
|
|
771
|
+
cv = std_val / mean_val if mean_val != 0 else float("inf")
|
|
772
|
+
|
|
773
|
+
return SensitivityResult(
|
|
774
|
+
metric=metric,
|
|
775
|
+
mean_value=mean_val,
|
|
776
|
+
std_value=std_val,
|
|
777
|
+
min_value=min_val,
|
|
778
|
+
max_value=max_val,
|
|
779
|
+
range_value=range_val,
|
|
780
|
+
coefficient_of_variation=cv,
|
|
781
|
+
is_stable=cv < stability_threshold,
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
# ============================================================================
|
|
786
|
+
# Summary Functions
|
|
787
|
+
# ============================================================================
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
@dataclass
|
|
791
|
+
class ThresholdAnalysisSummary:
|
|
792
|
+
"""Complete threshold analysis summary.
|
|
793
|
+
|
|
794
|
+
Attributes
|
|
795
|
+
----------
|
|
796
|
+
n_thresholds : int
|
|
797
|
+
Number of thresholds evaluated
|
|
798
|
+
optimal : OptimalThresholdResult
|
|
799
|
+
Optimal threshold result
|
|
800
|
+
monotonicity : dict[str, MonotonicityResult]
|
|
801
|
+
Monotonicity results per metric
|
|
802
|
+
sensitivity : dict[str, SensitivityResult]
|
|
803
|
+
Sensitivity results per metric
|
|
804
|
+
significant_count : int
|
|
805
|
+
Number of statistically significant thresholds
|
|
806
|
+
best_per_metric : dict[str, float]
|
|
807
|
+
Best threshold for each metric
|
|
808
|
+
"""
|
|
809
|
+
|
|
810
|
+
n_thresholds: int
|
|
811
|
+
optimal: OptimalThresholdResult
|
|
812
|
+
monotonicity: dict[str, MonotonicityResult]
|
|
813
|
+
sensitivity: dict[str, SensitivityResult]
|
|
814
|
+
significant_count: int
|
|
815
|
+
best_per_metric: dict[str, float]
|
|
816
|
+
|
|
817
|
+
def to_dict(self) -> dict:
|
|
818
|
+
"""Convert to dictionary."""
|
|
819
|
+
return {
|
|
820
|
+
"n_thresholds": self.n_thresholds,
|
|
821
|
+
"optimal": self.optimal.to_dict(),
|
|
822
|
+
"monotonicity": {k: v.to_dict() for k, v in self.monotonicity.items()},
|
|
823
|
+
"sensitivity": {k: v.to_dict() for k, v in self.sensitivity.items()},
|
|
824
|
+
"significant_count": self.significant_count,
|
|
825
|
+
"best_per_metric": self.best_per_metric,
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
def create_threshold_analysis_summary(
|
|
830
|
+
results_df: pl.DataFrame,
|
|
831
|
+
optimize_metric: str = "lift",
|
|
832
|
+
min_coverage: float = 0.01,
|
|
833
|
+
metrics: list[str] | None = None,
|
|
834
|
+
) -> ThresholdAnalysisSummary:
|
|
835
|
+
"""Create comprehensive threshold analysis summary.
|
|
836
|
+
|
|
837
|
+
Parameters
|
|
838
|
+
----------
|
|
839
|
+
results_df : pl.DataFrame
|
|
840
|
+
DataFrame from evaluate_threshold_sweep()
|
|
841
|
+
optimize_metric : str, default 'lift'
|
|
842
|
+
Metric to optimize for optimal threshold
|
|
843
|
+
min_coverage : float, default 0.01
|
|
844
|
+
Minimum coverage for optimal threshold
|
|
845
|
+
metrics : list[str], optional
|
|
846
|
+
Metrics to analyze
|
|
847
|
+
|
|
848
|
+
Returns
|
|
849
|
+
-------
|
|
850
|
+
ThresholdAnalysisSummary
|
|
851
|
+
Complete analysis summary
|
|
852
|
+
|
|
853
|
+
Examples
|
|
854
|
+
--------
|
|
855
|
+
>>> results = evaluate_threshold_sweep(indicator, labels, thresholds)
|
|
856
|
+
>>> summary = create_threshold_analysis_summary(results)
|
|
857
|
+
>>> print(f"Optimal threshold: {summary.optimal.threshold}")
|
|
858
|
+
>>> print(f"Significant thresholds: {summary.significant_count}")
|
|
859
|
+
"""
|
|
860
|
+
if metrics is None:
|
|
861
|
+
metrics = ["precision", "recall", "lift", "f1_score", "coverage"]
|
|
862
|
+
|
|
863
|
+
# Filter to available metrics
|
|
864
|
+
available_metrics = [m for m in metrics if m in results_df.columns]
|
|
865
|
+
|
|
866
|
+
# Optimal threshold
|
|
867
|
+
optimal = find_optimal_threshold(results_df, metric=optimize_metric, min_coverage=min_coverage)
|
|
868
|
+
|
|
869
|
+
# Monotonicity analysis
|
|
870
|
+
monotonicity = {}
|
|
871
|
+
for metric in available_metrics:
|
|
872
|
+
monotonicity[metric] = check_monotonicity(results_df, metric)
|
|
873
|
+
|
|
874
|
+
# Sensitivity analysis
|
|
875
|
+
sensitivity = {}
|
|
876
|
+
for metric in available_metrics:
|
|
877
|
+
sensitivity[metric] = analyze_threshold_sensitivity(results_df, metric)
|
|
878
|
+
|
|
879
|
+
# Significant count
|
|
880
|
+
sig_count = 0
|
|
881
|
+
if "is_significant" in results_df.columns:
|
|
882
|
+
sig_count = int(results_df.filter(pl.col("is_significant") == 1.0).height)
|
|
883
|
+
|
|
884
|
+
# Best threshold per metric
|
|
885
|
+
best_per_metric = {}
|
|
886
|
+
for metric in available_metrics:
|
|
887
|
+
result = find_optimal_threshold(results_df, metric=metric, min_coverage=0.0)
|
|
888
|
+
if result.found and result.threshold is not None:
|
|
889
|
+
best_per_metric[metric] = result.threshold
|
|
890
|
+
|
|
891
|
+
return ThresholdAnalysisSummary(
|
|
892
|
+
n_thresholds=len(results_df),
|
|
893
|
+
optimal=optimal,
|
|
894
|
+
monotonicity=monotonicity,
|
|
895
|
+
sensitivity=sensitivity,
|
|
896
|
+
significant_count=sig_count,
|
|
897
|
+
best_per_metric=best_per_metric,
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
def format_threshold_analysis(summary: ThresholdAnalysisSummary) -> str:
|
|
902
|
+
"""Format threshold analysis summary as human-readable string.
|
|
903
|
+
|
|
904
|
+
Parameters
|
|
905
|
+
----------
|
|
906
|
+
summary : ThresholdAnalysisSummary
|
|
907
|
+
Summary from create_threshold_analysis_summary()
|
|
908
|
+
|
|
909
|
+
Returns
|
|
910
|
+
-------
|
|
911
|
+
str
|
|
912
|
+
Formatted string
|
|
913
|
+
"""
|
|
914
|
+
lines = [
|
|
915
|
+
"Threshold Analysis Summary",
|
|
916
|
+
"=" * 50,
|
|
917
|
+
"",
|
|
918
|
+
f"Thresholds Evaluated: {summary.n_thresholds}",
|
|
919
|
+
f"Statistically Significant: {summary.significant_count}",
|
|
920
|
+
"",
|
|
921
|
+
]
|
|
922
|
+
|
|
923
|
+
# Optimal threshold
|
|
924
|
+
lines.append("Optimal Threshold:")
|
|
925
|
+
if summary.optimal.found:
|
|
926
|
+
lines.append(f" Threshold: {summary.optimal.threshold}")
|
|
927
|
+
lines.append(f" Optimized Metric: {summary.optimal.metric}")
|
|
928
|
+
lines.append(f" {summary.optimal.metric}: {summary.optimal.metric_value:.3f}")
|
|
929
|
+
lines.append(f" Precision: {summary.optimal.precision:.3f}")
|
|
930
|
+
lines.append(f" Recall: {summary.optimal.recall:.3f}")
|
|
931
|
+
lines.append(f" Coverage: {summary.optimal.coverage:.1%}")
|
|
932
|
+
lines.append(f" Significant: {summary.optimal.is_significant}")
|
|
933
|
+
else:
|
|
934
|
+
lines.append(f" Not found: {summary.optimal.reason}")
|
|
935
|
+
|
|
936
|
+
# Monotonicity
|
|
937
|
+
lines.extend(["", "Monotonicity Analysis:"])
|
|
938
|
+
for metric, mono in summary.monotonicity.items():
|
|
939
|
+
status = "[+] Monotonic" if mono.is_monotonic else f"[-] {mono.direction_changes} reversals"
|
|
940
|
+
lines.append(f" {metric}: {status}")
|
|
941
|
+
|
|
942
|
+
# Sensitivity
|
|
943
|
+
lines.extend(["", "Sensitivity Analysis:"])
|
|
944
|
+
for metric, sens in summary.sensitivity.items():
|
|
945
|
+
status = (
|
|
946
|
+
"[+] Stable"
|
|
947
|
+
if sens.is_stable
|
|
948
|
+
else f"[-] Variable (CV={sens.coefficient_of_variation:.2f})"
|
|
949
|
+
)
|
|
950
|
+
lines.append(f" {metric}: {status}")
|
|
951
|
+
|
|
952
|
+
# Best per metric
|
|
953
|
+
lines.extend(["", "Best Threshold per Metric:"])
|
|
954
|
+
for metric, threshold in summary.best_per_metric.items():
|
|
955
|
+
lines.append(f" {metric}: {threshold}")
|
|
956
|
+
|
|
957
|
+
return "\n".join(lines)
|