ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Time window computation for CPCV purging.
|
|
2
|
+
|
|
3
|
+
This module handles computing purge windows from test indices:
|
|
4
|
+
- Timestamp windows from exact indices
|
|
5
|
+
- Contiguous segment detection
|
|
6
|
+
- Window merging for efficient purging
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
from numpy.typing import NDArray
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import pandas as pd
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True)
|
|
22
|
+
class TimeWindow:
|
|
23
|
+
"""A time window for purging, with exclusive end bound.
|
|
24
|
+
|
|
25
|
+
Attributes
|
|
26
|
+
----------
|
|
27
|
+
start : pd.Timestamp
|
|
28
|
+
Start of the window (inclusive).
|
|
29
|
+
end_exclusive : pd.Timestamp
|
|
30
|
+
End of the window (exclusive).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
start: pd.Timestamp
|
|
34
|
+
end_exclusive: pd.Timestamp
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def timestamp_window_from_indices(
|
|
38
|
+
indices: NDArray[np.intp],
|
|
39
|
+
timestamps: pd.DatetimeIndex,
|
|
40
|
+
) -> TimeWindow | None:
|
|
41
|
+
"""Compute timestamp window from actual indices.
|
|
42
|
+
|
|
43
|
+
This is critical for correct purging in session-aligned mode. Instead of
|
|
44
|
+
using (min_row_idx, max_row_idx) boundaries which can span unrelated rows
|
|
45
|
+
in interleaved data, we compute the actual timestamp bounds from the test
|
|
46
|
+
indices.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
indices : ndarray
|
|
51
|
+
Row indices of test samples.
|
|
52
|
+
timestamps : pd.DatetimeIndex
|
|
53
|
+
Timestamps for all samples.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
TimeWindow or None
|
|
58
|
+
Window with (start_time, end_time_exclusive) if indices non-empty.
|
|
59
|
+
None if indices is empty (signals caller to skip purging).
|
|
60
|
+
|
|
61
|
+
Notes
|
|
62
|
+
-----
|
|
63
|
+
The end is made exclusive by adding 1 nanosecond. This handles the case
|
|
64
|
+
of duplicate timestamps at the boundary.
|
|
65
|
+
|
|
66
|
+
Examples
|
|
67
|
+
--------
|
|
68
|
+
>>> import pandas as pd
|
|
69
|
+
>>> import numpy as np
|
|
70
|
+
>>> timestamps = pd.date_range("2020-01-01", periods=10, freq="D", tz="UTC")
|
|
71
|
+
>>> indices = np.array([2, 3, 4])
|
|
72
|
+
>>> window = timestamp_window_from_indices(indices, timestamps)
|
|
73
|
+
>>> window.start
|
|
74
|
+
Timestamp('2020-01-03 00:00:00+0000', tz='UTC')
|
|
75
|
+
"""
|
|
76
|
+
import pandas as pd
|
|
77
|
+
|
|
78
|
+
if len(indices) == 0:
|
|
79
|
+
# Empty indices - return None to signal callers to skip purging
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
test_timestamps = timestamps.take(indices)
|
|
83
|
+
start_time = test_timestamps.min()
|
|
84
|
+
# Add 1 nanosecond to make end exclusive (handles duplicate timestamps)
|
|
85
|
+
end_time_exclusive = test_timestamps.max() + pd.Timedelta(1, "ns")
|
|
86
|
+
return TimeWindow(start=start_time, end_exclusive=end_time_exclusive)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def find_contiguous_segments(
|
|
90
|
+
test_groups_data: list[tuple[int, int, int, NDArray[np.intp] | None]],
|
|
91
|
+
asset_indices: NDArray[np.intp],
|
|
92
|
+
) -> list[list[tuple[int, int, int, NDArray[np.intp]]]]:
|
|
93
|
+
"""Find contiguous segments of test groups for a given asset.
|
|
94
|
+
|
|
95
|
+
Groups test data into contiguous segments based on temporal adjacency.
|
|
96
|
+
This allows applying one purge window per segment instead of per group,
|
|
97
|
+
which is more efficient and statistically correct.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
test_groups_data : list of tuple
|
|
102
|
+
Each tuple contains (group_idx, group_start, group_end, exact_indices).
|
|
103
|
+
exact_indices is non-None for session-aligned mode.
|
|
104
|
+
asset_indices : ndarray
|
|
105
|
+
Indices belonging to this asset.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
segments : list of list of tuple
|
|
110
|
+
Each segment is a list of (group_idx, start, end, asset_test_indices).
|
|
111
|
+
Segments are separated by gaps in the test groups.
|
|
112
|
+
|
|
113
|
+
Notes
|
|
114
|
+
-----
|
|
115
|
+
In session-aligned mode, exact_indices should be used instead of
|
|
116
|
+
generating indices via np.arange (which is wrong for interleaved data).
|
|
117
|
+
"""
|
|
118
|
+
contiguous_segments: list[list[tuple[int, int, int, NDArray[np.intp]]]] = []
|
|
119
|
+
current_segment: list[tuple[int, int, int, NDArray[np.intp]]] = []
|
|
120
|
+
|
|
121
|
+
for group_idx, group_start, group_end, exact_indices in test_groups_data:
|
|
122
|
+
# Get test indices for this asset in this group
|
|
123
|
+
if exact_indices is not None:
|
|
124
|
+
# Session-aligned mode: use exact indices
|
|
125
|
+
group_test_indices = exact_indices
|
|
126
|
+
else:
|
|
127
|
+
# Standard mode: generate from boundaries
|
|
128
|
+
group_test_indices = np.arange(group_start, group_end)
|
|
129
|
+
asset_group_test_indices = np.intersect1d(group_test_indices, asset_indices)
|
|
130
|
+
|
|
131
|
+
if len(asset_group_test_indices) == 0:
|
|
132
|
+
# No test data for this asset in this group
|
|
133
|
+
if current_segment:
|
|
134
|
+
contiguous_segments.append(current_segment)
|
|
135
|
+
current_segment = []
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
# Check if this group is contiguous with the previous segment
|
|
139
|
+
# current_segment[-1][2] is group_end (exclusive), gap exists if group_start > group_end
|
|
140
|
+
if current_segment and group_start > current_segment[-1][2]: # Gap detected
|
|
141
|
+
# Finish current segment and start new one
|
|
142
|
+
contiguous_segments.append(current_segment)
|
|
143
|
+
current_segment = [(group_idx, group_start, group_end, asset_group_test_indices)]
|
|
144
|
+
else:
|
|
145
|
+
# Add to current segment
|
|
146
|
+
current_segment.append((group_idx, group_start, group_end, asset_group_test_indices))
|
|
147
|
+
|
|
148
|
+
# Don't forget the last segment
|
|
149
|
+
if current_segment:
|
|
150
|
+
contiguous_segments.append(current_segment)
|
|
151
|
+
|
|
152
|
+
return contiguous_segments
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def merge_windows(windows: list[TimeWindow]) -> list[TimeWindow]:
|
|
156
|
+
"""Merge overlapping time windows.
|
|
157
|
+
|
|
158
|
+
This can reduce the number of purge operations when windows overlap,
|
|
159
|
+
and provides clearer semantics about what's being purged.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
windows : list of TimeWindow
|
|
164
|
+
Windows to merge.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
merged : list of TimeWindow
|
|
169
|
+
Non-overlapping windows covering the same time ranges.
|
|
170
|
+
"""
|
|
171
|
+
if not windows:
|
|
172
|
+
return []
|
|
173
|
+
|
|
174
|
+
# Sort by start time
|
|
175
|
+
sorted_windows = sorted(windows, key=lambda w: w.start)
|
|
176
|
+
merged = [sorted_windows[0]]
|
|
177
|
+
|
|
178
|
+
for window in sorted_windows[1:]:
|
|
179
|
+
last = merged[-1]
|
|
180
|
+
if window.start <= last.end_exclusive:
|
|
181
|
+
# Overlapping - merge by extending end
|
|
182
|
+
merged[-1] = TimeWindow(
|
|
183
|
+
start=last.start,
|
|
184
|
+
end_exclusive=max(last.end_exclusive, window.end_exclusive),
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
# Non-overlapping - add new window
|
|
188
|
+
merged.append(window)
|
|
189
|
+
|
|
190
|
+
return merged
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
"""Group isolation utilities for multi-asset cross-validation.
|
|
2
|
+
|
|
3
|
+
This module provides utilities to prevent the same asset (e.g., contract, symbol)
|
|
4
|
+
from appearing in both training and test sets during cross-validation. This is
|
|
5
|
+
critical for avoiding data leakage in multi-asset strategies.
|
|
6
|
+
|
|
7
|
+
Example Use Cases
|
|
8
|
+
-----------------
|
|
9
|
+
1. **Futures contracts**: Prevent ES_202312 from being in both train and test
|
|
10
|
+
2. **Multiple symbols**: Ensure AAPL data doesn't leak between folds
|
|
11
|
+
3. **Multi-strategy**: Isolate strategies to prevent cross-contamination
|
|
12
|
+
|
|
13
|
+
Integration with qdata
|
|
14
|
+
----------------------
|
|
15
|
+
The `groups` parameter should contain asset identifiers that come from your
|
|
16
|
+
data pipeline. Typically this would be a column like 'symbol', 'contract',
|
|
17
|
+
or 'asset_id' from your DataFrame.
|
|
18
|
+
|
|
19
|
+
Example::
|
|
20
|
+
|
|
21
|
+
import polars as pl
|
|
22
|
+
from ml4t.diagnostic.splitters import PurgedWalkForwardCV
|
|
23
|
+
|
|
24
|
+
# Data with asset identifiers
|
|
25
|
+
df = pl.DataFrame({
|
|
26
|
+
'timestamp': [...],
|
|
27
|
+
'symbol': ['AAPL', 'AAPL', 'MSFT', 'MSFT', ...],
|
|
28
|
+
'returns': [...]
|
|
29
|
+
})
|
|
30
|
+
|
|
31
|
+
# Cross-validate with group isolation
|
|
32
|
+
cv = PurgedWalkForwardCV(n_splits=5, isolate_groups=True)
|
|
33
|
+
|
|
34
|
+
for train_idx, test_idx in cv.split(df, groups=df['symbol']):
|
|
35
|
+
# Groups in test_idx will NEVER appear in train_idx
|
|
36
|
+
train_symbols = df[train_idx]['symbol'].unique()
|
|
37
|
+
test_symbols = df[test_idx]['symbol'].unique()
|
|
38
|
+
assert len(set(train_symbols) & set(test_symbols)) == 0
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
from __future__ import annotations
|
|
42
|
+
|
|
43
|
+
from typing import TYPE_CHECKING, Any
|
|
44
|
+
|
|
45
|
+
import numpy as np
|
|
46
|
+
import pandas as pd
|
|
47
|
+
import polars as pl
|
|
48
|
+
|
|
49
|
+
from ml4t.diagnostic.backends.adapter import DataFrameAdapter
|
|
50
|
+
|
|
51
|
+
if TYPE_CHECKING:
|
|
52
|
+
from numpy.typing import NDArray
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def validate_group_isolation(
|
|
56
|
+
train_indices: NDArray[np.intp],
|
|
57
|
+
test_indices: NDArray[np.intp],
|
|
58
|
+
groups: pl.Series | pd.Series | NDArray[Any],
|
|
59
|
+
) -> tuple[bool, set]:
|
|
60
|
+
"""Validate that train and test sets have no overlapping groups.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
train_indices : ndarray
|
|
65
|
+
Training set indices.
|
|
66
|
+
|
|
67
|
+
test_indices : ndarray
|
|
68
|
+
Test set indices.
|
|
69
|
+
|
|
70
|
+
groups : array-like
|
|
71
|
+
Group labels for each sample.
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
is_valid : bool
|
|
76
|
+
True if no groups overlap between train and test.
|
|
77
|
+
|
|
78
|
+
overlapping_groups : set
|
|
79
|
+
Set of group IDs that appear in both train and test.
|
|
80
|
+
Empty if is_valid=True.
|
|
81
|
+
|
|
82
|
+
Examples
|
|
83
|
+
--------
|
|
84
|
+
>>> import numpy as np
|
|
85
|
+
>>> train_idx = np.array([0, 1, 2, 3])
|
|
86
|
+
>>> test_idx = np.array([4, 5, 6, 7])
|
|
87
|
+
>>> groups = np.array(['A', 'A', 'B', 'B', 'C', 'C', 'D', 'D'])
|
|
88
|
+
>>> is_valid, overlap = validate_group_isolation(train_idx, test_idx, groups)
|
|
89
|
+
>>> assert is_valid # Groups don't overlap
|
|
90
|
+
>>> assert len(overlap) == 0
|
|
91
|
+
"""
|
|
92
|
+
# Convert groups to numpy array
|
|
93
|
+
groups_array = DataFrameAdapter.to_numpy(groups).flatten()
|
|
94
|
+
|
|
95
|
+
# Get unique groups in train and test
|
|
96
|
+
train_groups = set(groups_array[train_indices])
|
|
97
|
+
test_groups = set(groups_array[test_indices])
|
|
98
|
+
|
|
99
|
+
# Find overlap
|
|
100
|
+
overlapping_groups = train_groups & test_groups
|
|
101
|
+
|
|
102
|
+
return len(overlapping_groups) == 0, overlapping_groups
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def isolate_groups_from_train(
|
|
106
|
+
train_indices: NDArray[np.intp],
|
|
107
|
+
test_indices: NDArray[np.intp],
|
|
108
|
+
groups: pl.Series | pd.Series | NDArray[Any],
|
|
109
|
+
) -> NDArray[np.intp]:
|
|
110
|
+
"""Remove samples from training set that share groups with test set.
|
|
111
|
+
|
|
112
|
+
This function ensures strict group isolation by removing all training
|
|
113
|
+
samples whose group appears anywhere in the test set.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
train_indices : ndarray
|
|
118
|
+
Initial training set indices.
|
|
119
|
+
|
|
120
|
+
test_indices : ndarray
|
|
121
|
+
Test set indices.
|
|
122
|
+
|
|
123
|
+
groups : array-like
|
|
124
|
+
Group labels for each sample.
|
|
125
|
+
|
|
126
|
+
Returns
|
|
127
|
+
-------
|
|
128
|
+
clean_train_indices : ndarray
|
|
129
|
+
Training indices with test groups removed.
|
|
130
|
+
|
|
131
|
+
Examples
|
|
132
|
+
--------
|
|
133
|
+
>>> import numpy as np
|
|
134
|
+
>>> train_idx = np.array([0, 1, 2, 3, 4, 5])
|
|
135
|
+
>>> test_idx = np.array([6, 7])
|
|
136
|
+
>>> groups = np.array(['A', 'A', 'B', 'B', 'C', 'C', 'C', 'C'])
|
|
137
|
+
>>> clean_train = isolate_groups_from_train(train_idx, test_idx, groups)
|
|
138
|
+
>>> # Removes indices 4,5 because they share group 'C' with test indices 6,7
|
|
139
|
+
>>> assert all(groups[clean_train] != 'C')
|
|
140
|
+
|
|
141
|
+
Notes
|
|
142
|
+
-----
|
|
143
|
+
This can significantly reduce training set size if groups are imbalanced.
|
|
144
|
+
Consider using group-aware splitting strategies to maintain balanced folds.
|
|
145
|
+
"""
|
|
146
|
+
# Convert groups to numpy array
|
|
147
|
+
groups_array = DataFrameAdapter.to_numpy(groups).flatten()
|
|
148
|
+
|
|
149
|
+
# Get unique groups in test set
|
|
150
|
+
test_groups = set(groups_array[test_indices])
|
|
151
|
+
|
|
152
|
+
# Filter train indices to exclude any samples from test groups
|
|
153
|
+
clean_train_mask = np.array([groups_array[idx] not in test_groups for idx in train_indices])
|
|
154
|
+
|
|
155
|
+
return train_indices[clean_train_mask]
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def get_group_boundaries(
|
|
159
|
+
groups: pl.Series | pd.Series | NDArray[Any],
|
|
160
|
+
sorted_indices: NDArray[np.intp] | None = None,
|
|
161
|
+
) -> dict[Any, tuple[int, int]]:
|
|
162
|
+
"""Get start and end indices for each unique group in sorted data.
|
|
163
|
+
|
|
164
|
+
This is useful for group-aware splitting where you want to keep groups
|
|
165
|
+
contiguous and avoid splitting a group across train/test boundaries.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
groups : array-like
|
|
170
|
+
Group labels for each sample.
|
|
171
|
+
|
|
172
|
+
sorted_indices : ndarray, optional
|
|
173
|
+
Pre-sorted indices. If None, assumes data is already sorted by group.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
boundaries : dict
|
|
178
|
+
Mapping from group ID to (start_idx, end_idx) tuple.
|
|
179
|
+
|
|
180
|
+
Examples
|
|
181
|
+
--------
|
|
182
|
+
>>> import numpy as np
|
|
183
|
+
>>> groups = np.array(['A', 'A', 'A', 'B', 'B', 'C'])
|
|
184
|
+
>>> boundaries = get_group_boundaries(groups)
|
|
185
|
+
>>> assert boundaries['A'] == (0, 3)
|
|
186
|
+
>>> assert boundaries['B'] == (3, 5)
|
|
187
|
+
>>> assert boundaries['C'] == (5, 6)
|
|
188
|
+
|
|
189
|
+
Notes
|
|
190
|
+
-----
|
|
191
|
+
This assumes groups are contiguous in the data. If groups are interleaved,
|
|
192
|
+
provide `sorted_indices` to ensure correct boundary detection.
|
|
193
|
+
"""
|
|
194
|
+
# Convert groups to numpy array
|
|
195
|
+
groups_array = DataFrameAdapter.to_numpy(groups).flatten()
|
|
196
|
+
|
|
197
|
+
# Apply sorting if provided
|
|
198
|
+
if sorted_indices is not None:
|
|
199
|
+
groups_array = groups_array[sorted_indices]
|
|
200
|
+
|
|
201
|
+
# Find boundaries using change detection
|
|
202
|
+
boundaries = {}
|
|
203
|
+
unique_groups = []
|
|
204
|
+
current_group = None
|
|
205
|
+
start_idx = 0
|
|
206
|
+
|
|
207
|
+
for i, group_id in enumerate(groups_array):
|
|
208
|
+
if group_id != current_group:
|
|
209
|
+
# Group changed - record previous group's boundary
|
|
210
|
+
if current_group is not None:
|
|
211
|
+
boundaries[current_group] = (start_idx, i)
|
|
212
|
+
|
|
213
|
+
# Start new group
|
|
214
|
+
current_group = group_id
|
|
215
|
+
start_idx = i
|
|
216
|
+
unique_groups.append(group_id)
|
|
217
|
+
|
|
218
|
+
# Don't forget the last group
|
|
219
|
+
if current_group is not None:
|
|
220
|
+
boundaries[current_group] = (start_idx, len(groups_array))
|
|
221
|
+
|
|
222
|
+
return boundaries
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def split_by_groups(
|
|
226
|
+
n_samples: int,
|
|
227
|
+
groups: pl.Series | pd.Series | NDArray[Any],
|
|
228
|
+
test_group_indices: list[int],
|
|
229
|
+
all_group_ids: list[Any],
|
|
230
|
+
) -> tuple[NDArray[np.intp], NDArray[np.intp]]:
|
|
231
|
+
"""Split samples into train/test based on group assignments.
|
|
232
|
+
|
|
233
|
+
This creates a complete split where all samples from specified test groups
|
|
234
|
+
go to the test set, and all other samples go to the training set.
|
|
235
|
+
|
|
236
|
+
Parameters
|
|
237
|
+
----------
|
|
238
|
+
n_samples : int
|
|
239
|
+
Total number of samples.
|
|
240
|
+
|
|
241
|
+
groups : array-like
|
|
242
|
+
Group labels for each sample.
|
|
243
|
+
|
|
244
|
+
test_group_indices : list of int
|
|
245
|
+
Indices into `all_group_ids` specifying which groups go to test.
|
|
246
|
+
|
|
247
|
+
all_group_ids : list
|
|
248
|
+
Sorted list of all unique group IDs in the dataset.
|
|
249
|
+
|
|
250
|
+
Returns
|
|
251
|
+
-------
|
|
252
|
+
train_indices : ndarray
|
|
253
|
+
Indices of samples in training set.
|
|
254
|
+
|
|
255
|
+
test_indices : ndarray
|
|
256
|
+
Indices of samples in test set.
|
|
257
|
+
|
|
258
|
+
Examples
|
|
259
|
+
--------
|
|
260
|
+
>>> import numpy as np
|
|
261
|
+
>>> groups = np.array(['A', 'A', 'B', 'B', 'C', 'C'])
|
|
262
|
+
>>> all_groups = ['A', 'B', 'C']
|
|
263
|
+
>>> train_idx, test_idx = split_by_groups(
|
|
264
|
+
... n_samples=6,
|
|
265
|
+
... groups=groups,
|
|
266
|
+
... test_group_indices=[2], # Group 'C'
|
|
267
|
+
... all_group_ids=all_groups
|
|
268
|
+
... )
|
|
269
|
+
>>> assert set(groups[train_idx]) == {'A', 'B'}
|
|
270
|
+
>>> assert set(groups[test_idx]) == {'C'}
|
|
271
|
+
"""
|
|
272
|
+
# Convert groups to numpy array
|
|
273
|
+
groups_array = DataFrameAdapter.to_numpy(groups).flatten()
|
|
274
|
+
|
|
275
|
+
# Get test group IDs
|
|
276
|
+
test_group_ids = {all_group_ids[i] for i in test_group_indices}
|
|
277
|
+
|
|
278
|
+
# Create masks
|
|
279
|
+
test_mask = np.isin(groups_array, list(test_group_ids))
|
|
280
|
+
train_mask = ~test_mask
|
|
281
|
+
|
|
282
|
+
# Get indices
|
|
283
|
+
train_indices = np.where(train_mask)[0].astype(np.intp)
|
|
284
|
+
test_indices = np.where(test_mask)[0].astype(np.intp)
|
|
285
|
+
|
|
286
|
+
return train_indices, test_indices
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def count_samples_per_group(
|
|
290
|
+
groups: pl.Series | pd.Series | NDArray[Any],
|
|
291
|
+
) -> dict[Any, int]:
|
|
292
|
+
"""Count number of samples for each unique group.
|
|
293
|
+
|
|
294
|
+
Useful for understanding group distribution and detecting imbalanced groups.
|
|
295
|
+
|
|
296
|
+
Parameters
|
|
297
|
+
----------
|
|
298
|
+
groups : array-like
|
|
299
|
+
Group labels for each sample.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
counts : dict
|
|
304
|
+
Mapping from group ID to sample count.
|
|
305
|
+
|
|
306
|
+
Examples
|
|
307
|
+
--------
|
|
308
|
+
>>> import numpy as np
|
|
309
|
+
>>> groups = np.array(['A', 'A', 'A', 'B', 'B', 'C'])
|
|
310
|
+
>>> counts = count_samples_per_group(groups)
|
|
311
|
+
>>> assert counts == {'A': 3, 'B': 2, 'C': 1}
|
|
312
|
+
"""
|
|
313
|
+
# Convert groups to numpy array
|
|
314
|
+
groups_array = DataFrameAdapter.to_numpy(groups).flatten()
|
|
315
|
+
|
|
316
|
+
# Count using numpy unique
|
|
317
|
+
unique_groups, counts = np.unique(groups_array, return_counts=True)
|
|
318
|
+
|
|
319
|
+
return dict(zip(unique_groups, counts, strict=False))
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
# Make functions available at module level
|
|
323
|
+
__all__ = [
|
|
324
|
+
"validate_group_isolation",
|
|
325
|
+
"isolate_groups_from_train",
|
|
326
|
+
"get_group_boundaries",
|
|
327
|
+
"split_by_groups",
|
|
328
|
+
"count_samples_per_group",
|
|
329
|
+
]
|