ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,129 @@
1
+ """Dashboard data types and configuration.
2
+
3
+ Provides unified data structures for the dashboard to eliminate dict/object branching.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+
15
+ @dataclass
16
+ class DashboardConfig:
17
+ """Configuration for the dashboard.
18
+
19
+ Attributes
20
+ ----------
21
+ allow_pickle_upload : bool
22
+ Whether to allow uploading pickle files. Disabled by default for security.
23
+ styled : bool
24
+ Whether to apply professional CSS styling.
25
+ title : str
26
+ Dashboard title.
27
+ """
28
+
29
+ allow_pickle_upload: bool = False # Security: disabled by default
30
+ styled: bool = False
31
+ title: str = "Trade SHAP Diagnostics"
32
+
33
+
34
+ @dataclass
35
+ class DashboardBundle:
36
+ """Unified data container for all dashboard tabs.
37
+
38
+ This normalizes the varied input formats (dict vs object, different field names)
39
+ into a single consistent representation that all tabs can consume.
40
+
41
+ Attributes
42
+ ----------
43
+ trades_df : pd.DataFrame
44
+ One row per trade with stable columns:
45
+ - trade_id: str
46
+ - entry_time: datetime
47
+ - exit_time: datetime (optional)
48
+ - pnl: float
49
+ - return_pct: float (optional)
50
+ - symbol: str (optional)
51
+ Sorted chronologically by entry_time for time-series tests.
52
+ returns : np.ndarray | None
53
+ Trade returns array. Prefers return_pct if available, falls back to pnl.
54
+ returns_label : str
55
+ What the returns array represents: "return_pct", "pnl", or "none".
56
+ explanations : list[dict]
57
+ Normalized explanation dictionaries with stable keys:
58
+ - trade_id: str
59
+ - shap_vector: list[float]
60
+ - top_features: list[tuple[str, float]]
61
+ - trade_metrics: dict (optional)
62
+ patterns_df : pd.DataFrame
63
+ One row per error pattern with stable columns:
64
+ - cluster_id: int
65
+ - n_trades: int
66
+ - description: str
67
+ - top_features: list[tuple]
68
+ - hypothesis: str (optional)
69
+ - actions: list[str] (optional)
70
+ - confidence: float (optional)
71
+ n_trades_analyzed : int
72
+ Total number of trades analyzed.
73
+ n_trades_explained : int
74
+ Number of trades successfully explained.
75
+ n_trades_failed : int
76
+ Number of trades that failed explanation.
77
+ failed_trades : list[tuple[str, str]]
78
+ List of (trade_id, reason) for failed explanations.
79
+ config : DashboardConfig
80
+ Dashboard configuration.
81
+ """
82
+
83
+ trades_df: pd.DataFrame
84
+ returns: np.ndarray | None
85
+ returns_label: str # "return_pct" | "pnl" | "none"
86
+ explanations: list[dict[str, Any]]
87
+ patterns_df: pd.DataFrame
88
+ n_trades_analyzed: int
89
+ n_trades_explained: int
90
+ n_trades_failed: int
91
+ failed_trades: list[tuple[str, str]]
92
+ config: DashboardConfig = field(default_factory=DashboardConfig)
93
+
94
+
95
+ @dataclass
96
+ class ReturnSummary:
97
+ """Summary statistics for a returns series.
98
+
99
+ Attributes
100
+ ----------
101
+ n_samples : int
102
+ Number of samples.
103
+ mean : float
104
+ Mean return.
105
+ std : float
106
+ Standard deviation.
107
+ sharpe : float
108
+ Sharpe ratio (mean / std).
109
+ skewness : float
110
+ Skewness of distribution.
111
+ kurtosis : float
112
+ Kurtosis of distribution (not excess, 3.0 for normal).
113
+ min_val : float
114
+ Minimum value.
115
+ max_val : float
116
+ Maximum value.
117
+ win_rate : float
118
+ Fraction of positive returns.
119
+ """
120
+
121
+ n_samples: int
122
+ mean: float
123
+ std: float
124
+ sharpe: float
125
+ skewness: float
126
+ kurtosis: float
127
+ min_val: float
128
+ max_val: float
129
+ win_rate: float
@@ -0,0 +1,102 @@
1
+ """Trade-level SHAP diagnostics models.
2
+
3
+ This package contains the data models for Trade SHAP analysis.
4
+ The main TradeShapAnalyzer and HypothesisGenerator classes are imported
5
+ from the parent module for backward compatibility.
6
+
7
+ For analysis, import from the evaluation module:
8
+ >>> from ml4t.diagnostic.evaluation import TradeShapAnalyzer
9
+
10
+ For models only:
11
+ >>> from ml4t.diagnostic.evaluation.trade_shap import (
12
+ ... TradeShapResult,
13
+ ... ErrorPattern,
14
+ ... ClusteringResult,
15
+ ... TradeShapExplanation,
16
+ ... )
17
+ """
18
+
19
+ # Import models from the dedicated models module
20
+ from ml4t.diagnostic.evaluation.trade_shap.alignment import (
21
+ AlignmentResult,
22
+ TimestampAligner,
23
+ )
24
+ from ml4t.diagnostic.evaluation.trade_shap.characterize import (
25
+ CharacterizationConfig,
26
+ FeatureStatistics,
27
+ PatternCharacterizer,
28
+ benjamini_hochberg,
29
+ )
30
+ from ml4t.diagnostic.evaluation.trade_shap.cluster import (
31
+ ClusteringConfig,
32
+ HierarchicalClusterer,
33
+ compute_centroids,
34
+ compute_cluster_sizes,
35
+ find_optimal_clusters,
36
+ )
37
+ from ml4t.diagnostic.evaluation.trade_shap.explain import TradeShapExplainer
38
+ from ml4t.diagnostic.evaluation.trade_shap.hypotheses import (
39
+ HypothesisConfig,
40
+ HypothesisGenerator,
41
+ Template,
42
+ TemplateMatcher,
43
+ load_templates,
44
+ )
45
+ from ml4t.diagnostic.evaluation.trade_shap.models import (
46
+ ClusteringResult,
47
+ ErrorPattern,
48
+ TradeExplainFailure,
49
+ TradeShapExplanation,
50
+ TradeShapResult,
51
+ )
52
+ from ml4t.diagnostic.evaluation.trade_shap.normalize import (
53
+ NormalizationType,
54
+ normalize,
55
+ normalize_l1,
56
+ normalize_l2,
57
+ standardize,
58
+ )
59
+ from ml4t.diagnostic.evaluation.trade_shap.pipeline import (
60
+ TradeShapPipeline,
61
+ TradeShapPipelineConfig,
62
+ )
63
+
64
+ __all__ = [
65
+ # Alignment
66
+ "TimestampAligner",
67
+ "AlignmentResult",
68
+ # Explainer
69
+ "TradeShapExplainer",
70
+ # Normalization
71
+ "normalize",
72
+ "normalize_l1",
73
+ "normalize_l2",
74
+ "standardize",
75
+ "NormalizationType",
76
+ # Clustering
77
+ "HierarchicalClusterer",
78
+ "ClusteringConfig",
79
+ "find_optimal_clusters",
80
+ "compute_cluster_sizes",
81
+ "compute_centroids",
82
+ # Characterization
83
+ "PatternCharacterizer",
84
+ "CharacterizationConfig",
85
+ "FeatureStatistics",
86
+ "benjamini_hochberg",
87
+ # Hypothesis generation
88
+ "HypothesisGenerator",
89
+ "HypothesisConfig",
90
+ "TemplateMatcher",
91
+ "Template",
92
+ "load_templates",
93
+ # Pipeline
94
+ "TradeShapPipeline",
95
+ "TradeShapPipelineConfig",
96
+ # Result models
97
+ "TradeShapResult",
98
+ "TradeShapExplanation",
99
+ "TradeExplainFailure",
100
+ "ClusteringResult",
101
+ "ErrorPattern",
102
+ ]
@@ -0,0 +1,188 @@
1
+ """Fast timestamp alignment for trade SHAP analysis.
2
+
3
+ This module provides O(log n) timestamp lookup instead of O(n) linear scan,
4
+ using precomputed indices and binary search for nearest-match scenarios.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime
11
+ from typing import TYPE_CHECKING
12
+
13
+ import numpy as np
14
+
15
+ if TYPE_CHECKING:
16
+ from numpy.typing import NDArray
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class AlignmentResult:
21
+ """Result of timestamp alignment.
22
+
23
+ Attributes:
24
+ index: Index into the feature DataFrame, or None if not found
25
+ exact: Whether this was an exact match
26
+ distance_seconds: Distance in seconds from target (0 if exact)
27
+ """
28
+
29
+ index: int | None
30
+ exact: bool
31
+ distance_seconds: float
32
+
33
+
34
+ @dataclass
35
+ class TimestampAligner:
36
+ """Fast timestamp alignment using precomputed indices.
37
+
38
+ Provides O(1) exact match lookup via dict and O(log n) nearest-match
39
+ via binary search on sorted numpy datetime64 array.
40
+
41
+ Attributes:
42
+ timestamps_ns: Sorted numpy array of timestamps as int64 nanoseconds
43
+ index_by_ts: Dict mapping datetime to index for O(1) exact lookup
44
+ tolerance_seconds: Maximum allowed distance for nearest match
45
+ _sorted_indices: Original indices corresponding to sorted timestamps
46
+
47
+ Example:
48
+ >>> import pandas as pd
49
+ >>> timestamps = pd.DatetimeIndex(['2024-01-01', '2024-01-02', '2024-01-03'])
50
+ >>> aligner = TimestampAligner.from_datetime_index(timestamps, tolerance_seconds=3600)
51
+ >>> result = aligner.align(datetime(2024, 1, 2))
52
+ >>> result.index
53
+ 1
54
+ >>> result.exact
55
+ True
56
+ """
57
+
58
+ timestamps_ns: NDArray[np.int64]
59
+ index_by_ts: dict[datetime, int] = field(default_factory=dict)
60
+ tolerance_seconds: float = 0.0
61
+ _sorted_indices: NDArray[np.intp] = field(default_factory=lambda: np.array([], dtype=np.intp))
62
+
63
+ @classmethod
64
+ def from_datetime_index(
65
+ cls,
66
+ timestamps: NDArray | list[datetime],
67
+ tolerance_seconds: float = 0.0,
68
+ ) -> TimestampAligner:
69
+ """Create aligner from datetime index or array.
70
+
71
+ Args:
72
+ timestamps: DatetimeIndex, numpy datetime64 array, or list of datetimes
73
+ tolerance_seconds: Maximum allowed distance for nearest match (default: 0 = exact only)
74
+
75
+ Returns:
76
+ TimestampAligner ready for fast lookups
77
+
78
+ Raises:
79
+ ValueError: If timestamps array is empty
80
+ """
81
+ # Convert to numpy datetime64[ns] if needed
82
+ ts_array = np.asarray(timestamps, dtype="datetime64[ns]")
83
+
84
+ if len(ts_array) == 0:
85
+ raise ValueError("Cannot create aligner from empty timestamp array")
86
+
87
+ # Convert to int64 nanoseconds for fast comparison
88
+ ts_ns = ts_array.astype(np.int64)
89
+
90
+ # Get sort order (we need original indices)
91
+ sorted_indices = np.argsort(ts_ns)
92
+ sorted_ts_ns = ts_ns[sorted_indices]
93
+
94
+ # Build exact-match dict using original timestamps
95
+ # For duplicates, keep FIRST occurrence (standard behavior)
96
+ index_by_ts: dict[datetime, int] = {}
97
+ for i, ts in enumerate(timestamps):
98
+ if hasattr(ts, "to_pydatetime"):
99
+ # pandas Timestamp
100
+ dt = ts.to_pydatetime()
101
+ elif isinstance(ts, np.datetime64):
102
+ # numpy datetime64
103
+ dt = ts.astype("datetime64[us]").astype(datetime)
104
+ else:
105
+ dt = ts
106
+ # Only store first occurrence of each timestamp
107
+ if dt not in index_by_ts:
108
+ index_by_ts[dt] = i
109
+
110
+ return cls(
111
+ timestamps_ns=sorted_ts_ns,
112
+ index_by_ts=index_by_ts,
113
+ tolerance_seconds=tolerance_seconds,
114
+ _sorted_indices=sorted_indices,
115
+ )
116
+
117
+ def align(self, target: datetime) -> AlignmentResult:
118
+ """Find index for target timestamp.
119
+
120
+ First attempts exact match via dict lookup (O(1)).
121
+ If no exact match and tolerance > 0, uses binary search for nearest (O(log n)).
122
+
123
+ Args:
124
+ target: Target timestamp to align
125
+
126
+ Returns:
127
+ AlignmentResult with index (or None), exact flag, and distance
128
+ """
129
+ # Try exact match first (O(1))
130
+ if target in self.index_by_ts:
131
+ return AlignmentResult(index=self.index_by_ts[target], exact=True, distance_seconds=0.0)
132
+
133
+ # No exact match - if no tolerance, return None
134
+ if self.tolerance_seconds <= 0:
135
+ return AlignmentResult(index=None, exact=False, distance_seconds=float("inf"))
136
+
137
+ # Binary search for nearest (O(log n))
138
+ target_ns = np.datetime64(target, "ns").astype(np.int64)
139
+ insert_pos = np.searchsorted(self.timestamps_ns, target_ns)
140
+
141
+ # Check neighbors
142
+ candidates = []
143
+ if insert_pos > 0:
144
+ candidates.append(insert_pos - 1)
145
+ if insert_pos < len(self.timestamps_ns):
146
+ candidates.append(insert_pos)
147
+
148
+ if not candidates:
149
+ return AlignmentResult(index=None, exact=False, distance_seconds=float("inf"))
150
+
151
+ # Find closest
152
+ best_idx = None
153
+ best_distance_ns = float("inf")
154
+
155
+ for sorted_idx in candidates:
156
+ distance_ns = abs(self.timestamps_ns[sorted_idx] - target_ns)
157
+ if distance_ns < best_distance_ns:
158
+ best_distance_ns = distance_ns
159
+ best_idx = sorted_idx
160
+
161
+ # Convert to seconds and check tolerance
162
+ distance_seconds = best_distance_ns / 1e9
163
+
164
+ if distance_seconds <= self.tolerance_seconds:
165
+ # Map back to original index
166
+ original_idx = int(self._sorted_indices[best_idx])
167
+ return AlignmentResult(
168
+ index=original_idx,
169
+ exact=False,
170
+ distance_seconds=distance_seconds,
171
+ )
172
+
173
+ return AlignmentResult(index=None, exact=False, distance_seconds=distance_seconds)
174
+
175
+ def align_many(self, targets: list[datetime]) -> list[AlignmentResult]:
176
+ """Align multiple timestamps.
177
+
178
+ Args:
179
+ targets: List of target timestamps
180
+
181
+ Returns:
182
+ List of AlignmentResult for each target
183
+ """
184
+ return [self.align(t) for t in targets]
185
+
186
+ def __len__(self) -> int:
187
+ """Number of timestamps in the aligner."""
188
+ return len(self.timestamps_ns)