ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,432 @@
1
+ """Unified drift analysis using multiple detection methods.
2
+
3
+ This module provides the main analyze_drift() function that combines
4
+ PSI, Wasserstein, and Domain Classifier methods for comprehensive
5
+ drift detection.
6
+
7
+ Consensus Logic:
8
+ A feature is flagged as drifted if the fraction of methods detecting drift
9
+ exceeds the consensus_threshold. For example, with threshold=0.5:
10
+ - If 2/3 methods detect drift → flagged as drifted
11
+ - If 1/3 methods detect drift → not flagged as drifted
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import time
17
+ from dataclasses import dataclass, field
18
+ from typing import Any
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+ import polars as pl
23
+
24
+ from ml4t.diagnostic.evaluation.drift.domain_classifier import (
25
+ DomainClassifierResult,
26
+ compute_domain_classifier_drift,
27
+ )
28
+ from ml4t.diagnostic.evaluation.drift.population_stability_index import (
29
+ PSIResult,
30
+ compute_psi,
31
+ )
32
+ from ml4t.diagnostic.evaluation.drift.wasserstein import (
33
+ WassersteinResult,
34
+ compute_wasserstein_distance,
35
+ )
36
+
37
+
38
+ @dataclass
39
+ class FeatureDriftResult:
40
+ """Drift analysis result for a single feature across multiple methods.
41
+
42
+ Attributes:
43
+ feature: Feature name
44
+ psi_result: PSI drift detection result (if method was run)
45
+ wasserstein_result: Wasserstein drift detection result (if method was run)
46
+ drifted: Consensus drift flag (based on multiple methods)
47
+ n_methods_run: Number of methods that were run on this feature
48
+ n_methods_detected: Number of methods that detected drift
49
+ drift_probability: Fraction of methods that detected drift
50
+ interpretation: Human-readable interpretation
51
+ """
52
+
53
+ feature: str
54
+ psi_result: PSIResult | None = None
55
+ wasserstein_result: WassersteinResult | None = None
56
+ drifted: bool = False
57
+ n_methods_run: int = 0
58
+ n_methods_detected: int = 0
59
+ drift_probability: float = 0.0
60
+ interpretation: str = ""
61
+
62
+ def summary(self) -> str:
63
+ """Generate summary string for this feature's drift analysis."""
64
+ lines = [f"Feature: {self.feature}"]
65
+ lines.append(
66
+ f" Drifted: {self.drifted} ({self.n_methods_detected}/{self.n_methods_run} methods)"
67
+ )
68
+ lines.append(f" Drift Probability: {self.drift_probability:.2%}")
69
+
70
+ if self.psi_result is not None:
71
+ lines.append(f" PSI: {self.psi_result.psi:.4f} ({self.psi_result.alert_level})")
72
+
73
+ if self.wasserstein_result is not None:
74
+ drifted_str = "drifted" if self.wasserstein_result.drifted else "no drift"
75
+ lines.append(f" Wasserstein: {self.wasserstein_result.distance:.4f} ({drifted_str})")
76
+
77
+ return "\n".join(lines)
78
+
79
+
80
+ @dataclass
81
+ class DriftSummaryResult:
82
+ """Summary of multi-method drift analysis across features.
83
+
84
+ This result aggregates drift detection across multiple methods (PSI,
85
+ Wasserstein, Domain Classifier) to provide a comprehensive drift assessment.
86
+
87
+ Attributes:
88
+ feature_results: Per-feature drift results (PSI + Wasserstein)
89
+ domain_classifier_result: Multivariate drift result (if domain classifier was run)
90
+ n_features: Total number of features analyzed
91
+ n_features_drifted: Number of features flagged as drifted
92
+ drifted_features: List of feature names that drifted
93
+ overall_drifted: Overall drift flag (True if any feature drifted or domain classifier detected drift)
94
+ consensus_threshold: Minimum fraction of methods that must agree to flag drift
95
+ methods_used: List of drift detection methods used
96
+ univariate_methods: Methods run on individual features
97
+ multivariate_methods: Methods run on all features jointly
98
+ interpretation: Human-readable interpretation
99
+ computation_time: Total time taken for all methods (seconds)
100
+ """
101
+
102
+ feature_results: list[FeatureDriftResult]
103
+ domain_classifier_result: DomainClassifierResult | None = None
104
+ n_features: int = 0
105
+ n_features_drifted: int = 0
106
+ drifted_features: list[str] = field(default_factory=list)
107
+ overall_drifted: bool = False
108
+ consensus_threshold: float = 0.5
109
+ methods_used: list[str] = field(default_factory=list)
110
+ univariate_methods: list[str] = field(default_factory=list)
111
+ multivariate_methods: list[str] = field(default_factory=list)
112
+ interpretation: str = ""
113
+ computation_time: float = 0.0
114
+
115
+ def summary(self) -> str:
116
+ """Generate comprehensive summary of drift analysis."""
117
+ lines = ["=" * 60]
118
+ lines.append("Drift Analysis Summary")
119
+ lines.append("=" * 60)
120
+ lines.append(f"Methods Used: {', '.join(self.methods_used)}")
121
+ lines.append(f"Consensus Threshold: {self.consensus_threshold:.0%}")
122
+ lines.append(f"Total Features: {self.n_features}")
123
+ lines.append(
124
+ f"Drifted Features: {self.n_features_drifted} ({self.n_features_drifted / max(1, self.n_features):.0%})"
125
+ )
126
+ lines.append(f"Overall Drift Detected: {self.overall_drifted}")
127
+ lines.append("")
128
+
129
+ if self.drifted_features:
130
+ lines.append("Drifted Features:")
131
+ for feature in self.drifted_features:
132
+ lines.append(f" - {feature}")
133
+ lines.append("")
134
+
135
+ if self.domain_classifier_result is not None:
136
+ lines.append("Multivariate Drift (Domain Classifier):")
137
+ lines.append(f" AUC: {self.domain_classifier_result.auc:.4f}")
138
+ lines.append(f" Drifted: {self.domain_classifier_result.drifted}")
139
+ lines.append("")
140
+
141
+ lines.append(f"Computation Time: {self.computation_time:.2f}s")
142
+ lines.append("=" * 60)
143
+
144
+ return "\n".join(lines)
145
+
146
+ def to_dataframe(self) -> pl.DataFrame:
147
+ """Convert feature-level results to a DataFrame.
148
+
149
+ Returns:
150
+ Polars DataFrame with per-feature drift analysis results
151
+ """
152
+ data = []
153
+ for result in self.feature_results:
154
+ row = {
155
+ "feature": result.feature,
156
+ "drifted": result.drifted,
157
+ "drift_probability": result.drift_probability,
158
+ "n_methods_detected": result.n_methods_detected,
159
+ "n_methods_run": result.n_methods_run,
160
+ }
161
+
162
+ if result.psi_result is not None:
163
+ row["psi"] = result.psi_result.psi
164
+ row["psi_alert"] = result.psi_result.alert_level
165
+
166
+ if result.wasserstein_result is not None:
167
+ row["wasserstein_distance"] = result.wasserstein_result.distance
168
+ row["wasserstein_drifted"] = result.wasserstein_result.drifted
169
+ if result.wasserstein_result.p_value is not None:
170
+ row["wasserstein_pvalue"] = result.wasserstein_result.p_value
171
+
172
+ data.append(row)
173
+
174
+ return pl.DataFrame(data)
175
+
176
+
177
+ def analyze_drift(
178
+ reference: pd.DataFrame | pl.DataFrame,
179
+ test: pd.DataFrame | pl.DataFrame,
180
+ features: list[str] | None = None,
181
+ *,
182
+ methods: list[str] | None = None,
183
+ consensus_threshold: float = 0.5,
184
+ # PSI parameters
185
+ psi_config: dict[str, Any] | None = None,
186
+ # Wasserstein parameters
187
+ wasserstein_config: dict[str, Any] | None = None,
188
+ # Domain classifier parameters
189
+ domain_classifier_config: dict[str, Any] | None = None,
190
+ ) -> DriftSummaryResult:
191
+ """Comprehensive drift analysis using multiple detection methods.
192
+
193
+ This function provides a unified interface for drift detection across multiple
194
+ methods (PSI, Wasserstein, Domain Classifier). It runs univariate methods on
195
+ each feature and optionally multivariate methods on all features jointly.
196
+
197
+ **Univariate Methods** (run per feature):
198
+ - PSI: Population Stability Index (binning-based)
199
+ - Wasserstein: Earth Mover's Distance (metric-based)
200
+
201
+ **Multivariate Methods** (run on all features):
202
+ - Domain Classifier: ML-based drift detection with feature importance
203
+
204
+ **Consensus Logic**:
205
+ A feature is flagged as drifted if the fraction of methods detecting drift
206
+ exceeds the consensus_threshold. For example, with threshold=0.5:
207
+ - If 2/3 methods detect drift → flagged as drifted
208
+ - If 1/3 methods detect drift → not flagged as drifted
209
+
210
+ Args:
211
+ reference: Reference distribution (e.g., training data)
212
+ Can be pandas or polars DataFrame
213
+ test: Test distribution (e.g., production data)
214
+ Can be pandas or polars DataFrame
215
+ features: List of feature names to analyze. If None, uses all numeric columns
216
+ methods: List of methods to use. Options: ["psi", "wasserstein", "domain_classifier"]
217
+ Default: ["psi", "wasserstein", "domain_classifier"]
218
+ consensus_threshold: Minimum fraction of methods that must detect drift
219
+ to flag a feature as drifted (default: 0.5)
220
+ psi_config: Configuration dict for PSI. Keys:
221
+ - n_bins: int (default: 10)
222
+ - is_categorical: bool (default: False)
223
+ - psi_threshold_yellow: float (default: 0.1)
224
+ - psi_threshold_red: float (default: 0.2)
225
+ wasserstein_config: Configuration dict for Wasserstein. Keys:
226
+ - p: int (default: 1)
227
+ - threshold_calibration: bool (default: True)
228
+ - n_permutations: int (default: 1000)
229
+ - alpha: float (default: 0.05)
230
+ domain_classifier_config: Configuration dict for domain classifier. Keys:
231
+ - model_type: str (default: "lightgbm")
232
+ - n_estimators: int (default: 100)
233
+ - max_depth: int (default: 5)
234
+ - threshold: float (default: 0.6)
235
+ - cv_folds: int (default: 5)
236
+
237
+ Returns:
238
+ DriftSummaryResult with per-feature results, multivariate results,
239
+ and overall drift assessment
240
+
241
+ Raises:
242
+ ValueError: If inputs are invalid or methods list is empty
243
+
244
+ Example:
245
+ >>> import pandas as pd
246
+ >>> from ml4t.diagnostic.evaluation.drift import analyze_drift
247
+ >>>
248
+ >>> # Create reference and test data
249
+ >>> reference = pd.DataFrame({
250
+ ... 'feature1': np.random.normal(0, 1, 1000),
251
+ ... 'feature2': np.random.normal(0, 1, 1000)
252
+ ... })
253
+ >>> test = pd.DataFrame({
254
+ ... 'feature1': np.random.normal(0.5, 1, 1000), # Mean shifted
255
+ ... 'feature2': np.random.normal(0, 1, 1000) # No shift
256
+ ... })
257
+ >>>
258
+ >>> # Run drift analysis
259
+ >>> result = analyze_drift(reference, test)
260
+ >>> print(result.summary())
261
+ >>>
262
+ >>> # Check which features drifted
263
+ >>> print(f"Drifted features: {result.drifted_features}")
264
+ >>>
265
+ >>> # Get per-feature details
266
+ >>> df = result.to_dataframe()
267
+ >>> print(df)
268
+ """
269
+ start_time = time.time()
270
+
271
+ # Input validation
272
+ if reference is None or test is None:
273
+ raise ValueError("reference and test must not be None")
274
+
275
+ # Convert to pandas for easier processing
276
+ reference_pd: pd.DataFrame
277
+ test_pd: pd.DataFrame
278
+ if isinstance(reference, pl.DataFrame):
279
+ reference_pd = reference.to_pandas()
280
+ else:
281
+ reference_pd = reference
282
+ if isinstance(test, pl.DataFrame):
283
+ test_pd = test.to_pandas()
284
+ else:
285
+ test_pd = test
286
+
287
+ # Determine features to analyze
288
+ if features is None:
289
+ # Use all numeric columns
290
+ numeric_cols = reference_pd.select_dtypes(include=[np.number]).columns.tolist()
291
+ features = numeric_cols
292
+ else:
293
+ # Validate features exist
294
+ missing_in_ref = set(features) - set(reference_pd.columns)
295
+ missing_in_test = set(features) - set(test_pd.columns)
296
+ if missing_in_ref or missing_in_test:
297
+ raise ValueError(
298
+ f"Features not found - reference: {missing_in_ref}, test: {missing_in_test}"
299
+ )
300
+
301
+ if not features:
302
+ raise ValueError("No features to analyze")
303
+
304
+ # Determine methods to use
305
+ if methods is None:
306
+ methods = ["psi", "wasserstein", "domain_classifier"]
307
+
308
+ valid_methods = ["psi", "wasserstein", "domain_classifier"]
309
+ invalid_methods = set(methods) - set(valid_methods)
310
+ if invalid_methods:
311
+ raise ValueError(f"Invalid methods: {invalid_methods}. Valid: {valid_methods}")
312
+
313
+ # Separate univariate and multivariate methods
314
+ univariate_methods = [m for m in methods if m in ["psi", "wasserstein"]]
315
+ multivariate_methods = [m for m in methods if m == "domain_classifier"]
316
+
317
+ # Set default configs
318
+ if psi_config is None:
319
+ psi_config = {}
320
+ if wasserstein_config is None:
321
+ wasserstein_config = {}
322
+ if domain_classifier_config is None:
323
+ domain_classifier_config = {}
324
+
325
+ # Run univariate methods on each feature
326
+ feature_results = []
327
+ for feature in features:
328
+ # Explicitly convert to ndarray to handle ExtensionArray types
329
+ ref_values = np.asarray(reference_pd[feature].values, dtype=np.float64)
330
+ test_values = np.asarray(test_pd[feature].values, dtype=np.float64)
331
+
332
+ psi_result = None
333
+ wasserstein_result = None
334
+ n_methods_run = 0
335
+ n_methods_detected = 0
336
+
337
+ # PSI
338
+ if "psi" in methods:
339
+ try:
340
+ psi_result = compute_psi(ref_values, test_values, **psi_config)
341
+ n_methods_run += 1
342
+ if psi_result.alert_level in ["yellow", "red"]:
343
+ n_methods_detected += 1
344
+ except Exception as e:
345
+ # Log warning but continue
346
+ print(f"Warning: PSI failed for feature {feature}: {e}")
347
+
348
+ # Wasserstein
349
+ if "wasserstein" in methods:
350
+ try:
351
+ wasserstein_result = compute_wasserstein_distance(
352
+ ref_values, test_values, **wasserstein_config
353
+ )
354
+ n_methods_run += 1
355
+ if wasserstein_result.drifted:
356
+ n_methods_detected += 1
357
+ except Exception as e:
358
+ # Log warning but continue
359
+ print(f"Warning: Wasserstein failed for feature {feature}: {e}")
360
+
361
+ # Consensus drift flag
362
+ drift_probability = n_methods_detected / max(1, n_methods_run)
363
+ drifted = drift_probability >= consensus_threshold
364
+
365
+ # Interpretation
366
+ if drifted:
367
+ interpretation = f"{n_methods_detected}/{n_methods_run} methods detected drift (probability: {drift_probability:.0%})"
368
+ else:
369
+ interpretation = (
370
+ f"No consensus drift ({n_methods_detected}/{n_methods_run} methods, "
371
+ f"threshold: {consensus_threshold:.0%})"
372
+ )
373
+
374
+ feature_results.append(
375
+ FeatureDriftResult(
376
+ feature=feature,
377
+ psi_result=psi_result,
378
+ wasserstein_result=wasserstein_result,
379
+ drifted=drifted,
380
+ n_methods_run=n_methods_run,
381
+ n_methods_detected=n_methods_detected,
382
+ drift_probability=drift_probability,
383
+ interpretation=interpretation,
384
+ )
385
+ )
386
+
387
+ # Run multivariate domain classifier if requested
388
+ domain_classifier_result = None
389
+ if "domain_classifier" in methods:
390
+ try:
391
+ domain_classifier_result = compute_domain_classifier_drift(
392
+ reference[features], test[features], **domain_classifier_config
393
+ )
394
+ except Exception as e:
395
+ # Log warning but continue
396
+ print(f"Warning: Domain classifier failed: {e}")
397
+
398
+ # Aggregate results
399
+ n_features = len(features)
400
+ n_features_drifted = sum(r.drifted for r in feature_results)
401
+ drifted_features = [r.feature for r in feature_results if r.drifted]
402
+
403
+ # Overall drift flag
404
+ overall_drifted = n_features_drifted > 0
405
+ if domain_classifier_result is not None and domain_classifier_result.drifted:
406
+ overall_drifted = True
407
+
408
+ # Interpretation
409
+ if overall_drifted:
410
+ interpretation = (
411
+ f"Drift detected in {n_features_drifted}/{n_features} features "
412
+ f"({n_features_drifted / max(1, n_features):.0%})"
413
+ )
414
+ else:
415
+ interpretation = f"No drift detected across {n_features} features"
416
+
417
+ computation_time = time.time() - start_time
418
+
419
+ return DriftSummaryResult(
420
+ feature_results=feature_results,
421
+ domain_classifier_result=domain_classifier_result,
422
+ n_features=n_features,
423
+ n_features_drifted=n_features_drifted,
424
+ drifted_features=drifted_features,
425
+ overall_drifted=overall_drifted,
426
+ consensus_threshold=consensus_threshold,
427
+ methods_used=methods,
428
+ univariate_methods=univariate_methods,
429
+ multivariate_methods=multivariate_methods,
430
+ interpretation=interpretation,
431
+ computation_time=computation_time,
432
+ )