ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,206 @@
1
+ """Signal Analysis Configuration.
2
+
3
+ This module provides configuration for signal analysis including IC calculation,
4
+ quantile analysis, turnover metrics, and multi-signal batch analysis.
5
+
6
+ Consolidated Config:
7
+ - SignalConfig: Single config with analysis, visualization, and multi-signal settings
8
+
9
+ References
10
+ ----------
11
+ López de Prado, M. (2018). "Advances in Financial Machine Learning"
12
+ Paleologo, G. (2024). "Elements of Quantitative Investing"
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from enum import Enum
18
+ from typing import Literal
19
+
20
+ from pydantic import Field, field_validator, model_validator
21
+
22
+ from ml4t.diagnostic.config.base import BaseConfig
23
+
24
+
25
+ class ICMethod(str, Enum):
26
+ """Information Coefficient calculation method."""
27
+
28
+ SPEARMAN = "spearman"
29
+ PEARSON = "pearson"
30
+
31
+
32
+ class QuantileMethod(str, Enum):
33
+ """Method for assigning quantile labels."""
34
+
35
+ QUANTILE = "quantile" # pd.qcut - equal frequency
36
+ UNIFORM = "uniform" # pd.cut - equal width bins
37
+
38
+
39
+ # =============================================================================
40
+ # Settings Classes (Single-Level Nesting)
41
+ # =============================================================================
42
+
43
+
44
+ class AnalysisSettings(BaseConfig):
45
+ """Settings for signal analysis calculations."""
46
+
47
+ quantiles: int = Field(default=5, ge=2, le=20, description="Number of quantile bins")
48
+ periods: tuple[int, ...] = Field(default=(1, 5, 10), description="Forward return periods")
49
+ max_loss: float = Field(
50
+ default=0.35, ge=0.0, le=1.0, description="Max fraction of data to drop"
51
+ )
52
+ filter_zscore: float | None = Field(
53
+ default=3.0, ge=1.0, le=10.0, description="Outlier z-score threshold"
54
+ )
55
+ zero_aware: bool = Field(default=False, description="Separate quantiles around zero")
56
+ ic_method: ICMethod = Field(default=ICMethod.SPEARMAN, description="IC correlation method")
57
+ ic_by_group: bool = Field(default=False, description="Compute IC by group")
58
+ hac_lags: int | None = Field(default=None, ge=0, description="Newey-West lags")
59
+ quantile_method: QuantileMethod = Field(
60
+ default=QuantileMethod.QUANTILE, description="Quantile binning method"
61
+ )
62
+ quantile_labels: list[str] | None = Field(default=None, description="Custom quantile labels")
63
+ cumulative_returns: bool = Field(default=True, description="Compute cumulative returns")
64
+ spread_confidence: float = Field(
65
+ default=0.95, ge=0.80, le=0.99, description="Spread confidence level"
66
+ )
67
+ compute_turnover: bool = Field(default=True, description="Compute turnover metrics")
68
+ autocorrelation_lags: int = Field(default=5, ge=1, le=20, description="Autocorrelation lags")
69
+ cost_per_trade: float = Field(default=0.001, ge=0.0, le=0.05, description="Transaction cost")
70
+ group_column: str | None = Field(default=None, description="Group column for analysis")
71
+
72
+ @field_validator("periods", mode="before")
73
+ @classmethod
74
+ def validate_periods(cls, v: tuple[int, ...] | list[int]) -> tuple[int, ...]:
75
+ """Ensure periods is a tuple of positive integers."""
76
+ if isinstance(v, list):
77
+ v = tuple(v)
78
+ if not all(isinstance(p, int) and p > 0 for p in v):
79
+ raise ValueError("All periods must be positive integers")
80
+ if len(v) == 0:
81
+ raise ValueError("At least one period is required")
82
+ return tuple(sorted(set(v)))
83
+
84
+
85
+ class RASSettings(BaseConfig):
86
+ """Settings for Rademacher Anti-Serum adjustment."""
87
+
88
+ enabled: bool = Field(default=True, description="Apply RAS adjustment")
89
+ delta: float = Field(default=0.05, ge=0.001, le=0.20, description="Significance level")
90
+ kappa: float = Field(default=0.02, ge=0.001, le=1.0, description="IC bound")
91
+ n_simulations: int = Field(
92
+ default=10000, ge=1000, le=100000, description="Monte Carlo simulations"
93
+ )
94
+
95
+
96
+ class VisualizationSettings(BaseConfig):
97
+ """Settings for signal tear sheet visualization."""
98
+
99
+ theme: Literal["default", "dark", "print", "presentation"] = Field(default="default")
100
+ width: int = Field(default=1000, ge=400, le=2000, description="Plot width")
101
+ height_multiplier: float = Field(default=1.0, ge=0.5, le=2.0, description="Height scaling")
102
+ include_ic_plots: bool = Field(default=True)
103
+ include_quantile_plots: bool = Field(default=True)
104
+ include_turnover_plots: bool = Field(default=True)
105
+ include_summary_table: bool = Field(default=True)
106
+ ic_rolling_window: int = Field(default=21, ge=5, le=252, description="IC rolling window")
107
+ ic_significance_bands: bool = Field(default=True)
108
+ ic_heatmap_freq: Literal["M", "Q", "Y"] = Field(default="M")
109
+ html_self_contained: bool = Field(default=True)
110
+ html_include_plotlyjs: Literal["cdn", "directory", True, False] = Field(default="cdn")
111
+ export_data: bool = Field(default=False)
112
+
113
+
114
+ class MultiSignalSettings(BaseConfig):
115
+ """Settings for multi-signal batch analysis."""
116
+
117
+ fdr_alpha: float = Field(default=0.05, ge=0.001, le=0.5, description="FDR alpha")
118
+ fwer_alpha: float = Field(default=0.05, ge=0.001, le=0.5, description="FWER alpha")
119
+ min_ic_threshold: float = Field(default=0.0, ge=-1.0, le=1.0, description="Min IC threshold")
120
+ min_observations: int = Field(default=100, ge=10, description="Min observations")
121
+ n_jobs: int = Field(default=-1, ge=-1, description="Parallel jobs")
122
+ backend: Literal["loky", "threading", "multiprocessing"] = Field(default="loky")
123
+ cache_enabled: bool = Field(default=True)
124
+ cache_max_items: int = Field(default=200, ge=10, le=10000)
125
+ cache_ttl: int | None = Field(default=3600, ge=60)
126
+ max_signals_summary: int = Field(default=200, ge=10, le=1000)
127
+ max_signals_comparison: int = Field(default=20, ge=2, le=50)
128
+ max_signals_heatmap: int = Field(default=100, ge=10, le=500)
129
+ default_selection_metric: str = Field(default="ic_ir")
130
+ default_correlation_threshold: float = Field(default=0.7, ge=0.0, le=1.0)
131
+
132
+ @field_validator("default_selection_metric")
133
+ @classmethod
134
+ def validate_selection_metric(cls, v: str) -> str:
135
+ """Validate selection metric."""
136
+ valid = {"ic_mean", "ic_ir", "ic_t_stat", "turnover_adj_ic", "quantile_spread"}
137
+ if v not in valid:
138
+ raise ValueError(f"Invalid selection metric '{v}'. Valid: {valid}")
139
+ return v
140
+
141
+
142
+ # =============================================================================
143
+ # Consolidated Config
144
+ # =============================================================================
145
+
146
+
147
+ class SignalConfig(BaseConfig):
148
+ """Consolidated configuration for signal analysis.
149
+
150
+ Combines analysis settings, RAS adjustment, visualization, and
151
+ multi-signal batch analysis into a single configuration class.
152
+
153
+ Examples
154
+ --------
155
+ >>> config = SignalConfig(
156
+ ... analysis=AnalysisSettings(quantiles=10, periods=(1, 5)),
157
+ ... visualization=VisualizationSettings(theme="dark"),
158
+ ... )
159
+ >>> config.to_yaml("signal_config.yaml")
160
+ """
161
+
162
+ analysis: AnalysisSettings = Field(
163
+ default_factory=AnalysisSettings, description="Analysis settings"
164
+ )
165
+ ras: RASSettings = Field(default_factory=RASSettings, description="RAS adjustment settings")
166
+ visualization: VisualizationSettings = Field(
167
+ default_factory=VisualizationSettings, description="Visualization settings"
168
+ )
169
+ multi: MultiSignalSettings = Field(
170
+ default_factory=MultiSignalSettings, description="Multi-signal settings"
171
+ )
172
+
173
+ signal_name: str = Field(default="signal", description="Signal name for reports")
174
+ return_pandas: bool = Field(default=False, description="Return pandas instead of Polars")
175
+
176
+ @model_validator(mode="after")
177
+ def validate_quantile_labels_count(self) -> SignalConfig:
178
+ """Ensure quantile_labels matches quantiles count if provided."""
179
+ if self.analysis.quantile_labels is not None:
180
+ if len(self.analysis.quantile_labels) != self.analysis.quantiles:
181
+ raise ValueError(
182
+ f"quantile_labels length ({len(self.analysis.quantile_labels)}) "
183
+ f"must match quantiles ({self.analysis.quantiles})"
184
+ )
185
+ return self
186
+
187
+ # Convenience properties
188
+ @property
189
+ def quantiles(self) -> int:
190
+ """Number of quantiles (shortcut)."""
191
+ return self.analysis.quantiles
192
+
193
+ @property
194
+ def periods(self) -> tuple[int, ...]:
195
+ """Forward return periods (shortcut)."""
196
+ return self.analysis.periods
197
+
198
+ @property
199
+ def filter_zscore(self) -> float | None:
200
+ """Outlier z-score threshold (shortcut)."""
201
+ return self.analysis.filter_zscore
202
+
203
+ @property
204
+ def compute_turnover(self) -> bool:
205
+ """Compute turnover metrics (shortcut)."""
206
+ return self.analysis.compute_turnover
@@ -0,0 +1,310 @@
1
+ """Trade Analysis Configuration.
2
+
3
+ This module provides consolidated configuration for trade-level analysis:
4
+ - Trade extraction (worst/best trades by PnL)
5
+ - Trade filtering (duration, regime, symbol)
6
+ - SHAP alignment (map SHAP values to trades)
7
+ - Error pattern clustering
8
+ - Automated hypothesis generation
9
+
10
+ Consolidated Config:
11
+ - TradeConfig: Single config with all trade analysis settings
12
+
13
+ References
14
+ ----------
15
+ López de Prado, M. (2018). "Advances in Financial Machine Learning"
16
+ Lundberg & Lee (2017). "A Unified Approach to Interpreting Model Predictions"
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from datetime import timedelta
22
+ from typing import Literal
23
+
24
+ from pydantic import Field, field_validator, model_validator
25
+
26
+ from ml4t.diagnostic.config.base import BaseConfig
27
+ from ml4t.diagnostic.config.validation import (
28
+ ClusteringMethod,
29
+ DistanceMetric,
30
+ LinkageMethod,
31
+ NonNegativeInt,
32
+ PositiveInt,
33
+ Probability,
34
+ )
35
+
36
+ # =============================================================================
37
+ # Settings Classes (Single-Level Nesting)
38
+ # =============================================================================
39
+
40
+
41
+ class FilterSettings(BaseConfig):
42
+ """Settings for filtering trades before analysis."""
43
+
44
+ min_duration: timedelta | None = Field(None, description="Minimum trade duration")
45
+ max_duration: timedelta | None = Field(None, description="Maximum trade duration")
46
+ min_pnl: float | None = Field(None, description="Minimum PnL threshold")
47
+ exclude_symbols: list[str] | None = Field(None, description="Symbols to exclude")
48
+ regime_filter: str | None = Field(None, description="Only analyze specific regime")
49
+
50
+ @field_validator("min_duration", "max_duration")
51
+ @classmethod
52
+ def validate_duration_positive(cls, v: timedelta | None) -> timedelta | None:
53
+ """Ensure durations are positive if provided."""
54
+ if v is not None and v.total_seconds() <= 0:
55
+ raise ValueError("Duration must be positive")
56
+ return v
57
+
58
+ @field_validator("max_duration")
59
+ @classmethod
60
+ def validate_max_greater_than_min(cls, v: timedelta | None, info) -> timedelta | None:
61
+ """Ensure max_duration > min_duration if both provided."""
62
+ min_duration = info.data.get("min_duration")
63
+ if v is not None and min_duration is not None and v <= min_duration:
64
+ raise ValueError("max_duration must be greater than min_duration")
65
+ return v
66
+
67
+
68
+ class ExtractionSettings(BaseConfig):
69
+ """Settings for extracting worst/best trades."""
70
+
71
+ n_worst: PositiveInt = Field(20, description="Number of worst trades to extract")
72
+ n_best: NonNegativeInt = Field(10, description="Number of best trades for comparison")
73
+ percentile_mode: bool = Field(False, description="Interpret n_worst/n_best as percentiles")
74
+ compute_statistics: bool = Field(True, description="Compute aggregate statistics")
75
+ group_by_symbol: bool = Field(False, description="Group analysis by symbol")
76
+ group_by_regime: bool = Field(False, description="Group analysis by regime")
77
+
78
+ @field_validator("n_worst")
79
+ @classmethod
80
+ def check_n_worst_reasonable(cls, v: int, info) -> int:
81
+ """Warn if n_worst is very small or very large."""
82
+ import warnings
83
+
84
+ percentile_mode = info.data.get("percentile_mode", False)
85
+ if percentile_mode:
86
+ if v > 50:
87
+ warnings.warn(
88
+ f"n_worst={v}% includes majority of trades. Consider 5-20%.",
89
+ stacklevel=2,
90
+ )
91
+ else:
92
+ if v < 10:
93
+ warnings.warn(
94
+ f"n_worst={v} may be too few. Consider 20-50 for better signal.",
95
+ stacklevel=2,
96
+ )
97
+ elif v > 100:
98
+ warnings.warn(
99
+ f"n_worst={v} may dilute signal. Consider 20-50 for clearer patterns.",
100
+ stacklevel=2,
101
+ )
102
+ return v
103
+
104
+ @field_validator("n_best", "n_worst")
105
+ @classmethod
106
+ def validate_percentile_range(cls, v: int, info) -> int:
107
+ """Ensure percentile values are in valid range."""
108
+ percentile_mode = info.data.get("percentile_mode", False)
109
+ if percentile_mode and (v < 1 or v > 100):
110
+ raise ValueError(f"Percentile must be 1-100, got {v}")
111
+ return v
112
+
113
+
114
+ class AlignmentSettings(BaseConfig):
115
+ """Settings for aligning SHAP values to trade timestamps."""
116
+
117
+ mode: Literal["entry", "nearest", "average"] = Field("entry", description="Alignment mode")
118
+ tolerance: PositiveInt = Field(
119
+ 300, description="Max time difference for 'nearest' mode (seconds)"
120
+ )
121
+ missing_strategy: Literal["error", "skip", "zero"] = Field(
122
+ "skip", description="Missing value handling"
123
+ )
124
+ top_n_features: PositiveInt | None = Field(
125
+ None, description="Top N features per trade (None=all)"
126
+ )
127
+
128
+ @field_validator("tolerance")
129
+ @classmethod
130
+ def warn_large_tolerance(cls, v: int) -> int:
131
+ """Warn if tolerance is very large."""
132
+ if v > 3600:
133
+ import warnings
134
+
135
+ warnings.warn(
136
+ f"tolerance={v}s (>{v // 3600}h) may misalign SHAP values.",
137
+ stacklevel=2,
138
+ )
139
+ return v
140
+
141
+
142
+ class ClusteringSettings(BaseConfig):
143
+ """Settings for clustering error patterns in trades."""
144
+
145
+ method: ClusteringMethod = Field(
146
+ ClusteringMethod.HIERARCHICAL, description="Clustering algorithm"
147
+ )
148
+ linkage: LinkageMethod = Field(LinkageMethod.WARD, description="Linkage for hierarchical")
149
+ distance_metric: DistanceMetric = Field(DistanceMetric.EUCLIDEAN, description="Distance metric")
150
+ min_cluster_size: PositiveInt = Field(5, description="Minimum trades per cluster")
151
+ max_clusters: PositiveInt | None = Field(None, description="Max clusters (None=auto)")
152
+ normalization: Literal["l2", "l1", "standardize", None] = Field(
153
+ "l2", description="SHAP normalization"
154
+ )
155
+
156
+ @model_validator(mode="after")
157
+ def validate_ward_requires_euclidean(self) -> ClusteringSettings:
158
+ """Ensure Ward linkage uses Euclidean distance."""
159
+ if self.linkage == LinkageMethod.WARD and self.distance_metric != DistanceMetric.EUCLIDEAN:
160
+ raise ValueError(
161
+ f"Ward linkage requires Euclidean distance, got {self.distance_metric}. "
162
+ "Use linkage='average' or 'complete' for other metrics."
163
+ )
164
+ return self
165
+
166
+ @field_validator("min_cluster_size")
167
+ @classmethod
168
+ def warn_small_cluster_size(cls, v: int) -> int:
169
+ """Warn if min_cluster_size is very small."""
170
+ if v < 3:
171
+ import warnings
172
+
173
+ warnings.warn(f"min_cluster_size={v} may not be reliable. Use >= 5.", stacklevel=2)
174
+ return v
175
+
176
+
177
+ class HypothesisSettings(BaseConfig):
178
+ """Settings for automated hypothesis generation."""
179
+
180
+ enabled: bool = Field(True, description="Generate hypotheses automatically")
181
+ min_confidence: Probability = Field(0.6, description="Minimum confidence threshold")
182
+ max_per_cluster: PositiveInt = Field(5, description="Max hypotheses per cluster")
183
+ include_interactions: bool = Field(True, description="Look for feature × regime interactions")
184
+ template_library: Literal["comprehensive", "minimal", "custom"] = Field(
185
+ "comprehensive", description="Template set"
186
+ )
187
+
188
+
189
+ # =============================================================================
190
+ # Consolidated Config
191
+ # =============================================================================
192
+
193
+
194
+ class TradeConfig(BaseConfig):
195
+ """Consolidated configuration for trade analysis.
196
+
197
+ Combines trade extraction, filtering, SHAP alignment, error pattern
198
+ clustering, and hypothesis generation into a single configuration.
199
+
200
+ Examples
201
+ --------
202
+ >>> config = TradeConfig(
203
+ ... extraction=ExtractionSettings(n_worst=50),
204
+ ... clustering=ClusteringSettings(min_cluster_size=10),
205
+ ... )
206
+ >>> config.to_yaml("trade_config.yaml")
207
+ """
208
+
209
+ extraction: ExtractionSettings = Field(
210
+ default_factory=ExtractionSettings, description="Trade extraction settings"
211
+ )
212
+ filter: FilterSettings = Field(
213
+ default_factory=FilterSettings, description="Trade filtering settings"
214
+ )
215
+ alignment: AlignmentSettings = Field(
216
+ default_factory=AlignmentSettings, description="SHAP alignment settings"
217
+ )
218
+ clustering: ClusteringSettings = Field(
219
+ default_factory=ClusteringSettings, description="Clustering settings"
220
+ )
221
+ hypothesis: HypothesisSettings = Field(
222
+ default_factory=HypothesisSettings, description="Hypothesis generation"
223
+ )
224
+
225
+ min_trades_for_clustering: PositiveInt = Field(
226
+ 20, description="Minimum trades required for clustering"
227
+ )
228
+ generate_visualizations: bool = Field(True, description="Generate SHAP waterfall plots")
229
+ cache_shap_vectors: bool = Field(True, description="Cache SHAP vectors for performance")
230
+
231
+ # Convenience properties
232
+ @property
233
+ def n_worst(self) -> int:
234
+ """Number of worst trades (shortcut)."""
235
+ return self.extraction.n_worst
236
+
237
+ @property
238
+ def n_best(self) -> int:
239
+ """Number of best trades (shortcut)."""
240
+ return self.extraction.n_best
241
+
242
+ @field_validator("min_trades_for_clustering")
243
+ @classmethod
244
+ def warn_low_min_trades(cls, v: int) -> int:
245
+ """Warn if min_trades is very low."""
246
+ if v < 10:
247
+ import warnings
248
+
249
+ warnings.warn(
250
+ f"min_trades_for_clustering={v} may not identify reliable patterns. Use >= 20.",
251
+ stacklevel=2,
252
+ )
253
+ return v
254
+
255
+ @classmethod
256
+ def for_quick_diagnostics(cls) -> TradeConfig:
257
+ """Preset for quick diagnostics (minimal clustering)."""
258
+ return cls(
259
+ extraction=ExtractionSettings(n_worst=20, n_best=10),
260
+ alignment=AlignmentSettings(top_n_features=10),
261
+ clustering=ClusteringSettings(min_cluster_size=3, max_clusters=5),
262
+ hypothesis=HypothesisSettings(template_library="minimal", max_per_cluster=3),
263
+ min_trades_for_clustering=10,
264
+ generate_visualizations=False,
265
+ )
266
+
267
+ @classmethod
268
+ def for_deep_analysis(cls) -> TradeConfig:
269
+ """Preset for comprehensive analysis."""
270
+ return cls(
271
+ extraction=ExtractionSettings(n_worst=50, n_best=20, compute_statistics=True),
272
+ alignment=AlignmentSettings(top_n_features=None, mode="average"),
273
+ clustering=ClusteringSettings(
274
+ method=ClusteringMethod.HIERARCHICAL,
275
+ linkage=LinkageMethod.WARD,
276
+ min_cluster_size=10,
277
+ max_clusters=None,
278
+ normalization="l2",
279
+ ),
280
+ hypothesis=HypothesisSettings(
281
+ min_confidence=0.6,
282
+ max_per_cluster=10,
283
+ include_interactions=True,
284
+ template_library="comprehensive",
285
+ ),
286
+ min_trades_for_clustering=30,
287
+ generate_visualizations=True,
288
+ )
289
+
290
+ @classmethod
291
+ def for_production(cls) -> TradeConfig:
292
+ """Preset for production monitoring (efficient, focused)."""
293
+ return cls(
294
+ extraction=ExtractionSettings(n_worst=20, n_best=5, group_by_symbol=True),
295
+ alignment=AlignmentSettings(top_n_features=15),
296
+ clustering=ClusteringSettings(min_cluster_size=5, max_clusters=8),
297
+ hypothesis=HypothesisSettings(min_confidence=0.7, max_per_cluster=3),
298
+ min_trades_for_clustering=15,
299
+ generate_visualizations=False,
300
+ cache_shap_vectors=True,
301
+ )
302
+
303
+
304
+ # Rebuild models
305
+ FilterSettings.model_rebuild()
306
+ ExtractionSettings.model_rebuild()
307
+ AlignmentSettings.model_rebuild()
308
+ ClusteringSettings.model_rebuild()
309
+ HypothesisSettings.model_rebuild()
310
+ TradeConfig.model_rebuild()