ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
"""Population Stability Index (PSI) for distribution drift detection.
|
|
2
|
+
|
|
3
|
+
PSI measures the distribution shift between a reference dataset (e.g., training)
|
|
4
|
+
and a test dataset (e.g., production).
|
|
5
|
+
|
|
6
|
+
PSI Interpretation:
|
|
7
|
+
- PSI < 0.1: No significant change (green)
|
|
8
|
+
- 0.1 ≤ PSI < 0.2: Small change, monitor (yellow)
|
|
9
|
+
- PSI ≥ 0.2: Significant change, investigate (red)
|
|
10
|
+
|
|
11
|
+
References:
|
|
12
|
+
- Yurdakul, B. (2018). Statistical Properties of Population Stability Index.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import Literal
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import polars as pl
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class PSIResult:
|
|
26
|
+
"""Result of Population Stability Index calculation.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
psi: Overall PSI value (sum of bin-level PSI contributions)
|
|
30
|
+
bin_psi: PSI contribution per bin
|
|
31
|
+
bin_edges: Bin boundaries (continuous) or category labels (categorical)
|
|
32
|
+
reference_counts: Number of samples per bin in reference distribution
|
|
33
|
+
test_counts: Number of samples per bin in test distribution
|
|
34
|
+
reference_percents: Percentage of samples per bin in reference
|
|
35
|
+
test_percents: Percentage of samples per bin in test
|
|
36
|
+
n_bins: Number of bins used
|
|
37
|
+
is_categorical: Whether feature is categorical
|
|
38
|
+
alert_level: Alert level based on PSI thresholds
|
|
39
|
+
- "green": PSI < 0.1 (no significant change)
|
|
40
|
+
- "yellow": 0.1 ≤ PSI < 0.2 (small change, monitor)
|
|
41
|
+
- "red": PSI ≥ 0.2 (significant change, investigate)
|
|
42
|
+
interpretation: Human-readable interpretation
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
psi: float
|
|
46
|
+
bin_psi: np.ndarray
|
|
47
|
+
bin_edges: np.ndarray | list[str]
|
|
48
|
+
reference_counts: np.ndarray
|
|
49
|
+
test_counts: np.ndarray
|
|
50
|
+
reference_percents: np.ndarray
|
|
51
|
+
test_percents: np.ndarray
|
|
52
|
+
n_bins: int
|
|
53
|
+
is_categorical: bool
|
|
54
|
+
alert_level: Literal["green", "yellow", "red"]
|
|
55
|
+
interpretation: str
|
|
56
|
+
|
|
57
|
+
def summary(self) -> str:
|
|
58
|
+
"""Return formatted summary of PSI results."""
|
|
59
|
+
lines = [
|
|
60
|
+
"Population Stability Index (PSI) Report",
|
|
61
|
+
"=" * 50,
|
|
62
|
+
f"PSI Value: {self.psi:.4f}",
|
|
63
|
+
f"Alert Level: {self.alert_level.upper()}",
|
|
64
|
+
f"Feature Type: {'Categorical' if self.is_categorical else 'Continuous'}",
|
|
65
|
+
f"Number of Bins: {self.n_bins}",
|
|
66
|
+
"",
|
|
67
|
+
f"Interpretation: {self.interpretation}",
|
|
68
|
+
"",
|
|
69
|
+
"Bin-Level Analysis:",
|
|
70
|
+
"-" * 50,
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
# Add bin-level details
|
|
74
|
+
for i in range(self.n_bins):
|
|
75
|
+
if self.is_categorical:
|
|
76
|
+
bin_label = self.bin_edges[i]
|
|
77
|
+
else:
|
|
78
|
+
if i == 0:
|
|
79
|
+
bin_label = f"(-inf, {self.bin_edges[i + 1]:.3f}]"
|
|
80
|
+
elif i == self.n_bins - 1:
|
|
81
|
+
bin_label = f"({self.bin_edges[i]:.3f}, +inf)"
|
|
82
|
+
else:
|
|
83
|
+
bin_label = f"({self.bin_edges[i]:.3f}, {self.bin_edges[i + 1]:.3f}]"
|
|
84
|
+
|
|
85
|
+
lines.append(
|
|
86
|
+
f"Bin {i + 1:2d} {bin_label:20s}: "
|
|
87
|
+
f"Ref={self.reference_percents[i]:6.2%} "
|
|
88
|
+
f"Test={self.test_percents[i]:6.2%} "
|
|
89
|
+
f"PSI={self.bin_psi[i]:.4f}"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return "\n".join(lines)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def compute_psi(
|
|
96
|
+
reference: np.ndarray | pl.Series,
|
|
97
|
+
test: np.ndarray | pl.Series,
|
|
98
|
+
n_bins: int = 10,
|
|
99
|
+
is_categorical: bool = False,
|
|
100
|
+
missing_category_handling: Literal["ignore", "separate", "error"] = "separate",
|
|
101
|
+
psi_threshold_yellow: float = 0.1,
|
|
102
|
+
psi_threshold_red: float = 0.2,
|
|
103
|
+
) -> PSIResult:
|
|
104
|
+
"""Compute Population Stability Index (PSI) between two distributions.
|
|
105
|
+
|
|
106
|
+
PSI measures the distribution shift between a reference dataset (e.g., training)
|
|
107
|
+
and a test dataset (e.g., production). It quantifies how much the distribution
|
|
108
|
+
has changed.
|
|
109
|
+
|
|
110
|
+
Formula:
|
|
111
|
+
PSI = Σ (test_% - ref_%) × ln(test_% / ref_%)
|
|
112
|
+
|
|
113
|
+
For each bin i:
|
|
114
|
+
PSI_i = (P_test[i] - P_ref[i]) × ln(P_test[i] / P_ref[i])
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
reference: Reference distribution (e.g., training data)
|
|
118
|
+
test: Test distribution (e.g., production data)
|
|
119
|
+
n_bins: Number of quantile bins for continuous features (default: 10)
|
|
120
|
+
is_categorical: Whether feature is categorical (default: False)
|
|
121
|
+
missing_category_handling: How to handle categories in test not in reference:
|
|
122
|
+
- "ignore": Skip missing categories (not recommended)
|
|
123
|
+
- "separate": Create separate bin for missing categories (default)
|
|
124
|
+
- "error": Raise error if new categories found
|
|
125
|
+
psi_threshold_yellow: Threshold for yellow alert (default: 0.1)
|
|
126
|
+
psi_threshold_red: Threshold for red alert (default: 0.2)
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
PSIResult with overall PSI, bin-level contributions, and interpretation
|
|
130
|
+
|
|
131
|
+
Raises:
|
|
132
|
+
ValueError: If inputs are invalid or missing categories found with "error" handling
|
|
133
|
+
|
|
134
|
+
Example:
|
|
135
|
+
>>> # Continuous feature
|
|
136
|
+
>>> ref = np.random.normal(0, 1, 1000)
|
|
137
|
+
>>> test = np.random.normal(0.5, 1, 1000) # Mean shifted
|
|
138
|
+
>>> result = compute_psi(ref, test, n_bins=10)
|
|
139
|
+
>>> print(result.summary())
|
|
140
|
+
>>>
|
|
141
|
+
>>> # Categorical feature
|
|
142
|
+
>>> ref_cat = np.array(['A', 'B', 'C'] * 100)
|
|
143
|
+
>>> test_cat = np.array(['A', 'A', 'B'] * 100) # Distribution changed
|
|
144
|
+
>>> result = compute_psi(ref_cat, test_cat, is_categorical=True)
|
|
145
|
+
>>> print(f"PSI: {result.psi:.4f}, Alert: {result.alert_level}")
|
|
146
|
+
"""
|
|
147
|
+
# Convert to numpy arrays
|
|
148
|
+
if isinstance(reference, pl.Series):
|
|
149
|
+
reference = reference.to_numpy()
|
|
150
|
+
if isinstance(test, pl.Series):
|
|
151
|
+
test = test.to_numpy()
|
|
152
|
+
|
|
153
|
+
reference = np.asarray(reference)
|
|
154
|
+
test = np.asarray(test)
|
|
155
|
+
|
|
156
|
+
# Validate inputs
|
|
157
|
+
if len(reference) == 0 or len(test) == 0:
|
|
158
|
+
raise ValueError("Reference and test arrays must not be empty")
|
|
159
|
+
|
|
160
|
+
# Variables with union types for both branches
|
|
161
|
+
bin_labels: np.ndarray | list[str]
|
|
162
|
+
bin_edges: np.ndarray | list[str]
|
|
163
|
+
|
|
164
|
+
if not is_categorical:
|
|
165
|
+
# Continuous feature: quantile binning
|
|
166
|
+
bin_edges, ref_counts, test_counts = _bin_continuous(reference, test, n_bins)
|
|
167
|
+
bin_labels = bin_edges # Will be formatted in summary()
|
|
168
|
+
else:
|
|
169
|
+
# Categorical feature: category-based binning
|
|
170
|
+
bin_labels, ref_counts, test_counts = _bin_categorical(
|
|
171
|
+
reference, test, missing_category_handling
|
|
172
|
+
)
|
|
173
|
+
bin_edges = bin_labels
|
|
174
|
+
n_bins = len(bin_labels)
|
|
175
|
+
|
|
176
|
+
# Convert counts to percentages
|
|
177
|
+
ref_percents = ref_counts / ref_counts.sum()
|
|
178
|
+
test_percents = test_counts / test_counts.sum()
|
|
179
|
+
|
|
180
|
+
# Compute PSI per bin with numerical stability
|
|
181
|
+
# Add small epsilon to avoid log(0) and division by zero
|
|
182
|
+
epsilon = 1e-10
|
|
183
|
+
ref_percents_safe = np.maximum(ref_percents, epsilon)
|
|
184
|
+
test_percents_safe = np.maximum(test_percents, epsilon)
|
|
185
|
+
|
|
186
|
+
# PSI formula: (test% - ref%) * ln(test% / ref%)
|
|
187
|
+
bin_psi = (test_percents_safe - ref_percents_safe) * np.log(
|
|
188
|
+
test_percents_safe / ref_percents_safe
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Total PSI is sum of bin contributions
|
|
192
|
+
psi = float(np.sum(bin_psi))
|
|
193
|
+
|
|
194
|
+
# Determine alert level
|
|
195
|
+
alert_level: Literal["green", "yellow", "red"]
|
|
196
|
+
if psi < psi_threshold_yellow:
|
|
197
|
+
alert_level = "green"
|
|
198
|
+
interpretation = (
|
|
199
|
+
f"No significant distribution change detected (PSI={psi:.4f} < {psi_threshold_yellow}). "
|
|
200
|
+
"Feature distribution is stable."
|
|
201
|
+
)
|
|
202
|
+
elif psi < psi_threshold_red:
|
|
203
|
+
alert_level = "yellow"
|
|
204
|
+
interpretation = (
|
|
205
|
+
f"Small distribution change detected ({psi_threshold_yellow} ≤ PSI={psi:.4f} < {psi_threshold_red}). "
|
|
206
|
+
"Monitor feature closely but no immediate action required."
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
alert_level = "red"
|
|
210
|
+
interpretation = (
|
|
211
|
+
f"Significant distribution change detected (PSI={psi:.4f} ≥ {psi_threshold_red}). "
|
|
212
|
+
"Investigate cause and consider model retraining."
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return PSIResult(
|
|
216
|
+
psi=psi,
|
|
217
|
+
bin_psi=bin_psi,
|
|
218
|
+
bin_edges=bin_edges,
|
|
219
|
+
reference_counts=ref_counts,
|
|
220
|
+
test_counts=test_counts,
|
|
221
|
+
reference_percents=ref_percents,
|
|
222
|
+
test_percents=test_percents,
|
|
223
|
+
n_bins=n_bins,
|
|
224
|
+
is_categorical=is_categorical,
|
|
225
|
+
alert_level=alert_level,
|
|
226
|
+
interpretation=interpretation,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _bin_continuous(
|
|
231
|
+
reference: np.ndarray, test: np.ndarray, n_bins: int
|
|
232
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
233
|
+
"""Bin continuous features using quantiles from reference distribution.
|
|
234
|
+
|
|
235
|
+
Uses quantile binning to ensure roughly equal-sized bins in reference distribution.
|
|
236
|
+
Test distribution is binned using same bin edges.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
reference: Reference data (used to compute quantiles)
|
|
240
|
+
test: Test data (binned using reference quantiles)
|
|
241
|
+
n_bins: Number of bins
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Tuple of (bin_edges, reference_counts, test_counts)
|
|
245
|
+
"""
|
|
246
|
+
# Compute quantiles from reference distribution
|
|
247
|
+
# Use (n_bins + 1) to get n_bins bins with n_bins + 1 edges
|
|
248
|
+
quantiles = np.linspace(0, 100, n_bins + 1)
|
|
249
|
+
bin_edges = np.percentile(reference, quantiles)
|
|
250
|
+
|
|
251
|
+
# Ensure edges are unique (handle constant features)
|
|
252
|
+
bin_edges = np.unique(bin_edges)
|
|
253
|
+
|
|
254
|
+
# If all values are the same, create a single bin
|
|
255
|
+
if len(bin_edges) == 1:
|
|
256
|
+
return bin_edges, np.array([len(reference)]), np.array([len(test)])
|
|
257
|
+
|
|
258
|
+
# Bin both distributions using same edges
|
|
259
|
+
# Use digitize for open-interval binning
|
|
260
|
+
ref_bins = np.digitize(reference, bin_edges[1:-1])
|
|
261
|
+
test_bins = np.digitize(test, bin_edges[1:-1])
|
|
262
|
+
|
|
263
|
+
# Count samples per bin
|
|
264
|
+
ref_counts = np.bincount(ref_bins, minlength=len(bin_edges) - 1)
|
|
265
|
+
test_counts = np.bincount(test_bins, minlength=len(bin_edges) - 1)
|
|
266
|
+
|
|
267
|
+
return bin_edges, ref_counts, test_counts
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _bin_categorical(
|
|
271
|
+
reference: np.ndarray,
|
|
272
|
+
test: np.ndarray,
|
|
273
|
+
missing_handling: Literal["ignore", "separate", "error"],
|
|
274
|
+
) -> tuple[list[str], np.ndarray, np.ndarray]:
|
|
275
|
+
"""Bin categorical features by category labels.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
reference: Reference categories
|
|
279
|
+
test: Test categories
|
|
280
|
+
missing_handling: How to handle new categories in test
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Tuple of (category_labels, reference_counts, test_counts)
|
|
284
|
+
|
|
285
|
+
Raises:
|
|
286
|
+
ValueError: If new categories found and missing_handling="error"
|
|
287
|
+
"""
|
|
288
|
+
# Get unique categories from reference
|
|
289
|
+
ref_categories = sorted(set(reference))
|
|
290
|
+
test_categories = set(test)
|
|
291
|
+
|
|
292
|
+
# Check for new categories in test
|
|
293
|
+
new_categories = test_categories - set(ref_categories)
|
|
294
|
+
|
|
295
|
+
if new_categories:
|
|
296
|
+
if missing_handling == "error":
|
|
297
|
+
raise ValueError(
|
|
298
|
+
f"New categories found in test set: {new_categories}. "
|
|
299
|
+
"These categories were not present in reference distribution."
|
|
300
|
+
)
|
|
301
|
+
elif missing_handling == "separate":
|
|
302
|
+
# Add new categories to the end
|
|
303
|
+
ref_categories.extend(sorted(new_categories))
|
|
304
|
+
# else "ignore": new categories will be dropped
|
|
305
|
+
|
|
306
|
+
# Count occurrences per category
|
|
307
|
+
ref_counts = np.array([np.sum(reference == cat) for cat in ref_categories])
|
|
308
|
+
test_counts = np.array([np.sum(test == cat) for cat in ref_categories])
|
|
309
|
+
|
|
310
|
+
return ref_categories, ref_counts, test_counts
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""Wasserstein distance for continuous distribution drift detection.
|
|
2
|
+
|
|
3
|
+
The Wasserstein distance (Earth Mover's Distance) measures the minimum cost
|
|
4
|
+
to transform one probability distribution into another.
|
|
5
|
+
|
|
6
|
+
Properties:
|
|
7
|
+
- True metric: non-negative, symmetric, triangle inequality
|
|
8
|
+
- More sensitive to small shifts than PSI
|
|
9
|
+
- Natural interpretation as "transport cost"
|
|
10
|
+
- No binning artifacts
|
|
11
|
+
|
|
12
|
+
References:
|
|
13
|
+
- Villani, C. (2009). Optimal Transport: Old and New. Springer.
|
|
14
|
+
- Ramdas, A., et al. (2017). On Wasserstein Two-Sample Testing.
|
|
15
|
+
Entropy, 19(2), 47.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import time
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import polars as pl
|
|
26
|
+
from scipy.stats import wasserstein_distance
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class WassersteinResult:
|
|
31
|
+
"""Result of Wasserstein distance calculation.
|
|
32
|
+
|
|
33
|
+
The Wasserstein distance (also called Earth Mover's Distance) measures the
|
|
34
|
+
minimum "cost" to transform one distribution into another. It's a true metric
|
|
35
|
+
and doesn't require binning, making it ideal for continuous features.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
distance: Wasserstein distance value (W_p)
|
|
39
|
+
p: Order of Wasserstein distance (1 or 2)
|
|
40
|
+
threshold: Calibrated threshold from permutation test (if calibrated)
|
|
41
|
+
p_value: Statistical significance p-value (if calibrated)
|
|
42
|
+
drifted: Whether drift was detected (distance > threshold)
|
|
43
|
+
n_reference: Number of samples in reference distribution
|
|
44
|
+
n_test: Number of samples in test distribution
|
|
45
|
+
reference_stats: Summary statistics of reference distribution
|
|
46
|
+
test_stats: Summary statistics of test distribution
|
|
47
|
+
threshold_calibration_config: Configuration used for threshold calibration
|
|
48
|
+
interpretation: Human-readable interpretation
|
|
49
|
+
computation_time: Time taken to compute (seconds)
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
distance: float
|
|
53
|
+
p: int
|
|
54
|
+
threshold: float | None
|
|
55
|
+
p_value: float | None
|
|
56
|
+
drifted: bool
|
|
57
|
+
n_reference: int
|
|
58
|
+
n_test: int
|
|
59
|
+
reference_stats: dict[str, float]
|
|
60
|
+
test_stats: dict[str, float]
|
|
61
|
+
threshold_calibration_config: dict[str, Any] | None
|
|
62
|
+
interpretation: str
|
|
63
|
+
computation_time: float
|
|
64
|
+
|
|
65
|
+
def summary(self) -> str:
|
|
66
|
+
"""Return formatted summary of Wasserstein distance results."""
|
|
67
|
+
lines = [
|
|
68
|
+
"Wasserstein Distance Drift Detection Report",
|
|
69
|
+
"=" * 60,
|
|
70
|
+
f"Wasserstein-{self.p} Distance: {self.distance:.6f}",
|
|
71
|
+
f"Drift Detected: {'YES' if self.drifted else 'NO'}",
|
|
72
|
+
"",
|
|
73
|
+
"Sample Sizes:",
|
|
74
|
+
f" Reference: {self.n_reference:,}",
|
|
75
|
+
f" Test: {self.n_test:,}",
|
|
76
|
+
"",
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
if self.threshold is not None:
|
|
80
|
+
lines.extend(
|
|
81
|
+
[
|
|
82
|
+
"Threshold Calibration:",
|
|
83
|
+
f" Threshold: {self.threshold:.6f}",
|
|
84
|
+
f" P-value: {self.p_value:.4f}" if self.p_value else " P-value: N/A",
|
|
85
|
+
f" Config: {self.threshold_calibration_config}",
|
|
86
|
+
"",
|
|
87
|
+
]
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
lines.extend(
|
|
91
|
+
[
|
|
92
|
+
"Distribution Statistics:",
|
|
93
|
+
"-" * 60,
|
|
94
|
+
f"Reference: Mean={self.reference_stats['mean']:.4f}, "
|
|
95
|
+
f"Std={self.reference_stats['std']:.4f}, "
|
|
96
|
+
f"Min={self.reference_stats['min']:.4f}, "
|
|
97
|
+
f"Max={self.reference_stats['max']:.4f}",
|
|
98
|
+
f"Test: Mean={self.test_stats['mean']:.4f}, "
|
|
99
|
+
f"Std={self.test_stats['std']:.4f}, "
|
|
100
|
+
f"Min={self.test_stats['min']:.4f}, "
|
|
101
|
+
f"Max={self.test_stats['max']:.4f}",
|
|
102
|
+
"",
|
|
103
|
+
f"Interpretation: {self.interpretation}",
|
|
104
|
+
"",
|
|
105
|
+
f"Computation Time: {self.computation_time:.3f}s",
|
|
106
|
+
]
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return "\n".join(lines)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def compute_wasserstein_distance(
|
|
113
|
+
reference: np.ndarray | pl.Series,
|
|
114
|
+
test: np.ndarray | pl.Series,
|
|
115
|
+
p: int = 1,
|
|
116
|
+
threshold_calibration: bool = True,
|
|
117
|
+
n_permutations: int = 1000,
|
|
118
|
+
alpha: float = 0.05,
|
|
119
|
+
n_samples: int | None = None,
|
|
120
|
+
random_state: int | None = None,
|
|
121
|
+
) -> WassersteinResult:
|
|
122
|
+
"""Compute Wasserstein distance between reference and test distributions.
|
|
123
|
+
|
|
124
|
+
The Wasserstein distance (Earth Mover's Distance) measures the minimum cost
|
|
125
|
+
to transform one probability distribution into another. Unlike PSI, it doesn't
|
|
126
|
+
require binning and provides a true metric with desirable properties:
|
|
127
|
+
- Metric properties: non-negative, symmetric, triangle inequality
|
|
128
|
+
- More sensitive to small shifts than PSI
|
|
129
|
+
- Natural interpretation as "transport cost"
|
|
130
|
+
- No binning artifacts
|
|
131
|
+
|
|
132
|
+
The p-Wasserstein distance is defined as:
|
|
133
|
+
W_p(P, Q) = (∫|F_P^{-1}(u) - F_Q^{-1}(u)|^p du)^{1/p}
|
|
134
|
+
|
|
135
|
+
For empirical distributions with sorted samples x_1 ≤ ... ≤ x_n:
|
|
136
|
+
W_1(P, Q) = (1/n) Σ|x_i^P - x_i^Q|
|
|
137
|
+
|
|
138
|
+
Threshold calibration uses a permutation test:
|
|
139
|
+
H0: reference and test come from the same distribution
|
|
140
|
+
H1: distributions differ
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
reference: Reference distribution (e.g., training data)
|
|
144
|
+
test: Test distribution (e.g., production data)
|
|
145
|
+
p: Order of Wasserstein distance (1 or 2). Default: 1
|
|
146
|
+
- p=1: More robust, easier to interpret
|
|
147
|
+
- p=2: More sensitive to tail differences
|
|
148
|
+
threshold_calibration: Whether to calibrate threshold via permutation test
|
|
149
|
+
n_permutations: Number of permutations for threshold calibration
|
|
150
|
+
alpha: Significance level for threshold (default: 0.05)
|
|
151
|
+
n_samples: Subsample to this many samples if provided (for large datasets)
|
|
152
|
+
random_state: Random seed for reproducibility
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
WassersteinResult with distance, threshold, p-value, and interpretation
|
|
156
|
+
|
|
157
|
+
Raises:
|
|
158
|
+
ValueError: If inputs are invalid or p not in {1, 2}
|
|
159
|
+
|
|
160
|
+
Example:
|
|
161
|
+
>>> # Detect mean shift
|
|
162
|
+
>>> ref = np.random.normal(0, 1, 1000)
|
|
163
|
+
>>> test = np.random.normal(0.5, 1, 1000) # Mean shifted by 0.5
|
|
164
|
+
>>> result = compute_wasserstein_distance(ref, test)
|
|
165
|
+
>>> print(result.summary())
|
|
166
|
+
>>>
|
|
167
|
+
>>> # Detect variance shift
|
|
168
|
+
>>> test_var = np.random.normal(0, 2, 1000) # Variance doubled
|
|
169
|
+
>>> result = compute_wasserstein_distance(ref, test_var)
|
|
170
|
+
>>> print(f"Distance: {result.distance:.4f}, Drifted: {result.drifted}")
|
|
171
|
+
>>>
|
|
172
|
+
>>> # Without threshold calibration (faster)
|
|
173
|
+
>>> result = compute_wasserstein_distance(
|
|
174
|
+
... ref, test, threshold_calibration=False
|
|
175
|
+
... )
|
|
176
|
+
"""
|
|
177
|
+
start_time = time.time()
|
|
178
|
+
|
|
179
|
+
# Convert to numpy arrays
|
|
180
|
+
if isinstance(reference, pl.Series):
|
|
181
|
+
reference = reference.to_numpy()
|
|
182
|
+
if isinstance(test, pl.Series):
|
|
183
|
+
test = test.to_numpy()
|
|
184
|
+
|
|
185
|
+
reference = np.asarray(reference, dtype=np.float64)
|
|
186
|
+
test = np.asarray(test, dtype=np.float64)
|
|
187
|
+
|
|
188
|
+
# Validate inputs
|
|
189
|
+
if len(reference) == 0 or len(test) == 0:
|
|
190
|
+
raise ValueError("Reference and test arrays must not be empty")
|
|
191
|
+
|
|
192
|
+
if p not in [1, 2]:
|
|
193
|
+
raise ValueError(f"p must be 1 or 2, got {p}")
|
|
194
|
+
|
|
195
|
+
# Set random state
|
|
196
|
+
if random_state is not None:
|
|
197
|
+
np.random.seed(random_state)
|
|
198
|
+
|
|
199
|
+
# Subsample if requested
|
|
200
|
+
if n_samples is not None and len(reference) > n_samples:
|
|
201
|
+
indices_ref = np.random.choice(len(reference), n_samples, replace=False)
|
|
202
|
+
reference = reference[indices_ref]
|
|
203
|
+
if n_samples is not None and len(test) > n_samples:
|
|
204
|
+
indices_test = np.random.choice(len(test), n_samples, replace=False)
|
|
205
|
+
test = test[indices_test]
|
|
206
|
+
|
|
207
|
+
n_reference = len(reference)
|
|
208
|
+
n_test = len(test)
|
|
209
|
+
|
|
210
|
+
# Compute distribution statistics
|
|
211
|
+
reference_stats = {
|
|
212
|
+
"mean": float(np.mean(reference)),
|
|
213
|
+
"std": float(np.std(reference)),
|
|
214
|
+
"min": float(np.min(reference)),
|
|
215
|
+
"max": float(np.max(reference)),
|
|
216
|
+
"median": float(np.median(reference)),
|
|
217
|
+
"q25": float(np.percentile(reference, 25)),
|
|
218
|
+
"q75": float(np.percentile(reference, 75)),
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
test_stats = {
|
|
222
|
+
"mean": float(np.mean(test)),
|
|
223
|
+
"std": float(np.std(test)),
|
|
224
|
+
"min": float(np.min(test)),
|
|
225
|
+
"max": float(np.max(test)),
|
|
226
|
+
"median": float(np.median(test)),
|
|
227
|
+
"q25": float(np.percentile(test, 25)),
|
|
228
|
+
"q75": float(np.percentile(test, 75)),
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
# Compute Wasserstein distance
|
|
232
|
+
if p == 1:
|
|
233
|
+
distance = float(wasserstein_distance(reference, test))
|
|
234
|
+
else: # p == 2
|
|
235
|
+
# scipy's wasserstein_distance computes W_1
|
|
236
|
+
# For W_2, we need to compute it manually
|
|
237
|
+
distance = _wasserstein_2(reference, test)
|
|
238
|
+
|
|
239
|
+
# Threshold calibration via permutation test
|
|
240
|
+
threshold = None
|
|
241
|
+
p_value = None
|
|
242
|
+
calibration_config = None
|
|
243
|
+
|
|
244
|
+
if threshold_calibration:
|
|
245
|
+
threshold, p_value = _calibrate_wasserstein_threshold(
|
|
246
|
+
reference, test, distance, n_permutations, alpha, p
|
|
247
|
+
)
|
|
248
|
+
calibration_config = {
|
|
249
|
+
"n_permutations": n_permutations,
|
|
250
|
+
"alpha": alpha,
|
|
251
|
+
"method": "permutation",
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
# Determine drift status
|
|
255
|
+
if threshold is not None:
|
|
256
|
+
drifted = distance > threshold
|
|
257
|
+
else:
|
|
258
|
+
# Without calibration, use heuristic based on distribution statistics
|
|
259
|
+
# Drift if distance > 0.5 * std of reference
|
|
260
|
+
drifted = distance > 0.5 * reference_stats["std"]
|
|
261
|
+
threshold = 0.5 * reference_stats["std"]
|
|
262
|
+
|
|
263
|
+
# Generate interpretation
|
|
264
|
+
if drifted:
|
|
265
|
+
if p_value is not None:
|
|
266
|
+
interpretation = (
|
|
267
|
+
f"Distribution drift detected (W_{p}={distance:.6f} > {threshold:.6f}, "
|
|
268
|
+
f"p={p_value:.4f}). The test distribution differs significantly from "
|
|
269
|
+
f"the reference distribution."
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
interpretation = (
|
|
273
|
+
f"Distribution drift detected (W_{p}={distance:.6f} > {threshold:.6f}). "
|
|
274
|
+
f"The test distribution differs from the reference distribution."
|
|
275
|
+
)
|
|
276
|
+
else:
|
|
277
|
+
if p_value is not None:
|
|
278
|
+
interpretation = (
|
|
279
|
+
f"No significant drift detected (W_{p}={distance:.6f} ≤ {threshold:.6f}, "
|
|
280
|
+
f"p={p_value:.4f}). Distributions are consistent."
|
|
281
|
+
)
|
|
282
|
+
else:
|
|
283
|
+
interpretation = f"No significant drift detected (W_{p}={distance:.6f} ≤ {threshold:.6f}). Distributions are consistent."
|
|
284
|
+
|
|
285
|
+
computation_time = time.time() - start_time
|
|
286
|
+
|
|
287
|
+
return WassersteinResult(
|
|
288
|
+
distance=distance,
|
|
289
|
+
p=p,
|
|
290
|
+
threshold=threshold,
|
|
291
|
+
p_value=p_value,
|
|
292
|
+
drifted=drifted,
|
|
293
|
+
n_reference=n_reference,
|
|
294
|
+
n_test=n_test,
|
|
295
|
+
reference_stats=reference_stats,
|
|
296
|
+
test_stats=test_stats,
|
|
297
|
+
threshold_calibration_config=calibration_config,
|
|
298
|
+
interpretation=interpretation,
|
|
299
|
+
computation_time=computation_time,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _wasserstein_2(u_values: np.ndarray, v_values: np.ndarray) -> float:
|
|
304
|
+
"""Compute Wasserstein-2 distance between two 1D distributions.
|
|
305
|
+
|
|
306
|
+
W_2(P, Q) = sqrt(∫|F_P^{-1}(u) - F_Q^{-1}(u)|^2 du)
|
|
307
|
+
|
|
308
|
+
For empirical distributions, this is computed as:
|
|
309
|
+
W_2 = sqrt((1/n) Σ(x_i - y_i)^2) where x, y are sorted samples
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
u_values: First distribution samples
|
|
313
|
+
v_values: Second distribution samples
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Wasserstein-2 distance
|
|
317
|
+
"""
|
|
318
|
+
u_sorted = np.sort(u_values)
|
|
319
|
+
v_sorted = np.sort(v_values)
|
|
320
|
+
|
|
321
|
+
# Align to same length via CDF interpolation
|
|
322
|
+
# Use linear interpolation between sorted samples
|
|
323
|
+
n = min(len(u_sorted), len(v_sorted))
|
|
324
|
+
u_quantiles = np.interp(np.linspace(0, 1, n), np.linspace(0, 1, len(u_sorted)), u_sorted)
|
|
325
|
+
v_quantiles = np.interp(np.linspace(0, 1, n), np.linspace(0, 1, len(v_sorted)), v_sorted)
|
|
326
|
+
|
|
327
|
+
# Compute L2 distance
|
|
328
|
+
return float(np.sqrt(np.mean((u_quantiles - v_quantiles) ** 2)))
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _calibrate_wasserstein_threshold(
|
|
332
|
+
reference: np.ndarray,
|
|
333
|
+
test: np.ndarray,
|
|
334
|
+
observed_distance: float,
|
|
335
|
+
n_permutations: int,
|
|
336
|
+
alpha: float,
|
|
337
|
+
p: int,
|
|
338
|
+
) -> tuple[float, float]:
|
|
339
|
+
"""Calibrate Wasserstein distance threshold via permutation test.
|
|
340
|
+
|
|
341
|
+
Tests the null hypothesis that reference and test come from the same
|
|
342
|
+
distribution by computing the null distribution of Wasserstein distances
|
|
343
|
+
under random permutations.
|
|
344
|
+
|
|
345
|
+
H0: P_ref = P_test (no drift)
|
|
346
|
+
H1: P_ref ≠ P_test (drift detected)
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
reference: Reference distribution samples
|
|
350
|
+
test: Test distribution samples
|
|
351
|
+
observed_distance: Observed Wasserstein distance
|
|
352
|
+
n_permutations: Number of permutations
|
|
353
|
+
alpha: Significance level
|
|
354
|
+
p: Order of Wasserstein distance
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
Tuple of (threshold, p_value)
|
|
358
|
+
- threshold: (1-alpha) quantile of null distribution
|
|
359
|
+
- p_value: Fraction of null distances >= observed
|
|
360
|
+
"""
|
|
361
|
+
# Pool all samples
|
|
362
|
+
pooled = np.concatenate([reference, test])
|
|
363
|
+
n_ref = len(reference)
|
|
364
|
+
|
|
365
|
+
# Compute null distribution
|
|
366
|
+
null_distances = np.zeros(n_permutations)
|
|
367
|
+
|
|
368
|
+
for i in range(n_permutations):
|
|
369
|
+
# Random permutation
|
|
370
|
+
np.random.shuffle(pooled)
|
|
371
|
+
|
|
372
|
+
# Split into two groups
|
|
373
|
+
ref_perm = pooled[:n_ref]
|
|
374
|
+
test_perm = pooled[n_ref:]
|
|
375
|
+
|
|
376
|
+
# Compute distance
|
|
377
|
+
if p == 1:
|
|
378
|
+
null_distances[i] = wasserstein_distance(ref_perm, test_perm)
|
|
379
|
+
else: # p == 2
|
|
380
|
+
null_distances[i] = _wasserstein_2(ref_perm, test_perm)
|
|
381
|
+
|
|
382
|
+
# Compute threshold as (1-alpha) quantile
|
|
383
|
+
threshold = float(np.percentile(null_distances, (1 - alpha) * 100))
|
|
384
|
+
|
|
385
|
+
# Compute p-value
|
|
386
|
+
p_value = float(np.mean(null_distances >= observed_distance))
|
|
387
|
+
|
|
388
|
+
return threshold, p_value
|