ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,226 @@
1
+ """Monotonicity: Test monotonic relationship between feature values and outcomes.
2
+
3
+ Monotonicity is a key property for predictive features - we expect higher
4
+ (or lower) feature values to consistently correspond to higher outcomes.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any, Union
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import polars as pl
12
+ from scipy import stats
13
+ from scipy.stats import spearmanr
14
+
15
+ if TYPE_CHECKING:
16
+ from numpy.typing import NDArray
17
+
18
+
19
+ def compute_monotonicity(
20
+ features: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
21
+ outcomes: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
22
+ n_quantiles: int = 5,
23
+ feature_col: str | None = None,
24
+ outcome_col: str | None = None,
25
+ method: str = "spearman",
26
+ ) -> dict[str, Any]:
27
+ """Test monotonic relationship between feature values and outcomes.
28
+
29
+ Monotonicity is a key property for predictive features - we expect higher
30
+ (or lower) feature values to consistently correspond to higher outcomes.
31
+ Non-monotonic relationships often indicate:
32
+ 1. Feature needs transformation (e.g., absolute value, log)
33
+ 2. Feature has regime-dependent behavior
34
+ 3. Feature is not truly predictive
35
+
36
+ This function bins features into quantiles and checks if mean outcomes
37
+ increase/decrease monotonically across bins.
38
+
39
+ Parameters
40
+ ----------
41
+ features : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
42
+ Feature values to test
43
+ outcomes : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
44
+ Outcome values (typically returns)
45
+ n_quantiles : int, default 5
46
+ Number of quantile bins (5 = quintiles, 10 = deciles)
47
+ feature_col : str | None, default None
48
+ Column name for features (if DataFrame)
49
+ outcome_col : str | None, default None
50
+ Column name for outcomes (if DataFrame)
51
+ method : str, default "spearman"
52
+ Correlation method: "spearman" or "pearson"
53
+
54
+ Returns
55
+ -------
56
+ dict[str, Any]
57
+ Dictionary with monotonicity analysis:
58
+ - correlation: Spearman/Pearson correlation
59
+ - p_value: Statistical significance of correlation
60
+ - quantile_means: Mean outcome per quantile
61
+ - quantile_labels: Quantile labels (Q1, Q2, ...)
62
+ - is_monotonic: Boolean, True if strictly monotonic
63
+ - monotonicity_score: Fraction of quantile pairs that are monotonic (0-1)
64
+ - direction: "increasing", "decreasing", or "non-monotonic"
65
+ - n_observations: Total observations
66
+ - n_per_quantile: Observations per quantile
67
+
68
+ Examples
69
+ --------
70
+ >>> # Test if momentum predicts returns
71
+ >>> features = df['momentum']
72
+ >>> outcomes = df['forward_return']
73
+ >>> result = compute_monotonicity(features, outcomes, n_quantiles=5)
74
+ >>>
75
+ >>> print(f"Correlation: {result['correlation']:.3f}")
76
+ >>> print(f"P-value: {result['p_value']:.4f}")
77
+ >>> print(f"Monotonic: {result['is_monotonic']}")
78
+ >>> print(f"Direction: {result['direction']}")
79
+ >>> print(f"Quantile means: {result['quantile_means']}")
80
+ Correlation: 0.156
81
+ P-value: 0.0001
82
+ Monotonic: True
83
+ Direction: increasing
84
+ Quantile means: [-0.002, 0.001, 0.003, 0.005, 0.008]
85
+
86
+ Notes
87
+ -----
88
+ Monotonicity Score:
89
+ - 1.0: Perfect monotonicity (all adjacent quantiles ordered correctly)
90
+ - 0.8-1.0: Strong monotonicity (minor violations)
91
+ - 0.6-0.8: Moderate monotonicity
92
+ - <0.6: Weak or no monotonicity
93
+
94
+ Common Patterns:
95
+ - Monotonic increasing: Good positive predictor
96
+ - Monotonic decreasing: Good negative predictor (consider sign flip)
97
+ - U-shaped: Consider absolute value or squared feature
98
+ - Flat: Feature not predictive
99
+
100
+ References
101
+ ----------
102
+ .. [1] Kakushadze, Z., & Serur, J. A. (2018). "151 Trading Strategies."
103
+ """
104
+ # Extract feature and outcome arrays
105
+ feature_vals: NDArray[Any]
106
+ if isinstance(features, pl.DataFrame):
107
+ if feature_col is None:
108
+ raise ValueError("feature_col must be specified for DataFrame input")
109
+ feature_vals = features[feature_col].to_numpy()
110
+ elif isinstance(features, pd.DataFrame):
111
+ if feature_col is None:
112
+ raise ValueError("feature_col must be specified for DataFrame input")
113
+ feature_vals = features[feature_col].to_numpy()
114
+ else:
115
+ feature_vals = np.asarray(features).flatten()
116
+
117
+ outcome_vals: NDArray[Any]
118
+ if isinstance(outcomes, pl.DataFrame):
119
+ if outcome_col is None:
120
+ raise ValueError("outcome_col must be specified for DataFrame input")
121
+ outcome_vals = outcomes[outcome_col].to_numpy()
122
+ elif isinstance(outcomes, pd.DataFrame):
123
+ if outcome_col is None:
124
+ raise ValueError("outcome_col must be specified for DataFrame input")
125
+ outcome_vals = outcomes[outcome_col].to_numpy()
126
+ else:
127
+ outcome_vals = np.asarray(outcomes).flatten()
128
+
129
+ # Validate inputs
130
+ if len(feature_vals) != len(outcome_vals):
131
+ raise ValueError(
132
+ f"Features ({len(feature_vals)}) and outcomes ({len(outcome_vals)}) must have same length"
133
+ )
134
+
135
+ # Remove NaN values
136
+ valid_mask = ~(np.isnan(feature_vals.astype(float)) | np.isnan(outcome_vals.astype(float)))
137
+ feature_clean = feature_vals[valid_mask]
138
+ outcome_clean = outcome_vals[valid_mask]
139
+
140
+ n = len(feature_clean)
141
+ if n < n_quantiles * 2:
142
+ # Insufficient data for quantile analysis
143
+ return {
144
+ "correlation": np.nan,
145
+ "p_value": np.nan,
146
+ "quantile_means": [],
147
+ "quantile_labels": [],
148
+ "is_monotonic": False,
149
+ "monotonicity_score": 0.0,
150
+ "direction": "insufficient_data",
151
+ "n_observations": n,
152
+ "n_per_quantile": [],
153
+ }
154
+
155
+ # Compute correlation
156
+ if method == "spearman":
157
+ correlation, p_value = spearmanr(feature_clean, outcome_clean)
158
+ elif method == "pearson":
159
+ correlation, p_value = stats.pearsonr(feature_clean, outcome_clean)
160
+ else:
161
+ raise ValueError(f"Unknown method: {method}. Use 'spearman' or 'pearson'.")
162
+
163
+ # Create quantile bins
164
+ quantile_edges = np.linspace(0, 100, n_quantiles + 1)
165
+ quantile_bins = np.percentile(feature_clean, quantile_edges)
166
+
167
+ # Assign observations to quantiles
168
+ quantile_assignments = np.digitize(feature_clean, quantile_bins[1:-1]) # 0-indexed bins
169
+
170
+ # Compute mean outcome per quantile
171
+ quantile_means = []
172
+ n_per_quantile = []
173
+
174
+ for q in range(n_quantiles):
175
+ mask = quantile_assignments == q
176
+ if np.sum(mask) > 0:
177
+ quantile_means.append(float(np.mean(outcome_clean[mask])))
178
+ n_per_quantile.append(int(np.sum(mask)))
179
+ else:
180
+ quantile_means.append(np.nan)
181
+ n_per_quantile.append(0)
182
+
183
+ # Check monotonicity
184
+ # Count how many adjacent pairs are ordered correctly
185
+ monotonic_pairs = 0
186
+ total_pairs = 0
187
+
188
+ for i in range(len(quantile_means) - 1):
189
+ if not (np.isnan(quantile_means[i]) or np.isnan(quantile_means[i + 1])):
190
+ total_pairs += 1
191
+ # Check if ordered (either increasing or decreasing)
192
+ if correlation > 0:
193
+ # Expect increasing
194
+ if quantile_means[i + 1] > quantile_means[i]:
195
+ monotonic_pairs += 1
196
+ # Expect decreasing
197
+ elif quantile_means[i + 1] < quantile_means[i]:
198
+ monotonic_pairs += 1
199
+
200
+ monotonicity_score = monotonic_pairs / total_pairs if total_pairs > 0 else 0.0
201
+
202
+ # Strict monotonicity check (all pairs ordered correctly)
203
+ is_monotonic = monotonicity_score == 1.0
204
+
205
+ # Determine direction
206
+ if is_monotonic:
207
+ direction = "increasing" if correlation > 0 else "decreasing"
208
+ elif monotonicity_score >= 0.8:
209
+ direction = "mostly_" + ("increasing" if correlation > 0 else "decreasing")
210
+ else:
211
+ direction = "non_monotonic"
212
+
213
+ # Create quantile labels
214
+ quantile_labels = [f"Q{i + 1}" for i in range(n_quantiles)]
215
+
216
+ return {
217
+ "correlation": float(correlation),
218
+ "p_value": float(p_value),
219
+ "quantile_means": quantile_means,
220
+ "quantile_labels": quantile_labels,
221
+ "is_monotonic": is_monotonic,
222
+ "monotonicity_score": float(monotonicity_score),
223
+ "direction": direction,
224
+ "n_observations": n,
225
+ "n_per_quantile": n_per_quantile,
226
+ }
@@ -0,0 +1,324 @@
1
+ """Risk-adjusted performance metrics: Sharpe, Sortino, Maximum Drawdown.
2
+
3
+ This module provides standard risk-adjusted return metrics used in portfolio
4
+ and strategy evaluation.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any, Union
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import polars as pl
12
+
13
+ from ml4t.diagnostic.backends.adapter import DataFrameAdapter
14
+
15
+ if TYPE_CHECKING:
16
+ from numpy.typing import NDArray
17
+
18
+
19
+ def sharpe_ratio(
20
+ returns: Union[pl.Series, pd.Series, "NDArray[Any]"],
21
+ risk_free_rate: float = 0.0,
22
+ annualization_factor: float | None = None,
23
+ confidence_intervals: bool = False,
24
+ alpha: float = 0.05,
25
+ bootstrap_samples: int = 1000,
26
+ random_state: int | None = None,
27
+ ) -> float | dict[str, float]:
28
+ """Calculate Sharpe Ratio with optional confidence intervals.
29
+
30
+ The Sharpe Ratio measures risk-adjusted returns by dividing excess returns
31
+ by return volatility. Higher values indicate better risk-adjusted performance.
32
+
33
+ Parameters
34
+ ----------
35
+ returns : Union[pl.Series, pd.Series, np.ndarray]
36
+ Time series of returns
37
+ risk_free_rate : float, default 0.0
38
+ Risk-free rate (same frequency as returns)
39
+ annualization_factor : Optional[float], default None
40
+ Factor to annualize the ratio. If None, no annualization applied
41
+ confidence_intervals : bool, default False
42
+ Whether to compute bootstrap confidence intervals
43
+ alpha : float, default 0.05
44
+ Significance level for confidence intervals
45
+ bootstrap_samples : int, default 1000
46
+ Number of bootstrap samples for confidence intervals
47
+ random_state : Optional[int], default None
48
+ Random seed for reproducible bootstrap samples
49
+
50
+ Returns
51
+ -------
52
+ Union[float, dict]
53
+ If confidence_intervals=False: Sharpe ratio value
54
+ If confidence_intervals=True: dict with 'sharpe', 'lower_ci', 'upper_ci'
55
+
56
+ Examples
57
+ --------
58
+ >>> returns = np.array([0.01, 0.02, -0.01, 0.03, 0.00])
59
+ >>> sharpe = sharpe_ratio(returns, annualization_factor=252)
60
+ >>> print(f"Sharpe Ratio: {sharpe:.3f}")
61
+
62
+ >>> # With confidence intervals
63
+ >>> result = sharpe_ratio(returns, confidence_intervals=True, random_state=42)
64
+ >>> print(f"Sharpe: {result['sharpe']:.3f}")
65
+ """
66
+ if confidence_intervals:
67
+ return sharpe_ratio_with_ci(
68
+ returns, risk_free_rate, annualization_factor, alpha, bootstrap_samples, random_state
69
+ )
70
+ return _sharpe_ratio_core(returns, risk_free_rate, annualization_factor)
71
+
72
+
73
+ def _sharpe_ratio_core(
74
+ returns: Union[pl.Series, pd.Series, "NDArray[Any]"],
75
+ risk_free_rate: float = 0.0,
76
+ annualization_factor: float | None = None,
77
+ ) -> float:
78
+ """Calculate Sharpe Ratio (core calculation without confidence intervals).
79
+
80
+ Parameters
81
+ ----------
82
+ returns : Union[pl.Series, pd.Series, np.ndarray]
83
+ Time series of returns
84
+ risk_free_rate : float, default 0.0
85
+ Risk-free rate (same frequency as returns)
86
+ annualization_factor : Optional[float], default None
87
+ Factor to annualize the ratio
88
+
89
+ Returns
90
+ -------
91
+ float
92
+ Sharpe ratio value
93
+ """
94
+ ret_array = DataFrameAdapter.to_numpy(returns).flatten()
95
+ ret_clean = ret_array[~np.isnan(ret_array)]
96
+
97
+ if len(ret_clean) < 2:
98
+ return np.nan
99
+
100
+ excess_returns = ret_clean - risk_free_rate
101
+ mean_excess = np.mean(excess_returns)
102
+ std_excess = np.std(excess_returns, ddof=1)
103
+
104
+ if std_excess == 0:
105
+ if mean_excess > 0:
106
+ return np.inf
107
+ if mean_excess < 0:
108
+ return -np.inf
109
+ return np.nan
110
+
111
+ sharpe = mean_excess / std_excess
112
+
113
+ if annualization_factor is not None and not np.isinf(sharpe) and not np.isnan(sharpe):
114
+ sharpe *= np.sqrt(annualization_factor)
115
+
116
+ return float(sharpe) if not np.isinf(sharpe) else sharpe
117
+
118
+
119
+ def sharpe_ratio_with_ci(
120
+ returns: Union[pl.Series, pd.Series, "NDArray[Any]"],
121
+ risk_free_rate: float = 0.0,
122
+ annualization_factor: float | None = None,
123
+ alpha: float = 0.05,
124
+ bootstrap_samples: int = 1000,
125
+ random_state: int | None = None,
126
+ ) -> dict[str, float]:
127
+ """Calculate Sharpe Ratio with bootstrap confidence intervals.
128
+
129
+ Parameters
130
+ ----------
131
+ returns : Union[pl.Series, pd.Series, np.ndarray]
132
+ Time series of returns
133
+ risk_free_rate : float, default 0.0
134
+ Risk-free rate (same frequency as returns)
135
+ annualization_factor : Optional[float], default None
136
+ Factor to annualize the ratio
137
+ alpha : float, default 0.05
138
+ Significance level for confidence intervals
139
+ bootstrap_samples : int, default 1000
140
+ Number of bootstrap samples for confidence intervals
141
+ random_state : Optional[int], default None
142
+ Random seed for reproducible bootstrap samples
143
+
144
+ Returns
145
+ -------
146
+ dict[str, float]
147
+ Dict with 'sharpe', 'lower_ci', 'upper_ci' keys
148
+ """
149
+ sharpe = _sharpe_ratio_core(returns, risk_free_rate, annualization_factor)
150
+
151
+ if np.isnan(sharpe) or np.isinf(sharpe):
152
+ return {"sharpe": sharpe, "lower_ci": np.nan, "upper_ci": np.nan}
153
+
154
+ ret_array = DataFrameAdapter.to_numpy(returns).flatten()
155
+ ret_clean = ret_array[~np.isnan(ret_array)]
156
+
157
+ if len(ret_clean) < 10:
158
+ return {"sharpe": sharpe, "lower_ci": np.nan, "upper_ci": np.nan}
159
+
160
+ if random_state is not None:
161
+ np.random.seed(random_state)
162
+
163
+ bootstrap_sharpes = []
164
+ for _ in range(bootstrap_samples):
165
+ bootstrap_sample = np.random.choice(ret_clean, size=len(ret_clean), replace=True)
166
+ bootstrap_excess = bootstrap_sample - risk_free_rate
167
+ bootstrap_mean = np.mean(bootstrap_excess)
168
+ bootstrap_std = np.std(bootstrap_excess, ddof=1)
169
+
170
+ if bootstrap_std > 0:
171
+ bs_sharpe = bootstrap_mean / bootstrap_std
172
+ if annualization_factor is not None:
173
+ bs_sharpe *= np.sqrt(annualization_factor)
174
+ bootstrap_sharpes.append(bs_sharpe)
175
+
176
+ if len(bootstrap_sharpes) == 0:
177
+ return {"sharpe": sharpe, "lower_ci": np.nan, "upper_ci": np.nan}
178
+
179
+ lower_ci = np.percentile(bootstrap_sharpes, (alpha / 2) * 100)
180
+ upper_ci = np.percentile(bootstrap_sharpes, (1 - alpha / 2) * 100)
181
+
182
+ return {"sharpe": sharpe, "lower_ci": float(lower_ci), "upper_ci": float(upper_ci)}
183
+
184
+
185
+ def maximum_drawdown(
186
+ returns: Union[pl.Series, pd.Series, "NDArray[Any]"],
187
+ cumulative: bool = False,
188
+ ) -> dict[str, float]:
189
+ """Calculate Maximum Drawdown and related statistics.
190
+
191
+ Maximum Drawdown measures the largest peak-to-trough decline in cumulative
192
+ returns. It represents the worst-case loss an investor would experience.
193
+
194
+ Parameters
195
+ ----------
196
+ returns : Union[pl.Series, pd.Series, np.ndarray]
197
+ Time series of returns (or cumulative returns if cumulative=True)
198
+ cumulative : bool, default False
199
+ Whether input is already cumulative returns
200
+
201
+ Returns
202
+ -------
203
+ dict
204
+ Dictionary with 'max_drawdown', 'max_drawdown_duration', 'peak_date', 'trough_date'
205
+
206
+ Examples
207
+ --------
208
+ >>> returns = np.array([0.10, -0.05, 0.08, -0.12, 0.03])
209
+ >>> dd = maximum_drawdown(returns)
210
+ >>> print(f"Max Drawdown: {dd['max_drawdown']:.3f}")
211
+ Max Drawdown: -0.102
212
+ """
213
+ # Import here to avoid circular dependency
214
+ from ml4t.diagnostic.core.numba_utils import calculate_drawdown_numba
215
+
216
+ # Convert to numpy array
217
+ ret_array = DataFrameAdapter.to_numpy(returns).flatten()
218
+
219
+ # Remove NaN values
220
+ ret_clean = ret_array[~np.isnan(ret_array)]
221
+
222
+ if len(ret_clean) == 0:
223
+ return {
224
+ "max_drawdown": np.nan,
225
+ "max_drawdown_duration": np.nan,
226
+ "peak_date": np.nan,
227
+ "trough_date": np.nan,
228
+ }
229
+
230
+ # Calculate cumulative returns if needed
231
+ if cumulative:
232
+ cum_returns = ret_clean
233
+ else:
234
+ cum_returns = np.cumprod(1 + ret_clean) - 1 # Compound returns
235
+
236
+ # Use Numba-optimized function
237
+ max_drawdown_val, dd_duration, peak_idx, trough_idx = calculate_drawdown_numba(cum_returns)
238
+
239
+ # Handle case where no drawdown was found
240
+ if peak_idx == -1:
241
+ return {
242
+ "max_drawdown": 0.0,
243
+ "max_drawdown_duration": 0,
244
+ "peak_date": 0,
245
+ "trough_date": 0,
246
+ }
247
+
248
+ return {
249
+ "max_drawdown": float(max_drawdown_val),
250
+ "max_drawdown_duration": int(dd_duration),
251
+ "peak_date": int(peak_idx),
252
+ "trough_date": int(trough_idx),
253
+ }
254
+
255
+
256
+ def sortino_ratio(
257
+ returns: Union[pl.Series, pd.Series, "NDArray[Any]"],
258
+ target_return: float = 0.0,
259
+ annualization_factor: float | None = None,
260
+ ) -> float:
261
+ """Calculate Sortino Ratio focusing on downside risk.
262
+
263
+ The Sortino Ratio is similar to Sharpe ratio but only penalizes downside
264
+ volatility, making it more appropriate for asymmetric return distributions.
265
+
266
+ Parameters
267
+ ----------
268
+ returns : Union[pl.Series, pd.Series, np.ndarray]
269
+ Time series of returns
270
+ target_return : float, default 0.0
271
+ Target return threshold (same frequency as returns)
272
+ annualization_factor : Optional[float], default None
273
+ Factor to annualize the ratio
274
+
275
+ Returns
276
+ -------
277
+ float
278
+ Sortino ratio value
279
+
280
+ Examples
281
+ --------
282
+ >>> returns = np.array([0.01, 0.02, -0.01, 0.03, -0.02])
283
+ >>> sortino = sortino_ratio(returns, annualization_factor=252)
284
+ >>> print(f"Sortino Ratio: {sortino:.3f}")
285
+ Sortino Ratio: 0.894
286
+ """
287
+ # Convert to numpy array
288
+ ret_array = DataFrameAdapter.to_numpy(returns).flatten()
289
+
290
+ # Remove NaN values
291
+ ret_clean = ret_array[~np.isnan(ret_array)]
292
+
293
+ if len(ret_clean) < 2:
294
+ return np.nan
295
+
296
+ # Calculate excess returns relative to target
297
+ excess_returns = ret_clean - target_return
298
+
299
+ # Calculate downside returns (only negative excess returns)
300
+ downside_returns = excess_returns[excess_returns < 0]
301
+
302
+ if len(downside_returns) == 0:
303
+ # No downside - infinite Sortino ratio if mean is positive
304
+ mean_excess = np.mean(excess_returns)
305
+ if mean_excess > 0:
306
+ return np.inf
307
+ if mean_excess < 0:
308
+ return -np.inf
309
+ return np.nan
310
+
311
+ # Calculate Sortino ratio
312
+ mean_excess = np.mean(excess_returns)
313
+ downside_std = np.sqrt(np.mean(downside_returns**2)) # Downside deviation
314
+
315
+ if downside_std == 0:
316
+ return np.nan
317
+
318
+ sortino = mean_excess / downside_std
319
+
320
+ # Apply annualization if specified
321
+ if annualization_factor is not None:
322
+ sortino *= np.sqrt(annualization_factor)
323
+
324
+ return float(sortino)