ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,517 @@
|
|
|
1
|
+
"""Domain classifier for multivariate distribution drift detection.
|
|
2
|
+
|
|
3
|
+
The domain classifier trains a binary model to distinguish reference (label=0)
|
|
4
|
+
from test (label=1) samples. AUC indicates drift magnitude, feature importances
|
|
5
|
+
show which features drifted.
|
|
6
|
+
|
|
7
|
+
Advantages:
|
|
8
|
+
- Detects multivariate drift and feature interactions
|
|
9
|
+
- Non-parametric (no distributional assumptions)
|
|
10
|
+
- Interpretable via feature importance
|
|
11
|
+
- Sensitive to subtle multivariate shifts
|
|
12
|
+
|
|
13
|
+
AUC Interpretation:
|
|
14
|
+
- AUC ≈ 0.5: No drift (random guess)
|
|
15
|
+
- AUC = 0.6: Weak drift
|
|
16
|
+
- AUC = 0.7-0.8: Moderate drift
|
|
17
|
+
- AUC > 0.9: Strong drift
|
|
18
|
+
|
|
19
|
+
References:
|
|
20
|
+
- Lopez-Paz, D., & Oquab, M. (2017). Revisiting Classifier Two-Sample Tests.
|
|
21
|
+
ICLR 2017.
|
|
22
|
+
- Rabanser, S., et al. (2019). Failing Loudly: An Empirical Study of Methods
|
|
23
|
+
for Detecting Dataset Shift. NeurIPS 2019.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
import time
|
|
29
|
+
from dataclasses import dataclass
|
|
30
|
+
from typing import Any
|
|
31
|
+
|
|
32
|
+
import numpy as np
|
|
33
|
+
import pandas as pd
|
|
34
|
+
import polars as pl
|
|
35
|
+
|
|
36
|
+
# Lazy check for optional ML dependencies (imported on first use to avoid slow startup)
|
|
37
|
+
LIGHTGBM_AVAILABLE: bool | None = None
|
|
38
|
+
XGBOOST_AVAILABLE: bool | None = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _check_lightgbm_available() -> bool:
|
|
42
|
+
"""Check if lightgbm is available (lazy check)."""
|
|
43
|
+
global LIGHTGBM_AVAILABLE
|
|
44
|
+
if LIGHTGBM_AVAILABLE is None:
|
|
45
|
+
try:
|
|
46
|
+
import lightgbm # noqa: F401
|
|
47
|
+
|
|
48
|
+
LIGHTGBM_AVAILABLE = True
|
|
49
|
+
except ImportError:
|
|
50
|
+
LIGHTGBM_AVAILABLE = False
|
|
51
|
+
return LIGHTGBM_AVAILABLE
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _check_xgboost_available() -> bool:
|
|
55
|
+
"""Check if xgboost is available (lazy check)."""
|
|
56
|
+
global XGBOOST_AVAILABLE
|
|
57
|
+
if XGBOOST_AVAILABLE is None:
|
|
58
|
+
try:
|
|
59
|
+
import xgboost # noqa: F401
|
|
60
|
+
|
|
61
|
+
XGBOOST_AVAILABLE = True
|
|
62
|
+
except ImportError:
|
|
63
|
+
XGBOOST_AVAILABLE = False
|
|
64
|
+
return XGBOOST_AVAILABLE
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class DomainClassifierResult:
|
|
69
|
+
"""Result of domain classifier drift detection.
|
|
70
|
+
|
|
71
|
+
Domain classifier trains a binary model to distinguish reference (label=0)
|
|
72
|
+
from test (label=1) samples. AUC indicates drift magnitude, feature importances
|
|
73
|
+
show which features drifted.
|
|
74
|
+
|
|
75
|
+
Attributes:
|
|
76
|
+
auc: AUC-ROC score (0.5 = no drift, 1.0 = complete distribution shift)
|
|
77
|
+
drifted: Whether drift was detected (auc > threshold)
|
|
78
|
+
feature_importances: DataFrame with feature, importance, rank columns
|
|
79
|
+
threshold: AUC threshold used for drift detection
|
|
80
|
+
n_reference: Number of samples in reference distribution
|
|
81
|
+
n_test: Number of samples in test distribution
|
|
82
|
+
n_features: Number of features used
|
|
83
|
+
model_type: Type of classifier used (lightgbm, xgboost, sklearn)
|
|
84
|
+
cv_auc_mean: Mean AUC from cross-validation
|
|
85
|
+
cv_auc_std: Std of AUC from cross-validation
|
|
86
|
+
interpretation: Human-readable interpretation
|
|
87
|
+
computation_time: Time taken to compute (seconds)
|
|
88
|
+
metadata: Additional metadata
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
auc: float
|
|
92
|
+
drifted: bool
|
|
93
|
+
feature_importances: pl.DataFrame
|
|
94
|
+
threshold: float
|
|
95
|
+
n_reference: int
|
|
96
|
+
n_test: int
|
|
97
|
+
n_features: int
|
|
98
|
+
model_type: str
|
|
99
|
+
cv_auc_mean: float
|
|
100
|
+
cv_auc_std: float
|
|
101
|
+
interpretation: str
|
|
102
|
+
computation_time: float
|
|
103
|
+
metadata: dict[str, Any]
|
|
104
|
+
|
|
105
|
+
def summary(self) -> str:
|
|
106
|
+
"""Return formatted summary of domain classifier results."""
|
|
107
|
+
lines = [
|
|
108
|
+
"Domain Classifier Drift Detection Report",
|
|
109
|
+
"=" * 60,
|
|
110
|
+
f"AUC-ROC: {self.auc:.4f} (CV: {self.cv_auc_mean:.4f} ± {self.cv_auc_std:.4f})",
|
|
111
|
+
f"Drift Detected: {'YES' if self.drifted else 'NO'}",
|
|
112
|
+
f"Threshold: {self.threshold:.4f}",
|
|
113
|
+
"",
|
|
114
|
+
"Sample Sizes:",
|
|
115
|
+
f" Reference: {self.n_reference:,}",
|
|
116
|
+
f" Test: {self.n_test:,}",
|
|
117
|
+
"",
|
|
118
|
+
f"Model: {self.model_type}",
|
|
119
|
+
f"Features: {self.n_features}",
|
|
120
|
+
"",
|
|
121
|
+
"Top 5 Most Drifted Features:",
|
|
122
|
+
"-" * 60,
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
# Show top 5 features
|
|
126
|
+
top_features = self.feature_importances.head(5)
|
|
127
|
+
for row in top_features.iter_rows(named=True):
|
|
128
|
+
lines.append(
|
|
129
|
+
f" {row['rank']:2d}. {row['feature']:30s} (importance: {row['importance']:.4f})"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
lines.extend(
|
|
133
|
+
[
|
|
134
|
+
"",
|
|
135
|
+
f"Interpretation: {self.interpretation}",
|
|
136
|
+
"",
|
|
137
|
+
f"Computation Time: {self.computation_time:.3f}s",
|
|
138
|
+
]
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return "\n".join(lines)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def compute_domain_classifier_drift(
|
|
145
|
+
reference: np.ndarray | pd.DataFrame | pl.DataFrame,
|
|
146
|
+
test: np.ndarray | pd.DataFrame | pl.DataFrame,
|
|
147
|
+
features: list[str] | None = None,
|
|
148
|
+
*,
|
|
149
|
+
model_type: str = "lightgbm",
|
|
150
|
+
n_estimators: int = 100,
|
|
151
|
+
max_depth: int = 5,
|
|
152
|
+
threshold: float = 0.6,
|
|
153
|
+
cv_folds: int = 5,
|
|
154
|
+
random_state: int = 42,
|
|
155
|
+
) -> DomainClassifierResult:
|
|
156
|
+
"""Detect distribution drift using domain classifier.
|
|
157
|
+
|
|
158
|
+
Trains a binary classifier to distinguish reference (label=0) from test (label=1)
|
|
159
|
+
samples. AUC-ROC indicates drift magnitude, feature importance shows which features
|
|
160
|
+
drifted most.
|
|
161
|
+
|
|
162
|
+
The domain classifier approach detects multivariate drift by testing whether
|
|
163
|
+
a classifier can distinguish between two distributions. If AUC ≈ 0.5, the
|
|
164
|
+
distributions are indistinguishable (no drift). If AUC → 1.0, the distributions
|
|
165
|
+
are completely separated (strong drift).
|
|
166
|
+
|
|
167
|
+
**Advantages**:
|
|
168
|
+
- Detects multivariate drift and feature interactions
|
|
169
|
+
- Non-parametric (no distributional assumptions)
|
|
170
|
+
- Interpretable via feature importance
|
|
171
|
+
- Sensitive to subtle multivariate shifts
|
|
172
|
+
|
|
173
|
+
**AUC Interpretation**:
|
|
174
|
+
- AUC ≈ 0.5: No drift (random guess)
|
|
175
|
+
- AUC = 0.6: Weak drift
|
|
176
|
+
- AUC = 0.7-0.8: Moderate drift
|
|
177
|
+
- AUC > 0.9: Strong drift
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
reference: Reference distribution (e.g., training data).
|
|
181
|
+
Can be numpy array, pandas DataFrame, or polars DataFrame.
|
|
182
|
+
test: Test distribution (e.g., production data).
|
|
183
|
+
Can be numpy array, pandas DataFrame, or polars DataFrame.
|
|
184
|
+
features: List of feature names to use. If None, uses all numeric columns.
|
|
185
|
+
Only applicable for DataFrame inputs.
|
|
186
|
+
model_type: Classifier type. Options:
|
|
187
|
+
- "lightgbm": LightGBM (default, fastest)
|
|
188
|
+
- "xgboost": XGBoost
|
|
189
|
+
- "sklearn": sklearn RandomForestClassifier (always available)
|
|
190
|
+
n_estimators: Number of trees/estimators (default: 100)
|
|
191
|
+
max_depth: Maximum tree depth (default: 5)
|
|
192
|
+
threshold: AUC threshold for flagging drift (default: 0.6)
|
|
193
|
+
cv_folds: Number of cross-validation folds (default: 5)
|
|
194
|
+
random_state: Random seed for reproducibility (default: 42)
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
DomainClassifierResult with AUC, feature importances, drift flag, etc.
|
|
198
|
+
|
|
199
|
+
Raises:
|
|
200
|
+
ValueError: If inputs are invalid or model_type is unknown
|
|
201
|
+
ImportError: If required ML library is not installed
|
|
202
|
+
|
|
203
|
+
Example:
|
|
204
|
+
>>> import numpy as np
|
|
205
|
+
>>> import polars as pl
|
|
206
|
+
>>> from ml4t.diagnostic.evaluation.drift import compute_domain_classifier_drift
|
|
207
|
+
>>>
|
|
208
|
+
>>> # No drift (identical distributions)
|
|
209
|
+
>>> np.random.seed(42)
|
|
210
|
+
>>> ref = pl.DataFrame({
|
|
211
|
+
... "x1": np.random.normal(0, 1, 500),
|
|
212
|
+
... "x2": np.random.normal(0, 1, 500),
|
|
213
|
+
>>> })
|
|
214
|
+
>>> test = pl.DataFrame({
|
|
215
|
+
... "x1": np.random.normal(0, 1, 500),
|
|
216
|
+
... "x2": np.random.normal(0, 1, 500),
|
|
217
|
+
>>> })
|
|
218
|
+
>>> result = compute_domain_classifier_drift(ref, test)
|
|
219
|
+
>>> print(f"AUC: {result.auc:.4f}, Drifted: {result.drifted}")
|
|
220
|
+
AUC: 0.5123, Drifted: False
|
|
221
|
+
>>>
|
|
222
|
+
>>> # Strong drift (mean shift)
|
|
223
|
+
>>> test_shifted = pl.DataFrame({
|
|
224
|
+
... "x1": np.random.normal(2, 1, 500),
|
|
225
|
+
... "x2": np.random.normal(2, 1, 500),
|
|
226
|
+
>>> })
|
|
227
|
+
>>> result = compute_domain_classifier_drift(ref, test_shifted)
|
|
228
|
+
>>> print(f"AUC: {result.auc:.4f}, Drifted: {result.drifted}")
|
|
229
|
+
AUC: 0.9876, Drifted: True
|
|
230
|
+
>>> print(result.summary())
|
|
231
|
+
>>>
|
|
232
|
+
>>> # Interaction-based drift
|
|
233
|
+
>>> test_corr = pl.DataFrame({
|
|
234
|
+
... "x1": np.random.normal(0, 1, 500),
|
|
235
|
+
... "x2": np.random.normal(0, 1, 500) + 0.8 * np.random.normal(0, 1, 500),
|
|
236
|
+
>>> })
|
|
237
|
+
>>> result = compute_domain_classifier_drift(ref, test_corr)
|
|
238
|
+
>>> # Will detect correlation change via feature interactions
|
|
239
|
+
|
|
240
|
+
References:
|
|
241
|
+
- Lopez-Paz, D., & Oquab, M. (2017). Revisiting Classifier Two-Sample Tests.
|
|
242
|
+
ICLR 2017.
|
|
243
|
+
- Rabanser, S., et al. (2019). Failing Loudly: An Empirical Study of Methods
|
|
244
|
+
for Detecting Dataset Shift. NeurIPS 2019.
|
|
245
|
+
"""
|
|
246
|
+
start_time = time.time()
|
|
247
|
+
|
|
248
|
+
# Prepare data
|
|
249
|
+
X, y, feature_names = _prepare_domain_classification_data(reference, test, features)
|
|
250
|
+
|
|
251
|
+
# Train classifier with cross-validation
|
|
252
|
+
model, cv_scores = _train_domain_classifier(
|
|
253
|
+
X,
|
|
254
|
+
y,
|
|
255
|
+
model_type=model_type,
|
|
256
|
+
n_estimators=n_estimators,
|
|
257
|
+
max_depth=max_depth,
|
|
258
|
+
cv_folds=cv_folds,
|
|
259
|
+
random_state=random_state,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Extract feature importances
|
|
263
|
+
importances_df = _extract_feature_importances(model, feature_names)
|
|
264
|
+
|
|
265
|
+
# Compute final AUC on full data
|
|
266
|
+
from sklearn.metrics import roc_auc_score
|
|
267
|
+
|
|
268
|
+
y_pred_proba = model.predict_proba(X)[:, 1]
|
|
269
|
+
final_auc = float(roc_auc_score(y, y_pred_proba))
|
|
270
|
+
|
|
271
|
+
# Determine drift status
|
|
272
|
+
drifted = final_auc > threshold
|
|
273
|
+
|
|
274
|
+
# Generate interpretation
|
|
275
|
+
cv_auc_mean = float(np.mean(cv_scores))
|
|
276
|
+
cv_auc_std = float(np.std(cv_scores))
|
|
277
|
+
|
|
278
|
+
if drifted:
|
|
279
|
+
if final_auc > 0.9:
|
|
280
|
+
severity = "strong"
|
|
281
|
+
elif final_auc > 0.7:
|
|
282
|
+
severity = "moderate"
|
|
283
|
+
else:
|
|
284
|
+
severity = "weak"
|
|
285
|
+
|
|
286
|
+
interpretation = (
|
|
287
|
+
f"{severity.capitalize()} distribution drift detected "
|
|
288
|
+
f"(AUC={final_auc:.4f} > {threshold:.4f}). "
|
|
289
|
+
f"The classifier can distinguish reference from test distributions. "
|
|
290
|
+
f"Top drifted feature: {importances_df['feature'][0]}."
|
|
291
|
+
)
|
|
292
|
+
else:
|
|
293
|
+
interpretation = (
|
|
294
|
+
f"No significant drift detected (AUC={final_auc:.4f} ≤ {threshold:.4f}). "
|
|
295
|
+
f"Distributions are indistinguishable by the classifier."
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
computation_time = time.time() - start_time
|
|
299
|
+
|
|
300
|
+
return DomainClassifierResult(
|
|
301
|
+
auc=final_auc,
|
|
302
|
+
drifted=drifted,
|
|
303
|
+
feature_importances=importances_df,
|
|
304
|
+
threshold=threshold,
|
|
305
|
+
n_reference=int(np.sum(y == 0)),
|
|
306
|
+
n_test=int(np.sum(y == 1)),
|
|
307
|
+
n_features=len(feature_names),
|
|
308
|
+
model_type=model_type,
|
|
309
|
+
cv_auc_mean=cv_auc_mean,
|
|
310
|
+
cv_auc_std=cv_auc_std,
|
|
311
|
+
interpretation=interpretation,
|
|
312
|
+
computation_time=computation_time,
|
|
313
|
+
metadata={
|
|
314
|
+
"n_estimators": n_estimators,
|
|
315
|
+
"max_depth": max_depth,
|
|
316
|
+
"cv_folds": cv_folds,
|
|
317
|
+
"random_state": random_state,
|
|
318
|
+
},
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _prepare_domain_classification_data(
|
|
323
|
+
reference: np.ndarray | pd.DataFrame | pl.DataFrame,
|
|
324
|
+
test: np.ndarray | pd.DataFrame | pl.DataFrame,
|
|
325
|
+
features: list[str] | None = None,
|
|
326
|
+
) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
|
327
|
+
"""Prepare labeled dataset for domain classification.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
reference: Reference distribution
|
|
331
|
+
test: Test distribution
|
|
332
|
+
features: Feature names to use (for DataFrames)
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Tuple of (X, y, feature_names):
|
|
336
|
+
- X: Feature matrix (reference + test concatenated)
|
|
337
|
+
- y: Labels (0 for reference, 1 for test)
|
|
338
|
+
- feature_names: List of feature names
|
|
339
|
+
|
|
340
|
+
Raises:
|
|
341
|
+
ValueError: If inputs are invalid or incompatible
|
|
342
|
+
"""
|
|
343
|
+
# Convert to numpy arrays
|
|
344
|
+
if isinstance(reference, pl.DataFrame):
|
|
345
|
+
if features is None:
|
|
346
|
+
# Use all numeric columns
|
|
347
|
+
features = [
|
|
348
|
+
c
|
|
349
|
+
for c in reference.columns
|
|
350
|
+
if reference[c].dtype
|
|
351
|
+
in (pl.Float64, pl.Float32, pl.Int64, pl.Int32, pl.Int16, pl.Int8)
|
|
352
|
+
]
|
|
353
|
+
X_ref = reference[features].to_numpy()
|
|
354
|
+
feature_names = features
|
|
355
|
+
|
|
356
|
+
elif isinstance(reference, pd.DataFrame):
|
|
357
|
+
if features is None:
|
|
358
|
+
# Use all numeric columns
|
|
359
|
+
features = list(reference.select_dtypes(include=[np.number]).columns)
|
|
360
|
+
X_ref = reference[features].to_numpy()
|
|
361
|
+
feature_names = features
|
|
362
|
+
|
|
363
|
+
elif isinstance(reference, np.ndarray):
|
|
364
|
+
X_ref = reference
|
|
365
|
+
if features is None:
|
|
366
|
+
# Generate default feature names
|
|
367
|
+
if X_ref.ndim == 1:
|
|
368
|
+
X_ref = X_ref.reshape(-1, 1)
|
|
369
|
+
feature_names = [f"feature_{i}" for i in range(X_ref.shape[1])]
|
|
370
|
+
else:
|
|
371
|
+
feature_names = features
|
|
372
|
+
|
|
373
|
+
else:
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"Unsupported reference type: {type(reference)}. "
|
|
376
|
+
"Must be numpy array, pandas DataFrame, or polars DataFrame."
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Process test data
|
|
380
|
+
if isinstance(test, pl.DataFrame | pd.DataFrame):
|
|
381
|
+
X_test = test[feature_names].to_numpy()
|
|
382
|
+
elif isinstance(test, np.ndarray):
|
|
383
|
+
X_test = test
|
|
384
|
+
if X_test.ndim == 1:
|
|
385
|
+
X_test = X_test.reshape(-1, 1)
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError(
|
|
388
|
+
f"Unsupported test type: {type(test)}. Must be numpy array, pandas DataFrame, or polars DataFrame."
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# Validate shapes
|
|
392
|
+
if X_ref.shape[1] != X_test.shape[1]:
|
|
393
|
+
raise ValueError(
|
|
394
|
+
f"Feature count mismatch: reference has {X_ref.shape[1]} features, test has {X_test.shape[1]} features."
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# Concatenate and create labels
|
|
398
|
+
X = np.vstack([X_ref, X_test])
|
|
399
|
+
y = np.concatenate([np.zeros(len(X_ref)), np.ones(len(X_test))])
|
|
400
|
+
|
|
401
|
+
return X, y, feature_names
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def _train_domain_classifier(
|
|
405
|
+
X: np.ndarray,
|
|
406
|
+
y: np.ndarray,
|
|
407
|
+
model_type: str = "lightgbm",
|
|
408
|
+
n_estimators: int = 100,
|
|
409
|
+
max_depth: int = 5,
|
|
410
|
+
cv_folds: int = 5,
|
|
411
|
+
random_state: int = 42,
|
|
412
|
+
) -> tuple[Any, np.ndarray]:
|
|
413
|
+
"""Train binary classifier for domain classification.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
X: Feature matrix
|
|
417
|
+
y: Labels (0=reference, 1=test)
|
|
418
|
+
model_type: Classifier type
|
|
419
|
+
n_estimators: Number of trees
|
|
420
|
+
max_depth: Maximum tree depth
|
|
421
|
+
cv_folds: Cross-validation folds
|
|
422
|
+
random_state: Random seed
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
Tuple of (trained_model, cv_auc_scores)
|
|
426
|
+
|
|
427
|
+
Raises:
|
|
428
|
+
ValueError: If model_type is unknown
|
|
429
|
+
ImportError: If required library is not installed
|
|
430
|
+
"""
|
|
431
|
+
from sklearn.model_selection import cross_val_score
|
|
432
|
+
|
|
433
|
+
# Select and configure model
|
|
434
|
+
if model_type == "lightgbm":
|
|
435
|
+
if not _check_lightgbm_available():
|
|
436
|
+
raise ImportError(
|
|
437
|
+
"LightGBM required for domain classifier drift detection. "
|
|
438
|
+
"Install with: pip install ml4t-diagnostic[ml] or pip install lightgbm"
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
import lightgbm as lgb
|
|
442
|
+
|
|
443
|
+
model = lgb.LGBMClassifier(
|
|
444
|
+
n_estimators=n_estimators,
|
|
445
|
+
max_depth=max_depth,
|
|
446
|
+
random_state=random_state,
|
|
447
|
+
verbose=-1,
|
|
448
|
+
force_col_wise=True, # Suppress warning
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
elif model_type == "xgboost":
|
|
452
|
+
if not _check_xgboost_available():
|
|
453
|
+
raise ImportError(
|
|
454
|
+
"XGBoost required for domain classifier drift detection. Install with: pip install xgboost"
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
import xgboost as xgb
|
|
458
|
+
|
|
459
|
+
model = xgb.XGBClassifier(
|
|
460
|
+
n_estimators=n_estimators,
|
|
461
|
+
max_depth=max_depth,
|
|
462
|
+
random_state=random_state,
|
|
463
|
+
verbosity=0,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
elif model_type == "sklearn":
|
|
467
|
+
from sklearn.ensemble import RandomForestClassifier
|
|
468
|
+
|
|
469
|
+
model = RandomForestClassifier(
|
|
470
|
+
n_estimators=n_estimators,
|
|
471
|
+
max_depth=max_depth,
|
|
472
|
+
random_state=random_state,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
else:
|
|
476
|
+
raise ValueError(
|
|
477
|
+
f"Unknown model_type: '{model_type}'. Must be 'lightgbm', 'xgboost', or 'sklearn'."
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
# Cross-validation for AUC
|
|
481
|
+
cv_scores = cross_val_score(model, X, y, cv=cv_folds, scoring="roc_auc")
|
|
482
|
+
|
|
483
|
+
# Train on full data
|
|
484
|
+
model.fit(X, y)
|
|
485
|
+
|
|
486
|
+
return model, cv_scores
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def _extract_feature_importances(model: Any, feature_names: list[str]) -> pl.DataFrame:
|
|
490
|
+
"""Extract and rank feature importances.
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
model: Trained model with feature_importances_ attribute
|
|
494
|
+
feature_names: List of feature names
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
Polars DataFrame with columns: feature, importance, rank
|
|
498
|
+
|
|
499
|
+
Raises:
|
|
500
|
+
ValueError: If model doesn't have feature importances
|
|
501
|
+
"""
|
|
502
|
+
# Get importances (works for LightGBM, XGBoost, sklearn)
|
|
503
|
+
if hasattr(model, "feature_importances_"):
|
|
504
|
+
importances = model.feature_importances_
|
|
505
|
+
else:
|
|
506
|
+
raise ValueError(f"Model type {type(model)} does not have feature_importances_ attribute")
|
|
507
|
+
|
|
508
|
+
# Create DataFrame
|
|
509
|
+
df = pl.DataFrame({"feature": feature_names, "importance": importances})
|
|
510
|
+
|
|
511
|
+
# Sort by importance (descending)
|
|
512
|
+
df = df.sort("importance", descending=True)
|
|
513
|
+
|
|
514
|
+
# Add rank
|
|
515
|
+
df = df.with_columns(pl.arange(1, len(df) + 1).alias("rank"))
|
|
516
|
+
|
|
517
|
+
return df
|