ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,421 @@
1
+ """Calendar-aware time parsing for financial data cross-validation.
2
+
3
+ This module provides calendar-aware time period calculations for time-series CV,
4
+ ensuring that train/test splits respect trading calendar boundaries (sessions, weeks).
5
+
6
+ Key Features:
7
+ -----------
8
+ - Uses pandas_market_calendars for accurate trading session detection
9
+ - For intraday data: Sessions are atomic units (don't split trading sessions)
10
+ - For 'D' selections: Select complete trading sessions
11
+ - For 'W' selections: Select complete trading weeks (groups of sessions)
12
+ - Handles varying data density (dollar bars, trade bars) correctly
13
+
14
+ Background:
15
+ ----------
16
+ Traditional time-based CV approaches use fixed sample counts computed from
17
+ time periods, which fails for activity-based data (dollar bars, trade bars) where
18
+ sample density varies with market activity. This module ensures proper time-based
19
+ selection by using calendar boundaries as atomic units.
20
+
21
+ Example Issue (Dollar Bars):
22
+ - High volatility week: 100K samples in 7 calendar days
23
+ - Low volatility week: 65K samples in 7 calendar days
24
+ - Fixed sample approach: 82K samples = 3.14 to 5.0 weeks (WRONG!)
25
+ - Calendar approach: Exactly 7 calendar days with varying samples (CORRECT!)
26
+ """
27
+
28
+ from typing import Any, cast
29
+
30
+ import numpy as np
31
+ import pandas as pd
32
+ import pytz
33
+
34
+ try:
35
+ import pandas_market_calendars as mcal
36
+
37
+ HAS_MARKET_CALENDARS = True
38
+ except ImportError:
39
+ HAS_MARKET_CALENDARS = False
40
+
41
+ from ml4t.diagnostic.splitters.calendar_config import CalendarConfig
42
+
43
+
44
+ class TradingCalendar:
45
+ """Trading calendar for session-aware time period calculations.
46
+
47
+ This class handles proper timezone conversion and trading session detection
48
+ for financial time-series cross-validation.
49
+
50
+ Parameters
51
+ ----------
52
+ config : CalendarConfig or str
53
+ Calendar configuration or exchange name (will use default config)
54
+
55
+ Attributes
56
+ ----------
57
+ config : CalendarConfig
58
+ Configuration for calendar and timezone handling
59
+ calendar : mcal.MarketCalendar
60
+ The underlying market calendar instance
61
+ tz : pytz.timezone
62
+ Timezone object for conversions
63
+ """
64
+
65
+ def __init__(self, config: CalendarConfig | str = "CME_Equity"):
66
+ """Initialize trading calendar with configuration."""
67
+ if not HAS_MARKET_CALENDARS:
68
+ raise ImportError(
69
+ "pandas_market_calendars is required for calendar-aware CV. "
70
+ "Install with: pip install pandas_market_calendars"
71
+ )
72
+
73
+ # Handle string input (exchange name) by creating default config
74
+ if isinstance(config, str):
75
+ from ml4t.diagnostic.splitters.calendar_config import CalendarConfig
76
+
77
+ config = CalendarConfig(exchange=config, timezone="UTC", localize_naive=True)
78
+
79
+ self.config = config
80
+ self.calendar = mcal.get_calendar(config.exchange)
81
+ self.tz = pytz.timezone(config.timezone)
82
+
83
+ def _ensure_timezone_aware(self, timestamps: pd.DatetimeIndex) -> pd.DatetimeIndex:
84
+ """Ensure timestamps are timezone-aware.
85
+
86
+ Parameters
87
+ ----------
88
+ timestamps : pd.DatetimeIndex
89
+ Input timestamps (may be tz-naive or tz-aware)
90
+
91
+ Returns
92
+ -------
93
+ pd.DatetimeIndex
94
+ Timezone-aware timestamps in calendar's timezone
95
+ """
96
+ if timestamps.tz is None:
97
+ # Tz-naive data
98
+ if self.config.localize_naive:
99
+ # Localize to calendar timezone
100
+ return timestamps.tz_localize(self.tz)
101
+ else:
102
+ raise ValueError(
103
+ f"Data is timezone-naive but localize_naive=False in config. "
104
+ f"Either localize data to {self.config.timezone} or set "
105
+ f"localize_naive=True in CalendarConfig."
106
+ )
107
+ else:
108
+ # Tz-aware data - convert to calendar timezone
109
+ return timestamps.tz_convert(self.tz)
110
+
111
+ def get_sessions(
112
+ self,
113
+ timestamps: pd.DatetimeIndex,
114
+ ) -> pd.Series:
115
+ """Assign each timestamp to its trading session date (vectorized).
116
+
117
+ A trading session for futures typically runs from Sunday 5pm CT to Friday 4pm CT.
118
+ For stocks, it's the standard trading day.
119
+
120
+ Uses vectorized pandas operations for efficiency - handles 1M+ timestamps quickly.
121
+
122
+ Parameters
123
+ ----------
124
+ timestamps : pd.DatetimeIndex
125
+ Timestamps to assign to sessions (may be tz-naive or tz-aware)
126
+
127
+ Returns
128
+ -------
129
+ pd.Series
130
+ Session dates for each timestamp (tz-naive dates, index matches timestamps)
131
+ """
132
+ # Ensure all timestamps are in calendar timezone
133
+ timestamps_tz = self._ensure_timezone_aware(timestamps)
134
+
135
+ # Get schedule for the data period (with buffer for edge cases)
136
+ start_date = timestamps_tz[0].normalize() - pd.Timedelta(days=7)
137
+ end_date = timestamps_tz[-1].normalize() + pd.Timedelta(days=7)
138
+
139
+ # Get schedule (~250 sessions/year, very small)
140
+ schedule = self.calendar.schedule(start_date=start_date, end_date=end_date)
141
+
142
+ # Ensure schedule is in calendar timezone
143
+ if schedule["market_open"].dt.tz is None:
144
+ # Schedule is tz-naive - localize to calendar timezone
145
+ schedule["market_open"] = schedule["market_open"].dt.tz_localize(self.tz)
146
+ schedule["market_close"] = schedule["market_close"].dt.tz_localize(self.tz)
147
+ else:
148
+ # Schedule is tz-aware - convert to calendar timezone
149
+ schedule["market_open"] = schedule["market_open"].dt.tz_convert(self.tz)
150
+ schedule["market_close"] = schedule["market_close"].dt.tz_convert(self.tz)
151
+
152
+ # Vectorized assignment using merge_asof
153
+ # Create DataFrame with timestamps, preserving original index
154
+ df_ts = pd.DataFrame(
155
+ {"timestamp": timestamps_tz, "original_idx": range(len(timestamps_tz))}
156
+ )
157
+
158
+ # Create DataFrame with session boundaries
159
+ df_sessions = pd.DataFrame(
160
+ {
161
+ "session_date": schedule.index,
162
+ "market_open": schedule["market_open"],
163
+ "market_close": schedule["market_close"],
164
+ }
165
+ ).reset_index(drop=True)
166
+
167
+ # Sort for merge_asof (requires sorted data)
168
+ df_ts_sorted = df_ts.sort_values("timestamp")
169
+ df_sessions_sorted = df_sessions.sort_values("market_open")
170
+
171
+ # First, assign based on market_open (find the session that opened before this timestamp)
172
+ df_merged = pd.merge_asof(
173
+ df_ts_sorted,
174
+ df_sessions_sorted,
175
+ left_on="timestamp",
176
+ right_on="market_open",
177
+ direction="backward",
178
+ )
179
+
180
+ # Now filter: only keep assignments where timestamp < market_close
181
+ # For timestamps outside any session, assign to next session
182
+ within_session = df_merged["timestamp"] < df_merged["market_close"]
183
+
184
+ # For timestamps outside sessions, use forward merge (next session)
185
+ if not within_session.all():
186
+ df_outside = df_merged[~within_session][["timestamp", "original_idx"]]
187
+ if len(df_outside) > 0:
188
+ df_outside_merged = pd.merge_asof(
189
+ df_outside,
190
+ df_sessions_sorted,
191
+ left_on="timestamp",
192
+ right_on="market_open",
193
+ direction="forward",
194
+ )
195
+ # Update session assignments for outside timestamps
196
+ df_merged.loc[~within_session, "session_date"] = df_outside_merged[
197
+ "session_date"
198
+ ].values
199
+
200
+ # Return series with original index order
201
+ result = df_merged.sort_values("original_idx").set_index(timestamps)["session_date"]
202
+ return result
203
+
204
+ def count_samples_in_period(
205
+ self,
206
+ timestamps: pd.DatetimeIndex,
207
+ period_spec: str,
208
+ ) -> list[int]:
209
+ """Count samples in complete calendar periods across the dataset.
210
+
211
+ This method identifies complete periods (sessions, weeks, months) and counts
212
+ samples in each, providing the basis for calendar-aware fold creation.
213
+
214
+ Parameters
215
+ ----------
216
+ timestamps : pd.DatetimeIndex
217
+ Full dataset timestamps (may be tz-naive or tz-aware)
218
+ period_spec : str
219
+ Period specification (e.g., '1D', '4W', '3M')
220
+
221
+ Returns
222
+ -------
223
+ list[int]
224
+ Sample counts for each complete period found
225
+
226
+ Notes
227
+ -----
228
+ For intraday data with 'D' spec: Returns samples per session
229
+ For intraday data with 'W' spec: Returns samples per trading week
230
+ For daily data: Returns samples per calendar period
231
+ """
232
+ import re
233
+
234
+ # Ensure timezone-aware
235
+ timestamps_tz = self._ensure_timezone_aware(timestamps)
236
+
237
+ # Parse period specification
238
+ match = re.match(r"(\d+)([DWM])", period_spec.upper())
239
+ if not match:
240
+ raise ValueError(
241
+ f"Invalid period specification '{period_spec}'. Use format like '1D', '4W', '3M'"
242
+ )
243
+
244
+ n_periods = int(match.group(1))
245
+ freq = match.group(2)
246
+
247
+ # Determine if data is intraday (multiple samples per day)
248
+ df = pd.DataFrame({"timestamp": timestamps_tz})
249
+ # Cast to Any for DatetimeIndex.normalize() which is valid but type stubs don't recognize
250
+ daily_counts = df.groupby(cast(Any, timestamps_tz).normalize()).size()
251
+ is_intraday = (daily_counts > 1).any()
252
+
253
+ if is_intraday and freq in ["D", "W"]:
254
+ # Use trading calendar sessions
255
+ return self._count_samples_by_sessions(timestamps_tz, freq, n_periods)
256
+ else:
257
+ # Use calendar periods for daily data or monthly specs
258
+ return self._count_samples_by_calendar(timestamps_tz, freq, n_periods)
259
+
260
+ def _count_samples_by_sessions(
261
+ self,
262
+ timestamps: pd.DatetimeIndex,
263
+ freq: str,
264
+ n_periods: int,
265
+ ) -> list[int]:
266
+ """Count samples by trading sessions.
267
+
268
+ For 'D': Each session is one period
269
+ For 'W': Each n_periods sessions form one period (e.g., 5 sessions = 1 week)
270
+ """
271
+ # Assign each timestamp to its session
272
+ sessions = self.get_sessions(timestamps)
273
+
274
+ # Get unique sessions in order
275
+ unique_sessions = np.sort(cast(Any, sessions.unique()))
276
+
277
+ if freq == "D":
278
+ # Each session is one period
279
+ sample_counts = []
280
+ for session in unique_sessions:
281
+ count = (sessions == session).sum()
282
+ sample_counts.append(count)
283
+ return sample_counts
284
+
285
+ elif freq == "W":
286
+ # Group sessions into weeks, then count samples in n_periods weeks
287
+ # For '4W': 4 weeks × 5 sessions/week = 20 sessions per period
288
+ # Standard trading week = 5 sessions (Mon-Fri)
289
+ sessions_per_week = 5
290
+ sessions_per_period = sessions_per_week * n_periods # e.g., 5 × 4 = 20
291
+
292
+ sample_counts = []
293
+ for i in range(0, len(unique_sessions), sessions_per_period):
294
+ period_sessions = unique_sessions[i : i + sessions_per_period]
295
+ if len(period_sessions) == sessions_per_period:
296
+ # Only count complete periods (complete 4-week blocks)
297
+ count = sessions.isin(period_sessions).sum()
298
+ sample_counts.append(count)
299
+ return sample_counts
300
+
301
+ return []
302
+
303
+ def _count_samples_by_calendar(
304
+ self,
305
+ timestamps: pd.DatetimeIndex,
306
+ freq: str,
307
+ _n_periods: int,
308
+ ) -> list[int]:
309
+ """Count samples by calendar periods (for daily data or monthly specs)."""
310
+ # Group by calendar period
311
+ if freq == "D":
312
+ period_groups = cast(Any, timestamps).normalize()
313
+ elif freq == "W":
314
+ # Group by week start (Monday)
315
+ period_groups = timestamps.to_period("W").to_timestamp()
316
+ elif freq == "M":
317
+ # Group by month start
318
+ period_groups = timestamps.to_period("M").to_timestamp()
319
+ else:
320
+ raise ValueError(f"Unsupported frequency: {freq}")
321
+
322
+ # Count samples per period
323
+ df = pd.DataFrame({"period": period_groups})
324
+ counts = df.groupby("period").size()
325
+
326
+ return counts.values.tolist()
327
+
328
+
329
+ def parse_time_size_calendar_aware(
330
+ size_spec: str,
331
+ timestamps: pd.DatetimeIndex,
332
+ calendar: TradingCalendar | None = None,
333
+ ) -> int:
334
+ """Parse time-based size specification using calendar-aware logic.
335
+
336
+ This function replaces the naive sample-counting approach with proper
337
+ calendar-based selection that respects trading session boundaries.
338
+
339
+ Parameters
340
+ ----------
341
+ size_spec : str
342
+ Time period specification (e.g., '4W', '1D', '3M')
343
+ timestamps : pd.DatetimeIndex
344
+ Timestamps from the dataset
345
+ calendar : TradingCalendar, optional
346
+ Trading calendar to use. If None, uses naive time-based calculation.
347
+
348
+ Returns
349
+ -------
350
+ int
351
+ Number of samples corresponding to the time period
352
+
353
+ Notes
354
+ -----
355
+ Key difference from naive approach:
356
+ - Naive: Computes median samples/period, returns fixed count
357
+ - Calendar-aware: Returns sample count for actual calendar period
358
+
359
+ For activity-based data (dollar bars, trade bars), the calendar-aware
360
+ approach correctly allows sample counts to vary by market activity.
361
+
362
+ Examples
363
+ --------
364
+ >>> timestamps = pd.date_range('2024-01-01', periods=10000, freq='1min')
365
+ >>> calendar = TradingCalendar('CME_Equity')
366
+ >>> # Returns samples in exactly 4 trading weeks
367
+ >>> n_samples = parse_time_size_calendar_aware('4W', timestamps, calendar)
368
+ """
369
+ if calendar is None:
370
+ # Fallback to naive time-based calculation
371
+ return _parse_time_size_naive(size_spec, timestamps)
372
+
373
+ # Use calendar-aware counting
374
+ sample_counts = calendar.count_samples_in_period(timestamps, size_spec)
375
+
376
+ if not sample_counts:
377
+ raise ValueError(
378
+ f"Could not find any complete periods matching '{size_spec}' in the provided timestamps"
379
+ )
380
+
381
+ # Use median sample count as representative value
382
+ # This handles variability in activity-based data (dollar/trade bars)
383
+ median_count = int(np.median(sample_counts))
384
+
385
+ return median_count
386
+
387
+
388
+ def _parse_time_size_naive(
389
+ size_spec: str,
390
+ timestamps: pd.DatetimeIndex,
391
+ ) -> int:
392
+ """Naive time-based size calculation (fallback when no calendar provided).
393
+
394
+ This is the original ml4t-diagnostic logic - kept for backward compatibility.
395
+ """
396
+
397
+ # Parse the time period
398
+ try:
399
+ time_delta = pd.Timedelta(size_spec)
400
+ except ValueError:
401
+ try:
402
+ offset = pd.tseries.frequencies.to_offset(size_spec)
403
+ ref_date = timestamps[0]
404
+ time_delta = (ref_date + offset) - ref_date
405
+ except Exception as e:
406
+ raise ValueError(
407
+ f"Invalid time specification '{size_spec}'. "
408
+ f"Use pandas offset aliases like '4W', '30D', '3M', '1Y'. "
409
+ f"Error: {e}"
410
+ ) from e
411
+
412
+ # Simple proportion-based calculation
413
+ total_duration = timestamps[-1] - timestamps[0]
414
+ if total_duration.total_seconds() == 0:
415
+ raise ValueError("Cannot calculate time-based size for single-timestamp data")
416
+
417
+ n_samples = len(timestamps)
418
+ samples_per_second = n_samples / total_duration.total_seconds()
419
+ size_in_samples = int(samples_per_second * time_delta.total_seconds())
420
+
421
+ return size_in_samples
@@ -0,0 +1,91 @@
1
+ """Configuration for calendar-aware cross-validation.
2
+
3
+ This module defines configuration schemas for trading calendar integration,
4
+ ensuring proper timezone handling and session awareness.
5
+ """
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class CalendarConfig(BaseModel):
11
+ """Configuration for trading calendar in cross-validation.
12
+
13
+ This configuration ensures proper handling of:
14
+ - Trading sessions (don't split session boundaries)
15
+ - Timezones (consistent tz-aware comparisons)
16
+ - Market-specific calendars (CME, NYSE, LSE, etc.)
17
+
18
+ Attributes
19
+ ----------
20
+ exchange : str
21
+ Name of the exchange calendar from pandas_market_calendars.
22
+ Examples: 'CME_Equity', 'NYSE', 'LSE', 'TSX', 'HKEX'
23
+ See: https://pandas-market-calendars.readthedocs.io/
24
+
25
+ timezone : str, default='UTC'
26
+ Timezone for calendar operations. All timestamps will be converted
27
+ to this timezone for calendar comparisons.
28
+ - 'UTC': Universal Coordinated Time (default, safest)
29
+ - 'America/New_York': US Eastern (NYSE, NASDAQ)
30
+ - 'America/Chicago': US Central (CME futures)
31
+ - 'Europe/London': UK (LSE)
32
+ - See pytz documentation for full list
33
+
34
+ localize_naive : bool, default=True
35
+ If True, tz-naive data will be localized to the specified timezone.
36
+ If False, tz-naive data will raise an error.
37
+ Recommended: True for safety (assumes data is in calendar timezone)
38
+
39
+ Examples
40
+ --------
41
+ For CME futures (NQ, ES, etc.):
42
+ >>> config = CalendarConfig(
43
+ ... exchange='CME_Equity',
44
+ ... timezone='America/Chicago'
45
+ ... )
46
+
47
+ For US equities:
48
+ >>> config = CalendarConfig(
49
+ ... exchange='NYSE',
50
+ ... timezone='America/New_York'
51
+ ... )
52
+
53
+ For international markets:
54
+ >>> config = CalendarConfig(
55
+ ... exchange='LSE',
56
+ ... timezone='Europe/London'
57
+ ... )
58
+ """
59
+
60
+ exchange: str = Field(..., description="Exchange calendar name from pandas_market_calendars")
61
+
62
+ timezone: str = Field(
63
+ default="UTC", description="Timezone for calendar operations (pytz timezone name)"
64
+ )
65
+
66
+ localize_naive: bool = Field(
67
+ default=True, description="Whether to localize tz-naive data to the specified timezone"
68
+ )
69
+
70
+ class Config:
71
+ """Pydantic configuration."""
72
+
73
+ frozen = True # Immutable after creation
74
+
75
+ def __repr__(self) -> str:
76
+ """String representation."""
77
+ return (
78
+ f"CalendarConfig(exchange='{self.exchange}', "
79
+ f"timezone='{self.timezone}', "
80
+ f"localize_naive={self.localize_naive})"
81
+ )
82
+
83
+
84
+ # Preset configurations for common markets
85
+ CME_CONFIG = CalendarConfig(exchange="CME_Equity", timezone="America/Chicago", localize_naive=True)
86
+
87
+ NYSE_CONFIG = CalendarConfig(exchange="NYSE", timezone="America/New_York", localize_naive=True)
88
+
89
+ NASDAQ_CONFIG = CalendarConfig(exchange="NASDAQ", timezone="America/New_York", localize_naive=True)
90
+
91
+ LSE_CONFIG = CalendarConfig(exchange="LSE", timezone="Europe/London", localize_naive=True)