ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""Trade SHAP analysis pipeline.
|
|
2
|
+
|
|
3
|
+
This module provides the main TradeShapAnalyzer class that orchestrates
|
|
4
|
+
all components of trade SHAP analysis:
|
|
5
|
+
- TradeShapExplainer for individual trade explanations
|
|
6
|
+
- HierarchicalClusterer for error pattern clustering
|
|
7
|
+
- PatternCharacterizer for statistical characterization
|
|
8
|
+
- HypothesisGenerator for actionable insights
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from typing import TYPE_CHECKING, Any
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from ml4t.diagnostic.evaluation.trade_shap.characterize import (
|
|
19
|
+
CharacterizationConfig,
|
|
20
|
+
PatternCharacterizer,
|
|
21
|
+
)
|
|
22
|
+
from ml4t.diagnostic.evaluation.trade_shap.cluster import (
|
|
23
|
+
ClusteringConfig,
|
|
24
|
+
HierarchicalClusterer,
|
|
25
|
+
)
|
|
26
|
+
from ml4t.diagnostic.evaluation.trade_shap.explain import TradeShapExplainer
|
|
27
|
+
from ml4t.diagnostic.evaluation.trade_shap.hypotheses import (
|
|
28
|
+
HypothesisConfig,
|
|
29
|
+
HypothesisGenerator,
|
|
30
|
+
)
|
|
31
|
+
from ml4t.diagnostic.evaluation.trade_shap.models import (
|
|
32
|
+
TradeExplainFailure,
|
|
33
|
+
TradeShapExplanation,
|
|
34
|
+
TradeShapResult,
|
|
35
|
+
)
|
|
36
|
+
from ml4t.diagnostic.evaluation.trade_shap.normalize import normalize
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
import polars as pl
|
|
40
|
+
from numpy.typing import NDArray
|
|
41
|
+
|
|
42
|
+
from ml4t.diagnostic.evaluation.trade_analysis import TradeMetrics
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class TradeShapPipelineConfig:
|
|
47
|
+
"""Configuration for the trade SHAP analysis pipeline.
|
|
48
|
+
|
|
49
|
+
Attributes:
|
|
50
|
+
alignment_tolerance_seconds: Tolerance for timestamp alignment
|
|
51
|
+
alignment_mode: 'entry' for exact match, 'nearest' for closest
|
|
52
|
+
missing_value_strategy: How to handle alignment failures ('error', 'skip', 'zero')
|
|
53
|
+
top_n_features: Number of top features in explanations
|
|
54
|
+
normalization: Normalization method for clustering ('l1', 'l2', 'standardize', None)
|
|
55
|
+
clustering: Clustering configuration
|
|
56
|
+
characterization: Characterization configuration
|
|
57
|
+
hypothesis: Hypothesis generation configuration
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
alignment_tolerance_seconds: float = 0.0
|
|
61
|
+
alignment_mode: str = "entry"
|
|
62
|
+
missing_value_strategy: str = "skip"
|
|
63
|
+
top_n_features: int = 10
|
|
64
|
+
normalization: str | None = "l2"
|
|
65
|
+
clustering: ClusteringConfig = field(default_factory=ClusteringConfig)
|
|
66
|
+
characterization: CharacterizationConfig = field(default_factory=CharacterizationConfig)
|
|
67
|
+
hypothesis: HypothesisConfig = field(default_factory=HypothesisConfig)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class TradeShapPipeline:
|
|
71
|
+
"""Orchestrates trade SHAP analysis components.
|
|
72
|
+
|
|
73
|
+
This is the main entry point for trade SHAP analysis, providing a clean
|
|
74
|
+
interface that uses the refactored components internally.
|
|
75
|
+
|
|
76
|
+
Attributes:
|
|
77
|
+
features_df: Polars DataFrame with timestamp and feature columns
|
|
78
|
+
shap_values: SHAP values array (n_samples x n_features)
|
|
79
|
+
feature_names: List of feature column names
|
|
80
|
+
config: Pipeline configuration
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
>>> pipeline = TradeShapPipeline(
|
|
84
|
+
... features_df=features,
|
|
85
|
+
... shap_values=shap_values,
|
|
86
|
+
... feature_names=feature_names,
|
|
87
|
+
... )
|
|
88
|
+
>>> result = pipeline.analyze_worst_trades(trades, n=20)
|
|
89
|
+
>>> for pattern in result.error_patterns:
|
|
90
|
+
... print(pattern.hypothesis)
|
|
91
|
+
... print(pattern.actions)
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
features_df: pl.DataFrame,
|
|
97
|
+
shap_values: NDArray[np.floating[Any]],
|
|
98
|
+
feature_names: list[str],
|
|
99
|
+
config: TradeShapPipelineConfig | None = None,
|
|
100
|
+
) -> None:
|
|
101
|
+
"""Initialize pipeline.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
features_df: Polars DataFrame with 'timestamp' column and feature columns
|
|
105
|
+
shap_values: SHAP values array (n_samples x n_features)
|
|
106
|
+
feature_names: List of feature column names
|
|
107
|
+
config: Pipeline configuration (uses defaults if None)
|
|
108
|
+
"""
|
|
109
|
+
self.features_df = features_df
|
|
110
|
+
self.shap_values = shap_values
|
|
111
|
+
self.feature_names = feature_names
|
|
112
|
+
self.config = config or TradeShapPipelineConfig()
|
|
113
|
+
|
|
114
|
+
# Initialize explainer
|
|
115
|
+
self.explainer = TradeShapExplainer(
|
|
116
|
+
features_df=features_df,
|
|
117
|
+
shap_values=shap_values,
|
|
118
|
+
feature_names=feature_names,
|
|
119
|
+
tolerance_seconds=self.config.alignment_tolerance_seconds,
|
|
120
|
+
top_n_features=self.config.top_n_features,
|
|
121
|
+
alignment_mode=self.config.alignment_mode,
|
|
122
|
+
missing_value_strategy=self.config.missing_value_strategy,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Initialize clusterer
|
|
126
|
+
self.clusterer = HierarchicalClusterer(config=self.config.clustering)
|
|
127
|
+
|
|
128
|
+
# Initialize characterizer
|
|
129
|
+
self.characterizer = PatternCharacterizer(
|
|
130
|
+
feature_names=feature_names,
|
|
131
|
+
config=self.config.characterization,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Initialize hypothesis generator
|
|
135
|
+
self.hypothesis_generator = HypothesisGenerator(config=self.config.hypothesis)
|
|
136
|
+
|
|
137
|
+
def explain_trade(
|
|
138
|
+
self,
|
|
139
|
+
trade: TradeMetrics,
|
|
140
|
+
) -> TradeShapExplanation | TradeExplainFailure:
|
|
141
|
+
"""Explain a single trade.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
trade: Trade to explain
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
TradeShapExplanation on success, TradeExplainFailure on failure
|
|
148
|
+
"""
|
|
149
|
+
return self.explainer.explain(trade)
|
|
150
|
+
|
|
151
|
+
def explain_trades(
|
|
152
|
+
self,
|
|
153
|
+
trades: list[TradeMetrics],
|
|
154
|
+
) -> tuple[list[TradeShapExplanation], list[TradeExplainFailure]]:
|
|
155
|
+
"""Explain multiple trades.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
trades: List of trades to explain
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Tuple of (successful explanations, failures)
|
|
162
|
+
"""
|
|
163
|
+
return self.explainer.explain_many(trades)
|
|
164
|
+
|
|
165
|
+
def analyze_worst_trades(
|
|
166
|
+
self,
|
|
167
|
+
trades: list[TradeMetrics],
|
|
168
|
+
n: int | None = None,
|
|
169
|
+
) -> TradeShapResult:
|
|
170
|
+
"""Analyze worst trades with full pipeline.
|
|
171
|
+
|
|
172
|
+
This is the main entry point that:
|
|
173
|
+
1. Explains each trade
|
|
174
|
+
2. Clusters the SHAP vectors
|
|
175
|
+
3. Characterizes each cluster as an error pattern
|
|
176
|
+
4. Generates hypotheses for each pattern
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
trades: List of trades (should be sorted by loss, worst first)
|
|
180
|
+
n: Number of trades to analyze (defaults to all)
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
TradeShapResult with explanations, error patterns, and insights
|
|
184
|
+
"""
|
|
185
|
+
# Limit trades
|
|
186
|
+
trades_to_analyze = trades[:n] if n is not None else trades
|
|
187
|
+
|
|
188
|
+
# Step 1: Explain trades
|
|
189
|
+
explanations, failures = self.explain_trades(trades_to_analyze)
|
|
190
|
+
|
|
191
|
+
if not explanations:
|
|
192
|
+
# No successful explanations
|
|
193
|
+
return TradeShapResult(
|
|
194
|
+
n_trades_analyzed=len(trades_to_analyze),
|
|
195
|
+
n_trades_explained=0,
|
|
196
|
+
n_trades_failed=len(failures),
|
|
197
|
+
explanations=[],
|
|
198
|
+
failed_trades=[(f.trade_id, f.reason) for f in failures],
|
|
199
|
+
error_patterns=[],
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Step 2: Extract and normalize SHAP vectors for clustering
|
|
203
|
+
shap_vectors = np.array([exp.shap_vector for exp in explanations])
|
|
204
|
+
|
|
205
|
+
# Normalize if configured
|
|
206
|
+
if self.config.normalization:
|
|
207
|
+
shap_vectors = normalize(shap_vectors, method=self.config.normalization)
|
|
208
|
+
|
|
209
|
+
# Step 3: Cluster patterns (if enough trades)
|
|
210
|
+
error_patterns = []
|
|
211
|
+
min_trades = self.config.clustering.min_trades_for_clustering
|
|
212
|
+
|
|
213
|
+
if len(explanations) >= min_trades:
|
|
214
|
+
try:
|
|
215
|
+
clustering_result = self.clusterer.cluster(shap_vectors)
|
|
216
|
+
|
|
217
|
+
# Step 4: Characterize each cluster
|
|
218
|
+
patterns = self.characterizer.characterize_all_clusters(
|
|
219
|
+
shap_vectors=shap_vectors,
|
|
220
|
+
cluster_labels=clustering_result.cluster_assignments,
|
|
221
|
+
n_clusters=clustering_result.n_clusters,
|
|
222
|
+
centroids=clustering_result.centroids,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Step 5: Generate hypotheses for each pattern
|
|
226
|
+
for pattern in patterns:
|
|
227
|
+
enriched = self.hypothesis_generator.generate_hypothesis(pattern)
|
|
228
|
+
error_patterns.append(enriched)
|
|
229
|
+
|
|
230
|
+
except ValueError:
|
|
231
|
+
# Clustering failed (e.g., insufficient samples)
|
|
232
|
+
# Continue without error patterns
|
|
233
|
+
pass
|
|
234
|
+
|
|
235
|
+
return TradeShapResult(
|
|
236
|
+
n_trades_analyzed=len(trades_to_analyze),
|
|
237
|
+
n_trades_explained=len(explanations),
|
|
238
|
+
n_trades_failed=len(failures),
|
|
239
|
+
explanations=explanations,
|
|
240
|
+
failed_trades=[(f.trade_id, f.reason) for f in failures],
|
|
241
|
+
error_patterns=error_patterns,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def generate_actions(
|
|
245
|
+
self,
|
|
246
|
+
pattern_index: int = 0,
|
|
247
|
+
max_actions: int | None = None,
|
|
248
|
+
) -> list[dict[str, Any]]:
|
|
249
|
+
"""Generate prioritized actions for an error pattern.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
pattern_index: Index of pattern in last result (default: 0 = first)
|
|
253
|
+
max_actions: Maximum actions to return
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
List of action dictionaries
|
|
257
|
+
|
|
258
|
+
Note:
|
|
259
|
+
Must call analyze_worst_trades first.
|
|
260
|
+
"""
|
|
261
|
+
# This is a convenience method - in practice, use the hypothesis generator
|
|
262
|
+
# directly with the error pattern from results
|
|
263
|
+
raise NotImplementedError("Use hypothesis_generator.generate_actions(pattern) directly")
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
"""Streamlit dashboard for Trade-SHAP diagnostics.
|
|
2
|
+
|
|
3
|
+
This module provides an interactive Streamlit dashboard for visualizing
|
|
4
|
+
Trade-SHAP analysis results, including statistical validation, worst trades,
|
|
5
|
+
SHAP explanations, and error patterns.
|
|
6
|
+
|
|
7
|
+
The dashboard is designed for systematic trade debugging and continuous
|
|
8
|
+
improvement of ML trading strategies.
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
# From command line
|
|
12
|
+
streamlit run -m ml4t.diagnostic.evaluation.trade_shap_dashboard
|
|
13
|
+
|
|
14
|
+
# Programmatically
|
|
15
|
+
from ml4t.diagnostic.evaluation.trade_shap_dashboard import run_diagnostics_dashboard
|
|
16
|
+
run_diagnostics_dashboard(result)
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
>>> from ml4t.diagnostic.evaluation import TradeShapAnalyzer
|
|
20
|
+
>>> from ml4t.diagnostic.evaluation.trade_shap_dashboard import run_diagnostics_dashboard
|
|
21
|
+
>>>
|
|
22
|
+
>>> # Analyze trades and get results
|
|
23
|
+
>>> analyzer = TradeShapAnalyzer(model, features_df, shap_values)
|
|
24
|
+
>>> result = analyzer.explain_worst_trades(worst_trades)
|
|
25
|
+
>>>
|
|
26
|
+
>>> # Launch interactive dashboard
|
|
27
|
+
>>> run_diagnostics_dashboard(result)
|
|
28
|
+
|
|
29
|
+
Note:
|
|
30
|
+
This module is a thin wrapper around the modular dashboard package.
|
|
31
|
+
The implementation has been refactored into:
|
|
32
|
+
- ml4t.diagnostic.evaluation.trade_dashboard.app (main orchestrator)
|
|
33
|
+
- ml4t.diagnostic.evaluation.trade_dashboard.tabs (tab modules)
|
|
34
|
+
- ml4t.diagnostic.evaluation.trade_dashboard.stats (statistical computations)
|
|
35
|
+
- ml4t.diagnostic.evaluation.trade_dashboard.export (export functions)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
from __future__ import annotations
|
|
39
|
+
|
|
40
|
+
from typing import TYPE_CHECKING, Any
|
|
41
|
+
|
|
42
|
+
# Re-export the main entry point for backward compatibility
|
|
43
|
+
from ml4t.diagnostic.evaluation.trade_dashboard import run_diagnostics_dashboard
|
|
44
|
+
|
|
45
|
+
if TYPE_CHECKING:
|
|
46
|
+
from ml4t.diagnostic.evaluation.trade_shap.models import TradeShapResult
|
|
47
|
+
|
|
48
|
+
# Import utilities for backward compatibility
|
|
49
|
+
from ml4t.diagnostic.evaluation.trade_dashboard.io import (
|
|
50
|
+
PickleDisabledError,
|
|
51
|
+
)
|
|
52
|
+
from ml4t.diagnostic.evaluation.trade_dashboard.io import (
|
|
53
|
+
load_result_from_upload as load_data_from_file,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
__all__ = [
|
|
57
|
+
"run_diagnostics_dashboard",
|
|
58
|
+
"run_polished_dashboard",
|
|
59
|
+
"export_full_report_html",
|
|
60
|
+
"export_patterns_to_csv",
|
|
61
|
+
"export_trades_to_csv",
|
|
62
|
+
"load_data_from_file",
|
|
63
|
+
"PickleDisabledError",
|
|
64
|
+
"extract_trade_returns",
|
|
65
|
+
"extract_trade_data",
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def run_polished_dashboard(
|
|
70
|
+
result: TradeShapResult | dict[str, Any] | None = None,
|
|
71
|
+
title: str = "Trade-SHAP Diagnostics Dashboard",
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Run dashboard with styled=True. Alias for backward compat."""
|
|
74
|
+
run_diagnostics_dashboard(result=result, title=title, styled=True)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# Backward-compatible export functions that accept raw dicts/lists
|
|
78
|
+
def export_trades_to_csv(trades_data: list[dict[str, Any]]) -> str:
|
|
79
|
+
"""Export trades to CSV format. Backward-compatible API.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
trades_data : list of dict
|
|
84
|
+
List of trade dictionaries.
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
str
|
|
89
|
+
CSV formatted string.
|
|
90
|
+
"""
|
|
91
|
+
import pandas as pd
|
|
92
|
+
|
|
93
|
+
if not trades_data:
|
|
94
|
+
return ""
|
|
95
|
+
return pd.DataFrame(trades_data).to_csv(index=False)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def export_patterns_to_csv(patterns: list[dict[str, Any]]) -> str:
|
|
99
|
+
"""Export patterns to CSV format. Backward-compatible API.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
patterns : list of dict
|
|
104
|
+
List of pattern dictionaries.
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
str
|
|
109
|
+
CSV formatted string with headers Pattern ID, etc.
|
|
110
|
+
"""
|
|
111
|
+
import pandas as pd
|
|
112
|
+
|
|
113
|
+
if not patterns:
|
|
114
|
+
return ""
|
|
115
|
+
|
|
116
|
+
# Transform to expected format
|
|
117
|
+
records = []
|
|
118
|
+
for p in patterns:
|
|
119
|
+
records.append(
|
|
120
|
+
{
|
|
121
|
+
"Pattern ID": p.get("cluster_id", 0),
|
|
122
|
+
"N Trades": p.get("n_trades", 0),
|
|
123
|
+
"Description": p.get("description", ""),
|
|
124
|
+
"Hypothesis": p.get("hypothesis", ""),
|
|
125
|
+
"Confidence": p.get("confidence", ""),
|
|
126
|
+
}
|
|
127
|
+
)
|
|
128
|
+
return pd.DataFrame(records).to_csv(index=False)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def export_full_report_html(result: dict[str, Any]) -> str:
|
|
132
|
+
"""Export full HTML report. Backward-compatible API.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
result : dict
|
|
137
|
+
Analysis result dictionary.
|
|
138
|
+
|
|
139
|
+
Returns
|
|
140
|
+
-------
|
|
141
|
+
str
|
|
142
|
+
HTML report string.
|
|
143
|
+
"""
|
|
144
|
+
from datetime import datetime
|
|
145
|
+
|
|
146
|
+
patterns = result.get("error_patterns", [])
|
|
147
|
+
n_analyzed = result.get("n_trades_analyzed", 0)
|
|
148
|
+
n_explained = result.get("n_trades_explained", 0)
|
|
149
|
+
n_failed = result.get("n_trades_failed", 0)
|
|
150
|
+
|
|
151
|
+
patterns_html = ""
|
|
152
|
+
for p in patterns:
|
|
153
|
+
hypothesis = p.get("hypothesis", "No hypothesis")
|
|
154
|
+
actions = p.get("actions", [])
|
|
155
|
+
actions_html = "".join(f"<li>{a}</li>" for a in actions) if actions else ""
|
|
156
|
+
|
|
157
|
+
patterns_html += f"""
|
|
158
|
+
<div class="pattern">
|
|
159
|
+
<h3>Pattern {p.get("cluster_id", "N/A")}: {p.get("n_trades", 0)} trades</h3>
|
|
160
|
+
<p><strong>Description:</strong> {p.get("description", "N/A")}</p>
|
|
161
|
+
<p><strong>Hypothesis:</strong> {hypothesis}</p>
|
|
162
|
+
<ul>{actions_html}</ul>
|
|
163
|
+
</div>
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
return f"""<!DOCTYPE html>
|
|
167
|
+
<html>
|
|
168
|
+
<head>
|
|
169
|
+
<title>Trade-SHAP Analysis Report</title>
|
|
170
|
+
<style>
|
|
171
|
+
body {{ font-family: sans-serif; max-width: 1000px; margin: 0 auto; padding: 20px; }}
|
|
172
|
+
.header {{ background: #1f77b4; color: white; padding: 20px; }}
|
|
173
|
+
.metrics {{ display: flex; gap: 20px; margin: 20px 0; }}
|
|
174
|
+
.metric {{ background: #f0f0f0; padding: 15px; }}
|
|
175
|
+
.pattern {{ border: 1px solid #ddd; padding: 15px; margin: 10px 0; }}
|
|
176
|
+
</style>
|
|
177
|
+
</head>
|
|
178
|
+
<body>
|
|
179
|
+
<div class="header">
|
|
180
|
+
<h1>Trade-SHAP Analysis Report</h1>
|
|
181
|
+
<p>Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</p>
|
|
182
|
+
</div>
|
|
183
|
+
<div class="metrics">
|
|
184
|
+
<div class="metric"><strong>Analyzed:</strong> {n_analyzed}</div>
|
|
185
|
+
<div class="metric"><strong>Explained:</strong> {n_explained}</div>
|
|
186
|
+
<div class="metric"><strong>Failed:</strong> {n_failed}</div>
|
|
187
|
+
</div>
|
|
188
|
+
<h2>Error Patterns</h2>
|
|
189
|
+
{patterns_html}
|
|
190
|
+
</body>
|
|
191
|
+
</html>"""
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def extract_trade_returns(result: dict[str, Any]) -> list[float]:
|
|
195
|
+
"""Extract trade PnL values from analysis result.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
result : dict
|
|
200
|
+
Analysis result dictionary with "explanations" key.
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
list of float
|
|
205
|
+
List of PnL values from each trade.
|
|
206
|
+
|
|
207
|
+
Examples
|
|
208
|
+
--------
|
|
209
|
+
>>> result = {"explanations": [{"trade_metrics": {"pnl": 100.0}}]}
|
|
210
|
+
>>> extract_trade_returns(result)
|
|
211
|
+
[100.0]
|
|
212
|
+
"""
|
|
213
|
+
explanations = result.get("explanations", [])
|
|
214
|
+
returns = []
|
|
215
|
+
for exp in explanations:
|
|
216
|
+
trade_metrics = exp.get("trade_metrics", {})
|
|
217
|
+
pnl = trade_metrics.get("pnl", 0.0)
|
|
218
|
+
returns.append(pnl)
|
|
219
|
+
return returns
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def extract_trade_data(result: dict[str, Any]) -> list[dict[str, Any]]:
|
|
223
|
+
"""Extract trade data for display from analysis result.
|
|
224
|
+
|
|
225
|
+
Parameters
|
|
226
|
+
----------
|
|
227
|
+
result : dict
|
|
228
|
+
Analysis result dictionary with "explanations" key.
|
|
229
|
+
|
|
230
|
+
Returns
|
|
231
|
+
-------
|
|
232
|
+
list of dict
|
|
233
|
+
List of trade data dictionaries with keys:
|
|
234
|
+
- trade_id: Trade identifier
|
|
235
|
+
- timestamp: Trade timestamp
|
|
236
|
+
- symbol: Trading symbol
|
|
237
|
+
- pnl: Profit/loss
|
|
238
|
+
- return_pct: Return percentage
|
|
239
|
+
- duration_days: Trade duration
|
|
240
|
+
- entry_price: Entry price
|
|
241
|
+
- exit_price: Exit price
|
|
242
|
+
- top_feature: Most important feature
|
|
243
|
+
- top_shap_value: SHAP value of top feature
|
|
244
|
+
|
|
245
|
+
Examples
|
|
246
|
+
--------
|
|
247
|
+
>>> result = {"explanations": [{"trade_id": "T1", "trade_metrics": {"pnl": 100.0}}]}
|
|
248
|
+
>>> data = extract_trade_data(result)
|
|
249
|
+
>>> data[0]["trade_id"]
|
|
250
|
+
'T1'
|
|
251
|
+
"""
|
|
252
|
+
explanations = result.get("explanations", [])
|
|
253
|
+
trade_data = []
|
|
254
|
+
|
|
255
|
+
for exp in explanations:
|
|
256
|
+
trade_metrics = exp.get("trade_metrics", {})
|
|
257
|
+
top_features = exp.get("top_features", [])
|
|
258
|
+
|
|
259
|
+
# Get top feature info
|
|
260
|
+
top_feature = top_features[0][0] if top_features else None
|
|
261
|
+
top_shap_value = top_features[0][1] if top_features else None
|
|
262
|
+
|
|
263
|
+
trade_data.append(
|
|
264
|
+
{
|
|
265
|
+
"trade_id": exp.get("trade_id", ""),
|
|
266
|
+
"timestamp": exp.get("timestamp", ""),
|
|
267
|
+
"symbol": trade_metrics.get("symbol", ""),
|
|
268
|
+
"pnl": trade_metrics.get("pnl", 0.0),
|
|
269
|
+
"return_pct": trade_metrics.get("return_pct", 0.0),
|
|
270
|
+
"duration_days": trade_metrics.get("duration_days", 0.0),
|
|
271
|
+
"entry_price": trade_metrics.get("entry_price", 0.0),
|
|
272
|
+
"exit_price": trade_metrics.get("exit_price", 0.0),
|
|
273
|
+
"top_feature": top_feature,
|
|
274
|
+
"top_shap_value": top_shap_value,
|
|
275
|
+
}
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return trade_data
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
# Allow running as a standalone Streamlit app
|
|
282
|
+
if __name__ == "__main__":
|
|
283
|
+
run_diagnostics_dashboard()
|