ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,280 @@
1
+ """SHAP Analysis tab.
2
+
3
+ Displays individual trade SHAP explanations and global feature importance.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Any, cast
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ if TYPE_CHECKING:
14
+ from ml4t.diagnostic.evaluation.trade_dashboard.types import DashboardBundle
15
+
16
+
17
+ def render_tab(st: Any, bundle: DashboardBundle) -> None:
18
+ """Render the SHAP Analysis tab.
19
+
20
+ Parameters
21
+ ----------
22
+ st : streamlit
23
+ Streamlit module instance.
24
+ bundle : DashboardBundle
25
+ Normalized dashboard data.
26
+ """
27
+ st.header("SHAP Analysis")
28
+
29
+ st.info(
30
+ "Explore SHAP (SHapley Additive exPlanations) values for individual trades "
31
+ "to understand which features drove model predictions."
32
+ )
33
+
34
+ explanations = bundle.explanations
35
+
36
+ if not explanations:
37
+ st.warning("No trade explanations available.")
38
+ return
39
+
40
+ # Check for trade selected from worst trades tab
41
+ selected_from_tab2 = st.session_state.get("selected_trade_for_shap")
42
+ selected_trade_idx = 0
43
+
44
+ if selected_from_tab2:
45
+ for i, exp in enumerate(explanations):
46
+ if exp.get("trade_id") == selected_from_tab2:
47
+ selected_trade_idx = i
48
+ break
49
+
50
+ # Trade selector
51
+ st.subheader("Trade Selection")
52
+
53
+ if selected_from_tab2:
54
+ st.success(f"Currently viewing: **{selected_from_tab2}** (selected in Worst Trades tab)")
55
+
56
+ trade_options = [exp.get("trade_id", f"Trade_{i}") for i, exp in enumerate(explanations)]
57
+
58
+ selected_trade_idx = st.selectbox(
59
+ "Select trade to view SHAP explanation:",
60
+ range(len(trade_options)),
61
+ index=selected_trade_idx,
62
+ format_func=lambda x: trade_options[x],
63
+ )
64
+
65
+ if selected_trade_idx is not None:
66
+ _render_trade_shap(st, explanations[selected_trade_idx])
67
+
68
+ # Global feature importance
69
+ st.divider()
70
+ _render_global_importance(st, explanations)
71
+
72
+
73
+ def _render_trade_shap(st: Any, explanation: dict[str, Any]) -> None:
74
+ """Render SHAP explanation for a single trade."""
75
+ trade_id = explanation.get("trade_id", "Unknown")
76
+ timestamp = explanation.get("timestamp")
77
+ top_features = explanation.get("top_features", [])
78
+
79
+ st.divider()
80
+ st.subheader(f"Trade: {trade_id}")
81
+ if timestamp:
82
+ st.caption(f"Timestamp: {timestamp}")
83
+
84
+ # Note: Renamed from "Waterfall" - this is actually a bar chart
85
+ st.subheader("Top SHAP Contributions")
86
+
87
+ if not top_features:
88
+ st.warning("No SHAP features available for this trade.")
89
+ return
90
+
91
+ # Prepare data for visualization
92
+ features_data = []
93
+ cumulative = 0.0
94
+
95
+ for item in top_features[:15]:
96
+ if len(item) >= 2:
97
+ feature, shap_val = item[0], item[1]
98
+ cumulative += shap_val
99
+ features_data.append(
100
+ {
101
+ "Feature": feature,
102
+ "SHAP Value": shap_val,
103
+ "Cumulative": cumulative,
104
+ "Impact": "Positive" if shap_val > 0 else "Negative",
105
+ }
106
+ )
107
+
108
+ if not features_data:
109
+ st.warning("Could not parse SHAP features.")
110
+ return
111
+
112
+ df_shap = pd.DataFrame(features_data)
113
+
114
+ # Create bar chart
115
+ import plotly.graph_objects as go
116
+
117
+ colors = ["#FF6B6B" if val < 0 else "#51CF66" for val in df_shap["SHAP Value"]]
118
+
119
+ fig = go.Figure()
120
+
121
+ fig.add_trace(
122
+ go.Bar(
123
+ x=df_shap["SHAP Value"],
124
+ y=df_shap["Feature"],
125
+ orientation="h",
126
+ marker={"color": colors},
127
+ text=[f"{val:.4f}" for val in df_shap["SHAP Value"]],
128
+ textposition="auto",
129
+ hovertemplate="<b>%{y}</b><br>SHAP: %{x:.4f}<extra></extra>",
130
+ )
131
+ )
132
+
133
+ fig.update_layout(
134
+ title="SHAP Feature Contributions (Top 15 Features)",
135
+ xaxis_title="SHAP Value",
136
+ yaxis_title="Feature",
137
+ height=max(400, len(df_shap) * 30),
138
+ yaxis={"autorange": "reversed"},
139
+ showlegend=False,
140
+ )
141
+
142
+ st.plotly_chart(fig, use_container_width=True)
143
+
144
+ # Feature values table
145
+ st.subheader("Feature Values")
146
+
147
+ display_df = df_shap[["Feature", "SHAP Value", "Impact"]].copy()
148
+ display_df["SHAP Value"] = display_df["SHAP Value"].apply(lambda x: f"{x:.4f}")
149
+
150
+ st.dataframe(
151
+ display_df,
152
+ hide_index=True,
153
+ use_container_width=True,
154
+ column_config={
155
+ "Feature": st.column_config.TextColumn("Feature Name", width="medium"),
156
+ "SHAP Value": st.column_config.TextColumn("SHAP Value", width="small"),
157
+ "Impact": st.column_config.TextColumn("Impact", width="small"),
158
+ },
159
+ )
160
+
161
+ # Interpretation guide
162
+ with st.expander("How to Interpret SHAP Values"):
163
+ st.markdown(
164
+ """
165
+ **SHAP Value Interpretation:**
166
+
167
+ - **Positive SHAP value (green)**: Feature pushed prediction higher
168
+ - **Negative SHAP value (red)**: Feature pushed prediction lower
169
+ - **Magnitude**: Larger absolute values indicate stronger influence
170
+
171
+ **For a losing trade:**
172
+ - Large positive values contributed to an incorrect bullish prediction
173
+ - Large negative values contributed to an incorrect bearish prediction
174
+
175
+ **Actionable insights:**
176
+ - Identify which features consistently mislead the model
177
+ - Look for patterns across multiple losing trades (see Patterns tab)
178
+ """
179
+ )
180
+
181
+ # Summary statistics
182
+ st.divider()
183
+ st.subheader("SHAP Summary Statistics")
184
+
185
+ shap_values = [item[1] for item in top_features if len(item) >= 2]
186
+
187
+ col1, col2, col3, col4 = st.columns(4)
188
+
189
+ with col1:
190
+ total_shap = sum(shap_values)
191
+ st.metric("Total SHAP", f"{total_shap:.4f}")
192
+
193
+ with col2:
194
+ positive_shap = sum(v for v in shap_values if v > 0)
195
+ st.metric("Positive Contrib.", f"{positive_shap:.4f}")
196
+
197
+ with col3:
198
+ negative_shap = sum(v for v in shap_values if v < 0)
199
+ st.metric("Negative Contrib.", f"{negative_shap:.4f}")
200
+
201
+ with col4:
202
+ mean_abs_shap = float(np.mean([abs(v) for v in shap_values])) if shap_values else 0.0
203
+ st.metric("Mean Abs. SHAP", f"{mean_abs_shap:.4f}")
204
+
205
+
206
+ def _render_global_importance(st: Any, explanations: list[dict[str, Any]]) -> None:
207
+ """Render global feature importance across all trades."""
208
+ st.subheader("Global Feature Importance")
209
+
210
+ st.markdown(
211
+ "Aggregate SHAP importance across all analyzed trades to identify "
212
+ "which features are most influential overall."
213
+ )
214
+
215
+ # Calculate global importance
216
+ all_features: dict[str, list[float]] = {}
217
+
218
+ for exp in explanations:
219
+ top_features = exp.get("top_features", [])
220
+
221
+ for item in top_features:
222
+ if len(item) >= 2:
223
+ feature, shap_val = item[0], item[1]
224
+ if feature not in all_features:
225
+ all_features[feature] = []
226
+ all_features[feature].append(abs(shap_val))
227
+
228
+ if not all_features:
229
+ st.warning("No feature importance data available.")
230
+ return
231
+
232
+ # Calculate mean absolute SHAP for each feature
233
+ feature_importance = [
234
+ {
235
+ "Feature": feature,
236
+ "Mean Abs SHAP": float(np.mean(values)),
237
+ "Frequency": len(values),
238
+ "Total Impact": sum(values),
239
+ }
240
+ for feature, values in all_features.items()
241
+ ]
242
+
243
+ # Sort by mean absolute SHAP
244
+ feature_importance.sort(key=lambda x: cast(float, x["Mean Abs SHAP"]), reverse=True)
245
+
246
+ # Display top 20
247
+ df_importance = pd.DataFrame(feature_importance[:20])
248
+
249
+ # Create bar chart
250
+ import plotly.express as px
251
+
252
+ fig = px.bar(
253
+ df_importance,
254
+ x="Mean Abs SHAP",
255
+ y="Feature",
256
+ orientation="h",
257
+ title="Top 20 Most Important Features (Mean Absolute SHAP)",
258
+ color="Mean Abs SHAP",
259
+ color_continuous_scale="Blues",
260
+ )
261
+
262
+ fig.update_layout(
263
+ yaxis={"autorange": "reversed"},
264
+ height=600,
265
+ )
266
+
267
+ st.plotly_chart(fig, use_container_width=True)
268
+
269
+ # Display table
270
+ st.subheader("Feature Importance Table")
271
+
272
+ display_importance = df_importance.copy()
273
+ display_importance["Mean Abs SHAP"] = display_importance["Mean Abs SHAP"].apply(
274
+ lambda x: f"{x:.4f}"
275
+ )
276
+ display_importance["Total Impact"] = display_importance["Total Impact"].apply(
277
+ lambda x: f"{x:.4f}"
278
+ )
279
+
280
+ st.dataframe(display_importance, hide_index=True, use_container_width=True)
@@ -0,0 +1,186 @@
1
+ """Statistical Validation tab.
2
+
3
+ Displays PSR (Probabilistic Sharpe Ratio), distribution tests, and time-series tests.
4
+ Uses PSR instead of DSR because this dashboard analyzes a single strategy.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ if TYPE_CHECKING:
12
+ from ml4t.diagnostic.evaluation.trade_dashboard.types import DashboardBundle
13
+
14
+
15
+ def render_tab(st: Any, bundle: DashboardBundle) -> None:
16
+ """Render the Statistical Validation tab.
17
+
18
+ Parameters
19
+ ----------
20
+ st : streamlit
21
+ Streamlit module instance.
22
+ bundle : DashboardBundle
23
+ Normalized dashboard data.
24
+ """
25
+ from ml4t.diagnostic.evaluation.trade_dashboard.stats import (
26
+ compute_distribution_tests,
27
+ compute_return_summary,
28
+ compute_time_series_tests,
29
+ probabilistic_sharpe_ratio,
30
+ )
31
+
32
+ st.header("Statistical Validation")
33
+
34
+ st.info(
35
+ "Statistical validation ensures that identified patterns are "
36
+ "statistically significant and not due to random chance."
37
+ )
38
+
39
+ # Check if we have returns data
40
+ if bundle.returns is None or len(bundle.returns) == 0:
41
+ st.warning(
42
+ "No trade returns available for statistical analysis. "
43
+ "Ensure trade_metrics are attached to explanations."
44
+ )
45
+ return
46
+
47
+ returns = bundle.returns
48
+ summary = compute_return_summary(returns)
49
+
50
+ # Show warning if using PnL instead of return_pct
51
+ if bundle.returns_label == "pnl":
52
+ st.caption(
53
+ "Using PnL (dollar amounts) instead of normalized returns. "
54
+ "Sharpe ratio interpretation is limited."
55
+ )
56
+
57
+ # PSR section (replaces incorrect DSR usage)
58
+ st.subheader("Probabilistic Sharpe Ratio (PSR)")
59
+
60
+ st.markdown(
61
+ """
62
+ **What is PSR?**
63
+ The Probabilistic Sharpe Ratio (PSR) gives the probability that the true
64
+ Sharpe ratio exceeds a benchmark (typically 0), accounting for sample size
65
+ and return distribution characteristics.
66
+
67
+ *Note: DSR (Deflated Sharpe Ratio) was previously shown here but is not
68
+ applicable to single-strategy analysis. DSR requires K independent strategies
69
+ to compute the variance across trials.*
70
+
71
+ **Reference:** Bailey & Lopez de Prado (2012). "The Sharpe Ratio Efficient Frontier"
72
+ """
73
+ )
74
+
75
+ # Calculate PSR
76
+ psr_result = probabilistic_sharpe_ratio(
77
+ observed_sharpe=summary.sharpe,
78
+ benchmark_sharpe=0.0,
79
+ n_samples=summary.n_samples,
80
+ skewness=summary.skewness,
81
+ kurtosis=summary.kurtosis,
82
+ return_components=True,
83
+ )
84
+
85
+ # Display metrics
86
+ col1, col2, col3, col4 = st.columns(4)
87
+
88
+ with col1:
89
+ st.metric(
90
+ "Sharpe Ratio",
91
+ f"{summary.sharpe:.3f}",
92
+ help="Observed Sharpe ratio (mean / std)",
93
+ )
94
+
95
+ with col2:
96
+ st.metric(
97
+ "PSR (vs SR=0)",
98
+ f"{psr_result['psr']:.3f}",
99
+ help="Probability that true SR > 0",
100
+ )
101
+
102
+ with col3:
103
+ p_value = 1 - psr_result["psr"]
104
+ st.metric(
105
+ "P-Value",
106
+ f"{p_value:.4f}",
107
+ help="1 - PSR: probability true SR <= 0",
108
+ )
109
+
110
+ with col4:
111
+ st.metric("N Trades", summary.n_samples, help="Number of trades analyzed")
112
+
113
+ # Interpretation
114
+ psr = psr_result["psr"]
115
+ if psr >= 0.99:
116
+ st.success(f"Strong evidence SR > 0 (PSR = {psr:.3f} >= 0.99)")
117
+ elif psr >= 0.95:
118
+ st.success(f"Significant performance (PSR = {psr:.3f} >= 0.95)")
119
+ elif psr >= 0.90:
120
+ st.warning(f"Marginally significant (PSR = {psr:.3f} >= 0.90)")
121
+ elif psr >= 0.50:
122
+ st.warning(f"Weak evidence SR > 0 (PSR = {psr:.3f})")
123
+ else:
124
+ st.error(f"Evidence suggests SR <= 0 (PSR = {psr:.3f} < 0.50)")
125
+
126
+ # Return statistics
127
+ st.divider()
128
+ st.subheader("Return Statistics")
129
+
130
+ col1, col2, col3, col4 = st.columns(4)
131
+
132
+ with col1:
133
+ st.metric("Mean", f"{summary.mean:.4f}")
134
+
135
+ with col2:
136
+ st.metric("Std Dev", f"{summary.std:.4f}")
137
+
138
+ with col3:
139
+ st.metric("Win Rate", f"{summary.win_rate:.1%}")
140
+
141
+ with col4:
142
+ st.metric("Skewness", f"{summary.skewness:.3f}")
143
+
144
+ col1, col2, col3, col4 = st.columns(4)
145
+
146
+ with col1:
147
+ st.metric("Kurtosis", f"{summary.kurtosis:.3f}")
148
+
149
+ with col2:
150
+ st.metric("Min", f"{summary.min_val:.4f}")
151
+
152
+ with col3:
153
+ st.metric("Max", f"{summary.max_val:.4f}")
154
+
155
+ with col4:
156
+ pass # Empty column for alignment
157
+
158
+ # Distribution tests
159
+ st.divider()
160
+ st.subheader("Distribution Tests")
161
+
162
+ dist_tests = compute_distribution_tests(returns)
163
+ if not dist_tests.empty:
164
+ st.dataframe(
165
+ dist_tests,
166
+ hide_index=True,
167
+ use_container_width=True,
168
+ )
169
+ else:
170
+ st.caption("Insufficient data for distribution tests.")
171
+
172
+ # Time-series tests
173
+ st.divider()
174
+ st.subheader("Time-Series Tests")
175
+
176
+ st.caption("These tests require chronologically ordered data. Trades are sorted by entry_time.")
177
+
178
+ ts_tests = compute_time_series_tests(returns)
179
+ if not ts_tests.empty:
180
+ st.dataframe(
181
+ ts_tests,
182
+ hide_index=True,
183
+ use_container_width=True,
184
+ )
185
+ else:
186
+ st.caption("Insufficient data for time-series tests (need 20+ observations).")
@@ -0,0 +1,236 @@
1
+ """Worst Trades tab.
2
+
3
+ Displays a table of trades with sorting/filtering and detailed view.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import pandas as pd
11
+
12
+ if TYPE_CHECKING:
13
+ from ml4t.diagnostic.evaluation.trade_dashboard.types import DashboardBundle
14
+
15
+
16
+ def render_tab(st: Any, bundle: DashboardBundle) -> None:
17
+ """Render the Worst Trades tab.
18
+
19
+ Parameters
20
+ ----------
21
+ st : streamlit
22
+ Streamlit module instance.
23
+ bundle : DashboardBundle
24
+ Normalized dashboard data.
25
+ """
26
+ st.header("Worst Trades Analysis")
27
+
28
+ st.info(
29
+ "This tab shows the trades analyzed for error patterns. "
30
+ "Select a trade to see detailed SHAP explanations."
31
+ )
32
+
33
+ trades_df = bundle.trades_df
34
+
35
+ if trades_df.empty:
36
+ st.warning("No trade data available.")
37
+ return
38
+
39
+ # Sidebar filters
40
+ with st.sidebar:
41
+ st.divider()
42
+ st.subheader("Trade Filters")
43
+
44
+ # Sort options
45
+ sort_options = ["PnL (Low to High)", "PnL (High to Low)", "Entry Time", "Return %"]
46
+ sort_by = st.selectbox("Sort by", options=sort_options, index=0)
47
+
48
+ # Max trades slider
49
+ max_trades = st.slider("Max trades to display", min_value=5, max_value=100, value=20)
50
+
51
+ # Apply sorting
52
+ sorted_df = trades_df.copy()
53
+
54
+ if sort_by == "PnL (Low to High)" and "pnl" in sorted_df.columns:
55
+ sorted_df = sorted_df.sort_values("pnl", ascending=True, na_position="last")
56
+ elif sort_by == "PnL (High to Low)" and "pnl" in sorted_df.columns:
57
+ sorted_df = sorted_df.sort_values("pnl", ascending=False, na_position="last")
58
+ elif sort_by == "Entry Time" and "entry_time" in sorted_df.columns:
59
+ sorted_df = sorted_df.sort_values("entry_time", ascending=True, na_position="last")
60
+ elif sort_by == "Return %" and "return_pct" in sorted_df.columns:
61
+ sorted_df = sorted_df.sort_values("return_pct", ascending=True, na_position="last")
62
+
63
+ # Limit display
64
+ sorted_df = sorted_df.head(max_trades)
65
+
66
+ # Build display DataFrame
67
+ display_columns = {
68
+ "trade_id": "Trade ID",
69
+ "symbol": "Symbol",
70
+ "entry_time": "Entry Time",
71
+ "pnl": "PnL",
72
+ "return_pct": "Return %",
73
+ "duration_days": "Duration (days)",
74
+ "top_feature": "Top Feature",
75
+ "top_shap_value": "Top SHAP",
76
+ }
77
+
78
+ display_df = sorted_df[[c for c in display_columns if c in sorted_df.columns]].copy()
79
+ display_df = display_df.rename(
80
+ columns={k: v for k, v in display_columns.items() if k in display_df.columns}
81
+ )
82
+
83
+ # Format timestamp for display
84
+ if "Entry Time" in display_df.columns:
85
+ display_df["Entry Time"] = display_df["Entry Time"].apply(
86
+ lambda x: x.strftime("%Y-%m-%d %H:%M") if pd.notna(x) else "N/A"
87
+ )
88
+
89
+ # Configure column formatting
90
+ column_config = {
91
+ "Trade ID": st.column_config.TextColumn("Trade ID", width="medium"),
92
+ "Symbol": st.column_config.TextColumn("Symbol", width="small"),
93
+ "Entry Time": st.column_config.TextColumn("Entry Time", width="medium"),
94
+ "PnL": st.column_config.NumberColumn(
95
+ "PnL",
96
+ format="%.2f",
97
+ help="Profit/Loss for this trade",
98
+ ),
99
+ "Return %": st.column_config.NumberColumn(
100
+ "Return %",
101
+ format="%.2f%%",
102
+ help="Return as percentage",
103
+ ),
104
+ "Duration (days)": st.column_config.NumberColumn(
105
+ "Duration (days)",
106
+ format="%.1f",
107
+ help="Trade duration in days",
108
+ ),
109
+ "Top Feature": st.column_config.TextColumn(
110
+ "Top Feature",
111
+ help="Feature with highest absolute SHAP value",
112
+ ),
113
+ "Top SHAP": st.column_config.NumberColumn(
114
+ "Top SHAP",
115
+ format="%.4f",
116
+ help="SHAP value for top feature",
117
+ ),
118
+ }
119
+
120
+ # Display table with selection
121
+ st.subheader("Trade Table")
122
+
123
+ # Initialize session state for selected trade
124
+ if "selected_trade_idx" not in st.session_state:
125
+ st.session_state.selected_trade_idx = None
126
+
127
+ # Use dataframe with on_select callback
128
+ event = st.dataframe(
129
+ display_df,
130
+ hide_index=True,
131
+ use_container_width=True,
132
+ column_config={k: v for k, v in column_config.items() if k in display_df.columns},
133
+ on_select="rerun",
134
+ selection_mode="single-row",
135
+ )
136
+
137
+ # Handle row selection
138
+ selection = getattr(event, "selection", None)
139
+ if selection is not None:
140
+ rows = getattr(selection, "rows", [])
141
+ if rows:
142
+ st.session_state.selected_trade_idx = rows[0]
143
+
144
+ # Display trade details if selected
145
+ if (
146
+ st.session_state.selected_trade_idx is not None
147
+ and st.session_state.selected_trade_idx < len(sorted_df)
148
+ ):
149
+ _render_trade_details(st, sorted_df, bundle, st.session_state.selected_trade_idx)
150
+
151
+
152
+ def _render_trade_details(
153
+ st: Any,
154
+ sorted_df: pd.DataFrame,
155
+ bundle: DashboardBundle,
156
+ selected_idx: int,
157
+ ) -> None:
158
+ """Render detailed view of selected trade."""
159
+ st.divider()
160
+ st.subheader("Trade Details")
161
+
162
+ row = sorted_df.iloc[selected_idx]
163
+ trade_id = row.get("trade_id", "")
164
+
165
+ # Find corresponding explanation
166
+ explanation = next(
167
+ (exp for exp in bundle.explanations if exp.get("trade_id") == trade_id),
168
+ None,
169
+ )
170
+
171
+ # Basic metrics
172
+ col1, col2, col3, col4 = st.columns(4)
173
+
174
+ with col1:
175
+ st.metric("Trade ID", trade_id)
176
+ if pd.notna(row.get("symbol")):
177
+ st.metric("Symbol", row["symbol"])
178
+
179
+ with col2:
180
+ pnl = row.get("pnl")
181
+ if pd.notna(pnl):
182
+ st.metric("PnL", f"${pnl:.2f}")
183
+ else:
184
+ st.metric("PnL", "N/A")
185
+
186
+ with col3:
187
+ return_pct = row.get("return_pct")
188
+ if pd.notna(return_pct):
189
+ st.metric("Return", f"{return_pct:.2f}%")
190
+ else:
191
+ st.metric("Return", "N/A")
192
+
193
+ with col4:
194
+ duration = row.get("duration_days")
195
+ if pd.notna(duration):
196
+ st.metric("Duration", f"{duration:.1f} days")
197
+ else:
198
+ st.metric("Duration", "N/A")
199
+
200
+ # Entry/Exit prices
201
+ col1, col2 = st.columns(2)
202
+
203
+ with col1:
204
+ entry_price = row.get("entry_price")
205
+ if pd.notna(entry_price):
206
+ st.metric("Entry Price", f"${entry_price:.4f}")
207
+ else:
208
+ st.caption("Entry price not available")
209
+
210
+ with col2:
211
+ exit_price = row.get("exit_price")
212
+ if pd.notna(exit_price):
213
+ st.metric("Exit Price", f"${exit_price:.4f}")
214
+ else:
215
+ st.caption("Exit price not available")
216
+
217
+ # Top features from explanation
218
+ if explanation and explanation.get("top_features"):
219
+ st.subheader("Top SHAP Contributions")
220
+
221
+ top_features = explanation["top_features"]
222
+ feature_data = [
223
+ {"Feature": f[0], "SHAP Value": f[1]}
224
+ for f in top_features[:10] # Limit to top 10
225
+ ]
226
+
227
+ if feature_data:
228
+ st.dataframe(
229
+ pd.DataFrame(feature_data),
230
+ hide_index=True,
231
+ use_container_width=True,
232
+ column_config={
233
+ "Feature": st.column_config.TextColumn("Feature"),
234
+ "SHAP Value": st.column_config.NumberColumn("SHAP Value", format="%.4f"),
235
+ },
236
+ )