ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,471 @@
1
+ """Stratified and subsampling logic for financial time-series.
2
+
3
+ This module provides sampling strategies that preserve important
4
+ characteristics of financial data while reducing computational load
5
+ or balancing classes.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import polars as pl
13
+
14
+ if TYPE_CHECKING:
15
+ from numpy.typing import NDArray
16
+
17
+
18
+ def block_bootstrap(
19
+ indices: "NDArray[np.intp]",
20
+ n_samples: int,
21
+ sample_length: int | None = None,
22
+ random_state: int | None = None,
23
+ ) -> "NDArray[np.intp]":
24
+ """Block bootstrap for time series with temporal structure.
25
+
26
+ This method samples random blocks (contiguous sequences) of observations and includes subsequent
27
+ observations to preserve temporal structure and label overlap patterns.
28
+ Based on López de Prado (2018).
29
+
30
+ Parameters
31
+ ----------
32
+ indices : np.ndarray
33
+ Array of indices to sample from
34
+ n_samples : int
35
+ Number of bootstrap samples to generate
36
+ sample_length : int, optional
37
+ Length of each sequential sample. If None, uses average
38
+ length from original data
39
+ random_state : int, optional
40
+ Random seed for reproducibility
41
+
42
+ Returns:
43
+ -------
44
+ np.ndarray
45
+ Bootstrap sample indices
46
+
47
+ Raises:
48
+ ------
49
+ ValueError
50
+ If n_samples <= 0, if indices is empty, or if parameters are invalid
51
+
52
+ Examples:
53
+ --------
54
+ >>> indices = np.arange(100)
55
+ >>> bootstrap_idx = block_bootstrap(indices, n_samples=80, sample_length=5)
56
+ >>> len(bootstrap_idx)
57
+ 80
58
+ """
59
+ # Import here to avoid circular dependency
60
+ from ml4t.diagnostic.core.numba_utils import block_bootstrap_numba
61
+
62
+ # Input validation
63
+ if n_samples <= 0:
64
+ raise ValueError(f"n_samples must be positive, got {n_samples}")
65
+
66
+ if len(indices) == 0:
67
+ raise ValueError("indices array cannot be empty")
68
+
69
+ n_indices = len(indices)
70
+
71
+ if sample_length is None:
72
+ # Default to 10% of data length, minimum 1
73
+ sample_length = max(1, n_indices // 10)
74
+ elif sample_length <= 0:
75
+ raise ValueError(f"sample_length must be positive, got {sample_length}")
76
+
77
+ # Set random seed
78
+ if random_state is None:
79
+ random_state = np.random.randint(0, 2**31 - 1)
80
+
81
+ # Use Numba-optimized function
82
+ return block_bootstrap_numba(indices, n_samples, sample_length, random_state)
83
+
84
+
85
+ def stratified_sample_time_series(
86
+ data: pd.DataFrame | pl.DataFrame,
87
+ stratify_column: str,
88
+ sample_frac: float = 0.5,
89
+ time_column: str | None = None,
90
+ preserve_order: bool = True,
91
+ random_state: int | None = None,
92
+ ) -> pd.DataFrame | pl.DataFrame:
93
+ """Stratified sampling that preserves time series properties.
94
+
95
+ Parameters
96
+ ----------
97
+ data : pd.DataFrame or pl.DataFrame
98
+ Input data to sample from
99
+ stratify_column : str
100
+ Column to use for stratification
101
+ sample_frac : float
102
+ Fraction of data to sample from each stratum
103
+ time_column : str, optional
104
+ Time column for maintaining temporal order
105
+ preserve_order : bool
106
+ Whether to preserve temporal ordering within strata
107
+ random_state : int, optional
108
+ Random seed for reproducibility
109
+
110
+ Returns:
111
+ -------
112
+ pd.DataFrame or pl.DataFrame
113
+ Stratified sample preserving input type
114
+
115
+ Examples:
116
+ --------
117
+ >>> df = pd.DataFrame({
118
+ ... 'time': pd.date_range('2020-01-01', periods=1000),
119
+ ... 'label': np.random.choice([-1, 0, 1], 1000),
120
+ ... 'feature': np.random.randn(1000)
121
+ ... })
122
+ >>> sampled = stratified_sample_time_series(
123
+ ... df, stratify_column='label', sample_frac=0.3
124
+ ... )
125
+ """
126
+ rng = np.random.RandomState(random_state)
127
+
128
+ if isinstance(data, pl.DataFrame):
129
+ # Polars implementation
130
+ unique_values = data[stratify_column].unique().to_list()
131
+ sampled_dfs = []
132
+
133
+ for value in unique_values:
134
+ stratum_df = data.filter(pl.col(stratify_column) == value)
135
+ n_stratum = len(stratum_df)
136
+ n_sample = int(n_stratum * sample_frac)
137
+
138
+ if n_sample > 0:
139
+ if preserve_order and time_column:
140
+ # Sample by time blocks to preserve structure
141
+ block_size = max(1, n_stratum // (n_sample // 10 + 1))
142
+ sampled_indices: list[int] = []
143
+
144
+ for i in range(0, n_stratum - block_size + 1, block_size):
145
+ if rng.random() < sample_frac:
146
+ sampled_indices.extend(
147
+ range(i, min(i + block_size, n_stratum)),
148
+ )
149
+
150
+ sampled_stratum = stratum_df[sampled_indices[:n_sample]]
151
+ else:
152
+ # Random sampling
153
+ sample_indices = rng.choice(n_stratum, n_sample, replace=False)
154
+ sampled_stratum = stratum_df[sorted(sample_indices)]
155
+
156
+ sampled_dfs.append(sampled_stratum)
157
+
158
+ result = pl.concat(sampled_dfs)
159
+
160
+ if time_column and preserve_order:
161
+ result = result.sort(time_column)
162
+
163
+ elif isinstance(data, pd.DataFrame):
164
+ # Pandas implementation - explicit casts to ensure proper type narrowing
165
+ # Use completely separate variable names from Polars branch to avoid mypy redefinition errors
166
+ data_pandas: pd.DataFrame = data
167
+ unique_vals_pd = data_pandas[stratify_column].unique() # Returns ndarray
168
+ collected_dfs: list[pd.DataFrame] = []
169
+
170
+ for val in unique_vals_pd:
171
+ stratum: pd.DataFrame = data_pandas[data_pandas[stratify_column] == val]
172
+ n_rows = len(stratum)
173
+ n_to_sample = int(n_rows * sample_frac)
174
+
175
+ if n_to_sample > 0:
176
+ selected: pd.DataFrame
177
+ if preserve_order:
178
+ # Sample contiguous blocks
179
+ blk_size = max(1, n_rows // (n_to_sample // 10 + 1))
180
+ idx_list: list[Any] = []
181
+
182
+ for j in range(0, n_rows - blk_size + 1, blk_size):
183
+ if rng.random() < sample_frac:
184
+ idx_list.extend(
185
+ stratum.index[j : j + blk_size].tolist(),
186
+ )
187
+
188
+ selected = stratum.loc[idx_list[:n_to_sample]]
189
+ else:
190
+ selected = stratum.sample(
191
+ n=n_to_sample,
192
+ random_state=random_state,
193
+ )
194
+
195
+ collected_dfs.append(selected)
196
+
197
+ result_pd = pd.concat(collected_dfs)
198
+
199
+ if time_column and preserve_order:
200
+ result_pd = result_pd.sort_values(time_column)
201
+
202
+ return result_pd
203
+ else:
204
+ raise TypeError(f"data must be pd.DataFrame or pl.DataFrame, got {type(data)}")
205
+
206
+ return result
207
+
208
+
209
+ def sample_weights_by_importance(
210
+ returns: "NDArray[Any]",
211
+ method: str = "return_magnitude",
212
+ decay_factor: float = 0.94,
213
+ ) -> "NDArray[Any]":
214
+ """Calculate sampling weights based on importance criteria.
215
+
216
+ Parameters
217
+ ----------
218
+ returns : np.ndarray
219
+ Array of returns or outcomes
220
+ method : str
221
+ Method for calculating importance weights:
222
+ - 'return_magnitude': Weight by absolute return size
223
+ - 'recency': Exponential decay weights
224
+ - 'volatility': Weight by local volatility
225
+ decay_factor : float
226
+ Decay factor for recency weighting
227
+
228
+ Returns:
229
+ -------
230
+ np.ndarray
231
+ Sampling weights (sum to 1)
232
+
233
+ Raises:
234
+ ------
235
+ ValueError
236
+ If returns is empty, method is unknown, or decay_factor is invalid
237
+
238
+ Examples:
239
+ --------
240
+ >>> returns = np.random.randn(100) * 0.02
241
+ >>> weights = sample_weights_by_importance(returns, method='recency')
242
+ >>> weights.sum()
243
+ 1.0
244
+ """
245
+ # Input validation
246
+ if len(returns) == 0:
247
+ raise ValueError("returns array cannot be empty")
248
+
249
+ if not 0 < decay_factor < 1:
250
+ raise ValueError(f"decay_factor must be in (0, 1), got {decay_factor}")
251
+
252
+ valid_methods = ["return_magnitude", "recency", "volatility"]
253
+ if method not in valid_methods:
254
+ raise ValueError(f"method must be one of {valid_methods}, got '{method}'")
255
+
256
+ n_samples = len(returns)
257
+
258
+ if method == "return_magnitude":
259
+ # Weight by absolute return magnitude
260
+ weights = np.abs(returns)
261
+
262
+ # Handle case where all returns are zero
263
+ if np.sum(weights) == 0:
264
+ weights = np.ones(n_samples) # Equal weights if all returns are zero
265
+
266
+ weights = weights / weights.sum()
267
+
268
+ elif method == "recency":
269
+ # Exponential decay weights (more recent = higher weight)
270
+ time_weights = decay_factor ** np.arange(n_samples - 1, -1, -1)
271
+ weights = time_weights / time_weights.sum()
272
+
273
+ elif method == "volatility":
274
+ # Weight by local volatility (20-period rolling std)
275
+ if n_samples < 2:
276
+ # Can't calculate volatility with less than 2 samples
277
+ weights = np.ones(n_samples) / n_samples
278
+ else:
279
+ volatility: NDArray[Any] = (
280
+ pd.Series(returns).rolling(20, min_periods=1).std().to_numpy()
281
+ )
282
+
283
+ # Handle case where volatility is all NaN or zero
284
+ if np.all(np.isnan(volatility)) or float(np.nansum(volatility)) == 0:
285
+ weights = np.ones(n_samples) # Equal weights
286
+ else:
287
+ weights = volatility
288
+
289
+ # Replace any remaining NaN values
290
+ weights = np.nan_to_num(weights, nan=1.0)
291
+ weights = weights / weights.sum()
292
+
293
+ # Final safety check - ensure weights are valid probabilities
294
+ weights = np.nan_to_num(weights, nan=1 / n_samples, posinf=1 / n_samples, neginf=0)
295
+
296
+ # Ensure weights sum to 1
297
+ weights_sum = weights.sum()
298
+ if weights_sum <= 0:
299
+ # Fallback to equal weights
300
+ weights = np.ones(n_samples) / n_samples
301
+ else:
302
+ weights = weights / weights_sum
303
+
304
+ return weights
305
+
306
+
307
+ def balanced_subsample(
308
+ X: "NDArray[Any]",
309
+ y: "NDArray[Any]",
310
+ minority_weight: float = 1.0,
311
+ method: str = "undersample",
312
+ random_state: int | None = None,
313
+ ) -> tuple["NDArray[Any]", "NDArray[Any]"]:
314
+ """Balance classes through strategic subsampling.
315
+
316
+ Parameters
317
+ ----------
318
+ X : np.ndarray
319
+ Feature matrix
320
+ y : np.ndarray
321
+ Labels (assumed to be -1, 0, 1 for financial ML)
322
+ minority_weight : float
323
+ Weight given to minority class preservation
324
+ method : str
325
+ Balancing method:
326
+ - 'undersample': Undersample majority class
327
+ - 'hybrid': Combination of under and oversampling
328
+ random_state : int, optional
329
+ Random seed
330
+
331
+ Returns:
332
+ -------
333
+ X_balanced : np.ndarray
334
+ Balanced feature matrix
335
+ y_balanced : np.ndarray
336
+ Balanced labels
337
+ """
338
+ rng = np.random.RandomState(random_state)
339
+
340
+ # Get class counts
341
+ unique_labels, counts = np.unique(y, return_counts=True)
342
+ min_count = counts.min()
343
+ counts.max()
344
+
345
+ if method == "undersample":
346
+ # Undersample to match minority class
347
+ balanced_indices: list[int] = []
348
+
349
+ for label in unique_labels:
350
+ label_indices = np.where(y == label)[0]
351
+
352
+ if len(label_indices) > min_count:
353
+ # Undersample this class
354
+ if label == 0: # Neutral class in financial ML
355
+ # More aggressive undersampling for neutral class
356
+ n_sample = int(min_count * (2 - minority_weight))
357
+ else:
358
+ n_sample = min_count
359
+
360
+ sampled = rng.choice(label_indices, n_sample, replace=False)
361
+ else:
362
+ # Keep all minority samples
363
+ sampled = label_indices
364
+
365
+ balanced_indices.extend(sampled)
366
+
367
+ elif method == "hybrid":
368
+ # Combination approach
369
+ balanced_indices = []
370
+ target_count = int(min_count * (1 + minority_weight))
371
+
372
+ for label in unique_labels:
373
+ label_indices = np.where(y == label)[0]
374
+
375
+ if len(label_indices) > target_count:
376
+ # Undersample
377
+ sampled = rng.choice(label_indices, target_count, replace=False)
378
+ elif len(label_indices) < target_count:
379
+ # Oversample with replacement
380
+ sampled = rng.choice(label_indices, target_count, replace=True)
381
+ else:
382
+ sampled = label_indices
383
+
384
+ balanced_indices.extend(sampled)
385
+
386
+ else:
387
+ raise ValueError(f"Unknown method: {method}")
388
+
389
+ # Shuffle the indices
390
+ balanced_arr: NDArray[np.intp] = np.array(balanced_indices, dtype=np.intp)
391
+ rng.shuffle(balanced_arr)
392
+
393
+ return X[balanced_arr], y[balanced_arr]
394
+
395
+
396
+ def event_based_sample(
397
+ data: pd.DataFrame | pl.DataFrame,
398
+ event_column: str,
399
+ n_samples: int | None = None,
400
+ sample_frac: float | None = None,
401
+ min_event_spacing: int | None = None,
402
+ random_state: int | None = None,
403
+ ) -> pd.DataFrame | pl.DataFrame:
404
+ """Sample based on events ensuring minimum spacing.
405
+
406
+ This is useful for event-driven strategies where you want to
407
+ sample events (like price movements) with minimum time between them.
408
+
409
+ Parameters
410
+ ----------
411
+ data : pd.DataFrame or pl.DataFrame
412
+ Input data
413
+ event_column : str
414
+ Column indicating events (boolean or binary)
415
+ n_samples : int, optional
416
+ Number of events to sample
417
+ sample_frac : float, optional
418
+ Fraction of events to sample
419
+ min_event_spacing : int, optional
420
+ Minimum spacing between sampled events
421
+ random_state : int, optional
422
+ Random seed
423
+
424
+ Returns:
425
+ -------
426
+ pd.DataFrame or pl.DataFrame
427
+ Sampled data containing selected events
428
+ """
429
+ if n_samples is None and sample_frac is None:
430
+ raise ValueError("Either n_samples or sample_frac must be specified")
431
+
432
+ rng = np.random.RandomState(random_state)
433
+
434
+ if isinstance(data, pl.DataFrame):
435
+ # Get event indices
436
+ event_mask_pl = data[event_column].cast(bool)
437
+ event_indices = np.where(event_mask_pl.to_numpy())[0]
438
+ elif isinstance(data, pd.DataFrame):
439
+ # Pandas - explicit isinstance for type narrowing
440
+ event_mask_pd = data[event_column].astype(bool)
441
+ event_indices = np.where(event_mask_pd.to_numpy())[0]
442
+ else:
443
+ raise TypeError(f"data must be pd.DataFrame or pl.DataFrame, got {type(data)}")
444
+
445
+ if n_samples is None:
446
+ if sample_frac is None:
447
+ raise ValueError("Either n_samples or sample_frac must be provided")
448
+ n_samples = int(len(event_indices) * sample_frac)
449
+
450
+ # Sample events with spacing constraint
451
+ sampled_events: list[int] = []
452
+ available_indices = list(event_indices)
453
+
454
+ while len(sampled_events) < n_samples and available_indices:
455
+ # Sample an event
456
+ idx = rng.choice(len(available_indices))
457
+ event_idx = available_indices[idx]
458
+ sampled_events.append(event_idx)
459
+
460
+ # Remove nearby events from available pool
461
+ if min_event_spacing is not None:
462
+ available_indices = [
463
+ i for i in available_indices if abs(i - event_idx) > min_event_spacing
464
+ ]
465
+ else:
466
+ available_indices.pop(idx)
467
+
468
+ # Return data at sampled event indices
469
+ if isinstance(data, pl.DataFrame):
470
+ return data[sorted(sampled_events)]
471
+ return data.iloc[sorted(sampled_events)]
@@ -0,0 +1,205 @@
1
+ """
2
+ ML4T Diagnostic Error Handling Framework
3
+
4
+ Provides a comprehensive exception hierarchy for systematic error handling
5
+ across the ML4T Diagnostic library. All exceptions preserve context information and
6
+ provide actionable error messages.
7
+
8
+ Exception Hierarchy:
9
+ QEvalError (base)
10
+ ├── ConfigurationError # Configuration and setup errors
11
+ ├── ValidationError # Data validation failures
12
+ ├── ComputationError # Calculation and numerical errors
13
+ ├── DataError # Data access and format errors
14
+ └── IntegrationError # External library integration errors
15
+
16
+ Example:
17
+ >>> from ml4t.diagnostic.errors import ValidationError
18
+ >>> try:
19
+ ... validate_returns(returns)
20
+ ... except ValidationError as e:
21
+ ... print(f"Validation failed: {e}")
22
+ ... print(f"Context: {e.context}")
23
+ """
24
+
25
+ from typing import Any
26
+
27
+
28
+ class QEvalError(Exception):
29
+ """
30
+ Base exception for all ML4T Diagnostic errors.
31
+
32
+ All ML4T Diagnostic exceptions inherit from this base class, providing
33
+ consistent error handling and context preservation.
34
+
35
+ Attributes:
36
+ message: Human-readable error description
37
+ context: Additional error context (dict)
38
+ cause: Original exception if error was wrapped
39
+
40
+ Example:
41
+ >>> raise QEvalError(
42
+ ... "Operation failed",
43
+ ... context={"operation": "compute_sharpe", "reason": "insufficient_data"}
44
+ ... )
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ message: str,
50
+ context: dict[str, Any] | None = None,
51
+ cause: Exception | None = None,
52
+ ):
53
+ """
54
+ Initialize ML4T Diagnostic error.
55
+
56
+ Args:
57
+ message: Error description
58
+ context: Additional error context
59
+ cause: Original exception (for error chaining)
60
+ """
61
+ super().__init__(message)
62
+ self.message = message
63
+ self.context = context or {}
64
+ self.cause = cause
65
+
66
+ def __str__(self) -> str:
67
+ """Format error message with context."""
68
+ parts = [self.message]
69
+
70
+ if self.context:
71
+ parts.append("\nContext:")
72
+ for key, value in self.context.items():
73
+ parts.append(f" {key}: {value}")
74
+
75
+ if self.cause:
76
+ parts.append(f"\nCaused by: {type(self.cause).__name__}: {self.cause}")
77
+
78
+ return "".join(parts)
79
+
80
+ def __repr__(self) -> str:
81
+ """Detailed representation."""
82
+ return f"{self.__class__.__name__}(message={self.message!r}, context={self.context!r}, cause={self.cause!r})"
83
+
84
+
85
+ class ConfigurationError(QEvalError):
86
+ """
87
+ Configuration and setup errors.
88
+
89
+ Raised when:
90
+ - Invalid configuration values
91
+ - Missing required configuration
92
+ - Incompatible settings
93
+ - Setup/initialization failures
94
+
95
+ Example:
96
+ >>> from ml4t.diagnostic.config import QEvalConfig
97
+ >>> try:
98
+ ... config = QEvalConfig(n_splits=-1) # Invalid
99
+ ... except ConfigurationError as e:
100
+ ... print(f"Configuration error: {e}")
101
+ """
102
+
103
+ pass
104
+
105
+
106
+ class ValidationError(QEvalError):
107
+ """
108
+ Data validation failures.
109
+
110
+ Raised when:
111
+ - Required columns missing
112
+ - Data type mismatches
113
+ - Value constraints violated
114
+ - Schema validation failures
115
+
116
+ Note:
117
+ This is distinct from the ValidationError in ml4t-diagnostic.validation.
118
+ The validation module uses this exception type for all validation failures.
119
+
120
+ Example:
121
+ >>> from ml4t.diagnostic.validation import validate_returns
122
+ >>> try:
123
+ ... validate_returns(invalid_returns)
124
+ ... except ValidationError as e:
125
+ ... print(f"Validation failed: {e}")
126
+ ... print(f"Details: {e.context}")
127
+ """
128
+
129
+ pass
130
+
131
+
132
+ class ComputationError(QEvalError):
133
+ """
134
+ Calculation and numerical errors.
135
+
136
+ Raised when:
137
+ - Numerical instability (division by zero, overflow)
138
+ - Insufficient data for calculation
139
+ - Algorithm convergence failures
140
+ - Invalid mathematical operations
141
+
142
+ Example:
143
+ >>> from ml4t.diagnostic.metrics import sharpe_ratio
144
+ >>> try:
145
+ ... sr = sharpe_ratio([]) # Empty data
146
+ ... except ComputationError as e:
147
+ ... print(f"Computation failed: {e}")
148
+ """
149
+
150
+ pass
151
+
152
+
153
+ class DataError(QEvalError):
154
+ """
155
+ Data access and format errors.
156
+
157
+ Raised when:
158
+ - Data cannot be loaded
159
+ - Unexpected data format
160
+ - Missing expected data
161
+ - Data corruption
162
+
163
+ Example:
164
+ >>> from ml4t.diagnostic.integration.qfeatures import load_features
165
+ >>> try:
166
+ ... features = load_features("missing_file.parquet")
167
+ ... except DataError as e:
168
+ ... print(f"Data error: {e}")
169
+ """
170
+
171
+ pass
172
+
173
+
174
+ class IntegrationError(QEvalError):
175
+ """
176
+ External library integration errors.
177
+
178
+ Raised when:
179
+ - QFeatures integration fails
180
+ - QEngine integration fails
181
+ - External API errors
182
+ - Version compatibility issues
183
+
184
+ Example:
185
+ >>> from ml4t.diagnostic.integration.qfeatures import FeaturesAdapter
186
+ >>> try:
187
+ ... adapter = FeaturesAdapter()
188
+ ... features = adapter.load("data.parquet")
189
+ ... except IntegrationError as e:
190
+ ... print(f"Integration error: {e}")
191
+ ... print(f"Library: {e.context.get('library')}")
192
+ """
193
+
194
+ pass
195
+
196
+
197
+ # Public API
198
+ __all__ = [
199
+ "QEvalError",
200
+ "ConfigurationError",
201
+ "ValidationError",
202
+ "ComputationError",
203
+ "DataError",
204
+ "IntegrationError",
205
+ ]
@@ -0,0 +1,26 @@
1
+ # evaluation/ - Analysis Framework
2
+
3
+ ## Subdirectories
4
+
5
+ | Directory | Purpose |
6
+ |-----------|---------|
7
+ | [stats/](stats/AGENT.md) | DSR, RAS, FDR, HAC |
8
+ | [metrics/](metrics/AGENT.md) | IC, importance, interactions |
9
+ | distribution/ | Moments, tails, tests |
10
+ | drift/ | PSI, Wasserstein |
11
+ | stationarity/ | ADF, KPSS, PP |
12
+
13
+ ## Key Modules
14
+
15
+ | File | Lines | Purpose |
16
+ |------|-------|---------|
17
+ | framework.py | 935 | `Evaluator` class |
18
+ | validated_cv.py | ~200 | `ValidatedCrossValidation` |
19
+ | barrier_analysis.py | 1050 | `BarrierAnalysis` |
20
+ | binary_metrics.py | 910 | Classification metrics |
21
+ | trade_analysis.py | 1078 | Trade-level analysis |
22
+ | autocorrelation.py | 531 | ACF/PACF |
23
+
24
+ ## Key Classes
25
+
26
+ `Evaluator`, `ValidatedCrossValidation`, `BarrierAnalysis`, `FeatureDiagnostics`