ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
"""Normality tests for distribution analysis.
|
|
2
|
+
|
|
3
|
+
This module provides statistical tests for normality:
|
|
4
|
+
- Jarque-Bera test: Based on sample skewness and kurtosis, asymptotically valid
|
|
5
|
+
- Shapiro-Wilk test: More powerful for small samples (n < 2000), recommended
|
|
6
|
+
|
|
7
|
+
Test Comparison:
|
|
8
|
+
- Jarque-Bera: Based on sample skewness and kurtosis, asymptotically valid
|
|
9
|
+
- Shapiro-Wilk: More powerful for small samples (n < 2000), recommended
|
|
10
|
+
|
|
11
|
+
References:
|
|
12
|
+
- Jarque, C. M., & Bera, A. K. (1980). Efficient tests for normality,
|
|
13
|
+
homoscedasticity and serial independence of regression residuals.
|
|
14
|
+
Economics Letters, 6(3), 255-259. DOI: 10.1016/0165-1765(80)90024-5
|
|
15
|
+
- Shapiro, S. S., & Wilk, M. B. (1965). An analysis of variance test
|
|
16
|
+
for normality (complete samples). Biometrika, 52(3-4), 591-611.
|
|
17
|
+
DOI: 10.1093/biomet/52.3-4.591
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pandas as pd
|
|
26
|
+
from scipy import stats
|
|
27
|
+
|
|
28
|
+
from ml4t.diagnostic.errors import ComputationError, ValidationError
|
|
29
|
+
from ml4t.diagnostic.logging import get_logger
|
|
30
|
+
|
|
31
|
+
logger = get_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class JarqueBeraResult:
|
|
36
|
+
"""Jarque-Bera normality test result.
|
|
37
|
+
|
|
38
|
+
Tests for normality based on sample skewness and kurtosis. The test
|
|
39
|
+
statistic is: JB = (n/6) * (S^2 + K^2/4), where S is skewness and K
|
|
40
|
+
is excess kurtosis. Under H0 (normality), JB ~ χ²(2).
|
|
41
|
+
|
|
42
|
+
Attributes:
|
|
43
|
+
statistic: Jarque-Bera test statistic
|
|
44
|
+
p_value: P-value for null hypothesis (data is normally distributed)
|
|
45
|
+
skewness: Sample skewness used in test
|
|
46
|
+
excess_kurtosis: Sample excess kurtosis used in test (Fisher: normal=0)
|
|
47
|
+
is_normal: Whether data is consistent with normality (p >= alpha)
|
|
48
|
+
n_obs: Number of observations
|
|
49
|
+
alpha: Significance level used
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
statistic: float
|
|
53
|
+
p_value: float
|
|
54
|
+
skewness: float
|
|
55
|
+
excess_kurtosis: float
|
|
56
|
+
is_normal: bool
|
|
57
|
+
n_obs: int
|
|
58
|
+
alpha: float = 0.05
|
|
59
|
+
|
|
60
|
+
def __repr__(self) -> str:
|
|
61
|
+
"""String representation."""
|
|
62
|
+
return f"JarqueBeraResult(statistic={self.statistic:.4f}, p_value={self.p_value:.4f}, is_normal={self.is_normal})"
|
|
63
|
+
|
|
64
|
+
def summary(self) -> str:
|
|
65
|
+
"""Human-readable summary of Jarque-Bera test.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Formatted summary string
|
|
69
|
+
"""
|
|
70
|
+
lines = [
|
|
71
|
+
"Jarque-Bera Normality Test",
|
|
72
|
+
"=" * 50,
|
|
73
|
+
f"Test Statistic: {self.statistic:.4f}",
|
|
74
|
+
f"P-value: {self.p_value:.4f}",
|
|
75
|
+
f"Observations: {self.n_obs}",
|
|
76
|
+
f"Significance: α={self.alpha}",
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
lines.append("")
|
|
80
|
+
lines.append("Moments:")
|
|
81
|
+
lines.append(f" Skewness: {self.skewness:.4f}")
|
|
82
|
+
lines.append(f" Excess Kurtosis: {self.excess_kurtosis:.4f}")
|
|
83
|
+
|
|
84
|
+
lines.append("")
|
|
85
|
+
conclusion = (
|
|
86
|
+
"Data is consistent with normality"
|
|
87
|
+
if self.is_normal
|
|
88
|
+
else "Data deviates from normality"
|
|
89
|
+
)
|
|
90
|
+
lines.append(f"Conclusion: {conclusion}")
|
|
91
|
+
lines.append(
|
|
92
|
+
f" (Fail to reject H0 at {self.alpha * 100:.0f}% level)"
|
|
93
|
+
if self.is_normal
|
|
94
|
+
else f" (Reject H0 at {self.alpha * 100:.0f}% level)"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
lines.append("")
|
|
98
|
+
lines.append("Test Methodology:")
|
|
99
|
+
lines.append(" - JB = (n/6) * (S² + K²/4)")
|
|
100
|
+
lines.append(" - H0: Data is normally distributed")
|
|
101
|
+
lines.append(" - Under H0: JB ~ χ²(2)")
|
|
102
|
+
lines.append(" - Asymptotically valid (requires large n)")
|
|
103
|
+
|
|
104
|
+
if not self.is_normal:
|
|
105
|
+
lines.append("")
|
|
106
|
+
lines.append("Implications:")
|
|
107
|
+
lines.append(" - Normal distribution assumption violated")
|
|
108
|
+
lines.append(" - Consider robust statistical methods")
|
|
109
|
+
lines.append(" - Account for non-normality in risk models")
|
|
110
|
+
|
|
111
|
+
return "\n".join(lines)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclass
|
|
115
|
+
class ShapiroWilkResult:
|
|
116
|
+
"""Shapiro-Wilk normality test result.
|
|
117
|
+
|
|
118
|
+
Tests for normality using order statistics. More powerful than Jarque-Bera
|
|
119
|
+
for small samples (n < 2000). The test statistic W ranges from 0 to 1,
|
|
120
|
+
with values close to 1 indicating normality.
|
|
121
|
+
|
|
122
|
+
Attributes:
|
|
123
|
+
statistic: Shapiro-Wilk test statistic (W)
|
|
124
|
+
p_value: P-value for null hypothesis (data is normally distributed)
|
|
125
|
+
is_normal: Whether data is consistent with normality (p >= alpha)
|
|
126
|
+
n_obs: Number of observations
|
|
127
|
+
alpha: Significance level used
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
statistic: float
|
|
131
|
+
p_value: float
|
|
132
|
+
is_normal: bool
|
|
133
|
+
n_obs: int
|
|
134
|
+
alpha: float = 0.05
|
|
135
|
+
|
|
136
|
+
def __repr__(self) -> str:
|
|
137
|
+
"""String representation."""
|
|
138
|
+
return f"ShapiroWilkResult(statistic={self.statistic:.4f}, p_value={self.p_value:.4f}, is_normal={self.is_normal})"
|
|
139
|
+
|
|
140
|
+
def summary(self) -> str:
|
|
141
|
+
"""Human-readable summary of Shapiro-Wilk test.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Formatted summary string
|
|
145
|
+
"""
|
|
146
|
+
lines = [
|
|
147
|
+
"Shapiro-Wilk Normality Test",
|
|
148
|
+
"=" * 50,
|
|
149
|
+
f"Test Statistic (W): {self.statistic:.4f}",
|
|
150
|
+
f"P-value: {self.p_value:.4f}",
|
|
151
|
+
f"Observations: {self.n_obs}",
|
|
152
|
+
f"Significance: α={self.alpha}",
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
lines.append("")
|
|
156
|
+
conclusion = (
|
|
157
|
+
"Data is consistent with normality"
|
|
158
|
+
if self.is_normal
|
|
159
|
+
else "Data deviates from normality"
|
|
160
|
+
)
|
|
161
|
+
lines.append(f"Conclusion: {conclusion}")
|
|
162
|
+
lines.append(
|
|
163
|
+
f" (Fail to reject H0 at {self.alpha * 100:.0f}% level)"
|
|
164
|
+
if self.is_normal
|
|
165
|
+
else f" (Reject H0 at {self.alpha * 100:.0f}% level)"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
lines.append("")
|
|
169
|
+
lines.append("Test Methodology:")
|
|
170
|
+
lines.append(" - Based on correlation between data and normal scores")
|
|
171
|
+
lines.append(" - W statistic ranges from 0 (non-normal) to 1 (normal)")
|
|
172
|
+
lines.append(" - H0: Data is normally distributed")
|
|
173
|
+
lines.append(" - More powerful than Jarque-Bera for small samples")
|
|
174
|
+
lines.append(" - Recommended for n < 2000")
|
|
175
|
+
|
|
176
|
+
if not self.is_normal:
|
|
177
|
+
lines.append("")
|
|
178
|
+
lines.append("Implications:")
|
|
179
|
+
lines.append(" - Normal distribution assumption violated")
|
|
180
|
+
lines.append(" - Consider non-parametric methods")
|
|
181
|
+
lines.append(" - Use robust estimators for inference")
|
|
182
|
+
|
|
183
|
+
return "\n".join(lines)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def jarque_bera_test(
|
|
187
|
+
data: pd.Series | np.ndarray,
|
|
188
|
+
alpha: float = 0.05,
|
|
189
|
+
) -> JarqueBeraResult:
|
|
190
|
+
"""Jarque-Bera test for normality.
|
|
191
|
+
|
|
192
|
+
Tests whether sample skewness and kurtosis match a normal distribution.
|
|
193
|
+
The test statistic is:
|
|
194
|
+
|
|
195
|
+
JB = (n/6) * (S^2 + K^2/4)
|
|
196
|
+
|
|
197
|
+
where n is sample size, S is skewness, K is excess kurtosis.
|
|
198
|
+
Under H0 (normality), JB ~ χ²(2).
|
|
199
|
+
|
|
200
|
+
The null hypothesis is that the data is normally distributed. Low p-values
|
|
201
|
+
(< alpha) indicate rejection of normality.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
data: Time series data (1D array or Series)
|
|
205
|
+
alpha: Significance level (default 0.05)
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
JarqueBeraResult with test statistics and conclusion
|
|
209
|
+
|
|
210
|
+
Raises:
|
|
211
|
+
ValidationError: If data is invalid (empty, wrong shape, etc.)
|
|
212
|
+
ComputationError: If test computation fails
|
|
213
|
+
|
|
214
|
+
Example:
|
|
215
|
+
>>> import numpy as np
|
|
216
|
+
>>> # Normal data (should pass)
|
|
217
|
+
>>> normal = np.random.normal(0, 1, 1000)
|
|
218
|
+
>>> result = jarque_bera_test(normal)
|
|
219
|
+
>>> print(f"p-value: {result.p_value:.4f}, normal: {result.is_normal}")
|
|
220
|
+
>>>
|
|
221
|
+
>>> # Lognormal data (should fail)
|
|
222
|
+
>>> lognormal = np.random.lognormal(0, 0.5, 1000)
|
|
223
|
+
>>> result = jarque_bera_test(lognormal)
|
|
224
|
+
>>> print(f"p-value: {result.p_value:.4f}, normal: {result.is_normal}")
|
|
225
|
+
|
|
226
|
+
Notes:
|
|
227
|
+
- Test is asymptotically valid (requires large n)
|
|
228
|
+
- More powerful for large samples (n > 2000)
|
|
229
|
+
- For small samples, use Shapiro-Wilk test instead
|
|
230
|
+
- Uses scipy.stats.jarque_bera
|
|
231
|
+
"""
|
|
232
|
+
# Input validation (same as compute_moments)
|
|
233
|
+
if data is None:
|
|
234
|
+
raise ValidationError("Data cannot be None", context={"function": "jarque_bera_test"})
|
|
235
|
+
|
|
236
|
+
# Convert to numpy array
|
|
237
|
+
if isinstance(data, pd.Series):
|
|
238
|
+
arr = data.to_numpy()
|
|
239
|
+
elif isinstance(data, np.ndarray):
|
|
240
|
+
arr = data
|
|
241
|
+
else:
|
|
242
|
+
raise ValidationError(
|
|
243
|
+
f"Data must be pandas Series or numpy array, got {type(data)}",
|
|
244
|
+
context={"function": "jarque_bera_test", "data_type": type(data).__name__},
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Check array properties
|
|
248
|
+
if arr.ndim != 1:
|
|
249
|
+
raise ValidationError(
|
|
250
|
+
f"Data must be 1-dimensional, got {arr.ndim}D",
|
|
251
|
+
context={"function": "jarque_bera_test", "shape": arr.shape},
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if len(arr) == 0:
|
|
255
|
+
raise ValidationError(
|
|
256
|
+
"Data cannot be empty", context={"function": "jarque_bera_test", "length": 0}
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Check for missing/infinite values
|
|
260
|
+
if np.any(~np.isfinite(arr)):
|
|
261
|
+
n_invalid = np.sum(~np.isfinite(arr))
|
|
262
|
+
raise ValidationError(
|
|
263
|
+
f"Data contains {n_invalid} NaN or infinite values",
|
|
264
|
+
context={"function": "jarque_bera_test", "n_invalid": n_invalid, "length": len(arr)},
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Check minimum length
|
|
268
|
+
min_length = 20
|
|
269
|
+
if len(arr) < min_length:
|
|
270
|
+
raise ValidationError(
|
|
271
|
+
f"Insufficient data for Jarque-Bera test (need at least {min_length} observations)",
|
|
272
|
+
context={
|
|
273
|
+
"function": "jarque_bera_test",
|
|
274
|
+
"length": len(arr),
|
|
275
|
+
"min_length": min_length,
|
|
276
|
+
},
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Check for constant series
|
|
280
|
+
if np.std(arr) == 0:
|
|
281
|
+
raise ValidationError(
|
|
282
|
+
"Data is constant (zero variance)",
|
|
283
|
+
context={
|
|
284
|
+
"function": "jarque_bera_test",
|
|
285
|
+
"length": len(arr),
|
|
286
|
+
"mean": float(np.mean(arr)),
|
|
287
|
+
},
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
logger.info("Running Jarque-Bera test", n_obs=len(arr), alpha=alpha)
|
|
291
|
+
|
|
292
|
+
try:
|
|
293
|
+
# Run Jarque-Bera test using scipy
|
|
294
|
+
# Returns (statistic, p_value)
|
|
295
|
+
jb_stat, p_value = stats.jarque_bera(arr)
|
|
296
|
+
|
|
297
|
+
# Compute moments for reporting
|
|
298
|
+
skewness = float(stats.skew(arr, bias=False))
|
|
299
|
+
excess_kurtosis = float(stats.kurtosis(arr, bias=False))
|
|
300
|
+
|
|
301
|
+
# Determine normality
|
|
302
|
+
is_normal = p_value >= alpha
|
|
303
|
+
|
|
304
|
+
logger.info(
|
|
305
|
+
"Jarque-Bera test completed",
|
|
306
|
+
statistic=jb_stat,
|
|
307
|
+
p_value=p_value,
|
|
308
|
+
is_normal=is_normal,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
return JarqueBeraResult(
|
|
312
|
+
statistic=float(jb_stat),
|
|
313
|
+
p_value=float(p_value),
|
|
314
|
+
skewness=skewness,
|
|
315
|
+
excess_kurtosis=excess_kurtosis,
|
|
316
|
+
is_normal=is_normal,
|
|
317
|
+
n_obs=len(arr),
|
|
318
|
+
alpha=alpha,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
except Exception as e:
|
|
322
|
+
logger.error("Jarque-Bera test failed", error=str(e), n_obs=len(arr))
|
|
323
|
+
raise ComputationError( # noqa: B904
|
|
324
|
+
f"Jarque-Bera test computation failed: {e}",
|
|
325
|
+
context={"function": "jarque_bera_test", "n_obs": len(arr), "alpha": alpha},
|
|
326
|
+
cause=e,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def shapiro_wilk_test(
|
|
331
|
+
data: pd.Series | np.ndarray,
|
|
332
|
+
alpha: float = 0.05,
|
|
333
|
+
) -> ShapiroWilkResult:
|
|
334
|
+
"""Shapiro-Wilk test for normality.
|
|
335
|
+
|
|
336
|
+
Tests for normality using order statistics. More powerful than Jarque-Bera
|
|
337
|
+
for small samples (n < 2000). The test statistic W ranges from 0 to 1,
|
|
338
|
+
with values close to 1 indicating normality.
|
|
339
|
+
|
|
340
|
+
The null hypothesis is that the data is normally distributed. Low p-values
|
|
341
|
+
(< alpha) indicate rejection of normality.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
data: Time series data (1D array or Series)
|
|
345
|
+
alpha: Significance level (default 0.05)
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
ShapiroWilkResult with test statistics and conclusion
|
|
349
|
+
|
|
350
|
+
Raises:
|
|
351
|
+
ValidationError: If data is invalid (empty, wrong shape, etc.)
|
|
352
|
+
ComputationError: If test computation fails
|
|
353
|
+
|
|
354
|
+
Example:
|
|
355
|
+
>>> import numpy as np
|
|
356
|
+
>>> # Normal data (should pass)
|
|
357
|
+
>>> normal = np.random.normal(0, 1, 500)
|
|
358
|
+
>>> result = shapiro_wilk_test(normal)
|
|
359
|
+
>>> print(f"W: {result.statistic:.4f}, p-value: {result.p_value:.4f}")
|
|
360
|
+
>>>
|
|
361
|
+
>>> # Lognormal data (should fail)
|
|
362
|
+
>>> lognormal = np.random.lognormal(0, 0.5, 500)
|
|
363
|
+
>>> result = shapiro_wilk_test(lognormal)
|
|
364
|
+
>>> print(f"Normal: {result.is_normal}")
|
|
365
|
+
|
|
366
|
+
Notes:
|
|
367
|
+
- More powerful than Jarque-Bera for small samples (n < 2000)
|
|
368
|
+
- Recommended over Jarque-Bera when n < 2000
|
|
369
|
+
- W statistic close to 1 indicates normality
|
|
370
|
+
- Uses scipy.stats.shapiro
|
|
371
|
+
- Maximum sample size: 5000 (scipy limitation)
|
|
372
|
+
"""
|
|
373
|
+
# Input validation (same as jarque_bera_test)
|
|
374
|
+
if data is None:
|
|
375
|
+
raise ValidationError("Data cannot be None", context={"function": "shapiro_wilk_test"})
|
|
376
|
+
|
|
377
|
+
# Convert to numpy array
|
|
378
|
+
if isinstance(data, pd.Series):
|
|
379
|
+
arr = data.to_numpy()
|
|
380
|
+
elif isinstance(data, np.ndarray):
|
|
381
|
+
arr = data
|
|
382
|
+
else:
|
|
383
|
+
raise ValidationError(
|
|
384
|
+
f"Data must be pandas Series or numpy array, got {type(data)}",
|
|
385
|
+
context={"function": "shapiro_wilk_test", "data_type": type(data).__name__},
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Check array properties
|
|
389
|
+
if arr.ndim != 1:
|
|
390
|
+
raise ValidationError(
|
|
391
|
+
f"Data must be 1-dimensional, got {arr.ndim}D",
|
|
392
|
+
context={"function": "shapiro_wilk_test", "shape": arr.shape},
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
if len(arr) == 0:
|
|
396
|
+
raise ValidationError(
|
|
397
|
+
"Data cannot be empty", context={"function": "shapiro_wilk_test", "length": 0}
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Check for missing/infinite values
|
|
401
|
+
if np.any(~np.isfinite(arr)):
|
|
402
|
+
n_invalid = np.sum(~np.isfinite(arr))
|
|
403
|
+
raise ValidationError(
|
|
404
|
+
f"Data contains {n_invalid} NaN or infinite values",
|
|
405
|
+
context={"function": "shapiro_wilk_test", "n_invalid": n_invalid, "length": len(arr)},
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# Check minimum length (Shapiro-Wilk needs at least 3 observations)
|
|
409
|
+
min_length = 3
|
|
410
|
+
if len(arr) < min_length:
|
|
411
|
+
raise ValidationError(
|
|
412
|
+
f"Insufficient data for Shapiro-Wilk test (need at least {min_length} observations)",
|
|
413
|
+
context={
|
|
414
|
+
"function": "shapiro_wilk_test",
|
|
415
|
+
"length": len(arr),
|
|
416
|
+
"min_length": min_length,
|
|
417
|
+
},
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Check maximum length (scipy limitation)
|
|
421
|
+
max_length = 5000
|
|
422
|
+
if len(arr) > max_length:
|
|
423
|
+
logger.warning(
|
|
424
|
+
f"Data has {len(arr)} observations, using first {max_length} (scipy.stats.shapiro limitation)"
|
|
425
|
+
)
|
|
426
|
+
arr = arr[:max_length]
|
|
427
|
+
|
|
428
|
+
# Check for constant series
|
|
429
|
+
if np.std(arr) == 0:
|
|
430
|
+
raise ValidationError(
|
|
431
|
+
"Data is constant (zero variance)",
|
|
432
|
+
context={
|
|
433
|
+
"function": "shapiro_wilk_test",
|
|
434
|
+
"length": len(arr),
|
|
435
|
+
"mean": float(np.mean(arr)),
|
|
436
|
+
},
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
logger.info("Running Shapiro-Wilk test", n_obs=len(arr), alpha=alpha)
|
|
440
|
+
|
|
441
|
+
try:
|
|
442
|
+
# Run Shapiro-Wilk test using scipy
|
|
443
|
+
# Returns (statistic, p_value)
|
|
444
|
+
w_stat, p_value = stats.shapiro(arr)
|
|
445
|
+
|
|
446
|
+
# Determine normality
|
|
447
|
+
is_normal = p_value >= alpha
|
|
448
|
+
|
|
449
|
+
logger.info(
|
|
450
|
+
"Shapiro-Wilk test completed",
|
|
451
|
+
statistic=w_stat,
|
|
452
|
+
p_value=p_value,
|
|
453
|
+
is_normal=is_normal,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
return ShapiroWilkResult(
|
|
457
|
+
statistic=float(w_stat),
|
|
458
|
+
p_value=float(p_value),
|
|
459
|
+
is_normal=is_normal,
|
|
460
|
+
n_obs=len(arr),
|
|
461
|
+
alpha=alpha,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
except Exception as e:
|
|
465
|
+
logger.error("Shapiro-Wilk test failed", error=str(e), n_obs=len(arr))
|
|
466
|
+
raise ComputationError( # noqa: B904
|
|
467
|
+
f"Shapiro-Wilk test computation failed: {e}",
|
|
468
|
+
context={"function": "shapiro_wilk_test", "n_obs": len(arr), "alpha": alpha},
|
|
469
|
+
cause=e,
|
|
470
|
+
)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""Distribution drift detection for feature monitoring.
|
|
2
|
+
|
|
3
|
+
This module provides comprehensive drift detection with three complementary methods
|
|
4
|
+
and a unified analysis interface:
|
|
5
|
+
|
|
6
|
+
**Individual Methods**:
|
|
7
|
+
- **PSI (Population Stability Index)**: Bin-based distribution comparison
|
|
8
|
+
- **Wasserstein Distance**: Optimal transport metric for continuous features
|
|
9
|
+
- **Domain Classifier**: ML-based multivariate drift detection with feature importance
|
|
10
|
+
|
|
11
|
+
**Unified Interface**:
|
|
12
|
+
- **analyze_drift()**: Multi-method drift analysis with consensus-based flagging
|
|
13
|
+
|
|
14
|
+
Distribution drift is critical for ML model monitoring:
|
|
15
|
+
- Feature distributions change over time (concept drift)
|
|
16
|
+
- Model performance degrades when test distribution differs from training
|
|
17
|
+
- Early detection allows proactive model retraining
|
|
18
|
+
- Multi-method consensus increases confidence in drift detection
|
|
19
|
+
|
|
20
|
+
PSI Interpretation:
|
|
21
|
+
- PSI < 0.1: No significant change (green)
|
|
22
|
+
- 0.1 ≤ PSI < 0.2: Small change, monitor (yellow)
|
|
23
|
+
- PSI ≥ 0.2: Significant change, investigate (red)
|
|
24
|
+
|
|
25
|
+
Wasserstein Distance Interpretation:
|
|
26
|
+
- W = 0: Identical distributions
|
|
27
|
+
- W > 0: Distribution drift detected
|
|
28
|
+
- Larger values indicate greater drift magnitude
|
|
29
|
+
- Threshold calibrated via permutation testing
|
|
30
|
+
|
|
31
|
+
Domain Classifier Interpretation:
|
|
32
|
+
- AUC ≈ 0.5: No drift (random guess between reference and test)
|
|
33
|
+
- AUC = 0.6: Weak drift
|
|
34
|
+
- AUC = 0.7-0.8: Moderate drift
|
|
35
|
+
- AUC > 0.9: Strong drift
|
|
36
|
+
- Feature importance identifies which features drifted
|
|
37
|
+
|
|
38
|
+
When to Use:
|
|
39
|
+
- **PSI**: Categorical features or when binning is acceptable
|
|
40
|
+
- **Wasserstein**: Continuous features, more sensitive to small shifts
|
|
41
|
+
- **Domain Classifier**: Multivariate drift, interaction detection
|
|
42
|
+
- **analyze_drift()**: Comprehensive analysis with multiple methods
|
|
43
|
+
- Model monitoring: Compare production data to training data
|
|
44
|
+
- Temporal drift: Compare recent data to historical baseline
|
|
45
|
+
- Segmentation drift: Compare distributions across segments
|
|
46
|
+
|
|
47
|
+
References:
|
|
48
|
+
- Yurdakul, B. (2018). Statistical Properties of Population Stability Index.
|
|
49
|
+
https://scholarship.richmond.edu/honors-theses/1131/
|
|
50
|
+
- Webb, G. I., et al. (2016). Characterizing concept drift.
|
|
51
|
+
Data Mining and Knowledge Discovery, 30(4), 964-994.
|
|
52
|
+
- Villani, C. (2009). Optimal Transport: Old and New. Springer.
|
|
53
|
+
- Ramdas, A., et al. (2017). On Wasserstein Two-Sample Testing and Related
|
|
54
|
+
Families of Nonparametric Tests. Entropy, 19(2), 47.
|
|
55
|
+
- Lopez-Paz, D., & Oquab, M. (2017). Revisiting Classifier Two-Sample Tests.
|
|
56
|
+
ICLR 2017.
|
|
57
|
+
- Rabanser, S., et al. (2019). Failing Loudly: An Empirical Study of Methods
|
|
58
|
+
for Detecting Dataset Shift. NeurIPS 2019.
|
|
59
|
+
|
|
60
|
+
Example - Individual Methods:
|
|
61
|
+
>>> import numpy as np
|
|
62
|
+
>>> from ml4t.diagnostic.evaluation.drift import (
|
|
63
|
+
... compute_psi, compute_wasserstein_distance, compute_domain_classifier_drift
|
|
64
|
+
... )
|
|
65
|
+
>>>
|
|
66
|
+
>>> # PSI for univariate drift
|
|
67
|
+
>>> reference = np.random.normal(0, 1, 1000)
|
|
68
|
+
>>> test = np.random.normal(0.5, 1, 1000) # Mean shifted
|
|
69
|
+
>>> psi_result = compute_psi(reference, test, n_bins=10)
|
|
70
|
+
>>> print(f"PSI: {psi_result.psi:.4f}, Alert: {psi_result.alert_level}")
|
|
71
|
+
>>>
|
|
72
|
+
>>> # Wasserstein for continuous features
|
|
73
|
+
>>> ws_result = compute_wasserstein_distance(reference, test)
|
|
74
|
+
>>> print(f"Wasserstein: {ws_result.distance:.4f}, Drifted: {ws_result.drifted}")
|
|
75
|
+
|
|
76
|
+
Example - Unified Analysis:
|
|
77
|
+
>>> import pandas as pd
|
|
78
|
+
>>> from ml4t.diagnostic.evaluation.drift import analyze_drift
|
|
79
|
+
>>>
|
|
80
|
+
>>> # Create reference and test datasets
|
|
81
|
+
>>> reference = pd.DataFrame({
|
|
82
|
+
... 'feature1': np.random.normal(0, 1, 1000),
|
|
83
|
+
... 'feature2': np.random.normal(0, 1, 1000),
|
|
84
|
+
... })
|
|
85
|
+
>>> test = pd.DataFrame({
|
|
86
|
+
... 'feature1': np.random.normal(0.5, 1, 1000), # Drifted
|
|
87
|
+
... 'feature2': np.random.normal(0, 1, 1000), # Stable
|
|
88
|
+
... })
|
|
89
|
+
>>>
|
|
90
|
+
>>> # Comprehensive drift analysis with all methods
|
|
91
|
+
>>> result = analyze_drift(reference, test)
|
|
92
|
+
>>> print(result.summary())
|
|
93
|
+
>>> print(f"Drifted features: {result.drifted_features}")
|
|
94
|
+
>>>
|
|
95
|
+
>>> # Get detailed results as DataFrame
|
|
96
|
+
>>> df = result.to_dataframe()
|
|
97
|
+
>>> print(df)
|
|
98
|
+
>>>
|
|
99
|
+
>>> # Use specific methods only
|
|
100
|
+
>>> result = analyze_drift(reference, test, methods=['psi', 'wasserstein'])
|
|
101
|
+
>>>
|
|
102
|
+
>>> # Customize consensus threshold (default: 0.5)
|
|
103
|
+
>>> result = analyze_drift(reference, test, consensus_threshold=0.66)
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
# Import from submodules and re-export
|
|
107
|
+
from ml4t.diagnostic.evaluation.drift.analysis import (
|
|
108
|
+
DriftSummaryResult,
|
|
109
|
+
FeatureDriftResult,
|
|
110
|
+
analyze_drift,
|
|
111
|
+
)
|
|
112
|
+
from ml4t.diagnostic.evaluation.drift.domain_classifier import (
|
|
113
|
+
DomainClassifierResult,
|
|
114
|
+
compute_domain_classifier_drift,
|
|
115
|
+
)
|
|
116
|
+
from ml4t.diagnostic.evaluation.drift.population_stability_index import (
|
|
117
|
+
PSIResult,
|
|
118
|
+
compute_psi,
|
|
119
|
+
)
|
|
120
|
+
from ml4t.diagnostic.evaluation.drift.wasserstein import (
|
|
121
|
+
WassersteinResult,
|
|
122
|
+
compute_wasserstein_distance,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
__all__ = [
|
|
126
|
+
# PSI
|
|
127
|
+
"compute_psi",
|
|
128
|
+
"PSIResult",
|
|
129
|
+
# Wasserstein
|
|
130
|
+
"compute_wasserstein_distance",
|
|
131
|
+
"WassersteinResult",
|
|
132
|
+
# Domain Classifier
|
|
133
|
+
"compute_domain_classifier_drift",
|
|
134
|
+
"DomainClassifierResult",
|
|
135
|
+
# Unified analysis
|
|
136
|
+
"analyze_drift",
|
|
137
|
+
"FeatureDriftResult",
|
|
138
|
+
"DriftSummaryResult",
|
|
139
|
+
]
|