ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,316 @@
1
+ """Fold persistence for cross-validation reproducibility.
2
+
3
+ This module provides utilities for saving and loading cross-validation fold
4
+ configurations, enabling reproducible research and efficient caching of expensive
5
+ split computations (especially for CPCV with many combinations).
6
+
7
+ Examples
8
+ --------
9
+ >>> from ml4t.diagnostic.splitters import PurgedWalkForwardCV
10
+ >>> from ml4t.diagnostic.splitters.persistence import save_folds, load_folds
11
+ >>>
12
+ >>> # Save fold configuration
13
+ >>> cv = PurgedWalkForwardCV(n_splits=5, test_size=100)
14
+ >>> folds = list(cv.split(X))
15
+ >>> save_folds(folds, X, "my_folds.json", metadata={"strategy": "walk_forward"})
16
+ >>>
17
+ >>> # Load and reuse fold configuration
18
+ >>> loaded_folds, metadata = load_folds("my_folds.json")
19
+ >>> for train_idx, test_idx in loaded_folds:
20
+ >>> # Use same splits as original
21
+ >>> pass
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import json
27
+ from pathlib import Path
28
+ from typing import Any
29
+
30
+ import numpy as np
31
+ import pandas as pd
32
+ import polars as pl
33
+ from numpy.typing import NDArray
34
+
35
+ from ml4t.diagnostic.config.base import BaseConfig
36
+
37
+
38
+ def save_folds(
39
+ folds: list[tuple[NDArray[np.int_], NDArray[np.int_]]],
40
+ X: NDArray[np.floating] | pd.DataFrame | pl.DataFrame,
41
+ filepath: str | Path,
42
+ *,
43
+ metadata: dict[str, Any] | None = None,
44
+ include_timestamps: bool = True,
45
+ ) -> None:
46
+ """Save cross-validation folds to disk.
47
+
48
+ Parameters
49
+ ----------
50
+ folds : list[tuple[NDArray, NDArray]]
51
+ List of (train_indices, test_indices) tuples from CV splitter.
52
+ X : array-like or DataFrame
53
+ Original data used for splitting (for timestamp extraction if DataFrame).
54
+ filepath : str or Path
55
+ Path to save fold configuration (JSON format).
56
+ metadata : dict, optional
57
+ Additional metadata to store (e.g., splitter config, data info).
58
+ include_timestamps : bool, default=True
59
+ If True and X is a DataFrame with DatetimeIndex, save timestamps
60
+ alongside indices for better human readability.
61
+
62
+ Examples
63
+ --------
64
+ >>> from ml4t.diagnostic.splitters import PurgedWalkForwardCV
65
+ >>> cv = PurgedWalkForwardCV(n_splits=5, test_size=100)
66
+ >>> folds = list(cv.split(X))
67
+ >>> save_folds(folds, X, "cv_folds.json", metadata={"n_splits": 5})
68
+ """
69
+ filepath = Path(filepath)
70
+
71
+ # Extract timestamps if available
72
+ timestamps = None
73
+ if include_timestamps and isinstance(X, pd.DataFrame | pd.Series):
74
+ if isinstance(X.index, pd.DatetimeIndex):
75
+ timestamps = X.index.astype(str).tolist()
76
+ elif include_timestamps and isinstance(X, pl.DataFrame):
77
+ # Polars doesn't have index, check if first column is datetime
78
+ first_col = X.columns[0]
79
+ if X[first_col].dtype == pl.Datetime:
80
+ timestamps = X[first_col].cast(pl.Utf8).to_list()
81
+
82
+ # Build fold data structure
83
+ fold_data: dict[str, Any] = {
84
+ "version": "1.0",
85
+ "n_folds": len(folds),
86
+ "n_samples": len(X),
87
+ "folds": [],
88
+ "metadata": metadata or {},
89
+ }
90
+
91
+ if timestamps:
92
+ fold_data["timestamps"] = timestamps
93
+
94
+ for fold_idx, (train_idx, test_idx) in enumerate(folds):
95
+ fold_info = {
96
+ "fold_id": fold_idx,
97
+ "train_indices": train_idx.tolist(),
98
+ "test_indices": test_idx.tolist(),
99
+ "train_size": len(train_idx),
100
+ "test_size": len(test_idx),
101
+ }
102
+
103
+ # Add timestamp ranges if available (handle empty folds)
104
+ if timestamps:
105
+ if len(train_idx) > 0:
106
+ fold_info["train_start"] = timestamps[train_idx[0]]
107
+ fold_info["train_end"] = timestamps[train_idx[-1]]
108
+ else:
109
+ fold_info["train_start"] = None
110
+ fold_info["train_end"] = None
111
+
112
+ if len(test_idx) > 0:
113
+ fold_info["test_start"] = timestamps[test_idx[0]]
114
+ fold_info["test_end"] = timestamps[test_idx[-1]]
115
+ else:
116
+ fold_info["test_start"] = None
117
+ fold_info["test_end"] = None
118
+
119
+ fold_data["folds"].append(fold_info)
120
+
121
+ # Save to JSON
122
+ filepath.parent.mkdir(parents=True, exist_ok=True)
123
+ with filepath.open("w") as f:
124
+ json.dump(fold_data, f, indent=2)
125
+
126
+
127
+ def load_folds(
128
+ filepath: str | Path,
129
+ ) -> tuple[list[tuple[NDArray[np.int_], NDArray[np.int_]]], dict[str, Any]]:
130
+ """Load cross-validation folds from disk.
131
+
132
+ Parameters
133
+ ----------
134
+ filepath : str or Path
135
+ Path to saved fold configuration (JSON format).
136
+
137
+ Returns
138
+ -------
139
+ folds : list[tuple[NDArray, NDArray]]
140
+ List of (train_indices, test_indices) tuples.
141
+ metadata : dict
142
+ Metadata dictionary stored with folds.
143
+
144
+ Examples
145
+ --------
146
+ >>> folds, metadata = load_folds("cv_folds.json")
147
+ >>> print(f"Loaded {len(folds)} folds")
148
+ >>> print(f"Metadata: {metadata}")
149
+ """
150
+ filepath = Path(filepath)
151
+
152
+ if not filepath.exists():
153
+ raise FileNotFoundError(f"Fold file not found: {filepath}")
154
+
155
+ with filepath.open("r") as f:
156
+ fold_data = json.load(f)
157
+
158
+ # Validate version
159
+ if fold_data.get("version") != "1.0":
160
+ raise ValueError(f"Unsupported fold file version: {fold_data.get('version')}")
161
+
162
+ # Reconstruct folds
163
+ folds = []
164
+ for fold_info in fold_data["folds"]:
165
+ train_idx = np.array(fold_info["train_indices"], dtype=np.int_)
166
+ test_idx = np.array(fold_info["test_indices"], dtype=np.int_)
167
+ folds.append((train_idx, test_idx))
168
+
169
+ metadata = fold_data.get("metadata", {})
170
+
171
+ return folds, metadata
172
+
173
+
174
+ def save_config(
175
+ config: Any, # SplitterConfig or subclass
176
+ filepath: str | Path,
177
+ ) -> None:
178
+ """Save splitter configuration to disk.
179
+
180
+ This is a convenience wrapper around config.to_json() for consistency
181
+ with the persistence API.
182
+
183
+ Parameters
184
+ ----------
185
+ config : SplitterConfig
186
+ Configuration object to save.
187
+ filepath : str or Path
188
+ Path to save configuration (JSON format).
189
+
190
+ Examples
191
+ --------
192
+ >>> from ml4t.diagnostic.splitters.config import PurgedWalkForwardConfig
193
+ >>> config = PurgedWalkForwardConfig(n_splits=5, test_size=100)
194
+ >>> save_config(config, "cv_config.json")
195
+ """
196
+ filepath = Path(filepath)
197
+ config.to_json(filepath)
198
+
199
+
200
+ def load_config(
201
+ filepath: str | Path,
202
+ config_class: type[BaseConfig],
203
+ ) -> BaseConfig:
204
+ """Load splitter configuration from disk.
205
+
206
+ This is a convenience wrapper around config_class.from_json() for consistency
207
+ with the persistence API.
208
+
209
+ Parameters
210
+ ----------
211
+ filepath : str or Path
212
+ Path to saved configuration (JSON format).
213
+ config_class : type
214
+ Configuration class to instantiate (e.g., PurgedWalkForwardConfig).
215
+
216
+ Returns
217
+ -------
218
+ config : SplitterConfig
219
+ Loaded configuration object.
220
+
221
+ Examples
222
+ --------
223
+ >>> from ml4t.diagnostic.splitters.config import PurgedWalkForwardConfig
224
+ >>> config = load_config("cv_config.json", PurgedWalkForwardConfig)
225
+ >>> print(config.n_splits)
226
+ """
227
+ filepath = Path(filepath)
228
+ return config_class.from_json(filepath)
229
+
230
+
231
+ def verify_folds(
232
+ folds: list[tuple[NDArray[np.int_], NDArray[np.int_]]],
233
+ n_samples: int,
234
+ ) -> dict[str, Any]:
235
+ """Verify fold integrity and compute statistics.
236
+
237
+ Parameters
238
+ ----------
239
+ folds : list[tuple[NDArray, NDArray]]
240
+ List of (train_indices, test_indices) tuples.
241
+ n_samples : int
242
+ Total number of samples in dataset.
243
+
244
+ Returns
245
+ -------
246
+ stats : dict
247
+ Dictionary containing fold statistics and validation results.
248
+
249
+ Examples
250
+ --------
251
+ >>> folds, _ = load_folds("cv_folds.json")
252
+ >>> stats = verify_folds(folds, n_samples=1000)
253
+ >>> print(f"Valid: {stats['valid']}")
254
+ >>> print(f"Coverage: {stats['coverage']:.1%}")
255
+ """
256
+ stats: dict[str, Any] = {
257
+ "valid": True,
258
+ "errors": [],
259
+ "n_folds": len(folds),
260
+ "n_samples": n_samples,
261
+ "train_sizes": [],
262
+ "test_sizes": [],
263
+ }
264
+
265
+ all_train_indices: set[int] = set()
266
+ all_test_indices: set[int] = set()
267
+
268
+ for fold_idx, (train_idx, test_idx) in enumerate(folds):
269
+ stats["train_sizes"].append(len(train_idx))
270
+ stats["test_sizes"].append(len(test_idx))
271
+
272
+ # Check for index overlap within fold
273
+ overlap = set(train_idx) & set(test_idx)
274
+ if overlap:
275
+ stats["valid"] = False
276
+ stats["errors"].append(
277
+ f"Fold {fold_idx}: {len(overlap)} overlapping indices between train and test"
278
+ )
279
+
280
+ # Check for out-of-range indices
281
+ if np.any(train_idx < 0) or np.any(train_idx >= n_samples):
282
+ stats["valid"] = False
283
+ stats["errors"].append(f"Fold {fold_idx}: Train indices out of range")
284
+
285
+ if np.any(test_idx < 0) or np.any(test_idx >= n_samples):
286
+ stats["valid"] = False
287
+ stats["errors"].append(f"Fold {fold_idx}: Test indices out of range")
288
+
289
+ all_train_indices.update(train_idx)
290
+ all_test_indices.update(test_idx)
291
+
292
+ # Compute coverage statistics
293
+ all_indices = all_train_indices | all_test_indices
294
+ stats["coverage"] = len(all_indices) / n_samples
295
+ stats["train_coverage"] = len(all_train_indices) / n_samples
296
+ stats["test_coverage"] = len(all_test_indices) / n_samples
297
+
298
+ # Compute size statistics
299
+ if stats["train_sizes"]:
300
+ train_sizes: list[int] = stats["train_sizes"]
301
+ test_sizes: list[int] = stats["test_sizes"]
302
+ stats["avg_train_size"] = np.mean(train_sizes)
303
+ stats["std_train_size"] = np.std(train_sizes)
304
+ stats["avg_test_size"] = np.mean(test_sizes)
305
+ stats["std_test_size"] = np.std(test_sizes)
306
+
307
+ return stats
308
+
309
+
310
+ __all__ = [
311
+ "save_folds",
312
+ "load_folds",
313
+ "save_config",
314
+ "load_config",
315
+ "verify_folds",
316
+ ]
@@ -0,0 +1,207 @@
1
+ """Utility functions for cross-validation splitters.
2
+
3
+ This module contains shared functionality used across different splitter
4
+ implementations, particularly for handling timestamp conversions and
5
+ boundary calculations.
6
+ """
7
+
8
+ from typing import Any, cast
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from pandas import Timedelta
13
+
14
+
15
+ def convert_indices_to_timestamps(
16
+ start_idx: int,
17
+ end_idx: int,
18
+ timestamps: pd.DatetimeIndex | np.ndarray | None = None,
19
+ ) -> tuple[int | Any, int | Any]:
20
+ """Convert indices to timestamps with robust boundary handling.
21
+
22
+ This function handles the conversion of array indices to timestamp values,
23
+ with robust estimation when the end index extends beyond available data.
24
+ It's designed to handle both regular and irregular time series frequencies.
25
+
26
+ Parameters
27
+ ----------
28
+ start_idx : int
29
+ Starting index
30
+ end_idx : int
31
+ Ending index (exclusive)
32
+ timestamps : pd.DatetimeIndex or np.ndarray, optional
33
+ Array of timestamps. If None, returns original indices.
34
+
35
+ Returns:
36
+ -------
37
+ tuple[Union[int, Any], Union[int, Any]]
38
+ (start_time, end_time) where times are either timestamps or indices
39
+
40
+ Examples:
41
+ --------
42
+ >>> import pandas as pd
43
+ >>> timestamps = pd.date_range('2020-01-01', periods=100, freq='D')
44
+ >>> start_time, end_time = convert_indices_to_timestamps(10, 20, timestamps)
45
+ >>> print(start_time, end_time)
46
+ 2020-01-11 00:00:00 2020-01-21 00:00:00
47
+
48
+ >>> # Handle end index beyond data
49
+ >>> start_time, end_time = convert_indices_to_timestamps(90, 105, timestamps)
50
+ >>> print(end_time) # Estimated based on frequency
51
+ 2020-04-15 00:00:00
52
+ """
53
+ if timestamps is None:
54
+ return start_idx, end_idx
55
+
56
+ # Convert start index (always available)
57
+ start_time = timestamps[start_idx]
58
+
59
+ # Handle end index with robust boundary checking
60
+ if end_idx < len(timestamps):
61
+ # Direct lookup when index is within bounds
62
+ end_time = timestamps[end_idx]
63
+ else:
64
+ # Estimate end time when beyond available data
65
+ end_time = _estimate_timestamp_beyond_data(end_idx, timestamps)
66
+
67
+ return start_time, end_time
68
+
69
+
70
+ def _estimate_timestamp_beyond_data(
71
+ target_idx: int,
72
+ timestamps: pd.DatetimeIndex | np.ndarray,
73
+ ) -> Any:
74
+ """Estimate timestamp for an index beyond available data.
75
+
76
+ This function provides robust timestamp estimation for irregular
77
+ time series by using multiple frequency estimation methods.
78
+
79
+ Parameters
80
+ ----------
81
+ target_idx : int
82
+ Target index beyond the timestamp array
83
+ timestamps : pd.DatetimeIndex or np.ndarray
84
+ Available timestamps
85
+
86
+ Returns:
87
+ -------
88
+ Any
89
+ Estimated timestamp
90
+ """
91
+ if len(timestamps) < 2:
92
+ # Can't estimate frequency with fewer than 2 points
93
+ return timestamps[-1]
94
+
95
+ # Calculate how many steps beyond the data we need
96
+ steps_beyond = target_idx - len(timestamps) + 1
97
+
98
+ if isinstance(timestamps, pd.DatetimeIndex):
99
+ # Use pandas DatetimeIndex inference for better frequency handling
100
+ try:
101
+ # Try to infer frequency from the index
102
+ freq = timestamps.freq or pd.infer_freq(timestamps)
103
+ if freq is not None:
104
+ # freq is DateOffset or str - arithmetic works at runtime
105
+ return cast(
106
+ Any, timestamps[-1] + steps_beyond * pd.tseries.frequencies.to_offset(freq)
107
+ )
108
+ except (ValueError, TypeError):
109
+ # Fall back to simple difference calculation
110
+ pass
111
+
112
+ # Robust frequency estimation using multiple methods
113
+ # estimated_freq can be Timedelta or np.timedelta64 depending on input type
114
+ estimated_freq: Timedelta | np.timedelta64 | Any
115
+ if len(timestamps) >= 10:
116
+ # Use median of recent differences for more robust estimation
117
+ recent_diffs = np.diff(timestamps[-10:])
118
+ # Sort and take middle value to preserve timedelta type
119
+ sorted_diffs = np.sort(recent_diffs)
120
+ mid_idx = len(sorted_diffs) // 2
121
+ estimated_freq = sorted_diffs[mid_idx]
122
+ elif len(timestamps) >= 3:
123
+ # Use median of all differences
124
+ all_diffs = np.diff(timestamps)
125
+ # Sort and take middle value to preserve timedelta type
126
+ sorted_diffs = np.sort(all_diffs)
127
+ mid_idx = len(sorted_diffs) // 2
128
+ estimated_freq = sorted_diffs[mid_idx]
129
+ else:
130
+ # Simple two-point difference
131
+ estimated_freq = timestamps[-1] - timestamps[-2]
132
+
133
+ # Estimate the target timestamp - cast needed for mixed datetime arithmetic
134
+ estimated_time: Any = timestamps[-1] + steps_beyond * estimated_freq
135
+
136
+ return estimated_time
137
+
138
+
139
+ def validate_timestamp_array(
140
+ timestamps: pd.DatetimeIndex | np.ndarray | None,
141
+ n_samples: int,
142
+ ) -> None:
143
+ """Validate timestamp array for use in cross-validation.
144
+
145
+ Parameters
146
+ ----------
147
+ timestamps : pd.DatetimeIndex or np.ndarray, optional
148
+ Timestamp array to validate
149
+ n_samples : int
150
+ Expected number of samples
151
+
152
+ Raises:
153
+ ------
154
+ ValueError
155
+ If timestamps are invalid or mismatched with sample count
156
+ """
157
+ if timestamps is None:
158
+ return
159
+
160
+ if len(timestamps) != n_samples:
161
+ raise ValueError(
162
+ f"Timestamp array length ({len(timestamps)}) does not match number of samples ({n_samples})",
163
+ )
164
+
165
+ if len(timestamps) > 1:
166
+ # Check for non-decreasing order (allows for duplicate timestamps)
167
+ if isinstance(timestamps, pd.DatetimeIndex):
168
+ if not timestamps.is_monotonic_increasing:
169
+ raise ValueError("Timestamps must be in non-decreasing order")
170
+ else:
171
+ if not np.all(np.diff(timestamps) >= 0):
172
+ raise ValueError("Timestamps must be in non-decreasing order")
173
+
174
+
175
+ def get_time_boundaries(
176
+ group_boundaries: list[tuple[int, int]],
177
+ group_indices: list[int],
178
+ timestamps: pd.DatetimeIndex | np.ndarray | None = None,
179
+ ) -> list[tuple[int | Any, int | Any]]:
180
+ """Convert multiple group boundaries from indices to timestamps.
181
+
182
+ Parameters
183
+ ----------
184
+ group_boundaries : list[tuple[int, int]]
185
+ List of (start_idx, end_idx) boundaries
186
+ group_indices : list[int]
187
+ Indices of groups to convert
188
+ timestamps : pd.DatetimeIndex or np.ndarray, optional
189
+ Timestamp array
190
+
191
+ Returns:
192
+ -------
193
+ list[tuple[Union[int, Any], Union[int, Any]]]
194
+ List of (start_time, end_time) boundaries
195
+ """
196
+ time_boundaries = []
197
+
198
+ for group_idx in group_indices:
199
+ start_idx, end_idx = group_boundaries[group_idx]
200
+ start_time, end_time = convert_indices_to_timestamps(
201
+ start_idx,
202
+ end_idx,
203
+ timestamps,
204
+ )
205
+ time_boundaries.append((start_time, end_time))
206
+
207
+ return time_boundaries