ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,412 @@
1
+ """Weights & Biases integration for experiment tracking.
2
+
3
+ This module provides hooks for logging ml4t-diagnostic experiments to W&B,
4
+ enabling tracking of evaluation metrics, hyperparameters, and
5
+ visualizations across experiments.
6
+ """
7
+
8
+ import numbers
9
+ import warnings
10
+ from typing import Any, SupportsFloat, cast
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+ try:
16
+ import wandb # type: ignore[import-not-found,unused-ignore]
17
+
18
+ HAS_WANDB = True
19
+ except ImportError:
20
+ HAS_WANDB = False
21
+
22
+
23
+ class WandbLogger:
24
+ """Logger for Weights & Biases experiment tracking.
25
+
26
+ This class provides a unified interface for logging ml4t-diagnostic
27
+ experiments to W&B, handling initialization, metric logging,
28
+ and artifact management.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ project: str | None = None,
34
+ entity: str | None = None,
35
+ name: str | None = None,
36
+ config: dict[str, Any] | None = None,
37
+ tags: list[str] | None = None,
38
+ notes: str | None = None,
39
+ disabled: bool = False,
40
+ ):
41
+ """Initialize W&B logger.
42
+
43
+ Parameters
44
+ ----------
45
+ project : str, optional
46
+ W&B project name
47
+ entity : str, optional
48
+ W&B entity (team or username)
49
+ name : str, optional
50
+ Run name
51
+ config : dict, optional
52
+ Configuration dictionary to log
53
+ tags : list[str], optional
54
+ Tags for the run
55
+ notes : str, optional
56
+ Notes about the run
57
+ disabled : bool
58
+ If True, disables W&B logging
59
+ """
60
+ self.disabled = disabled or not HAS_WANDB
61
+ self.run = None
62
+
63
+ if self.disabled:
64
+ if not HAS_WANDB and not disabled:
65
+ warnings.warn(
66
+ "wandb not installed. Install with: pip install wandb",
67
+ stacklevel=2,
68
+ )
69
+ return
70
+
71
+ # Initialize W&B run
72
+ self.run = wandb.init(
73
+ project=project or "ml4t-diagnostic",
74
+ entity=entity,
75
+ name=name,
76
+ config=config,
77
+ tags=tags or [],
78
+ notes=notes,
79
+ reinit=True,
80
+ )
81
+
82
+ def log_config(self, config: dict[str, Any]) -> None:
83
+ """Log configuration parameters.
84
+
85
+ Parameters
86
+ ----------
87
+ config : dict
88
+ Configuration dictionary
89
+ """
90
+ if self.disabled or self.run is None:
91
+ return
92
+
93
+ # Flatten nested config for W&B
94
+ flat_config = self._flatten_dict(config)
95
+ wandb.config.update(flat_config)
96
+
97
+ def log_metrics(
98
+ self,
99
+ metrics: dict[str, Any],
100
+ step: int | None = None,
101
+ prefix: str = "",
102
+ ) -> None:
103
+ """Log evaluation metrics.
104
+
105
+ Parameters
106
+ ----------
107
+ metrics : dict
108
+ Metrics to log
109
+ step : int, optional
110
+ Step number (e.g., CV fold)
111
+ prefix : str
112
+ Prefix for metric names
113
+ """
114
+ if self.disabled or self.run is None:
115
+ return
116
+
117
+ # Prepare metrics for logging
118
+ log_dict = {}
119
+
120
+ for name, value in metrics.items():
121
+ key = f"{prefix}{name}" if prefix else name
122
+
123
+ if isinstance(value, dict):
124
+ # Handle nested metrics (e.g., with confidence intervals)
125
+ for sub_key, sub_value in value.items():
126
+ if isinstance(sub_value, numbers.Number):
127
+ log_dict[f"{key}/{sub_key}"] = float(cast(SupportsFloat, sub_value))
128
+ elif isinstance(value, numbers.Number):
129
+ log_dict[key] = float(cast(SupportsFloat, value))
130
+ elif isinstance(value, list | np.ndarray):
131
+ # Log array statistics
132
+ if len(value) > 0:
133
+ log_dict[f"{key}/mean"] = float(np.mean(value))
134
+ log_dict[f"{key}/std"] = float(np.std(value))
135
+ log_dict[f"{key}/min"] = float(np.min(value))
136
+ log_dict[f"{key}/max"] = float(np.max(value))
137
+
138
+ if step is not None:
139
+ log_dict["step"] = step
140
+
141
+ wandb.log(log_dict)
142
+
143
+ def log_fold_results(
144
+ self,
145
+ fold_idx: int,
146
+ train_size: int,
147
+ test_size: int,
148
+ metrics: dict[str, Any],
149
+ ) -> None:
150
+ """Log results from a single CV fold.
151
+
152
+ Parameters
153
+ ----------
154
+ fold_idx : int
155
+ Fold index
156
+ train_size : int
157
+ Training set size
158
+ test_size : int
159
+ Test set size
160
+ metrics : dict
161
+ Fold metrics
162
+ """
163
+ if self.disabled or self.run is None:
164
+ return
165
+
166
+ # Add metrics with fold prefix
167
+ self.log_metrics(metrics, step=fold_idx, prefix="fold/")
168
+
169
+ # Log fold metadata
170
+ wandb.log(
171
+ {
172
+ "fold/train_size": train_size,
173
+ "fold/test_size": test_size,
174
+ "fold/train_test_ratio": train_size / test_size if test_size > 0 else 0,
175
+ },
176
+ step=fold_idx,
177
+ )
178
+
179
+ def log_statistical_tests(self, tests: dict[str, Any]) -> None:
180
+ """Log statistical test results.
181
+
182
+ Parameters
183
+ ----------
184
+ tests : dict
185
+ Statistical test results
186
+ """
187
+ if self.disabled or self.run is None:
188
+ return
189
+
190
+ log_dict = {}
191
+
192
+ for test_name, result in tests.items():
193
+ if isinstance(result, dict):
194
+ for key, value in result.items():
195
+ if isinstance(value, numbers.Number):
196
+ log_dict[f"stats/{test_name}/{key}"] = float(cast(SupportsFloat, value))
197
+ elif key == "significant" and isinstance(value, bool):
198
+ log_dict[f"stats/{test_name}/{key}"] = int(value)
199
+
200
+ wandb.log(log_dict)
201
+
202
+ def log_figure(
203
+ self,
204
+ figure: Any,
205
+ name: str,
206
+ step: int | None = None,
207
+ ) -> None:
208
+ """Log a Plotly figure.
209
+
210
+ Parameters
211
+ ----------
212
+ figure : plotly.graph_objects.Figure
213
+ Figure to log
214
+ name : str
215
+ Figure name
216
+ step : int, optional
217
+ Step number
218
+ """
219
+ if self.disabled or self.run is None:
220
+ return
221
+
222
+ # Convert Plotly figure to W&B
223
+ wandb.log({f"plots/{name}": figure}, step=step)
224
+
225
+ def log_evaluation_summary(
226
+ self,
227
+ result: Any, # EvaluationResult
228
+ _predictions: Any | None = None,
229
+ _returns: Any | None = None,
230
+ ) -> None:
231
+ """Log complete evaluation summary.
232
+
233
+ Parameters
234
+ ----------
235
+ result : EvaluationResult
236
+ Evaluation result object
237
+ predictions : array-like, optional
238
+ Predictions for additional logging
239
+ returns : array-like, optional
240
+ Returns for additional logging
241
+ """
242
+ if self.disabled or self.run is None:
243
+ return
244
+
245
+ # Log summary metrics
246
+ summary = result.summary()
247
+
248
+ # Log aggregate metrics
249
+ self.log_metrics(summary["metrics"], prefix="summary/")
250
+
251
+ # Log statistical tests
252
+ if summary.get("statistical_tests"):
253
+ self.log_statistical_tests(summary["statistical_tests"])
254
+
255
+ # Log metadata
256
+ wandb.log(
257
+ {
258
+ "summary/tier": result.tier,
259
+ "summary/n_folds": summary["n_folds"],
260
+ "summary/splitter": result.splitter_name,
261
+ },
262
+ )
263
+
264
+ # Create summary table
265
+ if result.fold_results:
266
+ fold_data = []
267
+ for fold in result.fold_results:
268
+ fold_row = {"fold": fold.get("fold", 0)}
269
+ fold_row.update(
270
+ {k: v for k, v in fold.items() if isinstance(v, numbers.Number)},
271
+ )
272
+ fold_data.append(fold_row)
273
+
274
+ fold_table = wandb.Table(dataframe=pd.DataFrame(fold_data))
275
+ wandb.log({"tables/fold_results": fold_table})
276
+
277
+ def log_artifact(
278
+ self,
279
+ artifact_path: str,
280
+ name: str,
281
+ artifact_type: str = "evaluation",
282
+ metadata: dict[str, Any] | None = None,
283
+ ) -> None:
284
+ """Log an artifact (model, dataset, etc.).
285
+
286
+ Parameters
287
+ ----------
288
+ artifact_path : str
289
+ Path to artifact file
290
+ name : str
291
+ Artifact name
292
+ artifact_type : str
293
+ Type of artifact
294
+ metadata : dict, optional
295
+ Additional metadata
296
+ """
297
+ if self.disabled or self.run is None:
298
+ return
299
+
300
+ artifact = wandb.Artifact(
301
+ name=name,
302
+ type=artifact_type,
303
+ metadata=metadata or {},
304
+ )
305
+ artifact.add_file(artifact_path)
306
+ wandb.log_artifact(artifact)
307
+
308
+ def finish(self) -> None:
309
+ """Finish the W&B run."""
310
+ if self.disabled or self.run is None:
311
+ return
312
+
313
+ wandb.finish()
314
+
315
+ def __enter__(self):
316
+ """Context manager entry."""
317
+ return self
318
+
319
+ def __exit__(self, exc_type, exc_val, exc_tb):
320
+ """Context manager exit."""
321
+ self.finish()
322
+
323
+ @staticmethod
324
+ def _flatten_dict(
325
+ d: dict[str, Any],
326
+ parent_key: str = "",
327
+ sep: str = "/",
328
+ ) -> dict[str, Any]:
329
+ """Flatten nested dictionary."""
330
+ items: list[tuple[str, Any]] = []
331
+
332
+ for k, v in d.items():
333
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
334
+
335
+ if isinstance(v, dict):
336
+ items.extend(WandbLogger._flatten_dict(v, new_key, sep=sep).items())
337
+ else:
338
+ items.append((new_key, v))
339
+
340
+ return dict(items)
341
+
342
+
343
+ def log_experiment(
344
+ evaluator: Any,
345
+ X: Any,
346
+ y: Any,
347
+ model: Any,
348
+ project: str | None = None,
349
+ config: dict[str, Any] | None = None,
350
+ tags: list[str] | None = None,
351
+ **kwargs: Any,
352
+ ) -> Any:
353
+ """Convenience function to run and log an experiment.
354
+
355
+ Parameters
356
+ ----------
357
+ evaluator : ml4t-diagnostic.Evaluator
358
+ Configured evaluator
359
+ X : array-like
360
+ Features
361
+ y : array-like
362
+ Labels
363
+ model : estimator
364
+ Model to evaluate
365
+ project : str, optional
366
+ W&B project name
367
+ config : dict, optional
368
+ Additional config to log
369
+ tags : list[str], optional
370
+ Experiment tags
371
+ **kwargs : Any
372
+ Additional arguments passed to evaluate()
373
+
374
+ Returns:
375
+ -------
376
+ EvaluationResult
377
+ Result with W&B logging
378
+ """
379
+ if not HAS_WANDB:
380
+ warnings.warn(
381
+ "wandb not installed. Running without logging. Install with: pip install wandb",
382
+ stacklevel=2,
383
+ )
384
+ return evaluator.evaluate(X, y, model, **kwargs)
385
+
386
+ # Initialize logger
387
+ with WandbLogger(project=project, config=config, tags=tags) as logger:
388
+ # Log evaluator configuration
389
+ logger.log_config(
390
+ {
391
+ "evaluator": {
392
+ "tier": evaluator.tier,
393
+ "splitter": evaluator.splitter.__class__.__name__,
394
+ "metrics": evaluator.metrics,
395
+ "statistical_tests": evaluator.statistical_tests,
396
+ "confidence_level": evaluator.confidence_level,
397
+ "bootstrap_samples": evaluator.bootstrap_samples,
398
+ },
399
+ },
400
+ )
401
+
402
+ # Log model info if available
403
+ if hasattr(model, "get_params"):
404
+ logger.log_config({"model": model.get_params()})
405
+
406
+ # Run evaluation
407
+ result = evaluator.evaluate(X, y, model, **kwargs)
408
+
409
+ # Log results
410
+ logger.log_evaluation_summary(result)
411
+
412
+ return result
@@ -0,0 +1,9 @@
1
+ """
2
+ Metrics module for ML4T Diagnostic.
3
+
4
+ Provides statistical metrics and percentile computation utilities for model evaluation.
5
+ """
6
+
7
+ from ml4t.diagnostic.metrics.percentiles import compute_fold_percentiles
8
+
9
+ __all__ = ["compute_fold_percentiles"]
@@ -0,0 +1,128 @@
1
+ """
2
+ Percentile computation utilities for threshold-based signal generation.
3
+
4
+ Provides fast percentile computation from fold-specific predictions using Polars,
5
+ designed to prevent data leakage by computing thresholds from training data only.
6
+ """
7
+
8
+ from collections.abc import Sequence
9
+
10
+ import pandas as pd
11
+ import polars as pl
12
+
13
+
14
+ def compute_fold_percentiles(
15
+ predictions: pd.DataFrame | pl.DataFrame,
16
+ percentiles: Sequence[float],
17
+ fold_col: str = "fold_id",
18
+ iteration_col: str = "iteration",
19
+ prediction_col: str = "prediction",
20
+ verbose: bool = True,
21
+ ) -> pd.DataFrame:
22
+ """
23
+ Compute percentiles from predictions grouped by fold and iteration.
24
+
25
+ Uses efficient Polars group_by operation to compute percentiles 10-50x faster
26
+ than nested loops. Designed for threshold-based signal generation where
27
+ thresholds must be computed from TRAINING predictions only to prevent data leakage.
28
+
29
+ Performance: ~50-100ms for 89M predictions with 26 percentiles (vs 5-10s with loops)
30
+
31
+ Args:
32
+ predictions: DataFrame with predictions to compute percentiles from
33
+ Must contain: fold_col, iteration_col, prediction_col
34
+ percentiles: List of percentiles to compute (e.g., [0.1, 0.5, 1, ..., 99, 99.5, 99.9])
35
+ Values should be in range [0, 100]
36
+ fold_col: Name of fold identifier column (default: "fold_id")
37
+ iteration_col: Name of iteration/checkpoint column (default: "iteration")
38
+ prediction_col: Name of prediction values column (default: "prediction")
39
+ verbose: Print progress information (default: True)
40
+
41
+ Returns:
42
+ DataFrame with columns: [fold_col, iteration_col, p{percentile}, ...]
43
+ - One row per (fold, iteration) combination
44
+ - Percentile columns named like "p0.1", "p99.9", etc.
45
+
46
+ Example:
47
+ >>> # Training predictions: 13 folds × 10 iterations × 687k samples
48
+ >>> import pandas as pd
49
+ >>> predictions = pd.DataFrame({
50
+ ... 'fold_id': [0] * 1000 + [1] * 1000,
51
+ ... 'iteration': [50] * 500 + [100] * 500 + [50] * 500 + [100] * 500,
52
+ ... 'prediction': np.random.rand(2000)
53
+ ... })
54
+ >>>
55
+ >>> # Compute percentiles for LONG and SHORT strategies
56
+ >>> percentiles = [0.1, 0.5, 1, 5, 10, 90, 95, 99, 99.5, 99.9]
57
+ >>> thresholds = compute_fold_percentiles(predictions, percentiles)
58
+ >>>
59
+ >>> # Result: 2 rows (2 folds) × 2 iterations = 4 rows
60
+ >>> thresholds.shape
61
+ (4, 12) # 2 meta columns + 10 percentile columns
62
+ >>>
63
+ >>> # Use for signal generation
64
+ >>> fold_0_iter_100 = thresholds[
65
+ ... (thresholds['fold_id'] == 0) & (thresholds['iteration'] == 100)
66
+ ... ]
67
+ >>> long_threshold = fold_0_iter_100['p95'].values[0]
68
+ >>> short_threshold = fold_0_iter_100['p5'].values[0]
69
+
70
+ Methodology:
71
+ 1. Convert predictions to Polars (if pandas)
72
+ 2. Group by (fold_id, iteration)
73
+ 3. Compute all percentiles in single aggregation
74
+ 4. Return as pandas DataFrame
75
+
76
+ Data Leakage Prevention:
77
+ CRITICAL: This function should ONLY be called on TRAINING predictions.
78
+ - Training: compute_fold_percentiles(train_predictions) → save thresholds
79
+ - Validation: Apply saved thresholds to OOS predictions
80
+ - NEVER: compute_fold_percentiles(val_predictions) → data leakage!
81
+
82
+ Performance Notes:
83
+ - Polars group_by is 10-50x faster than nested loops
84
+ - Memory usage: O(n_predictions) for single pass
85
+ - Time complexity: O(n * log(n)) for sorting within groups
86
+ - Recommended for predictions > 1M rows
87
+ """
88
+ if verbose:
89
+ print("\nComputing fold-specific percentiles (Fast Polars Method)...")
90
+
91
+ # Convert to Polars if pandas
92
+ preds_pl = pl.from_pandas(predictions) if isinstance(predictions, pd.DataFrame) else predictions
93
+
94
+ # Validate required columns
95
+ required_cols = {fold_col, iteration_col, prediction_col}
96
+ available_cols = set(preds_pl.columns)
97
+ missing = required_cols - available_cols
98
+ if missing:
99
+ raise ValueError(f"Missing required columns: {missing}. Available: {available_cols}")
100
+
101
+ # Convert percentiles to quantiles (0-1 range)
102
+ quantiles = [p / 100 for p in percentiles]
103
+
104
+ # Compute percentiles with single group_by operation
105
+ percentiles_df = (
106
+ preds_pl.group_by([fold_col, iteration_col])
107
+ .agg(
108
+ [
109
+ pl.col(prediction_col).quantile(q, interpolation="linear").alias(f"p{p}")
110
+ for q, p in zip(quantiles, percentiles, strict=False)
111
+ ]
112
+ )
113
+ .sort([fold_col, iteration_col])
114
+ )
115
+
116
+ # Convert back to pandas for compatibility
117
+ result = percentiles_df.to_pandas()
118
+
119
+ if verbose:
120
+ n_folds = result[fold_col].nunique()
121
+ n_iterations = result[iteration_col].nunique()
122
+ print(f"✓ Computed {len(result)} percentile arrays")
123
+ print(
124
+ f"✓ Structure: {n_folds} folds × {n_iterations} iterations × {len(percentiles)} percentiles"
125
+ )
126
+ print(f"✓ Percentile columns: {sorted([c for c in result.columns if c.startswith('p')])}")
127
+
128
+ return result
@@ -0,0 +1 @@
1
+ # PEP 561 marker file - this package supports type checking
@@ -0,0 +1,43 @@
1
+ """Report generation module for ML4T Diagnostic results.
2
+
3
+ Provides flexible report generation in multiple formats:
4
+ - HTML: Rich, styled reports with tables and charts
5
+ - JSON: Machine-readable structured output
6
+ - Markdown: Human-readable documentation
7
+
8
+ Examples:
9
+ >>> from ml4t.diagnostic.reporting import ReportFactory, ReportFormat
10
+ >>> from ml4t.diagnostic.results import FeatureDiagnosticsResult
11
+ >>>
12
+ >>> # Generate HTML report
13
+ >>> html_report = ReportFactory.render(result, ReportFormat.HTML)
14
+ >>>
15
+ >>> # Generate JSON report
16
+ >>> json_report = ReportFactory.render(result, ReportFormat.JSON, indent=4)
17
+ >>>
18
+ >>> # Generate Markdown report
19
+ >>> md_report = ReportFactory.render(result, ReportFormat.MARKDOWN)
20
+ >>>
21
+ >>> # Save to file
22
+ >>> generator = ReportFactory.create(ReportFormat.HTML)
23
+ >>> html = generator.render(result)
24
+ >>> generator.save(html, "report.html")
25
+ """
26
+
27
+ from ml4t.diagnostic.reporting.base import ReportFactory, ReportFormat, ReportGenerator
28
+
29
+ # Import renderers to trigger registration
30
+ from ml4t.diagnostic.reporting.html_renderer import HTMLReportGenerator
31
+ from ml4t.diagnostic.reporting.json_renderer import JSONReportGenerator
32
+ from ml4t.diagnostic.reporting.markdown_renderer import MarkdownReportGenerator
33
+
34
+ __all__ = [
35
+ # Factory and base
36
+ "ReportFactory",
37
+ "ReportFormat",
38
+ "ReportGenerator",
39
+ # Renderers
40
+ "HTMLReportGenerator",
41
+ "JSONReportGenerator",
42
+ "MarkdownReportGenerator",
43
+ ]