ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
"""Trade SHAP explanation logic.
|
|
2
|
+
|
|
3
|
+
This module provides the TradeShapExplainer class that explains individual trades
|
|
4
|
+
using SHAP values, with O(log n) timestamp alignment and efficient feature extraction.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from ml4t.diagnostic.evaluation.trade_shap.alignment import TimestampAligner
|
|
14
|
+
from ml4t.diagnostic.evaluation.trade_shap.models import (
|
|
15
|
+
TradeExplainFailure,
|
|
16
|
+
TradeShapExplanation,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
import polars as pl
|
|
21
|
+
from numpy.typing import NDArray
|
|
22
|
+
|
|
23
|
+
from ml4t.diagnostic.evaluation.trade_analysis import TradeMetrics
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TradeShapExplainer:
|
|
27
|
+
"""Explains individual trades using SHAP values.
|
|
28
|
+
|
|
29
|
+
Uses TimestampAligner for O(log n) timestamp lookup and extracts
|
|
30
|
+
feature values in a single row read for efficiency.
|
|
31
|
+
|
|
32
|
+
Returns TradeExplainFailure for expected failure cases instead of
|
|
33
|
+
throwing exceptions, enabling clean batch processing.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
features_df: Polars DataFrame with timestamp and feature columns
|
|
37
|
+
shap_values: 2D numpy array of SHAP values (n_samples x n_features)
|
|
38
|
+
feature_names: List of feature column names
|
|
39
|
+
aligner: TimestampAligner for fast timestamp lookup
|
|
40
|
+
top_n_features: Number of top features to include in explanation
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
>>> explainer = TradeShapExplainer(
|
|
44
|
+
... features_df=features,
|
|
45
|
+
... shap_values=shap_values,
|
|
46
|
+
... feature_names=feature_names,
|
|
47
|
+
... tolerance_seconds=60.0,
|
|
48
|
+
... )
|
|
49
|
+
>>> result = explainer.explain(trade)
|
|
50
|
+
>>> if isinstance(result, TradeShapExplanation):
|
|
51
|
+
... print(result.top_features[:3])
|
|
52
|
+
... else:
|
|
53
|
+
... print(f"Failed: {result.reason}")
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
features_df: pl.DataFrame,
|
|
59
|
+
shap_values: NDArray[np.floating[Any]],
|
|
60
|
+
feature_names: list[str],
|
|
61
|
+
tolerance_seconds: float = 0.0,
|
|
62
|
+
top_n_features: int | None = None,
|
|
63
|
+
alignment_mode: str = "entry",
|
|
64
|
+
missing_value_strategy: str = "skip",
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Initialize the explainer.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
features_df: Polars DataFrame with 'timestamp' column and feature columns
|
|
70
|
+
shap_values: SHAP values array (n_samples x n_features)
|
|
71
|
+
feature_names: List of feature column names matching shap_values columns
|
|
72
|
+
tolerance_seconds: Maximum seconds for nearest-match alignment (0 = exact only)
|
|
73
|
+
top_n_features: Number of top features to include (None = all)
|
|
74
|
+
alignment_mode: 'entry' for exact match, 'nearest' for closest within tolerance
|
|
75
|
+
missing_value_strategy: How to handle alignment failures ('error', 'skip', 'zero')
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ValueError: If shap_values shape doesn't match features_df rows or feature_names
|
|
79
|
+
"""
|
|
80
|
+
self.features_df = features_df
|
|
81
|
+
self.shap_values = shap_values
|
|
82
|
+
self.feature_names = feature_names
|
|
83
|
+
self.top_n_features = top_n_features
|
|
84
|
+
self.alignment_mode = alignment_mode
|
|
85
|
+
self.missing_value_strategy = missing_value_strategy
|
|
86
|
+
|
|
87
|
+
# Validate shapes
|
|
88
|
+
n_rows = len(features_df)
|
|
89
|
+
n_features = len(feature_names)
|
|
90
|
+
|
|
91
|
+
if shap_values.shape[0] != n_rows:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"SHAP values rows ({shap_values.shape[0]}) != features_df rows ({n_rows})"
|
|
94
|
+
)
|
|
95
|
+
if shap_values.shape[1] != n_features:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"SHAP values columns ({shap_values.shape[1]}) != feature_names ({n_features})"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Build aligner with appropriate tolerance
|
|
101
|
+
timestamps = features_df["timestamp"].to_list()
|
|
102
|
+
effective_tolerance = tolerance_seconds if alignment_mode == "nearest" else 0.0
|
|
103
|
+
self.aligner = TimestampAligner.from_datetime_index(
|
|
104
|
+
timestamps, tolerance_seconds=effective_tolerance
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Cache feature data as numpy for fast row extraction
|
|
108
|
+
self._feature_matrix = features_df.select(feature_names).to_numpy()
|
|
109
|
+
|
|
110
|
+
def explain(
|
|
111
|
+
self,
|
|
112
|
+
trade: TradeMetrics,
|
|
113
|
+
) -> TradeShapExplanation | TradeExplainFailure:
|
|
114
|
+
"""Explain a single trade.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
trade: Trade to explain (must have timestamp and symbol attributes)
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
TradeShapExplanation on success, TradeExplainFailure on expected failures
|
|
121
|
+
"""
|
|
122
|
+
trade_id = f"{trade.symbol}_{trade.timestamp.isoformat()}"
|
|
123
|
+
|
|
124
|
+
# Align to timestamp
|
|
125
|
+
result = self.aligner.align(trade.timestamp)
|
|
126
|
+
|
|
127
|
+
if result.index is None:
|
|
128
|
+
# Handle alignment failure based on strategy
|
|
129
|
+
if self.missing_value_strategy == "error":
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"Cannot align SHAP values for trade {trade_id}: "
|
|
132
|
+
f"no timestamp within {self.aligner.tolerance_seconds}s "
|
|
133
|
+
f"(nearest is {result.distance_seconds:.1f}s away)"
|
|
134
|
+
)
|
|
135
|
+
elif self.missing_value_strategy == "zero":
|
|
136
|
+
# Return zero SHAP vector
|
|
137
|
+
shap_vector = np.zeros(len(self.feature_names))
|
|
138
|
+
feature_values = dict.fromkeys(self.feature_names, 0.0)
|
|
139
|
+
top_features = [(name, 0.0) for name in self.feature_names]
|
|
140
|
+
return TradeShapExplanation(
|
|
141
|
+
trade_id=trade_id,
|
|
142
|
+
timestamp=trade.timestamp,
|
|
143
|
+
top_features=top_features,
|
|
144
|
+
feature_values=feature_values,
|
|
145
|
+
shap_vector=shap_vector,
|
|
146
|
+
)
|
|
147
|
+
else: # "skip" or default
|
|
148
|
+
return TradeExplainFailure(
|
|
149
|
+
trade_id=trade_id,
|
|
150
|
+
timestamp=trade.timestamp,
|
|
151
|
+
reason="alignment_missing",
|
|
152
|
+
details={
|
|
153
|
+
"alignment_mode": self.alignment_mode,
|
|
154
|
+
"tolerance_seconds": self.aligner.tolerance_seconds,
|
|
155
|
+
"distance_seconds": result.distance_seconds,
|
|
156
|
+
},
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
idx = result.index
|
|
160
|
+
|
|
161
|
+
# Extract SHAP vector for this row
|
|
162
|
+
shap_vector = np.asarray(self.shap_values[idx, :], dtype=np.float64)
|
|
163
|
+
|
|
164
|
+
# Extract feature values in one row read (not per-feature loop)
|
|
165
|
+
feature_row = self._feature_matrix[idx, :]
|
|
166
|
+
feature_values = {
|
|
167
|
+
name: float(val) for name, val in zip(self.feature_names, feature_row, strict=True)
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
# Get top N contributors by absolute SHAP value
|
|
171
|
+
top_n = self.top_n_features if self.top_n_features is not None else len(self.feature_names)
|
|
172
|
+
|
|
173
|
+
# Create (feature_name, shap_value) pairs and sort by |shap|
|
|
174
|
+
feature_shap_pairs = list(zip(self.feature_names, shap_vector.tolist(), strict=True))
|
|
175
|
+
feature_shap_pairs.sort(key=lambda x: abs(x[1]), reverse=True)
|
|
176
|
+
top_features = [(name, float(val)) for name, val in feature_shap_pairs[:top_n]]
|
|
177
|
+
|
|
178
|
+
return TradeShapExplanation(
|
|
179
|
+
trade_id=trade_id,
|
|
180
|
+
timestamp=trade.timestamp,
|
|
181
|
+
top_features=top_features,
|
|
182
|
+
feature_values=feature_values,
|
|
183
|
+
shap_vector=shap_vector,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def explain_many(
|
|
187
|
+
self,
|
|
188
|
+
trades: list[TradeMetrics],
|
|
189
|
+
) -> tuple[list[TradeShapExplanation], list[TradeExplainFailure]]:
|
|
190
|
+
"""Explain multiple trades.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
trades: List of trades to explain
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Tuple of (successful explanations, failures)
|
|
197
|
+
"""
|
|
198
|
+
explanations: list[TradeShapExplanation] = []
|
|
199
|
+
failures: list[TradeExplainFailure] = []
|
|
200
|
+
|
|
201
|
+
for trade in trades:
|
|
202
|
+
result = self.explain(trade)
|
|
203
|
+
if isinstance(result, TradeShapExplanation):
|
|
204
|
+
explanations.append(result)
|
|
205
|
+
else:
|
|
206
|
+
failures.append(result)
|
|
207
|
+
|
|
208
|
+
return explanations, failures
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Hypothesis generation for trade SHAP error patterns.
|
|
2
|
+
|
|
3
|
+
This package provides template-based hypothesis generation for explaining
|
|
4
|
+
why trading patterns cause losses, with templates stored as YAML data.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ml4t.diagnostic.evaluation.trade_shap.hypotheses.generator import (
|
|
8
|
+
HypothesisConfig,
|
|
9
|
+
HypothesisGenerator,
|
|
10
|
+
)
|
|
11
|
+
from ml4t.diagnostic.evaluation.trade_shap.hypotheses.matcher import (
|
|
12
|
+
Template,
|
|
13
|
+
TemplateMatcher,
|
|
14
|
+
load_templates,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"HypothesisGenerator",
|
|
19
|
+
"HypothesisConfig",
|
|
20
|
+
"TemplateMatcher",
|
|
21
|
+
"Template",
|
|
22
|
+
"load_templates",
|
|
23
|
+
]
|
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
"""Hypothesis generator for trade SHAP error patterns.
|
|
2
|
+
|
|
3
|
+
Generates actionable hypotheses and improvement suggestions based on
|
|
4
|
+
template matching against error pattern features.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
from ml4t.diagnostic.evaluation.trade_shap.hypotheses.matcher import (
|
|
13
|
+
TemplateMatcher,
|
|
14
|
+
load_templates,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from ml4t.diagnostic.evaluation.trade_shap.models import ErrorPattern
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class HypothesisConfig:
|
|
23
|
+
"""Configuration for hypothesis generation.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
template_library: Which template library to use ('comprehensive' or 'minimal')
|
|
27
|
+
min_confidence: Minimum confidence threshold for generating hypothesis
|
|
28
|
+
max_actions: Maximum number of actions to include
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
template_library: str = "comprehensive"
|
|
32
|
+
min_confidence: float = 0.5
|
|
33
|
+
max_actions: int = 4
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class HypothesisGenerator:
|
|
37
|
+
"""Generates hypotheses for error patterns using template matching.
|
|
38
|
+
|
|
39
|
+
Matches error pattern features against a library of templates and
|
|
40
|
+
generates actionable hypotheses about why the pattern causes losses.
|
|
41
|
+
|
|
42
|
+
Attributes:
|
|
43
|
+
config: Hypothesis generation configuration
|
|
44
|
+
matcher: Template matcher
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
>>> generator = HypothesisGenerator()
|
|
48
|
+
>>> enriched = generator.generate_hypothesis(error_pattern)
|
|
49
|
+
>>> print(enriched.hypothesis)
|
|
50
|
+
>>> print(enriched.actions)
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, config: HypothesisConfig | Any | None = None) -> None:
|
|
54
|
+
"""Initialize generator.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
config: Hypothesis configuration (uses defaults if None).
|
|
58
|
+
Accepts HypothesisConfig dataclass or HypothesisGenerationConfig Pydantic model.
|
|
59
|
+
"""
|
|
60
|
+
# Normalize config to HypothesisConfig dataclass
|
|
61
|
+
self.config = self._normalize_config(config)
|
|
62
|
+
|
|
63
|
+
# Load templates and create matcher
|
|
64
|
+
templates = load_templates(self.config.template_library)
|
|
65
|
+
self.matcher = TemplateMatcher(templates)
|
|
66
|
+
|
|
67
|
+
def _normalize_config(self, config: Any) -> HypothesisConfig:
|
|
68
|
+
"""Normalize config to HypothesisConfig dataclass.
|
|
69
|
+
|
|
70
|
+
Supports both HypothesisConfig dataclass and HypothesisGenerationConfig Pydantic model.
|
|
71
|
+
"""
|
|
72
|
+
if config is None:
|
|
73
|
+
return HypothesisConfig()
|
|
74
|
+
|
|
75
|
+
if isinstance(config, HypothesisConfig):
|
|
76
|
+
return config
|
|
77
|
+
|
|
78
|
+
# Handle Pydantic HypothesisGenerationConfig or similar
|
|
79
|
+
return HypothesisConfig(
|
|
80
|
+
template_library=getattr(config, "template_library", "comprehensive"),
|
|
81
|
+
min_confidence=getattr(config, "min_confidence", 0.5),
|
|
82
|
+
max_actions=getattr(config, "max_actions", 4),
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def generate_hypothesis(
|
|
86
|
+
self,
|
|
87
|
+
error_pattern: ErrorPattern,
|
|
88
|
+
feature_names: list[str] | None = None,
|
|
89
|
+
) -> ErrorPattern:
|
|
90
|
+
"""Generate hypothesis for an error pattern.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
error_pattern: Error pattern to analyze
|
|
94
|
+
feature_names: Optional list of all feature names for context
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
ErrorPattern with hypothesis, actions, and confidence fields populated
|
|
98
|
+
"""
|
|
99
|
+
from ml4t.diagnostic.evaluation.trade_shap.models import ErrorPattern
|
|
100
|
+
|
|
101
|
+
# Parse top_features into dict format for matcher
|
|
102
|
+
pattern_features = [
|
|
103
|
+
{
|
|
104
|
+
"name": feat[0],
|
|
105
|
+
"mean_shap": feat[1],
|
|
106
|
+
"p_value_t": feat[2],
|
|
107
|
+
"p_value_mw": feat[3],
|
|
108
|
+
"is_significant": feat[4],
|
|
109
|
+
}
|
|
110
|
+
for feat in error_pattern.top_features
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
# Try to match a template
|
|
114
|
+
match_result = self.matcher.match(pattern_features)
|
|
115
|
+
|
|
116
|
+
if match_result is None or match_result.confidence < self.config.min_confidence:
|
|
117
|
+
# No good match - return pattern unchanged
|
|
118
|
+
return error_pattern
|
|
119
|
+
|
|
120
|
+
# Format hypothesis from template
|
|
121
|
+
hypothesis = self._format_hypothesis(
|
|
122
|
+
match_result.template.hypothesis_template,
|
|
123
|
+
match_result.matched_features,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Get actions (limit to max)
|
|
127
|
+
actions = match_result.template.actions[: self.config.max_actions]
|
|
128
|
+
|
|
129
|
+
# Adjust confidence based on pattern characteristics
|
|
130
|
+
adjusted_confidence = self._adjust_confidence(
|
|
131
|
+
match_result.confidence,
|
|
132
|
+
error_pattern.n_trades,
|
|
133
|
+
error_pattern.separation_score,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Return enriched pattern
|
|
137
|
+
return ErrorPattern(
|
|
138
|
+
cluster_id=error_pattern.cluster_id,
|
|
139
|
+
n_trades=error_pattern.n_trades,
|
|
140
|
+
description=error_pattern.description,
|
|
141
|
+
top_features=error_pattern.top_features,
|
|
142
|
+
separation_score=error_pattern.separation_score,
|
|
143
|
+
distinctiveness=error_pattern.distinctiveness,
|
|
144
|
+
hypothesis=hypothesis,
|
|
145
|
+
actions=actions,
|
|
146
|
+
confidence=adjusted_confidence,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def _format_hypothesis(
|
|
150
|
+
self,
|
|
151
|
+
template: str,
|
|
152
|
+
matched_features: list[dict[str, Any]],
|
|
153
|
+
) -> str:
|
|
154
|
+
"""Format hypothesis string from template.
|
|
155
|
+
|
|
156
|
+
Substitutes {feature} placeholder with actual feature name(s).
|
|
157
|
+
"""
|
|
158
|
+
if not matched_features:
|
|
159
|
+
return template.replace("{feature}", "the feature")
|
|
160
|
+
|
|
161
|
+
# Use first matched feature name
|
|
162
|
+
feature_name = matched_features[0]["name"]
|
|
163
|
+
|
|
164
|
+
# If multiple significant features, mention them
|
|
165
|
+
sig_features = [f for f in matched_features if f["is_significant"]]
|
|
166
|
+
if len(sig_features) > 1:
|
|
167
|
+
names = [f["name"] for f in sig_features[:2]]
|
|
168
|
+
feature_name = " and ".join(names)
|
|
169
|
+
|
|
170
|
+
return template.replace("{feature}", feature_name)
|
|
171
|
+
|
|
172
|
+
def _adjust_confidence(
|
|
173
|
+
self,
|
|
174
|
+
base_confidence: float,
|
|
175
|
+
n_trades: int,
|
|
176
|
+
separation_score: float,
|
|
177
|
+
) -> float:
|
|
178
|
+
"""Adjust confidence based on pattern characteristics.
|
|
179
|
+
|
|
180
|
+
- More trades = higher confidence (larger sample)
|
|
181
|
+
- Higher separation = higher confidence (more distinct pattern)
|
|
182
|
+
- Very small samples or poor separation get significant penalties
|
|
183
|
+
"""
|
|
184
|
+
# Trade count adjustment - penalize small samples heavily
|
|
185
|
+
if n_trades >= 20:
|
|
186
|
+
trade_boost = 0.05
|
|
187
|
+
elif n_trades >= 10:
|
|
188
|
+
trade_boost = 0.02
|
|
189
|
+
elif n_trades >= 5:
|
|
190
|
+
trade_boost = -0.10
|
|
191
|
+
elif n_trades >= 2:
|
|
192
|
+
trade_boost = -0.25
|
|
193
|
+
else:
|
|
194
|
+
# Single trade - very unreliable
|
|
195
|
+
trade_boost = -0.50
|
|
196
|
+
|
|
197
|
+
# Separation score adjustment - penalize poor cluster separation
|
|
198
|
+
if separation_score >= 1.5:
|
|
199
|
+
sep_boost = 0.05
|
|
200
|
+
elif separation_score >= 1.0:
|
|
201
|
+
sep_boost = 0.02
|
|
202
|
+
elif separation_score >= 0.5:
|
|
203
|
+
sep_boost = -0.20 # Moderate separation needs noticeable penalty
|
|
204
|
+
elif separation_score >= 0.3:
|
|
205
|
+
sep_boost = -0.35
|
|
206
|
+
else:
|
|
207
|
+
# Very poor separation - cluster is not distinct
|
|
208
|
+
sep_boost = -0.50
|
|
209
|
+
|
|
210
|
+
adjusted = base_confidence + trade_boost + sep_boost
|
|
211
|
+
return max(0.0, min(1.0, adjusted))
|
|
212
|
+
|
|
213
|
+
def generate_actions(
|
|
214
|
+
self,
|
|
215
|
+
error_pattern: ErrorPattern,
|
|
216
|
+
max_actions: int | None = None,
|
|
217
|
+
) -> list[dict[str, Any]]:
|
|
218
|
+
"""Generate prioritized action suggestions for an error pattern.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
error_pattern: Error pattern with hypothesis
|
|
222
|
+
max_actions: Maximum actions to return (defaults to config)
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
List of action dictionaries with category, description, priority, etc.
|
|
226
|
+
"""
|
|
227
|
+
if max_actions is None:
|
|
228
|
+
max_actions = self.config.max_actions
|
|
229
|
+
|
|
230
|
+
if not error_pattern.actions:
|
|
231
|
+
return []
|
|
232
|
+
|
|
233
|
+
# Categorize and prioritize actions
|
|
234
|
+
categorized_actions = []
|
|
235
|
+
|
|
236
|
+
for i, action in enumerate(error_pattern.actions[:max_actions]):
|
|
237
|
+
# Determine category from action text
|
|
238
|
+
category = self._categorize_action(action)
|
|
239
|
+
|
|
240
|
+
# Priority based on position and confidence
|
|
241
|
+
priority = self._determine_priority(i, error_pattern.confidence)
|
|
242
|
+
|
|
243
|
+
categorized_actions.append(
|
|
244
|
+
{
|
|
245
|
+
"category": category,
|
|
246
|
+
"description": action,
|
|
247
|
+
"priority": priority,
|
|
248
|
+
"implementation_difficulty": self._estimate_difficulty(action),
|
|
249
|
+
"rationale": f"Based on pattern: {error_pattern.description}",
|
|
250
|
+
}
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
return categorized_actions
|
|
254
|
+
|
|
255
|
+
def _categorize_action(self, action: str) -> str:
|
|
256
|
+
"""Categorize an action based on its text."""
|
|
257
|
+
action_lower = action.lower()
|
|
258
|
+
|
|
259
|
+
if any(word in action_lower for word in ["feature", "indicator", "add"]):
|
|
260
|
+
return "feature_engineering"
|
|
261
|
+
elif any(word in action_lower for word in ["filter", "regime", "threshold"]):
|
|
262
|
+
return "filter_regime"
|
|
263
|
+
elif any(word in action_lower for word in ["size", "position", "stop", "risk"]):
|
|
264
|
+
return "risk_management"
|
|
265
|
+
elif any(word in action_lower for word in ["tune", "parameter", "adjust"]):
|
|
266
|
+
return "model_adjustment"
|
|
267
|
+
else:
|
|
268
|
+
return "general"
|
|
269
|
+
|
|
270
|
+
def _determine_priority(self, position: int, confidence: float | None) -> str:
|
|
271
|
+
"""Determine action priority."""
|
|
272
|
+
conf = confidence or 0.5
|
|
273
|
+
|
|
274
|
+
if position == 0 and conf >= 0.7:
|
|
275
|
+
return "high"
|
|
276
|
+
elif position <= 1 and conf >= 0.5:
|
|
277
|
+
return "medium"
|
|
278
|
+
else:
|
|
279
|
+
return "low"
|
|
280
|
+
|
|
281
|
+
def _estimate_difficulty(self, action: str) -> str:
|
|
282
|
+
"""Estimate implementation difficulty from action text."""
|
|
283
|
+
action_lower = action.lower()
|
|
284
|
+
|
|
285
|
+
if any(word in action_lower for word in ["implement", "hmm", "model", "ensemble"]):
|
|
286
|
+
return "hard"
|
|
287
|
+
elif any(word in action_lower for word in ["add", "consider", "track"]):
|
|
288
|
+
return "medium"
|
|
289
|
+
else:
|
|
290
|
+
return "easy"
|