ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
"""Unified drift analysis using multiple detection methods.
|
|
2
|
+
|
|
3
|
+
This module provides the main analyze_drift() function that combines
|
|
4
|
+
PSI, Wasserstein, and Domain Classifier methods for comprehensive
|
|
5
|
+
drift detection.
|
|
6
|
+
|
|
7
|
+
Consensus Logic:
|
|
8
|
+
A feature is flagged as drifted if the fraction of methods detecting drift
|
|
9
|
+
exceeds the consensus_threshold. For example, with threshold=0.5:
|
|
10
|
+
- If 2/3 methods detect drift → flagged as drifted
|
|
11
|
+
- If 1/3 methods detect drift → not flagged as drifted
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import time
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pandas as pd
|
|
22
|
+
import polars as pl
|
|
23
|
+
|
|
24
|
+
from ml4t.diagnostic.evaluation.drift.domain_classifier import (
|
|
25
|
+
DomainClassifierResult,
|
|
26
|
+
compute_domain_classifier_drift,
|
|
27
|
+
)
|
|
28
|
+
from ml4t.diagnostic.evaluation.drift.population_stability_index import (
|
|
29
|
+
PSIResult,
|
|
30
|
+
compute_psi,
|
|
31
|
+
)
|
|
32
|
+
from ml4t.diagnostic.evaluation.drift.wasserstein import (
|
|
33
|
+
WassersteinResult,
|
|
34
|
+
compute_wasserstein_distance,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class FeatureDriftResult:
|
|
40
|
+
"""Drift analysis result for a single feature across multiple methods.
|
|
41
|
+
|
|
42
|
+
Attributes:
|
|
43
|
+
feature: Feature name
|
|
44
|
+
psi_result: PSI drift detection result (if method was run)
|
|
45
|
+
wasserstein_result: Wasserstein drift detection result (if method was run)
|
|
46
|
+
drifted: Consensus drift flag (based on multiple methods)
|
|
47
|
+
n_methods_run: Number of methods that were run on this feature
|
|
48
|
+
n_methods_detected: Number of methods that detected drift
|
|
49
|
+
drift_probability: Fraction of methods that detected drift
|
|
50
|
+
interpretation: Human-readable interpretation
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
feature: str
|
|
54
|
+
psi_result: PSIResult | None = None
|
|
55
|
+
wasserstein_result: WassersteinResult | None = None
|
|
56
|
+
drifted: bool = False
|
|
57
|
+
n_methods_run: int = 0
|
|
58
|
+
n_methods_detected: int = 0
|
|
59
|
+
drift_probability: float = 0.0
|
|
60
|
+
interpretation: str = ""
|
|
61
|
+
|
|
62
|
+
def summary(self) -> str:
|
|
63
|
+
"""Generate summary string for this feature's drift analysis."""
|
|
64
|
+
lines = [f"Feature: {self.feature}"]
|
|
65
|
+
lines.append(
|
|
66
|
+
f" Drifted: {self.drifted} ({self.n_methods_detected}/{self.n_methods_run} methods)"
|
|
67
|
+
)
|
|
68
|
+
lines.append(f" Drift Probability: {self.drift_probability:.2%}")
|
|
69
|
+
|
|
70
|
+
if self.psi_result is not None:
|
|
71
|
+
lines.append(f" PSI: {self.psi_result.psi:.4f} ({self.psi_result.alert_level})")
|
|
72
|
+
|
|
73
|
+
if self.wasserstein_result is not None:
|
|
74
|
+
drifted_str = "drifted" if self.wasserstein_result.drifted else "no drift"
|
|
75
|
+
lines.append(f" Wasserstein: {self.wasserstein_result.distance:.4f} ({drifted_str})")
|
|
76
|
+
|
|
77
|
+
return "\n".join(lines)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class DriftSummaryResult:
|
|
82
|
+
"""Summary of multi-method drift analysis across features.
|
|
83
|
+
|
|
84
|
+
This result aggregates drift detection across multiple methods (PSI,
|
|
85
|
+
Wasserstein, Domain Classifier) to provide a comprehensive drift assessment.
|
|
86
|
+
|
|
87
|
+
Attributes:
|
|
88
|
+
feature_results: Per-feature drift results (PSI + Wasserstein)
|
|
89
|
+
domain_classifier_result: Multivariate drift result (if domain classifier was run)
|
|
90
|
+
n_features: Total number of features analyzed
|
|
91
|
+
n_features_drifted: Number of features flagged as drifted
|
|
92
|
+
drifted_features: List of feature names that drifted
|
|
93
|
+
overall_drifted: Overall drift flag (True if any feature drifted or domain classifier detected drift)
|
|
94
|
+
consensus_threshold: Minimum fraction of methods that must agree to flag drift
|
|
95
|
+
methods_used: List of drift detection methods used
|
|
96
|
+
univariate_methods: Methods run on individual features
|
|
97
|
+
multivariate_methods: Methods run on all features jointly
|
|
98
|
+
interpretation: Human-readable interpretation
|
|
99
|
+
computation_time: Total time taken for all methods (seconds)
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
feature_results: list[FeatureDriftResult]
|
|
103
|
+
domain_classifier_result: DomainClassifierResult | None = None
|
|
104
|
+
n_features: int = 0
|
|
105
|
+
n_features_drifted: int = 0
|
|
106
|
+
drifted_features: list[str] = field(default_factory=list)
|
|
107
|
+
overall_drifted: bool = False
|
|
108
|
+
consensus_threshold: float = 0.5
|
|
109
|
+
methods_used: list[str] = field(default_factory=list)
|
|
110
|
+
univariate_methods: list[str] = field(default_factory=list)
|
|
111
|
+
multivariate_methods: list[str] = field(default_factory=list)
|
|
112
|
+
interpretation: str = ""
|
|
113
|
+
computation_time: float = 0.0
|
|
114
|
+
|
|
115
|
+
def summary(self) -> str:
|
|
116
|
+
"""Generate comprehensive summary of drift analysis."""
|
|
117
|
+
lines = ["=" * 60]
|
|
118
|
+
lines.append("Drift Analysis Summary")
|
|
119
|
+
lines.append("=" * 60)
|
|
120
|
+
lines.append(f"Methods Used: {', '.join(self.methods_used)}")
|
|
121
|
+
lines.append(f"Consensus Threshold: {self.consensus_threshold:.0%}")
|
|
122
|
+
lines.append(f"Total Features: {self.n_features}")
|
|
123
|
+
lines.append(
|
|
124
|
+
f"Drifted Features: {self.n_features_drifted} ({self.n_features_drifted / max(1, self.n_features):.0%})"
|
|
125
|
+
)
|
|
126
|
+
lines.append(f"Overall Drift Detected: {self.overall_drifted}")
|
|
127
|
+
lines.append("")
|
|
128
|
+
|
|
129
|
+
if self.drifted_features:
|
|
130
|
+
lines.append("Drifted Features:")
|
|
131
|
+
for feature in self.drifted_features:
|
|
132
|
+
lines.append(f" - {feature}")
|
|
133
|
+
lines.append("")
|
|
134
|
+
|
|
135
|
+
if self.domain_classifier_result is not None:
|
|
136
|
+
lines.append("Multivariate Drift (Domain Classifier):")
|
|
137
|
+
lines.append(f" AUC: {self.domain_classifier_result.auc:.4f}")
|
|
138
|
+
lines.append(f" Drifted: {self.domain_classifier_result.drifted}")
|
|
139
|
+
lines.append("")
|
|
140
|
+
|
|
141
|
+
lines.append(f"Computation Time: {self.computation_time:.2f}s")
|
|
142
|
+
lines.append("=" * 60)
|
|
143
|
+
|
|
144
|
+
return "\n".join(lines)
|
|
145
|
+
|
|
146
|
+
def to_dataframe(self) -> pl.DataFrame:
|
|
147
|
+
"""Convert feature-level results to a DataFrame.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Polars DataFrame with per-feature drift analysis results
|
|
151
|
+
"""
|
|
152
|
+
data = []
|
|
153
|
+
for result in self.feature_results:
|
|
154
|
+
row = {
|
|
155
|
+
"feature": result.feature,
|
|
156
|
+
"drifted": result.drifted,
|
|
157
|
+
"drift_probability": result.drift_probability,
|
|
158
|
+
"n_methods_detected": result.n_methods_detected,
|
|
159
|
+
"n_methods_run": result.n_methods_run,
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
if result.psi_result is not None:
|
|
163
|
+
row["psi"] = result.psi_result.psi
|
|
164
|
+
row["psi_alert"] = result.psi_result.alert_level
|
|
165
|
+
|
|
166
|
+
if result.wasserstein_result is not None:
|
|
167
|
+
row["wasserstein_distance"] = result.wasserstein_result.distance
|
|
168
|
+
row["wasserstein_drifted"] = result.wasserstein_result.drifted
|
|
169
|
+
if result.wasserstein_result.p_value is not None:
|
|
170
|
+
row["wasserstein_pvalue"] = result.wasserstein_result.p_value
|
|
171
|
+
|
|
172
|
+
data.append(row)
|
|
173
|
+
|
|
174
|
+
return pl.DataFrame(data)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def analyze_drift(
|
|
178
|
+
reference: pd.DataFrame | pl.DataFrame,
|
|
179
|
+
test: pd.DataFrame | pl.DataFrame,
|
|
180
|
+
features: list[str] | None = None,
|
|
181
|
+
*,
|
|
182
|
+
methods: list[str] | None = None,
|
|
183
|
+
consensus_threshold: float = 0.5,
|
|
184
|
+
# PSI parameters
|
|
185
|
+
psi_config: dict[str, Any] | None = None,
|
|
186
|
+
# Wasserstein parameters
|
|
187
|
+
wasserstein_config: dict[str, Any] | None = None,
|
|
188
|
+
# Domain classifier parameters
|
|
189
|
+
domain_classifier_config: dict[str, Any] | None = None,
|
|
190
|
+
) -> DriftSummaryResult:
|
|
191
|
+
"""Comprehensive drift analysis using multiple detection methods.
|
|
192
|
+
|
|
193
|
+
This function provides a unified interface for drift detection across multiple
|
|
194
|
+
methods (PSI, Wasserstein, Domain Classifier). It runs univariate methods on
|
|
195
|
+
each feature and optionally multivariate methods on all features jointly.
|
|
196
|
+
|
|
197
|
+
**Univariate Methods** (run per feature):
|
|
198
|
+
- PSI: Population Stability Index (binning-based)
|
|
199
|
+
- Wasserstein: Earth Mover's Distance (metric-based)
|
|
200
|
+
|
|
201
|
+
**Multivariate Methods** (run on all features):
|
|
202
|
+
- Domain Classifier: ML-based drift detection with feature importance
|
|
203
|
+
|
|
204
|
+
**Consensus Logic**:
|
|
205
|
+
A feature is flagged as drifted if the fraction of methods detecting drift
|
|
206
|
+
exceeds the consensus_threshold. For example, with threshold=0.5:
|
|
207
|
+
- If 2/3 methods detect drift → flagged as drifted
|
|
208
|
+
- If 1/3 methods detect drift → not flagged as drifted
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
reference: Reference distribution (e.g., training data)
|
|
212
|
+
Can be pandas or polars DataFrame
|
|
213
|
+
test: Test distribution (e.g., production data)
|
|
214
|
+
Can be pandas or polars DataFrame
|
|
215
|
+
features: List of feature names to analyze. If None, uses all numeric columns
|
|
216
|
+
methods: List of methods to use. Options: ["psi", "wasserstein", "domain_classifier"]
|
|
217
|
+
Default: ["psi", "wasserstein", "domain_classifier"]
|
|
218
|
+
consensus_threshold: Minimum fraction of methods that must detect drift
|
|
219
|
+
to flag a feature as drifted (default: 0.5)
|
|
220
|
+
psi_config: Configuration dict for PSI. Keys:
|
|
221
|
+
- n_bins: int (default: 10)
|
|
222
|
+
- is_categorical: bool (default: False)
|
|
223
|
+
- psi_threshold_yellow: float (default: 0.1)
|
|
224
|
+
- psi_threshold_red: float (default: 0.2)
|
|
225
|
+
wasserstein_config: Configuration dict for Wasserstein. Keys:
|
|
226
|
+
- p: int (default: 1)
|
|
227
|
+
- threshold_calibration: bool (default: True)
|
|
228
|
+
- n_permutations: int (default: 1000)
|
|
229
|
+
- alpha: float (default: 0.05)
|
|
230
|
+
domain_classifier_config: Configuration dict for domain classifier. Keys:
|
|
231
|
+
- model_type: str (default: "lightgbm")
|
|
232
|
+
- n_estimators: int (default: 100)
|
|
233
|
+
- max_depth: int (default: 5)
|
|
234
|
+
- threshold: float (default: 0.6)
|
|
235
|
+
- cv_folds: int (default: 5)
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
DriftSummaryResult with per-feature results, multivariate results,
|
|
239
|
+
and overall drift assessment
|
|
240
|
+
|
|
241
|
+
Raises:
|
|
242
|
+
ValueError: If inputs are invalid or methods list is empty
|
|
243
|
+
|
|
244
|
+
Example:
|
|
245
|
+
>>> import pandas as pd
|
|
246
|
+
>>> from ml4t.diagnostic.evaluation.drift import analyze_drift
|
|
247
|
+
>>>
|
|
248
|
+
>>> # Create reference and test data
|
|
249
|
+
>>> reference = pd.DataFrame({
|
|
250
|
+
... 'feature1': np.random.normal(0, 1, 1000),
|
|
251
|
+
... 'feature2': np.random.normal(0, 1, 1000)
|
|
252
|
+
... })
|
|
253
|
+
>>> test = pd.DataFrame({
|
|
254
|
+
... 'feature1': np.random.normal(0.5, 1, 1000), # Mean shifted
|
|
255
|
+
... 'feature2': np.random.normal(0, 1, 1000) # No shift
|
|
256
|
+
... })
|
|
257
|
+
>>>
|
|
258
|
+
>>> # Run drift analysis
|
|
259
|
+
>>> result = analyze_drift(reference, test)
|
|
260
|
+
>>> print(result.summary())
|
|
261
|
+
>>>
|
|
262
|
+
>>> # Check which features drifted
|
|
263
|
+
>>> print(f"Drifted features: {result.drifted_features}")
|
|
264
|
+
>>>
|
|
265
|
+
>>> # Get per-feature details
|
|
266
|
+
>>> df = result.to_dataframe()
|
|
267
|
+
>>> print(df)
|
|
268
|
+
"""
|
|
269
|
+
start_time = time.time()
|
|
270
|
+
|
|
271
|
+
# Input validation
|
|
272
|
+
if reference is None or test is None:
|
|
273
|
+
raise ValueError("reference and test must not be None")
|
|
274
|
+
|
|
275
|
+
# Convert to pandas for easier processing
|
|
276
|
+
reference_pd: pd.DataFrame
|
|
277
|
+
test_pd: pd.DataFrame
|
|
278
|
+
if isinstance(reference, pl.DataFrame):
|
|
279
|
+
reference_pd = reference.to_pandas()
|
|
280
|
+
else:
|
|
281
|
+
reference_pd = reference
|
|
282
|
+
if isinstance(test, pl.DataFrame):
|
|
283
|
+
test_pd = test.to_pandas()
|
|
284
|
+
else:
|
|
285
|
+
test_pd = test
|
|
286
|
+
|
|
287
|
+
# Determine features to analyze
|
|
288
|
+
if features is None:
|
|
289
|
+
# Use all numeric columns
|
|
290
|
+
numeric_cols = reference_pd.select_dtypes(include=[np.number]).columns.tolist()
|
|
291
|
+
features = numeric_cols
|
|
292
|
+
else:
|
|
293
|
+
# Validate features exist
|
|
294
|
+
missing_in_ref = set(features) - set(reference_pd.columns)
|
|
295
|
+
missing_in_test = set(features) - set(test_pd.columns)
|
|
296
|
+
if missing_in_ref or missing_in_test:
|
|
297
|
+
raise ValueError(
|
|
298
|
+
f"Features not found - reference: {missing_in_ref}, test: {missing_in_test}"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if not features:
|
|
302
|
+
raise ValueError("No features to analyze")
|
|
303
|
+
|
|
304
|
+
# Determine methods to use
|
|
305
|
+
if methods is None:
|
|
306
|
+
methods = ["psi", "wasserstein", "domain_classifier"]
|
|
307
|
+
|
|
308
|
+
valid_methods = ["psi", "wasserstein", "domain_classifier"]
|
|
309
|
+
invalid_methods = set(methods) - set(valid_methods)
|
|
310
|
+
if invalid_methods:
|
|
311
|
+
raise ValueError(f"Invalid methods: {invalid_methods}. Valid: {valid_methods}")
|
|
312
|
+
|
|
313
|
+
# Separate univariate and multivariate methods
|
|
314
|
+
univariate_methods = [m for m in methods if m in ["psi", "wasserstein"]]
|
|
315
|
+
multivariate_methods = [m for m in methods if m == "domain_classifier"]
|
|
316
|
+
|
|
317
|
+
# Set default configs
|
|
318
|
+
if psi_config is None:
|
|
319
|
+
psi_config = {}
|
|
320
|
+
if wasserstein_config is None:
|
|
321
|
+
wasserstein_config = {}
|
|
322
|
+
if domain_classifier_config is None:
|
|
323
|
+
domain_classifier_config = {}
|
|
324
|
+
|
|
325
|
+
# Run univariate methods on each feature
|
|
326
|
+
feature_results = []
|
|
327
|
+
for feature in features:
|
|
328
|
+
# Explicitly convert to ndarray to handle ExtensionArray types
|
|
329
|
+
ref_values = np.asarray(reference_pd[feature].values, dtype=np.float64)
|
|
330
|
+
test_values = np.asarray(test_pd[feature].values, dtype=np.float64)
|
|
331
|
+
|
|
332
|
+
psi_result = None
|
|
333
|
+
wasserstein_result = None
|
|
334
|
+
n_methods_run = 0
|
|
335
|
+
n_methods_detected = 0
|
|
336
|
+
|
|
337
|
+
# PSI
|
|
338
|
+
if "psi" in methods:
|
|
339
|
+
try:
|
|
340
|
+
psi_result = compute_psi(ref_values, test_values, **psi_config)
|
|
341
|
+
n_methods_run += 1
|
|
342
|
+
if psi_result.alert_level in ["yellow", "red"]:
|
|
343
|
+
n_methods_detected += 1
|
|
344
|
+
except Exception as e:
|
|
345
|
+
# Log warning but continue
|
|
346
|
+
print(f"Warning: PSI failed for feature {feature}: {e}")
|
|
347
|
+
|
|
348
|
+
# Wasserstein
|
|
349
|
+
if "wasserstein" in methods:
|
|
350
|
+
try:
|
|
351
|
+
wasserstein_result = compute_wasserstein_distance(
|
|
352
|
+
ref_values, test_values, **wasserstein_config
|
|
353
|
+
)
|
|
354
|
+
n_methods_run += 1
|
|
355
|
+
if wasserstein_result.drifted:
|
|
356
|
+
n_methods_detected += 1
|
|
357
|
+
except Exception as e:
|
|
358
|
+
# Log warning but continue
|
|
359
|
+
print(f"Warning: Wasserstein failed for feature {feature}: {e}")
|
|
360
|
+
|
|
361
|
+
# Consensus drift flag
|
|
362
|
+
drift_probability = n_methods_detected / max(1, n_methods_run)
|
|
363
|
+
drifted = drift_probability >= consensus_threshold
|
|
364
|
+
|
|
365
|
+
# Interpretation
|
|
366
|
+
if drifted:
|
|
367
|
+
interpretation = f"{n_methods_detected}/{n_methods_run} methods detected drift (probability: {drift_probability:.0%})"
|
|
368
|
+
else:
|
|
369
|
+
interpretation = (
|
|
370
|
+
f"No consensus drift ({n_methods_detected}/{n_methods_run} methods, "
|
|
371
|
+
f"threshold: {consensus_threshold:.0%})"
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
feature_results.append(
|
|
375
|
+
FeatureDriftResult(
|
|
376
|
+
feature=feature,
|
|
377
|
+
psi_result=psi_result,
|
|
378
|
+
wasserstein_result=wasserstein_result,
|
|
379
|
+
drifted=drifted,
|
|
380
|
+
n_methods_run=n_methods_run,
|
|
381
|
+
n_methods_detected=n_methods_detected,
|
|
382
|
+
drift_probability=drift_probability,
|
|
383
|
+
interpretation=interpretation,
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# Run multivariate domain classifier if requested
|
|
388
|
+
domain_classifier_result = None
|
|
389
|
+
if "domain_classifier" in methods:
|
|
390
|
+
try:
|
|
391
|
+
domain_classifier_result = compute_domain_classifier_drift(
|
|
392
|
+
reference[features], test[features], **domain_classifier_config
|
|
393
|
+
)
|
|
394
|
+
except Exception as e:
|
|
395
|
+
# Log warning but continue
|
|
396
|
+
print(f"Warning: Domain classifier failed: {e}")
|
|
397
|
+
|
|
398
|
+
# Aggregate results
|
|
399
|
+
n_features = len(features)
|
|
400
|
+
n_features_drifted = sum(r.drifted for r in feature_results)
|
|
401
|
+
drifted_features = [r.feature for r in feature_results if r.drifted]
|
|
402
|
+
|
|
403
|
+
# Overall drift flag
|
|
404
|
+
overall_drifted = n_features_drifted > 0
|
|
405
|
+
if domain_classifier_result is not None and domain_classifier_result.drifted:
|
|
406
|
+
overall_drifted = True
|
|
407
|
+
|
|
408
|
+
# Interpretation
|
|
409
|
+
if overall_drifted:
|
|
410
|
+
interpretation = (
|
|
411
|
+
f"Drift detected in {n_features_drifted}/{n_features} features "
|
|
412
|
+
f"({n_features_drifted / max(1, n_features):.0%})"
|
|
413
|
+
)
|
|
414
|
+
else:
|
|
415
|
+
interpretation = f"No drift detected across {n_features} features"
|
|
416
|
+
|
|
417
|
+
computation_time = time.time() - start_time
|
|
418
|
+
|
|
419
|
+
return DriftSummaryResult(
|
|
420
|
+
feature_results=feature_results,
|
|
421
|
+
domain_classifier_result=domain_classifier_result,
|
|
422
|
+
n_features=n_features,
|
|
423
|
+
n_features_drifted=n_features_drifted,
|
|
424
|
+
drifted_features=drifted_features,
|
|
425
|
+
overall_drifted=overall_drifted,
|
|
426
|
+
consensus_threshold=consensus_threshold,
|
|
427
|
+
methods_used=methods,
|
|
428
|
+
univariate_methods=univariate_methods,
|
|
429
|
+
multivariate_methods=multivariate_methods,
|
|
430
|
+
interpretation=interpretation,
|
|
431
|
+
computation_time=computation_time,
|
|
432
|
+
)
|