ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,274 @@
1
+ """DataFrame validation utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import polars as pl
8
+
9
+
10
+ class ValidationError(ValueError):
11
+ """Validation error with helpful context."""
12
+
13
+ def __init__(self, message: str, context: dict[str, Any] | None = None):
14
+ """Initialize validation error.
15
+
16
+ Args:
17
+ message: Error message
18
+ context: Additional context (columns, types, etc.)
19
+ """
20
+ self.context = context or {}
21
+ super().__init__(self._format_message(message))
22
+
23
+ def _format_message(self, message: str) -> str:
24
+ """Format error message with context."""
25
+ if not self.context:
26
+ return message
27
+
28
+ context_str = "\n".join(f" {k}: {v}" for k, v in self.context.items())
29
+ return f"{message}\nContext:\n{context_str}"
30
+
31
+
32
+ class DataFrameValidator:
33
+ """Validator for Polars DataFrames.
34
+
35
+ Examples:
36
+ >>> validator = DataFrameValidator(df)
37
+ >>> validator.require_columns(["close", "volume"])
38
+ >>> validator.require_numeric(["close", "volume"])
39
+ >>> validator.check_nulls(allow_nulls=False)
40
+ """
41
+
42
+ def __init__(self, df: pl.DataFrame):
43
+ """Initialize validator.
44
+
45
+ Args:
46
+ df: DataFrame to validate
47
+ """
48
+ self.df = df
49
+
50
+ def require_columns(self, columns: list[str]) -> DataFrameValidator:
51
+ """Require specific columns to exist.
52
+
53
+ Args:
54
+ columns: Required column names
55
+
56
+ Returns:
57
+ Self for chaining
58
+
59
+ Raises:
60
+ ValidationError: If required columns missing
61
+ """
62
+ missing = [col for col in columns if col not in self.df.columns]
63
+
64
+ if missing:
65
+ raise ValidationError(
66
+ f"Missing required columns: {missing}",
67
+ context={
68
+ "required": columns,
69
+ "available": self.df.columns,
70
+ "missing": missing,
71
+ },
72
+ )
73
+
74
+ return self
75
+
76
+ def require_numeric(self, columns: list[str]) -> DataFrameValidator:
77
+ """Require columns to be numeric types.
78
+
79
+ Args:
80
+ columns: Column names that must be numeric
81
+
82
+ Returns:
83
+ Self for chaining
84
+
85
+ Raises:
86
+ ValidationError: If columns not numeric
87
+ """
88
+ non_numeric = []
89
+
90
+ for col in columns:
91
+ if col not in self.df.columns:
92
+ continue
93
+
94
+ dtype = self.df[col].dtype
95
+ if not dtype.is_numeric():
96
+ non_numeric.append((col, str(dtype)))
97
+
98
+ if non_numeric:
99
+ raise ValidationError(
100
+ f"Non-numeric columns: {[col for col, _ in non_numeric]}",
101
+ context={
102
+ "expected": "numeric types (Int*, Float*, Decimal)",
103
+ "actual": dict(non_numeric),
104
+ },
105
+ )
106
+
107
+ return self
108
+
109
+ def check_nulls(
110
+ self, columns: list[str] | None = None, allow_nulls: bool = False
111
+ ) -> DataFrameValidator:
112
+ """Check for null values in columns.
113
+
114
+ Args:
115
+ columns: Columns to check (None = all columns)
116
+ allow_nulls: Whether nulls are allowed
117
+
118
+ Returns:
119
+ Self for chaining
120
+
121
+ Raises:
122
+ ValidationError: If nulls found when not allowed
123
+ """
124
+ check_columns = columns or self.df.columns
125
+
126
+ if not allow_nulls:
127
+ null_counts = {}
128
+
129
+ for col in check_columns:
130
+ if col not in self.df.columns:
131
+ continue
132
+
133
+ null_count = self.df[col].null_count()
134
+ if null_count > 0:
135
+ null_counts[col] = null_count
136
+
137
+ if null_counts:
138
+ total_nulls = sum(null_counts.values())
139
+ raise ValidationError(
140
+ f"Found {total_nulls} null values",
141
+ context={
142
+ "null_columns": list(null_counts.keys()),
143
+ "null_counts": null_counts,
144
+ },
145
+ )
146
+
147
+ return self
148
+
149
+ def check_empty(self) -> DataFrameValidator:
150
+ """Check if DataFrame is empty.
151
+
152
+ Returns:
153
+ Self for chaining
154
+
155
+ Raises:
156
+ ValidationError: If DataFrame is empty
157
+ """
158
+ if len(self.df) == 0:
159
+ raise ValidationError(
160
+ "DataFrame is empty",
161
+ context={"shape": self.df.shape, "columns": self.df.columns},
162
+ )
163
+
164
+ return self
165
+
166
+ def check_min_rows(self, min_rows: int) -> DataFrameValidator:
167
+ """Check minimum number of rows.
168
+
169
+ Args:
170
+ min_rows: Minimum required rows
171
+
172
+ Returns:
173
+ Self for chaining
174
+
175
+ Raises:
176
+ ValidationError: If too few rows
177
+ """
178
+ if len(self.df) < min_rows:
179
+ raise ValidationError(
180
+ f"Insufficient rows: {len(self.df)} < {min_rows}",
181
+ context={"required": min_rows, "actual": len(self.df)},
182
+ )
183
+
184
+ return self
185
+
186
+
187
+ def validate_dataframe(
188
+ df: pl.DataFrame,
189
+ required_columns: list[str] | None = None,
190
+ numeric_columns: list[str] | None = None,
191
+ allow_nulls: bool = True,
192
+ min_rows: int = 1,
193
+ ) -> None:
194
+ """Validate DataFrame structure and content.
195
+
196
+ Args:
197
+ df: DataFrame to validate
198
+ required_columns: Columns that must exist
199
+ numeric_columns: Columns that must be numeric
200
+ allow_nulls: Whether null values are allowed
201
+ min_rows: Minimum number of rows required
202
+
203
+ Raises:
204
+ ValidationError: If validation fails
205
+
206
+ Examples:
207
+ >>> validate_dataframe(
208
+ ... df,
209
+ ... required_columns=["close", "volume"],
210
+ ... numeric_columns=["close", "volume"],
211
+ ... allow_nulls=False,
212
+ ... min_rows=100
213
+ ... )
214
+ """
215
+ validator = DataFrameValidator(df)
216
+
217
+ validator.check_empty().check_min_rows(min_rows)
218
+
219
+ if required_columns:
220
+ validator.require_columns(required_columns)
221
+
222
+ if numeric_columns:
223
+ validator.require_numeric(numeric_columns)
224
+
225
+ if not allow_nulls:
226
+ validator.check_nulls(allow_nulls=False)
227
+
228
+
229
+ def validate_schema(df: pl.DataFrame, expected_schema: dict[str, str | type]) -> None:
230
+ """Validate DataFrame schema matches expected types.
231
+
232
+ Args:
233
+ df: DataFrame to validate
234
+ expected_schema: Map of column name to expected type
235
+
236
+ Raises:
237
+ ValidationError: If schema doesn't match
238
+
239
+ Examples:
240
+ >>> validate_schema(df, {
241
+ ... "close": "Float64",
242
+ ... "volume": "Int64",
243
+ ... "date": pl.Date
244
+ ... })
245
+ """
246
+ mismatches = {}
247
+
248
+ for col, expected_type in expected_schema.items():
249
+ if col not in df.columns:
250
+ mismatches[col] = ("missing", expected_type)
251
+ continue
252
+
253
+ actual_type = df[col].dtype
254
+
255
+ # Handle string type names
256
+ if isinstance(expected_type, str):
257
+ expected_type_str = expected_type
258
+ actual_type_str = str(actual_type)
259
+ else:
260
+ expected_type_str = str(expected_type)
261
+ actual_type_str = str(actual_type)
262
+
263
+ if expected_type_str not in actual_type_str:
264
+ mismatches[col] = (actual_type_str, expected_type_str)
265
+
266
+ if mismatches:
267
+ raise ValidationError(
268
+ "Schema mismatch",
269
+ context={
270
+ "mismatches": {
271
+ col: f"expected {exp}, got {act}" for col, (act, exp) in mismatches.items()
272
+ }
273
+ },
274
+ )
@@ -0,0 +1,280 @@
1
+ """Returns validation utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import SupportsFloat, cast
6
+
7
+ import polars as pl
8
+
9
+ from ml4t.diagnostic.validation.dataframe import ValidationError
10
+
11
+
12
+ class ReturnsValidator:
13
+ """Validator for returns series.
14
+
15
+ Examples:
16
+ >>> validator = ReturnsValidator(returns)
17
+ >>> validator.check_numeric()
18
+ >>> validator.check_bounds(-0.5, 0.5)
19
+ >>> validator.check_distribution()
20
+ """
21
+
22
+ def __init__(self, returns: pl.Series | pl.DataFrame, column: str | None = None):
23
+ """Initialize validator.
24
+
25
+ Args:
26
+ returns: Returns series or DataFrame
27
+ column: Column name if DataFrame provided
28
+ """
29
+ if isinstance(returns, pl.DataFrame):
30
+ if column is None:
31
+ raise ValueError("column required when passing DataFrame")
32
+ self.returns = returns[column]
33
+ else:
34
+ self.returns = returns
35
+
36
+ def check_numeric(self) -> ReturnsValidator:
37
+ """Check that returns are numeric.
38
+
39
+ Returns:
40
+ Self for chaining
41
+
42
+ Raises:
43
+ ValidationError: If not numeric
44
+ """
45
+ if not self.returns.dtype.is_numeric():
46
+ raise ValidationError(
47
+ "Returns must be numeric",
48
+ context={"dtype": str(self.returns.dtype)},
49
+ )
50
+
51
+ return self
52
+
53
+ def check_bounds(
54
+ self, lower: float | None = None, upper: float | None = None
55
+ ) -> ReturnsValidator:
56
+ """Check returns fall within bounds.
57
+
58
+ Args:
59
+ lower: Lower bound (inclusive)
60
+ upper: Upper bound (inclusive)
61
+
62
+ Returns:
63
+ Self for chaining
64
+
65
+ Raises:
66
+ ValidationError: If returns out of bounds
67
+ """
68
+ self.check_numeric()
69
+
70
+ # Drop nulls for bounds checking
71
+ clean_returns = self.returns.drop_nulls()
72
+
73
+ if len(clean_returns) == 0:
74
+ return self
75
+
76
+ if lower is not None:
77
+ min_val = float(cast(SupportsFloat, clean_returns.min()))
78
+ if min_val < lower:
79
+ out_of_bounds = (clean_returns < lower).sum()
80
+ raise ValidationError(
81
+ f"Returns below lower bound: {min_val:.4f} < {lower}",
82
+ context={
83
+ "lower_bound": lower,
84
+ "min_value": min_val,
85
+ "count_out_of_bounds": out_of_bounds,
86
+ },
87
+ )
88
+
89
+ if upper is not None:
90
+ max_val = float(cast(SupportsFloat, clean_returns.max()))
91
+ if max_val > upper:
92
+ out_of_bounds = (clean_returns > upper).sum()
93
+ raise ValidationError(
94
+ f"Returns above upper bound: {max_val:.4f} > {upper}",
95
+ context={
96
+ "upper_bound": upper,
97
+ "max_value": max_val,
98
+ "count_out_of_bounds": out_of_bounds,
99
+ },
100
+ )
101
+
102
+ return self
103
+
104
+ def check_finite(self) -> ReturnsValidator:
105
+ """Check for infinite values.
106
+
107
+ Returns:
108
+ Self for chaining
109
+
110
+ Raises:
111
+ ValidationError: If infinite values found
112
+ """
113
+ self.check_numeric()
114
+
115
+ # Check for inf/-inf
116
+ is_inf = self.returns.is_infinite()
117
+ inf_count = is_inf.sum()
118
+
119
+ if inf_count > 0:
120
+ raise ValidationError(
121
+ f"Found {inf_count} infinite values",
122
+ context={"infinite_count": inf_count},
123
+ )
124
+
125
+ return self
126
+
127
+ def check_nulls(self, allow_nulls: bool = False) -> ReturnsValidator:
128
+ """Check for null values.
129
+
130
+ Args:
131
+ allow_nulls: Whether nulls are allowed
132
+
133
+ Returns:
134
+ Self for chaining
135
+
136
+ Raises:
137
+ ValidationError: If nulls found when not allowed
138
+ """
139
+ if not allow_nulls:
140
+ null_count = self.returns.null_count()
141
+
142
+ if null_count > 0:
143
+ raise ValidationError(
144
+ f"Found {null_count} null values",
145
+ context={
146
+ "null_count": null_count,
147
+ "total_count": len(self.returns),
148
+ },
149
+ )
150
+
151
+ return self
152
+
153
+ def check_distribution(
154
+ self,
155
+ max_abs_skew: float | None = None,
156
+ max_abs_kurtosis: float | None = None,
157
+ ) -> ReturnsValidator:
158
+ """Check distribution characteristics.
159
+
160
+ Args:
161
+ max_abs_skew: Maximum absolute skewness (None = no check)
162
+ max_abs_kurtosis: Maximum absolute excess kurtosis (None = no check)
163
+
164
+ Returns:
165
+ Self for chaining
166
+
167
+ Raises:
168
+ ValidationError: If distribution extreme
169
+ """
170
+ self.check_numeric()
171
+
172
+ clean_returns = self.returns.drop_nulls()
173
+
174
+ if len(clean_returns) < 30:
175
+ # Need sufficient data for distribution checks
176
+ return self
177
+
178
+ if max_abs_skew is not None:
179
+ # Calculate skewness (simplified)
180
+ mean = float(cast(SupportsFloat, clean_returns.mean()))
181
+ std = float(cast(SupportsFloat, clean_returns.std()))
182
+
183
+ if std > 0:
184
+ skew = float(cast(SupportsFloat, ((clean_returns - mean) ** 3).mean())) / (std**3)
185
+
186
+ if abs(skew) > max_abs_skew:
187
+ raise ValidationError(
188
+ f"Extreme skewness detected: {skew:.2f}",
189
+ context={
190
+ "skewness": skew,
191
+ "max_allowed": max_abs_skew,
192
+ },
193
+ )
194
+
195
+ if max_abs_kurtosis is not None:
196
+ # Calculate excess kurtosis (simplified)
197
+ mean = float(cast(SupportsFloat, clean_returns.mean()))
198
+ std = float(cast(SupportsFloat, clean_returns.std()))
199
+
200
+ if std > 0:
201
+ kurtosis = (
202
+ float(cast(SupportsFloat, ((clean_returns - mean) ** 4).mean())) / (std**4) - 3
203
+ )
204
+
205
+ if abs(kurtosis) > max_abs_kurtosis:
206
+ raise ValidationError(
207
+ f"Extreme kurtosis detected: {kurtosis:.2f}",
208
+ context={
209
+ "kurtosis": kurtosis,
210
+ "max_allowed": max_abs_kurtosis,
211
+ },
212
+ )
213
+
214
+ return self
215
+
216
+
217
+ def validate_returns(
218
+ returns: pl.Series | pl.DataFrame,
219
+ column: str | None = None,
220
+ bounds: tuple[float, float] | None = None,
221
+ allow_nulls: bool = False,
222
+ check_finite: bool = True,
223
+ ) -> None:
224
+ """Validate returns series.
225
+
226
+ Args:
227
+ returns: Returns series or DataFrame
228
+ column: Column name if DataFrame
229
+ bounds: (lower, upper) bounds for returns
230
+ allow_nulls: Whether null values allowed
231
+ check_finite: Whether to check for infinite values
232
+
233
+ Raises:
234
+ ValidationError: If validation fails
235
+
236
+ Examples:
237
+ >>> validate_returns(
238
+ ... returns,
239
+ ... bounds=(-0.5, 0.5),
240
+ ... allow_nulls=False,
241
+ ... check_finite=True
242
+ ... )
243
+ """
244
+ validator = ReturnsValidator(returns, column)
245
+
246
+ validator.check_numeric()
247
+
248
+ if not allow_nulls:
249
+ validator.check_nulls(allow_nulls=False)
250
+
251
+ if check_finite:
252
+ validator.check_finite()
253
+
254
+ if bounds is not None:
255
+ lower, upper = bounds
256
+ validator.check_bounds(lower, upper)
257
+
258
+
259
+ def validate_bounds(
260
+ returns: pl.Series | pl.DataFrame,
261
+ column: str | None = None,
262
+ lower: float | None = None,
263
+ upper: float | None = None,
264
+ ) -> None:
265
+ """Validate returns fall within bounds.
266
+
267
+ Args:
268
+ returns: Returns series or DataFrame
269
+ column: Column name if DataFrame
270
+ lower: Lower bound (inclusive)
271
+ upper: Upper bound (inclusive)
272
+
273
+ Raises:
274
+ ValidationError: If returns out of bounds
275
+
276
+ Examples:
277
+ >>> validate_bounds(returns, lower=-1.0, upper=1.0)
278
+ """
279
+ validator = ReturnsValidator(returns, column)
280
+ validator.check_numeric().check_bounds(lower, upper)