ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
"""Fold persistence for cross-validation reproducibility.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for saving and loading cross-validation fold
|
|
4
|
+
configurations, enabling reproducible research and efficient caching of expensive
|
|
5
|
+
split computations (especially for CPCV with many combinations).
|
|
6
|
+
|
|
7
|
+
Examples
|
|
8
|
+
--------
|
|
9
|
+
>>> from ml4t.diagnostic.splitters import PurgedWalkForwardCV
|
|
10
|
+
>>> from ml4t.diagnostic.splitters.persistence import save_folds, load_folds
|
|
11
|
+
>>>
|
|
12
|
+
>>> # Save fold configuration
|
|
13
|
+
>>> cv = PurgedWalkForwardCV(n_splits=5, test_size=100)
|
|
14
|
+
>>> folds = list(cv.split(X))
|
|
15
|
+
>>> save_folds(folds, X, "my_folds.json", metadata={"strategy": "walk_forward"})
|
|
16
|
+
>>>
|
|
17
|
+
>>> # Load and reuse fold configuration
|
|
18
|
+
>>> loaded_folds, metadata = load_folds("my_folds.json")
|
|
19
|
+
>>> for train_idx, test_idx in loaded_folds:
|
|
20
|
+
>>> # Use same splits as original
|
|
21
|
+
>>> pass
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import json
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
import numpy as np
|
|
31
|
+
import pandas as pd
|
|
32
|
+
import polars as pl
|
|
33
|
+
from numpy.typing import NDArray
|
|
34
|
+
|
|
35
|
+
from ml4t.diagnostic.config.base import BaseConfig
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def save_folds(
|
|
39
|
+
folds: list[tuple[NDArray[np.int_], NDArray[np.int_]]],
|
|
40
|
+
X: NDArray[np.floating] | pd.DataFrame | pl.DataFrame,
|
|
41
|
+
filepath: str | Path,
|
|
42
|
+
*,
|
|
43
|
+
metadata: dict[str, Any] | None = None,
|
|
44
|
+
include_timestamps: bool = True,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Save cross-validation folds to disk.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
folds : list[tuple[NDArray, NDArray]]
|
|
51
|
+
List of (train_indices, test_indices) tuples from CV splitter.
|
|
52
|
+
X : array-like or DataFrame
|
|
53
|
+
Original data used for splitting (for timestamp extraction if DataFrame).
|
|
54
|
+
filepath : str or Path
|
|
55
|
+
Path to save fold configuration (JSON format).
|
|
56
|
+
metadata : dict, optional
|
|
57
|
+
Additional metadata to store (e.g., splitter config, data info).
|
|
58
|
+
include_timestamps : bool, default=True
|
|
59
|
+
If True and X is a DataFrame with DatetimeIndex, save timestamps
|
|
60
|
+
alongside indices for better human readability.
|
|
61
|
+
|
|
62
|
+
Examples
|
|
63
|
+
--------
|
|
64
|
+
>>> from ml4t.diagnostic.splitters import PurgedWalkForwardCV
|
|
65
|
+
>>> cv = PurgedWalkForwardCV(n_splits=5, test_size=100)
|
|
66
|
+
>>> folds = list(cv.split(X))
|
|
67
|
+
>>> save_folds(folds, X, "cv_folds.json", metadata={"n_splits": 5})
|
|
68
|
+
"""
|
|
69
|
+
filepath = Path(filepath)
|
|
70
|
+
|
|
71
|
+
# Extract timestamps if available
|
|
72
|
+
timestamps = None
|
|
73
|
+
if include_timestamps and isinstance(X, pd.DataFrame | pd.Series):
|
|
74
|
+
if isinstance(X.index, pd.DatetimeIndex):
|
|
75
|
+
timestamps = X.index.astype(str).tolist()
|
|
76
|
+
elif include_timestamps and isinstance(X, pl.DataFrame):
|
|
77
|
+
# Polars doesn't have index, check if first column is datetime
|
|
78
|
+
first_col = X.columns[0]
|
|
79
|
+
if X[first_col].dtype == pl.Datetime:
|
|
80
|
+
timestamps = X[first_col].cast(pl.Utf8).to_list()
|
|
81
|
+
|
|
82
|
+
# Build fold data structure
|
|
83
|
+
fold_data: dict[str, Any] = {
|
|
84
|
+
"version": "1.0",
|
|
85
|
+
"n_folds": len(folds),
|
|
86
|
+
"n_samples": len(X),
|
|
87
|
+
"folds": [],
|
|
88
|
+
"metadata": metadata or {},
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if timestamps:
|
|
92
|
+
fold_data["timestamps"] = timestamps
|
|
93
|
+
|
|
94
|
+
for fold_idx, (train_idx, test_idx) in enumerate(folds):
|
|
95
|
+
fold_info = {
|
|
96
|
+
"fold_id": fold_idx,
|
|
97
|
+
"train_indices": train_idx.tolist(),
|
|
98
|
+
"test_indices": test_idx.tolist(),
|
|
99
|
+
"train_size": len(train_idx),
|
|
100
|
+
"test_size": len(test_idx),
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
# Add timestamp ranges if available (handle empty folds)
|
|
104
|
+
if timestamps:
|
|
105
|
+
if len(train_idx) > 0:
|
|
106
|
+
fold_info["train_start"] = timestamps[train_idx[0]]
|
|
107
|
+
fold_info["train_end"] = timestamps[train_idx[-1]]
|
|
108
|
+
else:
|
|
109
|
+
fold_info["train_start"] = None
|
|
110
|
+
fold_info["train_end"] = None
|
|
111
|
+
|
|
112
|
+
if len(test_idx) > 0:
|
|
113
|
+
fold_info["test_start"] = timestamps[test_idx[0]]
|
|
114
|
+
fold_info["test_end"] = timestamps[test_idx[-1]]
|
|
115
|
+
else:
|
|
116
|
+
fold_info["test_start"] = None
|
|
117
|
+
fold_info["test_end"] = None
|
|
118
|
+
|
|
119
|
+
fold_data["folds"].append(fold_info)
|
|
120
|
+
|
|
121
|
+
# Save to JSON
|
|
122
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
123
|
+
with filepath.open("w") as f:
|
|
124
|
+
json.dump(fold_data, f, indent=2)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def load_folds(
|
|
128
|
+
filepath: str | Path,
|
|
129
|
+
) -> tuple[list[tuple[NDArray[np.int_], NDArray[np.int_]]], dict[str, Any]]:
|
|
130
|
+
"""Load cross-validation folds from disk.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
filepath : str or Path
|
|
135
|
+
Path to saved fold configuration (JSON format).
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
folds : list[tuple[NDArray, NDArray]]
|
|
140
|
+
List of (train_indices, test_indices) tuples.
|
|
141
|
+
metadata : dict
|
|
142
|
+
Metadata dictionary stored with folds.
|
|
143
|
+
|
|
144
|
+
Examples
|
|
145
|
+
--------
|
|
146
|
+
>>> folds, metadata = load_folds("cv_folds.json")
|
|
147
|
+
>>> print(f"Loaded {len(folds)} folds")
|
|
148
|
+
>>> print(f"Metadata: {metadata}")
|
|
149
|
+
"""
|
|
150
|
+
filepath = Path(filepath)
|
|
151
|
+
|
|
152
|
+
if not filepath.exists():
|
|
153
|
+
raise FileNotFoundError(f"Fold file not found: {filepath}")
|
|
154
|
+
|
|
155
|
+
with filepath.open("r") as f:
|
|
156
|
+
fold_data = json.load(f)
|
|
157
|
+
|
|
158
|
+
# Validate version
|
|
159
|
+
if fold_data.get("version") != "1.0":
|
|
160
|
+
raise ValueError(f"Unsupported fold file version: {fold_data.get('version')}")
|
|
161
|
+
|
|
162
|
+
# Reconstruct folds
|
|
163
|
+
folds = []
|
|
164
|
+
for fold_info in fold_data["folds"]:
|
|
165
|
+
train_idx = np.array(fold_info["train_indices"], dtype=np.int_)
|
|
166
|
+
test_idx = np.array(fold_info["test_indices"], dtype=np.int_)
|
|
167
|
+
folds.append((train_idx, test_idx))
|
|
168
|
+
|
|
169
|
+
metadata = fold_data.get("metadata", {})
|
|
170
|
+
|
|
171
|
+
return folds, metadata
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def save_config(
|
|
175
|
+
config: Any, # SplitterConfig or subclass
|
|
176
|
+
filepath: str | Path,
|
|
177
|
+
) -> None:
|
|
178
|
+
"""Save splitter configuration to disk.
|
|
179
|
+
|
|
180
|
+
This is a convenience wrapper around config.to_json() for consistency
|
|
181
|
+
with the persistence API.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
config : SplitterConfig
|
|
186
|
+
Configuration object to save.
|
|
187
|
+
filepath : str or Path
|
|
188
|
+
Path to save configuration (JSON format).
|
|
189
|
+
|
|
190
|
+
Examples
|
|
191
|
+
--------
|
|
192
|
+
>>> from ml4t.diagnostic.splitters.config import PurgedWalkForwardConfig
|
|
193
|
+
>>> config = PurgedWalkForwardConfig(n_splits=5, test_size=100)
|
|
194
|
+
>>> save_config(config, "cv_config.json")
|
|
195
|
+
"""
|
|
196
|
+
filepath = Path(filepath)
|
|
197
|
+
config.to_json(filepath)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def load_config(
|
|
201
|
+
filepath: str | Path,
|
|
202
|
+
config_class: type[BaseConfig],
|
|
203
|
+
) -> BaseConfig:
|
|
204
|
+
"""Load splitter configuration from disk.
|
|
205
|
+
|
|
206
|
+
This is a convenience wrapper around config_class.from_json() for consistency
|
|
207
|
+
with the persistence API.
|
|
208
|
+
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
filepath : str or Path
|
|
212
|
+
Path to saved configuration (JSON format).
|
|
213
|
+
config_class : type
|
|
214
|
+
Configuration class to instantiate (e.g., PurgedWalkForwardConfig).
|
|
215
|
+
|
|
216
|
+
Returns
|
|
217
|
+
-------
|
|
218
|
+
config : SplitterConfig
|
|
219
|
+
Loaded configuration object.
|
|
220
|
+
|
|
221
|
+
Examples
|
|
222
|
+
--------
|
|
223
|
+
>>> from ml4t.diagnostic.splitters.config import PurgedWalkForwardConfig
|
|
224
|
+
>>> config = load_config("cv_config.json", PurgedWalkForwardConfig)
|
|
225
|
+
>>> print(config.n_splits)
|
|
226
|
+
"""
|
|
227
|
+
filepath = Path(filepath)
|
|
228
|
+
return config_class.from_json(filepath)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def verify_folds(
|
|
232
|
+
folds: list[tuple[NDArray[np.int_], NDArray[np.int_]]],
|
|
233
|
+
n_samples: int,
|
|
234
|
+
) -> dict[str, Any]:
|
|
235
|
+
"""Verify fold integrity and compute statistics.
|
|
236
|
+
|
|
237
|
+
Parameters
|
|
238
|
+
----------
|
|
239
|
+
folds : list[tuple[NDArray, NDArray]]
|
|
240
|
+
List of (train_indices, test_indices) tuples.
|
|
241
|
+
n_samples : int
|
|
242
|
+
Total number of samples in dataset.
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
stats : dict
|
|
247
|
+
Dictionary containing fold statistics and validation results.
|
|
248
|
+
|
|
249
|
+
Examples
|
|
250
|
+
--------
|
|
251
|
+
>>> folds, _ = load_folds("cv_folds.json")
|
|
252
|
+
>>> stats = verify_folds(folds, n_samples=1000)
|
|
253
|
+
>>> print(f"Valid: {stats['valid']}")
|
|
254
|
+
>>> print(f"Coverage: {stats['coverage']:.1%}")
|
|
255
|
+
"""
|
|
256
|
+
stats: dict[str, Any] = {
|
|
257
|
+
"valid": True,
|
|
258
|
+
"errors": [],
|
|
259
|
+
"n_folds": len(folds),
|
|
260
|
+
"n_samples": n_samples,
|
|
261
|
+
"train_sizes": [],
|
|
262
|
+
"test_sizes": [],
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
all_train_indices: set[int] = set()
|
|
266
|
+
all_test_indices: set[int] = set()
|
|
267
|
+
|
|
268
|
+
for fold_idx, (train_idx, test_idx) in enumerate(folds):
|
|
269
|
+
stats["train_sizes"].append(len(train_idx))
|
|
270
|
+
stats["test_sizes"].append(len(test_idx))
|
|
271
|
+
|
|
272
|
+
# Check for index overlap within fold
|
|
273
|
+
overlap = set(train_idx) & set(test_idx)
|
|
274
|
+
if overlap:
|
|
275
|
+
stats["valid"] = False
|
|
276
|
+
stats["errors"].append(
|
|
277
|
+
f"Fold {fold_idx}: {len(overlap)} overlapping indices between train and test"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Check for out-of-range indices
|
|
281
|
+
if np.any(train_idx < 0) or np.any(train_idx >= n_samples):
|
|
282
|
+
stats["valid"] = False
|
|
283
|
+
stats["errors"].append(f"Fold {fold_idx}: Train indices out of range")
|
|
284
|
+
|
|
285
|
+
if np.any(test_idx < 0) or np.any(test_idx >= n_samples):
|
|
286
|
+
stats["valid"] = False
|
|
287
|
+
stats["errors"].append(f"Fold {fold_idx}: Test indices out of range")
|
|
288
|
+
|
|
289
|
+
all_train_indices.update(train_idx)
|
|
290
|
+
all_test_indices.update(test_idx)
|
|
291
|
+
|
|
292
|
+
# Compute coverage statistics
|
|
293
|
+
all_indices = all_train_indices | all_test_indices
|
|
294
|
+
stats["coverage"] = len(all_indices) / n_samples
|
|
295
|
+
stats["train_coverage"] = len(all_train_indices) / n_samples
|
|
296
|
+
stats["test_coverage"] = len(all_test_indices) / n_samples
|
|
297
|
+
|
|
298
|
+
# Compute size statistics
|
|
299
|
+
if stats["train_sizes"]:
|
|
300
|
+
train_sizes: list[int] = stats["train_sizes"]
|
|
301
|
+
test_sizes: list[int] = stats["test_sizes"]
|
|
302
|
+
stats["avg_train_size"] = np.mean(train_sizes)
|
|
303
|
+
stats["std_train_size"] = np.std(train_sizes)
|
|
304
|
+
stats["avg_test_size"] = np.mean(test_sizes)
|
|
305
|
+
stats["std_test_size"] = np.std(test_sizes)
|
|
306
|
+
|
|
307
|
+
return stats
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
__all__ = [
|
|
311
|
+
"save_folds",
|
|
312
|
+
"load_folds",
|
|
313
|
+
"save_config",
|
|
314
|
+
"load_config",
|
|
315
|
+
"verify_folds",
|
|
316
|
+
]
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Utility functions for cross-validation splitters.
|
|
2
|
+
|
|
3
|
+
This module contains shared functionality used across different splitter
|
|
4
|
+
implementations, particularly for handling timestamp conversions and
|
|
5
|
+
boundary calculations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, cast
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from pandas import Timedelta
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def convert_indices_to_timestamps(
|
|
16
|
+
start_idx: int,
|
|
17
|
+
end_idx: int,
|
|
18
|
+
timestamps: pd.DatetimeIndex | np.ndarray | None = None,
|
|
19
|
+
) -> tuple[int | Any, int | Any]:
|
|
20
|
+
"""Convert indices to timestamps with robust boundary handling.
|
|
21
|
+
|
|
22
|
+
This function handles the conversion of array indices to timestamp values,
|
|
23
|
+
with robust estimation when the end index extends beyond available data.
|
|
24
|
+
It's designed to handle both regular and irregular time series frequencies.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
start_idx : int
|
|
29
|
+
Starting index
|
|
30
|
+
end_idx : int
|
|
31
|
+
Ending index (exclusive)
|
|
32
|
+
timestamps : pd.DatetimeIndex or np.ndarray, optional
|
|
33
|
+
Array of timestamps. If None, returns original indices.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
-------
|
|
37
|
+
tuple[Union[int, Any], Union[int, Any]]
|
|
38
|
+
(start_time, end_time) where times are either timestamps or indices
|
|
39
|
+
|
|
40
|
+
Examples:
|
|
41
|
+
--------
|
|
42
|
+
>>> import pandas as pd
|
|
43
|
+
>>> timestamps = pd.date_range('2020-01-01', periods=100, freq='D')
|
|
44
|
+
>>> start_time, end_time = convert_indices_to_timestamps(10, 20, timestamps)
|
|
45
|
+
>>> print(start_time, end_time)
|
|
46
|
+
2020-01-11 00:00:00 2020-01-21 00:00:00
|
|
47
|
+
|
|
48
|
+
>>> # Handle end index beyond data
|
|
49
|
+
>>> start_time, end_time = convert_indices_to_timestamps(90, 105, timestamps)
|
|
50
|
+
>>> print(end_time) # Estimated based on frequency
|
|
51
|
+
2020-04-15 00:00:00
|
|
52
|
+
"""
|
|
53
|
+
if timestamps is None:
|
|
54
|
+
return start_idx, end_idx
|
|
55
|
+
|
|
56
|
+
# Convert start index (always available)
|
|
57
|
+
start_time = timestamps[start_idx]
|
|
58
|
+
|
|
59
|
+
# Handle end index with robust boundary checking
|
|
60
|
+
if end_idx < len(timestamps):
|
|
61
|
+
# Direct lookup when index is within bounds
|
|
62
|
+
end_time = timestamps[end_idx]
|
|
63
|
+
else:
|
|
64
|
+
# Estimate end time when beyond available data
|
|
65
|
+
end_time = _estimate_timestamp_beyond_data(end_idx, timestamps)
|
|
66
|
+
|
|
67
|
+
return start_time, end_time
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _estimate_timestamp_beyond_data(
|
|
71
|
+
target_idx: int,
|
|
72
|
+
timestamps: pd.DatetimeIndex | np.ndarray,
|
|
73
|
+
) -> Any:
|
|
74
|
+
"""Estimate timestamp for an index beyond available data.
|
|
75
|
+
|
|
76
|
+
This function provides robust timestamp estimation for irregular
|
|
77
|
+
time series by using multiple frequency estimation methods.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
target_idx : int
|
|
82
|
+
Target index beyond the timestamp array
|
|
83
|
+
timestamps : pd.DatetimeIndex or np.ndarray
|
|
84
|
+
Available timestamps
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
-------
|
|
88
|
+
Any
|
|
89
|
+
Estimated timestamp
|
|
90
|
+
"""
|
|
91
|
+
if len(timestamps) < 2:
|
|
92
|
+
# Can't estimate frequency with fewer than 2 points
|
|
93
|
+
return timestamps[-1]
|
|
94
|
+
|
|
95
|
+
# Calculate how many steps beyond the data we need
|
|
96
|
+
steps_beyond = target_idx - len(timestamps) + 1
|
|
97
|
+
|
|
98
|
+
if isinstance(timestamps, pd.DatetimeIndex):
|
|
99
|
+
# Use pandas DatetimeIndex inference for better frequency handling
|
|
100
|
+
try:
|
|
101
|
+
# Try to infer frequency from the index
|
|
102
|
+
freq = timestamps.freq or pd.infer_freq(timestamps)
|
|
103
|
+
if freq is not None:
|
|
104
|
+
# freq is DateOffset or str - arithmetic works at runtime
|
|
105
|
+
return cast(
|
|
106
|
+
Any, timestamps[-1] + steps_beyond * pd.tseries.frequencies.to_offset(freq)
|
|
107
|
+
)
|
|
108
|
+
except (ValueError, TypeError):
|
|
109
|
+
# Fall back to simple difference calculation
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
# Robust frequency estimation using multiple methods
|
|
113
|
+
# estimated_freq can be Timedelta or np.timedelta64 depending on input type
|
|
114
|
+
estimated_freq: Timedelta | np.timedelta64 | Any
|
|
115
|
+
if len(timestamps) >= 10:
|
|
116
|
+
# Use median of recent differences for more robust estimation
|
|
117
|
+
recent_diffs = np.diff(timestamps[-10:])
|
|
118
|
+
# Sort and take middle value to preserve timedelta type
|
|
119
|
+
sorted_diffs = np.sort(recent_diffs)
|
|
120
|
+
mid_idx = len(sorted_diffs) // 2
|
|
121
|
+
estimated_freq = sorted_diffs[mid_idx]
|
|
122
|
+
elif len(timestamps) >= 3:
|
|
123
|
+
# Use median of all differences
|
|
124
|
+
all_diffs = np.diff(timestamps)
|
|
125
|
+
# Sort and take middle value to preserve timedelta type
|
|
126
|
+
sorted_diffs = np.sort(all_diffs)
|
|
127
|
+
mid_idx = len(sorted_diffs) // 2
|
|
128
|
+
estimated_freq = sorted_diffs[mid_idx]
|
|
129
|
+
else:
|
|
130
|
+
# Simple two-point difference
|
|
131
|
+
estimated_freq = timestamps[-1] - timestamps[-2]
|
|
132
|
+
|
|
133
|
+
# Estimate the target timestamp - cast needed for mixed datetime arithmetic
|
|
134
|
+
estimated_time: Any = timestamps[-1] + steps_beyond * estimated_freq
|
|
135
|
+
|
|
136
|
+
return estimated_time
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def validate_timestamp_array(
|
|
140
|
+
timestamps: pd.DatetimeIndex | np.ndarray | None,
|
|
141
|
+
n_samples: int,
|
|
142
|
+
) -> None:
|
|
143
|
+
"""Validate timestamp array for use in cross-validation.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
timestamps : pd.DatetimeIndex or np.ndarray, optional
|
|
148
|
+
Timestamp array to validate
|
|
149
|
+
n_samples : int
|
|
150
|
+
Expected number of samples
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
------
|
|
154
|
+
ValueError
|
|
155
|
+
If timestamps are invalid or mismatched with sample count
|
|
156
|
+
"""
|
|
157
|
+
if timestamps is None:
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
if len(timestamps) != n_samples:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Timestamp array length ({len(timestamps)}) does not match number of samples ({n_samples})",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if len(timestamps) > 1:
|
|
166
|
+
# Check for non-decreasing order (allows for duplicate timestamps)
|
|
167
|
+
if isinstance(timestamps, pd.DatetimeIndex):
|
|
168
|
+
if not timestamps.is_monotonic_increasing:
|
|
169
|
+
raise ValueError("Timestamps must be in non-decreasing order")
|
|
170
|
+
else:
|
|
171
|
+
if not np.all(np.diff(timestamps) >= 0):
|
|
172
|
+
raise ValueError("Timestamps must be in non-decreasing order")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def get_time_boundaries(
|
|
176
|
+
group_boundaries: list[tuple[int, int]],
|
|
177
|
+
group_indices: list[int],
|
|
178
|
+
timestamps: pd.DatetimeIndex | np.ndarray | None = None,
|
|
179
|
+
) -> list[tuple[int | Any, int | Any]]:
|
|
180
|
+
"""Convert multiple group boundaries from indices to timestamps.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
group_boundaries : list[tuple[int, int]]
|
|
185
|
+
List of (start_idx, end_idx) boundaries
|
|
186
|
+
group_indices : list[int]
|
|
187
|
+
Indices of groups to convert
|
|
188
|
+
timestamps : pd.DatetimeIndex or np.ndarray, optional
|
|
189
|
+
Timestamp array
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
-------
|
|
193
|
+
list[tuple[Union[int, Any], Union[int, Any]]]
|
|
194
|
+
List of (start_time, end_time) boundaries
|
|
195
|
+
"""
|
|
196
|
+
time_boundaries = []
|
|
197
|
+
|
|
198
|
+
for group_idx in group_indices:
|
|
199
|
+
start_idx, end_idx = group_boundaries[group_idx]
|
|
200
|
+
start_time, end_time = convert_indices_to_timestamps(
|
|
201
|
+
start_idx,
|
|
202
|
+
end_idx,
|
|
203
|
+
timestamps,
|
|
204
|
+
)
|
|
205
|
+
time_boundaries.append((start_time, end_time))
|
|
206
|
+
|
|
207
|
+
return time_boundaries
|