ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,517 @@
1
+ """Domain classifier for multivariate distribution drift detection.
2
+
3
+ The domain classifier trains a binary model to distinguish reference (label=0)
4
+ from test (label=1) samples. AUC indicates drift magnitude, feature importances
5
+ show which features drifted.
6
+
7
+ Advantages:
8
+ - Detects multivariate drift and feature interactions
9
+ - Non-parametric (no distributional assumptions)
10
+ - Interpretable via feature importance
11
+ - Sensitive to subtle multivariate shifts
12
+
13
+ AUC Interpretation:
14
+ - AUC ≈ 0.5: No drift (random guess)
15
+ - AUC = 0.6: Weak drift
16
+ - AUC = 0.7-0.8: Moderate drift
17
+ - AUC > 0.9: Strong drift
18
+
19
+ References:
20
+ - Lopez-Paz, D., & Oquab, M. (2017). Revisiting Classifier Two-Sample Tests.
21
+ ICLR 2017.
22
+ - Rabanser, S., et al. (2019). Failing Loudly: An Empirical Study of Methods
23
+ for Detecting Dataset Shift. NeurIPS 2019.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import time
29
+ from dataclasses import dataclass
30
+ from typing import Any
31
+
32
+ import numpy as np
33
+ import pandas as pd
34
+ import polars as pl
35
+
36
+ # Lazy check for optional ML dependencies (imported on first use to avoid slow startup)
37
+ LIGHTGBM_AVAILABLE: bool | None = None
38
+ XGBOOST_AVAILABLE: bool | None = None
39
+
40
+
41
+ def _check_lightgbm_available() -> bool:
42
+ """Check if lightgbm is available (lazy check)."""
43
+ global LIGHTGBM_AVAILABLE
44
+ if LIGHTGBM_AVAILABLE is None:
45
+ try:
46
+ import lightgbm # noqa: F401
47
+
48
+ LIGHTGBM_AVAILABLE = True
49
+ except ImportError:
50
+ LIGHTGBM_AVAILABLE = False
51
+ return LIGHTGBM_AVAILABLE
52
+
53
+
54
+ def _check_xgboost_available() -> bool:
55
+ """Check if xgboost is available (lazy check)."""
56
+ global XGBOOST_AVAILABLE
57
+ if XGBOOST_AVAILABLE is None:
58
+ try:
59
+ import xgboost # noqa: F401
60
+
61
+ XGBOOST_AVAILABLE = True
62
+ except ImportError:
63
+ XGBOOST_AVAILABLE = False
64
+ return XGBOOST_AVAILABLE
65
+
66
+
67
+ @dataclass
68
+ class DomainClassifierResult:
69
+ """Result of domain classifier drift detection.
70
+
71
+ Domain classifier trains a binary model to distinguish reference (label=0)
72
+ from test (label=1) samples. AUC indicates drift magnitude, feature importances
73
+ show which features drifted.
74
+
75
+ Attributes:
76
+ auc: AUC-ROC score (0.5 = no drift, 1.0 = complete distribution shift)
77
+ drifted: Whether drift was detected (auc > threshold)
78
+ feature_importances: DataFrame with feature, importance, rank columns
79
+ threshold: AUC threshold used for drift detection
80
+ n_reference: Number of samples in reference distribution
81
+ n_test: Number of samples in test distribution
82
+ n_features: Number of features used
83
+ model_type: Type of classifier used (lightgbm, xgboost, sklearn)
84
+ cv_auc_mean: Mean AUC from cross-validation
85
+ cv_auc_std: Std of AUC from cross-validation
86
+ interpretation: Human-readable interpretation
87
+ computation_time: Time taken to compute (seconds)
88
+ metadata: Additional metadata
89
+ """
90
+
91
+ auc: float
92
+ drifted: bool
93
+ feature_importances: pl.DataFrame
94
+ threshold: float
95
+ n_reference: int
96
+ n_test: int
97
+ n_features: int
98
+ model_type: str
99
+ cv_auc_mean: float
100
+ cv_auc_std: float
101
+ interpretation: str
102
+ computation_time: float
103
+ metadata: dict[str, Any]
104
+
105
+ def summary(self) -> str:
106
+ """Return formatted summary of domain classifier results."""
107
+ lines = [
108
+ "Domain Classifier Drift Detection Report",
109
+ "=" * 60,
110
+ f"AUC-ROC: {self.auc:.4f} (CV: {self.cv_auc_mean:.4f} ± {self.cv_auc_std:.4f})",
111
+ f"Drift Detected: {'YES' if self.drifted else 'NO'}",
112
+ f"Threshold: {self.threshold:.4f}",
113
+ "",
114
+ "Sample Sizes:",
115
+ f" Reference: {self.n_reference:,}",
116
+ f" Test: {self.n_test:,}",
117
+ "",
118
+ f"Model: {self.model_type}",
119
+ f"Features: {self.n_features}",
120
+ "",
121
+ "Top 5 Most Drifted Features:",
122
+ "-" * 60,
123
+ ]
124
+
125
+ # Show top 5 features
126
+ top_features = self.feature_importances.head(5)
127
+ for row in top_features.iter_rows(named=True):
128
+ lines.append(
129
+ f" {row['rank']:2d}. {row['feature']:30s} (importance: {row['importance']:.4f})"
130
+ )
131
+
132
+ lines.extend(
133
+ [
134
+ "",
135
+ f"Interpretation: {self.interpretation}",
136
+ "",
137
+ f"Computation Time: {self.computation_time:.3f}s",
138
+ ]
139
+ )
140
+
141
+ return "\n".join(lines)
142
+
143
+
144
+ def compute_domain_classifier_drift(
145
+ reference: np.ndarray | pd.DataFrame | pl.DataFrame,
146
+ test: np.ndarray | pd.DataFrame | pl.DataFrame,
147
+ features: list[str] | None = None,
148
+ *,
149
+ model_type: str = "lightgbm",
150
+ n_estimators: int = 100,
151
+ max_depth: int = 5,
152
+ threshold: float = 0.6,
153
+ cv_folds: int = 5,
154
+ random_state: int = 42,
155
+ ) -> DomainClassifierResult:
156
+ """Detect distribution drift using domain classifier.
157
+
158
+ Trains a binary classifier to distinguish reference (label=0) from test (label=1)
159
+ samples. AUC-ROC indicates drift magnitude, feature importance shows which features
160
+ drifted most.
161
+
162
+ The domain classifier approach detects multivariate drift by testing whether
163
+ a classifier can distinguish between two distributions. If AUC ≈ 0.5, the
164
+ distributions are indistinguishable (no drift). If AUC → 1.0, the distributions
165
+ are completely separated (strong drift).
166
+
167
+ **Advantages**:
168
+ - Detects multivariate drift and feature interactions
169
+ - Non-parametric (no distributional assumptions)
170
+ - Interpretable via feature importance
171
+ - Sensitive to subtle multivariate shifts
172
+
173
+ **AUC Interpretation**:
174
+ - AUC ≈ 0.5: No drift (random guess)
175
+ - AUC = 0.6: Weak drift
176
+ - AUC = 0.7-0.8: Moderate drift
177
+ - AUC > 0.9: Strong drift
178
+
179
+ Args:
180
+ reference: Reference distribution (e.g., training data).
181
+ Can be numpy array, pandas DataFrame, or polars DataFrame.
182
+ test: Test distribution (e.g., production data).
183
+ Can be numpy array, pandas DataFrame, or polars DataFrame.
184
+ features: List of feature names to use. If None, uses all numeric columns.
185
+ Only applicable for DataFrame inputs.
186
+ model_type: Classifier type. Options:
187
+ - "lightgbm": LightGBM (default, fastest)
188
+ - "xgboost": XGBoost
189
+ - "sklearn": sklearn RandomForestClassifier (always available)
190
+ n_estimators: Number of trees/estimators (default: 100)
191
+ max_depth: Maximum tree depth (default: 5)
192
+ threshold: AUC threshold for flagging drift (default: 0.6)
193
+ cv_folds: Number of cross-validation folds (default: 5)
194
+ random_state: Random seed for reproducibility (default: 42)
195
+
196
+ Returns:
197
+ DomainClassifierResult with AUC, feature importances, drift flag, etc.
198
+
199
+ Raises:
200
+ ValueError: If inputs are invalid or model_type is unknown
201
+ ImportError: If required ML library is not installed
202
+
203
+ Example:
204
+ >>> import numpy as np
205
+ >>> import polars as pl
206
+ >>> from ml4t.diagnostic.evaluation.drift import compute_domain_classifier_drift
207
+ >>>
208
+ >>> # No drift (identical distributions)
209
+ >>> np.random.seed(42)
210
+ >>> ref = pl.DataFrame({
211
+ ... "x1": np.random.normal(0, 1, 500),
212
+ ... "x2": np.random.normal(0, 1, 500),
213
+ >>> })
214
+ >>> test = pl.DataFrame({
215
+ ... "x1": np.random.normal(0, 1, 500),
216
+ ... "x2": np.random.normal(0, 1, 500),
217
+ >>> })
218
+ >>> result = compute_domain_classifier_drift(ref, test)
219
+ >>> print(f"AUC: {result.auc:.4f}, Drifted: {result.drifted}")
220
+ AUC: 0.5123, Drifted: False
221
+ >>>
222
+ >>> # Strong drift (mean shift)
223
+ >>> test_shifted = pl.DataFrame({
224
+ ... "x1": np.random.normal(2, 1, 500),
225
+ ... "x2": np.random.normal(2, 1, 500),
226
+ >>> })
227
+ >>> result = compute_domain_classifier_drift(ref, test_shifted)
228
+ >>> print(f"AUC: {result.auc:.4f}, Drifted: {result.drifted}")
229
+ AUC: 0.9876, Drifted: True
230
+ >>> print(result.summary())
231
+ >>>
232
+ >>> # Interaction-based drift
233
+ >>> test_corr = pl.DataFrame({
234
+ ... "x1": np.random.normal(0, 1, 500),
235
+ ... "x2": np.random.normal(0, 1, 500) + 0.8 * np.random.normal(0, 1, 500),
236
+ >>> })
237
+ >>> result = compute_domain_classifier_drift(ref, test_corr)
238
+ >>> # Will detect correlation change via feature interactions
239
+
240
+ References:
241
+ - Lopez-Paz, D., & Oquab, M. (2017). Revisiting Classifier Two-Sample Tests.
242
+ ICLR 2017.
243
+ - Rabanser, S., et al. (2019). Failing Loudly: An Empirical Study of Methods
244
+ for Detecting Dataset Shift. NeurIPS 2019.
245
+ """
246
+ start_time = time.time()
247
+
248
+ # Prepare data
249
+ X, y, feature_names = _prepare_domain_classification_data(reference, test, features)
250
+
251
+ # Train classifier with cross-validation
252
+ model, cv_scores = _train_domain_classifier(
253
+ X,
254
+ y,
255
+ model_type=model_type,
256
+ n_estimators=n_estimators,
257
+ max_depth=max_depth,
258
+ cv_folds=cv_folds,
259
+ random_state=random_state,
260
+ )
261
+
262
+ # Extract feature importances
263
+ importances_df = _extract_feature_importances(model, feature_names)
264
+
265
+ # Compute final AUC on full data
266
+ from sklearn.metrics import roc_auc_score
267
+
268
+ y_pred_proba = model.predict_proba(X)[:, 1]
269
+ final_auc = float(roc_auc_score(y, y_pred_proba))
270
+
271
+ # Determine drift status
272
+ drifted = final_auc > threshold
273
+
274
+ # Generate interpretation
275
+ cv_auc_mean = float(np.mean(cv_scores))
276
+ cv_auc_std = float(np.std(cv_scores))
277
+
278
+ if drifted:
279
+ if final_auc > 0.9:
280
+ severity = "strong"
281
+ elif final_auc > 0.7:
282
+ severity = "moderate"
283
+ else:
284
+ severity = "weak"
285
+
286
+ interpretation = (
287
+ f"{severity.capitalize()} distribution drift detected "
288
+ f"(AUC={final_auc:.4f} > {threshold:.4f}). "
289
+ f"The classifier can distinguish reference from test distributions. "
290
+ f"Top drifted feature: {importances_df['feature'][0]}."
291
+ )
292
+ else:
293
+ interpretation = (
294
+ f"No significant drift detected (AUC={final_auc:.4f} ≤ {threshold:.4f}). "
295
+ f"Distributions are indistinguishable by the classifier."
296
+ )
297
+
298
+ computation_time = time.time() - start_time
299
+
300
+ return DomainClassifierResult(
301
+ auc=final_auc,
302
+ drifted=drifted,
303
+ feature_importances=importances_df,
304
+ threshold=threshold,
305
+ n_reference=int(np.sum(y == 0)),
306
+ n_test=int(np.sum(y == 1)),
307
+ n_features=len(feature_names),
308
+ model_type=model_type,
309
+ cv_auc_mean=cv_auc_mean,
310
+ cv_auc_std=cv_auc_std,
311
+ interpretation=interpretation,
312
+ computation_time=computation_time,
313
+ metadata={
314
+ "n_estimators": n_estimators,
315
+ "max_depth": max_depth,
316
+ "cv_folds": cv_folds,
317
+ "random_state": random_state,
318
+ },
319
+ )
320
+
321
+
322
+ def _prepare_domain_classification_data(
323
+ reference: np.ndarray | pd.DataFrame | pl.DataFrame,
324
+ test: np.ndarray | pd.DataFrame | pl.DataFrame,
325
+ features: list[str] | None = None,
326
+ ) -> tuple[np.ndarray, np.ndarray, list[str]]:
327
+ """Prepare labeled dataset for domain classification.
328
+
329
+ Args:
330
+ reference: Reference distribution
331
+ test: Test distribution
332
+ features: Feature names to use (for DataFrames)
333
+
334
+ Returns:
335
+ Tuple of (X, y, feature_names):
336
+ - X: Feature matrix (reference + test concatenated)
337
+ - y: Labels (0 for reference, 1 for test)
338
+ - feature_names: List of feature names
339
+
340
+ Raises:
341
+ ValueError: If inputs are invalid or incompatible
342
+ """
343
+ # Convert to numpy arrays
344
+ if isinstance(reference, pl.DataFrame):
345
+ if features is None:
346
+ # Use all numeric columns
347
+ features = [
348
+ c
349
+ for c in reference.columns
350
+ if reference[c].dtype
351
+ in (pl.Float64, pl.Float32, pl.Int64, pl.Int32, pl.Int16, pl.Int8)
352
+ ]
353
+ X_ref = reference[features].to_numpy()
354
+ feature_names = features
355
+
356
+ elif isinstance(reference, pd.DataFrame):
357
+ if features is None:
358
+ # Use all numeric columns
359
+ features = list(reference.select_dtypes(include=[np.number]).columns)
360
+ X_ref = reference[features].to_numpy()
361
+ feature_names = features
362
+
363
+ elif isinstance(reference, np.ndarray):
364
+ X_ref = reference
365
+ if features is None:
366
+ # Generate default feature names
367
+ if X_ref.ndim == 1:
368
+ X_ref = X_ref.reshape(-1, 1)
369
+ feature_names = [f"feature_{i}" for i in range(X_ref.shape[1])]
370
+ else:
371
+ feature_names = features
372
+
373
+ else:
374
+ raise ValueError(
375
+ f"Unsupported reference type: {type(reference)}. "
376
+ "Must be numpy array, pandas DataFrame, or polars DataFrame."
377
+ )
378
+
379
+ # Process test data
380
+ if isinstance(test, pl.DataFrame | pd.DataFrame):
381
+ X_test = test[feature_names].to_numpy()
382
+ elif isinstance(test, np.ndarray):
383
+ X_test = test
384
+ if X_test.ndim == 1:
385
+ X_test = X_test.reshape(-1, 1)
386
+ else:
387
+ raise ValueError(
388
+ f"Unsupported test type: {type(test)}. Must be numpy array, pandas DataFrame, or polars DataFrame."
389
+ )
390
+
391
+ # Validate shapes
392
+ if X_ref.shape[1] != X_test.shape[1]:
393
+ raise ValueError(
394
+ f"Feature count mismatch: reference has {X_ref.shape[1]} features, test has {X_test.shape[1]} features."
395
+ )
396
+
397
+ # Concatenate and create labels
398
+ X = np.vstack([X_ref, X_test])
399
+ y = np.concatenate([np.zeros(len(X_ref)), np.ones(len(X_test))])
400
+
401
+ return X, y, feature_names
402
+
403
+
404
+ def _train_domain_classifier(
405
+ X: np.ndarray,
406
+ y: np.ndarray,
407
+ model_type: str = "lightgbm",
408
+ n_estimators: int = 100,
409
+ max_depth: int = 5,
410
+ cv_folds: int = 5,
411
+ random_state: int = 42,
412
+ ) -> tuple[Any, np.ndarray]:
413
+ """Train binary classifier for domain classification.
414
+
415
+ Args:
416
+ X: Feature matrix
417
+ y: Labels (0=reference, 1=test)
418
+ model_type: Classifier type
419
+ n_estimators: Number of trees
420
+ max_depth: Maximum tree depth
421
+ cv_folds: Cross-validation folds
422
+ random_state: Random seed
423
+
424
+ Returns:
425
+ Tuple of (trained_model, cv_auc_scores)
426
+
427
+ Raises:
428
+ ValueError: If model_type is unknown
429
+ ImportError: If required library is not installed
430
+ """
431
+ from sklearn.model_selection import cross_val_score
432
+
433
+ # Select and configure model
434
+ if model_type == "lightgbm":
435
+ if not _check_lightgbm_available():
436
+ raise ImportError(
437
+ "LightGBM required for domain classifier drift detection. "
438
+ "Install with: pip install ml4t-diagnostic[ml] or pip install lightgbm"
439
+ )
440
+
441
+ import lightgbm as lgb
442
+
443
+ model = lgb.LGBMClassifier(
444
+ n_estimators=n_estimators,
445
+ max_depth=max_depth,
446
+ random_state=random_state,
447
+ verbose=-1,
448
+ force_col_wise=True, # Suppress warning
449
+ )
450
+
451
+ elif model_type == "xgboost":
452
+ if not _check_xgboost_available():
453
+ raise ImportError(
454
+ "XGBoost required for domain classifier drift detection. Install with: pip install xgboost"
455
+ )
456
+
457
+ import xgboost as xgb
458
+
459
+ model = xgb.XGBClassifier(
460
+ n_estimators=n_estimators,
461
+ max_depth=max_depth,
462
+ random_state=random_state,
463
+ verbosity=0,
464
+ )
465
+
466
+ elif model_type == "sklearn":
467
+ from sklearn.ensemble import RandomForestClassifier
468
+
469
+ model = RandomForestClassifier(
470
+ n_estimators=n_estimators,
471
+ max_depth=max_depth,
472
+ random_state=random_state,
473
+ )
474
+
475
+ else:
476
+ raise ValueError(
477
+ f"Unknown model_type: '{model_type}'. Must be 'lightgbm', 'xgboost', or 'sklearn'."
478
+ )
479
+
480
+ # Cross-validation for AUC
481
+ cv_scores = cross_val_score(model, X, y, cv=cv_folds, scoring="roc_auc")
482
+
483
+ # Train on full data
484
+ model.fit(X, y)
485
+
486
+ return model, cv_scores
487
+
488
+
489
+ def _extract_feature_importances(model: Any, feature_names: list[str]) -> pl.DataFrame:
490
+ """Extract and rank feature importances.
491
+
492
+ Args:
493
+ model: Trained model with feature_importances_ attribute
494
+ feature_names: List of feature names
495
+
496
+ Returns:
497
+ Polars DataFrame with columns: feature, importance, rank
498
+
499
+ Raises:
500
+ ValueError: If model doesn't have feature importances
501
+ """
502
+ # Get importances (works for LightGBM, XGBoost, sklearn)
503
+ if hasattr(model, "feature_importances_"):
504
+ importances = model.feature_importances_
505
+ else:
506
+ raise ValueError(f"Model type {type(model)} does not have feature_importances_ attribute")
507
+
508
+ # Create DataFrame
509
+ df = pl.DataFrame({"feature": feature_names, "importance": importances})
510
+
511
+ # Sort by importance (descending)
512
+ df = df.sort("importance", descending=True)
513
+
514
+ # Add rank
515
+ df = df.with_columns(pl.arange(1, len(df) + 1).alias("rank"))
516
+
517
+ return df