ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,470 @@
1
+ """Normality tests for distribution analysis.
2
+
3
+ This module provides statistical tests for normality:
4
+ - Jarque-Bera test: Based on sample skewness and kurtosis, asymptotically valid
5
+ - Shapiro-Wilk test: More powerful for small samples (n < 2000), recommended
6
+
7
+ Test Comparison:
8
+ - Jarque-Bera: Based on sample skewness and kurtosis, asymptotically valid
9
+ - Shapiro-Wilk: More powerful for small samples (n < 2000), recommended
10
+
11
+ References:
12
+ - Jarque, C. M., & Bera, A. K. (1980). Efficient tests for normality,
13
+ homoscedasticity and serial independence of regression residuals.
14
+ Economics Letters, 6(3), 255-259. DOI: 10.1016/0165-1765(80)90024-5
15
+ - Shapiro, S. S., & Wilk, M. B. (1965). An analysis of variance test
16
+ for normality (complete samples). Biometrika, 52(3-4), 591-611.
17
+ DOI: 10.1093/biomet/52.3-4.591
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from dataclasses import dataclass
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ from scipy import stats
27
+
28
+ from ml4t.diagnostic.errors import ComputationError, ValidationError
29
+ from ml4t.diagnostic.logging import get_logger
30
+
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class JarqueBeraResult:
36
+ """Jarque-Bera normality test result.
37
+
38
+ Tests for normality based on sample skewness and kurtosis. The test
39
+ statistic is: JB = (n/6) * (S^2 + K^2/4), where S is skewness and K
40
+ is excess kurtosis. Under H0 (normality), JB ~ χ²(2).
41
+
42
+ Attributes:
43
+ statistic: Jarque-Bera test statistic
44
+ p_value: P-value for null hypothesis (data is normally distributed)
45
+ skewness: Sample skewness used in test
46
+ excess_kurtosis: Sample excess kurtosis used in test (Fisher: normal=0)
47
+ is_normal: Whether data is consistent with normality (p >= alpha)
48
+ n_obs: Number of observations
49
+ alpha: Significance level used
50
+ """
51
+
52
+ statistic: float
53
+ p_value: float
54
+ skewness: float
55
+ excess_kurtosis: float
56
+ is_normal: bool
57
+ n_obs: int
58
+ alpha: float = 0.05
59
+
60
+ def __repr__(self) -> str:
61
+ """String representation."""
62
+ return f"JarqueBeraResult(statistic={self.statistic:.4f}, p_value={self.p_value:.4f}, is_normal={self.is_normal})"
63
+
64
+ def summary(self) -> str:
65
+ """Human-readable summary of Jarque-Bera test.
66
+
67
+ Returns:
68
+ Formatted summary string
69
+ """
70
+ lines = [
71
+ "Jarque-Bera Normality Test",
72
+ "=" * 50,
73
+ f"Test Statistic: {self.statistic:.4f}",
74
+ f"P-value: {self.p_value:.4f}",
75
+ f"Observations: {self.n_obs}",
76
+ f"Significance: α={self.alpha}",
77
+ ]
78
+
79
+ lines.append("")
80
+ lines.append("Moments:")
81
+ lines.append(f" Skewness: {self.skewness:.4f}")
82
+ lines.append(f" Excess Kurtosis: {self.excess_kurtosis:.4f}")
83
+
84
+ lines.append("")
85
+ conclusion = (
86
+ "Data is consistent with normality"
87
+ if self.is_normal
88
+ else "Data deviates from normality"
89
+ )
90
+ lines.append(f"Conclusion: {conclusion}")
91
+ lines.append(
92
+ f" (Fail to reject H0 at {self.alpha * 100:.0f}% level)"
93
+ if self.is_normal
94
+ else f" (Reject H0 at {self.alpha * 100:.0f}% level)"
95
+ )
96
+
97
+ lines.append("")
98
+ lines.append("Test Methodology:")
99
+ lines.append(" - JB = (n/6) * (S² + K²/4)")
100
+ lines.append(" - H0: Data is normally distributed")
101
+ lines.append(" - Under H0: JB ~ χ²(2)")
102
+ lines.append(" - Asymptotically valid (requires large n)")
103
+
104
+ if not self.is_normal:
105
+ lines.append("")
106
+ lines.append("Implications:")
107
+ lines.append(" - Normal distribution assumption violated")
108
+ lines.append(" - Consider robust statistical methods")
109
+ lines.append(" - Account for non-normality in risk models")
110
+
111
+ return "\n".join(lines)
112
+
113
+
114
+ @dataclass
115
+ class ShapiroWilkResult:
116
+ """Shapiro-Wilk normality test result.
117
+
118
+ Tests for normality using order statistics. More powerful than Jarque-Bera
119
+ for small samples (n < 2000). The test statistic W ranges from 0 to 1,
120
+ with values close to 1 indicating normality.
121
+
122
+ Attributes:
123
+ statistic: Shapiro-Wilk test statistic (W)
124
+ p_value: P-value for null hypothesis (data is normally distributed)
125
+ is_normal: Whether data is consistent with normality (p >= alpha)
126
+ n_obs: Number of observations
127
+ alpha: Significance level used
128
+ """
129
+
130
+ statistic: float
131
+ p_value: float
132
+ is_normal: bool
133
+ n_obs: int
134
+ alpha: float = 0.05
135
+
136
+ def __repr__(self) -> str:
137
+ """String representation."""
138
+ return f"ShapiroWilkResult(statistic={self.statistic:.4f}, p_value={self.p_value:.4f}, is_normal={self.is_normal})"
139
+
140
+ def summary(self) -> str:
141
+ """Human-readable summary of Shapiro-Wilk test.
142
+
143
+ Returns:
144
+ Formatted summary string
145
+ """
146
+ lines = [
147
+ "Shapiro-Wilk Normality Test",
148
+ "=" * 50,
149
+ f"Test Statistic (W): {self.statistic:.4f}",
150
+ f"P-value: {self.p_value:.4f}",
151
+ f"Observations: {self.n_obs}",
152
+ f"Significance: α={self.alpha}",
153
+ ]
154
+
155
+ lines.append("")
156
+ conclusion = (
157
+ "Data is consistent with normality"
158
+ if self.is_normal
159
+ else "Data deviates from normality"
160
+ )
161
+ lines.append(f"Conclusion: {conclusion}")
162
+ lines.append(
163
+ f" (Fail to reject H0 at {self.alpha * 100:.0f}% level)"
164
+ if self.is_normal
165
+ else f" (Reject H0 at {self.alpha * 100:.0f}% level)"
166
+ )
167
+
168
+ lines.append("")
169
+ lines.append("Test Methodology:")
170
+ lines.append(" - Based on correlation between data and normal scores")
171
+ lines.append(" - W statistic ranges from 0 (non-normal) to 1 (normal)")
172
+ lines.append(" - H0: Data is normally distributed")
173
+ lines.append(" - More powerful than Jarque-Bera for small samples")
174
+ lines.append(" - Recommended for n < 2000")
175
+
176
+ if not self.is_normal:
177
+ lines.append("")
178
+ lines.append("Implications:")
179
+ lines.append(" - Normal distribution assumption violated")
180
+ lines.append(" - Consider non-parametric methods")
181
+ lines.append(" - Use robust estimators for inference")
182
+
183
+ return "\n".join(lines)
184
+
185
+
186
+ def jarque_bera_test(
187
+ data: pd.Series | np.ndarray,
188
+ alpha: float = 0.05,
189
+ ) -> JarqueBeraResult:
190
+ """Jarque-Bera test for normality.
191
+
192
+ Tests whether sample skewness and kurtosis match a normal distribution.
193
+ The test statistic is:
194
+
195
+ JB = (n/6) * (S^2 + K^2/4)
196
+
197
+ where n is sample size, S is skewness, K is excess kurtosis.
198
+ Under H0 (normality), JB ~ χ²(2).
199
+
200
+ The null hypothesis is that the data is normally distributed. Low p-values
201
+ (< alpha) indicate rejection of normality.
202
+
203
+ Args:
204
+ data: Time series data (1D array or Series)
205
+ alpha: Significance level (default 0.05)
206
+
207
+ Returns:
208
+ JarqueBeraResult with test statistics and conclusion
209
+
210
+ Raises:
211
+ ValidationError: If data is invalid (empty, wrong shape, etc.)
212
+ ComputationError: If test computation fails
213
+
214
+ Example:
215
+ >>> import numpy as np
216
+ >>> # Normal data (should pass)
217
+ >>> normal = np.random.normal(0, 1, 1000)
218
+ >>> result = jarque_bera_test(normal)
219
+ >>> print(f"p-value: {result.p_value:.4f}, normal: {result.is_normal}")
220
+ >>>
221
+ >>> # Lognormal data (should fail)
222
+ >>> lognormal = np.random.lognormal(0, 0.5, 1000)
223
+ >>> result = jarque_bera_test(lognormal)
224
+ >>> print(f"p-value: {result.p_value:.4f}, normal: {result.is_normal}")
225
+
226
+ Notes:
227
+ - Test is asymptotically valid (requires large n)
228
+ - More powerful for large samples (n > 2000)
229
+ - For small samples, use Shapiro-Wilk test instead
230
+ - Uses scipy.stats.jarque_bera
231
+ """
232
+ # Input validation (same as compute_moments)
233
+ if data is None:
234
+ raise ValidationError("Data cannot be None", context={"function": "jarque_bera_test"})
235
+
236
+ # Convert to numpy array
237
+ if isinstance(data, pd.Series):
238
+ arr = data.to_numpy()
239
+ elif isinstance(data, np.ndarray):
240
+ arr = data
241
+ else:
242
+ raise ValidationError(
243
+ f"Data must be pandas Series or numpy array, got {type(data)}",
244
+ context={"function": "jarque_bera_test", "data_type": type(data).__name__},
245
+ )
246
+
247
+ # Check array properties
248
+ if arr.ndim != 1:
249
+ raise ValidationError(
250
+ f"Data must be 1-dimensional, got {arr.ndim}D",
251
+ context={"function": "jarque_bera_test", "shape": arr.shape},
252
+ )
253
+
254
+ if len(arr) == 0:
255
+ raise ValidationError(
256
+ "Data cannot be empty", context={"function": "jarque_bera_test", "length": 0}
257
+ )
258
+
259
+ # Check for missing/infinite values
260
+ if np.any(~np.isfinite(arr)):
261
+ n_invalid = np.sum(~np.isfinite(arr))
262
+ raise ValidationError(
263
+ f"Data contains {n_invalid} NaN or infinite values",
264
+ context={"function": "jarque_bera_test", "n_invalid": n_invalid, "length": len(arr)},
265
+ )
266
+
267
+ # Check minimum length
268
+ min_length = 20
269
+ if len(arr) < min_length:
270
+ raise ValidationError(
271
+ f"Insufficient data for Jarque-Bera test (need at least {min_length} observations)",
272
+ context={
273
+ "function": "jarque_bera_test",
274
+ "length": len(arr),
275
+ "min_length": min_length,
276
+ },
277
+ )
278
+
279
+ # Check for constant series
280
+ if np.std(arr) == 0:
281
+ raise ValidationError(
282
+ "Data is constant (zero variance)",
283
+ context={
284
+ "function": "jarque_bera_test",
285
+ "length": len(arr),
286
+ "mean": float(np.mean(arr)),
287
+ },
288
+ )
289
+
290
+ logger.info("Running Jarque-Bera test", n_obs=len(arr), alpha=alpha)
291
+
292
+ try:
293
+ # Run Jarque-Bera test using scipy
294
+ # Returns (statistic, p_value)
295
+ jb_stat, p_value = stats.jarque_bera(arr)
296
+
297
+ # Compute moments for reporting
298
+ skewness = float(stats.skew(arr, bias=False))
299
+ excess_kurtosis = float(stats.kurtosis(arr, bias=False))
300
+
301
+ # Determine normality
302
+ is_normal = p_value >= alpha
303
+
304
+ logger.info(
305
+ "Jarque-Bera test completed",
306
+ statistic=jb_stat,
307
+ p_value=p_value,
308
+ is_normal=is_normal,
309
+ )
310
+
311
+ return JarqueBeraResult(
312
+ statistic=float(jb_stat),
313
+ p_value=float(p_value),
314
+ skewness=skewness,
315
+ excess_kurtosis=excess_kurtosis,
316
+ is_normal=is_normal,
317
+ n_obs=len(arr),
318
+ alpha=alpha,
319
+ )
320
+
321
+ except Exception as e:
322
+ logger.error("Jarque-Bera test failed", error=str(e), n_obs=len(arr))
323
+ raise ComputationError( # noqa: B904
324
+ f"Jarque-Bera test computation failed: {e}",
325
+ context={"function": "jarque_bera_test", "n_obs": len(arr), "alpha": alpha},
326
+ cause=e,
327
+ )
328
+
329
+
330
+ def shapiro_wilk_test(
331
+ data: pd.Series | np.ndarray,
332
+ alpha: float = 0.05,
333
+ ) -> ShapiroWilkResult:
334
+ """Shapiro-Wilk test for normality.
335
+
336
+ Tests for normality using order statistics. More powerful than Jarque-Bera
337
+ for small samples (n < 2000). The test statistic W ranges from 0 to 1,
338
+ with values close to 1 indicating normality.
339
+
340
+ The null hypothesis is that the data is normally distributed. Low p-values
341
+ (< alpha) indicate rejection of normality.
342
+
343
+ Args:
344
+ data: Time series data (1D array or Series)
345
+ alpha: Significance level (default 0.05)
346
+
347
+ Returns:
348
+ ShapiroWilkResult with test statistics and conclusion
349
+
350
+ Raises:
351
+ ValidationError: If data is invalid (empty, wrong shape, etc.)
352
+ ComputationError: If test computation fails
353
+
354
+ Example:
355
+ >>> import numpy as np
356
+ >>> # Normal data (should pass)
357
+ >>> normal = np.random.normal(0, 1, 500)
358
+ >>> result = shapiro_wilk_test(normal)
359
+ >>> print(f"W: {result.statistic:.4f}, p-value: {result.p_value:.4f}")
360
+ >>>
361
+ >>> # Lognormal data (should fail)
362
+ >>> lognormal = np.random.lognormal(0, 0.5, 500)
363
+ >>> result = shapiro_wilk_test(lognormal)
364
+ >>> print(f"Normal: {result.is_normal}")
365
+
366
+ Notes:
367
+ - More powerful than Jarque-Bera for small samples (n < 2000)
368
+ - Recommended over Jarque-Bera when n < 2000
369
+ - W statistic close to 1 indicates normality
370
+ - Uses scipy.stats.shapiro
371
+ - Maximum sample size: 5000 (scipy limitation)
372
+ """
373
+ # Input validation (same as jarque_bera_test)
374
+ if data is None:
375
+ raise ValidationError("Data cannot be None", context={"function": "shapiro_wilk_test"})
376
+
377
+ # Convert to numpy array
378
+ if isinstance(data, pd.Series):
379
+ arr = data.to_numpy()
380
+ elif isinstance(data, np.ndarray):
381
+ arr = data
382
+ else:
383
+ raise ValidationError(
384
+ f"Data must be pandas Series or numpy array, got {type(data)}",
385
+ context={"function": "shapiro_wilk_test", "data_type": type(data).__name__},
386
+ )
387
+
388
+ # Check array properties
389
+ if arr.ndim != 1:
390
+ raise ValidationError(
391
+ f"Data must be 1-dimensional, got {arr.ndim}D",
392
+ context={"function": "shapiro_wilk_test", "shape": arr.shape},
393
+ )
394
+
395
+ if len(arr) == 0:
396
+ raise ValidationError(
397
+ "Data cannot be empty", context={"function": "shapiro_wilk_test", "length": 0}
398
+ )
399
+
400
+ # Check for missing/infinite values
401
+ if np.any(~np.isfinite(arr)):
402
+ n_invalid = np.sum(~np.isfinite(arr))
403
+ raise ValidationError(
404
+ f"Data contains {n_invalid} NaN or infinite values",
405
+ context={"function": "shapiro_wilk_test", "n_invalid": n_invalid, "length": len(arr)},
406
+ )
407
+
408
+ # Check minimum length (Shapiro-Wilk needs at least 3 observations)
409
+ min_length = 3
410
+ if len(arr) < min_length:
411
+ raise ValidationError(
412
+ f"Insufficient data for Shapiro-Wilk test (need at least {min_length} observations)",
413
+ context={
414
+ "function": "shapiro_wilk_test",
415
+ "length": len(arr),
416
+ "min_length": min_length,
417
+ },
418
+ )
419
+
420
+ # Check maximum length (scipy limitation)
421
+ max_length = 5000
422
+ if len(arr) > max_length:
423
+ logger.warning(
424
+ f"Data has {len(arr)} observations, using first {max_length} (scipy.stats.shapiro limitation)"
425
+ )
426
+ arr = arr[:max_length]
427
+
428
+ # Check for constant series
429
+ if np.std(arr) == 0:
430
+ raise ValidationError(
431
+ "Data is constant (zero variance)",
432
+ context={
433
+ "function": "shapiro_wilk_test",
434
+ "length": len(arr),
435
+ "mean": float(np.mean(arr)),
436
+ },
437
+ )
438
+
439
+ logger.info("Running Shapiro-Wilk test", n_obs=len(arr), alpha=alpha)
440
+
441
+ try:
442
+ # Run Shapiro-Wilk test using scipy
443
+ # Returns (statistic, p_value)
444
+ w_stat, p_value = stats.shapiro(arr)
445
+
446
+ # Determine normality
447
+ is_normal = p_value >= alpha
448
+
449
+ logger.info(
450
+ "Shapiro-Wilk test completed",
451
+ statistic=w_stat,
452
+ p_value=p_value,
453
+ is_normal=is_normal,
454
+ )
455
+
456
+ return ShapiroWilkResult(
457
+ statistic=float(w_stat),
458
+ p_value=float(p_value),
459
+ is_normal=is_normal,
460
+ n_obs=len(arr),
461
+ alpha=alpha,
462
+ )
463
+
464
+ except Exception as e:
465
+ logger.error("Shapiro-Wilk test failed", error=str(e), n_obs=len(arr))
466
+ raise ComputationError( # noqa: B904
467
+ f"Shapiro-Wilk test computation failed: {e}",
468
+ context={"function": "shapiro_wilk_test", "n_obs": len(arr), "alpha": alpha},
469
+ cause=e,
470
+ )
@@ -0,0 +1,139 @@
1
+ """Distribution drift detection for feature monitoring.
2
+
3
+ This module provides comprehensive drift detection with three complementary methods
4
+ and a unified analysis interface:
5
+
6
+ **Individual Methods**:
7
+ - **PSI (Population Stability Index)**: Bin-based distribution comparison
8
+ - **Wasserstein Distance**: Optimal transport metric for continuous features
9
+ - **Domain Classifier**: ML-based multivariate drift detection with feature importance
10
+
11
+ **Unified Interface**:
12
+ - **analyze_drift()**: Multi-method drift analysis with consensus-based flagging
13
+
14
+ Distribution drift is critical for ML model monitoring:
15
+ - Feature distributions change over time (concept drift)
16
+ - Model performance degrades when test distribution differs from training
17
+ - Early detection allows proactive model retraining
18
+ - Multi-method consensus increases confidence in drift detection
19
+
20
+ PSI Interpretation:
21
+ - PSI < 0.1: No significant change (green)
22
+ - 0.1 ≤ PSI < 0.2: Small change, monitor (yellow)
23
+ - PSI ≥ 0.2: Significant change, investigate (red)
24
+
25
+ Wasserstein Distance Interpretation:
26
+ - W = 0: Identical distributions
27
+ - W > 0: Distribution drift detected
28
+ - Larger values indicate greater drift magnitude
29
+ - Threshold calibrated via permutation testing
30
+
31
+ Domain Classifier Interpretation:
32
+ - AUC ≈ 0.5: No drift (random guess between reference and test)
33
+ - AUC = 0.6: Weak drift
34
+ - AUC = 0.7-0.8: Moderate drift
35
+ - AUC > 0.9: Strong drift
36
+ - Feature importance identifies which features drifted
37
+
38
+ When to Use:
39
+ - **PSI**: Categorical features or when binning is acceptable
40
+ - **Wasserstein**: Continuous features, more sensitive to small shifts
41
+ - **Domain Classifier**: Multivariate drift, interaction detection
42
+ - **analyze_drift()**: Comprehensive analysis with multiple methods
43
+ - Model monitoring: Compare production data to training data
44
+ - Temporal drift: Compare recent data to historical baseline
45
+ - Segmentation drift: Compare distributions across segments
46
+
47
+ References:
48
+ - Yurdakul, B. (2018). Statistical Properties of Population Stability Index.
49
+ https://scholarship.richmond.edu/honors-theses/1131/
50
+ - Webb, G. I., et al. (2016). Characterizing concept drift.
51
+ Data Mining and Knowledge Discovery, 30(4), 964-994.
52
+ - Villani, C. (2009). Optimal Transport: Old and New. Springer.
53
+ - Ramdas, A., et al. (2017). On Wasserstein Two-Sample Testing and Related
54
+ Families of Nonparametric Tests. Entropy, 19(2), 47.
55
+ - Lopez-Paz, D., & Oquab, M. (2017). Revisiting Classifier Two-Sample Tests.
56
+ ICLR 2017.
57
+ - Rabanser, S., et al. (2019). Failing Loudly: An Empirical Study of Methods
58
+ for Detecting Dataset Shift. NeurIPS 2019.
59
+
60
+ Example - Individual Methods:
61
+ >>> import numpy as np
62
+ >>> from ml4t.diagnostic.evaluation.drift import (
63
+ ... compute_psi, compute_wasserstein_distance, compute_domain_classifier_drift
64
+ ... )
65
+ >>>
66
+ >>> # PSI for univariate drift
67
+ >>> reference = np.random.normal(0, 1, 1000)
68
+ >>> test = np.random.normal(0.5, 1, 1000) # Mean shifted
69
+ >>> psi_result = compute_psi(reference, test, n_bins=10)
70
+ >>> print(f"PSI: {psi_result.psi:.4f}, Alert: {psi_result.alert_level}")
71
+ >>>
72
+ >>> # Wasserstein for continuous features
73
+ >>> ws_result = compute_wasserstein_distance(reference, test)
74
+ >>> print(f"Wasserstein: {ws_result.distance:.4f}, Drifted: {ws_result.drifted}")
75
+
76
+ Example - Unified Analysis:
77
+ >>> import pandas as pd
78
+ >>> from ml4t.diagnostic.evaluation.drift import analyze_drift
79
+ >>>
80
+ >>> # Create reference and test datasets
81
+ >>> reference = pd.DataFrame({
82
+ ... 'feature1': np.random.normal(0, 1, 1000),
83
+ ... 'feature2': np.random.normal(0, 1, 1000),
84
+ ... })
85
+ >>> test = pd.DataFrame({
86
+ ... 'feature1': np.random.normal(0.5, 1, 1000), # Drifted
87
+ ... 'feature2': np.random.normal(0, 1, 1000), # Stable
88
+ ... })
89
+ >>>
90
+ >>> # Comprehensive drift analysis with all methods
91
+ >>> result = analyze_drift(reference, test)
92
+ >>> print(result.summary())
93
+ >>> print(f"Drifted features: {result.drifted_features}")
94
+ >>>
95
+ >>> # Get detailed results as DataFrame
96
+ >>> df = result.to_dataframe()
97
+ >>> print(df)
98
+ >>>
99
+ >>> # Use specific methods only
100
+ >>> result = analyze_drift(reference, test, methods=['psi', 'wasserstein'])
101
+ >>>
102
+ >>> # Customize consensus threshold (default: 0.5)
103
+ >>> result = analyze_drift(reference, test, consensus_threshold=0.66)
104
+ """
105
+
106
+ # Import from submodules and re-export
107
+ from ml4t.diagnostic.evaluation.drift.analysis import (
108
+ DriftSummaryResult,
109
+ FeatureDriftResult,
110
+ analyze_drift,
111
+ )
112
+ from ml4t.diagnostic.evaluation.drift.domain_classifier import (
113
+ DomainClassifierResult,
114
+ compute_domain_classifier_drift,
115
+ )
116
+ from ml4t.diagnostic.evaluation.drift.population_stability_index import (
117
+ PSIResult,
118
+ compute_psi,
119
+ )
120
+ from ml4t.diagnostic.evaluation.drift.wasserstein import (
121
+ WassersteinResult,
122
+ compute_wasserstein_distance,
123
+ )
124
+
125
+ __all__ = [
126
+ # PSI
127
+ "compute_psi",
128
+ "PSIResult",
129
+ # Wasserstein
130
+ "compute_wasserstein_distance",
131
+ "WassersteinResult",
132
+ # Domain Classifier
133
+ "compute_domain_classifier_drift",
134
+ "DomainClassifierResult",
135
+ # Unified analysis
136
+ "analyze_drift",
137
+ "FeatureDriftResult",
138
+ "DriftSummaryResult",
139
+ ]