ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,310 @@
1
+ """Population Stability Index (PSI) for distribution drift detection.
2
+
3
+ PSI measures the distribution shift between a reference dataset (e.g., training)
4
+ and a test dataset (e.g., production).
5
+
6
+ PSI Interpretation:
7
+ - PSI < 0.1: No significant change (green)
8
+ - 0.1 ≤ PSI < 0.2: Small change, monitor (yellow)
9
+ - PSI ≥ 0.2: Significant change, investigate (red)
10
+
11
+ References:
12
+ - Yurdakul, B. (2018). Statistical Properties of Population Stability Index.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Literal
19
+
20
+ import numpy as np
21
+ import polars as pl
22
+
23
+
24
+ @dataclass
25
+ class PSIResult:
26
+ """Result of Population Stability Index calculation.
27
+
28
+ Attributes:
29
+ psi: Overall PSI value (sum of bin-level PSI contributions)
30
+ bin_psi: PSI contribution per bin
31
+ bin_edges: Bin boundaries (continuous) or category labels (categorical)
32
+ reference_counts: Number of samples per bin in reference distribution
33
+ test_counts: Number of samples per bin in test distribution
34
+ reference_percents: Percentage of samples per bin in reference
35
+ test_percents: Percentage of samples per bin in test
36
+ n_bins: Number of bins used
37
+ is_categorical: Whether feature is categorical
38
+ alert_level: Alert level based on PSI thresholds
39
+ - "green": PSI < 0.1 (no significant change)
40
+ - "yellow": 0.1 ≤ PSI < 0.2 (small change, monitor)
41
+ - "red": PSI ≥ 0.2 (significant change, investigate)
42
+ interpretation: Human-readable interpretation
43
+ """
44
+
45
+ psi: float
46
+ bin_psi: np.ndarray
47
+ bin_edges: np.ndarray | list[str]
48
+ reference_counts: np.ndarray
49
+ test_counts: np.ndarray
50
+ reference_percents: np.ndarray
51
+ test_percents: np.ndarray
52
+ n_bins: int
53
+ is_categorical: bool
54
+ alert_level: Literal["green", "yellow", "red"]
55
+ interpretation: str
56
+
57
+ def summary(self) -> str:
58
+ """Return formatted summary of PSI results."""
59
+ lines = [
60
+ "Population Stability Index (PSI) Report",
61
+ "=" * 50,
62
+ f"PSI Value: {self.psi:.4f}",
63
+ f"Alert Level: {self.alert_level.upper()}",
64
+ f"Feature Type: {'Categorical' if self.is_categorical else 'Continuous'}",
65
+ f"Number of Bins: {self.n_bins}",
66
+ "",
67
+ f"Interpretation: {self.interpretation}",
68
+ "",
69
+ "Bin-Level Analysis:",
70
+ "-" * 50,
71
+ ]
72
+
73
+ # Add bin-level details
74
+ for i in range(self.n_bins):
75
+ if self.is_categorical:
76
+ bin_label = self.bin_edges[i]
77
+ else:
78
+ if i == 0:
79
+ bin_label = f"(-inf, {self.bin_edges[i + 1]:.3f}]"
80
+ elif i == self.n_bins - 1:
81
+ bin_label = f"({self.bin_edges[i]:.3f}, +inf)"
82
+ else:
83
+ bin_label = f"({self.bin_edges[i]:.3f}, {self.bin_edges[i + 1]:.3f}]"
84
+
85
+ lines.append(
86
+ f"Bin {i + 1:2d} {bin_label:20s}: "
87
+ f"Ref={self.reference_percents[i]:6.2%} "
88
+ f"Test={self.test_percents[i]:6.2%} "
89
+ f"PSI={self.bin_psi[i]:.4f}"
90
+ )
91
+
92
+ return "\n".join(lines)
93
+
94
+
95
+ def compute_psi(
96
+ reference: np.ndarray | pl.Series,
97
+ test: np.ndarray | pl.Series,
98
+ n_bins: int = 10,
99
+ is_categorical: bool = False,
100
+ missing_category_handling: Literal["ignore", "separate", "error"] = "separate",
101
+ psi_threshold_yellow: float = 0.1,
102
+ psi_threshold_red: float = 0.2,
103
+ ) -> PSIResult:
104
+ """Compute Population Stability Index (PSI) between two distributions.
105
+
106
+ PSI measures the distribution shift between a reference dataset (e.g., training)
107
+ and a test dataset (e.g., production). It quantifies how much the distribution
108
+ has changed.
109
+
110
+ Formula:
111
+ PSI = Σ (test_% - ref_%) × ln(test_% / ref_%)
112
+
113
+ For each bin i:
114
+ PSI_i = (P_test[i] - P_ref[i]) × ln(P_test[i] / P_ref[i])
115
+
116
+ Args:
117
+ reference: Reference distribution (e.g., training data)
118
+ test: Test distribution (e.g., production data)
119
+ n_bins: Number of quantile bins for continuous features (default: 10)
120
+ is_categorical: Whether feature is categorical (default: False)
121
+ missing_category_handling: How to handle categories in test not in reference:
122
+ - "ignore": Skip missing categories (not recommended)
123
+ - "separate": Create separate bin for missing categories (default)
124
+ - "error": Raise error if new categories found
125
+ psi_threshold_yellow: Threshold for yellow alert (default: 0.1)
126
+ psi_threshold_red: Threshold for red alert (default: 0.2)
127
+
128
+ Returns:
129
+ PSIResult with overall PSI, bin-level contributions, and interpretation
130
+
131
+ Raises:
132
+ ValueError: If inputs are invalid or missing categories found with "error" handling
133
+
134
+ Example:
135
+ >>> # Continuous feature
136
+ >>> ref = np.random.normal(0, 1, 1000)
137
+ >>> test = np.random.normal(0.5, 1, 1000) # Mean shifted
138
+ >>> result = compute_psi(ref, test, n_bins=10)
139
+ >>> print(result.summary())
140
+ >>>
141
+ >>> # Categorical feature
142
+ >>> ref_cat = np.array(['A', 'B', 'C'] * 100)
143
+ >>> test_cat = np.array(['A', 'A', 'B'] * 100) # Distribution changed
144
+ >>> result = compute_psi(ref_cat, test_cat, is_categorical=True)
145
+ >>> print(f"PSI: {result.psi:.4f}, Alert: {result.alert_level}")
146
+ """
147
+ # Convert to numpy arrays
148
+ if isinstance(reference, pl.Series):
149
+ reference = reference.to_numpy()
150
+ if isinstance(test, pl.Series):
151
+ test = test.to_numpy()
152
+
153
+ reference = np.asarray(reference)
154
+ test = np.asarray(test)
155
+
156
+ # Validate inputs
157
+ if len(reference) == 0 or len(test) == 0:
158
+ raise ValueError("Reference and test arrays must not be empty")
159
+
160
+ # Variables with union types for both branches
161
+ bin_labels: np.ndarray | list[str]
162
+ bin_edges: np.ndarray | list[str]
163
+
164
+ if not is_categorical:
165
+ # Continuous feature: quantile binning
166
+ bin_edges, ref_counts, test_counts = _bin_continuous(reference, test, n_bins)
167
+ bin_labels = bin_edges # Will be formatted in summary()
168
+ else:
169
+ # Categorical feature: category-based binning
170
+ bin_labels, ref_counts, test_counts = _bin_categorical(
171
+ reference, test, missing_category_handling
172
+ )
173
+ bin_edges = bin_labels
174
+ n_bins = len(bin_labels)
175
+
176
+ # Convert counts to percentages
177
+ ref_percents = ref_counts / ref_counts.sum()
178
+ test_percents = test_counts / test_counts.sum()
179
+
180
+ # Compute PSI per bin with numerical stability
181
+ # Add small epsilon to avoid log(0) and division by zero
182
+ epsilon = 1e-10
183
+ ref_percents_safe = np.maximum(ref_percents, epsilon)
184
+ test_percents_safe = np.maximum(test_percents, epsilon)
185
+
186
+ # PSI formula: (test% - ref%) * ln(test% / ref%)
187
+ bin_psi = (test_percents_safe - ref_percents_safe) * np.log(
188
+ test_percents_safe / ref_percents_safe
189
+ )
190
+
191
+ # Total PSI is sum of bin contributions
192
+ psi = float(np.sum(bin_psi))
193
+
194
+ # Determine alert level
195
+ alert_level: Literal["green", "yellow", "red"]
196
+ if psi < psi_threshold_yellow:
197
+ alert_level = "green"
198
+ interpretation = (
199
+ f"No significant distribution change detected (PSI={psi:.4f} < {psi_threshold_yellow}). "
200
+ "Feature distribution is stable."
201
+ )
202
+ elif psi < psi_threshold_red:
203
+ alert_level = "yellow"
204
+ interpretation = (
205
+ f"Small distribution change detected ({psi_threshold_yellow} ≤ PSI={psi:.4f} < {psi_threshold_red}). "
206
+ "Monitor feature closely but no immediate action required."
207
+ )
208
+ else:
209
+ alert_level = "red"
210
+ interpretation = (
211
+ f"Significant distribution change detected (PSI={psi:.4f} ≥ {psi_threshold_red}). "
212
+ "Investigate cause and consider model retraining."
213
+ )
214
+
215
+ return PSIResult(
216
+ psi=psi,
217
+ bin_psi=bin_psi,
218
+ bin_edges=bin_edges,
219
+ reference_counts=ref_counts,
220
+ test_counts=test_counts,
221
+ reference_percents=ref_percents,
222
+ test_percents=test_percents,
223
+ n_bins=n_bins,
224
+ is_categorical=is_categorical,
225
+ alert_level=alert_level,
226
+ interpretation=interpretation,
227
+ )
228
+
229
+
230
+ def _bin_continuous(
231
+ reference: np.ndarray, test: np.ndarray, n_bins: int
232
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
233
+ """Bin continuous features using quantiles from reference distribution.
234
+
235
+ Uses quantile binning to ensure roughly equal-sized bins in reference distribution.
236
+ Test distribution is binned using same bin edges.
237
+
238
+ Args:
239
+ reference: Reference data (used to compute quantiles)
240
+ test: Test data (binned using reference quantiles)
241
+ n_bins: Number of bins
242
+
243
+ Returns:
244
+ Tuple of (bin_edges, reference_counts, test_counts)
245
+ """
246
+ # Compute quantiles from reference distribution
247
+ # Use (n_bins + 1) to get n_bins bins with n_bins + 1 edges
248
+ quantiles = np.linspace(0, 100, n_bins + 1)
249
+ bin_edges = np.percentile(reference, quantiles)
250
+
251
+ # Ensure edges are unique (handle constant features)
252
+ bin_edges = np.unique(bin_edges)
253
+
254
+ # If all values are the same, create a single bin
255
+ if len(bin_edges) == 1:
256
+ return bin_edges, np.array([len(reference)]), np.array([len(test)])
257
+
258
+ # Bin both distributions using same edges
259
+ # Use digitize for open-interval binning
260
+ ref_bins = np.digitize(reference, bin_edges[1:-1])
261
+ test_bins = np.digitize(test, bin_edges[1:-1])
262
+
263
+ # Count samples per bin
264
+ ref_counts = np.bincount(ref_bins, minlength=len(bin_edges) - 1)
265
+ test_counts = np.bincount(test_bins, minlength=len(bin_edges) - 1)
266
+
267
+ return bin_edges, ref_counts, test_counts
268
+
269
+
270
+ def _bin_categorical(
271
+ reference: np.ndarray,
272
+ test: np.ndarray,
273
+ missing_handling: Literal["ignore", "separate", "error"],
274
+ ) -> tuple[list[str], np.ndarray, np.ndarray]:
275
+ """Bin categorical features by category labels.
276
+
277
+ Args:
278
+ reference: Reference categories
279
+ test: Test categories
280
+ missing_handling: How to handle new categories in test
281
+
282
+ Returns:
283
+ Tuple of (category_labels, reference_counts, test_counts)
284
+
285
+ Raises:
286
+ ValueError: If new categories found and missing_handling="error"
287
+ """
288
+ # Get unique categories from reference
289
+ ref_categories = sorted(set(reference))
290
+ test_categories = set(test)
291
+
292
+ # Check for new categories in test
293
+ new_categories = test_categories - set(ref_categories)
294
+
295
+ if new_categories:
296
+ if missing_handling == "error":
297
+ raise ValueError(
298
+ f"New categories found in test set: {new_categories}. "
299
+ "These categories were not present in reference distribution."
300
+ )
301
+ elif missing_handling == "separate":
302
+ # Add new categories to the end
303
+ ref_categories.extend(sorted(new_categories))
304
+ # else "ignore": new categories will be dropped
305
+
306
+ # Count occurrences per category
307
+ ref_counts = np.array([np.sum(reference == cat) for cat in ref_categories])
308
+ test_counts = np.array([np.sum(test == cat) for cat in ref_categories])
309
+
310
+ return ref_categories, ref_counts, test_counts
@@ -0,0 +1,388 @@
1
+ """Wasserstein distance for continuous distribution drift detection.
2
+
3
+ The Wasserstein distance (Earth Mover's Distance) measures the minimum cost
4
+ to transform one probability distribution into another.
5
+
6
+ Properties:
7
+ - True metric: non-negative, symmetric, triangle inequality
8
+ - More sensitive to small shifts than PSI
9
+ - Natural interpretation as "transport cost"
10
+ - No binning artifacts
11
+
12
+ References:
13
+ - Villani, C. (2009). Optimal Transport: Old and New. Springer.
14
+ - Ramdas, A., et al. (2017). On Wasserstein Two-Sample Testing.
15
+ Entropy, 19(2), 47.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import time
21
+ from dataclasses import dataclass
22
+ from typing import Any
23
+
24
+ import numpy as np
25
+ import polars as pl
26
+ from scipy.stats import wasserstein_distance
27
+
28
+
29
+ @dataclass
30
+ class WassersteinResult:
31
+ """Result of Wasserstein distance calculation.
32
+
33
+ The Wasserstein distance (also called Earth Mover's Distance) measures the
34
+ minimum "cost" to transform one distribution into another. It's a true metric
35
+ and doesn't require binning, making it ideal for continuous features.
36
+
37
+ Attributes:
38
+ distance: Wasserstein distance value (W_p)
39
+ p: Order of Wasserstein distance (1 or 2)
40
+ threshold: Calibrated threshold from permutation test (if calibrated)
41
+ p_value: Statistical significance p-value (if calibrated)
42
+ drifted: Whether drift was detected (distance > threshold)
43
+ n_reference: Number of samples in reference distribution
44
+ n_test: Number of samples in test distribution
45
+ reference_stats: Summary statistics of reference distribution
46
+ test_stats: Summary statistics of test distribution
47
+ threshold_calibration_config: Configuration used for threshold calibration
48
+ interpretation: Human-readable interpretation
49
+ computation_time: Time taken to compute (seconds)
50
+ """
51
+
52
+ distance: float
53
+ p: int
54
+ threshold: float | None
55
+ p_value: float | None
56
+ drifted: bool
57
+ n_reference: int
58
+ n_test: int
59
+ reference_stats: dict[str, float]
60
+ test_stats: dict[str, float]
61
+ threshold_calibration_config: dict[str, Any] | None
62
+ interpretation: str
63
+ computation_time: float
64
+
65
+ def summary(self) -> str:
66
+ """Return formatted summary of Wasserstein distance results."""
67
+ lines = [
68
+ "Wasserstein Distance Drift Detection Report",
69
+ "=" * 60,
70
+ f"Wasserstein-{self.p} Distance: {self.distance:.6f}",
71
+ f"Drift Detected: {'YES' if self.drifted else 'NO'}",
72
+ "",
73
+ "Sample Sizes:",
74
+ f" Reference: {self.n_reference:,}",
75
+ f" Test: {self.n_test:,}",
76
+ "",
77
+ ]
78
+
79
+ if self.threshold is not None:
80
+ lines.extend(
81
+ [
82
+ "Threshold Calibration:",
83
+ f" Threshold: {self.threshold:.6f}",
84
+ f" P-value: {self.p_value:.4f}" if self.p_value else " P-value: N/A",
85
+ f" Config: {self.threshold_calibration_config}",
86
+ "",
87
+ ]
88
+ )
89
+
90
+ lines.extend(
91
+ [
92
+ "Distribution Statistics:",
93
+ "-" * 60,
94
+ f"Reference: Mean={self.reference_stats['mean']:.4f}, "
95
+ f"Std={self.reference_stats['std']:.4f}, "
96
+ f"Min={self.reference_stats['min']:.4f}, "
97
+ f"Max={self.reference_stats['max']:.4f}",
98
+ f"Test: Mean={self.test_stats['mean']:.4f}, "
99
+ f"Std={self.test_stats['std']:.4f}, "
100
+ f"Min={self.test_stats['min']:.4f}, "
101
+ f"Max={self.test_stats['max']:.4f}",
102
+ "",
103
+ f"Interpretation: {self.interpretation}",
104
+ "",
105
+ f"Computation Time: {self.computation_time:.3f}s",
106
+ ]
107
+ )
108
+
109
+ return "\n".join(lines)
110
+
111
+
112
+ def compute_wasserstein_distance(
113
+ reference: np.ndarray | pl.Series,
114
+ test: np.ndarray | pl.Series,
115
+ p: int = 1,
116
+ threshold_calibration: bool = True,
117
+ n_permutations: int = 1000,
118
+ alpha: float = 0.05,
119
+ n_samples: int | None = None,
120
+ random_state: int | None = None,
121
+ ) -> WassersteinResult:
122
+ """Compute Wasserstein distance between reference and test distributions.
123
+
124
+ The Wasserstein distance (Earth Mover's Distance) measures the minimum cost
125
+ to transform one probability distribution into another. Unlike PSI, it doesn't
126
+ require binning and provides a true metric with desirable properties:
127
+ - Metric properties: non-negative, symmetric, triangle inequality
128
+ - More sensitive to small shifts than PSI
129
+ - Natural interpretation as "transport cost"
130
+ - No binning artifacts
131
+
132
+ The p-Wasserstein distance is defined as:
133
+ W_p(P, Q) = (∫|F_P^{-1}(u) - F_Q^{-1}(u)|^p du)^{1/p}
134
+
135
+ For empirical distributions with sorted samples x_1 ≤ ... ≤ x_n:
136
+ W_1(P, Q) = (1/n) Σ|x_i^P - x_i^Q|
137
+
138
+ Threshold calibration uses a permutation test:
139
+ H0: reference and test come from the same distribution
140
+ H1: distributions differ
141
+
142
+ Args:
143
+ reference: Reference distribution (e.g., training data)
144
+ test: Test distribution (e.g., production data)
145
+ p: Order of Wasserstein distance (1 or 2). Default: 1
146
+ - p=1: More robust, easier to interpret
147
+ - p=2: More sensitive to tail differences
148
+ threshold_calibration: Whether to calibrate threshold via permutation test
149
+ n_permutations: Number of permutations for threshold calibration
150
+ alpha: Significance level for threshold (default: 0.05)
151
+ n_samples: Subsample to this many samples if provided (for large datasets)
152
+ random_state: Random seed for reproducibility
153
+
154
+ Returns:
155
+ WassersteinResult with distance, threshold, p-value, and interpretation
156
+
157
+ Raises:
158
+ ValueError: If inputs are invalid or p not in {1, 2}
159
+
160
+ Example:
161
+ >>> # Detect mean shift
162
+ >>> ref = np.random.normal(0, 1, 1000)
163
+ >>> test = np.random.normal(0.5, 1, 1000) # Mean shifted by 0.5
164
+ >>> result = compute_wasserstein_distance(ref, test)
165
+ >>> print(result.summary())
166
+ >>>
167
+ >>> # Detect variance shift
168
+ >>> test_var = np.random.normal(0, 2, 1000) # Variance doubled
169
+ >>> result = compute_wasserstein_distance(ref, test_var)
170
+ >>> print(f"Distance: {result.distance:.4f}, Drifted: {result.drifted}")
171
+ >>>
172
+ >>> # Without threshold calibration (faster)
173
+ >>> result = compute_wasserstein_distance(
174
+ ... ref, test, threshold_calibration=False
175
+ ... )
176
+ """
177
+ start_time = time.time()
178
+
179
+ # Convert to numpy arrays
180
+ if isinstance(reference, pl.Series):
181
+ reference = reference.to_numpy()
182
+ if isinstance(test, pl.Series):
183
+ test = test.to_numpy()
184
+
185
+ reference = np.asarray(reference, dtype=np.float64)
186
+ test = np.asarray(test, dtype=np.float64)
187
+
188
+ # Validate inputs
189
+ if len(reference) == 0 or len(test) == 0:
190
+ raise ValueError("Reference and test arrays must not be empty")
191
+
192
+ if p not in [1, 2]:
193
+ raise ValueError(f"p must be 1 or 2, got {p}")
194
+
195
+ # Set random state
196
+ if random_state is not None:
197
+ np.random.seed(random_state)
198
+
199
+ # Subsample if requested
200
+ if n_samples is not None and len(reference) > n_samples:
201
+ indices_ref = np.random.choice(len(reference), n_samples, replace=False)
202
+ reference = reference[indices_ref]
203
+ if n_samples is not None and len(test) > n_samples:
204
+ indices_test = np.random.choice(len(test), n_samples, replace=False)
205
+ test = test[indices_test]
206
+
207
+ n_reference = len(reference)
208
+ n_test = len(test)
209
+
210
+ # Compute distribution statistics
211
+ reference_stats = {
212
+ "mean": float(np.mean(reference)),
213
+ "std": float(np.std(reference)),
214
+ "min": float(np.min(reference)),
215
+ "max": float(np.max(reference)),
216
+ "median": float(np.median(reference)),
217
+ "q25": float(np.percentile(reference, 25)),
218
+ "q75": float(np.percentile(reference, 75)),
219
+ }
220
+
221
+ test_stats = {
222
+ "mean": float(np.mean(test)),
223
+ "std": float(np.std(test)),
224
+ "min": float(np.min(test)),
225
+ "max": float(np.max(test)),
226
+ "median": float(np.median(test)),
227
+ "q25": float(np.percentile(test, 25)),
228
+ "q75": float(np.percentile(test, 75)),
229
+ }
230
+
231
+ # Compute Wasserstein distance
232
+ if p == 1:
233
+ distance = float(wasserstein_distance(reference, test))
234
+ else: # p == 2
235
+ # scipy's wasserstein_distance computes W_1
236
+ # For W_2, we need to compute it manually
237
+ distance = _wasserstein_2(reference, test)
238
+
239
+ # Threshold calibration via permutation test
240
+ threshold = None
241
+ p_value = None
242
+ calibration_config = None
243
+
244
+ if threshold_calibration:
245
+ threshold, p_value = _calibrate_wasserstein_threshold(
246
+ reference, test, distance, n_permutations, alpha, p
247
+ )
248
+ calibration_config = {
249
+ "n_permutations": n_permutations,
250
+ "alpha": alpha,
251
+ "method": "permutation",
252
+ }
253
+
254
+ # Determine drift status
255
+ if threshold is not None:
256
+ drifted = distance > threshold
257
+ else:
258
+ # Without calibration, use heuristic based on distribution statistics
259
+ # Drift if distance > 0.5 * std of reference
260
+ drifted = distance > 0.5 * reference_stats["std"]
261
+ threshold = 0.5 * reference_stats["std"]
262
+
263
+ # Generate interpretation
264
+ if drifted:
265
+ if p_value is not None:
266
+ interpretation = (
267
+ f"Distribution drift detected (W_{p}={distance:.6f} > {threshold:.6f}, "
268
+ f"p={p_value:.4f}). The test distribution differs significantly from "
269
+ f"the reference distribution."
270
+ )
271
+ else:
272
+ interpretation = (
273
+ f"Distribution drift detected (W_{p}={distance:.6f} > {threshold:.6f}). "
274
+ f"The test distribution differs from the reference distribution."
275
+ )
276
+ else:
277
+ if p_value is not None:
278
+ interpretation = (
279
+ f"No significant drift detected (W_{p}={distance:.6f} ≤ {threshold:.6f}, "
280
+ f"p={p_value:.4f}). Distributions are consistent."
281
+ )
282
+ else:
283
+ interpretation = f"No significant drift detected (W_{p}={distance:.6f} ≤ {threshold:.6f}). Distributions are consistent."
284
+
285
+ computation_time = time.time() - start_time
286
+
287
+ return WassersteinResult(
288
+ distance=distance,
289
+ p=p,
290
+ threshold=threshold,
291
+ p_value=p_value,
292
+ drifted=drifted,
293
+ n_reference=n_reference,
294
+ n_test=n_test,
295
+ reference_stats=reference_stats,
296
+ test_stats=test_stats,
297
+ threshold_calibration_config=calibration_config,
298
+ interpretation=interpretation,
299
+ computation_time=computation_time,
300
+ )
301
+
302
+
303
+ def _wasserstein_2(u_values: np.ndarray, v_values: np.ndarray) -> float:
304
+ """Compute Wasserstein-2 distance between two 1D distributions.
305
+
306
+ W_2(P, Q) = sqrt(∫|F_P^{-1}(u) - F_Q^{-1}(u)|^2 du)
307
+
308
+ For empirical distributions, this is computed as:
309
+ W_2 = sqrt((1/n) Σ(x_i - y_i)^2) where x, y are sorted samples
310
+
311
+ Args:
312
+ u_values: First distribution samples
313
+ v_values: Second distribution samples
314
+
315
+ Returns:
316
+ Wasserstein-2 distance
317
+ """
318
+ u_sorted = np.sort(u_values)
319
+ v_sorted = np.sort(v_values)
320
+
321
+ # Align to same length via CDF interpolation
322
+ # Use linear interpolation between sorted samples
323
+ n = min(len(u_sorted), len(v_sorted))
324
+ u_quantiles = np.interp(np.linspace(0, 1, n), np.linspace(0, 1, len(u_sorted)), u_sorted)
325
+ v_quantiles = np.interp(np.linspace(0, 1, n), np.linspace(0, 1, len(v_sorted)), v_sorted)
326
+
327
+ # Compute L2 distance
328
+ return float(np.sqrt(np.mean((u_quantiles - v_quantiles) ** 2)))
329
+
330
+
331
+ def _calibrate_wasserstein_threshold(
332
+ reference: np.ndarray,
333
+ test: np.ndarray,
334
+ observed_distance: float,
335
+ n_permutations: int,
336
+ alpha: float,
337
+ p: int,
338
+ ) -> tuple[float, float]:
339
+ """Calibrate Wasserstein distance threshold via permutation test.
340
+
341
+ Tests the null hypothesis that reference and test come from the same
342
+ distribution by computing the null distribution of Wasserstein distances
343
+ under random permutations.
344
+
345
+ H0: P_ref = P_test (no drift)
346
+ H1: P_ref ≠ P_test (drift detected)
347
+
348
+ Args:
349
+ reference: Reference distribution samples
350
+ test: Test distribution samples
351
+ observed_distance: Observed Wasserstein distance
352
+ n_permutations: Number of permutations
353
+ alpha: Significance level
354
+ p: Order of Wasserstein distance
355
+
356
+ Returns:
357
+ Tuple of (threshold, p_value)
358
+ - threshold: (1-alpha) quantile of null distribution
359
+ - p_value: Fraction of null distances >= observed
360
+ """
361
+ # Pool all samples
362
+ pooled = np.concatenate([reference, test])
363
+ n_ref = len(reference)
364
+
365
+ # Compute null distribution
366
+ null_distances = np.zeros(n_permutations)
367
+
368
+ for i in range(n_permutations):
369
+ # Random permutation
370
+ np.random.shuffle(pooled)
371
+
372
+ # Split into two groups
373
+ ref_perm = pooled[:n_ref]
374
+ test_perm = pooled[n_ref:]
375
+
376
+ # Compute distance
377
+ if p == 1:
378
+ null_distances[i] = wasserstein_distance(ref_perm, test_perm)
379
+ else: # p == 2
380
+ null_distances[i] = _wasserstein_2(ref_perm, test_perm)
381
+
382
+ # Compute threshold as (1-alpha) quantile
383
+ threshold = float(np.percentile(null_distances, (1 - alpha) * 100))
384
+
385
+ # Compute p-value
386
+ p_value = float(np.mean(null_distances >= observed_distance))
387
+
388
+ return threshold, p_value