ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
"""Calendar-aware time parsing for financial data cross-validation.
|
|
2
|
+
|
|
3
|
+
This module provides calendar-aware time period calculations for time-series CV,
|
|
4
|
+
ensuring that train/test splits respect trading calendar boundaries (sessions, weeks).
|
|
5
|
+
|
|
6
|
+
Key Features:
|
|
7
|
+
-----------
|
|
8
|
+
- Uses pandas_market_calendars for accurate trading session detection
|
|
9
|
+
- For intraday data: Sessions are atomic units (don't split trading sessions)
|
|
10
|
+
- For 'D' selections: Select complete trading sessions
|
|
11
|
+
- For 'W' selections: Select complete trading weeks (groups of sessions)
|
|
12
|
+
- Handles varying data density (dollar bars, trade bars) correctly
|
|
13
|
+
|
|
14
|
+
Background:
|
|
15
|
+
----------
|
|
16
|
+
Traditional time-based CV approaches use fixed sample counts computed from
|
|
17
|
+
time periods, which fails for activity-based data (dollar bars, trade bars) where
|
|
18
|
+
sample density varies with market activity. This module ensures proper time-based
|
|
19
|
+
selection by using calendar boundaries as atomic units.
|
|
20
|
+
|
|
21
|
+
Example Issue (Dollar Bars):
|
|
22
|
+
- High volatility week: 100K samples in 7 calendar days
|
|
23
|
+
- Low volatility week: 65K samples in 7 calendar days
|
|
24
|
+
- Fixed sample approach: 82K samples = 3.14 to 5.0 weeks (WRONG!)
|
|
25
|
+
- Calendar approach: Exactly 7 calendar days with varying samples (CORRECT!)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from typing import Any, cast
|
|
29
|
+
|
|
30
|
+
import numpy as np
|
|
31
|
+
import pandas as pd
|
|
32
|
+
import pytz
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
import pandas_market_calendars as mcal
|
|
36
|
+
|
|
37
|
+
HAS_MARKET_CALENDARS = True
|
|
38
|
+
except ImportError:
|
|
39
|
+
HAS_MARKET_CALENDARS = False
|
|
40
|
+
|
|
41
|
+
from ml4t.diagnostic.splitters.calendar_config import CalendarConfig
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TradingCalendar:
|
|
45
|
+
"""Trading calendar for session-aware time period calculations.
|
|
46
|
+
|
|
47
|
+
This class handles proper timezone conversion and trading session detection
|
|
48
|
+
for financial time-series cross-validation.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
config : CalendarConfig or str
|
|
53
|
+
Calendar configuration or exchange name (will use default config)
|
|
54
|
+
|
|
55
|
+
Attributes
|
|
56
|
+
----------
|
|
57
|
+
config : CalendarConfig
|
|
58
|
+
Configuration for calendar and timezone handling
|
|
59
|
+
calendar : mcal.MarketCalendar
|
|
60
|
+
The underlying market calendar instance
|
|
61
|
+
tz : pytz.timezone
|
|
62
|
+
Timezone object for conversions
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, config: CalendarConfig | str = "CME_Equity"):
|
|
66
|
+
"""Initialize trading calendar with configuration."""
|
|
67
|
+
if not HAS_MARKET_CALENDARS:
|
|
68
|
+
raise ImportError(
|
|
69
|
+
"pandas_market_calendars is required for calendar-aware CV. "
|
|
70
|
+
"Install with: pip install pandas_market_calendars"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Handle string input (exchange name) by creating default config
|
|
74
|
+
if isinstance(config, str):
|
|
75
|
+
from ml4t.diagnostic.splitters.calendar_config import CalendarConfig
|
|
76
|
+
|
|
77
|
+
config = CalendarConfig(exchange=config, timezone="UTC", localize_naive=True)
|
|
78
|
+
|
|
79
|
+
self.config = config
|
|
80
|
+
self.calendar = mcal.get_calendar(config.exchange)
|
|
81
|
+
self.tz = pytz.timezone(config.timezone)
|
|
82
|
+
|
|
83
|
+
def _ensure_timezone_aware(self, timestamps: pd.DatetimeIndex) -> pd.DatetimeIndex:
|
|
84
|
+
"""Ensure timestamps are timezone-aware.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
timestamps : pd.DatetimeIndex
|
|
89
|
+
Input timestamps (may be tz-naive or tz-aware)
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
pd.DatetimeIndex
|
|
94
|
+
Timezone-aware timestamps in calendar's timezone
|
|
95
|
+
"""
|
|
96
|
+
if timestamps.tz is None:
|
|
97
|
+
# Tz-naive data
|
|
98
|
+
if self.config.localize_naive:
|
|
99
|
+
# Localize to calendar timezone
|
|
100
|
+
return timestamps.tz_localize(self.tz)
|
|
101
|
+
else:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"Data is timezone-naive but localize_naive=False in config. "
|
|
104
|
+
f"Either localize data to {self.config.timezone} or set "
|
|
105
|
+
f"localize_naive=True in CalendarConfig."
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
# Tz-aware data - convert to calendar timezone
|
|
109
|
+
return timestamps.tz_convert(self.tz)
|
|
110
|
+
|
|
111
|
+
def get_sessions(
|
|
112
|
+
self,
|
|
113
|
+
timestamps: pd.DatetimeIndex,
|
|
114
|
+
) -> pd.Series:
|
|
115
|
+
"""Assign each timestamp to its trading session date (vectorized).
|
|
116
|
+
|
|
117
|
+
A trading session for futures typically runs from Sunday 5pm CT to Friday 4pm CT.
|
|
118
|
+
For stocks, it's the standard trading day.
|
|
119
|
+
|
|
120
|
+
Uses vectorized pandas operations for efficiency - handles 1M+ timestamps quickly.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
timestamps : pd.DatetimeIndex
|
|
125
|
+
Timestamps to assign to sessions (may be tz-naive or tz-aware)
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
pd.Series
|
|
130
|
+
Session dates for each timestamp (tz-naive dates, index matches timestamps)
|
|
131
|
+
"""
|
|
132
|
+
# Ensure all timestamps are in calendar timezone
|
|
133
|
+
timestamps_tz = self._ensure_timezone_aware(timestamps)
|
|
134
|
+
|
|
135
|
+
# Get schedule for the data period (with buffer for edge cases)
|
|
136
|
+
start_date = timestamps_tz[0].normalize() - pd.Timedelta(days=7)
|
|
137
|
+
end_date = timestamps_tz[-1].normalize() + pd.Timedelta(days=7)
|
|
138
|
+
|
|
139
|
+
# Get schedule (~250 sessions/year, very small)
|
|
140
|
+
schedule = self.calendar.schedule(start_date=start_date, end_date=end_date)
|
|
141
|
+
|
|
142
|
+
# Ensure schedule is in calendar timezone
|
|
143
|
+
if schedule["market_open"].dt.tz is None:
|
|
144
|
+
# Schedule is tz-naive - localize to calendar timezone
|
|
145
|
+
schedule["market_open"] = schedule["market_open"].dt.tz_localize(self.tz)
|
|
146
|
+
schedule["market_close"] = schedule["market_close"].dt.tz_localize(self.tz)
|
|
147
|
+
else:
|
|
148
|
+
# Schedule is tz-aware - convert to calendar timezone
|
|
149
|
+
schedule["market_open"] = schedule["market_open"].dt.tz_convert(self.tz)
|
|
150
|
+
schedule["market_close"] = schedule["market_close"].dt.tz_convert(self.tz)
|
|
151
|
+
|
|
152
|
+
# Vectorized assignment using merge_asof
|
|
153
|
+
# Create DataFrame with timestamps, preserving original index
|
|
154
|
+
df_ts = pd.DataFrame(
|
|
155
|
+
{"timestamp": timestamps_tz, "original_idx": range(len(timestamps_tz))}
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Create DataFrame with session boundaries
|
|
159
|
+
df_sessions = pd.DataFrame(
|
|
160
|
+
{
|
|
161
|
+
"session_date": schedule.index,
|
|
162
|
+
"market_open": schedule["market_open"],
|
|
163
|
+
"market_close": schedule["market_close"],
|
|
164
|
+
}
|
|
165
|
+
).reset_index(drop=True)
|
|
166
|
+
|
|
167
|
+
# Sort for merge_asof (requires sorted data)
|
|
168
|
+
df_ts_sorted = df_ts.sort_values("timestamp")
|
|
169
|
+
df_sessions_sorted = df_sessions.sort_values("market_open")
|
|
170
|
+
|
|
171
|
+
# First, assign based on market_open (find the session that opened before this timestamp)
|
|
172
|
+
df_merged = pd.merge_asof(
|
|
173
|
+
df_ts_sorted,
|
|
174
|
+
df_sessions_sorted,
|
|
175
|
+
left_on="timestamp",
|
|
176
|
+
right_on="market_open",
|
|
177
|
+
direction="backward",
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Now filter: only keep assignments where timestamp < market_close
|
|
181
|
+
# For timestamps outside any session, assign to next session
|
|
182
|
+
within_session = df_merged["timestamp"] < df_merged["market_close"]
|
|
183
|
+
|
|
184
|
+
# For timestamps outside sessions, use forward merge (next session)
|
|
185
|
+
if not within_session.all():
|
|
186
|
+
df_outside = df_merged[~within_session][["timestamp", "original_idx"]]
|
|
187
|
+
if len(df_outside) > 0:
|
|
188
|
+
df_outside_merged = pd.merge_asof(
|
|
189
|
+
df_outside,
|
|
190
|
+
df_sessions_sorted,
|
|
191
|
+
left_on="timestamp",
|
|
192
|
+
right_on="market_open",
|
|
193
|
+
direction="forward",
|
|
194
|
+
)
|
|
195
|
+
# Update session assignments for outside timestamps
|
|
196
|
+
df_merged.loc[~within_session, "session_date"] = df_outside_merged[
|
|
197
|
+
"session_date"
|
|
198
|
+
].values
|
|
199
|
+
|
|
200
|
+
# Return series with original index order
|
|
201
|
+
result = df_merged.sort_values("original_idx").set_index(timestamps)["session_date"]
|
|
202
|
+
return result
|
|
203
|
+
|
|
204
|
+
def count_samples_in_period(
|
|
205
|
+
self,
|
|
206
|
+
timestamps: pd.DatetimeIndex,
|
|
207
|
+
period_spec: str,
|
|
208
|
+
) -> list[int]:
|
|
209
|
+
"""Count samples in complete calendar periods across the dataset.
|
|
210
|
+
|
|
211
|
+
This method identifies complete periods (sessions, weeks, months) and counts
|
|
212
|
+
samples in each, providing the basis for calendar-aware fold creation.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
timestamps : pd.DatetimeIndex
|
|
217
|
+
Full dataset timestamps (may be tz-naive or tz-aware)
|
|
218
|
+
period_spec : str
|
|
219
|
+
Period specification (e.g., '1D', '4W', '3M')
|
|
220
|
+
|
|
221
|
+
Returns
|
|
222
|
+
-------
|
|
223
|
+
list[int]
|
|
224
|
+
Sample counts for each complete period found
|
|
225
|
+
|
|
226
|
+
Notes
|
|
227
|
+
-----
|
|
228
|
+
For intraday data with 'D' spec: Returns samples per session
|
|
229
|
+
For intraday data with 'W' spec: Returns samples per trading week
|
|
230
|
+
For daily data: Returns samples per calendar period
|
|
231
|
+
"""
|
|
232
|
+
import re
|
|
233
|
+
|
|
234
|
+
# Ensure timezone-aware
|
|
235
|
+
timestamps_tz = self._ensure_timezone_aware(timestamps)
|
|
236
|
+
|
|
237
|
+
# Parse period specification
|
|
238
|
+
match = re.match(r"(\d+)([DWM])", period_spec.upper())
|
|
239
|
+
if not match:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Invalid period specification '{period_spec}'. Use format like '1D', '4W', '3M'"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
n_periods = int(match.group(1))
|
|
245
|
+
freq = match.group(2)
|
|
246
|
+
|
|
247
|
+
# Determine if data is intraday (multiple samples per day)
|
|
248
|
+
df = pd.DataFrame({"timestamp": timestamps_tz})
|
|
249
|
+
# Cast to Any for DatetimeIndex.normalize() which is valid but type stubs don't recognize
|
|
250
|
+
daily_counts = df.groupby(cast(Any, timestamps_tz).normalize()).size()
|
|
251
|
+
is_intraday = (daily_counts > 1).any()
|
|
252
|
+
|
|
253
|
+
if is_intraday and freq in ["D", "W"]:
|
|
254
|
+
# Use trading calendar sessions
|
|
255
|
+
return self._count_samples_by_sessions(timestamps_tz, freq, n_periods)
|
|
256
|
+
else:
|
|
257
|
+
# Use calendar periods for daily data or monthly specs
|
|
258
|
+
return self._count_samples_by_calendar(timestamps_tz, freq, n_periods)
|
|
259
|
+
|
|
260
|
+
def _count_samples_by_sessions(
|
|
261
|
+
self,
|
|
262
|
+
timestamps: pd.DatetimeIndex,
|
|
263
|
+
freq: str,
|
|
264
|
+
n_periods: int,
|
|
265
|
+
) -> list[int]:
|
|
266
|
+
"""Count samples by trading sessions.
|
|
267
|
+
|
|
268
|
+
For 'D': Each session is one period
|
|
269
|
+
For 'W': Each n_periods sessions form one period (e.g., 5 sessions = 1 week)
|
|
270
|
+
"""
|
|
271
|
+
# Assign each timestamp to its session
|
|
272
|
+
sessions = self.get_sessions(timestamps)
|
|
273
|
+
|
|
274
|
+
# Get unique sessions in order
|
|
275
|
+
unique_sessions = np.sort(cast(Any, sessions.unique()))
|
|
276
|
+
|
|
277
|
+
if freq == "D":
|
|
278
|
+
# Each session is one period
|
|
279
|
+
sample_counts = []
|
|
280
|
+
for session in unique_sessions:
|
|
281
|
+
count = (sessions == session).sum()
|
|
282
|
+
sample_counts.append(count)
|
|
283
|
+
return sample_counts
|
|
284
|
+
|
|
285
|
+
elif freq == "W":
|
|
286
|
+
# Group sessions into weeks, then count samples in n_periods weeks
|
|
287
|
+
# For '4W': 4 weeks × 5 sessions/week = 20 sessions per period
|
|
288
|
+
# Standard trading week = 5 sessions (Mon-Fri)
|
|
289
|
+
sessions_per_week = 5
|
|
290
|
+
sessions_per_period = sessions_per_week * n_periods # e.g., 5 × 4 = 20
|
|
291
|
+
|
|
292
|
+
sample_counts = []
|
|
293
|
+
for i in range(0, len(unique_sessions), sessions_per_period):
|
|
294
|
+
period_sessions = unique_sessions[i : i + sessions_per_period]
|
|
295
|
+
if len(period_sessions) == sessions_per_period:
|
|
296
|
+
# Only count complete periods (complete 4-week blocks)
|
|
297
|
+
count = sessions.isin(period_sessions).sum()
|
|
298
|
+
sample_counts.append(count)
|
|
299
|
+
return sample_counts
|
|
300
|
+
|
|
301
|
+
return []
|
|
302
|
+
|
|
303
|
+
def _count_samples_by_calendar(
|
|
304
|
+
self,
|
|
305
|
+
timestamps: pd.DatetimeIndex,
|
|
306
|
+
freq: str,
|
|
307
|
+
_n_periods: int,
|
|
308
|
+
) -> list[int]:
|
|
309
|
+
"""Count samples by calendar periods (for daily data or monthly specs)."""
|
|
310
|
+
# Group by calendar period
|
|
311
|
+
if freq == "D":
|
|
312
|
+
period_groups = cast(Any, timestamps).normalize()
|
|
313
|
+
elif freq == "W":
|
|
314
|
+
# Group by week start (Monday)
|
|
315
|
+
period_groups = timestamps.to_period("W").to_timestamp()
|
|
316
|
+
elif freq == "M":
|
|
317
|
+
# Group by month start
|
|
318
|
+
period_groups = timestamps.to_period("M").to_timestamp()
|
|
319
|
+
else:
|
|
320
|
+
raise ValueError(f"Unsupported frequency: {freq}")
|
|
321
|
+
|
|
322
|
+
# Count samples per period
|
|
323
|
+
df = pd.DataFrame({"period": period_groups})
|
|
324
|
+
counts = df.groupby("period").size()
|
|
325
|
+
|
|
326
|
+
return counts.values.tolist()
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def parse_time_size_calendar_aware(
|
|
330
|
+
size_spec: str,
|
|
331
|
+
timestamps: pd.DatetimeIndex,
|
|
332
|
+
calendar: TradingCalendar | None = None,
|
|
333
|
+
) -> int:
|
|
334
|
+
"""Parse time-based size specification using calendar-aware logic.
|
|
335
|
+
|
|
336
|
+
This function replaces the naive sample-counting approach with proper
|
|
337
|
+
calendar-based selection that respects trading session boundaries.
|
|
338
|
+
|
|
339
|
+
Parameters
|
|
340
|
+
----------
|
|
341
|
+
size_spec : str
|
|
342
|
+
Time period specification (e.g., '4W', '1D', '3M')
|
|
343
|
+
timestamps : pd.DatetimeIndex
|
|
344
|
+
Timestamps from the dataset
|
|
345
|
+
calendar : TradingCalendar, optional
|
|
346
|
+
Trading calendar to use. If None, uses naive time-based calculation.
|
|
347
|
+
|
|
348
|
+
Returns
|
|
349
|
+
-------
|
|
350
|
+
int
|
|
351
|
+
Number of samples corresponding to the time period
|
|
352
|
+
|
|
353
|
+
Notes
|
|
354
|
+
-----
|
|
355
|
+
Key difference from naive approach:
|
|
356
|
+
- Naive: Computes median samples/period, returns fixed count
|
|
357
|
+
- Calendar-aware: Returns sample count for actual calendar period
|
|
358
|
+
|
|
359
|
+
For activity-based data (dollar bars, trade bars), the calendar-aware
|
|
360
|
+
approach correctly allows sample counts to vary by market activity.
|
|
361
|
+
|
|
362
|
+
Examples
|
|
363
|
+
--------
|
|
364
|
+
>>> timestamps = pd.date_range('2024-01-01', periods=10000, freq='1min')
|
|
365
|
+
>>> calendar = TradingCalendar('CME_Equity')
|
|
366
|
+
>>> # Returns samples in exactly 4 trading weeks
|
|
367
|
+
>>> n_samples = parse_time_size_calendar_aware('4W', timestamps, calendar)
|
|
368
|
+
"""
|
|
369
|
+
if calendar is None:
|
|
370
|
+
# Fallback to naive time-based calculation
|
|
371
|
+
return _parse_time_size_naive(size_spec, timestamps)
|
|
372
|
+
|
|
373
|
+
# Use calendar-aware counting
|
|
374
|
+
sample_counts = calendar.count_samples_in_period(timestamps, size_spec)
|
|
375
|
+
|
|
376
|
+
if not sample_counts:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
f"Could not find any complete periods matching '{size_spec}' in the provided timestamps"
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Use median sample count as representative value
|
|
382
|
+
# This handles variability in activity-based data (dollar/trade bars)
|
|
383
|
+
median_count = int(np.median(sample_counts))
|
|
384
|
+
|
|
385
|
+
return median_count
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def _parse_time_size_naive(
|
|
389
|
+
size_spec: str,
|
|
390
|
+
timestamps: pd.DatetimeIndex,
|
|
391
|
+
) -> int:
|
|
392
|
+
"""Naive time-based size calculation (fallback when no calendar provided).
|
|
393
|
+
|
|
394
|
+
This is the original ml4t-diagnostic logic - kept for backward compatibility.
|
|
395
|
+
"""
|
|
396
|
+
|
|
397
|
+
# Parse the time period
|
|
398
|
+
try:
|
|
399
|
+
time_delta = pd.Timedelta(size_spec)
|
|
400
|
+
except ValueError:
|
|
401
|
+
try:
|
|
402
|
+
offset = pd.tseries.frequencies.to_offset(size_spec)
|
|
403
|
+
ref_date = timestamps[0]
|
|
404
|
+
time_delta = (ref_date + offset) - ref_date
|
|
405
|
+
except Exception as e:
|
|
406
|
+
raise ValueError(
|
|
407
|
+
f"Invalid time specification '{size_spec}'. "
|
|
408
|
+
f"Use pandas offset aliases like '4W', '30D', '3M', '1Y'. "
|
|
409
|
+
f"Error: {e}"
|
|
410
|
+
) from e
|
|
411
|
+
|
|
412
|
+
# Simple proportion-based calculation
|
|
413
|
+
total_duration = timestamps[-1] - timestamps[0]
|
|
414
|
+
if total_duration.total_seconds() == 0:
|
|
415
|
+
raise ValueError("Cannot calculate time-based size for single-timestamp data")
|
|
416
|
+
|
|
417
|
+
n_samples = len(timestamps)
|
|
418
|
+
samples_per_second = n_samples / total_duration.total_seconds()
|
|
419
|
+
size_in_samples = int(samples_per_second * time_delta.total_seconds())
|
|
420
|
+
|
|
421
|
+
return size_in_samples
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Configuration for calendar-aware cross-validation.
|
|
2
|
+
|
|
3
|
+
This module defines configuration schemas for trading calendar integration,
|
|
4
|
+
ensuring proper timezone handling and session awareness.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CalendarConfig(BaseModel):
|
|
11
|
+
"""Configuration for trading calendar in cross-validation.
|
|
12
|
+
|
|
13
|
+
This configuration ensures proper handling of:
|
|
14
|
+
- Trading sessions (don't split session boundaries)
|
|
15
|
+
- Timezones (consistent tz-aware comparisons)
|
|
16
|
+
- Market-specific calendars (CME, NYSE, LSE, etc.)
|
|
17
|
+
|
|
18
|
+
Attributes
|
|
19
|
+
----------
|
|
20
|
+
exchange : str
|
|
21
|
+
Name of the exchange calendar from pandas_market_calendars.
|
|
22
|
+
Examples: 'CME_Equity', 'NYSE', 'LSE', 'TSX', 'HKEX'
|
|
23
|
+
See: https://pandas-market-calendars.readthedocs.io/
|
|
24
|
+
|
|
25
|
+
timezone : str, default='UTC'
|
|
26
|
+
Timezone for calendar operations. All timestamps will be converted
|
|
27
|
+
to this timezone for calendar comparisons.
|
|
28
|
+
- 'UTC': Universal Coordinated Time (default, safest)
|
|
29
|
+
- 'America/New_York': US Eastern (NYSE, NASDAQ)
|
|
30
|
+
- 'America/Chicago': US Central (CME futures)
|
|
31
|
+
- 'Europe/London': UK (LSE)
|
|
32
|
+
- See pytz documentation for full list
|
|
33
|
+
|
|
34
|
+
localize_naive : bool, default=True
|
|
35
|
+
If True, tz-naive data will be localized to the specified timezone.
|
|
36
|
+
If False, tz-naive data will raise an error.
|
|
37
|
+
Recommended: True for safety (assumes data is in calendar timezone)
|
|
38
|
+
|
|
39
|
+
Examples
|
|
40
|
+
--------
|
|
41
|
+
For CME futures (NQ, ES, etc.):
|
|
42
|
+
>>> config = CalendarConfig(
|
|
43
|
+
... exchange='CME_Equity',
|
|
44
|
+
... timezone='America/Chicago'
|
|
45
|
+
... )
|
|
46
|
+
|
|
47
|
+
For US equities:
|
|
48
|
+
>>> config = CalendarConfig(
|
|
49
|
+
... exchange='NYSE',
|
|
50
|
+
... timezone='America/New_York'
|
|
51
|
+
... )
|
|
52
|
+
|
|
53
|
+
For international markets:
|
|
54
|
+
>>> config = CalendarConfig(
|
|
55
|
+
... exchange='LSE',
|
|
56
|
+
... timezone='Europe/London'
|
|
57
|
+
... )
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
exchange: str = Field(..., description="Exchange calendar name from pandas_market_calendars")
|
|
61
|
+
|
|
62
|
+
timezone: str = Field(
|
|
63
|
+
default="UTC", description="Timezone for calendar operations (pytz timezone name)"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
localize_naive: bool = Field(
|
|
67
|
+
default=True, description="Whether to localize tz-naive data to the specified timezone"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
class Config:
|
|
71
|
+
"""Pydantic configuration."""
|
|
72
|
+
|
|
73
|
+
frozen = True # Immutable after creation
|
|
74
|
+
|
|
75
|
+
def __repr__(self) -> str:
|
|
76
|
+
"""String representation."""
|
|
77
|
+
return (
|
|
78
|
+
f"CalendarConfig(exchange='{self.exchange}', "
|
|
79
|
+
f"timezone='{self.timezone}', "
|
|
80
|
+
f"localize_naive={self.localize_naive})"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# Preset configurations for common markets
|
|
85
|
+
CME_CONFIG = CalendarConfig(exchange="CME_Equity", timezone="America/Chicago", localize_naive=True)
|
|
86
|
+
|
|
87
|
+
NYSE_CONFIG = CalendarConfig(exchange="NYSE", timezone="America/New_York", localize_naive=True)
|
|
88
|
+
|
|
89
|
+
NASDAQ_CONFIG = CalendarConfig(exchange="NASDAQ", timezone="America/New_York", localize_naive=True)
|
|
90
|
+
|
|
91
|
+
LSE_CONFIG = CalendarConfig(exchange="LSE", timezone="Europe/London", localize_naive=True)
|