ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,501 @@
|
|
|
1
|
+
"""Base class for all time-series cross-validation splitters.
|
|
2
|
+
|
|
3
|
+
This module defines the abstract base class that all ml4t-diagnostic splitters inherit from,
|
|
4
|
+
ensuring compatibility with scikit-learn's cross-validation framework while adding
|
|
5
|
+
support for time-series specific features like purging and embargo.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from collections.abc import Generator
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Union, cast
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
import polars as pl
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from numpy.typing import NDArray
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BaseSplitter(ABC):
|
|
21
|
+
"""Abstract base class for all ml4t-diagnostic time-series splitters.
|
|
22
|
+
|
|
23
|
+
This class defines the interface that all splitters must implement to ensure
|
|
24
|
+
compatibility with scikit-learn's model selection tools while providing
|
|
25
|
+
additional functionality for financial time-series validation.
|
|
26
|
+
|
|
27
|
+
All splitters should support purging (removing training data that could leak
|
|
28
|
+
information into test data) and embargo (adding gaps between train and test
|
|
29
|
+
sets to account for serial correlation).
|
|
30
|
+
|
|
31
|
+
Session-Aware Splitting
|
|
32
|
+
-----------------------
|
|
33
|
+
Splitters can optionally align fold boundaries to trading session boundaries
|
|
34
|
+
by setting ``align_to_sessions=True``. This requires the data to have a
|
|
35
|
+
session column (default: 'session_date') that identifies trading sessions.
|
|
36
|
+
|
|
37
|
+
Trading sessions are atomic units that should never be split across train/test
|
|
38
|
+
folds. For intraday data (e.g., CME futures with Sunday 5pm - Friday 4pm sessions),
|
|
39
|
+
this prevents subtle lookahead bias from mid-session splits.
|
|
40
|
+
|
|
41
|
+
**Integration with qdata library:**
|
|
42
|
+
|
|
43
|
+
The session column should be added using the ``qdata`` library's session
|
|
44
|
+
assignment functionality::
|
|
45
|
+
|
|
46
|
+
from qdata import DataManager
|
|
47
|
+
|
|
48
|
+
manager = DataManager()
|
|
49
|
+
df = manager.load(symbol="BTC", exchange="CME", calendar="CME_Globex_Crypto")
|
|
50
|
+
# df now has 'session_date' column automatically assigned
|
|
51
|
+
|
|
52
|
+
Or manually using SessionAssigner::
|
|
53
|
+
|
|
54
|
+
from ml4t.data.sessions import SessionAssigner
|
|
55
|
+
|
|
56
|
+
assigner = SessionAssigner.from_exchange('CME')
|
|
57
|
+
df_with_sessions = assigner.assign_sessions(df)
|
|
58
|
+
|
|
59
|
+
Then use with ml4t-diagnostic splitters::
|
|
60
|
+
|
|
61
|
+
from ml4t.diagnostic.splitters import PurgedWalkForwardCV
|
|
62
|
+
|
|
63
|
+
cv = PurgedWalkForwardCV(
|
|
64
|
+
n_splits=5,
|
|
65
|
+
align_to_sessions=True, # Align folds to session boundaries
|
|
66
|
+
session_col='session_date'
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
for train_idx, test_idx in cv.split(df_with_sessions):
|
|
70
|
+
# Fold boundaries respect session boundaries
|
|
71
|
+
pass
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def split(
|
|
76
|
+
self,
|
|
77
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
78
|
+
y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
|
|
79
|
+
groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
|
|
80
|
+
) -> Generator[tuple["NDArray[np.intp]", "NDArray[np.intp]"], None, None]:
|
|
81
|
+
"""Generate indices to split data into training and test sets.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray
|
|
86
|
+
Training data with shape (n_samples, n_features).
|
|
87
|
+
|
|
88
|
+
y : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
|
|
89
|
+
Target variable with shape (n_samples,). Always ignored but kept
|
|
90
|
+
for scikit-learn compatibility.
|
|
91
|
+
|
|
92
|
+
groups : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
|
|
93
|
+
Group labels for samples, used for multi-asset splitting.
|
|
94
|
+
Shape (n_samples,).
|
|
95
|
+
|
|
96
|
+
Yields:
|
|
97
|
+
------
|
|
98
|
+
train : numpy.ndarray
|
|
99
|
+
The training set indices for that split.
|
|
100
|
+
|
|
101
|
+
test : numpy.ndarray
|
|
102
|
+
The testing set indices for that split.
|
|
103
|
+
|
|
104
|
+
Notes:
|
|
105
|
+
-----
|
|
106
|
+
The indices returned are integer positions, not labels or timestamps.
|
|
107
|
+
This ensures compatibility with numpy array indexing and scikit-learn.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def get_n_splits(
|
|
111
|
+
self,
|
|
112
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"] | None = None,
|
|
113
|
+
y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
|
|
114
|
+
groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
|
|
115
|
+
) -> int:
|
|
116
|
+
"""Return the number of splitting iterations in the cross-validator.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
X : polars.DataFrame, pandas.DataFrame, numpy.ndarray, or None, default=None
|
|
121
|
+
Training data. Some splitters may use properties of X to determine
|
|
122
|
+
the number of splits.
|
|
123
|
+
|
|
124
|
+
y : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
|
|
125
|
+
Always ignored, exists for compatibility.
|
|
126
|
+
|
|
127
|
+
groups : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
|
|
128
|
+
Group labels. Some splitters may use this to determine splits.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
-------
|
|
132
|
+
n_splits : int
|
|
133
|
+
The number of splitting iterations.
|
|
134
|
+
|
|
135
|
+
Notes:
|
|
136
|
+
-----
|
|
137
|
+
Most splitters can determine the number of splits from their parameters
|
|
138
|
+
alone, but some (like GroupKFold variants) may need to inspect the data.
|
|
139
|
+
"""
|
|
140
|
+
raise NotImplementedError(
|
|
141
|
+
f"{self.__class__.__name__} must implement get_n_splits()",
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def _get_n_samples(
|
|
145
|
+
self,
|
|
146
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
147
|
+
) -> int:
|
|
148
|
+
"""Get the number of samples in X regardless of type.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray
|
|
153
|
+
The data to get the sample count from.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
-------
|
|
157
|
+
n_samples : int
|
|
158
|
+
The number of samples (rows) in X.
|
|
159
|
+
"""
|
|
160
|
+
if isinstance(X, pl.DataFrame):
|
|
161
|
+
return X.height
|
|
162
|
+
if isinstance(X, pl.LazyFrame):
|
|
163
|
+
# LazyFrame doesn't have height, need to collect first
|
|
164
|
+
return X.collect().height
|
|
165
|
+
if isinstance(X, pd.DataFrame):
|
|
166
|
+
return len(X)
|
|
167
|
+
if isinstance(X, np.ndarray):
|
|
168
|
+
return int(X.shape[0])
|
|
169
|
+
raise TypeError(
|
|
170
|
+
f"X must be a Polars DataFrame, Pandas DataFrame, or numpy array. Got {type(X).__name__}",
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def _validate_data(
|
|
174
|
+
self,
|
|
175
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
176
|
+
y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
|
|
177
|
+
groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
|
|
178
|
+
) -> int:
|
|
179
|
+
"""Validate input data and return the number of samples.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray
|
|
184
|
+
Training data.
|
|
185
|
+
|
|
186
|
+
y : polars.Series, pandas.Series, numpy.ndarray, or None
|
|
187
|
+
Target variable.
|
|
188
|
+
|
|
189
|
+
groups : polars.Series, pandas.Series, numpy.ndarray, or None
|
|
190
|
+
Group labels.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
-------
|
|
194
|
+
n_samples : int
|
|
195
|
+
The number of samples in the data.
|
|
196
|
+
|
|
197
|
+
Raises:
|
|
198
|
+
------
|
|
199
|
+
ValueError
|
|
200
|
+
If the input data has inconsistent lengths.
|
|
201
|
+
TypeError
|
|
202
|
+
If the input data types are not supported.
|
|
203
|
+
"""
|
|
204
|
+
n_samples = self._get_n_samples(X)
|
|
205
|
+
|
|
206
|
+
# Validate y if provided
|
|
207
|
+
if y is not None:
|
|
208
|
+
if isinstance(y, pl.Series | pd.Series):
|
|
209
|
+
n_y = len(y)
|
|
210
|
+
elif isinstance(y, np.ndarray):
|
|
211
|
+
n_y = y.shape[0]
|
|
212
|
+
else:
|
|
213
|
+
raise TypeError(
|
|
214
|
+
f"y must be a Polars Series, Pandas Series, or numpy array. Got {type(y).__name__}",
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
if n_y != n_samples:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"X and y have inconsistent lengths: X has {n_samples} samples, y has {n_y} samples",
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Validate groups if provided
|
|
223
|
+
if groups is not None:
|
|
224
|
+
if isinstance(groups, pl.Series | pd.Series):
|
|
225
|
+
n_groups = len(groups)
|
|
226
|
+
elif isinstance(groups, np.ndarray):
|
|
227
|
+
n_groups = groups.shape[0]
|
|
228
|
+
else:
|
|
229
|
+
raise TypeError(
|
|
230
|
+
f"groups must be a Polars Series, Pandas Series, or numpy array. Got {type(groups).__name__}",
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if n_groups != n_samples:
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"X and groups have inconsistent lengths: X has {n_samples} samples, groups has {n_groups} samples",
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
return n_samples
|
|
239
|
+
|
|
240
|
+
def _validate_session_alignment(
|
|
241
|
+
self,
|
|
242
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
243
|
+
align_to_sessions: bool,
|
|
244
|
+
session_col: str,
|
|
245
|
+
) -> None:
|
|
246
|
+
"""Validate that session column exists if session alignment is enabled.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray
|
|
251
|
+
Training data that may contain session column.
|
|
252
|
+
|
|
253
|
+
align_to_sessions : bool
|
|
254
|
+
Whether session alignment is requested.
|
|
255
|
+
|
|
256
|
+
session_col : str
|
|
257
|
+
Name of the session column to look for.
|
|
258
|
+
|
|
259
|
+
Raises
|
|
260
|
+
------
|
|
261
|
+
ValueError
|
|
262
|
+
If align_to_sessions=True but session column is missing or X is not a DataFrame.
|
|
263
|
+
|
|
264
|
+
Notes
|
|
265
|
+
-----
|
|
266
|
+
This method provides helpful error messages that guide users to the qdata library
|
|
267
|
+
for session assignment if the required column is missing.
|
|
268
|
+
"""
|
|
269
|
+
if not align_to_sessions:
|
|
270
|
+
return # Skip validation if not using sessions
|
|
271
|
+
|
|
272
|
+
# Check that X is a DataFrame (sessions require column access)
|
|
273
|
+
if not hasattr(X, "columns"):
|
|
274
|
+
raise ValueError(
|
|
275
|
+
f"align_to_sessions=True requires X to be a DataFrame "
|
|
276
|
+
f"(Polars or Pandas), got {type(X).__name__}.\n"
|
|
277
|
+
f"\n"
|
|
278
|
+
f"Session alignment works with tabular data that has a session "
|
|
279
|
+
f"identifier column. NumPy arrays do not support column names."
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Check for session column
|
|
283
|
+
columns = list(cast(Any, X.columns))
|
|
284
|
+
if session_col not in columns:
|
|
285
|
+
raise ValueError(
|
|
286
|
+
f"align_to_sessions=True requires '{session_col}' column in X, "
|
|
287
|
+
f"but it was not found.\n"
|
|
288
|
+
f"\n"
|
|
289
|
+
f"Available columns: {columns}\n"
|
|
290
|
+
f"\n"
|
|
291
|
+
f"To add session dates to your data using the qdata library:\n"
|
|
292
|
+
f"\n"
|
|
293
|
+
f"Option 1 - Using DataManager (recommended):\n"
|
|
294
|
+
f" from qdata import DataManager\n"
|
|
295
|
+
f" manager = DataManager()\n"
|
|
296
|
+
f" df = manager.load(\n"
|
|
297
|
+
f" symbol='BTC',\n"
|
|
298
|
+
f" exchange='CME',\n"
|
|
299
|
+
f" calendar='CME_Globex_Crypto'\n"
|
|
300
|
+
f" )\n"
|
|
301
|
+
f" # df now has '{session_col}' column automatically\n"
|
|
302
|
+
f"\n"
|
|
303
|
+
f"Option 2 - Using SessionAssigner directly:\n"
|
|
304
|
+
f" from ml4t.data.sessions import SessionAssigner\n"
|
|
305
|
+
f" assigner = SessionAssigner.from_exchange('CME')\n"
|
|
306
|
+
f" df_with_sessions = assigner.assign_sessions(df)\n"
|
|
307
|
+
f"\n"
|
|
308
|
+
f"Option 3 - If you have a different session column:\n"
|
|
309
|
+
f" cv = {self.__class__.__name__}(\n"
|
|
310
|
+
f" ...,\n"
|
|
311
|
+
f" align_to_sessions=True,\n"
|
|
312
|
+
f" session_col='your_column_name' # Specify your column\n"
|
|
313
|
+
f" )\n"
|
|
314
|
+
f"\n"
|
|
315
|
+
f"Option 4 - Disable session alignment:\n"
|
|
316
|
+
f" cv = {self.__class__.__name__}(\n"
|
|
317
|
+
f" ...,\n"
|
|
318
|
+
f" align_to_sessions=False # Use standard splitting\n"
|
|
319
|
+
f" )\n"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
def _get_unique_sessions(
|
|
323
|
+
self,
|
|
324
|
+
X: pl.DataFrame | pd.DataFrame,
|
|
325
|
+
session_col: str,
|
|
326
|
+
) -> pl.Series | pd.Series:
|
|
327
|
+
"""Extract unique session identifiers in order of first appearance.
|
|
328
|
+
|
|
329
|
+
Parameters
|
|
330
|
+
----------
|
|
331
|
+
X : polars.DataFrame or pandas.DataFrame
|
|
332
|
+
Data containing session column.
|
|
333
|
+
|
|
334
|
+
session_col : str
|
|
335
|
+
Name of the session column.
|
|
336
|
+
|
|
337
|
+
Returns
|
|
338
|
+
-------
|
|
339
|
+
sessions : polars.Series or pandas.Series
|
|
340
|
+
Unique session identifiers in order of first appearance.
|
|
341
|
+
|
|
342
|
+
Notes
|
|
343
|
+
-----
|
|
344
|
+
Sessions are returned in the order they first appear in the data, which
|
|
345
|
+
is the correct chronological order if the data is sorted by time (as it
|
|
346
|
+
should be for time-series cross-validation).
|
|
347
|
+
|
|
348
|
+
Previously this method sorted by session ID, which is incorrect when
|
|
349
|
+
session IDs are not naturally sortable in chronological order.
|
|
350
|
+
"""
|
|
351
|
+
if isinstance(X, pl.DataFrame):
|
|
352
|
+
# maintain_order=True preserves order of first appearance
|
|
353
|
+
return X[session_col].unique(maintain_order=True)
|
|
354
|
+
else: # pandas DataFrame
|
|
355
|
+
# drop_duplicates without sorting preserves first appearance order
|
|
356
|
+
return X[session_col].drop_duplicates().reset_index(drop=True)
|
|
357
|
+
|
|
358
|
+
def _session_to_indices(
|
|
359
|
+
self,
|
|
360
|
+
X: pl.DataFrame | pd.DataFrame,
|
|
361
|
+
session_col: str,
|
|
362
|
+
) -> tuple[list[Any], dict[Any, "NDArray[np.intp]"]]:
|
|
363
|
+
"""Map each session to its row indices, preserving appearance order.
|
|
364
|
+
|
|
365
|
+
This is the key helper for session-aligned CV. It returns EXACT indices
|
|
366
|
+
per session, not (start, end) boundaries, which is critical for correct
|
|
367
|
+
behavior with non-contiguous or interleaved data.
|
|
368
|
+
|
|
369
|
+
Parameters
|
|
370
|
+
----------
|
|
371
|
+
X : polars.DataFrame or pandas.DataFrame
|
|
372
|
+
Data containing session column.
|
|
373
|
+
|
|
374
|
+
session_col : str
|
|
375
|
+
Name of the session column.
|
|
376
|
+
|
|
377
|
+
Returns
|
|
378
|
+
-------
|
|
379
|
+
ordered_sessions : list
|
|
380
|
+
Session IDs in order of first appearance.
|
|
381
|
+
|
|
382
|
+
session_indices : dict
|
|
383
|
+
Mapping from session ID to numpy array of row indices (sorted).
|
|
384
|
+
|
|
385
|
+
Examples
|
|
386
|
+
--------
|
|
387
|
+
>>> # Data with interleaved assets
|
|
388
|
+
>>> X = pl.DataFrame({
|
|
389
|
+
... "session": ["A", "A", "B", "A", "B"],
|
|
390
|
+
... "asset": ["X", "Y", "X", "X", "Y"]
|
|
391
|
+
... })
|
|
392
|
+
>>> sessions, indices = splitter._session_to_indices(X, "session")
|
|
393
|
+
>>> sessions
|
|
394
|
+
['A', 'B']
|
|
395
|
+
>>> indices['A']
|
|
396
|
+
array([0, 1, 3]) # Exact indices, NOT range(0, 3)
|
|
397
|
+
>>> indices['B']
|
|
398
|
+
array([2, 4])
|
|
399
|
+
"""
|
|
400
|
+
if isinstance(X, pl.DataFrame):
|
|
401
|
+
# Polars: use group_by with maintain_order=True
|
|
402
|
+
# Add row indices, group by session, collect indices per group
|
|
403
|
+
df_with_idx = X.with_row_index("__row_idx__")
|
|
404
|
+
grouped = df_with_idx.group_by(session_col, maintain_order=True).agg(
|
|
405
|
+
pl.col("__row_idx__")
|
|
406
|
+
)
|
|
407
|
+
ordered_sessions = grouped[session_col].to_list()
|
|
408
|
+
session_indices = {
|
|
409
|
+
row[session_col]: np.array(row["__row_idx__"], dtype=np.intp)
|
|
410
|
+
for row in grouped.iter_rows(named=True)
|
|
411
|
+
}
|
|
412
|
+
else:
|
|
413
|
+
# Pandas: use groupby().indices (fast, returns dict of arrays)
|
|
414
|
+
grouped = X.groupby(session_col, sort=False)
|
|
415
|
+
session_indices_raw = grouped.indices
|
|
416
|
+
# Preserve appearance order using drop_duplicates
|
|
417
|
+
ordered_sessions = X[session_col].drop_duplicates().tolist()
|
|
418
|
+
session_indices = {
|
|
419
|
+
session: np.array(session_indices_raw[session], dtype=np.intp)
|
|
420
|
+
for session in ordered_sessions
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
return ordered_sessions, session_indices
|
|
424
|
+
|
|
425
|
+
def _extract_timestamps(
|
|
426
|
+
self,
|
|
427
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
428
|
+
timestamp_col: str | None = None,
|
|
429
|
+
) -> pd.DatetimeIndex | None:
|
|
430
|
+
"""Extract timestamps from data for time-based size calculations.
|
|
431
|
+
|
|
432
|
+
This method supports both Polars and pandas DataFrames, enabling
|
|
433
|
+
time-based test_size/train_size specifications (e.g., '4W', '3M').
|
|
434
|
+
|
|
435
|
+
Parameters
|
|
436
|
+
----------
|
|
437
|
+
X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray
|
|
438
|
+
Input data.
|
|
439
|
+
timestamp_col : str or None
|
|
440
|
+
Column name containing timestamps for Polars DataFrames.
|
|
441
|
+
If None, falls back to pandas DatetimeIndex (backward compatible).
|
|
442
|
+
|
|
443
|
+
Returns
|
|
444
|
+
-------
|
|
445
|
+
timestamps : pandas.DatetimeIndex or None
|
|
446
|
+
Timestamps as a pandas DatetimeIndex for time-based calculations.
|
|
447
|
+
Returns None if timestamps cannot be extracted.
|
|
448
|
+
|
|
449
|
+
Notes
|
|
450
|
+
-----
|
|
451
|
+
For Polars DataFrames:
|
|
452
|
+
- Requires timestamp_col to be specified
|
|
453
|
+
- Column must be datetime type
|
|
454
|
+
- Converts to pandas DatetimeIndex for compatibility with time parsing
|
|
455
|
+
|
|
456
|
+
For pandas DataFrames:
|
|
457
|
+
- Uses DatetimeIndex if available
|
|
458
|
+
- Falls back to timestamp_col if index is not datetime
|
|
459
|
+
|
|
460
|
+
For numpy arrays:
|
|
461
|
+
- Returns None (no timestamp information available)
|
|
462
|
+
"""
|
|
463
|
+
# Polars DataFrame: extract from column
|
|
464
|
+
if isinstance(X, pl.DataFrame):
|
|
465
|
+
if timestamp_col is None:
|
|
466
|
+
return None
|
|
467
|
+
if timestamp_col not in X.columns:
|
|
468
|
+
raise ValueError(
|
|
469
|
+
f"timestamp_col='{timestamp_col}' not found in Polars DataFrame. "
|
|
470
|
+
f"Available columns: {X.columns}"
|
|
471
|
+
)
|
|
472
|
+
# Convert Polars datetime column to pandas DatetimeIndex
|
|
473
|
+
ts_series = X[timestamp_col].to_pandas()
|
|
474
|
+
if not pd.api.types.is_datetime64_any_dtype(ts_series):
|
|
475
|
+
raise ValueError(
|
|
476
|
+
f"timestamp_col='{timestamp_col}' must be datetime type, "
|
|
477
|
+
f"got {X[timestamp_col].dtype}"
|
|
478
|
+
)
|
|
479
|
+
idx = pd.DatetimeIndex(ts_series)
|
|
480
|
+
# Ensure timezone awareness (required for purging/embargo)
|
|
481
|
+
if idx.tz is None:
|
|
482
|
+
idx = idx.tz_localize("UTC")
|
|
483
|
+
return idx
|
|
484
|
+
|
|
485
|
+
# pandas DataFrame: prefer index, fallback to column
|
|
486
|
+
if isinstance(X, pd.DataFrame):
|
|
487
|
+
if isinstance(X.index, pd.DatetimeIndex):
|
|
488
|
+
return X.index
|
|
489
|
+
# Fallback: try timestamp_col if specified
|
|
490
|
+
if timestamp_col is not None and timestamp_col in X.columns:
|
|
491
|
+
ts_series = X[timestamp_col]
|
|
492
|
+
if pd.api.types.is_datetime64_any_dtype(ts_series):
|
|
493
|
+
return pd.DatetimeIndex(ts_series)
|
|
494
|
+
return None
|
|
495
|
+
|
|
496
|
+
# numpy array: no timestamp information
|
|
497
|
+
return None
|
|
498
|
+
|
|
499
|
+
def __repr__(self) -> str:
|
|
500
|
+
"""Return a string representation of the splitter."""
|
|
501
|
+
return f"{self.__class__.__name__}()"
|